Correctly set the Source IP to the received Destination IP (#524)
Previously, the oob data was just stored and sent to WriteMsgUDP but it ignores the Src field when writing. Instead, now it is setting the Src to the original Dst and handling IPv4 IPs over IPv6 correctly.master
parent
eccf8bbe83
commit
032fbabc82
@ -0,0 +1,7 @@
|
||||
// +build linux
|
||||
|
||||
package socket
|
||||
|
||||
func (h *cmsghdr) len() int { return int(h.Len) }
|
||||
func (h *cmsghdr) lvl() int { return int(h.Level) }
|
||||
func (h *cmsghdr) typ() int { return int(h.Type) }
|
@ -0,0 +1,20 @@
|
||||
// +build arm mips mipsle 386
|
||||
// +build linux
|
||||
|
||||
package socket
|
||||
|
||||
type cmsghdr struct {
|
||||
Len uint32
|
||||
Level int32
|
||||
Type int32
|
||||
}
|
||||
|
||||
const (
|
||||
sizeofCmsghdr = 0xc
|
||||
)
|
||||
|
||||
func (h *cmsghdr) set(l, lvl, typ int) {
|
||||
h.Len = uint32(l)
|
||||
h.Level = int32(lvl)
|
||||
h.Type = int32(typ)
|
||||
}
|
@ -0,0 +1,20 @@
|
||||
// +build arm64 amd64 ppc64 ppc64le mips64 mips64le s390x
|
||||
// +build linux
|
||||
|
||||
package socket
|
||||
|
||||
type cmsghdr struct {
|
||||
Len uint64
|
||||
Level int32
|
||||
Type int32
|
||||
}
|
||||
|
||||
const (
|
||||
sizeofCmsghdr = 0x10
|
||||
)
|
||||
|
||||
func (h *cmsghdr) set(l, lvl, typ int) {
|
||||
h.Len = uint64(l)
|
||||
h.Level = int32(lvl)
|
||||
h.Type = int32(typ)
|
||||
}
|
@ -0,0 +1,13 @@
|
||||
// +build !linux
|
||||
|
||||
package socket
|
||||
|
||||
type cmsghdr struct{}
|
||||
|
||||
const sizeofCmsghdr = 0
|
||||
|
||||
func (h *cmsghdr) len() int { return 0 }
|
||||
func (h *cmsghdr) lvl() int { return 0 }
|
||||
func (h *cmsghdr) typ() int { return 0 }
|
||||
|
||||
func (h *cmsghdr) set(l, lvl, typ int) {}
|
@ -0,0 +1,118 @@
|
||||
package socket
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
func controlHeaderLen() int {
|
||||
return roundup(sizeofCmsghdr)
|
||||
}
|
||||
|
||||
func controlMessageLen(dataLen int) int {
|
||||
return roundup(sizeofCmsghdr) + dataLen
|
||||
}
|
||||
|
||||
// returns the whole length of control message.
|
||||
func ControlMessageSpace(dataLen int) int {
|
||||
return roundup(sizeofCmsghdr) + roundup(dataLen)
|
||||
}
|
||||
|
||||
// A ControlMessage represents the head message in a stream of control
|
||||
// messages.
|
||||
//
|
||||
// A control message comprises of a header, data and a few padding
|
||||
// fields to conform to the interface to the kernel.
|
||||
//
|
||||
// See RFC 3542 for further information.
|
||||
type ControlMessage []byte
|
||||
|
||||
// Data returns the data field of the control message at the head.
|
||||
func (m ControlMessage) Data(dataLen int) []byte {
|
||||
l := controlHeaderLen()
|
||||
if len(m) < l || len(m) < l+dataLen {
|
||||
return nil
|
||||
}
|
||||
return m[l : l+dataLen]
|
||||
}
|
||||
|
||||
// ParseHeader parses and returns the header fields of the control
|
||||
// message at the head.
|
||||
func (m ControlMessage) ParseHeader() (lvl, typ, dataLen int, err error) {
|
||||
l := controlHeaderLen()
|
||||
if len(m) < l {
|
||||
return 0, 0, 0, errors.New("short message")
|
||||
}
|
||||
h := (*cmsghdr)(unsafe.Pointer(&m[0]))
|
||||
return h.lvl(), h.typ(), int(uint64(h.len()) - uint64(l)), nil
|
||||
}
|
||||
|
||||
// Next returns the control message at the next.
|
||||
func (m ControlMessage) Next(dataLen int) ControlMessage {
|
||||
l := ControlMessageSpace(dataLen)
|
||||
if len(m) < l {
|
||||
return nil
|
||||
}
|
||||
return m[l:]
|
||||
}
|
||||
|
||||
// MarshalHeader marshals the header fields of the control message at
|
||||
// the head.
|
||||
func (m ControlMessage) MarshalHeader(lvl, typ, dataLen int) error {
|
||||
if len(m) < controlHeaderLen() {
|
||||
return errors.New("short message")
|
||||
}
|
||||
h := (*cmsghdr)(unsafe.Pointer(&m[0]))
|
||||
h.set(controlMessageLen(dataLen), lvl, typ)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Marshal marshals the control message at the head, and returns the next
|
||||
// control message.
|
||||
func (m ControlMessage) Marshal(lvl, typ int, data []byte) (ControlMessage, error) {
|
||||
l := len(data)
|
||||
if len(m) < ControlMessageSpace(l) {
|
||||
return nil, errors.New("short message")
|
||||
}
|
||||
h := (*cmsghdr)(unsafe.Pointer(&m[0]))
|
||||
h.set(controlMessageLen(l), lvl, typ)
|
||||
if l > 0 {
|
||||
copy(m.Data(l), data)
|
||||
}
|
||||
return m.Next(l), nil
|
||||
}
|
||||
|
||||
// Parse parses as a single or multiple control messages.
|
||||
func (m ControlMessage) Parse() ([]ControlMessage, error) {
|
||||
var ms []ControlMessage
|
||||
for len(m) >= controlHeaderLen() {
|
||||
h := (*cmsghdr)(unsafe.Pointer(&m[0]))
|
||||
l := h.len()
|
||||
if l <= 0 {
|
||||
return nil, errors.New("invalid header length")
|
||||
}
|
||||
if uint64(l) < uint64(controlHeaderLen()) {
|
||||
return nil, errors.New("invalid message length")
|
||||
}
|
||||
if uint64(l) > uint64(len(m)) {
|
||||
return nil, errors.New("short buffer")
|
||||
}
|
||||
ms = append(ms, ControlMessage(m[:l]))
|
||||
ll := l - controlHeaderLen()
|
||||
if len(m) >= ControlMessageSpace(ll) {
|
||||
m = m[ControlMessageSpace(ll):]
|
||||
} else {
|
||||
m = m[controlMessageLen(ll):]
|
||||
}
|
||||
}
|
||||
return ms, nil
|
||||
}
|
||||
|
||||
// NewControlMessage returns a new stream of control messages.
|
||||
func NewControlMessage(dataLen []int) ControlMessage {
|
||||
var l int
|
||||
for i := range dataLen {
|
||||
l += ControlMessageSpace(dataLen[i])
|
||||
}
|
||||
return make([]byte, l)
|
||||
}
|
@ -0,0 +1,103 @@
|
||||
// +build linux
|
||||
|
||||
package socket
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type mockControl struct {
|
||||
Level int
|
||||
Type int
|
||||
Data []byte
|
||||
}
|
||||
|
||||
func TestControlMessage(t *testing.T) {
|
||||
for _, tt := range []struct {
|
||||
cs []mockControl
|
||||
}{
|
||||
{
|
||||
[]mockControl{
|
||||
{Level: 1, Type: 1},
|
||||
},
|
||||
},
|
||||
{
|
||||
[]mockControl{
|
||||
{Level: 2, Type: 2, Data: []byte{0xfe}},
|
||||
},
|
||||
},
|
||||
{
|
||||
[]mockControl{
|
||||
{Level: 3, Type: 3, Data: []byte{0xfe, 0xff, 0xff, 0xfe}},
|
||||
},
|
||||
},
|
||||
{
|
||||
[]mockControl{
|
||||
{Level: 4, Type: 4, Data: []byte{0xfe, 0xff, 0xff, 0xfe, 0xfe, 0xff, 0xff, 0xfe}},
|
||||
},
|
||||
},
|
||||
{
|
||||
[]mockControl{
|
||||
{Level: 4, Type: 4, Data: []byte{0xfe, 0xff, 0xff, 0xfe, 0xfe, 0xff, 0xff, 0xfe}},
|
||||
{Level: 2, Type: 2, Data: []byte{0xfe}},
|
||||
},
|
||||
},
|
||||
} {
|
||||
var w []byte
|
||||
var tailPadLen int
|
||||
mm := NewControlMessage([]int{0})
|
||||
for i, c := range tt.cs {
|
||||
m := NewControlMessage([]int{len(c.Data)})
|
||||
l := len(m) - len(mm)
|
||||
if i == len(tt.cs)-1 && l > len(c.Data) {
|
||||
tailPadLen = l - len(c.Data)
|
||||
}
|
||||
w = append(w, m...)
|
||||
}
|
||||
|
||||
var err error
|
||||
ww := make([]byte, len(w))
|
||||
copy(ww, w)
|
||||
m := ControlMessage(ww)
|
||||
for _, c := range tt.cs {
|
||||
if err = m.MarshalHeader(c.Level, c.Type, len(c.Data)); err != nil {
|
||||
t.Fatalf("(%v).MarshalHeader() = %v", tt.cs, err)
|
||||
}
|
||||
copy(m.Data(len(c.Data)), c.Data)
|
||||
m = m.Next(len(c.Data))
|
||||
}
|
||||
m = ControlMessage(w)
|
||||
for _, c := range tt.cs {
|
||||
m, err = m.Marshal(c.Level, c.Type, c.Data)
|
||||
if err != nil {
|
||||
t.Fatalf("(%v).Marshal() = %v", tt.cs, err)
|
||||
}
|
||||
}
|
||||
if !bytes.Equal(ww, w) {
|
||||
t.Fatalf("got %#v; want %#v", ww, w)
|
||||
}
|
||||
|
||||
ws := [][]byte{w}
|
||||
if tailPadLen > 0 {
|
||||
// Test a message with no tail padding.
|
||||
nopad := w[:len(w)-tailPadLen]
|
||||
ws = append(ws, [][]byte{nopad}...)
|
||||
}
|
||||
for _, w := range ws {
|
||||
ms, err := ControlMessage(w).Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("(%v).Parse() = %v", tt.cs, err)
|
||||
}
|
||||
for i, m := range ms {
|
||||
lvl, typ, dataLen, err := m.ParseHeader()
|
||||
if err != nil {
|
||||
t.Fatalf("(%v).ParseHeader() = %v", tt.cs, err)
|
||||
}
|
||||
if lvl != tt.cs[i].Level || typ != tt.cs[i].Type || dataLen != len(tt.cs[i].Data) {
|
||||
t.Fatalf("%v: got %d, %d, %d; want %d, %d, %d", tt.cs[i], lvl, typ, dataLen, tt.cs[i].Level, tt.cs[i].Type, len(tt.cs[i].Data))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,4 @@
|
||||
// Package socket contains ControlMessage parsing code from
|
||||
// golang.org/x/net/internal/socket. Instead of supporting all possible
|
||||
// architectures, we're only supporting linux 32/64 bit.
|
||||
package socket
|
@ -0,0 +1,14 @@
|
||||
package socket
|
||||
|
||||
import "unsafe"
|
||||
|
||||
var (
|
||||
kernelAlign = func() int {
|
||||
var p uintptr
|
||||
return int(unsafe.Sizeof(p))
|
||||
}()
|
||||
)
|
||||
|
||||
func roundup(l int) int {
|
||||
return (l + kernelAlign - 1) & ^(kernelAlign - 1)
|
||||
}
|
@ -0,0 +1,68 @@
|
||||
// +build linux,!appengine
|
||||
|
||||
package dns
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParseUDPSocketDst(t *testing.T) {
|
||||
// dst is :ffff:100.100.100.100
|
||||
oob := []byte{36, 0, 0, 0, 0, 0, 0, 0, 41, 0, 0, 0, 50, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 100, 100, 100, 100, 2, 0, 0, 0}
|
||||
dst, err := parseUDPSocketDst(oob)
|
||||
if err != nil {
|
||||
t.Fatalf("error parsing ipv6 oob: %v", err)
|
||||
}
|
||||
dst4 := dst.To4()
|
||||
if dst4 == nil {
|
||||
t.Errorf("failed to parse ipv4: %v", dst)
|
||||
} else if dst4.String() != "100.100.100.100" {
|
||||
t.Errorf("unexpected ipv4: %v", dst4)
|
||||
}
|
||||
|
||||
// dst is 2001:db8::1
|
||||
oob = []byte{36, 0, 0, 0, 0, 0, 0, 0, 41, 0, 0, 0, 50, 0, 0, 0, 32, 1, 13, 184, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0}
|
||||
dst, err = parseUDPSocketDst(oob)
|
||||
if err != nil {
|
||||
t.Fatalf("error parsing ipv6 oob: %v", err)
|
||||
}
|
||||
dst6 := dst.To16()
|
||||
if dst6 == nil {
|
||||
t.Errorf("failed to parse ipv6: %v", dst)
|
||||
} else if dst6.String() != "2001:db8::1" {
|
||||
t.Errorf("unexpected ipv6: %v", dst4)
|
||||
}
|
||||
|
||||
// dst is 100.100.100.100 but was received on 10.10.10.10
|
||||
oob = []byte{28, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 8, 0, 0, 0, 2, 0, 0, 0, 10, 10, 10, 10, 100, 100, 100, 100, 0, 0, 0, 0}
|
||||
dst, err = parseUDPSocketDst(oob)
|
||||
if err != nil {
|
||||
t.Fatalf("error parsing ipv4 oob: %v", err)
|
||||
}
|
||||
dst4 = dst.To4()
|
||||
if dst4 == nil {
|
||||
t.Errorf("failed to parse ipv4: %v", dst)
|
||||
} else if dst4.String() != "100.100.100.100" {
|
||||
t.Errorf("unexpected ipv4: %v", dst4)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarshalUDPSocketSrc(t *testing.T) {
|
||||
// src is 100.100.100.100
|
||||
exoob := []byte{28, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 8, 0, 0, 0, 0, 0, 0, 0, 100, 100, 100, 100, 0, 0, 0, 0, 0, 0, 0, 0}
|
||||
oob := marshalUDPSocketSrc(net.ParseIP("100.100.100.100"))
|
||||
if !bytes.Equal(exoob, oob) {
|
||||
t.Errorf("expected ipv4 oob:\n%v", exoob)
|
||||
t.Errorf("actual ipv4 oob:\n%v", oob)
|
||||
}
|
||||
|
||||
// src is 2001:db8::1
|
||||
exoob = []byte{36, 0, 0, 0, 0, 0, 0, 0, 41, 0, 0, 0, 50, 0, 0, 0, 32, 1, 13, 184, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0}
|
||||
oob = marshalUDPSocketSrc(net.ParseIP("2001:db8::1"))
|
||||
if !bytes.Equal(exoob, oob) {
|
||||
t.Errorf("expected ipv6 oob:\n%v", exoob)
|
||||
t.Errorf("actual ipv6 oob:\n%v", oob)
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue