all: use Go 1.19 and its atomic types

Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
This commit is contained in:
Brad Fitzpatrick 2022-08-30 07:43:11 -07:00 committed by Jason A. Donenfeld
parent d1d08426b2
commit b51010ba13
20 changed files with 156 additions and 288 deletions

View File

@ -74,7 +74,7 @@ type afWinRingBind struct {
type WinRingBind struct { type WinRingBind struct {
v4, v6 afWinRingBind v4, v6 afWinRingBind
mu sync.RWMutex mu sync.RWMutex
isOpen uint32 isOpen atomic.Uint32 // 0, 1, or 2
} }
func NewDefaultBind() Bind { return NewWinRingBind() } func NewDefaultBind() Bind { return NewWinRingBind() }
@ -212,7 +212,7 @@ func (bind *afWinRingBind) CloseAndZero() {
} }
func (bind *WinRingBind) closeAndZero() { func (bind *WinRingBind) closeAndZero() {
atomic.StoreUint32(&bind.isOpen, 0) bind.isOpen.Store(0)
bind.v4.CloseAndZero() bind.v4.CloseAndZero()
bind.v6.CloseAndZero() bind.v6.CloseAndZero()
} }
@ -276,7 +276,7 @@ func (bind *WinRingBind) Open(port uint16) (recvFns []ReceiveFunc, selectedPort
bind.closeAndZero() bind.closeAndZero()
} }
}() }()
if atomic.LoadUint32(&bind.isOpen) != 0 { if bind.isOpen.Load() != 0 {
return nil, 0, ErrBindAlreadyOpen return nil, 0, ErrBindAlreadyOpen
} }
var sa windows.Sockaddr var sa windows.Sockaddr
@ -299,17 +299,17 @@ func (bind *WinRingBind) Open(port uint16) (recvFns []ReceiveFunc, selectedPort
return nil, 0, err return nil, 0, err
} }
} }
atomic.StoreUint32(&bind.isOpen, 1) bind.isOpen.Store(1)
return []ReceiveFunc{bind.receiveIPv4, bind.receiveIPv6}, selectedPort, err return []ReceiveFunc{bind.receiveIPv4, bind.receiveIPv6}, selectedPort, err
} }
func (bind *WinRingBind) Close() error { func (bind *WinRingBind) Close() error {
bind.mu.RLock() bind.mu.RLock()
if atomic.LoadUint32(&bind.isOpen) != 1 { if bind.isOpen.Load() != 1 {
bind.mu.RUnlock() bind.mu.RUnlock()
return nil return nil
} }
atomic.StoreUint32(&bind.isOpen, 2) bind.isOpen.Store(2)
windows.PostQueuedCompletionStatus(bind.v4.rx.iocp, 0, 0, nil) windows.PostQueuedCompletionStatus(bind.v4.rx.iocp, 0, 0, nil)
windows.PostQueuedCompletionStatus(bind.v4.tx.iocp, 0, 0, nil) windows.PostQueuedCompletionStatus(bind.v4.tx.iocp, 0, 0, nil)
windows.PostQueuedCompletionStatus(bind.v6.rx.iocp, 0, 0, nil) windows.PostQueuedCompletionStatus(bind.v6.rx.iocp, 0, 0, nil)
@ -345,8 +345,8 @@ func (bind *afWinRingBind) InsertReceiveRequest() error {
//go:linkname procyield runtime.procyield //go:linkname procyield runtime.procyield
func procyield(cycles uint32) func procyield(cycles uint32)
func (bind *afWinRingBind) Receive(buf []byte, isOpen *uint32) (int, Endpoint, error) { func (bind *afWinRingBind) Receive(buf []byte, isOpen *atomic.Uint32) (int, Endpoint, error) {
if atomic.LoadUint32(isOpen) != 1 { if isOpen.Load() != 1 {
return 0, nil, net.ErrClosed return 0, nil, net.ErrClosed
} }
bind.rx.mu.Lock() bind.rx.mu.Lock()
@ -359,7 +359,7 @@ retry:
count = 0 count = 0
for tries := 0; count == 0 && tries < receiveSpins; tries++ { for tries := 0; count == 0 && tries < receiveSpins; tries++ {
if tries > 0 { if tries > 0 {
if atomic.LoadUint32(isOpen) != 1 { if isOpen.Load() != 1 {
return 0, nil, net.ErrClosed return 0, nil, net.ErrClosed
} }
procyield(1) procyield(1)
@ -378,7 +378,7 @@ retry:
if err != nil { if err != nil {
return 0, nil, err return 0, nil, err
} }
if atomic.LoadUint32(isOpen) != 1 { if isOpen.Load() != 1 {
return 0, nil, net.ErrClosed return 0, nil, net.ErrClosed
} }
count = winrio.DequeueCompletion(bind.rx.cq, results[:]) count = winrio.DequeueCompletion(bind.rx.cq, results[:])
@ -395,7 +395,7 @@ retry:
// huge packets. Just try again when this happens. The infinite loop this could cause is still limited to // huge packets. Just try again when this happens. The infinite loop this could cause is still limited to
// attacker bandwidth, just like the rest of the receive path. // attacker bandwidth, just like the rest of the receive path.
if windows.Errno(results[0].Status) == windows.WSAEMSGSIZE { if windows.Errno(results[0].Status) == windows.WSAEMSGSIZE {
if atomic.LoadUint32(isOpen) != 1 { if isOpen.Load() != 1 {
return 0, nil, net.ErrClosed return 0, nil, net.ErrClosed
} }
goto retry goto retry
@ -421,8 +421,8 @@ func (bind *WinRingBind) receiveIPv6(buf []byte) (int, Endpoint, error) {
return bind.v6.Receive(buf, &bind.isOpen) return bind.v6.Receive(buf, &bind.isOpen)
} }
func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *uint32) error { func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *atomic.Uint32) error {
if atomic.LoadUint32(isOpen) != 1 { if isOpen.Load() != 1 {
return net.ErrClosed return net.ErrClosed
} }
if len(buf) > bytesPerPacket { if len(buf) > bytesPerPacket {
@ -444,7 +444,7 @@ func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *uint3
if err != nil { if err != nil {
return err return err
} }
if atomic.LoadUint32(isOpen) != 1 { if isOpen.Load() != 1 {
return net.ErrClosed return net.ErrClosed
} }
count = winrio.DequeueCompletion(bind.tx.cq, results[:]) count = winrio.DequeueCompletion(bind.tx.cq, results[:])
@ -538,7 +538,7 @@ func (bind *StdNetBind) BindSocketToInterface6(interfaceIndex uint32, blackhole
func (bind *WinRingBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error { func (bind *WinRingBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
bind.mu.RLock() bind.mu.RLock()
defer bind.mu.RUnlock() defer bind.mu.RUnlock()
if atomic.LoadUint32(&bind.isOpen) != 1 { if bind.isOpen.Load() != 1 {
return net.ErrClosed return net.ErrClosed
} }
err := bindSocketToInterface4(bind.v4.sock, interfaceIndex) err := bindSocketToInterface4(bind.v4.sock, interfaceIndex)
@ -552,7 +552,7 @@ func (bind *WinRingBind) BindSocketToInterface4(interfaceIndex uint32, blackhole
func (bind *WinRingBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error { func (bind *WinRingBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
bind.mu.RLock() bind.mu.RLock()
defer bind.mu.RUnlock() defer bind.mu.RUnlock()
if atomic.LoadUint32(&bind.isOpen) != 1 { if bind.isOpen.Load() != 1 {
return net.ErrClosed return net.ErrClosed
} }
err := bindSocketToInterface6(bind.v6.sock, interfaceIndex) err := bindSocketToInterface6(bind.v6.sock, interfaceIndex)

View File

@ -1,65 +0,0 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"reflect"
"testing"
"unsafe"
)
func checkAlignment(t *testing.T, name string, offset uintptr) {
t.Helper()
if offset%8 != 0 {
t.Errorf("offset of %q within struct is %d bytes, which does not align to 64-bit word boundaries (missing %d bytes). Atomic operations will crash on 32-bit systems.", name, offset, 8-(offset%8))
}
}
// TestPeerAlignment checks that atomically-accessed fields are
// aligned to 64-bit boundaries, as required by the atomic package.
//
// Unfortunately, violating this rule on 32-bit platforms results in a
// hard segfault at runtime.
func TestPeerAlignment(t *testing.T) {
var p Peer
typ := reflect.TypeOf(&p).Elem()
t.Logf("Peer type size: %d, with fields:", typ.Size())
for i := 0; i < typ.NumField(); i++ {
field := typ.Field(i)
t.Logf("\t%30s\toffset=%3v\t(type size=%3d, align=%d)",
field.Name,
field.Offset,
field.Type.Size(),
field.Type.Align(),
)
}
checkAlignment(t, "Peer.stats", unsafe.Offsetof(p.stats))
checkAlignment(t, "Peer.isRunning", unsafe.Offsetof(p.isRunning))
}
// TestDeviceAlignment checks that atomically-accessed fields are
// aligned to 64-bit boundaries, as required by the atomic package.
//
// Unfortunately, violating this rule on 32-bit platforms results in a
// hard segfault at runtime.
func TestDeviceAlignment(t *testing.T) {
var d Device
typ := reflect.TypeOf(&d).Elem()
t.Logf("Device type size: %d, with fields:", typ.Size())
for i := 0; i < typ.NumField(); i++ {
field := typ.Field(i)
t.Logf("\t%30s\toffset=%3v\t(type size=%3d, align=%d)",
field.Name,
field.Offset,
field.Type.Size(),
field.Type.Align(),
)
}
checkAlignment(t, "Device.rate.underLoadUntil", unsafe.Offsetof(d.rate)+unsafe.Offsetof(d.rate.underLoadUntil))
}

View File

@ -30,7 +30,7 @@ type Device struct {
// will become the actual state; Up can fail. // will become the actual state; Up can fail.
// The device can also change state multiple times between time of check and time of use. // The device can also change state multiple times between time of check and time of use.
// Unsynchronized uses of state must therefore be advisory/best-effort only. // Unsynchronized uses of state must therefore be advisory/best-effort only.
state uint32 // actually a deviceState, but typed uint32 for convenience state atomic.Uint32 // actually a deviceState, but typed uint32 for convenience
// stopping blocks until all inputs to Device have been closed. // stopping blocks until all inputs to Device have been closed.
stopping sync.WaitGroup stopping sync.WaitGroup
// mu protects state changes. // mu protects state changes.
@ -58,9 +58,8 @@ type Device struct {
keyMap map[NoisePublicKey]*Peer keyMap map[NoisePublicKey]*Peer
} }
// Keep this 8-byte aligned
rate struct { rate struct {
underLoadUntil int64 underLoadUntil atomic.Int64
limiter ratelimiter.Ratelimiter limiter ratelimiter.Ratelimiter
} }
@ -82,7 +81,7 @@ type Device struct {
tun struct { tun struct {
device tun.Device device tun.Device
mtu int32 mtu atomic.Int32
} }
ipcMutex sync.RWMutex ipcMutex sync.RWMutex
@ -94,10 +93,9 @@ type Device struct {
// There are three states: down, up, closed. // There are three states: down, up, closed.
// Transitions: // Transitions:
// //
// down -----+ // down -----+
// ↑↓ ↓ // ↑↓ ↓
// up -> closed // up -> closed
//
type deviceState uint32 type deviceState uint32
//go:generate go run golang.org/x/tools/cmd/stringer -type deviceState -trimprefix=deviceState //go:generate go run golang.org/x/tools/cmd/stringer -type deviceState -trimprefix=deviceState
@ -110,7 +108,7 @@ const (
// deviceState returns device.state.state as a deviceState // deviceState returns device.state.state as a deviceState
// See those docs for how to interpret this value. // See those docs for how to interpret this value.
func (device *Device) deviceState() deviceState { func (device *Device) deviceState() deviceState {
return deviceState(atomic.LoadUint32(&device.state.state)) return deviceState(device.state.state.Load())
} }
// isClosed reports whether the device is closed (or is closing). // isClosed reports whether the device is closed (or is closing).
@ -149,14 +147,14 @@ func (device *Device) changeState(want deviceState) (err error) {
case old: case old:
return nil return nil
case deviceStateUp: case deviceStateUp:
atomic.StoreUint32(&device.state.state, uint32(deviceStateUp)) device.state.state.Store(uint32(deviceStateUp))
err = device.upLocked() err = device.upLocked()
if err == nil { if err == nil {
break break
} }
fallthrough // up failed; bring the device all the way back down fallthrough // up failed; bring the device all the way back down
case deviceStateDown: case deviceStateDown:
atomic.StoreUint32(&device.state.state, uint32(deviceStateDown)) device.state.state.Store(uint32(deviceStateDown))
errDown := device.downLocked() errDown := device.downLocked()
if err == nil { if err == nil {
err = errDown err = errDown
@ -182,7 +180,7 @@ func (device *Device) upLocked() error {
device.peers.RLock() device.peers.RLock()
for _, peer := range device.peers.keyMap { for _, peer := range device.peers.keyMap {
peer.Start() peer.Start()
if atomic.LoadUint32(&peer.persistentKeepaliveInterval) > 0 { if peer.persistentKeepaliveInterval.Load() > 0 {
peer.SendKeepalive() peer.SendKeepalive()
} }
} }
@ -219,11 +217,11 @@ func (device *Device) IsUnderLoad() bool {
now := time.Now() now := time.Now()
underLoad := len(device.queue.handshake.c) >= QueueHandshakeSize/8 underLoad := len(device.queue.handshake.c) >= QueueHandshakeSize/8
if underLoad { if underLoad {
atomic.StoreInt64(&device.rate.underLoadUntil, now.Add(UnderLoadAfterTime).UnixNano()) device.rate.underLoadUntil.Store(now.Add(UnderLoadAfterTime).UnixNano())
return true return true
} }
// check if recently under load // check if recently under load
return atomic.LoadInt64(&device.rate.underLoadUntil) > now.UnixNano() return device.rate.underLoadUntil.Load() > now.UnixNano()
} }
func (device *Device) SetPrivateKey(sk NoisePrivateKey) error { func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
@ -283,7 +281,7 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger) *Device { func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger) *Device {
device := new(Device) device := new(Device)
device.state.state = uint32(deviceStateDown) device.state.state.Store(uint32(deviceStateDown))
device.closed = make(chan struct{}) device.closed = make(chan struct{})
device.log = logger device.log = logger
device.net.bind = bind device.net.bind = bind
@ -293,7 +291,7 @@ func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger) *Device {
device.log.Errorf("Trouble determining MTU, assuming default: %v", err) device.log.Errorf("Trouble determining MTU, assuming default: %v", err)
mtu = DefaultMTU mtu = DefaultMTU
} }
device.tun.mtu = int32(mtu) device.tun.mtu.Store(int32(mtu))
device.peers.keyMap = make(map[NoisePublicKey]*Peer) device.peers.keyMap = make(map[NoisePublicKey]*Peer)
device.rate.limiter.Init() device.rate.limiter.Init()
device.indexTable.Init() device.indexTable.Init()
@ -359,7 +357,7 @@ func (device *Device) Close() {
if device.isClosed() { if device.isClosed() {
return return
} }
atomic.StoreUint32(&device.state.state, uint32(deviceStateClosed)) device.state.state.Store(uint32(deviceStateClosed))
device.log.Verbosef("Device closing") device.log.Verbosef("Device closing")
device.tun.device.Close() device.tun.device.Close()

View File

@ -333,7 +333,7 @@ func BenchmarkThroughput(b *testing.B) {
// Measure how long it takes to receive b.N packets, // Measure how long it takes to receive b.N packets,
// starting when we receive the first packet. // starting when we receive the first packet.
var recv uint64 var recv atomic.Uint64
var elapsed time.Duration var elapsed time.Duration
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(1) wg.Add(1)
@ -342,7 +342,7 @@ func BenchmarkThroughput(b *testing.B) {
var start time.Time var start time.Time
for { for {
<-pair[0].tun.Inbound <-pair[0].tun.Inbound
new := atomic.AddUint64(&recv, 1) new := recv.Add(1)
if new == 1 { if new == 1 {
start = time.Now() start = time.Now()
} }
@ -358,7 +358,7 @@ func BenchmarkThroughput(b *testing.B) {
ping := tuntest.Ping(pair[0].ip, pair[1].ip) ping := tuntest.Ping(pair[0].ip, pair[1].ip)
pingc := pair[1].tun.Outbound pingc := pair[1].tun.Outbound
var sent uint64 var sent uint64
for atomic.LoadUint64(&recv) != uint64(b.N) { for recv.Load() != uint64(b.N) {
sent++ sent++
pingc <- ping pingc <- ping
} }

View File

@ -10,7 +10,6 @@ import (
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
"unsafe"
"golang.zx2c4.com/wireguard/replay" "golang.zx2c4.com/wireguard/replay"
) )
@ -23,7 +22,7 @@ import (
*/ */
type Keypair struct { type Keypair struct {
sendNonce uint64 // accessed atomically sendNonce atomic.Uint64
send cipher.AEAD send cipher.AEAD
receive cipher.AEAD receive cipher.AEAD
replayFilter replay.Filter replayFilter replay.Filter
@ -37,15 +36,7 @@ type Keypairs struct {
sync.RWMutex sync.RWMutex
current *Keypair current *Keypair
previous *Keypair previous *Keypair
next *Keypair next atomic.Pointer[Keypair]
}
func (kp *Keypairs) storeNext(next *Keypair) {
atomic.StorePointer((*unsafe.Pointer)((unsafe.Pointer)(&kp.next)), (unsafe.Pointer)(next))
}
func (kp *Keypairs) loadNext() *Keypair {
return (*Keypair)(atomic.LoadPointer((*unsafe.Pointer)((unsafe.Pointer)(&kp.next))))
} }
func (kp *Keypairs) Current() *Keypair { func (kp *Keypairs) Current() *Keypair {

View File

@ -1,41 +0,0 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"sync/atomic"
)
/* Atomic Boolean */
const (
AtomicFalse = int32(iota)
AtomicTrue
)
type AtomicBool struct {
int32
}
func (a *AtomicBool) Get() bool {
return atomic.LoadInt32(&a.int32) == AtomicTrue
}
func (a *AtomicBool) Swap(val bool) bool {
flag := AtomicFalse
if val {
flag = AtomicTrue
}
return atomic.SwapInt32(&a.int32, flag) == AtomicTrue
}
func (a *AtomicBool) Set(val bool) {
flag := AtomicFalse
if val {
flag = AtomicTrue
}
atomic.StoreInt32(&a.int32, flag)
}

View File

@ -282,7 +282,7 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
// lookup peer // lookup peer
peer := device.LookupPeer(peerPK) peer := device.LookupPeer(peerPK)
if peer == nil || !peer.isRunning.Get() { if peer == nil || !peer.isRunning.Load() {
return nil return nil
} }
@ -581,12 +581,12 @@ func (peer *Peer) BeginSymmetricSession() error {
defer keypairs.Unlock() defer keypairs.Unlock()
previous := keypairs.previous previous := keypairs.previous
next := keypairs.loadNext() next := keypairs.next.Load()
current := keypairs.current current := keypairs.current
if isInitiator { if isInitiator {
if next != nil { if next != nil {
keypairs.storeNext(nil) keypairs.next.Store(nil)
keypairs.previous = next keypairs.previous = next
device.DeleteKeypair(current) device.DeleteKeypair(current)
} else { } else {
@ -595,7 +595,7 @@ func (peer *Peer) BeginSymmetricSession() error {
device.DeleteKeypair(previous) device.DeleteKeypair(previous)
keypairs.current = keypair keypairs.current = keypair
} else { } else {
keypairs.storeNext(keypair) keypairs.next.Store(keypair)
device.DeleteKeypair(next) device.DeleteKeypair(next)
keypairs.previous = nil keypairs.previous = nil
device.DeleteKeypair(previous) device.DeleteKeypair(previous)
@ -607,18 +607,18 @@ func (peer *Peer) BeginSymmetricSession() error {
func (peer *Peer) ReceivedWithKeypair(receivedKeypair *Keypair) bool { func (peer *Peer) ReceivedWithKeypair(receivedKeypair *Keypair) bool {
keypairs := &peer.keypairs keypairs := &peer.keypairs
if keypairs.loadNext() != receivedKeypair { if keypairs.next.Load() != receivedKeypair {
return false return false
} }
keypairs.Lock() keypairs.Lock()
defer keypairs.Unlock() defer keypairs.Unlock()
if keypairs.loadNext() != receivedKeypair { if keypairs.next.Load() != receivedKeypair {
return false return false
} }
old := keypairs.previous old := keypairs.previous
keypairs.previous = keypairs.current keypairs.previous = keypairs.current
peer.device.DeleteKeypair(old) peer.device.DeleteKeypair(old)
keypairs.current = keypairs.loadNext() keypairs.current = keypairs.next.Load()
keypairs.storeNext(nil) keypairs.next.Store(nil)
return true return true
} }

View File

@ -148,7 +148,7 @@ func TestNoiseHandshake(t *testing.T) {
t.Fatal("failed to derive keypair for peer 2", err) t.Fatal("failed to derive keypair for peer 2", err)
} }
key1 := peer1.keypairs.loadNext() key1 := peer1.keypairs.next.Load()
key2 := peer2.keypairs.current key2 := peer2.keypairs.current
// encrypting / decryption test // encrypting / decryption test

View File

@ -16,24 +16,16 @@ import (
) )
type Peer struct { type Peer struct {
isRunning AtomicBool isRunning atomic.Bool
sync.RWMutex // Mostly protects endpoint, but is generally taken whenever we modify peer sync.RWMutex // Mostly protects endpoint, but is generally taken whenever we modify peer
keypairs Keypairs keypairs Keypairs
handshake Handshake handshake Handshake
device *Device device *Device
endpoint conn.Endpoint endpoint conn.Endpoint
stopping sync.WaitGroup // routines pending stop stopping sync.WaitGroup // routines pending stop
txBytes atomic.Uint64 // bytes send to peer (endpoint)
// These fields are accessed with atomic operations, which must be rxBytes atomic.Uint64 // bytes received from peer
// 64-bit aligned even on 32-bit platforms. Go guarantees that an lastHandshakeNano atomic.Int64 // nano seconds since epoch
// allocated struct will be 64-bit aligned. So we place
// atomically-accessed fields up front, so that they can share in
// this alignment before smaller fields throw it off.
stats struct {
txBytes uint64 // bytes send to peer (endpoint)
rxBytes uint64 // bytes received from peer
lastHandshakeNano int64 // nano seconds since epoch
}
disableRoaming bool disableRoaming bool
@ -43,9 +35,9 @@ type Peer struct {
newHandshake *Timer newHandshake *Timer
zeroKeyMaterial *Timer zeroKeyMaterial *Timer
persistentKeepalive *Timer persistentKeepalive *Timer
handshakeAttempts uint32 handshakeAttempts atomic.Uint32
needAnotherKeepalive AtomicBool needAnotherKeepalive atomic.Bool
sentLastMinuteHandshake AtomicBool sentLastMinuteHandshake atomic.Bool
} }
state struct { state struct {
@ -60,7 +52,7 @@ type Peer struct {
cookieGenerator CookieGenerator cookieGenerator CookieGenerator
trieEntries list.List trieEntries list.List
persistentKeepaliveInterval uint32 // accessed atomically persistentKeepaliveInterval atomic.Uint32
} }
func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
@ -133,7 +125,7 @@ func (peer *Peer) SendBuffer(buffer []byte) error {
err := peer.device.net.bind.Send(buffer, peer.endpoint) err := peer.device.net.bind.Send(buffer, peer.endpoint)
if err == nil { if err == nil {
atomic.AddUint64(&peer.stats.txBytes, uint64(len(buffer))) peer.txBytes.Add(uint64(len(buffer)))
} }
return err return err
} }
@ -174,7 +166,7 @@ func (peer *Peer) Start() {
peer.state.Lock() peer.state.Lock()
defer peer.state.Unlock() defer peer.state.Unlock()
if peer.isRunning.Get() { if peer.isRunning.Load() {
return return
} }
@ -198,7 +190,7 @@ func (peer *Peer) Start() {
go peer.RoutineSequentialSender() go peer.RoutineSequentialSender()
go peer.RoutineSequentialReceiver() go peer.RoutineSequentialReceiver()
peer.isRunning.Set(true) peer.isRunning.Store(true)
} }
func (peer *Peer) ZeroAndFlushAll() { func (peer *Peer) ZeroAndFlushAll() {
@ -210,10 +202,10 @@ func (peer *Peer) ZeroAndFlushAll() {
keypairs.Lock() keypairs.Lock()
device.DeleteKeypair(keypairs.previous) device.DeleteKeypair(keypairs.previous)
device.DeleteKeypair(keypairs.current) device.DeleteKeypair(keypairs.current)
device.DeleteKeypair(keypairs.loadNext()) device.DeleteKeypair(keypairs.next.Load())
keypairs.previous = nil keypairs.previous = nil
keypairs.current = nil keypairs.current = nil
keypairs.storeNext(nil) keypairs.next.Store(nil)
keypairs.Unlock() keypairs.Unlock()
// clear handshake state // clear handshake state
@ -238,11 +230,10 @@ func (peer *Peer) ExpireCurrentKeypairs() {
keypairs := &peer.keypairs keypairs := &peer.keypairs
keypairs.Lock() keypairs.Lock()
if keypairs.current != nil { if keypairs.current != nil {
atomic.StoreUint64(&keypairs.current.sendNonce, RejectAfterMessages) keypairs.current.sendNonce.Store(RejectAfterMessages)
} }
if keypairs.next != nil { if next := keypairs.next.Load(); next != nil {
next := keypairs.loadNext() next.sendNonce.Store(RejectAfterMessages)
atomic.StoreUint64(&next.sendNonce, RejectAfterMessages)
} }
keypairs.Unlock() keypairs.Unlock()
} }

View File

@ -14,7 +14,7 @@ type WaitPool struct {
pool sync.Pool pool sync.Pool
cond sync.Cond cond sync.Cond
lock sync.Mutex lock sync.Mutex
count uint32 count atomic.Uint32
max uint32 max uint32
} }
@ -27,10 +27,10 @@ func NewWaitPool(max uint32, new func() any) *WaitPool {
func (p *WaitPool) Get() any { func (p *WaitPool) Get() any {
if p.max != 0 { if p.max != 0 {
p.lock.Lock() p.lock.Lock()
for atomic.LoadUint32(&p.count) >= p.max { for p.count.Load() >= p.max {
p.cond.Wait() p.cond.Wait()
} }
atomic.AddUint32(&p.count, 1) p.count.Add(1)
p.lock.Unlock() p.lock.Unlock()
} }
return p.pool.Get() return p.pool.Get()
@ -41,7 +41,7 @@ func (p *WaitPool) Put(x any) {
if p.max == 0 { if p.max == 0 {
return return
} }
atomic.AddUint32(&p.count, ^uint32(0)) p.count.Add(^uint32(0))
p.cond.Signal() p.cond.Signal()
} }

View File

@ -17,29 +17,31 @@ import (
func TestWaitPool(t *testing.T) { func TestWaitPool(t *testing.T) {
t.Skip("Currently disabled") t.Skip("Currently disabled")
var wg sync.WaitGroup var wg sync.WaitGroup
trials := int32(100000) var trials atomic.Int32
startTrials := int32(100000)
if raceEnabled { if raceEnabled {
// This test can be very slow with -race. // This test can be very slow with -race.
trials /= 10 startTrials /= 10
} }
trials.Store(startTrials)
workers := runtime.NumCPU() + 2 workers := runtime.NumCPU() + 2
if workers-4 <= 0 { if workers-4 <= 0 {
t.Skip("Not enough cores") t.Skip("Not enough cores")
} }
p := NewWaitPool(uint32(workers-4), func() any { return make([]byte, 16) }) p := NewWaitPool(uint32(workers-4), func() any { return make([]byte, 16) })
wg.Add(workers) wg.Add(workers)
max := uint32(0) var max atomic.Uint32
updateMax := func() { updateMax := func() {
count := atomic.LoadUint32(&p.count) count := p.count.Load()
if count > p.max { if count > p.max {
t.Errorf("count (%d) > max (%d)", count, p.max) t.Errorf("count (%d) > max (%d)", count, p.max)
} }
for { for {
old := atomic.LoadUint32(&max) old := max.Load()
if count <= old { if count <= old {
break break
} }
if atomic.CompareAndSwapUint32(&max, old, count) { if max.CompareAndSwap(old, count) {
break break
} }
} }
@ -47,7 +49,7 @@ func TestWaitPool(t *testing.T) {
for i := 0; i < workers; i++ { for i := 0; i < workers; i++ {
go func() { go func() {
defer wg.Done() defer wg.Done()
for atomic.AddInt32(&trials, -1) > 0 { for trials.Add(-1) > 0 {
updateMax() updateMax()
x := p.Get() x := p.Get()
updateMax() updateMax()
@ -59,14 +61,15 @@ func TestWaitPool(t *testing.T) {
}() }()
} }
wg.Wait() wg.Wait()
if max != p.max { if max.Load() != p.max {
t.Errorf("Actual maximum count (%d) != ideal maximum count (%d)", max, p.max) t.Errorf("Actual maximum count (%d) != ideal maximum count (%d)", max, p.max)
} }
} }
func BenchmarkWaitPool(b *testing.B) { func BenchmarkWaitPool(b *testing.B) {
var wg sync.WaitGroup var wg sync.WaitGroup
trials := int32(b.N) var trials atomic.Int32
trials.Store(int32(b.N))
workers := runtime.NumCPU() + 2 workers := runtime.NumCPU() + 2
if workers-4 <= 0 { if workers-4 <= 0 {
b.Skip("Not enough cores") b.Skip("Not enough cores")
@ -77,7 +80,7 @@ func BenchmarkWaitPool(b *testing.B) {
for i := 0; i < workers; i++ { for i := 0; i < workers; i++ {
go func() { go func() {
defer wg.Done() defer wg.Done()
for atomic.AddInt32(&trials, -1) > 0 { for trials.Add(-1) > 0 {
x := p.Get() x := p.Get()
time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond) time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond)
p.Put(x) p.Put(x)

View File

@ -11,7 +11,6 @@ import (
"errors" "errors"
"net" "net"
"sync" "sync"
"sync/atomic"
"time" "time"
"golang.org/x/crypto/chacha20poly1305" "golang.org/x/crypto/chacha20poly1305"
@ -52,12 +51,12 @@ func (elem *QueueInboundElement) clearPointers() {
* NOTE: Not thread safe, but called by sequential receiver! * NOTE: Not thread safe, but called by sequential receiver!
*/ */
func (peer *Peer) keepKeyFreshReceiving() { func (peer *Peer) keepKeyFreshReceiving() {
if peer.timers.sentLastMinuteHandshake.Get() { if peer.timers.sentLastMinuteHandshake.Load() {
return return
} }
keypair := peer.keypairs.Current() keypair := peer.keypairs.Current()
if keypair != nil && keypair.isInitiator && time.Since(keypair.created) > (RejectAfterTime-KeepaliveTimeout-RekeyTimeout) { if keypair != nil && keypair.isInitiator && time.Since(keypair.created) > (RejectAfterTime-KeepaliveTimeout-RekeyTimeout) {
peer.timers.sentLastMinuteHandshake.Set(true) peer.timers.sentLastMinuteHandshake.Store(true)
peer.SendHandshakeInitiation(false) peer.SendHandshakeInitiation(false)
} }
} }
@ -163,7 +162,7 @@ func (device *Device) RoutineReceiveIncoming(recv conn.ReceiveFunc) {
elem.Lock() elem.Lock()
// add to decryption queues // add to decryption queues
if peer.isRunning.Get() { if peer.isRunning.Load() {
peer.queue.inbound.c <- elem peer.queue.inbound.c <- elem
device.queue.decryption.c <- elem device.queue.decryption.c <- elem
buffer = device.GetMessageBuffer() buffer = device.GetMessageBuffer()
@ -268,7 +267,7 @@ func (device *Device) RoutineHandshake(id int) {
// consume reply // consume reply
if peer := entry.peer; peer.isRunning.Get() { if peer := entry.peer; peer.isRunning.Load() {
device.log.Verbosef("Receiving cookie response from %s", elem.endpoint.DstToString()) device.log.Verbosef("Receiving cookie response from %s", elem.endpoint.DstToString())
if !peer.cookieGenerator.ConsumeReply(&reply) { if !peer.cookieGenerator.ConsumeReply(&reply) {
device.log.Verbosef("Could not decrypt invalid cookie response") device.log.Verbosef("Could not decrypt invalid cookie response")
@ -341,7 +340,7 @@ func (device *Device) RoutineHandshake(id int) {
peer.SetEndpointFromPacket(elem.endpoint) peer.SetEndpointFromPacket(elem.endpoint)
device.log.Verbosef("%v - Received handshake initiation", peer) device.log.Verbosef("%v - Received handshake initiation", peer)
atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet))) peer.rxBytes.Add(uint64(len(elem.packet)))
peer.SendHandshakeResponse() peer.SendHandshakeResponse()
@ -369,7 +368,7 @@ func (device *Device) RoutineHandshake(id int) {
peer.SetEndpointFromPacket(elem.endpoint) peer.SetEndpointFromPacket(elem.endpoint)
device.log.Verbosef("%v - Received handshake response", peer) device.log.Verbosef("%v - Received handshake response", peer)
atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet))) peer.rxBytes.Add(uint64(len(elem.packet)))
// update timers // update timers
@ -426,7 +425,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
peer.keepKeyFreshReceiving() peer.keepKeyFreshReceiving()
peer.timersAnyAuthenticatedPacketTraversal() peer.timersAnyAuthenticatedPacketTraversal()
peer.timersAnyAuthenticatedPacketReceived() peer.timersAnyAuthenticatedPacketReceived()
atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)+MinMessageSize)) peer.rxBytes.Add(uint64(len(elem.packet) + MinMessageSize))
if len(elem.packet) == 0 { if len(elem.packet) == 0 {
device.log.Verbosef("%v - Receiving keepalive packet", peer) device.log.Verbosef("%v - Receiving keepalive packet", peer)

View File

@ -12,7 +12,6 @@ import (
"net" "net"
"os" "os"
"sync" "sync"
"sync/atomic"
"time" "time"
"golang.org/x/crypto/chacha20poly1305" "golang.org/x/crypto/chacha20poly1305"
@ -76,7 +75,7 @@ func (elem *QueueOutboundElement) clearPointers() {
/* Queues a keepalive if no packets are queued for peer /* Queues a keepalive if no packets are queued for peer
*/ */
func (peer *Peer) SendKeepalive() { func (peer *Peer) SendKeepalive() {
if len(peer.queue.staged) == 0 && peer.isRunning.Get() { if len(peer.queue.staged) == 0 && peer.isRunning.Load() {
elem := peer.device.NewOutboundElement() elem := peer.device.NewOutboundElement()
select { select {
case peer.queue.staged <- elem: case peer.queue.staged <- elem:
@ -91,7 +90,7 @@ func (peer *Peer) SendKeepalive() {
func (peer *Peer) SendHandshakeInitiation(isRetry bool) error { func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
if !isRetry { if !isRetry {
atomic.StoreUint32(&peer.timers.handshakeAttempts, 0) peer.timers.handshakeAttempts.Store(0)
} }
peer.handshake.mutex.RLock() peer.handshake.mutex.RLock()
@ -193,7 +192,7 @@ func (peer *Peer) keepKeyFreshSending() {
if keypair == nil { if keypair == nil {
return return
} }
nonce := atomic.LoadUint64(&keypair.sendNonce) nonce := keypair.sendNonce.Load()
if nonce > RekeyAfterMessages || (keypair.isInitiator && time.Since(keypair.created) > RekeyAfterTime) { if nonce > RekeyAfterMessages || (keypair.isInitiator && time.Since(keypair.created) > RekeyAfterTime) {
peer.SendHandshakeInitiation(false) peer.SendHandshakeInitiation(false)
} }
@ -269,7 +268,7 @@ func (device *Device) RoutineReadFromTUN() {
if peer == nil { if peer == nil {
continue continue
} }
if peer.isRunning.Get() { if peer.isRunning.Load() {
peer.StagePacket(elem) peer.StagePacket(elem)
elem = nil elem = nil
peer.SendStagedPackets() peer.SendStagedPackets()
@ -300,7 +299,7 @@ top:
} }
keypair := peer.keypairs.Current() keypair := peer.keypairs.Current()
if keypair == nil || atomic.LoadUint64(&keypair.sendNonce) >= RejectAfterMessages || time.Since(keypair.created) >= RejectAfterTime { if keypair == nil || keypair.sendNonce.Load() >= RejectAfterMessages || time.Since(keypair.created) >= RejectAfterTime {
peer.SendHandshakeInitiation(false) peer.SendHandshakeInitiation(false)
return return
} }
@ -309,9 +308,9 @@ top:
select { select {
case elem := <-peer.queue.staged: case elem := <-peer.queue.staged:
elem.peer = peer elem.peer = peer
elem.nonce = atomic.AddUint64(&keypair.sendNonce, 1) - 1 elem.nonce = keypair.sendNonce.Add(1) - 1
if elem.nonce >= RejectAfterMessages { if elem.nonce >= RejectAfterMessages {
atomic.StoreUint64(&keypair.sendNonce, RejectAfterMessages) keypair.sendNonce.Store(RejectAfterMessages)
peer.StagePacket(elem) // XXX: Out of order, but we can't front-load go chans peer.StagePacket(elem) // XXX: Out of order, but we can't front-load go chans
goto top goto top
} }
@ -320,7 +319,7 @@ top:
elem.Lock() elem.Lock()
// add to parallel and sequential queue // add to parallel and sequential queue
if peer.isRunning.Get() { if peer.isRunning.Load() {
peer.queue.outbound.c <- elem peer.queue.outbound.c <- elem
peer.device.queue.encryption.c <- elem peer.device.queue.encryption.c <- elem
} else { } else {
@ -385,7 +384,7 @@ func (device *Device) RoutineEncryption(id int) {
binary.LittleEndian.PutUint64(fieldNonce, elem.nonce) binary.LittleEndian.PutUint64(fieldNonce, elem.nonce)
// pad content to multiple of 16 // pad content to multiple of 16
paddingSize := calculatePaddingSize(len(elem.packet), int(atomic.LoadInt32(&device.tun.mtu))) paddingSize := calculatePaddingSize(len(elem.packet), int(device.tun.mtu.Load()))
elem.packet = append(elem.packet, paddingZeros[:paddingSize]...) elem.packet = append(elem.packet, paddingZeros[:paddingSize]...)
// encrypt content and release to consumer // encrypt content and release to consumer
@ -419,7 +418,7 @@ func (peer *Peer) RoutineSequentialSender() {
return return
} }
elem.Lock() elem.Lock()
if !peer.isRunning.Get() { if !peer.isRunning.Load() {
// peer has been stopped; return re-usable elems to the shared pool. // peer has been stopped; return re-usable elems to the shared pool.
// This is an optimization only. It is possible for the peer to be stopped // This is an optimization only. It is possible for the peer to be stopped
// immediately after this check, in which case, elem will get processed. // immediately after this check, in which case, elem will get processed.

View File

@ -9,7 +9,6 @@ package device
import ( import (
"sync" "sync"
"sync/atomic"
"time" "time"
_ "unsafe" _ "unsafe"
) )
@ -74,11 +73,11 @@ func (timer *Timer) IsPending() bool {
} }
func (peer *Peer) timersActive() bool { func (peer *Peer) timersActive() bool {
return peer.isRunning.Get() && peer.device != nil && peer.device.isUp() return peer.isRunning.Load() && peer.device != nil && peer.device.isUp()
} }
func expiredRetransmitHandshake(peer *Peer) { func expiredRetransmitHandshake(peer *Peer) {
if atomic.LoadUint32(&peer.timers.handshakeAttempts) > MaxTimerHandshakes { if peer.timers.handshakeAttempts.Load() > MaxTimerHandshakes {
peer.device.log.Verbosef("%s - Handshake did not complete after %d attempts, giving up", peer, MaxTimerHandshakes+2) peer.device.log.Verbosef("%s - Handshake did not complete after %d attempts, giving up", peer, MaxTimerHandshakes+2)
if peer.timersActive() { if peer.timersActive() {
@ -97,8 +96,8 @@ func expiredRetransmitHandshake(peer *Peer) {
peer.timers.zeroKeyMaterial.Mod(RejectAfterTime * 3) peer.timers.zeroKeyMaterial.Mod(RejectAfterTime * 3)
} }
} else { } else {
atomic.AddUint32(&peer.timers.handshakeAttempts, 1) peer.timers.handshakeAttempts.Add(1)
peer.device.log.Verbosef("%s - Handshake did not complete after %d seconds, retrying (try %d)", peer, int(RekeyTimeout.Seconds()), atomic.LoadUint32(&peer.timers.handshakeAttempts)+1) peer.device.log.Verbosef("%s - Handshake did not complete after %d seconds, retrying (try %d)", peer, int(RekeyTimeout.Seconds()), peer.timers.handshakeAttempts.Load()+1)
/* We clear the endpoint address src address, in case this is the cause of trouble. */ /* We clear the endpoint address src address, in case this is the cause of trouble. */
peer.Lock() peer.Lock()
@ -113,8 +112,8 @@ func expiredRetransmitHandshake(peer *Peer) {
func expiredSendKeepalive(peer *Peer) { func expiredSendKeepalive(peer *Peer) {
peer.SendKeepalive() peer.SendKeepalive()
if peer.timers.needAnotherKeepalive.Get() { if peer.timers.needAnotherKeepalive.Load() {
peer.timers.needAnotherKeepalive.Set(false) peer.timers.needAnotherKeepalive.Store(false)
if peer.timersActive() { if peer.timersActive() {
peer.timers.sendKeepalive.Mod(KeepaliveTimeout) peer.timers.sendKeepalive.Mod(KeepaliveTimeout)
} }
@ -138,7 +137,7 @@ func expiredZeroKeyMaterial(peer *Peer) {
} }
func expiredPersistentKeepalive(peer *Peer) { func expiredPersistentKeepalive(peer *Peer) {
if atomic.LoadUint32(&peer.persistentKeepaliveInterval) > 0 { if peer.persistentKeepaliveInterval.Load() > 0 {
peer.SendKeepalive() peer.SendKeepalive()
} }
} }
@ -156,7 +155,7 @@ func (peer *Peer) timersDataReceived() {
if !peer.timers.sendKeepalive.IsPending() { if !peer.timers.sendKeepalive.IsPending() {
peer.timers.sendKeepalive.Mod(KeepaliveTimeout) peer.timers.sendKeepalive.Mod(KeepaliveTimeout)
} else { } else {
peer.timers.needAnotherKeepalive.Set(true) peer.timers.needAnotherKeepalive.Store(true)
} }
} }
} }
@ -187,9 +186,9 @@ func (peer *Peer) timersHandshakeComplete() {
if peer.timersActive() { if peer.timersActive() {
peer.timers.retransmitHandshake.Del() peer.timers.retransmitHandshake.Del()
} }
atomic.StoreUint32(&peer.timers.handshakeAttempts, 0) peer.timers.handshakeAttempts.Store(0)
peer.timers.sentLastMinuteHandshake.Set(false) peer.timers.sentLastMinuteHandshake.Store(false)
atomic.StoreInt64(&peer.stats.lastHandshakeNano, time.Now().UnixNano()) peer.lastHandshakeNano.Store(time.Now().UnixNano())
} }
/* Should be called after an ephemeral key is created, which is before sending a handshake response or after receiving a handshake response. */ /* Should be called after an ephemeral key is created, which is before sending a handshake response or after receiving a handshake response. */
@ -201,7 +200,7 @@ func (peer *Peer) timersSessionDerived() {
/* Should be called before a packet with authentication -- keepalive, data, or handshake -- is sent, or after one is received. */ /* Should be called before a packet with authentication -- keepalive, data, or handshake -- is sent, or after one is received. */
func (peer *Peer) timersAnyAuthenticatedPacketTraversal() { func (peer *Peer) timersAnyAuthenticatedPacketTraversal() {
keepalive := atomic.LoadUint32(&peer.persistentKeepaliveInterval) keepalive := peer.persistentKeepaliveInterval.Load()
if keepalive > 0 && peer.timersActive() { if keepalive > 0 && peer.timersActive() {
peer.timers.persistentKeepalive.Mod(time.Duration(keepalive) * time.Second) peer.timers.persistentKeepalive.Mod(time.Duration(keepalive) * time.Second)
} }
@ -216,9 +215,9 @@ func (peer *Peer) timersInit() {
} }
func (peer *Peer) timersStart() { func (peer *Peer) timersStart() {
atomic.StoreUint32(&peer.timers.handshakeAttempts, 0) peer.timers.handshakeAttempts.Store(0)
peer.timers.sentLastMinuteHandshake.Set(false) peer.timers.sentLastMinuteHandshake.Store(false)
peer.timers.needAnotherKeepalive.Set(false) peer.timers.needAnotherKeepalive.Store(false)
} }
func (peer *Peer) timersStop() { func (peer *Peer) timersStop() {

View File

@ -7,7 +7,6 @@ package device
import ( import (
"fmt" "fmt"
"sync/atomic"
"golang.zx2c4.com/wireguard/tun" "golang.zx2c4.com/wireguard/tun"
) )
@ -33,7 +32,7 @@ func (device *Device) RoutineTUNEventReader() {
tooLarge = fmt.Sprintf(" (too large, capped at %v)", MaxContentSize) tooLarge = fmt.Sprintf(" (too large, capped at %v)", MaxContentSize)
mtu = MaxContentSize mtu = MaxContentSize
} }
old := atomic.SwapInt32(&device.tun.mtu, int32(mtu)) old := device.tun.mtu.Swap(int32(mtu))
if int(old) != mtu { if int(old) != mtu {
device.log.Verbosef("MTU updated: %v%s", mtu, tooLarge) device.log.Verbosef("MTU updated: %v%s", mtu, tooLarge)
} }

View File

@ -16,7 +16,6 @@ import (
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
"sync/atomic"
"time" "time"
"golang.zx2c4.com/wireguard/ipc" "golang.zx2c4.com/wireguard/ipc"
@ -112,15 +111,15 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
sendf("endpoint=%s", peer.endpoint.DstToString()) sendf("endpoint=%s", peer.endpoint.DstToString())
} }
nano := atomic.LoadInt64(&peer.stats.lastHandshakeNano) nano := peer.lastHandshakeNano.Load()
secs := nano / time.Second.Nanoseconds() secs := nano / time.Second.Nanoseconds()
nano %= time.Second.Nanoseconds() nano %= time.Second.Nanoseconds()
sendf("last_handshake_time_sec=%d", secs) sendf("last_handshake_time_sec=%d", secs)
sendf("last_handshake_time_nsec=%d", nano) sendf("last_handshake_time_nsec=%d", nano)
sendf("tx_bytes=%d", atomic.LoadUint64(&peer.stats.txBytes)) sendf("tx_bytes=%d", peer.txBytes.Load())
sendf("rx_bytes=%d", atomic.LoadUint64(&peer.stats.rxBytes)) sendf("rx_bytes=%d", peer.rxBytes.Load())
sendf("persistent_keepalive_interval=%d", atomic.LoadUint32(&peer.persistentKeepaliveInterval)) sendf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval.Load())
device.allowedips.EntriesForPeer(peer, func(prefix netip.Prefix) bool { device.allowedips.EntriesForPeer(peer, func(prefix netip.Prefix) bool {
sendf("allowed_ip=%s", prefix.String()) sendf("allowed_ip=%s", prefix.String())
@ -358,7 +357,7 @@ func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set persistent keepalive interval: %w", err) return ipcErrorf(ipc.IpcErrorInvalid, "failed to set persistent keepalive interval: %w", err)
} }
old := atomic.SwapUint32(&peer.persistentKeepaliveInterval, uint32(secs)) old := peer.persistentKeepaliveInterval.Swap(uint32(secs))
// Send immediate keepalive if we're turning it on and before it wasn't on. // Send immediate keepalive if we're turning it on and before it wasn't on.
peer.pkaOn = old == 0 && secs != 0 peer.pkaOn = old == 0 && secs != 0

2
go.mod
View File

@ -1,6 +1,6 @@
module golang.zx2c4.com/wireguard module golang.zx2c4.com/wireguard
go 1.18 go 1.19
require ( require (
golang.org/x/crypto v0.0.0-20220315160706-3147a52a75dd golang.org/x/crypto v0.0.0-20220315160706-3147a52a75dd

View File

@ -54,7 +54,7 @@ type file struct {
handle windows.Handle handle windows.Handle
wg sync.WaitGroup wg sync.WaitGroup
wgLock sync.RWMutex wgLock sync.RWMutex
closing uint32 // used as atomic boolean closing atomic.Bool
socket bool socket bool
readDeadline deadlineHandler readDeadline deadlineHandler
writeDeadline deadlineHandler writeDeadline deadlineHandler
@ -65,7 +65,7 @@ type deadlineHandler struct {
channel timeoutChan channel timeoutChan
channelLock sync.RWMutex channelLock sync.RWMutex
timer *time.Timer timer *time.Timer
timedout uint32 // used as atomic boolean timedout atomic.Bool
} }
// makeFile makes a new file from an existing file handle // makeFile makes a new file from an existing file handle
@ -89,7 +89,7 @@ func makeFile(h windows.Handle) (*file, error) {
func (f *file) closeHandle() { func (f *file) closeHandle() {
f.wgLock.Lock() f.wgLock.Lock()
// Atomically set that we are closing, releasing the resources only once. // Atomically set that we are closing, releasing the resources only once.
if atomic.SwapUint32(&f.closing, 1) == 0 { if f.closing.Swap(true) == false {
f.wgLock.Unlock() f.wgLock.Unlock()
// cancel all IO and wait for it to complete // cancel all IO and wait for it to complete
windows.CancelIoEx(f.handle, nil) windows.CancelIoEx(f.handle, nil)
@ -112,7 +112,7 @@ func (f *file) Close() error {
// The caller must call f.wg.Done() when the IO is finished, prior to Close() returning. // The caller must call f.wg.Done() when the IO is finished, prior to Close() returning.
func (f *file) prepareIo() (*ioOperation, error) { func (f *file) prepareIo() (*ioOperation, error) {
f.wgLock.RLock() f.wgLock.RLock()
if atomic.LoadUint32(&f.closing) == 1 { if f.closing.Load() {
f.wgLock.RUnlock() f.wgLock.RUnlock()
return nil, os.ErrClosed return nil, os.ErrClosed
} }
@ -144,7 +144,7 @@ func (f *file) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, err err
return int(bytes), err return int(bytes), err
} }
if atomic.LoadUint32(&f.closing) == 1 { if f.closing.Load() {
windows.CancelIoEx(f.handle, &c.o) windows.CancelIoEx(f.handle, &c.o)
} }
@ -160,7 +160,7 @@ func (f *file) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, err err
case r = <-c.ch: case r = <-c.ch:
err = r.err err = r.err
if err == windows.ERROR_OPERATION_ABORTED { if err == windows.ERROR_OPERATION_ABORTED {
if atomic.LoadUint32(&f.closing) == 1 { if f.closing.Load() {
err = os.ErrClosed err = os.ErrClosed
} }
} else if err != nil && f.socket { } else if err != nil && f.socket {
@ -192,7 +192,7 @@ func (f *file) Read(b []byte) (int, error) {
} }
defer f.wg.Done() defer f.wg.Done()
if atomic.LoadUint32(&f.readDeadline.timedout) == 1 { if f.readDeadline.timedout.Load() {
return 0, os.ErrDeadlineExceeded return 0, os.ErrDeadlineExceeded
} }
@ -219,7 +219,7 @@ func (f *file) Write(b []byte) (int, error) {
} }
defer f.wg.Done() defer f.wg.Done()
if atomic.LoadUint32(&f.writeDeadline.timedout) == 1 { if f.writeDeadline.timedout.Load() {
return 0, os.ErrDeadlineExceeded return 0, os.ErrDeadlineExceeded
} }
@ -256,7 +256,7 @@ func (d *deadlineHandler) set(deadline time.Time) error {
} }
d.timer = nil d.timer = nil
} }
atomic.StoreUint32(&d.timedout, 0) d.timedout.Store(false)
select { select {
case <-d.channel: case <-d.channel:
@ -271,7 +271,7 @@ func (d *deadlineHandler) set(deadline time.Time) error {
} }
timeoutIO := func() { timeoutIO := func() {
atomic.StoreUint32(&d.timedout, 1) d.timedout.Store(true)
close(d.channel) close(d.channel)
} }

View File

@ -29,7 +29,7 @@ type pipe struct {
type messageBytePipe struct { type messageBytePipe struct {
pipe pipe
writeClosed int32 writeClosed atomic.Bool
readEOF bool readEOF bool
} }
@ -51,17 +51,17 @@ func (f *pipe) SetDeadline(t time.Time) error {
// CloseWrite closes the write side of a message pipe in byte mode. // CloseWrite closes the write side of a message pipe in byte mode.
func (f *messageBytePipe) CloseWrite() error { func (f *messageBytePipe) CloseWrite() error {
if !atomic.CompareAndSwapInt32(&f.writeClosed, 0, 1) { if !f.writeClosed.CompareAndSwap(false, true) {
return io.ErrClosedPipe return io.ErrClosedPipe
} }
err := f.file.Flush() err := f.file.Flush()
if err != nil { if err != nil {
atomic.StoreInt32(&f.writeClosed, 0) f.writeClosed.Store(false)
return err return err
} }
_, err = f.file.Write(nil) _, err = f.file.Write(nil)
if err != nil { if err != nil {
atomic.StoreInt32(&f.writeClosed, 0) f.writeClosed.Store(false)
return err return err
} }
return nil return nil
@ -70,7 +70,7 @@ func (f *messageBytePipe) CloseWrite() error {
// Write writes bytes to a message pipe in byte mode. Zero-byte writes are ignored, since // Write writes bytes to a message pipe in byte mode. Zero-byte writes are ignored, since
// they are used to implement CloseWrite. // they are used to implement CloseWrite.
func (f *messageBytePipe) Write(b []byte) (int, error) { func (f *messageBytePipe) Write(b []byte) (int, error) {
if atomic.LoadInt32(&f.writeClosed) != 0 { if f.writeClosed.Load() {
return 0, io.ErrClosedPipe return 0, io.ErrClosedPipe
} }
if len(b) == 0 { if len(b) == 0 {

View File

@ -26,10 +26,10 @@ const (
) )
type rateJuggler struct { type rateJuggler struct {
current uint64 current atomic.Uint64
nextByteCount uint64 nextByteCount atomic.Uint64
nextStartTime int64 nextStartTime atomic.Int64
changing int32 changing atomic.Bool
} }
type NativeTun struct { type NativeTun struct {
@ -42,7 +42,7 @@ type NativeTun struct {
events chan Event events chan Event
running sync.WaitGroup running sync.WaitGroup
closeOnce sync.Once closeOnce sync.Once
close int32 close atomic.Bool
forcedMTU int forcedMTU int
} }
@ -57,18 +57,14 @@ func procyield(cycles uint32)
//go:linkname nanotime runtime.nanotime //go:linkname nanotime runtime.nanotime
func nanotime() int64 func nanotime() int64
//
// CreateTUN creates a Wintun interface with the given name. Should a Wintun // CreateTUN creates a Wintun interface with the given name. Should a Wintun
// interface with the same name exist, it is reused. // interface with the same name exist, it is reused.
//
func CreateTUN(ifname string, mtu int) (Device, error) { func CreateTUN(ifname string, mtu int) (Device, error) {
return CreateTUNWithRequestedGUID(ifname, WintunStaticRequestedGUID, mtu) return CreateTUNWithRequestedGUID(ifname, WintunStaticRequestedGUID, mtu)
} }
//
// CreateTUNWithRequestedGUID creates a Wintun interface with the given name and // CreateTUNWithRequestedGUID creates a Wintun interface with the given name and
// a requested GUID. Should a Wintun interface with the same name exist, it is reused. // a requested GUID. Should a Wintun interface with the same name exist, it is reused.
//
func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu int) (Device, error) { func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu int) (Device, error) {
wt, err := wintun.CreateAdapter(ifname, WintunTunnelType, requestedGUID) wt, err := wintun.CreateAdapter(ifname, WintunTunnelType, requestedGUID)
if err != nil { if err != nil {
@ -113,7 +109,7 @@ func (tun *NativeTun) Events() chan Event {
func (tun *NativeTun) Close() error { func (tun *NativeTun) Close() error {
var err error var err error
tun.closeOnce.Do(func() { tun.closeOnce.Do(func() {
atomic.StoreInt32(&tun.close, 1) tun.close.Store(true)
windows.SetEvent(tun.readWait) windows.SetEvent(tun.readWait)
tun.running.Wait() tun.running.Wait()
tun.session.End() tun.session.End()
@ -144,13 +140,13 @@ func (tun *NativeTun) Read(buff []byte, offset int) (int, error) {
tun.running.Add(1) tun.running.Add(1)
defer tun.running.Done() defer tun.running.Done()
retry: retry:
if atomic.LoadInt32(&tun.close) == 1 { if tun.close.Load() {
return 0, os.ErrClosed return 0, os.ErrClosed
} }
start := nanotime() start := nanotime()
shouldSpin := atomic.LoadUint64(&tun.rate.current) >= spinloopRateThreshold && uint64(start-atomic.LoadInt64(&tun.rate.nextStartTime)) <= rateMeasurementGranularity*2 shouldSpin := tun.rate.current.Load() >= spinloopRateThreshold && uint64(start-tun.rate.nextStartTime.Load()) <= rateMeasurementGranularity*2
for { for {
if atomic.LoadInt32(&tun.close) == 1 { if tun.close.Load() {
return 0, os.ErrClosed return 0, os.ErrClosed
} }
packet, err := tun.session.ReceivePacket() packet, err := tun.session.ReceivePacket()
@ -184,7 +180,7 @@ func (tun *NativeTun) Flush() error {
func (tun *NativeTun) Write(buff []byte, offset int) (int, error) { func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
tun.running.Add(1) tun.running.Add(1)
defer tun.running.Done() defer tun.running.Done()
if atomic.LoadInt32(&tun.close) == 1 { if tun.close.Load() {
return 0, os.ErrClosed return 0, os.ErrClosed
} }
@ -210,7 +206,7 @@ func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
func (tun *NativeTun) LUID() uint64 { func (tun *NativeTun) LUID() uint64 {
tun.running.Add(1) tun.running.Add(1)
defer tun.running.Done() defer tun.running.Done()
if atomic.LoadInt32(&tun.close) == 1 { if tun.close.Load() {
return 0 return 0
} }
return tun.wt.LUID() return tun.wt.LUID()
@ -223,15 +219,15 @@ func (tun *NativeTun) RunningVersion() (version uint32, err error) {
func (rate *rateJuggler) update(packetLen uint64) { func (rate *rateJuggler) update(packetLen uint64) {
now := nanotime() now := nanotime()
total := atomic.AddUint64(&rate.nextByteCount, packetLen) total := rate.nextByteCount.Add(packetLen)
period := uint64(now - atomic.LoadInt64(&rate.nextStartTime)) period := uint64(now - rate.nextStartTime.Load())
if period >= rateMeasurementGranularity { if period >= rateMeasurementGranularity {
if !atomic.CompareAndSwapInt32(&rate.changing, 0, 1) { if !rate.changing.CompareAndSwap(false, true) {
return return
} }
atomic.StoreInt64(&rate.nextStartTime, now) rate.nextStartTime.Store(now)
atomic.StoreUint64(&rate.current, total*uint64(time.Second/time.Nanosecond)/period) rate.current.Store(total * uint64(time.Second/time.Nanosecond) / period)
atomic.StoreUint64(&rate.nextByteCount, 0) rate.nextByteCount.Store(0)
atomic.StoreInt32(&rate.changing, 0) rate.changing.Store(false)
} }
} }