looks good, it should be test worthy

This commit is contained in:
2025-05-10 11:33:30 +10:00
parent 0c1a9fb7e7
commit 9eeda34e19
3 changed files with 90 additions and 67 deletions

View File

@@ -51,36 +51,54 @@ func initUDP(wg *sync.WaitGroup, ctx context.Context, bp *BufPool, cfg abstract.
msgChan: make(chan udpMessage),
bufPool: bp,
}
defer wg.Done()
go func(wg *sync.WaitGroup, ctx context.Context, cfg abstract.UDPConnectionConfig, backend *backendUDP) {
defer wg.Done()
wg.Add(1)
handlerSpawn(wg, ctx, backend)
incoming := backend.msgChan
readIncoming:
for {
select {
case <-ctx.Done():
break readIncoming
case m := <-incoming:
backend.handle(wg, ctx, m)
}
}
}(wg, ctx, cfg, backend)
wg.Add(1)
handlerSpawn(wg, ctx, backend)
wg.Add(1)
handlerSpawn(wg, ctx, backend)
wg.Add(1)
handlerSpawn(wg, ctx, backend)
return backend
}
func handlerSpawn(wg *sync.WaitGroup, ctx context.Context, backend *backendUDP) {
defer wg.Done()
incoming := backend.msgChan
readIncoming:
for {
select {
case <-ctx.Done():
break readIncoming
case m := <-incoming:
backend.handle(wg, ctx, m)
}
}
}
func (b *backendUDP) findRel(clientAddr string) (udpRel, bool) {
b.relMtx.Lock()
defer b.relMtx.Unlock()
rel, ok := b.relations[clientAddr]
if ok && rel.expiry.Before(time.Now()) {
// expired
if rel.ctxCancel != nil {
rel.ctxCancel()
}
_ = rel.backend.Close()
// cleans up connection; do it in a separate goroutine just in case the close operation might take some time
go func(c net.PacketConn) {
_ = rel.backend.Close()
}(rel.backend)
delete(b.relations, clientAddr)
return rel, false
}
@@ -115,9 +133,11 @@ func (b *backendUDP) createRelSend(clientAddr string, buf []byte) (udpRel, error
}
if _, err = udpConn.WriteTo(buf, b.cfg.BackendAddr()); err != nil {
//TODO: I think I need to fix this. This error handling is odd.
_ = udpConn.Close()
return rel, fmt.Errorf("create udp relation and send message: write udp: %w", err)
if errors.Is(err, net.ErrClosed) {
return rel, fmt.Errorf("create udp relation and send message: write udp: %w", err)
}
log.Error().Caller().Err(err).Str(CONNECTION, b.cfg.Name()).Str(DIRECTION, CLIENT_TO_BACKEND).Msg("create udp relation & send")
}
rel.backend = udpConn
@@ -137,7 +157,7 @@ func (b *backendUDP) relSend(rel udpRel, buf []byte) error {
}
if len(buf) != n {
log.Error().Msg("relSend mismatch size")
log.Error().Caller().Msg("relSend mismatch size")
}
return nil
@@ -150,15 +170,13 @@ func (b *backendUDP) handle(wg *sync.WaitGroup, ctx context.Context, msg udpMess
err error
)
defer b.bufPool.Put(msg.buf)
if rel, ok = b.findRel(msg.addr); !ok {
if rel, err = b.createRelSend(msg.addr, msg.buf); err != nil {
b.bufPool.Put(msg.buf)
log.Error().Err(err).Str(DIRECTION, CLIENT_TO_BACKEND).Msg("establish relation with udp backend")
log.Error().Caller().Err(err).Str(DIRECTION, CLIENT_TO_BACKEND).Msg("establish relation with udp backend")
return
}
b.bufPool.Put(msg.buf)
rel.ctx, rel.ctxCancel = context.WithCancel(ctx)
b.addUpdateRel(msg.addr, rel)
@@ -168,13 +186,10 @@ func (b *backendUDP) handle(wg *sync.WaitGroup, ctx context.Context, msg udpMess
}
if err = b.relSend(rel, msg.buf); err != nil {
b.bufPool.Put(msg.buf)
log.Error().Err(err).Msg("handle: send for existing relation")
log.Error().Caller().Err(err).Msg("handle: send for existing relation")
return
}
b.bufPool.Put(msg.buf)
rel.expiry = time.Now().Add(b.cfg.Timeout())
b.addUpdateRel(rel.clientAddr, rel)
}
@@ -184,11 +199,10 @@ func (b *backendUDP) Send(ctx context.Context, addr string, p []byte) error {
n int
)
buf := b.bufPool.Get()
buf := b.bufPool.Get() // the buf will be released in handle
n = copy(buf, p)
if len(p) != n {
b.bufPool.Put(buf)
return fmt.Errorf("send udp message to handler: failed to write complete message")
}
@@ -198,6 +212,7 @@ func (b *backendUDP) Send(ctx context.Context, addr string, p []byte) error {
buf: buf[:n],
}:
case <-ctx.Done():
return context.Canceled
}
return nil
@@ -213,29 +228,29 @@ func (b *backendUDP) udpBackend2Client(wg *sync.WaitGroup, rel udpRel) {
clientAddr string
)
buf := make([]byte, BUF_SIZE)
buf := b.bufPool.Get()
defer b.bufPool.Put(buf)
go func(ctx context.Context, backend net.PacketConn) {
<-ctx.Done()
_ = backend.Close()
}(rel.ctx, rel.backend)
udpBackendLoop:
for {
if n, _, err = rel.backend.ReadFrom(buf); err != nil {
select {
case <-rel.ctx.Done():
break udpBackendLoop
default:
}
if errors.Is(err, net.ErrClosed) { //TODO: implement this error handling to all socket read/write operation
if errors.Is(err, net.ErrClosed) {
rel.ctxCancel()
break udpBackendLoop
}
log.Error().Err(err).Str(CONNECTION, rel.clientAddr).Str(DIRECTION, BACKEND_TO_CLIENT).Msg("udpBackend2Client: read from udp")
log.Error().Caller().Err(err).Str(CONNECTION, rel.clientAddr).Str(DIRECTION, BACKEND_TO_CLIENT).Msg("udpBackend2Client: read from udp")
continue udpBackendLoop
}
clientAddr = rel.clientAddr
if rel, ok = b.findRel(clientAddr); !ok {
log.Error().Msg("relation not found")
log.Error().Caller().Msg("relation not found")
return
}
@@ -246,11 +261,19 @@ udpBackendLoop:
}
if wn, err = b.client.WriteTo(buf[:n], rel); err != nil {
//TODO: error handling
// in case of error, never close b.client, it's a shared Packet Conn.
// All UDP relation use a shared connection/socket back to the client
if errors.Is(err, net.ErrClosed) {
rel.ctxCancel()
break udpBackendLoop
}
log.Error().Caller().Err(err).Str(CONNECTION, rel.clientAddr).Str(DIRECTION, BACKEND_TO_CLIENT).Msg("udpBackend2Client: write to client")
continue udpBackendLoop
}
if wn != n {
//TODO: error when mismatch length
log.Warn().Caller().Str(CONNECTION, rel.clientAddr).Str(DIRECTION, BACKEND_TO_CLIENT).Msgf("failed to write complete message: %d vs %d", wn, n)
}
rel.expiry = time.Now().Add(b.cfg.Timeout())