diff --git a/api/adapter.c b/api/adapter.c index 30c3afe..608d473 100644 --- a/api/adapter.c +++ b/api/adapter.c @@ -977,6 +977,93 @@ cleanupTcpipAdapterRegKey: return Result; } +static DWORDLONG +VersionOfFile(WCHAR *Filename) +{ + DWORDLONG Version = 0; + DWORD Zero; + DWORD Len = GetFileVersionInfoSizeW(Filename, &Zero); + if (!Len) + return LOG_LAST_ERROR(L"Failed to get version info size"), Version; + VOID *VersionInfo = HeapAlloc(ModuleHeap, 0, Len); + if (!VersionInfo) + { + LOG(WINTUN_LOG_ERR, L"Out of memory"); + return Version; + } + VS_FIXEDFILEINFO *FixedInfo; + UINT FixedInfoLen = sizeof(*FixedInfo); + if (!GetFileVersionInfoW(Filename, 0, Len, VersionInfo)) + { + LOG_LAST_ERROR(L"Failed to get version info"); + goto out; + } + if (!VerQueryValueW(VersionInfo, L"\\", &FixedInfo, &FixedInfoLen)) + { + LOG_LAST_ERROR(L"Failed to query version info root block"); + goto out; + } + Version = (DWORDLONG)FixedInfo->dwFileVersionLS | ((DWORDLONG)FixedInfo->dwFileVersionMS << 32); +out: + HeapFree(ModuleHeap, 0, VersionInfo); + return Version; +} + +static DWORDLONG +RunningWintunVersion(void) +{ + DWORD RequiredSize = 0, CurrentSize = 0; + VOID **Drivers = NULL; + DWORDLONG Version = 0; + for (;;) + { + if (!EnumDeviceDrivers(Drivers, CurrentSize, &RequiredSize)) + { + LOG(WINTUN_LOG_ERR, L"Failed to enumerate drivers"); + return Version; + } + if (CurrentSize == RequiredSize) + break; + if (Drivers) + HeapFree(ModuleHeap, 0, Drivers); + Drivers = HeapAlloc(ModuleHeap, 0, RequiredSize); + if (!Drivers) + { + LOG(WINTUN_LOG_ERR, L"Out of memory"); + return Version; + } + CurrentSize = RequiredSize; + } + WCHAR MaybeWintun[11]; + for (DWORD i = CurrentSize / sizeof(Drivers[0]); i-- > 0;) + { + if (GetDeviceDriverBaseNameW(Drivers[i], MaybeWintun, _countof(MaybeWintun)) == 10 && + !_wcsicmp(MaybeWintun, L"wintun.sys")) + { + WCHAR WintunPath[MAX_PATH + 2]; + DWORD Len = GetDeviceDriverFileNameW(Drivers[i], WintunPath, _countof(WintunPath)); + if (!Len || Len == _countof(WintunPath) - 1) + { + LOG(WINTUN_LOG_ERR, L"Failed to locate driver path"); + goto out; + } + Version = VersionOfFile(WintunPath); + goto out; + } + } +out: + HeapFree(ModuleHeap, 0, Drivers); + return Version; +} + +static BOOL EnsureWintunUnloaded(VOID) +{ + BOOL Loaded; + for (int i = 0; (Loaded = RunningWintunVersion() != 0) != FALSE && i < 300; ++i) + Sleep(50); + return !Loaded; +} + static WINTUN_STATUS CreateAdapter( _In_z_count_c_(MAX_PATH) const WCHAR *InfPath, @@ -1626,14 +1713,46 @@ WintunCreateAdapter( goto cleanupDelete; } + DWORDLONG LoadedDriverVersion = RunningWintunVersion(); + SP_DEVINFO_DATA_LIST *ExistingAdapters = NULL; + HDEVINFO DevInfo = INVALID_HANDLE_VALUE; + if (LoadedDriverVersion) + { + DWORDLONG ProposedDriverVersion = VersionOfFile(SysPath); + if (!ProposedDriverVersion) + { + LOG(WINTUN_LOG_ERR, L"Unable to query version of sys file"); + goto cleanupDelete; + } + if (ProposedDriverVersion > LoadedDriverVersion) + { + DevInfo = SetupDiGetClassDevsExW(&GUID_DEVCLASS_NET, NULL, NULL, DIGCF_PRESENT, NULL, NULL, NULL); + if (DevInfo == INVALID_HANDLE_VALUE) + { + Result = LOG_LAST_ERROR(L"Failed to get present class devices"); + goto cleanupDelete; + } + AdapterDisableAllOurs(DevInfo, &ExistingAdapters); + LOG(WINTUN_LOG_INFO, L"Waiting for existing driver to unload from kernel"); + if (!EnsureWintunUnloaded()) + LOG(WINTUN_LOG_WARN, L"Unable to unload existing driver, which means a reboot will likely be required"); + } + } + LOG(WINTUN_LOG_INFO, L"Installing driver"); WCHAR InfStorePath[MAX_PATH]; WCHAR *InfStoreFilename; if (!SetupCopyOEMInfW(InfPath, NULL, SPOST_PATH, 0, InfStorePath, _countof(InfStorePath), NULL, &InfStoreFilename)) { Result = LOG_LAST_ERROR(L"Could not install driver to store"); - goto cleanupDelete; + goto cleanupCloseDevInfo; } + BOOL UpdateRebootRequired = FALSE; + if (ExistingAdapters && + !UpdateDriverForPlugAndPlayDevicesW( + NULL, WINTUN_HWID, InfPath, INSTALLFLAG_FORCE | INSTALLFLAG_NONINTERACTIVE, &UpdateRebootRequired)) + LOG(WINTUN_LOG_WARN, L"Could not update existing adapters"); + *RebootRequired = *RebootRequired || UpdateRebootRequired; Result = CreateAdapter(InfPath, Pool, Name, RequestedGUID, Adapter, RebootRequired); @@ -1643,6 +1762,19 @@ WintunCreateAdapter( LOG_LAST_ERROR(L"Unable to remove existing driver"); Result = Result != ERROR_SUCCESS ? Result : GetLastError(); } +cleanupCloseDevInfo: + if (ExistingAdapters) + { + AdapterEnableAll(DevInfo, ExistingAdapters); + while (ExistingAdapters) + { + SP_DEVINFO_DATA_LIST *Next = ExistingAdapters->Next; + HeapFree(ModuleHeap, 0, ExistingAdapters); + ExistingAdapters = Next; + } + } + if (DevInfo != INVALID_HANDLE_VALUE) + SetupDiDestroyDeviceInfoList(DevInfo); cleanupDelete: DeleteFileW(CatPath); DeleteFileW(SysPath); diff --git a/api/api.vcxproj b/api/api.vcxproj index cea689d..7fed028 100644 --- a/api/api.vcxproj +++ b/api/api.vcxproj @@ -157,8 +157,8 @@ _M_ARM64=1;%(PreprocessorDefinitions) - bcrypt.dll;iphlpapi.dll - Bcrypt.lib;Crypt32.lib;Cfgmgr32.lib;Iphlpapi.lib;ntdll.lib;Setupapi.lib;shlwapi.lib;%(AdditionalDependencies) + bcrypt.dll;iphlpapi.dll;newdev.dll;version.dll + Bcrypt.lib;Crypt32.lib;Cfgmgr32.lib;Iphlpapi.lib;newdev.lib;ntdll.lib;Setupapi.lib;shlwapi.lib;version.lib;%(AdditionalDependencies) exports.def Windows