WIP: backend udp

This commit is contained in:
Suyono 2025-04-20 12:20:36 +10:00
parent d0b3cbf967
commit 7f69553ead
4 changed files with 395 additions and 7 deletions

168
backend_udp.go Normal file
View File

@ -0,0 +1,168 @@
package netbounce
import (
"bytes"
"context"
"fmt"
"net"
"sync"
"time"
"gitea.suyono.dev/suyono/netbounce/abstract"
"github.com/rs/zerolog/log"
)
type udpRel struct {
backend net.PacketConn
clientAddr string
expiry time.Time
}
type backendUDP struct {
relations map[string]udpRel
relMtx *sync.Mutex
client net.PacketConn
cfg abstract.UDPConnectionConfig
bufPool sync.Pool
msgChan chan udpMessage
}
type udpMessage struct {
addr string
buf *bytes.Buffer
}
func initUDP(wg *sync.WaitGroup, ctx context.Context, cfg abstract.UDPConnectionConfig, client net.PacketConn) *backendUDP {
backend := &backendUDP{
relations: make(map[string]udpRel),
relMtx: new(sync.Mutex),
cfg: cfg,
client: client,
msgChan: make(chan udpMessage),
bufPool: sync.Pool{
New: func() any {
return new(bytes.Buffer)
},
},
}
go func(wg *sync.WaitGroup, ctx context.Context, cfg abstract.UDPConnectionConfig, backend *backendUDP) {
defer wg.Done()
incoming := backend.msgChan
readIncoming:
for {
select {
case <-ctx.Done():
break readIncoming
case m := <-incoming:
backend.handle(m)
}
}
}(wg, ctx, cfg, backend)
return backend
}
func (b *backendUDP) findRel(clientAddr string) (udpRel, bool) {
b.relMtx.Lock()
defer b.relMtx.Unlock()
rel, ok := b.relations[clientAddr]
if ok && rel.expiry.Before(time.Now()) {
delete(b.relations, clientAddr)
return rel, false
}
return rel, ok
}
func (b *backendUDP) addRel(clientAddr string, rel udpRel) {
b.relMtx.Lock()
defer b.relMtx.Unlock()
b.relations[clientAddr] = rel
}
func (b *backendUDP) createRelSend(clientAddr string, buf []byte) (udpRel, error) {
var (
udpAddr *net.UDPAddr
udpConn *net.UDPConn
n int
err error
)
rel := udpRel{
clientAddr: clientAddr,
}
if udpAddr, err = net.ResolveUDPAddr("udp", b.cfg.Backend()); err != nil {
return rel, fmt.Errorf("create udp relation and send message: resolve udp addr: %w", err)
}
if udpConn, err = net.DialUDP("udp", nil, udpAddr); err != nil {
return rel, fmt.Errorf("create udp relation and send message: dial udp: %w", err)
}
if n, err = udpConn.Write(buf); err != nil && n == 0 {
_ = udpConn.Close()
return rel, fmt.Errorf("create udp relation and send message: write udp: %w", err)
}
rel.backend = udpConn
rel.expiry = time.Now().Add(b.cfg.Timeout())
return rel, nil
}
func (b *backendUDP) handle(msg udpMessage) {
var (
rel udpRel
ok bool
err error
)
if rel, ok = b.findRel(msg.addr); !ok {
if rel, err = b.createRelSend(msg.addr, msg.buf.Bytes()); err != nil {
log.Error().Err(err).Str(DIRECTION, CLIENT_TO_BACKEND).Msg("establish relation with udp backend")
}
b.addRel(msg.addr, rel)
return
}
_ = rel
// if rel.expiry.Before(time.Now()) {
// //TODO: handle expiry
// }
}
func (b *backendUDP) Send(ctx context.Context, addr string, p []byte) error {
var (
n int
err error
)
buf := b.bufPool.Get().(*bytes.Buffer)
if n, err = buf.Write(p); err != nil {
buf.Reset()
b.bufPool.Put(buf)
return fmt.Errorf("send udp message to handler: %w", err)
}
if len(p) != n {
buf.Reset()
b.bufPool.Put(buf)
return fmt.Errorf("send udp message to handler: failed to write complete message")
}
select {
case b.msgChan <- udpMessage{
addr: addr,
buf: buf,
}:
case <-ctx.Done():
}
return nil
}

