From c53cddf38c8e7af3507f098bac9397d87e115c54 Mon Sep 17 00:00:00 2001 From: Miek Gieben Date: Sun, 4 Mar 2012 21:00:09 +0100 Subject: [PATCH] TSIG overhauled This lead to some changes in the Write() function. Both server side and client side are now more similar. --- client.go | 5 +++-- defaults.go | 4 ++-- ex/reflect/reflect.go | 50 +++++++++++++++++++++++++++---------------- server.go | 44 +++++++++++++++++++++---------------- 4 files changed, 62 insertions(+), 41 deletions(-) diff --git a/client.go b/client.go index e8d096f3..8f48d2f1 100644 --- a/client.go +++ b/client.go @@ -25,7 +25,7 @@ type QueryHandler interface { // construct a DNS request. type RequestWriter interface { // Write returns the request message and the reply back to the client. - Write(*Msg) + Write(*Msg) error // Send sends the message to the server. Send(*Msg) error // Receive waits for the reply of the servers. @@ -201,8 +201,9 @@ func ListenAndQuery(request chan *Request, handler QueryHandler) { // Write returns the original question and the answer on the // reply channel of the client. -func (w *reply) Write(m *Msg) { +func (w *reply) Write(m *Msg) error { w.Client().ReplyChan <- &Exchange{Request: w.req, Reply: m} + return nil } // Do performs an asynchronous query. The result is returned on the diff --git a/defaults.go b/defaults.go index 8642fdda..74a2e6c4 100644 --- a/defaults.go +++ b/defaults.go @@ -82,8 +82,8 @@ func (dns *Msg) SetAxfr(z string) { // SetTsig appends a TSIG RR to the message. // This is only a skeleton Tsig RR that is added as the last RR in the -// additional section. The caller should then call TsigGenerate, -// to generate the complete TSIG with the secret. +// additional section. The Tsig is calculated when the message is being +// send. func (dns *Msg) SetTsig(z, algo string, fudge uint16, timesigned int64) { t := new(RR_TSIG) t.Hdr = RR_Header{z, TypeTSIG, ClassANY, 0, 0} diff --git a/ex/reflect/reflect.go b/ex/reflect/reflect.go index f728629b..e46c8fb2 100644 --- a/ex/reflect/reflect.go +++ b/ex/reflect/reflect.go @@ -27,6 +27,8 @@ import ( "os/signal" "runtime/pprof" "strconv" + "strings" + "time" ) var ( @@ -84,25 +86,33 @@ func handleReflect(w dns.ResponseWriter, r *dns.Msg) { m.Extra = append(m.Extra, t) } - b, ok := m.Pack() + if r.IsTsig() { + println("Checking TSIG") + if w.TsigStatus() == nil { + println("TSIG OK") + m.SetTsig(r.Extra[len(r.Extra)-1].(*dns.RR_TSIG).Hdr.Name, dns.HmacMD5, 300, time.Now().Unix()) + } + } + + if *printf { fmt.Printf("%v\n", m.String()) } - if !ok { - log.Print("Packing failed") - m.SetRcode(r, dns.RcodeServerFailure) - m.Extra = nil - m.Answer = nil - b, _ = m.Pack() - } - // The reply is smaller then 512 bytes, so it will always "fit" - w.Write(b) + w.Write(m) // Discard the error? } -func serve(net string) { - err := dns.ListenAndServe(":8053", net, nil) - if err != nil { - fmt.Printf("Failed to setup the "+net+" server: %s\n", err.Error()) +func serve(net, name, secret string) { + switch name { + case "": + err := dns.ListenAndServe(":8053", net, nil) + if err != nil { + fmt.Printf("Failed to setup the "+net+" server: %s\n", err.Error()) + } + default: + err := dns.ListenAndServeTsig(":8053", net, nil, map[string]string{name: secret}) + if err != nil { + fmt.Printf("Failed to setup the "+net+" server: %s\n", err.Error()) + } } } @@ -110,12 +120,16 @@ func main() { cpuprofile := flag.String("cpuprofile", "", "write cpu profile to file") printf = flag.Bool("print", false, "print replies") compress = flag.Bool("compress", false, "compress replies") - tsig = flag.String("tisg", "", "use SHA1 hmac tsig: keyname:base64") + tsig = flag.String("tsig", "", "use MD5 hmac tsig: keyname:base64") + var name, secret string flag.Usage = func() { flag.PrintDefaults() } flag.Parse() - tsig = tsig //TODO + if *tsig != "" { + a := strings.SplitN(*tsig, ":", 2) + name, secret = a[0], a[1] + } if *cpuprofile != "" { f, err := os.Create(*cpuprofile) if err != nil { @@ -126,8 +140,8 @@ func main() { } dns.HandleFunc(".", handleReflect) - go serve("tcp") - go serve("udp") + go serve("tcp", name, secret) + go serve("udp", name, secret) sig := make(chan os.Signal) signal.Notify(sig) forever: diff --git a/server.go b/server.go index 297baa82..b065b464 100644 --- a/server.go +++ b/server.go @@ -26,7 +26,7 @@ type ResponseWriter interface { // Return the status of the Tsig (TsigNone, TsigVerified or TsigBad) TsigStatus() error // Write writes a reply back to the client. - Write([]byte) (int, error) + Write(*Msg) error } type conn struct { @@ -40,9 +40,11 @@ type conn struct { } type response struct { - conn *conn - req *Msg - tsigStatus error + conn *conn + req *Msg + tsigStatus error + tsigTimersOnly bool + tsigRequestMAC string } // ServeMux is an DNS request multiplexer. It matches the @@ -75,8 +77,7 @@ func (f HandlerFunc) ServeDNS(w ResponseWriter, r *Msg) { func Refused(w ResponseWriter, r *Msg) { m := new(Msg) m.SetRcode(r, RcodeRefused) - buf, _ := m.Pack() - w.Write(buf) + w.Write(m) } // RefusedHandler returns HandlerFunc with Refused. @@ -310,8 +311,7 @@ func (c *conn) serve() { // Send a format error back x := new(Msg) x.SetRcodeFormatError(req) - buf, _ := x.Pack() - w.Write(buf) + w.Write(x) break } @@ -321,8 +321,9 @@ func (c *conn) serve() { if _, ok := w.conn.tsigSecret[secret]; !ok { w.tsigStatus = ErrKeyAlg } - // Do I *ever* need Tsig.Mac here? Or timersOnly? TODO(mg) w.tsigStatus = TsigVerify(c.request, w.conn.tsigSecret[secret], "", false) + w.tsigTimersOnly = false // Will this ever be true? + w.tsigRequestMAC = req.Extra[len(req.Extra)-1].(*RR_TSIG).MAC } w.req = req c.handler.ServeDNS(w, w.req) // this does the writing back to the client @@ -336,41 +337,46 @@ func (c *conn) serve() { } } -func (w *response) Write(data []byte) (n int, err error) { +func (w *response) Write(m *Msg) error { + //data []byte) (n int, err error) { + data, ok := m.Pack() + if !ok { + return ErrPack + } switch { case w.conn._UDP != nil: - n, err = w.conn._UDP.WriteTo(data, w.conn.remoteAddr) + _, err := w.conn._UDP.WriteTo(data, w.conn.remoteAddr) if err != nil { - return 0, err + return err } case w.conn._TCP != nil: if len(data) > MaxMsgSize { - return 0, ErrBuf + return ErrBuf } l := make([]byte, 2) l[0], l[1] = packUint16(uint16(len(data))) - n, err = w.conn._TCP.Write(l) + n, err := w.conn._TCP.Write(l) if err != nil { - return n, err + return err } if n != 2 { - return n, io.ErrShortWrite + return io.ErrShortWrite } n, err = w.conn._TCP.Write(data) if err != nil { - return n, err + return err } i := n if i < len(data) { j, err := w.conn._TCP.Write(data[i:len(data)]) if err != nil { - return i, err + return err } i += j } n = i } - return n, nil + return nil } // RemoteAddr implements the ResponseWriter.RemoteAddr method