diff --git a/tun/tun_windows.go b/tun/tun_windows.go index 684d6f0..63eb812 100644 --- a/tun/tun_windows.go +++ b/tun/tun_windows.go @@ -9,10 +9,9 @@ import ( "errors" "fmt" "os" - "sync" "sync/atomic" "time" - "unsafe" + _ "unsafe" "golang.org/x/sys/windows" @@ -40,8 +39,8 @@ type NativeTun struct { errors chan error forcedMTU int rate rateJuggler - rings *wintun.RingDescriptor - writeLock sync.Mutex + session wintun.Session + readWait windows.Handle } var WintunPool *wintun.Pool @@ -103,17 +102,13 @@ func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu forcedMTU: forcedMTU, } - tun.rings, err = wintun.NewRingDescriptor() + tun.session, err = wt.StartSession(0x800000) // Ring capacity, 8 MiB if err != nil { - tun.Close() - return nil, fmt.Errorf("Error creating events: %v", err) - } - - tun.handle, err = tun.wt.Register(tun.rings) - if err != nil { - tun.Close() - return nil, fmt.Errorf("Error registering rings: %v", err) + _, err = tun.wt.Delete(false) + close(tun.events) + return nil, fmt.Errorf("Error starting session: %v", err) } + tun.readWait = tun.session.ReadWaitEvent() return tun, nil } @@ -131,13 +126,7 @@ func (tun *NativeTun) Events() chan Event { func (tun *NativeTun) Close() error { tun.close = true - if tun.rings.Send.TailMoved != 0 { - windows.SetEvent(tun.rings.Send.TailMoved) // wake the reader if it's sleeping - } - if tun.handle != windows.InvalidHandle { - windows.CloseHandle(tun.handle) - } - tun.rings.Close() + tun.session.End() var err error if tun.wt != nil { _, err = tun.wt.Delete(false) @@ -164,56 +153,34 @@ retry: return 0, err default: } - if tun.close { - return 0, os.ErrClosed - } - - buffHead := atomic.LoadUint32(&tun.rings.Send.Ring.Head) - if buffHead >= wintun.PacketCapacity { - return 0, os.ErrClosed - } - start := nanotime() shouldSpin := atomic.LoadUint64(&tun.rate.current) >= spinloopRateThreshold && uint64(start-atomic.LoadInt64(&tun.rate.nextStartTime)) <= rateMeasurementGranularity*2 - var buffTail uint32 for { - buffTail = atomic.LoadUint32(&tun.rings.Send.Ring.Tail) - if buffHead != buffTail { - break - } if tun.close { return 0, os.ErrClosed } - if !shouldSpin || uint64(nanotime()-start) >= spinloopDuration { - windows.WaitForSingleObject(tun.rings.Send.TailMoved, windows.INFINITE) - goto retry + packet, err := tun.session.ReceivePacket() + switch err { + case nil: + packetSize := len(packet) + copy(buff[offset:], packet) + tun.session.ReleaseReceivePacket(packet) + tun.rate.update(uint64(packetSize)) + return packetSize, nil + case windows.ERROR_NO_MORE_ITEMS: + if !shouldSpin || uint64(nanotime()-start) >= spinloopDuration { + windows.WaitForSingleObject(tun.readWait, windows.INFINITE) + goto retry + } + procyield(1) + continue + case windows.ERROR_HANDLE_EOF: + return 0, os.ErrClosed + case windows.ERROR_INVALID_DATA: + return 0, errors.New("Send ring corrupt") } - procyield(1) + return 0, fmt.Errorf("Read failed: %v", err) } - if buffTail >= wintun.PacketCapacity { - return 0, os.ErrClosed - } - - buffContent := tun.rings.Send.Ring.Wrap(buffTail - buffHead) - if buffContent < uint32(unsafe.Sizeof(wintun.PacketHeader{})) { - return 0, errors.New("incomplete packet header in send ring") - } - - packet := (*wintun.Packet)(unsafe.Pointer(&tun.rings.Send.Ring.Data[buffHead])) - if packet.Size > wintun.PacketSizeMax { - return 0, errors.New("packet too big in send ring") - } - - alignedPacketSize := wintun.PacketAlign(uint32(unsafe.Sizeof(wintun.PacketHeader{})) + packet.Size) - if alignedPacketSize > buffContent { - return 0, errors.New("incomplete packet in send ring") - } - - copy(buff[offset:], packet.Data[:packet.Size]) - buffHead = tun.rings.Send.Ring.Wrap(buffHead + alignedPacketSize) - atomic.StoreUint32(&tun.rings.Send.Ring.Head, buffHead) - tun.rate.update(uint64(packet.Size)) - return int(packet.Size), nil } func (tun *NativeTun) Flush() error { @@ -225,36 +192,22 @@ func (tun *NativeTun) Write(buff []byte, offset int) (int, error) { return 0, os.ErrClosed } - packetSize := uint32(len(buff) - offset) + packetSize := len(buff) - offset tun.rate.update(uint64(packetSize)) - alignedPacketSize := wintun.PacketAlign(uint32(unsafe.Sizeof(wintun.PacketHeader{})) + packetSize) - tun.writeLock.Lock() - defer tun.writeLock.Unlock() - - buffHead := atomic.LoadUint32(&tun.rings.Receive.Ring.Head) - if buffHead >= wintun.PacketCapacity { - return 0, os.ErrClosed + packet, err := tun.session.AllocateSendPacket(packetSize) + if err == nil { + copy(packet, buff[offset:]) + tun.session.SendPacket(packet) + return packetSize, nil } - - buffTail := atomic.LoadUint32(&tun.rings.Receive.Ring.Tail) - if buffTail >= wintun.PacketCapacity { + switch err { + case windows.ERROR_HANDLE_EOF: return 0, os.ErrClosed - } - - buffSpace := tun.rings.Receive.Ring.Wrap(buffHead - buffTail - wintun.PacketAlignment) - if alignedPacketSize > buffSpace { + case windows.ERROR_BUFFER_OVERFLOW: return 0, nil // Dropping when ring is full. } - - packet := (*wintun.Packet)(unsafe.Pointer(&tun.rings.Receive.Ring.Data[buffTail])) - packet.Size = packetSize - copy(packet.Data[:packetSize], buff[offset:]) - atomic.StoreUint32(&tun.rings.Receive.Ring.Tail, tun.rings.Receive.Ring.Wrap(buffTail+alignedPacketSize)) - if atomic.LoadInt32(&tun.rings.Receive.Ring.Alertable) != 0 { - windows.SetEvent(tun.rings.Receive.TailMoved) - } - return int(packetSize), nil + return 0, fmt.Errorf("Write failed: %v", err) } // LUID returns Windows interface instance ID. diff --git a/tun/wintun/ring_windows.go b/tun/wintun/ring_windows.go deleted file mode 100644 index ed460fb..0000000 --- a/tun/wintun/ring_windows.go +++ /dev/null @@ -1,117 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2020 WireGuard LLC. All Rights Reserved. - */ - -package wintun - -import ( - "runtime" - "unsafe" - - "golang.org/x/sys/windows" -) - -const ( - PacketAlignment = 4 // Number of bytes packets are aligned to in rings - PacketSizeMax = 0xffff // Maximum packet size - PacketCapacity = 0x800000 // Ring capacity, 8MiB - PacketTrailingSize = uint32(unsafe.Sizeof(PacketHeader{})) + ((PacketSizeMax + (PacketAlignment - 1)) &^ (PacketAlignment - 1)) - PacketAlignment - ioctlRegisterRings = (51820 << 16) | (0x970 << 2) | 0 /*METHOD_BUFFERED*/ | (0x3 /*FILE_READ_DATA | FILE_WRITE_DATA*/ << 14) -) - -type PacketHeader struct { - Size uint32 -} - -type Packet struct { - PacketHeader - Data [PacketSizeMax]byte -} - -type Ring struct { - Head uint32 - Tail uint32 - Alertable int32 - Data [PacketCapacity + PacketTrailingSize]byte -} - -type RingDescriptor struct { - Send, Receive struct { - Size uint32 - Ring *Ring - TailMoved windows.Handle - } -} - -// Wrap returns value modulo ring capacity -func (rb *Ring) Wrap(value uint32) uint32 { - return value & (PacketCapacity - 1) -} - -// Aligns a packet size to PacketAlignment -func PacketAlign(size uint32) uint32 { - return (size + (PacketAlignment - 1)) &^ (PacketAlignment - 1) -} - -func NewRingDescriptor() (descriptor *RingDescriptor, err error) { - descriptor = new(RingDescriptor) - allocatedRegion, err := windows.VirtualAlloc(0, unsafe.Sizeof(Ring{})*2, windows.MEM_COMMIT|windows.MEM_RESERVE, windows.PAGE_READWRITE) - if err != nil { - return - } - defer func() { - if err != nil { - descriptor.free() - descriptor = nil - } - }() - descriptor.Send.Size = uint32(unsafe.Sizeof(Ring{})) - descriptor.Send.Ring = (*Ring)(unsafe.Pointer(allocatedRegion)) - descriptor.Send.TailMoved, err = windows.CreateEvent(nil, 0, 0, nil) - if err != nil { - return - } - - descriptor.Receive.Size = uint32(unsafe.Sizeof(Ring{})) - descriptor.Receive.Ring = (*Ring)(unsafe.Pointer(allocatedRegion + unsafe.Sizeof(Ring{}))) - descriptor.Receive.TailMoved, err = windows.CreateEvent(nil, 0, 0, nil) - if err != nil { - windows.CloseHandle(descriptor.Send.TailMoved) - return - } - runtime.SetFinalizer(descriptor, func(d *RingDescriptor) { d.free() }) - return -} - -func (descriptor *RingDescriptor) free() { - if descriptor.Send.Ring != nil { - windows.VirtualFree(uintptr(unsafe.Pointer(descriptor.Send.Ring)), 0, windows.MEM_RELEASE) - descriptor.Send.Ring = nil - descriptor.Receive.Ring = nil - } -} - -func (descriptor *RingDescriptor) Close() { - if descriptor.Send.TailMoved != 0 { - windows.CloseHandle(descriptor.Send.TailMoved) - descriptor.Send.TailMoved = 0 - } - if descriptor.Send.TailMoved != 0 { - windows.CloseHandle(descriptor.Receive.TailMoved) - descriptor.Receive.TailMoved = 0 - } -} - -func (wintun *Adapter) Register(descriptor *RingDescriptor) (windows.Handle, error) { - handle, err := wintun.OpenAdapterDeviceObject() - if err != nil { - return 0, err - } - var bytesReturned uint32 - err = windows.DeviceIoControl(handle, ioctlRegisterRings, (*byte)(unsafe.Pointer(descriptor)), uint32(unsafe.Sizeof(*descriptor)), nil, 0, &bytesReturned, nil) - if err != nil { - return 0, err - } - return handle, nil -} diff --git a/tun/wintun/session_windows.go b/tun/wintun/session_windows.go new file mode 100644 index 0000000..1619e5a --- /dev/null +++ b/tun/wintun/session_windows.go @@ -0,0 +1,108 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2020 WireGuard LLC. All Rights Reserved. + */ + +package wintun + +import ( + "syscall" + "unsafe" + + "golang.org/x/sys/windows" +) + +type Session struct { + handle uintptr +} + +const ( + PacketSizeMax = 0xffff // Maximum packet size + RingCapacityMin = 0x20000 // Minimum ring capacity (128 kiB) + RingCapacityMax = 0x4000000 // Maximum ring capacity (64 MiB) +) + +// Packet with data +type Packet struct { + Next *Packet // Pointer to next packet in queue + Size uint32 // Size of packet (max WINTUN_MAX_IP_PACKET_SIZE) + Data *[PacketSizeMax]byte // Pointer to layer 3 IPv4 or IPv6 packet +} + +var ( + procWintunAllocateSendPacket = modwintun.NewProc("WintunAllocateSendPacket").Addr() + procWintunEndSession = modwintun.NewProc("WintunEndSession") + procWintunGetReadWaitEvent = modwintun.NewProc("WintunGetReadWaitEvent") + procWintunReceivePacket = modwintun.NewProc("WintunReceivePacket").Addr() + procWintunReleaseReceivePacket = modwintun.NewProc("WintunReleaseReceivePacket").Addr() + procWintunSendPacket = modwintun.NewProc("WintunSendPacket").Addr() + procWintunStartSession = modwintun.NewProc("WintunStartSession") +) + +func (wintun *Adapter) StartSession(capacity uint32) (session Session, err error) { + r0, _, e1 := syscall.Syscall(procWintunStartSession.Addr(), 2, uintptr(wintun.handle), uintptr(capacity), 0) + if r0 == 0 { + err = e1 + } else { + session = Session{r0} + } + return +} + +func (session Session) End() { + syscall.Syscall(procWintunEndSession.Addr(), 1, session.handle, 0, 0) + session.handle = 0 +} + +func (session Session) ReadWaitEvent() (handle windows.Handle) { + r0, _, _ := syscall.Syscall(procWintunGetReadWaitEvent.Addr(), 1, session.handle, 0, 0) + handle = windows.Handle(r0) + return +} + +func (session Session) ReceivePacket() (packet []byte, err error) { + var packetSize uint32 + r0, _, e1 := syscall.Syscall(procWintunReceivePacket, 2, session.handle, uintptr(unsafe.Pointer(&packetSize)), 0) + if r0 == 0 { + err = e1 + return + } + unsafeSlice(unsafe.Pointer(&packet), unsafe.Pointer(r0), int(packetSize)) + return +} + +func (session Session) ReleaseReceivePacket(packet []byte) { + syscall.Syscall(procWintunReleaseReceivePacket, 2, session.handle, uintptr(unsafe.Pointer(&packet[0])), 0) +} + +func (session Session) AllocateSendPacket(packetSize int) (packet []byte, err error) { + r0, _, e1 := syscall.Syscall(procWintunAllocateSendPacket, 2, session.handle, uintptr(packetSize), 0) + if r0 == 0 { + err = e1 + return + } + unsafeSlice(unsafe.Pointer(&packet), unsafe.Pointer(r0), int(packetSize)) + return +} + +func (session Session) SendPacket(packet []byte) { + syscall.Syscall(procWintunSendPacket, 2, session.handle, uintptr(unsafe.Pointer(&packet[0])), 0) +} + +// unsafeSlice updates the slice slicePtr to be a slice +// referencing the provided data with its length & capacity set to +// lenCap. +// +// TODO: when Go 1.16 or Go 1.17 is the minimum supported version, +// update callers to use unsafe.Slice instead of this. +func unsafeSlice(slicePtr, data unsafe.Pointer, lenCap int) { + type sliceHeader struct { + Data unsafe.Pointer + Len int + Cap int + } + h := (*sliceHeader)(slicePtr) + h.Data = data + h.Len = lenCap + h.Cap = lenCap +} diff --git a/tun/wintun/wintun_windows.go b/tun/wintun/wintun_windows.go index ac33579..e7ba8b6 100644 --- a/tun/wintun/wintun_windows.go +++ b/tun/wintun/wintun_windows.go @@ -45,7 +45,6 @@ var ( procWintunGetAdapterLUID = modwintun.NewProc("WintunGetAdapterLUID") procWintunGetAdapterName = modwintun.NewProc("WintunGetAdapterName") procWintunGetRunningDriverVersion = modwintun.NewProc("WintunGetRunningDriverVersion") - procWintunOpenAdapterDeviceObject = modwintun.NewProc("WintunOpenAdapterDeviceObject") procWintunSetAdapterName = modwintun.NewProc("WintunSetAdapterName") procWintunSetLogger = modwintun.NewProc("WintunSetLogger") ) @@ -210,16 +209,6 @@ func RunningVersion() (version uint32, err error) { return } -// handle returns a handle to the adapter device object. Release handle with windows.CloseHandle -func (wintun *Adapter) OpenAdapterDeviceObject() (handle windows.Handle, err error) { - r0, _, e1 := syscall.Syscall(procWintunOpenAdapterDeviceObject.Addr(), 1, uintptr(wintun.handle), 0, 0) - handle = windows.Handle(r0) - if handle == windows.InvalidHandle { - err = e1 - } - return -} - // LUID returns the LUID of the adapter. func (wintun *Adapter) LUID() (luid uint64) { syscall.Syscall(procWintunGetAdapterLUID.Addr(), 2, uintptr(wintun.handle), uintptr(unsafe.Pointer(&luid)), 0)