Begin work on outbound packet flow
This commit is contained in:
		
							parent
							
								
									cf3a5130d3
								
							
						
					
					
						commit
						9d806d3853
					
				
							
								
								
									
										39
									
								
								src/cookie.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										39
									
								
								src/cookie.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,39 @@
 | 
			
		||||
package main
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"errors"
 | 
			
		||||
	"golang.org/x/crypto/blake2s"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func CalculateCookie(peer *Peer, msg []byte) {
 | 
			
		||||
	size := len(msg)
 | 
			
		||||
 | 
			
		||||
	if size < blake2s.Size128*2 {
 | 
			
		||||
		panic(errors.New("bug: message too short"))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	startMac1 := size - (blake2s.Size128 * 2)
 | 
			
		||||
	startMac2 := size - blake2s.Size128
 | 
			
		||||
 | 
			
		||||
	mac1 := msg[startMac1 : startMac1+blake2s.Size128]
 | 
			
		||||
	mac2 := msg[startMac2 : startMac2+blake2s.Size128]
 | 
			
		||||
 | 
			
		||||
	peer.mutex.RLock()
 | 
			
		||||
	defer peer.mutex.RUnlock()
 | 
			
		||||
 | 
			
		||||
	// set mac1
 | 
			
		||||
 | 
			
		||||
	func() {
 | 
			
		||||
		mac, _ := blake2s.New128(peer.macKey[:])
 | 
			
		||||
		mac.Write(msg[:startMac1])
 | 
			
		||||
		mac.Sum(mac1[:0])
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	// set mac2
 | 
			
		||||
 | 
			
		||||
	if peer.cookie != nil {
 | 
			
		||||
		mac, _ := blake2s.New128(peer.cookie)
 | 
			
		||||
		mac.Write(msg[:startMac2])
 | 
			
		||||
		mac.Sum(mac2[:0])
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
@ -1,10 +1,12 @@
 | 
			
		||||
package main
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"log"
 | 
			
		||||
	"sync"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Device struct {
 | 
			
		||||
	mtu               int
 | 
			
		||||
	mutex             sync.RWMutex
 | 
			
		||||
	peers             map[NoisePublicKey]*Peer
 | 
			
		||||
	indices           IndexTable
 | 
			
		||||
@ -13,6 +15,8 @@ type Device struct {
 | 
			
		||||
	fwMark            uint32
 | 
			
		||||
	listenPort        uint16
 | 
			
		||||
	routingTable      RoutingTable
 | 
			
		||||
	logger            log.Logger
 | 
			
		||||
	queueWorkOutbound chan *OutboundWorkQueueElement
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (device *Device) SetPrivateKey(sk NoisePrivateKey) {
 | 
			
		||||
 | 
			
		||||
@ -2,11 +2,20 @@ package main
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"crypto/cipher"
 | 
			
		||||
	"sync"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type KeyPair struct {
 | 
			
		||||
	recv      cipher.AEAD
 | 
			
		||||
	recvNonce NoiseNonce
 | 
			
		||||
	recvNonce uint64
 | 
			
		||||
	send      cipher.AEAD
 | 
			
		||||
	sendNonce NoiseNonce
 | 
			
		||||
	sendNonce uint64
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type KeyPairs struct {
 | 
			
		||||
	mutex      sync.RWMutex
 | 
			
		||||
	current    *KeyPair
 | 
			
		||||
	previous   *KeyPair
 | 
			
		||||
	next       *KeyPair
 | 
			
		||||
	newKeyPair chan bool
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -1,6 +1,8 @@
 | 
			
		||||
package main
 | 
			
		||||
 | 
			
		||||
import "fmt"
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func main() {
 | 
			
		||||
	fd, err := CreateTUN("test0")
 | 
			
		||||
@ -8,9 +10,9 @@ func main() {
 | 
			
		||||
 | 
			
		||||
	queue := make(chan []byte, 1000)
 | 
			
		||||
 | 
			
		||||
	var device Device
 | 
			
		||||
	// var device Device
 | 
			
		||||
 | 
			
		||||
	go OutgoingRoutingWorker(&device, queue)
 | 
			
		||||
	// go OutgoingRoutingWorker(&device, queue)
 | 
			
		||||
 | 
			
		||||
	for {
 | 
			
		||||
		tmp := make([]byte, 1<<16)
 | 
			
		||||
 | 
			
		||||
@ -9,9 +9,9 @@ import (
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	HandshakeReset = iota
 | 
			
		||||
	HandshakeInitialCreated
 | 
			
		||||
	HandshakeInitialConsumed
 | 
			
		||||
	HandshakeZeroed = iota
 | 
			
		||||
	HandshakeInitiationCreated
 | 
			
		||||
	HandshakeInitiationConsumed
 | 
			
		||||
	HandshakeResponseCreated
 | 
			
		||||
	HandshakeResponseConsumed
 | 
			
		||||
)
 | 
			
		||||
@ -24,13 +24,19 @@ const (
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	MessageInitalType         = 1
 | 
			
		||||
	MessageInitiationType     = 1
 | 
			
		||||
	MessageResponseType       = 2
 | 
			
		||||
	MessageCookieResponseType = 3
 | 
			
		||||
	MessageTransportType      = 4
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type MessageInital struct {
 | 
			
		||||
/* Type is an 8-bit field, followed by 3 nul bytes,
 | 
			
		||||
 * by marshalling the messages in little-endian byteorder
 | 
			
		||||
 * we can treat these as a 32-bit int
 | 
			
		||||
 *
 | 
			
		||||
 */
 | 
			
		||||
 | 
			
		||||
type MessageInitiation struct {
 | 
			
		||||
	Type      uint32
 | 
			
		||||
	Sender    uint32
 | 
			
		||||
	Ephemeral NoisePublicKey
 | 
			
		||||
@ -73,9 +79,9 @@ type Handshake struct {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var (
 | 
			
		||||
	ZeroNonce      [chacha20poly1305.NonceSize]byte
 | 
			
		||||
	InitalChainKey [blake2s.Size]byte
 | 
			
		||||
	InitalHash     [blake2s.Size]byte
 | 
			
		||||
	ZeroNonce      [chacha20poly1305.NonceSize]byte
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func init() {
 | 
			
		||||
@ -83,23 +89,23 @@ func init() {
 | 
			
		||||
	InitalHash = blake2s.Sum256(append(InitalChainKey[:], []byte(WGIdentifier)...))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func addToChainKey(c [blake2s.Size]byte, data []byte) [blake2s.Size]byte {
 | 
			
		||||
func mixKey(c [blake2s.Size]byte, data []byte) [blake2s.Size]byte {
 | 
			
		||||
	return KDF1(c[:], data)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func addToHash(h [blake2s.Size]byte, data []byte) [blake2s.Size]byte {
 | 
			
		||||
func mixHash(h [blake2s.Size]byte, data []byte) [blake2s.Size]byte {
 | 
			
		||||
	return blake2s.Sum256(append(h[:], data...))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *Handshake) addToHash(data []byte) {
 | 
			
		||||
	h.hash = addToHash(h.hash, data)
 | 
			
		||||
func (h *Handshake) mixHash(data []byte) {
 | 
			
		||||
	h.hash = mixHash(h.hash, data)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *Handshake) addToChainKey(data []byte) {
 | 
			
		||||
	h.chainKey = addToChainKey(h.chainKey, data)
 | 
			
		||||
func (h *Handshake) mixKey(data []byte) {
 | 
			
		||||
	h.chainKey = mixKey(h.chainKey, data)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (device *Device) CreateMessageInitial(peer *Peer) (*MessageInital, error) {
 | 
			
		||||
func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) {
 | 
			
		||||
	handshake := &peer.handshake
 | 
			
		||||
	handshake.mutex.Lock()
 | 
			
		||||
	defer handshake.mutex.Unlock()
 | 
			
		||||
@ -108,7 +114,7 @@ func (device *Device) CreateMessageInitial(peer *Peer) (*MessageInital, error) {
 | 
			
		||||
 | 
			
		||||
	var err error
 | 
			
		||||
	handshake.chainKey = InitalChainKey
 | 
			
		||||
	handshake.hash = addToHash(InitalHash, handshake.remoteStatic[:])
 | 
			
		||||
	handshake.hash = mixHash(InitalHash, handshake.remoteStatic[:])
 | 
			
		||||
	handshake.localEphemeral, err = newPrivateKey()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
@ -116,9 +122,9 @@ func (device *Device) CreateMessageInitial(peer *Peer) (*MessageInital, error) {
 | 
			
		||||
 | 
			
		||||
	// assign index
 | 
			
		||||
 | 
			
		||||
	var msg MessageInital
 | 
			
		||||
	var msg MessageInitiation
 | 
			
		||||
 | 
			
		||||
	msg.Type = MessageInitalType
 | 
			
		||||
	msg.Type = MessageInitiationType
 | 
			
		||||
	msg.Ephemeral = handshake.localEphemeral.publicKey()
 | 
			
		||||
	handshake.localIndex, err = device.indices.NewIndex(peer)
 | 
			
		||||
 | 
			
		||||
@ -127,10 +133,10 @@ func (device *Device) CreateMessageInitial(peer *Peer) (*MessageInital, error) {
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	msg.Sender = handshake.localIndex
 | 
			
		||||
	handshake.addToChainKey(msg.Ephemeral[:])
 | 
			
		||||
	handshake.addToHash(msg.Ephemeral[:])
 | 
			
		||||
	handshake.mixKey(msg.Ephemeral[:])
 | 
			
		||||
	handshake.mixHash(msg.Ephemeral[:])
 | 
			
		||||
 | 
			
		||||
	// encrypt identity key
 | 
			
		||||
	// encrypt static key
 | 
			
		||||
 | 
			
		||||
	func() {
 | 
			
		||||
		var key [chacha20poly1305.KeySize]byte
 | 
			
		||||
@ -139,7 +145,7 @@ func (device *Device) CreateMessageInitial(peer *Peer) (*MessageInital, error) {
 | 
			
		||||
		aead, _ := chacha20poly1305.New(key[:])
 | 
			
		||||
		aead.Seal(msg.Static[:0], ZeroNonce[:], device.publicKey[:], handshake.hash[:])
 | 
			
		||||
	}()
 | 
			
		||||
	handshake.addToHash(msg.Static[:])
 | 
			
		||||
	handshake.mixHash(msg.Static[:])
 | 
			
		||||
 | 
			
		||||
	// encrypt timestamp
 | 
			
		||||
 | 
			
		||||
@ -154,22 +160,22 @@ func (device *Device) CreateMessageInitial(peer *Peer) (*MessageInital, error) {
 | 
			
		||||
		aead.Seal(msg.Timestamp[:0], ZeroNonce[:], timestamp[:], handshake.hash[:])
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	handshake.addToHash(msg.Timestamp[:])
 | 
			
		||||
	handshake.state = HandshakeInitialCreated
 | 
			
		||||
	handshake.mixHash(msg.Timestamp[:])
 | 
			
		||||
	handshake.state = HandshakeInitiationCreated
 | 
			
		||||
 | 
			
		||||
	return &msg, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (device *Device) ConsumeMessageInitial(msg *MessageInital) *Peer {
 | 
			
		||||
	if msg.Type != MessageInitalType {
 | 
			
		||||
		panic(errors.New("bug: invalid inital message type"))
 | 
			
		||||
func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
 | 
			
		||||
	if msg.Type != MessageInitiationType {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	hash := addToHash(InitalHash, device.publicKey[:])
 | 
			
		||||
	hash = addToHash(hash, msg.Ephemeral[:])
 | 
			
		||||
	chainKey := addToChainKey(InitalChainKey, msg.Ephemeral[:])
 | 
			
		||||
	hash := mixHash(InitalHash, device.publicKey[:])
 | 
			
		||||
	hash = mixHash(hash, msg.Ephemeral[:])
 | 
			
		||||
	chainKey := mixKey(InitalChainKey, msg.Ephemeral[:])
 | 
			
		||||
 | 
			
		||||
	// decrypt identity key
 | 
			
		||||
	// decrypt static key
 | 
			
		||||
 | 
			
		||||
	var err error
 | 
			
		||||
	var peerPK NoisePublicKey
 | 
			
		||||
@ -183,7 +189,7 @@ func (device *Device) ConsumeMessageInitial(msg *MessageInital) *Peer {
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	hash = addToHash(hash, msg.Static[:])
 | 
			
		||||
	hash = mixHash(hash, msg.Static[:])
 | 
			
		||||
 | 
			
		||||
	// find peer
 | 
			
		||||
 | 
			
		||||
@ -210,7 +216,7 @@ func (device *Device) ConsumeMessageInitial(msg *MessageInital) *Peer {
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	hash = addToHash(hash, msg.Timestamp[:])
 | 
			
		||||
	hash = mixHash(hash, msg.Timestamp[:])
 | 
			
		||||
 | 
			
		||||
	// check for replay attack
 | 
			
		||||
 | 
			
		||||
@ -218,7 +224,7 @@ func (device *Device) ConsumeMessageInitial(msg *MessageInital) *Peer {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// check for flood attack
 | 
			
		||||
	// TODO: check for flood attack
 | 
			
		||||
 | 
			
		||||
	// update handshake state
 | 
			
		||||
 | 
			
		||||
@ -227,7 +233,7 @@ func (device *Device) ConsumeMessageInitial(msg *MessageInital) *Peer {
 | 
			
		||||
	handshake.remoteIndex = msg.Sender
 | 
			
		||||
	handshake.remoteEphemeral = msg.Ephemeral
 | 
			
		||||
	handshake.lastTimestamp = timestamp
 | 
			
		||||
	handshake.state = HandshakeInitialConsumed
 | 
			
		||||
	handshake.state = HandshakeInitiationConsumed
 | 
			
		||||
	return peer
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -236,8 +242,8 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
 | 
			
		||||
	handshake.mutex.Lock()
 | 
			
		||||
	defer handshake.mutex.Unlock()
 | 
			
		||||
 | 
			
		||||
	if handshake.state != HandshakeInitialConsumed {
 | 
			
		||||
		panic(errors.New("bug: handshake initation must be consumed first"))
 | 
			
		||||
	if handshake.state != HandshakeInitiationConsumed {
 | 
			
		||||
		return nil, errors.New("handshake initation must be consumed first")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// assign index
 | 
			
		||||
@ -260,13 +266,13 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	msg.Ephemeral = handshake.localEphemeral.publicKey()
 | 
			
		||||
	handshake.addToHash(msg.Ephemeral[:])
 | 
			
		||||
	handshake.mixHash(msg.Ephemeral[:])
 | 
			
		||||
 | 
			
		||||
	func() {
 | 
			
		||||
		ss := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral)
 | 
			
		||||
		handshake.addToChainKey(ss[:])
 | 
			
		||||
		handshake.mixKey(ss[:])
 | 
			
		||||
		ss = handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
 | 
			
		||||
		handshake.addToChainKey(ss[:])
 | 
			
		||||
		handshake.mixKey(ss[:])
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	// add preshared key (psk)
 | 
			
		||||
@ -274,12 +280,12 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
 | 
			
		||||
	var tau [blake2s.Size]byte
 | 
			
		||||
	var key [chacha20poly1305.KeySize]byte
 | 
			
		||||
	handshake.chainKey, tau, key = KDF3(handshake.chainKey[:], handshake.presharedKey[:])
 | 
			
		||||
	handshake.addToHash(tau[:])
 | 
			
		||||
	handshake.mixHash(tau[:])
 | 
			
		||||
 | 
			
		||||
	func() {
 | 
			
		||||
		aead, _ := chacha20poly1305.New(key[:])
 | 
			
		||||
		aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:])
 | 
			
		||||
		handshake.addToHash(msg.Empty[:])
 | 
			
		||||
		handshake.mixHash(msg.Empty[:])
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	handshake.state = HandshakeResponseCreated
 | 
			
		||||
@ -288,7 +294,7 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
 | 
			
		||||
 | 
			
		||||
func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
 | 
			
		||||
	if msg.Type != MessageResponseType {
 | 
			
		||||
		panic(errors.New("bug: invalid message type"))
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// lookup handshake by reciever
 | 
			
		||||
@ -300,20 +306,20 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
 | 
			
		||||
	handshake := &peer.handshake
 | 
			
		||||
	handshake.mutex.Lock()
 | 
			
		||||
	defer handshake.mutex.Unlock()
 | 
			
		||||
	if handshake.state != HandshakeInitialCreated {
 | 
			
		||||
	if handshake.state != HandshakeInitiationCreated {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// finish 3-way DH
 | 
			
		||||
 | 
			
		||||
	hash := addToHash(handshake.hash, msg.Ephemeral[:])
 | 
			
		||||
	hash := mixHash(handshake.hash, msg.Ephemeral[:])
 | 
			
		||||
	chainKey := handshake.chainKey
 | 
			
		||||
 | 
			
		||||
	func() {
 | 
			
		||||
		ss := handshake.localEphemeral.sharedSecret(msg.Ephemeral)
 | 
			
		||||
		chainKey = addToChainKey(chainKey, ss[:])
 | 
			
		||||
		chainKey = mixKey(chainKey, ss[:])
 | 
			
		||||
		ss = device.privateKey.sharedSecret(msg.Ephemeral)
 | 
			
		||||
		chainKey = addToChainKey(chainKey, ss[:])
 | 
			
		||||
		chainKey = mixKey(chainKey, ss[:])
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	// add preshared key (psk)
 | 
			
		||||
@ -321,7 +327,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
 | 
			
		||||
	var tau [blake2s.Size]byte
 | 
			
		||||
	var key [chacha20poly1305.KeySize]byte
 | 
			
		||||
	chainKey, tau, key = KDF3(chainKey[:], handshake.presharedKey[:])
 | 
			
		||||
	hash = addToHash(hash, tau[:])
 | 
			
		||||
	hash = mixHash(hash, tau[:])
 | 
			
		||||
 | 
			
		||||
	// authenticate
 | 
			
		||||
 | 
			
		||||
@ -330,7 +336,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	hash = addToHash(hash, msg.Empty[:])
 | 
			
		||||
	hash = mixHash(hash, msg.Empty[:])
 | 
			
		||||
 | 
			
		||||
	// update handshake state
 | 
			
		||||
 | 
			
		||||
@ -368,7 +374,11 @@ func (peer *Peer) NewKeyPair() *KeyPair {
 | 
			
		||||
	keyPair.sendNonce = 0
 | 
			
		||||
	keyPair.recvNonce = 0
 | 
			
		||||
 | 
			
		||||
	peer.handshake.state = HandshakeReset
 | 
			
		||||
	// zero handshake
 | 
			
		||||
 | 
			
		||||
	handshake.chainKey = [blake2s.Size]byte{}
 | 
			
		||||
	handshake.localEphemeral = NoisePrivateKey{}
 | 
			
		||||
	peer.handshake.state = HandshakeZeroed
 | 
			
		||||
 | 
			
		||||
	return &keyPair
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -67,13 +67,13 @@ func TestNoiseHandshake(t *testing.T) {
 | 
			
		||||
 | 
			
		||||
	t.Log("exchange initiation message")
 | 
			
		||||
 | 
			
		||||
	msg1, err := dev1.CreateMessageInitial(peer2)
 | 
			
		||||
	msg1, err := dev1.CreateMessageInitiation(peer2)
 | 
			
		||||
	assertNil(t, err)
 | 
			
		||||
 | 
			
		||||
	packet := make([]byte, 0, 256)
 | 
			
		||||
	writer := bytes.NewBuffer(packet)
 | 
			
		||||
	err = binary.Write(writer, binary.LittleEndian, msg1)
 | 
			
		||||
	peer := dev2.ConsumeMessageInitial(msg1)
 | 
			
		||||
	peer := dev2.ConsumeMessageInitiation(msg1)
 | 
			
		||||
	if peer == nil {
 | 
			
		||||
		t.Fatal("handshake failed at initiation message")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										51
									
								
								src/peer.go
									
									
									
									
									
								
							
							
						
						
									
										51
									
								
								src/peer.go
									
									
									
									
									
								
							@ -1,39 +1,64 @@
 | 
			
		||||
package main
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"errors"
 | 
			
		||||
	"golang.org/x/crypto/blake2s"
 | 
			
		||||
	"net"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	OutboundQueueSize = 64
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Peer struct {
 | 
			
		||||
	mutex                       sync.RWMutex
 | 
			
		||||
	endpointIP                  net.IP        //
 | 
			
		||||
	endpointPort                uint16        //
 | 
			
		||||
	persistentKeepaliveInterval time.Duration // 0 = disabled
 | 
			
		||||
	keyPairs                    KeyPairs
 | 
			
		||||
	handshake                   Handshake
 | 
			
		||||
	device                      *Device
 | 
			
		||||
	macKey                      [blake2s.Size]byte // Hash(Label-Mac1 || publicKey)
 | 
			
		||||
	cookie                      []byte             // cookie
 | 
			
		||||
	cookieExpire                time.Time
 | 
			
		||||
	queueInbound                chan []byte
 | 
			
		||||
	queueOutbound               chan *OutboundWorkQueueElement
 | 
			
		||||
	queueOutboundRouting        chan []byte
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (device *Device) NewPeer(pk NoisePublicKey) *Peer {
 | 
			
		||||
	var peer Peer
 | 
			
		||||
 | 
			
		||||
	// map public key
 | 
			
		||||
 | 
			
		||||
	device.mutex.Lock()
 | 
			
		||||
	device.peers[pk] = &peer
 | 
			
		||||
	device.mutex.Unlock()
 | 
			
		||||
 | 
			
		||||
	// precompute
 | 
			
		||||
	// create peer
 | 
			
		||||
 | 
			
		||||
	peer.mutex.Lock()
 | 
			
		||||
	peer.device = device
 | 
			
		||||
	func(h *Handshake) {
 | 
			
		||||
		h.mutex.Lock()
 | 
			
		||||
		h.remoteStatic = pk
 | 
			
		||||
		h.precomputedStaticStatic = device.privateKey.sharedSecret(h.remoteStatic)
 | 
			
		||||
		h.mutex.Unlock()
 | 
			
		||||
	}(&peer.handshake)
 | 
			
		||||
	peer.queueOutbound = make(chan *OutboundWorkQueueElement, OutboundQueueSize)
 | 
			
		||||
 | 
			
		||||
	// map public key
 | 
			
		||||
 | 
			
		||||
	device.mutex.Lock()
 | 
			
		||||
	_, ok := device.peers[pk]
 | 
			
		||||
	if ok {
 | 
			
		||||
		panic(errors.New("bug: adding existing peer"))
 | 
			
		||||
	}
 | 
			
		||||
	device.peers[pk] = &peer
 | 
			
		||||
	device.mutex.Unlock()
 | 
			
		||||
 | 
			
		||||
	// precompute DH
 | 
			
		||||
 | 
			
		||||
	handshake := &peer.handshake
 | 
			
		||||
	handshake.mutex.Lock()
 | 
			
		||||
	handshake.remoteStatic = pk
 | 
			
		||||
	handshake.precomputedStaticStatic = device.privateKey.sharedSecret(handshake.remoteStatic)
 | 
			
		||||
 | 
			
		||||
	// compute mac key
 | 
			
		||||
 | 
			
		||||
	peer.macKey = blake2s.Sum256(append([]byte(WGLabelMAC1[:]), handshake.remoteStatic[:]...))
 | 
			
		||||
 | 
			
		||||
	handshake.mutex.Unlock()
 | 
			
		||||
	peer.mutex.Unlock()
 | 
			
		||||
 | 
			
		||||
	return &peer
 | 
			
		||||
 | 
			
		||||
@ -2,7 +2,6 @@ package main
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"net"
 | 
			
		||||
	"sync"
 | 
			
		||||
)
 | 
			
		||||
@ -52,25 +51,3 @@ func (table *RoutingTable) LookupIPv6(address []byte) *Peer {
 | 
			
		||||
	defer table.mutex.RUnlock()
 | 
			
		||||
	return table.IPv6.Lookup(address)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func OutgoingRoutingWorker(device *Device, queue chan []byte) {
 | 
			
		||||
	for {
 | 
			
		||||
		packet := <-queue
 | 
			
		||||
		switch packet[0] >> 4 {
 | 
			
		||||
 | 
			
		||||
		case IPv4version:
 | 
			
		||||
			dst := packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
 | 
			
		||||
			peer := device.routingTable.LookupIPv4(dst)
 | 
			
		||||
			fmt.Println("IPv4", peer)
 | 
			
		||||
 | 
			
		||||
		case IPv6version:
 | 
			
		||||
			dst := packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
 | 
			
		||||
			peer := device.routingTable.LookupIPv6(dst)
 | 
			
		||||
			fmt.Println("IPv6", peer)
 | 
			
		||||
 | 
			
		||||
		default:
 | 
			
		||||
			// todo: log
 | 
			
		||||
			fmt.Println("Unknown IP version")
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										154
									
								
								src/send.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										154
									
								
								src/send.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,154 @@
 | 
			
		||||
package main
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"net"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"sync/atomic"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
/* Handles outbound flow
 | 
			
		||||
 *
 | 
			
		||||
 * 1. TUN queue
 | 
			
		||||
 * 2. Routing
 | 
			
		||||
 * 3. Per peer queuing
 | 
			
		||||
 * 4. (work queuing)
 | 
			
		||||
 *
 | 
			
		||||
 */
 | 
			
		||||
 | 
			
		||||
type OutboundWorkQueueElement struct {
 | 
			
		||||
	wg      sync.WaitGroup
 | 
			
		||||
	packet  []byte
 | 
			
		||||
	nonce   uint64
 | 
			
		||||
	keyPair *KeyPair
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (device *Device) SendPacket(packet []byte) {
 | 
			
		||||
 | 
			
		||||
	// lookup peer
 | 
			
		||||
 | 
			
		||||
	var peer *Peer
 | 
			
		||||
	switch packet[0] >> 4 {
 | 
			
		||||
	case IPv4version:
 | 
			
		||||
		dst := packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
 | 
			
		||||
		peer = device.routingTable.LookupIPv4(dst)
 | 
			
		||||
 | 
			
		||||
	case IPv6version:
 | 
			
		||||
		dst := packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
 | 
			
		||||
		peer = device.routingTable.LookupIPv6(dst)
 | 
			
		||||
 | 
			
		||||
	default:
 | 
			
		||||
		device.logger.Println("unknown IP version")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if peer == nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// insert into peer queue
 | 
			
		||||
 | 
			
		||||
	for {
 | 
			
		||||
		select {
 | 
			
		||||
		case peer.queueOutboundRouting <- packet:
 | 
			
		||||
		default:
 | 
			
		||||
			select {
 | 
			
		||||
			case <-peer.queueOutboundRouting:
 | 
			
		||||
			default:
 | 
			
		||||
			}
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		break
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* Go routine
 | 
			
		||||
 *
 | 
			
		||||
 *
 | 
			
		||||
 * 1. waits for handshake.
 | 
			
		||||
 * 2. assigns key pair & nonce
 | 
			
		||||
 * 3. inserts to working queue
 | 
			
		||||
 *
 | 
			
		||||
 * TODO: avoid dynamic allocation of work queue elements
 | 
			
		||||
 */
 | 
			
		||||
func (peer *Peer) ConsumeOutboundPackets() {
 | 
			
		||||
	for {
 | 
			
		||||
		// wait for key pair
 | 
			
		||||
		keyPair := func() *KeyPair {
 | 
			
		||||
			peer.keyPairs.mutex.RLock()
 | 
			
		||||
			defer peer.keyPairs.mutex.RUnlock()
 | 
			
		||||
			return peer.keyPairs.current
 | 
			
		||||
		}()
 | 
			
		||||
		if keyPair == nil {
 | 
			
		||||
			if len(peer.queueOutboundRouting) > 0 {
 | 
			
		||||
				// TODO: start handshake
 | 
			
		||||
				<-peer.keyPairs.newKeyPair
 | 
			
		||||
			}
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// assign packets key pair
 | 
			
		||||
		for {
 | 
			
		||||
			select {
 | 
			
		||||
			case <-peer.keyPairs.newKeyPair:
 | 
			
		||||
			default:
 | 
			
		||||
			case <-peer.keyPairs.newKeyPair:
 | 
			
		||||
			case packet := <-peer.queueOutboundRouting:
 | 
			
		||||
 | 
			
		||||
				// create new work element
 | 
			
		||||
 | 
			
		||||
				work := new(OutboundWorkQueueElement)
 | 
			
		||||
				work.wg.Add(1)
 | 
			
		||||
				work.keyPair = keyPair
 | 
			
		||||
				work.packet = packet
 | 
			
		||||
				work.nonce = atomic.AddUint64(&keyPair.sendNonce, 1) - 1
 | 
			
		||||
 | 
			
		||||
				peer.queueOutbound <- work
 | 
			
		||||
 | 
			
		||||
				// drop packets until there is room
 | 
			
		||||
 | 
			
		||||
				for {
 | 
			
		||||
					select {
 | 
			
		||||
					case peer.device.queueWorkOutbound <- work:
 | 
			
		||||
						break
 | 
			
		||||
					default:
 | 
			
		||||
						drop := <-peer.device.queueWorkOutbound
 | 
			
		||||
						drop.packet = nil
 | 
			
		||||
						drop.wg.Done()
 | 
			
		||||
					}
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (peer *Peer) RoutineSequential() {
 | 
			
		||||
	for work := range peer.queueOutbound {
 | 
			
		||||
		work.wg.Wait()
 | 
			
		||||
		if work.packet == nil {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (device *Device) EncryptionWorker() {
 | 
			
		||||
	for {
 | 
			
		||||
		work := <-device.queueWorkOutbound
 | 
			
		||||
 | 
			
		||||
		func() {
 | 
			
		||||
			defer work.wg.Done()
 | 
			
		||||
 | 
			
		||||
			// pad packet
 | 
			
		||||
			padding := device.mtu - len(work.packet)
 | 
			
		||||
			if padding < 0 {
 | 
			
		||||
				work.packet = nil
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
			for n := 0; n < padding; n += 1 {
 | 
			
		||||
				work.packet = append(work.packet, 0) // TODO: gotta be a faster way
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			//
 | 
			
		||||
 | 
			
		||||
		}()
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user