wireguard-go/src/noise_protocol.go

416 lines
9.8 KiB
Go
Raw Normal View History

2017-06-23 13:41:59 +02:00
package main
import (
"errors"
"golang.org/x/crypto/blake2s"
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/crypto/poly1305"
"sync"
)
const (
2017-06-26 13:14:02 +02:00
HandshakeZeroed = iota
HandshakeInitiationCreated
HandshakeInitiationConsumed
2017-06-23 13:41:59 +02:00
HandshakeResponseCreated
2017-06-24 22:03:52 +02:00
HandshakeResponseConsumed
2017-06-23 13:41:59 +02:00
)
const (
NoiseConstruction = "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s"
WGIdentifier = "WireGuard v1 zx2c4 Jason@zx2c4.com"
WGLabelMAC1 = "mac1----"
WGLabelCookie = "cookie--"
)
const (
2017-06-26 13:14:02 +02:00
MessageInitiationType = 1
2017-06-23 13:41:59 +02:00
MessageResponseType = 2
MessageCookieResponseType = 3
MessageTransportType = 4
)
2017-06-26 13:14:02 +02:00
/* Type is an 8-bit field, followed by 3 nul bytes,
* by marshalling the messages in little-endian byteorder
* we can treat these as a 32-bit int
*
*/
type MessageInitiation struct {
2017-06-23 13:41:59 +02:00
Type uint32
Sender uint32
Ephemeral NoisePublicKey
Static [NoisePublicKeySize + poly1305.TagSize]byte
Timestamp [TAI64NSize + poly1305.TagSize]byte
Mac1 [blake2s.Size128]byte
Mac2 [blake2s.Size128]byte
}
type MessageResponse struct {
Type uint32
Sender uint32
Reciever uint32
Ephemeral NoisePublicKey
Empty [poly1305.TagSize]byte
Mac1 [blake2s.Size128]byte
Mac2 [blake2s.Size128]byte
}
type MessageTransport struct {
Type uint32
Reciever uint32
Counter uint64
Content []byte
}
type Handshake struct {
2017-06-24 15:34:17 +02:00
state int
mutex sync.Mutex
hash [blake2s.Size]byte // hash value
chainKey [blake2s.Size]byte // chain key
presharedKey NoiseSymmetricKey // psk
localEphemeral NoisePrivateKey // ephemeral secret key
localIndex uint32 // used to clear hash-table
remoteIndex uint32 // index for sending
remoteStatic NoisePublicKey // long term key
remoteEphemeral NoisePublicKey // ephemeral public key
precomputedStaticStatic [NoisePublicKeySize]byte // precomputed shared secret
lastTimestamp TAI64N
2017-06-23 13:41:59 +02:00
}
var (
InitalChainKey [blake2s.Size]byte
InitalHash [blake2s.Size]byte
2017-06-26 13:14:02 +02:00
ZeroNonce [chacha20poly1305.NonceSize]byte
2017-06-23 13:41:59 +02:00
)
func init() {
InitalChainKey = blake2s.Sum256([]byte(NoiseConstruction))
InitalHash = blake2s.Sum256(append(InitalChainKey[:], []byte(WGIdentifier)...))
}
2017-06-26 13:14:02 +02:00
func mixKey(c [blake2s.Size]byte, data []byte) [blake2s.Size]byte {
2017-06-24 22:03:52 +02:00
return KDF1(c[:], data)
}
2017-06-26 13:14:02 +02:00
func mixHash(h [blake2s.Size]byte, data []byte) [blake2s.Size]byte {
2017-06-24 22:03:52 +02:00
return blake2s.Sum256(append(h[:], data...))
}
2017-06-26 13:14:02 +02:00
func (h *Handshake) mixHash(data []byte) {
h.hash = mixHash(h.hash, data)
2017-06-23 13:41:59 +02:00
}
2017-06-26 13:14:02 +02:00
func (h *Handshake) mixKey(data []byte) {
h.chainKey = mixKey(h.chainKey, data)
2017-06-23 13:41:59 +02:00
}
2017-06-26 13:14:02 +02:00
func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) {
2017-06-24 15:34:17 +02:00
handshake := &peer.handshake
handshake.mutex.Lock()
defer handshake.mutex.Unlock()
2017-06-23 13:41:59 +02:00
2017-06-24 15:34:17 +02:00
// create ephemeral key
2017-06-23 13:41:59 +02:00
var err error
2017-06-24 15:34:17 +02:00
handshake.chainKey = InitalChainKey
2017-06-26 13:14:02 +02:00
handshake.hash = mixHash(InitalHash, handshake.remoteStatic[:])
2017-06-24 15:34:17 +02:00
handshake.localEphemeral, err = newPrivateKey()
2017-06-23 13:41:59 +02:00
if err != nil {
return nil, err
}
device.indices.ClearIndex(handshake.localIndex)
handshake.localIndex, err = device.indices.NewIndex(peer)
2017-06-24 15:34:17 +02:00
// assign index
2017-06-23 13:41:59 +02:00
2017-06-26 13:14:02 +02:00
var msg MessageInitiation
2017-06-24 15:34:17 +02:00
2017-06-26 13:14:02 +02:00
msg.Type = MessageInitiationType
2017-06-24 15:34:17 +02:00
msg.Ephemeral = handshake.localEphemeral.publicKey()
if err != nil {
return nil, err
}
2017-06-24 22:03:52 +02:00
msg.Sender = handshake.localIndex
2017-06-26 13:14:02 +02:00
handshake.mixKey(msg.Ephemeral[:])
handshake.mixHash(msg.Ephemeral[:])
2017-06-23 13:41:59 +02:00
2017-06-26 13:14:02 +02:00
// encrypt static key
2017-06-23 13:41:59 +02:00
func() {
var key [chacha20poly1305.KeySize]byte
2017-06-24 15:34:17 +02:00
ss := handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
handshake.chainKey, key = KDF2(handshake.chainKey[:], ss[:])
2017-06-23 13:41:59 +02:00
aead, _ := chacha20poly1305.New(key[:])
2017-06-24 15:34:17 +02:00
aead.Seal(msg.Static[:0], ZeroNonce[:], device.publicKey[:], handshake.hash[:])
2017-06-23 13:41:59 +02:00
}()
2017-06-26 13:14:02 +02:00
handshake.mixHash(msg.Static[:])
2017-06-23 13:41:59 +02:00
// encrypt timestamp
timestamp := Timestamp()
func() {
var key [chacha20poly1305.KeySize]byte
2017-06-24 15:34:17 +02:00
handshake.chainKey, key = KDF2(
handshake.chainKey[:],
handshake.precomputedStaticStatic[:],
)
2017-06-23 13:41:59 +02:00
aead, _ := chacha20poly1305.New(key[:])
2017-06-24 15:34:17 +02:00
aead.Seal(msg.Timestamp[:0], ZeroNonce[:], timestamp[:], handshake.hash[:])
2017-06-23 13:41:59 +02:00
}()
2017-06-24 15:34:17 +02:00
2017-06-26 13:14:02 +02:00
handshake.mixHash(msg.Timestamp[:])
handshake.state = HandshakeInitiationCreated
2017-06-24 15:34:17 +02:00
2017-06-23 13:41:59 +02:00
return &msg, nil
}
2017-06-26 13:14:02 +02:00
func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
if msg.Type != MessageInitiationType {
return nil
2017-06-23 13:41:59 +02:00
}
2017-06-26 13:14:02 +02:00
hash := mixHash(InitalHash, device.publicKey[:])
hash = mixHash(hash, msg.Ephemeral[:])
chainKey := mixKey(InitalChainKey, msg.Ephemeral[:])
2017-06-23 13:41:59 +02:00
2017-06-26 13:14:02 +02:00
// decrypt static key
2017-06-23 13:41:59 +02:00
2017-06-24 15:34:17 +02:00
var err error
var peerPK NoisePublicKey
func() {
var key [chacha20poly1305.KeySize]byte
ss := device.privateKey.sharedSecret(msg.Ephemeral)
chainKey, key = KDF2(chainKey[:], ss[:])
aead, _ := chacha20poly1305.New(key[:])
_, err = aead.Open(peerPK[:0], ZeroNonce[:], msg.Static[:], hash[:])
}()
2017-06-23 13:41:59 +02:00
if err != nil {
2017-06-24 15:34:17 +02:00
return nil
2017-06-23 13:41:59 +02:00
}
2017-06-26 13:14:02 +02:00
hash = mixHash(hash, msg.Static[:])
2017-06-23 13:41:59 +02:00
2017-06-24 15:34:17 +02:00
// find peer
peer := device.LookupPeer(peerPK)
if peer == nil {
return nil
}
handshake := &peer.handshake
handshake.mutex.Lock()
defer handshake.mutex.Unlock()
// decrypt timestamp
var timestamp TAI64N
func() {
var key [chacha20poly1305.KeySize]byte
chainKey, key = KDF2(
chainKey[:],
handshake.precomputedStaticStatic[:],
)
aead, _ := chacha20poly1305.New(key[:])
_, err = aead.Open(timestamp[:0], ZeroNonce[:], msg.Timestamp[:], hash[:])
}()
if err != nil {
return nil
}
2017-06-26 13:14:02 +02:00
hash = mixHash(hash, msg.Timestamp[:])
2017-06-23 13:41:59 +02:00
2017-06-24 15:34:17 +02:00
// check for replay attack
2017-06-23 13:41:59 +02:00
2017-06-24 15:34:17 +02:00
if !timestamp.After(handshake.lastTimestamp) {
return nil
}
2017-06-26 13:14:02 +02:00
// TODO: check for flood attack
2017-06-23 13:41:59 +02:00
2017-06-24 15:34:17 +02:00
// update handshake state
2017-06-23 13:41:59 +02:00
2017-06-24 15:34:17 +02:00
handshake.hash = hash
handshake.chainKey = chainKey
handshake.remoteIndex = msg.Sender
handshake.remoteEphemeral = msg.Ephemeral
2017-06-24 22:03:52 +02:00
handshake.lastTimestamp = timestamp
2017-06-26 13:14:02 +02:00
handshake.state = HandshakeInitiationConsumed
2017-06-24 15:34:17 +02:00
return peer
2017-06-23 13:41:59 +02:00
}
2017-06-24 15:34:17 +02:00
func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error) {
handshake := &peer.handshake
handshake.mutex.Lock()
defer handshake.mutex.Unlock()
2017-06-26 13:14:02 +02:00
if handshake.state != HandshakeInitiationConsumed {
return nil, errors.New("handshake initation must be consumed first")
2017-06-24 15:34:17 +02:00
}
// assign index
var err error
device.indices.ClearIndex(handshake.localIndex)
2017-06-24 22:03:52 +02:00
handshake.localIndex, err = device.indices.NewIndex(peer)
2017-06-24 15:34:17 +02:00
if err != nil {
return nil, err
}
2017-06-23 13:41:59 +02:00
2017-06-24 22:03:52 +02:00
var msg MessageResponse
msg.Type = MessageResponseType
msg.Sender = handshake.localIndex
msg.Reciever = handshake.remoteIndex
2017-06-24 15:34:17 +02:00
// create ephemeral key
handshake.localEphemeral, err = newPrivateKey()
if err != nil {
return nil, err
}
msg.Ephemeral = handshake.localEphemeral.publicKey()
2017-06-26 13:14:02 +02:00
handshake.mixHash(msg.Ephemeral[:])
2017-06-24 15:34:17 +02:00
func() {
ss := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral)
2017-06-26 13:14:02 +02:00
handshake.mixKey(ss[:])
2017-06-24 15:34:17 +02:00
ss = handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
2017-06-26 13:14:02 +02:00
handshake.mixKey(ss[:])
2017-06-24 15:34:17 +02:00
}()
// add preshared key (psk)
var tau [blake2s.Size]byte
var key [chacha20poly1305.KeySize]byte
handshake.chainKey, tau, key = KDF3(handshake.chainKey[:], handshake.presharedKey[:])
2017-06-26 13:14:02 +02:00
handshake.mixHash(tau[:])
2017-06-24 15:34:17 +02:00
func() {
aead, _ := chacha20poly1305.New(key[:])
2017-06-24 22:03:52 +02:00
aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:])
2017-06-26 13:14:02 +02:00
handshake.mixHash(msg.Empty[:])
2017-06-24 15:34:17 +02:00
}()
2017-06-24 22:03:52 +02:00
handshake.state = HandshakeResponseCreated
2017-06-24 15:34:17 +02:00
return &msg, nil
2017-06-23 13:41:59 +02:00
}
2017-06-24 22:03:52 +02:00
func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
if msg.Type != MessageResponseType {
2017-06-26 13:14:02 +02:00
return nil
2017-06-24 22:03:52 +02:00
}
// lookup handshake by reciever
lookup := device.indices.Lookup(msg.Reciever)
handshake := lookup.handshake
if handshake == nil {
2017-06-24 22:03:52 +02:00
return nil
}
2017-06-24 22:03:52 +02:00
handshake.mutex.Lock()
defer handshake.mutex.Unlock()
2017-06-26 13:14:02 +02:00
if handshake.state != HandshakeInitiationCreated {
2017-06-24 22:03:52 +02:00
return nil
}
// finish 3-way DH
2017-06-26 13:14:02 +02:00
hash := mixHash(handshake.hash, msg.Ephemeral[:])
2017-06-24 22:03:52 +02:00
chainKey := handshake.chainKey
func() {
ss := handshake.localEphemeral.sharedSecret(msg.Ephemeral)
2017-06-26 13:14:02 +02:00
chainKey = mixKey(chainKey, ss[:])
2017-06-24 22:03:52 +02:00
ss = device.privateKey.sharedSecret(msg.Ephemeral)
2017-06-26 13:14:02 +02:00
chainKey = mixKey(chainKey, ss[:])
2017-06-24 22:03:52 +02:00
}()
// add preshared key (psk)
var tau [blake2s.Size]byte
var key [chacha20poly1305.KeySize]byte
chainKey, tau, key = KDF3(chainKey[:], handshake.presharedKey[:])
2017-06-26 13:14:02 +02:00
hash = mixHash(hash, tau[:])
2017-06-24 22:03:52 +02:00
// authenticate
aead, _ := chacha20poly1305.New(key[:])
_, err := aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:])
if err != nil {
return nil
}
2017-06-26 13:14:02 +02:00
hash = mixHash(hash, msg.Empty[:])
2017-06-24 22:03:52 +02:00
// update handshake state
handshake.hash = hash
handshake.chainKey = chainKey
handshake.remoteIndex = msg.Sender
handshake.state = HandshakeResponseConsumed
return lookup.peer
2017-06-24 22:03:52 +02:00
}
func (peer *Peer) NewKeyPair() *KeyPair {
handshake := &peer.handshake
handshake.mutex.Lock()
defer handshake.mutex.Unlock()
// derive keys
var isInitiator bool
2017-06-24 22:03:52 +02:00
var sendKey [chacha20poly1305.KeySize]byte
var recvKey [chacha20poly1305.KeySize]byte
if handshake.state == HandshakeResponseConsumed {
sendKey, recvKey = KDF2(handshake.chainKey[:], nil)
isInitiator = true
2017-06-24 22:03:52 +02:00
} else if handshake.state == HandshakeResponseCreated {
recvKey, sendKey = KDF2(handshake.chainKey[:], nil)
isInitiator = false
2017-06-24 22:03:52 +02:00
} else {
return nil
}
// create AEAD instances
var keyPair KeyPair
2017-06-24 22:03:52 +02:00
keyPair.send, _ = chacha20poly1305.New(sendKey[:])
keyPair.recv, _ = chacha20poly1305.New(recvKey[:])
keyPair.sendNonce = 0
keyPair.recvNonce = 0
// remap index
peer.device.indices.Insert(handshake.localIndex, IndexTableEntry{
peer: peer,
keyPair: &keyPair,
handshake: nil,
})
handshake.localIndex = 0
// rotate key pairs
func() {
kp := &peer.keyPairs
kp.mutex.Lock()
defer kp.mutex.Unlock()
if isInitiator {
kp.previous = peer.keyPairs.current
kp.current = &keyPair
kp.newKeyPair <- true
} else {
kp.next = &keyPair
}
}()
2017-06-26 13:14:02 +02:00
// zero handshake
handshake.chainKey = [blake2s.Size]byte{}
handshake.localEphemeral = NoisePrivateKey{}
peer.handshake.state = HandshakeZeroed
2017-06-24 22:03:52 +02:00
return &keyPair
}