Add lowlevel read/write primatives - and make it much more Go-like
This commit is contained in:
parent
b722229700
commit
de9a1da6aa
167
dns.go
167
dns.go
|
@ -15,7 +15,11 @@
|
|||
//
|
||||
package dns
|
||||
|
||||
// ErrShortWrite is defined in io, use that!
|
||||
|
||||
import (
|
||||
"os"
|
||||
"net"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
|
@ -30,7 +34,7 @@ const (
|
|||
type Error struct {
|
||||
Error string
|
||||
Name string
|
||||
Server string
|
||||
Server net.Addr
|
||||
Timeout bool
|
||||
}
|
||||
|
||||
|
@ -41,6 +45,167 @@ func (e *Error) String() string {
|
|||
return e.Error
|
||||
}
|
||||
|
||||
// A Conn is the lowest primative in this DNS library
|
||||
// A hold both the UDP and TCP connection, but only one
|
||||
// can be active at any one time.
|
||||
type Conn struct {
|
||||
// The current UDP connection.
|
||||
UDP *net.UDPConn
|
||||
// The current TCP connection.
|
||||
TCP *net.TCPConn
|
||||
// The remote side of the connection.
|
||||
Addr net.Addr
|
||||
|
||||
// Timeout in sec
|
||||
Timeout int
|
||||
|
||||
// Number of attempts to try
|
||||
Attempts int
|
||||
}
|
||||
|
||||
func (d *Conn) Read(p []byte) (n int, err os.Error) {
|
||||
if d.UDP != nil && d.TCP != nil {
|
||||
return 0, &Error{Error: "UDP and TCP or both non-nil"}
|
||||
}
|
||||
switch {
|
||||
case d.UDP != nil:
|
||||
n, err = d.UDP.Read(p)
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
case d.TCP != nil:
|
||||
n, err = d.TCP.Read(p[0:1])
|
||||
if err != nil || n != 2 {
|
||||
return n, err
|
||||
}
|
||||
l, _ := unpackUint16(p[0:1], 0)
|
||||
if l == 0 {
|
||||
return 0, &Error{Error: "received nil msg length", Server: d.Addr}
|
||||
}
|
||||
if int(l) > len(p) {
|
||||
return int(l), &Error{Error: "Buffer too small to read"}
|
||||
}
|
||||
n, err = d.TCP.Read(p)
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
i := n
|
||||
for i < int(l) {
|
||||
n, err = d.TCP.Read(p[i:])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
i += n
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (d *Conn) Write(p []byte) (n int, err os.Error) {
|
||||
if d.UDP != nil && d.TCP != nil {
|
||||
return 0, &Error{Error: "UDP and TCP or both non-nil"}
|
||||
}
|
||||
|
||||
var attempts int
|
||||
if d.Attempts == 0 {
|
||||
attempts = 1
|
||||
} else {
|
||||
attempts = d.Attempts
|
||||
}
|
||||
d.SetTimeout()
|
||||
|
||||
switch {
|
||||
case d.UDP != nil:
|
||||
for a := 0; a < attempts; a++ {
|
||||
n, err = d.UDP.WriteTo(p, d.Addr)
|
||||
if err != nil {
|
||||
if e, ok := err.(net.Error); ok && e.Timeout() {
|
||||
continue
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
case d.TCP != nil:
|
||||
for a := 0; a < attempts; a++ {
|
||||
l := make([]byte, 2)
|
||||
l[0], l[1] = packUint16(uint16(len(p)))
|
||||
n, err = d.TCP.Write(l)
|
||||
if err != nil {
|
||||
if e, ok := err.(net.Error); ok && e.Timeout() {
|
||||
continue
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
if n != 2 {
|
||||
return n, &Error{Error: "Write failure"}
|
||||
}
|
||||
n, err = d.TCP.Write(p)
|
||||
if err != nil {
|
||||
if e, ok := err.(net.Error); ok && e.Timeout() {
|
||||
continue
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (d *Conn) Close() (err os.Error) {
|
||||
if d.UDP != nil && d.TCP != nil {
|
||||
return &Error{Error: "UDP and TCP or both non-nil"}
|
||||
}
|
||||
switch {
|
||||
case d.UDP != nil:
|
||||
err = d.UDP.Close()
|
||||
case d.TCP != nil:
|
||||
err = d.TCP.Close()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (d *Conn) SetTimeout() (err os.Error) {
|
||||
var sec int64
|
||||
if d.UDP != nil && d.TCP != nil {
|
||||
return &Error{Error: "UDP and TCP or both non-nil"}
|
||||
}
|
||||
sec = int64(d.Timeout)
|
||||
if sec == 0 {
|
||||
sec = 1
|
||||
}
|
||||
if d.UDP != nil {
|
||||
err = d.TCP.SetTimeout(sec * 1e9)
|
||||
}
|
||||
if d.TCP != nil {
|
||||
err = d.TCP.SetTimeout(sec * 1e9)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Fix those here...!
|
||||
// ReadTsig
|
||||
// WriteTsig
|
||||
|
||||
func (d *Conn) Exchange(request []byte, nosend bool) (reply []byte, err os.Error) {
|
||||
var n int
|
||||
n, err = d.Write(request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Layer violation to safe memory. (Its okay then.)
|
||||
if d.UDP == nil {
|
||||
reply = make([]byte, MaxMsgSize)
|
||||
} else {
|
||||
reply = make([]byte, DefaultMsgSize)
|
||||
}
|
||||
n, err = d.Read(reply)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
reply = reply[:n]
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
type RR interface {
|
||||
Header() *RR_Header
|
||||
|
|
32
resolver.go
32
resolver.go
|
@ -32,10 +32,6 @@ type Resolver struct {
|
|||
Rrb int // Last used server (for round robin)
|
||||
}
|
||||
|
||||
func (res *Resolver) QueryTSIG(q *Msg, secret *string) (d *Msg, err os.Error) {
|
||||
return nil,nil
|
||||
}
|
||||
|
||||
// Basic usage pattern for setting up a resolver:
|
||||
//
|
||||
// res := new(Resolver)
|
||||
|
@ -49,7 +45,6 @@ func (res *Resolver) QueryTSIG(q *Msg, secret *string) (d *Msg, err os.Error) {
|
|||
//
|
||||
// Note that message id checking is left to the caller.
|
||||
func (res *Resolver) Query(q *Msg) (d *Msg, err os.Error) {
|
||||
// Check if there is a TSIG appended, if so, check it
|
||||
var (
|
||||
c net.Conn
|
||||
port string
|
||||
|
@ -78,31 +73,28 @@ func (res *Resolver) Query(q *Msg) (d *Msg, err os.Error) {
|
|||
}
|
||||
|
||||
for i := 0; i < len(res.Servers); i++ {
|
||||
d := new(Conn)
|
||||
server := res.Servers[i] + ":" + port
|
||||
t := time.Nanoseconds()
|
||||
if res.Tcp {
|
||||
c, err = net.Dial("tcp", "", server)
|
||||
d.TCP = c.(*net.TCPConn)
|
||||
d.Addr = d.TCP.RemoteAddr()
|
||||
} else {
|
||||
c, err = net.Dial("udp", "", server)
|
||||
d.UDP = c.(*net.UDPConn)
|
||||
d.Addr = d.UDP.RemoteAddr()
|
||||
}
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if res.Tcp {
|
||||
inb, err = exchangeTCP(c, sending, res, true)
|
||||
in.Unpack(inb)
|
||||
|
||||
} else {
|
||||
inb, err = exchangeUDP(c, sending, res, true)
|
||||
in.Unpack(inb)
|
||||
}
|
||||
}
|
||||
inb, err = d.Exchange(sending, false)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
in.Unpack(inb)
|
||||
res.Rtt[server] = time.Nanoseconds() - t
|
||||
|
||||
// Check id in.id != out.id, should be checked in the client!
|
||||
c.Close()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
|
@ -550,7 +542,7 @@ func recvTCP(c net.Conn) ([]byte, os.Error) {
|
|||
}
|
||||
length := uint16(l[0])<<8 | uint16(l[1])
|
||||
if length == 0 {
|
||||
return nil, &Error{Error: "received nil msg length", Server: c.RemoteAddr().String()}
|
||||
return nil, &Error{Error: "received nil msg length", Server: c.RemoteAddr()}
|
||||
}
|
||||
m := make([]byte, length)
|
||||
n, cerr := c.Read(m)
|
||||
|
|
Loading…
Reference in New Issue