From 3c11c0308e4e9fae76e1531f4f49a39f1ae24253 Mon Sep 17 00:00:00 2001 From: "Jason A. Donenfeld" Date: Mon, 22 Feb 2021 18:47:41 +0100 Subject: [PATCH] conn: implement RIO for fast Windows UDP sockets Signed-off-by: Jason A. Donenfeld --- conn/bind_std.go | 2 + conn/bind_windows.go | 581 +++++++++++++++++++++++++++++++ conn/boundif_windows.go | 59 ---- conn/default.go | 2 +- conn/winrio/rio_windows.go | 243 +++++++++++++ device/queueconstants_default.go | 2 +- device/queueconstants_windows.go | 15 + go.mod | 6 +- go.sum | 13 +- 9 files changed, 852 insertions(+), 71 deletions(-) create mode 100644 conn/bind_windows.go delete mode 100644 conn/boundif_windows.go create mode 100644 conn/winrio/rio_windows.go create mode 100644 device/queueconstants_windows.go diff --git a/conn/bind_std.go b/conn/bind_std.go index 193c4fe..28d1464 100644 --- a/conn/bind_std.go +++ b/conn/bind_std.go @@ -128,6 +128,8 @@ func (bind *StdNetBind) Close() error { err2 = bind.ipv6.Close() bind.ipv6 = nil } + bind.blackhole4 = false + bind.blackhole6 = false if err1 != nil { return err1 } diff --git a/conn/bind_windows.go b/conn/bind_windows.go new file mode 100644 index 0000000..1e2712e --- /dev/null +++ b/conn/bind_windows.go @@ -0,0 +1,581 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. + */ + +package conn + +import ( + "encoding/binary" + "io" + "net" + "strconv" + "sync" + "sync/atomic" + "unsafe" + + "golang.org/x/sys/windows" + + "golang.zx2c4.com/wireguard/conn/winrio" +) + +const ( + packetsPerRing = 1024 + bytesPerPacket = 2048 - 32 + receiveSpins = 15 +) + +type ringPacket struct { + addr WinRingEndpoint + data [bytesPerPacket]byte +} + +type ringBuffer struct { + packets uintptr + head, tail uint32 + id winrio.BufferId + iocp windows.Handle + isFull bool + cq winrio.Cq + mu sync.Mutex + overlapped windows.Overlapped +} + +func (rb *ringBuffer) Push() *ringPacket { + for rb.isFull { + panic("ring is full") + } + ret := (*ringPacket)(unsafe.Pointer(rb.packets + (uintptr(rb.tail%packetsPerRing) * unsafe.Sizeof(ringPacket{})))) + rb.tail += 1 + if rb.tail == rb.head { + rb.isFull = true + } + return ret +} + +func (rb *ringBuffer) Return(count uint32) { + if rb.head%packetsPerRing == rb.tail%packetsPerRing && !rb.isFull { + return + } + rb.head += count + rb.isFull = false +} + +type afWinRingBind struct { + sock windows.Handle + rx, tx ringBuffer + rq winrio.Rq + mu sync.Mutex + blackhole bool +} + +// WinRingBind uses Windows registered I/O for fast ring buffered networking. +type WinRingBind struct { + v4, v6 afWinRingBind + mu sync.RWMutex + isOpen uint32 +} + +func NewDefaultBind() Bind { return NewWinRingBind() } + +func NewWinRingBind() Bind { + if !winrio.Initialize() { + return NewStdNetBind() + } + return new(WinRingBind) +} + +type WinRingEndpoint struct { + family uint16 + data [30]byte +} + +var _ Bind = (*WinRingBind)(nil) +var _ Endpoint = (*WinRingEndpoint)(nil) + +func (*WinRingBind) ParseEndpoint(s string) (Endpoint, error) { + host, port, err := net.SplitHostPort(s) + if err != nil { + return nil, err + } + host16, err := windows.UTF16PtrFromString(host) + if err != nil { + return nil, err + } + port16, err := windows.UTF16PtrFromString(port) + if err != nil { + return nil, err + } + hints := windows.AddrinfoW{ + Flags: windows.AI_NUMERICHOST, + Family: windows.AF_UNSPEC, + Socktype: windows.SOCK_DGRAM, + Protocol: windows.IPPROTO_UDP, + } + var addrinfo *windows.AddrinfoW + err = windows.GetAddrInfoW(host16, port16, &hints, &addrinfo) + if err != nil { + return nil, err + } + defer windows.FreeAddrInfoW(addrinfo) + if (addrinfo.Family != windows.AF_INET && addrinfo.Family != windows.AF_INET6) || addrinfo.Addrlen > unsafe.Sizeof(WinRingEndpoint{}) { + return nil, windows.ERROR_INVALID_ADDRESS + } + var src []byte + var dst [unsafe.Sizeof(WinRingEndpoint{})]byte + unsafeSlice(unsafe.Pointer(&src), unsafe.Pointer(addrinfo.Addr), int(addrinfo.Addrlen)) + copy(dst[:], src) + return (*WinRingEndpoint)(unsafe.Pointer(&dst[0])), nil +} + +func (*WinRingEndpoint) ClearSrc() {} + +func (e *WinRingEndpoint) DstIP() net.IP { + switch e.family { + case windows.AF_INET: + return append([]byte{}, e.data[2:6]...) + case windows.AF_INET6: + return append([]byte{}, e.data[6:22]...) + } + return nil +} + +func (e *WinRingEndpoint) SrcIP() net.IP { + return nil // not supported +} + +func (e *WinRingEndpoint) DstToBytes() []byte { + switch e.family { + case windows.AF_INET: + b := make([]byte, 0, 6) + b = append(b, e.data[2:6]...) + b = append(b, e.data[1], e.data[0]) + return b + case windows.AF_INET6: + b := make([]byte, 0, 18) + b = append(b, e.data[6:22]...) + b = append(b, e.data[1], e.data[0]) + return b + } + return nil +} + +func (e *WinRingEndpoint) DstToString() string { + switch e.family { + case windows.AF_INET: + addr := net.UDPAddr{IP: e.data[2:6], Port: int(binary.BigEndian.Uint16(e.data[0:2]))} + return addr.String() + case windows.AF_INET6: + var zone string + if scope := *(*uint32)(unsafe.Pointer(&e.data[22])); scope > 0 { + zone = strconv.FormatUint(uint64(scope), 10) + } + addr := net.UDPAddr{IP: e.data[6:22], Zone: zone, Port: int(binary.BigEndian.Uint16(e.data[0:2]))} + return addr.String() + } + return "" +} + +func (e *WinRingEndpoint) SrcToString() string { + return "" +} + +func (ring *ringBuffer) CloseAndZero() { + if ring.cq != 0 { + winrio.CloseCompletionQueue(ring.cq) + ring.cq = 0 + } + if ring.iocp != 0 { + windows.CloseHandle(ring.iocp) + ring.iocp = 0 + } + if ring.id != 0 { + winrio.DeregisterBuffer(ring.id) + ring.id = 0 + } + if ring.packets != 0 { + windows.VirtualFree(ring.packets, 0, windows.MEM_RELEASE) + ring.packets = 0 + } +} + +func (bind *afWinRingBind) CloseAndZero() { + bind.rx.CloseAndZero() + bind.tx.CloseAndZero() + if bind.sock != 0 { + windows.CloseHandle(bind.sock) + bind.sock = 0 + } + bind.blackhole = false +} + +func (bind *WinRingBind) closeAndZero() { + atomic.StoreUint32(&bind.isOpen, 0) + bind.v4.CloseAndZero() + bind.v6.CloseAndZero() +} + +func (ring *ringBuffer) Open() error { + var err error + packetsLen := unsafe.Sizeof(ringPacket{}) * packetsPerRing + ring.packets, err = windows.VirtualAlloc(0, packetsLen, windows.MEM_COMMIT|windows.MEM_RESERVE, windows.PAGE_READWRITE) + if err != nil { + return err + } + ring.id, err = winrio.RegisterPointer(unsafe.Pointer(ring.packets), uint32(packetsLen)) + if err != nil { + return err + } + ring.iocp, err = windows.CreateIoCompletionPort(windows.InvalidHandle, 0, 0, 0) + if err != nil { + return err + } + ring.cq, err = winrio.CreateIOCPCompletionQueue(packetsPerRing, ring.iocp, 0, &ring.overlapped) + if err != nil { + return err + } + return nil +} + +func (bind *afWinRingBind) Open(family int32, sa windows.Sockaddr) (windows.Sockaddr, error) { + var err error + bind.sock, err = winrio.Socket(family, windows.SOCK_DGRAM, windows.IPPROTO_UDP) + if err != nil { + return nil, err + } + err = bind.rx.Open() + if err != nil { + return nil, err + } + err = bind.tx.Open() + if err != nil { + return nil, err + } + bind.rq, err = winrio.CreateRequestQueue(bind.sock, packetsPerRing, 1, packetsPerRing, 1, bind.rx.cq, bind.tx.cq, 0) + if err != nil { + return nil, err + } + err = windows.Bind(bind.sock, sa) + if err != nil { + return nil, err + } + sa, err = windows.Getsockname(bind.sock) + if err != nil { + return nil, err + } + return sa, nil +} + +func (bind *WinRingBind) Open(port uint16) (selectedPort uint16, err error) { + bind.mu.Lock() + defer bind.mu.Unlock() + defer func() { + if err != nil { + bind.closeAndZero() + } + }() + if atomic.LoadUint32(&bind.isOpen) != 0 { + return 0, ErrBindAlreadyOpen + } + var sa windows.Sockaddr + sa, err = bind.v4.Open(windows.AF_INET, &windows.SockaddrInet4{Port: int(port)}) + if err != nil { + return 0, err + } + sa, err = bind.v6.Open(windows.AF_INET6, &windows.SockaddrInet6{Port: sa.(*windows.SockaddrInet4).Port}) + if err != nil { + return 0, err + } + selectedPort = uint16(sa.(*windows.SockaddrInet6).Port) + for i := 0; i < packetsPerRing; i++ { + err = bind.v4.InsertReceiveRequest() + if err != nil { + return 0, err + } + err = bind.v6.InsertReceiveRequest() + if err != nil { + return 0, err + } + } + atomic.StoreUint32(&bind.isOpen, 1) + return +} + +func (bind *WinRingBind) Close() error { + bind.mu.RLock() + if atomic.LoadUint32(&bind.isOpen) != 1 { + bind.mu.RUnlock() + return nil + } + atomic.StoreUint32(&bind.isOpen, 2) + windows.PostQueuedCompletionStatus(bind.v4.rx.iocp, 0, 0, nil) + windows.PostQueuedCompletionStatus(bind.v4.tx.iocp, 0, 0, nil) + windows.PostQueuedCompletionStatus(bind.v6.rx.iocp, 0, 0, nil) + windows.PostQueuedCompletionStatus(bind.v6.tx.iocp, 0, 0, nil) + bind.mu.RUnlock() + bind.mu.Lock() + defer bind.mu.Unlock() + bind.closeAndZero() + return nil +} + +func (bind *WinRingBind) SetMark(mark uint32) error { + return nil +} + +func (bind *afWinRingBind) InsertReceiveRequest() error { + packet := bind.rx.Push() + dataBuffer := &winrio.Buffer{ + Id: bind.rx.id, + Offset: uint32(uintptr(unsafe.Pointer(&packet.data[0])) - bind.rx.packets), + Length: uint32(len(packet.data)), + } + addressBuffer := &winrio.Buffer{ + Id: bind.rx.id, + Offset: uint32(uintptr(unsafe.Pointer(&packet.addr)) - bind.rx.packets), + Length: uint32(unsafe.Sizeof(packet.addr)), + } + bind.mu.Lock() + defer bind.mu.Unlock() + return winrio.ReceiveEx(bind.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, uintptr(unsafe.Pointer(packet))) +} + +//go:linkname procyield runtime.procyield +func procyield(cycles uint32) + +func (bind *afWinRingBind) Receive(buf []byte, isOpen *uint32) (int, Endpoint, error) { + if atomic.LoadUint32(isOpen) != 1 { + return 0, nil, net.ErrClosed + } + bind.rx.mu.Lock() + defer bind.rx.mu.Unlock() + var count uint32 + var results [1]winrio.Result + for tries := 0; count == 0 && tries < receiveSpins; tries++ { + if tries > 0 { + if atomic.LoadUint32(isOpen) != 1 { + return 0, nil, net.ErrClosed + } + procyield(1) + } + count = winrio.DequeueCompletion(bind.rx.cq, results[:]) + } + if count == 0 { + err := winrio.Notify(bind.rx.cq) + if err != nil { + return 0, nil, err + } + var bytes uint32 + var key uintptr + var overlapped *windows.Overlapped + err = windows.GetQueuedCompletionStatus(bind.rx.iocp, &bytes, &key, &overlapped, windows.INFINITE) + if err != nil { + return 0, nil, err + } + if atomic.LoadUint32(isOpen) != 1 { + return 0, nil, net.ErrClosed + } + count = winrio.DequeueCompletion(bind.rx.cq, results[:]) + if count == 0 { + return 0, nil, io.ErrNoProgress + + } + } + bind.rx.Return(1) + err := bind.InsertReceiveRequest() + if err != nil { + return 0, nil, err + } + if results[0].Status != 0 { + return 0, nil, windows.Errno(results[0].Status) + } + packet := (*ringPacket)(unsafe.Pointer(uintptr(results[0].RequestContext))) + ep := packet.addr + n := copy(buf, packet.data[:results[0].BytesTransferred]) + return n, &ep, nil +} + +func (bind *WinRingBind) ReceiveIPv4(buf []byte) (int, Endpoint, error) { + bind.mu.RLock() + defer bind.mu.RUnlock() + return bind.v4.Receive(buf, &bind.isOpen) +} + +func (bind *WinRingBind) ReceiveIPv6(buf []byte) (int, Endpoint, error) { + bind.mu.RLock() + defer bind.mu.RUnlock() + return bind.v6.Receive(buf, &bind.isOpen) +} + +func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *uint32) error { + if atomic.LoadUint32(isOpen) != 1 { + return net.ErrClosed + } + if len(buf) > bytesPerPacket { + return io.ErrShortBuffer + } + bind.tx.mu.Lock() + defer bind.tx.mu.Unlock() + var results [packetsPerRing]winrio.Result + count := winrio.DequeueCompletion(bind.tx.cq, results[:]) + if count == 0 && bind.tx.isFull { + err := winrio.Notify(bind.tx.cq) + if err != nil { + return err + } + var bytes uint32 + var key uintptr + var overlapped *windows.Overlapped + err = windows.GetQueuedCompletionStatus(bind.tx.iocp, &bytes, &key, &overlapped, windows.INFINITE) + if err != nil { + return err + } + if atomic.LoadUint32(isOpen) != 1 { + return net.ErrClosed + } + count = winrio.DequeueCompletion(bind.tx.cq, results[:]) + if count == 0 { + return io.ErrNoProgress + } + } + if count > 0 { + bind.tx.Return(count) + } + packet := bind.tx.Push() + packet.addr = *nend + copy(packet.data[:], buf) + dataBuffer := &winrio.Buffer{ + Id: bind.tx.id, + Offset: uint32(uintptr(unsafe.Pointer(&packet.data[0])) - bind.tx.packets), + Length: uint32(len(buf)), + } + addressBuffer := &winrio.Buffer{ + Id: bind.tx.id, + Offset: uint32(uintptr(unsafe.Pointer(&packet.addr)) - bind.tx.packets), + Length: uint32(unsafe.Sizeof(packet.addr)), + } + bind.mu.Lock() + defer bind.mu.Unlock() + return winrio.SendEx(bind.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0) +} + +func (bind *WinRingBind) Send(buf []byte, endpoint Endpoint) error { + nend, ok := endpoint.(*WinRingEndpoint) + if !ok { + return ErrWrongEndpointType + } + bind.mu.RLock() + defer bind.mu.RUnlock() + switch nend.family { + case windows.AF_INET: + if bind.v4.blackhole { + return nil + } + return bind.v4.Send(buf, nend, &bind.isOpen) + case windows.AF_INET6: + if bind.v6.blackhole { + return nil + } + return bind.v6.Send(buf, nend, &bind.isOpen) + } + return nil +} + +func (bind *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error { + sysconn, err := bind.ipv4.SyscallConn() + if err != nil { + return err + } + err2 := sysconn.Control(func(fd uintptr) { + err = bindSocketToInterface4(windows.Handle(fd), interfaceIndex) + }) + if err2 != nil { + return err2 + } + if err != nil { + return err + } + bind.blackhole4 = blackhole + return nil +} + +func (bind *StdNetBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error { + sysconn, err := bind.ipv6.SyscallConn() + if err != nil { + return err + } + err2 := sysconn.Control(func(fd uintptr) { + err = bindSocketToInterface6(windows.Handle(fd), interfaceIndex) + }) + if err2 != nil { + return err2 + } + if err != nil { + return err + } + bind.blackhole6 = blackhole + return nil +} +func (bind *WinRingBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error { + bind.mu.RLock() + defer bind.mu.RUnlock() + if atomic.LoadUint32(&bind.isOpen) != 1 { + return net.ErrClosed + } + err := bindSocketToInterface4(bind.v4.sock, interfaceIndex) + if err != nil { + return err + } + bind.v4.blackhole = blackhole + return nil +} + +func (bind *WinRingBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error { + bind.mu.RLock() + defer bind.mu.RUnlock() + if atomic.LoadUint32(&bind.isOpen) != 1 { + return net.ErrClosed + } + err := bindSocketToInterface6(bind.v6.sock, interfaceIndex) + if err != nil { + return err + } + bind.v6.blackhole = blackhole + return nil +} + +func bindSocketToInterface4(handle windows.Handle, interfaceIndex uint32) error { + const IP_UNICAST_IF = 31 + /* MSDN says for IPv4 this needs to be in net byte order, so that it's like an IP address with leading zeros. */ + var bytes [4]byte + binary.BigEndian.PutUint32(bytes[:], interfaceIndex) + interfaceIndex = *(*uint32)(unsafe.Pointer(&bytes[0])) + err := windows.SetsockoptInt(handle, windows.IPPROTO_IP, IP_UNICAST_IF, int(interfaceIndex)) + if err != nil { + return err + } + return nil +} + +func bindSocketToInterface6(handle windows.Handle, interfaceIndex uint32) error { + const IPV6_UNICAST_IF = 31 + return windows.SetsockoptInt(handle, windows.IPPROTO_IPV6, IPV6_UNICAST_IF, int(interfaceIndex)) +} + +// unsafeSlice updates the slice slicePtr to be a slice +// referencing the provided data with its length & capacity set to +// lenCap. +// +// TODO: when Go 1.16 or Go 1.17 is the minimum supported version, +// update callers to use unsafe.Slice instead of this. +func unsafeSlice(slicePtr, data unsafe.Pointer, lenCap int) { + type sliceHeader struct { + Data unsafe.Pointer + Len int + Cap int + } + h := (*sliceHeader)(slicePtr) + h.Data = data + h.Len = lenCap + h.Cap = lenCap +} diff --git a/conn/boundif_windows.go b/conn/boundif_windows.go deleted file mode 100644 index 6f6fdd8..0000000 --- a/conn/boundif_windows.go +++ /dev/null @@ -1,59 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. - */ - -package conn - -import ( - "encoding/binary" - "unsafe" - - "golang.org/x/sys/windows" -) - -const ( - sockoptIP_UNICAST_IF = 31 - sockoptIPV6_UNICAST_IF = 31 -) - -func (bind *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error { - /* MSDN says for IPv4 this needs to be in net byte order, so that it's like an IP address with leading zeros. */ - bytes := make([]byte, 4) - binary.BigEndian.PutUint32(bytes, interfaceIndex) - interfaceIndex = *(*uint32)(unsafe.Pointer(&bytes[0])) - - sysconn, err := bind.ipv4.SyscallConn() - if err != nil { - return err - } - err2 := sysconn.Control(func(fd uintptr) { - err = windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IP, sockoptIP_UNICAST_IF, int(interfaceIndex)) - }) - if err2 != nil { - return err2 - } - if err != nil { - return err - } - bind.blackhole4 = blackhole - return nil -} - -func (bind *StdNetBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error { - sysconn, err := bind.ipv6.SyscallConn() - if err != nil { - return err - } - err2 := sysconn.Control(func(fd uintptr) { - err = windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IPV6, sockoptIPV6_UNICAST_IF, int(interfaceIndex)) - }) - if err2 != nil { - return err2 - } - if err != nil { - return err - } - bind.blackhole6 = blackhole - return nil -} diff --git a/conn/default.go b/conn/default.go index cd9bfb0..161454a 100644 --- a/conn/default.go +++ b/conn/default.go @@ -1,4 +1,4 @@ -// +build !linux +// +build !linux,!windows /* SPDX-License-Identifier: MIT * diff --git a/conn/winrio/rio_windows.go b/conn/winrio/rio_windows.go new file mode 100644 index 0000000..1785a02 --- /dev/null +++ b/conn/winrio/rio_windows.go @@ -0,0 +1,243 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2021 WireGuard LLC. All Rights Reserved. + */ + +package winrio + +import ( + "log" + "sync" + "syscall" + "unsafe" + + "golang.org/x/sys/windows" +) + +const ( + MsgDontNotify = 1 + MsgDefer = 2 + MsgWaitAll = 4 + MsgCommitOnly = 8 + + MaxCqSize = 0x8000000 + + invalidBufferId = 0xFFFFFFFF + invalidCq = 0 + invalidRq = 0 + corruptCq = 0xFFFFFFFF +) + +var extensionFunctionTable struct { + cbSize uint32 + rioReceive uintptr + rioReceiveEx uintptr + rioSend uintptr + rioSendEx uintptr + rioCloseCompletionQueue uintptr + rioCreateCompletionQueue uintptr + rioCreateRequestQueue uintptr + rioDequeueCompletion uintptr + rioDeregisterBuffer uintptr + rioNotify uintptr + rioRegisterBuffer uintptr + rioResizeCompletionQueue uintptr + rioResizeRequestQueue uintptr +} + +type Cq uintptr + +type Rq uintptr + +type BufferId uintptr + +type Buffer struct { + Id BufferId + Offset uint32 + Length uint32 +} + +type Result struct { + Status int32 + BytesTransferred uint32 + SocketContext uint64 + RequestContext uint64 +} + +type notificationCompletionType uint32 + +const ( + eventCompletion notificationCompletionType = 1 + iocpCompletion notificationCompletionType = 2 +) + +type eventNotificationCompletion struct { + completionType notificationCompletionType + event windows.Handle + notifyReset uint32 +} + +type iocpNotificationCompletion struct { + completionType notificationCompletionType + iocp windows.Handle + key uintptr + overlapped *windows.Overlapped +} + +var initialized sync.Once +var available bool + +func Initialize() bool { + initialized.Do(func() { + var ( + err error + socket windows.Handle + cq Cq + ) + defer func() { + if err == nil { + return + } + if maj, _, _ := windows.RtlGetNtVersionNumbers(); maj <= 7 { + return + } + log.Printf("Registered I/O is unavailable: %v", err) + }() + socket, err = Socket(windows.AF_INET, windows.SOCK_DGRAM, windows.IPPROTO_UDP) + if err != nil { + return + } + defer windows.CloseHandle(socket) + var WSAID_MULTIPLE_RIO = &windows.GUID{0x8509e081, 0x96dd, 0x4005, [8]byte{0xb1, 0x65, 0x9e, 0x2e, 0xe8, 0xc7, 0x9e, 0x3f}} + const SIO_GET_MULTIPLE_EXTENSION_FUNCTION_POINTER = 0xc8000024 + ob := uint32(0) + err = windows.WSAIoctl(socket, SIO_GET_MULTIPLE_EXTENSION_FUNCTION_POINTER, + (*byte)(unsafe.Pointer(WSAID_MULTIPLE_RIO)), uint32(unsafe.Sizeof(*WSAID_MULTIPLE_RIO)), + (*byte)(unsafe.Pointer(&extensionFunctionTable)), uint32(unsafe.Sizeof(extensionFunctionTable)), + &ob, nil, 0) + if err != nil { + return + } + // While we should be able to stop here, after getting the function pointers, some anti-virus actually causes + // failures in RIOCreateRequestQueue, so keep going to be certain this is supported. + cq, err = CreatePolledCompletionQueue(2) + if err != nil { + return + } + defer CloseCompletionQueue(cq) + _, err = CreateRequestQueue(socket, 1, 1, 1, 1, cq, cq, 0) + if err != nil { + return + } + available = true + }) + return available +} + +func Socket(af, typ, proto int32) (windows.Handle, error) { + return windows.WSASocket(af, typ, proto, nil, 0, windows.WSA_FLAG_REGISTERED_IO) +} + +func CloseCompletionQueue(cq Cq) { + _, _, _ = syscall.Syscall(extensionFunctionTable.rioCloseCompletionQueue, 1, uintptr(cq), 0, 0) +} + +func CreateEventCompletionQueue(queueSize uint32, event windows.Handle, notifyReset bool) (Cq, error) { + notificationCompletion := &eventNotificationCompletion{ + completionType: eventCompletion, + event: event, + } + if notifyReset { + notificationCompletion.notifyReset = 1 + } + ret, _, err := syscall.Syscall(extensionFunctionTable.rioCreateCompletionQueue, 2, uintptr(queueSize), uintptr(unsafe.Pointer(notificationCompletion)), 0) + if ret == invalidCq { + return 0, err + } + return Cq(ret), nil +} + +func CreateIOCPCompletionQueue(queueSize uint32, iocp windows.Handle, key uintptr, overlapped *windows.Overlapped) (Cq, error) { + notificationCompletion := &iocpNotificationCompletion{ + completionType: iocpCompletion, + iocp: iocp, + overlapped: overlapped, + } + ret, _, err := syscall.Syscall(extensionFunctionTable.rioCreateCompletionQueue, 2, uintptr(queueSize), uintptr(unsafe.Pointer(notificationCompletion)), 0) + if ret == invalidCq { + return 0, err + } + return Cq(ret), nil +} + +func CreatePolledCompletionQueue(queueSize uint32) (Cq, error) { + ret, _, err := syscall.Syscall(extensionFunctionTable.rioCreateCompletionQueue, 2, uintptr(queueSize), 0, 0) + if ret == invalidCq { + return 0, err + } + return Cq(ret), nil +} + +func CreateRequestQueue(socket windows.Handle, maxOutstandingReceive, maxReceiveDataBuffers, maxOutstandingSend, maxSendDataBuffers uint32, receiveCq, sendCq Cq, socketContext uintptr) (Rq, error) { + ret, _, err := syscall.Syscall9(extensionFunctionTable.rioCreateRequestQueue, 8, uintptr(socket), uintptr(maxOutstandingReceive), uintptr(maxReceiveDataBuffers), uintptr(maxOutstandingSend), uintptr(maxSendDataBuffers), uintptr(receiveCq), uintptr(sendCq), socketContext, 0) + if ret == invalidRq { + return 0, err + } + return Rq(ret), nil +} + +func DequeueCompletion(cq Cq, results []Result) uint32 { + var array uintptr + if len(results) > 0 { + array = uintptr(unsafe.Pointer(&results[0])) + } + ret, _, _ := syscall.Syscall(extensionFunctionTable.rioDequeueCompletion, 3, uintptr(cq), array, uintptr(len(results))) + if ret == corruptCq { + panic("cq is corrupt") + } + return uint32(ret) +} + +func DeregisterBuffer(id BufferId) { + _, _, _ = syscall.Syscall(extensionFunctionTable.rioDeregisterBuffer, 1, uintptr(id), 0, 0) +} + +func RegisterBuffer(buffer []byte) (BufferId, error) { + var buf unsafe.Pointer + if len(buffer) > 0 { + buf = unsafe.Pointer(&buffer[0]) + } + return RegisterPointer(buf, uint32(len(buffer))) +} + +func RegisterPointer(ptr unsafe.Pointer, size uint32) (BufferId, error) { + ret, _, err := syscall.Syscall(extensionFunctionTable.rioRegisterBuffer, 2, uintptr(ptr), uintptr(size), 0) + if ret == invalidBufferId { + return 0, err + } + return BufferId(ret), nil +} + +func SendEx(rq Rq, buf *Buffer, dataBufferCount uint32, localAddress, remoteAddress, controlContext, flags *Buffer, sflags uint32, requestContext uintptr) error { + ret, _, err := syscall.Syscall9(extensionFunctionTable.rioSendEx, 9, uintptr(rq), uintptr(unsafe.Pointer(buf)), uintptr(dataBufferCount), uintptr(unsafe.Pointer(localAddress)), uintptr(unsafe.Pointer(remoteAddress)), uintptr(unsafe.Pointer(controlContext)), uintptr(unsafe.Pointer(flags)), uintptr(sflags), requestContext) + if ret == 0 { + return err + } + return nil +} + +func ReceiveEx(rq Rq, buf *Buffer, dataBufferCount uint32, localAddress, remoteAddress, controlContext, flags *Buffer, sflags uint32, requestContext uintptr) error { + ret, _, err := syscall.Syscall9(extensionFunctionTable.rioReceiveEx, 9, uintptr(rq), uintptr(unsafe.Pointer(buf)), uintptr(dataBufferCount), uintptr(unsafe.Pointer(localAddress)), uintptr(unsafe.Pointer(remoteAddress)), uintptr(unsafe.Pointer(controlContext)), uintptr(unsafe.Pointer(flags)), uintptr(sflags), requestContext) + if ret == 0 { + return err + } + return nil +} + +func Notify(cq Cq) error { + ret, _, _ := syscall.Syscall(extensionFunctionTable.rioNotify, 1, uintptr(cq), 0, 0) + if ret != 0 { + return windows.Errno(ret) + } + return nil +} diff --git a/device/queueconstants_default.go b/device/queueconstants_default.go index 773c2ca..d5c6927 100644 --- a/device/queueconstants_default.go +++ b/device/queueconstants_default.go @@ -1,4 +1,4 @@ -// +build !android,!ios +// +build !android,!ios,!windows /* SPDX-License-Identifier: MIT * diff --git a/device/queueconstants_windows.go b/device/queueconstants_windows.go new file mode 100644 index 0000000..e330a6b --- /dev/null +++ b/device/queueconstants_windows.go @@ -0,0 +1,15 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. + */ + +package device + +const ( + QueueStagedSize = 128 + QueueOutboundSize = 1024 + QueueInboundSize = 1024 + QueueHandshakeSize = 1024 + MaxSegmentSize = 2048 - 32 // largest possible UDP datagram + PreallocatedBuffersPerPool = 0 // Disable and allow for infinite memory growth +) diff --git a/go.mod b/go.mod index 11b6c7f..0aa27d8 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module golang.zx2c4.com/wireguard go 1.16 require ( - golang.org/x/crypto v0.0.0-20201221181555-eec23a3978ad - golang.org/x/net v0.0.0-20201224014010-6772e930b67b - golang.org/x/sys v0.0.0-20210105210732-16f7687f5001 + golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83 + golang.org/x/net v0.0.0-20210224082022-3d97a244fca7 + golang.org/x/sys v0.0.0-20210225014209-683adc9d29d7 ) diff --git a/go.sum b/go.sum index 62a8501..1ccf774 100644 --- a/go.sum +++ b/go.sum @@ -1,17 +1,16 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20201221181555-eec23a3978ad h1:DN0cp81fZ3njFcrLCytUHRSUkqBjfTo4Tx9RJTWs0EY= -golang.org/x/crypto v0.0.0-20201221181555-eec23a3978ad/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= +golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83 h1:/ZScEX8SfEmUGRHs0gxpqteO5nfNW6axyZbBdw9A12g= +golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20201224014010-6772e930b67b h1:iFwSg7t5GZmB/Q5TjiEAsdoLDrdJRC1RiF2WhuV29Qw= -golang.org/x/net v0.0.0-20201224014010-6772e930b67b/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20210224082022-3d97a244fca7 h1:OgUuv8lsRpBibGNbSizVwKWlysjaNzmC9gYMhPVfqFM= +golang.org/x/net v0.0.0-20210224082022-3d97a244fca7/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210105210732-16f7687f5001 h1:/dSxr6gT0FNI1MO5WLJo8mTmItROeOKTkDn+7OwWBos= -golang.org/x/sys v0.0.0-20210105210732-16f7687f5001/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210225014209-683adc9d29d7 h1:pk3Y+QnSKjMLfO/HIqzn/Zvv3/IHjRPhwblrmUuodzw= +golang.org/x/sys v0.0.0-20210225014209-683adc9d29d7/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e h1:FDhOuMEY4JVRztM/gsbk+IKUQ8kj74bxZrgw87eMMVc= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=