bb719d3a6e
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
524 lines
13 KiB
Go
524 lines
13 KiB
Go
/* SPDX-License-Identifier: MIT
|
|
*
|
|
* Copyright (C) 2017-2022 WireGuard LLC. All Rights Reserved.
|
|
*/
|
|
|
|
package device
|
|
|
|
import (
|
|
"runtime"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"golang.zx2c4.com/wireguard/conn"
|
|
"golang.zx2c4.com/wireguard/ratelimiter"
|
|
"golang.zx2c4.com/wireguard/rwcancel"
|
|
"golang.zx2c4.com/wireguard/tun"
|
|
)
|
|
|
|
type Device struct {
|
|
state struct {
|
|
// state holds the device's state. It is accessed atomically.
|
|
// Use the device.deviceState method to read it.
|
|
// device.deviceState does not acquire the mutex, so it captures only a snapshot.
|
|
// During state transitions, the state variable is updated before the device itself.
|
|
// The state is thus either the current state of the device or
|
|
// the intended future state of the device.
|
|
// For example, while executing a call to Up, state will be deviceStateUp.
|
|
// There is no guarantee that that intended future state of the device
|
|
// will become the actual state; Up can fail.
|
|
// The device can also change state multiple times between time of check and time of use.
|
|
// Unsynchronized uses of state must therefore be advisory/best-effort only.
|
|
state atomic.Uint32 // actually a deviceState, but typed uint32 for convenience
|
|
// stopping blocks until all inputs to Device have been closed.
|
|
stopping sync.WaitGroup
|
|
// mu protects state changes.
|
|
sync.Mutex
|
|
}
|
|
|
|
net struct {
|
|
stopping sync.WaitGroup
|
|
sync.RWMutex
|
|
bind conn.Bind // bind interface
|
|
netlinkCancel *rwcancel.RWCancel
|
|
port uint16 // listening port
|
|
fwmark uint32 // mark value (0 = disabled)
|
|
brokenRoaming bool
|
|
}
|
|
|
|
staticIdentity struct {
|
|
sync.RWMutex
|
|
privateKey NoisePrivateKey
|
|
publicKey NoisePublicKey
|
|
}
|
|
|
|
peers struct {
|
|
sync.RWMutex // protects keyMap
|
|
keyMap map[NoisePublicKey]*Peer
|
|
}
|
|
|
|
rate struct {
|
|
underLoadUntil atomic.Int64
|
|
limiter ratelimiter.Ratelimiter
|
|
}
|
|
|
|
allowedips AllowedIPs
|
|
indexTable IndexTable
|
|
cookieChecker CookieChecker
|
|
|
|
pool struct {
|
|
messageBuffers *WaitPool
|
|
inboundElements *WaitPool
|
|
outboundElements *WaitPool
|
|
}
|
|
|
|
queue struct {
|
|
encryption *outboundQueue
|
|
decryption *inboundQueue
|
|
handshake *handshakeQueue
|
|
}
|
|
|
|
tun struct {
|
|
device tun.Device
|
|
mtu atomic.Int32
|
|
}
|
|
|
|
ipcMutex sync.RWMutex
|
|
closed chan struct{}
|
|
log *Logger
|
|
}
|
|
|
|
// deviceState represents the state of a Device.
|
|
// There are three states: down, up, closed.
|
|
// Transitions:
|
|
//
|
|
// down -----+
|
|
// ↑↓ ↓
|
|
// up -> closed
|
|
type deviceState uint32
|
|
|
|
//go:generate go run golang.org/x/tools/cmd/stringer -type deviceState -trimprefix=deviceState
|
|
const (
|
|
deviceStateDown deviceState = iota
|
|
deviceStateUp
|
|
deviceStateClosed
|
|
)
|
|
|
|
// deviceState returns device.state.state as a deviceState
|
|
// See those docs for how to interpret this value.
|
|
func (device *Device) deviceState() deviceState {
|
|
return deviceState(device.state.state.Load())
|
|
}
|
|
|
|
// isClosed reports whether the device is closed (or is closing).
|
|
// See device.state.state comments for how to interpret this value.
|
|
func (device *Device) isClosed() bool {
|
|
return device.deviceState() == deviceStateClosed
|
|
}
|
|
|
|
// isUp reports whether the device is up (or is attempting to come up).
|
|
// See device.state.state comments for how to interpret this value.
|
|
func (device *Device) isUp() bool {
|
|
return device.deviceState() == deviceStateUp
|
|
}
|
|
|
|
// Must hold device.peers.Lock()
|
|
func removePeerLocked(device *Device, peer *Peer, key NoisePublicKey) {
|
|
// stop routing and processing of packets
|
|
device.allowedips.RemoveByPeer(peer)
|
|
peer.Stop()
|
|
|
|
// remove from peer map
|
|
delete(device.peers.keyMap, key)
|
|
}
|
|
|
|
// changeState attempts to change the device state to match want.
|
|
func (device *Device) changeState(want deviceState) (err error) {
|
|
device.state.Lock()
|
|
defer device.state.Unlock()
|
|
old := device.deviceState()
|
|
if old == deviceStateClosed {
|
|
// once closed, always closed
|
|
device.log.Verbosef("Interface closed, ignored requested state %s", want)
|
|
return nil
|
|
}
|
|
switch want {
|
|
case old:
|
|
return nil
|
|
case deviceStateUp:
|
|
device.state.state.Store(uint32(deviceStateUp))
|
|
err = device.upLocked()
|
|
if err == nil {
|
|
break
|
|
}
|
|
fallthrough // up failed; bring the device all the way back down
|
|
case deviceStateDown:
|
|
device.state.state.Store(uint32(deviceStateDown))
|
|
errDown := device.downLocked()
|
|
if err == nil {
|
|
err = errDown
|
|
}
|
|
}
|
|
device.log.Verbosef("Interface state was %s, requested %s, now %s", old, want, device.deviceState())
|
|
return
|
|
}
|
|
|
|
// upLocked attempts to bring the device up and reports whether it succeeded.
|
|
// The caller must hold device.state.mu and is responsible for updating device.state.state.
|
|
func (device *Device) upLocked() error {
|
|
if err := device.BindUpdate(); err != nil {
|
|
device.log.Errorf("Unable to update bind: %v", err)
|
|
return err
|
|
}
|
|
|
|
// The IPC set operation waits for peers to be created before calling Start() on them,
|
|
// so if there's a concurrent IPC set request happening, we should wait for it to complete.
|
|
device.ipcMutex.Lock()
|
|
defer device.ipcMutex.Unlock()
|
|
|
|
device.peers.RLock()
|
|
for _, peer := range device.peers.keyMap {
|
|
peer.Start()
|
|
if peer.persistentKeepaliveInterval.Load() > 0 {
|
|
peer.SendKeepalive()
|
|
}
|
|
}
|
|
device.peers.RUnlock()
|
|
return nil
|
|
}
|
|
|
|
// downLocked attempts to bring the device down.
|
|
// The caller must hold device.state.mu and is responsible for updating device.state.state.
|
|
func (device *Device) downLocked() error {
|
|
err := device.BindClose()
|
|
if err != nil {
|
|
device.log.Errorf("Bind close failed: %v", err)
|
|
}
|
|
|
|
device.peers.RLock()
|
|
for _, peer := range device.peers.keyMap {
|
|
peer.Stop()
|
|
}
|
|
device.peers.RUnlock()
|
|
return err
|
|
}
|
|
|
|
func (device *Device) Up() error {
|
|
return device.changeState(deviceStateUp)
|
|
}
|
|
|
|
func (device *Device) Down() error {
|
|
return device.changeState(deviceStateDown)
|
|
}
|
|
|
|
func (device *Device) IsUnderLoad() bool {
|
|
// check if currently under load
|
|
now := time.Now()
|
|
underLoad := len(device.queue.handshake.c) >= QueueHandshakeSize/8
|
|
if underLoad {
|
|
device.rate.underLoadUntil.Store(now.Add(UnderLoadAfterTime).UnixNano())
|
|
return true
|
|
}
|
|
// check if recently under load
|
|
return device.rate.underLoadUntil.Load() > now.UnixNano()
|
|
}
|
|
|
|
func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
|
|
// lock required resources
|
|
|
|
device.staticIdentity.Lock()
|
|
defer device.staticIdentity.Unlock()
|
|
|
|
if sk.Equals(device.staticIdentity.privateKey) {
|
|
return nil
|
|
}
|
|
|
|
device.peers.Lock()
|
|
defer device.peers.Unlock()
|
|
|
|
lockedPeers := make([]*Peer, 0, len(device.peers.keyMap))
|
|
for _, peer := range device.peers.keyMap {
|
|
peer.handshake.mutex.RLock()
|
|
lockedPeers = append(lockedPeers, peer)
|
|
}
|
|
|
|
// remove peers with matching public keys
|
|
|
|
publicKey := sk.publicKey()
|
|
for key, peer := range device.peers.keyMap {
|
|
if peer.handshake.remoteStatic.Equals(publicKey) {
|
|
peer.handshake.mutex.RUnlock()
|
|
removePeerLocked(device, peer, key)
|
|
peer.handshake.mutex.RLock()
|
|
}
|
|
}
|
|
|
|
// update key material
|
|
|
|
device.staticIdentity.privateKey = sk
|
|
device.staticIdentity.publicKey = publicKey
|
|
device.cookieChecker.Init(publicKey)
|
|
|
|
// do static-static DH pre-computations
|
|
|
|
expiredPeers := make([]*Peer, 0, len(device.peers.keyMap))
|
|
for _, peer := range device.peers.keyMap {
|
|
handshake := &peer.handshake
|
|
handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic)
|
|
expiredPeers = append(expiredPeers, peer)
|
|
}
|
|
|
|
for _, peer := range lockedPeers {
|
|
peer.handshake.mutex.RUnlock()
|
|
}
|
|
for _, peer := range expiredPeers {
|
|
peer.ExpireCurrentKeypairs()
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger) *Device {
|
|
device := new(Device)
|
|
device.state.state.Store(uint32(deviceStateDown))
|
|
device.closed = make(chan struct{})
|
|
device.log = logger
|
|
device.net.bind = bind
|
|
device.tun.device = tunDevice
|
|
mtu, err := device.tun.device.MTU()
|
|
if err != nil {
|
|
device.log.Errorf("Trouble determining MTU, assuming default: %v", err)
|
|
mtu = DefaultMTU
|
|
}
|
|
device.tun.mtu.Store(int32(mtu))
|
|
device.peers.keyMap = make(map[NoisePublicKey]*Peer)
|
|
device.rate.limiter.Init()
|
|
device.indexTable.Init()
|
|
device.PopulatePools()
|
|
|
|
// create queues
|
|
|
|
device.queue.handshake = newHandshakeQueue()
|
|
device.queue.encryption = newOutboundQueue()
|
|
device.queue.decryption = newInboundQueue()
|
|
|
|
// start workers
|
|
|
|
cpus := runtime.NumCPU()
|
|
device.state.stopping.Wait()
|
|
device.queue.encryption.wg.Add(cpus) // One for each RoutineHandshake
|
|
for i := 0; i < cpus; i++ {
|
|
go device.RoutineEncryption(i + 1)
|
|
go device.RoutineDecryption(i + 1)
|
|
go device.RoutineHandshake(i + 1)
|
|
}
|
|
|
|
device.state.stopping.Add(1) // RoutineReadFromTUN
|
|
device.queue.encryption.wg.Add(1) // RoutineReadFromTUN
|
|
go device.RoutineReadFromTUN()
|
|
go device.RoutineTUNEventReader()
|
|
|
|
return device
|
|
}
|
|
|
|
func (device *Device) LookupPeer(pk NoisePublicKey) *Peer {
|
|
device.peers.RLock()
|
|
defer device.peers.RUnlock()
|
|
|
|
return device.peers.keyMap[pk]
|
|
}
|
|
|
|
func (device *Device) RemovePeer(key NoisePublicKey) {
|
|
device.peers.Lock()
|
|
defer device.peers.Unlock()
|
|
// stop peer and remove from routing
|
|
|
|
peer, ok := device.peers.keyMap[key]
|
|
if ok {
|
|
removePeerLocked(device, peer, key)
|
|
}
|
|
}
|
|
|
|
func (device *Device) RemoveAllPeers() {
|
|
device.peers.Lock()
|
|
defer device.peers.Unlock()
|
|
|
|
for key, peer := range device.peers.keyMap {
|
|
removePeerLocked(device, peer, key)
|
|
}
|
|
|
|
device.peers.keyMap = make(map[NoisePublicKey]*Peer)
|
|
}
|
|
|
|
func (device *Device) Close() {
|
|
device.state.Lock()
|
|
defer device.state.Unlock()
|
|
if device.isClosed() {
|
|
return
|
|
}
|
|
device.state.state.Store(uint32(deviceStateClosed))
|
|
device.log.Verbosef("Device closing")
|
|
|
|
device.tun.device.Close()
|
|
device.downLocked()
|
|
|
|
// Remove peers before closing queues,
|
|
// because peers assume that queues are active.
|
|
device.RemoveAllPeers()
|
|
|
|
// We kept a reference to the encryption and decryption queues,
|
|
// in case we started any new peers that might write to them.
|
|
// No new peers are coming; we are done with these queues.
|
|
device.queue.encryption.wg.Done()
|
|
device.queue.decryption.wg.Done()
|
|
device.queue.handshake.wg.Done()
|
|
device.state.stopping.Wait()
|
|
|
|
device.rate.limiter.Close()
|
|
|
|
device.log.Verbosef("Device closed")
|
|
close(device.closed)
|
|
}
|
|
|
|
func (device *Device) Wait() chan struct{} {
|
|
return device.closed
|
|
}
|
|
|
|
func (device *Device) SendKeepalivesToPeersWithCurrentKeypair() {
|
|
if !device.isUp() {
|
|
return
|
|
}
|
|
|
|
device.peers.RLock()
|
|
for _, peer := range device.peers.keyMap {
|
|
peer.keypairs.RLock()
|
|
sendKeepalive := peer.keypairs.current != nil && !peer.keypairs.current.created.Add(RejectAfterTime).Before(time.Now())
|
|
peer.keypairs.RUnlock()
|
|
if sendKeepalive {
|
|
peer.SendKeepalive()
|
|
}
|
|
}
|
|
device.peers.RUnlock()
|
|
}
|
|
|
|
// closeBindLocked closes the device's net.bind.
|
|
// The caller must hold the net mutex.
|
|
func closeBindLocked(device *Device) error {
|
|
var err error
|
|
netc := &device.net
|
|
if netc.netlinkCancel != nil {
|
|
netc.netlinkCancel.Cancel()
|
|
}
|
|
if netc.bind != nil {
|
|
err = netc.bind.Close()
|
|
}
|
|
netc.stopping.Wait()
|
|
return err
|
|
}
|
|
|
|
func (device *Device) Bind() conn.Bind {
|
|
device.net.Lock()
|
|
defer device.net.Unlock()
|
|
return device.net.bind
|
|
}
|
|
|
|
func (device *Device) BindSetMark(mark uint32) error {
|
|
device.net.Lock()
|
|
defer device.net.Unlock()
|
|
|
|
// check if modified
|
|
if device.net.fwmark == mark {
|
|
return nil
|
|
}
|
|
|
|
// update fwmark on existing bind
|
|
device.net.fwmark = mark
|
|
if device.isUp() && device.net.bind != nil {
|
|
if err := device.net.bind.SetMark(mark); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
// clear cached source addresses
|
|
device.peers.RLock()
|
|
for _, peer := range device.peers.keyMap {
|
|
peer.Lock()
|
|
defer peer.Unlock()
|
|
if peer.endpoint != nil {
|
|
peer.endpoint.ClearSrc()
|
|
}
|
|
}
|
|
device.peers.RUnlock()
|
|
|
|
return nil
|
|
}
|
|
|
|
func (device *Device) BindUpdate() error {
|
|
device.net.Lock()
|
|
defer device.net.Unlock()
|
|
|
|
// close existing sockets
|
|
if err := closeBindLocked(device); err != nil {
|
|
return err
|
|
}
|
|
|
|
// open new sockets
|
|
if !device.isUp() {
|
|
return nil
|
|
}
|
|
|
|
// bind to new port
|
|
var err error
|
|
var recvFns []conn.ReceiveFunc
|
|
netc := &device.net
|
|
recvFns, netc.port, err = netc.bind.Open(netc.port)
|
|
if err != nil {
|
|
netc.port = 0
|
|
return err
|
|
}
|
|
netc.netlinkCancel, err = device.startRouteListener(netc.bind)
|
|
if err != nil {
|
|
netc.bind.Close()
|
|
netc.port = 0
|
|
return err
|
|
}
|
|
|
|
// set fwmark
|
|
if netc.fwmark != 0 {
|
|
err = netc.bind.SetMark(netc.fwmark)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
// clear cached source addresses
|
|
device.peers.RLock()
|
|
for _, peer := range device.peers.keyMap {
|
|
peer.Lock()
|
|
defer peer.Unlock()
|
|
if peer.endpoint != nil {
|
|
peer.endpoint.ClearSrc()
|
|
}
|
|
}
|
|
device.peers.RUnlock()
|
|
|
|
// start receiving routines
|
|
device.net.stopping.Add(len(recvFns))
|
|
device.queue.decryption.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption
|
|
device.queue.handshake.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.handshake
|
|
for _, fn := range recvFns {
|
|
go device.RoutineReceiveIncoming(fn)
|
|
}
|
|
|
|
device.log.Verbosef("UDP bind has been updated")
|
|
return nil
|
|
}
|
|
|
|
func (device *Device) BindClose() error {
|
|
device.net.Lock()
|
|
err := closeBindLocked(device)
|
|
device.net.Unlock()
|
|
return err
|
|
}
|