Restructuring of noise impl.
This commit is contained in:
		
							parent
							
								
									521e77fd54
								
							
						
					
					
						commit
						25190e4336
					
				@ -99,11 +99,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
 | 
			
		||||
			if ok {
 | 
			
		||||
				peer = found
 | 
			
		||||
			} else {
 | 
			
		||||
				newPeer := &Peer{
 | 
			
		||||
					publicKey: pubKey,
 | 
			
		||||
				}
 | 
			
		||||
				peer = newPeer
 | 
			
		||||
				device.peers[pubKey] = newPeer
 | 
			
		||||
				peer = device.NewPeer(pubKey)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
		case "replace_peers":
 | 
			
		||||
@ -125,14 +121,14 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
 | 
			
		||||
 | 
			
		||||
			case "remove":
 | 
			
		||||
				peer.mutex.Lock()
 | 
			
		||||
				device.RemovePeer(peer.publicKey)
 | 
			
		||||
				// device.RemovePeer(peer.publicKey)
 | 
			
		||||
				peer = nil
 | 
			
		||||
 | 
			
		||||
			case "preshared_key":
 | 
			
		||||
				err := func() error {
 | 
			
		||||
					peer.mutex.Lock()
 | 
			
		||||
					defer peer.mutex.Unlock()
 | 
			
		||||
					return peer.presharedKey.FromHex(value)
 | 
			
		||||
					return peer.handshake.presharedKey.FromHex(value)
 | 
			
		||||
				}()
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					return &IPCError{Code: ipcErrorInvalidPublicKey}
 | 
			
		||||
@ -144,7 +140,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
 | 
			
		||||
					return &IPCError{Code: ipcErrorInvalidIPAddress}
 | 
			
		||||
				}
 | 
			
		||||
				peer.mutex.Lock()
 | 
			
		||||
				peer.endpoint = ip
 | 
			
		||||
				// peer.endpoint = ip FIX
 | 
			
		||||
				peer.mutex.Unlock()
 | 
			
		||||
 | 
			
		||||
			case "persistent_keepalive_interval":
 | 
			
		||||
 | 
			
		||||
@ -1,17 +1,13 @@
 | 
			
		||||
package main
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"math/rand"
 | 
			
		||||
	"sync"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
/* TODO: Locking may be a little broad here
 | 
			
		||||
 */
 | 
			
		||||
 | 
			
		||||
