diff --git a/wintun.c b/wintun.c index 694365d..7529b20 100644 --- a/wintun.c +++ b/wintun.c @@ -66,6 +66,7 @@ typedef struct _TUN_CTX { struct { NDIS_HANDLE Handle; volatile LONG64 RefCount; + IO_REMOVE_LOCK RemoveLock; struct { KSPIN_LOCK Lock; @@ -143,11 +144,12 @@ static void TunIndicateStatus(_In_ NDIS_HANDLE MiniportAdapterHandle, _In_ NDIS_ } _IRQL_requires_max_(DISPATCH_LEVEL) -static void TunCompleteRequest(_Inout_ IRP *Irp, _In_ ULONG_PTR Information, _In_ NTSTATUS Status) +static void TunCompleteRequest(_Inout_ TUN_CTX *ctx, _Inout_ IRP *Irp, _In_ ULONG_PTR Information, _In_ NTSTATUS Status) { Irp->IoStatus.Information = Information; Irp->IoStatus.Status = Status; IoCompleteRequest(Irp, IO_NO_INCREMENT); + IoReleaseRemoveLock(&ctx->Device.RemoveLock, Irp); } _IRQL_requires_same_ @@ -240,7 +242,8 @@ static IO_CSQ_COMPLETE_CANCELED_IRP TunCsqCompleteCanceledIrp; _Use_decl_annotations_ static VOID TunCsqCompleteCanceledIrp(IO_CSQ *Csq, IRP *Irp) { - TunCompleteRequest(Irp, 0, STATUS_CANCELLED); + TUN_CTX *ctx = CONTAINING_RECORD(Csq, TUN_CTX, Device.ReadQueue.Csq); + TunCompleteRequest(ctx, Irp, 0, STATUS_CANCELLED); } _IRQL_requires_same_ @@ -315,6 +318,7 @@ retry: if (!NT_SUCCESS(status)) { irp->IoStatus.Status = status; IoCompleteRequest(irp, IO_NO_INCREMENT); + IoReleaseRemoveLock(&ctx->Device.RemoveLock, irp); goto retry; } @@ -555,6 +559,7 @@ static void TunQueueProcess(_Inout_ TUN_CTX *ctx) } else { irp->IoStatus.Status = STATUS_SUCCESS; IoCompleteRequest(irp, IO_NETWORK_INCREMENT); + IoReleaseRemoveLock(&ctx->Device.RemoveLock, irp); irp = NULL; } @@ -702,59 +707,64 @@ static NTSTATUS TunDispatch(DEVICE_OBJECT *DeviceObject, IRP *Irp) IO_STACK_LOCATION *stack = IoGetCurrentIrpStackLocation(Irp); switch (stack->MajorFunction) { case IRP_MJ_READ: - if (InterlockedGet((LONG *)&ctx->State) < TUN_STATE_PAUSED) { - status = STATUS_FILE_FORCED_CLOSED; - break; - } + if ((status = STATUS_FILE_FORCED_CLOSED, InterlockedGet((LONG *)&ctx->State) < TUN_STATE_PAUSED) || + !NT_SUCCESS(status = IoAcquireRemoveLock(&ctx->Device.RemoveLock, Irp))) + goto cleanup_complete_req; - status = IoCsqInsertIrpEx(&ctx->Device.ReadQueue.Csq, Irp, NULL, TUN_CSQ_INSERT_TAIL); - if (!NT_SUCCESS(status)) - break; + if (!NT_SUCCESS(status = IoCsqInsertIrpEx(&ctx->Device.ReadQueue.Csq, Irp, NULL, TUN_CSQ_INSERT_TAIL))) + goto cleanup_complete_req_and_release_remove_lock; TunQueueProcess(ctx); return STATUS_PENDING; case IRP_MJ_WRITE: - if (InterlockedGet((LONG *)&ctx->State) < TUN_STATE_PAUSED) { - status = STATUS_FILE_FORCED_CLOSED; - break; - } + if ((status = STATUS_FILE_FORCED_CLOSED, InterlockedGet((LONG *)&ctx->State) < TUN_STATE_PAUSED) || + !NT_SUCCESS(status = IoAcquireRemoveLock(&ctx->Device.RemoveLock, Irp))) + goto cleanup_complete_req; status = TunWriteFromIrp(ctx, Irp); - break; + goto cleanup_complete_req_and_release_remove_lock; case IRP_MJ_CREATE: - if (InterlockedGet((LONG *)&ctx->State) < TUN_STATE_PAUSED) { - status = STATUS_DELETE_PENDING; - break; - } + if ((status = STATUS_DELETE_PENDING, InterlockedGet((LONG *)&ctx->State) < TUN_STATE_PAUSED) || + !NT_SUCCESS(status = IoAcquireRemoveLock(&ctx->Device.RemoveLock, Irp))) + goto cleanup_complete_req; ASSERT(InterlockedGet64(&ctx->Device.RefCount) < MAXLONG64); if (InterlockedIncrement64(&ctx->Device.RefCount) > 0) TunIndicateStatus(ctx->MiniportAdapterHandle, MediaConnectStateConnected); + IoAcquireRemoveLock(&ctx->Device.RemoveLock, stack->FileObject); status = STATUS_SUCCESS; - break; + goto cleanup_complete_req_and_release_remove_lock; case IRP_MJ_CLOSE: ASSERT(InterlockedGet64(&ctx->Device.RefCount) > 0); if (InterlockedDecrement64(&ctx->Device.RefCount) <= 0 && ctx->MiniportAdapterHandle) TunIndicateStatus(ctx->MiniportAdapterHandle, MediaConnectStateDisconnected); + IoReleaseRemoveLock(&ctx->Device.RemoveLock, stack->FileObject); status = STATUS_SUCCESS; - break; + goto cleanup_complete_req; case IRP_MJ_CLEANUP: for (IRP *pending_irp; (pending_irp = IoCsqRemoveNextIrp(&ctx->Device.ReadQueue.Csq, stack->FileObject)) != NULL; ) - TunCompleteRequest(pending_irp, 0, STATUS_CANCELLED); + TunCompleteRequest(ctx, pending_irp, 0, STATUS_CANCELLED); status = STATUS_SUCCESS; - break; + goto cleanup_complete_req; default: status = STATUS_INVALID_PARAMETER; + goto cleanup_complete_req; } +cleanup_complete_req_and_release_remove_lock: + Irp->IoStatus.Status = status; + IoCompleteRequest(Irp, IO_NO_INCREMENT); + IoReleaseRemoveLock(&ctx->Device.RemoveLock, Irp); + return status; + cleanup_complete_req: Irp->IoStatus.Status = status; IoCompleteRequest(Irp, IO_NO_INCREMENT); @@ -948,6 +958,7 @@ static NDIS_STATUS TunInitializeEx(NDIS_HANDLE MiniportAdapterHandle, NDIS_HANDL NDIS_STATISTICS_FLAGS_VALID_BROADCAST_BYTES_XMIT; ctx->Device.Handle = handle; + IoInitializeRemoveLock(&ctx->Device.RemoveLock, TUN_MEMORY_TAG, 0, 0); KeInitializeSpinLock(&ctx->Device.ReadQueue.Lock); IoCsqInitializeEx(&ctx->Device.ReadQueue.Csq, TunCsqInsertIrpEx, @@ -1101,15 +1112,21 @@ static void TunHaltEx(NDIS_HANDLE MiniportAdapterContext, NDIS_HALT_ACTION HaltA /* Complete pending IRPs to unblock waiting clients. */ for (IRP *pending_irp; (pending_irp = IoCsqRemoveNextIrp(&ctx->Device.ReadQueue.Csq, NULL)) != NULL;) - TunCompleteRequest(pending_irp, 0, STATUS_FILE_FORCED_CLOSED); + TunCompleteRequest(ctx, pending_irp, 0, STATUS_FILE_FORCED_CLOSED); NdisFreeNetBufferListPool(ctx->NBLPool); ctx->NBLPool = NULL; ctx->MiniportAdapterHandle = NULL; + /* Wait for all device handles to close. */ + /* TODO: Research how to close all handles from within the driver, rather than depending on client to close them. */ + IoAcquireRemoveLock(&ctx->Device.RemoveLock, NULL); + IoReleaseRemoveLockAndWait(&ctx->Device.RemoveLock, NULL); + InterlockedExchange((LONG *)&ctx->State, TUN_STATE_HALTED); + /* Deregister device _after_ we are done writing to ctx not to risk an UaF. The ctx is hosted by device extension. */ NdisDeregisterDeviceEx(ctx->Device.Handle); }