conn, device, tun: implement vectorized I/O on Linux

Implement TCP offloading via TSO and GRO for the Linux tun.Device, which
is made possible by virtio extensions in the kernel's TUN driver.

Delete conn.LinuxSocketEndpoint in favor of a collapsed conn.StdNetBind.
conn.StdNetBind makes use of recvmmsg() and sendmmsg() on Linux. All
platforms now fall under conn.StdNetBind, except for Windows, which
remains in conn.WinRingBind, which still needs to be adjusted to handle
multiple packets.

Also refactor sticky sockets support to eventually be applicable on
platforms other than just Linux. However Linux remains the sole platform
that fully implements it for now.

Co-authored-by: James Tucker <james@tailscale.com>
Signed-off-by: James Tucker <james@tailscale.com>
Signed-off-by: Jordan Whited <jordan@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
This commit is contained in:
Jordan Whited 2023-03-02 15:08:28 -08:00 committed by Jason A. Donenfeld
parent 3bb8fec7e4
commit 9e2f386022
24 changed files with 1877 additions and 794 deletions

View File

@ -1,587 +0,0 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
import (
"errors"
"net"
"net/netip"
"strconv"
"sync"
"syscall"
"unsafe"
"golang.org/x/sys/unix"
)
type ipv4Source struct {
Src [4]byte
Ifindex int32
}
type ipv6Source struct {
src [16]byte
// ifindex belongs in dst.ZoneId
}
type LinuxSocketEndpoint struct {
mu sync.Mutex
dst [unsafe.Sizeof(unix.SockaddrInet6{})]byte
src [unsafe.Sizeof(ipv6Source{})]byte
isV6 bool
}
func (endpoint *LinuxSocketEndpoint) Src4() *ipv4Source { return endpoint.src4() }
func (endpoint *LinuxSocketEndpoint) Dst4() *unix.SockaddrInet4 { return endpoint.dst4() }
func (endpoint *LinuxSocketEndpoint) IsV6() bool { return endpoint.isV6 }
func (endpoint *LinuxSocketEndpoint) src4() *ipv4Source {
return (*ipv4Source)(unsafe.Pointer(&endpoint.src[0]))
}
func (endpoint *LinuxSocketEndpoint) src6() *ipv6Source {
return (*ipv6Source)(unsafe.Pointer(&endpoint.src[0]))
}
func (endpoint *LinuxSocketEndpoint) dst4() *unix.SockaddrInet4 {
return (*unix.SockaddrInet4)(unsafe.Pointer(&endpoint.dst[0]))
}
func (endpoint *LinuxSocketEndpoint) dst6() *unix.SockaddrInet6 {
return (*unix.SockaddrInet6)(unsafe.Pointer(&endpoint.dst[0]))
}
// LinuxSocketBind uses sendmsg and recvmsg to implement a full bind with sticky sockets on Linux.
type LinuxSocketBind struct {
// mu guards sock4 and sock6 and the associated fds.
// As long as someone holds mu (read or write), the associated fds are valid.
mu sync.RWMutex
sock4 int
sock6 int
}
func NewLinuxSocketBind() Bind { return &LinuxSocketBind{sock4: -1, sock6: -1} }
func NewDefaultBind() Bind { return NewLinuxSocketBind() }
var (
_ Endpoint = (*LinuxSocketEndpoint)(nil)
_ Bind = (*LinuxSocketBind)(nil)
)
func (*LinuxSocketBind) ParseEndpoint(s string) (Endpoint, error) {
var end LinuxSocketEndpoint
e, err := netip.ParseAddrPort(s)
if err != nil {
return nil, err
}
if e.Addr().Is4() {
dst := end.dst4()
end.isV6 = false
dst.Port = int(e.Port())
dst.Addr = e.Addr().As4()
end.ClearSrc()
return &end, nil
}
if e.Addr().Is6() {
zone, err := zoneToUint32(e.Addr().Zone())
if err != nil {
return nil, err
}
dst := end.dst6()
end.isV6 = true
dst.Port = int(e.Port())
dst.ZoneId = zone
dst.Addr = e.Addr().As16()
end.ClearSrc()
return &end, nil
}
return nil, errors.New("invalid IP address")
}
func (bind *LinuxSocketBind) Open(port uint16) ([]ReceiveFunc, uint16, error) {
bind.mu.Lock()
defer bind.mu.Unlock()
var err error
var newPort uint16
var tries int
if bind.sock4 != -1 || bind.sock6 != -1 {
return nil, 0, ErrBindAlreadyOpen
}
originalPort := port
again:
port = originalPort
var sock4, sock6 int
// Attempt ipv6 bind, update port if successful.
sock6, newPort, err = create6(port)
if err != nil {
if !errors.Is(err, syscall.EAFNOSUPPORT) {
return nil, 0, err
}
} else {
port = newPort
}
// Attempt ipv4 bind, update port if successful.
sock4, newPort, err = create4(port)
if err != nil {
if originalPort == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 {
unix.Close(sock6)
tries++
goto again
}
if !errors.Is(err, syscall.EAFNOSUPPORT) {
unix.Close(sock6)
return nil, 0, err
}
} else {
port = newPort
}
var fns []ReceiveFunc
if sock4 != -1 {
bind.sock4 = sock4
fns = append(fns, bind.receiveIPv4)
}
if sock6 != -1 {
bind.sock6 = sock6
fns = append(fns, bind.receiveIPv6)
}
if len(fns) == 0 {
return nil, 0, syscall.EAFNOSUPPORT
}
return fns, port, nil
}
func (bind *LinuxSocketBind) SetMark(value uint32) error {
bind.mu.RLock()
defer bind.mu.RUnlock()
if bind.sock6 != -1 {
err := unix.SetsockoptInt(
bind.sock6,
unix.SOL_SOCKET,
unix.SO_MARK,
int(value),
)
if err != nil {
return err
}
}
if bind.sock4 != -1 {
err := unix.SetsockoptInt(
bind.sock4,
unix.SOL_SOCKET,
unix.SO_MARK,
int(value),
)
if err != nil {
return err
}
}
return nil
}
func (bind *LinuxSocketBind) BatchSize() int {
return 1
}
func (bind *LinuxSocketBind) Close() error {
// Take a readlock to shut down the sockets...
bind.mu.RLock()
if bind.sock6 != -1 {
unix.Shutdown(bind.sock6, unix.SHUT_RDWR)
}
if bind.sock4 != -1 {
unix.Shutdown(bind.sock4, unix.SHUT_RDWR)
}
bind.mu.RUnlock()
// ...and a write lock to close the fd.
// This ensures that no one else is using the fd.
bind.mu.Lock()
defer bind.mu.Unlock()
var err1, err2 error
if bind.sock6 != -1 {
err1 = unix.Close(bind.sock6)
bind.sock6 = -1
}
if bind.sock4 != -1 {
err2 = unix.Close(bind.sock4)
bind.sock4 = -1
}
if err1 != nil {
return err1
}
return err2
}
func (bind *LinuxSocketBind) receiveIPv4(buffs [][]byte, sizes []int, eps []Endpoint) (int, error) {
bind.mu.RLock()
defer bind.mu.RUnlock()
if bind.sock4 == -1 {
return 0, net.ErrClosed
}
var end LinuxSocketEndpoint
n, err := receive4(bind.sock4, buffs[0], &end)
if err != nil {
return 0, err
}
eps[0] = &end
sizes[0] = n
return 1, nil
}
func (bind *LinuxSocketBind) receiveIPv6(buffs [][]byte, sizes []int, eps []Endpoint) (int, error) {
bind.mu.RLock()
defer bind.mu.RUnlock()
if bind.sock6 == -1 {
return 0, net.ErrClosed
}
var end LinuxSocketEndpoint
n, err := receive6(bind.sock6, buffs[0], &end)
if err != nil {
return 0, err
}
eps[0] = &end
sizes[0] = n
return 1, nil
}
func (bind *LinuxSocketBind) Send(buffs [][]byte, end Endpoint) error {
nend, ok := end.(*LinuxSocketEndpoint)
if !ok {
return ErrWrongEndpointType
}
bind.mu.RLock()
defer bind.mu.RUnlock()
if !nend.isV6 {
if bind.sock4 == -1 {
return net.ErrClosed
}
for _, buff := range buffs {
err := send4(bind.sock4, nend, buff)
if err != nil {
return err
}
}
} else {
if bind.sock6 == -1 {
return net.ErrClosed
}
for _, buff := range buffs {
err := send6(bind.sock6, nend, buff)
if err != nil {
return err
}
}
}
return nil
}
func (end *LinuxSocketEndpoint) SrcIP() netip.Addr {
if !end.isV6 {
return netip.AddrFrom4(end.src4().Src)
} else {
return netip.AddrFrom16(end.src6().src)
}
}
func (end *LinuxSocketEndpoint) DstIP() netip.Addr {
if !end.isV6 {
return netip.AddrFrom4(end.dst4().Addr)
} else {
return netip.AddrFrom16(end.dst6().Addr)
}
}
func (end *LinuxSocketEndpoint) DstToBytes() []byte {
if !end.isV6 {
return (*[unsafe.Offsetof(end.dst4().Addr) + unsafe.Sizeof(end.dst4().Addr)]byte)(unsafe.Pointer(end.dst4()))[:]
} else {
return (*[unsafe.Offsetof(end.dst6().Addr) + unsafe.Sizeof(end.dst6().Addr)]byte)(unsafe.Pointer(end.dst6()))[:]
}
}
func (end *LinuxSocketEndpoint) SrcToString() string {
return end.SrcIP().String()
}
func (end *LinuxSocketEndpoint) DstToString() string {
var port int
if !end.isV6 {
port = end.dst4().Port
} else {
port = end.dst6().Port
}
return netip.AddrPortFrom(end.DstIP(), uint16(port)).String()
}
func (end *LinuxSocketEndpoint) ClearDst() {
for i := range end.dst {
end.dst[i] = 0
}
}
func (end *LinuxSocketEndpoint) ClearSrc() {
for i := range end.src {
end.src[i] = 0
}
}
func zoneToUint32(zone string) (uint32, error) {
if zone == "" {
return 0, nil
}
if intr, err := net.InterfaceByName(zone); err == nil {
return uint32(intr.Index), nil
}
n, err := strconv.ParseUint(zone, 10, 32)
return uint32(n), err
}
func create4(port uint16) (int, uint16, error) {
// create socket
fd, err := unix.Socket(
unix.AF_INET,
unix.SOCK_DGRAM|unix.SOCK_CLOEXEC,
0,
)
if err != nil {
return -1, 0, err
}
addr := unix.SockaddrInet4{
Port: int(port),
}
// set sockopts and bind
if err := func() error {
if err := unix.SetsockoptInt(
fd,
unix.IPPROTO_IP,
unix.IP_PKTINFO,
1,
); err != nil {
return err
}
return unix.Bind(fd, &addr)
}(); err != nil {
unix.Close(fd)
return -1, 0, err
}
sa, err := unix.Getsockname(fd)
if err == nil {
addr.Port = sa.(*unix.SockaddrInet4).Port
}
return fd, uint16(addr.Port), err
}
func create6(port uint16) (int, uint16, error) {
// create socket
fd, err := unix.Socket(
unix.AF_INET6,
unix.SOCK_DGRAM|unix.SOCK_CLOEXEC,
0,
)
if err != nil {
return -1, 0, err
}
// set sockopts and bind
addr := unix.SockaddrInet6{
Port: int(port),
}
if err := func() error {
if err := unix.SetsockoptInt(
fd,
unix.IPPROTO_IPV6,
unix.IPV6_RECVPKTINFO,
1,
); err != nil {
return err
}
if err := unix.SetsockoptInt(
fd,
unix.IPPROTO_IPV6,
unix.IPV6_V6ONLY,
1,
); err != nil {
return err
}
return unix.Bind(fd, &addr)
}(); err != nil {
unix.Close(fd)
return -1, 0, err
}
sa, err := unix.Getsockname(fd)
if err == nil {
addr.Port = sa.(*unix.SockaddrInet6).Port
}
return fd, uint16(addr.Port), err
}
func send4(sock int, end *LinuxSocketEndpoint, buff []byte) error {
// construct message header
cmsg := struct {
cmsghdr unix.Cmsghdr
pktinfo unix.Inet4Pktinfo
}{
unix.Cmsghdr{
Level: unix.IPPROTO_IP,
Type: unix.IP_PKTINFO,
Len: unix.SizeofInet4Pktinfo + unix.SizeofCmsghdr,
},
unix.Inet4Pktinfo{
Spec_dst: end.src4().Src,
Ifindex: end.src4().Ifindex,
},
}
end.mu.Lock()
_, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0)
end.mu.Unlock()
if err == nil {
return nil
}
// clear src and retry
if err == unix.EINVAL {
end.ClearSrc()
cmsg.pktinfo = unix.Inet4Pktinfo{}
end.mu.Lock()
_, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0)
end.mu.Unlock()
}
return err
}
func send6(sock int, end *LinuxSocketEndpoint, buff []byte) error {
// construct message header
cmsg := struct {
cmsghdr unix.Cmsghdr
pktinfo unix.Inet6Pktinfo
}{
unix.Cmsghdr{
Level: unix.IPPROTO_IPV6,
Type: unix.IPV6_PKTINFO,
Len: unix.SizeofInet6Pktinfo + unix.SizeofCmsghdr,
},
unix.Inet6Pktinfo{
Addr: end.src6().src,
Ifindex: end.dst6().ZoneId,
},
}
if cmsg.pktinfo.Addr == [16]byte{} {
cmsg.pktinfo.Ifindex = 0
}
end.mu.Lock()
_, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0)
end.mu.Unlock()
if err == nil {
return nil
}
// clear src and retry
if err == unix.EINVAL {
end.ClearSrc()
cmsg.pktinfo = unix.Inet6Pktinfo{}
end.mu.Lock()
_, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0)
end.mu.Unlock()
}
return err
}
func receive4(sock int, buff []byte, end *LinuxSocketEndpoint) (int, error) {
// construct message header
var cmsg struct {
cmsghdr unix.Cmsghdr
pktinfo unix.Inet4Pktinfo
}
size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0)
if err != nil {
return 0, err
}
end.isV6 = false
if newDst4, ok := newDst.(*unix.SockaddrInet4); ok {
*end.dst4() = *newDst4
}
// update source cache
if cmsg.cmsghdr.Level == unix.IPPROTO_IP &&
cmsg.cmsghdr.Type == unix.IP_PKTINFO &&
cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo {
end.src4().Src = cmsg.pktinfo.Spec_dst
end.src4().Ifindex = cmsg.pktinfo.Ifindex
}
return size, nil
}
func receive6(sock int, buff []byte, end *LinuxSocketEndpoint) (int, error) {
// construct message header
var cmsg struct {
cmsghdr unix.Cmsghdr
pktinfo unix.Inet6Pktinfo
}
size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0)
if err != nil {
return 0, err
}
end.isV6 = true
if newDst6, ok := newDst.(*unix.SockaddrInet6); ok {
*end.dst6() = *newDst6
}
// update source cache
if cmsg.cmsghdr.Level == unix.IPPROTO_IPV6 &&
cmsg.cmsghdr.Type == unix.IPV6_PKTINFO &&
cmsg.cmsghdr.Len >= unix.SizeofInet6Pktinfo {
end.src6().src = cmsg.pktinfo.Addr
end.dst6().ZoneId = cmsg.pktinfo.Ifindex
}
return size, nil
}