type Device struct {
 | 
			
		||||
	mutex        sync.RWMutex
 | 
			
		||||
	peers        map[NoisePublicKey]*Peer
 | 
			
		||||
	sessions     map[uint32]*Handshake
 | 
			
		||||
	indices      IndexTable
 | 
			
		||||
	privateKey   NoisePrivateKey
 | 
			
		||||
	publicKey    NoisePublicKey
 | 
			
		||||
	fwMark       uint32
 | 
			
		||||
@ -19,43 +15,66 @@ type Device struct {
 | 
			
		||||
	routingTable RoutingTable
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (dev *Device) NewID(h *Handshake) uint32 {
 | 
			
		||||
	dev.mutex.Lock()
 | 
			
		||||
	defer dev.mutex.Unlock()
 | 
			
		||||
	for {
 | 
			
		||||
		id := rand.Uint32()
 | 
			
		||||
		_, ok := dev.sessions[id]
 | 
			
		||||
		if !ok {
 | 
			
		||||
			dev.sessions[id] = h
 | 
			
		||||
			return id
 | 
			
		||||
		}
 | 
			
		||||
func (device *Device) SetPrivateKey(sk NoisePrivateKey) {
 | 
			
		||||
	device.mutex.Lock()
 | 
			
		||||
	defer device.mutex.Unlock()
 | 
			
		||||
 | 
			
		||||
	// update key material
 | 
			
		||||
 | 
			
		||||
	device.privateKey = sk
 | 
			
		||||
	device.publicKey = sk.publicKey()
 | 
			
		||||
 | 
			
		||||
	// do precomputations
 | 
			
		||||
 | 
			
		||||
	for _, peer := range device.peers {
 | 
			
		||||
		h := &peer.handshake
 | 
			
		||||
		h.mutex.Lock()
 | 
			
		||||
		h.precomputedStaticStatic = device.privateKey.sharedSecret(h.remoteStatic)
 | 
			
		||||
		h.mutex.Unlock()
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (dev *Device) RemovePeer(key NoisePublicKey) {
 | 
			
		||||
	dev.mutex.Lock()
 | 
			
		||||
	defer dev.mutex.Unlock()
 | 
			
		||||
	peer, ok := dev.peers[key]
 | 
			
		||||
func (device *Device) Init() {
 | 
			
		||||
	device.mutex.Lock()
 | 
			
		||||
	defer device.mutex.Unlock()
 | 
			
		||||
 | 
			
		||||
	device.peers = make(map[NoisePublicKey]*Peer)
 | 
			
		||||
	device.indices.Init()
 | 
			
		||||
	device.listenPort = 0
 | 
			
		||||
	device.routingTable.Reset()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (device *Device) LookupPeer(pk NoisePublicKey) *Peer {
 | 
			
		||||
	device.mutex.RLock()
 | 
			
		||||
	defer device.mutex.RUnlock()
 | 
			
		||||
	return device.peers[pk]
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (device *Device) RemovePeer(key NoisePublicKey) {
 | 
			
		||||
	device.mutex.Lock()
 | 
			
		||||
	defer device.mutex.Unlock()
 | 
			
		||||
 | 
			
		||||
	peer, ok := device.peers[key]
 | 
			
		||||
	if !ok {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	peer.mutex.Lock()
 | 
			
		||||
	dev.routingTable.RemovePeer(peer)
 | 
			
		||||
	delete(dev.peers, key)
 | 
			
		||||
	device.routingTable.RemovePeer(peer)
 | 
			
		||||
	delete(device.peers, key)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (dev *Device) RemoveAllAllowedIps(peer *Peer) {
 | 
			
		||||
func (device *Device) RemoveAllAllowedIps(peer *Peer) {
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (dev *Device) RemoveAllPeers() {
 | 
			
		||||
	dev.mutex.Lock()
 | 
			
		||||
	defer dev.mutex.Unlock()
 | 
			
		||||
func (device *Device) RemoveAllPeers() {
 | 
			
		||||
	device.mutex.Lock()
 | 
			
		||||
	defer device.mutex.Unlock()
 | 
			
		||||
 | 
			
		||||
	for key, peer := range dev.peers {
 | 
			
		||||
	for key, peer := range device.peers {
 | 
			
		||||
		peer.mutex.Lock()
 | 
			
		||||
		dev.routingTable.RemovePeer(peer)
 | 
			
		||||
		delete(dev.peers, key)
 | 
			
		||||
		device.routingTable.RemovePeer(peer)
 | 
			
		||||
		delete(device.peers, key)
 | 
			
		||||
		peer.mutex.Unlock()
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										82
									
								
								src/index.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										82
									
								
								src/index.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,82 @@
 | 
			
		||||
package main
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"crypto/rand"
 | 
			
		||||
	"sync"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
/* Index=0 is reserved for unset indecies
 | 
			
		||||
 *
 | 
			
		||||
 */
 | 
			
		||||
 | 
			
		||||
type IndexTable struct {
 | 
			
		||||
	mutex      sync.RWMutex
 | 
			
		||||
	keypairs   map[uint32]*KeyPair
 | 
			
		||||
	handshakes map[uint32]*Handshake
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func randUint32() (uint32, error) {
 | 
			
		||||
	var buff [4]byte
 | 
			
		||||
	_, err := rand.Read(buff[:])
 | 
			
		||||
	id := uint32(buff[0])
 | 
			
		||||
	id <<= 8
 | 
			
		||||
	id |= uint32(buff[1])
 | 
			
		||||
	id <<= 8
 | 
			
		||||
	id |= uint32(buff[2])
 | 
			
		||||
	id <<= 8
 | 
			
		||||
	id |= uint32(buff[3])
 | 
			
		||||
	return id, err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (table *IndexTable) Init() {
 | 
			
		||||
	table.mutex.Lock()
 | 
			
		||||
	defer table.mutex.Unlock()
 | 
			
		||||
	table.keypairs = make(map[uint32]*KeyPair)
 | 
			
		||||
	table.handshakes = make(map[uint32]*Handshake)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (table *IndexTable) NewIndex(handshake *Handshake) (uint32, error) {
 | 
			
		||||
	table.mutex.Lock()
 | 
			
		||||
	defer table.mutex.Unlock()
 | 
			
		||||
	for {
 | 
			
		||||
		// generate random index
 | 
			
		||||
 | 
			
		||||
		id, err := randUint32()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return id, err
 | 
			
		||||
		}
 | 
			
		||||
		if id == 0 {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// check if index used
 | 
			
		||||
 | 
			
		||||
		_, ok := table.keypairs[id]
 | 
			
		||||
		if ok {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		_, ok = table.handshakes[id]
 | 
			
		||||
		if ok {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// update the index
 | 
			
		||||
 | 
			
		||||
		delete(table.handshakes, handshake.localIndex)
 | 
			
		||||
		handshake.localIndex = id
 | 
			
		||||
		table.handshakes[id] = handshake
 | 
			
		||||
		return id, nil
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (table *IndexTable) LookupKeyPair(id uint32) *KeyPair {
 | 
			
		||||
	table.mutex.RLock()
 | 
			
		||||
	defer table.mutex.RUnlock()
 | 
			
		||||
	return table.keypairs[id]
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (table *IndexTable) LookupHandshake(id uint32) *Handshake {
 | 
			
		||||
	table.mutex.RLock()
 | 
			
		||||
	defer table.mutex.RUnlock()
 | 
			
		||||
	return table.handshakes[id]
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										12
									
								
								src/keypair.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										12
									
								
								src/keypair.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,12 @@
 | 
			
		||||
package main
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"crypto/cipher"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type KeyPair struct {
 | 
			
		||||
	recieveKey   cipher.AEAD
 | 
			
		||||
	recieveNonce NoiseNonce
 | 
			
		||||
	sendKey      cipher.AEAD
 | 
			
		||||
	sendNonce    NoiseNonce
 | 
			
		||||
}
 | 
			
		||||
@ -81,6 +81,6 @@ func (sk *NoisePrivateKey) publicKey() (pk NoisePublicKey) {
 | 
			
		||||
func (sk *NoisePrivateKey) sharedSecret(pk NoisePublicKey) (ss [NoisePublicKeySize]byte) {
 | 
			
		||||
	apk := (*[NoisePublicKeySize]byte)(&pk)
 | 
			
		||||
	ask := (*[NoisePrivateKeySize]byte)(sk)
 | 
			
		||||
	curve25519.ScalarMult(&ss, apk, ask)
 | 
			
		||||
	curve25519.ScalarMult(&ss, ask, apk)
 | 
			
		||||
	return ss
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -56,18 +56,22 @@ type MessageTransport struct {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Handshake struct {
 | 
			
		||||
	lock         sync.Mutex
 | 
			
		||||
	state        int
 | 
			
		||||
	chainKey     [blake2s.Size]byte // chain key
 | 
			
		||||
	hash         [blake2s.Size]byte // hash value
 | 
			
		||||
	staticStatic NoisePublicKey     // precomputed DH(S_i, S_r)
 | 
			
		||||
	ephemeral    NoisePrivateKey    // ephemeral secret key
 | 
			
		||||
	remoteIndex  uint32             // index for sending
 | 
			
		||||
	device       *Device
 | 
			
		||||
	peer         *Peer
 | 
			
		||||
	state                   int
 | 
			
		||||
	mutex                   sync.Mutex
 | 
			
		||||
	hash                    [blake2s.Size]byte       // hash value
 | 
			
		||||
	chainKey                [blake2s.Size]byte       // chain key
 | 
			
		||||
	presharedKey            NoiseSymmetricKey        // psk
 | 
			
		||||
	localEphemeral          NoisePrivateKey          // ephemeral secret key
 | 
			
		||||
	localIndex              uint32                   // used to clear hash-table
 | 
			
		||||
	remoteIndex             uint32                   // index for sending
 | 
			
		||||
	remoteStatic            NoisePublicKey           // long term key
 | 
			
		||||
	remoteEphemeral         NoisePublicKey           // ephemeral public key
 | 
			
		||||
	precomputedStaticStatic [NoisePublicKeySize]byte // precomputed shared secret
 | 
			
		||||
	lastTimestamp           TAI64N
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var (
 | 
			
		||||
	EmptyMessage   []byte
 | 
			
		||||
	ZeroNonce      [chacha20poly1305.NonceSize]byte
 | 
			
		||||
	InitalChainKey [blake2s.Size]byte
 | 
			
		||||
	InitalHash     [blake2s.Size]byte
 | 
			
		||||
@ -78,102 +82,196 @@ func init() {
 | 
			
		||||
	InitalHash = blake2s.Sum256(append(InitalChainKey[:], []byte(WGIdentifier)...))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *Handshake) Precompute() {
 | 
			
		||||
	h.staticStatic = h.device.privateKey.sharedSecret(h.peer.publicKey)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *Handshake) ConsumeMessageResponse(msg *MessageResponse) {
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *Handshake) addHash(data []byte) {
 | 
			
		||||
func (h *Handshake) addToHash(data []byte) {
 | 
			
		||||
	h.hash = addToHash(h.hash, data)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *Handshake) addChain(data []byte) {
 | 
			
		||||
func (h *Handshake) addToChainKey(data []byte) {
 | 
			
		||||
	h.chainKey = addToChainKey(h.chainKey, data)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *Handshake) CreateMessageInital() (*MessageInital, error) {
 | 
			
		||||
	h.lock.Lock()
 | 
			
		||||
	defer h.lock.Unlock()
 | 
			
		||||
func (device *Device) Precompute(peer *Peer) {
 | 
			
		||||
	h := &peer.handshake
 | 
			
		||||
	h.precomputedStaticStatic = device.privateKey.sharedSecret(h.remoteStatic)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
	// reset handshake
 | 
			
		||||
 | 
			
		||||
	var err error
 | 
			
		||||
	h.ephemeral, err = newPrivateKey()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	h.chainKey = InitalChainKey
 | 
			
		||||
	h.hash = addToHash(InitalHash, h.device.publicKey[:])
 | 
			
		||||
func (device *Device) CreateMessageInitial(peer *Peer) (*MessageInital, error) {
 | 
			
		||||
	handshake := &peer.handshake
 | 
			
		||||
	handshake.mutex.Lock()
 | 
			
		||||
	defer handshake.mutex.Unlock()
 | 
			
		||||
 | 
			
		||||
	// create ephemeral key
 | 
			
		||||
 | 
			
		||||
	var err error
 | 
			
		||||
	handshake.chainKey = InitalChainKey
 | 
			
		||||
	handshake.hash = addToHash(InitalHash, handshake.remoteStatic[:])
 | 
			
		||||
	handshake.localEphemeral, err = newPrivateKey()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// assign index
 | 
			
		||||
 | 
			
		||||
	var msg MessageInital
 | 
			
		||||
 | 
			
		||||
	msg.Type = MessageInitalType
 | 
			
		||||
	msg.Sender = h.device.NewID(h)
 | 
			
		||||
	msg.Ephemeral = h.ephemeral.publicKey()
 | 
			
		||||
	h.chainKey = addToChainKey(h.chainKey, msg.Ephemeral[:])
 | 
			
		||||
	h.hash = addToHash(h.hash, msg.Ephemeral[:])
 | 
			
		||||
	msg.Ephemeral = handshake.localEphemeral.publicKey()
 | 
			
		||||
	msg.Sender, err = device.indices.NewIndex(handshake)
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	handshake.addToChainKey(msg.Ephemeral[:])
 | 
			
		||||
	handshake.addToHash(msg.Ephemeral[:])
 | 
			
		||||
 | 
			
		||||
	// encrypt long-term "identity key"
 | 
			
		||||
 | 
			
		||||
	func() {
 | 
			
		||||
		var key [chacha20poly1305.KeySize]byte
 | 
			
		||||
		ss := h.ephemeral.sharedSecret(h.peer.publicKey)
 | 
			
		||||
		h.chainKey, key = KDF2(h.chainKey[:], ss[:])
 | 
			
		||||
		ss := handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
 | 
			
		||||
		handshake.chainKey, key = KDF2(handshake.chainKey[:], ss[:])
 | 
			
		||||
		aead, _ := chacha20poly1305.New(key[:])
 | 
			
		||||
		aead.Seal(msg.Static[:0], ZeroNonce[:], h.device.publicKey[:], nil)
 | 
			
		||||
		aead.Seal(msg.Static[:0], ZeroNonce[:], device.publicKey[:], handshake.hash[:])
 | 
			
		||||
	}()
 | 
			
		||||
	h.addHash(msg.Static[:])
 | 
			
		||||
	handshake.addToHash(msg.Static[:])
 | 
			
		||||
 | 
			
		||||
	// encrypt timestamp
 | 
			
		||||
 | 
			
		||||
	timestamp := Timestamp()
 | 
			
		||||
	func() {
 | 
			
		||||
		var key [chacha20poly1305.KeySize]byte
 | 
			
		||||
		h.chainKey, key = KDF2(h.chainKey[:], h.staticStatic[:])
 | 
			
		||||
		handshake.chainKey, key = KDF2(
 | 
			
		||||
			handshake.chainKey[:],
 | 
			
		||||
			handshake.precomputedStaticStatic[:],
 | 
			
		||||
		)
 | 
			
		||||
		aead, _ := chacha20poly1305.New(key[:])
 | 
			
		||||
		aead.Seal(msg.Timestamp[:0], ZeroNonce[:], timestamp[:], nil)
 | 
			
		||||
		aead.Seal(msg.Timestamp[:0], ZeroNonce[:], timestamp[:], handshake.hash[:])
 | 
			
		||||
	}()
 | 
			
		||||
	h.addHash(msg.Timestamp[:])
 | 
			
		||||
	h.state = HandshakeInitialCreated
 | 
			
		||||
 | 
			
		||||
	handshake.addToHash(msg.Timestamp[:])
 | 
			
		||||
	handshake.state = HandshakeInitialCreated
 | 
			
		||||
 | 
			
		||||
	return &msg, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *Handshake) ConsumeMessageInitial(msg *MessageInital) error {
 | 
			
		||||
func (device *Device) ConsumeMessageInitial(msg *MessageInital) *Peer {
 | 
			
		||||
	if msg.Type != MessageInitalType {
 | 
			
		||||
		panic(errors.New("bug: invalid inital message type"))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	hash := addToHash(InitalHash, h.device.publicKey[:])
 | 
			
		||||
	chainKey := addToChainKey(InitalChainKey, msg.Ephemeral[:])
 | 
			
		||||
	hash := addToHash(InitalHash, device.publicKey[:])
 | 
			
		||||
	hash = addToHash(hash, msg.Ephemeral[:])
 | 
			
		||||
	chainKey := addToChainKey(InitalChainKey, msg.Ephemeral[:])
 | 
			
		||||
 | 
			
		||||
	//
 | 
			
		||||
	// decrypt identity key
 | 
			
		||||
 | 
			
		||||
	ephemeral, err := newPrivateKey()
 | 
			
		||||
	var err error
 | 
			
		||||
	var peerPK NoisePublicKey
 | 
			
		||||
	func() {
 | 
			
		||||
		var key [chacha20poly1305.KeySize]byte
 | 
			
		||||
		ss := device.privateKey.sharedSecret(msg.Ephemeral)
 | 
			
		||||
		chainKey, key = KDF2(chainKey[:], ss[:])
 | 
			
		||||
		aead, _ := chacha20poly1305.New(key[:])
 | 
			
		||||
		_, err = aead.Open(peerPK[:0], ZeroNonce[:], msg.Static[:], hash[:])
 | 
			
		||||
	}()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	hash = addToHash(hash, msg.Static[:])
 | 
			
		||||
 | 
			
		||||
	// find peer
 | 
			
		||||
 | 
			
		||||
	peer := device.LookupPeer(peerPK)
 | 
			
		||||
	if peer == nil {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	handshake := &peer.handshake
 | 
			
		||||
	handshake.mutex.Lock()
 | 
			
		||||
	defer handshake.mutex.Unlock()
 | 
			
		||||
 | 
			
		||||
	// decrypt timestamp
 | 
			
		||||
 | 
			
		||||
	var timestamp TAI64N
 | 
			
		||||
	func() {
 | 
			
		||||
		var key [chacha20poly1305.KeySize]byte
 | 
			
		||||
		chainKey, key = KDF2(
 | 
			
		||||
			chainKey[:],
 | 
			
		||||
			handshake.precomputedStaticStatic[:],
 | 
			
		||||
		)
 | 
			
		||||
		aead, _ := chacha20poly1305.New(key[:])
 | 
			
		||||
		_, err = aead.Open(timestamp[:0], ZeroNonce[:], msg.Timestamp[:], hash[:])
 | 
			
		||||
	}()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	hash = addToHash(hash, msg.Timestamp[:])
 | 
			
		||||
 | 
			
		||||
	// check for replay attack
 | 
			
		||||
 | 
			
		||||
	if !timestamp.After(handshake.lastTimestamp) {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// check for flood attack
 | 
			
		||||
 | 
			
		||||
	// update handshake state
 | 
			
		||||
 | 
			
		||||
	h.lock.Lock()
 | 
			
		||||
	defer h.lock.Unlock()
 | 
			
		||||
 | 
			
		||||
	h.hash = hash
 | 
			
		||||
	h.chainKey = chainKey
 | 
			
		||||
	h.remoteIndex = msg.Sender
 | 
			
		||||
	h.ephemeral = ephemeral
 | 
			
		||||
	h.state = HandshakeInitialConsumed
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
 | 
			
		||||
	handshake.hash = hash
 | 
			
		||||
	handshake.chainKey = chainKey
 | 
			
		||||
	handshake.remoteIndex = msg.Sender
 | 
			
		||||
	handshake.remoteEphemeral = msg.Ephemeral
 | 
			
		||||
	handshake.state = HandshakeInitialConsumed
 | 
			
		||||
	return peer
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *Handshake) CreateMessageResponse() []byte {
 | 
			
		||||
func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error) {
 | 
			
		||||
	handshake := &peer.handshake
 | 
			
		||||
	handshake.mutex.Lock()
 | 
			
		||||
	defer handshake.mutex.Unlock()
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
	if handshake.state != HandshakeInitialConsumed {
 | 
			
		||||
		panic(errors.New("bug: handshake initation must be consumed first"))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// assign index
 | 
			
		||||
 | 
			
		||||
	var err error
 | 
			
		||||
	var msg MessageResponse
 | 
			
		||||
	msg.Type = MessageResponseType
 | 
			
		||||
	msg.Sender, err = device.indices.NewIndex(handshake)
 | 
			
		||||
	msg.Reciever = handshake.remoteIndex
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// create ephemeral key
 | 
			
		||||
 | 
			
		||||
	handshake.localEphemeral, err = newPrivateKey()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	msg.Ephemeral = handshake.localEphemeral.publicKey()
 | 
			
		||||
 | 
			
		||||
	func() {
 | 
			
		||||
		ss := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral)
 | 
			
		||||
		handshake.addToChainKey(ss[:])
 | 
			
		||||
		ss = handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
 | 
			
		||||
		handshake.addToChainKey(ss[:])
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	// add preshared key (psk)
 | 
			
		||||
 | 
			
		||||
	var tau [blake2s.Size]byte
 | 
			
		||||
	var key [chacha20poly1305.KeySize]byte
 | 
			
		||||
	handshake.chainKey, tau, key = KDF3(handshake.chainKey[:], handshake.presharedKey[:])
 | 
			
		||||
	handshake.addToHash(tau[:])
 | 
			
		||||
 | 
			
		||||
	func() {
 | 
			
		||||
		aead, _ := chacha20poly1305.New(key[:])
 | 
			
		||||
		aead.Seal(msg.Empty[:0], ZeroNonce[:], EmptyMessage, handshake.hash[:])
 | 
			
		||||
		handshake.addToHash(msg.Empty[:])
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	return &msg, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -1,38 +1,93 @@
 | 
			
		||||
package main
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"encoding/binary"
 | 
			
		||||
	"testing"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestHandshake(t *testing.T) {
 | 
			
		||||
	var dev1 Device
 | 
			
		||||
	var dev2 Device
 | 
			
		||||
 | 
			
		||||
	var err error
 | 
			
		||||
 | 
			
		||||
	dev1.privateKey, err = newPrivateKey()
 | 
			
		||||
func assertNil(t *testing.T, err error) {
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
	dev2.privateKey, err = newPrivateKey()
 | 
			
		||||
func assertEqual(t *testing.T, a []byte, b []byte) {
 | 
			
		||||
	if bytes.Compare(a, b) != 0 {
 | 
			
		||||
		t.Fatal(a, "!=", b)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestCurveWrappers(t *testing.T) {
 | 
			
		||||
	sk1, err := newPrivateKey()
 | 
			
		||||
	assertNil(t, err)
 | 
			
		||||
 | 
			
		||||
	sk2, err := newPrivateKey()
 | 
			
		||||
	assertNil(t, err)
 | 
			
		||||
 | 
			
		||||
	pk1 := sk1.publicKey()
 | 
			
		||||
	pk2 := sk2.publicKey()
 | 
			
		||||
 | 
			
		||||
	ss1 := sk1.sharedSecret(pk2)
 | 
			
		||||
	ss2 := sk2.sharedSecret(pk1)
 | 
			
		||||
 | 
			
		||||
	if ss1 != ss2 {
 | 
			
		||||
		t.Fatal("Failed to compute shared secet")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func newDevice(t *testing.T) *Device {
 | 
			
		||||
	var device Device
 | 
			
		||||
	sk, err := newPrivateKey()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
	device.Init()
 | 
			
		||||
	device.SetPrivateKey(sk)
 | 
			
		||||
	return &device
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
	var peer1 Peer
 | 
			
		||||
	var peer2 Peer
 | 
			
		||||
func TestNoiseHandshake(t *testing.T) {
 | 
			
		||||
 | 
			
		||||
	peer1.publicKey = dev1.privateKey.publicKey()
 | 
			
		||||
	peer2.publicKey = dev2.privateKey.publicKey()
 | 
			
		||||
	dev1 := newDevice(t)
 | 
			
		||||
	dev2 := newDevice(t)
 | 
			
		||||
 | 
			
		||||
	var handshake1 Handshake
 | 
			
		||||
	var handshake2 Handshake
 | 
			
		||||
	peer1 := dev2.NewPeer(dev1.privateKey.publicKey())
 | 
			
		||||
	peer2 := dev1.NewPeer(dev2.privateKey.publicKey())
 | 
			
		||||
 | 
			
		||||
	handshake1.device = &dev1
 | 
			
		||||
	handshake2.device = &dev2
 | 
			
		||||
	assertEqual(
 | 
			
		||||
		t,
 | 
			
		||||
		peer1.handshake.precomputedStaticStatic[:],
 | 
			
		||||
		peer2.handshake.precomputedStaticStatic[:],
 | 
			
		||||
	)
 | 
			
		||||
 | 
			
		||||
	handshake1.peer = &peer2
 | 
			
		||||
	handshake2.peer = &peer1
 | 
			
		||||
	/* simulate handshake */
 | 
			
		||||
 | 
			
		||||
	// Initiation message
 | 
			
		||||
 | 
			
		||||
	msg1, err := dev1.CreateMessageInitial(peer2)
 | 
			
		||||
	assertNil(t, err)
 | 
			
		||||
 | 
			
		||||
	packet := make([]byte, 0, 256)
 | 
			
		||||
	writer := bytes.NewBuffer(packet)
 | 
			
		||||
	err = binary.Write(writer, binary.LittleEndian, msg1)
 | 
			
		||||
	peer := dev2.ConsumeMessageInitial(msg1)
 | 
			
		||||
	if peer == nil {
 | 
			
		||||
		t.Fatal("handshake failed at initiation message")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	assertEqual(
 | 
			
		||||
		t,
 | 
			
		||||
		peer1.handshake.chainKey[:],
 | 
			
		||||
		peer2.handshake.chainKey[:],
 | 
			
		||||
	)
 | 
			
		||||
 | 
			
		||||
	assertEqual(
 | 
			
		||||
		t,
 | 
			
		||||
		peer1.handshake.hash[:],
 | 
			
		||||
		peer2.handshake.hash[:],
 | 
			
		||||
	)
 | 
			
		||||
 | 
			
		||||
	// Response message
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										40
									
								
								src/peer.go
									
									
									
									
									
								
							
							
						
						
									
										40
									
								
								src/peer.go
									
									
									
									
									
								
							@ -6,17 +6,35 @@ import (
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type KeyPair struct {
 | 
			
		||||
	recieveKey   NoiseSymmetricKey
 | 
			
		||||
	recieveNonce NoiseNonce
 | 
			
		||||
	sendKey      NoiseSymmetricKey
 | 
			
		||||
	sendNonce    NoiseNonce
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Peer struct {
 | 
			
		||||
	mutex                       sync.RWMutex
 | 
			
		||||
	publicKey                   NoisePublicKey
 | 
			
		||||
	presharedKey                NoiseSymmetricKey
 | 
			
		||||
	endpoint                    net.IP
 | 
			
		||||
	persistentKeepaliveInterval time.Duration
 | 
			
		||||
	endpointIP                  net.IP        //
 | 
			
		||||
	endpointPort                uint16        //
 | 
			
		||||
	persistentKeepaliveInterval time.Duration // 0 = disabled
 | 
			
		||||
	handshake                   Handshake
 | 
			
		||||
	device                      *Device
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (device *Device) NewPeer(pk NoisePublicKey) *Peer {
 | 
			
		||||
	var peer Peer
 | 
			
		||||
 | 
			
		||||
	// map public key
 | 
			
		||||
 | 
			
		||||
	device.mutex.Lock()
 | 
			
		||||
	device.peers[pk] = &peer
 | 
			
		||||
	device.mutex.Unlock()
 | 
			
		||||
 | 
			
		||||
	// precompute
 | 
			
		||||
 | 
			
		||||
	peer.mutex.Lock()
 | 
			
		||||
	peer.device = device
 | 
			
		||||
	func(h *Handshake) {
 | 
			
		||||
		h.mutex.Lock()
 | 
			
		||||
		h.remoteStatic = pk
 | 
			
		||||
		h.precomputedStaticStatic = device.privateKey.sharedSecret(h.remoteStatic)
 | 
			
		||||
		h.mutex.Unlock()
 | 
			
		||||
	}(&peer.handshake)
 | 
			
		||||
	peer.mutex.Unlock()
 | 
			
		||||
 | 
			
		||||
	return &peer
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -13,6 +13,13 @@ type RoutingTable struct {
 | 
			
		||||
	mutex sync.RWMutex
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (table *RoutingTable) Reset() {
 | 
			
		||||
	table.mutex.Lock()
 | 
			
		||||
	defer table.mutex.Unlock()
 | 
			
		||||
	table.IPv4 = nil
 | 
			
		||||
	table.IPv6 = nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (table *RoutingTable) RemovePeer(peer *Peer) {
 | 
			
		||||
	table.mutex.Lock()
 | 
			
		||||
	defer table.mutex.Unlock()
 | 
			
		||||
 | 
			
		||||
@ -1,6 +1,7 @@
 | 
			
		||||
package main
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"encoding/binary"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
@ -21,3 +22,7 @@ func Timestamp() TAI64N {
 | 
			
		||||
	binary.BigEndian.PutUint32(tai64n[8:], nano)
 | 
			
		||||
	return tai64n
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (t1 *TAI64N) After(t2 TAI64N) bool {
 | 
			
		||||
	return bytes.Compare(t1[:], t2[:]) > 0
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user