The example test performs an axfr, but as ixfr differs slightly
it should also support ixfr
This commit is contained in:
Miek Gieben 2011-01-01 17:42:13 +01:00
parent 1c9282ed7e
commit 43ebf75fac
4 changed files with 226 additions and 168 deletions

4
TODO
View File

@ -1,8 +1,7 @@
Todo:
* fix os.Erros usage, add DNSSEC related errors
* DNSSEC validation
* NSEC(3) secure denial of existence
* fix os.Erros usage, add DNSSEC related errors
* AXFR/IXFR support
* Notify support (send/receive)
* TKEY -- RFC 2930
* TSIG -- RFC 4635
@ -16,7 +15,6 @@ Todo:
Issues:
* Better sized buffers
* TC bit handling
* Fix TCP: should read 2 bytes of length
* shortened ipv6 addresses are not parsed correctly
* quoted quotes in txt records
* Convience functions?

7
msg.go
View File

@ -20,6 +20,8 @@ import (
"os"
"reflect"
"net"
"rand"
"time"
"strconv"
"encoding/base64"
"encoding/hex"
@ -751,3 +753,8 @@ func (dns *Msg) String() string {
}
return s
}
// Set an Msg Id to a random value
func (m *Msg) SetId() {
m.Id = uint16(rand.Int()) ^ uint16(time.Nanoseconds())
}

24
resolver/axfr_test.go Normal file
View File

@ -0,0 +1,24 @@
package resolver
import (
"testing"
"fmt"
"dns"
)
func TestAXFR(t *testing.T) {
res := new(Resolver)
ch := NewXfer(res)
res.Servers = []string{"127.0.0.1"}
m := new(dns.Msg)
m.Question = make([]dns.Question, 1)
// ask something
m.Question[0] = dns.Question{"miek.nl", dns.TypeAXFR, dns.ClassINET}
ch <- DnsMsg{m, nil}
for dm := range ch {
fmt.Printf("%v\n",dm.Dns)
}
}

View File

