From 91b4e909bb5fe6980ec56983247e2bfb9fb70ee6 Mon Sep 17 00:00:00 2001 From: Simon Rozman Date: Wed, 20 Mar 2019 21:45:40 +0100 Subject: [PATCH] wintun: Use native Win32 API for I/O Signed-off-by: Simon Rozman --- tun/mksyscall.go | 8 ++++ tun/tun.go | 13 ------ tun/tun_default.go | 24 ++++++++++ tun/tun_windows.go | 107 +++++++++++++++++++++++++++++++++----------- tun/ztun_windows.go | 61 +++++++++++++++++++++++++ 5 files changed, 174 insertions(+), 39 deletions(-) create mode 100644 tun/mksyscall.go create mode 100644 tun/tun_default.go create mode 100644 tun/ztun_windows.go diff --git a/tun/mksyscall.go b/tun/mksyscall.go new file mode 100644 index 0000000..06bb41e --- /dev/null +++ b/tun/mksyscall.go @@ -0,0 +1,8 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + */ + +package tun + +//go:generate go run $GOROOT/src/syscall/mksyscall_windows.go -output ztun_windows.go tun_windows.go diff --git a/tun/tun.go b/tun/tun.go index f38ee31..c4b6cac 100644 --- a/tun/tun.go +++ b/tun/tun.go @@ -6,7 +6,6 @@ package tun import ( - "fmt" "os" ) @@ -27,15 +26,3 @@ type TUNDevice interface { Events() chan TUNEvent // returns a constant channel of events related to the device Close() error // stops the device and closes the event channel } - -func (tun *NativeTun) operateOnFd(fn func(fd uintptr)) { - sysconn, err := tun.tunFile.SyscallConn() - if err != nil { - tun.errors <- fmt.Errorf("unable to find sysconn for tunfile: %s", err.Error()) - return - } - err = sysconn.Control(fn) - if err != nil { - tun.errors <- fmt.Errorf("unable to control sysconn for tunfile: %s", err.Error()) - } -} diff --git a/tun/tun_default.go b/tun/tun_default.go new file mode 100644 index 0000000..31747a2 --- /dev/null +++ b/tun/tun_default.go @@ -0,0 +1,24 @@ +// +build !windows + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + */ + +package tun + +import ( + "fmt" +) + +func (tun *NativeTun) operateOnFd(fn func(fd uintptr)) { + sysconn, err := tun.tunFile.SyscallConn() + if err != nil { + tun.errors <- fmt.Errorf("unable to find sysconn for tunfile: %s", err.Error()) + return + } + err = sysconn.Control(fn) + if err != nil { + tun.errors <- fmt.Errorf("unable to control sysconn for tunfile: %s", err.Error()) + } +} diff --git a/tun/tun_windows.go b/tun/tun_windows.go index 81ba18e..15d9ae2 100644 --- a/tun/tun_windows.go +++ b/tun/tun_windows.go @@ -9,6 +9,8 @@ import ( "errors" "os" "sync" + "syscall" + "time" "unsafe" "golang.org/x/sys/windows" @@ -20,6 +22,8 @@ const ( packetExchangeAlignment uint32 = 16 // Number of bytes packets are aligned to in exchange buffers packetSizeMax uint32 = 0xf000 - packetExchangeAlignment // Maximum packet size packetExchangeSize uint32 = 0x100000 // Exchange buffer size (defaults to 1MiB) + retryRate = 4 // Number of retries per second to reopen device pipe + retryTimeout = 5 // Number of seconds to tolerate adapter unavailable ) type exchgBufRead struct { @@ -36,9 +40,10 @@ type exchgBufWrite struct { type NativeTun struct { wt *wintun.Wintun - tunName string - tunFile *os.File + tunName *uint16 + tunFile windows.Handle tunLock sync.Mutex + close bool rdBuff *exchgBufRead wrBuff *exchgBufWrite events chan TUNEvent @@ -46,6 +51,8 @@ type NativeTun struct { forcedMtu int } +//sys getOverlappedResult(handle windows.Handle, overlapped *windows.Overlapped, done *uint32, wait bool) (err error) = kernel32.GetOverlappedResult + func packetAlign(size uint32) uint32 { return (size + (packetExchangeAlignment - 1)) &^ (packetExchangeAlignment - 1) } @@ -83,9 +90,16 @@ func CreateTUN(ifname string) (TUNDevice, error) { return nil, errors.New("Flushing interface failed: " + err.Error()) } + tunNameUTF16, err := windows.UTF16PtrFromString(wt.DataFileName()) + if err != nil { + wt.DeleteInterface(0) + return nil, err + } + return &NativeTun{ wt: wt, - tunName: wt.DataFileName(), + tunName: tunNameUTF16, + tunFile: windows.InvalidHandle, rdBuff: &exchgBufRead{}, wrBuff: &exchgBufWrite{}, events: make(chan TUNEvent, 10), @@ -94,42 +108,67 @@ func CreateTUN(ifname string) (TUNDevice, error) { }, nil } -func (tun *NativeTun) openTUN() { +func (tun *NativeTun) openTUN() error { + retries := retryTimeout * retryRate for { - file, err := os.OpenFile(tun.tunName, os.O_RDWR, 0) - if err != nil { - continue + if tun.close { + return errors.New("Cancelled") } + + file, err := windows.CreateFile(tun.tunName, windows.GENERIC_READ|windows.GENERIC_WRITE, 0, nil, windows.OPEN_EXISTING, windows.FILE_ATTRIBUTE_NORMAL|windows.FILE_FLAG_OVERLAPPED|0x20000000 /*windows.FILE_FLAG_NO_BUFFERING*/, 0) + if err != nil { + if retries > 0 { + time.Sleep(time.Second / retryRate) + retries-- + continue + } + return err + } + tun.tunFile = file + return nil } } func (tun *NativeTun) closeTUN() (err error) { - if tun.tunFile != nil { + if tun.tunFile != windows.InvalidHandle { tun.tunLock.Lock() defer tun.tunLock.Unlock() - if tun.tunFile == nil { + if tun.tunFile == windows.InvalidHandle { return } t := tun.tunFile - tun.tunFile = nil - err = t.Close() + tun.tunFile = windows.InvalidHandle + err = windows.CloseHandle(t) } return } -func (tun *NativeTun) getTUN() (*os.File, error) { - if tun.tunFile == nil { +func (tun *NativeTun) getTUN() (windows.Handle, error) { + if tun.tunFile == windows.InvalidHandle { tun.tunLock.Lock() defer tun.tunLock.Unlock() - if tun.tunFile != nil { + if tun.tunFile != windows.InvalidHandle { return tun.tunFile, nil } - tun.openTUN() + err := tun.openTUN() + if err != nil { + return windows.InvalidHandle, err + } } return tun.tunFile, nil } +func (tun *NativeTun) isIOCancelled(err error) bool { + // Read&WriteFile() return the same ERROR_OPERATION_ABORTED if we close the handle + // or the TUN device is put down. We need a "close" flag to distinguish. + en, ok := err.(syscall.Errno) + if tun.close && ok && en == windows.ERROR_OPERATION_ABORTED { + return true + } + return false +} + func (tun *NativeTun) Name() (string, error) { return tun.wt.GetInterfaceName() } @@ -143,6 +182,7 @@ func (tun *NativeTun) Events() chan TUNEvent { } func (tun *NativeTun) Close() error { + tun.close = true err1 := tun.closeTUN() if tun.events != nil { @@ -199,15 +239,21 @@ func (tun *NativeTun) Read(buff []byte, offset int) (int, error) { } // Fill queue. - n, err := file.Read(tun.rdBuff.data[:]) + var n uint32 + overlapped := &windows.Overlapped{} + err = windows.ReadFile(file, tun.rdBuff.data[:], &n, overlapped) if err != nil { - if pe, ok := err.(*os.PathError); ok && pe.Err == os.ErrClosed { - return 0, err + if en, ok := err.(syscall.Errno); ok && en == windows.ERROR_IO_PENDING { + err = getOverlappedResult(file, overlapped, &n, true) + } + if err != nil { + tun.rdBuff.avail = 0 + if tun.isIOCancelled(err) { + return 0, err + } + tun.closeTUN() + continue } - // TUN interface stopped, failed, etc. Retry. - tun.rdBuff.avail = 0 - tun.closeTUN() - continue } tun.rdBuff.offset = 0 tun.rdBuff.avail = uint32(n) @@ -224,13 +270,22 @@ func (tun *NativeTun) flush() error { } // Flush write buffer. - _, err = file.Write(tun.wrBuff.data[:tun.wrBuff.offset]) + var n uint32 + overlapped := &windows.Overlapped{} + err = windows.WriteFile(file, tun.wrBuff.data[:tun.wrBuff.offset], &n, overlapped) tun.wrBuff.packetNum = 0 tun.wrBuff.offset = 0 if err != nil { - // TUN interface stopped, failed, etc. Drop. - tun.closeTUN() - return err + if en, ok := err.(syscall.Errno); ok && en == windows.ERROR_IO_PENDING { + err = getOverlappedResult(file, overlapped, &n, true) + } + if err != nil { + if tun.isIOCancelled(err) { + return err + } + tun.closeTUN() + return nil + } } return nil diff --git a/tun/ztun_windows.go b/tun/ztun_windows.go new file mode 100644 index 0000000..ed779c1 --- /dev/null +++ b/tun/ztun_windows.go @@ -0,0 +1,61 @@ +// Code generated by 'go generate'; DO NOT EDIT. + +package tun + +import ( + "syscall" + "unsafe" + + "golang.org/x/sys/windows" +) + +var _ unsafe.Pointer + +// Do the interface allocations only once for common +// Errno values. +const ( + errnoERROR_IO_PENDING = 997 +) + +var ( + errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING) +) + +// errnoErr returns common boxed Errno values, to prevent +// allocations at runtime. +func errnoErr(e syscall.Errno) error { + switch e { + case 0: + return nil + case errnoERROR_IO_PENDING: + return errERROR_IO_PENDING + } + // TODO: add more here, after collecting data on the common + // error values see on Windows. (perhaps when running + // all.bat?) + return e +} + +var ( + modkernel32 = windows.NewLazySystemDLL("kernel32.dll") + + procGetOverlappedResult = modkernel32.NewProc("GetOverlappedResult") +) + +func getOverlappedResult(handle windows.Handle, overlapped *windows.Overlapped, done *uint32, wait bool) (err error) { + var _p0 uint32 + if wait { + _p0 = 1 + } else { + _p0 = 0 + } + r1, _, e1 := syscall.Syscall6(procGetOverlappedResult.Addr(), 4, uintptr(handle), uintptr(unsafe.Pointer(overlapped)), uintptr(unsafe.Pointer(done)), uintptr(_p0), 0, 0) + if r1 == 0 { + if e1 != 0 { + err = errnoErr(e1) + } else { + err = syscall.EINVAL + } + } + return +}