View File

@ -6,32 +6,91 @@
package conn package conn
import ( import (
"context"
"errors" "errors"
"net" "net"
"net/netip" "net/netip"
"strconv"
"sync" "sync"
"syscall" "syscall"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
) )
// StdNetBind is meant to be a temporary solution on platforms for which var (
// the sticky socket / source caching behavior has not yet been implemented. _ Bind = (*StdNetBind)(nil)
// It uses the Go's net package to implement networking. )
// See LinuxSocketBind for a proper implementation on the Linux platform.
// StdNetBind implements Bind for all platforms except Windows.
type StdNetBind struct { type StdNetBind struct {
mu sync.Mutex // protects following fields mu sync.Mutex // protects following fields
ipv4 *net.UDPConn ipv4 *net.UDPConn
ipv6 *net.UDPConn ipv6 *net.UDPConn
blackhole4 bool blackhole4 bool
blackhole6 bool blackhole6 bool
ipv4PC *ipv4.PacketConn
ipv6PC *ipv6.PacketConn
batchSize int
udpAddrPool sync.Pool
ipv4MsgsPool sync.Pool
ipv6MsgsPool sync.Pool
} }
func NewStdNetBind() Bind { return &StdNetBind{} } func NewStdNetBind() Bind { return NewStdNetBindBatch(DefaultBatchSize) }
type StdNetEndpoint netip.AddrPort func NewStdNetBindBatch(maxBatchSize int) Bind {
if maxBatchSize == 0 {
maxBatchSize = DefaultBatchSize
}
return &StdNetBind{
batchSize: maxBatchSize,
udpAddrPool: sync.Pool{
New: func() any {
return &net.UDPAddr{
IP: make([]byte, 16),
}
},
},
ipv4MsgsPool: sync.Pool{
New: func() any {
msgs := make([]ipv4.Message, maxBatchSize)
for i := range msgs {
msgs[i].Buffers = make(net.Buffers, 1)
msgs[i].OOB = make([]byte, srcControlSize)
}
return &msgs
},
},
ipv6MsgsPool: sync.Pool{
New: func() any {
msgs := make([]ipv6.Message, maxBatchSize)
for i := range msgs {
msgs[i].Buffers = make(net.Buffers, 1)
msgs[i].OOB = make([]byte, srcControlSize)
}
return &msgs
},
},
}
}
type StdNetEndpoint struct {
// AddrPort is the endpoint destination.
netip.AddrPort
// src is the current sticky source address and interface index, if supported.
src struct {
netip.Addr
ifidx int32
}
}
var ( var (
_ Bind = (*StdNetBind)(nil) _ Bind = (*StdNetBind)(nil)
_ Endpoint = StdNetEndpoint{} _ Endpoint = &StdNetEndpoint{}
) )
func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) { func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) {
@ -39,31 +98,38 @@ func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) {
return asEndpoint(e), err return asEndpoint(e), err
} }
func (StdNetEndpoint) ClearSrc() {} func (e *StdNetEndpoint) ClearSrc() {
e.src.ifidx = 0
func (e StdNetEndpoint) DstIP() netip.Addr { e.src.Addr = netip.Addr{}
return (netip.AddrPort)(e).Addr()
} }
func (e StdNetEndpoint) SrcIP() netip.Addr { func (e *StdNetEndpoint) DstIP() netip.Addr {
return netip.Addr{} // not supported return e.AddrPort.Addr()
} }
func (e StdNetEndpoint) DstToBytes() []byte { func (e *StdNetEndpoint) SrcIP() netip.Addr {
b, _ := (netip.AddrPort)(e).MarshalBinary() return e.src.Addr
}
func (e *StdNetEndpoint) SrcIfidx() int32 {
return e.src.ifidx
}
func (e *StdNetEndpoint) DstToBytes() []byte {
b, _ := e.AddrPort.MarshalBinary()
return b return b
} }
func (e StdNetEndpoint) DstToString() string { func (e *StdNetEndpoint) DstToString() string {
return (netip.AddrPort)(e).String() return e.AddrPort.String()
} }
func (e StdNetEndpoint) SrcToString() string { func (e *StdNetEndpoint) SrcToString() string {
return "" return e.src.Addr.String()
} }
func listenNet(network string, port int) (*net.UDPConn, int, error) { func listenNet(network string, port int) (*net.UDPConn, int, error) {
conn, err := net.ListenUDP(network, &net.UDPAddr{Port: port}) conn, err := listenConfig().ListenPacket(context.Background(), network, ":"+strconv.Itoa(port))
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
} }
@ -77,17 +143,17 @@ func listenNet(network string, port int) (*net.UDPConn, int, error) {
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
} }
return conn, uaddr.Port, nil return conn.(*net.UDPConn), uaddr.Port, nil
} }
func (bind *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) { func (s *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
bind.mu.Lock() s.mu.Lock()
defer bind.mu.Unlock() defer s.mu.Unlock()
var err error var err error
var tries int var tries int
if bind.ipv4 != nil || bind.ipv6 != nil { if s.ipv4 != nil || s.ipv6 != nil {
return nil, 0, ErrBindAlreadyOpen return nil, 0, ErrBindAlreadyOpen
} }
@ -95,104 +161,121 @@ func (bind *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
// If uport is 0, we can retry on failure. // If uport is 0, we can retry on failure.
again: again:
port := int(uport) port := int(uport)
var ipv4, ipv6 *net.UDPConn var v4conn, v6conn *net.UDPConn
ipv4, port, err = listenNet("udp4", port) v4conn, port, err = listenNet("udp4", port)
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
return nil, 0, err return nil, 0, err
} }
// Listen on the same port as we're using for ipv4. // Listen on the same port as we're using for ipv4.
ipv6, port, err = listenNet("udp6", port) v6conn, port, err = listenNet("udp6", port)
if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 { if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 {
ipv4.Close() v4conn.Close()
tries++ tries++
goto again goto again
} }
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
ipv4.Close() v4conn.Close()
return nil, 0, err return nil, 0, err
} }
var fns []ReceiveFunc var fns []ReceiveFunc
if ipv4 != nil { if v4conn != nil {
fns = append(fns, bind.makeReceiveIPv4(ipv4)) fns = append(fns, s.receiveIPv4)
bind.ipv4 = ipv4 s.ipv4 = v4conn
} }
if ipv6 != nil { if v6conn != nil {
fns = append(fns, bind.makeReceiveIPv6(ipv6)) fns = append(fns, s.receiveIPv6)
bind.ipv6 = ipv6 s.ipv6 = v6conn
} }
if len(fns) == 0 { if len(fns) == 0 {
return nil, 0, syscall.EAFNOSUPPORT return nil, 0, syscall.EAFNOSUPPORT
} }
s.ipv4PC = ipv4.NewPacketConn(s.ipv4)
s.ipv6PC = ipv6.NewPacketConn(s.ipv6)
return fns, uint16(port), nil return fns, uint16(port), nil
} }
func (bind *StdNetBind) BatchSize() int { func (s *StdNetBind) receiveIPv4(buffs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
return 1 msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message)
defer s.ipv4MsgsPool.Put(msgs)
for i := range buffs {
(*msgs)[i].Buffers[0] = buffs[i]
}
numMsgs, err := s.ipv4PC.ReadBatch(*msgs, 0)
if err != nil {
return 0, err
}
for i := 0; i < numMsgs; i++ {
msg := &(*msgs)[i]
sizes[i] = msg.N
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
ep := asEndpoint(addrPort)
getSrcFromControl(msg.OOB, ep)
eps[i] = ep
}
return numMsgs, nil
} }
func (bind *StdNetBind) Close() error { func (s *StdNetBind) receiveIPv6(buffs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
bind.mu.Lock() msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message)
defer bind.mu.Unlock() defer s.ipv6MsgsPool.Put(msgs)
for i := range buffs {
(*msgs)[i].Buffers[0] = buffs[i]
}
numMsgs, err := s.ipv6PC.ReadBatch(*msgs, 0)
if err != nil {
return 0, err
}
for i := 0; i < numMsgs; i++ {
msg := &(*msgs)[i]
sizes[i] = msg.N
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
ep := asEndpoint(addrPort)
getSrcFromControl(msg.OOB, ep)
eps[i] = ep
}
return numMsgs, nil
}
func (s *StdNetBind) BatchSize() int {
return s.batchSize
}
func (s *StdNetBind) Close() error {
s.mu.Lock()
defer s.mu.Unlock()
var err1, err2 error var err1, err2 error
if bind.ipv4 != nil { if s.ipv4 != nil {
err1 = bind.ipv4.Close() err1 = s.ipv4.Close()
bind.ipv4 = nil s.ipv4 = nil
} }
if bind.ipv6 != nil { if s.ipv6 != nil {
err2 = bind.ipv6.Close() err2 = s.ipv6.Close()
bind.ipv6 = nil s.ipv6 = nil
} }
bind.blackhole4 = false s.blackhole4 = false
bind.blackhole6 = false s.blackhole6 = false
if err1 != nil { if err1 != nil {
return err1 return err1
} }
return err2 return err2
} }
func (*StdNetBind) makeReceiveIPv4(conn *net.UDPConn) ReceiveFunc { func (s *StdNetBind) Send(buffs [][]byte, endpoint Endpoint) error {
return func(buffs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { s.mu.Lock()
size, endpoint, err := conn.ReadFromUDPAddrPort(buffs[0]) blackhole := s.blackhole4
if err == nil { conn := s.ipv4
sizes[0] = size is6 := false
eps[0] = asEndpoint(endpoint) if endpoint.DstIP().Is6() {
return 1, nil blackhole = s.blackhole6
} conn = s.ipv6
return 0, err is6 = true
} }
} s.mu.Unlock()
func (*StdNetBind) makeReceiveIPv6(conn *net.UDPConn) ReceiveFunc {
return func(buffs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
size, endpoint, err := conn.ReadFromUDPAddrPort(buffs[0])
if err == nil {
sizes[0] = size
eps[0] = asEndpoint(endpoint)
return 1, nil
}
return 0, err
}
}
func (bind *StdNetBind) Send(buffs [][]byte, endpoint Endpoint) error {
var err error
nend, ok := endpoint.(StdNetEndpoint)
if !ok {
return ErrWrongEndpointType
}
addrPort := netip.AddrPort(nend)
bind.mu.Lock()
blackhole := bind.blackhole4
conn := bind.ipv4
if addrPort.Addr().Is6() {
blackhole = bind.blackhole6
conn = bind.ipv6
}
bind.mu.Unlock()
if blackhole { if blackhole {
return nil return nil
@ -200,13 +283,69 @@ func (bind *StdNetBind) Send(buffs [][]byte, endpoint Endpoint) error {
if conn == nil { if conn == nil {
return syscall.EAFNOSUPPORT return syscall.EAFNOSUPPORT
} }
for _, buff := range buffs { if is6 {
_, err = conn.WriteToUDPAddrPort(buff, addrPort) return s.send6(s.ipv6PC, endpoint, buffs)
if err != nil { } else {
return err return s.send4(s.ipv4PC, endpoint, buffs)
}
} }
return nil }
func (s *StdNetBind) send4(conn *ipv4.PacketConn, ep Endpoint, buffs [][]byte) error {
ua := s.udpAddrPool.Get().(*net.UDPAddr)
as4 := ep.DstIP().As4()
copy(ua.IP, as4[:])
ua.IP = ua.IP[:4]
ua.Port = int(ep.(*StdNetEndpoint).Port())
msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message)
for i, buff := range buffs {
(*msgs)[i].Buffers[0] = buff
(*msgs)[i].Addr = ua
setSrcControl(&(*msgs)[i].OOB, ep.(*StdNetEndpoint))
}
var (
n int
err error
start int
)
for {
n, err = conn.WriteBatch((*msgs)[start:len(buffs)], 0)
if err != nil || n == len((*msgs)[start:len(buffs)]) {
break
}
start += n
}
s.udpAddrPool.Put(ua)
s.ipv4MsgsPool.Put(msgs)
return err
}
func (s *StdNetBind) send6(conn *ipv6.PacketConn, ep Endpoint, buffs [][]byte) error {
ua := s.udpAddrPool.Get().(*net.UDPAddr)
as16 := ep.DstIP().As16()
copy(ua.IP, as16[:])
ua.IP = ua.IP[:16]
ua.Port = int(ep.(*StdNetEndpoint).Port())
msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message)
for i, buff := range buffs {
(*msgs)[i].Buffers[0] = buff
(*msgs)[i].Addr = ua
setSrcControl(&(*msgs)[i].OOB, ep.(*StdNetEndpoint))
}
var (
n int
err error
start int
)
for {
n, err = conn.WriteBatch((*msgs)[start:len(buffs)], 0)
if err != nil || n == len((*msgs)[start:len(buffs)]) {
break
}
start += n
}
s.udpAddrPool.Put(ua)
s.ipv6MsgsPool.Put(msgs)
return err
} }
// endpointPool contains a re-usable set of mapping from netip.AddrPort to Endpoint. // endpointPool contains a re-usable set of mapping from netip.AddrPort to Endpoint.
@ -214,17 +353,17 @@ func (bind *StdNetBind) Send(buffs [][]byte, endpoint Endpoint) error {
// but Endpoints are immutable, so we can re-use them. // but Endpoints are immutable, so we can re-use them.
var endpointPool = sync.Pool{ var endpointPool = sync.Pool{
New: func() any { New: func() any {
return make(map[netip.AddrPort]Endpoint) return make(map[netip.AddrPort]*StdNetEndpoint)
}, },
} }
// asEndpoint returns an Endpoint containing ap. // asEndpoint returns an Endpoint containing ap.
func asEndpoint(ap netip.AddrPort) Endpoint { func asEndpoint(ap netip.AddrPort) *StdNetEndpoint {
m := endpointPool.Get().(map[netip.AddrPort]Endpoint) m := endpointPool.Get().(map[netip.AddrPort]*StdNetEndpoint)
defer endpointPool.Put(m) defer endpointPool.Put(m)
e, ok := m[ap] e, ok := m[ap]
if !ok { if !ok {
e = Endpoint(StdNetEndpoint(ap)) e = &StdNetEndpoint{AddrPort: ap}
m[ap] = e m[ap] = e
} }
return e return e

View File

@ -5,8 +5,8 @@
package conn package conn
func (bind *StdNetBind) PeekLookAtSocketFd4() (fd int, err error) { func (s *StdNetBind) PeekLookAtSocketFd4() (fd int, err error) {
sysconn, err := bind.ipv4.SyscallConn() sysconn, err := s.ipv4.SyscallConn()
if err != nil { if err != nil {
return -1, err return -1, err
} }
@ -19,8 +19,8 @@ func (bind *StdNetBind) PeekLookAtSocketFd4() (fd int, err error) {
return return
} }
func (bind *StdNetBind) PeekLookAtSocketFd6() (fd int, err error) { func (s *StdNetBind) PeekLookAtSocketFd6() (fd int, err error) {
sysconn, err := bind.ipv6.SyscallConn() sysconn, err := s.ipv6.SyscallConn()
if err != nil { if err != nil {
return -1, err return -1, err
} }

View File

@ -16,7 +16,7 @@ import (
) )
const ( const (
DefaultBatchSize = 1 // maximum number of packets handled per read and write DefaultBatchSize = 128 // maximum number of packets handled per read and write
) )
// A ReceiveFunc receives at least one packet from the network and writes them // A ReceiveFunc receives at least one packet from the network and writes them

36
conn/controlfns.go Normal file
View File

@ -0,0 +1,36 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
import (
"net"
"syscall"
)
// controlFn is the callback function signature from net.ListenConfig.Control.
// It is used to apply platform specific configuration to the socket prior to
// bind.
type controlFn func(network, address string, c syscall.RawConn) error
// controlFns is a list of functions that are called from the listen config
// that can apply socket options.
var controlFns = []controlFn{}
// listenConfig returns a net.ListenConfig that applies the controlFns to the
// socket prior to bind. This is used to apply socket buffer sizing and packet
// information OOB configuration for sticky sockets.
func listenConfig() *net.ListenConfig {
return &net.ListenConfig{
Control: func(network, address string, c syscall.RawConn) error {
for _, fn := range controlFns {
if err := fn(network, address, c); err != nil {
return err
}
}
return nil
},
}
}

41
conn/controlfns_linux.go Normal file
View File

@ -0,0 +1,41 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
import (
"fmt"
"syscall"
"golang.org/x/sys/unix"
)
func init() {
controlFns = append(controlFns,
// Enable receiving of the packet information (IP_PKTINFO for IPv4,
// IPV6_PKTINFO for IPv6) that is used to implement sticky socket support.
func(network, address string, c syscall.RawConn) error {
var err error
switch network {
case "udp4":
c.Control(func(fd uintptr) {
err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_PKTINFO, 1)
})
case "udp6":
c.Control(func(fd uintptr) {
err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVPKTINFO, 1)
if err != nil {
return
}
err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_V6ONLY, 1)
})
default:
err = fmt.Errorf("unhandled network: %s: %w", network, unix.EINVAL)
}
return err
},
)
}

28
conn/controlfns_unix.go Normal file
View File

@ -0,0 +1,28 @@
//go:build !windows && !linux && !js
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
import (
"syscall"
"golang.org/x/sys/unix"
)
func init() {
controlFns = append(controlFns,
func(network, address string, c syscall.RawConn) error {
var err error
if network == "udp6" {
c.Control(func(fd uintptr) {
err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_V6ONLY, 1)
})
}
return err
},
)
}

View File

@ -1,4 +1,4 @@
//go:build !linux && !windows //go:build !windows
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *

View File

@ -7,6 +7,6 @@
package conn package conn
func (bind *StdNetBind) SetMark(mark uint32) error { func (s *StdNetBind) SetMark(mark uint32) error {
return nil return nil
} }

View File

@ -26,13 +26,13 @@ func init() {
} }
} }
func (bind *StdNetBind) SetMark(mark uint32) error { func (s *StdNetBind) SetMark(mark uint32) error {
var operr error var operr error
if fwmarkIoctl == 0 { if fwmarkIoctl == 0 {
return nil return nil
} }
if bind.ipv4 != nil { if s.ipv4 != nil {
fd, err := bind.ipv4.SyscallConn() fd, err := s.ipv4.SyscallConn()
if err != nil { if err != nil {
return err return err
} }
@ -46,8 +46,8 @@ func (bind *StdNetBind) SetMark(mark uint32) error {
return err return err
} }
} }
if bind.ipv6 != nil { if s.ipv6 != nil {
fd, err := bind.ipv6.SyscallConn() fd, err := s.ipv6.SyscallConn()
if err != nil { if err != nil {
return err return err
} }

26
conn/sticky_default.go Normal file
View File

@ -0,0 +1,26 @@
//go:build !linux
// +build !linux
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
// TODO: macOS, FreeBSD and other BSDs likely do support this feature set, but
// use alternatively named flags and need ports and require testing.
// getSrcFromControl parses the control for PKTINFO and if found updates ep with
// the source information found.
func getSrcFromControl(control []byte, ep *StdNetEndpoint) {
}
// setSrcControl parses the control for PKTINFO and if found updates ep with
// the source information found.
func setSrcControl(control *[]byte, ep *StdNetEndpoint) {
}
// srcControlSize returns the recommended buffer size for pooling sticky control
// data.
const srcControlSize = 0

111
conn/sticky_linux.go Normal file
View File

@ -0,0 +1,111 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
import (
"net/netip"
"unsafe"
"golang.org/x/sys/unix"
)
// getSrcFromControl parses the control for PKTINFO and if found updates ep with
// the source information found.
func getSrcFromControl(control []byte, ep *StdNetEndpoint) {
ep.ClearSrc()
var (
hdr unix.Cmsghdr
data []byte
rem []byte = control
err error
)
for len(rem) > unix.SizeofCmsghdr {
hdr, data, rem, err = unix.ParseOneSocketControlMessage(control)
if err != nil {
return
}
if hdr.Level == unix.IPPROTO_IP &&
hdr.Type == unix.IP_PKTINFO {
info := pktInfoFromBuf[unix.Inet4Pktinfo](data)
ep.src.Addr = netip.AddrFrom4(info.Spec_dst)
ep.src.ifidx = info.Ifindex
return
}
if hdr.Level == unix.IPPROTO_IPV6 &&
hdr.Type == unix.IPV6_PKTINFO {
info := pktInfoFromBuf[unix.Inet6Pktinfo](data)
ep.src.Addr = netip.AddrFrom16(info.Addr)
ep.src.ifidx = int32(info.Ifindex)
return
}
}
}
// pktInfoFromBuf returns type T populated from the provided buf via copy(). It
// panics if buf is of insufficient size.
func pktInfoFromBuf[T unix.Inet4Pktinfo | unix.Inet6Pktinfo](buf []byte) (t T) {
size := int(unsafe.Sizeof(t))
if len(buf) < size {
panic("pktInfoFromBuf: buffer too small")
}
copy(unsafe.Slice((*byte)(unsafe.Pointer(&t)), size), buf)
return t
}
// setSrcControl parses the control for PKTINFO and if found updates ep with
// the source information found.
func setSrcControl(control *[]byte, ep *StdNetEndpoint) {
*control = (*control)[:cap(*control)]
if len(*control) < int(unsafe.Sizeof(unix.Cmsghdr{})) {
*control = (*control)[:0]
return
}
if ep.src.ifidx == 0 && !ep.SrcIP().IsValid() {
*control = (*control)[:0]
return
}
if len(*control) < srcControlSize {
*control = (*control)[:0]
return
}
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&(*control)[0]))
if ep.SrcIP().Is4() {
hdr.Level = unix.IPPROTO_IP
hdr.Type = unix.IP_PKTINFO
hdr.SetLen(unix.CmsgLen(unix.SizeofInet4Pktinfo))
info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&(*control)[unix.SizeofCmsghdr]))
info.Ifindex = ep.src.ifidx
if ep.SrcIP().IsValid() {
info.Spec_dst = ep.SrcIP().As4()
}
} else {
hdr.Level = unix.IPPROTO_IPV6
hdr.Type = unix.IPV6_PKTINFO
hdr.Len = unix.SizeofCmsghdr + unix.SizeofInet6Pktinfo
info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&(*control)[unix.SizeofCmsghdr]))
info.Ifindex = uint32(ep.src.ifidx)
if ep.SrcIP().IsValid() {
info.Addr = ep.SrcIP().As16()
}
}
*control = (*control)[:hdr.Len]
}
var srcControlSize = unix.CmsgLen(unix.SizeofInet6Pktinfo)

