diff --git a/device/keypair.go b/device/keypair.go index 9c78fa9..d70c7f4 100644 --- a/device/keypair.go +++ b/device/keypair.go @@ -8,7 +8,9 @@ package device import ( "crypto/cipher" "sync" + "sync/atomic" "time" + "unsafe" "golang.zx2c4.com/wireguard/replay" ) @@ -38,6 +40,14 @@ type Keypairs struct { next *Keypair } +func (kp *Keypairs) storeNext(next *Keypair) { + atomic.StorePointer((*unsafe.Pointer)((unsafe.Pointer)(&kp.next)), (unsafe.Pointer)(next)) +} + +func (kp *Keypairs) loadNext() *Keypair { + return (*Keypair)(atomic.LoadPointer((*unsafe.Pointer)((unsafe.Pointer)(&kp.next)))) +} + func (kp *Keypairs) Current() *Keypair { kp.RLock() defer kp.RUnlock() diff --git a/device/noise-protocol.go b/device/noise-protocol.go index a848c47..e6f676c 100644 --- a/device/noise-protocol.go +++ b/device/noise-protocol.go @@ -14,6 +14,7 @@ import ( "golang.org/x/crypto/blake2s" "golang.org/x/crypto/chacha20poly1305" "golang.org/x/crypto/poly1305" + "golang.zx2c4.com/wireguard/tai64n" ) @@ -583,12 +584,12 @@ func (peer *Peer) BeginSymmetricSession() error { defer keypairs.Unlock() previous := keypairs.previous - next := keypairs.next + next := keypairs.loadNext() current := keypairs.current if isInitiator { if next != nil { - keypairs.next = nil + keypairs.storeNext(nil) keypairs.previous = next device.DeleteKeypair(current) } else { @@ -597,7 +598,7 @@ func (peer *Peer) BeginSymmetricSession() error { device.DeleteKeypair(previous) keypairs.current = keypair } else { - keypairs.next = keypair + keypairs.storeNext(keypair) device.DeleteKeypair(next) keypairs.previous = nil device.DeleteKeypair(previous) @@ -608,18 +609,19 @@ func (peer *Peer) BeginSymmetricSession() error { func (peer *Peer) ReceivedWithKeypair(receivedKeypair *Keypair) bool { keypairs := &peer.keypairs - if keypairs.next != receivedKeypair { + + if keypairs.loadNext() != receivedKeypair { return false } keypairs.Lock() defer keypairs.Unlock() - if keypairs.next != receivedKeypair { + if keypairs.loadNext() != receivedKeypair { return false } old := keypairs.previous keypairs.previous = keypairs.current peer.device.DeleteKeypair(old) - keypairs.current = keypairs.next - keypairs.next = nil + keypairs.current = keypairs.loadNext() + keypairs.storeNext(nil) return true } diff --git a/device/noise_test.go b/device/noise_test.go index 6ba3f2e..b5d5845 100644 --- a/device/noise_test.go +++ b/device/noise_test.go @@ -113,7 +113,7 @@ func TestNoiseHandshake(t *testing.T) { t.Fatal("failed to derive keypair for peer 2", err) } - key1 := peer1.keypairs.next + key1 := peer1.keypairs.loadNext() key2 := peer2.keypairs.current // encrypting / decryption test diff --git a/device/peer.go b/device/peer.go index 79d4981..899591b 100644 --- a/device/peer.go +++ b/device/peer.go @@ -223,10 +223,10 @@ func (peer *Peer) ZeroAndFlushAll() { keypairs.Lock() device.DeleteKeypair(keypairs.previous) device.DeleteKeypair(keypairs.current) - device.DeleteKeypair(keypairs.next) + device.DeleteKeypair(keypairs.loadNext()) keypairs.previous = nil keypairs.current = nil - keypairs.next = nil + keypairs.storeNext(nil) keypairs.Unlock() // clear handshake state @@ -254,7 +254,7 @@ func (peer *Peer) ExpireCurrentKeypairs() { keypairs.current.sendNonce = RejectAfterMessages } if keypairs.next != nil { - keypairs.next.sendNonce = RejectAfterMessages + keypairs.loadNext().sendNonce = RejectAfterMessages } keypairs.Unlock() }