2019-01-02 01:55:51 +01:00
|
|
|
/* SPDX-License-Identifier: MIT
|
2018-05-03 15:04:00 +02:00
|
|
|
*
|
2021-01-28 17:52:15 +01:00
|
|
|
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
2018-05-03 15:04:00 +02:00
|
|
|
*/
|
|
|
|
|
2019-03-03 04:04:41 +01:00
|
|
|
package device
|
2017-05-30 22:36:49 +02:00
|
|
|
|
2017-06-04 21:48:15 +02:00
|
|
|
import (
|
2017-06-28 23:45:45 +02:00
|
|
|
"errors"
|
2018-05-14 15:49:20 +02:00
|
|
|
"math/bits"
|
2017-06-04 21:48:15 +02:00
|
|
|
"net"
|
2018-05-13 19:33:41 +02:00
|
|
|
"sync"
|
2018-05-14 15:49:20 +02:00
|
|
|
"unsafe"
|
2017-06-04 21:48:15 +02:00
|
|
|
)
|
|
|
|
|
2018-05-13 19:33:41 +02:00
|
|
|
type trieEntry struct {
|
2021-01-26 23:44:37 +01:00
|
|
|
child [2]*trieEntry
|
|
|
|
peer *Peer
|
|
|
|
bits net.IP
|
|
|
|
cidr uint
|
|
|
|
bit_at_byte uint
|
|
|
|
bit_at_shift uint
|
|
|
|
nextEntryForPeer *trieEntry
|
|
|
|
pprevEntryForPeer **trieEntry
|
2017-05-30 22:36:49 +02:00
|
|
|
}
|
|
|
|
|
2018-05-14 15:49:20 +02:00
|
|
|
func isLittleEndian() bool {
|
|
|
|
one := uint32(1)
|
|
|
|
return *(*byte)(unsafe.Pointer(&one)) != 0
|
|
|
|
}
|
|
|
|
|
|
|
|
func swapU32(i uint32) uint32 {
|
|
|
|
if !isLittleEndian() {
|
|
|
|
return i
|
|
|
|
}
|
|
|
|
|
|
|
|
return bits.ReverseBytes32(i)
|
|
|
|
}
|
|
|
|
|
|
|
|
func swapU64(i uint64) uint64 {
|
|
|
|
if !isLittleEndian() {
|
|
|
|
return i
|
|
|
|
}
|
|
|
|
|
|
|
|
return bits.ReverseBytes64(i)
|
|
|
|
}
|
|
|
|
|
|
|
|
func commonBits(ip1 net.IP, ip2 net.IP) uint {
|
|
|
|
size := len(ip1)
|
|
|
|
if size == net.IPv4len {
|
|
|
|
a := (*uint32)(unsafe.Pointer(&ip1[0]))
|
|
|
|
b := (*uint32)(unsafe.Pointer(&ip2[0]))
|
|
|
|
x := *a ^ *b
|
|
|
|
return uint(bits.LeadingZeros32(swapU32(x)))
|
|
|
|
} else if size == net.IPv6len {
|
|
|
|
a := (*uint64)(unsafe.Pointer(&ip1[0]))
|
|
|
|
b := (*uint64)(unsafe.Pointer(&ip2[0]))
|
|
|
|
x := *a ^ *b
|
|
|
|
if x != 0 {
|
|
|
|
return uint(bits.LeadingZeros64(swapU64(x)))
|
2017-05-30 22:36:49 +02:00
|
|
|
}
|
2018-05-14 15:49:20 +02:00
|
|
|
a = (*uint64)(unsafe.Pointer(&ip1[8]))
|
|
|
|
b = (*uint64)(unsafe.Pointer(&ip2[8]))
|
|
|
|
x = *a ^ *b
|
|
|
|
return 64 + uint(bits.LeadingZeros64(swapU64(x)))
|
|
|
|
} else {
|
|
|
|
panic("Wrong size bit string")
|
2017-05-30 22:36:49 +02:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2021-01-26 23:44:37 +01:00
|
|
|
func (node *trieEntry) addToPeerEntries() {
|
|
|
|
p := node.peer
|
|
|
|
first := p.firstTrieEntry
|
|
|
|
node.nextEntryForPeer = first
|
|
|
|
if first != nil {
|
|
|
|
first.pprevEntryForPeer = &node.nextEntryForPeer
|
|
|
|
}
|
|
|
|
p.firstTrieEntry = node
|
|
|
|
node.pprevEntryForPeer = &p.firstTrieEntry
|
|
|
|
}
|
|
|
|
|
|
|
|
func (node *trieEntry) removeFromPeerEntries() {
|
|
|
|
if node.pprevEntryForPeer == nil {
|
|
|
|
return
|
|
|
|
}
|
|
|
|
next := node.nextEntryForPeer
|
|
|
|
pprev := node.pprevEntryForPeer
|
|
|
|
*pprev = next
|
|
|
|
if next != nil {
|
|
|
|
next.pprevEntryForPeer = pprev
|
|
|
|
}
|
|
|
|
node.nextEntryForPeer = nil
|
|
|
|
node.pprevEntryForPeer = nil
|
|
|
|
}
|
|
|
|
|
2018-05-13 19:33:41 +02:00
|
|
|
func (node *trieEntry) removeByPeer(p *Peer) *trieEntry {
|
2017-05-30 22:36:49 +02:00
|
|
|
if node == nil {
|
|
|
|
return node
|
|
|
|
}
|
|
|
|
|
2017-12-01 23:37:26 +01:00
|
|
|
// walk recursively
|
2017-05-30 22:36:49 +02:00
|
|
|
|
2018-05-13 19:33:41 +02:00
|
|
|
node.child[0] = node.child[0].removeByPeer(p)
|
|
|
|
node.child[1] = node.child[1].removeByPeer(p)
|
2017-05-30 22:36:49 +02:00
|
|
|
|
|
|
|
if node.peer != p {
|
|
|
|
return node
|
|
|
|
}
|
|
|
|
|
2017-07-13 14:32:40 +02:00
|
|
|
// remove peer & merge
|
2017-05-30 22:36:49 +02:00
|
|
|
|
2021-01-26 23:44:37 +01:00
|
|
|
node.removeFromPeerEntries()
|
2017-05-30 22:36:49 +02:00
|
|
|
node.peer = nil
|
|
|
|
if node.child[0] == nil {
|
|
|
|
return node.child[1]
|
|
|
|
}
|
|
|
|
return node.child[0]
|
|
|
|
}
|
|
|
|
|
2018-05-13 19:33:41 +02:00
|
|
|
func (node *trieEntry) choose(ip net.IP) byte {
|
2017-06-04 21:48:15 +02:00
|
|
|
return (ip[node.bit_at_byte] >> node.bit_at_shift) & 1
|
2017-06-01 21:31:30 +02:00
|
|
|
}
|
|
|
|
|
2021-01-26 23:44:37 +01:00
|
|
|
func (node *trieEntry) maskSelf() {
|
|
|
|
mask := net.CIDRMask(int(node.cidr), len(node.bits)*8)
|
|
|
|
for i := 0; i < len(mask); i++ {
|
|
|
|
node.bits[i] &= mask[i]
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2018-05-13 19:33:41 +02:00
|
|
|
func (node *trieEntry) insert(ip net.IP, cidr uint, peer *Peer) *trieEntry {
|
2017-06-01 21:31:30 +02:00
|
|
|
|
2017-07-13 14:32:40 +02:00
|
|
|
// at leaf
|
2017-06-01 21:31:30 +02:00
|
|
|
|
2017-05-30 22:36:49 +02:00
|
|
|
if node == nil {
|
2021-01-26 23:44:37 +01:00
|
|
|
node := &trieEntry{
|
2017-06-04 21:48:15 +02:00
|
|
|
bits: ip,
|
2017-05-30 22:36:49 +02:00
|
|
|
peer: peer,
|
|
|
|
cidr: cidr,
|
|
|
|
bit_at_byte: cidr / 8,
|
|
|
|
bit_at_shift: 7 - (cidr % 8),
|
|
|
|
}
|
2021-01-26 23:44:37 +01:00
|
|
|
node.maskSelf()
|
|
|
|
node.addToPeerEntries()
|
|
|
|
return node
|
2017-05-30 22:36:49 +02:00
|
|
|
}
|
|
|
|
|
2017-07-13 14:32:40 +02:00
|
|
|
// traverse deeper
|
2017-05-30 22:36:49 +02:00
|
|
|
|
2017-06-04 21:48:15 +02:00
|
|
|
common := commonBits(node.bits, ip)
|
2017-05-30 22:36:49 +02:00
|
|
|
if node.cidr <= cidr && common >= node.cidr {
|
|
|
|
if node.cidr == cidr {
|
2021-01-26 23:44:37 +01:00
|
|
|
node.removeFromPeerEntries()
|
2017-05-30 22:36:49 +02:00
|
|
|
node.peer = peer
|
2021-01-26 23:44:37 +01:00
|
|
|
node.addToPeerEntries()
|
2017-05-30 22:36:49 +02:00
|
|
|
return node
|
|
|
|
}
|
2017-06-04 21:48:15 +02:00
|
|
|
bit := node.choose(ip)
|
2018-05-13 19:33:41 +02:00
|
|
|
node.child[bit] = node.child[bit].insert(ip, cidr, peer)
|
2017-05-30 22:36:49 +02:00
|
|
|
return node
|
|
|
|
}
|
|
|
|
|
2017-07-13 14:32:40 +02:00
|
|
|
// split node
|
2017-05-30 22:36:49 +02:00
|
|
|
|
2018-05-13 19:33:41 +02:00
|
|
|
newNode := &trieEntry{
|
2017-06-04 21:48:15 +02:00
|
|
|
bits: ip,
|
2017-05-30 22:36:49 +02:00
|
|
|
peer: peer,
|
|
|
|
cidr: cidr,
|
|
|
|
bit_at_byte: cidr / 8,
|
|
|
|
bit_at_shift: 7 - (cidr % 8),
|
|
|
|
}
|
2021-01-26 23:44:37 +01:00
|
|
|
newNode.maskSelf()
|
|
|
|
newNode.addToPeerEntries()
|
2017-05-30 22:36:49 +02:00
|
|
|
|
|
|
|
cidr = min(cidr, common)
|
|
|
|
|
2017-07-13 14:32:40 +02:00
|
|
|
// check for shorter prefix
|
2017-05-30 22:36:49 +02:00
|
|
|
|
2017-06-01 21:31:30 +02:00
|
|
|
if newNode.cidr == cidr {
|
|
|
|
bit := newNode.choose(node.bits)
|
|
|
|
newNode.child[bit] = node
|
|
|
|
return newNode
|
|
|
|
}
|
|
|
|
|
2017-07-13 14:32:40 +02:00
|
|
|
// create new parent for node & newNode
|
2017-05-30 22:36:49 +02:00
|
|
|
|
2018-05-13 19:33:41 +02:00
|
|
|
parent := &trieEntry{
|
2021-01-26 23:44:37 +01:00
|
|
|
bits: append([]byte{}, ip...),
|
2017-06-01 21:31:30 +02:00
|
|
|
peer: nil,
|
|
|
|
cidr: cidr,
|
|
|
|
bit_at_byte: cidr / 8,
|
|
|
|
bit_at_shift: 7 - (cidr % 8),
|
2017-05-30 22:36:49 +02:00
|
|
|
}
|
2021-01-26 23:44:37 +01:00
|
|
|
parent.maskSelf()
|
2017-05-30 22:36:49 +02:00
|
|
|
|
2017-06-04 21:48:15 +02:00
|
|
|
bit := parent.choose(ip)
|
2017-06-01 21:31:30 +02:00
|
|
|
parent.child[bit] = newNode
|
|
|
|
parent.child[bit^1] = node
|
|
|
|
|
|
|
|
return parent
|
|
|
|
}
|
|
|
|
|
2018-05-13 19:33:41 +02:00
|
|
|
func (node *trieEntry) lookup(ip net.IP) *Peer {
|
2017-06-01 21:31:30 +02:00
|
|
|
var found *Peer
|
2017-06-04 21:48:15 +02:00
|
|
|
size := uint(len(ip))
|
|
|
|
for node != nil && commonBits(node.bits, ip) >= node.cidr {
|
2017-06-01 21:31:30 +02:00
|
|
|
if node.peer != nil {
|
|
|
|
found = node.peer
|
|
|
|
}
|
|
|
|
if node.bit_at_byte == size {
|
|
|
|
break
|
|
|
|
}
|
2017-06-04 21:48:15 +02:00
|
|
|
bit := node.choose(ip)
|
2017-06-01 21:31:30 +02:00
|
|
|
node = node.child[bit]
|
|
|
|
}
|
|
|
|
return found
|
|
|
|
}
|
2017-05-30 22:36:49 +02:00
|
|
|
|
2018-05-13 19:33:41 +02:00
|
|
|
type AllowedIPs struct {
|
|
|
|
IPv4 *trieEntry
|
|
|
|
IPv6 *trieEntry
|
|
|
|
mutex sync.RWMutex
|
|
|
|
}
|
|
|
|
|
2021-01-26 23:44:37 +01:00
|
|
|
func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(ip net.IP, cidr uint) bool) {
|
2018-05-13 19:33:41 +02:00
|
|
|
table.mutex.RLock()
|
|
|
|
defer table.mutex.RUnlock()
|
|
|
|
|
2021-01-26 23:44:37 +01:00
|
|
|
for node := peer.firstTrieEntry; node != nil; node = node.nextEntryForPeer {
|
|
|
|
if !cb(node.bits, node.cidr) {
|
|
|
|
return
|
|
|
|
}
|
|
|
|
}
|
2018-05-13 19:33:41 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
|
|
|
|
table.mutex.Lock()
|
|
|
|
defer table.mutex.Unlock()
|
|
|
|
|
|
|
|
table.IPv4 = table.IPv4.removeByPeer(peer)
|
|
|
|
table.IPv6 = table.IPv6.removeByPeer(peer)
|
|
|
|
}
|
|
|
|
|
|
|
|
func (table *AllowedIPs) Insert(ip net.IP, cidr uint, peer *Peer) {
|
|
|
|
table.mutex.Lock()
|
|
|
|
defer table.mutex.Unlock()
|
|
|
|
|
|
|
|
switch len(ip) {
|
|
|
|
case net.IPv6len:
|
|
|
|
table.IPv6 = table.IPv6.insert(ip, cidr, peer)
|
|
|
|
case net.IPv4len:
|
|
|
|
table.IPv4 = table.IPv4.insert(ip, cidr, peer)
|
|
|
|
default:
|
|
|
|
panic(errors.New("inserting unknown address type"))
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func (table *AllowedIPs) LookupIPv4(address []byte) *Peer {
|
|
|
|
table.mutex.RLock()
|
|
|
|
defer table.mutex.RUnlock()
|
|
|
|
return table.IPv4.lookup(address)
|
|
|
|
}
|
|
|
|
|
|
|
|
func (table *AllowedIPs) LookupIPv6(address []byte) *Peer {
|
|
|
|
table.mutex.RLock()
|
|
|
|
defer table.mutex.RUnlock()
|
|
|
|
return table.IPv6.lookup(address)
|
|
|
|
}
|