Merge pull request #236 from jdef/serve-wrapper
add interfaces to allow packet-level inspection for pre/post processing
This commit is contained in:
commit
9de5f84650
73
server.go
73
server.go
|
@ -47,6 +47,7 @@ type response struct {
|
||||||
tcp *net.TCPConn // i/o connection if TCP was used
|
tcp *net.TCPConn // i/o connection if TCP was used
|
||||||
udpSession *SessionUDP // oob data to get egress interface right
|
udpSession *SessionUDP // oob data to get egress interface right
|
||||||
remoteAddr net.Addr // address of the client
|
remoteAddr net.Addr // address of the client
|
||||||
|
writer Writer // writer to output the raw DNS bits
|
||||||
}
|
}
|
||||||
|
|
||||||
// ServeMux is an DNS request multiplexer. It matches the
|
// ServeMux is an DNS request multiplexer. It matches the
|
||||||
|
@ -197,6 +198,43 @@ func HandleFunc(pattern string, handler func(ResponseWriter, *Msg)) {
|
||||||
DefaultServeMux.HandleFunc(pattern, handler)
|
DefaultServeMux.HandleFunc(pattern, handler)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Writer writes raw DNS messages; each call to Write should send an entire message.
|
||||||
|
type Writer interface {
|
||||||
|
io.Writer
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reader reads raw DNS messages; each call to ReadTCP or ReadUDP should return an entire message.
|
||||||
|
type Reader interface {
|
||||||
|
// ReadTCP reads a raw message from a TCP connection. Implementations may alter
|
||||||
|
// connection properties, for example the read-deadline.
|
||||||
|
ReadTCP(conn *net.TCPConn, timeout time.Duration) ([]byte, error)
|
||||||
|
// ReadUDP reads a raw message from a UDP connection. Implementations may alter
|
||||||
|
// connection properties, for example the read-deadline.
|
||||||
|
ReadUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *SessionUDP, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// defaultReader is an adapter for the Server struct that implements the Reader interface
|
||||||
|
// using the readTCP and readUDP func of the embedded Server.
|
||||||
|
type defaultReader struct {
|
||||||
|
*Server
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dr *defaultReader) ReadTCP(conn *net.TCPConn, timeout time.Duration) ([]byte, error) {
|
||||||
|
return dr.readTCP(conn, timeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dr *defaultReader) ReadUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *SessionUDP, error) {
|
||||||
|
return dr.readUDP(conn, timeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DecorateReader is a decorator hook for extending or supplanting the functionality of a Reader.
|
||||||
|
// Implementations should never return a nil Reader.
|
||||||
|
type DecorateReader func(Reader) Reader
|
||||||
|
|
||||||
|
// DecorateWriter is a decorator hook for extending or supplanting the functionality of a Writer.
|
||||||
|
// Implementations should never return a nil Writer.
|
||||||
|
type DecorateWriter func(Writer) Writer
|
||||||
|
|
||||||
// A Server defines parameters for running an DNS server.
|
// A Server defines parameters for running an DNS server.
|
||||||
type Server struct {
|
type Server struct {
|
||||||
// Address to listen on, ":dns" if empty.
|
// Address to listen on, ":dns" if empty.
|
||||||
|
@ -225,6 +263,10 @@ type Server struct {
|
||||||
Unsafe bool
|
Unsafe bool
|
||||||
// If NotifyStartedFunc is set is is called, once the server has started listening.
|
// If NotifyStartedFunc is set is is called, once the server has started listening.
|
||||||
NotifyStartedFunc func()
|
NotifyStartedFunc func()
|
||||||
|
// DecorateReader is optional, allows customization of the process that reads raw DNS messages.
|
||||||
|
DecorateReader DecorateReader
|
||||||
|
// DecorateWriter is optional, allows customization of the process that writes raw DNS messages.
|
||||||
|
DecorateWriter DecorateWriter
|
||||||
|
|
||||||
// For graceful shutdown.
|
// For graceful shutdown.
|
||||||
stopUDP chan bool
|
stopUDP chan bool
|
||||||
|
@ -382,6 +424,11 @@ func (srv *Server) serveTCP(l *net.TCPListener) error {
|
||||||
srv.NotifyStartedFunc()
|
srv.NotifyStartedFunc()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
reader := Reader(&defaultReader{srv})
|
||||||
|
if srv.DecorateReader != nil {
|
||||||
|
reader = srv.DecorateReader(reader)
|
||||||
|
}
|
||||||
|
|
||||||
handler := srv.Handler
|
handler := srv.Handler
|
||||||
if handler == nil {
|
if handler == nil {
|
||||||
handler = DefaultServeMux
|
handler = DefaultServeMux
|
||||||
|
@ -393,7 +440,7 @@ func (srv *Server) serveTCP(l *net.TCPListener) error {
|
||||||
if e != nil {
|
if e != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
m, e := srv.readTCP(rw, rtimeout)
|
m, e := reader.ReadTCP(rw, rtimeout)
|
||||||
select {
|
select {
|
||||||
case <-srv.stopTCP:
|
case <-srv.stopTCP:
|
||||||
return nil
|
return nil
|
||||||
|
@ -417,6 +464,11 @@ func (srv *Server) serveUDP(l *net.UDPConn) error {
|
||||||
srv.NotifyStartedFunc()
|
srv.NotifyStartedFunc()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
reader := Reader(&defaultReader{srv})
|
||||||
|
if srv.DecorateReader != nil {
|
||||||
|
reader = srv.DecorateReader(reader)
|
||||||
|
}
|
||||||
|
|
||||||
handler := srv.Handler
|
handler := srv.Handler
|
||||||
if handler == nil {
|
if handler == nil {
|
||||||
handler = DefaultServeMux
|
handler = DefaultServeMux
|
||||||
|
@ -424,7 +476,7 @@ func (srv *Server) serveUDP(l *net.UDPConn) error {
|
||||||
rtimeout := srv.getReadTimeout()
|
rtimeout := srv.getReadTimeout()
|
||||||
// deadline is not used here
|
// deadline is not used here
|
||||||
for {
|
for {
|
||||||
m, s, e := srv.readUDP(l, rtimeout)
|
m, s, e := reader.ReadUDP(l, rtimeout)
|
||||||
select {
|
select {
|
||||||
case <-srv.stopUDP:
|
case <-srv.stopUDP:
|
||||||
return nil
|
return nil
|
||||||
|
@ -442,6 +494,12 @@ func (srv *Server) serveUDP(l *net.UDPConn) error {
|
||||||
// Serve a new connection.
|
// Serve a new connection.
|
||||||
func (srv *Server) serve(a net.Addr, h Handler, m []byte, u *net.UDPConn, s *SessionUDP, t *net.TCPConn) {
|
func (srv *Server) serve(a net.Addr, h Handler, m []byte, u *net.UDPConn, s *SessionUDP, t *net.TCPConn) {
|
||||||
w := &response{tsigSecret: srv.TsigSecret, udp: u, tcp: t, remoteAddr: a, udpSession: s}
|
w := &response{tsigSecret: srv.TsigSecret, udp: u, tcp: t, remoteAddr: a, udpSession: s}
|
||||||
|
if srv.DecorateWriter != nil {
|
||||||
|
w.writer = srv.DecorateWriter(w)
|
||||||
|
} else {
|
||||||
|
w.writer = w
|
||||||
|
}
|
||||||
|
|
||||||
q := 0
|
q := 0
|
||||||
defer func() {
|
defer func() {
|
||||||
if u != nil {
|
if u != nil {
|
||||||
|
@ -451,6 +509,11 @@ func (srv *Server) serve(a net.Addr, h Handler, m []byte, u *net.UDPConn, s *Ses
|
||||||
srv.wgTCP.Done()
|
srv.wgTCP.Done()
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
reader := Reader(&defaultReader{srv})
|
||||||
|
if srv.DecorateReader != nil {
|
||||||
|
reader = srv.DecorateReader(reader)
|
||||||
|
}
|
||||||
Redo:
|
Redo:
|
||||||
req := new(Msg)
|
req := new(Msg)
|
||||||
err := req.Unpack(m)
|
err := req.Unpack(m)
|
||||||
|
@ -490,7 +553,7 @@ Exit:
|
||||||
if srv.IdleTimeout != nil {
|
if srv.IdleTimeout != nil {
|
||||||
idleTimeout = srv.IdleTimeout()
|
idleTimeout = srv.IdleTimeout()
|
||||||
}
|
}
|
||||||
m, e := srv.readTCP(w.tcp, idleTimeout)
|
m, e := reader.ReadTCP(w.tcp, idleTimeout)
|
||||||
if e == nil {
|
if e == nil {
|
||||||
q++
|
q++
|
||||||
// TODO(miek): make this number configurable?
|
// TODO(miek): make this number configurable?
|
||||||
|
@ -562,7 +625,7 @@ func (w *response) WriteMsg(m *Msg) (err error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
_, err = w.Write(data)
|
_, err = w.writer.Write(data)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -570,7 +633,7 @@ func (w *response) WriteMsg(m *Msg) (err error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
_, err = w.Write(data)
|
_, err = w.writer.Write(data)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -397,3 +397,54 @@ func TestShutdownUDP(t *testing.T) {
|
||||||
t.Errorf("Could not shutdown test UDP server, %v", err)
|
t.Errorf("Could not shutdown test UDP server, %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ExampleFrameLengthWriter struct {
|
||||||
|
Writer
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ExampleFrameLengthWriter) Write(m []byte) (int, error) {
|
||||||
|
fmt.Println("writing raw DNS message of length", len(m))
|
||||||
|
return e.Writer.Write(m)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExampleDecorateWriter() {
|
||||||
|
// instrument raw DNS message writing
|
||||||
|
wf := DecorateWriter(func(w Writer) Writer {
|
||||||
|
return &ExampleFrameLengthWriter{w}
|
||||||
|
})
|
||||||
|
|
||||||
|
// simple UDP server
|
||||||
|
pc, err := net.ListenPacket("udp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println(err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
server := &Server{
|
||||||
|
PacketConn: pc,
|
||||||
|
DecorateWriter: wf,
|
||||||
|
}
|
||||||
|
|
||||||
|
waitLock := sync.Mutex{}
|
||||||
|
waitLock.Lock()
|
||||||
|
server.NotifyStartedFunc = waitLock.Unlock
|
||||||
|
defer server.Shutdown()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
server.ActivateAndServe()
|
||||||
|
pc.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
waitLock.Lock()
|
||||||
|
|
||||||
|
HandleFunc("miek.nl.", HelloServer)
|
||||||
|
|
||||||
|
c := new(Client)
|
||||||
|
m := new(Msg)
|
||||||
|
m.SetQuestion("miek.nl.", TypeTXT)
|
||||||
|
_, _, err = c.Exchange(m, pc.LocalAddr().String())
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println("failed to exchange", err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Output: writing raw DNS message of length 56
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue