refactor the error handling and optimize the heap allocation

This commit is contained in:
Suyono 2025-05-04 13:32:31 +10:00
parent 4abc00f07c
commit 0c1a9fb7e7
6 changed files with 242 additions and 66 deletions

View File

@ -40,6 +40,8 @@ type ConnectionConfig interface {
type TCPConnectionConfig interface { type TCPConnectionConfig interface {
ConnectionConfig ConnectionConfig
KeepAlive() bool KeepAlive() bool
ReadRetry() int
WriteRetry() int
} }
type UDPConnectionConfig interface { type UDPConnectionConfig interface {

View File

@ -1,7 +1,6 @@
package netbounce package netbounce
import ( import (
"bytes"
"context" "context"
"errors" "errors"
"fmt" "fmt"
@ -26,13 +25,13 @@ type backendUDP struct {
relMtx *sync.Mutex relMtx *sync.Mutex
client net.PacketConn client net.PacketConn
cfg abstract.UDPConnectionConfig cfg abstract.UDPConnectionConfig
bufPool sync.Pool bufPool *BufPool
msgChan chan udpMessage msgChan chan udpMessage
} }
type udpMessage struct { type udpMessage struct {
addr string addr string
buf *bytes.Buffer buf []byte
} }
func (r udpRel) Network() string { func (r udpRel) Network() string {
@ -43,18 +42,14 @@ func (r udpRel) String() string {
return r.clientAddr return r.clientAddr
} }
func initUDP(wg *sync.WaitGroup, ctx context.Context, cfg abstract.UDPConnectionConfig, client net.PacketConn) *backendUDP { func initUDP(wg *sync.WaitGroup, ctx context.Context, bp *BufPool, cfg abstract.UDPConnectionConfig, client net.PacketConn) *backendUDP {
backend := &backendUDP{ backend := &backendUDP{
relations: make(map[string]udpRel), relations: make(map[string]udpRel),
relMtx: new(sync.Mutex), relMtx: new(sync.Mutex),
cfg: cfg, cfg: cfg,
client: client, client: client,
msgChan: make(chan udpMessage), msgChan: make(chan udpMessage),
bufPool: sync.Pool{ bufPool: bp,
New: func() any {
return new(bytes.Buffer)
},
},
} }
go func(wg *sync.WaitGroup, ctx context.Context, cfg abstract.UDPConnectionConfig, backend *backendUDP) { go func(wg *sync.WaitGroup, ctx context.Context, cfg abstract.UDPConnectionConfig, backend *backendUDP) {
@ -104,7 +99,6 @@ func (b *backendUDP) createRelSend(clientAddr string, buf []byte) (udpRel, error
var ( var (
udpAddr *net.UDPAddr udpAddr *net.UDPAddr
udpConn *net.UDPConn udpConn *net.UDPConn
n int
err error err error
) )
@ -120,7 +114,8 @@ func (b *backendUDP) createRelSend(clientAddr string, buf []byte) (udpRel, error
return rel, fmt.Errorf("create udp relation and send message: dial udp: %w", err) return rel, fmt.Errorf("create udp relation and send message: dial udp: %w", err)
} }
if n, err = udpConn.WriteTo(buf, b.cfg.BackendAddr()); err != nil && n == 0 { if _, err = udpConn.WriteTo(buf, b.cfg.BackendAddr()); err != nil {
//TODO: I think I need to fix this. This error handling is odd.
_ = udpConn.Close() _ = udpConn.Close()
return rel, fmt.Errorf("create udp relation and send message: write udp: %w", err) return rel, fmt.Errorf("create udp relation and send message: write udp: %w", err)
} }
@ -156,14 +151,12 @@ func (b *backendUDP) handle(wg *sync.WaitGroup, ctx context.Context, msg udpMess
) )
if rel, ok = b.findRel(msg.addr); !ok { if rel, ok = b.findRel(msg.addr); !ok {
if rel, err = b.createRelSend(msg.addr, msg.buf.Bytes()); err != nil { if rel, err = b.createRelSend(msg.addr, msg.buf); err != nil {
msg.buf.Reset()
b.bufPool.Put(msg.buf) b.bufPool.Put(msg.buf)
log.Error().Err(err).Str(DIRECTION, CLIENT_TO_BACKEND).Msg("establish relation with udp backend") log.Error().Err(err).Str(DIRECTION, CLIENT_TO_BACKEND).Msg("establish relation with udp backend")
return return
} }
msg.buf.Reset()
b.bufPool.Put(msg.buf) b.bufPool.Put(msg.buf)
rel.ctx, rel.ctxCancel = context.WithCancel(ctx) rel.ctx, rel.ctxCancel = context.WithCancel(ctx)
@ -174,14 +167,12 @@ func (b *backendUDP) handle(wg *sync.WaitGroup, ctx context.Context, msg udpMess
return return
} }
if err = b.relSend(rel, msg.buf.Bytes()); err != nil { if err = b.relSend(rel, msg.buf); err != nil {
msg.buf.Reset()
b.bufPool.Put(msg.buf) b.bufPool.Put(msg.buf)
log.Error().Err(err).Msg("handle: send for existing relation") log.Error().Err(err).Msg("handle: send for existing relation")
return return
} }
msg.buf.Reset()
b.bufPool.Put(msg.buf) b.bufPool.Put(msg.buf)
rel.expiry = time.Now().Add(b.cfg.Timeout()) rel.expiry = time.Now().Add(b.cfg.Timeout())
@ -191,17 +182,12 @@ func (b *backendUDP) handle(wg *sync.WaitGroup, ctx context.Context, msg udpMess
func (b *backendUDP) Send(ctx context.Context, addr string, p []byte) error { func (b *backendUDP) Send(ctx context.Context, addr string, p []byte) error {
var ( var (
n int n int
err error
) )
buf := b.bufPool.Get().(*bytes.Buffer)
if n, err = buf.Write(p); err != nil { buf := b.bufPool.Get()
buf.Reset() n = copy(buf, p)
b.bufPool.Put(buf)
return fmt.Errorf("send udp message to handler: %w", err)
}
if len(p) != n { if len(p) != n {
buf.Reset()
b.bufPool.Put(buf) b.bufPool.Put(buf)
return fmt.Errorf("send udp message to handler: failed to write complete message") return fmt.Errorf("send udp message to handler: failed to write complete message")
} }
@ -209,7 +195,7 @@ func (b *backendUDP) Send(ctx context.Context, addr string, p []byte) error {
select { select {
case b.msgChan <- udpMessage{ case b.msgChan <- udpMessage{
addr: addr, addr: addr,
buf: buf, buf: buf[:n],
}: }:
case <-ctx.Done(): case <-ctx.Done():
} }

159
bounce.go
View File

@ -18,6 +18,7 @@ package netbounce
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"net" "net"
"sync" "sync"
@ -35,17 +36,80 @@ const (
BUF_SIZE = 4096 BUF_SIZE = 4096
) )
type RetryConfig interface {
ReadRetry() int
WriteRetry() int
}
type BufPool struct {
pool sync.Pool
}
func (b *BufPool) Get() []byte {
return b.pool.Get().([]byte)
}
func (b *BufPool) Put(buf []byte) {
//lint:ignore SA6002 copying a 3-word slice header is intentional and cheap
b.pool.Put(buf[:cap(buf)])
}
type retryCounter struct {
max int
counter int
}
func ReadRetry(cfg RetryConfig) *retryCounter {
return &retryCounter{
max: cfg.ReadRetry(),
counter: 0,
}
}
func WriteRetry(cfg RetryConfig) *retryCounter {
return &retryCounter{
max: cfg.WriteRetry(),
counter: 0,
}
}
func (r *retryCounter) IsContinue() bool {
if r.counter == r.max {
return false
}
r.counter++
return true
}
func (r *retryCounter) Reset() {
r.counter = 0
}
func (r *retryCounter) MaxCounterExceeded() bool {
return r.counter == r.max
}
func Bounce(ctx context.Context) error { func Bounce(ctx context.Context) error {
var ( var (
names []string names []string
err error err error
cc abstract.ConnectionConfig cc abstract.ConnectionConfig
bp *BufPool
) )
wg := &sync.WaitGroup{} wg := &sync.WaitGroup{}
if names, err = config.ListConnection(); err != nil { if names, err = config.ListConnection(); err != nil {
return fmt.Errorf("bounce: connection list: %w", err) return fmt.Errorf("bounce: connection list: %w", err)
} }
bp = &BufPool{
pool: sync.Pool{
New: func() any {
return make([]byte, BUF_SIZE)
},
},
}
for _, name := range names { for _, name := range names {
if cc, err = config.GetConnection(name); err != nil { if cc, err = config.GetConnection(name); err != nil {
return fmt.Errorf("bounce: connection get %s: %w", name, err) return fmt.Errorf("bounce: connection get %s: %w", name, err)
@ -53,16 +117,16 @@ func Bounce(ctx context.Context) error {
wg.Add(1) wg.Add(1)
if cc.Type() == abstract.TCP { if cc.Type() == abstract.TCP {
go RunTCP(wg, ctx, cc.AsTCP()) go RunTCP(wg, ctx, bp, cc.AsTCP())
} else { } else {
go RunUDP(wg, ctx, cc.AsUDP()) go RunUDP(wg, ctx, bp, cc.AsUDP())
} }
} }
return nil return nil
} }
func RunTCP(wg *sync.WaitGroup, ctx context.Context, cfg abstract.TCPConnectionConfig) { func RunTCP(wg *sync.WaitGroup, ctx context.Context, bp *BufPool, cfg abstract.TCPConnectionConfig) {
defer wg.Done() defer wg.Done()
var ( var (
@ -86,6 +150,10 @@ func RunTCP(wg *sync.WaitGroup, ctx context.Context, cfg abstract.TCPConnectionC
acceptIncoming: acceptIncoming:
for { for {
if c, err = listener.Accept(); err != nil { if c, err = listener.Accept(); err != nil {
if errors.Is(err, net.ErrClosed) {
log.Error().Err(err).Msg("stop accepting new connection")
break acceptIncoming
}
log.Error().Err(err).Msg("accept TCP connection") log.Error().Err(err).Msg("accept TCP connection")
continue continue
} }
@ -103,14 +171,14 @@ connStarter:
select { select {
case c := <-accepted: case c := <-accepted:
wg.Add(1) wg.Add(1)
go makeTCPConnection(wg, ctx, cfg, c) go makeTCPConnection(wg, ctx, bp, cfg, c)
case <-ctx.Done(): case <-ctx.Done():
break connStarter break connStarter
} }
} }
} }
func makeTCPConnection(wg *sync.WaitGroup, ctx context.Context, cfg abstract.TCPConnectionConfig, conn net.Conn) { func makeTCPConnection(wg *sync.WaitGroup, ctx context.Context, bp *BufPool, cfg abstract.TCPConnectionConfig, conn net.Conn) {
defer wg.Done() defer wg.Done()
var ( var (
@ -126,17 +194,17 @@ func makeTCPConnection(wg *sync.WaitGroup, ctx context.Context, cfg abstract.TCP
ctxConn, cancelConn := context.WithCancel(ctx) ctxConn, cancelConn := context.WithCancel(ctx)
wg.Add(1) wg.Add(1)
go tcpClient2Backend(wg, ctxConn, cfg, cancelConn, conn, backend) go tcpClient2Backend(wg, ctxConn, bp, cfg, cancelConn, conn, backend)
wg.Add(1) wg.Add(1)
go tcpBackend2Client(wg, ctxConn, cfg, cancelConn, backend, conn) go tcpBackend2Client(wg, ctxConn, bp, cfg, cancelConn, backend, conn)
<-ctxConn.Done() <-ctxConn.Done()
_ = backend.Close() _ = backend.Close()
_ = conn.Close() _ = conn.Close()
} }
func tcpBackend2Client(wg *sync.WaitGroup, ctx context.Context, cfg abstract.TCPConnectionConfig, cf context.CancelFunc, backend, client net.Conn) { func tcpBackend2Client(wg *sync.WaitGroup, ctx context.Context, bp *BufPool, cfg abstract.TCPConnectionConfig, cf context.CancelFunc, backend, client net.Conn) {
defer wg.Done() defer wg.Done()
var ( var (
@ -144,17 +212,27 @@ func tcpBackend2Client(wg *sync.WaitGroup, ctx context.Context, cfg abstract.TCP
n, wn int n, wn int
cancel context.CancelFunc cancel context.CancelFunc
err error err error
readRetryCounter *retryCounter
writeRetryCounter *retryCounter
) )
cancel = cf cancel = cf
buf = make([]byte, BUF_SIZE) buf = bp.Get()
defer bp.Put(buf)
readRetryCounter = ReadRetry(cfg)
writeRetryCounter = WriteRetry(cfg)
backendRead: backendRead:
for { for readRetryCounter.IsContinue() {
if n, err = backend.Read(buf); err != nil { if n, err = backend.Read(buf); err != nil {
log.Error().Err(err).Str(CONNECTION, cfg.Name()).Str(DIRECTION, BACKEND_TO_CLIENT).Msg("tcp read error") log.Error().Err(err).Str(CONNECTION, cfg.Name()).Str(DIRECTION, BACKEND_TO_CLIENT).Msg("tcp read error")
if errors.Is(err, net.ErrClosed) && n == 0 {
cancel() cancel()
break backendRead break backendRead
} }
continue backendRead
}
readRetryCounter.Reset()
select { select {
case <-ctx.Done(): case <-ctx.Done():
@ -162,8 +240,22 @@ backendRead:
default: default:
} }
backendWrite:
for writeRetryCounter.IsContinue() {
if wn, err = client.Write(buf[:n]); err != nil { if wn, err = client.Write(buf[:n]); err != nil {
log.Error().Err(err).Str(CONNECTION, cfg.Name()).Str(DIRECTION, BACKEND_TO_CLIENT).Msg("tcp read error") log.Error().Err(err).Str(CONNECTION, cfg.Name()).Str(DIRECTION, BACKEND_TO_CLIENT).Msg("tcp read error")
if errors.Is(err, net.ErrClosed) {
cancel()
break backendRead
}
continue backendWrite
}
writeRetryCounter.Reset()
break backendWrite
}
if writeRetryCounter.MaxCounterExceeded() {
log.Error().Str(CONNECTION, cfg.Name()).Str(DIRECTION, BACKEND_TO_CLIENT).Msg("tcp write retry exceeded")
cancel() cancel()
break backendRead break backendRead
} }
@ -180,9 +272,13 @@ backendRead:
default: default:
} }
} }
if readRetryCounter.MaxCounterExceeded() {
log.Error().Str(CONNECTION, cfg.Name()).Str(DIRECTION, BACKEND_TO_CLIENT).Msg("tcp read retry exceeded")
cancel()
}
} }
func tcpClient2Backend(wg *sync.WaitGroup, ctx context.Context, cfg abstract.TCPConnectionConfig, cf context.CancelFunc, client, backend net.Conn) { func tcpClient2Backend(wg *sync.WaitGroup, ctx context.Context, bp *BufPool, cfg abstract.TCPConnectionConfig, cf context.CancelFunc, client, backend net.Conn) {
defer wg.Done() defer wg.Done()
var ( var (
@ -190,17 +286,27 @@ func tcpClient2Backend(wg *sync.WaitGroup, ctx context.Context, cfg abstract.TCP
n, wn int n, wn int
cancel context.CancelFunc cancel context.CancelFunc
err error err error
readRetryCounter *retryCounter
writeRetryCounter *retryCounter
) )
cancel = cf cancel = cf
buf = make([]byte, BUF_SIZE) buf = bp.Get()
defer bp.Put(buf)
readRetryCounter = ReadRetry(cfg)
writeRetryCounter = WriteRetry(cfg)
clientRead: clientRead:
for { for readRetryCounter.IsContinue() {
if n, err = client.Read(buf); err != nil { if n, err = client.Read(buf); err != nil {
log.Error().Err(err).Str(CONNECTION, cfg.Name()).Str(DIRECTION, CLIENT_TO_BACKEND).Msg("tcp read error") log.Error().Err(err).Str(CONNECTION, cfg.Name()).Str(DIRECTION, CLIENT_TO_BACKEND).Msg("tcp read error")
if errors.Is(err, net.ErrClosed) && n == 0 {
cancel() cancel()
break clientRead break clientRead
} }
continue clientRead
}
readRetryCounter.Reset()
select { select {
case <-ctx.Done(): case <-ctx.Done():
@ -208,8 +314,22 @@ clientRead:
default: default:
} }
clientWrite:
for writeRetryCounter.IsContinue() {
if wn, err = backend.Write(buf[:n]); err != nil { if wn, err = backend.Write(buf[:n]); err != nil {
log.Error().Err(err).Str(CONNECTION, cfg.Name()).Str(DIRECTION, CLIENT_TO_BACKEND).Msg("tcp write error") log.Error().Err(err).Str(CONNECTION, cfg.Name()).Str(DIRECTION, CLIENT_TO_BACKEND).Msg("tcp write error")
if errors.Is(err, net.ErrClosed) {
cancel()
break clientRead
}
continue clientWrite
}
writeRetryCounter.Reset()
break clientWrite
}
if writeRetryCounter.MaxCounterExceeded() {
log.Error().Str(CONNECTION, cfg.Name()).Str(DIRECTION, CLIENT_TO_BACKEND).Msg("tcp write retry exceeded")
cancel() cancel()
break clientRead break clientRead
} }
@ -226,9 +346,13 @@ clientRead:
default: default:
} }
} }
if readRetryCounter.MaxCounterExceeded() {
log.Error().Str(CONNECTION, cfg.Name()).Str(DIRECTION, CLIENT_TO_BACKEND).Msg("tcp read retry exceeded")
cancel()
}
} }
func RunUDP(wg *sync.WaitGroup, ctx context.Context, cfg abstract.UDPConnectionConfig) { func RunUDP(wg *sync.WaitGroup, ctx context.Context, bp *BufPool, cfg abstract.UDPConnectionConfig) {
defer wg.Done() defer wg.Done()
var ( var (
@ -245,8 +369,9 @@ func RunUDP(wg *sync.WaitGroup, ctx context.Context, cfg abstract.UDPConnectionC
} }
wg.Add(1) wg.Add(1)
backend = initUDP(wg, ctx, cfg, client) backend = initUDP(wg, ctx, bp, cfg, client)
buf = make([]byte, BUF_SIZE) buf = bp.Get()
defer bp.Put(buf)
udpReadLoop: udpReadLoop:
for { for {
@ -262,7 +387,7 @@ udpReadLoop:
if err = backend.Send(ctx, addr.String(), buf[:n]); err != nil { if err = backend.Send(ctx, addr.String(), buf[:n]); err != nil {
log.Error().Err(err).Str(CONNECTION, cfg.Name()).Str(DIRECTION, CLIENT_TO_BACKEND).Msg("send udp message") log.Error().Err(err).Str(CONNECTION, cfg.Name()).Str(DIRECTION, CLIENT_TO_BACKEND).Msg("send udp message")
//continue udpReadLoop //TODO: continue udpReadLoop
} }
select { select {

View File

@ -0,0 +1,40 @@
package slicewriter
import (
"sync"
"testing"
)
func TestDummy(t *testing.T) {
var (
p sync.Pool
a, b []byte
)
p = sync.Pool{
New: func() any {
return make([]byte, 4096)
},
}
a = p.Get().([]byte)
t.Logf("retrieve a from pool: len %d, cap %d", len(a), cap(a))
b = p.Get().([]byte)
t.Logf("retrieve b from pool: len %d, cap %d", len(b), cap(b))
a = a[:1024]
b = b[:1024]
t.Logf("resize a : len %d, cap %d", len(a), cap(a))
t.Logf("resize b : len %d, cap %d", len(b), cap(b))
p.Put(a[:cap(a)])
p.Put(b)
t.Log("after putting back")
a = p.Get().([]byte)
t.Logf("retrieve a from pool: len %d, cap %d", len(a), cap(a))
b = p.Get().([]byte)
t.Logf("retrieve b from pool: len %d, cap %d", len(b), cap(b))
}

View File

@ -18,10 +18,11 @@ package config
import ( import (
"fmt" "fmt"
"gitea.suyono.dev/suyono/netbounce/abstract"
"maps" "maps"
"slices" "slices"
"gitea.suyono.dev/suyono/netbounce/abstract"
"github.com/spf13/viper" "github.com/spf13/viper"
) )
@ -83,6 +84,10 @@ func GetConnection(name string) (abstract.ConnectionConfig, error) {
tcp.Name = name tcp.Name = name
tcp.Type = abstract.TCP tcp.Type = abstract.TCP
module := tcpModule{config: tcp}
if err = module.Validate(); err != nil {
return nil, fmt.Errorf("tcp config: %w", err)
}
return tcpModule{config: tcp}, nil return tcpModule{config: tcp}, nil
default: default:
return nil, fmt.Errorf("connection %s: invalid connection type: %s", name, configType) return nil, fmt.Errorf("connection %s: invalid connection type: %s", name, configType)

View File

@ -28,6 +28,8 @@ type tcpConfig struct {
Listen string `mapstructure:"listen"` Listen string `mapstructure:"listen"`
Backend string `mapstructure:"backend"` Backend string `mapstructure:"backend"`
KeepAlive bool `mapstructure:"keepalive,false"` KeepAlive bool `mapstructure:"keepalive,false"`
ReadRetry int `mapstructure:"read-retry,5"`
WriteRetry int `mapstructure:"write-retry,5"`
} }
type tcpModule struct { type tcpModule struct {
@ -62,3 +64,19 @@ func (t tcpModule) AsUDP() abstract.UDPConnectionConfig {
panic(fmt.Errorf("not UDP")) panic(fmt.Errorf("not UDP"))
// return nil // return nil
} }
func (t tcpModule) ReadRetry() int {
return t.config.ReadRetry
}
func (t tcpModule) WriteRetry() int {
return t.config.WriteRetry
}
func (t tcpModule) Validate() error {
if t.ReadRetry() <= 0 || t.WriteRetry() <= 0 {
return fmt.Errorf("invalid retry config value")
}
return nil
}