diff --git a/wintun.c b/wintun.c index ab61d97..45a7819 100644 --- a/wintun.c +++ b/wintun.c @@ -276,7 +276,11 @@ static NTSTATUS TunGetIrpBuffer(_In_ IRP *Irp, _Out_ UCHAR **buffer, _Out_ ULONG case IRP_MJ_WRITE: *size = stack->Parameters.Write.Length; - priority = NormalPagePriority | MdlMappingNoWrite; + + /* If we use MdlMappingNoWrite flag and call NdisMIndicateReceiveNetBufferLists without + * NDIS_RECEIVE_FLAGS_RESOURCES flag we've got a ATTEMPTED_WRITE_TO_READONLY_MEMORY page + * fault. */ + priority = NormalPagePriority /*| MdlMappingNoWrite*/; break; default: @@ -366,29 +370,29 @@ static NTSTATUS TunWriteIntoIrp(_Inout_ IRP *Irp, _Inout_ UCHAR *buffer, _In_ NE return STATUS_SUCCESS; } -#define NET_BUFFER_LIST_MINIPORT_RESERVED_REFCOUNT(nbl) ((volatile LONG64 *)NET_BUFFER_LIST_MINIPORT_RESERVED(nbl)) +#define NET_BUFFER_LIST_REFCOUNT(nbl) ((volatile LONG64 *)NET_BUFFER_LIST_MINIPORT_RESERVED(nbl)) _IRQL_requires_same_ static void TunNBLRefInit(_Inout_ TUN_CTX *ctx, _Inout_ NET_BUFFER_LIST *nbl) { InterlockedIncrement64(&ctx->ActiveTransactionCount); InterlockedIncrement(&ctx->PacketQueue.NumNbl); - InterlockedExchange64(NET_BUFFER_LIST_MINIPORT_RESERVED_REFCOUNT(nbl), 1); + InterlockedExchange64(NET_BUFFER_LIST_REFCOUNT(nbl), 1); } _IRQL_requires_same_ static void TunNBLRefInc(_Inout_ NET_BUFFER_LIST *nbl) { - ASSERT(InterlockedGet64(NET_BUFFER_LIST_MINIPORT_RESERVED_REFCOUNT(nbl))); - InterlockedIncrement64(NET_BUFFER_LIST_MINIPORT_RESERVED_REFCOUNT(nbl)); + ASSERT(InterlockedGet64(NET_BUFFER_LIST_REFCOUNT(nbl))); + InterlockedIncrement64(NET_BUFFER_LIST_REFCOUNT(nbl)); } _When_( (SendCompleteFlags & NDIS_SEND_COMPLETE_FLAGS_DISPATCH_LEVEL), _IRQL_requires_ (DISPATCH_LEVEL)) _When_(!(SendCompleteFlags & NDIS_SEND_COMPLETE_FLAGS_DISPATCH_LEVEL), _IRQL_requires_max_(DISPATCH_LEVEL)) static BOOLEAN TunNBLRefDec(_Inout_ TUN_CTX *ctx, _Inout_ NET_BUFFER_LIST *nbl, _In_ ULONG SendCompleteFlags) { - ASSERT(InterlockedGet64(NET_BUFFER_LIST_MINIPORT_RESERVED_REFCOUNT(nbl))); - if (!InterlockedDecrement64(NET_BUFFER_LIST_MINIPORT_RESERVED_REFCOUNT(nbl))) { + ASSERT(InterlockedGet64(NET_BUFFER_LIST_REFCOUNT(nbl))); + if (!InterlockedDecrement64(NET_BUFFER_LIST_REFCOUNT(nbl))) { NET_BUFFER_LIST_NEXT_NBL(nbl) = NULL; NdisMSendNetBufferListsComplete(ctx->MiniportAdapterHandle, nbl, SendCompleteFlags); InterlockedDecrement(&ctx->PacketQueue.NumNbl); @@ -579,6 +583,9 @@ static void TunQueueProcess(_Inout_ TUN_CTX *ctx) } } +#define IRP_REFCOUNT(irp) ((volatile LONG64 *)&(irp)->Tail.Overlay.DriverContext[0]) +#define NET_BUFFER_LIST_IRP(nbl) (NET_BUFFER_LIST_MINIPORT_RESERVED(nbl)[0]) + _IRQL_requires_max_(DISPATCH_LEVEL) _Must_inspect_result_ static NTSTATUS TunWriteFromIrp(_Inout_ TUN_CTX *ctx, _Inout_ IRP *Irp) @@ -599,7 +606,7 @@ static NTSTATUS TunWriteFromIrp(_Inout_ TUN_CTX *ctx, _Inout_ IRP *Irp) const UCHAR *b = buffer, *b_end = buffer + size; ULONG nbl_count = 0; NET_BUFFER_LIST *nbl_head = NULL, *nbl_tail = NULL; - LONG64 stat_size = 0, stat_p_ok = 0, stat_p_err = 0; + LONG64 stat_p_err = 0; while (b < b_end) { TUN_PACKET *p = (TUN_PACKET *)b; if (p->Size > TUN_EXCH_MAX_IP_PACKET_SIZE) @@ -632,6 +639,7 @@ static NTSTATUS TunWriteFromIrp(_Inout_ TUN_CTX *ctx, _Inout_ IRP *Irp) NdisSetNblFlag(nbl, nbl_flags); NET_BUFFER_LIST_INFO(nbl, NetBufferListFrameType) = (PVOID)TunHtons(nbl_proto); NET_BUFFER_LIST_STATUS(nbl) = NDIS_STATUS_SUCCESS; + NET_BUFFER_LIST_IRP(nbl) = Irp; TunAppendNBL(&nbl_head, &nbl_tail, nbl); nbl_count++; goto next_packet; @@ -643,60 +651,19 @@ static NTSTATUS TunWriteFromIrp(_Inout_ TUN_CTX *ctx, _Inout_ IRP *Irp) next_packet: b += p_size; } + InterlockedAdd64((LONG64 *)&ctx->Statistics.ifInErrors, stat_p_err); - /* Commentary from Jason: - * - * Problem statement: - * We call IoCompleteRequest(Irp) immediately after NdisMIndicateReceiveNetBufferLists, which frees Irp->MdlAddress. - * Since we've just given the same memory to NdisMIndicateReceiveNetBufferLists (in a different MDL), we wind up - * freeing the memory before NDIS finishes processing them. - * - * Fix possibility 1: - * Move IoCompleteRequest(Irp) to TunReturnNetBufferLists. This reqiures reference counting how many NBLs are currently - * in flight that are using an IRP. When that drops to zero, we can call IoCompleteRequest(Irp). - * Problem: - * This means we have to block future wireguard-go Writes until *all* NBLs have completed processing in the networking - * stack. Is that safe to do? Will that introduce latency? Can userspace processes sabotage it by refusing to read from - * a TCP socket buffer? We don't know enough about how NdisMIndicateReceiveNetBufferLists works to assess its - * characteristics here. - * - * Fix possibility 2: - * Use NDIS_RECEIVE_FLAGS_RESOURCES, so that NdisMIndicateReceiveNetBufferLists makes a copy, and then we'll simply - * free everything immediately after. This is slow, and it could potentially lead to wireguard-go making the kernel - * allocate lots of memory in the case that NdisAllocateNetBufferAndNetBufferList doesn't ratelimit its creation in the - * same way Linux's skb_alloc does. However, it does make the lifetime of Irps shorter, which is easier to analyze, and - * it might lead to better latency, since we don't need to wait until userspace sends its next packets, so long as - * Ndis' ingestion queue doesn't become too large. - * - * Choice: - * Both (1) and (2) have pros and cons. Making (1) work is clearly the better long term goal. But we lack the knowledge - * to make it work correctly. (2) seems like an acceptable stopgap solution until we're smart enough to reason about - * (1). So, let's implement (2) now, and we'll let more knowledgeable people advise us on (1) later. - */ - if (nbl_head) - NdisMIndicateReceiveNetBufferLists(ctx->MiniportAdapterHandle, nbl_head, NDIS_DEFAULT_PORT_NUMBER, nbl_count, NDIS_RECEIVE_FLAGS_RESOURCES); - - for (NET_BUFFER_LIST *nbl = nbl_head, *nbl_next; nbl; nbl = nbl_next) { - nbl_next = NET_BUFFER_LIST_NEXT_NBL(nbl); - NET_BUFFER_LIST_NEXT_NBL(nbl) = NULL; - - MDL *mdl = NET_BUFFER_FIRST_MDL(NET_BUFFER_LIST_FIRST_NB(nbl)); - if (NT_SUCCESS(NET_BUFFER_LIST_STATUS(nbl))) { - ULONG p_size = MmGetMdlByteCount(mdl); - stat_size += p_size; - stat_p_ok++; - } else - stat_p_err++; - NdisFreeMdl(mdl); - NdisFreeNetBufferList(nbl); + if (!nbl_head) { + status = STATUS_SUCCESS; + goto cleanup_TunCompletePause; } - InterlockedAdd64((LONG64 *)&ctx->Statistics.ifHCInOctets, stat_size); - InterlockedAdd64((LONG64 *)&ctx->Statistics.ifHCInUcastOctets, stat_size); - InterlockedAdd64((LONG64 *)&ctx->Statistics.ifHCInUcastPkts, stat_p_ok); - InterlockedAdd64((LONG64 *)&ctx->Statistics.ifInErrors, stat_p_err); + InterlockedExchange64(IRP_REFCOUNT(Irp), nbl_count); + IoMarkIrpPending(Irp); - Irp->IoStatus.Information = b - buffer; + NdisMIndicateReceiveNetBufferLists(ctx->MiniportAdapterHandle, nbl_head, NDIS_DEFAULT_PORT_NUMBER, nbl_count, 0); + ExReleaseSpinLockShared(&ctx->TransitionLock, irql); + return STATUS_PENDING; cleanup_TunCompletePause: TunCompletePause(ctx, TRUE); @@ -779,7 +746,8 @@ static NTSTATUS TunDispatch(DEVICE_OBJECT *DeviceObject, IRP *Irp) !NT_SUCCESS(status = IoAcquireRemoveLock(&ctx->Device.RemoveLock, Irp))) goto cleanup_complete_req; - status = TunWriteFromIrp(ctx, Irp); + if ((status = TunWriteFromIrp(ctx, Irp)) == STATUS_PENDING) + return STATUS_PENDING; goto cleanup_complete_req_and_release_remove_lock; case IRP_MJ_CREATE: @@ -870,7 +838,36 @@ static MINIPORT_RETURN_NET_BUFFER_LISTS TunReturnNetBufferLists; _Use_decl_annotations_ static void TunReturnNetBufferLists(NDIS_HANDLE MiniportAdapterContext, PNET_BUFFER_LIST NetBufferLists, ULONG ReturnFlags) { - ASSERTMSG("TunReturnNetBufferLists() should not be called as NBLs are delivered using NDIS_RECEIVE_FLAGS_RESOURCES flag in NdisMIndicateReceiveNetBufferLists().", 0); + TUN_CTX *ctx = (TUN_CTX *)MiniportAdapterContext; + + LONG64 stat_size = 0, stat_p_ok = 0, stat_p_err = 0; + for (NET_BUFFER_LIST *nbl = NetBufferLists, *nbl_next; nbl; nbl = nbl_next) { + nbl_next = NET_BUFFER_LIST_NEXT_NBL(nbl); + NET_BUFFER_LIST_NEXT_NBL(nbl) = NULL; + + IRP *irp = NET_BUFFER_LIST_IRP(nbl); + MDL *mdl = NET_BUFFER_FIRST_MDL(NET_BUFFER_LIST_FIRST_NB(nbl)); + if (NT_SUCCESS(NET_BUFFER_LIST_STATUS(nbl))) { + ULONG p_size = MmGetMdlByteCount(mdl); + stat_size += p_size; + stat_p_ok++; + } else + stat_p_err++; + + NdisFreeMdl(mdl); + NdisFreeNetBufferList(nbl); + + if (InterlockedDecrement64(IRP_REFCOUNT(irp)) <= 0) { + IO_STACK_LOCATION *stack = IoGetCurrentIrpStackLocation(irp); + TunCompleteRequest(ctx, irp, stack->Parameters.Write.Length, STATUS_SUCCESS); + TunCompletePause(ctx, TRUE); + } + } + + InterlockedAdd64((LONG64 *)&ctx->Statistics.ifHCInOctets, stat_size); + InterlockedAdd64((LONG64 *)&ctx->Statistics.ifHCInUcastOctets, stat_size); + InterlockedAdd64((LONG64 *)&ctx->Statistics.ifHCInUcastPkts, stat_p_ok); + InterlockedAdd64((LONG64 *)&ctx->Statistics.ifInErrors, stat_p_err); } static MINIPORT_CANCEL_SEND TunCancelSend; @@ -1260,6 +1257,8 @@ static void TunHaltEx(NDIS_HANDLE MiniportAdapterContext, NDIS_HALT_ACTION HaltA { TUN_CTX *ctx = (TUN_CTX *)MiniportAdapterContext; + ASSERT(!InterlockedGet64(&ctx->ActiveTransactionCount)); /* Adapter should not be halted if it wasn't fully paused first. */ + InterlockedExchange((LONG *)&ctx->State, TUN_STATE_HALTING); if (ctx->PnPNotifications.Handle) {