207
conn/sticky_linux_test.go Normal file
View File

@ -0,0 +1,207 @@
//go:build linux
// +build linux
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
import (
"context"
"net"
"net/netip"
"runtime"
"testing"
"unsafe"
"golang.org/x/sys/unix"
)
func Test_setSrcControl(t *testing.T) {
t.Run("IPv4", func(t *testing.T) {
ep := &StdNetEndpoint{
AddrPort: netip.MustParseAddrPort("127.0.0.1:1234"),
}
ep.src.Addr = netip.MustParseAddr("127.0.0.1")
ep.src.ifidx = 5
control := make([]byte, srcControlSize)
setSrcControl(&control, ep)
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
if hdr.Level != unix.IPPROTO_IP {
t.Errorf("unexpected level: %d", hdr.Level)
}
if hdr.Type != unix.IP_PKTINFO {
t.Errorf("unexpected type: %d", hdr.Type)
}
if hdr.Len != uint64(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet4Pktinfo{})))) {
t.Errorf("unexpected length: %d", hdr.Len)
}
info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
if info.Spec_dst[0] != 127 || info.Spec_dst[1] != 0 || info.Spec_dst[2] != 0 || info.Spec_dst[3] != 1 {
t.Errorf("unexpected address: %v", info.Spec_dst)
}
if info.Ifindex != 5 {
t.Errorf("unexpected ifindex: %d", info.Ifindex)
}
})
t.Run("IPv6", func(t *testing.T) {
ep := &StdNetEndpoint{
AddrPort: netip.MustParseAddrPort("[::1]:1234"),
}
ep.src.Addr = netip.MustParseAddr("::1")
ep.src.ifidx = 5
control := make([]byte, srcControlSize)
setSrcControl(&control, ep)
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
if hdr.Level != unix.IPPROTO_IPV6 {
t.Errorf("unexpected level: %d", hdr.Level)
}
if hdr.Type != unix.IPV6_PKTINFO {
t.Errorf("unexpected type: %d", hdr.Type)
}
if hdr.Len != uint64(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet6Pktinfo{})))) {
t.Errorf("unexpected length: %d", hdr.Len)
}
info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
if info.Addr != ep.SrcIP().As16() {
t.Errorf("unexpected address: %v", info.Addr)
}
if info.Ifindex != 5 {
t.Errorf("unexpected ifindex: %d", info.Ifindex)
}
})
t.Run("ClearOnNoSrc", func(t *testing.T) {
control := make([]byte, srcControlSize)
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
hdr.Level = 1
hdr.Type = 2
hdr.Len = 3
setSrcControl(&control, &StdNetEndpoint{})
if len(control) != 0 {
t.Errorf("unexpected control: %v", control)
}
})
}
func Test_getSrcFromControl(t *testing.T) {
t.Run("IPv4", func(t *testing.T) {
control := make([]byte, srcControlSize)
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
hdr.Level = unix.IPPROTO_IP
hdr.Type = unix.IP_PKTINFO
hdr.Len = uint64(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet4Pktinfo{}))))
info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
info.Spec_dst = [4]byte{127, 0, 0, 1}
info.Ifindex = 5
ep := &StdNetEndpoint{}
getSrcFromControl(control, ep)
if ep.src.Addr != netip.MustParseAddr("127.0.0.1") {
t.Errorf("unexpected address: %v", ep.src.Addr)
}
if ep.src.ifidx != 5 {
t.Errorf("unexpected ifindex: %d", ep.src.ifidx)
}
})
t.Run("IPv6", func(t *testing.T) {
control := make([]byte, srcControlSize)
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
hdr.Level = unix.IPPROTO_IPV6
hdr.Type = unix.IPV6_PKTINFO
hdr.Len = uint64(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet6Pktinfo{}))))
info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
info.Addr = [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}
info.Ifindex = 5
ep := &StdNetEndpoint{}
getSrcFromControl(control, ep)
if ep.SrcIP() != netip.MustParseAddr("::1") {
t.Errorf("unexpected address: %v", ep.SrcIP())
}
if ep.src.ifidx != 5 {
t.Errorf("unexpected ifindex: %d", ep.src.ifidx)
}
})
t.Run("ClearOnEmpty", func(t *testing.T) {
control := make([]byte, srcControlSize)
ep := &StdNetEndpoint{}
ep.src.Addr = netip.MustParseAddr("::1")
ep.src.ifidx = 5
getSrcFromControl(control, ep)
if ep.SrcIP().IsValid() {
t.Errorf("unexpected address: %v", ep.src.Addr)
}
if ep.src.ifidx != 0 {
t.Errorf("unexpected ifindex: %d", ep.src.ifidx)
}
})
}
func Test_listenConfig(t *testing.T) {
t.Run("IPv4", func(t *testing.T) {
conn, err := listenConfig().ListenPacket(context.Background(), "udp4", ":0")
if err != nil {
t.Fatal(err)
}
defer conn.Close()
sc, err := conn.(*net.UDPConn).SyscallConn()
if err != nil {
t.Fatal(err)
}
if runtime.GOOS == "linux" {
var i int
sc.Control(func(fd uintptr) {
i, err = unix.GetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_PKTINFO)
})
if err != nil {
t.Fatal(err)
}
if i != 1 {
t.Error("IP_PKTINFO not set!")
}
} else {
t.Logf("listenConfig() does not set IPV6_RECVPKTINFO on %s", runtime.GOOS)
}
})
t.Run("IPv6", func(t *testing.T) {
conn, err := listenConfig().ListenPacket(context.Background(), "udp6", ":0")
if err != nil {
t.Fatal(err)
}
sc, err := conn.(*net.UDPConn).SyscallConn()
if err != nil {
t.Fatal(err)
}
if runtime.GOOS == "linux" {
var i int
sc.Control(func(fd uintptr) {
i, err = unix.GetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVPKTINFO)
})
if err != nil {
t.Fatal(err)
}
if i != 1 {
t.Error("IPV6_PKTINFO not set!")
}
} else {
t.Logf("listenConfig() does not set IPV6_RECVPKTINFO on %s", runtime.GOOS)
}
})
}

