Ran gofix, and manually bring code up to latest go release.

1) Ran gofix on all files.
2) Added "tcp" and "udp" to Resolve* functions in server.go
3) Generated primes to the primes array and not to two predefined
   struct members (P and Q), since now rsa support multi-factor primes.
This commit is contained in:
Elazar 2011-06-02 13:31:12 +03:00
parent 7dbb865c03
commit 904e322dfd
5 changed files with 79 additions and 78 deletions

View File

@ -149,7 +149,7 @@ func (q *Query) Query() os.Error {
if handler == nil {
handler = DefaultQueryMux
}
forever:
//forever:
for {
select {
case in := <-q.ChannelQuery:
@ -314,7 +314,7 @@ func (w *reply) Send(m *Msg) os.Error {
return ErrNoSig
}
m, _ = TsigGenerate(m, w.Client().TsigSecret[secret], w.tsigRequestMAC, w.tsigTimersOnly)
w.tsigRequestMAC = m.Extra[len(m.Extra)-1].(*RR_TSIG).MAC // Safe the requestMAC
w.tsigRequestMAC = m.Extra[len(m.Extra)-1].(*RR_TSIG).MAC // Safe the requestMAC
}
out, ok := m.Pack()
if !ok {
@ -336,7 +336,7 @@ func (w *reply) writeClient(p []byte) (n int, err os.Error) {
panic("c.Net empty")
}
conn, err := net.Dial(c.Net, "", w.addr)
conn, err := net.Dial(c.Net, w.addr)
if err != nil {
return 0, err
}

View File

@ -17,24 +17,24 @@ import (
// Wrap the contents of the /etc/resolv.conf.
type ClientConfig struct {
Servers []string // servers to use
Search []string // suffixes to append to local name
Port string // what port to use
Ndots int // number of dots in name to trigger absolute lookup
Timeout int // seconds before giving up on packet
Attempts int // lost packets before giving up on server
Servers []string // servers to use
Search []string // suffixes to append to local name
Port string // what port to use
Ndots int // number of dots in name to trigger absolute lookup
Timeout int // seconds before giving up on packet
Attempts int // lost packets before giving up on server
}
// See resolv.conf(5) on a Linux machine.
// Parse a /etc/resolv.conf like file and return a filled out ClientConfig. Note
// that all nameservers will have the port number appendend (:53)
func ClientConfigFromFile(conf string) (*ClientConfig, os.Error) {
file, err := os.Open(conf, os.O_RDONLY, 0)
file, err := os.Open(conf)
defer file.Close()
if err != nil {
return nil, err
}
c := new(ClientConfig)
c := new(ClientConfig)
b := bufio.NewReader(file)
c.Servers = make([]string, 3)[0:0] // small, but the standard limit
c.Search = make([]string, 0)

View File

@ -58,17 +58,17 @@ func (r *RR_DNSKEY) PrivateKeyString(p PrivateKey) (s string) {
e := big.NewInt(int64(t.PublicKey.E))
publicExponent := unpackBase64(e.Bytes())
privateExponent := unpackBase64(t.D.Bytes())
prime1 := unpackBase64(t.P.Bytes())
prime2 := unpackBase64(t.Q.Bytes())
prime1 := unpackBase64(t.Primes[0].Bytes())
prime2 := unpackBase64(t.Primes[1].Bytes())
// Calculate Exponent1/2 and Coefficient as per: http://en.wikipedia.org/wiki/RSA#Using_the_Chinese_remainder_algorithm
// and from: http://code.google.com/p/go/issues/detail?id=987
one := big.NewInt(1)
minusone := big.NewInt(-1)
p_1 := big.NewInt(0).Sub(t.P, one)
q_1 := big.NewInt(0).Sub(t.Q, one)
p_1 := big.NewInt(0).Sub(t.Primes[0], one)
q_1 := big.NewInt(0).Sub(t.Primes[1], one)
exp1 := big.NewInt(0).Mod(t.D, p_1)
exp2 := big.NewInt(0).Mod(t.D, q_1)
coeff := big.NewInt(0).Exp(t.Q, minusone, t.P)
coeff := big.NewInt(0).Exp(t.Primes[1], minusone, t.Primes[0])
exponent1 := unpackBase64(exp1.Bytes())
exponent2 := unpackBase64(exp2.Bytes())
@ -91,6 +91,7 @@ func (r *RR_DNSKEY) PrivateKeyString(p PrivateKey) (s string) {
// Read a private key (file) string and create a public key. Return the private key.
func (k *RR_DNSKEY) ReadPrivateKey(q io.Reader) (PrivateKey, os.Error) {
p := new(rsa.PrivateKey)
p.Primes = []*big.Int{nil,nil}
var left, right string
r := line.NewReader(q, 300)
line, _, err := r.ReadLine()
@ -128,12 +129,12 @@ func (k *RR_DNSKEY) ReadPrivateKey(q io.Reader) (PrivateKey, os.Error) {
p.D.SetBytes(v)
}
if left == "Prime1:" {
p.P = big.NewInt(0)
p.P.SetBytes(v)
p.Primes[0] = big.NewInt(0)
p.Primes[0].SetBytes(v)
}
if left == "Prime2:" {
p.Q = big.NewInt(0)
p.Q.SetBytes(v)
p.Primes[1] = big.NewInt(0)
p.Primes[1].SetBytes(v)
}
case "Exponent1:", "Exponent2:", "Coefficient:":
/* not used in Go (yet) */

112
msg.go
View File

@ -40,13 +40,13 @@ var (
ErrKeySize os.Error = &Error{Error: "bad key size"}
ErrAlg os.Error = &Error{Error: "bad algorithm"}
ErrTime os.Error = &Error{Error: "bad time"}
ErrNoSig os.Error = &Error{Error: "no signature found"}
ErrNoSig os.Error = &Error{Error: "no signature found"}
ErrSig os.Error = &Error{Error: "bad signature"}
ErrSigGen os.Error = &Error{Error: "bad signature generation"}
ErrAuth os.Error = &Error{Error: "bad authentication"}
ErrAuth os.Error = &Error{Error: "bad authentication"}
ErrXfrSoa os.Error = &Error{Error: "no SOA seen"}
ErrHandle os.Error = &Error{Error: "handle is nil"}
ErrChan os.Error = &Error{Error: "channel is nil"}
ErrHandle os.Error = &Error{Error: "handle is nil"}
ErrChan os.Error = &Error{Error: "channel is nil"}
)
// A manually-unpacked version of (id, bits).
@ -287,25 +287,25 @@ Loop:
// Pack a reflect.StructValue into msg. Struct members can only be uint8, uint16, uint32, string,
// slices and other (often anonymous) structs.
func packStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int, ok bool) {
func packStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok bool) {
for i := 0; i < val.NumField(); i++ {
f := val.Type().(*reflect.StructType).Field(i)
switch fv := val.Field(i).(type) {
f := val.Type().Field(i)
switch fv := val.Field(i); fv.Kind() {
default:
BadType:
//fmt.Fprintf(os.Stderr, "dns: unknown packing type %v\n", f.Type)
return len(msg), false
case *reflect.SliceValue:
case reflect.Slice:
switch f.Tag {
default:
//fmt.Fprintf(os.Stderr, "dns: unknown packing slice tag %v\n", f.Tag)
return len(msg), false
case "OPT": // edns
for j := 0; j < val.Field(i).(*reflect.SliceValue).Len(); j++ {
element := val.Field(i).(*reflect.SliceValue).Elem(j)
code := uint16(element.(*reflect.StructValue).Field(0).(*reflect.UintValue).Get())
for j := 0; j < val.Field(i).Len(); j++ {
element := val.Field(i).Index(j)
code := uint16(element.Field(0).Uint())
// for each code we should do something else
h, e := hex.DecodeString(string(element.(*reflect.StructValue).Field(1).(*reflect.StringValue).Get()))
h, e := hex.DecodeString(string(element.Field(1).String()))
if e != nil {
//fmt.Fprintf(os.Stderr, "dns: failure packing OTP")
return len(msg), false
@ -329,15 +329,15 @@ func packStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int, o
return len(msg), false
}
if fv.Len() == net.IPv6len {
msg[off] = byte(fv.Elem(12).(*reflect.UintValue).Get())
msg[off+1] = byte(fv.Elem(13).(*reflect.UintValue).Get())
msg[off+2] = byte(fv.Elem(14).(*reflect.UintValue).Get())
msg[off+3] = byte(fv.Elem(15).(*reflect.UintValue).Get())
msg[off] = byte(fv.Index(12).Uint())
msg[off+1] = byte(fv.Index(13).Uint())
msg[off+2] = byte(fv.Index(14).Uint())
msg[off+3] = byte(fv.Index(15).Uint())
} else {
msg[off] = byte(fv.Elem(0).(*reflect.UintValue).Get())
msg[off+1] = byte(fv.Elem(1).(*reflect.UintValue).Get())
msg[off+2] = byte(fv.Elem(2).(*reflect.UintValue).Get())
msg[off+3] = byte(fv.Elem(3).(*reflect.UintValue).Get())
msg[off] = byte(fv.Index(0).Uint())
msg[off+1] = byte(fv.Index(1).Uint())
msg[off+2] = byte(fv.Index(2).Uint())
msg[off+3] = byte(fv.Index(3).Uint())
}
off += net.IPv4len
case "AAAA":
@ -346,20 +346,20 @@ func packStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int, o
return len(msg), false
}
for j := 0; j < net.IPv6len; j++ {
msg[off] = byte(fv.Elem(j).(*reflect.UintValue).Get())
msg[off] = byte(fv.Index(j).Uint())
off++
}
case "NSEC": // NSEC/NSEC3
for j := 0; j < val.Field(i).(*reflect.SliceValue).Len(); j++ {
var _ = byte(fv.Elem(j).(*reflect.UintValue).Get())
for j := 0; j < val.Field(i).Len(); j++ {
var _ = byte(fv.Index(j).Uint())
}
// handle type bit maps
// TODO(mg)
}
case *reflect.StructValue:
case reflect.Struct:
off, ok = packStructValue(fv, msg, off)
case *reflect.UintValue:
i := fv.Get()
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
i := fv.Uint()
switch fv.Type().Kind() {
default:
goto BadType
@ -402,10 +402,10 @@ func packStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int, o
msg[off+5] = byte(i)
off += 6
}
case *reflect.StringValue:
case reflect.String:
// There are multiple string encodings.
// The tag distinguishes ordinary strings from domain names.
s := fv.Get()
s := fv.String()
switch f.Tag {
default:
//fmt.Fprintf(os.Stderr, "dns: unknown packing string tag %v", f.Tag)
@ -470,8 +470,8 @@ func packStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int, o
return off, true
}
func structValue(any interface{}) *reflect.StructValue {
return reflect.NewValue(any).(*reflect.PtrValue).Elem().(*reflect.StructValue)
func structValue(any interface{}) reflect.Value {
return reflect.ValueOf(any).Elem()
}
func packStruct(any interface{}, msg []byte, off int) (off1 int, ok bool) {
@ -481,15 +481,15 @@ func packStruct(any interface{}, msg []byte, off int) (off1 int, ok bool) {
// Unpack a reflect.StructValue from msg.
// Same restrictions as packStructValue.
func unpackStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int, ok bool) {
func unpackStructValue(val reflect.Value, msg []byte, off int) (off1 int, ok bool) {
for i := 0; i < val.NumField(); i++ {
f := val.Type().(*reflect.StructType).Field(i)
switch fv := val.Field(i).(type) {
f := val.Type().Field(i)
switch fv := val.Field(i); fv.Kind() {
default:
BadType:
//fmt.Fprintf(os.Stderr, "dns: unknown unpacking type %v", f.Type)
return len(msg), false
case *reflect.SliceValue:
case reflect.Slice:
switch f.Tag {
default:
//fmt.Fprintf(os.Stderr, "dns: unknown unpacking slice tag %v", f.Tag)
@ -500,7 +500,7 @@ func unpackStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int,
return len(msg), false
}
b := net.IPv4(msg[off], msg[off+1], msg[off+2], msg[off+3])
fv.Set(reflect.NewValue(b).(*reflect.SliceValue))
fv.Set(reflect.ValueOf(b))
off += net.IPv4len
case "AAAA":
if off+net.IPv6len > len(msg) {
@ -510,7 +510,7 @@ func unpackStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int,
p := make(net.IP, net.IPv6len)
copy(p, msg[off:off+net.IPv6len])
b := net.IP(p)
fv.Set(reflect.NewValue(b).(*reflect.SliceValue))
fv.Set(reflect.ValueOf(b))
off += net.IPv6len
case "OPT": // EDNS
if off+2 > len(msg) {
@ -526,7 +526,7 @@ func unpackStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int,
return len(msg), false
}
opt[0].Data = hex.EncodeToString(msg[off1 : off1+int(optlen)])
fv.Set(reflect.NewValue(opt).(*reflect.SliceValue))
fv.Set(reflect.ValueOf(opt))
off = off1 + int(optlen)
case "NSEC": // NSEC/NSEC3
if off+1 > len(msg) {
@ -546,7 +546,7 @@ func unpackStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int,
// Nothing encoded in this window
// Kinda lame to alloc above and to clear it here
nsec = nsec[:ni]
fv.Set(reflect.NewValue(nsec).(*reflect.SliceValue))
fv.Set(reflect.ValueOf(nsec))
break
}
@ -588,12 +588,12 @@ func unpackStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int,
}
}
nsec = nsec[:ni]
fv.Set(reflect.NewValue(nsec).(*reflect.SliceValue))
fv.Set(reflect.ValueOf(nsec))
off += blocks
}
case *reflect.StructValue:
case reflect.Struct:
off, ok = unpackStructValue(fv, msg, off)
case *reflect.UintValue:
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
switch fv.Type().Kind() {
default:
goto BadType
@ -603,7 +603,7 @@ func unpackStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int,
return len(msg), false
}
i := uint8(msg[off])
fv.Set(uint64(i))
fv.SetUint(uint64(i))
off++
case reflect.Uint16:
var i uint16
@ -612,14 +612,14 @@ func unpackStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int,
return len(msg), false
}
i, off = unpackUint16(msg, off)
fv.Set(uint64(i))
fv.SetUint(uint64(i))
case reflect.Uint32:
if off+4 > len(msg) {
//fmt.Fprintf(os.Stderr, "dns: overflow unpacking uint32")
return len(msg), false
}
i := uint32(msg[off])<<24 | uint32(msg[off+1])<<16 | uint32(msg[off+2])<<8 | uint32(msg[off+3])
fv.Set(uint64(i))
fv.SetUint(uint64(i))
off += 4
case reflect.Uint64:
// This is *only* used in TSIG where the last 48 bits are occupied
@ -630,10 +630,10 @@ func unpackStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int,
}
i := uint64(msg[off])<<40 | uint64(msg[off+1])<<32 | uint64(msg[off+2])<<24 | uint64(msg[off+3])<<16 |
uint64(msg[off+4])<<8 | uint64(msg[off+5])
fv.Set(uint64(i))
fv.SetUint(uint64(i))
off += 6
}
case *reflect.StringValue:
case reflect.String:
var s string
switch f.Tag {
default:
@ -641,7 +641,7 @@ func unpackStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int,
return len(msg), false
case "hex":
// Rest of the RR is hex encoded, network order an issue here?
rdlength := int(val.FieldByName("Hdr").(*reflect.StructValue).FieldByName("Rdlength").(*reflect.UintValue).Get())
rdlength := int(val.FieldByName("Hdr").FieldByName("Rdlength").Uint())
var consumed int
switch val.Type().Name() {
case "RR_DS":
@ -659,7 +659,7 @@ func unpackStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int,
off += rdlength - consumed
case "base64":
// Rest of the RR is base64 encoded value
rdlength := int(val.FieldByName("Hdr").(*reflect.StructValue).FieldByName("Rdlength").(*reflect.UintValue).Get())
rdlength := int(val.FieldByName("Hdr").FieldByName("Rdlength").Uint())
// Need to know how much of rdlength is already consumed, in this packet
var consumed int
switch val.Type().Name() {
@ -670,7 +670,7 @@ func unpackStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int,
// OrigTTL(4) + SigExpir(4) + SigIncep(4) + KeyTag(2) + len(signername)
// Should already be set in the sequence of parsing (comes before)
// Work because of rfc4034, section 3.17
consumed += len(val.FieldByName("SignerName").(*reflect.StringValue).Get()) + 1
consumed += len(val.FieldByName("SignerName").String()) + 1
default:
consumed = 0 // TODO
}
@ -689,7 +689,7 @@ func unpackStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int,
switch f.Name {
case "NextDomain":
name := val.FieldByName("HashLength")
size = int(name.(*reflect.UintValue).Get())
size = int(name.Uint())
}
}
if off+size > len(msg) {
@ -706,19 +706,19 @@ func unpackStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int,
switch f.Name {
case "Salt":
name := val.FieldByName("SaltLength")
size = int(name.(*reflect.UintValue).Get())
size = int(name.Uint())
case "NextDomain":
name := val.FieldByName("HashLength")
size = int(name.(*reflect.UintValue).Get())
size = int(name.Uint())
}
case "RR_TSIG":
switch f.Name {
case "MAC":
name := val.FieldByName("MACSize")
size = int(name.(*reflect.UintValue).Get())
size = int(name.Uint())
case "OtherData":
name := val.FieldByName("OtherLen")
size = int(name.(*reflect.UintValue).Get())
size = int(name.Uint())
}
}
if off+size > len(msg) {
@ -729,7 +729,7 @@ func unpackStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int,
off += size
case "txt":
// 1 or multiple txt pieces
rdlength := int(val.FieldByName("Hdr").(*reflect.StructValue).FieldByName("Rdlength").(*reflect.UintValue).Get())
rdlength := int(val.FieldByName("Hdr").FieldByName("Rdlength").Uint())
Txt:
if off >= len(msg) || off+1+int(msg[off]) > len(msg) {
//fmt.Fprintf(os.Stderr, "dns: failure unpacking txt string")
@ -761,7 +761,7 @@ func unpackStructValue(val *reflect.StructValue, msg []byte, off int) (off1 int,
off += n
s = string(b)
}
fv.Set(s)
fv.SetString(s)
}
}
return off, true

View File

@ -190,7 +190,7 @@ func (srv *Server) ListenAndServe() os.Error {
}
switch srv.Net {
case "tcp":
a, e := net.ResolveTCPAddr(addr)
a, e := net.ResolveTCPAddr("tcp",addr)
if e != nil {
return e
}
@ -200,7 +200,7 @@ func (srv *Server) ListenAndServe() os.Error {
}
return srv.ServeTCP(l)
case "udp":
a, e := net.ResolveUDPAddr(addr)
a, e := net.ResolveUDPAddr("udp",addr)
if e != nil {
return e
}