api: refuse to load DLL on init failure

Signed-off-by: Simon Rozman <simon@rozman.si>
This commit is contained in:
Simon Rozman 2020-10-31 08:15:52 +01:00
parent e7a85b7b28
commit 08440580c3
3 changed files with 13 additions and 4 deletions

View File

@ -64,11 +64,14 @@ DllMain(_In_ HINSTANCE hinstDLL, _In_ DWORD fdwReason, _In_ LPVOID lpvReserved)
case DLL_PROCESS_ATTACH: case DLL_PROCESS_ATTACH:
ResourceModule = hinstDLL; ResourceModule = hinstDLL;
ModuleHeap = HeapCreate(0, 0, 0); ModuleHeap = HeapCreate(0, 0, 0);
if (!ModuleHeap)
return FALSE;
ConvertStringSecurityDescriptorToSecurityDescriptorW( ConvertStringSecurityDescriptorToSecurityDescriptorW(
L"O:SYD:P(A;;GA;;;SY)", SDDL_REVISION_1, &SecurityAttributes.lpSecurityDescriptor, NULL); L"O:SYD:P(A;;GA;;;SY)", SDDL_REVISION_1, &SecurityAttributes.lpSecurityDescriptor, NULL);
AdapterInit(); AdapterInit();
NamespaceInit(); NamespaceInit();
NciInit(); if (NciInit() != ERROR_SUCCESS)
return FALSE;
break; break;
case DLL_PROCESS_DETACH: case DLL_PROCESS_DETACH:

View File

@ -15,16 +15,21 @@ DWORD(WINAPI *NciGetConnectionName)
_In_ DWORD InDestNameBytes, _In_ DWORD InDestNameBytes,
_Out_opt_ DWORD *OutDestNameBytes); _Out_opt_ DWORD *OutDestNameBytes);
void WINTUN_STATUS
NciInit(void) NciInit(void)
{ {
NciModule = LoadLibraryExW(L"nci.dll", NULL, LOAD_LIBRARY_SEARCH_SYSTEM32); NciModule = LoadLibraryExW(L"nci.dll", NULL, LOAD_LIBRARY_SEARCH_SYSTEM32);
if (!NciModule) if (!NciModule)
abort(); return GetLastError();
NciSetConnectionName = NciSetConnectionName =
(DWORD(WINAPI *)(const GUID *, const WCHAR *))GetProcAddress(NciModule, "NciSetConnectionName"); (DWORD(WINAPI *)(const GUID *, const WCHAR *))GetProcAddress(NciModule, "NciSetConnectionName");
if (!NciSetConnectionName)
return GetLastError();
NciGetConnectionName = NciGetConnectionName =
(DWORD(WINAPI *)(const GUID *, WCHAR *, DWORD, DWORD *))GetProcAddress(NciModule, "NciGetConnectionName"); (DWORD(WINAPI *)(const GUID *, WCHAR *, DWORD, DWORD *))GetProcAddress(NciModule, "NciGetConnectionName");
if (!NciGetConnectionName)
return GetLastError();
return ERROR_SUCCESS;
} }
void void

View File

@ -5,6 +5,7 @@
#pragma once #pragma once
#include "wintun.h"
#include <Windows.h> #include <Windows.h>
extern DWORD(WINAPI *NciSetConnectionName)(_In_ const GUID *Guid, _In_z_ const WCHAR *NewName); extern DWORD(WINAPI *NciSetConnectionName)(_In_ const GUID *Guid, _In_z_ const WCHAR *NewName);
@ -15,7 +16,7 @@ extern DWORD(WINAPI *NciGetConnectionName)(
_In_ DWORD InDestNameBytes, _In_ DWORD InDestNameBytes,
_Out_opt_ DWORD *OutDestNameBytes); _Out_opt_ DWORD *OutDestNameBytes);
void WINTUN_STATUS
NciInit(void); NciInit(void);
void void