diff --git a/device/device.go b/device/device.go index 8e55724..3368a93 100644 --- a/device/device.go +++ b/device/device.go @@ -265,7 +265,7 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error { expiredPeers := make([]*Peer, 0, len(device.peers.keyMap)) for _, peer := range device.peers.keyMap { handshake := &peer.handshake - handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic) + handshake.precomputedStaticStatic, _ = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic) expiredPeers = append(expiredPeers, peer) } diff --git a/device/noise-helpers.go b/device/noise-helpers.go index 729f8b0..c2f356b 100644 --- a/device/noise-helpers.go +++ b/device/noise-helpers.go @@ -9,6 +9,7 @@ import ( "crypto/hmac" "crypto/rand" "crypto/subtle" + "errors" "hash" "golang.org/x/crypto/blake2s" @@ -94,9 +95,14 @@ func (sk *NoisePrivateKey) publicKey() (pk NoisePublicKey) { return } -func (sk *NoisePrivateKey) sharedSecret(pk NoisePublicKey) (ss [NoisePublicKeySize]byte) { +var errInvalidPublicKey = errors.New("invalid public key") + +func (sk *NoisePrivateKey) sharedSecret(pk NoisePublicKey) (ss [NoisePublicKeySize]byte, err error) { apk := (*[NoisePublicKeySize]byte)(&pk) ask := (*[NoisePrivateKeySize]byte)(sk) curve25519.ScalarMult(&ss, ask, apk) - return ss + if isZero(ss[:]) { + return ss, errInvalidPublicKey + } + return ss, nil } diff --git a/device/noise-protocol.go b/device/noise-protocol.go index 117e960..e8f6145 100644 --- a/device/noise-protocol.go +++ b/device/noise-protocol.go @@ -175,8 +175,6 @@ func init() { } func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) { - errZeroECDHResult := errors.New("ECDH returned all zeros") - device.staticIdentity.RLock() defer device.staticIdentity.RUnlock() @@ -204,9 +202,9 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e handshake.mixHash(msg.Ephemeral[:]) // encrypt static key - ss := handshake.localEphemeral.sharedSecret(handshake.remoteStatic) - if isZero(ss[:]) { - return nil, errZeroECDHResult + ss, err := handshake.localEphemeral.sharedSecret(handshake.remoteStatic) + if err != nil { + return nil, err } var key [chacha20poly1305.KeySize]byte KDF2( @@ -221,7 +219,7 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e // encrypt timestamp if isZero(handshake.precomputedStaticStatic[:]) { - return nil, errZeroECDHResult + return nil, errInvalidPublicKey } KDF2( &handshake.chainKey, @@ -264,11 +262,10 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer { mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:]) // decrypt static key - var err error var peerPK NoisePublicKey var key [chacha20poly1305.KeySize]byte - ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral) - if isZero(ss[:]) { + ss, err := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral) + if err != nil { return nil } KDF2(&chainKey, &key, chainKey[:], ss[:]) @@ -384,12 +381,16 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error handshake.mixHash(msg.Ephemeral[:]) handshake.mixKey(msg.Ephemeral[:]) - func() { - ss := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral) - handshake.mixKey(ss[:]) - ss = handshake.localEphemeral.sharedSecret(handshake.remoteStatic) - handshake.mixKey(ss[:]) - }() + ss, err := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral) + if err != nil { + return nil, err + } + handshake.mixKey(ss[:]) + ss, err = handshake.localEphemeral.sharedSecret(handshake.remoteStatic) + if err != nil { + return nil, err + } + handshake.mixKey(ss[:]) // add preshared key @@ -406,11 +407,9 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error handshake.mixHash(tau[:]) - func() { - aead, _ := chacha20poly1305.New(key[:]) - aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:]) - handshake.mixHash(msg.Empty[:]) - }() + aead, _ := chacha20poly1305.New(key[:]) + aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:]) + handshake.mixHash(msg.Empty[:]) handshake.state = handshakeResponseCreated @@ -455,17 +454,19 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { mixHash(&hash, &handshake.hash, msg.Ephemeral[:]) mixKey(&chainKey, &handshake.chainKey, msg.Ephemeral[:]) - func() { - ss := handshake.localEphemeral.sharedSecret(msg.Ephemeral) - mixKey(&chainKey, &chainKey, ss[:]) - setZero(ss[:]) - }() + ss, err := handshake.localEphemeral.sharedSecret(msg.Ephemeral) + if err != nil { + return false + } + mixKey(&chainKey, &chainKey, ss[:]) + setZero(ss[:]) - func() { - ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral) - mixKey(&chainKey, &chainKey, ss[:]) - setZero(ss[:]) - }() + ss, err = device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral) + if err != nil { + return false + } + mixKey(&chainKey, &chainKey, ss[:]) + setZero(ss[:]) // add preshared key (psk) @@ -483,7 +484,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { // authenticate transcript aead, _ := chacha20poly1305.New(key[:]) - _, err := aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:]) + _, err = aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:]) if err != nil { return false } diff --git a/device/noise_test.go b/device/noise_test.go index 587d1e5..2dd5324 100644 --- a/device/noise_test.go +++ b/device/noise_test.go @@ -24,10 +24,10 @@ func TestCurveWrappers(t *testing.T) { pk1 := sk1.publicKey() pk2 := sk2.publicKey() - ss1 := sk1.sharedSecret(pk2) - ss2 := sk2.sharedSecret(pk1) + ss1, err1 := sk1.sharedSecret(pk2) + ss2, err2 := sk2.sharedSecret(pk1) - if ss1 != ss2 { + if ss1 != ss2 || err1 != nil || err2 != nil { t.Fatal("Failed to compute shared secet") } } diff --git a/device/peer.go b/device/peer.go index 8266dac..0e7b669 100644 --- a/device/peer.go +++ b/device/peer.go @@ -92,7 +92,7 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { // pre-compute DH handshake := &peer.handshake handshake.mutex.Lock() - handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(pk) + handshake.precomputedStaticStatic, _ = device.staticIdentity.privateKey.sharedSecret(pk) handshake.remoteStatic = pk handshake.mutex.Unlock()