From 032fbabc82455d0649e332bdf83767f134fc9431 Mon Sep 17 00:00:00 2001 From: James Hartig Date: Fri, 29 Sep 2017 05:38:26 -0400 Subject: [PATCH] Correctly set the Source IP to the received Destination IP (#524) Previously, the oob data was just stored and sent to WriteMsgUDP but it ignores the Src field when writing. Instead, now it is setting the Src to the original Dst and handling IPv4 IPs over IPv6 correctly. --- CONTRIBUTORS | 1 + internal/socket/cmsghdr.go | 7 ++ internal/socket/cmsghdr_linux_32bit.go | 20 +++++ internal/socket/cmsghdr_linux_64bit.go | 20 +++++ internal/socket/cmsghdr_other.go | 13 +++ internal/socket/controlmessage.go | 118 +++++++++++++++++++++++++ internal/socket/controlmessage_test.go | 103 +++++++++++++++++++++ internal/socket/socket.go | 4 + internal/socket/sys.go | 14 +++ udp.go | 15 +++- udp_linux.go | 115 ++++++++++++++++++++++++ udp_linux_test.go | 68 ++++++++++++++ udp_other.go | 10 ++- udp_windows.go | 5 +- 14 files changed, 505 insertions(+), 8 deletions(-) create mode 100644 internal/socket/cmsghdr.go create mode 100644 internal/socket/cmsghdr_linux_32bit.go create mode 100644 internal/socket/cmsghdr_linux_64bit.go create mode 100644 internal/socket/cmsghdr_other.go create mode 100644 internal/socket/controlmessage.go create mode 100644 internal/socket/controlmessage_test.go create mode 100644 internal/socket/socket.go create mode 100644 internal/socket/sys.go create mode 100644 udp_linux_test.go diff --git a/CONTRIBUTORS b/CONTRIBUTORS index f77e8a89..5903779d 100644 --- a/CONTRIBUTORS +++ b/CONTRIBUTORS @@ -7,3 +7,4 @@ Marek Majkowski Peter van Dijk Omri Bahumi Alex Sergeyev +James Hartig diff --git a/internal/socket/cmsghdr.go b/internal/socket/cmsghdr.go new file mode 100644 index 00000000..62f2d2f7 --- /dev/null +++ b/internal/socket/cmsghdr.go @@ -0,0 +1,7 @@ +// +build linux + +package socket + +func (h *cmsghdr) len() int { return int(h.Len) } +func (h *cmsghdr) lvl() int { return int(h.Level) } +func (h *cmsghdr) typ() int { return int(h.Type) } diff --git a/internal/socket/cmsghdr_linux_32bit.go b/internal/socket/cmsghdr_linux_32bit.go new file mode 100644 index 00000000..e92e8580 --- /dev/null +++ b/internal/socket/cmsghdr_linux_32bit.go @@ -0,0 +1,20 @@ +// +build arm mips mipsle 386 +// +build linux + +package socket + +type cmsghdr struct { + Len uint32 + Level int32 + Type int32 +} + +const ( + sizeofCmsghdr = 0xc +) + +func (h *cmsghdr) set(l, lvl, typ int) { + h.Len = uint32(l) + h.Level = int32(lvl) + h.Type = int32(typ) +} diff --git a/internal/socket/cmsghdr_linux_64bit.go b/internal/socket/cmsghdr_linux_64bit.go new file mode 100644 index 00000000..ddfc9e09 --- /dev/null +++ b/internal/socket/cmsghdr_linux_64bit.go @@ -0,0 +1,20 @@ +// +build arm64 amd64 ppc64 ppc64le mips64 mips64le s390x +// +build linux + +package socket + +type cmsghdr struct { + Len uint64 + Level int32 + Type int32 +} + +const ( + sizeofCmsghdr = 0x10 +) + +func (h *cmsghdr) set(l, lvl, typ int) { + h.Len = uint64(l) + h.Level = int32(lvl) + h.Type = int32(typ) +} diff --git a/internal/socket/cmsghdr_other.go b/internal/socket/cmsghdr_other.go new file mode 100644 index 00000000..8078487c --- /dev/null +++ b/internal/socket/cmsghdr_other.go @@ -0,0 +1,13 @@ +// +build !linux + +package socket + +type cmsghdr struct{} + +const sizeofCmsghdr = 0 + +func (h *cmsghdr) len() int { return 0 } +func (h *cmsghdr) lvl() int { return 0 } +func (h *cmsghdr) typ() int { return 0 } + +func (h *cmsghdr) set(l, lvl, typ int) {} diff --git a/internal/socket/controlmessage.go b/internal/socket/controlmessage.go new file mode 100644 index 00000000..3176e960 --- /dev/null +++ b/internal/socket/controlmessage.go @@ -0,0 +1,118 @@ +package socket + +import ( + "errors" + "unsafe" +) + +func controlHeaderLen() int { + return roundup(sizeofCmsghdr) +} + +func controlMessageLen(dataLen int) int { + return roundup(sizeofCmsghdr) + dataLen +} + +// returns the whole length of control message. +func ControlMessageSpace(dataLen int) int { + return roundup(sizeofCmsghdr) + roundup(dataLen) +} + +// A ControlMessage represents the head message in a stream of control +// messages. +// +// A control message comprises of a header, data and a few padding +// fields to conform to the interface to the kernel. +// +// See RFC 3542 for further information. +type ControlMessage []byte + +// Data returns the data field of the control message at the head. +func (m ControlMessage) Data(dataLen int) []byte { + l := controlHeaderLen() + if len(m) < l || len(m) < l+dataLen { + return nil + } + return m[l : l+dataLen] +} + +// ParseHeader parses and returns the header fields of the control +// message at the head. +func (m ControlMessage) ParseHeader() (lvl, typ, dataLen int, err error) { + l := controlHeaderLen() + if len(m) < l { + return 0, 0, 0, errors.New("short message") + } + h := (*cmsghdr)(unsafe.Pointer(&m[0])) + return h.lvl(), h.typ(), int(uint64(h.len()) - uint64(l)), nil +} + +// Next returns the control message at the next. +func (m ControlMessage) Next(dataLen int) ControlMessage { + l := ControlMessageSpace(dataLen) + if len(m) < l { + return nil + } + return m[l:] +} + +// MarshalHeader marshals the header fields of the control message at +// the head. +func (m ControlMessage) MarshalHeader(lvl, typ, dataLen int) error { + if len(m) < controlHeaderLen() { + return errors.New("short message") + } + h := (*cmsghdr)(unsafe.Pointer(&m[0])) + h.set(controlMessageLen(dataLen), lvl, typ) + return nil +} + +// Marshal marshals the control message at the head, and returns the next +// control message. +func (m ControlMessage) Marshal(lvl, typ int, data []byte) (ControlMessage, error) { + l := len(data) + if len(m) < ControlMessageSpace(l) { + return nil, errors.New("short message") + } + h := (*cmsghdr)(unsafe.Pointer(&m[0])) + h.set(controlMessageLen(l), lvl, typ) + if l > 0 { + copy(m.Data(l), data) + } + return m.Next(l), nil +} + +// Parse parses as a single or multiple control messages. +func (m ControlMessage) Parse() ([]ControlMessage, error) { + var ms []ControlMessage + for len(m) >= controlHeaderLen() { + h := (*cmsghdr)(unsafe.Pointer(&m[0])) + l := h.len() + if l <= 0 { + return nil, errors.New("invalid header length") + } + if uint64(l) < uint64(controlHeaderLen()) { + return nil, errors.New("invalid message length") + } + if uint64(l) > uint64(len(m)) { + return nil, errors.New("short buffer") + } + ms = append(ms, ControlMessage(m[:l])) + ll := l - controlHeaderLen() + if len(m) >= ControlMessageSpace(ll) { + m = m[ControlMessageSpace(ll):] + } else { + m = m[controlMessageLen(ll):] + } + } + return ms, nil +} + +// NewControlMessage returns a new stream of control messages. +func NewControlMessage(dataLen []int) ControlMessage { + var l int + for i := range dataLen { + l += ControlMessageSpace(dataLen[i]) + } + return make([]byte, l) +} diff --git a/internal/socket/controlmessage_test.go b/internal/socket/controlmessage_test.go new file mode 100644 index 00000000..e9fff4d4 --- /dev/null +++ b/internal/socket/controlmessage_test.go @@ -0,0 +1,103 @@ +// +build linux + +package socket + +import ( + "bytes" + "testing" +) + +type mockControl struct { + Level int + Type int + Data []byte +} + +func TestControlMessage(t *testing.T) { + for _, tt := range []struct { + cs []mockControl + }{ + { + []mockControl{ + {Level: 1, Type: 1}, + }, + }, + { + []mockControl{ + {Level: 2, Type: 2, Data: []byte{0xfe}}, + }, + }, + { + []mockControl{ + {Level: 3, Type: 3, Data: []byte{0xfe, 0xff, 0xff, 0xfe}}, + }, + }, + { + []mockControl{ + {Level: 4, Type: 4, Data: []byte{0xfe, 0xff, 0xff, 0xfe, 0xfe, 0xff, 0xff, 0xfe}}, + }, + }, + { + []mockControl{ + {Level: 4, Type: 4, Data: []byte{0xfe, 0xff, 0xff, 0xfe, 0xfe, 0xff, 0xff, 0xfe}}, + {Level: 2, Type: 2, Data: []byte{0xfe}}, + }, + }, + } { + var w []byte + var tailPadLen int + mm := NewControlMessage([]int{0}) + for i, c := range tt.cs { + m := NewControlMessage([]int{len(c.Data)}) + l := len(m) - len(mm) + if i == len(tt.cs)-1 && l > len(c.Data) { + tailPadLen = l - len(c.Data) + } + w = append(w, m...) + } + + var err error + ww := make([]byte, len(w)) + copy(ww, w) + m := ControlMessage(ww) + for _, c := range tt.cs { + if err = m.MarshalHeader(c.Level, c.Type, len(c.Data)); err != nil { + t.Fatalf("(%v).MarshalHeader() = %v", tt.cs, err) + } + copy(m.Data(len(c.Data)), c.Data) + m = m.Next(len(c.Data)) + } + m = ControlMessage(w) + for _, c := range tt.cs { + m, err = m.Marshal(c.Level, c.Type, c.Data) + if err != nil { + t.Fatalf("(%v).Marshal() = %v", tt.cs, err) + } + } + if !bytes.Equal(ww, w) { + t.Fatalf("got %#v; want %#v", ww, w) + } + + ws := [][]byte{w} + if tailPadLen > 0 { + // Test a message with no tail padding. + nopad := w[:len(w)-tailPadLen] + ws = append(ws, [][]byte{nopad}...) + } + for _, w := range ws { + ms, err := ControlMessage(w).Parse() + if err != nil { + t.Fatalf("(%v).Parse() = %v", tt.cs, err) + } + for i, m := range ms { + lvl, typ, dataLen, err := m.ParseHeader() + if err != nil { + t.Fatalf("(%v).ParseHeader() = %v", tt.cs, err) + } + if lvl != tt.cs[i].Level || typ != tt.cs[i].Type || dataLen != len(tt.cs[i].Data) { + t.Fatalf("%v: got %d, %d, %d; want %d, %d, %d", tt.cs[i], lvl, typ, dataLen, tt.cs[i].Level, tt.cs[i].Type, len(tt.cs[i].Data)) + } + } + } + } +} diff --git a/internal/socket/socket.go b/internal/socket/socket.go new file mode 100644 index 00000000..edb58e29 --- /dev/null +++ b/internal/socket/socket.go @@ -0,0 +1,4 @@ +// Package socket contains ControlMessage parsing code from +// golang.org/x/net/internal/socket. Instead of supporting all possible +// architectures, we're only supporting linux 32/64 bit. +package socket diff --git a/internal/socket/sys.go b/internal/socket/sys.go new file mode 100644 index 00000000..2f3f5cfe --- /dev/null +++ b/internal/socket/sys.go @@ -0,0 +1,14 @@ +package socket + +import "unsafe" + +var ( + kernelAlign = func() int { + var p uintptr + return int(unsafe.Sizeof(p)) + }() +) + +func roundup(l int) int { + return (l + kernelAlign - 1) & ^(kernelAlign - 1) +} diff --git a/udp.go b/udp.go index af111b9a..12a20967 100644 --- a/udp.go +++ b/udp.go @@ -27,8 +27,19 @@ func ReadFromSessionUDP(conn *net.UDPConn, b []byte) (int, *SessionUDP, error) { return n, &SessionUDP{raddr, oob[:oobn]}, err } -// WriteToSessionUDP acts just like net.UDPConn.WritetTo(), but uses a *SessionUDP instead of a net.Addr. +// WriteToSessionUDP acts just like net.UDPConn.WriteTo(), but uses a *SessionUDP instead of a net.Addr. func WriteToSessionUDP(conn *net.UDPConn, b []byte, session *SessionUDP) (int, error) { - n, _, err := conn.WriteMsgUDP(b, session.context, session.raddr) + oob := correctSource(session.context) + n, _, err := conn.WriteMsgUDP(b, oob, session.raddr) return n, err } + +// correctSource takes oob data and returns new oob data with the Src equal to the Dst +func correctSource(oob []byte) []byte { + dst, err := parseUDPSocketDst(oob) + // If the destination could not be determined, ignore. + if err != nil || dst == nil { + return nil + } + return marshalUDPSocketSrc(dst) +} diff --git a/udp_linux.go b/udp_linux.go index 033df423..13747ed3 100644 --- a/udp_linux.go +++ b/udp_linux.go @@ -13,8 +13,34 @@ package dns import ( "net" "syscall" + "unsafe" + + "github.com/miekg/dns/internal/socket" ) +const ( + sizeofInet6Pktinfo = 0x14 + sizeofInetPktinfo = 0xc + protocolIP = 0 + protocolIPv6 = 41 +) + +type inetPktinfo struct { + Ifindex int32 + Spec_dst [4]byte /* in_addr */ + Addr [4]byte /* in_addr */ +} + +type inet6Pktinfo struct { + Addr [16]byte /* in6_addr */ + Ifindex int32 +} + +type inetControlMessage struct { + Src net.IP // source address, specifying only + Dst net.IP // destination address, receiving only +} + // setUDPSocketOptions sets the UDP socket options. // This function is implemented on a per platform basis. See udp_*.go for more details func setUDPSocketOptions(conn *net.UDPConn) error { @@ -103,3 +129,92 @@ func getUDPSocketName(conn *net.UDPConn) (syscall.Sockaddr, error) { defer file.Close() return syscall.Getsockname(int(file.Fd())) } + +// marshalInetPacketInfo marshals a ipv4 control message, returning +// the byte slice for the next marshal, if any +func marshalInetPacketInfo(b []byte, cm *inetControlMessage) []byte { + m := socket.ControlMessage(b) + m.MarshalHeader(protocolIP, syscall.IP_PKTINFO, sizeofInetPktinfo) + if cm != nil { + pi := (*inetPktinfo)(unsafe.Pointer(&m.Data(sizeofInetPktinfo)[0])) + if ip := cm.Src.To4(); ip != nil { + copy(pi.Spec_dst[:], ip) + } + } + return m.Next(sizeofInetPktinfo) +} + +// marshalInet6PacketInfo marshals a ipv6 control message, returning +// the byte slice for the next marshal, if any +func marshalInet6PacketInfo(b []byte, cm *inetControlMessage) []byte { + m := socket.ControlMessage(b) + m.MarshalHeader(protocolIPv6, syscall.IPV6_PKTINFO, sizeofInet6Pktinfo) + if cm != nil { + pi := (*inet6Pktinfo)(unsafe.Pointer(&m.Data(sizeofInet6Pktinfo)[0])) + if ip := cm.Src.To16(); ip != nil && ip.To4() == nil { + copy(pi.Addr[:], ip) + } + } + return m.Next(sizeofInet6Pktinfo) +} + +func parseInetPacketInfo(cm *inetControlMessage, b []byte) { + pi := (*inetPktinfo)(unsafe.Pointer(&b[0])) + if len(cm.Dst) < net.IPv4len { + cm.Dst = make(net.IP, net.IPv4len) + } + copy(cm.Dst, pi.Addr[:]) +} + +func parseInet6PacketInfo(cm *inetControlMessage, b []byte) { + pi := (*inet6Pktinfo)(unsafe.Pointer(&b[0])) + if len(cm.Dst) < net.IPv6len { + cm.Dst = make(net.IP, net.IPv6len) + } + copy(cm.Dst, pi.Addr[:]) +} + +// parseUDPSocketDst takes out-of-band data from ReadMsgUDP and parses it for +// the Dst address +func parseUDPSocketDst(oob []byte) (net.IP, error) { + cm := new(inetControlMessage) + ms, err := socket.ControlMessage(oob).Parse() + if err != nil { + return nil, err + } + for _, m := range ms { + lvl, typ, l, err := m.ParseHeader() + if err != nil { + return nil, err + } + if lvl == protocolIPv6 { // IPv6 + if typ == syscall.IPV6_PKTINFO && l >= sizeofInet6Pktinfo { + parseInet6PacketInfo(cm, m.Data(l)) + } + } else if lvl == protocolIP { // IPv4 + if typ == syscall.IP_PKTINFO && l >= sizeofInetPktinfo { + parseInetPacketInfo(cm, m.Data(l)) + } + } + } + return cm.Dst, nil +} + +// marshalUDPSocketSrc takes the given src address and returns out-of-band data +// to give to WriteMsgUDP +func marshalUDPSocketSrc(src net.IP) []byte { + var oob []byte + // If the dst is definitely an ipv6, then use ipv6 control to respond + // otherwise use ipv4 because the ipv6 marshal ignores ipv4 messages. + // See marshalInet6PacketInfo + cm := new(inetControlMessage) + cm.Src = src + if src.To4() == nil { + oob = make([]byte, socket.ControlMessageSpace(sizeofInet6Pktinfo)) + marshalInet6PacketInfo(oob, cm) + } else { + oob = make([]byte, socket.ControlMessageSpace(sizeofInetPktinfo)) + marshalInetPacketInfo(oob, cm) + } + return oob +} diff --git a/udp_linux_test.go b/udp_linux_test.go new file mode 100644 index 00000000..14a8acef --- /dev/null +++ b/udp_linux_test.go @@ -0,0 +1,68 @@ +// +build linux,!appengine + +package dns + +import ( + "bytes" + "net" + "testing" +) + +func TestParseUDPSocketDst(t *testing.T) { + // dst is :ffff:100.100.100.100 + oob := []byte{36, 0, 0, 0, 0, 0, 0, 0, 41, 0, 0, 0, 50, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 100, 100, 100, 100, 2, 0, 0, 0} + dst, err := parseUDPSocketDst(oob) + if err != nil { + t.Fatalf("error parsing ipv6 oob: %v", err) + } + dst4 := dst.To4() + if dst4 == nil { + t.Errorf("failed to parse ipv4: %v", dst) + } else if dst4.String() != "100.100.100.100" { + t.Errorf("unexpected ipv4: %v", dst4) + } + + // dst is 2001:db8::1 + oob = []byte{36, 0, 0, 0, 0, 0, 0, 0, 41, 0, 0, 0, 50, 0, 0, 0, 32, 1, 13, 184, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0} + dst, err = parseUDPSocketDst(oob) + if err != nil { + t.Fatalf("error parsing ipv6 oob: %v", err) + } + dst6 := dst.To16() + if dst6 == nil { + t.Errorf("failed to parse ipv6: %v", dst) + } else if dst6.String() != "2001:db8::1" { + t.Errorf("unexpected ipv6: %v", dst4) + } + + // dst is 100.100.100.100 but was received on 10.10.10.10 + oob = []byte{28, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 8, 0, 0, 0, 2, 0, 0, 0, 10, 10, 10, 10, 100, 100, 100, 100, 0, 0, 0, 0} + dst, err = parseUDPSocketDst(oob) + if err != nil { + t.Fatalf("error parsing ipv4 oob: %v", err) + } + dst4 = dst.To4() + if dst4 == nil { + t.Errorf("failed to parse ipv4: %v", dst) + } else if dst4.String() != "100.100.100.100" { + t.Errorf("unexpected ipv4: %v", dst4) + } +} + +func TestMarshalUDPSocketSrc(t *testing.T) { + // src is 100.100.100.100 + exoob := []byte{28, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 8, 0, 0, 0, 0, 0, 0, 0, 100, 100, 100, 100, 0, 0, 0, 0, 0, 0, 0, 0} + oob := marshalUDPSocketSrc(net.ParseIP("100.100.100.100")) + if !bytes.Equal(exoob, oob) { + t.Errorf("expected ipv4 oob:\n%v", exoob) + t.Errorf("actual ipv4 oob:\n%v", oob) + } + + // src is 2001:db8::1 + exoob = []byte{36, 0, 0, 0, 0, 0, 0, 0, 41, 0, 0, 0, 50, 0, 0, 0, 32, 1, 13, 184, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0} + oob = marshalUDPSocketSrc(net.ParseIP("2001:db8::1")) + if !bytes.Equal(exoob, oob) { + t.Errorf("expected ipv6 oob:\n%v", exoob) + t.Errorf("actual ipv6 oob:\n%v", oob) + } +} diff --git a/udp_other.go b/udp_other.go index 488a282b..531f4ebc 100644 --- a/udp_other.go +++ b/udp_other.go @@ -9,7 +9,9 @@ import ( // These do nothing. See udp_linux.go for an example of how to implement this. // We tried to adhire to some kind of naming scheme. -func setUDPSocketOptions(conn *net.UDPConn) error { return nil } -func setUDPSocketOptions4(conn *net.UDPConn) error { return nil } -func setUDPSocketOptions6(conn *net.UDPConn) error { return nil } -func getUDPSocketOptions6Only(conn *net.UDPConn) (bool, error) { return false, nil } +func setUDPSocketOptions(conn *net.UDPConn) error { return nil } +func setUDPSocketOptions4(conn *net.UDPConn) error { return nil } +func setUDPSocketOptions6(conn *net.UDPConn) error { return nil } +func getUDPSocketOptions6Only(conn *net.UDPConn) (bool, error) { return false, nil } +func parseUDPSocketDst(oob []byte) (net.IP, error) { return nil, nil } +func marshalUDPSocketSrc(src net.IP) []byte { return nil } diff --git a/udp_windows.go b/udp_windows.go index 51e532ac..2ad4ede7 100644 --- a/udp_windows.go +++ b/udp_windows.go @@ -4,10 +4,12 @@ package dns import "net" +// SessionUDP holds the remote address type SessionUDP struct { raddr *net.UDPAddr } +// RemoteAddr returns the remote network address. func (s *SessionUDP) RemoteAddr() net.Addr { return s.raddr } // ReadFromSessionUDP acts just like net.UDPConn.ReadFrom(), but returns a session object instead of a @@ -21,9 +23,8 @@ func ReadFromSessionUDP(conn *net.UDPConn, b []byte) (int, *SessionUDP, error) { return n, session, err } -// WriteToSessionUDP acts just like net.UDPConn.WritetTo(), but uses a *SessionUDP instead of a net.Addr. +// WriteToSessionUDP acts just like net.UDPConn.WriteTo(), but uses a *SessionUDP instead of a net.Addr. func WriteToSessionUDP(conn *net.UDPConn, b []byte, session *SessionUDP) (int, error) { n, err := conn.WriteTo(b, session.raddr) return n, err } -