wintun: manage ring memory manually

It's large and Go's garbage collector doesn't deal with it especially
well.
This commit is contained in:
Jason A. Donenfeld 2019-11-21 14:48:21 +01:00
parent 4cdf805b29
commit 2b242f9393
2 changed files with 27 additions and 7 deletions

View File

@ -35,11 +35,11 @@ type NativeTun struct {
wt *wintun.Interface wt *wintun.Interface
handle windows.Handle handle windows.Handle
close bool close bool
rings wintun.RingDescriptor
events chan Event events chan Event
errors chan error errors chan error
forcedMTU int forcedMTU int
rate rateJuggler rate rateJuggler
rings *wintun.RingDescriptor
} }
const WintunPool = wintun.Pool("WireGuard") const WintunPool = wintun.Pool("WireGuard")
@ -93,13 +93,13 @@ func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu
forcedMTU: forcedMTU, forcedMTU: forcedMTU,
} }
err = tun.rings.Init() tun.rings, err = wintun.NewRingDescriptor()
if err != nil { if err != nil {
tun.Close() tun.Close()
return nil, fmt.Errorf("Error creating events: %v", err) return nil, fmt.Errorf("Error creating events: %v", err)
} }
tun.handle, err = tun.wt.Register(&tun.rings) tun.handle, err = tun.wt.Register(tun.rings)
if err != nil { if err != nil {
tun.Close() tun.Close()
return nil, fmt.Errorf("Error registering rings: %v", err) return nil, fmt.Errorf("Error registering rings: %v", err)

View File

@ -6,6 +6,7 @@
package wintun package wintun
import ( import (
"runtime"
"unsafe" "unsafe"
"golang.org/x/sys/windows" "golang.org/x/sys/windows"
@ -53,25 +54,44 @@ func PacketAlign(size uint32) uint32 {
return (size + (PacketAlignment - 1)) &^ (PacketAlignment - 1) return (size + (PacketAlignment - 1)) &^ (PacketAlignment - 1)
} }
func (descriptor *RingDescriptor) Init() (err error) { func NewRingDescriptor() (descriptor *RingDescriptor, err error) {
descriptor = new(RingDescriptor)
allocatedRegion, err := windows.VirtualAlloc(0, unsafe.Sizeof(Ring{})*2, windows.MEM_COMMIT|windows.MEM_RESERVE, windows.PAGE_READWRITE)
if err != nil {
return
}
defer func() {
if err != nil {
descriptor.free()
descriptor = nil
}
}()
descriptor.Send.Size = uint32(unsafe.Sizeof(Ring{})) descriptor.Send.Size = uint32(unsafe.Sizeof(Ring{}))
descriptor.Send.Ring = &Ring{} descriptor.Send.Ring = (*Ring)(unsafe.Pointer(allocatedRegion))
descriptor.Send.TailMoved, err = windows.CreateEvent(nil, 0, 0, nil) descriptor.Send.TailMoved, err = windows.CreateEvent(nil, 0, 0, nil)
if err != nil { if err != nil {
return return
} }
descriptor.Receive.Size = uint32(unsafe.Sizeof(Ring{})) descriptor.Receive.Size = uint32(unsafe.Sizeof(Ring{}))
descriptor.Receive.Ring = &Ring{} descriptor.Receive.Ring = (*Ring)(unsafe.Pointer(allocatedRegion + unsafe.Sizeof(Ring{})))
descriptor.Receive.TailMoved, err = windows.CreateEvent(nil, 0, 0, nil) descriptor.Receive.TailMoved, err = windows.CreateEvent(nil, 0, 0, nil)
if err != nil { if err != nil {
windows.CloseHandle(descriptor.Send.TailMoved) windows.CloseHandle(descriptor.Send.TailMoved)
return return
} }
runtime.SetFinalizer(descriptor, func(d *RingDescriptor) { d.free() })
return return
} }
func (descriptor *RingDescriptor) free() {
if descriptor.Send.Ring != nil {
windows.VirtualFree(uintptr(unsafe.Pointer(descriptor.Send.Ring)), 0, windows.MEM_RELEASE)
descriptor.Send.Ring = nil
descriptor.Receive.Ring = nil
}
}
func (descriptor *RingDescriptor) Close() { func (descriptor *RingDescriptor) Close() {
if descriptor.Send.TailMoved != 0 { if descriptor.Send.TailMoved != 0 {
windows.CloseHandle(descriptor.Send.TailMoved) windows.CloseHandle(descriptor.Send.TailMoved)