From 2e24e7dcae421c45aadf4c052a33b784ad320014 Mon Sep 17 00:00:00 2001 From: Simon Rozman Date: Thu, 11 Jul 2019 10:35:47 +0200 Subject: [PATCH] tun: windows: implement ring buffers Signed-off-by: Simon Rozman --- tun/tun_windows.go | 392 ++++++++++++++++++++++----------------------- 1 file changed, 195 insertions(+), 197 deletions(-) diff --git a/tun/tun_windows.go b/tun/tun_windows.go index 4ae5cf0..49cbdad 100644 --- a/tun/tun_windows.go +++ b/tun/tun_windows.go @@ -8,9 +8,9 @@ package tun import ( "errors" "fmt" - "io" "os" "sync" + "sync/atomic" "time" "unsafe" @@ -20,39 +20,54 @@ import ( ) const ( - packetExchangeAlignment uint32 = 4 // Number of bytes packets are aligned to in exchange buffers - packetSizeMax uint32 = 0xf000 - packetExchangeAlignment // Maximum packet size - packetExchangeSize uint32 = 0x100000 // Exchange buffer size (defaults to 1MiB) - retryRate = 4 // Number of retries per second to reopen device pipe - retryTimeout = 30 // Number of seconds to tolerate adapter unavailable + packetAlignment uint32 = 4 // Number of bytes packets are aligned to in rings + packetSizeMax uint32 = 0xffff // Maximum packet size + packetCapacity uint32 = 0x100000 // Ring capacity (defaults to 1MiB, must be a power of 2) + packetTrailingSize uint32 = uint32(unsafe.Sizeof(packetHeader{})) + ((packetSizeMax + (packetAlignment - 1)) &^ (packetAlignment - 1)) - packetAlignment + + ioctlRegisterRings uint32 = (0x22 /*FILE_DEVICE_UNKNOWN*/ << 16) | (0x800 << 2) | 0 /*METHOD_BUFFERED*/ | (0x3 /*FILE_READ_DATA | FILE_WRITE_DATA*/ << 14) + + retryRate = 4 // Number of retries per second to reopen device pipe + retryTimeout = 30 // Number of seconds to tolerate adapter unavailable ) -type exchgBufRead struct { - data [packetExchangeSize]byte - offset uint32 - avail uint32 +type packetHeader struct { + size uint32 } -type exchgBufWrite struct { - data [packetExchangeSize]byte - offset uint32 +type packet struct { + packetHeader + data [packetSizeMax]byte +} + +type ring struct { + head uint32 + tail uint32 + alertable int32 + data [packetCapacity + packetTrailingSize]byte +} + +type ringDescriptor struct { + send, receive struct { + size uint32 + ring *ring + tailMoved windows.Handle + } } type NativeTun struct { - wt *wintun.Wintun - tunFileRead *os.File - tunFileWrite *os.File - tunLock sync.Mutex - close bool - rdBuff *exchgBufRead - wrBuff *exchgBufWrite - events chan Event - errors chan error - forcedMTU int + wt *wintun.Wintun + tunDev windows.Handle + tunLock sync.Mutex + close bool + rings ringDescriptor + events chan Event + errors chan error + forcedMTU int } func packetAlign(size uint32) uint32 { - return (size + (packetExchangeAlignment - 1)) &^ (packetExchangeAlignment - 1) + return (size + (packetAlignment - 1)) &^ (packetAlignment - 1) } var shouldRetryOpen = windows.RtlGetVersion().MajorVersion < 10 @@ -102,14 +117,32 @@ func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID) (Dev return nil, fmt.Errorf("Unable to set name of Wintun interface: %v", err) } - return &NativeTun{ + tun := &NativeTun{ wt: wt, - rdBuff: &exchgBufRead{}, - wrBuff: &exchgBufWrite{}, + tunDev: windows.InvalidHandle, events: make(chan Event, 10), errors: make(chan error, 1), forcedMTU: 1500, - }, nil + } + + tun.rings.send.size = uint32(unsafe.Sizeof(ring{})) + tun.rings.send.ring = &ring{} + tun.rings.send.tailMoved, err = windows.CreateEvent(nil, 0, 0, nil) + if err != nil { + wt.DeleteInterface() + return nil, fmt.Errorf("Error creating event: %v", err) + } + + tun.rings.receive.size = uint32(unsafe.Sizeof(ring{})) + tun.rings.receive.ring = &ring{} + tun.rings.receive.tailMoved, err = windows.CreateEvent(nil, 0, 0, nil) + if err != nil { + windows.CloseHandle(tun.rings.send.tailMoved) + wt.DeleteInterface() + return nil, fmt.Errorf("Error creating event: %v", err) + } + + return tun, nil } func (tun *NativeTun) openTUN() error { @@ -119,9 +152,12 @@ func (tun *NativeTun) openTUN() error { } var err error - name := tun.wt.DataFileName() - for tun.tunFileRead == nil { - tun.tunFileRead, err = os.OpenFile(name, os.O_RDONLY, 0) + name, err := windows.UTF16PtrFromString(tun.wt.DataFileName()) + if err != nil { + return err + } + for tun.tunDev == windows.InvalidHandle { + tun.tunDev, err = windows.CreateFile(name, windows.GENERIC_READ|windows.GENERIC_WRITE, 0, nil, windows.OPEN_EXISTING, 0, 0) if err != nil { if retries > 0 && !tun.close { time.Sleep(time.Second / retryRate) @@ -130,72 +166,51 @@ func (tun *NativeTun) openTUN() error { } return err } - } - for tun.tunFileWrite == nil { - tun.tunFileWrite, err = os.OpenFile(name, os.O_WRONLY, 0) + + atomic.StoreUint32(&tun.rings.send.ring.head, 0) + atomic.StoreUint32(&tun.rings.send.ring.tail, 0) + atomic.StoreInt32(&tun.rings.send.ring.alertable, 0) + atomic.StoreUint32(&tun.rings.receive.ring.head, 0) + atomic.StoreUint32(&tun.rings.receive.ring.tail, 0) + atomic.StoreInt32(&tun.rings.receive.ring.alertable, 0) + + var bytesReturned uint32 + err = windows.DeviceIoControl(tun.tunDev, ioctlRegisterRings, (*byte)(unsafe.Pointer(&tun.rings)), uint32(unsafe.Sizeof(tun.rings)), nil, 0, &bytesReturned, nil) if err != nil { - if retries > 0 && !tun.close { - time.Sleep(time.Second / retryRate) - retries-- - continue - } - return err + return fmt.Errorf("Error registering rings: %v", err) } } return nil } func (tun *NativeTun) closeTUN() (err error) { - for tun.tunFileRead != nil { + for tun.tunDev != windows.InvalidHandle { tun.tunLock.Lock() - if tun.tunFileRead == nil { + if tun.tunDev == windows.InvalidHandle { tun.tunLock.Unlock() break } - t := tun.tunFileRead - tun.tunFileRead = nil - windows.CancelIoEx(windows.Handle(t.Fd()), nil) - err = t.Close() + t := tun.tunDev + tun.tunDev = windows.InvalidHandle + err = windows.CloseHandle(t) tun.tunLock.Unlock() break } - for tun.tunFileWrite != nil { - tun.tunLock.Lock() - if tun.tunFileWrite == nil { - tun.tunLock.Unlock() - break - } - t := tun.tunFileWrite - tun.tunFileWrite = nil - windows.CancelIoEx(windows.Handle(t.Fd()), nil) - err2 := t.Close() - tun.tunLock.Unlock() - if err == nil { - err = err2 - } - break - } return } -func (tun *NativeTun) getTUN() (read *os.File, write *os.File, err error) { - read, write = tun.tunFileRead, tun.tunFileWrite - if read == nil || write == nil { - read, write = nil, nil +func (tun *NativeTun) getTUN() (handle windows.Handle, err error) { + handle = tun.tunDev + if handle == windows.InvalidHandle { tun.tunLock.Lock() - if tun.tunFileRead != nil && tun.tunFileWrite != nil { - read, write = tun.tunFileRead, tun.tunFileWrite - tun.tunLock.Unlock() - return - } - err = tun.closeTUN() - if err != nil { + if tun.tunDev != windows.InvalidHandle { + handle = tun.tunDev tun.tunLock.Unlock() return } err = tun.openTUN() if err == nil { - read, write = tun.tunFileRead, tun.tunFileWrite + handle = tun.tunDev } tun.tunLock.Unlock() return @@ -217,18 +232,30 @@ func (tun *NativeTun) Events() chan Event { func (tun *NativeTun) Close() error { tun.close = true - err1 := tun.closeTUN() + windows.SetEvent(tun.rings.send.tailMoved) // wake the reader if it's sleeping + var err, err2 error + err = tun.closeTUN() if tun.events != nil { close(tun.events) } - _, err2 := tun.wt.DeleteInterface() - if err1 == nil { - err1 = err2 + err2 = windows.CloseHandle(tun.rings.receive.tailMoved) + if err == nil { + err = err2 } - return err1 + err2 = windows.CloseHandle(tun.rings.send.tailMoved) + if err == nil { + err = err2 + } + + _, err2 = tun.wt.DeleteInterface() + if err == nil { + err = err2 + } + + return err } func (tun *NativeTun) MTU() (int, error) { @@ -240,6 +267,8 @@ func (tun *NativeTun) ForceMTU(mtu int) { tun.forcedMTU = mtu } +// Note: Read() and Write() assume the caller comes only from a single thread; there's no locking. + func (tun *NativeTun) Read(buff []byte, offset int) (int, error) { select { case err := <-tun.errors: @@ -248,142 +277,111 @@ func (tun *NativeTun) Read(buff []byte, offset int) (int, error) { } retries := maybeRetry(1000) - for { - if tun.rdBuff.offset+packetExchangeAlignment <= tun.rdBuff.avail { - // Get packet from the exchange buffer. - packet := tun.rdBuff.data[tun.rdBuff.offset:] - size := *(*uint32)(unsafe.Pointer(&packet[0])) - pSize := packetAlign(size) + packetExchangeAlignment - if packetSizeMax < size || tun.rdBuff.avail < tun.rdBuff.offset+pSize { - // Invalid packet size. - tun.rdBuff.avail = 0 - continue - } - packet = packet[packetExchangeAlignment : packetExchangeAlignment+size] - - // Copy data. - copy(buff[offset:], packet) - tun.rdBuff.offset += pSize - return int(size), nil - } - - // Get TUN data pipe. - file, _, err := tun.getTUN() + for !tun.close { + _, err := tun.getTUN() if err != nil { return 0, err } - n, err := file.Read(tun.rdBuff.data[:]) - if err != nil { - tun.rdBuff.offset = 0 - tun.rdBuff.avail = 0 - pe, ok := err.(*os.PathError) - if tun.close { - return 0, os.ErrClosed - } - if retries > 0 && ok && (pe.Err == windows.ERROR_HANDLE_EOF || pe.Err == windows.ERROR_OPERATION_ABORTED) { - retries-- - tun.closeTUN() - time.Sleep(time.Millisecond * 2) - continue - } - return 0, err + buffHead := atomic.LoadUint32(&tun.rings.send.ring.head) + if buffHead >= packetCapacity { + return 0, errors.New("send ring head out of bounds") } - if n == 0 { - if retries == 0 { - return 0, io.ErrShortBuffer - } - retries-- + + buffTail := atomic.LoadUint32(&tun.rings.send.ring.tail) + if buffHead == buffTail { + windows.WaitForSingleObject(tun.rings.send.tailMoved, windows.INFINITE) continue } - tun.rdBuff.offset = 0 - tun.rdBuff.avail = uint32(n) - } -} + if buffTail >= packetCapacity { + if retries > 0 { + tun.closeTUN() + time.Sleep(time.Millisecond * 2) + retries-- + continue + } + return 0, errors.New("send ring tail out of bounds") + } + retries = maybeRetry(1000) -// Note: flush() and putTunPacket() assume the caller comes only from a single thread; there's no locking. + buffContent := tun.rings.send.ring.wrap(buffTail - buffHead) + if buffContent < uint32(unsafe.Sizeof(packetHeader{})) { + return 0, errors.New("incomplete packet header in send ring") + } + + packet := (*packet)(unsafe.Pointer(&tun.rings.send.ring.data[buffHead])) + if packet.size > packetSizeMax { + return 0, errors.New("packet too big in send ring") + } + + alignedPacketSize := packetAlign(uint32(unsafe.Sizeof(packetHeader{})) + packet.size) + if alignedPacketSize > buffContent { + return 0, errors.New("incomplete packet in send ring") + } + + copy(buff[offset:], packet.data[:packet.size]) + buffHead = tun.rings.send.ring.wrap(buffHead + alignedPacketSize) + atomic.StoreUint32(&tun.rings.send.ring.head, buffHead) + return int(packet.size), nil + } + + return 0, os.ErrClosed +} func (tun *NativeTun) Flush() error { - if tun.wrBuff.offset == 0 { - return nil - } - defer func() { - tun.wrBuff.offset = 0 - }() - retries := maybeRetry(1000) - - for { - // Get TUN data pipe. - _, file, err := tun.getTUN() - if err != nil { - return err - } - - for { - _, err = file.Write(tun.wrBuff.data[:tun.wrBuff.offset]) - if err != nil { - pe, ok := err.(*os.PathError) - if tun.close { - return os.ErrClosed - } - if retries > 0 && ok && pe.Err == windows.ERROR_OPERATION_ABORTED { // Adapter is paused or in low-power state. - retries-- - time.Sleep(time.Millisecond * 2) - continue - } - if retries > 0 && ok && pe.Err == windows.ERROR_HANDLE_EOF { // Adapter is going down. - retries-- - tun.closeTUN() - time.Sleep(time.Millisecond * 2) - break - } - return err - } - return nil - } - } -} - -func (tun *NativeTun) putTunPacket(buff []byte) error { - size := uint32(len(buff)) - if size == 0 { - return errors.New("Empty packet") - } - if size > packetSizeMax { - return errors.New("Packet too big") - } - pSize := packetAlign(size) + packetExchangeAlignment - - if tun.wrBuff.offset+pSize >= packetExchangeSize { - // Exchange buffer is full -> flush first. - err := tun.Flush() - if err != nil { - return err - } - } - - // Write packet to the exchange buffer. - packet := tun.wrBuff.data[tun.wrBuff.offset : tun.wrBuff.offset+pSize] - *(*uint32)(unsafe.Pointer(&packet[0])) = size - packet = packet[packetExchangeAlignment : packetExchangeAlignment+size] - copy(packet, buff) - - tun.wrBuff.offset += pSize - return nil } func (tun *NativeTun) Write(buff []byte, offset int) (int, error) { - err := tun.putTunPacket(buff[offset:]) - if err != nil { - return 0, err + retries := maybeRetry(1000) + for { + _, err := tun.getTUN() + if err != nil { + return 0, err + } + + packetSize := uint32(len(buff) - offset) + alignedPacketSize := packetAlign(uint32(unsafe.Sizeof(packetHeader{})) + packetSize) + + buffHead := atomic.LoadUint32(&tun.rings.receive.ring.head) + if buffHead >= packetCapacity { + if retries > 0 { + tun.closeTUN() + time.Sleep(time.Millisecond * 2) + retries-- + continue + } + return 0, errors.New("receive ring head out of bounds") + } + retries = maybeRetry(1000) + + buffTail := atomic.LoadUint32(&tun.rings.receive.ring.tail) + if buffTail >= packetCapacity { + return 0, errors.New("receive ring tail out of bounds") + } + + buffSpace := tun.rings.receive.ring.wrap(buffHead - buffTail - packetAlignment) + if alignedPacketSize > buffSpace { + return 0, errors.New("receive ring full") + } + + packet := (*packet)(unsafe.Pointer(&tun.rings.receive.ring.data[buffTail])) + packet.size = packetSize + copy(packet.data[:packetSize], buff[offset:]) + atomic.StoreUint32(&tun.rings.receive.ring.tail, tun.rings.receive.ring.wrap(buffTail+alignedPacketSize)) + if atomic.LoadInt32(&tun.rings.receive.ring.alertable) != 0 { + windows.SetEvent(tun.rings.receive.tailMoved) + } + return int(packetSize), nil } - return len(buff) - offset, nil } -// // LUID returns Windows adapter instance ID. -// func (tun *NativeTun) LUID() uint64 { return tun.wt.LUID() } + +// wrap returns value modulo ring capacity +func (rb *ring) wrap(value uint32) uint32 { + return value & (packetCapacity - 1) +}