diff --git a/wintun.c b/wintun.c index 45a94b2..20aad05 100644 --- a/wintun.c +++ b/wintun.c @@ -112,9 +112,11 @@ typedef struct _TUN_CTX typedef struct _TUN_MAPPED_UBUFFER { - VOID *UserAddress, *KernelAddress; + VOID *volatile UserAddress; + VOID *KernelAddress; MDL *Mdl; ULONG Size; + FAST_MUTEX InitializationComplete; // TODO: ThreadID for checking } TUN_MAPPED_UBUFFER; @@ -295,38 +297,56 @@ _Must_inspect_result_ static NTSTATUS TunMapUbuffer(_Inout_ TUN_MAPPED_UBUFFER *MappedBuffer, _In_ VOID *UserAddress, _In_ ULONG Size) { - if (MappedBuffer->UserAddress) + VOID *current_uaddr = InterlockedGetPointer(&MappedBuffer->UserAddress); + if (current_uaddr) { - if (UserAddress == MappedBuffer->UserAddress) // TODO: Check ThreadID + if (UserAddress == current_uaddr) // TODO: Check ThreadID return STATUS_SUCCESS; return STATUS_ALREADY_INITIALIZED; } + NTSTATUS status = STATUS_SUCCESS; + ExAcquireFastMutex(&MappedBuffer->InitializationComplete); + + // Recheck the same thing as above, but locked this time. + current_uaddr = InterlockedGetPointer(&MappedBuffer->UserAddress); + if (current_uaddr) + { + if (UserAddress != current_uaddr) // TODO: Check ThreadID + status = STATUS_ALREADY_INITIALIZED; + goto err_releasemutex; + } + MappedBuffer->Mdl = IoAllocateMdl(UserAddress, Size, FALSE, FALSE, NULL); + status = STATUS_INSUFFICIENT_RESOURCES; if (!MappedBuffer->Mdl) - return STATUS_INSUFFICIENT_RESOURCES; + goto err_releasemutex; + + status = STATUS_INVALID_USER_BUFFER; try { MmProbeAndLockPages(MappedBuffer->Mdl, UserMode, IoWriteAccess); } - except(EXCEPTION_EXECUTE_HANDLER) - { - IoFreeMdl(MappedBuffer->Mdl); - MappedBuffer->Mdl = NULL; - return STATUS_INVALID_USER_BUFFER; - } + except(EXCEPTION_EXECUTE_HANDLER) { goto err_freemdl; } + MappedBuffer->KernelAddress = MmGetSystemAddressForMdlSafe(MappedBuffer->Mdl, NormalPagePriority | MdlMappingNoExecute); + status = STATUS_INSUFFICIENT_RESOURCES; if (!MappedBuffer->KernelAddress) - { - MmUnlockPages(MappedBuffer->Mdl); - IoFreeMdl(MappedBuffer->Mdl); - MappedBuffer->Mdl = NULL; - return STATUS_INSUFFICIENT_RESOURCES; - } - MappedBuffer->UserAddress = UserAddress; + goto err_unlockmdl; MappedBuffer->Size = Size; + InterlockedExchangePointer(&MappedBuffer->UserAddress, UserAddress); + ExReleaseFastMutex(&MappedBuffer->InitializationComplete); return STATUS_SUCCESS; + +err_unlockmdl: + MmUnlockPages(MappedBuffer->Mdl); +err_freemdl: + IoFreeMdl(MappedBuffer->Mdl); + MappedBuffer->Mdl = NULL; +err_releasemutex: + ExReleaseFastMutex(&MappedBuffer->InitializationComplete); + return status; } _IRQL_requires_max_(DISPATCH_LEVEL) @@ -974,6 +994,8 @@ TunDispatchCreate(_Inout_ TUN_CTX *Ctx, _Inout_ IRP *Irp) if (!file_ctx) return STATUS_INSUFFICIENT_RESOURCES; RtlZeroMemory(file_ctx, sizeof(*file_ctx)); + ExInitializeFastMutex(&file_ctx->ReadBuffer.InitializationComplete); + ExInitializeFastMutex(&file_ctx->WriteBuffer.InitializationComplete); KIRQL irql = ExAcquireSpinLockShared(&Ctx->TransitionLock); LONG flags = InterlockedGet(&Ctx->Flags);