Merge branch 'master' of github.com:miekg/dns

This commit is contained in:
Miek Gieben 2012-08-31 15:10:46 +02:00
commit 0a586f5ebb
3 changed files with 60 additions and 58 deletions

View File

@ -126,13 +126,13 @@ func TestPack(t *testing.T) {
x.Answer = make([]RR, 1)
x.Answer[0], err = NewRR(rr[0])
if _, ok := x.Pack(); ok {
t.Log("Packing failed")
t.Log("Packing should fail")
t.Fail()
}
x.Question = make([]Question, 1)
x.Question[0] = Question{";sd#edddds鍛↙赏‘℅∥↙xzztsestxssweewwsssstx@s@Z嵌e@cn.pool.ntp.org.", TypeA, ClassINET}
if _, ok := x.Pack(); !ok {
t.Log("Packing failed")
if _, ok := x.Pack(); ok {
t.Log("Packing should fail")
t.Fail()
}
}

92
msg.go
View File

@ -391,14 +391,14 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
case reflect.Slice:
switch val.Type().Field(i).Tag.Get("dns") {
default:
println("dns: unknown tag packing slice", val.Type().Field(i).Tag.Get("dns"), '"', val.Type().Field(i).Tag, '"')
// println("dns: unknown tag packing slice", val.Type().Field(i).Tag.Get("dns"), '"', val.Type().Field(i).Tag, '"')
return lenmsg, false
case "domain-name":
for j := 0; j < val.Field(i).Len(); j++ {
element := val.Field(i).Index(j).String()
off, ok = PackDomainName(element, msg, off, compression, false && compress)
if !ok {
println("dns: overflow packing domain-name", off)
// println("dns: overflow packing domain-name", off)
return lenmsg, false
}
}
@ -407,7 +407,7 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
element := val.Field(i).Index(j).String()
// Counted string: 1 byte length.
if len(element) > 255 || off+1+len(element) > lenmsg {
println("dns: overflow packing TXT string")
// println("dns: overflow packing TXT string")
return lenmsg, false
}
msg[off] = byte(len(element))
@ -422,7 +422,7 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
element := val.Field(i).Index(j).Interface()
b, e := element.(EDNS0).Pack()
if e != nil {
println("dns: failure packing OPT")
// println("dns: failure packing OPT")
return lenmsg, false
}
// Option code
@ -440,7 +440,7 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
switch fv.Len() {
case net.IPv6len:
if off+net.IPv4len > lenmsg {
println("dns: overflow packing A", off, lenmsg)
// println("dns: overflow packing A", off, lenmsg)
return lenmsg, false
}
msg[off] = byte(fv.Index(12).Uint())
@ -450,7 +450,7 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
off += net.IPv4len
case net.IPv4len:
if off+net.IPv4len > lenmsg {
println("dns: overflow packing A", off, lenmsg)
// println("dns: overflow packing A", off, lenmsg)
return lenmsg, false
}
msg[off] = byte(fv.Index(0).Uint())
@ -461,12 +461,12 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
case 0:
// Allowed, for dynamic updates
default:
println("dns: overflow packing A")
// println("dns: overflow packing A")
return lenmsg, false
}
case "aaaa":
if fv.Len() > net.IPv6len || off+fv.Len() > lenmsg {
println("dns: overflow packing AAAA")
// println("dns: overflow packing AAAA")
return lenmsg, false
}
for j := 0; j < net.IPv6len; j++ {
@ -482,7 +482,7 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
serv := uint16((fv.Index(j).Uint()))
bitmapbyte = uint16(serv / 8)
if int(bitmapbyte) > lenmsg {
println("dns: overflow packing WKS")
// println("dns: overflow packing WKS")
return lenmsg, false
}
bit := uint16(serv) - bitmapbyte*8
@ -499,7 +499,7 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
lastwindow := uint16(0)
length := uint16(0)
if off+2 > lenmsg {
println("dns: overflow packing NSECx bitmap")
// println("dns: overflow packing NSECx bitmap")
return lenmsg, false
}
for j := 0; j < val.Field(i).Len(); j++ {
@ -509,14 +509,14 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
// New window, jump to the new offset
off += int(length) + 3
if off > lenmsg {
println("dns: overflow packing NSECx bitmap")
// println("dns: overflow packing NSECx bitmap")
return lenmsg, false
}
}
length = (t - window*256) / 8
bit := t - (window * 256) - (length * 8)
if off+2+int(length) > lenmsg {
println("dns: overflow packing NSECx bitmap")
// println("dns: overflow packing NSECx bitmap")
return lenmsg, false
}
@ -531,7 +531,7 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
off += 2 + int(length)
off++
if off > lenmsg {
println("dns: overflow packing NSECx bitmap")
// println("dns: overflow packing NSECx bitmap")
return lenmsg, false
}
}
@ -539,14 +539,14 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
off, ok = packStructValue(fv, msg, off, compression, compress)
case reflect.Uint8:
if off+1 > lenmsg {
println("dns: overflow packing uint8")
// println("dns: overflow packing uint8")
return lenmsg, false
}
msg[off] = byte(fv.Uint())
off++
case reflect.Uint16:
if off+2 > lenmsg {
println("dns: overflow packing uint16")
// println("dns: overflow packing uint16")
return lenmsg, false
}
i := fv.Uint()
@ -555,7 +555,7 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
off += 2
case reflect.Uint32:
if off+4 > lenmsg {
println("dns: overflow packing uint32")
// println("dns: overflow packing uint32")
return lenmsg, false
}
i := fv.Uint()
@ -567,7 +567,7 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
case reflect.Uint64:
// Only used in TSIG, where it stops at 48 bits, so we discard the upper 16
if off+6 > lenmsg {
println("dns: overflow packing uint64")
// println("dns: overflow packing uint64")
return lenmsg, false
}
i := fv.Uint()
@ -588,19 +588,19 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
case "base64":
b64, err := packBase64([]byte(s))
if err != nil {
println("dns: overflow packing base64")
// println("dns: overflow packing base64")
return lenmsg, false
}
copy(msg[off:off+len(b64)], b64)
off += len(b64)
case "domain-name":
if off, ok = PackDomainName(s, msg, off, compression, false && compress); !ok {
println("dns: overflow packing domain-name", off)
// println("dns: overflow packing domain-name", off)
return lenmsg, false
}
case "cdomain-name":
if off, ok = PackDomainName(s, msg, off, compression, true && compress); !ok {
println("dns: overflow packing domain-name", off)
// println("dns: overflow packing domain-name", off)
return lenmsg, false
}
case "size-base32":
@ -612,7 +612,7 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
case "base32":
b32, err := packBase32([]byte(s))
if err != nil {
println("dns: overflow packing base32")
// println("dns: overflow packing base32")
return lenmsg, false
}
copy(msg[off:off+len(b32)], b32)
@ -623,7 +623,7 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
// There is no length encoded here
h, e := hex.DecodeString(s)
if e != nil {
println("dns: overflow packing (size-)hex string")
// println("dns: overflow packing (size-)hex string")
return lenmsg, false
}
if off+hex.DecodedLen(len(s)) > lenmsg {
@ -642,7 +642,7 @@ func packStructValue(val reflect.Value, msg []byte, off int, compression map[str
case "":
// Counted string: 1 byte length.
if len(s) > 255 || off+1+len(s) > lenmsg {
println("dns: overflow packing string")
// println("dns: overflow packing string")
return lenmsg, false
}
msg[off] = byte(len(s))
@ -680,12 +680,12 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok boo
lenmsg := len(msg)
switch fv := val.Field(i); fv.Kind() {
default:
println("dns: unknown case unpacking struct")
// println("dns: unknown case unpacking struct")
return lenmsg, false
case reflect.Slice:
switch val.Type().Field(i).Tag.Get("dns") {
default:
println("dns: unknown tag unpacking slice", val.Type().Field(i).Tag)
// println("dns: unknown tag unpacking slice", val.Type().Field(i).Tag)
return lenmsg, false
case "domain-name":
// HIP record slice of name (or none)
@ -694,7 +694,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok boo
for off < lenmsg {
s, off, ok = UnpackDomainName(msg, off)
if !ok {
println("dns: failure unpacking domain-name")
// println("dns: failure unpacking domain-name")
return lenmsg, false
}
servers = append(servers, s)
@ -706,7 +706,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok boo
Txts:
l := int(msg[off])
if off+l+1 > lenmsg {
println("dns: failure unpacking txt strings")
// println("dns: failure unpacking txt strings")
return lenmsg, false
}
txt = append(txt, string(msg[off+1:off+l+1]))
@ -731,7 +731,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok boo
code, off = unpackUint16(msg, off) // Overflow? TODO
optlen, off1 := unpackUint16(msg, off)
if off1+int(optlen) > off+rdlength {
println("dns: overflow unpacking OPT")
// println("dns: overflow unpacking OPT")
return lenmsg, false
}
switch code {
@ -750,14 +750,14 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok boo
// goto ??
case "a":
if off+net.IPv4len > len(msg) {
println("dns: overflow unpacking A")
// println("dns: overflow unpacking A")
return lenmsg, false
}
fv.Set(reflect.ValueOf(net.IPv4(msg[off], msg[off+1], msg[off+2], msg[off+3])))
off += net.IPv4len
case "aaaa":
if off+net.IPv6len > lenmsg {
println("dns: overflow unpacking AAAA")
// println("dns: overflow unpacking AAAA")
return lenmsg, false
}
fv.Set(reflect.ValueOf(net.IP{msg[off], msg[off+1], msg[off+2], msg[off+3], msg[off+4],
@ -807,7 +807,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok boo
endrr := rdstart + rdlength
if off+2 > lenmsg {
println("dns: overflow unpacking NSEC")
// println("dns: overflow unpacking NSEC")
return lenmsg, false
}
nsec := make([]uint16, 0)
@ -820,11 +820,11 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok boo
if length == 0 {
// A length window of zero is strange. If there
// the window should not have been specified. Bail out
println("dns: length == 0 when unpacking NSEC")
// println("dns: length == 0 when unpacking NSEC")
return lenmsg, false
}
if length > 32 {
println("dns: length > 32 when unpacking NSEC")
// println("dns: length > 32 when unpacking NSEC")
return lenmsg, false
}
@ -870,7 +870,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok boo
}
case reflect.Uint8:
if off+1 > lenmsg {
println("dns: overflow unpacking uint8")
// println("dns: overflow unpacking uint8")
return lenmsg, false
}
fv.SetUint(uint64(uint8(msg[off])))
@ -878,14 +878,14 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok boo
case reflect.Uint16:
var i uint16
if off+2 > lenmsg {
println("dns: overflow unpacking uint16")
// println("dns: overflow unpacking uint16")
return lenmsg, false
}
i, off = unpackUint16(msg, off)
fv.SetUint(uint64(i))
case reflect.Uint32:
if off+4 > lenmsg {
println("dns: overflow unpacking uint32")
// println("dns: overflow unpacking uint32")
return lenmsg, false
}
fv.SetUint(uint64(uint32(msg[off])<<24 | uint32(msg[off+1])<<16 | uint32(msg[off+2])<<8 | uint32(msg[off+3])))
@ -894,7 +894,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok boo
// This is *only* used in TSIG where the last 48 bits are occupied
// So for now, assume a uint48 (6 bytes)
if off+6 > lenmsg {
println("dns: overflow unpacking uint64")
// println("dns: overflow unpacking uint64")
return lenmsg, false
}
fv.SetUint(uint64(uint64(msg[off])<<40 | uint64(msg[off+1])<<32 | uint64(msg[off+2])<<24 | uint64(msg[off+3])<<16 |
@ -904,14 +904,14 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok boo
var s string
switch val.Type().Field(i).Tag.Get("dns") {
default:
println("dns: unknown tag unpacking string")
// println("dns: unknown tag unpacking string")
return lenmsg, false
case "hex":
// Rest of the RR is hex encoded, network order an issue here?
rdlength := int(val.FieldByName("Hdr").FieldByName("Rdlength").Uint())
endrr := rdstart + rdlength
if endrr > lenmsg {
println("dns: overflow when unpacking hex string")
// println("dns: overflow when unpacking hex string")
return lenmsg, false
}
s = hex.EncodeToString(msg[off:endrr])
@ -921,7 +921,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok boo
rdlength := int(val.FieldByName("Hdr").FieldByName("Rdlength").Uint())
endrr := rdstart + rdlength
if endrr > lenmsg {
println("dns: failure unpacking base64")
// println("dns: failure unpacking base64")
return lenmsg, false
}
s = unpackBase64(msg[off:endrr])
@ -931,7 +931,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok boo
case "domain-name":
s, off, ok = UnpackDomainName(msg, off)
if !ok {
println("dns: failure unpacking domain-name")
// println("dns: failure unpacking domain-name")
return lenmsg, false
}
case "size-base32":
@ -945,7 +945,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok boo
}
}
if off+size > lenmsg {
println("dns: failure unpacking size-base32 string")
// println("dns: failure unpacking size-base32 string")
return lenmsg, false
}
s = unpackBase32(msg[off : off+size])
@ -974,7 +974,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok boo
}
}
if off+size > lenmsg {
println("dns: failure unpacking size-hex string")
// println("dns: failure unpacking size-hex string")
return lenmsg, false
}
s = hex.EncodeToString(msg[off : off+size])
@ -984,7 +984,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok boo
rdlength := int(val.FieldByName("Hdr").FieldByName("Rdlength").Uint())
Txt:
if off >= lenmsg || off+1+int(msg[off]) > lenmsg {
println("dns: failure unpacking txt string")
// println("dns: failure unpacking txt string")
return lenmsg, false
}
n := int(msg[off])
@ -999,7 +999,7 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok boo
}
case "":
if off >= lenmsg || off+1+int(msg[off]) > lenmsg {
println("dns: failure unpacking string")
// println("dns: failure unpacking string")
return lenmsg, false
}
n := int(msg[off])
@ -1290,7 +1290,7 @@ func (dns *Msg) Unpack(msg []byte) bool {
}
if off != len(msg) {
// TODO(mg) remove eventually
println("extra bytes in dns packet", off, "<", len(msg))
// println("extra bytes in dns packet", off, "<", len(msg))
}
return true
}

View File

@ -198,7 +198,7 @@ func (srv *Server) ListenAndServe() error {
if e != nil {
return e
}
return srv.ServeTCP(l)
return srv.serveTCP(l)
case "udp", "udp4", "udp6":
a, e := net.ResolveUDPAddr(srv.Net, addr)
if e != nil {
@ -208,14 +208,14 @@ func (srv *Server) ListenAndServe() error {
if e != nil {
return e
}
return srv.ServeUDP(l)
return srv.serveUDP(l)
}
return &Error{Err: "bad network"}
}
// ServeTCP starts a TCP listener for the server.
// serveTCP starts a TCP listener for the server.
// Each request is handled in a seperate goroutine.
func (srv *Server) ServeTCP(l *net.TCPListener) error {
func (srv *Server) serveTCP(l *net.TCPListener) error {
defer l.Close()
handler := srv.Handler
if handler == nil {
@ -225,7 +225,8 @@ forever:
for {
rw, e := l.AcceptTCP()
if e != nil {
return e
// don't bail out, but wait for a new request
continue
}
if srv.ReadTimeout != 0 {
rw.SetReadDeadline(time.Now().Add(srv.ReadTimeout))
@ -265,9 +266,9 @@ forever:
panic("dns: not reached")
}
// ServeUDP starts a UDP listener for the server.
// serveUDP starts a UDP listener for the server.
// Each request is handled in a seperate goroutine.
func (srv *Server) ServeUDP(l *net.UDPConn) error {
func (srv *Server) serveUDP(l *net.UDPConn) error {
defer l.Close()
handler := srv.Handler
if handler == nil {
@ -280,7 +281,8 @@ func (srv *Server) ServeUDP(l *net.UDPConn) error {
m := make([]byte, srv.UDPSize)
n, a, e := l.ReadFromUDP(m)
if e != nil || n == 0 {
return e
// don't bail out, but wait for a new request
continue
}
m = m[:n]
@ -314,7 +316,7 @@ func newConn(t *net.TCPConn, u *net.UDPConn, a net.Addr, buf []byte, handler Han
func (c *conn) serve() {
// for block to make it easy to break out to close the tcp connection
for {
// Request has been read in ServeUDP or ServeTCP
// Request has been read in serveUDP or serveTCP
w := new(response)
w.conn = c
req := new(Msg)