From 833bf76c282d338e307ff7ec181b95cfc117deb2 Mon Sep 17 00:00:00 2001 From: chantra Date: Sat, 22 Sep 2018 10:34:55 -0700 Subject: [PATCH] [tls] Carry TLS state within (possibly) response writer (#728) * [tls] Carry TLS state within (possibly) response writer This allows a server to make decision wether or not the link used to connect to the DNS server is using TLS. This can be used by the handler for instance to (but not limited to): - log that the request was TLS vs TCP - craft specific responsed knowing that the link is secured - return custom answers based on client cert (if provided) ... Fixes #711 * Address @tmthrgd comments: - do not check whether w.tcp is nil - create RR after setting txt value * Address @miekg comments. Attempt to make a TLS connection state specific test, it goes over testing each individual server types (TLS, TCP, UDP) and validate that tls.Connectionstate is only accessible when expected. * ConnectionState() returns value instead of pointer * * make ConnectionStater.ConnectionState() return a pointer again * rename interface ConnectionState to ConnectionStater * fix nits pointed by @tmthrgd * @tmthrgd comment: Do not use concret type in `ConnectionState` --- server.go | 18 ++++++++++ server_test.go | 92 ++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 110 insertions(+) diff --git a/server.go b/server.go index 4fbf7db6..2901f872 100644 --- a/server.go +++ b/server.go @@ -63,6 +63,12 @@ type ResponseWriter interface { Hijack() } +// A ConnectionStater interface is used by a DNS Handler to access TLS connection state +// when available. +type ConnectionStater interface { + ConnectionState() *tls.ConnectionState +} + type response struct { msg []byte hijacked bool // connection has been hijacked by handler @@ -894,3 +900,15 @@ func (w *response) Close() error { } return nil } + +// ConnectionState() implements the ConnectionStater.ConnectionState() interface. +func (w *response) ConnectionState() *tls.ConnectionState { + type tlsConnectionStater interface { + ConnectionState() tls.ConnectionState + } + if v, ok := w.tcp.(tlsConnectionStater); ok { + t := v.ConnectionState() + return &t + } + return nil +} diff --git a/server_test.go b/server_test.go index c53edd03..d60f0cd0 100644 --- a/server_test.go +++ b/server_test.go @@ -263,6 +263,98 @@ func TestServingTLS(t *testing.T) { } } +// TestServingTLSConnectionState tests that we only can access +// tls.ConnectionState under a DNS query handled by a TLS DNS server. +// This test will sequentially create a TLS, UDP and TCP server, attach a custom +// handler which will set a testing error if tls.ConnectionState is available +// when it is not expected, or the other way around. +func TestServingTLSConnectionState(t *testing.T) { + handlerResponse := "Hello example" + // tlsHandlerTLS is a HandlerFunc that can be set to expect or not TLS + // connection state. + tlsHandlerTLS := func(tlsExpected bool) func(ResponseWriter, *Msg) { + return func(w ResponseWriter, req *Msg) { + m := new(Msg) + m.SetReply(req) + tlsFound := true + if connState := w.(ConnectionStater).ConnectionState(); connState == nil { + tlsFound = false + } + if tlsFound != tlsExpected { + t.Errorf("TLS connection state available: %t, expected: %t", tlsFound, tlsExpected) + } + m.Extra = make([]RR, 1) + m.Extra[0] = &TXT{Hdr: RR_Header{Name: m.Question[0].Name, Rrtype: TypeTXT, Class: ClassINET, Ttl: 0}, Txt: []string{handlerResponse}} + w.WriteMsg(m) + } + } + + // Question used in tests + m := new(Msg) + m.SetQuestion("tlsstate.example.net.", TypeTXT) + + // TLS DNS server + HandleFunc(".", tlsHandlerTLS(true)) + cert, err := tls.X509KeyPair(CertPEMBlock, KeyPEMBlock) + if err != nil { + t.Fatalf("unable to build certificate: %v", err) + } + + config := tls.Config{ + Certificates: []tls.Certificate{cert}, + } + + s, addrstr, err := RunLocalTLSServer(":0", &config) + if err != nil { + t.Fatalf("unable to run test server: %v", err) + } + defer s.Shutdown() + + // TLS DNS query + c := &Client{ + Net: "tcp-tls", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + } + + _, _, err = c.Exchange(m, addrstr) + if err != nil { + t.Error("failed to exchange tlsstate.example.net", err) + } + + HandleRemove(".") + // UDP DNS Server + HandleFunc(".", tlsHandlerTLS(false)) + defer HandleRemove(".") + s, addrstr, err = RunLocalUDPServer(":0") + if err != nil { + t.Fatalf("unable to run test server: %v", err) + } + defer s.Shutdown() + + // UDP DNS query + c = new(Client) + _, _, err = c.Exchange(m, addrstr) + if err != nil { + t.Error("failed to exchange tlsstate.example.net", err) + } + + // TCP DNS Server + s, addrstr, err = RunLocalTCPServer(":0") + if err != nil { + t.Fatalf("unable to run test server: %v", err) + } + defer s.Shutdown() + + // TCP DNS query + c = &Client{Net: "tcp"} + _, _, err = c.Exchange(m, addrstr) + if err != nil { + t.Error("failed to exchange tlsstate.example.net", err) + } +} + func TestServingListenAndServe(t *testing.T) { HandleFunc("example.com.", AnotherHelloServer) defer HandleRemove("example.com.")