dns/server.go

349 lines
8.2 KiB
Go
Raw Normal View History

2011-02-08 20:25:01 +00:00
// Copyright 2011 Miek Gieben. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// DNS server implementation.
2011-02-08 20:25:01 +00:00
package dns
2011-02-08 20:25:01 +00:00
import (
2011-04-03 09:49:23 +00:00
"io"
2011-02-08 20:25:01 +00:00
"net"
)
2011-09-11 20:10:04 +00:00
// how to do Tsig here?? TODO(mg)
2011-04-01 11:15:36 +00:00
type Handler interface {
2011-04-03 09:49:23 +00:00
ServeDNS(w ResponseWriter, r *Msg)
// IP based ACL mapping. The contains the string representation
// of the IP address and a boolean saying it may connect (true) or not.
2011-04-01 11:15:36 +00:00
}
2011-04-03 09:14:54 +00:00
// A ResponseWriter interface is used by an DNS handler to
// construct an DNS response.
2011-04-02 07:22:05 +00:00
type ResponseWriter interface {
2011-07-05 19:08:22 +00:00
// RemoteAddr returns the net.Addr of the client that sent the current request.
2011-07-05 17:17:29 +00:00
RemoteAddr() net.Addr
2011-07-23 21:43:43 +00:00
// Write a reply back to the client.
2011-11-02 22:06:54 +00:00
Write([]byte) (int, error)
2011-04-02 07:22:05 +00:00
}
2011-09-11 20:10:04 +00:00
// port?
2011-04-02 07:22:05 +00:00
type conn struct {
2011-04-03 09:49:23 +00:00
remoteAddr net.Addr // address of remote side (sans port)
handler Handler // request handler
request []byte // bytes read
_UDP *net.UDPConn // i/o connection if UDP was used
_TCP *net.TCPConn // i/o connection if TCP was used
hijacked bool // connection has been hijacked by hander TODO(mg)
2011-04-02 07:22:05 +00:00
}
2011-04-01 08:53:31 +00:00
2011-04-02 07:22:05 +00:00
type response struct {
2011-04-03 09:49:23 +00:00
conn *conn
req *Msg
2011-04-01 08:53:31 +00:00
}
2011-04-01 11:15:36 +00:00
// ServeMux is an DNS request multiplexer. It matches the
// zone name of each incoming request against a list of
// registered patterns add calls the handler for the pattern
// that most closely matches the zone name.
type ServeMux struct {
2011-04-03 09:49:23 +00:00
m map[string]Handler
2011-04-01 11:15:36 +00:00
}
2011-04-02 07:22:05 +00:00
// NewServeMux allocates and returns a new ServeMux.
func NewServeMux() *ServeMux { return &ServeMux{make(map[string]Handler)} }
2011-04-01 11:15:36 +00:00
2011-04-02 07:22:05 +00:00
// DefaultServeMux is the default ServeMux used by Serve.
var DefaultServeMux = NewServeMux()
2011-04-01 11:15:36 +00:00
2011-04-02 07:22:05 +00:00
// The HandlerFunc type is an adapter to allow the use of
// ordinary functions as DNS handlers. If f is a function
// with the appropriate signature, HandlerFunc(f) is a
// Handler object that calls f.
type HandlerFunc func(ResponseWriter, *Msg)
2011-04-01 11:15:36 +00:00
2011-07-05 19:08:22 +00:00
// ServerDNS calls f(w, r)
2011-04-02 07:22:05 +00:00
func (f HandlerFunc) ServeDNS(w ResponseWriter, r *Msg) {
2011-04-03 09:49:23 +00:00
f(w, r)
2011-04-01 11:15:36 +00:00
}
2011-07-23 21:32:42 +00:00
// Helper handler that returns an answer with
// RCODE = refused for every request.
2011-04-03 11:43:46 +00:00
func Refused(w ResponseWriter, r *Msg) {
2011-04-18 20:08:12 +00:00
m := new(Msg)
2011-11-02 22:06:54 +00:00
m.SetRcode(r, RcodeRefused)
buf, _ := m.Pack()
2011-04-18 20:08:12 +00:00
w.Write(buf)
2011-04-03 11:43:46 +00:00
}
2011-04-01 11:15:36 +00:00
2011-07-23 21:32:42 +00:00
// RefusedHandler returns HandlerFunc with Refused.
2011-04-03 11:43:46 +00:00
func RefusedHandler() Handler { return HandlerFunc(Refused) }
2011-04-01 08:53:31 +00:00
2011-07-23 21:32:42 +00:00
// ...
2012-01-12 21:47:36 +00:00
func ListenAndServe(addr string, network string, handler Handler, size int) error {
server := &Server{Addr: addr, Net: network, Handler: handler, UDPSize: size}
2011-04-03 09:49:23 +00:00
return server.ListenAndServe()
2011-04-02 07:22:05 +00:00
}
2011-04-01 11:15:36 +00:00
func (mux *ServeMux) match(zone string) Handler {
2011-04-03 09:49:23 +00:00
var h Handler
var n = 0
for k, v := range mux.m {
if !zoneMatch(k, zone) {
continue
}
if h == nil || len(k) > n {
n = len(k)
h = v
}
}
return h
2011-04-01 11:15:36 +00:00
}
func (mux *ServeMux) Handle(pattern string, handler Handler) {
2011-04-03 09:49:23 +00:00
if pattern == "" {
panic("dns: invalid pattern " + pattern)
}
2011-07-04 20:27:23 +00:00
// Should this go
//if pattern[len(pattern)-1] != '.' { // no ending .
// mux.m[pattern+"."] = handler
//} else {
mux.m[pattern] = handler
2011-04-01 11:15:36 +00:00
}
2011-04-02 07:22:05 +00:00
func (mux *ServeMux) HandleFunc(pattern string, handler func(ResponseWriter, *Msg)) {
2011-04-03 09:49:23 +00:00
mux.Handle(pattern, HandlerFunc(handler))
2011-04-02 07:22:05 +00:00
}
// ServeDNS dispatches the request to the handler whose
// pattern most closely matches the request message.
func (mux *ServeMux) ServeDNS(w ResponseWriter, request *Msg) {
2011-04-03 09:49:23 +00:00
h := mux.match(request.Question[0].Name)
if h == nil {
2011-04-03 11:43:46 +00:00
h = RefusedHandler()
2011-04-03 09:49:23 +00:00
}
h.ServeDNS(w, request)
2011-04-02 07:22:05 +00:00
}
// Handle register the handler the given pattern
// in the DefaultServeMux. The documentation for
// ServeMux explains how patters are matched.
func Handle(pattern string, handler Handler) { DefaultServeMux.Handle(pattern, handler) }
func HandleFunc(pattern string, handler func(ResponseWriter, *Msg)) {
2011-04-03 09:49:23 +00:00
DefaultServeMux.HandleFunc(pattern, handler)
2011-04-02 07:22:05 +00:00
}
2011-04-12 19:44:56 +00:00
// A Server defines parameters for running an DNS server.
2011-07-06 07:25:05 +00:00
// Note how much it starts to look like 'Client struct'
2011-04-12 19:44:56 +00:00
type Server struct {
2011-07-23 21:43:43 +00:00
Addr string // address to listen on, ":dns" if empty
Net string // if "tcp" it will invoke a TCP listener, otherwise an UDP one
Handler Handler // handler to invoke, dns.DefaultServeMux if nil
2012-01-12 21:47:36 +00:00
UDPSize int // default buffer to use to read incoming UDP messages
2011-07-23 21:43:43 +00:00
ReadTimeout int64 // the net.Conn.SetReadTimeout value for new connections
WriteTimeout int64 // the net.Conn.SetWriteTimeout value for new connections
TsigSecret map[string]string // secret(s) for Tsig map[<zonename>]<base64 secret>
2011-04-12 19:44:56 +00:00
}
2012-01-12 21:47:36 +00:00
// ListenAndServe starts a nameserver on the configured address.
2011-11-02 22:06:54 +00:00
func (srv *Server) ListenAndServe() error {
2011-04-03 09:49:23 +00:00
addr := srv.Addr
if addr == "" {
addr = ":domain"
}
2011-04-18 20:08:12 +00:00
switch srv.Net {
2011-07-05 17:44:46 +00:00
case "tcp", "tcp4", "tcp6":
a, e := net.ResolveTCPAddr(srv.Net, addr)
2011-04-03 09:49:23 +00:00
if e != nil {
return e
}
2011-07-05 17:44:46 +00:00
l, e := net.ListenTCP(srv.Net, a)
2011-04-03 09:49:23 +00:00
if e != nil {
return e
}
return srv.ServeTCP(l)
2011-07-05 17:44:46 +00:00
case "udp", "udp4", "udp6":
a, e := net.ResolveUDPAddr(srv.Net, addr)
2011-04-03 09:49:23 +00:00
if e != nil {
return e
}
2011-07-05 17:44:46 +00:00
l, e := net.ListenUDP(srv.Net, a)
2011-04-03 09:49:23 +00:00
if e != nil {
return e
}
return srv.ServeUDP(l)
}
return nil // os.Error with wrong network
2011-04-02 07:22:05 +00:00
}
2011-11-02 22:06:54 +00:00
func (srv *Server) ServeTCP(l *net.TCPListener) error {
2011-04-03 09:49:23 +00:00
defer l.Close()
handler := srv.Handler
if handler == nil {
handler = DefaultServeMux
}
forever:
for {
rw, e := l.AcceptTCP()
if e != nil {
return e
}
if srv.ReadTimeout != 0 {
rw.SetReadTimeout(srv.ReadTimeout)
}
if srv.WriteTimeout != 0 {
rw.SetWriteTimeout(srv.WriteTimeout)
}
l := make([]byte, 2)
n, err := rw.Read(l)
if err != nil || n != 2 {
continue
}
length, _ := unpackUint16(l, 0)
if length == 0 {
continue
}
m := make([]byte, int(length))
n, err = rw.Read(m[:int(length)])
if err != nil {
continue
}
i := n
for i < int(length) {
j, err := rw.Read(m[i:int(length)])
if err != nil {
continue forever
}
i += j
}
n = i
d, err := newConn(rw, nil, rw.RemoteAddr(), m, handler)
if err != nil {
continue
}
go d.serve()
}
panic("not reached")
2011-04-02 07:22:05 +00:00
}
2011-11-02 22:06:54 +00:00
func (srv *Server) ServeUDP(l *net.UDPConn) error {
2011-04-03 09:49:23 +00:00
defer l.Close()
handler := srv.Handler
if handler == nil {
handler = DefaultServeMux
}
2012-01-12 21:47:36 +00:00
if srv.UDPSize == 0 {
2012-01-13 10:38:08 +00:00
srv.UDPSize = UDPReceiveMsgSize
2012-01-12 21:47:36 +00:00
}
2011-04-03 09:49:23 +00:00
for {
2012-01-12 21:34:53 +00:00
m := make([]byte, srv.UDPSize)
2011-04-03 09:49:23 +00:00
n, a, e := l.ReadFromUDP(m)
if e != nil {
return e
}
m = m[:n]
2011-04-02 07:22:05 +00:00
2011-04-03 09:49:23 +00:00
if srv.ReadTimeout != 0 {
l.SetReadTimeout(srv.ReadTimeout)
}
if srv.WriteTimeout != 0 {
l.SetWriteTimeout(srv.WriteTimeout)
}
d, err := newConn(nil, l, a, m, handler)
if err != nil {
continue
}
go d.serve()
}
panic("not reached")
2011-04-02 07:22:05 +00:00
}
2011-11-02 22:06:54 +00:00
func newConn(t *net.TCPConn, u *net.UDPConn, a net.Addr, buf []byte, handler Handler) (*conn, error) {
2011-04-03 11:16:33 +00:00
c := new(conn)
2011-04-03 09:49:23 +00:00
c.handler = handler
c._TCP = t
c._UDP = u
c.remoteAddr = a
c.request = buf
2011-04-03 11:16:33 +00:00
return c, nil
2011-04-02 07:22:05 +00:00
}
2011-04-03 09:14:54 +00:00
// Close the connection.
func (c *conn) close() {
2011-04-03 09:49:23 +00:00
switch {
case c._UDP != nil:
c._UDP.Close()
c._UDP = nil
case c._TCP != nil:
c._TCP.Close()
c._TCP = nil
}
2011-04-03 09:14:54 +00:00
}
// Serve a new connection.
2011-04-02 07:22:05 +00:00
func (c *conn) serve() {
2011-04-18 20:08:12 +00:00
for {
// Request has been read in ServeUDP or ServeTCP
w := new(response)
w.conn = c
req := new(Msg)
if !req.Unpack(c.request) {
2011-11-02 22:06:54 +00:00
// Send a format error back
x := new(Msg)
x.SetRcodeFormatError(req)
buf, _ := x.Pack()
w.Write(buf)
2011-04-18 20:08:12 +00:00
break
}
w.req = req
c.handler.ServeDNS(w, w.req) // this does the writing back to the client
if c.hijacked {
return
}
break // TODO(mg) Why is this a loop anyway
}
if c._TCP != nil {
c.close() // Listen and Serve is closed then
}
2011-04-02 07:22:05 +00:00
}
2011-11-02 22:06:54 +00:00
func (w *response) Write(data []byte) (n int, err error) {
2011-04-18 20:08:12 +00:00
switch {
case w.conn._UDP != nil:
n, err = w.conn._UDP.WriteTo(data, w.conn.remoteAddr)
if err != nil {
return 0, err
}
case w.conn._TCP != nil:
// TODO(mg) len(data) > 64K
l := make([]byte, 2)
l[0], l[1] = packUint16(uint16(len(data)))
n, err = w.conn._TCP.Write(l)
if err != nil {
return n, err
}
if n != 2 {
return n, io.ErrShortWrite
}
n, err = w.conn._TCP.Write(data)
if err != nil {
return n, err
}
i := n
if i < len(data) {
j, err := w.conn._TCP.Write(data[i:len(data)])
if err != nil {
return i, err
}
i += j
}
n = i
}
return n, nil
2011-04-03 09:49:23 +00:00
}
2011-04-03 09:14:54 +00:00
2011-04-03 09:49:23 +00:00
// RemoteAddr implements the ResponseWriter.RemoteAddr method
2011-07-05 17:17:29 +00:00
func (w *response) RemoteAddr() net.Addr { return w.conn.remoteAddr }