WIP: backend udp
This commit is contained in:
parent
d0b3cbf967
commit
7f69553ead
168
backend_udp.go
Normal file
168
backend_udp.go
Normal file
@ -0,0 +1,168 @@
|
||||
package netbounce
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"gitea.suyono.dev/suyono/netbounce/abstract"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
type udpRel struct {
|
||||
backend net.PacketConn
|
||||
clientAddr string
|
||||
expiry time.Time
|
||||
}
|
||||
|
||||
type backendUDP struct {
|
||||
relations map[string]udpRel
|
||||
relMtx *sync.Mutex
|
||||
client net.PacketConn
|
||||
cfg abstract.UDPConnectionConfig
|
||||
bufPool sync.Pool
|
||||
msgChan chan udpMessage
|
||||
}
|
||||
|
||||
type udpMessage struct {
|
||||
addr string
|
||||
buf *bytes.Buffer
|
||||
}
|
||||
|
||||
func initUDP(wg *sync.WaitGroup, ctx context.Context, cfg abstract.UDPConnectionConfig, client net.PacketConn) *backendUDP {
|
||||
backend := &backendUDP{
|
||||
relations: make(map[string]udpRel),
|
||||
relMtx: new(sync.Mutex),
|
||||
cfg: cfg,
|
||||
client: client,
|
||||
msgChan: make(chan udpMessage),
|
||||
bufPool: sync.Pool{
|
||||
New: func() any {
|
||||
return new(bytes.Buffer)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
go func(wg *sync.WaitGroup, ctx context.Context, cfg abstract.UDPConnectionConfig, backend *backendUDP) {
|
||||
defer wg.Done()
|
||||
|
||||
incoming := backend.msgChan
|
||||
readIncoming:
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
break readIncoming
|
||||
case m := <-incoming:
|
||||
backend.handle(m)
|
||||
}
|
||||
}
|
||||
}(wg, ctx, cfg, backend)
|
||||
|
||||
return backend
|
||||
}
|
||||
|
||||
func (b *backendUDP) findRel(clientAddr string) (udpRel, bool) {
|
||||
b.relMtx.Lock()
|
||||
defer b.relMtx.Unlock()
|
||||
|
||||
rel, ok := b.relations[clientAddr]
|
||||
if ok && rel.expiry.Before(time.Now()) {
|
||||
delete(b.relations, clientAddr)
|
||||
return rel, false
|
||||
}
|
||||
|
||||
return rel, ok
|
||||
}
|
||||
|
||||
func (b *backendUDP) addRel(clientAddr string, rel udpRel) {
|
||||
b.relMtx.Lock()
|
||||
defer b.relMtx.Unlock()
|
||||
|
||||
b.relations[clientAddr] = rel
|
||||
}
|
||||
|
||||
func (b *backendUDP) createRelSend(clientAddr string, buf []byte) (udpRel, error) {
|
||||
var (
|
||||
udpAddr *net.UDPAddr
|
||||
udpConn *net.UDPConn
|
||||
n int
|
||||
err error
|
||||
)
|
||||
|
||||
rel := udpRel{
|
||||
clientAddr: clientAddr,
|
||||
}
|
||||
|
||||
if udpAddr, err = net.ResolveUDPAddr("udp", b.cfg.Backend()); err != nil {
|
||||
return rel, fmt.Errorf("create udp relation and send message: resolve udp addr: %w", err)
|
||||
}
|
||||
|
||||
if udpConn, err = net.DialUDP("udp", nil, udpAddr); err != nil {
|
||||
return rel, fmt.Errorf("create udp relation and send message: dial udp: %w", err)
|
||||
}
|
||||
|
||||
if n, err = udpConn.Write(buf); err != nil && n == 0 {
|
||||
_ = udpConn.Close()
|
||||
return rel, fmt.Errorf("create udp relation and send message: write udp: %w", err)
|
||||
}
|
||||
|
||||
rel.backend = udpConn
|
||||
rel.expiry = time.Now().Add(b.cfg.Timeout())
|
||||
|
||||
return rel, nil
|
||||
}
|
||||
|
||||
func (b *backendUDP) handle(msg udpMessage) {
|
||||
var (
|
||||
rel udpRel
|
||||
ok bool
|
||||
err error
|
||||
)
|
||||
|
||||
if rel, ok = b.findRel(msg.addr); !ok {
|
||||
if rel, err = b.createRelSend(msg.addr, msg.buf.Bytes()); err != nil {
|
||||
log.Error().Err(err).Str(DIRECTION, CLIENT_TO_BACKEND).Msg("establish relation with udp backend")
|
||||
}
|
||||
|
||||
b.addRel(msg.addr, rel)
|
||||
return
|
||||
}
|
||||
|
||||
_ = rel
|
||||
// if rel.expiry.Before(time.Now()) {
|
||||
// //TODO: handle expiry
|
||||
// }
|
||||
|
||||
}
|
||||
|
||||
func (b *backendUDP) Send(ctx context.Context, addr string, p []byte) error {
|
||||
var (
|
||||
n int
|
||||
err error
|
||||
)
|
||||
buf := b.bufPool.Get().(*bytes.Buffer)
|
||||
if n, err = buf.Write(p); err != nil {
|
||||
buf.Reset()
|
||||
b.bufPool.Put(buf)
|
||||
return fmt.Errorf("send udp message to handler: %w", err)
|
||||
}
|
||||
|
||||
if len(p) != n {
|
||||
buf.Reset()
|
||||
b.bufPool.Put(buf)
|
||||
return fmt.Errorf("send udp message to handler: failed to write complete message")
|
||||
}
|
||||
|
||||
select {
|
||||
case b.msgChan <- udpMessage{
|
||||
addr: addr,
|
||||
buf: buf,
|
||||
}:
|
||||
case <-ctx.Done():
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
223
bounce.go
223
bounce.go
@ -17,18 +17,32 @@ package netbounce
|
||||
*/
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"gitea.suyono.dev/suyono/netbounce/abstract"
|
||||
"gitea.suyono.dev/suyono/netbounce/config"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
func Bounce() error {
|
||||
const (
|
||||
CONNECTION = "connection"
|
||||
DIRECTION = "direction"
|
||||
CLIENT_TO_BACKEND = "client-to-backend"
|
||||
BACKEND_TO_CLIENT = "backend-to-client"
|
||||
BUF_SIZE = 4096
|
||||
)
|
||||
|
||||
func Bounce(ctx context.Context) error {
|
||||
var (
|
||||
names []string
|
||||
err error
|
||||
cc abstract.ConnectionConfig
|
||||
)
|
||||
|
||||
wg := &sync.WaitGroup{}
|
||||
if names, err = config.ListConnection(); err != nil {
|
||||
return fmt.Errorf("bounce: connection list: %w", err)
|
||||
}
|
||||
@ -37,20 +51,219 @@ func Bounce() error {
|
||||
return fmt.Errorf("bounce: connection get %s: %w", name, err)
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
if cc.Type() == abstract.TCP {
|
||||
go RunTCP(cc.AsTCP())
|
||||
go RunTCP(wg, ctx, cc.AsTCP())
|
||||
} else {
|
||||
go RunUDP(cc.AsUDP())
|
||||
go RunUDP(wg, ctx, cc.AsUDP())
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func RunTCP(cfg abstract.TCPConnectionConfig) {
|
||||
func RunTCP(wg *sync.WaitGroup, ctx context.Context, cfg abstract.TCPConnectionConfig) {
|
||||
defer wg.Done()
|
||||
|
||||
var (
|
||||
err error
|
||||
l net.Listener
|
||||
accepted chan net.Conn
|
||||
)
|
||||
|
||||
if l, err = net.Listen("tcp", cfg.Listen()); err != nil {
|
||||
panic(fmt.Errorf("failed to listen tcp: %w", err))
|
||||
}
|
||||
|
||||
accepted = make(chan net.Conn)
|
||||
go func(ctx context.Context, cfg abstract.TCPConnectionConfig, listener net.Listener, acceptChan chan<- net.Conn) {
|
||||
defer log.Info().Msgf("go routine for accepting connection quited: %s", cfg.Name())
|
||||
var (
|
||||
err error
|
||||
c net.Conn
|
||||
)
|
||||
|
||||
acceptIncoming:
|
||||
for {
|
||||
if c, err = listener.Accept(); err != nil {
|
||||
log.Error().Err(err).Msg("accept TCP connection")
|
||||
continue
|
||||
}
|
||||
|
||||
select {
|
||||
case acceptChan <- c:
|
||||
case <-ctx.Done():
|
||||
break acceptIncoming
|
||||
}
|
||||
}
|
||||
}(ctx, cfg, l, accepted)
|
||||
|
||||
connStarter:
|
||||
for {
|
||||
select {
|
||||
case c := <-accepted:
|
||||
wg.Add(1)
|
||||
go makeTCPConnection(wg, ctx, cfg, c)
|
||||
case <-ctx.Done():
|
||||
break connStarter
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func RunUDP(cfg abstract.UDPConnectionConfig) {
|
||||
func makeTCPConnection(wg *sync.WaitGroup, ctx context.Context, cfg abstract.TCPConnectionConfig, conn net.Conn) {
|
||||
defer wg.Done()
|
||||
|
||||
var (
|
||||
backend net.Conn
|
||||
err error
|
||||
)
|
||||
|
||||
if backend, err = net.Dial("tcp", cfg.Backend()); err != nil {
|
||||
log.Error().Err(err).Str(CONNECTION, cfg.Name()).Msg("connection to backend failed")
|
||||
return
|
||||
}
|
||||
|
||||
ctxConn, cancelConn := context.WithCancel(ctx)
|
||||
|
||||
wg.Add(1)
|
||||
go tcpClient2Backend(wg, ctxConn, cfg, cancelConn, conn, backend)
|
||||
|
||||
wg.Add(1)
|
||||
go tcpBackend2Client(wg, ctxConn, cfg, cancelConn, backend, conn)
|
||||
|
||||
<-ctxConn.Done()
|
||||
_ = backend.Close()
|
||||
_ = conn.Close()
|
||||
}
|
||||
|
||||
func tcpBackend2Client(wg *sync.WaitGroup, ctx context.Context, cfg abstract.TCPConnectionConfig, cf context.CancelFunc, backend, client net.Conn) {
|
||||
defer wg.Done()
|
||||
|
||||
var (
|
||||
buf []byte
|
||||
n, wn int
|
||||
cancel context.CancelFunc
|
||||
err error
|
||||
)
|
||||
|
||||
cancel = cf
|
||||
buf = make([]byte, BUF_SIZE)
|
||||
backendRead:
|
||||
for {
|
||||
if n, err = backend.Read(buf); err != nil {
|
||||
log.Error().Err(err).Str(CONNECTION, cfg.Name()).Str(DIRECTION, BACKEND_TO_CLIENT).Msg("tcp read error")
|
||||
cancel()
|
||||
break backendRead
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
break backendRead
|
||||
default:
|
||||
}
|
||||
|
||||
if wn, err = client.Write(buf[:n]); err != nil {
|
||||
log.Error().Err(err).Str(CONNECTION, cfg.Name()).Str(DIRECTION, BACKEND_TO_CLIENT).Msg("tcp read error")
|
||||
cancel()
|
||||
break backendRead
|
||||
}
|
||||
|
||||
if wn != n {
|
||||
log.Error().Err(fmt.Errorf("mismatch length between read and write")).Str(CONNECTION, cfg.Name()).Str(DIRECTION, BACKEND_TO_CLIENT).Msg("tcp read problem")
|
||||
cancel()
|
||||
break backendRead
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
break backendRead
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func tcpClient2Backend(wg *sync.WaitGroup, ctx context.Context, cfg abstract.TCPConnectionConfig, cf context.CancelFunc, client, backend net.Conn) {
|
||||
defer wg.Done()
|
||||
|
||||
var (
|
||||
buf []byte
|
||||
n, wn int
|
||||
cancel context.CancelFunc
|
||||
err error
|
||||
)
|
||||
|
||||
cancel = cf
|
||||
buf = make([]byte, BUF_SIZE)
|
||||
clientRead:
|
||||
for {
|
||||
if n, err = client.Read(buf); err != nil {
|
||||
log.Error().Err(err).Str(CONNECTION, cfg.Name()).Str(DIRECTION, CLIENT_TO_BACKEND).Msg("tcp read error")
|
||||
cancel()
|
||||
break clientRead
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
break clientRead
|
||||
default:
|
||||
}
|
||||
|
||||
if wn, err = backend.Write(buf[:n]); err != nil {
|
||||
log.Error().Err(err).Str(CONNECTION, cfg.Name()).Str(DIRECTION, CLIENT_TO_BACKEND).Msg("tcp write error")
|
||||
cancel()
|
||||
break clientRead
|
||||
}
|
||||
|
||||
if wn != n {
|
||||
log.Error().Err(fmt.Errorf("mismatch length between read and write")).Str(CONNECTION, cfg.Name()).Str(DIRECTION, CLIENT_TO_BACKEND).Msg("tcp write problem")
|
||||
cancel()
|
||||
break clientRead
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
break clientRead
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func RunUDP(wg *sync.WaitGroup, ctx context.Context, cfg abstract.UDPConnectionConfig) {
|
||||
defer wg.Done()
|
||||
|
||||
var (
|
||||
client net.PacketConn
|
||||
buf []byte
|
||||
addr net.Addr
|
||||
err error
|
||||
n int
|
||||
backend *backendUDP
|
||||
)
|
||||
|
||||
if client, err = net.ListenPacket("udp", cfg.Listen()); err != nil {
|
||||
panic(fmt.Errorf("failed to bind for UDP %s: %w", cfg.Name(), err))
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
backend = initUDP(wg, ctx, cfg, client)
|
||||
buf = make([]byte, BUF_SIZE)
|
||||
|
||||
udpReadLoop:
|
||||
for {
|
||||
if n, addr, err = client.ReadFrom(buf); err != nil {
|
||||
log.Error().Err(err).Str(CONNECTION, cfg.Name()).Str(DIRECTION, CLIENT_TO_BACKEND).Msg("udp read error")
|
||||
continue udpReadLoop
|
||||
}
|
||||
|
||||
if err = backend.Send(ctx, addr.String(), buf[:n]); err != nil {
|
||||
log.Error().Err(err).Str(CONNECTION, cfg.Name()).Str(DIRECTION, CLIENT_TO_BACKEND).Msg("send udp message")
|
||||
continue udpReadLoop
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
break udpReadLoop
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -17,23 +17,30 @@ package main
|
||||
*/
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os/signal"
|
||||
|
||||
"gitea.suyono.dev/suyono/netbounce"
|
||||
"gitea.suyono.dev/suyono/netbounce/config"
|
||||
"gitea.suyono.dev/suyono/netbounce/flag"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func main() {
|
||||
zerolog.TimeFieldFormat = zerolog.TimeFormatUnix
|
||||
|
||||
ctx, stop := signal.NotifyContext(context.Background(), unix.SIGINT, unix.SIGTERM)
|
||||
defer stop()
|
||||
|
||||
flag.Parse()
|
||||
config.CollectEnv()
|
||||
if err := config.ReadConfig(); err != nil {
|
||||
log.Fatal().Err(err).Msg("error reading config")
|
||||
}
|
||||
|
||||
if err := netbounce.Bounce(); err != nil {
|
||||
if err := netbounce.Bounce(ctx); err != nil {
|
||||
log.Fatal().Err(err).Msg("starting netbounce")
|
||||
}
|
||||
}
|
||||
|
||||
2
go.mod
2
go.mod
@ -6,6 +6,7 @@ require (
|
||||
github.com/rs/zerolog v1.34.0
|
||||
github.com/spf13/pflag v1.0.6
|
||||
github.com/spf13/viper v1.20.1
|
||||
golang.org/x/sys v0.32.0
|
||||
)
|
||||
|
||||
require (
|
||||
@ -21,7 +22,6 @@ require (
|
||||
github.com/subosito/gotenv v1.6.0 // indirect
|
||||
go.uber.org/atomic v1.9.0 // indirect
|
||||
go.uber.org/multierr v1.9.0 // indirect
|
||||
golang.org/x/sys v0.32.0 // indirect
|
||||
golang.org/x/text v0.21.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user