From 9504191ba0ba32460aae358ff7727141605127be Mon Sep 17 00:00:00 2001 From: Simon Rozman Date: Wed, 5 Jun 2019 13:04:22 +0200 Subject: [PATCH] Unify interlocked reference counting - Only positive values are valid. - Add missing overflow assert checks - Fix spacing Signed-off-by: Simon Rozman --- wintun.c | 28 +++++++++++++++++----------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/wintun.c b/wintun.c index 07aacac..6009711 100644 --- a/wintun.c +++ b/wintun.c @@ -182,7 +182,7 @@ _IRQL_requires_max_(DISPATCH_LEVEL) static NDIS_STATUS TunCompletePause(_Inout_ TUN_CTX *ctx, _In_ BOOLEAN async_completion) { ASSERT(InterlockedGet64(&ctx->ActiveTransactionCount) > 0); - if (!InterlockedDecrement64(&ctx->ActiveTransactionCount) && + if (InterlockedDecrement64(&ctx->ActiveTransactionCount) <= 0 && InterlockedCompareExchange((LONG *)&ctx->State, TUN_STATE_PAUSED, TUN_STATE_PAUSING) == TUN_STATE_PAUSING) { if (async_completion) NdisMPauseComplete(ctx->MiniportAdapterHandle); @@ -379,7 +379,9 @@ static NTSTATUS TunWriteIntoIrp(_Inout_ IRP *Irp, _Inout_ UCHAR *buffer, _In_ NE _IRQL_requires_same_ static void TunNBLRefInit(_Inout_ TUN_CTX *ctx, _Inout_ NET_BUFFER_LIST *nbl) { + ASSERT(InterlockedGet64(&ctx->ActiveTransactionCount) < MAXLONG64); InterlockedIncrement64(&ctx->ActiveTransactionCount); + ASSERT(InterlockedGet(&ctx->PacketQueue.NumNbl) < MAXLONG); InterlockedIncrement(&ctx->PacketQueue.NumNbl); InterlockedExchange64(NET_BUFFER_LIST_REFCOUNT(nbl), 1); } @@ -388,6 +390,7 @@ _IRQL_requires_same_ static void TunNBLRefInc(_Inout_ NET_BUFFER_LIST *nbl) { ASSERT(InterlockedGet64(NET_BUFFER_LIST_REFCOUNT(nbl))); + ASSERT(InterlockedGet64(NET_BUFFER_LIST_REFCOUNT(nbl)) < MAXLONG64); InterlockedIncrement64(NET_BUFFER_LIST_REFCOUNT(nbl)); } @@ -395,10 +398,11 @@ _When_( (SendCompleteFlags & NDIS_SEND_COMPLETE_FLAGS_DISPATCH_LEVEL), _IRQL_req _When_(!(SendCompleteFlags & NDIS_SEND_COMPLETE_FLAGS_DISPATCH_LEVEL), _IRQL_requires_max_(DISPATCH_LEVEL)) static BOOLEAN TunNBLRefDec(_Inout_ TUN_CTX *ctx, _Inout_ NET_BUFFER_LIST *nbl, _In_ ULONG SendCompleteFlags) { - ASSERT(InterlockedGet64(NET_BUFFER_LIST_REFCOUNT(nbl))); - if (!InterlockedDecrement64(NET_BUFFER_LIST_REFCOUNT(nbl))) { + ASSERT(InterlockedGet64(NET_BUFFER_LIST_REFCOUNT(nbl)) > 0); + if (InterlockedDecrement64(NET_BUFFER_LIST_REFCOUNT(nbl)) <= 0) { NET_BUFFER_LIST_NEXT_NBL(nbl) = NULL; NdisMSendNetBufferListsComplete(ctx->MiniportAdapterHandle, nbl, SendCompleteFlags); + ASSERT(InterlockedGet(&ctx->PacketQueue.NumNbl) > 0); InterlockedDecrement(&ctx->PacketQueue.NumNbl); TunCompletePause(ctx, TRUE); return TRUE; @@ -607,10 +611,10 @@ static NTSTATUS TunWriteFromIrp(_Inout_ TUN_CTX *ctx, _Inout_ IRP *Irp) goto cleanup_TunCompletePause; const UCHAR *b = buffer, *b_end = buffer + size; - ULONG nbl_count = 0; + LONG nbl_count = 0; NET_BUFFER_LIST *nbl_head = NULL, *nbl_tail = NULL; while (b + sizeof(TUN_PACKET) <= b_end) { - if (nbl_count >= MAXULONG) { + if (nbl_count >= MAXLONG) { status = STATUS_INVALID_USER_BUFFER; goto cleanup_nbl_head; @@ -874,8 +878,8 @@ static void TunReturnNetBufferLists(NDIS_HANDLE MiniportAdapterContext, PNET_BUF NdisFreeMdl(mdl); NdisFreeNetBufferList(nbl); - ASSERT(InterlockedGet(IRP_REFCOUNT(irp))); - if (!InterlockedDecrement(IRP_REFCOUNT(irp))) { + ASSERT(InterlockedGet(IRP_REFCOUNT(irp)) > 0); + if (InterlockedDecrement(IRP_REFCOUNT(irp)) <= 0) { TunCompleteRequest(ctx, irp, STATUS_SUCCESS, IO_NETWORK_INCREMENT); TunCompletePause(ctx, TRUE); } @@ -1250,6 +1254,7 @@ static NDIS_STATUS TunInitializeEx(NDIS_HANDLE MiniportAdapterHandle, NDIS_HANDL * of the MiniportInitializeEx function. */ TunIndicateStatus(MiniportAdapterHandle, MediaConnectStateDisconnected); + ASSERT(InterlockedGet64(&AdapterCount) < MAXLONG64); InterlockedIncrement64(&AdapterCount); InterlockedExchange((LONG *)&ctx->State, TUN_STATE_PAUSED); return NDIS_STATUS_SUCCESS; @@ -1287,7 +1292,7 @@ static void TunHaltEx(NDIS_HANDLE MiniportAdapterContext, NDIS_HALT_ACTION HaltA ASSERT(!InterlockedGet64(&ctx->ActiveTransactionCount)); /* Adapter should not be halted if it wasn't fully paused first. */ - InterlockedExchange((LONG*)& ctx->State, TUN_STATE_HALTING); + InterlockedExchange((LONG *)&ctx->State, TUN_STATE_HALTING); if (ctx->PnPNotifications.Handle) { PVOID h = ctx->PnPNotifications.Handle; @@ -1325,10 +1330,11 @@ static void TunHaltEx(NDIS_HANDLE MiniportAdapterContext, NDIS_HALT_ACTION HaltA /* MiniportAdapterHandle must not be used in TunDispatch(). After TunHaltEx() returns it is invalidated. */ ctx->MiniportAdapterHandle = NULL; - InterlockedExchange((LONG*)& ctx->PowerState, NdisDeviceStateUnspecified); - InterlockedExchange((LONG*)& ctx->State, TUN_STATE_HALTED); + InterlockedExchange((LONG *)&ctx->PowerState, NdisDeviceStateUnspecified); + InterlockedExchange((LONG *)&ctx->State, TUN_STATE_HALTED); - if (!InterlockedDecrement64(&AdapterCount)) + ASSERT(InterlockedGet64(&AdapterCount) > 0); + if (InterlockedDecrement64(&AdapterCount) <= 0) TunWaitForReferencesToDropToZero(ctx); /* Deregister device _after_ we are done writing to ctx not to risk an UaF. The ctx is hosted by device extension. */