dns/types_generate.go

227 lines
4.9 KiB
Go
Raw Normal View History

//+build ignore
package main
import (
"bytes"
"fmt"
"go/format"
"log"
"os"
"strings"
"text/template"
"golang.org/x/tools/go/loader"
"golang.org/x/tools/go/types"
)
var skip = map[string]struct{}{
"PrivateRR": struct{}{},
}
var skipLen = map[string]struct{}{
"NSEC": struct{}{},
"NSEC3": struct{}{},
"OPT": struct{}{},
"WKS": struct{}{},
"IPSECKEY": struct{}{},
}
var packageHdr = `
// *** DO NOT MODIFY ***
// AUTOGENERATED BY go generate
package dns
import (
"encoding/base64"
"net"
)
`
var typeToRR = template.Must(template.New("typeToRR").Parse(`
// Map of constructors for each RR type.
var typeToRR = map[uint16]func() RR{
{{range .}}{{if ne . "RFC3597"}} Type{{.}}: func() RR { return new({{.}}) },
{{end}}{{end}} }
`))
var typeToString = template.Must(template.New("typeToString").Parse(`
// TypeToString is a map of strings for each RR wire type.
var TypeToString = map[uint16]string{
{{range .}}{{if ne . "NSAPPTR"}} Type{{.}}: "{{.}}",
{{end}}{{end}} TypeNSAPPTR: "NSAP-PTR",
}
`))
var headerFunc = template.Must(template.New("headerFunc").Parse(`
// Header() functions
{{range .}} func (rr *{{.}}) Header() *RR_Header { return &rr.Hdr }
{{end}}
`))
func getTypeStruct(t types.Type, scope *types.Scope) (*types.Struct, bool) {
st, ok := t.Underlying().(*types.Struct)
if !ok {
return nil, false
}
if st.Field(0).Type() == scope.Lookup("RR_Header").Type() {
return st, false
}
if st.Field(0).Anonymous() {
st, _ := getTypeStruct(st.Field(0).Type(), scope)
return st, true
}
return nil, false
}
func main() {
var conf loader.Config
conf.Import(".")
prog, err := conf.Load()
fatalIfErr(err)
scope := prog.Package(".").Pkg.Scope()
// Collect constants like TypeX
var numberedTypes []string
for _, name := range scope.Names() {
o := scope.Lookup(name)
if o == nil || !o.Exported() {
continue
}
b, ok := o.Type().(*types.Basic)
if !ok || b.Kind() != types.Uint16 {
continue
}
if !strings.HasPrefix(o.Name(), "Type") {
continue
}
name := strings.TrimPrefix(o.Name(), "Type")
if _, ok := skip[name]; ok {
continue
}
numberedTypes = append(numberedTypes, name)
}
// Collect actual types (*X)
var namedTypes []string
for _, name := range scope.Names() {
o := scope.Lookup(name)
if o == nil || !o.Exported() {
continue
}
if st, _ := getTypeStruct(o.Type(), scope); st == nil {
continue
}
if _, ok := skip[o.Name()]; ok {
continue
}
// Check if corresponding TypeX exists
if scope.Lookup("Type"+o.Name()) == nil && o.Name() != "RFC3597" {
log.Fatalf("Constant Type%s does not exist.", o.Name())
}
namedTypes = append(namedTypes, o.Name())
}
b := &bytes.Buffer{}
b.WriteString(packageHdr)
// Generate typeToRR
fatalIfErr(typeToRR.Execute(b, namedTypes))
// Generate typeToString
fatalIfErr(typeToString.Execute(b, numberedTypes))
// Generate headerFunc
fatalIfErr(headerFunc.Execute(b, namedTypes))
// Generate len()
fmt.Fprint(b, "// len() functions\n")
for _, name := range namedTypes {
if _, ok := skipLen[name]; ok {
continue
}
o := scope.Lookup(name)
st, isEmbedded := getTypeStruct(o.Type(), scope)
if isEmbedded {
continue
}
fmt.Fprintf(b, "func (rr *%s) len() int {\n", name)
fmt.Fprintf(b, "l := rr.Hdr.len()\n")
for i := 1; i < st.NumFields(); i++ {
o := func(s string) { fmt.Fprintf(b, s, st.Field(i).Name()) }
if _, ok := st.Field(i).Type().(*types.Slice); ok {
switch st.Tag(i) {
case `dns:"-"`:
// ignored
case `dns:"cdomain-name"`, `dns:"domain-name"`, `dns:"txt"`:
o("for _, x := range rr.%s { l += len(x) + 1 }\n")
default:
log.Fatalln(name, st.Field(i).Name(), st.Tag(i))
}
continue
}
switch st.Tag(i) {
case `dns:"-"`:
// ignored
case `dns:"cdomain-name"`, `dns:"domain-name"`:
o("l += len(rr.%s) + 1\n")
case `dns:"octet"`:
o("l += len(rr.%s)\n")
case `dns:"base64"`:
o("l += base64.StdEncoding.DecodedLen(len(rr.%s))\n")
case `dns:"size-hex"`, `dns:"hex"`:
o("l += len(rr.%s)/2 + 1\n")
case `dns:"a"`:
o("l += net.IPv4len // %s\n")
case `dns:"aaaa"`:
o("l += net.IPv6len // %s\n")
case `dns:"txt"`:
o("for _, t := range rr.%s { l += len(t) + 1 }\n")
case `dns:"uint48"`:
o("l += 6 // %s\n")
case "":
switch st.Field(i).Type().(*types.Basic).Kind() {
case types.Uint8:
o("l += 1 // %s\n")
case types.Uint16:
o("l += 2 // %s\n")
case types.Uint32:
o("l += 4 // %s\n")
case types.Uint64:
o("l += 8 // %s\n")
case types.String:
o("l += len(rr.%s) + 1\n")
default:
log.Fatalln(name, st.Field(i).Name())
}
default:
log.Fatalln(name, st.Field(i).Name(), st.Tag(i))
}
}
fmt.Fprintf(b, "return l }\n")
}
res, err := format.Source(b.Bytes())
fatalIfErr(err)
f, err := os.Create("types_auto.go")
fatalIfErr(err)
defer f.Close()
f.Write(res)
}
func fatalIfErr(err error) {
if err != nil {
log.Fatal(err)
}
}