refactor the error handling and optimize the heap allocation

This commit is contained in:
Suyono 2025-05-04 13:32:31 +10:00
parent 4abc00f07c
commit 0c1a9fb7e7
6 changed files with 242 additions and 66 deletions

View File

@ -40,6 +40,8 @@ type ConnectionConfig interface {
type TCPConnectionConfig interface {
ConnectionConfig
KeepAlive() bool
ReadRetry() int
WriteRetry() int
}
type UDPConnectionConfig interface {

View File

@ -1,7 +1,6 @@
package netbounce
import (
"bytes"
"context"
"errors"
"fmt"
@ -26,13 +25,13 @@ type backendUDP struct {
relMtx *sync.Mutex
client net.PacketConn
cfg abstract.UDPConnectionConfig
bufPool sync.Pool
bufPool *BufPool
msgChan chan udpMessage
}
type udpMessage struct {
addr string
buf *bytes.Buffer
buf []byte
}
func (r udpRel) Network() string {
@ -43,18 +42,14 @@ func (r udpRel) String() string {
return r.clientAddr
}
func initUDP(wg *sync.WaitGroup, ctx context.Context, cfg abstract.UDPConnectionConfig, client net.PacketConn) *backendUDP {
func initUDP(wg *sync.WaitGroup, ctx context.Context, bp *BufPool, cfg abstract.UDPConnectionConfig, client net.PacketConn) *backendUDP {
backend := &backendUDP{
relations: make(map[string]udpRel),
relMtx: new(sync.Mutex),
cfg: cfg,
client: client,
msgChan: make(chan udpMessage),
bufPool: sync.Pool{
New: func() any {
return new(bytes.Buffer)
},
},
bufPool: bp,
}
go func(wg *sync.WaitGroup, ctx context.Context, cfg abstract.UDPConnectionConfig, backend *backendUDP) {
@ -104,7 +99,6 @@ func (b *backendUDP) createRelSend(clientAddr string, buf []byte) (udpRel, error
var (
udpAddr *net.UDPAddr
udpConn *net.UDPConn
n int
err error
)
@ -120,7 +114,8 @@ func (b *backendUDP) createRelSend(clientAddr string, buf []byte) (udpRel, error
return rel, fmt.Errorf("create udp relation and send message: dial udp: %w", err)
}
if n, err = udpConn.WriteTo(buf, b.cfg.BackendAddr()); err != nil && n == 0 {
if _, err = udpConn.WriteTo(buf, b.cfg.BackendAddr()); err != nil {
//TODO: I think I need to fix this. This error handling is odd.
_ = udpConn.Close()
return rel, fmt.Errorf("create udp relation and send message: write udp: %w", err)
}
@ -156,14 +151,12 @@ func (b *backendUDP) handle(wg *sync.WaitGroup, ctx context.Context, msg udpMess
)
if rel, ok = b.findRel(msg.addr); !ok {
if rel, err = b.createRelSend(msg.addr, msg.buf.Bytes()); err != nil {
msg.buf.Reset()
if rel, err = b.createRelSend(msg.addr, msg.buf); err != nil {
b.bufPool.Put(msg.buf)
log.Error().Err(err).Str(DIRECTION, CLIENT_TO_BACKEND).Msg("establish relation with udp backend")
return
}
msg.buf.Reset()
b.bufPool.Put(msg.buf)
rel.ctx, rel.ctxCancel = context.WithCancel(ctx)
@ -174,14 +167,12 @@ func (b *backendUDP) handle(wg *sync.WaitGroup, ctx context.Context, msg udpMess
return
}
if err = b.relSend(rel, msg.buf.Bytes()); err != nil {
msg.buf.Reset()
if err = b.relSend(rel, msg.buf); err != nil {
b.bufPool.Put(msg.buf)
log.Error().Err(err).Msg("handle: send for existing relation")
return
}
msg.buf.Reset()
b.bufPool.Put(msg.buf)
rel.expiry = time.Now().Add(b.cfg.Timeout())
@ -190,18 +181,13 @@ func (b *backendUDP) handle(wg *sync.WaitGroup, ctx context.Context, msg udpMess
func (b *backendUDP) Send(ctx context.Context, addr string, p []byte) error {
var (
n int
err error
n int
)
buf := b.bufPool.Get().(*bytes.Buffer)
if n, err = buf.Write(p); err != nil {
buf.Reset()
b.bufPool.Put(buf)
return fmt.Errorf("send udp message to handler: %w", err)
}
buf := b.bufPool.Get()
n = copy(buf, p)
if len(p) != n {
buf.Reset()
b.bufPool.Put(buf)
return fmt.Errorf("send udp message to handler: failed to write complete message")
}
@ -209,7 +195,7 @@ func (b *backendUDP) Send(ctx context.Context, addr string, p []byte) error {
select {
case b.msgChan <- udpMessage{
addr: addr,
buf: buf,
buf: buf[:n],
}:
case <-ctx.Done():
}

191
bounce.go
View File

@ -18,6 +18,7 @@ package netbounce
import (
"context"
"errors"
"fmt"
"net"
"sync"
@ -35,17 +36,80 @@ const (
BUF_SIZE = 4096
)
type RetryConfig interface {
ReadRetry() int
WriteRetry() int
}
type BufPool struct {
pool sync.Pool
}
func (b *BufPool) Get() []byte {
return b.pool.Get().([]byte)
}
func (b *BufPool) Put(buf []byte) {
//lint:ignore SA6002 copying a 3-word slice header is intentional and cheap
b.pool.Put(buf[:cap(buf)])
}
type retryCounter struct {
max int
counter int
}
func ReadRetry(cfg RetryConfig) *retryCounter {
return &retryCounter{
max: cfg.ReadRetry(),
counter: 0,
}
}
func WriteRetry(cfg RetryConfig) *retryCounter {
return &retryCounter{
max: cfg.WriteRetry(),
counter: 0,
}
}
func (r *retryCounter) IsContinue() bool {
if r.counter == r.max {
return false
}
r.counter++
return true
}
func (r *retryCounter) Reset() {
r.counter = 0
}
func (r *retryCounter) MaxCounterExceeded() bool {
return r.counter == r.max
}
func Bounce(ctx context.Context) error {
var (
names []string
err error
cc abstract.ConnectionConfig
bp *BufPool
)
wg := &sync.WaitGroup{}
if names, err = config.ListConnection(); err != nil {
return fmt.Errorf("bounce: connection list: %w", err)
}
bp = &BufPool{
pool: sync.Pool{
New: func() any {
return make([]byte, BUF_SIZE)
},
},
}
for _, name := range names {
if cc, err = config.GetConnection(name); err != nil {
return fmt.Errorf("bounce: connection get %s: %w", name, err)
@ -53,16 +117,16 @@ func Bounce(ctx context.Context) error {
wg.Add(1)
if cc.Type() == abstract.TCP {
go RunTCP(wg, ctx, cc.AsTCP())
go RunTCP(wg, ctx, bp, cc.AsTCP())
} else {
go RunUDP(wg, ctx, cc.AsUDP())
go RunUDP(wg, ctx, bp, cc.AsUDP())
}
}
return nil
}
func RunTCP(wg *sync.WaitGroup, ctx context.Context, cfg abstract.TCPConnectionConfig) {
func RunTCP(wg *sync.WaitGroup, ctx context.Context, bp *BufPool, cfg abstract.TCPConnectionConfig) {
defer wg.Done()
var (
@ -86,6 +150,10 @@ func RunTCP(wg *sync.WaitGroup, ctx context.Context, cfg abstract.TCPConnectionC
acceptIncoming:
for {
if c, err = listener.Accept(); err != nil {
if errors.Is(err, net.ErrClosed) {
log.Error().Err(err).Msg("stop accepting new connection")
break acceptIncoming
}
log.Error().Err(err).Msg("accept TCP connection")
continue
}
@ -103,14 +171,14 @@ connStarter:
select {
case c := <-accepted:
wg.Add(1)
go makeTCPConnection(wg, ctx, cfg, c)
go makeTCPConnection(wg, ctx, bp, cfg, c)
case <-ctx.Done():
break connStarter
}
}
}
func makeTCPConnection(wg *sync.WaitGroup, ctx context.Context, cfg abstract.TCPConnectionConfig, conn net.Conn) {
func makeTCPConnection(wg *sync.WaitGroup, ctx context.Context, bp *BufPool, cfg abstract.TCPConnectionConfig, conn net.Conn) {
defer wg.Done()
var (
@ -126,35 +194,45 @@ func makeTCPConnection(wg *sync.WaitGroup, ctx context.Context, cfg abstract.TCP
ctxConn, cancelConn := context.WithCancel(ctx)
wg.Add(1)
go tcpClient2Backend(wg, ctxConn, cfg, cancelConn, conn, backend)
go tcpClient2Backend(wg, ctxConn, bp, cfg, cancelConn, conn, backend)
wg.Add(1)
go tcpBackend2Client(wg, ctxConn, cfg, cancelConn, backend, conn)
go tcpBackend2Client(wg, ctxConn, bp, cfg, cancelConn, backend, conn)
<-ctxConn.Done()
_ = backend.Close()
_ = conn.Close()
}
func tcpBackend2Client(wg *sync.WaitGroup, ctx context.Context, cfg abstract.TCPConnectionConfig, cf context.CancelFunc, backend, client net.Conn) {
func tcpBackend2Client(wg *sync.WaitGroup, ctx context.Context, bp *BufPool, cfg abstract.TCPConnectionConfig, cf context.CancelFunc, backend, client net.Conn) {
defer wg.Done()
var (
buf []byte
n, wn int
cancel context.CancelFunc
err error
buf []byte
n, wn int
cancel context.CancelFunc
err error
readRetryCounter *retryCounter
writeRetryCounter *retryCounter
)
cancel = cf
buf = make([]byte, BUF_SIZE)
buf = bp.Get()
defer bp.Put(buf)
readRetryCounter = ReadRetry(cfg)
writeRetryCounter = WriteRetry(cfg)
backendRead:
for {
for readRetryCounter.IsContinue() {
if n, err = backend.Read(buf); err != nil {
log.Error().Err(err).Str(CONNECTION, cfg.Name()).Str(DIRECTION, BACKEND_TO_CLIENT).Msg("tcp read error")
cancel()
break backendRead
if errors.Is(err, net.ErrClosed) && n == 0 {
cancel()
break backendRead
}
continue backendRead
}
readRetryCounter.Reset()
select {
case <-ctx.Done():
@ -162,8 +240,22 @@ backendRead:
default:
}
if wn, err = client.Write(buf[:n]); err != nil {
log.Error().Err(err).Str(CONNECTION, cfg.Name()).Str(DIRECTION, BACKEND_TO_CLIENT).Msg("tcp read error")
backendWrite:
for writeRetryCounter.IsContinue() {
if wn, err = client.Write(buf[:n]); err != nil {
log.Error().Err(err).Str(CONNECTION, cfg.Name()).Str(DIRECTION, BACKEND_TO_CLIENT).Msg("tcp read error")
if errors.Is(err, net.ErrClosed) {
cancel()
break backendRead
}
continue backendWrite
}
writeRetryCounter.Reset()
break backendWrite
}
if writeRetryCounter.MaxCounterExceeded() {
log.Error().Str(CONNECTION, cfg.Name()).Str(DIRECTION, BACKEND_TO_CLIENT).Msg("tcp write retry exceeded")
cancel()
break backendRead
}
@ -180,27 +272,41 @@ backendRead:
default:
}
}
if readRetryCounter.MaxCounterExceeded() {
log.Error().Str(CONNECTION, cfg.Name()).Str(DIRECTION, BACKEND_TO_CLIENT).Msg("tcp read retry exceeded")
cancel()
}
}
func tcpClient2Backend(wg *sync.WaitGroup, ctx context.Context, cfg abstract.TCPConnectionConfig, cf context.CancelFunc, client, backend net.Conn) {
func tcpClient2Backend(wg *sync.WaitGroup, ctx context.Context, bp *BufPool, cfg abstract.TCPConnectionConfig, cf context.CancelFunc, client, backend net.Conn) {
defer wg.Done()
var (
buf []byte
n, wn int
cancel context.CancelFunc
err error
buf []byte
n, wn int
cancel context.CancelFunc
err error
readRetryCounter *retryCounter
writeRetryCounter *retryCounter
)
cancel = cf
buf = make([]byte, BUF_SIZE)
buf = bp.Get()
defer bp.Put(buf)
readRetryCounter = ReadRetry(cfg)
writeRetryCounter = WriteRetry(cfg)
clientRead:
for {
for readRetryCounter.IsContinue() {
if n, err = client.Read(buf); err != nil {
log.Error().Err(err).Str(CONNECTION, cfg.Name()).Str(DIRECTION, CLIENT_TO_BACKEND).Msg("tcp read error")
cancel()
break clientRead
if errors.Is(err, net.ErrClosed) && n == 0 {
cancel()
break clientRead
}
continue clientRead
}
readRetryCounter.Reset()
select {
case <-ctx.Done():
@ -208,8 +314,22 @@ clientRead:
default:
}
if wn, err = backend.Write(buf[:n]); err != nil {
log.Error().Err(err).Str(CONNECTION, cfg.Name()).Str(DIRECTION, CLIENT_TO_BACKEND).Msg("tcp write error")
clientWrite:
for writeRetryCounter.IsContinue() {
if wn, err = backend.Write(buf[:n]); err != nil {
log.Error().Err(err).Str(CONNECTION, cfg.Name()).Str(DIRECTION, CLIENT_TO_BACKEND).Msg("tcp write error")
if errors.Is(err, net.ErrClosed) {
cancel()
break clientRead
}
continue clientWrite
}
writeRetryCounter.Reset()
break clientWrite
}
if writeRetryCounter.MaxCounterExceeded() {
log.Error().Str(CONNECTION, cfg.Name()).Str(DIRECTION, CLIENT_TO_BACKEND).Msg("tcp write retry exceeded")
cancel()
break clientRead
}
@ -226,9 +346,13 @@ clientRead:
default:
}
}
if readRetryCounter.MaxCounterExceeded() {
log.Error().Str(CONNECTION, cfg.Name()).Str(DIRECTION, CLIENT_TO_BACKEND).Msg("tcp read retry exceeded")
cancel()
}
}
func RunUDP(wg *sync.WaitGroup, ctx context.Context, cfg abstract.UDPConnectionConfig) {
func RunUDP(wg *sync.WaitGroup, ctx context.Context, bp *BufPool, cfg abstract.UDPConnectionConfig) {
defer wg.Done()
var (
@ -245,8 +369,9 @@ func RunUDP(wg *sync.WaitGroup, ctx context.Context, cfg abstract.UDPConnectionC
}
wg.Add(1)
backend = initUDP(wg, ctx, cfg, client)
buf = make([]byte, BUF_SIZE)
backend = initUDP(wg, ctx, bp, cfg, client)
buf = bp.Get()
defer bp.Put(buf)
udpReadLoop:
for {
@ -262,7 +387,7 @@ udpReadLoop:
if err = backend.Send(ctx, addr.String(), buf[:n]); err != nil {
log.Error().Err(err).Str(CONNECTION, cfg.Name()).Str(DIRECTION, CLIENT_TO_BACKEND).Msg("send udp message")
//continue udpReadLoop
//TODO: continue udpReadLoop
}
select {

View File

@ -0,0 +1,40 @@
package slicewriter
import (
"sync"
"testing"
)
func TestDummy(t *testing.T) {
var (
p sync.Pool
a, b []byte
)
p = sync.Pool{
New: func() any {
return make([]byte, 4096)
},
}
a = p.Get().([]byte)
t.Logf("retrieve a from pool: len %d, cap %d", len(a), cap(a))
b = p.Get().([]byte)
t.Logf("retrieve b from pool: len %d, cap %d", len(b), cap(b))
a = a[:1024]
b = b[:1024]
t.Logf("resize a : len %d, cap %d", len(a), cap(a))
t.Logf("resize b : len %d, cap %d", len(b), cap(b))
p.Put(a[:cap(a)])
p.Put(b)
t.Log("after putting back")
a = p.Get().([]byte)
t.Logf("retrieve a from pool: len %d, cap %d", len(a), cap(a))
b = p.Get().([]byte)
t.Logf("retrieve b from pool: len %d, cap %d", len(b), cap(b))
}

View File

@ -18,10 +18,11 @@ package config
import (
"fmt"
"gitea.suyono.dev/suyono/netbounce/abstract"
"maps"
"slices"
"gitea.suyono.dev/suyono/netbounce/abstract"
"github.com/spf13/viper"
)
@ -83,6 +84,10 @@ func GetConnection(name string) (abstract.ConnectionConfig, error) {
tcp.Name = name
tcp.Type = abstract.TCP
module := tcpModule{config: tcp}
if err = module.Validate(); err != nil {
return nil, fmt.Errorf("tcp config: %w", err)
}
return tcpModule{config: tcp}, nil
default:
return nil, fmt.Errorf("connection %s: invalid connection type: %s", name, configType)

View File

@ -23,11 +23,13 @@ import (
)
type tcpConfig struct {
Name string `mapstructure:"-,"`
Type abstract.ConnectionType `mapstructure:"-,"`
Listen string `mapstructure:"listen"`
Backend string `mapstructure:"backend"`
KeepAlive bool `mapstructure:"keepalive,false"`
Name string `mapstructure:"-,"`
Type abstract.ConnectionType `mapstructure:"-,"`
Listen string `mapstructure:"listen"`
Backend string `mapstructure:"backend"`
KeepAlive bool `mapstructure:"keepalive,false"`
ReadRetry int `mapstructure:"read-retry,5"`
WriteRetry int `mapstructure:"write-retry,5"`
}
type tcpModule struct {
@ -62,3 +64,19 @@ func (t tcpModule) AsUDP() abstract.UDPConnectionConfig {
panic(fmt.Errorf("not UDP"))
// return nil
}
func (t tcpModule) ReadRetry() int {
return t.config.ReadRetry
}
func (t tcpModule) WriteRetry() int {
return t.config.WriteRetry
}
func (t tcpModule) Validate() error {
if t.ReadRetry() <= 0 || t.WriteRetry() <= 0 {
return fmt.Errorf("invalid retry config value")
}
return nil
}