diff --git a/tun/tun_windows.go b/tun/tun_windows.go index 9428373..948f08d 100644 --- a/tun/tun_windows.go +++ b/tun/tun_windows.go @@ -75,18 +75,11 @@ func CreateTUN(ifname string) (TUNDevice, error) { return nil, err } - go func() { - retries := retryTimeout * retryRate - for { - err := wt.SetInterfaceName(ifname) - if err != nil && retries > 0 { - time.Sleep(time.Second / retryRate) - retries-- - continue - } - return - } - }() + err = wt.SetInterfaceName(ifname) + if err != nil { + wt.DeleteInterface(0) + return nil, err + } err = wt.FlushInterface() if err != nil { diff --git a/tun/wintun/registryhacks_windows.go b/tun/wintun/registryhacks_windows.go new file mode 100644 index 0000000..62a629a --- /dev/null +++ b/tun/wintun/registryhacks_windows.go @@ -0,0 +1,42 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + */ + +package wintun + +import ( + "golang.org/x/sys/windows/registry" + "time" +) + +const ( + numRetries = 25 + retryTimeout = 100 * time.Millisecond +) + +func registryOpenKeyRetry(k registry.Key, path string, access uint32) (key registry.Key, err error) { + for i := 0; i < numRetries; i++ { + key, err = registry.OpenKey(k, path, access) + if err == nil { + break + } + if i != numRetries - 1 { + time.Sleep(retryTimeout) + } + } + return +} + +func keyGetStringValueRetry(k registry.Key, name string) (val string, valtype uint32, err error) { + for i := 0; i < numRetries; i++ { + val, valtype, err = k.GetStringValue(name) + if err == nil { + break + } + if i != numRetries - 1 { + time.Sleep(retryTimeout) + } + } + return +} diff --git a/tun/wintun/wintun_windows.go b/tun/wintun/wintun_windows.go index ba94b11..77e83a0 100644 --- a/tun/wintun/wintun_windows.go +++ b/tun/wintun/wintun_windows.go @@ -48,22 +48,14 @@ func MakeWintun(deviceInfoSet setupapi.DevInfo, deviceInfoData *setupapi.DevInfo var valueStr string var valueType uint32 - //TODO: Figure out a way to not need to loop like this. - for i := 0; i < 30; i++ { - // Read the NetCfgInstanceId value. - valueStr, valueType, err = key.GetStringValue("NetCfgInstanceId") - if err != nil { - time.Sleep(time.Millisecond * 100) - continue - } - if valueType != registry.SZ { - return nil, fmt.Errorf("NetCfgInstanceId registry value is not REG_SZ (expected: %v, provided: %v)", registry.SZ, valueType) - } - break - } + // Read the NetCfgInstanceId value. + valueStr, valueType, err = keyGetStringValueRetry(key, "NetCfgInstanceId") if err != nil { return nil, errors.New("RegQueryStringValue(\"NetCfgInstanceId\") failed: " + err.Error()) } + if valueType != registry.SZ { + return nil, fmt.Errorf("NetCfgInstanceId registry value is not REG_SZ (expected: %v, provided: %v)", registry.SZ, valueType) + } // Convert to windows.GUID. ifid, err := guid.FromString(valueStr) @@ -117,7 +109,6 @@ func GetInterface(ifname string, hwndParent uintptr) (*Wintun, error) { // "foobar" would cause conflict with "FooBar". ifname = strings.ToLower(ifname) - // Iterate. for index := 0; ; index++ { // Get the device from the list. Should anything be wrong with this device, continue with next. deviceData, err := devInfoList.EnumDeviceInfo(index) @@ -174,7 +165,7 @@ func GetInterface(ifname string, hwndParent uintptr) (*Wintun, error) { } // This interface is not using Wintun driver. - return wintun, errors.New("Foreign network interface with the same name exists") + return nil, errors.New("Foreign network interface with the same name exists") } } @@ -444,7 +435,7 @@ func checkReboot(deviceInfoSet setupapi.DevInfo, deviceInfoData *setupapi.DevInf // GetInterfaceName returns network interface name. // func (wintun *Wintun) GetInterfaceName() (string, error) { - key, err := registry.OpenKey(registry.LOCAL_MACHINE, wintun.GetNetRegKeyName(), registry.QUERY_VALUE) + key, err := registryOpenKeyRetry(registry.LOCAL_MACHINE, wintun.GetNetRegKeyName(), registry.QUERY_VALUE) if err != nil { return "", errors.New("Network-specific registry key open failed: " + err.Error()) } @@ -458,7 +449,7 @@ func (wintun *Wintun) GetInterfaceName() (string, error) { // SetInterfaceName sets network interface name. // func (wintun *Wintun) SetInterfaceName(ifname string) error { - key, err := registry.OpenKey(registry.LOCAL_MACHINE, wintun.GetNetRegKeyName(), registry.SET_VALUE) + key, err := registryOpenKeyRetry(registry.LOCAL_MACHINE, wintun.GetNetRegKeyName(), registry.SET_VALUE) if err != nil { return errors.New("Network-specific registry key open failed: " + err.Error()) } @@ -483,7 +474,7 @@ func (wintun *Wintun) GetNetRegKeyName() string { // func getRegStringValue(key registry.Key, name string) (string, error) { // Read string value. - value, valueType, err := key.GetStringValue(name) + value, valueType, err := keyGetStringValueRetry(key, name) if err != nil { return "", err }