package netbounce /* Copyright 2025 Suyono Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ import ( "context" "errors" "fmt" "net" "sync" "gitea.suyono.dev/suyono/netbounce/abstract" "gitea.suyono.dev/suyono/netbounce/config" "github.com/rs/zerolog/log" ) const ( CONNECTION = "connection" DIRECTION = "direction" CLIENT_TO_BACKEND = "client-to-backend" BACKEND_TO_CLIENT = "backend-to-client" BUF_SIZE = 4096 ) type RetryConfig interface { ReadRetry() int WriteRetry() int } type BufPool struct { pool sync.Pool } func (b *BufPool) Get() []byte { return b.pool.Get().([]byte) } func (b *BufPool) Put(buf []byte) { //lint:ignore SA6002 copying a 3-word slice header is intentional and cheap b.pool.Put(buf[:cap(buf)]) } type retryCounter struct { max int counter int } func ReadRetry(cfg RetryConfig) *retryCounter { return &retryCounter{ max: cfg.ReadRetry(), counter: 0, } } func WriteRetry(cfg RetryConfig) *retryCounter { return &retryCounter{ max: cfg.WriteRetry(), counter: 0, } } func (r *retryCounter) IsContinue() bool { if r.counter == r.max { return false } r.counter++ return true } func (r *retryCounter) Reset() { r.counter = 0 } func (r *retryCounter) MaxCounterExceeded() bool { return r.counter == r.max } func Bounce(ctx context.Context) error { var ( names []string err error cc abstract.ConnectionConfig bp *BufPool ) wg := &sync.WaitGroup{} if names, err = config.ListConnection(); err != nil { return fmt.Errorf("bounce: connection list: %w", err) } bp = &BufPool{ pool: sync.Pool{ New: func() any { return make([]byte, BUF_SIZE) }, }, } for _, name := range names { if cc, err = config.GetConnection(name); err != nil { return fmt.Errorf("bounce: connection get %s: %w", name, err) } wg.Add(1) if cc.Type() == abstract.TCP { go RunTCP(wg, ctx, bp, cc.AsTCP()) } else { go RunUDP(wg, ctx, bp, cc.AsUDP()) } } return nil } func RunTCP(wg *sync.WaitGroup, ctx context.Context, bp *BufPool, cfg abstract.TCPConnectionConfig) { defer wg.Done() var ( err error l net.Listener accepted chan net.Conn ) if l, err = net.Listen("tcp", cfg.Listen()); err != nil { log.Panic().Caller().Err(err).Msg("listen tcp") } accepted = make(chan net.Conn) go func(ctx context.Context, cfg abstract.TCPConnectionConfig, listener net.Listener, acceptChan chan<- net.Conn) { defer log.Info().Caller().Msgf("go routine for accepting connection quited: %s", cfg.Name()) var ( err error c net.Conn ) acceptIncoming: for { if c, err = listener.Accept(); err != nil { if errors.Is(err, net.ErrClosed) { log.Error().Caller().Err(err).Msg("stop accepting new connection") break acceptIncoming } log.Error().Caller().Err(err).Msg("accept TCP connection") continue } select { case acceptChan <- c: case <-ctx.Done(): break acceptIncoming } } }(ctx, cfg, l, accepted) connStarter: for { select { case c := <-accepted: wg.Add(1) go makeTCPConnection(wg, ctx, bp, cfg, c) case <-ctx.Done(): break connStarter } } } func makeTCPConnection(wg *sync.WaitGroup, ctx context.Context, bp *BufPool, cfg abstract.TCPConnectionConfig, conn net.Conn) { defer wg.Done() var ( backend net.Conn err error ) if backend, err = net.Dial("tcp", cfg.Backend()); err != nil { log.Error().Caller().Err(err).Str(CONNECTION, cfg.Name()).Msg("connection to backend failed") return } ctxConn, cancelConn := context.WithCancel(ctx) wg.Add(1) go tcpClient2Backend(wg, ctxConn, bp, cfg, cancelConn, conn, backend) wg.Add(1) go tcpBackend2Client(wg, ctxConn, bp, cfg, cancelConn, backend, conn) <-ctxConn.Done() _ = backend.Close() _ = conn.Close() } func tcpBackend2Client(wg *sync.WaitGroup, ctx context.Context, bp *BufPool, cfg abstract.TCPConnectionConfig, cf context.CancelFunc, backend, client net.Conn) { defer wg.Done() var ( buf []byte n, wn int cancel context.CancelFunc err error readRetryCounter *retryCounter writeRetryCounter *retryCounter ) cancel = cf buf = bp.Get() defer bp.Put(buf) readRetryCounter = ReadRetry(cfg) writeRetryCounter = WriteRetry(cfg) backendRead: for readRetryCounter.IsContinue() { if n, err = backend.Read(buf); err != nil { log.Error().Caller().Err(err).Str(CONNECTION, cfg.Name()).Str(DIRECTION, BACKEND_TO_CLIENT).Msg("tcp read error") if errors.Is(err, net.ErrClosed) && n == 0 { cancel() break backendRead } continue backendRead } readRetryCounter.Reset() select { case <-ctx.Done(): break backendRead default: } backendWrite: for writeRetryCounter.IsContinue() { if wn, err = client.Write(buf[:n]); err != nil { log.Error().Caller().Err(err).Str(CONNECTION, cfg.Name()).Str(DIRECTION, BACKEND_TO_CLIENT).Msg("tcp read error") if errors.Is(err, net.ErrClosed) { cancel() break backendRead } continue backendWrite } writeRetryCounter.Reset() break backendWrite } if writeRetryCounter.MaxCounterExceeded() { log.Error().Caller().Str(CONNECTION, cfg.Name()).Str(DIRECTION, BACKEND_TO_CLIENT).Msg("tcp write retry exceeded") cancel() break backendRead } if wn != n { log.Error().Caller().Err(fmt.Errorf("mismatch length between read and write")).Str(CONNECTION, cfg.Name()).Str(DIRECTION, BACKEND_TO_CLIENT).Msg("tcp read problem") cancel() break backendRead } select { case <-ctx.Done(): break backendRead default: } } if readRetryCounter.MaxCounterExceeded() { log.Error().Caller().Str(CONNECTION, cfg.Name()).Str(DIRECTION, BACKEND_TO_CLIENT).Msg("tcp read retry exceeded") cancel() } } func tcpClient2Backend(wg *sync.WaitGroup, ctx context.Context, bp *BufPool, cfg abstract.TCPConnectionConfig, cf context.CancelFunc, client, backend net.Conn) { defer wg.Done() var ( buf []byte n, wn int cancel context.CancelFunc err error readRetryCounter *retryCounter writeRetryCounter *retryCounter ) cancel = cf buf = bp.Get() defer bp.Put(buf) readRetryCounter = ReadRetry(cfg) writeRetryCounter = WriteRetry(cfg) clientRead: for readRetryCounter.IsContinue() { if n, err = client.Read(buf); err != nil { log.Error().Caller().Err(err).Str(CONNECTION, cfg.Name()).Str(DIRECTION, CLIENT_TO_BACKEND).Msg("tcp read error") if errors.Is(err, net.ErrClosed) && n == 0 { cancel() break clientRead } continue clientRead } readRetryCounter.Reset() select { case <-ctx.Done(): break clientRead default: } clientWrite: for writeRetryCounter.IsContinue() { if wn, err = backend.Write(buf[:n]); err != nil { log.Error().Caller().Err(err).Str(CONNECTION, cfg.Name()).Str(DIRECTION, CLIENT_TO_BACKEND).Msg("tcp write error") if errors.Is(err, net.ErrClosed) { cancel() break clientRead } continue clientWrite } writeRetryCounter.Reset() break clientWrite } if writeRetryCounter.MaxCounterExceeded() { log.Error().Caller().Str(CONNECTION, cfg.Name()).Str(DIRECTION, CLIENT_TO_BACKEND).Msg("tcp write retry exceeded") cancel() break clientRead } if wn != n { log.Error().Err(fmt.Errorf("mismatch length between read and write")).Str(CONNECTION, cfg.Name()).Str(DIRECTION, CLIENT_TO_BACKEND).Msg("tcp write problem") cancel() break clientRead } select { case <-ctx.Done(): break clientRead default: } } if readRetryCounter.MaxCounterExceeded() { log.Error().Caller().Str(CONNECTION, cfg.Name()).Str(DIRECTION, CLIENT_TO_BACKEND).Msg("tcp read retry exceeded") cancel() } } func RunUDP(wg *sync.WaitGroup, ctx context.Context, bp *BufPool, cfg abstract.UDPConnectionConfig) { defer wg.Done() var ( client net.PacketConn buf []byte addr net.Addr err error n int backend *backendUDP ) if client, err = net.ListenPacket("udp", cfg.Listen()); err != nil { log.Panic().Err(err).Msgf("failed to bind for UDP %s", cfg.Listen()) } wg.Add(1) if backend, err = initUDP(wg, ctx, bp, cfg, client); err != nil { log.Panic().Caller().Err(err).Msg("failed to init UDP") } buf = bp.Get() defer bp.Put(buf) // wait for context cancel/done then close socket; exit mechanism go func(c net.PacketConn, ctx context.Context) { <-ctx.Done() _ = client.Close() }(client, ctx) udpReadLoop: for { if n, addr, err = client.ReadFrom(buf); err != nil { log.Error().Caller().Err(err).Str(CONNECTION, cfg.Name()).Str(DIRECTION, CLIENT_TO_BACKEND).Msg("udp read error") if errors.Is(err, net.ErrClosed) { return } continue udpReadLoop } if err = backend.Send(ctx, addr.String(), buf[:n]); err != nil { if errors.Is(err, context.Canceled) { return } log.Error().Caller().Err(err).Str(CONNECTION, cfg.Name()).Str(DIRECTION, CLIENT_TO_BACKEND).Msg("send udp message") } } }