TSIG overhauled
This lead to some changes in the Write() function. Both server side and client side are now more similar.
This commit is contained in:
parent
58a0addc8a
commit
c53cddf38c
|
@ -25,7 +25,7 @@ type QueryHandler interface {
|
|||
// construct a DNS request.
|
||||
type RequestWriter interface {
|
||||
// Write returns the request message and the reply back to the client.
|
||||
Write(*Msg)
|
||||
Write(*Msg) error
|
||||
// Send sends the message to the server.
|
||||
Send(*Msg) error
|
||||
// Receive waits for the reply of the servers.
|
||||
|
@ -201,8 +201,9 @@ func ListenAndQuery(request chan *Request, handler QueryHandler) {
|
|||
|
||||
// Write returns the original question and the answer on the
|
||||
// reply channel of the client.
|
||||
func (w *reply) Write(m *Msg) {
|
||||
func (w *reply) Write(m *Msg) error {
|
||||
w.Client().ReplyChan <- &Exchange{Request: w.req, Reply: m}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Do performs an asynchronous query. The result is returned on the
|
||||
|
|
|
@ -82,8 +82,8 @@ func (dns *Msg) SetAxfr(z string) {
|
|||
|
||||
// SetTsig appends a TSIG RR to the message.
|
||||
// This is only a skeleton Tsig RR that is added as the last RR in the
|
||||
// additional section. The caller should then call TsigGenerate,
|
||||
// to generate the complete TSIG with the secret.
|
||||
// additional section. The Tsig is calculated when the message is being
|
||||
// send.
|
||||
func (dns *Msg) SetTsig(z, algo string, fudge uint16, timesigned int64) {
|
||||
t := new(RR_TSIG)
|
||||
t.Hdr = RR_Header{z, TypeTSIG, ClassANY, 0, 0}
|
||||
|
|
|
@ -27,6 +27,8 @@ import (
|
|||
"os/signal"
|
||||
"runtime/pprof"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -84,25 +86,33 @@ func handleReflect(w dns.ResponseWriter, r *dns.Msg) {
|
|||
m.Extra = append(m.Extra, t)
|
||||
}
|
||||
|
||||
b, ok := m.Pack()
|
||||
if r.IsTsig() {
|
||||
println("Checking TSIG")
|
||||
if w.TsigStatus() == nil {
|
||||
println("TSIG OK")
|
||||
m.SetTsig(r.Extra[len(r.Extra)-1].(*dns.RR_TSIG).Hdr.Name, dns.HmacMD5, 300, time.Now().Unix())
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
if *printf {
|
||||
fmt.Printf("%v\n", m.String())
|
||||
}
|
||||
if !ok {
|
||||
log.Print("Packing failed")
|
||||
m.SetRcode(r, dns.RcodeServerFailure)
|
||||
m.Extra = nil
|
||||
m.Answer = nil
|
||||
b, _ = m.Pack()
|
||||
}
|
||||
// The reply is smaller then 512 bytes, so it will always "fit"
|
||||
w.Write(b)
|
||||
w.Write(m) // Discard the error?
|
||||
}
|
||||
|
||||
func serve(net string) {
|
||||
err := dns.ListenAndServe(":8053", net, nil)
|
||||
if err != nil {
|
||||
fmt.Printf("Failed to setup the "+net+" server: %s\n", err.Error())
|
||||
func serve(net, name, secret string) {
|
||||
switch name {
|
||||
case "":
|
||||
err := dns.ListenAndServe(":8053", net, nil)
|
||||
if err != nil {
|
||||
fmt.Printf("Failed to setup the "+net+" server: %s\n", err.Error())
|
||||
}
|
||||
default:
|
||||
err := dns.ListenAndServeTsig(":8053", net, nil, map[string]string{name: secret})
|
||||
if err != nil {
|
||||
fmt.Printf("Failed to setup the "+net+" server: %s\n", err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -110,12 +120,16 @@ func main() {
|
|||
cpuprofile := flag.String("cpuprofile", "", "write cpu profile to file")
|
||||
printf = flag.Bool("print", false, "print replies")
|
||||
compress = flag.Bool("compress", false, "compress replies")
|
||||
tsig = flag.String("tisg", "", "use SHA1 hmac tsig: keyname:base64")
|
||||
tsig = flag.String("tsig", "", "use MD5 hmac tsig: keyname:base64")
|
||||
var name, secret string
|
||||
flag.Usage = func() {
|
||||
flag.PrintDefaults()
|
||||
}
|
||||
flag.Parse()
|
||||
tsig = tsig //TODO
|
||||
if *tsig != "" {
|
||||
a := strings.SplitN(*tsig, ":", 2)
|
||||
name, secret = a[0], a[1]
|
||||
}
|
||||
if *cpuprofile != "" {
|
||||
f, err := os.Create(*cpuprofile)
|
||||
if err != nil {
|
||||
|
@ -126,8 +140,8 @@ func main() {
|
|||
}
|
||||
|
||||
dns.HandleFunc(".", handleReflect)
|
||||
go serve("tcp")
|
||||
go serve("udp")
|
||||
go serve("tcp", name, secret)
|
||||
go serve("udp", name, secret)
|
||||
sig := make(chan os.Signal)
|
||||
signal.Notify(sig)
|
||||
forever:
|
||||
|
|
44
server.go
44
server.go
|
@ -26,7 +26,7 @@ type ResponseWriter interface {
|
|||
// Return the status of the Tsig (TsigNone, TsigVerified or TsigBad)
|
||||
TsigStatus() error
|
||||
// Write writes a reply back to the client.
|
||||
Write([]byte) (int, error)
|
||||
Write(*Msg) error
|
||||
}
|
||||
|
||||
type conn struct {
|
||||
|
@ -40,9 +40,11 @@ type conn struct {
|
|||
}
|
||||
|
||||
type response struct {
|
||||
conn *conn
|
||||
req *Msg
|
||||
tsigStatus error
|
||||
conn *conn
|
||||
req *Msg
|
||||
tsigStatus error
|
||||
tsigTimersOnly bool
|
||||
tsigRequestMAC string
|
||||
}
|
||||
|
||||
// ServeMux is an DNS request multiplexer. It matches the
|
||||
|
@ -75,8 +77,7 @@ func (f HandlerFunc) ServeDNS(w ResponseWriter, r *Msg) {
|
|||
func Refused(w ResponseWriter, r *Msg) {
|
||||
m := new(Msg)
|
||||
m.SetRcode(r, RcodeRefused)
|
||||
buf, _ := m.Pack()
|
||||
w.Write(buf)
|
||||
w.Write(m)
|
||||
}
|
||||
|
||||
// RefusedHandler returns HandlerFunc with Refused.
|
||||
|
@ -310,8 +311,7 @@ func (c *conn) serve() {
|
|||
// Send a format error back
|
||||
x := new(Msg)
|
||||
x.SetRcodeFormatError(req)
|
||||
buf, _ := x.Pack()
|
||||
w.Write(buf)
|
||||
w.Write(x)
|
||||
break
|
||||
}
|
||||
|
||||
|
@ -321,8 +321,9 @@ func (c *conn) serve() {
|
|||
if _, ok := w.conn.tsigSecret[secret]; !ok {
|
||||
w.tsigStatus = ErrKeyAlg
|
||||
}
|
||||
// Do I *ever* need Tsig.Mac here? Or timersOnly? TODO(mg)
|
||||
w.tsigStatus = TsigVerify(c.request, w.conn.tsigSecret[secret], "", false)
|
||||
w.tsigTimersOnly = false // Will this ever be true?
|
||||
w.tsigRequestMAC = req.Extra[len(req.Extra)-1].(*RR_TSIG).MAC
|
||||
}
|
||||
w.req = req
|
||||
c.handler.ServeDNS(w, w.req) // this does the writing back to the client
|
||||
|
@ -336,41 +337,46 @@ func (c *conn) serve() {
|
|||
}
|
||||
}
|
||||
|
||||
func (w *response) Write(data []byte) (n int, err error) {
|
||||
func (w *response) Write(m *Msg) error {
|
||||
//data []byte) (n int, err error) {
|
||||
data, ok := m.Pack()
|
||||
if !ok {
|
||||
return ErrPack
|
||||
}
|
||||
switch {
|
||||
case w.conn._UDP != nil:
|
||||
n, err = w.conn._UDP.WriteTo(data, w.conn.remoteAddr)
|
||||
_, err := w.conn._UDP.WriteTo(data, w.conn.remoteAddr)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
return err
|
||||
}
|
||||
case w.conn._TCP != nil:
|
||||
if len(data) > MaxMsgSize {
|
||||
return 0, ErrBuf
|
||||
return ErrBuf
|
||||
}
|
||||
l := make([]byte, 2)
|
||||
l[0], l[1] = packUint16(uint16(len(data)))
|
||||
n, err = w.conn._TCP.Write(l)
|
||||
n, err := w.conn._TCP.Write(l)
|
||||
if err != nil {
|
||||
return n, err
|
||||
return err
|
||||
}
|
||||
if n != 2 {
|
||||
return n, io.ErrShortWrite
|
||||
return io.ErrShortWrite
|
||||
}
|
||||
n, err = w.conn._TCP.Write(data)
|
||||
if err != nil {
|
||||
return n, err
|
||||
return err
|
||||
}
|
||||
i := n
|
||||
if i < len(data) {
|
||||
j, err := w.conn._TCP.Write(data[i:len(data)])
|
||||
if err != nil {
|
||||
return i, err
|
||||
return err
|
||||
}
|
||||
i += j
|
||||
}
|
||||
n = i
|
||||
}
|
||||
return n, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoteAddr implements the ResponseWriter.RemoteAddr method
|
||||
|
|
Loading…
Reference in New Issue