conn, device, tun: implement vectorized I/O plumbing

Accept packet vectors for reading and writing in the tun.Device and
conn.Bind interfaces, so that the internal plumbing between these
interfaces now passes a vector of packets. Vectors move untouched
between these interfaces, i.e. if 128 packets are received from
conn.Bind.Read(), 128 packets are passed to tun.Device.Write(). There is
no internal buffering.

Currently, existing implementations are only adjusted to have vectors
of length one. Subsequent patches will improve that.

Also, as a related fixup, use the unix and windows packages rather than
the syscall package when possible.

Co-authored-by: James Tucker <james@tailscale.com>
Signed-off-by: James Tucker <james@tailscale.com>
Signed-off-by: Jordan Whited <jordan@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
This commit is contained in:
Jordan Whited 2023-03-02 14:48:02 -08:00 committed by Jason A. Donenfeld
parent 21636207a6
commit 3bb8fec7e4
25 changed files with 1046 additions and 514 deletions

View File

@ -193,6 +193,10 @@ func (bind *LinuxSocketBind) SetMark(value uint32) error {
return nil return nil
} }
func (bind *LinuxSocketBind) BatchSize() int {
return 1
}
func (bind *LinuxSocketBind) Close() error { func (bind *LinuxSocketBind) Close() error {
// Take a readlock to shut down the sockets... // Take a readlock to shut down the sockets...
bind.mu.RLock() bind.mu.RLock()
@ -223,29 +227,39 @@ func (bind *LinuxSocketBind) Close() error {
return err2 return err2
} }
func (bind *LinuxSocketBind) receiveIPv4(buf []byte) (int, Endpoint, error) { func (bind *LinuxSocketBind) receiveIPv4(buffs [][]byte, sizes []int, eps []Endpoint) (int, error) {
bind.mu.RLock() bind.mu.RLock()
defer bind.mu.RUnlock() defer bind.mu.RUnlock()
if bind.sock4 == -1 { if bind.sock4 == -1 {
return 0, nil, net.ErrClosed return 0, net.ErrClosed
} }
var end LinuxSocketEndpoint var end LinuxSocketEndpoint
n, err := receive4(bind.sock4, buf, &end) n, err := receive4(bind.sock4, buffs[0], &end)
return n, &end, err if err != nil {
return 0, err
}
eps[0] = &end
sizes[0] = n
return 1, nil
} }
func (bind *LinuxSocketBind) receiveIPv6(buf []byte) (int, Endpoint, error) { func (bind *LinuxSocketBind) receiveIPv6(buffs [][]byte, sizes []int, eps []Endpoint) (int, error) {
bind.mu.RLock() bind.mu.RLock()
defer bind.mu.RUnlock() defer bind.mu.RUnlock()
if bind.sock6 == -1 { if bind.sock6 == -1 {
return 0, nil, net.ErrClosed return 0, net.ErrClosed
} }
var end LinuxSocketEndpoint var end LinuxSocketEndpoint
n, err := receive6(bind.sock6, buf, &end) n, err := receive6(bind.sock6, buffs[0], &end)
return n, &end, err if err != nil {
return 0, err
}
eps[0] = &end
sizes[0] = n
return 1, nil
} }
func (bind *LinuxSocketBind) Send(buff []byte, end Endpoint) error { func (bind *LinuxSocketBind) Send(buffs [][]byte, end Endpoint) error {
nend, ok := end.(*LinuxSocketEndpoint) nend, ok := end.(*LinuxSocketEndpoint)
if !ok { if !ok {
return ErrWrongEndpointType return ErrWrongEndpointType
@ -256,14 +270,25 @@ func (bind *LinuxSocketBind) Send(buff []byte, end Endpoint) error {
if bind.sock4 == -1 { if bind.sock4 == -1 {
return net.ErrClosed return net.ErrClosed
} }
return send4(bind.sock4, nend, buff) for _, buff := range buffs {
err := send4(bind.sock4, nend, buff)
if err != nil {
return err
}
}
} else { } else {
if bind.sock6 == -1 { if bind.sock6 == -1 {
return net.ErrClosed return net.ErrClosed
} }
return send6(bind.sock6, nend, buff) for _, buff := range buffs {
err := send6(bind.sock6, nend, buff)
if err != nil {
return err
} }
} }
}
return nil
}
func (end *LinuxSocketEndpoint) SrcIP() netip.Addr { func (end *LinuxSocketEndpoint) SrcIP() netip.Addr {
if !end.isV6 { if !end.isV6 {

View File

@ -128,6 +128,10 @@ again:
return fns, uint16(port), nil return fns, uint16(port), nil
} }
func (bind *StdNetBind) BatchSize() int {
return 1
}
func (bind *StdNetBind) Close() error { func (bind *StdNetBind) Close() error {
bind.mu.Lock() bind.mu.Lock()
defer bind.mu.Unlock() defer bind.mu.Unlock()
@ -150,20 +154,30 @@ func (bind *StdNetBind) Close() error {
} }
func (*StdNetBind) makeReceiveIPv4(conn *net.UDPConn) ReceiveFunc { func (*StdNetBind) makeReceiveIPv4(conn *net.UDPConn) ReceiveFunc {
return func(buff []byte) (int, Endpoint, error) { return func(buffs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
n, endpoint, err := conn.ReadFromUDPAddrPort(buff) size, endpoint, err := conn.ReadFromUDPAddrPort(buffs[0])
return n, asEndpoint(endpoint), err if err == nil {
sizes[0] = size
eps[0] = asEndpoint(endpoint)
return 1, nil
}
return 0, err
} }
} }
func (*StdNetBind) makeReceiveIPv6(conn *net.UDPConn) ReceiveFunc { func (*StdNetBind) makeReceiveIPv6(conn *net.UDPConn) ReceiveFunc {
return func(buff []byte) (int, Endpoint, error) { return func(buffs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
n, endpoint, err := conn.ReadFromUDPAddrPort(buff) size, endpoint, err := conn.ReadFromUDPAddrPort(buffs[0])
return n, asEndpoint(endpoint), err if err == nil {
sizes[0] = size
eps[0] = asEndpoint(endpoint)
return 1, nil
}
return 0, err
} }
} }
func (bind *StdNetBind) Send(buff []byte, endpoint Endpoint) error { func (bind *StdNetBind) Send(buffs [][]byte, endpoint Endpoint) error {
var err error var err error
nend, ok := endpoint.(StdNetEndpoint) nend, ok := endpoint.(StdNetEndpoint)
if !ok { if !ok {
@ -186,9 +200,14 @@ func (bind *StdNetBind) Send(buff []byte, endpoint Endpoint) error {
if conn == nil { if conn == nil {
return syscall.EAFNOSUPPORT return syscall.EAFNOSUPPORT
} }
for _, buff := range buffs {
_, err = conn.WriteToUDPAddrPort(buff, addrPort) _, err = conn.WriteToUDPAddrPort(buff, addrPort)
if err != nil {
return err return err
} }
}
return nil
}
// endpointPool contains a re-usable set of mapping from netip.AddrPort to Endpoint. // endpointPool contains a re-usable set of mapping from netip.AddrPort to Endpoint.
// This exists to reduce allocations: Putting a netip.AddrPort in an Endpoint allocates, // This exists to reduce allocations: Putting a netip.AddrPort in an Endpoint allocates,

View File

@ -321,6 +321,11 @@ func (bind *WinRingBind) Close() error {
return nil return nil
} }
func (bind *WinRingBind) BatchSize() int {
// TODO: implement batching in and out of the ring
return 1
}
func (bind *WinRingBind) SetMark(mark uint32) error { func (bind *WinRingBind) SetMark(mark uint32) error {
return nil return nil
} }
@ -409,16 +414,22 @@ retry:
return n, &ep, nil return n, &ep, nil
} }
func (bind *WinRingBind) receiveIPv4(buf []byte) (int, Endpoint, error) { func (bind *WinRingBind) receiveIPv4(buffs [][]byte, sizes []int, eps []Endpoint) (int, error) {
bind.mu.RLock() bind.mu.RLock()
defer bind.mu.RUnlock() defer bind.mu.RUnlock()
return bind.v4.Receive(buf, &bind.isOpen) n, ep, err := bind.v4.Receive(buffs[0], &bind.isOpen)
sizes[0] = n
eps[0] = ep
return 1, err
} }
func (bind *WinRingBind) receiveIPv6(buf []byte) (int, Endpoint, error) { func (bind *WinRingBind) receiveIPv6(buffs [][]byte, sizes []int, eps []Endpoint) (int, error) {
bind.mu.RLock() bind.mu.RLock()
defer bind.mu.RUnlock() defer bind.mu.RUnlock()
return bind.v6.Receive(buf, &bind.isOpen) n, ep, err := bind.v6.Receive(buffs[0], &bind.isOpen)
sizes[0] = n
eps[0] = ep
return 1, err
} }
func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *atomic.Uint32) error { func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *atomic.Uint32) error {
@ -473,32 +484,38 @@ func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *atomi
return winrio.SendEx(bind.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0) return winrio.SendEx(bind.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0)
} }
func (bind *WinRingBind) Send(buf []byte, endpoint Endpoint) error { func (bind *WinRingBind) Send(buffs [][]byte, endpoint Endpoint) error {
nend, ok := endpoint.(*WinRingEndpoint) nend, ok := endpoint.(*WinRingEndpoint)
if !ok { if !ok {
return ErrWrongEndpointType return ErrWrongEndpointType
} }
bind.mu.RLock() bind.mu.RLock()
defer bind.mu.RUnlock() defer bind.mu.RUnlock()
for _, buf := range buffs {
switch nend.family { switch nend.family {
case windows.AF_INET: case windows.AF_INET:
if bind.v4.blackhole { if bind.v4.blackhole {
return nil continue
}
if err := bind.v4.Send(buf, nend, &bind.isOpen); err != nil {
return err
} }
return bind.v4.Send(buf, nend, &bind.isOpen)
case windows.AF_INET6: case windows.AF_INET6:
if bind.v6.blackhole { if bind.v6.blackhole {
return nil continue
}
if err := bind.v6.Send(buf, nend, &bind.isOpen); err != nil {
return err
}
} }
return bind.v6.Send(buf, nend, &bind.isOpen)
} }
return nil return nil
} }
func (bind *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error { func (s *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
bind.mu.Lock() s.mu.Lock()
defer bind.mu.Unlock() defer s.mu.Unlock()
sysconn, err := bind.ipv4.SyscallConn() sysconn, err := s.ipv4.SyscallConn()
if err != nil { if err != nil {
return err return err
} }
@ -511,14 +528,14 @@ func (bind *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole
if err != nil { if err != nil {
return err return err
} }
bind.blackhole4 = blackhole s.blackhole4 = blackhole
return nil return nil
} }
func (bind *StdNetBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error { func (s *StdNetBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
bind.mu.Lock() s.mu.Lock()
defer bind.mu.Unlock() defer s.mu.Unlock()
sysconn, err := bind.ipv6.SyscallConn() sysconn, err := s.ipv6.SyscallConn()
if err != nil { if err != nil {
return err return err
} }
@ -531,7 +548,7 @@ func (bind *StdNetBind) BindSocketToInterface6(interfaceIndex uint32, blackhole
if err != nil { if err != nil {
return err return err
} }
bind.blackhole6 = blackhole s.blackhole6 = blackhole
return nil return nil
} }

View File

@ -89,20 +89,26 @@ func (c *ChannelBind) Close() error {
return nil return nil
} }
func (c *ChannelBind) BatchSize() int { return 1 }
func (c *ChannelBind) SetMark(mark uint32) error { return nil } func (c *ChannelBind) SetMark(mark uint32) error { return nil }
func (c *ChannelBind) makeReceiveFunc(ch chan []byte) conn.ReceiveFunc { func (c *ChannelBind) makeReceiveFunc(ch chan []byte) conn.ReceiveFunc {
return func(b []byte) (n int, ep conn.Endpoint, err error) { return func(buffs [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) {
select { select {
case <-c.closeSignal: case <-c.closeSignal:
return 0, nil, net.ErrClosed return 0, net.ErrClosed
case rx := <-ch: case rx := <-ch:
return copy(b, rx), c.target6, nil copied := copy(buffs[0], rx)
sizes[0] = copied
eps[0] = c.target6
return 1, nil
} }
} }
} }
func (c *ChannelBind) Send(b []byte, ep conn.Endpoint) error { func (c *ChannelBind) Send(buffs [][]byte, ep conn.Endpoint) error {
for _, b := range buffs {
select { select {
case <-c.closeSignal: case <-c.closeSignal:
return net.ErrClosed return net.ErrClosed
@ -117,6 +123,7 @@ func (c *ChannelBind) Send(b []byte, ep conn.Endpoint) error {
return os.ErrInvalid return os.ErrInvalid
} }
} }
}
return nil return nil
} }

View File

@ -15,10 +15,17 @@ import (
"strings" "strings"
) )
// A ReceiveFunc receives a single inbound packet from the network. const (
// It writes the data into b. n is the length of the packet. DefaultBatchSize = 1 // maximum number of packets handled per read and write
// ep is the remote endpoint. )
type ReceiveFunc func(b []byte) (n int, ep Endpoint, err error)
// A ReceiveFunc receives at least one packet from the network and writes them
// into packets. On a successful read it returns the number of elements of
// sizes, packets, and endpoints that should be evaluated. Some elements of
// sizes may be zero, and callers should ignore them. Callers must pass a sizes
// and eps slice with a length greater than or equal to the length of packets.
// These lengths must not exceed the length of the associated Bind.BatchSize().
type ReceiveFunc func(packets [][]byte, sizes []int, eps []Endpoint) (n int, err error)
// A Bind listens on a port for both IPv6 and IPv4 UDP traffic. // A Bind listens on a port for both IPv6 and IPv4 UDP traffic.
// //
@ -38,11 +45,16 @@ type Bind interface {
// This mark is passed to the kernel as the socket option SO_MARK. // This mark is passed to the kernel as the socket option SO_MARK.
SetMark(mark uint32) error SetMark(mark uint32) error
// Send writes a packet b to address ep. // Send writes one or more packets in buffs to address ep. The length of
Send(b []byte, ep Endpoint) error // buffs must not exceed BatchSize().
Send(buffs [][]byte, ep Endpoint) error
// ParseEndpoint creates a new endpoint from a string. // ParseEndpoint creates a new endpoint from a string.
ParseEndpoint(s string) (Endpoint, error) ParseEndpoint(s string) (Endpoint, error)
// BatchSize is the number of buffers expected to be passed to
// the ReceiveFuncs, and the maximum expected to be passed to SendBatch.
BatchSize() int
} }
// BindSocketToInterface is implemented by Bind objects that support being // BindSocketToInterface is implemented by Bind objects that support being

24
conn/conn_test.go Normal file
View File

@ -0,0 +1,24 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
import (
"testing"
)
func TestPrettyName(t *testing.T) {
var (
recvFunc ReceiveFunc = func(buffs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { return }
)
const want = "TestPrettyName"
t.Run("ReceiveFunc.PrettyName", func(t *testing.T) {
if got := recvFunc.PrettyName(); got != want {
t.Errorf("PrettyName() = %v, want %v", got, want)
}
})
}

View File

@ -72,7 +72,7 @@ func newHandshakeQueue() *handshakeQueue {
} }
type autodrainingInboundQueue struct { type autodrainingInboundQueue struct {
c chan *QueueInboundElement c chan *[]*QueueInboundElement
} }
// newAutodrainingInboundQueue returns a channel that will be drained when it gets GC'd. // newAutodrainingInboundQueue returns a channel that will be drained when it gets GC'd.
@ -81,7 +81,7 @@ type autodrainingInboundQueue struct {
// some other means, such as sending a sentinel nil values. // some other means, such as sending a sentinel nil values.
func newAutodrainingInboundQueue(device *Device) *autodrainingInboundQueue { func newAutodrainingInboundQueue(device *Device) *autodrainingInboundQueue {
q := &autodrainingInboundQueue{ q := &autodrainingInboundQueue{
c: make(chan *QueueInboundElement, QueueInboundSize), c: make(chan *[]*QueueInboundElement, QueueInboundSize),
} }
runtime.SetFinalizer(q, device.flushInboundQueue) runtime.SetFinalizer(q, device.flushInboundQueue)
return q return q
@ -90,10 +90,13 @@ func newAutodrainingInboundQueue(device *Device) *autodrainingInboundQueue {
func (device *Device) flushInboundQueue(q *autodrainingInboundQueue) { func (device *Device) flushInboundQueue(q *autodrainingInboundQueue) {
for { for {
select { select {
case elem := <-q.c: case elems := <-q.c:
for _, elem := range *elems {
elem.Lock() elem.Lock()
device.PutMessageBuffer(elem.buffer) device.PutMessageBuffer(elem.buffer)
device.PutInboundElement(elem) device.PutInboundElement(elem)
}
device.PutInboundElementsSlice(elems)
default: default:
return return
} }
@ -101,7 +104,7 @@ func (device *Device) flushInboundQueue(q *autodrainingInboundQueue) {
} }
type autodrainingOutboundQueue struct { type autodrainingOutboundQueue struct {
c chan *QueueOutboundElement c chan *[]*QueueOutboundElement
} }
// newAutodrainingOutboundQueue returns a channel that will be drained when it gets GC'd. // newAutodrainingOutboundQueue returns a channel that will be drained when it gets GC'd.
@ -111,7 +114,7 @@ type autodrainingOutboundQueue struct {
// All sends to the channel must be best-effort, because there may be no receivers. // All sends to the channel must be best-effort, because there may be no receivers.
func newAutodrainingOutboundQueue(device *Device) *autodrainingOutboundQueue { func newAutodrainingOutboundQueue(device *Device) *autodrainingOutboundQueue {
q := &autodrainingOutboundQueue{ q := &autodrainingOutboundQueue{
c: make(chan *QueueOutboundElement, QueueOutboundSize), c: make(chan *[]*QueueOutboundElement, QueueOutboundSize),
} }
runtime.SetFinalizer(q, device.flushOutboundQueue) runtime.SetFinalizer(q, device.flushOutboundQueue)
return q return q
@ -120,10 +123,13 @@ func newAutodrainingOutboundQueue(device *Device) *autodrainingOutboundQueue {
func (device *Device) flushOutboundQueue(q *autodrainingOutboundQueue) { func (device *Device) flushOutboundQueue(q *autodrainingOutboundQueue) {
for { for {
select { select {
case elem := <-q.c: case elems := <-q.c:
for _, elem := range *elems {
elem.Lock() elem.Lock()
device.PutMessageBuffer(elem.buffer) device.PutMessageBuffer(elem.buffer)
device.PutOutboundElement(elem) device.PutOutboundElement(elem)
}
device.PutOutboundElementsSlice(elems)
default: default:
return return
} }

View File

@ -68,6 +68,8 @@ type Device struct {
cookieChecker CookieChecker cookieChecker CookieChecker
pool struct { pool struct {
outboundElementsSlice *WaitPool
inboundElementsSlice *WaitPool
messageBuffers *WaitPool messageBuffers *WaitPool
inboundElements *WaitPool inboundElements *WaitPool
outboundElements *WaitPool outboundElements *WaitPool
@ -295,6 +297,7 @@ func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger) *Device {
device.peers.keyMap = make(map[NoisePublicKey]*Peer) device.peers.keyMap = make(map[NoisePublicKey]*Peer)
device.rate.limiter.Init() device.rate.limiter.Init()
device.indexTable.Init() device.indexTable.Init()
device.PopulatePools() device.PopulatePools()
// create queues // create queues
@ -322,6 +325,19 @@ func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger) *Device {
return device return device
} }
// BatchSize returns the BatchSize for the device as a whole which is the max of
// the bind batch size and the tun batch size. The batch size reported by device
// is the size used to construct memory pools, and is the allowed batch size for
// the lifetime of the device.
func (device *Device) BatchSize() int {
size := device.net.bind.BatchSize()
dSize := device.tun.device.BatchSize()
if size < dSize {
size = dSize
}
return size
}
func (device *Device) LookupPeer(pk NoisePublicKey) *Peer { func (device *Device) LookupPeer(pk NoisePublicKey) *Peer {
device.peers.RLock() device.peers.RLock()
defer device.peers.RUnlock() defer device.peers.RUnlock()
@ -472,11 +488,13 @@ func (device *Device) BindUpdate() error {
var err error var err error
var recvFns []conn.ReceiveFunc var recvFns []conn.ReceiveFunc
netc := &device.net netc := &device.net
recvFns, netc.port, err = netc.bind.Open(netc.port) recvFns, netc.port, err = netc.bind.Open(netc.port)
if err != nil { if err != nil {
netc.port = 0 netc.port = 0
return err return err
} }
netc.netlinkCancel, err = device.startRouteListener(netc.bind) netc.netlinkCancel, err = device.startRouteListener(netc.bind)
if err != nil { if err != nil {
netc.bind.Close() netc.bind.Close()
@ -507,8 +525,9 @@ func (device *Device) BindUpdate() error {
device.net.stopping.Add(len(recvFns)) device.net.stopping.Add(len(recvFns))
device.queue.decryption.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption device.queue.decryption.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption
device.queue.handshake.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.handshake device.queue.handshake.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.handshake
batchSize := netc.bind.BatchSize()
for _, fn := range recvFns { for _, fn := range recvFns {
go device.RoutineReceiveIncoming(fn) go device.RoutineReceiveIncoming(batchSize, fn)
} }
device.log.Verbosef("UDP bind has been updated") device.log.Verbosef("UDP bind has been updated")

View File

@ -12,6 +12,7 @@ import (
"io" "io"
"math/rand" "math/rand"
"net/netip" "net/netip"
"os"
"runtime" "runtime"
"runtime/pprof" "runtime/pprof"
"sync" "sync"
@ -21,6 +22,7 @@ import (
"golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/conn/bindtest" "golang.zx2c4.com/wireguard/conn/bindtest"
"golang.zx2c4.com/wireguard/tun"
"golang.zx2c4.com/wireguard/tun/tuntest" "golang.zx2c4.com/wireguard/tun/tuntest"
) )
@ -307,6 +309,17 @@ func TestConcurrencySafety(t *testing.T) {
} }
}) })
// Perform bind updates and keepalive sends concurrently with tunnel use.
t.Run("bindUpdate and keepalive", func(t *testing.T) {
const iters = 10
for i := 0; i < iters; i++ {
for _, peer := range pair {
peer.dev.BindUpdate()
peer.dev.SendKeepalivesToPeersWithCurrentKeypair()
}
}
})
close(done) close(done)
} }
@ -405,3 +418,59 @@ func goroutineLeakCheck(t *testing.T) {
t.Fatalf("expected %d goroutines, got %d, leak?", startGoroutines, endGoroutines) t.Fatalf("expected %d goroutines, got %d, leak?", startGoroutines, endGoroutines)
}) })
} }
type fakeBindSized struct {
size int
}
func (b *fakeBindSized) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
return nil, 0, nil
}
func (b *fakeBindSized) Close() error { return nil }
func (b *fakeBindSized) SetMark(mark uint32) error { return nil }
func (b *fakeBindSized) Send(buffs [][]byte, ep conn.Endpoint) error { return nil }
func (b *fakeBindSized) ParseEndpoint(s string) (conn.Endpoint, error) { return nil, nil }
func (b *fakeBindSized) BatchSize() int { return b.size }
type fakeTUNDeviceSized struct {
size int
}
func (t *fakeTUNDeviceSized) File() *os.File { return nil }
func (t *fakeTUNDeviceSized) Read(buffs [][]byte, sizes []int, offset int) (n int, err error) {
return 0, nil
}
func (t *fakeTUNDeviceSized) Write(buffs [][]byte, offset int) (int, error) { return 0, nil }
func (t *fakeTUNDeviceSized) MTU() (int, error) { return 0, nil }
func (t *fakeTUNDeviceSized) Name() (string, error) { return "", nil }
func (t *fakeTUNDeviceSized) Events() <-chan tun.Event { return nil }
func (t *fakeTUNDeviceSized) Close() error { return nil }
func (t *fakeTUNDeviceSized) BatchSize() int { return t.size }
func TestBatchSize(t *testing.T) {
d := Device{}
d.net.bind = &fakeBindSized{1}
d.tun.device = &fakeTUNDeviceSized{1}
if want, got := 1, d.BatchSize(); got != want {
t.Errorf("expected batch size %d, got %d", want, got)
}
d.net.bind = &fakeBindSized{1}
d.tun.device = &fakeTUNDeviceSized{128}
if want, got := 128, d.BatchSize(); got != want {
t.Errorf("expected batch size %d, got %d", want, got)
}
d.net.bind = &fakeBindSized{128}
d.tun.device = &fakeTUNDeviceSized{1}
if want, got := 128, d.BatchSize(); got != want {
t.Errorf("expected batch size %d, got %d", want, got)
}
d.net.bind = &fakeBindSized{128}
d.tun.device = &fakeTUNDeviceSized{128}
if want, got := 128, d.BatchSize(); got != want {
t.Errorf("expected batch size %d, got %d", want, got)
}
}

View File

@ -45,7 +45,7 @@ type Peer struct {
} }
queue struct { queue struct {
staged chan *QueueOutboundElement // staged packets before a handshake is available staged chan *[]*QueueOutboundElement // staged packets before a handshake is available
outbound *autodrainingOutboundQueue // sequential ordering of udp transmission outbound *autodrainingOutboundQueue // sequential ordering of udp transmission
inbound *autodrainingInboundQueue // sequential ordering of tun writing inbound *autodrainingInboundQueue // sequential ordering of tun writing
} }
@ -81,7 +81,7 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
peer.device = device peer.device = device
peer.queue.outbound = newAutodrainingOutboundQueue(device) peer.queue.outbound = newAutodrainingOutboundQueue(device)
peer.queue.inbound = newAutodrainingInboundQueue(device) peer.queue.inbound = newAutodrainingInboundQueue(device)
peer.queue.staged = make(chan *QueueOutboundElement, QueueStagedSize) peer.queue.staged = make(chan *[]*QueueOutboundElement, QueueStagedSize)
// map public key // map public key
_, ok := device.peers.keyMap[pk] _, ok := device.peers.keyMap[pk]
@ -108,7 +108,7 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
return peer, nil return peer, nil
} }
func (peer *Peer) SendBuffer(buffer []byte) error { func (peer *Peer) SendBuffers(buffers [][]byte) error {
peer.device.net.RLock() peer.device.net.RLock()
defer peer.device.net.RUnlock() defer peer.device.net.RUnlock()
@ -123,9 +123,13 @@ func (peer *Peer) SendBuffer(buffer []byte) error {
return errors.New("no known endpoint for peer") return errors.New("no known endpoint for peer")
} }
err := peer.device.net.bind.Send(buffer, peer.endpoint) err := peer.device.net.bind.Send(buffers, peer.endpoint)
if err == nil { if err == nil {
peer.txBytes.Add(uint64(len(buffer))) var totalLen uint64
for _, b := range buffers {
totalLen += uint64(len(b))
}
peer.txBytes.Add(totalLen)
} }
return err return err
} }
@ -187,8 +191,12 @@ func (peer *Peer) Start() {
device.flushInboundQueue(peer.queue.inbound) device.flushInboundQueue(peer.queue.inbound)
device.flushOutboundQueue(peer.queue.outbound) device.flushOutboundQueue(peer.queue.outbound)
go peer.RoutineSequentialSender()
go peer.RoutineSequentialReceiver() // Use the device batch size, not the bind batch size, as the device size is
// the size of the batch pools.
batchSize := peer.device.BatchSize()
go peer.RoutineSequentialSender(batchSize)
go peer.RoutineSequentialReceiver(batchSize)
peer.isRunning.Store(true) peer.isRunning.Store(true)
} }

View File

@ -46,6 +46,14 @@ func (p *WaitPool) Put(x any) {
} }
func (device *Device) PopulatePools() { func (device *Device) PopulatePools() {
device.pool.outboundElementsSlice = NewWaitPool(PreallocatedBuffersPerPool, func() any {
s := make([]*QueueOutboundElement, 0, device.BatchSize())
return &s
})
device.pool.inboundElementsSlice = NewWaitPool(PreallocatedBuffersPerPool, func() any {
s := make([]*QueueInboundElement, 0, device.BatchSize())
return &s
})
device.pool.messageBuffers = NewWaitPool(PreallocatedBuffersPerPool, func() any { device.pool.messageBuffers = NewWaitPool(PreallocatedBuffersPerPool, func() any {
return new([MaxMessageSize]byte) return new([MaxMessageSize]byte)
}) })
@ -57,6 +65,30 @@ func (device *Device) PopulatePools() {
}) })
} }
func (device *Device) GetOutboundElementsSlice() *[]*QueueOutboundElement {
return device.pool.outboundElementsSlice.Get().(*[]*QueueOutboundElement)
}
func (device *Device) PutOutboundElementsSlice(s *[]*QueueOutboundElement) {
for i := range *s {
(*s)[i] = nil
}
*s = (*s)[:0]
device.pool.outboundElementsSlice.Put(s)
}
func (device *Device) GetInboundElementsSlice() *[]*QueueInboundElement {
return device.pool.inboundElementsSlice.Get().(*[]*QueueInboundElement)
}
func (device *Device) PutInboundElementsSlice(s *[]*QueueInboundElement) {
for i := range *s {
(*s)[i] = nil
}
*s = (*s)[:0]
device.pool.inboundElementsSlice.Put(s)
}
func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte { func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte {
return device.pool.messageBuffers.Get().(*[MaxMessageSize]byte) return device.pool.messageBuffers.Get().(*[MaxMessageSize]byte)
} }

View File

@ -89,3 +89,51 @@ func BenchmarkWaitPool(b *testing.B) {
} }
wg.Wait() wg.Wait()
} }
func BenchmarkWaitPoolEmpty(b *testing.B) {
var wg sync.WaitGroup
var trials atomic.Int32
trials.Store(int32(b.N))
workers := runtime.NumCPU() + 2
if workers-4 <= 0 {
b.Skip("Not enough cores")
}
p := NewWaitPool(0, func() any { return make([]byte, 16) })
wg.Add(workers)
b.ResetTimer()
for i := 0; i < workers; i++ {
go func() {
defer wg.Done()
for trials.Add(-1) > 0 {
x := p.Get()
time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond)
p.Put(x)
}
}()
}
wg.Wait()
}
func BenchmarkSyncPool(b *testing.B) {
var wg sync.WaitGroup
var trials atomic.Int32
trials.Store(int32(b.N))
workers := runtime.NumCPU() + 2
if workers-4 <= 0 {
b.Skip("Not enough cores")
}
p := sync.Pool{New: func() any { return make([]byte, 16) }}
wg.Add(workers)
b.ResetTimer()
for i := 0; i < workers; i++ {
go func() {
defer wg.Done()
for trials.Add(-1) > 0 {
x := p.Get()
time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond)
p.Put(x)
}
}()
}
wg.Wait()
}

