18b6627f33
Simplification found by staticcheck: $ staticcheck ./... | grep S1012 device/cookie.go:90:5: should use time.Since instead of time.Now().Sub (S1012) device/cookie.go:127:5: should use time.Since instead of time.Now().Sub (S1012) device/cookie.go:242:5: should use time.Since instead of time.Now().Sub (S1012) device/noise-protocol.go:304:13: should use time.Since instead of time.Now().Sub (S1012) device/receive.go:82:46: should use time.Since instead of time.Now().Sub (S1012) device/send.go:132:5: should use time.Since instead of time.Now().Sub (S1012) device/send.go:139:5: should use time.Since instead of time.Now().Sub (S1012) device/send.go:235:59: should use time.Since instead of time.Now().Sub (S1012) device/send.go:393:9: should use time.Since instead of time.Now().Sub (S1012) ratelimiter/ratelimiter.go:79:10: should use time.Since instead of time.Now().Sub (S1012) ratelimiter/ratelimiter.go:87:10: should use time.Since instead of time.Now().Sub (S1012) Change applied using: $ find . -type f -name "*.go" -exec sed -i "s/Now().Sub(/Since(/g" {} \; Signed-off-by: Matt Layher <mdlayher@gmail.com>
166 lines
3.1 KiB
Go
166 lines
3.1 KiB
Go
/* SPDX-License-Identifier: MIT
|
|
*
|
|
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
|
*/
|
|
|
|
package ratelimiter
|
|
|
|
import (
|
|
"net"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
const (
|
|
packetsPerSecond = 20
|
|
packetsBurstable = 5
|
|
garbageCollectTime = time.Second
|
|
packetCost = 1000000000 / packetsPerSecond
|
|
maxTokens = packetCost * packetsBurstable
|
|
)
|
|
|
|
type RatelimiterEntry struct {
|
|
sync.Mutex
|
|
lastTime time.Time
|
|
tokens int64
|
|
}
|
|
|
|
type Ratelimiter struct {
|
|
sync.RWMutex
|
|
stopReset chan struct{}
|
|
tableIPv4 map[[net.IPv4len]byte]*RatelimiterEntry
|
|
tableIPv6 map[[net.IPv6len]byte]*RatelimiterEntry
|
|
}
|
|
|
|
func (rate *Ratelimiter) Close() {
|
|
rate.Lock()
|
|
defer rate.Unlock()
|
|
|
|
if rate.stopReset != nil {
|
|
close(rate.stopReset)
|
|
}
|
|
}
|
|
|
|
func (rate *Ratelimiter) Init() {
|
|
rate.Lock()
|
|
defer rate.Unlock()
|
|
|
|
// stop any ongoing garbage collection routine
|
|
|
|
if rate.stopReset != nil {
|
|
close(rate.stopReset)
|
|
}
|
|
|
|
rate.stopReset = make(chan struct{})
|
|
rate.tableIPv4 = make(map[[net.IPv4len]byte]*RatelimiterEntry)
|
|
rate.tableIPv6 = make(map[[net.IPv6len]byte]*RatelimiterEntry)
|
|
|
|
// start garbage collection routine
|
|
|
|
go func() {
|
|
ticker := time.NewTicker(time.Second)
|
|
ticker.Stop()
|
|
for {
|
|
select {
|
|
case _, ok := <-rate.stopReset:
|
|
ticker.Stop()
|
|
if ok {
|
|
ticker = time.NewTicker(time.Second)
|
|
} else {
|
|
return
|
|
}
|
|
case <-ticker.C:
|
|
func() {
|
|
rate.Lock()
|
|
defer rate.Unlock()
|
|
|
|
for key, entry := range rate.tableIPv4 {
|
|
entry.Lock()
|
|
if time.Since(entry.lastTime) > garbageCollectTime {
|
|
delete(rate.tableIPv4, key)
|
|
}
|
|
entry.Unlock()
|
|
}
|
|
|
|
for key, entry := range rate.tableIPv6 {
|
|
entry.Lock()
|
|
if time.Since(entry.lastTime) > garbageCollectTime {
|
|
delete(rate.tableIPv6, key)
|
|
}
|
|
entry.Unlock()
|
|
}
|
|
|
|
if len(rate.tableIPv4) == 0 && len(rate.tableIPv6) == 0 {
|
|
ticker.Stop()
|
|
}
|
|
}()
|
|
}
|
|
}
|
|
}()
|
|
}
|
|
|
|
func (rate *Ratelimiter) Allow(ip net.IP) bool {
|
|
var entry *RatelimiterEntry
|
|
var keyIPv4 [net.IPv4len]byte
|
|
var keyIPv6 [net.IPv6len]byte
|
|
|
|
// lookup entry
|
|
|
|
IPv4 := ip.To4()
|
|
IPv6 := ip.To16()
|
|
|
|
rate.RLock()
|
|
|
|
if IPv4 != nil {
|
|
copy(keyIPv4[:], IPv4)
|
|
entry = rate.tableIPv4[keyIPv4]
|
|
} else {
|
|
copy(keyIPv6[:], IPv6)
|
|
entry = rate.tableIPv6[keyIPv6]
|
|
}
|
|
|
|
rate.RUnlock()
|
|
|
|
// make new entry if not found
|
|
|
|
if entry == nil {
|
|
entry = new(RatelimiterEntry)
|
|
entry.tokens = maxTokens - packetCost
|
|
entry.lastTime = time.Now()
|
|
rate.Lock()
|
|
if IPv4 != nil {
|
|
rate.tableIPv4[keyIPv4] = entry
|
|
if len(rate.tableIPv4) == 1 && len(rate.tableIPv6) == 0 {
|
|
rate.stopReset <- struct{}{}
|
|
}
|
|
} else {
|
|
rate.tableIPv6[keyIPv6] = entry
|
|
if len(rate.tableIPv6) == 1 && len(rate.tableIPv4) == 0 {
|
|
rate.stopReset <- struct{}{}
|
|
}
|
|
}
|
|
rate.Unlock()
|
|
return true
|
|
}
|
|
|
|
// add tokens to entry
|
|
|
|
entry.Lock()
|
|
now := time.Now()
|
|
entry.tokens += now.Sub(entry.lastTime).Nanoseconds()
|
|
entry.lastTime = now
|
|
if entry.tokens > maxTokens {
|
|
entry.tokens = maxTokens
|
|
}
|
|
|
|
// subtract cost of packet
|
|
|
|
if entry.tokens > packetCost {
|
|
entry.tokens -= packetCost
|
|
entry.Unlock()
|
|
return true
|
|
}
|
|
entry.Unlock()
|
|
return false
|
|
}
|