diff --git a/CONTRIBUTORS b/CONTRIBUTORS index f77e8a89..5903779d 100644 --- a/CONTRIBUTORS +++ b/CONTRIBUTORS @@ -7,3 +7,4 @@ Marek Majkowski Peter van Dijk Omri Bahumi Alex Sergeyev +James Hartig diff --git a/internal/socket/cmsghdr.go b/internal/socket/cmsghdr.go new file mode 100644 index 00000000..62f2d2f7 --- /dev/null +++ b/internal/socket/cmsghdr.go @@ -0,0 +1,7 @@ +// +build linux + +package socket + +func (h *cmsghdr) len() int { return int(h.Len) } +func (h *cmsghdr) lvl() int { return int(h.Level) } +func (h *cmsghdr) typ() int { return int(h.Type) } diff --git a/internal/socket/cmsghdr_linux_32bit.go b/internal/socket/cmsghdr_linux_32bit.go new file mode 100644 index 00000000..e92e8580 --- /dev/null +++ b/internal/socket/cmsghdr_linux_32bit.go @@ -0,0 +1,20 @@ +// +build arm mips mipsle 386 +// +build linux + +package socket + +type cmsghdr struct { + Len uint32 + Level int32 + Type int32 +} + +const ( + sizeofCmsghdr = 0xc +) + +func (h *cmsghdr) set(l, lvl, typ int) { + h.Len = uint32(l) + h.Level = int32(lvl) + h.Type = int32(typ) +} diff --git a/internal/socket/cmsghdr_linux_64bit.go b/internal/socket/cmsghdr_linux_64bit.go new file mode 100644 index 00000000..ddfc9e09 --- /dev/null +++ b/internal/socket/cmsghdr_linux_64bit.go @@ -0,0 +1,20 @@ +// +build arm64 amd64 ppc64 ppc64le mips64 mips64le s390x +// +build linux + +package socket + +type cmsghdr struct { + Len uint64 + Level int32 + Type int32 +} + +const ( + sizeofCmsghdr = 0x10 +) + +func (h *cmsghdr) set(l, lvl, typ int) { + h.Len = uint64(l) + h.Level = int32(lvl) + h.Type = int32(typ) +} diff --git a/internal/socket/cmsghdr_other.go b/internal/socket/cmsghdr_other.go new file mode 100644 index 00000000..8078487c --- /dev/null +++ b/internal/socket/cmsghdr_other.go @@ -0,0 +1,13 @@ +// +build !linux + +package socket + +type cmsghdr struct{} + +const sizeofCmsghdr = 0 + +func (h *cmsghdr) len() int { return 0 } +func (h *cmsghdr) lvl() int { return 0 } +func (h *cmsghdr) typ() int { return 0 } + +func (h *cmsghdr) set(l, lvl, typ int) {} diff --git a/internal/socket/controlmessage.go b/internal/socket/controlmessage.go new file mode 100644 index 00000000..3176e960 --- /dev/null +++ b/internal/socket/controlmessage.go @@ -0,0 +1,118 @@ +package socket + +import ( + "errors" + "unsafe" +) + +func controlHeaderLen() int { + return roundup(sizeofCmsghdr) +} + +func controlMessageLen(dataLen int) int { + return roundup(sizeofCmsghdr) + dataLen +} + +// returns the whole length of control message. +func ControlMessageSpace(dataLen int) int { + return roundup(sizeofCmsghdr) + roundup(dataLen) +} + +// A ControlMessage represents the head message in a stream of control +// messages. +// +// A control message comprises of a header, data and a few padding +// fields to conform to the interface to the kernel. +// +// See RFC 3542 for further information. +type ControlMessage []byte + +// Data returns the data field of the control message at the head. +func (m ControlMessage) Data(dataLen int) []byte { + l := controlHeaderLen() + if len(m) < l || len(m) < l+dataLen { + return nil + } + return m[l : l+dataLen] +} + +// ParseHeader parses and returns the header fields of the control +// message at the head. +func (m ControlMessage) ParseHeader() (lvl, typ, dataLen int, err error) { + l := controlHeaderLen() + if len(m) < l { + return 0, 0, 0, errors.New("short message") + } + h := (*cmsghdr)(unsafe.Pointer(&m[0])) + return h.lvl(), h.typ(), int(uint64(h.len()) - uint64(l)), nil +} + +// Next returns the control message at the next. +func (m ControlMessage) Next(dataLen int) ControlMessage { + l := ControlMessageSpace(dataLen) + if len(m) < l { + return nil + } + return m[l:] +} + +// MarshalHeader marshals the header fields of the control message at +// the head. +func (m ControlMessage) MarshalHeader(lvl, typ, dataLen int) error { + if len(m) < controlHeaderLen() { + return errors.New("short message") + } + h := (*cmsghdr)(unsafe.Pointer(&m[0])) + h.set(controlMessageLen(dataLen), lvl, typ) + return nil +} + +// Marshal marshals the control message at the head, and returns the next +// control message. +func (m ControlMessage) Marshal(lvl, typ int, data []byte) (ControlMessage, error) { + l := len(data) + if len(m) < ControlMessageSpace(l) { + return nil, errors.New("short message") + } + h := (*cmsghdr)(unsafe.Pointer(&m[0])) + h.set(controlMessageLen(l), lvl, typ) + if l > 0 { + copy(m.Data(l), data) + } + return m.Next(l), nil +} + +// Parse parses as a single or multiple control messages. +func (m ControlMessage) Parse() ([]ControlMessage, error) { + var ms []ControlMessage + for len(m) >= controlHeaderLen() { + h := (*cmsghdr)(unsafe.Pointer(&m[0])) + l := h.len() + if l <= 0 { + return nil, errors.New("invalid header length") + } + if uint64(l) < uint64(controlHeaderLen()) { + return nil, errors.New("invalid message length") + } + if uint64(l) > uint64(len(m)) { + return nil, errors.New("short buffer") + } + ms = append(ms, ControlMessage(m[:l])) + ll := l - controlHeaderLen() + if len(m) >= ControlMessageSpace(ll) { + m = m[ControlMessageSpace(ll):] + } else { + m = m[controlMessageLen(ll):] + } + } + return ms, nil +} + +// NewControlMessage returns a new stream of control messages. +func NewControlMessage(dataLen []int) ControlMessage { + var l int + for i := range dataLen { + l += ControlMessageSpace(dataLen[i]) + } + return make([]byte, l) +} diff --git a/internal/socket/controlmessage_test.go b/internal/socket/controlmessage_test.go new file mode 100644 index 00000000..e9fff4d4 --- /dev/null +++ b/internal/socket/controlmessage_test.go @@ -0,0 +1,103 @@ +// +build linux + +package socket + +import ( + "bytes" + "testing" +) + +type mockControl struct { + Level int + Type int + Data []byte +} + +func TestControlMessage(t *testing.T) { + for _, tt := range []struct { + cs []mockControl + }{ + { + []mockControl{ + {Level: 1, Type: 1}, + }, + }, + { + []mockControl{ + {Level: 2, Type: 2, Data: []byte{0xfe}}, + }, + }, + { + []mockControl{ + {Level: 3, Type: 3, Data: []byte{0xfe, 0xff, 0xff, 0xfe}}, + }, + }, + { + []mockControl{ + {Level: 4, Type: 4, Data: []byte{0xfe, 0xff, 0xff, 0xfe, 0xfe, 0xff, 0xff, 0xfe}}, + }, + }, + { + []mockControl{ + {Level: 4, Type: 4, Data: []byte{0xfe, 0xff, 0xff, 0xfe, 0xfe, 0xff, 0xff, 0xfe}}, + {Level: 2, Type: 2, Data: []byte{0xfe}}, + }, + }, + } { + var w []byte + var tailPadLen int + mm := NewControlMessage([]int{0}) + for i, c := range tt.cs { + m := NewControlMessage([]int{len(c.Data)}) + l := len(m) - len(mm) + if i == len(tt.cs)-1 && l > len(c.Data) { + tailPadLen = l - len(c.Data) + } + w = append(w, m...) + } + + var err error + ww := make([]byte, len(w)) + copy(ww, w) + m := ControlMessage(ww) + for _, c := range tt.cs { + if err = m.MarshalHeader(c.Level, c.Type, len(c.Data)); err != nil { + t.Fatalf("(%v).MarshalHeader() = %v", tt.cs, err) + } + copy(m.Data(len(c.Data)), c.Data) + m = m.Next(len(c.Data)) + } + m = ControlMessage(w) + for _, c := range tt.cs { + m, err = m.Marshal(c.Level, c.Type, c.Data) + if err != nil { + t.Fatalf("(%v).Marshal() = %v", tt.cs, err) + } + } + if !bytes.Equal(ww, w) { + t.Fatalf("got %#v; want %#v", ww, w) + } + + ws := [][]byte{w} + if tailPadLen > 0 { + // Test a message with no tail padding. + nopad := w[:len(w)-tailPadLen] + ws = append(ws, [][]byte{nopad}...) + } + for _, w := range ws { + ms, err := ControlMessage(w).Parse() + if err != nil { + t.Fatalf("(%v).Parse() = %v", tt.cs, err) + } + for i, m := range ms { + lvl, typ, dataLen, err := m.ParseHeader() + if err != nil { + t.Fatalf("(%v).ParseHeader() = %v", tt.cs, err) + } + if lvl != tt.cs[i].Level || typ != tt.cs[i].Type || dataLen != len(tt.cs[i].Data) { + t.Fatalf("%v: got %d, %d, %d; want %d, %d, %d", tt.cs[i], lvl, typ, dataLen, tt.cs[i].Level, tt.cs[i].Type, len(tt.cs[i].Data)) + } + } + } + } +} diff --git a/internal/socket/socket.go b/internal/socket/socket.go new file mode 100644 index 00000000..edb58e29 --- /dev/null +++ b/internal/socket/socket.go @@ -0,0 +1,4 @@ +// Package socket contains ControlMessage parsing code from +// golang.org/x/net/internal/socket. Instead of supporting all possible +// architectures, we're only supporting linux 32/64 bit. +package socket diff --git a/internal/socket/sys.go b/internal/socket/sys.go new file mode 100644 index 00000000..2f3f5cfe --- /dev/null +++ b/internal/socket/sys.go @@ -0,0 +1,14 @@ +package socket + +import "unsafe" + +var ( + kernelAlign = func() int { + var p uintptr + return int(unsafe.Sizeof(p)) + }() +) + +func roundup(l int) int { + return (l + kernelAlign - 1) & ^(kernelAlign - 1) +} diff --git a/udp.go b/udp.go index af111b9a..12a20967 100644 --- a/udp.go +++ b/udp.go @@ -27,8 +27,19 @@ func ReadFromSessionUDP(conn *net.UDPConn, b []byte) (int, *SessionUDP, error) { return n, &SessionUDP{raddr, oob[:oobn]}, err } -// WriteToSessionUDP acts just like net.UDPConn.WritetTo(), but uses a *SessionUDP instead of a net.Addr. +// WriteToSessionUDP acts just like net.UDPConn.WriteTo(), but uses a *SessionUDP instead of a net.Addr. func WriteToSessionUDP(conn *net.UDPConn, b []byte, session *SessionUDP) (int, error) { - n, _, err := conn.WriteMsgUDP(b, session.context, session.raddr) + oob := correctSource(session.context) + n, _, err := conn.WriteMsgUDP(b, oob, session.raddr) return n, err } + +// correctSource takes oob data and returns new oob data with the Src equal to the Dst +func correctSource(oob []byte) []byte { + dst, err := parseUDPSocketDst(oob) + // If the destination could not be determined, ignore. + if err != nil || dst == nil { + return nil + } + return marshalUDPSocketSrc(dst) +} diff --git a/udp_linux.go b/udp_linux.go index 033df423..13747ed3 100644 --- a/udp_linux.go +++ b/udp_linux.go @@ -13,8 +13,34 @@ package dns import ( "net" "syscall" + "unsafe" + + "github.com/miekg/dns/internal/socket" ) +const ( + sizeofInet6Pktinfo = 0x14 + sizeofInetPktinfo = 0xc + protocolIP = 0 + protocolIPv6 = 41 +) + +type inetPktinfo struct { + Ifindex int32 + Spec_dst [4]byte /* in_addr */ + Addr [4]byte /* in_addr */ +} + +type inet6Pktinfo struct { + Addr [16]byte /* in6_addr */ + Ifindex int32 +} + +type inetControlMessage struct { + Src net.IP // source address, specifying only + Dst net.IP // destination address, receiving only +} + // setUDPSocketOptions sets the UDP socket options. // This function is implemented on a per platform basis. See udp_*.go for more details func setUDPSocketOptions(conn *net.UDPConn) error { @@ -103,3 +129,92 @@ func getUDPSocketName(conn *net.UDPConn) (syscall.Sockaddr, error) { defer file.Close() return syscall.Getsockname(int(file.Fd())) } + +// marshalInetPacketInfo marshals a ipv4 control message, returning +// the byte slice for the next marshal, if any +func marshalInetPacketInfo(b []byte, cm *inetControlMessage) []byte { + m := socket.ControlMessage(b) + m.MarshalHeader(protocolIP, syscall.IP_PKTINFO, sizeofInetPktinfo) + if cm != nil { + pi := (*inetPktinfo)(unsafe.Pointer(&m.Data(sizeofInetPktinfo)[0])) + if ip := cm.Src.To4(); ip != nil { + copy(pi.Spec_dst[:], ip) + } + } + return m.Next(sizeofInetPktinfo) +} + +// marshalInet6PacketInfo marshals a ipv6 control message, returning +// the byte slice for the next marshal, if any +func marshalInet6PacketInfo(b []byte, cm *inetControlMessage) []byte { + m := socket.ControlMessage(b) + m.MarshalHeader(protocolIPv6, syscall.IPV6_PKTINFO, sizeofInet6Pktinfo) + if cm != nil { + pi := (*inet6Pktinfo)(unsafe.Pointer(&m.Data(sizeofInet6Pktinfo)[0])) + if ip := cm.Src.To16(); ip != nil && ip.To4() == nil { + copy(pi.Addr[:], ip) + } + } + return m.Next(sizeofInet6Pktinfo) +} + +func parseInetPacketInfo(cm *inetControlMessage, b []byte) { + pi := (*inetPktinfo)(unsafe.Pointer(&b[0])) + if len(cm.Dst) < net.IPv4len { + cm.Dst = make(net.IP, net.IPv4len) + } + copy(cm.Dst, pi.Addr[:]) +} + +func parseInet6PacketInfo(cm *inetControlMessage, b []byte) { + pi := (*inet6Pktinfo)(unsafe.Pointer(&b[0])) + if len(cm.Dst) < net.IPv6len { + cm.Dst = make(net.IP, net.IPv6len) + } + copy(cm.Dst, pi.Addr[:]) +} + +// parseUDPSocketDst takes out-of-band data from ReadMsgUDP and parses it for +// the Dst address +func parseUDPSocketDst(oob []byte) (net.IP, error) { + cm := new(inetControlMessage) + ms, err := socket.ControlMessage(oob).Parse() + if err != nil { + return nil, err + } + for _, m := range ms { + lvl, typ, l, err := m.ParseHeader() + if err != nil { + return nil, err + } + if lvl == protocolIPv6 { // IPv6 + if typ == syscall.IPV6_PKTINFO && l >= sizeofInet6Pktinfo { + parseInet6PacketInfo(cm, m.Data(l)) + } + } else if lvl == protocolIP { // IPv4 + if typ == syscall.IP_PKTINFO && l >= sizeofInetPktinfo { + parseInetPacketInfo(cm, m.Data(l)) + } + } + } + return cm.Dst, nil +} + +// marshalUDPSocketSrc takes the given src address and returns out-of-band data +// to give to WriteMsgUDP +func marshalUDPSocketSrc(src net.IP) []byte { + var oob []byte + // If the dst is definitely an ipv6, then use ipv6 control to respond + // otherwise use ipv4 because the ipv6 marshal ignores ipv4 messages. + // See marshalInet6PacketInfo + cm := new(inetControlMessage) + cm.Src = src + if src.To4() == nil { + oob = make([]byte, socket.ControlMessageSpace(sizeofInet6Pktinfo)) + marshalInet6PacketInfo(oob, cm) + } else { + oob = make([]byte, socket.ControlMessageSpace(sizeofInetPktinfo)) + marshalInetPacketInfo(oob, cm) + } + return oob +} diff --git a/udp_linux_test.go b/udp_linux_test.go new file mode 100644 index 00000000..14a8acef --- /dev/null +++ b/udp_linux_test.go @@ -0,0 +1,68 @@ +// +build linux,!appengine + +package dns + +import ( + "bytes" + "net" + "testing" +) + +func TestParseUDPSocketDst(t *testing.T) { + // dst is :ffff:100.100.100.100 + oob := []byte{36, 0, 0, 0, 0, 0, 0, 0, 41, 0, 0, 0, 50, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 100, 100, 100, 100, 2, 0, 0, 0} + dst, err := parseUDPSocketDst(oob) + if err != nil { + t.Fatalf("error parsing ipv6 oob: %v", err) + } + dst4 := dst.To4() + if dst4 == nil { + t.Errorf("failed to parse ipv4: %v", dst) + } else if dst4.String() != "100.100.100.100" { + t.Errorf("unexpected ipv4: %v", dst4) + } + + // dst is 2001:db8::1 + oob = []byte{36, 0, 0, 0, 0, 0, 0, 0, 41, 0, 0, 0, 50, 0, 0, 0, 32, 1, 13, 184, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0} + dst, err = parseUDPSocketDst(oob) + if err != nil { + t.Fatalf("error parsing ipv6 oob: %v", err) + } + dst6 := dst.To16() + if dst6 == nil { + t.Errorf("failed to parse ipv6: %v", dst) + } else if dst6.String() != "2001:db8::1" { + t.Errorf("unexpected ipv6: %v", dst4) + } + + // dst is 100.100.100.100 but was received on 10.10.10.10 + oob = []byte{28, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 8, 0, 0, 0, 2, 0, 0, 0, 10, 10, 10, 10, 100, 100, 100, 100, 0, 0, 0, 0} + dst, err = parseUDPSocketDst(oob) + if err != nil { + t.Fatalf("error parsing ipv4 oob: %v", err) + } + dst4 = dst.To4() + if dst4 == nil { + t.Errorf("failed to parse ipv4: %v", dst) + } else if dst4.String() != "100.100.100.100" { + t.Errorf("unexpected ipv4: %v", dst4) + } +} + +func TestMarshalUDPSocketSrc(t *testing.T) { + // src is 100.100.100.100 + exoob := []byte{28, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 8, 0, 0, 0, 0, 0, 0, 0, 100, 100, 100, 100, 0, 0, 0, 0, 0, 0, 0, 0} + oob := marshalUDPSocketSrc(net.ParseIP("100.100.100.100")) + if !bytes.Equal(exoob, oob) { + t.Errorf("expected ipv4 oob:\n%v", exoob) + t.Errorf("actual ipv4 oob:\n%v", oob) + } + + // src is 2001:db8::1 + exoob = []byte{36, 0, 0, 0, 0, 0, 0, 0, 41, 0, 0, 0, 50, 0, 0, 0, 32, 1, 13, 184, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0} + oob = marshalUDPSocketSrc(net.ParseIP("2001:db8::1")) + if !bytes.Equal(exoob, oob) { + t.Errorf("expected ipv6 oob:\n%v", exoob) + t.Errorf("actual ipv6 oob:\n%v", oob) + } +} diff --git a/udp_other.go b/udp_other.go index 488a282b..531f4ebc 100644 --- a/udp_other.go +++ b/udp_other.go @@ -9,7 +9,9 @@ import ( // These do nothing. See udp_linux.go for an example of how to implement this. // We tried to adhire to some kind of naming scheme. -func setUDPSocketOptions(conn *net.UDPConn) error { return nil } -func setUDPSocketOptions4(conn *net.UDPConn) error { return nil } -func setUDPSocketOptions6(conn *net.UDPConn) error { return nil } -func getUDPSocketOptions6Only(conn *net.UDPConn) (bool, error) { return false, nil } +func setUDPSocketOptions(conn *net.UDPConn) error { return nil } +func setUDPSocketOptions4(conn *net.UDPConn) error { return nil } +func setUDPSocketOptions6(conn *net.UDPConn) error { return nil } +func getUDPSocketOptions6Only(conn *net.UDPConn) (bool, error) { return false, nil } +func parseUDPSocketDst(oob []byte) (net.IP, error) { return nil, nil } +func marshalUDPSocketSrc(src net.IP) []byte { return nil } diff --git a/udp_windows.go b/udp_windows.go index 51e532ac..2ad4ede7 100644 --- a/udp_windows.go +++ b/udp_windows.go @@ -4,10 +4,12 @@ package dns import "net" +// SessionUDP holds the remote address type SessionUDP struct { raddr *net.UDPAddr } +// RemoteAddr returns the remote network address. func (s *SessionUDP) RemoteAddr() net.Addr { return s.raddr } // ReadFromSessionUDP acts just like net.UDPConn.ReadFrom(), but returns a session object instead of a @@ -21,9 +23,8 @@ func ReadFromSessionUDP(conn *net.UDPConn, b []byte) (int, *SessionUDP, error) { return n, session, err } -// WriteToSessionUDP acts just like net.UDPConn.WritetTo(), but uses a *SessionUDP instead of a net.Addr. +// WriteToSessionUDP acts just like net.UDPConn.WriteTo(), but uses a *SessionUDP instead of a net.Addr. func WriteToSessionUDP(conn *net.UDPConn, b []byte, session *SessionUDP) (int, error) { n, err := conn.WriteTo(b, session.raddr) return n, err } -