View File

@ -66,7 +66,7 @@ func (peer *Peer) keepKeyFreshReceiving() {
* Every time the bind is updated a new routine is started for * Every time the bind is updated a new routine is started for
* IPv4 and IPv6 (separately) * IPv4 and IPv6 (separately)
*/ */
func (device *Device) RoutineReceiveIncoming(recv conn.ReceiveFunc) { func (device *Device) RoutineReceiveIncoming(maxBatchSize int, recv conn.ReceiveFunc) {
recvName := recv.PrettyName() recvName := recv.PrettyName()
defer func() { defer func() {
device.log.Verbosef("Routine: receive incoming %s - stopped", recvName) device.log.Verbosef("Routine: receive incoming %s - stopped", recvName)
@ -79,20 +79,33 @@ func (device *Device) RoutineReceiveIncoming(recv conn.ReceiveFunc) {
// receive datagrams until conn is closed // receive datagrams until conn is closed
buffer := device.GetMessageBuffer()
var ( var (
buffsArrs = make([]*[MaxMessageSize]byte, maxBatchSize)
buffs = make([][]byte, maxBatchSize)
err error err error
size int sizes = make([]int, maxBatchSize)
endpoint conn.Endpoint count int
endpoints = make([]conn.Endpoint, maxBatchSize)
deathSpiral int deathSpiral int
elemsByPeer = make(map[*Peer]*[]*QueueInboundElement, maxBatchSize)
) )
for { for i := range buffsArrs {
size, endpoint, err = recv(buffer[:]) buffsArrs[i] = device.GetMessageBuffer()
buffs[i] = buffsArrs[i][:]
}
defer func() {
for i := 0; i < maxBatchSize; i++ {
if buffsArrs[i] != nil {
device.PutMessageBuffer(buffsArrs[i])
}
}
}()
for {
count, err = recv(buffs, sizes, endpoints)
if err != nil { if err != nil {
device.PutMessageBuffer(buffer)
if errors.Is(err, net.ErrClosed) { if errors.Is(err, net.ErrClosed) {
return return
} }
@ -103,24 +116,23 @@ func (device *Device) RoutineReceiveIncoming(recv conn.ReceiveFunc) {
if deathSpiral < 10 { if deathSpiral < 10 {
deathSpiral++ deathSpiral++
time.Sleep(time.Second / 3) time.Sleep(time.Second / 3)
buffer = device.GetMessageBuffer()
continue continue
} }
return return
} }
deathSpiral = 0 deathSpiral = 0
// handle each packet in the batch
for i, size := range sizes[:count] {
if size < MinMessageSize { if size < MinMessageSize {
continue continue
} }
// check size of packet // check size of packet
packet := buffer[:size] packet := buffsArrs[i][:size]
msgType := binary.LittleEndian.Uint32(packet[:4]) msgType := binary.LittleEndian.Uint32(packet[:4])
var okay bool
switch msgType { switch msgType {
// check if transport // check if transport
@ -154,50 +166,72 @@ func (device *Device) RoutineReceiveIncoming(recv conn.ReceiveFunc) {
peer := value.peer peer := value.peer
elem := device.GetInboundElement() elem := device.GetInboundElement()
elem.packet = packet elem.packet = packet
elem.buffer = buffer elem.buffer = buffsArrs[i]
elem.keypair = keypair elem.keypair = keypair
elem.endpoint = endpoint elem.endpoint = endpoints[i]
elem.counter = 0 elem.counter = 0
elem.Mutex = sync.Mutex{} elem.Mutex = sync.Mutex{}
elem.Lock() elem.Lock()
// add to decryption queues elemsForPeer, ok := elemsByPeer[peer]
if peer.isRunning.Load() { if !ok {
peer.queue.inbound.c <- elem elemsForPeer = device.GetInboundElementsSlice()
device.queue.decryption.c <- elem elemsByPeer[peer] = elemsForPeer
buffer = device.GetMessageBuffer()
} else {
device.PutInboundElement(elem)
} }
*elemsForPeer = append(*elemsForPeer, elem)
buffsArrs[i] = device.GetMessageBuffer()
buffs[i] = buffsArrs[i][:]
continue continue
// otherwise it is a fixed size & handshake related packet // otherwise it is a fixed size & handshake related packet
case MessageInitiationType: case MessageInitiationType:
okay = len(packet) == MessageInitiationSize if len(packet) != MessageInitiationSize {
continue
}
case MessageResponseType: case MessageResponseType:
okay = len(packet) == MessageResponseSize if len(packet) != MessageResponseSize {
continue
}
case MessageCookieReplyType: case MessageCookieReplyType:
okay = len(packet) == MessageCookieReplySize if len(packet) != MessageCookieReplySize {
continue
}
default: default:
device.log.Verbosef("Received message with unknown type") device.log.Verbosef("Received message with unknown type")
continue
} }
if okay {
select { select {
case device.queue.handshake.c <- QueueHandshakeElement{ case device.queue.handshake.c <- QueueHandshakeElement{
msgType: msgType, msgType: msgType,
buffer: buffer, buffer: buffsArrs[i],
packet: packet, packet: packet,
endpoint: endpoint, endpoint: endpoints[i],
}: }:
buffer = device.GetMessageBuffer() buffsArrs[i] = device.GetMessageBuffer()
buffs[i] = buffsArrs[i][:]
default: default:
} }
} }
for peer, elems := range elemsByPeer {
if peer.isRunning.Load() {
peer.queue.inbound.c <- elems
for _, elem := range *elems {
device.queue.decryption.c <- elem
}
} else {
for _, elem := range *elems {
device.PutMessageBuffer(elem.buffer)
device.PutInboundElement(elem)
}
device.PutInboundElementsSlice(elems)
}
delete(elemsByPeer, peer)
}
} }
} }
@ -393,7 +427,7 @@ func (device *Device) RoutineHandshake(id int) {
} }
} }
func (peer *Peer) RoutineSequentialReceiver() { func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) {
device := peer.device device := peer.device
defer func() { defer func() {
device.log.Verbosef("%v - Routine: sequential receiver - stopped", peer) device.log.Verbosef("%v - Routine: sequential receiver - stopped", peer)
@ -401,19 +435,21 @@ func (peer *Peer) RoutineSequentialReceiver() {
}() }()
device.log.Verbosef("%v - Routine: sequential receiver - started", peer) device.log.Verbosef("%v - Routine: sequential receiver - started", peer)
for elem := range peer.queue.inbound.c { buffs := make([][]byte, 0, maxBatchSize)
if elem == nil {
for elems := range peer.queue.inbound.c {
if elems == nil {
return return
} }
var err error for _, elem := range *elems {
elem.Lock() elem.Lock()
if elem.packet == nil { if elem.packet == nil {
// decryption failed // decryption failed
goto skip continue
} }
if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) { if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) {
goto skip continue
} }
peer.SetEndpointFromPacket(elem.endpoint) peer.SetEndpointFromPacket(elem.endpoint)
@ -421,7 +457,6 @@ func (peer *Peer) RoutineSequentialReceiver() {
peer.timersHandshakeComplete() peer.timersHandshakeComplete()
peer.SendStagedPackets() peer.SendStagedPackets()
} }
peer.keepKeyFreshReceiving() peer.keepKeyFreshReceiving()
peer.timersAnyAuthenticatedPacketTraversal() peer.timersAnyAuthenticatedPacketTraversal()
peer.timersAnyAuthenticatedPacketReceived() peer.timersAnyAuthenticatedPacketReceived()
@ -429,61 +464,62 @@ func (peer *Peer) RoutineSequentialReceiver() {
if len(elem.packet) == 0 { if len(elem.packet) == 0 {
device.log.Verbosef("%v - Receiving keepalive packet", peer) device.log.Verbosef("%v - Receiving keepalive packet", peer)
goto skip continue
} }
peer.timersDataReceived() peer.timersDataReceived()
switch elem.packet[0] >> 4 { switch elem.packet[0] >> 4 {
case ipv4.Version: case 4:
if len(elem.packet) < ipv4.HeaderLen { if len(elem.packet) < ipv4.HeaderLen {
goto skip continue
} }
field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2] field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2]
length := binary.BigEndian.Uint16(field) length := binary.BigEndian.Uint16(field)
if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen { if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen {
goto skip continue
} }
elem.packet = elem.packet[:length] elem.packet = elem.packet[:length]
src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len] src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
if device.allowedips.Lookup(src) != peer { if device.allowedips.Lookup(src) != peer {
device.log.Verbosef("IPv4 packet with disallowed source address from %v", peer) device.log.Verbosef("IPv4 packet with disallowed source address from %v", peer)
goto skip continue
} }
case ipv6.Version: case 6:
if len(elem.packet) < ipv6.HeaderLen { if len(elem.packet) < ipv6.HeaderLen {
goto skip continue
} }
field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2] field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2]
length := binary.BigEndian.Uint16(field) length := binary.BigEndian.Uint16(field)
length += ipv6.HeaderLen length += ipv6.HeaderLen
if int(length) > len(elem.packet) { if int(length) > len(elem.packet) {
goto skip continue
} }
elem.packet = elem.packet[:length] elem.packet = elem.packet[:length]
src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len] src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
if device.allowedips.Lookup(src) != peer { if device.allowedips.Lookup(src) != peer {
device.log.Verbosef("IPv6 packet with disallowed source address from %v", peer) device.log.Verbosef("IPv6 packet with disallowed source address from %v", peer)
goto skip continue
} }
default: default:
device.log.Verbosef("Packet with invalid IP version from %v", peer) device.log.Verbosef("Packet with invalid IP version from %v", peer)
goto skip continue
} }
_, err = device.tun.device.Write(elem.buffer[:MessageTransportOffsetContent+len(elem.packet)], MessageTransportOffsetContent) buffs = append(buffs, elem.buffer[:MessageTransportOffsetContent+len(elem.packet)])
}
if len(buffs) > 0 {
_, err := device.tun.device.Write(buffs, MessageTransportOffsetContent)
if err != nil && !device.isClosed() { if err != nil && !device.isClosed() {
device.log.Errorf("Failed to write packet to TUN device: %v", err) device.log.Errorf("Failed to write packets to TUN device: %v", err)
}
if len(peer.queue.inbound.c) == 0 {
err = device.tun.device.Flush()
if err != nil {
peer.device.log.Errorf("Unable to flush packets: %v", err)
} }
} }
skip: for _, elem := range *elems {
device.PutMessageBuffer(elem.buffer) device.PutMessageBuffer(elem.buffer)
device.PutInboundElement(elem) device.PutInboundElement(elem)
} }
buffs = buffs[:0]
device.PutInboundElementsSlice(elems)
}
} }

