//+build ignore package main import ( "bytes" "fmt" "go/format" "log" "os" "strings" "text/template" "golang.org/x/tools/go/loader" "golang.org/x/tools/go/types" ) var skip = map[string]struct{}{ "PrivateRR": struct{}{}, } var skipLen = map[string]struct{}{ "NSEC": struct{}{}, "NSEC3": struct{}{}, "OPT": struct{}{}, "WKS": struct{}{}, "IPSECKEY": struct{}{}, } var packageHdr = ` // *** DO NOT MODIFY *** // AUTOGENERATED BY go generate package dns import ( "encoding/base64" "net" ) ` var typeToRR = template.Must(template.New("typeToRR").Parse(` // Map of constructors for each RR type. var typeToRR = map[uint16]func() RR{ {{range .}}{{if ne . "RFC3597"}} Type{{.}}: func() RR { return new({{.}}) }, {{end}}{{end}} } `)) var typeToString = template.Must(template.New("typeToString").Parse(` // TypeToString is a map of strings for each RR wire type. var TypeToString = map[uint16]string{ {{range .}}{{if ne . "NSAPPTR"}} Type{{.}}: "{{.}}", {{end}}{{end}} TypeNSAPPTR: "NSAP-PTR", } `)) var headerFunc = template.Must(template.New("headerFunc").Parse(` // Header() functions {{range .}} func (rr *{{.}}) Header() *RR_Header { return &rr.Hdr } {{end}} `)) func getTypeStruct(t types.Type, scope *types.Scope) (*types.Struct, bool) { st, ok := t.Underlying().(*types.Struct) if !ok { return nil, false } if st.Field(0).Type() == scope.Lookup("RR_Header").Type() { return st, false } if st.Field(0).Anonymous() { st, _ := getTypeStruct(st.Field(0).Type(), scope) return st, true } return nil, false } func main() { var conf loader.Config conf.Import(".") prog, err := conf.Load() fatalIfErr(err) scope := prog.Package(".").Pkg.Scope() // Collect constants like TypeX var numberedTypes []string for _, name := range scope.Names() { o := scope.Lookup(name) if o == nil || !o.Exported() { continue } b, ok := o.Type().(*types.Basic) if !ok || b.Kind() != types.Uint16 { continue } if !strings.HasPrefix(o.Name(), "Type") { continue } name := strings.TrimPrefix(o.Name(), "Type") if _, ok := skip[name]; ok { continue } numberedTypes = append(numberedTypes, name) } // Collect actual types (*X) var namedTypes []string for _, name := range scope.Names() { o := scope.Lookup(name) if o == nil || !o.Exported() { continue } if st, _ := getTypeStruct(o.Type(), scope); st == nil { continue } if _, ok := skip[o.Name()]; ok { continue } // Check if corresponding TypeX exists if scope.Lookup("Type"+o.Name()) == nil && o.Name() != "RFC3597" { log.Fatalf("Constant Type%s does not exist.", o.Name()) } namedTypes = append(namedTypes, o.Name()) } b := &bytes.Buffer{} b.WriteString(packageHdr) // Generate typeToRR fatalIfErr(typeToRR.Execute(b, namedTypes)) // Generate typeToString fatalIfErr(typeToString.Execute(b, numberedTypes)) // Generate headerFunc fatalIfErr(headerFunc.Execute(b, namedTypes)) // Generate len() fmt.Fprint(b, "// len() functions\n") for _, name := range namedTypes { if _, ok := skipLen[name]; ok { continue } o := scope.Lookup(name) st, isEmbedded := getTypeStruct(o.Type(), scope) if isEmbedded { continue } fmt.Fprintf(b, "func (rr *%s) len() int {\n", name) fmt.Fprintf(b, "l := rr.Hdr.len()\n") for i := 1; i < st.NumFields(); i++ { o := func(s string) { fmt.Fprintf(b, s, st.Field(i).Name()) } if _, ok := st.Field(i).Type().(*types.Slice); ok { switch st.Tag(i) { case `dns:"-"`: // ignored case `dns:"cdomain-name"`, `dns:"domain-name"`, `dns:"txt"`: o("for _, x := range rr.%s { l += len(x) + 1 }\n") default: log.Fatalln(name, st.Field(i).Name(), st.Tag(i)) } continue } switch st.Tag(i) { case `dns:"-"`: // ignored case `dns:"cdomain-name"`, `dns:"domain-name"`: o("l += len(rr.%s) + 1\n") case `dns:"octet"`: o("l += len(rr.%s)\n") case `dns:"base64"`: o("l += base64.StdEncoding.DecodedLen(len(rr.%s))\n") case `dns:"size-hex"`, `dns:"hex"`: o("l += len(rr.%s)/2 + 1\n") case `dns:"a"`: o("l += net.IPv4len // %s\n") case `dns:"aaaa"`: o("l += net.IPv6len // %s\n") case `dns:"txt"`: o("for _, t := range rr.%s { l += len(t) + 1 }\n") case `dns:"uint48"`: o("l += 6 // %s\n") case "": switch st.Field(i).Type().(*types.Basic).Kind() { case types.Uint8: o("l += 1 // %s\n") case types.Uint16: o("l += 2 // %s\n") case types.Uint32: o("l += 4 // %s\n") case types.Uint64: o("l += 8 // %s\n") case types.String: o("l += len(rr.%s) + 1\n") default: log.Fatalln(name, st.Field(i).Name()) } default: log.Fatalln(name, st.Field(i).Name(), st.Tag(i)) } } fmt.Fprintf(b, "return l }\n") } res, err := format.Source(b.Bytes()) fatalIfErr(err) f, err := os.Create("types_auto.go") fatalIfErr(err) defer f.Close() f.Write(res) } func fatalIfErr(err error) { if err != nil { log.Fatal(err) } }