@ -24,10 +24,8 @@ package resolver
import (
"os"
"rand"
"time"
"net"
"dns"
"dns"
)
// When communicating with a resolver, we use this structure
@ -64,13 +62,13 @@ func query(res *Resolver, msg chan DnsMsg) {
var c net.Conn
var err os.Error
var in *dns.Msg
var port string
// len(res.Server) == 0 can be perfectly valid, when setting up the resolver
if res.Port == "" {
port = "53"
} else {
port = res.Port
}
var port string
// len(res.Server) == 0 can be perfectly valid, when setting up the resolver
if res.Port == "" {
port = "53"
} else {
port = res.Port
}
for {
select {
@ -83,14 +81,13 @@ func query(res *Resolver, msg chan DnsMsg) {
}
var cerr os.Error
// Set an id
//if len(name) >= 256 {
out.Dns.Id = uint16(rand.Int()) ^ uint16(time.Nanoseconds())
out.Dns.SetId()
sending, ok := out.Dns.Pack()
if !ok {
//println("pack failed")
//println("pack failed")
msg <- DnsMsg{nil, nil} // todo error
continue;
continue
}
for i := 0; i < len(res.Servers); i++ {
@ -100,18 +97,19 @@ func query(res *Resolver, msg chan DnsMsg) {
} else {
c, cerr = net.Dial("udp", "", server)
}
defer c.Close()
if cerr != nil {
err = cerr
continue
}
if res.Tcp {
in, err = exchange_tcp(c, sending, res)
} else {
in, err = exchange_udp(c, sending, res)
}
if res.Tcp {
in, err = exchange_tcp(c, sending, res, true)
} else {
in, err = exchange_udp(c, sending, res, true)
}
// Check id in.id != out.id
// TODO(mg)
// TODO(mg)
c.Close()
if err != nil {
@ -132,102 +130,130 @@ func query(res *Resolver, msg chan DnsMsg) {
// Channel will be closed when the axfr is finished, until
// that time new messages will appear on the channel
func NewXfer(res *Resolver) (ch chan DnsMsg) {
ch = make(chan DnsMsg)
go axfr(res, ch)
return
ch = make(chan DnsMsg)
go axfr(res, ch)
return
}
func axfr(res *Resolver, msg chan DnsMsg) {
// open socket
// call exchange_tcp
// check for soa
// repeat if no soa
// close channel
var port string
var err os.Error
var in *dns.Msg
if res.Port == "" {
port = "53"
} else {
port = res.Port
}
var sending []byte
// out.Dns.Id = uint16(rand.Int()) ^ uint16(time.Nanoseconds())
// sending, ok := out.Dns.Pack()
for i:=0; i<len(res.Servers); i++ {
server := res.Servers[i] + port
c, cerr := net.Dial("tcp", "", server)
if cerr != nil {
err = cerr
continue
}
var port string
var err os.Error
var in *dns.Msg
if res.Port == "" {
port = "53"
} else {
port = res.Port
}
first := true
for {
for {
select {
case out := <-msg: // msg received
if out.Dns == nil {
// stop
msg <- DnsMsg{nil, nil}
close(msg)
return
}
in, cerr = exchange_tcp(c, sending, res)
if cerr != nil {
err = cerr
continue
}
if first {
if ! checkSOA(in, true) {
// SOA record not there...
return // ?
}
first = !first
// send in to message
out.Dns.SetId()
sending, ok := out.Dns.Pack()
if !ok {
msg <- DnsMsg{nil, nil}
}
SERVER:
for i := 0; i < len(res.Servers); i++ {
server := res.Servers[i] + ":" + port
c, cerr := net.Dial("tcp", "", server)
if cerr != nil {
err = cerr
continue SERVER
}
defer c.Close()
continue
} else {
if ! checkSOA(in, false) {
// Soa record not the last one
// return packet
// next
} else {
// last one
// close channel
}
}
}
}
err = err
first := true
// Start the AXFR
for {
if first {
in, cerr = exchange_tcp(c, sending, res, true)
} else {
in, cerr = exchange_tcp(c, sending, res, false)
}
return
if cerr != nil {
// Failed to send, try the next
err = cerr
continue SERVER
}
if first {
if !checkSOA(in, true) {
// SOA record not there...
c.Close()
continue SERVER
}
first = !first
}
if !first {
if !checkSOA(in, false) {
// Soa record not the last one
msg <- DnsMsg{in, nil}
continue
// next
} else {
c.Close()
msg <- DnsMsg{in, nil}
close(msg)
return
}
}
}
close(msg)
return
}
// With 1 successfull server, we dont get here, so
// We've failed
msg <- DnsMsg{nil, err} // TODO Err
close(msg)
return
}
}
return
}
// Send a request on the connection and hope for a reply.
// Up to res.Attempts attempts.
func exchange_udp(c net.Conn, m []byte, r *Resolver) (*dns.Msg, os.Error) {
var timeout int64
var attempts int
func exchange_udp(c net.Conn, m []byte, r *Resolver, send bool) (*dns.Msg, os.Error) {
var timeout int64
var attempts int
if r.Mangle != nil {
m = r.Mangle(m)
}
if r.Timeout == 0 {
timeout = 1
} else {
timeout = int64(r.Timeout)
}
if r.Attempts == 0 {
attempts = 1
} else {
attempts = r.Attempts
}
for a:= 0; a < attempts; a++ {
n, err := c.Write(m)
if err != nil {
//println("error writing")
return nil, err
}
if r.Timeout == 0 {
timeout = 1
} else {
timeout = int64(r.Timeout)
}
if r.Attempts == 0 {
attempts = 1
} else {
attempts = r.Attempts
}
for a := 0; a < attempts; a++ {
if send {
_, err := c.Write(m)
if err != nil {
//println("error writing")
return nil, err
}
}
c.SetReadTimeout(timeout * 1e9) // nanoseconds
buf := make([]byte, dns.DefaultMsgSize) // More than enough???
n, err = c.Read(buf)
n, err := c.Read(buf)
if err != nil {
//println("error reading")
//println(err.String())
//println("error reading")
//println(err.String())
// More Go foo needed
//if e, ok := err.(Error); ok && e.Timeout() {
// continue
@ -245,81 +271,84 @@ func exchange_udp(c net.Conn, m []byte, r *Resolver) (*dns.Msg, os.Error) {
}
// Up to res.Attempts attempts.
func exchange_tcp(c net.Conn, m []byte, r *Resolver) (*dns.Msg, os.Error) {
var timeout int64
var attempts int
if r.Mangle != nil {
m = r.Mangle(m)
}
if r.Timeout == 0 {
timeout = 1
} else {
timeout = int64(r.Timeout)
}
if r.Attempts == 0 {
attempts = 1
} else {
attempts = r.Attempts
}
func exchange_tcp(c net.Conn, m []byte, r *Resolver, send bool) (*dns.Msg, os.Error) {
var timeout int64
var attempts int
if r.Mangle != nil {
m = r.Mangle(m)
}
if r.Timeout == 0 {
timeout = 1
} else {
timeout = int64(r.Timeout)
}
if r.Attempts == 0 {
attempts = 1
} else {
attempts = r.Attempts
}
ls := make([]byte, 2) // sender length
lr := make([]byte, 2) // receiver length
var length uint16
ls[0] = byte(len(m) >> 8)
ls[1] = byte(len(m))
for a := 0; a < attempts; a++ {
// With DNS over TCP we first send the length
_, err := c.Write(ls)
if err != nil {
return nil, err
ls := make([]byte, 2) // sender length
lr := make([]byte, 2) // receiver length
var length uint16
ls[0] = byte(len(m) >> 8)
ls[1] = byte(len(m))
for a := 0; a < attempts; a++ {
// only send something when told so
if send {
// With DNS over TCP we first send the length
_, err := c.Write(ls)
if err != nil {
return nil, err
}
// And then send the message
_, err = c.Write(m)
if err != nil {
return nil, err
}
}
// And then send the message
_, err = c.Write(m)
if err != nil {
return nil, err
}
c.SetReadTimeout(timeout * 1e9) // nanoseconds
// The server replies with two bytes length
_, err = c.Read(lr)
if err != nil {
return nil, err
}
length = uint16(lr[0])<<8 | uint16(lr[1])
// if length is 0??
// And then the message
buf := make([]byte, length)
_, err = c.Read(buf)
if err != nil {
//println("error reading")
//println(err.String())
// More Go foo needed
//if e, ok := err.(Error); ok && e.Timeout() {
// continue
//}
return nil, err
}
in := new(dns.Msg)
if !in.Unpack(buf) {
// println("unpacking went wrong")
continue
}
return in, nil
}
return nil, nil // todo error
c.SetReadTimeout(timeout * 1e9) // nanoseconds
// The server replies with two bytes length
_, err := c.Read(lr)
if err != nil {
return nil, err
}
length = uint16(lr[0])<<8 | uint16(lr[1])
// if length is 0??
// And then the message
buf := make([]byte, length)
_, err = c.Read(buf)
if err != nil {
//println("error reading")
//println(err.String())
// More Go foo needed
//if e, ok := err.(Error); ok && e.Timeout() {
// continue
//}
return nil, err
}
in := new(dns.Msg)
if !in.Unpack(buf) {
// println("unpacking went wrong")
continue
}
return in, nil
}
return nil, nil // todo error
}
// Check if he SOA record exists in the Answer section of
// the packet. If first is true the first RR must be a soa
// if false, the last one should be a SOA
func checkSOA(in *dns.Msg, first bool) bool {
if len(in.Answer) > 0 {
if first {
return in.Answer[0].Header().Rrtype == dns.TypeSOA
} else {
return in.Answer[len(in.Answer)].Header().Rrtype == dns.TypeSOA
}
}
return false
if len(in.Answer) > 0 {
if first {
return in.Answer[0].Header().Rrtype == dns.TypeSOA
} else {
return in.Answer[len(in.Answer)-1].Header().Rrtype == dns.TypeSOA
}
}
return false
}