Files
netbounce/bounce.go

402 lines
9.3 KiB
Go

package netbounce
/*
Copyright 2025 Suyono <suyono3484@gmail.com>
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
import (
"context"
"errors"
"fmt"
"net"
"sync"
"gitea.suyono.dev/suyono/netbounce/abstract"
"gitea.suyono.dev/suyono/netbounce/config"
"github.com/rs/zerolog/log"
)
const (
CONNECTION = "connection"
DIRECTION = "direction"
CLIENT_TO_BACKEND = "client-to-backend"
BACKEND_TO_CLIENT = "backend-to-client"
BUF_SIZE = 4096
)
type RetryConfig interface {
ReadRetry() int
WriteRetry() int
}
type BufPool struct {
pool sync.Pool
}
func (b *BufPool) Get() []byte {
return b.pool.Get().([]byte)
}
func (b *BufPool) Put(buf []byte) {
//lint:ignore SA6002 copying a 3-word slice header is intentional and cheap
b.pool.Put(buf[:cap(buf)])
}
type retryCounter struct {
max int
counter int
}
func ReadRetry(cfg RetryConfig) *retryCounter {
return &retryCounter{
max: cfg.ReadRetry(),
counter: 0,
}
}
func WriteRetry(cfg RetryConfig) *retryCounter {
return &retryCounter{
max: cfg.WriteRetry(),
counter: 0,
}
}
func (r *retryCounter) IsContinue() bool {
if r.counter == r.max {
return false
}
r.counter++
return true
}
func (r *retryCounter) Reset() {
r.counter = 0
}
func (r *retryCounter) MaxCounterExceeded() bool {
return r.counter == r.max
}
func Bounce(ctx context.Context) error {
var (
names []string
err error
cc abstract.ConnectionConfig
bp *BufPool
)
wg := &sync.WaitGroup{}
if names, err = config.ListConnection(); err != nil {
return fmt.Errorf("bounce: connection list: %w", err)
}
bp = &BufPool{
pool: sync.Pool{
New: func() any {
return make([]byte, BUF_SIZE)
},
},
}
for _, name := range names {
if cc, err = config.GetConnection(name); err != nil {
return fmt.Errorf("bounce: connection get %s: %w", name, err)
}
wg.Add(1)
if cc.Type() == abstract.TCP {
go RunTCP(wg, ctx, bp, cc.AsTCP())
} else {
go RunUDP(wg, ctx, bp, cc.AsUDP())
}
}
return nil
}
func RunTCP(wg *sync.WaitGroup, ctx context.Context, bp *BufPool, cfg abstract.TCPConnectionConfig) {
defer wg.Done()
var (
err error
l net.Listener
accepted chan net.Conn
)
if l, err = net.Listen("tcp", cfg.Listen()); err != nil {
log.Panic().Caller().Err(err).Msg("listen tcp")
}
accepted = make(chan net.Conn)
go func(ctx context.Context, cfg abstract.TCPConnectionConfig, listener net.Listener, acceptChan chan<- net.Conn) {
defer log.Info().Caller().Msgf("go routine for accepting connection quited: %s", cfg.Name())
var (
err error
c net.Conn
)
acceptIncoming:
for {
if c, err = listener.Accept(); err != nil {
if errors.Is(err, net.ErrClosed) {
log.Error().Caller().Err(err).Msg("stop accepting new connection")
break acceptIncoming
}
log.Error().Caller().Err(err).Msg("accept TCP connection")
continue
}
select {
case acceptChan <- c:
case <-ctx.Done():
break acceptIncoming
}
}
}(ctx, cfg, l, accepted)
connStarter:
for {
select {
case c := <-accepted:
wg.Add(1)
go makeTCPConnection(wg, ctx, bp, cfg, c)
case <-ctx.Done():
break connStarter
}
}
}
func makeTCPConnection(wg *sync.WaitGroup, ctx context.Context, bp *BufPool, cfg abstract.TCPConnectionConfig, conn net.Conn) {
defer wg.Done()
var (
backend net.Conn
err error
)
if backend, err = net.Dial("tcp", cfg.Backend()); err != nil {
log.Error().Caller().Err(err).Str(CONNECTION, cfg.Name()).Msg("connection to backend failed")
return
}
ctxConn, cancelConn := context.WithCancel(ctx)
wg.Add(1)
go tcpClient2Backend(wg, ctxConn, bp, cfg, cancelConn, conn, backend)
wg.Add(1)
go tcpBackend2Client(wg, ctxConn, bp, cfg, cancelConn, backend, conn)
<-ctxConn.Done()
_ = backend.Close()
_ = conn.Close()
}
func tcpBackend2Client(wg *sync.WaitGroup, ctx context.Context, bp *BufPool, cfg abstract.TCPConnectionConfig, cf context.CancelFunc, backend, client net.Conn) {
defer wg.Done()
var (
buf []byte
n, wn int
cancel context.CancelFunc
err error
readRetryCounter *retryCounter
writeRetryCounter *retryCounter
)
cancel = cf
buf = bp.Get()
defer bp.Put(buf)
readRetryCounter = ReadRetry(cfg)
writeRetryCounter = WriteRetry(cfg)
backendRead:
for readRetryCounter.IsContinue() {
if n, err = backend.Read(buf); err != nil {
log.Error().Caller().Err(err).Str(CONNECTION, cfg.Name()).Str(DIRECTION, BACKEND_TO_CLIENT).Msg("tcp read error")
if errors.Is(err, net.ErrClosed) && n == 0 {
cancel()
break backendRead
}
continue backendRead
}
readRetryCounter.Reset()
select {
case <-ctx.Done():
break backendRead
default:
}
backendWrite:
for writeRetryCounter.IsContinue() {
if wn, err = client.Write(buf[:n]); err != nil {
log.Error().Caller().Err(err).Str(CONNECTION, cfg.Name()).Str(DIRECTION, BACKEND_TO_CLIENT).Msg("tcp read error")
if errors.Is(err, net.ErrClosed) {
cancel()
break backendRead
}
continue backendWrite
}
writeRetryCounter.Reset()
break backendWrite
}
if writeRetryCounter.MaxCounterExceeded() {
log.Error().Caller().Str(CONNECTION, cfg.Name()).Str(DIRECTION, BACKEND_TO_CLIENT).Msg("tcp write retry exceeded")
cancel()
break backendRead
}
if wn != n {
log.Error().Caller().Err(fmt.Errorf("mismatch length between read and write")).Str(CONNECTION, cfg.Name()).Str(DIRECTION, BACKEND_TO_CLIENT).Msg("tcp read problem")
cancel()
break backendRead
}
select {
case <-ctx.Done():
break backendRead
default:
}
}
if readRetryCounter.MaxCounterExceeded() {
log.Error().Caller().Str(CONNECTION, cfg.Name()).Str(DIRECTION, BACKEND_TO_CLIENT).Msg("tcp read retry exceeded")
cancel()
}
}
func tcpClient2Backend(wg *sync.WaitGroup, ctx context.Context, bp *BufPool, cfg abstract.TCPConnectionConfig, cf context.CancelFunc, client, backend net.Conn) {
defer wg.Done()
var (
buf []byte
n, wn int
cancel context.CancelFunc
err error
readRetryCounter *retryCounter
writeRetryCounter *retryCounter
)
cancel = cf
buf = bp.Get()
defer bp.Put(buf)
readRetryCounter = ReadRetry(cfg)
writeRetryCounter = WriteRetry(cfg)
clientRead:
for readRetryCounter.IsContinue() {
if n, err = client.Read(buf); err != nil {
log.Error().Caller().Err(err).Str(CONNECTION, cfg.Name()).Str(DIRECTION, CLIENT_TO_BACKEND).Msg("tcp read error")
if errors.Is(err, net.ErrClosed) && n == 0 {
cancel()
break clientRead
}
continue clientRead
}
readRetryCounter.Reset()
select {
case <-ctx.Done():
break clientRead
default:
}
clientWrite:
for writeRetryCounter.IsContinue() {
if wn, err = backend.Write(buf[:n]); err != nil {
log.Error().Caller().Err(err).Str(CONNECTION, cfg.Name()).Str(DIRECTION, CLIENT_TO_BACKEND).Msg("tcp write error")
if errors.Is(err, net.ErrClosed) {
cancel()
break clientRead
}
continue clientWrite
}
writeRetryCounter.Reset()
break clientWrite
}
if writeRetryCounter.MaxCounterExceeded() {
log.Error().Caller().Str(CONNECTION, cfg.Name()).Str(DIRECTION, CLIENT_TO_BACKEND).Msg("tcp write retry exceeded")
cancel()
break clientRead
}
if wn != n {
log.Error().Err(fmt.Errorf("mismatch length between read and write")).Str(CONNECTION, cfg.Name()).Str(DIRECTION, CLIENT_TO_BACKEND).Msg("tcp write problem")
cancel()
break clientRead
}
select {
case <-ctx.Done():
break clientRead
default:
}
}
if readRetryCounter.MaxCounterExceeded() {
log.Error().Caller().Str(CONNECTION, cfg.Name()).Str(DIRECTION, CLIENT_TO_BACKEND).Msg("tcp read retry exceeded")
cancel()
}
}
func RunUDP(wg *sync.WaitGroup, ctx context.Context, bp *BufPool, cfg abstract.UDPConnectionConfig) {
defer wg.Done()
var (
client net.PacketConn
buf []byte
addr net.Addr
err error
n int
backend *backendUDP
)
if client, err = net.ListenPacket("udp", cfg.Listen()); err != nil {
log.Panic().Err(err).Msgf("failed to bind for UDP %s", cfg.Listen())
}
wg.Add(1)
if backend, err = initUDP(wg, ctx, bp, cfg, client); err != nil {
log.Panic().Caller().Err(err).Msg("failed to init UDP")
}
buf = bp.Get()
defer bp.Put(buf)
// wait for context cancel/done then close socket; exit mechanism
go func(c net.PacketConn, ctx context.Context) {
<-ctx.Done()
_ = client.Close()
}(client, ctx)
udpReadLoop:
for {
if n, addr, err = client.ReadFrom(buf); err != nil {
log.Error().Caller().Err(err).Str(CONNECTION, cfg.Name()).Str(DIRECTION, CLIENT_TO_BACKEND).Msg("udp read error")
if errors.Is(err, net.ErrClosed) {
return
}
continue udpReadLoop
}
if err = backend.Send(ctx, addr.String(), buf[:n]); err != nil {
if errors.Is(err, context.Canceled) {
return
}
log.Error().Caller().Err(err).Str(CONNECTION, cfg.Name()).Str(DIRECTION, CLIENT_TO_BACKEND).Msg("send udp message")
}
}
}