More TSIG changes. Curious if they amount to something

This commit is contained in:
Miek Gieben 2011-03-15 16:18:13 +01:00
parent 566b5f7d1a
commit 0059556516
3 changed files with 147 additions and 83 deletions

81
msg.go
View File

@ -16,7 +16,7 @@ package dns
import (
"os"
// "fmt"
// "fmt"
"reflect"
"net"
"rand"
@ -67,7 +67,7 @@ var Rr_str = map[uint16]string{
TypeTXT: "TXT",
TypeSRV: "SRV",
TypeNAPTR: "NAPTR",
TypeKX: "KX",
TypeKX: "KX",
TypeCERT: "CERT",
TypeDNAME: "DNAME",
TypeA: "A",
@ -75,7 +75,7 @@ var Rr_str = map[uint16]string{
TypeLOC: "LOC",
TypeOPT: "OPT",
TypeDS: "DS",
TypeDHCID: "DHCID",
TypeDHCID: "DHCID",
TypeIPSECKEY: "IPSECKEY",
TypeSSHFP: "SSHFP",
TypeRRSIG: "RRSIG",
@ -332,7 +332,7 @@ func packStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int, o
var _ = byte(fv.Elem(j).(*reflect.UintValue).Get())
}
// handle type bit maps
// TODO(mg)
// TODO(mg)
}
case *reflect.StructValue:
off, ok = packStructValue(fv, msg, off)
@ -426,10 +426,10 @@ func packStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int, o
// length of string. String is RAW (not encoded in hex, nor base64)
copy(msg[off:off+len(s)], s)
off += len(s)
case "txt":
// Counted string: 1 byte length, but the string may be longer
// than 255, in that case it should be multiple strings, for now:
fallthrough
case "txt":
// Counted string: 1 byte length, but the string may be longer
// than 255, in that case it should be multiple strings, for now:
fallthrough
case "":
// Counted string: 1 byte length.
if len(s) > 255 || off+1+len(s) > len(msg) {
@ -520,13 +520,13 @@ func unpackStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int,
//fmt.Fprintf(os.Stderr, "dns: overflow unpacking NSEC")
return len(msg), false
}
if blocks == 0 {
// Nothing encoded in this window
// Kinda lame to alloc above and to clear it here
nsec = nsec[:ni]
fv.Set(reflect.NewValue(nsec).(*reflect.SliceValue))
break
}
if blocks == 0 {
// Nothing encoded in this window
// Kinda lame to alloc above and to clear it here
nsec = nsec[:ni]
fv.Set(reflect.NewValue(nsec).(*reflect.SliceValue))
break
}
off += 2
for j := 0; j < blocks; j++ {
@ -689,15 +689,15 @@ func unpackStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int,
name := val.FieldByName("HashLength")
size = int(name.(*reflect.UintValue).Get())
}
case "RR_TSIG":
switch f.Name {
case "MAC":
name := val.FieldByName("MACSize")
size = int(name.(*reflect.UintValue).Get())
case "OtherData":
name := val.FieldByName("OtherLen")
size = int(name.(*reflect.UintValue).Get())
}
case "RR_TSIG":
switch f.Name {
case "MAC":
name := val.FieldByName("MACSize")
size = int(name.(*reflect.UintValue).Get())
case "OtherData":
name := val.FieldByName("OtherLen")
size = int(name.(*reflect.UintValue).Get())
}
}
if off+size > len(msg) {
//fmt.Fprintf(os.Stderr, "dns: failure unpacking size-hex string")
@ -705,10 +705,10 @@ func unpackStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int,
}
s = hex.EncodeToString(msg[off : off+size])
off += size
case "txt":
// 1 or multiple txt pieces
case "txt":
// 1 or multiple txt pieces
rdlength := int(val.FieldByName("Hdr").(*reflect.StructValue).FieldByName("Rdlength").(*reflect.UintValue).Get())
Txt:
Txt:
if off >= len(msg) || off+1+int(msg[off]) > len(msg) {
//fmt.Fprintf(os.Stderr, "dns: failure unpacking txt string")
return len(msg), false
@ -721,10 +721,10 @@ func unpackStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int,
}
off += n
s += string(b)
if off < rdlength {
// More to come
goto Txt
}
if off < rdlength {
// More to come
goto Txt
}
case "":
if off >= len(msg) || off+1+int(msg[off]) > len(msg) {
//fmt.Fprintf(os.Stderr, "dns: failure unpacking string")
@ -770,6 +770,10 @@ func unpackBase64(b []byte) string {
}
// Helper function for packing, mostly used in dnssec.go
func packUint16(i uint16) (byte, byte) {
return byte(i >> 8), byte(i)
}
func packBase64(s []byte) ([]byte, os.Error) {
b64len := base64.StdEncoding.DecodedLen(len(s))
buf := make([]byte, b64len)
@ -1022,25 +1026,32 @@ func (dns *Msg) String() string {
if len(dns.Question) > 0 {
s += "\n;; QUESTION SECTION:\n"
for i := 0; i < len(dns.Question); i++ {
s += dns.Question[i].String() + "\n"
// Need check if it exists? TODO(mg)
s += dns.Question[i].String() + "\n"
}
}
if len(dns.Answer) > 0 {
s += "\n;; ANSWER SECTION:\n"
for i := 0; i < len(dns.Answer); i++ {
s += dns.Answer[i].String() + "\n"
if dns.Answer[i] != nil {
s += dns.Answer[i].String() + "\n"
}
}
}
if len(dns.Ns) > 0 {
s += "\n;; AUTHORITY SECTION:\n"
for i := 0; i < len(dns.Ns); i++ {
s += dns.Ns[i].String() + "\n"
if dns.Ns[i] != nil {
s += dns.Ns[i].String() + "\n"
}
}
}
if len(dns.Extra) > 0 {
s += "\n;; ADDITIONAL SECTION:\n"
for i := 0; i < len(dns.Extra); i++ {
s += dns.Extra[i].String() + "\n"
if dns.Extra[i] != nil {
s += dns.Extra[i].String() + "\n"
}
}
}
return s

View File

@ -45,9 +45,10 @@ func (res *Resolver) Query(q *Msg) (d *Msg, err os.Error) {
// Check if there is a TSIG appended, if so, check it
var (
c net.Conn
in *Msg
port string
inb []byte
)
in := new(Msg)
if len(res.Servers) == 0 {
return nil, &Error{Error: "No servers defined"}
}
@ -81,9 +82,12 @@ func (res *Resolver) Query(q *Msg) (d *Msg, err os.Error) {
continue
}
if res.Tcp {
in, err = exchangeTCP(c, sending, res, true)
inb, err = exchangeTCP(c, sending, res, true)
in.Unpack(inb)
} else {
in, err = exchangeUDP(c, sending, res, true)
inb, err = exchangeUDP(c, sending, res, true)
in.Unpack(inb)
}
res.Rtt[server] = time.Nanoseconds() - t
@ -114,9 +118,12 @@ type Xfr struct {
// Channel m is closed when the IXFR ends.
func (res *Resolver) Ixfr(q *Msg, m chan Xfr) {
// TSIG
var port string
var in *Msg
var x Xfr
var (
port string
x Xfr
inb []byte
)
in := new(Msg)
if res.Port == "" {
port = "53"
} else {
@ -149,9 +156,11 @@ Server:
defer c.Close()
for {
if first {
in, err = exchangeTCP(c, sending, res, true)
inb, err = exchangeTCP(c, sending, res, true)
in.Unpack(inb)
} else {
in, err = exchangeTCP(c, sending, res, false)
inb, err = exchangeTCP(c, sending, res, false)
in.Unpack(inb)
}
if err != nil {
@ -220,8 +229,11 @@ Server:
// the zone as-is. Xfr.Add is always true.
// The channel is closed to signal the end of the AXFR.
func (res *Resolver) AxfrTSIG(q *Msg, m chan Xfr, secret string) {
var port string
var in *Msg
var (
port string
inb []byte
)
in := new(Msg)
if res.Port == "" {
port = "53"
} else {
@ -263,9 +275,17 @@ Server:
defer c.Close() // TODO(mg): if not open?
for {
if first {
in, err = exchangeTCP(c, sending, res, true)
inb, err = exchangeTCP(c, sending, res, true)
stripTSIG(inb)
/*
pt2 := new(Msg)
pt2.Unpack(t2)
//println("P", pt2.String())
*/
in.Unpack(inb)
} else {
in, err = exchangeTCP(c, sending, res, false)
inb, err = exchangeTCP(c, sending, res, false)
in.Unpack(inb)
}
if err != nil {
@ -282,7 +302,7 @@ Server:
t := in.Extra[len(in.Extra)-1]
switch t.(type) {
case *RR_TSIG:
if t.(*RR_TSIG).Verify(in, secret, reqmac) {
if t.(*RR_TSIG).Verify(inb, secret, reqmac) {
println("Validates")
} else {
println("DOES NOT validate")
@ -322,8 +342,11 @@ Server:
// the zone as-is. Xfr.Add is always true.
// The channel is closed to signal the end of the AXFR.
func (res *Resolver) Axfr(q *Msg, m chan Xfr) {
var port string
var in *Msg
var (
port string
inb []byte
)
in := new(Msg)
if res.Port == "" {
port = "53"
} else {
@ -343,17 +366,6 @@ func (res *Resolver) Axfr(q *Msg, m chan Xfr) {
return
}
/*
// Need the secret!
var tsig *RR_TSIG
// Check if there is a TSIG added
if len(q.Extra) > 0 {
lastrr := q.Extra[len(q.Extra)-1]
if lastrr.Header().Rrtype == TypeTSIG {
tsig = lastrr.(*RR_TSIG)
}
}
*/
Server:
for i := 0; i < len(res.Servers); i++ {
server := res.Servers[i] + ":" + port
@ -365,9 +377,11 @@ Server:
defer c.Close() // TODO(mg): if not open?
for {
if first {
in, err = exchangeTCP(c, sending, res, true)
inb, err = exchangeTCP(c, sending, res, true)
in.Unpack(inb)
} else {
in, err = exchangeTCP(c, sending, res, false)
inb, err = exchangeTCP(c, sending, res, false)
in.Unpack(inb)
}
if err != nil {
@ -408,7 +422,7 @@ Server:
// Send a request on the connection and hope for a reply.
// Up to res.Attempts attempts. If send is false, nothing
// is send.
func exchangeUDP(c net.Conn, m []byte, r *Resolver, send bool) (*Msg, os.Error) {
func exchangeUDP(c net.Conn, m []byte, r *Resolver, send bool) ([]byte, os.Error) {
var timeout int64
var attempts int
if r.Mangle != nil {
@ -443,18 +457,13 @@ func exchangeUDP(c net.Conn, m []byte, r *Resolver, send bool) (*Msg, os.Error)
}
return nil, err
}
in := new(Msg)
if !in.Unpack(buf) {
continue
}
return in, nil
return buf, nil
}
return nil, &Error{Error: servErr}
}
// Up to res.Attempts attempts.
func exchangeTCP(c net.Conn, m []byte, r *Resolver, send bool) (*Msg, os.Error) {
func exchangeTCP(c net.Conn, m []byte, r *Resolver, send bool) ([]byte, os.Error) {
var timeout int64
var attempts int
if r.Mangle != nil {
@ -484,7 +493,7 @@ func exchangeTCP(c net.Conn, m []byte, r *Resolver, send bool) (*Msg, os.Error)
}
c.SetReadTimeout(timeout * 1e9) // nanoseconds
// The server replies with two bytes length
// The server replies with two bytes length.
buf, err := recvTCP(c)
if err != nil {
if e, ok := err.(net.Error); ok && e.Timeout() {
@ -492,11 +501,7 @@ func exchangeTCP(c net.Conn, m []byte, r *Resolver, send bool) (*Msg, os.Error)
}
return nil, err
}
in := new(Msg)
if !in.Unpack(buf) {
continue
}
return in, nil
return buf, nil
}
return nil, &Error{Error: servErr}
}
@ -510,7 +515,7 @@ func sendUDP(m []byte, c net.Conn) os.Error {
}
func recvUDP(c net.Conn) ([]byte, os.Error) {
m := make([]byte, DefaultMsgSize) // More than enough???
m := make([]byte, DefaultMsgSize)
n, err := c.Read(m)
if err != nil {
return nil, err
@ -537,8 +542,7 @@ func sendTCP(m []byte, c net.Conn) os.Error {
}
func recvTCP(c net.Conn) ([]byte, os.Error) {
l := make([]byte, 2) // receiver length
// The server replies with two bytes length
l := make([]byte, 2) // The server replies with two bytes length.
_, err := c.Read(l)
if err != nil {
return nil, err

55
tsig.go
View File

@ -107,7 +107,7 @@ func (t *RR_TSIG) Generate(m *Msg, secret string) bool {
// the TSIG record still attached (as the last rr in the Additional
// section). Return true on success.
// The secret is a base64 encoded string with the secret.
func (t *RR_TSIG) Verify(m *Msg, secret, reqmac string) bool {
func (t *RR_TSIG) Verify(m []byte, secret, reqmac string) bool {
rawsecret, err := packBase64([]byte(secret))
if err != nil {
return false
@ -121,9 +121,8 @@ func (t *RR_TSIG) Verify(m *Msg, secret, reqmac string) bool {
if t.Header().Rrtype != TypeTSIG {
return false
}
println(msg2.String())
msg2.MsgHdr.Id = t.OrigId
println(msg2.String())
msg2.Extra = msg2.Extra[:len(msg2.Extra)-1] // Strip off the TSIG
buf, ok := tsigToBuf(t, msg2, reqmac)
if !ok {
@ -182,3 +181,53 @@ func tsigToBuf(rr *RR_TSIG, msg *Msg, reqmac string) ([]byte, bool) {
}
return buf, true
}
// Strip the TSIG from the pkt.
func stripTSIG(orig []byte) ([]byte, bool) {
// Copied from msg.go's Unpack()
// Header.
var dh Header
dns := new(Msg)
msg := make([]byte, len(orig))
copy(msg, orig) // fhhh.. another copy
off := 0
tsigoff := 0
var ok bool
if off, ok = unpackStruct(&dh, msg, off); !ok {
return nil, false
}
if dh.Arcount == 0 {
// No records at all in the additional.
return nil, false
}
// Arrays.
dns.Question = make([]Question, dh.Qdcount)
dns.Answer = make([]RR, dh.Ancount)
dns.Ns = make([]RR, dh.Nscount)
dns.Extra = make([]RR, dh.Arcount)
for i := 0; i < len(dns.Question); i++ {
off, ok = unpackStruct(&dns.Question[i], msg, off)
}
for i := 0; i < len(dns.Answer); i++ {
dns.Answer[i], off, ok = unpackRR(msg, off)
}
for i := 0; i < len(dns.Ns); i++ {
dns.Ns[i], off, ok = unpackRR(msg, off)
}
for i := 0; i < len(dns.Extra); i++ {
tsigoff = off
dns.Extra[i], off, ok = unpackRR(msg, off)
if dns.Extra[i].Header().Rrtype == TypeTSIG {
// Adjust Arcount.
arcount, _ := unpackUint16(msg, 10)
msg[10], msg[11] = packUint16(arcount-1)
break
}
}
if !ok {
return nil, false
}
return msg[:tsigoff], true
}