Odds and ends
This commit is contained in:
		
							parent
							
								
									e94185681f
								
							
						
					
					
						commit
						2326d6a4d7
					
				
							
								
								
									
										3
									
								
								Makefile
									
									
									
									
									
								
							
							
						
						
									
										3
									
								
								Makefile
									
									
									
									
									
								
							@ -6,7 +6,4 @@ wireguard-go: $(wildcard *.go)
 | 
			
		||||
clean:
 | 
			
		||||
	rm -f wireguard-go
 | 
			
		||||
 | 
			
		||||
cloc:
 | 
			
		||||
	cloc $(filter-out xchacha20.go $(wildcard *_test.go), $(wildcard *.go))
 | 
			
		||||
 | 
			
		||||
.PHONY: clean cloc
 | 
			
		||||
 | 
			
		||||
@ -8,21 +8,12 @@ package main
 | 
			
		||||
import (
 | 
			
		||||
	"errors"
 | 
			
		||||
	"net"
 | 
			
		||||
	"sync"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
/* Binary trie
 | 
			
		||||
 *
 | 
			
		||||
 * The net.IPs used here are not formatted the
 | 
			
		||||
 * same way as those created by the "net" functions.
 | 
			
		||||
 * Here the IPs are slices of either 4 or 16 byte (not always 16)
 | 
			
		||||
 *
 | 
			
		||||
 * Synchronization done separately
 | 
			
		||||
 * See: routing.go
 | 
			
		||||
 */
 | 
			
		||||
 | 
			
		||||
type Trie struct {
 | 
			
		||||
type trieEntry struct {
 | 
			
		||||
	cidr  uint
 | 
			
		||||
	child [2]*Trie
 | 
			
		||||
	child [2]*trieEntry
 | 
			
		||||
	bits  []byte
 | 
			
		||||
	peer  *Peer
 | 
			
		||||
 | 
			
		||||
@ -90,15 +81,15 @@ func commonBits(ip1 []byte, ip2 []byte) uint {
 | 
			
		||||
	return i * 8
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (node *Trie) RemovePeer(p *Peer) *Trie {
 | 
			
		||||
func (node *trieEntry) removeByPeer(p *Peer) *trieEntry {
 | 
			
		||||
	if node == nil {
 | 
			
		||||
		return node
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// walk recursively
 | 
			
		||||
 | 
			
		||||
	node.child[0] = node.child[0].RemovePeer(p)
 | 
			
		||||
	node.child[1] = node.child[1].RemovePeer(p)
 | 
			
		||||
	node.child[0] = node.child[0].removeByPeer(p)
 | 
			
		||||
	node.child[1] = node.child[1].removeByPeer(p)
 | 
			
		||||
 | 
			
		||||
	if node.peer != p {
 | 
			
		||||
		return node
 | 
			
		||||
@ -113,16 +104,16 @@ func (node *Trie) RemovePeer(p *Peer) *Trie {
 | 
			
		||||
	return node.child[0]
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (node *Trie) choose(ip net.IP) byte {
 | 
			
		||||
func (node *trieEntry) choose(ip net.IP) byte {
 | 
			
		||||
	return (ip[node.bit_at_byte] >> node.bit_at_shift) & 1
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (node *Trie) Insert(ip net.IP, cidr uint, peer *Peer) *Trie {
 | 
			
		||||
func (node *trieEntry) insert(ip net.IP, cidr uint, peer *Peer) *trieEntry {
 | 
			
		||||
 | 
			
		||||
	// at leaf
 | 
			
		||||
 | 
			
		||||
	if node == nil {
 | 
			
		||||
		return &Trie{
 | 
			
		||||
		return &trieEntry{
 | 
			
		||||
			bits:         ip,
 | 
			
		||||
			peer:         peer,
 | 
			
		||||
			cidr:         cidr,
 | 
			
		||||
@ -140,13 +131,13 @@ func (node *Trie) Insert(ip net.IP, cidr uint, peer *Peer) *Trie {
 | 
			
		||||
			return node
 | 
			
		||||
		}
 | 
			
		||||
		bit := node.choose(ip)
 | 
			
		||||
		node.child[bit] = node.child[bit].Insert(ip, cidr, peer)
 | 
			
		||||
		node.child[bit] = node.child[bit].insert(ip, cidr, peer)
 | 
			
		||||
		return node
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// split node
 | 
			
		||||
 | 
			
		||||
	newNode := &Trie{
 | 
			
		||||
	newNode := &trieEntry{
 | 
			
		||||
		bits:         ip,
 | 
			
		||||
		peer:         peer,
 | 
			
		||||
		cidr:         cidr,
 | 
			
		||||
@ -166,7 +157,7 @@ func (node *Trie) Insert(ip net.IP, cidr uint, peer *Peer) *Trie {
 | 
			
		||||
 | 
			
		||||
	// create new parent for node & newNode
 | 
			
		||||
 | 
			
		||||
	parent := &Trie{
 | 
			
		||||
	parent := &trieEntry{
 | 
			
		||||
		bits:         ip,
 | 
			
		||||
		peer:         nil,
 | 
			
		||||
		cidr:         cidr,
 | 
			
		||||
@ -181,7 +172,7 @@ func (node *Trie) Insert(ip net.IP, cidr uint, peer *Peer) *Trie {
 | 
			
		||||
	return parent
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (node *Trie) Lookup(ip net.IP) *Peer {
 | 
			
		||||
func (node *trieEntry) lookup(ip net.IP) *Peer {
 | 
			
		||||
	var found *Peer
 | 
			
		||||
	size := uint(len(ip))
 | 
			
		||||
	for node != nil && commonBits(node.bits, ip) >= node.cidr {
 | 
			
		||||
@ -197,16 +188,7 @@ func (node *Trie) Lookup(ip net.IP) *Peer {
 | 
			
		||||
	return found
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (node *Trie) Count() uint {
 | 
			
		||||
	if node == nil {
 | 
			
		||||
		return 0
 | 
			
		||||
	}
 | 
			
		||||
	l := node.child[0].Count()
 | 
			
		||||
	r := node.child[1].Count()
 | 
			
		||||
	return l + r
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (node *Trie) AllowedIPs(p *Peer, results []net.IPNet) []net.IPNet {
 | 
			
		||||
func (node *trieEntry) entriesForPeer(p *Peer, results []net.IPNet) []net.IPNet {
 | 
			
		||||
	if node == nil {
 | 
			
		||||
		return results
 | 
			
		||||
	}
 | 
			
		||||
@ -223,11 +205,69 @@ func (node *Trie) AllowedIPs(p *Peer, results []net.IPNet) []net.IPNet {
 | 
			
		||||
		} else if len(node.bits) == net.IPv6len {
 | 
			
		||||
			mask.IP = node.bits
 | 
			
		||||
		} else {
 | 
			
		||||
			panic(errors.New("bug: unexpected address length"))
 | 
			
		||||
			panic(errors.New("unexpected address length"))
 | 
			
		||||
		}
 | 
			
		||||
		results = append(results, mask)
 | 
			
		||||
	}
 | 
			
		||||
	results = node.child[0].AllowedIPs(p, results)
 | 
			
		||||
	results = node.child[1].AllowedIPs(p, results)
 | 
			
		||||
	results = node.child[0].entriesForPeer(p, results)
 | 
			
		||||
	results = node.child[1].entriesForPeer(p, results)
 | 
			
		||||
	return results
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type AllowedIPs struct {
 | 
			
		||||
	IPv4  *trieEntry
 | 
			
		||||
	IPv6  *trieEntry
 | 
			
		||||
	mutex sync.RWMutex
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (table *AllowedIPs) EntriesForPeer(peer *Peer) []net.IPNet {
 | 
			
		||||
	table.mutex.RLock()
 | 
			
		||||
	defer table.mutex.RUnlock()
 | 
			
		||||
 | 
			
		||||
	allowed := make([]net.IPNet, 0, 10)
 | 
			
		||||
	allowed = table.IPv4.entriesForPeer(peer, allowed)
 | 
			
		||||
	allowed = table.IPv6.entriesForPeer(peer, allowed)
 | 
			
		||||
	return allowed
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (table *AllowedIPs) Reset() {
 | 
			
		||||
	table.mutex.Lock()
 | 
			
		||||
	defer table.mutex.Unlock()
 | 
			
		||||
 | 
			
		||||
	table.IPv4 = nil
 | 
			
		||||
	table.IPv6 = nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
 | 
			
		||||
	table.mutex.Lock()
 | 
			
		||||
	defer table.mutex.Unlock()
 | 
			
		||||
 | 
			
		||||
	table.IPv4 = table.IPv4.removeByPeer(peer)
 | 
			
		||||
	table.IPv6 = table.IPv6.removeByPeer(peer)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (table *AllowedIPs) Insert(ip net.IP, cidr uint, peer *Peer) {
 | 
			
		||||
	table.mutex.Lock()
 | 
			
		||||
	defer table.mutex.Unlock()
 | 
			
		||||
 | 
			
		||||
	switch len(ip) {
 | 
			
		||||
	case net.IPv6len:
 | 
			
		||||
		table.IPv6 = table.IPv6.insert(ip, cidr, peer)
 | 
			
		||||
	case net.IPv4len:
 | 
			
		||||
		table.IPv4 = table.IPv4.insert(ip, cidr, peer)
 | 
			
		||||
	default:
 | 
			
		||||
		panic(errors.New("inserting unknown address type"))
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (table *AllowedIPs) LookupIPv4(address []byte) *Peer {
 | 
			
		||||
	table.mutex.RLock()
 | 
			
		||||
	defer table.mutex.RUnlock()
 | 
			
		||||
	return table.IPv4.lookup(address)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (table *AllowedIPs) LookupIPv6(address []byte) *Peer {
 | 
			
		||||
	table.mutex.RLock()
 | 
			
		||||
	defer table.mutex.RUnlock()
 | 
			
		||||
	return table.IPv6.lookup(address)
 | 
			
		||||
}
 | 
			
		||||
@ -65,7 +65,7 @@ func (r SlowRouter) Lookup(addr []byte) *Peer {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestTrieRandomIPv4(t *testing.T) {
 | 
			
		||||
	var trie *Trie
 | 
			
		||||
	var trie *trieEntry
 | 
			
		||||
	var slow SlowRouter
 | 
			
		||||
	var peers []*Peer
 | 
			
		||||
 | 
			
		||||
@ -82,7 +82,7 @@ func TestTrieRandomIPv4(t *testing.T) {
 | 
			
		||||
		rand.Read(addr[:])
 | 
			
		||||
		cidr := uint(rand.Uint32() % (AddressLength * 8))
 | 
			
		||||
		index := rand.Int() % NumberOfPeers
 | 
			
		||||
		trie = trie.Insert(addr[:], cidr, peers[index])
 | 
			
		||||
		trie = trie.insert(addr[:], cidr, peers[index])
 | 
			
		||||
		slow = slow.Insert(addr[:], cidr, peers[index])
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
@ -90,15 +90,15 @@ func TestTrieRandomIPv4(t *testing.T) {
 | 
			
		||||
		var addr [AddressLength]byte
 | 
			
		||||
		rand.Read(addr[:])
 | 
			
		||||
		peer1 := slow.Lookup(addr[:])
 | 
			
		||||
		peer2 := trie.Lookup(addr[:])
 | 
			
		||||
		peer2 := trie.lookup(addr[:])
 | 
			
		||||
		if peer1 != peer2 {
 | 
			
		||||
			t.Error("Trie did not match naive implementation, for:", addr)
 | 
			
		||||
			t.Error("trieEntry did not match naive implementation, for:", addr)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestTrieRandomIPv6(t *testing.T) {
 | 
			
		||||
	var trie *Trie
 | 
			
		||||
	var trie *trieEntry
 | 
			
		||||
	var slow SlowRouter
 | 
			
		||||
	var peers []*Peer
 | 
			
		||||
 | 
			
		||||
@ -115,7 +115,7 @@ func TestTrieRandomIPv6(t *testing.T) {
 | 
			
		||||
		rand.Read(addr[:])
 | 
			
		||||
		cidr := uint(rand.Uint32() % (AddressLength * 8))
 | 
			
		||||
		index := rand.Int() % NumberOfPeers
 | 
			
		||||
		trie = trie.Insert(addr[:], cidr, peers[index])
 | 
			
		||||
		trie = trie.insert(addr[:], cidr, peers[index])
 | 
			
		||||
		slow = slow.Insert(addr[:], cidr, peers[index])
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
@ -123,9 +123,9 @@ func TestTrieRandomIPv6(t *testing.T) {
 | 
			
		||||
		var addr [AddressLength]byte
 | 
			
		||||
		rand.Read(addr[:])
 | 
			
		||||
		peer1 := slow.Lookup(addr[:])
 | 
			
		||||
		peer2 := trie.Lookup(addr[:])
 | 
			
		||||
		peer2 := trie.lookup(addr[:])
 | 
			
		||||
		if peer1 != peer2 {
 | 
			
		||||
			t.Error("Trie did not match naive implementation, for:", addr)
 | 
			
		||||
			t.Error("trieEntry did not match naive implementation, for:", addr)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
@ -31,7 +31,7 @@ type testPairTrieLookup struct {
 | 
			
		||||
	peer *Peer
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func printTrie(t *testing.T, p *Trie) {
 | 
			
		||||
func printTrie(t *testing.T, p *trieEntry) {
 | 
			
		||||
	if p == nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
@ -63,7 +63,7 @@ func TestCommonBits(t *testing.T) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func benchmarkTrie(peerNumber int, addressNumber int, addressLength int, b *testing.B) {
 | 
			
		||||
	var trie *Trie
 | 
			
		||||
	var trie *trieEntry
 | 
			
		||||
	var peers []*Peer
 | 
			
		||||
 | 
			
		||||
	rand.Seed(1)
 | 
			
		||||
@ -79,13 +79,13 @@ func benchmarkTrie(peerNumber int, addressNumber int, addressLength int, b *test
 | 
			
		||||
		rand.Read(addr[:])
 | 
			
		||||
		cidr := uint(rand.Uint32() % (AddressLength * 8))
 | 
			
		||||
		index := rand.Int() % peerNumber
 | 
			
		||||
		trie = trie.Insert(addr[:], cidr, peers[index])
 | 
			
		||||
		trie = trie.insert(addr[:], cidr, peers[index])
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for n := 0; n < b.N; n += 1 {
 | 
			
		||||
		var addr [AddressLength]byte
 | 
			
		||||
		rand.Read(addr[:])
 | 
			
		||||
		trie.Lookup(addr[:])
 | 
			
		||||
		trie.lookup(addr[:])
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -117,21 +117,21 @@ func TestTrieIPv4(t *testing.T) {
 | 
			
		||||
	g := &Peer{}
 | 
			
		||||
	h := &Peer{}
 | 
			
		||||
 | 
			
		||||
	var trie *Trie
 | 
			
		||||
	var trie *trieEntry
 | 
			
		||||
 | 
			
		||||
	insert := func(peer *Peer, a, b, c, d byte, cidr uint) {
 | 
			
		||||
		trie = trie.Insert([]byte{a, b, c, d}, cidr, peer)
 | 
			
		||||
		trie = trie.insert([]byte{a, b, c, d}, cidr, peer)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	assertEQ := func(peer *Peer, a, b, c, d byte) {
 | 
			
		||||
		p := trie.Lookup([]byte{a, b, c, d})
 | 
			
		||||
		p := trie.lookup([]byte{a, b, c, d})
 | 
			
		||||
		if p != peer {
 | 
			
		||||
			t.Error("Assert EQ failed")
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	assertNEQ := func(peer *Peer, a, b, c, d byte) {
 | 
			
		||||
		p := trie.Lookup([]byte{a, b, c, d})
 | 
			
		||||
		p := trie.lookup([]byte{a, b, c, d})
 | 
			
		||||
		if p == peer {
 | 
			
		||||
			t.Error("Assert NEQ failed")
 | 
			
		||||
		}
 | 
			
		||||
@ -173,7 +173,7 @@ func TestTrieIPv4(t *testing.T) {
 | 
			
		||||
	assertEQ(a, 192, 0, 0, 0)
 | 
			
		||||
	assertEQ(a, 255, 0, 0, 0)
 | 
			
		||||
 | 
			
		||||
	trie = trie.RemovePeer(a)
 | 
			
		||||
	trie = trie.removeByPeer(a)
 | 
			
		||||
 | 
			
		||||
	assertNEQ(a, 1, 0, 0, 0)
 | 
			
		||||
	assertNEQ(a, 64, 0, 0, 0)
 | 
			
		||||
@ -186,7 +186,7 @@ func TestTrieIPv4(t *testing.T) {
 | 
			
		||||
	insert(a, 192, 168, 0, 0, 16)
 | 
			
		||||
	insert(a, 192, 168, 0, 0, 24)
 | 
			
		||||
 | 
			
		||||
	trie = trie.RemovePeer(a)
 | 
			
		||||
	trie = trie.removeByPeer(a)
 | 
			
		||||
 | 
			
		||||
	assertNEQ(a, 192, 168, 0, 1)
 | 
			
		||||
}
 | 
			
		||||
@ -204,7 +204,7 @@ func TestTrieIPv6(t *testing.T) {
 | 
			
		||||
	g := &Peer{}
 | 
			
		||||
	h := &Peer{}
 | 
			
		||||
 | 
			
		||||
	var trie *Trie
 | 
			
		||||
	var trie *trieEntry
 | 
			
		||||
 | 
			
		||||
	expand := func(a uint32) []byte {
 | 
			
		||||
		var out [4]byte
 | 
			
		||||
@ -221,7 +221,7 @@ func TestTrieIPv6(t *testing.T) {
 | 
			
		||||
		addr = append(addr, expand(b)...)
 | 
			
		||||
		addr = append(addr, expand(c)...)
 | 
			
		||||
		addr = append(addr, expand(d)...)
 | 
			
		||||
		trie = trie.Insert(addr, cidr, peer)
 | 
			
		||||
		trie = trie.insert(addr, cidr, peer)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	assertEQ := func(peer *Peer, a, b, c, d uint32) {
 | 
			
		||||
@ -230,7 +230,7 @@ func TestTrieIPv6(t *testing.T) {
 | 
			
		||||
		addr = append(addr, expand(b)...)
 | 
			
		||||
		addr = append(addr, expand(c)...)
 | 
			
		||||
		addr = append(addr, expand(d)...)
 | 
			
		||||
		p := trie.Lookup(addr)
 | 
			
		||||
		p := trie.lookup(addr)
 | 
			
		||||
		if p != peer {
 | 
			
		||||
			t.Error("Assert EQ failed")
 | 
			
		||||
		}
 | 
			
		||||
@ -46,7 +46,7 @@ type Device struct {
 | 
			
		||||
 | 
			
		||||
	routing struct {
 | 
			
		||||
		mutex sync.RWMutex
 | 
			
		||||
		table RoutingTable
 | 
			
		||||
		table AllowedIPs
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	peers struct {
 | 
			
		||||
@ -95,7 +95,7 @@ func unsafeRemovePeer(device *Device, peer *Peer, key NoisePublicKey) {
 | 
			
		||||
 | 
			
		||||
	// stop routing and processing of packets
 | 
			
		||||
 | 
			
		||||
	device.routing.table.RemovePeer(peer)
 | 
			
		||||
	device.routing.table.RemoveByPeer(peer)
 | 
			
		||||
	peer.Stop()
 | 
			
		||||
 | 
			
		||||
	// remove from peer map
 | 
			
		||||
 | 
			
		||||
@ -33,7 +33,7 @@ type Keypairs struct {
 | 
			
		||||
	mutex    sync.RWMutex
 | 
			
		||||
	current  *Keypair
 | 
			
		||||
	previous *Keypair
 | 
			
		||||
	next     *Keypair // not yet "confirmed by transport"
 | 
			
		||||
	next     *Keypair
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (kp *Keypairs) Current() *Keypair {
 | 
			
		||||
 | 
			
		||||
@ -40,7 +40,7 @@ func NewLogger(level int, prepend string) *Logger {
 | 
			
		||||
 | 
			
		||||
	logger.Debug = log.New(logDebug,
 | 
			
		||||
		"DEBUG: "+prepend,
 | 
			
		||||
		log.Ldate|log.Ltime|log.Lshortfile,
 | 
			
		||||
		log.Ldate|log.Ltime,
 | 
			
		||||
	)
 | 
			
		||||
 | 
			
		||||
	logger.Info = log.New(logInfo,
 | 
			
		||||
 | 
			
		||||
@ -71,14 +71,13 @@ func isZero(val []byte) bool {
 | 
			
		||||
	return acc == 1
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* This function is not used as pervasively as it should because this is mostly impossible in Go at the moment */
 | 
			
		||||
func setZero(arr []byte) {
 | 
			
		||||
	for i := range arr {
 | 
			
		||||
		arr[i] = 0
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* curve25519 wrappers */
 | 
			
		||||
 | 
			
		||||
func newPrivateKey() (sk NoisePrivateKey, err error) {
 | 
			
		||||
	// clamping: https://cr.yp.to/ecdh.html
 | 
			
		||||
	_, err = rand.Read(sk[:])
 | 
			
		||||
 | 
			
		||||
@ -30,7 +30,7 @@ func loadExactHex(dst []byte, src string) error {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	if len(slice) != len(dst) {
 | 
			
		||||
		return errors.New("Hex string does not fit the slice")
 | 
			
		||||
		return errors.New("hex string does not fit the slice")
 | 
			
		||||
	}
 | 
			
		||||
	copy(dst, slice)
 | 
			
		||||
	return nil
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										24
									
								
								peer.go
									
									
									
									
									
								
							
							
						
						
									
										24
									
								
								peer.go
									
									
									
									
									
								
							@ -61,7 +61,7 @@ type Peer struct {
 | 
			
		||||
		mutex    sync.Mutex     // held when stopping / starting routines
 | 
			
		||||
		starting sync.WaitGroup // routines pending start
 | 
			
		||||
		stopping sync.WaitGroup // routines pending stop
 | 
			
		||||
		stop     chan struct{}  // size 0, stop all go-routines in peer
 | 
			
		||||
		stop     chan struct{}  // size 0, stop all go routines in peer
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	mac CookieGenerator
 | 
			
		||||
@ -70,7 +70,7 @@ type Peer struct {
 | 
			
		||||
func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
 | 
			
		||||
 | 
			
		||||
	if device.isClosed.Get() {
 | 
			
		||||
		return nil, errors.New("Device closed")
 | 
			
		||||
		return nil, errors.New("device closed")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// lock resources
 | 
			
		||||
@ -87,7 +87,7 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
 | 
			
		||||
	// check if over limit
 | 
			
		||||
 | 
			
		||||
	if len(device.peers.keyMap) >= MaxPeers {
 | 
			
		||||
		return nil, errors.New("Too many peers")
 | 
			
		||||
		return nil, errors.New("too many peers")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// create peer
 | 
			
		||||
@ -104,7 +104,7 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
 | 
			
		||||
 | 
			
		||||
	_, ok := device.peers.keyMap[pk]
 | 
			
		||||
	if ok {
 | 
			
		||||
		return nil, errors.New("Adding existing peer")
 | 
			
		||||
		return nil, errors.New("adding existing peer")
 | 
			
		||||
	}
 | 
			
		||||
	device.peers.keyMap[pk] = peer
 | 
			
		||||
 | 
			
		||||
@ -134,26 +134,26 @@ func (peer *Peer) SendBuffer(buffer []byte) error {
 | 
			
		||||
	defer peer.device.net.mutex.RUnlock()
 | 
			
		||||
 | 
			
		||||
	if peer.device.net.bind == nil {
 | 
			
		||||
		return errors.New("No bind")
 | 
			
		||||
		return errors.New("no bind")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	peer.mutex.RLock()
 | 
			
		||||
	defer peer.mutex.RUnlock()
 | 
			
		||||
 | 
			
		||||
	if peer.endpoint == nil {
 | 
			
		||||
		return errors.New("No known endpoint for peer")
 | 
			
		||||
		return errors.New("no known endpoint for peer")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return peer.device.net.bind.Send(buffer, peer.endpoint)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* Returns a short string identifier for logging
 | 
			
		||||
 */
 | 
			
		||||
func (peer *Peer) String() string {
 | 
			
		||||
	return fmt.Sprintf(
 | 
			
		||||
		"peer(%s)",
 | 
			
		||||
		base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]),
 | 
			
		||||
	)
 | 
			
		||||
	base64Key := base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:])
 | 
			
		||||
	abbreviatedKey := "invalid"
 | 
			
		||||
	if len(base64Key) == 44 {
 | 
			
		||||
		abbreviatedKey = base64Key[0:4] + "..." + base64Key[40:44]
 | 
			
		||||
	}
 | 
			
		||||
	return fmt.Sprintf("peer(%s)", abbreviatedKey)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (peer *Peer) Start() {
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										26
									
								
								receive.go
									
									
									
									
									
								
							
							
						
						
									
										26
									
								
								receive.go
									
									
									
									
									
								
							@ -600,20 +600,24 @@ func (peer *Peer) RoutineSequentialReceiver() {
 | 
			
		||||
			// check if using new key-pair
 | 
			
		||||
 | 
			
		||||
			kp := &peer.keypairs
 | 
			
		||||
			kp.mutex.Lock() //TODO: make this into an RW lock to reduce contention here for the equality check which is rarely true
 | 
			
		||||
			if kp.next == elem.keypair {
 | 
			
		||||
				old := kp.previous
 | 
			
		||||
				kp.previous = kp.current
 | 
			
		||||
				device.DeleteKeypair(old)
 | 
			
		||||
				kp.current = kp.next
 | 
			
		||||
				kp.next = nil
 | 
			
		||||
				peer.timersHandshakeComplete()
 | 
			
		||||
				select {
 | 
			
		||||
				case peer.signals.newKeypairArrived <- struct{}{}:
 | 
			
		||||
				default:
 | 
			
		||||
				kp.mutex.Lock()
 | 
			
		||||
				if kp.next != elem.keypair {
 | 
			
		||||
					kp.mutex.Unlock()
 | 
			
		||||
				} else {
 | 
			
		||||
					old := kp.previous
 | 
			
		||||
					kp.previous = kp.current
 | 
			
		||||
					device.DeleteKeypair(old)
 | 
			
		||||
					kp.current = kp.next
 | 
			
		||||
					kp.next = nil
 | 
			
		||||
					kp.mutex.Unlock()
 | 
			
		||||
					peer.timersHandshakeComplete()
 | 
			
		||||
					select {
 | 
			
		||||
					case peer.signals.newKeypairArrived <- struct{}{}:
 | 
			
		||||
					default:
 | 
			
		||||
					}
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			kp.mutex.Unlock()
 | 
			
		||||
 | 
			
		||||
			peer.keepKeyFreshReceiving()
 | 
			
		||||
			peer.timersAnyAuthenticatedPacketTraversal()
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										70
									
								
								routing.go
									
									
									
									
									
								
							
							
						
						
									
										70
									
								
								routing.go
									
									
									
									
									
								
							@ -1,70 +0,0 @@
 | 
			
		||||
/* SPDX-License-Identifier: GPL-2.0
 | 
			
		||||
 *
 | 
			
		||||
 * Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
 | 
			
		||||
 */
 | 
			
		||||
 | 
			
		||||
package main
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"errors"
 | 
			
		||||
	"net"
 | 
			
		||||
	"sync"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type RoutingTable struct {
 | 
			
		||||
	IPv4  *Trie
 | 
			
		||||
	IPv6  *Trie
 | 
			
		||||
	mutex sync.RWMutex
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (table *RoutingTable) AllowedIPs(peer *Peer) []net.IPNet {
 | 
			
		||||
	table.mutex.RLock()
 | 
			
		||||
	defer table.mutex.RUnlock()
 | 
			
		||||
 | 
			
		||||
	allowed := make([]net.IPNet, 0, 10)
 | 
			
		||||
	allowed = table.IPv4.AllowedIPs(peer, allowed)
 | 
			
		||||
	allowed = table.IPv6.AllowedIPs(peer, allowed)
 | 
			
		||||
	return allowed
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (table *RoutingTable) Reset() {
 | 
			
		||||
	table.mutex.Lock()
 | 
			
		||||
	defer table.mutex.Unlock()
 | 
			
		||||
 | 
			
		||||
	table.IPv4 = nil
 | 
			
		||||
	table.IPv6 = nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (table *RoutingTable) RemovePeer(peer *Peer) {
 | 
			
		||||
	table.mutex.Lock()
 | 
			
		||||
	defer table.mutex.Unlock()
 | 
			
		||||
 | 
			
		||||
	table.IPv4 = table.IPv4.RemovePeer(peer)
 | 
			
		||||
	table.IPv6 = table.IPv6.RemovePeer(peer)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (table *RoutingTable) Insert(ip net.IP, cidr uint, peer *Peer) {
 | 
			
		||||
	table.mutex.Lock()
 | 
			
		||||
	defer table.mutex.Unlock()
 | 
			
		||||
 | 
			
		||||
	switch len(ip) {
 | 
			
		||||
	case net.IPv6len:
 | 
			
		||||
		table.IPv6 = table.IPv6.Insert(ip, cidr, peer)
 | 
			
		||||
	case net.IPv4len:
 | 
			
		||||
		table.IPv4 = table.IPv4.Insert(ip, cidr, peer)
 | 
			
		||||
	default:
 | 
			
		||||
		panic(errors.New("Inserting unknown address type"))
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (table *RoutingTable) LookupIPv4(address []byte) *Peer {
 | 
			
		||||
	table.mutex.RLock()
 | 
			
		||||
	defer table.mutex.RUnlock()
 | 
			
		||||
	return table.IPv4.Lookup(address)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (table *RoutingTable) LookupIPv6(address []byte) *Peer {
 | 
			
		||||
	table.mutex.RLock()
 | 
			
		||||
	defer table.mutex.RUnlock()
 | 
			
		||||
	return table.IPv6.Lookup(address)
 | 
			
		||||
}
 | 
			
		||||
@ -224,7 +224,9 @@ func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (tun *NativeTun) Close() error {
 | 
			
		||||
	return tun.fd.Close()
 | 
			
		||||
	err := tun.fd.Close()
 | 
			
		||||
	close(tun.events)
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (tun *NativeTun) setMTU(n int) error {
 | 
			
		||||
 | 
			
		||||
@ -392,6 +392,7 @@ func (tun *NativeTun) Close() error {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	tun.closingWriter.Write([]byte{0})
 | 
			
		||||
	close(tun.events)
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -125,7 +125,9 @@ func (f *NativeTUN) Events() chan TUNEvent {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (f *NativeTUN) Close() error {
 | 
			
		||||
	return windows.Close(f.fd)
 | 
			
		||||
	close(f.events)
 | 
			
		||||
	err := windows.Close(f.fd)
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (f *NativeTUN) Write(b []byte) (int, error) {
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										6
									
								
								uapi.go
									
									
									
									
									
								
							
							
						
						
									
										6
									
								
								uapi.go
									
									
									
									
									
								
							@ -91,7 +91,7 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
 | 
			
		||||
			send(fmt.Sprintf("rx_bytes=%d", peer.stats.rxBytes))
 | 
			
		||||
			send(fmt.Sprintf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval))
 | 
			
		||||
 | 
			
		||||
			for _, ip := range device.routing.table.AllowedIPs(peer) {
 | 
			
		||||
			for _, ip := range device.routing.table.EntriesForPeer(peer) {
 | 
			
		||||
				send("allowed_ip=" + ip.String())
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
@ -337,7 +337,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
 | 
			
		||||
 | 
			
		||||
			case "replace_allowed_ips":
 | 
			
		||||
 | 
			
		||||
				logDebug.Println("UAPI: Removing all allowed IPs for peer:", peer)
 | 
			
		||||
				logDebug.Println("UAPI: Removing all allowed EntriesForPeer for peer:", peer)
 | 
			
		||||
 | 
			
		||||
				if value != "true" {
 | 
			
		||||
					logError.Println("Failed to set replace_allowed_ips, invalid value:", value)
 | 
			
		||||
@ -349,7 +349,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				device.routing.mutex.Lock()
 | 
			
		||||
				device.routing.table.RemovePeer(peer)
 | 
			
		||||
				device.routing.table.RemoveByPeer(peer)
 | 
			
		||||
				device.routing.mutex.Unlock()
 | 
			
		||||
 | 
			
		||||
			case "allowed_ip":
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user