wsla: Create wslaservice.exe (#13623)

* Save state

* Save state

* Save state

* Save state

* Save state

* Save state

* Save state

* Save state

* Save state

* Cleanup for review

* Update ServiceMain.cpp comment

* Remove duplicated definitions from wslservice.idl
This commit is contained in:
Blue 2025-10-23 17:44:57 -07:00 committed by GitHub
parent 3f24caaf73
commit 451a7e103a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
57 changed files with 4209 additions and 2480 deletions

View File

@ -388,6 +388,7 @@ include_directories(${WSLDEPS_SOURCE_DIR}/include/lxcore)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/src/shared/inc) include_directories(${CMAKE_CURRENT_SOURCE_DIR}/src/shared/inc)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/src/windows/inc) include_directories(${CMAKE_CURRENT_SOURCE_DIR}/src/windows/inc)
include_directories(${CMAKE_CURRENT_BINARY_DIR}/src/windows/service/inc/${TARGET_PLATFORM}/${CMAKE_BUILD_TYPE}) include_directories(${CMAKE_CURRENT_BINARY_DIR}/src/windows/service/inc/${TARGET_PLATFORM}/${CMAKE_BUILD_TYPE})
include_directories(${CMAKE_CURRENT_BINARY_DIR}/src/windows/wslaservice/inc/${TARGET_PLATFORM}/${CMAKE_BUILD_TYPE})
include_directories(${CMAKE_CURRENT_BINARY_DIR}/src/windows/wslinstaller/inc/${TARGET_PLATFORM}/${CMAKE_BUILD_TYPE}) include_directories(${CMAKE_CURRENT_BINARY_DIR}/src/windows/wslinstaller/inc/${TARGET_PLATFORM}/${CMAKE_BUILD_TYPE})
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/src/linux/init/inc) include_directories(${CMAKE_CURRENT_SOURCE_DIR}/src/linux/init/inc)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/src/windows/common) include_directories(${CMAKE_CURRENT_SOURCE_DIR}/src/windows/common)
@ -413,6 +414,7 @@ add_subdirectory(msipackage)
add_subdirectory(msixinstaller) add_subdirectory(msixinstaller)
add_subdirectory(src/windows/common) add_subdirectory(src/windows/common)
add_subdirectory(src/windows/service) add_subdirectory(src/windows/service)
add_subdirectory(src/windows/wslaservice)
add_subdirectory(src/windows/wslinstaller/inc) add_subdirectory(src/windows/wslinstaller/inc)
add_subdirectory(src/windows/wslinstaller/stub) add_subdirectory(src/windows/wslinstaller/stub)
add_subdirectory(src/windows/wslinstaller/exe) add_subdirectory(src/windows/wslinstaller/exe)

View File

