it compiles + more tweaks
This commit is contained in:
parent
c7dbb1edc2
commit
064bfe4f2e
441
server.go
441
server.go
|
@ -7,46 +7,48 @@
|
|||
package dns
|
||||
|
||||
import (
|
||||
// "io"
|
||||
"io"
|
||||
"os"
|
||||
"net"
|
||||
)
|
||||
|
||||
type Handler interface {
|
||||
ServeDNS(w ResponseWriter, r *Msg)
|
||||
ServeDNS(w ResponseWriter, r *Msg)
|
||||
}
|
||||
|
||||
// TODO(mg): fit axfr responses in here too
|
||||
// A ResponseWriter interface is used by an DNS handler to
|
||||
// construct an DNS response.
|
||||
type ResponseWriter interface {
|
||||
// RemoteAddr returns the address of the client that sent the current request
|
||||
RemoteAddr() string
|
||||
// RemoteAddr returns the address of the client that sent the current request
|
||||
RemoteAddr() string
|
||||
|
||||
Write([]byte) (int, os.Error)
|
||||
Write([]byte) (int, os.Error)
|
||||
|
||||
// IP based ACL mapping. The contains the string representation
|
||||
// of the IP address and a boolean saying it may connect (true) or not.
|
||||
Acl() map[string]bool
|
||||
// IP based ACL mapping. The contains the string representation
|
||||
// of the IP address and a boolean saying it may connect (true) or not.
|
||||
Acl() map[string]bool
|
||||
|
||||
// Tsig secrets. Its a mapping of key names to secrets.
|
||||
Tsig() map[string]string
|
||||
// Tsig secrets. Its a mapping of key names to secrets.
|
||||
Tsig() map[string]string
|
||||
}
|
||||
|
||||
type conn struct {
|
||||
remoteAddr net.Addr // address of remote side (sans port)
|
||||
port int // port of the remote side
|
||||
handler Handler // request handler
|
||||
request []byte // bytes read
|
||||
_UDP *net.UDPConn // i/o connection if UDP was used
|
||||
_TCP *net.TCPConn // i/o connection if TCP was used
|
||||
hijacked bool // connection has been hijacked by hander TODO(mg)
|
||||
remoteAddr net.Addr // address of remote side (sans port)
|
||||
port int // port of the remote side, needed TODO(mg)
|
||||
handler Handler // request handler
|
||||
request []byte // bytes read
|
||||
_UDP *net.UDPConn // i/o connection if UDP was used
|
||||
_TCP *net.TCPConn // i/o connection if TCP was used
|
||||
hijacked bool // connection has been hijacked by hander TODO(mg)
|
||||
tsig map[string]string // tsig secrets
|
||||
acl map[string]bool // ip acl list
|
||||
}
|
||||
|
||||
type response struct {
|
||||
conn *conn
|
||||
req *Msg
|
||||
xfr bool // {i/a}xfr was requested
|
||||
conn *conn
|
||||
req *Msg
|
||||
xfr bool // {i/a}xfr was requested
|
||||
}
|
||||
|
||||
// ServeMux is an DNS request multiplexer. It matches the
|
||||
|
@ -54,7 +56,7 @@ type response struct {
|
|||
// registered patterns add calls the handler for the pattern
|
||||
// that most closely matches the zone name.
|
||||
type ServeMux struct {
|
||||
m map[string]Handler
|
||||
m map[string]Handler
|
||||
}
|
||||
|
||||
// NewServeMux allocates and returns a new ServeMux.
|
||||
|
@ -71,7 +73,7 @@ type HandlerFunc func(ResponseWriter, *Msg)
|
|||
|
||||
// ServerDNS calls f(w, reg)
|
||||
func (f HandlerFunc) ServeDNS(w ResponseWriter, r *Msg) {
|
||||
f(w, r)
|
||||
f(w, r)
|
||||
}
|
||||
|
||||
// Helper handlers
|
||||
|
@ -100,8 +102,8 @@ func HandleUDP(l *net.UDPConn, f func(*Conn, *Msg)) os.Error {
|
|||
m = m[:n]
|
||||
|
||||
d := new(Conn)
|
||||
// Use the remote addr as we got from ReadFromUDP
|
||||
d.SetUDPConn(l, addr)
|
||||
// Use the remote addr as we got from ReadFromUDP
|
||||
d.SetUDPConn(l, addr)
|
||||
|
||||
msg := new(Msg)
|
||||
if !msg.Unpack(m) {
|
||||
|
@ -123,7 +125,7 @@ func HandleTCP(l *net.TCPListener, f func(*Conn, *Msg)) os.Error {
|
|||
return e
|
||||
}
|
||||
d := new(Conn)
|
||||
d.SetTCPConn(c, nil)
|
||||
d.SetTCPConn(c, nil)
|
||||
|
||||
msg := new(Msg)
|
||||
err := d.ReadMsg(msg)
|
||||
|
@ -138,53 +140,53 @@ func HandleTCP(l *net.TCPListener, f func(*Conn, *Msg)) os.Error {
|
|||
}
|
||||
|
||||
func ListenAndServe(addr string, network string, handler Handler) os.Error {
|
||||
server := &Server{Addr: addr, Network: network, Handler: handler}
|
||||
return server.ListenAndServe()
|
||||
server := &Server{Addr: addr, Network: network, Handler: handler}
|
||||
return server.ListenAndServe()
|
||||
|
||||
}
|
||||
|
||||
func zoneMatch(pattern, zone string) bool {
|
||||
if len(pattern) == 0 {
|
||||
return false
|
||||
}
|
||||
n := len(pattern)
|
||||
return zone[:n] == pattern
|
||||
if len(pattern) == 0 {
|
||||
return false
|
||||
}
|
||||
n := len(pattern)
|
||||
return zone[:n] == pattern
|
||||
}
|
||||
|
||||
func (mux *ServeMux) match(zone string) Handler {
|
||||
var h Handler
|
||||
var n = 0
|
||||
for k, v := range mux.m {
|
||||
if !zoneMatch(k, zone) {
|
||||
continue
|
||||
}
|
||||
if h == nil || len(k) > n {
|
||||
n = len(k)
|
||||
h = v
|
||||
}
|
||||
}
|
||||
return h
|
||||
var h Handler
|
||||
var n = 0
|
||||
for k, v := range mux.m {
|
||||
if !zoneMatch(k, zone) {
|
||||
continue
|
||||
}
|
||||
if h == nil || len(k) > n {
|
||||
n = len(k)
|
||||
h = v
|
||||
}
|
||||
}
|
||||
return h
|
||||
}
|
||||
|
||||
func (mux *ServeMux) Handle(pattern string, handler Handler) {
|
||||
if pattern == "" {
|
||||
panic("dns: invalid pattern " + pattern)
|
||||
}
|
||||
mux.m[pattern] = handler
|
||||
if pattern == "" {
|
||||
panic("dns: invalid pattern " + pattern)
|
||||
}
|
||||
mux.m[pattern] = handler
|
||||
}
|
||||
|
||||
func (mux *ServeMux) HandleFunc(pattern string, handler func(ResponseWriter, *Msg)) {
|
||||
mux.Handle(pattern, HandlerFunc(handler))
|
||||
mux.Handle(pattern, HandlerFunc(handler))
|
||||
}
|
||||
|
||||
// ServeDNS dispatches the request to the handler whose
|
||||
// pattern most closely matches the request message.
|
||||
func (mux *ServeMux) ServeDNS(w ResponseWriter, request *Msg) {
|
||||
h := mux.match(request.Question[0].Name)
|
||||
if h == nil {
|
||||
// h = NotFoundHandler()
|
||||
}
|
||||
h.ServeDNS(w, request)
|
||||
h := mux.match(request.Question[0].Name)
|
||||
if h == nil {
|
||||
// h = NotFoundHandler()
|
||||
}
|
||||
h.ServeDNS(w, request)
|
||||
}
|
||||
|
||||
// Handle register the handler the given pattern
|
||||
|
@ -193,7 +195,7 @@ func (mux *ServeMux) ServeDNS(w ResponseWriter, request *Msg) {
|
|||
func Handle(pattern string, handler Handler) { DefaultServeMux.Handle(pattern, handler) }
|
||||
|
||||
func HandleFunc(pattern string, handler func(ResponseWriter, *Msg)) {
|
||||
DefaultServeMux.HandleFunc(pattern, handler)
|
||||
DefaultServeMux.HandleFunc(pattern, handler)
|
||||
}
|
||||
|
||||
// Serve accepts incoming DNS request on the TCP listener l,
|
||||
|
@ -201,8 +203,8 @@ func HandleFunc(pattern string, handler func(ResponseWriter, *Msg)) {
|
|||
// read requests and then call handler to reply to them.
|
||||
// Handler is typically nil, in which case the DefaultServeMux is used.
|
||||
func ServeTCP(l *net.TCPListener, handler Handler) os.Error {
|
||||
srv := &Server{Handler: handler, Network: "tcp"}
|
||||
return srv.ServeTCP(l)
|
||||
srv := &Server{Handler: handler, Network: "tcp"}
|
||||
return srv.ServeTCP(l)
|
||||
}
|
||||
|
||||
// Serve accepts incoming DNS request on the UDP Conn l,
|
||||
|
@ -210,171 +212,222 @@ func ServeTCP(l *net.TCPListener, handler Handler) os.Error {
|
|||
// read requests and then call handler to reply to them.
|
||||
// Handler is typically nil, in which case the DefaultServeMux is used.
|
||||
func ServeUDP(l *net.UDPConn, handler Handler) os.Error {
|
||||
srv := &Server{Handler: handler, Network: "udp"}
|
||||
return srv.ServeUDP(l)
|
||||
srv := &Server{Handler: handler, Network: "udp"}
|
||||
return srv.ServeUDP(l)
|
||||
}
|
||||
|
||||
// A Server defines parameters for running an HTTP server.
|
||||
type Server struct {
|
||||
Addr string // address to listen on, ":dns" if empty
|
||||
Network string // If "tcp" it will invoke a TCP listener, otherwise an UDP one
|
||||
Handler Handler // handler to invoke, http.DefaultServeMux if nil
|
||||
ReadTimeout int64 // the net.Conn.SetReadTimeout value for new connections
|
||||
WriteTimeout int64 // the net.Conn.SetWriteTimeout value for new connections
|
||||
Addr string // address to listen on, ":dns" if empty
|
||||
Network string // If "tcp" it will invoke a TCP listener, otherwise an UDP one
|
||||
Handler Handler // handler to invoke, http.DefaultServeMux if nil
|
||||
ReadTimeout int64 // the net.Conn.SetReadTimeout value for new connections
|
||||
WriteTimeout int64 // the net.Conn.SetWriteTimeout value for new connections
|
||||
}
|
||||
|
||||
// Fixes for udp/tcp
|
||||
func (srv *Server) ListenAndServe() os.Error {
|
||||
addr := srv.Addr
|
||||
if addr == "" {
|
||||
addr = ":domain"
|
||||
}
|
||||
switch srv.Network {
|
||||
case "tcp":
|
||||
a, e := net.ResolveTCPAddr(addr)
|
||||
if e != nil {
|
||||
return e
|
||||
}
|
||||
l, e := net.ListenTCP("tcp", a)
|
||||
if e != nil {
|
||||
return e
|
||||
}
|
||||
return srv.ServeTCP(l)
|
||||
case "udp":
|
||||
a, e := net.ResolveUDPAddr(addr)
|
||||
if e != nil {
|
||||
return e
|
||||
}
|
||||
l, e := net.ListenUDP("udp", a)
|
||||
if e != nil {
|
||||
return e
|
||||
}
|
||||
return srv.ServeUDP(l)
|
||||
}
|
||||
return nil // os.Error with wrong network
|
||||
addr := srv.Addr
|
||||
if addr == "" {
|
||||
addr = ":domain"
|
||||
}
|
||||
switch srv.Network {
|
||||
case "tcp":
|
||||
a, e := net.ResolveTCPAddr(addr)
|
||||
if e != nil {
|
||||
return e
|
||||
}
|
||||
l, e := net.ListenTCP("tcp", a)
|
||||
if e != nil {
|
||||
return e
|
||||
}
|
||||
return srv.ServeTCP(l)
|
||||
case "udp":
|
||||
a, e := net.ResolveUDPAddr(addr)
|
||||
if e != nil {
|
||||
return e
|
||||
}
|
||||
l, e := net.ListenUDP("udp", a)
|
||||
if e != nil {
|
||||
return e
|
||||
}
|
||||
return srv.ServeUDP(l)
|
||||
}
|
||||
return nil // os.Error with wrong network
|
||||
}
|
||||
|
||||
func (srv *Server) ServeTCP(l *net.TCPListener) os.Error {
|
||||
defer l.Close()
|
||||
handler := srv.Handler
|
||||
if handler == nil {
|
||||
handler = DefaultServeMux
|
||||
}
|
||||
forever:
|
||||
for {
|
||||
rw, e := l.AcceptTCP()
|
||||
if e != nil {
|
||||
return e
|
||||
}
|
||||
if srv.ReadTimeout != 0 {
|
||||
rw.SetReadTimeout(srv.ReadTimeout)
|
||||
}
|
||||
if srv.WriteTimeout != 0 {
|
||||
rw.SetWriteTimeout(srv.WriteTimeout)
|
||||
}
|
||||
l := make([]byte, 2)
|
||||
n, err := rw.Read(l)
|
||||
if err != nil || n != 2 {
|
||||
continue
|
||||
}
|
||||
length, _ := unpackUint16(l, 0)
|
||||
if length == 0 {
|
||||
continue
|
||||
}
|
||||
m := make([]byte, int(length))
|
||||
n, err = rw.Read(m[:int(length)])
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
i := n
|
||||
for i < int(length) {
|
||||
j, err := rw.Read(m[i:int(length)])
|
||||
if err != nil {
|
||||
continue forever
|
||||
}
|
||||
i += j
|
||||
}
|
||||
n = i
|
||||
d, err := newConn(rw, nil, rw.RemoteAddr(), m, handler)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
go d.serve()
|
||||
}
|
||||
panic("not reached")
|
||||
defer l.Close()
|
||||
handler := srv.Handler
|
||||
if handler == nil {
|
||||
handler = DefaultServeMux
|
||||
}
|
||||
forever:
|
||||
for {
|
||||
rw, e := l.AcceptTCP()
|
||||
if e != nil {
|
||||
return e
|
||||
}
|
||||
if srv.ReadTimeout != 0 {
|
||||
rw.SetReadTimeout(srv.ReadTimeout)
|
||||
}
|
||||
if srv.WriteTimeout != 0 {
|
||||
rw.SetWriteTimeout(srv.WriteTimeout)
|
||||
}
|
||||
l := make([]byte, 2)
|
||||
n, err := rw.Read(l)
|
||||
if err != nil || n != 2 {
|
||||
continue
|
||||
}
|
||||
length, _ := unpackUint16(l, 0)
|
||||
if length == 0 {
|
||||
continue
|
||||
}
|
||||
m := make([]byte, int(length))
|
||||
n, err = rw.Read(m[:int(length)])
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
i := n
|
||||
for i < int(length) {
|
||||
j, err := rw.Read(m[i:int(length)])
|
||||
if err != nil {
|
||||
continue forever
|
||||
}
|
||||
i += j
|
||||
}
|
||||
n = i
|
||||
d, err := newConn(rw, nil, rw.RemoteAddr(), m, handler)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
go d.serve()
|
||||
}
|
||||
panic("not reached")
|
||||
}
|
||||
|
||||
func (srv *Server) ServeUDP(l *net.UDPConn) os.Error {
|
||||
defer l.Close()
|
||||
handler := srv.Handler
|
||||
if handler == nil {
|
||||
handler = DefaultServeMux
|
||||
}
|
||||
for {
|
||||
m := make([]byte, DefaultMsgSize)
|
||||
n, a, e := l.ReadFromUDP(m)
|
||||
if e != nil {
|
||||
return e
|
||||
}
|
||||
m = m[:n]
|
||||
defer l.Close()
|
||||
handler := srv.Handler
|
||||
if handler == nil {
|
||||
handler = DefaultServeMux
|
||||
}
|
||||
for {
|
||||
m := make([]byte, DefaultMsgSize)
|
||||
n, a, e := l.ReadFromUDP(m)
|
||||
if e != nil {
|
||||
return e
|
||||
}
|
||||
m = m[:n]
|
||||
|
||||
if srv.ReadTimeout != 0 {
|
||||
l.SetReadTimeout(srv.ReadTimeout)
|
||||
}
|
||||
if srv.WriteTimeout != 0 {
|
||||
l.SetWriteTimeout(srv.WriteTimeout)
|
||||
}
|
||||
d, err := newConn(nil, l, a, m, handler)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
go d.serve()
|
||||
}
|
||||
panic("not reached")
|
||||
if srv.ReadTimeout != 0 {
|
||||
l.SetReadTimeout(srv.ReadTimeout)
|
||||
}
|
||||
if srv.WriteTimeout != 0 {
|
||||
l.SetWriteTimeout(srv.WriteTimeout)
|
||||
}
|
||||
d, err := newConn(nil, l, a, m, handler)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
go d.serve()
|
||||
}
|
||||
panic("not reached")
|
||||
}
|
||||
|
||||
func newConn(t *net.TCPConn, u *net.UDPConn, a net.Addr, buf []byte, handler Handler) (c *conn, err os.Error) {
|
||||
c = new(conn)
|
||||
c.handler = handler
|
||||
c._TCP = t
|
||||
c._UDP = u
|
||||
c.remoteAddr = a
|
||||
c.request = buf
|
||||
if t != nil {
|
||||
c.port = a.(*net.TCPAddr).Port
|
||||
}
|
||||
if u != nil {
|
||||
c.port = a.(*net.UDPAddr).Port
|
||||
}
|
||||
return c, err
|
||||
c = new(conn)
|
||||
c.handler = handler
|
||||
c._TCP = t
|
||||
c._UDP = u
|
||||
c.remoteAddr = a
|
||||
c.request = buf
|
||||
if t != nil {
|
||||
c.port = a.(*net.TCPAddr).Port
|
||||
}
|
||||
if u != nil {
|
||||
c.port = a.(*net.UDPAddr).Port
|
||||
}
|
||||
return c, err
|
||||
}
|
||||
|
||||
|
||||
// Close the connection.
|
||||
func (c *conn) close() {
|
||||
switch {
|
||||
case c._UDP != nil:
|
||||
c._UDP.Close()
|
||||
c._UDP = nil
|
||||
case c._TCP != nil:
|
||||
c._TCP.Close()
|
||||
c._TCP = nil
|
||||
}
|
||||
switch {
|
||||
case c._UDP != nil:
|
||||
c._UDP.Close()
|
||||
c._UDP = nil
|
||||
case c._TCP != nil:
|
||||
c._TCP.Close()
|
||||
c._TCP = nil
|
||||
}
|
||||
}
|
||||
|
||||
// Serve a new connection.
|
||||
func (c *conn) serve() {
|
||||
// c.ReadRequest
|
||||
|
||||
// c.Handler.ServeDNS(w, w.req) // this does the writing
|
||||
// Request has been read in ServeUDP or ServeTCP
|
||||
w := new(response)
|
||||
w.conn = c
|
||||
w.xfr = false
|
||||
req := new(Msg)
|
||||
if !req.Unpack(c.request) {
|
||||
return
|
||||
}
|
||||
c.handler.ServeDNS(w, w.req) // this does the writing back to the client
|
||||
if c.hijacked {
|
||||
return
|
||||
}
|
||||
c.close()
|
||||
}
|
||||
|
||||
|
||||
func (c *conn) readRequest() (w *response, err os.Error) {
|
||||
|
||||
|
||||
|
||||
|
||||
w = new(response)
|
||||
return w, nil
|
||||
func (w *response) Write(data []byte) (n int, err os.Error) {
|
||||
switch {
|
||||
case w.conn._UDP != nil:
|
||||
n, err = w.conn._UDP.WriteTo(data, w.conn.remoteAddr)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
case w.conn._TCP != nil:
|
||||
// TODO(mg) len(data) > 64K
|
||||
l := make([]byte, 2)
|
||||
l[0], l[1] = packUint16(uint16(len(data)))
|
||||
n, err = w.conn._TCP.Write(data)
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
if n != 2 {
|
||||
return n, io.ErrShortWrite
|
||||
}
|
||||
n, err = w.conn._TCP.Write(data)
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
i := n
|
||||
if i < len(data) {
|
||||
j, err := w.conn._TCP.Write(data[i:len(data)])
|
||||
if err != nil {
|
||||
return i, err
|
||||
}
|
||||
i += j
|
||||
}
|
||||
n = i
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// Acl implements the ResponseWriter.Acl
|
||||
func (w *response) Acl() map[string]bool {
|
||||
return w.conn.acl
|
||||
}
|
||||
|
||||
|
||||
// Tsig implements the ResponseWriter.Tsig
|
||||
func (w *response) Tsig() map[string]string {
|
||||
return w.conn.tsig
|
||||
}
|
||||
|
||||
// RemoteAddr implements the ResponseWriter.RemoteAddr method
|
||||
func (w *response) RemoteAddr() string { return w.conn.remoteAddr.String() }
|
||||
|
||||
|
|
Loading…
Reference in New Issue