device: zero out allowedip node pointers when removing

This should make it a bit easier for the garbage collector.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
This commit is contained in:
Jason A. Donenfeld 2021-06-04 16:33:28 +02:00
parent d0cf96114f
commit f9b48a961c
2 changed files with 22 additions and 1 deletions

View File

@ -96,6 +96,14 @@ func (node *trieEntry) maskSelf() {
} }
} }
func (node *trieEntry) zeroizePointers() {
// Make the garbage collector's life slightly easier
node.peer = nil
node.child[0] = nil
node.child[1] = nil
node.parent.parentBit = nil
}
func (node *trieEntry) nodePlacement(ip net.IP, cidr uint8) (parent *trieEntry, exact bool) { func (node *trieEntry) nodePlacement(ip net.IP, cidr uint8) (parent *trieEntry, exact bool) {
for node != nil && node.cidr <= cidr && commonBits(node.bits, ip) >= node.cidr { for node != nil && node.cidr <= cidr && commonBits(node.bits, ip) >= node.cidr {
parent = node parent = node
@ -257,10 +265,12 @@ func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
} }
*node.parent.parentBit = child *node.parent.parentBit = child
if node.child[0] != nil || node.child[1] != nil || node.parent.parentBitType > 1 { if node.child[0] != nil || node.child[1] != nil || node.parent.parentBitType > 1 {
node.zeroizePointers()
continue continue
} }
parent := (*trieEntry)(unsafe.Pointer(uintptr(unsafe.Pointer(node.parent.parentBit)) - unsafe.Offsetof(node.child) - unsafe.Sizeof(node.child[0])*uintptr(node.parent.parentBitType))) parent := (*trieEntry)(unsafe.Pointer(uintptr(unsafe.Pointer(node.parent.parentBit)) - unsafe.Offsetof(node.child) - unsafe.Sizeof(node.child[0])*uintptr(node.parent.parentBitType)))
if parent.peer != nil { if parent.peer != nil {
node.zeroizePointers()
continue continue
} }
child = parent.child[node.parent.parentBitType^1] child = parent.child[node.parent.parentBitType^1]
@ -268,6 +278,8 @@ func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
child.parent = parent.parent child.parent = parent.parent
} }
*parent.parent.parentBit = child *parent.parent.parentBit = child
node.zeroizePointers()
parent.zeroizePointers()
} }
} }

View File

@ -159,7 +159,16 @@ func TestTrieIPv4(t *testing.T) {
assertNEQ(a, 192, 0, 0, 0) assertNEQ(a, 192, 0, 0, 0)
assertNEQ(a, 255, 0, 0, 0) assertNEQ(a, 255, 0, 0, 0)
allowedIPs = AllowedIPs{} allowedIPs.RemoveByPeer(a)
allowedIPs.RemoveByPeer(b)
allowedIPs.RemoveByPeer(c)
allowedIPs.RemoveByPeer(d)
allowedIPs.RemoveByPeer(e)
allowedIPs.RemoveByPeer(g)
allowedIPs.RemoveByPeer(h)
if allowedIPs.IPv4 != nil || allowedIPs.IPv6 != nil {
t.Error("Expected removing all the peers to empty trie, but it did not")
}
insert(a, 192, 168, 0, 0, 16) insert(a, 192, 168, 0, 0, 16)
insert(a, 192, 168, 0, 0, 24) insert(a, 192, 168, 0, 0, 24)