add interfaces to allow packet-level inspection for pre/post processing
This commit is contained in:
parent
16c1d54948
commit
e0f83dee9a
61
server.go
61
server.go
|
@ -47,6 +47,7 @@ type response struct {
|
|||
tcp *net.TCPConn // i/o connection if TCP was used
|
||||
udpSession *SessionUDP // oob data to get egress interface right
|
||||
remoteAddr net.Addr // address of the client
|
||||
writer Writer // writer to output the raw DNS bits
|
||||
}
|
||||
|
||||
// ServeMux is an DNS request multiplexer. It matches the
|
||||
|
@ -197,6 +198,31 @@ func HandleFunc(pattern string, handler func(ResponseWriter, *Msg)) {
|
|||
DefaultServeMux.HandleFunc(pattern, handler)
|
||||
}
|
||||
|
||||
type Writer interface {
|
||||
io.Writer
|
||||
}
|
||||
|
||||
type Reader interface {
|
||||
ReadTCP(conn *net.TCPConn, timeout time.Duration) ([]byte, error)
|
||||
ReadUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *SessionUDP, error)
|
||||
}
|
||||
|
||||
type defaultReader struct {
|
||||
*Server
|
||||
}
|
||||
|
||||
func (dr *defaultReader) ReadTCP(conn *net.TCPConn, timeout time.Duration) ([]byte, error) {
|
||||
return dr.readTCP(conn, timeout)
|
||||
}
|
||||
|
||||
func (dr *defaultReader) ReadUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *SessionUDP, error) {
|
||||
return dr.readUDP(conn, timeout)
|
||||
}
|
||||
|
||||
type ReaderBuilder func(Reader) Reader
|
||||
|
||||
type WriterBuilder func(Writer) Writer
|
||||
|
||||
// A Server defines parameters for running an DNS server.
|
||||
type Server struct {
|
||||
// Address to listen on, ":dns" if empty.
|
||||
|
@ -225,6 +251,10 @@ type Server struct {
|
|||
Unsafe bool
|
||||
// If NotifyStartedFunc is set is is called, once the server has started listening.
|
||||
NotifyStartedFunc func()
|
||||
// ReaderBuilder is optional, allows customization of the process that reads DNS frames
|
||||
ReaderBuilder ReaderBuilder
|
||||
// WriterBuilder is optional, allows customization of the process that writes DNS frames
|
||||
WriterBuilder WriterBuilder
|
||||
|
||||
// For graceful shutdown.
|
||||
stopUDP chan bool
|
||||
|
@ -382,6 +412,11 @@ func (srv *Server) serveTCP(l *net.TCPListener) error {
|
|||
srv.NotifyStartedFunc()
|
||||
}
|
||||
|
||||
reader := Reader(&defaultReader{srv})
|
||||
if srv.ReaderBuilder != nil {
|
||||
reader = srv.ReaderBuilder(reader)
|
||||
}
|
||||
|
||||
handler := srv.Handler
|
||||
if handler == nil {
|
||||
handler = DefaultServeMux
|
||||
|
@ -393,7 +428,7 @@ func (srv *Server) serveTCP(l *net.TCPListener) error {
|
|||
if e != nil {
|
||||
continue
|
||||
}
|
||||
m, e := srv.readTCP(rw, rtimeout)
|
||||
m, e := reader.ReadTCP(rw, rtimeout)
|
||||
select {
|
||||
case <-srv.stopTCP:
|
||||
return nil
|
||||
|
@ -417,6 +452,11 @@ func (srv *Server) serveUDP(l *net.UDPConn) error {
|
|||
srv.NotifyStartedFunc()
|
||||
}
|
||||
|
||||
reader := Reader(&defaultReader{srv})
|
||||
if srv.ReaderBuilder != nil {
|
||||
reader = srv.ReaderBuilder(reader)
|
||||
}
|
||||
|
||||
handler := srv.Handler
|
||||
if handler == nil {
|
||||
handler = DefaultServeMux
|
||||
|
@ -424,7 +464,7 @@ func (srv *Server) serveUDP(l *net.UDPConn) error {
|
|||
rtimeout := srv.getReadTimeout()
|
||||
// deadline is not used here
|
||||
for {
|
||||
m, s, e := srv.readUDP(l, rtimeout)
|
||||
m, s, e := reader.ReadUDP(l, rtimeout)
|
||||
select {
|
||||
case <-srv.stopUDP:
|
||||
return nil
|
||||
|
@ -442,6 +482,12 @@ func (srv *Server) serveUDP(l *net.UDPConn) error {
|
|||
// Serve a new connection.
|
||||
func (srv *Server) serve(a net.Addr, h Handler, m []byte, u *net.UDPConn, s *SessionUDP, t *net.TCPConn) {
|
||||
w := &response{tsigSecret: srv.TsigSecret, udp: u, tcp: t, remoteAddr: a, udpSession: s}
|
||||
if srv.WriterBuilder != nil {
|
||||
w.writer = srv.WriterBuilder(w)
|
||||
} else {
|
||||
w.writer = w
|
||||
}
|
||||
|
||||
q := 0
|
||||
defer func() {
|
||||
if u != nil {
|
||||
|
@ -451,6 +497,11 @@ func (srv *Server) serve(a net.Addr, h Handler, m []byte, u *net.UDPConn, s *Ses
|
|||
srv.wgTCP.Done()
|
||||
}
|
||||
}()
|
||||
|
||||
reader := Reader(&defaultReader{srv})
|
||||
if srv.ReaderBuilder != nil {
|
||||
reader = srv.ReaderBuilder(reader)
|
||||
}
|
||||
Redo:
|
||||
req := new(Msg)
|
||||
err := req.Unpack(m)
|
||||
|
@ -490,7 +541,7 @@ Exit:
|
|||
if srv.IdleTimeout != nil {
|
||||
idleTimeout = srv.IdleTimeout()
|
||||
}
|
||||
m, e := srv.readTCP(w.tcp, idleTimeout)
|
||||
m, e := reader.ReadTCP(w.tcp, idleTimeout)
|
||||
if e == nil {
|
||||
q++
|
||||
// TODO(miek): make this number configurable?
|
||||
|
@ -562,7 +613,7 @@ func (w *response) WriteMsg(m *Msg) (err error) {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = w.Write(data)
|
||||
_, err = w.writer.Write(data)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
@ -570,7 +621,7 @@ func (w *response) WriteMsg(m *Msg) (err error) {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = w.Write(data)
|
||||
_, err = w.writer.Write(data)
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue