Odds and ends
This commit is contained in:
parent
e94185681f
commit
2326d6a4d7
3
Makefile
3
Makefile
@ -6,7 +6,4 @@ wireguard-go: $(wildcard *.go)
|
|||||||
clean:
|
clean:
|
||||||
rm -f wireguard-go
|
rm -f wireguard-go
|
||||||
|
|
||||||
cloc:
|
|
||||||
cloc $(filter-out xchacha20.go $(wildcard *_test.go), $(wildcard *.go))
|
|
||||||
|
|
||||||
.PHONY: clean cloc
|
.PHONY: clean cloc
|
||||||
|
@ -8,21 +8,12 @@ package main
|
|||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
/* Binary trie
|
type trieEntry struct {
|
||||||
*
|
|
||||||
* The net.IPs used here are not formatted the
|
|
||||||
* same way as those created by the "net" functions.
|
|
||||||
* Here the IPs are slices of either 4 or 16 byte (not always 16)
|
|
||||||
*
|
|
||||||
* Synchronization done separately
|
|
||||||
* See: routing.go
|
|
||||||
*/
|
|
||||||
|
|
||||||
type Trie struct {
|
|
||||||
cidr uint
|
cidr uint
|
||||||
child [2]*Trie
|
child [2]*trieEntry
|
||||||
bits []byte
|
bits []byte
|
||||||
peer *Peer
|
peer *Peer
|
||||||
|
|
||||||
@ -90,15 +81,15 @@ func commonBits(ip1 []byte, ip2 []byte) uint {
|
|||||||
return i * 8
|
return i * 8
|
||||||
}
|
}
|
||||||
|
|
||||||
func (node *Trie) RemovePeer(p *Peer) *Trie {
|
func (node *trieEntry) removeByPeer(p *Peer) *trieEntry {
|
||||||
if node == nil {
|
if node == nil {
|
||||||
return node
|
return node
|
||||||
}
|
}
|
||||||
|
|
||||||
// walk recursively
|
// walk recursively
|
||||||
|
|
||||||
node.child[0] = node.child[0].RemovePeer(p)
|
node.child[0] = node.child[0].removeByPeer(p)
|
||||||
node.child[1] = node.child[1].RemovePeer(p)
|
node.child[1] = node.child[1].removeByPeer(p)
|
||||||
|
|
||||||
if node.peer != p {
|
if node.peer != p {
|
||||||
return node
|
return node
|
||||||
@ -113,16 +104,16 @@ func (node *Trie) RemovePeer(p *Peer) *Trie {
|
|||||||
return node.child[0]
|
return node.child[0]
|
||||||
}
|
}
|
||||||
|
|
||||||
func (node *Trie) choose(ip net.IP) byte {
|
func (node *trieEntry) choose(ip net.IP) byte {
|
||||||
return (ip[node.bit_at_byte] >> node.bit_at_shift) & 1
|
return (ip[node.bit_at_byte] >> node.bit_at_shift) & 1
|
||||||
}
|
}
|
||||||
|
|
||||||
func (node *Trie) Insert(ip net.IP, cidr uint, peer *Peer) *Trie {
|
func (node *trieEntry) insert(ip net.IP, cidr uint, peer *Peer) *trieEntry {
|
||||||
|
|
||||||
// at leaf
|
// at leaf
|
||||||
|
|
||||||
if node == nil {
|
if node == nil {
|
||||||
return &Trie{
|
return &trieEntry{
|
||||||
bits: ip,
|
bits: ip,
|
||||||
peer: peer,
|
peer: peer,
|
||||||
cidr: cidr,
|
cidr: cidr,
|
||||||
@ -140,13 +131,13 @@ func (node *Trie) Insert(ip net.IP, cidr uint, peer *Peer) *Trie {
|
|||||||
return node
|
return node
|
||||||
}
|
}
|
||||||
bit := node.choose(ip)
|
bit := node.choose(ip)
|
||||||
node.child[bit] = node.child[bit].Insert(ip, cidr, peer)
|
node.child[bit] = node.child[bit].insert(ip, cidr, peer)
|
||||||
return node
|
return node
|
||||||
}
|
}
|
||||||
|
|
||||||
// split node
|
// split node
|
||||||
|
|
||||||
newNode := &Trie{
|
newNode := &trieEntry{
|
||||||
bits: ip,
|
bits: ip,
|
||||||
peer: peer,
|
peer: peer,
|
||||||
cidr: cidr,
|
cidr: cidr,
|
||||||
@ -166,7 +157,7 @@ func (node *Trie) Insert(ip net.IP, cidr uint, peer *Peer) *Trie {
|
|||||||
|
|
||||||
// create new parent for node & newNode
|
// create new parent for node & newNode
|
||||||
|
|
||||||
parent := &Trie{
|
parent := &trieEntry{
|
||||||
bits: ip,
|
bits: ip,
|
||||||
peer: nil,
|
peer: nil,
|
||||||
cidr: cidr,
|
cidr: cidr,
|
||||||
@ -181,7 +172,7 @@ func (node *Trie) Insert(ip net.IP, cidr uint, peer *Peer) *Trie {
|
|||||||
return parent
|
return parent
|
||||||
}
|
}
|
||||||
|
|
||||||
func (node *Trie) Lookup(ip net.IP) *Peer {
|
func (node *trieEntry) lookup(ip net.IP) *Peer {
|
||||||
var found *Peer
|
var found *Peer
|
||||||
size := uint(len(ip))
|
size := uint(len(ip))
|
||||||
for node != nil && commonBits(node.bits, ip) >= node.cidr {
|
for node != nil && commonBits(node.bits, ip) >= node.cidr {
|
||||||
@ -197,16 +188,7 @@ func (node *Trie) Lookup(ip net.IP) *Peer {
|
|||||||
return found
|
return found
|
||||||
}
|
}
|
||||||
|
|
||||||
func (node *Trie) Count() uint {
|
func (node *trieEntry) entriesForPeer(p *Peer, results []net.IPNet) []net.IPNet {
|
||||||
if node == nil {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
l := node.child[0].Count()
|
|
||||||
r := node.child[1].Count()
|
|
||||||
return l + r
|
|
||||||
}
|
|
||||||
|
|
||||||
func (node *Trie) AllowedIPs(p *Peer, results []net.IPNet) []net.IPNet {
|
|
||||||
if node == nil {
|
if node == nil {
|
||||||
return results
|
return results
|
||||||
}
|
}
|
||||||
@ -223,11 +205,69 @@ func (node *Trie) AllowedIPs(p *Peer, results []net.IPNet) []net.IPNet {
|
|||||||
} else if len(node.bits) == net.IPv6len {
|
} else if len(node.bits) == net.IPv6len {
|
||||||
mask.IP = node.bits
|
mask.IP = node.bits
|
||||||
} else {
|
} else {
|
||||||
panic(errors.New("bug: unexpected address length"))
|
panic(errors.New("unexpected address length"))
|
||||||
}
|
}
|
||||||
results = append(results, mask)
|
results = append(results, mask)
|
||||||
}
|
}
|
||||||
results = node.child[0].AllowedIPs(p, results)
|
results = node.child[0].entriesForPeer(p, results)
|
||||||
results = node.child[1].AllowedIPs(p, results)
|
results = node.child[1].entriesForPeer(p, results)
|
||||||
return results
|
return results
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type AllowedIPs struct {
|
||||||
|
IPv4 *trieEntry
|
||||||
|
IPv6 *trieEntry
|
||||||
|
mutex sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func (table *AllowedIPs) EntriesForPeer(peer *Peer) []net.IPNet {
|
||||||
|
table.mutex.RLock()
|
||||||
|
defer table.mutex.RUnlock()
|
||||||
|
|
||||||
|
allowed := make([]net.IPNet, 0, 10)
|
||||||
|
allowed = table.IPv4.entriesForPeer(peer, allowed)
|
||||||
|
allowed = table.IPv6.entriesForPeer(peer, allowed)
|
||||||
|
return allowed
|
||||||
|
}
|
||||||
|
|
||||||
|
func (table *AllowedIPs) Reset() {
|
||||||
|
table.mutex.Lock()
|
||||||
|
defer table.mutex.Unlock()
|
||||||
|
|
||||||
|
table.IPv4 = nil
|
||||||
|
table.IPv6 = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
|
||||||
|
table.mutex.Lock()
|
||||||
|
defer table.mutex.Unlock()
|
||||||
|
|
||||||
|
table.IPv4 = table.IPv4.removeByPeer(peer)
|
||||||
|
table.IPv6 = table.IPv6.removeByPeer(peer)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (table *AllowedIPs) Insert(ip net.IP, cidr uint, peer *Peer) {
|
||||||
|
table.mutex.Lock()
|
||||||
|
defer table.mutex.Unlock()
|
||||||
|
|
||||||
|
switch len(ip) {
|
||||||
|
case net.IPv6len:
|
||||||
|
table.IPv6 = table.IPv6.insert(ip, cidr, peer)
|
||||||
|
case net.IPv4len:
|
||||||
|
table.IPv4 = table.IPv4.insert(ip, cidr, peer)
|
||||||
|
default:
|
||||||
|
panic(errors.New("inserting unknown address type"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (table *AllowedIPs) LookupIPv4(address []byte) *Peer {
|
||||||
|
table.mutex.RLock()
|
||||||
|
defer table.mutex.RUnlock()
|
||||||
|
return table.IPv4.lookup(address)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (table *AllowedIPs) LookupIPv6(address []byte) *Peer {
|
||||||
|
table.mutex.RLock()
|
||||||
|
defer table.mutex.RUnlock()
|
||||||
|
return table.IPv6.lookup(address)
|
||||||
|
}
|
@ -65,7 +65,7 @@ func (r SlowRouter) Lookup(addr []byte) *Peer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestTrieRandomIPv4(t *testing.T) {
|
func TestTrieRandomIPv4(t *testing.T) {
|
||||||
var trie *Trie
|
var trie *trieEntry
|
||||||
var slow SlowRouter
|
var slow SlowRouter
|
||||||
var peers []*Peer
|
var peers []*Peer
|
||||||
|
|
||||||
@ -82,7 +82,7 @@ func TestTrieRandomIPv4(t *testing.T) {
|
|||||||
rand.Read(addr[:])
|
rand.Read(addr[:])
|
||||||
cidr := uint(rand.Uint32() % (AddressLength * 8))
|
cidr := uint(rand.Uint32() % (AddressLength * 8))
|
||||||
index := rand.Int() % NumberOfPeers
|
index := rand.Int() % NumberOfPeers
|
||||||
trie = trie.Insert(addr[:], cidr, peers[index])
|
trie = trie.insert(addr[:], cidr, peers[index])
|
||||||
slow = slow.Insert(addr[:], cidr, peers[index])
|
slow = slow.Insert(addr[:], cidr, peers[index])
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -90,15 +90,15 @@ func TestTrieRandomIPv4(t *testing.T) {
|
|||||||
var addr [AddressLength]byte
|
var addr [AddressLength]byte
|
||||||
rand.Read(addr[:])
|
rand.Read(addr[:])
|
||||||
peer1 := slow.Lookup(addr[:])
|
peer1 := slow.Lookup(addr[:])
|
||||||
peer2 := trie.Lookup(addr[:])
|
peer2 := trie.lookup(addr[:])
|
||||||
if peer1 != peer2 {
|
if peer1 != peer2 {
|
||||||
t.Error("Trie did not match naive implementation, for:", addr)
|
t.Error("trieEntry did not match naive implementation, for:", addr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTrieRandomIPv6(t *testing.T) {
|
func TestTrieRandomIPv6(t *testing.T) {
|
||||||
var trie *Trie
|
var trie *trieEntry
|
||||||
var slow SlowRouter
|
var slow SlowRouter
|
||||||
var peers []*Peer
|
var peers []*Peer
|
||||||
|
|
||||||
@ -115,7 +115,7 @@ func TestTrieRandomIPv6(t *testing.T) {
|
|||||||
rand.Read(addr[:])
|
rand.Read(addr[:])
|
||||||
cidr := uint(rand.Uint32() % (AddressLength * 8))
|
cidr := uint(rand.Uint32() % (AddressLength * 8))
|
||||||
index := rand.Int() % NumberOfPeers
|
index := rand.Int() % NumberOfPeers
|
||||||
trie = trie.Insert(addr[:], cidr, peers[index])
|
trie = trie.insert(addr[:], cidr, peers[index])
|
||||||
slow = slow.Insert(addr[:], cidr, peers[index])
|
slow = slow.Insert(addr[:], cidr, peers[index])
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -123,9 +123,9 @@ func TestTrieRandomIPv6(t *testing.T) {
|
|||||||
var addr [AddressLength]byte
|
var addr [AddressLength]byte
|
||||||
rand.Read(addr[:])
|
rand.Read(addr[:])
|
||||||
peer1 := slow.Lookup(addr[:])
|
peer1 := slow.Lookup(addr[:])
|
||||||
peer2 := trie.Lookup(addr[:])
|
peer2 := trie.lookup(addr[:])
|
||||||
if peer1 != peer2 {
|
if peer1 != peer2 {
|
||||||
t.Error("Trie did not match naive implementation, for:", addr)
|
t.Error("trieEntry did not match naive implementation, for:", addr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
@ -31,7 +31,7 @@ type testPairTrieLookup struct {
|
|||||||
peer *Peer
|
peer *Peer
|
||||||
}
|
}
|
||||||
|
|
||||||
func printTrie(t *testing.T, p *Trie) {
|
func printTrie(t *testing.T, p *trieEntry) {
|
||||||
if p == nil {
|
if p == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -63,7 +63,7 @@ func TestCommonBits(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func benchmarkTrie(peerNumber int, addressNumber int, addressLength int, b *testing.B) {
|
func benchmarkTrie(peerNumber int, addressNumber int, addressLength int, b *testing.B) {
|
||||||
var trie *Trie
|
var trie *trieEntry
|
||||||
var peers []*Peer
|
var peers []*Peer
|
||||||
|
|
||||||
rand.Seed(1)
|
rand.Seed(1)
|
||||||
@ -79,13 +79,13 @@ func benchmarkTrie(peerNumber int, addressNumber int, addressLength int, b *test
|
|||||||
rand.Read(addr[:])
|
rand.Read(addr[:])
|
||||||
cidr := uint(rand.Uint32() % (AddressLength * 8))
|
cidr := uint(rand.Uint32() % (AddressLength * 8))
|
||||||
index := rand.Int() % peerNumber
|
index := rand.Int() % peerNumber
|
||||||
trie = trie.Insert(addr[:], cidr, peers[index])
|
trie = trie.insert(addr[:], cidr, peers[index])
|
||||||
}
|
}
|
||||||
|
|
||||||
for n := 0; n < b.N; n += 1 {
|
for n := 0; n < b.N; n += 1 {
|
||||||
var addr [AddressLength]byte
|
var addr [AddressLength]byte
|
||||||
rand.Read(addr[:])
|
rand.Read(addr[:])
|
||||||
trie.Lookup(addr[:])
|
trie.lookup(addr[:])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -117,21 +117,21 @@ func TestTrieIPv4(t *testing.T) {
|
|||||||
g := &Peer{}
|
g := &Peer{}
|
||||||
h := &Peer{}
|
h := &Peer{}
|
||||||
|
|
||||||
var trie *Trie
|
var trie *trieEntry
|
||||||
|
|
||||||
insert := func(peer *Peer, a, b, c, d byte, cidr uint) {
|
insert := func(peer *Peer, a, b, c, d byte, cidr uint) {
|
||||||
trie = trie.Insert([]byte{a, b, c, d}, cidr, peer)
|
trie = trie.insert([]byte{a, b, c, d}, cidr, peer)
|
||||||
}
|
}
|
||||||
|
|
||||||
assertEQ := func(peer *Peer, a, b, c, d byte) {
|
assertEQ := func(peer *Peer, a, b, c, d byte) {
|
||||||
p := trie.Lookup([]byte{a, b, c, d})
|
p := trie.lookup([]byte{a, b, c, d})
|
||||||
if p != peer {
|
if p != peer {
|
||||||
t.Error("Assert EQ failed")
|
t.Error("Assert EQ failed")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
assertNEQ := func(peer *Peer, a, b, c, d byte) {
|
assertNEQ := func(peer *Peer, a, b, c, d byte) {
|
||||||
p := trie.Lookup([]byte{a, b, c, d})
|
p := trie.lookup([]byte{a, b, c, d})
|
||||||
if p == peer {
|
if p == peer {
|
||||||
t.Error("Assert NEQ failed")
|
t.Error("Assert NEQ failed")
|
||||||
}
|
}
|
||||||
@ -173,7 +173,7 @@ func TestTrieIPv4(t *testing.T) {
|
|||||||
assertEQ(a, 192, 0, 0, 0)
|
assertEQ(a, 192, 0, 0, 0)
|
||||||
assertEQ(a, 255, 0, 0, 0)
|
assertEQ(a, 255, 0, 0, 0)
|
||||||
|
|
||||||
trie = trie.RemovePeer(a)
|
trie = trie.removeByPeer(a)
|
||||||
|
|
||||||
assertNEQ(a, 1, 0, 0, 0)
|
assertNEQ(a, 1, 0, 0, 0)
|
||||||
assertNEQ(a, 64, 0, 0, 0)
|
assertNEQ(a, 64, 0, 0, 0)
|
||||||
@ -186,7 +186,7 @@ func TestTrieIPv4(t *testing.T) {
|
|||||||
insert(a, 192, 168, 0, 0, 16)
|
insert(a, 192, 168, 0, 0, 16)
|
||||||
insert(a, 192, 168, 0, 0, 24)
|
insert(a, 192, 168, 0, 0, 24)
|
||||||
|
|
||||||
trie = trie.RemovePeer(a)
|
trie = trie.removeByPeer(a)
|
||||||
|
|
||||||
assertNEQ(a, 192, 168, 0, 1)
|
assertNEQ(a, 192, 168, 0, 1)
|
||||||
}
|
}
|
||||||
@ -204,7 +204,7 @@ func TestTrieIPv6(t *testing.T) {
|
|||||||
g := &Peer{}
|
g := &Peer{}
|
||||||
h := &Peer{}
|
h := &Peer{}
|
||||||
|
|
||||||
var trie *Trie
|
var trie *trieEntry
|
||||||
|
|
||||||
expand := func(a uint32) []byte {
|
expand := func(a uint32) []byte {
|
||||||
var out [4]byte
|
var out [4]byte
|
||||||
@ -221,7 +221,7 @@ func TestTrieIPv6(t *testing.T) {
|
|||||||
addr = append(addr, expand(b)...)
|
addr = append(addr, expand(b)...)
|
||||||
addr = append(addr, expand(c)...)
|
addr = append(addr, expand(c)...)
|
||||||
addr = append(addr, expand(d)...)
|
addr = append(addr, expand(d)...)
|
||||||
trie = trie.Insert(addr, cidr, peer)
|
trie = trie.insert(addr, cidr, peer)
|
||||||
}
|
}
|
||||||
|
|
||||||
assertEQ := func(peer *Peer, a, b, c, d uint32) {
|
assertEQ := func(peer *Peer, a, b, c, d uint32) {
|
||||||
@ -230,7 +230,7 @@ func TestTrieIPv6(t *testing.T) {
|
|||||||
addr = append(addr, expand(b)...)
|
addr = append(addr, expand(b)...)
|
||||||
addr = append(addr, expand(c)...)
|
addr = append(addr, expand(c)...)
|
||||||
addr = append(addr, expand(d)...)
|
addr = append(addr, expand(d)...)
|
||||||
p := trie.Lookup(addr)
|
p := trie.lookup(addr)
|
||||||
if p != peer {
|
if p != peer {
|
||||||
t.Error("Assert EQ failed")
|
t.Error("Assert EQ failed")
|
||||||
}
|
}
|
@ -46,7 +46,7 @@ type Device struct {
|
|||||||
|
|
||||||
routing struct {
|
routing struct {
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
table RoutingTable
|
table AllowedIPs
|
||||||
}
|
}
|
||||||
|
|
||||||
peers struct {
|
peers struct {
|
||||||
@ -95,7 +95,7 @@ func unsafeRemovePeer(device *Device, peer *Peer, key NoisePublicKey) {
|
|||||||
|
|
||||||
// stop routing and processing of packets
|
// stop routing and processing of packets
|
||||||
|
|
||||||
device.routing.table.RemovePeer(peer)
|
device.routing.table.RemoveByPeer(peer)
|
||||||
peer.Stop()
|
peer.Stop()
|
||||||
|
|
||||||
// remove from peer map
|
// remove from peer map
|
||||||
|
@ -33,7 +33,7 @@ type Keypairs struct {
|
|||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
current *Keypair
|
current *Keypair
|
||||||
previous *Keypair
|
previous *Keypair
|
||||||
next *Keypair // not yet "confirmed by transport"
|
next *Keypair
|
||||||
}
|
}
|
||||||
|
|
||||||
func (kp *Keypairs) Current() *Keypair {
|
func (kp *Keypairs) Current() *Keypair {
|
||||||
|
@ -40,7 +40,7 @@ func NewLogger(level int, prepend string) *Logger {
|
|||||||
|
|
||||||
logger.Debug = log.New(logDebug,
|
logger.Debug = log.New(logDebug,
|
||||||
"DEBUG: "+prepend,
|
"DEBUG: "+prepend,
|
||||||
log.Ldate|log.Ltime|log.Lshortfile,
|
log.Ldate|log.Ltime,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.Info = log.New(logInfo,
|
logger.Info = log.New(logInfo,
|
||||||
|
@ -71,14 +71,13 @@ func isZero(val []byte) bool {
|
|||||||
return acc == 1
|
return acc == 1
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* This function is not used as pervasively as it should because this is mostly impossible in Go at the moment */
|
||||||
func setZero(arr []byte) {
|
func setZero(arr []byte) {
|
||||||
for i := range arr {
|
for i := range arr {
|
||||||
arr[i] = 0
|
arr[i] = 0
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/* curve25519 wrappers */
|
|
||||||
|
|
||||||
func newPrivateKey() (sk NoisePrivateKey, err error) {
|
func newPrivateKey() (sk NoisePrivateKey, err error) {
|
||||||
// clamping: https://cr.yp.to/ecdh.html
|
// clamping: https://cr.yp.to/ecdh.html
|
||||||
_, err = rand.Read(sk[:])
|
_, err = rand.Read(sk[:])
|
||||||
|
@ -30,7 +30,7 @@ func loadExactHex(dst []byte, src string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if len(slice) != len(dst) {
|
if len(slice) != len(dst) {
|
||||||
return errors.New("Hex string does not fit the slice")
|
return errors.New("hex string does not fit the slice")
|
||||||
}
|
}
|
||||||
copy(dst, slice)
|
copy(dst, slice)
|
||||||
return nil
|
return nil
|
||||||
|
24
peer.go
24
peer.go
@ -61,7 +61,7 @@ type Peer struct {
|
|||||||
mutex sync.Mutex // held when stopping / starting routines
|
mutex sync.Mutex // held when stopping / starting routines
|
||||||
starting sync.WaitGroup // routines pending start
|
starting sync.WaitGroup // routines pending start
|
||||||
stopping sync.WaitGroup // routines pending stop
|
stopping sync.WaitGroup // routines pending stop
|
||||||
stop chan struct{} // size 0, stop all go-routines in peer
|
stop chan struct{} // size 0, stop all go routines in peer
|
||||||
}
|
}
|
||||||
|
|
||||||
mac CookieGenerator
|
mac CookieGenerator
|
||||||
@ -70,7 +70,7 @@ type Peer struct {
|
|||||||
func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
|
func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
|
||||||
|
|
||||||
if device.isClosed.Get() {
|
if device.isClosed.Get() {
|
||||||
return nil, errors.New("Device closed")
|
return nil, errors.New("device closed")
|
||||||
}
|
}
|
||||||
|
|
||||||
// lock resources
|
// lock resources
|
||||||
@ -87,7 +87,7 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
|
|||||||
// check if over limit
|
// check if over limit
|
||||||
|
|
||||||
if len(device.peers.keyMap) >= MaxPeers {
|
if len(device.peers.keyMap) >= MaxPeers {
|
||||||
return nil, errors.New("Too many peers")
|
return nil, errors.New("too many peers")
|
||||||
}
|
}
|
||||||
|
|
||||||
// create peer
|
// create peer
|
||||||
@ -104,7 +104,7 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
|
|||||||
|
|
||||||
_, ok := device.peers.keyMap[pk]
|
_, ok := device.peers.keyMap[pk]
|
||||||
if ok {
|
if ok {
|
||||||
return nil, errors.New("Adding existing peer")
|
return nil, errors.New("adding existing peer")
|
||||||
}
|
}
|
||||||
device.peers.keyMap[pk] = peer
|
device.peers.keyMap[pk] = peer
|
||||||
|
|
||||||
@ -134,26 +134,26 @@ func (peer *Peer) SendBuffer(buffer []byte) error {
|
|||||||
defer peer.device.net.mutex.RUnlock()
|
defer peer.device.net.mutex.RUnlock()
|
||||||
|
|
||||||
if peer.device.net.bind == nil {
|
if peer.device.net.bind == nil {
|
||||||
return errors.New("No bind")
|
return errors.New("no bind")
|
||||||
}
|
}
|
||||||
|
|
||||||
peer.mutex.RLock()
|
peer.mutex.RLock()
|
||||||
defer peer.mutex.RUnlock()
|
defer peer.mutex.RUnlock()
|
||||||
|
|
||||||
if peer.endpoint == nil {
|
if peer.endpoint == nil {
|
||||||
return errors.New("No known endpoint for peer")
|
return errors.New("no known endpoint for peer")
|
||||||
}
|
}
|
||||||
|
|
||||||
return peer.device.net.bind.Send(buffer, peer.endpoint)
|
return peer.device.net.bind.Send(buffer, peer.endpoint)
|
||||||
}
|
}
|
||||||
|
|
||||||
/* Returns a short string identifier for logging
|
|
||||||
*/
|
|
||||||
func (peer *Peer) String() string {
|
func (peer *Peer) String() string {
|
||||||
return fmt.Sprintf(
|
base64Key := base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:])
|
||||||
"peer(%s)",
|
abbreviatedKey := "invalid"
|
||||||
base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]),
|
if len(base64Key) == 44 {
|
||||||
)
|
abbreviatedKey = base64Key[0:4] + "..." + base64Key[40:44]
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("peer(%s)", abbreviatedKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (peer *Peer) Start() {
|
func (peer *Peer) Start() {
|
||||||
|
26
receive.go
26
receive.go
@ -600,20 +600,24 @@ func (peer *Peer) RoutineSequentialReceiver() {
|
|||||||
// check if using new key-pair
|
// check if using new key-pair
|
||||||
|
|
||||||
kp := &peer.keypairs
|
kp := &peer.keypairs
|
||||||
kp.mutex.Lock() //TODO: make this into an RW lock to reduce contention here for the equality check which is rarely true
|
|
||||||
if kp.next == elem.keypair {
|
if kp.next == elem.keypair {
|
||||||
old := kp.previous
|
kp.mutex.Lock()
|
||||||
kp.previous = kp.current
|
if kp.next != elem.keypair {
|
||||||
device.DeleteKeypair(old)
|
kp.mutex.Unlock()
|
||||||
kp.current = kp.next
|
} else {
|
||||||
kp.next = nil
|
old := kp.previous
|
||||||
peer.timersHandshakeComplete()
|
kp.previous = kp.current
|
||||||
select {
|
device.DeleteKeypair(old)
|
||||||
case peer.signals.newKeypairArrived <- struct{}{}:
|
kp.current = kp.next
|
||||||
default:
|
kp.next = nil
|
||||||
|
kp.mutex.Unlock()
|
||||||
|
peer.timersHandshakeComplete()
|
||||||
|
select {
|
||||||
|
case peer.signals.newKeypairArrived <- struct{}{}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
kp.mutex.Unlock()
|
|
||||||
|
|
||||||
peer.keepKeyFreshReceiving()
|
peer.keepKeyFreshReceiving()
|
||||||
peer.timersAnyAuthenticatedPacketTraversal()
|
peer.timersAnyAuthenticatedPacketTraversal()
|
||||||
|
70
routing.go
70
routing.go
@ -1,70 +0,0 @@
|
|||||||
/* SPDX-License-Identifier: GPL-2.0
|
|
||||||
*
|
|
||||||
* Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"net"
|
|
||||||
"sync"
|
|
||||||
)
|
|
||||||
|
|
||||||
type RoutingTable struct {
|
|
||||||
IPv4 *Trie
|
|
||||||
IPv6 *Trie
|
|
||||||
mutex sync.RWMutex
|
|
||||||
}
|
|
||||||
|
|
||||||
func (table *RoutingTable) AllowedIPs(peer *Peer) []net.IPNet {
|
|
||||||
table.mutex.RLock()
|
|
||||||
defer table.mutex.RUnlock()
|
|
||||||
|
|
||||||
allowed := make([]net.IPNet, 0, 10)
|
|
||||||
allowed = table.IPv4.AllowedIPs(peer, allowed)
|
|
||||||
allowed = table.IPv6.AllowedIPs(peer, allowed)
|
|
||||||
return allowed
|
|
||||||
}
|
|
||||||
|
|
||||||
func (table *RoutingTable) Reset() {
|
|
||||||
table.mutex.Lock()
|
|
||||||
defer table.mutex.Unlock()
|
|
||||||
|
|
||||||
table.IPv4 = nil
|
|
||||||
table.IPv6 = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (table *RoutingTable) RemovePeer(peer *Peer) {
|
|
||||||
table.mutex.Lock()
|
|
||||||
defer table.mutex.Unlock()
|
|
||||||
|
|
||||||
table.IPv4 = table.IPv4.RemovePeer(peer)
|
|
||||||
table.IPv6 = table.IPv6.RemovePeer(peer)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (table *RoutingTable) Insert(ip net.IP, cidr uint, peer *Peer) {
|
|
||||||
table.mutex.Lock()
|
|
||||||
defer table.mutex.Unlock()
|
|
||||||
|
|
||||||
switch len(ip) {
|
|
||||||
case net.IPv6len:
|
|
||||||
table.IPv6 = table.IPv6.Insert(ip, cidr, peer)
|
|
||||||
case net.IPv4len:
|
|
||||||
table.IPv4 = table.IPv4.Insert(ip, cidr, peer)
|
|
||||||
default:
|
|
||||||
panic(errors.New("Inserting unknown address type"))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (table *RoutingTable) LookupIPv4(address []byte) *Peer {
|
|
||||||
table.mutex.RLock()
|
|
||||||
defer table.mutex.RUnlock()
|
|
||||||
return table.IPv4.Lookup(address)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (table *RoutingTable) LookupIPv6(address []byte) *Peer {
|
|
||||||
table.mutex.RLock()
|
|
||||||
defer table.mutex.RUnlock()
|
|
||||||
return table.IPv6.Lookup(address)
|
|
||||||
}
|
|
@ -224,7 +224,9 @@ func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (tun *NativeTun) Close() error {
|
func (tun *NativeTun) Close() error {
|
||||||
return tun.fd.Close()
|
err := tun.fd.Close()
|
||||||
|
close(tun.events)
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tun *NativeTun) setMTU(n int) error {
|
func (tun *NativeTun) setMTU(n int) error {
|
||||||
|
@ -392,6 +392,7 @@ func (tun *NativeTun) Close() error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
tun.closingWriter.Write([]byte{0})
|
tun.closingWriter.Write([]byte{0})
|
||||||
|
close(tun.events)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -125,7 +125,9 @@ func (f *NativeTUN) Events() chan TUNEvent {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (f *NativeTUN) Close() error {
|
func (f *NativeTUN) Close() error {
|
||||||
return windows.Close(f.fd)
|
close(f.events)
|
||||||
|
err := windows.Close(f.fd)
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *NativeTUN) Write(b []byte) (int, error) {
|
func (f *NativeTUN) Write(b []byte) (int, error) {
|
||||||
|
6
uapi.go
6
uapi.go
@ -91,7 +91,7 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
|
|||||||
send(fmt.Sprintf("rx_bytes=%d", peer.stats.rxBytes))
|
send(fmt.Sprintf("rx_bytes=%d", peer.stats.rxBytes))
|
||||||
send(fmt.Sprintf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval))
|
send(fmt.Sprintf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval))
|
||||||
|
|
||||||
for _, ip := range device.routing.table.AllowedIPs(peer) {
|
for _, ip := range device.routing.table.EntriesForPeer(peer) {
|
||||||
send("allowed_ip=" + ip.String())
|
send("allowed_ip=" + ip.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -337,7 +337,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
|
|||||||
|
|
||||||
case "replace_allowed_ips":
|
case "replace_allowed_ips":
|
||||||
|
|
||||||
logDebug.Println("UAPI: Removing all allowed IPs for peer:", peer)
|
logDebug.Println("UAPI: Removing all allowed EntriesForPeer for peer:", peer)
|
||||||
|
|
||||||
if value != "true" {
|
if value != "true" {
|
||||||
logError.Println("Failed to set replace_allowed_ips, invalid value:", value)
|
logError.Println("Failed to set replace_allowed_ips, invalid value:", value)
|
||||||
@ -349,7 +349,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
|
|||||||
}
|
}
|
||||||
|
|
||||||
device.routing.mutex.Lock()
|
device.routing.mutex.Lock()
|
||||||
device.routing.table.RemovePeer(peer)
|
device.routing.table.RemoveByPeer(peer)
|
||||||
device.routing.mutex.Unlock()
|
device.routing.mutex.Unlock()
|
||||||
|
|
||||||
case "allowed_ip":
|
case "allowed_ip":
|
||||||
|
Loading…
Reference in New Issue
Block a user