TSIG overhauled
This lead to some changes in the Write() function. Both server side and client side are now more similar.
This commit is contained in:
parent
58a0addc8a
commit
c53cddf38c
|
@ -25,7 +25,7 @@ type QueryHandler interface {
|
||||||
// construct a DNS request.
|
// construct a DNS request.
|
||||||
type RequestWriter interface {
|
type RequestWriter interface {
|
||||||
// Write returns the request message and the reply back to the client.
|
// Write returns the request message and the reply back to the client.
|
||||||
Write(*Msg)
|
Write(*Msg) error
|
||||||
// Send sends the message to the server.
|
// Send sends the message to the server.
|
||||||
Send(*Msg) error
|
Send(*Msg) error
|
||||||
// Receive waits for the reply of the servers.
|
// Receive waits for the reply of the servers.
|
||||||
|
@ -201,8 +201,9 @@ func ListenAndQuery(request chan *Request, handler QueryHandler) {
|
||||||
|
|
||||||
// Write returns the original question and the answer on the
|
// Write returns the original question and the answer on the
|
||||||
// reply channel of the client.
|
// reply channel of the client.
|
||||||
func (w *reply) Write(m *Msg) {
|
func (w *reply) Write(m *Msg) error {
|
||||||
w.Client().ReplyChan <- &Exchange{Request: w.req, Reply: m}
|
w.Client().ReplyChan <- &Exchange{Request: w.req, Reply: m}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Do performs an asynchronous query. The result is returned on the
|
// Do performs an asynchronous query. The result is returned on the
|
||||||
|
|
|
@ -82,8 +82,8 @@ func (dns *Msg) SetAxfr(z string) {
|
||||||
|
|
||||||
// SetTsig appends a TSIG RR to the message.
|
// SetTsig appends a TSIG RR to the message.
|
||||||
// This is only a skeleton Tsig RR that is added as the last RR in the
|
// This is only a skeleton Tsig RR that is added as the last RR in the
|
||||||
// additional section. The caller should then call TsigGenerate,
|
// additional section. The Tsig is calculated when the message is being
|
||||||
// to generate the complete TSIG with the secret.
|
// send.
|
||||||
func (dns *Msg) SetTsig(z, algo string, fudge uint16, timesigned int64) {
|
func (dns *Msg) SetTsig(z, algo string, fudge uint16, timesigned int64) {
|
||||||
t := new(RR_TSIG)
|
t := new(RR_TSIG)
|
||||||
t.Hdr = RR_Header{z, TypeTSIG, ClassANY, 0, 0}
|
t.Hdr = RR_Header{z, TypeTSIG, ClassANY, 0, 0}
|
||||||
|
|
|
@ -27,6 +27,8 @@ import (
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"runtime/pprof"
|
"runtime/pprof"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -84,25 +86,33 @@ func handleReflect(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
m.Extra = append(m.Extra, t)
|
m.Extra = append(m.Extra, t)
|
||||||
}
|
}
|
||||||
|
|
||||||
b, ok := m.Pack()
|
if r.IsTsig() {
|
||||||
|
println("Checking TSIG")
|
||||||
|
if w.TsigStatus() == nil {
|
||||||
|
println("TSIG OK")
|
||||||
|
m.SetTsig(r.Extra[len(r.Extra)-1].(*dns.RR_TSIG).Hdr.Name, dns.HmacMD5, 300, time.Now().Unix())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
if *printf {
|
if *printf {
|
||||||
fmt.Printf("%v\n", m.String())
|
fmt.Printf("%v\n", m.String())
|
||||||
}
|
}
|
||||||
if !ok {
|
w.Write(m) // Discard the error?
|
||||||
log.Print("Packing failed")
|
|
||||||
m.SetRcode(r, dns.RcodeServerFailure)
|
|
||||||
m.Extra = nil
|
|
||||||
m.Answer = nil
|
|
||||||
b, _ = m.Pack()
|
|
||||||
}
|
|
||||||
// The reply is smaller then 512 bytes, so it will always "fit"
|
|
||||||
w.Write(b)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func serve(net string) {
|
func serve(net, name, secret string) {
|
||||||
err := dns.ListenAndServe(":8053", net, nil)
|
switch name {
|
||||||
if err != nil {
|
case "":
|
||||||
fmt.Printf("Failed to setup the "+net+" server: %s\n", err.Error())
|
err := dns.ListenAndServe(":8053", net, nil)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Failed to setup the "+net+" server: %s\n", err.Error())
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
err := dns.ListenAndServeTsig(":8053", net, nil, map[string]string{name: secret})
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Failed to setup the "+net+" server: %s\n", err.Error())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -110,12 +120,16 @@ func main() {
|
||||||
cpuprofile := flag.String("cpuprofile", "", "write cpu profile to file")
|
cpuprofile := flag.String("cpuprofile", "", "write cpu profile to file")
|
||||||
printf = flag.Bool("print", false, "print replies")
|
printf = flag.Bool("print", false, "print replies")
|
||||||
compress = flag.Bool("compress", false, "compress replies")
|
compress = flag.Bool("compress", false, "compress replies")
|
||||||
tsig = flag.String("tisg", "", "use SHA1 hmac tsig: keyname:base64")
|
tsig = flag.String("tsig", "", "use MD5 hmac tsig: keyname:base64")
|
||||||
|
var name, secret string
|
||||||
flag.Usage = func() {
|
flag.Usage = func() {
|
||||||
flag.PrintDefaults()
|
flag.PrintDefaults()
|
||||||
}
|
}
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
tsig = tsig //TODO
|
if *tsig != "" {
|
||||||
|
a := strings.SplitN(*tsig, ":", 2)
|
||||||
|
name, secret = a[0], a[1]
|
||||||
|
}
|
||||||
if *cpuprofile != "" {
|
if *cpuprofile != "" {
|
||||||
f, err := os.Create(*cpuprofile)
|
f, err := os.Create(*cpuprofile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -126,8 +140,8 @@ func main() {
|
||||||
}
|
}
|
||||||
|
|
||||||
dns.HandleFunc(".", handleReflect)
|
dns.HandleFunc(".", handleReflect)
|
||||||
go serve("tcp")
|
go serve("tcp", name, secret)
|
||||||
go serve("udp")
|
go serve("udp", name, secret)
|
||||||
sig := make(chan os.Signal)
|
sig := make(chan os.Signal)
|
||||||
signal.Notify(sig)
|
signal.Notify(sig)
|
||||||
forever:
|
forever:
|
||||||
|
|
44
server.go
44
server.go
|
@ -26,7 +26,7 @@ type ResponseWriter interface {
|
||||||
// Return the status of the Tsig (TsigNone, TsigVerified or TsigBad)
|
// Return the status of the Tsig (TsigNone, TsigVerified or TsigBad)
|
||||||
TsigStatus() error
|
TsigStatus() error
|
||||||
// Write writes a reply back to the client.
|
// Write writes a reply back to the client.
|
||||||
Write([]byte) (int, error)
|
Write(*Msg) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type conn struct {
|
type conn struct {
|
||||||
|
@ -40,9 +40,11 @@ type conn struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type response struct {
|
type response struct {
|
||||||
conn *conn
|
conn *conn
|
||||||
req *Msg
|
req *Msg
|
||||||
tsigStatus error
|
tsigStatus error
|
||||||
|
tsigTimersOnly bool
|
||||||
|
tsigRequestMAC string
|
||||||
}
|
}
|
||||||
|
|
||||||
// ServeMux is an DNS request multiplexer. It matches the
|
// ServeMux is an DNS request multiplexer. It matches the
|
||||||
|
@ -75,8 +77,7 @@ func (f HandlerFunc) ServeDNS(w ResponseWriter, r *Msg) {
|
||||||
func Refused(w ResponseWriter, r *Msg) {
|
func Refused(w ResponseWriter, r *Msg) {
|
||||||
m := new(Msg)
|
m := new(Msg)
|
||||||
m.SetRcode(r, RcodeRefused)
|
m.SetRcode(r, RcodeRefused)
|
||||||
buf, _ := m.Pack()
|
w.Write(m)
|
||||||
w.Write(buf)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// RefusedHandler returns HandlerFunc with Refused.
|
// RefusedHandler returns HandlerFunc with Refused.
|
||||||
|
@ -310,8 +311,7 @@ func (c *conn) serve() {
|
||||||
// Send a format error back
|
// Send a format error back
|
||||||
x := new(Msg)
|
x := new(Msg)
|
||||||
x.SetRcodeFormatError(req)
|
x.SetRcodeFormatError(req)
|
||||||
buf, _ := x.Pack()
|
w.Write(x)
|
||||||
w.Write(buf)
|
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -321,8 +321,9 @@ func (c *conn) serve() {
|
||||||
if _, ok := w.conn.tsigSecret[secret]; !ok {
|
if _, ok := w.conn.tsigSecret[secret]; !ok {
|
||||||
w.tsigStatus = ErrKeyAlg
|
w.tsigStatus = ErrKeyAlg
|
||||||
}
|
}
|
||||||
// Do I *ever* need Tsig.Mac here? Or timersOnly? TODO(mg)
|
|
||||||
w.tsigStatus = TsigVerify(c.request, w.conn.tsigSecret[secret], "", false)
|
w.tsigStatus = TsigVerify(c.request, w.conn.tsigSecret[secret], "", false)
|
||||||
|
w.tsigTimersOnly = false // Will this ever be true?
|
||||||
|
w.tsigRequestMAC = req.Extra[len(req.Extra)-1].(*RR_TSIG).MAC
|
||||||
}
|
}
|
||||||
w.req = req
|
w.req = req
|
||||||
c.handler.ServeDNS(w, w.req) // this does the writing back to the client
|
c.handler.ServeDNS(w, w.req) // this does the writing back to the client
|
||||||
|
@ -336,41 +337,46 @@ func (c *conn) serve() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *response) Write(data []byte) (n int, err error) {
|
func (w *response) Write(m *Msg) error {
|
||||||
|
//data []byte) (n int, err error) {
|
||||||
|
data, ok := m.Pack()
|
||||||
|
if !ok {
|
||||||
|
return ErrPack
|
||||||
|
}
|
||||||
switch {
|
switch {
|
||||||
case w.conn._UDP != nil:
|
case w.conn._UDP != nil:
|
||||||
n, err = w.conn._UDP.WriteTo(data, w.conn.remoteAddr)
|
_, err := w.conn._UDP.WriteTo(data, w.conn.remoteAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return err
|
||||||
}
|
}
|
||||||
case w.conn._TCP != nil:
|
case w.conn._TCP != nil:
|
||||||
if len(data) > MaxMsgSize {
|
if len(data) > MaxMsgSize {
|
||||||
return 0, ErrBuf
|
return ErrBuf
|
||||||
}
|
}
|
||||||
l := make([]byte, 2)
|
l := make([]byte, 2)
|
||||||
l[0], l[1] = packUint16(uint16(len(data)))
|
l[0], l[1] = packUint16(uint16(len(data)))
|
||||||
n, err = w.conn._TCP.Write(l)
|
n, err := w.conn._TCP.Write(l)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return n, err
|
return err
|
||||||
}
|
}
|
||||||
if n != 2 {
|
if n != 2 {
|
||||||
return n, io.ErrShortWrite
|
return io.ErrShortWrite
|
||||||
}
|
}
|
||||||
n, err = w.conn._TCP.Write(data)
|
n, err = w.conn._TCP.Write(data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return n, err
|
return err
|
||||||
}
|
}
|
||||||
i := n
|
i := n
|
||||||
if i < len(data) {
|
if i < len(data) {
|
||||||
j, err := w.conn._TCP.Write(data[i:len(data)])
|
j, err := w.conn._TCP.Write(data[i:len(data)])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return i, err
|
return err
|
||||||
}
|
}
|
||||||
i += j
|
i += j
|
||||||
}
|
}
|
||||||
n = i
|
n = i
|
||||||
}
|
}
|
||||||
return n, nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemoteAddr implements the ResponseWriter.RemoteAddr method
|
// RemoteAddr implements the ResponseWriter.RemoteAddr method
|
||||||
|
|
Loading…
Reference in New Issue