f9dac7099e
Googlers have a habit of graffiting their name in TODO items that then are never addressed, and other people won't go near those because they're marked territory of another animal. I've been gradually cleaning these up as I see them, but this commit just goes all the way and removes the remaining stragglers. Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
626 lines
15 KiB
Go
626 lines
15 KiB
Go
/* SPDX-License-Identifier: MIT
|
|
*
|
|
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
|
*/
|
|
|
|
package device
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"sync"
|
|
"time"
|
|
|
|
"golang.org/x/crypto/blake2s"
|
|
"golang.org/x/crypto/chacha20poly1305"
|
|
"golang.org/x/crypto/poly1305"
|
|
|
|
"golang.zx2c4.com/wireguard/tai64n"
|
|
)
|
|
|
|
type handshakeState int
|
|
|
|
const (
|
|
handshakeZeroed = handshakeState(iota)
|
|
handshakeInitiationCreated
|
|
handshakeInitiationConsumed
|
|
handshakeResponseCreated
|
|
handshakeResponseConsumed
|
|
)
|
|
|
|
func (hs handshakeState) String() string {
|
|
switch hs {
|
|
case handshakeZeroed:
|
|
return "handshakeZeroed"
|
|
case handshakeInitiationCreated:
|
|
return "handshakeInitiationCreated"
|
|
case handshakeInitiationConsumed:
|
|
return "handshakeInitiationConsumed"
|
|
case handshakeResponseCreated:
|
|
return "handshakeResponseCreated"
|
|
case handshakeResponseConsumed:
|
|
return "handshakeResponseConsumed"
|
|
default:
|
|
return fmt.Sprintf("Handshake(UNKNOWN:%d)", int(hs))
|
|
}
|
|
}
|
|
|
|
const (
|
|
NoiseConstruction = "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s"
|
|
WGIdentifier = "WireGuard v1 zx2c4 Jason@zx2c4.com"
|
|
WGLabelMAC1 = "mac1----"
|
|
WGLabelCookie = "cookie--"
|
|
)
|
|
|
|
const (
|
|
MessageInitiationType = 1
|
|
MessageResponseType = 2
|
|
MessageCookieReplyType = 3
|
|
MessageTransportType = 4
|
|
)
|
|
|
|
const (
|
|
MessageInitiationSize = 148 // size of handshake initiation message
|
|
MessageResponseSize = 92 // size of response message
|
|
MessageCookieReplySize = 64 // size of cookie reply message
|
|
MessageTransportHeaderSize = 16 // size of data preceding content in transport message
|
|
MessageTransportSize = MessageTransportHeaderSize + poly1305.TagSize // size of empty transport
|
|
MessageKeepaliveSize = MessageTransportSize // size of keepalive
|
|
MessageHandshakeSize = MessageInitiationSize // size of largest handshake related message
|
|
)
|
|
|
|
const (
|
|
MessageTransportOffsetReceiver = 4
|
|
MessageTransportOffsetCounter = 8
|
|
MessageTransportOffsetContent = 16
|
|
)
|
|
|
|
/* Type is an 8-bit field, followed by 3 nul bytes,
|
|
* by marshalling the messages in little-endian byteorder
|
|
* we can treat these as a 32-bit unsigned int (for now)
|
|
*
|
|
*/
|
|
|
|
type MessageInitiation struct {
|
|
Type uint32
|
|
Sender uint32
|
|
Ephemeral NoisePublicKey
|
|
Static [NoisePublicKeySize + poly1305.TagSize]byte
|
|
Timestamp [tai64n.TimestampSize + poly1305.TagSize]byte
|
|
MAC1 [blake2s.Size128]byte
|
|
MAC2 [blake2s.Size128]byte
|
|
}
|
|
|
|
type MessageResponse struct {
|
|
Type uint32
|
|
Sender uint32
|
|
Receiver uint32
|
|
Ephemeral NoisePublicKey
|
|
Empty [poly1305.TagSize]byte
|
|
MAC1 [blake2s.Size128]byte
|
|
MAC2 [blake2s.Size128]byte
|
|
}
|
|
|
|
type MessageTransport struct {
|
|
Type uint32
|
|
Receiver uint32
|
|
Counter uint64
|
|
Content []byte
|
|
}
|
|
|
|
type MessageCookieReply struct {
|
|
Type uint32
|
|
Receiver uint32
|
|
Nonce [chacha20poly1305.NonceSizeX]byte
|
|
Cookie [blake2s.Size128 + poly1305.TagSize]byte
|
|
}
|
|
|
|
type Handshake struct {
|
|
state handshakeState
|
|
mutex sync.RWMutex
|
|
hash [blake2s.Size]byte // hash value
|
|
chainKey [blake2s.Size]byte // chain key
|
|
presharedKey NoisePresharedKey // psk
|
|
localEphemeral NoisePrivateKey // ephemeral secret key
|
|
localIndex uint32 // used to clear hash-table
|
|
remoteIndex uint32 // index for sending
|
|
remoteStatic NoisePublicKey // long term key
|
|
remoteEphemeral NoisePublicKey // ephemeral public key
|
|
precomputedStaticStatic [NoisePublicKeySize]byte // precomputed shared secret
|
|
lastTimestamp tai64n.Timestamp
|
|
lastInitiationConsumption time.Time
|
|
lastSentHandshake time.Time
|
|
}
|
|
|
|
var (
|
|
InitialChainKey [blake2s.Size]byte
|
|
InitialHash [blake2s.Size]byte
|
|
ZeroNonce [chacha20poly1305.NonceSize]byte
|
|
)
|
|
|
|
func mixKey(dst *[blake2s.Size]byte, c *[blake2s.Size]byte, data []byte) {
|
|
KDF1(dst, c[:], data)
|
|
}
|
|
|
|
func mixHash(dst *[blake2s.Size]byte, h *[blake2s.Size]byte, data []byte) {
|
|
hash, _ := blake2s.New256(nil)
|
|
hash.Write(h[:])
|
|
hash.Write(data)
|
|
hash.Sum(dst[:0])
|
|
hash.Reset()
|
|
}
|
|
|
|
func (h *Handshake) Clear() {
|
|
setZero(h.localEphemeral[:])
|
|
setZero(h.remoteEphemeral[:])
|
|
setZero(h.chainKey[:])
|
|
setZero(h.hash[:])
|
|
h.localIndex = 0
|
|
h.state = handshakeZeroed
|
|
}
|
|
|
|
func (h *Handshake) mixHash(data []byte) {
|
|
mixHash(&h.hash, &h.hash, data)
|
|
}
|
|
|
|
func (h *Handshake) mixKey(data []byte) {
|
|
mixKey(&h.chainKey, &h.chainKey, data)
|
|
}
|
|
|
|
/* Do basic precomputations
|
|
*/
|
|
func init() {
|
|
InitialChainKey = blake2s.Sum256([]byte(NoiseConstruction))
|
|
mixHash(&InitialHash, &InitialChainKey, []byte(WGIdentifier))
|
|
}
|
|
|
|
func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) {
|
|
var errZeroECDHResult = errors.New("ECDH returned all zeros")
|
|
|
|
device.staticIdentity.RLock()
|
|
defer device.staticIdentity.RUnlock()
|
|
|
|
handshake := &peer.handshake
|
|
handshake.mutex.Lock()
|
|
defer handshake.mutex.Unlock()
|
|
|
|
// create ephemeral key
|
|
var err error
|
|
handshake.hash = InitialHash
|
|
handshake.chainKey = InitialChainKey
|
|
handshake.localEphemeral, err = newPrivateKey()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
handshake.mixHash(handshake.remoteStatic[:])
|
|
|
|
msg := MessageInitiation{
|
|
Type: MessageInitiationType,
|
|
Ephemeral: handshake.localEphemeral.publicKey(),
|
|
}
|
|
|
|
handshake.mixKey(msg.Ephemeral[:])
|
|
handshake.mixHash(msg.Ephemeral[:])
|
|
|
|
// encrypt static key
|
|
ss := handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
|
|
if isZero(ss[:]) {
|
|
return nil, errZeroECDHResult
|
|
}
|
|
var key [chacha20poly1305.KeySize]byte
|
|
KDF2(
|
|
&handshake.chainKey,
|
|
&key,
|
|
handshake.chainKey[:],
|
|
ss[:],
|
|
)
|
|
aead, _ := chacha20poly1305.New(key[:])
|
|
aead.Seal(msg.Static[:0], ZeroNonce[:], device.staticIdentity.publicKey[:], handshake.hash[:])
|
|
handshake.mixHash(msg.Static[:])
|
|
|
|
// encrypt timestamp
|
|
if isZero(handshake.precomputedStaticStatic[:]) {
|
|
return nil, errZeroECDHResult
|
|
}
|
|
KDF2(
|
|
&handshake.chainKey,
|
|
&key,
|
|
handshake.chainKey[:],
|
|
handshake.precomputedStaticStatic[:],
|
|
)
|
|
timestamp := tai64n.Now()
|
|
aead, _ = chacha20poly1305.New(key[:])
|
|
aead.Seal(msg.Timestamp[:0], ZeroNonce[:], timestamp[:], handshake.hash[:])
|
|
|
|
// assign index
|
|
device.indexTable.Delete(handshake.localIndex)
|
|
msg.Sender, err = device.indexTable.NewIndexForHandshake(peer, handshake)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
handshake.localIndex = msg.Sender
|
|
|
|
handshake.mixHash(msg.Timestamp[:])
|
|
handshake.state = handshakeInitiationCreated
|
|
return &msg, nil
|
|
}
|
|
|
|
func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
|
|
var (
|
|
hash [blake2s.Size]byte
|
|
chainKey [blake2s.Size]byte
|
|
)
|
|
|
|
if msg.Type != MessageInitiationType {
|
|
return nil
|
|
}
|
|
|
|
device.staticIdentity.RLock()
|
|
defer device.staticIdentity.RUnlock()
|
|
|
|
mixHash(&hash, &InitialHash, device.staticIdentity.publicKey[:])
|
|
mixHash(&hash, &hash, msg.Ephemeral[:])
|
|
mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:])
|
|
|
|
// decrypt static key
|
|
var err error
|
|
var peerPK NoisePublicKey
|
|
var key [chacha20poly1305.KeySize]byte
|
|
ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
|
|
if isZero(ss[:]) {
|
|
return nil
|
|
}
|
|
KDF2(&chainKey, &key, chainKey[:], ss[:])
|
|
aead, _ := chacha20poly1305.New(key[:])
|
|
_, err = aead.Open(peerPK[:0], ZeroNonce[:], msg.Static[:], hash[:])
|
|
if err != nil {
|
|
return nil
|
|
}
|
|
mixHash(&hash, &hash, msg.Static[:])
|
|
|
|
// lookup peer
|
|
|
|
peer := device.LookupPeer(peerPK)
|
|
if peer == nil {
|
|
return nil
|
|
}
|
|
|
|
handshake := &peer.handshake
|
|
|
|
// verify identity
|
|
|
|
var timestamp tai64n.Timestamp
|
|
|
|
handshake.mutex.RLock()
|
|
|
|
if isZero(handshake.precomputedStaticStatic[:]) {
|
|
handshake.mutex.RUnlock()
|
|
return nil
|
|
}
|
|
KDF2(
|
|
&chainKey,
|
|
&key,
|
|
chainKey[:],
|
|
handshake.precomputedStaticStatic[:],
|
|
)
|
|
aead, _ = chacha20poly1305.New(key[:])
|
|
_, err = aead.Open(timestamp[:0], ZeroNonce[:], msg.Timestamp[:], hash[:])
|
|
if err != nil {
|
|
handshake.mutex.RUnlock()
|
|
return nil
|
|
}
|
|
mixHash(&hash, &hash, msg.Timestamp[:])
|
|
|
|
// protect against replay & flood
|
|
|
|
replay := !timestamp.After(handshake.lastTimestamp)
|
|
flood := time.Since(handshake.lastInitiationConsumption) <= HandshakeInitationRate
|
|
handshake.mutex.RUnlock()
|
|
if replay {
|
|
device.log.Verbosef("%v - ConsumeMessageInitiation: handshake replay @ %v", peer, timestamp)
|
|
return nil
|
|
}
|
|
if flood {
|
|
device.log.Verbosef("%v - ConsumeMessageInitiation: handshake flood", peer)
|
|
return nil
|
|
}
|
|
|
|
// update handshake state
|
|
|
|
handshake.mutex.Lock()
|
|
|
|
handshake.hash = hash
|
|
handshake.chainKey = chainKey
|
|
handshake.remoteIndex = msg.Sender
|
|
handshake.remoteEphemeral = msg.Ephemeral
|
|
if timestamp.After(handshake.lastTimestamp) {
|
|
handshake.lastTimestamp = timestamp
|
|
}
|
|
now := time.Now()
|
|
if now.After(handshake.lastInitiationConsumption) {
|
|
handshake.lastInitiationConsumption = now
|
|
}
|
|
handshake.state = handshakeInitiationConsumed
|
|
|
|
handshake.mutex.Unlock()
|
|
|
|
setZero(hash[:])
|
|
setZero(chainKey[:])
|
|
|
|
return peer
|
|
}
|
|
|
|
func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error) {
|
|
handshake := &peer.handshake
|
|
handshake.mutex.Lock()
|
|
defer handshake.mutex.Unlock()
|
|
|
|
if handshake.state != handshakeInitiationConsumed {
|
|
return nil, errors.New("handshake initiation must be consumed first")
|
|
}
|
|
|
|
// assign index
|
|
|
|
var err error
|
|
device.indexTable.Delete(handshake.localIndex)
|
|
handshake.localIndex, err = device.indexTable.NewIndexForHandshake(peer, handshake)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var msg MessageResponse
|
|
msg.Type = MessageResponseType
|
|
msg.Sender = handshake.localIndex
|
|
msg.Receiver = handshake.remoteIndex
|
|
|
|
// create ephemeral key
|
|
|
|
handshake.localEphemeral, err = newPrivateKey()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
msg.Ephemeral = handshake.localEphemeral.publicKey()
|
|
handshake.mixHash(msg.Ephemeral[:])
|
|
handshake.mixKey(msg.Ephemeral[:])
|
|
|
|
func() {
|
|
ss := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral)
|
|
handshake.mixKey(ss[:])
|
|
ss = handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
|
|
handshake.mixKey(ss[:])
|
|
}()
|
|
|
|
// add preshared key
|
|
|
|
var tau [blake2s.Size]byte
|
|
var key [chacha20poly1305.KeySize]byte
|
|
|
|
KDF3(
|
|
&handshake.chainKey,
|
|
&tau,
|
|
&key,
|
|
handshake.chainKey[:],
|
|
handshake.presharedKey[:],
|
|
)
|
|
|
|
handshake.mixHash(tau[:])
|
|
|
|
func() {
|
|
aead, _ := chacha20poly1305.New(key[:])
|
|
aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:])
|
|
handshake.mixHash(msg.Empty[:])
|
|
}()
|
|
|
|
handshake.state = handshakeResponseCreated
|
|
|
|
return &msg, nil
|
|
}
|
|
|
|
func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
|
|
if msg.Type != MessageResponseType {
|
|
return nil
|
|
}
|
|
|
|
// lookup handshake by receiver
|
|
|
|
lookup := device.indexTable.Lookup(msg.Receiver)
|
|
handshake := lookup.handshake
|
|
if handshake == nil {
|
|
return nil
|
|
}
|
|
|
|
var (
|
|
hash [blake2s.Size]byte
|
|
chainKey [blake2s.Size]byte
|
|
)
|
|
|
|
ok := func() bool {
|
|
|
|
// lock handshake state
|
|
|
|
handshake.mutex.RLock()
|
|
defer handshake.mutex.RUnlock()
|
|
|
|
if handshake.state != handshakeInitiationCreated {
|
|
return false
|
|
}
|
|
|
|
// lock private key for reading
|
|
|
|
device.staticIdentity.RLock()
|
|
defer device.staticIdentity.RUnlock()
|
|
|
|
// finish 3-way DH
|
|
|
|
mixHash(&hash, &handshake.hash, msg.Ephemeral[:])
|
|
mixKey(&chainKey, &handshake.chainKey, msg.Ephemeral[:])
|
|
|
|
func() {
|
|
ss := handshake.localEphemeral.sharedSecret(msg.Ephemeral)
|
|
mixKey(&chainKey, &chainKey, ss[:])
|
|
setZero(ss[:])
|
|
}()
|
|
|
|
func() {
|
|
ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
|
|
mixKey(&chainKey, &chainKey, ss[:])
|
|
setZero(ss[:])
|
|
}()
|
|
|
|
// add preshared key (psk)
|
|
|
|
var tau [blake2s.Size]byte
|
|
var key [chacha20poly1305.KeySize]byte
|
|
KDF3(
|
|
&chainKey,
|
|
&tau,
|
|
&key,
|
|
chainKey[:],
|
|
handshake.presharedKey[:],
|
|
)
|
|
mixHash(&hash, &hash, tau[:])
|
|
|
|
// authenticate transcript
|
|
|
|
aead, _ := chacha20poly1305.New(key[:])
|
|
_, err := aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:])
|
|
if err != nil {
|
|
return false
|
|
}
|
|
mixHash(&hash, &hash, msg.Empty[:])
|
|
return true
|
|
}()
|
|
|
|
if !ok {
|
|
return nil
|
|
}
|
|
|
|
// update handshake state
|
|
|
|
handshake.mutex.Lock()
|
|
|
|
handshake.hash = hash
|
|
handshake.chainKey = chainKey
|
|
handshake.remoteIndex = msg.Sender
|
|
handshake.state = handshakeResponseConsumed
|
|
|
|
handshake.mutex.Unlock()
|
|
|
|
setZero(hash[:])
|
|
setZero(chainKey[:])
|
|
|
|
return lookup.peer
|
|
}
|
|
|
|
/* Derives a new keypair from the current handshake state
|
|
*
|
|
*/
|
|
func (peer *Peer) BeginSymmetricSession() error {
|
|
device := peer.device
|
|
handshake := &peer.handshake
|
|
handshake.mutex.Lock()
|
|
defer handshake.mutex.Unlock()
|
|
|
|
// derive keys
|
|
|
|
var isInitiator bool
|
|
var sendKey [chacha20poly1305.KeySize]byte
|
|
var recvKey [chacha20poly1305.KeySize]byte
|
|
|
|
if handshake.state == handshakeResponseConsumed {
|
|
KDF2(
|
|
&sendKey,
|
|
&recvKey,
|
|
handshake.chainKey[:],
|
|
nil,
|
|
)
|
|
isInitiator = true
|
|
} else if handshake.state == handshakeResponseCreated {
|
|
KDF2(
|
|
&recvKey,
|
|
&sendKey,
|
|
handshake.chainKey[:],
|
|
nil,
|
|
)
|
|
isInitiator = false
|
|
} else {
|
|
return fmt.Errorf("invalid state for keypair derivation: %v", handshake.state)
|
|
}
|
|
|
|
// zero handshake
|
|
|
|
setZero(handshake.chainKey[:])
|
|
setZero(handshake.hash[:]) // Doesn't necessarily need to be zeroed. Could be used for something interesting down the line.
|
|
setZero(handshake.localEphemeral[:])
|
|
peer.handshake.state = handshakeZeroed
|
|
|
|
// create AEAD instances
|
|
|
|
keypair := new(Keypair)
|
|
keypair.send, _ = chacha20poly1305.New(sendKey[:])
|
|
keypair.receive, _ = chacha20poly1305.New(recvKey[:])
|
|
|
|
setZero(sendKey[:])
|
|
setZero(recvKey[:])
|
|
|
|
keypair.created = time.Now()
|
|
keypair.replayFilter.Reset()
|
|
keypair.isInitiator = isInitiator
|
|
keypair.localIndex = peer.handshake.localIndex
|
|
keypair.remoteIndex = peer.handshake.remoteIndex
|
|
|
|
// remap index
|
|
|
|
device.indexTable.SwapIndexForKeypair(handshake.localIndex, keypair)
|
|
handshake.localIndex = 0
|
|
|
|
// rotate key pairs
|
|
|
|
keypairs := &peer.keypairs
|
|
keypairs.Lock()
|
|
defer keypairs.Unlock()
|
|
|
|
previous := keypairs.previous
|
|
next := keypairs.loadNext()
|
|
current := keypairs.current
|
|
|
|
if isInitiator {
|
|
if next != nil {
|
|
keypairs.storeNext(nil)
|
|
keypairs.previous = next
|
|
device.DeleteKeypair(current)
|
|
} else {
|
|
keypairs.previous = current
|
|
}
|
|
device.DeleteKeypair(previous)
|
|
keypairs.current = keypair
|
|
} else {
|
|
keypairs.storeNext(keypair)
|
|
device.DeleteKeypair(next)
|
|
keypairs.previous = nil
|
|
device.DeleteKeypair(previous)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (peer *Peer) ReceivedWithKeypair(receivedKeypair *Keypair) bool {
|
|
keypairs := &peer.keypairs
|
|
|
|
if keypairs.loadNext() != receivedKeypair {
|
|
return false
|
|
}
|
|
keypairs.Lock()
|
|
defer keypairs.Unlock()
|
|
if keypairs.loadNext() != receivedKeypair {
|
|
return false
|
|
}
|
|
old := keypairs.previous
|
|
keypairs.previous = keypairs.current
|
|
peer.device.DeleteKeypair(old)
|
|
keypairs.current = keypairs.loadNext()
|
|
keypairs.storeNext(nil)
|
|
return true
|
|
}
|