diff --git a/conn/bind_std.go b/conn/bind_std.go index 69789b3..c701ef8 100644 --- a/conn/bind_std.go +++ b/conn/bind_std.go @@ -81,11 +81,10 @@ func NewStdNetBind() Bind { type StdNetEndpoint struct { // AddrPort is the endpoint destination. netip.AddrPort - // src is the current sticky source address and interface index, if supported. - src struct { - netip.Addr - ifidx int32 - } + // src is the current sticky source address and interface index, if + // supported. Typically this is a PKTINFO structure from/for control + // messages, see unix.PKTINFO for an example. + src []byte } var ( @@ -104,21 +103,17 @@ func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) { } func (e *StdNetEndpoint) ClearSrc() { - e.src.ifidx = 0 - e.src.Addr = netip.Addr{} + if e.src != nil { + // Truncate src, no need to reallocate. + e.src = e.src[:0] + } } func (e *StdNetEndpoint) DstIP() netip.Addr { return e.AddrPort.Addr() } -func (e *StdNetEndpoint) SrcIP() netip.Addr { - return e.src.Addr -} - -func (e *StdNetEndpoint) SrcIfidx() int32 { - return e.src.ifidx -} +// See sticky_default,linux, etc for implementations of SrcIP and SrcIfidx. func (e *StdNetEndpoint) DstToBytes() []byte { b, _ := e.AddrPort.MarshalBinary() @@ -129,10 +124,6 @@ func (e *StdNetEndpoint) DstToString() string { return e.AddrPort.String() } -func (e *StdNetEndpoint) SrcToString() string { - return e.src.Addr.String() -} - func listenNet(network string, port int) (*net.UDPConn, int, error) { conn, err := listenConfig().ListenPacket(context.Background(), network, ":"+strconv.Itoa(port)) if err != nil { diff --git a/conn/sticky_default.go b/conn/sticky_default.go index 05f00ea..1fa8a0c 100644 --- a/conn/sticky_default.go +++ b/conn/sticky_default.go @@ -7,6 +7,20 @@ package conn +import "net/netip" + +func (e *StdNetEndpoint) SrcIP() netip.Addr { + return netip.Addr{} +} + +func (e *StdNetEndpoint) SrcIfidx() int32 { + return 0 +} + +func (e *StdNetEndpoint) SrcToString() string { + return "" +} + // TODO: macOS, FreeBSD and other BSDs likely do support this feature set, but // use alternatively named flags and need ports and require testing. diff --git a/conn/sticky_linux.go b/conn/sticky_linux.go index 274fa38..a30ccc7 100644 --- a/conn/sticky_linux.go +++ b/conn/sticky_linux.go @@ -14,6 +14,37 @@ import ( "golang.org/x/sys/unix" ) +func (e *StdNetEndpoint) SrcIP() netip.Addr { + switch len(e.src) { + case unix.CmsgSpace(unix.SizeofInet4Pktinfo): + info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)])) + return netip.AddrFrom4(info.Spec_dst) + case unix.CmsgSpace(unix.SizeofInet6Pktinfo): + info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)])) + // TODO: set zone. in order to do so we need to check if the address is + // link local, and if it is perform a syscall to turn the ifindex into a + // zone string because netip uses string zones. + return netip.AddrFrom16(info.Addr) + } + return netip.Addr{} +} + +func (e *StdNetEndpoint) SrcIfidx() int32 { + switch len(e.src) { + case unix.CmsgSpace(unix.SizeofInet4Pktinfo): + info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)])) + return info.Ifindex + case unix.CmsgSpace(unix.SizeofInet6Pktinfo): + info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)])) + return int32(info.Ifindex) + } + return 0 +} + +func (e *StdNetEndpoint) SrcToString() string { + return e.SrcIP().String() +} + // getSrcFromControl parses the control for PKTINFO and if found updates ep with // the source information found. func getSrcFromControl(control []byte, ep *StdNetEndpoint) { @@ -35,81 +66,43 @@ func getSrcFromControl(control []byte, ep *StdNetEndpoint) { if hdr.Level == unix.IPPROTO_IP && hdr.Type == unix.IP_PKTINFO { - info := pktInfoFromBuf[unix.Inet4Pktinfo](data) - ep.src.Addr = netip.AddrFrom4(info.Spec_dst) - ep.src.ifidx = info.Ifindex + if ep.src == nil || cap(ep.src) < unix.CmsgSpace(unix.SizeofInet4Pktinfo) { + ep.src = make([]byte, 0, unix.CmsgSpace(unix.SizeofInet4Pktinfo)) + } + ep.src = ep.src[:unix.CmsgSpace(unix.SizeofInet4Pktinfo)] + hdrBuf := unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), unix.SizeofCmsghdr) + copy(ep.src, hdrBuf) + copy(ep.src[unix.CmsgLen(0):], data) return } if hdr.Level == unix.IPPROTO_IPV6 && hdr.Type == unix.IPV6_PKTINFO { - info := pktInfoFromBuf[unix.Inet6Pktinfo](data) - ep.src.Addr = netip.AddrFrom16(info.Addr) - ep.src.ifidx = int32(info.Ifindex) + if ep.src == nil || cap(ep.src) < unix.CmsgSpace(unix.SizeofInet6Pktinfo) { + ep.src = make([]byte, 0, unix.CmsgSpace(unix.SizeofInet6Pktinfo)) + } + ep.src = ep.src[:unix.CmsgSpace(unix.SizeofInet6Pktinfo)] + + hdrBuf := unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), unix.SizeofCmsghdr) + copy(ep.src, hdrBuf) + copy(ep.src[unix.CmsgLen(0):], data) return } } } -// pktInfoFromBuf returns type T populated from the provided buf via copy(). It -// panics if buf is of insufficient size. -func pktInfoFromBuf[T unix.Inet4Pktinfo | unix.Inet6Pktinfo](buf []byte) (t T) { - size := int(unsafe.Sizeof(t)) - if len(buf) < size { - panic("pktInfoFromBuf: buffer too small") - } - copy(unsafe.Slice((*byte)(unsafe.Pointer(&t)), size), buf) - return t -} - // setSrcControl sets an IP{V6}_PKTINFO in control based on the source address // and source ifindex found in ep. control's len will be set to 0 in the event // that ep is a default value. func setSrcControl(control *[]byte, ep *StdNetEndpoint) { - *control = (*control)[:cap(*control)] - if len(*control) < int(unsafe.Sizeof(unix.Cmsghdr{})) { - *control = (*control)[:0] + if cap(*control) < len(ep.src) { return } - - if ep.src.ifidx == 0 && !ep.SrcIP().IsValid() { - *control = (*control)[:0] - return - } - - if len(*control) < srcControlSize { - *control = (*control)[:0] - return - } - - hdr := (*unix.Cmsghdr)(unsafe.Pointer(&(*control)[0])) - if ep.SrcIP().Is4() { - hdr.Level = unix.IPPROTO_IP - hdr.Type = unix.IP_PKTINFO - hdr.SetLen(unix.CmsgLen(unix.SizeofInet4Pktinfo)) - - info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&(*control)[unix.SizeofCmsghdr])) - info.Ifindex = ep.src.ifidx - if ep.SrcIP().IsValid() { - info.Spec_dst = ep.SrcIP().As4() - } - *control = (*control)[:unix.CmsgSpace(unix.SizeofInet4Pktinfo)] - } else { - hdr.Level = unix.IPPROTO_IPV6 - hdr.Type = unix.IPV6_PKTINFO - hdr.SetLen(unix.CmsgLen(unix.SizeofInet6Pktinfo)) - - info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&(*control)[unix.SizeofCmsghdr])) - info.Ifindex = uint32(ep.src.ifidx) - if ep.SrcIP().IsValid() { - info.Addr = ep.SrcIP().As16() - } - *control = (*control)[:unix.CmsgSpace(unix.SizeofInet6Pktinfo)] - } - + *control = (*control)[:0] + *control = append(*control, ep.src...) } var srcControlSize = unix.CmsgSpace(unix.SizeofInet6Pktinfo) diff --git a/conn/sticky_linux_test.go b/conn/sticky_linux_test.go index 0219ac3..679213a 100644 --- a/conn/sticky_linux_test.go +++ b/conn/sticky_linux_test.go @@ -18,13 +18,47 @@ import ( "golang.org/x/sys/unix" ) +func setSrc(ep *StdNetEndpoint, addr netip.Addr, ifidx int32) { + var buf []byte + if addr.Is4() { + buf = make([]byte, unix.CmsgSpace(unix.SizeofInet4Pktinfo)) + hdr := unix.Cmsghdr{ + Level: unix.IPPROTO_IP, + Type: unix.IP_PKTINFO, + } + hdr.SetLen(unix.CmsgLen(unix.SizeofInet4Pktinfo)) + copy(buf, unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), int(unsafe.Sizeof(hdr)))) + + info := unix.Inet4Pktinfo{ + Ifindex: ifidx, + Spec_dst: addr.As4(), + } + copy(buf[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&info)), unix.SizeofInet4Pktinfo)) + } else { + buf = make([]byte, unix.CmsgSpace(unix.SizeofInet6Pktinfo)) + hdr := unix.Cmsghdr{ + Level: unix.IPPROTO_IPV6, + Type: unix.IPV6_PKTINFO, + } + hdr.SetLen(unix.CmsgLen(unix.SizeofInet6Pktinfo)) + copy(buf, unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), int(unsafe.Sizeof(hdr)))) + + info := unix.Inet6Pktinfo{ + Ifindex: uint32(ifidx), + Addr: addr.As16(), + } + copy(buf[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&info)), unix.SizeofInet6Pktinfo)) + } + + ep.src = buf +} + func Test_setSrcControl(t *testing.T) { t.Run("IPv4", func(t *testing.T) { ep := &StdNetEndpoint{ AddrPort: netip.MustParseAddrPort("127.0.0.1:1234"), } - ep.src.Addr = netip.MustParseAddr("127.0.0.1") - ep.src.ifidx = 5 + setSrc(ep, netip.MustParseAddr("127.0.0.1"), 5) control := make([]byte, srcControlSize) @@ -53,8 +87,7 @@ func Test_setSrcControl(t *testing.T) { ep := &StdNetEndpoint{ AddrPort: netip.MustParseAddrPort("[::1]:1234"), } - ep.src.Addr = netip.MustParseAddr("::1") - ep.src.ifidx = 5 + setSrc(ep, netip.MustParseAddr("::1"), 5) control := make([]byte, srcControlSize) @@ -80,7 +113,7 @@ func Test_setSrcControl(t *testing.T) { }) t.Run("ClearOnNoSrc", func(t *testing.T) { - control := make([]byte, srcControlSize) + control := make([]byte, unix.CmsgLen(0)) hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0])) hdr.Level = 1 hdr.Type = 2 @@ -96,7 +129,7 @@ func Test_setSrcControl(t *testing.T) { func Test_getSrcFromControl(t *testing.T) { t.Run("IPv4", func(t *testing.T) { - control := make([]byte, srcControlSize) + control := make([]byte, unix.CmsgSpace(unix.SizeofInet4Pktinfo)) hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0])) hdr.Level = unix.IPPROTO_IP hdr.Type = unix.IP_PKTINFO @@ -108,15 +141,15 @@ func Test_getSrcFromControl(t *testing.T) { ep := &StdNetEndpoint{} getSrcFromControl(control, ep) - if ep.src.Addr != netip.MustParseAddr("127.0.0.1") { - t.Errorf("unexpected address: %v", ep.src.Addr) + if ep.SrcIP() != netip.MustParseAddr("127.0.0.1") { + t.Errorf("unexpected address: %v", ep.SrcIP()) } - if ep.src.ifidx != 5 { - t.Errorf("unexpected ifindex: %d", ep.src.ifidx) + if ep.SrcIfidx() != 5 { + t.Errorf("unexpected ifindex: %d", ep.SrcIfidx()) } }) t.Run("IPv6", func(t *testing.T) { - control := make([]byte, srcControlSize) + control := make([]byte, unix.CmsgSpace(unix.SizeofInet6Pktinfo)) hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0])) hdr.Level = unix.IPPROTO_IPV6 hdr.Type = unix.IPV6_PKTINFO @@ -131,22 +164,21 @@ func Test_getSrcFromControl(t *testing.T) { if ep.SrcIP() != netip.MustParseAddr("::1") { t.Errorf("unexpected address: %v", ep.SrcIP()) } - if ep.src.ifidx != 5 { - t.Errorf("unexpected ifindex: %d", ep.src.ifidx) + if ep.SrcIfidx() != 5 { + t.Errorf("unexpected ifindex: %d", ep.SrcIfidx()) } }) t.Run("ClearOnEmpty", func(t *testing.T) { - control := make([]byte, srcControlSize) + var control []byte ep := &StdNetEndpoint{} - ep.src.Addr = netip.MustParseAddr("::1") - ep.src.ifidx = 5 + setSrc(ep, netip.MustParseAddr("::1"), 5) getSrcFromControl(control, ep) if ep.SrcIP().IsValid() { - t.Errorf("unexpected address: %v", ep.src.Addr) + t.Errorf("unexpected address: %v", ep.SrcIP()) } - if ep.src.ifidx != 0 { - t.Errorf("unexpected ifindex: %d", ep.src.ifidx) + if ep.SrcIfidx() != 0 { + t.Errorf("unexpected ifindex: %d", ep.SrcIfidx()) } }) t.Run("Multiple", func(t *testing.T) { @@ -154,7 +186,7 @@ func Test_getSrcFromControl(t *testing.T) { zeroHdr := (*unix.Cmsghdr)(unsafe.Pointer(&zeroControl[0])) zeroHdr.SetLen(unix.CmsgLen(0)) - control := make([]byte, srcControlSize) + control := make([]byte, unix.CmsgSpace(unix.SizeofInet4Pktinfo)) hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0])) hdr.Level = unix.IPPROTO_IP hdr.Type = unix.IP_PKTINFO @@ -170,11 +202,11 @@ func Test_getSrcFromControl(t *testing.T) { ep := &StdNetEndpoint{} getSrcFromControl(combined, ep) - if ep.src.Addr != netip.MustParseAddr("127.0.0.1") { - t.Errorf("unexpected address: %v", ep.src.Addr) + if ep.SrcIP() != netip.MustParseAddr("127.0.0.1") { + t.Errorf("unexpected address: %v", ep.SrcIP()) } - if ep.src.ifidx != 5 { - t.Errorf("unexpected ifindex: %d", ep.src.ifidx) + if ep.SrcIfidx() != 5 { + t.Errorf("unexpected ifindex: %d", ep.SrcIfidx()) } }) }