commit f67087a17651bea65fd882785d4fd88fb76919fd Author: Miek Gieben Date: Tue Aug 3 23:57:59 2010 +0200 added files diff --git a/Changes b/Changes new file mode 100644 index 00000000..989a612f --- /dev/null +++ b/Changes @@ -0,0 +1,6 @@ +Add resolver type and use the for querying +IPv6 support +DNSSEC support +uint8 support +base64 support (only for unpacking atm) +Split of the type definition into dnstypes.go diff --git a/Makefile b/Makefile new file mode 100644 index 00000000..3064e26c --- /dev/null +++ b/Makefile @@ -0,0 +1,18 @@ +# Copyright 2009 The Go Authors. All rights reserved. +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. + +include $(GOROOT)/src/Make.$(GOARCH) + +TARG=dns +GOFILES=\ + parse.go\ + dns.go\ + dnsmsg.go\ + dnsconfig.go\ + dnstypes.go\ + +include $(GOROOT)/src/Make.pkg + +restest: restest.go + 6g -I _obj restest.go && 6l -L _obj -o restest restest.6 diff --git a/TODO b/TODO new file mode 100644 index 00000000..5ad17a50 --- /dev/null +++ b/TODO @@ -0,0 +1,4 @@ +EDNS -- add EDNS0 support +DNSSEC - validation and the remaining records + DS, RRSIG, NSEC, etc. +Generic version of TryOneName diff --git a/dns.go b/dns.go new file mode 100644 index 00000000..c6216125 --- /dev/null +++ b/dns.go @@ -0,0 +1,311 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// DNS client: see RFC 1035. +// Has to be linked into package net for Dial. + +// TODO(rsc): +// Check periodically whether /etc/resolv.conf has changed. +// Could potentially handle many outstanding lookups faster. +// Could have a small cache. +// Random UDP source port (net.Dial should do that for us). +// Random request IDs. + +package dns + +import ( + "os" + "rand" + "time" + "net" +) + +// DnsError represents a DNS lookup error. +type DnsError struct { + Error string // description of the error + Name string // name looked for + Server string // server used + IsTimeout bool +} + +func (e *DnsError) String() string { + s := "lookup " + e.Name + if e.Server != "" { + s += " on " + e.Server + } + s += ": " + e.Error + return s +} + +func (e *DnsError) Timeout() bool { return e.IsTimeout } +func (e *DnsError) Temporary() bool { return e.IsTimeout } + +const noSuchHost = "no such host" + +type Resolver struct { + Servers []string // servers to use + Search []string // suffixes to append to local name + Ndots int // number of dots in name to trigger absolute lookup + Timeout int // seconds before giving up on packet + Attempts int // lost packets before giving up on server + Rotate bool // round robin among servers +} + +// Send a request on the connection and hope for a reply. +// Up to res.Attempts attempts. +func Exchange(res *Resolver, c net.Conn, name string, qtype uint16, qclass uint16) (*Msg, os.Error) { + if len(name) >= 256 { + return nil, &DnsError{Error: "name too long", Name: name} + } + out := new(Msg) + out.id = uint16(rand.Int()) ^ uint16(time.Nanoseconds()) + out.question = []Question{ + Question{name, qtype, qclass}, + } + out.recursion_desired = true + msg, ok := out.Pack() + if !ok { + return nil, &DnsError{Error: "internal error - cannot pack message", Name: name} + } + + for attempt := 0; attempt < res.Attempts; attempt++ { + n, err := c.Write(msg) + if err != nil { + return nil, err + } + + c.SetReadTimeout(int64(res.Timeout) * 1e9) // nanoseconds + // EDNS + buf := make([]byte, 2000) // More than enough. + n, err = c.Read(buf) + if err != nil { + // if e, ok := err.(Error); ok && e.Timeout() { + // continue + // } + return nil, err + } + buf = buf[0:n] + in := new(Msg) + if !in.Unpack(buf) || in.id != out.id { + continue + } + return in, nil + } + var server string + if a := c.RemoteAddr(); a != nil { + server = a.String() + } + return nil, &DnsError{Error: "no answer from server", Name: name, Server: server, IsTimeout: true} +} + +// Find answer for name in dns message. +// On return, if err == nil, addrs != nil. +func answer(name, server string, dns *Msg, qtype uint16) (addrs []RR, err os.Error) { + addrs = make([]RR, 0, len(dns.answer)) + + if dns.rcode == RcodeNameError && dns.recursion_available { + return nil, &DnsError{Error: noSuchHost, Name: name} + } + if dns.rcode != RcodeSuccess { + // None of the error codes make sense + // for the query we sent. If we didn't get + // a name error and we didn't get success, + // the server is behaving incorrectly. + return nil, &DnsError{Error: "server misbehaving", Name: name, Server: server} + } + + // Look for the name. + // Presotto says it's okay to assume that servers listed in + // /etc/resolv.conf are recursive resolvers. + // We asked for recursion, so it should have included + // all the answers we need in this one packet. +Cname: + for cnameloop := 0; cnameloop < 10; cnameloop++ { + addrs = addrs[0:0] + for i := 0; i < len(dns.answer); i++ { + rr := dns.answer[i] + h := rr.Header() + if h.Class == ClassINET && h.Name == name { + switch h.Rrtype { + case qtype: + n := len(addrs) + addrs = addrs[0 : n+1] + addrs[n] = rr + case TypeCNAME: + // redirect to cname + name = rr.(*RR_CNAME).Cname + continue Cname + } + } + } + if len(addrs) == 0 { + return nil, &DnsError{Error: noSuchHost, Name: name, Server: server} + } + return addrs, nil + } + + return nil, &DnsError{Error: "too many redirects", Name: name, Server: server} +} + +// Look up a single name + +func (res *Resolver) Query(name string, qtype uint16, qclass uint16) (msg *Msg, err os.Error) { + if len(res.Servers) == 0 { + return nil, &DnsError{Error: "no DNS servers", Name: name} + } + for i := 0; i < len(res.Servers); i++ { + // Calling Dial here is scary -- we have to be sure + // not to dial a name that will require a DNS lookup, + // or Dial will call back here to translate it. + // The DNS config parser has already checked that + // all the res.Servers[i] are IP addresses, which + // Dial will use without a DNS lookup. + server := res.Servers[i] + ":53" + c, cerr := net.Dial("udp", "", server) + if cerr != nil { + err = cerr + continue + } + msg, err = Exchange(res, c, name, qtype, qclass) + c.Close() + if err != nil { + continue + } + } + return +} + +// Do a lookup for a single name, which must be rooted +// (otherwise answer will not find the answers). +func (res *Resolver) TryOneName(name string, qtype uint16) (addrs []RR, err os.Error) { + if len(res.Servers) == 0 { + return nil, &DnsError{Error: "no DNS servers", Name: name} + } + for i := 0; i < len(res.Servers); i++ { + // Calling Dial here is scary -- we have to be sure + // not to dial a name that will require a DNS lookup, + // or Dial will call back here to translate it. + // The DNS config parser has already checked that + // all the res.Servers[i] are IP addresses, which + // Dial will use without a DNS lookup. + server := res.Servers[i] + ":53" + c, cerr := net.Dial("udp", "", server) + if cerr != nil { + err = cerr + continue + } + msg, merr := Exchange(res, c, name, qtype, ClassINET) + c.Close() + if merr != nil { + err = merr + continue + } + addrs, err = answer(name, server, msg, qtype) + if err == nil || err.(*DnsError).Error == noSuchHost { + break + } + } + return +} + +var res *Resolver +var dnserr os.Error + +func isDomainName(s string) bool { + // Requirements on DNS name: + // * must not be empty. + // * must be alphanumeric plus - and . + // * each of the dot-separated elements must begin + // and end with a letter or digit. + // RFC 1035 required the element to begin with a letter, + // but RFC 3696 says this has been relaxed to allow digits too. + // still, there must be a letter somewhere in the entire name. + if len(s) == 0 { + return false + } + if s[len(s)-1] != '.' { // simplify checking loop: make name end in dot + s += "." + } + + last := byte('.') + ok := false // ok once we've seen a letter + for i := 0; i < len(s); i++ { + c := s[i] + switch { + default: + return false + case 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z': + ok = true + case '0' <= c && c <= '9': + // fine + case c == '-': + // byte before dash cannot be dot + if last == '.' { + return false + } + case c == '.': + // byte before dot cannot be dot, dash + if last == '.' || last == '-' { + return false + } + } + last = c + } + + return ok +} + +func lookup(name string, qtype uint16) (cname string, addrs []RR, err os.Error) { + if !isDomainName(name) { + return name, nil, &DnsError{Error: "invalid domain name", Name: name} + } + + if dnserr != nil || res == nil { + err = dnserr + return + } + // If name is rooted (trailing dot) or has enough dots, + // try it by itself first. + rooted := len(name) > 0 && name[len(name)-1] == '.' + if rooted || count(name, '.') >= res.Ndots { + rname := name + if !rooted { + rname += "." + } + // Can try as ordinary name. + addrs, err = res.TryOneName(rname, qtype) + if err == nil { + cname = rname + return + } + } + if rooted { + return + } + + // Otherwise, try suffixes. + for i := 0; i < len(res.Search); i++ { + rname := name + "." + res.Search[i] + if rname[len(rname)-1] != '.' { + rname += "." + } + addrs, err = res.TryOneName(rname, qtype) + if err == nil { + cname = rname + return + } + } + + // Last ditch effort: try unsuffixed. + rname := name + if !rooted { + rname += "." + } + addrs, err = res.TryOneName(rname, qtype) + if err == nil { + cname = rname + return + } + return +} diff --git a/dnsconfig.go b/dnsconfig.go new file mode 100644 index 00000000..536fe397 --- /dev/null +++ b/dnsconfig.go @@ -0,0 +1,105 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Read system DNS config from /etc/resolv.conf +// Or any other file with the same format. + +package dns + +import ( + "os" + "net" +) + +// See resolv.conf(5) on a Linux machine. +// TODO(rsc): Supposed to call uname() and chop the beginning +// of the host name to get the default search domain. +// We assume it's in resolv.conf anyway. +func ReadConfig(f string) (*Resolver, os.Error) { + if f == "" { + f = "/etc/resolv.conf" + } + + file, err := open(f) + if err != nil { + return nil, err // DnsError + } + r := new(Resolver) + r.Servers = make([]string, 3)[0:0] // small, but the standard limit + r.Search = make([]string, 0) + r.Ndots = 1 + r.Timeout = 5 + r.Attempts = 2 + r.Rotate = false + for line, ok := file.readLine(); ok; line, ok = file.readLine() { + f := getFields(line) + if len(f) < 1 { + continue + } + switch f[0] { + case "nameserver": // add one name server + a := r.Servers + n := len(a) + if len(f) > 1 && n < cap(a) { + // One more check: make sure server name is + // just an IP address. Otherwise we need DNS + // to look it up. + name := f[1] + switch len(net.ParseIP(name)) { + case 16: + name = "[" + name + "]" + fallthrough + case 4: + a = a[0 : n+1] + a[n] = name + r.Servers = a + } + } + + case "domain": // set search path to just this domain + if len(f) > 1 { + r.Search = make([]string, 1) + r.Search[0] = f[1] + } else { + r.Search = make([]string, 0) + } + + case "search": // set search path to given servers + r.Search = make([]string, len(f)-1) + for i := 0; i < len(r.Search); i++ { + r.Search[i] = f[i+1] + } + + case "options": // magic options + for i := 1; i < len(f); i++ { + s := f[i] + switch { + case len(s) >= 6 && s[0:6] == "ndots:": + n, _, _ := dtoi(s, 6) + if n < 1 { + n = 1 + } + r.Ndots = n + case len(s) >= 8 && s[0:8] == "timeout:": + n, _, _ := dtoi(s, 8) + if n < 1 { + n = 1 + } + r.Timeout = n + case len(s) >= 8 && s[0:9] == "attempts:": + n, _, _ := dtoi(s, 9) + if n < 1 { + n = 1 + } + r.Attempts = n + case s == "rotate": + r.Rotate = true + } + } + } + } + file.close() + + return r, nil +} diff --git a/dnsmsg.go b/dnsmsg.go new file mode 100644 index 00000000..0eda4ba1 --- /dev/null +++ b/dnsmsg.go @@ -0,0 +1,631 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// DNS packet assembly. See RFC 1035. +// +// This is intended to support name resolution during net.Dial. +// It doesn't have to be blazing fast. +// +// Rather than write the usual handful of routines to pack and +// unpack every message that can appear on the wire, we use +// reflection to write a generic pack/unpack for structs and then +// use it. Thus, if in the future we need to define new message +// structs, no new pack/unpack/printing code needs to be written. +// +// The first half of this file defines the DNS message formats. +// The second half implements the conversion to and from wire format. +// A few of the structure elements have string tags to aid the +// generic pack/unpack routines. +// +// TODO(miekg): + +package dns + +import ( + "fmt" + "os" + "reflect" + "net" + "strconv" + "encoding/base64" + "encoding/hex" +) + +// Packing and unpacking. +// +// All the packers and unpackers take a (msg []byte, off int) +// and return (off1 int, ok bool). If they return ok==false, they +// also return off1==len(msg), so that the next unpacker will +// also fail. This lets us avoid checks of ok until the end of a +// packing sequence. + +// Map of strings for each RR wire type. +var class_str = map[uint16]string{ + ClassINET: "IN", + ClassCSNET: "CS", + ClassCHAOS: "CH", + ClassHESIOD: "HS", + ClassANY: "ANY", +} + +// Map of strings for opcodes. +var opcode_str = map[int]string{ + 0: "QUERY", +} + +// Map of strings for rcode +var rcode_str = map[int]string{ + 0: "NOERROR", + + + 3: "NXDOMAIN", +} + +// Pack a domain name s into msg[off:]. +// Domain names are a sequence of counted strings +// split at the dots. They end with a zero-length string. +func packDomainName(s string, msg []byte, off int) (off1 int, ok bool) { + // Add trailing dot to canonicalize name. + if n := len(s); n == 0 || s[n-1] != '.' { + s += "." + } + + // Each dot ends a segment of the name. + // We trade each dot byte for a length byte. + // There is also a trailing zero. + // Check that we have all the space we need. + tot := len(s) + 1 + if off+tot > len(msg) { + return len(msg), false + } + + // Emit sequence of counted strings, chopping at dots. + begin := 0 + for i := 0; i < len(s); i++ { + if s[i] == '.' { + if i-begin >= 1<<6 { // top two bits of length must be clear + return len(msg), false + } + msg[off] = byte(i - begin) + off++ + for j := begin; j < i; j++ { + msg[off] = s[j] + off++ + } + begin = i + 1 + } + } + msg[off] = 0 + off++ + return off, true +} + +// Unpack a domain name. +// In addition to the simple sequences of counted strings above, +// domain names are allowed to refer to strings elsewhere in the +// packet, to avoid repeating common suffixes when returning +// many entries in a single domain. The pointers are marked +// by a length byte with the top two bits set. Ignoring those +// two bits, that byte and the next give a 14 bit offset from msg[0] +// where we should pick up the trail. +// Note that if we jump elsewhere in the packet, +// we return off1 == the offset after the first pointer we found, +// which is where the next record will start. +// In theory, the pointers are only allowed to jump backward. +// We let them jump anywhere and stop jumping after a while. +func unpackDomainName(msg []byte, off int) (s string, off1 int, ok bool) { + s = "" + ptr := 0 // number of pointers followed +Loop: + for { + if off >= len(msg) { + return "", len(msg), false + } + c := int(msg[off]) + off++ + switch c & 0xC0 { + case 0x00: + if c == 0x00 { + // end of name + break Loop + } + // literal string + if off+c > len(msg) { + return "", len(msg), false + } + s += string(msg[off:off+c]) + "." + off += c + case 0xC0: + // pointer to somewhere else in msg. + // remember location after first ptr, + // since that's how many bytes we consumed. + // also, don't follow too many pointers -- + // maybe there's a loop. + if off >= len(msg) { + return "", len(msg), false + } + c1 := msg[off] + off++ + if ptr == 0 { + off1 = off + } + if ptr++; ptr > 10 { + return "", len(msg), false + } + off = (c^0xC0)<<8 | int(c1) + default: + // 0x80 and 0x40 are reserved + return "", len(msg), false + } + } + if ptr == 0 { + off1 = off + } + return s, off1, true +} + +// TODO(rsc): Move into generic library? +// Pack a reflect.StructValue into msg. Struct members can only be uint16, uint32, string, +// and other (often anonymous) structs. +// IPV6 IPV4 still to do +func packStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int, ok bool) { + for i := 0; i < val.NumField(); i++ { + f := val.Type().(*reflect.StructType).Field(i) + switch fv := val.Field(i).(type) { + default: + BadType: + fmt.Fprintf(os.Stderr, "net: dns: unknown packing type %v", f.Type) + return len(msg), false + case *reflect.StructValue: + off, ok = packStructValue(fv, msg, off) + case *reflect.UintValue: + i := fv.Get() + switch fv.Type().Kind() { + default: + goto BadType + case reflect.Uint8: + if off+1 > len(msg) { + return len(msg), false + } + msg[off] = byte(i) + off++ + case reflect.Uint16: + if off+2 > len(msg) { + return len(msg), false + } + msg[off] = byte(i >> 8) + msg[off+1] = byte(i) + off += 2 + case reflect.Uint32: + if off+4 > len(msg) { + return len(msg), false + } + msg[off] = byte(i >> 24) + msg[off+1] = byte(i >> 16) + msg[off+2] = byte(i >> 8) + msg[off+3] = byte(i) + off += 4 + } + case *reflect.StringValue: + // There are multiple string encodings. + // The tag distinguishes ordinary strings from domain names. + s := fv.Get() + switch f.Tag { + default: + return len(msg), false + case "base64": + //TODO + case "domain-name": + off, ok = packDomainName(s, msg, off) + if !ok { + return len(msg), false + } + case "": + // Counted string: 1 byte length. + if len(s) > 255 || off+1+len(s) > len(msg) { + return len(msg), false + } + msg[off] = byte(len(s)) + off++ + for i := 0; i < len(s); i++ { + msg[off+i] = s[i] + } + off += len(s) + } + } + } + return off, true +} + +func structValue(any interface{}) *reflect.StructValue { + return reflect.NewValue(any).(*reflect.PtrValue).Elem().(*reflect.StructValue) +} + +func packStruct(any interface{}, msg []byte, off int) (off1 int, ok bool) { + off, ok = packStructValue(structValue(any), msg, off) + return off, ok +} + +// Unpack a reflect.StructValue from msg. +// Same restrictions as packStructValue. +func unpackStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int, ok bool) { + for i := 0; i < val.NumField(); i++ { + f := val.Type().(*reflect.StructType).Field(i) + switch fv := val.Field(i).(type) { + default: + BadType: + fmt.Fprintf(os.Stderr, "net: dns: unknown packing type %v", f.Type) + return len(msg), false + case *reflect.SliceValue: + switch f.Tag { + default: + fmt.Fprintf(os.Stderr, "net: dns: unknown IP tag %v", f.Tag) + return len(msg), false + case "ipv4": + if off+net.IPv4len > len(msg) { + return len(msg), false + } + b := net.IPv4(msg[off], msg[off+1], msg[off+2], msg[off+3]) + fv.Set(reflect.NewValue(b).(*reflect.SliceValue)) + off += net.IPv4len + case "ipv6": + if off+net.IPv6len > len(msg) { + return len(msg), false + } + p := make(net.IP, net.IPv6len) + copy(p, msg[off:off+net.IPv6len]) + b := net.IP(p) + fv.Set(reflect.NewValue(b).(*reflect.SliceValue)) + off += net.IPv6len + } + case *reflect.StructValue: + off, ok = unpackStructValue(fv, msg, off) + case *reflect.UintValue: + switch fv.Type().Kind() { + default: + goto BadType + case reflect.Uint8: + if off+1 > len(msg) { + return len(msg), false + } + i := uint8(msg[off]) + fv.Set(uint64(i)) + off++ + case reflect.Uint16: + if off+2 > len(msg) { + return len(msg), false + } + i := uint16(msg[off])<<8 | uint16(msg[off+1]) + fv.Set(uint64(i)) + off += 2 + case reflect.Uint32: + if off+4 > len(msg) { + return len(msg), false + } + i := uint32(msg[off])<<24 | uint32(msg[off+1])<<16 | uint32(msg[off+2])<<8 | uint32(msg[off+3]) + fv.Set(uint64(i)) + off += 4 + } + case *reflect.StringValue: + var s string + switch f.Tag { + default: + fmt.Fprintf(os.Stderr, "net: dns: unknown string tag %v", f.Tag) + return len(msg), false + case "hex": + // Rest of the RR is hex encoded + rdlength := int(val.FieldByName("Hdr").(*reflect.StructValue).FieldByName("Rdlength").(*reflect.UintValue).Get()) + var consumed int + switch val.Type().Name() { + case "RR_DS": + consumed = 4 // KeyTag(2) + Algorithm(1) + DigestType(1) + default: + consumed = 0 // TODO + } + s = hex.EncodeToString(msg[off:off+rdlength-consumed]) + off += rdlength-consumed + case "base64": + // Rest of the RR is base64 encoded value + rdlength := int(val.FieldByName("Hdr").(*reflect.StructValue).FieldByName("Rdlength").(*reflect.UintValue).Get()) + // Need to know how much of rdlength is already consumed + var consumed int + // Can't I figure out via reflect how many bytes there are already consumed?? + switch val.Type().Name() { + case "RR_DNSKEY": + consumed = 4 // Flags(2) + Protocol(1) + Algorithm(1) + case "RR_DS": + consumed = 4 // KeyTag(2) + Algorithm(1) + DigestType(1) + default: + consumed = 0 // TODO + } + b64 := make([]byte, base64.StdEncoding.EncodedLen(len(msg[off:off+rdlength-consumed]))) + base64.StdEncoding.Encode(b64, msg[off:off+rdlength-consumed]) + s = string(b64) + off += rdlength-consumed + case "domain-name": + s, off, ok = unpackDomainName(msg, off) + if !ok { + return len(msg), false + } + case "": + if off >= len(msg) || off+1+int(msg[off]) > len(msg) { + return len(msg), false + } + n := int(msg[off]) + off++ + b := make([]byte, n) + for i := 0; i < n; i++ { + b[i] = msg[off+i] + } + off += n + s = string(b) + } + fv.Set(s) + } + } + return off, true +} + +func unpackStruct(any interface{}, msg []byte, off int) (off1 int, ok bool) { + off, ok = unpackStructValue(structValue(any), msg, off) + return off, ok +} + +// Generic struct printer. +// Doesn't care about the string tag "domain-name", +// but does look for an "ipv4" tag on uint32 variables, +// printing them as IP addresses. +func printStructValue(val *reflect.StructValue) string { + s := "{" + for i := 0; i < val.NumField(); i++ { + if i > 0 { + s += ", " + } + f := val.Type().(*reflect.StructType).Field(i) + if !f.Anonymous { + s += f.Name + "=" + } + fval := val.Field(i) + if fv, ok := fval.(*reflect.StructValue); ok { + s += printStructValue(fv) + } else if fv, ok := fval.(*reflect.UintValue); ok && f.Tag == "ipv4" { + i := fv.Get() + s += net.IPv4(byte(i>>24), byte(i>>16), byte(i>>8), byte(i)).String() + } else { + s += fmt.Sprint(fval.Interface()) + } + } + s += "}" + return s +} + +func PrintStruct(any interface{}) string { return printStructValue(structValue(any)) } + +// Resource record packer. +func packRR(rr RR, msg []byte, off int) (off2 int, ok bool) { + var off1 int + // pack twice, once to find end of header + // and again to find end of packet. + // a bit inefficient but this doesn't need to be fast. + // off1 is end of header + // off2 is end of rr + off1, ok = packStruct(rr.Header(), msg, off) + off2, ok = packStruct(rr, msg, off) + if !ok { + return len(msg), false + } + // pack a third time; redo header with correct data length + rr.Header().Rdlength = uint16(off2 - off1) + packStruct(rr.Header(), msg, off) + return off2, true +} + +// Resource record unpacker. +func unpackRR(msg []byte, off int) (rr RR, off1 int, ok bool) { + // unpack just the header, to find the rr type and length + var h RR_Header + off0 := off + if off, ok = unpackStruct(&h, msg, off); !ok { + return nil, len(msg), false + } + end := off + int(h.Rdlength) + + // make an rr of that type and re-unpack. + // again inefficient but doesn't need to be fast. + mk, known := rr_mk[int(h.Rrtype)] + if !known { + return &h, end, true + } + rr = mk() + off, ok = unpackStruct(rr, msg, off0) + if off != end { + return &h, end, true + } + return rr, off, ok +} + +// Usable representation of a DNS packet. + +// A manually-unpacked version of (id, bits). +// This is in its own struct for easy printing. +type MsgHdr struct { + id uint16 + response bool + opcode int + authoritative bool + truncated bool + recursion_desired bool + recursion_available bool + rcode int +} + +//;; ->>HEADER<<- opcode: QUERY, status: NOERROR, id: 48404 +//;; flags: qr aa rd ra; +func (h *MsgHdr) String() string { + s := ";; ->>HEADER<<- opcode: " + opcode_str[h.opcode] + s += ", status: " + rcode_str[h.rcode] + s += ", id: " + strconv.Itoa(int(h.id)) + "\n" + + s += ";; flags: " + if h.authoritative { + s += "aa " + } + if h.truncated { + s += "tc " + } + if h.recursion_desired { + s += "rd " + } + if h.recursion_available { + s += "ra " + } + s += ";" + return s +} + +type Msg struct { + MsgHdr + question []Question + edns []Edns + answer []RR + ns []RR + extra []RR +} + + +func (dns *Msg) Pack() (msg []byte, ok bool) { + var dh Header + + // Convert convenient Msg into wire-like Header. + dh.Id = dns.id + dh.Bits = uint16(dns.opcode)<<11 | uint16(dns.rcode) + if dns.recursion_available { + dh.Bits |= _RA + } + if dns.recursion_desired { + dh.Bits |= _RD + } + if dns.truncated { + dh.Bits |= _TC + } + if dns.authoritative { + dh.Bits |= _AA + } + if dns.response { + dh.Bits |= _QR + } + + // Prepare variable sized arrays. + question := dns.question + answer := dns.answer + ns := dns.ns + extra := dns.extra + + dh.Qdcount = uint16(len(question)) + dh.Ancount = uint16(len(answer)) + dh.Nscount = uint16(len(ns)) + dh.Arcount = uint16(len(extra)) + + // Could work harder to calculate message size, + // but this is far more than we need and not + // big enough to hurt the allocator. + msg = make([]byte, 2000) + + // Pack it in: header and then the pieces. + off := 0 + off, ok = packStruct(&dh, msg, off) + for i := 0; i < len(question); i++ { + off, ok = packStruct(&question[i], msg, off) + } + for i := 0; i < len(answer); i++ { + off, ok = packRR(answer[i], msg, off) + } + for i := 0; i < len(ns); i++ { + off, ok = packRR(ns[i], msg, off) + } + for i := 0; i < len(extra); i++ { + off, ok = packRR(extra[i], msg, off) + } + if !ok { + return nil, false + } + return msg[0:off], true +} + +func (dns *Msg) Unpack(msg []byte) bool { + // Header. + var dh Header + off := 0 + var ok bool + if off, ok = unpackStruct(&dh, msg, off); !ok { + return false + } + dns.id = dh.Id + dns.response = (dh.Bits & _QR) != 0 + dns.opcode = int(dh.Bits>>11) & 0xF + dns.authoritative = (dh.Bits & _AA) != 0 + dns.truncated = (dh.Bits & _TC) != 0 + dns.recursion_desired = (dh.Bits & _RD) != 0 + dns.recursion_available = (dh.Bits & _RA) != 0 + dns.rcode = int(dh.Bits & 0xF) + + // Arrays. + dns.question = make([]Question, dh.Qdcount) + dns.answer = make([]RR, dh.Ancount) + dns.ns = make([]RR, dh.Nscount) + dns.extra = make([]RR, dh.Arcount) + + for i := 0; i < len(dns.question); i++ { + off, ok = unpackStruct(&dns.question[i], msg, off) + } + for i := 0; i < len(dns.answer); i++ { + dns.answer[i], off, ok = unpackRR(msg, off) + } + for i := 0; i < len(dns.ns); i++ { + dns.ns[i], off, ok = unpackRR(msg, off) + } + for i := 0; i < len(dns.extra); i++ { + dns.extra[i], off, ok = unpackRR(msg, off) + } + if !ok { + return false + } + if off != len(msg) { + println("extra bytes in dns packet", off, "<", len(msg)) + } + return true +} + +func (dns *Msg) String() string { + s := dns.MsgHdr.String() + " " + s += "QUERY: " + strconv.Itoa(len(dns.question)) + ", " + s += "ANSWER: " + strconv.Itoa(len(dns.answer)) + ", " + s += "AUTHORITY: " + strconv.Itoa(len(dns.ns)) + ", " + s += "ADDITIONAL: " + strconv.Itoa(len(dns.extra)) + "\n" + if len(dns.question) > 0 { + s += "\n;; QUESTION SECTION:\n" + for i := 0; i < len(dns.question); i++ { + s += dns.question[i].String() + "\n" + } + } + if len(dns.answer) > 0 { + s += "\n;; ANSWER SECTION:\n" + for i := 0; i < len(dns.answer); i++ { + s += dns.answer[i].String() + "\n" + } + } + if len(dns.ns) > 0 { + s += "\n;; AUTHORITY SECTION:\n" + for i := 0; i < len(dns.ns); i++ { + s += dns.ns[i].String() + "\n" + } + } + if len(dns.extra) > 0 { + s += "\n;; ADDITIONAL SECTION:\n" + for i := 0; i < len(dns.extra); i++ { + s += dns.extra[i].String() + "\n" + } + } + return s +} diff --git a/dnstypes.go b/dnstypes.go new file mode 100644 index 00000000..18dd5ca1 --- /dev/null +++ b/dnstypes.go @@ -0,0 +1,509 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// DNS Resource Records Types. See RFC 1035 and ... +// + +package dns + +import ( + "net" + "strconv" +) + +// Packet formats + +// Wire constants. +const ( + // valid RR_Header.Rrtype and Question.qtype + TypeA = 1 + TypeNS = 2 + TypeMD = 3 + TypeMF = 4 + TypeCNAME = 5 + TypeSOA = 6 + TypeMB = 7 + TypeMG = 8 + TypeMR = 9 + TypeNULL = 10 + TypeWKS = 11 + TypePTR = 12 + TypeHINFO = 13 + TypeMINFO = 14 + TypeMX = 15 + TypeTXT = 16 + TypeAAAA = 28 + TypeSRV = 33 + + // EDNS + TypeOPT = 41 + + // DNSSEC + TypeDS = 43 + TypeRRSIG = 46 + TypeNSEC = 47 + TypeDNSKEY = 48 + TypeNSEC3 = 50 + TypeNSEC3PARAM = 51 + + // valid Question.qtype only + TypeAXFR = 252 + TypeMAILB = 253 + TypeMAILA = 254 + TypeALL = 255 + + // valid Question.qclass + ClassINET = 1 + ClassCSNET = 2 + ClassCHAOS = 3 + ClassHESIOD = 4 + ClassANY = 255 + + // Msg.rcode + RcodeSuccess = 0 + RcodeFormatError = 1 + RcodeServerFailure = 2 + RcodeNameError = 3 + RcodeNotImplemented = 4 + RcodeRefused = 5 +) + +// The wire format for the DNS packet header. +type Header struct { + Id uint16 + Bits uint16 + Qdcount, Ancount, Nscount, Arcount uint16 +} + +const ( + // Header.Bits + _QR = 1 << 15 // query/response (response=1) + _AA = 1 << 10 // authoritative + _TC = 1 << 9 // truncated + _RD = 1 << 8 // recursion desired + _RA = 1 << 7 // recursion available + // _AD = 1 << ? // authenticated data + // _CD = 1 << ? // checking disabled +) + +const ( + // DNSSEC algorithms + AlgRSAMD5 = 1 + AlgDH = 2 + AlgDSA = 3 + AlgECC = 4 + AlgRSASHA1 = 5 + AlgRSASHA256 = 8 + AlgRSASHA512 = 10 + AlgECCGOST = 12 +) + +// DNS queries. +type Question struct { + Name string "domain-name" // "domain-name" specifies encoding; see packers below + Qtype uint16 + Qclass uint16 +} + +// Rcode needs some setting and getting work for _z and _version +type Edns struct { + Name string "domain-name" + Opt uint16 // was type + UDPSize uint16 // was class + Rcode uint32 // was TTL + Rdlength uint16 +} + + +func (q *Question) String() string { + // prefix with ; (as in dig) + s := ";" + q.Name + "\t" + s = s + class_str[q.Qclass] + "\t" + s = s + rr_str[q.Qtype] + return s +} + +// DNS responses (resource records). +// There are many types of messages, +// but they all share the same header. +type RR_Header struct { + Name string "domain-name" + Rrtype uint16 + Class uint16 + Ttl uint32 + Rdlength uint16 // length of data after header +} + +func (h *RR_Header) Header() *RR_Header { + return h +} + +func (h *RR_Header) String() string { + var s string + if len(h.Name) == 0 { + s = ".\t" + } else { + s = h.Name + "\t" + } + s = s + strconv.Itoa(int(h.Ttl)) + "\t" // why no strconv.Uint16?? + s = s + class_str[h.Class] + "\t" + s = s + rr_str[h.Rrtype] + "\t" + return s +} + +type RR interface { + Header() *RR_Header + String() string +} + +// Specific DNS RR formats for each query type. + +type RR_CNAME struct { + Hdr RR_Header + Cname string "domain-name" +} + +func (rr *RR_CNAME) Header() *RR_Header { + return &rr.Hdr +} + +func (rr *RR_CNAME) String() string { + return rr.Hdr.String() + rr.Cname +} + +type RR_HINFO struct { + Hdr RR_Header + Cpu string + Os string +} + +func (rr *RR_HINFO) Header() *RR_Header { + return &rr.Hdr +} + +func (rr *RR_HINFO) String() string { + return rr.Hdr.String() + rr.Cpu + " " + rr.Os +} + +type RR_MB struct { + Hdr RR_Header + Mb string "domain-name" +} + +func (rr *RR_MB) Header() *RR_Header { + return &rr.Hdr +} + +func (rr *RR_MB) String() string { + return rr.Hdr.String() + rr.Mb +} + +type RR_MG struct { + Hdr RR_Header + Mg string "domain-name" +} + +func (rr *RR_MG) Header() *RR_Header { + return &rr.Hdr +} + +func (rr *RR_MG) String() string { + return rr.Hdr.String() + rr.Mg +} + +type RR_MINFO struct { + Hdr RR_Header + Rmail string "domain-name" + Email string "domain-name" +} + +func (rr *RR_MINFO) Header() *RR_Header { + return &rr.Hdr +} + +func (rr *RR_MINFO) String() string { + return rr.Hdr.String() + rr.Rmail + " " + rr.Email +} + +type RR_MR struct { + Hdr RR_Header + Mr string "domain-name" +} + +func (rr *RR_MR) Header() *RR_Header { + return &rr.Hdr +} + +func (rr *RR_MR) String() string { + return rr.Hdr.String() + rr.Mr +} + +type RR_MX struct { + Hdr RR_Header + Pref uint16 + Mx string "domain-name" +} + +func (rr *RR_MX) Header() *RR_Header { + return &rr.Hdr +} + +func (rr *RR_MX) String() string { + return rr.Hdr.String() + strconv.Itoa(int(rr.Pref)) + " " + rr.Mx +} + +type RR_NS struct { + Hdr RR_Header + Ns string "domain-name" +} + +func (rr *RR_NS) Header() *RR_Header { + return &rr.Hdr +} + +func (rr *RR_NS) String() string { + return rr.Hdr.String() + rr.Ns +} + +type RR_PTR struct { + Hdr RR_Header + Ptr string "domain-name" +} + +func (rr *RR_PTR) Header() *RR_Header { + return &rr.Hdr +} + +func (rr *RR_PTR) String() string { + return rr.Hdr.String() + rr.Ptr +} + +type RR_SOA struct { + Hdr RR_Header + Ns string "domain-name" + Mbox string "domain-name" + Serial uint32 + Refresh uint32 + Retry uint32 + Expire uint32 + Minttl uint32 +} + +func (rr *RR_SOA) Header() *RR_Header { + return &rr.Hdr +} + +func (rr *RR_SOA) String() string { + return rr.Hdr.String() + rr.Ns + " " + rr.Mbox + + " " + strconv.Itoa(int(rr.Serial)) + + " " + strconv.Itoa(int(rr.Refresh)) + + " " + strconv.Itoa(int(rr.Retry)) + + " " + strconv.Itoa(int(rr.Expire)) + + " " + strconv.Itoa(int(rr.Minttl)) +} + +type RR_TXT struct { + Hdr RR_Header + Txt string // not domain name +} + +func (rr *RR_TXT) Header() *RR_Header { + return &rr.Hdr +} + +func (rr *RR_TXT) String() string { + return rr.Hdr.String() + "\"" + rr.Txt + "\"" +} + +type RR_SRV struct { + Hdr RR_Header + Priority uint16 + Weight uint16 + Port uint16 + Target string "domain-name" +} + +func (rr *RR_SRV) Header() *RR_Header { + return &rr.Hdr +} + +func (rr *RR_SRV) String() string { + return rr.Hdr.String() + + strconv.Itoa(int(rr.Priority)) + " " + + strconv.Itoa(int(rr.Weight)) + " " + + strconv.Itoa(int(rr.Port)) + " " + rr.Target +} + +type RR_A struct { + Hdr RR_Header + A net.IP "ipv4" +} + +func (rr *RR_A) Header() *RR_Header { + return &rr.Hdr +} + +func (rr *RR_A) String() string { + return rr.Hdr.String() + rr.A.String() +} + +type RR_AAAA struct { + Hdr RR_Header + AAAA net.IP "ipv6" +} + +func (rr *RR_AAAA) Header() *RR_Header { + return &rr.Hdr +} + +func (rr *RR_AAAA) String() string { + return rr.Hdr.String() + rr.AAAA.String() +} + +// DNSSEC types +type RR_RRSIG struct { + Hdr RR_Header +} + +func (rr *RR_RRSIG) Header() *RR_Header { + return &rr.Hdr +} +func (rr *RR_RRSIG) String() string { + return "BLAH" +} + +type RR_NSEC struct { + Hdr RR_Header +} + +func (rr *RR_NSEC) Header() *RR_Header { + return &rr.Hdr +} + +func (rr *RR_NSEC) String() string { + return "BLAH" +} + +type RR_DS struct { + Hdr RR_Header + KeyTag uint16 + Algorithm uint8 + DigestType uint8 + Digest string "hex" +} + +func (rr *RR_DS) Header() *RR_Header { + return &rr.Hdr +} + +func (rr *RR_DS) String() string { + return rr.Hdr.String() + + " " + strconv.Itoa(int(rr.KeyTag)) + + " " + alg_str[rr.Algorithm] + + " " + strconv.Itoa(int(rr.DigestType)) + + " " + rr.Digest +} + +type RR_DNSKEY struct { + Hdr RR_Header + Flags uint16 + Protocol uint8 + Algorithm uint8 + PubKey string "base64" +} + +func (rr *RR_DNSKEY) Header() *RR_Header { + return &rr.Hdr +} + +func (rr *RR_DNSKEY) String() string { + return rr.Hdr.String() + + " " + strconv.Itoa(int(rr.Flags)) + + " " + strconv.Itoa(int(rr.Protocol)) + + " " + alg_str[rr.Algorithm] + + " " + rr.PubKey // encoding/base64 +} + +type RR_NSEC3 struct { + Hdr RR_Header +} + +func (rr *RR_NSEC3) Header() *RR_Header { + return &rr.Hdr +} + +func (rr *RR_NSEC3) String() string { + return "BLAH" +} + +type RR_NSEC3PARAM struct { + Hdr RR_Header +} + +func (rr *RR_NSEC3PARAM) Header() *RR_Header { + return &rr.Hdr +} + +func (rr *RR_NSEC3PARAM) String() string { + return "BLAH" +} + +// Map of constructors for each RR wire type. +var rr_mk = map[int]func() RR{ + TypeCNAME: func() RR { return new(RR_CNAME) }, + TypeHINFO: func() RR { return new(RR_HINFO) }, + TypeMB: func() RR { return new(RR_MB) }, + TypeMG: func() RR { return new(RR_MG) }, + TypeMINFO: func() RR { return new(RR_MINFO) }, + TypeMR: func() RR { return new(RR_MR) }, + TypeMX: func() RR { return new(RR_MX) }, + TypeNS: func() RR { return new(RR_NS) }, + TypePTR: func() RR { return new(RR_PTR) }, + TypeSOA: func() RR { return new(RR_SOA) }, + TypeTXT: func() RR { return new(RR_TXT) }, + TypeSRV: func() RR { return new(RR_SRV) }, + TypeA: func() RR { return new(RR_A) }, + TypeAAAA: func() RR { return new(RR_AAAA) }, + TypeDS: func() RR { return new(RR_DS) }, + TypeRRSIG: func() RR { return new(RR_RRSIG) }, + TypeNSEC: func() RR { return new(RR_NSEC) }, + TypeDNSKEY: func() RR { return new(RR_DNSKEY) }, + TypeNSEC3: func() RR { return new(RR_NSEC3) }, + TypeNSEC3PARAM: func() RR { return new(RR_NSEC3PARAM) }, +} + +// Map of strings for each RR wire type. +var rr_str = map[uint16]string{ + TypeCNAME: "CNAME", + TypeHINFO: "HINFO", + TypeMB: "MB", + TypeMG: "MG", + TypeMINFO: "MINFO", + TypeMR: "MR", + TypeMX: "MX", + TypeNS: "NS", + TypePTR: "PTR", + TypeSOA: "SOA", + TypeTXT: "TXT", + TypeSRV: "SRV", + TypeA: "A", + TypeAAAA: "AAAA", + TypeDS: "DS", + TypeRRSIG: "RRSIG", + TypeNSEC: "NSEC", + TypeDNSKEY: "DNSKEY", + TypeNSEC3: "NSEC3", + TypeNSEC3PARAM: "NSEC3PARAM", +} + +// Map for algorithm names. +var alg_str = map[uint8]string{ + AlgRSAMD5: "RSAMD5", + AlgDH: "DH", + AlgDSA: "DSA", + AlgRSASHA1: "RSASHA1", + AlgRSASHA256: "RSASHA256", + AlgRSASHA512: "RSASHA512", + AlgECCGOST: "ECC-GOST", +} diff --git a/parse.go b/parse.go new file mode 100644 index 00000000..a621f73e --- /dev/null +++ b/parse.go @@ -0,0 +1,216 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Simple file i/o and string manipulation, to avoid +// depending on strconv and bufio and strings. + +// Don't want or NEED this. Must GO + +package dns + +import ( + "io" + "os" +) + +type file struct { + file *os.File + data []byte + atEOF bool +} + +func (f *file) close() { f.file.Close() } + +func (f *file) getLineFromData() (s string, ok bool) { + data := f.data + i := 0 + for i = 0; i < len(data); i++ { + if data[i] == '\n' { + s = string(data[0:i]) + ok = true + // move data + i++ + n := len(data) - i + copy(data[0:], data[i:]) + f.data = data[0:n] + return + } + } + if f.atEOF && len(f.data) > 0 { + // EOF, return all we have + s = string(data) + f.data = f.data[0:0] + ok = true + } + return +} + +func (f *file) readLine() (s string, ok bool) { + if s, ok = f.getLineFromData(); ok { + return + } + if len(f.data) < cap(f.data) { + ln := len(f.data) + n, err := io.ReadFull(f.file, f.data[ln:cap(f.data)]) + if n >= 0 { + f.data = f.data[0 : ln+n] + } + if err == os.EOF { + f.atEOF = true + } + } + s, ok = f.getLineFromData() + return +} + +func open(name string) (*file, os.Error) { + fd, err := os.Open(name, os.O_RDONLY, 0) + if err != nil { + return nil, err + } + return &file{fd, make([]byte, 1024)[0:0], false}, nil +} + +func byteIndex(s string, c byte) int { + for i := 0; i < len(s); i++ { + if s[i] == c { + return i + } + } + return -1 +} + +// Count occurrences in s of any bytes in t. +func countAnyByte(s string, t string) int { + n := 0 + for i := 0; i < len(s); i++ { + if byteIndex(t, s[i]) >= 0 { + n++ + } + } + return n +} + +// Split s at any bytes in t. +func splitAtBytes(s string, t string) []string { + a := make([]string, 1+countAnyByte(s, t)) + n := 0 + last := 0 + for i := 0; i < len(s); i++ { + if byteIndex(t, s[i]) >= 0 { + if last < i { + a[n] = string(s[last:i]) + n++ + } + last = i + 1 + } + } + if last < len(s) { + a[n] = string(s[last:]) + n++ + } + return a[0:n] +} + +func getFields(s string) []string { return splitAtBytes(s, " \r\t\n") } + +// Bigger than we need, not too big to worry about overflow +const big = 0xFFFFFF + +// Decimal to integer starting at &s[i0]. +// Returns number, new offset, success. +func dtoi(s string, i0 int) (n int, i int, ok bool) { + n = 0 + for i = i0; i < len(s) && '0' <= s[i] && s[i] <= '9'; i++ { + n = n*10 + int(s[i]-'0') + if n >= big { + return 0, i, false + } + } + if i == i0 { + return 0, i, false + } + return n, i, true +} + +// Hexadecimal to integer starting at &s[i0]. +// Returns number, new offset, success. +func xtoi(s string, i0 int) (n int, i int, ok bool) { + n = 0 + for i = i0; i < len(s); i++ { + if '0' <= s[i] && s[i] <= '9' { + n *= 16 + n += int(s[i] - '0') + } else if 'a' <= s[i] && s[i] <= 'f' { + n *= 16 + n += int(s[i]-'a') + 10 + } else if 'A' <= s[i] && s[i] <= 'F' { + n *= 16 + n += int(s[i]-'A') + 10 + } else { + break + } + if n >= big { + return 0, i, false + } + } + if i == i0 { + return 0, i, false + } + return n, i, true +} + +// Integer to decimal. +func itoa(i int) string { + var buf [30]byte + n := len(buf) + neg := false + if i < 0 { + i = -i + neg = true + } + ui := uint(i) + for ui > 0 || n == len(buf) { + n-- + buf[n] = byte('0' + ui%10) + ui /= 10 + } + if neg { + n-- + buf[n] = '-' + } + return string(buf[n:]) +} + +// Number of occurrences of b in s. +func count(s string, b byte) int { + n := 0 + for i := 0; i < len(s); i++ { + if s[i] == b { + n++ + } + } + return n +} + +// Returns the prefix of s up to but not including the character c +func prefixBefore(s string, c byte) string { + for i, v := range s { + if v == int(c) { + return s[0:i] + } + } + return s +} + +// Index of rightmost occurrence of b in s. +func last(s string, b byte) int { + i := len(s) + for i--; i >= 0; i-- { + if s[i] == b { + break + } + } + return i +} diff --git a/restest.go b/restest.go new file mode 100644 index 00000000..79cedaed --- /dev/null +++ b/restest.go @@ -0,0 +1,35 @@ +package main + +import ( + "dns" + "fmt" + "net" +) + +func main() { + res := new(dns.Resolver) + res.Servers = []string{"192.168.1.2"} + res.Timeout = 2 + res.Attempts = 1 + + a := new(dns.RR_A) + a.A = net.ParseIP("192.168.1.2").To4() + + aaaa := new(dns.RR_AAAA) + aaaa.AAAA = net.ParseIP("2003::53").To16() + + fmt.Printf("%v\n", a) + fmt.Printf("%v\n", aaaa) + +// msg, _ := res.Query("miek.nl.", dns.TypeTXT, dns.ClassINET) +// fmt.Printf("%v\n", msg) +// +// msg, _ = res.Query("www.nlnetlabs.nl", dns.TypeAAAA, dns.ClassINET) +// fmt.Printf("%v\n", msg) +// + msg, _ := res.Query("nlnetlabs.nl", dns.TypeDNSKEY, dns.ClassINET) + fmt.Printf("%v\n", msg) + + msg, _ = res.Query("jelte.nlnetlabs.nl", dns.TypeDS, dns.ClassINET) + fmt.Printf("%v\n", msg) +}