package netbounce import ( "context" "errors" "fmt" "net" "sync" "time" "gitea.suyono.dev/suyono/netbounce/abstract" "github.com/rs/zerolog/log" ) type udpRel struct { backend net.PacketConn clientAddr string expiry time.Time ctx context.Context ctxCancel context.CancelFunc } type backendUDP struct { relations map[string]udpRel relMtx *sync.Mutex client net.PacketConn cfg abstract.UDPConnectionConfig bufPool *BufPool msgChan chan udpMessage backendAddr *net.UDPAddr } type udpMessage struct { addr string buf []byte } func (r udpRel) Network() string { return "udp" } func (r udpRel) String() string { return r.clientAddr } func initUDP(wg *sync.WaitGroup, ctx context.Context, bp *BufPool, cfg abstract.UDPConnectionConfig, client net.PacketConn) (*backendUDP, error) { var ( addr *net.UDPAddr err error ) if addr, err = net.ResolveUDPAddr("udp", cfg.Backend()); err != nil { return nil, fmt.Errorf("failed to resolve UDP address: %v", err) } backend := &backendUDP{ relations: make(map[string]udpRel), relMtx: new(sync.Mutex), cfg: cfg, client: client, msgChan: make(chan udpMessage), bufPool: bp, backendAddr: addr, } defer wg.Done() //TODO: make the number of handler spawn configurable wg.Add(1) handlerSpawn(wg, ctx, backend) wg.Add(1) handlerSpawn(wg, ctx, backend) wg.Add(1) handlerSpawn(wg, ctx, backend) wg.Add(1) handlerSpawn(wg, ctx, backend) return backend, nil } func handlerSpawn(wg *sync.WaitGroup, ctx context.Context, backend *backendUDP) { defer wg.Done() incoming := backend.msgChan readIncoming: for { select { case <-ctx.Done(): break readIncoming case m := <-incoming: backend.handle(wg, ctx, m) } } } func (b *backendUDP) findRel(clientAddr string) (udpRel, bool) { b.relMtx.Lock() defer b.relMtx.Unlock() rel, ok := b.relations[clientAddr] if ok && rel.expiry.Before(time.Now()) { // expired if rel.ctxCancel != nil { rel.ctxCancel() } // cleans up connection; do it in a separate goroutine just in case the close operation might take some time go func(c net.PacketConn) { _ = rel.backend.Close() }(rel.backend) delete(b.relations, clientAddr) return rel, false } return rel, ok } func (b *backendUDP) addUpdateRel(clientAddr string, rel udpRel) { b.relMtx.Lock() defer b.relMtx.Unlock() b.relations[clientAddr] = rel } func (b *backendUDP) createRelSend(clientAddr string, buf []byte) (udpRel, error) { var ( laddr *net.UDPAddr udpConn *net.UDPConn err error ) rel := udpRel{ clientAddr: clientAddr, } if laddr, err = net.ResolveUDPAddr("udp", ""); err != nil { return rel, fmt.Errorf("create udp relation and send message: resolve local/self address for UDP: %w", err) } if udpConn, err = net.ListenUDP("udp", laddr); err != nil { return rel, fmt.Errorf("create udp relation and send message: bind local/self address for UDP: %w", err) } if _, err = udpConn.WriteTo(buf, b.backendAddr); err != nil { if errors.Is(err, net.ErrClosed) { return rel, fmt.Errorf("create udp relation and send message: write udp: %w", err) } log.Error().Caller().Err(err).Str(CONNECTION, b.cfg.Name()).Str(DIRECTION, CLIENT_TO_BACKEND).Msg("create udp relation & send") } rel.backend = udpConn rel.expiry = time.Now().Add(b.cfg.Timeout()) return rel, nil } func (b *backendUDP) relSend(rel udpRel, buf []byte) error { var ( n int err error ) if n, err = rel.backend.WriteTo(buf, b.backendAddr); err != nil && n == 0 { return fmt.Errorf("relSend: %w", err) } if len(buf) != n { log.Error().Caller().Msg("relSend mismatch size") } return nil } func (b *backendUDP) handle(wg *sync.WaitGroup, ctx context.Context, msg udpMessage) { var ( rel udpRel ok bool err error ) defer b.bufPool.Put(msg.buf) if rel, ok = b.findRel(msg.addr); !ok { if rel, err = b.createRelSend(msg.addr, msg.buf); err != nil { log.Error().Caller().Err(err).Str(DIRECTION, CLIENT_TO_BACKEND).Msg("establish relation with udp backend") return } rel.ctx, rel.ctxCancel = context.WithCancel(ctx) b.addUpdateRel(msg.addr, rel) wg.Add(1) go b.udpBackend2Client(wg, rel) return } if err = b.relSend(rel, msg.buf); err != nil { log.Error().Caller().Err(err).Msg("handle: send for existing relation") return } rel.expiry = time.Now().Add(b.cfg.Timeout()) b.addUpdateRel(rel.clientAddr, rel) } func (b *backendUDP) Send(ctx context.Context, addr string, p []byte) error { var ( n int ) buf := b.bufPool.Get() // the buf will be released in the handle n = copy(buf, p) if len(p) != n { return fmt.Errorf("send udp message to handler: failed to write complete message") } select { case b.msgChan <- udpMessage{ addr: addr, buf: buf[:n], }: case <-ctx.Done(): return context.Canceled } return nil } func (b *backendUDP) udpBackend2Client(wg *sync.WaitGroup, rel udpRel) { defer wg.Done() var ( n, wn int err error ok bool clientAddr string ) buf := b.bufPool.Get() defer b.bufPool.Put(buf) go func(ctx context.Context, backend net.PacketConn) { <-ctx.Done() _ = backend.Close() }(rel.ctx, rel.backend) udpBackendLoop: for { if n, _, err = rel.backend.ReadFrom(buf); err != nil { if errors.Is(err, net.ErrClosed) { rel.ctxCancel() break udpBackendLoop } log.Error().Caller().Err(err).Str(CONNECTION, rel.clientAddr).Str(DIRECTION, BACKEND_TO_CLIENT).Msg("udpBackend2Client: read from udp") continue udpBackendLoop } clientAddr = rel.clientAddr if rel, ok = b.findRel(clientAddr); !ok { log.Error().Caller().Msg("relation not found") return } select { case <-rel.ctx.Done(): break udpBackendLoop default: } if wn, err = b.client.WriteTo(buf[:n], rel); err != nil { // In case of error, never close b.client, it's a shared Packet Conn. // All UDP relations use a shared connection/socket back to the client if errors.Is(err, net.ErrClosed) { rel.ctxCancel() break udpBackendLoop } log.Error().Caller().Err(err).Str(CONNECTION, rel.clientAddr).Str(DIRECTION, BACKEND_TO_CLIENT).Msg("udpBackend2Client: write to client") continue udpBackendLoop } if wn != n { log.Warn().Caller().Str(CONNECTION, rel.clientAddr).Str(DIRECTION, BACKEND_TO_CLIENT).Msgf("failed to write complete message: %d vs %d", wn, n) } rel.expiry = time.Now().Add(b.cfg.Timeout()) b.addUpdateRel(clientAddr, rel) select { case <-rel.ctx.Done(): break udpBackendLoop default: } } }