diff --git a/tsig.go b/tsig.go index 0b010091..1c9e9024 100644 --- a/tsig.go +++ b/tsig.go @@ -111,7 +111,10 @@ func TsigGenerate(m *Msg, secret, requestMAC string, timersOnly bool) ([]byte, s if err != nil { return nil, "", err } - buf := tsigBuffer(mbuf, rr, requestMAC, timersOnly) + buf, err := tsigBuffer(mbuf, rr, requestMAC, timersOnly) + if err != nil { + return nil, "", err + } t := new(TSIG) var h hash.Hash @@ -173,7 +176,10 @@ func tsigVerify(msg []byte, secret, requestMAC string, timersOnly bool, now uint return err } - buf := tsigBuffer(stripped, tsig, requestMAC, timersOnly) + buf, err := tsigBuffer(stripped, tsig, requestMAC, timersOnly) + if err != nil { + return err + } var h hash.Hash switch CanonicalName(tsig.Algorithm) { @@ -209,7 +215,7 @@ func tsigVerify(msg []byte, secret, requestMAC string, timersOnly bool, now uint } // Create a wiredata buffer for the MAC calculation. -func tsigBuffer(msgbuf []byte, rr *TSIG, requestMAC string, timersOnly bool) []byte { +func tsigBuffer(msgbuf []byte, rr *TSIG, requestMAC string, timersOnly bool) ([]byte, error) { var buf []byte if rr.TimeSigned == 0 { rr.TimeSigned = uint64(time.Now().Unix()) @@ -226,7 +232,10 @@ func tsigBuffer(msgbuf []byte, rr *TSIG, requestMAC string, timersOnly bool) []b m.MACSize = uint16(len(requestMAC) / 2) m.MAC = requestMAC buf = make([]byte, len(requestMAC)) // long enough - n, _ := packMacWire(m, buf) + n, err := packMacWire(m, buf) + if err != nil { + return nil, err + } buf = buf[:n] } @@ -235,7 +244,10 @@ func tsigBuffer(msgbuf []byte, rr *TSIG, requestMAC string, timersOnly bool) []b tsig := new(timerWireFmt) tsig.TimeSigned = rr.TimeSigned tsig.Fudge = rr.Fudge - n, _ := packTimerWire(tsig, tsigvar) + n, err := packTimerWire(tsig, tsigvar) + if err != nil { + return nil, err + } tsigvar = tsigvar[:n] } else { tsig := new(tsigWireFmt) @@ -248,7 +260,10 @@ func tsigBuffer(msgbuf []byte, rr *TSIG, requestMAC string, timersOnly bool) []b tsig.Error = rr.Error tsig.OtherLen = rr.OtherLen tsig.OtherData = rr.OtherData - n, _ := packTsigWire(tsig, tsigvar) + n, err := packTsigWire(tsig, tsigvar) + if err != nil { + return nil, err + } tsigvar = tsigvar[:n] } @@ -258,7 +273,7 @@ func tsigBuffer(msgbuf []byte, rr *TSIG, requestMAC string, timersOnly bool) []b } else { buf = append(msgbuf, tsigvar...) } - return buf + return buf, nil } // Strip the TSIG from the raw message. diff --git a/tsig_test.go b/tsig_test.go index 8e4cc91b..65d4b363 100644 --- a/tsig_test.go +++ b/tsig_test.go @@ -4,6 +4,7 @@ import ( "encoding/binary" "encoding/hex" "fmt" + "strings" "testing" "time" ) @@ -99,11 +100,28 @@ func TestTsigErrors(t *testing.T) { } // call TsigVerify with a message that doesn't contain a TSIG - msgData, _, err := stripTsig(buildMsgData(timeSigned)) + msgData, tsig, err := stripTsig(buildMsgData(timeSigned)) if err != nil { t.Fatal(err) } if err := tsigVerify(msgData, testSecret, "", false, timeSigned); err != ErrNoSig { t.Fatalf("expected an error '%v' but got '%v'", ErrNoSig, err) } + + // replace the test TSIG with a bogus one with large "other data", which would cause overflow in TsigVerify. + // The overflow should be caught without disruption. + tsig.OtherData = strings.Repeat("00", 4096) + tsig.OtherLen = uint16(len(tsig.OtherData) / 2) + msg := new(Msg) + if err = msg.Unpack(msgData); err != nil { + t.Fatal(err) + } + msg.Extra = append(msg.Extra, tsig) + if msgData, err = msg.Pack(); err != nil { + t.Fatal(err) + } + err = tsigVerify(msgData, testSecret, "", false, timeSigned) + if err == nil || !strings.Contains(err.Error(), "overflow") { + t.Errorf("expected error to contain %q, but got %v", "overflow", err) + } }