wireguard-go/device/receive.go

491 lines
12 KiB
Go
Raw Normal View History

2019-01-02 01:55:51 +01:00
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
2019-03-03 04:04:41 +01:00
package device
2017-07-01 23:29:22 +02:00
import (
"bytes"
"encoding/binary"
"errors"
2017-07-01 23:29:22 +02:00
"net"
"sync"
"sync/atomic"
"time"
2019-05-14 09:09:52 +02:00
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
"golang.zx2c4.com/wireguard/conn"
2017-07-01 23:29:22 +02:00
)
type QueueHandshakeElement struct {
msgType uint32
packet []byte
endpoint conn.Endpoint
buffer *[MaxMessageSize]byte
2017-07-01 23:29:22 +02:00
}
type QueueInboundElement struct {
sync.Mutex
2017-11-14 16:27:53 +01:00
buffer *[MaxMessageSize]byte
packet []byte
counter uint64
2018-05-13 18:23:40 +02:00
keypair *Keypair
endpoint conn.Endpoint
2017-07-01 23:29:22 +02:00
}
// clearPointers clears elem fields that contain pointers.
// This makes the garbage collector's life easier and
// avoids accidentally keeping other objects around unnecessarily.
// It also reduces the possible collateral damage from use-after-free bugs.
func (elem *QueueInboundElement) clearPointers() {
elem.buffer = nil
elem.packet = nil
elem.keypair = nil
elem.endpoint = nil
}
/* Called when a new authenticated message has been received
*
* NOTE: Not thread safe, but called by sequential receiver!
*/
func (peer *Peer) keepKeyFreshReceiving() {
2018-05-20 06:50:07 +02:00
if peer.timers.sentLastMinuteHandshake.Get() {
return
}
2018-05-13 23:14:43 +02:00
keypair := peer.keypairs.Current()
if keypair != nil && keypair.isInitiator && time.Since(keypair.created) > (RejectAfterTime-KeepaliveTimeout-RekeyTimeout) {
2018-05-20 06:50:07 +02:00
peer.timers.sentLastMinuteHandshake.Set(true)
peer.SendHandshakeInitiation(false)
}
}
2017-12-01 23:37:26 +01:00
/* Receives incoming datagrams for the device
*
* Every time the bind is updated a new routine is started for
* IPv4 and IPv6 (separately)
*/
func (device *Device) RoutineReceiveIncoming(IP int, bind conn.Bind) {
defer func() {
device.log.Verbosef("Routine: receive incoming IPv%d - stopped", IP)
device.queue.decryption.wg.Done()
device.queue.handshake.wg.Done()
device.net.stopping.Done()
}()
device.log.Verbosef("Routine: receive incoming IPv%d - started", IP)
2017-07-01 23:29:22 +02:00
2017-12-01 00:03:06 +01:00
// receive datagrams until conn is closed
2017-07-01 23:29:22 +02:00
2017-12-01 00:03:06 +01:00
buffer := device.GetMessageBuffer()
2017-07-01 23:29:22 +02:00
2017-12-01 00:03:06 +01:00
var (
err error
size int
endpoint conn.Endpoint
deathSpiral int
2017-12-01 00:03:06 +01:00
)
2017-12-01 00:03:06 +01:00
for {
switch IP {
case ipv4.Version:
size, endpoint, err = bind.ReceiveIPv4(buffer[:])
case ipv6.Version:
size, endpoint, err = bind.ReceiveIPv6(buffer[:])
default:
panic("invalid IP version")
2017-12-01 00:03:06 +01:00
}
2017-08-04 16:15:53 +02:00
2017-12-01 00:03:06 +01:00
if err != nil {
device.PutMessageBuffer(buffer)
if errors.Is(err, conn.NetErrClosed) {
return
}
device.log.Errorf("Failed to receive packet: %v", err)
if deathSpiral < 10 {
deathSpiral++
time.Sleep(time.Second / 3)
continue
}
2017-12-01 00:03:06 +01:00
return
}
deathSpiral = 0
2017-07-01 23:29:22 +02:00
2017-12-01 00:03:06 +01:00
if size < MinMessageSize {
continue
}
2017-07-01 23:29:22 +02:00
2017-12-01 00:03:06 +01:00
// check size of packet
2017-07-01 23:29:22 +02:00
2017-12-01 00:03:06 +01:00
packet := buffer[:size]
msgType := binary.LittleEndian.Uint32(packet[:4])
2017-08-04 16:15:53 +02:00
2017-12-01 00:03:06 +01:00
var okay bool
2017-07-01 23:29:22 +02:00
2017-12-01 00:03:06 +01:00
switch msgType {
2017-07-01 23:29:22 +02:00
2017-12-01 00:03:06 +01:00
// check if transport
2017-07-01 23:29:22 +02:00
2017-12-01 00:03:06 +01:00
case MessageTransportType:
2017-07-01 23:29:22 +02:00
2017-12-01 00:03:06 +01:00
// check size
2017-07-01 23:29:22 +02:00
if len(packet) < MessageTransportSize {
2017-12-01 00:03:06 +01:00
continue
}
2017-07-01 23:29:22 +02:00
2017-12-01 00:03:06 +01:00
// lookup key pair
2017-12-01 00:03:06 +01:00
receiver := binary.LittleEndian.Uint32(
packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter],
)
2018-05-13 18:23:40 +02:00
value := device.indexTable.Lookup(receiver)
keypair := value.keypair
if keypair == nil {
2017-12-01 00:03:06 +01:00
continue
}
2017-07-01 23:29:22 +02:00
2018-05-13 19:50:58 +02:00
// check keypair expiry
2017-07-01 23:29:22 +02:00
2018-05-13 18:23:40 +02:00
if keypair.created.Add(RejectAfterTime).Before(time.Now()) {
2017-12-01 00:03:06 +01:00
continue
}
2017-07-01 23:29:22 +02:00
2017-12-01 00:03:06 +01:00
// create work element
peer := value.peer
2018-09-22 06:29:02 +02:00
elem := device.GetInboundElement()
elem.packet = packet
elem.buffer = buffer
elem.keypair = keypair
elem.endpoint = endpoint
elem.counter = 0
elem.Mutex = sync.Mutex{}
elem.Lock()
2017-07-01 23:29:22 +02:00
2017-12-01 00:03:06 +01:00
// add to decryption queues
if peer.isRunning.Get() {
peer.queue.inbound <- elem
device.queue.decryption.c <- elem
buffer = device.GetMessageBuffer()
} else {
device.PutInboundElement(elem)
}
2017-12-01 00:03:06 +01:00
continue
2017-07-01 23:29:22 +02:00
2017-12-01 00:03:06 +01:00
// otherwise it is a fixed size & handshake related packet
2017-12-01 00:03:06 +01:00
case MessageInitiationType:
okay = len(packet) == MessageInitiationSize
2017-12-01 00:03:06 +01:00
case MessageResponseType:
okay = len(packet) == MessageResponseSize
2017-12-01 00:03:06 +01:00
case MessageCookieReplyType:
okay = len(packet) == MessageCookieReplySize
default:
device.log.Verbosef("Received message with unknown type")
2017-12-01 00:03:06 +01:00
}
2017-12-01 00:03:06 +01:00
if okay {
select {
case device.queue.handshake.c <- QueueHandshakeElement{
msgType: msgType,
buffer: buffer,
packet: packet,
endpoint: endpoint,
}:
buffer = device.GetMessageBuffer()
default:
}
}
2017-07-01 23:29:22 +02:00
}
}
func (device *Device) RoutineDecryption() {
var nonce [chacha20poly1305.NonceSize]byte
defer device.log.Verbosef("Routine: decryption worker - stopped")
device.log.Verbosef("Routine: decryption worker - started")
for elem := range device.queue.decryption.c {
// split message into fields
counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent]
content := elem.packet[MessageTransportOffsetContent:]
// decrypt and release to consumer
var err error
elem.counter = binary.LittleEndian.Uint64(counter)
// copy counter to nonce
binary.LittleEndian.PutUint64(nonce[0x4:0xc], elem.counter)
elem.packet, err = elem.keypair.receive.Open(
content[:0],
nonce[:],
content,
nil,
)
if err != nil {
elem.packet = nil
2017-07-01 23:29:22 +02:00
}
elem.Unlock()
2017-07-01 23:29:22 +02:00
}
}
2017-12-01 23:37:26 +01:00
/* Handles incoming packets related to handshake
2017-07-01 23:29:22 +02:00
*/
func (device *Device) RoutineHandshake() {
defer device.log.Verbosef("Routine: handshake worker - stopped")
device.log.Verbosef("Routine: handshake worker - started")
2017-07-01 23:29:22 +02:00
for elem := range device.queue.handshake.c {
// handle cookie fields and ratelimiting
2017-07-01 23:29:22 +02:00
switch elem.msgType {
2017-07-08 09:23:10 +02:00
case MessageCookieReplyType:
2017-08-14 17:09:25 +02:00
// unmarshal packet
var reply MessageCookieReply
reader := bytes.NewReader(elem.packet)
err := binary.Read(reader, binary.LittleEndian, &reply)
if err != nil {
device.log.Verbosef("Failed to decode cookie reply")
goto skip
2017-07-08 09:23:10 +02:00
}
2017-08-14 17:09:25 +02:00
// lookup peer from index
2017-08-14 17:09:25 +02:00
2018-05-13 18:23:40 +02:00
entry := device.indexTable.Lookup(reply.Receiver)
2017-08-14 17:09:25 +02:00
if entry.peer == nil {
goto skip
2017-08-14 17:09:25 +02:00
}
// consume reply
if peer := entry.peer; peer.isRunning.Get() {
device.log.Verbosef("Receiving cookie response from %s", elem.endpoint.DstToString())
2018-12-19 00:35:53 +01:00
if !peer.cookieGenerator.ConsumeReply(&reply) {
device.log.Verbosef("Could not decrypt invalid cookie response")
2018-12-19 00:35:53 +01:00
}
}
goto skip
2017-07-08 09:23:10 +02:00
case MessageInitiationType, MessageResponseType:
2017-07-08 09:23:10 +02:00
2018-05-13 23:14:43 +02:00
// check mac fields and maybe ratelimit
2017-07-08 09:23:10 +02:00
2018-05-13 23:14:43 +02:00
if !device.cookieChecker.CheckMAC1(elem.packet) {
device.log.Verbosef("Received packet with invalid mac1")
goto skip
2017-07-08 09:23:10 +02:00
}
// endpoints destination address is the source of the datagram
if device.IsUnderLoad() {
2017-10-08 22:03:32 +02:00
// verify MAC2 field
2018-05-13 23:14:43 +02:00
if !device.cookieChecker.CheckMAC2(elem.packet, elem.endpoint.DstToBytes()) {
device.SendHandshakeCookie(&elem)
goto skip
}
2017-10-08 22:03:32 +02:00
// check ratelimiter
if !device.rate.limiter.Allow(elem.endpoint.DstIP()) {
goto skip
}
}
default:
device.log.Errorf("Invalid packet ended up in the handshake queue")
goto skip
}
2017-07-08 09:23:10 +02:00
2017-12-01 23:37:26 +01:00
// handle handshake initiation/response content
2017-07-01 23:29:22 +02:00
switch elem.msgType {
case MessageInitiationType:
2017-07-01 23:29:22 +02:00
// unmarshal
2017-07-01 23:29:22 +02:00
var msg MessageInitiation
reader := bytes.NewReader(elem.packet)
err := binary.Read(reader, binary.LittleEndian, &msg)
if err != nil {
device.log.Errorf("Failed to decode initiation message")
goto skip
}
2017-07-01 23:29:22 +02:00
// consume initiation
2017-07-01 23:29:22 +02:00
peer := device.ConsumeMessageInitiation(&msg)
if peer == nil {
device.log.Verbosef("Received invalid initiation message from %s", elem.endpoint.DstToString())
goto skip
}
2017-08-04 16:15:53 +02:00
// update timers
2017-08-04 16:15:53 +02:00
peer.timersAnyAuthenticatedPacketTraversal()
peer.timersAnyAuthenticatedPacketReceived()
2017-07-07 13:47:09 +02:00
// update endpoint
2018-05-26 02:59:26 +02:00
peer.SetEndpointFromPacket(elem.endpoint)
device.log.Verbosef("%v - Received handshake initiation", peer)
atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)))
2018-04-20 07:13:40 +02:00
2018-05-13 23:14:43 +02:00
peer.SendHandshakeResponse()
case MessageResponseType:
2017-07-01 23:29:22 +02:00
// unmarshal
2017-07-01 23:29:22 +02:00
var msg MessageResponse
reader := bytes.NewReader(elem.packet)
err := binary.Read(reader, binary.LittleEndian, &msg)
if err != nil {
device.log.Errorf("Failed to decode response message")
goto skip
}
2017-07-01 23:29:22 +02:00
// consume response
2017-07-01 23:29:22 +02:00
peer := device.ConsumeMessageResponse(&msg)
if peer == nil {
device.log.Verbosef("Received invalid response message from %s", elem.endpoint.DstToString())
goto skip
}
2017-07-27 23:45:37 +02:00
2017-11-14 16:27:53 +01:00
// update endpoint
2018-05-26 02:59:26 +02:00
peer.SetEndpointFromPacket(elem.endpoint)
2017-11-14 16:27:53 +01:00
device.log.Verbosef("%v - Received handshake response", peer)
atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)))
// update timers
2017-08-04 16:15:53 +02:00
peer.timersAnyAuthenticatedPacketTraversal()
peer.timersAnyAuthenticatedPacketReceived()
2017-08-04 16:15:53 +02:00
2018-05-13 19:50:58 +02:00
// derive keypair
2017-07-01 23:29:22 +02:00
2018-05-13 23:14:43 +02:00
err = peer.BeginSymmetricSession()
if err != nil {
device.log.Errorf("%v - Failed to derive keypair: %v", peer, err)
goto skip
}
2018-05-13 23:14:43 +02:00
peer.timersSessionDerived()
peer.timersHandshakeComplete()
peer.SendKeepalive()
}
skip:
device.PutMessageBuffer(elem.buffer)
2017-07-01 23:29:22 +02:00
}
}
func (peer *Peer) RoutineSequentialReceiver() {
device := peer.device
defer func() {
device.log.Verbosef("%v - Routine: sequential receiver - stopped", peer)
peer.stopping.Done()
2018-02-04 19:18:44 +01:00
}()
device.log.Verbosef("%v - Routine: sequential receiver - started", peer)
for elem := range peer.queue.inbound {
device: remove mutex from Peer send/receive The immediate motivation for this change is an observed deadlock. 1. A goroutine calls peer.Stop. That calls peer.queue.Lock(). 2. Another goroutine is in RoutineSequentialReceiver. It receives an elem from peer.queue.inbound. 3. The peer.Stop goroutine calls close(peer.queue.inbound), close(peer.queue.outbound), and peer.stopping.Wait(). It blocks waiting for RoutineSequentialReceiver and RoutineSequentialSender to exit. 4. The RoutineSequentialReceiver goroutine calls peer.SendStagedPackets(). SendStagedPackets attempts peer.queue.RLock(). That blocks forever because the peer.Stop goroutine holds a write lock on that mutex. A background motivation for this change is that it can be expensive to have a mutex in the hot code path of RoutineSequential*. The mutex was necessary to avoid attempting to send elems on a closed channel. This commit removes that danger by never closing the channel. Instead, we send a sentinel nil value on the channel to indicate to the receiver that it should exit. The only problem with this is that if the receiver exits, we could write an elem into the channel which would never get received. If it never gets received, it cannot get returned to the device pools. To work around this, we use a finalizer. When the channel can be GC'd, the finalizer drains any remaining elements from the channel and restores them to the device pool. After that change, peer.queue.RWMutex no longer makes sense where it is. It is only used to prevent concurrent calls to Start and Stop. Move it to a more sensible location and make it a plain sync.Mutex. Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
2021-02-08 22:02:52 +01:00
if elem == nil {
return
}
var err error
2019-03-21 21:43:04 +01:00
elem.Lock()
if elem.packet == nil {
// decryption failed
goto skip
2019-03-21 21:43:04 +01:00
}
2017-07-10 12:09:19 +02:00
2019-03-21 21:43:04 +01:00
if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) {
goto skip
2019-03-21 21:43:04 +01:00
}
2017-07-08 23:51:26 +02:00
2019-03-21 21:43:04 +01:00
peer.SetEndpointFromPacket(elem.endpoint)
if peer.ReceivedWithKeypair(elem.keypair) {
peer.timersHandshakeComplete()
peer.SendStagedPackets()
2019-03-21 21:43:04 +01:00
}
2017-07-01 23:29:22 +02:00
2019-03-21 21:43:04 +01:00
peer.keepKeyFreshReceiving()
peer.timersAnyAuthenticatedPacketTraversal()
peer.timersAnyAuthenticatedPacketReceived()
atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)+MinMessageSize))
2017-07-01 23:29:22 +02:00
2019-03-21 21:43:04 +01:00
if len(elem.packet) == 0 {
device.log.Verbosef("%v - Receiving keepalive packet", peer)
goto skip
2019-03-21 21:43:04 +01:00
}
peer.timersDataReceived()
2017-07-08 09:23:10 +02:00
2019-03-21 21:43:04 +01:00
switch elem.packet[0] >> 4 {
case ipv4.Version:
if len(elem.packet) < ipv4.HeaderLen {
goto skip
2019-03-21 21:43:04 +01:00
}
field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2]
length := binary.BigEndian.Uint16(field)
if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen {
goto skip
2019-03-21 21:43:04 +01:00
}
elem.packet = elem.packet[:length]
src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
if device.allowedips.LookupIPv4(src) != peer {
device.log.Verbosef("IPv4 packet with disallowed source address from %v", peer)
goto skip
2019-03-21 21:43:04 +01:00
}
2017-07-07 13:47:09 +02:00
2019-03-21 21:43:04 +01:00
case ipv6.Version:
if len(elem.packet) < ipv6.HeaderLen {
goto skip
2019-03-21 21:43:04 +01:00
}
field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2]
length := binary.BigEndian.Uint16(field)
length += ipv6.HeaderLen
if int(length) > len(elem.packet) {
goto skip
2019-03-21 21:43:04 +01:00
}
elem.packet = elem.packet[:length]
src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
if device.allowedips.LookupIPv6(src) != peer {
device.log.Verbosef("IPv6 packet with disallowed source address from %v", peer)
goto skip
}
2019-03-21 21:43:04 +01:00
default:
device.log.Verbosef("Packet with invalid IP version from %v", peer)
goto skip
2019-03-21 21:43:04 +01:00
}
_, err = device.tun.device.Write(elem.buffer[:MessageTransportOffsetContent+len(elem.packet)], MessageTransportOffsetContent)
if err != nil && !device.isClosed() {
device.log.Errorf("Failed to write packet to TUN device: %v", err)
}
2019-07-01 15:23:24 +02:00
if len(peer.queue.inbound) == 0 {
err = device.tun.device.Flush()
2019-07-01 15:23:24 +02:00
if err != nil {
peer.device.log.Errorf("Unable to flush packets: %v", err)
2019-07-01 15:23:24 +02:00
}
2019-03-21 21:43:04 +01:00
}
skip:
device.PutMessageBuffer(elem.buffer)
device.PutInboundElement(elem)
2017-07-01 23:29:22 +02:00
}
}