diff --git a/server.go b/server.go index 5e4ec92b..6a74ea87 100644 --- a/server.go +++ b/server.go @@ -47,6 +47,7 @@ type response struct { tcp *net.TCPConn // i/o connection if TCP was used udpSession *SessionUDP // oob data to get egress interface right remoteAddr net.Addr // address of the client + writer Writer // writer to output the raw DNS bits } // ServeMux is an DNS request multiplexer. It matches the @@ -197,6 +198,43 @@ func HandleFunc(pattern string, handler func(ResponseWriter, *Msg)) { DefaultServeMux.HandleFunc(pattern, handler) } +// Writer writes raw DNS messages; each call to Write should send an entire message. +type Writer interface { + io.Writer +} + +// Reader reads raw DNS messages; each call to ReadTCP or ReadUDP should return an entire message. +type Reader interface { + // ReadTCP reads a raw message from a TCP connection. Implementations may alter + // connection properties, for example the read-deadline. + ReadTCP(conn *net.TCPConn, timeout time.Duration) ([]byte, error) + // ReadUDP reads a raw message from a UDP connection. Implementations may alter + // connection properties, for example the read-deadline. + ReadUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *SessionUDP, error) +} + +// defaultReader is an adapter for the Server struct that implements the Reader interface +// using the readTCP and readUDP func of the embedded Server. +type defaultReader struct { + *Server +} + +func (dr *defaultReader) ReadTCP(conn *net.TCPConn, timeout time.Duration) ([]byte, error) { + return dr.readTCP(conn, timeout) +} + +func (dr *defaultReader) ReadUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *SessionUDP, error) { + return dr.readUDP(conn, timeout) +} + +// DecorateReader is a decorator hook for extending or supplanting the functionality of a Reader. +// Implementations should never return a nil Reader. +type DecorateReader func(Reader) Reader + +// DecorateWriter is a decorator hook for extending or supplanting the functionality of a Writer. +// Implementations should never return a nil Writer. +type DecorateWriter func(Writer) Writer + // A Server defines parameters for running an DNS server. type Server struct { // Address to listen on, ":dns" if empty. @@ -225,6 +263,10 @@ type Server struct { Unsafe bool // If NotifyStartedFunc is set is is called, once the server has started listening. NotifyStartedFunc func() + // DecorateReader is optional, allows customization of the process that reads raw DNS messages. + DecorateReader DecorateReader + // DecorateWriter is optional, allows customization of the process that writes raw DNS messages. + DecorateWriter DecorateWriter // For graceful shutdown. stopUDP chan bool @@ -382,6 +424,11 @@ func (srv *Server) serveTCP(l *net.TCPListener) error { srv.NotifyStartedFunc() } + reader := Reader(&defaultReader{srv}) + if srv.DecorateReader != nil { + reader = srv.DecorateReader(reader) + } + handler := srv.Handler if handler == nil { handler = DefaultServeMux @@ -393,7 +440,7 @@ func (srv *Server) serveTCP(l *net.TCPListener) error { if e != nil { continue } - m, e := srv.readTCP(rw, rtimeout) + m, e := reader.ReadTCP(rw, rtimeout) select { case <-srv.stopTCP: return nil @@ -417,6 +464,11 @@ func (srv *Server) serveUDP(l *net.UDPConn) error { srv.NotifyStartedFunc() } + reader := Reader(&defaultReader{srv}) + if srv.DecorateReader != nil { + reader = srv.DecorateReader(reader) + } + handler := srv.Handler if handler == nil { handler = DefaultServeMux @@ -424,7 +476,7 @@ func (srv *Server) serveUDP(l *net.UDPConn) error { rtimeout := srv.getReadTimeout() // deadline is not used here for { - m, s, e := srv.readUDP(l, rtimeout) + m, s, e := reader.ReadUDP(l, rtimeout) select { case <-srv.stopUDP: return nil @@ -442,6 +494,12 @@ func (srv *Server) serveUDP(l *net.UDPConn) error { // Serve a new connection. func (srv *Server) serve(a net.Addr, h Handler, m []byte, u *net.UDPConn, s *SessionUDP, t *net.TCPConn) { w := &response{tsigSecret: srv.TsigSecret, udp: u, tcp: t, remoteAddr: a, udpSession: s} + if srv.DecorateWriter != nil { + w.writer = srv.DecorateWriter(w) + } else { + w.writer = w + } + q := 0 defer func() { if u != nil { @@ -451,6 +509,11 @@ func (srv *Server) serve(a net.Addr, h Handler, m []byte, u *net.UDPConn, s *Ses srv.wgTCP.Done() } }() + + reader := Reader(&defaultReader{srv}) + if srv.DecorateReader != nil { + reader = srv.DecorateReader(reader) + } Redo: req := new(Msg) err := req.Unpack(m) @@ -490,7 +553,7 @@ Exit: if srv.IdleTimeout != nil { idleTimeout = srv.IdleTimeout() } - m, e := srv.readTCP(w.tcp, idleTimeout) + m, e := reader.ReadTCP(w.tcp, idleTimeout) if e == nil { q++ // TODO(miek): make this number configurable? @@ -562,7 +625,7 @@ func (w *response) WriteMsg(m *Msg) (err error) { if err != nil { return err } - _, err = w.Write(data) + _, err = w.writer.Write(data) return err } } @@ -570,7 +633,7 @@ func (w *response) WriteMsg(m *Msg) (err error) { if err != nil { return err } - _, err = w.Write(data) + _, err = w.writer.Write(data) return err } diff --git a/server_test.go b/server_test.go index dff0fb52..2ff606ac 100644 --- a/server_test.go +++ b/server_test.go @@ -397,3 +397,54 @@ func TestShutdownUDP(t *testing.T) { t.Errorf("Could not shutdown test UDP server, %v", err) } } + +type ExampleFrameLengthWriter struct { + Writer +} + +func (e *ExampleFrameLengthWriter) Write(m []byte) (int, error) { + fmt.Println("writing raw DNS message of length", len(m)) + return e.Writer.Write(m) +} + +func ExampleDecorateWriter() { + // instrument raw DNS message writing + wf := DecorateWriter(func(w Writer) Writer { + return &ExampleFrameLengthWriter{w} + }) + + // simple UDP server + pc, err := net.ListenPacket("udp", "127.0.0.1:0") + if err != nil { + fmt.Println(err.Error()) + return + } + server := &Server{ + PacketConn: pc, + DecorateWriter: wf, + } + + waitLock := sync.Mutex{} + waitLock.Lock() + server.NotifyStartedFunc = waitLock.Unlock + defer server.Shutdown() + + go func() { + server.ActivateAndServe() + pc.Close() + }() + + waitLock.Lock() + + HandleFunc("miek.nl.", HelloServer) + + c := new(Client) + m := new(Msg) + m.SetQuestion("miek.nl.", TypeTXT) + _, _, err = c.Exchange(m, pc.LocalAddr().String()) + if err != nil { + fmt.Println("failed to exchange", err.Error()) + return + } + // Output: writing raw DNS message of length 56 +}