"port" the write function
This commit is contained in:
parent
8d01350deb
commit
031dbba174
101
client.go
101
client.go
|
@ -10,7 +10,8 @@ package dns
|
|||
// This completely mirrors server.go impl.
|
||||
import (
|
||||
"os"
|
||||
// "net"
|
||||
"io"
|
||||
"net"
|
||||
)
|
||||
|
||||
type QueryHandler interface {
|
||||
|
@ -21,7 +22,7 @@ type QueryHandler interface {
|
|||
// construct an DNS request.
|
||||
type RequestWriter interface {
|
||||
WriteMessages([]*Msg)
|
||||
Write(*Msg)
|
||||
Write(*Msg)
|
||||
}
|
||||
|
||||
// hijacked connections...?
|
||||
|
@ -45,7 +46,7 @@ func NewQueryMux() *QueryMux { return &QueryMux{make(map[string]QueryHandler)} }
|
|||
var DefaultQueryMux = NewQueryMux()
|
||||
|
||||
func newQueryChanSlice() chan []*Msg { return make(chan []*Msg) }
|
||||
func newQueryChan() chan *Msg { return make(chan *Msg) }
|
||||
func newQueryChan() chan *Msg { return make(chan *Msg) }
|
||||
|
||||
// Default channel to use for the resolver
|
||||
var DefaultReplyChan = newQueryChanSlice()
|
||||
|
@ -110,17 +111,17 @@ func (mux *QueryMux) QueryDNS(w RequestWriter, request *Msg) {
|
|||
}
|
||||
|
||||
type Client struct {
|
||||
Network string // if "tcp" a TCP query will be initiated, otherwise an UDP one
|
||||
Net string // if "tcp" a TCP query will be initiated, otherwise an UDP one
|
||||
Addr string // address to call
|
||||
Attempts int // number of attempts
|
||||
Retry bool // retry with TCP
|
||||
ChannelQuery chan *Msg // read DNS request from this channel
|
||||
ChannelReply chan []*Msg // read DNS request from this channel
|
||||
ChannelReply chan []*Msg // read DNS request from this channel
|
||||
Handler QueryHandler // handler to invoke, dns.DefaultQueryMux if nil
|
||||
ReadTimeout int64 // the net.Conn.SetReadTimeout value for new connections
|
||||
WriteTimeout int64 // the net.Conn.SetWriteTimeout value for new connections
|
||||
}
|
||||
|
||||
|
||||
// Query accepts incoming DNS request,
|
||||
// Write to in
|
||||
// creating a new service thread for each. The service threads
|
||||
|
@ -160,10 +161,16 @@ func (c *Client) ListenAndQuery() os.Error {
|
|||
}
|
||||
|
||||
func (c *Client) Do(m *Msg, addr string) {
|
||||
// addr !!!
|
||||
if c.ChannelQuery == nil {
|
||||
DefaultQueryChan <- m
|
||||
}
|
||||
if c.Net == "" {
|
||||
c.Net = "udp"
|
||||
}
|
||||
if c.Attempts == 0 {
|
||||
c.Attempts = 1
|
||||
}
|
||||
c.Addr = addr
|
||||
}
|
||||
|
||||
func ListenAndQuery(c chan *Msg, handler QueryHandler) {
|
||||
|
@ -178,6 +185,84 @@ func (w *reply) Write(m *Msg) {
|
|||
|
||||
func (w *reply) WriteMessages(m []*Msg) {
|
||||
// Write to the channel
|
||||
m1 := append([]*Msg{w.req}, m...) // Really the way?
|
||||
m1 := append([]*Msg{w.req}, m...) // Really the way?
|
||||
w.Client.ChannelReply <- m1
|
||||
}
|
||||
|
||||
func (c *Client) Read() {
|
||||
|
||||
}
|
||||
|
||||
func (c *Client) Write(m *Msg) os.Error {
|
||||
out, ok := m.Pack()
|
||||
if !ok {
|
||||
return ErrPack
|
||||
}
|
||||
_, err := c.write(out)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Fill Client.Conn with the connection
|
||||
func (c *Client) write(p []byte) (n int, err os.Error) {
|
||||
conn, err := net.Dial(c.Net, "", c.Addr)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if c.Attempts == 0 {
|
||||
panic("client attempts 0")
|
||||
}
|
||||
switch c.Net {
|
||||
case "tcp":
|
||||
if len(p) < 2 {
|
||||
return 0, io.ErrShortBuffer
|
||||
}
|
||||
for a := 0; a < c.Attempts; a++ {
|
||||
l := make([]byte, 2)
|
||||
l[0], l[1] = packUint16(uint16(len(p)))
|
||||
n, err = conn.Write(l)
|
||||
if err != nil {
|
||||
if e, ok := err.(net.Error); ok && e.Timeout() {
|
||||
continue
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
if n != 2 {
|
||||
return n, io.ErrShortWrite
|
||||
}
|
||||
n, err = conn.Write(p)
|
||||
if err != nil {
|
||||
if e, ok := err.(net.Error); ok && e.Timeout() {
|
||||
continue
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
i := n
|
||||
if i < len(p) {
|
||||
j, err := conn.Write(p[i:len(p)])
|
||||
if err != nil {
|
||||
if e, ok := err.(net.Error); ok && e.Timeout() {
|
||||
// We are half way in our write...
|
||||
continue
|
||||
}
|
||||
return i, err
|
||||
}
|
||||
i += j
|
||||
}
|
||||
n = i
|
||||
}
|
||||
case "udp":
|
||||
for a := 0; a < c.Attempts; a++ {
|
||||
n, err = conn.(*net.UDPConn).WriteTo(p, conn.RemoteAddr())
|
||||
if err != nil {
|
||||
if e, ok := err.(net.Error); ok && e.Timeout() {
|
||||
continue
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue