diff --git a/CMakeLists.txt b/CMakeLists.txt
index e760398..e66a489 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -388,6 +388,7 @@ include_directories(${WSLDEPS_SOURCE_DIR}/include/lxcore)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/src/shared/inc)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/src/windows/inc)
include_directories(${CMAKE_CURRENT_BINARY_DIR}/src/windows/service/inc/${TARGET_PLATFORM}/${CMAKE_BUILD_TYPE})
+include_directories(${CMAKE_CURRENT_BINARY_DIR}/src/windows/wslaservice/inc/${TARGET_PLATFORM}/${CMAKE_BUILD_TYPE})
include_directories(${CMAKE_CURRENT_BINARY_DIR}/src/windows/wslinstaller/inc/${TARGET_PLATFORM}/${CMAKE_BUILD_TYPE})
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/src/linux/init/inc)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/src/windows/common)
@@ -413,6 +414,7 @@ add_subdirectory(msipackage)
add_subdirectory(msixinstaller)
add_subdirectory(src/windows/common)
add_subdirectory(src/windows/service)
+add_subdirectory(src/windows/wslaservice)
add_subdirectory(src/windows/wslinstaller/inc)
add_subdirectory(src/windows/wslinstaller/stub)
add_subdirectory(src/windows/wslinstaller/exe)
diff --git a/msipackage/CMakeLists.txt b/msipackage/CMakeLists.txt
index 9f51362..7f7bcb1 100644
--- a/msipackage/CMakeLists.txt
+++ b/msipackage/CMakeLists.txt
@@ -3,7 +3,7 @@ set(OUTPUT_PACKAGE ${BIN}/wsl.msi)
set(PACKAGE_WIX_IN ${CMAKE_CURRENT_LIST_DIR}/package.wix.in)
set(PACKAGE_WIX ${BIN}/package.wix)
set(CAB_CACHE ${BIN}/cab)
-set(BINARIES wsl.exe;wslg.exe;wslhost.exe;wslrelay.exe;wslservice.exe;wslserviceproxystub.dll;init;initrd.img;wslinstall.dll)
+set(BINARIES wsl.exe;wslg.exe;wslhost.exe;wslrelay.exe;wslservice.exe;wslserviceproxystub.dll;init;initrd.img;wslinstall.dll;wslaserviceproxystub.dll;wslaservice.exe)
if (WSL_BUILD_WSL_SETTINGS)
list(APPEND BINARIES_DEPENDENCIES "wslsettings/wslsettings.dll;wslsettings/wslsettings.exe;libwsl.dll")
@@ -39,7 +39,7 @@ add_custom_command(
add_custom_target(msipackage DEPENDS ${OUTPUT_PACKAGE})
set_target_properties(msipackage PROPERTIES EXCLUDE_FROM_ALL FALSE SOURCES ${PACKAGE_WIX_IN})
-add_dependencies(msipackage wsl wslg wslservice wslhost wslrelay wslserviceproxystub init initramfs wslinstall msixgluepackage)
+add_dependencies(msipackage wsl wslg wslservice wslhost wslrelay wslserviceproxystub init initramfs wslinstall msixgluepackage wslaservice wslaserviceproxystub)
if (WSL_BUILD_WSL_SETTINGS)
add_dependencies(msipackage wslsettings libwsl)
diff --git a/msipackage/package.wix.in b/msipackage/package.wix.in
index 30837b0..5a7d1fd 100644
--- a/msipackage/package.wix.in
+++ b/msipackage/package.wix.in
@@ -50,75 +50,75 @@
-
-
-
-
-
+
+
+
+
+
-
-
-
+
+
+
-
-
-
+
+
+
-
-
-
-
-
-
+
+
+
+
+
+
-
-
+
+
-
-
-
-
-
+
+
+
+
+
+
-
-
-
-
+
+
+
-
-
-
+
+
+
-
-
-
-
+
+
+
+
-
-
-
-
+
+
+
+
-
-
-
-
-
-
+
+
+
+
+
+
+
-
-
+
-
-
-
-
-
+
+
+
+
+
@@ -131,33 +131,13 @@
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
+
-
-
-
-
+
+
+
+
@@ -175,18 +155,6 @@
-
-
-
-
-
-
-
-
-
-
-
-
@@ -226,7 +194,7 @@
-
+
@@ -265,6 +233,68 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
@@ -402,6 +432,7 @@
+
@@ -501,21 +532,21 @@
Return="check"
Execute="deferred"
/>
-
-
-
diff --git a/src/windows/common/CMakeLists.txt b/src/windows/common/CMakeLists.txt
index 90c50d0..150ee09 100644
--- a/src/windows/common/CMakeLists.txt
+++ b/src/windows/common/CMakeLists.txt
@@ -1,83 +1,105 @@
set(SOURCES
ConsoleProgressBar.cpp
ConsoleProgressIndicator.cpp
+ DeviceHostProxy.cpp
+ DeviceHostProxy.h
disk.cpp
Distribution.cpp
+ Dmesg.cpp
+ Dmesg.h
+ DnsResolver.cpp
+ DnsTunnelingChannel.cpp
+ ExecutionContext.cpp
filesystem.cpp
+ GnsChannel.cpp
HandleConsoleProgressBar.cpp
hcs.cpp
helpers.cpp
- interop.cpp
- ExecutionContext.cpp
- socket.cpp
hvsocket.cpp
+ interop.cpp
Localization.cpp
lxssbusclient.cpp
lxssclient.cpp
LxssMessagePort.cpp
+ LxssSecurity.cpp
LxssServerPort.cpp
+ NatNetworking.cpp
+ notifications.cpp
Redirector.cpp
registry.cpp
relay.cpp
+ RingBuffer.cpp
+ socket.cpp
string.cpp
SubProcess.cpp
svccomm.cpp
svccommio.cpp
WslClient.cpp
WslCoreConfig.cpp
+ WslCoreFilesystem.cpp
WslCoreFirewallSupport.cpp
+ WslCoreHostDnsInfo.cpp
+ WslCoreNetworkEndpointSettings.cpp
WslCoreNetworkingSupport.cpp
WslInstall.cpp
WslSecurity.cpp
WslTelemetry.cpp
wslutil.cpp
- notifications.cpp)
+ )
set(HEADERS
../../../generated/Localization.h
- ../../shared/inc/CommandLine.h
- ../../shared/inc/defs.h
- ../../shared/inc/lxfsshares.h
- ../../shared/inc/lxinitshared.h
- ../../shared/inc/SocketChannel.h
- ../../shared/inc/socketshared.h
- ../../shared/inc/hns_schema.h
- ../../shared/inc/JsonUtils.h
- ../../shared/inc/stringshared.h
- ../../shared/inc/retryshared.h
- ../../shared/inc/message.h
- ../../shared/inc/prettyprintshared.h
- ../inc/WslPluginApi.h
- ../inc/wslpolicies.h
../inc/lxssbusclient.h
../inc/lxssclient.h
../inc/LxssDynamicFunction.h
../inc/traceloggingconfig.h
../inc/wdk.h
- ../inc/wsl.h
../inc/wslconfig.h
+ ../inc/wsl.h
../inc/wslhost.h
+ ../inc/WslPluginApi.h
+ ../inc/wslpolicies.h
../inc/wslrelay.h
+ ../../shared/inc/CommandLine.h
+ ../../shared/inc/defs.h
+ ../../shared/inc/hns_schema.h
+ ../../shared/inc/JsonUtils.h
+ ../../shared/inc/lxfsshares.h
+ ../../shared/inc/lxinitshared.h
+ ../../shared/inc/message.h
+ ../../shared/inc/prettyprintshared.h
+ ../../shared/inc/retryshared.h
+ ../../shared/inc/SocketChannel.h
+ ../../shared/inc/socketshared.h
+ ../../shared/inc/stringshared.h
ConsoleProgressBar.h
ConsoleProgressIndicator.h
disk.hpp
Distribution.h
+ DnsResolver.h
+ DnsTunnelingChannel.h
+ ExecutionContext.h
filesystem.hpp
+ GnsChannel.h
+ HandleConsoleProgressBar.h
hcs.hpp
hcs_schema.h
helpers.hpp
- HandleConsoleProgressBar.h
- interop.hpp
- ExecutionContext.h
- socket.hpp
hvsocket.hpp
+ INetworkingEngine.h
+ interop.hpp
LxssMessagePort.h
LxssPort.h
+ LxssSecurity.h
LxssServerPort.h
+ NatNetworking.h
+ notifications.h
precomp.h
Redirector.h
registry.hpp
relay.hpp
+ RingBuffer.h
+ socket.hpp
string.hpp
Stringify.h
SubProcess.h
@@ -85,13 +107,17 @@ set(HEADERS
svccommio.hpp
WslClient.h
WslCoreConfig.h
+ WslCoreFilesystem.h
WslCoreFirewallSupport.h
+ WslCoreHostDnsInfo.h
+ WslCoreMessageQueue.h
+ WslCoreNetworkEndpointSettings.h
WslCoreNetworkingSupport.h
WslInstall.h
WslSecurity.h
WslTelemetry.h
- wslutil.h
- notifications.h)
+ wslutil.cpp
+ )
add_library(common STATIC ${SOURCES} ${HEADERS})
add_dependencies(common wslserviceidl localization wslservicemc wslinstalleridl)
diff --git a/src/windows/service/exe/DeviceHostProxy.cpp b/src/windows/common/DeviceHostProxy.cpp
similarity index 97%
rename from src/windows/service/exe/DeviceHostProxy.cpp
rename to src/windows/common/DeviceHostProxy.cpp
index 17a7dd4..e04f346 100644
--- a/src/windows/service/exe/DeviceHostProxy.cpp
+++ b/src/windows/common/DeviceHostProxy.cpp
@@ -1,261 +1,261 @@
-// Copyright (C) Microsoft Corporation. All rights reserved.
-
-#include "precomp.h"
-#include "DeviceHostProxy.h"
-
-// This template works around a limitation with decltype on overloaded functions. It will be able
-// to get the correct version of GetVmWorkerProcess based on the provided type arguments. By
-// doing it this way, a compiler error will be generated if someone changes the signature of
-// GetVmWorkerProcess.
-//
-// The way this works: decltype(GetVmWorkerProcess) does not work because it's overloaded.
-// decltype(GetVmWorkerProcess(arg1, ...)) works to select an overload if you have values of the
-// correct type (std::declval() generates a value of the specified type), however the result
-// of that is the function's return type, not the function's type, so the argument types must
-// be repeated to reconstruct the function type.
-template
-using GetVmWorkerProcessType = decltype(GetVmWorkerProcess(std::declval()...))(Args...);
-
-// Limit the number of allowed doorbells registered by an external HDV vdev. Currently virtio-9p only uses
-// one doorbell and wsldevicehost uses only two.
-#define DEVICE_HOST_PROXY_DOORBELL_LIMIT 8
-
-using namespace wsl::windows::common::hcs;
-
-DeviceHostProxy::DeviceHostProxy(const std::wstring& VmId, const GUID& RuntimeId) :
- m_systemId{VmId}, m_runtimeId{RuntimeId}, m_system{wsl::windows::common::hcs::OpenComputeSystem(VmId.c_str(), GENERIC_ALL)}, m_shutdown{false}
-{
- m_devicesShutdown = false;
-}
-
-GUID DeviceHostProxy::AddNewDevice(const GUID& Type, const wil::com_ptr& Plan9Fs, const std::wstring& VirtIoTag)
-{
- const wrl::ComPtr thisUnknown{CastToUnknown()};
- GUID instanceId{};
- THROW_IF_FAILED(UuidCreate(&instanceId));
- // Tell the device host to create the device.
- THROW_IF_FAILED(Plan9Fs->CreateVirtioDevice(m_systemId.c_str(), thisUnknown.Get(), VirtIoTag.c_str(), &instanceId));
-
- // Add the instance ID to the list of known devices. This must be done before the device is
- // added to the system, because doing that can cause the register doorbell function to be
- // called.
- // N.B. It will be removed if there is a failure.
- {
- auto lock = m_devicesLock.lock_exclusive();
- THROW_HR_IF(E_CHANGED_STATE, m_devicesShutdown);
-
- m_devices.emplace(instanceId, DeviceHostProxyEntry{});
- }
-
- auto removeOnFailure = wil::scope_exit_log(WI_DIAGNOSTICS_INFO, [&]() {
- auto lock = m_devicesLock.lock_exclusive();
- m_devices.erase(instanceId);
- });
-
- // Add the device to the compute system on behalf of the device host.
- ModifySettingRequest request;
- request.RequestType = ModifyRequestType::Add;
- request.ResourcePath = L"VirtualMachine/Devices/FlexibleIov/";
- request.ResourcePath += wsl::shared::string::GuidToString(instanceId, wsl::shared::string::GuidToStringFlags::None);
- request.Settings.EmulatorId = Type;
- request.Settings.HostingModel = FlexibleIoDeviceHostingModel::ExternalRestricted;
- wsl::windows::common::hcs::ModifyComputeSystem(m_system.get(), wsl::shared::ToJsonW(request).c_str());
- removeOnFailure.release();
- return instanceId;
-}
-
-void DeviceHostProxy::AddRemoteFileSystem(const GUID& ImplementationClsid, const std::wstring& Tag, const wil::com_ptr& Plan9Fs)
-{
- auto lock = m_lock.lock_exclusive();
- THROW_HR_IF(E_CHANGED_STATE, m_shutdown);
-
- // Make sure there are no duplicate tags.
- for (auto& entry : m_fileSystems)
- {
- THROW_HR_IF(E_INVALIDARG, entry.ImplementationClsid == ImplementationClsid && entry.Tag == Tag);
- }
-
- m_fileSystems.emplace_back(ImplementationClsid, Tag, Plan9Fs);
-}
-
-wil::com_ptr DeviceHostProxy::GetRemoteFileSystem(const GUID& ImplementationClsid, std::wstring_view Tag)
-{
- auto lock = m_lock.lock_shared();
- THROW_HR_IF(E_CHANGED_STATE, m_shutdown);
-
- for (auto& entry : m_fileSystems)
- {
- if (entry.ImplementationClsid == ImplementationClsid && entry.Tag == Tag)
- {
- return entry.Instance;
- }
- }
-
- return {};
-}
-
-void DeviceHostProxy::Shutdown()
-{
- {
- auto lock = m_lock.lock_exclusive();
- m_fileSystems.clear();
- m_shutdown = true;
- }
-
- {
- auto lock = m_devicesLock.lock_exclusive();
- m_devices.clear();
- m_devicesShutdown = true;
- }
-}
-
-HRESULT
-DeviceHostProxy::RegisterDeviceHost(_In_ IVmDeviceHost* DeviceHost, _In_ DWORD ProcessId, _Out_ UINT64* IpcSectionHandle)
-try
-{
- //
- // Because HdvProxyDeviceHost is not part of the API set, it is loaded here dynamically.
- //
-
- static LxssDynamicFunction proxyDeviceHost{c_hdvModuleName, "HdvProxyDeviceHost"};
- const wil::com_ptr remoteHost = DeviceHost;
- const wil::com_ptr unknown = remoteHost.query();
- THROW_IF_FAILED(proxyDeviceHost(m_system.get(), unknown.get(), ProcessId, IpcSectionHandle));
- return S_OK;
-}
-CATCH_RETURN()
-
-HRESULT
-DeviceHostProxy::NotifyAllDevicesInUse(_In_ LPCWSTR Tag)
-try
-{
- //
- // Add another Plan9 virtio device to the guest so additional mount commands will be possible.
- // This callback should be unused by virtiofs devices because a device is created for every
- // AddSharePath call.
- //
- auto p9fs = GetRemoteFileSystem(__uuidof(p9fs::Plan9FileSystem), Tag);
- THROW_HR_IF(E_NOT_SET, !p9fs);
- (void)AddNewDevice(VIRTIO_PLAN9_DEVICE_ID, p9fs, Tag);
- return S_OK;
-}
-CATCH_RETURN()
-
-HRESULT
-DeviceHostProxy::RegisterDoorbell(const GUID& InstanceId, UINT8 BarIndex, UINT64 Offset, UINT64 TriggerValue, UINT64 Flags, HANDLE Event)
-try
-{
- auto lock = m_devicesLock.lock_exclusive();
- RETURN_HR_IF(E_CHANGED_STATE, m_devicesShutdown);
-
- // Check if the device is one of the known devices that doorbells can be registered for, and
- // if the device has not already registered a doorbell.
- // N.B. For security it is enforced that each device can only register a small number of doorbells.
- // Currently virtio-9p only uses one and the external virtio device uses two.
- const auto knownDevice = m_devices.find(InstanceId);
- RETURN_HR_IF(E_ACCESSDENIED, knownDevice == m_devices.end() || knownDevice->second.DoorbellCount == DEVICE_HOST_PROXY_DOORBELL_LIMIT);
-
- if (!knownDevice->second.MemoryNotification)
- {
- // Get an interface to the worker process to query devices.
- if (!m_deviceAccess)
- {
- static LxssDynamicFunction> getVmWorker{
- c_vmwpctrlModuleName, "GetVmWorkerProcess"};
-
- RETURN_IF_FAILED(getVmWorker(m_runtimeId, __uuidof(*m_deviceAccess), reinterpret_cast(&m_deviceAccess)));
- }
-
- RETURN_HR_IF(E_NOINTERFACE, !m_deviceAccess);
-
- // Retrieve the device's memory notification interface to register the doorbell, and store it
- // to be used during unregistration.
- wil::com_ptr device;
- RETURN_IF_FAILED(m_deviceAccess->GetDevice(FLEXIO_DEVICE_ID, InstanceId, &device));
- knownDevice->second.MemoryNotification = device.query();
- }
-
- const auto result = knownDevice->second.MemoryNotification->RegisterDoorbell(
- static_cast(BarIndex), Offset, TriggerValue, Flags, Event);
-
- if (SUCCEEDED(result))
- {
- ++knownDevice->second.DoorbellCount;
- }
-
- return result;
-}
-CATCH_RETURN()
-
-HRESULT
-DeviceHostProxy::UnregisterDoorbell(const GUID& InstanceId, UINT8 BarIndex, UINT64 Offset, UINT64 TriggerValue, UINT64 Flags)
-try
-{
- auto lock = m_devicesLock.lock_exclusive();
- RETURN_HR_IF(E_CHANGED_STATE, m_devicesShutdown);
-
- // Check if the device is a known device and has registered a doorbell.
- // N.B. If the device is being removed, the device can't be retrieved from the worker process
- // so it's necessary to use the stored COM pointer.
- const auto device = m_devices.find(InstanceId);
- RETURN_HR_IF(E_ACCESSDENIED, device == m_devices.end() || device->second.DoorbellCount == 0);
- RETURN_IF_FAILED(device->second.MemoryNotification->UnregisterDoorbell(static_cast(BarIndex), Offset, TriggerValue, Flags));
-
- if (--device->second.DoorbellCount == 0)
- {
- device->second.MemoryNotification.reset();
- }
-
- return S_OK;
-}
-CATCH_RETURN()
-
-HRESULT
-DeviceHostProxy::CreateSectionBackedMmioRange(
- const GUID& InstanceId, UINT8 BarIndex, UINT64 BarOffsetInPages, UINT64 PageCount, UINT64 MappingFlags, HANDLE SectionHandle, UINT64 SectionOffsetInPages)
-try
-{
- auto lock = m_devicesLock.lock_exclusive();
- RETURN_HR_IF(E_CHANGED_STATE, m_devicesShutdown);
-
- // Check if the device is one of the known devices.
- const auto knownDevice = m_devices.find(InstanceId);
- THROW_HR_IF(E_ACCESSDENIED, knownDevice == m_devices.end());
-
- if (!knownDevice->second.MemoryMapping)
- {
- // Get an interface to the worker process to query devices.
- if (!m_deviceAccess)
- {
- static LxssDynamicFunction> getVmWorker{
- c_vmwpctrlModuleName, "GetVmWorkerProcess"};
- THROW_IF_FAILED(getVmWorker(m_runtimeId, __uuidof(*m_deviceAccess), reinterpret_cast(&m_deviceAccess)));
- }
-
- THROW_HR_IF(E_NOINTERFACE, !m_deviceAccess);
-
- // Retrieve the device specific interface to manage mapped sections.
- wil::com_ptr device;
- THROW_IF_FAILED(m_deviceAccess->GetDevice(FLEXIO_DEVICE_ID, InstanceId, &device));
- knownDevice->second.MemoryMapping = device.query();
- }
-
- THROW_IF_FAILED(knownDevice->second.MemoryMapping->CreateSectionBackedMmioRange(
- static_cast(BarIndex), BarOffsetInPages, PageCount, static_cast(MappingFlags), SectionHandle, SectionOffsetInPages));
-
- return S_OK;
-}
-CATCH_RETURN()
-
-HRESULT
-DeviceHostProxy::DestroySectionBackedMmioRange(const GUID& InstanceId, UINT8 BarIndex, UINT64 BarOffsetInPages)
-try
-{
- auto lock = m_devicesLock.lock_exclusive();
- RETURN_HR_IF(E_CHANGED_STATE, m_devicesShutdown);
- const auto device = m_devices.find(InstanceId);
- RETURN_HR_IF(E_ACCESSDENIED, device == m_devices.end() || !device->second.MemoryMapping);
- RETURN_IF_FAILED(device->second.MemoryMapping->DestroySectionBackedMmioRange(static_cast(BarIndex), BarOffsetInPages));
- return S_OK;
-}
+// Copyright (C) Microsoft Corporation. All rights reserved.
+
+#include "precomp.h"
+#include "DeviceHostProxy.h"
+
+// This template works around a limitation with decltype on overloaded functions. It will be able
+// to get the correct version of GetVmWorkerProcess based on the provided type arguments. By
+// doing it this way, a compiler error will be generated if someone changes the signature of
+// GetVmWorkerProcess.
+//
+// The way this works: decltype(GetVmWorkerProcess) does not work because it's overloaded.
+// decltype(GetVmWorkerProcess(arg1, ...)) works to select an overload if you have values of the
+// correct type (std::declval() generates a value of the specified type), however the result
+// of that is the function's return type, not the function's type, so the argument types must
+// be repeated to reconstruct the function type.
+template
+using GetVmWorkerProcessType = decltype(GetVmWorkerProcess(std::declval()...))(Args...);
+
+// Limit the number of allowed doorbells registered by an external HDV vdev. Currently virtio-9p only uses
+// one doorbell and wsldevicehost uses only two.
+#define DEVICE_HOST_PROXY_DOORBELL_LIMIT 8
+
+using namespace wsl::windows::common::hcs;
+
+DeviceHostProxy::DeviceHostProxy(const std::wstring& VmId, const GUID& RuntimeId) :
+ m_systemId{VmId}, m_runtimeId{RuntimeId}, m_system{wsl::windows::common::hcs::OpenComputeSystem(VmId.c_str(), GENERIC_ALL)}, m_shutdown{false}
+{
+ m_devicesShutdown = false;
+}
+
+GUID DeviceHostProxy::AddNewDevice(const GUID& Type, const wil::com_ptr& Plan9Fs, const std::wstring& VirtIoTag)
+{
+ const wrl::ComPtr thisUnknown{CastToUnknown()};
+ GUID instanceId{};
+ THROW_IF_FAILED(UuidCreate(&instanceId));
+ // Tell the device host to create the device.
+ THROW_IF_FAILED(Plan9Fs->CreateVirtioDevice(m_systemId.c_str(), thisUnknown.Get(), VirtIoTag.c_str(), &instanceId));
+
+ // Add the instance ID to the list of known devices. This must be done before the device is
+ // added to the system, because doing that can cause the register doorbell function to be
+ // called.
+ // N.B. It will be removed if there is a failure.
+ {
+ auto lock = m_devicesLock.lock_exclusive();
+ THROW_HR_IF(E_CHANGED_STATE, m_devicesShutdown);
+
+ m_devices.emplace(instanceId, DeviceHostProxyEntry{});
+ }
+
+ auto removeOnFailure = wil::scope_exit_log(WI_DIAGNOSTICS_INFO, [&]() {
+ auto lock = m_devicesLock.lock_exclusive();
+ m_devices.erase(instanceId);
+ });
+
+ // Add the device to the compute system on behalf of the device host.
+ ModifySettingRequest request;
+ request.RequestType = ModifyRequestType::Add;
+ request.ResourcePath = L"VirtualMachine/Devices/FlexibleIov/";
+ request.ResourcePath += wsl::shared::string::GuidToString(instanceId, wsl::shared::string::GuidToStringFlags::None);
+ request.Settings.EmulatorId = Type;
+ request.Settings.HostingModel = FlexibleIoDeviceHostingModel::ExternalRestricted;
+ wsl::windows::common::hcs::ModifyComputeSystem(m_system.get(), wsl::shared::ToJsonW(request).c_str());
+ removeOnFailure.release();
+ return instanceId;
+}
+
+void DeviceHostProxy::AddRemoteFileSystem(const GUID& ImplementationClsid, const std::wstring& Tag, const wil::com_ptr& Plan9Fs)
+{
+ auto lock = m_lock.lock_exclusive();
+ THROW_HR_IF(E_CHANGED_STATE, m_shutdown);
+
+ // Make sure there are no duplicate tags.
+ for (auto& entry : m_fileSystems)
+ {
+ THROW_HR_IF(E_INVALIDARG, entry.ImplementationClsid == ImplementationClsid && entry.Tag == Tag);
+ }
+
+ m_fileSystems.emplace_back(ImplementationClsid, Tag, Plan9Fs);
+}
+
+wil::com_ptr DeviceHostProxy::GetRemoteFileSystem(const GUID& ImplementationClsid, std::wstring_view Tag)
+{
+ auto lock = m_lock.lock_shared();
+ THROW_HR_IF(E_CHANGED_STATE, m_shutdown);
+
+ for (auto& entry : m_fileSystems)
+ {
+ if (entry.ImplementationClsid == ImplementationClsid && entry.Tag == Tag)
+ {
+ return entry.Instance;
+ }
+ }
+
+ return {};
+}
+
+void DeviceHostProxy::Shutdown()
+{
+ {
+ auto lock = m_lock.lock_exclusive();
+ m_fileSystems.clear();
+ m_shutdown = true;
+ }
+
+ {
+ auto lock = m_devicesLock.lock_exclusive();
+ m_devices.clear();
+ m_devicesShutdown = true;
+ }
+}
+
+HRESULT
+DeviceHostProxy::RegisterDeviceHost(_In_ IVmDeviceHost* DeviceHost, _In_ DWORD ProcessId, _Out_ UINT64* IpcSectionHandle)
+try
+{
+ //
+ // Because HdvProxyDeviceHost is not part of the API set, it is loaded here dynamically.
+ //
+
+ static LxssDynamicFunction proxyDeviceHost{c_hdvModuleName, "HdvProxyDeviceHost"};
+ const wil::com_ptr remoteHost = DeviceHost;
+ const wil::com_ptr unknown = remoteHost.query();
+ THROW_IF_FAILED(proxyDeviceHost(m_system.get(), unknown.get(), ProcessId, IpcSectionHandle));
+ return S_OK;
+}
+CATCH_RETURN()
+
+HRESULT
+DeviceHostProxy::NotifyAllDevicesInUse(_In_ LPCWSTR Tag)
+try
+{
+ //
+ // Add another Plan9 virtio device to the guest so additional mount commands will be possible.
+ // This callback should be unused by virtiofs devices because a device is created for every
+ // AddSharePath call.
+ //
+ auto p9fs = GetRemoteFileSystem(__uuidof(p9fs::Plan9FileSystem), Tag);
+ THROW_HR_IF(E_NOT_SET, !p9fs);
+ (void)AddNewDevice(VIRTIO_PLAN9_DEVICE_ID, p9fs, Tag);
+ return S_OK;
+}
+CATCH_RETURN()
+
+HRESULT
+DeviceHostProxy::RegisterDoorbell(const GUID& InstanceId, UINT8 BarIndex, UINT64 Offset, UINT64 TriggerValue, UINT64 Flags, HANDLE Event)
+try
+{
+ auto lock = m_devicesLock.lock_exclusive();
+ RETURN_HR_IF(E_CHANGED_STATE, m_devicesShutdown);
+
+ // Check if the device is one of the known devices that doorbells can be registered for, and
+ // if the device has not already registered a doorbell.
+ // N.B. For security it is enforced that each device can only register a small number of doorbells.
+ // Currently virtio-9p only uses one and the external virtio device uses two.
+ const auto knownDevice = m_devices.find(InstanceId);
+ RETURN_HR_IF(E_ACCESSDENIED, knownDevice == m_devices.end() || knownDevice->second.DoorbellCount == DEVICE_HOST_PROXY_DOORBELL_LIMIT);
+
+ if (!knownDevice->second.MemoryNotification)
+ {
+ // Get an interface to the worker process to query devices.
+ if (!m_deviceAccess)
+ {
+ static LxssDynamicFunction> getVmWorker{
+ c_vmwpctrlModuleName, "GetVmWorkerProcess"};
+
+ RETURN_IF_FAILED(getVmWorker(m_runtimeId, __uuidof(*m_deviceAccess), reinterpret_cast(&m_deviceAccess)));
+ }
+
+ RETURN_HR_IF(E_NOINTERFACE, !m_deviceAccess);
+
+ // Retrieve the device's memory notification interface to register the doorbell, and store it
+ // to be used during unregistration.
+ wil::com_ptr device;
+ RETURN_IF_FAILED(m_deviceAccess->GetDevice(FLEXIO_DEVICE_ID, InstanceId, &device));
+ knownDevice->second.MemoryNotification = device.query();
+ }
+
+ const auto result = knownDevice->second.MemoryNotification->RegisterDoorbell(
+ static_cast(BarIndex), Offset, TriggerValue, Flags, Event);
+
+ if (SUCCEEDED(result))
+ {
+ ++knownDevice->second.DoorbellCount;
+ }
+
+ return result;
+}
+CATCH_RETURN()
+
+HRESULT
+DeviceHostProxy::UnregisterDoorbell(const GUID& InstanceId, UINT8 BarIndex, UINT64 Offset, UINT64 TriggerValue, UINT64 Flags)
+try
+{
+ auto lock = m_devicesLock.lock_exclusive();
+ RETURN_HR_IF(E_CHANGED_STATE, m_devicesShutdown);
+
+ // Check if the device is a known device and has registered a doorbell.
+ // N.B. If the device is being removed, the device can't be retrieved from the worker process
+ // so it's necessary to use the stored COM pointer.
+ const auto device = m_devices.find(InstanceId);
+ RETURN_HR_IF(E_ACCESSDENIED, device == m_devices.end() || device->second.DoorbellCount == 0);
+ RETURN_IF_FAILED(device->second.MemoryNotification->UnregisterDoorbell(static_cast(BarIndex), Offset, TriggerValue, Flags));
+
+ if (--device->second.DoorbellCount == 0)
+ {
+ device->second.MemoryNotification.reset();
+ }
+
+ return S_OK;
+}
+CATCH_RETURN()
+
+HRESULT
+DeviceHostProxy::CreateSectionBackedMmioRange(
+ const GUID& InstanceId, UINT8 BarIndex, UINT64 BarOffsetInPages, UINT64 PageCount, UINT64 MappingFlags, HANDLE SectionHandle, UINT64 SectionOffsetInPages)
+try
+{
+ auto lock = m_devicesLock.lock_exclusive();
+ RETURN_HR_IF(E_CHANGED_STATE, m_devicesShutdown);
+
+ // Check if the device is one of the known devices.
+ const auto knownDevice = m_devices.find(InstanceId);
+ THROW_HR_IF(E_ACCESSDENIED, knownDevice == m_devices.end());
+
+ if (!knownDevice->second.MemoryMapping)
+ {
+ // Get an interface to the worker process to query devices.
+ if (!m_deviceAccess)
+ {
+ static LxssDynamicFunction> getVmWorker{
+ c_vmwpctrlModuleName, "GetVmWorkerProcess"};
+ THROW_IF_FAILED(getVmWorker(m_runtimeId, __uuidof(*m_deviceAccess), reinterpret_cast(&m_deviceAccess)));
+ }
+
+ THROW_HR_IF(E_NOINTERFACE, !m_deviceAccess);
+
+ // Retrieve the device specific interface to manage mapped sections.
+ wil::com_ptr device;
+ THROW_IF_FAILED(m_deviceAccess->GetDevice(FLEXIO_DEVICE_ID, InstanceId, &device));
+ knownDevice->second.MemoryMapping = device.query();
+ }
+
+ THROW_IF_FAILED(knownDevice->second.MemoryMapping->CreateSectionBackedMmioRange(
+ static_cast(BarIndex), BarOffsetInPages, PageCount, static_cast(MappingFlags), SectionHandle, SectionOffsetInPages));
+
+ return S_OK;
+}
+CATCH_RETURN()
+
+HRESULT
+DeviceHostProxy::DestroySectionBackedMmioRange(const GUID& InstanceId, UINT8 BarIndex, UINT64 BarOffsetInPages)
+try
+{
+ auto lock = m_devicesLock.lock_exclusive();
+ RETURN_HR_IF(E_CHANGED_STATE, m_devicesShutdown);
+ const auto device = m_devices.find(InstanceId);
+ RETURN_HR_IF(E_ACCESSDENIED, device == m_devices.end() || !device->second.MemoryMapping);
+ RETURN_IF_FAILED(device->second.MemoryMapping->DestroySectionBackedMmioRange(static_cast(BarIndex), BarOffsetInPages));
+ return S_OK;
+}
CATCH_RETURN()
\ No newline at end of file
diff --git a/src/windows/service/exe/DeviceHostProxy.h b/src/windows/common/DeviceHostProxy.h
similarity index 97%
rename from src/windows/service/exe/DeviceHostProxy.h
rename to src/windows/common/DeviceHostProxy.h
index c5cc165..5f951b3 100644
--- a/src/windows/service/exe/DeviceHostProxy.h
+++ b/src/windows/common/DeviceHostProxy.h
@@ -1,76 +1,76 @@
-// Copyright (C) Microsoft Corporation. All rights reserved.
-
-#pragma once
-
-#include
-#include "hcs.hpp"
-
-namespace wrl = Microsoft::WRL;
-
-class DeviceHostProxy : public wrl::RuntimeClass, IVmDeviceHostSupport, IPlan9FileSystemHost>
-{
-public:
- DeviceHostProxy(const std::wstring& VmId, const GUID& RuntimeId);
-
- GUID AddNewDevice(const GUID& Type, const wil::com_ptr& Plan9Fs, const std::wstring& VirtIoTag);
-
- void AddRemoteFileSystem(const GUID& ImplementationClsid, const std::wstring& Tag, const wil::com_ptr& Plan9Fs);
-
- wil::com_ptr GetRemoteFileSystem(const GUID& ImplementationClsid, std::wstring_view Tag);
-
- void Shutdown();
-
- //
- // IVmDeviceHostSupport
- //
- IFACEMETHOD(RegisterDeviceHost)(_In_ IVmDeviceHost* DeviceHost, _In_ DWORD ProcessId, _Out_ UINT64* IpcSectionHandle) override;
-
- //
- // IPlan9FileSystemHost
- //
- IFACEMETHOD(NotifyAllDevicesInUse)(_In_ LPCWSTR Tag) override;
-
- IFACEMETHOD(RegisterDoorbell)(const GUID& InstanceId, UINT8 BarIndex, UINT64 Offset, UINT64 TriggerValue, UINT64 Flags, HANDLE Event) override;
-
- IFACEMETHOD(UnregisterDoorbell)(const GUID& InstanceId, UINT8 BarIndex, UINT64 Offset, UINT64 TriggerValue, UINT64 Flags) override;
-
- IFACEMETHOD(CreateSectionBackedMmioRange)(
- const GUID& InstanceId, UINT8 BarIndex, UINT64 BarOffsetInPages, UINT64 PageCount, UINT64 MappingFlags, HANDLE SectionHandle, UINT64 SectionOffsetInPages) override;
-
- IFACEMETHOD(DestroySectionBackedMmioRange)(const GUID& InstanceId, UINT8 BarIndex, UINT64 BarOffsetInPages) override;
-
-private:
- struct RemoteFileSystemInfo
- {
- RemoteFileSystemInfo(GUID ImplementationClsid, const std::wstring& Tag, const wil::com_ptr& Instance) :
- ImplementationClsid{ImplementationClsid}, Tag{Tag}, Instance{Instance}
- {
- }
-
- GUID ImplementationClsid;
- std::wstring Tag;
- wil::com_ptr Instance;
- };
-
- std::wstring m_systemId;
- GUID m_runtimeId;
- wsl::windows::common::hcs::unique_hcs_system m_system;
- wil::srwlock m_lock;
- std::vector m_fileSystems;
- bool m_shutdown;
-
- struct DeviceHostProxyEntry
- {
- wil::com_ptr MemoryNotification;
- wil::com_ptr MemoryMapping;
- size_t DoorbellCount = 0;
- };
-
- wil::com_ptr m_deviceAccess;
- wil::srwlock m_devicesLock;
- std::map m_devices;
- bool m_devicesShutdown;
-
- static constexpr LPCWSTR c_hdvModuleName = L"vmdevicehost.dll";
- static constexpr LPCWSTR c_vmwpctrlModuleName = L"vmwpctrl.dll";
+// Copyright (C) Microsoft Corporation. All rights reserved.
+
+#pragma once
+
+#include
+#include "hcs.hpp"
+
+namespace wrl = Microsoft::WRL;
+
+class DeviceHostProxy : public wrl::RuntimeClass, IVmDeviceHostSupport, IPlan9FileSystemHost>
+{
+public:
+ DeviceHostProxy(const std::wstring& VmId, const GUID& RuntimeId);
+
+ GUID AddNewDevice(const GUID& Type, const wil::com_ptr& Plan9Fs, const std::wstring& VirtIoTag);
+
+ void AddRemoteFileSystem(const GUID& ImplementationClsid, const std::wstring& Tag, const wil::com_ptr& Plan9Fs);
+
+ wil::com_ptr GetRemoteFileSystem(const GUID& ImplementationClsid, std::wstring_view Tag);
+
+ void Shutdown();
+
+ //
+ // IVmDeviceHostSupport
+ //
+ IFACEMETHOD(RegisterDeviceHost)(_In_ IVmDeviceHost* DeviceHost, _In_ DWORD ProcessId, _Out_ UINT64* IpcSectionHandle) override;
+
+ //
+ // IPlan9FileSystemHost
+ //
+ IFACEMETHOD(NotifyAllDevicesInUse)(_In_ LPCWSTR Tag) override;
+
+ IFACEMETHOD(RegisterDoorbell)(const GUID& InstanceId, UINT8 BarIndex, UINT64 Offset, UINT64 TriggerValue, UINT64 Flags, HANDLE Event) override;
+
+ IFACEMETHOD(UnregisterDoorbell)(const GUID& InstanceId, UINT8 BarIndex, UINT64 Offset, UINT64 TriggerValue, UINT64 Flags) override;
+
+ IFACEMETHOD(CreateSectionBackedMmioRange)(
+ const GUID& InstanceId, UINT8 BarIndex, UINT64 BarOffsetInPages, UINT64 PageCount, UINT64 MappingFlags, HANDLE SectionHandle, UINT64 SectionOffsetInPages) override;
+
+ IFACEMETHOD(DestroySectionBackedMmioRange)(const GUID& InstanceId, UINT8 BarIndex, UINT64 BarOffsetInPages) override;
+
+private:
+ struct RemoteFileSystemInfo
+ {
+ RemoteFileSystemInfo(GUID ImplementationClsid, const std::wstring& Tag, const wil::com_ptr& Instance) :
+ ImplementationClsid{ImplementationClsid}, Tag{Tag}, Instance{Instance}
+ {
+ }
+
+ GUID ImplementationClsid;
+ std::wstring Tag;
+ wil::com_ptr Instance;
+ };
+
+ std::wstring m_systemId;
+ GUID m_runtimeId;
+ wsl::windows::common::hcs::unique_hcs_system m_system;
+ wil::srwlock m_lock;
+ std::vector m_fileSystems;
+ bool m_shutdown;
+
+ struct DeviceHostProxyEntry
+ {
+ wil::com_ptr MemoryNotification;
+ wil::com_ptr MemoryMapping;
+ size_t DoorbellCount = 0;
+ };
+
+ wil::com_ptr m_deviceAccess;
+ wil::srwlock m_devicesLock;
+ std::map m_devices;
+ bool m_devicesShutdown;
+
+ static constexpr LPCWSTR c_hdvModuleName = L"vmdevicehost.dll";
+ static constexpr LPCWSTR c_vmwpctrlModuleName = L"vmwpctrl.dll";
};
\ No newline at end of file
diff --git a/src/windows/service/exe/Dmesg.cpp b/src/windows/common/Dmesg.cpp
similarity index 100%
rename from src/windows/service/exe/Dmesg.cpp
rename to src/windows/common/Dmesg.cpp
diff --git a/src/windows/service/exe/Dmesg.h b/src/windows/common/Dmesg.h
similarity index 100%
rename from src/windows/service/exe/Dmesg.h
rename to src/windows/common/Dmesg.h
diff --git a/src/windows/service/exe/DnsResolver.cpp b/src/windows/common/DnsResolver.cpp
similarity index 97%
rename from src/windows/service/exe/DnsResolver.cpp
rename to src/windows/common/DnsResolver.cpp
index 090e9b4..3fc13bf 100644
--- a/src/windows/service/exe/DnsResolver.cpp
+++ b/src/windows/common/DnsResolver.cpp
@@ -1,417 +1,417 @@
-// Copyright (C) Microsoft Corporation. All rights reserved.
-
-#include
-#include "precomp.h"
-#include "DnsResolver.h"
-
-using wsl::core::networking::DnsResolver;
-
-static constexpr auto c_dnsModuleName = L"dnsapi.dll";
-
-std::optional> DnsResolver::s_dnsQueryRaw;
-std::optional> DnsResolver::s_dnsCancelQueryRaw;
-std::optional> DnsResolver::s_dnsQueryRawResultFree;
-
-HRESULT DnsResolver::LoadDnsResolverMethods() noexcept
-{
- static wil::shared_hmodule dnsModule;
- static DWORD loadError = ERROR_SUCCESS;
- static std::once_flag dnsLoadFlag;
-
- // Load DNS dll only once
- std::call_once(dnsLoadFlag, [&]() {
- dnsModule.reset(LoadLibraryEx(c_dnsModuleName, nullptr, LOAD_LIBRARY_SEARCH_SYSTEM32));
- if (!dnsModule)
- {
- loadError = GetLastError();
- }
- });
-
- RETURN_IF_WIN32_ERROR_MSG(loadError, "LoadLibraryEx %ls", c_dnsModuleName);
-
- // Initialize dynamic functions for the DNS tunneling Windows APIs.
- // using the non-throwing instance of LxssDynamicFunction as to not end up in the Error telemetry
- LxssDynamicFunction local_dnsQueryRaw{DynamicFunctionErrorLogs::None};
- RETURN_IF_FAILED_EXPECTED(local_dnsQueryRaw.load(dnsModule, "DnsQueryRaw"));
- LxssDynamicFunction local_dnsCancelQueryRaw{DynamicFunctionErrorLogs::None};
- RETURN_IF_FAILED_EXPECTED(local_dnsCancelQueryRaw.load(dnsModule, "DnsCancelQueryRaw"));
- LxssDynamicFunction local_dnsQueryRawResultFree{DynamicFunctionErrorLogs::None};
- RETURN_IF_FAILED_EXPECTED(local_dnsQueryRawResultFree.load(dnsModule, "DnsQueryRawResultFree"));
-
- // Make a dummy call to the DNS APIs to verify if they are working. The APIs are going to be present
- // on older Windows versions, where they can be turned on/off. If turned off, the APIs
- // will be unusable and will return ERROR_CALL_NOT_IMPLEMENTED.
- if (local_dnsQueryRaw(nullptr, nullptr) == ERROR_CALL_NOT_IMPLEMENTED)
- {
- RETURN_IF_WIN32_ERROR_EXPECTED(ERROR_CALL_NOT_IMPLEMENTED);
- }
-
- s_dnsQueryRaw.emplace(std::move(local_dnsQueryRaw));
- s_dnsCancelQueryRaw.emplace(std::move(local_dnsCancelQueryRaw));
- s_dnsQueryRawResultFree.emplace(std::move(local_dnsQueryRawResultFree));
- return S_OK;
-}
-
-DnsResolver::DnsResolver(wil::unique_socket&& dnsHvsocket, DnsResolverFlags flags) :
- m_dnsChannel(
- std::move(dnsHvsocket),
- [this](const gsl::span dnsBuffer, const LX_GNS_DNS_CLIENT_IDENTIFIER& dnsClientIdentifier) {
- ProcessDnsRequest(dnsBuffer, dnsClientIdentifier);
- }),
- m_flags(flags)
-{
- // Initialize as signaled, as there are no requests yet
- m_allRequestsFinished.SetEvent();
-
- // Read external interface constraint regkey
- const auto lxssKey = windows::common::registry::OpenLxssMachineKey(KEY_READ);
- m_externalInterfaceConstraintName =
- windows::common::registry::ReadString(lxssKey.get(), nullptr, c_interfaceConstraintKey, L"");
-
- if (!m_externalInterfaceConstraintName.empty())
- {
- ResolveExternalInterfaceConstraintIndex();
-
- WSL_LOG(
- "DnsResolver::DnsResolver",
- TraceLoggingValue(m_externalInterfaceConstraintName.c_str(), "m_externalInterfaceConstraintName"),
- TraceLoggingValue(m_externalInterfaceConstraintIndex, "m_externalInterfaceConstraintIndex"));
-
- // Register for interface change notifications. Notifications are used to determine if the external interface constraint setting is applicable.
- THROW_IF_WIN32_ERROR(NotifyIpInterfaceChange(AF_UNSPEC, &DnsResolver::InterfaceChangeCallback, this, FALSE, &m_interfaceNotificationHandle));
- }
-}
-
-DnsResolver::~DnsResolver() noexcept
-{
- Stop();
-}
-
-void DnsResolver::GenerateTelemetry() noexcept
-try
-{
- // Find the 3 most common DNS API failures
- uint32_t mostCommonDnsStatusError = 0;
- uint32_t mostCommonDnsStatusErrorCount = 0;
- uint32_t secondCommonDnsStatusError = 0;
- uint32_t secondCommonDnsStatusErrorCount = 0;
- uint32_t thirdCommonDnsStatusError = 0;
- uint32_t thirdCommonDnsStatusErrorCount = 0;
-
- std::vector> failures(m_dnsApiFailures.size());
- std::copy(m_dnsApiFailures.begin(), m_dnsApiFailures.end(), failures.begin());
-
- // Sort in descending order based on failure count
- std::sort(failures.begin(), failures.end(), [](const auto& lhs, const auto& rhs) { return lhs.second > rhs.second; });
-
- if (failures.size() >= 1)
- {
- mostCommonDnsStatusError = failures[0].first;
- mostCommonDnsStatusErrorCount = failures[0].second;
- }
- if (failures.size() >= 2)
- {
- secondCommonDnsStatusError = failures[1].first;
- secondCommonDnsStatusErrorCount = failures[1].second;
- }
- if (failures.size() >= 3)
- {
- thirdCommonDnsStatusError = failures[2].first;
- thirdCommonDnsStatusErrorCount = failures[2].second;
- }
-
- // Add telemetry with DNS tunneling statistics, before shutting down
- WSL_LOG(
- "DnsTunnelingStatistics",
- TraceLoggingValue(m_totalUdpQueries.load(), "totalUdpQueries"),
- TraceLoggingValue(m_successfulUdpQueries.load(), "successfulUdpQueries"),
- TraceLoggingValue(m_totalTcpQueries.load(), "totalTcpQueries"),
- TraceLoggingValue(m_successfulTcpQueries.load(), "successfulTcpQueries"),
- TraceLoggingValue(m_queriesWithNullResult.load(), "queriesWithNullResult"),
- TraceLoggingValue(m_failedDnsQueryRawCalls.load(), "FailedDnsQueryRawCalls"),
- TraceLoggingValue(m_dnsApiFailures.size(), "totalDnsStatusErrorInstances"),
- TraceLoggingValue(mostCommonDnsStatusError, "mostCommonDnsStatusError"),
- TraceLoggingValue(mostCommonDnsStatusErrorCount, "mostCommonDnsStatusErrorCount"),
- TraceLoggingValue(secondCommonDnsStatusError, "secondCommonDnsStatusError"),
- TraceLoggingValue(secondCommonDnsStatusErrorCount, "secondCommonDnsStatusErrorCount"),
- TraceLoggingValue(thirdCommonDnsStatusError, "thirdCommonDnsStatusError"),
- TraceLoggingValue(thirdCommonDnsStatusErrorCount, "thirdCommonDnsStatusErrorCount"));
-}
-CATCH_LOG()
-
-void DnsResolver::Stop() noexcept
-try
-{
- WSL_LOG("DnsResolver::Stop");
-
- // Scoped m_dnsLock
- {
- const std::lock_guard lock(m_dnsLock);
-
- m_stopped = true;
-
- // Cancel existing requests. Cancel is complete when DnsQueryRawCallback is
- // invoked with status == ERROR_CANCELLED
- // N.B. Cancelling can end up calling the DnsQueryRawCallback directly on this same thread. i.e., while this
- // lock is held. Which is fine because m_dnsLock is a recursive mutex.
- // N.B. Cancelling a query will synchronously remove the query from m_dnsRequests, which invalidates iterators.
-
- std::vector cancelHandles;
- cancelHandles.reserve(m_dnsRequests.size());
-
- for (auto& [_, context] : m_dnsRequests)
- {
- cancelHandles.emplace_back(&context->m_cancelHandle);
- }
-
- for (const auto e : cancelHandles)
- {
- LOG_IF_WIN32_ERROR(s_dnsCancelQueryRaw.value()(e));
- }
- }
-
- // Wait for all requests to complete. At this point no new requests can be started since the object is stopped.
- // We are only waiting for existing requests to finish.
- m_allRequestsFinished.wait();
-
- // Stop the response queue first as it can make calls in m_dnsChannel
- m_dnsResponseQueue.cancel();
-
- m_dnsChannel.Stop();
-
- // Stop interface change notifications
- m_interfaceNotificationHandle.reset();
-
- GenerateTelemetry();
-}
-CATCH_LOG()
-
-void DnsResolver::ProcessDnsRequest(const gsl::span dnsBuffer, const LX_GNS_DNS_CLIENT_IDENTIFIER& dnsClientIdentifier) noexcept
-try
-{
- const std::lock_guard lock(m_dnsLock);
- if (m_stopped)
- {
- return;
- }
-
- WSL_LOG_DEBUG(
- "DnsResolver::ProcessDnsRequest - received new DNS request",
- TraceLoggingValue(dnsBuffer.size(), "DNS buffer size"),
- TraceLoggingValue(dnsClientIdentifier.Protocol == IPPROTO_UDP ? "UDP" : "TCP", "Protocol"),
- TraceLoggingValue(dnsClientIdentifier.DnsClientId, "DNS client id"),
- TraceLoggingValue(!m_externalInterfaceConstraintName.empty(), "Is ExternalInterfaceConstraint configured"),
- TraceLoggingValue(m_externalInterfaceConstraintIndex, "m_externalInterfaceConstraintIndex"));
-
- // If the external interface constraint is configured but it is *not* present/up, WSL should be net-blind, so we avoid making DNS requests.
- if (!m_externalInterfaceConstraintName.empty() && m_externalInterfaceConstraintIndex == 0)
- {
- return;
- }
-
- dnsClientIdentifier.Protocol == IPPROTO_UDP ? m_totalUdpQueries++ : m_totalTcpQueries++;
-
- // Get next request id. If value reaches UINT_MAX + 1 it will be automatically reset to 0
- const auto requestId = m_currentRequestId++;
-
- // Create the DNS request context
- auto context = std::make_unique(
- requestId, dnsClientIdentifier, [this](_Inout_ DnsResolver::DnsQueryContext* context, _Inout_opt_ DNS_QUERY_RAW_RESULT* queryResults) {
- HandleDnsQueryCompletion(context, queryResults);
- });
-
- auto [it, _] = m_dnsRequests.emplace(requestId, std::move(context));
- const auto localContext = it->second.get();
-
- auto removeContextOnError = wil::scope_exit([&] { WI_VERIFY(m_dnsRequests.erase(requestId) == 1); });
-
- // Fill DNS request structure
- DNS_QUERY_RAW_REQUEST request{};
-
- request.version = DNS_QUERY_RAW_REQUEST_VERSION1;
- request.resultsVersion = DNS_QUERY_RAW_RESULTS_VERSION1;
- request.dnsQueryRawSize = static_cast(dnsBuffer.size());
- request.dnsQueryRaw = (PBYTE)dnsBuffer.data();
- request.protocol = (dnsClientIdentifier.Protocol == IPPROTO_TCP) ? DNS_PROTOCOL_TCP : DNS_PROTOCOL_UDP;
- request.queryCompletionCallback = DnsResolver::DnsQueryRawCallback;
- request.queryContext = localContext;
- // Only unicast UDP & TCP queries are tunneled. Pass this flag to tell Windows DNS client to *not* resolve using multicast.
- request.queryOptions |= DNS_QUERY_NO_MULTICAST;
-
- // In a DNS request from Linux there might be DNS records that Windows DNS client does not know how to parse.
- // By default in this case Windows will fail the request. When the flag is enabled, Windows will extract the
- // question from the DNS request and attempt to resolve it, ignoring the unknown records.
- if (WI_IsFlagSet(m_flags, DnsResolverFlags::BestEffortDnsParsing))
- {
- request.queryRawOptions |= DNS_QUERY_RAW_OPTION_BEST_EFFORT_PARSE;
- }
-
- // If the external interface constraint is configured and present on the host, only send DNS requests on that interface.
- if (m_externalInterfaceConstraintIndex != 0)
- {
- request.interfaceIndex = m_externalInterfaceConstraintIndex;
- }
-
- // Start the DNS request
- // N.B. All DNS requests will bypass the Windows DNS cache
- const auto result = s_dnsQueryRaw.value()(&request, &localContext->m_cancelHandle);
- if (result != DNS_REQUEST_PENDING)
- {
- m_failedDnsQueryRawCalls++;
-
- WSL_LOG(
- "ProcessDnsRequestFailed",
- TraceLoggingValue(requestId, "requestId"),
- TraceLoggingValue(result, "result"),
- TraceLoggingValue("DnsQueryRaw", "executionStep"));
- return;
- }
-
- removeContextOnError.release();
-
- m_allRequestsFinished.ResetEvent();
-}
-CATCH_LOG()
-
-void DnsResolver::HandleDnsQueryCompletion(_Inout_ DnsResolver::DnsQueryContext* queryContext, _Inout_opt_ DNS_QUERY_RAW_RESULT* queryResults) noexcept
-try
-{
- // Always free the query result structure
- const auto freeQueryResults = wil::scope_exit([&] {
- if (queryResults != nullptr)
- {
- s_dnsQueryRawResultFree.value()(queryResults);
- }
- });
-
- const std::lock_guard lock(m_dnsLock);
-
- if (queryResults != nullptr)
- {
- WSL_LOG(
- "DnsResolver::HandleDnsQueryCompletion",
- TraceLoggingValue(queryContext->m_id, "queryContext->m_id"),
- TraceLoggingValue(queryResults->queryStatus, "queryResults->queryStatus"),
- TraceLoggingValue(queryResults->queryRawResponse != nullptr, "validResponse"));
-
- // Note: The response may be valid even if queryResults->queryStatus is not 0, for example when the DNS server returns a negative response.
- if (queryResults->queryRawResponse != nullptr)
- {
- queryContext->m_dnsClientIdentifier.Protocol == IPPROTO_UDP ? m_successfulUdpQueries++ : m_successfulTcpQueries++;
- }
- // the Windows DNS API returned failure
- else
- {
- if (m_dnsApiFailures.find(queryResults->queryStatus) == m_dnsApiFailures.end())
- {
- m_dnsApiFailures[queryResults->queryStatus] = 1;
- }
- else
- {
- m_dnsApiFailures[queryResults->queryStatus]++;
- }
- }
- }
- else
- {
- WSL_LOG(
- "DnsResolver::HandleDnsQueryCompletion - received a NULL queryResults",
- TraceLoggingValue(queryContext->m_id, "queryContext->m_id"));
- m_queriesWithNullResult++;
- }
-
- if (!m_stopped && queryResults != nullptr && queryResults->queryRawResponse != nullptr)
- {
- // Copy DNS response buffer
- std::vector dnsResponse(queryResults->queryRawResponseSize);
- CopyMemory(dnsResponse.data(), queryResults->queryRawResponse, queryResults->queryRawResponseSize);
-
- WSL_LOG_DEBUG(
- "DnsResolver::HandleDnsQueryCompletion - received new DNS response",
- TraceLoggingValue(dnsResponse.size(), "DNS buffer size"),
- TraceLoggingValue(queryContext->m_dnsClientIdentifier.Protocol == IPPROTO_UDP ? "UDP" : "TCP", "Protocol"),
- TraceLoggingValue(queryContext->m_dnsClientIdentifier.DnsClientId, "DNS client id"));
-
- // Schedule the DNS response to be sent to Linux
- m_dnsResponseQueue.submit([this, dnsResponse = std::move(dnsResponse), dnsClientIdentifier = queryContext->m_dnsClientIdentifier]() mutable {
- m_dnsChannel.SendDnsMessage(gsl::make_span(dnsResponse), dnsClientIdentifier);
- });
- }
-
- // Stop tracking this DNS request and delete the request context
- WI_VERIFY(m_dnsRequests.erase(queryContext->m_id) == 1);
-
- // Set event if all tracked requests have finished
- if (m_dnsRequests.empty())
- {
- m_allRequestsFinished.SetEvent();
- }
-}
-CATCH_LOG()
-
-void DnsResolver::ResolveExternalInterfaceConstraintIndex() noexcept
-try
-{
- const std::lock_guard lock(m_dnsLock);
- if (m_stopped)
- {
- return;
- }
-
- if (m_externalInterfaceConstraintName.empty())
- {
- return;
- }
-
- NET_LUID interfaceLuid{};
- ULONG interfaceIndex = 0;
-
- // Update the interface index on every exit path.
- // The calls below to convert interface name to index will fail if the interface does not exist anymore,
- // in which case we still need to reset the interface index to its default value of 0.
- const auto setInterfaceIndex = wil::scope_exit([&] {
- if (interfaceIndex != m_externalInterfaceConstraintIndex)
- {
- WSL_LOG(
- "DnsResolver::ResolveExternalInterfaceConstraintIndex - setting m_externalInterfaceConstraintIndex to new value",
- TraceLoggingValue(m_externalInterfaceConstraintIndex, "old interface index"),
- TraceLoggingValue(interfaceIndex, "new interface index"));
-
- m_externalInterfaceConstraintIndex = interfaceIndex;
- }
- });
-
- // If external interface constraint is configured, query to see if it's present on the host.
- auto errorCode = ConvertInterfaceAliasToLuid(m_externalInterfaceConstraintName.c_str(), &interfaceLuid);
- if (FAILED_WIN32_LOG(errorCode))
- {
- return;
- }
-
- errorCode = ConvertInterfaceLuidToIndex(&interfaceLuid, reinterpret_cast(&interfaceIndex));
- if (FAILED_WIN32_LOG(errorCode))
- {
- return;
- }
-}
-CATCH_LOG()
-
-VOID CALLBACK DnsResolver::DnsQueryRawCallback(_In_ VOID* queryContext, _Inout_opt_ DNS_QUERY_RAW_RESULT* queryResults) noexcept
-try
-{
- assert(queryContext != nullptr);
-
- const auto context = static_cast(queryContext);
-
- // Call into DnsResolver parent object to process the query result
- context->m_handleQueryCompletion(context, queryResults);
-}
-CATCH_LOG()
-
-VOID CALLBACK DnsResolver::InterfaceChangeCallback(_In_ PVOID context, PMIB_IPINTERFACE_ROW, MIB_NOTIFICATION_TYPE) noexcept
-try
-{
- const auto dnsResolver = static_cast(context);
- dnsResolver->ResolveExternalInterfaceConstraintIndex();
-}
-CATCH_LOG()
+// Copyright (C) Microsoft Corporation. All rights reserved.
+
+#include
+#include "precomp.h"
+#include "DnsResolver.h"
+
+using wsl::core::networking::DnsResolver;
+
+static constexpr auto c_dnsModuleName = L"dnsapi.dll";
+
+std::optional> DnsResolver::s_dnsQueryRaw;
+std::optional> DnsResolver::s_dnsCancelQueryRaw;
+std::optional> DnsResolver::s_dnsQueryRawResultFree;
+
+HRESULT DnsResolver::LoadDnsResolverMethods() noexcept
+{
+ static wil::shared_hmodule dnsModule;
+ static DWORD loadError = ERROR_SUCCESS;
+ static std::once_flag dnsLoadFlag;
+
+ // Load DNS dll only once
+ std::call_once(dnsLoadFlag, [&]() {
+ dnsModule.reset(LoadLibraryEx(c_dnsModuleName, nullptr, LOAD_LIBRARY_SEARCH_SYSTEM32));
+ if (!dnsModule)
+ {
+ loadError = GetLastError();
+ }
+ });
+
+ RETURN_IF_WIN32_ERROR_MSG(loadError, "LoadLibraryEx %ls", c_dnsModuleName);
+
+ // Initialize dynamic functions for the DNS tunneling Windows APIs.
+ // using the non-throwing instance of LxssDynamicFunction as to not end up in the Error telemetry
+ LxssDynamicFunction local_dnsQueryRaw{DynamicFunctionErrorLogs::None};
+ RETURN_IF_FAILED_EXPECTED(local_dnsQueryRaw.load(dnsModule, "DnsQueryRaw"));
+ LxssDynamicFunction local_dnsCancelQueryRaw{DynamicFunctionErrorLogs::None};
+ RETURN_IF_FAILED_EXPECTED(local_dnsCancelQueryRaw.load(dnsModule, "DnsCancelQueryRaw"));
+ LxssDynamicFunction local_dnsQueryRawResultFree{DynamicFunctionErrorLogs::None};
+ RETURN_IF_FAILED_EXPECTED(local_dnsQueryRawResultFree.load(dnsModule, "DnsQueryRawResultFree"));
+
+ // Make a dummy call to the DNS APIs to verify if they are working. The APIs are going to be present
+ // on older Windows versions, where they can be turned on/off. If turned off, the APIs
+ // will be unusable and will return ERROR_CALL_NOT_IMPLEMENTED.
+ if (local_dnsQueryRaw(nullptr, nullptr) == ERROR_CALL_NOT_IMPLEMENTED)
+ {
+ RETURN_IF_WIN32_ERROR_EXPECTED(ERROR_CALL_NOT_IMPLEMENTED);
+ }
+
+ s_dnsQueryRaw.emplace(std::move(local_dnsQueryRaw));
+ s_dnsCancelQueryRaw.emplace(std::move(local_dnsCancelQueryRaw));
+ s_dnsQueryRawResultFree.emplace(std::move(local_dnsQueryRawResultFree));
+ return S_OK;
+}
+
+DnsResolver::DnsResolver(wil::unique_socket&& dnsHvsocket, DnsResolverFlags flags) :
+ m_dnsChannel(
+ std::move(dnsHvsocket),
+ [this](const gsl::span dnsBuffer, const LX_GNS_DNS_CLIENT_IDENTIFIER& dnsClientIdentifier) {
+ ProcessDnsRequest(dnsBuffer, dnsClientIdentifier);
+ }),
+ m_flags(flags)
+{
+ // Initialize as signaled, as there are no requests yet
+ m_allRequestsFinished.SetEvent();
+
+ // Read external interface constraint regkey
+ const auto lxssKey = windows::common::registry::OpenLxssMachineKey(KEY_READ);
+ m_externalInterfaceConstraintName =
+ windows::common::registry::ReadString(lxssKey.get(), nullptr, c_interfaceConstraintKey, L"");
+
+ if (!m_externalInterfaceConstraintName.empty())
+ {
+ ResolveExternalInterfaceConstraintIndex();
+
+ WSL_LOG(
+ "DnsResolver::DnsResolver",
+ TraceLoggingValue(m_externalInterfaceConstraintName.c_str(), "m_externalInterfaceConstraintName"),
+ TraceLoggingValue(m_externalInterfaceConstraintIndex, "m_externalInterfaceConstraintIndex"));
+
+ // Register for interface change notifications. Notifications are used to determine if the external interface constraint setting is applicable.
+ THROW_IF_WIN32_ERROR(NotifyIpInterfaceChange(AF_UNSPEC, &DnsResolver::InterfaceChangeCallback, this, FALSE, &m_interfaceNotificationHandle));
+ }
+}
+
+DnsResolver::~DnsResolver() noexcept
+{
+ Stop();
+}
+
+void DnsResolver::GenerateTelemetry() noexcept
+try
+{
+ // Find the 3 most common DNS API failures
+ uint32_t mostCommonDnsStatusError = 0;
+ uint32_t mostCommonDnsStatusErrorCount = 0;
+ uint32_t secondCommonDnsStatusError = 0;
+ uint32_t secondCommonDnsStatusErrorCount = 0;
+ uint32_t thirdCommonDnsStatusError = 0;
+ uint32_t thirdCommonDnsStatusErrorCount = 0;
+
+ std::vector> failures(m_dnsApiFailures.size());
+ std::copy(m_dnsApiFailures.begin(), m_dnsApiFailures.end(), failures.begin());
+
+ // Sort in descending order based on failure count
+ std::sort(failures.begin(), failures.end(), [](const auto& lhs, const auto& rhs) { return lhs.second > rhs.second; });
+
+ if (failures.size() >= 1)
+ {
+ mostCommonDnsStatusError = failures[0].first;
+ mostCommonDnsStatusErrorCount = failures[0].second;
+ }
+ if (failures.size() >= 2)
+ {
+ secondCommonDnsStatusError = failures[1].first;
+ secondCommonDnsStatusErrorCount = failures[1].second;
+ }
+ if (failures.size() >= 3)
+ {
+ thirdCommonDnsStatusError = failures[2].first;
+ thirdCommonDnsStatusErrorCount = failures[2].second;
+ }
+
+ // Add telemetry with DNS tunneling statistics, before shutting down
+ WSL_LOG(
+ "DnsTunnelingStatistics",
+ TraceLoggingValue(m_totalUdpQueries.load(), "totalUdpQueries"),
+ TraceLoggingValue(m_successfulUdpQueries.load(), "successfulUdpQueries"),
+ TraceLoggingValue(m_totalTcpQueries.load(), "totalTcpQueries"),
+ TraceLoggingValue(m_successfulTcpQueries.load(), "successfulTcpQueries"),
+ TraceLoggingValue(m_queriesWithNullResult.load(), "queriesWithNullResult"),
+ TraceLoggingValue(m_failedDnsQueryRawCalls.load(), "FailedDnsQueryRawCalls"),
+ TraceLoggingValue(m_dnsApiFailures.size(), "totalDnsStatusErrorInstances"),
+ TraceLoggingValue(mostCommonDnsStatusError, "mostCommonDnsStatusError"),
+ TraceLoggingValue(mostCommonDnsStatusErrorCount, "mostCommonDnsStatusErrorCount"),
+ TraceLoggingValue(secondCommonDnsStatusError, "secondCommonDnsStatusError"),
+ TraceLoggingValue(secondCommonDnsStatusErrorCount, "secondCommonDnsStatusErrorCount"),
+ TraceLoggingValue(thirdCommonDnsStatusError, "thirdCommonDnsStatusError"),
+ TraceLoggingValue(thirdCommonDnsStatusErrorCount, "thirdCommonDnsStatusErrorCount"));
+}
+CATCH_LOG()
+
+void DnsResolver::Stop() noexcept
+try
+{
+ WSL_LOG("DnsResolver::Stop");
+
+ // Scoped m_dnsLock
+ {
+ const std::lock_guard lock(m_dnsLock);
+
+ m_stopped = true;
+
+ // Cancel existing requests. Cancel is complete when DnsQueryRawCallback is
+ // invoked with status == ERROR_CANCELLED
+ // N.B. Cancelling can end up calling the DnsQueryRawCallback directly on this same thread. i.e., while this
+ // lock is held. Which is fine because m_dnsLock is a recursive mutex.
+ // N.B. Cancelling a query will synchronously remove the query from m_dnsRequests, which invalidates iterators.
+
+ std::vector cancelHandles;
+ cancelHandles.reserve(m_dnsRequests.size());
+
+ for (auto& [_, context] : m_dnsRequests)
+ {
+ cancelHandles.emplace_back(&context->m_cancelHandle);
+ }
+
+ for (const auto e : cancelHandles)
+ {
+ LOG_IF_WIN32_ERROR(s_dnsCancelQueryRaw.value()(e));
+ }
+ }
+
+ // Wait for all requests to complete. At this point no new requests can be started since the object is stopped.
+ // We are only waiting for existing requests to finish.
+ m_allRequestsFinished.wait();
+
+ // Stop the response queue first as it can make calls in m_dnsChannel
+ m_dnsResponseQueue.cancel();
+
+ m_dnsChannel.Stop();
+
+ // Stop interface change notifications
+ m_interfaceNotificationHandle.reset();
+
+ GenerateTelemetry();
+}
+CATCH_LOG()
+
+void DnsResolver::ProcessDnsRequest(const gsl::span dnsBuffer, const LX_GNS_DNS_CLIENT_IDENTIFIER& dnsClientIdentifier) noexcept
+try
+{
+ const std::lock_guard lock(m_dnsLock);
+ if (m_stopped)
+ {
+ return;
+ }
+
+ WSL_LOG_DEBUG(
+ "DnsResolver::ProcessDnsRequest - received new DNS request",
+ TraceLoggingValue(dnsBuffer.size(), "DNS buffer size"),
+ TraceLoggingValue(dnsClientIdentifier.Protocol == IPPROTO_UDP ? "UDP" : "TCP", "Protocol"),
+ TraceLoggingValue(dnsClientIdentifier.DnsClientId, "DNS client id"),
+ TraceLoggingValue(!m_externalInterfaceConstraintName.empty(), "Is ExternalInterfaceConstraint configured"),
+ TraceLoggingValue(m_externalInterfaceConstraintIndex, "m_externalInterfaceConstraintIndex"));
+
+ // If the external interface constraint is configured but it is *not* present/up, WSL should be net-blind, so we avoid making DNS requests.
+ if (!m_externalInterfaceConstraintName.empty() && m_externalInterfaceConstraintIndex == 0)
+ {
+ return;
+ }
+
+ dnsClientIdentifier.Protocol == IPPROTO_UDP ? m_totalUdpQueries++ : m_totalTcpQueries++;
+
+ // Get next request id. If value reaches UINT_MAX + 1 it will be automatically reset to 0
+ const auto requestId = m_currentRequestId++;
+
+ // Create the DNS request context
+ auto context = std::make_unique(
+ requestId, dnsClientIdentifier, [this](_Inout_ DnsResolver::DnsQueryContext* context, _Inout_opt_ DNS_QUERY_RAW_RESULT* queryResults) {
+ HandleDnsQueryCompletion(context, queryResults);
+ });
+
+ auto [it, _] = m_dnsRequests.emplace(requestId, std::move(context));
+ const auto localContext = it->second.get();
+
+ auto removeContextOnError = wil::scope_exit([&] { WI_VERIFY(m_dnsRequests.erase(requestId) == 1); });
+
+ // Fill DNS request structure
+ DNS_QUERY_RAW_REQUEST request{};
+
+ request.version = DNS_QUERY_RAW_REQUEST_VERSION1;
+ request.resultsVersion = DNS_QUERY_RAW_RESULTS_VERSION1;
+ request.dnsQueryRawSize = static_cast(dnsBuffer.size());
+ request.dnsQueryRaw = (PBYTE)dnsBuffer.data();
+ request.protocol = (dnsClientIdentifier.Protocol == IPPROTO_TCP) ? DNS_PROTOCOL_TCP : DNS_PROTOCOL_UDP;
+ request.queryCompletionCallback = DnsResolver::DnsQueryRawCallback;
+ request.queryContext = localContext;
+ // Only unicast UDP & TCP queries are tunneled. Pass this flag to tell Windows DNS client to *not* resolve using multicast.
+ request.queryOptions |= DNS_QUERY_NO_MULTICAST;
+
+ // In a DNS request from Linux there might be DNS records that Windows DNS client does not know how to parse.
+ // By default in this case Windows will fail the request. When the flag is enabled, Windows will extract the
+ // question from the DNS request and attempt to resolve it, ignoring the unknown records.
+ if (WI_IsFlagSet(m_flags, DnsResolverFlags::BestEffortDnsParsing))
+ {
+ request.queryRawOptions |= DNS_QUERY_RAW_OPTION_BEST_EFFORT_PARSE;
+ }
+
+ // If the external interface constraint is configured and present on the host, only send DNS requests on that interface.
+ if (m_externalInterfaceConstraintIndex != 0)
+ {
+ request.interfaceIndex = m_externalInterfaceConstraintIndex;
+ }
+
+ // Start the DNS request
+ // N.B. All DNS requests will bypass the Windows DNS cache
+ const auto result = s_dnsQueryRaw.value()(&request, &localContext->m_cancelHandle);
+ if (result != DNS_REQUEST_PENDING)
+ {
+ m_failedDnsQueryRawCalls++;
+
+ WSL_LOG(
+ "ProcessDnsRequestFailed",
+ TraceLoggingValue(requestId, "requestId"),
+ TraceLoggingValue(result, "result"),
+ TraceLoggingValue("DnsQueryRaw", "executionStep"));
+ return;
+ }
+
+ removeContextOnError.release();
+
+ m_allRequestsFinished.ResetEvent();
+}
+CATCH_LOG()
+
+void DnsResolver::HandleDnsQueryCompletion(_Inout_ DnsResolver::DnsQueryContext* queryContext, _Inout_opt_ DNS_QUERY_RAW_RESULT* queryResults) noexcept
+try
+{
+ // Always free the query result structure
+ const auto freeQueryResults = wil::scope_exit([&] {
+ if (queryResults != nullptr)
+ {
+ s_dnsQueryRawResultFree.value()(queryResults);
+ }
+ });
+
+ const std::lock_guard lock(m_dnsLock);
+
+ if (queryResults != nullptr)
+ {
+ WSL_LOG(
+ "DnsResolver::HandleDnsQueryCompletion",
+ TraceLoggingValue(queryContext->m_id, "queryContext->m_id"),
+ TraceLoggingValue(queryResults->queryStatus, "queryResults->queryStatus"),
+ TraceLoggingValue(queryResults->queryRawResponse != nullptr, "validResponse"));
+
+ // Note: The response may be valid even if queryResults->queryStatus is not 0, for example when the DNS server returns a negative response.
+ if (queryResults->queryRawResponse != nullptr)
+ {
+ queryContext->m_dnsClientIdentifier.Protocol == IPPROTO_UDP ? m_successfulUdpQueries++ : m_successfulTcpQueries++;
+ }
+ // the Windows DNS API returned failure
+ else
+ {
+ if (m_dnsApiFailures.find(queryResults->queryStatus) == m_dnsApiFailures.end())
+ {
+ m_dnsApiFailures[queryResults->queryStatus] = 1;
+ }
+ else
+ {
+ m_dnsApiFailures[queryResults->queryStatus]++;
+ }
+ }
+ }
+ else
+ {
+ WSL_LOG(
+ "DnsResolver::HandleDnsQueryCompletion - received a NULL queryResults",
+ TraceLoggingValue(queryContext->m_id, "queryContext->m_id"));
+ m_queriesWithNullResult++;
+ }
+
+ if (!m_stopped && queryResults != nullptr && queryResults->queryRawResponse != nullptr)
+ {
+ // Copy DNS response buffer
+ std::vector dnsResponse(queryResults->queryRawResponseSize);
+ CopyMemory(dnsResponse.data(), queryResults->queryRawResponse, queryResults->queryRawResponseSize);
+
+ WSL_LOG_DEBUG(
+ "DnsResolver::HandleDnsQueryCompletion - received new DNS response",
+ TraceLoggingValue(dnsResponse.size(), "DNS buffer size"),
+ TraceLoggingValue(queryContext->m_dnsClientIdentifier.Protocol == IPPROTO_UDP ? "UDP" : "TCP", "Protocol"),
+ TraceLoggingValue(queryContext->m_dnsClientIdentifier.DnsClientId, "DNS client id"));
+
+ // Schedule the DNS response to be sent to Linux
+ m_dnsResponseQueue.submit([this, dnsResponse = std::move(dnsResponse), dnsClientIdentifier = queryContext->m_dnsClientIdentifier]() mutable {
+ m_dnsChannel.SendDnsMessage(gsl::make_span(dnsResponse), dnsClientIdentifier);
+ });
+ }
+
+ // Stop tracking this DNS request and delete the request context
+ WI_VERIFY(m_dnsRequests.erase(queryContext->m_id) == 1);
+
+ // Set event if all tracked requests have finished
+ if (m_dnsRequests.empty())
+ {
+ m_allRequestsFinished.SetEvent();
+ }
+}
+CATCH_LOG()
+
+void DnsResolver::ResolveExternalInterfaceConstraintIndex() noexcept
+try
+{
+ const std::lock_guard lock(m_dnsLock);
+ if (m_stopped)
+ {
+ return;
+ }
+
+ if (m_externalInterfaceConstraintName.empty())
+ {
+ return;
+ }
+
+ NET_LUID interfaceLuid{};
+ ULONG interfaceIndex = 0;
+
+ // Update the interface index on every exit path.
+ // The calls below to convert interface name to index will fail if the interface does not exist anymore,
+ // in which case we still need to reset the interface index to its default value of 0.
+ const auto setInterfaceIndex = wil::scope_exit([&] {
+ if (interfaceIndex != m_externalInterfaceConstraintIndex)
+ {
+ WSL_LOG(
+ "DnsResolver::ResolveExternalInterfaceConstraintIndex - setting m_externalInterfaceConstraintIndex to new value",
+ TraceLoggingValue(m_externalInterfaceConstraintIndex, "old interface index"),
+ TraceLoggingValue(interfaceIndex, "new interface index"));
+
+ m_externalInterfaceConstraintIndex = interfaceIndex;
+ }
+ });
+
+ // If external interface constraint is configured, query to see if it's present on the host.
+ auto errorCode = ConvertInterfaceAliasToLuid(m_externalInterfaceConstraintName.c_str(), &interfaceLuid);
+ if (FAILED_WIN32_LOG(errorCode))
+ {
+ return;
+ }
+
+ errorCode = ConvertInterfaceLuidToIndex(&interfaceLuid, reinterpret_cast(&interfaceIndex));
+ if (FAILED_WIN32_LOG(errorCode))
+ {
+ return;
+ }
+}
+CATCH_LOG()
+
+VOID CALLBACK DnsResolver::DnsQueryRawCallback(_In_ VOID* queryContext, _Inout_opt_ DNS_QUERY_RAW_RESULT* queryResults) noexcept
+try
+{
+ assert(queryContext != nullptr);
+
+ const auto context = static_cast(queryContext);
+
+ // Call into DnsResolver parent object to process the query result
+ context->m_handleQueryCompletion(context, queryResults);
+}
+CATCH_LOG()
+
+VOID CALLBACK DnsResolver::InterfaceChangeCallback(_In_ PVOID context, PMIB_IPINTERFACE_ROW, MIB_NOTIFICATION_TYPE) noexcept
+try
+{
+ const auto dnsResolver = static_cast(context);
+ dnsResolver->ResolveExternalInterfaceConstraintIndex();
+}
+CATCH_LOG()
diff --git a/src/windows/service/exe/DnsResolver.h b/src/windows/common/DnsResolver.h
similarity index 97%
rename from src/windows/service/exe/DnsResolver.h
rename to src/windows/common/DnsResolver.h
index a186737..d3ea8f2 100644
--- a/src/windows/service/exe/DnsResolver.h
+++ b/src/windows/common/DnsResolver.h
@@ -1,141 +1,141 @@
-// Copyright (C) Microsoft Corporation. All rights reserved.
-
-#pragma once
-
-#include "DnsTunnelingChannel.h"
-#include "WslCoreMessageQueue.h"
-#include "WslCoreNetworkingSupport.h"
-
-namespace wsl::core::networking {
-
-enum class DnsResolverFlags
-{
- None = 0x0,
- BestEffortDnsParsing = 0x1
-};
-DEFINE_ENUM_FLAG_OPERATORS(DnsResolverFlags);
-
-class DnsResolver
-{
-public:
- DnsResolver(wil::unique_socket&& dnsHvsocket, DnsResolverFlags flags);
- ~DnsResolver() noexcept;
-
- DnsResolver(const DnsResolver&) = delete;
- DnsResolver& operator=(const DnsResolver&) = delete;
-
- DnsResolver(DnsResolver&&) = delete;
- DnsResolver& operator=(DnsResolver&&) = delete;
-
- void Stop() noexcept;
-
- static HRESULT LoadDnsResolverMethods() noexcept;
-
-private:
- struct DnsQueryContext
- {
- // Struct containing protocol (TCP/UDP) and unique id of the Linux DNS client making the request.
- LX_GNS_DNS_CLIENT_IDENTIFIER m_dnsClientIdentifier{};
-
- // Handle used to cancel the request.
- DNS_QUERY_RAW_CANCEL m_cancelHandle{};
-
- // Unique query id.
- uint32_t m_id{};
-
- // Callback to the parent object to notify about the DNS query completion.
- std::function m_handleQueryCompletion;
-
- DnsQueryContext(
- uint32_t id,
- const LX_GNS_DNS_CLIENT_IDENTIFIER& dnsClientIdentifier,
- std::function&& handleQueryCompletion) :
- m_dnsClientIdentifier(dnsClientIdentifier), m_id(id), m_handleQueryCompletion(std::move(handleQueryCompletion))
- {
- }
-
- ~DnsQueryContext() noexcept = default;
-
- DnsQueryContext(const DnsQueryContext&) = delete;
- DnsQueryContext& operator=(const DnsQueryContext&) = delete;
- DnsQueryContext(DnsQueryContext&&) = delete;
- DnsQueryContext& operator=(DnsQueryContext&&) = delete;
- };
-
- void GenerateTelemetry() noexcept;
-
- // Process DNS request received from Linux.
- //
- // Arguments:
- // dnsBuffer - buffer containing DNS request.
- // dnsClientIdentifier - struct containing protocol (TCP/UDP) and unique id of the Linux DNS client making the request.
- void ProcessDnsRequest(const gsl::span dnsBuffer, const LX_GNS_DNS_CLIENT_IDENTIFIER& dnsClientIdentifier) noexcept;
-
- // Handle completion of DNS query.
- //
- // Arguments:
- // dnsQueryContext - context structure for the DNS request.
- // queryResults - structure containing result of the DNS request.
- void HandleDnsQueryCompletion(_Inout_ DnsQueryContext* dnsQueryContext, _Inout_opt_ DNS_QUERY_RAW_RESULT* queryResults) noexcept;
-
- void ResolveExternalInterfaceConstraintIndex() noexcept;
-
- // Callback that will be invoked by the DNS API whenever a request finishes. The callback is invoked on success, error or when request is cancelled.
- //
- // Arguments:
- // queryContext - pointer to context structure, will be a structure of type DnsQueryContext.
- // queryResults - pointer to structure containing the result of the DNS request.
- static VOID CALLBACK DnsQueryRawCallback(_In_ VOID* queryContext, _Inout_opt_ DNS_QUERY_RAW_RESULT* queryResults) noexcept;
-
- static VOID CALLBACK InterfaceChangeCallback(_In_ PVOID context, PMIB_IPINTERFACE_ROW, MIB_NOTIFICATION_TYPE) noexcept;
-
- std::recursive_mutex m_dnsLock;
-
- // Flag used when shutting down the object.
- _Guarded_by_(m_dnsLock) bool m_stopped = false;
-
- // Hvsocket channel used to exchange DNS messages with Linux.
- DnsTunnelingChannel m_dnsChannel;
-
- // Queue used to send DNS responses to Linux.
- WslCoreMessageQueue m_dnsResponseQueue;
-
- // Unique id that is incremented for each request. In case the value reaches MAX_UINT and is reset to 0,
- // it's assumed previous requests with id's 0, 1, ... finished in the meantime and the id can be reused.
- _Guarded_by_(m_dnsLock) uint32_t m_currentRequestId = 0;
-
- // Mapping request id to the request context structure.
- _Guarded_by_(m_dnsLock) std::unordered_map> m_dnsRequests {};
-
- // Event that is set when all tracked DNS requests have completed.
- wil::unique_event m_allRequestsFinished{wil::EventOptions::ManualReset};
-
- // Used for handling of external interface constraint setting.
- unique_notify_handle m_interfaceNotificationHandle{};
-
- std::wstring m_externalInterfaceConstraintName;
- _Guarded_by_(m_dnsLock) ULONG m_externalInterfaceConstraintIndex = 0;
-
- const DnsResolverFlags m_flags{};
-
- // Statistics used for telemetry.
- std::atomic m_totalUdpQueries{0};
- std::atomic m_successfulUdpQueries{0};
- std::atomic m_totalTcpQueries{0};
- std::atomic m_successfulTcpQueries{0};
- std::atomic m_queriesWithNullResult{0};
- std::atomic m_failedDnsQueryRawCalls{0};
-
- _Guarded_by_(m_dnsLock) std::map m_dnsApiFailures;
-
- // Dynamic functions used for calling the DNS APIs.
-
- // Function to start a raw DNS request.
- static std::optional> s_dnsQueryRaw;
- // Function to cancel a raw DNS request.
- static std::optional> s_dnsCancelQueryRaw;
- // Function to free the structure containing the result of a raw DNS request.
- static std::optional> s_dnsQueryRawResultFree;
-};
-
-} // namespace wsl::core::networking
+// Copyright (C) Microsoft Corporation. All rights reserved.
+
+#pragma once
+
+#include "DnsTunnelingChannel.h"
+#include "WslCoreMessageQueue.h"
+#include "WslCoreNetworkingSupport.h"
+
+namespace wsl::core::networking {
+
+enum class DnsResolverFlags
+{
+ None = 0x0,
+ BestEffortDnsParsing = 0x1
+};
+DEFINE_ENUM_FLAG_OPERATORS(DnsResolverFlags);
+
+class DnsResolver
+{
+public:
+ DnsResolver(wil::unique_socket&& dnsHvsocket, DnsResolverFlags flags);
+ ~DnsResolver() noexcept;
+
+ DnsResolver(const DnsResolver&) = delete;
+ DnsResolver& operator=(const DnsResolver&) = delete;
+
+ DnsResolver(DnsResolver&&) = delete;
+ DnsResolver& operator=(DnsResolver&&) = delete;
+
+ void Stop() noexcept;
+
+ static HRESULT LoadDnsResolverMethods() noexcept;
+
+private:
+ struct DnsQueryContext
+ {
+ // Struct containing protocol (TCP/UDP) and unique id of the Linux DNS client making the request.
+ LX_GNS_DNS_CLIENT_IDENTIFIER m_dnsClientIdentifier{};
+
+ // Handle used to cancel the request.
+ DNS_QUERY_RAW_CANCEL m_cancelHandle{};
+
+ // Unique query id.
+ uint32_t m_id{};
+
+ // Callback to the parent object to notify about the DNS query completion.
+ std::function m_handleQueryCompletion;
+
+ DnsQueryContext(
+ uint32_t id,
+ const LX_GNS_DNS_CLIENT_IDENTIFIER& dnsClientIdentifier,
+ std::function&& handleQueryCompletion) :
+ m_dnsClientIdentifier(dnsClientIdentifier), m_id(id), m_handleQueryCompletion(std::move(handleQueryCompletion))
+ {
+ }
+
+ ~DnsQueryContext() noexcept = default;
+
+ DnsQueryContext(const DnsQueryContext&) = delete;
+ DnsQueryContext& operator=(const DnsQueryContext&) = delete;
+ DnsQueryContext(DnsQueryContext&&) = delete;
+ DnsQueryContext& operator=(DnsQueryContext&&) = delete;
+ };
+
+ void GenerateTelemetry() noexcept;
+
+ // Process DNS request received from Linux.
+ //
+ // Arguments:
+ // dnsBuffer - buffer containing DNS request.
+ // dnsClientIdentifier - struct containing protocol (TCP/UDP) and unique id of the Linux DNS client making the request.
+ void ProcessDnsRequest(const gsl::span dnsBuffer, const LX_GNS_DNS_CLIENT_IDENTIFIER& dnsClientIdentifier) noexcept;
+
+ // Handle completion of DNS query.
+ //
+ // Arguments:
+ // dnsQueryContext - context structure for the DNS request.
+ // queryResults - structure containing result of the DNS request.
+ void HandleDnsQueryCompletion(_Inout_ DnsQueryContext* dnsQueryContext, _Inout_opt_ DNS_QUERY_RAW_RESULT* queryResults) noexcept;
+
+ void ResolveExternalInterfaceConstraintIndex() noexcept;
+
+ // Callback that will be invoked by the DNS API whenever a request finishes. The callback is invoked on success, error or when request is cancelled.
+ //
+ // Arguments:
+ // queryContext - pointer to context structure, will be a structure of type DnsQueryContext.
+ // queryResults - pointer to structure containing the result of the DNS request.
+ static VOID CALLBACK DnsQueryRawCallback(_In_ VOID* queryContext, _Inout_opt_ DNS_QUERY_RAW_RESULT* queryResults) noexcept;
+
+ static VOID CALLBACK InterfaceChangeCallback(_In_ PVOID context, PMIB_IPINTERFACE_ROW, MIB_NOTIFICATION_TYPE) noexcept;
+
+ std::recursive_mutex m_dnsLock;
+
+ // Flag used when shutting down the object.
+ _Guarded_by_(m_dnsLock) bool m_stopped = false;
+
+ // Hvsocket channel used to exchange DNS messages with Linux.
+ DnsTunnelingChannel m_dnsChannel;
+
+ // Queue used to send DNS responses to Linux.
+ WslCoreMessageQueue m_dnsResponseQueue;
+
+ // Unique id that is incremented for each request. In case the value reaches MAX_UINT and is reset to 0,
+ // it's assumed previous requests with id's 0, 1, ... finished in the meantime and the id can be reused.
+ _Guarded_by_(m_dnsLock) uint32_t m_currentRequestId = 0;
+
+ // Mapping request id to the request context structure.
+ _Guarded_by_(m_dnsLock) std::unordered_map> m_dnsRequests {};
+
+ // Event that is set when all tracked DNS requests have completed.
+ wil::unique_event m_allRequestsFinished{wil::EventOptions::ManualReset};
+
+ // Used for handling of external interface constraint setting.
+ unique_notify_handle m_interfaceNotificationHandle{};
+
+ std::wstring m_externalInterfaceConstraintName;
+ _Guarded_by_(m_dnsLock) ULONG m_externalInterfaceConstraintIndex = 0;
+
+ const DnsResolverFlags m_flags{};
+
+ // Statistics used for telemetry.
+ std::atomic m_totalUdpQueries{0};
+ std::atomic m_successfulUdpQueries{0};
+ std::atomic m_totalTcpQueries{0};
+ std::atomic m_successfulTcpQueries{0};
+ std::atomic m_queriesWithNullResult{0};
+ std::atomic m_failedDnsQueryRawCalls{0};
+
+ _Guarded_by_(m_dnsLock) std::map m_dnsApiFailures;
+
+ // Dynamic functions used for calling the DNS APIs.
+
+ // Function to start a raw DNS request.
+ static std::optional> s_dnsQueryRaw;
+ // Function to cancel a raw DNS request.
+ static std::optional> s_dnsCancelQueryRaw;
+ // Function to free the structure containing the result of a raw DNS request.
+ static std::optional> s_dnsQueryRawResultFree;
+};
+
+} // namespace wsl::core::networking
diff --git a/src/windows/service/exe/DnsTunnelingChannel.cpp b/src/windows/common/DnsTunnelingChannel.cpp
similarity index 97%
rename from src/windows/service/exe/DnsTunnelingChannel.cpp
rename to src/windows/common/DnsTunnelingChannel.cpp
index 6d34a2a..b80bf8f 100644
--- a/src/windows/service/exe/DnsTunnelingChannel.cpp
+++ b/src/windows/common/DnsTunnelingChannel.cpp
@@ -1,115 +1,115 @@
-// Copyright (C) Microsoft Corporation. All rights reserved.
-
-#include "precomp.h"
-#include "DnsTunnelingChannel.h"
-
-using wsl::core::networking::DnsTunnelingChannel;
-
-DnsTunnelingChannel::DnsTunnelingChannel(wil::unique_socket&& socket, DnsTunnelingCallback&& reportDnsRequest) :
- m_channel{std::move(socket), "DnsTunneling", m_stopEvent.get()}, m_reportDnsRequest(std::move(reportDnsRequest))
-{
- WSL_LOG("DnsTunnelingChannel::DnsTunnelingChannel [Windows]", TraceLoggingValue(m_channel.Socket(), "socket"));
-
- // Start thread waiting for incoming messages from Linux side
- m_receiveWorkerThread = std::thread([this]() { ReceiveLoop(); });
-}
-
-DnsTunnelingChannel::~DnsTunnelingChannel()
-{
- Stop();
-}
-
-void DnsTunnelingChannel::SendDnsMessage(const gsl::span dnsBuffer, const LX_GNS_DNS_CLIENT_IDENTIFIER& dnsClientIdentifier) noexcept
-try
-{
- // Exit if channel was stopped
- if (m_stopEvent.is_signaled())
- {
- return;
- }
-
- wsl::shared::MessageWriter message(LxGnsMessageDnsTunneling);
- message->DnsClientIdentifier = dnsClientIdentifier;
- message.WriteSpan(dnsBuffer);
-
- m_channel.SendMessage(message.Span());
-}
-CATCH_LOG()
-
-void DnsTunnelingChannel::ReceiveLoop() noexcept
-{
- std::vector receiveBuffer;
-
- for (;;)
- {
- try
- {
- if (m_stopEvent.is_signaled())
- {
- return;
- }
-
- WSL_LOG_DEBUG("DnsTunnelingChannel::ReceiveLoop [Windows] - waiting for next message from Linux");
-
- // Read next message. wsl::shared::socket::RecvMessage() first reads the message header, then uses it to determine the
- // total size of the message and read the rest of the message, resizing the buffer if needed.
- auto [message, span] = m_channel.ReceiveMessageOrClosed();
- if (message == nullptr)
- {
- WSL_LOG("DnsTunnelingChannel::ReceiveLoop [Windows] - failed to read message");
- return;
- }
-
- // Get the message type from the message header
- switch (message->MessageType)
- {
- case LxGnsMessageDnsTunneling:
- {
- // Cast message to a LX_GNS_DNS_TUNNELING_MESSAGE struct
- auto* dnsMessage = gslhelpers::try_get_struct(span);
- if (!dnsMessage)
- {
- WSL_LOG(
- "DnsTunnelingChannel::ReceiveLoop [Windows] - failed to convert message to LX_GNS_DNS_TUNNELING_MESSAGE");
- return;
- }
-
- // Extract DNS buffer from message
- auto dnsBuffer = span.subspan(offsetof(LX_GNS_DNS_TUNNELING_MESSAGE, Buffer));
-
- WSL_LOG_DEBUG(
- "DnsTunnelingChannel::ReceiveLoop [Windows] - received DNS message",
- TraceLoggingValue(dnsBuffer.size(), "DNS buffer size"),
- TraceLoggingValue(dnsMessage->DnsClientIdentifier.Protocol == IPPROTO_UDP ? "UDP" : "TCP", "Protocol"),
- TraceLoggingValue(dnsMessage->DnsClientIdentifier.DnsClientId, "DNS client id"));
-
- // Invoke callback to notify about the new DNS request
- m_reportDnsRequest(dnsBuffer, dnsMessage->DnsClientIdentifier);
-
- break;
- }
-
- default:
- {
- THROW_HR_MSG(E_UNEXPECTED, "Unexpected LX_MESSAGE_TYPE : %i", message->MessageType);
- }
- }
- }
- CATCH_LOG()
- }
-}
-
-void DnsTunnelingChannel::Stop() noexcept
-try
-{
- WSL_LOG("DnsTunnelingChannel::Stop [Windows]");
-
- m_stopEvent.SetEvent();
-
- // Stop receive loop
- if (m_receiveWorkerThread.joinable())
- {
- m_receiveWorkerThread.join();
- }
-}
-CATCH_LOG()
+// Copyright (C) Microsoft Corporation. All rights reserved.
+
+#include "precomp.h"
+#include "DnsTunnelingChannel.h"
+
+using wsl::core::networking::DnsTunnelingChannel;
+
+DnsTunnelingChannel::DnsTunnelingChannel(wil::unique_socket&& socket, DnsTunnelingCallback&& reportDnsRequest) :
+ m_channel{std::move(socket), "DnsTunneling", m_stopEvent.get()}, m_reportDnsRequest(std::move(reportDnsRequest))
+{
+ WSL_LOG("DnsTunnelingChannel::DnsTunnelingChannel [Windows]", TraceLoggingValue(m_channel.Socket(), "socket"));
+
+ // Start thread waiting for incoming messages from Linux side
+ m_receiveWorkerThread = std::thread([this]() { ReceiveLoop(); });
+}
+
+DnsTunnelingChannel::~DnsTunnelingChannel()
+{
+ Stop();
+}
+
+void DnsTunnelingChannel::SendDnsMessage(const gsl::span dnsBuffer, const LX_GNS_DNS_CLIENT_IDENTIFIER& dnsClientIdentifier) noexcept
+try
+{
+ // Exit if channel was stopped
+ if (m_stopEvent.is_signaled())
+ {
+ return;
+ }
+
+ wsl::shared::MessageWriter message(LxGnsMessageDnsTunneling);
+ message->DnsClientIdentifier = dnsClientIdentifier;
+ message.WriteSpan(dnsBuffer);
+
+ m_channel.SendMessage(message.Span());
+}
+CATCH_LOG()
+
+void DnsTunnelingChannel::ReceiveLoop() noexcept
+{
+ std::vector receiveBuffer;
+
+ for (;;)
+ {
+ try
+ {
+ if (m_stopEvent.is_signaled())
+ {
+ return;
+ }
+
+ WSL_LOG_DEBUG("DnsTunnelingChannel::ReceiveLoop [Windows] - waiting for next message from Linux");
+
+ // Read next message. wsl::shared::socket::RecvMessage() first reads the message header, then uses it to determine the
+ // total size of the message and read the rest of the message, resizing the buffer if needed.
+ auto [message, span] = m_channel.ReceiveMessageOrClosed();
+ if (message == nullptr)
+ {
+ WSL_LOG("DnsTunnelingChannel::ReceiveLoop [Windows] - failed to read message");
+ return;
+ }
+
+ // Get the message type from the message header
+ switch (message->MessageType)
+ {
+ case LxGnsMessageDnsTunneling:
+ {
+ // Cast message to a LX_GNS_DNS_TUNNELING_MESSAGE struct
+ auto* dnsMessage = gslhelpers::try_get_struct(span);
+ if (!dnsMessage)
+ {
+ WSL_LOG(
+ "DnsTunnelingChannel::ReceiveLoop [Windows] - failed to convert message to LX_GNS_DNS_TUNNELING_MESSAGE");
+ return;
+ }
+
+ // Extract DNS buffer from message
+ auto dnsBuffer = span.subspan(offsetof(LX_GNS_DNS_TUNNELING_MESSAGE, Buffer));
+
+ WSL_LOG_DEBUG(
+ "DnsTunnelingChannel::ReceiveLoop [Windows] - received DNS message",
+ TraceLoggingValue(dnsBuffer.size(), "DNS buffer size"),
+ TraceLoggingValue(dnsMessage->DnsClientIdentifier.Protocol == IPPROTO_UDP ? "UDP" : "TCP", "Protocol"),
+ TraceLoggingValue(dnsMessage->DnsClientIdentifier.DnsClientId, "DNS client id"));
+
+ // Invoke callback to notify about the new DNS request
+ m_reportDnsRequest(dnsBuffer, dnsMessage->DnsClientIdentifier);
+
+ break;
+ }
+
+ default:
+ {
+ THROW_HR_MSG(E_UNEXPECTED, "Unexpected LX_MESSAGE_TYPE : %i", message->MessageType);
+ }
+ }
+ }
+ CATCH_LOG()
+ }
+}
+
+void DnsTunnelingChannel::Stop() noexcept
+try
+{
+ WSL_LOG("DnsTunnelingChannel::Stop [Windows]");
+
+ m_stopEvent.SetEvent();
+
+ // Stop receive loop
+ if (m_receiveWorkerThread.joinable())
+ {
+ m_receiveWorkerThread.join();
+ }
+}
+CATCH_LOG()
diff --git a/src/windows/service/exe/DnsTunnelingChannel.h b/src/windows/common/DnsTunnelingChannel.h
similarity index 97%
rename from src/windows/service/exe/DnsTunnelingChannel.h
rename to src/windows/common/DnsTunnelingChannel.h
index 0214cf6..e05e198 100644
--- a/src/windows/service/exe/DnsTunnelingChannel.h
+++ b/src/windows/common/DnsTunnelingChannel.h
@@ -1,50 +1,50 @@
-// Copyright (C) Microsoft Corporation. All rights reserved.
-
-#pragma once
-
-#include
-#include "lxinitshared.h"
-#include "SocketChannel.h"
-
-namespace wsl::core::networking {
-
-using DnsTunnelingCallback = std::function, const LX_GNS_DNS_CLIENT_IDENTIFIER&)>;
-
-class DnsTunnelingChannel
-{
-public:
- DnsTunnelingChannel(wil::unique_socket&& socket, DnsTunnelingCallback&& reportDnsRequest);
- ~DnsTunnelingChannel();
-
- DnsTunnelingChannel(const DnsTunnelingChannel&) = delete;
- DnsTunnelingChannel& operator=(const DnsTunnelingChannel&) = delete;
-
- DnsTunnelingChannel(DnsTunnelingChannel&&) = delete;
- DnsTunnelingChannel& operator=(DnsTunnelingChannel&&) = delete;
-
- // Construct and send a LX_GNS_DNS_TUNNELING_MESSAGE message on the channel.
- // Note: Callers are responsible for sequencing calls to this method.
- //
- // Arguments:
- // dnsBuffer - buffer containing DNS response.
- // dnsClientIdentifier - struct containing protocol (TCP/UDP) and unique id of the Linux DNS client making the request.
- void SendDnsMessage(const gsl::span dnsBuffer, const LX_GNS_DNS_CLIENT_IDENTIFIER& dnsClientIdentifier) noexcept;
-
- // Stop the channel.
- void Stop() noexcept;
-
-private:
- // Wait for messages on the channel from Linux side.
- void ReceiveLoop() noexcept;
-
- wil::unique_event m_stopEvent{wil::EventOptions::ManualReset};
-
- wsl::shared::SocketChannel m_channel;
-
- std::thread m_receiveWorkerThread;
-
- // Callback used to notify when there is a new DNS request message on the channel.
- DnsTunnelingCallback m_reportDnsRequest;
-};
-
-} // namespace wsl::core::networking
+// Copyright (C) Microsoft Corporation. All rights reserved.
+
+#pragma once
+
+#include
+#include "lxinitshared.h"
+#include "SocketChannel.h"
+
+namespace wsl::core::networking {
+
+using DnsTunnelingCallback = std::function, const LX_GNS_DNS_CLIENT_IDENTIFIER&)>;
+
+class DnsTunnelingChannel
+{
+public:
+ DnsTunnelingChannel(wil::unique_socket&& socket, DnsTunnelingCallback&& reportDnsRequest);
+ ~DnsTunnelingChannel();
+
+ DnsTunnelingChannel(const DnsTunnelingChannel&) = delete;
+ DnsTunnelingChannel& operator=(const DnsTunnelingChannel&) = delete;
+
+ DnsTunnelingChannel(DnsTunnelingChannel&&) = delete;
+ DnsTunnelingChannel& operator=(DnsTunnelingChannel&&) = delete;
+
+ // Construct and send a LX_GNS_DNS_TUNNELING_MESSAGE message on the channel.
+ // Note: Callers are responsible for sequencing calls to this method.
+ //
+ // Arguments:
+ // dnsBuffer - buffer containing DNS response.
+ // dnsClientIdentifier - struct containing protocol (TCP/UDP) and unique id of the Linux DNS client making the request.
+ void SendDnsMessage(const gsl::span dnsBuffer, const LX_GNS_DNS_CLIENT_IDENTIFIER& dnsClientIdentifier) noexcept;
+
+ // Stop the channel.
+ void Stop() noexcept;
+
+private:
+ // Wait for messages on the channel from Linux side.
+ void ReceiveLoop() noexcept;
+
+ wil::unique_event m_stopEvent{wil::EventOptions::ManualReset};
+
+ wsl::shared::SocketChannel m_channel;
+
+ std::thread m_receiveWorkerThread;
+
+ // Callback used to notify when there is a new DNS request message on the channel.
+ DnsTunnelingCallback m_reportDnsRequest;
+};
+
+} // namespace wsl::core::networking
diff --git a/src/windows/service/exe/GnsChannel.cpp b/src/windows/common/GnsChannel.cpp
similarity index 100%
rename from src/windows/service/exe/GnsChannel.cpp
rename to src/windows/common/GnsChannel.cpp
diff --git a/src/windows/service/exe/GnsChannel.h b/src/windows/common/GnsChannel.h
similarity index 100%
rename from src/windows/service/exe/GnsChannel.h
rename to src/windows/common/GnsChannel.h
diff --git a/src/windows/service/exe/INetworkingEngine.h b/src/windows/common/INetworkingEngine.h
similarity index 100%
rename from src/windows/service/exe/INetworkingEngine.h
rename to src/windows/common/INetworkingEngine.h
diff --git a/src/windows/service/exe/LxssSecurity.cpp b/src/windows/common/LxssSecurity.cpp
similarity index 100%
rename from src/windows/service/exe/LxssSecurity.cpp
rename to src/windows/common/LxssSecurity.cpp
diff --git a/src/windows/service/exe/LxssSecurity.h b/src/windows/common/LxssSecurity.h
similarity index 100%
rename from src/windows/service/exe/LxssSecurity.h
rename to src/windows/common/LxssSecurity.h
diff --git a/src/windows/service/exe/NatNetworking.cpp b/src/windows/common/NatNetworking.cpp
similarity index 99%
rename from src/windows/service/exe/NatNetworking.cpp
rename to src/windows/common/NatNetworking.cpp
index b7dd080..dbad2a2 100644
--- a/src/windows/service/exe/NatNetworking.cpp
+++ b/src/windows/common/NatNetworking.cpp
@@ -6,7 +6,6 @@
#include "WslCoreHostDnsInfo.h"
#include "Stringify.h"
#include "WslCoreFirewallSupport.h"
-#include "WslCoreVm.h"
#include "hcs.hpp"
using namespace wsl::core::networking;
@@ -672,7 +671,7 @@ wsl::windows::common::hcs::unique_hcn_network NatNetworking::CreateNetwork(wsl::
wil::ResultFromException(WI_DIAGNOSTICS_INFO, [&] {
try
{
- wsl::core::networking::ConfigureHyperVFirewall(config.FirewallConfig, c_vmOwner);
+ wsl::core::networking::ConfigureHyperVFirewall(config.FirewallConfig, wsl::windows::common::wslutil::c_vmOwner);
natNetwork = CreateNetworkInternal(config);
}
catch (...)
diff --git a/src/windows/service/exe/NatNetworking.h b/src/windows/common/NatNetworking.h
similarity index 100%
rename from src/windows/service/exe/NatNetworking.h
rename to src/windows/common/NatNetworking.h
diff --git a/src/windows/service/exe/RingBuffer.cpp b/src/windows/common/RingBuffer.cpp
similarity index 96%
rename from src/windows/service/exe/RingBuffer.cpp
rename to src/windows/common/RingBuffer.cpp
index b13b058..efce04e 100644
--- a/src/windows/service/exe/RingBuffer.cpp
+++ b/src/windows/common/RingBuffer.cpp
@@ -1,154 +1,154 @@
-/*++
-
-Copyright (c) Microsoft. All rights reserved.
-
-Module Name:
-
- RingBuffer.cpp
-
-Abstract:
-
- This file contains definitions for the RingBuffer class.
-
---*/
-
-#include "precomp.h"
-#include "RingBuffer.h"
-
-RingBuffer::RingBuffer(size_t size) : m_maxSize(size), m_offset(0)
-{
- m_buffer.reserve(size);
-}
-
-void RingBuffer::Insert(std::string_view data)
-{
- auto lock = m_lock.lock_exclusive();
- auto remainingData = gsl::make_span(data.data(), data.size());
- if (remainingData.size() > m_maxSize)
- {
- remainingData = remainingData.subspan(remainingData.size() - m_maxSize);
- }
-
- const auto bytesAtEnd = std::min(m_maxSize - m_offset, remainingData.size());
- if (m_offset + bytesAtEnd > m_buffer.size())
- {
- m_buffer.resize(m_offset + bytesAtEnd);
- WI_ASSERT(m_buffer.size() <= m_maxSize);
- }
-
- const auto allBuffer = gsl::make_span(m_buffer);
- const auto beginCopyBuffer = allBuffer.subspan(m_offset, bytesAtEnd);
- copy(remainingData.subspan(0, bytesAtEnd), beginCopyBuffer);
- remainingData = remainingData.subspan(bytesAtEnd);
- if (!remainingData.empty())
- {
- copy(remainingData, allBuffer);
- m_offset = remainingData.size();
- }
- else
- {
- m_offset += bytesAtEnd;
- }
-}
-
-std::vector RingBuffer::GetLastDelimitedStrings(char Delimiter, size_t Count) const
-{
- auto lock = m_lock.lock_shared();
- auto [begin, end] = Contents();
- std::vector results;
- std::optional endIndex;
- for (size_t i = end.size(); i > 0; i--)
- {
- if (results.size() == Count)
- {
- break;
- }
-
- if (Delimiter == end[i - 1])
- {
- if (endIndex.has_value())
- {
- results.emplace(results.begin(), &end[i], endIndex.value() - i);
- endIndex.reset();
- }
- else
- {
- endIndex = i - 1;
- }
- }
- }
-
- if (results.size() == Count)
- {
- return results;
- }
-
- std::string partial;
- if (endIndex.has_value())
- {
- partial = std::string{&end[0], endIndex.value()};
- endIndex.reset();
- }
-
- for (size_t i = begin.size(); i > 0; i--)
- {
- if (results.size() == Count)
- {
- break;
- }
-
- if (Delimiter == begin[i - 1])
- {
- if (!partial.empty())
- {
- // The debug CRT will fastfail if begin[size] is accessed
- // But in this case it's not a problem because begin.size() - i would be == 0
- std::string partial_begin{&begin.data()[i], begin.size() - i};
- results.emplace(results.begin(), partial_begin + partial);
- partial.clear();
- }
- else if (endIndex.has_value())
- {
- results.emplace(results.begin(), &begin.data()[i], endIndex.value() - i);
- endIndex.reset();
- }
- else
- {
- endIndex = i - 1;
- }
- }
- }
-
- if (results.size() < Count)
- {
- // May have lost some data, or this could be the very first line logged.
- if (!partial.empty())
- {
- results.emplace(results.begin(), partial);
- }
- else if (endIndex.has_value())
- {
- results.emplace(results.begin(), &begin[0], endIndex.value());
- }
- }
-
- return results;
-}
-
-std::string RingBuffer::Get() const
-{
- auto lock = m_lock.lock_shared();
- auto [begin, end] = Contents();
- std::string data;
- data.reserve(begin.size() + end.size());
- data.append(begin.data(), begin.size());
- data.append(end.data(), end.size());
- return data;
-}
-
-std::pair RingBuffer::Contents() const
-{
- std::string_view beginView(m_buffer.data() + m_offset, m_buffer.size() - m_offset);
- std::string_view endView(m_buffer.data(), m_offset);
- return {beginView, endView};
+/*++
+
+Copyright (c) Microsoft. All rights reserved.
+
+Module Name:
+
+ RingBuffer.cpp
+
+Abstract:
+
+ This file contains definitions for the RingBuffer class.
+
+--*/
+
+#include "precomp.h"
+#include "RingBuffer.h"
+
+RingBuffer::RingBuffer(size_t size) : m_maxSize(size), m_offset(0)
+{
+ m_buffer.reserve(size);
+}
+
+void RingBuffer::Insert(std::string_view data)
+{
+ auto lock = m_lock.lock_exclusive();
+ auto remainingData = gsl::make_span(data.data(), data.size());
+ if (remainingData.size() > m_maxSize)
+ {
+ remainingData = remainingData.subspan(remainingData.size() - m_maxSize);
+ }
+
+ const auto bytesAtEnd = std::min(m_maxSize - m_offset, remainingData.size());
+ if (m_offset + bytesAtEnd > m_buffer.size())
+ {
+ m_buffer.resize(m_offset + bytesAtEnd);
+ WI_ASSERT(m_buffer.size() <= m_maxSize);
+ }
+
+ const auto allBuffer = gsl::make_span(m_buffer);
+ const auto beginCopyBuffer = allBuffer.subspan(m_offset, bytesAtEnd);
+ copy(remainingData.subspan(0, bytesAtEnd), beginCopyBuffer);
+ remainingData = remainingData.subspan(bytesAtEnd);
+ if (!remainingData.empty())
+ {
+ copy(remainingData, allBuffer);
+ m_offset = remainingData.size();
+ }
+ else
+ {
+ m_offset += bytesAtEnd;
+ }
+}
+
+std::vector RingBuffer::GetLastDelimitedStrings(char Delimiter, size_t Count) const
+{
+ auto lock = m_lock.lock_shared();
+ auto [begin, end] = Contents();
+ std::vector results;
+ std::optional endIndex;
+ for (size_t i = end.size(); i > 0; i--)
+ {
+ if (results.size() == Count)
+ {
+ break;
+ }
+
+ if (Delimiter == end[i - 1])
+ {
+ if (endIndex.has_value())
+ {
+ results.emplace(results.begin(), &end[i], endIndex.value() - i);
+ endIndex.reset();
+ }
+ else
+ {
+ endIndex = i - 1;
+ }
+ }
+ }
+
+ if (results.size() == Count)
+ {
+ return results;
+ }
+
+ std::string partial;
+ if (endIndex.has_value())
+ {
+ partial = std::string{&end[0], endIndex.value()};
+ endIndex.reset();
+ }
+
+ for (size_t i = begin.size(); i > 0; i--)
+ {
+ if (results.size() == Count)
+ {
+ break;
+ }
+
+ if (Delimiter == begin[i - 1])
+ {
+ if (!partial.empty())
+ {
+ // The debug CRT will fastfail if begin[size] is accessed
+ // But in this case it's not a problem because begin.size() - i would be == 0
+ std::string partial_begin{&begin.data()[i], begin.size() - i};
+ results.emplace(results.begin(), partial_begin + partial);
+ partial.clear();
+ }
+ else if (endIndex.has_value())
+ {
+ results.emplace(results.begin(), &begin.data()[i], endIndex.value() - i);
+ endIndex.reset();
+ }
+ else
+ {
+ endIndex = i - 1;
+ }
+ }
+ }
+
+ if (results.size() < Count)
+ {
+ // May have lost some data, or this could be the very first line logged.
+ if (!partial.empty())
+ {
+ results.emplace(results.begin(), partial);
+ }
+ else if (endIndex.has_value())
+ {
+ results.emplace(results.begin(), &begin[0], endIndex.value());
+ }
+ }
+
+ return results;
+}
+
+std::string RingBuffer::Get() const
+{
+ auto lock = m_lock.lock_shared();
+ auto [begin, end] = Contents();
+ std::string data;
+ data.reserve(begin.size() + end.size());
+ data.append(begin.data(), begin.size());
+ data.append(end.data(), end.size());
+ return data;
+}
+
+std::pair RingBuffer::Contents() const
+{
+ std::string_view beginView(m_buffer.data() + m_offset, m_buffer.size() - m_offset);
+ std::string_view endView(m_buffer.data(), m_offset);
+ return {beginView, endView};
}
\ No newline at end of file
diff --git a/src/windows/service/exe/RingBuffer.h b/src/windows/common/RingBuffer.h
similarity index 94%
rename from src/windows/service/exe/RingBuffer.h
rename to src/windows/common/RingBuffer.h
index f8d204d..2786ebf 100644
--- a/src/windows/service/exe/RingBuffer.h
+++ b/src/windows/common/RingBuffer.h
@@ -1,34 +1,34 @@
-/*++
-
-Copyright (c) Microsoft. All rights reserved.
-
-Module Name:
-
- RingBuffer.h
-
-Abstract:
-
- This file contains declarations for the RingBuffer class.
-
---*/
-
-#pragma once
-
-class RingBuffer
-{
-public:
- RingBuffer() = delete;
- RingBuffer(size_t size);
-
- void Insert(std::string_view data);
- std::vector GetLastDelimitedStrings(char Delimiter, size_t Count) const;
- std::string Get() const;
-
-private:
- std::pair Contents() const;
-
- mutable wil::srwlock m_lock;
- std::vector m_buffer;
- size_t m_maxSize;
- size_t m_offset;
+/*++
+
+Copyright (c) Microsoft. All rights reserved.
+
+Module Name:
+
+ RingBuffer.h
+
+Abstract:
+
+ This file contains declarations for the RingBuffer class.
+
+--*/
+
+#pragma once
+
+class RingBuffer
+{
+public:
+ RingBuffer() = delete;
+ RingBuffer(size_t size);
+
+ void Insert(std::string_view data);
+ std::vector GetLastDelimitedStrings(char Delimiter, size_t Count) const;
+ std::string Get() const;
+
+private:
+ std::pair Contents() const;
+
+ mutable wil::srwlock m_lock;
+ std::vector m_buffer;
+ size_t m_maxSize;
+ size_t m_offset;
};
\ No newline at end of file
diff --git a/src/windows/service/exe/WslCoreFilesystem.cpp b/src/windows/common/WslCoreFilesystem.cpp
similarity index 100%
rename from src/windows/service/exe/WslCoreFilesystem.cpp
rename to src/windows/common/WslCoreFilesystem.cpp
diff --git a/src/windows/service/exe/WslCoreFilesystem.h b/src/windows/common/WslCoreFilesystem.h
similarity index 100%
rename from src/windows/service/exe/WslCoreFilesystem.h
rename to src/windows/common/WslCoreFilesystem.h
diff --git a/src/windows/service/exe/WslCoreHostDnsInfo.cpp b/src/windows/common/WslCoreHostDnsInfo.cpp
similarity index 100%
rename from src/windows/service/exe/WslCoreHostDnsInfo.cpp
rename to src/windows/common/WslCoreHostDnsInfo.cpp
diff --git a/src/windows/service/exe/WslCoreHostDnsInfo.h b/src/windows/common/WslCoreHostDnsInfo.h
similarity index 95%
rename from src/windows/service/exe/WslCoreHostDnsInfo.h
rename to src/windows/common/WslCoreHostDnsInfo.h
index 1813574..8fa4894 100644
--- a/src/windows/service/exe/WslCoreHostDnsInfo.h
+++ b/src/windows/common/WslCoreHostDnsInfo.h
@@ -1,104 +1,104 @@
-// Copyright (C) Microsoft Corporation. All rights reserved.
-
-#pragma once
-#include
-#include
-#include
-
-#include
-#include
-
-#include "WslCoreNetworkingSupport.h"
-#include "RegistryWatcher.h"
-
-namespace wsl::core::networking {
-struct DnsInfo
-{
- std::vector Servers;
- std::vector Domains;
-};
-
-enum class DnsSettingsFlags
-{
- None = 0x0,
- IncludeVpn = 0x1,
- IncludeIpv6Servers = 0x2,
- IncludeAllSuffixes = 0x4
-};
-DEFINE_ENUM_FLAG_OPERATORS(DnsSettingsFlags);
-
-inline bool operator==(const DnsInfo& lhs, const DnsInfo& rhs) noexcept
-{
- return lhs.Servers == rhs.Servers && lhs.Domains == rhs.Domains;
-}
-inline bool operator!=(const DnsInfo& lhs, const DnsInfo& rhs) noexcept
-{
- return !(lhs == rhs);
-}
-
-std::string GenerateResolvConf(_In_ const DnsInfo& Info);
-
-std::vector GetAllDnsSuffixes(const std::vector& AdapterAddresses);
-
-DWORD GetBestInterface();
-
-class HostDnsInfo
-{
-public:
- DnsInfo GetDnsSettings(_In_ DnsSettingsFlags Flags);
-
- void UpdateNetworkInformation();
-
- static DnsInfo GetDnsTunnelingSettings(const std::wstring& dnsTunnelingNameserver);
-
- const std::vector& CurrentAddresses() const
- {
- return m_addresses;
- }
-
-private:
- ///
- /// Internal function to retrieve the latest copy of interface information.
- ///
- std::vector GetAdapterAddresses();
-
- ///
- /// Internal function to retrieve interface DNS servers.
- ///
- std::vector GetInterfaceDnsServers(const std::vector& AdapterAddresses, _In_ DnsSettingsFlags Flags);
-
- ///
- /// Internal function to retrieve all Windows DNS suffixes.
- ///
- static std::vector GetInterfaceDnsSuffixes(const std::vector& AdapterAddresses);
-
- ///
- /// Internal function to convert DNS server addresses into strings.
- ///
- static std::vector GetDnsServerStrings(_In_ const PIP_ADAPTER_DNS_SERVER_ADDRESS& DnsServer, _In_ USHORT IpFamilyFilter, _In_ USHORT MaxValues);
-
- ///
- /// Stores latest copy of interface information.
- ///
- std::mutex m_lock;
- _Guarded_by_(m_lock) std::vector m_addresses;
-};
-
-using RegistryChangeCallback = std::function;
-
-///
-/// Class used to get notifications when Windows DNS suffixes are updated in registry.
-///
-class DnsSuffixRegistryWatcher
-{
-public:
- DnsSuffixRegistryWatcher(RegistryChangeCallback&& reportRegistryChange);
- ~DnsSuffixRegistryWatcher() noexcept = default;
-
-private:
- RegistryChangeCallback m_reportRegistryChange;
-
- std::vector> m_registryWatchers;
-};
-
+// Copyright (C) Microsoft Corporation. All rights reserved.
+
+#pragma once
+#include
+#include
+#include
+
+#include
+#include
+
+#include "WslCoreNetworkingSupport.h"
+#include "RegistryWatcher.h"
+
+namespace wsl::core::networking {
+struct DnsInfo
+{
+ std::vector Servers;
+ std::vector Domains;
+};
+
+enum class DnsSettingsFlags
+{
+ None = 0x0,
+ IncludeVpn = 0x1,
+ IncludeIpv6Servers = 0x2,
+ IncludeAllSuffixes = 0x4
+};
+DEFINE_ENUM_FLAG_OPERATORS(DnsSettingsFlags);
+
+inline bool operator==(const DnsInfo& lhs, const DnsInfo& rhs) noexcept
+{
+ return lhs.Servers == rhs.Servers && lhs.Domains == rhs.Domains;
+}
+inline bool operator!=(const DnsInfo& lhs, const DnsInfo& rhs) noexcept
+{
+ return !(lhs == rhs);
+}
+
+std::string GenerateResolvConf(_In_ const DnsInfo& Info);
+
+std::vector GetAllDnsSuffixes(const std::vector& AdapterAddresses);
+
+DWORD GetBestInterface();
+
+class HostDnsInfo
+{
+public:
+ DnsInfo GetDnsSettings(_In_ DnsSettingsFlags Flags);
+
+ void UpdateNetworkInformation();
+
+ static DnsInfo GetDnsTunnelingSettings(const std::wstring& dnsTunnelingNameserver);
+
+ const std::vector& CurrentAddresses() const
+ {
+ return m_addresses;
+ }
+
+private:
+ ///
+ /// Internal function to retrieve the latest copy of interface information.
+ ///
+ std::vector GetAdapterAddresses();
+
+ ///
+ /// Internal function to retrieve interface DNS servers.
+ ///
+ std::vector GetInterfaceDnsServers(const std::vector& AdapterAddresses, _In_ DnsSettingsFlags Flags);
+
+ ///
+ /// Internal function to retrieve all Windows DNS suffixes.
+ ///
+ static std::vector GetInterfaceDnsSuffixes(const std::vector& AdapterAddresses);
+
+ ///
+ /// Internal function to convert DNS server addresses into strings.
+ ///
+ static std::vector GetDnsServerStrings(_In_ const PIP_ADAPTER_DNS_SERVER_ADDRESS& DnsServer, _In_ USHORT IpFamilyFilter, _In_ USHORT MaxValues);
+
+ ///
+ /// Stores latest copy of interface information.
+ ///
+ std::mutex m_lock;
+ _Guarded_by_(m_lock) std::vector m_addresses;
+};
+
+using RegistryChangeCallback = std::function;
+
+///
+/// Class used to get notifications when Windows DNS suffixes are updated in registry.
+///
+class DnsSuffixRegistryWatcher
+{
+public:
+ DnsSuffixRegistryWatcher(RegistryChangeCallback&& reportRegistryChange);
+ ~DnsSuffixRegistryWatcher() noexcept = default;
+
+private:
+ RegistryChangeCallback m_reportRegistryChange;
+
+ std::vector> m_registryWatchers;
+};
+
} // namespace wsl::core::networking
\ No newline at end of file
diff --git a/src/windows/service/exe/WslCoreMessageQueue.h b/src/windows/common/WslCoreMessageQueue.h
similarity index 96%
rename from src/windows/service/exe/WslCoreMessageQueue.h
rename to src/windows/common/WslCoreMessageQueue.h
index d2598ed..d8f3b8d 100644
--- a/src/windows/service/exe/WslCoreMessageQueue.h
+++ b/src/windows/common/WslCoreMessageQueue.h
@@ -1,359 +1,359 @@
-/*++
-
-Copyright (c) Microsoft. All rights reserved.
-
-Module Name:
-
- WslCoreMessageQueue.h
-
-Abstract:
-
- This file contains a queuing implementation, guaranteeing running function objects
- with guaranteed serialization in a threadpool thread
-
---*/
-
-#pragma once
-#include
-#include
-#include
-#include
-#include
-#include
-
-namespace wsl::core {
-// forward-declare classes that can instantiate a WslThreadPoolWaitableResult object
-class WslCoreMessageQueue;
-
-class WslBaseThreadPoolWaitableResult
-{
-public:
- virtual ~WslBaseThreadPoolWaitableResult() noexcept = default;
-
-private:
- // limit who can run() and abort()
- friend class WslCoreMessageQueue;
-
- virtual void run() noexcept = 0;
- virtual void abort() noexcept = 0;
-};
-
-template
-class WslThreadPoolWaitableResult : public WslBaseThreadPoolWaitableResult
-{
-public:
- // throws a wil exception on failure
- template
- explicit WslThreadPoolWaitableResult(FunctorType&& functor) : m_function(std::forward(functor))
- {
- }
-
- ~WslThreadPoolWaitableResult() noexcept override = default;
-
- // returns ERROR_SUCCESS if the callback ran to completion
- // returns ERROR_TIMEOUT if this wait timed out
- // - this can be called multiple times if needing to probe
- // any other error code resulted from attempting to run the callback
- // - meaning it did *not* run to completion
- DWORD wait(DWORD timeout) const noexcept
- {
- if (!m_completionSignal.wait(timeout))
- {
- // not setting m_internalError to timeout
- // since the caller is allowed to try to wait() again later
- return ERROR_TIMEOUT;
- }
- const auto lock = m_lock.lock_shared();
- return m_internalError;
- }
-
- // waitable event handle, signaled when the callback has run to completion (or failed)
- HANDLE notification_event() const noexcept
- {
- return m_completionSignal.get();
- }
-
- const TReturn& read_result() const noexcept
- {
- return result;
- }
-
- // move the result out of the object for move-only types
- TReturn move_result() noexcept
- {
- TReturn move_out(std::move(result));
- return move_out;
- }
-
- // non-copyable
- WslThreadPoolWaitableResult(const WslThreadPoolWaitableResult&) = delete;
- WslThreadPoolWaitableResult& operator=(const WslThreadPoolWaitableResult&) = delete;
-
-private:
- void run() noexcept override
- {
- // we are now running in the TP callback
- {
- const auto lock = m_lock.lock_exclusive();
- if (m_runStatus != RunStatus::NotYetRun)
- {
- // return early - the caller has already canceled this
- return;
- }
- m_runStatus = RunStatus::Running;
- }
-
- DWORD error = NO_ERROR;
- try
- {
- result = std::move(m_function());
- }
- catch (...)
- {
- const HRESULT hr = wil::ResultFromCaughtException();
- // HRESULT_TO_WIN32
- error = (HRESULT_FACILITY(hr) == FACILITY_WIN32) ? HRESULT_CODE(hr) : hr;
- }
-
- const auto lock = m_lock.lock_exclusive();
- WI_ASSERT(m_runStatus == RunStatus::Running);
- m_runStatus = RunStatus::RanToCompletion;
- m_internalError = error;
- m_completionSignal.SetEvent();
- }
-
- void abort() noexcept override
- {
- const auto lock = m_lock.lock_exclusive();
- // only override the error if we know we haven't started running their functor
- if (m_runStatus == RunStatus::NotYetRun)
- {
- m_runStatus = RunStatus::Canceled;
- m_internalError = ERROR_CANCELLED;
- m_completionSignal.SetEvent();
- }
- }
-
- std::function m_function;
- // a notification event
- wil::unique_event m_completionSignal{wil::EventOptions::ManualReset};
- mutable wil::srwlock m_lock;
- TReturn result{};
- DWORD m_internalError = NO_ERROR;
-
- enum class RunStatus
- {
- NotYetRun,
- Running,
- RanToCompletion,
- Canceled
- } m_runStatus{RunStatus::NotYetRun};
-};
-
-class WslCoreMessageQueue
-{
-public:
- WslCoreMessageQueue() : m_tpEnvironment(0, 1)
- {
- // create a single-threaded threadpool
- m_tpHandle = m_tpEnvironment.create_tp(WorkCallback, this);
- }
-
- template
- std::shared_ptr> submit_with_results(FunctorType&& functor) noexcept
- try
- {
- FAIL_FAST_IF(m_tpHandle.get() == nullptr);
-
- const auto new_result = std::make_shared>(std::forward(functor));
- // scope to the queue lock
- {
- const auto queueLock = m_lock.lock_exclusive();
- THROW_HR_IF(HRESULT_FROM_WIN32(ERROR_CANCELLED), m_isCanceled);
- m_workItems.emplace_back(new_result);
- }
-
- // always maintain a 1:1 ratio for calls to SubmitWorkWithResults() and ::SubmitThreadpoolWork
- SubmitThreadpoolWork(m_tpHandle.get());
- return new_result;
- }
- catch (...)
- {
- LOG_CAUGHT_EXCEPTION();
- return nullptr;
- }
-
- template
- bool submit(FunctorType&& functor) noexcept
- try
- {
- FAIL_FAST_IF(m_tpHandle.get() == nullptr);
-
- // scope to the queue lock
- {
- const auto queueLock = m_lock.lock_exclusive();
- THROW_HR_IF(HRESULT_FROM_WIN32(ERROR_CANCELLED), m_isCanceled);
- m_workItems.emplace_back(std::forward(functor));
- }
-
- // always maintain a 1:1 ratio for calls to SubmitWork() and ::SubmitThreadpoolWork
- SubmitThreadpoolWork(m_tpHandle.get());
- return true;
- }
- catch (...)
- {
- LOG_CAUGHT_EXCEPTION();
- return false;
- }
-
- // functors must return type HRESULT
- template
- HRESULT submit_and_wait(FunctorType&& functor) noexcept
- try
- {
- HRESULT hr = HRESULT_FROM_WIN32(ERROR_OUTOFMEMORY);
- if (const auto waitableResult = submit_with_results(std::forward(functor)))
- {
- hr = HRESULT_FROM_WIN32(waitableResult->wait(INFINITE));
- if (SUCCEEDED(hr))
- {
- hr = waitableResult->read_result();
- }
- }
- return hr;
- }
- CATCH_RETURN()
-
- // cancels anything queued to the TP - this WslCoreMessageQueue instance can no longer be used
- void cancel() noexcept
- try
- {
- if (m_tpHandle)
- {
- // immediately release anyone waiting for these workitems not yet run
- {
- const auto queueLock = m_lock.lock_exclusive();
- m_isCanceled = true;
-
- for (const auto& work : m_workItems)
- {
- // signal that these are canceled before we shutdown the TP which they could be scheduled
- if (const auto* pWaitableWorkitem = std::get_if(&work))
- {
- (*pWaitableWorkitem)->abort();
- }
- }
-
- m_workItems.clear();
- }
-
- // force the m_tpHandle to wait and close the TP
- m_tpHandle.reset();
- m_tpEnvironment.reset();
- }
- }
- CATCH_LOG()
-
- bool isRunningInQueue() const noexcept
- {
- const auto currentThreadId = GetThreadId(GetCurrentThread());
- return currentThreadId == static_cast(InterlockedCompareExchange64(&m_threadpoolThreadId, 0ll, 0ll));
- }
-
- ~WslCoreMessageQueue() noexcept
- {
- cancel();
- }
-
- WslCoreMessageQueue(const WslCoreMessageQueue&) = delete;
- WslCoreMessageQueue& operator=(const WslCoreMessageQueue&) = delete;
- WslCoreMessageQueue(WslCoreMessageQueue&&) = delete;
- WslCoreMessageQueue& operator=(WslCoreMessageQueue&&) = delete;
-
-private:
- struct TPEnvironment
- {
- using unique_tp_env = wil::unique_struct;
- unique_tp_env m_tpEnvironment;
-
- using unique_tp_pool = wil::unique_any;
- unique_tp_pool m_threadPool;
-
- TPEnvironment(DWORD countMinThread, DWORD countMaxThread)
- {
- InitializeThreadpoolEnvironment(&m_tpEnvironment);
-
- m_threadPool.reset(CreateThreadpool(nullptr));
- THROW_LAST_ERROR_IF_NULL(m_threadPool.get());
-
- // Set min and max thread counts for custom thread pool
- THROW_LAST_ERROR_IF(!::SetThreadpoolThreadMinimum(m_threadPool.get(), countMinThread));
- SetThreadpoolThreadMaximum(m_threadPool.get(), countMaxThread);
- SetThreadpoolCallbackPool(&m_tpEnvironment, m_threadPool.get());
- }
-
- wil::unique_threadpool_work create_tp(PTP_WORK_CALLBACK callback, void* pv)
- {
- wil::unique_threadpool_work newThreadpool(CreateThreadpoolWork(callback, pv, (m_threadPool) ? &m_tpEnvironment : nullptr));
- THROW_LAST_ERROR_IF_NULL(newThreadpool.get());
- return newThreadpool;
- }
-
- void reset()
- {
- m_threadPool.reset();
- m_tpEnvironment.reset();
- }
- };
-
- using SimpleFunction_t = std::function;
- using WaitableFunction_t = std::shared_ptr;
- using FunctionVariant_t = std::variant;
-
- // the lock must be destroyed *after* the TP object (thus must be declared first)
- // since the lock is used in the TP callback
- // the lock is mutable to allow us to acquire the lock in const methods
- mutable wil::srwlock m_lock;
- TPEnvironment m_tpEnvironment;
- wil::unique_threadpool_work m_tpHandle;
- std::deque m_workItems;
- mutable LONG64 m_threadpoolThreadId{0}; // useful for callers to assert they are running within the queue
- bool m_isCanceled{false};
-
- static void CALLBACK WorkCallback(PTP_CALLBACK_INSTANCE, void* Context, PTP_WORK) noexcept
- try
- {
- auto* pThis = static_cast(Context);
-
- FunctionVariant_t work;
- {
- const auto queueLock = pThis->m_lock.lock_exclusive();
-
- if (pThis->m_workItems.empty())
- {
- // pThis object is being destroyed and the queue was cleared
- return;
- }
-
- std::swap(work, pThis->m_workItems.front());
- pThis->m_workItems.pop_front();
-
- InterlockedExchange64(&pThis->m_threadpoolThreadId, GetThreadId(GetCurrentThread()));
- }
-
- // run the tasks outside the WslCoreMessageQueue lock
- const auto resetThreadIdOnExit = wil::scope_exit([pThis] { InterlockedExchange64(&pThis->m_threadpoolThreadId, 0ll); });
- if (work.index() == 0)
- {
- const auto& workItem = std::get(work);
- workItem();
- }
- else
- {
- const auto& waitableWorkItem = std::get(work);
- waitableWorkItem->run();
- }
- }
- CATCH_LOG()
-};
-} // namespace wsl::core
+/*++
+
+Copyright (c) Microsoft. All rights reserved.
+
+Module Name:
+
+ WslCoreMessageQueue.h
+
+Abstract:
+
+ This file contains a queuing implementation, guaranteeing running function objects
+ with guaranteed serialization in a threadpool thread
+
+--*/
+
+#pragma once
+#include
+#include
+#include
+#include
+#include
+#include
+
+namespace wsl::core {
+// forward-declare classes that can instantiate a WslThreadPoolWaitableResult object
+class WslCoreMessageQueue;
+
+class WslBaseThreadPoolWaitableResult
+{
+public:
+ virtual ~WslBaseThreadPoolWaitableResult() noexcept = default;
+
+private:
+ // limit who can run() and abort()
+ friend class WslCoreMessageQueue;
+
+ virtual void run() noexcept = 0;
+ virtual void abort() noexcept = 0;
+};
+
+template
+class WslThreadPoolWaitableResult : public WslBaseThreadPoolWaitableResult
+{
+public:
+ // throws a wil exception on failure
+ template
+ explicit WslThreadPoolWaitableResult(FunctorType&& functor) : m_function(std::forward(functor))
+ {
+ }
+
+ ~WslThreadPoolWaitableResult() noexcept override = default;
+
+ // returns ERROR_SUCCESS if the callback ran to completion
+ // returns ERROR_TIMEOUT if this wait timed out
+ // - this can be called multiple times if needing to probe
+ // any other error code resulted from attempting to run the callback
+ // - meaning it did *not* run to completion
+ DWORD wait(DWORD timeout) const noexcept
+ {
+ if (!m_completionSignal.wait(timeout))
+ {
+ // not setting m_internalError to timeout
+ // since the caller is allowed to try to wait() again later
+ return ERROR_TIMEOUT;
+ }
+ const auto lock = m_lock.lock_shared();
+ return m_internalError;
+ }
+
+ // waitable event handle, signaled when the callback has run to completion (or failed)
+ HANDLE notification_event() const noexcept
+ {
+ return m_completionSignal.get();
+ }
+
+ const TReturn& read_result() const noexcept
+ {
+ return result;
+ }
+
+ // move the result out of the object for move-only types
+ TReturn move_result() noexcept
+ {
+ TReturn move_out(std::move(result));
+ return move_out;
+ }
+
+ // non-copyable
+ WslThreadPoolWaitableResult(const WslThreadPoolWaitableResult&) = delete;
+ WslThreadPoolWaitableResult& operator=(const WslThreadPoolWaitableResult&) = delete;
+
+private:
+ void run() noexcept override
+ {
+ // we are now running in the TP callback
+ {
+ const auto lock = m_lock.lock_exclusive();
+ if (m_runStatus != RunStatus::NotYetRun)
+ {
+ // return early - the caller has already canceled this
+ return;
+ }
+ m_runStatus = RunStatus::Running;
+ }
+
+ DWORD error = NO_ERROR;
+ try
+ {
+ result = std::move(m_function());
+ }
+ catch (...)
+ {
+ const HRESULT hr = wil::ResultFromCaughtException();
+ // HRESULT_TO_WIN32
+ error = (HRESULT_FACILITY(hr) == FACILITY_WIN32) ? HRESULT_CODE(hr) : hr;
+ }
+
+ const auto lock = m_lock.lock_exclusive();
+ WI_ASSERT(m_runStatus == RunStatus::Running);
+ m_runStatus = RunStatus::RanToCompletion;
+ m_internalError = error;
+ m_completionSignal.SetEvent();
+ }
+
+ void abort() noexcept override
+ {
+ const auto lock = m_lock.lock_exclusive();
+ // only override the error if we know we haven't started running their functor
+ if (m_runStatus == RunStatus::NotYetRun)
+ {
+ m_runStatus = RunStatus::Canceled;
+ m_internalError = ERROR_CANCELLED;
+ m_completionSignal.SetEvent();
+ }
+ }
+
+ std::function m_function;
+ // a notification event
+ wil::unique_event m_completionSignal{wil::EventOptions::ManualReset};
+ mutable wil::srwlock m_lock;
+ TReturn result{};
+ DWORD m_internalError = NO_ERROR;
+
+ enum class RunStatus
+ {
+ NotYetRun,
+ Running,
+ RanToCompletion,
+ Canceled
+ } m_runStatus{RunStatus::NotYetRun};
+};
+
+class WslCoreMessageQueue
+{
+public:
+ WslCoreMessageQueue() : m_tpEnvironment(0, 1)
+ {
+ // create a single-threaded threadpool
+ m_tpHandle = m_tpEnvironment.create_tp(WorkCallback, this);
+ }
+
+ template
+ std::shared_ptr> submit_with_results(FunctorType&& functor) noexcept
+ try
+ {
+ FAIL_FAST_IF(m_tpHandle.get() == nullptr);
+
+ const auto new_result = std::make_shared>(std::forward(functor));
+ // scope to the queue lock
+ {
+ const auto queueLock = m_lock.lock_exclusive();
+ THROW_HR_IF(HRESULT_FROM_WIN32(ERROR_CANCELLED), m_isCanceled);
+ m_workItems.emplace_back(new_result);
+ }
+
+ // always maintain a 1:1 ratio for calls to SubmitWorkWithResults() and ::SubmitThreadpoolWork
+ SubmitThreadpoolWork(m_tpHandle.get());
+ return new_result;
+ }
+ catch (...)
+ {
+ LOG_CAUGHT_EXCEPTION();
+ return nullptr;
+ }
+
+ template