diff --git a/generate.go b/generate.go index e4481a4b..3a559793 100644 --- a/generate.go +++ b/generate.go @@ -133,10 +133,20 @@ BuildRR: // Convert a $GENERATE modifier 0,0,d to something Printf can deal with. func modToPrintf(s string) (string, int, error) { - xs := strings.SplitN(s, ",", 3) - if len(xs) != 3 { + xs := strings.Split(s, ",") + + // Modifier is { offset [ ,width [ ,base ] ] } - provide default + // values for optional width and type, if necessary. + switch len(xs) { + case 1: + xs = append(xs, "0", "d") + case 2: + xs = append(xs, "d") + case 3: + default: return "", 0, errors.New("bad modifier in $GENERATE") } + // xs[0] is offset, xs[1] is width, xs[2] is base if xs[2] != "o" && xs[2] != "d" && xs[2] != "x" && xs[2] != "X" { return "", 0, errors.New("bad base in $GENERATE") diff --git a/generate_test.go b/generate_test.go new file mode 100644 index 00000000..1df6abd6 --- /dev/null +++ b/generate_test.go @@ -0,0 +1,39 @@ +package dns + +import ( + "testing" +) + +func TestGenerateModToPrintf(t *testing.T) { + tests := []struct { + mod string + wantFmt string + wantOffset int + wantErr bool + }{ + {"0,0,d", "%0d", 0, false}, + {"0,0", "%0d", 0, false}, + {"0", "%0d", 0, false}, + {"3,2,d", "%02d", 3, false}, + {"3,2", "%02d", 3, false}, + {"3", "%0d", 3, false}, + {"0,0,o", "%0o", 0, false}, + {"0,0,x", "%0x", 0, false}, + {"0,0,X", "%0X", 0, false}, + {"0,0,z", "", 0, true}, + {"0,0,0,d", "", 0, true}, + } + for _, test := range tests { + gotFmt, gotOffset, err := modToPrintf(test.mod) + switch { + case err != nil && !test.wantErr: + t.Errorf("modToPrintf(%q) - expected nil-error, but got %v", test.mod, err) + case err == nil && test.wantErr: + t.Errorf("modToPrintf(%q) - expected error, but got nil-error", test.mod) + case gotFmt != test.wantFmt: + t.Errorf("modToPrintf(%q) - expected format %q, but got %q", test.mod, test.wantFmt, gotFmt) + case gotOffset != test.wantOffset: + t.Errorf("modToPrintf(%q) - expected offset %d, but got %d", test.mod, test.wantOffset, gotOffset) + } + } +} diff --git a/scan_test.go b/scan_test.go index e43ad447..88baa602 100644 --- a/scan_test.go +++ b/scan_test.go @@ -2,11 +2,44 @@ package dns import ( "io/ioutil" + "net" "os" "strings" "testing" ) +func TestParseZoneGenerate(t *testing.T) { + zone := "$ORIGIN example.org.\n$GENERATE 10-12 foo${2,3,d} IN A 127.0.0.$" + + wantRRs := []RR{ + &A{Hdr: RR_Header{Name: "foo012.example.org."}, A: net.ParseIP("127.0.0.10")}, + &A{Hdr: RR_Header{Name: "foo013.example.org."}, A: net.ParseIP("127.0.0.11")}, + &A{Hdr: RR_Header{Name: "foo014.example.org."}, A: net.ParseIP("127.0.0.12")}, + } + wantIdx := 0 + + tok := ParseZone(strings.NewReader(zone), "", "") + for x := range tok { + if wantIdx >= len(wantRRs) { + t.Fatalf("expected %d RRs, but got more", len(wantRRs)) + } + if x.Error != nil { + t.Fatalf("expected no error, but got %s", x.Error) + } + if got, want := x.RR.Header().Name, wantRRs[wantIdx].Header().Name; got != want { + t.Fatalf("expected name %s, but got %s", want, got) + } + a, ok := x.RR.(*A) + if !ok { + t.Fatalf("expected *A RR, but got %T", x.RR) + } + if got, want := a.A, wantRRs[wantIdx].(*A).A; !got.Equal(want) { + t.Fatalf("expected A with IP %v, but got %v", got, want) + } + wantIdx++ + } +} + func TestParseZoneInclude(t *testing.T) { tmpfile, err := ioutil.TempFile("", "dns")