diff --git a/server.go b/server.go index 5e4ec92b..36fee173 100644 --- a/server.go +++ b/server.go @@ -47,6 +47,7 @@ type response struct { tcp *net.TCPConn // i/o connection if TCP was used udpSession *SessionUDP // oob data to get egress interface right remoteAddr net.Addr // address of the client + writer Writer // writer to output the raw DNS bits } // ServeMux is an DNS request multiplexer. It matches the @@ -197,6 +198,31 @@ func HandleFunc(pattern string, handler func(ResponseWriter, *Msg)) { DefaultServeMux.HandleFunc(pattern, handler) } +type Writer interface { + io.Writer +} + +type Reader interface { + ReadTCP(conn *net.TCPConn, timeout time.Duration) ([]byte, error) + ReadUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *SessionUDP, error) +} + +type defaultReader struct { + *Server +} + +func (dr *defaultReader) ReadTCP(conn *net.TCPConn, timeout time.Duration) ([]byte, error) { + return dr.readTCP(conn, timeout) +} + +func (dr *defaultReader) ReadUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *SessionUDP, error) { + return dr.readUDP(conn, timeout) +} + +type ReaderBuilder func(Reader) Reader + +type WriterBuilder func(Writer) Writer + // A Server defines parameters for running an DNS server. type Server struct { // Address to listen on, ":dns" if empty. @@ -225,6 +251,10 @@ type Server struct { Unsafe bool // If NotifyStartedFunc is set is is called, once the server has started listening. NotifyStartedFunc func() + // ReaderBuilder is optional, allows customization of the process that reads DNS frames + ReaderBuilder ReaderBuilder + // WriterBuilder is optional, allows customization of the process that writes DNS frames + WriterBuilder WriterBuilder // For graceful shutdown. stopUDP chan bool @@ -382,6 +412,11 @@ func (srv *Server) serveTCP(l *net.TCPListener) error { srv.NotifyStartedFunc() } + reader := Reader(&defaultReader{srv}) + if srv.ReaderBuilder != nil { + reader = srv.ReaderBuilder(reader) + } + handler := srv.Handler if handler == nil { handler = DefaultServeMux @@ -393,7 +428,7 @@ func (srv *Server) serveTCP(l *net.TCPListener) error { if e != nil { continue } - m, e := srv.readTCP(rw, rtimeout) + m, e := reader.ReadTCP(rw, rtimeout) select { case <-srv.stopTCP: return nil @@ -417,6 +452,11 @@ func (srv *Server) serveUDP(l *net.UDPConn) error { srv.NotifyStartedFunc() } + reader := Reader(&defaultReader{srv}) + if srv.ReaderBuilder != nil { + reader = srv.ReaderBuilder(reader) + } + handler := srv.Handler if handler == nil { handler = DefaultServeMux @@ -424,7 +464,7 @@ func (srv *Server) serveUDP(l *net.UDPConn) error { rtimeout := srv.getReadTimeout() // deadline is not used here for { - m, s, e := srv.readUDP(l, rtimeout) + m, s, e := reader.ReadUDP(l, rtimeout) select { case <-srv.stopUDP: return nil @@ -442,6 +482,12 @@ func (srv *Server) serveUDP(l *net.UDPConn) error { // Serve a new connection. func (srv *Server) serve(a net.Addr, h Handler, m []byte, u *net.UDPConn, s *SessionUDP, t *net.TCPConn) { w := &response{tsigSecret: srv.TsigSecret, udp: u, tcp: t, remoteAddr: a, udpSession: s} + if srv.WriterBuilder != nil { + w.writer = srv.WriterBuilder(w) + } else { + w.writer = w + } + q := 0 defer func() { if u != nil { @@ -451,6 +497,11 @@ func (srv *Server) serve(a net.Addr, h Handler, m []byte, u *net.UDPConn, s *Ses srv.wgTCP.Done() } }() + + reader := Reader(&defaultReader{srv}) + if srv.ReaderBuilder != nil { + reader = srv.ReaderBuilder(reader) + } Redo: req := new(Msg) err := req.Unpack(m) @@ -490,7 +541,7 @@ Exit: if srv.IdleTimeout != nil { idleTimeout = srv.IdleTimeout() } - m, e := srv.readTCP(w.tcp, idleTimeout) + m, e := reader.ReadTCP(w.tcp, idleTimeout) if e == nil { q++ // TODO(miek): make this number configurable? @@ -562,7 +613,7 @@ func (w *response) WriteMsg(m *Msg) (err error) { if err != nil { return err } - _, err = w.Write(data) + _, err = w.writer.Write(data) return err } } @@ -570,7 +621,7 @@ func (w *response) WriteMsg(m *Msg) (err error) { if err != nil { return err } - _, err = w.Write(data) + _, err = w.writer.Write(data) return err }