From 233f079a9479279d2aab68f4accb139ee87ad664 Mon Sep 17 00:00:00 2001 From: "Jason A. Donenfeld" Date: Mon, 7 May 2018 22:27:03 +0200 Subject: [PATCH] Rewrite timers and related state machines --- constants.go | 27 ++- device.go | 10 +- event.go | 43 ---- index.go | 2 +- keypair.go | 14 +- main.go | 15 ++ noise-protocol.go | 33 +-- noise_test.go | 4 +- peer.go | 78 +++---- receive.go | 101 +++++---- send.go | 134 +++++++++--- signal.go | 71 ------- timers.go | 512 +++++++++++++++++----------------------------- uapi.go | 11 +- 14 files changed, 453 insertions(+), 602 deletions(-) delete mode 100644 event.go delete mode 100644 signal.go diff --git a/constants.go b/constants.go index 04b75d7..01af1bb 100644 --- a/constants.go +++ b/constants.go @@ -12,21 +12,18 @@ import ( /* Specification constants */ const ( - RekeyAfterMessages = (1 << 64) - (1 << 16) - 1 - RejectAfterMessages = (1 << 64) - (1 << 4) - 1 - RekeyAfterTime = time.Second * 120 - RekeyAttemptTime = time.Second * 90 - RekeyTimeout = time.Second * 5 - RejectAfterTime = time.Second * 180 - KeepaliveTimeout = time.Second * 10 - CookieRefreshTime = time.Second * 120 - HandshakeInitationRate = time.Second / 20 - PaddingMultiple = 16 -) - -const ( - RekeyAfterTimeReceiving = RejectAfterTime - KeepaliveTimeout - RekeyTimeout - NewHandshakeTime = KeepaliveTimeout + RekeyTimeout // upon failure to acknowledge transport message + RekeyAfterMessages = (1 << 64) - (1 << 16) - 1 + RejectAfterMessages = (1 << 64) - (1 << 4) - 1 + RekeyAfterTime = time.Second * 120 + RekeyAttemptTime = time.Second * 90 + RekeyTimeout = time.Second * 5 + MaxTimerHandshakes = 90 / 5 /* RekeyAttemptTime / RekeyTimeout */ + RekeyTimeoutJitterMaxMs = 334 + RejectAfterTime = time.Second * 180 + KeepaliveTimeout = time.Second * 10 + CookieRefreshTime = time.Second * 120 + HandshakeInitationRate = time.Second / 20 + PaddingMultiple = 16 ) /* Implementation specific constants */ diff --git a/device.go b/device.go index c714b21..e127b5b 100644 --- a/device.go +++ b/device.go @@ -74,8 +74,8 @@ type Device struct { handshake chan QueueHandshakeElement } - signal struct { - stop Signal + signals struct { + stop chan struct{} } tun struct { @@ -302,7 +302,7 @@ func NewDevice(tun TUNDevice, logger *Logger) *Device { // prepare signals - device.signal.stop = NewSignal() + device.signals.stop = make(chan struct{}, 1) // prepare net @@ -400,7 +400,7 @@ func (device *Device) Close() { device.isUp.Set(false) - device.signal.stop.Broadcast() + close(device.signals.stop) device.state.stopping.Wait() device.FlushPacketQueues() @@ -413,5 +413,5 @@ func (device *Device) Close() { } func (device *Device) Wait() chan struct{} { - return device.signal.stop.Wait() + return device.signals.stop } diff --git a/event.go b/event.go deleted file mode 100644 index 6235ba4..0000000 --- a/event.go +++ /dev/null @@ -1,43 +0,0 @@ -package main - -import ( - "sync/atomic" - "time" -) - -type Event struct { - guard int32 - next time.Time - interval time.Duration - C chan struct{} -} - -func newEvent(interval time.Duration) *Event { - return &Event{ - guard: 0, - next: time.Now(), - interval: interval, - C: make(chan struct{}, 1), - } -} - -func (e *Event) Clear() { - select { - case <-e.C: - default: - } -} - -func (e *Event) Fire() { - if e == nil || atomic.SwapInt32(&e.guard, 1) != 0 { - return - } - if now := time.Now(); now.After(e.next) { - select { - case e.C <- struct{}{}: - default: - } - e.next = now.Add(e.interval) - } - atomic.StoreInt32(&e.guard, 0) -} diff --git a/index.go b/index.go index c309f23..4a78d55 100644 --- a/index.go +++ b/index.go @@ -18,7 +18,7 @@ import ( type IndexTableEntry struct { peer *Peer handshake *Handshake - keyPair *KeyPair + keyPair *Keypair } type IndexTable struct { diff --git a/keypair.go b/keypair.go index eaf30b2..07a183d 100644 --- a/keypair.go +++ b/keypair.go @@ -18,7 +18,7 @@ import ( * we plan to resolve this issue; whenever Go allows us to do so. */ -type KeyPair struct { +type Keypair struct { sendNonce uint64 send cipher.AEAD receive cipher.AEAD @@ -29,20 +29,20 @@ type KeyPair struct { remoteIndex uint32 } -type KeyPairs struct { +type Keypairs struct { mutex sync.RWMutex - current *KeyPair - previous *KeyPair - next *KeyPair // not yet "confirmed by transport" + current *Keypair + previous *Keypair + next *Keypair // not yet "confirmed by transport" } -func (kp *KeyPairs) Current() *KeyPair { +func (kp *Keypairs) Current() *Keypair { kp.mutex.RLock() defer kp.mutex.RUnlock() return kp.current } -func (device *Device) DeleteKeyPair(key *KeyPair) { +func (device *Device) DeleteKeypair(key *Keypair) { if key != nil { device.indices.Delete(key.localIndex) } diff --git a/main.go b/main.go index ecfbc50..5001bc4 100644 --- a/main.go +++ b/main.go @@ -30,6 +30,8 @@ func printUsage() { } func warning() { + shouldQuit := false + fmt.Fprintln(os.Stderr, "WARNING WARNING WARNING WARNING WARNING WARNING WARNING") fmt.Fprintln(os.Stderr, "W G") fmt.Fprintln(os.Stderr, "W This is alpha software. It will very likely not G") @@ -37,6 +39,8 @@ func warning() { fmt.Fprintln(os.Stderr, "W horribly wrong. You have been warned. Proceed G") fmt.Fprintln(os.Stderr, "W at your own risk. G") if runtime.GOOS == "linux" { + shouldQuit = os.Getenv("WG_I_PREFER_BUGGY_USERSPACE_TO_POLISHED_KMOD") != "1" + fmt.Fprintln(os.Stderr, "W G") fmt.Fprintln(os.Stderr, "W Furthermore, you are running this software on a G") fmt.Fprintln(os.Stderr, "W Linux kernel, which is probably unnecessary and G") @@ -46,9 +50,20 @@ func warning() { fmt.Fprintln(os.Stderr, "W program. For more information on installing the G") fmt.Fprintln(os.Stderr, "W kernel module, please visit: G") fmt.Fprintln(os.Stderr, "W https://www.wireguard.com/install G") + if shouldQuit { + fmt.Fprintln(os.Stderr, "W G") + fmt.Fprintln(os.Stderr, "W If you still want to use this program, against G") + fmt.Fprintln(os.Stderr, "W the sage advice here, please first export this G") + fmt.Fprintln(os.Stderr, "W environment variable: G") + fmt.Fprintln(os.Stderr, "W WG_I_PREFER_BUGGY_USERSPACE_TO_POLISHED_KMOD=1 G") + } } fmt.Fprintln(os.Stderr, "W G") fmt.Fprintln(os.Stderr, "WARNING WARNING WARNING WARNING WARNING WARNING WARNING") + + if shouldQuit { + os.Exit(1) + } } func main() { diff --git a/noise-protocol.go b/noise-protocol.go index 35e95ef..3abbe4b 100644 --- a/noise-protocol.go +++ b/noise-protocol.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: GPL-2.0 * - * Copyright (C) 2017-2018 Jason A. Donenfeld . All Rights Reserved. + * Copyright (C) 2015-2018 Jason A. Donenfeld . All Rights Reserved. */ package main @@ -488,7 +488,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { /* Derives a new key-pair from the current handshake state * */ -func (peer *Peer) NewKeyPair() *KeyPair { +func (peer *Peer) NewKeypair() *Keypair { device := peer.device handshake := &peer.handshake handshake.mutex.Lock() @@ -528,7 +528,7 @@ func (peer *Peer) NewKeyPair() *KeyPair { // create AEAD instances - keyPair := new(KeyPair) + keyPair := new(Keypair) keyPair.send, _ = chacha20poly1305.New(sendKey[:]) keyPair.receive, _ = chacha20poly1305.New(recvKey[:]) @@ -559,24 +559,27 @@ func (peer *Peer) NewKeyPair() *KeyPair { kp := &peer.keyPairs kp.mutex.Lock() + peer.timersSessionDerived() + + previous := kp.previous + next := kp.next + current := kp.current + if isInitiator { - if kp.previous != nil { - device.DeleteKeyPair(kp.previous) - kp.previous = nil - } - - if kp.next != nil { - kp.previous = kp.next - kp.next = keyPair + if next != nil { + kp.next = nil + kp.previous = next + device.DeleteKeypair(current) } else { - kp.previous = kp.current - kp.current = keyPair - peer.event.newKeyPair.Fire() + kp.previous = current } - + device.DeleteKeypair(previous) + kp.current = keyPair } else { kp.next = keyPair + device.DeleteKeypair(next) kp.previous = nil + device.DeleteKeypair(previous) } kp.mutex.Unlock() diff --git a/noise_test.go b/noise_test.go index 958a4ef..37bfb94 100644 --- a/noise_test.go +++ b/noise_test.go @@ -102,8 +102,8 @@ func TestNoiseHandshake(t *testing.T) { t.Log("deriving keys") - key1 := peer1.NewKeyPair() - key2 := peer2.NewKeyPair() + key1 := peer1.NewKeypair() + key2 := peer2.NewKeypair() if key1 == nil { t.Fatal("failed to dervice key-pair for peer 1") diff --git a/peer.go b/peer.go index 739c8fb..242729e 100644 --- a/peer.go +++ b/peer.go @@ -14,14 +14,13 @@ import ( ) const ( - PeerRoutineNumber = 4 - EventInterval = 10 * time.Millisecond + PeerRoutineNumber = 3 ) type Peer struct { isRunning AtomicBool mutex sync.RWMutex - keyPairs KeyPairs + keyPairs Keypairs handshake Handshake device *Device endpoint Endpoint @@ -34,34 +33,28 @@ type Peer struct { lastHandshakeNano int64 // nano seconds since epoch } - time struct { - mutex sync.RWMutex - lastSend time.Time // last send message - lastHandshake time.Time // last completed handshake - nextKeepalive time.Time + timers struct { + retransmitHandshake *Timer + sendKeepalive *Timer + newHandshake *Timer + zeroKeyMaterial *Timer + persistentKeepalive *Timer + handshakeAttempts uint + needAnotherKeepalive bool + sentLastMinuteHandshake bool + lastSentHandshake time.Time } - event struct { - dataSent *Event - dataReceived *Event - anyAuthenticatedPacketReceived *Event - anyAuthenticatedPacketTraversal *Event - handshakeCompleted *Event - handshakePushDeadline *Event - handshakeBegin *Event - ephemeralKeyCreated *Event - newKeyPair *Event - flushNonceQueue *Event - } - - timer struct { - sendLastMinuteHandshake AtomicBool + signals struct { + newKeypairArrived chan struct{} + flushNonceQueue chan struct{} } queue struct { - nonce chan *QueueOutboundElement // nonce / pre-handshake queue - outbound chan *QueueOutboundElement // sequential ordering of work - inbound chan *QueueInboundElement // sequential ordering of work + nonce chan *QueueOutboundElement // nonce / pre-handshake queue + outbound chan *QueueOutboundElement // sequential ordering of work + inbound chan *QueueInboundElement // sequential ordering of work + packetInNonceQueueIsAwaitingKey bool } routines struct { @@ -188,6 +181,8 @@ func (peer *Peer) Start() { peer.routines.starting.Wait() peer.routines.stopping.Wait() peer.routines.stop = make(chan struct{}) + peer.routines.starting.Add(PeerRoutineNumber) + peer.routines.stopping.Add(PeerRoutineNumber) // prepare queues @@ -195,28 +190,13 @@ func (peer *Peer) Start() { peer.queue.outbound = make(chan *QueueOutboundElement, QueueOutboundSize) peer.queue.inbound = make(chan *QueueInboundElement, QueueInboundSize) - // events - - peer.event.dataSent = newEvent(EventInterval) - peer.event.dataReceived = newEvent(EventInterval) - peer.event.anyAuthenticatedPacketReceived = newEvent(EventInterval) - peer.event.anyAuthenticatedPacketTraversal = newEvent(EventInterval) - peer.event.handshakeCompleted = newEvent(EventInterval) - peer.event.handshakePushDeadline = newEvent(EventInterval) - peer.event.handshakeBegin = newEvent(EventInterval) - peer.event.ephemeralKeyCreated = newEvent(EventInterval) - peer.event.newKeyPair = newEvent(EventInterval) - peer.event.flushNonceQueue = newEvent(EventInterval) - - peer.isRunning.Set(true) + peer.timersInit() + peer.signals.newKeypairArrived = make(chan struct{}, 1) + peer.signals.flushNonceQueue = make(chan struct{}, 1) // wait for routines to start - peer.routines.starting.Add(PeerRoutineNumber) - peer.routines.stopping.Add(PeerRoutineNumber) - go peer.RoutineNonce() - go peer.RoutineTimerHandler() go peer.RoutineSequentialSender() go peer.RoutineSequentialReceiver() @@ -238,6 +218,8 @@ func (peer *Peer) Stop() { device := peer.device device.log.Debug.Println(peer, ": Stopping...") + peer.timersStop() + // stop & wait for ongoing peer routines peer.routines.starting.Wait() @@ -255,9 +237,9 @@ func (peer *Peer) Stop() { kp := &peer.keyPairs kp.mutex.Lock() - device.DeleteKeyPair(kp.previous) - device.DeleteKeyPair(kp.current) - device.DeleteKeyPair(kp.next) + device.DeleteKeypair(kp.previous) + device.DeleteKeypair(kp.current) + device.DeleteKeypair(kp.next) kp.previous = nil kp.current = nil @@ -271,4 +253,6 @@ func (peer *Peer) Stop() { device.indices.Delete(hs.localIndex) hs.Clear() hs.mutex.Unlock() + + peer.FlushNonceQueue() } diff --git a/receive.go b/receive.go index 1cf77b2..0f22a3f 100644 --- a/receive.go +++ b/receive.go @@ -31,7 +31,7 @@ type QueueInboundElement struct { buffer *[MaxMessageSize]byte packet []byte counter uint64 - keyPair *KeyPair + keyPair *Keypair endpoint Endpoint } @@ -99,6 +99,21 @@ func (device *Device) addToHandshakeQueue( } } +/* Called when a new authenticated message has been received + * + * NOTE: Not thread safe, but called by sequential receiver! + */ +func (peer *Peer) keepKeyFreshReceiving() { + if peer.timers.sentLastMinuteHandshake { + return + } + kp := peer.keyPairs.Current() + if kp != nil && kp.isInitiator && time.Now().Sub(kp.created) > (RejectAfterTime-KeepaliveTimeout-RekeyTimeout) { + peer.timers.sentLastMinuteHandshake = true + peer.SendHandshakeInitiation(false) + } +} + /* Receives incoming datagrams for the device * * Every time the bind is updated a new routine is started for @@ -245,7 +260,7 @@ func (device *Device) RoutineDecryption() { for { select { - case <-device.signal.stop.Wait(): + case <-device.signals.stop: return case elem, ok := <-device.queue.decryption: @@ -317,7 +332,7 @@ func (device *Device) RoutineHandshake() { for { select { case elem, ok = <-device.queue.handshake: - case <-device.signal.stop.Wait(): + case <-device.signals.stop: return } @@ -441,8 +456,8 @@ func (device *Device) RoutineHandshake() { // update timers - peer.event.anyAuthenticatedPacketTraversal.Fire() - peer.event.anyAuthenticatedPacketReceived.Fire() + peer.timersAnyAuthenticatedPacketTraversal() + peer.timersAnyAuthenticatedPacketReceived() // update endpoint @@ -460,10 +475,11 @@ func (device *Device) RoutineHandshake() { continue } - peer.TimerEphemeralKeyCreated() - peer.NewKeyPair() + if peer.NewKeypair() == nil { + continue + } - logDebug.Println(peer, ": Creating handshake response") + logDebug.Println(peer, ": Sending handshake response") writer := bytes.NewBuffer(temp[:0]) binary.Write(writer, binary.LittleEndian, response) @@ -472,9 +488,10 @@ func (device *Device) RoutineHandshake() { // send response + peer.timers.lastSentHandshake = time.Now() err = peer.SendBuffer(packet) if err == nil { - peer.event.anyAuthenticatedPacketTraversal.Fire() + peer.timersAnyAuthenticatedPacketTraversal() } else { logError.Println(peer, ": Failed to send handshake response", err) } @@ -510,18 +527,23 @@ func (device *Device) RoutineHandshake() { logDebug.Println(peer, ": Received handshake response") - peer.TimerEphemeralKeyCreated() - // update timers - peer.event.anyAuthenticatedPacketTraversal.Fire() - peer.event.anyAuthenticatedPacketReceived.Fire() - peer.event.handshakeCompleted.Fire() + peer.timersAnyAuthenticatedPacketTraversal() + peer.timersAnyAuthenticatedPacketReceived() // derive key-pair - peer.NewKeyPair() - peer.SendKeepAlive() + if peer.NewKeypair() == nil { + continue + } + + peer.timersHandshakeComplete() + peer.SendKeepalive() + select { + case peer.signals.newKeypairArrived <- struct{}{}: + default: + } } } } @@ -569,38 +591,41 @@ func (peer *Peer) RoutineSequentialReceiver() { continue } - peer.event.anyAuthenticatedPacketTraversal.Fire() - peer.event.anyAuthenticatedPacketReceived.Fire() - peer.KeepKeyFreshReceiving() - - // check if using new key-pair - - kp := &peer.keyPairs - kp.mutex.Lock() - if kp.next == elem.keyPair { - peer.event.handshakeCompleted.Fire() - if kp.previous != nil { - device.DeleteKeyPair(kp.previous) - } - kp.previous = kp.current - kp.current = kp.next - kp.next = nil - } - kp.mutex.Unlock() - // update endpoint peer.mutex.Lock() peer.endpoint = elem.endpoint peer.mutex.Unlock() - // check for keep-alive + // check if using new key-pair + + kp := &peer.keyPairs + kp.mutex.Lock() //TODO: make this into an RW lock to reduce contention here for the equality check which is rarely true + if kp.next == elem.keyPair { + old := kp.previous + kp.previous = kp.current + device.DeleteKeypair(old) + kp.current = kp.next + kp.next = nil + peer.timersHandshakeComplete() + select { + case peer.signals.newKeypairArrived <- struct{}{}: + default: + } + } + kp.mutex.Unlock() + + peer.keepKeyFreshReceiving() + peer.timersAnyAuthenticatedPacketTraversal() + peer.timersAnyAuthenticatedPacketReceived() + + // check for keepalive if len(elem.packet) == 0 { - logDebug.Println(peer, ": Received keep-alive") + logDebug.Println(peer, ": Receiving keepalive packet") continue } - peer.event.dataReceived.Fire() + peer.timersDataReceived() // verify source and strip padding diff --git a/send.go b/send.go index ddebb99..1b35e27 100644 --- a/send.go +++ b/send.go @@ -6,6 +6,7 @@ package main import ( + "bytes" "encoding/binary" "golang.org/x/crypto/chacha20poly1305" "golang.org/x/net/ipv4" @@ -46,21 +47,10 @@ type QueueOutboundElement struct { buffer *[MaxMessageSize]byte // slice holding the packet data packet []byte // slice of "buffer" (always!) nonce uint64 // nonce for encryption - keyPair *KeyPair // key-pair for encryption + keyPair *Keypair // key-pair for encryption peer *Peer // related peer } -func (peer *Peer) flushNonceQueue() { - elems := len(peer.queue.nonce) - for i := 0; i < elems; i++ { - select { - case <-peer.queue.nonce: - default: - return - } - } -} - func (device *Device) NewOutboundElement() *QueueOutboundElement { return &QueueOutboundElement{ dropped: AtomicFalse, @@ -114,6 +104,73 @@ func addToEncryptionQueue( } } +/* Queues a keepalive if no packets are queued for peer + */ +func (peer *Peer) SendKeepalive() bool { + if len(peer.queue.nonce) != 0 || peer.queue.packetInNonceQueueIsAwaitingKey { + return false + } + elem := peer.device.NewOutboundElement() + elem.packet = nil + select { + case peer.queue.nonce <- elem: + peer.device.log.Debug.Println(peer, ": Sending keepalive packet") + return true + default: + return false + } +} + +/* Sends a new handshake initiation message to the peer (endpoint) + */ +func (peer *Peer) SendHandshakeInitiation(isRetry bool) error { + if !isRetry { + peer.timers.handshakeAttempts = 0 + } + + if time.Now().Sub(peer.timers.lastSentHandshake) < RekeyTimeout { + return nil + } + peer.timers.lastSentHandshake = time.Now() //TODO: locking for this variable? + + // create initiation message + + msg, err := peer.device.CreateMessageInitiation(peer) + if err != nil { + return err + } + + peer.device.log.Debug.Println(peer, ": Sending handshake initiation") + + // marshal handshake message + + var buff [MessageInitiationSize]byte + writer := bytes.NewBuffer(buff[:0]) + binary.Write(writer, binary.LittleEndian, msg) + packet := writer.Bytes() + peer.mac.AddMacs(packet) + + // send to endpoint + + peer.timersAnyAuthenticatedPacketTraversal() + peer.timersHandshakeInitiated() + return peer.SendBuffer(packet) +} + +/* Called when a new authenticated message has been send + * + */ +func (peer *Peer) keepKeyFreshSending() { + kp := peer.keyPairs.Current() + if kp == nil { + return + } + nonce := atomic.LoadUint64(&kp.sendNonce) + if nonce > RekeyAfterMessages || (kp.isInitiator && time.Now().Sub(kp.created) > RekeyAfterTime) { + peer.SendHandshakeInitiation(false) + } +} + /* Reads packets from the TUN and inserts * into nonce queue for peer * @@ -180,13 +237,22 @@ func (device *Device) RoutineReadFromTUN() { // insert into nonce/pre-handshake queue if peer.isRunning.Get() { - peer.event.handshakePushDeadline.Fire() + if peer.queue.packetInNonceQueueIsAwaitingKey { + peer.SendHandshakeInitiation(false) + } addToOutboundQueue(peer.queue.nonce, elem) elem = device.NewOutboundElement() } } } +func (peer *Peer) FlushNonceQueue() { + select { + case peer.signals.flushNonceQueue <- struct{}{}: + default: + } +} + /* Queues packets when there is no handshake. * Then assigns nonces to packets sequentially * and creates "work" structs for workers @@ -194,13 +260,14 @@ func (device *Device) RoutineReadFromTUN() { * Obs. A single instance per peer */ func (peer *Peer) RoutineNonce() { - var keyPair *KeyPair + var keyPair *Keypair device := peer.device logDebug := device.log.Debug defer func() { logDebug.Println(peer, ": Routine: nonce worker - stopped") + peer.queue.packetInNonceQueueIsAwaitingKey = false peer.routines.stopping.Done() }() @@ -209,8 +276,7 @@ func (peer *Peer) RoutineNonce() { for { NextPacket: - - peer.event.flushNonceQueue.Clear() + peer.queue.packetInNonceQueueIsAwaitingKey = false select { case <-peer.routines.stop: @@ -225,34 +291,48 @@ func (peer *Peer) RoutineNonce() { // wait for key pair for { - - peer.event.newKeyPair.Clear() - keyPair = peer.keyPairs.Current() if keyPair != nil && keyPair.sendNonce < RejectAfterMessages { if time.Now().Sub(keyPair.created) < RejectAfterTime { break } } + peer.queue.packetInNonceQueueIsAwaitingKey = true - peer.event.handshakeBegin.Fire() + select { + case <-peer.signals.newKeypairArrived: + default: + } + + peer.SendHandshakeInitiation(false) logDebug.Println(peer, ": Awaiting key-pair") select { - case <-peer.event.newKeyPair.C: + case <-peer.signals.newKeypairArrived: logDebug.Println(peer, ": Obtained awaited key-pair") - case <-peer.event.flushNonceQueue.C: - goto NextPacket + case <-peer.signals.flushNonceQueue: + for { + select { + case <-peer.queue.nonce: + default: + goto NextPacket + } + } case <-peer.routines.stop: return } } + peer.queue.packetInNonceQueueIsAwaitingKey = false // populate work element elem.peer = peer elem.nonce = atomic.AddUint64(&keyPair.sendNonce, 1) - 1 + // double check in case of race condition added by future code + if elem.nonce >= RejectAfterMessages { + goto NextPacket + } elem.keyPair = keyPair elem.dropped = AtomicFalse elem.mutex.Lock() @@ -288,7 +368,7 @@ func (device *Device) RoutineEncryption() { // fetch next element select { - case <-device.signal.stop.Wait(): + case <-device.signals.stop: return case elem, ok := <-device.queue.encryption: @@ -389,11 +469,11 @@ func (peer *Peer) RoutineSequentialSender() { // update timers - peer.event.anyAuthenticatedPacketTraversal.Fire() + peer.timersAnyAuthenticatedPacketTraversal() if len(elem.packet) != MessageKeepaliveSize { - peer.event.dataSent.Fire() + peer.timersDataSent() } - peer.KeepKeyFreshSending() + peer.keepKeyFreshSending() } } } diff --git a/signal.go b/signal.go deleted file mode 100644 index 606da52..0000000 --- a/signal.go +++ /dev/null @@ -1,71 +0,0 @@ -/* SPDX-License-Identifier: GPL-2.0 - * - * Copyright (C) 2017-2018 Jason A. Donenfeld . All Rights Reserved. - */ - -package main - -func signalSend(s chan<- struct{}) { - select { - case s <- struct{}{}: - default: - } -} - -type Signal struct { - enabled AtomicBool - C chan struct{} -} - -func NewSignal() (s Signal) { - s.C = make(chan struct{}, 1) - s.Enable() - return -} - -func (s *Signal) Close() { - close(s.C) -} - -func (s *Signal) Disable() { - s.enabled.Set(false) - s.Clear() -} - -func (s *Signal) Enable() { - s.enabled.Set(true) -} - -/* Unblock exactly one listener - */ -func (s *Signal) Send() { - if s.enabled.Get() { - select { - case s.C <- struct{}{}: - default: - } - } -} - -/* Clear the signal if already fired - */ -func (s Signal) Clear() { - select { - case <-s.C: - default: - } -} - -/* Unblocks all listeners (forever) - */ -func (s Signal) Broadcast() { - if s.enabled.Get() { - close(s.C) - } -} - -/* Wait for the signal - */ -func (s Signal) Wait() chan struct{} { - return s.C -} diff --git a/timers.go b/timers.go index 38c9b46..5c72efd 100644 --- a/timers.go +++ b/timers.go @@ -1,355 +1,221 @@ /* SPDX-License-Identifier: GPL-2.0 * - * Copyright (C) 2017-2018 Jason A. Donenfeld . All Rights Reserved. + * Copyright (C) 2015-2018 Jason A. Donenfeld . All Rights Reserved. + * + * This is based heavily on timers.c from the kernel implementation. */ package main import ( - "bytes" - "encoding/binary" "math/rand" "sync/atomic" "time" ) -/* NOTE: - * Notion of validity +/* This Timer structure and related functions should roughly copy the interface of + * the Linux kernel's struct timer_list. */ -/* Called when a new authenticated message has been send - * - */ -func (peer *Peer) KeepKeyFreshSending() { - kp := peer.keyPairs.Current() - if kp == nil { - return - } - nonce := atomic.LoadUint64(&kp.sendNonce) - if nonce > RekeyAfterMessages { - peer.event.handshakeBegin.Fire() - } - if kp.isInitiator && time.Now().Sub(kp.created) > RekeyAfterTime { - peer.event.handshakeBegin.Fire() - } +type Timer struct { + timer *time.Timer + isPending bool } -/* Called when a new authenticated message has been received - * - * NOTE: Not thread safe, but called by sequential receiver! - */ -func (peer *Peer) KeepKeyFreshReceiving() { - if peer.timer.sendLastMinuteHandshake.Get() { - return - } - kp := peer.keyPairs.Current() - if kp == nil { - return - } - if !kp.isInitiator { - return - } - nonce := atomic.LoadUint64(&kp.sendNonce) - send := nonce > RekeyAfterMessages || time.Now().Sub(kp.created) > RekeyAfterTimeReceiving - if send { - // do a last minute attempt at initiating a new handshake - peer.timer.sendLastMinuteHandshake.Set(true) - peer.event.handshakeBegin.Fire() - } -} - -/* Queues a keep-alive if no packets are queued for peer - */ -func (peer *Peer) SendKeepAlive() bool { - if len(peer.queue.nonce) != 0 { - return false - } - elem := peer.device.NewOutboundElement() - elem.packet = nil - select { - case peer.queue.nonce <- elem: - return true - default: - return false - } -} - -/* Called after successfully completing a handshake. - * i.e. after: - * - * - Valid handshake response - * - First transport message under the "next" key - */ -// peer.device.log.Info.Println(peer, ": New handshake completed") - -/* Event: - * An ephemeral key is generated - * - * i.e. after: - * - * CreateMessageInitiation - * CreateMessageResponse - * - * Action: - * Schedule the deletion of all key material - * upon failure to complete a handshake - */ -func (peer *Peer) TimerEphemeralKeyCreated() { - peer.event.ephemeralKeyCreated.Fire() - // peer.timer.zeroAllKeys.Reset(RejectAfterTime * 3) -} - -/* Sends a new handshake initiation message to the peer (endpoint) - */ -func (peer *Peer) sendNewHandshake() error { - - // create initiation message - - msg, err := peer.device.CreateMessageInitiation(peer) - if err != nil { - return err - } - - // marshal handshake message - - var buff [MessageInitiationSize]byte - writer := bytes.NewBuffer(buff[:0]) - binary.Write(writer, binary.LittleEndian, msg) - packet := writer.Bytes() - peer.mac.AddMacs(packet) - - // send to endpoint - - peer.event.anyAuthenticatedPacketTraversal.Fire() - - return peer.SendBuffer(packet) -} - -func newTimer() *time.Timer { - timer := time.NewTimer(time.Hour) - timer.Stop() +func (peer *Peer) NewTimer(expirationFunction func(*Peer)) *Timer { + timer := &Timer{} + timer.timer = time.AfterFunc(time.Hour, func() { + timer.isPending = false + expirationFunction(peer) + }) + timer.timer.Stop() return timer } -func (peer *Peer) RoutineTimerHandler() { +func (timer *Timer) Mod(d time.Duration) { + timer.isPending = true + timer.timer.Reset(d) +} - device := peer.device +func (timer *Timer) Del() { + timer.isPending = false + timer.timer.Stop() +} - logInfo := device.log.Info - logDebug := device.log.Debug +func (peer *Peer) timersActive() bool { + return peer.isRunning.Get() && peer.device != nil && peer.device.isUp.Get() && len(peer.device.peers.keyMap) > 0 +} - defer func() { - logDebug.Println(peer, ": Routine: timer handler - stopped") - peer.routines.stopping.Done() - }() +func expiredRetransmitHandshake(peer *Peer) { + if peer.timers.handshakeAttempts > MaxTimerHandshakes { + peer.device.log.Debug.Printf("%s: Handshake did not complete after %d attempts, giving up\n", peer, MaxTimerHandshakes+2) - logDebug.Println(peer, ": Routine: timer handler - started") + if peer.timersActive() { + peer.timers.sendKeepalive.Del() + } - // reset all timers + /* We drop all packets without a keypair and don't try again, + * if we try unsuccessfully for too long to make a handshake. + */ + peer.FlushNonceQueue() - enableHandshake := true - pendingHandshakeNew := false - pendingKeepalivePassive := false - needAnotherKeepalive := false + /* We set a timer for destroying any residue that might be left + * of a partial exchange. + */ + if peer.timersActive() && !peer.timers.zeroKeyMaterial.isPending { + peer.timers.zeroKeyMaterial.Mod(RejectAfterTime * 3) + } + } else { + peer.timers.handshakeAttempts++ + peer.device.log.Debug.Printf("%s: Handshake did not complete after %d seconds, retrying (try %d)\n", peer, int(RekeyTimeout.Seconds()), peer.timers.handshakeAttempts+1) - timerKeepalivePassive := newTimer() - timerHandshakeDeadline := newTimer() - timerHandshakeTimeout := newTimer() - timerHandshakeNew := newTimer() - timerZeroAllKeys := newTimer() - timerKeepalivePersistent := newTimer() + /* We clear the endpoint address src address, in case this is the cause of trouble. */ + peer.mutex.Lock() + if peer.endpoint != nil { + peer.endpoint.ClearSrc() + } + peer.mutex.Unlock() - interval := peer.persistentKeepaliveInterval - if interval > 0 { - duration := time.Duration(interval) * time.Second - timerKeepalivePersistent.Reset(duration) + peer.SendHandshakeInitiation(true) } +} - // signal synchronised setup complete - - peer.routines.starting.Done() - - // handle timer events - - for { - select { - - /* stopping */ - - case <-peer.routines.stop: - return - - /* events */ - - case <-peer.event.dataSent.C: - timerKeepalivePassive.Stop() - if !pendingHandshakeNew { - timerHandshakeNew.Reset(NewHandshakeTime) - } - - case <-peer.event.dataReceived.C: - if pendingKeepalivePassive { - needAnotherKeepalive = true - } else { - timerKeepalivePassive.Reset(KeepaliveTimeout) - } - - case <-peer.event.anyAuthenticatedPacketTraversal.C: - interval := peer.persistentKeepaliveInterval - if interval > 0 { - duration := time.Duration(interval) * time.Second - timerKeepalivePersistent.Reset(duration) - } - - case <-peer.event.handshakeBegin.C: - - if !enableHandshake { - continue - } - - logDebug.Println(peer, ": Event, Handshake Begin") - - err := peer.sendNewHandshake() - - // set timeout - - jitter := time.Millisecond * time.Duration(rand.Int31n(334)) - timerKeepalivePassive.Stop() - timerHandshakeTimeout.Reset(RekeyTimeout + jitter) - - if err != nil { - logInfo.Println(peer, ": Failed to send handshake initiation", err) - } else { - logDebug.Println(peer, ": Send handshake initiation (initial)") - } - - timerHandshakeDeadline.Reset(RekeyAttemptTime) - - // disable further handshakes - - peer.event.handshakeBegin.Clear() - enableHandshake = false - - case <-peer.event.handshakeCompleted.C: - - logInfo.Println(peer, ": Handshake completed") - - atomic.StoreInt64( - &peer.stats.lastHandshakeNano, - time.Now().UnixNano(), - ) - - timerHandshakeTimeout.Stop() - timerHandshakeDeadline.Stop() - peer.timer.sendLastMinuteHandshake.Set(false) - - // allow further handshakes - - peer.event.handshakeBegin.Clear() - enableHandshake = true - - /* timers */ - - case <-timerKeepalivePersistent.C: - - interval := peer.persistentKeepaliveInterval - if interval > 0 { - logDebug.Println(peer, ": Send keep-alive (persistent)") - timerKeepalivePassive.Stop() - peer.SendKeepAlive() - } - - case <-timerKeepalivePassive.C: - - logDebug.Println(peer, ": Send keep-alive (passive)") - - peer.SendKeepAlive() - - if needAnotherKeepalive { - timerKeepalivePassive.Reset(KeepaliveTimeout) - needAnotherKeepalive = false - } - - case <-timerZeroAllKeys.C: - - logDebug.Println(peer, ": Clear all key-material (timer event)") - - hs := &peer.handshake - hs.mutex.Lock() - - kp := &peer.keyPairs - kp.mutex.Lock() - - // remove key-pairs - - if kp.previous != nil { - device.DeleteKeyPair(kp.previous) - kp.previous = nil - } - if kp.current != nil { - device.DeleteKeyPair(kp.current) - kp.current = nil - } - if kp.next != nil { - device.DeleteKeyPair(kp.next) - kp.next = nil - } - kp.mutex.Unlock() - - // zero out handshake - - device.indices.Delete(hs.localIndex) - hs.Clear() - hs.mutex.Unlock() - - case <-timerHandshakeTimeout.C: - - // allow new handshake to be send - - enableHandshake = true - - // clear source (in case this is causing problems) - - peer.mutex.Lock() - if peer.endpoint != nil { - peer.endpoint.ClearSrc() - } - peer.mutex.Unlock() - - // send new handshake - - err := peer.sendNewHandshake() - - // set timeout - - jitter := time.Millisecond * time.Duration(rand.Int31n(334)) - timerKeepalivePassive.Stop() - timerHandshakeTimeout.Reset(RekeyTimeout + jitter) - - if err != nil { - logInfo.Println(peer, ": Failed to send handshake initiation", err) - } else { - logDebug.Println(peer, ": Send handshake initiation (subsequent)") - } - - // disable further handshakes - - peer.event.handshakeBegin.Clear() - enableHandshake = false - - case <-timerHandshakeDeadline.C: - - // clear all queued packets and stop keep-alive - - logInfo.Println(peer, ": Handshake negotiation timed-out") - - peer.flushNonceQueue() - peer.event.flushNonceQueue.Fire() - - // renable further handshakes - - peer.event.handshakeBegin.Clear() - enableHandshake = true +func expiredSendKeepalive(peer *Peer) { + peer.SendKeepalive() + if peer.timers.needAnotherKeepalive { + peer.timers.needAnotherKeepalive = false + if peer.timersActive() { + peer.timers.sendKeepalive.Mod(KeepaliveTimeout) } } } + +func expiredNewHandshake(peer *Peer) { + peer.device.log.Debug.Printf("%s: Retrying handshake because we stopped hearing back after %d seconds\n", peer, int((KeepaliveTimeout + RekeyTimeout).Seconds())) + /* We clear the endpoint address src address, in case this is the cause of trouble. */ + peer.mutex.Lock() + if peer.endpoint != nil { + peer.endpoint.ClearSrc() + } + peer.mutex.Unlock() + peer.SendHandshakeInitiation(false) + +} + +func expiredZeroKeyMaterial(peer *Peer) { + peer.device.log.Debug.Printf(":%s Removing all keys, since we haven't received a new one in %d seconds\n", peer, int((RejectAfterTime * 3).Seconds())) + + hs := &peer.handshake + hs.mutex.Lock() + + kp := &peer.keyPairs + kp.mutex.Lock() + + if kp.previous != nil { + peer.device.DeleteKeypair(kp.previous) + kp.previous = nil + } + if kp.current != nil { + peer.device.DeleteKeypair(kp.current) + kp.current = nil + } + if kp.next != nil { + peer.device.DeleteKeypair(kp.next) + kp.next = nil + } + kp.mutex.Unlock() + + peer.device.indices.Delete(hs.localIndex) + hs.Clear() + hs.mutex.Unlock() +} + +func expiredPersistentKeepalive(peer *Peer) { + if peer.persistentKeepaliveInterval > 0 { + if peer.timersActive() { + peer.timers.sendKeepalive.Del() + } + peer.SendKeepalive() + } +} + +/* Should be called after an authenticated data packet is sent. */ +func (peer *Peer) timersDataSent() { + if peer.timersActive() { + peer.timers.sendKeepalive.Del() + } + + if peer.timersActive() && !peer.timers.newHandshake.isPending { + peer.timers.newHandshake.Mod(KeepaliveTimeout + RekeyTimeout) + } +} + +/* Should be called after an authenticated data packet is received. */ +func (peer *Peer) timersDataReceived() { + if peer.timersActive() { + if !peer.timers.sendKeepalive.isPending { + peer.timers.sendKeepalive.Mod(KeepaliveTimeout) + } else { + peer.timers.needAnotherKeepalive = true + } + } +} + +/* Should be called after any type of authenticated packet is received -- keepalive or data. */ +func (peer *Peer) timersAnyAuthenticatedPacketReceived() { + if peer.timersActive() { + peer.timers.newHandshake.Del() + } +} + +/* Should be called after a handshake initiation message is sent. */ +func (peer *Peer) timersHandshakeInitiated() { + if peer.timersActive() { + peer.timers.sendKeepalive.Del() + peer.timers.retransmitHandshake.Mod(RekeyTimeout + time.Millisecond*time.Duration(rand.Int31n(RekeyTimeoutJitterMaxMs))) + } +} + +/* Should be called after a handshake response message is received and processed or when getting key confirmation via the first data message. */ +func (peer *Peer) timersHandshakeComplete() { + if peer.timersActive() { + peer.timers.retransmitHandshake.Del() + } + peer.timers.handshakeAttempts = 0 + peer.timers.sentLastMinuteHandshake = false + atomic.StoreInt64(&peer.stats.lastHandshakeNano, time.Now().UnixNano()) +} + +/* Should be called after an ephemeral key is created, which is before sending a handshake response or after receiving a handshake response. */ +func (peer *Peer) timersSessionDerived() { + if peer.timersActive() { + peer.timers.zeroKeyMaterial.Mod(RejectAfterTime * 3) + } +} + +/* Should be called before a packet with authentication -- data, keepalive, either handshake -- is sent, or after one is received. */ +func (peer *Peer) timersAnyAuthenticatedPacketTraversal() { + if peer.persistentKeepaliveInterval > 0 && peer.timersActive() { + peer.timers.persistentKeepalive.Mod(time.Duration(peer.persistentKeepaliveInterval) * time.Second) + } +} + +func (peer *Peer) timersInit() { + peer.timers.retransmitHandshake = peer.NewTimer(expiredRetransmitHandshake) + peer.timers.sendKeepalive = peer.NewTimer(expiredSendKeepalive) + peer.timers.newHandshake = peer.NewTimer(expiredNewHandshake) + peer.timers.zeroKeyMaterial = peer.NewTimer(expiredZeroKeyMaterial) + peer.timers.persistentKeepalive = peer.NewTimer(expiredPersistentKeepalive) + peer.timers.handshakeAttempts = 0 + peer.timers.sentLastMinuteHandshake = false + peer.timers.needAnotherKeepalive = false + peer.timers.lastSentHandshake = time.Now().Add(-(RekeyTimeout + time.Second)) +} + +func (peer *Peer) timersStop() { + peer.timers.retransmitHandshake.Del() + peer.timers.sendKeepalive.Del() + peer.timers.newHandshake.Del() + peer.timers.zeroKeyMaterial.Del() + peer.timers.persistentKeepalive.Del() +} diff --git a/uapi.go b/uapi.go index 54d9bae..4b2038b 100644 --- a/uapi.go +++ b/uapi.go @@ -256,8 +256,6 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { logDebug.Println("UAPI: Created new peer:", peer) } - peer.event.handshakePushDeadline.Fire() - case "remove": // remove currently selected peer from device @@ -288,8 +286,6 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { return &IPCError{Code: ipcErrorInvalid} } - peer.event.handshakePushDeadline.Fire() - case "endpoint": // set endpoint destination @@ -304,7 +300,6 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { return err } peer.endpoint = endpoint - peer.event.handshakePushDeadline.Fire() return nil }() @@ -315,7 +310,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { case "persistent_keepalive_interval": - // update keep-alive interval + // update persistent keepalive interval logDebug.Println("UAPI: Updating persistent_keepalive_interval for peer:", peer) @@ -328,7 +323,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { old := peer.persistentKeepaliveInterval peer.persistentKeepaliveInterval = uint16(secs) - // send immediate keep-alive + // send immediate keepalive if we're turning it on and before it wasn't on if old == 0 && secs != 0 { if err != nil { @@ -336,7 +331,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { return &IPCError{Code: ipcErrorIO} } if device.isUp.Get() && !dummy { - peer.SendKeepAlive() + peer.SendKeepalive() } }