diff --git a/wintun.c b/wintun.c index 9ba1b1f..ffca16d 100644 --- a/wintun.c +++ b/wintun.c @@ -799,9 +799,6 @@ static NTSTATUS TunDispatchWrite(_Inout_ TUN_CTX *Ctx, _Inout_ IRP *Irp) { NTSTATUS Status; - - InterlockedIncrement64(&Ctx->ActiveNBLCount); - IO_STACK_LOCATION *Stack = IoGetCurrentIrpStackLocation(Irp); ULONG Size = Stack->Parameters.Write.Length; if (Status = STATUS_INVALID_USER_BUFFER, (Size < TUN_EXCH_MIN_BUFFER_SIZE_WRITE || Size > TUN_EXCH_MAX_BUFFER_SIZE)) @@ -823,11 +820,6 @@ TunDispatchWrite(_Inout_ TUN_CTX *Ctx, _Inout_ IRP *Irp) if (Status = STATUS_INSUFFICIENT_RESOURCES, !Mdl) goto cleanup_ExFreePoolWithTag; - KIRQL Irql = ExAcquireSpinLockShared(&Ctx->TransitionLock); - LONG Flags = InterlockedGet(&Ctx->Flags); - if (Status = STATUS_FILE_FORCED_CLOSED, !(Flags & TUN_FLAGS_PRESENT)) - goto cleanup_ExReleaseSpinLockShared; - const UCHAR *BufferPos = BufferStart, *BufferEnd = BufferStart + Size; typedef enum { @@ -852,24 +844,14 @@ TunDispatchWrite(_Inout_ TUN_CTX *Ctx, _Inout_ IRP *Irp) LONG NblCount = 0; while (BufferEnd - BufferPos >= sizeof(TUN_PACKET)) { - if (NblCount >= MAXLONG) - { - Status = STATUS_INVALID_USER_BUFFER; + if (Status = STATUS_INVALID_USER_BUFFER, NblCount >= MAXLONG) goto cleanup_nbl_queues; - } - TUN_PACKET *Packet = (TUN_PACKET *)BufferPos; - if (Packet->Size > TUN_EXCH_MAX_IP_PACKET_SIZE) - { - Status = STATUS_INVALID_USER_BUFFER; + if (Status = STATUS_INVALID_USER_BUFFER, Packet->Size > TUN_EXCH_MAX_IP_PACKET_SIZE) goto cleanup_nbl_queues; - } ULONG AlignedPacketSize = TunPacketAlign(sizeof(TUN_PACKET) + Packet->Size); - if (BufferEnd - BufferPos < (ptrdiff_t)AlignedPacketSize) - { - Status = STATUS_INVALID_USER_BUFFER; + if (Status = STATUS_INVALID_USER_BUFFER, (BufferEnd - BufferPos < (ptrdiff_t)AlignedPacketSize)) goto cleanup_nbl_queues; - } EtherTypeIndex Index; if (Packet->Size >= 20 && Packet->Data[0] >> 4 == 4) @@ -884,11 +866,8 @@ TunDispatchWrite(_Inout_ TUN_CTX *Ctx, _Inout_ IRP *Irp) NET_BUFFER_LIST *Nbl = NdisAllocateNetBufferAndNetBufferList( Ctx->NBLPool, 0, 0, Mdl, (ULONG)(Packet->Data - BufferStart), Packet->Size); - if (!Nbl) - { - Status = STATUS_INSUFFICIENT_RESOURCES; + if (Status = STATUS_INSUFFICIENT_RESOURCES, !Nbl) goto cleanup_nbl_queues; - } Nbl->SourceHandle = Ctx->MiniportAdapterHandle; NdisSetNblFlag(Nbl, EtherTypeConstants[Index].NblFlags); @@ -901,29 +880,25 @@ TunDispatchWrite(_Inout_ TUN_CTX *Ctx, _Inout_ IRP *Irp) BufferPos += AlignedPacketSize; } - if ((ULONG)(BufferPos - BufferStart) != Size) - { - Status = STATUS_INVALID_USER_BUFFER; + if (Status = STATUS_INVALID_USER_BUFFER, (ULONG)(BufferPos - BufferStart) != Size) goto cleanup_nbl_queues; - } Irp->IoStatus.Information = Size; - if (!NblCount) - { - Status = STATUS_SUCCESS; - goto cleanup_ExReleaseSpinLockShared; - } - if (!(Flags & TUN_FLAGS_RUNNING)) + if (Status = STATUS_SUCCESS, !NblCount) + goto cleanup_nbl_queues; + + KIRQL Irql = ExAcquireSpinLockShared(&Ctx->TransitionLock); + LONG Flags = InterlockedGet(&Ctx->Flags); + if ((Status = STATUS_FILE_FORCED_CLOSED, !(Flags & TUN_FLAGS_PRESENT)) || + (Status = STATUS_SUCCESS, !(Flags & TUN_FLAGS_RUNNING))) { InterlockedAdd64((LONG64 *)&Ctx->Statistics.ifInDiscards, NblCount); InterlockedAdd64((LONG64 *)&Ctx->Statistics.ifInErrors, NblCount); - Status = STATUS_SUCCESS; - goto cleanup_nbl_queues; + goto cleanup_ExReleaseSpinLockShared; } - InterlockedAdd64(&Ctx->ActiveNBLCount, NblCount); + InterlockedAdd64(&Ctx->ActiveNBLCount, NblCount + 1); *MdlRefcount = NblCount; - for (EtherTypeIndex Index = EtherTypeIndexStart; Index < EtherTypeIndexEnd; Index++) { if (!NblQueue[Index].Head) @@ -937,10 +912,12 @@ TunDispatchWrite(_Inout_ TUN_CTX *Ctx, _Inout_ IRP *Irp) } ExReleaseSpinLockShared(&Ctx->TransitionLock, Irql); - TunCompleteRequest(Ctx, Irp, STATUS_SUCCESS, IO_NETWORK_INCREMENT); TunCompletePause(Ctx, TRUE); + TunCompleteRequest(Ctx, Irp, STATUS_SUCCESS, IO_NETWORK_INCREMENT); return STATUS_SUCCESS; +cleanup_ExReleaseSpinLockShared: + ExReleaseSpinLockShared(&Ctx->TransitionLock, Irql); cleanup_nbl_queues: for (EtherTypeIndex Index = EtherTypeIndexStart; Index < EtherTypeIndexEnd; Index++) { @@ -951,14 +928,11 @@ cleanup_nbl_queues: NdisFreeNetBufferList(Nbl); } } -cleanup_ExReleaseSpinLockShared: - ExReleaseSpinLockShared(&Ctx->TransitionLock, Irql); NdisFreeMdl(Mdl); cleanup_ExFreePoolWithTag: ExFreePoolWithTag(BufferStart, TUN_MEMORY_TAG); cleanup_CompleteRequest: TunCompleteRequest(Ctx, Irp, Status, IO_NO_INCREMENT); - TunCompletePause(Ctx, TRUE); return Status; }