diff --git a/server.go b/server.go index bb0d074a..7680b69d 100644 --- a/server.go +++ b/server.go @@ -4,6 +4,7 @@ package dns import ( "bytes" + "crypto/tls" "io" "net" "sync" @@ -47,7 +48,7 @@ type response struct { tsigRequestMAC string tsigSecret map[string]string // the tsig secrets udp *net.UDPConn // i/o connection if UDP was used - tcp *net.TCPConn // i/o connection if TCP was used + tcp net.Conn // i/o connection if TCP was used udpSession *SessionUDP // oob data to get egress interface right remoteAddr net.Addr // address of the client writer Writer // writer to output the raw DNS bits @@ -210,7 +211,7 @@ type Writer interface { type Reader interface { // ReadTCP reads a raw message from a TCP connection. Implementations may alter // connection properties, for example the read-deadline. - ReadTCP(conn *net.TCPConn, timeout time.Duration) ([]byte, error) + ReadTCP(conn net.Conn, timeout time.Duration) ([]byte, error) // ReadUDP reads a raw message from a UDP connection. Implementations may alter // connection properties, for example the read-deadline. ReadUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *SessionUDP, error) @@ -222,7 +223,7 @@ type defaultReader struct { *Server } -func (dr *defaultReader) ReadTCP(conn *net.TCPConn, timeout time.Duration) ([]byte, error) { +func (dr *defaultReader) ReadTCP(conn net.Conn, timeout time.Duration) ([]byte, error) { return dr.readTCP(conn, timeout) } @@ -242,10 +243,12 @@ type DecorateWriter func(Writer) Writer type Server struct { // Address to listen on, ":dns" if empty. Addr string - // if "tcp" it will invoke a TCP listener, otherwise an UDP one. + // if "tcp" or "tcp-tls" (DNS over TLS) it will invoke a TCP listener, otherwise an UDP one Net string // TCP Listener to use, this is to aid in systemd's socket activation. Listener net.Listener + // TLS connection configuration + TLSConfig *tls.Config // UDP "Listener" to use, this is to aid in systemd's socket activation. PacketConn net.PacketConn // Handler to invoke, dns.DefaultServeMux if nil. @@ -309,6 +312,24 @@ func (srv *Server) ListenAndServe() error { e = srv.serveTCP(l) srv.lock.Lock() // to satisfy the defer at the top return e + case "tcp-tls", "tcp4-tls", "tcp6-tls": + network := "tcp" + if srv.Net == "tcp4-tls" { + network = "tcp4" + } else if srv.Net == "tcp6" { + network = "tcp6" + } + + l, e := tls.Listen(network, addr, srv.TLSConfig) + if e != nil { + return e + } + srv.Listener = l + srv.started = true + srv.lock.Unlock() + e = srv.serveTCP(l) + srv.lock.Lock() // to satisfy the defer at the top + return e case "udp", "udp4", "udp6": a, e := net.ResolveUDPAddr(srv.Net, addr) if e != nil { @@ -357,13 +378,11 @@ func (srv *Server) ActivateAndServe() error { } } if l != nil { - if t, ok := l.(*net.TCPListener); ok { - srv.started = true - srv.lock.Unlock() - e := srv.serveTCP(t) - srv.lock.Lock() // to satisfy the defer at the top - return e - } + srv.started = true + srv.lock.Unlock() + e := srv.serveTCP(l) + srv.lock.Lock() // to satisfy the defer at the top + return e } return &Error{err: "bad listeners"} } @@ -413,7 +432,7 @@ func (srv *Server) getReadTimeout() time.Duration { // serveTCP starts a TCP listener for the server. // Each request is handled in a separate goroutine. -func (srv *Server) serveTCP(l *net.TCPListener) error { +func (srv *Server) serveTCP(l net.Listener) error { defer l.Close() if srv.NotifyStartedFunc != nil { @@ -432,7 +451,7 @@ func (srv *Server) serveTCP(l *net.TCPListener) error { rtimeout := srv.getReadTimeout() // deadline is not used here for { - rw, e := l.AcceptTCP() + rw, e := l.Accept() if e != nil { if neterr, ok := e.(net.Error); ok && neterr.Temporary() { continue @@ -491,7 +510,7 @@ func (srv *Server) serveUDP(l *net.UDPConn) error { } // Serve a new connection. -func (srv *Server) serve(a net.Addr, h Handler, m []byte, u *net.UDPConn, s *SessionUDP, t *net.TCPConn) { +func (srv *Server) serve(a net.Addr, h Handler, m []byte, u *net.UDPConn, s *SessionUDP, t net.Conn) { defer srv.inFlight.Done() w := &response{tsigSecret: srv.TsigSecret, udp: u, tcp: t, remoteAddr: a, udpSession: s} @@ -564,7 +583,7 @@ Exit: return } -func (srv *Server) readTCP(conn *net.TCPConn, timeout time.Duration) ([]byte, error) { +func (srv *Server) readTCP(conn net.Conn, timeout time.Duration) ([]byte, error) { conn.SetReadDeadline(time.Now().Add(timeout)) l := make([]byte, 2) n, err := conn.Read(l) diff --git a/server_test.go b/server_test.go index c7ae3737..1b5cbc97 100644 --- a/server_test.go +++ b/server_test.go @@ -1,6 +1,7 @@ package dns import ( + "crypto/tls" "fmt" "io" "net" @@ -109,6 +110,27 @@ func RunLocalTCPServer(laddr string) (*Server, string, error) { return server, l.Addr().String(), nil } +func RunLocalTLSServer(laddr string, config *tls.Config) (*Server, string, error) { + l, err := tls.Listen("tcp", laddr, config) + if err != nil { + return nil, "", err + } + + server := &Server{Listener: l, ReadTimeout: time.Hour, WriteTimeout: time.Hour} + + waitLock := sync.Mutex{} + waitLock.Lock() + server.NotifyStartedFunc = waitLock.Unlock + + go func() { + server.ActivateAndServe() + l.Close() + }() + + waitLock.Lock() + return server, l.Addr().String(), nil +} + func TestServing(t *testing.T) { HandleFunc("miek.nl.", HelloServer) HandleFunc("example.com.", AnotherHelloServer) @@ -155,6 +177,66 @@ func TestServing(t *testing.T) { } } +func TestServingTLS(t *testing.T) { + HandleFunc("miek.nl.", HelloServer) + HandleFunc("example.com.", AnotherHelloServer) + defer HandleRemove("miek.nl.") + defer HandleRemove("example.com.") + + cert, err := tls.X509KeyPair(CertPEMBlock, KeyPEMBlock) + if err != nil { + t.Fatalf("unable to build certificate: %v", err) + } + + config := tls.Config{ + Certificates: []tls.Certificate{cert}, + } + + s, addrstr, err := RunLocalTLSServer("127.0.0.1:0", &config) + if err != nil { + t.Fatalf("unable to run test server: %v", err) + } + defer s.Shutdown() + + c := new(Client) + c.Net = "tcp-tls" + c.TLSConfig = &tls.Config{ + InsecureSkipVerify: true, + } + + m := new(Msg) + m.SetQuestion("miek.nl.", TypeTXT) + r, _, err := c.Exchange(m, addrstr) + if err != nil || len(r.Extra) == 0 { + t.Fatal("failed to exchange miek.nl", err) + } + txt := r.Extra[0].(*TXT).Txt[0] + if txt != "Hello world" { + t.Error("unexpected result for miek.nl", txt, "!= Hello world") + } + + m.SetQuestion("example.com.", TypeTXT) + r, _, err = c.Exchange(m, addrstr) + if err != nil { + t.Fatal("failed to exchange example.com", err) + } + txt = r.Extra[0].(*TXT).Txt[0] + if txt != "Hello example" { + t.Error("unexpected result for example.com", txt, "!= Hello example") + } + + // Test Mixes cased as noticed by Ask. + m.SetQuestion("eXaMplE.cOm.", TypeTXT) + r, _, err = c.Exchange(m, addrstr) + if err != nil { + t.Error("failed to exchange eXaMplE.cOm", err) + } + txt = r.Extra[0].(*TXT).Txt[0] + if txt != "Hello example" { + t.Error("unexpected result for example.com", txt, "!= Hello example") + } +} + func BenchmarkServe(b *testing.B) { b.StopTimer() HandleFunc("miek.nl.", HelloServer) @@ -399,6 +481,26 @@ func TestShutdownTCP(t *testing.T) { } } +func TestShutdownTLS(t *testing.T) { + cert, err := tls.X509KeyPair(CertPEMBlock, KeyPEMBlock) + if err != nil { + t.Fatalf("unable to build certificate: %v", err) + } + + config := tls.Config{ + Certificates: []tls.Certificate{cert}, + } + + s, _, err := RunLocalTLSServer("127.0.0.1:0", &config) + if err != nil { + t.Fatalf("unable to run test server: %v", err) + } + err = s.Shutdown() + if err != nil { + t.Errorf("could not shutdown test TLS server, %v", err) + } +} + type trigger struct { done bool sync.RWMutex @@ -523,3 +625,55 @@ func ExampleDecorateWriter() { } // Output: writing raw DNS message of length 56 } + +var ( + // CertPEMBlock is a X509 data used to test TLS servers (used with tls.X509KeyPair) + CertPEMBlock = []byte(`-----BEGIN CERTIFICATE----- +MIIDAzCCAeugAwIBAgIRAJFYMkcn+b8dpU15wjf++GgwDQYJKoZIhvcNAQELBQAw +EjEQMA4GA1UEChMHQWNtZSBDbzAeFw0xNjAxMDgxMjAzNTNaFw0xNzAxMDcxMjAz +NTNaMBIxEDAOBgNVBAoTB0FjbWUgQ28wggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAw +ggEKAoIBAQDXjqO6skvP03k58CNjQggd9G/mt+Wa+xRU+WXiKCCHttawM8x+slq5 +yfsHCwxlwsGn79HmJqecNqgHb2GWBXAvVVokFDTcC1hUP4+gp2gu9Ny27UHTjlLm +O0l/xZ5MN8tfKyYlFw18tXu3fkaPyHj8v/D1RDkuo4ARdFvGSe8TqisbhLk2+9ow +xfIGbEM9Fdiw8qByC2+d+FfvzIKz3GfQVwn0VoRom8L6NBIANq1IGrB5JefZB6nv +DnfuxkBmY7F1513HKuEJ8KsLWWZWV9OPU4j4I4Rt+WJNlKjbD2srHxyrS2RDsr91 +8nCkNoWVNO3sZq0XkWKecdc921vL4ginAgMBAAGjVDBSMA4GA1UdDwEB/wQEAwIC +pDATBgNVHSUEDDAKBggrBgEFBQcDATAPBgNVHRMBAf8EBTADAQH/MBoGA1UdEQQT +MBGCCWxvY2FsaG9zdIcEfwAAATANBgkqhkiG9w0BAQsFAAOCAQEAGcU3iyLBIVZj +aDzSvEDHUd1bnLBl1C58Xu/CyKlPqVU7mLfK0JcgEaYQTSX6fCJVNLbbCrcGLsPJ +fbjlBbyeLjTV413fxPVuona62pBFjqdtbli2Qe8FRH2KBdm41JUJGdo+SdsFu7nc +BFOcubdw6LLIXvsTvwndKcHWx1rMX709QU1Vn1GAIsbJV/DWI231Jyyb+lxAUx/C +8vce5uVxiKcGS+g6OjsN3D3TtiEQGSXLh013W6Wsih8td8yMCMZ3w8LQ38br1GUe +ahLIgUJ9l6HDguM17R7kGqxNvbElsMUHfTtXXP7UDQUiYXDakg8xDP6n9DCDhJ8Y +bSt7OLB7NQ== +-----END CERTIFICATE-----`) + + // KeyPEMBlock is a X509 data used to test TLS servers (used with tls.X509KeyPair) + KeyPEMBlock = []byte(`-----BEGIN RSA PRIVATE KEY----- +MIIEpQIBAAKCAQEA146jurJLz9N5OfAjY0IIHfRv5rflmvsUVPll4iggh7bWsDPM +frJaucn7BwsMZcLBp+/R5iannDaoB29hlgVwL1VaJBQ03AtYVD+PoKdoLvTctu1B +045S5jtJf8WeTDfLXysmJRcNfLV7t35Gj8h4/L/w9UQ5LqOAEXRbxknvE6orG4S5 +NvvaMMXyBmxDPRXYsPKgcgtvnfhX78yCs9xn0FcJ9FaEaJvC+jQSADatSBqweSXn +2Qep7w537sZAZmOxdeddxyrhCfCrC1lmVlfTj1OI+COEbfliTZSo2w9rKx8cq0tk +Q7K/dfJwpDaFlTTt7GatF5FinnHXPdtby+IIpwIDAQABAoIBAAJK4RDmPooqTJrC +JA41MJLo+5uvjwCT9QZmVKAQHzByUFw1YNJkITTiognUI0CdzqNzmH7jIFs39ZeG +proKusO2G6xQjrNcZ4cV2fgyb5g4QHStl0qhs94A+WojduiGm2IaumAgm6Mc5wDv +ld6HmknN3Mku/ZCyanVFEIjOVn2WB7ZQLTBs6ZYaebTJG2Xv6p9t2YJW7pPQ9Xce +s9ohAWohyM4X/OvfnfnLtQp2YLw/BxwehBsCR5SXM3ibTKpFNtxJC8hIfTuWtxZu +2ywrmXShYBRB1WgtZt5k04bY/HFncvvcHK3YfI1+w4URKtwdaQgPUQRbVwDwuyBn +flfkCJECgYEA/eWt01iEyE/lXkGn6V9lCocUU7lCU6yk5UT8VXVUc5If4KZKPfCk +p4zJDOqwn2eM673aWz/mG9mtvAvmnugaGjcaVCyXOp/D/GDmKSoYcvW5B/yjfkLy +dK6Yaa5LDRVYlYgyzcdCT5/9Qc626NzFwKCZNI4ncIU8g7ViATRxWJ8CgYEA2Ver +vZ0M606sfgC0H3NtwNBxmuJ+lIF5LNp/wDi07lDfxRR1rnZMX5dnxjcpDr/zvm8J +WtJJX3xMgqjtHuWKL3yKKony9J5ZPjichSbSbhrzfovgYIRZLxLLDy4MP9L3+CX/ +yBXnqMWuSnFX+M5fVGxdDWiYF3V+wmeOv9JvavkCgYEAiXAPDFzaY+R78O3xiu7M +r0o3wqqCMPE/wav6O/hrYrQy9VSO08C0IM6g9pEEUwWmzuXSkZqhYWoQFb8Lc/GI +T7CMXAxXQLDDUpbRgG79FR3Wr3AewHZU8LyiXHKwxcBMV4WGmsXGK3wbh8fyU1NO +6NsGk+BvkQVOoK1LBAPzZ1kCgYEAsBSmD8U33T9s4dxiEYTrqyV0lH3g/SFz8ZHH +pAyNEPI2iC1ONhyjPWKlcWHpAokiyOqeUpVBWnmSZtzC1qAydsxYB6ShT+sl9BHb +RMix/QAauzBJhQhUVJ3OIys0Q1UBDmqCsjCE8SfOT4NKOUnA093C+YT+iyrmmktZ +zDCJkckCgYEAndqM5KXGk5xYo+MAA1paZcbTUXwaWwjLU+XSRSSoyBEi5xMtfvUb +7+a1OMhLwWbuz+pl64wFKrbSUyimMOYQpjVE/1vk/kb99pxbgol27hdKyTH1d+ov +kFsxKCqxAnBVGEWAvVZAiiTOxleQFjz5RnL0BQp9Lg2cQe+dvuUmIAA= +-----END RSA PRIVATE KEY-----`) +)