0bcb822e5b
This commit simplifies device state management. It creates a single unified state variable and documents its semantics. It also makes state changes more atomic. As an example of the sort of bug that occurred due to non-atomic state changes, the following sequence of events used to occur approximately every 2.5 million test runs: * RoutineTUNEventReader received an EventDown event. * It called device.Down, which called device.setUpDown. * That set device.state.changing, but did not yet attempt to lock device.state.Mutex. * Test completion called device.Close. * device.Close locked device.state.Mutex. * device.Close blocked on a call to device.state.stopping.Wait. * device.setUpDown then attempted to lock device.state.Mutex and blocked. Deadlock results. setUpDown cannot progress because device.state.Mutex is locked. Until setUpDown returns, RoutineTUNEventReader cannot call device.state.stopping.Done. Until device.state.stopping.Done gets called, device.state.stopping.Wait is blocked. As long as device.state.stopping.Wait is blocked, device.state.Mutex cannot be unlocked. This commit fixes that deadlock by holding device.state.mu when checking that the device is not closed. Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
591 lines
14 KiB
Go
591 lines
14 KiB
Go
/* SPDX-License-Identifier: MIT
|
|
*
|
|
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
|
*/
|
|
|
|
package device
|
|
|
|
import (
|
|
"runtime"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"golang.org/x/net/ipv4"
|
|
"golang.org/x/net/ipv6"
|
|
|
|
"golang.zx2c4.com/wireguard/conn"
|
|
"golang.zx2c4.com/wireguard/ratelimiter"
|
|
"golang.zx2c4.com/wireguard/rwcancel"
|
|
"golang.zx2c4.com/wireguard/tun"
|
|
)
|
|
|
|
type Device struct {
|
|
log *Logger
|
|
|
|
// synchronized resources (locks acquired in order)
|
|
|
|
state struct {
|
|
// state holds the device's state. It is accessed atomically.
|
|
// Use the device.deviceState method to read it.
|
|
// If state.mu is (r)locked, state is the current state of the device.
|
|
// Without state.mu (r)locked, state is either the current state
|
|
// of the device or the intended future state of the device.
|
|
// For example, while executing a call to Up, state will be deviceStateUp.
|
|
// There is no guarantee that that intended future state of the device
|
|
// will become the actual state; Up can fail.
|
|
// The device can also change state multiple times between time of check and time of use.
|
|
// Unsynchronized uses of state must therefore be advisory/best-effort only.
|
|
state uint32 // actually a deviceState, but typed uint32 for conveniene
|
|
// stopping blocks until all inputs to Device have been closed.
|
|
stopping sync.WaitGroup
|
|
// mu protects state changes.
|
|
mu sync.Mutex
|
|
}
|
|
|
|
net struct {
|
|
stopping sync.WaitGroup
|
|
sync.RWMutex
|
|
bind conn.Bind // bind interface
|
|
netlinkCancel *rwcancel.RWCancel
|
|
port uint16 // listening port
|
|
fwmark uint32 // mark value (0 = disabled)
|
|
}
|
|
|
|
staticIdentity struct {
|
|
sync.RWMutex
|
|
privateKey NoisePrivateKey
|
|
publicKey NoisePublicKey
|
|
}
|
|
|
|
peers struct {
|
|
empty AtomicBool // empty reports whether len(keyMap) == 0
|
|
sync.RWMutex // protects keyMap
|
|
keyMap map[NoisePublicKey]*Peer
|
|
}
|
|
|
|
// unprotected / "self-synchronising resources"
|
|
|
|
allowedips AllowedIPs
|
|
indexTable IndexTable
|
|
cookieChecker CookieChecker
|
|
|
|
rate struct {
|
|
underLoadUntil int64
|
|
limiter ratelimiter.Ratelimiter
|
|
}
|
|
|
|
pool struct {
|
|
messageBuffers *WaitPool
|
|
inboundElements *WaitPool
|
|
outboundElements *WaitPool
|
|
}
|
|
|
|
queue struct {
|
|
encryption *outboundQueue
|
|
decryption *inboundQueue
|
|
handshake *handshakeQueue
|
|
}
|
|
|
|
tun struct {
|
|
device tun.Device
|
|
mtu int32
|
|
}
|
|
|
|
ipcMutex sync.RWMutex
|
|
closed chan struct{}
|
|
}
|
|
|
|
// deviceState represents the state of a Device.
|
|
// There are four states: new, down, up, closed.
|
|
// However, state new should never be observable.
|
|
// Transitions:
|
|
//
|
|
// new -> down -----+
|
|
// ↑↓ ↓
|
|
// up -> closed
|
|
//
|
|
type deviceState uint32
|
|
|
|
//go:generate stringer -type deviceState -trimprefix=deviceState
|
|
const (
|
|
deviceStateNew deviceState = iota
|
|
deviceStateDown
|
|
deviceStateUp
|
|
deviceStateClosed
|
|
)
|
|
|
|
// deviceState returns device.state.state as a deviceState
|
|
// See those docs for how to interpret this value.
|
|
func (device *Device) deviceState() deviceState {
|
|
return deviceState(atomic.LoadUint32(&device.state.state))
|
|
}
|
|
|
|
// isClosed reports whether the device is closed (or is closing).
|
|
// See device.state.state comments for how to interpret this value.
|
|
func (device *Device) isClosed() bool {
|
|
return device.deviceState() == deviceStateClosed
|
|
}
|
|
|
|
// isUp reports whether the device is up (or is attempting to come up).
|
|
// See device.state.state comments for how to interpret this value.
|
|
func (device *Device) isUp() bool {
|
|
return device.deviceState() == deviceStateUp
|
|
}
|
|
|
|
// An outboundQueue is a channel of QueueOutboundElements awaiting encryption.
|
|
// An outboundQueue is ref-counted using its wg field.
|
|
// An outboundQueue created with newOutboundQueue has one reference.
|
|
// Every additional writer must call wg.Add(1).
|
|
// Every completed writer must call wg.Done().
|
|
// When no further writers will be added,
|
|
// call wg.Done to remove the initial reference.
|
|
// When the refcount hits 0, the queue's channel is closed.
|
|
type outboundQueue struct {
|
|
c chan *QueueOutboundElement
|
|
wg sync.WaitGroup
|
|
}
|
|
|
|
func newOutboundQueue() *outboundQueue {
|
|
q := &outboundQueue{
|
|
c: make(chan *QueueOutboundElement, QueueOutboundSize),
|
|
}
|
|
q.wg.Add(1)
|
|
go func() {
|
|
q.wg.Wait()
|
|
close(q.c)
|
|
}()
|
|
return q
|
|
}
|
|
|
|
// A inboundQueue is similar to an outboundQueue; see those docs.
|
|
type inboundQueue struct {
|
|
c chan *QueueInboundElement
|
|
wg sync.WaitGroup
|
|
}
|
|
|
|
func newInboundQueue() *inboundQueue {
|
|
q := &inboundQueue{
|
|
c: make(chan *QueueInboundElement, QueueInboundSize),
|
|
}
|
|
q.wg.Add(1)
|
|
go func() {
|
|
q.wg.Wait()
|
|
close(q.c)
|
|
}()
|
|
return q
|
|
}
|
|
|
|
// A handshakeQueue is similar to an outboundQueue; see those docs.
|
|
type handshakeQueue struct {
|
|
c chan QueueHandshakeElement
|
|
wg sync.WaitGroup
|
|
}
|
|
|
|
func newHandshakeQueue() *handshakeQueue {
|
|
q := &handshakeQueue{
|
|
c: make(chan QueueHandshakeElement, QueueHandshakeSize),
|
|
}
|
|
q.wg.Add(1)
|
|
go func() {
|
|
q.wg.Wait()
|
|
close(q.c)
|
|
}()
|
|
return q
|
|
}
|
|
|
|
/* Converts the peer into a "zombie", which remains in the peer map,
|
|
* but processes no packets and does not exists in the routing table.
|
|
*
|
|
* Must hold device.peers.Mutex
|
|
*/
|
|
func unsafeRemovePeer(device *Device, peer *Peer, key NoisePublicKey) {
|
|
// stop routing and processing of packets
|
|
device.allowedips.RemoveByPeer(peer)
|
|
peer.Stop()
|
|
|
|
// remove from peer map
|
|
delete(device.peers.keyMap, key)
|
|
device.peers.empty.Set(len(device.peers.keyMap) == 0)
|
|
}
|
|
|
|
// changeState attempts to change the device state to match want.
|
|
func (device *Device) changeState(want deviceState) {
|
|
device.state.mu.Lock()
|
|
defer device.state.mu.Unlock()
|
|
old := device.deviceState()
|
|
if old == deviceStateClosed {
|
|
// once closed, always closed
|
|
device.log.Verbosef("Interface closed, ignored requested state %s", want)
|
|
return
|
|
}
|
|
switch want {
|
|
case old:
|
|
device.log.Verbosef("Interface already in state %s", want)
|
|
return
|
|
case deviceStateUp:
|
|
atomic.StoreUint32(&device.state.state, uint32(deviceStateUp))
|
|
if ok := device.upLocked(); ok {
|
|
break
|
|
}
|
|
fallthrough // up failed; bring the device all the way back down
|
|
case deviceStateDown:
|
|
atomic.StoreUint32(&device.state.state, uint32(deviceStateDown))
|
|
device.downLocked()
|
|
}
|
|
device.log.Verbosef("Interface state was %s, requested %s, now %s", old, want, device.deviceState())
|
|
}
|
|
|
|
// upLocked attempts to bring the device up and reports whether it succeeded.
|
|
// The caller must hold device.state.mu and is responsible for updating device.state.state.
|
|
func (device *Device) upLocked() bool {
|
|
if err := device.BindUpdate(); err != nil {
|
|
device.log.Errorf("Unable to update bind: %v", err)
|
|
return false
|
|
}
|
|
|
|
device.peers.RLock()
|
|
for _, peer := range device.peers.keyMap {
|
|
peer.Start()
|
|
if atomic.LoadUint32(&peer.persistentKeepaliveInterval) > 0 {
|
|
peer.SendKeepalive()
|
|
}
|
|
}
|
|
device.peers.RUnlock()
|
|
return true
|
|
}
|
|
|
|
// downLocked attempts to bring the device down.
|
|
// The caller must hold device.state.mu and is responsible for updating device.state.state.
|
|
func (device *Device) downLocked() {
|
|
err := device.BindClose()
|
|
if err != nil {
|
|
device.log.Errorf("Bind close failed: %v", err)
|
|
}
|
|
|
|
device.peers.RLock()
|
|
for _, peer := range device.peers.keyMap {
|
|
peer.Stop()
|
|
}
|
|
device.peers.RUnlock()
|
|
}
|
|
|
|
func (device *Device) Up() {
|
|
device.changeState(deviceStateUp)
|
|
}
|
|
|
|
func (device *Device) Down() {
|
|
device.changeState(deviceStateDown)
|
|
}
|
|
|
|
func (device *Device) IsUnderLoad() bool {
|
|
// check if currently under load
|
|
now := time.Now()
|
|
underLoad := len(device.queue.handshake.c) >= UnderLoadQueueSize
|
|
if underLoad {
|
|
atomic.StoreInt64(&device.rate.underLoadUntil, now.Add(UnderLoadAfterTime).UnixNano())
|
|
return true
|
|
}
|
|
// check if recently under load
|
|
return atomic.LoadInt64(&device.rate.underLoadUntil) > now.UnixNano()
|
|
}
|
|
|
|
func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
|
|
// lock required resources
|
|
|
|
device.staticIdentity.Lock()
|
|
defer device.staticIdentity.Unlock()
|
|
|
|
if sk.Equals(device.staticIdentity.privateKey) {
|
|
return nil
|
|
}
|
|
|
|
device.peers.Lock()
|
|
defer device.peers.Unlock()
|
|
|
|
lockedPeers := make([]*Peer, 0, len(device.peers.keyMap))
|
|
for _, peer := range device.peers.keyMap {
|
|
peer.handshake.mutex.RLock()
|
|
lockedPeers = append(lockedPeers, peer)
|
|
}
|
|
|
|
// remove peers with matching public keys
|
|
|
|
publicKey := sk.publicKey()
|
|
for key, peer := range device.peers.keyMap {
|
|
if peer.handshake.remoteStatic.Equals(publicKey) {
|
|
peer.handshake.mutex.RUnlock()
|
|
unsafeRemovePeer(device, peer, key)
|
|
peer.handshake.mutex.RLock()
|
|
}
|
|
}
|
|
|
|
// update key material
|
|
|
|
device.staticIdentity.privateKey = sk
|
|
device.staticIdentity.publicKey = publicKey
|
|
device.cookieChecker.Init(publicKey)
|
|
|
|
// do static-static DH pre-computations
|
|
|
|
expiredPeers := make([]*Peer, 0, len(device.peers.keyMap))
|
|
for _, peer := range device.peers.keyMap {
|
|
handshake := &peer.handshake
|
|
handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic)
|
|
expiredPeers = append(expiredPeers, peer)
|
|
}
|
|
|
|
for _, peer := range lockedPeers {
|
|
peer.handshake.mutex.RUnlock()
|
|
}
|
|
for _, peer := range expiredPeers {
|
|
peer.ExpireCurrentKeypairs()
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func NewDevice(tunDevice tun.Device, logger *Logger) *Device {
|
|
device := new(Device)
|
|
device.state.state = uint32(deviceStateDown)
|
|
device.closed = make(chan struct{})
|
|
device.log = logger
|
|
device.tun.device = tunDevice
|
|
mtu, err := device.tun.device.MTU()
|
|
if err != nil {
|
|
device.log.Errorf("Trouble determining MTU, assuming default: %v", err)
|
|
mtu = DefaultMTU
|
|
}
|
|
device.tun.mtu = int32(mtu)
|
|
device.peers.keyMap = make(map[NoisePublicKey]*Peer)
|
|
device.rate.limiter.Init()
|
|
device.indexTable.Init()
|
|
device.PopulatePools()
|
|
|
|
// create queues
|
|
|
|
device.queue.handshake = newHandshakeQueue()
|
|
device.queue.encryption = newOutboundQueue()
|
|
device.queue.decryption = newInboundQueue()
|
|
|
|
// prepare net
|
|
|
|
device.net.port = 0
|
|
device.net.bind = nil
|
|
|
|
// start workers
|
|
|
|
cpus := runtime.NumCPU()
|
|
device.state.stopping.Wait()
|
|
for i := 0; i < cpus; i++ {
|
|
go device.RoutineEncryption()
|
|
go device.RoutineDecryption()
|
|
go device.RoutineHandshake()
|
|
}
|
|
|
|
device.state.stopping.Add(2)
|
|
go device.RoutineReadFromTUN()
|
|
go device.RoutineTUNEventReader()
|
|
|
|
return device
|
|
}
|
|
|
|
func (device *Device) LookupPeer(pk NoisePublicKey) *Peer {
|
|
device.peers.RLock()
|
|
defer device.peers.RUnlock()
|
|
|
|
return device.peers.keyMap[pk]
|
|
}
|
|
|
|
func (device *Device) RemovePeer(key NoisePublicKey) {
|
|
device.peers.Lock()
|
|
defer device.peers.Unlock()
|
|
// stop peer and remove from routing
|
|
|
|
peer, ok := device.peers.keyMap[key]
|
|
if ok {
|
|
unsafeRemovePeer(device, peer, key)
|
|
}
|
|
}
|
|
|
|
func (device *Device) RemoveAllPeers() {
|
|
device.peers.Lock()
|
|
defer device.peers.Unlock()
|
|
|
|
for key, peer := range device.peers.keyMap {
|
|
unsafeRemovePeer(device, peer, key)
|
|
}
|
|
|
|
device.peers.keyMap = make(map[NoisePublicKey]*Peer)
|
|
}
|
|
|
|
func (device *Device) Close() {
|
|
device.state.mu.Lock()
|
|
defer device.state.mu.Unlock()
|
|
if device.isClosed() {
|
|
return
|
|
}
|
|
atomic.StoreUint32(&device.state.state, uint32(deviceStateClosed))
|
|
device.log.Verbosef("Device closing")
|
|
|
|
device.tun.device.Close()
|
|
device.downLocked()
|
|
|
|
// Remove peers before closing queues,
|
|
// because peers assume that queues are active.
|
|
device.RemoveAllPeers()
|
|
|
|
// We kept a reference to the encryption and decryption queues,
|
|
// in case we started any new peers that might write to them.
|
|
// No new peers are coming; we are done with these queues.
|
|
device.queue.encryption.wg.Done()
|
|
device.queue.decryption.wg.Done()
|
|
device.queue.handshake.wg.Done()
|
|
device.state.stopping.Wait()
|
|
|
|
device.rate.limiter.Close()
|
|
|
|
device.log.Verbosef("Device closed")
|
|
close(device.closed)
|
|
}
|
|
|
|
func (device *Device) Wait() chan struct{} {
|
|
return device.closed
|
|
}
|
|
|
|
func (device *Device) SendKeepalivesToPeersWithCurrentKeypair() {
|
|
if !device.isUp() {
|
|
return
|
|
}
|
|
|
|
device.peers.RLock()
|
|
for _, peer := range device.peers.keyMap {
|
|
peer.keypairs.RLock()
|
|
sendKeepalive := peer.keypairs.current != nil && !peer.keypairs.current.created.Add(RejectAfterTime).Before(time.Now())
|
|
peer.keypairs.RUnlock()
|
|
if sendKeepalive {
|
|
peer.SendKeepalive()
|
|
}
|
|
}
|
|
device.peers.RUnlock()
|
|
}
|
|
|
|
func unsafeCloseBind(device *Device) error {
|
|
var err error
|
|
netc := &device.net
|
|
if netc.netlinkCancel != nil {
|
|
netc.netlinkCancel.Cancel()
|
|
}
|
|
if netc.bind != nil {
|
|
err = netc.bind.Close()
|
|
netc.bind = nil
|
|
}
|
|
netc.stopping.Wait()
|
|
return err
|
|
}
|
|
|
|
func (device *Device) Bind() conn.Bind {
|
|
device.net.Lock()
|
|
defer device.net.Unlock()
|
|
return device.net.bind
|
|
}
|
|
|
|
func (device *Device) BindSetMark(mark uint32) error {
|
|
device.net.Lock()
|
|
defer device.net.Unlock()
|
|
|
|
// check if modified
|
|
if device.net.fwmark == mark {
|
|
return nil
|
|
}
|
|
|
|
// update fwmark on existing bind
|
|
device.net.fwmark = mark
|
|
if device.isUp() && device.net.bind != nil {
|
|
if err := device.net.bind.SetMark(mark); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
// clear cached source addresses
|
|
device.peers.RLock()
|
|
for _, peer := range device.peers.keyMap {
|
|
peer.Lock()
|
|
defer peer.Unlock()
|
|
if peer.endpoint != nil {
|
|
peer.endpoint.ClearSrc()
|
|
}
|
|
}
|
|
device.peers.RUnlock()
|
|
|
|
return nil
|
|
}
|
|
|
|
func (device *Device) BindUpdate() error {
|
|
device.net.Lock()
|
|
defer device.net.Unlock()
|
|
|
|
// close existing sockets
|
|
if err := unsafeCloseBind(device); err != nil {
|
|
return err
|
|
}
|
|
|
|
// open new sockets
|
|
if !device.isUp() {
|
|
return nil
|
|
}
|
|
|
|
// bind to new port
|
|
var err error
|
|
netc := &device.net
|
|
netc.bind, netc.port, err = conn.CreateBind(netc.port)
|
|
if err != nil {
|
|
netc.bind = nil
|
|
netc.port = 0
|
|
return err
|
|
}
|
|
netc.netlinkCancel, err = device.startRouteListener(netc.bind)
|
|
if err != nil {
|
|
netc.bind.Close()
|
|
netc.bind = nil
|
|
netc.port = 0
|
|
return err
|
|
}
|
|
|
|
// set fwmark
|
|
if netc.fwmark != 0 {
|
|
err = netc.bind.SetMark(netc.fwmark)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
// clear cached source addresses
|
|
device.peers.RLock()
|
|
for _, peer := range device.peers.keyMap {
|
|
peer.Lock()
|
|
defer peer.Unlock()
|
|
if peer.endpoint != nil {
|
|
peer.endpoint.ClearSrc()
|
|
}
|
|
}
|
|
device.peers.RUnlock()
|
|
|
|
// start receiving routines
|
|
device.net.stopping.Add(2)
|
|
device.queue.decryption.wg.Add(2) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption
|
|
device.queue.handshake.wg.Add(2) // each RoutineReceiveIncoming goroutine writes to device.queue.handshake
|
|
go device.RoutineReceiveIncoming(ipv4.Version, netc.bind)
|
|
go device.RoutineReceiveIncoming(ipv6.Version, netc.bind)
|
|
|
|
device.log.Verbosef("UDP bind has been updated")
|
|
return nil
|
|
}
|
|
|
|
func (device *Device) BindClose() error {
|
|
device.net.Lock()
|
|
err := unsafeCloseBind(device)
|
|
device.net.Unlock()
|
|
return err
|
|
}
|