Unify interlocked reference counting

- Only positive values are valid.
- Add missing overflow assert checks
- Fix spacing

Signed-off-by: Simon Rozman <simon@rozman.si>
This commit is contained in:
Simon Rozman 2019-06-05 13:04:22 +02:00 committed by Jason A. Donenfeld
parent 6c405efc42
commit 9504191ba0

View File

@ -182,7 +182,7 @@ _IRQL_requires_max_(DISPATCH_LEVEL)
static NDIS_STATUS TunCompletePause(_Inout_ TUN_CTX *ctx, _In_ BOOLEAN async_completion) static NDIS_STATUS TunCompletePause(_Inout_ TUN_CTX *ctx, _In_ BOOLEAN async_completion)
{ {
ASSERT(InterlockedGet64(&ctx->ActiveTransactionCount) > 0); ASSERT(InterlockedGet64(&ctx->ActiveTransactionCount) > 0);
if (!InterlockedDecrement64(&ctx->ActiveTransactionCount) && if (InterlockedDecrement64(&ctx->ActiveTransactionCount) <= 0 &&
InterlockedCompareExchange((LONG *)&ctx->State, TUN_STATE_PAUSED, TUN_STATE_PAUSING) == TUN_STATE_PAUSING) { InterlockedCompareExchange((LONG *)&ctx->State, TUN_STATE_PAUSED, TUN_STATE_PAUSING) == TUN_STATE_PAUSING) {
if (async_completion) if (async_completion)
NdisMPauseComplete(ctx->MiniportAdapterHandle); NdisMPauseComplete(ctx->MiniportAdapterHandle);
@ -379,7 +379,9 @@ static NTSTATUS TunWriteIntoIrp(_Inout_ IRP *Irp, _Inout_ UCHAR *buffer, _In_ NE
_IRQL_requires_same_ _IRQL_requires_same_
static void TunNBLRefInit(_Inout_ TUN_CTX *ctx, _Inout_ NET_BUFFER_LIST *nbl) static void TunNBLRefInit(_Inout_ TUN_CTX *ctx, _Inout_ NET_BUFFER_LIST *nbl)
{ {
ASSERT(InterlockedGet64(&ctx->ActiveTransactionCount) < MAXLONG64);
InterlockedIncrement64(&ctx->ActiveTransactionCount); InterlockedIncrement64(&ctx->ActiveTransactionCount);
ASSERT(InterlockedGet(&ctx->PacketQueue.NumNbl) < MAXLONG);
InterlockedIncrement(&ctx->PacketQueue.NumNbl); InterlockedIncrement(&ctx->PacketQueue.NumNbl);
InterlockedExchange64(NET_BUFFER_LIST_REFCOUNT(nbl), 1); InterlockedExchange64(NET_BUFFER_LIST_REFCOUNT(nbl), 1);
} }
@ -388,6 +390,7 @@ _IRQL_requires_same_
static void TunNBLRefInc(_Inout_ NET_BUFFER_LIST *nbl) static void TunNBLRefInc(_Inout_ NET_BUFFER_LIST *nbl)
{ {
ASSERT(InterlockedGet64(NET_BUFFER_LIST_REFCOUNT(nbl))); ASSERT(InterlockedGet64(NET_BUFFER_LIST_REFCOUNT(nbl)));
ASSERT(InterlockedGet64(NET_BUFFER_LIST_REFCOUNT(nbl)) < MAXLONG64);
InterlockedIncrement64(NET_BUFFER_LIST_REFCOUNT(nbl)); InterlockedIncrement64(NET_BUFFER_LIST_REFCOUNT(nbl));
} }
@ -395,10 +398,11 @@ _When_( (SendCompleteFlags & NDIS_SEND_COMPLETE_FLAGS_DISPATCH_LEVEL), _IRQL_req
_When_(!(SendCompleteFlags & NDIS_SEND_COMPLETE_FLAGS_DISPATCH_LEVEL), _IRQL_requires_max_(DISPATCH_LEVEL)) _When_(!(SendCompleteFlags & NDIS_SEND_COMPLETE_FLAGS_DISPATCH_LEVEL), _IRQL_requires_max_(DISPATCH_LEVEL))
static BOOLEAN TunNBLRefDec(_Inout_ TUN_CTX *ctx, _Inout_ NET_BUFFER_LIST *nbl, _In_ ULONG SendCompleteFlags) static BOOLEAN TunNBLRefDec(_Inout_ TUN_CTX *ctx, _Inout_ NET_BUFFER_LIST *nbl, _In_ ULONG SendCompleteFlags)
{ {
ASSERT(InterlockedGet64(NET_BUFFER_LIST_REFCOUNT(nbl))); ASSERT(InterlockedGet64(NET_BUFFER_LIST_REFCOUNT(nbl)) > 0);
if (!InterlockedDecrement64(NET_BUFFER_LIST_REFCOUNT(nbl))) { if (InterlockedDecrement64(NET_BUFFER_LIST_REFCOUNT(nbl)) <= 0) {
NET_BUFFER_LIST_NEXT_NBL(nbl) = NULL; NET_BUFFER_LIST_NEXT_NBL(nbl) = NULL;
NdisMSendNetBufferListsComplete(ctx->MiniportAdapterHandle, nbl, SendCompleteFlags); NdisMSendNetBufferListsComplete(ctx->MiniportAdapterHandle, nbl, SendCompleteFlags);
ASSERT(InterlockedGet(&ctx->PacketQueue.NumNbl) > 0);
InterlockedDecrement(&ctx->PacketQueue.NumNbl); InterlockedDecrement(&ctx->PacketQueue.NumNbl);
TunCompletePause(ctx, TRUE); TunCompletePause(ctx, TRUE);
return TRUE; return TRUE;
@ -607,10 +611,10 @@ static NTSTATUS TunWriteFromIrp(_Inout_ TUN_CTX *ctx, _Inout_ IRP *Irp)
goto cleanup_TunCompletePause; goto cleanup_TunCompletePause;
const UCHAR *b = buffer, *b_end = buffer + size; const UCHAR *b = buffer, *b_end = buffer + size;
ULONG nbl_count = 0; LONG nbl_count = 0;
NET_BUFFER_LIST *nbl_head = NULL, *nbl_tail = NULL; NET_BUFFER_LIST *nbl_head = NULL, *nbl_tail = NULL;
while (b + sizeof(TUN_PACKET) <= b_end) { while (b + sizeof(TUN_PACKET) <= b_end) {
if (nbl_count >= MAXULONG) { if (nbl_count >= MAXLONG) {
status = STATUS_INVALID_USER_BUFFER; status = STATUS_INVALID_USER_BUFFER;
goto cleanup_nbl_head; goto cleanup_nbl_head;
@ -874,8 +878,8 @@ static void TunReturnNetBufferLists(NDIS_HANDLE MiniportAdapterContext, PNET_BUF
NdisFreeMdl(mdl); NdisFreeMdl(mdl);
NdisFreeNetBufferList(nbl); NdisFreeNetBufferList(nbl);
ASSERT(InterlockedGet(IRP_REFCOUNT(irp))); ASSERT(InterlockedGet(IRP_REFCOUNT(irp)) > 0);
if (!InterlockedDecrement(IRP_REFCOUNT(irp))) { if (InterlockedDecrement(IRP_REFCOUNT(irp)) <= 0) {
TunCompleteRequest(ctx, irp, STATUS_SUCCESS, IO_NETWORK_INCREMENT); TunCompleteRequest(ctx, irp, STATUS_SUCCESS, IO_NETWORK_INCREMENT);
TunCompletePause(ctx, TRUE); TunCompletePause(ctx, TRUE);
} }
@ -1250,6 +1254,7 @@ static NDIS_STATUS TunInitializeEx(NDIS_HANDLE MiniportAdapterHandle, NDIS_HANDL
* of the MiniportInitializeEx function. * of the MiniportInitializeEx function.
*/ */
TunIndicateStatus(MiniportAdapterHandle, MediaConnectStateDisconnected); TunIndicateStatus(MiniportAdapterHandle, MediaConnectStateDisconnected);
ASSERT(InterlockedGet64(&AdapterCount) < MAXLONG64);
InterlockedIncrement64(&AdapterCount); InterlockedIncrement64(&AdapterCount);
InterlockedExchange((LONG *)&ctx->State, TUN_STATE_PAUSED); InterlockedExchange((LONG *)&ctx->State, TUN_STATE_PAUSED);
return NDIS_STATUS_SUCCESS; return NDIS_STATUS_SUCCESS;
@ -1328,7 +1333,8 @@ static void TunHaltEx(NDIS_HANDLE MiniportAdapterContext, NDIS_HALT_ACTION HaltA
InterlockedExchange((LONG *)&ctx->PowerState, NdisDeviceStateUnspecified); InterlockedExchange((LONG *)&ctx->PowerState, NdisDeviceStateUnspecified);
InterlockedExchange((LONG *)&ctx->State, TUN_STATE_HALTED); InterlockedExchange((LONG *)&ctx->State, TUN_STATE_HALTED);
if (!InterlockedDecrement64(&AdapterCount)) ASSERT(InterlockedGet64(&AdapterCount) > 0);
if (InterlockedDecrement64(&AdapterCount) <= 0)
TunWaitForReferencesToDropToZero(ctx); TunWaitForReferencesToDropToZero(ctx);
/* Deregister device _after_ we are done writing to ctx not to risk an UaF. The ctx is hosted by device extension. */ /* Deregister device _after_ we are done writing to ctx not to risk an UaF. The ctx is hosted by device extension. */