Cleanup atomic getters

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
This commit is contained in:
Jason A. Donenfeld 2019-07-05 13:38:35 +00:00
parent 2e7809f0d1
commit 6fe055d0e8

View File

@ -105,7 +105,7 @@ typedef struct _TUN_CTX
KSPIN_LOCK Lock; KSPIN_LOCK Lock;
NET_BUFFER_LIST *FirstNbl, *LastNbl; NET_BUFFER_LIST *FirstNbl, *LastNbl;
NET_BUFFER *NextNb; NET_BUFFER *NextNb;
LONG NumNbl; volatile LONG NumNbl;
} PacketQueue; } PacketQueue;
NDIS_HANDLE NBLPool; NDIS_HANDLE NBLPool;
@ -131,10 +131,29 @@ static NDIS_HANDLE NdisMiniportDriverHandle;
static DRIVER_DISPATCH *NdisDispatchPnP; static DRIVER_DISPATCH *NdisDispatchPnP;
static volatile LONG64 TunAdapterCount; static volatile LONG64 TunAdapterCount;
#define InterlockedGet(val) (InterlockedAdd((val), 0)) static __forceinline LONG
#define InterlockedGet64(val) (InterlockedAdd64((val), 0)) InterlockedGet(_In_ _Interlocked_operand_ LONG volatile *Value)
#define InterlockedGetPointer(val) (InterlockedCompareExchangePointer((val), NULL, NULL)) {
#define TunPacketAlign(size) (((UINT)(size) + (UINT)(TUN_EXCH_ALIGNMENT - 1)) & ~(UINT)(TUN_EXCH_ALIGNMENT - 1)) return *Value;
}
static __forceinline PVOID
InterlockedGetPointer(_In_ _Interlocked_operand_ PVOID volatile *Value)
{
return *Value;
}
static __forceinline LONG64
InterlockedGet64(_In_ _Interlocked_operand_ LONG64 volatile *Value)
{
#ifdef _WIN64
return *Value;
#else
return InterlockedCompareExchange64(Value, 0, 0);
#endif
}
#define TunPacketAlign(size) (((ULONG)(size) + (ULONG)(TUN_EXCH_ALIGNMENT - 1)) & ~(ULONG)(TUN_EXCH_ALIGNMENT - 1))
#define TunInitUnicodeString(str, buf) \ #define TunInitUnicodeString(str, buf) \
{ \ { \
(str)->Length = 0; \ (str)->Length = 0; \
@ -460,7 +479,7 @@ TunAppendNBL(_Inout_ NET_BUFFER_LIST **Head, _Inout_ NET_BUFFER_LIST **Tail, __d
_Requires_lock_not_held_(Ctx->PacketQueue.Lock) _Requires_lock_not_held_(Ctx->PacketQueue.Lock)
_IRQL_requires_max_(DISPATCH_LEVEL) _IRQL_requires_max_(DISPATCH_LEVEL)
static void static void
TunQueueAppend(_Inout_ TUN_CTX *Ctx, _In_ NET_BUFFER_LIST *Nbl, _In_ UINT MaxNbls) TunQueueAppend(_Inout_ TUN_CTX *Ctx, _In_ NET_BUFFER_LIST *Nbl, _In_ ULONG MaxNbls)
{ {
for (NET_BUFFER_LIST *NextNbl; Nbl; Nbl = NextNbl) for (NET_BUFFER_LIST *NextNbl; Nbl; Nbl = NextNbl)
{ {
@ -477,7 +496,7 @@ TunQueueAppend(_Inout_ TUN_CTX *Ctx, _In_ NET_BUFFER_LIST *Nbl, _In_ UINT MaxNbl
TunNBLRefInit(Ctx, Nbl); TunNBLRefInit(Ctx, Nbl);
TunAppendNBL(&Ctx->PacketQueue.FirstNbl, &Ctx->PacketQueue.LastNbl, Nbl); TunAppendNBL(&Ctx->PacketQueue.FirstNbl, &Ctx->PacketQueue.LastNbl, Nbl);
while ((UINT)InterlockedGet(&Ctx->PacketQueue.NumNbl) > MaxNbls && Ctx->PacketQueue.FirstNbl) while ((ULONG)InterlockedGet(&Ctx->PacketQueue.NumNbl) > MaxNbls && Ctx->PacketQueue.FirstNbl)
{ {
NET_BUFFER_LIST *SecondNbl = NET_BUFFER_LIST_NEXT_NBL(Ctx->PacketQueue.FirstNbl); NET_BUFFER_LIST *SecondNbl = NET_BUFFER_LIST_NEXT_NBL(Ctx->PacketQueue.FirstNbl);
@ -773,12 +792,12 @@ TunReturnNetBufferLists(NDIS_HANDLE MiniportAdapterContext, PNET_BUFFER_LIST Net
TunCompletePause(Ctx, TRUE); TunCompletePause(Ctx, TRUE);
volatile LONG *MdlRefcount = NET_BUFFER_LIST_MDL_REFCOUNT(Nbl); LONG volatile *MdlRefCount = NET_BUFFER_LIST_MDL_REFCOUNT(Nbl);
ASSERT(InterlockedGet(MdlRefcount) > 0); ASSERT(InterlockedGet(MdlRefCount) > 0);
if (InterlockedDecrement(MdlRefcount) <= 0) if (InterlockedDecrement(MdlRefCount) <= 0)
{ {
/* MdlRefcount is also the first pointer in the allocation. */ /* MdlRefCount is also the first pointer in the allocation. */
ExFreePoolWithTag((PVOID)MdlRefcount, TUN_MEMORY_TAG); ExFreePoolWithTag((PVOID)MdlRefCount, TUN_MEMORY_TAG);
NdisFreeMdl(NET_BUFFER_LIST_FIRST_NB(Nbl)->MdlChain); NdisFreeMdl(NET_BUFFER_LIST_FIRST_NB(Nbl)->MdlChain);
} }
NdisFreeNetBufferList(Nbl); NdisFreeNetBufferList(Nbl);
@ -804,7 +823,7 @@ TunDispatchWrite(_Inout_ TUN_CTX *Ctx, _Inout_ IRP *Irp)
if (Status = STATUS_INSUFFICIENT_RESOURCES, !BufferStart) if (Status = STATUS_INSUFFICIENT_RESOURCES, !BufferStart)
goto cleanup_CompleteRequest; goto cleanup_CompleteRequest;
/* We don't write to this until we're totally finished using Packet->Size. */ /* We don't write to this until we're totally finished using Packet->Size. */
LONG *MdlRefcount = (LONG *)BufferStart; LONG *MdlRefCount = (LONG *)BufferStart;
try try
{ {
Status = STATUS_INVALID_USER_BUFFER; Status = STATUS_INVALID_USER_BUFFER;
@ -870,7 +889,7 @@ TunDispatchWrite(_Inout_ TUN_CTX *Ctx, _Inout_ IRP *Irp)
NdisSetNblFlag(Nbl, EtherTypeConstants[Index].NblFlags); NdisSetNblFlag(Nbl, EtherTypeConstants[Index].NblFlags);
NET_BUFFER_LIST_INFO(Nbl, NetBufferListFrameType) = (PVOID)EtherTypeConstants[Index].NblProto; NET_BUFFER_LIST_INFO(Nbl, NetBufferListFrameType) = (PVOID)EtherTypeConstants[Index].NblProto;
NET_BUFFER_LIST_STATUS(Nbl) = NDIS_STATUS_SUCCESS; NET_BUFFER_LIST_STATUS(Nbl) = NDIS_STATUS_SUCCESS;
NET_BUFFER_LIST_MDL_REFCOUNT(Nbl) = MdlRefcount; NET_BUFFER_LIST_MDL_REFCOUNT(Nbl) = MdlRefCount;
TunAppendNBL(&NblQueue[Index].Head, &NblQueue[Index].Tail, Nbl); TunAppendNBL(&NblQueue[Index].Head, &NblQueue[Index].Tail, Nbl);
NblQueue[Index].Count++; NblQueue[Index].Count++;
NblCount++; NblCount++;
@ -895,7 +914,7 @@ TunDispatchWrite(_Inout_ TUN_CTX *Ctx, _Inout_ IRP *Irp)
} }
InterlockedAdd64(&Ctx->ActiveNBLCount, NblCount); InterlockedAdd64(&Ctx->ActiveNBLCount, NblCount);
*MdlRefcount = NblCount; *MdlRefCount = NblCount;
for (EtherTypeIndex Index = EtherTypeIndexStart; Index < EtherTypeIndexEnd; Index++) for (EtherTypeIndex Index = EtherTypeIndexStart; Index < EtherTypeIndexEnd; Index++)
{ {
if (!NblQueue[Index].Head) if (!NblQueue[Index].Head)
@ -1544,7 +1563,7 @@ TunOidQueryWrite32or64(_Inout_ NDIS_OID_REQUEST *OidRequest, _In_ ULONG64 Value)
_IRQL_requires_max_(APC_LEVEL) _IRQL_requires_max_(APC_LEVEL)
_Must_inspect_result_ _Must_inspect_result_
static NDIS_STATUS static NDIS_STATUS
TunOidQueryWriteBuf(_Inout_ NDIS_OID_REQUEST *OidRequest, _In_bytecount_(Size) const void *Buf, _In_ UINT Size) TunOidQueryWriteBuf(_Inout_ NDIS_OID_REQUEST *OidRequest, _In_bytecount_(Size) const void *Buf, _In_ ULONG Size)
{ {
if (OidRequest->DATA.QUERY_INFORMATION.InformationBufferLength < Size) if (OidRequest->DATA.QUERY_INFORMATION.InformationBufferLength < Size)
{ {
@ -1584,7 +1603,7 @@ TunOidQuery(_Inout_ TUN_CTX *ctx, _Inout_ NDIS_OID_REQUEST *OidRequest)
return TunOidQueryWrite(OidRequest, TUN_HTONL(TUN_VENDOR_ID)); return TunOidQueryWrite(OidRequest, TUN_HTONL(TUN_VENDOR_ID));
case OID_GEN_VENDOR_DESCRIPTION: case OID_GEN_VENDOR_DESCRIPTION:
return TunOidQueryWriteBuf(OidRequest, TUN_VENDOR_NAME, (UINT)sizeof(TUN_VENDOR_NAME)); return TunOidQueryWriteBuf(OidRequest, TUN_VENDOR_NAME, (ULONG)sizeof(TUN_VENDOR_NAME));
case OID_GEN_VENDOR_DRIVER_VERSION: case OID_GEN_VENDOR_DRIVER_VERSION:
return TunOidQueryWrite(OidRequest, (WINTUN_VERSION_MAJ << 16) | WINTUN_VERSION_MIN); return TunOidQueryWrite(OidRequest, (WINTUN_VERSION_MAJ << 16) | WINTUN_VERSION_MIN);
@ -1604,16 +1623,16 @@ TunOidQuery(_Inout_ TUN_CTX *ctx, _Inout_ NDIS_OID_REQUEST *OidRequest)
InterlockedGet64((LONG64 *)&ctx->Statistics.ifHCInBroadcastPkts)); InterlockedGet64((LONG64 *)&ctx->Statistics.ifHCInBroadcastPkts));
case OID_GEN_STATISTICS: case OID_GEN_STATISTICS:
return TunOidQueryWriteBuf(OidRequest, &ctx->Statistics, (UINT)sizeof(ctx->Statistics)); return TunOidQueryWriteBuf(OidRequest, &ctx->Statistics, (ULONG)sizeof(ctx->Statistics));
case OID_GEN_INTERRUPT_MODERATION: { case OID_GEN_INTERRUPT_MODERATION: {
static const NDIS_INTERRUPT_MODERATION_PARAMETERS intp = { static const NDIS_INTERRUPT_MODERATION_PARAMETERS InterruptParameters = {
.Header = { .Type = NDIS_OBJECT_TYPE_DEFAULT, .Header = { .Type = NDIS_OBJECT_TYPE_DEFAULT,
.Revision = NDIS_INTERRUPT_MODERATION_PARAMETERS_REVISION_1, .Revision = NDIS_INTERRUPT_MODERATION_PARAMETERS_REVISION_1,
.Size = NDIS_SIZEOF_INTERRUPT_MODERATION_PARAMETERS_REVISION_1 }, .Size = NDIS_SIZEOF_INTERRUPT_MODERATION_PARAMETERS_REVISION_1 },
.InterruptModeration = NdisInterruptModerationNotSupported .InterruptModeration = NdisInterruptModerationNotSupported
}; };
return TunOidQueryWriteBuf(OidRequest, &intp, (UINT)sizeof(intp)); return TunOidQueryWriteBuf(OidRequest, &InterruptParameters, (ULONG)sizeof(InterruptParameters));
} }
case OID_PNP_QUERY_POWER: case OID_PNP_QUERY_POWER: