Completed noise handshake
This commit is contained in:
		
							parent
							
								
									25190e4336
								
							
						
					
					
						commit
						cf3a5130d3
					
				
							
								
								
									
										17
									
								
								src/index.go
									
									
									
									
									
								
							
							
						
						
									
										17
									
								
								src/index.go
									
									
									
									
									
								
							@ -6,13 +6,15 @@ import (
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
/* Index=0 is reserved for unset indecies
 | 
			
		||||
 *
 | 
			
		||||
 * TODO: Rethink map[id] -> peer VS map[id] -> handshake and handshake <ref> peer
 | 
			
		||||
 *
 | 
			
		||||
 */
 | 
			
		||||
 | 
			
		||||
type IndexTable struct {
 | 
			
		||||
	mutex      sync.RWMutex
 | 
			
		||||
	keypairs   map[uint32]*KeyPair
 | 
			
		||||
	handshakes map[uint32]*Handshake
 | 
			
		||||
	handshakes map[uint32]*Peer
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func randUint32() (uint32, error) {
 | 
			
		||||
@ -32,10 +34,10 @@ func (table *IndexTable) Init() {
 | 
			
		||||
	table.mutex.Lock()
 | 
			
		||||
	defer table.mutex.Unlock()
 | 
			
		||||
	table.keypairs = make(map[uint32]*KeyPair)
 | 
			
		||||
	table.handshakes = make(map[uint32]*Handshake)
 | 
			
		||||
	table.handshakes = make(map[uint32]*Peer)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (table *IndexTable) NewIndex(handshake *Handshake) (uint32, error) {
 | 
			
		||||
func (table *IndexTable) NewIndex(peer *Peer) (uint32, error) {
 | 
			
		||||
	table.mutex.Lock()
 | 
			
		||||
	defer table.mutex.Unlock()
 | 
			
		||||
	for {
 | 
			
		||||
@ -60,11 +62,10 @@ func (table *IndexTable) NewIndex(handshake *Handshake) (uint32, error) {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// update the index
 | 
			
		||||
		// clean old index
 | 
			
		||||
 | 
			
		||||
		delete(table.handshakes, handshake.localIndex)
 | 
			
		||||
		handshake.localIndex = id
 | 
			
		||||
		table.handshakes[id] = handshake
 | 
			
		||||
		delete(table.handshakes, peer.handshake.localIndex)
 | 
			
		||||
		table.handshakes[id] = peer
 | 
			
		||||
		return id, nil
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
@ -75,7 +76,7 @@ func (table *IndexTable) LookupKeyPair(id uint32) *KeyPair {
 | 
			
		||||
	return table.keypairs[id]
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (table *IndexTable) LookupHandshake(id uint32) *Handshake {
 | 
			
		||||
func (table *IndexTable) LookupHandshake(id uint32) *Peer {
 | 
			
		||||
	table.mutex.RLock()
 | 
			
		||||
	defer table.mutex.RUnlock()
 | 
			
		||||
	return table.handshakes[id]
 | 
			
		||||
 | 
			
		||||
@ -5,8 +5,8 @@ import (
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type KeyPair struct {
 | 
			
		||||
	recieveKey   cipher.AEAD
 | 
			
		||||
	recieveNonce NoiseNonce
 | 
			
		||||
	sendKey      cipher.AEAD
 | 
			
		||||
	sendNonce    NoiseNonce
 | 
			
		||||
	recv      cipher.AEAD
 | 
			
		||||
	recvNonce NoiseNonce
 | 
			
		||||
	send      cipher.AEAD
 | 
			
		||||
	sendNonce NoiseNonce
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -45,22 +45,7 @@ func KDF3(key []byte, input []byte) (t0 [blake2s.Size]byte, t1 [blake2s.Size]byt
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/*
 | 
			
		||||
 *
 | 
			
		||||
 */
 | 
			
		||||
 | 
			
		||||
func addToChainKey(c [blake2s.Size]byte, data []byte) [blake2s.Size]byte {
 | 
			
		||||
	return KDF1(c[:], data)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func addToHash(h [blake2s.Size]byte, data []byte) [blake2s.Size]byte {
 | 
			
		||||
	return blake2s.Sum256(append(h[:], data...))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* Curve25519 wrappers
 | 
			
		||||
 *
 | 
			
		||||
 * TODO: Rethink this
 | 
			
		||||
 */
 | 
			
		||||
/* curve25519 wrappers */
 | 
			
		||||
 | 
			
		||||
func newPrivateKey() (sk NoisePrivateKey, err error) {
 | 
			
		||||
	// clamping: https://cr.yp.to/ecdh.html
 | 
			
		||||
 | 
			
		||||
@ -9,9 +9,11 @@ import (
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	HandshakeInitialCreated = iota
 | 
			
		||||
	HandshakeReset = iota
 | 
			
		||||
	HandshakeInitialCreated
 | 
			
		||||
	HandshakeInitialConsumed
 | 
			
		||||
	HandshakeResponseCreated
 | 
			
		||||
	HandshakeResponseConsumed
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
@ -71,7 +73,6 @@ type Handshake struct {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var (
 | 
			
		||||
	EmptyMessage   []byte
 | 
			
		||||
	ZeroNonce      [chacha20poly1305.NonceSize]byte
 | 
			
		||||
	InitalChainKey [blake2s.Size]byte
 | 
			
		||||
	InitalHash     [blake2s.Size]byte
 | 
			
		||||
@ -82,6 +83,14 @@ func init() {
 | 
			
		||||
	InitalHash = blake2s.Sum256(append(InitalChainKey[:], []byte(WGIdentifier)...))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func addToChainKey(c [blake2s.Size]byte, data []byte) [blake2s.Size]byte {
 | 
			
		||||
	return KDF1(c[:], data)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func addToHash(h [blake2s.Size]byte, data []byte) [blake2s.Size]byte {
 | 
			
		||||
	return blake2s.Sum256(append(h[:], data...))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *Handshake) addToHash(data []byte) {
 | 
			
		||||
	h.hash = addToHash(h.hash, data)
 | 
			
		||||
}
 | 
			
		||||
@ -90,11 +99,6 @@ func (h *Handshake) addToChainKey(data []byte) {
 | 
			
		||||
	h.chainKey = addToChainKey(h.chainKey, data)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (device *Device) Precompute(peer *Peer) {
 | 
			
		||||
	h := &peer.handshake
 | 
			
		||||
	h.precomputedStaticStatic = device.privateKey.sharedSecret(h.remoteStatic)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (device *Device) CreateMessageInitial(peer *Peer) (*MessageInital, error) {
 | 
			
		||||
	handshake := &peer.handshake
 | 
			
		||||
	handshake.mutex.Lock()
 | 
			
		||||
@ -116,16 +120,17 @@ func (device *Device) CreateMessageInitial(peer *Peer) (*MessageInital, error) {
 | 
			
		||||
 | 
			
		||||
	msg.Type = MessageInitalType
 | 
			
		||||
	msg.Ephemeral = handshake.localEphemeral.publicKey()
 | 
			
		||||
	msg.Sender, err = device.indices.NewIndex(handshake)
 | 
			
		||||
	handshake.localIndex, err = device.indices.NewIndex(peer)
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	msg.Sender = handshake.localIndex
 | 
			
		||||
	handshake.addToChainKey(msg.Ephemeral[:])
 | 
			
		||||
	handshake.addToHash(msg.Ephemeral[:])
 | 
			
		||||
 | 
			
		||||
	// encrypt long-term "identity key"
 | 
			
		||||
	// encrypt identity key
 | 
			
		||||
 | 
			
		||||
	func() {
 | 
			
		||||
		var key [chacha20poly1305.KeySize]byte
 | 
			
		||||
@ -221,6 +226,7 @@ func (device *Device) ConsumeMessageInitial(msg *MessageInital) *Peer {
 | 
			
		||||
	handshake.chainKey = chainKey
 | 
			
		||||
	handshake.remoteIndex = msg.Sender
 | 
			
		||||
	handshake.remoteEphemeral = msg.Ephemeral
 | 
			
		||||
	handshake.lastTimestamp = timestamp
 | 
			
		||||
	handshake.state = HandshakeInitialConsumed
 | 
			
		||||
	return peer
 | 
			
		||||
}
 | 
			
		||||
@ -237,14 +243,16 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
 | 
			
		||||
	// assign index
 | 
			
		||||
 | 
			
		||||
	var err error
 | 
			
		||||
	var msg MessageResponse
 | 
			
		||||
	msg.Type = MessageResponseType
 | 
			
		||||
	msg.Sender, err = device.indices.NewIndex(handshake)
 | 
			
		||||
	msg.Reciever = handshake.remoteIndex
 | 
			
		||||
	handshake.localIndex, err = device.indices.NewIndex(peer)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var msg MessageResponse
 | 
			
		||||
	msg.Type = MessageResponseType
 | 
			
		||||
	msg.Sender = handshake.localIndex
 | 
			
		||||
	msg.Reciever = handshake.remoteIndex
 | 
			
		||||
 | 
			
		||||
	// create ephemeral key
 | 
			
		||||
 | 
			
		||||
	handshake.localEphemeral, err = newPrivateKey()
 | 
			
		||||
@ -252,6 +260,7 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	msg.Ephemeral = handshake.localEphemeral.publicKey()
 | 
			
		||||
	handshake.addToHash(msg.Ephemeral[:])
 | 
			
		||||
 | 
			
		||||
	func() {
 | 
			
		||||
		ss := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral)
 | 
			
		||||
@ -269,9 +278,97 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
 | 
			
		||||
 | 
			
		||||
	func() {
 | 
			
		||||
		aead, _ := chacha20poly1305.New(key[:])
 | 
			
		||||
		aead.Seal(msg.Empty[:0], ZeroNonce[:], EmptyMessage, handshake.hash[:])
 | 
			
		||||
		aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:])
 | 
			
		||||
		handshake.addToHash(msg.Empty[:])
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	handshake.state = HandshakeResponseCreated
 | 
			
		||||
	return &msg, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
 | 
			
		||||
	if msg.Type != MessageResponseType {
 | 
			
		||||
		panic(errors.New("bug: invalid message type"))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// lookup handshake by reciever
 | 
			
		||||
 | 
			
		||||
	peer := device.indices.LookupHandshake(msg.Reciever)
 | 
			
		||||
	if peer == nil {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	handshake := &peer.handshake
 | 
			
		||||
	handshake.mutex.Lock()
 | 
			
		||||
	defer handshake.mutex.Unlock()
 | 
			
		||||
	if handshake.state != HandshakeInitialCreated {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// finish 3-way DH
 | 
			
		||||
 | 
			
		||||
	hash := addToHash(handshake.hash, msg.Ephemeral[:])
 | 
			
		||||
	chainKey := handshake.chainKey
 | 
			
		||||
 | 
			
		||||
	func() {
 | 
			
		||||
		ss := handshake.localEphemeral.sharedSecret(msg.Ephemeral)
 | 
			
		||||
		chainKey = addToChainKey(chainKey, ss[:])
 | 
			
		||||
		ss = device.privateKey.sharedSecret(msg.Ephemeral)
 | 
			
		||||
		chainKey = addToChainKey(chainKey, ss[:])
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	// add preshared key (psk)
 | 
			
		||||
 | 
			
		||||
	var tau [blake2s.Size]byte
 | 
			
		||||
	var key [chacha20poly1305.KeySize]byte
 | 
			
		||||
	chainKey, tau, key = KDF3(chainKey[:], handshake.presharedKey[:])
 | 
			
		||||
	hash = addToHash(hash, tau[:])
 | 
			
		||||
 | 
			
		||||
	// authenticate
 | 
			
		||||
 | 
			
		||||
	aead, _ := chacha20poly1305.New(key[:])
 | 
			
		||||
	_, err := aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:])
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	hash = addToHash(hash, msg.Empty[:])
 | 
			
		||||
 | 
			
		||||
	// update handshake state
 | 
			
		||||
 | 
			
		||||
	handshake.hash = hash
 | 
			
		||||
	handshake.chainKey = chainKey
 | 
			
		||||
	handshake.remoteIndex = msg.Sender
 | 
			
		||||
	handshake.state = HandshakeResponseConsumed
 | 
			
		||||
 | 
			
		||||
	return peer
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (peer *Peer) NewKeyPair() *KeyPair {
 | 
			
		||||
	handshake := &peer.handshake
 | 
			
		||||
	handshake.mutex.Lock()
 | 
			
		||||
	defer handshake.mutex.Unlock()
 | 
			
		||||
 | 
			
		||||
	// derive keys
 | 
			
		||||
 | 
			
		||||
	var sendKey [chacha20poly1305.KeySize]byte
 | 
			
		||||
	var recvKey [chacha20poly1305.KeySize]byte
 | 
			
		||||
 | 
			
		||||
	if handshake.state == HandshakeResponseConsumed {
 | 
			
		||||
		sendKey, recvKey = KDF2(handshake.chainKey[:], nil)
 | 
			
		||||
	} else if handshake.state == HandshakeResponseCreated {
 | 
			
		||||
		recvKey, sendKey = KDF2(handshake.chainKey[:], nil)
 | 
			
		||||
	} else {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// create AEAD instances
 | 
			
		||||
 | 
			
		||||
	var keyPair KeyPair
 | 
			
		||||
	keyPair.send, _ = chacha20poly1305.New(sendKey[:])
 | 
			
		||||
	keyPair.recv, _ = chacha20poly1305.New(recvKey[:])
 | 
			
		||||
	keyPair.sendNonce = 0
 | 
			
		||||
	keyPair.recvNonce = 0
 | 
			
		||||
 | 
			
		||||
	peer.handshake.state = HandshakeReset
 | 
			
		||||
 | 
			
		||||
	return &keyPair
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -63,7 +63,9 @@ func TestNoiseHandshake(t *testing.T) {
 | 
			
		||||
 | 
			
		||||
	/* simulate handshake */
 | 
			
		||||
 | 
			
		||||
	// Initiation message
 | 
			
		||||
	// initiation message
 | 
			
		||||
 | 
			
		||||
	t.Log("exchange initiation message")
 | 
			
		||||
 | 
			
		||||
	msg1, err := dev1.CreateMessageInitial(peer2)
 | 
			
		||||
	assertNil(t, err)
 | 
			
		||||
@ -88,6 +90,68 @@ func TestNoiseHandshake(t *testing.T) {
 | 
			
		||||
		peer2.handshake.hash[:],
 | 
			
		||||
	)
 | 
			
		||||
 | 
			
		||||
	// Response message
 | 
			
		||||
	// response message
 | 
			
		||||
 | 
			
		||||
	t.Log("exchange response message")
 | 
			
		||||
 | 
			
		||||
	msg2, err := dev2.CreateMessageResponse(peer1)
 | 
			
		||||
	assertNil(t, err)
 | 
			
		||||
 | 
			
		||||
	peer = dev1.ConsumeMessageResponse(msg2)
 | 
			
		||||
	if peer == nil {
 | 
			
		||||
		t.Fatal("handshake failed at response message")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	assertEqual(
 | 
			
		||||
		t,
 | 
			
		||||
		peer1.handshake.chainKey[:],
 | 
			
		||||
		peer2.handshake.chainKey[:],
 | 
			
		||||
	)
 | 
			
		||||
 | 
			
		||||
	assertEqual(
 | 
			
		||||
		t,
 | 
			
		||||
		peer1.handshake.hash[:],
 | 
			
		||||
		peer2.handshake.hash[:],
 | 
			
		||||
	)
 | 
			
		||||
 | 
			
		||||
	// key pairs
 | 
			
		||||
 | 
			
		||||
	t.Log("deriving keys")
 | 
			
		||||
 | 
			
		||||
	key1 := peer1.NewKeyPair()
 | 
			
		||||
	key2 := peer2.NewKeyPair()
 | 
			
		||||
 | 
			
		||||
	if key1 == nil {
 | 
			
		||||
		t.Fatal("failed to dervice key-pair for peer 1")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if key2 == nil {
 | 
			
		||||
		t.Fatal("failed to dervice key-pair for peer 2")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// encrypting / decryption test
 | 
			
		||||
 | 
			
		||||
	t.Log("test key pairs")
 | 
			
		||||
 | 
			
		||||
	func() {
 | 
			
		||||
		testMsg := []byte("wireguard test message 1")
 | 
			
		||||
		var err error
 | 
			
		||||
		var out []byte
 | 
			
		||||
		var nonce [12]byte
 | 
			
		||||
		out = key1.send.Seal(out, nonce[:], testMsg, nil)
 | 
			
		||||
		out, err = key2.recv.Open(out[:0], nonce[:], out, nil)
 | 
			
		||||
		assertNil(t, err)
 | 
			
		||||
		assertEqual(t, out, testMsg)
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	func() {
 | 
			
		||||
		testMsg := []byte("wireguard test message 2")
 | 
			
		||||
		var err error
 | 
			
		||||
		var out []byte
 | 
			
		||||
		var nonce [12]byte
 | 
			
		||||
		out = key2.send.Seal(out, nonce[:], testMsg, nil)
 | 
			
		||||
		out, err = key1.recv.Open(out[:0], nonce[:], out, nil)
 | 
			
		||||
		assertNil(t, err)
 | 
			
		||||
		assertEqual(t, out, testMsg)
 | 
			
		||||
	}()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user