diff --git a/scan.go b/scan.go index 8d4773c3..5f7f6442 100644 --- a/scan.go +++ b/scan.go @@ -278,8 +278,7 @@ func parseZone(r io.Reader, origin, f string, t chan *Token, include int) { return } neworigin := origin // There may be optionally a new origin set after the filename, if not use current one - l := <-c - switch l.value { + switch l := <-c; l.value { case zBlank: l := <-c if l.value == zString { diff --git a/scan_test.go b/scan_test.go new file mode 100644 index 00000000..b31c4c77 --- /dev/null +++ b/scan_test.go @@ -0,0 +1,45 @@ +package dns + +import ( + "io/ioutil" + "os" + "strings" + "testing" +) + +func TestParseZoneInclude(t *testing.T) { + + tmpfile, err := ioutil.TempFile("", "dns") + if err != nil { + t.Fatalf("could not create tmpfile for test: %s", err) + } + + if _, err := tmpfile.WriteString("foo\tIN\tA\t127.0.0.1"); err != nil { + t.Fatalf("unable to write content to tmpfile %q: %s", tmpfile.Name(), err) + } + if err := tmpfile.Close(); err != nil { + t.Fatalf("could not close tmpfile %q: %s", tmpfile.Name(), err) + } + + zone := "$INCLUDE " + tmpfile.Name() + + tok := ParseZone(strings.NewReader(zone), "", "") + for x := range tok { + if x.Error != nil { + t.Fatalf("expected no error, but got %s", x.Error) + } + } + + os.Remove(tmpfile.Name()) + + tok = ParseZone(strings.NewReader(zone), "", "") + for x := range tok { + if x.Error == nil { + t.Fatalf("expected first token to contain an error but it didn't") + } + if !strings.Contains(x.Error.Error(), "failed to open") || + !strings.Contains(x.Error.Error(), tmpfile.Name()) { + t.Fatalf(`expected error to contain: "failed to open" and %q but got: %s`, tmpfile.Name(), x.Error) + } + } +}