diff --git a/dnssec_keyscan.go b/dnssec_keyscan.go index 4f8d830b..063094d6 100644 --- a/dnssec_keyscan.go +++ b/dnssec_keyscan.go @@ -169,10 +169,20 @@ func readPrivateKeyECDSA(m map[string]string) (*ecdsa.PrivateKey, error) { // parseKey reads a private key from r. It returns a map[string]string, // with the key-value pairs, or an error when the file is not correct. func parseKey(r io.Reader, file string) (map[string]string, error) { - s := scanInit(r) + s, cancel := scanInit(r) m := make(map[string]string) c := make(chan lex) k := "" + defer func() { + cancel() + // zlexer can send up to two tokens, the next one and possibly 1 remainders. + // Do a non-blocking read. + _, ok := <-c + _, ok = <-c + if !ok { + // too bad + } + }() // Start the lexer go klexer(s, c) for l := range c { diff --git a/leak_test.go b/leak_test.go new file mode 100644 index 00000000..af37011d --- /dev/null +++ b/leak_test.go @@ -0,0 +1,71 @@ +package dns + +import ( + "fmt" + "os" + "runtime" + "sort" + "strings" + "testing" + "time" +) + +// copied from net/http/main_test.go + +func interestingGoroutines() (gs []string) { + buf := make([]byte, 2<<20) + buf = buf[:runtime.Stack(buf, true)] + for _, g := range strings.Split(string(buf), "\n\n") { + sl := strings.SplitN(g, "\n", 2) + if len(sl) != 2 { + continue + } + stack := strings.TrimSpace(sl[1]) + if stack == "" || + strings.Contains(stack, "testing.(*M).before.func1") || + strings.Contains(stack, "os/signal.signal_recv") || + strings.Contains(stack, "created by net.startServer") || + strings.Contains(stack, "created by testing.RunTests") || + strings.Contains(stack, "closeWriteAndWait") || + strings.Contains(stack, "testing.Main(") || + strings.Contains(stack, "testing.(*T).Run(") || + // These only show up with GOTRACEBACK=2; Issue 5005 (comment 28) + strings.Contains(stack, "runtime.goexit") || + strings.Contains(stack, "created by runtime.gc") || + strings.Contains(stack, "dns.interestingGoroutines") || + strings.Contains(stack, "runtime.MHeap_Scavenger") { + continue + } + gs = append(gs, stack) + } + sort.Strings(gs) + return +} + +func goroutineLeaked() error { + if testing.Short() { + // Don't worry about goroutine leaks in -short mode or in + // benchmark mode. Too distracting when there are false positives. + return nil + } + + var stackCount map[string]int + for i := 0; i < 5; i++ { + n := 0 + stackCount = make(map[string]int) + gs := interestingGoroutines() + for _, g := range gs { + stackCount[g]++ + n++ + } + if n == 0 { + return nil + } + // Wait for goroutines to schedule and die off: + time.Sleep(100 * time.Millisecond) + } + for stack, count := range stackCount { + fmt.Fprintf(os.Stderr, "%d instances of:\n%s\n", count, stack) + } + return fmt.Errorf("too many goroutines running after dns test(s)") +} diff --git a/parse_test.go b/parse_test.go index e7ffd408..a61943cd 100644 --- a/parse_test.go +++ b/parse_test.go @@ -1406,6 +1406,18 @@ func TestParseAVC(t *testing.T) { } } +func TestParseBadNAPTR(t *testing.T) { + // Should look like: mplus.ims.vodafone.com. 3600 IN NAPTR 10 100 "S" "SIP+D2U" "" _sip._udp.mplus.ims.vodafone.com. + naptr := `mplus.ims.vodafone.com. 3600 IN NAPTR 10 100 S SIP+D2U _sip._udp.mplus.ims.vodafone.com.` + _, err := NewRR(naptr) // parse fails, we should not have leaked a goroutine. + if err == nil { + t.Fatalf("parsing NAPTR should have failed: %s", naptr) + } + if err := goroutineLeaked(); err != nil { + t.Errorf("leaked goroutines: %s", err) + } +} + func TestUnbalancedParens(t *testing.T) { sig := `example.com. 3600 IN RRSIG MX 15 2 3600 ( 1440021600 1438207200 3613 example.com. ( diff --git a/scan.go b/scan.go index 243b9cf1..ba74b6a5 100644 --- a/scan.go +++ b/scan.go @@ -179,10 +179,22 @@ func parseZone(r io.Reader, origin string, defttl *ttlState, f string, t chan *T close(t) } }() - s := scanInit(r) + s, cancel := scanInit(r) c := make(chan lex) // Start the lexer go zlexer(s, c) + + defer func() { + cancel() + // zlexer can send up to three tokens, the next one and possibly 2 remainders. + // Do a non-blocking read. + _, ok := <-c + _, ok = <-c + _, ok = <-c + if !ok { + // too bad + } + }() // 6 possible beginnings of a line, _ is a space // 0. zRRTYPE -> all omitted until the rrtype // 1. zOwner _ zRrtype -> class/ttl omitted diff --git a/scanner.go b/scanner.go index c29bc2f3..424e5af9 100644 --- a/scanner.go +++ b/scanner.go @@ -4,6 +4,7 @@ package dns import ( "bufio" + "context" "io" "text/scanner" ) @@ -12,13 +13,18 @@ type scan struct { src *bufio.Reader position scanner.Position eof bool // Have we just seen a eof + ctx context.Context } -func scanInit(r io.Reader) *scan { +func scanInit(r io.Reader) (*scan, context.CancelFunc) { s := new(scan) s.src = bufio.NewReader(r) s.position.Line = 1 - return s + + ctx, cancel := context.WithCancel(context.Background()) + s.ctx = ctx + + return s, cancel } // tokenText returns the next byte from the input @@ -27,6 +33,13 @@ func (s *scan) tokenText() (byte, error) { if err != nil { return c, err } + select { + case <-s.ctx.Done(): + return c, context.Canceled + default: + break + } + // delay the newline handling until the next token is delivered, // fixes off-by-one errors when reporting a parse error. if s.eof == true {