Terminate device connection on NDIS pause

If device connection is not terminated on NDIS pause, the driver doesn't
unload when adapter is disabled. This also prevents driver updates
without reboot.

Signed-off-by: Simon Rozman <simon@rozman.si>
This commit is contained in:
Simon Rozman 2019-03-22 13:47:17 +01:00 committed by Jason A. Donenfeld
parent 9fbc8091ab
commit 9704a4d11f

View File

@ -104,6 +104,7 @@ static NDIS_HANDLE NdisMiniportDriverHandle = NULL;
#define InterlockedGet(val) (InterlockedAdd((val), 0)) #define InterlockedGet(val) (InterlockedAdd((val), 0))
#define InterlockedGet64(val) (InterlockedAdd64((val), 0)) #define InterlockedGet64(val) (InterlockedAdd64((val), 0))
#define InterlockedGetPointer(val) (InterlockedCompareExchangePointer((val), NULL, NULL))
#define InterlockedSubtract64(val, n) (InterlockedAdd64((val), -(LONG64)(n))) #define InterlockedSubtract64(val, n) (InterlockedAdd64((val), -(LONG64)(n)))
#define TunPacketAlign(size) (((UINT)(size) + (UINT)(TUN_EXCH_ALIGNMENT - 1)) & ~(UINT)(TUN_EXCH_ALIGNMENT - 1)) #define TunPacketAlign(size) (((UINT)(size) + (UINT)(TUN_EXCH_ALIGNMENT - 1)) & ~(UINT)(TUN_EXCH_ALIGNMENT - 1))
#define TunInitUnicodeString(str, buf) { (str)->Length = 0; (str)->MaximumLength = sizeof(buf); (str)->Buffer = buf; } #define TunInitUnicodeString(str, buf) { (str)->Length = 0; (str)->MaximumLength = sizeof(buf); (str)->Buffer = buf; }
@ -352,11 +353,15 @@ static NTSTATUS TunDispatchCreate(DEVICE_OBJECT *DeviceObject, IRP *Irp)
goto cleanup_complete_req; goto cleanup_complete_req;
} }
if ((status = TunCheckForPause(ctx, 1i64)) != STATUS_SUCCESS)
goto cleanup_TunCompletePause;
ASSERT(InterlockedGet64(&ctx->Device.RefCount) < MAXLONG64); ASSERT(InterlockedGet64(&ctx->Device.RefCount) < MAXLONG64);
InterlockedIncrement64(&ctx->Device.RefCount); InterlockedIncrement64(&ctx->Device.RefCount);
TunIndicateStatus(ctx); TunIndicateStatus(ctx);
status = STATUS_SUCCESS;
cleanup_TunCompletePause:
TunCompletePause(ctx, 1i64);
cleanup_complete_req: cleanup_complete_req:
TunCompleteRequest(Irp, 0, status); TunCompleteRequest(Irp, 0, status);
return status; return status;
@ -374,11 +379,15 @@ static NTSTATUS TunDispatchClose(DEVICE_OBJECT *DeviceObject, IRP *Irp)
goto cleanup_complete_req; goto cleanup_complete_req;
} }
if ((status = TunCheckForPause(ctx, 1i64)) != STATUS_SUCCESS)
goto cleanup_TunCompletePause;
ASSERT(InterlockedGet64(&ctx->Device.RefCount) > 0); ASSERT(InterlockedGet64(&ctx->Device.RefCount) > 0);
InterlockedDecrement64(&ctx->Device.RefCount); InterlockedDecrement64(&ctx->Device.RefCount);
TunIndicateStatus(ctx); TunIndicateStatus(ctx);
status = STATUS_SUCCESS;
cleanup_TunCompletePause:
TunCompletePause(ctx, 1i64);
cleanup_complete_req: cleanup_complete_req:
TunCompleteRequest(Irp, 0, status); TunCompleteRequest(Irp, 0, status);
return status; return status;
@ -396,10 +405,16 @@ static NTSTATUS TunDispatchRead(DEVICE_OBJECT *DeviceObject, IRP *Irp)
goto cleanup_complete_req; goto cleanup_complete_req;
} }
if ((status = TunCheckForPause(ctx, 1i64)) != STATUS_SUCCESS)
goto cleanup_TunCompletePause;
IoMarkIrpPending(Irp); IoMarkIrpPending(Irp);
IoStartPacket(DeviceObject, Irp, NULL, NULL); IoStartPacket(DeviceObject, Irp, NULL, NULL);
TunCompletePause(ctx, 1i64);
return STATUS_PENDING; return STATUS_PENDING;
cleanup_TunCompletePause:
TunCompletePause(ctx, 1i64);
cleanup_complete_req: cleanup_complete_req:
TunCompleteRequest(Irp, 0, status); TunCompleteRequest(Irp, 0, status);
return status; return status;
@ -418,11 +433,14 @@ static NTSTATUS TunDispatchWrite(DEVICE_OBJECT *DeviceObject, IRP *Irp)
goto cleanup_complete_req; goto cleanup_complete_req;
} }
if ((status = TunCheckForPause(ctx, 1i64)) != STATUS_SUCCESS)
goto cleanup_TunCompletePause;
UCHAR *buffer = NULL; UCHAR *buffer = NULL;
ULONG size = 0; ULONG size = 0;
status = TunGetIRPBuffer(Irp, &buffer, &size); status = TunGetIRPBuffer(Irp, &buffer, &size);
if (status != STATUS_SUCCESS) if (status != STATUS_SUCCESS)
goto cleanup_complete_req; goto cleanup_TunCompletePause;
const UCHAR *b = buffer, *b_end = buffer + size; const UCHAR *b = buffer, *b_end = buffer + size;
ULONG nbl_count = 0; ULONG nbl_count = 0;
@ -472,18 +490,6 @@ static NTSTATUS TunDispatchWrite(DEVICE_OBJECT *DeviceObject, IRP *Irp)
b += p_size; b += p_size;
} }
BOOLEAN update_statistics = TRUE;
if ((status = TunCheckForPause(ctx, nbl_count)) != STATUS_SUCCESS) {
update_statistics = FALSE;
goto cleanup_nbl_head;
}
information = b - buffer;
status = STATUS_SUCCESS;
if (!nbl_head)
goto cleanup_statistics;
/* Commentary from Jason: /* Commentary from Jason:
* *
* Problem statement: * Problem statement:
@ -513,34 +519,33 @@ static NTSTATUS TunDispatchWrite(DEVICE_OBJECT *DeviceObject, IRP *Irp)
* to make it work correctly. (2) seems like an acceptable stopgap solution until we're smart enough to reason about * to make it work correctly. (2) seems like an acceptable stopgap solution until we're smart enough to reason about
* (1). So, let's implement (2) now, and we'll let more knowledgeable people advise us on (1) later. * (1). So, let's implement (2) now, and we'll let more knowledgeable people advise us on (1) later.
*/ */
NdisMIndicateReceiveNetBufferLists(ctx->MiniportAdapterHandle, nbl_head, NDIS_DEFAULT_PORT_NUMBER, nbl_count, NDIS_RECEIVE_FLAGS_RESOURCES); if (nbl_head)
NdisMIndicateReceiveNetBufferLists(ctx->MiniportAdapterHandle, nbl_head, NDIS_DEFAULT_PORT_NUMBER, nbl_count, NDIS_RECEIVE_FLAGS_RESOURCES);
cleanup_nbl_head:
for (NET_BUFFER_LIST *nbl = nbl_head, *nbl_next; nbl; nbl = nbl_next) { for (NET_BUFFER_LIST *nbl = nbl_head, *nbl_next; nbl; nbl = nbl_next) {
nbl_next = NET_BUFFER_LIST_NEXT_NBL(nbl); nbl_next = NET_BUFFER_LIST_NEXT_NBL(nbl);
NET_BUFFER_LIST_NEXT_NBL(nbl) = NULL; NET_BUFFER_LIST_NEXT_NBL(nbl) = NULL;
MDL *mdl = NET_BUFFER_FIRST_MDL(NET_BUFFER_LIST_FIRST_NB(nbl)); MDL *mdl = NET_BUFFER_FIRST_MDL(NET_BUFFER_LIST_FIRST_NB(nbl));
if (update_statistics) { if (NET_BUFFER_LIST_STATUS(nbl) == NDIS_STATUS_SUCCESS) {
if (NET_BUFFER_LIST_STATUS(nbl) == NDIS_STATUS_SUCCESS) { ULONG p_size = MmGetMdlByteCount(mdl);
ULONG p_size = MmGetMdlByteCount(mdl); stat_size += p_size;
stat_size += p_size; stat_p_ok++;
stat_p_ok++; } else
} else stat_p_err++;
stat_p_err++;
}
NdisFreeMdl(mdl); NdisFreeMdl(mdl);
NdisFreeNetBufferList(nbl); NdisFreeNetBufferList(nbl);
} }
cleanup_statistics:
TunCompletePause(ctx, nbl_count);
InterlockedAdd64((LONG64 *)&ctx->Statistics.ifHCInOctets, stat_size); InterlockedAdd64((LONG64 *)&ctx->Statistics.ifHCInOctets, stat_size);
InterlockedAdd64((LONG64 *)&ctx->Statistics.ifHCInUcastOctets, stat_size); InterlockedAdd64((LONG64 *)&ctx->Statistics.ifHCInUcastOctets, stat_size);
InterlockedAdd64((LONG64 *)&ctx->Statistics.ifHCInUcastPkts, stat_p_ok); InterlockedAdd64((LONG64 *)&ctx->Statistics.ifHCInUcastPkts, stat_p_ok);
InterlockedAdd64((LONG64 *)&ctx->Statistics.ifInErrors, stat_p_err); InterlockedAdd64((LONG64 *)&ctx->Statistics.ifInErrors, stat_p_err);
information = b - buffer;
cleanup_TunCompletePause:
TunCompletePause(ctx, 1i64);
cleanup_complete_req: cleanup_complete_req:
TunCompleteRequest(Irp, information, status); TunCompleteRequest(Irp, information, status);
return status; return status;
@ -562,6 +567,11 @@ static NDIS_STATUS TunPause(NDIS_HANDLE MiniportAdapterContext, PNDIS_MINIPORT_P
if (InterlockedCompareExchange((LONG *)&ctx->State, TUN_STATE_PAUSING, TUN_STATE_RUNNING) != TUN_STATE_RUNNING) if (InterlockedCompareExchange((LONG *)&ctx->State, TUN_STATE_PAUSING, TUN_STATE_RUNNING) != TUN_STATE_RUNNING)
return NDIS_STATUS_FAILURE; return NDIS_STATUS_FAILURE;
/* Reset adapter context in device object, as Windows keep calling dispatch handlers even after NdisDeregisterDeviceEx(). */
TUN_CTX **control_device_extension = (TUN_CTX **)NdisGetDeviceReservedExtension(ctx->Device.Object);
if (control_device_extension)
InterlockedExchangePointer(control_device_extension, NULL);
ULONG nbl_count = 0; ULONG nbl_count = 0;
NdisAcquireSpinLock(&ctx->PacketQueue.Lock); NdisAcquireSpinLock(&ctx->PacketQueue.Lock);
if (ctx->PacketQueue.Head) { if (ctx->PacketQueue.Head) {
@ -577,6 +587,12 @@ static NDIS_STATUS TunPause(NDIS_HANDLE MiniportAdapterContext, PNDIS_MINIPORT_P
} else } else
NdisReleaseSpinLock(&ctx->PacketQueue.Lock); NdisReleaseSpinLock(&ctx->PacketQueue.Lock);
/* Cancel pending IRP to unblock waiting clients. */
IRP *Irp = InterlockedExchangePointer((PVOID volatile *)&ctx->Device.ActiveIrp, NULL);
if (Irp)
TunCompleteRequest(Irp, 0, STATUS_CANCELLED);
InterlockedExchange64(&ctx->Device.RefCount, 0);
TunIndicateStatus(ctx); TunIndicateStatus(ctx);
if (InterlockedSubtract64(&ctx->ActiveTransactionCount, nbl_count)) if (InterlockedSubtract64(&ctx->ActiveTransactionCount, nbl_count))
@ -594,6 +610,11 @@ static NDIS_STATUS TunRestart(NDIS_HANDLE MiniportAdapterContext, PNDIS_MINIPORT
if (InterlockedCompareExchange((LONG *)&ctx->State, TUN_STATE_RESTARTING, TUN_STATE_PAUSED) != TUN_STATE_PAUSED) if (InterlockedCompareExchange((LONG *)&ctx->State, TUN_STATE_RESTARTING, TUN_STATE_PAUSED) != TUN_STATE_PAUSED)
return NDIS_STATUS_FAILURE; return NDIS_STATUS_FAILURE;
TUN_CTX **control_device_extension = (TUN_CTX **)NdisGetDeviceReservedExtension(ctx->Device.Object);
if (control_device_extension)
InterlockedExchangePointer(control_device_extension, ctx);
ASSERT(!InterlockedGet64(&ctx->Device.RefCount));
TunIndicateStatus(ctx); TunIndicateStatus(ctx);
InterlockedExchange((LONG *)&ctx->State, TUN_STATE_RUNNING); InterlockedExchange((LONG *)&ctx->State, TUN_STATE_RUNNING);
@ -869,18 +890,9 @@ static NDIS_STATUS TunInitializeEx(NDIS_HANDLE MiniportAdapterHandle, NDIS_HANDL
IoSetStartIoAttributes(ctx->Device.Object, TRUE, TRUE); IoSetStartIoAttributes(ctx->Device.Object, TRUE, TRUE);
TUN_CTX **control_device_extension = (TUN_CTX **)NdisGetDeviceReservedExtension(ctx->Device.Object);
if (!control_device_extension) {
status = NDIS_STATUS_FAILURE;
goto cleanup_NdisDeregisterDeviceEx;
}
InterlockedExchangePointer(control_device_extension, ctx);
ctx->State = TUN_STATE_PAUSED; ctx->State = TUN_STATE_PAUSED;
return NDIS_STATUS_SUCCESS; return NDIS_STATUS_SUCCESS;
cleanup_NdisDeregisterDeviceEx:
NdisDeregisterDeviceEx(ctx->Device.Handle);
cleanup_NdisFreeNetBufferListPool: cleanup_NdisFreeNetBufferListPool:
NdisFreeNetBufferListPool(ctx->NBLPool); NdisFreeNetBufferListPool(ctx->NBLPool);
cleanup_NdisFreeSpinLock: cleanup_NdisFreeSpinLock:
@ -906,16 +918,8 @@ static void TunHaltEx(NDIS_HANDLE MiniportAdapterContext, NDIS_HALT_ACTION HaltA
return; return;
ASSERT(!InterlockedGet64(&ctx->ActiveTransactionCount)); ASSERT(!InterlockedGet64(&ctx->ActiveTransactionCount));
ASSERT(!InterlockedGetPointer((PVOID volatile *)&ctx->Device.ActiveIrp));
/* Reset adapter context in device object, as Windows keep calling dispatch handlers even after NdisDeregisterDeviceEx(). */ ASSERT(!InterlockedGet64(&ctx->Device.RefCount));
TUN_CTX **control_device_extension = (TUN_CTX **)NdisGetDeviceReservedExtension(ctx->Device.Object);
if (control_device_extension)
InterlockedExchangePointer(control_device_extension, NULL);
/* Cancel pending IRP to unblock waiting clients. */
IRP *Irp = InterlockedExchangePointer((PVOID volatile *)&ctx->Device.ActiveIrp, NULL);
if (Irp)
TunCompleteRequest(Irp, 0, STATUS_CANCELLED);
/* Release resources. */ /* Release resources. */
NdisDeregisterDeviceEx(ctx->Device.Handle); NdisDeregisterDeviceEx(ctx->Device.Handle);