[tls] Carry TLS state within (possibly) response writer (#728)

* [tls] Carry TLS state within (possibly) response writer

This allows a server to make decision wether or not the link used to
connect to the DNS server is using TLS.
This can be used by the handler for instance to (but not limited to):
- log that the request was TLS vs TCP
- craft specific responsed knowing that the link is secured
- return custom answers based on client cert (if provided)
...

Fixes #711

* Address @tmthrgd comments:
- do not check whether w.tcp is nil
- create RR after setting txt value

* Address @miekg comments.

Attempt to make a TLS connection state specific test, it goes over
testing each individual server types (TLS, TCP, UDP) and validate that
tls.Connectionstate is only accessible when expected.

* ConnectionState() returns value instead of pointer

* * make ConnectionStater.ConnectionState() return a pointer again
* rename interface ConnectionState to ConnectionStater
* fix nits pointed by @tmthrgd

* @tmthrgd comment: Do not use concret type in `ConnectionState`
This commit is contained in:
chantra 2018-09-22 10:34:55 -07:00 committed by Miek Gieben
parent 426ea785a9
commit 833bf76c28
2 changed files with 110 additions and 0 deletions

View File

@ -63,6 +63,12 @@ type ResponseWriter interface {
Hijack()
}
// A ConnectionStater interface is used by a DNS Handler to access TLS connection state
// when available.
type ConnectionStater interface {
ConnectionState() *tls.ConnectionState
}
type response struct {
msg []byte
hijacked bool // connection has been hijacked by handler
@ -894,3 +900,15 @@ func (w *response) Close() error {
}
return nil
}
// ConnectionState() implements the ConnectionStater.ConnectionState() interface.
func (w *response) ConnectionState() *tls.ConnectionState {
type tlsConnectionStater interface {
ConnectionState() tls.ConnectionState
}
if v, ok := w.tcp.(tlsConnectionStater); ok {
t := v.ConnectionState()
return &t
}
return nil
}

View File

@ -263,6 +263,98 @@ func TestServingTLS(t *testing.T) {
}
}
// TestServingTLSConnectionState tests that we only can access
// tls.ConnectionState under a DNS query handled by a TLS DNS server.
// This test will sequentially create a TLS, UDP and TCP server, attach a custom
// handler which will set a testing error if tls.ConnectionState is available
// when it is not expected, or the other way around.
func TestServingTLSConnectionState(t *testing.T) {
handlerResponse := "Hello example"
// tlsHandlerTLS is a HandlerFunc that can be set to expect or not TLS
// connection state.
tlsHandlerTLS := func(tlsExpected bool) func(ResponseWriter, *Msg) {
return func(w ResponseWriter, req *Msg) {
m := new(Msg)
m.SetReply(req)
tlsFound := true
if connState := w.(ConnectionStater).ConnectionState(); connState == nil {
tlsFound = false
}
if tlsFound != tlsExpected {
t.Errorf("TLS connection state available: %t, expected: %t", tlsFound, tlsExpected)
}
m.Extra = make([]RR, 1)
m.Extra[0] = &TXT{Hdr: RR_Header{Name: m.Question[0].Name, Rrtype: TypeTXT, Class: ClassINET, Ttl: 0}, Txt: []string{handlerResponse}}
w.WriteMsg(m)
}
}
// Question used in tests
m := new(Msg)
m.SetQuestion("tlsstate.example.net.", TypeTXT)
// TLS DNS server
HandleFunc(".", tlsHandlerTLS(true))
cert, err := tls.X509KeyPair(CertPEMBlock, KeyPEMBlock)
if err != nil {
t.Fatalf("unable to build certificate: %v", err)
}
config := tls.Config{
Certificates: []tls.Certificate{cert},
}
s, addrstr, err := RunLocalTLSServer(":0", &config)
if err != nil {
t.Fatalf("unable to run test server: %v", err)
}
defer s.Shutdown()
// TLS DNS query
c := &Client{
Net: "tcp-tls",
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
},
}
_, _, err = c.Exchange(m, addrstr)
if err != nil {
t.Error("failed to exchange tlsstate.example.net", err)
}
HandleRemove(".")
// UDP DNS Server
HandleFunc(".", tlsHandlerTLS(false))
defer HandleRemove(".")
s, addrstr, err = RunLocalUDPServer(":0")
if err != nil {
t.Fatalf("unable to run test server: %v", err)
}
defer s.Shutdown()
// UDP DNS query
c = new(Client)
_, _, err = c.Exchange(m, addrstr)
if err != nil {
t.Error("failed to exchange tlsstate.example.net", err)
}
// TCP DNS Server
s, addrstr, err = RunLocalTCPServer(":0")
if err != nil {
t.Fatalf("unable to run test server: %v", err)
}
defer s.Shutdown()
// TCP DNS query
c = &Client{Net: "tcp"}
_, _, err = c.Exchange(m, addrstr)
if err != nil {
t.Error("failed to exchange tlsstate.example.net", err)
}
}
func TestServingListenAndServe(t *testing.T) {
HandleFunc("example.com.", AnotherHelloServer)
defer HandleRemove("example.com.")