device: use atomic access for unlocked keypair.next

Go's GC semantics might not always guarantee the safety of this, and the
race detector gets upset too, so instead we wrap this all in atomic
accessors.

Reported-by: David Anderson <danderson@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
This commit is contained in:
Jason A. Donenfeld 2020-05-02 01:30:23 -06:00
parent fdba6c183a
commit 28c4d04304
4 changed files with 23 additions and 11 deletions

View File

@ -8,7 +8,9 @@ package device
import (
"crypto/cipher"
"sync"
"sync/atomic"
"time"
"unsafe"
"golang.zx2c4.com/wireguard/replay"
)
@ -38,6 +40,14 @@ type Keypairs struct {
next *Keypair
}
func (kp *Keypairs) storeNext(next *Keypair) {
atomic.StorePointer((*unsafe.Pointer)((unsafe.Pointer)(&kp.next)), (unsafe.Pointer)(next))
}
func (kp *Keypairs) loadNext() *Keypair {
return (*Keypair)(atomic.LoadPointer((*unsafe.Pointer)((unsafe.Pointer)(&kp.next))))
}
func (kp *Keypairs) Current() *Keypair {
kp.RLock()
defer kp.RUnlock()

View File

@ -14,6 +14,7 @@ import (
"golang.org/x/crypto/blake2s"
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/crypto/poly1305"
"golang.zx2c4.com/wireguard/tai64n"
)
@ -583,12 +584,12 @@ func (peer *Peer) BeginSymmetricSession() error {
defer keypairs.Unlock()
previous := keypairs.previous
next := keypairs.next
next := keypairs.loadNext()
current := keypairs.current
if isInitiator {
if next != nil {
keypairs.next = nil
keypairs.storeNext(nil)
keypairs.previous = next
device.DeleteKeypair(current)
} else {
@ -597,7 +598,7 @@ func (peer *Peer) BeginSymmetricSession() error {
device.DeleteKeypair(previous)
keypairs.current = keypair
} else {
keypairs.next = keypair
keypairs.storeNext(keypair)
device.DeleteKeypair(next)
keypairs.previous = nil
device.DeleteKeypair(previous)
@ -608,18 +609,19 @@ func (peer *Peer) BeginSymmetricSession() error {
func (peer *Peer) ReceivedWithKeypair(receivedKeypair *Keypair) bool {
keypairs := &peer.keypairs
if keypairs.next != receivedKeypair {
if keypairs.loadNext() != receivedKeypair {
return false
}
keypairs.Lock()
defer keypairs.Unlock()
if keypairs.next != receivedKeypair {
if keypairs.loadNext() != receivedKeypair {
return false
}
old := keypairs.previous
keypairs.previous = keypairs.current
peer.device.DeleteKeypair(old)
keypairs.current = keypairs.next
keypairs.next = nil
keypairs.current = keypairs.loadNext()
keypairs.storeNext(nil)
return true
}

View File

@ -113,7 +113,7 @@ func TestNoiseHandshake(t *testing.T) {
t.Fatal("failed to derive keypair for peer 2", err)
}
key1 := peer1.keypairs.next
key1 := peer1.keypairs.loadNext()
key2 := peer2.keypairs.current
// encrypting / decryption test

View File

@ -223,10 +223,10 @@ func (peer *Peer) ZeroAndFlushAll() {
keypairs.Lock()
device.DeleteKeypair(keypairs.previous)
device.DeleteKeypair(keypairs.current)
device.DeleteKeypair(keypairs.next)
device.DeleteKeypair(keypairs.loadNext())
keypairs.previous = nil
keypairs.current = nil
keypairs.next = nil
keypairs.storeNext(nil)
keypairs.Unlock()
// clear handshake state
@ -254,7 +254,7 @@ func (peer *Peer) ExpireCurrentKeypairs() {
keypairs.current.sendNonce = RejectAfterMessages
}
if keypairs.next != nil {
keypairs.next.sendNonce = RejectAfterMessages
keypairs.loadNext().sendNonce = RejectAfterMessages
}
keypairs.Unlock()
}