all: use Go 1.19 and its atomic types
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com> Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
This commit is contained in:
parent
d1d08426b2
commit
b51010ba13
@ -74,7 +74,7 @@ type afWinRingBind struct {
|
|||||||
type WinRingBind struct {
|
type WinRingBind struct {
|
||||||
v4, v6 afWinRingBind
|
v4, v6 afWinRingBind
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
isOpen uint32
|
isOpen atomic.Uint32 // 0, 1, or 2
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewDefaultBind() Bind { return NewWinRingBind() }
|
func NewDefaultBind() Bind { return NewWinRingBind() }
|
||||||
@ -212,7 +212,7 @@ func (bind *afWinRingBind) CloseAndZero() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (bind *WinRingBind) closeAndZero() {
|
func (bind *WinRingBind) closeAndZero() {
|
||||||
atomic.StoreUint32(&bind.isOpen, 0)
|
bind.isOpen.Store(0)
|
||||||
bind.v4.CloseAndZero()
|
bind.v4.CloseAndZero()
|
||||||
bind.v6.CloseAndZero()
|
bind.v6.CloseAndZero()
|
||||||
}
|
}
|
||||||
@ -276,7 +276,7 @@ func (bind *WinRingBind) Open(port uint16) (recvFns []ReceiveFunc, selectedPort
|
|||||||
bind.closeAndZero()
|
bind.closeAndZero()
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
if atomic.LoadUint32(&bind.isOpen) != 0 {
|
if bind.isOpen.Load() != 0 {
|
||||||
return nil, 0, ErrBindAlreadyOpen
|
return nil, 0, ErrBindAlreadyOpen
|
||||||
}
|
}
|
||||||
var sa windows.Sockaddr
|
var sa windows.Sockaddr
|
||||||
@ -299,17 +299,17 @@ func (bind *WinRingBind) Open(port uint16) (recvFns []ReceiveFunc, selectedPort
|
|||||||
return nil, 0, err
|
return nil, 0, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
atomic.StoreUint32(&bind.isOpen, 1)
|
bind.isOpen.Store(1)
|
||||||
return []ReceiveFunc{bind.receiveIPv4, bind.receiveIPv6}, selectedPort, err
|
return []ReceiveFunc{bind.receiveIPv4, bind.receiveIPv6}, selectedPort, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (bind *WinRingBind) Close() error {
|
func (bind *WinRingBind) Close() error {
|
||||||
bind.mu.RLock()
|
bind.mu.RLock()
|
||||||
if atomic.LoadUint32(&bind.isOpen) != 1 {
|
if bind.isOpen.Load() != 1 {
|
||||||
bind.mu.RUnlock()
|
bind.mu.RUnlock()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
atomic.StoreUint32(&bind.isOpen, 2)
|
bind.isOpen.Store(2)
|
||||||
windows.PostQueuedCompletionStatus(bind.v4.rx.iocp, 0, 0, nil)
|
windows.PostQueuedCompletionStatus(bind.v4.rx.iocp, 0, 0, nil)
|
||||||
windows.PostQueuedCompletionStatus(bind.v4.tx.iocp, 0, 0, nil)
|
windows.PostQueuedCompletionStatus(bind.v4.tx.iocp, 0, 0, nil)
|
||||||
windows.PostQueuedCompletionStatus(bind.v6.rx.iocp, 0, 0, nil)
|
windows.PostQueuedCompletionStatus(bind.v6.rx.iocp, 0, 0, nil)
|
||||||
@ -345,8 +345,8 @@ func (bind *afWinRingBind) InsertReceiveRequest() error {
|
|||||||
//go:linkname procyield runtime.procyield
|
//go:linkname procyield runtime.procyield
|
||||||
func procyield(cycles uint32)
|
func procyield(cycles uint32)
|
||||||
|
|
||||||
func (bind *afWinRingBind) Receive(buf []byte, isOpen *uint32) (int, Endpoint, error) {
|
func (bind *afWinRingBind) Receive(buf []byte, isOpen *atomic.Uint32) (int, Endpoint, error) {
|
||||||
if atomic.LoadUint32(isOpen) != 1 {
|
if isOpen.Load() != 1 {
|
||||||
return 0, nil, net.ErrClosed
|
return 0, nil, net.ErrClosed
|
||||||
}
|
}
|
||||||
bind.rx.mu.Lock()
|
bind.rx.mu.Lock()
|
||||||
@ -359,7 +359,7 @@ retry:
|
|||||||
count = 0
|
count = 0
|
||||||
for tries := 0; count == 0 && tries < receiveSpins; tries++ {
|
for tries := 0; count == 0 && tries < receiveSpins; tries++ {
|
||||||
if tries > 0 {
|
if tries > 0 {
|
||||||
if atomic.LoadUint32(isOpen) != 1 {
|
if isOpen.Load() != 1 {
|
||||||
return 0, nil, net.ErrClosed
|
return 0, nil, net.ErrClosed
|
||||||
}
|
}
|
||||||
procyield(1)
|
procyield(1)
|
||||||
@ -378,7 +378,7 @@ retry:
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, nil, err
|
return 0, nil, err
|
||||||
}
|
}
|
||||||
if atomic.LoadUint32(isOpen) != 1 {
|
if isOpen.Load() != 1 {
|
||||||
return 0, nil, net.ErrClosed
|
return 0, nil, net.ErrClosed
|
||||||
}
|
}
|
||||||
count = winrio.DequeueCompletion(bind.rx.cq, results[:])
|
count = winrio.DequeueCompletion(bind.rx.cq, results[:])
|
||||||
@ -395,7 +395,7 @@ retry:
|
|||||||
// huge packets. Just try again when this happens. The infinite loop this could cause is still limited to
|
// huge packets. Just try again when this happens. The infinite loop this could cause is still limited to
|
||||||
// attacker bandwidth, just like the rest of the receive path.
|
// attacker bandwidth, just like the rest of the receive path.
|
||||||
if windows.Errno(results[0].Status) == windows.WSAEMSGSIZE {
|
if windows.Errno(results[0].Status) == windows.WSAEMSGSIZE {
|
||||||
if atomic.LoadUint32(isOpen) != 1 {
|
if isOpen.Load() != 1 {
|
||||||
return 0, nil, net.ErrClosed
|
return 0, nil, net.ErrClosed
|
||||||
}
|
}
|
||||||
goto retry
|
goto retry
|
||||||
@ -421,8 +421,8 @@ func (bind *WinRingBind) receiveIPv6(buf []byte) (int, Endpoint, error) {
|
|||||||
return bind.v6.Receive(buf, &bind.isOpen)
|
return bind.v6.Receive(buf, &bind.isOpen)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *uint32) error {
|
func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *atomic.Uint32) error {
|
||||||
if atomic.LoadUint32(isOpen) != 1 {
|
if isOpen.Load() != 1 {
|
||||||
return net.ErrClosed
|
return net.ErrClosed
|
||||||
}
|
}
|
||||||
if len(buf) > bytesPerPacket {
|
if len(buf) > bytesPerPacket {
|
||||||
@ -444,7 +444,7 @@ func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *uint3
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if atomic.LoadUint32(isOpen) != 1 {
|
if isOpen.Load() != 1 {
|
||||||
return net.ErrClosed
|
return net.ErrClosed
|
||||||
}
|
}
|
||||||
count = winrio.DequeueCompletion(bind.tx.cq, results[:])
|
count = winrio.DequeueCompletion(bind.tx.cq, results[:])
|
||||||
@ -538,7 +538,7 @@ func (bind *StdNetBind) BindSocketToInterface6(interfaceIndex uint32, blackhole
|
|||||||
func (bind *WinRingBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
|
func (bind *WinRingBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
|
||||||
bind.mu.RLock()
|
bind.mu.RLock()
|
||||||
defer bind.mu.RUnlock()
|
defer bind.mu.RUnlock()
|
||||||
if atomic.LoadUint32(&bind.isOpen) != 1 {
|
if bind.isOpen.Load() != 1 {
|
||||||
return net.ErrClosed
|
return net.ErrClosed
|
||||||
}
|
}
|
||||||
err := bindSocketToInterface4(bind.v4.sock, interfaceIndex)
|
err := bindSocketToInterface4(bind.v4.sock, interfaceIndex)
|
||||||
@ -552,7 +552,7 @@ func (bind *WinRingBind) BindSocketToInterface4(interfaceIndex uint32, blackhole
|
|||||||
func (bind *WinRingBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
|
func (bind *WinRingBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
|
||||||
bind.mu.RLock()
|
bind.mu.RLock()
|
||||||
defer bind.mu.RUnlock()
|
defer bind.mu.RUnlock()
|
||||||
if atomic.LoadUint32(&bind.isOpen) != 1 {
|
if bind.isOpen.Load() != 1 {
|
||||||
return net.ErrClosed
|
return net.ErrClosed
|
||||||
}
|
}
|
||||||
err := bindSocketToInterface6(bind.v6.sock, interfaceIndex)
|
err := bindSocketToInterface6(bind.v6.sock, interfaceIndex)
|
||||||
|
@ -1,65 +0,0 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package device
|
|
||||||
|
|
||||||
import (
|
|
||||||
"reflect"
|
|
||||||
"testing"
|
|
||||||
"unsafe"
|
|
||||||
)
|
|
||||||
|
|
||||||
func checkAlignment(t *testing.T, name string, offset uintptr) {
|
|
||||||
t.Helper()
|
|
||||||
if offset%8 != 0 {
|
|
||||||
t.Errorf("offset of %q within struct is %d bytes, which does not align to 64-bit word boundaries (missing %d bytes). Atomic operations will crash on 32-bit systems.", name, offset, 8-(offset%8))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestPeerAlignment checks that atomically-accessed fields are
|
|
||||||
// aligned to 64-bit boundaries, as required by the atomic package.
|
|
||||||
//
|
|
||||||
// Unfortunately, violating this rule on 32-bit platforms results in a
|
|
||||||
// hard segfault at runtime.
|
|
||||||
func TestPeerAlignment(t *testing.T) {
|
|
||||||
var p Peer
|
|
||||||
|
|
||||||
typ := reflect.TypeOf(&p).Elem()
|
|
||||||
t.Logf("Peer type size: %d, with fields:", typ.Size())
|
|
||||||
for i := 0; i < typ.NumField(); i++ {
|
|
||||||
field := typ.Field(i)
|
|
||||||
t.Logf("\t%30s\toffset=%3v\t(type size=%3d, align=%d)",
|
|
||||||
field.Name,
|
|
||||||
field.Offset,
|
|
||||||
field.Type.Size(),
|
|
||||||
field.Type.Align(),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
checkAlignment(t, "Peer.stats", unsafe.Offsetof(p.stats))
|
|
||||||
checkAlignment(t, "Peer.isRunning", unsafe.Offsetof(p.isRunning))
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestDeviceAlignment checks that atomically-accessed fields are
|
|
||||||
// aligned to 64-bit boundaries, as required by the atomic package.
|
|
||||||
//
|
|
||||||
// Unfortunately, violating this rule on 32-bit platforms results in a
|
|
||||||
// hard segfault at runtime.
|
|
||||||
func TestDeviceAlignment(t *testing.T) {
|
|
||||||
var d Device
|
|
||||||
|
|
||||||
typ := reflect.TypeOf(&d).Elem()
|
|
||||||
t.Logf("Device type size: %d, with fields:", typ.Size())
|
|
||||||
for i := 0; i < typ.NumField(); i++ {
|
|
||||||
field := typ.Field(i)
|
|
||||||
t.Logf("\t%30s\toffset=%3v\t(type size=%3d, align=%d)",
|
|
||||||
field.Name,
|
|
||||||
field.Offset,
|
|
||||||
field.Type.Size(),
|
|
||||||
field.Type.Align(),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
checkAlignment(t, "Device.rate.underLoadUntil", unsafe.Offsetof(d.rate)+unsafe.Offsetof(d.rate.underLoadUntil))
|
|
||||||
}
|
|
@ -30,7 +30,7 @@ type Device struct {
|
|||||||
// will become the actual state; Up can fail.
|
// will become the actual state; Up can fail.
|
||||||
// The device can also change state multiple times between time of check and time of use.
|
// The device can also change state multiple times between time of check and time of use.
|
||||||
// Unsynchronized uses of state must therefore be advisory/best-effort only.
|
// Unsynchronized uses of state must therefore be advisory/best-effort only.
|
||||||
state uint32 // actually a deviceState, but typed uint32 for convenience
|
state atomic.Uint32 // actually a deviceState, but typed uint32 for convenience
|
||||||
// stopping blocks until all inputs to Device have been closed.
|
// stopping blocks until all inputs to Device have been closed.
|
||||||
stopping sync.WaitGroup
|
stopping sync.WaitGroup
|
||||||
// mu protects state changes.
|
// mu protects state changes.
|
||||||
@ -58,9 +58,8 @@ type Device struct {
|
|||||||
keyMap map[NoisePublicKey]*Peer
|
keyMap map[NoisePublicKey]*Peer
|
||||||
}
|
}
|
||||||
|
|
||||||
// Keep this 8-byte aligned
|
|
||||||
rate struct {
|
rate struct {
|
||||||
underLoadUntil int64
|
underLoadUntil atomic.Int64
|
||||||
limiter ratelimiter.Ratelimiter
|
limiter ratelimiter.Ratelimiter
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -82,7 +81,7 @@ type Device struct {
|
|||||||
|
|
||||||
tun struct {
|
tun struct {
|
||||||
device tun.Device
|
device tun.Device
|
||||||
mtu int32
|
mtu atomic.Int32
|
||||||
}
|
}
|
||||||
|
|
||||||
ipcMutex sync.RWMutex
|
ipcMutex sync.RWMutex
|
||||||
@ -94,10 +93,9 @@ type Device struct {
|
|||||||
// There are three states: down, up, closed.
|
// There are three states: down, up, closed.
|
||||||
// Transitions:
|
// Transitions:
|
||||||
//
|
//
|
||||||
// down -----+
|
// down -----+
|
||||||
// ↑↓ ↓
|
// ↑↓ ↓
|
||||||
// up -> closed
|
// up -> closed
|
||||||
//
|
|
||||||
type deviceState uint32
|
type deviceState uint32
|
||||||
|
|
||||||
//go:generate go run golang.org/x/tools/cmd/stringer -type deviceState -trimprefix=deviceState
|
//go:generate go run golang.org/x/tools/cmd/stringer -type deviceState -trimprefix=deviceState
|
||||||
@ -110,7 +108,7 @@ const (
|
|||||||
// deviceState returns device.state.state as a deviceState
|
// deviceState returns device.state.state as a deviceState
|
||||||
// See those docs for how to interpret this value.
|
// See those docs for how to interpret this value.
|
||||||
func (device *Device) deviceState() deviceState {
|
func (device *Device) deviceState() deviceState {
|
||||||
return deviceState(atomic.LoadUint32(&device.state.state))
|
return deviceState(device.state.state.Load())
|
||||||
}
|
}
|
||||||
|
|
||||||
// isClosed reports whether the device is closed (or is closing).
|
// isClosed reports whether the device is closed (or is closing).
|
||||||
@ -149,14 +147,14 @@ func (device *Device) changeState(want deviceState) (err error) {
|
|||||||
case old:
|
case old:
|
||||||
return nil
|
return nil
|
||||||
case deviceStateUp:
|
case deviceStateUp:
|
||||||
atomic.StoreUint32(&device.state.state, uint32(deviceStateUp))
|
device.state.state.Store(uint32(deviceStateUp))
|
||||||
err = device.upLocked()
|
err = device.upLocked()
|
||||||
if err == nil {
|
if err == nil {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
fallthrough // up failed; bring the device all the way back down
|
fallthrough // up failed; bring the device all the way back down
|
||||||
case deviceStateDown:
|
case deviceStateDown:
|
||||||
atomic.StoreUint32(&device.state.state, uint32(deviceStateDown))
|
device.state.state.Store(uint32(deviceStateDown))
|
||||||
errDown := device.downLocked()
|
errDown := device.downLocked()
|
||||||
if err == nil {
|
if err == nil {
|
||||||
err = errDown
|
err = errDown
|
||||||
@ -182,7 +180,7 @@ func (device *Device) upLocked() error {
|
|||||||
device.peers.RLock()
|
device.peers.RLock()
|
||||||
for _, peer := range device.peers.keyMap {
|
for _, peer := range device.peers.keyMap {
|
||||||
peer.Start()
|
peer.Start()
|
||||||
if atomic.LoadUint32(&peer.persistentKeepaliveInterval) > 0 {
|
if peer.persistentKeepaliveInterval.Load() > 0 {
|
||||||
peer.SendKeepalive()
|
peer.SendKeepalive()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -219,11 +217,11 @@ func (device *Device) IsUnderLoad() bool {
|
|||||||
now := time.Now()
|
now := time.Now()
|
||||||
underLoad := len(device.queue.handshake.c) >= QueueHandshakeSize/8
|
underLoad := len(device.queue.handshake.c) >= QueueHandshakeSize/8
|
||||||
if underLoad {
|
if underLoad {
|
||||||
atomic.StoreInt64(&device.rate.underLoadUntil, now.Add(UnderLoadAfterTime).UnixNano())
|
device.rate.underLoadUntil.Store(now.Add(UnderLoadAfterTime).UnixNano())
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
// check if recently under load
|
// check if recently under load
|
||||||
return atomic.LoadInt64(&device.rate.underLoadUntil) > now.UnixNano()
|
return device.rate.underLoadUntil.Load() > now.UnixNano()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
|
func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
|
||||||
@ -283,7 +281,7 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
|
|||||||
|
|
||||||
func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger) *Device {
|
func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger) *Device {
|
||||||
device := new(Device)
|
device := new(Device)
|
||||||
device.state.state = uint32(deviceStateDown)
|
device.state.state.Store(uint32(deviceStateDown))
|
||||||
device.closed = make(chan struct{})
|
device.closed = make(chan struct{})
|
||||||
device.log = logger
|
device.log = logger
|
||||||
device.net.bind = bind
|
device.net.bind = bind
|
||||||
@ -293,7 +291,7 @@ func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger) *Device {
|
|||||||
device.log.Errorf("Trouble determining MTU, assuming default: %v", err)
|
device.log.Errorf("Trouble determining MTU, assuming default: %v", err)
|
||||||
mtu = DefaultMTU
|
mtu = DefaultMTU
|
||||||
}
|
}
|
||||||
device.tun.mtu = int32(mtu)
|
device.tun.mtu.Store(int32(mtu))
|
||||||
device.peers.keyMap = make(map[NoisePublicKey]*Peer)
|
device.peers.keyMap = make(map[NoisePublicKey]*Peer)
|
||||||
device.rate.limiter.Init()
|
device.rate.limiter.Init()
|
||||||
device.indexTable.Init()
|
device.indexTable.Init()
|
||||||
@ -359,7 +357,7 @@ func (device *Device) Close() {
|
|||||||
if device.isClosed() {
|
if device.isClosed() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
atomic.StoreUint32(&device.state.state, uint32(deviceStateClosed))
|
device.state.state.Store(uint32(deviceStateClosed))
|
||||||
device.log.Verbosef("Device closing")
|
device.log.Verbosef("Device closing")
|
||||||
|
|
||||||
device.tun.device.Close()
|
device.tun.device.Close()
|
||||||
|
@ -333,7 +333,7 @@ func BenchmarkThroughput(b *testing.B) {
|
|||||||
|
|
||||||
// Measure how long it takes to receive b.N packets,
|
// Measure how long it takes to receive b.N packets,
|
||||||
// starting when we receive the first packet.
|
// starting when we receive the first packet.
|
||||||
var recv uint64
|
var recv atomic.Uint64
|
||||||
var elapsed time.Duration
|
var elapsed time.Duration
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
@ -342,7 +342,7 @@ func BenchmarkThroughput(b *testing.B) {
|
|||||||
var start time.Time
|
var start time.Time
|
||||||
for {
|
for {
|
||||||
<-pair[0].tun.Inbound
|
<-pair[0].tun.Inbound
|
||||||
new := atomic.AddUint64(&recv, 1)
|
new := recv.Add(1)
|
||||||
if new == 1 {
|
if new == 1 {
|
||||||
start = time.Now()
|
start = time.Now()
|
||||||
}
|
}
|
||||||
@ -358,7 +358,7 @@ func BenchmarkThroughput(b *testing.B) {
|
|||||||
ping := tuntest.Ping(pair[0].ip, pair[1].ip)
|
ping := tuntest.Ping(pair[0].ip, pair[1].ip)
|
||||||
pingc := pair[1].tun.Outbound
|
pingc := pair[1].tun.Outbound
|
||||||
var sent uint64
|
var sent uint64
|
||||||
for atomic.LoadUint64(&recv) != uint64(b.N) {
|
for recv.Load() != uint64(b.N) {
|
||||||
sent++
|
sent++
|
||||||
pingc <- ping
|
pingc <- ping
|
||||||
}
|
}
|
||||||
|
@ -10,7 +10,6 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
"unsafe"
|
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/replay"
|
"golang.zx2c4.com/wireguard/replay"
|
||||||
)
|
)
|
||||||
@ -23,7 +22,7 @@ import (
|
|||||||
*/
|
*/
|
||||||
|
|
||||||
type Keypair struct {
|
type Keypair struct {
|
||||||
sendNonce uint64 // accessed atomically
|
sendNonce atomic.Uint64
|
||||||
send cipher.AEAD
|
send cipher.AEAD
|
||||||
receive cipher.AEAD
|
receive cipher.AEAD
|
||||||
replayFilter replay.Filter
|
replayFilter replay.Filter
|
||||||
@ -37,15 +36,7 @@ type Keypairs struct {
|
|||||||
sync.RWMutex
|
sync.RWMutex
|
||||||
current *Keypair
|
current *Keypair
|
||||||
previous *Keypair
|
previous *Keypair
|
||||||
next *Keypair
|
next atomic.Pointer[Keypair]
|
||||||
}
|
|
||||||
|
|
||||||
func (kp *Keypairs) storeNext(next *Keypair) {
|
|
||||||
atomic.StorePointer((*unsafe.Pointer)((unsafe.Pointer)(&kp.next)), (unsafe.Pointer)(next))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (kp *Keypairs) loadNext() *Keypair {
|
|
||||||
return (*Keypair)(atomic.LoadPointer((*unsafe.Pointer)((unsafe.Pointer)(&kp.next))))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (kp *Keypairs) Current() *Keypair {
|
func (kp *Keypairs) Current() *Keypair {
|
||||||
|
@ -1,41 +0,0 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package device
|
|
||||||
|
|
||||||
import (
|
|
||||||
"sync/atomic"
|
|
||||||
)
|
|
||||||
|
|
||||||
/* Atomic Boolean */
|
|
||||||
|
|
||||||
const (
|
|
||||||
AtomicFalse = int32(iota)
|
|
||||||
AtomicTrue
|
|
||||||
)
|
|
||||||
|
|
||||||
type AtomicBool struct {
|
|
||||||
int32
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *AtomicBool) Get() bool {
|
|
||||||
return atomic.LoadInt32(&a.int32) == AtomicTrue
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *AtomicBool) Swap(val bool) bool {
|
|
||||||
flag := AtomicFalse
|
|
||||||
if val {
|
|
||||||
flag = AtomicTrue
|
|
||||||
}
|
|
||||||
return atomic.SwapInt32(&a.int32, flag) == AtomicTrue
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *AtomicBool) Set(val bool) {
|
|
||||||
flag := AtomicFalse
|
|
||||||
if val {
|
|
||||||
flag = AtomicTrue
|
|
||||||
}
|
|
||||||
atomic.StoreInt32(&a.int32, flag)
|
|
||||||
}
|
|
@ -282,7 +282,7 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
|
|||||||
// lookup peer
|
// lookup peer
|
||||||
|
|
||||||
peer := device.LookupPeer(peerPK)
|
peer := device.LookupPeer(peerPK)
|
||||||
if peer == nil || !peer.isRunning.Get() {
|
if peer == nil || !peer.isRunning.Load() {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -581,12 +581,12 @@ func (peer *Peer) BeginSymmetricSession() error {
|
|||||||
defer keypairs.Unlock()
|
defer keypairs.Unlock()
|
||||||
|
|
||||||
previous := keypairs.previous
|
previous := keypairs.previous
|
||||||
next := keypairs.loadNext()
|
next := keypairs.next.Load()
|
||||||
current := keypairs.current
|
current := keypairs.current
|
||||||
|
|
||||||
if isInitiator {
|
if isInitiator {
|
||||||
if next != nil {
|
if next != nil {
|
||||||
keypairs.storeNext(nil)
|
keypairs.next.Store(nil)
|
||||||
keypairs.previous = next
|
keypairs.previous = next
|
||||||
device.DeleteKeypair(current)
|
device.DeleteKeypair(current)
|
||||||
} else {
|
} else {
|
||||||
@ -595,7 +595,7 @@ func (peer *Peer) BeginSymmetricSession() error {
|
|||||||
device.DeleteKeypair(previous)
|
device.DeleteKeypair(previous)
|
||||||
keypairs.current = keypair
|
keypairs.current = keypair
|
||||||
} else {
|
} else {
|
||||||
keypairs.storeNext(keypair)
|
keypairs.next.Store(keypair)
|
||||||
device.DeleteKeypair(next)
|
device.DeleteKeypair(next)
|
||||||
keypairs.previous = nil
|
keypairs.previous = nil
|
||||||
device.DeleteKeypair(previous)
|
device.DeleteKeypair(previous)
|
||||||
@ -607,18 +607,18 @@ func (peer *Peer) BeginSymmetricSession() error {
|
|||||||
func (peer *Peer) ReceivedWithKeypair(receivedKeypair *Keypair) bool {
|
func (peer *Peer) ReceivedWithKeypair(receivedKeypair *Keypair) bool {
|
||||||
keypairs := &peer.keypairs
|
keypairs := &peer.keypairs
|
||||||
|
|
||||||
if keypairs.loadNext() != receivedKeypair {
|
if keypairs.next.Load() != receivedKeypair {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
keypairs.Lock()
|
keypairs.Lock()
|
||||||
defer keypairs.Unlock()
|
defer keypairs.Unlock()
|
||||||
if keypairs.loadNext() != receivedKeypair {
|
if keypairs.next.Load() != receivedKeypair {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
old := keypairs.previous
|
old := keypairs.previous
|
||||||
keypairs.previous = keypairs.current
|
keypairs.previous = keypairs.current
|
||||||
peer.device.DeleteKeypair(old)
|
peer.device.DeleteKeypair(old)
|
||||||
keypairs.current = keypairs.loadNext()
|
keypairs.current = keypairs.next.Load()
|
||||||
keypairs.storeNext(nil)
|
keypairs.next.Store(nil)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
@ -148,7 +148,7 @@ func TestNoiseHandshake(t *testing.T) {
|
|||||||
t.Fatal("failed to derive keypair for peer 2", err)
|
t.Fatal("failed to derive keypair for peer 2", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
key1 := peer1.keypairs.loadNext()
|
key1 := peer1.keypairs.next.Load()
|
||||||
key2 := peer2.keypairs.current
|
key2 := peer2.keypairs.current
|
||||||
|
|
||||||
// encrypting / decryption test
|
// encrypting / decryption test
|
||||||
|
@ -16,24 +16,16 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Peer struct {
|
type Peer struct {
|
||||||
isRunning AtomicBool
|
isRunning atomic.Bool
|
||||||
sync.RWMutex // Mostly protects endpoint, but is generally taken whenever we modify peer
|
sync.RWMutex // Mostly protects endpoint, but is generally taken whenever we modify peer
|
||||||
keypairs Keypairs
|
keypairs Keypairs
|
||||||
handshake Handshake
|
handshake Handshake
|
||||||
device *Device
|
device *Device
|
||||||
endpoint conn.Endpoint
|
endpoint conn.Endpoint
|
||||||
stopping sync.WaitGroup // routines pending stop
|
stopping sync.WaitGroup // routines pending stop
|
||||||
|
txBytes atomic.Uint64 // bytes send to peer (endpoint)
|
||||||
// These fields are accessed with atomic operations, which must be
|
rxBytes atomic.Uint64 // bytes received from peer
|
||||||
// 64-bit aligned even on 32-bit platforms. Go guarantees that an
|
lastHandshakeNano atomic.Int64 // nano seconds since epoch
|
||||||
// allocated struct will be 64-bit aligned. So we place
|
|
||||||
// atomically-accessed fields up front, so that they can share in
|
|
||||||
// this alignment before smaller fields throw it off.
|
|
||||||
stats struct {
|
|
||||||
txBytes uint64 // bytes send to peer (endpoint)
|
|
||||||
rxBytes uint64 // bytes received from peer
|
|
||||||
lastHandshakeNano int64 // nano seconds since epoch
|
|
||||||
}
|
|
||||||
|
|
||||||
disableRoaming bool
|
disableRoaming bool
|
||||||
|
|
||||||
@ -43,9 +35,9 @@ type Peer struct {
|
|||||||
newHandshake *Timer
|
newHandshake *Timer
|
||||||
zeroKeyMaterial *Timer
|
zeroKeyMaterial *Timer
|
||||||
persistentKeepalive *Timer
|
persistentKeepalive *Timer
|
||||||
handshakeAttempts uint32
|
handshakeAttempts atomic.Uint32
|
||||||
needAnotherKeepalive AtomicBool
|
needAnotherKeepalive atomic.Bool
|
||||||
sentLastMinuteHandshake AtomicBool
|
sentLastMinuteHandshake atomic.Bool
|
||||||
}
|
}
|
||||||
|
|
||||||
state struct {
|
state struct {
|
||||||
@ -60,7 +52,7 @@ type Peer struct {
|
|||||||
|
|
||||||
cookieGenerator CookieGenerator
|
cookieGenerator CookieGenerator
|
||||||
trieEntries list.List
|
trieEntries list.List
|
||||||
persistentKeepaliveInterval uint32 // accessed atomically
|
persistentKeepaliveInterval atomic.Uint32
|
||||||
}
|
}
|
||||||
|
|
||||||
func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
|
func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
|
||||||
@ -133,7 +125,7 @@ func (peer *Peer) SendBuffer(buffer []byte) error {
|
|||||||
|
|
||||||
err := peer.device.net.bind.Send(buffer, peer.endpoint)
|
err := peer.device.net.bind.Send(buffer, peer.endpoint)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
atomic.AddUint64(&peer.stats.txBytes, uint64(len(buffer)))
|
peer.txBytes.Add(uint64(len(buffer)))
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -174,7 +166,7 @@ func (peer *Peer) Start() {
|
|||||||
peer.state.Lock()
|
peer.state.Lock()
|
||||||
defer peer.state.Unlock()
|
defer peer.state.Unlock()
|
||||||
|
|
||||||
if peer.isRunning.Get() {
|
if peer.isRunning.Load() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -198,7 +190,7 @@ func (peer *Peer) Start() {
|
|||||||
go peer.RoutineSequentialSender()
|
go peer.RoutineSequentialSender()
|
||||||
go peer.RoutineSequentialReceiver()
|
go peer.RoutineSequentialReceiver()
|
||||||
|
|
||||||
peer.isRunning.Set(true)
|
peer.isRunning.Store(true)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (peer *Peer) ZeroAndFlushAll() {
|
func (peer *Peer) ZeroAndFlushAll() {
|
||||||
@ -210,10 +202,10 @@ func (peer *Peer) ZeroAndFlushAll() {
|
|||||||
keypairs.Lock()
|
keypairs.Lock()
|
||||||
device.DeleteKeypair(keypairs.previous)
|
device.DeleteKeypair(keypairs.previous)
|
||||||
device.DeleteKeypair(keypairs.current)
|
device.DeleteKeypair(keypairs.current)
|
||||||
device.DeleteKeypair(keypairs.loadNext())
|
device.DeleteKeypair(keypairs.next.Load())
|
||||||
keypairs.previous = nil
|
keypairs.previous = nil
|
||||||
keypairs.current = nil
|
keypairs.current = nil
|
||||||
keypairs.storeNext(nil)
|
keypairs.next.Store(nil)
|
||||||
keypairs.Unlock()
|
keypairs.Unlock()
|
||||||
|
|
||||||
// clear handshake state
|
// clear handshake state
|
||||||
@ -238,11 +230,10 @@ func (peer *Peer) ExpireCurrentKeypairs() {
|
|||||||
keypairs := &peer.keypairs
|
keypairs := &peer.keypairs
|
||||||
keypairs.Lock()
|
keypairs.Lock()
|
||||||
if keypairs.current != nil {
|
if keypairs.current != nil {
|
||||||
atomic.StoreUint64(&keypairs.current.sendNonce, RejectAfterMessages)
|
keypairs.current.sendNonce.Store(RejectAfterMessages)
|
||||||
}
|
}
|
||||||
if keypairs.next != nil {
|
if next := keypairs.next.Load(); next != nil {
|
||||||
next := keypairs.loadNext()
|
next.sendNonce.Store(RejectAfterMessages)
|
||||||
atomic.StoreUint64(&next.sendNonce, RejectAfterMessages)
|
|
||||||
}
|
}
|
||||||
keypairs.Unlock()
|
keypairs.Unlock()
|
||||||
}
|
}
|
||||||
|
@ -14,7 +14,7 @@ type WaitPool struct {
|
|||||||
pool sync.Pool
|
pool sync.Pool
|
||||||
cond sync.Cond
|
cond sync.Cond
|
||||||
lock sync.Mutex
|
lock sync.Mutex
|
||||||
count uint32
|
count atomic.Uint32
|
||||||
max uint32
|
max uint32
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -27,10 +27,10 @@ func NewWaitPool(max uint32, new func() any) *WaitPool {
|
|||||||
func (p *WaitPool) Get() any {
|
func (p *WaitPool) Get() any {
|
||||||
if p.max != 0 {
|
if p.max != 0 {
|
||||||
p.lock.Lock()
|
p.lock.Lock()
|
||||||
for atomic.LoadUint32(&p.count) >= p.max {
|
for p.count.Load() >= p.max {
|
||||||
p.cond.Wait()
|
p.cond.Wait()
|
||||||
}
|
}
|
||||||
atomic.AddUint32(&p.count, 1)
|
p.count.Add(1)
|
||||||
p.lock.Unlock()
|
p.lock.Unlock()
|
||||||
}
|
}
|
||||||
return p.pool.Get()
|
return p.pool.Get()
|
||||||
@ -41,7 +41,7 @@ func (p *WaitPool) Put(x any) {
|
|||||||
if p.max == 0 {
|
if p.max == 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
atomic.AddUint32(&p.count, ^uint32(0))
|
p.count.Add(^uint32(0))
|
||||||
p.cond.Signal()
|
p.cond.Signal()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -17,29 +17,31 @@ import (
|
|||||||
func TestWaitPool(t *testing.T) {
|
func TestWaitPool(t *testing.T) {
|
||||||
t.Skip("Currently disabled")
|
t.Skip("Currently disabled")
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
trials := int32(100000)
|
var trials atomic.Int32
|
||||||
|
startTrials := int32(100000)
|
||||||
if raceEnabled {
|
if raceEnabled {
|
||||||
// This test can be very slow with -race.
|
// This test can be very slow with -race.
|
||||||
trials /= 10
|
startTrials /= 10
|
||||||
}
|
}
|
||||||
|
trials.Store(startTrials)
|
||||||
workers := runtime.NumCPU() + 2
|
workers := runtime.NumCPU() + 2
|
||||||
if workers-4 <= 0 {
|
if workers-4 <= 0 {
|
||||||
t.Skip("Not enough cores")
|
t.Skip("Not enough cores")
|
||||||
}
|
}
|
||||||
p := NewWaitPool(uint32(workers-4), func() any { return make([]byte, 16) })
|
p := NewWaitPool(uint32(workers-4), func() any { return make([]byte, 16) })
|
||||||
wg.Add(workers)
|
wg.Add(workers)
|
||||||
max := uint32(0)
|
var max atomic.Uint32
|
||||||
updateMax := func() {
|
updateMax := func() {
|
||||||
count := atomic.LoadUint32(&p.count)
|
count := p.count.Load()
|
||||||
if count > p.max {
|
if count > p.max {
|
||||||
t.Errorf("count (%d) > max (%d)", count, p.max)
|
t.Errorf("count (%d) > max (%d)", count, p.max)
|
||||||
}
|
}
|
||||||
for {
|
for {
|
||||||
old := atomic.LoadUint32(&max)
|
old := max.Load()
|
||||||
if count <= old {
|
if count <= old {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
if atomic.CompareAndSwapUint32(&max, old, count) {
|
if max.CompareAndSwap(old, count) {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -47,7 +49,7 @@ func TestWaitPool(t *testing.T) {
|
|||||||
for i := 0; i < workers; i++ {
|
for i := 0; i < workers; i++ {
|
||||||
go func() {
|
go func() {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
for atomic.AddInt32(&trials, -1) > 0 {
|
for trials.Add(-1) > 0 {
|
||||||
updateMax()
|
updateMax()
|
||||||
x := p.Get()
|
x := p.Get()
|
||||||
updateMax()
|
updateMax()
|
||||||
@ -59,14 +61,15 @@ func TestWaitPool(t *testing.T) {
|
|||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
if max != p.max {
|
if max.Load() != p.max {
|
||||||
t.Errorf("Actual maximum count (%d) != ideal maximum count (%d)", max, p.max)
|
t.Errorf("Actual maximum count (%d) != ideal maximum count (%d)", max, p.max)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkWaitPool(b *testing.B) {
|
func BenchmarkWaitPool(b *testing.B) {
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
trials := int32(b.N)
|
var trials atomic.Int32
|
||||||
|
trials.Store(int32(b.N))
|
||||||
workers := runtime.NumCPU() + 2
|
workers := runtime.NumCPU() + 2
|
||||||
if workers-4 <= 0 {
|
if workers-4 <= 0 {
|
||||||
b.Skip("Not enough cores")
|
b.Skip("Not enough cores")
|
||||||
@ -77,7 +80,7 @@ func BenchmarkWaitPool(b *testing.B) {
|
|||||||
for i := 0; i < workers; i++ {
|
for i := 0; i < workers; i++ {
|
||||||
go func() {
|
go func() {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
for atomic.AddInt32(&trials, -1) > 0 {
|
for trials.Add(-1) > 0 {
|
||||||
x := p.Get()
|
x := p.Get()
|
||||||
time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond)
|
time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond)
|
||||||
p.Put(x)
|
p.Put(x)
|
||||||
|
@ -11,7 +11,6 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"golang.org/x/crypto/chacha20poly1305"
|
"golang.org/x/crypto/chacha20poly1305"
|
||||||
@ -52,12 +51,12 @@ func (elem *QueueInboundElement) clearPointers() {
|
|||||||
* NOTE: Not thread safe, but called by sequential receiver!
|
* NOTE: Not thread safe, but called by sequential receiver!
|
||||||
*/
|
*/
|
||||||
func (peer *Peer) keepKeyFreshReceiving() {
|
func (peer *Peer) keepKeyFreshReceiving() {
|
||||||
if peer.timers.sentLastMinuteHandshake.Get() {
|
if peer.timers.sentLastMinuteHandshake.Load() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
keypair := peer.keypairs.Current()
|
keypair := peer.keypairs.Current()
|
||||||
if keypair != nil && keypair.isInitiator && time.Since(keypair.created) > (RejectAfterTime-KeepaliveTimeout-RekeyTimeout) {
|
if keypair != nil && keypair.isInitiator && time.Since(keypair.created) > (RejectAfterTime-KeepaliveTimeout-RekeyTimeout) {
|
||||||
peer.timers.sentLastMinuteHandshake.Set(true)
|
peer.timers.sentLastMinuteHandshake.Store(true)
|
||||||
peer.SendHandshakeInitiation(false)
|
peer.SendHandshakeInitiation(false)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -163,7 +162,7 @@ func (device *Device) RoutineReceiveIncoming(recv conn.ReceiveFunc) {
|
|||||||
elem.Lock()
|
elem.Lock()
|
||||||
|
|
||||||
// add to decryption queues
|
// add to decryption queues
|
||||||
if peer.isRunning.Get() {
|
if peer.isRunning.Load() {
|
||||||
peer.queue.inbound.c <- elem
|
peer.queue.inbound.c <- elem
|
||||||
device.queue.decryption.c <- elem
|
device.queue.decryption.c <- elem
|
||||||
buffer = device.GetMessageBuffer()
|
buffer = device.GetMessageBuffer()
|
||||||
@ -268,7 +267,7 @@ func (device *Device) RoutineHandshake(id int) {
|
|||||||
|
|
||||||
// consume reply
|
// consume reply
|
||||||
|
|
||||||
if peer := entry.peer; peer.isRunning.Get() {
|
if peer := entry.peer; peer.isRunning.Load() {
|
||||||
device.log.Verbosef("Receiving cookie response from %s", elem.endpoint.DstToString())
|
device.log.Verbosef("Receiving cookie response from %s", elem.endpoint.DstToString())
|
||||||
if !peer.cookieGenerator.ConsumeReply(&reply) {
|
if !peer.cookieGenerator.ConsumeReply(&reply) {
|
||||||
device.log.Verbosef("Could not decrypt invalid cookie response")
|
device.log.Verbosef("Could not decrypt invalid cookie response")
|
||||||
@ -341,7 +340,7 @@ func (device *Device) RoutineHandshake(id int) {
|
|||||||
peer.SetEndpointFromPacket(elem.endpoint)
|
peer.SetEndpointFromPacket(elem.endpoint)
|
||||||
|
|
||||||
device.log.Verbosef("%v - Received handshake initiation", peer)
|
device.log.Verbosef("%v - Received handshake initiation", peer)
|
||||||
atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)))
|
peer.rxBytes.Add(uint64(len(elem.packet)))
|
||||||
|
|
||||||
peer.SendHandshakeResponse()
|
peer.SendHandshakeResponse()
|
||||||
|
|
||||||
@ -369,7 +368,7 @@ func (device *Device) RoutineHandshake(id int) {
|
|||||||
peer.SetEndpointFromPacket(elem.endpoint)
|
peer.SetEndpointFromPacket(elem.endpoint)
|
||||||
|
|
||||||
device.log.Verbosef("%v - Received handshake response", peer)
|
device.log.Verbosef("%v - Received handshake response", peer)
|
||||||
atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)))
|
peer.rxBytes.Add(uint64(len(elem.packet)))
|
||||||
|
|
||||||
// update timers
|
// update timers
|
||||||
|
|
||||||
@ -426,7 +425,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
|
|||||||
peer.keepKeyFreshReceiving()
|
peer.keepKeyFreshReceiving()
|
||||||
peer.timersAnyAuthenticatedPacketTraversal()
|
peer.timersAnyAuthenticatedPacketTraversal()
|
||||||
peer.timersAnyAuthenticatedPacketReceived()
|
peer.timersAnyAuthenticatedPacketReceived()
|
||||||
atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)+MinMessageSize))
|
peer.rxBytes.Add(uint64(len(elem.packet) + MinMessageSize))
|
||||||
|
|
||||||
if len(elem.packet) == 0 {
|
if len(elem.packet) == 0 {
|
||||||
device.log.Verbosef("%v - Receiving keepalive packet", peer)
|
device.log.Verbosef("%v - Receiving keepalive packet", peer)
|
||||||
|
@ -12,7 +12,6 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"golang.org/x/crypto/chacha20poly1305"
|
"golang.org/x/crypto/chacha20poly1305"
|
||||||
@ -76,7 +75,7 @@ func (elem *QueueOutboundElement) clearPointers() {
|
|||||||
/* Queues a keepalive if no packets are queued for peer
|
/* Queues a keepalive if no packets are queued for peer
|
||||||
*/
|
*/
|
||||||
func (peer *Peer) SendKeepalive() {
|
func (peer *Peer) SendKeepalive() {
|
||||||
if len(peer.queue.staged) == 0 && peer.isRunning.Get() {
|
if len(peer.queue.staged) == 0 && peer.isRunning.Load() {
|
||||||
elem := peer.device.NewOutboundElement()
|
elem := peer.device.NewOutboundElement()
|
||||||
select {
|
select {
|
||||||
case peer.queue.staged <- elem:
|
case peer.queue.staged <- elem:
|
||||||
@ -91,7 +90,7 @@ func (peer *Peer) SendKeepalive() {
|
|||||||
|
|
||||||
func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
|
func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
|
||||||
if !isRetry {
|
if !isRetry {
|
||||||
atomic.StoreUint32(&peer.timers.handshakeAttempts, 0)
|
peer.timers.handshakeAttempts.Store(0)
|
||||||
}
|
}
|
||||||
|
|
||||||
peer.handshake.mutex.RLock()
|
peer.handshake.mutex.RLock()
|
||||||
@ -193,7 +192,7 @@ func (peer *Peer) keepKeyFreshSending() {
|
|||||||
if keypair == nil {
|
if keypair == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
nonce := atomic.LoadUint64(&keypair.sendNonce)
|
nonce := keypair.sendNonce.Load()
|
||||||
if nonce > RekeyAfterMessages || (keypair.isInitiator && time.Since(keypair.created) > RekeyAfterTime) {
|
if nonce > RekeyAfterMessages || (keypair.isInitiator && time.Since(keypair.created) > RekeyAfterTime) {
|
||||||
peer.SendHandshakeInitiation(false)
|
peer.SendHandshakeInitiation(false)
|
||||||
}
|
}
|
||||||
@ -269,7 +268,7 @@ func (device *Device) RoutineReadFromTUN() {
|
|||||||
if peer == nil {
|
if peer == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if peer.isRunning.Get() {
|
if peer.isRunning.Load() {
|
||||||
peer.StagePacket(elem)
|
peer.StagePacket(elem)
|
||||||
elem = nil
|
elem = nil
|
||||||
peer.SendStagedPackets()
|
peer.SendStagedPackets()
|
||||||
@ -300,7 +299,7 @@ top:
|
|||||||
}
|
}
|
||||||
|
|
||||||
keypair := peer.keypairs.Current()
|
keypair := peer.keypairs.Current()
|
||||||
if keypair == nil || atomic.LoadUint64(&keypair.sendNonce) >= RejectAfterMessages || time.Since(keypair.created) >= RejectAfterTime {
|
if keypair == nil || keypair.sendNonce.Load() >= RejectAfterMessages || time.Since(keypair.created) >= RejectAfterTime {
|
||||||
peer.SendHandshakeInitiation(false)
|
peer.SendHandshakeInitiation(false)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -309,9 +308,9 @@ top:
|
|||||||
select {
|
select {
|
||||||
case elem := <-peer.queue.staged:
|
case elem := <-peer.queue.staged:
|
||||||
elem.peer = peer
|
elem.peer = peer
|
||||||
elem.nonce = atomic.AddUint64(&keypair.sendNonce, 1) - 1
|
elem.nonce = keypair.sendNonce.Add(1) - 1
|
||||||
if elem.nonce >= RejectAfterMessages {
|
if elem.nonce >= RejectAfterMessages {
|
||||||
atomic.StoreUint64(&keypair.sendNonce, RejectAfterMessages)
|
keypair.sendNonce.Store(RejectAfterMessages)
|
||||||
peer.StagePacket(elem) // XXX: Out of order, but we can't front-load go chans
|
peer.StagePacket(elem) // XXX: Out of order, but we can't front-load go chans
|
||||||
goto top
|
goto top
|
||||||
}
|
}
|
||||||
@ -320,7 +319,7 @@ top:
|
|||||||
elem.Lock()
|
elem.Lock()
|
||||||
|
|
||||||
// add to parallel and sequential queue
|
// add to parallel and sequential queue
|
||||||
if peer.isRunning.Get() {
|
if peer.isRunning.Load() {
|
||||||
peer.queue.outbound.c <- elem
|
peer.queue.outbound.c <- elem
|
||||||
peer.device.queue.encryption.c <- elem
|
peer.device.queue.encryption.c <- elem
|
||||||
} else {
|
} else {
|
||||||
@ -385,7 +384,7 @@ func (device *Device) RoutineEncryption(id int) {
|
|||||||
binary.LittleEndian.PutUint64(fieldNonce, elem.nonce)
|
binary.LittleEndian.PutUint64(fieldNonce, elem.nonce)
|
||||||
|
|
||||||
// pad content to multiple of 16
|
// pad content to multiple of 16
|
||||||
paddingSize := calculatePaddingSize(len(elem.packet), int(atomic.LoadInt32(&device.tun.mtu)))
|
paddingSize := calculatePaddingSize(len(elem.packet), int(device.tun.mtu.Load()))
|
||||||
elem.packet = append(elem.packet, paddingZeros[:paddingSize]...)
|
elem.packet = append(elem.packet, paddingZeros[:paddingSize]...)
|
||||||
|
|
||||||
// encrypt content and release to consumer
|
// encrypt content and release to consumer
|
||||||
@ -419,7 +418,7 @@ func (peer *Peer) RoutineSequentialSender() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
elem.Lock()
|
elem.Lock()
|
||||||
if !peer.isRunning.Get() {
|
if !peer.isRunning.Load() {
|
||||||
// peer has been stopped; return re-usable elems to the shared pool.
|
// peer has been stopped; return re-usable elems to the shared pool.
|
||||||
// This is an optimization only. It is possible for the peer to be stopped
|
// This is an optimization only. It is possible for the peer to be stopped
|
||||||
// immediately after this check, in which case, elem will get processed.
|
// immediately after this check, in which case, elem will get processed.
|
||||||
|
@ -9,7 +9,6 @@ package device
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
"time"
|
||||||
_ "unsafe"
|
_ "unsafe"
|
||||||
)
|
)
|
||||||
@ -74,11 +73,11 @@ func (timer *Timer) IsPending() bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (peer *Peer) timersActive() bool {
|
func (peer *Peer) timersActive() bool {
|
||||||
return peer.isRunning.Get() && peer.device != nil && peer.device.isUp()
|
return peer.isRunning.Load() && peer.device != nil && peer.device.isUp()
|
||||||
}
|
}
|
||||||
|
|
||||||
func expiredRetransmitHandshake(peer *Peer) {
|
func expiredRetransmitHandshake(peer *Peer) {
|
||||||
if atomic.LoadUint32(&peer.timers.handshakeAttempts) > MaxTimerHandshakes {
|
if peer.timers.handshakeAttempts.Load() > MaxTimerHandshakes {
|
||||||
peer.device.log.Verbosef("%s - Handshake did not complete after %d attempts, giving up", peer, MaxTimerHandshakes+2)
|
peer.device.log.Verbosef("%s - Handshake did not complete after %d attempts, giving up", peer, MaxTimerHandshakes+2)
|
||||||
|
|
||||||
if peer.timersActive() {
|
if peer.timersActive() {
|
||||||
@ -97,8 +96,8 @@ func expiredRetransmitHandshake(peer *Peer) {
|
|||||||
peer.timers.zeroKeyMaterial.Mod(RejectAfterTime * 3)
|
peer.timers.zeroKeyMaterial.Mod(RejectAfterTime * 3)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
atomic.AddUint32(&peer.timers.handshakeAttempts, 1)
|
peer.timers.handshakeAttempts.Add(1)
|
||||||
peer.device.log.Verbosef("%s - Handshake did not complete after %d seconds, retrying (try %d)", peer, int(RekeyTimeout.Seconds()), atomic.LoadUint32(&peer.timers.handshakeAttempts)+1)
|
peer.device.log.Verbosef("%s - Handshake did not complete after %d seconds, retrying (try %d)", peer, int(RekeyTimeout.Seconds()), peer.timers.handshakeAttempts.Load()+1)
|
||||||
|
|
||||||
/* We clear the endpoint address src address, in case this is the cause of trouble. */
|
/* We clear the endpoint address src address, in case this is the cause of trouble. */
|
||||||
peer.Lock()
|
peer.Lock()
|
||||||
@ -113,8 +112,8 @@ func expiredRetransmitHandshake(peer *Peer) {
|
|||||||
|
|
||||||
func expiredSendKeepalive(peer *Peer) {
|
func expiredSendKeepalive(peer *Peer) {
|
||||||
peer.SendKeepalive()
|
peer.SendKeepalive()
|
||||||
if peer.timers.needAnotherKeepalive.Get() {
|
if peer.timers.needAnotherKeepalive.Load() {
|
||||||
peer.timers.needAnotherKeepalive.Set(false)
|
peer.timers.needAnotherKeepalive.Store(false)
|
||||||
if peer.timersActive() {
|
if peer.timersActive() {
|
||||||
peer.timers.sendKeepalive.Mod(KeepaliveTimeout)
|
peer.timers.sendKeepalive.Mod(KeepaliveTimeout)
|
||||||
}
|
}
|
||||||
@ -138,7 +137,7 @@ func expiredZeroKeyMaterial(peer *Peer) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func expiredPersistentKeepalive(peer *Peer) {
|
func expiredPersistentKeepalive(peer *Peer) {
|
||||||
if atomic.LoadUint32(&peer.persistentKeepaliveInterval) > 0 {
|
if peer.persistentKeepaliveInterval.Load() > 0 {
|
||||||
peer.SendKeepalive()
|
peer.SendKeepalive()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -156,7 +155,7 @@ func (peer *Peer) timersDataReceived() {
|
|||||||
if !peer.timers.sendKeepalive.IsPending() {
|
if !peer.timers.sendKeepalive.IsPending() {
|
||||||
peer.timers.sendKeepalive.Mod(KeepaliveTimeout)
|
peer.timers.sendKeepalive.Mod(KeepaliveTimeout)
|
||||||
} else {
|
} else {
|
||||||
peer.timers.needAnotherKeepalive.Set(true)
|
peer.timers.needAnotherKeepalive.Store(true)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -187,9 +186,9 @@ func (peer *Peer) timersHandshakeComplete() {
|
|||||||
if peer.timersActive() {
|
if peer.timersActive() {
|
||||||
peer.timers.retransmitHandshake.Del()
|
peer.timers.retransmitHandshake.Del()
|
||||||
}
|
}
|
||||||
atomic.StoreUint32(&peer.timers.handshakeAttempts, 0)
|
peer.timers.handshakeAttempts.Store(0)
|
||||||
peer.timers.sentLastMinuteHandshake.Set(false)
|
peer.timers.sentLastMinuteHandshake.Store(false)
|
||||||
atomic.StoreInt64(&peer.stats.lastHandshakeNano, time.Now().UnixNano())
|
peer.lastHandshakeNano.Store(time.Now().UnixNano())
|
||||||
}
|
}
|
||||||
|
|
||||||
/* Should be called after an ephemeral key is created, which is before sending a handshake response or after receiving a handshake response. */
|
/* Should be called after an ephemeral key is created, which is before sending a handshake response or after receiving a handshake response. */
|
||||||
@ -201,7 +200,7 @@ func (peer *Peer) timersSessionDerived() {
|
|||||||
|
|
||||||
/* Should be called before a packet with authentication -- keepalive, data, or handshake -- is sent, or after one is received. */
|
/* Should be called before a packet with authentication -- keepalive, data, or handshake -- is sent, or after one is received. */
|
||||||
func (peer *Peer) timersAnyAuthenticatedPacketTraversal() {
|
func (peer *Peer) timersAnyAuthenticatedPacketTraversal() {
|
||||||
keepalive := atomic.LoadUint32(&peer.persistentKeepaliveInterval)
|
keepalive := peer.persistentKeepaliveInterval.Load()
|
||||||
if keepalive > 0 && peer.timersActive() {
|
if keepalive > 0 && peer.timersActive() {
|
||||||
peer.timers.persistentKeepalive.Mod(time.Duration(keepalive) * time.Second)
|
peer.timers.persistentKeepalive.Mod(time.Duration(keepalive) * time.Second)
|
||||||
}
|
}
|
||||||
@ -216,9 +215,9 @@ func (peer *Peer) timersInit() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (peer *Peer) timersStart() {
|
func (peer *Peer) timersStart() {
|
||||||
atomic.StoreUint32(&peer.timers.handshakeAttempts, 0)
|
peer.timers.handshakeAttempts.Store(0)
|
||||||
peer.timers.sentLastMinuteHandshake.Set(false)
|
peer.timers.sentLastMinuteHandshake.Store(false)
|
||||||
peer.timers.needAnotherKeepalive.Set(false)
|
peer.timers.needAnotherKeepalive.Store(false)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (peer *Peer) timersStop() {
|
func (peer *Peer) timersStop() {
|
||||||
|
@ -7,7 +7,6 @@ package device
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"sync/atomic"
|
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/tun"
|
"golang.zx2c4.com/wireguard/tun"
|
||||||
)
|
)
|
||||||
@ -33,7 +32,7 @@ func (device *Device) RoutineTUNEventReader() {
|
|||||||
tooLarge = fmt.Sprintf(" (too large, capped at %v)", MaxContentSize)
|
tooLarge = fmt.Sprintf(" (too large, capped at %v)", MaxContentSize)
|
||||||
mtu = MaxContentSize
|
mtu = MaxContentSize
|
||||||
}
|
}
|
||||||
old := atomic.SwapInt32(&device.tun.mtu, int32(mtu))
|
old := device.tun.mtu.Swap(int32(mtu))
|
||||||
if int(old) != mtu {
|
if int(old) != mtu {
|
||||||
device.log.Verbosef("MTU updated: %v%s", mtu, tooLarge)
|
device.log.Verbosef("MTU updated: %v%s", mtu, tooLarge)
|
||||||
}
|
}
|
||||||
|
@ -16,7 +16,6 @@ import (
|
|||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/ipc"
|
"golang.zx2c4.com/wireguard/ipc"
|
||||||
@ -112,15 +111,15 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
|
|||||||
sendf("endpoint=%s", peer.endpoint.DstToString())
|
sendf("endpoint=%s", peer.endpoint.DstToString())
|
||||||
}
|
}
|
||||||
|
|
||||||
nano := atomic.LoadInt64(&peer.stats.lastHandshakeNano)
|
nano := peer.lastHandshakeNano.Load()
|
||||||
secs := nano / time.Second.Nanoseconds()
|
secs := nano / time.Second.Nanoseconds()
|
||||||
nano %= time.Second.Nanoseconds()
|
nano %= time.Second.Nanoseconds()
|
||||||
|
|
||||||
sendf("last_handshake_time_sec=%d", secs)
|
sendf("last_handshake_time_sec=%d", secs)
|
||||||
sendf("last_handshake_time_nsec=%d", nano)
|
sendf("last_handshake_time_nsec=%d", nano)
|
||||||
sendf("tx_bytes=%d", atomic.LoadUint64(&peer.stats.txBytes))
|
sendf("tx_bytes=%d", peer.txBytes.Load())
|
||||||
sendf("rx_bytes=%d", atomic.LoadUint64(&peer.stats.rxBytes))
|
sendf("rx_bytes=%d", peer.rxBytes.Load())
|
||||||
sendf("persistent_keepalive_interval=%d", atomic.LoadUint32(&peer.persistentKeepaliveInterval))
|
sendf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval.Load())
|
||||||
|
|
||||||
device.allowedips.EntriesForPeer(peer, func(prefix netip.Prefix) bool {
|
device.allowedips.EntriesForPeer(peer, func(prefix netip.Prefix) bool {
|
||||||
sendf("allowed_ip=%s", prefix.String())
|
sendf("allowed_ip=%s", prefix.String())
|
||||||
@ -358,7 +357,7 @@ func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error
|
|||||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set persistent keepalive interval: %w", err)
|
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set persistent keepalive interval: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
old := atomic.SwapUint32(&peer.persistentKeepaliveInterval, uint32(secs))
|
old := peer.persistentKeepaliveInterval.Swap(uint32(secs))
|
||||||
|
|
||||||
// Send immediate keepalive if we're turning it on and before it wasn't on.
|
// Send immediate keepalive if we're turning it on and before it wasn't on.
|
||||||
peer.pkaOn = old == 0 && secs != 0
|
peer.pkaOn = old == 0 && secs != 0
|
||||||
|
2
go.mod
2
go.mod
@ -1,6 +1,6 @@
|
|||||||
module golang.zx2c4.com/wireguard
|
module golang.zx2c4.com/wireguard
|
||||||
|
|
||||||
go 1.18
|
go 1.19
|
||||||
|
|
||||||
require (
|
require (
|
||||||
golang.org/x/crypto v0.0.0-20220315160706-3147a52a75dd
|
golang.org/x/crypto v0.0.0-20220315160706-3147a52a75dd
|
||||||
|
@ -54,7 +54,7 @@ type file struct {
|
|||||||
handle windows.Handle
|
handle windows.Handle
|
||||||
wg sync.WaitGroup
|
wg sync.WaitGroup
|
||||||
wgLock sync.RWMutex
|
wgLock sync.RWMutex
|
||||||
closing uint32 // used as atomic boolean
|
closing atomic.Bool
|
||||||
socket bool
|
socket bool
|
||||||
readDeadline deadlineHandler
|
readDeadline deadlineHandler
|
||||||
writeDeadline deadlineHandler
|
writeDeadline deadlineHandler
|
||||||
@ -65,7 +65,7 @@ type deadlineHandler struct {
|
|||||||
channel timeoutChan
|
channel timeoutChan
|
||||||
channelLock sync.RWMutex
|
channelLock sync.RWMutex
|
||||||
timer *time.Timer
|
timer *time.Timer
|
||||||
timedout uint32 // used as atomic boolean
|
timedout atomic.Bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// makeFile makes a new file from an existing file handle
|
// makeFile makes a new file from an existing file handle
|
||||||
@ -89,7 +89,7 @@ func makeFile(h windows.Handle) (*file, error) {
|
|||||||
func (f *file) closeHandle() {
|
func (f *file) closeHandle() {
|
||||||
f.wgLock.Lock()
|
f.wgLock.Lock()
|
||||||
// Atomically set that we are closing, releasing the resources only once.
|
// Atomically set that we are closing, releasing the resources only once.
|
||||||
if atomic.SwapUint32(&f.closing, 1) == 0 {
|
if f.closing.Swap(true) == false {
|
||||||
f.wgLock.Unlock()
|
f.wgLock.Unlock()
|
||||||
// cancel all IO and wait for it to complete
|
// cancel all IO and wait for it to complete
|
||||||
windows.CancelIoEx(f.handle, nil)
|
windows.CancelIoEx(f.handle, nil)
|
||||||
@ -112,7 +112,7 @@ func (f *file) Close() error {
|
|||||||
// The caller must call f.wg.Done() when the IO is finished, prior to Close() returning.
|
// The caller must call f.wg.Done() when the IO is finished, prior to Close() returning.
|
||||||
func (f *file) prepareIo() (*ioOperation, error) {
|
func (f *file) prepareIo() (*ioOperation, error) {
|
||||||
f.wgLock.RLock()
|
f.wgLock.RLock()
|
||||||
if atomic.LoadUint32(&f.closing) == 1 {
|
if f.closing.Load() {
|
||||||
f.wgLock.RUnlock()
|
f.wgLock.RUnlock()
|
||||||
return nil, os.ErrClosed
|
return nil, os.ErrClosed
|
||||||
}
|
}
|
||||||
@ -144,7 +144,7 @@ func (f *file) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, err err
|
|||||||
return int(bytes), err
|
return int(bytes), err
|
||||||
}
|
}
|
||||||
|
|
||||||
if atomic.LoadUint32(&f.closing) == 1 {
|
if f.closing.Load() {
|
||||||
windows.CancelIoEx(f.handle, &c.o)
|
windows.CancelIoEx(f.handle, &c.o)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -160,7 +160,7 @@ func (f *file) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, err err
|
|||||||
case r = <-c.ch:
|
case r = <-c.ch:
|
||||||
err = r.err
|
err = r.err
|
||||||
if err == windows.ERROR_OPERATION_ABORTED {
|
if err == windows.ERROR_OPERATION_ABORTED {
|
||||||
if atomic.LoadUint32(&f.closing) == 1 {
|
if f.closing.Load() {
|
||||||
err = os.ErrClosed
|
err = os.ErrClosed
|
||||||
}
|
}
|
||||||
} else if err != nil && f.socket {
|
} else if err != nil && f.socket {
|
||||||
@ -192,7 +192,7 @@ func (f *file) Read(b []byte) (int, error) {
|
|||||||
}
|
}
|
||||||
defer f.wg.Done()
|
defer f.wg.Done()
|
||||||
|
|
||||||
if atomic.LoadUint32(&f.readDeadline.timedout) == 1 {
|
if f.readDeadline.timedout.Load() {
|
||||||
return 0, os.ErrDeadlineExceeded
|
return 0, os.ErrDeadlineExceeded
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -219,7 +219,7 @@ func (f *file) Write(b []byte) (int, error) {
|
|||||||
}
|
}
|
||||||
defer f.wg.Done()
|
defer f.wg.Done()
|
||||||
|
|
||||||
if atomic.LoadUint32(&f.writeDeadline.timedout) == 1 {
|
if f.writeDeadline.timedout.Load() {
|
||||||
return 0, os.ErrDeadlineExceeded
|
return 0, os.ErrDeadlineExceeded
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -256,7 +256,7 @@ func (d *deadlineHandler) set(deadline time.Time) error {
|
|||||||
}
|
}
|
||||||
d.timer = nil
|
d.timer = nil
|
||||||
}
|
}
|
||||||
atomic.StoreUint32(&d.timedout, 0)
|
d.timedout.Store(false)
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-d.channel:
|
case <-d.channel:
|
||||||
@ -271,7 +271,7 @@ func (d *deadlineHandler) set(deadline time.Time) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
timeoutIO := func() {
|
timeoutIO := func() {
|
||||||
atomic.StoreUint32(&d.timedout, 1)
|
d.timedout.Store(true)
|
||||||
close(d.channel)
|
close(d.channel)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -29,7 +29,7 @@ type pipe struct {
|
|||||||
|
|
||||||
type messageBytePipe struct {
|
type messageBytePipe struct {
|
||||||
pipe
|
pipe
|
||||||
writeClosed int32
|
writeClosed atomic.Bool
|
||||||
readEOF bool
|
readEOF bool
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -51,17 +51,17 @@ func (f *pipe) SetDeadline(t time.Time) error {
|
|||||||
|
|
||||||
// CloseWrite closes the write side of a message pipe in byte mode.
|
// CloseWrite closes the write side of a message pipe in byte mode.
|
||||||
func (f *messageBytePipe) CloseWrite() error {
|
func (f *messageBytePipe) CloseWrite() error {
|
||||||
if !atomic.CompareAndSwapInt32(&f.writeClosed, 0, 1) {
|
if !f.writeClosed.CompareAndSwap(false, true) {
|
||||||
return io.ErrClosedPipe
|
return io.ErrClosedPipe
|
||||||
}
|
}
|
||||||
err := f.file.Flush()
|
err := f.file.Flush()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
atomic.StoreInt32(&f.writeClosed, 0)
|
f.writeClosed.Store(false)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
_, err = f.file.Write(nil)
|
_, err = f.file.Write(nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
atomic.StoreInt32(&f.writeClosed, 0)
|
f.writeClosed.Store(false)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@ -70,7 +70,7 @@ func (f *messageBytePipe) CloseWrite() error {
|
|||||||
// Write writes bytes to a message pipe in byte mode. Zero-byte writes are ignored, since
|
// Write writes bytes to a message pipe in byte mode. Zero-byte writes are ignored, since
|
||||||
// they are used to implement CloseWrite.
|
// they are used to implement CloseWrite.
|
||||||
func (f *messageBytePipe) Write(b []byte) (int, error) {
|
func (f *messageBytePipe) Write(b []byte) (int, error) {
|
||||||
if atomic.LoadInt32(&f.writeClosed) != 0 {
|
if f.writeClosed.Load() {
|
||||||
return 0, io.ErrClosedPipe
|
return 0, io.ErrClosedPipe
|
||||||
}
|
}
|
||||||
if len(b) == 0 {
|
if len(b) == 0 {
|
||||||
|
@ -26,10 +26,10 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type rateJuggler struct {
|
type rateJuggler struct {
|
||||||
current uint64
|
current atomic.Uint64
|
||||||
nextByteCount uint64
|
nextByteCount atomic.Uint64
|
||||||
nextStartTime int64
|
nextStartTime atomic.Int64
|
||||||
changing int32
|
changing atomic.Bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type NativeTun struct {
|
type NativeTun struct {
|
||||||
@ -42,7 +42,7 @@ type NativeTun struct {
|
|||||||
events chan Event
|
events chan Event
|
||||||
running sync.WaitGroup
|
running sync.WaitGroup
|
||||||
closeOnce sync.Once
|
closeOnce sync.Once
|
||||||
close int32
|
close atomic.Bool
|
||||||
forcedMTU int
|
forcedMTU int
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -57,18 +57,14 @@ func procyield(cycles uint32)
|
|||||||
//go:linkname nanotime runtime.nanotime
|
//go:linkname nanotime runtime.nanotime
|
||||||
func nanotime() int64
|
func nanotime() int64
|
||||||
|
|
||||||
//
|
|
||||||
// CreateTUN creates a Wintun interface with the given name. Should a Wintun
|
// CreateTUN creates a Wintun interface with the given name. Should a Wintun
|
||||||
// interface with the same name exist, it is reused.
|
// interface with the same name exist, it is reused.
|
||||||
//
|
|
||||||
func CreateTUN(ifname string, mtu int) (Device, error) {
|
func CreateTUN(ifname string, mtu int) (Device, error) {
|
||||||
return CreateTUNWithRequestedGUID(ifname, WintunStaticRequestedGUID, mtu)
|
return CreateTUNWithRequestedGUID(ifname, WintunStaticRequestedGUID, mtu)
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
|
||||||
// CreateTUNWithRequestedGUID creates a Wintun interface with the given name and
|
// CreateTUNWithRequestedGUID creates a Wintun interface with the given name and
|
||||||
// a requested GUID. Should a Wintun interface with the same name exist, it is reused.
|
// a requested GUID. Should a Wintun interface with the same name exist, it is reused.
|
||||||
//
|
|
||||||
func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu int) (Device, error) {
|
func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu int) (Device, error) {
|
||||||
wt, err := wintun.CreateAdapter(ifname, WintunTunnelType, requestedGUID)
|
wt, err := wintun.CreateAdapter(ifname, WintunTunnelType, requestedGUID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -113,7 +109,7 @@ func (tun *NativeTun) Events() chan Event {
|
|||||||
func (tun *NativeTun) Close() error {
|
func (tun *NativeTun) Close() error {
|
||||||
var err error
|
var err error
|
||||||
tun.closeOnce.Do(func() {
|
tun.closeOnce.Do(func() {
|
||||||
atomic.StoreInt32(&tun.close, 1)
|
tun.close.Store(true)
|
||||||
windows.SetEvent(tun.readWait)
|
windows.SetEvent(tun.readWait)
|
||||||
tun.running.Wait()
|
tun.running.Wait()
|
||||||
tun.session.End()
|
tun.session.End()
|
||||||
@ -144,13 +140,13 @@ func (tun *NativeTun) Read(buff []byte, offset int) (int, error) {
|
|||||||
tun.running.Add(1)
|
tun.running.Add(1)
|
||||||
defer tun.running.Done()
|
defer tun.running.Done()
|
||||||
retry:
|
retry:
|
||||||
if atomic.LoadInt32(&tun.close) == 1 {
|
if tun.close.Load() {
|
||||||
return 0, os.ErrClosed
|
return 0, os.ErrClosed
|
||||||
}
|
}
|
||||||
start := nanotime()
|
start := nanotime()
|
||||||
shouldSpin := atomic.LoadUint64(&tun.rate.current) >= spinloopRateThreshold && uint64(start-atomic.LoadInt64(&tun.rate.nextStartTime)) <= rateMeasurementGranularity*2
|
shouldSpin := tun.rate.current.Load() >= spinloopRateThreshold && uint64(start-tun.rate.nextStartTime.Load()) <= rateMeasurementGranularity*2
|
||||||
for {
|
for {
|
||||||
if atomic.LoadInt32(&tun.close) == 1 {
|
if tun.close.Load() {
|
||||||
return 0, os.ErrClosed
|
return 0, os.ErrClosed
|
||||||
}
|
}
|
||||||
packet, err := tun.session.ReceivePacket()
|
packet, err := tun.session.ReceivePacket()
|
||||||
@ -184,7 +180,7 @@ func (tun *NativeTun) Flush() error {
|
|||||||
func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
|
func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
|
||||||
tun.running.Add(1)
|
tun.running.Add(1)
|
||||||
defer tun.running.Done()
|
defer tun.running.Done()
|
||||||
if atomic.LoadInt32(&tun.close) == 1 {
|
if tun.close.Load() {
|
||||||
return 0, os.ErrClosed
|
return 0, os.ErrClosed
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -210,7 +206,7 @@ func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
|
|||||||
func (tun *NativeTun) LUID() uint64 {
|
func (tun *NativeTun) LUID() uint64 {
|
||||||
tun.running.Add(1)
|
tun.running.Add(1)
|
||||||
defer tun.running.Done()
|
defer tun.running.Done()
|
||||||
if atomic.LoadInt32(&tun.close) == 1 {
|
if tun.close.Load() {
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
return tun.wt.LUID()
|
return tun.wt.LUID()
|
||||||
@ -223,15 +219,15 @@ func (tun *NativeTun) RunningVersion() (version uint32, err error) {
|
|||||||
|
|
||||||
func (rate *rateJuggler) update(packetLen uint64) {
|
func (rate *rateJuggler) update(packetLen uint64) {
|
||||||
now := nanotime()
|
now := nanotime()
|
||||||
total := atomic.AddUint64(&rate.nextByteCount, packetLen)
|
total := rate.nextByteCount.Add(packetLen)
|
||||||
period := uint64(now - atomic.LoadInt64(&rate.nextStartTime))
|
period := uint64(now - rate.nextStartTime.Load())
|
||||||
if period >= rateMeasurementGranularity {
|
if period >= rateMeasurementGranularity {
|
||||||
if !atomic.CompareAndSwapInt32(&rate.changing, 0, 1) {
|
if !rate.changing.CompareAndSwap(false, true) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
atomic.StoreInt64(&rate.nextStartTime, now)
|
rate.nextStartTime.Store(now)
|
||||||
atomic.StoreUint64(&rate.current, total*uint64(time.Second/time.Nanosecond)/period)
|
rate.current.Store(total * uint64(time.Second/time.Nanosecond) / period)
|
||||||
atomic.StoreUint64(&rate.nextByteCount, 0)
|
rate.nextByteCount.Store(0)
|
||||||
atomic.StoreInt32(&rate.changing, 0)
|
rate.changing.Store(false)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user