Completed noise handshake
This commit is contained in:
parent
25190e4336
commit
cf3a5130d3
17
src/index.go
17
src/index.go
@ -6,13 +6,15 @@ import (
|
||||
)
|
||||
|
||||
/* Index=0 is reserved for unset indecies
|
||||
*
|
||||
* TODO: Rethink map[id] -> peer VS map[id] -> handshake and handshake <ref> peer
|
||||
*
|
||||
*/
|
||||
|
||||
type IndexTable struct {
|
||||
mutex sync.RWMutex
|
||||
keypairs map[uint32]*KeyPair
|
||||
handshakes map[uint32]*Handshake
|
||||
handshakes map[uint32]*Peer
|
||||
}
|
||||
|
||||
func randUint32() (uint32, error) {
|
||||
@ -32,10 +34,10 @@ func (table *IndexTable) Init() {
|
||||
table.mutex.Lock()
|
||||
defer table.mutex.Unlock()
|
||||
table.keypairs = make(map[uint32]*KeyPair)
|
||||
table.handshakes = make(map[uint32]*Handshake)
|
||||
table.handshakes = make(map[uint32]*Peer)
|
||||
}
|
||||
|
||||
func (table *IndexTable) NewIndex(handshake *Handshake) (uint32, error) {
|
||||
func (table *IndexTable) NewIndex(peer *Peer) (uint32, error) {
|
||||
table.mutex.Lock()
|
||||
defer table.mutex.Unlock()
|
||||
for {
|
||||
@ -60,11 +62,10 @@ func (table *IndexTable) NewIndex(handshake *Handshake) (uint32, error) {
|
||||
continue
|
||||
}
|
||||
|
||||
// update the index
|
||||
// clean old index
|
||||
|
||||
delete(table.handshakes, handshake.localIndex)
|
||||
handshake.localIndex = id
|
||||
table.handshakes[id] = handshake
|
||||
delete(table.handshakes, peer.handshake.localIndex)
|
||||
table.handshakes[id] = peer
|
||||
return id, nil
|
||||
}
|
||||
}
|
||||
@ -75,7 +76,7 @@ func (table *IndexTable) LookupKeyPair(id uint32) *KeyPair {
|
||||
return table.keypairs[id]
|
||||
}
|
||||
|
||||
func (table *IndexTable) LookupHandshake(id uint32) *Handshake {
|
||||
func (table *IndexTable) LookupHandshake(id uint32) *Peer {
|
||||
table.mutex.RLock()
|
||||
defer table.mutex.RUnlock()
|
||||
return table.handshakes[id]
|
||||
|
@ -5,8 +5,8 @@ import (
|
||||
)
|
||||
|
||||
type KeyPair struct {
|
||||
recieveKey cipher.AEAD
|
||||
recieveNonce NoiseNonce
|
||||
sendKey cipher.AEAD
|
||||
sendNonce NoiseNonce
|
||||
recv cipher.AEAD
|
||||
recvNonce NoiseNonce
|
||||
send cipher.AEAD
|
||||
sendNonce NoiseNonce
|
||||
}
|
||||
|
@ -45,22 +45,7 @@ func KDF3(key []byte, input []byte) (t0 [blake2s.Size]byte, t1 [blake2s.Size]byt
|
||||
return
|
||||
}
|
||||
|
||||
/*
|
||||
*
|
||||
*/
|
||||
|
||||
func addToChainKey(c [blake2s.Size]byte, data []byte) [blake2s.Size]byte {
|
||||
return KDF1(c[:], data)
|
||||
}
|
||||
|
||||
func addToHash(h [blake2s.Size]byte, data []byte) [blake2s.Size]byte {
|
||||
return blake2s.Sum256(append(h[:], data...))
|
||||
}
|
||||
|
||||
/* Curve25519 wrappers
|
||||
*
|
||||
* TODO: Rethink this
|
||||
*/
|
||||
/* curve25519 wrappers */
|
||||
|
||||
func newPrivateKey() (sk NoisePrivateKey, err error) {
|
||||
// clamping: https://cr.yp.to/ecdh.html
|
||||
|
@ -9,9 +9,11 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
HandshakeInitialCreated = iota
|
||||
HandshakeReset = iota
|
||||
HandshakeInitialCreated
|
||||
HandshakeInitialConsumed
|
||||
HandshakeResponseCreated
|
||||
HandshakeResponseConsumed
|
||||
)
|
||||
|
||||
const (
|
||||
@ -71,7 +73,6 @@ type Handshake struct {
|
||||
}
|
||||
|
||||
var (
|
||||
EmptyMessage []byte
|
||||
ZeroNonce [chacha20poly1305.NonceSize]byte
|
||||
InitalChainKey [blake2s.Size]byte
|
||||
InitalHash [blake2s.Size]byte
|
||||
@ -82,6 +83,14 @@ func init() {
|
||||
InitalHash = blake2s.Sum256(append(InitalChainKey[:], []byte(WGIdentifier)...))
|
||||
}
|
||||
|
||||
func addToChainKey(c [blake2s.Size]byte, data []byte) [blake2s.Size]byte {
|
||||
return KDF1(c[:], data)
|
||||
}
|
||||
|
||||
func addToHash(h [blake2s.Size]byte, data []byte) [blake2s.Size]byte {
|
||||
return blake2s.Sum256(append(h[:], data...))
|
||||
}
|
||||
|
||||
func (h *Handshake) addToHash(data []byte) {
|
||||
h.hash = addToHash(h.hash, data)
|
||||
}
|
||||
@ -90,11 +99,6 @@ func (h *Handshake) addToChainKey(data []byte) {
|
||||
h.chainKey = addToChainKey(h.chainKey, data)
|
||||
}
|
||||
|
||||
func (device *Device) Precompute(peer *Peer) {
|
||||
h := &peer.handshake
|
||||
h.precomputedStaticStatic = device.privateKey.sharedSecret(h.remoteStatic)
|
||||
}
|
||||
|
||||
func (device *Device) CreateMessageInitial(peer *Peer) (*MessageInital, error) {
|
||||
handshake := &peer.handshake
|
||||
handshake.mutex.Lock()
|
||||
@ -116,16 +120,17 @@ func (device *Device) CreateMessageInitial(peer *Peer) (*MessageInital, error) {
|
||||
|
||||
msg.Type = MessageInitalType
|
||||
msg.Ephemeral = handshake.localEphemeral.publicKey()
|
||||
msg.Sender, err = device.indices.NewIndex(handshake)
|
||||
handshake.localIndex, err = device.indices.NewIndex(peer)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
msg.Sender = handshake.localIndex
|
||||
handshake.addToChainKey(msg.Ephemeral[:])
|
||||
handshake.addToHash(msg.Ephemeral[:])
|
||||
|
||||
// encrypt long-term "identity key"
|
||||
// encrypt identity key
|
||||
|
||||
func() {
|
||||
var key [chacha20poly1305.KeySize]byte
|
||||
@ -221,6 +226,7 @@ func (device *Device) ConsumeMessageInitial(msg *MessageInital) *Peer {
|
||||
handshake.chainKey = chainKey
|
||||
handshake.remoteIndex = msg.Sender
|
||||
handshake.remoteEphemeral = msg.Ephemeral
|
||||
handshake.lastTimestamp = timestamp
|
||||
handshake.state = HandshakeInitialConsumed
|
||||
return peer
|
||||
}
|
||||
@ -237,14 +243,16 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
|
||||
// assign index
|
||||
|
||||
var err error
|
||||
var msg MessageResponse
|
||||
msg.Type = MessageResponseType
|
||||
msg.Sender, err = device.indices.NewIndex(handshake)
|
||||
msg.Reciever = handshake.remoteIndex
|
||||
handshake.localIndex, err = device.indices.NewIndex(peer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var msg MessageResponse
|
||||
msg.Type = MessageResponseType
|
||||
msg.Sender = handshake.localIndex
|
||||
msg.Reciever = handshake.remoteIndex
|
||||
|
||||
// create ephemeral key
|
||||
|
||||
handshake.localEphemeral, err = newPrivateKey()
|
||||
@ -252,6 +260,7 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
|
||||
return nil, err
|
||||
}
|
||||
msg.Ephemeral = handshake.localEphemeral.publicKey()
|
||||
handshake.addToHash(msg.Ephemeral[:])
|
||||
|
||||
func() {
|
||||
ss := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral)
|
||||
@ -269,9 +278,97 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
|
||||
|
||||
func() {
|
||||
aead, _ := chacha20poly1305.New(key[:])
|
||||
aead.Seal(msg.Empty[:0], ZeroNonce[:], EmptyMessage, handshake.hash[:])
|
||||
aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:])
|
||||
handshake.addToHash(msg.Empty[:])
|
||||
}()
|
||||
|
||||
handshake.state = HandshakeResponseCreated
|
||||
return &msg, nil
|
||||
}
|
||||
|
||||
func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
|
||||
if msg.Type != MessageResponseType {
|
||||
panic(errors.New("bug: invalid message type"))
|
||||
}
|
||||
|
||||
// lookup handshake by reciever
|
||||
|
||||
peer := device.indices.LookupHandshake(msg.Reciever)
|
||||
if peer == nil {
|
||||
return nil
|
||||
}
|
||||
handshake := &peer.handshake
|
||||
handshake.mutex.Lock()
|
||||
defer handshake.mutex.Unlock()
|
||||
if handshake.state != HandshakeInitialCreated {
|
||||
return nil
|
||||
}
|
||||
|
||||
// finish 3-way DH
|
||||
|
||||
hash := addToHash(handshake.hash, msg.Ephemeral[:])
|
||||
chainKey := handshake.chainKey
|
||||
|
||||
func() {
|
||||
ss := handshake.localEphemeral.sharedSecret(msg.Ephemeral)
|
||||
chainKey = addToChainKey(chainKey, ss[:])
|
||||
ss = device.privateKey.sharedSecret(msg.Ephemeral)
|
||||
chainKey = addToChainKey(chainKey, ss[:])
|
||||
}()
|
||||
|
||||
// add preshared key (psk)
|
||||
|
||||
var tau [blake2s.Size]byte
|
||||
var key [chacha20poly1305.KeySize]byte
|
||||
chainKey, tau, key = KDF3(chainKey[:], handshake.presharedKey[:])
|
||||
hash = addToHash(hash, tau[:])
|
||||
|
||||
// authenticate
|
||||
|
||||
aead, _ := chacha20poly1305.New(key[:])
|
||||
_, err := aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:])
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
hash = addToHash(hash, msg.Empty[:])
|
||||
|
||||
// update handshake state
|
||||
|
||||
handshake.hash = hash
|
||||
handshake.chainKey = chainKey
|
||||
handshake.remoteIndex = msg.Sender
|
||||
handshake.state = HandshakeResponseConsumed
|
||||
|
||||
return peer
|
||||
}
|
||||
|
||||
func (peer *Peer) NewKeyPair() *KeyPair {
|
||||
handshake := &peer.handshake
|
||||
handshake.mutex.Lock()
|
||||
defer handshake.mutex.Unlock()
|
||||
|
||||
// derive keys
|
||||
|
||||
var sendKey [chacha20poly1305.KeySize]byte
|
||||
var recvKey [chacha20poly1305.KeySize]byte
|
||||
|
||||
if handshake.state == HandshakeResponseConsumed {
|
||||
sendKey, recvKey = KDF2(handshake.chainKey[:], nil)
|
||||
} else if handshake.state == HandshakeResponseCreated {
|
||||
recvKey, sendKey = KDF2(handshake.chainKey[:], nil)
|
||||
} else {
|
||||
return nil
|
||||
}
|
||||
|
||||
// create AEAD instances
|
||||
|
||||
var keyPair KeyPair
|
||||
keyPair.send, _ = chacha20poly1305.New(sendKey[:])
|
||||
keyPair.recv, _ = chacha20poly1305.New(recvKey[:])
|
||||
keyPair.sendNonce = 0
|
||||
keyPair.recvNonce = 0
|
||||
|
||||
peer.handshake.state = HandshakeReset
|
||||
|
||||
return &keyPair
|
||||
}
|
||||
|
@ -63,7 +63,9 @@ func TestNoiseHandshake(t *testing.T) {
|
||||
|
||||
/* simulate handshake */
|
||||
|
||||
// Initiation message
|
||||
// initiation message
|
||||
|
||||
t.Log("exchange initiation message")
|
||||
|
||||
msg1, err := dev1.CreateMessageInitial(peer2)
|
||||
assertNil(t, err)
|
||||
@ -88,6 +90,68 @@ func TestNoiseHandshake(t *testing.T) {
|
||||
peer2.handshake.hash[:],
|
||||
)
|
||||
|
||||
// Response message
|
||||
// response message
|
||||
|
||||
t.Log("exchange response message")
|
||||
|
||||
msg2, err := dev2.CreateMessageResponse(peer1)
|
||||
assertNil(t, err)
|
||||
|
||||
peer = dev1.ConsumeMessageResponse(msg2)
|
||||
if peer == nil {
|
||||
t.Fatal("handshake failed at response message")
|
||||
}
|
||||
|
||||
assertEqual(
|
||||
t,
|
||||
peer1.handshake.chainKey[:],
|
||||
peer2.handshake.chainKey[:],
|
||||
)
|
||||
|
||||
assertEqual(
|
||||
t,
|
||||
peer1.handshake.hash[:],
|
||||
peer2.handshake.hash[:],
|
||||
)
|
||||
|
||||
// key pairs
|
||||
|
||||
t.Log("deriving keys")
|
||||
|
||||
key1 := peer1.NewKeyPair()
|
||||
key2 := peer2.NewKeyPair()
|
||||
|
||||
if key1 == nil {
|
||||
t.Fatal("failed to dervice key-pair for peer 1")
|
||||
}
|
||||
|
||||
if key2 == nil {
|
||||
t.Fatal("failed to dervice key-pair for peer 2")
|
||||
}
|
||||
|
||||
// encrypting / decryption test
|
||||
|
||||
t.Log("test key pairs")
|
||||
|
||||
func() {
|
||||
testMsg := []byte("wireguard test message 1")
|
||||
var err error
|
||||
var out []byte
|
||||
var nonce [12]byte
|
||||
out = key1.send.Seal(out, nonce[:], testMsg, nil)
|
||||
out, err = key2.recv.Open(out[:0], nonce[:], out, nil)
|
||||
assertNil(t, err)
|
||||
assertEqual(t, out, testMsg)
|
||||
}()
|
||||
|
||||
func() {
|
||||
testMsg := []byte("wireguard test message 2")
|
||||
var err error
|
||||
var out []byte
|
||||
var nonce [12]byte
|
||||
out = key2.send.Seal(out, nonce[:], testMsg, nil)
|
||||
out, err = key1.recv.Open(out[:0], nonce[:], out, nil)
|
||||
assertNil(t, err)
|
||||
assertEqual(t, out, testMsg)
|
||||
}()
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user