Map user buffer only once

This avoids needless page table modifications and also lets us enforce
having writable pages.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
This commit is contained in:
Jason A. Donenfeld 2019-06-26 10:30:52 +00:00
parent d50cab5732
commit 007ea09d1b
2 changed files with 115 additions and 34 deletions

View File

@ -119,6 +119,6 @@ After loading the driver and creating a network interface the typical way using
~ ~ ~ ~
``` ```
Each packet segment should contain a layer 3 IPv4 or IPv6 packet. Up to 15728640 bytes may be read or written during each call to `ReadFile` or `WriteFile`. Each packet segment should contain a layer 3 IPv4 or IPv6 packet. Up to 15728640 bytes may be read or written during each call to `ReadFile` or `WriteFile`. All calls to `ReadFile` must be called with the same virtual address, and all calls to `WriteFile` must be called with the same virtual address. These virtual addresses must reference pages that are readable and writable for the same length as passed to the first calls of `ReadFile` and `WriteFile`.
It is advisable to use [overlapped I/O](https://docs.microsoft.com/en-us/windows/desktop/sync/synchronization-and-overlapped-input-and-output) for this. If using blocking I/O instead, it may be desirable to open separate handles for reading and writing. It is advisable to use [overlapped I/O](https://docs.microsoft.com/en-us/windows/desktop/sync/synchronization-and-overlapped-input-and-output) for this. If using blocking I/O instead, it may be desirable to open separate handles for reading and writing.

147
wintun.c
View File

@ -99,6 +99,18 @@ typedef struct _TUN_CTX {
NDIS_HANDLE NBLPool; NDIS_HANDLE NBLPool;
} TUN_CTX; } TUN_CTX;
typedef struct _TUN_MAPPED_UBUFFER {
VOID *UserAddress, *KernelAddress;
MDL *Mdl;
ULONG Size;
//TODO: ThreadID for checking
} TUN_MAPPED_UBUFFER;
typedef struct _TUN_FILE_CTX {
TUN_MAPPED_UBUFFER ReadBuffer;
TUN_MAPPED_UBUFFER WriteBuffer;
} TUN_FILE_CTX;
static UINT NdisVersion; static UINT NdisVersion;
static NDIS_HANDLE NdisMiniportDriverHandle; static NDIS_HANDLE NdisMiniportDriverHandle;
static DRIVER_DISPATCH *NdisDispatchPnP; static DRIVER_DISPATCH *NdisDispatchPnP;
@ -235,51 +247,103 @@ _IRQL_requires_max_(DISPATCH_LEVEL)
_Must_inspect_result_ _Must_inspect_result_
static NTSTATUS TunGetIrpBuffer(_In_ IRP *Irp, _Out_ UCHAR **buffer, _Out_ ULONG *size) static NTSTATUS TunGetIrpBuffer(_In_ IRP *Irp, _Out_ UCHAR **buffer, _Out_ ULONG *size)
{ {
/* Get and validate request parameters. */ TUN_MAPPED_UBUFFER *ubuffer = NULL;
ULONG priority;
IO_STACK_LOCATION *stack = IoGetCurrentIrpStackLocation(Irp); IO_STACK_LOCATION *stack = IoGetCurrentIrpStackLocation(Irp);
TUN_FILE_CTX *file_ctx = (TUN_FILE_CTX *)stack->FileObject->FsContext;
switch (stack->MajorFunction) { switch (stack->MajorFunction) {
case IRP_MJ_READ: case IRP_MJ_READ:
*size = stack->Parameters.Read.Length; *size = stack->Parameters.Read.Length;
priority = NormalPagePriority; ubuffer = &file_ctx->ReadBuffer;
break; break;
case IRP_MJ_WRITE: case IRP_MJ_WRITE:
*size = stack->Parameters.Write.Length; *size = stack->Parameters.Write.Length;
priority = NormalPagePriority | MdlMappingNoWrite; ubuffer = &file_ctx->WriteBuffer;
break; break;
default: default:
return STATUS_INVALID_PARAMETER; ASSERT(FALSE);
} }
_Analysis_assume_(ubuffer != NULL);
/* Get buffer size and address. */ if (*size > ubuffer->Size)
if (!Irp->MdlAddress)
return STATUS_INVALID_PARAMETER;
ULONG size_mdl;
*buffer = NULL;
NdisQueryMdl(Irp->MdlAddress, buffer, &size_mdl, priority);
if (!*buffer)
return STATUS_INSUFFICIENT_RESOURCES;
if (size_mdl < *size)
*size = size_mdl;
if (*size > TUN_EXCH_MAX_BUFFER_SIZE)
return STATUS_INVALID_USER_BUFFER; return STATUS_INVALID_USER_BUFFER;
ASSERT(ubuffer->KernelAddress != NULL);
*buffer = ubuffer->KernelAddress;
return STATUS_SUCCESS;
}
_IRQL_requires_max_(APC_LEVEL)
_Must_inspect_result_
static NTSTATUS TunMapUbuffer(_Inout_ TUN_MAPPED_UBUFFER *MappedBuffer, _In_ VOID *UserAddress, _In_ ULONG Size)
{
if (MappedBuffer->UserAddress) {
if (UserAddress == MappedBuffer->UserAddress) //TODO: Check ThreadID
return STATUS_SUCCESS;
return STATUS_ALREADY_INITIALIZED;
}
__try {
ProbeForWrite(UserAddress, Size, 1);
ProbeForRead(UserAddress, Size, 1);
MappedBuffer->Mdl = IoAllocateMdl(UserAddress, Size, FALSE, FALSE, NULL);
if (!MappedBuffer->Mdl)
return STATUS_INSUFFICIENT_RESOURCES;
MmProbeAndLockPages(MappedBuffer->Mdl, KernelMode, IoWriteAccess);
MappedBuffer->KernelAddress = MmGetSystemAddressForMdlSafe(MappedBuffer->Mdl, NormalPagePriority | MdlMappingNoExecute);
if (!MappedBuffer->KernelAddress) {
IoFreeMdl(MappedBuffer->Mdl);
MappedBuffer->Mdl = NULL;
return STATUS_INSUFFICIENT_RESOURCES;
}
MappedBuffer->UserAddress = UserAddress;
MappedBuffer->Size = Size;
} __except (EXCEPTION_EXECUTE_HANDLER) {
if (MappedBuffer->Mdl) {
IoFreeMdl(MappedBuffer->Mdl);
MappedBuffer->Mdl = NULL;
}
return STATUS_INVALID_USER_BUFFER;
}
return STATUS_SUCCESS;
}
_IRQL_requires_max_(DISPATCH_LEVEL)
static void TunUnmapUbuffer(_Inout_ TUN_MAPPED_UBUFFER *MappedBuffer)
{
if (MappedBuffer->Mdl) {
MmUnlockPages(MappedBuffer->Mdl);
IoFreeMdl(MappedBuffer->Mdl);
MappedBuffer->UserAddress = MappedBuffer->KernelAddress = MappedBuffer->Mdl = NULL;
}
}
_IRQL_requires_max_(APC_LEVEL)
_Must_inspect_result_
static NTSTATUS TunMapIrp(_In_ IRP *Irp)
{
ULONG size;
TUN_MAPPED_UBUFFER *ubuffer;
IO_STACK_LOCATION *stack = IoGetCurrentIrpStackLocation(Irp);
TUN_FILE_CTX *file_ctx = (TUN_FILE_CTX *)stack->FileObject->FsContext;
switch (stack->MajorFunction) { switch (stack->MajorFunction) {
case IRP_MJ_READ: case IRP_MJ_READ:
if (*size < TUN_EXCH_MIN_BUFFER_SIZE_READ) size = stack->Parameters.Read.Length;
if (size < TUN_EXCH_MIN_BUFFER_SIZE_READ)
return STATUS_INVALID_USER_BUFFER; return STATUS_INVALID_USER_BUFFER;
ubuffer = &file_ctx->ReadBuffer;
break; break;
case IRP_MJ_WRITE: case IRP_MJ_WRITE:
if (*size < TUN_EXCH_MIN_BUFFER_SIZE_WRITE) size = stack->Parameters.Write.Length;
if (size < TUN_EXCH_MIN_BUFFER_SIZE_WRITE)
return STATUS_INVALID_USER_BUFFER; return STATUS_INVALID_USER_BUFFER;
ubuffer = &file_ctx->WriteBuffer;
break; break;
default:
return STATUS_INVALID_PARAMETER;
} }
if (size > TUN_EXCH_MAX_BUFFER_SIZE)
return STATUS_SUCCESS; return STATUS_INVALID_USER_BUFFER;
return TunMapUbuffer(ubuffer, Irp->UserBuffer, size);
} }
_IRQL_requires_max_(DISPATCH_LEVEL) _IRQL_requires_max_(DISPATCH_LEVEL)
@ -609,11 +673,13 @@ static void TunCancelSend(NDIS_HANDLE MiniportAdapterContext, PVOID CancelId)
KeReleaseInStackQueuedSpinLock(&lqh); KeReleaseInStackQueuedSpinLock(&lqh);
} }
_IRQL_requires_max_(DISPATCH_LEVEL) _IRQL_requires_max_(APC_LEVEL)
_Must_inspect_result_ _Must_inspect_result_
static NTSTATUS TunDispatchRead(_Inout_ TUN_CTX *ctx, _Inout_ IRP *Irp) static NTSTATUS TunDispatchRead(_Inout_ TUN_CTX *ctx, _Inout_ IRP *Irp)
{ {
NTSTATUS status; NTSTATUS status = TunMapIrp(Irp);
if (!NT_SUCCESS(status))
goto cleanup_CompleteRequest;
KIRQL irql = ExAcquireSpinLockShared(&ctx->TransitionLock); KIRQL irql = ExAcquireSpinLockShared(&ctx->TransitionLock);
LONG flags = InterlockedGet(&ctx->Flags); LONG flags = InterlockedGet(&ctx->Flags);
@ -627,6 +693,7 @@ static NTSTATUS TunDispatchRead(_Inout_ TUN_CTX *ctx, _Inout_ IRP *Irp)
cleanup_ExReleaseSpinLockShared: cleanup_ExReleaseSpinLockShared:
ExReleaseSpinLockShared(&ctx->TransitionLock, irql); ExReleaseSpinLockShared(&ctx->TransitionLock, irql);
cleanup_CompleteRequest:
TunCompleteRequest(ctx, Irp, status, IO_NO_INCREMENT); TunCompleteRequest(ctx, Irp, status, IO_NO_INCREMENT);
return status; return status;
} }
@ -634,7 +701,7 @@ cleanup_ExReleaseSpinLockShared:
#define IRP_REFCOUNT(irp) ((volatile LONG *)&(irp)->Tail.Overlay.DriverContext[0]) #define IRP_REFCOUNT(irp) ((volatile LONG *)&(irp)->Tail.Overlay.DriverContext[0])
#define NET_BUFFER_LIST_IRP(nbl) (NET_BUFFER_LIST_MINIPORT_RESERVED(nbl)[0]) #define NET_BUFFER_LIST_IRP(nbl) (NET_BUFFER_LIST_MINIPORT_RESERVED(nbl)[0])
_IRQL_requires_max_(DISPATCH_LEVEL) _IRQL_requires_max_(APC_LEVEL)
_Must_inspect_result_ _Must_inspect_result_
static NTSTATUS TunDispatchWrite(_Inout_ TUN_CTX *ctx, _Inout_ IRP *Irp) static NTSTATUS TunDispatchWrite(_Inout_ TUN_CTX *ctx, _Inout_ IRP *Irp)
{ {
@ -642,6 +709,9 @@ static NTSTATUS TunDispatchWrite(_Inout_ TUN_CTX *ctx, _Inout_ IRP *Irp)
InterlockedIncrement64(&ctx->ActiveNBLCount); InterlockedIncrement64(&ctx->ActiveNBLCount);
if (!NT_SUCCESS(status = TunMapIrp(Irp)))
goto cleanup_CompleteRequest;
KIRQL irql = ExAcquireSpinLockShared(&ctx->TransitionLock); KIRQL irql = ExAcquireSpinLockShared(&ctx->TransitionLock);
LONG flags = InterlockedGet(&ctx->Flags); LONG flags = InterlockedGet(&ctx->Flags);
if (status = STATUS_FILE_FORCED_CLOSED, !(flags & TUN_FLAGS_PRESENT)) if (status = STATUS_FILE_FORCED_CLOSED, !(flags & TUN_FLAGS_PRESENT))
@ -651,7 +721,8 @@ static NTSTATUS TunDispatchWrite(_Inout_ TUN_CTX *ctx, _Inout_ IRP *Irp)
ULONG size; ULONG size;
if (!NT_SUCCESS(status = TunGetIrpBuffer(Irp, &buffer, &size))) if (!NT_SUCCESS(status = TunGetIrpBuffer(Irp, &buffer, &size)))
goto cleanup_ExReleaseSpinLockShared; goto cleanup_ExReleaseSpinLockShared;
IO_STACK_LOCATION *stack = IoGetCurrentIrpStackLocation(Irp);
MDL *mdl = ((TUN_FILE_CTX *)stack->FileObject->FsContext)->WriteBuffer.Mdl;
const UCHAR *b = buffer, *b_end = buffer + size; const UCHAR *b = buffer, *b_end = buffer + size;
typedef enum _ethtypeidx_t { typedef enum _ethtypeidx_t {
ethtypeidx_ipv4 = 0, ethtypeidx_start = 0, ethtypeidx_ipv4 = 0, ethtypeidx_start = 0,
@ -700,7 +771,7 @@ static NTSTATUS TunDispatchWrite(_Inout_ TUN_CTX *ctx, _Inout_ IRP *Irp)
goto cleanup_nbl_queues; goto cleanup_nbl_queues;
} }
NET_BUFFER_LIST *nbl = NdisAllocateNetBufferAndNetBufferList(ctx->NBLPool, 0, 0, Irp->MdlAddress, (ULONG)(p->Data - buffer), p->Size); NET_BUFFER_LIST *nbl = NdisAllocateNetBufferAndNetBufferList(ctx->NBLPool, 0, 0, mdl, (ULONG)(p->Data - buffer), p->Size);
if (!nbl) { if (!nbl) {
status = STATUS_INSUFFICIENT_RESOURCES; status = STATUS_INSUFFICIENT_RESOURCES;
goto cleanup_nbl_queues; goto cleanup_nbl_queues;
@ -708,7 +779,6 @@ static NTSTATUS TunDispatchWrite(_Inout_ TUN_CTX *ctx, _Inout_ IRP *Irp)
nbl->SourceHandle = ctx->MiniportAdapterHandle; nbl->SourceHandle = ctx->MiniportAdapterHandle;
NdisSetNblFlag(nbl, ether_const[idx].nbl_flags); NdisSetNblFlag(nbl, ether_const[idx].nbl_flags);
NdisSetNblFlag(nbl, NDIS_NBL_FLAGS_RECV_READ_ONLY);
NET_BUFFER_LIST_INFO(nbl, NetBufferListFrameType) = (PVOID)ether_const[idx].nbl_proto; NET_BUFFER_LIST_INFO(nbl, NetBufferListFrameType) = (PVOID)ether_const[idx].nbl_proto;
NET_BUFFER_LIST_STATUS(nbl) = NDIS_STATUS_SUCCESS; NET_BUFFER_LIST_STATUS(nbl) = NDIS_STATUS_SUCCESS;
NET_BUFFER_LIST_IRP(nbl) = Irp; NET_BUFFER_LIST_IRP(nbl) = Irp;
@ -758,6 +828,7 @@ cleanup_nbl_queues:
} }
cleanup_ExReleaseSpinLockShared: cleanup_ExReleaseSpinLockShared:
ExReleaseSpinLockShared(&ctx->TransitionLock, irql); ExReleaseSpinLockShared(&ctx->TransitionLock, irql);
cleanup_CompleteRequest:
TunCompleteRequest(ctx, Irp, status, IO_NO_INCREMENT); TunCompleteRequest(ctx, Irp, status, IO_NO_INCREMENT);
TunCompletePause(ctx, TRUE); TunCompletePause(ctx, TRUE);
return status; return status;
@ -801,6 +872,10 @@ _Must_inspect_result_
static NTSTATUS TunDispatchCreate(_Inout_ TUN_CTX *ctx, _Inout_ IRP *Irp) static NTSTATUS TunDispatchCreate(_Inout_ TUN_CTX *ctx, _Inout_ IRP *Irp)
{ {
NTSTATUS status; NTSTATUS status;
TUN_FILE_CTX *file_ctx = ExAllocatePoolWithTag(NonPagedPoolNx, sizeof(*file_ctx), TUN_HTONL(TUN_MEMORY_TAG));
if (!file_ctx)
return STATUS_INSUFFICIENT_RESOURCES;
RtlZeroMemory(file_ctx, sizeof(*file_ctx));
KIRQL irql = ExAcquireSpinLockShared(&ctx->TransitionLock); KIRQL irql = ExAcquireSpinLockShared(&ctx->TransitionLock);
LONG flags = InterlockedGet(&ctx->Flags); LONG flags = InterlockedGet(&ctx->Flags);
@ -810,6 +885,7 @@ static NTSTATUS TunDispatchCreate(_Inout_ TUN_CTX *ctx, _Inout_ IRP *Irp)
IO_STACK_LOCATION *stack = IoGetCurrentIrpStackLocation(Irp); IO_STACK_LOCATION *stack = IoGetCurrentIrpStackLocation(Irp);
if (!NT_SUCCESS(status = IoAcquireRemoveLock(&ctx->Device.RemoveLock, stack->FileObject))) if (!NT_SUCCESS(status = IoAcquireRemoveLock(&ctx->Device.RemoveLock, stack->FileObject)))
goto cleanup_ExReleaseSpinLockShared; goto cleanup_ExReleaseSpinLockShared;
stack->FileObject->FsContext = file_ctx;
if (InterlockedIncrement64(&ctx->Device.RefCount) == 1) if (InterlockedIncrement64(&ctx->Device.RefCount) == 1)
TunIndicateStatus(ctx->MiniportAdapterHandle, MediaConnectStateConnected); TunIndicateStatus(ctx->MiniportAdapterHandle, MediaConnectStateConnected);
@ -819,6 +895,8 @@ static NTSTATUS TunDispatchCreate(_Inout_ TUN_CTX *ctx, _Inout_ IRP *Irp)
cleanup_ExReleaseSpinLockShared: cleanup_ExReleaseSpinLockShared:
ExReleaseSpinLockShared(&ctx->TransitionLock, irql); ExReleaseSpinLockShared(&ctx->TransitionLock, irql);
TunCompleteRequest(ctx, Irp, status, IO_NO_INCREMENT); TunCompleteRequest(ctx, Irp, status, IO_NO_INCREMENT);
if (!NT_SUCCESS(status))
ExFreePoolWithTag(file_ctx, TUN_HTONL(TUN_MEMORY_TAG));
return status; return status;
} }
@ -865,6 +943,10 @@ static NTSTATUS TunDispatch(DEVICE_OBJECT *DeviceObject, IRP *Irp)
TunIndicateStatus(handle, MediaConnectStateDisconnected); TunIndicateStatus(handle, MediaConnectStateDisconnected);
TunQueueClear(ctx, NDIS_STATUS_MEDIA_DISCONNECTED); TunQueueClear(ctx, NDIS_STATUS_MEDIA_DISCONNECTED);
} }
TUN_FILE_CTX *file_ctx = (TUN_FILE_CTX *)stack->FileObject->FsContext;
TunUnmapUbuffer(&file_ctx->ReadBuffer);
TunUnmapUbuffer(&file_ctx->WriteBuffer);
ExFreePoolWithTag(file_ctx, TUN_HTONL(TUN_MEMORY_TAG));
IoReleaseRemoveLock(&ctx->Device.RemoveLock, stack->FileObject); IoReleaseRemoveLock(&ctx->Device.RemoveLock, stack->FileObject);
status = STATUS_SUCCESS; status = STATUS_SUCCESS;
@ -1007,8 +1089,7 @@ static NDIS_STATUS TunInitializeEx(NDIS_HANDLE MiniportAdapterHandle, NDIS_HANDL
if (!NT_SUCCESS(status = NdisRegisterDeviceEx(NdisMiniportDriverHandle, &t, &object, &handle))) if (!NT_SUCCESS(status = NdisRegisterDeviceEx(NdisMiniportDriverHandle, &t, &object, &handle)))
return NDIS_STATUS_FAILURE; return NDIS_STATUS_FAILURE;
object->Flags &= ~DO_BUFFERED_IO; object->Flags &= ~(DO_BUFFERED_IO | DO_DIRECT_IO);
object->Flags |= DO_DIRECT_IO;
TUN_CTX *ctx = NdisGetDeviceReservedExtension(object); TUN_CTX *ctx = NdisGetDeviceReservedExtension(object);
if (!ctx) { if (!ctx) {