301 lines
6.5 KiB
Go
301 lines
6.5 KiB
Go
package netbounce
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"sync"
|
|
"time"
|
|
|
|
"gitea.suyono.dev/suyono/netbounce/abstract"
|
|
"github.com/rs/zerolog/log"
|
|
)
|
|
|
|
type udpRel struct {
|
|
backend net.PacketConn
|
|
clientAddr string
|
|
expiry time.Time
|
|
ctx context.Context
|
|
ctxCancel context.CancelFunc
|
|
}
|
|
|
|
type backendUDP struct {
|
|
relations map[string]udpRel
|
|
relMtx *sync.Mutex
|
|
client net.PacketConn
|
|
cfg abstract.UDPConnectionConfig
|
|
bufPool *BufPool
|
|
msgChan chan udpMessage
|
|
backendAddr *net.UDPAddr
|
|
}
|
|
|
|
type udpMessage struct {
|
|
addr string
|
|
buf []byte
|
|
}
|
|
|
|
func (r udpRel) Network() string {
|
|
return "udp"
|
|
}
|
|
|
|
func (r udpRel) String() string {
|
|
return r.clientAddr
|
|
}
|
|
|
|
func initUDP(wg *sync.WaitGroup, ctx context.Context, bp *BufPool, cfg abstract.UDPConnectionConfig, client net.PacketConn) (*backendUDP, error) {
|
|
var (
|
|
addr *net.UDPAddr
|
|
err error
|
|
)
|
|
|
|
if addr, err = net.ResolveUDPAddr("udp", cfg.Backend()); err != nil {
|
|
return nil, fmt.Errorf("failed to resolve UDP address: %v", err)
|
|
}
|
|
|
|
backend := &backendUDP{
|
|
relations: make(map[string]udpRel),
|
|
relMtx: new(sync.Mutex),
|
|
cfg: cfg,
|
|
client: client,
|
|
msgChan: make(chan udpMessage),
|
|
bufPool: bp,
|
|
backendAddr: addr,
|
|
}
|
|
defer wg.Done()
|
|
|
|
//TODO: make the number of handler spawn configurable
|
|
wg.Add(1)
|
|
handlerSpawn(wg, ctx, backend)
|
|
|
|
wg.Add(1)
|
|
handlerSpawn(wg, ctx, backend)
|
|
|
|
wg.Add(1)
|
|
handlerSpawn(wg, ctx, backend)
|
|
|
|
wg.Add(1)
|
|
handlerSpawn(wg, ctx, backend)
|
|
|
|
return backend, nil
|
|
}
|
|
|
|
func handlerSpawn(wg *sync.WaitGroup, ctx context.Context, backend *backendUDP) {
|
|
defer wg.Done()
|
|
|
|
incoming := backend.msgChan
|
|
readIncoming:
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
break readIncoming
|
|
case m := <-incoming:
|
|
backend.handle(wg, ctx, m)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (b *backendUDP) findRel(clientAddr string) (udpRel, bool) {
|
|
b.relMtx.Lock()
|
|
defer b.relMtx.Unlock()
|
|
|
|
rel, ok := b.relations[clientAddr]
|
|
if ok && rel.expiry.Before(time.Now()) {
|
|
// expired
|
|
if rel.ctxCancel != nil {
|
|
rel.ctxCancel()
|
|
}
|
|
|
|
// cleans up connection; do it in a separate goroutine just in case the close operation might take some time
|
|
go func(c net.PacketConn) {
|
|
_ = rel.backend.Close()
|
|
}(rel.backend)
|
|
|
|
delete(b.relations, clientAddr)
|
|
return rel, false
|
|
}
|
|
|
|
return rel, ok
|
|
}
|
|
|
|
func (b *backendUDP) addUpdateRel(clientAddr string, rel udpRel) {
|
|
b.relMtx.Lock()
|
|
defer b.relMtx.Unlock()
|
|
|
|
b.relations[clientAddr] = rel
|
|
}
|
|
|
|
func (b *backendUDP) createRelSend(clientAddr string, buf []byte) (udpRel, error) {
|
|
var (
|
|
laddr *net.UDPAddr
|
|
udpConn *net.UDPConn
|
|
err error
|
|
)
|
|
|
|
rel := udpRel{
|
|
clientAddr: clientAddr,
|
|
}
|
|
|
|
if laddr, err = net.ResolveUDPAddr("udp", ""); err != nil {
|
|
return rel, fmt.Errorf("create udp relation and send message: resolve local/self address for UDP: %w", err)
|
|
}
|
|
|
|
if udpConn, err = net.ListenUDP("udp", laddr); err != nil {
|
|
return rel, fmt.Errorf("create udp relation and send message: bind local/self address for UDP: %w", err)
|
|
}
|
|
|
|
if _, err = udpConn.WriteTo(buf, b.backendAddr); err != nil {
|
|
if errors.Is(err, net.ErrClosed) {
|
|
return rel, fmt.Errorf("create udp relation and send message: write udp: %w", err)
|
|
}
|
|
|
|
log.Error().Caller().Err(err).Str(CONNECTION, b.cfg.Name()).Str(DIRECTION, CLIENT_TO_BACKEND).Msg("create udp relation & send")
|
|
}
|
|
|
|
rel.backend = udpConn
|
|
rel.expiry = time.Now().Add(b.cfg.Timeout())
|
|
|
|
return rel, nil
|
|
}
|
|
|
|
func (b *backendUDP) relSend(rel udpRel, buf []byte) error {
|
|
var (
|
|
n int
|
|
err error
|
|
)
|
|
|
|
if n, err = rel.backend.WriteTo(buf, b.backendAddr); err != nil && n == 0 {
|
|
return fmt.Errorf("relSend: %w", err)
|
|
}
|
|
|
|
if len(buf) != n {
|
|
log.Error().Caller().Msg("relSend mismatch size")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (b *backendUDP) handle(wg *sync.WaitGroup, ctx context.Context, msg udpMessage) {
|
|
var (
|
|
rel udpRel
|
|
ok bool
|
|
err error
|
|
)
|
|
|
|
defer b.bufPool.Put(msg.buf)
|
|
if rel, ok = b.findRel(msg.addr); !ok {
|
|
if rel, err = b.createRelSend(msg.addr, msg.buf); err != nil {
|
|
log.Error().Caller().Err(err).Str(DIRECTION, CLIENT_TO_BACKEND).Msg("establish relation with udp backend")
|
|
return
|
|
}
|
|
|
|
rel.ctx, rel.ctxCancel = context.WithCancel(ctx)
|
|
b.addUpdateRel(msg.addr, rel)
|
|
|
|
wg.Add(1)
|
|
go b.udpBackend2Client(wg, rel)
|
|
return
|
|
}
|
|
|
|
if err = b.relSend(rel, msg.buf); err != nil {
|
|
log.Error().Caller().Err(err).Msg("handle: send for existing relation")
|
|
return
|
|
}
|
|
|
|
rel.expiry = time.Now().Add(b.cfg.Timeout())
|
|
b.addUpdateRel(rel.clientAddr, rel)
|
|
}
|
|
|
|
func (b *backendUDP) Send(ctx context.Context, addr string, p []byte) error {
|
|
var (
|
|
n int
|
|
)
|
|
|
|
buf := b.bufPool.Get() // the buf will be released in the handle
|
|
n = copy(buf, p)
|
|
|
|
if len(p) != n {
|
|
return fmt.Errorf("send udp message to handler: failed to write complete message")
|
|
}
|
|
|
|
select {
|
|
case b.msgChan <- udpMessage{
|
|
addr: addr,
|
|
buf: buf[:n],
|
|
}:
|
|
case <-ctx.Done():
|
|
return context.Canceled
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (b *backendUDP) udpBackend2Client(wg *sync.WaitGroup, rel udpRel) {
|
|
defer wg.Done()
|
|
|
|
var (
|
|
n, wn int
|
|
err error
|
|
ok bool
|
|
clientAddr string
|
|
)
|
|
|
|
buf := b.bufPool.Get()
|
|
defer b.bufPool.Put(buf)
|
|
|
|
go func(ctx context.Context, backend net.PacketConn) {
|
|
<-ctx.Done()
|
|
_ = backend.Close()
|
|
}(rel.ctx, rel.backend)
|
|
|
|
udpBackendLoop:
|
|
for {
|
|
if n, _, err = rel.backend.ReadFrom(buf); err != nil {
|
|
if errors.Is(err, net.ErrClosed) {
|
|
rel.ctxCancel()
|
|
break udpBackendLoop
|
|
}
|
|
|
|
log.Error().Caller().Err(err).Str(CONNECTION, rel.clientAddr).Str(DIRECTION, BACKEND_TO_CLIENT).Msg("udpBackend2Client: read from udp")
|
|
continue udpBackendLoop
|
|
}
|
|
|
|
clientAddr = rel.clientAddr
|
|
if rel, ok = b.findRel(clientAddr); !ok {
|
|
log.Error().Caller().Msg("relation not found")
|
|
return
|
|
}
|
|
|
|
select {
|
|
case <-rel.ctx.Done():
|
|
break udpBackendLoop
|
|
default:
|
|
}
|
|
|
|
if wn, err = b.client.WriteTo(buf[:n], rel); err != nil {
|
|
// In case of error, never close b.client, it's a shared Packet Conn.
|
|
// All UDP relations use a shared connection/socket back to the client
|
|
if errors.Is(err, net.ErrClosed) {
|
|
rel.ctxCancel()
|
|
break udpBackendLoop
|
|
}
|
|
|
|
log.Error().Caller().Err(err).Str(CONNECTION, rel.clientAddr).Str(DIRECTION, BACKEND_TO_CLIENT).Msg("udpBackend2Client: write to client")
|
|
continue udpBackendLoop
|
|
}
|
|
|
|
if wn != n {
|
|
log.Warn().Caller().Str(CONNECTION, rel.clientAddr).Str(DIRECTION, BACKEND_TO_CLIENT).Msgf("failed to write complete message: %d vs %d", wn, n)
|
|
}
|
|
|
|
rel.expiry = time.Now().Add(b.cfg.Timeout())
|
|
b.addUpdateRel(clientAddr, rel)
|
|
|
|
select {
|
|
case <-rel.ctx.Done():
|
|
break udpBackendLoop
|
|
default:
|
|
}
|
|
}
|
|
}
|