device: uniformly check ECDH output for zeros

For some reason, this was omitted for response messages.

Reported-by: z <dzm@unexpl0.red>
Fixes: 8c34c4c ("First set of code review patches")
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
This commit is contained in:
Jason A. Donenfeld 2023-02-16 15:51:30 +01:00
parent 1e2c3e5a3c
commit c7b76d3d9e
5 changed files with 45 additions and 38 deletions

View File

@ -265,7 +265,7 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
expiredPeers := make([]*Peer, 0, len(device.peers.keyMap)) expiredPeers := make([]*Peer, 0, len(device.peers.keyMap))
for _, peer := range device.peers.keyMap { for _, peer := range device.peers.keyMap {
handshake := &peer.handshake handshake := &peer.handshake
handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic) handshake.precomputedStaticStatic, _ = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic)
expiredPeers = append(expiredPeers, peer) expiredPeers = append(expiredPeers, peer)
} }

View File

@ -9,6 +9,7 @@ import (
"crypto/hmac" "crypto/hmac"
"crypto/rand" "crypto/rand"
"crypto/subtle" "crypto/subtle"
"errors"
"hash" "hash"
"golang.org/x/crypto/blake2s" "golang.org/x/crypto/blake2s"
@ -94,9 +95,14 @@ func (sk *NoisePrivateKey) publicKey() (pk NoisePublicKey) {
return return
} }
func (sk *NoisePrivateKey) sharedSecret(pk NoisePublicKey) (ss [NoisePublicKeySize]byte) { var errInvalidPublicKey = errors.New("invalid public key")
func (sk *NoisePrivateKey) sharedSecret(pk NoisePublicKey) (ss [NoisePublicKeySize]byte, err error) {
apk := (*[NoisePublicKeySize]byte)(&pk) apk := (*[NoisePublicKeySize]byte)(&pk)
ask := (*[NoisePrivateKeySize]byte)(sk) ask := (*[NoisePrivateKeySize]byte)(sk)
curve25519.ScalarMult(&ss, ask, apk) curve25519.ScalarMult(&ss, ask, apk)
return ss if isZero(ss[:]) {
return ss, errInvalidPublicKey
}
return ss, nil
} }

View File

@ -175,8 +175,6 @@ func init() {
} }
func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) { func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) {
errZeroECDHResult := errors.New("ECDH returned all zeros")
device.staticIdentity.RLock() device.staticIdentity.RLock()
defer device.staticIdentity.RUnlock() defer device.staticIdentity.RUnlock()
@ -204,9 +202,9 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
handshake.mixHash(msg.Ephemeral[:]) handshake.mixHash(msg.Ephemeral[:])
// encrypt static key // encrypt static key
ss := handshake.localEphemeral.sharedSecret(handshake.remoteStatic) ss, err := handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
if isZero(ss[:]) { if err != nil {
return nil, errZeroECDHResult return nil, err
} }
var key [chacha20poly1305.KeySize]byte var key [chacha20poly1305.KeySize]byte
KDF2( KDF2(
@ -221,7 +219,7 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
// encrypt timestamp // encrypt timestamp
if isZero(handshake.precomputedStaticStatic[:]) { if isZero(handshake.precomputedStaticStatic[:]) {
return nil, errZeroECDHResult return nil, errInvalidPublicKey
} }
KDF2( KDF2(
&handshake.chainKey, &handshake.chainKey,
@ -264,11 +262,10 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:]) mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:])
// decrypt static key // decrypt static key
var err error
var peerPK NoisePublicKey var peerPK NoisePublicKey
var key [chacha20poly1305.KeySize]byte var key [chacha20poly1305.KeySize]byte
ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral) ss, err := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
if isZero(ss[:]) { if err != nil {
return nil return nil
} }
KDF2(&chainKey, &key, chainKey[:], ss[:]) KDF2(&chainKey, &key, chainKey[:], ss[:])
@ -384,12 +381,16 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
handshake.mixHash(msg.Ephemeral[:]) handshake.mixHash(msg.Ephemeral[:])
handshake.mixKey(msg.Ephemeral[:]) handshake.mixKey(msg.Ephemeral[:])
func() { ss, err := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral)
ss := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral) if err != nil {
handshake.mixKey(ss[:]) return nil, err
ss = handshake.localEphemeral.sharedSecret(handshake.remoteStatic) }
handshake.mixKey(ss[:]) handshake.mixKey(ss[:])
}() ss, err = handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
if err != nil {
return nil, err
}
handshake.mixKey(ss[:])
// add preshared key // add preshared key
@ -406,11 +407,9 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
handshake.mixHash(tau[:]) handshake.mixHash(tau[:])
func() { aead, _ := chacha20poly1305.New(key[:])
aead, _ := chacha20poly1305.New(key[:]) aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:])
aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:]) handshake.mixHash(msg.Empty[:])
handshake.mixHash(msg.Empty[:])
}()
handshake.state = handshakeResponseCreated handshake.state = handshakeResponseCreated
@ -455,17 +454,19 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
mixHash(&hash, &handshake.hash, msg.Ephemeral[:]) mixHash(&hash, &handshake.hash, msg.Ephemeral[:])
mixKey(&chainKey, &handshake.chainKey, msg.Ephemeral[:]) mixKey(&chainKey, &handshake.chainKey, msg.Ephemeral[:])
func() { ss, err := handshake.localEphemeral.sharedSecret(msg.Ephemeral)
ss := handshake.localEphemeral.sharedSecret(msg.Ephemeral) if err != nil {
mixKey(&chainKey, &chainKey, ss[:]) return false
setZero(ss[:]) }
}() mixKey(&chainKey, &chainKey, ss[:])
setZero(ss[:])
func() { ss, err = device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral) if err != nil {
mixKey(&chainKey, &chainKey, ss[:]) return false
setZero(ss[:]) }
}() mixKey(&chainKey, &chainKey, ss[:])
setZero(ss[:])
// add preshared key (psk) // add preshared key (psk)
@ -483,7 +484,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
// authenticate transcript // authenticate transcript
aead, _ := chacha20poly1305.New(key[:]) aead, _ := chacha20poly1305.New(key[:])
_, err := aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:]) _, err = aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:])
if err != nil { if err != nil {
return false return false
} }

View File

@ -24,10 +24,10 @@ func TestCurveWrappers(t *testing.T) {
pk1 := sk1.publicKey() pk1 := sk1.publicKey()
pk2 := sk2.publicKey() pk2 := sk2.publicKey()
ss1 := sk1.sharedSecret(pk2) ss1, err1 := sk1.sharedSecret(pk2)
ss2 := sk2.sharedSecret(pk1) ss2, err2 := sk2.sharedSecret(pk1)
if ss1 != ss2 { if ss1 != ss2 || err1 != nil || err2 != nil {
t.Fatal("Failed to compute shared secet") t.Fatal("Failed to compute shared secet")
} }
} }

View File

@ -92,7 +92,7 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
// pre-compute DH // pre-compute DH
handshake := &peer.handshake handshake := &peer.handshake
handshake.mutex.Lock() handshake.mutex.Lock()
handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(pk) handshake.precomputedStaticStatic, _ = device.staticIdentity.privateKey.sharedSecret(pk)
handshake.remoteStatic = pk handshake.remoteStatic = pk
handshake.mutex.Unlock() handshake.mutex.Unlock()