@ -3,7 +3,7 @@ set(OUTPUT_PACKAGE ${BIN}/wsl.msi)
set(PACKAGE_WIX_IN ${CMAKE_CURRENT_LIST_DIR}/package.wix.in) set(PACKAGE_WIX_IN ${CMAKE_CURRENT_LIST_DIR}/package.wix.in)
set(PACKAGE_WIX ${BIN}/package.wix) set(PACKAGE_WIX ${BIN}/package.wix)
set(CAB_CACHE ${BIN}/cab) set(CAB_CACHE ${BIN}/cab)
set(BINARIES wsl.exe;wslg.exe;wslhost.exe;wslrelay.exe;wslservice.exe;wslserviceproxystub.dll;init;initrd.img;wslinstall.dll) set(BINARIES wsl.exe;wslg.exe;wslhost.exe;wslrelay.exe;wslservice.exe;wslserviceproxystub.dll;init;initrd.img;wslinstall.dll;wslaserviceproxystub.dll;wslaservice.exe)
if (WSL_BUILD_WSL_SETTINGS) if (WSL_BUILD_WSL_SETTINGS)
list(APPEND BINARIES_DEPENDENCIES "wslsettings/wslsettings.dll;wslsettings/wslsettings.exe;libwsl.dll") list(APPEND BINARIES_DEPENDENCIES "wslsettings/wslsettings.dll;wslsettings/wslsettings.exe;libwsl.dll")
@ -39,7 +39,7 @@ add_custom_command(
add_custom_target(msipackage DEPENDS ${OUTPUT_PACKAGE}) add_custom_target(msipackage DEPENDS ${OUTPUT_PACKAGE})
set_target_properties(msipackage PROPERTIES EXCLUDE_FROM_ALL FALSE SOURCES ${PACKAGE_WIX_IN}) set_target_properties(msipackage PROPERTIES EXCLUDE_FROM_ALL FALSE SOURCES ${PACKAGE_WIX_IN})
add_dependencies(msipackage wsl wslg wslservice wslhost wslrelay wslserviceproxystub init initramfs wslinstall msixgluepackage) add_dependencies(msipackage wsl wslg wslservice wslhost wslrelay wslserviceproxystub init initramfs wslinstall msixgluepackage wslaservice wslaserviceproxystub)
if (WSL_BUILD_WSL_SETTINGS) if (WSL_BUILD_WSL_SETTINGS)
add_dependencies(msipackage wslsettings libwsl) add_dependencies(msipackage wslsettings libwsl)

View File

@ -50,75 +50,75 @@
<Component Id="explorerplan9shortcut" Guid="{93CBFF23-A04C-4344-A332-238CE5B97AED}" UninstallWhenSuperseded="yes" DisableRegistryReflection="yes" Bitness="always64"> <Component Id="explorerplan9shortcut" Guid="{93CBFF23-A04C-4344-A332-238CE5B97AED}" UninstallWhenSuperseded="yes" DisableRegistryReflection="yes" Bitness="always64">
<!-- Explorer extensions --> <!-- Explorer extensions -->
<RegistryKey Root="HKLM" Key="SOFTWARE\Classes\CLSID\{B2B4A4D1-2754-4140-A2EB-9A76D9D7CDC6}"> <RegistryKey Root="HKLM" Key="SOFTWARE\Classes\CLSID\{B2B4A4D1-2754-4140-A2EB-9A76D9D7CDC6}">
<RegistryValue Value="Linux" Type="string"/> <RegistryValue Value="Linux" Type="string"/>
<RegistryValue Name="SortOrderIndex" Value="119" Type="integer"/> <RegistryValue Name="SortOrderIndex" Value="119" Type="integer"/>
<!--0x77--> <!--0x77-->
<RegistryValue Name="System.IsPinnedToNameSpaceTree" Value="1" Type="integer"/> <RegistryValue Name="System.IsPinnedToNameSpaceTree" Value="1" Type="integer"/>
<RegistryKey Key="DefaultIcon"> <RegistryKey Key="DefaultIcon">
<RegistryValue Value="[System64Folder]wsl.exe,-1" Type="string"/> <RegistryValue Value="[System64Folder]wsl.exe,-1" Type="string"/>
</RegistryKey> </RegistryKey>
<RegistryKey Key="InProcServer32"> <RegistryKey Key="InProcServer32">
<RegistryValue Value="[System64Folder]windows.storage.dll" Type="string"/> <RegistryValue Value="[System64Folder]windows.storage.dll" Type="string"/>
</RegistryKey> </RegistryKey>
<RegistryKey Key="ShellFolder"> <RegistryKey Key="ShellFolder">
<RegistryValue Name="Attributes" Value="2692743245" Type="integer"/> <RegistryValue Name="Attributes" Value="2692743245" Type="integer"/>
<!--0xa080004d"--> <!--0xa080004d"-->
<RegistryValue Name="FolderValueFlags" Value="40" Type="integer"/> <RegistryValue Name="FolderValueFlags" Value="40" Type="integer"/>
<!--0x28--> <!--0x28-->
</RegistryKey> </RegistryKey>
<RegistryKey Key="Instance"> <RegistryKey Key="Instance">
<RegistryValue Name="CLSID" Value="{4FE04BFD-85B9-49DD-B914-F4C9556B9DA6}" Type="string"/> <RegistryValue Name="CLSID" Value="{4FE04BFD-85B9-49DD-B914-F4C9556B9DA6}" Type="string"/>
<RegistryKey Key="InitPropertyBag"> <RegistryKey Key="InitPropertyBag">
<RegistryValue Name="DisplayType" Value="2" Type="integer"/> <RegistryValue Name="DisplayType" Value="2" Type="integer"/>
<RegistryValue Name="EnumObjectsTelemetryValue" Value="WSL" Type="string"/> <RegistryValue Name="EnumObjectsTelemetryValue" Value="WSL" Type="string"/>
<RegistryValue Name="Provider" Value="Plan 9 Network Provider" Type="string"/> <RegistryValue Name="Provider" Value="Plan 9 Network Provider" Type="string"/>
<RegistryValue Name="ResName" Value="\\wsl.localhost" Type="string"/> <RegistryValue Name="ResName" Value="\\wsl.localhost" Type="string"/>
</RegistryKey>
</RegistryKey> </RegistryKey>
</RegistryKey> </RegistryKey>
</RegistryKey>
<RegistryKey Root="HKLM" Key="SOFTWARE\Microsoft\Windows\CurrentVersion\Explorer\HideDesktopIcons\NewStartPanel"> <RegistryKey Root="HKLM" Key="SOFTWARE\Microsoft\Windows\CurrentVersion\Explorer\HideDesktopIcons\NewStartPanel">
<RegistryValue Name="{B2B4A4D1-2754-4140-A2EB-9A76D9D7CDC6}" Value="1" Type="integer"/> <RegistryValue Name="{B2B4A4D1-2754-4140-A2EB-9A76D9D7CDC6}" Value="1" Type="integer"/>
</RegistryKey> </RegistryKey>
<RegistryKey Root="HKLM" Key="SOFTWARE\Microsoft\Windows\CurrentVersion\Explorer\Desktop\NameSpace\{B2B4A4D1-2754-4140-A2EB-9A76D9D7CDC6}"> <RegistryKey Root="HKLM" Key="SOFTWARE\Microsoft\Windows\CurrentVersion\Explorer\Desktop\NameSpace\{B2B4A4D1-2754-4140-A2EB-9A76D9D7CDC6}">
<RegistryValue Value="Linux" Type="string"/> <RegistryValue Value="Linux" Type="string"/>
</RegistryKey> </RegistryKey>
<RegistryKey Root="HKLM" Key="SOFTWARE\Microsoft\Windows\CurrentVersion\Explorer\IdListAliasTranslations\WSL"> <RegistryKey Root="HKLM" Key="SOFTWARE\Microsoft\Windows\CurrentVersion\Explorer\IdListAliasTranslations\WSL">
<RegistryValue Name="Target" Value="::{B2B4A4D1-2754-4140-A2EB-9A76D9D7CDC6}" Type="string"/> <RegistryValue Name="Target" Value="::{B2B4A4D1-2754-4140-A2EB-9A76D9D7CDC6}" Type="string"/>
<RegistryValue Name="Source" Value="\\wsl.localhost" Type="string"/> <RegistryValue Name="Source" Value="\\wsl.localhost" Type="string"/>
</RegistryKey> </RegistryKey>
<RegistryKey Root="HKLM" Key="SOFTWARE\Microsoft\Windows\CurrentVersion\Explorer\IdListAliasTranslations\WSLLegacy"> <RegistryKey Root="HKLM" Key="SOFTWARE\Microsoft\Windows\CurrentVersion\Explorer\IdListAliasTranslations\WSLLegacy">
<RegistryValue Name="Target" Value="::{B2B4A4D1-2754-4140-A2EB-9A76D9D7CDC6}" Type="string"/> <RegistryValue Name="Target" Value="::{B2B4A4D1-2754-4140-A2EB-9A76D9D7CDC6}" Type="string"/>
<RegistryValue Name="Source" Value="\\wsl$" Type="string"/> <RegistryValue Name="Source" Value="\\wsl$" Type="string"/>
</RegistryKey> </RegistryKey>
</Component> </Component>
<Component Id="explorershell" Guid="{93CBFF23-A04C-4344-A332-238CE5B97AEC}" UninstallWhenSuperseded="yes" DisableRegistryReflection="yes" Bitness="always64"> <Component Id="explorershell" Guid="{93CBFF23-A04C-4344-A332-238CE5B97AEC}" UninstallWhenSuperseded="yes" DisableRegistryReflection="yes" Bitness="always64">
<?foreach PATH in SOFTWARE\Classes\Directory\shell\WSL;SOFTWARE\Classes\Directory\Background\shell\WSL;SOFTWARE\Classes\Drive\shell\WSL?> <?foreach PATH in SOFTWARE\Classes\Directory\shell\WSL;SOFTWARE\Classes\Directory\Background\shell\WSL;SOFTWARE\Classes\Drive\shell\WSL?>
<RegistryKey Root="HKLM" Key="$(var.PATH)"> <RegistryKey Root="HKLM" Key="$(var.PATH)">
<RegistryValue Value="@wsl.exe,-2" Type="string"/> <RegistryValue Value="@wsl.exe,-2" Type="string"/>
<RegistryValue Name="Extended" Value="" Type="string"/> <RegistryValue Name="Extended" Value="" Type="string"/>
<RegistryValue Name="NoWorkingDirectory" Value="" Type="string"/> <RegistryValue Name="NoWorkingDirectory" Value="" Type="string"/>
<RegistryKey Key="command"> <RegistryKey Key="command">
<RegistryValue Value='wsl.exe --cd "%V"' Type="string"/> <RegistryValue Value='wsl.exe --cd "%V"' Type="string"/>
</RegistryKey>
</RegistryKey> </RegistryKey>
</RegistryKey> <?endforeach?>
<?endforeach?>
<ProgId Id="WSLDistributionTar" Description="WSL tar distribution" Icon="wsl.exe"> <ProgId Id="WSLDistributionTar" Description="WSL tar distribution" Icon="wsl.exe">
<Extension Id="wsl"> <Extension Id="wsl">
<Verb Id="open" Command="open" TargetFile="wsl.exe" Argument="--install --prompt-before-exit --from-file &quot;%1&quot;" /> <Verb Id="open" Command="open" TargetFile="wsl.exe" Argument="--install --prompt-before-exit --from-file &quot;%1&quot;" />
</Extension> </Extension>
</ProgId> </ProgId>
</Component> </Component>
<Component Id="wslservice" Guid="F0C8D6BA-1502-41E7-BF72-D93DFA134735" UninstallWhenSuperseded="yes" DisableRegistryReflection="yes" Bitness="always64"> <Component Id="wslservice" Guid="F0C8D6BA-1502-41E7-BF72-D93DFA134735" UninstallWhenSuperseded="yes" DisableRegistryReflection="yes" Bitness="always64">
@ -131,33 +131,13 @@
</RegistryKey> </RegistryKey>
</RegistryKey> </RegistryKey>
<RegistryKey Root="HKCR" Key="Interface\{82A7ABC8-6B50-43FC-AB96-15FBBE7E8760}"> <!-- WSLServiceProxyStub. -->
<RegistryValue Value="IWSLAUserSession" Type="string" />
<RegistryKey Key="ProxyStubClsid32">
<RegistryValue Value="{4EA0C6DD-E9FF-48E7-994E-13A31D10DC60}" Type="string" />
</RegistryKey>
</RegistryKey>
<RegistryKey Root="HKCR" Key="Interface\{82A7ABC8-6B50-43FC-AB96-15FBBE7E8761}">
<RegistryValue Value="IWSLAVirtualMachine" Type="string" />
<RegistryKey Key="ProxyStubClsid32">
<RegistryValue Value="{4EA0C6DD-E9FF-48E7-994E-13A31D10DC60}" Type="string" />
</RegistryKey>
</RegistryKey>
<RegistryKey Root="HKCR" Key="Interface\{7BC4E198-6531-4FA6-ADE2-5EF3D2A04DFE}">
<RegistryValue Value="ITerminationCallback" Type="string" />
<RegistryKey Key="ProxyStubClsid32">
<RegistryValue Value="{4EA0C6DD-E9FF-48E7-994E-13A31D10DC60}" Type="string" />
</RegistryKey>
</RegistryKey>
<RegistryKey Root="HKCR" Key="CLSID\{4EA0C6DD-E9FF-48E7-994E-13A31D10DC60}"> <RegistryKey Root="HKCR" Key="CLSID\{4EA0C6DD-E9FF-48E7-994E-13A31D10DC60}">
<RegistryValue Value="PSFactoryBuffer" Type="string" /> <RegistryValue Value="PSFactoryBuffer" Type="string" />
</RegistryKey> <RegistryKey Key="InProcServer32">
<RegistryKey Root="HKCR" Key="CLSID\{4EA0C6DD-E9FF-48E7-994E-13A31D10DC60}\InProcServer32"> <RegistryValue Value="[INSTALLDIR]wslserviceproxystub.dll" Type="string" />
<RegistryValue Value="[INSTALLDIR]wslserviceproxystub.dll" Type="string" /> <RegistryValue Name="ThreadingModel" Value="Both" Type="string" />
<RegistryValue Name="ThreadingModel" Value="Both" Type="string" /> </RegistryKey>
</RegistryKey> </RegistryKey>
<!-- ILxssUserSession --> <!-- ILxssUserSession -->
@ -175,18 +155,6 @@
<RegistryValue Value="LxssUserSession" Type="string" /> <RegistryValue Value="LxssUserSession" Type="string" />
</RegistryKey> </RegistryKey>
<!-- WSLAUserSession -->
<RegistryKey Root="HKCR" Key="CLSID\{a9b7a1b9-0671-405c-95f1-e0612cb4ce8f}">
<RegistryValue Name="AppId" Value="{370121D2-AA7E-4608-A86D-0BBAB9DA1A60}" Type="string" />
<RegistryValue Value="WSLAUserSession" Type="string" />
</RegistryKey>
<!-- WSLAVirtualMachine -->
<RegistryKey Root="HKCR" Key="CLSID\{0CFC5DC1-B6A7-45FC-8034-3FA9ED73CE30}">
<RegistryValue Name="AppId" Value="{370121D2-AA7E-4608-A86D-0BBAB9DA1A60}" Type="string" />
<RegistryValue Value="WSLAVirtualMachine" Type="string" />
</RegistryKey>
<!-- Notification server --> <!-- Notification server -->
<RegistryKey Root="HKCR" Key="CLSID\{2B9C59C3-98F1-45C8-B87B-12AE3C7927E8}\LocalServer32"> <RegistryKey Root="HKCR" Key="CLSID\{2B9C59C3-98F1-45C8-B87B-12AE3C7927E8}\LocalServer32">
<RegistryValue Value='"[INSTALLDIR]wslhost.exe"' Type="string" /> <RegistryValue Value='"[INSTALLDIR]wslhost.exe"' Type="string" />
@ -226,7 +194,7 @@
<RegistryKey Root="HKCR" Key="AppID\{7F82AD86-755B-4870-86B1-D2E68DFE8A49}"> <RegistryKey Root="HKCR" Key="AppID\{7F82AD86-755B-4870-86B1-D2E68DFE8A49}">
<RegistryValue Name="DllSurrogate" Value="" Type="string"/> <RegistryValue Name="DllSurrogate" Value="" Type="string"/>
<RegistryValue Name="AppIDFlags" Value="2048" Type="integer"/><!--0x800--> <RegistryValue Name="AppIDFlags" Value="2048" Type="integer"/> <!--0x800-->
<!-- O:BAG:BAD:(A;;CCDCSW;;;AU)(A;;CCDCSW;;;PS)(A;;CCDCSW;;;SY) --> <!-- O:BAG:BAD:(A;;CCDCSW;;;AU)(A;;CCDCSW;;;PS)(A;;CCDCSW;;;SY) -->
<RegistryValue Name="AccessPermission" Value="01000480580000006800000000000000140000000200440003000000000014000B00000001010000000000050B000000000014000B00000001010000000000050A000000000014000B0000000101000000000005120000000102000000000005200000002002000001020000000000052000000020020000" Type="binary" /> <RegistryValue Name="AccessPermission" Value="01000480580000006800000000000000140000000200440003000000000014000B00000001010000000000050B000000000014000B00000001010000000000050A000000000014000B0000000101000000000005120000000102000000000005200000002002000001020000000000052000000020020000" Type="binary" />
@ -265,6 +233,68 @@
<ServiceControl Id="StopService" Stop="both" Remove="uninstall" Name="WSLService" Wait="yes" /> <ServiceControl Id="StopService" Stop="both" Remove="uninstall" Name="WSLService" Wait="yes" />
</Component> </Component>
<Component Id="wslaservice" Guid="DC97520E-BFA5-4559-960F-D580E151629F" UninstallWhenSuperseded="yes" DisableRegistryReflection="yes" Bitness="always64">
<!-- WSLAServiceProxyStub. -->
<RegistryKey Root="HKCR" Key="CLSID\{4EA0C6DD-E9FF-48E7-994E-13A31D10DC61}">
<RegistryValue Value="PSFactoryBuffer" Type="string" />
<RegistryKey Key="InProcServer32">
<RegistryValue Value="[INSTALLDIR]wslaserviceproxystub.dll" Type="string" />
<RegistryValue Name="ThreadingModel" Value="Both" Type="string" />
</RegistryKey>
</RegistryKey>
<!-- WSLAService COM app -->
<RegistryKey Root="HKCR" Key="AppID\{E9B79997-57E3-4201-AECC-6A464E530DD2}">
<!-- O:BAG:BAD:(A;;CCDCSW;;;AU)(A;;CCDCSW;;;PS)(A;;CCDCSW;;;SY) -->
<RegistryValue Name="AccessPermission" Value="01000480580000006800000000000000140000000200440003000000000014000B00000001010000000000050B000000000014000B00000001010000000000050A000000000014000B0000000101000000000005120000000102000000000005200000002002000001020000000000052000000020020000" Type="binary" />
<RegistryValue Name="LaunchPermission" Value="01000480580000006800000000000000140000000200440003000000000014000B00000001010000000000050B000000000014000B00000001010000000000050A000000000014000B0000000101000000000005120000000102000000000005200000002002000001020000000000052000000020020000" Type="binary" />
<RegistryValue Name="LocalService" Value="WSLAService" Type="string" />
</RegistryKey>
<!-- WSLAUserSession -->
<RegistryKey Root="HKCR" Key="CLSID\{a9b7a1b9-0671-405c-95f1-e0612cb4ce8f}">
<RegistryValue Name="AppId" Value="{E9B79997-57E3-4201-AECC-6A464E530DD2}" Type="string" />
<RegistryValue Value="WSLAUserSession" Type="string" />
</RegistryKey>
<!-- WSLAVirtualMachine -->
<RegistryKey Root="HKCR" Key="CLSID\{0CFC5DC1-B6A7-45FC-8034-3FA9ED73CE30}">
<RegistryValue Name="AppId" Value="{E9B79997-57E3-4201-AECC-6A464E530DD2}" Type="string" />
<RegistryValue Value="WSLAVirtualMachine" Type="string" />
</RegistryKey>
<!-- IWSLAUserSession-->
<RegistryKey Root="HKCR" Key="Interface\{82A7ABC8-6B50-43FC-AB96-15FBBE7E8760}">
<RegistryValue Value="IWSLAUserSession" Type="string" />
<RegistryKey Key="ProxyStubClsid32">
<RegistryValue Value="{4EA0C6DD-E9FF-48E7-994E-13A31D10DC61}" Type="string" />
</RegistryKey>
</RegistryKey>
<!-- IWSLAVirtualMachine-->
<RegistryKey Root="HKCR" Key="Interface\{82A7ABC8-6B50-43FC-AB96-15FBBE7E8761}">
<RegistryValue Value="IWSLAVirtualMachine" Type="string" />
<RegistryKey Key="ProxyStubClsid32">
<RegistryValue Value="{4EA0C6DD-E9FF-48E7-994E-13A31D10DC61}" Type="string" />
</RegistryKey>
</RegistryKey>
<!-- ITerminationCallback-->
<RegistryKey Root="HKCR" Key="Interface\{7BC4E198-6531-4FA6-ADE2-5EF3D2A04DFE}">
<RegistryValue Value="ITerminationCallback" Type="string" />
<RegistryKey Key="ProxyStubClsid32">
<RegistryValue Value="{4EA0C6DD-E9FF-48E7-994E-13A31D10DC61}" Type="string" />
</RegistryKey>
</RegistryKey>
<File Id="wslaservice.exe" Source="${BIN}/wslaservice.exe" KeyPath="yes" />
<File Id="wslaserviceproxystub.dll" Name="wslaserviceproxystub.dll" Source="${BIN}/wslaserviceproxystub.dll" />
<ServiceInstall Name="WSLAService" DisplayName="WSLA Service" Description="WSLA Service" Start="auto" Type="ownProcess" ErrorControl="normal" Account="LocalSystem" Vital="yes" Interactive="no" />
<ServiceControl Id="StopWSLAService" Stop="both" Remove="uninstall" Name="WSLAService" Wait="yes" />
</Component>
<Component Id="wslg" Guid="F0C8D6BA-1502-41E7-BF72-D93DFA134731" UninstallWhenSuperseded="yes" DisableRegistryReflection="yes" Bitness="always64"> <Component Id="wslg" Guid="F0C8D6BA-1502-41E7-BF72-D93DFA134731" UninstallWhenSuperseded="yes" DisableRegistryReflection="yes" Bitness="always64">
<?if "${WSL_DEV_BINARY_PATH}" = "" ?> <?if "${WSL_DEV_BINARY_PATH}" = "" ?>
<File Id="msrdc.exe" Source="${MSRDC_SOURCE_DIR}/${TARGET_PLATFORM}/msrdc.exe" /> <File Id="msrdc.exe" Source="${MSRDC_SOURCE_DIR}/${TARGET_PLATFORM}/msrdc.exe" />
@ -402,6 +432,7 @@
<Feature Id="WSL" Title="Windows Subsystem for Linux" Level="1"> <Feature Id="WSL" Title="Windows Subsystem for Linux" Level="1">
<ComponentRef Id="wsl" /> <ComponentRef Id="wsl" />
<ComponentRef Id="wslservice" /> <ComponentRef Id="wslservice" />
<ComponentRef Id="wslaservice" />
<ComponentRef Id="wslg" /> <ComponentRef Id="wslg" />
<ComponentRef Id="tools" /> <ComponentRef Id="tools" />
<ComponentRef Id="explorershell" /> <ComponentRef Id="explorershell" />
@ -501,21 +532,21 @@
Return="check" Return="check"
Execute="deferred" Execute="deferred"
/> />
<CustomAction Id="RemoveRegistryKeyProtections" <CustomAction Id="RemoveRegistryKeyProtections"
Impersonate="no" Impersonate="no"
BinaryRef="wslinstall.dll" BinaryRef="wslinstall.dll"
DllEntry="RemoveRegistryKeyProtections" DllEntry="RemoveRegistryKeyProtections"
Return="check" Return="check"
Execute="deferred" Execute="deferred"
/> />
<CustomAction Id="UnregisterLspCategories" <CustomAction Id="UnregisterLspCategories"
Impersonate="no" Impersonate="no"
BinaryRef="wslinstall.dll" BinaryRef="wslinstall.dll"
DllEntry="UnregisterLspCategories" DllEntry="UnregisterLspCategories"
Return="check" Return="check"
Execute="deferred" Execute="deferred"
/> />
<CustomAction Id="InstallMsix.SetProperty" Return="check" Property="InstallMsix" Value='[DATABASE]' Execute='immediate' /> <CustomAction Id="InstallMsix.SetProperty" Return="check" Property="InstallMsix" Value='[DATABASE]' Execute='immediate' />

View File

@ -1,83 +1,105 @@
set(SOURCES set(SOURCES
ConsoleProgressBar.cpp ConsoleProgressBar.cpp
ConsoleProgressIndicator.cpp ConsoleProgressIndicator.cpp
DeviceHostProxy.cpp
DeviceHostProxy.h
disk.cpp disk.cpp
Distribution.cpp Distribution.cpp
Dmesg.cpp
Dmesg.h
DnsResolver.cpp
DnsTunnelingChannel.cpp
ExecutionContext.cpp
filesystem.cpp filesystem.cpp
GnsChannel.cpp
HandleConsoleProgressBar.cpp HandleConsoleProgressBar.cpp
hcs.cpp hcs.cpp
helpers.cpp helpers.cpp
interop.cpp
ExecutionContext.cpp
socket.cpp
hvsocket.cpp hvsocket.cpp
interop.cpp
Localization.cpp Localization.cpp
lxssbusclient.cpp lxssbusclient.cpp
lxssclient.cpp lxssclient.cpp
LxssMessagePort.cpp LxssMessagePort.cpp
LxssSecurity.cpp
LxssServerPort.cpp LxssServerPort.cpp
NatNetworking.cpp
notifications.cpp
Redirector.cpp Redirector.cpp
registry.cpp registry.cpp
relay.cpp relay.cpp
RingBuffer.cpp
socket.cpp
string.cpp string.cpp
SubProcess.cpp SubProcess.cpp
svccomm.cpp svccomm.cpp
svccommio.cpp svccommio.cpp
WslClient.cpp WslClient.cpp
WslCoreConfig.cpp WslCoreConfig.cpp
WslCoreFilesystem.cpp
WslCoreFirewallSupport.cpp WslCoreFirewallSupport.cpp
WslCoreHostDnsInfo.cpp
WslCoreNetworkEndpointSettings.cpp
WslCoreNetworkingSupport.cpp WslCoreNetworkingSupport.cpp
WslInstall.cpp WslInstall.cpp
WslSecurity.cpp WslSecurity.cpp
WslTelemetry.cpp WslTelemetry.cpp
wslutil.cpp wslutil.cpp
notifications.cpp) )
set(HEADERS set(HEADERS
../../../generated/Localization.h ../../../generated/Localization.h
../../shared/inc/CommandLine.h
../../shared/inc/defs.h
../../shared/inc/lxfsshares.h
../../shared/inc/lxinitshared.h
../../shared/inc/SocketChannel.h
../../shared/inc/socketshared.h
../../shared/inc/hns_schema.h
../../shared/inc/JsonUtils.h
../../shared/inc/stringshared.h
../../shared/inc/retryshared.h
../../shared/inc/message.h
../../shared/inc/prettyprintshared.h
../inc/WslPluginApi.h
../inc/wslpolicies.h
../inc/lxssbusclient.h ../inc/lxssbusclient.h
../inc/lxssclient.h ../inc/lxssclient.h
../inc/LxssDynamicFunction.h ../inc/LxssDynamicFunction.h
../inc/traceloggingconfig.h ../inc/traceloggingconfig.h
../inc/wdk.h ../inc/wdk.h
../inc/wsl.h
../inc/wslconfig.h ../inc/wslconfig.h
../inc/wsl.h
../inc/wslhost.h ../inc/wslhost.h
../inc/WslPluginApi.h
../inc/wslpolicies.h
../inc/wslrelay.h ../inc/wslrelay.h
../../shared/inc/CommandLine.h
../../shared/inc/defs.h
../../shared/inc/hns_schema.h
../../shared/inc/JsonUtils.h
../../shared/inc/lxfsshares.h
../../shared/inc/lxinitshared.h
../../shared/inc/message.h
../../shared/inc/prettyprintshared.h
../../shared/inc/retryshared.h
../../shared/inc/SocketChannel.h
../../shared/inc/socketshared.h
../../shared/inc/stringshared.h
ConsoleProgressBar.h ConsoleProgressBar.h
ConsoleProgressIndicator.h ConsoleProgressIndicator.h
disk.hpp disk.hpp
Distribution.h Distribution.h
DnsResolver.h
DnsTunnelingChannel.h
ExecutionContext.h
filesystem.hpp filesystem.hpp
GnsChannel.h
HandleConsoleProgressBar.h
hcs.hpp hcs.hpp
hcs_schema.h hcs_schema.h
helpers.hpp helpers.hpp
HandleConsoleProgressBar.h
interop.hpp
ExecutionContext.h
socket.hpp
hvsocket.hpp hvsocket.hpp
INetworkingEngine.h
interop.hpp
LxssMessagePort.h LxssMessagePort.h
LxssPort.h LxssPort.h
LxssSecurity.h
LxssServerPort.h LxssServerPort.h
NatNetworking.h
notifications.h
precomp.h precomp.h
Redirector.h Redirector.h
registry.hpp registry.hpp
relay.hpp relay.hpp
RingBuffer.h
socket.hpp
string.hpp string.hpp
Stringify.h Stringify.h
SubProcess.h SubProcess.h
@ -85,13 +107,17 @@ set(HEADERS
svccommio.hpp svccommio.hpp
WslClient.h WslClient.h
WslCoreConfig.h WslCoreConfig.h
WslCoreFilesystem.h
WslCoreFirewallSupport.h WslCoreFirewallSupport.h
WslCoreHostDnsInfo.h
WslCoreMessageQueue.h
WslCoreNetworkEndpointSettings.h
WslCoreNetworkingSupport.h WslCoreNetworkingSupport.h
WslInstall.h WslInstall.h
WslSecurity.h WslSecurity.h
WslTelemetry.h WslTelemetry.h
wslutil.h wslutil.cpp
notifications.h) )
add_library(common STATIC ${SOURCES} ${HEADERS}) add_library(common STATIC ${SOURCES} ${HEADERS})
add_dependencies(common wslserviceidl localization wslservicemc wslinstalleridl) add_dependencies(common wslserviceidl localization wslservicemc wslinstalleridl)

View File

@ -1,261 +1,261 @@
// Copyright (C) Microsoft Corporation. All rights reserved. // Copyright (C) Microsoft Corporation. All rights reserved.
#include "precomp.h" #include "precomp.h"
#include "DeviceHostProxy.h" #include "DeviceHostProxy.h"
// This template works around a limitation with decltype on overloaded functions. It will be able // This template works around a limitation with decltype on overloaded functions. It will be able
// to get the correct version of GetVmWorkerProcess based on the provided type arguments. By // to get the correct version of GetVmWorkerProcess based on the provided type arguments. By
// doing it this way, a compiler error will be generated if someone changes the signature of // doing it this way, a compiler error will be generated if someone changes the signature of
// GetVmWorkerProcess. // GetVmWorkerProcess.
// //
// The way this works: decltype(GetVmWorkerProcess) does not work because it's overloaded. // The way this works: decltype(GetVmWorkerProcess) does not work because it's overloaded.
// decltype(GetVmWorkerProcess(arg1, ...)) works to select an overload if you have values of the // decltype(GetVmWorkerProcess(arg1, ...)) works to select an overload if you have values of the
// correct type (std::declval<T>() generates a value of the specified type), however the result // correct type (std::declval<T>() generates a value of the specified type), however the result
// of that is the function's return type, not the function's type, so the argument types must // of that is the function's return type, not the function's type, so the argument types must
// be repeated to reconstruct the function type. // be repeated to reconstruct the function type.
template <typename... Args> template <typename... Args>
using GetVmWorkerProcessType = decltype(GetVmWorkerProcess(std::declval<Args>()...))(Args...); using GetVmWorkerProcessType = decltype(GetVmWorkerProcess(std::declval<Args>()...))(Args...);
// Limit the number of allowed doorbells registered by an external HDV vdev. Currently virtio-9p only uses // Limit the number of allowed doorbells registered by an external HDV vdev. Currently virtio-9p only uses
// one doorbell and wsldevicehost uses only two. // one doorbell and wsldevicehost uses only two.
#define DEVICE_HOST_PROXY_DOORBELL_LIMIT 8 #define DEVICE_HOST_PROXY_DOORBELL_LIMIT 8
using namespace wsl::windows::common::hcs; using namespace wsl::windows::common::hcs;
DeviceHostProxy::DeviceHostProxy(const std::wstring& VmId, const GUID& RuntimeId) : DeviceHostProxy::DeviceHostProxy(const std::wstring& VmId, const GUID& RuntimeId) :
m_systemId{VmId}, m_runtimeId{RuntimeId}, m_system{wsl::windows::common::hcs::OpenComputeSystem(VmId.c_str(), GENERIC_ALL)}, m_shutdown{false} m_systemId{VmId}, m_runtimeId{RuntimeId}, m_system{wsl::windows::common::hcs::OpenComputeSystem(VmId.c_str(), GENERIC_ALL)}, m_shutdown{false}
{ {
m_devicesShutdown = false; m_devicesShutdown = false;
} }
GUID DeviceHostProxy::AddNewDevice(const GUID& Type, const wil::com_ptr<IPlan9FileSystem>& Plan9Fs, const std::wstring& VirtIoTag) GUID DeviceHostProxy::AddNewDevice(const GUID& Type, const wil::com_ptr<IPlan9FileSystem>& Plan9Fs, const std::wstring& VirtIoTag)
{ {
const wrl::ComPtr<IUnknown> thisUnknown{CastToUnknown()}; const wrl::ComPtr<IUnknown> thisUnknown{CastToUnknown()};
GUID instanceId{}; GUID instanceId{};
THROW_IF_FAILED(UuidCreate(&instanceId)); THROW_IF_FAILED(UuidCreate(&instanceId));
// Tell the device host to create the device. // Tell the device host to create the device.
THROW_IF_FAILED(Plan9Fs->CreateVirtioDevice(m_systemId.c_str(), thisUnknown.Get(), VirtIoTag.c_str(), &instanceId)); THROW_IF_FAILED(Plan9Fs->CreateVirtioDevice(m_systemId.c_str(), thisUnknown.Get(), VirtIoTag.c_str(), &instanceId));
// Add the instance ID to the list of known devices. This must be done before the device is // Add the instance ID to the list of known devices. This must be done before the device is
// added to the system, because doing that can cause the register doorbell function to be // added to the system, because doing that can cause the register doorbell function to be
// called. // called.
// N.B. It will be removed if there is a failure. // N.B. It will be removed if there is a failure.
{ {
auto lock = m_devicesLock.lock_exclusive(); auto lock = m_devicesLock.lock_exclusive();
THROW_HR_IF(E_CHANGED_STATE, m_devicesShutdown); THROW_HR_IF(E_CHANGED_STATE, m_devicesShutdown);
m_devices.emplace(instanceId, DeviceHostProxyEntry{}); m_devices.emplace(instanceId, DeviceHostProxyEntry{});
} }
auto removeOnFailure = wil::scope_exit_log(WI_DIAGNOSTICS_INFO, [&]() { auto removeOnFailure = wil::scope_exit_log(WI_DIAGNOSTICS_INFO, [&]() {
auto lock = m_devicesLock.lock_exclusive(); auto lock = m_devicesLock.lock_exclusive();
m_devices.erase(instanceId); m_devices.erase(instanceId);
}); });
// Add the device to the compute system on behalf of the device host. // Add the device to the compute system on behalf of the device host.
ModifySettingRequest<FlexibleIoDevice> request; ModifySettingRequest<FlexibleIoDevice> request;
request.RequestType = ModifyRequestType::Add; request.RequestType = ModifyRequestType::Add;
request.ResourcePath = L"VirtualMachine/Devices/FlexibleIov/"; request.ResourcePath = L"VirtualMachine/Devices/FlexibleIov/";
request.ResourcePath += wsl::shared::string::GuidToString<wchar_t>(instanceId, wsl::shared::string::GuidToStringFlags::None); request.ResourcePath += wsl::shared::string::GuidToString<wchar_t>(instanceId, wsl::shared::string::GuidToStringFlags::None);
request.Settings.EmulatorId = Type; request.Settings.EmulatorId = Type;
request.Settings.HostingModel = FlexibleIoDeviceHostingModel::ExternalRestricted; request.Settings.HostingModel = FlexibleIoDeviceHostingModel::ExternalRestricted;
wsl::windows::common::hcs::ModifyComputeSystem(m_system.get(), wsl::shared::ToJsonW(request).c_str()); wsl::windows::common::hcs::ModifyComputeSystem(m_system.get(), wsl::shared::ToJsonW(request).c_str());
removeOnFailure.release(); removeOnFailure.release();
return instanceId; return instanceId;
} }
void DeviceHostProxy::AddRemoteFileSystem(const GUID& ImplementationClsid, const std::wstring& Tag, const wil::com_ptr<IPlan9FileSystem>& Plan9Fs) void DeviceHostProxy::AddRemoteFileSystem(const GUID& ImplementationClsid, const std::wstring& Tag, const wil::com_ptr<IPlan9FileSystem>& Plan9Fs)
{ {
auto lock = m_lock.lock_exclusive(); auto lock = m_lock.lock_exclusive();
THROW_HR_IF(E_CHANGED_STATE, m_shutdown); THROW_HR_IF(E_CHANGED_STATE, m_shutdown);
// Make sure there are no duplicate tags. // Make sure there are no duplicate tags.
for (auto& entry : m_fileSystems) for (auto& entry : m_fileSystems)
{ {
THROW_HR_IF(E_INVALIDARG, entry.ImplementationClsid == ImplementationClsid && entry.Tag == Tag); THROW_HR_IF(E_INVALIDARG, entry.ImplementationClsid == ImplementationClsid && entry.Tag == Tag);
} }
m_fileSystems.emplace_back(ImplementationClsid, Tag, Plan9Fs); m_fileSystems.emplace_back(ImplementationClsid, Tag, Plan9Fs);
} }
wil::com_ptr<IPlan9FileSystem> DeviceHostProxy::GetRemoteFileSystem(const GUID& ImplementationClsid, std::wstring_view Tag) wil::com_ptr<IPlan9FileSystem> DeviceHostProxy::GetRemoteFileSystem(const GUID& ImplementationClsid, std::wstring_view Tag)
{ {
auto lock = m_lock.lock_shared(); auto lock = m_lock.lock_shared();
THROW_HR_IF(E_CHANGED_STATE, m_shutdown); THROW_HR_IF(E_CHANGED_STATE, m_shutdown);
for (auto& entry : m_fileSystems) for (auto& entry : m_fileSystems)
{ {
if (entry.ImplementationClsid == ImplementationClsid && entry.Tag == Tag) if (entry.ImplementationClsid == ImplementationClsid && entry.Tag == Tag)
{ {
return entry.Instance; return entry.Instance;
} }
} }
return {}; return {};
} }
void DeviceHostProxy::Shutdown() void DeviceHostProxy::Shutdown()
{ {
{ {
auto lock = m_lock.lock_exclusive(); auto lock = m_lock.lock_exclusive();
m_fileSystems.clear(); m_fileSystems.clear();
m_shutdown = true; m_shutdown = true;
} }
{ {
auto lock = m_devicesLock.lock_exclusive(); auto lock = m_devicesLock.lock_exclusive();
m_devices.clear(); m_devices.clear();
m_devicesShutdown = true; m_devicesShutdown = true;
} }
} }
HRESULT HRESULT
DeviceHostProxy::RegisterDeviceHost(_In_ IVmDeviceHost* DeviceHost, _In_ DWORD ProcessId, _Out_ UINT64* IpcSectionHandle) DeviceHostProxy::RegisterDeviceHost(_In_ IVmDeviceHost* DeviceHost, _In_ DWORD ProcessId, _Out_ UINT64* IpcSectionHandle)
try try
{ {
// //
// Because HdvProxyDeviceHost is not part of the API set, it is loaded here dynamically. // Because HdvProxyDeviceHost is not part of the API set, it is loaded here dynamically.
// //
static LxssDynamicFunction<decltype(HdvProxyDeviceHost)> proxyDeviceHost{c_hdvModuleName, "HdvProxyDeviceHost"}; static LxssDynamicFunction<decltype(HdvProxyDeviceHost)> proxyDeviceHost{c_hdvModuleName, "HdvProxyDeviceHost"};
const wil::com_ptr<IVmDeviceHost> remoteHost = DeviceHost; const wil::com_ptr<IVmDeviceHost> remoteHost = DeviceHost;
const wil::com_ptr<IUnknown> unknown = remoteHost.query<IUnknown>(); const wil::com_ptr<IUnknown> unknown = remoteHost.query<IUnknown>();
THROW_IF_FAILED(proxyDeviceHost(m_system.get(), unknown.get(), ProcessId, IpcSectionHandle)); THROW_IF_FAILED(proxyDeviceHost(m_system.get(), unknown.get(), ProcessId, IpcSectionHandle));
return S_OK; return S_OK;
} }
CATCH_RETURN() CATCH_RETURN()
HRESULT HRESULT
DeviceHostProxy::NotifyAllDevicesInUse(_In_ LPCWSTR Tag) DeviceHostProxy::NotifyAllDevicesInUse(_In_ LPCWSTR Tag)
try try
{ {
// //
// Add another Plan9 virtio device to the guest so additional mount commands will be possible. // Add another Plan9 virtio device to the guest so additional mount commands will be possible.
// This callback should be unused by virtiofs devices because a device is created for every // This callback should be unused by virtiofs devices because a device is created for every
// AddSharePath call. // AddSharePath call.
// //
auto p9fs = GetRemoteFileSystem(__uuidof(p9fs::Plan9FileSystem), Tag); auto p9fs = GetRemoteFileSystem(__uuidof(p9fs::Plan9FileSystem), Tag);
THROW_HR_IF(E_NOT_SET, !p9fs); THROW_HR_IF(E_NOT_SET, !p9fs);
(void)AddNewDevice(VIRTIO_PLAN9_DEVICE_ID, p9fs, Tag); (void)AddNewDevice(VIRTIO_PLAN9_DEVICE_ID, p9fs, Tag);
return S_OK; return S_OK;
} }
CATCH_RETURN() CATCH_RETURN()
HRESULT HRESULT
DeviceHostProxy::RegisterDoorbell(const GUID& InstanceId, UINT8 BarIndex, UINT64 Offset, UINT64 TriggerValue, UINT64 Flags, HANDLE Event) DeviceHostProxy::RegisterDoorbell(const GUID& InstanceId, UINT8 BarIndex, UINT64 Offset, UINT64 TriggerValue, UINT64 Flags, HANDLE Event)
try try
{ {
auto lock = m_devicesLock.lock_exclusive(); auto lock = m_devicesLock.lock_exclusive();
RETURN_HR_IF(E_CHANGED_STATE, m_devicesShutdown); RETURN_HR_IF(E_CHANGED_STATE, m_devicesShutdown);
// Check if the device is one of the known devices that doorbells can be registered for, and // Check if the device is one of the known devices that doorbells can be registered for, and
// if the device has not already registered a doorbell. // if the device has not already registered a doorbell.
// N.B. For security it is enforced that each device can only register a small number of doorbells. // N.B. For security it is enforced that each device can only register a small number of doorbells.
// Currently virtio-9p only uses one and the external virtio device uses two. // Currently virtio-9p only uses one and the external virtio device uses two.
const auto knownDevice = m_devices.find(InstanceId); const auto knownDevice = m_devices.find(InstanceId);
RETURN_HR_IF(E_ACCESSDENIED, knownDevice == m_devices.end() || knownDevice->second.DoorbellCount == DEVICE_HOST_PROXY_DOORBELL_LIMIT); RETURN_HR_IF(E_ACCESSDENIED, knownDevice == m_devices.end() || knownDevice->second.DoorbellCount == DEVICE_HOST_PROXY_DOORBELL_LIMIT);
if (!knownDevice->second.MemoryNotification) if (!knownDevice->second.MemoryNotification)
{ {
// Get an interface to the worker process to query devices. // Get an interface to the worker process to query devices.
if (!m_deviceAccess) if (!m_deviceAccess)
{ {
static LxssDynamicFunction<GetVmWorkerProcessType<REFGUID, REFIID, IUnknown**>> getVmWorker{ static LxssDynamicFunction<GetVmWorkerProcessType<REFGUID, REFIID, IUnknown**>> getVmWorker{
c_vmwpctrlModuleName, "GetVmWorkerProcess"}; c_vmwpctrlModuleName, "GetVmWorkerProcess"};
RETURN_IF_FAILED(getVmWorker(m_runtimeId, __uuidof(*m_deviceAccess), reinterpret_cast<IUnknown**>(&m_deviceAccess))); RETURN_IF_FAILED(getVmWorker(m_runtimeId, __uuidof(*m_deviceAccess), reinterpret_cast<IUnknown**>(&m_deviceAccess)));
} }
RETURN_HR_IF(E_NOINTERFACE, !m_deviceAccess); RETURN_HR_IF(E_NOINTERFACE, !m_deviceAccess);
// Retrieve the device's memory notification interface to register the doorbell, and store it // Retrieve the device's memory notification interface to register the doorbell, and store it
// to be used during unregistration. // to be used during unregistration.
wil::com_ptr<IUnknown> device; wil::com_ptr<IUnknown> device;
RETURN_IF_FAILED(m_deviceAccess->GetDevice(FLEXIO_DEVICE_ID, InstanceId, &device)); RETURN_IF_FAILED(m_deviceAccess->GetDevice(FLEXIO_DEVICE_ID, InstanceId, &device));
knownDevice->second.MemoryNotification = device.query<IVmFiovGuestMemoryFastNotification>(); knownDevice->second.MemoryNotification = device.query<IVmFiovGuestMemoryFastNotification>();
} }
const auto result = knownDevice->second.MemoryNotification->RegisterDoorbell( const auto result = knownDevice->second.MemoryNotification->RegisterDoorbell(
static_cast<FIOV_BAR_SELECTOR>(BarIndex), Offset, TriggerValue, Flags, Event); static_cast<FIOV_BAR_SELECTOR>(BarIndex), Offset, TriggerValue, Flags, Event);
if (SUCCEEDED(result)) if (SUCCEEDED(result))
{ {
++knownDevice->second.DoorbellCount; ++knownDevice->second.DoorbellCount;
} }
return result; return result;
} }
CATCH_RETURN() CATCH_RETURN()
HRESULT HRESULT
DeviceHostProxy::UnregisterDoorbell(const GUID& InstanceId, UINT8 BarIndex, UINT64 Offset, UINT64 TriggerValue, UINT64 Flags) DeviceHostProxy::UnregisterDoorbell(const GUID& InstanceId, UINT8 BarIndex, UINT64 Offset, UINT64 TriggerValue, UINT64 Flags)
try try
{ {
auto lock = m_devicesLock.lock_exclusive(); auto lock = m_devicesLock.lock_exclusive();
RETURN_HR_IF(E_CHANGED_STATE, m_devicesShutdown); RETURN_HR_IF(E_CHANGED_STATE, m_devicesShutdown);
// Check if the device is a known device and has registered a doorbell. // Check if the device is a known device and has registered a doorbell.
// N.B. If the device is being removed, the device can't be retrieved from the worker process // N.B. If the device is being removed, the device can't be retrieved from the worker process
// so it's necessary to use the stored COM pointer. // so it's necessary to use the stored COM pointer.
const auto device = m_devices.find(InstanceId); const auto device = m_devices.find(InstanceId);
RETURN_HR_IF(E_ACCESSDENIED, device == m_devices.end() || device->second.DoorbellCount == 0); RETURN_HR_IF(E_ACCESSDENIED, device == m_devices.end() || device->second.DoorbellCount == 0);
RETURN_IF_FAILED(device->second.MemoryNotification->UnregisterDoorbell(static_cast<FIOV_BAR_SELECTOR>(BarIndex), Offset, TriggerValue, Flags)); RETURN_IF_FAILED(device->second.MemoryNotification->UnregisterDoorbell(static_cast<FIOV_BAR_SELECTOR>(BarIndex), Offset, TriggerValue, Flags));
if (--device->second.DoorbellCount == 0) if (--device->second.DoorbellCount == 0)
{ {
device->second.MemoryNotification.reset(); device->second.MemoryNotification.reset();
} }
return S_OK; return S_OK;
} }
CATCH_RETURN() CATCH_RETURN()
HRESULT HRESULT
DeviceHostProxy::CreateSectionBackedMmioRange( DeviceHostProxy::CreateSectionBackedMmioRange(
const GUID& InstanceId, UINT8 BarIndex, UINT64 BarOffsetInPages, UINT64 PageCount, UINT64 MappingFlags, HANDLE SectionHandle, UINT64 SectionOffsetInPages) const GUID& InstanceId, UINT8 BarIndex, UINT64 BarOffsetInPages, UINT64 PageCount, UINT64 MappingFlags, HANDLE SectionHandle, UINT64 SectionOffsetInPages)
try try
{ {
auto lock = m_devicesLock.lock_exclusive(); auto lock = m_devicesLock.lock_exclusive();
RETURN_HR_IF(E_CHANGED_STATE, m_devicesShutdown); RETURN_HR_IF(E_CHANGED_STATE, m_devicesShutdown);
// Check if the device is one of the known devices. // Check if the device is one of the known devices.
const auto knownDevice = m_devices.find(InstanceId); const auto knownDevice = m_devices.find(InstanceId);
THROW_HR_IF(E_ACCESSDENIED, knownDevice == m_devices.end()); THROW_HR_IF(E_ACCESSDENIED, knownDevice == m_devices.end());
if (!knownDevice->second.MemoryMapping) if (!knownDevice->second.MemoryMapping)
{ {
// Get an interface to the worker process to query devices. // Get an interface to the worker process to query devices.
if (!m_deviceAccess) if (!m_deviceAccess)
{ {
static LxssDynamicFunction<GetVmWorkerProcessType<REFGUID, REFIID, IUnknown**>> getVmWorker{ static LxssDynamicFunction<GetVmWorkerProcessType<REFGUID, REFIID, IUnknown**>> getVmWorker{
c_vmwpctrlModuleName, "GetVmWorkerProcess"}; c_vmwpctrlModuleName, "GetVmWorkerProcess"};
THROW_IF_FAILED(getVmWorker(m_runtimeId, __uuidof(*m_deviceAccess), reinterpret_cast<IUnknown**>(&m_deviceAccess))); THROW_IF_FAILED(getVmWorker(m_runtimeId, __uuidof(*m_deviceAccess), reinterpret_cast<IUnknown**>(&m_deviceAccess)));
} }
THROW_HR_IF(E_NOINTERFACE, !m_deviceAccess); THROW_HR_IF(E_NOINTERFACE, !m_deviceAccess);
// Retrieve the device specific interface to manage mapped sections. // Retrieve the device specific interface to manage mapped sections.
wil::com_ptr<IUnknown> device; wil::com_ptr<IUnknown> device;
THROW_IF_FAILED(m_deviceAccess->GetDevice(FLEXIO_DEVICE_ID, InstanceId, &device)); THROW_IF_FAILED(m_deviceAccess->GetDevice(FLEXIO_DEVICE_ID, InstanceId, &device));
knownDevice->second.MemoryMapping = device.query<IVmFiovGuestMmioMappings>(); knownDevice->second.MemoryMapping = device.query<IVmFiovGuestMmioMappings>();
} }
THROW_IF_FAILED(knownDevice->second.MemoryMapping->CreateSectionBackedMmioRange( THROW_IF_FAILED(knownDevice->second.MemoryMapping->CreateSectionBackedMmioRange(
static_cast<FIOV_BAR_SELECTOR>(BarIndex), BarOffsetInPages, PageCount, static_cast<FiovMmioMappingFlags>(MappingFlags), SectionHandle, SectionOffsetInPages)); static_cast<FIOV_BAR_SELECTOR>(BarIndex), BarOffsetInPages, PageCount, static_cast<FiovMmioMappingFlags>(MappingFlags), SectionHandle, SectionOffsetInPages));
return S_OK; return S_OK;
} }
CATCH_RETURN() CATCH_RETURN()
HRESULT HRESULT
DeviceHostProxy::DestroySectionBackedMmioRange(const GUID& InstanceId, UINT8 BarIndex, UINT64 BarOffsetInPages) DeviceHostProxy::DestroySectionBackedMmioRange(const GUID& InstanceId, UINT8 BarIndex, UINT64 BarOffsetInPages)
try try
{ {
auto lock = m_devicesLock.lock_exclusive(); auto lock = m_devicesLock.lock_exclusive();
RETURN_HR_IF(E_CHANGED_STATE, m_devicesShutdown); RETURN_HR_IF(E_CHANGED_STATE, m_devicesShutdown);
const auto device = m_devices.find(InstanceId); const auto device = m_devices.find(InstanceId);
RETURN_HR_IF(E_ACCESSDENIED, device == m_devices.end() || !device->second.MemoryMapping); RETURN_HR_IF(E_ACCESSDENIED, device == m_devices.end() || !device->second.MemoryMapping);
RETURN_IF_FAILED(device->second.MemoryMapping->DestroySectionBackedMmioRange(static_cast<FIOV_BAR_SELECTOR>(BarIndex), BarOffsetInPages)); RETURN_IF_FAILED(device->second.MemoryMapping->DestroySectionBackedMmioRange(static_cast<FIOV_BAR_SELECTOR>(BarIndex), BarOffsetInPages));
return S_OK; return S_OK;
} }
CATCH_RETURN() CATCH_RETURN()

View File

@ -1,76 +1,76 @@
// Copyright (C) Microsoft Corporation. All rights reserved. // Copyright (C) Microsoft Corporation. All rights reserved.
#pragma once #pragma once
#include <windowsdefs.h> #include <windowsdefs.h>
#include "hcs.hpp" #include "hcs.hpp"
namespace wrl = Microsoft::WRL; namespace wrl = Microsoft::WRL;
class DeviceHostProxy : public wrl::RuntimeClass<wrl::RuntimeClassFlags<wrl::RuntimeClassType::ClassicCom>, IVmDeviceHostSupport, IPlan9FileSystemHost> class DeviceHostProxy : public wrl::RuntimeClass<wrl::RuntimeClassFlags<wrl::RuntimeClassType::ClassicCom>, IVmDeviceHostSupport, IPlan9FileSystemHost>
{ {
public: public:
DeviceHostProxy(const std::wstring& VmId, const GUID& RuntimeId); DeviceHostProxy(const std::wstring& VmId, const GUID& RuntimeId);
GUID AddNewDevice(const GUID& Type, const wil::com_ptr<IPlan9FileSystem>& Plan9Fs, const std::wstring& VirtIoTag); GUID AddNewDevice(const GUID& Type, const wil::com_ptr<IPlan9FileSystem>& Plan9Fs, const std::wstring& VirtIoTag);
void AddRemoteFileSystem(const GUID& ImplementationClsid, const std::wstring& Tag, const wil::com_ptr<IPlan9FileSystem>& Plan9Fs); void AddRemoteFileSystem(const GUID& ImplementationClsid, const std::wstring& Tag, const wil::com_ptr<IPlan9FileSystem>& Plan9Fs);
wil::com_ptr<IPlan9FileSystem> GetRemoteFileSystem(const GUID& ImplementationClsid, std::wstring_view Tag); wil::com_ptr<IPlan9FileSystem> GetRemoteFileSystem(const GUID& ImplementationClsid, std::wstring_view Tag);
void Shutdown(); void Shutdown();
// //
// IVmDeviceHostSupport // IVmDeviceHostSupport
// //
IFACEMETHOD(RegisterDeviceHost)(_In_ IVmDeviceHost* DeviceHost, _In_ DWORD ProcessId, _Out_ UINT64* IpcSectionHandle) override; IFACEMETHOD(RegisterDeviceHost)(_In_ IVmDeviceHost* DeviceHost, _In_ DWORD ProcessId, _Out_ UINT64* IpcSectionHandle) override;
// //
// IPlan9FileSystemHost // IPlan9FileSystemHost
// //
IFACEMETHOD(NotifyAllDevicesInUse)(_In_ LPCWSTR Tag) override; IFACEMETHOD(NotifyAllDevicesInUse)(_In_ LPCWSTR Tag) override;
IFACEMETHOD(RegisterDoorbell)(const GUID& InstanceId, UINT8 BarIndex, UINT64 Offset, UINT64 TriggerValue, UINT64 Flags, HANDLE Event) override; IFACEMETHOD(RegisterDoorbell)(const GUID& InstanceId, UINT8 BarIndex, UINT64 Offset, UINT64 TriggerValue, UINT64 Flags, HANDLE Event) override;
IFACEMETHOD(UnregisterDoorbell)(const GUID& InstanceId, UINT8 BarIndex, UINT64 Offset, UINT64 TriggerValue, UINT64 Flags) override; IFACEMETHOD(UnregisterDoorbell)(const GUID& InstanceId, UINT8 BarIndex, UINT64 Offset, UINT64 TriggerValue, UINT64 Flags) override;
IFACEMETHOD(CreateSectionBackedMmioRange)( IFACEMETHOD(CreateSectionBackedMmioRange)(
const GUID& InstanceId, UINT8 BarIndex, UINT64 BarOffsetInPages, UINT64 PageCount, UINT64 MappingFlags, HANDLE SectionHandle, UINT64 SectionOffsetInPages) override; const GUID& InstanceId, UINT8 BarIndex, UINT64 BarOffsetInPages, UINT64 PageCount, UINT64 MappingFlags, HANDLE SectionHandle, UINT64 SectionOffsetInPages) override;
IFACEMETHOD(DestroySectionBackedMmioRange)(const GUID& InstanceId, UINT8 BarIndex, UINT64 BarOffsetInPages) override; IFACEMETHOD(DestroySectionBackedMmioRange)(const GUID& InstanceId, UINT8 BarIndex, UINT64 BarOffsetInPages) override;
private: private:
struct RemoteFileSystemInfo struct RemoteFileSystemInfo
{ {
RemoteFileSystemInfo(GUID ImplementationClsid, const std::wstring& Tag, const wil::com_ptr<IPlan9FileSystem>& Instance) : RemoteFileSystemInfo(GUID ImplementationClsid, const std::wstring& Tag, const wil::com_ptr<IPlan9FileSystem>& Instance) :
ImplementationClsid{ImplementationClsid}, Tag{Tag}, Instance{Instance} ImplementationClsid{ImplementationClsid}, Tag{Tag}, Instance{Instance}
{ {
} }
GUID ImplementationClsid; GUID ImplementationClsid;
std::wstring Tag; std::wstring Tag;
wil::com_ptr<IPlan9FileSystem> Instance; wil::com_ptr<IPlan9FileSystem> Instance;
}; };
std::wstring m_systemId; std::wstring m_systemId;
GUID m_runtimeId; GUID m_runtimeId;
wsl::windows::common::hcs::unique_hcs_system m_system; wsl::windows::common::hcs::unique_hcs_system m_system;
wil::srwlock m_lock; wil::srwlock m_lock;
std::vector<RemoteFileSystemInfo> m_fileSystems; std::vector<RemoteFileSystemInfo> m_fileSystems;
bool m_shutdown; bool m_shutdown;
struct DeviceHostProxyEntry struct DeviceHostProxyEntry
{ {
wil::com_ptr<IVmFiovGuestMemoryFastNotification> MemoryNotification; wil::com_ptr<IVmFiovGuestMemoryFastNotification> MemoryNotification;
wil::com_ptr<IVmFiovGuestMmioMappings> MemoryMapping; wil::com_ptr<IVmFiovGuestMmioMappings> MemoryMapping;
size_t DoorbellCount = 0; size_t DoorbellCount = 0;
}; };
wil::com_ptr<IVmVirtualDeviceAccess> m_deviceAccess; wil::com_ptr<IVmVirtualDeviceAccess> m_deviceAccess;
wil::srwlock m_devicesLock; wil::srwlock m_devicesLock;
std::map<GUID, DeviceHostProxyEntry, wsl::windows::common::helpers::GuidLess> m_devices; std::map<GUID, DeviceHostProxyEntry, wsl::windows::common::helpers::GuidLess> m_devices;
bool m_devicesShutdown; bool m_devicesShutdown;
static constexpr LPCWSTR c_hdvModuleName = L"vmdevicehost.dll"; static constexpr LPCWSTR c_hdvModuleName = L"vmdevicehost.dll";
static constexpr LPCWSTR c_vmwpctrlModuleName = L"vmwpctrl.dll"; static constexpr LPCWSTR c_vmwpctrlModuleName = L"vmwpctrl.dll";
}; };

View File

@ -1,417 +1,417 @@
// Copyright (C) Microsoft Corporation. All rights reserved. // Copyright (C) Microsoft Corporation. All rights reserved.
#include <LxssDynamicFunction.h> #include <LxssDynamicFunction.h>
#include "precomp.h" #include "precomp.h"
#include "DnsResolver.h" #include "DnsResolver.h"
using wsl::core::networking::DnsResolver; using wsl::core::networking::DnsResolver;
static constexpr auto c_dnsModuleName = L"dnsapi.dll"; static constexpr auto c_dnsModuleName = L"dnsapi.dll";
std::optional<LxssDynamicFunction<decltype(DnsQueryRaw)>> DnsResolver::s_dnsQueryRaw; std::optional<LxssDynamicFunction<decltype(DnsQueryRaw)>> DnsResolver::s_dnsQueryRaw;
std::optional<LxssDynamicFunction<decltype(DnsCancelQueryRaw)>> DnsResolver::s_dnsCancelQueryRaw; std::optional<LxssDynamicFunction<decltype(DnsCancelQueryRaw)>> DnsResolver::s_dnsCancelQueryRaw;
std::optional<LxssDynamicFunction<decltype(DnsQueryRawResultFree)>> DnsResolver::s_dnsQueryRawResultFree; std::optional<LxssDynamicFunction<decltype(DnsQueryRawResultFree)>> DnsResolver::s_dnsQueryRawResultFree;
HRESULT DnsResolver::LoadDnsResolverMethods() noexcept HRESULT DnsResolver::LoadDnsResolverMethods() noexcept
{ {
static wil::shared_hmodule dnsModule; static wil::shared_hmodule dnsModule;
static DWORD loadError = ERROR_SUCCESS; static DWORD loadError = ERROR_SUCCESS;
static std::once_flag dnsLoadFlag; static std::once_flag dnsLoadFlag;
// Load DNS dll only once // Load DNS dll only once
std::call_once(dnsLoadFlag, [&]() { std::call_once(dnsLoadFlag, [&]() {
dnsModule.reset(LoadLibraryEx(c_dnsModuleName, nullptr, LOAD_LIBRARY_SEARCH_SYSTEM32)); dnsModule.reset(LoadLibraryEx(c_dnsModuleName, nullptr, LOAD_LIBRARY_SEARCH_SYSTEM32));
if (!dnsModule) if (!dnsModule)
{ {
loadError = GetLastError(); loadError = GetLastError();
} }
}); });
RETURN_IF_WIN32_ERROR_MSG(loadError, "LoadLibraryEx %ls", c_dnsModuleName); RETURN_IF_WIN32_ERROR_MSG(loadError, "LoadLibraryEx %ls", c_dnsModuleName);
// Initialize dynamic functions for the DNS tunneling Windows APIs. // Initialize dynamic functions for the DNS tunneling Windows APIs.
// using the non-throwing instance of LxssDynamicFunction as to not end up in the Error telemetry // using the non-throwing instance of LxssDynamicFunction as to not end up in the Error telemetry
LxssDynamicFunction<decltype(DnsQueryRaw)> local_dnsQueryRaw{DynamicFunctionErrorLogs::None}; LxssDynamicFunction<decltype(DnsQueryRaw)> local_dnsQueryRaw{DynamicFunctionErrorLogs::None};
RETURN_IF_FAILED_EXPECTED(local_dnsQueryRaw.load(dnsModule, "DnsQueryRaw")); RETURN_IF_FAILED_EXPECTED(local_dnsQueryRaw.load(dnsModule, "DnsQueryRaw"));
LxssDynamicFunction<decltype(DnsCancelQueryRaw)> local_dnsCancelQueryRaw{DynamicFunctionErrorLogs::None}; LxssDynamicFunction<decltype(DnsCancelQueryRaw)> local_dnsCancelQueryRaw{DynamicFunctionErrorLogs::None};
RETURN_IF_FAILED_EXPECTED(local_dnsCancelQueryRaw.load(dnsModule, "DnsCancelQueryRaw")); RETURN_IF_FAILED_EXPECTED(local_dnsCancelQueryRaw.load(dnsModule, "DnsCancelQueryRaw"));
LxssDynamicFunction<decltype(DnsQueryRawResultFree)> local_dnsQueryRawResultFree{DynamicFunctionErrorLogs::None}; LxssDynamicFunction<decltype(DnsQueryRawResultFree)> local_dnsQueryRawResultFree{DynamicFunctionErrorLogs::None};
RETURN_IF_FAILED_EXPECTED(local_dnsQueryRawResultFree.load(dnsModule, "DnsQueryRawResultFree")); RETURN_IF_FAILED_EXPECTED(local_dnsQueryRawResultFree.load(dnsModule, "DnsQueryRawResultFree"));
// Make a dummy call to the DNS APIs to verify if they are working. The APIs are going to be present // Make a dummy call to the DNS APIs to verify if they are working. The APIs are going to be present
// on older Windows versions, where they can be turned on/off. If turned off, the APIs // on older Windows versions, where they can be turned on/off. If turned off, the APIs
// will be unusable and will return ERROR_CALL_NOT_IMPLEMENTED. // will be unusable and will return ERROR_CALL_NOT_IMPLEMENTED.
if (local_dnsQueryRaw(nullptr, nullptr) == ERROR_CALL_NOT_IMPLEMENTED) if (local_dnsQueryRaw(nullptr, nullptr) == ERROR_CALL_NOT_IMPLEMENTED)
{ {
RETURN_IF_WIN32_ERROR_EXPECTED(ERROR_CALL_NOT_IMPLEMENTED); RETURN_IF_WIN32_ERROR_EXPECTED(ERROR_CALL_NOT_IMPLEMENTED);
} }
s_dnsQueryRaw.emplace(std::move(local_dnsQueryRaw)); s_dnsQueryRaw.emplace(std::move(local_dnsQueryRaw));
s_dnsCancelQueryRaw.emplace(std::move(local_dnsCancelQueryRaw)); s_dnsCancelQueryRaw.emplace(std::move(local_dnsCancelQueryRaw));
s_dnsQueryRawResultFree.emplace(std::move(local_dnsQueryRawResultFree)); s_dnsQueryRawResultFree.emplace(std::move(local_dnsQueryRawResultFree));
return S_OK; return S_OK;
} }
DnsResolver::DnsResolver(wil::unique_socket&& dnsHvsocket, DnsResolverFlags flags) : DnsResolver::DnsResolver(wil::unique_socket&& dnsHvsocket, DnsResolverFlags flags) :
m_dnsChannel( m_dnsChannel(
std::move(dnsHvsocket), std::move(dnsHvsocket),
[this](const gsl::span<gsl::byte> dnsBuffer, const LX_GNS_DNS_CLIENT_IDENTIFIER& dnsClientIdentifier) { [this](const gsl::span<gsl::byte> dnsBuffer, const LX_GNS_DNS_CLIENT_IDENTIFIER& dnsClientIdentifier) {
ProcessDnsRequest(dnsBuffer, dnsClientIdentifier); ProcessDnsRequest(dnsBuffer, dnsClientIdentifier);
}), }),
m_flags(flags) m_flags(flags)
{ {
// Initialize as signaled, as there are no requests yet // Initialize as signaled, as there are no requests yet
m_allRequestsFinished.SetEvent(); m_allRequestsFinished.SetEvent();
// Read external interface constraint regkey // Read external interface constraint regkey
const auto lxssKey = windows::common::registry::OpenLxssMachineKey(KEY_READ); const auto lxssKey = windows::common::registry::OpenLxssMachineKey(KEY_READ);
m_externalInterfaceConstraintName = m_externalInterfaceConstraintName =
windows::common::registry::ReadString(lxssKey.get(), nullptr, c_interfaceConstraintKey, L""); windows::common::registry::ReadString(lxssKey.get(), nullptr, c_interfaceConstraintKey, L"");
if (!m_externalInterfaceConstraintName.empty()) if (!m_externalInterfaceConstraintName.empty())
{ {
ResolveExternalInterfaceConstraintIndex(); ResolveExternalInterfaceConstraintIndex();
WSL_LOG( WSL_LOG(
"DnsResolver::DnsResolver", "DnsResolver::DnsResolver",
TraceLoggingValue(m_externalInterfaceConstraintName.c_str(), "m_externalInterfaceConstraintName"), TraceLoggingValue(m_externalInterfaceConstraintName.c_str(), "m_externalInterfaceConstraintName"),
TraceLoggingValue(m_externalInterfaceConstraintIndex, "m_externalInterfaceConstraintIndex")); TraceLoggingValue(m_externalInterfaceConstraintIndex, "m_externalInterfaceConstraintIndex"));
// Register for interface change notifications. Notifications are used to determine if the external interface constraint setting is applicable. // Register for interface change notifications. Notifications are used to determine if the external interface constraint setting is applicable.
THROW_IF_WIN32_ERROR(NotifyIpInterfaceChange(AF_UNSPEC, &DnsResolver::InterfaceChangeCallback, this, FALSE, &m_interfaceNotificationHandle)); THROW_IF_WIN32_ERROR(NotifyIpInterfaceChange(AF_UNSPEC, &DnsResolver::InterfaceChangeCallback, this, FALSE, &m_interfaceNotificationHandle));
} }
} }
DnsResolver::~DnsResolver() noexcept DnsResolver::~DnsResolver() noexcept
{ {
Stop(); Stop();
} }
void DnsResolver::GenerateTelemetry() noexcept void DnsResolver::GenerateTelemetry() noexcept
try try
{ {
// Find the 3 most common DNS API failures // Find the 3 most common DNS API failures
uint32_t mostCommonDnsStatusError = 0; uint32_t mostCommonDnsStatusError = 0;
uint32_t mostCommonDnsStatusErrorCount = 0; uint32_t mostCommonDnsStatusErrorCount = 0;
uint32_t secondCommonDnsStatusError = 0; uint32_t secondCommonDnsStatusError = 0;
uint32_t secondCommonDnsStatusErrorCount = 0; uint32_t secondCommonDnsStatusErrorCount = 0;
uint32_t thirdCommonDnsStatusError = 0; uint32_t thirdCommonDnsStatusError = 0;
uint32_t thirdCommonDnsStatusErrorCount = 0; uint32_t thirdCommonDnsStatusErrorCount = 0;
std::vector<std::pair<uint32_t, uint32_t>> failures(m_dnsApiFailures.size()); std::vector<std::pair<uint32_t, uint32_t>> failures(m_dnsApiFailures.size());
std::copy(m_dnsApiFailures.begin(), m_dnsApiFailures.end(), failures.begin()); std::copy(m_dnsApiFailures.begin(), m_dnsApiFailures.end(), failures.begin());
// Sort in descending order based on failure count // Sort in descending order based on failure count
std::sort(failures.begin(), failures.end(), [](const auto& lhs, const auto& rhs) { return lhs.second > rhs.second; }); std::sort(failures.begin(), failures.end(), [](const auto& lhs, const auto& rhs) { return lhs.second > rhs.second; });
if (failures.size() >= 1) if (failures.size() >= 1)
{ {
mostCommonDnsStatusError = failures[0].first; mostCommonDnsStatusError = failures[0].first;
mostCommonDnsStatusErrorCount = failures[0].second; mostCommonDnsStatusErrorCount = failures[0].second;
} }
if (failures.size() >= 2) if (failures.size() >= 2)
{ {
secondCommonDnsStatusError = failures[1].first; secondCommonDnsStatusError = failures[1].first;
secondCommonDnsStatusErrorCount = failures[1].second; secondCommonDnsStatusErrorCount = failures[1].second;
} }
if (failures.size() >= 3) if (failures.size() >= 3)
{ {
thirdCommonDnsStatusError = failures[2].first; thirdCommonDnsStatusError = failures[2].first;
thirdCommonDnsStatusErrorCount = failures[2].second; thirdCommonDnsStatusErrorCount = failures[2].second;
} }
// Add telemetry with DNS tunneling statistics, before shutting down // Add telemetry with DNS tunneling statistics, before shutting down
WSL_LOG( WSL_LOG(
"DnsTunnelingStatistics", "DnsTunnelingStatistics",
TraceLoggingValue(m_totalUdpQueries.load(), "totalUdpQueries"), TraceLoggingValue(m_totalUdpQueries.load(), "totalUdpQueries"),
TraceLoggingValue(m_successfulUdpQueries.load(), "successfulUdpQueries"), TraceLoggingValue(m_successfulUdpQueries.load(), "successfulUdpQueries"),
TraceLoggingValue(m_totalTcpQueries.load(), "totalTcpQueries"), TraceLoggingValue(m_totalTcpQueries.load(), "totalTcpQueries"),
TraceLoggingValue(m_successfulTcpQueries.load(), "successfulTcpQueries"), TraceLoggingValue(m_successfulTcpQueries.load(), "successfulTcpQueries"),
TraceLoggingValue(m_queriesWithNullResult.load(), "queriesWithNullResult"), TraceLoggingValue(m_queriesWithNullResult.load(), "queriesWithNullResult"),
TraceLoggingValue(m_failedDnsQueryRawCalls.load(), "FailedDnsQueryRawCalls"), TraceLoggingValue(m_failedDnsQueryRawCalls.load(), "FailedDnsQueryRawCalls"),
TraceLoggingValue(m_dnsApiFailures.size(), "totalDnsStatusErrorInstances"), TraceLoggingValue(m_dnsApiFailures.size(), "totalDnsStatusErrorInstances"),
TraceLoggingValue(mostCommonDnsStatusError, "mostCommonDnsStatusError"), TraceLoggingValue(mostCommonDnsStatusError, "mostCommonDnsStatusError"),
TraceLoggingValue(mostCommonDnsStatusErrorCount, "mostCommonDnsStatusErrorCount"), TraceLoggingValue(mostCommonDnsStatusErrorCount, "mostCommonDnsStatusErrorCount"),
TraceLoggingValue(secondCommonDnsStatusError, "secondCommonDnsStatusError"), TraceLoggingValue(secondCommonDnsStatusError, "secondCommonDnsStatusError"),
TraceLoggingValue(secondCommonDnsStatusErrorCount, "secondCommonDnsStatusErrorCount"), TraceLoggingValue(secondCommonDnsStatusErrorCount, "secondCommonDnsStatusErrorCount"),
TraceLoggingValue(thirdCommonDnsStatusError, "thirdCommonDnsStatusError"), TraceLoggingValue(thirdCommonDnsStatusError, "thirdCommonDnsStatusError"),
TraceLoggingValue(thirdCommonDnsStatusErrorCount, "thirdCommonDnsStatusErrorCount")); TraceLoggingValue(thirdCommonDnsStatusErrorCount, "thirdCommonDnsStatusErrorCount"));
} }
CATCH_LOG() CATCH_LOG()
void DnsResolver::Stop() noexcept void DnsResolver::Stop() noexcept
try try
{ {
WSL_LOG("DnsResolver::Stop"); WSL_LOG("DnsResolver::Stop");
// Scoped m_dnsLock // Scoped m_dnsLock
{ {
const std::lock_guard lock(m_dnsLock); const std::lock_guard lock(m_dnsLock);
m_stopped = true; m_stopped = true;
// Cancel existing requests. Cancel is complete when DnsQueryRawCallback is // Cancel existing requests. Cancel is complete when DnsQueryRawCallback is
// invoked with status == ERROR_CANCELLED // invoked with status == ERROR_CANCELLED
// N.B. Cancelling can end up calling the DnsQueryRawCallback directly on this same thread. i.e., while this // N.B. Cancelling can end up calling the DnsQueryRawCallback directly on this same thread. i.e., while this
// lock is held. Which is fine because m_dnsLock is a recursive mutex. // lock is held. Which is fine because m_dnsLock is a recursive mutex.
// N.B. Cancelling a query will synchronously remove the query from m_dnsRequests, which invalidates iterators. // N.B. Cancelling a query will synchronously remove the query from m_dnsRequests, which invalidates iterators.
std::vector<DNS_QUERY_RAW_CANCEL*> cancelHandles; std::vector<DNS_QUERY_RAW_CANCEL*> cancelHandles;
cancelHandles.reserve(m_dnsRequests.size()); cancelHandles.reserve(m_dnsRequests.size());
for (auto& [_, context] : m_dnsRequests) for (auto& [_, context] : m_dnsRequests)
{ {
cancelHandles.emplace_back(&context->m_cancelHandle); cancelHandles.emplace_back(&context->m_cancelHandle);
} }
for (const auto e : cancelHandles) for (const auto e : cancelHandles)
{ {
LOG_IF_WIN32_ERROR(s_dnsCancelQueryRaw.value()(e)); LOG_IF_WIN32_ERROR(s_dnsCancelQueryRaw.value()(e));
} }
} }
// Wait for all requests to complete. At this point no new requests can be started since the object is stopped. // Wait for all requests to complete. At this point no new requests can be started since the object is stopped.
// We are only waiting for existing requests to finish. // We are only waiting for existing requests to finish.
m_allRequestsFinished.wait(); m_allRequestsFinished.wait();
// Stop the response queue first as it can make calls in m_dnsChannel // Stop the response queue first as it can make calls in m_dnsChannel
m_dnsResponseQueue.cancel(); m_dnsResponseQueue.cancel();
m_dnsChannel.Stop(); m_dnsChannel.Stop();
// Stop interface change notifications // Stop interface change notifications
m_interfaceNotificationHandle.reset(); m_interfaceNotificationHandle.reset();
GenerateTelemetry(); GenerateTelemetry();
} }
CATCH_LOG() CATCH_LOG()
void DnsResolver::ProcessDnsRequest(const gsl::span<gsl::byte> dnsBuffer, const LX_GNS_DNS_CLIENT_IDENTIFIER& dnsClientIdentifier) noexcept void DnsResolver::ProcessDnsRequest(const gsl::span<gsl::byte> dnsBuffer, const LX_GNS_DNS_CLIENT_IDENTIFIER& dnsClientIdentifier) noexcept
try try
{ {
const std::lock_guard lock(m_dnsLock); const std::lock_guard lock(m_dnsLock);
if (m_stopped) if (m_stopped)
{ {
return; return;
} }
WSL_LOG_DEBUG( WSL_LOG_DEBUG(
"DnsResolver::ProcessDnsRequest - received new DNS request", "DnsResolver::ProcessDnsRequest - received new DNS request",
TraceLoggingValue(dnsBuffer.size(), "DNS buffer size"), TraceLoggingValue(dnsBuffer.size(), "DNS buffer size"),
TraceLoggingValue(dnsClientIdentifier.Protocol == IPPROTO_UDP ? "UDP" : "TCP", "Protocol"), TraceLoggingValue(dnsClientIdentifier.Protocol == IPPROTO_UDP ? "UDP" : "TCP", "Protocol"),
TraceLoggingValue(dnsClientIdentifier.DnsClientId, "DNS client id"), TraceLoggingValue(dnsClientIdentifier.DnsClientId, "DNS client id"),
TraceLoggingValue(!m_externalInterfaceConstraintName.empty(), "Is ExternalInterfaceConstraint configured"), TraceLoggingValue(!m_externalInterfaceConstraintName.empty(), "Is ExternalInterfaceConstraint configured"),
TraceLoggingValue(m_externalInterfaceConstraintIndex, "m_externalInterfaceConstraintIndex")); TraceLoggingValue(m_externalInterfaceConstraintIndex, "m_externalInterfaceConstraintIndex"));
// If the external interface constraint is configured but it is *not* present/up, WSL should be net-blind, so we avoid making DNS requests. // If the external interface constraint is configured but it is *not* present/up, WSL should be net-blind, so we avoid making DNS requests.
if (!m_externalInterfaceConstraintName.empty() && m_externalInterfaceConstraintIndex == 0) if (!m_externalInterfaceConstraintName.empty() && m_externalInterfaceConstraintIndex == 0)
{ {
return; return;
} }
dnsClientIdentifier.Protocol == IPPROTO_UDP ? m_totalUdpQueries++ : m_totalTcpQueries++; dnsClientIdentifier.Protocol == IPPROTO_UDP ? m_totalUdpQueries++ : m_totalTcpQueries++;
// Get next request id. If value reaches UINT_MAX + 1 it will be automatically reset to 0 // Get next request id. If value reaches UINT_MAX + 1 it will be automatically reset to 0
const auto requestId = m_currentRequestId++; const auto requestId = m_currentRequestId++;
// Create the DNS request context // Create the DNS request context
auto context = std::make_unique<DnsResolver::DnsQueryContext>( auto context = std::make_unique<DnsResolver::DnsQueryContext>(
requestId, dnsClientIdentifier, [this](_Inout_ DnsResolver::DnsQueryContext* context, _Inout_opt_ DNS_QUERY_RAW_RESULT* queryResults) { requestId, dnsClientIdentifier, [this](_Inout_ DnsResolver::DnsQueryContext* context, _Inout_opt_ DNS_QUERY_RAW_RESULT* queryResults) {
HandleDnsQueryCompletion(context, queryResults); HandleDnsQueryCompletion(context, queryResults);
}); });
auto [it, _] = m_dnsRequests.emplace(requestId, std::move(context)); auto [it, _] = m_dnsRequests.emplace(requestId, std::move(context));
const auto localContext = it->second.get(); const auto localContext = it->second.get();
auto removeContextOnError = wil::scope_exit([&] { WI_VERIFY(m_dnsRequests.erase(requestId) == 1); }); auto removeContextOnError = wil::scope_exit([&] { WI_VERIFY(m_dnsRequests.erase(requestId) == 1); });
// Fill DNS request structure // Fill DNS request structure
DNS_QUERY_RAW_REQUEST request{}; DNS_QUERY_RAW_REQUEST request{};
request.version = DNS_QUERY_RAW_REQUEST_VERSION1; request.version = DNS_QUERY_RAW_REQUEST_VERSION1;
request.resultsVersion = DNS_QUERY_RAW_RESULTS_VERSION1; request.resultsVersion = DNS_QUERY_RAW_RESULTS_VERSION1;
request.dnsQueryRawSize = static_cast<ULONG>(dnsBuffer.size()); request.dnsQueryRawSize = static_cast<ULONG>(dnsBuffer.size());
request.dnsQueryRaw = (PBYTE)dnsBuffer.data(); request.dnsQueryRaw = (PBYTE)dnsBuffer.data();
request.protocol = (dnsClientIdentifier.Protocol == IPPROTO_TCP) ? DNS_PROTOCOL_TCP : DNS_PROTOCOL_UDP; request.protocol = (dnsClientIdentifier.Protocol == IPPROTO_TCP) ? DNS_PROTOCOL_TCP : DNS_PROTOCOL_UDP;
request.queryCompletionCallback = DnsResolver::DnsQueryRawCallback; request.queryCompletionCallback = DnsResolver::DnsQueryRawCallback;
request.queryContext = localContext; request.queryContext = localContext;
// Only unicast UDP & TCP queries are tunneled. Pass this flag to tell Windows DNS client to *not* resolve using multicast. // Only unicast UDP & TCP queries are tunneled. Pass this flag to tell Windows DNS client to *not* resolve using multicast.
request.queryOptions |= DNS_QUERY_NO_MULTICAST; request.queryOptions |= DNS_QUERY_NO_MULTICAST;
// In a DNS request from Linux there might be DNS records that Windows DNS client does not know how to parse. // In a DNS request from Linux there might be DNS records that Windows DNS client does not know how to parse.
// By default in this case Windows will fail the request. When the flag is enabled, Windows will extract the // By default in this case Windows will fail the request. When the flag is enabled, Windows will extract the
// question from the DNS request and attempt to resolve it, ignoring the unknown records. // question from the DNS request and attempt to resolve it, ignoring the unknown records.
if (WI_IsFlagSet(m_flags, DnsResolverFlags::BestEffortDnsParsing)) if (WI_IsFlagSet(m_flags, DnsResolverFlags::BestEffortDnsParsing))
{ {
request.queryRawOptions |= DNS_QUERY_RAW_OPTION_BEST_EFFORT_PARSE; request.queryRawOptions |= DNS_QUERY_RAW_OPTION_BEST_EFFORT_PARSE;
} }
// If the external interface constraint is configured and present on the host, only send DNS requests on that interface. // If the external interface constraint is configured and present on the host, only send DNS requests on that interface.
if (m_externalInterfaceConstraintIndex != 0) if (m_externalInterfaceConstraintIndex != 0)
{ {
request.interfaceIndex = m_externalInterfaceConstraintIndex; request.interfaceIndex = m_externalInterfaceConstraintIndex;
} }
// Start the DNS request // Start the DNS request
// N.B. All DNS requests will bypass the Windows DNS cache // N.B. All DNS requests will bypass the Windows DNS cache
const auto result = s_dnsQueryRaw.value()(&request, &localContext->m_cancelHandle); const auto result = s_dnsQueryRaw.value()(&request, &localContext->m_cancelHandle);
if (result != DNS_REQUEST_PENDING) if (result != DNS_REQUEST_PENDING)
{ {
m_failedDnsQueryRawCalls++; m_failedDnsQueryRawCalls++;
WSL_LOG( WSL_LOG(
"ProcessDnsRequestFailed", "ProcessDnsRequestFailed",
TraceLoggingValue(requestId, "requestId"), TraceLoggingValue(requestId, "requestId"),
TraceLoggingValue(result, "result"), TraceLoggingValue(result, "result"),
TraceLoggingValue("DnsQueryRaw", "executionStep")); TraceLoggingValue("DnsQueryRaw", "executionStep"));
return; return;
} }
removeContextOnError.release(); removeContextOnError.release();
m_allRequestsFinished.ResetEvent(); m_allRequestsFinished.ResetEvent();
} }
CATCH_LOG() CATCH_LOG()
void DnsResolver::HandleDnsQueryCompletion(_Inout_ DnsResolver::DnsQueryContext* queryContext, _Inout_opt_ DNS_QUERY_RAW_RESULT* queryResults) noexcept void DnsResolver::HandleDnsQueryCompletion(_Inout_ DnsResolver::DnsQueryContext* queryContext, _Inout_opt_ DNS_QUERY_RAW_RESULT* queryResults) noexcept
try try
{ {
// Always free the query result structure // Always free the query result structure
const auto freeQueryResults = wil::scope_exit([&] { const auto freeQueryResults = wil::scope_exit([&] {
if (queryResults != nullptr) if (queryResults != nullptr)
{ {
s_dnsQueryRawResultFree.value()(queryResults); s_dnsQueryRawResultFree.value()(queryResults);
} }
}); });
const std::lock_guard lock(m_dnsLock); const std::lock_guard lock(m_dnsLock);
if (queryResults != nullptr) if (queryResults != nullptr)
{ {
WSL_LOG( WSL_LOG(
"DnsResolver::HandleDnsQueryCompletion", "DnsResolver::HandleDnsQueryCompletion",
TraceLoggingValue(queryContext->m_id, "queryContext->m_id"), TraceLoggingValue(queryContext->m_id, "queryContext->m_id"),
TraceLoggingValue(queryResults->queryStatus, "queryResults->queryStatus"), TraceLoggingValue(queryResults->queryStatus, "queryResults->queryStatus"),
TraceLoggingValue(queryResults->queryRawResponse != nullptr, "validResponse")); TraceLoggingValue(queryResults->queryRawResponse != nullptr, "validResponse"));
// Note: The response may be valid even if queryResults->queryStatus is not 0, for example when the DNS server returns a negative response. // Note: The response may be valid even if queryResults->queryStatus is not 0, for example when the DNS server returns a negative response.
if (queryResults->queryRawResponse != nullptr) if (queryResults->queryRawResponse != nullptr)
{ {
queryContext->m_dnsClientIdentifier.Protocol == IPPROTO_UDP ? m_successfulUdpQueries++ : m_successfulTcpQueries++; queryContext->m_dnsClientIdentifier.Protocol == IPPROTO_UDP ? m_successfulUdpQueries++ : m_successfulTcpQueries++;
} }
// the Windows DNS API returned failure // the Windows DNS API returned failure
else else
{ {
if (m_dnsApiFailures.find(queryResults->queryStatus) == m_dnsApiFailures.end()) if (m_dnsApiFailures.find(queryResults->queryStatus) == m_dnsApiFailures.end())
{ {
m_dnsApiFailures[queryResults->queryStatus] = 1; m_dnsApiFailures[queryResults->queryStatus] = 1;
} }
else else
{ {
m_dnsApiFailures[queryResults->queryStatus]++; m_dnsApiFailures[queryResults->queryStatus]++;
} }
} }
} }
else else
{ {
WSL_LOG( WSL_LOG(
"DnsResolver::HandleDnsQueryCompletion - received a NULL queryResults", "DnsResolver::HandleDnsQueryCompletion - received a NULL queryResults",
TraceLoggingValue(queryContext->m_id, "queryContext->m_id")); TraceLoggingValue(queryContext->m_id, "queryContext->m_id"));
m_queriesWithNullResult++; m_queriesWithNullResult++;
} }
if (!m_stopped && queryResults != nullptr && queryResults->queryRawResponse != nullptr) if (!m_stopped && queryResults != nullptr && queryResults->queryRawResponse != nullptr)
{ {
// Copy DNS response buffer // Copy DNS response buffer
std::vector<gsl::byte> dnsResponse(queryResults->queryRawResponseSize); std::vector<gsl::byte> dnsResponse(queryResults->queryRawResponseSize);
CopyMemory(dnsResponse.data(), queryResults->queryRawResponse, queryResults->queryRawResponseSize); CopyMemory(dnsResponse.data(), queryResults->queryRawResponse, queryResults->queryRawResponseSize);
WSL_LOG_DEBUG( WSL_LOG_DEBUG(
"DnsResolver::HandleDnsQueryCompletion - received new DNS response", "DnsResolver::HandleDnsQueryCompletion - received new DNS response",
TraceLoggingValue(dnsResponse.size(), "DNS buffer size"), TraceLoggingValue(dnsResponse.size(), "DNS buffer size"),
TraceLoggingValue(queryContext->m_dnsClientIdentifier.Protocol == IPPROTO_UDP ? "UDP" : "TCP", "Protocol"), TraceLoggingValue(queryContext->m_dnsClientIdentifier.Protocol == IPPROTO_UDP ? "UDP" : "TCP", "Protocol"),
TraceLoggingValue(queryContext->m_dnsClientIdentifier.DnsClientId, "DNS client id")); TraceLoggingValue(queryContext->m_dnsClientIdentifier.DnsClientId, "DNS client id"));
// Schedule the DNS response to be sent to Linux // Schedule the DNS response to be sent to Linux
m_dnsResponseQueue.submit([this, dnsResponse = std::move(dnsResponse), dnsClientIdentifier = queryContext->m_dnsClientIdentifier]() mutable { m_dnsResponseQueue.submit([this, dnsResponse = std::move(dnsResponse), dnsClientIdentifier = queryContext->m_dnsClientIdentifier]() mutable {
m_dnsChannel.SendDnsMessage(gsl::make_span(dnsResponse), dnsClientIdentifier); m_dnsChannel.SendDnsMessage(gsl::make_span(dnsResponse), dnsClientIdentifier);
}); });
} }
// Stop tracking this DNS request and delete the request context // Stop tracking this DNS request and delete the request context
WI_VERIFY(m_dnsRequests.erase(queryContext->m_id) == 1); WI_VERIFY(m_dnsRequests.erase(queryContext->m_id) == 1);
// Set event if all tracked requests have finished // Set event if all tracked requests have finished
if (m_dnsRequests.empty()) if (m_dnsRequests.empty())
{ {
m_allRequestsFinished.SetEvent(); m_allRequestsFinished.SetEvent();
} }
} }
CATCH_LOG() CATCH_LOG()
void DnsResolver::ResolveExternalInterfaceConstraintIndex() noexcept void DnsResolver::ResolveExternalInterfaceConstraintIndex() noexcept
try try
{ {
const std::lock_guard lock(m_dnsLock); const std::lock_guard lock(m_dnsLock);
if (m_stopped) if (m_stopped)
{ {
return; return;
} }
if (m_externalInterfaceConstraintName.empty()) if (m_externalInterfaceConstraintName.empty())
{ {
return; return;
} }
NET_LUID interfaceLuid{}; NET_LUID interfaceLuid{};
ULONG interfaceIndex = 0; ULONG interfaceIndex = 0;
// Update the interface index on every exit path. // Update the interface index on every exit path.
// The calls below to convert interface name to index will fail if the interface does not exist anymore, // The calls below to convert interface name to index will fail if the interface does not exist anymore,
// in which case we still need to reset the interface index to its default value of 0. // in which case we still need to reset the interface index to its default value of 0.
const auto setInterfaceIndex = wil::scope_exit([&] { const auto setInterfaceIndex = wil::scope_exit([&] {
if (interfaceIndex != m_externalInterfaceConstraintIndex) if (interfaceIndex != m_externalInterfaceConstraintIndex)
{ {
WSL_LOG( WSL_LOG(
"DnsResolver::ResolveExternalInterfaceConstraintIndex - setting m_externalInterfaceConstraintIndex to new value", "DnsResolver::ResolveExternalInterfaceConstraintIndex - setting m_externalInterfaceConstraintIndex to new value",
TraceLoggingValue(m_externalInterfaceConstraintIndex, "old interface index"), TraceLoggingValue(m_externalInterfaceConstraintIndex, "old interface index"),
TraceLoggingValue(interfaceIndex, "new interface index")); TraceLoggingValue(interfaceIndex, "new interface index"));
m_externalInterfaceConstraintIndex = interfaceIndex; m_externalInterfaceConstraintIndex = interfaceIndex;
} }
}); });
// If external interface constraint is configured, query to see if it's present on the host. // If external interface constraint is configured, query to see if it's present on the host.
auto errorCode = ConvertInterfaceAliasToLuid(m_externalInterfaceConstraintName.c_str(), &interfaceLuid); auto errorCode = ConvertInterfaceAliasToLuid(m_externalInterfaceConstraintName.c_str(), &interfaceLuid);
if (FAILED_WIN32_LOG(errorCode)) if (FAILED_WIN32_LOG(errorCode))
{ {
return; return;
} }
errorCode = ConvertInterfaceLuidToIndex(&interfaceLuid, reinterpret_cast<PNET_IFINDEX>(&interfaceIndex)); errorCode = ConvertInterfaceLuidToIndex(&interfaceLuid, reinterpret_cast<PNET_IFINDEX>(&interfaceIndex));
if (FAILED_WIN32_LOG(errorCode)) if (FAILED_WIN32_LOG(errorCode))
{ {
return; return;
} }
} }
CATCH_LOG() CATCH_LOG()
VOID CALLBACK DnsResolver::DnsQueryRawCallback(_In_ VOID* queryContext, _Inout_opt_ DNS_QUERY_RAW_RESULT* queryResults) noexcept VOID CALLBACK DnsResolver::DnsQueryRawCallback(_In_ VOID* queryContext, _Inout_opt_ DNS_QUERY_RAW_RESULT* queryResults) noexcept
try try
{ {
assert(queryContext != nullptr); assert(queryContext != nullptr);
const auto context = static_cast<DnsQueryContext*>(queryContext); const auto context = static_cast<DnsQueryContext*>(queryContext);
// Call into DnsResolver parent object to process the query result // Call into DnsResolver parent object to process the query result
context->m_handleQueryCompletion(context, queryResults); context->m_handleQueryCompletion(context, queryResults);
} }
CATCH_LOG() CATCH_LOG()
VOID CALLBACK DnsResolver::InterfaceChangeCallback(_In_ PVOID context, PMIB_IPINTERFACE_ROW, MIB_NOTIFICATION_TYPE) noexcept VOID CALLBACK DnsResolver::InterfaceChangeCallback(_In_ PVOID context, PMIB_IPINTERFACE_ROW, MIB_NOTIFICATION_TYPE) noexcept
try try
{ {
const auto dnsResolver = static_cast<DnsResolver*>(context); const auto dnsResolver = static_cast<DnsResolver*>(context);
dnsResolver->ResolveExternalInterfaceConstraintIndex(); dnsResolver->ResolveExternalInterfaceConstraintIndex();
} }
CATCH_LOG() CATCH_LOG()

View File

@ -1,141 +1,141 @@
// Copyright (C) Microsoft Corporation. All rights reserved. // Copyright (C) Microsoft Corporation. All rights reserved.
#pragma once #pragma once
#include "DnsTunnelingChannel.h" #include "DnsTunnelingChannel.h"
#include "WslCoreMessageQueue.h" #include "WslCoreMessageQueue.h"
#include "WslCoreNetworkingSupport.h" #include "WslCoreNetworkingSupport.h"
namespace wsl::core::networking { namespace wsl::core::networking {
enum class DnsResolverFlags enum class DnsResolverFlags
{ {
None = 0x0, None = 0x0,
BestEffortDnsParsing = 0x1 BestEffortDnsParsing = 0x1
}; };
DEFINE_ENUM_FLAG_OPERATORS(DnsResolverFlags); DEFINE_ENUM_FLAG_OPERATORS(DnsResolverFlags);
class DnsResolver class DnsResolver
{ {
public: public:
DnsResolver(wil::unique_socket&& dnsHvsocket, DnsResolverFlags flags); DnsResolver(wil::unique_socket&& dnsHvsocket, DnsResolverFlags flags);
~DnsResolver() noexcept; ~DnsResolver() noexcept;
DnsResolver(const DnsResolver&) = delete; DnsResolver(const DnsResolver&) = delete;
DnsResolver& operator=(const DnsResolver&) = delete; DnsResolver& operator=(const DnsResolver&) = delete;
DnsResolver(DnsResolver&&) = delete; DnsResolver(DnsResolver&&) = delete;
DnsResolver& operator=(DnsResolver&&) = delete; DnsResolver& operator=(DnsResolver&&) = delete;
void Stop() noexcept; void Stop() noexcept;
static HRESULT LoadDnsResolverMethods() noexcept; static HRESULT LoadDnsResolverMethods() noexcept;
private: private:
struct DnsQueryContext struct DnsQueryContext
{ {
// Struct containing protocol (TCP/UDP) and unique id of the Linux DNS client making the request. // Struct containing protocol (TCP/UDP) and unique id of the Linux DNS client making the request.
LX_GNS_DNS_CLIENT_IDENTIFIER m_dnsClientIdentifier{}; LX_GNS_DNS_CLIENT_IDENTIFIER m_dnsClientIdentifier{};
// Handle used to cancel the request. // Handle used to cancel the request.
DNS_QUERY_RAW_CANCEL m_cancelHandle{}; DNS_QUERY_RAW_CANCEL m_cancelHandle{};
// Unique query id. // Unique query id.
uint32_t m_id{}; uint32_t m_id{};
// Callback to the parent object to notify about the DNS query completion. // Callback to the parent object to notify about the DNS query completion.
std::function<void(DnsQueryContext*, DNS_QUERY_RAW_RESULT*)> m_handleQueryCompletion; std::function<void(DnsQueryContext*, DNS_QUERY_RAW_RESULT*)> m_handleQueryCompletion;
DnsQueryContext( DnsQueryContext(
uint32_t id, uint32_t id,
const LX_GNS_DNS_CLIENT_IDENTIFIER& dnsClientIdentifier, const LX_GNS_DNS_CLIENT_IDENTIFIER& dnsClientIdentifier,
std::function<void(DnsQueryContext*, DNS_QUERY_RAW_RESULT*)>&& handleQueryCompletion) : std::function<void(DnsQueryContext*, DNS_QUERY_RAW_RESULT*)>&& handleQueryCompletion) :
m_dnsClientIdentifier(dnsClientIdentifier), m_id(id), m_handleQueryCompletion(std::move(handleQueryCompletion)) m_dnsClientIdentifier(dnsClientIdentifier), m_id(id), m_handleQueryCompletion(std::move(handleQueryCompletion))
{ {
} }
~DnsQueryContext() noexcept = default; ~DnsQueryContext() noexcept = default;
DnsQueryContext(const DnsQueryContext&) = delete; DnsQueryContext(const DnsQueryContext&) = delete;
DnsQueryContext& operator=(const DnsQueryContext&) = delete; DnsQueryContext& operator=(const DnsQueryContext&) = delete;
DnsQueryContext(DnsQueryContext&&) = delete; DnsQueryContext(DnsQueryContext&&) = delete;
DnsQueryContext& operator=(DnsQueryContext&&) = delete; DnsQueryContext& operator=(DnsQueryContext&&) = delete;
}; };
void GenerateTelemetry() noexcept; void GenerateTelemetry() noexcept;
// Process DNS request received from Linux. // Process DNS request received from Linux.
// //
// Arguments: // Arguments:
// dnsBuffer - buffer containing DNS request. // dnsBuffer - buffer containing DNS request.
// dnsClientIdentifier - struct containing protocol (TCP/UDP) and unique id of the Linux DNS client making the request. // dnsClientIdentifier - struct containing protocol (TCP/UDP) and unique id of the Linux DNS client making the request.
void ProcessDnsRequest(const gsl::span<gsl::byte> dnsBuffer, const LX_GNS_DNS_CLIENT_IDENTIFIER& dnsClientIdentifier) noexcept; void ProcessDnsRequest(const gsl::span<gsl::byte> dnsBuffer, const LX_GNS_DNS_CLIENT_IDENTIFIER& dnsClientIdentifier) noexcept;
// Handle completion of DNS query. // Handle completion of DNS query.
// //
// Arguments: // Arguments:
// dnsQueryContext - context structure for the DNS request. // dnsQueryContext - context structure for the DNS request.
// queryResults - structure containing result of the DNS request. // queryResults - structure containing result of the DNS request.
void HandleDnsQueryCompletion(_Inout_ DnsQueryContext* dnsQueryContext, _Inout_opt_ DNS_QUERY_RAW_RESULT* queryResults) noexcept; void HandleDnsQueryCompletion(_Inout_ DnsQueryContext* dnsQueryContext, _Inout_opt_ DNS_QUERY_RAW_RESULT* queryResults) noexcept;
void ResolveExternalInterfaceConstraintIndex() noexcept; void ResolveExternalInterfaceConstraintIndex() noexcept;
// Callback that will be invoked by the DNS API whenever a request finishes. The callback is invoked on success, error or when request is cancelled. // Callback that will be invoked by the DNS API whenever a request finishes. The callback is invoked on success, error or when request is cancelled.
// //
// Arguments: // Arguments:
// queryContext - pointer to context structure, will be a structure of type DnsQueryContext. // queryContext - pointer to context structure, will be a structure of type DnsQueryContext.
// queryResults - pointer to structure containing the result of the DNS request. // queryResults - pointer to structure containing the result of the DNS request.
static VOID CALLBACK DnsQueryRawCallback(_In_ VOID* queryContext, _Inout_opt_ DNS_QUERY_RAW_RESULT* queryResults) noexcept; static VOID CALLBACK DnsQueryRawCallback(_In_ VOID* queryContext, _Inout_opt_ DNS_QUERY_RAW_RESULT* queryResults) noexcept;
static VOID CALLBACK InterfaceChangeCallback(_In_ PVOID context, PMIB_IPINTERFACE_ROW, MIB_NOTIFICATION_TYPE) noexcept; static VOID CALLBACK InterfaceChangeCallback(_In_ PVOID context, PMIB_IPINTERFACE_ROW, MIB_NOTIFICATION_TYPE) noexcept;
std::recursive_mutex m_dnsLock; std::recursive_mutex m_dnsLock;
// Flag used when shutting down the object. // Flag used when shutting down the object.
_Guarded_by_(m_dnsLock) bool m_stopped = false; _Guarded_by_(m_dnsLock) bool m_stopped = false;
// Hvsocket channel used to exchange DNS messages with Linux. // Hvsocket channel used to exchange DNS messages with Linux.
DnsTunnelingChannel m_dnsChannel; DnsTunnelingChannel m_dnsChannel;
// Queue used to send DNS responses to Linux. // Queue used to send DNS responses to Linux.
WslCoreMessageQueue m_dnsResponseQueue; WslCoreMessageQueue m_dnsResponseQueue;
// Unique id that is incremented for each request. In case the value reaches MAX_UINT and is reset to 0, // Unique id that is incremented for each request. In case the value reaches MAX_UINT and is reset to 0,
// it's assumed previous requests with id's 0, 1, ... finished in the meantime and the id can be reused. // it's assumed previous requests with id's 0, 1, ... finished in the meantime and the id can be reused.
_Guarded_by_(m_dnsLock) uint32_t m_currentRequestId = 0; _Guarded_by_(m_dnsLock) uint32_t m_currentRequestId = 0;
// Mapping request id to the request context structure. // Mapping request id to the request context structure.
_Guarded_by_(m_dnsLock) std::unordered_map<uint32_t, std::unique_ptr<DnsQueryContext>> m_dnsRequests {}; _Guarded_by_(m_dnsLock) std::unordered_map<uint32_t, std::unique_ptr<DnsQueryContext>> m_dnsRequests {};
// Event that is set when all tracked DNS requests have completed. // Event that is set when all tracked DNS requests have completed.
wil::unique_event m_allRequestsFinished{wil::EventOptions::ManualReset}; wil::unique_event m_allRequestsFinished{wil::EventOptions::ManualReset};
// Used for handling of external interface constraint setting. // Used for handling of external interface constraint setting.
unique_notify_handle m_interfaceNotificationHandle{}; unique_notify_handle m_interfaceNotificationHandle{};
std::wstring m_externalInterfaceConstraintName; std::wstring m_externalInterfaceConstraintName;
_Guarded_by_(m_dnsLock) ULONG m_externalInterfaceConstraintIndex = 0; _Guarded_by_(m_dnsLock) ULONG m_externalInterfaceConstraintIndex = 0;
const DnsResolverFlags m_flags{}; const DnsResolverFlags m_flags{};
// Statistics used for telemetry. // Statistics used for telemetry.
std::atomic<uint32_t> m_totalUdpQueries{0}; std::atomic<uint32_t> m_totalUdpQueries{0};
std::atomic<uint32_t> m_successfulUdpQueries{0}; std::atomic<uint32_t> m_successfulUdpQueries{0};
std::atomic<uint32_t> m_totalTcpQueries{0}; std::atomic<uint32_t> m_totalTcpQueries{0};
std::atomic<uint32_t> m_successfulTcpQueries{0}; std::atomic<uint32_t> m_successfulTcpQueries{0};
std::atomic<uint32_t> m_queriesWithNullResult{0}; std::atomic<uint32_t> m_queriesWithNullResult{0};
std::atomic<uint32_t> m_failedDnsQueryRawCalls{0}; std::atomic<uint32_t> m_failedDnsQueryRawCalls{0};
_Guarded_by_(m_dnsLock) std::map<uint32_t, uint32_t> m_dnsApiFailures; _Guarded_by_(m_dnsLock) std::map<uint32_t, uint32_t> m_dnsApiFailures;
// Dynamic functions used for calling the DNS APIs. // Dynamic functions used for calling the DNS APIs.
// Function to start a raw DNS request. // Function to start a raw DNS request.
static std::optional<LxssDynamicFunction<decltype(DnsQueryRaw)>> s_dnsQueryRaw; static std::optional<LxssDynamicFunction<decltype(DnsQueryRaw)>> s_dnsQueryRaw;
// Function to cancel a raw DNS request. // Function to cancel a raw DNS request.
static std::optional<LxssDynamicFunction<decltype(DnsCancelQueryRaw)>> s_dnsCancelQueryRaw; static std::optional<LxssDynamicFunction<decltype(DnsCancelQueryRaw)>> s_dnsCancelQueryRaw;
// Function to free the structure containing the result of a raw DNS request. // Function to free the structure containing the result of a raw DNS request.
static std::optional<LxssDynamicFunction<decltype(DnsQueryRawResultFree)>> s_dnsQueryRawResultFree; static std::optional<LxssDynamicFunction<decltype(DnsQueryRawResultFree)>> s_dnsQueryRawResultFree;
}; };
} // namespace wsl::core::networking } // namespace wsl::core::networking

View File

@ -1,115 +1,115 @@
// Copyright (C) Microsoft Corporation. All rights reserved. // Copyright (C) Microsoft Corporation. All rights reserved.
#include "precomp.h" #include "precomp.h"
#include "DnsTunnelingChannel.h" #include "DnsTunnelingChannel.h"
using wsl::core::networking::DnsTunnelingChannel; using wsl::core::networking::DnsTunnelingChannel;
DnsTunnelingChannel::DnsTunnelingChannel(wil::unique_socket&& socket, DnsTunnelingCallback&& reportDnsRequest) : DnsTunnelingChannel::DnsTunnelingChannel(wil::unique_socket&& socket, DnsTunnelingCallback&& reportDnsRequest) :
m_channel{std::move(socket), "DnsTunneling", m_stopEvent.get()}, m_reportDnsRequest(std::move(reportDnsRequest)) m_channel{std::move(socket), "DnsTunneling", m_stopEvent.get()}, m_reportDnsRequest(std::move(reportDnsRequest))
{ {
WSL_LOG("DnsTunnelingChannel::DnsTunnelingChannel [Windows]", TraceLoggingValue(m_channel.Socket(), "socket")); WSL_LOG("DnsTunnelingChannel::DnsTunnelingChannel [Windows]", TraceLoggingValue(m_channel.Socket(), "socket"));
// Start thread waiting for incoming messages from Linux side // Start thread waiting for incoming messages from Linux side
m_receiveWorkerThread = std::thread([this]() { ReceiveLoop(); }); m_receiveWorkerThread = std::thread([this]() { ReceiveLoop(); });
} }
DnsTunnelingChannel::~DnsTunnelingChannel() DnsTunnelingChannel::~DnsTunnelingChannel()
{ {
Stop(); Stop();
} }
void DnsTunnelingChannel::SendDnsMessage(const gsl::span<gsl::byte> dnsBuffer, const LX_GNS_DNS_CLIENT_IDENTIFIER& dnsClientIdentifier) noexcept void DnsTunnelingChannel::SendDnsMessage(const gsl::span<gsl::byte> dnsBuffer, const LX_GNS_DNS_CLIENT_IDENTIFIER& dnsClientIdentifier) noexcept
try try
{ {
// Exit if channel was stopped // Exit if channel was stopped
if (m_stopEvent.is_signaled()) if (m_stopEvent.is_signaled())
{ {
return; return;
} }
wsl::shared::MessageWriter<LX_GNS_DNS_TUNNELING_MESSAGE> message(LxGnsMessageDnsTunneling); wsl::shared::MessageWriter<LX_GNS_DNS_TUNNELING_MESSAGE> message(LxGnsMessageDnsTunneling);
message->DnsClientIdentifier = dnsClientIdentifier; message->DnsClientIdentifier = dnsClientIdentifier;
message.WriteSpan(dnsBuffer); message.WriteSpan(dnsBuffer);
m_channel.SendMessage<LX_GNS_DNS_TUNNELING_MESSAGE>(message.Span()); m_channel.SendMessage<LX_GNS_DNS_TUNNELING_MESSAGE>(message.Span());
} }
CATCH_LOG() CATCH_LOG()
void DnsTunnelingChannel::ReceiveLoop() noexcept void DnsTunnelingChannel::ReceiveLoop() noexcept
{ {
std::vector<gsl::byte> receiveBuffer; std::vector<gsl::byte> receiveBuffer;
for (;;) for (;;)
{ {
try try
{ {
if (m_stopEvent.is_signaled()) if (m_stopEvent.is_signaled())
{ {
return; return;
} }
WSL_LOG_DEBUG("DnsTunnelingChannel::ReceiveLoop [Windows] - waiting for next message from Linux"); WSL_LOG_DEBUG("DnsTunnelingChannel::ReceiveLoop [Windows] - waiting for next message from Linux");
// Read next message. wsl::shared::socket::RecvMessage() first reads the message header, then uses it to determine the // Read next message. wsl::shared::socket::RecvMessage() first reads the message header, then uses it to determine the
// total size of the message and read the rest of the message, resizing the buffer if needed. // total size of the message and read the rest of the message, resizing the buffer if needed.
auto [message, span] = m_channel.ReceiveMessageOrClosed<MESSAGE_HEADER>(); auto [message, span] = m_channel.ReceiveMessageOrClosed<MESSAGE_HEADER>();
if (message == nullptr) if (message == nullptr)
{ {
WSL_LOG("DnsTunnelingChannel::ReceiveLoop [Windows] - failed to read message"); WSL_LOG("DnsTunnelingChannel::ReceiveLoop [Windows] - failed to read message");
return; return;
} }
// Get the message type from the message header // Get the message type from the message header
switch (message->MessageType) switch (message->MessageType)
{ {
case LxGnsMessageDnsTunneling: case LxGnsMessageDnsTunneling:
{ {
// Cast message to a LX_GNS_DNS_TUNNELING_MESSAGE struct // Cast message to a LX_GNS_DNS_TUNNELING_MESSAGE struct
auto* dnsMessage = gslhelpers::try_get_struct<LX_GNS_DNS_TUNNELING_MESSAGE>(span); auto* dnsMessage = gslhelpers::try_get_struct<LX_GNS_DNS_TUNNELING_MESSAGE>(span);
if (!dnsMessage) if (!dnsMessage)
{ {
WSL_LOG( WSL_LOG(
"DnsTunnelingChannel::ReceiveLoop [Windows] - failed to convert message to LX_GNS_DNS_TUNNELING_MESSAGE"); "DnsTunnelingChannel::ReceiveLoop [Windows] - failed to convert message to LX_GNS_DNS_TUNNELING_MESSAGE");
return; return;
} }
// Extract DNS buffer from message // Extract DNS buffer from message
auto dnsBuffer = span.subspan(offsetof(LX_GNS_DNS_TUNNELING_MESSAGE, Buffer)); auto dnsBuffer = span.subspan(offsetof(LX_GNS_DNS_TUNNELING_MESSAGE, Buffer));
WSL_LOG_DEBUG( WSL_LOG_DEBUG(
"DnsTunnelingChannel::ReceiveLoop [Windows] - received DNS message", "DnsTunnelingChannel::ReceiveLoop [Windows] - received DNS message",
TraceLoggingValue(dnsBuffer.size(), "DNS buffer size"), TraceLoggingValue(dnsBuffer.size(), "DNS buffer size"),
TraceLoggingValue(dnsMessage->DnsClientIdentifier.Protocol == IPPROTO_UDP ? "UDP" : "TCP", "Protocol"), TraceLoggingValue(dnsMessage->DnsClientIdentifier.Protocol == IPPROTO_UDP ? "UDP" : "TCP", "Protocol"),
TraceLoggingValue(dnsMessage->DnsClientIdentifier.DnsClientId, "DNS client id")); TraceLoggingValue(dnsMessage->DnsClientIdentifier.DnsClientId, "DNS client id"));
// Invoke callback to notify about the new DNS request // Invoke callback to notify about the new DNS request
m_reportDnsRequest(dnsBuffer, dnsMessage->DnsClientIdentifier); m_reportDnsRequest(dnsBuffer, dnsMessage->DnsClientIdentifier);
break; break;
} }
default: default:
{ {
THROW_HR_MSG(E_UNEXPECTED, "Unexpected LX_MESSAGE_TYPE : %i", message->MessageType); THROW_HR_MSG(E_UNEXPECTED, "Unexpected LX_MESSAGE_TYPE : %i", message->MessageType);
} }
} }
} }
CATCH_LOG() CATCH_LOG()
} }
} }
void DnsTunnelingChannel::Stop() noexcept void DnsTunnelingChannel::Stop() noexcept
try try
{ {
WSL_LOG("DnsTunnelingChannel::Stop [Windows]"); WSL_LOG("DnsTunnelingChannel::Stop [Windows]");
m_stopEvent.SetEvent(); m_stopEvent.SetEvent();
// Stop receive loop // Stop receive loop
if (m_receiveWorkerThread.joinable()) if (m_receiveWorkerThread.joinable())
{ {
m_receiveWorkerThread.join(); m_receiveWorkerThread.join();
} }
} }
CATCH_LOG() CATCH_LOG()

View File

@ -1,50 +1,50 @@
// Copyright (C) Microsoft Corporation. All rights reserved. // Copyright (C) Microsoft Corporation. All rights reserved.
#pragma once #pragma once
#include <wil/resource.h> #include <wil/resource.h>
#include "lxinitshared.h" #include "lxinitshared.h"
#include "SocketChannel.h" #include "SocketChannel.h"
namespace wsl::core::networking { namespace wsl::core::networking {
using DnsTunnelingCallback = std::function<void(const gsl::span<gsl::byte>, const LX_GNS_DNS_CLIENT_IDENTIFIER&)>; using DnsTunnelingCallback = std::function<void(const gsl::span<gsl::byte>, const LX_GNS_DNS_CLIENT_IDENTIFIER&)>;
class DnsTunnelingChannel class DnsTunnelingChannel
{ {
public: public:
DnsTunnelingChannel(wil::unique_socket&& socket, DnsTunnelingCallback&& reportDnsRequest); DnsTunnelingChannel(wil::unique_socket&& socket, DnsTunnelingCallback&& reportDnsRequest);
~DnsTunnelingChannel(); ~DnsTunnelingChannel();
DnsTunnelingChannel(const DnsTunnelingChannel&) = delete; DnsTunnelingChannel(const DnsTunnelingChannel&) = delete;
DnsTunnelingChannel& operator=(const DnsTunnelingChannel&) = delete; DnsTunnelingChannel& operator=(const DnsTunnelingChannel&) = delete;
DnsTunnelingChannel(DnsTunnelingChannel&&) = delete; DnsTunnelingChannel(DnsTunnelingChannel&&) = delete;
DnsTunnelingChannel& operator=(DnsTunnelingChannel&&) = delete; DnsTunnelingChannel& operator=(DnsTunnelingChannel&&) = delete;
// Construct and send a LX_GNS_DNS_TUNNELING_MESSAGE message on the channel. // Construct and send a LX_GNS_DNS_TUNNELING_MESSAGE message on the channel.
// Note: Callers are responsible for sequencing calls to this method. // Note: Callers are responsible for sequencing calls to this method.
// //
// Arguments: // Arguments:
// dnsBuffer - buffer containing DNS response. // dnsBuffer - buffer containing DNS response.
// dnsClientIdentifier - struct containing protocol (TCP/UDP) and unique id of the Linux DNS client making the request. // dnsClientIdentifier - struct containing protocol (TCP/UDP) and unique id of the Linux DNS client making the request.
void SendDnsMessage(const gsl::span<gsl::byte> dnsBuffer, const LX_GNS_DNS_CLIENT_IDENTIFIER& dnsClientIdentifier) noexcept; void SendDnsMessage(const gsl::span<gsl::byte> dnsBuffer, const LX_GNS_DNS_CLIENT_IDENTIFIER& dnsClientIdentifier) noexcept;
// Stop the channel. // Stop the channel.
void Stop() noexcept; void Stop() noexcept;
private: private:
// Wait for messages on the channel from Linux side. // Wait for messages on the channel from Linux side.
void ReceiveLoop() noexcept; void ReceiveLoop() noexcept;
wil::unique_event m_stopEvent{wil::EventOptions::ManualReset}; wil::unique_event m_stopEvent{wil::EventOptions::ManualReset};
wsl::shared::SocketChannel m_channel; wsl::shared::SocketChannel m_channel;
std::thread m_receiveWorkerThread; std::thread m_receiveWorkerThread;
// Callback used to notify when there is a new DNS request message on the channel. // Callback used to notify when there is a new DNS request message on the channel.
DnsTunnelingCallback m_reportDnsRequest; DnsTunnelingCallback m_reportDnsRequest;
}; };
} // namespace wsl::core::networking } // namespace wsl::core::networking

View File

@ -6,7 +6,6 @@
#include "WslCoreHostDnsInfo.h" #include "WslCoreHostDnsInfo.h"
#include "Stringify.h" #include "Stringify.h"
#include "WslCoreFirewallSupport.h" #include "WslCoreFirewallSupport.h"
#include "WslCoreVm.h"
#include "hcs.hpp" #include "hcs.hpp"
using namespace wsl::core::networking; using namespace wsl::core::networking;
@ -672,7 +671,7 @@ wsl::windows::common::hcs::unique_hcn_network NatNetworking::CreateNetwork(wsl::
wil::ResultFromException(WI_DIAGNOSTICS_INFO, [&] { wil::ResultFromException(WI_DIAGNOSTICS_INFO, [&] {
try try
{ {
wsl::core::networking::ConfigureHyperVFirewall(config.FirewallConfig, c_vmOwner); wsl::core::networking::ConfigureHyperVFirewall(config.FirewallConfig, wsl::windows::common::wslutil::c_vmOwner);
natNetwork = CreateNetworkInternal(config); natNetwork = CreateNetworkInternal(config);
} }
catch (...) catch (...)

View File

@ -1,154 +1,154 @@
/*++ /*++
Copyright (c) Microsoft. All rights reserved. Copyright (c) Microsoft. All rights reserved.
Module Name: Module Name:
RingBuffer.cpp RingBuffer.cpp
Abstract: Abstract:
This file contains definitions for the RingBuffer class. This file contains definitions for the RingBuffer class.
--*/ --*/
#include "precomp.h" #include "precomp.h"
#include "RingBuffer.h" #include "RingBuffer.h"
RingBuffer::RingBuffer(size_t size) : m_maxSize(size), m_offset(0) RingBuffer::RingBuffer(size_t size) : m_maxSize(size), m_offset(0)
{ {
m_buffer.reserve(size); m_buffer.reserve(size);
} }
void RingBuffer::Insert(std::string_view data) void RingBuffer::Insert(std::string_view data)
{ {
auto lock = m_lock.lock_exclusive(); auto lock = m_lock.lock_exclusive();
auto remainingData = gsl::make_span(data.data(), data.size()); auto remainingData = gsl::make_span(data.data(), data.size());
if (remainingData.size() > m_maxSize) if (remainingData.size() > m_maxSize)
{ {
remainingData = remainingData.subspan(remainingData.size() - m_maxSize); remainingData = remainingData.subspan(remainingData.size() - m_maxSize);
} }
const auto bytesAtEnd = std::min(m_maxSize - m_offset, remainingData.size()); const auto bytesAtEnd = std::min(m_maxSize - m_offset, remainingData.size());
if (m_offset + bytesAtEnd > m_buffer.size()) if (m_offset + bytesAtEnd > m_buffer.size())
{ {
m_buffer.resize(m_offset + bytesAtEnd); m_buffer.resize(m_offset + bytesAtEnd);
WI_ASSERT(m_buffer.size() <= m_maxSize); WI_ASSERT(m_buffer.size() <= m_maxSize);
} }
const auto allBuffer = gsl::make_span(m_buffer); const auto allBuffer = gsl::make_span(m_buffer);
const auto beginCopyBuffer = allBuffer.subspan(m_offset, bytesAtEnd); const auto beginCopyBuffer = allBuffer.subspan(m_offset, bytesAtEnd);
copy(remainingData.subspan(0, bytesAtEnd), beginCopyBuffer); copy(remainingData.subspan(0, bytesAtEnd), beginCopyBuffer);
remainingData = remainingData.subspan(bytesAtEnd); remainingData = remainingData.subspan(bytesAtEnd);
if (!remainingData.empty()) if (!remainingData.empty())
{ {
copy(remainingData, allBuffer); copy(remainingData, allBuffer);
m_offset = remainingData.size(); m_offset = remainingData.size();
} }
else else
{ {
m_offset += bytesAtEnd; m_offset += bytesAtEnd;
} }
} }
std::vector<std::string> RingBuffer::GetLastDelimitedStrings(char Delimiter, size_t Count) const std::vector<std::string> RingBuffer::GetLastDelimitedStrings(char Delimiter, size_t Count) const
{ {
auto lock = m_lock.lock_shared(); auto lock = m_lock.lock_shared();
auto [begin, end] = Contents(); auto [begin, end] = Contents();
std::vector<std::string> results; std::vector<std::string> results;
std::optional<size_t> endIndex; std::optional<size_t> endIndex;
for (size_t i = end.size(); i > 0; i--) for (size_t i = end.size(); i > 0; i--)
{ {
if (results.size() == Count) if (results.size() == Count)
{ {
break; break;
} }
if (Delimiter == end[i - 1]) if (Delimiter == end[i - 1])
{ {
if (endIndex.has_value()) if (endIndex.has_value())
{ {
results.emplace(results.begin(), &end[i], endIndex.value() - i); results.emplace(results.begin(), &end[i], endIndex.value() - i);
endIndex.reset(); endIndex.reset();
} }
else else
{ {
endIndex = i - 1; endIndex = i - 1;
} }
} }
} }
if (results.size() == Count) if (results.size() == Count)
{ {
return results; return results;
} }
std::string partial; std::string partial;
if (endIndex.has_value()) if (endIndex.has_value())
{ {
partial = std::string{&end[0], endIndex.value()}; partial = std::string{&end[0], endIndex.value()};
endIndex.reset(); endIndex.reset();
} }
for (size_t i = begin.size(); i > 0; i--) for (size_t i = begin.size(); i > 0; i--)
{ {
if (results.size() == Count) if (results.size() == Count)
{ {
break; break;
} }
if (Delimiter == begin[i - 1]) if (Delimiter == begin[i - 1])
{ {
if (!partial.empty()) if (!partial.empty())
{ {
// The debug CRT will fastfail if begin[size] is accessed // The debug CRT will fastfail if begin[size] is accessed
// But in this case it's not a problem because begin.size() - i would be == 0 // But in this case it's not a problem because begin.size() - i would be == 0
std::string partial_begin{&begin.data()[i], begin.size() - i}; std::string partial_begin{&begin.data()[i], begin.size() - i};
results.emplace(results.begin(), partial_begin + partial); results.emplace(results.begin(), partial_begin + partial);
partial.clear(); partial.clear();
} }
else if (endIndex.has_value()) else if (endIndex.has_value())
{ {
results.emplace(results.begin(), &begin.data()[i], endIndex.value() - i); results.emplace(results.begin(), &begin.data()[i], endIndex.value() - i);
endIndex.reset(); endIndex.reset();
} }
else else
{ {
endIndex = i - 1; endIndex = i - 1;
} }
} }
} }
if (results.size() < Count) if (results.size() < Count)
{ {
// May have lost some data, or this could be the very first line logged. // May have lost some data, or this could be the very first line logged.
if (!partial.empty()) if (!partial.empty())
{ {
results.emplace(results.begin(), partial); results.emplace(results.begin(), partial);
} }
else if (endIndex.has_value()) else if (endIndex.has_value())
{ {
results.emplace(results.begin(), &begin[0], endIndex.value()); results.emplace(results.begin(), &begin[0], endIndex.value());
} }
} }
return results; return results;
} }
std::string RingBuffer::Get() const std::string RingBuffer::Get() const
{ {
auto lock = m_lock.lock_shared(); auto lock = m_lock.lock_shared();
auto [begin, end] = Contents(); auto [begin, end] = Contents();
std::string data; std::string data;
data.reserve(begin.size() + end.size()); data.reserve(begin.size() + end.size());
data.append(begin.data(), begin.size()); data.append(begin.data(), begin.size());
data.append(end.data(), end.size()); data.append(end.data(), end.size());
return data; return data;
} }
std::pair<std::string_view, std::string_view> RingBuffer::Contents() const std::pair<std::string_view, std::string_view> RingBuffer::Contents() const
{ {
std::string_view beginView(m_buffer.data() + m_offset, m_buffer.size() - m_offset); std::string_view beginView(m_buffer.data() + m_offset, m_buffer.size() - m_offset);
std::string_view endView(m_buffer.data(), m_offset); std::string_view endView(m_buffer.data(), m_offset);
return {beginView, endView}; return {beginView, endView};
} }

View File

@ -1,34 +1,34 @@
/*++ /*++
Copyright (c) Microsoft. All rights reserved. Copyright (c) Microsoft. All rights reserved.
Module Name: Module Name:
RingBuffer.h RingBuffer.h
Abstract: Abstract:
This file contains declarations for the RingBuffer class. This file contains declarations for the RingBuffer class.
--*/ --*/
#pragma once #pragma once
class RingBuffer class RingBuffer
{ {
public: public:
RingBuffer() = delete; RingBuffer() = delete;
RingBuffer(size_t size); RingBuffer(size_t size);
void Insert(std::string_view data); void Insert(std::string_view data);
std::vector<std::string> GetLastDelimitedStrings(char Delimiter, size_t Count) const; std::vector<std::string> GetLastDelimitedStrings(char Delimiter, size_t Count) const;
std::string Get() const; std::string Get() const;
private: private:
std::pair<std::string_view, std::string_view> Contents() const; std::pair<std::string_view, std::string_view> Contents() const;
mutable wil::srwlock m_lock; mutable wil::srwlock m_lock;
std::vector<char> m_buffer; std::vector<char> m_buffer;
size_t m_maxSize; size_t m_maxSize;
size_t m_offset; size_t m_offset;
}; };

View File

@ -1,104 +1,104 @@
// Copyright (C) Microsoft Corporation. All rights reserved. // Copyright (C) Microsoft Corporation. All rights reserved.
#pragma once #pragma once
#include <mutex> #include <mutex>
#include <string> #include <string>
#include <vector> #include <vector>
#include <iptypes.h> #include <iptypes.h>
#include <wil/registry.h> #include <wil/registry.h>
#include "WslCoreNetworkingSupport.h" #include "WslCoreNetworkingSupport.h"
#include "RegistryWatcher.h" #include "RegistryWatcher.h"
namespace wsl::core::networking { namespace wsl::core::networking {
struct DnsInfo struct DnsInfo
{ {
std::vector<std::string> Servers; std::vector<std::string> Servers;
std::vector<std::string> Domains; std::vector<std::string> Domains;
}; };
enum class DnsSettingsFlags enum class DnsSettingsFlags
{ {
None = 0x0, None = 0x0,
IncludeVpn = 0x1, IncludeVpn = 0x1,
IncludeIpv6Servers = 0x2, IncludeIpv6Servers = 0x2,
IncludeAllSuffixes = 0x4 IncludeAllSuffixes = 0x4
}; };
DEFINE_ENUM_FLAG_OPERATORS(DnsSettingsFlags); DEFINE_ENUM_FLAG_OPERATORS(DnsSettingsFlags);
inline bool operator==(const DnsInfo& lhs, const DnsInfo& rhs) noexcept inline bool operator==(const DnsInfo& lhs, const DnsInfo& rhs) noexcept
{ {
return lhs.Servers == rhs.Servers && lhs.Domains == rhs.Domains; return lhs.Servers == rhs.Servers && lhs.Domains == rhs.Domains;
} }
inline bool operator!=(const DnsInfo& lhs, const DnsInfo& rhs) noexcept inline bool operator!=(const DnsInfo& lhs, const DnsInfo& rhs) noexcept
{ {
return !(lhs == rhs); return !(lhs == rhs);
} }
std::string GenerateResolvConf(_In_ const DnsInfo& Info); std::string GenerateResolvConf(_In_ const DnsInfo& Info);
std::vector<std::string> GetAllDnsSuffixes(const std::vector<IpAdapterAddress>& AdapterAddresses); std::vector<std::string> GetAllDnsSuffixes(const std::vector<IpAdapterAddress>& AdapterAddresses);
DWORD GetBestInterface(); DWORD GetBestInterface();
class HostDnsInfo class HostDnsInfo
{ {
public: public:
DnsInfo GetDnsSettings(_In_ DnsSettingsFlags Flags); DnsInfo GetDnsSettings(_In_ DnsSettingsFlags Flags);
void UpdateNetworkInformation(); void UpdateNetworkInformation();
static DnsInfo GetDnsTunnelingSettings(const std::wstring& dnsTunnelingNameserver); static DnsInfo GetDnsTunnelingSettings(const std::wstring& dnsTunnelingNameserver);
const std::vector<IpAdapterAddress>& CurrentAddresses() const const std::vector<IpAdapterAddress>& CurrentAddresses() const
{ {
return m_addresses; return m_addresses;
} }
private: private:
/// <summary> /// <summary>
/// Internal function to retrieve the latest copy of interface information. /// Internal function to retrieve the latest copy of interface information.
/// </summary> /// </summary>
std::vector<IpAdapterAddress> GetAdapterAddresses(); std::vector<IpAdapterAddress> GetAdapterAddresses();
/// <summary> /// <summary>
/// Internal function to retrieve interface DNS servers. /// Internal function to retrieve interface DNS servers.
/// </summary> /// </summary>
std::vector<std::string> GetInterfaceDnsServers(const std::vector<IpAdapterAddress>& AdapterAddresses, _In_ DnsSettingsFlags Flags); std::vector<std::string> GetInterfaceDnsServers(const std::vector<IpAdapterAddress>& AdapterAddresses, _In_ DnsSettingsFlags Flags);
/// <summary> /// <summary>
/// Internal function to retrieve all Windows DNS suffixes. /// Internal function to retrieve all Windows DNS suffixes.
/// </summary> /// </summary>
static std::vector<std::string> GetInterfaceDnsSuffixes(const std::vector<IpAdapterAddress>& AdapterAddresses); static std::vector<std::string> GetInterfaceDnsSuffixes(const std::vector<IpAdapterAddress>& AdapterAddresses);
/// <summary> /// <summary>
/// Internal function to convert DNS server addresses into strings. /// Internal function to convert DNS server addresses into strings.
/// </summary> /// </summary>
static std::vector<std::string> GetDnsServerStrings(_In_ const PIP_ADAPTER_DNS_SERVER_ADDRESS& DnsServer, _In_ USHORT IpFamilyFilter, _In_ USHORT MaxValues); static std::vector<std::string> GetDnsServerStrings(_In_ const PIP_ADAPTER_DNS_SERVER_ADDRESS& DnsServer, _In_ USHORT IpFamilyFilter, _In_ USHORT MaxValues);
/// <summary> /// <summary>
/// Stores latest copy of interface information. /// Stores latest copy of interface information.
/// </summary> /// </summary>
std::mutex m_lock; std::mutex m_lock;
_Guarded_by_(m_lock) std::vector<IpAdapterAddress> m_addresses; _Guarded_by_(m_lock) std::vector<IpAdapterAddress> m_addresses;
}; };
using RegistryChangeCallback = std::function<void()>; using RegistryChangeCallback = std::function<void()>;
/// <summary> /// <summary>
/// Class used to get notifications when Windows DNS suffixes are updated in registry. /// Class used to get notifications when Windows DNS suffixes are updated in registry.
/// </summary> /// </summary>
class DnsSuffixRegistryWatcher class DnsSuffixRegistryWatcher
{ {
public: public:
DnsSuffixRegistryWatcher(RegistryChangeCallback&& reportRegistryChange); DnsSuffixRegistryWatcher(RegistryChangeCallback&& reportRegistryChange);
~DnsSuffixRegistryWatcher() noexcept = default; ~DnsSuffixRegistryWatcher() noexcept = default;
private: private:
RegistryChangeCallback m_reportRegistryChange; RegistryChangeCallback m_reportRegistryChange;
std::vector<wistd::unique_ptr<wsl::windows::common::slim_registry_watcher>> m_registryWatchers; std::vector<wistd::unique_ptr<wsl::windows::common::slim_registry_watcher>> m_registryWatchers;
}; };
} // namespace wsl::core::networking } // namespace wsl::core::networking

View File

@ -1,359 +1,359 @@
/*++ /*++
Copyright (c) Microsoft. All rights reserved. Copyright (c) Microsoft. All rights reserved.
Module Name: Module Name:
WslCoreMessageQueue.h WslCoreMessageQueue.h
Abstract: Abstract:
This file contains a queuing implementation, guaranteeing running function objects This file contains a queuing implementation, guaranteeing running function objects
with guaranteed serialization in a threadpool thread with guaranteed serialization in a threadpool thread
--*/ --*/
#pragma once #pragma once
#include <deque> #include <deque>
#include <functional> #include <functional>
#include <memory> #include <memory>
#include <variant> #include <variant>
#include <windows.h> #include <windows.h>
#include <wil/resource.h> #include <wil/resource.h>
namespace wsl::core { namespace wsl::core {
// forward-declare classes that can instantiate a WslThreadPoolWaitableResult object // forward-declare classes that can instantiate a WslThreadPoolWaitableResult object
class WslCoreMessageQueue; class WslCoreMessageQueue;
class WslBaseThreadPoolWaitableResult class WslBaseThreadPoolWaitableResult
{ {
public: public:
virtual ~WslBaseThreadPoolWaitableResult() noexcept = default; virtual ~WslBaseThreadPoolWaitableResult() noexcept = default;
private: private:
// limit who can run() and abort() // limit who can run() and abort()
friend class WslCoreMessageQueue; friend class WslCoreMessageQueue;
virtual void run() noexcept = 0; virtual void run() noexcept = 0;
virtual void abort() noexcept = 0; virtual void abort() noexcept = 0;
}; };
template <typename TReturn> template <typename TReturn>
class WslThreadPoolWaitableResult : public WslBaseThreadPoolWaitableResult class WslThreadPoolWaitableResult : public WslBaseThreadPoolWaitableResult
{ {
public: public:
// throws a wil exception on failure // throws a wil exception on failure
template <typename FunctorType> template <typename FunctorType>
explicit WslThreadPoolWaitableResult(FunctorType&& functor) : m_function(std::forward<FunctorType>(functor)) explicit WslThreadPoolWaitableResult(FunctorType&& functor) : m_function(std::forward<FunctorType>(functor))
{ {
} }
~WslThreadPoolWaitableResult() noexcept override = default; ~WslThreadPoolWaitableResult() noexcept override = default;
// returns ERROR_SUCCESS if the callback ran to completion // returns ERROR_SUCCESS if the callback ran to completion
// returns ERROR_TIMEOUT if this wait timed out // returns ERROR_TIMEOUT if this wait timed out
// - this can be called multiple times if needing to probe // - this can be called multiple times if needing to probe
// any other error code resulted from attempting to run the callback // any other error code resulted from attempting to run the callback
// - meaning it did *not* run to completion // - meaning it did *not* run to completion
DWORD wait(DWORD timeout) const noexcept DWORD wait(DWORD timeout) const noexcept
{ {
if (!m_completionSignal.wait(timeout)) if (!m_completionSignal.wait(timeout))
{ {
// not setting m_internalError to timeout // not setting m_internalError to timeout
// since the caller is allowed to try to wait() again later // since the caller is allowed to try to wait() again later
return ERROR_TIMEOUT; return ERROR_TIMEOUT;
} }
const auto lock = m_lock.lock_shared(); const auto lock = m_lock.lock_shared();
return m_internalError; return m_internalError;
} }
// waitable event handle, signaled when the callback has run to completion (or failed) // waitable event handle, signaled when the callback has run to completion (or failed)
HANDLE notification_event() const noexcept HANDLE notification_event() const noexcept
{ {
return m_completionSignal.get(); return m_completionSignal.get();
} }
const TReturn& read_result() const noexcept const TReturn& read_result() const noexcept
{ {
return result; return result;
} }
// move the result out of the object for move-only types // move the result out of the object for move-only types
TReturn move_result() noexcept TReturn move_result() noexcept
{ {
TReturn move_out(std::move(result)); TReturn move_out(std::move(result));
return move_out; return move_out;
} }
// non-copyable // non-copyable
WslThreadPoolWaitableResult(const WslThreadPoolWaitableResult&) = delete; WslThreadPoolWaitableResult(const WslThreadPoolWaitableResult&) = delete;
WslThreadPoolWaitableResult& operator=(const WslThreadPoolWaitableResult&) = delete; WslThreadPoolWaitableResult& operator=(const WslThreadPoolWaitableResult&) = delete;
private: private:
void run() noexcept override void run() noexcept override
{ {
// we are now running in the TP callback // we are now running in the TP callback
{ {
const auto lock = m_lock.lock_exclusive(); const auto lock = m_lock.lock_exclusive();
if (m_runStatus != RunStatus::NotYetRun) if (m_runStatus != RunStatus::NotYetRun)
{ {
// return early - the caller has already canceled this // return early - the caller has already canceled this
return; return;
} }
m_runStatus = RunStatus::Running; m_runStatus = RunStatus::Running;
} }
DWORD error = NO_ERROR; DWORD error = NO_ERROR;
try try
{ {
result = std::move(m_function()); result = std::move(m_function());
} }
catch (...) catch (...)
{ {
const HRESULT hr = wil::ResultFromCaughtException(); const HRESULT hr = wil::ResultFromCaughtException();
// HRESULT_TO_WIN32 // HRESULT_TO_WIN32
error = (HRESULT_FACILITY(hr) == FACILITY_WIN32) ? HRESULT_CODE(hr) : hr; error = (HRESULT_FACILITY(hr) == FACILITY_WIN32) ? HRESULT_CODE(hr) : hr;
} }
const auto lock = m_lock.lock_exclusive(); const auto lock = m_lock.lock_exclusive();
WI_ASSERT(m_runStatus == RunStatus::Running); WI_ASSERT(m_runStatus == RunStatus::Running);
m_runStatus = RunStatus::RanToCompletion; m_runStatus = RunStatus::RanToCompletion;
m_internalError = error; m_internalError = error;
m_completionSignal.SetEvent(); m_completionSignal.SetEvent();
} }
void abort() noexcept override void abort() noexcept override
{ {
const auto lock = m_lock.lock_exclusive(); const auto lock = m_lock.lock_exclusive();
// only override the error if we know we haven't started running their functor // only override the error if we know we haven't started running their functor
if (m_runStatus == RunStatus::NotYetRun) if (m_runStatus == RunStatus::NotYetRun)
{ {
m_runStatus = RunStatus::Canceled; m_runStatus = RunStatus::Canceled;
m_internalError = ERROR_CANCELLED; m_internalError = ERROR_CANCELLED;
m_completionSignal.SetEvent(); m_completionSignal.SetEvent();
} }
} }
std::function<TReturn(void)> m_function; std::function<TReturn(void)> m_function;
// a notification event // a notification event
wil::unique_event m_completionSignal{wil::EventOptions::ManualReset}; wil::unique_event m_completionSignal{wil::EventOptions::ManualReset};
mutable wil::srwlock m_lock; mutable wil::srwlock m_lock;
TReturn result{}; TReturn result{};
DWORD m_internalError = NO_ERROR; DWORD m_internalError = NO_ERROR;
enum class RunStatus enum class RunStatus
{ {
NotYetRun, NotYetRun,
Running, Running,
RanToCompletion, RanToCompletion,
Canceled Canceled
} m_runStatus{RunStatus::NotYetRun}; } m_runStatus{RunStatus::NotYetRun};
}; };
class WslCoreMessageQueue class WslCoreMessageQueue
{ {
public: public:
WslCoreMessageQueue() : m_tpEnvironment(0, 1) WslCoreMessageQueue() : m_tpEnvironment(0, 1)
{ {
// create a single-threaded threadpool // create a single-threaded threadpool
m_tpHandle = m_tpEnvironment.create_tp(WorkCallback, this); m_tpHandle = m_tpEnvironment.create_tp(WorkCallback, this);
} }
template <typename TReturn, typename FunctorType> template <typename TReturn, typename FunctorType>
std::shared_ptr<WslThreadPoolWaitableResult<TReturn>> submit_with_results(FunctorType&& functor) noexcept std::shared_ptr<WslThreadPoolWaitableResult<TReturn>> submit_with_results(FunctorType&& functor) noexcept
try try
{ {
FAIL_FAST_IF(m_tpHandle.get() == nullptr); FAIL_FAST_IF(m_tpHandle.get() == nullptr);
const auto new_result = std::make_shared<WslThreadPoolWaitableResult<TReturn>>(std::forward<FunctorType>(functor)); const auto new_result = std::make_shared<WslThreadPoolWaitableResult<TReturn>>(std::forward<FunctorType>(functor));
// scope to the queue lock // scope to the queue lock
{ {
const auto queueLock = m_lock.lock_exclusive(); const auto queueLock = m_lock.lock_exclusive();
THROW_HR_IF(HRESULT_FROM_WIN32(ERROR_CANCELLED), m_isCanceled); THROW_HR_IF(HRESULT_FROM_WIN32(ERROR_CANCELLED), m_isCanceled);
m_workItems.emplace_back(new_result); m_workItems.emplace_back(new_result);
} }
// always maintain a 1:1 ratio for calls to SubmitWorkWithResults() and ::SubmitThreadpoolWork // always maintain a 1:1 ratio for calls to SubmitWorkWithResults() and ::SubmitThreadpoolWork
SubmitThreadpoolWork(m_tpHandle.get()); SubmitThreadpoolWork(m_tpHandle.get());
return new_result; return new_result;
} }
catch (...) catch (...)
{ {
LOG_CAUGHT_EXCEPTION(); LOG_CAUGHT_EXCEPTION();
return nullptr; return nullptr;
} }
template <typename FunctorType> template <typename FunctorType>
bool submit(FunctorType&& functor) noexcept bool submit(FunctorType&& functor) noexcept
try try
{ {
FAIL_FAST_IF(m_tpHandle.get() == nullptr); FAIL_FAST_IF(m_tpHandle.get() == nullptr);
// scope to the queue lock // scope to the queue lock
{ {
const auto queueLock = m_lock.lock_exclusive(); const auto queueLock = m_lock.lock_exclusive();
THROW_HR_IF(HRESULT_FROM_WIN32(ERROR_CANCELLED), m_isCanceled); THROW_HR_IF(HRESULT_FROM_WIN32(ERROR_CANCELLED), m_isCanceled);
m_workItems.emplace_back(std::forward<SimpleFunction_t>(functor)); m_workItems.emplace_back(std::forward<SimpleFunction_t>(functor));
} }
// always maintain a 1:1 ratio for calls to SubmitWork() and ::SubmitThreadpoolWork // always maintain a 1:1 ratio for calls to SubmitWork() and ::SubmitThreadpoolWork
SubmitThreadpoolWork(m_tpHandle.get()); SubmitThreadpoolWork(m_tpHandle.get());
return true; return true;
} }
catch (...) catch (...)
{ {
LOG_CAUGHT_EXCEPTION(); LOG_CAUGHT_EXCEPTION();
return false; return false;
} }
// functors must return type HRESULT // functors must return type HRESULT
template <typename FunctorType> template <typename FunctorType>
HRESULT submit_and_wait(FunctorType&& functor) noexcept HRESULT submit_and_wait(FunctorType&& functor) noexcept
try try
{ {
HRESULT hr = HRESULT_FROM_WIN32(ERROR_OUTOFMEMORY); HRESULT hr = HRESULT_FROM_WIN32(ERROR_OUTOFMEMORY);
if (const auto waitableResult = submit_with_results<HRESULT>(std::forward<FunctorType>(functor))) if (const auto waitableResult = submit_with_results<HRESULT>(std::forward<FunctorType>(functor)))
{ {
hr = HRESULT_FROM_WIN32(waitableResult->wait(INFINITE)); hr = HRESULT_FROM_WIN32(waitableResult->wait(INFINITE));
if (SUCCEEDED(hr)) if (SUCCEEDED(hr))
{ {
hr = waitableResult->read_result(); hr = waitableResult->read_result();
} }
} }
return hr; return hr;
} }
CATCH_RETURN() CATCH_RETURN()
// cancels anything queued to the TP - this WslCoreMessageQueue instance can no longer be used // cancels anything queued to the TP - this WslCoreMessageQueue instance can no longer be used
void cancel() noexcept void cancel() noexcept
try try
{ {
if (m_tpHandle) if (m_tpHandle)
{ {
// immediately release anyone waiting for these workitems not yet run // immediately release anyone waiting for these workitems not yet run
{ {
const auto queueLock = m_lock.lock_exclusive(); const auto queueLock = m_lock.lock_exclusive();
m_isCanceled = true; m_isCanceled = true;
for (const auto& work : m_workItems) for (const auto& work : m_workItems)
{ {
// signal that these are canceled before we shutdown the TP which they could be scheduled // signal that these are canceled before we shutdown the TP which they could be scheduled
if (const auto* pWaitableWorkitem = std::get_if<WaitableFunction_t>(&work)) if (const auto* pWaitableWorkitem = std::get_if<WaitableFunction_t>(&work))
{ {
(*pWaitableWorkitem)->abort(); (*pWaitableWorkitem)->abort();
} }
} }
m_workItems.clear(); m_workItems.clear();
} }
// force the m_tpHandle to wait and close the TP // force the m_tpHandle to wait and close the TP
m_tpHandle.reset(); m_tpHandle.reset();
m_tpEnvironment.reset(); m_tpEnvironment.reset();
} }
} }
CATCH_LOG() CATCH_LOG()
bool isRunningInQueue() const noexcept bool isRunningInQueue() const noexcept
{ {
const auto currentThreadId = GetThreadId(GetCurrentThread()); const auto currentThreadId = GetThreadId(GetCurrentThread());
return currentThreadId == static_cast<DWORD>(InterlockedCompareExchange64(&m_threadpoolThreadId, 0ll, 0ll)); return currentThreadId == static_cast<DWORD>(InterlockedCompareExchange64(&m_threadpoolThreadId, 0ll, 0ll));
} }
~WslCoreMessageQueue() noexcept ~WslCoreMessageQueue() noexcept
{ {
cancel(); cancel();
} }
WslCoreMessageQueue(const WslCoreMessageQueue&) = delete; WslCoreMessageQueue(const WslCoreMessageQueue&) = delete;
WslCoreMessageQueue& operator=(const WslCoreMessageQueue&) = delete; WslCoreMessageQueue& operator=(const WslCoreMessageQueue&) = delete;
WslCoreMessageQueue(WslCoreMessageQueue&&) = delete; WslCoreMessageQueue(WslCoreMessageQueue&&) = delete;
WslCoreMessageQueue& operator=(WslCoreMessageQueue&&) = delete; WslCoreMessageQueue& operator=(WslCoreMessageQueue&&) = delete;
private: private:
struct TPEnvironment struct TPEnvironment
{ {
using unique_tp_env = wil::unique_struct<TP_CALLBACK_ENVIRON, decltype(&DestroyThreadpoolEnvironment), DestroyThreadpoolEnvironment>; using unique_tp_env = wil::unique_struct<TP_CALLBACK_ENVIRON, decltype(&DestroyThreadpoolEnvironment), DestroyThreadpoolEnvironment>;
unique_tp_env m_tpEnvironment; unique_tp_env m_tpEnvironment;
using unique_tp_pool = wil::unique_any<PTP_POOL, decltype(&CloseThreadpool), CloseThreadpool>; using unique_tp_pool = wil::unique_any<PTP_POOL, decltype(&CloseThreadpool), CloseThreadpool>;
unique_tp_pool m_threadPool; unique_tp_pool m_threadPool;
TPEnvironment(DWORD countMinThread, DWORD countMaxThread) TPEnvironment(DWORD countMinThread, DWORD countMaxThread)
{ {
InitializeThreadpoolEnvironment(&m_tpEnvironment); InitializeThreadpoolEnvironment(&m_tpEnvironment);
m_threadPool.reset(CreateThreadpool(nullptr)); m_threadPool.reset(CreateThreadpool(nullptr));
THROW_LAST_ERROR_IF_NULL(m_threadPool.get()); THROW_LAST_ERROR_IF_NULL(m_threadPool.get());
// Set min and max thread counts for custom thread pool // Set min and max thread counts for custom thread pool
THROW_LAST_ERROR_IF(!::SetThreadpoolThreadMinimum(m_threadPool.get(), countMinThread)); THROW_LAST_ERROR_IF(!::SetThreadpoolThreadMinimum(m_threadPool.get(), countMinThread));
SetThreadpoolThreadMaximum(m_threadPool.get(), countMaxThread); SetThreadpoolThreadMaximum(m_threadPool.get(), countMaxThread);
SetThreadpoolCallbackPool(&m_tpEnvironment, m_threadPool.get()); SetThreadpoolCallbackPool(&m_tpEnvironment, m_threadPool.get());
} }
wil::unique_threadpool_work create_tp(PTP_WORK_CALLBACK callback, void* pv) wil::unique_threadpool_work create_tp(PTP_WORK_CALLBACK callback, void* pv)
{ {
wil::unique_threadpool_work newThreadpool(CreateThreadpoolWork(callback, pv, (m_threadPool) ? &m_tpEnvironment : nullptr)); wil::unique_threadpool_work newThreadpool(CreateThreadpoolWork(callback, pv, (m_threadPool) ? &m_tpEnvironment : nullptr));
THROW_LAST_ERROR_IF_NULL(newThreadpool.get()); THROW_LAST_ERROR_IF_NULL(newThreadpool.get());
return newThreadpool; return newThreadpool;
} }
void reset() void reset()
{ {
m_threadPool.reset(); m_threadPool.reset();
m_tpEnvironment.reset(); m_tpEnvironment.reset();
} }
}; };
using SimpleFunction_t = std::function<void()>; using SimpleFunction_t = std::function<void()>;
using WaitableFunction_t = std::shared_ptr<WslBaseThreadPoolWaitableResult>; using WaitableFunction_t = std::shared_ptr<WslBaseThreadPoolWaitableResult>;
using FunctionVariant_t = std::variant<SimpleFunction_t, WaitableFunction_t>; using FunctionVariant_t = std::variant<SimpleFunction_t, WaitableFunction_t>;
// the lock must be destroyed *after* the TP object (thus must be declared first) // the lock must be destroyed *after* the TP object (thus must be declared first)
// since the lock is used in the TP callback // since the lock is used in the TP callback
// the lock is mutable to allow us to acquire the lock in const methods // the lock is mutable to allow us to acquire the lock in const methods
mutable wil::srwlock m_lock; mutable wil::srwlock m_lock;
TPEnvironment m_tpEnvironment; TPEnvironment m_tpEnvironment;
wil::unique_threadpool_work m_tpHandle; wil::unique_threadpool_work m_tpHandle;
std::deque<FunctionVariant_t> m_workItems; std::deque<FunctionVariant_t> m_workItems;
mutable LONG64 m_threadpoolThreadId{0}; // useful for callers to assert they are running within the queue mutable LONG64 m_threadpoolThreadId{0}; // useful for callers to assert they are running within the queue
bool m_isCanceled{false}; bool m_isCanceled{false};
static void CALLBACK WorkCallback(PTP_CALLBACK_INSTANCE, void* Context, PTP_WORK) noexcept static void CALLBACK WorkCallback(PTP_CALLBACK_INSTANCE, void* Context, PTP_WORK) noexcept
try try
{ {
auto* pThis = static_cast<WslCoreMessageQueue*>(Context); auto* pThis = static_cast<WslCoreMessageQueue*>(Context);
FunctionVariant_t work; FunctionVariant_t work;
{ {
const auto queueLock = pThis->m_lock.lock_exclusive(); const auto queueLock = pThis->m_lock.lock_exclusive();
if (pThis->m_workItems.empty()) if (pThis->m_workItems.empty())
{ {
// pThis object is being destroyed and the queue was cleared // pThis object is being destroyed and the queue was cleared
return; return;
} }
std::swap(work, pThis->m_workItems.front()); std::swap(work, pThis->m_workItems.front());
pThis->m_workItems.pop_front(); pThis->m_workItems.pop_front();
InterlockedExchange64(&pThis->m_threadpoolThreadId, GetThreadId(GetCurrentThread())); InterlockedExchange64(&pThis->m_threadpoolThreadId, GetThreadId(GetCurrentThread()));
} }
// run the tasks outside the WslCoreMessageQueue lock // run the tasks outside the WslCoreMessageQueue lock
const auto resetThreadIdOnExit = wil::scope_exit([pThis] { InterlockedExchange64(&pThis->m_threadpoolThreadId, 0ll); }); const auto resetThreadIdOnExit = wil::scope_exit([pThis] { InterlockedExchange64(&pThis->m_threadpoolThreadId, 0ll); });
if (work.index() == 0) if (work.index() == 0)
{ {
const auto& workItem = std::get<SimpleFunction_t>(work); const auto& workItem = std::get<SimpleFunction_t>(work);
workItem(); workItem();
} }
else else
{ {
const auto& waitableWorkItem = std::get<WaitableFunction_t>(work); const auto& waitableWorkItem = std::get<WaitableFunction_t>(work);
waitableWorkItem->run(); waitableWorkItem->run();
} }
} }
CATCH_LOG() CATCH_LOG()
}; };
} // namespace wsl::core } // namespace wsl::core

View File

@ -1,110 +1,110 @@
// Copyright (C) Microsoft Corporation. All rights reserved. // Copyright (C) Microsoft Corporation. All rights reserved.
#include "precomp.h" #include "precomp.h"
#include "hns_schema.h" #include "hns_schema.h"
#include "WslCoreNetworkEndpointSettings.h" #include "WslCoreNetworkEndpointSettings.h"
#include "WslCoreHostDnsInfo.h" #include "WslCoreHostDnsInfo.h"
using namespace wsl::shared; using namespace wsl::shared;
std::shared_ptr<wsl::core::networking::NetworkSettings> wsl::core::networking::GetEndpointSettings(const hns::HNSEndpoint& properties) std::shared_ptr<wsl::core::networking::NetworkSettings> wsl::core::networking::GetEndpointSettings(const hns::HNSEndpoint& properties)
{ {
EndpointIpAddress address{}; EndpointIpAddress address{};
address.Address = windows::common::string::StringToSockAddrInet(properties.IPAddress); address.Address = windows::common::string::StringToSockAddrInet(properties.IPAddress);
address.AddressString = properties.IPAddress; address.AddressString = properties.IPAddress;
address.PrefixLength = properties.PrefixLength; address.PrefixLength = properties.PrefixLength;
EndpointRoute route{}; EndpointRoute route{};
route.DestinationPrefix.PrefixLength = 0; route.DestinationPrefix.PrefixLength = 0;
IN4ADDR_SETANY(&route.DestinationPrefix.Prefix.Ipv4); IN4ADDR_SETANY(&route.DestinationPrefix.Prefix.Ipv4);
route.DestinationPrefixString = LX_INIT_UNSPECIFIED_ADDRESS; route.DestinationPrefixString = LX_INIT_UNSPECIFIED_ADDRESS;
route.NextHop = windows::common::string::StringToSockAddrInet(properties.GatewayAddress); route.NextHop = windows::common::string::StringToSockAddrInet(properties.GatewayAddress);
route.NextHopString = properties.GatewayAddress; route.NextHopString = properties.GatewayAddress;
return std::make_shared<wsl::core::networking::NetworkSettings>( return std::make_shared<wsl::core::networking::NetworkSettings>(
properties.InterfaceConstraint.InterfaceGuid, properties.InterfaceConstraint.InterfaceGuid,
address, address,
route, route,
properties.MacAddress, properties.MacAddress,
L"unuseddevicename", L"unuseddevicename",
properties.InterfaceConstraint.InterfaceIndex, properties.InterfaceConstraint.InterfaceIndex,
properties.InterfaceConstraint.InterfaceMediaType, properties.InterfaceConstraint.InterfaceMediaType,
properties.DNSServerList); properties.DNSServerList);
} }
std::shared_ptr<wsl::core::networking::NetworkSettings> wsl::core::networking::GetHostEndpointSettings() std::shared_ptr<wsl::core::networking::NetworkSettings> wsl::core::networking::GetHostEndpointSettings()
{ {
HostDnsInfo dnsInfo; HostDnsInfo dnsInfo;
dnsInfo.UpdateNetworkInformation(); dnsInfo.UpdateNetworkInformation();
auto addresses = dnsInfo.CurrentAddresses(); auto addresses = dnsInfo.CurrentAddresses();
auto bestIndex = GetBestInterface(); auto bestIndex = GetBestInterface();
auto bestInterfacePtr = auto bestInterfacePtr =
std::find_if(addresses.cbegin(), addresses.cend(), [&](const auto& address) { return address->IfIndex == bestIndex; }); std::find_if(addresses.cbegin(), addresses.cend(), [&](const auto& address) { return address->IfIndex == bestIndex; });
if (bestInterfacePtr == addresses.end()) if (bestInterfacePtr == addresses.end())
{ {
return std::make_shared<NetworkSettings>(); return std::make_shared<NetworkSettings>();
} }
const auto& bestInterface = *bestInterfacePtr; const auto& bestInterface = *bestInterfacePtr;
std::wstring macAddress = wsl::shared::string::FormatMacAddress( std::wstring macAddress = wsl::shared::string::FormatMacAddress(
wsl::shared::string::MacAddress{ wsl::shared::string::MacAddress{
bestInterface->PhysicalAddress[0], bestInterface->PhysicalAddress[0],
bestInterface->PhysicalAddress[1], bestInterface->PhysicalAddress[1],
bestInterface->PhysicalAddress[2], bestInterface->PhysicalAddress[2],
bestInterface->PhysicalAddress[3], bestInterface->PhysicalAddress[3],
bestInterface->PhysicalAddress[4], bestInterface->PhysicalAddress[4],
bestInterface->PhysicalAddress[5]}, bestInterface->PhysicalAddress[5]},
L'-'); L'-');
EndpointIpAddress address{}; EndpointIpAddress address{};
auto firstIpv4Address = bestInterface->FirstUnicastAddress; auto firstIpv4Address = bestInterface->FirstUnicastAddress;
while (firstIpv4Address && firstIpv4Address->Address.lpSockaddr->sa_family != AF_INET) while (firstIpv4Address && firstIpv4Address->Address.lpSockaddr->sa_family != AF_INET)
{ {
firstIpv4Address = firstIpv4Address->Next; firstIpv4Address = firstIpv4Address->Next;
} }
if (firstIpv4Address) if (firstIpv4Address)
{ {
address.Address = *reinterpret_cast<SOCKADDR_INET*>(firstIpv4Address->Address.lpSockaddr); address.Address = *reinterpret_cast<SOCKADDR_INET*>(firstIpv4Address->Address.lpSockaddr);
address.AddressString = windows::common::string::SockAddrInetToWstring(address.Address); address.AddressString = windows::common::string::SockAddrInetToWstring(address.Address);
address.PrefixLength = firstIpv4Address->OnLinkPrefixLength; address.PrefixLength = firstIpv4Address->OnLinkPrefixLength;
} }
EndpointRoute route{}; EndpointRoute route{};
PIP_ADAPTER_GATEWAY_ADDRESS nextGatewayAddress = bestInterface->FirstGatewayAddress; PIP_ADAPTER_GATEWAY_ADDRESS nextGatewayAddress = bestInterface->FirstGatewayAddress;
while (nextGatewayAddress && nextGatewayAddress->Address.lpSockaddr->sa_family != AF_INET) while (nextGatewayAddress && nextGatewayAddress->Address.lpSockaddr->sa_family != AF_INET)
{ {
nextGatewayAddress = nextGatewayAddress->Next; nextGatewayAddress = nextGatewayAddress->Next;
} }
if (nextGatewayAddress) if (nextGatewayAddress)
{ {
route.DestinationPrefix.PrefixLength = 0; route.DestinationPrefix.PrefixLength = 0;
IN4ADDR_SETANY(&route.DestinationPrefix.Prefix.Ipv4); IN4ADDR_SETANY(&route.DestinationPrefix.Prefix.Ipv4);
route.DestinationPrefixString = LX_INIT_UNSPECIFIED_ADDRESS; route.DestinationPrefixString = LX_INIT_UNSPECIFIED_ADDRESS;
route.NextHop = *reinterpret_cast<SOCKADDR_INET*>(nextGatewayAddress->Address.lpSockaddr); route.NextHop = *reinterpret_cast<SOCKADDR_INET*>(nextGatewayAddress->Address.lpSockaddr);
route.NextHopString = windows::common::string::SockAddrInetToWstring(route.NextHop); route.NextHopString = windows::common::string::SockAddrInetToWstring(route.NextHop);
} }
else if (address.Address.si_family == AF_INET) else if (address.Address.si_family == AF_INET)
{ {
IN_ADDR default_route{}; IN_ADDR default_route{};
default_route.s_addr = htonl((ntohl(address.Address.Ipv4.sin_addr.s_addr) & ~((1 << (32 - address.PrefixLength)) - 1)) | 1); default_route.s_addr = htonl((ntohl(address.Address.Ipv4.sin_addr.s_addr) & ~((1 << (32 - address.PrefixLength)) - 1)) | 1);
route.DestinationPrefix.PrefixLength = 0; route.DestinationPrefix.PrefixLength = 0;
IN4ADDR_SETANY(&route.DestinationPrefix.Prefix.Ipv4); IN4ADDR_SETANY(&route.DestinationPrefix.Prefix.Ipv4);
route.DestinationPrefixString = LX_INIT_UNSPECIFIED_ADDRESS; route.DestinationPrefixString = LX_INIT_UNSPECIFIED_ADDRESS;
IN4ADDR_SETSOCKADDR(&route.NextHop.Ipv4, &default_route, 0); IN4ADDR_SETSOCKADDR(&route.NextHop.Ipv4, &default_route, 0);
route.NextHopString = windows::common::string::SockAddrInetToWstring(route.NextHop); route.NextHopString = windows::common::string::SockAddrInetToWstring(route.NextHop);
} }
std::wstring dnsServerList; std::wstring dnsServerList;
for (const auto& serverAddress : dnsInfo.GetDnsSettings(DnsSettingsFlags::IncludeVpn).Servers) for (const auto& serverAddress : dnsInfo.GetDnsSettings(DnsSettingsFlags::IncludeVpn).Servers)
{ {
if (!dnsServerList.empty()) if (!dnsServerList.empty())
{ {
dnsServerList += L","; dnsServerList += L",";
} }
dnsServerList += wsl::shared::string::MultiByteToWide(serverAddress); dnsServerList += wsl::shared::string::MultiByteToWide(serverAddress);
} }
return std::shared_ptr<NetworkSettings>(new NetworkSettings( return std::shared_ptr<NetworkSettings>(new NetworkSettings(
bestInterface->NetworkGuid, address, route, macAddress, {}, bestInterface->IfIndex, bestInterface->IfType, dnsServerList)); bestInterface->NetworkGuid, address, route, macAddress, {}, bestInterface->IfIndex, bestInterface->IfType, dnsServerList));
} }

View File

@ -1,398 +1,398 @@
// Copyright (C) Microsoft Corporation. All rights reserved. // Copyright (C) Microsoft Corporation. All rights reserved.
#pragma once #pragma once
#include <algorithm> #include <algorithm>
#include <set> #include <set>
#include <string> #include <string>
#include <windows.h> #include <windows.h>
#include <mstcpip.h> #include <mstcpip.h>
#include <ws2ipdef.h> #include <ws2ipdef.h>
#include <netioapi.h> #include <netioapi.h>
#include "hcs.hpp" #include "hcs.hpp"
#include "lxinitshared.h" #include "lxinitshared.h"
#include "Stringify.h" #include "Stringify.h"
#include "stringshared.h" #include "stringshared.h"
#include "WslCoreNetworkingSupport.h" #include "WslCoreNetworkingSupport.h"
#include "hns_schema.h" #include "hns_schema.h"
namespace wsl::core::networking { namespace wsl::core::networking {
constexpr auto AddEndpointRetryPeriod = std::chrono::milliseconds(100); constexpr auto AddEndpointRetryPeriod = std::chrono::milliseconds(100);
constexpr auto AddEndpointRetryTimeout = std::chrono::seconds(3); constexpr auto AddEndpointRetryTimeout = std::chrono::seconds(3);
constexpr auto AddEndpointRetryPredicate = [] { constexpr auto AddEndpointRetryPredicate = [] {
// Don't retry if ModifyComputeSystem fails with: // Don't retry if ModifyComputeSystem fails with:
// HCN_E_ENDPOINT_NOT_FOUND - indicates that the underlying network object was deleted. // HCN_E_ENDPOINT_NOT_FOUND - indicates that the underlying network object was deleted.
// HCN_E_ENDPOINT_ALREADY_ATTACHED - occurs when HNS was restarted before the endpoints were removed. // HCN_E_ENDPOINT_ALREADY_ATTACHED - occurs when HNS was restarted before the endpoints were removed.
// VM_E_INVALID_STATE - occurs when the VM has been terminated. // VM_E_INVALID_STATE - occurs when the VM has been terminated.
const auto result = wil::ResultFromCaughtException(); const auto result = wil::ResultFromCaughtException();
return result != HCN_E_ENDPOINT_NOT_FOUND && result != HCN_E_ENDPOINT_ALREADY_ATTACHED && result != VM_E_INVALID_STATE; return result != HCN_E_ENDPOINT_NOT_FOUND && result != HCN_E_ENDPOINT_ALREADY_ATTACHED && result != VM_E_INVALID_STATE;
}; };
struct EndpointIpAddress struct EndpointIpAddress
{ {
SOCKADDR_INET Address{}; SOCKADDR_INET Address{};
std::wstring AddressString{}; std::wstring AddressString{};
unsigned char PrefixLength = 0; unsigned char PrefixLength = 0;
unsigned int PrefixOrigin = 0; unsigned int PrefixOrigin = 0;
unsigned int SuffixOrigin = 0; unsigned int SuffixOrigin = 0;
// The following field can be changed from a const iterator in SyncIpStateWithLinux - that's why it's marked mutable. // The following field can be changed from a const iterator in SyncIpStateWithLinux - that's why it's marked mutable.
mutable unsigned int PreferredLifetime = 0; mutable unsigned int PreferredLifetime = 0;
EndpointIpAddress() = default; EndpointIpAddress() = default;
~EndpointIpAddress() noexcept = default; ~EndpointIpAddress() noexcept = default;
EndpointIpAddress(EndpointIpAddress&&) = default; EndpointIpAddress(EndpointIpAddress&&) = default;
EndpointIpAddress& operator=(EndpointIpAddress&&) = default; EndpointIpAddress& operator=(EndpointIpAddress&&) = default;
EndpointIpAddress(const EndpointIpAddress&) = default; EndpointIpAddress(const EndpointIpAddress&) = default;
EndpointIpAddress& operator=(const EndpointIpAddress&) = default; EndpointIpAddress& operator=(const EndpointIpAddress&) = default;
explicit EndpointIpAddress(const MIB_UNICASTIPADDRESS_ROW& AddressRow) : explicit EndpointIpAddress(const MIB_UNICASTIPADDRESS_ROW& AddressRow) :
Address(AddressRow.Address), Address(AddressRow.Address),
AddressString(windows::common::string::SockAddrInetToWstring(AddressRow.Address)), AddressString(windows::common::string::SockAddrInetToWstring(AddressRow.Address)),
PrefixLength(AddressRow.OnLinkPrefixLength), PrefixLength(AddressRow.OnLinkPrefixLength),
PrefixOrigin(AddressRow.PrefixOrigin), PrefixOrigin(AddressRow.PrefixOrigin),
SuffixOrigin(AddressRow.SuffixOrigin), SuffixOrigin(AddressRow.SuffixOrigin),
// We treat the preferred lifetime field as effective DAD state - 0 is not preferred, anything else is preferred. // We treat the preferred lifetime field as effective DAD state - 0 is not preferred, anything else is preferred.
// We do this for convenience, as we can't directly set the DAD state of an address into the guest, but we // We do this for convenience, as we can't directly set the DAD state of an address into the guest, but we
// we can set an address's preferred lifetime (in Linux, at least). // we can set an address's preferred lifetime (in Linux, at least).
PreferredLifetime(AddressRow.DadState == IpDadStatePreferred ? 0xFFFFFFFF : 0) PreferredLifetime(AddressRow.DadState == IpDadStatePreferred ? 0xFFFFFFFF : 0)
{ {
} }
// operator== is deliberately not comparing PreferredLifetime (DAD state) for equality - only the address portion // operator== is deliberately not comparing PreferredLifetime (DAD state) for equality - only the address portion
bool operator==(const EndpointIpAddress& rhs) const noexcept bool operator==(const EndpointIpAddress& rhs) const noexcept
{ {
return Address == rhs.Address && PrefixLength == rhs.PrefixLength; return Address == rhs.Address && PrefixLength == rhs.PrefixLength;
} }
bool operator<(const EndpointIpAddress& rhs) const noexcept bool operator<(const EndpointIpAddress& rhs) const noexcept
{ {
if (Address == rhs.Address) if (Address == rhs.Address)
{ {
return PrefixLength < rhs.PrefixLength; return PrefixLength < rhs.PrefixLength;
} }
return Address < rhs.Address; return Address < rhs.Address;
} }
void Clear() noexcept void Clear() noexcept
{ {
Address = {}; Address = {};
AddressString.clear(); AddressString.clear();
PrefixLength = 0; PrefixLength = 0;
PrefixOrigin = 0; PrefixOrigin = 0;
SuffixOrigin = 0; SuffixOrigin = 0;
} }
std::wstring GetPrefix() const std::wstring GetPrefix() const
{ {
SOCKADDR_INET address{Address}; SOCKADDR_INET address{Address};
unsigned char* addressPointer{nullptr}; unsigned char* addressPointer{nullptr};
if (Address.si_family == AF_INET) if (Address.si_family == AF_INET)
{ {
addressPointer = reinterpret_cast<unsigned char*>(&address.Ipv4.sin_addr); addressPointer = reinterpret_cast<unsigned char*>(&address.Ipv4.sin_addr);
} }
else if (Address.si_family == AF_INET6) else if (Address.si_family == AF_INET6)
{ {
addressPointer = address.Ipv6.sin6_addr.u.Byte; addressPointer = address.Ipv6.sin6_addr.u.Byte;
} }
else else
{ {
return L""; return L"";
} }
constexpr int c_numBitsPerByte = 8; constexpr int c_numBitsPerByte = 8;
for (int i = 0, currPrefixLength = PrefixLength; i < INET_ADDR_LENGTH(Address.si_family); i++, currPrefixLength -= c_numBitsPerByte) for (int i = 0, currPrefixLength = PrefixLength; i < INET_ADDR_LENGTH(Address.si_family); i++, currPrefixLength -= c_numBitsPerByte)
{ {
if (currPrefixLength < c_numBitsPerByte) if (currPrefixLength < c_numBitsPerByte)
{ {
const int bitShiftAmt = c_numBitsPerByte - std::max(currPrefixLength, 0); const int bitShiftAmt = c_numBitsPerByte - std::max(currPrefixLength, 0);
addressPointer[i] &= (0xFF >> bitShiftAmt) << bitShiftAmt; addressPointer[i] &= (0xFF >> bitShiftAmt) << bitShiftAmt;
} }
} }
const auto addressString = windows::common::string::SockAddrInetToWstring(address); const auto addressString = windows::common::string::SockAddrInetToWstring(address);
WI_ASSERT(!addressString.empty()); WI_ASSERT(!addressString.empty());
if (addressString.empty()) if (addressString.empty())
{ {
// just return an empty string if we have a bad address // just return an empty string if we have a bad address
return addressString; return addressString;
} }
return std::format(L"{}/{}", addressString, PrefixLength); return std::format(L"{}/{}", addressString, PrefixLength);
} }
std::wstring GetIpv4BroadcastMask() const std::wstring GetIpv4BroadcastMask() const
{ {
// start with all bits set, then shift off the prefix // start with all bits set, then shift off the prefix
ULONG prefixMask{0xffffffff}; ULONG prefixMask{0xffffffff};
prefixMask <<= PrefixLength; prefixMask <<= PrefixLength;
prefixMask >>= PrefixLength; prefixMask >>= PrefixLength;
SOCKADDR_INET address{Address}; SOCKADDR_INET address{Address};
// flip to host-order, then apply the mask // flip to host-order, then apply the mask
ULONG hostOrder = ntohl(address.Ipv4.sin_addr.S_un.S_addr); ULONG hostOrder = ntohl(address.Ipv4.sin_addr.S_un.S_addr);
hostOrder |= prefixMask; hostOrder |= prefixMask;
address.Ipv4.sin_addr.S_un.S_addr = htonl(hostOrder); address.Ipv4.sin_addr.S_un.S_addr = htonl(hostOrder);
return windows::common::string::SockAddrInetToWstring(address); return windows::common::string::SockAddrInetToWstring(address);
} }
bool IsPreferred() const noexcept bool IsPreferred() const noexcept
{ {
return PreferredLifetime > 0; return PreferredLifetime > 0;
} }
bool IsLinkLocal() const bool IsLinkLocal() const
{ {
return (Address.si_family == AF_INET && IN4_IS_ADDR_LINKLOCAL(&Address.Ipv4.sin_addr)) || return (Address.si_family == AF_INET && IN4_IS_ADDR_LINKLOCAL(&Address.Ipv4.sin_addr)) ||
(Address.si_family == AF_INET6 && IN6_IS_ADDR_LINKLOCAL(&Address.Ipv6.sin6_addr)); (Address.si_family == AF_INET6 && IN6_IS_ADDR_LINKLOCAL(&Address.Ipv6.sin6_addr));
} }
}; };
struct EndpointRoute struct EndpointRoute
{ {
ADDRESS_FAMILY Family = AF_INET; ADDRESS_FAMILY Family = AF_INET;
IP_ADDRESS_PREFIX DestinationPrefix{}; IP_ADDRESS_PREFIX DestinationPrefix{};
std::wstring DestinationPrefixString{}; std::wstring DestinationPrefixString{};
SOCKADDR_INET NextHop{}; SOCKADDR_INET NextHop{};
std::wstring NextHopString{}; std::wstring NextHopString{};
unsigned char SitePrefixLength = 0; unsigned char SitePrefixLength = 0;
unsigned int Metric = 0; unsigned int Metric = 0;
bool IsAutoGeneratedPrefixRoute = false; bool IsAutoGeneratedPrefixRoute = false;
EndpointRoute() = default; EndpointRoute() = default;
~EndpointRoute() noexcept = default; ~EndpointRoute() noexcept = default;
EndpointRoute(EndpointRoute&&) = default; EndpointRoute(EndpointRoute&&) = default;
EndpointRoute& operator=(EndpointRoute&&) = default; EndpointRoute& operator=(EndpointRoute&&) = default;
EndpointRoute(const EndpointRoute&) = default; EndpointRoute(const EndpointRoute&) = default;
EndpointRoute& operator=(const EndpointRoute&) = default; EndpointRoute& operator=(const EndpointRoute&) = default;
EndpointRoute(const MIB_IPFORWARD_ROW2& RouteRow) : EndpointRoute(const MIB_IPFORWARD_ROW2& RouteRow) :
Family(RouteRow.NextHop.si_family), Family(RouteRow.NextHop.si_family),
DestinationPrefix(RouteRow.DestinationPrefix), DestinationPrefix(RouteRow.DestinationPrefix),
DestinationPrefixString(windows::common::string::SockAddrInetToWstring(RouteRow.DestinationPrefix.Prefix)), DestinationPrefixString(windows::common::string::SockAddrInetToWstring(RouteRow.DestinationPrefix.Prefix)),
NextHop(RouteRow.NextHop), NextHop(RouteRow.NextHop),
NextHopString(windows::common::string::SockAddrInetToWstring(RouteRow.NextHop)), NextHopString(windows::common::string::SockAddrInetToWstring(RouteRow.NextHop)),
SitePrefixLength(RouteRow.SitePrefixLength), SitePrefixLength(RouteRow.SitePrefixLength),
Metric(RouteRow.Metric) Metric(RouteRow.Metric)
{ {
} }
unsigned char GetMaxPrefixLength() const unsigned char GetMaxPrefixLength() const
{ {
return (Family == AF_INET) ? 32 : 128; return (Family == AF_INET) ? 32 : 128;
} }
std::wstring GetFullDestinationPrefix() const std::wstring GetFullDestinationPrefix() const
{ {
return std::format(L"{}/{}", DestinationPrefixString, static_cast<unsigned int>(DestinationPrefix.PrefixLength)); return std::format(L"{}/{}", DestinationPrefixString, static_cast<unsigned int>(DestinationPrefix.PrefixLength));
} }
bool IsNextHopOnlink() const noexcept bool IsNextHopOnlink() const noexcept
{ {
return (Family == AF_INET && NextHopString == LX_INIT_UNSPECIFIED_ADDRESS) || return (Family == AF_INET && NextHopString == LX_INIT_UNSPECIFIED_ADDRESS) ||
(Family == AF_INET6 && NextHopString == LX_INIT_UNSPECIFIED_V6_ADDRESS); (Family == AF_INET6 && NextHopString == LX_INIT_UNSPECIFIED_V6_ADDRESS);
} }
bool IsDefault() const noexcept bool IsDefault() const noexcept
{ {
return (Family == AF_INET && DestinationPrefixString == LX_INIT_UNSPECIFIED_ADDRESS) || return (Family == AF_INET && DestinationPrefixString == LX_INIT_UNSPECIFIED_ADDRESS) ||
(Family == AF_INET6 && DestinationPrefixString == LX_INIT_UNSPECIFIED_V6_ADDRESS); (Family == AF_INET6 && DestinationPrefixString == LX_INIT_UNSPECIFIED_V6_ADDRESS);
} }
bool IsUnicastAddressRoute() const noexcept bool IsUnicastAddressRoute() const noexcept
{ {
return (Family == AF_INET && DestinationPrefix.PrefixLength == 32) || (Family == AF_INET6 && DestinationPrefix.PrefixLength == 128); return (Family == AF_INET && DestinationPrefix.PrefixLength == 32) || (Family == AF_INET6 && DestinationPrefix.PrefixLength == 128);
} }
std::wstring ToString() const std::wstring ToString() const
{ {
return std::format(L"{}=>{} [metric {}]", GetFullDestinationPrefix(), NextHopString, Metric); return std::format(L"{}=>{} [metric {}]", GetFullDestinationPrefix(), NextHopString, Metric);
} }
bool operator==(const EndpointRoute& rhs) const noexcept bool operator==(const EndpointRoute& rhs) const noexcept
{ {
return Family == rhs.Family && DestinationPrefix.PrefixLength == rhs.DestinationPrefix.PrefixLength && return Family == rhs.Family && DestinationPrefix.PrefixLength == rhs.DestinationPrefix.PrefixLength &&
DestinationPrefix.Prefix == rhs.DestinationPrefix.Prefix && NextHop == rhs.NextHop && DestinationPrefix.Prefix == rhs.DestinationPrefix.Prefix && NextHop == rhs.NextHop &&
SitePrefixLength == rhs.SitePrefixLength && Metric == rhs.Metric; SitePrefixLength == rhs.SitePrefixLength && Metric == rhs.Metric;
} }
bool operator!=(const EndpointRoute& other) const bool operator!=(const EndpointRoute& other) const
{ {
return !(*this == other); return !(*this == other);
} }
// sort by family, then by next-hop (on-link routes first), then by prefix, then by metric // sort by family, then by next-hop (on-link routes first), then by prefix, then by metric
bool operator<(const EndpointRoute& rhs) const noexcept bool operator<(const EndpointRoute& rhs) const noexcept
{ {
if (Family == rhs.Family) if (Family == rhs.Family)
{ {
if (NextHop == rhs.NextHop) if (NextHop == rhs.NextHop)
{ {
if (DestinationPrefix.Prefix == rhs.DestinationPrefix.Prefix) if (DestinationPrefix.Prefix == rhs.DestinationPrefix.Prefix)
{ {
if (DestinationPrefix.PrefixLength == rhs.DestinationPrefix.PrefixLength) if (DestinationPrefix.PrefixLength == rhs.DestinationPrefix.PrefixLength)
{ {
if (Metric == rhs.Metric) if (Metric == rhs.Metric)
{ {
return SitePrefixLength < rhs.SitePrefixLength; return SitePrefixLength < rhs.SitePrefixLength;
} }
return Metric < rhs.Metric; return Metric < rhs.Metric;
} }
return DestinationPrefix.PrefixLength < rhs.DestinationPrefix.PrefixLength; return DestinationPrefix.PrefixLength < rhs.DestinationPrefix.PrefixLength;
} }
return DestinationPrefix.Prefix < rhs.DestinationPrefix.Prefix; return DestinationPrefix.Prefix < rhs.DestinationPrefix.Prefix;
} }
return NextHop < rhs.NextHop; return NextHop < rhs.NextHop;
} }
return Family < rhs.Family; return Family < rhs.Family;
} }
}; };
struct NetworkSettings struct NetworkSettings
{ {
NetworkSettings() = default; NetworkSettings() = default;
NetworkSettings( NetworkSettings(
const GUID& interfaceGuid, const GUID& interfaceGuid,
EndpointIpAddress preferredIpAddress, EndpointIpAddress preferredIpAddress,
EndpointRoute gateway, EndpointRoute gateway,
std::wstring macAddress, std::wstring macAddress,
std::wstring deviceName, std::wstring deviceName,
uint32_t interfaceIndex, uint32_t interfaceIndex,
uint32_t mediaType, uint32_t mediaType,
const std::wstring& dnsServerList) : const std::wstring& dnsServerList) :
InterfaceGuid(interfaceGuid), InterfaceGuid(interfaceGuid),
PreferredIpAddress(std::move(preferredIpAddress)), PreferredIpAddress(std::move(preferredIpAddress)),
MacAddress(std::move(macAddress)), MacAddress(std::move(macAddress)),
DeviceName(std::move(deviceName)), DeviceName(std::move(deviceName)),
InterfaceIndex(interfaceIndex), InterfaceIndex(interfaceIndex),
InterfaceType(mediaType) InterfaceType(mediaType)
{ {
Routes.emplace(std::move(gateway)); Routes.emplace(std::move(gateway));
DnsServers = wsl::shared::string::Split(dnsServerList, L','); DnsServers = wsl::shared::string::Split(dnsServerList, L',');
} }
GUID InterfaceGuid{}; GUID InterfaceGuid{};
EndpointIpAddress PreferredIpAddress{}; EndpointIpAddress PreferredIpAddress{};
std::set<EndpointIpAddress> IpAddresses{}; // Does not include PreferredIpAddress. std::set<EndpointIpAddress> IpAddresses{}; // Does not include PreferredIpAddress.
std::set<EndpointRoute> Routes{}; std::set<EndpointRoute> Routes{};
std::vector<std::wstring> DnsServers{}; std::vector<std::wstring> DnsServers{};
std::wstring MacAddress; std::wstring MacAddress;
std::wstring DeviceName; std::wstring DeviceName;
IF_INDEX InterfaceIndex = 0; IF_INDEX InterfaceIndex = 0;
IFTYPE InterfaceType = 0; IFTYPE InterfaceType = 0;
ULONG IPv4InterfaceMtu = 0; ULONG IPv4InterfaceMtu = 0;
ULONG IPv6InterfaceMtu = 0; ULONG IPv6InterfaceMtu = 0;
// some interfaces will only have an IPv4 or IPv6 interface // some interfaces will only have an IPv4 or IPv6 interface
std::optional<ULONG> IPv4InterfaceMetric = 0; std::optional<ULONG> IPv4InterfaceMetric = 0;
std::optional<ULONG> IPv6InterfaceMetric = 0; std::optional<ULONG> IPv6InterfaceMetric = 0;
bool IsHidden = false; bool IsHidden = false;
bool IsConnected = false; bool IsConnected = false;
bool IsMetered = false; bool IsMetered = false;
bool DisableIpv4DefaultRoutes = false; bool DisableIpv4DefaultRoutes = false;
bool DisableIpv6DefaultRoutes = false; bool DisableIpv6DefaultRoutes = false;
bool PendingUpdateToReconnectForMetered = false; bool PendingUpdateToReconnectForMetered = false;
bool PendingIPInterfaceUpdate = false; bool PendingIPInterfaceUpdate = false;
auto operator<=>(const NetworkSettings&) const = default; auto operator<=>(const NetworkSettings&) const = default;
std::wstring GetBestGatewayAddressString() const std::wstring GetBestGatewayAddressString() const
{ {
// Best is currently defined as simply the first IPv4 gateway. // Best is currently defined as simply the first IPv4 gateway.
for (const auto& route : Routes) for (const auto& route : Routes)
{ {
if (route.Family == AF_INET && route.DestinationPrefix.PrefixLength == 0 && route.DestinationPrefixString == LX_INIT_UNSPECIFIED_ADDRESS) if (route.Family == AF_INET && route.DestinationPrefix.PrefixLength == 0 && route.DestinationPrefixString == LX_INIT_UNSPECIFIED_ADDRESS)
{ {
return route.NextHopString; return route.NextHopString;
} }
} }
return {}; return {};
} }
SOCKADDR_INET GetBestGatewayAddress() const SOCKADDR_INET GetBestGatewayAddress() const
{ {
// Best is currently defined as simply the first IPv4 gateway. // Best is currently defined as simply the first IPv4 gateway.
for (const auto& route : Routes) for (const auto& route : Routes)
{ {
if (route.Family == AF_INET && route.DestinationPrefix.PrefixLength == 0 && route.DestinationPrefixString == LX_INIT_UNSPECIFIED_ADDRESS) if (route.Family == AF_INET && route.DestinationPrefix.PrefixLength == 0 && route.DestinationPrefixString == LX_INIT_UNSPECIFIED_ADDRESS)
{ {
return route.NextHop; return route.NextHop;
} }
} }
return {}; return {};
} }
std::wstring IpAddressesString() const std::wstring IpAddressesString() const
{ {
return std::accumulate(std::begin(IpAddresses), std::end(IpAddresses), std::wstring{}, [](const std::wstring& prev, const auto& addr) { return std::accumulate(std::begin(IpAddresses), std::end(IpAddresses), std::wstring{}, [](const std::wstring& prev, const auto& addr) {
return addr.AddressString + (prev.empty() ? L"" : L"," + prev); return addr.AddressString + (prev.empty() ? L"" : L"," + prev);
}); });
} }
std::wstring RoutesString() const std::wstring RoutesString() const
{ {
return std::accumulate(std::begin(Routes), std::end(Routes), std::wstring{}, [](const std::wstring& prev, const EndpointRoute& route) { return std::accumulate(std::begin(Routes), std::end(Routes), std::wstring{}, [](const std::wstring& prev, const EndpointRoute& route) {
return route.ToString() + (prev.empty() ? L"" : L"," + prev); return route.ToString() + (prev.empty() ? L"" : L"," + prev);
}); });
} }
std::wstring DnsServersString() const std::wstring DnsServersString() const
{ {
return wsl::shared::string::Join(DnsServers, L','); return wsl::shared::string::Join(DnsServers, L',');
} }
// will return ULONG_MAX if there's no configured MTU // will return ULONG_MAX if there's no configured MTU
ULONG GetEffectiveMtu() const noexcept ULONG GetEffectiveMtu() const noexcept
{ {
return std::min(IPv4InterfaceMtu > 0 ? IPv4InterfaceMtu : ULONG_MAX, IPv6InterfaceMtu > 0 ? IPv6InterfaceMtu : ULONG_MAX); return std::min(IPv4InterfaceMtu > 0 ? IPv4InterfaceMtu : ULONG_MAX, IPv6InterfaceMtu > 0 ? IPv6InterfaceMtu : ULONG_MAX);
} }
// will return zero if there's no configured metric // will return zero if there's no configured metric
ULONG GetMinimumMetric() const noexcept ULONG GetMinimumMetric() const noexcept
{ {
if (!IPv4InterfaceMetric.has_value() && !IPv6InterfaceMetric.has_value()) if (!IPv4InterfaceMetric.has_value() && !IPv6InterfaceMetric.has_value())
{ {
return 0; return 0;
} }
if (!IPv4InterfaceMetric.has_value()) if (!IPv4InterfaceMetric.has_value())
{ {
return IPv6InterfaceMetric.value(); return IPv6InterfaceMetric.value();
} }
if (!IPv6InterfaceMetric.has_value()) if (!IPv6InterfaceMetric.has_value())
{ {
return IPv4InterfaceMetric.value(); return IPv4InterfaceMetric.value();
} }
return std::min(IPv4InterfaceMetric.value(), IPv6InterfaceMetric.value()); return std::min(IPv4InterfaceMetric.value(), IPv6InterfaceMetric.value());
} }
}; };
std::shared_ptr<NetworkSettings> GetEndpointSettings(const wsl::shared::hns::HNSEndpoint& properties); std::shared_ptr<NetworkSettings> GetEndpointSettings(const wsl::shared::hns::HNSEndpoint& properties);
std::shared_ptr<NetworkSettings> GetHostEndpointSettings(); std::shared_ptr<NetworkSettings> GetHostEndpointSettings();
#define TRACE_NETWORKSETTINGS_OBJECT(settings) \ #define TRACE_NETWORKSETTINGS_OBJECT(settings) \
TraceLoggingValue((settings)->InterfaceGuid, "interfaceGuid"), TraceLoggingValue((settings)->InterfaceIndex, "interfaceIndex"), \ TraceLoggingValue((settings)->InterfaceGuid, "interfaceGuid"), TraceLoggingValue((settings)->InterfaceIndex, "interfaceIndex"), \
TraceLoggingValue((settings)->InterfaceType, "interfaceType"), \ TraceLoggingValue((settings)->InterfaceType, "interfaceType"), \
TraceLoggingValue((settings)->IsConnected, "isConnected"), TraceLoggingValue((settings)->IsMetered, "isMetered"), \ TraceLoggingValue((settings)->IsConnected, "isConnected"), TraceLoggingValue((settings)->IsMetered, "isMetered"), \
TraceLoggingValue((settings)->GetBestGatewayAddressString().c_str(), "bestGatewayAddress"), \ TraceLoggingValue((settings)->GetBestGatewayAddressString().c_str(), "bestGatewayAddress"), \
TraceLoggingValue((settings)->PreferredIpAddress.AddressString.c_str(), "preferredIpAddress"), \ TraceLoggingValue((settings)->PreferredIpAddress.AddressString.c_str(), "preferredIpAddress"), \
TraceLoggingValue((settings)->PreferredIpAddress.PrefixLength, "preferredIpAddressPrefixLength"), \ TraceLoggingValue((settings)->PreferredIpAddress.PrefixLength, "preferredIpAddressPrefixLength"), \
TraceLoggingValue((settings)->IpAddressesString().c_str(), "ipAddresses"), \ TraceLoggingValue((settings)->IpAddressesString().c_str(), "ipAddresses"), \
TraceLoggingValue((settings)->RoutesString().c_str(), "routes"), \ TraceLoggingValue((settings)->RoutesString().c_str(), "routes"), \
TraceLoggingValue((settings)->DnsServersString().c_str(), "dnsServerList"), \ TraceLoggingValue((settings)->DnsServersString().c_str(), "dnsServerList"), \
TraceLoggingValue((settings)->MacAddress.c_str(), "macAddress"), \ TraceLoggingValue((settings)->MacAddress.c_str(), "macAddress"), \
TraceLoggingValue((settings)->IPv4InterfaceMtu, "IPv4InterfaceMtu"), \ TraceLoggingValue((settings)->IPv4InterfaceMtu, "IPv4InterfaceMtu"), \
TraceLoggingValue((settings)->IPv6InterfaceMtu, "IPv6InterfaceMtu"), \ TraceLoggingValue((settings)->IPv6InterfaceMtu, "IPv6InterfaceMtu"), \
TraceLoggingValue((settings)->IPv4InterfaceMetric.value_or(0xffffffff), "IPv4InterfaceMetric"), \ TraceLoggingValue((settings)->IPv4InterfaceMetric.value_or(0xffffffff), "IPv4InterfaceMetric"), \
TraceLoggingValue((settings)->IPv6InterfaceMetric.value_or(0xffffffff), "IPv6InterfaceMetric"), \ TraceLoggingValue((settings)->IPv6InterfaceMetric.value_or(0xffffffff), "IPv6InterfaceMetric"), \
TraceLoggingValue((settings)->PendingIPInterfaceUpdate, "PendingIPInterfaceUpdate"), \ TraceLoggingValue((settings)->PendingIPInterfaceUpdate, "PendingIPInterfaceUpdate"), \
TraceLoggingValue((settings)->PendingUpdateToReconnectForMetered, "PendingUpdateToReconnectForMetered") TraceLoggingValue((settings)->PendingUpdateToReconnectForMetered, "PendingUpdateToReconnectForMetered")
} // namespace wsl::core::networking } // namespace wsl::core::networking

View File

@ -44,6 +44,7 @@ inline auto c_msixPackageFamilyName = L"MicrosoftCorporationII.WindowsSubsystemF
inline auto c_githubUrlOverrideRegistryValue = L"GitHubUrlOverride"; inline auto c_githubUrlOverrideRegistryValue = L"GitHubUrlOverride";
inline auto c_vhdFileExtension = L".vhd"; inline auto c_vhdFileExtension = L".vhd";
inline auto c_vhdxFileExtension = L".vhdx"; inline auto c_vhdxFileExtension = L".vhdx";
inline constexpr auto c_vmOwner = L"WSL"; // TODO-WSLA: Does this apply to WSLA ?
struct GitHubReleaseAsset struct GitHubReleaseAsset
{ {

View File

@ -1,6 +1,5 @@
set(SOURCES set(SOURCES
DistributionRegistration.cpp DistributionRegistration.cpp
LxssSecurity.cpp
LxssUserCallback.cpp LxssUserCallback.cpp
LxssUserSession.cpp LxssUserSession.cpp
LxssUserSessionFactory.cpp LxssUserSessionFactory.cpp
@ -11,10 +10,6 @@ set(SOURCES
PluginManager.cpp PluginManager.cpp
ServiceMain.cpp ServiceMain.cpp
BridgedNetworking.cpp BridgedNetworking.cpp
DeviceHostProxy.cpp
Dmesg.cpp
DnsTunnelingChannel.cpp
GnsChannel.cpp
GnsPortTrackerChannel.cpp GnsPortTrackerChannel.cpp
GnsRpcServer.cpp GnsRpcServer.cpp
GuestTelemetryLogger.cpp GuestTelemetryLogger.cpp
@ -22,21 +17,12 @@ set(SOURCES
LxssConsoleManager.cpp LxssConsoleManager.cpp
LxssCreateProcess.cpp LxssCreateProcess.cpp
MirroredNetworking.cpp MirroredNetworking.cpp
NatNetworking.cpp
RingBuffer.cpp
WslCoreFilesystem.cpp
WslCoreGuestNetworkService.cpp WslCoreGuestNetworkService.cpp
WslCoreHostDnsInfo.cpp
WslCoreInstance.cpp WslCoreInstance.cpp
WslMirroredNetworking.cpp WslMirroredNetworking.cpp
WslCoreNetworkEndpointSettings.cpp
DnsResolver.cpp
WslCoreTcpIpStateTracking.cpp WslCoreTcpIpStateTracking.cpp
WslCoreVm.cpp WslCoreVm.cpp
VirtioNetworking.cpp VirtioNetworking.cpp
WSLAUserSession.cpp
WSLAUserSessionFactory.cpp
WSLAVirtualMachine.cpp
main.rc main.rc
${CMAKE_CURRENT_BINARY_DIR}/../mc/${TARGET_PLATFORM}/${CMAKE_BUILD_TYPE}/wsleventschema.rc ${CMAKE_CURRENT_BINARY_DIR}/../mc/${TARGET_PLATFORM}/${CMAKE_BUILD_TYPE}/wsleventschema.rc
application.manifest) application.manifest)
@ -44,7 +30,6 @@ set(SOURCES
set(HEADERS set(HEADERS
../../inc/comservicehelper.h ../../inc/comservicehelper.h
DistributionRegistration.h DistributionRegistration.h
LxssSecurity.h
LxssUserCallback.h LxssUserCallback.h
LxssUserSession.h LxssUserSession.h
LxssUserSessionFactory.h LxssUserSessionFactory.h
@ -53,35 +38,20 @@ set(HEADERS
PluginManager.h PluginManager.h
LxssInstance.h LxssInstance.h
BridgedNetworking.h BridgedNetworking.h
DeviceHostProxy.h
Dmesg.h
DnsTunnelingChannel.h
GnsChannel.h
GnsPortTrackerChannel.h GnsPortTrackerChannel.h
GnsRpcServer.h GnsRpcServer.h
GuestTelemetryLogger.h GuestTelemetryLogger.h
INetworkingEngine.h
IMirroredNetworkManager.h IMirroredNetworkManager.h
Lifetime.h Lifetime.h
LxssConsoleManager.h LxssConsoleManager.h
LxssCreateProcess.h LxssCreateProcess.h
MirroredNetworking.h MirroredNetworking.h
NatNetworking.h
RingBuffer.h
WslCoreFilesystem.h
WslCoreGuestNetworkService.h WslCoreGuestNetworkService.h
WslCoreHostDnsInfo.h
WslCoreInstance.h WslCoreInstance.h
WslCoreMessageQueue.h
WslMirroredNetworking.h WslMirroredNetworking.h
WslCoreNetworkEndpoint.h WslCoreNetworkEndpoint.h
WslCoreNetworkEndpointSettings.h
DnsResolver.h
WslCoreTcpIpStateTracking.h WslCoreTcpIpStateTracking.h
WslCoreVm.h WslCoreVm.h)
WSLAUserSession.h
WSLAUserSessionFactory.h
WSLAVirtualMachine.h)
include_directories(${CMAKE_SOURCE_DIR}/src/windows/wslaclient) include_directories(${CMAKE_SOURCE_DIR}/src/windows/wslaclient)

View File

@ -256,7 +256,7 @@ void MirroredNetworking::Initialize()
m_config.FirewallConfig.Enabled(), m_config.IgnoredPorts, m_runtimeId, m_gnsRpcServer->GetServerUuid(), s_GuestNetworkServiceCallback, this); m_config.FirewallConfig.Enabled(), m_config.IgnoredPorts, m_runtimeId, m_gnsRpcServer->GetServerUuid(), s_GuestNetworkServiceCallback, this);
m_ephemeralPortRange = m_guestNetworkService.AllocateEphemeralPortRange(); m_ephemeralPortRange = m_guestNetworkService.AllocateEphemeralPortRange();
networking::ConfigureHyperVFirewall(m_config.FirewallConfig, c_vmOwner); networking::ConfigureHyperVFirewall(m_config.FirewallConfig, wsl::windows::common::wslutil::c_vmOwner);
// must keep all m_networkManager interactions (including) creation queued // must keep all m_networkManager interactions (including) creation queued
// also must queue GNS callbacks to keep them serialized // also must queue GNS callbacks to keep them serialized

View File

@ -31,7 +31,6 @@ wil::unique_event g_networkingReady{wil::EventOptions::ManualReset};
// Declare the LxssUserSession COM class. // Declare the LxssUserSession COM class.
CoCreatableClassWrlCreatorMapInclude(LxssUserSession); CoCreatableClassWrlCreatorMapInclude(LxssUserSession);
CoCreatableClassWrlCreatorMapInclude(WSLAUserSession);
struct WslServiceSecurityPolicy struct WslServiceSecurityPolicy
{ {
@ -241,7 +240,6 @@ void WslService::ServiceStopped()
// Terminate all user sessions. // Terminate all user sessions.
ClearSessionsAndBlockNewInstances(); ClearSessionsAndBlockNewInstances();
wsl::windows::service::wsla::ClearWslaSessionsAndBlockNewInstances();
// Disconnect from the LxCore driver. // Disconnect from the LxCore driver.
if (g_lxcoreInitialized) if (g_lxcoreInitialized)

View File

@ -1468,7 +1468,7 @@ void WslCoreVm::FreeLun(_In_ ULONG lun)
std::wstring WslCoreVm::GenerateConfigJson() std::wstring WslCoreVm::GenerateConfigJson()
{ {
hcs::ComputeSystem systemSettings{}; hcs::ComputeSystem systemSettings{};
systemSettings.Owner = c_vmOwner; systemSettings.Owner = wsl::windows::common::wslutil::c_vmOwner;
systemSettings.ShouldTerminateOnLastHandleClosed = true; systemSettings.ShouldTerminateOnLastHandleClosed = true;
systemSettings.SchemaVersion.Major = 2; systemSettings.SchemaVersion.Major = 2;
systemSettings.SchemaVersion.Minor = 3; systemSettings.SchemaVersion.Minor = 3;
@ -1575,7 +1575,7 @@ std::wstring WslCoreVm::GenerateConfigJson()
// Set the vmmem suffix which will change the process name in task manager. // Set the vmmem suffix which will change the process name in task manager.
if (helpers::IsVmemmSuffixSupported()) if (helpers::IsVmemmSuffixSupported())
{ {
vmSettings.ComputeTopology.Memory.HostingProcessNameSuffix = c_vmOwner; vmSettings.ComputeTopology.Memory.HostingProcessNameSuffix = wsl::windows::common::wslutil::c_vmOwner;
} }
// If nested virtualization was requested, ensure the platform supports it. // If nested virtualization was requested, ensure the platform supports it.

View File

@ -39,8 +39,6 @@ inline constexpr auto c_optionsValueName = L"Options";
inline constexpr auto c_typeValueName = L"Type"; inline constexpr auto c_typeValueName = L"Type";
inline constexpr auto c_mountNameValueName = L"Name"; inline constexpr auto c_mountNameValueName = L"Name";
inline constexpr auto c_vmOwner = L"WSL";
static constexpr GUID c_virtiofsAdminClassId = {0x7e6ad219, 0xd1b3, 0x42d5, {0xb8, 0xee, 0xd9, 0x63, 0x24, 0xe6, 0x4f, 0xf6}}; static constexpr GUID c_virtiofsAdminClassId = {0x7e6ad219, 0xd1b3, 0x42d5, {0xb8, 0xee, 0xd9, 0x63, 0x24, 0xe6, 0x4f, 0xf6}};
// {60285AE6-AAF3-4456-B444-A6C2D0DEDA38} // {60285AE6-AAF3-4456-B444-A6C2D0DEDA38}

View File

@ -396,97 +396,3 @@ cpp_quote("#define WSL_E_DISK_CORRUPTED MAKE_HRESULT(SEVERITY_ERROR, FACILITY_IT
cpp_quote("#define WSL_E_DISTRIBUTION_NAME_NEEDED MAKE_HRESULT(SEVERITY_ERROR, FACILITY_ITF, WSL_E_BASE + 0x30) /* 0x80040330 */") cpp_quote("#define WSL_E_DISTRIBUTION_NAME_NEEDED MAKE_HRESULT(SEVERITY_ERROR, FACILITY_ITF, WSL_E_BASE + 0x30) /* 0x80040330 */")
cpp_quote("#define WSL_E_INVALID_JSON MAKE_HRESULT(SEVERITY_ERROR, FACILITY_ITF, WSL_E_BASE + 0x31) /* 0x80040331 */") cpp_quote("#define WSL_E_INVALID_JSON MAKE_HRESULT(SEVERITY_ERROR, FACILITY_ITF, WSL_E_BASE + 0x31) /* 0x80040331 */")
cpp_quote("#define WSL_E_VM_CRASHED MAKE_HRESULT(SEVERITY_ERROR, FACILITY_ITF, WSL_E_BASE + 0x32) /* 0x80040332 */") cpp_quote("#define WSL_E_VM_CRASHED MAKE_HRESULT(SEVERITY_ERROR, FACILITY_ITF, WSL_E_BASE + 0x32) /* 0x80040332 */")
typedef
struct _WSL_VERSION {
ULONG Major;
ULONG Minor;
ULONG Revision;
} WSL_VERSION;
typedef [system_handle(sh_socket)] HANDLE HVSOCKET_HANDLE;
typedef
struct _WSLA_CREATE_PROCESS_OPTIONS {
[string] LPCSTR Executable;
ULONG CommandLineCount;
[unique, size_is(CommandLineCount)] LPCSTR* CommandLine;
ULONG EnvironmentCount;
[unique, size_is(EnvironmentCount)] LPCSTR* Environment;
[unique] LPCSTR CurrentDirectory;
} WSLA_CREATE_PROCESS_OPTIONS;
typedef struct _WSLA_PROCESS_FD
{
LONG Fd;
int Type;
[string, unique] LPCSTR Path;
} WSLA_PROCESS_FD;
typedef
struct _WSLA_CREATE_PROCESS_RESULT {
int Errno;
int Pid;
} WSLA_CREATE_PROCESS_RESULT;
[
uuid(7BC4E198-6531-4FA6-ADE2-5EF3D2A04DFE),
pointer_default(unique),
object
]
interface ITerminationCallback : IUnknown
{
HRESULT OnTermination(ULONG Reason, LPCWSTR Details);
};
[
uuid(82A7ABC8-6B50-43FC-AB96-15FBBE7E8761),
pointer_default(unique),
object
]
interface IWSLAVirtualMachine : IUnknown
{
HRESULT AttachDisk([in] LPCWSTR Path, [in] BOOL ReadOnly, [out] LPSTR* Device, [out] ULONG* Lun);
HRESULT Mount([in, unique] LPCSTR Source, [in] LPCSTR Target, [in] LPCSTR Type, [in] LPCSTR Options, [in] ULONG Flags);
HRESULT CreateLinuxProcess([in] const WSLA_CREATE_PROCESS_OPTIONS* Options, [in] ULONG FdCount, [in, unique, size_is(FdCount)] WSLA_PROCESS_FD* Fds, [out, size_is(FdCount)] ULONG* Handles, [out] WSLA_CREATE_PROCESS_RESULT* Result);
HRESULT WaitPid([in] LONG Pid, [in] ULONGLONG TimeoutMs, [out] ULONG* State, [out] int* Code);
HRESULT Signal([in] LONG Pid, [in] int Signal);
HRESULT Shutdown([in] ULONGLONG TimeoutMs);
HRESULT RegisterCallback([in] ITerminationCallback* terminationCallback);
HRESULT GetDebugShellPipe([out] LPWSTR* pipePath);
HRESULT MapPort([in] int Family, [in] short WindowsPort, [in] short LinuxPort, [in] BOOL Remove);
HRESULT Unmount([in] LPCSTR Path);
HRESULT DetachDisk([in] ULONG Lun);
HRESULT MountWindowsFolder([in] LPCWSTR WindowsPath, [in] LPCSTR LinuxPath, [in] BOOL ReadOnly);
HRESULT UnmountWindowsFolder([in] LPCSTR LinuxPath);
HRESULT MountGpuLibraries([in] LPCSTR LibrariesMountPoint, [in] LPCSTR DriversMountpoint, [in] DWORD Flags);
}
typedef
struct _VIRTUAL_MACHINE_SETTINGS {
LPCWSTR DisplayName;
ULONGLONG MemoryMb;
ULONG CpuCount;
ULONG BootTimeoutMs;
ULONG DmesgOutput;
ULONG NetworkingMode;
BOOL EnableDnsTunneling;
BOOL EnableDebugShell;
BOOL EnableEarlyBootDmesg;
BOOL EnableGPU;
} VIRTUAL_MACHINE_SETTINGS;
[
uuid(82A7ABC8-6B50-43FC-AB96-15FBBE7E8760),
pointer_default(unique),
object
]
interface IWSLAUserSession : IUnknown
{
HRESULT GetVersion([out] WSL_VERSION* Error);
HRESULT CreateVirtualMachine([in] const VIRTUAL_MACHINE_SETTINGS* Settings, [out]IWSLAVirtualMachine** VirtualMachine);
}

View File

@ -3,6 +3,7 @@ set(HEADERS WSLAApi.h)
add_library(wslaclient SHARED ${SOURCES} ${HEADERS} wslaclient.def) add_library(wslaclient SHARED ${SOURCES} ${HEADERS} wslaclient.def)
set_target_properties(wslaclient PROPERTIES EXCLUDE_FROM_ALL FALSE) set_target_properties(wslaclient PROPERTIES EXCLUDE_FROM_ALL FALSE)
add_dependencies(wslaclient wslaserviceidl)
target_link_libraries(wslaclient ${COMMON_LINK_LIBRARIES} legacy_stdio_definitions common) target_link_libraries(wslaclient ${COMMON_LINK_LIBRARIES} legacy_stdio_definitions common)
target_precompile_headers(wslaclient REUSE_FROM common) target_precompile_headers(wslaclient REUSE_FROM common)
set_target_properties(wslaclient PROPERTIES FOLDER windows) set_target_properties(wslaclient PROPERTIES FOLDER windows)

View File

@ -13,7 +13,7 @@ Abstract:
--*/ --*/
#include "precomp.h" #include "precomp.h"
#include "wslservice.h" #include "wslaservice.h"
#include "WSLAApi.h" #include "WSLAApi.h"
#include "wslrelay.h" #include "wslrelay.h"
#include "wslInstall.h" #include "wslInstall.h"

View File

@ -0,0 +1,6 @@
add_compile_definitions("PROXY_CLSID_IS={0x4EA0C6DD,0xE9FF,0x48E7,{0x99,0x4e,0x13,0xa3,0x1d,0x10,0xdc,0x61}}")
add_compile_definitions("REGISTER_PROXY_DLL")
add_subdirectory(exe)
add_subdirectory(inc)
add_subdirectory(stub)

View File

@ -0,0 +1,34 @@
set(SOURCES
application.manifest
main.rc
ServiceMain.cpp
WSLAUserSession.cpp
WSLAUserSessionFactory.cpp
WSLAVirtualMachine.cpp
)
set(HEADERS
WSLAUserSession.h
WSLAUserSessionFactory.h
WSLAVirtualMachine.h)
include_directories(${CMAKE_SOURCE_DIR}/src/windows/wslaclient)
add_executable(wslaservice ${SOURCES} ${HEADERS})
add_dependencies(wslaservice wslaserviceidl)
add_compile_definitions(__WRL_CLASSIC_COM__)
add_compile_definitions(__WRL_DISABLE_STATIC_INITIALIZE__)
add_compile_definitions(USE_COM_CONTEXT_DEF=1)
set_target_properties(wslaservice PROPERTIES LINK_FLAGS "/merge:minATL=.rdata /include:__minATLObjMap_WSLAUserSession_COM")
target_link_libraries(wslaservice
${COMMON_LINK_LIBRARIES}
computecore
common
configfile
legacy_stdio_definitions
VirtDisk.lib
Winhttp.lib
Synchronization.lib)
target_precompile_headers(wslaservice REUSE_FROM common)
set_target_properties(wslaservice PROPERTIES FOLDER windows)

View File

@ -0,0 +1,126 @@
/*++
Copyright (c) Microsoft. All rights reserved.
Module Name:
ServiceMain.cpp
Abstract:
This file contains the entrypoint for the Lxss Manager service.
--*/
#include "precomp.h"
#include "comservicehelper.h"
#include "LxssSecurity.h"
#include "WslCoreFilesystem.h"
#include "WSLAUserSessionFactory.h"
#include <ctime>
using namespace wsl::windows::common::registry;
using namespace wsl::windows::common::string;
using namespace wsl::windows::common::wslutil;
using namespace wsl::windows::policies;
wil::unique_event g_networkingReady{wil::EventOptions::ManualReset};
// Declare the WSLAUserSession COM class.
CoCreatableClassWrlCreatorMapInclude(WSLAUserSession);
struct WslaServiceSecurityPolicy
{
static LPCWSTR GetSDDLText()
{
// COM Access and Launch permissions allowed for authenticated user, principal self, and system.
// 0xB = (COM_RIGHTS_EXECUTE | COM_RIGHTS_EXECUTE_LOCAL | COM_RIGHTS_ACTIVATE_LOCAL)
// N.B. This should be kept in sync with the security descriptors in the appxmanifest and package.wix.
return L"O:BAG:BAD:(A;;0xB;;;AU)(A;;0xB;;;PS)(A;;0xB;;;SY)";
}
};
class WslaService : public Windows::Internal::Service<WslaService, Windows::Internal::ContinueRunningWithNoObjects, WslaServiceSecurityPolicy>
{
public:
static wchar_t* GetName()
{
return const_cast<LPWSTR>(L"WSLAService");
}
static void OnSessionChanged(DWORD eventType, DWORD sessionId);
HRESULT OnServiceStarting();
HRESULT ServiceStarted();
void ServiceStopped();
private:
wil::unique_couninitialize_call m_coInit{false};
};
HRESULT WslaService::OnServiceStarting()
try
{
ConfigureCrt();
// Enable contextualized errors
wsl::windows::common::EnableContextualizedErrors(true);
// Initialize telemetry.
// TODO-WSLA: Create a dedicated WSLA provider
WslTraceLoggingInitialize(WslServiceTelemetryProvider, !wsl::shared::OfficialBuild);
WSL_LOG("Service starting", TraceLoggingLevel(WINEVENT_LEVEL_INFO));
// Don't kill the process on unknown C++ exceptions.
wil::g_fResultFailFastUnknownExceptions = false;
// wsl::windows::common::security::ApplyProcessMitigationPolicies();
// Initialize Winsock.
WSADATA Data;
THROW_IF_WIN32_ERROR(WSAStartup(MAKEWORD(2, 2), &Data));
return S_OK;
}
CATCH_RETURN()
HRESULT WslaService::ServiceStarted()
{
m_coInit = wil::CoInitializeEx(COINIT_MULTITHREADED);
return S_OK;
}
void WslaService::OnSessionChanged(DWORD eventType, DWORD sessionId)
{
if (eventType == WTS_SESSION_LOGOFF)
{
// TODO-WSLA: Implement for WSLA
// TerminateSession(sessionId);
}
}
void WslaService::ServiceStopped()
{
WSL_LOG("Service stopping", TraceLoggingLevel(WINEVENT_LEVEL_INFO));
// Terminate all user sessions.
wsl::windows::service::wsla::ClearWslaSessionsAndBlockNewInstances();
// There is a potential deadlock if CoUninitialize() is called before the LanguageChangeNotifyThread
// isn't done initializing. Clearing the COM objects before calling CoUninitialize() works around the issue.
winrt::clear_factory_cache();
// Tear down telemetry.
WslTraceLoggingUninitialize();
// uninitialize COM. This must be done here because this call can cause cleanups that will be fail
// if the CRT is shutting down.
m_coInit.reset();
}
int __cdecl wmain()
{
WslaService::ProcessMain();
return 0;
}

View File

@ -0,0 +1,87 @@
/*++
Copyright (c) Microsoft. All rights reserved.
Module Name:
WSLAUserSession.cpp
Abstract:
TODO
--*/
#include "WSLAUserSession.h"
using wsl::windows::service::wsla::WSLAUserSessionImpl;
WSLAUserSessionImpl::WSLAUserSessionImpl(HANDLE Token, wil::unique_tokeninfo_ptr<TOKEN_USER>&& TokenInfo) :
m_tokenInfo(std::move(TokenInfo))
{
}
WSLAUserSessionImpl::~WSLAUserSessionImpl()
{
// Manually signal the VM termination events. This prevents being stuck on an API call that holds the VM lock.
{
std::lock_guard lock(m_virtualMachinesLock);
for (auto* e : m_virtualMachines)
{
e->OnSessionTerminating();
}
}
}
void WSLAUserSessionImpl::OnVmTerminated(WSLAVirtualMachine* machine)
{
std::lock_guard lock(m_virtualMachinesLock);
auto pred = [machine](const auto* e) { return machine == e; };
// Remove any stale VM reference.
m_virtualMachines.erase(std::remove_if(m_virtualMachines.begin(), m_virtualMachines.end(), pred), m_virtualMachines.end());
}
HRESULT WSLAUserSessionImpl::CreateVirtualMachine(const VIRTUAL_MACHINE_SETTINGS* Settings, IWSLAVirtualMachine** VirtualMachine)
{
auto vm = wil::MakeOrThrow<WSLAVirtualMachine>(*Settings, GetUserSid(), this);
{
std::lock_guard lock(m_virtualMachinesLock);
m_virtualMachines.emplace_back(vm.Get());
}
vm->Start();
THROW_IF_FAILED(vm.CopyTo(__uuidof(IWSLAVirtualMachine), (void**)VirtualMachine));
return S_OK;
}
PSID WSLAUserSessionImpl::GetUserSid() const
{
return m_tokenInfo->User.Sid;
}
wsl::windows::service::wsla::WSLAUserSession::WSLAUserSession(std::weak_ptr<WSLAUserSessionImpl>&& Session) :
m_session(std::move(Session))
{
}
HRESULT wsl::windows::service::wsla::WSLAUserSession::GetVersion(_Out_ WSL_VERSION* Version)
{
Version->Major = WSL_PACKAGE_VERSION_MAJOR;
Version->Minor = WSL_PACKAGE_VERSION_MINOR;
Version->Revision = WSL_PACKAGE_VERSION_REVISION;
return S_OK;
}
HRESULT wsl::windows::service::wsla::WSLAUserSession::CreateVirtualMachine(const VIRTUAL_MACHINE_SETTINGS* Settings, IWSLAVirtualMachine** VirtualMachine)
try
{
auto session = m_session.lock();
RETURN_HR_IF(RPC_E_DISCONNECTED, !session);
return session->CreateVirtualMachine(Settings, VirtualMachine);
}
CATCH_RETURN();

View File

@ -0,0 +1,56 @@
/*++
Copyright (c) Microsoft. All rights reserved.
Module Name:
WSLAUserSession.h
Abstract:
TODO
--*/
#pragma once
#include "WSLAVirtualMachine.h"
namespace wsl::windows::service::wsla {
class WSLAUserSessionImpl
{
public:
WSLAUserSessionImpl(HANDLE Token, wil::unique_tokeninfo_ptr<TOKEN_USER>&& TokenInfo);
WSLAUserSessionImpl(WSLAUserSessionImpl&&) = default;
WSLAUserSessionImpl& operator=(WSLAUserSessionImpl&&) = default;
~WSLAUserSessionImpl();
PSID GetUserSid() const;
HRESULT CreateVirtualMachine(const VIRTUAL_MACHINE_SETTINGS* Settings, IWSLAVirtualMachine** VirtualMachine);
void OnVmTerminated(WSLAVirtualMachine* machine);
private:
wil::unique_tokeninfo_ptr<TOKEN_USER> m_tokenInfo;
std::recursive_mutex m_virtualMachinesLock;
std::vector<WSLAVirtualMachine*> m_virtualMachines;
};
class DECLSPEC_UUID("a9b7a1b9-0671-405c-95f1-e0612cb4ce8f") WSLAUserSession
: public Microsoft::WRL::RuntimeClass<Microsoft::WRL::RuntimeClassFlags<Microsoft::WRL::ClassicCom>, IWSLAUserSession, IFastRundown>
{
public:
WSLAUserSession(std::weak_ptr<WSLAUserSessionImpl>&& Session);
WSLAUserSession(const WSLAUserSession&) = delete;
WSLAUserSession& operator=(const WSLAUserSession&) = delete;
IFACEMETHOD(GetVersion)(_Out_ WSL_VERSION* Version) override;
IFACEMETHOD(CreateVirtualMachine)(const VIRTUAL_MACHINE_SETTINGS* Settings, IWSLAVirtualMachine** VirtualMachine) override;
private:
std::weak_ptr<WSLAUserSessionImpl> m_session;
};
} // namespace wsl::windows::service::wsla

View File

@ -0,0 +1,82 @@
/*++
Copyright (c) Microsoft. All rights reserved.
Module Name:
WSLAUserSessionFactory.cpp
Abstract:
TODO
--*/
#include "precomp.h"
#include "WSLAUserSessionFactory.h"
#include "WSLAUserSession.h"
using wsl::windows::service::wsla::WSLAUserSessionFactory;
using wsl::windows::service::wsla::WSLAUserSessionImpl;
CoCreatableClassWithFactory(WSLAUserSession, WSLAUserSessionFactory);
static std::mutex g_mutex;
static std::optional<std::vector<std::shared_ptr<WSLAUserSessionImpl>>> g_sessions =
std::make_optional<std::vector<std::shared_ptr<WSLAUserSessionImpl>>>();
HRESULT WSLAUserSessionFactory::CreateInstance(_In_ IUnknown* pUnkOuter, _In_ REFIID riid, _Out_ void** ppCreated)
{
RETURN_HR_IF_NULL(E_POINTER, ppCreated);
*ppCreated = nullptr;
RETURN_HR_IF(CLASS_E_NOAGGREGATION, pUnkOuter != nullptr);
WSL_LOG("WSLAUserSessionFactory", TraceLoggingLevel(WINEVENT_LEVEL_VERBOSE));
try
{
const wil::unique_handle userToken = wsl::windows::common::security::GetUserToken(TokenImpersonation);
// Get the session ID and SID of the client process.
DWORD sessionId{};
DWORD length = 0;
THROW_IF_WIN32_BOOL_FALSE(::GetTokenInformation(userToken.get(), TokenSessionId, &sessionId, sizeof(sessionId), &length));
auto tokenInfo = wil::get_token_information<TOKEN_USER>(userToken.get());
std::lock_guard lock{g_mutex};
THROW_HR_IF(CO_E_SERVER_STOPPING, !g_sessions.has_value());
auto session = std::find_if(g_sessions->begin(), g_sessions->end(), [&tokenInfo](auto it) {
return EqualSid(it->GetUserSid(), &tokenInfo->User.Sid);
});
if (session == g_sessions->end())
{
session = g_sessions->insert(g_sessions->end(), std::make_shared<WSLAUserSessionImpl>(userToken.get(), std::move(tokenInfo)));
}
auto comInstance = wil::MakeOrThrow<WSLAUserSession>(std::weak_ptr<WSLAUserSessionImpl>(*session));
THROW_IF_FAILED(comInstance.CopyTo(riid, ppCreated));
}
catch (...)
{
const auto result = wil::ResultFromCaughtException();
// Note: S_FALSE will cause COM to retry if the service is stopping.
return result == CO_E_SERVER_STOPPING ? S_FALSE : result;
}
WSL_LOG("WSLAUserSessionFactory", TraceLoggingLevel(WINEVENT_LEVEL_VERBOSE));
return S_OK;
}
void wsl::windows::service::wsla::ClearWslaSessionsAndBlockNewInstances()
{
std::lock_guard lock{g_mutex};
g_sessions.reset();
}

View File

@ -0,0 +1,28 @@
/*++
Copyright (c) Microsoft. All rights reserved.
Module Name:
WSLAUserSessionFactory.h
Abstract:
TODO
--*/
#pragma once
#include <wil/resource.h>
namespace wsl::windows::service::wsla {
class WSLAUserSessionFactory : public Microsoft::WRL::ClassFactory<>
{
public:
WSLAUserSessionFactory() = default;
STDMETHODIMP CreateInstance(_In_ IUnknown* pUnkOuter, _In_ REFIID riid, _Out_ void** ppCreated) override;
};
void ClearWslaSessionsAndBlockNewInstances();
} // namespace wsl::windows::service::wsla

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,106 @@
/*++
Copyright (c) Microsoft. All rights reserved.
Module Name:
WSLAVirtualMachine.h
Abstract:
TODO
--*/
#pragma once
#include "wslaservice.h"
#include "INetworkingEngine.h"
#include "hcs.hpp"
#include "Dmesg.h"
#include "WSLAApi.h"
namespace wsl::windows::service::wsla {
class WSLAUserSessionImpl;
class DECLSPEC_UUID("0CFC5DC1-B6A7-45FC-8034-3FA9ED73CE30") WSLAVirtualMachine
: public Microsoft::WRL::RuntimeClass<Microsoft::WRL::RuntimeClassFlags<Microsoft::WRL::ClassicCom>, IWSLAVirtualMachine, IFastRundown>
{
public:
WSLAVirtualMachine(const VIRTUAL_MACHINE_SETTINGS& Settings, PSID Sid, WSLAUserSessionImpl* UserSession);
~WSLAVirtualMachine();
void Start();
void OnSessionTerminating();
IFACEMETHOD(AttachDisk(_In_ PCWSTR Path, _In_ BOOL ReadOnly, _Out_ LPSTR* Device, _Out_ ULONG* Lun)) override;
IFACEMETHOD(Mount(_In_ LPCSTR Source, _In_ LPCSTR Target, _In_ LPCSTR Type, _In_ LPCSTR Options, _In_ ULONG Flags)) override;
IFACEMETHOD(CreateLinuxProcess(
_In_ const WSLA_CREATE_PROCESS_OPTIONS* Options, _In_ ULONG FdCount, _In_ WSLA_PROCESS_FD* Fd, _Out_ ULONG* Handles, _Out_ WSLA_CREATE_PROCESS_RESULT* Result)) override;
IFACEMETHOD(WaitPid(_In_ LONG Pid, _In_ ULONGLONG TimeoutMs, _Out_ ULONG* State, _Out_ int* Code)) override;
IFACEMETHOD(Signal(_In_ LONG Pid, _In_ int Signal)) override;
IFACEMETHOD(Shutdown(ULONGLONG _In_ TimeoutMs)) override;
IFACEMETHOD(RegisterCallback(_In_ ITerminationCallback* callback)) override;
IFACEMETHOD(GetDebugShellPipe(_Out_ LPWSTR* pipePath)) override;
IFACEMETHOD(MapPort(_In_ int Family, _In_ short WindowsPort, _In_ short LinuxPort, _In_ BOOL Remove)) override;
IFACEMETHOD(Unmount(_In_ const char* Path)) override;
IFACEMETHOD(DetachDisk(_In_ ULONG Lun)) override;
IFACEMETHOD(MountWindowsFolder(_In_ LPCWSTR WindowsPath, _In_ LPCSTR LinuxPath, _In_ BOOL ReadOnly)) override;
IFACEMETHOD(UnmountWindowsFolder(_In_ LPCSTR LinuxPath)) override;
IFACEMETHOD(MountGpuLibraries(_In_ LPCSTR LibrariesMountPoint, _In_ LPCSTR DriversMountpoint, _In_ DWORD Flags)) override;
private:
static int32_t MountImpl(wsl::shared::SocketChannel& Channel, LPCSTR Source, _In_ LPCSTR Target, _In_ LPCSTR Type, _In_ LPCSTR Options, _In_ ULONG Flags);
static void CALLBACK s_OnExit(_In_ HCS_EVENT* Event, _In_opt_ void* Context);
static bool ParseTtyInformation(const WSLA_PROCESS_FD* Fds, ULONG FdCount, const WSLA_PROCESS_FD** TtyInput, const WSLA_PROCESS_FD** TtyOutput);
void ConfigureNetworking();
void OnExit(_In_ const HCS_EVENT* Event);
std::tuple<int32_t, int32_t, wsl::shared::SocketChannel> Fork(enum WSLA_FORK::ForkType Type);
std::tuple<int32_t, int32_t, wsl::shared::SocketChannel> Fork(wsl::shared::SocketChannel& Channel, enum WSLA_FORK::ForkType Type);
int32_t ExpectClosedChannelOrError(wsl::shared::SocketChannel& Channel);
wil::unique_socket ConnectSocket(wsl::shared::SocketChannel& Channel, int32_t Fd);
void OpenLinuxFile(wsl::shared::SocketChannel& Channel, const char* Path, uint32_t Flags, int32_t Fd);
void LaunchPortRelay();
std::vector<wil::unique_socket> CreateLinuxProcessImpl(
_In_ const WSLA_CREATE_PROCESS_OPTIONS* Options, _In_ ULONG FdCount, _In_ WSLA_PROCESS_FD* Fd, _Out_ WSLA_CREATE_PROCESS_RESULT* Result);
HRESULT MountWindowsFolderImpl(_In_ LPCWSTR WindowsPath, _In_ LPCSTR LinuxPath, _In_ BOOL ReadOnly, _In_ WslMountFlags Flags);
struct AttachedDisk
{
std::filesystem::path Path;
std::string Device;
bool AccessGranted = false;
};
VIRTUAL_MACHINE_SETTINGS m_settings;
GUID m_vmId{};
std::wstring m_vmIdString;
wsl::windows::common::helpers::WindowsVersion m_windowsVersion = wsl::windows::common::helpers::GetWindowsVersion();
int m_coldDiscardShiftSize{};
bool m_running = false;
PSID m_userSid{};
std::wstring m_debugShellPipe;
wsl::windows::common::hcs::unique_hcs_system m_computeSystem;
std::shared_ptr<DmesgCollector> m_dmesgCollector;
wil::unique_event m_vmExitEvent{wil::EventOptions::ManualReset};
wil::unique_event m_vmTerminatingEvent{wil::EventOptions::ManualReset};
wil::com_ptr<ITerminationCallback> m_terminationCallback;
std::unique_ptr<wsl::core::INetworkingEngine> m_networkEngine;
wsl::shared::SocketChannel m_initChannel;
wil::unique_handle m_portRelayChannelRead;
wil::unique_handle m_portRelayChannelWrite;
std::map<ULONG, AttachedDisk> m_attachedDisks;
std::map<std::string, std::wstring> m_plan9Mounts;
std::recursive_mutex m_lock;
std::mutex m_portRelaylock;
WSLAUserSessionImpl* m_userSession;
};
} // namespace wsl::windows::service::wsla

View File

@ -0,0 +1,8 @@
<?xml version="1.0" encoding="utf-8" standalone="yes"?>
<assembly xmlns="urn:schemas-microsoft-com:asm.v1" manifestVersion="1.0" xmlns:asmv3="urn:schemas-microsoft-com:asm.v3" >
<application xmlns="urn:schemas-microsoft-com:asm.v3">
<windowsSettings xmlns:ws2="http://schemas.microsoft.com/SMI/2016/WindowsSettings">
<ws2:longPathAware>true</ws2:longPathAware>
</windowsSettings>
</application>
</assembly>

View File

@ -0,0 +1,28 @@
/*++
Copyright (c) Microsoft. All rights reserved.
Module Name:
main.rc
Abstract:
This file contains resources for wslaservice.
--*/
#include <windows.h>
#include "resource.h"
#include "wslversioninfo.h"
#define VER_INTERNALNAME_STR "wslaservice.exe"
#define VER_ORIGINALFILENAME_STR "wslaservice.exe"
#define VER_FILETYPE VFT_APP
#define VER_FILESUBTYPE VFT2_UNKNOWN
#define VER_FILEDESCRIPTION_STR "Windows Subsystem for Linux for Apps Service"
ID_ICON ICON PRELOAD DISCARDABLE "..\..\..\Images\wsl.ico"
#include <common.ver>

View File

@ -0,0 +1,15 @@
/*++
Copyright (c) Microsoft. All rights reserved.
Module Name:
resource.h
Abstract:
This file contains resource declarations for wslaservice.exe
--*/
#define ID_ICON 1

View File

@ -0,0 +1,2 @@
add_idl(wslaserviceidl "wslaservice.idl" "")
set_target_properties(wslaserviceidl PROPERTIES FOLDER windows)

View File

@ -0,0 +1,114 @@
/*++
Copyright (c) Microsoft Corporation. All rights reserved.
Module Name:
wslaservice.idl
Abstract:
This file contains the COM object definitions used to talk with the WSLa
service "WslaService"
--*/
import "unknwn.idl";
import "wtypes.idl";
cpp_quote("#ifdef __cplusplus")
cpp_quote("class DECLSPEC_UUID(\"a9b7a1b9-0671-405c-95f1-e0612cb4ce8f\") WSLAUserSession;")
cpp_quote("#endif")
typedef
struct _WSL_VERSION {
ULONG Major;
ULONG Minor;
ULONG Revision;
} WSL_VERSION;
typedef [system_handle(sh_socket)] HANDLE HVSOCKET_HANDLE;
typedef
struct _WSLA_CREATE_PROCESS_OPTIONS {
[string] LPCSTR Executable;
ULONG CommandLineCount;
[unique, size_is(CommandLineCount)] LPCSTR* CommandLine;
ULONG EnvironmentCount;
[unique, size_is(EnvironmentCount)] LPCSTR* Environment;
[unique] LPCSTR CurrentDirectory;
} WSLA_CREATE_PROCESS_OPTIONS;
typedef struct _WSLA_PROCESS_FD
{
LONG Fd;
int Type;
[string, unique] LPCSTR Path;
} WSLA_PROCESS_FD;
typedef
struct _WSLA_CREATE_PROCESS_RESULT {
int Errno;
int Pid;
} WSLA_CREATE_PROCESS_RESULT;
[
uuid(7BC4E198-6531-4FA6-ADE2-5EF3D2A04DFE),
pointer_default(unique),
object
]
interface ITerminationCallback : IUnknown
{
HRESULT OnTermination(ULONG Reason, LPCWSTR Details);
};
[
uuid(82A7ABC8-6B50-43FC-AB96-15FBBE7E8761),
pointer_default(unique),
object
]
interface IWSLAVirtualMachine : IUnknown
{
HRESULT AttachDisk([in] LPCWSTR Path, [in] BOOL ReadOnly, [out] LPSTR* Device, [out] ULONG* Lun);
HRESULT Mount([in, unique] LPCSTR Source, [in] LPCSTR Target, [in] LPCSTR Type, [in] LPCSTR Options, [in] ULONG Flags);
HRESULT CreateLinuxProcess([in] const WSLA_CREATE_PROCESS_OPTIONS* Options, [in] ULONG FdCount, [in, unique, size_is(FdCount)] WSLA_PROCESS_FD* Fds, [out, size_is(FdCount)] ULONG* Handles, [out] WSLA_CREATE_PROCESS_RESULT* Result);
HRESULT WaitPid([in] LONG Pid, [in] ULONGLONG TimeoutMs, [out] ULONG* State, [out] int* Code);
HRESULT Signal([in] LONG Pid, [in] int Signal);
HRESULT Shutdown([in] ULONGLONG TimeoutMs);
HRESULT RegisterCallback([in] ITerminationCallback* terminationCallback);
HRESULT GetDebugShellPipe([out] LPWSTR* pipePath);
HRESULT MapPort([in] int Family, [in] short WindowsPort, [in] short LinuxPort, [in] BOOL Remove);
HRESULT Unmount([in] LPCSTR Path);
HRESULT DetachDisk([in] ULONG Lun);
HRESULT MountWindowsFolder([in] LPCWSTR WindowsPath, [in] LPCSTR LinuxPath, [in] BOOL ReadOnly);
HRESULT UnmountWindowsFolder([in] LPCSTR LinuxPath);
HRESULT MountGpuLibraries([in] LPCSTR LibrariesMountPoint, [in] LPCSTR DriversMountpoint, [in] DWORD Flags);
}
typedef
struct _VIRTUAL_MACHINE_SETTINGS {
LPCWSTR DisplayName;
ULONGLONG MemoryMb;
ULONG CpuCount;
ULONG BootTimeoutMs;
ULONG DmesgOutput;
ULONG NetworkingMode;
BOOL EnableDnsTunneling;
BOOL EnableDebugShell;
BOOL EnableEarlyBootDmesg;
BOOL EnableGPU;
} VIRTUAL_MACHINE_SETTINGS;
[
uuid(82A7ABC8-6B50-43FC-AB96-15FBBE7E8760),
pointer_default(unique),
object
]
interface IWSLAUserSession : IUnknown
{
HRESULT GetVersion([out] WSL_VERSION* Error);
HRESULT CreateVirtualMachine([in] const VIRTUAL_MACHINE_SETTINGS* Settings, [out]IWSLAVirtualMachine** VirtualMachine);
}

View File

@ -0,0 +1,13 @@
set(SOURCES
${CMAKE_CURRENT_BINARY_DIR}/../inc/${TARGET_PLATFORM}/${CMAKE_BUILD_TYPE}/wslaservice_i_${TARGET_PLATFORM}.c
${CMAKE_CURRENT_BINARY_DIR}/../inc/${TARGET_PLATFORM}/${CMAKE_BUILD_TYPE}/wslaservice_p_${TARGET_PLATFORM}.c
${CMAKE_CURRENT_BINARY_DIR}/../inc/${TARGET_PLATFORM}/${CMAKE_BUILD_TYPE}/dlldata_${TARGET_PLATFORM}.c
${CMAKE_CURRENT_LIST_DIR}/WslaServiceProxyStub.def
${CMAKE_CURRENT_LIST_DIR}/WslaServiceProxyStub.rc)
set_source_files_properties(${SOURCES} PROPERTIES GENERATED TRUE)
add_library(wslaserviceproxystub SHARED ${SOURCES})
add_dependencies(wslaserviceproxystub wslaserviceidl)
target_link_libraries(wslaserviceproxystub ${COMMON_LINK_LIBRARIES})
set_target_properties(wslaserviceproxystub PROPERTIES FOLDER windows)

View File

@ -0,0 +1,5 @@
LIBRARY WslaServiceProxyStub.dll
EXPORTS
DllGetClassObject PRIVATE
DllCanUnloadNow PRIVATE

View File

@ -0,0 +1,23 @@
/*++
Copyright (c) Microsoft. All rights reserved.
Module Name:
WslaServiceProxyStub.rc
Abstract:
This file contains resources for wslaserviceproxystub.dll.
--*/
#include <windows.h>
#include "wslversioninfo.h"
#define VER_INTERNALNAME_STR "wslaserviceproxystub.dll"
#define VER_ORIGINALFILENAME_STR "wslaserviceproxystub.dll"
#define VER_FILEDESCRIPTION_STR "WSLA Service ProxyStub DLL"
#include <common.ver>

View File

@ -1316,6 +1316,17 @@ void StopWslService()
StopService(service.get()); StopService(service.get());
} }
void StopWslaService()
{
LogInfo("Stopping WSLAService");
const wil::unique_schandle manager{OpenSCManager(nullptr, nullptr, SC_MANAGER_CONNECT)};
VERIFY_IS_NOT_NULL(manager);
const wil::unique_schandle service{OpenService(manager.get(), L"wslaservice", SERVICE_STOP | SERVICE_QUERY_STATUS)};
VERIFY_IS_NOT_NULL(service);
StopService(service.get());
}
wil::unique_handle GetNonElevatedToken() wil::unique_handle GetNonElevatedToken()
{ {
const auto token = wil::open_current_access_token(TOKEN_ALL_ACCESS); const auto token = wil::open_current_access_token(TOKEN_ALL_ACCESS);

View File

@ -523,6 +523,7 @@ inline auto EnableSystemd(const std::string& extraConfig = "")
std::wstring EscapePath(std::wstring_view Path); std::wstring EscapePath(std::wstring_view Path);
void StopWslService(); void StopWslService();
void StopWslaService();
std::optional<GUID> GetDistributionId(LPCWSTR Name); std::optional<GUID> GetDistributionId(LPCWSTR Name);
wil::unique_hkey OpenDistributionKey(LPCWSTR Name); wil::unique_hkey OpenDistributionKey(LPCWSTR Name);

View File

@ -821,7 +821,7 @@ class WSLATests
}); });
// Stop the service // Stop the service
StopWslService(); StopWslaService();
// Verify that the thread is unstuck // Verify that the thread is unstuck
stuckThread.join(); stuckThread.join();