refactor the error handling and optimize the heap allocation

This commit is contained in:
2025-05-04 13:32:31 +10:00
parent 4abc00f07c
commit 0c1a9fb7e7
6 changed files with 242 additions and 66 deletions

191
bounce.go
View File

@@ -18,6 +18,7 @@ package netbounce
import (
"context"
"errors"
"fmt"
"net"
"sync"
@@ -35,17 +36,80 @@ const (
BUF_SIZE = 4096
)
type RetryConfig interface {
ReadRetry() int
WriteRetry() int
}
type BufPool struct {
pool sync.Pool
}
func (b *BufPool) Get() []byte {
return b.pool.Get().([]byte)
}
func (b *BufPool) Put(buf []byte) {
//lint:ignore SA6002 copying a 3-word slice header is intentional and cheap
b.pool.Put(buf[:cap(buf)])
}
type retryCounter struct {
max int
counter int
}
func ReadRetry(cfg RetryConfig) *retryCounter {
return &retryCounter{
max: cfg.ReadRetry(),
counter: 0,
}
}
func WriteRetry(cfg RetryConfig) *retryCounter {
return &retryCounter{
max: cfg.WriteRetry(),
counter: 0,
}
}
func (r *retryCounter) IsContinue() bool {
if r.counter == r.max {
return false
}
r.counter++
return true
}
func (r *retryCounter) Reset() {
r.counter = 0
}
func (r *retryCounter) MaxCounterExceeded() bool {
return r.counter == r.max
}
func Bounce(ctx context.Context) error {
var (
names []string
err error
cc abstract.ConnectionConfig
bp *BufPool
)
wg := &sync.WaitGroup{}
if names, err = config.ListConnection(); err != nil {
return fmt.Errorf("bounce: connection list: %w", err)
}
bp = &BufPool{
pool: sync.Pool{
New: func() any {
return make([]byte, BUF_SIZE)
},
},
}
for _, name := range names {
if cc, err = config.GetConnection(name); err != nil {
return fmt.Errorf("bounce: connection get %s: %w", name, err)
@@ -53,16 +117,16 @@ func Bounce(ctx context.Context) error {
wg.Add(1)
if cc.Type() == abstract.TCP {
go RunTCP(wg, ctx, cc.AsTCP())
go RunTCP(wg, ctx, bp, cc.AsTCP())
} else {
go RunUDP(wg, ctx, cc.AsUDP())
go RunUDP(wg, ctx, bp, cc.AsUDP())
}
}
return nil
}
func RunTCP(wg *sync.WaitGroup, ctx context.Context, cfg abstract.TCPConnectionConfig) {
func RunTCP(wg *sync.WaitGroup, ctx context.Context, bp *BufPool, cfg abstract.TCPConnectionConfig) {
defer wg.Done()
var (
@@ -86,6 +150,10 @@ func RunTCP(wg *sync.WaitGroup, ctx context.Context, cfg abstract.TCPConnectionC
acceptIncoming:
for {
if c, err = listener.Accept(); err != nil {
if errors.Is(err, net.ErrClosed) {
log.Error().Err(err).Msg("stop accepting new connection")
break acceptIncoming
}
log.Error().Err(err).Msg("accept TCP connection")
continue
}
@@ -103,14 +171,14 @@ connStarter:
select {
case c := <-accepted:
wg.Add(1)
go makeTCPConnection(wg, ctx, cfg, c)
go makeTCPConnection(wg, ctx, bp, cfg, c)
case <-ctx.Done():
break connStarter
}
}
}
func makeTCPConnection(wg *sync.WaitGroup, ctx context.Context, cfg abstract.TCPConnectionConfig, conn net.Conn) {
func makeTCPConnection(wg *sync.WaitGroup, ctx context.Context, bp *BufPool, cfg abstract.TCPConnectionConfig, conn net.Conn) {
defer wg.Done()
var (
@@ -126,35 +194,45 @@ func makeTCPConnection(wg *sync.WaitGroup, ctx context.Context, cfg abstract.TCP
ctxConn, cancelConn := context.WithCancel(ctx)
wg.Add(1)
go tcpClient2Backend(wg, ctxConn, cfg, cancelConn, conn, backend)
go tcpClient2Backend(wg, ctxConn, bp, cfg, cancelConn, conn, backend)
wg.Add(1)
go tcpBackend2Client(wg, ctxConn, cfg, cancelConn, backend, conn)
go tcpBackend2Client(wg, ctxConn, bp, cfg, cancelConn, backend, conn)
<-ctxConn.Done()
_ = backend.Close()
_ = conn.Close()
}
func tcpBackend2Client(wg *sync.WaitGroup, ctx context.Context, cfg abstract.TCPConnectionConfig, cf context.CancelFunc, backend, client net.Conn) {
func tcpBackend2Client(wg *sync.WaitGroup, ctx context.Context, bp *BufPool, cfg abstract.TCPConnectionConfig, cf context.CancelFunc, backend, client net.Conn) {
defer wg.Done()
var (
buf []byte
n, wn int
cancel context.CancelFunc
err error
buf []byte
n, wn int
cancel context.CancelFunc
err error
readRetryCounter *retryCounter
writeRetryCounter *retryCounter
)
cancel = cf
buf = make([]byte, BUF_SIZE)
buf = bp.Get()
defer bp.Put(buf)
readRetryCounter = ReadRetry(cfg)
writeRetryCounter = WriteRetry(cfg)
backendRead:
for {
for readRetryCounter.IsContinue() {
if n, err = backend.Read(buf); err != nil {
log.Error().Err(err).Str(CONNECTION, cfg.Name()).Str(DIRECTION, BACKEND_TO_CLIENT).Msg("tcp read error")
cancel()
break backendRead
if errors.Is(err, net.ErrClosed) && n == 0 {
cancel()
break backendRead
}
continue backendRead
}
readRetryCounter.Reset()
select {
case <-ctx.Done():
@@ -162,8 +240,22 @@ backendRead:
default:
}
if wn, err = client.Write(buf[:n]); err != nil {
log.Error().Err(err).Str(CONNECTION, cfg.Name()).Str(DIRECTION, BACKEND_TO_CLIENT).Msg("tcp read error")
backendWrite:
for writeRetryCounter.IsContinue() {
if wn, err = client.Write(buf[:n]); err != nil {
log.Error().Err(err).Str(CONNECTION, cfg.Name()).Str(DIRECTION, BACKEND_TO_CLIENT).Msg("tcp read error")
if errors.Is(err, net.ErrClosed) {
cancel()
break backendRead
}
continue backendWrite
}
writeRetryCounter.Reset()
break backendWrite
}
if writeRetryCounter.MaxCounterExceeded() {
log.Error().Str(CONNECTION, cfg.Name()).Str(DIRECTION, BACKEND_TO_CLIENT).Msg("tcp write retry exceeded")
cancel()
break backendRead
}
@@ -180,27 +272,41 @@ backendRead:
default:
}
}
if readRetryCounter.MaxCounterExceeded() {
log.Error().Str(CONNECTION, cfg.Name()).Str(DIRECTION, BACKEND_TO_CLIENT).Msg("tcp read retry exceeded")
cancel()
}
}
func tcpClient2Backend(wg *sync.WaitGroup, ctx context.Context, cfg abstract.TCPConnectionConfig, cf context.CancelFunc, client, backend net.Conn) {
func tcpClient2Backend(wg *sync.WaitGroup, ctx context.Context, bp *BufPool, cfg abstract.TCPConnectionConfig, cf context.CancelFunc, client, backend net.Conn) {
defer wg.Done()
var (
buf []byte
n, wn int
cancel context.CancelFunc
err error
buf []byte
n, wn int
cancel context.CancelFunc
err error
readRetryCounter *retryCounter
writeRetryCounter *retryCounter
)
cancel = cf
buf = make([]byte, BUF_SIZE)
buf = bp.Get()
defer bp.Put(buf)
readRetryCounter = ReadRetry(cfg)
writeRetryCounter = WriteRetry(cfg)
clientRead:
for {
for readRetryCounter.IsContinue() {
if n, err = client.Read(buf); err != nil {
log.Error().Err(err).Str(CONNECTION, cfg.Name()).Str(DIRECTION, CLIENT_TO_BACKEND).Msg("tcp read error")
cancel()
break clientRead
if errors.Is(err, net.ErrClosed) && n == 0 {
cancel()
break clientRead
}
continue clientRead
}
readRetryCounter.Reset()
select {
case <-ctx.Done():
@@ -208,8 +314,22 @@ clientRead:
default:
}
if wn, err = backend.Write(buf[:n]); err != nil {
log.Error().Err(err).Str(CONNECTION, cfg.Name()).Str(DIRECTION, CLIENT_TO_BACKEND).Msg("tcp write error")
clientWrite:
for writeRetryCounter.IsContinue() {
if wn, err = backend.Write(buf[:n]); err != nil {
log.Error().Err(err).Str(CONNECTION, cfg.Name()).Str(DIRECTION, CLIENT_TO_BACKEND).Msg("tcp write error")
if errors.Is(err, net.ErrClosed) {
cancel()
break clientRead
}
continue clientWrite
}
writeRetryCounter.Reset()
break clientWrite
}
if writeRetryCounter.MaxCounterExceeded() {
log.Error().Str(CONNECTION, cfg.Name()).Str(DIRECTION, CLIENT_TO_BACKEND).Msg("tcp write retry exceeded")
cancel()
break clientRead
}
@@ -226,9 +346,13 @@ clientRead:
default:
}
}
if readRetryCounter.MaxCounterExceeded() {
log.Error().Str(CONNECTION, cfg.Name()).Str(DIRECTION, CLIENT_TO_BACKEND).Msg("tcp read retry exceeded")
cancel()
}
}
func RunUDP(wg *sync.WaitGroup, ctx context.Context, cfg abstract.UDPConnectionConfig) {
func RunUDP(wg *sync.WaitGroup, ctx context.Context, bp *BufPool, cfg abstract.UDPConnectionConfig) {
defer wg.Done()
var (
@@ -245,8 +369,9 @@ func RunUDP(wg *sync.WaitGroup, ctx context.Context, cfg abstract.UDPConnectionC
}
wg.Add(1)
backend = initUDP(wg, ctx, cfg, client)
buf = make([]byte, BUF_SIZE)
backend = initUDP(wg, ctx, bp, cfg, client)
buf = bp.Get()
defer bp.Put(buf)
udpReadLoop:
for {
@@ -262,7 +387,7 @@ udpReadLoop:
if err = backend.Send(ctx, addr.String(), buf[:n]); err != nil {
log.Error().Err(err).Str(CONNECTION, cfg.Name()).Str(DIRECTION, CLIENT_TO_BACKEND).Msg("send udp message")
//continue udpReadLoop
//TODO: continue udpReadLoop
}
select {