223
bounce.go
View File

@ -17,18 +17,32 @@ package netbounce
*/
import (
"context"
"fmt"
"net"
"sync"
"gitea.suyono.dev/suyono/netbounce/abstract"
"gitea.suyono.dev/suyono/netbounce/config"
"github.com/rs/zerolog/log"
)
func Bounce() error {
const (
CONNECTION = "connection"
DIRECTION = "direction"
CLIENT_TO_BACKEND = "client-to-backend"
BACKEND_TO_CLIENT = "backend-to-client"
BUF_SIZE = 4096
)
func Bounce(ctx context.Context) error {
var (
names []string
err error
cc abstract.ConnectionConfig
)
wg := &sync.WaitGroup{}
if names, err = config.ListConnection(); err != nil {
return fmt.Errorf("bounce: connection list: %w", err)
}
@ -37,20 +51,219 @@ func Bounce() error {
return fmt.Errorf("bounce: connection get %s: %w", name, err)
}
wg.Add(1)
if cc.Type() == abstract.TCP {
go RunTCP(cc.AsTCP())
go RunTCP(wg, ctx, cc.AsTCP())
} else {
go RunUDP(cc.AsUDP())
go RunUDP(wg, ctx, cc.AsUDP())
}
}
return nil
}
func RunTCP(cfg abstract.TCPConnectionConfig) {
func RunTCP(wg *sync.WaitGroup, ctx context.Context, cfg abstract.TCPConnectionConfig) {
defer wg.Done()
var (
err error
l net.Listener
accepted chan net.Conn
)
if l, err = net.Listen("tcp", cfg.Listen()); err != nil {
panic(fmt.Errorf("failed to listen tcp: %w", err))
}
accepted = make(chan net.Conn)
go func(ctx context.Context, cfg abstract.TCPConnectionConfig, listener net.Listener, acceptChan chan<- net.Conn) {
defer log.Info().Msgf("go routine for accepting connection quited: %s", cfg.Name())
var (
err error
c net.Conn
)
acceptIncoming:
for {
if c, err = listener.Accept(); err != nil {
log.Error().Err(err).Msg("accept TCP connection")
continue
}
select {
case acceptChan <- c:
case <-ctx.Done():
break acceptIncoming
}
}
}(ctx, cfg, l, accepted)
connStarter:
for {
select {
case c := <-accepted:
wg.Add(1)
go makeTCPConnection(wg, ctx, cfg, c)
case <-ctx.Done():
break connStarter
}
}
}
func RunUDP(cfg abstract.UDPConnectionConfig) {
func makeTCPConnection(wg *sync.WaitGroup, ctx context.Context, cfg abstract.TCPConnectionConfig, conn net.Conn) {
defer wg.Done()
var (
backend net.Conn
err error
)
if backend, err = net.Dial("tcp", cfg.Backend()); err != nil {
log.Error().Err(err).Str(CONNECTION, cfg.Name()).Msg("connection to backend failed")
return
}
ctxConn, cancelConn := context.WithCancel(ctx)
wg.Add(1)
go tcpClient2Backend(wg, ctxConn, cfg, cancelConn, conn, backend)
wg.Add(1)
go tcpBackend2Client(wg, ctxConn, cfg, cancelConn, backend, conn)
<-ctxConn.Done()
_ = backend.Close()
_ = conn.Close()
}
func tcpBackend2Client(wg *sync.WaitGroup, ctx context.Context, cfg abstract.TCPConnectionConfig, cf context.CancelFunc, backend, client net.Conn) {
defer wg.Done()
var (
buf []byte
n, wn int
cancel context.CancelFunc
err error
)
cancel = cf
buf = make([]byte, BUF_SIZE)
backendRead:
for {
if n, err = backend.Read(buf); err != nil {
log.Error().Err(err).Str(CONNECTION, cfg.Name()).Str(DIRECTION, BACKEND_TO_CLIENT).Msg("tcp read error")
cancel()
break backendRead
}
select {
case <-ctx.Done():
break backendRead
default:
}
if wn, err = client.Write(buf[:n]); err != nil {
log.Error().Err(err).Str(CONNECTION, cfg.Name()).Str(DIRECTION, BACKEND_TO_CLIENT).Msg("tcp read error")
cancel()
break backendRead
}
if wn != n {
log.Error().Err(fmt.Errorf("mismatch length between read and write")).Str(CONNECTION, cfg.Name()).Str(DIRECTION, BACKEND_TO_CLIENT).Msg("tcp read problem")
cancel()
break backendRead
}
select {
case <-ctx.Done():
break backendRead
default:
}
}
}
func tcpClient2Backend(wg *sync.WaitGroup, ctx context.Context, cfg abstract.TCPConnectionConfig, cf context.CancelFunc, client, backend net.Conn) {
defer wg.Done()
var (
buf []byte
n, wn int
cancel context.CancelFunc
err error
)
cancel = cf
buf = make([]byte, BUF_SIZE)
clientRead:
for {
if n, err = client.Read(buf); err != nil {
log.Error().Err(err).Str(CONNECTION, cfg.Name()).Str(DIRECTION, CLIENT_TO_BACKEND).Msg("tcp read error")
cancel()
break clientRead
}
select {
case <-ctx.Done():
break clientRead
default:
}
if wn, err = backend.Write(buf[:n]); err != nil {
log.Error().Err(err).Str(CONNECTION, cfg.Name()).Str(DIRECTION, CLIENT_TO_BACKEND).Msg("tcp write error")
cancel()
break clientRead
}
if wn != n {
log.Error().Err(fmt.Errorf("mismatch length between read and write")).Str(CONNECTION, cfg.Name()).Str(DIRECTION, CLIENT_TO_BACKEND).Msg("tcp write problem")
cancel()
break clientRead
}
select {
case <-ctx.Done():
break clientRead
default:
}
}
}
func RunUDP(wg *sync.WaitGroup, ctx context.Context, cfg abstract.UDPConnectionConfig) {
defer wg.Done()
var (
client net.PacketConn
buf []byte
addr net.Addr
err error
n int
backend *backendUDP
)
if client, err = net.ListenPacket("udp", cfg.Listen()); err != nil {
panic(fmt.Errorf("failed to bind for UDP %s: %w", cfg.Name(), err))
}
wg.Add(1)
backend = initUDP(wg, ctx, cfg, client)
buf = make([]byte, BUF_SIZE)
udpReadLoop:
for {
if n, addr, err = client.ReadFrom(buf); err != nil {
log.Error().Err(err).Str(CONNECTION, cfg.Name()).Str(DIRECTION, CLIENT_TO_BACKEND).Msg("udp read error")
continue udpReadLoop
}
if err = backend.Send(ctx, addr.String(), buf[:n]); err != nil {
log.Error().Err(err).Str(CONNECTION, cfg.Name()).Str(DIRECTION, CLIENT_TO_BACKEND).Msg("send udp message")
continue udpReadLoop
}
select {
case <-ctx.Done():
break udpReadLoop
default:
}
}
}

View File

@ -17,23 +17,30 @@ package main
*/
import (
"context"
"os/signal"
"gitea.suyono.dev/suyono/netbounce"
"gitea.suyono.dev/suyono/netbounce/config"
"gitea.suyono.dev/suyono/netbounce/flag"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"golang.org/x/sys/unix"
)
func main() {
zerolog.TimeFieldFormat = zerolog.TimeFormatUnix
ctx, stop := signal.NotifyContext(context.Background(), unix.SIGINT, unix.SIGTERM)
defer stop()
flag.Parse()
config.CollectEnv()
if err := config.ReadConfig(); err != nil {
log.Fatal().Err(err).Msg("error reading config")
}
if err := netbounce.Bounce(); err != nil {
if err := netbounce.Bounce(ctx); err != nil {
log.Fatal().Err(err).Msg("starting netbounce")
}
}

2
go.mod
View File

@ -6,6 +6,7 @@ require (
github.com/rs/zerolog v1.34.0
github.com/spf13/pflag v1.0.6
github.com/spf13/viper v1.20.1
golang.org/x/sys v0.32.0
)
require (
@ -21,7 +22,6 @@ require (
github.com/subosito/gotenv v1.6.0 // indirect
go.uber.org/atomic v1.9.0 // indirect
go.uber.org/multierr v1.9.0 // indirect
golang.org/x/sys v0.32.0 // indirect
golang.org/x/text v0.21.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)