View File

@ -5,10 +5,12 @@
package device package device
import "golang.zx2c4.com/wireguard/conn"
/* Reduce memory consumption for Android */ /* Reduce memory consumption for Android */
const ( const (
QueueStagedSize = 128 QueueStagedSize = conn.DefaultBatchSize
QueueOutboundSize = 1024 QueueOutboundSize = 1024
QueueInboundSize = 1024 QueueInboundSize = 1024
QueueHandshakeSize = 1024 QueueHandshakeSize = 1024

View File

@ -7,8 +7,10 @@
package device package device
import "golang.zx2c4.com/wireguard/conn"
const ( const (
QueueStagedSize = 128 QueueStagedSize = conn.DefaultBatchSize
QueueOutboundSize = 1024 QueueOutboundSize = 1024
QueueInboundSize = 1024 QueueInboundSize = 1024
QueueHandshakeSize = 1024 QueueHandshakeSize = 1024

View File

@ -25,7 +25,7 @@ import (
) )
func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) { func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) {
if _, ok := bind.(*conn.LinuxSocketBind); !ok { if _, ok := bind.(*conn.StdNetBind); !ok {
return nil, nil return nil, nil
} }
@ -112,11 +112,11 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl
pePtr.peer.Unlock() pePtr.peer.Unlock()
break break
} }
if uint32(pePtr.peer.endpoint.(*conn.LinuxSocketEndpoint).Src4().Ifindex) == ifidx { if uint32(pePtr.peer.endpoint.(*conn.StdNetEndpoint).SrcIfidx()) == ifidx {
pePtr.peer.Unlock() pePtr.peer.Unlock()
break break
} }
pePtr.peer.endpoint.(*conn.LinuxSocketEndpoint).ClearSrc() pePtr.peer.endpoint.(*conn.StdNetEndpoint).ClearSrc()
pePtr.peer.Unlock() pePtr.peer.Unlock()
} }
attr = attr[attrhdr.Len:] attr = attr[attrhdr.Len:]
@ -136,12 +136,12 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl
peer.RUnlock() peer.RUnlock()
continue continue
} }
nativeEP, _ := peer.endpoint.(*conn.LinuxSocketEndpoint) nativeEP, _ := peer.endpoint.(*conn.StdNetEndpoint)
if nativeEP == nil { if nativeEP == nil {
peer.RUnlock() peer.RUnlock()
continue continue
} }
if nativeEP.IsV6() || nativeEP.Src4().Ifindex == 0 { if nativeEP.DstIP().Is6() || nativeEP.SrcIfidx() == 0 {
peer.RUnlock() peer.RUnlock()
break break
} }
@ -169,12 +169,12 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl
Len: 8, Len: 8,
Type: unix.RTA_DST, Type: unix.RTA_DST,
}, },
nativeEP.Dst4().Addr, nativeEP.DstIP().As4(),
unix.RtAttr{ unix.RtAttr{
Len: 8, Len: 8,
Type: unix.RTA_SRC, Type: unix.RTA_SRC,
}, },
nativeEP.Src4().Src, nativeEP.SrcIP().As4(),
unix.RtAttr{ unix.RtAttr{
Len: 8, Len: 8,
Type: unix.RTA_MARK, Type: unix.RTA_MARK,

6
go.mod
View File

@ -3,9 +3,9 @@ module golang.zx2c4.com/wireguard
go 1.19 go 1.19
require ( require (
golang.org/x/crypto v0.0.0-20220315160706-3147a52a75dd golang.org/x/crypto v0.3.0
golang.org/x/net v0.0.0-20220225172249-27dd8689420f golang.org/x/net v0.2.0
golang.org/x/sys v0.2.0 golang.org/x/sys v0.5.1-0.20230222185716-a3b23cc77e89
golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224 golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224
gvisor.dev/gvisor v0.0.0-20221203005347-703fd9b7fbc0 gvisor.dev/gvisor v0.0.0-20221203005347-703fd9b7fbc0
) )

12
go.sum
View File

@ -1,11 +1,11 @@
github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4= github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4=
github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA= github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA=
golang.org/x/crypto v0.0.0-20220315160706-3147a52a75dd h1:XcWmESyNjXJMLahc3mqVQJcgSTDxFxhETVlfk9uGc38= golang.org/x/crypto v0.3.0 h1:a06MkbcxBrEFc0w0QIZWXrH/9cCX6KJyWbBOIwAn+7A=
golang.org/x/crypto v0.0.0-20220315160706-3147a52a75dd/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.3.0/go.mod h1:hebNnKkNXi2UzZN1eVRvBB7co0a+JxK6XbPiWVs/3J4=
golang.org/x/net v0.0.0-20220225172249-27dd8689420f h1:oA4XRj0qtSt8Yo1Zms0CUlsT3KG69V2UGQWPBxujDmc= golang.org/x/net v0.2.0 h1:sZfSu1wtKLGlWI4ZZayP0ck9Y73K1ynO6gqzTdBVdPU=
golang.org/x/net v0.0.0-20220225172249-27dd8689420f/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/net v0.2.0/go.mod h1:KqCZLdyyvdV855qA2rE3GC2aiw5xGR5TEjj8smXukLY=
golang.org/x/sys v0.2.0 h1:ljd4t30dBnAvMZaQCevtY0xLLD0A+bRZXbgLMLU1F/A= golang.org/x/sys v0.5.1-0.20230222185716-a3b23cc77e89 h1:260HNjMTPDya+jq5AM1zZLgG9pv9GASPAGiEEJUbRg4=
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.1-0.20230222185716-a3b23cc77e89/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/time v0.0.0-20191024005414-555d28b269f0 h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs= golang.org/x/time v0.0.0-20191024005414-555d28b269f0 h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs=
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224 h1:Ug9qvr1myri/zFN6xL17LSCBGFDnphBBhzmILHsM5TY= golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224 h1:Ug9qvr1myri/zFN6xL17LSCBGFDnphBBhzmILHsM5TY=

42
tun/checksum.go Normal file
View File

@ -0,0 +1,42 @@
package tun
import "encoding/binary"
// TODO: Explore SIMD and/or other assembly optimizations.
func checksumNoFold(b []byte, initial uint64) uint64 {
ac := initial
i := 0
n := len(b)
for n >= 4 {
ac += uint64(binary.BigEndian.Uint32(b[i : i+4]))
n -= 4
i += 4
}
for n >= 2 {
ac += uint64(binary.BigEndian.Uint16(b[i : i+2]))
n -= 2
i += 2
}
if n == 1 {
ac += uint64(b[i]) << 8
}
return ac
}
func checksum(b []byte, initial uint64) uint16 {
ac := checksumNoFold(b, initial)
ac = (ac >> 16) + (ac & 0xffff)
ac = (ac >> 16) + (ac & 0xffff)
ac = (ac >> 16) + (ac & 0xffff)
ac = (ac >> 16) + (ac & 0xffff)
return uint16(ac)
}
func pseudoHeaderChecksumNoFold(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint64 {
sum := checksumNoFold(srcAddr, 0)
sum = checksumNoFold(dstAddr, sum)
sum = checksumNoFold([]byte{0, protocol}, sum)
tmp := make([]byte, 2)
binary.BigEndian.PutUint16(tmp, totalLen)
return checksumNoFold(tmp, sum)
}

612
tun/tcp_offload_linux.go Normal file
View File

@ -0,0 +1,612 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package tun
import (
"bytes"
"encoding/binary"
"errors"
"io"
"unsafe"
"golang.org/x/sys/unix"
"golang.zx2c4.com/wireguard/conn"
)
const tcpFlagsOffset = 13
const (
tcpFlagFIN uint8 = 0x01
tcpFlagPSH uint8 = 0x08
tcpFlagACK uint8 = 0x10
)
// virtioNetHdr is defined in the kernel in include/uapi/linux/virtio_net.h. The
// kernel symbol is virtio_net_hdr.
type virtioNetHdr struct {
flags uint8
gsoType uint8
hdrLen uint16
gsoSize uint16
csumStart uint16
csumOffset uint16
}
func (v *virtioNetHdr) decode(b []byte) error {
if len(b) < virtioNetHdrLen {
return io.ErrShortBuffer
}
copy(unsafe.Slice((*byte)(unsafe.Pointer(v)), virtioNetHdrLen), b[:virtioNetHdrLen])
return nil
}
func (v *virtioNetHdr) encode(b []byte) error {
if len(b) < virtioNetHdrLen {
return io.ErrShortBuffer
}
copy(b[:virtioNetHdrLen], unsafe.Slice((*byte)(unsafe.Pointer(v)), virtioNetHdrLen))
return nil
}
const (
// virtioNetHdrLen is the length in bytes of virtioNetHdr. This matches the
// shape of the C ABI for its kernel counterpart -- sizeof(virtio_net_hdr).
virtioNetHdrLen = int(unsafe.Sizeof(virtioNetHdr{}))
)
// flowKey represents the key for a flow.
type flowKey struct {
srcAddr, dstAddr [16]byte
srcPort, dstPort uint16
rxAck uint32 // varying ack values should not be coalesced. Treat them as separate flows.
}
// tcpGROTable holds flow and coalescing information for the purposes of GRO.
type tcpGROTable struct {
itemsByFlow map[flowKey][]tcpGROItem
itemsPool [][]tcpGROItem
}
func newTCPGROTable() *tcpGROTable {
t := &tcpGROTable{
itemsByFlow: make(map[flowKey][]tcpGROItem, conn.DefaultBatchSize),
itemsPool: make([][]tcpGROItem, conn.DefaultBatchSize),
}
for i := range t.itemsPool {
t.itemsPool[i] = make([]tcpGROItem, 0, conn.DefaultBatchSize)
}
return t
}
func newFlowKey(pkt []byte, srcAddr, dstAddr, tcphOffset int) flowKey {
key := flowKey{}
addrSize := dstAddr - srcAddr
copy(key.srcAddr[:], pkt[srcAddr:dstAddr])
copy(key.dstAddr[:], pkt[dstAddr:dstAddr+addrSize])
key.srcPort = binary.BigEndian.Uint16(pkt[tcphOffset:])
key.dstPort = binary.BigEndian.Uint16(pkt[tcphOffset+2:])
key.rxAck = binary.BigEndian.Uint32(pkt[tcphOffset+8:])
return key
}
// lookupOrInsert looks up a flow for the provided packet and metadata,
// returning the packets found for the flow, or inserting a new one if none
// is found.
func (t *tcpGROTable) lookupOrInsert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, buffsIndex int) ([]tcpGROItem, bool) {
key := newFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset)
items, ok := t.itemsByFlow[key]
if ok {
return items, ok
}
// TODO: insert() performs another map lookup. This could be rearranged to avoid.
t.insert(pkt, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, buffsIndex)
return nil, false
}
// insert an item in the table for the provided packet and packet metadata.
func (t *tcpGROTable) insert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, buffsIndex int) {
key := newFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset)
item := tcpGROItem{
key: key,
buffsIndex: uint16(buffsIndex),
gsoSize: uint16(len(pkt[tcphOffset+tcphLen:])),
iphLen: uint8(tcphOffset),
tcphLen: uint8(tcphLen),
sentSeq: binary.BigEndian.Uint32(pkt[tcphOffset+4:]),
pshSet: pkt[tcphOffset+tcpFlagsOffset]&tcpFlagPSH != 0,
}
items, ok := t.itemsByFlow[key]
if !ok {
items = t.newItems()
}
items = append(items, item)
t.itemsByFlow[key] = items
}
func (t *tcpGROTable) updateAt(item tcpGROItem, i int) {
items, _ := t.itemsByFlow[item.key]
items[i] = item
}
func (t *tcpGROTable) deleteAt(key flowKey, i int) {
items, _ := t.itemsByFlow[key]
items = append(items[:i], items[i+1:]...)
t.itemsByFlow[key] = items
}
// tcpGROItem represents bookkeeping data for a TCP packet during the lifetime
// of a GRO evaluation across a vector of packets.
type tcpGROItem struct {
key flowKey
sentSeq uint32 // the sequence number
buffsIndex uint16 // the index into the original buffs slice
numMerged uint16 // the number of packets merged into this item
gsoSize uint16 // payload size
iphLen uint8 // ip header len
tcphLen uint8 // tcp header len
pshSet bool // psh flag is set
}
func (t *tcpGROTable) newItems() []tcpGROItem {
var items []tcpGROItem
items, t.itemsPool = t.itemsPool[len(t.itemsPool)-1], t.itemsPool[:len(t.itemsPool)-1]
return items
}
func (t *tcpGROTable) reset() {
for k, items := range t.itemsByFlow {
items = items[:0]
t.itemsPool = append(t.itemsPool, items)
delete(t.itemsByFlow, k)
}
}
// canCoalesce represents the outcome of checking if two TCP packets are
// candidates for coalescing.
type canCoalesce int
const (
coalescePrepend canCoalesce = -1
coalesceUnavailable canCoalesce = 0
coalesceAppend canCoalesce = 1
)
// tcpPacketsCanCoalesce evaluates if pkt can be coalesced with the packet
// described by item. This function makes considerations that match the kernel's
// GRO self tests, which can be found in tools/testing/selftests/net/gro.c.
func tcpPacketsCanCoalesce(pkt []byte, iphLen, tcphLen uint8, seq uint32, pshSet bool, gsoSize uint16, item tcpGROItem, buffs [][]byte, buffsOffset int) canCoalesce {
pktTarget := buffs[item.buffsIndex][buffsOffset:]
if tcphLen != item.tcphLen {
// cannot coalesce with unequal tcp options len
return coalesceUnavailable
}
if tcphLen > 20 {
if !bytes.Equal(pkt[iphLen+20:iphLen+tcphLen], pktTarget[item.iphLen+20:iphLen+tcphLen]) {
// cannot coalesce with unequal tcp options
return coalesceUnavailable
}
}
if pkt[1] != pktTarget[1] {
// cannot coalesce with unequal ToS values
return coalesceUnavailable
}
if pkt[6]>>5 != pktTarget[6]>>5 {
// cannot coalesce with unequal DF or reserved bits. MF is checked
// further up the stack.
return coalesceUnavailable
}
// seq adjacency
lhsLen := item.gsoSize
lhsLen += item.numMerged * item.gsoSize
if seq == item.sentSeq+uint32(lhsLen) { // pkt aligns following item from a seq num perspective
if item.pshSet {
// We cannot append to a segment that has the PSH flag set, PSH
// can only be set on the final segment in a reassembled group.
return coalesceUnavailable
}
if len(pktTarget[iphLen+tcphLen:])%int(item.gsoSize) != 0 {
// A smaller than gsoSize packet has been appended previously.
// Nothing can come after a smaller packet on the end.
return coalesceUnavailable
}
if gsoSize > item.gsoSize {
// We cannot have a larger packet following a smaller one.
return coalesceUnavailable
}
return coalesceAppend
} else if seq+uint32(gsoSize) == item.sentSeq { // pkt aligns in front of item from a seq num perspective
if pshSet {
// We cannot prepend with a segment that has the PSH flag set, PSH
// can only be set on the final segment in a reassembled group.
return coalesceUnavailable
}
if gsoSize < item.gsoSize {
// We cannot have a larger packet following a smaller one.
return coalesceUnavailable
}
if gsoSize > item.gsoSize && item.numMerged > 0 {
// There's at least one previous merge, and we're larger than all
// previous. This would put multiple smaller packets on the end.
return coalesceUnavailable
}
return coalescePrepend
}
return coalesceUnavailable
}
func tcpChecksumValid(pkt []byte, iphLen uint8, isV6 bool) bool {
srcAddrAt := ipv4SrcAddrOffset
addrSize := 4
if isV6 {
srcAddrAt = ipv6SrcAddrOffset
addrSize = 16
}
tcpTotalLen := uint16(len(pkt) - int(iphLen))
tcpCSumNoFold := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, pkt[srcAddrAt:srcAddrAt+addrSize], pkt[srcAddrAt+addrSize:srcAddrAt+addrSize*2], tcpTotalLen)
return ^checksum(pkt[iphLen:], tcpCSumNoFold) == 0
}
// coalesceResult represents the result of attempting to coalesce two TCP
// packets.
type coalesceResult int
const (
coalesceInsufficientCap coalesceResult = 0
coalescePSHEnding coalesceResult = 1
coalesceItemInvalidCSum coalesceResult = 2
coalescePktInvalidCSum coalesceResult = 3
coalesceSuccess coalesceResult = 4
)
// coalesceTCPPackets attempts to coalesce pkt with the packet described by
// item, returning the outcome. This function may swap buffs elements in the
// event of a prepend as item's buffs index is already being tracked for writing
// to a Device.
func coalesceTCPPackets(mode canCoalesce, pkt []byte, pktBuffsIndex int, gsoSize uint16, seq uint32, pshSet bool, item *tcpGROItem, buffs [][]byte, buffsOffset int, isV6 bool) coalesceResult {
var pktHead []byte // the packet that will end up at the front
headersLen := item.iphLen + item.tcphLen
coalescedLen := len(buffs[item.buffsIndex][buffsOffset:]) + len(pkt) - int(headersLen)
// Copy data
if mode == coalescePrepend {
pktHead = pkt
if cap(pkt)-buffsOffset < coalescedLen {
// We don't want to allocate a new underlying array if capacity is
// too small.
return coalesceInsufficientCap
}
if pshSet {
return coalescePSHEnding
}
if item.numMerged == 0 {
if !tcpChecksumValid(buffs[item.buffsIndex][buffsOffset:], item.iphLen, isV6) {
return coalesceItemInvalidCSum
}
}
if !tcpChecksumValid(pkt, item.iphLen, isV6) {
return coalescePktInvalidCSum
}
item.sentSeq = seq
extendBy := coalescedLen - len(pktHead)
buffs[pktBuffsIndex] = append(buffs[pktBuffsIndex], make([]byte, extendBy)...)
copy(buffs[pktBuffsIndex][buffsOffset+len(pkt):], buffs[item.buffsIndex][buffsOffset+int(headersLen):])
// Flip the slice headers in buffs as part of prepend. The index of item
// is already being tracked for writing.
buffs[item.buffsIndex], buffs[pktBuffsIndex] = buffs[pktBuffsIndex], buffs[item.buffsIndex]
} else {
pktHead = buffs[item.buffsIndex][buffsOffset:]
if cap(pktHead)-buffsOffset < coalescedLen {
// We don't want to allocate a new underlying array if capacity is
// too small.
return coalesceInsufficientCap
}
if item.numMerged == 0 {
if !tcpChecksumValid(buffs[item.buffsIndex][buffsOffset:], item.iphLen, isV6) {
return coalesceItemInvalidCSum
}
}
if !tcpChecksumValid(pkt, item.iphLen, isV6) {
return coalescePktInvalidCSum
}
if pshSet {
// We are appending a segment with PSH set.
item.pshSet = pshSet
pktHead[item.iphLen+tcpFlagsOffset] |= tcpFlagPSH
}
extendBy := len(pkt) - int(headersLen)
buffs[item.buffsIndex] = append(buffs[item.buffsIndex], make([]byte, extendBy)...)
copy(buffs[item.buffsIndex][buffsOffset+len(pktHead):], pkt[headersLen:])
}
if gsoSize > item.gsoSize {
item.gsoSize = gsoSize
}
hdr := virtioNetHdr{
flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, // this turns into CHECKSUM_PARTIAL in the skb
hdrLen: uint16(headersLen),
gsoSize: uint16(item.gsoSize),
csumStart: uint16(item.iphLen),
csumOffset: 16,
}
// Recalculate the total len (IPv4) or payload len (IPv6). Recalculate the
// (IPv4) header checksum.
if isV6 {
hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV6
binary.BigEndian.PutUint16(pktHead[4:], uint16(coalescedLen)-uint16(item.iphLen)) // set new payload len
} else {
hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV4
pktHead[10], pktHead[11] = 0, 0 // clear checksum field
binary.BigEndian.PutUint16(pktHead[2:], uint16(coalescedLen)) // set new total length
iphCSum := ^checksum(pktHead[:item.iphLen], 0) // compute checksum
binary.BigEndian.PutUint16(pktHead[10:], iphCSum) // set checksum field
}
hdr.encode(buffs[item.buffsIndex][buffsOffset-virtioNetHdrLen:])
// Calculate the pseudo header checksum and place it at the TCP checksum
// offset. Downstream checksum offloading will combine this with computation
// of the tcp header and payload checksum.
addrLen := 4
addrOffset := ipv4SrcAddrOffset
if isV6 {
addrLen = 16
addrOffset = ipv6SrcAddrOffset
}
srcAddrAt := buffsOffset + addrOffset
srcAddr := buffs[item.buffsIndex][srcAddrAt : srcAddrAt+addrLen]
dstAddr := buffs[item.buffsIndex][srcAddrAt+addrLen : srcAddrAt+addrLen*2]
psum := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, srcAddr, dstAddr, uint16(coalescedLen-int(item.iphLen)))
binary.BigEndian.PutUint16(pktHead[hdr.csumStart+hdr.csumOffset:], checksum([]byte{}, psum))
item.numMerged++
return coalesceSuccess
}
const (
ipv4FlagMoreFragments = 0x80
)
const (
ipv4SrcAddrOffset = 12
ipv6SrcAddrOffset = 8
maxUint16 = 1<<16 - 1
)
// tcpGRO evaluates the TCP packet at pktI in buffs for coalescing with
// existing packets tracked in table. It will return false when pktI is not
// coalesced, otherwise true. This indicates to the caller if buffs[pktI]
// should be written to the Device.
func tcpGRO(buffs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) (pktCoalesced bool) {
pkt := buffs[pktI][offset:]
if len(pkt) > maxUint16 {
// A valid IPv4 or IPv6 packet will never exceed this.
return false
}
iphLen := int((pkt[0] & 0x0F) * 4)
if isV6 {
iphLen = 40
ipv6HPayloadLen := int(binary.BigEndian.Uint16(pkt[4:]))
if ipv6HPayloadLen != len(pkt)-iphLen {
return false
}
} else {
totalLen := int(binary.BigEndian.Uint16(pkt[2:]))
if totalLen != len(pkt) {
return false
}
if iphLen < 20 || iphLen > 60 {
return false
}
}
if len(pkt) < iphLen {
return false
}
tcphLen := int((pkt[iphLen+12] >> 4) * 4)
if tcphLen < 20 || tcphLen > 60 {
return false
}
if len(pkt) < iphLen+tcphLen {
return false
}
if !isV6 {
if pkt[6]&ipv4FlagMoreFragments != 0 || (pkt[6]<<3 != 0 || pkt[7] != 0) {
// no GRO support for fragmented segments for now
return false
}
}
tcpFlags := pkt[iphLen+tcpFlagsOffset]
var pshSet bool
// not a candidate if any non-ACK flags (except PSH+ACK) are set
if tcpFlags != tcpFlagACK {
if pkt[iphLen+tcpFlagsOffset] != tcpFlagACK|tcpFlagPSH {
return false
}
pshSet = true
}
gsoSize := uint16(len(pkt) - tcphLen - iphLen)
// not a candidate if payload len is 0
if gsoSize < 1 {
return false
}
seq := binary.BigEndian.Uint32(pkt[iphLen+4:])
srcAddrOffset := ipv4SrcAddrOffset
addrLen := 4
if isV6 {
srcAddrOffset = ipv6SrcAddrOffset
addrLen = 16
}
items, existing := table.lookupOrInsert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI)
if !existing {
return false
}
for i := len(items) - 1; i >= 0; i-- {
// In the best case of packets arriving in order iterating in reverse is
// more efficient if there are multiple items for a given flow. This
// also enables a natural table.deleteAt() in the
// coalesceItemInvalidCSum case without the need for index tracking.
// This algorithm makes a best effort to coalesce in the event of
// unordered packets, where pkt may land anywhere in items from a
// sequence number perspective, however once an item is inserted into
// the table it is never compared across other items later.
item := items[i]
can := tcpPacketsCanCoalesce(pkt, uint8(iphLen), uint8(tcphLen), seq, pshSet, gsoSize, item, buffs, offset)
if can != coalesceUnavailable {
result := coalesceTCPPackets(can, pkt, pktI, gsoSize, seq, pshSet, &item, buffs, offset, isV6)
switch result {
case coalesceSuccess:
table.updateAt(item, i)
return true
case coalesceItemInvalidCSum:
// delete the item with an invalid csum
table.deleteAt(item.key, i)
case coalescePktInvalidCSum:
// no point in inserting an item that we can't coalesce
return false
default:
}
}
}
// failed to coalesce with any other packets; store the item in the flow
table.insert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI)
return false
}
func isTCP4(b []byte) bool {
if len(b) < 40 {
return false
}
if b[0]>>4 != 4 {
return false
}
if b[9] != unix.IPPROTO_TCP {
return false
}
return true
}
func isTCP6NoEH(b []byte) bool {
if len(b) < 60 {
return false
}
if b[0]>>4 != 6 {
return false
}
if b[6] != unix.IPPROTO_TCP {
return false
}
return true
}
// handleGRO evaluates buffs for GRO, and writes the indices of the resulting
// packets into toWrite. toWrite, tcp4Table, and tcp6Table should initially be
// empty (but non-nil), and are passed in to save allocs as the caller may reset
// and recycle them across vectors of packets.
func handleGRO(buffs [][]byte, offset int, tcp4Table, tcp6Table *tcpGROTable, toWrite *[]int) error {
for i := range buffs {
if offset < virtioNetHdrLen || offset > len(buffs[i])-1 {
return errors.New("invalid offset")
}
var coalesced bool
switch {
case isTCP4(buffs[i][offset:]):
coalesced = tcpGRO(buffs, offset, i, tcp4Table, false)
case isTCP6NoEH(buffs[i][offset:]): // ipv6 packets w/extension headers do not coalesce
coalesced = tcpGRO(buffs, offset, i, tcp6Table, true)
}
if !coalesced {
hdr := virtioNetHdr{}
err := hdr.encode(buffs[i][offset-virtioNetHdrLen:])
if err != nil {
return err
}
*toWrite = append(*toWrite, i)
}
}
return nil
}
// tcpTSO splits packets from in into outBuffs, writing the size of each
// element into sizes. It returns the number of buffers populated, and/or an
// error.
func tcpTSO(in []byte, hdr virtioNetHdr, outBuffs [][]byte, sizes []int, outOffset int) (int, error) {
iphLen := int(hdr.csumStart)
srcAddrOffset := ipv6SrcAddrOffset
addrLen := 16
if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV4 {
in[10], in[11] = 0, 0 // clear ipv4 header checksum
srcAddrOffset = ipv4SrcAddrOffset
addrLen = 4
}
tcpCSumAt := int(hdr.csumStart + hdr.csumOffset)
in[tcpCSumAt], in[tcpCSumAt+1] = 0, 0 // clear tcp checksum
firstTCPSeqNum := binary.BigEndian.Uint32(in[hdr.csumStart+4:])
nextSegmentDataAt := int(hdr.hdrLen)
i := 0
for ; nextSegmentDataAt < len(in); i++ {
if i == len(outBuffs) {
return i - 1, ErrTooManySegments
}
nextSegmentEnd := nextSegmentDataAt + int(hdr.gsoSize)
if nextSegmentEnd > len(in) {
nextSegmentEnd = len(in)
}
segmentDataLen := nextSegmentEnd - nextSegmentDataAt
totalLen := int(hdr.hdrLen) + segmentDataLen
sizes[i] = totalLen
out := outBuffs[i][outOffset:]
copy(out, in[:iphLen])
if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV4 {
// For IPv4 we are responsible for incrementing the ID field,
// updating the total len field, and recalculating the header
// checksum.
if i > 0 {
id := binary.BigEndian.Uint16(out[4:])
id += uint16(i)
binary.BigEndian.PutUint16(out[4:], id)
}
binary.BigEndian.PutUint16(out[2:], uint16(totalLen))
ipv4CSum := ^checksum(out[:iphLen], 0)
binary.BigEndian.PutUint16(out[10:], ipv4CSum)
} else {
// For IPv6 we are responsible for updating the payload length field.
binary.BigEndian.PutUint16(out[4:], uint16(totalLen-iphLen))
}
// TCP header
copy(out[hdr.csumStart:hdr.hdrLen], in[hdr.csumStart:hdr.hdrLen])
tcpSeq := firstTCPSeqNum + uint32(hdr.gsoSize*uint16(i))
binary.BigEndian.PutUint32(out[hdr.csumStart+4:], tcpSeq)
if nextSegmentEnd != len(in) {
// FIN and PSH should only be set on last segment
clearFlags := tcpFlagFIN | tcpFlagPSH
out[hdr.csumStart+tcpFlagsOffset] &^= clearFlags
}
// payload
copy(out[hdr.hdrLen:], in[nextSegmentDataAt:nextSegmentEnd])
// TCP checksum
tcpHLen := int(hdr.hdrLen - hdr.csumStart)
tcpLenForPseudo := uint16(tcpHLen + segmentDataLen)
tcpCSumNoFold := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, in[srcAddrOffset:srcAddrOffset+addrLen], in[srcAddrOffset+addrLen:srcAddrOffset+addrLen*2], tcpLenForPseudo)
tcpCSum := ^checksum(out[hdr.csumStart:totalLen], tcpCSumNoFold)
binary.BigEndian.PutUint16(out[hdr.csumStart+hdr.csumOffset:], tcpCSum)
nextSegmentDataAt += int(hdr.gsoSize)
}
return i, nil
}
func gsoNoneChecksum(in []byte, cSumStart, cSumOffset uint16) error {
cSumAt := cSumStart + cSumOffset
// The initial value at the checksum offset should be summed with the
// checksum we compute. This is typically the pseudo-header checksum.
initial := binary.BigEndian.Uint16(in[cSumAt:])
in[cSumAt], in[cSumAt+1] = 0, 0
binary.BigEndian.PutUint16(in[cSumAt:], ^checksum(in[cSumStart:], uint64(initial)))
return nil
}

