Support TsigProvider for Server and Transfer (#1331)
Automatically submitted.
This commit is contained in:
parent
51afb90ed3
commit
33e64002b6
32
client.go
32
client.go
@ -39,6 +39,14 @@ type Conn struct {
|
||||
tsigRequestMAC string
|
||||
}
|
||||
|
||||
func (co *Conn) tsigProvider() TsigProvider {
|
||||
if co.TsigProvider != nil {
|
||||
return co.TsigProvider
|
||||
}
|
||||
// tsigSecretProvider will return ErrSecret if co.TsigSecret is nil.
|
||||
return tsigSecretProvider(co.TsigSecret)
|
||||
}
|
||||
|
||||
// A Client defines parameters for a DNS client.
|
||||
type Client struct {
|
||||
Net string // if "tcp" or "tcp-tls" (DNS over TLS) a TCP query will be initiated, otherwise an UDP one (default is "" for UDP)
|
||||
@ -271,15 +279,8 @@ func (co *Conn) ReadMsg() (*Msg, error) {
|
||||
return m, err
|
||||
}
|
||||
if t := m.IsTsig(); t != nil {
|
||||
if co.TsigProvider != nil {
|
||||
err = tsigVerifyProvider(p, co.TsigProvider, co.tsigRequestMAC, false)
|
||||
} else {
|
||||
if _, ok := co.TsigSecret[t.Hdr.Name]; !ok {
|
||||
return m, ErrSecret
|
||||
}
|
||||
// Need to work on the original message p, as that was used to calculate the tsig.
|
||||
err = TsigVerify(p, co.TsigSecret[t.Hdr.Name], co.tsigRequestMAC, false)
|
||||
}
|
||||
// Need to work on the original message p, as that was used to calculate the tsig.
|
||||
err = tsigVerifyProvider(p, co.tsigProvider(), co.tsigRequestMAC, false)
|
||||
}
|
||||
return m, err
|
||||
}
|
||||
@ -356,17 +357,8 @@ func (co *Conn) Read(p []byte) (n int, err error) {
|
||||
func (co *Conn) WriteMsg(m *Msg) (err error) {
|
||||
var out []byte
|
||||
if t := m.IsTsig(); t != nil {
|
||||
mac := ""
|
||||
if co.TsigProvider != nil {
|
||||
out, mac, err = tsigGenerateProvider(m, co.TsigProvider, co.tsigRequestMAC, false)
|
||||
} else {
|
||||
if _, ok := co.TsigSecret[t.Hdr.Name]; !ok {
|
||||
return ErrSecret
|
||||
}
|
||||
out, mac, err = TsigGenerate(m, co.TsigSecret[t.Hdr.Name], co.tsigRequestMAC, false)
|
||||
}
|
||||
// Set for the next read, although only used in zone transfers
|
||||
co.tsigRequestMAC = mac
|
||||
// Set tsigRequestMAC for the next read, although only used in zone transfers.
|
||||
out, co.tsigRequestMAC, err = tsigGenerateProvider(m, co.tsigProvider(), co.tsigRequestMAC, false)
|
||||
} else {
|
||||
out, err = m.Pack()
|
||||
}
|
||||
|
42
server.go
42
server.go
@ -71,12 +71,12 @@ type response struct {
|
||||
tsigTimersOnly bool
|
||||
tsigStatus error
|
||||
tsigRequestMAC string
|
||||
tsigSecret map[string]string // the tsig secrets
|
||||
udp net.PacketConn // i/o connection if UDP was used
|
||||
tcp net.Conn // i/o connection if TCP was used
|
||||
udpSession *SessionUDP // oob data to get egress interface right
|
||||
pcSession net.Addr // address to use when writing to a generic net.PacketConn
|
||||
writer Writer // writer to output the raw DNS bits
|
||||
tsigProvider TsigProvider
|
||||
udp net.PacketConn // i/o connection if UDP was used
|
||||
tcp net.Conn // i/o connection if TCP was used
|
||||
udpSession *SessionUDP // oob data to get egress interface right
|
||||
pcSession net.Addr // address to use when writing to a generic net.PacketConn
|
||||
writer Writer // writer to output the raw DNS bits
|
||||
}
|
||||
|
||||
// handleRefused returns a HandlerFunc that returns REFUSED for every request it gets.
|
||||
@ -211,6 +211,8 @@ type Server struct {
|
||||
WriteTimeout time.Duration
|
||||
// TCP idle timeout for multiple queries, if nil, defaults to 8 * time.Second (RFC 5966).
|
||||
IdleTimeout func() time.Duration
|
||||
// An implementation of the TsigProvider interface. If defined it replaces TsigSecret and is used for all TSIG operations.
|
||||
TsigProvider TsigProvider
|
||||
// Secret(s) for Tsig map[<zonename>]<base64 secret>. The zonename must be in canonical form (lowercase, fqdn, see RFC 4034 Section 6.2).
|
||||
TsigSecret map[string]string
|
||||
// If NotifyStartedFunc is set it is called once the server has started listening.
|
||||
@ -238,6 +240,16 @@ type Server struct {
|
||||
udpPool sync.Pool
|
||||
}
|
||||
|
||||
func (srv *Server) tsigProvider() TsigProvider {
|
||||
if srv.TsigProvider != nil {
|
||||
return srv.TsigProvider
|
||||
}
|
||||
if srv.TsigSecret != nil {
|
||||
return tsigSecretProvider(srv.TsigSecret)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (srv *Server) isStarted() bool {
|
||||
srv.lock.RLock()
|
||||
started := srv.started
|
||||
@ -526,7 +538,7 @@ func (srv *Server) serveUDP(l net.PacketConn) error {
|
||||
|
||||
// Serve a new TCP connection.
|
||||
func (srv *Server) serveTCPConn(wg *sync.WaitGroup, rw net.Conn) {
|
||||
w := &response{tsigSecret: srv.TsigSecret, tcp: rw}
|
||||
w := &response{tsigProvider: srv.tsigProvider(), tcp: rw}
|
||||
if srv.DecorateWriter != nil {
|
||||
w.writer = srv.DecorateWriter(w)
|
||||
} else {
|
||||
@ -581,7 +593,7 @@ func (srv *Server) serveTCPConn(wg *sync.WaitGroup, rw net.Conn) {
|
||||
|
||||
// Serve a new UDP request.
|
||||
func (srv *Server) serveUDPPacket(wg *sync.WaitGroup, m []byte, u net.PacketConn, udpSession *SessionUDP, pcSession net.Addr) {
|
||||
w := &response{tsigSecret: srv.TsigSecret, udp: u, udpSession: udpSession, pcSession: pcSession}
|
||||
w := &response{tsigProvider: srv.tsigProvider(), udp: u, udpSession: udpSession, pcSession: pcSession}
|
||||
if srv.DecorateWriter != nil {
|
||||
w.writer = srv.DecorateWriter(w)
|
||||
} else {
|
||||
@ -632,15 +644,11 @@ func (srv *Server) serveDNS(m []byte, w *response) {
|
||||
}
|
||||
|
||||
w.tsigStatus = nil
|
||||
if w.tsigSecret != nil {
|
||||
if w.tsigProvider != nil {
|
||||
if t := req.IsTsig(); t != nil {
|
||||
if secret, ok := w.tsigSecret[t.Hdr.Name]; ok {
|
||||
w.tsigStatus = TsigVerify(m, secret, "", false)
|
||||
} else {
|
||||
w.tsigStatus = ErrSecret
|
||||
}
|
||||
w.tsigStatus = tsigVerifyProvider(m, w.tsigProvider, "", false)
|
||||
w.tsigTimersOnly = false
|
||||
w.tsigRequestMAC = req.Extra[len(req.Extra)-1].(*TSIG).MAC
|
||||
w.tsigRequestMAC = t.MAC
|
||||
}
|
||||
}
|
||||
|
||||
@ -718,9 +726,9 @@ func (w *response) WriteMsg(m *Msg) (err error) {
|
||||
}
|
||||
|
||||
var data []byte
|
||||
if w.tsigSecret != nil { // if no secrets, dont check for the tsig (which is a longer check)
|
||||
if w.tsigProvider != nil { // if no provider, dont check for the tsig (which is a longer check)
|
||||
if t := m.IsTsig(); t != nil {
|
||||
data, w.tsigRequestMAC, err = TsigGenerate(m, w.tsigSecret[t.Hdr.Name], w.tsigRequestMAC, w.tsigTimersOnly)
|
||||
data, w.tsigRequestMAC, err = tsigGenerateProvider(m, w.tsigProvider, w.tsigRequestMAC, w.tsigTimersOnly)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
18
tsig.go
18
tsig.go
@ -74,6 +74,24 @@ func (key tsigHMACProvider) Verify(msg []byte, t *TSIG) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type tsigSecretProvider map[string]string
|
||||
|
||||
func (ts tsigSecretProvider) Generate(msg []byte, t *TSIG) ([]byte, error) {
|
||||
key, ok := ts[t.Hdr.Name]
|
||||
if !ok {
|
||||
return nil, ErrSecret
|
||||
}
|
||||
return tsigHMACProvider(key).Generate(msg, t)
|
||||
}
|
||||
|
||||
func (ts tsigSecretProvider) Verify(msg []byte, t *TSIG) error {
|
||||
key, ok := ts[t.Hdr.Name]
|
||||
if !ok {
|
||||
return ErrSecret
|
||||
}
|
||||
return tsigHMACProvider(key).Verify(msg, t)
|
||||
}
|
||||
|
||||
// TSIG is the RR the holds the transaction signature of a message.
|
||||
// See RFC 2845 and RFC 4635.
|
||||
type TSIG struct {
|
||||
|
27
xfr.go
27
xfr.go
@ -17,11 +17,22 @@ type Transfer struct {
|
||||
DialTimeout time.Duration // net.DialTimeout, defaults to 2 seconds
|
||||
ReadTimeout time.Duration // net.Conn.SetReadTimeout value for connections, defaults to 2 seconds
|
||||
WriteTimeout time.Duration // net.Conn.SetWriteTimeout value for connections, defaults to 2 seconds
|
||||
TsigProvider TsigProvider // An implementation of the TsigProvider interface. If defined it replaces TsigSecret and is used for all TSIG operations.
|
||||
TsigSecret map[string]string // Secret(s) for Tsig map[<zonename>]<base64 secret>, zonename must be in canonical form (lowercase, fqdn, see RFC 4034 Section 6.2)
|
||||
tsigTimersOnly bool
|
||||
}
|
||||
|
||||
// Think we need to away to stop the transfer
|
||||
func (t *Transfer) tsigProvider() TsigProvider {
|
||||
if t.TsigProvider != nil {
|
||||
return t.TsigProvider
|
||||
}
|
||||
if t.TsigSecret != nil {
|
||||
return tsigSecretProvider(t.TsigSecret)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO: Think we need to away to stop the transfer
|
||||
|
||||
// In performs an incoming transfer with the server in a.
|
||||
// If you would like to set the source IP, or some other attribute
|
||||
@ -224,12 +235,9 @@ func (t *Transfer) ReadMsg() (*Msg, error) {
|
||||
if err := m.Unpack(p); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if ts := m.IsTsig(); ts != nil && t.TsigSecret != nil {
|
||||
if _, ok := t.TsigSecret[ts.Hdr.Name]; !ok {
|
||||
return m, ErrSecret
|
||||
}
|
||||
if ts, tp := m.IsTsig(), t.tsigProvider(); ts != nil && tp != nil {
|
||||
// Need to work on the original message p, as that was used to calculate the tsig.
|
||||
err = TsigVerify(p, t.TsigSecret[ts.Hdr.Name], t.tsigRequestMAC, t.tsigTimersOnly)
|
||||
err = tsigVerifyProvider(p, tp, t.tsigRequestMAC, t.tsigTimersOnly)
|
||||
t.tsigRequestMAC = ts.MAC
|
||||
}
|
||||
return m, err
|
||||
@ -238,11 +246,8 @@ func (t *Transfer) ReadMsg() (*Msg, error) {
|
||||
// WriteMsg writes a message through the transfer connection t.
|
||||
func (t *Transfer) WriteMsg(m *Msg) (err error) {
|
||||
var out []byte
|
||||
if ts := m.IsTsig(); ts != nil && t.TsigSecret != nil {
|
||||
if _, ok := t.TsigSecret[ts.Hdr.Name]; !ok {
|
||||
return ErrSecret
|
||||
}
|
||||
out, t.tsigRequestMAC, err = TsigGenerate(m, t.TsigSecret[ts.Hdr.Name], t.tsigRequestMAC, t.tsigTimersOnly)
|
||||
if ts, tp := m.IsTsig(), t.tsigProvider(); ts != nil && tp != nil {
|
||||
out, t.tsigRequestMAC, err = tsigGenerateProvider(m, tp, t.tsigRequestMAC, t.tsigTimersOnly)
|
||||
} else {
|
||||
out, err = m.Pack()
|
||||
}
|
||||
|
56
xfr_test.go
56
xfr_test.go
@ -1,6 +1,9 @@
|
||||
package dns
|
||||
|
||||
import "testing"
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
tsigSecret = map[string]string{"axfr.": "so6ZGir4GPAqINNh9U5c3A=="}
|
||||
@ -127,3 +130,54 @@ func axfrTestingSuite(t *testing.T, addrstr string) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func axfrTestingSuiteWithCustomTsig(t *testing.T, addrstr string, provider TsigProvider) {
|
||||
tr := new(Transfer)
|
||||
m := new(Msg)
|
||||
var err error
|
||||
tr.Conn, err = Dial("tcp", addrstr)
|
||||
if err != nil {
|
||||
t.Fatal("failed to dial", err)
|
||||
}
|
||||
tr.TsigProvider = provider
|
||||
m.SetAxfr("miek.nl.")
|
||||
m.SetTsig("axfr.", HmacSHA256, 300, time.Now().Unix())
|
||||
|
||||
c, err := tr.In(m, addrstr)
|
||||
if err != nil {
|
||||
t.Fatal("failed to zone transfer in", err)
|
||||
}
|
||||
|
||||
var records []RR
|
||||
for msg := range c {
|
||||
if msg.Error != nil {
|
||||
t.Fatal(msg.Error)
|
||||
}
|
||||
records = append(records, msg.RR...)
|
||||
}
|
||||
|
||||
if len(records) != len(xfrTestData) {
|
||||
t.Fatalf("bad axfr: expected %v, got %v", records, xfrTestData)
|
||||
}
|
||||
|
||||
for i, rr := range records {
|
||||
if !IsDuplicate(rr, xfrTestData[i]) {
|
||||
t.Errorf("bad axfr: expected %v, got %v", records, xfrTestData)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCustomTsigProvider(t *testing.T) {
|
||||
HandleFunc("miek.nl.", SingleEnvelopeXfrServer)
|
||||
defer HandleRemove("miek.nl.")
|
||||
|
||||
s, addrstr, _, err := RunLocalTCPServer(":0", func(srv *Server) {
|
||||
srv.TsigProvider = tsigSecretProvider(tsigSecret)
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unable to run test server: %s", err)
|
||||
}
|
||||
defer s.Shutdown()
|
||||
|
||||
axfrTestingSuiteWithCustomTsig(t, addrstr, tsigSecretProvider(tsigSecret))
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user