diff --git a/Makefile b/Makefile index 7f873bc8..e2a7c0b6 100644 --- a/Makefile +++ b/Makefile @@ -7,7 +7,6 @@ include $(GOROOT)/src/Make.inc TARG=dns GOFILES=\ - xfr.go\ config.go\ defaults.go\ dns.go\ @@ -21,7 +20,7 @@ GOFILES=\ string.go\ tsig.go\ types.go\ -# y.go\ + xfr.go\ include $(GOROOT)/src/Make.pkg diff --git a/README.markdown b/README.markdown index c59c9f09..085a05fb 100644 --- a/README.markdown +++ b/README.markdown @@ -38,6 +38,7 @@ Miek Gieben - 2010, 2011 - miek@miek.nl * 403{3,4,5} - DNSSEC + validation functions * 4255 - SSHFP * 4408 - SPF +* 4509 - SHA256 Hash in DS * 4635 - HMAC SHA TSIG * 5001 - NSID * 5155 - NSEC diff --git a/TODO b/TODO index 6ff81661..9d742bef 100644 --- a/TODO +++ b/TODO @@ -1,18 +1,30 @@ +Guidelines for the API: + +o symmetrical, client side stuff should be mirrored in the server stuff +o clean, small API +o fast data structures (rb-tree, when they come available) +o api-use should lead to self documenting code + +o zone structure -- only as rb-tree +o compression (only ownernames?) + +o Key2DS, also for offline keys -- need to parse them ofcourse + +o Closing of tcp connections? +o Tsig will probably become an interface which has all configuration + stuff, but this will come later. Config which has Tsig function + Todo: * Parsing from strings, going with goyacc and .cz lexer? * encoding NSEC3/NSEC bitmaps, DEcoding works * HIP RR (needs list of domain names, need slice stuff for that) -* Resolver can see if you want ixfr or axfr from the msg no - need for seperate functions +* Is subdomain, is glue helper functions for this kind of stuff Issues: * Check the network order, it works now, but this is on Intel?? * Make the testsuite work with public DNS servers * pack/Unpack smaller. EDNS 'n stuff can be folded in * SetDefaults() for *all* types? -* Closing of tcp connections? - -* Refacter the IXFR/AXFR/TSIG code Examples: * Test impl of nameserver, with a small zone, 1 KSK and online signing diff --git a/_examples/Makefile b/_examples/Makefile index 627a9602..30c47f6b 100644 --- a/_examples/Makefile +++ b/_examples/Makefile @@ -4,7 +4,7 @@ chaos \ axfr \ reflect \ funkensturm \ -ns \ +ds2key \ all: for i in $(EXAMPLES); do gomake -C $$i; done diff --git a/_examples/axfr/axfr.go b/_examples/axfr/axfr.go index 0c52ddc2..98a81348 100644 --- a/_examples/axfr/axfr.go +++ b/_examples/axfr/axfr.go @@ -18,9 +18,9 @@ func main() { res.Servers[0] = *nameserver c := make(chan dns.Xfr) - m := new(dns.Msg) m.Question = make([]dns.Question, 1) + if *serial > 0 { m.Question[0] = dns.Question{zone, dns.TypeIXFR, dns.ClassINET} soa := new(dns.RR_SOA) @@ -28,11 +28,10 @@ func main() { soa.Serial = uint32(*serial) m.Ns = make([]dns.RR, 1) m.Ns[0] = soa - go res.Ixfr(m, c) } else { m.Question[0] = dns.Question{zone, dns.TypeAXFR, dns.ClassINET} - go res.Axfr(m, c) } + go res.Xfr(m, nil, c) for x := range c { fmt.Printf("%v %v\n",x.Add, x.RR) } diff --git a/_examples/ns/Makefile b/_examples/key2ds/Makefile similarity index 88% rename from _examples/ns/Makefile rename to _examples/key2ds/Makefile index 676dbf01..be182665 100644 --- a/_examples/ns/Makefile +++ b/_examples/key2ds/Makefile @@ -2,7 +2,7 @@ # Use of this source code is governed by a BSD-style # license that can be found in the LICENSE file. include $(GOROOT)/src/Make.inc -TARG=ns -GOFILES=ns.go +TARG=key2ds +GOFILES=key2ds.go DEPS=../../ include $(GOROOT)/src/Make.cmd diff --git a/_examples/key2ds/key2ds.go b/_examples/key2ds/key2ds.go new file mode 100644 index 00000000..f561fd61 --- /dev/null +++ b/_examples/key2ds/key2ds.go @@ -0,0 +1,42 @@ +package main + +// Print the DNSKEY records of a domain as DS records +// (c) Miek Gieben - 2011 +import ( + "dns" + "os" + "fmt" +) + +func main() { + r := new(dns.Resolver) + r.FromFile("/etc/resolv.conf") + if len(os.Args) != 2 { + fmt.Printf("%s DOMAIN\n", os.Args[0]) + os.Exit(1) + } + m := new(dns.Msg) + m.MsgHdr.RecursionDesired = true //only set this bit + m.Question = make([]dns.Question, 1) + m.Question[0] = dns.Question{os.Args[1], dns.TypeDNSKEY, dns.ClassINET} + + in, err := r.Query(m) + if in != nil { + if in.Rcode != dns.RcodeSuccess { + fmt.Printf(" *** invalid answer name %s after DNSKEY query for %s\n", os.Args[1], os.Args[1]) + os.Exit(1) + } + // Stuff must be in the answer section + for _, k := range in.Answer { + // Foreach key would need to provide a DS records, both sha1 and sha256 + if key, ok := k.(*dns.RR_DNSKEY); ok { + ds := key.ToDS(dns.HashSHA1) + fmt.Printf("%v\n", ds) + ds = key.ToDS(dns.HashSHA256) + fmt.Printf("%v\n", ds) + } + } + } else { + fmt.Printf("*** error: %s\n", err.String()) + } +} diff --git a/_examples/ns/ns.go b/_examples/ns/ns.go deleted file mode 100644 index 47b61d7d..00000000 --- a/_examples/ns/ns.go +++ /dev/null @@ -1,130 +0,0 @@ -package main - -import ( - "os" - "dns" - "net" - "fmt" - "flag" - "os/signal" -// "json" -) - -var counter int - -func main() { -// var zone *string = flag.String("zone", "", "The zone to serve") - flag.Usage = func() { - fmt.Fprintf(os.Stderr, "Usage: %s zone...\n", os.Args[0]) - flag.PrintDefaults() - } - flag.Parse() - - m := new(dns.Msg) - m.MsgHdr.Id = dns.Id() - m.MsgHdr.Authoritative = true - m.MsgHdr.AuthenticatedData = false - m.MsgHdr.RecursionAvailable = true - m.MsgHdr.Response = true - m.MsgHdr.Opcode = dns.OpcodeQuery - m.MsgHdr.Rcode = dns.RcodeSuccess - m.Question = make([]dns.Question, 1) - m.Question[0] = dns.Question{"miek.nl.", dns.TypeTXT, dns.ClassINET} - m.Answer = make([]dns.RR, 1) - t := new(dns.RR_TXT) - t.Hdr = dns.RR_Header{Name: "miek.nl.", Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: 3600} - t.Txt = "Een antwoord" - m.Answer[0] = t - - errchan := make(chan os.Error) - go udp("127.0.0.1:8054", errchan) - go tcp("127.0.0.1:8054", errchan) - - forever: - for { - select { - case e := <-errchan: - fmt.Printf("Error received, stopping: %s\n", e.String()) - break forever - case <-signal.Incoming: - fmt.Printf("Signal received, stopping\n") - break forever - } - } - close(errchan) - fmt.Printf("Queries answered: %d\n", counter) -} - -func tcp(addr string, e chan os.Error) { - a, err := net.ResolveTCPAddr(addr) - if err != nil { - e <- err - } - l, err := net.ListenTCP("tcp", a) - if err != nil { - e <- err - } - err = dns.ServeTCP(l, replyTCP) - e <- err - return -} - -func udp(addr string, e chan os.Error) { - a, err := net.ResolveUDPAddr(addr) - if err != nil { - e <- err - } - l, err := net.ListenUDP("udp", a) - if err != nil { - e <- err - } - err = dns.ServeUDP(l, replyUDP) - e <- err - return -} - - -func createpkg(id uint16, tcp bool, remove net.Addr) []byte { - m := new(dns.Msg) - m.MsgHdr.Id = id - m.MsgHdr.Authoritative = true - m.MsgHdr.AuthenticatedData = false - m.MsgHdr.RecursionAvailable = true - m.MsgHdr.Response = true - m.MsgHdr.Opcode = dns.OpcodeQuery - m.MsgHdr.Rcode = dns.RcodeSuccess - m.Question = make([]dns.Question, 1) - m.Question[0] = dns.Question{"miek.nl.", dns.TypeTXT, dns.ClassINET} - m.Answer = make([]dns.RR, 1) - t := new(dns.RR_TXT) - t.Hdr = dns.RR_Header{Name: "miek.nl.", Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: 3600} - if tcp { - t.Txt = "Dit is iets anders TCP" - } else { - t.Txt = "Dit is iets anders UDP" - } - m.Answer[0] = t - out, _ := m.Pack() - return out -} - -func replyUDP(c *net.UDPConn, a net.Addr, in *dns.Msg) { - if in.MsgHdr.Response == true { - // Uh... answering to an response?? - // dont think so - return - } - out := createpkg(in.MsgHdr.Id, false, a) - dns.SendUDP(out, c, a) - counter++ -} - -func replyTCP(c *net.TCPConn, a net.Addr, in *dns.Msg) { - if in.MsgHdr.Response == true { - return - } - out := createpkg(in.MsgHdr.Id, true, a) - dns.SendTCP(out, c, a) - counter++ -} - diff --git a/_examples/reflect/reflect.go b/_examples/reflect/reflect.go index 41461a9b..4cd7dbcc 100644 --- a/_examples/reflect/reflect.go +++ b/_examples/reflect/reflect.go @@ -18,99 +18,74 @@ package main import ( "os" - "os/signal" "net" "dns" "fmt" + "os/signal" "strconv" ) -func reply(a net.Addr, in *dns.Msg, tcp bool) *dns.Msg { - if in.MsgHdr.Response == true { - return nil // Don't answer responses - } +func reply(c *dns.Conn, in *dns.Msg) []byte { m := new(dns.Msg) - m.MsgHdr.Id = in.MsgHdr.Id - m.MsgHdr.Authoritative = true - m.MsgHdr.Response = true - m.MsgHdr.Opcode = dns.OpcodeQuery + m.SetReply(in.MsgHdr.Id) - m.MsgHdr.Rcode = dns.RcodeSuccess m.Question = make([]dns.Question, 1) m.Answer = make([]dns.RR, 1) m.Extra = make([]dns.RR, 1) - r := new(dns.RR_A) - r.Hdr = dns.RR_Header{Name: "whoami.miek.nl.", Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 0} - ip, _ := net.ResolveUDPAddr(a.String()) - r.A = ip.IP + // Copy the question. + m.Question[0] = in.Question[0] + + // Some foo to check if we are called through ip6 or ip4. + // We add the correct reply RR. + var ad net.IP + if c.UDP != nil { + ad = c.Addr.(*net.UDPAddr).IP + } else { + ad = c.Addr.(*net.TCPAddr).IP + } + + if ad.To4() != nil { + r := new(dns.RR_A) + r.Hdr = dns.RR_Header{Name: "whoami.miek.nl.", Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 0} + r.A = ad + m.Answer[0] = r + } else { + r := new(dns.RR_AAAA) + r.Hdr = dns.RR_Header{Name: "whoami.miek.nl.", Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: 0} + r.AAAA = ad + m.Answer[0] = r + } t := new(dns.RR_TXT) t.Hdr = dns.RR_Header{Name: "whoami.miek.nl.", Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: 0} - if tcp { - t.Txt = "Port: " + strconv.Itoa(ip.Port) + " (tcp)" + if c.TCP != nil { + t.Txt = "Port: " + strconv.Itoa(c.Port) + " (tcp)" } else { - t.Txt = "Port: " + strconv.Itoa(ip.Port) + " (udp)" + t.Txt = "Port: " + strconv.Itoa(c.Port) + " (udp)" } - - m.Question[0] = in.Question[0] - m.Answer[0] = r m.Extra[0] = t - return m + + b, _ := m.Pack() + return b } -func replyUDP(c *net.UDPConn, a net.Addr, in *dns.Msg) { - m := reply(a, in, false) - if m == nil { - return +func handle(c *dns.Conn, in *dns.Msg) { + if in.MsgHdr.Response == true { + return // We don't do responses } - fmt.Fprintf(os.Stderr, "%v\n", m) - out, ok := m.Pack() - if !ok { - println("Failed to pack") - return - } - dns.SendUDP(out, c, a) -} - -func replyTCP(c *net.TCPConn, a net.Addr, in *dns.Msg) { - m := reply(a, in, true) - if m == nil { - return - } - fmt.Fprintf(os.Stderr, "%v\n", m) - out, ok := m.Pack() - if !ok { - println("Failed to pack") - return - } - dns.SendTCP(out, c, a) + answer := reply(c, in) + c.Write(answer) } func tcp(addr string, e chan os.Error) { - a, err := net.ResolveTCPAddr(addr) - if err != nil { - e <- err - } - l, err := net.ListenTCP("tcp", a) - if err != nil { - e <- err - } - err = dns.ServeTCP(l, replyTCP) + err := dns.ListenAndServeTCP(addr, handle) e <- err return } func udp(addr string, e chan os.Error) { - a, err := net.ResolveUDPAddr(addr) - if err != nil { - e <- err - } - l, err := net.ListenUDP("udp", a) - if err != nil { - e <- err - } - err = dns.ServeUDP(l, replyUDP) + err := dns.ListenAndServeUDP(addr, handle) e <- err return } diff --git a/defaults.go b/defaults.go index 719e3e88..2eff77aa 100644 --- a/defaults.go +++ b/defaults.go @@ -1,5 +1,14 @@ package dns +// Create a reply packet. +func (dns *Msg) SetReply(id uint16) { + dns.MsgHdr.Id = id + dns.MsgHdr.Authoritative = true + dns.MsgHdr.Response = true + dns.MsgHdr.Opcode = OpcodeQuery + dns.MsgHdr.Rcode = RcodeSuccess +} + // Create a notify packet. func (dns *Msg) SetNotify(z string, class uint16) { dns.MsgHdr.Opcode = OpcodeNotify @@ -37,5 +46,4 @@ func (dns *Msg) SetAxfr(z string, class uint16) { dns.Question = make([]Question, 1) dns.Question[0] = Question{z, TypeAXFR, class} } - // IsIxfr/IsAxfr? diff --git a/dns.go b/dns.go index 8b04e843..953e7bb4 100644 --- a/dns.go +++ b/dns.go @@ -15,7 +15,11 @@ // package dns +// ErrShortWrite is defined in io, use that! + import ( + "os" + "net" "strconv" ) @@ -30,7 +34,7 @@ const ( type Error struct { Error string Name string - Server string + Server net.Addr Timeout bool } @@ -41,6 +45,222 @@ func (e *Error) String() string { return e.Error } +// A Conn is the lowest primative in this DNS library +// A hold both the UDP and TCP connection, but only one +// can be active at any one time. +type Conn struct { + // The current UDP connection. + UDP *net.UDPConn + + // The current TCP connection. + TCP *net.TCPConn + + // The remote side of the connection. + Addr net.Addr + + // The remote port number of the connection. + Port int + + // If TSIG is used, this holds all the information + Tsig *Tsig + + // Timeout in sec + Timeout int + + // Number of attempts to try + Attempts int +} + +// Create a new buffer of the appropiate size. +func (d *Conn) NewBuffer() []byte { + if d.TCP != nil { + b := make([]byte, MaxMsgSize) + return b + } + if d.UDP != nil { + b := make([]byte, DefaultMsgSize) + return b + } + return nil +} + + +func (d *Conn) Read(p []byte) (n int, err os.Error) { + if d.UDP != nil && d.TCP != nil { + return 0, &Error{Error: "UDP and TCP or both non-nil"} + } + switch { + case d.UDP != nil: + var addr net.Addr + n, addr, err = d.UDP.ReadFromUDP(p) + if err != nil { + return n, err + } + d.Addr = addr + d.Port = addr.(*net.UDPAddr).Port + case d.TCP != nil: + if len(p) < 1 { + return 0, &Error{Error: "Buffer too small to read"} + } + n, err = d.TCP.Read(p[0:2]) + if err != nil || n != 2 { + return n, err + } + d.Addr = d.TCP.RemoteAddr() + d.Port = d.TCP.RemoteAddr().(*net.TCPAddr).Port + l, _ := unpackUint16(p[0:2], 0) + if l == 0 { + return 0, &Error{Error: "received nil msg length", Server: d.Addr} + } + if int(l) > len(p) { + return int(l), &Error{Error: "Buffer too small to read"} + } + n, err = d.TCP.Read(p) + if err != nil { + return n, err + } + i := n + for i < int(l) { + n, err = d.TCP.Read(p[i:]) + if err != nil { + return i, err + } + i += n + } + n = i + } + if d.Tsig != nil { + // Check the TSIG that we should be read + _, err = d.Tsig.Verify(p) + if err != nil { + return + } + } + return +} + +func (d *Conn) Write(p []byte) (n int, err os.Error) { + if d.UDP != nil && d.TCP != nil { + return 0, &Error{Error: "UDP and TCP or both non-nil"} + } + + var attempts int + var q []byte + if d.Attempts == 0 { + attempts = 1 + } else { + attempts = d.Attempts + } + d.SetTimeout() + if d.Tsig != nil { + // Create a new buffer with the TSIG added. + q, err = d.Tsig.Generate(p) + if err != nil { + return 0, err + } + } else { + q = p + } + + switch { + case d.UDP != nil: + for a := 0; a < attempts; a++ { + n, err = d.UDP.WriteTo(q, d.Addr) + if err != nil { + if e, ok := err.(net.Error); ok && e.Timeout() { + continue + } + return 0, err + } + } + case d.TCP != nil: + for a := 0; a < attempts; a++ { + l := make([]byte, 2) + l[0], l[1] = packUint16(uint16(len(q))) + n, err = d.TCP.Write(l) + if err != nil { + if e, ok := err.(net.Error); ok && e.Timeout() { + continue + } + return n, err + } + if n != 2 { + return n, &Error{Error: "Write failure"} + } + n, err = d.TCP.Write(q) + if err != nil { + if e, ok := err.(net.Error); ok && e.Timeout() { + continue + } + return n, err + } + i := n + if i < len(q) { + n, err = d.TCP.Write(q) + if err != nil { + if e, ok := err.(net.Error); ok && e.Timeout() { + // We are half way in our write... + continue + } + return n, err + } + i += n + } + n = i + } + } + return +} + +func (d *Conn) Close() (err os.Error) { + if d.UDP != nil && d.TCP != nil { + return &Error{Error: "UDP and TCP or both non-nil"} + } + switch { + case d.UDP != nil: + err = d.UDP.Close() + case d.TCP != nil: + err = d.TCP.Close() + } + return +} + +func (d *Conn) SetTimeout() (err os.Error) { + var sec int64 + if d.UDP != nil && d.TCP != nil { + return &Error{Error: "UDP and TCP or both non-nil"} + } + sec = int64(d.Timeout) + if sec == 0 { + sec = 1 + } + if d.UDP != nil { + err = d.TCP.SetTimeout(sec * 1e9) + } + if d.TCP != nil { + err = d.TCP.SetTimeout(sec * 1e9) + } + return +} + +func (d *Conn) Exchange(request []byte, nosend bool) (reply []byte, err os.Error) { + var n int + if !nosend { + n, err = d.Write(request) + if err != nil { + return nil, err + } + } + // Layer violation to save memory. Its okay then... + reply = d.NewBuffer() + n, err = d.Read(reply) + if err != nil { + return nil, err + } + reply = reply[:n] + return +} + type RR interface { Header() *RR_Header diff --git a/dns.y b/dns.y deleted file mode 100644 index b4b8e5ce..00000000 --- a/dns.y +++ /dev/null @@ -1,124 +0,0 @@ -// Copyright Miek Gieben 2011 -// Heavily influenced by the zone-parser from NSD - -%{ - -package dns - -import ( - "fmt" -) - -// A yacc parser for DNS Resource Records contained in strings - -%} - -%union { - val string - rrtype uint16 - class uint16 - ttl uint16 - -} - -/* - * Types known to package dns - */ -%token YA YNS - -/* - * Other elements of the Resource Records - */ -%token TTL -%token CLASS -%token VAL -%% -rr: name TTL CLASS rrtype - { - - }; - -name: label - | name '.' label - -label: VAL - -rrtype: - /* All supported RR types */ - YA - | YNS -%% - -type DnsLex int - -func (DnsLex) Lex(lval *yySymType) int { - - // yylval.rrtype = Str_rr($XX) //give back TypeA, TypeNS - // return Y_A this should be the token, another map? - -//func scan(s string) (string, int) { - if len(s) == 0 { - println("a bit short") - } - raw := []byte(s) - chunk := "" - off := 0 - brace := 0 -redo: - for off < len(raw) { - c := raw[off] -// println(c, string(c)) - switch c { - case '\n': - // normal case?? - if brace > 0 { - off++ - continue - } - case '.': -// println("off", off) - if off == 0 { - print("DOT") - return ".", off + 1 - } else { - return chunk, off - } - case ' ','\t': - if brace != 0 { - off++ - continue - } - // eat whitespace - // Look at next char - if raw[off+1] == ' ' { - off++ - continue - } else { - // if chunk is empty, we have skipped whitespace, and seen nothing - if len(chunk) == 0 { - off++ - goto redo - } - print("VAL ") - return chunk, off - } - case '(': - brace++ - off++ - continue - case ')': - brace-- - if brace < 0 { - println("syntax error") - } - off++ - continue - } - if c == ' ' { println("adding space") } - if c == '\t' { println("adding tab") } - chunk += string(c) - off++ - } - print("VAL ") - return chunk, off -} diff --git a/dnssec.go b/dnssec.go index 93422076..8733ccbf 100644 --- a/dnssec.go +++ b/dnssec.go @@ -32,7 +32,8 @@ const ( // DNSSEC hashing codes. const ( - HashSHA1 = iota + _ = iota + HashSHA1 HashSHA256 HashGOST94 ) @@ -104,6 +105,7 @@ func (k *RR_DNSKEY) ToDS(h int) *RR_DS { ds := new(RR_DS) ds.Hdr.Name = k.Hdr.Name ds.Hdr.Class = k.Hdr.Class + ds.Hdr.Rrtype = TypeDS ds.Hdr.Ttl = k.Hdr.Ttl ds.Algorithm = k.Algorithm ds.DigestType = uint8(h) diff --git a/edns.go b/edns.go index 497d9431..0b19202e 100644 --- a/edns.go +++ b/edns.go @@ -7,9 +7,10 @@ import ( // EDNS0 Options const ( - OptionCodeLLQ = 1 // not used - OptionCodeUL = 2 // not used - OptionCodeNSID = 3 // NSID, RFC5001 + _ = iota + OptionCodeLLQ // not used + OptionCodeUL // not used + OptionCodeNSID // NSID, RFC5001 _DO = 1 << 7 // dnssec ok ) diff --git a/resolver.go b/resolver.go index 8b2701fc..61e1fb73 100644 --- a/resolver.go +++ b/resolver.go @@ -5,7 +5,6 @@ // DNS resolver client: see RFC 1035. package dns -// TODO: refacter this import ( "os" @@ -32,10 +31,6 @@ type Resolver struct { Rrb int // Last used server (for round robin) } -func (res *Resolver) QueryTSIG(q *Msg, secret *string) (d *Msg, err os.Error) { - return nil,nil -} - // Basic usage pattern for setting up a resolver: // // res := new(Resolver) @@ -45,64 +40,50 @@ func (res *Resolver) QueryTSIG(q *Msg, secret *string) (d *Msg, err os.Error) { // m.MsgHdr.Recursion_desired = true // header bits // m.Question = make([]Question, 1) // 1 RR in question section // m.Question[0] = Question{"miek.nl", TypeSOA, ClassINET} -// in, err := res.Query(m) // Ask the question +// in, err := res.Query(m, nil) // Ask the question // // Note that message id checking is left to the caller. func (res *Resolver) Query(q *Msg) (d *Msg, err os.Error) { - // Check if there is a TSIG appended, if so, check it - var ( - c net.Conn - port string - inb []byte - ) - in := new(Msg) - if len(res.Servers) == 0 { - return nil, &Error{Error: "No servers defined"} - } - if res.Rtt == nil { - res.Rtt = make(map[string]int64) - } - if res.Port == "" { - port = "53" - } else { - port = res.Port + return res.QueryTsig(q, nil) +} + +func (res *Resolver) QueryTsig(q *Msg, tsig *Tsig) (d *Msg, err os.Error) { + var c net.Conn + var inb []byte + in := new(Msg) + port, err := check(res, q) + if err != nil { + return nil, err } - if q.Id == 0 { - // No Id sed, set it - q.Id = Id() - } sending, ok := q.Pack() if !ok { return nil, &Error{Error: ErrPack} } for i := 0; i < len(res.Servers); i++ { + d := new(Conn) server := res.Servers[i] + ":" + port t := time.Nanoseconds() if res.Tcp { c, err = net.Dial("tcp", "", server) + d.TCP = c.(*net.TCPConn) + d.Addr = d.TCP.RemoteAddr() } else { c, err = net.Dial("udp", "", server) + d.UDP = c.(*net.UDPConn) + d.Addr = d.UDP.RemoteAddr() } if err != nil { continue } - if res.Tcp { - inb, err = exchangeTCP(c, sending, res, true) - in.Unpack(inb) - - } else { - inb, err = exchangeUDP(c, sending, res, true) - in.Unpack(inb) + inb, err = d.Exchange(sending, false) + if err != nil { + continue } + in.Unpack(inb) // Discard error. res.Rtt[server] = time.Nanoseconds() - t - - // Check id in.id != out.id, should be checked in the client! c.Close() - if err != nil { - continue - } break } if err != nil { @@ -111,40 +92,15 @@ func (res *Resolver) Query(q *Msg) (d *Msg, err os.Error) { return in, nil } -// Xfr is used in communicating with *xfr functions. -// This structure is returned on the channel. -type Xfr struct { - Add bool // true is to be added, otherwise false - RR - Err os.Error +func (res *Resolver) Xfr(q *Msg, m chan Xfr) { + res.XfrTsig(q, nil, m) } -// Start an IXFR, q should contain a *Msg with the question -// for an IXFR: "miek.nl" ANY IXFR. RRs that should be added -// have Xfr.Add set to true otherwise it is false. -// Channel m is closed when the IXFR ends. -func (res *Resolver) Ixfr(q *Msg, m chan Xfr) { - // TSIG - var ( - port string - x Xfr - inb []byte - ) - in := new(Msg) - if res.Port == "" { - port = "53" - } else { - port = res.Port +func (res *Resolver) XfrTsig(q *Msg, t *Tsig, m chan Xfr) { + port, err := check(res, q) + if err != nil { + return } - if res.Rtt == nil { - res.Rtt = make(map[string]int64) - } - - if q.Id == 0 { - q.Id = Id() - } - - defer close(m) sending, ok := q.Pack() if !ok { return @@ -157,197 +113,22 @@ Server: if err != nil { continue Server } - first := true - var serial uint32 // The first serial seen is the current server serial + d := new(Conn) + d.TCP = c.(*net.TCPConn) + d.Addr = d.TCP.RemoteAddr() + d.Tsig = t - defer c.Close() - for { - if first { - inb, err = exchangeTCP(c, sending, res, true) - in.Unpack(inb) - } else { - inb, err = exchangeTCP(c, sending, res, false) - in.Unpack(inb) - } - - if err != nil { - // Failed to send, try the next - c.Close() - continue Server - } - if in.Id != q.Id { - return - } - - if first { - // A single SOA RR signals "no changes" - if len(in.Answer) == 1 && checkAxfrSOA(in, true) { - return - } - - // But still check if the returned answer is ok - if !checkAxfrSOA(in, true) { - c.Close() - continue Server - } - // This serial is important - serial = in.Answer[0].(*RR_SOA).Serial - first = !first - } - - // Now we need to check each message for SOA records, to see what we need to do - x.Add = true - if !first { - for k, r := range in.Answer { - // If the last record in the IXFR contains the servers' SOA, we should quit - if r.Header().Rrtype == TypeSOA { - switch { - case r.(*RR_SOA).Serial == serial: - if k == len(in.Answer)-1 { - // last rr is SOA with correct serial - //m <- r dont' send it - return - } - x.Add = true - if k != 0 { - // Intermediate SOA - continue - } - case r.(*RR_SOA).Serial != serial: - x.Add = false - continue // Don't need to see this SOA - } - } - x.RR = r - m <- x - } - } - return - } - panic("not reached") - return - } - return -} - -// Start an AXFR, q should contain a message with the question -// for an AXFR: "miek.nl" ANY AXFR. The closing SOA isn't -// returned over the channel, so the caller will receive -// the zone as-is. Xfr.Add is always true. -// The channel is closed to signal the end of the AXFR. -func (res *Resolver) AxfrTSIG(q *Msg, m chan Xfr, secret string) { - var ( - port string - inb []byte - ) - in := new(Msg) - if res.Port == "" { - port = "53" - } else { - port = res.Port - } - if res.Rtt == nil { - res.Rtt = make(map[string]int64) - } - - if q.Id == 0 { - q.Id = Id() - } - - defer close(m) - sending, ok := q.Pack() - if !ok { - return - } - - var tsig bool - var reqmac string - // Check if there is a TSIG added to the request msg - if len(q.Extra) > 0 { - tsig = q.Extra[len(q.Extra)-1].Header().Rrtype == TypeTSIG - if tsig { - reqmac = q.Extra[len(q.Extra)-1].(*RR_TSIG).MAC + _, err = d.Write(sending) + if err != nil { + continue Server } - } - -Server: - for i := 0; i < len(res.Servers); i++ { - server := res.Servers[i] + ":" + port - c, err := net.Dial("tcp", "", server) - if err != nil { - continue Server - } - first := true - defer c.Close() // TODO(mg): if not open? - for { - if first { - inb, err = exchangeTCP(c, sending, res, true) - in.Unpack(inb) - } else { - inb, err = exchangeTCP(c, sending, res, false) - in.Unpack(inb) - } - - if err != nil { - // Failed to send, try the next - c.Close() - continue Server - } - if in.Id != q.Id { - c.Close() - return - } - - if tsig && len(in.Extra) > 0 { // What if not included? - t := in.Extra[len(in.Extra)-1] - if t.Header().Rrtype == TypeTSIG { - if t.(*RR_TSIG).Verify(inb, secret, reqmac, first) { - // Set the MAC for the next round. - reqmac = t.(*RR_TSIG).MAC - } else { - c.Close() - return - } - } - } - - if first { - if !checkAxfrSOA(in, true) { - c.Close() - continue Server - } - first = !first - } - - if !first { - if !checkAxfrSOA(in, false) { - // Soa record not the last one - sendFromMsg(in, m, false) - continue - } else { - sendFromMsg(in, m, true) - return - } - } - } - panic("not reached") - return + d.XfrRead(q, m) // check } return } - -// Start an AXFR, q should contain a message with the question -// for an AXFR: "miek.nl" ANY AXFR. The closing SOA isn't -// returned over the channel, so the caller will receive -// the zone as-is. Xfr.Add is always true. -// The channel is closed to signal the end of the AXFR. -func (res *Resolver) Axfr(q *Msg, m chan Xfr) { - var ( - port string - inb []byte - ) - in := new(Msg) +// Some assorted checks on the resolver +func check(res *Resolver, q *Msg) (port string, err os.Error) { if res.Port == "" { port = "53" } else { @@ -356,240 +137,8 @@ func (res *Resolver) Axfr(q *Msg, m chan Xfr) { if res.Rtt == nil { res.Rtt = make(map[string]int64) } - if q.Id == 0 { q.Id = Id() } - - defer close(m) - sending, ok := q.Pack() - if !ok { - return - } - -Server: - for i := 0; i < len(res.Servers); i++ { - server := res.Servers[i] + ":" + port - c, err := net.Dial("tcp", "", server) - if err != nil { - continue Server - } - first := true - defer c.Close() // TODO(mg): if not open? - for { - if first { - inb, err = exchangeTCP(c, sending, res, true) - in.Unpack(inb) - } else { - inb, err = exchangeTCP(c, sending, res, false) - in.Unpack(inb) - } - - if err != nil { - // Failed to send, try the next - c.Close() - continue Server - } - if in.Id != q.Id { - c.Close() - return - } - - if first { - if !checkAxfrSOA(in, true) { - c.Close() - continue Server - } - first = !first - } - - if !first { - if !checkAxfrSOA(in, false) { - // Soa record not the last one - sendFromMsg(in, m, false) - continue - } else { - sendFromMsg(in, m, true) - return - } - } - } - panic("not reached") - return - } return } - -// Send a request on the connection and hope for a reply. -// Up to res.Attempts attempts. If send is false, nothing -// is send. -func exchangeUDP(c net.Conn, m []byte, r *Resolver, send bool) ([]byte, os.Error) { - var timeout int64 - var attempts int - if r.Mangle != nil { - m = r.Mangle(m) - } - if r.Timeout == 0 { - timeout = 1 - } else { - timeout = int64(r.Timeout) - } - if r.Attempts == 0 { - attempts = 1 - } else { - attempts = r.Attempts - } - for a := 0; a < attempts; a++ { - if send { - err := sendUDP(m, c) - if err != nil { - if e, ok := err.(net.Error); ok && e.Timeout() { - continue - } - return nil, err - } - } - - c.SetReadTimeout(timeout * 1e9) // nanoseconds - buf, err := recvUDP(c) - if err != nil { - if e, ok := err.(net.Error); ok && e.Timeout() { - continue - } - return nil, err - } - return buf, nil - } - return nil, &Error{Error: ErrServ} -} - -// Up to res.Attempts attempts. -func exchangeTCP(c net.Conn, m []byte, r *Resolver, send bool) ([]byte, os.Error) { - var timeout int64 - var attempts int - if r.Mangle != nil { - m = r.Mangle(m) - } - if r.Timeout == 0 { - timeout = 1 - } else { - timeout = int64(r.Timeout) - } - if r.Attempts == 0 { - attempts = 1 - } else { - attempts = r.Attempts - } - - for a := 0; a < attempts; a++ { - // only send something when told so - if send { - err := sendTCP(m, c) - if err != nil { - if e, ok := err.(net.Error); ok && e.Timeout() { - continue - } - return nil, err - } - } - - c.SetReadTimeout(timeout * 1e9) // nanoseconds - // The server replies with two bytes length. - buf, err := recvTCP(c) - if err != nil { - if e, ok := err.(net.Error); ok && e.Timeout() { - continue - } - return nil, err - } - return buf, nil - } - return nil, &Error{Error: ErrServ} -} - -func sendUDP(m []byte, c net.Conn) os.Error { - _, err := c.Write(m) - if err != nil { - return err - } - return nil -} - -func recvUDP(c net.Conn) ([]byte, os.Error) { - m := make([]byte, DefaultMsgSize) - n, err := c.Read(m) - if err != nil { - return nil, err - } - m = m[:n] - return m, nil -} - -func sendTCP(m []byte, c net.Conn) os.Error { - l := make([]byte, 2) - l[0] = byte(len(m) >> 8) - l[1] = byte(len(m)) - // First we send the length - _, err := c.Write(l) - if err != nil { - return err - } - // And the the message - _, err = c.Write(m) - if err != nil { - return err - } - return nil -} - -func recvTCP(c net.Conn) ([]byte, os.Error) { - l := make([]byte, 2) // The server replies with two bytes length. - _, err := c.Read(l) - if err != nil { - return nil, err - } - length := uint16(l[0])<<8 | uint16(l[1]) - if length == 0 { - return nil, &Error{Error: "received nil msg length", Server: c.RemoteAddr().String()} - } - m := make([]byte, length) - n, cerr := c.Read(m) - if cerr != nil { - return nil, cerr - } - i := n - for i < int(length) { - n, err = c.Read(m[i:]) - if err != nil { - return nil, err - } - i += n - } - return m, nil -} - -// Check if he SOA record exists in the Answer section of -// the packet. If first is true the first RR must be a soa -// if false, the last one should be a SOA -func checkAxfrSOA(in *Msg, first bool) bool { - if len(in.Answer) > 0 { - if first { - return in.Answer[0].Header().Rrtype == TypeSOA - } else { - return in.Answer[len(in.Answer)-1].Header().Rrtype == TypeSOA - } - } - return false -} - -// Send the answer section to the channel -func sendFromMsg(in *Msg, c chan Xfr, nosoa bool) { - x := Xfr{Add: true} - for k, r := range in.Answer { - if nosoa && k == len(in.Answer)-1 { - continue - } - x.RR = r - c <- x - } -} diff --git a/server.go b/server.go index 62241139..4ecdcc43 100644 --- a/server.go +++ b/server.go @@ -11,70 +11,65 @@ import ( "net" ) -type Server struct { - ServeUDP func(*net.UDPConn, net.Addr, *Msg) os.Error - ServeTCP func(*net.TCPConn, net.Addr, *Msg) os.Error - /* notify stuff here? */ - /* tsig here */ -} +// For both -> logging +// Add tsig stuff as in resolver.go -func ServeUDP(l *net.UDPConn, f func(*net.UDPConn, net.Addr, *Msg)) os.Error { +func HandleUDP(l *net.UDPConn, f func(*Conn, *Msg)) os.Error { for { m := make([]byte, DefaultMsgSize) - n, radd, e := l.ReadFromUDP(m) + n, addr, e := l.ReadFromUDP(m) if e != nil { continue } m = m[:n] + + d := new(Conn) + d.UDP = l + d.Addr = addr + d.Port = addr.Port // Why not the same as in dns.go, line 96 + msg := new(Msg) if !msg.Unpack(m) { continue } - go f(l, radd, msg) + go f(d, msg) } panic("not reached") } -func ServeTCP(l *net.TCPListener, f func(*net.TCPConn, net.Addr, *Msg)) os.Error { - b := make([]byte, 2) +func HandleTCP(l *net.TCPListener, f func(*Conn, *Msg)) os.Error { for { c, e := l.AcceptTCP() if e != nil { return e } - n, e := c.Read(b) - if e != nil { - continue - } + d := new(Conn) + d.TCP = c + d.Addr = c.RemoteAddr() + d.Port = d.TCP.RemoteAddr().(*net.TCPAddr).Port - length := uint16(b[0])<<8 | uint16(b[1]) - if length == 0 { - return &Error{Error: "received nil msg length"} - } - m := make([]byte, length) + m := d.NewBuffer() + n, e := d.Read(m) + if e != nil { + continue + } + m = m[:n] - n, e = c.Read(m) - if e != nil { - continue - } - i := n - if i < int(length) { - n, e = c.Read(m[i:]) - if e != nil { - continue - } - i += n - } msg := new(Msg) if !msg.Unpack(m) { + // Logging?? continue } - go f(c, c.RemoteAddr(), msg) + go f(d, msg) } panic("not reached") } -func ListenAndServeTCP(addr string, f func(*net.TCPConn, net.Addr, *Msg)) os.Error { +// config functions Config +// ListenAndServeTCPTsig +// ListenAndServeUDPTsig + +func ListenAndServeTCP(addr string, f func(*Conn, *Msg)) os.Error { a, err := net.ResolveTCPAddr(addr) if err != nil { return err @@ -83,11 +78,11 @@ func ListenAndServeTCP(addr string, f func(*net.TCPConn, net.Addr, *Msg)) os.Err if err != nil { return err } - err = ServeTCP(l, f) + err = HandleTCP(l, f) return err } -func ListenAndServeUDP(addr string, f func(*net.UDPConn, net.Addr, *Msg)) os.Error { +func ListenAndServeUDP(addr string, f func(*Conn, *Msg)) os.Error { a, err := net.ResolveUDPAddr(addr) if err != nil { return err @@ -96,42 +91,6 @@ func ListenAndServeUDP(addr string, f func(*net.UDPConn, net.Addr, *Msg)) os.Err if err != nil { return err } - err = ServeUDP(l, f) + err = HandleUDP(l, f) return err } - -// Send a buffer on the TCP connection. -func SendTCP(m []byte, c *net.TCPConn, a net.Addr) os.Error { - l := make([]byte, 2) - l[0] = byte(len(m) >> 8) - l[1] = byte(len(m)) - // First we send the length - n, err := c.Write(l) - if err != nil { - return err - } - // And the the message - n, err = c.Write(m) - if err != nil { - return err - } - i := n - for i < len(m) { - n, err = c.Write(m) - if err != nil { - return err - } - i += n - } - return nil -} - -// Send a buffer to the remove address. Only here because -// of the symmetry with SendTCP(). -func SendUDP(m []byte, c *net.UDPConn, a net.Addr) os.Error { - _, err := c.WriteTo(m, a) - if err != nil { - return err - } - return nil -} diff --git a/tsig.go b/tsig.go index 0c16032d..b1bfcd5e 100644 --- a/tsig.go +++ b/tsig.go @@ -4,12 +4,33 @@ package dns // RFC 2845 and RFC 4635 import ( "io" - "strconv" + "os" + "time" "strings" "crypto/hmac" "encoding/hex" ) +// Return os.Error with real tsig errors + +// Structure used in Read/Write lowlevel functions +// for TSIG generation and verification. +type Tsig struct { + // The name of the key. + Name string + Fudge uint16 + TimeSigned uint64 + Algorithm string + // Tsig secret encoded in base64. + Secret string + // MAC (if known) + MAC string + // Request MAC + RequestMAC string + // Only include the timers if true. + TimersOnly bool +} + // HMAC hashing codes. These are transmitted as domain names. const ( HmacMD5 = "hmac-md5.sig-alg.reg.int." @@ -17,46 +38,6 @@ const ( HmacSHA256 = "hmac-sha256." ) -type RR_TSIG struct { - Hdr RR_Header - Algorithm string "domain-name" - TimeSigned uint64 - Fudge uint16 - MACSize uint16 - MAC string "size-hex" - OrigId uint16 - Error uint16 - OtherLen uint16 - OtherData string "size-hex" -} - -func (rr *RR_TSIG) Header() *RR_Header { - return &rr.Hdr -} - -// move to defaults.go? -func (rr *RR_TSIG) SetDefaults() { - rr.Header().Ttl = 0 - rr.Header().Class = ClassANY - rr.Header().Rrtype = TypeTSIG - rr.Fudge = 300 - rr.Algorithm = HmacMD5 -} - -// TSIG has no official presentation format, but this will suffice. -func (rr *RR_TSIG) String() string { - return rr.Hdr.String() + - " " + rr.Algorithm + - " " + tsigTimeToDate(rr.TimeSigned) + - " " + strconv.Itoa(int(rr.Fudge)) + - " " + strconv.Itoa(int(rr.MACSize)) + - " " + strings.ToUpper(rr.MAC) + - " " + strconv.Itoa(int(rr.OrigId)) + - " " + strconv.Itoa(int(rr.Error)) + - " " + strconv.Itoa(int(rr.OtherLen)) + - " " + rr.OtherData -} - // The following values must be put in wireformat, so that the MAC can be calculated. // RFC 2845, section 3.4.2. TSIG Variables. type tsigWireFmt struct { @@ -87,123 +68,133 @@ type timerWireFmt struct { Fudge uint16 } -// Generate the HMAC for message. The TSIG RR is modified -// to include the MAC and MACSize. Note the the msg Id must -// already be set, otherwise the MAC will not be correct when -// the message is send. -// The string 'secret' must be encoded in base64. -func (t *RR_TSIG) Generate(m *Msg, secret string) bool { - rawsecret, err := packBase64([]byte(secret)) +// In a message and out a new message with the tsig added +func (t *Tsig) Generate(msg []byte) ([]byte, os.Error) { + rawsecret, err := packBase64([]byte(t.Secret)) if err != nil { - return false + return nil, err } - t.OrigId = m.MsgHdr.Id + if t.Fudge == 0 { + t.Fudge = 300 + } + if t.TimeSigned == 0 { + t.TimeSigned = uint64(time.Seconds()) + } - msg, ok := m.Pack() - if !ok { - return false - } - buf, ok1 := tsigToBuf(t, msg, "", true) - if !ok1 { - return false + buf, err := t.Buffer(msg) + if err != nil { + return nil, err } h := hmac.NewMD5([]byte(rawsecret)) io.WriteString(h, string(buf)) + t.MAC = hex.EncodeToString(h.Sum()) // Size is half! - t.MAC = hex.EncodeToString(h.Sum()) - t.MACSize = uint16(len(h.Sum())) // Needs to be "on-the-wire" size. - if !ok { - return false - } - return true + // Create TSIG and add it to the message. + q := new(Msg) + if !q.Unpack(msg) { + return nil, &Error{Error: "Failed to unpack"} + } + + rr := new(RR_TSIG) + rr.Hdr = RR_Header{Name: t.Name, Rrtype: TypeTSIG, Class: ClassANY, Ttl: 0} + rr.Fudge = t.Fudge + rr.TimeSigned = t.TimeSigned + rr.Algorithm = t.Algorithm + rr.OrigId = q.Id + rr.MAC = t.MAC + rr.MACSize = uint16(len(t.MAC) / 2) + + q.Extra = append(q.Extra, rr) + send, ok := q.Pack() + if !ok { + return send, &Error{Error: "Failed to pack"} + } + return send, nil } -// Verify a TSIG. The message should be the complete with -// the TSIG record still attached (as the last rr in the Additional -// section). Return true on success. -// The secret is a base64 encoded string with the secret. -func (t *RR_TSIG) Verify(msg []byte, secret, reqmac string, timers bool) bool { - rawsecret, err := packBase64([]byte(secret)) +// Verify a TSIG on a message. All relevant data should +// be set in the Tsig structure. +func (t *Tsig) Verify(msg []byte) (bool, os.Error) { + rawsecret, err := packBase64([]byte(t.Secret)) if err != nil { - return false + return false, err + } + // Stipped the TSIG from the incoming msg + stripped, ok := stripTsig(msg) + if !ok { + return false, &Error{Error: "Failed to strip tsig"} } - if t.Header().Rrtype != TypeTSIG { - return false + buf,err := t.Buffer(stripped) + if err != nil { + return false, err } - // t.OrigId -- need to check - stripped, ok := stripTSIG(msg) - if !ok { - return false - } - buf, ok := tsigToBuf(t, stripped, reqmac, timers) - if !ok { - return false - } + // Time needs to be checked */ + // Generic time error h := hmac.NewMD5([]byte(rawsecret)) io.WriteString(h, string(buf)) - return strings.ToUpper(hex.EncodeToString(h.Sum())) == strings.ToUpper(t.MAC) + return strings.ToUpper(hex.EncodeToString(h.Sum())) == strings.ToUpper(t.MAC), nil } -// Create the buffer which we use for the MAC calculation. -func tsigToBuf(rr *RR_TSIG, msg []byte, reqmac string, timers bool) ([]byte, bool) { +// Create a wiredata buffer for the MAC calculation +func (t *Tsig) Buffer(msg []byte) ([]byte, os.Error) { var ( macbuf []byte buf []byte ) - if reqmac != "" { + if t.RequestMAC != "" { m := new(macWireFmt) - m.MACSize = uint16(len(reqmac) / 2) - m.MAC = reqmac - macbuf = make([]byte, len(reqmac)) // reqmac should be twice as long + m.MACSize = uint16(len(t.RequestMAC) / 2) + m.MAC = t.RequestMAC + macbuf = make([]byte, len(t.RequestMAC)) // reqmac should be twice as long n, ok := packStruct(m, macbuf, 0) if !ok { - return nil, false + return nil, &Error{Error: "Failed to pack request mac"} } macbuf = macbuf[:n] } tsigvar := make([]byte, DefaultMsgSize) - if timers { - tsig := new(tsigWireFmt) - tsig.Name = strings.ToLower(rr.Header().Name) - tsig.Class = rr.Header().Class - tsig.Ttl = rr.Header().Ttl - tsig.Algorithm = strings.ToLower(rr.Algorithm) - tsig.TimeSigned = rr.TimeSigned - tsig.Fudge = rr.Fudge - tsig.Error = rr.Error - tsig.OtherLen = rr.OtherLen - tsig.OtherData = rr.OtherData + if t.TimersOnly { + tsig := new(timerWireFmt) + tsig.TimeSigned = t.TimeSigned + tsig.Fudge = t.Fudge n, ok1 := packStruct(tsig, tsigvar, 0) if !ok1 { - return nil, false + return nil, &Error{Error: "Failed to pack timers"} } tsigvar = tsigvar[:n] } else { - tsig := new(timerWireFmt) - tsig.TimeSigned = rr.TimeSigned - tsig.Fudge = rr.Fudge + tsig := new(tsigWireFmt) + tsig.Name = strings.ToLower(t.Name) + tsig.Class = ClassANY + tsig.Ttl = 0 + tsig.Algorithm = strings.ToLower(t.Algorithm) + tsig.TimeSigned = t.TimeSigned + tsig.Fudge = t.Fudge + tsig.Error = 0 + tsig.OtherLen = 0 + tsig.OtherData = "" n, ok1 := packStruct(tsig, tsigvar, 0) if !ok1 { - return nil, false + return nil, &Error{Error: "Failed to pack tsig variables"} } tsigvar = tsigvar[:n] } - if reqmac != "" { + if t.RequestMAC != "" { x := append(macbuf, msg...) buf = append(x, tsigvar...) } else { buf = append(msg, tsigvar...) } - return buf, true + return buf, nil } // Strip the TSIG from the pkt. -func stripTSIG(orig []byte) ([]byte, bool) { +func stripTsig(orig []byte) ([]byte, bool) { // Copied from msg.go's Unpack() // Header. var dh Header diff --git a/types.go b/types.go index 774bf4a7..cdf3b620 100644 --- a/types.go +++ b/types.go @@ -739,6 +739,38 @@ func (rr *RR_DHCID) String() string { return rr.Hdr.String() + rr.Digest } +// RFC 2845. +type RR_TSIG struct { + Hdr RR_Header + Algorithm string "domain-name" + TimeSigned uint64 + Fudge uint16 + MACSize uint16 + MAC string "size-hex" + OrigId uint16 + Error uint16 + OtherLen uint16 + OtherData string "size-hex" +} + +func (rr *RR_TSIG) Header() *RR_Header { + return &rr.Hdr +} + +// TSIG has no official presentation format, but this will suffice. +func (rr *RR_TSIG) String() string { + return rr.Hdr.String() + + " " + rr.Algorithm + + " " + tsigTimeToDate(rr.TimeSigned) + + " " + strconv.Itoa(int(rr.Fudge)) + + " " + strconv.Itoa(int(rr.MACSize)) + + " " + strings.ToUpper(rr.MAC) + + " " + strconv.Itoa(int(rr.OrigId)) + + " " + strconv.Itoa(int(rr.Error)) + + " " + strconv.Itoa(int(rr.OtherLen)) + + " " + rr.OtherData +} + // Translate the RRSIG's incep. and expir. time to the correct date. // Taking into account serial arithmetic (RFC 1982) func timeToDate(t uint32) string { diff --git a/xfr.go b/xfr.go index 530accbd..4e5898a2 100644 --- a/xfr.go +++ b/xfr.go @@ -1,3 +1,223 @@ package dns +import ( + "os" +) + // Outgoing AXFR and IXFR implementations +// error handling?? + +// Xfr is used in communicating with *xfr functions. +// This structure is returned on the channel. +type Xfr struct { + Add bool // true is to be added, otherwise false + RR + Err os.Error +} + +// Msg tells use what to do +func (d *Conn) XfrRead(q *Msg, m chan Xfr) { + switch q.Question[0].Qtype { + case TypeAXFR: + d.axfrRead(q, m) + case TypeIXFR: + d.ixfrRead(q, m) + } +} + +func (d *Conn) XfrWrite(q *Msg, m chan Xfr) { + switch q.Question[0].Qtype { + case TypeAXFR: + d.axfrWrite(q, m) + case TypeIXFR: + // d.ixfrWrite(q, m) + } +} + +func (d *Conn) axfrRead(q *Msg, m chan Xfr) { + defer close(m) + first := true + in := new(Msg) + for { + inb := d.NewBuffer() + n, err := d.Read(inb) + if err != nil { + m <- Xfr{true, nil, err} + return + } + inb = inb[:n] + + if !in.Unpack(inb) { + m <- Xfr{true, nil, &Error{Error: "Failed to unpack"}} + return + } + if in.Id != q.Id { + m <- Xfr{true, nil, &Error{Error: "Id mismatch"}} + return + } + + if first { + if !checkXfrSOA(in, true) { + m <- Xfr{true, nil, &Error{Error: "SOA not first record"}} + return + } + first = !first + } + + if !first { + if d.Tsig != nil { + d.Tsig.TimersOnly = true // Subsequent envelopes use this + } + if !checkXfrSOA(in, false) { + // Soa record not the last one + sendMsg(in, m, false) + continue + } else { + sendMsg(in, m, true) + return + } + } + } + panic("not reached") + return +} + +// Just send the zone +func (d *Conn) axfrWrite(q *Msg, m chan Xfr) { + out := new(Msg) + out.Id = q.Id + out.Question = q.Question + out.Answer = make([]RR, 1000) + var soa *RR_SOA + i := 0 + for r := range m { + out.Answer[i] = r.RR + if soa == nil { + if r.RR.Header().Rrtype != TypeSOA { + return + } else { + soa = r.RR.(*RR_SOA) + } + } + i++ + if i > 1000 { + // Send it + send, _ := out.Pack() + _, err := d.Write(send) + if err != nil { + /* ... */ + } + i = 0 + out.Answer = out.Answer[:0] + } + // TimersOnly foo + } + // Everything is sent, only the closing soa is left. + out.Answer[i] = soa + send, _ := out.Pack() + _, err := d.Write(send) + if err != nil { + /* ... */ + } +} + +func (d *Conn) ixfrRead(q *Msg, m chan Xfr) { + defer close(m) + var serial uint32 // The first serial seen is the current server serial + var x Xfr + first := true + in := new(Msg) + for { + inb := d.NewBuffer() + n, err := d.Read(inb) + if err != nil { + m <- Xfr{true, nil, err} + return + } + inb = inb[:n] + + if !in.Unpack(inb) { + m <- Xfr{true, nil, &Error{Error: "Failed to unpack"}} + return + } + if in.Id != q.Id { + m <- Xfr{true, nil, &Error{Error: "Id mismatch"}} + return + } + + if first { + // A single SOA RR signals "no changes" + if len(in.Answer) == 1 && checkXfrSOA(in, true) { + return + } + + // But still check if the returned answer is ok + if !checkXfrSOA(in, true) { + m <- Xfr{true, nil, &Error{Error: "SOA not first record"}} + return + } + // This serial is important + serial = in.Answer[0].(*RR_SOA).Serial + first = !first + } + + // Now we need to check each message for SOA records, to see what we need to do + x.Add = true + if !first { + if d.Tsig != nil { + d.Tsig.TimersOnly = true + } + for k, r := range in.Answer { + // If the last record in the IXFR contains the servers' SOA, we should quit + if r.Header().Rrtype == TypeSOA { + switch { + case r.(*RR_SOA).Serial == serial: + if k == len(in.Answer)-1 { + // last rr is SOA with correct serial + //m <- r dont' send it + return + } + x.Add = true + if k != 0 { + // Intermediate SOA + continue + } + case r.(*RR_SOA).Serial != serial: + x.Add = false + continue // Don't need to see this SOA + } + } + x.RR = r + m <- x + } + } + } + panic("not reached") + return +} + +// Check if he SOA record exists in the Answer section of +// the packet. If first is true the first RR must be a soa +// if false, the last one should be a SOA +func checkXfrSOA(in *Msg, first bool) bool { + if len(in.Answer) > 0 { + if first { + return in.Answer[0].Header().Rrtype == TypeSOA + } else { + return in.Answer[len(in.Answer)-1].Header().Rrtype == TypeSOA + } + } + return false +} + +// Send the answer section to the channel +func sendMsg(in *Msg, c chan Xfr, nosoa bool) { + x := Xfr{Add: true} + for k, r := range in.Answer { + if nosoa && k == len(in.Answer)-1 { + continue + } + x.RR = r + c <- x + } +}