diff --git a/client_test.go b/client_test.go index 2cc11eac..f74b1e1b 100644 --- a/client_test.go +++ b/client_test.go @@ -62,6 +62,79 @@ func TestClientEDNS0(t *testing.T) { } } +// Validates the transmission and parsing of custom EDNS0 options. +func TestClientEDNS0Custom(t *testing.T) { + handler := func(w ResponseWriter, req *Msg) { + m := new(Msg) + m.SetReply(req) + + m.Extra = make([]RR, 1, 2) + m.Extra[0] = &TXT{Hdr: RR_Header{Name: m.Question[0].Name, Rrtype: TypeTXT, Class: ClassINET, Ttl: 0}, Txt: []string{"Hello custom edns"}} + + // If the custom options are what we expect, then reflect them back. + ec1 := req.Extra[0].(*OPT).Option[0].(*EDNS0_CUSTOM).String() + ec2 := req.Extra[0].(*OPT).Option[1].(*EDNS0_CUSTOM).String() + if ec1 == "1979:0x0707" && ec2 == "1997:0x0601" { + m.Extra = append(m.Extra, req.Extra[0]) + } + + w.WriteMsg(m) + } + + HandleFunc("miek.nl.", handler) + defer HandleRemove("miek.nl.") + + s, addrstr, err := RunLocalUDPServer("127.0.0.1:0") + if err != nil { + t.Fatalf("Unable to run test server: %s", err) + } + defer s.Shutdown() + + m := new(Msg) + m.SetQuestion("miek.nl.", TypeTXT) + + ec1 := &EDNS0_CUSTOM{Code: 1979, Data: []byte{7, 7}} + ec2 := &EDNS0_CUSTOM{Code: 1997, Data: []byte{6, 1}} + o := &OPT{Hdr: RR_Header{Name: ".", Rrtype: TypeOPT}, Option: []EDNS0{ec1, ec2}} + m.Extra = append(m.Extra, o) + + c := new(Client) + r, _, e := c.Exchange(m, addrstr) + if e != nil { + t.Logf("failed to exchange: %s", e.Error()) + t.Fail() + } + + if r != nil && r.Rcode != RcodeSuccess { + t.Log("failed to get a valid answer") + t.Fail() + t.Logf("%v\n", r) + } + + txt := r.Extra[0].(*TXT).Txt[0] + if txt != "Hello custom edns" { + t.Log("Unexpected result for miek.nl", txt, "!= Hello custom edns") + t.Fail() + } + + // Validate the custom options in the reply. + exp := "1979:0x0707" + got := r.Extra[1].(*OPT).Option[0].(*EDNS0_CUSTOM).String() + if got != exp { + t.Log("failed to get custom edns0 answer; got %s, expected %s", got, exp) + t.Fail() + t.Logf("%v\n", r) + } + + exp = "1997:0x0601" + got = r.Extra[1].(*OPT).Option[1].(*EDNS0_CUSTOM).String() + if got != exp { + t.Log("failed to get custom edns0 answer; got %s, expected %s", got, exp) + t.Fail() + t.Logf("%v\n", r) + } +} + func TestSingleSingleInflight(t *testing.T) { HandleFunc("miek.nl.", HelloServer) defer HandleRemove("miek.nl.") diff --git a/edns.go b/edns.go index 59d5e6cc..56f22732 100644 --- a/edns.go +++ b/edns.go @@ -68,6 +68,8 @@ func (rr *OPT) String() string { s += "\n; DS HASH UNDERSTOOD: " + o.String() case *EDNS0_N3U: s += "\n; NSEC3 HASH UNDERSTOOD: " + o.String() + case *EDNS0_CUSTOM: + s += "\n; CUSTOM OPT: " + o.String() } } return s @@ -76,8 +78,9 @@ func (rr *OPT) String() string { func (rr *OPT) len() int { l := rr.Hdr.len() for i := 0; i < len(rr.Option); i++ { + l += 4 // Account for 2-byte option code and 2-byte option length. lo, _ := rr.Option[i].pack() - l += 2 + len(lo) + l += len(lo) } return l } @@ -475,3 +478,31 @@ func (e *EDNS0_EXPIRE) unpack(b []byte) error { e.Expire = uint32(b[0])<<24 | uint32(b[1])<<16 | uint32(b[2])<<8 | uint32(b[3]) return nil } + +type EDNS0_CUSTOM struct { + Code uint16 + Data []byte +} + +func (e *EDNS0_CUSTOM) Option() uint16 { return e.Code } +func (e *EDNS0_CUSTOM) String() string { + return strconv.FormatInt(int64(e.Code), 10) + ":0x" + hex.EncodeToString(e.Data) +} + +func (e *EDNS0_CUSTOM) pack() ([]byte, error) { + b := make([]byte, len(e.Data)) + copied := copy(b, e.Data) + if copied != len(e.Data) { + return nil, ErrBuf + } + return b, nil +} + +func (e *EDNS0_CUSTOM) unpack(b []byte) error { + e.Data = make([]byte, len(b)) + copied := copy(e.Data, b) + if copied != len(b) { + return ErrBuf + } + return nil +} diff --git a/msg.go b/msg.go index 0b63e986..06a99bfc 100644 --- a/msg.go +++ b/msg.go @@ -1048,7 +1048,12 @@ func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, err er edns = append(edns, e) off = off1 + int(optlen) default: - // do nothing? + e := new(EDNS0_CUSTOM) + e.Code = code + if err := e.unpack(msg[off1 : off1+int(optlen)]); err != nil { + return lenmsg, err + } + edns = append(edns, e) off = off1 + int(optlen) } if off < lenrd {