refactor the error handling and optimize the heap allocation
This commit is contained in:
191
bounce.go
191
bounce.go
@@ -18,6 +18,7 @@ package netbounce
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
@@ -35,17 +36,80 @@ const (
|
||||
BUF_SIZE = 4096
|
||||
)
|
||||
|
||||
type RetryConfig interface {
|
||||
ReadRetry() int
|
||||
WriteRetry() int
|
||||
}
|
||||
|
||||
type BufPool struct {
|
||||
pool sync.Pool
|
||||
}
|
||||
|
||||
func (b *BufPool) Get() []byte {
|
||||
return b.pool.Get().([]byte)
|
||||
}
|
||||
|
||||
func (b *BufPool) Put(buf []byte) {
|
||||
//lint:ignore SA6002 copying a 3-word slice header is intentional and cheap
|
||||
b.pool.Put(buf[:cap(buf)])
|
||||
}
|
||||
|
||||
type retryCounter struct {
|
||||
max int
|
||||
counter int
|
||||
}
|
||||
|
||||
func ReadRetry(cfg RetryConfig) *retryCounter {
|
||||
return &retryCounter{
|
||||
max: cfg.ReadRetry(),
|
||||
counter: 0,
|
||||
}
|
||||
}
|
||||
|
||||
func WriteRetry(cfg RetryConfig) *retryCounter {
|
||||
return &retryCounter{
|
||||
max: cfg.WriteRetry(),
|
||||
counter: 0,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *retryCounter) IsContinue() bool {
|
||||
if r.counter == r.max {
|
||||
return false
|
||||
}
|
||||
|
||||
r.counter++
|
||||
return true
|
||||
}
|
||||
|
||||
func (r *retryCounter) Reset() {
|
||||
r.counter = 0
|
||||
}
|
||||
|
||||
func (r *retryCounter) MaxCounterExceeded() bool {
|
||||
return r.counter == r.max
|
||||
}
|
||||
|
||||
func Bounce(ctx context.Context) error {
|
||||
var (
|
||||
names []string
|
||||
err error
|
||||
cc abstract.ConnectionConfig
|
||||
bp *BufPool
|
||||
)
|
||||
|
||||
wg := &sync.WaitGroup{}
|
||||
if names, err = config.ListConnection(); err != nil {
|
||||
return fmt.Errorf("bounce: connection list: %w", err)
|
||||
}
|
||||
|
||||
bp = &BufPool{
|
||||
pool: sync.Pool{
|
||||
New: func() any {
|
||||
return make([]byte, BUF_SIZE)
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, name := range names {
|
||||
if cc, err = config.GetConnection(name); err != nil {
|
||||
return fmt.Errorf("bounce: connection get %s: %w", name, err)
|
||||
@@ -53,16 +117,16 @@ func Bounce(ctx context.Context) error {
|
||||
|
||||
wg.Add(1)
|
||||
if cc.Type() == abstract.TCP {
|
||||
go RunTCP(wg, ctx, cc.AsTCP())
|
||||
go RunTCP(wg, ctx, bp, cc.AsTCP())
|
||||
} else {
|
||||
go RunUDP(wg, ctx, cc.AsUDP())
|
||||
go RunUDP(wg, ctx, bp, cc.AsUDP())
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func RunTCP(wg *sync.WaitGroup, ctx context.Context, cfg abstract.TCPConnectionConfig) {
|
||||
func RunTCP(wg *sync.WaitGroup, ctx context.Context, bp *BufPool, cfg abstract.TCPConnectionConfig) {
|
||||
defer wg.Done()
|
||||
|
||||
var (
|
||||
@@ -86,6 +150,10 @@ func RunTCP(wg *sync.WaitGroup, ctx context.Context, cfg abstract.TCPConnectionC
|
||||
acceptIncoming:
|
||||
for {
|
||||
if c, err = listener.Accept(); err != nil {
|
||||
if errors.Is(err, net.ErrClosed) {
|
||||
log.Error().Err(err).Msg("stop accepting new connection")
|
||||
break acceptIncoming
|
||||
}
|
||||
log.Error().Err(err).Msg("accept TCP connection")
|
||||
continue
|
||||
}
|
||||
@@ -103,14 +171,14 @@ connStarter:
|
||||
select {
|
||||
case c := <-accepted:
|
||||
wg.Add(1)
|
||||
go makeTCPConnection(wg, ctx, cfg, c)
|
||||
go makeTCPConnection(wg, ctx, bp, cfg, c)
|
||||
case <-ctx.Done():
|
||||
break connStarter
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func makeTCPConnection(wg *sync.WaitGroup, ctx context.Context, cfg abstract.TCPConnectionConfig, conn net.Conn) {
|
||||
func makeTCPConnection(wg *sync.WaitGroup, ctx context.Context, bp *BufPool, cfg abstract.TCPConnectionConfig, conn net.Conn) {
|
||||
defer wg.Done()
|
||||
|
||||
var (
|
||||
@@ -126,35 +194,45 @@ func makeTCPConnection(wg *sync.WaitGroup, ctx context.Context, cfg abstract.TCP
|
||||
ctxConn, cancelConn := context.WithCancel(ctx)
|
||||
|
||||
wg.Add(1)
|
||||
go tcpClient2Backend(wg, ctxConn, cfg, cancelConn, conn, backend)
|
||||
go tcpClient2Backend(wg, ctxConn, bp, cfg, cancelConn, conn, backend)
|
||||
|
||||
wg.Add(1)
|
||||
go tcpBackend2Client(wg, ctxConn, cfg, cancelConn, backend, conn)
|
||||
go tcpBackend2Client(wg, ctxConn, bp, cfg, cancelConn, backend, conn)
|
||||
|
||||
<-ctxConn.Done()
|
||||
_ = backend.Close()
|
||||
_ = conn.Close()
|
||||
}
|
||||
|
||||
func tcpBackend2Client(wg *sync.WaitGroup, ctx context.Context, cfg abstract.TCPConnectionConfig, cf context.CancelFunc, backend, client net.Conn) {
|
||||
func tcpBackend2Client(wg *sync.WaitGroup, ctx context.Context, bp *BufPool, cfg abstract.TCPConnectionConfig, cf context.CancelFunc, backend, client net.Conn) {
|
||||
defer wg.Done()
|
||||
|
||||
var (
|
||||
buf []byte
|
||||
n, wn int
|
||||
cancel context.CancelFunc
|
||||
err error
|
||||
buf []byte
|
||||
n, wn int
|
||||
cancel context.CancelFunc
|
||||
err error
|
||||
readRetryCounter *retryCounter
|
||||
writeRetryCounter *retryCounter
|
||||
)
|
||||
|
||||
cancel = cf
|
||||
buf = make([]byte, BUF_SIZE)
|
||||
buf = bp.Get()
|
||||
defer bp.Put(buf)
|
||||
|
||||
readRetryCounter = ReadRetry(cfg)
|
||||
writeRetryCounter = WriteRetry(cfg)
|
||||
backendRead:
|
||||
for {
|
||||
for readRetryCounter.IsContinue() {
|
||||
if n, err = backend.Read(buf); err != nil {
|
||||
log.Error().Err(err).Str(CONNECTION, cfg.Name()).Str(DIRECTION, BACKEND_TO_CLIENT).Msg("tcp read error")
|
||||
cancel()
|
||||
break backendRead
|
||||
if errors.Is(err, net.ErrClosed) && n == 0 {
|
||||
cancel()
|
||||
break backendRead
|
||||
}
|
||||
continue backendRead
|
||||
}
|
||||
readRetryCounter.Reset()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
@@ -162,8 +240,22 @@ backendRead:
|
||||
default:
|
||||
}
|
||||
|
||||
if wn, err = client.Write(buf[:n]); err != nil {
|
||||
log.Error().Err(err).Str(CONNECTION, cfg.Name()).Str(DIRECTION, BACKEND_TO_CLIENT).Msg("tcp read error")
|
||||
backendWrite:
|
||||
for writeRetryCounter.IsContinue() {
|
||||
if wn, err = client.Write(buf[:n]); err != nil {
|
||||
log.Error().Err(err).Str(CONNECTION, cfg.Name()).Str(DIRECTION, BACKEND_TO_CLIENT).Msg("tcp read error")
|
||||
if errors.Is(err, net.ErrClosed) {
|
||||
cancel()
|
||||
break backendRead
|
||||
}
|
||||
continue backendWrite
|
||||
}
|
||||
|
||||
writeRetryCounter.Reset()
|
||||
break backendWrite
|
||||
}
|
||||
if writeRetryCounter.MaxCounterExceeded() {
|
||||
log.Error().Str(CONNECTION, cfg.Name()).Str(DIRECTION, BACKEND_TO_CLIENT).Msg("tcp write retry exceeded")
|
||||
cancel()
|
||||
break backendRead
|
||||
}
|
||||
@@ -180,27 +272,41 @@ backendRead:
|
||||
default:
|
||||
}
|
||||
}
|
||||
if readRetryCounter.MaxCounterExceeded() {
|
||||
log.Error().Str(CONNECTION, cfg.Name()).Str(DIRECTION, BACKEND_TO_CLIENT).Msg("tcp read retry exceeded")
|
||||
cancel()
|
||||
}
|
||||
}
|
||||
|
||||
func tcpClient2Backend(wg *sync.WaitGroup, ctx context.Context, cfg abstract.TCPConnectionConfig, cf context.CancelFunc, client, backend net.Conn) {
|
||||
func tcpClient2Backend(wg *sync.WaitGroup, ctx context.Context, bp *BufPool, cfg abstract.TCPConnectionConfig, cf context.CancelFunc, client, backend net.Conn) {
|
||||
defer wg.Done()
|
||||
|
||||
var (
|
||||
buf []byte
|
||||
n, wn int
|
||||
cancel context.CancelFunc
|
||||
err error
|
||||
buf []byte
|
||||
n, wn int
|
||||
cancel context.CancelFunc
|
||||
err error
|
||||
readRetryCounter *retryCounter
|
||||
writeRetryCounter *retryCounter
|
||||
)
|
||||
|
||||
cancel = cf
|
||||
buf = make([]byte, BUF_SIZE)
|
||||
buf = bp.Get()
|
||||
defer bp.Put(buf)
|
||||
|
||||
readRetryCounter = ReadRetry(cfg)
|
||||
writeRetryCounter = WriteRetry(cfg)
|
||||
clientRead:
|
||||
for {
|
||||
for readRetryCounter.IsContinue() {
|
||||
if n, err = client.Read(buf); err != nil {
|
||||
log.Error().Err(err).Str(CONNECTION, cfg.Name()).Str(DIRECTION, CLIENT_TO_BACKEND).Msg("tcp read error")
|
||||
cancel()
|
||||
break clientRead
|
||||
if errors.Is(err, net.ErrClosed) && n == 0 {
|
||||
cancel()
|
||||
break clientRead
|
||||
}
|
||||
continue clientRead
|
||||
}
|
||||
readRetryCounter.Reset()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
@@ -208,8 +314,22 @@ clientRead:
|
||||
default:
|
||||
}
|
||||
|
||||
if wn, err = backend.Write(buf[:n]); err != nil {
|
||||
log.Error().Err(err).Str(CONNECTION, cfg.Name()).Str(DIRECTION, CLIENT_TO_BACKEND).Msg("tcp write error")
|
||||
clientWrite:
|
||||
for writeRetryCounter.IsContinue() {
|
||||
if wn, err = backend.Write(buf[:n]); err != nil {
|
||||
log.Error().Err(err).Str(CONNECTION, cfg.Name()).Str(DIRECTION, CLIENT_TO_BACKEND).Msg("tcp write error")
|
||||
if errors.Is(err, net.ErrClosed) {
|
||||
cancel()
|
||||
break clientRead
|
||||
}
|
||||
continue clientWrite
|
||||
}
|
||||
|
||||
writeRetryCounter.Reset()
|
||||
break clientWrite
|
||||
}
|
||||
if writeRetryCounter.MaxCounterExceeded() {
|
||||
log.Error().Str(CONNECTION, cfg.Name()).Str(DIRECTION, CLIENT_TO_BACKEND).Msg("tcp write retry exceeded")
|
||||
cancel()
|
||||
break clientRead
|
||||
}
|
||||
@@ -226,9 +346,13 @@ clientRead:
|
||||
default:
|
||||
}
|
||||
}
|
||||
if readRetryCounter.MaxCounterExceeded() {
|
||||
log.Error().Str(CONNECTION, cfg.Name()).Str(DIRECTION, CLIENT_TO_BACKEND).Msg("tcp read retry exceeded")
|
||||
cancel()
|
||||
}
|
||||
}
|
||||
|
||||
func RunUDP(wg *sync.WaitGroup, ctx context.Context, cfg abstract.UDPConnectionConfig) {
|
||||
func RunUDP(wg *sync.WaitGroup, ctx context.Context, bp *BufPool, cfg abstract.UDPConnectionConfig) {
|
||||
defer wg.Done()
|
||||
|
||||
var (
|
||||
@@ -245,8 +369,9 @@ func RunUDP(wg *sync.WaitGroup, ctx context.Context, cfg abstract.UDPConnectionC
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
backend = initUDP(wg, ctx, cfg, client)
|
||||
buf = make([]byte, BUF_SIZE)
|
||||
backend = initUDP(wg, ctx, bp, cfg, client)
|
||||
buf = bp.Get()
|
||||
defer bp.Put(buf)
|
||||
|
||||
udpReadLoop:
|
||||
for {
|
||||
@@ -262,7 +387,7 @@ udpReadLoop:
|
||||
|
||||
if err = backend.Send(ctx, addr.String(), buf[:n]); err != nil {
|
||||
log.Error().Err(err).Str(CONNECTION, cfg.Name()).Str(DIRECTION, CLIENT_TO_BACKEND).Msg("send udp message")
|
||||
//continue udpReadLoop
|
||||
//TODO: continue udpReadLoop
|
||||
}
|
||||
|
||||
select {
|
||||
|
||||
Reference in New Issue
Block a user