From 895d6c23cd60bd0522c5b6598a69ad6c5f1ab3a7 Mon Sep 17 00:00:00 2001 From: Jordan Whited Date: Mon, 2 Oct 2023 14:43:56 -0700 Subject: [PATCH] tun: unwind summing loop in checksumNoFold() MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit $ benchstat old.txt new.txt goos: linux goarch: amd64 pkg: golang.zx2c4.com/wireguard/tun cpu: 12th Gen Intel(R) Core(TM) i5-12400 │ old.txt │ new.txt │ │ sec/op │ sec/op vs base │ Checksum/64-12 10.670n ± 2% 4.769n ± 0% -55.30% (p=0.000 n=10) Checksum/128-12 19.665n ± 2% 8.032n ± 0% -59.16% (p=0.000 n=10) Checksum/256-12 37.68n ± 1% 16.06n ± 0% -57.37% (p=0.000 n=10) Checksum/512-12 76.61n ± 3% 32.13n ± 0% -58.06% (p=0.000 n=10) Checksum/1024-12 160.55n ± 4% 64.25n ± 0% -59.98% (p=0.000 n=10) Checksum/1500-12 231.05n ± 7% 94.12n ± 0% -59.26% (p=0.000 n=10) Checksum/2048-12 309.5n ± 3% 128.5n ± 0% -58.48% (p=0.000 n=10) Checksum/4096-12 603.8n ± 4% 257.2n ± 0% -57.41% (p=0.000 n=10) Checksum/8192-12 1185.0n ± 3% 515.5n ± 0% -56.50% (p=0.000 n=10) Checksum/9000-12 1328.5n ± 5% 564.8n ± 0% -57.49% (p=0.000 n=10) Checksum/9001-12 1340.5n ± 3% 564.8n ± 0% -57.87% (p=0.000 n=10) geomean 185.3n 77.99n -57.92% Reviewed-by: Adrian Dewhurst Signed-off-by: Jordan Whited Signed-off-by: Jason A. Donenfeld --- tun/checksum.go | 100 +++++++++++++++++++++++++++++++++++++------ tun/checksum_test.go | 35 +++++++++++++++ 2 files changed, 123 insertions(+), 12 deletions(-) create mode 100644 tun/checksum_test.go diff --git a/tun/checksum.go b/tun/checksum.go index f4f8471..29a8fc8 100644 --- a/tun/checksum.go +++ b/tun/checksum.go @@ -3,23 +3,99 @@ package tun import "encoding/binary" // TODO: Explore SIMD and/or other assembly optimizations. +// TODO: Test native endian loads. See RFC 1071 section 2 part B. func checksumNoFold(b []byte, initial uint64) uint64 { ac := initial - i := 0 - n := len(b) - for n >= 4 { - ac += uint64(binary.BigEndian.Uint32(b[i : i+4])) - n -= 4 - i += 4 + + for len(b) >= 128 { + ac += uint64(binary.BigEndian.Uint32(b[:4])) + ac += uint64(binary.BigEndian.Uint32(b[4:8])) + ac += uint64(binary.BigEndian.Uint32(b[8:12])) + ac += uint64(binary.BigEndian.Uint32(b[12:16])) + ac += uint64(binary.BigEndian.Uint32(b[16:20])) + ac += uint64(binary.BigEndian.Uint32(b[20:24])) + ac += uint64(binary.BigEndian.Uint32(b[24:28])) + ac += uint64(binary.BigEndian.Uint32(b[28:32])) + ac += uint64(binary.BigEndian.Uint32(b[32:36])) + ac += uint64(binary.BigEndian.Uint32(b[36:40])) + ac += uint64(binary.BigEndian.Uint32(b[40:44])) + ac += uint64(binary.BigEndian.Uint32(b[44:48])) + ac += uint64(binary.BigEndian.Uint32(b[48:52])) + ac += uint64(binary.BigEndian.Uint32(b[52:56])) + ac += uint64(binary.BigEndian.Uint32(b[56:60])) + ac += uint64(binary.BigEndian.Uint32(b[60:64])) + ac += uint64(binary.BigEndian.Uint32(b[64:68])) + ac += uint64(binary.BigEndian.Uint32(b[68:72])) + ac += uint64(binary.BigEndian.Uint32(b[72:76])) + ac += uint64(binary.BigEndian.Uint32(b[76:80])) + ac += uint64(binary.BigEndian.Uint32(b[80:84])) + ac += uint64(binary.BigEndian.Uint32(b[84:88])) + ac += uint64(binary.BigEndian.Uint32(b[88:92])) + ac += uint64(binary.BigEndian.Uint32(b[92:96])) + ac += uint64(binary.BigEndian.Uint32(b[96:100])) + ac += uint64(binary.BigEndian.Uint32(b[100:104])) + ac += uint64(binary.BigEndian.Uint32(b[104:108])) + ac += uint64(binary.BigEndian.Uint32(b[108:112])) + ac += uint64(binary.BigEndian.Uint32(b[112:116])) + ac += uint64(binary.BigEndian.Uint32(b[116:120])) + ac += uint64(binary.BigEndian.Uint32(b[120:124])) + ac += uint64(binary.BigEndian.Uint32(b[124:128])) + b = b[128:] } - for n >= 2 { - ac += uint64(binary.BigEndian.Uint16(b[i : i+2])) - n -= 2 - i += 2 + if len(b) >= 64 { + ac += uint64(binary.BigEndian.Uint32(b[:4])) + ac += uint64(binary.BigEndian.Uint32(b[4:8])) + ac += uint64(binary.BigEndian.Uint32(b[8:12])) + ac += uint64(binary.BigEndian.Uint32(b[12:16])) + ac += uint64(binary.BigEndian.Uint32(b[16:20])) + ac += uint64(binary.BigEndian.Uint32(b[20:24])) + ac += uint64(binary.BigEndian.Uint32(b[24:28])) + ac += uint64(binary.BigEndian.Uint32(b[28:32])) + ac += uint64(binary.BigEndian.Uint32(b[32:36])) + ac += uint64(binary.BigEndian.Uint32(b[36:40])) + ac += uint64(binary.BigEndian.Uint32(b[40:44])) + ac += uint64(binary.BigEndian.Uint32(b[44:48])) + ac += uint64(binary.BigEndian.Uint32(b[48:52])) + ac += uint64(binary.BigEndian.Uint32(b[52:56])) + ac += uint64(binary.BigEndian.Uint32(b[56:60])) + ac += uint64(binary.BigEndian.Uint32(b[60:64])) + b = b[64:] } - if n == 1 { - ac += uint64(b[i]) << 8 + if len(b) >= 32 { + ac += uint64(binary.BigEndian.Uint32(b[:4])) + ac += uint64(binary.BigEndian.Uint32(b[4:8])) + ac += uint64(binary.BigEndian.Uint32(b[8:12])) + ac += uint64(binary.BigEndian.Uint32(b[12:16])) + ac += uint64(binary.BigEndian.Uint32(b[16:20])) + ac += uint64(binary.BigEndian.Uint32(b[20:24])) + ac += uint64(binary.BigEndian.Uint32(b[24:28])) + ac += uint64(binary.BigEndian.Uint32(b[28:32])) + b = b[32:] } + if len(b) >= 16 { + ac += uint64(binary.BigEndian.Uint32(b[:4])) + ac += uint64(binary.BigEndian.Uint32(b[4:8])) + ac += uint64(binary.BigEndian.Uint32(b[8:12])) + ac += uint64(binary.BigEndian.Uint32(b[12:16])) + b = b[16:] + } + if len(b) >= 8 { + ac += uint64(binary.BigEndian.Uint32(b[:4])) + ac += uint64(binary.BigEndian.Uint32(b[4:8])) + b = b[8:] + } + if len(b) >= 4 { + ac += uint64(binary.BigEndian.Uint32(b)) + b = b[4:] + } + if len(b) >= 2 { + ac += uint64(binary.BigEndian.Uint16(b)) + b = b[2:] + } + if len(b) == 1 { + ac += uint64(b[0]) << 8 + } + return ac } diff --git a/tun/checksum_test.go b/tun/checksum_test.go new file mode 100644 index 0000000..c1ccff5 --- /dev/null +++ b/tun/checksum_test.go @@ -0,0 +1,35 @@ +package tun + +import ( + "fmt" + "math/rand" + "testing" +) + +func BenchmarkChecksum(b *testing.B) { + lengths := []int{ + 64, + 128, + 256, + 512, + 1024, + 1500, + 2048, + 4096, + 8192, + 9000, + 9001, + } + + for _, length := range lengths { + b.Run(fmt.Sprintf("%d", length), func(b *testing.B) { + buf := make([]byte, length) + rng := rand.New(rand.NewSource(1)) + rng.Read(buf) + b.ResetTimer() + for i := 0; i < b.N; i++ { + checksum(buf, 0) + } + }) + } +}