wsla: Create wslaservice.exe (#13623)

* Save state

* Save state

* Save state

* Save state

* Save state

* Save state

* Save state

* Save state

* Save state

* Cleanup for review

* Update ServiceMain.cpp comment

* Remove duplicated definitions from wslservice.idl
This commit is contained in:
Blue 2025-10-23 17:44:57 -07:00 committed by GitHub
parent 3f24caaf73
commit 451a7e103a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
57 changed files with 4209 additions and 2480 deletions

View File

@ -388,6 +388,7 @@ include_directories(${WSLDEPS_SOURCE_DIR}/include/lxcore)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/src/shared/inc)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/src/windows/inc)
include_directories(${CMAKE_CURRENT_BINARY_DIR}/src/windows/service/inc/${TARGET_PLATFORM}/${CMAKE_BUILD_TYPE})
include_directories(${CMAKE_CURRENT_BINARY_DIR}/src/windows/wslaservice/inc/${TARGET_PLATFORM}/${CMAKE_BUILD_TYPE})
include_directories(${CMAKE_CURRENT_BINARY_DIR}/src/windows/wslinstaller/inc/${TARGET_PLATFORM}/${CMAKE_BUILD_TYPE})
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/src/linux/init/inc)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/src/windows/common)
@ -413,6 +414,7 @@ add_subdirectory(msipackage)
add_subdirectory(msixinstaller)
add_subdirectory(src/windows/common)
add_subdirectory(src/windows/service)
add_subdirectory(src/windows/wslaservice)
add_subdirectory(src/windows/wslinstaller/inc)
add_subdirectory(src/windows/wslinstaller/stub)
add_subdirectory(src/windows/wslinstaller/exe)

View File

@ -3,7 +3,7 @@ set(OUTPUT_PACKAGE ${BIN}/wsl.msi)
set(PACKAGE_WIX_IN ${CMAKE_CURRENT_LIST_DIR}/package.wix.in)
set(PACKAGE_WIX ${BIN}/package.wix)
set(CAB_CACHE ${BIN}/cab)
set(BINARIES wsl.exe;wslg.exe;wslhost.exe;wslrelay.exe;wslservice.exe;wslserviceproxystub.dll;init;initrd.img;wslinstall.dll)
set(BINARIES wsl.exe;wslg.exe;wslhost.exe;wslrelay.exe;wslservice.exe;wslserviceproxystub.dll;init;initrd.img;wslinstall.dll;wslaserviceproxystub.dll;wslaservice.exe)
if (WSL_BUILD_WSL_SETTINGS)
list(APPEND BINARIES_DEPENDENCIES "wslsettings/wslsettings.dll;wslsettings/wslsettings.exe;libwsl.dll")
@ -39,7 +39,7 @@ add_custom_command(
add_custom_target(msipackage DEPENDS ${OUTPUT_PACKAGE})
set_target_properties(msipackage PROPERTIES EXCLUDE_FROM_ALL FALSE SOURCES ${PACKAGE_WIX_IN})
add_dependencies(msipackage wsl wslg wslservice wslhost wslrelay wslserviceproxystub init initramfs wslinstall msixgluepackage)
add_dependencies(msipackage wsl wslg wslservice wslhost wslrelay wslserviceproxystub init initramfs wslinstall msixgluepackage wslaservice wslaserviceproxystub)
if (WSL_BUILD_WSL_SETTINGS)
add_dependencies(msipackage wslsettings libwsl)

View File

@ -50,75 +50,75 @@
<Component Id="explorerplan9shortcut" Guid="{93CBFF23-A04C-4344-A332-238CE5B97AED}" UninstallWhenSuperseded="yes" DisableRegistryReflection="yes" Bitness="always64">
<!-- Explorer extensions -->
<RegistryKey Root="HKLM" Key="SOFTWARE\Classes\CLSID\{B2B4A4D1-2754-4140-A2EB-9A76D9D7CDC6}">
<RegistryValue Value="Linux" Type="string"/>
<RegistryValue Name="SortOrderIndex" Value="119" Type="integer"/>
<!--0x77-->
<RegistryValue Name="System.IsPinnedToNameSpaceTree" Value="1" Type="integer"/>
<RegistryKey Root="HKLM" Key="SOFTWARE\Classes\CLSID\{B2B4A4D1-2754-4140-A2EB-9A76D9D7CDC6}">
<RegistryValue Value="Linux" Type="string"/>
<RegistryValue Name="SortOrderIndex" Value="119" Type="integer"/>
<!--0x77-->
<RegistryValue Name="System.IsPinnedToNameSpaceTree" Value="1" Type="integer"/>
<RegistryKey Key="DefaultIcon">
<RegistryValue Value="[System64Folder]wsl.exe,-1" Type="string"/>
</RegistryKey>
<RegistryKey Key="DefaultIcon">
<RegistryValue Value="[System64Folder]wsl.exe,-1" Type="string"/>
</RegistryKey>
<RegistryKey Key="InProcServer32">
<RegistryValue Value="[System64Folder]windows.storage.dll" Type="string"/>
</RegistryKey>
<RegistryKey Key="InProcServer32">
<RegistryValue Value="[System64Folder]windows.storage.dll" Type="string"/>
</RegistryKey>
<RegistryKey Key="ShellFolder">
<RegistryValue Name="Attributes" Value="2692743245" Type="integer"/>
<!--0xa080004d"-->
<RegistryValue Name="FolderValueFlags" Value="40" Type="integer"/>
<!--0x28-->
</RegistryKey>
<RegistryKey Key="ShellFolder">
<RegistryValue Name="Attributes" Value="2692743245" Type="integer"/>
<!--0xa080004d"-->
<RegistryValue Name="FolderValueFlags" Value="40" Type="integer"/>
<!--0x28-->
</RegistryKey>
<RegistryKey Key="Instance">
<RegistryValue Name="CLSID" Value="{4FE04BFD-85B9-49DD-B914-F4C9556B9DA6}" Type="string"/>
<RegistryKey Key="Instance">
<RegistryValue Name="CLSID" Value="{4FE04BFD-85B9-49DD-B914-F4C9556B9DA6}" Type="string"/>
<RegistryKey Key="InitPropertyBag">
<RegistryValue Name="DisplayType" Value="2" Type="integer"/>
<RegistryValue Name="EnumObjectsTelemetryValue" Value="WSL" Type="string"/>
<RegistryValue Name="Provider" Value="Plan 9 Network Provider" Type="string"/>
<RegistryValue Name="ResName" Value="\\wsl.localhost" Type="string"/>
<RegistryKey Key="InitPropertyBag">
<RegistryValue Name="DisplayType" Value="2" Type="integer"/>
<RegistryValue Name="EnumObjectsTelemetryValue" Value="WSL" Type="string"/>
<RegistryValue Name="Provider" Value="Plan 9 Network Provider" Type="string"/>
<RegistryValue Name="ResName" Value="\\wsl.localhost" Type="string"/>
</RegistryKey>
</RegistryKey>
</RegistryKey>
</RegistryKey>
<RegistryKey Root="HKLM" Key="SOFTWARE\Microsoft\Windows\CurrentVersion\Explorer\HideDesktopIcons\NewStartPanel">
<RegistryValue Name="{B2B4A4D1-2754-4140-A2EB-9A76D9D7CDC6}" Value="1" Type="integer"/>
</RegistryKey>
<RegistryKey Root="HKLM" Key="SOFTWARE\Microsoft\Windows\CurrentVersion\Explorer\HideDesktopIcons\NewStartPanel">
<RegistryValue Name="{B2B4A4D1-2754-4140-A2EB-9A76D9D7CDC6}" Value="1" Type="integer"/>
</RegistryKey>
<RegistryKey Root="HKLM" Key="SOFTWARE\Microsoft\Windows\CurrentVersion\Explorer\Desktop\NameSpace\{B2B4A4D1-2754-4140-A2EB-9A76D9D7CDC6}">
<RegistryValue Value="Linux" Type="string"/>
</RegistryKey>
<RegistryKey Root="HKLM" Key="SOFTWARE\Microsoft\Windows\CurrentVersion\Explorer\Desktop\NameSpace\{B2B4A4D1-2754-4140-A2EB-9A76D9D7CDC6}">
<RegistryValue Value="Linux" Type="string"/>
</RegistryKey>
<RegistryKey Root="HKLM" Key="SOFTWARE\Microsoft\Windows\CurrentVersion\Explorer\IdListAliasTranslations\WSL">
<RegistryValue Name="Target" Value="::{B2B4A4D1-2754-4140-A2EB-9A76D9D7CDC6}" Type="string"/>
<RegistryValue Name="Source" Value="\\wsl.localhost" Type="string"/>
</RegistryKey>
<RegistryKey Root="HKLM" Key="SOFTWARE\Microsoft\Windows\CurrentVersion\Explorer\IdListAliasTranslations\WSL">
<RegistryValue Name="Target" Value="::{B2B4A4D1-2754-4140-A2EB-9A76D9D7CDC6}" Type="string"/>
<RegistryValue Name="Source" Value="\\wsl.localhost" Type="string"/>
</RegistryKey>
<RegistryKey Root="HKLM" Key="SOFTWARE\Microsoft\Windows\CurrentVersion\Explorer\IdListAliasTranslations\WSLLegacy">
<RegistryValue Name="Target" Value="::{B2B4A4D1-2754-4140-A2EB-9A76D9D7CDC6}" Type="string"/>
<RegistryValue Name="Source" Value="\\wsl$" Type="string"/>
</RegistryKey>
<RegistryKey Root="HKLM" Key="SOFTWARE\Microsoft\Windows\CurrentVersion\Explorer\IdListAliasTranslations\WSLLegacy">
<RegistryValue Name="Target" Value="::{B2B4A4D1-2754-4140-A2EB-9A76D9D7CDC6}" Type="string"/>
<RegistryValue Name="Source" Value="\\wsl$" Type="string"/>
</RegistryKey>
</Component>
<Component Id="explorershell" Guid="{93CBFF23-A04C-4344-A332-238CE5B97AEC}" UninstallWhenSuperseded="yes" DisableRegistryReflection="yes" Bitness="always64">
<?foreach PATH in SOFTWARE\Classes\Directory\shell\WSL;SOFTWARE\Classes\Directory\Background\shell\WSL;SOFTWARE\Classes\Drive\shell\WSL?>
<RegistryKey Root="HKLM" Key="$(var.PATH)">
<RegistryValue Value="@wsl.exe,-2" Type="string"/>
<RegistryValue Name="Extended" Value="" Type="string"/>
<RegistryValue Name="NoWorkingDirectory" Value="" Type="string"/>
<RegistryKey Key="command">
<RegistryValue Value='wsl.exe --cd "%V"' Type="string"/>
<RegistryKey Root="HKLM" Key="$(var.PATH)">
<RegistryValue Value="@wsl.exe,-2" Type="string"/>
<RegistryValue Name="Extended" Value="" Type="string"/>
<RegistryValue Name="NoWorkingDirectory" Value="" Type="string"/>
<RegistryKey Key="command">
<RegistryValue Value='wsl.exe --cd "%V"' Type="string"/>
</RegistryKey>
</RegistryKey>
</RegistryKey>
<?endforeach?>
<?endforeach?>
<ProgId Id="WSLDistributionTar" Description="WSL tar distribution" Icon="wsl.exe">
<Extension Id="wsl">
<Verb Id="open" Command="open" TargetFile="wsl.exe" Argument="--install --prompt-before-exit --from-file &quot;%1&quot;" />
</Extension>
</ProgId>
<ProgId Id="WSLDistributionTar" Description="WSL tar distribution" Icon="wsl.exe">
<Extension Id="wsl">
<Verb Id="open" Command="open" TargetFile="wsl.exe" Argument="--install --prompt-before-exit --from-file &quot;%1&quot;" />
</Extension>
</ProgId>
</Component>
<Component Id="wslservice" Guid="F0C8D6BA-1502-41E7-BF72-D93DFA134735" UninstallWhenSuperseded="yes" DisableRegistryReflection="yes" Bitness="always64">
@ -131,33 +131,13 @@
</RegistryKey>
</RegistryKey>
<RegistryKey Root="HKCR" Key="Interface\{82A7ABC8-6B50-43FC-AB96-15FBBE7E8760}">
<RegistryValue Value="IWSLAUserSession" Type="string" />
<RegistryKey Key="ProxyStubClsid32">
<RegistryValue Value="{4EA0C6DD-E9FF-48E7-994E-13A31D10DC60}" Type="string" />
</RegistryKey>
</RegistryKey>
<RegistryKey Root="HKCR" Key="Interface\{82A7ABC8-6B50-43FC-AB96-15FBBE7E8761}">
<RegistryValue Value="IWSLAVirtualMachine" Type="string" />
<RegistryKey Key="ProxyStubClsid32">
<RegistryValue Value="{4EA0C6DD-E9FF-48E7-994E-13A31D10DC60}" Type="string" />
</RegistryKey>
</RegistryKey>
<RegistryKey Root="HKCR" Key="Interface\{7BC4E198-6531-4FA6-ADE2-5EF3D2A04DFE}">
<RegistryValue Value="ITerminationCallback" Type="string" />
<RegistryKey Key="ProxyStubClsid32">
<RegistryValue Value="{4EA0C6DD-E9FF-48E7-994E-13A31D10DC60}" Type="string" />
</RegistryKey>
</RegistryKey>
<!-- WSLServiceProxyStub. -->
<RegistryKey Root="HKCR" Key="CLSID\{4EA0C6DD-E9FF-48E7-994E-13A31D10DC60}">
<RegistryValue Value="PSFactoryBuffer" Type="string" />
</RegistryKey>
<RegistryKey Root="HKCR" Key="CLSID\{4EA0C6DD-E9FF-48E7-994E-13A31D10DC60}\InProcServer32">
<RegistryValue Value="[INSTALLDIR]wslserviceproxystub.dll" Type="string" />
<RegistryValue Name="ThreadingModel" Value="Both" Type="string" />
<RegistryKey Key="InProcServer32">
<RegistryValue Value="[INSTALLDIR]wslserviceproxystub.dll" Type="string" />
<RegistryValue Name="ThreadingModel" Value="Both" Type="string" />
</RegistryKey>
</RegistryKey>
<!-- ILxssUserSession -->
@ -175,18 +155,6 @@
<RegistryValue Value="LxssUserSession" Type="string" />
</RegistryKey>
<!-- WSLAUserSession -->
<RegistryKey Root="HKCR" Key="CLSID\{a9b7a1b9-0671-405c-95f1-e0612cb4ce8f}">
<RegistryValue Name="AppId" Value="{370121D2-AA7E-4608-A86D-0BBAB9DA1A60}" Type="string" />
<RegistryValue Value="WSLAUserSession" Type="string" />
</RegistryKey>
<!-- WSLAVirtualMachine -->
<RegistryKey Root="HKCR" Key="CLSID\{0CFC5DC1-B6A7-45FC-8034-3FA9ED73CE30}">
<RegistryValue Name="AppId" Value="{370121D2-AA7E-4608-A86D-0BBAB9DA1A60}" Type="string" />
<RegistryValue Value="WSLAVirtualMachine" Type="string" />
</RegistryKey>
<!-- Notification server -->
<RegistryKey Root="HKCR" Key="CLSID\{2B9C59C3-98F1-45C8-B87B-12AE3C7927E8}\LocalServer32">
<RegistryValue Value='"[INSTALLDIR]wslhost.exe"' Type="string" />
@ -226,7 +194,7 @@
<RegistryKey Root="HKCR" Key="AppID\{7F82AD86-755B-4870-86B1-D2E68DFE8A49}">
<RegistryValue Name="DllSurrogate" Value="" Type="string"/>
<RegistryValue Name="AppIDFlags" Value="2048" Type="integer"/><!--0x800-->
<RegistryValue Name="AppIDFlags" Value="2048" Type="integer"/> <!--0x800-->
<!-- O:BAG:BAD:(A;;CCDCSW;;;AU)(A;;CCDCSW;;;PS)(A;;CCDCSW;;;SY) -->
<RegistryValue Name="AccessPermission" Value="01000480580000006800000000000000140000000200440003000000000014000B00000001010000000000050B000000000014000B00000001010000000000050A000000000014000B0000000101000000000005120000000102000000000005200000002002000001020000000000052000000020020000" Type="binary" />
@ -265,6 +233,68 @@
<ServiceControl Id="StopService" Stop="both" Remove="uninstall" Name="WSLService" Wait="yes" />
</Component>
<Component Id="wslaservice" Guid="DC97520E-BFA5-4559-960F-D580E151629F" UninstallWhenSuperseded="yes" DisableRegistryReflection="yes" Bitness="always64">
<!-- WSLAServiceProxyStub. -->
<RegistryKey Root="HKCR" Key="CLSID\{4EA0C6DD-E9FF-48E7-994E-13A31D10DC61}">
<RegistryValue Value="PSFactoryBuffer" Type="string" />
<RegistryKey Key="InProcServer32">
<RegistryValue Value="[INSTALLDIR]wslaserviceproxystub.dll" Type="string" />
<RegistryValue Name="ThreadingModel" Value="Both" Type="string" />
</RegistryKey>
</RegistryKey>
<!-- WSLAService COM app -->
<RegistryKey Root="HKCR" Key="AppID\{E9B79997-57E3-4201-AECC-6A464E530DD2}">
<!-- O:BAG:BAD:(A;;CCDCSW;;;AU)(A;;CCDCSW;;;PS)(A;;CCDCSW;;;SY) -->
<RegistryValue Name="AccessPermission" Value="01000480580000006800000000000000140000000200440003000000000014000B00000001010000000000050B000000000014000B00000001010000000000050A000000000014000B0000000101000000000005120000000102000000000005200000002002000001020000000000052000000020020000" Type="binary" />
<RegistryValue Name="LaunchPermission" Value="01000480580000006800000000000000140000000200440003000000000014000B00000001010000000000050B000000000014000B00000001010000000000050A000000000014000B0000000101000000000005120000000102000000000005200000002002000001020000000000052000000020020000" Type="binary" />
<RegistryValue Name="LocalService" Value="WSLAService" Type="string" />
</RegistryKey>
<!-- WSLAUserSession -->
<RegistryKey Root="HKCR" Key="CLSID\{a9b7a1b9-0671-405c-95f1-e0612cb4ce8f}">
<RegistryValue Name="AppId" Value="{E9B79997-57E3-4201-AECC-6A464E530DD2}" Type="string" />
<RegistryValue Value="WSLAUserSession" Type="string" />
</RegistryKey>
<!-- WSLAVirtualMachine -->
<RegistryKey Root="HKCR" Key="CLSID\{0CFC5DC1-B6A7-45FC-8034-3FA9ED73CE30}">
<RegistryValue Name="AppId" Value="{E9B79997-57E3-4201-AECC-6A464E530DD2}" Type="string" />
<RegistryValue Value="WSLAVirtualMachine" Type="string" />
</RegistryKey>
<!-- IWSLAUserSession-->
<RegistryKey Root="HKCR" Key="Interface\{82A7ABC8-6B50-43FC-AB96-15FBBE7E8760}">
<RegistryValue Value="IWSLAUserSession" Type="string" />
<RegistryKey Key="ProxyStubClsid32">
<RegistryValue Value="{4EA0C6DD-E9FF-48E7-994E-13A31D10DC61}" Type="string" />
</RegistryKey>
</RegistryKey>
<!-- IWSLAVirtualMachine-->
<RegistryKey Root="HKCR" Key="Interface\{82A7ABC8-6B50-43FC-AB96-15FBBE7E8761}">
<RegistryValue Value="IWSLAVirtualMachine" Type="string" />
<RegistryKey Key="ProxyStubClsid32">
<RegistryValue Value="{4EA0C6DD-E9FF-48E7-994E-13A31D10DC61}" Type="string" />
</RegistryKey>
</RegistryKey>
<!-- ITerminationCallback-->
<RegistryKey Root="HKCR" Key="Interface\{7BC4E198-6531-4FA6-ADE2-5EF3D2A04DFE}">
<RegistryValue Value="ITerminationCallback" Type="string" />
<RegistryKey Key="ProxyStubClsid32">
<RegistryValue Value="{4EA0C6DD-E9FF-48E7-994E-13A31D10DC61}" Type="string" />
</RegistryKey>
</RegistryKey>
<File Id="wslaservice.exe" Source="${BIN}/wslaservice.exe" KeyPath="yes" />
<File Id="wslaserviceproxystub.dll" Name="wslaserviceproxystub.dll" Source="${BIN}/wslaserviceproxystub.dll" />
<ServiceInstall Name="WSLAService" DisplayName="WSLA Service" Description="WSLA Service" Start="auto" Type="ownProcess" ErrorControl="normal" Account="LocalSystem" Vital="yes" Interactive="no" />
<ServiceControl Id="StopWSLAService" Stop="both" Remove="uninstall" Name="WSLAService" Wait="yes" />
</Component>
<Component Id="wslg" Guid="F0C8D6BA-1502-41E7-BF72-D93DFA134731" UninstallWhenSuperseded="yes" DisableRegistryReflection="yes" Bitness="always64">
<?if "${WSL_DEV_BINARY_PATH}" = "" ?>
<File Id="msrdc.exe" Source="${MSRDC_SOURCE_DIR}/${TARGET_PLATFORM}/msrdc.exe" />
@ -402,6 +432,7 @@
<Feature Id="WSL" Title="Windows Subsystem for Linux" Level="1">
<ComponentRef Id="wsl" />
<ComponentRef Id="wslservice" />
<ComponentRef Id="wslaservice" />
<ComponentRef Id="wslg" />
<ComponentRef Id="tools" />
<ComponentRef Id="explorershell" />
@ -501,21 +532,21 @@
Return="check"
Execute="deferred"
/>
<CustomAction Id="RemoveRegistryKeyProtections"
Impersonate="no"
BinaryRef="wslinstall.dll"
DllEntry="RemoveRegistryKeyProtections"
Return="check"
Execute="deferred"
<CustomAction Id="RemoveRegistryKeyProtections"
Impersonate="no"
BinaryRef="wslinstall.dll"
DllEntry="RemoveRegistryKeyProtections"
Return="check"
Execute="deferred"
/>
<CustomAction Id="UnregisterLspCategories"
Impersonate="no"
BinaryRef="wslinstall.dll"
DllEntry="UnregisterLspCategories"
Return="check"
Execute="deferred"
<CustomAction Id="UnregisterLspCategories"
Impersonate="no"
BinaryRef="wslinstall.dll"
DllEntry="UnregisterLspCategories"
Return="check"
Execute="deferred"
/>
<CustomAction Id="InstallMsix.SetProperty" Return="check" Property="InstallMsix" Value='[DATABASE]' Execute='immediate' />

View File

@ -1,83 +1,105 @@
set(SOURCES
ConsoleProgressBar.cpp
ConsoleProgressIndicator.cpp
DeviceHostProxy.cpp
DeviceHostProxy.h
disk.cpp
Distribution.cpp
Dmesg.cpp
Dmesg.h
DnsResolver.cpp
DnsTunnelingChannel.cpp
ExecutionContext.cpp
filesystem.cpp
GnsChannel.cpp
HandleConsoleProgressBar.cpp
hcs.cpp
helpers.cpp
interop.cpp
ExecutionContext.cpp
socket.cpp
hvsocket.cpp
interop.cpp
Localization.cpp
lxssbusclient.cpp
lxssclient.cpp
LxssMessagePort.cpp
LxssSecurity.cpp
LxssServerPort.cpp
NatNetworking.cpp
notifications.cpp
Redirector.cpp
registry.cpp
relay.cpp
RingBuffer.cpp
socket.cpp
string.cpp
SubProcess.cpp
svccomm.cpp
svccommio.cpp
WslClient.cpp
WslCoreConfig.cpp
WslCoreFilesystem.cpp
WslCoreFirewallSupport.cpp
WslCoreHostDnsInfo.cpp
WslCoreNetworkEndpointSettings.cpp
WslCoreNetworkingSupport.cpp
WslInstall.cpp
WslSecurity.cpp
WslTelemetry.cpp
wslutil.cpp
notifications.cpp)
)
set(HEADERS
../../../generated/Localization.h
../../shared/inc/CommandLine.h
../../shared/inc/defs.h
../../shared/inc/lxfsshares.h
../../shared/inc/lxinitshared.h
../../shared/inc/SocketChannel.h
../../shared/inc/socketshared.h
../../shared/inc/hns_schema.h
../../shared/inc/JsonUtils.h
../../shared/inc/stringshared.h
../../shared/inc/retryshared.h
../../shared/inc/message.h
../../shared/inc/prettyprintshared.h
../inc/WslPluginApi.h
../inc/wslpolicies.h
../inc/lxssbusclient.h
../inc/lxssclient.h
../inc/LxssDynamicFunction.h
../inc/traceloggingconfig.h
../inc/wdk.h
../inc/wsl.h
../inc/wslconfig.h
../inc/wsl.h
../inc/wslhost.h
../inc/WslPluginApi.h
../inc/wslpolicies.h
../inc/wslrelay.h
../../shared/inc/CommandLine.h
../../shared/inc/defs.h
../../shared/inc/hns_schema.h
../../shared/inc/JsonUtils.h
../../shared/inc/lxfsshares.h
../../shared/inc/lxinitshared.h
../../shared/inc/message.h
../../shared/inc/prettyprintshared.h
../../shared/inc/retryshared.h
../../shared/inc/SocketChannel.h
../../shared/inc/socketshared.h
../../shared/inc/stringshared.h
ConsoleProgressBar.h
ConsoleProgressIndicator.h
disk.hpp
Distribution.h
DnsResolver.h
DnsTunnelingChannel.h
ExecutionContext.h
filesystem.hpp
GnsChannel.h
HandleConsoleProgressBar.h
hcs.hpp
hcs_schema.h
helpers.hpp
HandleConsoleProgressBar.h
interop.hpp
ExecutionContext.h
socket.hpp
hvsocket.hpp
INetworkingEngine.h
interop.hpp
LxssMessagePort.h
LxssPort.h
LxssSecurity.h
LxssServerPort.h
NatNetworking.h
notifications.h
precomp.h
Redirector.h
registry.hpp
relay.hpp
RingBuffer.h
socket.hpp
string.hpp
Stringify.h
SubProcess.h
@ -85,13 +107,17 @@ set(HEADERS
svccommio.hpp
WslClient.h
WslCoreConfig.h
WslCoreFilesystem.h
WslCoreFirewallSupport.h
WslCoreHostDnsInfo.h
WslCoreMessageQueue.h
WslCoreNetworkEndpointSettings.h
WslCoreNetworkingSupport.h
WslInstall.h
WslSecurity.h
WslTelemetry.h
wslutil.h
notifications.h)
wslutil.cpp
)
add_library(common STATIC ${SOURCES} ${HEADERS})
add_dependencies(common wslserviceidl localization wslservicemc wslinstalleridl)

View File

