device: uniformly check ECDH output for zeros
For some reason, this was omitted for response messages.
Reported-by: z <dzm@unexpl0.red>
Fixes: 8c34c4c
("First set of code review patches")
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
This commit is contained in:
parent
1e2c3e5a3c
commit
c7b76d3d9e
@ -265,7 +265,7 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
|
|||||||
expiredPeers := make([]*Peer, 0, len(device.peers.keyMap))
|
expiredPeers := make([]*Peer, 0, len(device.peers.keyMap))
|
||||||
for _, peer := range device.peers.keyMap {
|
for _, peer := range device.peers.keyMap {
|
||||||
handshake := &peer.handshake
|
handshake := &peer.handshake
|
||||||
handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic)
|
handshake.precomputedStaticStatic, _ = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic)
|
||||||
expiredPeers = append(expiredPeers, peer)
|
expiredPeers = append(expiredPeers, peer)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -9,6 +9,7 @@ import (
|
|||||||
"crypto/hmac"
|
"crypto/hmac"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/subtle"
|
"crypto/subtle"
|
||||||
|
"errors"
|
||||||
"hash"
|
"hash"
|
||||||
|
|
||||||
"golang.org/x/crypto/blake2s"
|
"golang.org/x/crypto/blake2s"
|
||||||
@ -94,9 +95,14 @@ func (sk *NoisePrivateKey) publicKey() (pk NoisePublicKey) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sk *NoisePrivateKey) sharedSecret(pk NoisePublicKey) (ss [NoisePublicKeySize]byte) {
|
var errInvalidPublicKey = errors.New("invalid public key")
|
||||||
|
|
||||||
|
func (sk *NoisePrivateKey) sharedSecret(pk NoisePublicKey) (ss [NoisePublicKeySize]byte, err error) {
|
||||||
apk := (*[NoisePublicKeySize]byte)(&pk)
|
apk := (*[NoisePublicKeySize]byte)(&pk)
|
||||||
ask := (*[NoisePrivateKeySize]byte)(sk)
|
ask := (*[NoisePrivateKeySize]byte)(sk)
|
||||||
curve25519.ScalarMult(&ss, ask, apk)
|
curve25519.ScalarMult(&ss, ask, apk)
|
||||||
return ss
|
if isZero(ss[:]) {
|
||||||
|
return ss, errInvalidPublicKey
|
||||||
|
}
|
||||||
|
return ss, nil
|
||||||
}
|
}
|
||||||
|
@ -175,8 +175,6 @@ func init() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) {
|
func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) {
|
||||||
errZeroECDHResult := errors.New("ECDH returned all zeros")
|
|
||||||
|
|
||||||
device.staticIdentity.RLock()
|
device.staticIdentity.RLock()
|
||||||
defer device.staticIdentity.RUnlock()
|
defer device.staticIdentity.RUnlock()
|
||||||
|
|
||||||
@ -204,9 +202,9 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
|
|||||||
handshake.mixHash(msg.Ephemeral[:])
|
handshake.mixHash(msg.Ephemeral[:])
|
||||||
|
|
||||||
// encrypt static key
|
// encrypt static key
|
||||||
ss := handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
|
ss, err := handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
|
||||||
if isZero(ss[:]) {
|
if err != nil {
|
||||||
return nil, errZeroECDHResult
|
return nil, err
|
||||||
}
|
}
|
||||||
var key [chacha20poly1305.KeySize]byte
|
var key [chacha20poly1305.KeySize]byte
|
||||||
KDF2(
|
KDF2(
|
||||||
@ -221,7 +219,7 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
|
|||||||
|
|
||||||
// encrypt timestamp
|
// encrypt timestamp
|
||||||
if isZero(handshake.precomputedStaticStatic[:]) {
|
if isZero(handshake.precomputedStaticStatic[:]) {
|
||||||
return nil, errZeroECDHResult
|
return nil, errInvalidPublicKey
|
||||||
}
|
}
|
||||||
KDF2(
|
KDF2(
|
||||||
&handshake.chainKey,
|
&handshake.chainKey,
|
||||||
@ -264,11 +262,10 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
|
|||||||
mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:])
|
mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:])
|
||||||
|
|
||||||
// decrypt static key
|
// decrypt static key
|
||||||
var err error
|
|
||||||
var peerPK NoisePublicKey
|
var peerPK NoisePublicKey
|
||||||
var key [chacha20poly1305.KeySize]byte
|
var key [chacha20poly1305.KeySize]byte
|
||||||
ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
|
ss, err := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
|
||||||
if isZero(ss[:]) {
|
if err != nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
KDF2(&chainKey, &key, chainKey[:], ss[:])
|
KDF2(&chainKey, &key, chainKey[:], ss[:])
|
||||||
@ -384,12 +381,16 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
|
|||||||
handshake.mixHash(msg.Ephemeral[:])
|
handshake.mixHash(msg.Ephemeral[:])
|
||||||
handshake.mixKey(msg.Ephemeral[:])
|
handshake.mixKey(msg.Ephemeral[:])
|
||||||
|
|
||||||
func() {
|
ss, err := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral)
|
||||||
ss := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral)
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
handshake.mixKey(ss[:])
|
handshake.mixKey(ss[:])
|
||||||
ss = handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
|
ss, err = handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
handshake.mixKey(ss[:])
|
handshake.mixKey(ss[:])
|
||||||
}()
|
|
||||||
|
|
||||||
// add preshared key
|
// add preshared key
|
||||||
|
|
||||||
@ -406,11 +407,9 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
|
|||||||
|
|
||||||
handshake.mixHash(tau[:])
|
handshake.mixHash(tau[:])
|
||||||
|
|
||||||
func() {
|
|
||||||
aead, _ := chacha20poly1305.New(key[:])
|
aead, _ := chacha20poly1305.New(key[:])
|
||||||
aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:])
|
aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:])
|
||||||
handshake.mixHash(msg.Empty[:])
|
handshake.mixHash(msg.Empty[:])
|
||||||
}()
|
|
||||||
|
|
||||||
handshake.state = handshakeResponseCreated
|
handshake.state = handshakeResponseCreated
|
||||||
|
|
||||||
@ -455,17 +454,19 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
|
|||||||
mixHash(&hash, &handshake.hash, msg.Ephemeral[:])
|
mixHash(&hash, &handshake.hash, msg.Ephemeral[:])
|
||||||
mixKey(&chainKey, &handshake.chainKey, msg.Ephemeral[:])
|
mixKey(&chainKey, &handshake.chainKey, msg.Ephemeral[:])
|
||||||
|
|
||||||
func() {
|
ss, err := handshake.localEphemeral.sharedSecret(msg.Ephemeral)
|
||||||
ss := handshake.localEphemeral.sharedSecret(msg.Ephemeral)
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
mixKey(&chainKey, &chainKey, ss[:])
|
mixKey(&chainKey, &chainKey, ss[:])
|
||||||
setZero(ss[:])
|
setZero(ss[:])
|
||||||
}()
|
|
||||||
|
|
||||||
func() {
|
ss, err = device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
|
||||||
ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
mixKey(&chainKey, &chainKey, ss[:])
|
mixKey(&chainKey, &chainKey, ss[:])
|
||||||
setZero(ss[:])
|
setZero(ss[:])
|
||||||
}()
|
|
||||||
|
|
||||||
// add preshared key (psk)
|
// add preshared key (psk)
|
||||||
|
|
||||||
@ -483,7 +484,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
|
|||||||
// authenticate transcript
|
// authenticate transcript
|
||||||
|
|
||||||
aead, _ := chacha20poly1305.New(key[:])
|
aead, _ := chacha20poly1305.New(key[:])
|
||||||
_, err := aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:])
|
_, err = aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
@ -24,10 +24,10 @@ func TestCurveWrappers(t *testing.T) {
|
|||||||
pk1 := sk1.publicKey()
|
pk1 := sk1.publicKey()
|
||||||
pk2 := sk2.publicKey()
|
pk2 := sk2.publicKey()
|
||||||
|
|
||||||
ss1 := sk1.sharedSecret(pk2)
|
ss1, err1 := sk1.sharedSecret(pk2)
|
||||||
ss2 := sk2.sharedSecret(pk1)
|
ss2, err2 := sk2.sharedSecret(pk1)
|
||||||
|
|
||||||
if ss1 != ss2 {
|
if ss1 != ss2 || err1 != nil || err2 != nil {
|
||||||
t.Fatal("Failed to compute shared secet")
|
t.Fatal("Failed to compute shared secet")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -92,7 +92,7 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
|
|||||||
// pre-compute DH
|
// pre-compute DH
|
||||||
handshake := &peer.handshake
|
handshake := &peer.handshake
|
||||||
handshake.mutex.Lock()
|
handshake.mutex.Lock()
|
||||||
handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(pk)
|
handshake.precomputedStaticStatic, _ = device.staticIdentity.privateKey.sharedSecret(pk)
|
||||||
handshake.remoteStatic = pk
|
handshake.remoteStatic = pk
|
||||||
handshake.mutex.Unlock()
|
handshake.mutex.Unlock()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user