From 05a68526632cf11f656db231533b2767b138c540 Mon Sep 17 00:00:00 2001 From: Miek Gieben Date: Mon, 18 Apr 2011 18:27:59 +0200 Subject: [PATCH] port tcp read --- _examples/axfr/axfr.go | 17 +++++++++-------- client.go | 28 +++++++++++++++++++++++++++- defaults.go | 2 ++ xfr.go | 5 ++--- 4 files changed, 40 insertions(+), 12 deletions(-) diff --git a/_examples/axfr/axfr.go b/_examples/axfr/axfr.go index f25cebda..12f48f32 100644 --- a/_examples/axfr/axfr.go +++ b/_examples/axfr/axfr.go @@ -12,18 +12,19 @@ func main() { flag.Parse() zone := flag.Arg(flag.NArg() - 1) - c := make(chan *dns.Xfr) - d := new(dns.Conn) - m := new(dns.Msg) - - d.RemoteAddr = *nameserver + // only UDP works atm + client := dns.NewClient() + m := new(dns.Msg) if *serial > 0 { m.SetIxfr(zone, uint32(*serial)) } else { m.SetAxfr(zone) } - go d.XfrRead(m, c) - for x := range c { - fmt.Printf("%v %v %v\n", x.Add, x.RR, x.Err) + axfr, err := client.XfrReceive(m, *nameserver) + if err != nil { + println(err.String()) + } + for _, v := range axfr { + fmt.Printf("%v\n", v) } } diff --git a/client.go b/client.go index 6baa1a1d..f626a3fa 100644 --- a/client.go +++ b/client.go @@ -252,7 +252,33 @@ func (w *reply) readClient(p []byte) (n int, err os.Error) { } switch w.Client().Net { case "tcp": - // + if len(p) < 1 { + return 0, io.ErrShortBuffer + } + n, err = w.conn.(*net.TCPConn).Read(p[0:2]) + if err != nil || n != 2 { + return n, err + } + l, _ := unpackUint16(p[0:2], 0) + if l == 0 { + return 0, ErrShortRead + } + if int(l) > len(p) { + return int(l), io.ErrShortBuffer + } + n, err = w.conn.(*net.TCPConn).Read(p[:l]) + if err != nil { + return n, err + } + i := n + for i < int(l) { + j, err := w.conn.(*net.TCPConn).Read(p[i:int(l)]) + if err != nil { + return i, err + } + i += j + } + n = i case "udp": n, _, err = w.conn.(*net.UDPConn).ReadFromUDP(p) if err != nil { diff --git a/defaults.go b/defaults.go index 8d538fb7..3f9eb120 100644 --- a/defaults.go +++ b/defaults.go @@ -52,6 +52,7 @@ func (dns *Msg) IsNotify() (ok bool) { // Create a dns msg suitable for requesting an ixfr. func (dns *Msg) SetIxfr(z string, serial uint32) { + dns.MsgHdr.Id = Id() dns.Question = make([]Question, 1) dns.Ns = make([]RR, 1) s := new(RR_SOA) @@ -64,6 +65,7 @@ func (dns *Msg) SetIxfr(z string, serial uint32) { // Create a dns msg suitable for requesting an axfr. func (dns *Msg) SetAxfr(z string) { + dns.MsgHdr.Id = Id() dns.Question = make([]Question, 1) dns.Question[0] = Question{z, TypeAXFR, ClassINET} } diff --git a/xfr.go b/xfr.go index e0b7e197..868b1433 100644 --- a/xfr.go +++ b/xfr.go @@ -28,10 +28,11 @@ func (c *Client) XfrReceive(q *Msg, a string) ([]*Msg, os.Error) { } func (w *reply) axfrReceive() ([]*Msg, os.Error) { - axfr := make([]*Msg, 1) // use append ALL the time? + axfr := make([]*Msg, 0) // use append ALL the time? first := true for { in, err := w.Receive() + axfr = append(axfr, in) if err != nil { return axfr, err } @@ -49,10 +50,8 @@ func (w *reply) axfrReceive() ([]*Msg, os.Error) { //} if !checkXfrSOA(in, false) { // Soa record not the last one - axfr = append(axfr, in) continue } else { - axfr = append(axfr, in) return axfr, nil } }