View File

@ -17,6 +17,7 @@ import (
"golang.org/x/crypto/chacha20poly1305" "golang.org/x/crypto/chacha20poly1305"
"golang.org/x/net/ipv4" "golang.org/x/net/ipv4"
"golang.org/x/net/ipv6" "golang.org/x/net/ipv6"
"golang.zx2c4.com/wireguard/tun"
) )
/* Outbound flow /* Outbound flow
@ -77,12 +78,15 @@ func (elem *QueueOutboundElement) clearPointers() {
func (peer *Peer) SendKeepalive() { func (peer *Peer) SendKeepalive() {
if len(peer.queue.staged) == 0 && peer.isRunning.Load() { if len(peer.queue.staged) == 0 && peer.isRunning.Load() {
elem := peer.device.NewOutboundElement() elem := peer.device.NewOutboundElement()
elems := peer.device.GetOutboundElementsSlice()
*elems = append(*elems, elem)
select { select {
case peer.queue.staged <- elem: case peer.queue.staged <- elems:
peer.device.log.Verbosef("%v - Sending keepalive packet", peer) peer.device.log.Verbosef("%v - Sending keepalive packet", peer)
default: default:
peer.device.PutMessageBuffer(elem.buffer) peer.device.PutMessageBuffer(elem.buffer)
peer.device.PutOutboundElement(elem) peer.device.PutOutboundElement(elem)
peer.device.PutOutboundElementsSlice(elems)
} }
} }
peer.SendStagedPackets() peer.SendStagedPackets()
@ -125,7 +129,7 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
peer.timersAnyAuthenticatedPacketTraversal() peer.timersAnyAuthenticatedPacketTraversal()
peer.timersAnyAuthenticatedPacketSent() peer.timersAnyAuthenticatedPacketSent()
err = peer.SendBuffer(packet) err = peer.SendBuffers([][]byte{packet})
if err != nil { if err != nil {
peer.device.log.Errorf("%v - Failed to send handshake initiation: %v", peer, err) peer.device.log.Errorf("%v - Failed to send handshake initiation: %v", peer, err)
} }
@ -163,7 +167,8 @@ func (peer *Peer) SendHandshakeResponse() error {
peer.timersAnyAuthenticatedPacketTraversal() peer.timersAnyAuthenticatedPacketTraversal()
peer.timersAnyAuthenticatedPacketSent() peer.timersAnyAuthenticatedPacketSent()
err = peer.SendBuffer(packet) // TODO: allocation could be avoided
err = peer.SendBuffers([][]byte{packet})
if err != nil { if err != nil {
peer.device.log.Errorf("%v - Failed to send handshake response: %v", peer, err) peer.device.log.Errorf("%v - Failed to send handshake response: %v", peer, err)
} }
@ -183,7 +188,8 @@ func (device *Device) SendHandshakeCookie(initiatingElem *QueueHandshakeElement)
var buff [MessageCookieReplySize]byte var buff [MessageCookieReplySize]byte
writer := bytes.NewBuffer(buff[:0]) writer := bytes.NewBuffer(buff[:0])
binary.Write(writer, binary.LittleEndian, reply) binary.Write(writer, binary.LittleEndian, reply)
device.net.bind.Send(writer.Bytes(), initiatingElem.endpoint) // TODO: allocation could be avoided
device.net.bind.Send([][]byte{writer.Bytes()}, initiatingElem.endpoint)
return nil return nil
} }
@ -198,11 +204,6 @@ func (peer *Peer) keepKeyFreshSending() {
} }
} }
/* Reads packets from the TUN and inserts
* into staged queue for peer
*
* Obs. Single instance per TUN device
*/
func (device *Device) RoutineReadFromTUN() { func (device *Device) RoutineReadFromTUN() {
defer func() { defer func() {
device.log.Verbosef("Routine: TUN reader - stopped") device.log.Verbosef("Routine: TUN reader - stopped")
@ -212,49 +213,53 @@ func (device *Device) RoutineReadFromTUN() {
device.log.Verbosef("Routine: TUN reader - started") device.log.Verbosef("Routine: TUN reader - started")
var elem *QueueOutboundElement var (
batchSize = device.BatchSize()
readErr error
elems = make([]*QueueOutboundElement, batchSize)
buffs = make([][]byte, batchSize)
elemsByPeer = make(map[*Peer]*[]*QueueOutboundElement, batchSize)
count = 0
sizes = make([]int, batchSize)
offset = MessageTransportHeaderSize
)
for { for i := range elems {
elems[i] = device.NewOutboundElement()
buffs[i] = elems[i].buffer[:]
}
defer func() {
for _, elem := range elems {
if elem != nil { if elem != nil {
device.PutMessageBuffer(elem.buffer) device.PutMessageBuffer(elem.buffer)
device.PutOutboundElement(elem) device.PutOutboundElement(elem)
} }
elem = device.NewOutboundElement()
// read packet
offset := MessageTransportHeaderSize
size, err := device.tun.device.Read(elem.buffer[:], offset)
if err != nil {
if !device.isClosed() {
if !errors.Is(err, os.ErrClosed) {
device.log.Errorf("Failed to read packet from TUN device: %v", err)
}
go device.Close()
}
device.PutMessageBuffer(elem.buffer)
device.PutOutboundElement(elem)
return
} }
}()
if size == 0 || size > MaxContentSize { for {
// read packets
count, readErr = device.tun.device.Read(buffs, sizes, offset)
for i := 0; i < count; i++ {
if sizes[i] < 1 {
continue continue
} }
elem.packet = elem.buffer[offset : offset+size] elem := elems[i]
elem.packet = buffs[i][offset : offset+sizes[i]]
// lookup peer // lookup peer
var peer *Peer var peer *Peer
switch elem.packet[0] >> 4 { switch elem.packet[0] >> 4 {
case ipv4.Version: case 4:
if len(elem.packet) < ipv4.HeaderLen { if len(elem.packet) < ipv4.HeaderLen {
continue continue
} }
dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len] dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
peer = device.allowedips.Lookup(dst) peer = device.allowedips.Lookup(dst)
case ipv6.Version: case 6:
if len(elem.packet) < ipv6.HeaderLen { if len(elem.packet) < ipv6.HeaderLen {
continue continue
} }
@ -268,25 +273,63 @@ func (device *Device) RoutineReadFromTUN() {
if peer == nil { if peer == nil {
continue continue
} }
elemsForPeer, ok := elemsByPeer[peer]
if !ok {
elemsForPeer = device.GetOutboundElementsSlice()
elemsByPeer[peer] = elemsForPeer
}
*elemsForPeer = append(*elemsForPeer, elem)
elems[i] = device.NewOutboundElement()
buffs[i] = elems[i].buffer[:]
}
for peer, elemsForPeer := range elemsByPeer {
if peer.isRunning.Load() { if peer.isRunning.Load() {
peer.StagePacket(elem) peer.StagePackets(elemsForPeer)
elem = nil
peer.SendStagedPackets() peer.SendStagedPackets()
} else {
for _, elem := range *elemsForPeer {
device.PutMessageBuffer(elem.buffer)
device.PutOutboundElement(elem)
}
device.PutOutboundElementsSlice(elemsForPeer)
}
delete(elemsByPeer, peer)
}
if readErr != nil {
if errors.Is(readErr, tun.ErrTooManySegments) {
// TODO: record stat for this
// This will happen if MSS is surprisingly small (< 576)
// coincident with reasonably high throughput.
device.log.Verbosef("Dropped some packets from multi-segment read: %v", readErr)
continue
}
if !device.isClosed() {
if !errors.Is(readErr, os.ErrClosed) {
device.log.Errorf("Failed to read packet from TUN device: %v", readErr)
}
go device.Close()
}
return
} }
} }
} }
func (peer *Peer) StagePacket(elem *QueueOutboundElement) { func (peer *Peer) StagePackets(elems *[]*QueueOutboundElement) {
for { for {
select { select {
case peer.queue.staged <- elem: case peer.queue.staged <- elems:
return return
default: default:
} }
select { select {
case tooOld := <-peer.queue.staged: case tooOld := <-peer.queue.staged:
peer.device.PutMessageBuffer(tooOld.buffer) for _, elem := range *tooOld {
peer.device.PutOutboundElement(tooOld) peer.device.PutMessageBuffer(elem.buffer)
peer.device.PutOutboundElement(elem)
}
peer.device.PutOutboundElementsSlice(tooOld)
default: default:
} }
} }
@ -305,27 +348,56 @@ top:
} }
for { for {
var elemsOOO *[]*QueueOutboundElement
select { select {
case elem := <-peer.queue.staged: case elems := <-peer.queue.staged:
i := 0
for _, elem := range *elems {
elem.peer = peer elem.peer = peer
elem.nonce = keypair.sendNonce.Add(1) - 1 elem.nonce = keypair.sendNonce.Add(1) - 1
if elem.nonce >= RejectAfterMessages { if elem.nonce >= RejectAfterMessages {
keypair.sendNonce.Store(RejectAfterMessages) keypair.sendNonce.Store(RejectAfterMessages)
peer.StagePacket(elem) // XXX: Out of order, but we can't front-load go chans if elemsOOO == nil {
goto top elemsOOO = peer.device.GetOutboundElementsSlice()
}
*elemsOOO = append(*elemsOOO, elem)
continue
} else {
(*elems)[i] = elem
i++
} }
elem.keypair = keypair elem.keypair = keypair
elem.Lock() elem.Lock()
}
*elems = (*elems)[:i]
if elemsOOO != nil {
peer.StagePackets(elemsOOO) // XXX: Out of order, but we can't front-load go chans
}
if len(*elems) == 0 {
peer.device.PutOutboundElementsSlice(elems)
goto top
}
// add to parallel and sequential queue // add to parallel and sequential queue
if peer.isRunning.Load() { if peer.isRunning.Load() {
peer.queue.outbound.c <- elem peer.queue.outbound.c <- elems
for _, elem := range *elems {
peer.device.queue.encryption.c <- elem peer.device.queue.encryption.c <- elem
}
} else { } else {
for _, elem := range *elems {
peer.device.PutMessageBuffer(elem.buffer) peer.device.PutMessageBuffer(elem.buffer)
peer.device.PutOutboundElement(elem) peer.device.PutOutboundElement(elem)
} }
peer.device.PutOutboundElementsSlice(elems)
}
if elemsOOO != nil {
goto top
}
default: default:
return return
} }
@ -335,9 +407,12 @@ top:
func (peer *Peer) FlushStagedPackets() { func (peer *Peer) FlushStagedPackets() {
for { for {
select { select {
case elem := <-peer.queue.staged: case elems := <-peer.queue.staged:
for _, elem := range *elems {
peer.device.PutMessageBuffer(elem.buffer) peer.device.PutMessageBuffer(elem.buffer)
peer.device.PutOutboundElement(elem) peer.device.PutOutboundElement(elem)
}
peer.device.PutOutboundElementsSlice(elems)
default: default:
return return
} }
@ -400,12 +475,7 @@ func (device *Device) RoutineEncryption(id int) {
} }
} }
/* Sequentially reads packets from queue and sends to endpoint func (peer *Peer) RoutineSequentialSender(maxBatchSize int) {
*
* Obs. Single instance per peer.
* The routine terminates then the outbound queue is closed.
*/
func (peer *Peer) RoutineSequentialSender() {
device := peer.device device := peer.device
defer func() { defer func() {
defer device.log.Verbosef("%v - Routine: sequential sender - stopped", peer) defer device.log.Verbosef("%v - Routine: sequential sender - stopped", peer)
@ -413,36 +483,50 @@ func (peer *Peer) RoutineSequentialSender() {
}() }()
device.log.Verbosef("%v - Routine: sequential sender - started", peer) device.log.Verbosef("%v - Routine: sequential sender - started", peer)
for elem := range peer.queue.outbound.c { buffs := make([][]byte, 0, maxBatchSize)
if elem == nil {
for elems := range peer.queue.outbound.c {
buffs = buffs[:0]
if elems == nil {
return return
} }
elem.Lock()
if !peer.isRunning.Load() { if !peer.isRunning.Load() {
// peer has been stopped; return re-usable elems to the shared pool. // peer has been stopped; return re-usable elems to the shared pool.
// This is an optimization only. It is possible for the peer to be stopped // This is an optimization only. It is possible for the peer to be stopped
// immediately after this check, in which case, elem will get processed. // immediately after this check, in which case, elem will get processed.
// The timers and SendBuffer code are resilient to a few stragglers. // The timers and SendBuffers code are resilient to a few stragglers.
// TODO: rework peer shutdown order to ensure // TODO: rework peer shutdown order to ensure
// that we never accidentally keep timers alive longer than necessary. // that we never accidentally keep timers alive longer than necessary.
for _, elem := range *elems {
elem.Lock()
device.PutMessageBuffer(elem.buffer) device.PutMessageBuffer(elem.buffer)
device.PutOutboundElement(elem) device.PutOutboundElement(elem)
}
continue continue
} }
dataSent := false
for _, elem := range *elems {
elem.Lock()
if len(elem.packet) != MessageKeepaliveSize {
dataSent = true
}
buffs = append(buffs, elem.packet)
}
peer.timersAnyAuthenticatedPacketTraversal() peer.timersAnyAuthenticatedPacketTraversal()
peer.timersAnyAuthenticatedPacketSent() peer.timersAnyAuthenticatedPacketSent()
// send message and return buffer to pool err := peer.SendBuffers(buffs)
if dataSent {
err := peer.SendBuffer(elem.packet)
if len(elem.packet) != MessageKeepaliveSize {
peer.timersDataSent() peer.timersDataSent()
} }
for _, elem := range *elems {
device.PutMessageBuffer(elem.buffer) device.PutMessageBuffer(elem.buffer)
device.PutOutboundElement(elem) device.PutOutboundElement(elem)
}
device.PutOutboundElementsSlice(elems)
if err != nil { if err != nil {
device.log.Errorf("%v - Failed to send data packet: %v", peer, err) device.log.Errorf("%v - Failed to send data packets: %v", peer, err)
continue continue
} }

14
main.go
View File

@ -13,8 +13,8 @@ import (
"os/signal" "os/signal"
"runtime" "runtime"
"strconv" "strconv"
"syscall"
"golang.org/x/sys/unix"
"golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/ipc" "golang.zx2c4.com/wireguard/ipc"
@ -111,7 +111,7 @@ func main() {
// open TUN device (or use supplied fd) // open TUN device (or use supplied fd)
tun, err := func() (tun.Device, error) { tdev, err := func() (tun.Device, error) {
tunFdStr := os.Getenv(ENV_WG_TUN_FD) tunFdStr := os.Getenv(ENV_WG_TUN_FD)
if tunFdStr == "" { if tunFdStr == "" {
return tun.CreateTUN(interfaceName, device.DefaultMTU) return tun.CreateTUN(interfaceName, device.DefaultMTU)
@ -124,7 +124,7 @@ func main() {
return nil, err return nil, err
} }
err = syscall.SetNonblock(int(fd), true) err = unix.SetNonblock(int(fd), true)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -134,7 +134,7 @@ func main() {
}() }()
if err == nil { if err == nil {
realInterfaceName, err2 := tun.Name() realInterfaceName, err2 := tdev.Name()
if err2 == nil { if err2 == nil {
interfaceName = realInterfaceName interfaceName = realInterfaceName
} }
@ -196,7 +196,7 @@ func main() {
files[0], // stdin files[0], // stdin
files[1], // stdout files[1], // stdout
files[2], // stderr files[2], // stderr
tun.File(), tdev.File(),
fileUAPI, fileUAPI,
}, },
Dir: ".", Dir: ".",
@ -222,7 +222,7 @@ func main() {
return return
} }
device := device.NewDevice(tun, conn.NewDefaultBind(), logger) device := device.NewDevice(tdev, conn.NewDefaultBind(), logger)
logger.Verbosef("Device started") logger.Verbosef("Device started")
@ -250,7 +250,7 @@ func main() {
// wait for program to terminate // wait for program to terminate
signal.Notify(term, syscall.SIGTERM) signal.Notify(term, unix.SIGTERM)
signal.Notify(term, os.Interrupt) signal.Notify(term, os.Interrupt)
select { select {

View File

@ -9,7 +9,8 @@ import (
"fmt" "fmt"
"os" "os"
"os/signal" "os/signal"
"syscall"
"golang.org/x/sys/windows"
"golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/device"
@ -81,7 +82,7 @@ func main() {
signal.Notify(term, os.Interrupt) signal.Notify(term, os.Interrupt)
signal.Notify(term, os.Kill) signal.Notify(term, os.Kill)
signal.Notify(term, syscall.SIGTERM) signal.Notify(term, windows.SIGTERM)
select { select {
case <-term: case <-term:

60
tun/errors.go Normal file
View File

@ -0,0 +1,60 @@
package tun
import (
"errors"
"fmt"
)
var (
// ErrTooManySegments is returned by Device.Read() when segmentation
// overflows the length of supplied buffers. This error should not cause
// reads to cease.
ErrTooManySegments = errors.New("too many segments")
)
type errorBatch []error
// ErrorBatch takes a possibly nil or empty list of errors, and if the list is
// non-nil returns an error type that wraps all of the errors. Expected usage is
// to append to an []errors and coerce the set to an error using this method.
func ErrorBatch(errs []error) error {
if len(errs) == 0 {
return nil
}
return errorBatch(errs)
}
func (e errorBatch) Error() string {
if len(e) == 0 {
return ""
}
if len(e) == 1 {
return e[0].Error()
}
return fmt.Sprintf("batch operation: %v (and %d more errors)", e[0], len(e)-1)
}
func (e errorBatch) Is(target error) bool {
for _, err := range e {
if errors.Is(err, target) {
return true
}
}
return false
}
func (e errorBatch) As(target interface{}) bool {
for _, err := range e {
if errors.As(err, target) {
return true
}
}
return false
}
func (e errorBatch) Unwrap() error {
if len(e) == 0 {
return nil
}
return e[0]
}

View File

@ -19,6 +19,7 @@ import (
"regexp" "regexp"
"strconv" "strconv"
"strings" "strings"
"syscall"
"time" "time"
"golang.zx2c4.com/wireguard/tun" "golang.zx2c4.com/wireguard/tun"
@ -113,19 +114,25 @@ func (tun *netTun) Events() <-chan tun.Event {
return tun.events return tun.events
} }
func (tun *netTun) Read(buf []byte, offset int) (int, error) { func (tun *netTun) Read(buf [][]byte, sizes []int, offset int) (int, error) {
view, ok := <-tun.incomingPacket view, ok := <-tun.incomingPacket
if !ok { if !ok {
return 0, os.ErrClosed return 0, os.ErrClosed
} }
return view.Read(buf[offset:]) n, err := view.Read(buf[0][offset:])
if err != nil {
return 0, err
}
sizes[0] = n
return 1, nil
} }
func (tun *netTun) Write(buf []byte, offset int) (int, error) { func (tun *netTun) Write(buf [][]byte, offset int) (int, error) {
for _, buf := range buf {
packet := buf[offset:] packet := buf[offset:]
if len(packet) == 0 { if len(packet) == 0 {
return 0, nil continue
} }
pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: bufferv2.MakeWithData(packet)}) pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: bufferv2.MakeWithData(packet)})
@ -134,8 +141,10 @@ func (tun *netTun) Write(buf []byte, offset int) (int, error) {
tun.ep.InjectInbound(header.IPv4ProtocolNumber, pkb) tun.ep.InjectInbound(header.IPv4ProtocolNumber, pkb)
case 6: case 6:
tun.ep.InjectInbound(header.IPv6ProtocolNumber, pkb) tun.ep.InjectInbound(header.IPv6ProtocolNumber, pkb)
default:
return 0, syscall.EAFNOSUPPORT
}
} }
return len(buf), nil return len(buf), nil
} }
@ -151,10 +160,6 @@ func (tun *netTun) WriteNotify() {
tun.incomingPacket <- view tun.incomingPacket <- view
} }
func (tun *netTun) Flush() error {
return nil
}
func (tun *netTun) Close() error { func (tun *netTun) Close() error {
tun.stack.RemoveNIC(1) tun.stack.RemoveNIC(1)
@ -175,6 +180,10 @@ func (tun *netTun) MTU() (int, error) {
return tun.mtu, nil return tun.mtu, nil
} }
func (tun *netTun) BatchSize() int {
return 1
}
func convertToFullAddr(endpoint netip.AddrPort) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) { func convertToFullAddr(endpoint netip.AddrPort) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) {
var protoNumber tcpip.NetworkProtocolNumber var protoNumber tcpip.NetworkProtocolNumber
if endpoint.Addr().Is4() { if endpoint.Addr().Is4() {

View File

@ -18,12 +18,36 @@ const (
) )
type Device interface { type Device interface {
File() *os.File // returns the file descriptor of the device // File returns the file descriptor of the device.
Read([]byte, int) (int, error) // read a packet from the device (without any additional headers) File() *os.File
Write([]byte, int) (int, error) // writes a packet to the device (without any additional headers)
Flush() error // flush all previous writes to the device // Read one or more packets from the Device (without any additional headers).
MTU() (int, error) // returns the MTU of the device // On a successful read it returns the number of packets read, and sets
Name() (string, error) // fetches and returns the current name // packet lengths within the sizes slice. len(sizes) must be >= len(buffs).
Events() <-chan Event // returns a constant channel of events related to the device // A nonzero offset can be used to instruct the Device on where to begin
Close() error // stops the device and closes the event channel // reading into each element of the buffs slice.
Read(buffs [][]byte, sizes []int, offset int) (n int, err error)
// Write one or more packets to the device (without any additional headers).
// On a successful write it returns the number of packets written. A nonzero
// offset can be used to instruct the Device on where to begin writing from
// each packet contained within the buffs slice.
Write(buffs [][]byte, offset int) (int, error)
// MTU returns the MTU of the Device.
MTU() (int, error)
// Name returns the current name of the Device.
Name() (string, error)
// Events returns a channel of type Event, which is fed Device events.
Events() <-chan Event
// Close stops the Device and closes the Event channel.
Close() error
// BatchSize returns the preferred/max number of packets that can be read or
// written in a single read/write call. BatchSize must not change over the
// lifetime of a Device.
BatchSize() int
} }

View File

@ -8,6 +8,7 @@ package tun
import ( import (
"errors" "errors"
"fmt" "fmt"
"io"
"net" "net"
"os" "os"
"sync" "sync"
@ -15,7 +16,6 @@ import (
"time" "time"
"unsafe" "unsafe"
"golang.org/x/net/ipv6"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
@ -33,7 +33,7 @@ type NativeTun struct {
func retryInterfaceByIndex(index int) (iface *net.Interface, err error) { func retryInterfaceByIndex(index int) (iface *net.Interface, err error) {
for i := 0; i < 20; i++ { for i := 0; i < 20; i++ {
iface, err = net.InterfaceByIndex(index) iface, err = net.InterfaceByIndex(index)
if err != nil && errors.Is(err, syscall.ENOMEM) { if err != nil && errors.Is(err, unix.ENOMEM) {
time.Sleep(time.Duration(i) * time.Second / 3) time.Sleep(time.Duration(i) * time.Second / 3)
continue continue
} }
@ -55,7 +55,7 @@ func (tun *NativeTun) routineRouteListener(tunIfindex int) {
retry: retry:
n, err := unix.Read(tun.routeSocket, data) n, err := unix.Read(tun.routeSocket, data)
if err != nil { if err != nil {
if errno, ok := err.(syscall.Errno); ok && errno == syscall.EINTR { if errno, ok := err.(unix.Errno); ok && errno == unix.EINTR {
goto retry goto retry
} }
tun.errors <- err tun.errors <- err
@ -217,45 +217,46 @@ func (tun *NativeTun) Events() <-chan Event {
return tun.events return tun.events
} }
func (tun *NativeTun) Read(buff []byte, offset int) (int, error) { func (tun *NativeTun) Read(buffs [][]byte, sizes []int, offset int) (int, error) {
// TODO: the BSDs look very similar in Read() and Write(). They should be
// collapsed, with platform-specific files containing the varying parts of
// their implementations.
select { select {
case err := <-tun.errors: case err := <-tun.errors:
return 0, err return 0, err
default: default:
buff := buff[offset-4:] buff := buffs[0][offset-4:]
n, err := tun.tunFile.Read(buff[:]) n, err := tun.tunFile.Read(buff[:])
if n < 4 { if n < 4 {
return 0, err return 0, err
} }
return n - 4, err sizes[0] = n - 4
return 1, err
} }
} }
func (tun *NativeTun) Write(buff []byte, offset int) (int, error) { func (tun *NativeTun) Write(buffs [][]byte, offset int) (int, error) {
// reserve space for header if offset < 4 {
return 0, io.ErrShortBuffer
buff = buff[offset-4:]
// add packet information header
buff[0] = 0x00
buff[1] = 0x00
buff[2] = 0x00
if buff[4]>>4 == ipv6.Version {
buff[3] = unix.AF_INET6
} else {
buff[3] = unix.AF_INET
} }
for i, buf := range buffs {
// write buf = buf[offset-4:]
buf[0] = 0x00
return tun.tunFile.Write(buff) buf[1] = 0x00
buf[2] = 0x00
switch buf[4] >> 4 {
case 4:
buf[3] = unix.AF_INET
case 6:
buf[3] = unix.AF_INET6
default:
return i, unix.EAFNOSUPPORT
} }
if _, err := tun.tunFile.Write(buf); err != nil {
func (tun *NativeTun) Flush() error { return i, err
// TODO: can flushing be implemented by buffering and using sendmmsg? }
return nil }
return len(buffs), nil
} }
func (tun *NativeTun) Close() error { func (tun *NativeTun) Close() error {
@ -318,6 +319,10 @@ func (tun *NativeTun) MTU() (int, error) {
return int(ifr.MTU), nil return int(ifr.MTU), nil
} }
func (tun *NativeTun) BatchSize() int {
return 1
}
func socketCloexec(family, sotype, proto int) (fd int, err error) { func socketCloexec(family, sotype, proto int) (fd int, err error) {
// See go/src/net/sys_cloexec.go for background. // See go/src/net/sys_cloexec.go for background.
syscall.ForkLock.RLock() syscall.ForkLock.RLock()

View File

@ -333,27 +333,29 @@ func (tun *NativeTun) Events() <-chan Event {
return tun.events return tun.events
} }
func (tun *NativeTun) Read(buff []byte, offset int) (int, error) { func (tun *NativeTun) Read(buffs [][]byte, sizes []int, offset int) (int, error) {
select { select {
case err := <-tun.errors: case err := <-tun.errors:
return 0, err return 0, err
default: default:
buff := buff[offset-4:] buff := buffs[0][offset-4:]
n, err := tun.tunFile.Read(buff[:]) n, err := tun.tunFile.Read(buff[:])
if n < 4 { if n < 4 {
return 0, err return 0, err
} }
return n - 4, err sizes[0] = n - 4
return 1, err
} }
} }
func (tun *NativeTun) Write(buf []byte, offset int) (int, error) { func (tun *NativeTun) Write(buffs [][]byte, offset int) (int, error) {
if offset < 4 { if offset < 4 {
return 0, io.ErrShortBuffer return 0, io.ErrShortBuffer
} }
for i, buf := range buffs {
buf = buf[offset-4:] buf = buf[offset-4:]
if len(buf) < 5 { if len(buf) < 5 {
return 0, io.ErrShortBuffer return i, io.ErrShortBuffer
} }
buf[0] = 0x00 buf[0] = 0x00
buf[1] = 0x00 buf[1] = 0x00
@ -364,14 +366,13 @@ func (tun *NativeTun) Write(buf []byte, offset int) (int, error) {
case 6: case 6:
buf[3] = unix.AF_INET6 buf[3] = unix.AF_INET6
default: default:
return 0, unix.EAFNOSUPPORT return i, unix.EAFNOSUPPORT
} }
return tun.tunFile.Write(buf) if _, err := tun.tunFile.Write(buf); err != nil {
return i, err
} }
}
func (tun *NativeTun) Flush() error { return len(buffs), nil
// TODO: can flushing be implemented by buffering and using sendmmsg?
return nil
} }
func (tun *NativeTun) Close() error { func (tun *NativeTun) Close() error {
@ -428,3 +429,7 @@ func (tun *NativeTun) MTU() (int, error) {
} }
return int(*(*int32)(unsafe.Pointer(&ifr.MTU))), nil return int(*(*int32)(unsafe.Pointer(&ifr.MTU))), nil
} }
func (tun *NativeTun) BatchSize() int {
return 1
}

View File

@ -323,12 +323,13 @@ func (tun *NativeTun) nameSlow() (string, error) {
return unix.ByteSliceToString(ifr[:]), nil return unix.ByteSliceToString(ifr[:]), nil
} }
func (tun *NativeTun) Write(buf []byte, offset int) (int, error) { func (tun *NativeTun) Write(buffs [][]byte, offset int) (n int, err error) {
var buf []byte
if tun.nopi { if tun.nopi {
buf = buf[offset:] buf = buffs[0][offset:]
} else { } else {
// reserve space for header // reserve space for header
buf = buf[offset-4:] buf = buffs[0][offset-4:]
// add packet information header // add packet information header
buf[0] = 0x00 buf[0] = 0x00
@ -342,34 +343,36 @@ func (tun *NativeTun) Write(buf []byte, offset int) (int, error) {
} }
} }
n, err := tun.tunFile.Write(buf) _, err = tun.tunFile.Write(buf)
if errors.Is(err, syscall.EBADFD) { if errors.Is(err, syscall.EBADFD) {
err = os.ErrClosed err = os.ErrClosed
} else if err == nil {
n = 1
} }
return n, err return n, err
} }
func (tun *NativeTun) Flush() error { func (tun *NativeTun) Read(buffs [][]byte, sizes []int, offset int) (n int, err error) {
// TODO: can flushing be implemented by buffering and using sendmmsg?
return nil
}
func (tun *NativeTun) Read(buf []byte, offset int) (n int, err error) {
select { select {
case err = <-tun.errors: case err = <-tun.errors:
default: default:
if tun.nopi { if tun.nopi {
n, err = tun.tunFile.Read(buf[offset:]) sizes[0], err = tun.tunFile.Read(buffs[0][offset:])
if err == nil {
n = 1
}
} else { } else {
buff := buf[offset-4:] buff := buffs[0][offset-4:]
n, err = tun.tunFile.Read(buff[:]) sizes[0], err = tun.tunFile.Read(buff[:])
if errors.Is(err, syscall.EBADFD) { if errors.Is(err, syscall.EBADFD) {
err = os.ErrClosed err = os.ErrClosed
} else if err == nil {
n = 1
} }
if n < 4 { if sizes[0] < 4 {
n = 0 sizes[0] = 0
} else { } else {
n -= 4 sizes[0] -= 4
} }
} }
} }
@ -399,6 +402,10 @@ func (tun *NativeTun) Close() error {
return err2 return err2
} }
func (tun *NativeTun) BatchSize() int {
return 1
}
func CreateTUN(name string, mtu int) (Device, error) { func CreateTUN(name string, mtu int) (Device, error) {
nfd, err := unix.Open(cloneDevicePath, unix.O_RDWR|unix.O_CLOEXEC, 0) nfd, err := unix.Open(cloneDevicePath, unix.O_RDWR|unix.O_CLOEXEC, 0)
if err != nil { if err != nil {

View File

@ -8,13 +8,13 @@ package tun
import ( import (
"errors" "errors"
"fmt" "fmt"
"io"
"net" "net"
"os" "os"
"sync" "sync"
"syscall" "syscall"
"unsafe" "unsafe"
"golang.org/x/net/ipv6"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
@ -204,45 +204,43 @@ func (tun *NativeTun) Events() <-chan Event {
return tun.events return tun.events
} }
func (tun *NativeTun) Read(buff []byte, offset int) (int, error) { func (tun *NativeTun) Read(buffs [][]byte, sizes []int, offset int) (int, error) {
select { select {
case err := <-tun.errors: case err := <-tun.errors:
return 0, err return 0, err
default: default:
buff := buff[offset-4:] buff := buffs[0][offset-4:]
n, err := tun.tunFile.Read(buff[:]) n, err := tun.tunFile.Read(buff[:])
if n < 4 { if n < 4 {
return 0, err return 0, err
} }
return n - 4, err sizes[0] = n - 4
return 1, err
} }
} }
func (tun *NativeTun) Write(buff []byte, offset int) (int, error) { func (tun *NativeTun) Write(buffs [][]byte, offset int) (int, error) {
// reserve space for header if offset < 4 {
return 0, io.ErrShortBuffer
buff = buff[offset-4:]
// add packet information header
buff[0] = 0x00
buff[1] = 0x00
buff[2] = 0x00
if buff[4]>>4 == ipv6.Version {
buff[3] = unix.AF_INET6
} else {
buff[3] = unix.AF_INET
} }
for i, buf := range buffs {
// write buf = buf[offset-4:]
buf[0] = 0x00
return tun.tunFile.Write(buff) buf[1] = 0x00
buf[2] = 0x00
switch buf[4] >> 4 {
case 4:
buf[3] = unix.AF_INET
case 6:
buf[3] = unix.AF_INET6
default:
return i, unix.EAFNOSUPPORT
} }
if _, err := tun.tunFile.Write(buf); err != nil {
func (tun *NativeTun) Flush() error { return i, err
// TODO: can flushing be implemented by buffering and using sendmmsg? }
return nil }
return len(buffs), nil
} }
func (tun *NativeTun) Close() error { func (tun *NativeTun) Close() error {
@ -329,3 +327,7 @@ func (tun *NativeTun) MTU() (int, error) {
return int(*(*int32)(unsafe.Pointer(&ifr.MTU))), nil return int(*(*int32)(unsafe.Pointer(&ifr.MTU))), nil
} }
func (tun *NativeTun) BatchSize() int {
return 1
}

View File

@ -15,7 +15,6 @@ import (
_ "unsafe" _ "unsafe"
"golang.org/x/sys/windows" "golang.org/x/sys/windows"
"golang.zx2c4.com/wintun" "golang.zx2c4.com/wintun"
) )
@ -44,6 +43,7 @@ type NativeTun struct {
closeOnce sync.Once closeOnce sync.Once
close atomic.Bool close atomic.Bool
forcedMTU int forcedMTU int
outSizes []int
} }
var ( var (
@ -134,9 +134,14 @@ func (tun *NativeTun) ForceMTU(mtu int) {
} }
} }
func (tun *NativeTun) BatchSize() int {
// TODO: implement batching with wintun
return 1
}
// Note: Read() and Write() assume the caller comes only from a single thread; there's no locking. // Note: Read() and Write() assume the caller comes only from a single thread; there's no locking.
func (tun *NativeTun) Read(buff []byte, offset int) (int, error) { func (tun *NativeTun) Read(buffs [][]byte, sizes []int, offset int) (int, error) {
tun.running.Add(1) tun.running.Add(1)
defer tun.running.Done() defer tun.running.Done()
retry: retry:
@ -153,10 +158,11 @@ retry:
switch err { switch err {
case nil: case nil:
packetSize := len(packet) packetSize := len(packet)
copy(buff[offset:], packet) copy(buffs[0][offset:], packet)
sizes[0] = packetSize
tun.session.ReleaseReceivePacket(packet) tun.session.ReleaseReceivePacket(packet)
tun.rate.update(uint64(packetSize)) tun.rate.update(uint64(packetSize))
return packetSize, nil return 1, nil
case windows.ERROR_NO_MORE_ITEMS: case windows.ERROR_NO_MORE_ITEMS:
if !shouldSpin || uint64(nanotime()-start) >= spinloopDuration { if !shouldSpin || uint64(nanotime()-start) >= spinloopDuration {
windows.WaitForSingleObject(tun.readWait, windows.INFINITE) windows.WaitForSingleObject(tun.readWait, windows.INFINITE)
@ -173,33 +179,33 @@ retry:
} }
} }
func (tun *NativeTun) Flush() error { func (tun *NativeTun) Write(buffs [][]byte, offset int) (int, error) {
return nil
}
func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
tun.running.Add(1) tun.running.Add(1)
defer tun.running.Done() defer tun.running.Done()
if tun.close.Load() { if tun.close.Load() {
return 0, os.ErrClosed return 0, os.ErrClosed
} }
for i, buff := range buffs {
packetSize := len(buff) - offset packetSize := len(buff) - offset
tun.rate.update(uint64(packetSize)) tun.rate.update(uint64(packetSize))
packet, err := tun.session.AllocateSendPacket(packetSize) packet, err := tun.session.AllocateSendPacket(packetSize)
if err == nil { switch err {
case nil:
// TODO: Explore options to eliminate this copy.
copy(packet, buff[offset:]) copy(packet, buff[offset:])
tun.session.SendPacket(packet) tun.session.SendPacket(packet)
return packetSize, nil continue
}
switch err {
case windows.ERROR_HANDLE_EOF: case windows.ERROR_HANDLE_EOF:
return 0, os.ErrClosed return i, os.ErrClosed
case windows.ERROR_BUFFER_OVERFLOW: case windows.ERROR_BUFFER_OVERFLOW:
return 0, nil // Dropping when ring is full. continue // Dropping when ring is full.
default:
return i, fmt.Errorf("Write failed: %w", err)
} }
return 0, fmt.Errorf("Write failed: %w", err) }
return len(buffs), nil
} }
// LUID returns Windows interface instance ID. // LUID returns Windows interface instance ID.

View File

@ -110,35 +110,42 @@ type chTun struct {
func (t *chTun) File() *os.File { return nil } func (t *chTun) File() *os.File { return nil }
func (t *chTun) Read(data []byte, offset int) (int, error) { func (t *chTun) Read(packets [][]byte, sizes []int, offset int) (int, error) {
select { select {
case <-t.c.closed: case <-t.c.closed:
return 0, os.ErrClosed return 0, os.ErrClosed
case msg := <-t.c.Outbound: case msg := <-t.c.Outbound:
return copy(data[offset:], msg), nil n := copy(packets[0][offset:], msg)
sizes[0] = n
return 1, nil
} }
} }
// Write is called by the wireguard device to deliver a packet for routing. // Write is called by the wireguard device to deliver a packet for routing.
func (t *chTun) Write(data []byte, offset int) (int, error) { func (t *chTun) Write(packets [][]byte, offset int) (int, error) {
if offset == -1 { if offset == -1 {
close(t.c.closed) close(t.c.closed)
close(t.c.events) close(t.c.events)
return 0, io.EOF return 0, io.EOF
} }
for i, data := range packets {
msg := make([]byte, len(data)-offset) msg := make([]byte, len(data)-offset)
copy(msg, data[offset:]) copy(msg, data[offset:])
select { select {
case <-t.c.closed: case <-t.c.closed:
return 0, os.ErrClosed return i, os.ErrClosed
case t.c.Inbound <- msg: case t.c.Inbound <- msg:
return len(data) - offset, nil
} }
} }
return len(packets), nil
}
func (t *chTun) BatchSize() int {
return 1
}
const DefaultMTU = 1420 const DefaultMTU = 1420
func (t *chTun) Flush() error { return nil }
func (t *chTun) MTU() (int, error) { return DefaultMTU, nil } func (t *chTun) MTU() (int, error) { return DefaultMTU, nil }
func (t *chTun) Name() (string, error) { return "loopbackTun1", nil } func (t *chTun) Name() (string, error) { return "loopbackTun1", nil }
func (t *chTun) Events() <-chan tun.Event { return t.c.events } func (t *chTun) Events() <-chan tun.Event { return t.c.events }