diff --git a/conn/bind_linux.go b/conn/bind_linux.go deleted file mode 100644 index b6bc0dc..0000000 --- a/conn/bind_linux.go +++ /dev/null @@ -1,587 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. - */ - -package conn - -import ( - "errors" - "net" - "net/netip" - "strconv" - "sync" - "syscall" - "unsafe" - - "golang.org/x/sys/unix" -) - -type ipv4Source struct { - Src [4]byte - Ifindex int32 -} - -type ipv6Source struct { - src [16]byte - // ifindex belongs in dst.ZoneId -} - -type LinuxSocketEndpoint struct { - mu sync.Mutex - dst [unsafe.Sizeof(unix.SockaddrInet6{})]byte - src [unsafe.Sizeof(ipv6Source{})]byte - isV6 bool -} - -func (endpoint *LinuxSocketEndpoint) Src4() *ipv4Source { return endpoint.src4() } -func (endpoint *LinuxSocketEndpoint) Dst4() *unix.SockaddrInet4 { return endpoint.dst4() } -func (endpoint *LinuxSocketEndpoint) IsV6() bool { return endpoint.isV6 } - -func (endpoint *LinuxSocketEndpoint) src4() *ipv4Source { - return (*ipv4Source)(unsafe.Pointer(&endpoint.src[0])) -} - -func (endpoint *LinuxSocketEndpoint) src6() *ipv6Source { - return (*ipv6Source)(unsafe.Pointer(&endpoint.src[0])) -} - -func (endpoint *LinuxSocketEndpoint) dst4() *unix.SockaddrInet4 { - return (*unix.SockaddrInet4)(unsafe.Pointer(&endpoint.dst[0])) -} - -func (endpoint *LinuxSocketEndpoint) dst6() *unix.SockaddrInet6 { - return (*unix.SockaddrInet6)(unsafe.Pointer(&endpoint.dst[0])) -} - -// LinuxSocketBind uses sendmsg and recvmsg to implement a full bind with sticky sockets on Linux. -type LinuxSocketBind struct { - // mu guards sock4 and sock6 and the associated fds. - // As long as someone holds mu (read or write), the associated fds are valid. - mu sync.RWMutex - sock4 int - sock6 int -} - -func NewLinuxSocketBind() Bind { return &LinuxSocketBind{sock4: -1, sock6: -1} } -func NewDefaultBind() Bind { return NewLinuxSocketBind() } - -var ( - _ Endpoint = (*LinuxSocketEndpoint)(nil) - _ Bind = (*LinuxSocketBind)(nil) -) - -func (*LinuxSocketBind) ParseEndpoint(s string) (Endpoint, error) { - var end LinuxSocketEndpoint - e, err := netip.ParseAddrPort(s) - if err != nil { - return nil, err - } - - if e.Addr().Is4() { - dst := end.dst4() - end.isV6 = false - dst.Port = int(e.Port()) - dst.Addr = e.Addr().As4() - end.ClearSrc() - return &end, nil - } - - if e.Addr().Is6() { - zone, err := zoneToUint32(e.Addr().Zone()) - if err != nil { - return nil, err - } - dst := end.dst6() - end.isV6 = true - dst.Port = int(e.Port()) - dst.ZoneId = zone - dst.Addr = e.Addr().As16() - end.ClearSrc() - return &end, nil - } - - return nil, errors.New("invalid IP address") -} - -func (bind *LinuxSocketBind) Open(port uint16) ([]ReceiveFunc, uint16, error) { - bind.mu.Lock() - defer bind.mu.Unlock() - - var err error - var newPort uint16 - var tries int - - if bind.sock4 != -1 || bind.sock6 != -1 { - return nil, 0, ErrBindAlreadyOpen - } - - originalPort := port - -again: - port = originalPort - var sock4, sock6 int - // Attempt ipv6 bind, update port if successful. - sock6, newPort, err = create6(port) - if err != nil { - if !errors.Is(err, syscall.EAFNOSUPPORT) { - return nil, 0, err - } - } else { - port = newPort - } - - // Attempt ipv4 bind, update port if successful. - sock4, newPort, err = create4(port) - if err != nil { - if originalPort == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 { - unix.Close(sock6) - tries++ - goto again - } - if !errors.Is(err, syscall.EAFNOSUPPORT) { - unix.Close(sock6) - return nil, 0, err - } - } else { - port = newPort - } - - var fns []ReceiveFunc - if sock4 != -1 { - bind.sock4 = sock4 - fns = append(fns, bind.receiveIPv4) - } - if sock6 != -1 { - bind.sock6 = sock6 - fns = append(fns, bind.receiveIPv6) - } - if len(fns) == 0 { - return nil, 0, syscall.EAFNOSUPPORT - } - return fns, port, nil -} - -func (bind *LinuxSocketBind) SetMark(value uint32) error { - bind.mu.RLock() - defer bind.mu.RUnlock() - - if bind.sock6 != -1 { - err := unix.SetsockoptInt( - bind.sock6, - unix.SOL_SOCKET, - unix.SO_MARK, - int(value), - ) - if err != nil { - return err - } - } - - if bind.sock4 != -1 { - err := unix.SetsockoptInt( - bind.sock4, - unix.SOL_SOCKET, - unix.SO_MARK, - int(value), - ) - if err != nil { - return err - } - } - - return nil -} - -func (bind *LinuxSocketBind) BatchSize() int { - return 1 -} - -func (bind *LinuxSocketBind) Close() error { - // Take a readlock to shut down the sockets... - bind.mu.RLock() - if bind.sock6 != -1 { - unix.Shutdown(bind.sock6, unix.SHUT_RDWR) - } - if bind.sock4 != -1 { - unix.Shutdown(bind.sock4, unix.SHUT_RDWR) - } - bind.mu.RUnlock() - // ...and a write lock to close the fd. - // This ensures that no one else is using the fd. - bind.mu.Lock() - defer bind.mu.Unlock() - var err1, err2 error - if bind.sock6 != -1 { - err1 = unix.Close(bind.sock6) - bind.sock6 = -1 - } - if bind.sock4 != -1 { - err2 = unix.Close(bind.sock4) - bind.sock4 = -1 - } - - if err1 != nil { - return err1 - } - return err2 -} - -func (bind *LinuxSocketBind) receiveIPv4(buffs [][]byte, sizes []int, eps []Endpoint) (int, error) { - bind.mu.RLock() - defer bind.mu.RUnlock() - if bind.sock4 == -1 { - return 0, net.ErrClosed - } - var end LinuxSocketEndpoint - n, err := receive4(bind.sock4, buffs[0], &end) - if err != nil { - return 0, err - } - eps[0] = &end - sizes[0] = n - return 1, nil -} - -func (bind *LinuxSocketBind) receiveIPv6(buffs [][]byte, sizes []int, eps []Endpoint) (int, error) { - bind.mu.RLock() - defer bind.mu.RUnlock() - if bind.sock6 == -1 { - return 0, net.ErrClosed - } - var end LinuxSocketEndpoint - n, err := receive6(bind.sock6, buffs[0], &end) - if err != nil { - return 0, err - } - eps[0] = &end - sizes[0] = n - return 1, nil -} - -func (bind *LinuxSocketBind) Send(buffs [][]byte, end Endpoint) error { - nend, ok := end.(*LinuxSocketEndpoint) - if !ok { - return ErrWrongEndpointType - } - bind.mu.RLock() - defer bind.mu.RUnlock() - if !nend.isV6 { - if bind.sock4 == -1 { - return net.ErrClosed - } - for _, buff := range buffs { - err := send4(bind.sock4, nend, buff) - if err != nil { - return err - } - } - } else { - if bind.sock6 == -1 { - return net.ErrClosed - } - for _, buff := range buffs { - err := send6(bind.sock6, nend, buff) - if err != nil { - return err - } - } - } - return nil -} - -func (end *LinuxSocketEndpoint) SrcIP() netip.Addr { - if !end.isV6 { - return netip.AddrFrom4(end.src4().Src) - } else { - return netip.AddrFrom16(end.src6().src) - } -} - -func (end *LinuxSocketEndpoint) DstIP() netip.Addr { - if !end.isV6 { - return netip.AddrFrom4(end.dst4().Addr) - } else { - return netip.AddrFrom16(end.dst6().Addr) - } -} - -func (end *LinuxSocketEndpoint) DstToBytes() []byte { - if !end.isV6 { - return (*[unsafe.Offsetof(end.dst4().Addr) + unsafe.Sizeof(end.dst4().Addr)]byte)(unsafe.Pointer(end.dst4()))[:] - } else { - return (*[unsafe.Offsetof(end.dst6().Addr) + unsafe.Sizeof(end.dst6().Addr)]byte)(unsafe.Pointer(end.dst6()))[:] - } -} - -func (end *LinuxSocketEndpoint) SrcToString() string { - return end.SrcIP().String() -} - -func (end *LinuxSocketEndpoint) DstToString() string { - var port int - if !end.isV6 { - port = end.dst4().Port - } else { - port = end.dst6().Port - } - return netip.AddrPortFrom(end.DstIP(), uint16(port)).String() -} - -func (end *LinuxSocketEndpoint) ClearDst() { - for i := range end.dst { - end.dst[i] = 0 - } -} - -func (end *LinuxSocketEndpoint) ClearSrc() { - for i := range end.src { - end.src[i] = 0 - } -} - -func zoneToUint32(zone string) (uint32, error) { - if zone == "" { - return 0, nil - } - if intr, err := net.InterfaceByName(zone); err == nil { - return uint32(intr.Index), nil - } - n, err := strconv.ParseUint(zone, 10, 32) - return uint32(n), err -} - -func create4(port uint16) (int, uint16, error) { - // create socket - - fd, err := unix.Socket( - unix.AF_INET, - unix.SOCK_DGRAM|unix.SOCK_CLOEXEC, - 0, - ) - if err != nil { - return -1, 0, err - } - - addr := unix.SockaddrInet4{ - Port: int(port), - } - - // set sockopts and bind - - if err := func() error { - if err := unix.SetsockoptInt( - fd, - unix.IPPROTO_IP, - unix.IP_PKTINFO, - 1, - ); err != nil { - return err - } - - return unix.Bind(fd, &addr) - }(); err != nil { - unix.Close(fd) - return -1, 0, err - } - - sa, err := unix.Getsockname(fd) - if err == nil { - addr.Port = sa.(*unix.SockaddrInet4).Port - } - - return fd, uint16(addr.Port), err -} - -func create6(port uint16) (int, uint16, error) { - // create socket - - fd, err := unix.Socket( - unix.AF_INET6, - unix.SOCK_DGRAM|unix.SOCK_CLOEXEC, - 0, - ) - if err != nil { - return -1, 0, err - } - - // set sockopts and bind - - addr := unix.SockaddrInet6{ - Port: int(port), - } - - if err := func() error { - if err := unix.SetsockoptInt( - fd, - unix.IPPROTO_IPV6, - unix.IPV6_RECVPKTINFO, - 1, - ); err != nil { - return err - } - - if err := unix.SetsockoptInt( - fd, - unix.IPPROTO_IPV6, - unix.IPV6_V6ONLY, - 1, - ); err != nil { - return err - } - - return unix.Bind(fd, &addr) - }(); err != nil { - unix.Close(fd) - return -1, 0, err - } - - sa, err := unix.Getsockname(fd) - if err == nil { - addr.Port = sa.(*unix.SockaddrInet6).Port - } - - return fd, uint16(addr.Port), err -} - -func send4(sock int, end *LinuxSocketEndpoint, buff []byte) error { - // construct message header - - cmsg := struct { - cmsghdr unix.Cmsghdr - pktinfo unix.Inet4Pktinfo - }{ - unix.Cmsghdr{ - Level: unix.IPPROTO_IP, - Type: unix.IP_PKTINFO, - Len: unix.SizeofInet4Pktinfo + unix.SizeofCmsghdr, - }, - unix.Inet4Pktinfo{ - Spec_dst: end.src4().Src, - Ifindex: end.src4().Ifindex, - }, - } - - end.mu.Lock() - _, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0) - end.mu.Unlock() - - if err == nil { - return nil - } - - // clear src and retry - - if err == unix.EINVAL { - end.ClearSrc() - cmsg.pktinfo = unix.Inet4Pktinfo{} - end.mu.Lock() - _, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0) - end.mu.Unlock() - } - - return err -} - -func send6(sock int, end *LinuxSocketEndpoint, buff []byte) error { - // construct message header - - cmsg := struct { - cmsghdr unix.Cmsghdr - pktinfo unix.Inet6Pktinfo - }{ - unix.Cmsghdr{ - Level: unix.IPPROTO_IPV6, - Type: unix.IPV6_PKTINFO, - Len: unix.SizeofInet6Pktinfo + unix.SizeofCmsghdr, - }, - unix.Inet6Pktinfo{ - Addr: end.src6().src, - Ifindex: end.dst6().ZoneId, - }, - } - - if cmsg.pktinfo.Addr == [16]byte{} { - cmsg.pktinfo.Ifindex = 0 - } - - end.mu.Lock() - _, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0) - end.mu.Unlock() - - if err == nil { - return nil - } - - // clear src and retry - - if err == unix.EINVAL { - end.ClearSrc() - cmsg.pktinfo = unix.Inet6Pktinfo{} - end.mu.Lock() - _, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0) - end.mu.Unlock() - } - - return err -} - -func receive4(sock int, buff []byte, end *LinuxSocketEndpoint) (int, error) { - // construct message header - - var cmsg struct { - cmsghdr unix.Cmsghdr - pktinfo unix.Inet4Pktinfo - } - - size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0) - if err != nil { - return 0, err - } - end.isV6 = false - - if newDst4, ok := newDst.(*unix.SockaddrInet4); ok { - *end.dst4() = *newDst4 - } - - // update source cache - - if cmsg.cmsghdr.Level == unix.IPPROTO_IP && - cmsg.cmsghdr.Type == unix.IP_PKTINFO && - cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo { - end.src4().Src = cmsg.pktinfo.Spec_dst - end.src4().Ifindex = cmsg.pktinfo.Ifindex - } - - return size, nil -} - -func receive6(sock int, buff []byte, end *LinuxSocketEndpoint) (int, error) { - // construct message header - - var cmsg struct { - cmsghdr unix.Cmsghdr - pktinfo unix.Inet6Pktinfo - } - - size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0) - if err != nil { - return 0, err - } - end.isV6 = true - - if newDst6, ok := newDst.(*unix.SockaddrInet6); ok { - *end.dst6() = *newDst6 - } - - // update source cache - - if cmsg.cmsghdr.Level == unix.IPPROTO_IPV6 && - cmsg.cmsghdr.Type == unix.IPV6_PKTINFO && - cmsg.cmsghdr.Len >= unix.SizeofInet6Pktinfo { - end.src6().src = cmsg.pktinfo.Addr - end.dst6().ZoneId = cmsg.pktinfo.Ifindex - } - - return size, nil -} diff --git a/conn/bind_std.go b/conn/bind_std.go index 98fe23c..a164f56 100644 --- a/conn/bind_std.go +++ b/conn/bind_std.go @@ -6,32 +6,91 @@ package conn import ( + "context" "errors" "net" "net/netip" + "strconv" "sync" "syscall" + + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" ) -// StdNetBind is meant to be a temporary solution on platforms for which -// the sticky socket / source caching behavior has not yet been implemented. -// It uses the Go's net package to implement networking. -// See LinuxSocketBind for a proper implementation on the Linux platform. +var ( + _ Bind = (*StdNetBind)(nil) +) + +// StdNetBind implements Bind for all platforms except Windows. type StdNetBind struct { - mu sync.Mutex // protects following fields - ipv4 *net.UDPConn - ipv6 *net.UDPConn - blackhole4 bool - blackhole6 bool + mu sync.Mutex // protects following fields + ipv4 *net.UDPConn + ipv6 *net.UDPConn + blackhole4 bool + blackhole6 bool + ipv4PC *ipv4.PacketConn + ipv6PC *ipv6.PacketConn + batchSize int + udpAddrPool sync.Pool + ipv4MsgsPool sync.Pool + ipv6MsgsPool sync.Pool } -func NewStdNetBind() Bind { return &StdNetBind{} } +func NewStdNetBind() Bind { return NewStdNetBindBatch(DefaultBatchSize) } -type StdNetEndpoint netip.AddrPort +func NewStdNetBindBatch(maxBatchSize int) Bind { + if maxBatchSize == 0 { + maxBatchSize = DefaultBatchSize + } + return &StdNetBind{ + batchSize: maxBatchSize, + + udpAddrPool: sync.Pool{ + New: func() any { + return &net.UDPAddr{ + IP: make([]byte, 16), + } + }, + }, + + ipv4MsgsPool: sync.Pool{ + New: func() any { + msgs := make([]ipv4.Message, maxBatchSize) + for i := range msgs { + msgs[i].Buffers = make(net.Buffers, 1) + msgs[i].OOB = make([]byte, srcControlSize) + } + return &msgs + }, + }, + + ipv6MsgsPool: sync.Pool{ + New: func() any { + msgs := make([]ipv6.Message, maxBatchSize) + for i := range msgs { + msgs[i].Buffers = make(net.Buffers, 1) + msgs[i].OOB = make([]byte, srcControlSize) + } + return &msgs + }, + }, + } +} + +type StdNetEndpoint struct { + // AddrPort is the endpoint destination. + netip.AddrPort + // src is the current sticky source address and interface index, if supported. + src struct { + netip.Addr + ifidx int32 + } +} var ( _ Bind = (*StdNetBind)(nil) - _ Endpoint = StdNetEndpoint{} + _ Endpoint = &StdNetEndpoint{} ) func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) { @@ -39,31 +98,38 @@ func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) { return asEndpoint(e), err } -func (StdNetEndpoint) ClearSrc() {} - -func (e StdNetEndpoint) DstIP() netip.Addr { - return (netip.AddrPort)(e).Addr() +func (e *StdNetEndpoint) ClearSrc() { + e.src.ifidx = 0 + e.src.Addr = netip.Addr{} } -func (e StdNetEndpoint) SrcIP() netip.Addr { - return netip.Addr{} // not supported +func (e *StdNetEndpoint) DstIP() netip.Addr { + return e.AddrPort.Addr() } -func (e StdNetEndpoint) DstToBytes() []byte { - b, _ := (netip.AddrPort)(e).MarshalBinary() +func (e *StdNetEndpoint) SrcIP() netip.Addr { + return e.src.Addr +} + +func (e *StdNetEndpoint) SrcIfidx() int32 { + return e.src.ifidx +} + +func (e *StdNetEndpoint) DstToBytes() []byte { + b, _ := e.AddrPort.MarshalBinary() return b } -func (e StdNetEndpoint) DstToString() string { - return (netip.AddrPort)(e).String() +func (e *StdNetEndpoint) DstToString() string { + return e.AddrPort.String() } -func (e StdNetEndpoint) SrcToString() string { - return "" +func (e *StdNetEndpoint) SrcToString() string { + return e.src.Addr.String() } func listenNet(network string, port int) (*net.UDPConn, int, error) { - conn, err := net.ListenUDP(network, &net.UDPAddr{Port: port}) + conn, err := listenConfig().ListenPacket(context.Background(), network, ":"+strconv.Itoa(port)) if err != nil { return nil, 0, err } @@ -77,17 +143,17 @@ func listenNet(network string, port int) (*net.UDPConn, int, error) { if err != nil { return nil, 0, err } - return conn, uaddr.Port, nil + return conn.(*net.UDPConn), uaddr.Port, nil } -func (bind *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) { - bind.mu.Lock() - defer bind.mu.Unlock() +func (s *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) { + s.mu.Lock() + defer s.mu.Unlock() var err error var tries int - if bind.ipv4 != nil || bind.ipv6 != nil { + if s.ipv4 != nil || s.ipv6 != nil { return nil, 0, ErrBindAlreadyOpen } @@ -95,104 +161,121 @@ func (bind *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) { // If uport is 0, we can retry on failure. again: port := int(uport) - var ipv4, ipv6 *net.UDPConn + var v4conn, v6conn *net.UDPConn - ipv4, port, err = listenNet("udp4", port) + v4conn, port, err = listenNet("udp4", port) if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { return nil, 0, err } // Listen on the same port as we're using for ipv4. - ipv6, port, err = listenNet("udp6", port) + v6conn, port, err = listenNet("udp6", port) if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 { - ipv4.Close() + v4conn.Close() tries++ goto again } if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { - ipv4.Close() + v4conn.Close() return nil, 0, err } var fns []ReceiveFunc - if ipv4 != nil { - fns = append(fns, bind.makeReceiveIPv4(ipv4)) - bind.ipv4 = ipv4 + if v4conn != nil { + fns = append(fns, s.receiveIPv4) + s.ipv4 = v4conn } - if ipv6 != nil { - fns = append(fns, bind.makeReceiveIPv6(ipv6)) - bind.ipv6 = ipv6 + if v6conn != nil { + fns = append(fns, s.receiveIPv6) + s.ipv6 = v6conn } if len(fns) == 0 { return nil, 0, syscall.EAFNOSUPPORT } + + s.ipv4PC = ipv4.NewPacketConn(s.ipv4) + s.ipv6PC = ipv6.NewPacketConn(s.ipv6) + return fns, uint16(port), nil } -func (bind *StdNetBind) BatchSize() int { - return 1 +func (s *StdNetBind) receiveIPv4(buffs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { + msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message) + defer s.ipv4MsgsPool.Put(msgs) + for i := range buffs { + (*msgs)[i].Buffers[0] = buffs[i] + } + numMsgs, err := s.ipv4PC.ReadBatch(*msgs, 0) + if err != nil { + return 0, err + } + for i := 0; i < numMsgs; i++ { + msg := &(*msgs)[i] + sizes[i] = msg.N + addrPort := msg.Addr.(*net.UDPAddr).AddrPort() + ep := asEndpoint(addrPort) + getSrcFromControl(msg.OOB, ep) + eps[i] = ep + } + return numMsgs, nil } -func (bind *StdNetBind) Close() error { - bind.mu.Lock() - defer bind.mu.Unlock() +func (s *StdNetBind) receiveIPv6(buffs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { + msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message) + defer s.ipv6MsgsPool.Put(msgs) + for i := range buffs { + (*msgs)[i].Buffers[0] = buffs[i] + } + numMsgs, err := s.ipv6PC.ReadBatch(*msgs, 0) + if err != nil { + return 0, err + } + for i := 0; i < numMsgs; i++ { + msg := &(*msgs)[i] + sizes[i] = msg.N + addrPort := msg.Addr.(*net.UDPAddr).AddrPort() + ep := asEndpoint(addrPort) + getSrcFromControl(msg.OOB, ep) + eps[i] = ep + } + return numMsgs, nil +} + +func (s *StdNetBind) BatchSize() int { + return s.batchSize +} + +func (s *StdNetBind) Close() error { + s.mu.Lock() + defer s.mu.Unlock() var err1, err2 error - if bind.ipv4 != nil { - err1 = bind.ipv4.Close() - bind.ipv4 = nil + if s.ipv4 != nil { + err1 = s.ipv4.Close() + s.ipv4 = nil } - if bind.ipv6 != nil { - err2 = bind.ipv6.Close() - bind.ipv6 = nil + if s.ipv6 != nil { + err2 = s.ipv6.Close() + s.ipv6 = nil } - bind.blackhole4 = false - bind.blackhole6 = false + s.blackhole4 = false + s.blackhole6 = false if err1 != nil { return err1 } return err2 } -func (*StdNetBind) makeReceiveIPv4(conn *net.UDPConn) ReceiveFunc { - return func(buffs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { - size, endpoint, err := conn.ReadFromUDPAddrPort(buffs[0]) - if err == nil { - sizes[0] = size - eps[0] = asEndpoint(endpoint) - return 1, nil - } - return 0, err +func (s *StdNetBind) Send(buffs [][]byte, endpoint Endpoint) error { + s.mu.Lock() + blackhole := s.blackhole4 + conn := s.ipv4 + is6 := false + if endpoint.DstIP().Is6() { + blackhole = s.blackhole6 + conn = s.ipv6 + is6 = true } -} - -func (*StdNetBind) makeReceiveIPv6(conn *net.UDPConn) ReceiveFunc { - return func(buffs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { - size, endpoint, err := conn.ReadFromUDPAddrPort(buffs[0]) - if err == nil { - sizes[0] = size - eps[0] = asEndpoint(endpoint) - return 1, nil - } - return 0, err - } -} - -func (bind *StdNetBind) Send(buffs [][]byte, endpoint Endpoint) error { - var err error - nend, ok := endpoint.(StdNetEndpoint) - if !ok { - return ErrWrongEndpointType - } - addrPort := netip.AddrPort(nend) - - bind.mu.Lock() - blackhole := bind.blackhole4 - conn := bind.ipv4 - if addrPort.Addr().Is6() { - blackhole = bind.blackhole6 - conn = bind.ipv6 - } - bind.mu.Unlock() + s.mu.Unlock() if blackhole { return nil @@ -200,13 +283,69 @@ func (bind *StdNetBind) Send(buffs [][]byte, endpoint Endpoint) error { if conn == nil { return syscall.EAFNOSUPPORT } - for _, buff := range buffs { - _, err = conn.WriteToUDPAddrPort(buff, addrPort) - if err != nil { - return err - } + if is6 { + return s.send6(s.ipv6PC, endpoint, buffs) + } else { + return s.send4(s.ipv4PC, endpoint, buffs) } - return nil +} + +func (s *StdNetBind) send4(conn *ipv4.PacketConn, ep Endpoint, buffs [][]byte) error { + ua := s.udpAddrPool.Get().(*net.UDPAddr) + as4 := ep.DstIP().As4() + copy(ua.IP, as4[:]) + ua.IP = ua.IP[:4] + ua.Port = int(ep.(*StdNetEndpoint).Port()) + msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message) + for i, buff := range buffs { + (*msgs)[i].Buffers[0] = buff + (*msgs)[i].Addr = ua + setSrcControl(&(*msgs)[i].OOB, ep.(*StdNetEndpoint)) + } + var ( + n int + err error + start int + ) + for { + n, err = conn.WriteBatch((*msgs)[start:len(buffs)], 0) + if err != nil || n == len((*msgs)[start:len(buffs)]) { + break + } + start += n + } + s.udpAddrPool.Put(ua) + s.ipv4MsgsPool.Put(msgs) + return err +} + +func (s *StdNetBind) send6(conn *ipv6.PacketConn, ep Endpoint, buffs [][]byte) error { + ua := s.udpAddrPool.Get().(*net.UDPAddr) + as16 := ep.DstIP().As16() + copy(ua.IP, as16[:]) + ua.IP = ua.IP[:16] + ua.Port = int(ep.(*StdNetEndpoint).Port()) + msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message) + for i, buff := range buffs { + (*msgs)[i].Buffers[0] = buff + (*msgs)[i].Addr = ua + setSrcControl(&(*msgs)[i].OOB, ep.(*StdNetEndpoint)) + } + var ( + n int + err error + start int + ) + for { + n, err = conn.WriteBatch((*msgs)[start:len(buffs)], 0) + if err != nil || n == len((*msgs)[start:len(buffs)]) { + break + } + start += n + } + s.udpAddrPool.Put(ua) + s.ipv6MsgsPool.Put(msgs) + return err } // endpointPool contains a re-usable set of mapping from netip.AddrPort to Endpoint. @@ -214,17 +353,17 @@ func (bind *StdNetBind) Send(buffs [][]byte, endpoint Endpoint) error { // but Endpoints are immutable, so we can re-use them. var endpointPool = sync.Pool{ New: func() any { - return make(map[netip.AddrPort]Endpoint) + return make(map[netip.AddrPort]*StdNetEndpoint) }, } // asEndpoint returns an Endpoint containing ap. -func asEndpoint(ap netip.AddrPort) Endpoint { - m := endpointPool.Get().(map[netip.AddrPort]Endpoint) +func asEndpoint(ap netip.AddrPort) *StdNetEndpoint { + m := endpointPool.Get().(map[netip.AddrPort]*StdNetEndpoint) defer endpointPool.Put(m) e, ok := m[ap] if !ok { - e = Endpoint(StdNetEndpoint(ap)) + e = &StdNetEndpoint{AddrPort: ap} m[ap] = e } return e diff --git a/conn/boundif_android.go b/conn/boundif_android.go index 818e4e6..dd3ca5b 100644 --- a/conn/boundif_android.go +++ b/conn/boundif_android.go @@ -5,8 +5,8 @@ package conn -func (bind *StdNetBind) PeekLookAtSocketFd4() (fd int, err error) { - sysconn, err := bind.ipv4.SyscallConn() +func (s *StdNetBind) PeekLookAtSocketFd4() (fd int, err error) { + sysconn, err := s.ipv4.SyscallConn() if err != nil { return -1, err } @@ -19,8 +19,8 @@ func (bind *StdNetBind) PeekLookAtSocketFd4() (fd int, err error) { return } -func (bind *StdNetBind) PeekLookAtSocketFd6() (fd int, err error) { - sysconn, err := bind.ipv6.SyscallConn() +func (s *StdNetBind) PeekLookAtSocketFd6() (fd int, err error) { + sysconn, err := s.ipv6.SyscallConn() if err != nil { return -1, err } diff --git a/conn/conn.go b/conn/conn.go index 8c0a827..9cbd0af 100644 --- a/conn/conn.go +++ b/conn/conn.go @@ -16,7 +16,7 @@ import ( ) const ( - DefaultBatchSize = 1 // maximum number of packets handled per read and write + DefaultBatchSize = 128 // maximum number of packets handled per read and write ) // A ReceiveFunc receives at least one packet from the network and writes them diff --git a/conn/controlfns.go b/conn/controlfns.go new file mode 100644 index 0000000..fe32871 --- /dev/null +++ b/conn/controlfns.go @@ -0,0 +1,36 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package conn + +import ( + "net" + "syscall" +) + +// controlFn is the callback function signature from net.ListenConfig.Control. +// It is used to apply platform specific configuration to the socket prior to +// bind. +type controlFn func(network, address string, c syscall.RawConn) error + +// controlFns is a list of functions that are called from the listen config +// that can apply socket options. +var controlFns = []controlFn{} + +// listenConfig returns a net.ListenConfig that applies the controlFns to the +// socket prior to bind. This is used to apply socket buffer sizing and packet +// information OOB configuration for sticky sockets. +func listenConfig() *net.ListenConfig { + return &net.ListenConfig{ + Control: func(network, address string, c syscall.RawConn) error { + for _, fn := range controlFns { + if err := fn(network, address, c); err != nil { + return err + } + } + return nil + }, + } +} diff --git a/conn/controlfns_linux.go b/conn/controlfns_linux.go new file mode 100644 index 0000000..9e26d95 --- /dev/null +++ b/conn/controlfns_linux.go @@ -0,0 +1,41 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package conn + +import ( + "fmt" + "syscall" + + "golang.org/x/sys/unix" +) + +func init() { + controlFns = append(controlFns, + + // Enable receiving of the packet information (IP_PKTINFO for IPv4, + // IPV6_PKTINFO for IPv6) that is used to implement sticky socket support. + func(network, address string, c syscall.RawConn) error { + var err error + switch network { + case "udp4": + c.Control(func(fd uintptr) { + err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_PKTINFO, 1) + }) + case "udp6": + c.Control(func(fd uintptr) { + err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVPKTINFO, 1) + if err != nil { + return + } + err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_V6ONLY, 1) + }) + default: + err = fmt.Errorf("unhandled network: %s: %w", network, unix.EINVAL) + } + return err + }, + ) +} diff --git a/conn/controlfns_unix.go b/conn/controlfns_unix.go new file mode 100644 index 0000000..9738c73 --- /dev/null +++ b/conn/controlfns_unix.go @@ -0,0 +1,28 @@ +//go:build !windows && !linux && !js + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package conn + +import ( + "syscall" + + "golang.org/x/sys/unix" +) + +func init() { + controlFns = append(controlFns, + func(network, address string, c syscall.RawConn) error { + var err error + if network == "udp6" { + c.Control(func(fd uintptr) { + err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_V6ONLY, 1) + }) + } + return err + }, + ) +} diff --git a/conn/default.go b/conn/default.go index c7b4a84..b6f761b 100644 --- a/conn/default.go +++ b/conn/default.go @@ -1,4 +1,4 @@ -//go:build !linux && !windows +//go:build !windows /* SPDX-License-Identifier: MIT * diff --git a/conn/mark_default.go b/conn/mark_default.go index 9944c38..3102384 100644 --- a/conn/mark_default.go +++ b/conn/mark_default.go @@ -7,6 +7,6 @@ package conn -func (bind *StdNetBind) SetMark(mark uint32) error { +func (s *StdNetBind) SetMark(mark uint32) error { return nil } diff --git a/conn/mark_unix.go b/conn/mark_unix.go index 5566b28..d9e46ee 100644 --- a/conn/mark_unix.go +++ b/conn/mark_unix.go @@ -26,13 +26,13 @@ func init() { } } -func (bind *StdNetBind) SetMark(mark uint32) error { +func (s *StdNetBind) SetMark(mark uint32) error { var operr error if fwmarkIoctl == 0 { return nil } - if bind.ipv4 != nil { - fd, err := bind.ipv4.SyscallConn() + if s.ipv4 != nil { + fd, err := s.ipv4.SyscallConn() if err != nil { return err } @@ -46,8 +46,8 @@ func (bind *StdNetBind) SetMark(mark uint32) error { return err } } - if bind.ipv6 != nil { - fd, err := bind.ipv6.SyscallConn() + if s.ipv6 != nil { + fd, err := s.ipv6.SyscallConn() if err != nil { return err } diff --git a/conn/sticky_default.go b/conn/sticky_default.go new file mode 100644 index 0000000..3ce9a56 --- /dev/null +++ b/conn/sticky_default.go @@ -0,0 +1,26 @@ +//go:build !linux +// +build !linux + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package conn + +// TODO: macOS, FreeBSD and other BSDs likely do support this feature set, but +// use alternatively named flags and need ports and require testing. + +// getSrcFromControl parses the control for PKTINFO and if found updates ep with +// the source information found. +func getSrcFromControl(control []byte, ep *StdNetEndpoint) { +} + +// setSrcControl parses the control for PKTINFO and if found updates ep with +// the source information found. +func setSrcControl(control *[]byte, ep *StdNetEndpoint) { +} + +// srcControlSize returns the recommended buffer size for pooling sticky control +// data. +const srcControlSize = 0 diff --git a/conn/sticky_linux.go b/conn/sticky_linux.go new file mode 100644 index 0000000..bf17839 --- /dev/null +++ b/conn/sticky_linux.go @@ -0,0 +1,111 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package conn + +import ( + "net/netip" + "unsafe" + + "golang.org/x/sys/unix" +) + +// getSrcFromControl parses the control for PKTINFO and if found updates ep with +// the source information found. +func getSrcFromControl(control []byte, ep *StdNetEndpoint) { + ep.ClearSrc() + + var ( + hdr unix.Cmsghdr + data []byte + rem []byte = control + err error + ) + + for len(rem) > unix.SizeofCmsghdr { + hdr, data, rem, err = unix.ParseOneSocketControlMessage(control) + if err != nil { + return + } + + if hdr.Level == unix.IPPROTO_IP && + hdr.Type == unix.IP_PKTINFO { + + info := pktInfoFromBuf[unix.Inet4Pktinfo](data) + ep.src.Addr = netip.AddrFrom4(info.Spec_dst) + ep.src.ifidx = info.Ifindex + + return + } + + if hdr.Level == unix.IPPROTO_IPV6 && + hdr.Type == unix.IPV6_PKTINFO { + + info := pktInfoFromBuf[unix.Inet6Pktinfo](data) + ep.src.Addr = netip.AddrFrom16(info.Addr) + ep.src.ifidx = int32(info.Ifindex) + + return + } + } +} + +// pktInfoFromBuf returns type T populated from the provided buf via copy(). It +// panics if buf is of insufficient size. +func pktInfoFromBuf[T unix.Inet4Pktinfo | unix.Inet6Pktinfo](buf []byte) (t T) { + size := int(unsafe.Sizeof(t)) + if len(buf) < size { + panic("pktInfoFromBuf: buffer too small") + } + copy(unsafe.Slice((*byte)(unsafe.Pointer(&t)), size), buf) + return t +} + +// setSrcControl parses the control for PKTINFO and if found updates ep with +// the source information found. +func setSrcControl(control *[]byte, ep *StdNetEndpoint) { + *control = (*control)[:cap(*control)] + if len(*control) < int(unsafe.Sizeof(unix.Cmsghdr{})) { + *control = (*control)[:0] + return + } + + if ep.src.ifidx == 0 && !ep.SrcIP().IsValid() { + *control = (*control)[:0] + return + } + + if len(*control) < srcControlSize { + *control = (*control)[:0] + return + } + + hdr := (*unix.Cmsghdr)(unsafe.Pointer(&(*control)[0])) + if ep.SrcIP().Is4() { + hdr.Level = unix.IPPROTO_IP + hdr.Type = unix.IP_PKTINFO + hdr.SetLen(unix.CmsgLen(unix.SizeofInet4Pktinfo)) + + info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&(*control)[unix.SizeofCmsghdr])) + info.Ifindex = ep.src.ifidx + if ep.SrcIP().IsValid() { + info.Spec_dst = ep.SrcIP().As4() + } + } else { + hdr.Level = unix.IPPROTO_IPV6 + hdr.Type = unix.IPV6_PKTINFO + hdr.Len = unix.SizeofCmsghdr + unix.SizeofInet6Pktinfo + + info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&(*control)[unix.SizeofCmsghdr])) + info.Ifindex = uint32(ep.src.ifidx) + if ep.SrcIP().IsValid() { + info.Addr = ep.SrcIP().As16() + } + } + + *control = (*control)[:hdr.Len] +} + +var srcControlSize = unix.CmsgLen(unix.SizeofInet6Pktinfo) diff --git a/conn/sticky_linux_test.go b/conn/sticky_linux_test.go new file mode 100644 index 0000000..a42c89e --- /dev/null +++ b/conn/sticky_linux_test.go @@ -0,0 +1,207 @@ +//go:build linux +// +build linux + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package conn + +import ( + "context" + "net" + "net/netip" + "runtime" + "testing" + "unsafe" + + "golang.org/x/sys/unix" +) + +func Test_setSrcControl(t *testing.T) { + t.Run("IPv4", func(t *testing.T) { + ep := &StdNetEndpoint{ + AddrPort: netip.MustParseAddrPort("127.0.0.1:1234"), + } + ep.src.Addr = netip.MustParseAddr("127.0.0.1") + ep.src.ifidx = 5 + + control := make([]byte, srcControlSize) + + setSrcControl(&control, ep) + + hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0])) + if hdr.Level != unix.IPPROTO_IP { + t.Errorf("unexpected level: %d", hdr.Level) + } + if hdr.Type != unix.IP_PKTINFO { + t.Errorf("unexpected type: %d", hdr.Type) + } + if hdr.Len != uint64(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet4Pktinfo{})))) { + t.Errorf("unexpected length: %d", hdr.Len) + } + info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)])) + if info.Spec_dst[0] != 127 || info.Spec_dst[1] != 0 || info.Spec_dst[2] != 0 || info.Spec_dst[3] != 1 { + t.Errorf("unexpected address: %v", info.Spec_dst) + } + if info.Ifindex != 5 { + t.Errorf("unexpected ifindex: %d", info.Ifindex) + } + }) + + t.Run("IPv6", func(t *testing.T) { + ep := &StdNetEndpoint{ + AddrPort: netip.MustParseAddrPort("[::1]:1234"), + } + ep.src.Addr = netip.MustParseAddr("::1") + ep.src.ifidx = 5 + + control := make([]byte, srcControlSize) + + setSrcControl(&control, ep) + + hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0])) + if hdr.Level != unix.IPPROTO_IPV6 { + t.Errorf("unexpected level: %d", hdr.Level) + } + if hdr.Type != unix.IPV6_PKTINFO { + t.Errorf("unexpected type: %d", hdr.Type) + } + if hdr.Len != uint64(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet6Pktinfo{})))) { + t.Errorf("unexpected length: %d", hdr.Len) + } + info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)])) + if info.Addr != ep.SrcIP().As16() { + t.Errorf("unexpected address: %v", info.Addr) + } + if info.Ifindex != 5 { + t.Errorf("unexpected ifindex: %d", info.Ifindex) + } + }) + + t.Run("ClearOnNoSrc", func(t *testing.T) { + control := make([]byte, srcControlSize) + hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0])) + hdr.Level = 1 + hdr.Type = 2 + hdr.Len = 3 + + setSrcControl(&control, &StdNetEndpoint{}) + + if len(control) != 0 { + t.Errorf("unexpected control: %v", control) + } + }) +} + +func Test_getSrcFromControl(t *testing.T) { + t.Run("IPv4", func(t *testing.T) { + control := make([]byte, srcControlSize) + hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0])) + hdr.Level = unix.IPPROTO_IP + hdr.Type = unix.IP_PKTINFO + hdr.Len = uint64(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet4Pktinfo{})))) + info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)])) + info.Spec_dst = [4]byte{127, 0, 0, 1} + info.Ifindex = 5 + + ep := &StdNetEndpoint{} + getSrcFromControl(control, ep) + + if ep.src.Addr != netip.MustParseAddr("127.0.0.1") { + t.Errorf("unexpected address: %v", ep.src.Addr) + } + if ep.src.ifidx != 5 { + t.Errorf("unexpected ifindex: %d", ep.src.ifidx) + } + }) + t.Run("IPv6", func(t *testing.T) { + control := make([]byte, srcControlSize) + hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0])) + hdr.Level = unix.IPPROTO_IPV6 + hdr.Type = unix.IPV6_PKTINFO + hdr.Len = uint64(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet6Pktinfo{})))) + info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)])) + info.Addr = [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1} + info.Ifindex = 5 + + ep := &StdNetEndpoint{} + getSrcFromControl(control, ep) + + if ep.SrcIP() != netip.MustParseAddr("::1") { + t.Errorf("unexpected address: %v", ep.SrcIP()) + } + if ep.src.ifidx != 5 { + t.Errorf("unexpected ifindex: %d", ep.src.ifidx) + } + }) + t.Run("ClearOnEmpty", func(t *testing.T) { + control := make([]byte, srcControlSize) + ep := &StdNetEndpoint{} + ep.src.Addr = netip.MustParseAddr("::1") + ep.src.ifidx = 5 + + getSrcFromControl(control, ep) + if ep.SrcIP().IsValid() { + t.Errorf("unexpected address: %v", ep.src.Addr) + } + if ep.src.ifidx != 0 { + t.Errorf("unexpected ifindex: %d", ep.src.ifidx) + } + }) +} + +func Test_listenConfig(t *testing.T) { + t.Run("IPv4", func(t *testing.T) { + conn, err := listenConfig().ListenPacket(context.Background(), "udp4", ":0") + if err != nil { + t.Fatal(err) + } + defer conn.Close() + sc, err := conn.(*net.UDPConn).SyscallConn() + if err != nil { + t.Fatal(err) + } + + if runtime.GOOS == "linux" { + var i int + sc.Control(func(fd uintptr) { + i, err = unix.GetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_PKTINFO) + }) + if err != nil { + t.Fatal(err) + } + if i != 1 { + t.Error("IP_PKTINFO not set!") + } + } else { + t.Logf("listenConfig() does not set IPV6_RECVPKTINFO on %s", runtime.GOOS) + } + }) + t.Run("IPv6", func(t *testing.T) { + conn, err := listenConfig().ListenPacket(context.Background(), "udp6", ":0") + if err != nil { + t.Fatal(err) + } + sc, err := conn.(*net.UDPConn).SyscallConn() + if err != nil { + t.Fatal(err) + } + + if runtime.GOOS == "linux" { + var i int + sc.Control(func(fd uintptr) { + i, err = unix.GetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVPKTINFO) + }) + if err != nil { + t.Fatal(err) + } + if i != 1 { + t.Error("IPV6_PKTINFO not set!") + } + } else { + t.Logf("listenConfig() does not set IPV6_RECVPKTINFO on %s", runtime.GOOS) + } + }) +} diff --git a/device/queueconstants_android.go b/device/queueconstants_android.go index fc937b3..1158387 100644 --- a/device/queueconstants_android.go +++ b/device/queueconstants_android.go @@ -5,10 +5,12 @@ package device +import "golang.zx2c4.com/wireguard/conn" + /* Reduce memory consumption for Android */ const ( - QueueStagedSize = 128 + QueueStagedSize = conn.DefaultBatchSize QueueOutboundSize = 1024 QueueInboundSize = 1024 QueueHandshakeSize = 1024 diff --git a/device/queueconstants_default.go b/device/queueconstants_default.go index 6b69150..7ed70a1 100644 --- a/device/queueconstants_default.go +++ b/device/queueconstants_default.go @@ -7,8 +7,10 @@ package device +import "golang.zx2c4.com/wireguard/conn" + const ( - QueueStagedSize = 128 + QueueStagedSize = conn.DefaultBatchSize QueueOutboundSize = 1024 QueueInboundSize = 1024 QueueHandshakeSize = 1024 diff --git a/device/sticky_linux.go b/device/sticky_linux.go index 7afdf28..3ce0769 100644 --- a/device/sticky_linux.go +++ b/device/sticky_linux.go @@ -25,7 +25,7 @@ import ( ) func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) { - if _, ok := bind.(*conn.LinuxSocketBind); !ok { + if _, ok := bind.(*conn.StdNetBind); !ok { return nil, nil } @@ -112,11 +112,11 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl pePtr.peer.Unlock() break } - if uint32(pePtr.peer.endpoint.(*conn.LinuxSocketEndpoint).Src4().Ifindex) == ifidx { + if uint32(pePtr.peer.endpoint.(*conn.StdNetEndpoint).SrcIfidx()) == ifidx { pePtr.peer.Unlock() break } - pePtr.peer.endpoint.(*conn.LinuxSocketEndpoint).ClearSrc() + pePtr.peer.endpoint.(*conn.StdNetEndpoint).ClearSrc() pePtr.peer.Unlock() } attr = attr[attrhdr.Len:] @@ -136,12 +136,12 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl peer.RUnlock() continue } - nativeEP, _ := peer.endpoint.(*conn.LinuxSocketEndpoint) + nativeEP, _ := peer.endpoint.(*conn.StdNetEndpoint) if nativeEP == nil { peer.RUnlock() continue } - if nativeEP.IsV6() || nativeEP.Src4().Ifindex == 0 { + if nativeEP.DstIP().Is6() || nativeEP.SrcIfidx() == 0 { peer.RUnlock() break } @@ -169,12 +169,12 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl Len: 8, Type: unix.RTA_DST, }, - nativeEP.Dst4().Addr, + nativeEP.DstIP().As4(), unix.RtAttr{ Len: 8, Type: unix.RTA_SRC, }, - nativeEP.Src4().Src, + nativeEP.SrcIP().As4(), unix.RtAttr{ Len: 8, Type: unix.RTA_MARK, diff --git a/go.mod b/go.mod index 46427df..a6f7254 100644 --- a/go.mod +++ b/go.mod @@ -3,9 +3,9 @@ module golang.zx2c4.com/wireguard go 1.19 require ( - golang.org/x/crypto v0.0.0-20220315160706-3147a52a75dd - golang.org/x/net v0.0.0-20220225172249-27dd8689420f - golang.org/x/sys v0.2.0 + golang.org/x/crypto v0.3.0 + golang.org/x/net v0.2.0 + golang.org/x/sys v0.5.1-0.20230222185716-a3b23cc77e89 golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224 gvisor.dev/gvisor v0.0.0-20221203005347-703fd9b7fbc0 ) diff --git a/go.sum b/go.sum index 97ae95d..a967a10 100644 --- a/go.sum +++ b/go.sum @@ -1,11 +1,11 @@ github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4= github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA= -golang.org/x/crypto v0.0.0-20220315160706-3147a52a75dd h1:XcWmESyNjXJMLahc3mqVQJcgSTDxFxhETVlfk9uGc38= -golang.org/x/crypto v0.0.0-20220315160706-3147a52a75dd/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= -golang.org/x/net v0.0.0-20220225172249-27dd8689420f h1:oA4XRj0qtSt8Yo1Zms0CUlsT3KG69V2UGQWPBxujDmc= -golang.org/x/net v0.0.0-20220225172249-27dd8689420f/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= -golang.org/x/sys v0.2.0 h1:ljd4t30dBnAvMZaQCevtY0xLLD0A+bRZXbgLMLU1F/A= -golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/crypto v0.3.0 h1:a06MkbcxBrEFc0w0QIZWXrH/9cCX6KJyWbBOIwAn+7A= +golang.org/x/crypto v0.3.0/go.mod h1:hebNnKkNXi2UzZN1eVRvBB7co0a+JxK6XbPiWVs/3J4= +golang.org/x/net v0.2.0 h1:sZfSu1wtKLGlWI4ZZayP0ck9Y73K1ynO6gqzTdBVdPU= +golang.org/x/net v0.2.0/go.mod h1:KqCZLdyyvdV855qA2rE3GC2aiw5xGR5TEjj8smXukLY= +golang.org/x/sys v0.5.1-0.20230222185716-a3b23cc77e89 h1:260HNjMTPDya+jq5AM1zZLgG9pv9GASPAGiEEJUbRg4= +golang.org/x/sys v0.5.1-0.20230222185716-a3b23cc77e89/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/time v0.0.0-20191024005414-555d28b269f0 h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224 h1:Ug9qvr1myri/zFN6xL17LSCBGFDnphBBhzmILHsM5TY= diff --git a/tun/checksum.go b/tun/checksum.go new file mode 100644 index 0000000..f4f8471 --- /dev/null +++ b/tun/checksum.go @@ -0,0 +1,42 @@ +package tun + +import "encoding/binary" + +// TODO: Explore SIMD and/or other assembly optimizations. +func checksumNoFold(b []byte, initial uint64) uint64 { + ac := initial + i := 0 + n := len(b) + for n >= 4 { + ac += uint64(binary.BigEndian.Uint32(b[i : i+4])) + n -= 4 + i += 4 + } + for n >= 2 { + ac += uint64(binary.BigEndian.Uint16(b[i : i+2])) + n -= 2 + i += 2 + } + if n == 1 { + ac += uint64(b[i]) << 8 + } + return ac +} + +func checksum(b []byte, initial uint64) uint16 { + ac := checksumNoFold(b, initial) + ac = (ac >> 16) + (ac & 0xffff) + ac = (ac >> 16) + (ac & 0xffff) + ac = (ac >> 16) + (ac & 0xffff) + ac = (ac >> 16) + (ac & 0xffff) + return uint16(ac) +} + +func pseudoHeaderChecksumNoFold(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint64 { + sum := checksumNoFold(srcAddr, 0) + sum = checksumNoFold(dstAddr, sum) + sum = checksumNoFold([]byte{0, protocol}, sum) + tmp := make([]byte, 2) + binary.BigEndian.PutUint16(tmp, totalLen) + return checksumNoFold(tmp, sum) +} diff --git a/tun/tcp_offload_linux.go b/tun/tcp_offload_linux.go new file mode 100644 index 0000000..f3ffa75 --- /dev/null +++ b/tun/tcp_offload_linux.go @@ -0,0 +1,612 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package tun + +import ( + "bytes" + "encoding/binary" + "errors" + "io" + "unsafe" + + "golang.org/x/sys/unix" + "golang.zx2c4.com/wireguard/conn" +) + +const tcpFlagsOffset = 13 + +const ( + tcpFlagFIN uint8 = 0x01 + tcpFlagPSH uint8 = 0x08 + tcpFlagACK uint8 = 0x10 +) + +// virtioNetHdr is defined in the kernel in include/uapi/linux/virtio_net.h. The +// kernel symbol is virtio_net_hdr. +type virtioNetHdr struct { + flags uint8 + gsoType uint8 + hdrLen uint16 + gsoSize uint16 + csumStart uint16 + csumOffset uint16 +} + +func (v *virtioNetHdr) decode(b []byte) error { + if len(b) < virtioNetHdrLen { + return io.ErrShortBuffer + } + copy(unsafe.Slice((*byte)(unsafe.Pointer(v)), virtioNetHdrLen), b[:virtioNetHdrLen]) + return nil +} + +func (v *virtioNetHdr) encode(b []byte) error { + if len(b) < virtioNetHdrLen { + return io.ErrShortBuffer + } + copy(b[:virtioNetHdrLen], unsafe.Slice((*byte)(unsafe.Pointer(v)), virtioNetHdrLen)) + return nil +} + +const ( + // virtioNetHdrLen is the length in bytes of virtioNetHdr. This matches the + // shape of the C ABI for its kernel counterpart -- sizeof(virtio_net_hdr). + virtioNetHdrLen = int(unsafe.Sizeof(virtioNetHdr{})) +) + +// flowKey represents the key for a flow. +type flowKey struct { + srcAddr, dstAddr [16]byte + srcPort, dstPort uint16 + rxAck uint32 // varying ack values should not be coalesced. Treat them as separate flows. +} + +// tcpGROTable holds flow and coalescing information for the purposes of GRO. +type tcpGROTable struct { + itemsByFlow map[flowKey][]tcpGROItem + itemsPool [][]tcpGROItem +} + +func newTCPGROTable() *tcpGROTable { + t := &tcpGROTable{ + itemsByFlow: make(map[flowKey][]tcpGROItem, conn.DefaultBatchSize), + itemsPool: make([][]tcpGROItem, conn.DefaultBatchSize), + } + for i := range t.itemsPool { + t.itemsPool[i] = make([]tcpGROItem, 0, conn.DefaultBatchSize) + } + return t +} + +func newFlowKey(pkt []byte, srcAddr, dstAddr, tcphOffset int) flowKey { + key := flowKey{} + addrSize := dstAddr - srcAddr + copy(key.srcAddr[:], pkt[srcAddr:dstAddr]) + copy(key.dstAddr[:], pkt[dstAddr:dstAddr+addrSize]) + key.srcPort = binary.BigEndian.Uint16(pkt[tcphOffset:]) + key.dstPort = binary.BigEndian.Uint16(pkt[tcphOffset+2:]) + key.rxAck = binary.BigEndian.Uint32(pkt[tcphOffset+8:]) + return key +} + +// lookupOrInsert looks up a flow for the provided packet and metadata, +// returning the packets found for the flow, or inserting a new one if none +// is found. +func (t *tcpGROTable) lookupOrInsert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, buffsIndex int) ([]tcpGROItem, bool) { + key := newFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset) + items, ok := t.itemsByFlow[key] + if ok { + return items, ok + } + // TODO: insert() performs another map lookup. This could be rearranged to avoid. + t.insert(pkt, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, buffsIndex) + return nil, false +} + +// insert an item in the table for the provided packet and packet metadata. +func (t *tcpGROTable) insert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, buffsIndex int) { + key := newFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset) + item := tcpGROItem{ + key: key, + buffsIndex: uint16(buffsIndex), + gsoSize: uint16(len(pkt[tcphOffset+tcphLen:])), + iphLen: uint8(tcphOffset), + tcphLen: uint8(tcphLen), + sentSeq: binary.BigEndian.Uint32(pkt[tcphOffset+4:]), + pshSet: pkt[tcphOffset+tcpFlagsOffset]&tcpFlagPSH != 0, + } + items, ok := t.itemsByFlow[key] + if !ok { + items = t.newItems() + } + items = append(items, item) + t.itemsByFlow[key] = items +} + +func (t *tcpGROTable) updateAt(item tcpGROItem, i int) { + items, _ := t.itemsByFlow[item.key] + items[i] = item +} + +func (t *tcpGROTable) deleteAt(key flowKey, i int) { + items, _ := t.itemsByFlow[key] + items = append(items[:i], items[i+1:]...) + t.itemsByFlow[key] = items +} + +// tcpGROItem represents bookkeeping data for a TCP packet during the lifetime +// of a GRO evaluation across a vector of packets. +type tcpGROItem struct { + key flowKey + sentSeq uint32 // the sequence number + buffsIndex uint16 // the index into the original buffs slice + numMerged uint16 // the number of packets merged into this item + gsoSize uint16 // payload size + iphLen uint8 // ip header len + tcphLen uint8 // tcp header len + pshSet bool // psh flag is set +} + +func (t *tcpGROTable) newItems() []tcpGROItem { + var items []tcpGROItem + items, t.itemsPool = t.itemsPool[len(t.itemsPool)-1], t.itemsPool[:len(t.itemsPool)-1] + return items +} + +func (t *tcpGROTable) reset() { + for k, items := range t.itemsByFlow { + items = items[:0] + t.itemsPool = append(t.itemsPool, items) + delete(t.itemsByFlow, k) + } +} + +// canCoalesce represents the outcome of checking if two TCP packets are +// candidates for coalescing. +type canCoalesce int + +const ( + coalescePrepend canCoalesce = -1 + coalesceUnavailable canCoalesce = 0 + coalesceAppend canCoalesce = 1 +) + +// tcpPacketsCanCoalesce evaluates if pkt can be coalesced with the packet +// described by item. This function makes considerations that match the kernel's +// GRO self tests, which can be found in tools/testing/selftests/net/gro.c. +func tcpPacketsCanCoalesce(pkt []byte, iphLen, tcphLen uint8, seq uint32, pshSet bool, gsoSize uint16, item tcpGROItem, buffs [][]byte, buffsOffset int) canCoalesce { + pktTarget := buffs[item.buffsIndex][buffsOffset:] + if tcphLen != item.tcphLen { + // cannot coalesce with unequal tcp options len + return coalesceUnavailable + } + if tcphLen > 20 { + if !bytes.Equal(pkt[iphLen+20:iphLen+tcphLen], pktTarget[item.iphLen+20:iphLen+tcphLen]) { + // cannot coalesce with unequal tcp options + return coalesceUnavailable + } + } + if pkt[1] != pktTarget[1] { + // cannot coalesce with unequal ToS values + return coalesceUnavailable + } + if pkt[6]>>5 != pktTarget[6]>>5 { + // cannot coalesce with unequal DF or reserved bits. MF is checked + // further up the stack. + return coalesceUnavailable + } + // seq adjacency + lhsLen := item.gsoSize + lhsLen += item.numMerged * item.gsoSize + if seq == item.sentSeq+uint32(lhsLen) { // pkt aligns following item from a seq num perspective + if item.pshSet { + // We cannot append to a segment that has the PSH flag set, PSH + // can only be set on the final segment in a reassembled group. + return coalesceUnavailable + } + if len(pktTarget[iphLen+tcphLen:])%int(item.gsoSize) != 0 { + // A smaller than gsoSize packet has been appended previously. + // Nothing can come after a smaller packet on the end. + return coalesceUnavailable + } + if gsoSize > item.gsoSize { + // We cannot have a larger packet following a smaller one. + return coalesceUnavailable + } + return coalesceAppend + } else if seq+uint32(gsoSize) == item.sentSeq { // pkt aligns in front of item from a seq num perspective + if pshSet { + // We cannot prepend with a segment that has the PSH flag set, PSH + // can only be set on the final segment in a reassembled group. + return coalesceUnavailable + } + if gsoSize < item.gsoSize { + // We cannot have a larger packet following a smaller one. + return coalesceUnavailable + } + if gsoSize > item.gsoSize && item.numMerged > 0 { + // There's at least one previous merge, and we're larger than all + // previous. This would put multiple smaller packets on the end. + return coalesceUnavailable + } + return coalescePrepend + } + return coalesceUnavailable +} + +func tcpChecksumValid(pkt []byte, iphLen uint8, isV6 bool) bool { + srcAddrAt := ipv4SrcAddrOffset + addrSize := 4 + if isV6 { + srcAddrAt = ipv6SrcAddrOffset + addrSize = 16 + } + tcpTotalLen := uint16(len(pkt) - int(iphLen)) + tcpCSumNoFold := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, pkt[srcAddrAt:srcAddrAt+addrSize], pkt[srcAddrAt+addrSize:srcAddrAt+addrSize*2], tcpTotalLen) + return ^checksum(pkt[iphLen:], tcpCSumNoFold) == 0 +} + +// coalesceResult represents the result of attempting to coalesce two TCP +// packets. +type coalesceResult int + +const ( + coalesceInsufficientCap coalesceResult = 0 + coalescePSHEnding coalesceResult = 1 + coalesceItemInvalidCSum coalesceResult = 2 + coalescePktInvalidCSum coalesceResult = 3 + coalesceSuccess coalesceResult = 4 +) + +// coalesceTCPPackets attempts to coalesce pkt with the packet described by +// item, returning the outcome. This function may swap buffs elements in the +// event of a prepend as item's buffs index is already being tracked for writing +// to a Device. +func coalesceTCPPackets(mode canCoalesce, pkt []byte, pktBuffsIndex int, gsoSize uint16, seq uint32, pshSet bool, item *tcpGROItem, buffs [][]byte, buffsOffset int, isV6 bool) coalesceResult { + var pktHead []byte // the packet that will end up at the front + headersLen := item.iphLen + item.tcphLen + coalescedLen := len(buffs[item.buffsIndex][buffsOffset:]) + len(pkt) - int(headersLen) + + // Copy data + if mode == coalescePrepend { + pktHead = pkt + if cap(pkt)-buffsOffset < coalescedLen { + // We don't want to allocate a new underlying array if capacity is + // too small. + return coalesceInsufficientCap + } + if pshSet { + return coalescePSHEnding + } + if item.numMerged == 0 { + if !tcpChecksumValid(buffs[item.buffsIndex][buffsOffset:], item.iphLen, isV6) { + return coalesceItemInvalidCSum + } + } + if !tcpChecksumValid(pkt, item.iphLen, isV6) { + return coalescePktInvalidCSum + } + item.sentSeq = seq + extendBy := coalescedLen - len(pktHead) + buffs[pktBuffsIndex] = append(buffs[pktBuffsIndex], make([]byte, extendBy)...) + copy(buffs[pktBuffsIndex][buffsOffset+len(pkt):], buffs[item.buffsIndex][buffsOffset+int(headersLen):]) + // Flip the slice headers in buffs as part of prepend. The index of item + // is already being tracked for writing. + buffs[item.buffsIndex], buffs[pktBuffsIndex] = buffs[pktBuffsIndex], buffs[item.buffsIndex] + } else { + pktHead = buffs[item.buffsIndex][buffsOffset:] + if cap(pktHead)-buffsOffset < coalescedLen { + // We don't want to allocate a new underlying array if capacity is + // too small. + return coalesceInsufficientCap + } + if item.numMerged == 0 { + if !tcpChecksumValid(buffs[item.buffsIndex][buffsOffset:], item.iphLen, isV6) { + return coalesceItemInvalidCSum + } + } + if !tcpChecksumValid(pkt, item.iphLen, isV6) { + return coalescePktInvalidCSum + } + if pshSet { + // We are appending a segment with PSH set. + item.pshSet = pshSet + pktHead[item.iphLen+tcpFlagsOffset] |= tcpFlagPSH + } + extendBy := len(pkt) - int(headersLen) + buffs[item.buffsIndex] = append(buffs[item.buffsIndex], make([]byte, extendBy)...) + copy(buffs[item.buffsIndex][buffsOffset+len(pktHead):], pkt[headersLen:]) + } + + if gsoSize > item.gsoSize { + item.gsoSize = gsoSize + } + hdr := virtioNetHdr{ + flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, // this turns into CHECKSUM_PARTIAL in the skb + hdrLen: uint16(headersLen), + gsoSize: uint16(item.gsoSize), + csumStart: uint16(item.iphLen), + csumOffset: 16, + } + + // Recalculate the total len (IPv4) or payload len (IPv6). Recalculate the + // (IPv4) header checksum. + if isV6 { + hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV6 + binary.BigEndian.PutUint16(pktHead[4:], uint16(coalescedLen)-uint16(item.iphLen)) // set new payload len + } else { + hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV4 + pktHead[10], pktHead[11] = 0, 0 // clear checksum field + binary.BigEndian.PutUint16(pktHead[2:], uint16(coalescedLen)) // set new total length + iphCSum := ^checksum(pktHead[:item.iphLen], 0) // compute checksum + binary.BigEndian.PutUint16(pktHead[10:], iphCSum) // set checksum field + } + hdr.encode(buffs[item.buffsIndex][buffsOffset-virtioNetHdrLen:]) + + // Calculate the pseudo header checksum and place it at the TCP checksum + // offset. Downstream checksum offloading will combine this with computation + // of the tcp header and payload checksum. + addrLen := 4 + addrOffset := ipv4SrcAddrOffset + if isV6 { + addrLen = 16 + addrOffset = ipv6SrcAddrOffset + } + srcAddrAt := buffsOffset + addrOffset + srcAddr := buffs[item.buffsIndex][srcAddrAt : srcAddrAt+addrLen] + dstAddr := buffs[item.buffsIndex][srcAddrAt+addrLen : srcAddrAt+addrLen*2] + psum := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, srcAddr, dstAddr, uint16(coalescedLen-int(item.iphLen))) + binary.BigEndian.PutUint16(pktHead[hdr.csumStart+hdr.csumOffset:], checksum([]byte{}, psum)) + + item.numMerged++ + return coalesceSuccess +} + +const ( + ipv4FlagMoreFragments = 0x80 +) + +const ( + ipv4SrcAddrOffset = 12 + ipv6SrcAddrOffset = 8 + maxUint16 = 1<<16 - 1 +) + +// tcpGRO evaluates the TCP packet at pktI in buffs for coalescing with +// existing packets tracked in table. It will return false when pktI is not +// coalesced, otherwise true. This indicates to the caller if buffs[pktI] +// should be written to the Device. +func tcpGRO(buffs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) (pktCoalesced bool) { + pkt := buffs[pktI][offset:] + if len(pkt) > maxUint16 { + // A valid IPv4 or IPv6 packet will never exceed this. + return false + } + iphLen := int((pkt[0] & 0x0F) * 4) + if isV6 { + iphLen = 40 + ipv6HPayloadLen := int(binary.BigEndian.Uint16(pkt[4:])) + if ipv6HPayloadLen != len(pkt)-iphLen { + return false + } + } else { + totalLen := int(binary.BigEndian.Uint16(pkt[2:])) + if totalLen != len(pkt) { + return false + } + if iphLen < 20 || iphLen > 60 { + return false + } + } + if len(pkt) < iphLen { + return false + } + tcphLen := int((pkt[iphLen+12] >> 4) * 4) + if tcphLen < 20 || tcphLen > 60 { + return false + } + if len(pkt) < iphLen+tcphLen { + return false + } + if !isV6 { + if pkt[6]&ipv4FlagMoreFragments != 0 || (pkt[6]<<3 != 0 || pkt[7] != 0) { + // no GRO support for fragmented segments for now + return false + } + } + tcpFlags := pkt[iphLen+tcpFlagsOffset] + var pshSet bool + // not a candidate if any non-ACK flags (except PSH+ACK) are set + if tcpFlags != tcpFlagACK { + if pkt[iphLen+tcpFlagsOffset] != tcpFlagACK|tcpFlagPSH { + return false + } + pshSet = true + } + gsoSize := uint16(len(pkt) - tcphLen - iphLen) + // not a candidate if payload len is 0 + if gsoSize < 1 { + return false + } + seq := binary.BigEndian.Uint32(pkt[iphLen+4:]) + srcAddrOffset := ipv4SrcAddrOffset + addrLen := 4 + if isV6 { + srcAddrOffset = ipv6SrcAddrOffset + addrLen = 16 + } + items, existing := table.lookupOrInsert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI) + if !existing { + return false + } + for i := len(items) - 1; i >= 0; i-- { + // In the best case of packets arriving in order iterating in reverse is + // more efficient if there are multiple items for a given flow. This + // also enables a natural table.deleteAt() in the + // coalesceItemInvalidCSum case without the need for index tracking. + // This algorithm makes a best effort to coalesce in the event of + // unordered packets, where pkt may land anywhere in items from a + // sequence number perspective, however once an item is inserted into + // the table it is never compared across other items later. + item := items[i] + can := tcpPacketsCanCoalesce(pkt, uint8(iphLen), uint8(tcphLen), seq, pshSet, gsoSize, item, buffs, offset) + if can != coalesceUnavailable { + result := coalesceTCPPackets(can, pkt, pktI, gsoSize, seq, pshSet, &item, buffs, offset, isV6) + switch result { + case coalesceSuccess: + table.updateAt(item, i) + return true + case coalesceItemInvalidCSum: + // delete the item with an invalid csum + table.deleteAt(item.key, i) + case coalescePktInvalidCSum: + // no point in inserting an item that we can't coalesce + return false + default: + } + } + } + // failed to coalesce with any other packets; store the item in the flow + table.insert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI) + return false +} + +func isTCP4(b []byte) bool { + if len(b) < 40 { + return false + } + if b[0]>>4 != 4 { + return false + } + if b[9] != unix.IPPROTO_TCP { + return false + } + return true +} + +func isTCP6NoEH(b []byte) bool { + if len(b) < 60 { + return false + } + if b[0]>>4 != 6 { + return false + } + if b[6] != unix.IPPROTO_TCP { + return false + } + return true +} + +// handleGRO evaluates buffs for GRO, and writes the indices of the resulting +// packets into toWrite. toWrite, tcp4Table, and tcp6Table should initially be +// empty (but non-nil), and are passed in to save allocs as the caller may reset +// and recycle them across vectors of packets. +func handleGRO(buffs [][]byte, offset int, tcp4Table, tcp6Table *tcpGROTable, toWrite *[]int) error { + for i := range buffs { + if offset < virtioNetHdrLen || offset > len(buffs[i])-1 { + return errors.New("invalid offset") + } + var coalesced bool + switch { + case isTCP4(buffs[i][offset:]): + coalesced = tcpGRO(buffs, offset, i, tcp4Table, false) + case isTCP6NoEH(buffs[i][offset:]): // ipv6 packets w/extension headers do not coalesce + coalesced = tcpGRO(buffs, offset, i, tcp6Table, true) + } + if !coalesced { + hdr := virtioNetHdr{} + err := hdr.encode(buffs[i][offset-virtioNetHdrLen:]) + if err != nil { + return err + } + *toWrite = append(*toWrite, i) + } + } + return nil +} + +// tcpTSO splits packets from in into outBuffs, writing the size of each +// element into sizes. It returns the number of buffers populated, and/or an +// error. +func tcpTSO(in []byte, hdr virtioNetHdr, outBuffs [][]byte, sizes []int, outOffset int) (int, error) { + iphLen := int(hdr.csumStart) + srcAddrOffset := ipv6SrcAddrOffset + addrLen := 16 + if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV4 { + in[10], in[11] = 0, 0 // clear ipv4 header checksum + srcAddrOffset = ipv4SrcAddrOffset + addrLen = 4 + } + tcpCSumAt := int(hdr.csumStart + hdr.csumOffset) + in[tcpCSumAt], in[tcpCSumAt+1] = 0, 0 // clear tcp checksum + firstTCPSeqNum := binary.BigEndian.Uint32(in[hdr.csumStart+4:]) + nextSegmentDataAt := int(hdr.hdrLen) + i := 0 + for ; nextSegmentDataAt < len(in); i++ { + if i == len(outBuffs) { + return i - 1, ErrTooManySegments + } + nextSegmentEnd := nextSegmentDataAt + int(hdr.gsoSize) + if nextSegmentEnd > len(in) { + nextSegmentEnd = len(in) + } + segmentDataLen := nextSegmentEnd - nextSegmentDataAt + totalLen := int(hdr.hdrLen) + segmentDataLen + sizes[i] = totalLen + out := outBuffs[i][outOffset:] + + copy(out, in[:iphLen]) + if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV4 { + // For IPv4 we are responsible for incrementing the ID field, + // updating the total len field, and recalculating the header + // checksum. + if i > 0 { + id := binary.BigEndian.Uint16(out[4:]) + id += uint16(i) + binary.BigEndian.PutUint16(out[4:], id) + } + binary.BigEndian.PutUint16(out[2:], uint16(totalLen)) + ipv4CSum := ^checksum(out[:iphLen], 0) + binary.BigEndian.PutUint16(out[10:], ipv4CSum) + } else { + // For IPv6 we are responsible for updating the payload length field. + binary.BigEndian.PutUint16(out[4:], uint16(totalLen-iphLen)) + } + + // TCP header + copy(out[hdr.csumStart:hdr.hdrLen], in[hdr.csumStart:hdr.hdrLen]) + tcpSeq := firstTCPSeqNum + uint32(hdr.gsoSize*uint16(i)) + binary.BigEndian.PutUint32(out[hdr.csumStart+4:], tcpSeq) + if nextSegmentEnd != len(in) { + // FIN and PSH should only be set on last segment + clearFlags := tcpFlagFIN | tcpFlagPSH + out[hdr.csumStart+tcpFlagsOffset] &^= clearFlags + } + + // payload + copy(out[hdr.hdrLen:], in[nextSegmentDataAt:nextSegmentEnd]) + + // TCP checksum + tcpHLen := int(hdr.hdrLen - hdr.csumStart) + tcpLenForPseudo := uint16(tcpHLen + segmentDataLen) + tcpCSumNoFold := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, in[srcAddrOffset:srcAddrOffset+addrLen], in[srcAddrOffset+addrLen:srcAddrOffset+addrLen*2], tcpLenForPseudo) + tcpCSum := ^checksum(out[hdr.csumStart:totalLen], tcpCSumNoFold) + binary.BigEndian.PutUint16(out[hdr.csumStart+hdr.csumOffset:], tcpCSum) + + nextSegmentDataAt += int(hdr.gsoSize) + } + return i, nil +} + +func gsoNoneChecksum(in []byte, cSumStart, cSumOffset uint16) error { + cSumAt := cSumStart + cSumOffset + // The initial value at the checksum offset should be summed with the + // checksum we compute. This is typically the pseudo-header checksum. + initial := binary.BigEndian.Uint16(in[cSumAt:]) + in[cSumAt], in[cSumAt+1] = 0, 0 + binary.BigEndian.PutUint16(in[cSumAt:], ^checksum(in[cSumStart:], uint64(initial))) + return nil +} diff --git a/tun/tcp_offload_linux_test.go b/tun/tcp_offload_linux_test.go new file mode 100644 index 0000000..7fa0777 --- /dev/null +++ b/tun/tcp_offload_linux_test.go @@ -0,0 +1,273 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package tun + +import ( + "net/netip" + "testing" + + "golang.org/x/sys/unix" + "golang.zx2c4.com/wireguard/conn" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" +) + +const ( + offset = virtioNetHdrLen +) + +var ( + ip4PortA = netip.MustParseAddrPort("192.0.2.1:1") + ip4PortB = netip.MustParseAddrPort("192.0.2.2:1") + ip4PortC = netip.MustParseAddrPort("192.0.2.3:1") + ip6PortA = netip.MustParseAddrPort("[2001:db8::1]:1") + ip6PortB = netip.MustParseAddrPort("[2001:db8::2]:1") + ip6PortC = netip.MustParseAddrPort("[2001:db8::3]:1") +) + +func tcp4Packet(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32) []byte { + totalLen := 40 + segmentSize + b := make([]byte, offset+int(totalLen), 65535) + ipv4H := header.IPv4(b[offset:]) + srcAs4 := srcIPPort.Addr().As4() + dstAs4 := dstIPPort.Addr().As4() + ipv4H.Encode(&header.IPv4Fields{ + SrcAddr: tcpip.Address(srcAs4[:]), + DstAddr: tcpip.Address(dstAs4[:]), + Protocol: unix.IPPROTO_TCP, + TTL: 64, + TotalLength: uint16(totalLen), + }) + tcpH := header.TCP(b[offset+20:]) + tcpH.Encode(&header.TCPFields{ + SrcPort: srcIPPort.Port(), + DstPort: dstIPPort.Port(), + SeqNum: seq, + AckNum: 1, + DataOffset: 20, + Flags: flags, + WindowSize: 3000, + }) + ipv4H.SetChecksum(^ipv4H.CalculateChecksum()) + pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_TCP, ipv4H.SourceAddress(), ipv4H.DestinationAddress(), uint16(20+segmentSize)) + tcpH.SetChecksum(^tcpH.CalculateChecksum(pseudoCsum)) + return b +} + +func tcp6Packet(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32) []byte { + totalLen := 60 + segmentSize + b := make([]byte, offset+int(totalLen), 65535) + ipv6H := header.IPv6(b[offset:]) + srcAs16 := srcIPPort.Addr().As16() + dstAs16 := dstIPPort.Addr().As16() + ipv6H.Encode(&header.IPv6Fields{ + SrcAddr: tcpip.Address(srcAs16[:]), + DstAddr: tcpip.Address(dstAs16[:]), + TransportProtocol: unix.IPPROTO_TCP, + HopLimit: 64, + PayloadLength: uint16(segmentSize + 20), + }) + tcpH := header.TCP(b[offset+40:]) + tcpH.Encode(&header.TCPFields{ + SrcPort: srcIPPort.Port(), + DstPort: dstIPPort.Port(), + SeqNum: seq, + AckNum: 1, + DataOffset: 20, + Flags: flags, + WindowSize: 3000, + }) + pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_TCP, ipv6H.SourceAddress(), ipv6H.DestinationAddress(), uint16(20+segmentSize)) + tcpH.SetChecksum(^tcpH.CalculateChecksum(pseudoCsum)) + return b +} + +func Test_handleVirtioRead(t *testing.T) { + tests := []struct { + name string + hdr virtioNetHdr + pktIn []byte + wantLens []int + wantErr bool + }{ + { + "tcp4", + virtioNetHdr{ + flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, + gsoType: unix.VIRTIO_NET_HDR_GSO_TCPV4, + gsoSize: 100, + hdrLen: 40, + csumStart: 20, + csumOffset: 16, + }, + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck|header.TCPFlagPsh, 200, 1), + []int{140, 140}, + false, + }, + { + "tcp6", + virtioNetHdr{ + flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, + gsoType: unix.VIRTIO_NET_HDR_GSO_TCPV6, + gsoSize: 100, + hdrLen: 60, + csumStart: 40, + csumOffset: 16, + }, + tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck|header.TCPFlagPsh, 200, 1), + []int{160, 160}, + false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + out := make([][]byte, conn.DefaultBatchSize) + sizes := make([]int, conn.DefaultBatchSize) + for i := range out { + out[i] = make([]byte, 65535) + } + tt.hdr.encode(tt.pktIn) + n, err := handleVirtioRead(tt.pktIn, out, sizes, offset) + if err != nil { + if tt.wantErr { + return + } + t.Fatalf("got err: %v", err) + } + if n != len(tt.wantLens) { + t.Fatalf("got %d packets, wanted %d", n, len(tt.wantLens)) + } + for i := range tt.wantLens { + if tt.wantLens[i] != sizes[i] { + t.Fatalf("wantLens[%d]: %d != outSizes: %d", i, tt.wantLens[i], sizes[i]) + } + } + }) + } +} + +func flipTCP4Checksum(b []byte) []byte { + at := virtioNetHdrLen + 20 + 16 // 20 byte ipv4 header; tcp csum offset is 16 + b[at] ^= 0xFF + b[at+1] ^= 0xFF + return b +} + +func Fuzz_handleGRO(f *testing.F) { + pkt0 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1) + pkt1 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101) + pkt2 := tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201) + pkt3 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1) + pkt4 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101) + pkt5 := tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201) + f.Add(pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, offset) + f.Fuzz(func(t *testing.T, pkt0, pkt1, pkt2, pkt3, pkt4, pkt5 []byte, offset int) { + pkts := [][]byte{pkt0, pkt1, pkt2, pkt3, pkt4, pkt5} + toWrite := make([]int, 0, len(pkts)) + handleGRO(pkts, offset, newTCPGROTable(), newTCPGROTable(), &toWrite) + if len(toWrite) > len(pkts) { + t.Errorf("len(toWrite): %d > len(pkts): %d", len(toWrite), len(pkts)) + } + seenWriteI := make(map[int]bool) + for _, writeI := range toWrite { + if writeI < 0 || writeI > len(pkts)-1 { + t.Errorf("toWrite value (%d) outside bounds of len(pkts): %d", writeI, len(pkts)) + } + if seenWriteI[writeI] { + t.Errorf("duplicate toWrite value: %d", writeI) + } + seenWriteI[writeI] = true + } + }) +} + +func Test_handleGRO(t *testing.T) { + tests := []struct { + name string + pktsIn [][]byte + wantToWrite []int + wantLens []int + wantErr bool + }{ + { + "multiple flows", + [][]byte{ + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // v4 flow 1 + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // v4 flow 1 + tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201), // v4 flow 2 + tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), // v6 flow 1 + tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101), // v6 flow 1 + tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201), // v6 flow 2 + }, + []int{0, 2, 3, 5}, + []int{240, 140, 260, 160}, + false, + }, + { + "PSH interleaved", + [][]byte{ + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // v4 flow 1 + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck|header.TCPFlagPsh, 100, 101), // v4 flow 1 + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1 + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 301), // v4 flow 1 + tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), // v6 flow 1 + tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck|header.TCPFlagPsh, 100, 101), // v6 flow 1 + tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 201), // v6 flow 1 + tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 301), // v6 flow 1 + }, + []int{0, 2, 4, 6}, + []int{240, 240, 260, 260}, + false, + }, + { + "coalesceItemInvalidCSum", + [][]byte{ + flipTCP4Checksum(tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1)), // v4 flow 1 seq 1 len 100 + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // v4 flow 1 seq 101 len 100 + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1 seq 201 len 100 + }, + []int{0, 1}, + []int{140, 240}, + false, + }, + { + "out of order", + [][]byte{ + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // v4 flow 1 seq 101 len 100 + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // v4 flow 1 seq 1 len 100 + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1 seq 201 len 100 + }, + []int{0}, + []int{340}, + false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + toWrite := make([]int, 0, len(tt.pktsIn)) + err := handleGRO(tt.pktsIn, offset, newTCPGROTable(), newTCPGROTable(), &toWrite) + if err != nil { + if tt.wantErr { + return + } + t.Fatalf("got err: %v", err) + } + if len(toWrite) != len(tt.wantToWrite) { + t.Fatalf("got %d packets, wanted %d", len(toWrite), len(tt.wantToWrite)) + } + for i, pktI := range tt.wantToWrite { + if tt.wantToWrite[i] != toWrite[i] { + t.Fatalf("wantToWrite[%d]: %d != toWrite: %d", i, tt.wantToWrite[i], toWrite[i]) + } + if tt.wantLens[i] != len(tt.pktsIn[pktI][offset:]) { + t.Errorf("wanted len %d packet at %d, got: %d", tt.wantLens[i], i, len(tt.pktsIn[pktI][offset:])) + } + } + }) + } +} diff --git a/tun/testdata/fuzz/Fuzz_handleGRO/032aec0105f26f709c118365e4830d6dc087cab24cd1e154c2e790589a309b77 b/tun/testdata/fuzz/Fuzz_handleGRO/032aec0105f26f709c118365e4830d6dc087cab24cd1e154c2e790589a309b77 new file mode 100644 index 0000000..5461e79 --- /dev/null +++ b/tun/testdata/fuzz/Fuzz_handleGRO/032aec0105f26f709c118365e4830d6dc087cab24cd1e154c2e790589a309b77 @@ -0,0 +1,8 @@ +go test fuzz v1 +[]byte("0") +[]byte("0") +[]byte("0") +[]byte("0") +[]byte("0") +[]byte("0") +int(34) diff --git a/tun/testdata/fuzz/Fuzz_handleGRO/0da283f9a2098dec30d1c86784411a8ce2e8e03aa3384105e581f2c67494700d b/tun/testdata/fuzz/Fuzz_handleGRO/0da283f9a2098dec30d1c86784411a8ce2e8e03aa3384105e581f2c67494700d new file mode 100644 index 0000000..b441819 --- /dev/null +++ b/tun/testdata/fuzz/Fuzz_handleGRO/0da283f9a2098dec30d1c86784411a8ce2e8e03aa3384105e581f2c67494700d @@ -0,0 +1,8 @@ +go test fuzz v1 +[]byte("0") +[]byte("0") +[]byte("0") +[]byte("0") +[]byte("0") +[]byte("0") +int(-48) diff --git a/tun/tun_linux.go b/tun/tun_linux.go index 21984ca..d56e3c1 100644 --- a/tun/tun_linux.go +++ b/tun/tun_linux.go @@ -17,9 +17,8 @@ import ( "time" "unsafe" - "golang.org/x/net/ipv6" "golang.org/x/sys/unix" - + "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/rwcancel" ) @@ -33,17 +32,25 @@ type NativeTun struct { index int32 // if index errors chan error // async error handling events chan Event // device related events - nopi bool // the device was passed IFF_NO_PI netlinkSock int netlinkCancel *rwcancel.RWCancel hackListenerClosed sync.Mutex statusListenersShutdown chan struct{} + batchSize int + vnetHdr bool closeOnce sync.Once nameOnce sync.Once // guards calling initNameCache, which sets following fields nameCache string // name of interface nameErr error + + readOpMu sync.Mutex // readOpMu guards readBuff + readBuff [virtioNetHdrLen + 65535]byte // if vnetHdr every read() is prefixed by virtioNetHdr + + writeOpMu sync.Mutex // writeOpMu guards toWrite, tcp4GROTable, tcp6GROTable + toWrite []int + tcp4GROTable, tcp6GROTable *tcpGROTable } func (tun *NativeTun) File() *os.File { @@ -323,60 +330,142 @@ func (tun *NativeTun) nameSlow() (string, error) { return unix.ByteSliceToString(ifr[:]), nil } -func (tun *NativeTun) Write(buffs [][]byte, offset int) (n int, err error) { - var buf []byte - if tun.nopi { - buf = buffs[0][offset:] +func (tun *NativeTun) Write(buffs [][]byte, offset int) (int, error) { + tun.writeOpMu.Lock() + defer func() { + tun.tcp4GROTable.reset() + tun.tcp6GROTable.reset() + tun.writeOpMu.Unlock() + }() + var ( + errs []error + total int + ) + tun.toWrite = tun.toWrite[:0] + if tun.vnetHdr { + err := handleGRO(buffs, offset, tun.tcp4GROTable, tun.tcp6GROTable, &tun.toWrite) + if err != nil { + return 0, err + } + offset -= virtioNetHdrLen } else { - // reserve space for header - buf = buffs[0][offset-4:] - - // add packet information header - buf[0] = 0x00 - buf[1] = 0x00 - if buf[4]>>4 == ipv6.Version { - buf[2] = 0x86 - buf[3] = 0xdd - } else { - buf[2] = 0x08 - buf[3] = 0x00 + for i := range buffs { + tun.toWrite = append(tun.toWrite, i) } } - - _, err = tun.tunFile.Write(buf) - if errors.Is(err, syscall.EBADFD) { - err = os.ErrClosed - } else if err == nil { - n = 1 + for _, buffsI := range tun.toWrite { + n, err := tun.tunFile.Write(buffs[buffsI][offset:]) + if errors.Is(err, syscall.EBADFD) { + return total, os.ErrClosed + } + if err != nil { + errs = append(errs, err) + } else { + total += n + } } - return n, err + return total, ErrorBatch(errs) } -func (tun *NativeTun) Read(buffs [][]byte, sizes []int, offset int) (n int, err error) { - select { - case err = <-tun.errors: - default: - if tun.nopi { - sizes[0], err = tun.tunFile.Read(buffs[0][offset:]) - if err == nil { - n = 1 - } - } else { - buff := buffs[0][offset-4:] - sizes[0], err = tun.tunFile.Read(buff[:]) - if errors.Is(err, syscall.EBADFD) { - err = os.ErrClosed - } else if err == nil { - n = 1 - } - if sizes[0] < 4 { - sizes[0] = 0 - } else { - sizes[0] -= 4 +// handleVirtioRead splits in into buffs, leaving offset bytes at the front of +// each buffer. It mutates sizes to reflect the size of each element of buffs, +// and returns the number of packets read. +func handleVirtioRead(in []byte, buffs [][]byte, sizes []int, offset int) (int, error) { + var hdr virtioNetHdr + err := hdr.decode(in) + if err != nil { + return 0, err + } + in = in[virtioNetHdrLen:] + if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_NONE { + if hdr.flags&unix.VIRTIO_NET_HDR_F_NEEDS_CSUM != 0 { + // This means CHECKSUM_PARTIAL in skb context. We are responsible + // for computing the checksum starting at hdr.csumStart and placing + // at hdr.csumOffset. + err = gsoNoneChecksum(in, hdr.csumStart, hdr.csumOffset) + if err != nil { + return 0, err } } + if len(in) > len(buffs[0][offset:]) { + return 0, fmt.Errorf("read len %d overflows buffs element len %d", len(in), len(buffs[0][offset:])) + } + n := copy(buffs[0][offset:], in) + sizes[0] = n + return 1, nil + } + if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 { + return 0, fmt.Errorf("unsupported virtio GSO type: %d", hdr.gsoType) + } + + ipVersion := in[0] >> 4 + switch ipVersion { + case 4: + if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 { + return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType) + } + case 6: + if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 { + return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType) + } + default: + return 0, fmt.Errorf("invalid ip header version: %d", ipVersion) + } + + if len(in) <= int(hdr.csumStart+12) { + return 0, errors.New("packet is too short") + } + // Don't trust hdr.hdrLen from the kernel as it can be equal to the length + // of the entire first packet when the kernel is handling it as part of a + // FORWARD path. Instead, parse the TCP header length and add it onto + // csumStart, which is synonymous for IP header length. + tcpHLen := uint16(in[hdr.csumStart+12] >> 4 * 4) + if tcpHLen < 20 || tcpHLen > 60 { + // A TCP header must be between 20 and 60 bytes in length. + return 0, fmt.Errorf("tcp header len is invalid: %d", tcpHLen) + } + hdr.hdrLen = hdr.csumStart + tcpHLen + + if len(in) < int(hdr.hdrLen) { + return 0, fmt.Errorf("length of packet (%d) < virtioNetHdr.hdrLen (%d)", len(in), hdr.hdrLen) + } + + if hdr.hdrLen < hdr.csumStart { + return 0, fmt.Errorf("virtioNetHdr.hdrLen (%d) < virtioNetHdr.csumStart (%d)", hdr.hdrLen, hdr.csumStart) + } + cSumAt := int(hdr.csumStart + hdr.csumOffset) + if cSumAt+1 >= len(in) { + return 0, fmt.Errorf("end of checksum offset (%d) exceeds packet length (%d)", cSumAt+1, len(in)) + } + + return tcpTSO(in, hdr, buffs, sizes, offset) +} + +func (tun *NativeTun) Read(buffs [][]byte, sizes []int, offset int) (int, error) { + tun.readOpMu.Lock() + defer tun.readOpMu.Unlock() + select { + case err := <-tun.errors: + return 0, err + default: + readInto := buffs[0][offset:] + if tun.vnetHdr { + readInto = tun.readBuff[:] + } + n, err := tun.tunFile.Read(readInto) + if errors.Is(err, syscall.EBADFD) { + err = os.ErrClosed + } + if err != nil { + return 0, err + } + if tun.vnetHdr { + return handleVirtioRead(readInto[:n], buffs, sizes, offset) + } else { + sizes[0] = n + return 1, nil + } } - return } func (tun *NativeTun) Events() <-chan Event { @@ -403,9 +492,49 @@ func (tun *NativeTun) Close() error { } func (tun *NativeTun) BatchSize() int { - return 1 + return tun.batchSize } +const ( + // TODO: support TSO with ECN bits + tunOffloads = unix.TUN_F_CSUM | unix.TUN_F_TSO4 | unix.TUN_F_TSO6 +) + +func (tun *NativeTun) initFromFlags(name string) error { + sc, err := tun.tunFile.SyscallConn() + if err != nil { + return err + } + if e := sc.Control(func(fd uintptr) { + var ( + ifr *unix.Ifreq + ) + ifr, err = unix.NewIfreq(name) + if err != nil { + return + } + err = unix.IoctlIfreq(int(fd), unix.TUNGETIFF, ifr) + if err != nil { + return + } + got := ifr.Uint16() + if got&unix.IFF_VNET_HDR != 0 { + err = unix.IoctlSetInt(int(fd), unix.TUNSETOFFLOAD, tunOffloads) + if err != nil { + return + } + tun.vnetHdr = true + tun.batchSize = conn.DefaultBatchSize + } else { + tun.batchSize = 1 + } + }); e != nil { + return e + } + return err +} + +// CreateTUN creates a Device with the provided name and MTU. func CreateTUN(name string, mtu int) (Device, error) { nfd, err := unix.Open(cloneDevicePath, unix.O_RDWR|unix.O_CLOEXEC, 0) if err != nil { @@ -415,25 +544,16 @@ func CreateTUN(name string, mtu int) (Device, error) { return nil, err } - var ifr [ifReqSize]byte - var flags uint16 = unix.IFF_TUN // | unix.IFF_NO_PI (disabled for TUN status hack) - nameBytes := []byte(name) - if len(nameBytes) >= unix.IFNAMSIZ { - unix.Close(nfd) - return nil, fmt.Errorf("interface name too long: %w", unix.ENAMETOOLONG) + ifr, err := unix.NewIfreq(name) + if err != nil { + return nil, err } - copy(ifr[:], nameBytes) - *(*uint16)(unsafe.Pointer(&ifr[unix.IFNAMSIZ])) = flags - - _, _, errno := unix.Syscall( - unix.SYS_IOCTL, - uintptr(nfd), - uintptr(unix.TUNSETIFF), - uintptr(unsafe.Pointer(&ifr[0])), - ) - if errno != 0 { - unix.Close(nfd) - return nil, errno + // IFF_VNET_HDR enables the "tun status hack" via routineHackListener() + // where a null write will return EINVAL indicating the TUN is up. + ifr.SetUint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_VNET_HDR) + err = unix.IoctlIfreq(nfd, unix.TUNSETIFF, ifr) + if err != nil { + return nil, err } err = unix.SetNonblock(nfd, true) @@ -448,13 +568,16 @@ func CreateTUN(name string, mtu int) (Device, error) { return CreateTUNFromFile(fd, mtu) } +// CreateTUNFromFile creates a Device from an os.File with the provided MTU. func CreateTUNFromFile(file *os.File, mtu int) (Device, error) { tun := &NativeTun{ tunFile: file, events: make(chan Event, 5), errors: make(chan error, 5), statusListenersShutdown: make(chan struct{}), - nopi: false, + tcp4GROTable: newTCPGROTable(), + tcp6GROTable: newTCPGROTable(), + toWrite: make([]int, 0, conn.DefaultBatchSize), } name, err := tun.Name() @@ -462,8 +585,12 @@ func CreateTUNFromFile(file *os.File, mtu int) (Device, error) { return nil, err } - // start event listener + err = tun.initFromFlags(name) + if err != nil { + return nil, err + } + // start event listener tun.index, err = getIFIndex(name) if err != nil { return nil, err @@ -492,6 +619,8 @@ func CreateTUNFromFile(file *os.File, mtu int) (Device, error) { return tun, nil } +// CreateUnmonitoredTUNFromFD creates a Device from the provided file +// descriptor. func CreateUnmonitoredTUNFromFD(fd int) (Device, string, error) { err := unix.SetNonblock(fd, true) if err != nil { @@ -499,14 +628,20 @@ func CreateUnmonitoredTUNFromFD(fd int) (Device, string, error) { } file := os.NewFile(uintptr(fd), "/dev/tun") tun := &NativeTun{ - tunFile: file, - events: make(chan Event, 5), - errors: make(chan error, 5), - nopi: true, + tunFile: file, + events: make(chan Event, 5), + errors: make(chan error, 5), + tcp4GROTable: newTCPGROTable(), + tcp6GROTable: newTCPGROTable(), + toWrite: make([]int, 0, conn.DefaultBatchSize), } name, err := tun.Name() if err != nil { return nil, "", err } - return tun, name, nil + err = tun.initFromFlags(name) + if err != nil { + return nil, "", err + } + return tun, name, err }