View File

@ -0,0 +1,273 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package tun
import (
"net/netip"
"testing"
"golang.org/x/sys/unix"
"golang.zx2c4.com/wireguard/conn"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
)
const (
offset = virtioNetHdrLen
)
var (
ip4PortA = netip.MustParseAddrPort("192.0.2.1:1")
ip4PortB = netip.MustParseAddrPort("192.0.2.2:1")
ip4PortC = netip.MustParseAddrPort("192.0.2.3:1")
ip6PortA = netip.MustParseAddrPort("[2001:db8::1]:1")
ip6PortB = netip.MustParseAddrPort("[2001:db8::2]:1")
ip6PortC = netip.MustParseAddrPort("[2001:db8::3]:1")
)
func tcp4Packet(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32) []byte {
totalLen := 40 + segmentSize
b := make([]byte, offset+int(totalLen), 65535)
ipv4H := header.IPv4(b[offset:])
srcAs4 := srcIPPort.Addr().As4()
dstAs4 := dstIPPort.Addr().As4()
ipv4H.Encode(&header.IPv4Fields{
SrcAddr: tcpip.Address(srcAs4[:]),
DstAddr: tcpip.Address(dstAs4[:]),
Protocol: unix.IPPROTO_TCP,
TTL: 64,
TotalLength: uint16(totalLen),
})
tcpH := header.TCP(b[offset+20:])
tcpH.Encode(&header.TCPFields{
SrcPort: srcIPPort.Port(),
DstPort: dstIPPort.Port(),
SeqNum: seq,
AckNum: 1,
DataOffset: 20,
Flags: flags,
WindowSize: 3000,
})
ipv4H.SetChecksum(^ipv4H.CalculateChecksum())
pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_TCP, ipv4H.SourceAddress(), ipv4H.DestinationAddress(), uint16(20+segmentSize))
tcpH.SetChecksum(^tcpH.CalculateChecksum(pseudoCsum))
return b
}
func tcp6Packet(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32) []byte {
totalLen := 60 + segmentSize
b := make([]byte, offset+int(totalLen), 65535)
ipv6H := header.IPv6(b[offset:])
srcAs16 := srcIPPort.Addr().As16()
dstAs16 := dstIPPort.Addr().As16()
ipv6H.Encode(&header.IPv6Fields{
SrcAddr: tcpip.Address(srcAs16[:]),
DstAddr: tcpip.Address(dstAs16[:]),
TransportProtocol: unix.IPPROTO_TCP,
HopLimit: 64,
PayloadLength: uint16(segmentSize + 20),
})
tcpH := header.TCP(b[offset+40:])
tcpH.Encode(&header.TCPFields{
SrcPort: srcIPPort.Port(),
DstPort: dstIPPort.Port(),
SeqNum: seq,
AckNum: 1,
DataOffset: 20,
Flags: flags,
WindowSize: 3000,
})
pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_TCP, ipv6H.SourceAddress(), ipv6H.DestinationAddress(), uint16(20+segmentSize))
tcpH.SetChecksum(^tcpH.CalculateChecksum(pseudoCsum))
return b
}
func Test_handleVirtioRead(t *testing.T) {
tests := []struct {
name string
hdr virtioNetHdr
pktIn []byte
wantLens []int
wantErr bool
}{
{
"tcp4",
virtioNetHdr{
flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
gsoType: unix.VIRTIO_NET_HDR_GSO_TCPV4,
gsoSize: 100,
hdrLen: 40,
csumStart: 20,
csumOffset: 16,
},
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck|header.TCPFlagPsh, 200, 1),
[]int{140, 140},
false,
},
{
"tcp6",
virtioNetHdr{
flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
gsoType: unix.VIRTIO_NET_HDR_GSO_TCPV6,
gsoSize: 100,
hdrLen: 60,
csumStart: 40,
csumOffset: 16,
},
tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck|header.TCPFlagPsh, 200, 1),
[]int{160, 160},
false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
out := make([][]byte, conn.DefaultBatchSize)
sizes := make([]int, conn.DefaultBatchSize)
for i := range out {
out[i] = make([]byte, 65535)
}
tt.hdr.encode(tt.pktIn)
n, err := handleVirtioRead(tt.pktIn, out, sizes, offset)
if err != nil {
if tt.wantErr {
return
}
t.Fatalf("got err: %v", err)
}
if n != len(tt.wantLens) {
t.Fatalf("got %d packets, wanted %d", n, len(tt.wantLens))
}
for i := range tt.wantLens {
if tt.wantLens[i] != sizes[i] {
t.Fatalf("wantLens[%d]: %d != outSizes: %d", i, tt.wantLens[i], sizes[i])
}
}
})
}
}
func flipTCP4Checksum(b []byte) []byte {
at := virtioNetHdrLen + 20 + 16 // 20 byte ipv4 header; tcp csum offset is 16
b[at] ^= 0xFF
b[at+1] ^= 0xFF
return b
}
func Fuzz_handleGRO(f *testing.F) {
pkt0 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1)
pkt1 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101)
pkt2 := tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201)
pkt3 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1)
pkt4 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101)
pkt5 := tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201)
f.Add(pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, offset)
f.Fuzz(func(t *testing.T, pkt0, pkt1, pkt2, pkt3, pkt4, pkt5 []byte, offset int) {
pkts := [][]byte{pkt0, pkt1, pkt2, pkt3, pkt4, pkt5}
toWrite := make([]int, 0, len(pkts))
handleGRO(pkts, offset, newTCPGROTable(), newTCPGROTable(), &toWrite)
if len(toWrite) > len(pkts) {
t.Errorf("len(toWrite): %d > len(pkts): %d", len(toWrite), len(pkts))
}
seenWriteI := make(map[int]bool)
for _, writeI := range toWrite {
if writeI < 0 || writeI > len(pkts)-1 {
t.Errorf("toWrite value (%d) outside bounds of len(pkts): %d", writeI, len(pkts))
}
if seenWriteI[writeI] {
t.Errorf("duplicate toWrite value: %d", writeI)
}
seenWriteI[writeI] = true
}
})
}
func Test_handleGRO(t *testing.T) {
tests := []struct {
name string
pktsIn [][]byte
wantToWrite []int
wantLens []int
wantErr bool
}{
{
"multiple flows",
[][]byte{
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // v4 flow 1
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // v4 flow 1
tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201), // v4 flow 2
tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), // v6 flow 1
tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101), // v6 flow 1
tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201), // v6 flow 2
},
[]int{0, 2, 3, 5},
[]int{240, 140, 260, 160},
false,
},
{
"PSH interleaved",
[][]byte{
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // v4 flow 1
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck|header.TCPFlagPsh, 100, 101), // v4 flow 1
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 301), // v4 flow 1
tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), // v6 flow 1
tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck|header.TCPFlagPsh, 100, 101), // v6 flow 1
tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 201), // v6 flow 1
tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 301), // v6 flow 1
},
[]int{0, 2, 4, 6},
[]int{240, 240, 260, 260},
false,
},
{
"coalesceItemInvalidCSum",
[][]byte{
flipTCP4Checksum(tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1)), // v4 flow 1 seq 1 len 100
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // v4 flow 1 seq 101 len 100
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1 seq 201 len 100
},
[]int{0, 1},
[]int{140, 240},
false,
},
{
"out of order",
[][]byte{
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // v4 flow 1 seq 101 len 100
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // v4 flow 1 seq 1 len 100
tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1 seq 201 len 100
},
[]int{0},
[]int{340},
false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
toWrite := make([]int, 0, len(tt.pktsIn))
err := handleGRO(tt.pktsIn, offset, newTCPGROTable(), newTCPGROTable(), &toWrite)
if err != nil {
if tt.wantErr {
return
}
t.Fatalf("got err: %v", err)
}
if len(toWrite) != len(tt.wantToWrite) {
t.Fatalf("got %d packets, wanted %d", len(toWrite), len(tt.wantToWrite))
}
for i, pktI := range tt.wantToWrite {
if tt.wantToWrite[i] != toWrite[i] {
t.Fatalf("wantToWrite[%d]: %d != toWrite: %d", i, tt.wantToWrite[i], toWrite[i])
}
if tt.wantLens[i] != len(tt.pktsIn[pktI][offset:]) {
t.Errorf("wanted len %d packet at %d, got: %d", tt.wantLens[i], i, len(tt.pktsIn[pktI][offset:]))
}
}
})
}
}

