diff --git a/CMakeLists.txt b/CMakeLists.txt index e760398..e66a489 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -388,6 +388,7 @@ include_directories(${WSLDEPS_SOURCE_DIR}/include/lxcore) include_directories(${CMAKE_CURRENT_SOURCE_DIR}/src/shared/inc) include_directories(${CMAKE_CURRENT_SOURCE_DIR}/src/windows/inc) include_directories(${CMAKE_CURRENT_BINARY_DIR}/src/windows/service/inc/${TARGET_PLATFORM}/${CMAKE_BUILD_TYPE}) +include_directories(${CMAKE_CURRENT_BINARY_DIR}/src/windows/wslaservice/inc/${TARGET_PLATFORM}/${CMAKE_BUILD_TYPE}) include_directories(${CMAKE_CURRENT_BINARY_DIR}/src/windows/wslinstaller/inc/${TARGET_PLATFORM}/${CMAKE_BUILD_TYPE}) include_directories(${CMAKE_CURRENT_SOURCE_DIR}/src/linux/init/inc) include_directories(${CMAKE_CURRENT_SOURCE_DIR}/src/windows/common) @@ -413,6 +414,7 @@ add_subdirectory(msipackage) add_subdirectory(msixinstaller) add_subdirectory(src/windows/common) add_subdirectory(src/windows/service) +add_subdirectory(src/windows/wslaservice) add_subdirectory(src/windows/wslinstaller/inc) add_subdirectory(src/windows/wslinstaller/stub) add_subdirectory(src/windows/wslinstaller/exe) diff --git a/msipackage/CMakeLists.txt b/msipackage/CMakeLists.txt index 9f51362..7f7bcb1 100644 --- a/msipackage/CMakeLists.txt +++ b/msipackage/CMakeLists.txt @@ -3,7 +3,7 @@ set(OUTPUT_PACKAGE ${BIN}/wsl.msi) set(PACKAGE_WIX_IN ${CMAKE_CURRENT_LIST_DIR}/package.wix.in) set(PACKAGE_WIX ${BIN}/package.wix) set(CAB_CACHE ${BIN}/cab) -set(BINARIES wsl.exe;wslg.exe;wslhost.exe;wslrelay.exe;wslservice.exe;wslserviceproxystub.dll;init;initrd.img;wslinstall.dll) +set(BINARIES wsl.exe;wslg.exe;wslhost.exe;wslrelay.exe;wslservice.exe;wslserviceproxystub.dll;init;initrd.img;wslinstall.dll;wslaserviceproxystub.dll;wslaservice.exe) if (WSL_BUILD_WSL_SETTINGS) list(APPEND BINARIES_DEPENDENCIES "wslsettings/wslsettings.dll;wslsettings/wslsettings.exe;libwsl.dll") @@ -39,7 +39,7 @@ add_custom_command( add_custom_target(msipackage DEPENDS ${OUTPUT_PACKAGE}) set_target_properties(msipackage PROPERTIES EXCLUDE_FROM_ALL FALSE SOURCES ${PACKAGE_WIX_IN}) -add_dependencies(msipackage wsl wslg wslservice wslhost wslrelay wslserviceproxystub init initramfs wslinstall msixgluepackage) +add_dependencies(msipackage wsl wslg wslservice wslhost wslrelay wslserviceproxystub init initramfs wslinstall msixgluepackage wslaservice wslaserviceproxystub) if (WSL_BUILD_WSL_SETTINGS) add_dependencies(msipackage wslsettings libwsl) diff --git a/msipackage/package.wix.in b/msipackage/package.wix.in index 30837b0..5a7d1fd 100644 --- a/msipackage/package.wix.in +++ b/msipackage/package.wix.in @@ -50,75 +50,75 @@ - - - - - + + + + + - - - + + + - - - + + + - - - - - - + + + + + + - - + + - - - - - + + + + + + - - - - + + + - - - + + + - - - - + + + + - - - - + + + + - - - - - - + + + + + + + - - + - - - - - + + + + + @@ -131,33 +131,13 @@ - - - - - - - - - - - - - - - - - - - - - + - - - - + + + + @@ -175,18 +155,6 @@ - - - - - - - - - - - - @@ -226,7 +194,7 @@ - + @@ -265,6 +233,68 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + @@ -402,6 +432,7 @@ + @@ -501,21 +532,21 @@ Return="check" Execute="deferred" /> - - - diff --git a/src/windows/common/CMakeLists.txt b/src/windows/common/CMakeLists.txt index 90c50d0..150ee09 100644 --- a/src/windows/common/CMakeLists.txt +++ b/src/windows/common/CMakeLists.txt @@ -1,83 +1,105 @@ set(SOURCES ConsoleProgressBar.cpp ConsoleProgressIndicator.cpp + DeviceHostProxy.cpp + DeviceHostProxy.h disk.cpp Distribution.cpp + Dmesg.cpp + Dmesg.h + DnsResolver.cpp + DnsTunnelingChannel.cpp + ExecutionContext.cpp filesystem.cpp + GnsChannel.cpp HandleConsoleProgressBar.cpp hcs.cpp helpers.cpp - interop.cpp - ExecutionContext.cpp - socket.cpp hvsocket.cpp + interop.cpp Localization.cpp lxssbusclient.cpp lxssclient.cpp LxssMessagePort.cpp + LxssSecurity.cpp LxssServerPort.cpp + NatNetworking.cpp + notifications.cpp Redirector.cpp registry.cpp relay.cpp + RingBuffer.cpp + socket.cpp string.cpp SubProcess.cpp svccomm.cpp svccommio.cpp WslClient.cpp WslCoreConfig.cpp + WslCoreFilesystem.cpp WslCoreFirewallSupport.cpp + WslCoreHostDnsInfo.cpp + WslCoreNetworkEndpointSettings.cpp WslCoreNetworkingSupport.cpp WslInstall.cpp WslSecurity.cpp WslTelemetry.cpp wslutil.cpp - notifications.cpp) + ) set(HEADERS ../../../generated/Localization.h - ../../shared/inc/CommandLine.h - ../../shared/inc/defs.h - ../../shared/inc/lxfsshares.h - ../../shared/inc/lxinitshared.h - ../../shared/inc/SocketChannel.h - ../../shared/inc/socketshared.h - ../../shared/inc/hns_schema.h - ../../shared/inc/JsonUtils.h - ../../shared/inc/stringshared.h - ../../shared/inc/retryshared.h - ../../shared/inc/message.h - ../../shared/inc/prettyprintshared.h - ../inc/WslPluginApi.h - ../inc/wslpolicies.h ../inc/lxssbusclient.h ../inc/lxssclient.h ../inc/LxssDynamicFunction.h ../inc/traceloggingconfig.h ../inc/wdk.h - ../inc/wsl.h ../inc/wslconfig.h + ../inc/wsl.h ../inc/wslhost.h + ../inc/WslPluginApi.h + ../inc/wslpolicies.h ../inc/wslrelay.h + ../../shared/inc/CommandLine.h + ../../shared/inc/defs.h + ../../shared/inc/hns_schema.h + ../../shared/inc/JsonUtils.h + ../../shared/inc/lxfsshares.h + ../../shared/inc/lxinitshared.h + ../../shared/inc/message.h + ../../shared/inc/prettyprintshared.h + ../../shared/inc/retryshared.h + ../../shared/inc/SocketChannel.h + ../../shared/inc/socketshared.h + ../../shared/inc/stringshared.h ConsoleProgressBar.h ConsoleProgressIndicator.h disk.hpp Distribution.h + DnsResolver.h + DnsTunnelingChannel.h + ExecutionContext.h filesystem.hpp + GnsChannel.h + HandleConsoleProgressBar.h hcs.hpp hcs_schema.h helpers.hpp - HandleConsoleProgressBar.h - interop.hpp - ExecutionContext.h - socket.hpp hvsocket.hpp + INetworkingEngine.h + interop.hpp LxssMessagePort.h LxssPort.h + LxssSecurity.h LxssServerPort.h + NatNetworking.h + notifications.h precomp.h Redirector.h registry.hpp relay.hpp + RingBuffer.h + socket.hpp string.hpp Stringify.h SubProcess.h @@ -85,13 +107,17 @@ set(HEADERS svccommio.hpp WslClient.h WslCoreConfig.h + WslCoreFilesystem.h WslCoreFirewallSupport.h + WslCoreHostDnsInfo.h + WslCoreMessageQueue.h + WslCoreNetworkEndpointSettings.h WslCoreNetworkingSupport.h WslInstall.h WslSecurity.h WslTelemetry.h - wslutil.h - notifications.h) + wslutil.cpp + ) add_library(common STATIC ${SOURCES} ${HEADERS}) add_dependencies(common wslserviceidl localization wslservicemc wslinstalleridl) diff --git a/src/windows/service/exe/DeviceHostProxy.cpp b/src/windows/common/DeviceHostProxy.cpp similarity index 97% rename from src/windows/service/exe/DeviceHostProxy.cpp rename to src/windows/common/DeviceHostProxy.cpp index 17a7dd4..e04f346 100644 --- a/src/windows/service/exe/DeviceHostProxy.cpp +++ b/src/windows/common/DeviceHostProxy.cpp @@ -1,261 +1,261 @@ -// Copyright (C) Microsoft Corporation. All rights reserved. - -#include "precomp.h" -#include "DeviceHostProxy.h" - -// This template works around a limitation with decltype on overloaded functions. It will be able -// to get the correct version of GetVmWorkerProcess based on the provided type arguments. By -// doing it this way, a compiler error will be generated if someone changes the signature of -// GetVmWorkerProcess. -// -// The way this works: decltype(GetVmWorkerProcess) does not work because it's overloaded. -// decltype(GetVmWorkerProcess(arg1, ...)) works to select an overload if you have values of the -// correct type (std::declval() generates a value of the specified type), however the result -// of that is the function's return type, not the function's type, so the argument types must -// be repeated to reconstruct the function type. -template -using GetVmWorkerProcessType = decltype(GetVmWorkerProcess(std::declval()...))(Args...); - -// Limit the number of allowed doorbells registered by an external HDV vdev. Currently virtio-9p only uses -// one doorbell and wsldevicehost uses only two. -#define DEVICE_HOST_PROXY_DOORBELL_LIMIT 8 - -using namespace wsl::windows::common::hcs; - -DeviceHostProxy::DeviceHostProxy(const std::wstring& VmId, const GUID& RuntimeId) : - m_systemId{VmId}, m_runtimeId{RuntimeId}, m_system{wsl::windows::common::hcs::OpenComputeSystem(VmId.c_str(), GENERIC_ALL)}, m_shutdown{false} -{ - m_devicesShutdown = false; -} - -GUID DeviceHostProxy::AddNewDevice(const GUID& Type, const wil::com_ptr& Plan9Fs, const std::wstring& VirtIoTag) -{ - const wrl::ComPtr thisUnknown{CastToUnknown()}; - GUID instanceId{}; - THROW_IF_FAILED(UuidCreate(&instanceId)); - // Tell the device host to create the device. - THROW_IF_FAILED(Plan9Fs->CreateVirtioDevice(m_systemId.c_str(), thisUnknown.Get(), VirtIoTag.c_str(), &instanceId)); - - // Add the instance ID to the list of known devices. This must be done before the device is - // added to the system, because doing that can cause the register doorbell function to be - // called. - // N.B. It will be removed if there is a failure. - { - auto lock = m_devicesLock.lock_exclusive(); - THROW_HR_IF(E_CHANGED_STATE, m_devicesShutdown); - - m_devices.emplace(instanceId, DeviceHostProxyEntry{}); - } - - auto removeOnFailure = wil::scope_exit_log(WI_DIAGNOSTICS_INFO, [&]() { - auto lock = m_devicesLock.lock_exclusive(); - m_devices.erase(instanceId); - }); - - // Add the device to the compute system on behalf of the device host. - ModifySettingRequest request; - request.RequestType = ModifyRequestType::Add; - request.ResourcePath = L"VirtualMachine/Devices/FlexibleIov/"; - request.ResourcePath += wsl::shared::string::GuidToString(instanceId, wsl::shared::string::GuidToStringFlags::None); - request.Settings.EmulatorId = Type; - request.Settings.HostingModel = FlexibleIoDeviceHostingModel::ExternalRestricted; - wsl::windows::common::hcs::ModifyComputeSystem(m_system.get(), wsl::shared::ToJsonW(request).c_str()); - removeOnFailure.release(); - return instanceId; -} - -void DeviceHostProxy::AddRemoteFileSystem(const GUID& ImplementationClsid, const std::wstring& Tag, const wil::com_ptr& Plan9Fs) -{ - auto lock = m_lock.lock_exclusive(); - THROW_HR_IF(E_CHANGED_STATE, m_shutdown); - - // Make sure there are no duplicate tags. - for (auto& entry : m_fileSystems) - { - THROW_HR_IF(E_INVALIDARG, entry.ImplementationClsid == ImplementationClsid && entry.Tag == Tag); - } - - m_fileSystems.emplace_back(ImplementationClsid, Tag, Plan9Fs); -} - -wil::com_ptr DeviceHostProxy::GetRemoteFileSystem(const GUID& ImplementationClsid, std::wstring_view Tag) -{ - auto lock = m_lock.lock_shared(); - THROW_HR_IF(E_CHANGED_STATE, m_shutdown); - - for (auto& entry : m_fileSystems) - { - if (entry.ImplementationClsid == ImplementationClsid && entry.Tag == Tag) - { - return entry.Instance; - } - } - - return {}; -} - -void DeviceHostProxy::Shutdown() -{ - { - auto lock = m_lock.lock_exclusive(); - m_fileSystems.clear(); - m_shutdown = true; - } - - { - auto lock = m_devicesLock.lock_exclusive(); - m_devices.clear(); - m_devicesShutdown = true; - } -} - -HRESULT -DeviceHostProxy::RegisterDeviceHost(_In_ IVmDeviceHost* DeviceHost, _In_ DWORD ProcessId, _Out_ UINT64* IpcSectionHandle) -try -{ - // - // Because HdvProxyDeviceHost is not part of the API set, it is loaded here dynamically. - // - - static LxssDynamicFunction proxyDeviceHost{c_hdvModuleName, "HdvProxyDeviceHost"}; - const wil::com_ptr remoteHost = DeviceHost; - const wil::com_ptr unknown = remoteHost.query(); - THROW_IF_FAILED(proxyDeviceHost(m_system.get(), unknown.get(), ProcessId, IpcSectionHandle)); - return S_OK; -} -CATCH_RETURN() - -HRESULT -DeviceHostProxy::NotifyAllDevicesInUse(_In_ LPCWSTR Tag) -try -{ - // - // Add another Plan9 virtio device to the guest so additional mount commands will be possible. - // This callback should be unused by virtiofs devices because a device is created for every - // AddSharePath call. - // - auto p9fs = GetRemoteFileSystem(__uuidof(p9fs::Plan9FileSystem), Tag); - THROW_HR_IF(E_NOT_SET, !p9fs); - (void)AddNewDevice(VIRTIO_PLAN9_DEVICE_ID, p9fs, Tag); - return S_OK; -} -CATCH_RETURN() - -HRESULT -DeviceHostProxy::RegisterDoorbell(const GUID& InstanceId, UINT8 BarIndex, UINT64 Offset, UINT64 TriggerValue, UINT64 Flags, HANDLE Event) -try -{ - auto lock = m_devicesLock.lock_exclusive(); - RETURN_HR_IF(E_CHANGED_STATE, m_devicesShutdown); - - // Check if the device is one of the known devices that doorbells can be registered for, and - // if the device has not already registered a doorbell. - // N.B. For security it is enforced that each device can only register a small number of doorbells. - // Currently virtio-9p only uses one and the external virtio device uses two. - const auto knownDevice = m_devices.find(InstanceId); - RETURN_HR_IF(E_ACCESSDENIED, knownDevice == m_devices.end() || knownDevice->second.DoorbellCount == DEVICE_HOST_PROXY_DOORBELL_LIMIT); - - if (!knownDevice->second.MemoryNotification) - { - // Get an interface to the worker process to query devices. - if (!m_deviceAccess) - { - static LxssDynamicFunction> getVmWorker{ - c_vmwpctrlModuleName, "GetVmWorkerProcess"}; - - RETURN_IF_FAILED(getVmWorker(m_runtimeId, __uuidof(*m_deviceAccess), reinterpret_cast(&m_deviceAccess))); - } - - RETURN_HR_IF(E_NOINTERFACE, !m_deviceAccess); - - // Retrieve the device's memory notification interface to register the doorbell, and store it - // to be used during unregistration. - wil::com_ptr device; - RETURN_IF_FAILED(m_deviceAccess->GetDevice(FLEXIO_DEVICE_ID, InstanceId, &device)); - knownDevice->second.MemoryNotification = device.query(); - } - - const auto result = knownDevice->second.MemoryNotification->RegisterDoorbell( - static_cast(BarIndex), Offset, TriggerValue, Flags, Event); - - if (SUCCEEDED(result)) - { - ++knownDevice->second.DoorbellCount; - } - - return result; -} -CATCH_RETURN() - -HRESULT -DeviceHostProxy::UnregisterDoorbell(const GUID& InstanceId, UINT8 BarIndex, UINT64 Offset, UINT64 TriggerValue, UINT64 Flags) -try -{ - auto lock = m_devicesLock.lock_exclusive(); - RETURN_HR_IF(E_CHANGED_STATE, m_devicesShutdown); - - // Check if the device is a known device and has registered a doorbell. - // N.B. If the device is being removed, the device can't be retrieved from the worker process - // so it's necessary to use the stored COM pointer. - const auto device = m_devices.find(InstanceId); - RETURN_HR_IF(E_ACCESSDENIED, device == m_devices.end() || device->second.DoorbellCount == 0); - RETURN_IF_FAILED(device->second.MemoryNotification->UnregisterDoorbell(static_cast(BarIndex), Offset, TriggerValue, Flags)); - - if (--device->second.DoorbellCount == 0) - { - device->second.MemoryNotification.reset(); - } - - return S_OK; -} -CATCH_RETURN() - -HRESULT -DeviceHostProxy::CreateSectionBackedMmioRange( - const GUID& InstanceId, UINT8 BarIndex, UINT64 BarOffsetInPages, UINT64 PageCount, UINT64 MappingFlags, HANDLE SectionHandle, UINT64 SectionOffsetInPages) -try -{ - auto lock = m_devicesLock.lock_exclusive(); - RETURN_HR_IF(E_CHANGED_STATE, m_devicesShutdown); - - // Check if the device is one of the known devices. - const auto knownDevice = m_devices.find(InstanceId); - THROW_HR_IF(E_ACCESSDENIED, knownDevice == m_devices.end()); - - if (!knownDevice->second.MemoryMapping) - { - // Get an interface to the worker process to query devices. - if (!m_deviceAccess) - { - static LxssDynamicFunction> getVmWorker{ - c_vmwpctrlModuleName, "GetVmWorkerProcess"}; - THROW_IF_FAILED(getVmWorker(m_runtimeId, __uuidof(*m_deviceAccess), reinterpret_cast(&m_deviceAccess))); - } - - THROW_HR_IF(E_NOINTERFACE, !m_deviceAccess); - - // Retrieve the device specific interface to manage mapped sections. - wil::com_ptr device; - THROW_IF_FAILED(m_deviceAccess->GetDevice(FLEXIO_DEVICE_ID, InstanceId, &device)); - knownDevice->second.MemoryMapping = device.query(); - } - - THROW_IF_FAILED(knownDevice->second.MemoryMapping->CreateSectionBackedMmioRange( - static_cast(BarIndex), BarOffsetInPages, PageCount, static_cast(MappingFlags), SectionHandle, SectionOffsetInPages)); - - return S_OK; -} -CATCH_RETURN() - -HRESULT -DeviceHostProxy::DestroySectionBackedMmioRange(const GUID& InstanceId, UINT8 BarIndex, UINT64 BarOffsetInPages) -try -{ - auto lock = m_devicesLock.lock_exclusive(); - RETURN_HR_IF(E_CHANGED_STATE, m_devicesShutdown); - const auto device = m_devices.find(InstanceId); - RETURN_HR_IF(E_ACCESSDENIED, device == m_devices.end() || !device->second.MemoryMapping); - RETURN_IF_FAILED(device->second.MemoryMapping->DestroySectionBackedMmioRange(static_cast(BarIndex), BarOffsetInPages)); - return S_OK; -} +// Copyright (C) Microsoft Corporation. All rights reserved. + +#include "precomp.h" +#include "DeviceHostProxy.h" + +// This template works around a limitation with decltype on overloaded functions. It will be able +// to get the correct version of GetVmWorkerProcess based on the provided type arguments. By +// doing it this way, a compiler error will be generated if someone changes the signature of +// GetVmWorkerProcess. +// +// The way this works: decltype(GetVmWorkerProcess) does not work because it's overloaded. +// decltype(GetVmWorkerProcess(arg1, ...)) works to select an overload if you have values of the +// correct type (std::declval() generates a value of the specified type), however the result +// of that is the function's return type, not the function's type, so the argument types must +// be repeated to reconstruct the function type. +template +using GetVmWorkerProcessType = decltype(GetVmWorkerProcess(std::declval()...))(Args...); + +// Limit the number of allowed doorbells registered by an external HDV vdev. Currently virtio-9p only uses +// one doorbell and wsldevicehost uses only two. +#define DEVICE_HOST_PROXY_DOORBELL_LIMIT 8 + +using namespace wsl::windows::common::hcs; + +DeviceHostProxy::DeviceHostProxy(const std::wstring& VmId, const GUID& RuntimeId) : + m_systemId{VmId}, m_runtimeId{RuntimeId}, m_system{wsl::windows::common::hcs::OpenComputeSystem(VmId.c_str(), GENERIC_ALL)}, m_shutdown{false} +{ + m_devicesShutdown = false; +} + +GUID DeviceHostProxy::AddNewDevice(const GUID& Type, const wil::com_ptr& Plan9Fs, const std::wstring& VirtIoTag) +{ + const wrl::ComPtr thisUnknown{CastToUnknown()}; + GUID instanceId{}; + THROW_IF_FAILED(UuidCreate(&instanceId)); + // Tell the device host to create the device. + THROW_IF_FAILED(Plan9Fs->CreateVirtioDevice(m_systemId.c_str(), thisUnknown.Get(), VirtIoTag.c_str(), &instanceId)); + + // Add the instance ID to the list of known devices. This must be done before the device is + // added to the system, because doing that can cause the register doorbell function to be + // called. + // N.B. It will be removed if there is a failure. + { + auto lock = m_devicesLock.lock_exclusive(); + THROW_HR_IF(E_CHANGED_STATE, m_devicesShutdown); + + m_devices.emplace(instanceId, DeviceHostProxyEntry{}); + } + + auto removeOnFailure = wil::scope_exit_log(WI_DIAGNOSTICS_INFO, [&]() { + auto lock = m_devicesLock.lock_exclusive(); + m_devices.erase(instanceId); + }); + + // Add the device to the compute system on behalf of the device host. + ModifySettingRequest request; + request.RequestType = ModifyRequestType::Add; + request.ResourcePath = L"VirtualMachine/Devices/FlexibleIov/"; + request.ResourcePath += wsl::shared::string::GuidToString(instanceId, wsl::shared::string::GuidToStringFlags::None); + request.Settings.EmulatorId = Type; + request.Settings.HostingModel = FlexibleIoDeviceHostingModel::ExternalRestricted; + wsl::windows::common::hcs::ModifyComputeSystem(m_system.get(), wsl::shared::ToJsonW(request).c_str()); + removeOnFailure.release(); + return instanceId; +} + +void DeviceHostProxy::AddRemoteFileSystem(const GUID& ImplementationClsid, const std::wstring& Tag, const wil::com_ptr& Plan9Fs) +{ + auto lock = m_lock.lock_exclusive(); + THROW_HR_IF(E_CHANGED_STATE, m_shutdown); + + // Make sure there are no duplicate tags. + for (auto& entry : m_fileSystems) + { + THROW_HR_IF(E_INVALIDARG, entry.ImplementationClsid == ImplementationClsid && entry.Tag == Tag); + } + + m_fileSystems.emplace_back(ImplementationClsid, Tag, Plan9Fs); +} + +wil::com_ptr DeviceHostProxy::GetRemoteFileSystem(const GUID& ImplementationClsid, std::wstring_view Tag) +{ + auto lock = m_lock.lock_shared(); + THROW_HR_IF(E_CHANGED_STATE, m_shutdown); + + for (auto& entry : m_fileSystems) + { + if (entry.ImplementationClsid == ImplementationClsid && entry.Tag == Tag) + { + return entry.Instance; + } + } + + return {}; +} + +void DeviceHostProxy::Shutdown() +{ + { + auto lock = m_lock.lock_exclusive(); + m_fileSystems.clear(); + m_shutdown = true; + } + + { + auto lock = m_devicesLock.lock_exclusive(); + m_devices.clear(); + m_devicesShutdown = true; + } +} + +HRESULT +DeviceHostProxy::RegisterDeviceHost(_In_ IVmDeviceHost* DeviceHost, _In_ DWORD ProcessId, _Out_ UINT64* IpcSectionHandle) +try +{ + // + // Because HdvProxyDeviceHost is not part of the API set, it is loaded here dynamically. + // + + static LxssDynamicFunction proxyDeviceHost{c_hdvModuleName, "HdvProxyDeviceHost"}; + const wil::com_ptr remoteHost = DeviceHost; + const wil::com_ptr unknown = remoteHost.query(); + THROW_IF_FAILED(proxyDeviceHost(m_system.get(), unknown.get(), ProcessId, IpcSectionHandle)); + return S_OK; +} +CATCH_RETURN() + +HRESULT +DeviceHostProxy::NotifyAllDevicesInUse(_In_ LPCWSTR Tag) +try +{ + // + // Add another Plan9 virtio device to the guest so additional mount commands will be possible. + // This callback should be unused by virtiofs devices because a device is created for every + // AddSharePath call. + // + auto p9fs = GetRemoteFileSystem(__uuidof(p9fs::Plan9FileSystem), Tag); + THROW_HR_IF(E_NOT_SET, !p9fs); + (void)AddNewDevice(VIRTIO_PLAN9_DEVICE_ID, p9fs, Tag); + return S_OK; +} +CATCH_RETURN() + +HRESULT +DeviceHostProxy::RegisterDoorbell(const GUID& InstanceId, UINT8 BarIndex, UINT64 Offset, UINT64 TriggerValue, UINT64 Flags, HANDLE Event) +try +{ + auto lock = m_devicesLock.lock_exclusive(); + RETURN_HR_IF(E_CHANGED_STATE, m_devicesShutdown); + + // Check if the device is one of the known devices that doorbells can be registered for, and + // if the device has not already registered a doorbell. + // N.B. For security it is enforced that each device can only register a small number of doorbells. + // Currently virtio-9p only uses one and the external virtio device uses two. + const auto knownDevice = m_devices.find(InstanceId); + RETURN_HR_IF(E_ACCESSDENIED, knownDevice == m_devices.end() || knownDevice->second.DoorbellCount == DEVICE_HOST_PROXY_DOORBELL_LIMIT); + + if (!knownDevice->second.MemoryNotification) + { + // Get an interface to the worker process to query devices. + if (!m_deviceAccess) + { + static LxssDynamicFunction> getVmWorker{ + c_vmwpctrlModuleName, "GetVmWorkerProcess"}; + + RETURN_IF_FAILED(getVmWorker(m_runtimeId, __uuidof(*m_deviceAccess), reinterpret_cast(&m_deviceAccess))); + } + + RETURN_HR_IF(E_NOINTERFACE, !m_deviceAccess); + + // Retrieve the device's memory notification interface to register the doorbell, and store it + // to be used during unregistration. + wil::com_ptr device; + RETURN_IF_FAILED(m_deviceAccess->GetDevice(FLEXIO_DEVICE_ID, InstanceId, &device)); + knownDevice->second.MemoryNotification = device.query(); + } + + const auto result = knownDevice->second.MemoryNotification->RegisterDoorbell( + static_cast(BarIndex), Offset, TriggerValue, Flags, Event); + + if (SUCCEEDED(result)) + { + ++knownDevice->second.DoorbellCount; + } + + return result; +} +CATCH_RETURN() + +HRESULT +DeviceHostProxy::UnregisterDoorbell(const GUID& InstanceId, UINT8 BarIndex, UINT64 Offset, UINT64 TriggerValue, UINT64 Flags) +try +{ + auto lock = m_devicesLock.lock_exclusive(); + RETURN_HR_IF(E_CHANGED_STATE, m_devicesShutdown); + + // Check if the device is a known device and has registered a doorbell. + // N.B. If the device is being removed, the device can't be retrieved from the worker process + // so it's necessary to use the stored COM pointer. + const auto device = m_devices.find(InstanceId); + RETURN_HR_IF(E_ACCESSDENIED, device == m_devices.end() || device->second.DoorbellCount == 0); + RETURN_IF_FAILED(device->second.MemoryNotification->UnregisterDoorbell(static_cast(BarIndex), Offset, TriggerValue, Flags)); + + if (--device->second.DoorbellCount == 0) + { + device->second.MemoryNotification.reset(); + } + + return S_OK; +} +CATCH_RETURN() + +HRESULT +DeviceHostProxy::CreateSectionBackedMmioRange( + const GUID& InstanceId, UINT8 BarIndex, UINT64 BarOffsetInPages, UINT64 PageCount, UINT64 MappingFlags, HANDLE SectionHandle, UINT64 SectionOffsetInPages) +try +{ + auto lock = m_devicesLock.lock_exclusive(); + RETURN_HR_IF(E_CHANGED_STATE, m_devicesShutdown); + + // Check if the device is one of the known devices. + const auto knownDevice = m_devices.find(InstanceId); + THROW_HR_IF(E_ACCESSDENIED, knownDevice == m_devices.end()); + + if (!knownDevice->second.MemoryMapping) + { + // Get an interface to the worker process to query devices. + if (!m_deviceAccess) + { + static LxssDynamicFunction> getVmWorker{ + c_vmwpctrlModuleName, "GetVmWorkerProcess"}; + THROW_IF_FAILED(getVmWorker(m_runtimeId, __uuidof(*m_deviceAccess), reinterpret_cast(&m_deviceAccess))); + } + + THROW_HR_IF(E_NOINTERFACE, !m_deviceAccess); + + // Retrieve the device specific interface to manage mapped sections. + wil::com_ptr device; + THROW_IF_FAILED(m_deviceAccess->GetDevice(FLEXIO_DEVICE_ID, InstanceId, &device)); + knownDevice->second.MemoryMapping = device.query(); + } + + THROW_IF_FAILED(knownDevice->second.MemoryMapping->CreateSectionBackedMmioRange( + static_cast(BarIndex), BarOffsetInPages, PageCount, static_cast(MappingFlags), SectionHandle, SectionOffsetInPages)); + + return S_OK; +} +CATCH_RETURN() + +HRESULT +DeviceHostProxy::DestroySectionBackedMmioRange(const GUID& InstanceId, UINT8 BarIndex, UINT64 BarOffsetInPages) +try +{ + auto lock = m_devicesLock.lock_exclusive(); + RETURN_HR_IF(E_CHANGED_STATE, m_devicesShutdown); + const auto device = m_devices.find(InstanceId); + RETURN_HR_IF(E_ACCESSDENIED, device == m_devices.end() || !device->second.MemoryMapping); + RETURN_IF_FAILED(device->second.MemoryMapping->DestroySectionBackedMmioRange(static_cast(BarIndex), BarOffsetInPages)); + return S_OK; +} CATCH_RETURN() \ No newline at end of file diff --git a/src/windows/service/exe/DeviceHostProxy.h b/src/windows/common/DeviceHostProxy.h similarity index 97% rename from src/windows/service/exe/DeviceHostProxy.h rename to src/windows/common/DeviceHostProxy.h index c5cc165..5f951b3 100644 --- a/src/windows/service/exe/DeviceHostProxy.h +++ b/src/windows/common/DeviceHostProxy.h @@ -1,76 +1,76 @@ -// Copyright (C) Microsoft Corporation. All rights reserved. - -#pragma once - -#include -#include "hcs.hpp" - -namespace wrl = Microsoft::WRL; - -class DeviceHostProxy : public wrl::RuntimeClass, IVmDeviceHostSupport, IPlan9FileSystemHost> -{ -public: - DeviceHostProxy(const std::wstring& VmId, const GUID& RuntimeId); - - GUID AddNewDevice(const GUID& Type, const wil::com_ptr& Plan9Fs, const std::wstring& VirtIoTag); - - void AddRemoteFileSystem(const GUID& ImplementationClsid, const std::wstring& Tag, const wil::com_ptr& Plan9Fs); - - wil::com_ptr GetRemoteFileSystem(const GUID& ImplementationClsid, std::wstring_view Tag); - - void Shutdown(); - - // - // IVmDeviceHostSupport - // - IFACEMETHOD(RegisterDeviceHost)(_In_ IVmDeviceHost* DeviceHost, _In_ DWORD ProcessId, _Out_ UINT64* IpcSectionHandle) override; - - // - // IPlan9FileSystemHost - // - IFACEMETHOD(NotifyAllDevicesInUse)(_In_ LPCWSTR Tag) override; - - IFACEMETHOD(RegisterDoorbell)(const GUID& InstanceId, UINT8 BarIndex, UINT64 Offset, UINT64 TriggerValue, UINT64 Flags, HANDLE Event) override; - - IFACEMETHOD(UnregisterDoorbell)(const GUID& InstanceId, UINT8 BarIndex, UINT64 Offset, UINT64 TriggerValue, UINT64 Flags) override; - - IFACEMETHOD(CreateSectionBackedMmioRange)( - const GUID& InstanceId, UINT8 BarIndex, UINT64 BarOffsetInPages, UINT64 PageCount, UINT64 MappingFlags, HANDLE SectionHandle, UINT64 SectionOffsetInPages) override; - - IFACEMETHOD(DestroySectionBackedMmioRange)(const GUID& InstanceId, UINT8 BarIndex, UINT64 BarOffsetInPages) override; - -private: - struct RemoteFileSystemInfo - { - RemoteFileSystemInfo(GUID ImplementationClsid, const std::wstring& Tag, const wil::com_ptr& Instance) : - ImplementationClsid{ImplementationClsid}, Tag{Tag}, Instance{Instance} - { - } - - GUID ImplementationClsid; - std::wstring Tag; - wil::com_ptr Instance; - }; - - std::wstring m_systemId; - GUID m_runtimeId; - wsl::windows::common::hcs::unique_hcs_system m_system; - wil::srwlock m_lock; - std::vector m_fileSystems; - bool m_shutdown; - - struct DeviceHostProxyEntry - { - wil::com_ptr MemoryNotification; - wil::com_ptr MemoryMapping; - size_t DoorbellCount = 0; - }; - - wil::com_ptr m_deviceAccess; - wil::srwlock m_devicesLock; - std::map m_devices; - bool m_devicesShutdown; - - static constexpr LPCWSTR c_hdvModuleName = L"vmdevicehost.dll"; - static constexpr LPCWSTR c_vmwpctrlModuleName = L"vmwpctrl.dll"; +// Copyright (C) Microsoft Corporation. All rights reserved. + +#pragma once + +#include +#include "hcs.hpp" + +namespace wrl = Microsoft::WRL; + +class DeviceHostProxy : public wrl::RuntimeClass, IVmDeviceHostSupport, IPlan9FileSystemHost> +{ +public: + DeviceHostProxy(const std::wstring& VmId, const GUID& RuntimeId); + + GUID AddNewDevice(const GUID& Type, const wil::com_ptr& Plan9Fs, const std::wstring& VirtIoTag); + + void AddRemoteFileSystem(const GUID& ImplementationClsid, const std::wstring& Tag, const wil::com_ptr& Plan9Fs); + + wil::com_ptr GetRemoteFileSystem(const GUID& ImplementationClsid, std::wstring_view Tag); + + void Shutdown(); + + // + // IVmDeviceHostSupport + // + IFACEMETHOD(RegisterDeviceHost)(_In_ IVmDeviceHost* DeviceHost, _In_ DWORD ProcessId, _Out_ UINT64* IpcSectionHandle) override; + + // + // IPlan9FileSystemHost + // + IFACEMETHOD(NotifyAllDevicesInUse)(_In_ LPCWSTR Tag) override; + + IFACEMETHOD(RegisterDoorbell)(const GUID& InstanceId, UINT8 BarIndex, UINT64 Offset, UINT64 TriggerValue, UINT64 Flags, HANDLE Event) override; + + IFACEMETHOD(UnregisterDoorbell)(const GUID& InstanceId, UINT8 BarIndex, UINT64 Offset, UINT64 TriggerValue, UINT64 Flags) override; + + IFACEMETHOD(CreateSectionBackedMmioRange)( + const GUID& InstanceId, UINT8 BarIndex, UINT64 BarOffsetInPages, UINT64 PageCount, UINT64 MappingFlags, HANDLE SectionHandle, UINT64 SectionOffsetInPages) override; + + IFACEMETHOD(DestroySectionBackedMmioRange)(const GUID& InstanceId, UINT8 BarIndex, UINT64 BarOffsetInPages) override; + +private: + struct RemoteFileSystemInfo + { + RemoteFileSystemInfo(GUID ImplementationClsid, const std::wstring& Tag, const wil::com_ptr& Instance) : + ImplementationClsid{ImplementationClsid}, Tag{Tag}, Instance{Instance} + { + } + + GUID ImplementationClsid; + std::wstring Tag; + wil::com_ptr Instance; + }; + + std::wstring m_systemId; + GUID m_runtimeId; + wsl::windows::common::hcs::unique_hcs_system m_system; + wil::srwlock m_lock; + std::vector m_fileSystems; + bool m_shutdown; + + struct DeviceHostProxyEntry + { + wil::com_ptr MemoryNotification; + wil::com_ptr MemoryMapping; + size_t DoorbellCount = 0; + }; + + wil::com_ptr m_deviceAccess; + wil::srwlock m_devicesLock; + std::map m_devices; + bool m_devicesShutdown; + + static constexpr LPCWSTR c_hdvModuleName = L"vmdevicehost.dll"; + static constexpr LPCWSTR c_vmwpctrlModuleName = L"vmwpctrl.dll"; }; \ No newline at end of file diff --git a/src/windows/service/exe/Dmesg.cpp b/src/windows/common/Dmesg.cpp similarity index 100% rename from src/windows/service/exe/Dmesg.cpp rename to src/windows/common/Dmesg.cpp diff --git a/src/windows/service/exe/Dmesg.h b/src/windows/common/Dmesg.h similarity index 100% rename from src/windows/service/exe/Dmesg.h rename to src/windows/common/Dmesg.h diff --git a/src/windows/service/exe/DnsResolver.cpp b/src/windows/common/DnsResolver.cpp similarity index 97% rename from src/windows/service/exe/DnsResolver.cpp rename to src/windows/common/DnsResolver.cpp index 090e9b4..3fc13bf 100644 --- a/src/windows/service/exe/DnsResolver.cpp +++ b/src/windows/common/DnsResolver.cpp @@ -1,417 +1,417 @@ -// Copyright (C) Microsoft Corporation. All rights reserved. - -#include -#include "precomp.h" -#include "DnsResolver.h" - -using wsl::core::networking::DnsResolver; - -static constexpr auto c_dnsModuleName = L"dnsapi.dll"; - -std::optional> DnsResolver::s_dnsQueryRaw; -std::optional> DnsResolver::s_dnsCancelQueryRaw; -std::optional> DnsResolver::s_dnsQueryRawResultFree; - -HRESULT DnsResolver::LoadDnsResolverMethods() noexcept -{ - static wil::shared_hmodule dnsModule; - static DWORD loadError = ERROR_SUCCESS; - static std::once_flag dnsLoadFlag; - - // Load DNS dll only once - std::call_once(dnsLoadFlag, [&]() { - dnsModule.reset(LoadLibraryEx(c_dnsModuleName, nullptr, LOAD_LIBRARY_SEARCH_SYSTEM32)); - if (!dnsModule) - { - loadError = GetLastError(); - } - }); - - RETURN_IF_WIN32_ERROR_MSG(loadError, "LoadLibraryEx %ls", c_dnsModuleName); - - // Initialize dynamic functions for the DNS tunneling Windows APIs. - // using the non-throwing instance of LxssDynamicFunction as to not end up in the Error telemetry - LxssDynamicFunction local_dnsQueryRaw{DynamicFunctionErrorLogs::None}; - RETURN_IF_FAILED_EXPECTED(local_dnsQueryRaw.load(dnsModule, "DnsQueryRaw")); - LxssDynamicFunction local_dnsCancelQueryRaw{DynamicFunctionErrorLogs::None}; - RETURN_IF_FAILED_EXPECTED(local_dnsCancelQueryRaw.load(dnsModule, "DnsCancelQueryRaw")); - LxssDynamicFunction local_dnsQueryRawResultFree{DynamicFunctionErrorLogs::None}; - RETURN_IF_FAILED_EXPECTED(local_dnsQueryRawResultFree.load(dnsModule, "DnsQueryRawResultFree")); - - // Make a dummy call to the DNS APIs to verify if they are working. The APIs are going to be present - // on older Windows versions, where they can be turned on/off. If turned off, the APIs - // will be unusable and will return ERROR_CALL_NOT_IMPLEMENTED. - if (local_dnsQueryRaw(nullptr, nullptr) == ERROR_CALL_NOT_IMPLEMENTED) - { - RETURN_IF_WIN32_ERROR_EXPECTED(ERROR_CALL_NOT_IMPLEMENTED); - } - - s_dnsQueryRaw.emplace(std::move(local_dnsQueryRaw)); - s_dnsCancelQueryRaw.emplace(std::move(local_dnsCancelQueryRaw)); - s_dnsQueryRawResultFree.emplace(std::move(local_dnsQueryRawResultFree)); - return S_OK; -} - -DnsResolver::DnsResolver(wil::unique_socket&& dnsHvsocket, DnsResolverFlags flags) : - m_dnsChannel( - std::move(dnsHvsocket), - [this](const gsl::span dnsBuffer, const LX_GNS_DNS_CLIENT_IDENTIFIER& dnsClientIdentifier) { - ProcessDnsRequest(dnsBuffer, dnsClientIdentifier); - }), - m_flags(flags) -{ - // Initialize as signaled, as there are no requests yet - m_allRequestsFinished.SetEvent(); - - // Read external interface constraint regkey - const auto lxssKey = windows::common::registry::OpenLxssMachineKey(KEY_READ); - m_externalInterfaceConstraintName = - windows::common::registry::ReadString(lxssKey.get(), nullptr, c_interfaceConstraintKey, L""); - - if (!m_externalInterfaceConstraintName.empty()) - { - ResolveExternalInterfaceConstraintIndex(); - - WSL_LOG( - "DnsResolver::DnsResolver", - TraceLoggingValue(m_externalInterfaceConstraintName.c_str(), "m_externalInterfaceConstraintName"), - TraceLoggingValue(m_externalInterfaceConstraintIndex, "m_externalInterfaceConstraintIndex")); - - // Register for interface change notifications. Notifications are used to determine if the external interface constraint setting is applicable. - THROW_IF_WIN32_ERROR(NotifyIpInterfaceChange(AF_UNSPEC, &DnsResolver::InterfaceChangeCallback, this, FALSE, &m_interfaceNotificationHandle)); - } -} - -DnsResolver::~DnsResolver() noexcept -{ - Stop(); -} - -void DnsResolver::GenerateTelemetry() noexcept -try -{ - // Find the 3 most common DNS API failures - uint32_t mostCommonDnsStatusError = 0; - uint32_t mostCommonDnsStatusErrorCount = 0; - uint32_t secondCommonDnsStatusError = 0; - uint32_t secondCommonDnsStatusErrorCount = 0; - uint32_t thirdCommonDnsStatusError = 0; - uint32_t thirdCommonDnsStatusErrorCount = 0; - - std::vector> failures(m_dnsApiFailures.size()); - std::copy(m_dnsApiFailures.begin(), m_dnsApiFailures.end(), failures.begin()); - - // Sort in descending order based on failure count - std::sort(failures.begin(), failures.end(), [](const auto& lhs, const auto& rhs) { return lhs.second > rhs.second; }); - - if (failures.size() >= 1) - { - mostCommonDnsStatusError = failures[0].first; - mostCommonDnsStatusErrorCount = failures[0].second; - } - if (failures.size() >= 2) - { - secondCommonDnsStatusError = failures[1].first; - secondCommonDnsStatusErrorCount = failures[1].second; - } - if (failures.size() >= 3) - { - thirdCommonDnsStatusError = failures[2].first; - thirdCommonDnsStatusErrorCount = failures[2].second; - } - - // Add telemetry with DNS tunneling statistics, before shutting down - WSL_LOG( - "DnsTunnelingStatistics", - TraceLoggingValue(m_totalUdpQueries.load(), "totalUdpQueries"), - TraceLoggingValue(m_successfulUdpQueries.load(), "successfulUdpQueries"), - TraceLoggingValue(m_totalTcpQueries.load(), "totalTcpQueries"), - TraceLoggingValue(m_successfulTcpQueries.load(), "successfulTcpQueries"), - TraceLoggingValue(m_queriesWithNullResult.load(), "queriesWithNullResult"), - TraceLoggingValue(m_failedDnsQueryRawCalls.load(), "FailedDnsQueryRawCalls"), - TraceLoggingValue(m_dnsApiFailures.size(), "totalDnsStatusErrorInstances"), - TraceLoggingValue(mostCommonDnsStatusError, "mostCommonDnsStatusError"), - TraceLoggingValue(mostCommonDnsStatusErrorCount, "mostCommonDnsStatusErrorCount"), - TraceLoggingValue(secondCommonDnsStatusError, "secondCommonDnsStatusError"), - TraceLoggingValue(secondCommonDnsStatusErrorCount, "secondCommonDnsStatusErrorCount"), - TraceLoggingValue(thirdCommonDnsStatusError, "thirdCommonDnsStatusError"), - TraceLoggingValue(thirdCommonDnsStatusErrorCount, "thirdCommonDnsStatusErrorCount")); -} -CATCH_LOG() - -void DnsResolver::Stop() noexcept -try -{ - WSL_LOG("DnsResolver::Stop"); - - // Scoped m_dnsLock - { - const std::lock_guard lock(m_dnsLock); - - m_stopped = true; - - // Cancel existing requests. Cancel is complete when DnsQueryRawCallback is - // invoked with status == ERROR_CANCELLED - // N.B. Cancelling can end up calling the DnsQueryRawCallback directly on this same thread. i.e., while this - // lock is held. Which is fine because m_dnsLock is a recursive mutex. - // N.B. Cancelling a query will synchronously remove the query from m_dnsRequests, which invalidates iterators. - - std::vector cancelHandles; - cancelHandles.reserve(m_dnsRequests.size()); - - for (auto& [_, context] : m_dnsRequests) - { - cancelHandles.emplace_back(&context->m_cancelHandle); - } - - for (const auto e : cancelHandles) - { - LOG_IF_WIN32_ERROR(s_dnsCancelQueryRaw.value()(e)); - } - } - - // Wait for all requests to complete. At this point no new requests can be started since the object is stopped. - // We are only waiting for existing requests to finish. - m_allRequestsFinished.wait(); - - // Stop the response queue first as it can make calls in m_dnsChannel - m_dnsResponseQueue.cancel(); - - m_dnsChannel.Stop(); - - // Stop interface change notifications - m_interfaceNotificationHandle.reset(); - - GenerateTelemetry(); -} -CATCH_LOG() - -void DnsResolver::ProcessDnsRequest(const gsl::span dnsBuffer, const LX_GNS_DNS_CLIENT_IDENTIFIER& dnsClientIdentifier) noexcept -try -{ - const std::lock_guard lock(m_dnsLock); - if (m_stopped) - { - return; - } - - WSL_LOG_DEBUG( - "DnsResolver::ProcessDnsRequest - received new DNS request", - TraceLoggingValue(dnsBuffer.size(), "DNS buffer size"), - TraceLoggingValue(dnsClientIdentifier.Protocol == IPPROTO_UDP ? "UDP" : "TCP", "Protocol"), - TraceLoggingValue(dnsClientIdentifier.DnsClientId, "DNS client id"), - TraceLoggingValue(!m_externalInterfaceConstraintName.empty(), "Is ExternalInterfaceConstraint configured"), - TraceLoggingValue(m_externalInterfaceConstraintIndex, "m_externalInterfaceConstraintIndex")); - - // If the external interface constraint is configured but it is *not* present/up, WSL should be net-blind, so we avoid making DNS requests. - if (!m_externalInterfaceConstraintName.empty() && m_externalInterfaceConstraintIndex == 0) - { - return; - } - - dnsClientIdentifier.Protocol == IPPROTO_UDP ? m_totalUdpQueries++ : m_totalTcpQueries++; - - // Get next request id. If value reaches UINT_MAX + 1 it will be automatically reset to 0 - const auto requestId = m_currentRequestId++; - - // Create the DNS request context - auto context = std::make_unique( - requestId, dnsClientIdentifier, [this](_Inout_ DnsResolver::DnsQueryContext* context, _Inout_opt_ DNS_QUERY_RAW_RESULT* queryResults) { - HandleDnsQueryCompletion(context, queryResults); - }); - - auto [it, _] = m_dnsRequests.emplace(requestId, std::move(context)); - const auto localContext = it->second.get(); - - auto removeContextOnError = wil::scope_exit([&] { WI_VERIFY(m_dnsRequests.erase(requestId) == 1); }); - - // Fill DNS request structure - DNS_QUERY_RAW_REQUEST request{}; - - request.version = DNS_QUERY_RAW_REQUEST_VERSION1; - request.resultsVersion = DNS_QUERY_RAW_RESULTS_VERSION1; - request.dnsQueryRawSize = static_cast(dnsBuffer.size()); - request.dnsQueryRaw = (PBYTE)dnsBuffer.data(); - request.protocol = (dnsClientIdentifier.Protocol == IPPROTO_TCP) ? DNS_PROTOCOL_TCP : DNS_PROTOCOL_UDP; - request.queryCompletionCallback = DnsResolver::DnsQueryRawCallback; - request.queryContext = localContext; - // Only unicast UDP & TCP queries are tunneled. Pass this flag to tell Windows DNS client to *not* resolve using multicast. - request.queryOptions |= DNS_QUERY_NO_MULTICAST; - - // In a DNS request from Linux there might be DNS records that Windows DNS client does not know how to parse. - // By default in this case Windows will fail the request. When the flag is enabled, Windows will extract the - // question from the DNS request and attempt to resolve it, ignoring the unknown records. - if (WI_IsFlagSet(m_flags, DnsResolverFlags::BestEffortDnsParsing)) - { - request.queryRawOptions |= DNS_QUERY_RAW_OPTION_BEST_EFFORT_PARSE; - } - - // If the external interface constraint is configured and present on the host, only send DNS requests on that interface. - if (m_externalInterfaceConstraintIndex != 0) - { - request.interfaceIndex = m_externalInterfaceConstraintIndex; - } - - // Start the DNS request - // N.B. All DNS requests will bypass the Windows DNS cache - const auto result = s_dnsQueryRaw.value()(&request, &localContext->m_cancelHandle); - if (result != DNS_REQUEST_PENDING) - { - m_failedDnsQueryRawCalls++; - - WSL_LOG( - "ProcessDnsRequestFailed", - TraceLoggingValue(requestId, "requestId"), - TraceLoggingValue(result, "result"), - TraceLoggingValue("DnsQueryRaw", "executionStep")); - return; - } - - removeContextOnError.release(); - - m_allRequestsFinished.ResetEvent(); -} -CATCH_LOG() - -void DnsResolver::HandleDnsQueryCompletion(_Inout_ DnsResolver::DnsQueryContext* queryContext, _Inout_opt_ DNS_QUERY_RAW_RESULT* queryResults) noexcept -try -{ - // Always free the query result structure - const auto freeQueryResults = wil::scope_exit([&] { - if (queryResults != nullptr) - { - s_dnsQueryRawResultFree.value()(queryResults); - } - }); - - const std::lock_guard lock(m_dnsLock); - - if (queryResults != nullptr) - { - WSL_LOG( - "DnsResolver::HandleDnsQueryCompletion", - TraceLoggingValue(queryContext->m_id, "queryContext->m_id"), - TraceLoggingValue(queryResults->queryStatus, "queryResults->queryStatus"), - TraceLoggingValue(queryResults->queryRawResponse != nullptr, "validResponse")); - - // Note: The response may be valid even if queryResults->queryStatus is not 0, for example when the DNS server returns a negative response. - if (queryResults->queryRawResponse != nullptr) - { - queryContext->m_dnsClientIdentifier.Protocol == IPPROTO_UDP ? m_successfulUdpQueries++ : m_successfulTcpQueries++; - } - // the Windows DNS API returned failure - else - { - if (m_dnsApiFailures.find(queryResults->queryStatus) == m_dnsApiFailures.end()) - { - m_dnsApiFailures[queryResults->queryStatus] = 1; - } - else - { - m_dnsApiFailures[queryResults->queryStatus]++; - } - } - } - else - { - WSL_LOG( - "DnsResolver::HandleDnsQueryCompletion - received a NULL queryResults", - TraceLoggingValue(queryContext->m_id, "queryContext->m_id")); - m_queriesWithNullResult++; - } - - if (!m_stopped && queryResults != nullptr && queryResults->queryRawResponse != nullptr) - { - // Copy DNS response buffer - std::vector dnsResponse(queryResults->queryRawResponseSize); - CopyMemory(dnsResponse.data(), queryResults->queryRawResponse, queryResults->queryRawResponseSize); - - WSL_LOG_DEBUG( - "DnsResolver::HandleDnsQueryCompletion - received new DNS response", - TraceLoggingValue(dnsResponse.size(), "DNS buffer size"), - TraceLoggingValue(queryContext->m_dnsClientIdentifier.Protocol == IPPROTO_UDP ? "UDP" : "TCP", "Protocol"), - TraceLoggingValue(queryContext->m_dnsClientIdentifier.DnsClientId, "DNS client id")); - - // Schedule the DNS response to be sent to Linux - m_dnsResponseQueue.submit([this, dnsResponse = std::move(dnsResponse), dnsClientIdentifier = queryContext->m_dnsClientIdentifier]() mutable { - m_dnsChannel.SendDnsMessage(gsl::make_span(dnsResponse), dnsClientIdentifier); - }); - } - - // Stop tracking this DNS request and delete the request context - WI_VERIFY(m_dnsRequests.erase(queryContext->m_id) == 1); - - // Set event if all tracked requests have finished - if (m_dnsRequests.empty()) - { - m_allRequestsFinished.SetEvent(); - } -} -CATCH_LOG() - -void DnsResolver::ResolveExternalInterfaceConstraintIndex() noexcept -try -{ - const std::lock_guard lock(m_dnsLock); - if (m_stopped) - { - return; - } - - if (m_externalInterfaceConstraintName.empty()) - { - return; - } - - NET_LUID interfaceLuid{}; - ULONG interfaceIndex = 0; - - // Update the interface index on every exit path. - // The calls below to convert interface name to index will fail if the interface does not exist anymore, - // in which case we still need to reset the interface index to its default value of 0. - const auto setInterfaceIndex = wil::scope_exit([&] { - if (interfaceIndex != m_externalInterfaceConstraintIndex) - { - WSL_LOG( - "DnsResolver::ResolveExternalInterfaceConstraintIndex - setting m_externalInterfaceConstraintIndex to new value", - TraceLoggingValue(m_externalInterfaceConstraintIndex, "old interface index"), - TraceLoggingValue(interfaceIndex, "new interface index")); - - m_externalInterfaceConstraintIndex = interfaceIndex; - } - }); - - // If external interface constraint is configured, query to see if it's present on the host. - auto errorCode = ConvertInterfaceAliasToLuid(m_externalInterfaceConstraintName.c_str(), &interfaceLuid); - if (FAILED_WIN32_LOG(errorCode)) - { - return; - } - - errorCode = ConvertInterfaceLuidToIndex(&interfaceLuid, reinterpret_cast(&interfaceIndex)); - if (FAILED_WIN32_LOG(errorCode)) - { - return; - } -} -CATCH_LOG() - -VOID CALLBACK DnsResolver::DnsQueryRawCallback(_In_ VOID* queryContext, _Inout_opt_ DNS_QUERY_RAW_RESULT* queryResults) noexcept -try -{ - assert(queryContext != nullptr); - - const auto context = static_cast(queryContext); - - // Call into DnsResolver parent object to process the query result - context->m_handleQueryCompletion(context, queryResults); -} -CATCH_LOG() - -VOID CALLBACK DnsResolver::InterfaceChangeCallback(_In_ PVOID context, PMIB_IPINTERFACE_ROW, MIB_NOTIFICATION_TYPE) noexcept -try -{ - const auto dnsResolver = static_cast(context); - dnsResolver->ResolveExternalInterfaceConstraintIndex(); -} -CATCH_LOG() +// Copyright (C) Microsoft Corporation. All rights reserved. + +#include +#include "precomp.h" +#include "DnsResolver.h" + +using wsl::core::networking::DnsResolver; + +static constexpr auto c_dnsModuleName = L"dnsapi.dll"; + +std::optional> DnsResolver::s_dnsQueryRaw; +std::optional> DnsResolver::s_dnsCancelQueryRaw; +std::optional> DnsResolver::s_dnsQueryRawResultFree; + +HRESULT DnsResolver::LoadDnsResolverMethods() noexcept +{ + static wil::shared_hmodule dnsModule; + static DWORD loadError = ERROR_SUCCESS; + static std::once_flag dnsLoadFlag; + + // Load DNS dll only once + std::call_once(dnsLoadFlag, [&]() { + dnsModule.reset(LoadLibraryEx(c_dnsModuleName, nullptr, LOAD_LIBRARY_SEARCH_SYSTEM32)); + if (!dnsModule) + { + loadError = GetLastError(); + } + }); + + RETURN_IF_WIN32_ERROR_MSG(loadError, "LoadLibraryEx %ls", c_dnsModuleName); + + // Initialize dynamic functions for the DNS tunneling Windows APIs. + // using the non-throwing instance of LxssDynamicFunction as to not end up in the Error telemetry + LxssDynamicFunction local_dnsQueryRaw{DynamicFunctionErrorLogs::None}; + RETURN_IF_FAILED_EXPECTED(local_dnsQueryRaw.load(dnsModule, "DnsQueryRaw")); + LxssDynamicFunction local_dnsCancelQueryRaw{DynamicFunctionErrorLogs::None}; + RETURN_IF_FAILED_EXPECTED(local_dnsCancelQueryRaw.load(dnsModule, "DnsCancelQueryRaw")); + LxssDynamicFunction local_dnsQueryRawResultFree{DynamicFunctionErrorLogs::None}; + RETURN_IF_FAILED_EXPECTED(local_dnsQueryRawResultFree.load(dnsModule, "DnsQueryRawResultFree")); + + // Make a dummy call to the DNS APIs to verify if they are working. The APIs are going to be present + // on older Windows versions, where they can be turned on/off. If turned off, the APIs + // will be unusable and will return ERROR_CALL_NOT_IMPLEMENTED. + if (local_dnsQueryRaw(nullptr, nullptr) == ERROR_CALL_NOT_IMPLEMENTED) + { + RETURN_IF_WIN32_ERROR_EXPECTED(ERROR_CALL_NOT_IMPLEMENTED); + } + + s_dnsQueryRaw.emplace(std::move(local_dnsQueryRaw)); + s_dnsCancelQueryRaw.emplace(std::move(local_dnsCancelQueryRaw)); + s_dnsQueryRawResultFree.emplace(std::move(local_dnsQueryRawResultFree)); + return S_OK; +} + +DnsResolver::DnsResolver(wil::unique_socket&& dnsHvsocket, DnsResolverFlags flags) : + m_dnsChannel( + std::move(dnsHvsocket), + [this](const gsl::span dnsBuffer, const LX_GNS_DNS_CLIENT_IDENTIFIER& dnsClientIdentifier) { + ProcessDnsRequest(dnsBuffer, dnsClientIdentifier); + }), + m_flags(flags) +{ + // Initialize as signaled, as there are no requests yet + m_allRequestsFinished.SetEvent(); + + // Read external interface constraint regkey + const auto lxssKey = windows::common::registry::OpenLxssMachineKey(KEY_READ); + m_externalInterfaceConstraintName = + windows::common::registry::ReadString(lxssKey.get(), nullptr, c_interfaceConstraintKey, L""); + + if (!m_externalInterfaceConstraintName.empty()) + { + ResolveExternalInterfaceConstraintIndex(); + + WSL_LOG( + "DnsResolver::DnsResolver", + TraceLoggingValue(m_externalInterfaceConstraintName.c_str(), "m_externalInterfaceConstraintName"), + TraceLoggingValue(m_externalInterfaceConstraintIndex, "m_externalInterfaceConstraintIndex")); + + // Register for interface change notifications. Notifications are used to determine if the external interface constraint setting is applicable. + THROW_IF_WIN32_ERROR(NotifyIpInterfaceChange(AF_UNSPEC, &DnsResolver::InterfaceChangeCallback, this, FALSE, &m_interfaceNotificationHandle)); + } +} + +DnsResolver::~DnsResolver() noexcept +{ + Stop(); +} + +void DnsResolver::GenerateTelemetry() noexcept +try +{ + // Find the 3 most common DNS API failures + uint32_t mostCommonDnsStatusError = 0; + uint32_t mostCommonDnsStatusErrorCount = 0; + uint32_t secondCommonDnsStatusError = 0; + uint32_t secondCommonDnsStatusErrorCount = 0; + uint32_t thirdCommonDnsStatusError = 0; + uint32_t thirdCommonDnsStatusErrorCount = 0; + + std::vector> failures(m_dnsApiFailures.size()); + std::copy(m_dnsApiFailures.begin(), m_dnsApiFailures.end(), failures.begin()); + + // Sort in descending order based on failure count + std::sort(failures.begin(), failures.end(), [](const auto& lhs, const auto& rhs) { return lhs.second > rhs.second; }); + + if (failures.size() >= 1) + { + mostCommonDnsStatusError = failures[0].first; + mostCommonDnsStatusErrorCount = failures[0].second; + } + if (failures.size() >= 2) + { + secondCommonDnsStatusError = failures[1].first; + secondCommonDnsStatusErrorCount = failures[1].second; + } + if (failures.size() >= 3) + { + thirdCommonDnsStatusError = failures[2].first; + thirdCommonDnsStatusErrorCount = failures[2].second; + } + + // Add telemetry with DNS tunneling statistics, before shutting down + WSL_LOG( + "DnsTunnelingStatistics", + TraceLoggingValue(m_totalUdpQueries.load(), "totalUdpQueries"), + TraceLoggingValue(m_successfulUdpQueries.load(), "successfulUdpQueries"), + TraceLoggingValue(m_totalTcpQueries.load(), "totalTcpQueries"), + TraceLoggingValue(m_successfulTcpQueries.load(), "successfulTcpQueries"), + TraceLoggingValue(m_queriesWithNullResult.load(), "queriesWithNullResult"), + TraceLoggingValue(m_failedDnsQueryRawCalls.load(), "FailedDnsQueryRawCalls"), + TraceLoggingValue(m_dnsApiFailures.size(), "totalDnsStatusErrorInstances"), + TraceLoggingValue(mostCommonDnsStatusError, "mostCommonDnsStatusError"), + TraceLoggingValue(mostCommonDnsStatusErrorCount, "mostCommonDnsStatusErrorCount"), + TraceLoggingValue(secondCommonDnsStatusError, "secondCommonDnsStatusError"), + TraceLoggingValue(secondCommonDnsStatusErrorCount, "secondCommonDnsStatusErrorCount"), + TraceLoggingValue(thirdCommonDnsStatusError, "thirdCommonDnsStatusError"), + TraceLoggingValue(thirdCommonDnsStatusErrorCount, "thirdCommonDnsStatusErrorCount")); +} +CATCH_LOG() + +void DnsResolver::Stop() noexcept +try +{ + WSL_LOG("DnsResolver::Stop"); + + // Scoped m_dnsLock + { + const std::lock_guard lock(m_dnsLock); + + m_stopped = true; + + // Cancel existing requests. Cancel is complete when DnsQueryRawCallback is + // invoked with status == ERROR_CANCELLED + // N.B. Cancelling can end up calling the DnsQueryRawCallback directly on this same thread. i.e., while this + // lock is held. Which is fine because m_dnsLock is a recursive mutex. + // N.B. Cancelling a query will synchronously remove the query from m_dnsRequests, which invalidates iterators. + + std::vector cancelHandles; + cancelHandles.reserve(m_dnsRequests.size()); + + for (auto& [_, context] : m_dnsRequests) + { + cancelHandles.emplace_back(&context->m_cancelHandle); + } + + for (const auto e : cancelHandles) + { + LOG_IF_WIN32_ERROR(s_dnsCancelQueryRaw.value()(e)); + } + } + + // Wait for all requests to complete. At this point no new requests can be started since the object is stopped. + // We are only waiting for existing requests to finish. + m_allRequestsFinished.wait(); + + // Stop the response queue first as it can make calls in m_dnsChannel + m_dnsResponseQueue.cancel(); + + m_dnsChannel.Stop(); + + // Stop interface change notifications + m_interfaceNotificationHandle.reset(); + + GenerateTelemetry(); +} +CATCH_LOG() + +void DnsResolver::ProcessDnsRequest(const gsl::span dnsBuffer, const LX_GNS_DNS_CLIENT_IDENTIFIER& dnsClientIdentifier) noexcept +try +{ + const std::lock_guard lock(m_dnsLock); + if (m_stopped) + { + return; + } + + WSL_LOG_DEBUG( + "DnsResolver::ProcessDnsRequest - received new DNS request", + TraceLoggingValue(dnsBuffer.size(), "DNS buffer size"), + TraceLoggingValue(dnsClientIdentifier.Protocol == IPPROTO_UDP ? "UDP" : "TCP", "Protocol"), + TraceLoggingValue(dnsClientIdentifier.DnsClientId, "DNS client id"), + TraceLoggingValue(!m_externalInterfaceConstraintName.empty(), "Is ExternalInterfaceConstraint configured"), + TraceLoggingValue(m_externalInterfaceConstraintIndex, "m_externalInterfaceConstraintIndex")); + + // If the external interface constraint is configured but it is *not* present/up, WSL should be net-blind, so we avoid making DNS requests. + if (!m_externalInterfaceConstraintName.empty() && m_externalInterfaceConstraintIndex == 0) + { + return; + } + + dnsClientIdentifier.Protocol == IPPROTO_UDP ? m_totalUdpQueries++ : m_totalTcpQueries++; + + // Get next request id. If value reaches UINT_MAX + 1 it will be automatically reset to 0 + const auto requestId = m_currentRequestId++; + + // Create the DNS request context + auto context = std::make_unique( + requestId, dnsClientIdentifier, [this](_Inout_ DnsResolver::DnsQueryContext* context, _Inout_opt_ DNS_QUERY_RAW_RESULT* queryResults) { + HandleDnsQueryCompletion(context, queryResults); + }); + + auto [it, _] = m_dnsRequests.emplace(requestId, std::move(context)); + const auto localContext = it->second.get(); + + auto removeContextOnError = wil::scope_exit([&] { WI_VERIFY(m_dnsRequests.erase(requestId) == 1); }); + + // Fill DNS request structure + DNS_QUERY_RAW_REQUEST request{}; + + request.version = DNS_QUERY_RAW_REQUEST_VERSION1; + request.resultsVersion = DNS_QUERY_RAW_RESULTS_VERSION1; + request.dnsQueryRawSize = static_cast(dnsBuffer.size()); + request.dnsQueryRaw = (PBYTE)dnsBuffer.data(); + request.protocol = (dnsClientIdentifier.Protocol == IPPROTO_TCP) ? DNS_PROTOCOL_TCP : DNS_PROTOCOL_UDP; + request.queryCompletionCallback = DnsResolver::DnsQueryRawCallback; + request.queryContext = localContext; + // Only unicast UDP & TCP queries are tunneled. Pass this flag to tell Windows DNS client to *not* resolve using multicast. + request.queryOptions |= DNS_QUERY_NO_MULTICAST; + + // In a DNS request from Linux there might be DNS records that Windows DNS client does not know how to parse. + // By default in this case Windows will fail the request. When the flag is enabled, Windows will extract the + // question from the DNS request and attempt to resolve it, ignoring the unknown records. + if (WI_IsFlagSet(m_flags, DnsResolverFlags::BestEffortDnsParsing)) + { + request.queryRawOptions |= DNS_QUERY_RAW_OPTION_BEST_EFFORT_PARSE; + } + + // If the external interface constraint is configured and present on the host, only send DNS requests on that interface. + if (m_externalInterfaceConstraintIndex != 0) + { + request.interfaceIndex = m_externalInterfaceConstraintIndex; + } + + // Start the DNS request + // N.B. All DNS requests will bypass the Windows DNS cache + const auto result = s_dnsQueryRaw.value()(&request, &localContext->m_cancelHandle); + if (result != DNS_REQUEST_PENDING) + { + m_failedDnsQueryRawCalls++; + + WSL_LOG( + "ProcessDnsRequestFailed", + TraceLoggingValue(requestId, "requestId"), + TraceLoggingValue(result, "result"), + TraceLoggingValue("DnsQueryRaw", "executionStep")); + return; + } + + removeContextOnError.release(); + + m_allRequestsFinished.ResetEvent(); +} +CATCH_LOG() + +void DnsResolver::HandleDnsQueryCompletion(_Inout_ DnsResolver::DnsQueryContext* queryContext, _Inout_opt_ DNS_QUERY_RAW_RESULT* queryResults) noexcept +try +{ + // Always free the query result structure + const auto freeQueryResults = wil::scope_exit([&] { + if (queryResults != nullptr) + { + s_dnsQueryRawResultFree.value()(queryResults); + } + }); + + const std::lock_guard lock(m_dnsLock); + + if (queryResults != nullptr) + { + WSL_LOG( + "DnsResolver::HandleDnsQueryCompletion", + TraceLoggingValue(queryContext->m_id, "queryContext->m_id"), + TraceLoggingValue(queryResults->queryStatus, "queryResults->queryStatus"), + TraceLoggingValue(queryResults->queryRawResponse != nullptr, "validResponse")); + + // Note: The response may be valid even if queryResults->queryStatus is not 0, for example when the DNS server returns a negative response. + if (queryResults->queryRawResponse != nullptr) + { + queryContext->m_dnsClientIdentifier.Protocol == IPPROTO_UDP ? m_successfulUdpQueries++ : m_successfulTcpQueries++; + } + // the Windows DNS API returned failure + else + { + if (m_dnsApiFailures.find(queryResults->queryStatus) == m_dnsApiFailures.end()) + { + m_dnsApiFailures[queryResults->queryStatus] = 1; + } + else + { + m_dnsApiFailures[queryResults->queryStatus]++; + } + } + } + else + { + WSL_LOG( + "DnsResolver::HandleDnsQueryCompletion - received a NULL queryResults", + TraceLoggingValue(queryContext->m_id, "queryContext->m_id")); + m_queriesWithNullResult++; + } + + if (!m_stopped && queryResults != nullptr && queryResults->queryRawResponse != nullptr) + { + // Copy DNS response buffer + std::vector dnsResponse(queryResults->queryRawResponseSize); + CopyMemory(dnsResponse.data(), queryResults->queryRawResponse, queryResults->queryRawResponseSize); + + WSL_LOG_DEBUG( + "DnsResolver::HandleDnsQueryCompletion - received new DNS response", + TraceLoggingValue(dnsResponse.size(), "DNS buffer size"), + TraceLoggingValue(queryContext->m_dnsClientIdentifier.Protocol == IPPROTO_UDP ? "UDP" : "TCP", "Protocol"), + TraceLoggingValue(queryContext->m_dnsClientIdentifier.DnsClientId, "DNS client id")); + + // Schedule the DNS response to be sent to Linux + m_dnsResponseQueue.submit([this, dnsResponse = std::move(dnsResponse), dnsClientIdentifier = queryContext->m_dnsClientIdentifier]() mutable { + m_dnsChannel.SendDnsMessage(gsl::make_span(dnsResponse), dnsClientIdentifier); + }); + } + + // Stop tracking this DNS request and delete the request context + WI_VERIFY(m_dnsRequests.erase(queryContext->m_id) == 1); + + // Set event if all tracked requests have finished + if (m_dnsRequests.empty()) + { + m_allRequestsFinished.SetEvent(); + } +} +CATCH_LOG() + +void DnsResolver::ResolveExternalInterfaceConstraintIndex() noexcept +try +{ + const std::lock_guard lock(m_dnsLock); + if (m_stopped) + { + return; + } + + if (m_externalInterfaceConstraintName.empty()) + { + return; + } + + NET_LUID interfaceLuid{}; + ULONG interfaceIndex = 0; + + // Update the interface index on every exit path. + // The calls below to convert interface name to index will fail if the interface does not exist anymore, + // in which case we still need to reset the interface index to its default value of 0. + const auto setInterfaceIndex = wil::scope_exit([&] { + if (interfaceIndex != m_externalInterfaceConstraintIndex) + { + WSL_LOG( + "DnsResolver::ResolveExternalInterfaceConstraintIndex - setting m_externalInterfaceConstraintIndex to new value", + TraceLoggingValue(m_externalInterfaceConstraintIndex, "old interface index"), + TraceLoggingValue(interfaceIndex, "new interface index")); + + m_externalInterfaceConstraintIndex = interfaceIndex; + } + }); + + // If external interface constraint is configured, query to see if it's present on the host. + auto errorCode = ConvertInterfaceAliasToLuid(m_externalInterfaceConstraintName.c_str(), &interfaceLuid); + if (FAILED_WIN32_LOG(errorCode)) + { + return; + } + + errorCode = ConvertInterfaceLuidToIndex(&interfaceLuid, reinterpret_cast(&interfaceIndex)); + if (FAILED_WIN32_LOG(errorCode)) + { + return; + } +} +CATCH_LOG() + +VOID CALLBACK DnsResolver::DnsQueryRawCallback(_In_ VOID* queryContext, _Inout_opt_ DNS_QUERY_RAW_RESULT* queryResults) noexcept +try +{ + assert(queryContext != nullptr); + + const auto context = static_cast(queryContext); + + // Call into DnsResolver parent object to process the query result + context->m_handleQueryCompletion(context, queryResults); +} +CATCH_LOG() + +VOID CALLBACK DnsResolver::InterfaceChangeCallback(_In_ PVOID context, PMIB_IPINTERFACE_ROW, MIB_NOTIFICATION_TYPE) noexcept +try +{ + const auto dnsResolver = static_cast(context); + dnsResolver->ResolveExternalInterfaceConstraintIndex(); +} +CATCH_LOG() diff --git a/src/windows/service/exe/DnsResolver.h b/src/windows/common/DnsResolver.h similarity index 97% rename from src/windows/service/exe/DnsResolver.h rename to src/windows/common/DnsResolver.h index a186737..d3ea8f2 100644 --- a/src/windows/service/exe/DnsResolver.h +++ b/src/windows/common/DnsResolver.h @@ -1,141 +1,141 @@ -// Copyright (C) Microsoft Corporation. All rights reserved. - -#pragma once - -#include "DnsTunnelingChannel.h" -#include "WslCoreMessageQueue.h" -#include "WslCoreNetworkingSupport.h" - -namespace wsl::core::networking { - -enum class DnsResolverFlags -{ - None = 0x0, - BestEffortDnsParsing = 0x1 -}; -DEFINE_ENUM_FLAG_OPERATORS(DnsResolverFlags); - -class DnsResolver -{ -public: - DnsResolver(wil::unique_socket&& dnsHvsocket, DnsResolverFlags flags); - ~DnsResolver() noexcept; - - DnsResolver(const DnsResolver&) = delete; - DnsResolver& operator=(const DnsResolver&) = delete; - - DnsResolver(DnsResolver&&) = delete; - DnsResolver& operator=(DnsResolver&&) = delete; - - void Stop() noexcept; - - static HRESULT LoadDnsResolverMethods() noexcept; - -private: - struct DnsQueryContext - { - // Struct containing protocol (TCP/UDP) and unique id of the Linux DNS client making the request. - LX_GNS_DNS_CLIENT_IDENTIFIER m_dnsClientIdentifier{}; - - // Handle used to cancel the request. - DNS_QUERY_RAW_CANCEL m_cancelHandle{}; - - // Unique query id. - uint32_t m_id{}; - - // Callback to the parent object to notify about the DNS query completion. - std::function m_handleQueryCompletion; - - DnsQueryContext( - uint32_t id, - const LX_GNS_DNS_CLIENT_IDENTIFIER& dnsClientIdentifier, - std::function&& handleQueryCompletion) : - m_dnsClientIdentifier(dnsClientIdentifier), m_id(id), m_handleQueryCompletion(std::move(handleQueryCompletion)) - { - } - - ~DnsQueryContext() noexcept = default; - - DnsQueryContext(const DnsQueryContext&) = delete; - DnsQueryContext& operator=(const DnsQueryContext&) = delete; - DnsQueryContext(DnsQueryContext&&) = delete; - DnsQueryContext& operator=(DnsQueryContext&&) = delete; - }; - - void GenerateTelemetry() noexcept; - - // Process DNS request received from Linux. - // - // Arguments: - // dnsBuffer - buffer containing DNS request. - // dnsClientIdentifier - struct containing protocol (TCP/UDP) and unique id of the Linux DNS client making the request. - void ProcessDnsRequest(const gsl::span dnsBuffer, const LX_GNS_DNS_CLIENT_IDENTIFIER& dnsClientIdentifier) noexcept; - - // Handle completion of DNS query. - // - // Arguments: - // dnsQueryContext - context structure for the DNS request. - // queryResults - structure containing result of the DNS request. - void HandleDnsQueryCompletion(_Inout_ DnsQueryContext* dnsQueryContext, _Inout_opt_ DNS_QUERY_RAW_RESULT* queryResults) noexcept; - - void ResolveExternalInterfaceConstraintIndex() noexcept; - - // Callback that will be invoked by the DNS API whenever a request finishes. The callback is invoked on success, error or when request is cancelled. - // - // Arguments: - // queryContext - pointer to context structure, will be a structure of type DnsQueryContext. - // queryResults - pointer to structure containing the result of the DNS request. - static VOID CALLBACK DnsQueryRawCallback(_In_ VOID* queryContext, _Inout_opt_ DNS_QUERY_RAW_RESULT* queryResults) noexcept; - - static VOID CALLBACK InterfaceChangeCallback(_In_ PVOID context, PMIB_IPINTERFACE_ROW, MIB_NOTIFICATION_TYPE) noexcept; - - std::recursive_mutex m_dnsLock; - - // Flag used when shutting down the object. - _Guarded_by_(m_dnsLock) bool m_stopped = false; - - // Hvsocket channel used to exchange DNS messages with Linux. - DnsTunnelingChannel m_dnsChannel; - - // Queue used to send DNS responses to Linux. - WslCoreMessageQueue m_dnsResponseQueue; - - // Unique id that is incremented for each request. In case the value reaches MAX_UINT and is reset to 0, - // it's assumed previous requests with id's 0, 1, ... finished in the meantime and the id can be reused. - _Guarded_by_(m_dnsLock) uint32_t m_currentRequestId = 0; - - // Mapping request id to the request context structure. - _Guarded_by_(m_dnsLock) std::unordered_map> m_dnsRequests {}; - - // Event that is set when all tracked DNS requests have completed. - wil::unique_event m_allRequestsFinished{wil::EventOptions::ManualReset}; - - // Used for handling of external interface constraint setting. - unique_notify_handle m_interfaceNotificationHandle{}; - - std::wstring m_externalInterfaceConstraintName; - _Guarded_by_(m_dnsLock) ULONG m_externalInterfaceConstraintIndex = 0; - - const DnsResolverFlags m_flags{}; - - // Statistics used for telemetry. - std::atomic m_totalUdpQueries{0}; - std::atomic m_successfulUdpQueries{0}; - std::atomic m_totalTcpQueries{0}; - std::atomic m_successfulTcpQueries{0}; - std::atomic m_queriesWithNullResult{0}; - std::atomic m_failedDnsQueryRawCalls{0}; - - _Guarded_by_(m_dnsLock) std::map m_dnsApiFailures; - - // Dynamic functions used for calling the DNS APIs. - - // Function to start a raw DNS request. - static std::optional> s_dnsQueryRaw; - // Function to cancel a raw DNS request. - static std::optional> s_dnsCancelQueryRaw; - // Function to free the structure containing the result of a raw DNS request. - static std::optional> s_dnsQueryRawResultFree; -}; - -} // namespace wsl::core::networking +// Copyright (C) Microsoft Corporation. All rights reserved. + +#pragma once + +#include "DnsTunnelingChannel.h" +#include "WslCoreMessageQueue.h" +#include "WslCoreNetworkingSupport.h" + +namespace wsl::core::networking { + +enum class DnsResolverFlags +{ + None = 0x0, + BestEffortDnsParsing = 0x1 +}; +DEFINE_ENUM_FLAG_OPERATORS(DnsResolverFlags); + +class DnsResolver +{ +public: + DnsResolver(wil::unique_socket&& dnsHvsocket, DnsResolverFlags flags); + ~DnsResolver() noexcept; + + DnsResolver(const DnsResolver&) = delete; + DnsResolver& operator=(const DnsResolver&) = delete; + + DnsResolver(DnsResolver&&) = delete; + DnsResolver& operator=(DnsResolver&&) = delete; + + void Stop() noexcept; + + static HRESULT LoadDnsResolverMethods() noexcept; + +private: + struct DnsQueryContext + { + // Struct containing protocol (TCP/UDP) and unique id of the Linux DNS client making the request. + LX_GNS_DNS_CLIENT_IDENTIFIER m_dnsClientIdentifier{}; + + // Handle used to cancel the request. + DNS_QUERY_RAW_CANCEL m_cancelHandle{}; + + // Unique query id. + uint32_t m_id{}; + + // Callback to the parent object to notify about the DNS query completion. + std::function m_handleQueryCompletion; + + DnsQueryContext( + uint32_t id, + const LX_GNS_DNS_CLIENT_IDENTIFIER& dnsClientIdentifier, + std::function&& handleQueryCompletion) : + m_dnsClientIdentifier(dnsClientIdentifier), m_id(id), m_handleQueryCompletion(std::move(handleQueryCompletion)) + { + } + + ~DnsQueryContext() noexcept = default; + + DnsQueryContext(const DnsQueryContext&) = delete; + DnsQueryContext& operator=(const DnsQueryContext&) = delete; + DnsQueryContext(DnsQueryContext&&) = delete; + DnsQueryContext& operator=(DnsQueryContext&&) = delete; + }; + + void GenerateTelemetry() noexcept; + + // Process DNS request received from Linux. + // + // Arguments: + // dnsBuffer - buffer containing DNS request. + // dnsClientIdentifier - struct containing protocol (TCP/UDP) and unique id of the Linux DNS client making the request. + void ProcessDnsRequest(const gsl::span dnsBuffer, const LX_GNS_DNS_CLIENT_IDENTIFIER& dnsClientIdentifier) noexcept; + + // Handle completion of DNS query. + // + // Arguments: + // dnsQueryContext - context structure for the DNS request. + // queryResults - structure containing result of the DNS request. + void HandleDnsQueryCompletion(_Inout_ DnsQueryContext* dnsQueryContext, _Inout_opt_ DNS_QUERY_RAW_RESULT* queryResults) noexcept; + + void ResolveExternalInterfaceConstraintIndex() noexcept; + + // Callback that will be invoked by the DNS API whenever a request finishes. The callback is invoked on success, error or when request is cancelled. + // + // Arguments: + // queryContext - pointer to context structure, will be a structure of type DnsQueryContext. + // queryResults - pointer to structure containing the result of the DNS request. + static VOID CALLBACK DnsQueryRawCallback(_In_ VOID* queryContext, _Inout_opt_ DNS_QUERY_RAW_RESULT* queryResults) noexcept; + + static VOID CALLBACK InterfaceChangeCallback(_In_ PVOID context, PMIB_IPINTERFACE_ROW, MIB_NOTIFICATION_TYPE) noexcept; + + std::recursive_mutex m_dnsLock; + + // Flag used when shutting down the object. + _Guarded_by_(m_dnsLock) bool m_stopped = false; + + // Hvsocket channel used to exchange DNS messages with Linux. + DnsTunnelingChannel m_dnsChannel; + + // Queue used to send DNS responses to Linux. + WslCoreMessageQueue m_dnsResponseQueue; + + // Unique id that is incremented for each request. In case the value reaches MAX_UINT and is reset to 0, + // it's assumed previous requests with id's 0, 1, ... finished in the meantime and the id can be reused. + _Guarded_by_(m_dnsLock) uint32_t m_currentRequestId = 0; + + // Mapping request id to the request context structure. + _Guarded_by_(m_dnsLock) std::unordered_map> m_dnsRequests {}; + + // Event that is set when all tracked DNS requests have completed. + wil::unique_event m_allRequestsFinished{wil::EventOptions::ManualReset}; + + // Used for handling of external interface constraint setting. + unique_notify_handle m_interfaceNotificationHandle{}; + + std::wstring m_externalInterfaceConstraintName; + _Guarded_by_(m_dnsLock) ULONG m_externalInterfaceConstraintIndex = 0; + + const DnsResolverFlags m_flags{}; + + // Statistics used for telemetry. + std::atomic m_totalUdpQueries{0}; + std::atomic m_successfulUdpQueries{0}; + std::atomic m_totalTcpQueries{0}; + std::atomic m_successfulTcpQueries{0}; + std::atomic m_queriesWithNullResult{0}; + std::atomic m_failedDnsQueryRawCalls{0}; + + _Guarded_by_(m_dnsLock) std::map m_dnsApiFailures; + + // Dynamic functions used for calling the DNS APIs. + + // Function to start a raw DNS request. + static std::optional> s_dnsQueryRaw; + // Function to cancel a raw DNS request. + static std::optional> s_dnsCancelQueryRaw; + // Function to free the structure containing the result of a raw DNS request. + static std::optional> s_dnsQueryRawResultFree; +}; + +} // namespace wsl::core::networking diff --git a/src/windows/service/exe/DnsTunnelingChannel.cpp b/src/windows/common/DnsTunnelingChannel.cpp similarity index 97% rename from src/windows/service/exe/DnsTunnelingChannel.cpp rename to src/windows/common/DnsTunnelingChannel.cpp index 6d34a2a..b80bf8f 100644 --- a/src/windows/service/exe/DnsTunnelingChannel.cpp +++ b/src/windows/common/DnsTunnelingChannel.cpp @@ -1,115 +1,115 @@ -// Copyright (C) Microsoft Corporation. All rights reserved. - -#include "precomp.h" -#include "DnsTunnelingChannel.h" - -using wsl::core::networking::DnsTunnelingChannel; - -DnsTunnelingChannel::DnsTunnelingChannel(wil::unique_socket&& socket, DnsTunnelingCallback&& reportDnsRequest) : - m_channel{std::move(socket), "DnsTunneling", m_stopEvent.get()}, m_reportDnsRequest(std::move(reportDnsRequest)) -{ - WSL_LOG("DnsTunnelingChannel::DnsTunnelingChannel [Windows]", TraceLoggingValue(m_channel.Socket(), "socket")); - - // Start thread waiting for incoming messages from Linux side - m_receiveWorkerThread = std::thread([this]() { ReceiveLoop(); }); -} - -DnsTunnelingChannel::~DnsTunnelingChannel() -{ - Stop(); -} - -void DnsTunnelingChannel::SendDnsMessage(const gsl::span dnsBuffer, const LX_GNS_DNS_CLIENT_IDENTIFIER& dnsClientIdentifier) noexcept -try -{ - // Exit if channel was stopped - if (m_stopEvent.is_signaled()) - { - return; - } - - wsl::shared::MessageWriter message(LxGnsMessageDnsTunneling); - message->DnsClientIdentifier = dnsClientIdentifier; - message.WriteSpan(dnsBuffer); - - m_channel.SendMessage(message.Span()); -} -CATCH_LOG() - -void DnsTunnelingChannel::ReceiveLoop() noexcept -{ - std::vector receiveBuffer; - - for (;;) - { - try - { - if (m_stopEvent.is_signaled()) - { - return; - } - - WSL_LOG_DEBUG("DnsTunnelingChannel::ReceiveLoop [Windows] - waiting for next message from Linux"); - - // Read next message. wsl::shared::socket::RecvMessage() first reads the message header, then uses it to determine the - // total size of the message and read the rest of the message, resizing the buffer if needed. - auto [message, span] = m_channel.ReceiveMessageOrClosed(); - if (message == nullptr) - { - WSL_LOG("DnsTunnelingChannel::ReceiveLoop [Windows] - failed to read message"); - return; - } - - // Get the message type from the message header - switch (message->MessageType) - { - case LxGnsMessageDnsTunneling: - { - // Cast message to a LX_GNS_DNS_TUNNELING_MESSAGE struct - auto* dnsMessage = gslhelpers::try_get_struct(span); - if (!dnsMessage) - { - WSL_LOG( - "DnsTunnelingChannel::ReceiveLoop [Windows] - failed to convert message to LX_GNS_DNS_TUNNELING_MESSAGE"); - return; - } - - // Extract DNS buffer from message - auto dnsBuffer = span.subspan(offsetof(LX_GNS_DNS_TUNNELING_MESSAGE, Buffer)); - - WSL_LOG_DEBUG( - "DnsTunnelingChannel::ReceiveLoop [Windows] - received DNS message", - TraceLoggingValue(dnsBuffer.size(), "DNS buffer size"), - TraceLoggingValue(dnsMessage->DnsClientIdentifier.Protocol == IPPROTO_UDP ? "UDP" : "TCP", "Protocol"), - TraceLoggingValue(dnsMessage->DnsClientIdentifier.DnsClientId, "DNS client id")); - - // Invoke callback to notify about the new DNS request - m_reportDnsRequest(dnsBuffer, dnsMessage->DnsClientIdentifier); - - break; - } - - default: - { - THROW_HR_MSG(E_UNEXPECTED, "Unexpected LX_MESSAGE_TYPE : %i", message->MessageType); - } - } - } - CATCH_LOG() - } -} - -void DnsTunnelingChannel::Stop() noexcept -try -{ - WSL_LOG("DnsTunnelingChannel::Stop [Windows]"); - - m_stopEvent.SetEvent(); - - // Stop receive loop - if (m_receiveWorkerThread.joinable()) - { - m_receiveWorkerThread.join(); - } -} -CATCH_LOG() +// Copyright (C) Microsoft Corporation. All rights reserved. + +#include "precomp.h" +#include "DnsTunnelingChannel.h" + +using wsl::core::networking::DnsTunnelingChannel; + +DnsTunnelingChannel::DnsTunnelingChannel(wil::unique_socket&& socket, DnsTunnelingCallback&& reportDnsRequest) : + m_channel{std::move(socket), "DnsTunneling", m_stopEvent.get()}, m_reportDnsRequest(std::move(reportDnsRequest)) +{ + WSL_LOG("DnsTunnelingChannel::DnsTunnelingChannel [Windows]", TraceLoggingValue(m_channel.Socket(), "socket")); + + // Start thread waiting for incoming messages from Linux side + m_receiveWorkerThread = std::thread([this]() { ReceiveLoop(); }); +} + +DnsTunnelingChannel::~DnsTunnelingChannel() +{ + Stop(); +} + +void DnsTunnelingChannel::SendDnsMessage(const gsl::span dnsBuffer, const LX_GNS_DNS_CLIENT_IDENTIFIER& dnsClientIdentifier) noexcept +try +{ + // Exit if channel was stopped + if (m_stopEvent.is_signaled()) + { + return; + } + + wsl::shared::MessageWriter message(LxGnsMessageDnsTunneling); + message->DnsClientIdentifier = dnsClientIdentifier; + message.WriteSpan(dnsBuffer); + + m_channel.SendMessage(message.Span()); +} +CATCH_LOG() + +void DnsTunnelingChannel::ReceiveLoop() noexcept +{ + std::vector receiveBuffer; + + for (;;) + { + try + { + if (m_stopEvent.is_signaled()) + { + return; + } + + WSL_LOG_DEBUG("DnsTunnelingChannel::ReceiveLoop [Windows] - waiting for next message from Linux"); + + // Read next message. wsl::shared::socket::RecvMessage() first reads the message header, then uses it to determine the + // total size of the message and read the rest of the message, resizing the buffer if needed. + auto [message, span] = m_channel.ReceiveMessageOrClosed(); + if (message == nullptr) + { + WSL_LOG("DnsTunnelingChannel::ReceiveLoop [Windows] - failed to read message"); + return; + } + + // Get the message type from the message header + switch (message->MessageType) + { + case LxGnsMessageDnsTunneling: + { + // Cast message to a LX_GNS_DNS_TUNNELING_MESSAGE struct + auto* dnsMessage = gslhelpers::try_get_struct(span); + if (!dnsMessage) + { + WSL_LOG( + "DnsTunnelingChannel::ReceiveLoop [Windows] - failed to convert message to LX_GNS_DNS_TUNNELING_MESSAGE"); + return; + } + + // Extract DNS buffer from message + auto dnsBuffer = span.subspan(offsetof(LX_GNS_DNS_TUNNELING_MESSAGE, Buffer)); + + WSL_LOG_DEBUG( + "DnsTunnelingChannel::ReceiveLoop [Windows] - received DNS message", + TraceLoggingValue(dnsBuffer.size(), "DNS buffer size"), + TraceLoggingValue(dnsMessage->DnsClientIdentifier.Protocol == IPPROTO_UDP ? "UDP" : "TCP", "Protocol"), + TraceLoggingValue(dnsMessage->DnsClientIdentifier.DnsClientId, "DNS client id")); + + // Invoke callback to notify about the new DNS request + m_reportDnsRequest(dnsBuffer, dnsMessage->DnsClientIdentifier); + + break; + } + + default: + { + THROW_HR_MSG(E_UNEXPECTED, "Unexpected LX_MESSAGE_TYPE : %i", message->MessageType); + } + } + } + CATCH_LOG() + } +} + +void DnsTunnelingChannel::Stop() noexcept +try +{ + WSL_LOG("DnsTunnelingChannel::Stop [Windows]"); + + m_stopEvent.SetEvent(); + + // Stop receive loop + if (m_receiveWorkerThread.joinable()) + { + m_receiveWorkerThread.join(); + } +} +CATCH_LOG() diff --git a/src/windows/service/exe/DnsTunnelingChannel.h b/src/windows/common/DnsTunnelingChannel.h similarity index 97% rename from src/windows/service/exe/DnsTunnelingChannel.h rename to src/windows/common/DnsTunnelingChannel.h index 0214cf6..e05e198 100644 --- a/src/windows/service/exe/DnsTunnelingChannel.h +++ b/src/windows/common/DnsTunnelingChannel.h @@ -1,50 +1,50 @@ -// Copyright (C) Microsoft Corporation. All rights reserved. - -#pragma once - -#include -#include "lxinitshared.h" -#include "SocketChannel.h" - -namespace wsl::core::networking { - -using DnsTunnelingCallback = std::function, const LX_GNS_DNS_CLIENT_IDENTIFIER&)>; - -class DnsTunnelingChannel -{ -public: - DnsTunnelingChannel(wil::unique_socket&& socket, DnsTunnelingCallback&& reportDnsRequest); - ~DnsTunnelingChannel(); - - DnsTunnelingChannel(const DnsTunnelingChannel&) = delete; - DnsTunnelingChannel& operator=(const DnsTunnelingChannel&) = delete; - - DnsTunnelingChannel(DnsTunnelingChannel&&) = delete; - DnsTunnelingChannel& operator=(DnsTunnelingChannel&&) = delete; - - // Construct and send a LX_GNS_DNS_TUNNELING_MESSAGE message on the channel. - // Note: Callers are responsible for sequencing calls to this method. - // - // Arguments: - // dnsBuffer - buffer containing DNS response. - // dnsClientIdentifier - struct containing protocol (TCP/UDP) and unique id of the Linux DNS client making the request. - void SendDnsMessage(const gsl::span dnsBuffer, const LX_GNS_DNS_CLIENT_IDENTIFIER& dnsClientIdentifier) noexcept; - - // Stop the channel. - void Stop() noexcept; - -private: - // Wait for messages on the channel from Linux side. - void ReceiveLoop() noexcept; - - wil::unique_event m_stopEvent{wil::EventOptions::ManualReset}; - - wsl::shared::SocketChannel m_channel; - - std::thread m_receiveWorkerThread; - - // Callback used to notify when there is a new DNS request message on the channel. - DnsTunnelingCallback m_reportDnsRequest; -}; - -} // namespace wsl::core::networking +// Copyright (C) Microsoft Corporation. All rights reserved. + +#pragma once + +#include +#include "lxinitshared.h" +#include "SocketChannel.h" + +namespace wsl::core::networking { + +using DnsTunnelingCallback = std::function, const LX_GNS_DNS_CLIENT_IDENTIFIER&)>; + +class DnsTunnelingChannel +{ +public: + DnsTunnelingChannel(wil::unique_socket&& socket, DnsTunnelingCallback&& reportDnsRequest); + ~DnsTunnelingChannel(); + + DnsTunnelingChannel(const DnsTunnelingChannel&) = delete; + DnsTunnelingChannel& operator=(const DnsTunnelingChannel&) = delete; + + DnsTunnelingChannel(DnsTunnelingChannel&&) = delete; + DnsTunnelingChannel& operator=(DnsTunnelingChannel&&) = delete; + + // Construct and send a LX_GNS_DNS_TUNNELING_MESSAGE message on the channel. + // Note: Callers are responsible for sequencing calls to this method. + // + // Arguments: + // dnsBuffer - buffer containing DNS response. + // dnsClientIdentifier - struct containing protocol (TCP/UDP) and unique id of the Linux DNS client making the request. + void SendDnsMessage(const gsl::span dnsBuffer, const LX_GNS_DNS_CLIENT_IDENTIFIER& dnsClientIdentifier) noexcept; + + // Stop the channel. + void Stop() noexcept; + +private: + // Wait for messages on the channel from Linux side. + void ReceiveLoop() noexcept; + + wil::unique_event m_stopEvent{wil::EventOptions::ManualReset}; + + wsl::shared::SocketChannel m_channel; + + std::thread m_receiveWorkerThread; + + // Callback used to notify when there is a new DNS request message on the channel. + DnsTunnelingCallback m_reportDnsRequest; +}; + +} // namespace wsl::core::networking diff --git a/src/windows/service/exe/GnsChannel.cpp b/src/windows/common/GnsChannel.cpp similarity index 100% rename from src/windows/service/exe/GnsChannel.cpp rename to src/windows/common/GnsChannel.cpp diff --git a/src/windows/service/exe/GnsChannel.h b/src/windows/common/GnsChannel.h similarity index 100% rename from src/windows/service/exe/GnsChannel.h rename to src/windows/common/GnsChannel.h diff --git a/src/windows/service/exe/INetworkingEngine.h b/src/windows/common/INetworkingEngine.h similarity index 100% rename from src/windows/service/exe/INetworkingEngine.h rename to src/windows/common/INetworkingEngine.h diff --git a/src/windows/service/exe/LxssSecurity.cpp b/src/windows/common/LxssSecurity.cpp similarity index 100% rename from src/windows/service/exe/LxssSecurity.cpp rename to src/windows/common/LxssSecurity.cpp diff --git a/src/windows/service/exe/LxssSecurity.h b/src/windows/common/LxssSecurity.h similarity index 100% rename from src/windows/service/exe/LxssSecurity.h rename to src/windows/common/LxssSecurity.h diff --git a/src/windows/service/exe/NatNetworking.cpp b/src/windows/common/NatNetworking.cpp similarity index 99% rename from src/windows/service/exe/NatNetworking.cpp rename to src/windows/common/NatNetworking.cpp index b7dd080..dbad2a2 100644 --- a/src/windows/service/exe/NatNetworking.cpp +++ b/src/windows/common/NatNetworking.cpp @@ -6,7 +6,6 @@ #include "WslCoreHostDnsInfo.h" #include "Stringify.h" #include "WslCoreFirewallSupport.h" -#include "WslCoreVm.h" #include "hcs.hpp" using namespace wsl::core::networking; @@ -672,7 +671,7 @@ wsl::windows::common::hcs::unique_hcn_network NatNetworking::CreateNetwork(wsl:: wil::ResultFromException(WI_DIAGNOSTICS_INFO, [&] { try { - wsl::core::networking::ConfigureHyperVFirewall(config.FirewallConfig, c_vmOwner); + wsl::core::networking::ConfigureHyperVFirewall(config.FirewallConfig, wsl::windows::common::wslutil::c_vmOwner); natNetwork = CreateNetworkInternal(config); } catch (...) diff --git a/src/windows/service/exe/NatNetworking.h b/src/windows/common/NatNetworking.h similarity index 100% rename from src/windows/service/exe/NatNetworking.h rename to src/windows/common/NatNetworking.h diff --git a/src/windows/service/exe/RingBuffer.cpp b/src/windows/common/RingBuffer.cpp similarity index 96% rename from src/windows/service/exe/RingBuffer.cpp rename to src/windows/common/RingBuffer.cpp index b13b058..efce04e 100644 --- a/src/windows/service/exe/RingBuffer.cpp +++ b/src/windows/common/RingBuffer.cpp @@ -1,154 +1,154 @@ -/*++ - -Copyright (c) Microsoft. All rights reserved. - -Module Name: - - RingBuffer.cpp - -Abstract: - - This file contains definitions for the RingBuffer class. - ---*/ - -#include "precomp.h" -#include "RingBuffer.h" - -RingBuffer::RingBuffer(size_t size) : m_maxSize(size), m_offset(0) -{ - m_buffer.reserve(size); -} - -void RingBuffer::Insert(std::string_view data) -{ - auto lock = m_lock.lock_exclusive(); - auto remainingData = gsl::make_span(data.data(), data.size()); - if (remainingData.size() > m_maxSize) - { - remainingData = remainingData.subspan(remainingData.size() - m_maxSize); - } - - const auto bytesAtEnd = std::min(m_maxSize - m_offset, remainingData.size()); - if (m_offset + bytesAtEnd > m_buffer.size()) - { - m_buffer.resize(m_offset + bytesAtEnd); - WI_ASSERT(m_buffer.size() <= m_maxSize); - } - - const auto allBuffer = gsl::make_span(m_buffer); - const auto beginCopyBuffer = allBuffer.subspan(m_offset, bytesAtEnd); - copy(remainingData.subspan(0, bytesAtEnd), beginCopyBuffer); - remainingData = remainingData.subspan(bytesAtEnd); - if (!remainingData.empty()) - { - copy(remainingData, allBuffer); - m_offset = remainingData.size(); - } - else - { - m_offset += bytesAtEnd; - } -} - -std::vector RingBuffer::GetLastDelimitedStrings(char Delimiter, size_t Count) const -{ - auto lock = m_lock.lock_shared(); - auto [begin, end] = Contents(); - std::vector results; - std::optional endIndex; - for (size_t i = end.size(); i > 0; i--) - { - if (results.size() == Count) - { - break; - } - - if (Delimiter == end[i - 1]) - { - if (endIndex.has_value()) - { - results.emplace(results.begin(), &end[i], endIndex.value() - i); - endIndex.reset(); - } - else - { - endIndex = i - 1; - } - } - } - - if (results.size() == Count) - { - return results; - } - - std::string partial; - if (endIndex.has_value()) - { - partial = std::string{&end[0], endIndex.value()}; - endIndex.reset(); - } - - for (size_t i = begin.size(); i > 0; i--) - { - if (results.size() == Count) - { - break; - } - - if (Delimiter == begin[i - 1]) - { - if (!partial.empty()) - { - // The debug CRT will fastfail if begin[size] is accessed - // But in this case it's not a problem because begin.size() - i would be == 0 - std::string partial_begin{&begin.data()[i], begin.size() - i}; - results.emplace(results.begin(), partial_begin + partial); - partial.clear(); - } - else if (endIndex.has_value()) - { - results.emplace(results.begin(), &begin.data()[i], endIndex.value() - i); - endIndex.reset(); - } - else - { - endIndex = i - 1; - } - } - } - - if (results.size() < Count) - { - // May have lost some data, or this could be the very first line logged. - if (!partial.empty()) - { - results.emplace(results.begin(), partial); - } - else if (endIndex.has_value()) - { - results.emplace(results.begin(), &begin[0], endIndex.value()); - } - } - - return results; -} - -std::string RingBuffer::Get() const -{ - auto lock = m_lock.lock_shared(); - auto [begin, end] = Contents(); - std::string data; - data.reserve(begin.size() + end.size()); - data.append(begin.data(), begin.size()); - data.append(end.data(), end.size()); - return data; -} - -std::pair RingBuffer::Contents() const -{ - std::string_view beginView(m_buffer.data() + m_offset, m_buffer.size() - m_offset); - std::string_view endView(m_buffer.data(), m_offset); - return {beginView, endView}; +/*++ + +Copyright (c) Microsoft. All rights reserved. + +Module Name: + + RingBuffer.cpp + +Abstract: + + This file contains definitions for the RingBuffer class. + +--*/ + +#include "precomp.h" +#include "RingBuffer.h" + +RingBuffer::RingBuffer(size_t size) : m_maxSize(size), m_offset(0) +{ + m_buffer.reserve(size); +} + +void RingBuffer::Insert(std::string_view data) +{ + auto lock = m_lock.lock_exclusive(); + auto remainingData = gsl::make_span(data.data(), data.size()); + if (remainingData.size() > m_maxSize) + { + remainingData = remainingData.subspan(remainingData.size() - m_maxSize); + } + + const auto bytesAtEnd = std::min(m_maxSize - m_offset, remainingData.size()); + if (m_offset + bytesAtEnd > m_buffer.size()) + { + m_buffer.resize(m_offset + bytesAtEnd); + WI_ASSERT(m_buffer.size() <= m_maxSize); + } + + const auto allBuffer = gsl::make_span(m_buffer); + const auto beginCopyBuffer = allBuffer.subspan(m_offset, bytesAtEnd); + copy(remainingData.subspan(0, bytesAtEnd), beginCopyBuffer); + remainingData = remainingData.subspan(bytesAtEnd); + if (!remainingData.empty()) + { + copy(remainingData, allBuffer); + m_offset = remainingData.size(); + } + else + { + m_offset += bytesAtEnd; + } +} + +std::vector RingBuffer::GetLastDelimitedStrings(char Delimiter, size_t Count) const +{ + auto lock = m_lock.lock_shared(); + auto [begin, end] = Contents(); + std::vector results; + std::optional endIndex; + for (size_t i = end.size(); i > 0; i--) + { + if (results.size() == Count) + { + break; + } + + if (Delimiter == end[i - 1]) + { + if (endIndex.has_value()) + { + results.emplace(results.begin(), &end[i], endIndex.value() - i); + endIndex.reset(); + } + else + { + endIndex = i - 1; + } + } + } + + if (results.size() == Count) + { + return results; + } + + std::string partial; + if (endIndex.has_value()) + { + partial = std::string{&end[0], endIndex.value()}; + endIndex.reset(); + } + + for (size_t i = begin.size(); i > 0; i--) + { + if (results.size() == Count) + { + break; + } + + if (Delimiter == begin[i - 1]) + { + if (!partial.empty()) + { + // The debug CRT will fastfail if begin[size] is accessed + // But in this case it's not a problem because begin.size() - i would be == 0 + std::string partial_begin{&begin.data()[i], begin.size() - i}; + results.emplace(results.begin(), partial_begin + partial); + partial.clear(); + } + else if (endIndex.has_value()) + { + results.emplace(results.begin(), &begin.data()[i], endIndex.value() - i); + endIndex.reset(); + } + else + { + endIndex = i - 1; + } + } + } + + if (results.size() < Count) + { + // May have lost some data, or this could be the very first line logged. + if (!partial.empty()) + { + results.emplace(results.begin(), partial); + } + else if (endIndex.has_value()) + { + results.emplace(results.begin(), &begin[0], endIndex.value()); + } + } + + return results; +} + +std::string RingBuffer::Get() const +{ + auto lock = m_lock.lock_shared(); + auto [begin, end] = Contents(); + std::string data; + data.reserve(begin.size() + end.size()); + data.append(begin.data(), begin.size()); + data.append(end.data(), end.size()); + return data; +} + +std::pair RingBuffer::Contents() const +{ + std::string_view beginView(m_buffer.data() + m_offset, m_buffer.size() - m_offset); + std::string_view endView(m_buffer.data(), m_offset); + return {beginView, endView}; } \ No newline at end of file diff --git a/src/windows/service/exe/RingBuffer.h b/src/windows/common/RingBuffer.h similarity index 94% rename from src/windows/service/exe/RingBuffer.h rename to src/windows/common/RingBuffer.h index f8d204d..2786ebf 100644 --- a/src/windows/service/exe/RingBuffer.h +++ b/src/windows/common/RingBuffer.h @@ -1,34 +1,34 @@ -/*++ - -Copyright (c) Microsoft. All rights reserved. - -Module Name: - - RingBuffer.h - -Abstract: - - This file contains declarations for the RingBuffer class. - ---*/ - -#pragma once - -class RingBuffer -{ -public: - RingBuffer() = delete; - RingBuffer(size_t size); - - void Insert(std::string_view data); - std::vector GetLastDelimitedStrings(char Delimiter, size_t Count) const; - std::string Get() const; - -private: - std::pair Contents() const; - - mutable wil::srwlock m_lock; - std::vector m_buffer; - size_t m_maxSize; - size_t m_offset; +/*++ + +Copyright (c) Microsoft. All rights reserved. + +Module Name: + + RingBuffer.h + +Abstract: + + This file contains declarations for the RingBuffer class. + +--*/ + +#pragma once + +class RingBuffer +{ +public: + RingBuffer() = delete; + RingBuffer(size_t size); + + void Insert(std::string_view data); + std::vector GetLastDelimitedStrings(char Delimiter, size_t Count) const; + std::string Get() const; + +private: + std::pair Contents() const; + + mutable wil::srwlock m_lock; + std::vector m_buffer; + size_t m_maxSize; + size_t m_offset; }; \ No newline at end of file diff --git a/src/windows/service/exe/WslCoreFilesystem.cpp b/src/windows/common/WslCoreFilesystem.cpp similarity index 100% rename from src/windows/service/exe/WslCoreFilesystem.cpp rename to src/windows/common/WslCoreFilesystem.cpp diff --git a/src/windows/service/exe/WslCoreFilesystem.h b/src/windows/common/WslCoreFilesystem.h similarity index 100% rename from src/windows/service/exe/WslCoreFilesystem.h rename to src/windows/common/WslCoreFilesystem.h diff --git a/src/windows/service/exe/WslCoreHostDnsInfo.cpp b/src/windows/common/WslCoreHostDnsInfo.cpp similarity index 100% rename from src/windows/service/exe/WslCoreHostDnsInfo.cpp rename to src/windows/common/WslCoreHostDnsInfo.cpp diff --git a/src/windows/service/exe/WslCoreHostDnsInfo.h b/src/windows/common/WslCoreHostDnsInfo.h similarity index 95% rename from src/windows/service/exe/WslCoreHostDnsInfo.h rename to src/windows/common/WslCoreHostDnsInfo.h index 1813574..8fa4894 100644 --- a/src/windows/service/exe/WslCoreHostDnsInfo.h +++ b/src/windows/common/WslCoreHostDnsInfo.h @@ -1,104 +1,104 @@ -// Copyright (C) Microsoft Corporation. All rights reserved. - -#pragma once -#include -#include -#include - -#include -#include - -#include "WslCoreNetworkingSupport.h" -#include "RegistryWatcher.h" - -namespace wsl::core::networking { -struct DnsInfo -{ - std::vector Servers; - std::vector Domains; -}; - -enum class DnsSettingsFlags -{ - None = 0x0, - IncludeVpn = 0x1, - IncludeIpv6Servers = 0x2, - IncludeAllSuffixes = 0x4 -}; -DEFINE_ENUM_FLAG_OPERATORS(DnsSettingsFlags); - -inline bool operator==(const DnsInfo& lhs, const DnsInfo& rhs) noexcept -{ - return lhs.Servers == rhs.Servers && lhs.Domains == rhs.Domains; -} -inline bool operator!=(const DnsInfo& lhs, const DnsInfo& rhs) noexcept -{ - return !(lhs == rhs); -} - -std::string GenerateResolvConf(_In_ const DnsInfo& Info); - -std::vector GetAllDnsSuffixes(const std::vector& AdapterAddresses); - -DWORD GetBestInterface(); - -class HostDnsInfo -{ -public: - DnsInfo GetDnsSettings(_In_ DnsSettingsFlags Flags); - - void UpdateNetworkInformation(); - - static DnsInfo GetDnsTunnelingSettings(const std::wstring& dnsTunnelingNameserver); - - const std::vector& CurrentAddresses() const - { - return m_addresses; - } - -private: - /// - /// Internal function to retrieve the latest copy of interface information. - /// - std::vector GetAdapterAddresses(); - - /// - /// Internal function to retrieve interface DNS servers. - /// - std::vector GetInterfaceDnsServers(const std::vector& AdapterAddresses, _In_ DnsSettingsFlags Flags); - - /// - /// Internal function to retrieve all Windows DNS suffixes. - /// - static std::vector GetInterfaceDnsSuffixes(const std::vector& AdapterAddresses); - - /// - /// Internal function to convert DNS server addresses into strings. - /// - static std::vector GetDnsServerStrings(_In_ const PIP_ADAPTER_DNS_SERVER_ADDRESS& DnsServer, _In_ USHORT IpFamilyFilter, _In_ USHORT MaxValues); - - /// - /// Stores latest copy of interface information. - /// - std::mutex m_lock; - _Guarded_by_(m_lock) std::vector m_addresses; -}; - -using RegistryChangeCallback = std::function; - -/// -/// Class used to get notifications when Windows DNS suffixes are updated in registry. -/// -class DnsSuffixRegistryWatcher -{ -public: - DnsSuffixRegistryWatcher(RegistryChangeCallback&& reportRegistryChange); - ~DnsSuffixRegistryWatcher() noexcept = default; - -private: - RegistryChangeCallback m_reportRegistryChange; - - std::vector> m_registryWatchers; -}; - +// Copyright (C) Microsoft Corporation. All rights reserved. + +#pragma once +#include +#include +#include + +#include +#include + +#include "WslCoreNetworkingSupport.h" +#include "RegistryWatcher.h" + +namespace wsl::core::networking { +struct DnsInfo +{ + std::vector Servers; + std::vector Domains; +}; + +enum class DnsSettingsFlags +{ + None = 0x0, + IncludeVpn = 0x1, + IncludeIpv6Servers = 0x2, + IncludeAllSuffixes = 0x4 +}; +DEFINE_ENUM_FLAG_OPERATORS(DnsSettingsFlags); + +inline bool operator==(const DnsInfo& lhs, const DnsInfo& rhs) noexcept +{ + return lhs.Servers == rhs.Servers && lhs.Domains == rhs.Domains; +} +inline bool operator!=(const DnsInfo& lhs, const DnsInfo& rhs) noexcept +{ + return !(lhs == rhs); +} + +std::string GenerateResolvConf(_In_ const DnsInfo& Info); + +std::vector GetAllDnsSuffixes(const std::vector& AdapterAddresses); + +DWORD GetBestInterface(); + +class HostDnsInfo +{ +public: + DnsInfo GetDnsSettings(_In_ DnsSettingsFlags Flags); + + void UpdateNetworkInformation(); + + static DnsInfo GetDnsTunnelingSettings(const std::wstring& dnsTunnelingNameserver); + + const std::vector& CurrentAddresses() const + { + return m_addresses; + } + +private: + /// + /// Internal function to retrieve the latest copy of interface information. + /// + std::vector GetAdapterAddresses(); + + /// + /// Internal function to retrieve interface DNS servers. + /// + std::vector GetInterfaceDnsServers(const std::vector& AdapterAddresses, _In_ DnsSettingsFlags Flags); + + /// + /// Internal function to retrieve all Windows DNS suffixes. + /// + static std::vector GetInterfaceDnsSuffixes(const std::vector& AdapterAddresses); + + /// + /// Internal function to convert DNS server addresses into strings. + /// + static std::vector GetDnsServerStrings(_In_ const PIP_ADAPTER_DNS_SERVER_ADDRESS& DnsServer, _In_ USHORT IpFamilyFilter, _In_ USHORT MaxValues); + + /// + /// Stores latest copy of interface information. + /// + std::mutex m_lock; + _Guarded_by_(m_lock) std::vector m_addresses; +}; + +using RegistryChangeCallback = std::function; + +/// +/// Class used to get notifications when Windows DNS suffixes are updated in registry. +/// +class DnsSuffixRegistryWatcher +{ +public: + DnsSuffixRegistryWatcher(RegistryChangeCallback&& reportRegistryChange); + ~DnsSuffixRegistryWatcher() noexcept = default; + +private: + RegistryChangeCallback m_reportRegistryChange; + + std::vector> m_registryWatchers; +}; + } // namespace wsl::core::networking \ No newline at end of file diff --git a/src/windows/service/exe/WslCoreMessageQueue.h b/src/windows/common/WslCoreMessageQueue.h similarity index 96% rename from src/windows/service/exe/WslCoreMessageQueue.h rename to src/windows/common/WslCoreMessageQueue.h index d2598ed..d8f3b8d 100644 --- a/src/windows/service/exe/WslCoreMessageQueue.h +++ b/src/windows/common/WslCoreMessageQueue.h @@ -1,359 +1,359 @@ -/*++ - -Copyright (c) Microsoft. All rights reserved. - -Module Name: - - WslCoreMessageQueue.h - -Abstract: - - This file contains a queuing implementation, guaranteeing running function objects - with guaranteed serialization in a threadpool thread - ---*/ - -#pragma once -#include -#include -#include -#include -#include -#include - -namespace wsl::core { -// forward-declare classes that can instantiate a WslThreadPoolWaitableResult object -class WslCoreMessageQueue; - -class WslBaseThreadPoolWaitableResult -{ -public: - virtual ~WslBaseThreadPoolWaitableResult() noexcept = default; - -private: - // limit who can run() and abort() - friend class WslCoreMessageQueue; - - virtual void run() noexcept = 0; - virtual void abort() noexcept = 0; -}; - -template -class WslThreadPoolWaitableResult : public WslBaseThreadPoolWaitableResult -{ -public: - // throws a wil exception on failure - template - explicit WslThreadPoolWaitableResult(FunctorType&& functor) : m_function(std::forward(functor)) - { - } - - ~WslThreadPoolWaitableResult() noexcept override = default; - - // returns ERROR_SUCCESS if the callback ran to completion - // returns ERROR_TIMEOUT if this wait timed out - // - this can be called multiple times if needing to probe - // any other error code resulted from attempting to run the callback - // - meaning it did *not* run to completion - DWORD wait(DWORD timeout) const noexcept - { - if (!m_completionSignal.wait(timeout)) - { - // not setting m_internalError to timeout - // since the caller is allowed to try to wait() again later - return ERROR_TIMEOUT; - } - const auto lock = m_lock.lock_shared(); - return m_internalError; - } - - // waitable event handle, signaled when the callback has run to completion (or failed) - HANDLE notification_event() const noexcept - { - return m_completionSignal.get(); - } - - const TReturn& read_result() const noexcept - { - return result; - } - - // move the result out of the object for move-only types - TReturn move_result() noexcept - { - TReturn move_out(std::move(result)); - return move_out; - } - - // non-copyable - WslThreadPoolWaitableResult(const WslThreadPoolWaitableResult&) = delete; - WslThreadPoolWaitableResult& operator=(const WslThreadPoolWaitableResult&) = delete; - -private: - void run() noexcept override - { - // we are now running in the TP callback - { - const auto lock = m_lock.lock_exclusive(); - if (m_runStatus != RunStatus::NotYetRun) - { - // return early - the caller has already canceled this - return; - } - m_runStatus = RunStatus::Running; - } - - DWORD error = NO_ERROR; - try - { - result = std::move(m_function()); - } - catch (...) - { - const HRESULT hr = wil::ResultFromCaughtException(); - // HRESULT_TO_WIN32 - error = (HRESULT_FACILITY(hr) == FACILITY_WIN32) ? HRESULT_CODE(hr) : hr; - } - - const auto lock = m_lock.lock_exclusive(); - WI_ASSERT(m_runStatus == RunStatus::Running); - m_runStatus = RunStatus::RanToCompletion; - m_internalError = error; - m_completionSignal.SetEvent(); - } - - void abort() noexcept override - { - const auto lock = m_lock.lock_exclusive(); - // only override the error if we know we haven't started running their functor - if (m_runStatus == RunStatus::NotYetRun) - { - m_runStatus = RunStatus::Canceled; - m_internalError = ERROR_CANCELLED; - m_completionSignal.SetEvent(); - } - } - - std::function m_function; - // a notification event - wil::unique_event m_completionSignal{wil::EventOptions::ManualReset}; - mutable wil::srwlock m_lock; - TReturn result{}; - DWORD m_internalError = NO_ERROR; - - enum class RunStatus - { - NotYetRun, - Running, - RanToCompletion, - Canceled - } m_runStatus{RunStatus::NotYetRun}; -}; - -class WslCoreMessageQueue -{ -public: - WslCoreMessageQueue() : m_tpEnvironment(0, 1) - { - // create a single-threaded threadpool - m_tpHandle = m_tpEnvironment.create_tp(WorkCallback, this); - } - - template - std::shared_ptr> submit_with_results(FunctorType&& functor) noexcept - try - { - FAIL_FAST_IF(m_tpHandle.get() == nullptr); - - const auto new_result = std::make_shared>(std::forward(functor)); - // scope to the queue lock - { - const auto queueLock = m_lock.lock_exclusive(); - THROW_HR_IF(HRESULT_FROM_WIN32(ERROR_CANCELLED), m_isCanceled); - m_workItems.emplace_back(new_result); - } - - // always maintain a 1:1 ratio for calls to SubmitWorkWithResults() and ::SubmitThreadpoolWork - SubmitThreadpoolWork(m_tpHandle.get()); - return new_result; - } - catch (...) - { - LOG_CAUGHT_EXCEPTION(); - return nullptr; - } - - template - bool submit(FunctorType&& functor) noexcept - try - { - FAIL_FAST_IF(m_tpHandle.get() == nullptr); - - // scope to the queue lock - { - const auto queueLock = m_lock.lock_exclusive(); - THROW_HR_IF(HRESULT_FROM_WIN32(ERROR_CANCELLED), m_isCanceled); - m_workItems.emplace_back(std::forward(functor)); - } - - // always maintain a 1:1 ratio for calls to SubmitWork() and ::SubmitThreadpoolWork - SubmitThreadpoolWork(m_tpHandle.get()); - return true; - } - catch (...) - { - LOG_CAUGHT_EXCEPTION(); - return false; - } - - // functors must return type HRESULT - template - HRESULT submit_and_wait(FunctorType&& functor) noexcept - try - { - HRESULT hr = HRESULT_FROM_WIN32(ERROR_OUTOFMEMORY); - if (const auto waitableResult = submit_with_results(std::forward(functor))) - { - hr = HRESULT_FROM_WIN32(waitableResult->wait(INFINITE)); - if (SUCCEEDED(hr)) - { - hr = waitableResult->read_result(); - } - } - return hr; - } - CATCH_RETURN() - - // cancels anything queued to the TP - this WslCoreMessageQueue instance can no longer be used - void cancel() noexcept - try - { - if (m_tpHandle) - { - // immediately release anyone waiting for these workitems not yet run - { - const auto queueLock = m_lock.lock_exclusive(); - m_isCanceled = true; - - for (const auto& work : m_workItems) - { - // signal that these are canceled before we shutdown the TP which they could be scheduled - if (const auto* pWaitableWorkitem = std::get_if(&work)) - { - (*pWaitableWorkitem)->abort(); - } - } - - m_workItems.clear(); - } - - // force the m_tpHandle to wait and close the TP - m_tpHandle.reset(); - m_tpEnvironment.reset(); - } - } - CATCH_LOG() - - bool isRunningInQueue() const noexcept - { - const auto currentThreadId = GetThreadId(GetCurrentThread()); - return currentThreadId == static_cast(InterlockedCompareExchange64(&m_threadpoolThreadId, 0ll, 0ll)); - } - - ~WslCoreMessageQueue() noexcept - { - cancel(); - } - - WslCoreMessageQueue(const WslCoreMessageQueue&) = delete; - WslCoreMessageQueue& operator=(const WslCoreMessageQueue&) = delete; - WslCoreMessageQueue(WslCoreMessageQueue&&) = delete; - WslCoreMessageQueue& operator=(WslCoreMessageQueue&&) = delete; - -private: - struct TPEnvironment - { - using unique_tp_env = wil::unique_struct; - unique_tp_env m_tpEnvironment; - - using unique_tp_pool = wil::unique_any; - unique_tp_pool m_threadPool; - - TPEnvironment(DWORD countMinThread, DWORD countMaxThread) - { - InitializeThreadpoolEnvironment(&m_tpEnvironment); - - m_threadPool.reset(CreateThreadpool(nullptr)); - THROW_LAST_ERROR_IF_NULL(m_threadPool.get()); - - // Set min and max thread counts for custom thread pool - THROW_LAST_ERROR_IF(!::SetThreadpoolThreadMinimum(m_threadPool.get(), countMinThread)); - SetThreadpoolThreadMaximum(m_threadPool.get(), countMaxThread); - SetThreadpoolCallbackPool(&m_tpEnvironment, m_threadPool.get()); - } - - wil::unique_threadpool_work create_tp(PTP_WORK_CALLBACK callback, void* pv) - { - wil::unique_threadpool_work newThreadpool(CreateThreadpoolWork(callback, pv, (m_threadPool) ? &m_tpEnvironment : nullptr)); - THROW_LAST_ERROR_IF_NULL(newThreadpool.get()); - return newThreadpool; - } - - void reset() - { - m_threadPool.reset(); - m_tpEnvironment.reset(); - } - }; - - using SimpleFunction_t = std::function; - using WaitableFunction_t = std::shared_ptr; - using FunctionVariant_t = std::variant; - - // the lock must be destroyed *after* the TP object (thus must be declared first) - // since the lock is used in the TP callback - // the lock is mutable to allow us to acquire the lock in const methods - mutable wil::srwlock m_lock; - TPEnvironment m_tpEnvironment; - wil::unique_threadpool_work m_tpHandle; - std::deque m_workItems; - mutable LONG64 m_threadpoolThreadId{0}; // useful for callers to assert they are running within the queue - bool m_isCanceled{false}; - - static void CALLBACK WorkCallback(PTP_CALLBACK_INSTANCE, void* Context, PTP_WORK) noexcept - try - { - auto* pThis = static_cast(Context); - - FunctionVariant_t work; - { - const auto queueLock = pThis->m_lock.lock_exclusive(); - - if (pThis->m_workItems.empty()) - { - // pThis object is being destroyed and the queue was cleared - return; - } - - std::swap(work, pThis->m_workItems.front()); - pThis->m_workItems.pop_front(); - - InterlockedExchange64(&pThis->m_threadpoolThreadId, GetThreadId(GetCurrentThread())); - } - - // run the tasks outside the WslCoreMessageQueue lock - const auto resetThreadIdOnExit = wil::scope_exit([pThis] { InterlockedExchange64(&pThis->m_threadpoolThreadId, 0ll); }); - if (work.index() == 0) - { - const auto& workItem = std::get(work); - workItem(); - } - else - { - const auto& waitableWorkItem = std::get(work); - waitableWorkItem->run(); - } - } - CATCH_LOG() -}; -} // namespace wsl::core +/*++ + +Copyright (c) Microsoft. All rights reserved. + +Module Name: + + WslCoreMessageQueue.h + +Abstract: + + This file contains a queuing implementation, guaranteeing running function objects + with guaranteed serialization in a threadpool thread + +--*/ + +#pragma once +#include +#include +#include +#include +#include +#include + +namespace wsl::core { +// forward-declare classes that can instantiate a WslThreadPoolWaitableResult object +class WslCoreMessageQueue; + +class WslBaseThreadPoolWaitableResult +{ +public: + virtual ~WslBaseThreadPoolWaitableResult() noexcept = default; + +private: + // limit who can run() and abort() + friend class WslCoreMessageQueue; + + virtual void run() noexcept = 0; + virtual void abort() noexcept = 0; +}; + +template +class WslThreadPoolWaitableResult : public WslBaseThreadPoolWaitableResult +{ +public: + // throws a wil exception on failure + template + explicit WslThreadPoolWaitableResult(FunctorType&& functor) : m_function(std::forward(functor)) + { + } + + ~WslThreadPoolWaitableResult() noexcept override = default; + + // returns ERROR_SUCCESS if the callback ran to completion + // returns ERROR_TIMEOUT if this wait timed out + // - this can be called multiple times if needing to probe + // any other error code resulted from attempting to run the callback + // - meaning it did *not* run to completion + DWORD wait(DWORD timeout) const noexcept + { + if (!m_completionSignal.wait(timeout)) + { + // not setting m_internalError to timeout + // since the caller is allowed to try to wait() again later + return ERROR_TIMEOUT; + } + const auto lock = m_lock.lock_shared(); + return m_internalError; + } + + // waitable event handle, signaled when the callback has run to completion (or failed) + HANDLE notification_event() const noexcept + { + return m_completionSignal.get(); + } + + const TReturn& read_result() const noexcept + { + return result; + } + + // move the result out of the object for move-only types + TReturn move_result() noexcept + { + TReturn move_out(std::move(result)); + return move_out; + } + + // non-copyable + WslThreadPoolWaitableResult(const WslThreadPoolWaitableResult&) = delete; + WslThreadPoolWaitableResult& operator=(const WslThreadPoolWaitableResult&) = delete; + +private: + void run() noexcept override + { + // we are now running in the TP callback + { + const auto lock = m_lock.lock_exclusive(); + if (m_runStatus != RunStatus::NotYetRun) + { + // return early - the caller has already canceled this + return; + } + m_runStatus = RunStatus::Running; + } + + DWORD error = NO_ERROR; + try + { + result = std::move(m_function()); + } + catch (...) + { + const HRESULT hr = wil::ResultFromCaughtException(); + // HRESULT_TO_WIN32 + error = (HRESULT_FACILITY(hr) == FACILITY_WIN32) ? HRESULT_CODE(hr) : hr; + } + + const auto lock = m_lock.lock_exclusive(); + WI_ASSERT(m_runStatus == RunStatus::Running); + m_runStatus = RunStatus::RanToCompletion; + m_internalError = error; + m_completionSignal.SetEvent(); + } + + void abort() noexcept override + { + const auto lock = m_lock.lock_exclusive(); + // only override the error if we know we haven't started running their functor + if (m_runStatus == RunStatus::NotYetRun) + { + m_runStatus = RunStatus::Canceled; + m_internalError = ERROR_CANCELLED; + m_completionSignal.SetEvent(); + } + } + + std::function m_function; + // a notification event + wil::unique_event m_completionSignal{wil::EventOptions::ManualReset}; + mutable wil::srwlock m_lock; + TReturn result{}; + DWORD m_internalError = NO_ERROR; + + enum class RunStatus + { + NotYetRun, + Running, + RanToCompletion, + Canceled + } m_runStatus{RunStatus::NotYetRun}; +}; + +class WslCoreMessageQueue +{ +public: + WslCoreMessageQueue() : m_tpEnvironment(0, 1) + { + // create a single-threaded threadpool + m_tpHandle = m_tpEnvironment.create_tp(WorkCallback, this); + } + + template + std::shared_ptr> submit_with_results(FunctorType&& functor) noexcept + try + { + FAIL_FAST_IF(m_tpHandle.get() == nullptr); + + const auto new_result = std::make_shared>(std::forward(functor)); + // scope to the queue lock + { + const auto queueLock = m_lock.lock_exclusive(); + THROW_HR_IF(HRESULT_FROM_WIN32(ERROR_CANCELLED), m_isCanceled); + m_workItems.emplace_back(new_result); + } + + // always maintain a 1:1 ratio for calls to SubmitWorkWithResults() and ::SubmitThreadpoolWork + SubmitThreadpoolWork(m_tpHandle.get()); + return new_result; + } + catch (...) + { + LOG_CAUGHT_EXCEPTION(); + return nullptr; + } + + template + bool submit(FunctorType&& functor) noexcept + try + { + FAIL_FAST_IF(m_tpHandle.get() == nullptr); + + // scope to the queue lock + { + const auto queueLock = m_lock.lock_exclusive(); + THROW_HR_IF(HRESULT_FROM_WIN32(ERROR_CANCELLED), m_isCanceled); + m_workItems.emplace_back(std::forward(functor)); + } + + // always maintain a 1:1 ratio for calls to SubmitWork() and ::SubmitThreadpoolWork + SubmitThreadpoolWork(m_tpHandle.get()); + return true; + } + catch (...) + { + LOG_CAUGHT_EXCEPTION(); + return false; + } + + // functors must return type HRESULT + template + HRESULT submit_and_wait(FunctorType&& functor) noexcept + try + { + HRESULT hr = HRESULT_FROM_WIN32(ERROR_OUTOFMEMORY); + if (const auto waitableResult = submit_with_results(std::forward(functor))) + { + hr = HRESULT_FROM_WIN32(waitableResult->wait(INFINITE)); + if (SUCCEEDED(hr)) + { + hr = waitableResult->read_result(); + } + } + return hr; + } + CATCH_RETURN() + + // cancels anything queued to the TP - this WslCoreMessageQueue instance can no longer be used + void cancel() noexcept + try + { + if (m_tpHandle) + { + // immediately release anyone waiting for these workitems not yet run + { + const auto queueLock = m_lock.lock_exclusive(); + m_isCanceled = true; + + for (const auto& work : m_workItems) + { + // signal that these are canceled before we shutdown the TP which they could be scheduled + if (const auto* pWaitableWorkitem = std::get_if(&work)) + { + (*pWaitableWorkitem)->abort(); + } + } + + m_workItems.clear(); + } + + // force the m_tpHandle to wait and close the TP + m_tpHandle.reset(); + m_tpEnvironment.reset(); + } + } + CATCH_LOG() + + bool isRunningInQueue() const noexcept + { + const auto currentThreadId = GetThreadId(GetCurrentThread()); + return currentThreadId == static_cast(InterlockedCompareExchange64(&m_threadpoolThreadId, 0ll, 0ll)); + } + + ~WslCoreMessageQueue() noexcept + { + cancel(); + } + + WslCoreMessageQueue(const WslCoreMessageQueue&) = delete; + WslCoreMessageQueue& operator=(const WslCoreMessageQueue&) = delete; + WslCoreMessageQueue(WslCoreMessageQueue&&) = delete; + WslCoreMessageQueue& operator=(WslCoreMessageQueue&&) = delete; + +private: + struct TPEnvironment + { + using unique_tp_env = wil::unique_struct; + unique_tp_env m_tpEnvironment; + + using unique_tp_pool = wil::unique_any; + unique_tp_pool m_threadPool; + + TPEnvironment(DWORD countMinThread, DWORD countMaxThread) + { + InitializeThreadpoolEnvironment(&m_tpEnvironment); + + m_threadPool.reset(CreateThreadpool(nullptr)); + THROW_LAST_ERROR_IF_NULL(m_threadPool.get()); + + // Set min and max thread counts for custom thread pool + THROW_LAST_ERROR_IF(!::SetThreadpoolThreadMinimum(m_threadPool.get(), countMinThread)); + SetThreadpoolThreadMaximum(m_threadPool.get(), countMaxThread); + SetThreadpoolCallbackPool(&m_tpEnvironment, m_threadPool.get()); + } + + wil::unique_threadpool_work create_tp(PTP_WORK_CALLBACK callback, void* pv) + { + wil::unique_threadpool_work newThreadpool(CreateThreadpoolWork(callback, pv, (m_threadPool) ? &m_tpEnvironment : nullptr)); + THROW_LAST_ERROR_IF_NULL(newThreadpool.get()); + return newThreadpool; + } + + void reset() + { + m_threadPool.reset(); + m_tpEnvironment.reset(); + } + }; + + using SimpleFunction_t = std::function; + using WaitableFunction_t = std::shared_ptr; + using FunctionVariant_t = std::variant; + + // the lock must be destroyed *after* the TP object (thus must be declared first) + // since the lock is used in the TP callback + // the lock is mutable to allow us to acquire the lock in const methods + mutable wil::srwlock m_lock; + TPEnvironment m_tpEnvironment; + wil::unique_threadpool_work m_tpHandle; + std::deque m_workItems; + mutable LONG64 m_threadpoolThreadId{0}; // useful for callers to assert they are running within the queue + bool m_isCanceled{false}; + + static void CALLBACK WorkCallback(PTP_CALLBACK_INSTANCE, void* Context, PTP_WORK) noexcept + try + { + auto* pThis = static_cast(Context); + + FunctionVariant_t work; + { + const auto queueLock = pThis->m_lock.lock_exclusive(); + + if (pThis->m_workItems.empty()) + { + // pThis object is being destroyed and the queue was cleared + return; + } + + std::swap(work, pThis->m_workItems.front()); + pThis->m_workItems.pop_front(); + + InterlockedExchange64(&pThis->m_threadpoolThreadId, GetThreadId(GetCurrentThread())); + } + + // run the tasks outside the WslCoreMessageQueue lock + const auto resetThreadIdOnExit = wil::scope_exit([pThis] { InterlockedExchange64(&pThis->m_threadpoolThreadId, 0ll); }); + if (work.index() == 0) + { + const auto& workItem = std::get(work); + workItem(); + } + else + { + const auto& waitableWorkItem = std::get(work); + waitableWorkItem->run(); + } + } + CATCH_LOG() +}; +} // namespace wsl::core diff --git a/src/windows/service/exe/WslCoreNetworkEndpointSettings.cpp b/src/windows/common/WslCoreNetworkEndpointSettings.cpp similarity index 97% rename from src/windows/service/exe/WslCoreNetworkEndpointSettings.cpp rename to src/windows/common/WslCoreNetworkEndpointSettings.cpp index 92108a5..5d5c789 100644 --- a/src/windows/service/exe/WslCoreNetworkEndpointSettings.cpp +++ b/src/windows/common/WslCoreNetworkEndpointSettings.cpp @@ -1,110 +1,110 @@ -// Copyright (C) Microsoft Corporation. All rights reserved. - -#include "precomp.h" -#include "hns_schema.h" -#include "WslCoreNetworkEndpointSettings.h" -#include "WslCoreHostDnsInfo.h" - -using namespace wsl::shared; - -std::shared_ptr wsl::core::networking::GetEndpointSettings(const hns::HNSEndpoint& properties) -{ - EndpointIpAddress address{}; - address.Address = windows::common::string::StringToSockAddrInet(properties.IPAddress); - address.AddressString = properties.IPAddress; - address.PrefixLength = properties.PrefixLength; - - EndpointRoute route{}; - route.DestinationPrefix.PrefixLength = 0; - IN4ADDR_SETANY(&route.DestinationPrefix.Prefix.Ipv4); - route.DestinationPrefixString = LX_INIT_UNSPECIFIED_ADDRESS; - route.NextHop = windows::common::string::StringToSockAddrInet(properties.GatewayAddress); - route.NextHopString = properties.GatewayAddress; - - return std::make_shared( - properties.InterfaceConstraint.InterfaceGuid, - address, - route, - properties.MacAddress, - L"unuseddevicename", - properties.InterfaceConstraint.InterfaceIndex, - properties.InterfaceConstraint.InterfaceMediaType, - properties.DNSServerList); -} - -std::shared_ptr wsl::core::networking::GetHostEndpointSettings() -{ - HostDnsInfo dnsInfo; - dnsInfo.UpdateNetworkInformation(); - auto addresses = dnsInfo.CurrentAddresses(); - auto bestIndex = GetBestInterface(); - auto bestInterfacePtr = - std::find_if(addresses.cbegin(), addresses.cend(), [&](const auto& address) { return address->IfIndex == bestIndex; }); - if (bestInterfacePtr == addresses.end()) - { - return std::make_shared(); - } - - const auto& bestInterface = *bestInterfacePtr; - - std::wstring macAddress = wsl::shared::string::FormatMacAddress( - wsl::shared::string::MacAddress{ - bestInterface->PhysicalAddress[0], - bestInterface->PhysicalAddress[1], - bestInterface->PhysicalAddress[2], - bestInterface->PhysicalAddress[3], - bestInterface->PhysicalAddress[4], - bestInterface->PhysicalAddress[5]}, - L'-'); - - EndpointIpAddress address{}; - auto firstIpv4Address = bestInterface->FirstUnicastAddress; - while (firstIpv4Address && firstIpv4Address->Address.lpSockaddr->sa_family != AF_INET) - { - firstIpv4Address = firstIpv4Address->Next; - } - if (firstIpv4Address) - { - address.Address = *reinterpret_cast(firstIpv4Address->Address.lpSockaddr); - address.AddressString = windows::common::string::SockAddrInetToWstring(address.Address); - address.PrefixLength = firstIpv4Address->OnLinkPrefixLength; - } - - EndpointRoute route{}; - PIP_ADAPTER_GATEWAY_ADDRESS nextGatewayAddress = bestInterface->FirstGatewayAddress; - while (nextGatewayAddress && nextGatewayAddress->Address.lpSockaddr->sa_family != AF_INET) - { - nextGatewayAddress = nextGatewayAddress->Next; - } - if (nextGatewayAddress) - { - route.DestinationPrefix.PrefixLength = 0; - IN4ADDR_SETANY(&route.DestinationPrefix.Prefix.Ipv4); - route.DestinationPrefixString = LX_INIT_UNSPECIFIED_ADDRESS; - route.NextHop = *reinterpret_cast(nextGatewayAddress->Address.lpSockaddr); - route.NextHopString = windows::common::string::SockAddrInetToWstring(route.NextHop); - } - else if (address.Address.si_family == AF_INET) - { - IN_ADDR default_route{}; - default_route.s_addr = htonl((ntohl(address.Address.Ipv4.sin_addr.s_addr) & ~((1 << (32 - address.PrefixLength)) - 1)) | 1); - route.DestinationPrefix.PrefixLength = 0; - IN4ADDR_SETANY(&route.DestinationPrefix.Prefix.Ipv4); - route.DestinationPrefixString = LX_INIT_UNSPECIFIED_ADDRESS; - IN4ADDR_SETSOCKADDR(&route.NextHop.Ipv4, &default_route, 0); - route.NextHopString = windows::common::string::SockAddrInetToWstring(route.NextHop); - } - - std::wstring dnsServerList; - for (const auto& serverAddress : dnsInfo.GetDnsSettings(DnsSettingsFlags::IncludeVpn).Servers) - { - if (!dnsServerList.empty()) - { - dnsServerList += L","; - } - dnsServerList += wsl::shared::string::MultiByteToWide(serverAddress); - } - - return std::shared_ptr(new NetworkSettings( - bestInterface->NetworkGuid, address, route, macAddress, {}, bestInterface->IfIndex, bestInterface->IfType, dnsServerList)); -} +// Copyright (C) Microsoft Corporation. All rights reserved. + +#include "precomp.h" +#include "hns_schema.h" +#include "WslCoreNetworkEndpointSettings.h" +#include "WslCoreHostDnsInfo.h" + +using namespace wsl::shared; + +std::shared_ptr wsl::core::networking::GetEndpointSettings(const hns::HNSEndpoint& properties) +{ + EndpointIpAddress address{}; + address.Address = windows::common::string::StringToSockAddrInet(properties.IPAddress); + address.AddressString = properties.IPAddress; + address.PrefixLength = properties.PrefixLength; + + EndpointRoute route{}; + route.DestinationPrefix.PrefixLength = 0; + IN4ADDR_SETANY(&route.DestinationPrefix.Prefix.Ipv4); + route.DestinationPrefixString = LX_INIT_UNSPECIFIED_ADDRESS; + route.NextHop = windows::common::string::StringToSockAddrInet(properties.GatewayAddress); + route.NextHopString = properties.GatewayAddress; + + return std::make_shared( + properties.InterfaceConstraint.InterfaceGuid, + address, + route, + properties.MacAddress, + L"unuseddevicename", + properties.InterfaceConstraint.InterfaceIndex, + properties.InterfaceConstraint.InterfaceMediaType, + properties.DNSServerList); +} + +std::shared_ptr wsl::core::networking::GetHostEndpointSettings() +{ + HostDnsInfo dnsInfo; + dnsInfo.UpdateNetworkInformation(); + auto addresses = dnsInfo.CurrentAddresses(); + auto bestIndex = GetBestInterface(); + auto bestInterfacePtr = + std::find_if(addresses.cbegin(), addresses.cend(), [&](const auto& address) { return address->IfIndex == bestIndex; }); + if (bestInterfacePtr == addresses.end()) + { + return std::make_shared(); + } + + const auto& bestInterface = *bestInterfacePtr; + + std::wstring macAddress = wsl::shared::string::FormatMacAddress( + wsl::shared::string::MacAddress{ + bestInterface->PhysicalAddress[0], + bestInterface->PhysicalAddress[1], + bestInterface->PhysicalAddress[2], + bestInterface->PhysicalAddress[3], + bestInterface->PhysicalAddress[4], + bestInterface->PhysicalAddress[5]}, + L'-'); + + EndpointIpAddress address{}; + auto firstIpv4Address = bestInterface->FirstUnicastAddress; + while (firstIpv4Address && firstIpv4Address->Address.lpSockaddr->sa_family != AF_INET) + { + firstIpv4Address = firstIpv4Address->Next; + } + if (firstIpv4Address) + { + address.Address = *reinterpret_cast(firstIpv4Address->Address.lpSockaddr); + address.AddressString = windows::common::string::SockAddrInetToWstring(address.Address); + address.PrefixLength = firstIpv4Address->OnLinkPrefixLength; + } + + EndpointRoute route{}; + PIP_ADAPTER_GATEWAY_ADDRESS nextGatewayAddress = bestInterface->FirstGatewayAddress; + while (nextGatewayAddress && nextGatewayAddress->Address.lpSockaddr->sa_family != AF_INET) + { + nextGatewayAddress = nextGatewayAddress->Next; + } + if (nextGatewayAddress) + { + route.DestinationPrefix.PrefixLength = 0; + IN4ADDR_SETANY(&route.DestinationPrefix.Prefix.Ipv4); + route.DestinationPrefixString = LX_INIT_UNSPECIFIED_ADDRESS; + route.NextHop = *reinterpret_cast(nextGatewayAddress->Address.lpSockaddr); + route.NextHopString = windows::common::string::SockAddrInetToWstring(route.NextHop); + } + else if (address.Address.si_family == AF_INET) + { + IN_ADDR default_route{}; + default_route.s_addr = htonl((ntohl(address.Address.Ipv4.sin_addr.s_addr) & ~((1 << (32 - address.PrefixLength)) - 1)) | 1); + route.DestinationPrefix.PrefixLength = 0; + IN4ADDR_SETANY(&route.DestinationPrefix.Prefix.Ipv4); + route.DestinationPrefixString = LX_INIT_UNSPECIFIED_ADDRESS; + IN4ADDR_SETSOCKADDR(&route.NextHop.Ipv4, &default_route, 0); + route.NextHopString = windows::common::string::SockAddrInetToWstring(route.NextHop); + } + + std::wstring dnsServerList; + for (const auto& serverAddress : dnsInfo.GetDnsSettings(DnsSettingsFlags::IncludeVpn).Servers) + { + if (!dnsServerList.empty()) + { + dnsServerList += L","; + } + dnsServerList += wsl::shared::string::MultiByteToWide(serverAddress); + } + + return std::shared_ptr(new NetworkSettings( + bestInterface->NetworkGuid, address, route, macAddress, {}, bestInterface->IfIndex, bestInterface->IfType, dnsServerList)); +} diff --git a/src/windows/service/exe/WslCoreNetworkEndpointSettings.h b/src/windows/common/WslCoreNetworkEndpointSettings.h similarity index 97% rename from src/windows/service/exe/WslCoreNetworkEndpointSettings.h rename to src/windows/common/WslCoreNetworkEndpointSettings.h index 5e165c9..ce91a24 100644 --- a/src/windows/service/exe/WslCoreNetworkEndpointSettings.h +++ b/src/windows/common/WslCoreNetworkEndpointSettings.h @@ -1,398 +1,398 @@ -// Copyright (C) Microsoft Corporation. All rights reserved. - -#pragma once -#include -#include -#include - -#include -#include -#include -#include - -#include "hcs.hpp" -#include "lxinitshared.h" -#include "Stringify.h" -#include "stringshared.h" -#include "WslCoreNetworkingSupport.h" -#include "hns_schema.h" - -namespace wsl::core::networking { - -constexpr auto AddEndpointRetryPeriod = std::chrono::milliseconds(100); -constexpr auto AddEndpointRetryTimeout = std::chrono::seconds(3); -constexpr auto AddEndpointRetryPredicate = [] { - // Don't retry if ModifyComputeSystem fails with: - // HCN_E_ENDPOINT_NOT_FOUND - indicates that the underlying network object was deleted. - // HCN_E_ENDPOINT_ALREADY_ATTACHED - occurs when HNS was restarted before the endpoints were removed. - // VM_E_INVALID_STATE - occurs when the VM has been terminated. - const auto result = wil::ResultFromCaughtException(); - return result != HCN_E_ENDPOINT_NOT_FOUND && result != HCN_E_ENDPOINT_ALREADY_ATTACHED && result != VM_E_INVALID_STATE; -}; - -struct EndpointIpAddress -{ - SOCKADDR_INET Address{}; - std::wstring AddressString{}; - unsigned char PrefixLength = 0; - unsigned int PrefixOrigin = 0; - unsigned int SuffixOrigin = 0; - - // The following field can be changed from a const iterator in SyncIpStateWithLinux - that's why it's marked mutable. - mutable unsigned int PreferredLifetime = 0; - - EndpointIpAddress() = default; - ~EndpointIpAddress() noexcept = default; - - EndpointIpAddress(EndpointIpAddress&&) = default; - EndpointIpAddress& operator=(EndpointIpAddress&&) = default; - EndpointIpAddress(const EndpointIpAddress&) = default; - EndpointIpAddress& operator=(const EndpointIpAddress&) = default; - - explicit EndpointIpAddress(const MIB_UNICASTIPADDRESS_ROW& AddressRow) : - Address(AddressRow.Address), - AddressString(windows::common::string::SockAddrInetToWstring(AddressRow.Address)), - PrefixLength(AddressRow.OnLinkPrefixLength), - PrefixOrigin(AddressRow.PrefixOrigin), - SuffixOrigin(AddressRow.SuffixOrigin), - // We treat the preferred lifetime field as effective DAD state - 0 is not preferred, anything else is preferred. - // We do this for convenience, as we can't directly set the DAD state of an address into the guest, but we - // we can set an address's preferred lifetime (in Linux, at least). - PreferredLifetime(AddressRow.DadState == IpDadStatePreferred ? 0xFFFFFFFF : 0) - { - } - - // operator== is deliberately not comparing PreferredLifetime (DAD state) for equality - only the address portion - bool operator==(const EndpointIpAddress& rhs) const noexcept - { - return Address == rhs.Address && PrefixLength == rhs.PrefixLength; - } - - bool operator<(const EndpointIpAddress& rhs) const noexcept - { - if (Address == rhs.Address) - { - return PrefixLength < rhs.PrefixLength; - } - return Address < rhs.Address; - } - - void Clear() noexcept - { - Address = {}; - AddressString.clear(); - PrefixLength = 0; - PrefixOrigin = 0; - SuffixOrigin = 0; - } - - std::wstring GetPrefix() const - { - SOCKADDR_INET address{Address}; - unsigned char* addressPointer{nullptr}; - - if (Address.si_family == AF_INET) - { - addressPointer = reinterpret_cast(&address.Ipv4.sin_addr); - } - else if (Address.si_family == AF_INET6) - { - addressPointer = address.Ipv6.sin6_addr.u.Byte; - } - else - { - return L""; - } - - constexpr int c_numBitsPerByte = 8; - for (int i = 0, currPrefixLength = PrefixLength; i < INET_ADDR_LENGTH(Address.si_family); i++, currPrefixLength -= c_numBitsPerByte) - { - if (currPrefixLength < c_numBitsPerByte) - { - const int bitShiftAmt = c_numBitsPerByte - std::max(currPrefixLength, 0); - addressPointer[i] &= (0xFF >> bitShiftAmt) << bitShiftAmt; - } - } - - const auto addressString = windows::common::string::SockAddrInetToWstring(address); - WI_ASSERT(!addressString.empty()); - if (addressString.empty()) - { - // just return an empty string if we have a bad address - return addressString; - } - - return std::format(L"{}/{}", addressString, PrefixLength); - } - - std::wstring GetIpv4BroadcastMask() const - { - // start with all bits set, then shift off the prefix - ULONG prefixMask{0xffffffff}; - prefixMask <<= PrefixLength; - prefixMask >>= PrefixLength; - - SOCKADDR_INET address{Address}; - // flip to host-order, then apply the mask - ULONG hostOrder = ntohl(address.Ipv4.sin_addr.S_un.S_addr); - hostOrder |= prefixMask; - address.Ipv4.sin_addr.S_un.S_addr = htonl(hostOrder); - - return windows::common::string::SockAddrInetToWstring(address); - } - - bool IsPreferred() const noexcept - { - return PreferredLifetime > 0; - } - - bool IsLinkLocal() const - { - return (Address.si_family == AF_INET && IN4_IS_ADDR_LINKLOCAL(&Address.Ipv4.sin_addr)) || - (Address.si_family == AF_INET6 && IN6_IS_ADDR_LINKLOCAL(&Address.Ipv6.sin6_addr)); - } -}; - -struct EndpointRoute -{ - ADDRESS_FAMILY Family = AF_INET; - IP_ADDRESS_PREFIX DestinationPrefix{}; - std::wstring DestinationPrefixString{}; - SOCKADDR_INET NextHop{}; - std::wstring NextHopString{}; - unsigned char SitePrefixLength = 0; - unsigned int Metric = 0; - bool IsAutoGeneratedPrefixRoute = false; - - EndpointRoute() = default; - ~EndpointRoute() noexcept = default; - - EndpointRoute(EndpointRoute&&) = default; - EndpointRoute& operator=(EndpointRoute&&) = default; - EndpointRoute(const EndpointRoute&) = default; - EndpointRoute& operator=(const EndpointRoute&) = default; - - EndpointRoute(const MIB_IPFORWARD_ROW2& RouteRow) : - Family(RouteRow.NextHop.si_family), - DestinationPrefix(RouteRow.DestinationPrefix), - DestinationPrefixString(windows::common::string::SockAddrInetToWstring(RouteRow.DestinationPrefix.Prefix)), - NextHop(RouteRow.NextHop), - NextHopString(windows::common::string::SockAddrInetToWstring(RouteRow.NextHop)), - SitePrefixLength(RouteRow.SitePrefixLength), - Metric(RouteRow.Metric) - { - } - - unsigned char GetMaxPrefixLength() const - { - return (Family == AF_INET) ? 32 : 128; - } - - std::wstring GetFullDestinationPrefix() const - { - return std::format(L"{}/{}", DestinationPrefixString, static_cast(DestinationPrefix.PrefixLength)); - } - - bool IsNextHopOnlink() const noexcept - { - return (Family == AF_INET && NextHopString == LX_INIT_UNSPECIFIED_ADDRESS) || - (Family == AF_INET6 && NextHopString == LX_INIT_UNSPECIFIED_V6_ADDRESS); - } - - bool IsDefault() const noexcept - { - return (Family == AF_INET && DestinationPrefixString == LX_INIT_UNSPECIFIED_ADDRESS) || - (Family == AF_INET6 && DestinationPrefixString == LX_INIT_UNSPECIFIED_V6_ADDRESS); - } - - bool IsUnicastAddressRoute() const noexcept - { - return (Family == AF_INET && DestinationPrefix.PrefixLength == 32) || (Family == AF_INET6 && DestinationPrefix.PrefixLength == 128); - } - - std::wstring ToString() const - { - return std::format(L"{}=>{} [metric {}]", GetFullDestinationPrefix(), NextHopString, Metric); - } - - bool operator==(const EndpointRoute& rhs) const noexcept - { - return Family == rhs.Family && DestinationPrefix.PrefixLength == rhs.DestinationPrefix.PrefixLength && - DestinationPrefix.Prefix == rhs.DestinationPrefix.Prefix && NextHop == rhs.NextHop && - SitePrefixLength == rhs.SitePrefixLength && Metric == rhs.Metric; - } - - bool operator!=(const EndpointRoute& other) const - { - return !(*this == other); - } - - // sort by family, then by next-hop (on-link routes first), then by prefix, then by metric - bool operator<(const EndpointRoute& rhs) const noexcept - { - if (Family == rhs.Family) - { - if (NextHop == rhs.NextHop) - { - if (DestinationPrefix.Prefix == rhs.DestinationPrefix.Prefix) - { - if (DestinationPrefix.PrefixLength == rhs.DestinationPrefix.PrefixLength) - { - if (Metric == rhs.Metric) - { - return SitePrefixLength < rhs.SitePrefixLength; - } - return Metric < rhs.Metric; - } - return DestinationPrefix.PrefixLength < rhs.DestinationPrefix.PrefixLength; - } - return DestinationPrefix.Prefix < rhs.DestinationPrefix.Prefix; - } - return NextHop < rhs.NextHop; - } - return Family < rhs.Family; - } -}; - -struct NetworkSettings -{ - NetworkSettings() = default; - - NetworkSettings( - const GUID& interfaceGuid, - EndpointIpAddress preferredIpAddress, - EndpointRoute gateway, - std::wstring macAddress, - std::wstring deviceName, - uint32_t interfaceIndex, - uint32_t mediaType, - const std::wstring& dnsServerList) : - InterfaceGuid(interfaceGuid), - PreferredIpAddress(std::move(preferredIpAddress)), - MacAddress(std::move(macAddress)), - DeviceName(std::move(deviceName)), - InterfaceIndex(interfaceIndex), - InterfaceType(mediaType) - { - Routes.emplace(std::move(gateway)); - DnsServers = wsl::shared::string::Split(dnsServerList, L','); - } - - GUID InterfaceGuid{}; - EndpointIpAddress PreferredIpAddress{}; - std::set IpAddresses{}; // Does not include PreferredIpAddress. - std::set Routes{}; - std::vector DnsServers{}; - std::wstring MacAddress; - std::wstring DeviceName; - IF_INDEX InterfaceIndex = 0; - IFTYPE InterfaceType = 0; - ULONG IPv4InterfaceMtu = 0; - ULONG IPv6InterfaceMtu = 0; - // some interfaces will only have an IPv4 or IPv6 interface - std::optional IPv4InterfaceMetric = 0; - std::optional IPv6InterfaceMetric = 0; - bool IsHidden = false; - bool IsConnected = false; - bool IsMetered = false; - bool DisableIpv4DefaultRoutes = false; - bool DisableIpv6DefaultRoutes = false; - bool PendingUpdateToReconnectForMetered = false; - bool PendingIPInterfaceUpdate = false; - - auto operator<=>(const NetworkSettings&) const = default; - - std::wstring GetBestGatewayAddressString() const - { - // Best is currently defined as simply the first IPv4 gateway. - for (const auto& route : Routes) - { - if (route.Family == AF_INET && route.DestinationPrefix.PrefixLength == 0 && route.DestinationPrefixString == LX_INIT_UNSPECIFIED_ADDRESS) - { - return route.NextHopString; - } - } - - return {}; - } - - SOCKADDR_INET GetBestGatewayAddress() const - { - // Best is currently defined as simply the first IPv4 gateway. - for (const auto& route : Routes) - { - if (route.Family == AF_INET && route.DestinationPrefix.PrefixLength == 0 && route.DestinationPrefixString == LX_INIT_UNSPECIFIED_ADDRESS) - { - return route.NextHop; - } - } - - return {}; - } - - std::wstring IpAddressesString() const - { - return std::accumulate(std::begin(IpAddresses), std::end(IpAddresses), std::wstring{}, [](const std::wstring& prev, const auto& addr) { - return addr.AddressString + (prev.empty() ? L"" : L"," + prev); - }); - } - - std::wstring RoutesString() const - { - return std::accumulate(std::begin(Routes), std::end(Routes), std::wstring{}, [](const std::wstring& prev, const EndpointRoute& route) { - return route.ToString() + (prev.empty() ? L"" : L"," + prev); - }); - } - - std::wstring DnsServersString() const - { - return wsl::shared::string::Join(DnsServers, L','); - } - - // will return ULONG_MAX if there's no configured MTU - ULONG GetEffectiveMtu() const noexcept - { - return std::min(IPv4InterfaceMtu > 0 ? IPv4InterfaceMtu : ULONG_MAX, IPv6InterfaceMtu > 0 ? IPv6InterfaceMtu : ULONG_MAX); - } - - // will return zero if there's no configured metric - ULONG GetMinimumMetric() const noexcept - { - if (!IPv4InterfaceMetric.has_value() && !IPv6InterfaceMetric.has_value()) - { - return 0; - } - if (!IPv4InterfaceMetric.has_value()) - { - return IPv6InterfaceMetric.value(); - } - if (!IPv6InterfaceMetric.has_value()) - { - return IPv4InterfaceMetric.value(); - } - return std::min(IPv4InterfaceMetric.value(), IPv6InterfaceMetric.value()); - } -}; - -std::shared_ptr GetEndpointSettings(const wsl::shared::hns::HNSEndpoint& properties); -std::shared_ptr GetHostEndpointSettings(); - -#define TRACE_NETWORKSETTINGS_OBJECT(settings) \ - TraceLoggingValue((settings)->InterfaceGuid, "interfaceGuid"), TraceLoggingValue((settings)->InterfaceIndex, "interfaceIndex"), \ - TraceLoggingValue((settings)->InterfaceType, "interfaceType"), \ - TraceLoggingValue((settings)->IsConnected, "isConnected"), TraceLoggingValue((settings)->IsMetered, "isMetered"), \ - TraceLoggingValue((settings)->GetBestGatewayAddressString().c_str(), "bestGatewayAddress"), \ - TraceLoggingValue((settings)->PreferredIpAddress.AddressString.c_str(), "preferredIpAddress"), \ - TraceLoggingValue((settings)->PreferredIpAddress.PrefixLength, "preferredIpAddressPrefixLength"), \ - TraceLoggingValue((settings)->IpAddressesString().c_str(), "ipAddresses"), \ - TraceLoggingValue((settings)->RoutesString().c_str(), "routes"), \ - TraceLoggingValue((settings)->DnsServersString().c_str(), "dnsServerList"), \ - TraceLoggingValue((settings)->MacAddress.c_str(), "macAddress"), \ - TraceLoggingValue((settings)->IPv4InterfaceMtu, "IPv4InterfaceMtu"), \ - TraceLoggingValue((settings)->IPv6InterfaceMtu, "IPv6InterfaceMtu"), \ - TraceLoggingValue((settings)->IPv4InterfaceMetric.value_or(0xffffffff), "IPv4InterfaceMetric"), \ - TraceLoggingValue((settings)->IPv6InterfaceMetric.value_or(0xffffffff), "IPv6InterfaceMetric"), \ - TraceLoggingValue((settings)->PendingIPInterfaceUpdate, "PendingIPInterfaceUpdate"), \ - TraceLoggingValue((settings)->PendingUpdateToReconnectForMetered, "PendingUpdateToReconnectForMetered") - -} // namespace wsl::core::networking +// Copyright (C) Microsoft Corporation. All rights reserved. + +#pragma once +#include +#include +#include + +#include +#include +#include +#include + +#include "hcs.hpp" +#include "lxinitshared.h" +#include "Stringify.h" +#include "stringshared.h" +#include "WslCoreNetworkingSupport.h" +#include "hns_schema.h" + +namespace wsl::core::networking { + +constexpr auto AddEndpointRetryPeriod = std::chrono::milliseconds(100); +constexpr auto AddEndpointRetryTimeout = std::chrono::seconds(3); +constexpr auto AddEndpointRetryPredicate = [] { + // Don't retry if ModifyComputeSystem fails with: + // HCN_E_ENDPOINT_NOT_FOUND - indicates that the underlying network object was deleted. + // HCN_E_ENDPOINT_ALREADY_ATTACHED - occurs when HNS was restarted before the endpoints were removed. + // VM_E_INVALID_STATE - occurs when the VM has been terminated. + const auto result = wil::ResultFromCaughtException(); + return result != HCN_E_ENDPOINT_NOT_FOUND && result != HCN_E_ENDPOINT_ALREADY_ATTACHED && result != VM_E_INVALID_STATE; +}; + +struct EndpointIpAddress +{ + SOCKADDR_INET Address{}; + std::wstring AddressString{}; + unsigned char PrefixLength = 0; + unsigned int PrefixOrigin = 0; + unsigned int SuffixOrigin = 0; + + // The following field can be changed from a const iterator in SyncIpStateWithLinux - that's why it's marked mutable. + mutable unsigned int PreferredLifetime = 0; + + EndpointIpAddress() = default; + ~EndpointIpAddress() noexcept = default; + + EndpointIpAddress(EndpointIpAddress&&) = default; + EndpointIpAddress& operator=(EndpointIpAddress&&) = default; + EndpointIpAddress(const EndpointIpAddress&) = default; + EndpointIpAddress& operator=(const EndpointIpAddress&) = default; + + explicit EndpointIpAddress(const MIB_UNICASTIPADDRESS_ROW& AddressRow) : + Address(AddressRow.Address), + AddressString(windows::common::string::SockAddrInetToWstring(AddressRow.Address)), + PrefixLength(AddressRow.OnLinkPrefixLength), + PrefixOrigin(AddressRow.PrefixOrigin), + SuffixOrigin(AddressRow.SuffixOrigin), + // We treat the preferred lifetime field as effective DAD state - 0 is not preferred, anything else is preferred. + // We do this for convenience, as we can't directly set the DAD state of an address into the guest, but we + // we can set an address's preferred lifetime (in Linux, at least). + PreferredLifetime(AddressRow.DadState == IpDadStatePreferred ? 0xFFFFFFFF : 0) + { + } + + // operator== is deliberately not comparing PreferredLifetime (DAD state) for equality - only the address portion + bool operator==(const EndpointIpAddress& rhs) const noexcept + { + return Address == rhs.Address && PrefixLength == rhs.PrefixLength; + } + + bool operator<(const EndpointIpAddress& rhs) const noexcept + { + if (Address == rhs.Address) + { + return PrefixLength < rhs.PrefixLength; + } + return Address < rhs.Address; + } + + void Clear() noexcept + { + Address = {}; + AddressString.clear(); + PrefixLength = 0; + PrefixOrigin = 0; + SuffixOrigin = 0; + } + + std::wstring GetPrefix() const + { + SOCKADDR_INET address{Address}; + unsigned char* addressPointer{nullptr}; + + if (Address.si_family == AF_INET) + { + addressPointer = reinterpret_cast(&address.Ipv4.sin_addr); + } + else if (Address.si_family == AF_INET6) + { + addressPointer = address.Ipv6.sin6_addr.u.Byte; + } + else + { + return L""; + } + + constexpr int c_numBitsPerByte = 8; + for (int i = 0, currPrefixLength = PrefixLength; i < INET_ADDR_LENGTH(Address.si_family); i++, currPrefixLength -= c_numBitsPerByte) + { + if (currPrefixLength < c_numBitsPerByte) + { + const int bitShiftAmt = c_numBitsPerByte - std::max(currPrefixLength, 0); + addressPointer[i] &= (0xFF >> bitShiftAmt) << bitShiftAmt; + } + } + + const auto addressString = windows::common::string::SockAddrInetToWstring(address); + WI_ASSERT(!addressString.empty()); + if (addressString.empty()) + { + // just return an empty string if we have a bad address + return addressString; + } + + return std::format(L"{}/{}", addressString, PrefixLength); + } + + std::wstring GetIpv4BroadcastMask() const + { + // start with all bits set, then shift off the prefix + ULONG prefixMask{0xffffffff}; + prefixMask <<= PrefixLength; + prefixMask >>= PrefixLength; + + SOCKADDR_INET address{Address}; + // flip to host-order, then apply the mask + ULONG hostOrder = ntohl(address.Ipv4.sin_addr.S_un.S_addr); + hostOrder |= prefixMask; + address.Ipv4.sin_addr.S_un.S_addr = htonl(hostOrder); + + return windows::common::string::SockAddrInetToWstring(address); + } + + bool IsPreferred() const noexcept + { + return PreferredLifetime > 0; + } + + bool IsLinkLocal() const + { + return (Address.si_family == AF_INET && IN4_IS_ADDR_LINKLOCAL(&Address.Ipv4.sin_addr)) || + (Address.si_family == AF_INET6 && IN6_IS_ADDR_LINKLOCAL(&Address.Ipv6.sin6_addr)); + } +}; + +struct EndpointRoute +{ + ADDRESS_FAMILY Family = AF_INET; + IP_ADDRESS_PREFIX DestinationPrefix{}; + std::wstring DestinationPrefixString{}; + SOCKADDR_INET NextHop{}; + std::wstring NextHopString{}; + unsigned char SitePrefixLength = 0; + unsigned int Metric = 0; + bool IsAutoGeneratedPrefixRoute = false; + + EndpointRoute() = default; + ~EndpointRoute() noexcept = default; + + EndpointRoute(EndpointRoute&&) = default; + EndpointRoute& operator=(EndpointRoute&&) = default; + EndpointRoute(const EndpointRoute&) = default; + EndpointRoute& operator=(const EndpointRoute&) = default; + + EndpointRoute(const MIB_IPFORWARD_ROW2& RouteRow) : + Family(RouteRow.NextHop.si_family), + DestinationPrefix(RouteRow.DestinationPrefix), + DestinationPrefixString(windows::common::string::SockAddrInetToWstring(RouteRow.DestinationPrefix.Prefix)), + NextHop(RouteRow.NextHop), + NextHopString(windows::common::string::SockAddrInetToWstring(RouteRow.NextHop)), + SitePrefixLength(RouteRow.SitePrefixLength), + Metric(RouteRow.Metric) + { + } + + unsigned char GetMaxPrefixLength() const + { + return (Family == AF_INET) ? 32 : 128; + } + + std::wstring GetFullDestinationPrefix() const + { + return std::format(L"{}/{}", DestinationPrefixString, static_cast(DestinationPrefix.PrefixLength)); + } + + bool IsNextHopOnlink() const noexcept + { + return (Family == AF_INET && NextHopString == LX_INIT_UNSPECIFIED_ADDRESS) || + (Family == AF_INET6 && NextHopString == LX_INIT_UNSPECIFIED_V6_ADDRESS); + } + + bool IsDefault() const noexcept + { + return (Family == AF_INET && DestinationPrefixString == LX_INIT_UNSPECIFIED_ADDRESS) || + (Family == AF_INET6 && DestinationPrefixString == LX_INIT_UNSPECIFIED_V6_ADDRESS); + } + + bool IsUnicastAddressRoute() const noexcept + { + return (Family == AF_INET && DestinationPrefix.PrefixLength == 32) || (Family == AF_INET6 && DestinationPrefix.PrefixLength == 128); + } + + std::wstring ToString() const + { + return std::format(L"{}=>{} [metric {}]", GetFullDestinationPrefix(), NextHopString, Metric); + } + + bool operator==(const EndpointRoute& rhs) const noexcept + { + return Family == rhs.Family && DestinationPrefix.PrefixLength == rhs.DestinationPrefix.PrefixLength && + DestinationPrefix.Prefix == rhs.DestinationPrefix.Prefix && NextHop == rhs.NextHop && + SitePrefixLength == rhs.SitePrefixLength && Metric == rhs.Metric; + } + + bool operator!=(const EndpointRoute& other) const + { + return !(*this == other); + } + + // sort by family, then by next-hop (on-link routes first), then by prefix, then by metric + bool operator<(const EndpointRoute& rhs) const noexcept + { + if (Family == rhs.Family) + { + if (NextHop == rhs.NextHop) + { + if (DestinationPrefix.Prefix == rhs.DestinationPrefix.Prefix) + { + if (DestinationPrefix.PrefixLength == rhs.DestinationPrefix.PrefixLength) + { + if (Metric == rhs.Metric) + { + return SitePrefixLength < rhs.SitePrefixLength; + } + return Metric < rhs.Metric; + } + return DestinationPrefix.PrefixLength < rhs.DestinationPrefix.PrefixLength; + } + return DestinationPrefix.Prefix < rhs.DestinationPrefix.Prefix; + } + return NextHop < rhs.NextHop; + } + return Family < rhs.Family; + } +}; + +struct NetworkSettings +{ + NetworkSettings() = default; + + NetworkSettings( + const GUID& interfaceGuid, + EndpointIpAddress preferredIpAddress, + EndpointRoute gateway, + std::wstring macAddress, + std::wstring deviceName, + uint32_t interfaceIndex, + uint32_t mediaType, + const std::wstring& dnsServerList) : + InterfaceGuid(interfaceGuid), + PreferredIpAddress(std::move(preferredIpAddress)), + MacAddress(std::move(macAddress)), + DeviceName(std::move(deviceName)), + InterfaceIndex(interfaceIndex), + InterfaceType(mediaType) + { + Routes.emplace(std::move(gateway)); + DnsServers = wsl::shared::string::Split(dnsServerList, L','); + } + + GUID InterfaceGuid{}; + EndpointIpAddress PreferredIpAddress{}; + std::set IpAddresses{}; // Does not include PreferredIpAddress. + std::set Routes{}; + std::vector DnsServers{}; + std::wstring MacAddress; + std::wstring DeviceName; + IF_INDEX InterfaceIndex = 0; + IFTYPE InterfaceType = 0; + ULONG IPv4InterfaceMtu = 0; + ULONG IPv6InterfaceMtu = 0; + // some interfaces will only have an IPv4 or IPv6 interface + std::optional IPv4InterfaceMetric = 0; + std::optional IPv6InterfaceMetric = 0; + bool IsHidden = false; + bool IsConnected = false; + bool IsMetered = false; + bool DisableIpv4DefaultRoutes = false; + bool DisableIpv6DefaultRoutes = false; + bool PendingUpdateToReconnectForMetered = false; + bool PendingIPInterfaceUpdate = false; + + auto operator<=>(const NetworkSettings&) const = default; + + std::wstring GetBestGatewayAddressString() const + { + // Best is currently defined as simply the first IPv4 gateway. + for (const auto& route : Routes) + { + if (route.Family == AF_INET && route.DestinationPrefix.PrefixLength == 0 && route.DestinationPrefixString == LX_INIT_UNSPECIFIED_ADDRESS) + { + return route.NextHopString; + } + } + + return {}; + } + + SOCKADDR_INET GetBestGatewayAddress() const + { + // Best is currently defined as simply the first IPv4 gateway. + for (const auto& route : Routes) + { + if (route.Family == AF_INET && route.DestinationPrefix.PrefixLength == 0 && route.DestinationPrefixString == LX_INIT_UNSPECIFIED_ADDRESS) + { + return route.NextHop; + } + } + + return {}; + } + + std::wstring IpAddressesString() const + { + return std::accumulate(std::begin(IpAddresses), std::end(IpAddresses), std::wstring{}, [](const std::wstring& prev, const auto& addr) { + return addr.AddressString + (prev.empty() ? L"" : L"," + prev); + }); + } + + std::wstring RoutesString() const + { + return std::accumulate(std::begin(Routes), std::end(Routes), std::wstring{}, [](const std::wstring& prev, const EndpointRoute& route) { + return route.ToString() + (prev.empty() ? L"" : L"," + prev); + }); + } + + std::wstring DnsServersString() const + { + return wsl::shared::string::Join(DnsServers, L','); + } + + // will return ULONG_MAX if there's no configured MTU + ULONG GetEffectiveMtu() const noexcept + { + return std::min(IPv4InterfaceMtu > 0 ? IPv4InterfaceMtu : ULONG_MAX, IPv6InterfaceMtu > 0 ? IPv6InterfaceMtu : ULONG_MAX); + } + + // will return zero if there's no configured metric + ULONG GetMinimumMetric() const noexcept + { + if (!IPv4InterfaceMetric.has_value() && !IPv6InterfaceMetric.has_value()) + { + return 0; + } + if (!IPv4InterfaceMetric.has_value()) + { + return IPv6InterfaceMetric.value(); + } + if (!IPv6InterfaceMetric.has_value()) + { + return IPv4InterfaceMetric.value(); + } + return std::min(IPv4InterfaceMetric.value(), IPv6InterfaceMetric.value()); + } +}; + +std::shared_ptr GetEndpointSettings(const wsl::shared::hns::HNSEndpoint& properties); +std::shared_ptr GetHostEndpointSettings(); + +#define TRACE_NETWORKSETTINGS_OBJECT(settings) \ + TraceLoggingValue((settings)->InterfaceGuid, "interfaceGuid"), TraceLoggingValue((settings)->InterfaceIndex, "interfaceIndex"), \ + TraceLoggingValue((settings)->InterfaceType, "interfaceType"), \ + TraceLoggingValue((settings)->IsConnected, "isConnected"), TraceLoggingValue((settings)->IsMetered, "isMetered"), \ + TraceLoggingValue((settings)->GetBestGatewayAddressString().c_str(), "bestGatewayAddress"), \ + TraceLoggingValue((settings)->PreferredIpAddress.AddressString.c_str(), "preferredIpAddress"), \ + TraceLoggingValue((settings)->PreferredIpAddress.PrefixLength, "preferredIpAddressPrefixLength"), \ + TraceLoggingValue((settings)->IpAddressesString().c_str(), "ipAddresses"), \ + TraceLoggingValue((settings)->RoutesString().c_str(), "routes"), \ + TraceLoggingValue((settings)->DnsServersString().c_str(), "dnsServerList"), \ + TraceLoggingValue((settings)->MacAddress.c_str(), "macAddress"), \ + TraceLoggingValue((settings)->IPv4InterfaceMtu, "IPv4InterfaceMtu"), \ + TraceLoggingValue((settings)->IPv6InterfaceMtu, "IPv6InterfaceMtu"), \ + TraceLoggingValue((settings)->IPv4InterfaceMetric.value_or(0xffffffff), "IPv4InterfaceMetric"), \ + TraceLoggingValue((settings)->IPv6InterfaceMetric.value_or(0xffffffff), "IPv6InterfaceMetric"), \ + TraceLoggingValue((settings)->PendingIPInterfaceUpdate, "PendingIPInterfaceUpdate"), \ + TraceLoggingValue((settings)->PendingUpdateToReconnectForMetered, "PendingUpdateToReconnectForMetered") + +} // namespace wsl::core::networking diff --git a/src/windows/common/wslutil.h b/src/windows/common/wslutil.h index cb97439..aa026e3 100644 --- a/src/windows/common/wslutil.h +++ b/src/windows/common/wslutil.h @@ -44,6 +44,7 @@ inline auto c_msixPackageFamilyName = L"MicrosoftCorporationII.WindowsSubsystemF inline auto c_githubUrlOverrideRegistryValue = L"GitHubUrlOverride"; inline auto c_vhdFileExtension = L".vhd"; inline auto c_vhdxFileExtension = L".vhdx"; +inline constexpr auto c_vmOwner = L"WSL"; // TODO-WSLA: Does this apply to WSLA ? struct GitHubReleaseAsset { diff --git a/src/windows/service/exe/CMakeLists.txt b/src/windows/service/exe/CMakeLists.txt index a4fcd4c..7ecdaef 100644 --- a/src/windows/service/exe/CMakeLists.txt +++ b/src/windows/service/exe/CMakeLists.txt @@ -1,6 +1,5 @@ set(SOURCES DistributionRegistration.cpp - LxssSecurity.cpp LxssUserCallback.cpp LxssUserSession.cpp LxssUserSessionFactory.cpp @@ -11,10 +10,6 @@ set(SOURCES PluginManager.cpp ServiceMain.cpp BridgedNetworking.cpp - DeviceHostProxy.cpp - Dmesg.cpp - DnsTunnelingChannel.cpp - GnsChannel.cpp GnsPortTrackerChannel.cpp GnsRpcServer.cpp GuestTelemetryLogger.cpp @@ -22,21 +17,12 @@ set(SOURCES LxssConsoleManager.cpp LxssCreateProcess.cpp MirroredNetworking.cpp - NatNetworking.cpp - RingBuffer.cpp - WslCoreFilesystem.cpp WslCoreGuestNetworkService.cpp - WslCoreHostDnsInfo.cpp WslCoreInstance.cpp WslMirroredNetworking.cpp - WslCoreNetworkEndpointSettings.cpp - DnsResolver.cpp WslCoreTcpIpStateTracking.cpp WslCoreVm.cpp VirtioNetworking.cpp - WSLAUserSession.cpp - WSLAUserSessionFactory.cpp - WSLAVirtualMachine.cpp main.rc ${CMAKE_CURRENT_BINARY_DIR}/../mc/${TARGET_PLATFORM}/${CMAKE_BUILD_TYPE}/wsleventschema.rc application.manifest) @@ -44,7 +30,6 @@ set(SOURCES set(HEADERS ../../inc/comservicehelper.h DistributionRegistration.h - LxssSecurity.h LxssUserCallback.h LxssUserSession.h LxssUserSessionFactory.h @@ -53,35 +38,20 @@ set(HEADERS PluginManager.h LxssInstance.h BridgedNetworking.h - DeviceHostProxy.h - Dmesg.h - DnsTunnelingChannel.h - GnsChannel.h GnsPortTrackerChannel.h GnsRpcServer.h GuestTelemetryLogger.h - INetworkingEngine.h IMirroredNetworkManager.h Lifetime.h LxssConsoleManager.h LxssCreateProcess.h MirroredNetworking.h - NatNetworking.h - RingBuffer.h - WslCoreFilesystem.h WslCoreGuestNetworkService.h - WslCoreHostDnsInfo.h WslCoreInstance.h - WslCoreMessageQueue.h WslMirroredNetworking.h WslCoreNetworkEndpoint.h - WslCoreNetworkEndpointSettings.h - DnsResolver.h WslCoreTcpIpStateTracking.h - WslCoreVm.h - WSLAUserSession.h - WSLAUserSessionFactory.h - WSLAVirtualMachine.h) + WslCoreVm.h) include_directories(${CMAKE_SOURCE_DIR}/src/windows/wslaclient) diff --git a/src/windows/service/exe/MirroredNetworking.cpp b/src/windows/service/exe/MirroredNetworking.cpp index 79b15cb..417006e 100644 --- a/src/windows/service/exe/MirroredNetworking.cpp +++ b/src/windows/service/exe/MirroredNetworking.cpp @@ -256,7 +256,7 @@ void MirroredNetworking::Initialize() m_config.FirewallConfig.Enabled(), m_config.IgnoredPorts, m_runtimeId, m_gnsRpcServer->GetServerUuid(), s_GuestNetworkServiceCallback, this); m_ephemeralPortRange = m_guestNetworkService.AllocateEphemeralPortRange(); - networking::ConfigureHyperVFirewall(m_config.FirewallConfig, c_vmOwner); + networking::ConfigureHyperVFirewall(m_config.FirewallConfig, wsl::windows::common::wslutil::c_vmOwner); // must keep all m_networkManager interactions (including) creation queued // also must queue GNS callbacks to keep them serialized diff --git a/src/windows/service/exe/ServiceMain.cpp b/src/windows/service/exe/ServiceMain.cpp index a41b3c9..ef97b8a 100644 --- a/src/windows/service/exe/ServiceMain.cpp +++ b/src/windows/service/exe/ServiceMain.cpp @@ -31,7 +31,6 @@ wil::unique_event g_networkingReady{wil::EventOptions::ManualReset}; // Declare the LxssUserSession COM class. CoCreatableClassWrlCreatorMapInclude(LxssUserSession); -CoCreatableClassWrlCreatorMapInclude(WSLAUserSession); struct WslServiceSecurityPolicy { @@ -241,7 +240,6 @@ void WslService::ServiceStopped() // Terminate all user sessions. ClearSessionsAndBlockNewInstances(); - wsl::windows::service::wsla::ClearWslaSessionsAndBlockNewInstances(); // Disconnect from the LxCore driver. if (g_lxcoreInitialized) diff --git a/src/windows/service/exe/WslCoreVm.cpp b/src/windows/service/exe/WslCoreVm.cpp index a29b3b5..602131a 100644 --- a/src/windows/service/exe/WslCoreVm.cpp +++ b/src/windows/service/exe/WslCoreVm.cpp @@ -1468,7 +1468,7 @@ void WslCoreVm::FreeLun(_In_ ULONG lun) std::wstring WslCoreVm::GenerateConfigJson() { hcs::ComputeSystem systemSettings{}; - systemSettings.Owner = c_vmOwner; + systemSettings.Owner = wsl::windows::common::wslutil::c_vmOwner; systemSettings.ShouldTerminateOnLastHandleClosed = true; systemSettings.SchemaVersion.Major = 2; systemSettings.SchemaVersion.Minor = 3; @@ -1575,7 +1575,7 @@ std::wstring WslCoreVm::GenerateConfigJson() // Set the vmmem suffix which will change the process name in task manager. if (helpers::IsVmemmSuffixSupported()) { - vmSettings.ComputeTopology.Memory.HostingProcessNameSuffix = c_vmOwner; + vmSettings.ComputeTopology.Memory.HostingProcessNameSuffix = wsl::windows::common::wslutil::c_vmOwner; } // If nested virtualization was requested, ensure the platform supports it. diff --git a/src/windows/service/exe/WslCoreVm.h b/src/windows/service/exe/WslCoreVm.h index 89f4e44..17ed189 100644 --- a/src/windows/service/exe/WslCoreVm.h +++ b/src/windows/service/exe/WslCoreVm.h @@ -39,8 +39,6 @@ inline constexpr auto c_optionsValueName = L"Options"; inline constexpr auto c_typeValueName = L"Type"; inline constexpr auto c_mountNameValueName = L"Name"; -inline constexpr auto c_vmOwner = L"WSL"; - static constexpr GUID c_virtiofsAdminClassId = {0x7e6ad219, 0xd1b3, 0x42d5, {0xb8, 0xee, 0xd9, 0x63, 0x24, 0xe6, 0x4f, 0xf6}}; // {60285AE6-AAF3-4456-B444-A6C2D0DEDA38} diff --git a/src/windows/service/inc/wslservice.idl b/src/windows/service/inc/wslservice.idl index 714003e..4a0b341 100644 --- a/src/windows/service/inc/wslservice.idl +++ b/src/windows/service/inc/wslservice.idl @@ -396,97 +396,3 @@ cpp_quote("#define WSL_E_DISK_CORRUPTED MAKE_HRESULT(SEVERITY_ERROR, FACILITY_IT cpp_quote("#define WSL_E_DISTRIBUTION_NAME_NEEDED MAKE_HRESULT(SEVERITY_ERROR, FACILITY_ITF, WSL_E_BASE + 0x30) /* 0x80040330 */") cpp_quote("#define WSL_E_INVALID_JSON MAKE_HRESULT(SEVERITY_ERROR, FACILITY_ITF, WSL_E_BASE + 0x31) /* 0x80040331 */") cpp_quote("#define WSL_E_VM_CRASHED MAKE_HRESULT(SEVERITY_ERROR, FACILITY_ITF, WSL_E_BASE + 0x32) /* 0x80040332 */") - - -typedef -struct _WSL_VERSION { - ULONG Major; - ULONG Minor; - ULONG Revision; -} WSL_VERSION; - - -typedef [system_handle(sh_socket)] HANDLE HVSOCKET_HANDLE; - - -typedef -struct _WSLA_CREATE_PROCESS_OPTIONS { - [string] LPCSTR Executable; - ULONG CommandLineCount; - [unique, size_is(CommandLineCount)] LPCSTR* CommandLine; - ULONG EnvironmentCount; - [unique, size_is(EnvironmentCount)] LPCSTR* Environment; - [unique] LPCSTR CurrentDirectory; -} WSLA_CREATE_PROCESS_OPTIONS; - -typedef struct _WSLA_PROCESS_FD -{ - LONG Fd; - int Type; - [string, unique] LPCSTR Path; -} WSLA_PROCESS_FD; - -typedef -struct _WSLA_CREATE_PROCESS_RESULT { - int Errno; - int Pid; -} WSLA_CREATE_PROCESS_RESULT; - -[ - uuid(7BC4E198-6531-4FA6-ADE2-5EF3D2A04DFE), - pointer_default(unique), - object -] -interface ITerminationCallback : IUnknown -{ - HRESULT OnTermination(ULONG Reason, LPCWSTR Details); -}; - -[ - uuid(82A7ABC8-6B50-43FC-AB96-15FBBE7E8761), - pointer_default(unique), - object -] -interface IWSLAVirtualMachine : IUnknown -{ - HRESULT AttachDisk([in] LPCWSTR Path, [in] BOOL ReadOnly, [out] LPSTR* Device, [out] ULONG* Lun); - HRESULT Mount([in, unique] LPCSTR Source, [in] LPCSTR Target, [in] LPCSTR Type, [in] LPCSTR Options, [in] ULONG Flags); - HRESULT CreateLinuxProcess([in] const WSLA_CREATE_PROCESS_OPTIONS* Options, [in] ULONG FdCount, [in, unique, size_is(FdCount)] WSLA_PROCESS_FD* Fds, [out, size_is(FdCount)] ULONG* Handles, [out] WSLA_CREATE_PROCESS_RESULT* Result); - HRESULT WaitPid([in] LONG Pid, [in] ULONGLONG TimeoutMs, [out] ULONG* State, [out] int* Code); - HRESULT Signal([in] LONG Pid, [in] int Signal); - HRESULT Shutdown([in] ULONGLONG TimeoutMs); - HRESULT RegisterCallback([in] ITerminationCallback* terminationCallback); - HRESULT GetDebugShellPipe([out] LPWSTR* pipePath); - HRESULT MapPort([in] int Family, [in] short WindowsPort, [in] short LinuxPort, [in] BOOL Remove); - HRESULT Unmount([in] LPCSTR Path); - HRESULT DetachDisk([in] ULONG Lun); - HRESULT MountWindowsFolder([in] LPCWSTR WindowsPath, [in] LPCSTR LinuxPath, [in] BOOL ReadOnly); - HRESULT UnmountWindowsFolder([in] LPCSTR LinuxPath); - HRESULT MountGpuLibraries([in] LPCSTR LibrariesMountPoint, [in] LPCSTR DriversMountpoint, [in] DWORD Flags); -} - -typedef -struct _VIRTUAL_MACHINE_SETTINGS { - LPCWSTR DisplayName; - ULONGLONG MemoryMb; - ULONG CpuCount; - ULONG BootTimeoutMs; - ULONG DmesgOutput; - ULONG NetworkingMode; - BOOL EnableDnsTunneling; - BOOL EnableDebugShell; - BOOL EnableEarlyBootDmesg; - BOOL EnableGPU; -} VIRTUAL_MACHINE_SETTINGS; - - -[ - uuid(82A7ABC8-6B50-43FC-AB96-15FBBE7E8760), - pointer_default(unique), - object -] -interface IWSLAUserSession : IUnknown -{ - HRESULT GetVersion([out] WSL_VERSION* Error); - HRESULT CreateVirtualMachine([in] const VIRTUAL_MACHINE_SETTINGS* Settings, [out]IWSLAVirtualMachine** VirtualMachine); -} \ No newline at end of file diff --git a/src/windows/wslaclient/CMakeLists.txt b/src/windows/wslaclient/CMakeLists.txt index 5fbc876..24d5a25 100644 --- a/src/windows/wslaclient/CMakeLists.txt +++ b/src/windows/wslaclient/CMakeLists.txt @@ -3,6 +3,7 @@ set(HEADERS WSLAApi.h) add_library(wslaclient SHARED ${SOURCES} ${HEADERS} wslaclient.def) set_target_properties(wslaclient PROPERTIES EXCLUDE_FROM_ALL FALSE) +add_dependencies(wslaclient wslaserviceidl) target_link_libraries(wslaclient ${COMMON_LINK_LIBRARIES} legacy_stdio_definitions common) target_precompile_headers(wslaclient REUSE_FROM common) set_target_properties(wslaclient PROPERTIES FOLDER windows) \ No newline at end of file diff --git a/src/windows/wslaclient/DllMain.cpp b/src/windows/wslaclient/DllMain.cpp index ed06dfe..65d5044 100644 --- a/src/windows/wslaclient/DllMain.cpp +++ b/src/windows/wslaclient/DllMain.cpp @@ -13,7 +13,7 @@ Abstract: --*/ #include "precomp.h" -#include "wslservice.h" +#include "wslaservice.h" #include "WSLAApi.h" #include "wslrelay.h" #include "wslInstall.h" diff --git a/src/windows/wslaservice/CMakeLists.txt b/src/windows/wslaservice/CMakeLists.txt new file mode 100644 index 0000000..2d69374 --- /dev/null +++ b/src/windows/wslaservice/CMakeLists.txt @@ -0,0 +1,6 @@ +add_compile_definitions("PROXY_CLSID_IS={0x4EA0C6DD,0xE9FF,0x48E7,{0x99,0x4e,0x13,0xa3,0x1d,0x10,0xdc,0x61}}") +add_compile_definitions("REGISTER_PROXY_DLL") + +add_subdirectory(exe) +add_subdirectory(inc) +add_subdirectory(stub) \ No newline at end of file diff --git a/src/windows/wslaservice/exe/CMakeLists.txt b/src/windows/wslaservice/exe/CMakeLists.txt new file mode 100644 index 0000000..8d51923 --- /dev/null +++ b/src/windows/wslaservice/exe/CMakeLists.txt @@ -0,0 +1,34 @@ +set(SOURCES + application.manifest + main.rc + ServiceMain.cpp + WSLAUserSession.cpp + WSLAUserSessionFactory.cpp + WSLAVirtualMachine.cpp + ) + +set(HEADERS + WSLAUserSession.h + WSLAUserSessionFactory.h + WSLAVirtualMachine.h) + +include_directories(${CMAKE_SOURCE_DIR}/src/windows/wslaclient) + +add_executable(wslaservice ${SOURCES} ${HEADERS}) +add_dependencies(wslaservice wslaserviceidl) +add_compile_definitions(__WRL_CLASSIC_COM__) +add_compile_definitions(__WRL_DISABLE_STATIC_INITIALIZE__) +add_compile_definitions(USE_COM_CONTEXT_DEF=1) +set_target_properties(wslaservice PROPERTIES LINK_FLAGS "/merge:minATL=.rdata /include:__minATLObjMap_WSLAUserSession_COM") +target_link_libraries(wslaservice + ${COMMON_LINK_LIBRARIES} + computecore + common + configfile + legacy_stdio_definitions + VirtDisk.lib + Winhttp.lib + Synchronization.lib) + +target_precompile_headers(wslaservice REUSE_FROM common) +set_target_properties(wslaservice PROPERTIES FOLDER windows) \ No newline at end of file diff --git a/src/windows/wslaservice/exe/ServiceMain.cpp b/src/windows/wslaservice/exe/ServiceMain.cpp new file mode 100644 index 0000000..7a78656 --- /dev/null +++ b/src/windows/wslaservice/exe/ServiceMain.cpp @@ -0,0 +1,126 @@ +/*++ + +Copyright (c) Microsoft. All rights reserved. + +Module Name: + + ServiceMain.cpp + +Abstract: + + This file contains the entrypoint for the Lxss Manager service. + +--*/ + +#include "precomp.h" +#include "comservicehelper.h" +#include "LxssSecurity.h" +#include "WslCoreFilesystem.h" +#include "WSLAUserSessionFactory.h" +#include + +using namespace wsl::windows::common::registry; +using namespace wsl::windows::common::string; +using namespace wsl::windows::common::wslutil; +using namespace wsl::windows::policies; + +wil::unique_event g_networkingReady{wil::EventOptions::ManualReset}; + +// Declare the WSLAUserSession COM class. +CoCreatableClassWrlCreatorMapInclude(WSLAUserSession); + +struct WslaServiceSecurityPolicy +{ + static LPCWSTR GetSDDLText() + { + // COM Access and Launch permissions allowed for authenticated user, principal self, and system. + // 0xB = (COM_RIGHTS_EXECUTE | COM_RIGHTS_EXECUTE_LOCAL | COM_RIGHTS_ACTIVATE_LOCAL) + // N.B. This should be kept in sync with the security descriptors in the appxmanifest and package.wix. + return L"O:BAG:BAD:(A;;0xB;;;AU)(A;;0xB;;;PS)(A;;0xB;;;SY)"; + } +}; + +class WslaService : public Windows::Internal::Service +{ +public: + static wchar_t* GetName() + { + return const_cast(L"WSLAService"); + } + + static void OnSessionChanged(DWORD eventType, DWORD sessionId); + HRESULT OnServiceStarting(); + HRESULT ServiceStarted(); + void ServiceStopped(); + +private: + wil::unique_couninitialize_call m_coInit{false}; +}; + +HRESULT WslaService::OnServiceStarting() +try +{ + ConfigureCrt(); + + // Enable contextualized errors + wsl::windows::common::EnableContextualizedErrors(true); + + // Initialize telemetry. + // TODO-WSLA: Create a dedicated WSLA provider + WslTraceLoggingInitialize(WslServiceTelemetryProvider, !wsl::shared::OfficialBuild); + + WSL_LOG("Service starting", TraceLoggingLevel(WINEVENT_LEVEL_INFO)); + + // Don't kill the process on unknown C++ exceptions. + wil::g_fResultFailFastUnknownExceptions = false; + + // wsl::windows::common::security::ApplyProcessMitigationPolicies(); + + // Initialize Winsock. + WSADATA Data; + THROW_IF_WIN32_ERROR(WSAStartup(MAKEWORD(2, 2), &Data)); + + return S_OK; +} +CATCH_RETURN() + +HRESULT WslaService::ServiceStarted() +{ + m_coInit = wil::CoInitializeEx(COINIT_MULTITHREADED); + + return S_OK; +} + +void WslaService::OnSessionChanged(DWORD eventType, DWORD sessionId) +{ + if (eventType == WTS_SESSION_LOGOFF) + { + // TODO-WSLA: Implement for WSLA + // TerminateSession(sessionId); + } +} + +void WslaService::ServiceStopped() +{ + WSL_LOG("Service stopping", TraceLoggingLevel(WINEVENT_LEVEL_INFO)); + + // Terminate all user sessions. + wsl::windows::service::wsla::ClearWslaSessionsAndBlockNewInstances(); + + // There is a potential deadlock if CoUninitialize() is called before the LanguageChangeNotifyThread + // isn't done initializing. Clearing the COM objects before calling CoUninitialize() works around the issue. + winrt::clear_factory_cache(); + + // Tear down telemetry. + WslTraceLoggingUninitialize(); + + // uninitialize COM. This must be done here because this call can cause cleanups that will be fail + // if the CRT is shutting down. + m_coInit.reset(); +} + +int __cdecl wmain() +{ + WslaService::ProcessMain(); + return 0; +} diff --git a/src/windows/wslaservice/exe/WSLAUserSession.cpp b/src/windows/wslaservice/exe/WSLAUserSession.cpp new file mode 100644 index 0000000..deb6599 --- /dev/null +++ b/src/windows/wslaservice/exe/WSLAUserSession.cpp @@ -0,0 +1,87 @@ +/*++ + +Copyright (c) Microsoft. All rights reserved. + +Module Name: + + WSLAUserSession.cpp + +Abstract: + + TODO + +--*/ +#include "WSLAUserSession.h" + +using wsl::windows::service::wsla::WSLAUserSessionImpl; + +WSLAUserSessionImpl::WSLAUserSessionImpl(HANDLE Token, wil::unique_tokeninfo_ptr&& TokenInfo) : + m_tokenInfo(std::move(TokenInfo)) +{ +} + +WSLAUserSessionImpl::~WSLAUserSessionImpl() +{ + // Manually signal the VM termination events. This prevents being stuck on an API call that holds the VM lock. + { + std::lock_guard lock(m_virtualMachinesLock); + + for (auto* e : m_virtualMachines) + { + e->OnSessionTerminating(); + } + } +} + +void WSLAUserSessionImpl::OnVmTerminated(WSLAVirtualMachine* machine) +{ + std::lock_guard lock(m_virtualMachinesLock); + auto pred = [machine](const auto* e) { return machine == e; }; + + // Remove any stale VM reference. + m_virtualMachines.erase(std::remove_if(m_virtualMachines.begin(), m_virtualMachines.end(), pred), m_virtualMachines.end()); +} + +HRESULT WSLAUserSessionImpl::CreateVirtualMachine(const VIRTUAL_MACHINE_SETTINGS* Settings, IWSLAVirtualMachine** VirtualMachine) +{ + auto vm = wil::MakeOrThrow(*Settings, GetUserSid(), this); + + { + std::lock_guard lock(m_virtualMachinesLock); + m_virtualMachines.emplace_back(vm.Get()); + } + + vm->Start(); + THROW_IF_FAILED(vm.CopyTo(__uuidof(IWSLAVirtualMachine), (void**)VirtualMachine)); + + return S_OK; +} + +PSID WSLAUserSessionImpl::GetUserSid() const +{ + return m_tokenInfo->User.Sid; +} + +wsl::windows::service::wsla::WSLAUserSession::WSLAUserSession(std::weak_ptr&& Session) : + m_session(std::move(Session)) +{ +} + +HRESULT wsl::windows::service::wsla::WSLAUserSession::GetVersion(_Out_ WSL_VERSION* Version) +{ + Version->Major = WSL_PACKAGE_VERSION_MAJOR; + Version->Minor = WSL_PACKAGE_VERSION_MINOR; + Version->Revision = WSL_PACKAGE_VERSION_REVISION; + + return S_OK; +} + +HRESULT wsl::windows::service::wsla::WSLAUserSession::CreateVirtualMachine(const VIRTUAL_MACHINE_SETTINGS* Settings, IWSLAVirtualMachine** VirtualMachine) +try +{ + auto session = m_session.lock(); + RETURN_HR_IF(RPC_E_DISCONNECTED, !session); + + return session->CreateVirtualMachine(Settings, VirtualMachine); +} +CATCH_RETURN(); \ No newline at end of file diff --git a/src/windows/wslaservice/exe/WSLAUserSession.h b/src/windows/wslaservice/exe/WSLAUserSession.h new file mode 100644 index 0000000..e1c1762 --- /dev/null +++ b/src/windows/wslaservice/exe/WSLAUserSession.h @@ -0,0 +1,56 @@ +/*++ + +Copyright (c) Microsoft. All rights reserved. + +Module Name: + + WSLAUserSession.h + +Abstract: + + TODO + +--*/ +#pragma once +#include "WSLAVirtualMachine.h" + +namespace wsl::windows::service::wsla { + +class WSLAUserSessionImpl +{ +public: + WSLAUserSessionImpl(HANDLE Token, wil::unique_tokeninfo_ptr&& TokenInfo); + WSLAUserSessionImpl(WSLAUserSessionImpl&&) = default; + WSLAUserSessionImpl& operator=(WSLAUserSessionImpl&&) = default; + + ~WSLAUserSessionImpl(); + + PSID GetUserSid() const; + + HRESULT CreateVirtualMachine(const VIRTUAL_MACHINE_SETTINGS* Settings, IWSLAVirtualMachine** VirtualMachine); + + void OnVmTerminated(WSLAVirtualMachine* machine); + +private: + wil::unique_tokeninfo_ptr m_tokenInfo; + + std::recursive_mutex m_virtualMachinesLock; + std::vector m_virtualMachines; +}; + +class DECLSPEC_UUID("a9b7a1b9-0671-405c-95f1-e0612cb4ce8f") WSLAUserSession + : public Microsoft::WRL::RuntimeClass, IWSLAUserSession, IFastRundown> +{ +public: + WSLAUserSession(std::weak_ptr&& Session); + WSLAUserSession(const WSLAUserSession&) = delete; + WSLAUserSession& operator=(const WSLAUserSession&) = delete; + + IFACEMETHOD(GetVersion)(_Out_ WSL_VERSION* Version) override; + IFACEMETHOD(CreateVirtualMachine)(const VIRTUAL_MACHINE_SETTINGS* Settings, IWSLAVirtualMachine** VirtualMachine) override; + +private: + std::weak_ptr m_session; +}; + +} // namespace wsl::windows::service::wsla diff --git a/src/windows/wslaservice/exe/WSLAUserSessionFactory.cpp b/src/windows/wslaservice/exe/WSLAUserSessionFactory.cpp new file mode 100644 index 0000000..e13effe --- /dev/null +++ b/src/windows/wslaservice/exe/WSLAUserSessionFactory.cpp @@ -0,0 +1,82 @@ +/*++ + +Copyright (c) Microsoft. All rights reserved. + +Module Name: + + WSLAUserSessionFactory.cpp + +Abstract: + + TODO + +--*/ +#include "precomp.h" + +#include "WSLAUserSessionFactory.h" +#include "WSLAUserSession.h" + +using wsl::windows::service::wsla::WSLAUserSessionFactory; +using wsl::windows::service::wsla::WSLAUserSessionImpl; + +CoCreatableClassWithFactory(WSLAUserSession, WSLAUserSessionFactory); + +static std::mutex g_mutex; +static std::optional>> g_sessions = + std::make_optional>>(); + +HRESULT WSLAUserSessionFactory::CreateInstance(_In_ IUnknown* pUnkOuter, _In_ REFIID riid, _Out_ void** ppCreated) +{ + RETURN_HR_IF_NULL(E_POINTER, ppCreated); + *ppCreated = nullptr; + + RETURN_HR_IF(CLASS_E_NOAGGREGATION, pUnkOuter != nullptr); + + WSL_LOG("WSLAUserSessionFactory", TraceLoggingLevel(WINEVENT_LEVEL_VERBOSE)); + + try + { + const wil::unique_handle userToken = wsl::windows::common::security::GetUserToken(TokenImpersonation); + + // Get the session ID and SID of the client process. + DWORD sessionId{}; + DWORD length = 0; + THROW_IF_WIN32_BOOL_FALSE(::GetTokenInformation(userToken.get(), TokenSessionId, &sessionId, sizeof(sessionId), &length)); + + auto tokenInfo = wil::get_token_information(userToken.get()); + + std::lock_guard lock{g_mutex}; + + THROW_HR_IF(CO_E_SERVER_STOPPING, !g_sessions.has_value()); + + auto session = std::find_if(g_sessions->begin(), g_sessions->end(), [&tokenInfo](auto it) { + return EqualSid(it->GetUserSid(), &tokenInfo->User.Sid); + }); + + if (session == g_sessions->end()) + { + session = g_sessions->insert(g_sessions->end(), std::make_shared(userToken.get(), std::move(tokenInfo))); + } + + auto comInstance = wil::MakeOrThrow(std::weak_ptr(*session)); + + THROW_IF_FAILED(comInstance.CopyTo(riid, ppCreated)); + } + catch (...) + { + const auto result = wil::ResultFromCaughtException(); + + // Note: S_FALSE will cause COM to retry if the service is stopping. + return result == CO_E_SERVER_STOPPING ? S_FALSE : result; + } + + WSL_LOG("WSLAUserSessionFactory", TraceLoggingLevel(WINEVENT_LEVEL_VERBOSE)); + + return S_OK; +} + +void wsl::windows::service::wsla::ClearWslaSessionsAndBlockNewInstances() +{ + std::lock_guard lock{g_mutex}; + g_sessions.reset(); +} \ No newline at end of file diff --git a/src/windows/wslaservice/exe/WSLAUserSessionFactory.h b/src/windows/wslaservice/exe/WSLAUserSessionFactory.h new file mode 100644 index 0000000..fbc516a --- /dev/null +++ b/src/windows/wslaservice/exe/WSLAUserSessionFactory.h @@ -0,0 +1,28 @@ +/*++ + +Copyright (c) Microsoft. All rights reserved. + +Module Name: + + WSLAUserSessionFactory.h + +Abstract: + + TODO + +--*/ +#pragma once +#include + +namespace wsl::windows::service::wsla { + +class WSLAUserSessionFactory : public Microsoft::WRL::ClassFactory<> +{ +public: + WSLAUserSessionFactory() = default; + + STDMETHODIMP CreateInstance(_In_ IUnknown* pUnkOuter, _In_ REFIID riid, _Out_ void** ppCreated) override; +}; + +void ClearWslaSessionsAndBlockNewInstances(); +} // namespace wsl::windows::service::wsla \ No newline at end of file diff --git a/src/windows/wslaservice/exe/WSLAVirtualMachine.cpp b/src/windows/wslaservice/exe/WSLAVirtualMachine.cpp new file mode 100644 index 0000000..900076d --- /dev/null +++ b/src/windows/wslaservice/exe/WSLAVirtualMachine.cpp @@ -0,0 +1,1052 @@ +/*++ + +Copyright (c) Microsoft. All rights reserved. + +Module Name: + + WSLAVirtualMachine.cpp + +Abstract: + + Class for the WSLA virtual machine. + +--*/ + +#include "WSLAVirtualMachine.h" +#include "hcs_schema.h" +#include "NatNetworking.h" +#include "WSLAUserSession.h" + +using namespace wsl::windows::common; +using helpers::WindowsBuildNumbers; +using helpers::WindowsVersion; +using wsl::windows::service::wsla::WSLAVirtualMachine; + +WSLAVirtualMachine::WSLAVirtualMachine(const VIRTUAL_MACHINE_SETTINGS& Settings, PSID UserSid, WSLAUserSessionImpl* Session) : + m_settings(Settings), m_userSid(UserSid), m_userSession(Session) +{ + THROW_IF_FAILED(CoCreateGuid(&m_vmId)); + + if (Settings.EnableDebugShell) + { + m_debugShellPipe = wsl::windows::common::wslutil::GetDebugShellPipeName(m_userSid) + m_settings.DisplayName; + } +} + +HRESULT WSLAVirtualMachine::GetDebugShellPipe(LPWSTR* pipePath) +{ + RETURN_HR_IF(E_INVALIDARG, m_debugShellPipe.empty()); + + *pipePath = wil::make_unique_string(m_debugShellPipe.c_str()).release(); + + return S_OK; +} + +void WSLAVirtualMachine::OnSessionTerminating() +{ + m_userSession = nullptr; + std::lock_guard mutex(m_lock); + + if (m_vmTerminatingEvent.is_signaled()) + { + return; + } + + WSL_LOG("WSLASignalTerminating", TraceLoggingValue(m_running, "running")); + + m_vmTerminatingEvent.SetEvent(); +} + +WSLAVirtualMachine::~WSLAVirtualMachine() +{ + { + std::lock_guard mutex(m_lock); + + if (m_userSession != nullptr) + { + m_userSession->OnVmTerminated(this); + } + } + + WSL_LOG("WSLATerminateVmStart", TraceLoggingValue(m_running, "running")); + + m_initChannel.Close(); + + bool forceTerminate = false; + + // Wait up to 5 seconds for the VM to terminate. + if (!m_vmExitEvent.wait(5000)) + { + forceTerminate = true; + try + { + wsl::windows::common::hcs::TerminateComputeSystem(m_computeSystem.get()); + } + CATCH_LOG() + } + + WSL_LOG("WSLATerminateVm", TraceLoggingValue(forceTerminate, "forced"), TraceLoggingValue(m_running, "running")); + + m_computeSystem.reset(); + + for (const auto& e : m_attachedDisks) + { + try + { + if (e.second.AccessGranted) + { + wsl::windows::common::hcs::RevokeVmAccess(m_vmIdString.c_str(), e.second.Path.c_str()); + } + } + CATCH_LOG() + } +} + +void WSLAVirtualMachine::Start() +{ + hcs::ComputeSystem systemSettings{}; + systemSettings.Owner = L"WSL"; + systemSettings.ShouldTerminateOnLastHandleClosed = true; + systemSettings.SchemaVersion.Major = 2; + systemSettings.SchemaVersion.Minor = 3; + hcs::VirtualMachine vmSettings{}; + vmSettings.StopOnReset = true; + vmSettings.Chipset.UseUtc = true; + + // Ensure the 2MB granularity enforced by HCS. + vmSettings.ComputeTopology.Memory.SizeInMB = m_settings.MemoryMb & ~0x1; + vmSettings.ComputeTopology.Memory.AllowOvercommit = true; + vmSettings.ComputeTopology.Memory.EnableDeferredCommit = true; + vmSettings.ComputeTopology.Memory.EnableColdDiscardHint = true; + + // Configure backing page size, fault cluster shift size, and cold discard hint size to favor density (lower vmmem usage). + // + // N.B. Cold discard hint size should be a multiple of the fault cluster shift size. + // + // N.B. This is only done on builds that have the fix for the VID deadlock on partition teardown. + if ((m_windowsVersion.BuildNumber >= WindowsBuildNumbers::Germanium) || + (m_windowsVersion.BuildNumber >= WindowsBuildNumbers::Cobalt && m_windowsVersion.UpdateBuildRevision >= 2360) || + (m_windowsVersion.BuildNumber >= WindowsBuildNumbers::Iron && m_windowsVersion.UpdateBuildRevision >= 1970) || + (m_windowsVersion.BuildNumber >= WindowsBuildNumbers::Vibranium_22H2 && m_windowsVersion.UpdateBuildRevision >= 3393)) + { + vmSettings.ComputeTopology.Memory.BackingPageSize = hcs::MemoryBackingPageSize::Small; + vmSettings.ComputeTopology.Memory.FaultClusterSizeShift = 4; // 64k + vmSettings.ComputeTopology.Memory.DirectMapFaultClusterSizeShift = 4; // 64k + m_coldDiscardShiftSize = 5; // 128k + } + else + { + m_coldDiscardShiftSize = 9; // 2MB + } + + // Configure the number of processors. + vmSettings.ComputeTopology.Processor.Count = m_settings.CpuCount; + + // Set the vmmem suffix which will change the process name in task manager. + if (helpers::IsVmemmSuffixSupported()) + { + vmSettings.ComputeTopology.Memory.HostingProcessNameSuffix = m_settings.DisplayName; + } + +#ifdef _AMD64_ + + HV_X64_HYPERVISOR_HARDWARE_FEATURES hardwareFeatures{}; + __cpuid(reinterpret_cast(&hardwareFeatures), HvCpuIdFunctionMsHvHardwareFeatures); + vmSettings.ComputeTopology.Processor.EnablePerfmonPmu = hardwareFeatures.ChildPerfmonPmuSupported != 0; + vmSettings.ComputeTopology.Processor.EnablePerfmonLbr = hardwareFeatures.ChildPerfmonLbrSupported != 0; + +#endif + + // Initialize kernel command line. + std::wstring kernelCmdLine = L"initrd=\\" LXSS_VM_MODE_INITRD_NAME L" " TEXT(WSLA_ROOT_INIT_ENV) L"=1 panic=-1"; + + // Set number of processors. + kernelCmdLine += std::format(L" nr_cpus={}", m_settings.CpuCount); + + // Enable timesync workaround to sync on resume from sleep in modern standby. + kernelCmdLine += L" hv_utils.timesync_implicit=1"; + + wil::unique_handle dmesgOutput; + if (m_settings.DmesgOutput != 0) + { + dmesgOutput.reset(wsl::windows::common::wslutil::DuplicateHandleFromCallingProcess(ULongToHandle(m_settings.DmesgOutput))); + } + + m_dmesgCollector = DmesgCollector::Create(m_vmId, m_vmExitEvent, true, false, L"", true, std::move(dmesgOutput)); + + if (m_settings.EnableEarlyBootDmesg) + { + kernelCmdLine += L" earlycon=uart8250,io,0x3f8,115200"; + vmSettings.Devices.ComPorts["0"] = hcs::ComPort{m_dmesgCollector->EarlyConsoleName()}; + } + + if (helpers::IsVirtioSerialConsoleSupported()) + { + vmSettings.Devices.VirtioSerial.emplace(); + + // The primary "console" will be a virtio serial device. + + kernelCmdLine += L" console=hvc0 debug"; + hcs::VirtioSerialPort virtioPort{}; + virtioPort.Name = L"hvc0"; + virtioPort.NamedPipe = m_dmesgCollector->VirtioConsoleName(); + virtioPort.ConsoleSupport = true; + vmSettings.Devices.VirtioSerial->Ports["0"] = std::move(virtioPort); + + if (!m_debugShellPipe.empty()) + { + hcs::VirtioSerialPort virtioPort; + virtioPort.Name = L"hvc1"; + virtioPort.NamedPipe = m_debugShellPipe; + virtioPort.ConsoleSupport = true; + vmSettings.Devices.VirtioSerial->Ports["1"] = std::move(virtioPort); + } + } + + // Set up boot params. + // + // N.B. Linux kernel direct boot is not yet supported on ARM64. + + auto basePath = wslutil::GetBasePath(); + +#ifdef WSL_KERNEL_PATH + + auto kernelPath = std::filesystem::path(WSL_KERNEL_PATH); + +#else + auto kernelPath = std::filesystem::path(basePath) / L"tools" / LXSS_VM_MODE_KERNEL_NAME; +#endif + + if constexpr (!wsl::shared::Arm64) + { + vmSettings.Chipset.LinuxKernelDirect.emplace(); + vmSettings.Chipset.LinuxKernelDirect->KernelFilePath = kernelPath.wstring(); + vmSettings.Chipset.LinuxKernelDirect->InitRdPath = (basePath / L"tools" / LXSS_VM_MODE_INITRD_NAME).c_str(); + vmSettings.Chipset.LinuxKernelDirect->KernelCmdLine = kernelCmdLine; + } + else + { + // TODO + THROW_HR(E_NOTIMPL); + auto bootThis = hcs::UefiBootEntry{}; + bootThis.DeviceType = hcs::UefiBootDevice::VmbFs; + // bootThis.VmbFsRootPath = m_rootFsPath.c_str(); + bootThis.DevicePath = L"\\" LXSS_VM_MODE_KERNEL_NAME; + bootThis.OptionalData = kernelCmdLine; + hcs::Uefi uefiSettings{}; + uefiSettings.BootThis = std::move(bootThis); + vmSettings.Chipset.Uefi = std::move(uefiSettings); + } + + // Initialize other devices. + vmSettings.Devices.Scsi["0"] = hcs::Scsi{}; + hcs::HvSocket hvSocketConfig{}; + + // Construct a security descriptor that allows system and the current user. + wil::unique_hlocal_string userSidString; + THROW_LAST_ERROR_IF(!ConvertSidToStringSidW(m_userSid, &userSidString)); + + std::wstring securityDescriptor{L"D:P(A;;FA;;;SY)(A;;FA;;;"}; + securityDescriptor += userSidString.get(); + securityDescriptor += L")"; + hvSocketConfig.HvSocketConfig.DefaultBindSecurityDescriptor = securityDescriptor; + hvSocketConfig.HvSocketConfig.DefaultConnectSecurityDescriptor = securityDescriptor; + vmSettings.Devices.HvSocket = std::move(hvSocketConfig); + + systemSettings.VirtualMachine = std::move(vmSettings); + auto json = wsl::shared::ToJsonW(systemSettings); + + WSL_LOG("CreateWSLAVirtualMachine", TraceLoggingValue(json.c_str(), "json")); + + m_vmIdString = wsl::shared::string::GuidToString(m_vmId, wsl::shared::string::GuidToStringFlags::Uppercase); + m_computeSystem = hcs::CreateComputeSystem(m_vmIdString.c_str(), json.c_str()); + + auto runtimeId = wsl::windows::common::hcs::GetRuntimeId(m_computeSystem.get()); + WI_ASSERT(IsEqualGUID(m_vmId, runtimeId)); + + wsl::windows::common::hcs::RegisterCallback(m_computeSystem.get(), &s_OnExit, this); + + wsl::windows::common::hcs::StartComputeSystem(m_computeSystem.get(), json.c_str()); + + // Create a socket listening for connections from mini_init. + auto listenSocket = wsl::windows::common::hvsocket::Listen(runtimeId, LX_INIT_UTILITY_VM_INIT_PORT); + auto socket = wsl::windows::common::hvsocket::Accept(listenSocket.get(), m_settings.BootTimeoutMs, m_vmTerminatingEvent.get()); + m_initChannel = wsl::shared::SocketChannel{std::move(socket), "mini_init", m_vmTerminatingEvent.get()}; + + ConfigureNetworking(); + + // Mount the kernel modules VHD. + +#ifdef WSL_KERNEL_MODULES_PATH + + auto kernelModulesPath = std::filesystem::path(TEXT(WSL_KERNEL_MODULES_PATH)); + +#else + + auto kernelModulesPath = basePath / L"tools" / L"modules.vhd"; + +#endif + + wil::unique_cotaskmem_ansistring device; + ULONG lun{}; + THROW_IF_FAILED(AttachDisk(kernelModulesPath.c_str(), true, &device, &lun)); + + THROW_HR_IF_MSG( + E_FAIL, + MountImpl(m_initChannel, device.get(), "", "ext4", "ro", WSLA_MOUNT::KernelModules) != 0, + "Failed to mount the kernel modules from: %hs", + device.get()); + + // Configure GPU if requested. + if (m_settings.EnableGPU) + { + hcs::ModifySettingRequest gpuRequest{}; + gpuRequest.ResourcePath = L"VirtualMachine/ComputeTopology/Gpu"; + gpuRequest.RequestType = hcs::ModifyRequestType::Update; + gpuRequest.Settings.AssignmentMode = hcs::GpuAssignmentMode::Mirror; + gpuRequest.Settings.AllowVendorExtension = true; + if (wsl::windows::common::helpers::IsDisableVgpuSettingsSupported()) + { + gpuRequest.Settings.DisableGdiAcceleration = true; + gpuRequest.Settings.DisablePresentation = true; + } + + wsl::windows::common::hcs::ModifyComputeSystem(m_computeSystem.get(), wsl::shared::ToJsonW(gpuRequest).c_str()); + } +} + +void WSLAVirtualMachine::ConfigureNetworking() +{ + if (m_settings.NetworkingMode == WslNetworkingModeNone) + { + return; + } + else if (m_settings.NetworkingMode == WslNetworkingModeNAT) + { + // Launch GNS + + WSLA_PROCESS_FD fd{}; + fd.Fd = 3; + fd.Type = WslFdType::WslFdTypeDefault; + + std::vector cmd{"/gns", LX_INIT_GNS_SOCKET_ARG, "3"}; + WSLA_CREATE_PROCESS_OPTIONS options{}; + options.Executable = "/init"; + options.CommandLine = cmd.data(); + options.CommandLineCount = static_cast(cmd.size()); + + std::vector socketHandles(2); + + WSLA_CREATE_PROCESS_RESULT result{}; + auto sockets = CreateLinuxProcessImpl(&options, 1, &fd, &result); + + THROW_HR_IF(E_FAIL, result.Errno != 0); + + // TODO: refactor this to avoid using wsl config + static wsl::core::Config config(nullptr); + + // TODO-WSLA: Implement firewall logic + /*if (!wsl::core::MirroredNetworking::IsHyperVFirewallSupported(config)) + { + config.FirewallConfig.reset(); + }*/ + + // TODO: DNS Tunneling support + m_networkEngine = std::make_unique( + m_computeSystem.get(), wsl::core::NatNetworking::CreateNetwork(config), std::move(sockets[0]), config, wil::unique_socket{}); + + m_networkEngine->Initialize(); + + LaunchPortRelay(); + } + else + { + THROW_HR_MSG(E_INVALIDARG, "Invalid networking mode: %lu", m_settings.NetworkingMode); + } +} + +void CALLBACK WSLAVirtualMachine::s_OnExit(_In_ HCS_EVENT* Event, _In_opt_ void* Context) +{ + if (Event->Type == HcsEventSystemExited || Event->Type == HcsEventSystemCrashInitiated || Event->Type == HcsEventSystemCrashReport) + { + reinterpret_cast(Context)->OnExit(Event); + } +} + +void WSLAVirtualMachine::OnExit(_In_ const HCS_EVENT* Event) +{ + WSL_LOG( + "WSLAVmExited", TraceLoggingValue(Event->EventData, "details"), TraceLoggingValue(static_cast(Event->Type), "type")); + + m_vmExitEvent.SetEvent(); + + std::lock_guard lock(m_lock); + if (m_terminationCallback) + { + // TODO: parse json and give a better error. + WslVirtualMachineTerminationReason reason = WslVirtualMachineTerminationReasonUnknown; + if (Event->Type == HcsEventSystemExited) + { + reason = WslVirtualMachineTerminationReasonShutdown; + } + else if (Event->Type == HcsEventSystemCrashInitiated || Event->Type == HcsEventSystemCrashReport) + { + reason = WslVirtualMachineTerminationReasonCrashed; + } + + LOG_IF_FAILED(m_terminationCallback->OnTermination(static_cast(reason), Event->EventData)); + } +} + +HRESULT WSLAVirtualMachine::AttachDisk(_In_ PCWSTR Path, _In_ BOOL ReadOnly, _Out_ LPSTR* Device, _Out_ ULONG* Lun) +try +{ + *Device = nullptr; + auto result = wil::ResultFromException([&]() { + std::lock_guard lock{m_lock}; + THROW_HR_IF(HRESULT_FROM_WIN32(ERROR_INVALID_STATE), m_running); + + AttachedDisk disk{Path}; + + auto grantDiskAccess = [&]() { + const auto userToken = wsl::windows::common::security::GetUserToken(TokenImpersonation); + auto runAsUser = wil::impersonate_token(userToken.get()); + wsl::windows::common::hcs::GrantVmAccess(m_vmIdString.c_str(), Path); + disk.AccessGranted = true; + }; + + if (!ReadOnly) + { + grantDiskAccess(); + } + + *Lun = 0; + while (m_attachedDisks.find(*Lun) != m_attachedDisks.end()) + { + (*Lun)++; + } + + bool vhdAdded = false; + auto cleanup = wil::scope_exit_log(WI_DIAGNOSTICS_INFO, [&]() { + if (vhdAdded) + { + wsl::windows::common::hcs::RemoveScsiDisk(m_computeSystem.get(), *Lun); + } + + if (disk.AccessGranted) + { + wsl::windows::common::hcs::RevokeVmAccess(m_vmIdString.c_str(), Path); + } + }); + + auto result = + wil::ResultFromException([&]() { wsl::windows::common::hcs::AddVhd(m_computeSystem.get(), Path, *Lun, ReadOnly); }); + + if (result == HRESULT_FROM_WIN32(ERROR_ACCESS_DENIED) && !disk.AccessGranted) + { + grantDiskAccess(); + wsl::windows::common::hcs::AddVhd(m_computeSystem.get(), Path, *Lun, ReadOnly); + } + else + { + THROW_IF_FAILED(result); + } + + vhdAdded = true; + + WSLA_GET_DISK message{}; + message.Header.MessageSize = sizeof(message); + message.Header.MessageType = WSLA_GET_DISK::Type; + message.ScsiLun = *Lun; + const auto& response = m_initChannel.Transaction(message); + + THROW_HR_IF_MSG(E_FAIL, response.Result != 0, "Failed to attach disk, init returned: %lu", response.Result); + + cleanup.release(); + + disk.Device = response.Buffer; + m_attachedDisks.emplace(*Lun, std::move(disk)); + + *Device = wil::make_unique_ansistring(response.Buffer).release(); + }); + + WSL_LOG( + "WSLAAttachDisk", + TraceLoggingValue(Path, "Path"), + TraceLoggingValue(ReadOnly, "ReadOnly"), + TraceLoggingValue(*Device == nullptr ? "" : *Device, "Device"), + TraceLoggingValue(result, "Result")); + + return result; +} +CATCH_RETURN(); + +HRESULT WSLAVirtualMachine::Mount(_In_ LPCSTR Source, _In_ LPCSTR Target, _In_ LPCSTR Type, _In_ LPCSTR Options, _In_ ULONG Flags) +try +{ + THROW_HR_IF(E_INVALIDARG, WI_IsAnyFlagSet(Flags, ~(WslMountFlagsChroot | WslMountFlagsWriteableOverlayFs))); + + std::lock_guard lock{m_lock}; + THROW_HR_IF(HRESULT_FROM_WIN32(ERROR_INVALID_STATE), m_running); + + THROW_HR_IF(E_FAIL, MountImpl(m_initChannel, Source, Target, Type, Options, Flags) != 0); + + return S_OK; +} +CATCH_RETURN(); + +HRESULT WSLAVirtualMachine::Unmount(_In_ const char* Path) +try +{ + auto [pid, _, subChannel] = Fork(WSLA_FORK::Thread); + + wsl::shared::MessageWriter message; + message.WriteString(Path); + + const auto& response = subChannel.Transaction(message.Span()); + + // TODO: Return errno to caller + THROW_HR_IF(HRESULT_FROM_WIN32(ERROR_NOT_FOUND), response.Result == EINVAL); + THROW_HR_IF(E_FAIL, response.Result != 0); + + return S_OK; +} +CATCH_RETURN() + +HRESULT WSLAVirtualMachine::DetachDisk(_In_ ULONG Lun) +try +{ + std::lock_guard lock{m_lock}; + + // Find the disk + auto it = m_attachedDisks.find(Lun); + RETURN_HR_IF(HRESULT_FROM_WIN32(ERROR_NOT_FOUND), it == m_attachedDisks.end()); + + // Detach it from the guest + WSLA_DETACH message; + message.Lun = Lun; + const auto& response = m_initChannel.Transaction(message); + + // TODO: Return errno to caller + THROW_HR_IF(E_FAIL, response.Result != 0); + + // Remove it from the VM + m_attachedDisks.erase(it); + + hcs::RemoveScsiDisk(m_computeSystem.get(), Lun); + + return S_OK; +} +CATCH_RETURN() + +std::tuple WSLAVirtualMachine::Fork(enum WSLA_FORK::ForkType Type) +{ + std::lock_guard lock{m_lock}; + return Fork(m_initChannel, Type); +} + +std::tuple WSLAVirtualMachine::Fork(wsl::shared::SocketChannel& Channel, enum WSLA_FORK::ForkType Type) +{ + uint32_t port{}; + int32_t pid{}; + int32_t ptyMaster{}; + { + THROW_HR_IF(HRESULT_FROM_WIN32(ERROR_INVALID_STATE), m_running); + + WSLA_FORK message; + message.ForkType = Type; + message.TtyColumns = 80; + message.TtyRows = 80; + const auto& response = Channel.Transaction(message); + port = response.Port; + pid = response.Pid; + ptyMaster = response.PtyMasterFd; + } + + THROW_HR_IF_MSG(E_FAIL, pid <= 0, "fork() returned %i", pid); + + auto socket = wsl::windows::common::hvsocket::Connect(m_vmId, port, m_vmExitEvent.get(), m_settings.BootTimeoutMs); + + return std::make_tuple(pid, ptyMaster, wsl::shared::SocketChannel{std::move(socket), std::to_string(pid), m_vmTerminatingEvent.get()}); +} + +wil::unique_socket WSLAVirtualMachine::ConnectSocket(wsl::shared::SocketChannel& Channel, int32_t Fd) +{ + WSLA_ACCEPT message{}; + message.Fd = Fd; + const auto& response = Channel.Transaction(message); + + return wsl::windows::common::hvsocket::Connect(m_vmId, response.Result); +} + +void WSLAVirtualMachine::OpenLinuxFile(wsl::shared::SocketChannel& Channel, const char* Path, uint32_t Flags, int32_t Fd) +{ + static_assert(WslFdTypeLinuxFileInput == WslaOpenFlagsRead); + static_assert(WslFdTypeLinuxFileOutput == WslaOpenFlagsWrite); + static_assert(WslFdTypeLinuxFileAppend == WslaOpenFlagsAppend); + static_assert(WslFdTypeLinuxFileCreate == WslaOpenFlagsCreate); + + shared::MessageWriter message; + message->Fd = Fd; + message->Flags = Flags; + message.WriteString(Path); + + auto result = Channel.Transaction(message.Span()).Result; + + THROW_HR_IF_MSG(E_FAIL, result != 0, "Failed to open %hs (flags: %u), %i", Path, Flags, result); +} + +HRESULT WSLAVirtualMachine::CreateLinuxProcess( + _In_ const WSLA_CREATE_PROCESS_OPTIONS* Options, ULONG FdCount, WSLA_PROCESS_FD* Fds, _Out_ ULONG* Handles, _Out_ WSLA_CREATE_PROCESS_RESULT* Result) +try +{ + auto sockets = CreateLinuxProcessImpl(Options, FdCount, Fds, Result); + + for (size_t i = 0; i < sockets.size(); i++) + { + if (sockets[i]) + { + Handles[i] = + HandleToUlong(wsl::windows::common::wslutil::DuplicateHandleToCallingProcess(reinterpret_cast(sockets[i].get()))); + } + } + + return S_OK; +} +CATCH_RETURN(); + +std::vector WSLAVirtualMachine::CreateLinuxProcessImpl( + _In_ const WSLA_CREATE_PROCESS_OPTIONS* Options, _In_ ULONG FdCount, _In_ WSLA_PROCESS_FD* Fds, _Out_ WSLA_CREATE_PROCESS_RESULT* Result) +{ + // Check if this is a tty or not + const WSLA_PROCESS_FD* ttyInput = nullptr; + const WSLA_PROCESS_FD* ttyOutput = nullptr; + auto interactiveTty = ParseTtyInformation(Fds, FdCount, &ttyInput, &ttyOutput); + auto [pid, _, childChannel] = Fork(WSLA_FORK::Process); + + std::vector sockets(FdCount); + for (size_t i = 0; i < FdCount; i++) + { + if (Fds[i].Type == WslFdTypeDefault || Fds[i].Type == WslFdTypeTerminalInput || Fds[i].Type == WslFdTypeTerminalOutput) + { + THROW_HR_IF_MSG(E_INVALIDARG, Fds[i].Type > WslFdTypeTerminalOutput, "Invalid flags: %i", Fds[i].Type); + THROW_HR_IF_MSG(E_INVALIDARG, Fds[i].Path != nullptr, "Fd[%zu] has a non-null path but flags: %i", i, Fds[i].Type); + sockets[i] = ConnectSocket(childChannel, static_cast(Fds[i].Fd)); + } + else + { + THROW_HR_IF_MSG( + E_INVALIDARG, + WI_IsAnyFlagSet(Fds[i].Type, WslFdTypeTerminalInput | WslFdTypeTerminalOutput), + "Invalid flags: %i", + Fds[i].Type); + + THROW_HR_IF_MSG(E_INVALIDARG, Fds[i].Path == nullptr, "Fd[%zu] has a null path but flags: %i", i, Fds[i].Type); + OpenLinuxFile(childChannel, Fds[i].Path, Fds[i].Type, Fds[i].Fd); + } + } + + wsl::shared::MessageWriter Message; + + Message.WriteString(Message->ExecutableIndex, Options->Executable); + Message.WriteString(Message->CurrentDirectoryIndex, Options->CurrentDirectory ? Options->CurrentDirectory : "/"); + Message.WriteStringArray(Message->CommandLineIndex, Options->CommandLine, Options->CommandLineCount); + Message.WriteStringArray(Message->EnvironmentIndex, Options->Environment, Options->EnvironmentCount); + + // If this is an interactive tty, we need a relay process + if (interactiveTty) + { + auto [grandChildPid, ptyMaster, grandChildChannel] = Fork(childChannel, WSLA_FORK::Pty); + WSLA_TTY_RELAY relayMessage; + relayMessage.TtyMaster = ptyMaster; + relayMessage.TtyInput = ttyInput->Fd; + relayMessage.TtyOutput = ttyOutput->Fd; + childChannel.SendMessage(relayMessage); + + auto result = ExpectClosedChannelOrError(childChannel); + if (result != 0) + { + Result->Errno = result; + THROW_HR(E_FAIL); + } + + grandChildChannel.SendMessage(Message.Span()); + result = ExpectClosedChannelOrError(grandChildChannel); + if (result != 0) + { + Result->Errno = result; + THROW_HR(E_FAIL); + } + + pid = grandChildPid; + } + else + { + childChannel.SendMessage(Message.Span()); + auto result = ExpectClosedChannelOrError(childChannel); + if (result != 0) + { + Result->Errno = result; + THROW_HR(E_FAIL); + } + } + + Result->Errno = 0; + Result->Pid = pid; + return sockets; +} + +int32_t WSLAVirtualMachine::MountImpl(shared::SocketChannel& Channel, LPCSTR Source, LPCSTR Target, LPCSTR Type, LPCSTR Options, ULONG Flags) +{ + static_assert(WslMountFlagsNone == WSLA_MOUNT::None); + static_assert(WslMountFlagsChroot == WSLA_MOUNT::Chroot); + static_assert(WslMountFlagsWriteableOverlayFs == WSLA_MOUNT::OverlayFs); + + wsl::shared::MessageWriter message; + + auto optionalAdd = [&](auto value, unsigned int& index) { + if (value != nullptr) + { + message.WriteString(index, value); + } + }; + + optionalAdd(Source, message->SourceIndex); + optionalAdd(Target, message->DestinationIndex); + optionalAdd(Type, message->TypeIndex); + optionalAdd(Options, message->OptionsIndex); + message->Flags = Flags; + + const auto& response = Channel.Transaction(message.Span()); + + WSL_LOG( + "WSLAMount", + TraceLoggingValue(Source == nullptr ? "" : Source, "Source"), + TraceLoggingValue(Target == nullptr ? "" : Target, "Target"), + TraceLoggingValue(Type == nullptr ? "" : Type, "Type"), + TraceLoggingValue(Options == nullptr ? "" : Options, "Options"), + TraceLoggingValue(Flags, "Flags"), + TraceLoggingValue(response.Result, "Result")); + + return response.Result; +} + +int32_t WSLAVirtualMachine::ExpectClosedChannelOrError(wsl::shared::SocketChannel& Channel) +{ + auto [response, span] = Channel.ReceiveMessageOrClosed>(); + if (response != nullptr) + { + return response->Result; + } + else + { + return 0; + } +} + +HRESULT WSLAVirtualMachine::WaitPid(LONG Pid, ULONGLONG TimeoutMs, ULONG* State, int* Code) +try +{ + auto [pid, _, subChannel] = Fork(WSLA_FORK::Thread); + + WSLA_WAITPID message{}; + message.Pid = Pid; + message.TimeoutMs = TimeoutMs; + + const auto& response = subChannel.Transaction(message); + + THROW_HR_IF(E_FAIL, response.State == WSLAOpenFlagsUnknown); + + *State = response.State; + *Code = response.Code; + + return S_OK; +} +CATCH_RETURN(); + +HRESULT WSLAVirtualMachine::Shutdown(ULONGLONG TimeoutMs) +try +{ + std::lock_guard lock(m_lock); + + THROW_HR_IF(HRESULT_FROM_WIN32(ERROR_INVALID_STATE), m_running); + + WSLA_SHUTDOWN message{}; + m_initChannel.SendMessage(message); + auto response = m_initChannel.ReceiveMessageOrClosed(static_cast(TimeoutMs)); + + RETURN_HR_IF(E_UNEXPECTED, response.first != nullptr); + + m_running = false; + return S_OK; +} +CATCH_RETURN(); + +HRESULT WSLAVirtualMachine::Signal(_In_ LONG Pid, _In_ int Signal) +try +{ + std::lock_guard lock(m_lock); + THROW_HR_IF(HRESULT_FROM_WIN32(ERROR_INVALID_STATE), m_running); + + WSLA_SIGNAL message; + message.Pid = Pid; + message.Signal = Signal; + const auto& response = m_initChannel.Transaction(message); + + RETURN_HR_IF(E_FAIL, response.Result != 0); + return S_OK; +} +CATCH_RETURN(); + +HRESULT WSLAVirtualMachine::RegisterCallback(ITerminationCallback* callback) +try +{ + std::lock_guard lock(m_lock); + + THROW_HR_IF(E_INVALIDARG, m_terminationCallback); + + // N.B. this calls AddRef() on the callback + m_terminationCallback = callback; + + return S_OK; +} +CATCH_RETURN(); + +bool WSLAVirtualMachine::ParseTtyInformation(const WSLA_PROCESS_FD* Fds, ULONG FdCount, const WSLA_PROCESS_FD** TtyInput, const WSLA_PROCESS_FD** TtyOutput) +{ + bool foundNonTtyFd = false; + + for (ULONG i = 0; i < FdCount; i++) + { + if (Fds[i].Type == WslFdTypeTerminalInput) + { + THROW_HR_IF_MSG(E_INVALIDARG, *TtyInput != nullptr, "Only one TtyInput fd can be passed. Index=%lu", i); + + *TtyInput = &Fds[i]; + } + else if (Fds[i].Type == WslFdTypeTerminalOutput) + { + THROW_HR_IF_MSG(E_INVALIDARG, *TtyOutput != nullptr, "Only one TtyOutput fd can be passed. Index=%lu", i); + *TtyOutput = &Fds[i]; + } + else + { + foundNonTtyFd = true; + } + } + + THROW_HR_IF_MSG( + E_INVALIDARG, foundNonTtyFd && (*TtyOutput != nullptr || *TtyInput != nullptr), "Found mixed tty & non tty fds"); + + return !foundNonTtyFd && FdCount > 0; +} + +void WSLAVirtualMachine::LaunchPortRelay() +{ + WI_ASSERT(!m_portRelayChannelRead); + + auto [_, __, channel] = Fork(WSLA_FORK::ForkType::Process); + + std::lock_guard lock(m_portRelaylock); + auto relayPort = channel.Transaction(); + + wil::unique_handle readPipe; + wil::unique_handle writePipe; + THROW_IF_WIN32_BOOL_FALSE(CreatePipe(&readPipe, &m_portRelayChannelWrite, nullptr, 0)); + THROW_IF_WIN32_BOOL_FALSE(CreatePipe(&m_portRelayChannelRead, &writePipe, nullptr, 0)); + + wsl::windows::common::helpers::SetHandleInheritable(readPipe.get()); + wsl::windows::common::helpers::SetHandleInheritable(writePipe.get()); + wsl::windows::common::helpers::SetHandleInheritable(m_vmExitEvent.get()); + + // Get an impersonation token + auto userToken = wsl::windows::common::security::GetUserToken(TokenImpersonation); + auto restrictedToken = wsl::windows::common::security::CreateRestrictedToken(userToken.get()); + + auto path = wsl::windows::common::wslutil::GetBasePath() / L"wslrelay.exe"; + + auto cmd = std::format( + L"\"{}\" {} {} {} {} {} {} {} {}", + path, + wslrelay::mode_option, + static_cast(wslrelay::RelayMode::WSLAPortRelay), + wslrelay::exit_event_option, + HandleToUlong(m_vmExitEvent.get()), + wslrelay::port_option, + relayPort.Result, + wslrelay::vm_id_option, + m_vmId); + + WSL_LOG("LaunchWslRelay", TraceLoggingValue(cmd.c_str(), "cmd")); + + wsl::windows::common::SubProcess process{nullptr, cmd.c_str()}; + process.SetStdHandles(readPipe.get(), writePipe.get(), nullptr); + process.SetToken(restrictedToken.get()); + process.Start(); + + readPipe.release(); + writePipe.release(); +} + +HRESULT WSLAVirtualMachine::MapPort(_In_ int Family, _In_ short WindowsPort, _In_ short LinuxPort, _In_ BOOL Remove) +try +{ + std::lock_guard lock(m_portRelaylock); + + RETURN_HR_IF(E_ILLEGAL_STATE_CHANGE, !m_portRelayChannelWrite); + + WSLA_MAP_PORT message; + message.WindowsPort = WindowsPort; + message.LinuxPort = LinuxPort; + message.AddressFamily = Family; + message.Stop = Remove; + + DWORD bytesTransfered{}; + THROW_IF_WIN32_BOOL_FALSE(WriteFile(m_portRelayChannelWrite.get(), &message, sizeof(message), &bytesTransfered, nullptr)); + THROW_HR_IF_MSG(E_UNEXPECTED, bytesTransfered != sizeof(message), "%u bytes transfered", bytesTransfered); + + HRESULT result = E_UNEXPECTED; + THROW_IF_WIN32_BOOL_FALSE(ReadFile(m_portRelayChannelRead.get(), &result, sizeof(result), &bytesTransfered, nullptr)); + + THROW_HR_IF(E_UNEXPECTED, bytesTransfered != sizeof(result)); + + return result; +} +CATCH_RETURN(); + +HRESULT WSLAVirtualMachine::MountWindowsFolder(_In_ LPCWSTR WindowsPath, _In_ LPCSTR LinuxPath, _In_ BOOL ReadOnly) +{ + return MountWindowsFolderImpl(WindowsPath, LinuxPath, ReadOnly, WslMountFlagsNone); +} + +HRESULT WSLAVirtualMachine::MountWindowsFolderImpl(_In_ LPCWSTR WindowsPath, _In_ LPCSTR LinuxPath, _In_ BOOL ReadOnly, _In_ WslMountFlags Flags) +try +{ + std::filesystem::path Path(WindowsPath); + THROW_HR_IF_MSG(E_INVALIDARG, !Path.is_absolute(), "Path is not absolute: '%ls'", WindowsPath); + THROW_HR_IF_MSG( + HRESULT_FROM_WIN32(ERROR_PATH_NOT_FOUND), !std::filesystem::is_directory(Path), "Path is not a directory: '%ls'", WindowsPath); + + GUID shareGuid{}; + THROW_IF_FAILED(CoCreateGuid(&shareGuid)); + + auto shareName = shared::string::GuidToString(shareGuid, shared::string::None); + + { + // Create the plan9 share on the host + std::lock_guard lock(m_lock); + + // Verify that this folder isn't already mounted. + auto it = m_plan9Mounts.find(LinuxPath); + THROW_HR_IF(HRESULT_FROM_WIN32(ERROR_ALREADY_EXISTS), it != m_plan9Mounts.end()); + + hcs::AddPlan9Share( + m_computeSystem.get(), + shareName.c_str(), + shareName.c_str(), + WindowsPath, + LX_INIT_UTILITY_VM_PLAN9_PORT, + hcs::Plan9ShareFlags::AllowOptions | (ReadOnly ? hcs::Plan9ShareFlags::ReadOnly : hcs::Plan9ShareFlags::None), + wsl::windows::common::security::GetUserToken(TokenImpersonation).get()); + + m_plan9Mounts.emplace(LinuxPath, shareName); + } + + auto deleteOnFailure = wil::scope_exit_log(WI_DIAGNOSTICS_INFO, [&]() { + std::lock_guard lock(m_lock); + + LOG_HR_IF(E_UNEXPECTED, m_plan9Mounts.erase(LinuxPath) != 1); + }); + + // Create the guest mount + auto [_, __, channel] = Fork(WSLA_FORK::Thread); + + WSLA_CONNECT message; + message.HostPort = LX_INIT_UTILITY_VM_PLAN9_PORT; + + auto fd = channel.Transaction(message).Result; + THROW_HR_IF_MSG(E_FAIL, fd < 0, "WSLA_CONNECT failed with %i", fd); + + auto shareNameUtf8 = shared::string::WideToMultiByte(shareName); + auto mountOptions = + std::format("msize={},trans=fd,rfdno={},wfdno={},aname={},cache=mmap", LX_INIT_UTILITY_VM_PLAN9_BUFFER_SIZE, fd, fd, shareNameUtf8); + + THROW_HR_IF(E_FAIL, MountImpl(channel, shareNameUtf8.c_str(), LinuxPath, "9p", mountOptions.c_str(), Flags) != 0); + + deleteOnFailure.release(); + return S_OK; +} +CATCH_RETURN(); + +HRESULT WSLAVirtualMachine::UnmountWindowsFolder(_In_ LPCSTR LinuxPath) +try +{ + std::lock_guard lock(m_lock); + + // Verify that this folder is mounted. + auto it = m_plan9Mounts.find(LinuxPath); + THROW_HR_IF(HRESULT_FROM_WIN32(ERROR_NOT_FOUND), it == m_plan9Mounts.end()); + + // Unmount the folder from the guest. If the mount is not found, this most likely means that the guest unmounted it. + auto result = Unmount(LinuxPath); + THROW_HR_IF(result, FAILED(result) && result != HRESULT_FROM_WIN32(ERROR_NOT_FOUND)); + + // Remove the share from the host + hcs::RemovePlan9Share(m_computeSystem.get(), it->second.c_str(), LX_INIT_UTILITY_VM_PLAN9_PORT); + + m_plan9Mounts.erase(it); + + return S_OK; +} +CATCH_RETURN(); + +HRESULT WSLAVirtualMachine::MountGpuLibraries(_In_ LPCSTR LibrariesMountPoint, _In_ LPCSTR DriversMountpoint, _In_ DWORD Flags) +try +{ + RETURN_HR_IF_MSG(E_INVALIDARG, WI_IsAnyFlagSet(Flags, ~WslMountFlagsWriteableOverlayFs), "Unexpected flags: %lu", Flags); + + RETURN_HR_IF(HRESULT_FROM_WIN32(ERROR_INVALID_CONFIG_VALUE), !m_settings.EnableGPU); + + auto [channel, _, __] = Fork(WSLA_FORK::Thread); + + auto windowsPath = wil::GetWindowsDirectoryW(); + + // Mount drivers. + RETURN_IF_FAILED(MountWindowsFolderImpl( + std::format(L"{}\\System32\\DriverStore\\FileRepository", windowsPath).c_str(), DriversMountpoint, true, static_cast(Flags))); + + // Mount the inbox libraries. + auto inboxLibPath = std::format(L"{}\\System32\\lxss\\lib", windowsPath); + std::optional inboxLibMountPoint; + if (std::filesystem::is_directory(inboxLibPath)) + { + inboxLibMountPoint = std::format("{}/inbox", LibrariesMountPoint); + RETURN_IF_FAILED(MountWindowsFolder(inboxLibPath.c_str(), inboxLibMountPoint->c_str(), true)); + } + + // Mount the packaged libraries. + +#ifdef WSL_GPU_LIB_PATH + + auto packagedLibPath = std::filesystem::path(TEXT(WSL_GPU_LIB_PATH)); + +#else + + auto packagedLibPath = wslutil::GetBasePath() / L"lib"; + +#endif + + auto packagedLibMountPoint = std::format("{}/packaged", LibrariesMountPoint); + RETURN_IF_FAILED(MountWindowsFolder(packagedLibPath.c_str(), packagedLibMountPoint.c_str(), true)); + + // Mount an overlay containing both inbox and packaged libraries (the packaged mount takes precedence). + std::string options = "lowerdir=" + packagedLibMountPoint; + if (inboxLibMountPoint.has_value()) + { + options += ":" + inboxLibMountPoint.value(); + } + + RETURN_IF_FAILED(Mount("none", LibrariesMountPoint, "overlay", options.c_str(), Flags)); + return S_OK; +} +CATCH_RETURN(); \ No newline at end of file diff --git a/src/windows/wslaservice/exe/WSLAVirtualMachine.h b/src/windows/wslaservice/exe/WSLAVirtualMachine.h new file mode 100644 index 0000000..2ae3aa2 --- /dev/null +++ b/src/windows/wslaservice/exe/WSLAVirtualMachine.h @@ -0,0 +1,106 @@ +/*++ + +Copyright (c) Microsoft. All rights reserved. + +Module Name: + + WSLAVirtualMachine.h + +Abstract: + + TODO + +--*/ +#pragma once +#include "wslaservice.h" +#include "INetworkingEngine.h" +#include "hcs.hpp" +#include "Dmesg.h" +#include "WSLAApi.h" + +namespace wsl::windows::service::wsla { + +class WSLAUserSessionImpl; + +class DECLSPEC_UUID("0CFC5DC1-B6A7-45FC-8034-3FA9ED73CE30") WSLAVirtualMachine + : public Microsoft::WRL::RuntimeClass, IWSLAVirtualMachine, IFastRundown> + +{ +public: + WSLAVirtualMachine(const VIRTUAL_MACHINE_SETTINGS& Settings, PSID Sid, WSLAUserSessionImpl* UserSession); + ~WSLAVirtualMachine(); + + void Start(); + void OnSessionTerminating(); + + IFACEMETHOD(AttachDisk(_In_ PCWSTR Path, _In_ BOOL ReadOnly, _Out_ LPSTR* Device, _Out_ ULONG* Lun)) override; + IFACEMETHOD(Mount(_In_ LPCSTR Source, _In_ LPCSTR Target, _In_ LPCSTR Type, _In_ LPCSTR Options, _In_ ULONG Flags)) override; + IFACEMETHOD(CreateLinuxProcess( + _In_ const WSLA_CREATE_PROCESS_OPTIONS* Options, _In_ ULONG FdCount, _In_ WSLA_PROCESS_FD* Fd, _Out_ ULONG* Handles, _Out_ WSLA_CREATE_PROCESS_RESULT* Result)) override; + IFACEMETHOD(WaitPid(_In_ LONG Pid, _In_ ULONGLONG TimeoutMs, _Out_ ULONG* State, _Out_ int* Code)) override; + IFACEMETHOD(Signal(_In_ LONG Pid, _In_ int Signal)) override; + IFACEMETHOD(Shutdown(ULONGLONG _In_ TimeoutMs)) override; + IFACEMETHOD(RegisterCallback(_In_ ITerminationCallback* callback)) override; + IFACEMETHOD(GetDebugShellPipe(_Out_ LPWSTR* pipePath)) override; + IFACEMETHOD(MapPort(_In_ int Family, _In_ short WindowsPort, _In_ short LinuxPort, _In_ BOOL Remove)) override; + IFACEMETHOD(Unmount(_In_ const char* Path)) override; + IFACEMETHOD(DetachDisk(_In_ ULONG Lun)) override; + IFACEMETHOD(MountWindowsFolder(_In_ LPCWSTR WindowsPath, _In_ LPCSTR LinuxPath, _In_ BOOL ReadOnly)) override; + IFACEMETHOD(UnmountWindowsFolder(_In_ LPCSTR LinuxPath)) override; + IFACEMETHOD(MountGpuLibraries(_In_ LPCSTR LibrariesMountPoint, _In_ LPCSTR DriversMountpoint, _In_ DWORD Flags)) override; + +private: + static int32_t MountImpl(wsl::shared::SocketChannel& Channel, LPCSTR Source, _In_ LPCSTR Target, _In_ LPCSTR Type, _In_ LPCSTR Options, _In_ ULONG Flags); + static void CALLBACK s_OnExit(_In_ HCS_EVENT* Event, _In_opt_ void* Context); + static bool ParseTtyInformation(const WSLA_PROCESS_FD* Fds, ULONG FdCount, const WSLA_PROCESS_FD** TtyInput, const WSLA_PROCESS_FD** TtyOutput); + + void ConfigureNetworking(); + void OnExit(_In_ const HCS_EVENT* Event); + + std::tuple Fork(enum WSLA_FORK::ForkType Type); + std::tuple Fork(wsl::shared::SocketChannel& Channel, enum WSLA_FORK::ForkType Type); + int32_t ExpectClosedChannelOrError(wsl::shared::SocketChannel& Channel); + + wil::unique_socket ConnectSocket(wsl::shared::SocketChannel& Channel, int32_t Fd); + void OpenLinuxFile(wsl::shared::SocketChannel& Channel, const char* Path, uint32_t Flags, int32_t Fd); + void LaunchPortRelay(); + + std::vector CreateLinuxProcessImpl( + _In_ const WSLA_CREATE_PROCESS_OPTIONS* Options, _In_ ULONG FdCount, _In_ WSLA_PROCESS_FD* Fd, _Out_ WSLA_CREATE_PROCESS_RESULT* Result); + + HRESULT MountWindowsFolderImpl(_In_ LPCWSTR WindowsPath, _In_ LPCSTR LinuxPath, _In_ BOOL ReadOnly, _In_ WslMountFlags Flags); + + struct AttachedDisk + { + std::filesystem::path Path; + std::string Device; + bool AccessGranted = false; + }; + + VIRTUAL_MACHINE_SETTINGS m_settings; + GUID m_vmId{}; + std::wstring m_vmIdString; + wsl::windows::common::helpers::WindowsVersion m_windowsVersion = wsl::windows::common::helpers::GetWindowsVersion(); + int m_coldDiscardShiftSize{}; + bool m_running = false; + PSID m_userSid{}; + std::wstring m_debugShellPipe; + + wsl::windows::common::hcs::unique_hcs_system m_computeSystem; + std::shared_ptr m_dmesgCollector; + wil::unique_event m_vmExitEvent{wil::EventOptions::ManualReset}; + wil::unique_event m_vmTerminatingEvent{wil::EventOptions::ManualReset}; + wil::com_ptr m_terminationCallback; + std::unique_ptr m_networkEngine; + + wsl::shared::SocketChannel m_initChannel; + wil::unique_handle m_portRelayChannelRead; + wil::unique_handle m_portRelayChannelWrite; + + std::map m_attachedDisks; + std::map m_plan9Mounts; + std::recursive_mutex m_lock; + std::mutex m_portRelaylock; + WSLAUserSessionImpl* m_userSession; +}; +} // namespace wsl::windows::service::wsla \ No newline at end of file diff --git a/src/windows/wslaservice/exe/application.manifest b/src/windows/wslaservice/exe/application.manifest new file mode 100644 index 0000000..0c16180 --- /dev/null +++ b/src/windows/wslaservice/exe/application.manifest @@ -0,0 +1,8 @@ + + + + + true + + + \ No newline at end of file diff --git a/src/windows/wslaservice/exe/main.rc b/src/windows/wslaservice/exe/main.rc new file mode 100644 index 0000000..174615e --- /dev/null +++ b/src/windows/wslaservice/exe/main.rc @@ -0,0 +1,28 @@ +/*++ + +Copyright (c) Microsoft. All rights reserved. + +Module Name: + + main.rc + +Abstract: + + This file contains resources for wslaservice. + +--*/ + +#include +#include "resource.h" +#include "wslversioninfo.h" + +#define VER_INTERNALNAME_STR "wslaservice.exe" +#define VER_ORIGINALFILENAME_STR "wslaservice.exe" + +#define VER_FILETYPE VFT_APP +#define VER_FILESUBTYPE VFT2_UNKNOWN +#define VER_FILEDESCRIPTION_STR "Windows Subsystem for Linux for Apps Service" +ID_ICON ICON PRELOAD DISCARDABLE "..\..\..\Images\wsl.ico" + + +#include diff --git a/src/windows/wslaservice/exe/resource.h b/src/windows/wslaservice/exe/resource.h new file mode 100644 index 0000000..3d6234e --- /dev/null +++ b/src/windows/wslaservice/exe/resource.h @@ -0,0 +1,15 @@ +/*++ + +Copyright (c) Microsoft. All rights reserved. + +Module Name: + + resource.h + +Abstract: + + This file contains resource declarations for wslaservice.exe + +--*/ + +#define ID_ICON 1 diff --git a/src/windows/wslaservice/inc/CMakeLists.txt b/src/windows/wslaservice/inc/CMakeLists.txt new file mode 100644 index 0000000..a498313 --- /dev/null +++ b/src/windows/wslaservice/inc/CMakeLists.txt @@ -0,0 +1,2 @@ +add_idl(wslaserviceidl "wslaservice.idl" "") +set_target_properties(wslaserviceidl PROPERTIES FOLDER windows) \ No newline at end of file diff --git a/src/windows/wslaservice/inc/wslaservice.idl b/src/windows/wslaservice/inc/wslaservice.idl new file mode 100644 index 0000000..831beff --- /dev/null +++ b/src/windows/wslaservice/inc/wslaservice.idl @@ -0,0 +1,114 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Module Name: + + wslaservice.idl + +Abstract: + + This file contains the COM object definitions used to talk with the WSLa + service "WslaService" + +--*/ + +import "unknwn.idl"; +import "wtypes.idl"; + +cpp_quote("#ifdef __cplusplus") +cpp_quote("class DECLSPEC_UUID(\"a9b7a1b9-0671-405c-95f1-e0612cb4ce8f\") WSLAUserSession;") +cpp_quote("#endif") + +typedef +struct _WSL_VERSION { + ULONG Major; + ULONG Minor; + ULONG Revision; +} WSL_VERSION; + + +typedef [system_handle(sh_socket)] HANDLE HVSOCKET_HANDLE; + + +typedef +struct _WSLA_CREATE_PROCESS_OPTIONS { + [string] LPCSTR Executable; + ULONG CommandLineCount; + [unique, size_is(CommandLineCount)] LPCSTR* CommandLine; + ULONG EnvironmentCount; + [unique, size_is(EnvironmentCount)] LPCSTR* Environment; + [unique] LPCSTR CurrentDirectory; +} WSLA_CREATE_PROCESS_OPTIONS; + +typedef struct _WSLA_PROCESS_FD +{ + LONG Fd; + int Type; + [string, unique] LPCSTR Path; +} WSLA_PROCESS_FD; + +typedef +struct _WSLA_CREATE_PROCESS_RESULT { + int Errno; + int Pid; +} WSLA_CREATE_PROCESS_RESULT; + +[ + uuid(7BC4E198-6531-4FA6-ADE2-5EF3D2A04DFE), + pointer_default(unique), + object +] +interface ITerminationCallback : IUnknown +{ + HRESULT OnTermination(ULONG Reason, LPCWSTR Details); +}; + +[ + uuid(82A7ABC8-6B50-43FC-AB96-15FBBE7E8761), + pointer_default(unique), + object +] +interface IWSLAVirtualMachine : IUnknown +{ + HRESULT AttachDisk([in] LPCWSTR Path, [in] BOOL ReadOnly, [out] LPSTR* Device, [out] ULONG* Lun); + HRESULT Mount([in, unique] LPCSTR Source, [in] LPCSTR Target, [in] LPCSTR Type, [in] LPCSTR Options, [in] ULONG Flags); + HRESULT CreateLinuxProcess([in] const WSLA_CREATE_PROCESS_OPTIONS* Options, [in] ULONG FdCount, [in, unique, size_is(FdCount)] WSLA_PROCESS_FD* Fds, [out, size_is(FdCount)] ULONG* Handles, [out] WSLA_CREATE_PROCESS_RESULT* Result); + HRESULT WaitPid([in] LONG Pid, [in] ULONGLONG TimeoutMs, [out] ULONG* State, [out] int* Code); + HRESULT Signal([in] LONG Pid, [in] int Signal); + HRESULT Shutdown([in] ULONGLONG TimeoutMs); + HRESULT RegisterCallback([in] ITerminationCallback* terminationCallback); + HRESULT GetDebugShellPipe([out] LPWSTR* pipePath); + HRESULT MapPort([in] int Family, [in] short WindowsPort, [in] short LinuxPort, [in] BOOL Remove); + HRESULT Unmount([in] LPCSTR Path); + HRESULT DetachDisk([in] ULONG Lun); + HRESULT MountWindowsFolder([in] LPCWSTR WindowsPath, [in] LPCSTR LinuxPath, [in] BOOL ReadOnly); + HRESULT UnmountWindowsFolder([in] LPCSTR LinuxPath); + HRESULT MountGpuLibraries([in] LPCSTR LibrariesMountPoint, [in] LPCSTR DriversMountpoint, [in] DWORD Flags); +} + +typedef +struct _VIRTUAL_MACHINE_SETTINGS { + LPCWSTR DisplayName; + ULONGLONG MemoryMb; + ULONG CpuCount; + ULONG BootTimeoutMs; + ULONG DmesgOutput; + ULONG NetworkingMode; + BOOL EnableDnsTunneling; + BOOL EnableDebugShell; + BOOL EnableEarlyBootDmesg; + BOOL EnableGPU; +} VIRTUAL_MACHINE_SETTINGS; + + +[ + uuid(82A7ABC8-6B50-43FC-AB96-15FBBE7E8760), + pointer_default(unique), + object +] +interface IWSLAUserSession : IUnknown +{ + HRESULT GetVersion([out] WSL_VERSION* Error); + HRESULT CreateVirtualMachine([in] const VIRTUAL_MACHINE_SETTINGS* Settings, [out]IWSLAVirtualMachine** VirtualMachine); +} \ No newline at end of file diff --git a/src/windows/wslaservice/stub/CMakeLists.txt b/src/windows/wslaservice/stub/CMakeLists.txt new file mode 100644 index 0000000..feccf07 --- /dev/null +++ b/src/windows/wslaservice/stub/CMakeLists.txt @@ -0,0 +1,13 @@ +set(SOURCES + ${CMAKE_CURRENT_BINARY_DIR}/../inc/${TARGET_PLATFORM}/${CMAKE_BUILD_TYPE}/wslaservice_i_${TARGET_PLATFORM}.c + ${CMAKE_CURRENT_BINARY_DIR}/../inc/${TARGET_PLATFORM}/${CMAKE_BUILD_TYPE}/wslaservice_p_${TARGET_PLATFORM}.c + ${CMAKE_CURRENT_BINARY_DIR}/../inc/${TARGET_PLATFORM}/${CMAKE_BUILD_TYPE}/dlldata_${TARGET_PLATFORM}.c + ${CMAKE_CURRENT_LIST_DIR}/WslaServiceProxyStub.def + ${CMAKE_CURRENT_LIST_DIR}/WslaServiceProxyStub.rc) + +set_source_files_properties(${SOURCES} PROPERTIES GENERATED TRUE) + +add_library(wslaserviceproxystub SHARED ${SOURCES}) +add_dependencies(wslaserviceproxystub wslaserviceidl) +target_link_libraries(wslaserviceproxystub ${COMMON_LINK_LIBRARIES}) +set_target_properties(wslaserviceproxystub PROPERTIES FOLDER windows) \ No newline at end of file diff --git a/src/windows/wslaservice/stub/WslaServiceProxyStub.def b/src/windows/wslaservice/stub/WslaServiceProxyStub.def new file mode 100644 index 0000000..35f95e7 --- /dev/null +++ b/src/windows/wslaservice/stub/WslaServiceProxyStub.def @@ -0,0 +1,5 @@ +LIBRARY WslaServiceProxyStub.dll + +EXPORTS + DllGetClassObject PRIVATE + DllCanUnloadNow PRIVATE diff --git a/src/windows/wslaservice/stub/WslaServiceProxyStub.rc b/src/windows/wslaservice/stub/WslaServiceProxyStub.rc new file mode 100644 index 0000000..5c9cc05 --- /dev/null +++ b/src/windows/wslaservice/stub/WslaServiceProxyStub.rc @@ -0,0 +1,23 @@ +/*++ + +Copyright (c) Microsoft. All rights reserved. + +Module Name: + + WslaServiceProxyStub.rc + +Abstract: + + This file contains resources for wslaserviceproxystub.dll. + +--*/ + +#include +#include "wslversioninfo.h" + +#define VER_INTERNALNAME_STR "wslaserviceproxystub.dll" +#define VER_ORIGINALFILENAME_STR "wslaserviceproxystub.dll" + +#define VER_FILEDESCRIPTION_STR "WSLA Service ProxyStub DLL" + +#include diff --git a/test/windows/Common.cpp b/test/windows/Common.cpp index fdd1bf3..1614d30 100644 --- a/test/windows/Common.cpp +++ b/test/windows/Common.cpp @@ -1316,6 +1316,17 @@ void StopWslService() StopService(service.get()); } +void StopWslaService() +{ + LogInfo("Stopping WSLAService"); + const wil::unique_schandle manager{OpenSCManager(nullptr, nullptr, SC_MANAGER_CONNECT)}; + VERIFY_IS_NOT_NULL(manager); + + const wil::unique_schandle service{OpenService(manager.get(), L"wslaservice", SERVICE_STOP | SERVICE_QUERY_STATUS)}; + VERIFY_IS_NOT_NULL(service); + StopService(service.get()); +} + wil::unique_handle GetNonElevatedToken() { const auto token = wil::open_current_access_token(TOKEN_ALL_ACCESS); diff --git a/test/windows/Common.h b/test/windows/Common.h index 776a98f..b0d4a6b 100644 --- a/test/windows/Common.h +++ b/test/windows/Common.h @@ -523,6 +523,7 @@ inline auto EnableSystemd(const std::string& extraConfig = "") std::wstring EscapePath(std::wstring_view Path); void StopWslService(); +void StopWslaService(); std::optional GetDistributionId(LPCWSTR Name); wil::unique_hkey OpenDistributionKey(LPCWSTR Name); diff --git a/test/windows/WSLATests.cpp b/test/windows/WSLATests.cpp index 382da85..ef1e5b2 100644 --- a/test/windows/WSLATests.cpp +++ b/test/windows/WSLATests.cpp @@ -821,7 +821,7 @@ class WSLATests }); // Stop the service - StopWslService(); + StopWslaService(); // Verify that the thread is unstuck stuckThread.join();