it compiles + more tweaks

This commit is contained in:
Miek Gieben 2011-04-03 11:49:23 +02:00
parent c7dbb1edc2
commit 064bfe4f2e
1 changed files with 247 additions and 194 deletions

441
server.go
View File

@ -7,46 +7,48 @@
package dns package dns
import ( import (
// "io" "io"
"os" "os"
"net" "net"
) )
type Handler interface { type Handler interface {
ServeDNS(w ResponseWriter, r *Msg) ServeDNS(w ResponseWriter, r *Msg)
} }
// TODO(mg): fit axfr responses in here too // TODO(mg): fit axfr responses in here too
// A ResponseWriter interface is used by an DNS handler to // A ResponseWriter interface is used by an DNS handler to
// construct an DNS response. // construct an DNS response.
type ResponseWriter interface { type ResponseWriter interface {
// RemoteAddr returns the address of the client that sent the current request // RemoteAddr returns the address of the client that sent the current request
RemoteAddr() string RemoteAddr() string
Write([]byte) (int, os.Error) Write([]byte) (int, os.Error)
// IP based ACL mapping. The contains the string representation // IP based ACL mapping. The contains the string representation
// of the IP address and a boolean saying it may connect (true) or not. // of the IP address and a boolean saying it may connect (true) or not.
Acl() map[string]bool Acl() map[string]bool
// Tsig secrets. Its a mapping of key names to secrets. // Tsig secrets. Its a mapping of key names to secrets.
Tsig() map[string]string Tsig() map[string]string
} }
type conn struct { type conn struct {
remoteAddr net.Addr // address of remote side (sans port) remoteAddr net.Addr // address of remote side (sans port)
port int // port of the remote side port int // port of the remote side, needed TODO(mg)
handler Handler // request handler handler Handler // request handler
request []byte // bytes read request []byte // bytes read
_UDP *net.UDPConn // i/o connection if UDP was used _UDP *net.UDPConn // i/o connection if UDP was used
_TCP *net.TCPConn // i/o connection if TCP was used _TCP *net.TCPConn // i/o connection if TCP was used
hijacked bool // connection has been hijacked by hander TODO(mg) hijacked bool // connection has been hijacked by hander TODO(mg)
tsig map[string]string // tsig secrets
acl map[string]bool // ip acl list
} }
type response struct { type response struct {
conn *conn conn *conn
req *Msg req *Msg
xfr bool // {i/a}xfr was requested xfr bool // {i/a}xfr was requested
} }
// ServeMux is an DNS request multiplexer. It matches the // ServeMux is an DNS request multiplexer. It matches the
@ -54,7 +56,7 @@ type response struct {
// registered patterns add calls the handler for the pattern // registered patterns add calls the handler for the pattern
// that most closely matches the zone name. // that most closely matches the zone name.
type ServeMux struct { type ServeMux struct {
m map[string]Handler m map[string]Handler
} }
// NewServeMux allocates and returns a new ServeMux. // NewServeMux allocates and returns a new ServeMux.
@ -71,7 +73,7 @@ type HandlerFunc func(ResponseWriter, *Msg)
// ServerDNS calls f(w, reg) // ServerDNS calls f(w, reg)
func (f HandlerFunc) ServeDNS(w ResponseWriter, r *Msg) { func (f HandlerFunc) ServeDNS(w ResponseWriter, r *Msg) {
f(w, r) f(w, r)
} }
// Helper handlers // Helper handlers
@ -100,8 +102,8 @@ func HandleUDP(l *net.UDPConn, f func(*Conn, *Msg)) os.Error {
m = m[:n] m = m[:n]
d := new(Conn) d := new(Conn)
// Use the remote addr as we got from ReadFromUDP // Use the remote addr as we got from ReadFromUDP
d.SetUDPConn(l, addr) d.SetUDPConn(l, addr)
msg := new(Msg) msg := new(Msg)
if !msg.Unpack(m) { if !msg.Unpack(m) {
@ -123,7 +125,7 @@ func HandleTCP(l *net.TCPListener, f func(*Conn, *Msg)) os.Error {
return e return e
} }
d := new(Conn) d := new(Conn)
d.SetTCPConn(c, nil) d.SetTCPConn(c, nil)
msg := new(Msg) msg := new(Msg)
err := d.ReadMsg(msg) err := d.ReadMsg(msg)
@ -138,53 +140,53 @@ func HandleTCP(l *net.TCPListener, f func(*Conn, *Msg)) os.Error {
} }
func ListenAndServe(addr string, network string, handler Handler) os.Error { func ListenAndServe(addr string, network string, handler Handler) os.Error {
server := &Server{Addr: addr, Network: network, Handler: handler} server := &Server{Addr: addr, Network: network, Handler: handler}
return server.ListenAndServe() return server.ListenAndServe()
} }
func zoneMatch(pattern, zone string) bool { func zoneMatch(pattern, zone string) bool {
if len(pattern) == 0 { if len(pattern) == 0 {
return false return false
} }
n := len(pattern) n := len(pattern)
return zone[:n] == pattern return zone[:n] == pattern
} }
func (mux *ServeMux) match(zone string) Handler { func (mux *ServeMux) match(zone string) Handler {
var h Handler var h Handler
var n = 0 var n = 0
for k, v := range mux.m { for k, v := range mux.m {
if !zoneMatch(k, zone) { if !zoneMatch(k, zone) {
continue continue
} }
if h == nil || len(k) > n { if h == nil || len(k) > n {
n = len(k) n = len(k)
h = v h = v
} }
} }
return h return h
} }
func (mux *ServeMux) Handle(pattern string, handler Handler) { func (mux *ServeMux) Handle(pattern string, handler Handler) {
if pattern == "" { if pattern == "" {
panic("dns: invalid pattern " + pattern) panic("dns: invalid pattern " + pattern)
} }
mux.m[pattern] = handler mux.m[pattern] = handler
} }
func (mux *ServeMux) HandleFunc(pattern string, handler func(ResponseWriter, *Msg)) { func (mux *ServeMux) HandleFunc(pattern string, handler func(ResponseWriter, *Msg)) {
mux.Handle(pattern, HandlerFunc(handler)) mux.Handle(pattern, HandlerFunc(handler))
} }
// ServeDNS dispatches the request to the handler whose // ServeDNS dispatches the request to the handler whose
// pattern most closely matches the request message. // pattern most closely matches the request message.
func (mux *ServeMux) ServeDNS(w ResponseWriter, request *Msg) { func (mux *ServeMux) ServeDNS(w ResponseWriter, request *Msg) {
h := mux.match(request.Question[0].Name) h := mux.match(request.Question[0].Name)
if h == nil { if h == nil {
// h = NotFoundHandler() // h = NotFoundHandler()
} }
h.ServeDNS(w, request) h.ServeDNS(w, request)
} }
// Handle register the handler the given pattern // Handle register the handler the given pattern
@ -193,7 +195,7 @@ func (mux *ServeMux) ServeDNS(w ResponseWriter, request *Msg) {
func Handle(pattern string, handler Handler) { DefaultServeMux.Handle(pattern, handler) } func Handle(pattern string, handler Handler) { DefaultServeMux.Handle(pattern, handler) }
func HandleFunc(pattern string, handler func(ResponseWriter, *Msg)) { func HandleFunc(pattern string, handler func(ResponseWriter, *Msg)) {
DefaultServeMux.HandleFunc(pattern, handler) DefaultServeMux.HandleFunc(pattern, handler)
} }
// Serve accepts incoming DNS request on the TCP listener l, // Serve accepts incoming DNS request on the TCP listener l,
@ -201,8 +203,8 @@ func HandleFunc(pattern string, handler func(ResponseWriter, *Msg)) {
// read requests and then call handler to reply to them. // read requests and then call handler to reply to them.
// Handler is typically nil, in which case the DefaultServeMux is used. // Handler is typically nil, in which case the DefaultServeMux is used.
func ServeTCP(l *net.TCPListener, handler Handler) os.Error { func ServeTCP(l *net.TCPListener, handler Handler) os.Error {
srv := &Server{Handler: handler, Network: "tcp"} srv := &Server{Handler: handler, Network: "tcp"}
return srv.ServeTCP(l) return srv.ServeTCP(l)
} }
// Serve accepts incoming DNS request on the UDP Conn l, // Serve accepts incoming DNS request on the UDP Conn l,
@ -210,171 +212,222 @@ func ServeTCP(l *net.TCPListener, handler Handler) os.Error {
// read requests and then call handler to reply to them. // read requests and then call handler to reply to them.
// Handler is typically nil, in which case the DefaultServeMux is used. // Handler is typically nil, in which case the DefaultServeMux is used.
func ServeUDP(l *net.UDPConn, handler Handler) os.Error { func ServeUDP(l *net.UDPConn, handler Handler) os.Error {
srv := &Server{Handler: handler, Network: "udp"} srv := &Server{Handler: handler, Network: "udp"}
return srv.ServeUDP(l) return srv.ServeUDP(l)
} }
// A Server defines parameters for running an HTTP server. // A Server defines parameters for running an HTTP server.
type Server struct { type Server struct {
Addr string // address to listen on, ":dns" if empty Addr string // address to listen on, ":dns" if empty
Network string // If "tcp" it will invoke a TCP listener, otherwise an UDP one Network string // If "tcp" it will invoke a TCP listener, otherwise an UDP one
Handler Handler // handler to invoke, http.DefaultServeMux if nil Handler Handler // handler to invoke, http.DefaultServeMux if nil
ReadTimeout int64 // the net.Conn.SetReadTimeout value for new connections ReadTimeout int64 // the net.Conn.SetReadTimeout value for new connections
WriteTimeout int64 // the net.Conn.SetWriteTimeout value for new connections WriteTimeout int64 // the net.Conn.SetWriteTimeout value for new connections
} }
// Fixes for udp/tcp // Fixes for udp/tcp
func (srv *Server) ListenAndServe() os.Error { func (srv *Server) ListenAndServe() os.Error {
addr := srv.Addr addr := srv.Addr
if addr == "" { if addr == "" {
addr = ":domain" addr = ":domain"
} }
switch srv.Network { switch srv.Network {
case "tcp": case "tcp":
a, e := net.ResolveTCPAddr(addr) a, e := net.ResolveTCPAddr(addr)
if e != nil { if e != nil {
return e return e
} }
l, e := net.ListenTCP("tcp", a) l, e := net.ListenTCP("tcp", a)
if e != nil { if e != nil {
return e return e
} }
return srv.ServeTCP(l) return srv.ServeTCP(l)
case "udp": case "udp":
a, e := net.ResolveUDPAddr(addr) a, e := net.ResolveUDPAddr(addr)
if e != nil { if e != nil {
return e return e
} }
l, e := net.ListenUDP("udp", a) l, e := net.ListenUDP("udp", a)
if e != nil { if e != nil {
return e return e
} }
return srv.ServeUDP(l) return srv.ServeUDP(l)
} }
return nil // os.Error with wrong network return nil // os.Error with wrong network
} }
func (srv *Server) ServeTCP(l *net.TCPListener) os.Error { func (srv *Server) ServeTCP(l *net.TCPListener) os.Error {
defer l.Close() defer l.Close()
handler := srv.Handler handler := srv.Handler
if handler == nil { if handler == nil {
handler = DefaultServeMux handler = DefaultServeMux
} }
forever: forever:
for { for {
rw, e := l.AcceptTCP() rw, e := l.AcceptTCP()
if e != nil { if e != nil {
return e return e
} }
if srv.ReadTimeout != 0 { if srv.ReadTimeout != 0 {
rw.SetReadTimeout(srv.ReadTimeout) rw.SetReadTimeout(srv.ReadTimeout)
} }
if srv.WriteTimeout != 0 { if srv.WriteTimeout != 0 {
rw.SetWriteTimeout(srv.WriteTimeout) rw.SetWriteTimeout(srv.WriteTimeout)
} }
l := make([]byte, 2) l := make([]byte, 2)
n, err := rw.Read(l) n, err := rw.Read(l)
if err != nil || n != 2 { if err != nil || n != 2 {
continue continue
} }
length, _ := unpackUint16(l, 0) length, _ := unpackUint16(l, 0)
if length == 0 { if length == 0 {
continue continue
} }
m := make([]byte, int(length)) m := make([]byte, int(length))
n, err = rw.Read(m[:int(length)]) n, err = rw.Read(m[:int(length)])
if err != nil { if err != nil {
continue continue
} }
i := n i := n
for i < int(length) { for i < int(length) {
j, err := rw.Read(m[i:int(length)]) j, err := rw.Read(m[i:int(length)])
if err != nil { if err != nil {
continue forever continue forever
} }
i += j i += j
} }
n = i n = i
d, err := newConn(rw, nil, rw.RemoteAddr(), m, handler) d, err := newConn(rw, nil, rw.RemoteAddr(), m, handler)
if err != nil { if err != nil {
continue continue
} }
go d.serve() go d.serve()
} }
panic("not reached") panic("not reached")
} }
func (srv *Server) ServeUDP(l *net.UDPConn) os.Error { func (srv *Server) ServeUDP(l *net.UDPConn) os.Error {
defer l.Close() defer l.Close()
handler := srv.Handler handler := srv.Handler
if handler == nil { if handler == nil {
handler = DefaultServeMux handler = DefaultServeMux
} }
for { for {
m := make([]byte, DefaultMsgSize) m := make([]byte, DefaultMsgSize)
n, a, e := l.ReadFromUDP(m) n, a, e := l.ReadFromUDP(m)
if e != nil { if e != nil {
return e return e
} }
m = m[:n] m = m[:n]
if srv.ReadTimeout != 0 { if srv.ReadTimeout != 0 {
l.SetReadTimeout(srv.ReadTimeout) l.SetReadTimeout(srv.ReadTimeout)
} }
if srv.WriteTimeout != 0 { if srv.WriteTimeout != 0 {
l.SetWriteTimeout(srv.WriteTimeout) l.SetWriteTimeout(srv.WriteTimeout)
} }
d, err := newConn(nil, l, a, m, handler) d, err := newConn(nil, l, a, m, handler)
if err != nil { if err != nil {
continue continue
} }
go d.serve() go d.serve()
} }
panic("not reached") panic("not reached")
} }
func newConn(t *net.TCPConn, u *net.UDPConn, a net.Addr, buf []byte, handler Handler) (c *conn, err os.Error) { func newConn(t *net.TCPConn, u *net.UDPConn, a net.Addr, buf []byte, handler Handler) (c *conn, err os.Error) {
c = new(conn) c = new(conn)
c.handler = handler c.handler = handler
c._TCP = t c._TCP = t
c._UDP = u c._UDP = u
c.remoteAddr = a c.remoteAddr = a
c.request = buf c.request = buf
if t != nil { if t != nil {
c.port = a.(*net.TCPAddr).Port c.port = a.(*net.TCPAddr).Port
} }
if u != nil { if u != nil {
c.port = a.(*net.UDPAddr).Port c.port = a.(*net.UDPAddr).Port
} }
return c, err return c, err
} }
// Close the connection. // Close the connection.
func (c *conn) close() { func (c *conn) close() {
switch { switch {
case c._UDP != nil: case c._UDP != nil:
c._UDP.Close() c._UDP.Close()
c._UDP = nil c._UDP = nil
case c._TCP != nil: case c._TCP != nil:
c._TCP.Close() c._TCP.Close()
c._TCP = nil c._TCP = nil
} }
} }
// Serve a new connection. // Serve a new connection.
func (c *conn) serve() { func (c *conn) serve() {
// c.ReadRequest // Request has been read in ServeUDP or ServeTCP
w := new(response)
// c.Handler.ServeDNS(w, w.req) // this does the writing w.conn = c
w.xfr = false
req := new(Msg)
if !req.Unpack(c.request) {
return
}
c.handler.ServeDNS(w, w.req) // this does the writing back to the client
if c.hijacked {
return
}
c.close()
} }
func (c *conn) readRequest() (w *response, err os.Error) { func (w *response) Write(data []byte) (n int, err os.Error) {
switch {
case w.conn._UDP != nil:
n, err = w.conn._UDP.WriteTo(data, w.conn.remoteAddr)
if err != nil {
w = new(response) return 0, err
return w, nil }
case w.conn._TCP != nil:
// TODO(mg) len(data) > 64K
l := make([]byte, 2)
l[0], l[1] = packUint16(uint16(len(data)))
n, err = w.conn._TCP.Write(data)
if err != nil {
return n, err
}
if n != 2 {
return n, io.ErrShortWrite
}
n, err = w.conn._TCP.Write(data)
if err != nil {
return n, err
}
i := n
if i < len(data) {
j, err := w.conn._TCP.Write(data[i:len(data)])
if err != nil {
return i, err
}
i += j
}
n = i
}
return n, nil
} }
// Acl implements the ResponseWriter.Acl
func (w *response) Acl() map[string]bool {
return w.conn.acl
}
// Tsig implements the ResponseWriter.Tsig
func (w *response) Tsig() map[string]string {
return w.conn.tsig
}
// RemoteAddr implements the ResponseWriter.RemoteAddr method
func (w *response) RemoteAddr() string { return w.conn.remoteAddr.String() }