wireguard-go/device/receive.go
Jordan Whited 3bb8fec7e4 conn, device, tun: implement vectorized I/O plumbing
Accept packet vectors for reading and writing in the tun.Device and
conn.Bind interfaces, so that the internal plumbing between these
interfaces now passes a vector of packets. Vectors move untouched
between these interfaces, i.e. if 128 packets are received from
conn.Bind.Read(), 128 packets are passed to tun.Device.Write(). There is
no internal buffering.

Currently, existing implementations are only adjusted to have vectors
of length one. Subsequent patches will improve that.

Also, as a related fixup, use the unix and windows packages rather than
the syscall package when possible.

Co-authored-by: James Tucker <james@tailscale.com>
Signed-off-by: James Tucker <james@tailscale.com>
Signed-off-by: Jordan Whited <jordan@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-03-10 14:52:13 +01:00

526 lines
13 KiB
Go

/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"bytes"
"encoding/binary"
"errors"
"net"
"sync"
"time"
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
"golang.zx2c4.com/wireguard/conn"
)
type QueueHandshakeElement struct {
msgType uint32
packet []byte
endpoint conn.Endpoint
buffer *[MaxMessageSize]byte
}
type QueueInboundElement struct {
sync.Mutex
buffer *[MaxMessageSize]byte
packet []byte
counter uint64
keypair *Keypair
endpoint conn.Endpoint
}
// clearPointers clears elem fields that contain pointers.
// This makes the garbage collector's life easier and
// avoids accidentally keeping other objects around unnecessarily.
// It also reduces the possible collateral damage from use-after-free bugs.
func (elem *QueueInboundElement) clearPointers() {
elem.buffer = nil
elem.packet = nil
elem.keypair = nil
elem.endpoint = nil
}
/* Called when a new authenticated message has been received
*
* NOTE: Not thread safe, but called by sequential receiver!
*/
func (peer *Peer) keepKeyFreshReceiving() {
if peer.timers.sentLastMinuteHandshake.Load() {
return
}
keypair := peer.keypairs.Current()
if keypair != nil && keypair.isInitiator && time.Since(keypair.created) > (RejectAfterTime-KeepaliveTimeout-RekeyTimeout) {
peer.timers.sentLastMinuteHandshake.Store(true)
peer.SendHandshakeInitiation(false)
}
}
/* Receives incoming datagrams for the device
*
* Every time the bind is updated a new routine is started for
* IPv4 and IPv6 (separately)
*/
func (device *Device) RoutineReceiveIncoming(maxBatchSize int, recv conn.ReceiveFunc) {
recvName := recv.PrettyName()
defer func() {
device.log.Verbosef("Routine: receive incoming %s - stopped", recvName)
device.queue.decryption.wg.Done()
device.queue.handshake.wg.Done()
device.net.stopping.Done()
}()
device.log.Verbosef("Routine: receive incoming %s - started", recvName)
// receive datagrams until conn is closed
var (
buffsArrs = make([]*[MaxMessageSize]byte, maxBatchSize)
buffs = make([][]byte, maxBatchSize)
err error
sizes = make([]int, maxBatchSize)
count int
endpoints = make([]conn.Endpoint, maxBatchSize)
deathSpiral int
elemsByPeer = make(map[*Peer]*[]*QueueInboundElement, maxBatchSize)
)
for i := range buffsArrs {
buffsArrs[i] = device.GetMessageBuffer()
buffs[i] = buffsArrs[i][:]
}
defer func() {
for i := 0; i < maxBatchSize; i++ {
if buffsArrs[i] != nil {
device.PutMessageBuffer(buffsArrs[i])
}
}
}()
for {
count, err = recv(buffs, sizes, endpoints)
if err != nil {
if errors.Is(err, net.ErrClosed) {
return
}
device.log.Verbosef("Failed to receive %s packet: %v", recvName, err)
if neterr, ok := err.(net.Error); ok && !neterr.Temporary() {
return
}
if deathSpiral < 10 {
deathSpiral++
time.Sleep(time.Second / 3)
continue
}
return
}
deathSpiral = 0
// handle each packet in the batch
for i, size := range sizes[:count] {
if size < MinMessageSize {
continue
}
// check size of packet
packet := buffsArrs[i][:size]
msgType := binary.LittleEndian.Uint32(packet[:4])
switch msgType {
// check if transport
case MessageTransportType:
// check size
if len(packet) < MessageTransportSize {
continue
}
// lookup key pair
receiver := binary.LittleEndian.Uint32(
packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter],
)
value := device.indexTable.Lookup(receiver)
keypair := value.keypair
if keypair == nil {
continue
}
// check keypair expiry
if keypair.created.Add(RejectAfterTime).Before(time.Now()) {
continue
}
// create work element
peer := value.peer
elem := device.GetInboundElement()
elem.packet = packet
elem.buffer = buffsArrs[i]
elem.keypair = keypair
elem.endpoint = endpoints[i]
elem.counter = 0
elem.Mutex = sync.Mutex{}
elem.Lock()
elemsForPeer, ok := elemsByPeer[peer]
if !ok {
elemsForPeer = device.GetInboundElementsSlice()
elemsByPeer[peer] = elemsForPeer
}
*elemsForPeer = append(*elemsForPeer, elem)
buffsArrs[i] = device.GetMessageBuffer()
buffs[i] = buffsArrs[i][:]
continue
// otherwise it is a fixed size & handshake related packet
case MessageInitiationType:
if len(packet) != MessageInitiationSize {
continue
}
case MessageResponseType:
if len(packet) != MessageResponseSize {
continue
}
case MessageCookieReplyType:
if len(packet) != MessageCookieReplySize {
continue
}
default:
device.log.Verbosef("Received message with unknown type")
continue
}
select {
case device.queue.handshake.c <- QueueHandshakeElement{
msgType: msgType,
buffer: buffsArrs[i],
packet: packet,
endpoint: endpoints[i],
}:
buffsArrs[i] = device.GetMessageBuffer()
buffs[i] = buffsArrs[i][:]
default:
}
}
for peer, elems := range elemsByPeer {
if peer.isRunning.Load() {
peer.queue.inbound.c <- elems
for _, elem := range *elems {
device.queue.decryption.c <- elem
}
} else {
for _, elem := range *elems {
device.PutMessageBuffer(elem.buffer)
device.PutInboundElement(elem)
}
device.PutInboundElementsSlice(elems)
}
delete(elemsByPeer, peer)
}
}
}
func (device *Device) RoutineDecryption(id int) {
var nonce [chacha20poly1305.NonceSize]byte
defer device.log.Verbosef("Routine: decryption worker %d - stopped", id)
device.log.Verbosef("Routine: decryption worker %d - started", id)
for elem := range device.queue.decryption.c {
// split message into fields
counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent]
content := elem.packet[MessageTransportOffsetContent:]
// decrypt and release to consumer
var err error
elem.counter = binary.LittleEndian.Uint64(counter)
// copy counter to nonce
binary.LittleEndian.PutUint64(nonce[0x4:0xc], elem.counter)
elem.packet, err = elem.keypair.receive.Open(
content[:0],
nonce[:],
content,
nil,
)
if err != nil {
elem.packet = nil
}
elem.Unlock()
}
}
/* Handles incoming packets related to handshake
*/
func (device *Device) RoutineHandshake(id int) {
defer func() {
device.log.Verbosef("Routine: handshake worker %d - stopped", id)
device.queue.encryption.wg.Done()
}()
device.log.Verbosef("Routine: handshake worker %d - started", id)
for elem := range device.queue.handshake.c {
// handle cookie fields and ratelimiting
switch elem.msgType {
case MessageCookieReplyType:
// unmarshal packet
var reply MessageCookieReply
reader := bytes.NewReader(elem.packet)
err := binary.Read(reader, binary.LittleEndian, &reply)
if err != nil {
device.log.Verbosef("Failed to decode cookie reply")
goto skip
}
// lookup peer from index
entry := device.indexTable.Lookup(reply.Receiver)
if entry.peer == nil {
goto skip
}
// consume reply
if peer := entry.peer; peer.isRunning.Load() {
device.log.Verbosef("Receiving cookie response from %s", elem.endpoint.DstToString())
if !peer.cookieGenerator.ConsumeReply(&reply) {
device.log.Verbosef("Could not decrypt invalid cookie response")
}
}
goto skip
case MessageInitiationType, MessageResponseType:
// check mac fields and maybe ratelimit
if !device.cookieChecker.CheckMAC1(elem.packet) {
device.log.Verbosef("Received packet with invalid mac1")
goto skip
}
// endpoints destination address is the source of the datagram
if device.IsUnderLoad() {
// verify MAC2 field
if !device.cookieChecker.CheckMAC2(elem.packet, elem.endpoint.DstToBytes()) {
device.SendHandshakeCookie(&elem)
goto skip
}
// check ratelimiter
if !device.rate.limiter.Allow(elem.endpoint.DstIP()) {
goto skip
}
}
default:
device.log.Errorf("Invalid packet ended up in the handshake queue")
goto skip
}
// handle handshake initiation/response content
switch elem.msgType {
case MessageInitiationType:
// unmarshal
var msg MessageInitiation
reader := bytes.NewReader(elem.packet)
err := binary.Read(reader, binary.LittleEndian, &msg)
if err != nil {
device.log.Errorf("Failed to decode initiation message")
goto skip
}
// consume initiation
peer := device.ConsumeMessageInitiation(&msg)
if peer == nil {
device.log.Verbosef("Received invalid initiation message from %s", elem.endpoint.DstToString())
goto skip
}
// update timers
peer.timersAnyAuthenticatedPacketTraversal()
peer.timersAnyAuthenticatedPacketReceived()
// update endpoint
peer.SetEndpointFromPacket(elem.endpoint)
device.log.Verbosef("%v - Received handshake initiation", peer)
peer.rxBytes.Add(uint64(len(elem.packet)))
peer.SendHandshakeResponse()
case MessageResponseType:
// unmarshal
var msg MessageResponse
reader := bytes.NewReader(elem.packet)
err := binary.Read(reader, binary.LittleEndian, &msg)
if err != nil {
device.log.Errorf("Failed to decode response message")
goto skip
}
// consume response
peer := device.ConsumeMessageResponse(&msg)
if peer == nil {
device.log.Verbosef("Received invalid response message from %s", elem.endpoint.DstToString())
goto skip
}
// update endpoint
peer.SetEndpointFromPacket(elem.endpoint)
device.log.Verbosef("%v - Received handshake response", peer)
peer.rxBytes.Add(uint64(len(elem.packet)))
// update timers
peer.timersAnyAuthenticatedPacketTraversal()
peer.timersAnyAuthenticatedPacketReceived()
// derive keypair
err = peer.BeginSymmetricSession()
if err != nil {
device.log.Errorf("%v - Failed to derive keypair: %v", peer, err)
goto skip
}
peer.timersSessionDerived()
peer.timersHandshakeComplete()
peer.SendKeepalive()
}
skip:
device.PutMessageBuffer(elem.buffer)
}
}
func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) {
device := peer.device
defer func() {
device.log.Verbosef("%v - Routine: sequential receiver - stopped", peer)
peer.stopping.Done()
}()
device.log.Verbosef("%v - Routine: sequential receiver - started", peer)
buffs := make([][]byte, 0, maxBatchSize)
for elems := range peer.queue.inbound.c {
if elems == nil {
return
}
for _, elem := range *elems {
elem.Lock()
if elem.packet == nil {
// decryption failed
continue
}
if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) {
continue
}
peer.SetEndpointFromPacket(elem.endpoint)
if peer.ReceivedWithKeypair(elem.keypair) {
peer.timersHandshakeComplete()
peer.SendStagedPackets()
}
peer.keepKeyFreshReceiving()
peer.timersAnyAuthenticatedPacketTraversal()
peer.timersAnyAuthenticatedPacketReceived()
peer.rxBytes.Add(uint64(len(elem.packet) + MinMessageSize))
if len(elem.packet) == 0 {
device.log.Verbosef("%v - Receiving keepalive packet", peer)
continue
}
peer.timersDataReceived()
switch elem.packet[0] >> 4 {
case 4:
if len(elem.packet) < ipv4.HeaderLen {
continue
}
field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2]
length := binary.BigEndian.Uint16(field)
if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen {
continue
}
elem.packet = elem.packet[:length]
src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
if device.allowedips.Lookup(src) != peer {
device.log.Verbosef("IPv4 packet with disallowed source address from %v", peer)
continue
}
case 6:
if len(elem.packet) < ipv6.HeaderLen {
continue
}
field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2]
length := binary.BigEndian.Uint16(field)
length += ipv6.HeaderLen
if int(length) > len(elem.packet) {
continue
}
elem.packet = elem.packet[:length]
src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
if device.allowedips.Lookup(src) != peer {
device.log.Verbosef("IPv6 packet with disallowed source address from %v", peer)
continue
}
default:
device.log.Verbosef("Packet with invalid IP version from %v", peer)
continue
}
buffs = append(buffs, elem.buffer[:MessageTransportOffsetContent+len(elem.packet)])
}
if len(buffs) > 0 {
_, err := device.tun.device.Write(buffs, MessageTransportOffsetContent)
if err != nil && !device.isClosed() {
device.log.Errorf("Failed to write packets to TUN device: %v", err)
}
}
for _, elem := range *elems {
device.PutMessageBuffer(elem.buffer)
device.PutInboundElement(elem)
}
buffs = buffs[:0]
device.PutInboundElementsSlice(elems)
}
}