From 800934f8d417745b75bad7f93a540f280a053a08 Mon Sep 17 00:00:00 2001 From: Tom Thorogood Date: Fri, 6 Apr 2018 23:21:04 +0930 Subject: [PATCH] Cleanup udp.go somewhat (#662) * Avoid hard-coding required UDP OOB buffer size * Simplify parseDstFromOOB if statements * Use self-invoking function for udpOOBSize variable --- udp.go | 33 +++++++++++++++++++++++---------- 1 file changed, 23 insertions(+), 10 deletions(-) diff --git a/udp.go b/udp.go index f3f31a7a..a4826ee2 100644 --- a/udp.go +++ b/udp.go @@ -9,6 +9,22 @@ import ( "golang.org/x/net/ipv6" ) +// This is the required size of the OOB buffer to pass to ReadMsgUDP. +var udpOOBSize = func() int { + // We can't know whether we'll get an IPv4 control message or an + // IPv6 control message ahead of time. To get around this, we size + // the buffer equal to the largest of the two. + + oob4 := ipv4.NewControlMessage(ipv4.FlagDst | ipv4.FlagInterface) + oob6 := ipv6.NewControlMessage(ipv6.FlagDst | ipv6.FlagInterface) + + if len(oob4) > len(oob6) { + return len(oob4) + } + + return len(oob6) +}() + // SessionUDP holds the remote address and the associated // out-of-band data. type SessionUDP struct { @@ -22,7 +38,7 @@ func (s *SessionUDP) RemoteAddr() net.Addr { return s.raddr } // ReadFromSessionUDP acts just like net.UDPConn.ReadFrom(), but returns a session object instead of a // net.UDPAddr. func ReadFromSessionUDP(conn *net.UDPConn, b []byte) (int, *SessionUDP, error) { - oob := make([]byte, 40) + oob := make([]byte, udpOOBSize) n, oobn, _, raddr, err := conn.ReadMsgUDP(b, oob) if err != nil { return n, nil, err @@ -53,18 +69,15 @@ func parseDstFromOOB(oob []byte) net.IP { // Start with IPv6 and then fallback to IPv4 // TODO(fastest963): Figure out a way to prefer one or the other. Looking at // the lvl of the header for a 0 or 41 isn't cross-platform. - var dst net.IP cm6 := new(ipv6.ControlMessage) - if cm6.Parse(oob) == nil { - dst = cm6.Dst + if cm6.Parse(oob) == nil && cm6.Dst != nil { + return cm6.Dst } - if dst == nil { - cm4 := new(ipv4.ControlMessage) - if cm4.Parse(oob) == nil { - dst = cm4.Dst - } + cm4 := new(ipv4.ControlMessage) + if cm4.Parse(oob) == nil && cm4.Dst != nil { + return cm4.Dst } - return dst + return nil } // correctSource takes oob data and returns new oob data with the Src equal to the Dst