//go:build !windows // +build !windows package dns import ( "net" "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" ) // This is the required size of the OOB buffer to pass to ReadMsgUDP. var udpOOBSize = func() int { // We can't know whether we'll get an IPv4 control message or an // IPv6 control message ahead of time. To get around this, we size // the buffer equal to the largest of the two. oob4 := ipv4.NewControlMessage(ipv4.FlagDst | ipv4.FlagInterface) oob6 := ipv6.NewControlMessage(ipv6.FlagDst | ipv6.FlagInterface) if len(oob4) > len(oob6) { return len(oob4) } return len(oob6) }() // SessionUDP holds the remote address and the associated // out-of-band data. type SessionUDP struct { raddr *net.UDPAddr context []byte } // RemoteAddr returns the remote network address. func (s *SessionUDP) RemoteAddr() net.Addr { return s.raddr } // ReadFromSessionUDP acts just like net.UDPConn.ReadFrom(), but returns a session object instead of a // net.UDPAddr. func ReadFromSessionUDP(conn *net.UDPConn, b []byte) (int, *SessionUDP, error) { oob := make([]byte, udpOOBSize) n, oobn, _, raddr, err := conn.ReadMsgUDP(b, oob) if err != nil { return n, nil, err } return n, &SessionUDP{raddr, oob[:oobn]}, err } // WriteToSessionUDP acts just like net.UDPConn.WriteTo(), but uses a *SessionUDP instead of a net.Addr. func WriteToSessionUDP(conn *net.UDPConn, b []byte, session *SessionUDP) (int, error) { oob := correctSource(session.context) n, _, err := conn.WriteMsgUDP(b, oob, session.raddr) return n, err } func setUDPSocketOptions(conn *net.UDPConn) error { // Try setting the flags for both families and ignore the errors unless they // both error. err6 := ipv6.NewPacketConn(conn).SetControlMessage(ipv6.FlagDst|ipv6.FlagInterface, true) err4 := ipv4.NewPacketConn(conn).SetControlMessage(ipv4.FlagDst|ipv4.FlagInterface, true) if err6 != nil && err4 != nil { return err4 } return nil } // parseDstFromOOB takes oob data and returns the destination IP. func parseDstFromOOB(oob []byte) net.IP { // Start with IPv6 and then fallback to IPv4 // TODO(fastest963): Figure out a way to prefer one or the other. Looking at // the lvl of the header for a 0 or 41 isn't cross-platform. cm6 := new(ipv6.ControlMessage) if cm6.Parse(oob) == nil && cm6.Dst != nil { return cm6.Dst } cm4 := new(ipv4.ControlMessage) if cm4.Parse(oob) == nil && cm4.Dst != nil { return cm4.Dst } return nil } // correctSource takes oob data and returns new oob data with the Src equal to the Dst func correctSource(oob []byte) []byte { dst := parseDstFromOOB(oob) if dst == nil { return nil } // If the dst is definitely an IPv6, then use ipv6's ControlMessage to // respond otherwise use ipv4's because ipv6's marshal ignores ipv4 // addresses. if dst.To4() == nil { cm := new(ipv6.ControlMessage) cm.Src = dst oob = cm.Marshal() } else { cm := new(ipv4.ControlMessage) cm.Src = dst oob = cm.Marshal() } return oob }