diff --git a/server.go b/server.go index 66594d7f..e6772549 100644 --- a/server.go +++ b/server.go @@ -44,9 +44,10 @@ type response struct { tsigTimersOnly bool tsigRequestMAC string tsigSecret map[string]string // the tsig secrets - udp *net.UDPConn // i/o connection if UDP was used - tcp *net.TCPConn // i/o connection if TCP was used - remoteAddr net.Addr // address of the client + udp *UDPConn // i/o connection if UDP was used + udpSession *UDPSession + tcp *net.TCPConn // i/o connection if TCP was used + remoteAddr net.Addr // address of the client } // ServeMux is an DNS request multiplexer. It matches the @@ -242,7 +243,13 @@ func (srv *Server) ListenAndServe() error { if e != nil { return e } - return srv.serveUDP(l) + + ll, e := NewUDPConn(l) + if e != nil { + return e + } + + return srv.serveUDP(ll) } return &Error{err: "bad network"} } @@ -268,14 +275,14 @@ func (srv *Server) serveTCP(l *net.TCPListener) error { if e != nil { continue } - go srv.serve(rw.RemoteAddr(), handler, m, nil, rw) + go srv.serve(rw.RemoteAddr(), handler, m, nil, nil, rw) } panic("dns: not reached") } // serveUDP starts a UDP listener for the server. // Each request is handled in a seperate goroutine. -func (srv *Server) serveUDP(l *net.UDPConn) error { +func (srv *Server) serveUDP(l *UDPConn) error { defer l.Close() handler := srv.Handler if handler == nil { @@ -286,19 +293,19 @@ func (srv *Server) serveUDP(l *net.UDPConn) error { rtimeout = srv.ReadTimeout } for { - m, a, e := srv.readUDP(l, rtimeout) + m, s, e := srv.readUDP(l, rtimeout) if e != nil { // TODO(miek): logging? continue } - go srv.serve(a, handler, m, l, nil) + go srv.serve(s.RemoteAddr(), handler, m, l, s, nil) } panic("dns: not reached") } // Serve a new connection. -func (srv *Server) serve(a net.Addr, h Handler, m []byte, u *net.UDPConn, t *net.TCPConn) { - w := &response{tsigSecret: srv.TsigSecret, udp: u, tcp: t, remoteAddr: a} +func (srv *Server) serve(a net.Addr, h Handler, m []byte, u *UDPConn, s *UDPSession, t *net.TCPConn) { + w := &response{tsigSecret: srv.TsigSecret, udp: u, tcp: t, remoteAddr: a, udpSession: s} q := 0 Redo: req := new(Msg) @@ -385,9 +392,9 @@ func (srv *Server) readTCP(conn *net.TCPConn, timeout time.Duration) ([]byte, er return m, nil } -func (srv *Server) readUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, net.Addr, error) { +func (srv *Server) readUDP(conn *UDPConn, timeout time.Duration) ([]byte, *UDPSession, error) { m := make([]byte, srv.UDPSize) - n, a, e := conn.ReadFromUDP(m) + n, s, e := conn.ReadFromSessionUDP(m) if e != nil || n == 0 { if e != nil { return nil, nil, e @@ -395,7 +402,7 @@ func (srv *Server) readUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, ne return nil, nil, ErrShortRead } m = m[:n] - return m, a, nil + return m, s, nil } // WriteMsg implements the ResponseWriter.WriteMsg method. @@ -423,7 +430,7 @@ func (w *response) WriteMsg(m *Msg) (err error) { func (w *response) Write(m []byte) (int, error) { switch { case w.udp != nil: - n, err := w.udp.WriteTo(m, w.remoteAddr) + n, err := w.udp.WriteToSessionUDP(m, w.udpSession) return n, err case w.tcp != nil: lm := len(m) diff --git a/udp.go b/udp.go new file mode 100644 index 00000000..acc84a61 --- /dev/null +++ b/udp.go @@ -0,0 +1,45 @@ +package dns + +import ( + "net" +) + +type UDPSession struct { + raddr *net.UDPAddr + context []byte +} + +func (session *UDPSession) RemoteAddr() net.Addr { + return session.raddr +} + +type UDPConn struct { + *net.UDPConn +} + +func NewUDPConn(conn *net.UDPConn) (newconn *UDPConn, err error) { + err = udpSocketOobData(conn) + if err != nil { + return + } + + return &UDPConn{conn}, nil +} + +func (conn *UDPConn) ReadFromSessionUDP(b []byte) (n int, session *UDPSession, err error) { + oob := make([]byte, 1024) + + n, oobn, _, raddr, err := conn.ReadMsgUDP(b, oob) + if err != nil { + return + } + + session = &UDPSession{raddr, oob[:oobn]} + + return +} + +func (conn *UDPConn) WriteToSessionUDP(b []byte, session *UDPSession) (n int, err error) { + n, _, err = conn.WriteMsgUDP(b, session.context, session.raddr) + return +} diff --git a/udp_linux.go b/udp_linux.go new file mode 100644 index 00000000..044b1eb2 --- /dev/null +++ b/udp_linux.go @@ -0,0 +1,19 @@ +// +build linux + +package dns + +import ( + "net" + "syscall" +) + +func udpSocketOobData(conn *net.UDPConn) (err error) { + file, err := conn.File() + if err != nil { + return + } + + err = syscall.SetsockoptInt(int(file.Fd()), syscall.IPPROTO_IP, syscall.IP_PKTINFO, 1) + + return +} diff --git a/udp_other.go b/udp_other.go new file mode 100644 index 00000000..0f7adfab --- /dev/null +++ b/udp_other.go @@ -0,0 +1,11 @@ +// +build !linux + +package dns + +import ( + "net" +) + +func udpSocketOobData(conn *net.UDPConn) (err error) { + return +}