@ -1,261 +1,261 @@
// Copyright (C) Microsoft Corporation. All rights reserved.
#include "precomp.h"
#include "DeviceHostProxy.h"
// This template works around a limitation with decltype on overloaded functions. It will be able
// to get the correct version of GetVmWorkerProcess based on the provided type arguments. By
// doing it this way, a compiler error will be generated if someone changes the signature of
// GetVmWorkerProcess.
//
// The way this works: decltype(GetVmWorkerProcess) does not work because it's overloaded.
// decltype(GetVmWorkerProcess(arg1, ...)) works to select an overload if you have values of the
// correct type (std::declval<T>() generates a value of the specified type), however the result
// of that is the function's return type, not the function's type, so the argument types must
// be repeated to reconstruct the function type.
template <typename... Args>
using GetVmWorkerProcessType = decltype(GetVmWorkerProcess(std::declval<Args>()...))(Args...);
// Limit the number of allowed doorbells registered by an external HDV vdev. Currently virtio-9p only uses
// one doorbell and wsldevicehost uses only two.
#define DEVICE_HOST_PROXY_DOORBELL_LIMIT 8
using namespace wsl::windows::common::hcs;
DeviceHostProxy::DeviceHostProxy(const std::wstring& VmId, const GUID& RuntimeId) :
m_systemId{VmId}, m_runtimeId{RuntimeId}, m_system{wsl::windows::common::hcs::OpenComputeSystem(VmId.c_str(), GENERIC_ALL)}, m_shutdown{false}
{
m_devicesShutdown = false;
}
GUID DeviceHostProxy::AddNewDevice(const GUID& Type, const wil::com_ptr<IPlan9FileSystem>& Plan9Fs, const std::wstring& VirtIoTag)
{
const wrl::ComPtr<IUnknown> thisUnknown{CastToUnknown()};
GUID instanceId{};
THROW_IF_FAILED(UuidCreate(&instanceId));
// Tell the device host to create the device.
THROW_IF_FAILED(Plan9Fs->CreateVirtioDevice(m_systemId.c_str(), thisUnknown.Get(), VirtIoTag.c_str(), &instanceId));
// Add the instance ID to the list of known devices. This must be done before the device is
// added to the system, because doing that can cause the register doorbell function to be
// called.
// N.B. It will be removed if there is a failure.
{
auto lock = m_devicesLock.lock_exclusive();
THROW_HR_IF(E_CHANGED_STATE, m_devicesShutdown);
m_devices.emplace(instanceId, DeviceHostProxyEntry{});
}
auto removeOnFailure = wil::scope_exit_log(WI_DIAGNOSTICS_INFO, [&]() {
auto lock = m_devicesLock.lock_exclusive();
m_devices.erase(instanceId);
});
// Add the device to the compute system on behalf of the device host.
ModifySettingRequest<FlexibleIoDevice> request;
request.RequestType = ModifyRequestType::Add;
request.ResourcePath = L"VirtualMachine/Devices/FlexibleIov/";
request.ResourcePath += wsl::shared::string::GuidToString<wchar_t>(instanceId, wsl::shared::string::GuidToStringFlags::None);
request.Settings.EmulatorId = Type;
request.Settings.HostingModel = FlexibleIoDeviceHostingModel::ExternalRestricted;
wsl::windows::common::hcs::ModifyComputeSystem(m_system.get(), wsl::shared::ToJsonW(request).c_str());
removeOnFailure.release();
return instanceId;
}
void DeviceHostProxy::AddRemoteFileSystem(const GUID& ImplementationClsid, const std::wstring& Tag, const wil::com_ptr<IPlan9FileSystem>& Plan9Fs)
{
auto lock = m_lock.lock_exclusive();
THROW_HR_IF(E_CHANGED_STATE, m_shutdown);
// Make sure there are no duplicate tags.
for (auto& entry : m_fileSystems)
{
THROW_HR_IF(E_INVALIDARG, entry.ImplementationClsid == ImplementationClsid && entry.Tag == Tag);
}
m_fileSystems.emplace_back(ImplementationClsid, Tag, Plan9Fs);
}
wil::com_ptr<IPlan9FileSystem> DeviceHostProxy::GetRemoteFileSystem(const GUID& ImplementationClsid, std::wstring_view Tag)
{
auto lock = m_lock.lock_shared();
THROW_HR_IF(E_CHANGED_STATE, m_shutdown);
for (auto& entry : m_fileSystems)
{
if (entry.ImplementationClsid == ImplementationClsid && entry.Tag == Tag)
{
return entry.Instance;
}
}
return {};
}
void DeviceHostProxy::Shutdown()
{
{
auto lock = m_lock.lock_exclusive();
m_fileSystems.clear();
m_shutdown = true;
}
{
auto lock = m_devicesLock.lock_exclusive();
m_devices.clear();
m_devicesShutdown = true;
}
}
HRESULT
DeviceHostProxy::RegisterDeviceHost(_In_ IVmDeviceHost* DeviceHost, _In_ DWORD ProcessId, _Out_ UINT64* IpcSectionHandle)
try
{
//
// Because HdvProxyDeviceHost is not part of the API set, it is loaded here dynamically.
//
static LxssDynamicFunction<decltype(HdvProxyDeviceHost)> proxyDeviceHost{c_hdvModuleName, "HdvProxyDeviceHost"};
const wil::com_ptr<IVmDeviceHost> remoteHost = DeviceHost;
const wil::com_ptr<IUnknown> unknown = remoteHost.query<IUnknown>();
THROW_IF_FAILED(proxyDeviceHost(m_system.get(), unknown.get(), ProcessId, IpcSectionHandle));
return S_OK;
}
CATCH_RETURN()
HRESULT
DeviceHostProxy::NotifyAllDevicesInUse(_In_ LPCWSTR Tag)
try
{
//
// Add another Plan9 virtio device to the guest so additional mount commands will be possible.
// This callback should be unused by virtiofs devices because a device is created for every
// AddSharePath call.
//
auto p9fs = GetRemoteFileSystem(__uuidof(p9fs::Plan9FileSystem), Tag);
THROW_HR_IF(E_NOT_SET, !p9fs);
(void)AddNewDevice(VIRTIO_PLAN9_DEVICE_ID, p9fs, Tag);
return S_OK;
}
CATCH_RETURN()
HRESULT
DeviceHostProxy::RegisterDoorbell(const GUID& InstanceId, UINT8 BarIndex, UINT64 Offset, UINT64 TriggerValue, UINT64 Flags, HANDLE Event)
try
{
auto lock = m_devicesLock.lock_exclusive();
RETURN_HR_IF(E_CHANGED_STATE, m_devicesShutdown);
// Check if the device is one of the known devices that doorbells can be registered for, and
// if the device has not already registered a doorbell.
// N.B. For security it is enforced that each device can only register a small number of doorbells.
// Currently virtio-9p only uses one and the external virtio device uses two.
const auto knownDevice = m_devices.find(InstanceId);
RETURN_HR_IF(E_ACCESSDENIED, knownDevice == m_devices.end() || knownDevice->second.DoorbellCount == DEVICE_HOST_PROXY_DOORBELL_LIMIT);
if (!knownDevice->second.MemoryNotification)
{
// Get an interface to the worker process to query devices.
if (!m_deviceAccess)
{
static LxssDynamicFunction<GetVmWorkerProcessType<REFGUID, REFIID, IUnknown**>> getVmWorker{
c_vmwpctrlModuleName, "GetVmWorkerProcess"};
RETURN_IF_FAILED(getVmWorker(m_runtimeId, __uuidof(*m_deviceAccess), reinterpret_cast<IUnknown**>(&m_deviceAccess)));
}
RETURN_HR_IF(E_NOINTERFACE, !m_deviceAccess);
// Retrieve the device's memory notification interface to register the doorbell, and store it
// to be used during unregistration.
wil::com_ptr<IUnknown> device;
RETURN_IF_FAILED(m_deviceAccess->GetDevice(FLEXIO_DEVICE_ID, InstanceId, &device));
knownDevice->second.MemoryNotification = device.query<IVmFiovGuestMemoryFastNotification>();
}
const auto result = knownDevice->second.MemoryNotification->RegisterDoorbell(
static_cast<FIOV_BAR_SELECTOR>(BarIndex), Offset, TriggerValue, Flags, Event);
if (SUCCEEDED(result))
{
++knownDevice->second.DoorbellCount;
}
return result;
}
CATCH_RETURN()
HRESULT
DeviceHostProxy::UnregisterDoorbell(const GUID& InstanceId, UINT8 BarIndex, UINT64 Offset, UINT64 TriggerValue, UINT64 Flags)
try
{
auto lock = m_devicesLock.lock_exclusive();
RETURN_HR_IF(E_CHANGED_STATE, m_devicesShutdown);
// Check if the device is a known device and has registered a doorbell.
// N.B. If the device is being removed, the device can't be retrieved from the worker process
// so it's necessary to use the stored COM pointer.
const auto device = m_devices.find(InstanceId);
RETURN_HR_IF(E_ACCESSDENIED, device == m_devices.end() || device->second.DoorbellCount == 0);
RETURN_IF_FAILED(device->second.MemoryNotification->UnregisterDoorbell(static_cast<FIOV_BAR_SELECTOR>(BarIndex), Offset, TriggerValue, Flags));
if (--device->second.DoorbellCount == 0)
{
device->second.MemoryNotification.reset();
}
return S_OK;
}
CATCH_RETURN()
HRESULT
DeviceHostProxy::CreateSectionBackedMmioRange(
const GUID& InstanceId, UINT8 BarIndex, UINT64 BarOffsetInPages, UINT64 PageCount, UINT64 MappingFlags, HANDLE SectionHandle, UINT64 SectionOffsetInPages)
try
{
auto lock = m_devicesLock.lock_exclusive();
RETURN_HR_IF(E_CHANGED_STATE, m_devicesShutdown);
// Check if the device is one of the known devices.
const auto knownDevice = m_devices.find(InstanceId);
THROW_HR_IF(E_ACCESSDENIED, knownDevice == m_devices.end());
if (!knownDevice->second.MemoryMapping)
{
// Get an interface to the worker process to query devices.
if (!m_deviceAccess)
{
static LxssDynamicFunction<GetVmWorkerProcessType<REFGUID, REFIID, IUnknown**>> getVmWorker{
c_vmwpctrlModuleName, "GetVmWorkerProcess"};
THROW_IF_FAILED(getVmWorker(m_runtimeId, __uuidof(*m_deviceAccess), reinterpret_cast<IUnknown**>(&m_deviceAccess)));
}
THROW_HR_IF(E_NOINTERFACE, !m_deviceAccess);
// Retrieve the device specific interface to manage mapped sections.
wil::com_ptr<IUnknown> device;
THROW_IF_FAILED(m_deviceAccess->GetDevice(FLEXIO_DEVICE_ID, InstanceId, &device));
knownDevice->second.MemoryMapping = device.query<IVmFiovGuestMmioMappings>();
}
THROW_IF_FAILED(knownDevice->second.MemoryMapping->CreateSectionBackedMmioRange(
static_cast<FIOV_BAR_SELECTOR>(BarIndex), BarOffsetInPages, PageCount, static_cast<FiovMmioMappingFlags>(MappingFlags), SectionHandle, SectionOffsetInPages));
return S_OK;
}
CATCH_RETURN()
HRESULT
DeviceHostProxy::DestroySectionBackedMmioRange(const GUID& InstanceId, UINT8 BarIndex, UINT64 BarOffsetInPages)
try
{
auto lock = m_devicesLock.lock_exclusive();
RETURN_HR_IF(E_CHANGED_STATE, m_devicesShutdown);
const auto device = m_devices.find(InstanceId);
RETURN_HR_IF(E_ACCESSDENIED, device == m_devices.end() || !device->second.MemoryMapping);
RETURN_IF_FAILED(device->second.MemoryMapping->DestroySectionBackedMmioRange(static_cast<FIOV_BAR_SELECTOR>(BarIndex), BarOffsetInPages));
return S_OK;
}
// Copyright (C) Microsoft Corporation. All rights reserved.
#include "precomp.h"
#include "DeviceHostProxy.h"
// This template works around a limitation with decltype on overloaded functions. It will be able
// to get the correct version of GetVmWorkerProcess based on the provided type arguments. By
// doing it this way, a compiler error will be generated if someone changes the signature of
// GetVmWorkerProcess.
//
// The way this works: decltype(GetVmWorkerProcess) does not work because it's overloaded.
// decltype(GetVmWorkerProcess(arg1, ...)) works to select an overload if you have values of the
// correct type (std::declval<T>() generates a value of the specified type), however the result
// of that is the function's return type, not the function's type, so the argument types must
// be repeated to reconstruct the function type.
template <typename... Args>
using GetVmWorkerProcessType = decltype(GetVmWorkerProcess(std::declval<Args>()...))(Args...);
// Limit the number of allowed doorbells registered by an external HDV vdev. Currently virtio-9p only uses
// one doorbell and wsldevicehost uses only two.
#define DEVICE_HOST_PROXY_DOORBELL_LIMIT 8
using namespace wsl::windows::common::hcs;
DeviceHostProxy::DeviceHostProxy(const std::wstring& VmId, const GUID& RuntimeId) :
m_systemId{VmId}, m_runtimeId{RuntimeId}, m_system{wsl::windows::common::hcs::OpenComputeSystem(VmId.c_str(), GENERIC_ALL)}, m_shutdown{false}
{
m_devicesShutdown = false;
}
GUID DeviceHostProxy::AddNewDevice(const GUID& Type, const wil::com_ptr<IPlan9FileSystem>& Plan9Fs, const std::wstring& VirtIoTag)
{
const wrl::ComPtr<IUnknown> thisUnknown{CastToUnknown()};
GUID instanceId{};
THROW_IF_FAILED(UuidCreate(&instanceId));
// Tell the device host to create the device.
THROW_IF_FAILED(Plan9Fs->CreateVirtioDevice(m_systemId.c_str(), thisUnknown.Get(), VirtIoTag.c_str(), &instanceId));
// Add the instance ID to the list of known devices. This must be done before the device is
// added to the system, because doing that can cause the register doorbell function to be
// called.
// N.B. It will be removed if there is a failure.
{
auto lock = m_devicesLock.lock_exclusive();
THROW_HR_IF(E_CHANGED_STATE, m_devicesShutdown);
m_devices.emplace(instanceId, DeviceHostProxyEntry{});
}
auto removeOnFailure = wil::scope_exit_log(WI_DIAGNOSTICS_INFO, [&]() {
auto lock = m_devicesLock.lock_exclusive();
m_devices.erase(instanceId);
});
// Add the device to the compute system on behalf of the device host.
ModifySettingRequest<FlexibleIoDevice> request;
request.RequestType = ModifyRequestType::Add;
request.ResourcePath = L"VirtualMachine/Devices/FlexibleIov/";
request.ResourcePath += wsl::shared::string::GuidToString<wchar_t>(instanceId, wsl::shared::string::GuidToStringFlags::None);
request.Settings.EmulatorId = Type;
request.Settings.HostingModel = FlexibleIoDeviceHostingModel::ExternalRestricted;
wsl::windows::common::hcs::ModifyComputeSystem(m_system.get(), wsl::shared::ToJsonW(request).c_str());
removeOnFailure.release();
return instanceId;
}
void DeviceHostProxy::AddRemoteFileSystem(const GUID& ImplementationClsid, const std::wstring& Tag, const wil::com_ptr<IPlan9FileSystem>& Plan9Fs)
{
auto lock = m_lock.lock_exclusive();
THROW_HR_IF(E_CHANGED_STATE, m_shutdown);
// Make sure there are no duplicate tags.
for (auto& entry : m_fileSystems)
{
THROW_HR_IF(E_INVALIDARG, entry.ImplementationClsid == ImplementationClsid && entry.Tag == Tag);
}
m_fileSystems.emplace_back(ImplementationClsid, Tag, Plan9Fs);
}
wil::com_ptr<IPlan9FileSystem> DeviceHostProxy::GetRemoteFileSystem(const GUID& ImplementationClsid, std::wstring_view Tag)
{
auto lock = m_lock.lock_shared();
THROW_HR_IF(E_CHANGED_STATE, m_shutdown);
for (auto& entry : m_fileSystems)
{
if (entry.ImplementationClsid == ImplementationClsid && entry.Tag == Tag)
{
return entry.Instance;
}
}
return {};
}
void DeviceHostProxy::Shutdown()
{
{
auto lock = m_lock.lock_exclusive();
m_fileSystems.clear();
m_shutdown = true;
}
{
auto lock = m_devicesLock.lock_exclusive();
m_devices.clear();
m_devicesShutdown = true;
}
}
HRESULT
DeviceHostProxy::RegisterDeviceHost(_In_ IVmDeviceHost* DeviceHost, _In_ DWORD ProcessId, _Out_ UINT64* IpcSectionHandle)
try
{
//
// Because HdvProxyDeviceHost is not part of the API set, it is loaded here dynamically.
//
static LxssDynamicFunction<decltype(HdvProxyDeviceHost)> proxyDeviceHost{c_hdvModuleName, "HdvProxyDeviceHost"};
const wil::com_ptr<IVmDeviceHost> remoteHost = DeviceHost;
const wil::com_ptr<IUnknown> unknown = remoteHost.query<IUnknown>();
THROW_IF_FAILED(proxyDeviceHost(m_system.get(), unknown.get(), ProcessId, IpcSectionHandle));
return S_OK;
}
CATCH_RETURN()
HRESULT
DeviceHostProxy::NotifyAllDevicesInUse(_In_ LPCWSTR Tag)
try
{
//
// Add another Plan9 virtio device to the guest so additional mount commands will be possible.
// This callback should be unused by virtiofs devices because a device is created for every
// AddSharePath call.
//
auto p9fs = GetRemoteFileSystem(__uuidof(p9fs::Plan9FileSystem), Tag);
THROW_HR_IF(E_NOT_SET, !p9fs);
(void)AddNewDevice(VIRTIO_PLAN9_DEVICE_ID, p9fs, Tag);
return S_OK;
}
CATCH_RETURN()
HRESULT
DeviceHostProxy::RegisterDoorbell(const GUID& InstanceId, UINT8 BarIndex, UINT64 Offset, UINT64 TriggerValue, UINT64 Flags, HANDLE Event)
try
{
auto lock = m_devicesLock.lock_exclusive();
RETURN_HR_IF(E_CHANGED_STATE, m_devicesShutdown);
// Check if the device is one of the known devices that doorbells can be registered for, and
// if the device has not already registered a doorbell.
// N.B. For security it is enforced that each device can only register a small number of doorbells.
// Currently virtio-9p only uses one and the external virtio device uses two.
const auto knownDevice = m_devices.find(InstanceId);
RETURN_HR_IF(E_ACCESSDENIED, knownDevice == m_devices.end() || knownDevice->second.DoorbellCount == DEVICE_HOST_PROXY_DOORBELL_LIMIT);
if (!knownDevice->second.MemoryNotification)
{
// Get an interface to the worker process to query devices.
if (!m_deviceAccess)
{
static LxssDynamicFunction<GetVmWorkerProcessType<REFGUID, REFIID, IUnknown**>> getVmWorker{
c_vmwpctrlModuleName, "GetVmWorkerProcess"};
RETURN_IF_FAILED(getVmWorker(m_runtimeId, __uuidof(*m_deviceAccess), reinterpret_cast<IUnknown**>(&m_deviceAccess)));
}
RETURN_HR_IF(E_NOINTERFACE, !m_deviceAccess);
// Retrieve the device's memory notification interface to register the doorbell, and store it
// to be used during unregistration.
wil::com_ptr<IUnknown> device;
RETURN_IF_FAILED(m_deviceAccess->GetDevice(FLEXIO_DEVICE_ID, InstanceId, &device));
knownDevice->second.MemoryNotification = device.query<IVmFiovGuestMemoryFastNotification>();
}
const auto result = knownDevice->second.MemoryNotification->RegisterDoorbell(
static_cast<FIOV_BAR_SELECTOR>(BarIndex), Offset, TriggerValue, Flags, Event);
if (SUCCEEDED(result))
{
++knownDevice->second.DoorbellCount;
}
return result;
}
CATCH_RETURN()
HRESULT
DeviceHostProxy::UnregisterDoorbell(const GUID& InstanceId, UINT8 BarIndex, UINT64 Offset, UINT64 TriggerValue, UINT64 Flags)
try
{
auto lock = m_devicesLock.lock_exclusive();
RETURN_HR_IF(E_CHANGED_STATE, m_devicesShutdown);
// Check if the device is a known device and has registered a doorbell.
// N.B. If the device is being removed, the device can't be retrieved from the worker process
// so it's necessary to use the stored COM pointer.
const auto device = m_devices.find(InstanceId);
RETURN_HR_IF(E_ACCESSDENIED, device == m_devices.end() || device->second.DoorbellCount == 0);
RETURN_IF_FAILED(device->second.MemoryNotification->UnregisterDoorbell(static_cast<FIOV_BAR_SELECTOR>(BarIndex), Offset, TriggerValue, Flags));
if (--device->second.DoorbellCount == 0)
{
device->second.MemoryNotification.reset();
}
return S_OK;
}
CATCH_RETURN()
HRESULT
DeviceHostProxy::CreateSectionBackedMmioRange(
const GUID& InstanceId, UINT8 BarIndex, UINT64 BarOffsetInPages, UINT64 PageCount, UINT64 MappingFlags, HANDLE SectionHandle, UINT64 SectionOffsetInPages)
try
{
auto lock = m_devicesLock.lock_exclusive();
RETURN_HR_IF(E_CHANGED_STATE, m_devicesShutdown);
// Check if the device is one of the known devices.
const auto knownDevice = m_devices.find(InstanceId);
THROW_HR_IF(E_ACCESSDENIED, knownDevice == m_devices.end());
if (!knownDevice->second.MemoryMapping)
{
// Get an interface to the worker process to query devices.
if (!m_deviceAccess)
{
static LxssDynamicFunction<GetVmWorkerProcessType<REFGUID, REFIID, IUnknown**>> getVmWorker{
c_vmwpctrlModuleName, "GetVmWorkerProcess"};
THROW_IF_FAILED(getVmWorker(m_runtimeId, __uuidof(*m_deviceAccess), reinterpret_cast<IUnknown**>(&m_deviceAccess)));
}
THROW_HR_IF(E_NOINTERFACE, !m_deviceAccess);
// Retrieve the device specific interface to manage mapped sections.
wil::com_ptr<IUnknown> device;
THROW_IF_FAILED(m_deviceAccess->GetDevice(FLEXIO_DEVICE_ID, InstanceId, &device));
knownDevice->second.MemoryMapping = device.query<IVmFiovGuestMmioMappings>();
}
THROW_IF_FAILED(knownDevice->second.MemoryMapping->CreateSectionBackedMmioRange(
static_cast<FIOV_BAR_SELECTOR>(BarIndex), BarOffsetInPages, PageCount, static_cast<FiovMmioMappingFlags>(MappingFlags), SectionHandle, SectionOffsetInPages));
return S_OK;
}
CATCH_RETURN()
HRESULT
DeviceHostProxy::DestroySectionBackedMmioRange(const GUID& InstanceId, UINT8 BarIndex, UINT64 BarOffsetInPages)
try
{
auto lock = m_devicesLock.lock_exclusive();
RETURN_HR_IF(E_CHANGED_STATE, m_devicesShutdown);
const auto device = m_devices.find(InstanceId);
RETURN_HR_IF(E_ACCESSDENIED, device == m_devices.end() || !device->second.MemoryMapping);
RETURN_IF_FAILED(device->second.MemoryMapping->DestroySectionBackedMmioRange(static_cast<FIOV_BAR_SELECTOR>(BarIndex), BarOffsetInPages));
return S_OK;
}
CATCH_RETURN()

View File

