From 01786286c1dbda06cf74aa4b389a9f43fcbfb644 Mon Sep 17 00:00:00 2001 From: "Jason A. Donenfeld" Date: Sun, 18 Aug 2019 11:49:37 +0200 Subject: [PATCH] tun: windows: don't spin unless we really need it --- tun/tun_windows.go | 52 +++++++++++++++++++++++++++++++++++++--------- 1 file changed, 42 insertions(+), 10 deletions(-) diff --git a/tun/tun_windows.go b/tun/tun_windows.go index b0faed8..a66f709 100644 --- a/tun/tun_windows.go +++ b/tun/tun_windows.go @@ -19,11 +19,14 @@ import ( ) const ( - packetAlignment uint32 = 4 // Number of bytes packets are aligned to in rings - packetSizeMax = 0xffff // Maximum packet size - packetCapacity = 0x800000 // Ring capacity, 8MiB - packetTrailingSize = uint32(unsafe.Sizeof(packetHeader{})) + ((packetSizeMax + (packetAlignment - 1)) &^ (packetAlignment - 1)) - packetAlignment - ioctlRegisterRings = (51820 << 16) | (0x970 << 2) | 0 /*METHOD_BUFFERED*/ | (0x3 /*FILE_READ_DATA | FILE_WRITE_DATA*/ << 14) + packetAlignment = 4 // Number of bytes packets are aligned to in rings + packetSizeMax = 0xffff // Maximum packet size + packetCapacity = 0x800000 // Ring capacity, 8MiB + packetTrailingSize = uint32(unsafe.Sizeof(packetHeader{})) + ((packetSizeMax + (packetAlignment - 1)) &^ (packetAlignment - 1)) - packetAlignment + ioctlRegisterRings = (51820 << 16) | (0x970 << 2) | 0 /*METHOD_BUFFERED*/ | (0x3 /*FILE_READ_DATA | FILE_WRITE_DATA*/ << 14) + rateMeasurementGranularity = uint64((time.Second / 2) / time.Nanosecond) + spinloopRateThreshold = 800000000 / 8 // 800mbps + spinloopDuration = uint64(time.Millisecond / 80 / time.Nanosecond) // ~1gbit/s ) type packetHeader struct { @@ -50,6 +53,13 @@ type ringDescriptor struct { } } +type rateJuggler struct { + current uint64 + nextByteCount uint64 + nextStartTime int64 + changing int32 +} + type NativeTun struct { wt *wintun.Wintun handle windows.Handle @@ -58,8 +68,15 @@ type NativeTun struct { events chan Event errors chan error forcedMTU int + rate rateJuggler } +//go:linkname procyield runtime.procyield +func procyield(cycles uint32) + +//go:linkname nanotime runtime.nanotime +func nanotime() int64 + func packetAlign(size uint32) uint32 { return (size + (packetAlignment - 1)) &^ (packetAlignment - 1) } @@ -184,9 +201,6 @@ func (tun *NativeTun) ForceMTU(mtu int) { tun.forcedMTU = mtu } -//go:linkname procyield runtime.procyield -func procyield(cycles uint32) - // Note: Read() and Write() assume the caller comes only from a single thread; there's no locking. func (tun *NativeTun) Read(buff []byte, offset int) (int, error) { @@ -205,7 +219,8 @@ retry: return 0, os.ErrClosed } - start := time.Now() + start := nanotime() + shouldSpin := atomic.LoadUint64(&tun.rate.current) >= spinloopRateThreshold && uint64(start-atomic.LoadInt64(&tun.rate.nextStartTime)) <= rateMeasurementGranularity*2 var buffTail uint32 for { buffTail = atomic.LoadUint32(&tun.rings.send.ring.tail) @@ -215,7 +230,7 @@ retry: if tun.close { return 0, os.ErrClosed } - if time.Since(start) >= time.Millisecond/80 /* ~1gbit/s */ { + if !shouldSpin || uint64(nanotime()-start) >= spinloopDuration { windows.WaitForSingleObject(tun.rings.send.tailMoved, windows.INFINITE) goto retry } @@ -243,6 +258,7 @@ retry: copy(buff[offset:], packet.data[:packet.size]) buffHead = tun.rings.send.ring.wrap(buffHead + alignedPacketSize) atomic.StoreUint32(&tun.rings.send.ring.head, buffHead) + tun.rate.update(uint64(packet.size)) return int(packet.size), nil } @@ -256,6 +272,7 @@ func (tun *NativeTun) Write(buff []byte, offset int) (int, error) { } packetSize := uint32(len(buff) - offset) + tun.rate.update(uint64(packetSize)) alignedPacketSize := packetAlign(uint32(unsafe.Sizeof(packetHeader{})) + packetSize) buffHead := atomic.LoadUint32(&tun.rings.receive.ring.head) @@ -292,3 +309,18 @@ func (tun *NativeTun) LUID() uint64 { func (rb *ring) wrap(value uint32) uint32 { return value & (packetCapacity - 1) } + +func (rate *rateJuggler) update(packetLen uint64) { + now := nanotime() + total := atomic.AddUint64(&rate.nextByteCount, packetLen) + period := uint64(now - atomic.LoadInt64(&rate.nextStartTime)) + if period >= rateMeasurementGranularity { + if !atomic.CompareAndSwapInt32(&rate.changing, 0, 1) { + return + } + atomic.StoreInt64(&rate.nextStartTime, now) + atomic.StoreUint64(&rate.current, total*uint64(time.Second/time.Nanosecond)/period) + atomic.StoreUint64(&rate.nextByteCount, 0) + atomic.StoreInt32(&rate.changing, 0) + } +}