diff --git a/edns.go b/edns.go index d070e41b..862e5343 100644 --- a/edns.go +++ b/edns.go @@ -98,6 +98,8 @@ func (rr *OPT) String() string { s += "\n; SUBNET: " + o.String() case *EDNS0_COOKIE: s += "\n; COOKIE: " + o.String() + case *EDNS0_TCP_KEEPALIVE: + s += "\n; KEEPALIVE: " + o.String() case *EDNS0_UL: s += "\n; UPDATE LEASE: " + o.String() case *EDNS0_LLQ: @@ -657,57 +659,52 @@ func (e *EDNS0_LOCAL) unpack(b []byte) error { // EDNS0_TCP_KEEPALIVE is an EDNS0 option that instructs the server to keep // the TCP connection alive. See RFC 7828. type EDNS0_TCP_KEEPALIVE struct { - Code uint16 // Always EDNSTCPKEEPALIVE - Length uint16 // the value 0 if the TIMEOUT is omitted, the value 2 if it is present; - Timeout uint16 // an idle timeout value for the TCP connection, specified in units of 100 milliseconds, encoded in network byte order. + Code uint16 // Always EDNSTCPKEEPALIVE + + // Timeout is an idle timeout value for the TCP connection, specified in + // units of 100 milliseconds, encoded in network byte order. If set to 0, + // pack will return a nil slice. + Timeout uint16 + + // Length is the option's length. + // Deprecated: this field is deprecated and is always equal to 0. + Length uint16 } // Option implements the EDNS0 interface. func (e *EDNS0_TCP_KEEPALIVE) Option() uint16 { return EDNS0TCPKEEPALIVE } func (e *EDNS0_TCP_KEEPALIVE) pack() ([]byte, error) { - if e.Timeout != 0 && e.Length != 2 { - return nil, errors.New("dns: timeout specified but length is not 2") + if e.Timeout > 0 { + b := make([]byte, 2) + binary.BigEndian.PutUint16(b, e.Timeout) + return b, nil } - if e.Timeout == 0 && e.Length != 0 { - return nil, errors.New("dns: timeout not specified but length is not 0") - } - b := make([]byte, 4+e.Length) - binary.BigEndian.PutUint16(b[0:], e.Code) - binary.BigEndian.PutUint16(b[2:], e.Length) - if e.Length == 2 { - binary.BigEndian.PutUint16(b[4:], e.Timeout) - } - return b, nil + return nil, nil } func (e *EDNS0_TCP_KEEPALIVE) unpack(b []byte) error { - if len(b) < 4 { - return ErrBuf - } - e.Length = binary.BigEndian.Uint16(b[2:4]) - if e.Length != 0 && e.Length != 2 { - return errors.New("dns: length mismatch, want 0/2 but got " + strconv.FormatUint(uint64(e.Length), 10)) - } - if e.Length == 2 { - if len(b) < 6 { - return ErrBuf - } - e.Timeout = binary.BigEndian.Uint16(b[4:6]) + switch len(b) { + case 0: + case 2: + e.Timeout = binary.BigEndian.Uint16(b) + default: + return fmt.Errorf("dns: length mismatch, want 0/2 but got %d", len(b)) } return nil } -func (e *EDNS0_TCP_KEEPALIVE) String() (s string) { - s = "use tcp keep-alive" - if e.Length == 0 { +func (e *EDNS0_TCP_KEEPALIVE) String() string { + s := "use tcp keep-alive" + if e.Timeout == 0 { s += ", timeout omitted" } else { s += fmt.Sprintf(", timeout %dms", e.Timeout*100) } - return + return s } -func (e *EDNS0_TCP_KEEPALIVE) copy() EDNS0 { return &EDNS0_TCP_KEEPALIVE{e.Code, e.Length, e.Timeout} } + +func (e *EDNS0_TCP_KEEPALIVE) copy() EDNS0 { return &EDNS0_TCP_KEEPALIVE{e.Code, e.Timeout, e.Length} } // EDNS0_PADDING option is used to add padding to a request/response. The default // value of padding SHOULD be 0x0 but other values MAY be used, for instance if diff --git a/edns_test.go b/edns_test.go index dc4ea7b3..b7c15f7e 100644 --- a/edns_test.go +++ b/edns_test.go @@ -1,6 +1,7 @@ package dns import ( + "bytes" "net" "testing" ) @@ -192,3 +193,87 @@ func TestEDNS0_ESU(t *testing.T) { t.Errorf("unpacked option is different; expected %v, got %v", expect, esu.Uri) } } + +func TestEDNS0_TCP_KEEPALIVE_unpack(t *testing.T) { + cases := []struct { + name string + b []byte + expected uint16 + expectedErr bool + }{ + { + name: "empty", + b: []byte{}, + expected: 0, + }, + { + name: "timeout 1", + b: []byte{0, 1}, + expected: 1, + }, + { + name: "invalid", + b: []byte{0, 1, 3}, + expectedErr: true, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + e := &EDNS0_TCP_KEEPALIVE{} + err := e.unpack(tc.b) + if err != nil && !tc.expectedErr { + t.Error("failed to unpack, expected no error") + } + if err == nil && tc.expectedErr { + t.Error("unpacked, but expected an error") + } + if e.Timeout != tc.expected { + t.Errorf("invalid timeout, actual: %d, expected: %d", e.Timeout, tc.expected) + } + }) + } +} + +func TestEDNS0_TCP_KEEPALIVE_pack(t *testing.T) { + cases := []struct { + name string + edns *EDNS0_TCP_KEEPALIVE + expected []byte + }{ + { + name: "empty", + edns: &EDNS0_TCP_KEEPALIVE{ + Code: EDNS0TCPKEEPALIVE, + Timeout: 0, + }, + expected: nil, + }, + { + name: "timeout 1", + edns: &EDNS0_TCP_KEEPALIVE{ + Code: EDNS0TCPKEEPALIVE, + Timeout: 1, + }, + expected: []byte{0, 1}, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + b, err := tc.edns.pack() + if err != nil { + t.Error("expected no error") + } + + if tc.expected == nil && b != nil { + t.Errorf("invalid result, expected nil") + } + + res := bytes.Compare(b, tc.expected) + if res != 0 { + t.Errorf("invalid result, expected: %v, actual: %v", tc.expected, b) + } + }) + } +}