View File

@ -0,0 +1,8 @@
go test fuzz v1
[]byte("0")
[]byte("0")
[]byte("0")
[]byte("0")
[]byte("0")
[]byte("0")
int(34)

View File

@ -0,0 +1,8 @@
go test fuzz v1
[]byte("0")
[]byte("0")
[]byte("0")
[]byte("0")
[]byte("0")
[]byte("0")
int(-48)

View File

@ -17,9 +17,8 @@ import (
"time" "time"
"unsafe" "unsafe"
"golang.org/x/net/ipv6"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/rwcancel" "golang.zx2c4.com/wireguard/rwcancel"
) )
@ -33,17 +32,25 @@ type NativeTun struct {
index int32 // if index index int32 // if index
errors chan error // async error handling errors chan error // async error handling
events chan Event // device related events events chan Event // device related events
nopi bool // the device was passed IFF_NO_PI
netlinkSock int netlinkSock int
netlinkCancel *rwcancel.RWCancel netlinkCancel *rwcancel.RWCancel
hackListenerClosed sync.Mutex hackListenerClosed sync.Mutex
statusListenersShutdown chan struct{} statusListenersShutdown chan struct{}
batchSize int
vnetHdr bool
closeOnce sync.Once closeOnce sync.Once
nameOnce sync.Once // guards calling initNameCache, which sets following fields nameOnce sync.Once // guards calling initNameCache, which sets following fields
nameCache string // name of interface nameCache string // name of interface
nameErr error nameErr error
readOpMu sync.Mutex // readOpMu guards readBuff
readBuff [virtioNetHdrLen + 65535]byte // if vnetHdr every read() is prefixed by virtioNetHdr
writeOpMu sync.Mutex // writeOpMu guards toWrite, tcp4GROTable, tcp6GROTable
toWrite []int
tcp4GROTable, tcp6GROTable *tcpGROTable
} }
func (tun *NativeTun) File() *os.File { func (tun *NativeTun) File() *os.File {
@ -323,60 +330,142 @@ func (tun *NativeTun) nameSlow() (string, error) {
return unix.ByteSliceToString(ifr[:]), nil return unix.ByteSliceToString(ifr[:]), nil
} }
func (tun *NativeTun) Write(buffs [][]byte, offset int) (n int, err error) { func (tun *NativeTun) Write(buffs [][]byte, offset int) (int, error) {
var buf []byte tun.writeOpMu.Lock()
if tun.nopi { defer func() {
buf = buffs[0][offset:] tun.tcp4GROTable.reset()
tun.tcp6GROTable.reset()
tun.writeOpMu.Unlock()
}()
var (
errs []error
total int
)
tun.toWrite = tun.toWrite[:0]
if tun.vnetHdr {
err := handleGRO(buffs, offset, tun.tcp4GROTable, tun.tcp6GROTable, &tun.toWrite)
if err != nil {
return 0, err
}
offset -= virtioNetHdrLen
} else { } else {
// reserve space for header for i := range buffs {
buf = buffs[0][offset-4:] tun.toWrite = append(tun.toWrite, i)
// add packet information header
buf[0] = 0x00
buf[1] = 0x00
if buf[4]>>4 == ipv6.Version {
buf[2] = 0x86
buf[3] = 0xdd
} else {
buf[2] = 0x08
buf[3] = 0x00
} }
} }
for _, buffsI := range tun.toWrite {
_, err = tun.tunFile.Write(buf) n, err := tun.tunFile.Write(buffs[buffsI][offset:])
if errors.Is(err, syscall.EBADFD) { if errors.Is(err, syscall.EBADFD) {
err = os.ErrClosed return total, os.ErrClosed
} else if err == nil { }
n = 1 if err != nil {
errs = append(errs, err)
} else {
total += n
}
} }
return n, err return total, ErrorBatch(errs)
} }
func (tun *NativeTun) Read(buffs [][]byte, sizes []int, offset int) (n int, err error) { // handleVirtioRead splits in into buffs, leaving offset bytes at the front of
select { // each buffer. It mutates sizes to reflect the size of each element of buffs,
case err = <-tun.errors: // and returns the number of packets read.
default: func handleVirtioRead(in []byte, buffs [][]byte, sizes []int, offset int) (int, error) {
if tun.nopi { var hdr virtioNetHdr
sizes[0], err = tun.tunFile.Read(buffs[0][offset:]) err := hdr.decode(in)
if err == nil { if err != nil {
n = 1 return 0, err
} }
} else { in = in[virtioNetHdrLen:]
buff := buffs[0][offset-4:] if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_NONE {
sizes[0], err = tun.tunFile.Read(buff[:]) if hdr.flags&unix.VIRTIO_NET_HDR_F_NEEDS_CSUM != 0 {
if errors.Is(err, syscall.EBADFD) { // This means CHECKSUM_PARTIAL in skb context. We are responsible
err = os.ErrClosed // for computing the checksum starting at hdr.csumStart and placing
} else if err == nil { // at hdr.csumOffset.
n = 1 err = gsoNoneChecksum(in, hdr.csumStart, hdr.csumOffset)
} if err != nil {
if sizes[0] < 4 { return 0, err
sizes[0] = 0
} else {
sizes[0] -= 4
} }
} }
if len(in) > len(buffs[0][offset:]) {
return 0, fmt.Errorf("read len %d overflows buffs element len %d", len(in), len(buffs[0][offset:]))
}
n := copy(buffs[0][offset:], in)
sizes[0] = n
return 1, nil
}
if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 {
return 0, fmt.Errorf("unsupported virtio GSO type: %d", hdr.gsoType)
}
ipVersion := in[0] >> 4
switch ipVersion {
case 4:
if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 {
return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType)
}
case 6:
if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 {
return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType)
}
default:
return 0, fmt.Errorf("invalid ip header version: %d", ipVersion)
}
if len(in) <= int(hdr.csumStart+12) {
return 0, errors.New("packet is too short")
}
// Don't trust hdr.hdrLen from the kernel as it can be equal to the length
// of the entire first packet when the kernel is handling it as part of a
// FORWARD path. Instead, parse the TCP header length and add it onto
// csumStart, which is synonymous for IP header length.
tcpHLen := uint16(in[hdr.csumStart+12] >> 4 * 4)
if tcpHLen < 20 || tcpHLen > 60 {
// A TCP header must be between 20 and 60 bytes in length.
return 0, fmt.Errorf("tcp header len is invalid: %d", tcpHLen)
}
hdr.hdrLen = hdr.csumStart + tcpHLen
if len(in) < int(hdr.hdrLen) {
return 0, fmt.Errorf("length of packet (%d) < virtioNetHdr.hdrLen (%d)", len(in), hdr.hdrLen)
}
if hdr.hdrLen < hdr.csumStart {
return 0, fmt.Errorf("virtioNetHdr.hdrLen (%d) < virtioNetHdr.csumStart (%d)", hdr.hdrLen, hdr.csumStart)
}
cSumAt := int(hdr.csumStart + hdr.csumOffset)
if cSumAt+1 >= len(in) {
return 0, fmt.Errorf("end of checksum offset (%d) exceeds packet length (%d)", cSumAt+1, len(in))
}
return tcpTSO(in, hdr, buffs, sizes, offset)
}
func (tun *NativeTun) Read(buffs [][]byte, sizes []int, offset int) (int, error) {
tun.readOpMu.Lock()
defer tun.readOpMu.Unlock()
select {
case err := <-tun.errors:
return 0, err
default:
readInto := buffs[0][offset:]
if tun.vnetHdr {
readInto = tun.readBuff[:]
}
n, err := tun.tunFile.Read(readInto)
if errors.Is(err, syscall.EBADFD) {
err = os.ErrClosed
}
if err != nil {
return 0, err
}
if tun.vnetHdr {
return handleVirtioRead(readInto[:n], buffs, sizes, offset)
} else {
sizes[0] = n
return 1, nil
}
} }
return
} }
func (tun *NativeTun) Events() <-chan Event { func (tun *NativeTun) Events() <-chan Event {
@ -403,9 +492,49 @@ func (tun *NativeTun) Close() error {
} }
func (tun *NativeTun) BatchSize() int { func (tun *NativeTun) BatchSize() int {
return 1 return tun.batchSize
} }
const (
// TODO: support TSO with ECN bits
tunOffloads = unix.TUN_F_CSUM | unix.TUN_F_TSO4 | unix.TUN_F_TSO6
)
func (tun *NativeTun) initFromFlags(name string) error {
sc, err := tun.tunFile.SyscallConn()
if err != nil {
return err
}
if e := sc.Control(func(fd uintptr) {
var (
ifr *unix.Ifreq
)
ifr, err = unix.NewIfreq(name)
if err != nil {
return
}
err = unix.IoctlIfreq(int(fd), unix.TUNGETIFF, ifr)
if err != nil {
return
}
got := ifr.Uint16()
if got&unix.IFF_VNET_HDR != 0 {
err = unix.IoctlSetInt(int(fd), unix.TUNSETOFFLOAD, tunOffloads)
if err != nil {
return
}
tun.vnetHdr = true
tun.batchSize = conn.DefaultBatchSize
} else {
tun.batchSize = 1
}
}); e != nil {
return e
}
return err
}
// CreateTUN creates a Device with the provided name and MTU.
func CreateTUN(name string, mtu int) (Device, error) { func CreateTUN(name string, mtu int) (Device, error) {
nfd, err := unix.Open(cloneDevicePath, unix.O_RDWR|unix.O_CLOEXEC, 0) nfd, err := unix.Open(cloneDevicePath, unix.O_RDWR|unix.O_CLOEXEC, 0)
if err != nil { if err != nil {
@ -415,25 +544,16 @@ func CreateTUN(name string, mtu int) (Device, error) {
return nil, err return nil, err
} }
var ifr [ifReqSize]byte ifr, err := unix.NewIfreq(name)
var flags uint16 = unix.IFF_TUN // | unix.IFF_NO_PI (disabled for TUN status hack) if err != nil {
nameBytes := []byte(name) return nil, err
if len(nameBytes) >= unix.IFNAMSIZ {
unix.Close(nfd)
return nil, fmt.Errorf("interface name too long: %w", unix.ENAMETOOLONG)
} }
copy(ifr[:], nameBytes) // IFF_VNET_HDR enables the "tun status hack" via routineHackListener()
*(*uint16)(unsafe.Pointer(&ifr[unix.IFNAMSIZ])) = flags // where a null write will return EINVAL indicating the TUN is up.
ifr.SetUint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_VNET_HDR)
_, _, errno := unix.Syscall( err = unix.IoctlIfreq(nfd, unix.TUNSETIFF, ifr)
unix.SYS_IOCTL, if err != nil {
uintptr(nfd), return nil, err
uintptr(unix.TUNSETIFF),
uintptr(unsafe.Pointer(&ifr[0])),
)
if errno != 0 {
unix.Close(nfd)
return nil, errno
} }
err = unix.SetNonblock(nfd, true) err = unix.SetNonblock(nfd, true)
@ -448,13 +568,16 @@ func CreateTUN(name string, mtu int) (Device, error) {
return CreateTUNFromFile(fd, mtu) return CreateTUNFromFile(fd, mtu)
} }
// CreateTUNFromFile creates a Device from an os.File with the provided MTU.
func CreateTUNFromFile(file *os.File, mtu int) (Device, error) { func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
tun := &NativeTun{ tun := &NativeTun{
tunFile: file, tunFile: file,
events: make(chan Event, 5), events: make(chan Event, 5),
errors: make(chan error, 5), errors: make(chan error, 5),
statusListenersShutdown: make(chan struct{}), statusListenersShutdown: make(chan struct{}),
nopi: false, tcp4GROTable: newTCPGROTable(),
tcp6GROTable: newTCPGROTable(),
toWrite: make([]int, 0, conn.DefaultBatchSize),
} }
name, err := tun.Name() name, err := tun.Name()
@ -462,8 +585,12 @@ func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
return nil, err return nil, err
} }
// start event listener err = tun.initFromFlags(name)
if err != nil {
return nil, err
}
// start event listener
tun.index, err = getIFIndex(name) tun.index, err = getIFIndex(name)
if err != nil { if err != nil {
return nil, err return nil, err
@ -492,6 +619,8 @@ func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
return tun, nil return tun, nil
} }
// CreateUnmonitoredTUNFromFD creates a Device from the provided file
// descriptor.
func CreateUnmonitoredTUNFromFD(fd int) (Device, string, error) { func CreateUnmonitoredTUNFromFD(fd int) (Device, string, error) {
err := unix.SetNonblock(fd, true) err := unix.SetNonblock(fd, true)
if err != nil { if err != nil {
@ -499,14 +628,20 @@ func CreateUnmonitoredTUNFromFD(fd int) (Device, string, error) {
} }
file := os.NewFile(uintptr(fd), "/dev/tun") file := os.NewFile(uintptr(fd), "/dev/tun")
tun := &NativeTun{ tun := &NativeTun{
tunFile: file, tunFile: file,
events: make(chan Event, 5), events: make(chan Event, 5),
errors: make(chan error, 5), errors: make(chan error, 5),
nopi: true, tcp4GROTable: newTCPGROTable(),
tcp6GROTable: newTCPGROTable(),
toWrite: make([]int, 0, conn.DefaultBatchSize),
} }
name, err := tun.Name() name, err := tun.Name()
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }
return tun, name, nil err = tun.initFromFlags(name)
if err != nil {
return nil, "", err
}
return tun, name, err
} }