@ -1,76 +1,76 @@
// Copyright (C) Microsoft Corporation. All rights reserved.
#pragma once
#include <windowsdefs.h>
#include "hcs.hpp"
namespace wrl = Microsoft::WRL;
class DeviceHostProxy : public wrl::RuntimeClass<wrl::RuntimeClassFlags<wrl::RuntimeClassType::ClassicCom>, IVmDeviceHostSupport, IPlan9FileSystemHost>
{
public:
DeviceHostProxy(const std::wstring& VmId, const GUID& RuntimeId);
GUID AddNewDevice(const GUID& Type, const wil::com_ptr<IPlan9FileSystem>& Plan9Fs, const std::wstring& VirtIoTag);
void AddRemoteFileSystem(const GUID& ImplementationClsid, const std::wstring& Tag, const wil::com_ptr<IPlan9FileSystem>& Plan9Fs);
wil::com_ptr<IPlan9FileSystem> GetRemoteFileSystem(const GUID& ImplementationClsid, std::wstring_view Tag);
void Shutdown();
//
// IVmDeviceHostSupport
//
IFACEMETHOD(RegisterDeviceHost)(_In_ IVmDeviceHost* DeviceHost, _In_ DWORD ProcessId, _Out_ UINT64* IpcSectionHandle) override;
//
// IPlan9FileSystemHost
//
IFACEMETHOD(NotifyAllDevicesInUse)(_In_ LPCWSTR Tag) override;
IFACEMETHOD(RegisterDoorbell)(const GUID& InstanceId, UINT8 BarIndex, UINT64 Offset, UINT64 TriggerValue, UINT64 Flags, HANDLE Event) override;
IFACEMETHOD(UnregisterDoorbell)(const GUID& InstanceId, UINT8 BarIndex, UINT64 Offset, UINT64 TriggerValue, UINT64 Flags) override;
IFACEMETHOD(CreateSectionBackedMmioRange)(
const GUID& InstanceId, UINT8 BarIndex, UINT64 BarOffsetInPages, UINT64 PageCount, UINT64 MappingFlags, HANDLE SectionHandle, UINT64 SectionOffsetInPages) override;
IFACEMETHOD(DestroySectionBackedMmioRange)(const GUID& InstanceId, UINT8 BarIndex, UINT64 BarOffsetInPages) override;
private:
struct RemoteFileSystemInfo
{
RemoteFileSystemInfo(GUID ImplementationClsid, const std::wstring& Tag, const wil::com_ptr<IPlan9FileSystem>& Instance) :
ImplementationClsid{ImplementationClsid}, Tag{Tag}, Instance{Instance}
{
}
GUID ImplementationClsid;
std::wstring Tag;
wil::com_ptr<IPlan9FileSystem> Instance;
};
std::wstring m_systemId;
GUID m_runtimeId;
wsl::windows::common::hcs::unique_hcs_system m_system;
wil::srwlock m_lock;
std::vector<RemoteFileSystemInfo> m_fileSystems;
bool m_shutdown;
struct DeviceHostProxyEntry
{
wil::com_ptr<IVmFiovGuestMemoryFastNotification> MemoryNotification;
wil::com_ptr<IVmFiovGuestMmioMappings> MemoryMapping;
size_t DoorbellCount = 0;
};
wil::com_ptr<IVmVirtualDeviceAccess> m_deviceAccess;
wil::srwlock m_devicesLock;
std::map<GUID, DeviceHostProxyEntry, wsl::windows::common::helpers::GuidLess> m_devices;
bool m_devicesShutdown;
static constexpr LPCWSTR c_hdvModuleName = L"vmdevicehost.dll";
static constexpr LPCWSTR c_vmwpctrlModuleName = L"vmwpctrl.dll";
// Copyright (C) Microsoft Corporation. All rights reserved.
#pragma once
#include <windowsdefs.h>
#include "hcs.hpp"
namespace wrl = Microsoft::WRL;
class DeviceHostProxy : public wrl::RuntimeClass<wrl::RuntimeClassFlags<wrl::RuntimeClassType::ClassicCom>, IVmDeviceHostSupport, IPlan9FileSystemHost>
{
public:
DeviceHostProxy(const std::wstring& VmId, const GUID& RuntimeId);
GUID AddNewDevice(const GUID& Type, const wil::com_ptr<IPlan9FileSystem>& Plan9Fs, const std::wstring& VirtIoTag);
void AddRemoteFileSystem(const GUID& ImplementationClsid, const std::wstring& Tag, const wil::com_ptr<IPlan9FileSystem>& Plan9Fs);
wil::com_ptr<IPlan9FileSystem> GetRemoteFileSystem(const GUID& ImplementationClsid, std::wstring_view Tag);
void Shutdown();
//
// IVmDeviceHostSupport
//
IFACEMETHOD(RegisterDeviceHost)(_In_ IVmDeviceHost* DeviceHost, _In_ DWORD ProcessId, _Out_ UINT64* IpcSectionHandle) override;
//
// IPlan9FileSystemHost
//
IFACEMETHOD(NotifyAllDevicesInUse)(_In_ LPCWSTR Tag) override;
IFACEMETHOD(RegisterDoorbell)(const GUID& InstanceId, UINT8 BarIndex, UINT64 Offset, UINT64 TriggerValue, UINT64 Flags, HANDLE Event) override;
IFACEMETHOD(UnregisterDoorbell)(const GUID& InstanceId, UINT8 BarIndex, UINT64 Offset, UINT64 TriggerValue, UINT64 Flags) override;
IFACEMETHOD(CreateSectionBackedMmioRange)(
const GUID& InstanceId, UINT8 BarIndex, UINT64 BarOffsetInPages, UINT64 PageCount, UINT64 MappingFlags, HANDLE SectionHandle, UINT64 SectionOffsetInPages) override;
IFACEMETHOD(DestroySectionBackedMmioRange)(const GUID& InstanceId, UINT8 BarIndex, UINT64 BarOffsetInPages) override;
private:
struct RemoteFileSystemInfo
{
RemoteFileSystemInfo(GUID ImplementationClsid, const std::wstring& Tag, const wil::com_ptr<IPlan9FileSystem>& Instance) :
ImplementationClsid{ImplementationClsid}, Tag{Tag}, Instance{Instance}
{
}
GUID ImplementationClsid;
std::wstring Tag;
wil::com_ptr<IPlan9FileSystem> Instance;
};
std::wstring m_systemId;
GUID m_runtimeId;
wsl::windows::common::hcs::unique_hcs_system m_system;
wil::srwlock m_lock;
std::vector<RemoteFileSystemInfo> m_fileSystems;
bool m_shutdown;
struct DeviceHostProxyEntry
{
wil::com_ptr<IVmFiovGuestMemoryFastNotification> MemoryNotification;
wil::com_ptr<IVmFiovGuestMmioMappings> MemoryMapping;
size_t DoorbellCount = 0;
};
wil::com_ptr<IVmVirtualDeviceAccess> m_deviceAccess;
wil::srwlock m_devicesLock;
std::map<GUID, DeviceHostProxyEntry, wsl::windows::common::helpers::GuidLess> m_devices;
bool m_devicesShutdown;
static constexpr LPCWSTR c_hdvModuleName = L"vmdevicehost.dll";
static constexpr LPCWSTR c_vmwpctrlModuleName = L"vmwpctrl.dll";
};

View File

@ -1,417 +1,417 @@
// Copyright (C) Microsoft Corporation. All rights reserved.
#include <LxssDynamicFunction.h>
#include "precomp.h"
#include "DnsResolver.h"
using wsl::core::networking::DnsResolver;
static constexpr auto c_dnsModuleName = L"dnsapi.dll";
std::optional<LxssDynamicFunction<decltype(DnsQueryRaw)>> DnsResolver::s_dnsQueryRaw;
std::optional<LxssDynamicFunction<decltype(DnsCancelQueryRaw)>> DnsResolver::s_dnsCancelQueryRaw;
std::optional<LxssDynamicFunction<decltype(DnsQueryRawResultFree)>> DnsResolver::s_dnsQueryRawResultFree;
HRESULT DnsResolver::LoadDnsResolverMethods() noexcept
{
static wil::shared_hmodule dnsModule;
static DWORD loadError = ERROR_SUCCESS;
static std::once_flag dnsLoadFlag;
// Load DNS dll only once
std::call_once(dnsLoadFlag, [&]() {
dnsModule.reset(LoadLibraryEx(c_dnsModuleName, nullptr, LOAD_LIBRARY_SEARCH_SYSTEM32));
if (!dnsModule)
{
loadError = GetLastError();
}
});
RETURN_IF_WIN32_ERROR_MSG(loadError, "LoadLibraryEx %ls", c_dnsModuleName);
// Initialize dynamic functions for the DNS tunneling Windows APIs.
// using the non-throwing instance of LxssDynamicFunction as to not end up in the Error telemetry
LxssDynamicFunction<decltype(DnsQueryRaw)> local_dnsQueryRaw{DynamicFunctionErrorLogs::None};
RETURN_IF_FAILED_EXPECTED(local_dnsQueryRaw.load(dnsModule, "DnsQueryRaw"));
LxssDynamicFunction<decltype(DnsCancelQueryRaw)> local_dnsCancelQueryRaw{DynamicFunctionErrorLogs::None};
RETURN_IF_FAILED_EXPECTED(local_dnsCancelQueryRaw.load(dnsModule, "DnsCancelQueryRaw"));
LxssDynamicFunction<decltype(DnsQueryRawResultFree)> local_dnsQueryRawResultFree{DynamicFunctionErrorLogs::None};
RETURN_IF_FAILED_EXPECTED(local_dnsQueryRawResultFree.load(dnsModule, "DnsQueryRawResultFree"));
// Make a dummy call to the DNS APIs to verify if they are working. The APIs are going to be present
// on older Windows versions, where they can be turned on/off. If turned off, the APIs
// will be unusable and will return ERROR_CALL_NOT_IMPLEMENTED.
if (local_dnsQueryRaw(nullptr, nullptr) == ERROR_CALL_NOT_IMPLEMENTED)
{
RETURN_IF_WIN32_ERROR_EXPECTED(ERROR_CALL_NOT_IMPLEMENTED);
}
s_dnsQueryRaw.emplace(std::move(local_dnsQueryRaw));
s_dnsCancelQueryRaw.emplace(std::move(local_dnsCancelQueryRaw));
s_dnsQueryRawResultFree.emplace(std::move(local_dnsQueryRawResultFree));
return S_OK;
}
DnsResolver::DnsResolver(wil::unique_socket&& dnsHvsocket, DnsResolverFlags flags) :
m_dnsChannel(
std::move(dnsHvsocket),
[this](const gsl::span<gsl::byte> dnsBuffer, const LX_GNS_DNS_CLIENT_IDENTIFIER& dnsClientIdentifier) {
ProcessDnsRequest(dnsBuffer, dnsClientIdentifier);
}),
m_flags(flags)
{
// Initialize as signaled, as there are no requests yet
m_allRequestsFinished.SetEvent();
// Read external interface constraint regkey
const auto lxssKey = windows::common::registry::OpenLxssMachineKey(KEY_READ);
m_externalInterfaceConstraintName =
windows::common::registry::ReadString(lxssKey.get(), nullptr, c_interfaceConstraintKey, L"");
if (!m_externalInterfaceConstraintName.empty())
{
ResolveExternalInterfaceConstraintIndex();
WSL_LOG(
"DnsResolver::DnsResolver",
TraceLoggingValue(m_externalInterfaceConstraintName.c_str(), "m_externalInterfaceConstraintName"),
TraceLoggingValue(m_externalInterfaceConstraintIndex, "m_externalInterfaceConstraintIndex"));
// Register for interface change notifications. Notifications are used to determine if the external interface constraint setting is applicable.
THROW_IF_WIN32_ERROR(NotifyIpInterfaceChange(AF_UNSPEC, &DnsResolver::InterfaceChangeCallback, this, FALSE, &m_interfaceNotificationHandle));
}
}
DnsResolver::~DnsResolver() noexcept
{
Stop();
}
void DnsResolver::GenerateTelemetry() noexcept
try
{
// Find the 3 most common DNS API failures
uint32_t mostCommonDnsStatusError = 0;
uint32_t mostCommonDnsStatusErrorCount = 0;
uint32_t secondCommonDnsStatusError = 0;
uint32_t secondCommonDnsStatusErrorCount = 0;
uint32_t thirdCommonDnsStatusError = 0;
uint32_t thirdCommonDnsStatusErrorCount = 0;
std::vector<std::pair<uint32_t, uint32_t>> failures(m_dnsApiFailures.size());
std::copy(m_dnsApiFailures.begin(), m_dnsApiFailures.end(), failures.begin());
// Sort in descending order based on failure count
std::sort(failures.begin(), failures.end(), [](const auto& lhs, const auto& rhs) { return lhs.second > rhs.second; });
if (failures.size() >= 1)
{
mostCommonDnsStatusError = failures[0].first;
mostCommonDnsStatusErrorCount = failures[0].second;
}
if (failures.size() >= 2)
{
secondCommonDnsStatusError = failures[1].first;
secondCommonDnsStatusErrorCount = failures[1].second;
}
if (failures.size() >= 3)
{
thirdCommonDnsStatusError = failures[2].first;
thirdCommonDnsStatusErrorCount = failures[2].second;
}
// Add telemetry with DNS tunneling statistics, before shutting down
WSL_LOG(
"DnsTunnelingStatistics",
TraceLoggingValue(m_totalUdpQueries.load(), "totalUdpQueries"),
TraceLoggingValue(m_successfulUdpQueries.load(), "successfulUdpQueries"),
TraceLoggingValue(m_totalTcpQueries.load(), "totalTcpQueries"),
TraceLoggingValue(m_successfulTcpQueries.load(), "successfulTcpQueries"),
TraceLoggingValue(m_queriesWithNullResult.load(), "queriesWithNullResult"),
TraceLoggingValue(m_failedDnsQueryRawCalls.load(), "FailedDnsQueryRawCalls"),
TraceLoggingValue(m_dnsApiFailures.size(), "totalDnsStatusErrorInstances"),
TraceLoggingValue(mostCommonDnsStatusError, "mostCommonDnsStatusError"),
TraceLoggingValue(mostCommonDnsStatusErrorCount, "mostCommonDnsStatusErrorCount"),
TraceLoggingValue(secondCommonDnsStatusError, "secondCommonDnsStatusError"),
TraceLoggingValue(secondCommonDnsStatusErrorCount, "secondCommonDnsStatusErrorCount"),
TraceLoggingValue(thirdCommonDnsStatusError, "thirdCommonDnsStatusError"),
TraceLoggingValue(thirdCommonDnsStatusErrorCount, "thirdCommonDnsStatusErrorCount"));
}
CATCH_LOG()
void DnsResolver::Stop() noexcept
try
{
WSL_LOG("DnsResolver::Stop");
// Scoped m_dnsLock
{
const std::lock_guard lock(m_dnsLock);
m_stopped = true;
// Cancel existing requests. Cancel is complete when DnsQueryRawCallback is
// invoked with status == ERROR_CANCELLED
// N.B. Cancelling can end up calling the DnsQueryRawCallback directly on this same thread. i.e., while this
// lock is held. Which is fine because m_dnsLock is a recursive mutex.
// N.B. Cancelling a query will synchronously remove the query from m_dnsRequests, which invalidates iterators.
std::vector<DNS_QUERY_RAW_CANCEL*> cancelHandles;
cancelHandles.reserve(m_dnsRequests.size());
for (auto& [_, context] : m_dnsRequests)
{
cancelHandles.emplace_back(&context->m_cancelHandle);
}
for (const auto e : cancelHandles)
{
LOG_IF_WIN32_ERROR(s_dnsCancelQueryRaw.value()(e));
}
}
// Wait for all requests to complete. At this point no new requests can be started since the object is stopped.
// We are only waiting for existing requests to finish.
m_allRequestsFinished.wait();
// Stop the response queue first as it can make calls in m_dnsChannel
m_dnsResponseQueue.cancel();
m_dnsChannel.Stop();
// Stop interface change notifications
m_interfaceNotificationHandle.reset();
GenerateTelemetry();
}
CATCH_LOG()
void DnsResolver::ProcessDnsRequest(const gsl::span<gsl::byte> dnsBuffer, const LX_GNS_DNS_CLIENT_IDENTIFIER& dnsClientIdentifier) noexcept
try
{
const std::lock_guard lock(m_dnsLock);
if (m_stopped)
{
return;
}
WSL_LOG_DEBUG(
"DnsResolver::ProcessDnsRequest - received new DNS request",
TraceLoggingValue(dnsBuffer.size(), "DNS buffer size"),
TraceLoggingValue(dnsClientIdentifier.Protocol == IPPROTO_UDP ? "UDP" : "TCP", "Protocol"),
TraceLoggingValue(dnsClientIdentifier.DnsClientId, "DNS client id"),
TraceLoggingValue(!m_externalInterfaceConstraintName.empty(), "Is ExternalInterfaceConstraint configured"),
TraceLoggingValue(m_externalInterfaceConstraintIndex, "m_externalInterfaceConstraintIndex"));
// If the external interface constraint is configured but it is *not* present/up, WSL should be net-blind, so we avoid making DNS requests.
if (!m_externalInterfaceConstraintName.empty() && m_externalInterfaceConstraintIndex == 0)
{
return;
}
dnsClientIdentifier.Protocol == IPPROTO_UDP ? m_totalUdpQueries++ : m_totalTcpQueries++;
// Get next request id. If value reaches UINT_MAX + 1 it will be automatically reset to 0
const auto requestId = m_currentRequestId++;
// Create the DNS request context
auto context = std::make_unique<DnsResolver::DnsQueryContext>(
requestId, dnsClientIdentifier, [this](_Inout_ DnsResolver::DnsQueryContext* context, _Inout_opt_ DNS_QUERY_RAW_RESULT* queryResults) {
HandleDnsQueryCompletion(context, queryResults);
});
auto [it, _] = m_dnsRequests.emplace(requestId, std::move(context));
const auto localContext = it->second.get();
auto removeContextOnError = wil::scope_exit([&] { WI_VERIFY(m_dnsRequests.erase(requestId) == 1); });
// Fill DNS request structure
DNS_QUERY_RAW_REQUEST request{};
request.version = DNS_QUERY_RAW_REQUEST_VERSION1;
request.resultsVersion = DNS_QUERY_RAW_RESULTS_VERSION1;
request.dnsQueryRawSize = static_cast<ULONG>(dnsBuffer.size());
request.dnsQueryRaw = (PBYTE)dnsBuffer.data();
request.protocol = (dnsClientIdentifier.Protocol == IPPROTO_TCP) ? DNS_PROTOCOL_TCP : DNS_PROTOCOL_UDP;
request.queryCompletionCallback = DnsResolver::DnsQueryRawCallback;
request.queryContext = localContext;
// Only unicast UDP & TCP queries are tunneled. Pass this flag to tell Windows DNS client to *not* resolve using multicast.
request.queryOptions |= DNS_QUERY_NO_MULTICAST;
// In a DNS request from Linux there might be DNS records that Windows DNS client does not know how to parse.
// By default in this case Windows will fail the request. When the flag is enabled, Windows will extract the
// question from the DNS request and attempt to resolve it, ignoring the unknown records.
if (WI_IsFlagSet(m_flags, DnsResolverFlags::BestEffortDnsParsing))
{
request.queryRawOptions |= DNS_QUERY_RAW_OPTION_BEST_EFFORT_PARSE;
}
// If the external interface constraint is configured and present on the host, only send DNS requests on that interface.
if (m_externalInterfaceConstraintIndex != 0)
{
request.interfaceIndex = m_externalInterfaceConstraintIndex;
}
// Start the DNS request
// N.B. All DNS requests will bypass the Windows DNS cache
const auto result = s_dnsQueryRaw.value()(&request, &localContext->m_cancelHandle);
if (result != DNS_REQUEST_PENDING)
{
m_failedDnsQueryRawCalls++;
WSL_LOG(
"ProcessDnsRequestFailed",
TraceLoggingValue(requestId, "requestId"),
TraceLoggingValue(result, "result"),
TraceLoggingValue("DnsQueryRaw", "executionStep"));
return;
}
removeContextOnError.release();
m_allRequestsFinished.ResetEvent();
}
CATCH_LOG()
void DnsResolver::HandleDnsQueryCompletion(_Inout_ DnsResolver::DnsQueryContext* queryContext, _Inout_opt_ DNS_QUERY_RAW_RESULT* queryResults) noexcept
try
{
// Always free the query result structure
const auto freeQueryResults = wil::scope_exit([&] {
if (queryResults != nullptr)
{
s_dnsQueryRawResultFree.value()(queryResults);
}
});
const std::lock_guard lock(m_dnsLock);
if (queryResults != nullptr)
{
WSL_LOG(
"DnsResolver::HandleDnsQueryCompletion",
TraceLoggingValue(queryContext->m_id, "queryContext->m_id"),
TraceLoggingValue(queryResults->queryStatus, "queryResults->queryStatus"),
TraceLoggingValue(queryResults->queryRawResponse != nullptr, "validResponse"));
// Note: The response may be valid even if queryResults->queryStatus is not 0, for example when the DNS server returns a negative response.
if (queryResults->queryRawResponse != nullptr)
{
queryContext->m_dnsClientIdentifier.Protocol == IPPROTO_UDP ? m_successfulUdpQueries++ : m_successfulTcpQueries++;
}
// the Windows DNS API returned failure
else
{
if (m_dnsApiFailures.find(queryResults->queryStatus) == m_dnsApiFailures.end())
{
m_dnsApiFailures[queryResults->queryStatus] = 1;
}
else
{
m_dnsApiFailures[queryResults->queryStatus]++;
}
}
}
else
{
WSL_LOG(
"DnsResolver::HandleDnsQueryCompletion - received a NULL queryResults",
TraceLoggingValue(queryContext->m_id, "queryContext->m_id"));
m_queriesWithNullResult++;
}
if (!m_stopped && queryResults != nullptr && queryResults->queryRawResponse != nullptr)
{
// Copy DNS response buffer
std::vector<gsl::byte> dnsResponse(queryResults->queryRawResponseSize);
CopyMemory(dnsResponse.data(), queryResults->queryRawResponse, queryResults->queryRawResponseSize);
WSL_LOG_DEBUG(
"DnsResolver::HandleDnsQueryCompletion - received new DNS response",
TraceLoggingValue(dnsResponse.size(), "DNS buffer size"),
TraceLoggingValue(queryContext->m_dnsClientIdentifier.Protocol == IPPROTO_UDP ? "UDP" : "TCP", "Protocol"),
TraceLoggingValue(queryContext->m_dnsClientIdentifier.DnsClientId, "DNS client id"));
// Schedule the DNS response to be sent to Linux
m_dnsResponseQueue.submit([this, dnsResponse = std::move(dnsResponse), dnsClientIdentifier = queryContext->m_dnsClientIdentifier]() mutable {
m_dnsChannel.SendDnsMessage(gsl::make_span(dnsResponse), dnsClientIdentifier);
});
}
// Stop tracking this DNS request and delete the request context
WI_VERIFY(m_dnsRequests.erase(queryContext->m_id) == 1);
// Set event if all tracked requests have finished
if (m_dnsRequests.empty())
{
m_allRequestsFinished.SetEvent();
}
}
CATCH_LOG()
void DnsResolver::ResolveExternalInterfaceConstraintIndex() noexcept
try
{
const std::lock_guard lock(m_dnsLock);
if (m_stopped)
{
return;
}
if (m_externalInterfaceConstraintName.empty())
{
return;
}
NET_LUID interfaceLuid{};
ULONG interfaceIndex = 0;
// Update the interface index on every exit path.
// The calls below to convert interface name to index will fail if the interface does not exist anymore,
// in which case we still need to reset the interface index to its default value of 0.
const auto setInterfaceIndex = wil::scope_exit([&] {
if (interfaceIndex != m_externalInterfaceConstraintIndex)
{
WSL_LOG(
"DnsResolver::ResolveExternalInterfaceConstraintIndex - setting m_externalInterfaceConstraintIndex to new value",
TraceLoggingValue(m_externalInterfaceConstraintIndex, "old interface index"),
TraceLoggingValue(interfaceIndex, "new interface index"));
m_externalInterfaceConstraintIndex = interfaceIndex;
}
});
// If external interface constraint is configured, query to see if it's present on the host.
auto errorCode = ConvertInterfaceAliasToLuid(m_externalInterfaceConstraintName.c_str(), &interfaceLuid);
if (FAILED_WIN32_LOG(errorCode))
{
return;
}
errorCode = ConvertInterfaceLuidToIndex(&interfaceLuid, reinterpret_cast<PNET_IFINDEX>(&interfaceIndex));
if (FAILED_WIN32_LOG(errorCode))
{
return;
}
}
CATCH_LOG()
VOID CALLBACK DnsResolver::DnsQueryRawCallback(_In_ VOID* queryContext, _Inout_opt_ DNS_QUERY_RAW_RESULT* queryResults) noexcept
try
{
assert(queryContext != nullptr);
const auto context = static_cast<DnsQueryContext*>(queryContext);
// Call into DnsResolver parent object to process the query result
context->m_handleQueryCompletion(context, queryResults);
}
CATCH_LOG()
VOID CALLBACK DnsResolver::InterfaceChangeCallback(_In_ PVOID context, PMIB_IPINTERFACE_ROW, MIB_NOTIFICATION_TYPE) noexcept
try
{
const auto dnsResolver = static_cast<DnsResolver*>(context);
dnsResolver->ResolveExternalInterfaceConstraintIndex();
}
CATCH_LOG()
// Copyright (C) Microsoft Corporation. All rights reserved.
#include <LxssDynamicFunction.h>
#include "precomp.h"
#include "DnsResolver.h"
using wsl::core::networking::DnsResolver;
static constexpr auto c_dnsModuleName = L"dnsapi.dll";
std::optional<LxssDynamicFunction<decltype(DnsQueryRaw)>> DnsResolver::s_dnsQueryRaw;
std::optional<LxssDynamicFunction<decltype(DnsCancelQueryRaw)>> DnsResolver::s_dnsCancelQueryRaw;
std::optional<LxssDynamicFunction<decltype(DnsQueryRawResultFree)>> DnsResolver::s_dnsQueryRawResultFree;
HRESULT DnsResolver::LoadDnsResolverMethods() noexcept
{
static wil::shared_hmodule dnsModule;
static DWORD loadError = ERROR_SUCCESS;
static std::once_flag dnsLoadFlag;
// Load DNS dll only once
std::call_once(dnsLoadFlag, [&]() {
dnsModule.reset(LoadLibraryEx(c_dnsModuleName, nullptr, LOAD_LIBRARY_SEARCH_SYSTEM32));
if (!dnsModule)
{
loadError = GetLastError();
}
});
RETURN_IF_WIN32_ERROR_MSG(loadError, "LoadLibraryEx %ls", c_dnsModuleName);
// Initialize dynamic functions for the DNS tunneling Windows APIs.
// using the non-throwing instance of LxssDynamicFunction as to not end up in the Error telemetry
LxssDynamicFunction<decltype(DnsQueryRaw)> local_dnsQueryRaw{DynamicFunctionErrorLogs::None};
RETURN_IF_FAILED_EXPECTED(local_dnsQueryRaw.load(dnsModule, "DnsQueryRaw"));
LxssDynamicFunction<decltype(DnsCancelQueryRaw)> local_dnsCancelQueryRaw{DynamicFunctionErrorLogs::None};
RETURN_IF_FAILED_EXPECTED(local_dnsCancelQueryRaw.load(dnsModule, "DnsCancelQueryRaw"));
LxssDynamicFunction<decltype(DnsQueryRawResultFree)> local_dnsQueryRawResultFree{DynamicFunctionErrorLogs::None};
RETURN_IF_FAILED_EXPECTED(local_dnsQueryRawResultFree.load(dnsModule, "DnsQueryRawResultFree"));
// Make a dummy call to the DNS APIs to verify if they are working. The APIs are going to be present
// on older Windows versions, where they can be turned on/off. If turned off, the APIs
// will be unusable and will return ERROR_CALL_NOT_IMPLEMENTED.
if (local_dnsQueryRaw(nullptr, nullptr) == ERROR_CALL_NOT_IMPLEMENTED)
{
RETURN_IF_WIN32_ERROR_EXPECTED(ERROR_CALL_NOT_IMPLEMENTED);
}
s_dnsQueryRaw.emplace(std::move(local_dnsQueryRaw));
s_dnsCancelQueryRaw.emplace(std::move(local_dnsCancelQueryRaw));
s_dnsQueryRawResultFree.emplace(std::move(local_dnsQueryRawResultFree));
return S_OK;
}
DnsResolver::DnsResolver(wil::unique_socket&& dnsHvsocket, DnsResolverFlags flags) :
m_dnsChannel(
std::move(dnsHvsocket),
[this](const gsl::span<gsl::byte> dnsBuffer, const LX_GNS_DNS_CLIENT_IDENTIFIER& dnsClientIdentifier) {
ProcessDnsRequest(dnsBuffer, dnsClientIdentifier);
}),
m_flags(flags)
{
// Initialize as signaled, as there are no requests yet
m_allRequestsFinished.SetEvent();
// Read external interface constraint regkey
const auto lxssKey = windows::common::registry::OpenLxssMachineKey(KEY_READ);
m_externalInterfaceConstraintName =
windows::common::registry::ReadString(lxssKey.get(), nullptr, c_interfaceConstraintKey, L"");
if (!m_externalInterfaceConstraintName.empty())
{
ResolveExternalInterfaceConstraintIndex();
WSL_LOG(
"DnsResolver::DnsResolver",
TraceLoggingValue(m_externalInterfaceConstraintName.c_str(), "m_externalInterfaceConstraintName"),
TraceLoggingValue(m_externalInterfaceConstraintIndex, "m_externalInterfaceConstraintIndex"));
// Register for interface change notifications. Notifications are used to determine if the external interface constraint setting is applicable.
THROW_IF_WIN32_ERROR(NotifyIpInterfaceChange(AF_UNSPEC, &DnsResolver::InterfaceChangeCallback, this, FALSE, &m_interfaceNotificationHandle));
}
}
DnsResolver::~DnsResolver() noexcept
{
Stop();
}
void DnsResolver::GenerateTelemetry() noexcept
try
{
// Find the 3 most common DNS API failures
uint32_t mostCommonDnsStatusError = 0;
uint32_t mostCommonDnsStatusErrorCount = 0;
uint32_t secondCommonDnsStatusError = 0;
uint32_t secondCommonDnsStatusErrorCount = 0;
uint32_t thirdCommonDnsStatusError = 0;
uint32_t thirdCommonDnsStatusErrorCount = 0;
std::vector<std::pair<uint32_t, uint32_t>> failures(m_dnsApiFailures.size());
std::copy(m_dnsApiFailures.begin(), m_dnsApiFailures.end(), failures.begin());
// Sort in descending order based on failure count
std::sort(failures.begin(), failures.end(), [](const auto& lhs, const auto& rhs) { return lhs.second > rhs.second; });
if (failures.size() >= 1)
{
mostCommonDnsStatusError = failures[0].first;
mostCommonDnsStatusErrorCount = failures[0].second;
}
if (failures.size() >= 2)
{
secondCommonDnsStatusError = failures[1].first;
secondCommonDnsStatusErrorCount = failures[1].second;
}
if (failures.size() >= 3)
{
thirdCommonDnsStatusError = failures[2].first;
thirdCommonDnsStatusErrorCount = failures[2].second;
}
// Add telemetry with DNS tunneling statistics, before shutting down
WSL_LOG(
"DnsTunnelingStatistics",
TraceLoggingValue(m_totalUdpQueries.load(), "totalUdpQueries"),
TraceLoggingValue(m_successfulUdpQueries.load(), "successfulUdpQueries"),
TraceLoggingValue(m_totalTcpQueries.load(), "totalTcpQueries"),
TraceLoggingValue(m_successfulTcpQueries.load(), "successfulTcpQueries"),
TraceLoggingValue(m_queriesWithNullResult.load(), "queriesWithNullResult"),
TraceLoggingValue(m_failedDnsQueryRawCalls.load(), "FailedDnsQueryRawCalls"),
TraceLoggingValue(m_dnsApiFailures.size(), "totalDnsStatusErrorInstances"),
TraceLoggingValue(mostCommonDnsStatusError, "mostCommonDnsStatusError"),
TraceLoggingValue(mostCommonDnsStatusErrorCount, "mostCommonDnsStatusErrorCount"),
TraceLoggingValue(secondCommonDnsStatusError, "secondCommonDnsStatusError"),
TraceLoggingValue(secondCommonDnsStatusErrorCount, "secondCommonDnsStatusErrorCount"),
TraceLoggingValue(thirdCommonDnsStatusError, "thirdCommonDnsStatusError"),
TraceLoggingValue(thirdCommonDnsStatusErrorCount, "thirdCommonDnsStatusErrorCount"));
}
CATCH_LOG()
void DnsResolver::Stop() noexcept
try
{
WSL_LOG("DnsResolver::Stop");
// Scoped m_dnsLock
{
const std::lock_guard lock(m_dnsLock);
m_stopped = true;
// Cancel existing requests. Cancel is complete when DnsQueryRawCallback is
// invoked with status == ERROR_CANCELLED
// N.B. Cancelling can end up calling the DnsQueryRawCallback directly on this same thread. i.e., while this
// lock is held. Which is fine because m_dnsLock is a recursive mutex.
// N.B. Cancelling a query will synchronously remove the query from m_dnsRequests, which invalidates iterators.
std::vector<DNS_QUERY_RAW_CANCEL*> cancelHandles;
cancelHandles.reserve(m_dnsRequests.size());
for (auto& [_, context] : m_dnsRequests)
{
cancelHandles.emplace_back(&context->m_cancelHandle);
}
for (const auto e : cancelHandles)
{
LOG_IF_WIN32_ERROR(s_dnsCancelQueryRaw.value()(e));
}
}
// Wait for all requests to complete. At this point no new requests can be started since the object is stopped.
// We are only waiting for existing requests to finish.
m_allRequestsFinished.wait();
// Stop the response queue first as it can make calls in m_dnsChannel
m_dnsResponseQueue.cancel();
m_dnsChannel.Stop();
// Stop interface change notifications
m_interfaceNotificationHandle.reset();
GenerateTelemetry();
}
CATCH_LOG()
void DnsResolver::ProcessDnsRequest(const gsl::span<gsl::byte> dnsBuffer, const LX_GNS_DNS_CLIENT_IDENTIFIER& dnsClientIdentifier) noexcept
try
{
const std::lock_guard lock(m_dnsLock);
if (m_stopped)
{
return;
}
WSL_LOG_DEBUG(
"DnsResolver::ProcessDnsRequest - received new DNS request",
TraceLoggingValue(dnsBuffer.size(), "DNS buffer size"),
TraceLoggingValue(dnsClientIdentifier.Protocol == IPPROTO_UDP ? "UDP" : "TCP", "Protocol"),
TraceLoggingValue(dnsClientIdentifier.DnsClientId, "DNS client id"),
TraceLoggingValue(!m_externalInterfaceConstraintName.empty(), "Is ExternalInterfaceConstraint configured"),
TraceLoggingValue(m_externalInterfaceConstraintIndex, "m_externalInterfaceConstraintIndex"));
// If the external interface constraint is configured but it is *not* present/up, WSL should be net-blind, so we avoid making DNS requests.
if (!m_externalInterfaceConstraintName.empty() && m_externalInterfaceConstraintIndex == 0)
{
return;
}
dnsClientIdentifier.Protocol == IPPROTO_UDP ? m_totalUdpQueries++ : m_totalTcpQueries++;
// Get next request id. If value reaches UINT_MAX + 1 it will be automatically reset to 0
const auto requestId = m_currentRequestId++;
// Create the DNS request context
auto context = std::make_unique<DnsResolver::DnsQueryContext>(
requestId, dnsClientIdentifier, [this](_Inout_ DnsResolver::DnsQueryContext* context, _Inout_opt_ DNS_QUERY_RAW_RESULT* queryResults) {
HandleDnsQueryCompletion(context, queryResults);
});
auto [it, _] = m_dnsRequests.emplace(requestId, std::move(context));
const auto localContext = it->second.get();
auto removeContextOnError = wil::scope_exit([&] { WI_VERIFY(m_dnsRequests.erase(requestId) == 1); });
// Fill DNS request structure
DNS_QUERY_RAW_REQUEST request{};
request.version = DNS_QUERY_RAW_REQUEST_VERSION1;
request.resultsVersion = DNS_QUERY_RAW_RESULTS_VERSION1;
request.dnsQueryRawSize = static_cast<ULONG>(dnsBuffer.size());
request.dnsQueryRaw = (PBYTE)dnsBuffer.data();
request.protocol = (dnsClientIdentifier.Protocol == IPPROTO_TCP) ? DNS_PROTOCOL_TCP : DNS_PROTOCOL_UDP;
request.queryCompletionCallback = DnsResolver::DnsQueryRawCallback;
request.queryContext = localContext;
// Only unicast UDP & TCP queries are tunneled. Pass this flag to tell Windows DNS client to *not* resolve using multicast.
request.queryOptions |= DNS_QUERY_NO_MULTICAST;
// In a DNS request from Linux there might be DNS records that Windows DNS client does not know how to parse.
// By default in this case Windows will fail the request. When the flag is enabled, Windows will extract the
// question from the DNS request and attempt to resolve it, ignoring the unknown records.
if (WI_IsFlagSet(m_flags, DnsResolverFlags::BestEffortDnsParsing))
{
request.queryRawOptions |= DNS_QUERY_RAW_OPTION_BEST_EFFORT_PARSE;
}
// If the external interface constraint is configured and present on the host, only send DNS requests on that interface.
if (m_externalInterfaceConstraintIndex != 0)
{
request.interfaceIndex = m_externalInterfaceConstraintIndex;
}
// Start the DNS request
// N.B. All DNS requests will bypass the Windows DNS cache
const auto result = s_dnsQueryRaw.value()(&request, &localContext->m_cancelHandle);
if (result != DNS_REQUEST_PENDING)
{
m_failedDnsQueryRawCalls++;
WSL_LOG(
"ProcessDnsRequestFailed",
TraceLoggingValue(requestId, "requestId"),
TraceLoggingValue(result, "result"),
TraceLoggingValue("DnsQueryRaw", "executionStep"));
return;
}
removeContextOnError.release();
m_allRequestsFinished.ResetEvent();
}
CATCH_LOG()
void DnsResolver::HandleDnsQueryCompletion(_Inout_ DnsResolver::DnsQueryContext* queryContext, _Inout_opt_ DNS_QUERY_RAW_RESULT* queryResults) noexcept
try
{
// Always free the query result structure
const auto freeQueryResults = wil::scope_exit([&] {
if (queryResults != nullptr)
{
s_dnsQueryRawResultFree.value()(queryResults);
}
});
const std::lock_guard lock(m_dnsLock);
if (queryResults != nullptr)
{
WSL_LOG(
"DnsResolver::HandleDnsQueryCompletion",
TraceLoggingValue(queryContext->m_id, "queryContext->m_id"),
TraceLoggingValue(queryResults->queryStatus, "queryResults->queryStatus"),
TraceLoggingValue(queryResults->queryRawResponse != nullptr, "validResponse"));
// Note: The response may be valid even if queryResults->queryStatus is not 0, for example when the DNS server returns a negative response.
if (queryResults->queryRawResponse != nullptr)
{
queryContext->m_dnsClientIdentifier.Protocol == IPPROTO_UDP ? m_successfulUdpQueries++ : m_successfulTcpQueries++;
}
// the Windows DNS API returned failure
else
{
if (m_dnsApiFailures.find(queryResults->queryStatus) == m_dnsApiFailures.end())
{
m_dnsApiFailures[queryResults->queryStatus] = 1;
}
else
{
m_dnsApiFailures[queryResults->queryStatus]++;
}
}
}
else
{
WSL_LOG(
"DnsResolver::HandleDnsQueryCompletion - received a NULL queryResults",
TraceLoggingValue(queryContext->m_id, "queryContext->m_id"));
m_queriesWithNullResult++;
}
if (!m_stopped && queryResults != nullptr && queryResults->queryRawResponse != nullptr)
{
// Copy DNS response buffer
std::vector<gsl::byte> dnsResponse(queryResults->queryRawResponseSize);
CopyMemory(dnsResponse.data(), queryResults->queryRawResponse, queryResults->queryRawResponseSize);
WSL_LOG_DEBUG(
"DnsResolver::HandleDnsQueryCompletion - received new DNS response",
TraceLoggingValue(dnsResponse.size(), "DNS buffer size"),
TraceLoggingValue(queryContext->m_dnsClientIdentifier.Protocol == IPPROTO_UDP ? "UDP" : "TCP", "Protocol"),
TraceLoggingValue(queryContext->m_dnsClientIdentifier.DnsClientId, "DNS client id"));
// Schedule the DNS response to be sent to Linux
m_dnsResponseQueue.submit([this, dnsResponse = std::move(dnsResponse), dnsClientIdentifier = queryContext->m_dnsClientIdentifier]() mutable {
m_dnsChannel.SendDnsMessage(gsl::make_span(dnsResponse), dnsClientIdentifier);
});
}
// Stop tracking this DNS request and delete the request context
WI_VERIFY(m_dnsRequests.erase(queryContext->m_id) == 1);
// Set event if all tracked requests have finished
if (m_dnsRequests.empty())
{
m_allRequestsFinished.SetEvent();
}
}
CATCH_LOG()
void DnsResolver::ResolveExternalInterfaceConstraintIndex() noexcept
try
{
const std::lock_guard lock(m_dnsLock);
if (m_stopped)
{
return;
}
if (m_externalInterfaceConstraintName.empty())
{
return;
}
NET_LUID interfaceLuid{};
ULONG interfaceIndex = 0;
// Update the interface index on every exit path.
// The calls below to convert interface name to index will fail if the interface does not exist anymore,
// in which case we still need to reset the interface index to its default value of 0.
const auto setInterfaceIndex = wil::scope_exit([&] {
if (interfaceIndex != m_externalInterfaceConstraintIndex)
{
WSL_LOG(
"DnsResolver::ResolveExternalInterfaceConstraintIndex - setting m_externalInterfaceConstraintIndex to new value",
TraceLoggingValue(m_externalInterfaceConstraintIndex, "old interface index"),
TraceLoggingValue(interfaceIndex, "new interface index"));
m_externalInterfaceConstraintIndex = interfaceIndex;
}
});
// If external interface constraint is configured, query to see if it's present on the host.
auto errorCode = ConvertInterfaceAliasToLuid(m_externalInterfaceConstraintName.c_str(), &interfaceLuid);
if (FAILED_WIN32_LOG(errorCode))
{
return;
}
errorCode = ConvertInterfaceLuidToIndex(&interfaceLuid, reinterpret_cast<PNET_IFINDEX>(&interfaceIndex));
if (FAILED_WIN32_LOG(errorCode))
{
return;
}
}
CATCH_LOG()
VOID CALLBACK DnsResolver::DnsQueryRawCallback(_In_ VOID* queryContext, _Inout_opt_ DNS_QUERY_RAW_RESULT* queryResults) noexcept
try
{
assert(queryContext != nullptr);
const auto context = static_cast<DnsQueryContext*>(queryContext);
// Call into DnsResolver parent object to process the query result
context->m_handleQueryCompletion(context, queryResults);
}
CATCH_LOG()
VOID CALLBACK DnsResolver::InterfaceChangeCallback(_In_ PVOID context, PMIB_IPINTERFACE_ROW, MIB_NOTIFICATION_TYPE) noexcept
try
{
const auto dnsResolver = static_cast<DnsResolver*>(context);
dnsResolver->ResolveExternalInterfaceConstraintIndex();
}
CATCH_LOG()

View File

@ -1,141 +1,141 @@
// Copyright (C) Microsoft Corporation. All rights reserved.
#pragma once
#include "DnsTunnelingChannel.h"
#include "WslCoreMessageQueue.h"
#include "WslCoreNetworkingSupport.h"
namespace wsl::core::networking {
enum class DnsResolverFlags
{
None = 0x0,
BestEffortDnsParsing = 0x1
};
DEFINE_ENUM_FLAG_OPERATORS(DnsResolverFlags);
class DnsResolver
{
public:
DnsResolver(wil::unique_socket&& dnsHvsocket, DnsResolverFlags flags);
~DnsResolver() noexcept;
DnsResolver(const DnsResolver&) = delete;
DnsResolver& operator=(const DnsResolver&) = delete;
DnsResolver(DnsResolver&&) = delete;
DnsResolver& operator=(DnsResolver&&) = delete;
void Stop() noexcept;
static HRESULT LoadDnsResolverMethods() noexcept;
private:
struct DnsQueryContext
{
// Struct containing protocol (TCP/UDP) and unique id of the Linux DNS client making the request.
LX_GNS_DNS_CLIENT_IDENTIFIER m_dnsClientIdentifier{};
// Handle used to cancel the request.
DNS_QUERY_RAW_CANCEL m_cancelHandle{};
// Unique query id.
uint32_t m_id{};
// Callback to the parent object to notify about the DNS query completion.
std::function<void(DnsQueryContext*, DNS_QUERY_RAW_RESULT*)> m_handleQueryCompletion;
DnsQueryContext(
uint32_t id,
const LX_GNS_DNS_CLIENT_IDENTIFIER& dnsClientIdentifier,
std::function<void(DnsQueryContext*, DNS_QUERY_RAW_RESULT*)>&& handleQueryCompletion) :
m_dnsClientIdentifier(dnsClientIdentifier), m_id(id), m_handleQueryCompletion(std::move(handleQueryCompletion))
{
}
~DnsQueryContext() noexcept = default;
DnsQueryContext(const DnsQueryContext&) = delete;
DnsQueryContext& operator=(const DnsQueryContext&) = delete;
DnsQueryContext(DnsQueryContext&&) = delete;
DnsQueryContext& operator=(DnsQueryContext&&) = delete;
};
void GenerateTelemetry() noexcept;
// Process DNS request received from Linux.
//
// Arguments:
// dnsBuffer - buffer containing DNS request.
// dnsClientIdentifier - struct containing protocol (TCP/UDP) and unique id of the Linux DNS client making the request.
void ProcessDnsRequest(const gsl::span<gsl::byte> dnsBuffer, const LX_GNS_DNS_CLIENT_IDENTIFIER& dnsClientIdentifier) noexcept;
// Handle completion of DNS query.
//
// Arguments:
// dnsQueryContext - context structure for the DNS request.
// queryResults - structure containing result of the DNS request.
void HandleDnsQueryCompletion(_Inout_ DnsQueryContext* dnsQueryContext, _Inout_opt_ DNS_QUERY_RAW_RESULT* queryResults) noexcept;
void ResolveExternalInterfaceConstraintIndex() noexcept;
// Callback that will be invoked by the DNS API whenever a request finishes. The callback is invoked on success, error or when request is cancelled.
//
// Arguments:
// queryContext - pointer to context structure, will be a structure of type DnsQueryContext.
// queryResults - pointer to structure containing the result of the DNS request.
static VOID CALLBACK DnsQueryRawCallback(_In_ VOID* queryContext, _Inout_opt_ DNS_QUERY_RAW_RESULT* queryResults) noexcept;
static VOID CALLBACK InterfaceChangeCallback(_In_ PVOID context, PMIB_IPINTERFACE_ROW, MIB_NOTIFICATION_TYPE) noexcept;
std::recursive_mutex m_dnsLock;
// Flag used when shutting down the object.
_Guarded_by_(m_dnsLock) bool m_stopped = false;
// Hvsocket channel used to exchange DNS messages with Linux.
DnsTunnelingChannel m_dnsChannel;
// Queue used to send DNS responses to Linux.
WslCoreMessageQueue m_dnsResponseQueue;
// Unique id that is incremented for each request. In case the value reaches MAX_UINT and is reset to 0,
// it's assumed previous requests with id's 0, 1, ... finished in the meantime and the id can be reused.
_Guarded_by_(m_dnsLock) uint32_t m_currentRequestId = 0;
// Mapping request id to the request context structure.
_Guarded_by_(m_dnsLock) std::unordered_map<uint32_t, std::unique_ptr<DnsQueryContext>> m_dnsRequests {};
// Event that is set when all tracked DNS requests have completed.
wil::unique_event m_allRequestsFinished{wil::EventOptions::ManualReset};
// Used for handling of external interface constraint setting.
unique_notify_handle m_interfaceNotificationHandle{};
std::wstring m_externalInterfaceConstraintName;
_Guarded_by_(m_dnsLock) ULONG m_externalInterfaceConstraintIndex = 0;
const DnsResolverFlags m_flags{};
// Statistics used for telemetry.
std::atomic<uint32_t> m_totalUdpQueries{0};
std::atomic<uint32_t> m_successfulUdpQueries{0};
std::atomic<uint32_t> m_totalTcpQueries{0};
std::atomic<uint32_t> m_successfulTcpQueries{0};
std::atomic<uint32_t> m_queriesWithNullResult{0};
std::atomic<uint32_t> m_failedDnsQueryRawCalls{0};
_Guarded_by_(m_dnsLock) std::map<uint32_t, uint32_t> m_dnsApiFailures;
// Dynamic functions used for calling the DNS APIs.
// Function to start a raw DNS request.
static std::optional<LxssDynamicFunction<decltype(DnsQueryRaw)>> s_dnsQueryRaw;
// Function to cancel a raw DNS request.
static std::optional<LxssDynamicFunction<decltype(DnsCancelQueryRaw)>> s_dnsCancelQueryRaw;
// Function to free the structure containing the result of a raw DNS request.
static std::optional<LxssDynamicFunction<decltype(DnsQueryRawResultFree)>> s_dnsQueryRawResultFree;
};
} // namespace wsl::core::networking
// Copyright (C) Microsoft Corporation. All rights reserved.
#pragma once
#include "DnsTunnelingChannel.h"
#include "WslCoreMessageQueue.h"
#include "WslCoreNetworkingSupport.h"
namespace wsl::core::networking {
enum class DnsResolverFlags
{
None = 0x0,
BestEffortDnsParsing = 0x1
};
DEFINE_ENUM_FLAG_OPERATORS(DnsResolverFlags);
class DnsResolver
{
public:
DnsResolver(wil::unique_socket&& dnsHvsocket, DnsResolverFlags flags);
~DnsResolver() noexcept;
DnsResolver(const DnsResolver&) = delete;
DnsResolver& operator=(const DnsResolver&) = delete;
DnsResolver(DnsResolver&&) = delete;
DnsResolver& operator=(DnsResolver&&) = delete;
void Stop() noexcept;
static HRESULT LoadDnsResolverMethods() noexcept;
private:
struct DnsQueryContext
{
// Struct containing protocol (TCP/UDP) and unique id of the Linux DNS client making the request.
LX_GNS_DNS_CLIENT_IDENTIFIER m_dnsClientIdentifier{};
// Handle used to cancel the request.
DNS_QUERY_RAW_CANCEL m_cancelHandle{};
// Unique query id.
uint32_t m_id{};
// Callback to the parent object to notify about the DNS query completion.
std::function<void(DnsQueryContext*, DNS_QUERY_RAW_RESULT*)> m_handleQueryCompletion;
DnsQueryContext(
uint32_t id,
const LX_GNS_DNS_CLIENT_IDENTIFIER& dnsClientIdentifier,
std::function<void(DnsQueryContext*, DNS_QUERY_RAW_RESULT*)>&& handleQueryCompletion) :
m_dnsClientIdentifier(dnsClientIdentifier), m_id(id), m_handleQueryCompletion(std::move(handleQueryCompletion))
{
}
~DnsQueryContext() noexcept = default;
DnsQueryContext(const DnsQueryContext&) = delete;
DnsQueryContext& operator=(const DnsQueryContext&) = delete;
DnsQueryContext(DnsQueryContext&&) = delete;
DnsQueryContext& operator=(DnsQueryContext&&) = delete;
};
void GenerateTelemetry() noexcept;
// Process DNS request received from Linux.
//
// Arguments:
// dnsBuffer - buffer containing DNS request.
// dnsClientIdentifier - struct containing protocol (TCP/UDP) and unique id of the Linux DNS client making the request.
void ProcessDnsRequest(const gsl::span<gsl::byte> dnsBuffer, const LX_GNS_DNS_CLIENT_IDENTIFIER& dnsClientIdentifier) noexcept;
// Handle completion of DNS query.
//
// Arguments:
// dnsQueryContext - context structure for the DNS request.
// queryResults - structure containing result of the DNS request.
void HandleDnsQueryCompletion(_Inout_ DnsQueryContext* dnsQueryContext, _Inout_opt_ DNS_QUERY_RAW_RESULT* queryResults) noexcept;
void ResolveExternalInterfaceConstraintIndex() noexcept;
// Callback that will be invoked by the DNS API whenever a request finishes. The callback is invoked on success, error or when request is cancelled.
//
// Arguments:
// queryContext - pointer to context structure, will be a structure of type DnsQueryContext.
// queryResults - pointer to structure containing the result of the DNS request.
static VOID CALLBACK DnsQueryRawCallback(_In_ VOID* queryContext, _Inout_opt_ DNS_QUERY_RAW_RESULT* queryResults) noexcept;
static VOID CALLBACK InterfaceChangeCallback(_In_ PVOID context, PMIB_IPINTERFACE_ROW, MIB_NOTIFICATION_TYPE) noexcept;
std::recursive_mutex m_dnsLock;
// Flag used when shutting down the object.
_Guarded_by_(m_dnsLock) bool m_stopped = false;
// Hvsocket channel used to exchange DNS messages with Linux.
DnsTunnelingChannel m_dnsChannel;
// Queue used to send DNS responses to Linux.
WslCoreMessageQueue m_dnsResponseQueue;
// Unique id that is incremented for each request. In case the value reaches MAX_UINT and is reset to 0,
// it's assumed previous requests with id's 0, 1, ... finished in the meantime and the id can be reused.
_Guarded_by_(m_dnsLock) uint32_t m_currentRequestId = 0;
// Mapping request id to the request context structure.
_Guarded_by_(m_dnsLock) std::unordered_map<uint32_t, std::unique_ptr<DnsQueryContext>> m_dnsRequests {};
// Event that is set when all tracked DNS requests have completed.
wil::unique_event m_allRequestsFinished{wil::EventOptions::ManualReset};
// Used for handling of external interface constraint setting.
unique_notify_handle m_interfaceNotificationHandle{};
std::wstring m_externalInterfaceConstraintName;
_Guarded_by_(m_dnsLock) ULONG m_externalInterfaceConstraintIndex = 0;
const DnsResolverFlags m_flags{};
// Statistics used for telemetry.
std::atomic<uint32_t> m_totalUdpQueries{0};
std::atomic<uint32_t> m_successfulUdpQueries{0};
std::atomic<uint32_t> m_totalTcpQueries{0};
std::atomic<uint32_t> m_successfulTcpQueries{0};
std::atomic<uint32_t> m_queriesWithNullResult{0};
std::atomic<uint32_t> m_failedDnsQueryRawCalls{0};
_Guarded_by_(m_dnsLock) std::map<uint32_t, uint32_t> m_dnsApiFailures;
// Dynamic functions used for calling the DNS APIs.
// Function to start a raw DNS request.
static std::optional<LxssDynamicFunction<decltype(DnsQueryRaw)>> s_dnsQueryRaw;
// Function to cancel a raw DNS request.
static std::optional<LxssDynamicFunction<decltype(DnsCancelQueryRaw)>> s_dnsCancelQueryRaw;
// Function to free the structure containing the result of a raw DNS request.
static std::optional<LxssDynamicFunction<decltype(DnsQueryRawResultFree)>> s_dnsQueryRawResultFree;
};
} // namespace wsl::core::networking

View File

@ -1,115 +1,115 @@
// Copyright (C) Microsoft Corporation. All rights reserved.
#include "precomp.h"
#include "DnsTunnelingChannel.h"
using wsl::core::networking::DnsTunnelingChannel;
DnsTunnelingChannel::DnsTunnelingChannel(wil::unique_socket&& socket, DnsTunnelingCallback&& reportDnsRequest) :
m_channel{std::move(socket), "DnsTunneling", m_stopEvent.get()}, m_reportDnsRequest(std::move(reportDnsRequest))
{
WSL_LOG("DnsTunnelingChannel::DnsTunnelingChannel [Windows]", TraceLoggingValue(m_channel.Socket(), "socket"));
// Start thread waiting for incoming messages from Linux side
m_receiveWorkerThread = std::thread([this]() { ReceiveLoop(); });
}
DnsTunnelingChannel::~DnsTunnelingChannel()
{
Stop();
}
void DnsTunnelingChannel::SendDnsMessage(const gsl::span<gsl::byte> dnsBuffer, const LX_GNS_DNS_CLIENT_IDENTIFIER& dnsClientIdentifier) noexcept
try
{
// Exit if channel was stopped
if (m_stopEvent.is_signaled())
{
return;
}
wsl::shared::MessageWriter<LX_GNS_DNS_TUNNELING_MESSAGE> message(LxGnsMessageDnsTunneling);
message->DnsClientIdentifier = dnsClientIdentifier;
message.WriteSpan(dnsBuffer);
m_channel.SendMessage<LX_GNS_DNS_TUNNELING_MESSAGE>(message.Span());
}
CATCH_LOG()
void DnsTunnelingChannel::ReceiveLoop() noexcept
{
std::vector<gsl::byte> receiveBuffer;
for (;;)
{
try
{
if (m_stopEvent.is_signaled())
{
return;
}
WSL_LOG_DEBUG("DnsTunnelingChannel::ReceiveLoop [Windows] - waiting for next message from Linux");
// Read next message. wsl::shared::socket::RecvMessage() first reads the message header, then uses it to determine the
// total size of the message and read the rest of the message, resizing the buffer if needed.
auto [message, span] = m_channel.ReceiveMessageOrClosed<MESSAGE_HEADER>();
if (message == nullptr)
{
WSL_LOG("DnsTunnelingChannel::ReceiveLoop [Windows] - failed to read message");
return;
}
// Get the message type from the message header
switch (message->MessageType)
{
case LxGnsMessageDnsTunneling:
{
// Cast message to a LX_GNS_DNS_TUNNELING_MESSAGE struct
auto* dnsMessage = gslhelpers::try_get_struct<LX_GNS_DNS_TUNNELING_MESSAGE>(span);
if (!dnsMessage)
{
WSL_LOG(
"DnsTunnelingChannel::ReceiveLoop [Windows] - failed to convert message to LX_GNS_DNS_TUNNELING_MESSAGE");
return;
}
// Extract DNS buffer from message
auto dnsBuffer = span.subspan(offsetof(LX_GNS_DNS_TUNNELING_MESSAGE, Buffer));
WSL_LOG_DEBUG(
"DnsTunnelingChannel::ReceiveLoop [Windows] - received DNS message",
TraceLoggingValue(dnsBuffer.size(), "DNS buffer size"),
TraceLoggingValue(dnsMessage->DnsClientIdentifier.Protocol == IPPROTO_UDP ? "UDP" : "TCP", "Protocol"),
TraceLoggingValue(dnsMessage->DnsClientIdentifier.DnsClientId, "DNS client id"));
// Invoke callback to notify about the new DNS request
m_reportDnsRequest(dnsBuffer, dnsMessage->DnsClientIdentifier);
break;
}
default:
{
THROW_HR_MSG(E_UNEXPECTED, "Unexpected LX_MESSAGE_TYPE : %i", message->MessageType);
}
}
}
CATCH_LOG()
}
}
void DnsTunnelingChannel::Stop() noexcept
try
{
WSL_LOG("DnsTunnelingChannel::Stop [Windows]");
m_stopEvent.SetEvent();
// Stop receive loop
if (m_receiveWorkerThread.joinable())
{
m_receiveWorkerThread.join();
}
}
CATCH_LOG()
// Copyright (C) Microsoft Corporation. All rights reserved.
#include "precomp.h"
#include "DnsTunnelingChannel.h"
using wsl::core::networking::DnsTunnelingChannel;
DnsTunnelingChannel::DnsTunnelingChannel(wil::unique_socket&& socket, DnsTunnelingCallback&& reportDnsRequest) :
m_channel{std::move(socket), "DnsTunneling", m_stopEvent.get()}, m_reportDnsRequest(std::move(reportDnsRequest))
{
WSL_LOG("DnsTunnelingChannel::DnsTunnelingChannel [Windows]", TraceLoggingValue(m_channel.Socket(), "socket"));
// Start thread waiting for incoming messages from Linux side
m_receiveWorkerThread = std::thread([this]() { ReceiveLoop(); });
}
DnsTunnelingChannel::~DnsTunnelingChannel()
{
Stop();
}
void DnsTunnelingChannel::SendDnsMessage(const gsl::span<gsl::byte> dnsBuffer, const LX_GNS_DNS_CLIENT_IDENTIFIER& dnsClientIdentifier) noexcept
try
{
// Exit if channel was stopped
if (m_stopEvent.is_signaled())
{
return;
}
wsl::shared::MessageWriter<LX_GNS_DNS_TUNNELING_MESSAGE> message(LxGnsMessageDnsTunneling);
message->DnsClientIdentifier = dnsClientIdentifier;
message.WriteSpan(dnsBuffer);
m_channel.SendMessage<LX_GNS_DNS_TUNNELING_MESSAGE>(message.Span());
}
CATCH_LOG()
void DnsTunnelingChannel::ReceiveLoop() noexcept
{
std::vector<gsl::byte> receiveBuffer;
for (;;)
{
try
{
if (m_stopEvent.is_signaled())
{
return;
}
WSL_LOG_DEBUG("DnsTunnelingChannel::ReceiveLoop [Windows] - waiting for next message from Linux");
// Read next message. wsl::shared::socket::RecvMessage() first reads the message header, then uses it to determine the
// total size of the message and read the rest of the message, resizing the buffer if needed.
auto [message, span] = m_channel.ReceiveMessageOrClosed<MESSAGE_HEADER>();
if (message == nullptr)
{
WSL_LOG("DnsTunnelingChannel::ReceiveLoop [Windows] - failed to read message");
return;
}
// Get the message type from the message header
switch (message->MessageType)
{
case LxGnsMessageDnsTunneling:
{
// Cast message to a LX_GNS_DNS_TUNNELING_MESSAGE struct
auto* dnsMessage = gslhelpers::try_get_struct<LX_GNS_DNS_TUNNELING_MESSAGE>(span);
if (!dnsMessage)
{
WSL_LOG(
"DnsTunnelingChannel::ReceiveLoop [Windows] - failed to convert message to LX_GNS_DNS_TUNNELING_MESSAGE");
return;
}
// Extract DNS buffer from message
auto dnsBuffer = span.subspan(offsetof(LX_GNS_DNS_TUNNELING_MESSAGE, Buffer));
WSL_LOG_DEBUG(
"DnsTunnelingChannel::ReceiveLoop [Windows] - received DNS message",
TraceLoggingValue(dnsBuffer.size(), "DNS buffer size"),
TraceLoggingValue(dnsMessage->DnsClientIdentifier.Protocol == IPPROTO_UDP ? "UDP" : "TCP", "Protocol"),
TraceLoggingValue(dnsMessage->DnsClientIdentifier.DnsClientId, "DNS client id"));
// Invoke callback to notify about the new DNS request
m_reportDnsRequest(dnsBuffer, dnsMessage->DnsClientIdentifier);
break;
}
default:
{
THROW_HR_MSG(E_UNEXPECTED, "Unexpected LX_MESSAGE_TYPE : %i", message->MessageType);
}
}
}
CATCH_LOG()
}
}
void DnsTunnelingChannel::Stop() noexcept
try
{
WSL_LOG("DnsTunnelingChannel::Stop [Windows]");
m_stopEvent.SetEvent();
// Stop receive loop
if (m_receiveWorkerThread.joinable())
{
m_receiveWorkerThread.join();
}
}
CATCH_LOG()

View File

@ -1,50 +1,50 @@
// Copyright (C) Microsoft Corporation. All rights reserved.
#pragma once
#include <wil/resource.h>
#include "lxinitshared.h"
#include "SocketChannel.h"
namespace wsl::core::networking {
using DnsTunnelingCallback = std::function<void(const gsl::span<gsl::byte>, const LX_GNS_DNS_CLIENT_IDENTIFIER&)>;
class DnsTunnelingChannel
{
public:
DnsTunnelingChannel(wil::unique_socket&& socket, DnsTunnelingCallback&& reportDnsRequest);
~DnsTunnelingChannel();
DnsTunnelingChannel(const DnsTunnelingChannel&) = delete;
DnsTunnelingChannel& operator=(const DnsTunnelingChannel&) = delete;
DnsTunnelingChannel(DnsTunnelingChannel&&) = delete;
DnsTunnelingChannel& operator=(DnsTunnelingChannel&&) = delete;
// Construct and send a LX_GNS_DNS_TUNNELING_MESSAGE message on the channel.
// Note: Callers are responsible for sequencing calls to this method.
//
// Arguments:
// dnsBuffer - buffer containing DNS response.
// dnsClientIdentifier - struct containing protocol (TCP/UDP) and unique id of the Linux DNS client making the request.
void SendDnsMessage(const gsl::span<gsl::byte> dnsBuffer, const LX_GNS_DNS_CLIENT_IDENTIFIER& dnsClientIdentifier) noexcept;
// Stop the channel.
void Stop() noexcept;
private:
// Wait for messages on the channel from Linux side.
void ReceiveLoop() noexcept;
wil::unique_event m_stopEvent{wil::EventOptions::ManualReset};
wsl::shared::SocketChannel m_channel;
std::thread m_receiveWorkerThread;
// Callback used to notify when there is a new DNS request message on the channel.
DnsTunnelingCallback m_reportDnsRequest;
};
} // namespace wsl::core::networking
// Copyright (C) Microsoft Corporation. All rights reserved.
#pragma once
#include <wil/resource.h>
#include "lxinitshared.h"
#include "SocketChannel.h"
namespace wsl::core::networking {
using DnsTunnelingCallback = std::function<void(const gsl::span<gsl::byte>, const LX_GNS_DNS_CLIENT_IDENTIFIER&)>;
class DnsTunnelingChannel
{
public:
DnsTunnelingChannel(wil::unique_socket&& socket, DnsTunnelingCallback&& reportDnsRequest);
~DnsTunnelingChannel();
DnsTunnelingChannel(const DnsTunnelingChannel&) = delete;
DnsTunnelingChannel& operator=(const DnsTunnelingChannel&) = delete;
DnsTunnelingChannel(DnsTunnelingChannel&&) = delete;
DnsTunnelingChannel& operator=(DnsTunnelingChannel&&) = delete;
// Construct and send a LX_GNS_DNS_TUNNELING_MESSAGE message on the channel.
// Note: Callers are responsible for sequencing calls to this method.
//
// Arguments:
// dnsBuffer - buffer containing DNS response.
// dnsClientIdentifier - struct containing protocol (TCP/UDP) and unique id of the Linux DNS client making the request.
void SendDnsMessage(const gsl::span<gsl::byte> dnsBuffer, const LX_GNS_DNS_CLIENT_IDENTIFIER& dnsClientIdentifier) noexcept;
// Stop the channel.
void Stop() noexcept;
private:
// Wait for messages on the channel from Linux side.
void ReceiveLoop() noexcept;
wil::unique_event m_stopEvent{wil::EventOptions::ManualReset};
wsl::shared::SocketChannel m_channel;
std::thread m_receiveWorkerThread;
// Callback used to notify when there is a new DNS request message on the channel.
DnsTunnelingCallback m_reportDnsRequest;
};
} // namespace wsl::core::networking

View File

@ -6,7 +6,6 @@
#include "WslCoreHostDnsInfo.h"
#include "Stringify.h"
#include "WslCoreFirewallSupport.h"
#include "WslCoreVm.h"
#include "hcs.hpp"
using namespace wsl::core::networking;
@ -672,7 +671,7 @@ wsl::windows::common::hcs::unique_hcn_network NatNetworking::CreateNetwork(wsl::
wil::ResultFromException(WI_DIAGNOSTICS_INFO, [&] {
try
{
wsl::core::networking::ConfigureHyperVFirewall(config.FirewallConfig, c_vmOwner);
wsl::core::networking::ConfigureHyperVFirewall(config.FirewallConfig, wsl::windows::common::wslutil::c_vmOwner);
natNetwork = CreateNetworkInternal(config);
}
catch (...)

View File

@ -1,154 +1,154 @@
/*++
Copyright (c) Microsoft. All rights reserved.
Module Name:
RingBuffer.cpp
Abstract:
This file contains definitions for the RingBuffer class.
--*/
#include "precomp.h"
#include "RingBuffer.h"
RingBuffer::RingBuffer(size_t size) : m_maxSize(size), m_offset(0)
{
m_buffer.reserve(size);
}
void RingBuffer::Insert(std::string_view data)
{
auto lock = m_lock.lock_exclusive();
auto remainingData = gsl::make_span(data.data(), data.size());
if (remainingData.size() > m_maxSize)
{
remainingData = remainingData.subspan(remainingData.size() - m_maxSize);
}
const auto bytesAtEnd = std::min(m_maxSize - m_offset, remainingData.size());
if (m_offset + bytesAtEnd > m_buffer.size())
{
m_buffer.resize(m_offset + bytesAtEnd);
WI_ASSERT(m_buffer.size() <= m_maxSize);
}
const auto allBuffer = gsl::make_span(m_buffer);
const auto beginCopyBuffer = allBuffer.subspan(m_offset, bytesAtEnd);
copy(remainingData.subspan(0, bytesAtEnd), beginCopyBuffer);
remainingData = remainingData.subspan(bytesAtEnd);
if (!remainingData.empty())
{
copy(remainingData, allBuffer);
m_offset = remainingData.size();
}
else
{
m_offset += bytesAtEnd;
}
}
std::vector<std::string> RingBuffer::GetLastDelimitedStrings(char Delimiter, size_t Count) const
{
auto lock = m_lock.lock_shared();
auto [begin, end] = Contents();
std::vector<std::string> results;
std::optional<size_t> endIndex;
for (size_t i = end.size(); i > 0; i--)
{
if (results.size() == Count)
{
break;
}
if (Delimiter == end[i - 1])
{
if (endIndex.has_value())
{
results.emplace(results.begin(), &end[i], endIndex.value() - i);
endIndex.reset();
}
else
{
endIndex = i - 1;
}
}
}
if (results.size() == Count)
{
return results;
}
std::string partial;
if (endIndex.has_value())
{
partial = std::string{&end[0], endIndex.value()};
endIndex.reset();
}
for (size_t i = begin.size(); i > 0; i--)
{
if (results.size() == Count)
{
break;
}
if (Delimiter == begin[i - 1])
{
if (!partial.empty())
{
// The debug CRT will fastfail if begin[size] is accessed
// But in this case it's not a problem because begin.size() - i would be == 0
std::string partial_begin{&begin.data()[i], begin.size() - i};
results.emplace(results.begin(), partial_begin + partial);
partial.clear();
}
else if (endIndex.has_value())
{
results.emplace(results.begin(), &begin.data()[i], endIndex.value() - i);
endIndex.reset();
}
else
{
endIndex = i - 1;
}
}
}
if (results.size() < Count)
{
// May have lost some data, or this could be the very first line logged.
if (!partial.empty())
{
results.emplace(results.begin(), partial);
}
else if (endIndex.has_value())
{
results.emplace(results.begin(), &begin[0], endIndex.value());
}
}
return results;
}
std::string RingBuffer::Get() const
{
auto lock = m_lock.lock_shared();
auto [begin, end] = Contents();
std::string data;
data.reserve(begin.size() + end.size());
data.append(begin.data(), begin.size());
data.append(end.data(), end.size());
return data;
}
std::pair<std::string_view, std::string_view> RingBuffer::Contents() const
{
std::string_view beginView(m_buffer.data() + m_offset, m_buffer.size() - m_offset);
std::string_view endView(m_buffer.data(), m_offset);
return {beginView, endView};
/*++
Copyright (c) Microsoft. All rights reserved.
Module Name:
RingBuffer.cpp
Abstract:
This file contains definitions for the RingBuffer class.
--*/
#include "precomp.h"
#include "RingBuffer.h"
RingBuffer::RingBuffer(size_t size) : m_maxSize(size), m_offset(0)
{
m_buffer.reserve(size);
}
void RingBuffer::Insert(std::string_view data)
{
auto lock = m_lock.lock_exclusive();
auto remainingData = gsl::make_span(data.data(), data.size());
if (remainingData.size() > m_maxSize)
{
remainingData = remainingData.subspan(remainingData.size() - m_maxSize);
}
const auto bytesAtEnd = std::min(m_maxSize - m_offset, remainingData.size());
if (m_offset + bytesAtEnd > m_buffer.size())
{
m_buffer.resize(m_offset + bytesAtEnd);
WI_ASSERT(m_buffer.size() <= m_maxSize);
}
const auto allBuffer = gsl::make_span(m_buffer);
const auto beginCopyBuffer = allBuffer.subspan(m_offset, bytesAtEnd);
copy(remainingData.subspan(0, bytesAtEnd), beginCopyBuffer);
remainingData = remainingData.subspan(bytesAtEnd);
if (!remainingData.empty())
{
copy(remainingData, allBuffer);
m_offset = remainingData.size();
}
else
{
m_offset += bytesAtEnd;
}
}
std::vector<std::string> RingBuffer::GetLastDelimitedStrings(char Delimiter, size_t Count) const
{
auto lock = m_lock.lock_shared();
auto [begin, end] = Contents();
std::vector<std::string> results;
std::optional<size_t> endIndex;
for (size_t i = end.size(); i > 0; i--)
{
if (results.size() == Count)
{
break;
}
if (Delimiter == end[i - 1])
{
if (endIndex.has_value())
{
results.emplace(results.begin(), &end[i], endIndex.value() - i);
endIndex.reset();
}
else
{
endIndex = i - 1;
}
}
}
if (results.size() == Count)
{
return results;
}
std::string partial;
if (endIndex.has_value())
{
partial = std::string{&end[0], endIndex.value()};
endIndex.reset();
}
for (size_t i = begin.size(); i > 0; i--)
{
if (results.size() == Count)
{
break;
}
if (Delimiter == begin[i - 1])
{
if (!partial.empty())
{
// The debug CRT will fastfail if begin[size] is accessed
// But in this case it's not a problem because begin.size() - i would be == 0
std::string partial_begin{&begin.data()[i], begin.size() - i};
results.emplace(results.begin(), partial_begin + partial);
partial.clear();
}
else if (endIndex.has_value())
{
results.emplace(results.begin(), &begin.data()[i], endIndex.value() - i);
endIndex.reset();
}
else
{
endIndex = i - 1;
}
}
}
if (results.size() < Count)
{
// May have lost some data, or this could be the very first line logged.
if (!partial.empty())
{
results.emplace(results.begin(), partial);
}
else if (endIndex.has_value())
{
results.emplace(results.begin(), &begin[0], endIndex.value());
}
}
return results;
}
std::string RingBuffer::Get() const
{
auto lock = m_lock.lock_shared();
auto [begin, end] = Contents();
std::string data;
data.reserve(begin.size() + end.size());
data.append(begin.data(), begin.size());
data.append(end.data(), end.size());
return data;
}
std::pair<std::string_view, std::string_view> RingBuffer::Contents() const
{
std::string_view beginView(m_buffer.data() + m_offset, m_buffer.size() - m_offset);
std::string_view endView(m_buffer.data(), m_offset);
return {beginView, endView};
}

View File

@ -1,34 +1,34 @@
/*++
Copyright (c) Microsoft. All rights reserved.
Module Name:
RingBuffer.h
Abstract:
This file contains declarations for the RingBuffer class.
--*/
#pragma once
class RingBuffer
{
public:
RingBuffer() = delete;
RingBuffer(size_t size);
void Insert(std::string_view data);
std::vector<std::string> GetLastDelimitedStrings(char Delimiter, size_t Count) const;
std::string Get() const;
private:
std::pair<std::string_view, std::string_view> Contents() const;
mutable wil::srwlock m_lock;
std::vector<char> m_buffer;
size_t m_maxSize;
size_t m_offset;
/*++
Copyright (c) Microsoft. All rights reserved.
Module Name:
RingBuffer.h
Abstract:
This file contains declarations for the RingBuffer class.
--*/
#pragma once
class RingBuffer
{
public:
RingBuffer() = delete;
RingBuffer(size_t size);
void Insert(std::string_view data);
std::vector<std::string> GetLastDelimitedStrings(char Delimiter, size_t Count) const;
std::string Get() const;
private:
std::pair<std::string_view, std::string_view> Contents() const;
mutable wil::srwlock m_lock;
std::vector<char> m_buffer;
size_t m_maxSize;
size_t m_offset;
};

View File

@ -1,104 +1,104 @@
// Copyright (C) Microsoft Corporation. All rights reserved.
#pragma once
#include <mutex>
#include <string>
#include <vector>
#include <iptypes.h>
#include <wil/registry.h>
#include "WslCoreNetworkingSupport.h"
#include "RegistryWatcher.h"
namespace wsl::core::networking {
struct DnsInfo
{
std::vector<std::string> Servers;
std::vector<std::string> Domains;
};
enum class DnsSettingsFlags
{
None = 0x0,
IncludeVpn = 0x1,
IncludeIpv6Servers = 0x2,
IncludeAllSuffixes = 0x4
};
DEFINE_ENUM_FLAG_OPERATORS(DnsSettingsFlags);
inline bool operator==(const DnsInfo& lhs, const DnsInfo& rhs) noexcept
{
return lhs.Servers == rhs.Servers && lhs.Domains == rhs.Domains;
}
inline bool operator!=(const DnsInfo& lhs, const DnsInfo& rhs) noexcept
{
return !(lhs == rhs);
}
std::string GenerateResolvConf(_In_ const DnsInfo& Info);
std::vector<std::string> GetAllDnsSuffixes(const std::vector<IpAdapterAddress>& AdapterAddresses);
DWORD GetBestInterface();
class HostDnsInfo
{
public:
DnsInfo GetDnsSettings(_In_ DnsSettingsFlags Flags);
void UpdateNetworkInformation();
static DnsInfo GetDnsTunnelingSettings(const std::wstring& dnsTunnelingNameserver);
const std::vector<IpAdapterAddress>& CurrentAddresses() const
{
return m_addresses;
}
private:
/// <summary>
/// Internal function to retrieve the latest copy of interface information.
/// </summary>
std::vector<IpAdapterAddress> GetAdapterAddresses();
/// <summary>
/// Internal function to retrieve interface DNS servers.
/// </summary>
std::vector<std::string> GetInterfaceDnsServers(const std::vector<IpAdapterAddress>& AdapterAddresses, _In_ DnsSettingsFlags Flags);
/// <summary>
/// Internal function to retrieve all Windows DNS suffixes.
/// </summary>
static std::vector<std::string> GetInterfaceDnsSuffixes(const std::vector<IpAdapterAddress>& AdapterAddresses);
/// <summary>
/// Internal function to convert DNS server addresses into strings.
/// </summary>
static std::vector<std::string> GetDnsServerStrings(_In_ const PIP_ADAPTER_DNS_SERVER_ADDRESS& DnsServer, _In_ USHORT IpFamilyFilter, _In_ USHORT MaxValues);
/// <summary>
/// Stores latest copy of interface information.
/// </summary>
std::mutex m_lock;
_Guarded_by_(m_lock) std::vector<IpAdapterAddress> m_addresses;
};
using RegistryChangeCallback = std::function<void()>;
/// <summary>
/// Class used to get notifications when Windows DNS suffixes are updated in registry.
/// </summary>
class DnsSuffixRegistryWatcher
{
public:
DnsSuffixRegistryWatcher(RegistryChangeCallback&& reportRegistryChange);
~DnsSuffixRegistryWatcher() noexcept = default;
private:
RegistryChangeCallback m_reportRegistryChange;
std::vector<wistd::unique_ptr<wsl::windows::common::slim_registry_watcher>> m_registryWatchers;
};
// Copyright (C) Microsoft Corporation. All rights reserved.
#pragma once
#include <mutex>
#include <string>
#include <vector>
#include <iptypes.h>
#include <wil/registry.h>
#include "WslCoreNetworkingSupport.h"
#include "RegistryWatcher.h"
namespace wsl::core::networking {
struct DnsInfo
{
std::vector<std::string> Servers;
std::vector<std::string> Domains;
};
enum class DnsSettingsFlags
{
None = 0x0,
IncludeVpn = 0x1,
IncludeIpv6Servers = 0x2,
IncludeAllSuffixes = 0x4
};
DEFINE_ENUM_FLAG_OPERATORS(DnsSettingsFlags);
inline bool operator==(const DnsInfo& lhs, const DnsInfo& rhs) noexcept
{
return lhs.Servers == rhs.Servers && lhs.Domains == rhs.Domains;
}
inline bool operator!=(const DnsInfo& lhs, const DnsInfo& rhs) noexcept
{
return !(lhs == rhs);
}
std::string GenerateResolvConf(_In_ const DnsInfo& Info);
std::vector<std::string> GetAllDnsSuffixes(const std::vector<IpAdapterAddress>& AdapterAddresses);
DWORD GetBestInterface();
class HostDnsInfo
{
public:
DnsInfo GetDnsSettings(_In_ DnsSettingsFlags Flags);
void UpdateNetworkInformation();
static DnsInfo GetDnsTunnelingSettings(const std::wstring& dnsTunnelingNameserver);
const std::vector<IpAdapterAddress>& CurrentAddresses() const
{
return m_addresses;
}
private:
/// <summary>
/// Internal function to retrieve the latest copy of interface information.
/// </summary>
std::vector<IpAdapterAddress> GetAdapterAddresses();
/// <summary>
/// Internal function to retrieve interface DNS servers.
/// </summary>
std::vector<std::string> GetInterfaceDnsServers(const std::vector<IpAdapterAddress>& AdapterAddresses, _In_ DnsSettingsFlags Flags);
/// <summary>
/// Internal function to retrieve all Windows DNS suffixes.
/// </summary>
static std::vector<std::string> GetInterfaceDnsSuffixes(const std::vector<IpAdapterAddress>& AdapterAddresses);
/// <summary>
/// Internal function to convert DNS server addresses into strings.
/// </summary>
static std::vector<std::string> GetDnsServerStrings(_In_ const PIP_ADAPTER_DNS_SERVER_ADDRESS& DnsServer, _In_ USHORT IpFamilyFilter, _In_ USHORT MaxValues);
/// <summary>
/// Stores latest copy of interface information.
/// </summary>
std::mutex m_lock;
_Guarded_by_(m_lock) std::vector<IpAdapterAddress> m_addresses;
};
using RegistryChangeCallback = std::function<void()>;
/// <summary>
/// Class used to get notifications when Windows DNS suffixes are updated in registry.
/// </summary>
class DnsSuffixRegistryWatcher
{
public:
DnsSuffixRegistryWatcher(RegistryChangeCallback&& reportRegistryChange);
~DnsSuffixRegistryWatcher() noexcept = default;
private:
RegistryChangeCallback m_reportRegistryChange;
std::vector<wistd::unique_ptr<wsl::windows::common::slim_registry_watcher>> m_registryWatchers;
};
} // namespace wsl::core::networking

View File

@ -1,359 +1,359 @@
/*++
Copyright (c) Microsoft. All rights reserved.
Module Name:
WslCoreMessageQueue.h
Abstract:
This file contains a queuing implementation, guaranteeing running function objects
with guaranteed serialization in a threadpool thread
--*/
#pragma once
#include <deque>
#include <functional>
#include <memory>
#include <variant>
#include <windows.h>
#include <wil/resource.h>
namespace wsl::core {
// forward-declare classes that can instantiate a WslThreadPoolWaitableResult object
class WslCoreMessageQueue;
class WslBaseThreadPoolWaitableResult
{
public:
virtual ~WslBaseThreadPoolWaitableResult() noexcept = default;
private:
// limit who can run() and abort()
friend class WslCoreMessageQueue;
virtual void run() noexcept = 0;
virtual void abort() noexcept = 0;
};
template <typename TReturn>
class WslThreadPoolWaitableResult : public WslBaseThreadPoolWaitableResult
{
public:
// throws a wil exception on failure
template <typename FunctorType>
explicit WslThreadPoolWaitableResult(FunctorType&& functor) : m_function(std::forward<FunctorType>(functor))
{
}
~WslThreadPoolWaitableResult() noexcept override = default;
// returns ERROR_SUCCESS if the callback ran to completion
// returns ERROR_TIMEOUT if this wait timed out
// - this can be called multiple times if needing to probe
// any other error code resulted from attempting to run the callback
// - meaning it did *not* run to completion
DWORD wait(DWORD timeout) const noexcept
{
if (!m_completionSignal.wait(timeout))
{
// not setting m_internalError to timeout
// since the caller is allowed to try to wait() again later
return ERROR_TIMEOUT;
}
const auto lock = m_lock.lock_shared();
return m_internalError;
}
// waitable event handle, signaled when the callback has run to completion (or failed)
HANDLE notification_event() const noexcept
{
return m_completionSignal.get();
}
const TReturn& read_result() const noexcept
{
return result;
}
// move the result out of the object for move-only types
TReturn move_result() noexcept
{
TReturn move_out(std::move(result));
return move_out;
}
// non-copyable
WslThreadPoolWaitableResult(const WslThreadPoolWaitableResult&) = delete;
WslThreadPoolWaitableResult& operator=(const WslThreadPoolWaitableResult&) = delete;
private:
void run() noexcept override
{
// we are now running in the TP callback
{
const auto lock = m_lock.lock_exclusive();
if (m_runStatus != RunStatus::NotYetRun)
{
// return early - the caller has already canceled this
return;
}
m_runStatus = RunStatus::Running;
}
DWORD error = NO_ERROR;
try
{
result = std::move(m_function());
}
catch (...)
{
const HRESULT hr = wil::ResultFromCaughtException();
// HRESULT_TO_WIN32
error = (HRESULT_FACILITY(hr) == FACILITY_WIN32) ? HRESULT_CODE(hr) : hr;
}
const auto lock = m_lock.lock_exclusive();
WI_ASSERT(m_runStatus == RunStatus::Running);
m_runStatus = RunStatus::RanToCompletion;
m_internalError = error;
m_completionSignal.SetEvent();
}
void abort() noexcept override
{
const auto lock = m_lock.lock_exclusive();
// only override the error if we know we haven't started running their functor
if (m_runStatus == RunStatus::NotYetRun)
{
m_runStatus = RunStatus::Canceled;
m_internalError = ERROR_CANCELLED;
m_completionSignal.SetEvent();
}
}
std::function<TReturn(void)> m_function;
// a notification event
wil::unique_event m_completionSignal{wil::EventOptions::ManualReset};
mutable wil::srwlock m_lock;
TReturn result{};
DWORD m_internalError = NO_ERROR;
enum class RunStatus
{
NotYetRun,
Running,
RanToCompletion,
Canceled
} m_runStatus{RunStatus::NotYetRun};
};
class WslCoreMessageQueue
{
public:
WslCoreMessageQueue() : m_tpEnvironment(0, 1)
{
// create a single-threaded threadpool
m_tpHandle = m_tpEnvironment.create_tp(WorkCallback, this);
}
template <typename TReturn, typename FunctorType>
std::shared_ptr<WslThreadPoolWaitableResult<TReturn>> submit_with_results(FunctorType&& functor) noexcept
try
{
FAIL_FAST_IF(m_tpHandle.get() == nullptr);
const auto new_result = std::make_shared<WslThreadPoolWaitableResult<TReturn>>(std::forward<FunctorType>(functor));
// scope to the queue lock
{
const auto queueLock = m_lock.lock_exclusive();
THROW_HR_IF(HRESULT_FROM_WIN32(ERROR_CANCELLED), m_isCanceled);
m_workItems.emplace_back(new_result);
}
// always maintain a 1:1 ratio for calls to SubmitWorkWithResults() and ::SubmitThreadpoolWork
SubmitThreadpoolWork(m_tpHandle.get());
return new_result;
}
catch (...)
{
LOG_CAUGHT_EXCEPTION();
return nullptr;
}
template <typename FunctorType>
bool submit(FunctorType&& functor) noexcept
try
{
FAIL_FAST_IF(m_tpHandle.get() == nullptr);
// scope to the queue lock
{
const auto queueLock = m_lock.lock_exclusive();
THROW_HR_IF(HRESULT_FROM_WIN32(ERROR_CANCELLED), m_isCanceled);
m_workItems.emplace_back(std::forward<SimpleFunction_t>(functor));
}
// always maintain a 1:1 ratio for calls to SubmitWork() and ::SubmitThreadpoolWork
SubmitThreadpoolWork(m_tpHandle.get());
return true;
}
catch (...)
{
LOG_CAUGHT_EXCEPTION();
return false;
}
// functors must return type HRESULT
template <typename FunctorType>
HRESULT submit_and_wait(FunctorType&& functor) noexcept
try
{
HRESULT hr = HRESULT_FROM_WIN32(ERROR_OUTOFMEMORY);
if (const auto waitableResult = submit_with_results<HRESULT>(std::forward<FunctorType>(functor)))
{
hr = HRESULT_FROM_WIN32(waitableResult->wait(INFINITE));
if (SUCCEEDED(hr))
{
hr = waitableResult->read_result();
}
}
return hr;
}
CATCH_RETURN()
// cancels anything queued to the TP - this WslCoreMessageQueue instance can no longer be used
void cancel() noexcept
try
{
if (m_tpHandle)
{
// immediately release anyone waiting for these workitems not yet run
{
const auto queueLock = m_lock.lock_exclusive();
m_isCanceled = true;
for (const auto& work : m_workItems)
{
// signal that these are canceled before we shutdown the TP which they could be scheduled
if (const auto* pWaitableWorkitem = std::get_if<WaitableFunction_t>(&work))
{
(*pWaitableWorkitem)->abort();
}
}
m_workItems.clear();
}
// force the m_tpHandle to wait and close the TP
m_tpHandle.reset();
m_tpEnvironment.reset();
}
}
CATCH_LOG()
bool isRunningInQueue() const noexcept
{
const auto currentThreadId = GetThreadId(GetCurrentThread());
return currentThreadId == static_cast<DWORD>(InterlockedCompareExchange64(&m_threadpoolThreadId, 0ll, 0ll));
}
~WslCoreMessageQueue() noexcept
{
cancel();
}
WslCoreMessageQueue(const WslCoreMessageQueue&) = delete;
WslCoreMessageQueue& operator=(const WslCoreMessageQueue&) = delete;
WslCoreMessageQueue(WslCoreMessageQueue&&) = delete;
WslCoreMessageQueue& operator=(WslCoreMessageQueue&&) = delete;
private:
struct TPEnvironment
{
using unique_tp_env = wil::unique_struct<TP_CALLBACK_ENVIRON, decltype(&DestroyThreadpoolEnvironment), DestroyThreadpoolEnvironment>;
unique_tp_env m_tpEnvironment;
using unique_tp_pool = wil::unique_any<PTP_POOL, decltype(&CloseThreadpool), CloseThreadpool>;
unique_tp_pool m_threadPool;
TPEnvironment(DWORD countMinThread, DWORD countMaxThread)
{
InitializeThreadpoolEnvironment(&m_tpEnvironment);
m_threadPool.reset(CreateThreadpool(nullptr));
THROW_LAST_ERROR_IF_NULL(m_threadPool.get());
// Set min and max thread counts for custom thread pool
THROW_LAST_ERROR_IF(!::SetThreadpoolThreadMinimum(m_threadPool.get(), countMinThread));
SetThreadpoolThreadMaximum(m_threadPool.get(), countMaxThread);
SetThreadpoolCallbackPool(&m_tpEnvironment, m_threadPool.get());
}
wil::unique_threadpool_work create_tp(PTP_WORK_CALLBACK callback, void* pv)
{
wil::unique_threadpool_work newThreadpool(CreateThreadpoolWork(callback, pv, (m_threadPool) ? &m_tpEnvironment : nullptr));
THROW_LAST_ERROR_IF_NULL(newThreadpool.get());
return newThreadpool;
}
void reset()
{
m_threadPool.reset();
m_tpEnvironment.reset();
}
};
using SimpleFunction_t = std::function<void()>;
using WaitableFunction_t = std::shared_ptr<WslBaseThreadPoolWaitableResult>;
using FunctionVariant_t = std::variant<SimpleFunction_t, WaitableFunction_t>;
// the lock must be destroyed *after* the TP object (thus must be declared first)
// since the lock is used in the TP callback
// the lock is mutable to allow us to acquire the lock in const methods
mutable wil::srwlock m_lock;
TPEnvironment m_tpEnvironment;
wil::unique_threadpool_work m_tpHandle;
std::deque<FunctionVariant_t> m_workItems;
mutable LONG64 m_threadpoolThreadId{0}; // useful for callers to assert they are running within the queue
bool m_isCanceled{false};
static void CALLBACK WorkCallback(PTP_CALLBACK_INSTANCE, void* Context, PTP_WORK) noexcept
try
{
auto* pThis = static_cast<WslCoreMessageQueue*>(Context);
FunctionVariant_t work;
{
const auto queueLock = pThis->m_lock.lock_exclusive();
if (pThis->m_workItems.empty())
{
// pThis object is being destroyed and the queue was cleared
return;
}
std::swap(work, pThis->m_workItems.front());
pThis->m_workItems.pop_front();
InterlockedExchange64(&pThis->m_threadpoolThreadId, GetThreadId(GetCurrentThread()));
}
// run the tasks outside the WslCoreMessageQueue lock
const auto resetThreadIdOnExit = wil::scope_exit([pThis] { InterlockedExchange64(&pThis->m_threadpoolThreadId, 0ll); });
if (work.index() == 0)
{
const auto& workItem = std::get<SimpleFunction_t>(work);
workItem();
}
else
{
const auto& waitableWorkItem = std::get<WaitableFunction_t>(work);
waitableWorkItem->run();
}
}
CATCH_LOG()
};
} // namespace wsl::core
/*++
Copyright (c) Microsoft. All rights reserved.
Module Name:
WslCoreMessageQueue.h
Abstract:
This file contains a queuing implementation, guaranteeing running function objects
with guaranteed serialization in a threadpool thread
--*/
#pragma once
#include <deque>
#include <functional>
#include <memory>
#include <variant>
#include <windows.h>
#include <wil/resource.h>
namespace wsl::core {
// forward-declare classes that can instantiate a WslThreadPoolWaitableResult object
class WslCoreMessageQueue;
class WslBaseThreadPoolWaitableResult
{
public:
virtual ~WslBaseThreadPoolWaitableResult() noexcept = default;
private:
// limit who can run() and abort()
friend class WslCoreMessageQueue;
virtual void run() noexcept = 0;
virtual void abort() noexcept = 0;
};
template <typename TReturn>
class WslThreadPoolWaitableResult : public WslBaseThreadPoolWaitableResult
{
public:
// throws a wil exception on failure
template <typename FunctorType>
explicit WslThreadPoolWaitableResult(FunctorType&& functor) : m_function(std::forward<FunctorType>(functor))
{
}
~WslThreadPoolWaitableResult() noexcept override = default;
// returns ERROR_SUCCESS if the callback ran to completion
// returns ERROR_TIMEOUT if this wait timed out
// - this can be called multiple times if needing to probe
// any other error code resulted from attempting to run the callback
// - meaning it did *not* run to completion
DWORD wait(DWORD timeout) const noexcept
{
if (!m_completionSignal.wait(timeout))
{
// not setting m_internalError to timeout
// since the caller is allowed to try to wait() again later
return ERROR_TIMEOUT;
}
const auto lock = m_lock.lock_shared();
return m_internalError;
}
// waitable event handle, signaled when the callback has run to completion (or failed)
HANDLE notification_event() const noexcept
{
return m_completionSignal.get();
}
const TReturn& read_result() const noexcept
{
return result;
}
// move the result out of the object for move-only types
TReturn move_result() noexcept
{
TReturn move_out(std::move(result));
return move_out;
}
// non-copyable
WslThreadPoolWaitableResult(const WslThreadPoolWaitableResult&) = delete;
WslThreadPoolWaitableResult& operator=(const WslThreadPoolWaitableResult&) = delete;
private:
void run() noexcept override
{
// we are now running in the TP callback
{
const auto lock = m_lock.lock_exclusive();
if (m_runStatus != RunStatus::NotYetRun)
{
// return early - the caller has already canceled this
return;
}
m_runStatus = RunStatus::Running;
}
DWORD error = NO_ERROR;
try
{
result = std::move(m_function());
}
catch (...)
{
const HRESULT hr = wil::ResultFromCaughtException();
// HRESULT_TO_WIN32
error = (HRESULT_FACILITY(hr) == FACILITY_WIN32) ? HRESULT_CODE(hr) : hr;
}
const auto lock = m_lock.lock_exclusive();
WI_ASSERT(m_runStatus == RunStatus::Running);
m_runStatus = RunStatus::RanToCompletion;
m_internalError = error;
m_completionSignal.SetEvent();
}
void abort() noexcept override
{
const auto lock = m_lock.lock_exclusive();
// only override the error if we know we haven't started running their functor
if (m_runStatus == RunStatus::NotYetRun)
{
m_runStatus = RunStatus::Canceled;
m_internalError = ERROR_CANCELLED;
m_completionSignal.SetEvent();
}
}
std::function<TReturn(void)> m_function;
// a notification event
wil::unique_event m_completionSignal{wil::EventOptions::ManualReset};
mutable wil::srwlock m_lock;
TReturn result{};
DWORD m_internalError = NO_ERROR;
enum class RunStatus
{
NotYetRun,
Running,
RanToCompletion,
Canceled
} m_runStatus{RunStatus::NotYetRun};
};
class WslCoreMessageQueue
{
public:
WslCoreMessageQueue() : m_tpEnvironment(0, 1)
{
// create a single-threaded threadpool
m_tpHandle = m_tpEnvironment.create_tp(WorkCallback, this);
}
template <typename TReturn, typename FunctorType>
std::shared_ptr<WslThreadPoolWaitableResult<TReturn>> submit_with_results(FunctorType&& functor) noexcept
try
{
FAIL_FAST_IF(m_tpHandle.get() == nullptr);
const auto new_result = std::make_shared<WslThreadPoolWaitableResult<TReturn>>(std::forward<FunctorType>(functor));
// scope to the queue lock
{
const auto queueLock = m_lock.lock_exclusive();
THROW_HR_IF(HRESULT_FROM_WIN32(ERROR_CANCELLED), m_isCanceled);
m_workItems.emplace_back(new_result);
}
// always maintain a 1:1 ratio for calls to SubmitWorkWithResults() and ::SubmitThreadpoolWork
SubmitThreadpoolWork(m_tpHandle.get());
return new_result;
}
catch (...)
{
LOG_CAUGHT_EXCEPTION();
return nullptr;
}
template <typename FunctorType>
bool submit(FunctorType&& functor) noexcept
try
{
FAIL_FAST_IF(m_tpHandle.get() == nullptr);
// scope to the queue lock
{
const auto queueLock = m_lock.lock_exclusive();
THROW_HR_IF(HRESULT_FROM_WIN32(ERROR_CANCELLED), m_isCanceled);
m_workItems.emplace_back(std::forward<SimpleFunction_t>(functor));
}
// always maintain a 1:1 ratio for calls to SubmitWork() and ::SubmitThreadpoolWork
SubmitThreadpoolWork(m_tpHandle.get());
return true;
}
catch (...)
{
LOG_CAUGHT_EXCEPTION();
return false;
}
// functors must return type HRESULT
template <typename FunctorType>
HRESULT submit_and_wait(FunctorType&& functor) noexcept
try
{
HRESULT hr = HRESULT_FROM_WIN32(ERROR_OUTOFMEMORY);
if (const auto waitableResult = submit_with_results<HRESULT>(std::forward<FunctorType>(functor)))
{
hr = HRESULT_FROM_WIN32(waitableResult->wait(INFINITE));
if (SUCCEEDED(hr))
{
hr = waitableResult->read_result();
}
}
return hr;
}
CATCH_RETURN()
// cancels anything queued to the TP - this WslCoreMessageQueue instance can no longer be used
void cancel() noexcept
try
{
if (m_tpHandle)
{
// immediately release anyone waiting for these workitems not yet run
{
const auto queueLock = m_lock.lock_exclusive();
m_isCanceled = true;
for (const auto& work : m_workItems)
{
// signal that these are canceled before we shutdown the TP which they could be scheduled
if (const auto* pWaitableWorkitem = std::get_if<WaitableFunction_t>(&work))
{
(*pWaitableWorkitem)->abort();
}
}
m_workItems.clear();
}
// force the m_tpHandle to wait and close the TP
m_tpHandle.reset();
m_tpEnvironment.reset();
}
}
CATCH_LOG()
bool isRunningInQueue() const noexcept
{
const auto currentThreadId = GetThreadId(GetCurrentThread());
return currentThreadId == static_cast<DWORD>(InterlockedCompareExchange64(&m_threadpoolThreadId, 0ll, 0ll));
}
~WslCoreMessageQueue() noexcept
{
cancel();
}
WslCoreMessageQueue(const WslCoreMessageQueue&) = delete;
WslCoreMessageQueue& operator=(const WslCoreMessageQueue&) = delete;
WslCoreMessageQueue(WslCoreMessageQueue&&) = delete;
WslCoreMessageQueue& operator=(WslCoreMessageQueue&&) = delete;
private:
struct TPEnvironment
{
using unique_tp_env = wil::unique_struct<TP_CALLBACK_ENVIRON, decltype(&DestroyThreadpoolEnvironment), DestroyThreadpoolEnvironment>;
unique_tp_env m_tpEnvironment;
using unique_tp_pool = wil::unique_any<PTP_POOL, decltype(&CloseThreadpool), CloseThreadpool>;
unique_tp_pool m_threadPool;
TPEnvironment(DWORD countMinThread, DWORD countMaxThread)
{
InitializeThreadpoolEnvironment(&m_tpEnvironment);
m_threadPool.reset(CreateThreadpool(nullptr));
THROW_LAST_ERROR_IF_NULL(m_threadPool.get());
// Set min and max thread counts for custom thread pool
THROW_LAST_ERROR_IF(!::SetThreadpoolThreadMinimum(m_threadPool.get(), countMinThread));
SetThreadpoolThreadMaximum(m_threadPool.get(), countMaxThread);
SetThreadpoolCallbackPool(&m_tpEnvironment, m_threadPool.get());
}
wil::unique_threadpool_work create_tp(PTP_WORK_CALLBACK callback, void* pv)
{
wil::unique_threadpool_work newThreadpool(CreateThreadpoolWork(callback, pv, (m_threadPool) ? &m_tpEnvironment : nullptr));
THROW_LAST_ERROR_IF_NULL(newThreadpool.get());
return newThreadpool;
}
void reset()
{
m_threadPool.reset();
m_tpEnvironment.reset();
}
};
using SimpleFunction_t = std::function<void()>;
using WaitableFunction_t = std::shared_ptr<WslBaseThreadPoolWaitableResult>;
using FunctionVariant_t = std::variant<SimpleFunction_t, WaitableFunction_t>;
// the lock must be destroyed *after* the TP object (thus must be declared first)
// since the lock is used in the TP callback
// the lock is mutable to allow us to acquire the lock in const methods
mutable wil::srwlock m_lock;
TPEnvironment m_tpEnvironment;
wil::unique_threadpool_work m_tpHandle;
std::deque<FunctionVariant_t> m_workItems;
mutable LONG64 m_threadpoolThreadId{0}; // useful for callers to assert they are running within the queue
bool m_isCanceled{false};
static void CALLBACK WorkCallback(PTP_CALLBACK_INSTANCE, void* Context, PTP_WORK) noexcept
try
{
auto* pThis = static_cast<WslCoreMessageQueue*>(Context);
FunctionVariant_t work;
{
const auto queueLock = pThis->m_lock.lock_exclusive();
if (pThis->m_workItems.empty())
{
// pThis object is being destroyed and the queue was cleared
return;
}
std::swap(work, pThis->m_workItems.front());
pThis->m_workItems.pop_front();
InterlockedExchange64(&pThis->m_threadpoolThreadId, GetThreadId(GetCurrentThread()));
}
// run the tasks outside the WslCoreMessageQueue lock
const auto resetThreadIdOnExit = wil::scope_exit([pThis] { InterlockedExchange64(&pThis->m_threadpoolThreadId, 0ll); });
if (work.index() == 0)
{
const auto& workItem = std::get<SimpleFunction_t>(work);
workItem();
}
else
{
const auto& waitableWorkItem = std::get<WaitableFunction_t>(work);
waitableWorkItem->run();
}
}
CATCH_LOG()
};
} // namespace wsl::core

View File

@ -1,110 +1,110 @@
// Copyright (C) Microsoft Corporation. All rights reserved.
#include "precomp.h"
#include "hns_schema.h"
#include "WslCoreNetworkEndpointSettings.h"
#include "WslCoreHostDnsInfo.h"
using namespace wsl::shared;
std::shared_ptr<wsl::core::networking::NetworkSettings> wsl::core::networking::GetEndpointSettings(const hns::HNSEndpoint& properties)
{
EndpointIpAddress address{};
address.Address = windows::common::string::StringToSockAddrInet(properties.IPAddress);
address.AddressString = properties.IPAddress;
address.PrefixLength = properties.PrefixLength;
EndpointRoute route{};
route.DestinationPrefix.PrefixLength = 0;
IN4ADDR_SETANY(&route.DestinationPrefix.Prefix.Ipv4);
route.DestinationPrefixString = LX_INIT_UNSPECIFIED_ADDRESS;
route.NextHop = windows::common::string::StringToSockAddrInet(properties.GatewayAddress);
route.NextHopString = properties.GatewayAddress;
return std::make_shared<wsl::core::networking::NetworkSettings>(
properties.InterfaceConstraint.InterfaceGuid,
address,
route,
properties.MacAddress,
L"unuseddevicename",
properties.InterfaceConstraint.InterfaceIndex,
properties.InterfaceConstraint.InterfaceMediaType,
properties.DNSServerList);
}
std::shared_ptr<wsl::core::networking::NetworkSettings> wsl::core::networking::GetHostEndpointSettings()
{
HostDnsInfo dnsInfo;
dnsInfo.UpdateNetworkInformation();
auto addresses = dnsInfo.CurrentAddresses();
auto bestIndex = GetBestInterface();
auto bestInterfacePtr =
std::find_if(addresses.cbegin(), addresses.cend(), [&](const auto& address) { return address->IfIndex == bestIndex; });
if (bestInterfacePtr == addresses.end())
{
return std::make_shared<NetworkSettings>();
}
const auto& bestInterface = *bestInterfacePtr;
std::wstring macAddress = wsl::shared::string::FormatMacAddress(
wsl::shared::string::MacAddress{
bestInterface->PhysicalAddress[0],
bestInterface->PhysicalAddress[1],
bestInterface->PhysicalAddress[2],
bestInterface->PhysicalAddress[3],
bestInterface->PhysicalAddress[4],
bestInterface->PhysicalAddress[5]},
L'-');
EndpointIpAddress address{};
auto firstIpv4Address = bestInterface->FirstUnicastAddress;
while (firstIpv4Address && firstIpv4Address->Address.lpSockaddr->sa_family != AF_INET)
{
firstIpv4Address = firstIpv4Address->Next;
}
if (firstIpv4Address)
{
address.Address = *reinterpret_cast<SOCKADDR_INET*>(firstIpv4Address->Address.lpSockaddr);
address.AddressString = windows::common::string::SockAddrInetToWstring(address.Address);
address.PrefixLength = firstIpv4Address->OnLinkPrefixLength;
}
EndpointRoute route{};
PIP_ADAPTER_GATEWAY_ADDRESS nextGatewayAddress = bestInterface->FirstGatewayAddress;
while (nextGatewayAddress && nextGatewayAddress->Address.lpSockaddr->sa_family != AF_INET)
{
nextGatewayAddress = nextGatewayAddress->Next;
}
if (nextGatewayAddress)
{
route.DestinationPrefix.PrefixLength = 0;
IN4ADDR_SETANY(&route.DestinationPrefix.Prefix.Ipv4);
route.DestinationPrefixString = LX_INIT_UNSPECIFIED_ADDRESS;
route.NextHop = *reinterpret_cast<SOCKADDR_INET*>(nextGatewayAddress->Address.lpSockaddr);
route.NextHopString = windows::common::string::SockAddrInetToWstring(route.NextHop);
}
else if (address.Address.si_family == AF_INET)
{
IN_ADDR default_route{};
default_route.s_addr = htonl((ntohl(address.Address.Ipv4.sin_addr.s_addr) & ~((1 << (32 - address.PrefixLength)) - 1)) | 1);
route.DestinationPrefix.PrefixLength = 0;
IN4ADDR_SETANY(&route.DestinationPrefix.Prefix.Ipv4);
route.DestinationPrefixString = LX_INIT_UNSPECIFIED_ADDRESS;
IN4ADDR_SETSOCKADDR(&route.NextHop.Ipv4, &default_route, 0);
route.NextHopString = windows::common::string::SockAddrInetToWstring(route.NextHop);
}
std::wstring dnsServerList;
for (const auto& serverAddress : dnsInfo.GetDnsSettings(DnsSettingsFlags::IncludeVpn).Servers)
{
if (!dnsServerList.empty())
{
dnsServerList += L",";
}
dnsServerList += wsl::shared::string::MultiByteToWide(serverAddress);
}
return std::shared_ptr<NetworkSettings>(new NetworkSettings(
bestInterface->NetworkGuid, address, route, macAddress, {}, bestInterface->IfIndex, bestInterface->IfType, dnsServerList));
}
// Copyright (C) Microsoft Corporation. All rights reserved.
#include "precomp.h"
#include "hns_schema.h"
#include "WslCoreNetworkEndpointSettings.h"
#include "WslCoreHostDnsInfo.h"
using namespace wsl::shared;
std::shared_ptr<wsl::core::networking::NetworkSettings> wsl::core::networking::GetEndpointSettings(const hns::HNSEndpoint& properties)
{
EndpointIpAddress address{};
address.Address = windows::common::string::StringToSockAddrInet(properties.IPAddress);
address.AddressString = properties.IPAddress;
address.PrefixLength = properties.PrefixLength;
EndpointRoute route{};
route.DestinationPrefix.PrefixLength = 0;
IN4ADDR_SETANY(&route.DestinationPrefix.Prefix.Ipv4);
route.DestinationPrefixString = LX_INIT_UNSPECIFIED_ADDRESS;
route.NextHop = windows::common::string::StringToSockAddrInet(properties.GatewayAddress);
route.NextHopString = properties.GatewayAddress;
return std::make_shared<wsl::core::networking::NetworkSettings>(
properties.InterfaceConstraint.InterfaceGuid,
address,
route,
properties.MacAddress,
L"unuseddevicename",
properties.InterfaceConstraint.InterfaceIndex,
properties.InterfaceConstraint.InterfaceMediaType,
properties.DNSServerList);
}
std::shared_ptr<wsl::core::networking::NetworkSettings> wsl::core::networking::GetHostEndpointSettings()
{
HostDnsInfo dnsInfo;
dnsInfo.UpdateNetworkInformation();
auto addresses = dnsInfo.CurrentAddresses();
auto bestIndex = GetBestInterface();
auto bestInterfacePtr =
std::find_if(addresses.cbegin(), addresses.cend(), [&](const auto& address) { return address->IfIndex == bestIndex; });
if (bestInterfacePtr == addresses.end())
{
return std::make_shared<NetworkSettings>();
}
const auto& bestInterface = *bestInterfacePtr;
std::wstring macAddress = wsl::shared::string::FormatMacAddress(
wsl::shared::string::MacAddress{
bestInterface->PhysicalAddress[0],
bestInterface->PhysicalAddress[1],
bestInterface->PhysicalAddress[2],
bestInterface->PhysicalAddress[3],
bestInterface->PhysicalAddress[4],
bestInterface->PhysicalAddress[5]},
L'-');
EndpointIpAddress address{};
auto firstIpv4Address = bestInterface->FirstUnicastAddress;
while (firstIpv4Address && firstIpv4Address->Address.lpSockaddr->sa_family != AF_INET)
{
firstIpv4Address = firstIpv4Address->Next;
}
if (firstIpv4Address)
{
address.Address = *reinterpret_cast<SOCKADDR_INET*>(firstIpv4Address->Address.lpSockaddr);
address.AddressString = windows::common::string::SockAddrInetToWstring(address.Address);
address.PrefixLength = firstIpv4Address->OnLinkPrefixLength;
}
EndpointRoute route{};
PIP_ADAPTER_GATEWAY_ADDRESS nextGatewayAddress = bestInterface->FirstGatewayAddress;
while (nextGatewayAddress && nextGatewayAddress->Address.lpSockaddr->sa_family != AF_INET)
{
nextGatewayAddress = nextGatewayAddress->Next;
}
if (nextGatewayAddress)
{
route.DestinationPrefix.PrefixLength = 0;
IN4ADDR_SETANY(&route.DestinationPrefix.Prefix.Ipv4);
route.DestinationPrefixString = LX_INIT_UNSPECIFIED_ADDRESS;
route.NextHop = *reinterpret_cast<SOCKADDR_INET*>(nextGatewayAddress->Address.lpSockaddr);
route.NextHopString = windows::common::string::SockAddrInetToWstring(route.NextHop);
}
else if (address.Address.si_family == AF_INET)
{
IN_ADDR default_route{};
default_route.s_addr = htonl((ntohl(address.Address.Ipv4.sin_addr.s_addr) & ~((1 << (32 - address.PrefixLength)) - 1)) | 1);
route.DestinationPrefix.PrefixLength = 0;
IN4ADDR_SETANY(&route.DestinationPrefix.Prefix.Ipv4);
route.DestinationPrefixString = LX_INIT_UNSPECIFIED_ADDRESS;
IN4ADDR_SETSOCKADDR(&route.NextHop.Ipv4, &default_route, 0);
route.NextHopString = windows::common::string::SockAddrInetToWstring(route.NextHop);
}
std::wstring dnsServerList;
for (const auto& serverAddress : dnsInfo.GetDnsSettings(DnsSettingsFlags::IncludeVpn).Servers)
{
if (!dnsServerList.empty())
{
dnsServerList += L",";
}
dnsServerList += wsl::shared::string::MultiByteToWide(serverAddress);
}
return std::shared_ptr<NetworkSettings>(new NetworkSettings(
bestInterface->NetworkGuid, address, route, macAddress, {}, bestInterface->IfIndex, bestInterface->IfType, dnsServerList));
}

View File

@ -1,398 +1,398 @@
// Copyright (C) Microsoft Corporation. All rights reserved.
#pragma once
#include <algorithm>
#include <set>
#include <string>
#include <windows.h>
#include <mstcpip.h>
#include <ws2ipdef.h>
#include <netioapi.h>
#include "hcs.hpp"
#include "lxinitshared.h"
#include "Stringify.h"
#include "stringshared.h"
#include "WslCoreNetworkingSupport.h"
#include "hns_schema.h"
namespace wsl::core::networking {
constexpr auto AddEndpointRetryPeriod = std::chrono::milliseconds(100);
constexpr auto AddEndpointRetryTimeout = std::chrono::seconds(3);
constexpr auto AddEndpointRetryPredicate = [] {
// Don't retry if ModifyComputeSystem fails with:
// HCN_E_ENDPOINT_NOT_FOUND - indicates that the underlying network object was deleted.
// HCN_E_ENDPOINT_ALREADY_ATTACHED - occurs when HNS was restarted before the endpoints were removed.
// VM_E_INVALID_STATE - occurs when the VM has been terminated.
const auto result = wil::ResultFromCaughtException();
return result != HCN_E_ENDPOINT_NOT_FOUND && result != HCN_E_ENDPOINT_ALREADY_ATTACHED && result != VM_E_INVALID_STATE;
};
struct EndpointIpAddress
{
SOCKADDR_INET Address{};
std::wstring AddressString{};
unsigned char PrefixLength = 0;
unsigned int PrefixOrigin = 0;
unsigned int SuffixOrigin = 0;
// The following field can be changed from a const iterator in SyncIpStateWithLinux - that's why it's marked mutable.
mutable unsigned int PreferredLifetime = 0;
EndpointIpAddress() = default;
~EndpointIpAddress() noexcept = default;
EndpointIpAddress(EndpointIpAddress&&) = default;
EndpointIpAddress& operator=(EndpointIpAddress&&) = default;
EndpointIpAddress(const EndpointIpAddress&) = default;
EndpointIpAddress& operator=(const EndpointIpAddress&) = default;
explicit EndpointIpAddress(const MIB_UNICASTIPADDRESS_ROW& AddressRow) :
Address(AddressRow.Address),
AddressString(windows::common::string::SockAddrInetToWstring(AddressRow.Address)),
PrefixLength(AddressRow.OnLinkPrefixLength),
PrefixOrigin(AddressRow.PrefixOrigin),
SuffixOrigin(AddressRow.SuffixOrigin),
// We treat the preferred lifetime field as effective DAD state - 0 is not preferred, anything else is preferred.
// We do this for convenience, as we can't directly set the DAD state of an address into the guest, but we
// we can set an address's preferred lifetime (in Linux, at least).
PreferredLifetime(AddressRow.DadState == IpDadStatePreferred ? 0xFFFFFFFF : 0)
{
}
// operator== is deliberately not comparing PreferredLifetime (DAD state) for equality - only the address portion
bool operator==(const EndpointIpAddress& rhs) const noexcept
{
return Address == rhs.Address && PrefixLength == rhs.PrefixLength;
}
bool operator<(const EndpointIpAddress& rhs) const noexcept
{
if (Address == rhs.Address)
{
return PrefixLength < rhs.PrefixLength;
}
return Address < rhs.Address;
}
void Clear() noexcept
{
Address = {};
AddressString.clear();
PrefixLength = 0;
PrefixOrigin = 0;
SuffixOrigin = 0;
}
std::wstring GetPrefix() const
{
SOCKADDR_INET address{Address};
unsigned char* addressPointer{nullptr};
if (Address.si_family == AF_INET)
{
addressPointer = reinterpret_cast<unsigned char*>(&address.Ipv4.sin_addr);
}
else if (Address.si_family == AF_INET6)
{
addressPointer = address.Ipv6.sin6_addr.u.Byte;
}
else
{
return L"";
}
constexpr int c_numBitsPerByte = 8;
for (int i = 0, currPrefixLength = PrefixLength; i < INET_ADDR_LENGTH(Address.si_family); i++, currPrefixLength -= c_numBitsPerByte)
{
if (currPrefixLength < c_numBitsPerByte)
{
const int bitShiftAmt = c_numBitsPerByte - std::max(currPrefixLength, 0);
addressPointer[i] &= (0xFF >> bitShiftAmt) << bitShiftAmt;
}
}
const auto addressString = windows::common::string::SockAddrInetToWstring(address);
WI_ASSERT(!addressString.empty());
if (addressString.empty())
{
// just return an empty string if we have a bad address
return addressString;
}
return std::format(L"{}/{}", addressString, PrefixLength);
}
std::wstring GetIpv4BroadcastMask() const
{
// start with all bits set, then shift off the prefix
ULONG prefixMask{0xffffffff};
prefixMask <<= PrefixLength;
prefixMask >>= PrefixLength;
SOCKADDR_INET address{Address};
// flip to host-order, then apply the mask
ULONG hostOrder = ntohl(address.Ipv4.sin_addr.S_un.S_addr);
hostOrder |= prefixMask;
address.Ipv4.sin_addr.S_un.S_addr = htonl(hostOrder);
return windows::common::string::SockAddrInetToWstring(address);
}
bool IsPreferred() const noexcept
{
return PreferredLifetime > 0;
}
bool IsLinkLocal() const
{
return (Address.si_family == AF_INET && IN4_IS_ADDR_LINKLOCAL(&Address.Ipv4.sin_addr)) ||
(Address.si_family == AF_INET6 && IN6_IS_ADDR_LINKLOCAL(&Address.Ipv6.sin6_addr));
}
};
struct EndpointRoute
{
ADDRESS_FAMILY Family = AF_INET;
IP_ADDRESS_PREFIX DestinationPrefix{};
std::wstring DestinationPrefixString{};
SOCKADDR_INET NextHop{};
std::wstring NextHopString{};
unsigned char SitePrefixLength = 0;
unsigned int Metric = 0;
bool IsAutoGeneratedPrefixRoute = false;
EndpointRoute() = default;
~EndpointRoute() noexcept = default;
EndpointRoute(EndpointRoute&&) = default;
EndpointRoute& operator=(EndpointRoute&&) = default;
EndpointRoute(const EndpointRoute&) = default;
EndpointRoute& operator=(const EndpointRoute&) = default;
EndpointRoute(const MIB_IPFORWARD_ROW2& RouteRow) :
Family(RouteRow.NextHop.si_family),
DestinationPrefix(RouteRow.DestinationPrefix),
DestinationPrefixString(windows::common::string::SockAddrInetToWstring(RouteRow.DestinationPrefix.Prefix)),
NextHop(RouteRow.NextHop),
NextHopString(windows::common::string::SockAddrInetToWstring(RouteRow.NextHop)),
SitePrefixLength(RouteRow.SitePrefixLength),
Metric(RouteRow.Metric)
{
}
unsigned char GetMaxPrefixLength() const
{
return (Family == AF_INET) ? 32 : 128;
}
std::wstring GetFullDestinationPrefix() const
{
return std::format(L"{}/{}", DestinationPrefixString, static_cast<unsigned int>(DestinationPrefix.PrefixLength));
}
bool IsNextHopOnlink() const noexcept
{
return (Family == AF_INET && NextHopString == LX_INIT_UNSPECIFIED_ADDRESS) ||
(Family == AF_INET6 && NextHopString == LX_INIT_UNSPECIFIED_V6_ADDRESS);
}
bool IsDefault() const noexcept
{
return (Family == AF_INET && DestinationPrefixString == LX_INIT_UNSPECIFIED_ADDRESS) ||
(Family == AF_INET6 && DestinationPrefixString == LX_INIT_UNSPECIFIED_V6_ADDRESS);
}
bool IsUnicastAddressRoute() const noexcept
{
return (Family == AF_INET && DestinationPrefix.PrefixLength == 32) || (Family == AF_INET6 && DestinationPrefix.PrefixLength == 128);
}
std::wstring ToString() const
{
return std::format(L"{}=>{} [metric {}]", GetFullDestinationPrefix(), NextHopString, Metric);
}
bool operator==(const EndpointRoute& rhs) const noexcept
{
return Family == rhs.Family && DestinationPrefix.PrefixLength == rhs.DestinationPrefix.PrefixLength &&
DestinationPrefix.Prefix == rhs.DestinationPrefix.Prefix && NextHop == rhs.NextHop &&
SitePrefixLength == rhs.SitePrefixLength && Metric == rhs.Metric;
}
bool operator!=(const EndpointRoute& other) const
{
return !(*this == other);
}
// sort by family, then by next-hop (on-link routes first), then by prefix, then by metric
bool operator<(const EndpointRoute& rhs) const noexcept
{
if (Family == rhs.Family)
{
if (NextHop == rhs.NextHop)
{
if (DestinationPrefix.Prefix == rhs.DestinationPrefix.Prefix)
{
if (DestinationPrefix.PrefixLength == rhs.DestinationPrefix.PrefixLength)
{
if (Metric == rhs.Metric)
{
return SitePrefixLength < rhs.SitePrefixLength;
}
return Metric < rhs.Metric;
}
return DestinationPrefix.PrefixLength < rhs.DestinationPrefix.PrefixLength;
}
return DestinationPrefix.Prefix < rhs.DestinationPrefix.Prefix;
}
return NextHop < rhs.NextHop;
}
return Family < rhs.Family;
}
};
struct NetworkSettings
{
NetworkSettings() = default;
NetworkSettings(
const GUID& interfaceGuid,
EndpointIpAddress preferredIpAddress,
EndpointRoute gateway,
std::wstring macAddress,
std::wstring deviceName,
uint32_t interfaceIndex,
uint32_t mediaType,
const std::wstring& dnsServerList) :
InterfaceGuid(interfaceGuid),
PreferredIpAddress(std::move(preferredIpAddress)),
MacAddress(std::move(macAddress)),
DeviceName(std::move(deviceName)),
InterfaceIndex(interfaceIndex),
InterfaceType(mediaType)
{
Routes.emplace(std::move(gateway));
DnsServers = wsl::shared::string::Split(dnsServerList, L',');
}
GUID InterfaceGuid{};
EndpointIpAddress PreferredIpAddress{};
std::set<EndpointIpAddress> IpAddresses{}; // Does not include PreferredIpAddress.
std::set<EndpointRoute> Routes{};
std::vector<std::wstring> DnsServers{};
std::wstring MacAddress;
std::wstring DeviceName;
IF_INDEX InterfaceIndex = 0;
IFTYPE InterfaceType = 0;
ULONG IPv4InterfaceMtu = 0;
ULONG IPv6InterfaceMtu = 0;
// some interfaces will only have an IPv4 or IPv6 interface
std::optional<ULONG> IPv4InterfaceMetric = 0;
std::optional<ULONG> IPv6InterfaceMetric = 0;
bool IsHidden = false;
bool IsConnected = false;
bool IsMetered = false;
bool DisableIpv4DefaultRoutes = false;
bool DisableIpv6DefaultRoutes = false;
bool PendingUpdateToReconnectForMetered = false;
bool PendingIPInterfaceUpdate = false;
auto operator<=>(const NetworkSettings&) const = default;
std::wstring GetBestGatewayAddressString() const
{
// Best is currently defined as simply the first IPv4 gateway.
for (const auto& route : Routes)
{
if (route.Family == AF_INET && route.DestinationPrefix.PrefixLength == 0 && route.DestinationPrefixString == LX_INIT_UNSPECIFIED_ADDRESS)
{
return route.NextHopString;
}
}
return {};
}
SOCKADDR_INET GetBestGatewayAddress() const
{
// Best is currently defined as simply the first IPv4 gateway.
for (const auto& route : Routes)
{
if (route.Family == AF_INET && route.DestinationPrefix.PrefixLength == 0 && route.DestinationPrefixString == LX_INIT_UNSPECIFIED_ADDRESS)
{
return route.NextHop;
}
}
return {};
}
std::wstring IpAddressesString() const
{
return std::accumulate(std::begin(IpAddresses), std::end(IpAddresses), std::wstring{}, [](const std::wstring& prev, const auto& addr) {
return addr.AddressString + (prev.empty() ? L"" : L"," + prev);
});
}
std::wstring RoutesString() const
{
return std::accumulate(std::begin(Routes), std::end(Routes), std::wstring{}, [](const std::wstring& prev, const EndpointRoute& route) {
return route.ToString() + (prev.empty() ? L"" : L"," + prev);
});
}
std::wstring DnsServersString() const
{
return wsl::shared::string::Join(DnsServers, L',');
}
// will return ULONG_MAX if there's no configured MTU
ULONG GetEffectiveMtu() const noexcept
{
return std::min(IPv4InterfaceMtu > 0 ? IPv4InterfaceMtu : ULONG_MAX, IPv6InterfaceMtu > 0 ? IPv6InterfaceMtu : ULONG_MAX);
}
// will return zero if there's no configured metric
ULONG GetMinimumMetric() const noexcept
{
if (!IPv4InterfaceMetric.has_value() && !IPv6InterfaceMetric.has_value())
{
return 0;
}
if (!IPv4InterfaceMetric.has_value())
{
return IPv6InterfaceMetric.value();
}
if (!IPv6InterfaceMetric.has_value())
{
return IPv4InterfaceMetric.value();
}
return std::min(IPv4InterfaceMetric.value(), IPv6InterfaceMetric.value());
}
};
std::shared_ptr<NetworkSettings> GetEndpointSettings(const wsl::shared::hns::HNSEndpoint& properties);
std::shared_ptr<NetworkSettings> GetHostEndpointSettings();
#define TRACE_NETWORKSETTINGS_OBJECT(settings) \
TraceLoggingValue((settings)->InterfaceGuid, "interfaceGuid"), TraceLoggingValue((settings)->InterfaceIndex, "interfaceIndex"), \
TraceLoggingValue((settings)->InterfaceType, "interfaceType"), \
TraceLoggingValue((settings)->IsConnected, "isConnected"), TraceLoggingValue((settings)->IsMetered, "isMetered"), \
TraceLoggingValue((settings)->GetBestGatewayAddressString().c_str(), "bestGatewayAddress"), \
TraceLoggingValue((settings)->PreferredIpAddress.AddressString.c_str(), "preferredIpAddress"), \
TraceLoggingValue((settings)->PreferredIpAddress.PrefixLength, "preferredIpAddressPrefixLength"), \
TraceLoggingValue((settings)->IpAddressesString().c_str(), "ipAddresses"), \
TraceLoggingValue((settings)->RoutesString().c_str(), "routes"), \
TraceLoggingValue((settings)->DnsServersString().c_str(), "dnsServerList"), \
TraceLoggingValue((settings)->MacAddress.c_str(), "macAddress"), \
TraceLoggingValue((settings)->IPv4InterfaceMtu, "IPv4InterfaceMtu"), \
TraceLoggingValue((settings)->IPv6InterfaceMtu, "IPv6InterfaceMtu"), \
TraceLoggingValue((settings)->IPv4InterfaceMetric.value_or(0xffffffff), "IPv4InterfaceMetric"), \
TraceLoggingValue((settings)->IPv6InterfaceMetric.value_or(0xffffffff), "IPv6InterfaceMetric"), \
TraceLoggingValue((settings)->PendingIPInterfaceUpdate, "PendingIPInterfaceUpdate"), \
TraceLoggingValue((settings)->PendingUpdateToReconnectForMetered, "PendingUpdateToReconnectForMetered")
} // namespace wsl::core::networking
// Copyright (C) Microsoft Corporation. All rights reserved.
#pragma once
#include <algorithm>
#include <set>
#include <string>
#include <windows.h>
#include <mstcpip.h>
#include <ws2ipdef.h>
#include <netioapi.h>
#include "hcs.hpp"
#include "lxinitshared.h"
#include "Stringify.h"
#include "stringshared.h"
#include "WslCoreNetworkingSupport.h"
#include "hns_schema.h"
namespace wsl::core::networking {
constexpr auto AddEndpointRetryPeriod = std::chrono::milliseconds(100);
constexpr auto AddEndpointRetryTimeout = std::chrono::seconds(3);
constexpr auto AddEndpointRetryPredicate = [] {
// Don't retry if ModifyComputeSystem fails with:
// HCN_E_ENDPOINT_NOT_FOUND - indicates that the underlying network object was deleted.
// HCN_E_ENDPOINT_ALREADY_ATTACHED - occurs when HNS was restarted before the endpoints were removed.
// VM_E_INVALID_STATE - occurs when the VM has been terminated.
const auto result = wil::ResultFromCaughtException();
return result != HCN_E_ENDPOINT_NOT_FOUND && result != HCN_E_ENDPOINT_ALREADY_ATTACHED && result != VM_E_INVALID_STATE;
};
struct EndpointIpAddress
{
SOCKADDR_INET Address{};
std::wstring AddressString{};
unsigned char PrefixLength = 0;
unsigned int PrefixOrigin = 0;
unsigned int SuffixOrigin = 0;
// The following field can be changed from a const iterator in SyncIpStateWithLinux - that's why it's marked mutable.
mutable unsigned int PreferredLifetime = 0;
EndpointIpAddress() = default;
~EndpointIpAddress() noexcept = default;
EndpointIpAddress(EndpointIpAddress&&) = default;
EndpointIpAddress& operator=(EndpointIpAddress&&) = default;
EndpointIpAddress(const EndpointIpAddress&) = default;
EndpointIpAddress& operator=(const EndpointIpAddress&) = default;
explicit EndpointIpAddress(const MIB_UNICASTIPADDRESS_ROW& AddressRow) :
Address(AddressRow.Address),
AddressString(windows::common::string::SockAddrInetToWstring(AddressRow.Address)),
PrefixLength(AddressRow.OnLinkPrefixLength),
PrefixOrigin(AddressRow.PrefixOrigin),
SuffixOrigin(AddressRow.SuffixOrigin),
// We treat the preferred lifetime field as effective DAD state - 0 is not preferred, anything else is preferred.
// We do this for convenience, as we can't directly set the DAD state of an address into the guest, but we
// we can set an address's preferred lifetime (in Linux, at least).
PreferredLifetime(AddressRow.DadState == IpDadStatePreferred ? 0xFFFFFFFF : 0)
{
}
// operator== is deliberately not comparing PreferredLifetime (DAD state) for equality - only the address portion
bool operator==(const EndpointIpAddress& rhs) const noexcept
{
return Address == rhs.Address && PrefixLength == rhs.PrefixLength;
}
bool operator<(const EndpointIpAddress& rhs) const noexcept
{
if (Address == rhs.Address)
{
return PrefixLength < rhs.PrefixLength;
}
return Address < rhs.Address;
}
void Clear() noexcept
{
Address = {};
AddressString.clear();
PrefixLength = 0;
PrefixOrigin = 0;
SuffixOrigin = 0;
}
std::wstring GetPrefix() const
{
SOCKADDR_INET address{Address};
unsigned char* addressPointer{nullptr};
if (Address.si_family == AF_INET)
{
addressPointer = reinterpret_cast<unsigned char*>(&address.Ipv4.sin_addr);
}
else if (Address.si_family == AF_INET6)
{
addressPointer = address.Ipv6.sin6_addr.u.Byte;
}
else
{
return L"";
}
constexpr int c_numBitsPerByte = 8;
for (int i = 0, currPrefixLength = PrefixLength; i < INET_ADDR_LENGTH(Address.si_family); i++, currPrefixLength -= c_numBitsPerByte)
{
if (currPrefixLength < c_numBitsPerByte)
{
const int bitShiftAmt = c_numBitsPerByte - std::max(currPrefixLength, 0);
addressPointer[i] &= (0xFF >> bitShiftAmt) << bitShiftAmt;
}
}
const auto addressString = windows::common::string::SockAddrInetToWstring(address);
WI_ASSERT(!addressString.empty());
if (addressString.empty())
{
// just return an empty string if we have a bad address
return addressString;
}
return std::format(L"{}/{}", addressString, PrefixLength);
}
std::wstring GetIpv4BroadcastMask() const
{
// start with all bits set, then shift off the prefix
ULONG prefixMask{0xffffffff};
prefixMask <<= PrefixLength;
prefixMask >>= PrefixLength;
SOCKADDR_INET address{Address};
// flip to host-order, then apply the mask
ULONG hostOrder = ntohl(address.Ipv4.sin_addr.S_un.S_addr);
hostOrder |= prefixMask;
address.Ipv4.sin_addr.S_un.S_addr = htonl(hostOrder);
return windows::common::string::SockAddrInetToWstring(address);
}
bool IsPreferred() const noexcept
{
return PreferredLifetime > 0;
}
bool IsLinkLocal() const
{
return (Address.si_family == AF_INET && IN4_IS_ADDR_LINKLOCAL(&Address.Ipv4.sin_addr)) ||
(Address.si_family == AF_INET6 && IN6_IS_ADDR_LINKLOCAL(&Address.Ipv6.sin6_addr));
}
};
struct EndpointRoute
{
ADDRESS_FAMILY Family = AF_INET;
IP_ADDRESS_PREFIX DestinationPrefix{};
std::wstring DestinationPrefixString{};
SOCKADDR_INET NextHop{};
std::wstring NextHopString{};
unsigned char SitePrefixLength = 0;
unsigned int Metric = 0;
bool IsAutoGeneratedPrefixRoute = false;
EndpointRoute() = default;
~EndpointRoute() noexcept = default;
EndpointRoute(EndpointRoute&&) = default;
EndpointRoute& operator=(EndpointRoute&&) = default;
EndpointRoute(const EndpointRoute&) = default;
EndpointRoute& operator=(const EndpointRoute&) = default;
EndpointRoute(const MIB_IPFORWARD_ROW2& RouteRow) :
Family(RouteRow.NextHop.si_family),
DestinationPrefix(RouteRow.DestinationPrefix),
DestinationPrefixString(windows::common::string::SockAddrInetToWstring(RouteRow.DestinationPrefix.Prefix)),
NextHop(RouteRow.NextHop),
NextHopString(windows::common::string::SockAddrInetToWstring(RouteRow.NextHop)),
SitePrefixLength(RouteRow.SitePrefixLength),
Metric(RouteRow.Metric)
{
}
unsigned char GetMaxPrefixLength() const
{
return (Family == AF_INET) ? 32 : 128;
}
std::wstring GetFullDestinationPrefix() const
{
return std::format(L"{}/{}", DestinationPrefixString, static_cast<unsigned int>(DestinationPrefix.PrefixLength));
}
bool IsNextHopOnlink() const noexcept
{
return (Family == AF_INET && NextHopString == LX_INIT_UNSPECIFIED_ADDRESS) ||
(Family == AF_INET6 && NextHopString == LX_INIT_UNSPECIFIED_V6_ADDRESS);
}
bool IsDefault() const noexcept
{
return (Family == AF_INET && DestinationPrefixString == LX_INIT_UNSPECIFIED_ADDRESS) ||
(Family == AF_INET6 && DestinationPrefixString == LX_INIT_UNSPECIFIED_V6_ADDRESS);
}
bool IsUnicastAddressRoute() const noexcept
{
return (Family == AF_INET && DestinationPrefix.PrefixLength == 32) || (Family == AF_INET6 && DestinationPrefix.PrefixLength == 128);
}
std::wstring ToString() const
{
return std::format(L"{}=>{} [metric {}]", GetFullDestinationPrefix(), NextHopString, Metric);
}
bool operator==(const EndpointRoute& rhs) const noexcept
{
return Family == rhs.Family && DestinationPrefix.PrefixLength == rhs.DestinationPrefix.PrefixLength &&
DestinationPrefix.Prefix == rhs.DestinationPrefix.Prefix && NextHop == rhs.NextHop &&
SitePrefixLength == rhs.SitePrefixLength && Metric == rhs.Metric;
}
bool operator!=(const EndpointRoute& other) const
{
return !(*this == other);
}
// sort by family, then by next-hop (on-link routes first), then by prefix, then by metric
bool operator<(const EndpointRoute& rhs) const noexcept
{
if (Family == rhs.Family)
{
if (NextHop == rhs.NextHop)
{
if (DestinationPrefix.Prefix == rhs.DestinationPrefix.Prefix)
{
if (DestinationPrefix.PrefixLength == rhs.DestinationPrefix.PrefixLength)
{
if (Metric == rhs.Metric)
{
return SitePrefixLength < rhs.SitePrefixLength;
}
return Metric < rhs.Metric;
}
return DestinationPrefix.PrefixLength < rhs.DestinationPrefix.PrefixLength;
}
return DestinationPrefix.Prefix < rhs.DestinationPrefix.Prefix;
}
return NextHop < rhs.NextHop;
}
return Family < rhs.Family;
}
};
struct NetworkSettings
{
NetworkSettings() = default;
NetworkSettings(
const GUID& interfaceGuid,
EndpointIpAddress preferredIpAddress,
EndpointRoute gateway,
std::wstring macAddress,
std::wstring deviceName,
uint32_t interfaceIndex,
uint32_t mediaType,
const std::wstring& dnsServerList) :
InterfaceGuid(interfaceGuid),
PreferredIpAddress(std::move(preferredIpAddress)),
MacAddress(std::move(macAddress)),
DeviceName(std::move(deviceName)),
InterfaceIndex(interfaceIndex),
InterfaceType(mediaType)
{
Routes.emplace(std::move(gateway));
DnsServers = wsl::shared::string::Split(dnsServerList, L',');
}
GUID InterfaceGuid{};
EndpointIpAddress PreferredIpAddress{};
std::set<EndpointIpAddress> IpAddresses{}; // Does not include PreferredIpAddress.
std::set<EndpointRoute> Routes{};
std::vector<std::wstring> DnsServers{};
std::wstring MacAddress;
std::wstring DeviceName;
IF_INDEX InterfaceIndex = 0;
IFTYPE InterfaceType = 0;
ULONG IPv4InterfaceMtu = 0;
ULONG IPv6InterfaceMtu = 0;
// some interfaces will only have an IPv4 or IPv6 interface
std::optional<ULONG> IPv4InterfaceMetric = 0;
std::optional<ULONG> IPv6InterfaceMetric = 0;
bool IsHidden = false;
bool IsConnected = false;
bool IsMetered = false;
bool DisableIpv4DefaultRoutes = false;
bool DisableIpv6DefaultRoutes = false;
bool PendingUpdateToReconnectForMetered = false;
bool PendingIPInterfaceUpdate = false;
auto operator<=>(const NetworkSettings&) const = default;
std::wstring GetBestGatewayAddressString() const
{
// Best is currently defined as simply the first IPv4 gateway.
for (const auto& route : Routes)
{
if (route.Family == AF_INET && route.DestinationPrefix.PrefixLength == 0 && route.DestinationPrefixString == LX_INIT_UNSPECIFIED_ADDRESS)
{
return route.NextHopString;
}
}
return {};
}
SOCKADDR_INET GetBestGatewayAddress() const
{
// Best is currently defined as simply the first IPv4 gateway.
for (const auto& route : Routes)
{
if (route.Family == AF_INET && route.DestinationPrefix.PrefixLength == 0 && route.DestinationPrefixString == LX_INIT_UNSPECIFIED_ADDRESS)
{
return route.NextHop;
}
}
return {};
}
std::wstring IpAddressesString() const
{
return std::accumulate(std::begin(IpAddresses), std::end(IpAddresses), std::wstring{}, [](const std::wstring& prev, const auto& addr) {
return addr.AddressString + (prev.empty() ? L"" : L"," + prev);
});
}
std::wstring RoutesString() const
{
return std::accumulate(std::begin(Routes), std::end(Routes), std::wstring{}, [](const std::wstring& prev, const EndpointRoute& route) {
return route.ToString() + (prev.empty() ? L"" : L"," + prev);
});
}
std::wstring DnsServersString() const
{
return wsl::shared::string::Join(DnsServers, L',');
}
// will return ULONG_MAX if there's no configured MTU
ULONG GetEffectiveMtu() const noexcept
{
return std::min(IPv4InterfaceMtu > 0 ? IPv4InterfaceMtu : ULONG_MAX, IPv6InterfaceMtu > 0 ? IPv6InterfaceMtu : ULONG_MAX);
}
// will return zero if there's no configured metric
ULONG GetMinimumMetric() const noexcept
{
if (!IPv4InterfaceMetric.has_value() && !IPv6InterfaceMetric.has_value())
{
return 0;
}
if (!IPv4InterfaceMetric.has_value())
{
return IPv6InterfaceMetric.value();
}
if (!IPv6InterfaceMetric.has_value())
{
return IPv4InterfaceMetric.value();
}
return std::min(IPv4InterfaceMetric.value(), IPv6InterfaceMetric.value());
}
};
std::shared_ptr<NetworkSettings> GetEndpointSettings(const wsl::shared::hns::HNSEndpoint& properties);
std::shared_ptr<NetworkSettings> GetHostEndpointSettings();
#define TRACE_NETWORKSETTINGS_OBJECT(settings) \
TraceLoggingValue((settings)->InterfaceGuid, "interfaceGuid"), TraceLoggingValue((settings)->InterfaceIndex, "interfaceIndex"), \
TraceLoggingValue((settings)->InterfaceType, "interfaceType"), \
TraceLoggingValue((settings)->IsConnected, "isConnected"), TraceLoggingValue((settings)->IsMetered, "isMetered"), \
TraceLoggingValue((settings)->GetBestGatewayAddressString().c_str(), "bestGatewayAddress"), \
TraceLoggingValue((settings)->PreferredIpAddress.AddressString.c_str(), "preferredIpAddress"), \
TraceLoggingValue((settings)->PreferredIpAddress.PrefixLength, "preferredIpAddressPrefixLength"), \
TraceLoggingValue((settings)->IpAddressesString().c_str(), "ipAddresses"), \
TraceLoggingValue((settings)->RoutesString().c_str(), "routes"), \
TraceLoggingValue((settings)->DnsServersString().c_str(), "dnsServerList"), \
TraceLoggingValue((settings)->MacAddress.c_str(), "macAddress"), \
TraceLoggingValue((settings)->IPv4InterfaceMtu, "IPv4InterfaceMtu"), \
TraceLoggingValue((settings)->IPv6InterfaceMtu, "IPv6InterfaceMtu"), \
TraceLoggingValue((settings)->IPv4InterfaceMetric.value_or(0xffffffff), "IPv4InterfaceMetric"), \
TraceLoggingValue((settings)->IPv6InterfaceMetric.value_or(0xffffffff), "IPv6InterfaceMetric"), \
TraceLoggingValue((settings)->PendingIPInterfaceUpdate, "PendingIPInterfaceUpdate"), \
TraceLoggingValue((settings)->PendingUpdateToReconnectForMetered, "PendingUpdateToReconnectForMetered")
} // namespace wsl::core::networking

View File

@ -44,6 +44,7 @@ inline auto c_msixPackageFamilyName = L"MicrosoftCorporationII.WindowsSubsystemF
inline auto c_githubUrlOverrideRegistryValue = L"GitHubUrlOverride";
inline auto c_vhdFileExtension = L".vhd";
inline auto c_vhdxFileExtension = L".vhdx";
inline constexpr auto c_vmOwner = L"WSL"; // TODO-WSLA: Does this apply to WSLA ?
struct GitHubReleaseAsset
{

View File

@ -1,6 +1,5 @@
set(SOURCES
DistributionRegistration.cpp
LxssSecurity.cpp
LxssUserCallback.cpp
LxssUserSession.cpp
LxssUserSessionFactory.cpp
@ -11,10 +10,6 @@ set(SOURCES
PluginManager.cpp
ServiceMain.cpp
BridgedNetworking.cpp
DeviceHostProxy.cpp
Dmesg.cpp
DnsTunnelingChannel.cpp
GnsChannel.cpp
GnsPortTrackerChannel.cpp
GnsRpcServer.cpp
GuestTelemetryLogger.cpp
@ -22,21 +17,12 @@ set(SOURCES
LxssConsoleManager.cpp
LxssCreateProcess.cpp
MirroredNetworking.cpp
NatNetworking.cpp
RingBuffer.cpp
WslCoreFilesystem.cpp
WslCoreGuestNetworkService.cpp
WslCoreHostDnsInfo.cpp
WslCoreInstance.cpp
WslMirroredNetworking.cpp
WslCoreNetworkEndpointSettings.cpp
DnsResolver.cpp
WslCoreTcpIpStateTracking.cpp
WslCoreVm.cpp
VirtioNetworking.cpp
WSLAUserSession.cpp
WSLAUserSessionFactory.cpp
WSLAVirtualMachine.cpp
main.rc
${CMAKE_CURRENT_BINARY_DIR}/../mc/${TARGET_PLATFORM}/${CMAKE_BUILD_TYPE}/wsleventschema.rc
application.manifest)
@ -44,7 +30,6 @@ set(SOURCES
set(HEADERS
../../inc/comservicehelper.h
DistributionRegistration.h
LxssSecurity.h
LxssUserCallback.h
LxssUserSession.h
LxssUserSessionFactory.h
@ -53,35 +38,20 @@ set(HEADERS
PluginManager.h
LxssInstance.h
BridgedNetworking.h
DeviceHostProxy.h
Dmesg.h
DnsTunnelingChannel.h
GnsChannel.h
GnsPortTrackerChannel.h
GnsRpcServer.h
GuestTelemetryLogger.h
INetworkingEngine.h
IMirroredNetworkManager.h
Lifetime.h
LxssConsoleManager.h
LxssCreateProcess.h
MirroredNetworking.h
NatNetworking.h
RingBuffer.h
WslCoreFilesystem.h
WslCoreGuestNetworkService.h
WslCoreHostDnsInfo.h
WslCoreInstance.h
WslCoreMessageQueue.h
WslMirroredNetworking.h
WslCoreNetworkEndpoint.h
WslCoreNetworkEndpointSettings.h
DnsResolver.h
WslCoreTcpIpStateTracking.h
WslCoreVm.h
WSLAUserSession.h
WSLAUserSessionFactory.h
WSLAVirtualMachine.h)
WslCoreVm.h)
include_directories(${CMAKE_SOURCE_DIR}/src/windows/wslaclient)

View File

@ -256,7 +256,7 @@ void MirroredNetworking::Initialize()
m_config.FirewallConfig.Enabled(), m_config.IgnoredPorts, m_runtimeId, m_gnsRpcServer->GetServerUuid(), s_GuestNetworkServiceCallback, this);
m_ephemeralPortRange = m_guestNetworkService.AllocateEphemeralPortRange();
networking::ConfigureHyperVFirewall(m_config.FirewallConfig, c_vmOwner);
networking::ConfigureHyperVFirewall(m_config.FirewallConfig, wsl::windows::common::wslutil::c_vmOwner);
// must keep all m_networkManager interactions (including) creation queued
// also must queue GNS callbacks to keep them serialized

View File

@ -31,7 +31,6 @@ wil::unique_event g_networkingReady{wil::EventOptions::ManualReset};
// Declare the LxssUserSession COM class.
CoCreatableClassWrlCreatorMapInclude(LxssUserSession);
CoCreatableClassWrlCreatorMapInclude(WSLAUserSession);
struct WslServiceSecurityPolicy
{
@ -241,7 +240,6 @@ void WslService::ServiceStopped()
// Terminate all user sessions.
ClearSessionsAndBlockNewInstances();
wsl::windows::service::wsla::ClearWslaSessionsAndBlockNewInstances();
// Disconnect from the LxCore driver.
if (g_lxcoreInitialized)

View File

@ -1468,7 +1468,7 @@ void WslCoreVm::FreeLun(_In_ ULONG lun)
std::wstring WslCoreVm::GenerateConfigJson()
{
hcs::ComputeSystem systemSettings{};
systemSettings.Owner = c_vmOwner;
systemSettings.Owner = wsl::windows::common::wslutil::c_vmOwner;
systemSettings.ShouldTerminateOnLastHandleClosed = true;
systemSettings.SchemaVersion.Major = 2;
systemSettings.SchemaVersion.Minor = 3;
@ -1575,7 +1575,7 @@ std::wstring WslCoreVm::GenerateConfigJson()
// Set the vmmem suffix which will change the process name in task manager.
if (helpers::IsVmemmSuffixSupported())
{
vmSettings.ComputeTopology.Memory.HostingProcessNameSuffix = c_vmOwner;
vmSettings.ComputeTopology.Memory.HostingProcessNameSuffix = wsl::windows::common::wslutil::c_vmOwner;
}
// If nested virtualization was requested, ensure the platform supports it.

View File

@ -39,8 +39,6 @@ inline constexpr auto c_optionsValueName = L"Options";
inline constexpr auto c_typeValueName = L"Type";
inline constexpr auto c_mountNameValueName = L"Name";
inline constexpr auto c_vmOwner = L"WSL";
static constexpr GUID c_virtiofsAdminClassId = {0x7e6ad219, 0xd1b3, 0x42d5, {0xb8, 0xee, 0xd9, 0x63, 0x24, 0xe6, 0x4f, 0xf6}};
// {60285AE6-AAF3-4456-B444-A6C2D0DEDA38}

View File

@ -396,97 +396,3 @@ cpp_quote("#define WSL_E_DISK_CORRUPTED MAKE_HRESULT(SEVERITY_ERROR, FACILITY_IT
cpp_quote("#define WSL_E_DISTRIBUTION_NAME_NEEDED MAKE_HRESULT(SEVERITY_ERROR, FACILITY_ITF, WSL_E_BASE + 0x30) /* 0x80040330 */")
cpp_quote("#define WSL_E_INVALID_JSON MAKE_HRESULT(SEVERITY_ERROR, FACILITY_ITF, WSL_E_BASE + 0x31) /* 0x80040331 */")
cpp_quote("#define WSL_E_VM_CRASHED MAKE_HRESULT(SEVERITY_ERROR, FACILITY_ITF, WSL_E_BASE + 0x32) /* 0x80040332 */")
typedef
struct _WSL_VERSION {
ULONG Major;
ULONG Minor;
ULONG Revision;
} WSL_VERSION;
typedef [system_handle(sh_socket)] HANDLE HVSOCKET_HANDLE;
typedef
struct _WSLA_CREATE_PROCESS_OPTIONS {
[string] LPCSTR Executable;
ULONG CommandLineCount;
[unique, size_is(CommandLineCount)] LPCSTR* CommandLine;
ULONG EnvironmentCount;
[unique, size_is(EnvironmentCount)] LPCSTR* Environment;
[unique] LPCSTR CurrentDirectory;
} WSLA_CREATE_PROCESS_OPTIONS;
typedef struct _WSLA_PROCESS_FD
{
LONG Fd;
int Type;
[string, unique] LPCSTR Path;
} WSLA_PROCESS_FD;
typedef
struct _WSLA_CREATE_PROCESS_RESULT {
int Errno;
int Pid;
} WSLA_CREATE_PROCESS_RESULT;
[
uuid(7BC4E198-6531-4FA6-ADE2-5EF3D2A04DFE),
pointer_default(unique),
object
]
interface ITerminationCallback : IUnknown
{
HRESULT OnTermination(ULONG Reason, LPCWSTR Details);
};
[
uuid(82A7ABC8-6B50-43FC-AB96-15FBBE7E8761),
pointer_default(unique),
object
]
interface IWSLAVirtualMachine : IUnknown
{
HRESULT AttachDisk([in] LPCWSTR Path, [in] BOOL ReadOnly, [out] LPSTR* Device, [out] ULONG* Lun);
HRESULT Mount([in, unique] LPCSTR Source, [in] LPCSTR Target, [in] LPCSTR Type, [in] LPCSTR Options, [in] ULONG Flags);
HRESULT CreateLinuxProcess([in] const WSLA_CREATE_PROCESS_OPTIONS* Options, [in] ULONG FdCount, [in, unique, size_is(FdCount)] WSLA_PROCESS_FD* Fds, [out, size_is(FdCount)] ULONG* Handles, [out] WSLA_CREATE_PROCESS_RESULT* Result);
HRESULT WaitPid([in] LONG Pid, [in] ULONGLONG TimeoutMs, [out] ULONG* State, [out] int* Code);
HRESULT Signal([in] LONG Pid, [in] int Signal);
HRESULT Shutdown([in] ULONGLONG TimeoutMs);
HRESULT RegisterCallback([in] ITerminationCallback* terminationCallback);
HRESULT GetDebugShellPipe([out] LPWSTR* pipePath);
HRESULT MapPort([in] int Family, [in] short WindowsPort, [in] short LinuxPort, [in] BOOL Remove);
HRESULT Unmount([in] LPCSTR Path);
HRESULT DetachDisk([in] ULONG Lun);
HRESULT MountWindowsFolder([in] LPCWSTR WindowsPath, [in] LPCSTR LinuxPath, [in] BOOL ReadOnly);
HRESULT UnmountWindowsFolder([in] LPCSTR LinuxPath);
HRESULT MountGpuLibraries([in] LPCSTR LibrariesMountPoint, [in] LPCSTR DriversMountpoint, [in] DWORD Flags);
}
typedef
struct _VIRTUAL_MACHINE_SETTINGS {
LPCWSTR DisplayName;
ULONGLONG MemoryMb;
ULONG CpuCount;
ULONG BootTimeoutMs;
ULONG DmesgOutput;
ULONG NetworkingMode;
BOOL EnableDnsTunneling;
BOOL EnableDebugShell;
BOOL EnableEarlyBootDmesg;
BOOL EnableGPU;
} VIRTUAL_MACHINE_SETTINGS;
[
uuid(82A7ABC8-6B50-43FC-AB96-15FBBE7E8760),
pointer_default(unique),
object
]
interface IWSLAUserSession : IUnknown
{
HRESULT GetVersion([out] WSL_VERSION* Error);
HRESULT CreateVirtualMachine([in] const VIRTUAL_MACHINE_SETTINGS* Settings, [out]IWSLAVirtualMachine** VirtualMachine);
}

View File

@ -3,6 +3,7 @@ set(HEADERS WSLAApi.h)
add_library(wslaclient SHARED ${SOURCES} ${HEADERS} wslaclient.def)
set_target_properties(wslaclient PROPERTIES EXCLUDE_FROM_ALL FALSE)
add_dependencies(wslaclient wslaserviceidl)
target_link_libraries(wslaclient ${COMMON_LINK_LIBRARIES} legacy_stdio_definitions common)
target_precompile_headers(wslaclient REUSE_FROM common)
set_target_properties(wslaclient PROPERTIES FOLDER windows)

View File

@ -13,7 +13,7 @@ Abstract:
--*/
#include "precomp.h"
#include "wslservice.h"
#include "wslaservice.h"
#include "WSLAApi.h"
#include "wslrelay.h"
#include "wslInstall.h"

View File

@ -0,0 +1,6 @@
add_compile_definitions("PROXY_CLSID_IS={0x4EA0C6DD,0xE9FF,0x48E7,{0x99,0x4e,0x13,0xa3,0x1d,0x10,0xdc,0x61}}")
add_compile_definitions("REGISTER_PROXY_DLL")
add_subdirectory(exe)
add_subdirectory(inc)
add_subdirectory(stub)

View File

@ -0,0 +1,34 @@
set(SOURCES
application.manifest
main.rc
ServiceMain.cpp
WSLAUserSession.cpp
WSLAUserSessionFactory.cpp
WSLAVirtualMachine.cpp
)
set(HEADERS
WSLAUserSession.h
WSLAUserSessionFactory.h
WSLAVirtualMachine.h)
include_directories(${CMAKE_SOURCE_DIR}/src/windows/wslaclient)
add_executable(wslaservice ${SOURCES} ${HEADERS})
add_dependencies(wslaservice wslaserviceidl)
add_compile_definitions(__WRL_CLASSIC_COM__)
add_compile_definitions(__WRL_DISABLE_STATIC_INITIALIZE__)
add_compile_definitions(USE_COM_CONTEXT_DEF=1)
set_target_properties(wslaservice PROPERTIES LINK_FLAGS "/merge:minATL=.rdata /include:__minATLObjMap_WSLAUserSession_COM")
target_link_libraries(wslaservice
${COMMON_LINK_LIBRARIES}
computecore
common
configfile
legacy_stdio_definitions
VirtDisk.lib
Winhttp.lib
Synchronization.lib)
target_precompile_headers(wslaservice REUSE_FROM common)
set_target_properties(wslaservice PROPERTIES FOLDER windows)

View File

@ -0,0 +1,126 @@
/*++
Copyright (c) Microsoft. All rights reserved.
Module Name:
ServiceMain.cpp
Abstract:
This file contains the entrypoint for the Lxss Manager service.
--*/
#include "precomp.h"
#include "comservicehelper.h"
#include "LxssSecurity.h"
#include "WslCoreFilesystem.h"
#include "WSLAUserSessionFactory.h"
#include <ctime>
using namespace wsl::windows::common::registry;
using namespace wsl::windows::common::string;
using namespace wsl::windows::common::wslutil;
using namespace wsl::windows::policies;
wil::unique_event g_networkingReady{wil::EventOptions::ManualReset};
// Declare the WSLAUserSession COM class.
CoCreatableClassWrlCreatorMapInclude(WSLAUserSession);
struct WslaServiceSecurityPolicy
{
static LPCWSTR GetSDDLText()
{
// COM Access and Launch permissions allowed for authenticated user, principal self, and system.
// 0xB = (COM_RIGHTS_EXECUTE | COM_RIGHTS_EXECUTE_LOCAL | COM_RIGHTS_ACTIVATE_LOCAL)
// N.B. This should be kept in sync with the security descriptors in the appxmanifest and package.wix.
return L"O:BAG:BAD:(A;;0xB;;;AU)(A;;0xB;;;PS)(A;;0xB;;;SY)";
}
};
class WslaService : public Windows::Internal::Service<WslaService, Windows::Internal::ContinueRunningWithNoObjects, WslaServiceSecurityPolicy>
{
public:
static wchar_t* GetName()
{
return const_cast<LPWSTR>(L"WSLAService");
}
static void OnSessionChanged(DWORD eventType, DWORD sessionId);
HRESULT OnServiceStarting();
HRESULT ServiceStarted();
void ServiceStopped();
private:
wil::unique_couninitialize_call m_coInit{false};
};
HRESULT WslaService::OnServiceStarting()
try
{
ConfigureCrt();
// Enable contextualized errors
wsl::windows::common::EnableContextualizedErrors(true);
// Initialize telemetry.
// TODO-WSLA: Create a dedicated WSLA provider
WslTraceLoggingInitialize(WslServiceTelemetryProvider, !wsl::shared::OfficialBuild);
WSL_LOG("Service starting", TraceLoggingLevel(WINEVENT_LEVEL_INFO));
// Don't kill the process on unknown C++ exceptions.
wil::g_fResultFailFastUnknownExceptions = false;
// wsl::windows::common::security::ApplyProcessMitigationPolicies();
// Initialize Winsock.
WSADATA Data;
THROW_IF_WIN32_ERROR(WSAStartup(MAKEWORD(2, 2), &Data));
return S_OK;
}
CATCH_RETURN()
HRESULT WslaService::ServiceStarted()
{
m_coInit = wil::CoInitializeEx(COINIT_MULTITHREADED);
return S_OK;
}
void WslaService::OnSessionChanged(DWORD eventType, DWORD sessionId)
{
if (eventType == WTS_SESSION_LOGOFF)
{
// TODO-WSLA: Implement for WSLA
// TerminateSession(sessionId);
}
}
void WslaService::ServiceStopped()
{
WSL_LOG("Service stopping", TraceLoggingLevel(WINEVENT_LEVEL_INFO));
// Terminate all user sessions.
wsl::windows::service::wsla::ClearWslaSessionsAndBlockNewInstances();
// There is a potential deadlock if CoUninitialize() is called before the LanguageChangeNotifyThread
// isn't done initializing. Clearing the COM objects before calling CoUninitialize() works around the issue.
winrt::clear_factory_cache();
// Tear down telemetry.
WslTraceLoggingUninitialize();
// uninitialize COM. This must be done here because this call can cause cleanups that will be fail
// if the CRT is shutting down.
m_coInit.reset();
}
int __cdecl wmain()
{
WslaService::ProcessMain();
return 0;
}

View File

@ -0,0 +1,87 @@
/*++
Copyright (c) Microsoft. All rights reserved.
Module Name:
WSLAUserSession.cpp
Abstract:
TODO
--*/
#include "WSLAUserSession.h"
using wsl::windows::service::wsla::WSLAUserSessionImpl;
WSLAUserSessionImpl::WSLAUserSessionImpl(HANDLE Token, wil::unique_tokeninfo_ptr<TOKEN_USER>&& TokenInfo) :
m_tokenInfo(std::move(TokenInfo))
{
}
WSLAUserSessionImpl::~WSLAUserSessionImpl()
{
// Manually signal the VM termination events. This prevents being stuck on an API call that holds the VM lock.
{
std::lock_guard lock(m_virtualMachinesLock);
for (auto* e : m_virtualMachines)
{
e->OnSessionTerminating();
}
}
}
void WSLAUserSessionImpl::OnVmTerminated(WSLAVirtualMachine* machine)
{
std::lock_guard lock(m_virtualMachinesLock);
auto pred = [machine](const auto* e) { return machine == e; };
// Remove any stale VM reference.
m_virtualMachines.erase(std::remove_if(m_virtualMachines.begin(), m_virtualMachines.end(), pred), m_virtualMachines.end());
}
HRESULT WSLAUserSessionImpl::CreateVirtualMachine(const VIRTUAL_MACHINE_SETTINGS* Settings, IWSLAVirtualMachine** VirtualMachine)
{
auto vm = wil::MakeOrThrow<WSLAVirtualMachine>(*Settings, GetUserSid(), this);
{
std::lock_guard lock(m_virtualMachinesLock);
m_virtualMachines.emplace_back(vm.Get());
}
vm->Start();
THROW_IF_FAILED(vm.CopyTo(__uuidof(IWSLAVirtualMachine), (void**)VirtualMachine));
return S_OK;
}
PSID WSLAUserSessionImpl::GetUserSid() const
{
return m_tokenInfo->User.Sid;
}
wsl::windows::service::wsla::WSLAUserSession::WSLAUserSession(std::weak_ptr<WSLAUserSessionImpl>&& Session) :
m_session(std::move(Session))
{
}
HRESULT wsl::windows::service::wsla::WSLAUserSession::GetVersion(_Out_ WSL_VERSION* Version)
{
Version->Major = WSL_PACKAGE_VERSION_MAJOR;
Version->Minor = WSL_PACKAGE_VERSION_MINOR;
Version->Revision = WSL_PACKAGE_VERSION_REVISION;
return S_OK;
}
HRESULT wsl::windows::service::wsla::WSLAUserSession::CreateVirtualMachine(const VIRTUAL_MACHINE_SETTINGS* Settings, IWSLAVirtualMachine** VirtualMachine)
try
{
auto session = m_session.lock();
RETURN_HR_IF(RPC_E_DISCONNECTED, !session);
return session->CreateVirtualMachine(Settings, VirtualMachine);
}
CATCH_RETURN();

View File

@ -0,0 +1,56 @@
/*++
Copyright (c) Microsoft. All rights reserved.
Module Name:
WSLAUserSession.h
Abstract:
TODO
--*/
#pragma once
#include "WSLAVirtualMachine.h"
namespace wsl::windows::service::wsla {
class WSLAUserSessionImpl
{
public:
WSLAUserSessionImpl(HANDLE Token, wil::unique_tokeninfo_ptr<TOKEN_USER>&& TokenInfo);
WSLAUserSessionImpl(WSLAUserSessionImpl&&) = default;
WSLAUserSessionImpl& operator=(WSLAUserSessionImpl&&) = default;
~WSLAUserSessionImpl();
PSID GetUserSid() const;
HRESULT CreateVirtualMachine(const VIRTUAL_MACHINE_SETTINGS* Settings, IWSLAVirtualMachine** VirtualMachine);
void OnVmTerminated(WSLAVirtualMachine* machine);
private:
wil::unique_tokeninfo_ptr<TOKEN_USER> m_tokenInfo;
std::recursive_mutex m_virtualMachinesLock;
std::vector<WSLAVirtualMachine*> m_virtualMachines;
};
class DECLSPEC_UUID("a9b7a1b9-0671-405c-95f1-e0612cb4ce8f") WSLAUserSession
: public Microsoft::WRL::RuntimeClass<Microsoft::WRL::RuntimeClassFlags<Microsoft::WRL::ClassicCom>, IWSLAUserSession, IFastRundown>
{
public:
WSLAUserSession(std::weak_ptr<WSLAUserSessionImpl>&& Session);
WSLAUserSession(const WSLAUserSession&) = delete;
WSLAUserSession& operator=(const WSLAUserSession&) = delete;
IFACEMETHOD(GetVersion)(_Out_ WSL_VERSION* Version) override;
IFACEMETHOD(CreateVirtualMachine)(const VIRTUAL_MACHINE_SETTINGS* Settings, IWSLAVirtualMachine** VirtualMachine) override;
private:
std::weak_ptr<WSLAUserSessionImpl> m_session;
};
} // namespace wsl::windows::service::wsla

View File

@ -0,0 +1,82 @@
/*++
Copyright (c) Microsoft. All rights reserved.
Module Name:
WSLAUserSessionFactory.cpp
Abstract:
TODO
--*/
#include "precomp.h"
#include "WSLAUserSessionFactory.h"
#include "WSLAUserSession.h"
using wsl::windows::service::wsla::WSLAUserSessionFactory;
using wsl::windows::service::wsla::WSLAUserSessionImpl;
CoCreatableClassWithFactory(WSLAUserSession, WSLAUserSessionFactory);
static std::mutex g_mutex;
static std::optional<std::vector<std::shared_ptr<WSLAUserSessionImpl>>> g_sessions =
std::make_optional<std::vector<std::shared_ptr<WSLAUserSessionImpl>>>();
HRESULT WSLAUserSessionFactory::CreateInstance(_In_ IUnknown* pUnkOuter, _In_ REFIID riid, _Out_ void** ppCreated)
{
RETURN_HR_IF_NULL(E_POINTER, ppCreated);
*ppCreated = nullptr;
RETURN_HR_IF(CLASS_E_NOAGGREGATION, pUnkOuter != nullptr);
WSL_LOG("WSLAUserSessionFactory", TraceLoggingLevel(WINEVENT_LEVEL_VERBOSE));
try
{
const wil::unique_handle userToken = wsl::windows::common::security::GetUserToken(TokenImpersonation);
// Get the session ID and SID of the client process.
DWORD sessionId{};
DWORD length = 0;
THROW_IF_WIN32_BOOL_FALSE(::GetTokenInformation(userToken.get(), TokenSessionId, &sessionId, sizeof(sessionId), &length));
auto tokenInfo = wil::get_token_information<TOKEN_USER>(userToken.get());
std::lock_guard lock{g_mutex};
THROW_HR_IF(CO_E_SERVER_STOPPING, !g_sessions.has_value());
auto session = std::find_if(g_sessions->begin(), g_sessions->end(), [&tokenInfo](auto it) {
return EqualSid(it->GetUserSid(), &tokenInfo->User.Sid);
});
if (session == g_sessions->end())
{
session = g_sessions->insert(g_sessions->end(), std::make_shared<WSLAUserSessionImpl>(userToken.get(), std::move(tokenInfo)));
}
auto comInstance = wil::MakeOrThrow<WSLAUserSession>(std::weak_ptr<WSLAUserSessionImpl>(*session));
THROW_IF_FAILED(comInstance.CopyTo(riid, ppCreated));
}
catch (...)
{
const auto result = wil::ResultFromCaughtException();
// Note: S_FALSE will cause COM to retry if the service is stopping.
return result == CO_E_SERVER_STOPPING ? S_FALSE : result;
}
WSL_LOG("WSLAUserSessionFactory", TraceLoggingLevel(WINEVENT_LEVEL_VERBOSE));
return S_OK;
}
void wsl::windows::service::wsla::ClearWslaSessionsAndBlockNewInstances()
{
std::lock_guard lock{g_mutex};
g_sessions.reset();
}

View File

@ -0,0 +1,28 @@
/*++
Copyright (c) Microsoft. All rights reserved.
Module Name:
WSLAUserSessionFactory.h
Abstract:
TODO
--*/
#pragma once
#include <wil/resource.h>
namespace wsl::windows::service::wsla {
class WSLAUserSessionFactory : public Microsoft::WRL::ClassFactory<>
{
public:
WSLAUserSessionFactory() = default;
STDMETHODIMP CreateInstance(_In_ IUnknown* pUnkOuter, _In_ REFIID riid, _Out_ void** ppCreated) override;
};
void ClearWslaSessionsAndBlockNewInstances();
} // namespace wsl::windows::service::wsla

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,106 @@
/*++
Copyright (c) Microsoft. All rights reserved.
Module Name:
WSLAVirtualMachine.h
Abstract:
TODO
--*/
#pragma once
#include "wslaservice.h"
#include "INetworkingEngine.h"
#include "hcs.hpp"
#include "Dmesg.h"
#include "WSLAApi.h"
namespace wsl::windows::service::wsla {
class WSLAUserSessionImpl;
class DECLSPEC_UUID("0CFC5DC1-B6A7-45FC-8034-3FA9ED73CE30") WSLAVirtualMachine
: public Microsoft::WRL::RuntimeClass<Microsoft::WRL::RuntimeClassFlags<Microsoft::WRL::ClassicCom>, IWSLAVirtualMachine, IFastRundown>
{
public:
WSLAVirtualMachine(const VIRTUAL_MACHINE_SETTINGS& Settings, PSID Sid, WSLAUserSessionImpl* UserSession);
~WSLAVirtualMachine();
void Start();
void OnSessionTerminating();
IFACEMETHOD(AttachDisk(_In_ PCWSTR Path, _In_ BOOL ReadOnly, _Out_ LPSTR* Device, _Out_ ULONG* Lun)) override;
IFACEMETHOD(Mount(_In_ LPCSTR Source, _In_ LPCSTR Target, _In_ LPCSTR Type, _In_ LPCSTR Options, _In_ ULONG Flags)) override;
IFACEMETHOD(CreateLinuxProcess(
_In_ const WSLA_CREATE_PROCESS_OPTIONS* Options, _In_ ULONG FdCount, _In_ WSLA_PROCESS_FD* Fd, _Out_ ULONG* Handles, _Out_ WSLA_CREATE_PROCESS_RESULT* Result)) override;
IFACEMETHOD(WaitPid(_In_ LONG Pid, _In_ ULONGLONG TimeoutMs, _Out_ ULONG* State, _Out_ int* Code)) override;
IFACEMETHOD(Signal(_In_ LONG Pid, _In_ int Signal)) override;
IFACEMETHOD(Shutdown(ULONGLONG _In_ TimeoutMs)) override;
IFACEMETHOD(RegisterCallback(_In_ ITerminationCallback* callback)) override;
IFACEMETHOD(GetDebugShellPipe(_Out_ LPWSTR* pipePath)) override;
IFACEMETHOD(MapPort(_In_ int Family, _In_ short WindowsPort, _In_ short LinuxPort, _In_ BOOL Remove)) override;
IFACEMETHOD(Unmount(_In_ const char* Path)) override;
IFACEMETHOD(DetachDisk(_In_ ULONG Lun)) override;
IFACEMETHOD(MountWindowsFolder(_In_ LPCWSTR WindowsPath, _In_ LPCSTR LinuxPath, _In_ BOOL ReadOnly)) override;
IFACEMETHOD(UnmountWindowsFolder(_In_ LPCSTR LinuxPath)) override;
IFACEMETHOD(MountGpuLibraries(_In_ LPCSTR LibrariesMountPoint, _In_ LPCSTR DriversMountpoint, _In_ DWORD Flags)) override;
private:
static int32_t MountImpl(wsl::shared::SocketChannel& Channel, LPCSTR Source, _In_ LPCSTR Target, _In_ LPCSTR Type, _In_ LPCSTR Options, _In_ ULONG Flags);
static void CALLBACK s_OnExit(_In_ HCS_EVENT* Event, _In_opt_ void* Context);
static bool ParseTtyInformation(const WSLA_PROCESS_FD* Fds, ULONG FdCount, const WSLA_PROCESS_FD** TtyInput, const WSLA_PROCESS_FD** TtyOutput);
void ConfigureNetworking();
void OnExit(_In_ const HCS_EVENT* Event);
std::tuple<int32_t, int32_t, wsl::shared::SocketChannel> Fork(enum WSLA_FORK::ForkType Type);
std::tuple<int32_t, int32_t, wsl::shared::SocketChannel> Fork(wsl::shared::SocketChannel& Channel, enum WSLA_FORK::ForkType Type);
int32_t ExpectClosedChannelOrError(wsl::shared::SocketChannel& Channel);
wil::unique_socket ConnectSocket(wsl::shared::SocketChannel& Channel, int32_t Fd);
void OpenLinuxFile(wsl::shared::SocketChannel& Channel, const char* Path, uint32_t Flags, int32_t Fd);
void LaunchPortRelay();
std::vector<wil::unique_socket> CreateLinuxProcessImpl(
_In_ const WSLA_CREATE_PROCESS_OPTIONS* Options, _In_ ULONG FdCount, _In_ WSLA_PROCESS_FD* Fd, _Out_ WSLA_CREATE_PROCESS_RESULT* Result);
HRESULT MountWindowsFolderImpl(_In_ LPCWSTR WindowsPath, _In_ LPCSTR LinuxPath, _In_ BOOL ReadOnly, _In_ WslMountFlags Flags);
struct AttachedDisk
{
std::filesystem::path Path;
std::string Device;
bool AccessGranted = false;
};
VIRTUAL_MACHINE_SETTINGS m_settings;
GUID m_vmId{};
std::wstring m_vmIdString;
wsl::windows::common::helpers::WindowsVersion m_windowsVersion = wsl::windows::common::helpers::GetWindowsVersion();
int m_coldDiscardShiftSize{};
bool m_running = false;
PSID m_userSid{};
std::wstring m_debugShellPipe;
wsl::windows::common::hcs::unique_hcs_system m_computeSystem;
std::shared_ptr<DmesgCollector> m_dmesgCollector;
wil::unique_event m_vmExitEvent{wil::EventOptions::ManualReset};
wil::unique_event m_vmTerminatingEvent{wil::EventOptions::ManualReset};
wil::com_ptr<ITerminationCallback> m_terminationCallback;
std::unique_ptr<wsl::core::INetworkingEngine> m_networkEngine;
wsl::shared::SocketChannel m_initChannel;
wil::unique_handle m_portRelayChannelRead;
wil::unique_handle m_portRelayChannelWrite;
std::map<ULONG, AttachedDisk> m_attachedDisks;
std::map<std::string, std::wstring> m_plan9Mounts;
std::recursive_mutex m_lock;
std::mutex m_portRelaylock;
WSLAUserSessionImpl* m_userSession;
};
} // namespace wsl::windows::service::wsla

View File

@ -0,0 +1,8 @@
<?xml version="1.0" encoding="utf-8" standalone="yes"?>
<assembly xmlns="urn:schemas-microsoft-com:asm.v1" manifestVersion="1.0" xmlns:asmv3="urn:schemas-microsoft-com:asm.v3" >
<application xmlns="urn:schemas-microsoft-com:asm.v3">
<windowsSettings xmlns:ws2="http://schemas.microsoft.com/SMI/2016/WindowsSettings">
<ws2:longPathAware>true</ws2:longPathAware>
</windowsSettings>
</application>
</assembly>

View File

@ -0,0 +1,28 @@
/*++
Copyright (c) Microsoft. All rights reserved.
Module Name:
main.rc
Abstract:
This file contains resources for wslaservice.
--*/
#include <windows.h>
#include "resource.h"
#include "wslversioninfo.h"
#define VER_INTERNALNAME_STR "wslaservice.exe"
#define VER_ORIGINALFILENAME_STR "wslaservice.exe"
#define VER_FILETYPE VFT_APP
#define VER_FILESUBTYPE VFT2_UNKNOWN
#define VER_FILEDESCRIPTION_STR "Windows Subsystem for Linux for Apps Service"
ID_ICON ICON PRELOAD DISCARDABLE "..\..\..\Images\wsl.ico"
#include <common.ver>

View File

@ -0,0 +1,15 @@
/*++
Copyright (c) Microsoft. All rights reserved.
Module Name:
resource.h
Abstract:
This file contains resource declarations for wslaservice.exe
--*/
#define ID_ICON 1

View File

@ -0,0 +1,2 @@
add_idl(wslaserviceidl "wslaservice.idl" "")
set_target_properties(wslaserviceidl PROPERTIES FOLDER windows)

View File

@ -0,0 +1,114 @@
/*++
Copyright (c) Microsoft Corporation. All rights reserved.
Module Name:
wslaservice.idl
Abstract:
This file contains the COM object definitions used to talk with the WSLa
service "WslaService"
--*/
import "unknwn.idl";
import "wtypes.idl";
cpp_quote("#ifdef __cplusplus")
cpp_quote("class DECLSPEC_UUID(\"a9b7a1b9-0671-405c-95f1-e0612cb4ce8f\") WSLAUserSession;")
cpp_quote("#endif")
typedef
struct _WSL_VERSION {
ULONG Major;
ULONG Minor;
ULONG Revision;
} WSL_VERSION;
typedef [system_handle(sh_socket)] HANDLE HVSOCKET_HANDLE;
typedef
struct _WSLA_CREATE_PROCESS_OPTIONS {
[string] LPCSTR Executable;
ULONG CommandLineCount;
[unique, size_is(CommandLineCount)] LPCSTR* CommandLine;
ULONG EnvironmentCount;
[unique, size_is(EnvironmentCount)] LPCSTR* Environment;
[unique] LPCSTR CurrentDirectory;
} WSLA_CREATE_PROCESS_OPTIONS;
typedef struct _WSLA_PROCESS_FD
{
LONG Fd;
int Type;
[string, unique] LPCSTR Path;
} WSLA_PROCESS_FD;
typedef
struct _WSLA_CREATE_PROCESS_RESULT {
int Errno;
int Pid;
} WSLA_CREATE_PROCESS_RESULT;
[
uuid(7BC4E198-6531-4FA6-ADE2-5EF3D2A04DFE),
pointer_default(unique),
object
]
interface ITerminationCallback : IUnknown
{
HRESULT OnTermination(ULONG Reason, LPCWSTR Details);
};
[
uuid(82A7ABC8-6B50-43FC-AB96-15FBBE7E8761),
pointer_default(unique),
object
]
interface IWSLAVirtualMachine : IUnknown
{
HRESULT AttachDisk([in] LPCWSTR Path, [in] BOOL ReadOnly, [out] LPSTR* Device, [out] ULONG* Lun);
HRESULT Mount([in, unique] LPCSTR Source, [in] LPCSTR Target, [in] LPCSTR Type, [in] LPCSTR Options, [in] ULONG Flags);
HRESULT CreateLinuxProcess([in] const WSLA_CREATE_PROCESS_OPTIONS* Options, [in] ULONG FdCount, [in, unique, size_is(FdCount)] WSLA_PROCESS_FD* Fds, [out, size_is(FdCount)] ULONG* Handles, [out] WSLA_CREATE_PROCESS_RESULT* Result);
HRESULT WaitPid([in] LONG Pid, [in] ULONGLONG TimeoutMs, [out] ULONG* State, [out] int* Code);
HRESULT Signal([in] LONG Pid, [in] int Signal);
HRESULT Shutdown([in] ULONGLONG TimeoutMs);
HRESULT RegisterCallback([in] ITerminationCallback* terminationCallback);
HRESULT GetDebugShellPipe([out] LPWSTR* pipePath);
HRESULT MapPort([in] int Family, [in] short WindowsPort, [in] short LinuxPort, [in] BOOL Remove);
HRESULT Unmount([in] LPCSTR Path);
HRESULT DetachDisk([in] ULONG Lun);
HRESULT MountWindowsFolder([in] LPCWSTR WindowsPath, [in] LPCSTR LinuxPath, [in] BOOL ReadOnly);
HRESULT UnmountWindowsFolder([in] LPCSTR LinuxPath);
HRESULT MountGpuLibraries([in] LPCSTR LibrariesMountPoint, [in] LPCSTR DriversMountpoint, [in] DWORD Flags);
}
typedef
struct _VIRTUAL_MACHINE_SETTINGS {
LPCWSTR DisplayName;
ULONGLONG MemoryMb;
ULONG CpuCount;
ULONG BootTimeoutMs;
ULONG DmesgOutput;
ULONG NetworkingMode;
BOOL EnableDnsTunneling;
BOOL EnableDebugShell;
BOOL EnableEarlyBootDmesg;
BOOL EnableGPU;
} VIRTUAL_MACHINE_SETTINGS;
[
uuid(82A7ABC8-6B50-43FC-AB96-15FBBE7E8760),
pointer_default(unique),
object
]
interface IWSLAUserSession : IUnknown
{
HRESULT GetVersion([out] WSL_VERSION* Error);
HRESULT CreateVirtualMachine([in] const VIRTUAL_MACHINE_SETTINGS* Settings, [out]IWSLAVirtualMachine** VirtualMachine);
}

View File

@ -0,0 +1,13 @@
set(SOURCES
${CMAKE_CURRENT_BINARY_DIR}/../inc/${TARGET_PLATFORM}/${CMAKE_BUILD_TYPE}/wslaservice_i_${TARGET_PLATFORM}.c
${CMAKE_CURRENT_BINARY_DIR}/../inc/${TARGET_PLATFORM}/${CMAKE_BUILD_TYPE}/wslaservice_p_${TARGET_PLATFORM}.c
${CMAKE_CURRENT_BINARY_DIR}/../inc/${TARGET_PLATFORM}/${CMAKE_BUILD_TYPE}/dlldata_${TARGET_PLATFORM}.c
${CMAKE_CURRENT_LIST_DIR}/WslaServiceProxyStub.def
${CMAKE_CURRENT_LIST_DIR}/WslaServiceProxyStub.rc)
set_source_files_properties(${SOURCES} PROPERTIES GENERATED TRUE)
add_library(wslaserviceproxystub SHARED ${SOURCES})
add_dependencies(wslaserviceproxystub wslaserviceidl)
target_link_libraries(wslaserviceproxystub ${COMMON_LINK_LIBRARIES})
set_target_properties(wslaserviceproxystub PROPERTIES FOLDER windows)

View File

@ -0,0 +1,5 @@
LIBRARY WslaServiceProxyStub.dll
EXPORTS
DllGetClassObject PRIVATE
DllCanUnloadNow PRIVATE

View File

@ -0,0 +1,23 @@
/*++
Copyright (c) Microsoft. All rights reserved.
Module Name:
WslaServiceProxyStub.rc
Abstract:
This file contains resources for wslaserviceproxystub.dll.
--*/
#include <windows.h>
#include "wslversioninfo.h"
#define VER_INTERNALNAME_STR "wslaserviceproxystub.dll"
#define VER_ORIGINALFILENAME_STR "wslaserviceproxystub.dll"
#define VER_FILEDESCRIPTION_STR "WSLA Service ProxyStub DLL"
#include <common.ver>

View File

@ -1316,6 +1316,17 @@ void StopWslService()
StopService(service.get());
}
void StopWslaService()
{
LogInfo("Stopping WSLAService");
const wil::unique_schandle manager{OpenSCManager(nullptr, nullptr, SC_MANAGER_CONNECT)};
VERIFY_IS_NOT_NULL(manager);
const wil::unique_schandle service{OpenService(manager.get(), L"wslaservice", SERVICE_STOP | SERVICE_QUERY_STATUS)};
VERIFY_IS_NOT_NULL(service);
StopService(service.get());
}
wil::unique_handle GetNonElevatedToken()
{
const auto token = wil::open_current_access_token(TOKEN_ALL_ACCESS);

View File

@ -523,6 +523,7 @@ inline auto EnableSystemd(const std::string& extraConfig = "")
std::wstring EscapePath(std::wstring_view Path);
void StopWslService();
void StopWslaService();
std::optional<GUID> GetDistributionId(LPCWSTR Name);
wil::unique_hkey OpenDistributionKey(LPCWSTR Name);

View File

@ -821,7 +821,7 @@ class WSLATests
});
// Stop the service
StopWslService();
StopWslaService();
// Verify that the thread is unstuck
stuckThread.join();