mirror of
https://github.com/microsoft/WSL.git
synced 2025-12-10 00:44:55 -06:00
wsla: Create wslaservice.exe (#13623)
* Save state * Save state * Save state * Save state * Save state * Save state * Save state * Save state * Save state * Cleanup for review * Update ServiceMain.cpp comment * Remove duplicated definitions from wslservice.idl
This commit is contained in:
parent
3f24caaf73
commit
451a7e103a
@ -388,6 +388,7 @@ include_directories(${WSLDEPS_SOURCE_DIR}/include/lxcore)
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/src/shared/inc)
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/src/windows/inc)
|
||||
include_directories(${CMAKE_CURRENT_BINARY_DIR}/src/windows/service/inc/${TARGET_PLATFORM}/${CMAKE_BUILD_TYPE})
|
||||
include_directories(${CMAKE_CURRENT_BINARY_DIR}/src/windows/wslaservice/inc/${TARGET_PLATFORM}/${CMAKE_BUILD_TYPE})
|
||||
include_directories(${CMAKE_CURRENT_BINARY_DIR}/src/windows/wslinstaller/inc/${TARGET_PLATFORM}/${CMAKE_BUILD_TYPE})
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/src/linux/init/inc)
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/src/windows/common)
|
||||
@ -413,6 +414,7 @@ add_subdirectory(msipackage)
|
||||
add_subdirectory(msixinstaller)
|
||||
add_subdirectory(src/windows/common)
|
||||
add_subdirectory(src/windows/service)
|
||||
add_subdirectory(src/windows/wslaservice)
|
||||
add_subdirectory(src/windows/wslinstaller/inc)
|
||||
add_subdirectory(src/windows/wslinstaller/stub)
|
||||
add_subdirectory(src/windows/wslinstaller/exe)
|
||||
|
||||
@ -3,7 +3,7 @@ set(OUTPUT_PACKAGE ${BIN}/wsl.msi)
|
||||
set(PACKAGE_WIX_IN ${CMAKE_CURRENT_LIST_DIR}/package.wix.in)
|
||||
set(PACKAGE_WIX ${BIN}/package.wix)
|
||||
set(CAB_CACHE ${BIN}/cab)
|
||||
set(BINARIES wsl.exe;wslg.exe;wslhost.exe;wslrelay.exe;wslservice.exe;wslserviceproxystub.dll;init;initrd.img;wslinstall.dll)
|
||||
set(BINARIES wsl.exe;wslg.exe;wslhost.exe;wslrelay.exe;wslservice.exe;wslserviceproxystub.dll;init;initrd.img;wslinstall.dll;wslaserviceproxystub.dll;wslaservice.exe)
|
||||
|
||||
if (WSL_BUILD_WSL_SETTINGS)
|
||||
list(APPEND BINARIES_DEPENDENCIES "wslsettings/wslsettings.dll;wslsettings/wslsettings.exe;libwsl.dll")
|
||||
@ -39,7 +39,7 @@ add_custom_command(
|
||||
|
||||
add_custom_target(msipackage DEPENDS ${OUTPUT_PACKAGE})
|
||||
set_target_properties(msipackage PROPERTIES EXCLUDE_FROM_ALL FALSE SOURCES ${PACKAGE_WIX_IN})
|
||||
add_dependencies(msipackage wsl wslg wslservice wslhost wslrelay wslserviceproxystub init initramfs wslinstall msixgluepackage)
|
||||
add_dependencies(msipackage wsl wslg wslservice wslhost wslrelay wslserviceproxystub init initramfs wslinstall msixgluepackage wslaservice wslaserviceproxystub)
|
||||
|
||||
if (WSL_BUILD_WSL_SETTINGS)
|
||||
add_dependencies(msipackage wslsettings libwsl)
|
||||
|
||||
@ -50,75 +50,75 @@
|
||||
|
||||
<Component Id="explorerplan9shortcut" Guid="{93CBFF23-A04C-4344-A332-238CE5B97AED}" UninstallWhenSuperseded="yes" DisableRegistryReflection="yes" Bitness="always64">
|
||||
<!-- Explorer extensions -->
|
||||
<RegistryKey Root="HKLM" Key="SOFTWARE\Classes\CLSID\{B2B4A4D1-2754-4140-A2EB-9A76D9D7CDC6}">
|
||||
<RegistryValue Value="Linux" Type="string"/>
|
||||
<RegistryValue Name="SortOrderIndex" Value="119" Type="integer"/>
|
||||
<!--0x77-->
|
||||
<RegistryValue Name="System.IsPinnedToNameSpaceTree" Value="1" Type="integer"/>
|
||||
<RegistryKey Root="HKLM" Key="SOFTWARE\Classes\CLSID\{B2B4A4D1-2754-4140-A2EB-9A76D9D7CDC6}">
|
||||
<RegistryValue Value="Linux" Type="string"/>
|
||||
<RegistryValue Name="SortOrderIndex" Value="119" Type="integer"/>
|
||||
<!--0x77-->
|
||||
<RegistryValue Name="System.IsPinnedToNameSpaceTree" Value="1" Type="integer"/>
|
||||
|
||||
<RegistryKey Key="DefaultIcon">
|
||||
<RegistryValue Value="[System64Folder]wsl.exe,-1" Type="string"/>
|
||||
</RegistryKey>
|
||||
<RegistryKey Key="DefaultIcon">
|
||||
<RegistryValue Value="[System64Folder]wsl.exe,-1" Type="string"/>
|
||||
</RegistryKey>
|
||||
|
||||
<RegistryKey Key="InProcServer32">
|
||||
<RegistryValue Value="[System64Folder]windows.storage.dll" Type="string"/>
|
||||
</RegistryKey>
|
||||
<RegistryKey Key="InProcServer32">
|
||||
<RegistryValue Value="[System64Folder]windows.storage.dll" Type="string"/>
|
||||
</RegistryKey>
|
||||
|
||||
<RegistryKey Key="ShellFolder">
|
||||
<RegistryValue Name="Attributes" Value="2692743245" Type="integer"/>
|
||||
<!--0xa080004d"-->
|
||||
<RegistryValue Name="FolderValueFlags" Value="40" Type="integer"/>
|
||||
<!--0x28-->
|
||||
</RegistryKey>
|
||||
<RegistryKey Key="ShellFolder">
|
||||
<RegistryValue Name="Attributes" Value="2692743245" Type="integer"/>
|
||||
<!--0xa080004d"-->
|
||||
<RegistryValue Name="FolderValueFlags" Value="40" Type="integer"/>
|
||||
<!--0x28-->
|
||||
</RegistryKey>
|
||||
|
||||
<RegistryKey Key="Instance">
|
||||
<RegistryValue Name="CLSID" Value="{4FE04BFD-85B9-49DD-B914-F4C9556B9DA6}" Type="string"/>
|
||||
<RegistryKey Key="Instance">
|
||||
<RegistryValue Name="CLSID" Value="{4FE04BFD-85B9-49DD-B914-F4C9556B9DA6}" Type="string"/>
|
||||
|
||||
<RegistryKey Key="InitPropertyBag">
|
||||
<RegistryValue Name="DisplayType" Value="2" Type="integer"/>
|
||||
<RegistryValue Name="EnumObjectsTelemetryValue" Value="WSL" Type="string"/>
|
||||
<RegistryValue Name="Provider" Value="Plan 9 Network Provider" Type="string"/>
|
||||
<RegistryValue Name="ResName" Value="\\wsl.localhost" Type="string"/>
|
||||
<RegistryKey Key="InitPropertyBag">
|
||||
<RegistryValue Name="DisplayType" Value="2" Type="integer"/>
|
||||
<RegistryValue Name="EnumObjectsTelemetryValue" Value="WSL" Type="string"/>
|
||||
<RegistryValue Name="Provider" Value="Plan 9 Network Provider" Type="string"/>
|
||||
<RegistryValue Name="ResName" Value="\\wsl.localhost" Type="string"/>
|
||||
</RegistryKey>
|
||||
</RegistryKey>
|
||||
</RegistryKey>
|
||||
</RegistryKey>
|
||||
|
||||
<RegistryKey Root="HKLM" Key="SOFTWARE\Microsoft\Windows\CurrentVersion\Explorer\HideDesktopIcons\NewStartPanel">
|
||||
<RegistryValue Name="{B2B4A4D1-2754-4140-A2EB-9A76D9D7CDC6}" Value="1" Type="integer"/>
|
||||
</RegistryKey>
|
||||
<RegistryKey Root="HKLM" Key="SOFTWARE\Microsoft\Windows\CurrentVersion\Explorer\HideDesktopIcons\NewStartPanel">
|
||||
<RegistryValue Name="{B2B4A4D1-2754-4140-A2EB-9A76D9D7CDC6}" Value="1" Type="integer"/>
|
||||
</RegistryKey>
|
||||
|
||||
<RegistryKey Root="HKLM" Key="SOFTWARE\Microsoft\Windows\CurrentVersion\Explorer\Desktop\NameSpace\{B2B4A4D1-2754-4140-A2EB-9A76D9D7CDC6}">
|
||||
<RegistryValue Value="Linux" Type="string"/>
|
||||
</RegistryKey>
|
||||
<RegistryKey Root="HKLM" Key="SOFTWARE\Microsoft\Windows\CurrentVersion\Explorer\Desktop\NameSpace\{B2B4A4D1-2754-4140-A2EB-9A76D9D7CDC6}">
|
||||
<RegistryValue Value="Linux" Type="string"/>
|
||||
</RegistryKey>
|
||||
|
||||
<RegistryKey Root="HKLM" Key="SOFTWARE\Microsoft\Windows\CurrentVersion\Explorer\IdListAliasTranslations\WSL">
|
||||
<RegistryValue Name="Target" Value="::{B2B4A4D1-2754-4140-A2EB-9A76D9D7CDC6}" Type="string"/>
|
||||
<RegistryValue Name="Source" Value="\\wsl.localhost" Type="string"/>
|
||||
</RegistryKey>
|
||||
<RegistryKey Root="HKLM" Key="SOFTWARE\Microsoft\Windows\CurrentVersion\Explorer\IdListAliasTranslations\WSL">
|
||||
<RegistryValue Name="Target" Value="::{B2B4A4D1-2754-4140-A2EB-9A76D9D7CDC6}" Type="string"/>
|
||||
<RegistryValue Name="Source" Value="\\wsl.localhost" Type="string"/>
|
||||
</RegistryKey>
|
||||
|
||||
<RegistryKey Root="HKLM" Key="SOFTWARE\Microsoft\Windows\CurrentVersion\Explorer\IdListAliasTranslations\WSLLegacy">
|
||||
<RegistryValue Name="Target" Value="::{B2B4A4D1-2754-4140-A2EB-9A76D9D7CDC6}" Type="string"/>
|
||||
<RegistryValue Name="Source" Value="\\wsl$" Type="string"/>
|
||||
</RegistryKey>
|
||||
<RegistryKey Root="HKLM" Key="SOFTWARE\Microsoft\Windows\CurrentVersion\Explorer\IdListAliasTranslations\WSLLegacy">
|
||||
<RegistryValue Name="Target" Value="::{B2B4A4D1-2754-4140-A2EB-9A76D9D7CDC6}" Type="string"/>
|
||||
<RegistryValue Name="Source" Value="\\wsl$" Type="string"/>
|
||||
</RegistryKey>
|
||||
</Component>
|
||||
|
||||
<Component Id="explorershell" Guid="{93CBFF23-A04C-4344-A332-238CE5B97AEC}" UninstallWhenSuperseded="yes" DisableRegistryReflection="yes" Bitness="always64">
|
||||
<?foreach PATH in SOFTWARE\Classes\Directory\shell\WSL;SOFTWARE\Classes\Directory\Background\shell\WSL;SOFTWARE\Classes\Drive\shell\WSL?>
|
||||
<RegistryKey Root="HKLM" Key="$(var.PATH)">
|
||||
<RegistryValue Value="@wsl.exe,-2" Type="string"/>
|
||||
<RegistryValue Name="Extended" Value="" Type="string"/>
|
||||
<RegistryValue Name="NoWorkingDirectory" Value="" Type="string"/>
|
||||
<RegistryKey Key="command">
|
||||
<RegistryValue Value='wsl.exe --cd "%V"' Type="string"/>
|
||||
<RegistryKey Root="HKLM" Key="$(var.PATH)">
|
||||
<RegistryValue Value="@wsl.exe,-2" Type="string"/>
|
||||
<RegistryValue Name="Extended" Value="" Type="string"/>
|
||||
<RegistryValue Name="NoWorkingDirectory" Value="" Type="string"/>
|
||||
<RegistryKey Key="command">
|
||||
<RegistryValue Value='wsl.exe --cd "%V"' Type="string"/>
|
||||
</RegistryKey>
|
||||
</RegistryKey>
|
||||
</RegistryKey>
|
||||
<?endforeach?>
|
||||
<?endforeach?>
|
||||
|
||||
<ProgId Id="WSLDistributionTar" Description="WSL tar distribution" Icon="wsl.exe">
|
||||
<Extension Id="wsl">
|
||||
<Verb Id="open" Command="open" TargetFile="wsl.exe" Argument="--install --prompt-before-exit --from-file "%1"" />
|
||||
</Extension>
|
||||
</ProgId>
|
||||
<ProgId Id="WSLDistributionTar" Description="WSL tar distribution" Icon="wsl.exe">
|
||||
<Extension Id="wsl">
|
||||
<Verb Id="open" Command="open" TargetFile="wsl.exe" Argument="--install --prompt-before-exit --from-file "%1"" />
|
||||
</Extension>
|
||||
</ProgId>
|
||||
</Component>
|
||||
|
||||
<Component Id="wslservice" Guid="F0C8D6BA-1502-41E7-BF72-D93DFA134735" UninstallWhenSuperseded="yes" DisableRegistryReflection="yes" Bitness="always64">
|
||||
@ -131,33 +131,13 @@
|
||||
</RegistryKey>
|
||||
</RegistryKey>
|
||||
|
||||
<RegistryKey Root="HKCR" Key="Interface\{82A7ABC8-6B50-43FC-AB96-15FBBE7E8760}">
|
||||
<RegistryValue Value="IWSLAUserSession" Type="string" />
|
||||
<RegistryKey Key="ProxyStubClsid32">
|
||||
<RegistryValue Value="{4EA0C6DD-E9FF-48E7-994E-13A31D10DC60}" Type="string" />
|
||||
</RegistryKey>
|
||||
</RegistryKey>
|
||||
|
||||
<RegistryKey Root="HKCR" Key="Interface\{82A7ABC8-6B50-43FC-AB96-15FBBE7E8761}">
|
||||
<RegistryValue Value="IWSLAVirtualMachine" Type="string" />
|
||||
<RegistryKey Key="ProxyStubClsid32">
|
||||
<RegistryValue Value="{4EA0C6DD-E9FF-48E7-994E-13A31D10DC60}" Type="string" />
|
||||
</RegistryKey>
|
||||
</RegistryKey>
|
||||
|
||||
<RegistryKey Root="HKCR" Key="Interface\{7BC4E198-6531-4FA6-ADE2-5EF3D2A04DFE}">
|
||||
<RegistryValue Value="ITerminationCallback" Type="string" />
|
||||
<RegistryKey Key="ProxyStubClsid32">
|
||||
<RegistryValue Value="{4EA0C6DD-E9FF-48E7-994E-13A31D10DC60}" Type="string" />
|
||||
</RegistryKey>
|
||||
</RegistryKey>
|
||||
|
||||
<!-- WSLServiceProxyStub. -->
|
||||
<RegistryKey Root="HKCR" Key="CLSID\{4EA0C6DD-E9FF-48E7-994E-13A31D10DC60}">
|
||||
<RegistryValue Value="PSFactoryBuffer" Type="string" />
|
||||
</RegistryKey>
|
||||
<RegistryKey Root="HKCR" Key="CLSID\{4EA0C6DD-E9FF-48E7-994E-13A31D10DC60}\InProcServer32">
|
||||
<RegistryValue Value="[INSTALLDIR]wslserviceproxystub.dll" Type="string" />
|
||||
<RegistryValue Name="ThreadingModel" Value="Both" Type="string" />
|
||||
<RegistryKey Key="InProcServer32">
|
||||
<RegistryValue Value="[INSTALLDIR]wslserviceproxystub.dll" Type="string" />
|
||||
<RegistryValue Name="ThreadingModel" Value="Both" Type="string" />
|
||||
</RegistryKey>
|
||||
</RegistryKey>
|
||||
|
||||
<!-- ILxssUserSession -->
|
||||
@ -175,18 +155,6 @@
|
||||
<RegistryValue Value="LxssUserSession" Type="string" />
|
||||
</RegistryKey>
|
||||
|
||||
<!-- WSLAUserSession -->
|
||||
<RegistryKey Root="HKCR" Key="CLSID\{a9b7a1b9-0671-405c-95f1-e0612cb4ce8f}">
|
||||
<RegistryValue Name="AppId" Value="{370121D2-AA7E-4608-A86D-0BBAB9DA1A60}" Type="string" />
|
||||
<RegistryValue Value="WSLAUserSession" Type="string" />
|
||||
</RegistryKey>
|
||||
|
||||
<!-- WSLAVirtualMachine -->
|
||||
<RegistryKey Root="HKCR" Key="CLSID\{0CFC5DC1-B6A7-45FC-8034-3FA9ED73CE30}">
|
||||
<RegistryValue Name="AppId" Value="{370121D2-AA7E-4608-A86D-0BBAB9DA1A60}" Type="string" />
|
||||
<RegistryValue Value="WSLAVirtualMachine" Type="string" />
|
||||
</RegistryKey>
|
||||
|
||||
<!-- Notification server -->
|
||||
<RegistryKey Root="HKCR" Key="CLSID\{2B9C59C3-98F1-45C8-B87B-12AE3C7927E8}\LocalServer32">
|
||||
<RegistryValue Value='"[INSTALLDIR]wslhost.exe"' Type="string" />
|
||||
@ -226,7 +194,7 @@
|
||||
|
||||
<RegistryKey Root="HKCR" Key="AppID\{7F82AD86-755B-4870-86B1-D2E68DFE8A49}">
|
||||
<RegistryValue Name="DllSurrogate" Value="" Type="string"/>
|
||||
<RegistryValue Name="AppIDFlags" Value="2048" Type="integer"/><!--0x800-->
|
||||
<RegistryValue Name="AppIDFlags" Value="2048" Type="integer"/> <!--0x800-->
|
||||
|
||||
<!-- O:BAG:BAD:(A;;CCDCSW;;;AU)(A;;CCDCSW;;;PS)(A;;CCDCSW;;;SY) -->
|
||||
<RegistryValue Name="AccessPermission" Value="01000480580000006800000000000000140000000200440003000000000014000B00000001010000000000050B000000000014000B00000001010000000000050A000000000014000B0000000101000000000005120000000102000000000005200000002002000001020000000000052000000020020000" Type="binary" />
|
||||
@ -265,6 +233,68 @@
|
||||
<ServiceControl Id="StopService" Stop="both" Remove="uninstall" Name="WSLService" Wait="yes" />
|
||||
</Component>
|
||||
|
||||
<Component Id="wslaservice" Guid="DC97520E-BFA5-4559-960F-D580E151629F" UninstallWhenSuperseded="yes" DisableRegistryReflection="yes" Bitness="always64">
|
||||
<!-- WSLAServiceProxyStub. -->
|
||||
<RegistryKey Root="HKCR" Key="CLSID\{4EA0C6DD-E9FF-48E7-994E-13A31D10DC61}">
|
||||
<RegistryValue Value="PSFactoryBuffer" Type="string" />
|
||||
<RegistryKey Key="InProcServer32">
|
||||
<RegistryValue Value="[INSTALLDIR]wslaserviceproxystub.dll" Type="string" />
|
||||
<RegistryValue Name="ThreadingModel" Value="Both" Type="string" />
|
||||
</RegistryKey>
|
||||
</RegistryKey>
|
||||
|
||||
<!-- WSLAService COM app -->
|
||||
<RegistryKey Root="HKCR" Key="AppID\{E9B79997-57E3-4201-AECC-6A464E530DD2}">
|
||||
|
||||
<!-- O:BAG:BAD:(A;;CCDCSW;;;AU)(A;;CCDCSW;;;PS)(A;;CCDCSW;;;SY) -->
|
||||
<RegistryValue Name="AccessPermission" Value="01000480580000006800000000000000140000000200440003000000000014000B00000001010000000000050B000000000014000B00000001010000000000050A000000000014000B0000000101000000000005120000000102000000000005200000002002000001020000000000052000000020020000" Type="binary" />
|
||||
<RegistryValue Name="LaunchPermission" Value="01000480580000006800000000000000140000000200440003000000000014000B00000001010000000000050B000000000014000B00000001010000000000050A000000000014000B0000000101000000000005120000000102000000000005200000002002000001020000000000052000000020020000" Type="binary" />
|
||||
<RegistryValue Name="LocalService" Value="WSLAService" Type="string" />
|
||||
</RegistryKey>
|
||||
|
||||
<!-- WSLAUserSession -->
|
||||
<RegistryKey Root="HKCR" Key="CLSID\{a9b7a1b9-0671-405c-95f1-e0612cb4ce8f}">
|
||||
<RegistryValue Name="AppId" Value="{E9B79997-57E3-4201-AECC-6A464E530DD2}" Type="string" />
|
||||
<RegistryValue Value="WSLAUserSession" Type="string" />
|
||||
</RegistryKey>
|
||||
|
||||
<!-- WSLAVirtualMachine -->
|
||||
<RegistryKey Root="HKCR" Key="CLSID\{0CFC5DC1-B6A7-45FC-8034-3FA9ED73CE30}">
|
||||
<RegistryValue Name="AppId" Value="{E9B79997-57E3-4201-AECC-6A464E530DD2}" Type="string" />
|
||||
<RegistryValue Value="WSLAVirtualMachine" Type="string" />
|
||||
</RegistryKey>
|
||||
|
||||
<!-- IWSLAUserSession-->
|
||||
<RegistryKey Root="HKCR" Key="Interface\{82A7ABC8-6B50-43FC-AB96-15FBBE7E8760}">
|
||||
<RegistryValue Value="IWSLAUserSession" Type="string" />
|
||||
<RegistryKey Key="ProxyStubClsid32">
|
||||
<RegistryValue Value="{4EA0C6DD-E9FF-48E7-994E-13A31D10DC61}" Type="string" />
|
||||
</RegistryKey>
|
||||
</RegistryKey>
|
||||
|
||||
<!-- IWSLAVirtualMachine-->
|
||||
<RegistryKey Root="HKCR" Key="Interface\{82A7ABC8-6B50-43FC-AB96-15FBBE7E8761}">
|
||||
<RegistryValue Value="IWSLAVirtualMachine" Type="string" />
|
||||
<RegistryKey Key="ProxyStubClsid32">
|
||||
<RegistryValue Value="{4EA0C6DD-E9FF-48E7-994E-13A31D10DC61}" Type="string" />
|
||||
</RegistryKey>
|
||||
</RegistryKey>
|
||||
|
||||
<!-- ITerminationCallback-->
|
||||
<RegistryKey Root="HKCR" Key="Interface\{7BC4E198-6531-4FA6-ADE2-5EF3D2A04DFE}">
|
||||
<RegistryValue Value="ITerminationCallback" Type="string" />
|
||||
<RegistryKey Key="ProxyStubClsid32">
|
||||
<RegistryValue Value="{4EA0C6DD-E9FF-48E7-994E-13A31D10DC61}" Type="string" />
|
||||
</RegistryKey>
|
||||
</RegistryKey>
|
||||
|
||||
|
||||
<File Id="wslaservice.exe" Source="${BIN}/wslaservice.exe" KeyPath="yes" />
|
||||
<File Id="wslaserviceproxystub.dll" Name="wslaserviceproxystub.dll" Source="${BIN}/wslaserviceproxystub.dll" />
|
||||
|
||||
<ServiceInstall Name="WSLAService" DisplayName="WSLA Service" Description="WSLA Service" Start="auto" Type="ownProcess" ErrorControl="normal" Account="LocalSystem" Vital="yes" Interactive="no" />
|
||||
<ServiceControl Id="StopWSLAService" Stop="both" Remove="uninstall" Name="WSLAService" Wait="yes" />
|
||||
</Component>
|
||||
<Component Id="wslg" Guid="F0C8D6BA-1502-41E7-BF72-D93DFA134731" UninstallWhenSuperseded="yes" DisableRegistryReflection="yes" Bitness="always64">
|
||||
<?if "${WSL_DEV_BINARY_PATH}" = "" ?>
|
||||
<File Id="msrdc.exe" Source="${MSRDC_SOURCE_DIR}/${TARGET_PLATFORM}/msrdc.exe" />
|
||||
@ -402,6 +432,7 @@
|
||||
<Feature Id="WSL" Title="Windows Subsystem for Linux" Level="1">
|
||||
<ComponentRef Id="wsl" />
|
||||
<ComponentRef Id="wslservice" />
|
||||
<ComponentRef Id="wslaservice" />
|
||||
<ComponentRef Id="wslg" />
|
||||
<ComponentRef Id="tools" />
|
||||
<ComponentRef Id="explorershell" />
|
||||
@ -501,21 +532,21 @@
|
||||
Return="check"
|
||||
Execute="deferred"
|
||||
/>
|
||||
|
||||
<CustomAction Id="RemoveRegistryKeyProtections"
|
||||
Impersonate="no"
|
||||
BinaryRef="wslinstall.dll"
|
||||
DllEntry="RemoveRegistryKeyProtections"
|
||||
Return="check"
|
||||
Execute="deferred"
|
||||
|
||||
<CustomAction Id="RemoveRegistryKeyProtections"
|
||||
Impersonate="no"
|
||||
BinaryRef="wslinstall.dll"
|
||||
DllEntry="RemoveRegistryKeyProtections"
|
||||
Return="check"
|
||||
Execute="deferred"
|
||||
/>
|
||||
|
||||
<CustomAction Id="UnregisterLspCategories"
|
||||
Impersonate="no"
|
||||
BinaryRef="wslinstall.dll"
|
||||
DllEntry="UnregisterLspCategories"
|
||||
Return="check"
|
||||
Execute="deferred"
|
||||
<CustomAction Id="UnregisterLspCategories"
|
||||
Impersonate="no"
|
||||
BinaryRef="wslinstall.dll"
|
||||
DllEntry="UnregisterLspCategories"
|
||||
Return="check"
|
||||
Execute="deferred"
|
||||
/>
|
||||
|
||||
<CustomAction Id="InstallMsix.SetProperty" Return="check" Property="InstallMsix" Value='[DATABASE]' Execute='immediate' />
|
||||
|
||||
@ -1,83 +1,105 @@
|
||||
set(SOURCES
|
||||
ConsoleProgressBar.cpp
|
||||
ConsoleProgressIndicator.cpp
|
||||
DeviceHostProxy.cpp
|
||||
DeviceHostProxy.h
|
||||
disk.cpp
|
||||
Distribution.cpp
|
||||
Dmesg.cpp
|
||||
Dmesg.h
|
||||
DnsResolver.cpp
|
||||
DnsTunnelingChannel.cpp
|
||||
ExecutionContext.cpp
|
||||
filesystem.cpp
|
||||
GnsChannel.cpp
|
||||
HandleConsoleProgressBar.cpp
|
||||
hcs.cpp
|
||||
helpers.cpp
|
||||
interop.cpp
|
||||
ExecutionContext.cpp
|
||||
socket.cpp
|
||||
hvsocket.cpp
|
||||
interop.cpp
|
||||
Localization.cpp
|
||||
lxssbusclient.cpp
|
||||
lxssclient.cpp
|
||||
LxssMessagePort.cpp
|
||||
LxssSecurity.cpp
|
||||
LxssServerPort.cpp
|
||||
NatNetworking.cpp
|
||||
notifications.cpp
|
||||
Redirector.cpp
|
||||
registry.cpp
|
||||
relay.cpp
|
||||
RingBuffer.cpp
|
||||
socket.cpp
|
||||
string.cpp
|
||||
SubProcess.cpp
|
||||
svccomm.cpp
|
||||
svccommio.cpp
|
||||
WslClient.cpp
|
||||
WslCoreConfig.cpp
|
||||
WslCoreFilesystem.cpp
|
||||
WslCoreFirewallSupport.cpp
|
||||
WslCoreHostDnsInfo.cpp
|
||||
WslCoreNetworkEndpointSettings.cpp
|
||||
WslCoreNetworkingSupport.cpp
|
||||
WslInstall.cpp
|
||||
WslSecurity.cpp
|
||||
WslTelemetry.cpp
|
||||
wslutil.cpp
|
||||
notifications.cpp)
|
||||
)
|
||||
|
||||
set(HEADERS
|
||||
../../../generated/Localization.h
|
||||
../../shared/inc/CommandLine.h
|
||||
../../shared/inc/defs.h
|
||||
../../shared/inc/lxfsshares.h
|
||||
../../shared/inc/lxinitshared.h
|
||||
../../shared/inc/SocketChannel.h
|
||||
../../shared/inc/socketshared.h
|
||||
../../shared/inc/hns_schema.h
|
||||
../../shared/inc/JsonUtils.h
|
||||
../../shared/inc/stringshared.h
|
||||
../../shared/inc/retryshared.h
|
||||
../../shared/inc/message.h
|
||||
../../shared/inc/prettyprintshared.h
|
||||
../inc/WslPluginApi.h
|
||||
../inc/wslpolicies.h
|
||||
../inc/lxssbusclient.h
|
||||
../inc/lxssclient.h
|
||||
../inc/LxssDynamicFunction.h
|
||||
../inc/traceloggingconfig.h
|
||||
../inc/wdk.h
|
||||
../inc/wsl.h
|
||||
../inc/wslconfig.h
|
||||
../inc/wsl.h
|
||||
../inc/wslhost.h
|
||||
../inc/WslPluginApi.h
|
||||
../inc/wslpolicies.h
|
||||
../inc/wslrelay.h
|
||||
../../shared/inc/CommandLine.h
|
||||
../../shared/inc/defs.h
|
||||
../../shared/inc/hns_schema.h
|
||||
../../shared/inc/JsonUtils.h
|
||||
../../shared/inc/lxfsshares.h
|
||||
../../shared/inc/lxinitshared.h
|
||||
../../shared/inc/message.h
|
||||
../../shared/inc/prettyprintshared.h
|
||||
../../shared/inc/retryshared.h
|
||||
../../shared/inc/SocketChannel.h
|
||||
../../shared/inc/socketshared.h
|
||||
../../shared/inc/stringshared.h
|
||||
ConsoleProgressBar.h
|
||||
ConsoleProgressIndicator.h
|
||||
disk.hpp
|
||||
Distribution.h
|
||||
DnsResolver.h
|
||||
DnsTunnelingChannel.h
|
||||
ExecutionContext.h
|
||||
filesystem.hpp
|
||||
GnsChannel.h
|
||||
HandleConsoleProgressBar.h
|
||||
hcs.hpp
|
||||
hcs_schema.h
|
||||
helpers.hpp
|
||||
HandleConsoleProgressBar.h
|
||||
interop.hpp
|
||||
ExecutionContext.h
|
||||
socket.hpp
|
||||
hvsocket.hpp
|
||||
INetworkingEngine.h
|
||||
interop.hpp
|
||||
LxssMessagePort.h
|
||||
LxssPort.h
|
||||
LxssSecurity.h
|
||||
LxssServerPort.h
|
||||
NatNetworking.h
|
||||
notifications.h
|
||||
precomp.h
|
||||
Redirector.h
|
||||
registry.hpp
|
||||
relay.hpp
|
||||
RingBuffer.h
|
||||
socket.hpp
|
||||
string.hpp
|
||||
Stringify.h
|
||||
SubProcess.h
|
||||
@ -85,13 +107,17 @@ set(HEADERS
|
||||
svccommio.hpp
|
||||
WslClient.h
|
||||
WslCoreConfig.h
|
||||
WslCoreFilesystem.h
|
||||
WslCoreFirewallSupport.h
|
||||
WslCoreHostDnsInfo.h
|
||||
WslCoreMessageQueue.h
|
||||
WslCoreNetworkEndpointSettings.h
|
||||
WslCoreNetworkingSupport.h
|
||||
WslInstall.h
|
||||
WslSecurity.h
|
||||
WslTelemetry.h
|
||||
wslutil.h
|
||||
notifications.h)
|
||||
wslutil.cpp
|
||||
)
|
||||
|
||||
add_library(common STATIC ${SOURCES} ${HEADERS})
|
||||
add_dependencies(common wslserviceidl localization wslservicemc wslinstalleridl)
|
||||
|
||||
@ -1,261 +1,261 @@
|
||||
// Copyright (C) Microsoft Corporation. All rights reserved.
|
||||
|
||||
#include "precomp.h"
|
||||
#include "DeviceHostProxy.h"
|
||||
|
||||
// This template works around a limitation with decltype on overloaded functions. It will be able
|
||||
// to get the correct version of GetVmWorkerProcess based on the provided type arguments. By
|
||||
// doing it this way, a compiler error will be generated if someone changes the signature of
|
||||
// GetVmWorkerProcess.
|
||||
//
|
||||
// The way this works: decltype(GetVmWorkerProcess) does not work because it's overloaded.
|
||||
// decltype(GetVmWorkerProcess(arg1, ...)) works to select an overload if you have values of the
|
||||
// correct type (std::declval<T>() generates a value of the specified type), however the result
|
||||
// of that is the function's return type, not the function's type, so the argument types must
|
||||
// be repeated to reconstruct the function type.
|
||||
template <typename... Args>
|
||||
using GetVmWorkerProcessType = decltype(GetVmWorkerProcess(std::declval<Args>()...))(Args...);
|
||||
|
||||
// Limit the number of allowed doorbells registered by an external HDV vdev. Currently virtio-9p only uses
|
||||
// one doorbell and wsldevicehost uses only two.
|
||||
#define DEVICE_HOST_PROXY_DOORBELL_LIMIT 8
|
||||
|
||||
using namespace wsl::windows::common::hcs;
|
||||
|
||||
DeviceHostProxy::DeviceHostProxy(const std::wstring& VmId, const GUID& RuntimeId) :
|
||||
m_systemId{VmId}, m_runtimeId{RuntimeId}, m_system{wsl::windows::common::hcs::OpenComputeSystem(VmId.c_str(), GENERIC_ALL)}, m_shutdown{false}
|
||||
{
|
||||
m_devicesShutdown = false;
|
||||
}
|
||||
|
||||
GUID DeviceHostProxy::AddNewDevice(const GUID& Type, const wil::com_ptr<IPlan9FileSystem>& Plan9Fs, const std::wstring& VirtIoTag)
|
||||
{
|
||||
const wrl::ComPtr<IUnknown> thisUnknown{CastToUnknown()};
|
||||
GUID instanceId{};
|
||||
THROW_IF_FAILED(UuidCreate(&instanceId));
|
||||
// Tell the device host to create the device.
|
||||
THROW_IF_FAILED(Plan9Fs->CreateVirtioDevice(m_systemId.c_str(), thisUnknown.Get(), VirtIoTag.c_str(), &instanceId));
|
||||
|
||||
// Add the instance ID to the list of known devices. This must be done before the device is
|
||||
// added to the system, because doing that can cause the register doorbell function to be
|
||||
// called.
|
||||
// N.B. It will be removed if there is a failure.
|
||||
{
|
||||
auto lock = m_devicesLock.lock_exclusive();
|
||||
THROW_HR_IF(E_CHANGED_STATE, m_devicesShutdown);
|
||||
|
||||
m_devices.emplace(instanceId, DeviceHostProxyEntry{});
|
||||
}
|
||||
|
||||
auto removeOnFailure = wil::scope_exit_log(WI_DIAGNOSTICS_INFO, [&]() {
|
||||
auto lock = m_devicesLock.lock_exclusive();
|
||||
m_devices.erase(instanceId);
|
||||
});
|
||||
|
||||
// Add the device to the compute system on behalf of the device host.
|
||||
ModifySettingRequest<FlexibleIoDevice> request;
|
||||
request.RequestType = ModifyRequestType::Add;
|
||||
request.ResourcePath = L"VirtualMachine/Devices/FlexibleIov/";
|
||||
request.ResourcePath += wsl::shared::string::GuidToString<wchar_t>(instanceId, wsl::shared::string::GuidToStringFlags::None);
|
||||
request.Settings.EmulatorId = Type;
|
||||
request.Settings.HostingModel = FlexibleIoDeviceHostingModel::ExternalRestricted;
|
||||
wsl::windows::common::hcs::ModifyComputeSystem(m_system.get(), wsl::shared::ToJsonW(request).c_str());
|
||||
removeOnFailure.release();
|
||||
return instanceId;
|
||||
}
|
||||
|
||||
void DeviceHostProxy::AddRemoteFileSystem(const GUID& ImplementationClsid, const std::wstring& Tag, const wil::com_ptr<IPlan9FileSystem>& Plan9Fs)
|
||||
{
|
||||
auto lock = m_lock.lock_exclusive();
|
||||
THROW_HR_IF(E_CHANGED_STATE, m_shutdown);
|
||||
|
||||
// Make sure there are no duplicate tags.
|
||||
for (auto& entry : m_fileSystems)
|
||||
{
|
||||
THROW_HR_IF(E_INVALIDARG, entry.ImplementationClsid == ImplementationClsid && entry.Tag == Tag);
|
||||
}
|
||||
|
||||
m_fileSystems.emplace_back(ImplementationClsid, Tag, Plan9Fs);
|
||||
}
|
||||
|
||||
wil::com_ptr<IPlan9FileSystem> DeviceHostProxy::GetRemoteFileSystem(const GUID& ImplementationClsid, std::wstring_view Tag)
|
||||
{
|
||||
auto lock = m_lock.lock_shared();
|
||||
THROW_HR_IF(E_CHANGED_STATE, m_shutdown);
|
||||
|
||||
for (auto& entry : m_fileSystems)
|
||||
{
|
||||
if (entry.ImplementationClsid == ImplementationClsid && entry.Tag == Tag)
|
||||
{
|
||||
return entry.Instance;
|
||||
}
|
||||
}
|
||||
|
||||
return {};
|
||||
}
|
||||
|
||||
void DeviceHostProxy::Shutdown()
|
||||
{
|
||||
{
|
||||
auto lock = m_lock.lock_exclusive();
|
||||
m_fileSystems.clear();
|
||||
m_shutdown = true;
|
||||
}
|
||||
|
||||
{
|
||||
auto lock = m_devicesLock.lock_exclusive();
|
||||
m_devices.clear();
|
||||
m_devicesShutdown = true;
|
||||
}
|
||||
}
|
||||
|
||||
HRESULT
|
||||
DeviceHostProxy::RegisterDeviceHost(_In_ IVmDeviceHost* DeviceHost, _In_ DWORD ProcessId, _Out_ UINT64* IpcSectionHandle)
|
||||
try
|
||||
{
|
||||
//
|
||||
// Because HdvProxyDeviceHost is not part of the API set, it is loaded here dynamically.
|
||||
//
|
||||
|
||||
static LxssDynamicFunction<decltype(HdvProxyDeviceHost)> proxyDeviceHost{c_hdvModuleName, "HdvProxyDeviceHost"};
|
||||
const wil::com_ptr<IVmDeviceHost> remoteHost = DeviceHost;
|
||||
const wil::com_ptr<IUnknown> unknown = remoteHost.query<IUnknown>();
|
||||
THROW_IF_FAILED(proxyDeviceHost(m_system.get(), unknown.get(), ProcessId, IpcSectionHandle));
|
||||
return S_OK;
|
||||
}
|
||||
CATCH_RETURN()
|
||||
|
||||
HRESULT
|
||||
DeviceHostProxy::NotifyAllDevicesInUse(_In_ LPCWSTR Tag)
|
||||
try
|
||||
{
|
||||
//
|
||||
// Add another Plan9 virtio device to the guest so additional mount commands will be possible.
|
||||
// This callback should be unused by virtiofs devices because a device is created for every
|
||||
// AddSharePath call.
|
||||
//
|
||||
auto p9fs = GetRemoteFileSystem(__uuidof(p9fs::Plan9FileSystem), Tag);
|
||||
THROW_HR_IF(E_NOT_SET, !p9fs);
|
||||
(void)AddNewDevice(VIRTIO_PLAN9_DEVICE_ID, p9fs, Tag);
|
||||
return S_OK;
|
||||
}
|
||||
CATCH_RETURN()
|
||||
|
||||
HRESULT
|
||||
DeviceHostProxy::RegisterDoorbell(const GUID& InstanceId, UINT8 BarIndex, UINT64 Offset, UINT64 TriggerValue, UINT64 Flags, HANDLE Event)
|
||||
try
|
||||
{
|
||||
auto lock = m_devicesLock.lock_exclusive();
|
||||
RETURN_HR_IF(E_CHANGED_STATE, m_devicesShutdown);
|
||||
|
||||
// Check if the device is one of the known devices that doorbells can be registered for, and
|
||||
// if the device has not already registered a doorbell.
|
||||
// N.B. For security it is enforced that each device can only register a small number of doorbells.
|
||||
// Currently virtio-9p only uses one and the external virtio device uses two.
|
||||
const auto knownDevice = m_devices.find(InstanceId);
|
||||
RETURN_HR_IF(E_ACCESSDENIED, knownDevice == m_devices.end() || knownDevice->second.DoorbellCount == DEVICE_HOST_PROXY_DOORBELL_LIMIT);
|
||||
|
||||
if (!knownDevice->second.MemoryNotification)
|
||||
{
|
||||
// Get an interface to the worker process to query devices.
|
||||
if (!m_deviceAccess)
|
||||
{
|
||||
static LxssDynamicFunction<GetVmWorkerProcessType<REFGUID, REFIID, IUnknown**>> getVmWorker{
|
||||
c_vmwpctrlModuleName, "GetVmWorkerProcess"};
|
||||
|
||||
RETURN_IF_FAILED(getVmWorker(m_runtimeId, __uuidof(*m_deviceAccess), reinterpret_cast<IUnknown**>(&m_deviceAccess)));
|
||||
}
|
||||
|
||||
RETURN_HR_IF(E_NOINTERFACE, !m_deviceAccess);
|
||||
|
||||
// Retrieve the device's memory notification interface to register the doorbell, and store it
|
||||
// to be used during unregistration.
|
||||
wil::com_ptr<IUnknown> device;
|
||||
RETURN_IF_FAILED(m_deviceAccess->GetDevice(FLEXIO_DEVICE_ID, InstanceId, &device));
|
||||
knownDevice->second.MemoryNotification = device.query<IVmFiovGuestMemoryFastNotification>();
|
||||
}
|
||||
|
||||
const auto result = knownDevice->second.MemoryNotification->RegisterDoorbell(
|
||||
static_cast<FIOV_BAR_SELECTOR>(BarIndex), Offset, TriggerValue, Flags, Event);
|
||||
|
||||
if (SUCCEEDED(result))
|
||||
{
|
||||
++knownDevice->second.DoorbellCount;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
CATCH_RETURN()
|
||||
|
||||
HRESULT
|
||||
DeviceHostProxy::UnregisterDoorbell(const GUID& InstanceId, UINT8 BarIndex, UINT64 Offset, UINT64 TriggerValue, UINT64 Flags)
|
||||
try
|
||||
{
|
||||
auto lock = m_devicesLock.lock_exclusive();
|
||||
RETURN_HR_IF(E_CHANGED_STATE, m_devicesShutdown);
|
||||
|
||||
// Check if the device is a known device and has registered a doorbell.
|
||||
// N.B. If the device is being removed, the device can't be retrieved from the worker process
|
||||
// so it's necessary to use the stored COM pointer.
|
||||
const auto device = m_devices.find(InstanceId);
|
||||
RETURN_HR_IF(E_ACCESSDENIED, device == m_devices.end() || device->second.DoorbellCount == 0);
|
||||
RETURN_IF_FAILED(device->second.MemoryNotification->UnregisterDoorbell(static_cast<FIOV_BAR_SELECTOR>(BarIndex), Offset, TriggerValue, Flags));
|
||||
|
||||
if (--device->second.DoorbellCount == 0)
|
||||
{
|
||||
device->second.MemoryNotification.reset();
|
||||
}
|
||||
|
||||
return S_OK;
|
||||
}
|
||||
CATCH_RETURN()
|
||||
|
||||
HRESULT
|
||||
DeviceHostProxy::CreateSectionBackedMmioRange(
|
||||
const GUID& InstanceId, UINT8 BarIndex, UINT64 BarOffsetInPages, UINT64 PageCount, UINT64 MappingFlags, HANDLE SectionHandle, UINT64 SectionOffsetInPages)
|
||||
try
|
||||
{
|
||||
auto lock = m_devicesLock.lock_exclusive();
|
||||
RETURN_HR_IF(E_CHANGED_STATE, m_devicesShutdown);
|
||||
|
||||
// Check if the device is one of the known devices.
|
||||
const auto knownDevice = m_devices.find(InstanceId);
|
||||
THROW_HR_IF(E_ACCESSDENIED, knownDevice == m_devices.end());
|
||||
|
||||
if (!knownDevice->second.MemoryMapping)
|
||||
{
|
||||
// Get an interface to the worker process to query devices.
|
||||
if (!m_deviceAccess)
|
||||
{
|
||||
static LxssDynamicFunction<GetVmWorkerProcessType<REFGUID, REFIID, IUnknown**>> getVmWorker{
|
||||
c_vmwpctrlModuleName, "GetVmWorkerProcess"};
|
||||
THROW_IF_FAILED(getVmWorker(m_runtimeId, __uuidof(*m_deviceAccess), reinterpret_cast<IUnknown**>(&m_deviceAccess)));
|
||||
}
|
||||
|
||||
THROW_HR_IF(E_NOINTERFACE, !m_deviceAccess);
|
||||
|
||||
// Retrieve the device specific interface to manage mapped sections.
|
||||
wil::com_ptr<IUnknown> device;
|
||||
THROW_IF_FAILED(m_deviceAccess->GetDevice(FLEXIO_DEVICE_ID, InstanceId, &device));
|
||||
knownDevice->second.MemoryMapping = device.query<IVmFiovGuestMmioMappings>();
|
||||
}
|
||||
|
||||
THROW_IF_FAILED(knownDevice->second.MemoryMapping->CreateSectionBackedMmioRange(
|
||||
static_cast<FIOV_BAR_SELECTOR>(BarIndex), BarOffsetInPages, PageCount, static_cast<FiovMmioMappingFlags>(MappingFlags), SectionHandle, SectionOffsetInPages));
|
||||
|
||||
return S_OK;
|
||||
}
|
||||
CATCH_RETURN()
|
||||
|
||||
HRESULT
|
||||
DeviceHostProxy::DestroySectionBackedMmioRange(const GUID& InstanceId, UINT8 BarIndex, UINT64 BarOffsetInPages)
|
||||
try
|
||||
{
|
||||
auto lock = m_devicesLock.lock_exclusive();
|
||||
RETURN_HR_IF(E_CHANGED_STATE, m_devicesShutdown);
|
||||
const auto device = m_devices.find(InstanceId);
|
||||
RETURN_HR_IF(E_ACCESSDENIED, device == m_devices.end() || !device->second.MemoryMapping);
|
||||
RETURN_IF_FAILED(device->second.MemoryMapping->DestroySectionBackedMmioRange(static_cast<FIOV_BAR_SELECTOR>(BarIndex), BarOffsetInPages));
|
||||
return S_OK;
|
||||
}
|
||||
// Copyright (C) Microsoft Corporation. All rights reserved.
|
||||
|
||||
#include "precomp.h"
|
||||
#include "DeviceHostProxy.h"
|
||||
|
||||
// This template works around a limitation with decltype on overloaded functions. It will be able
|
||||
// to get the correct version of GetVmWorkerProcess based on the provided type arguments. By
|
||||
// doing it this way, a compiler error will be generated if someone changes the signature of
|
||||
// GetVmWorkerProcess.
|
||||
//
|
||||
// The way this works: decltype(GetVmWorkerProcess) does not work because it's overloaded.
|
||||
// decltype(GetVmWorkerProcess(arg1, ...)) works to select an overload if you have values of the
|
||||
// correct type (std::declval<T>() generates a value of the specified type), however the result
|
||||
// of that is the function's return type, not the function's type, so the argument types must
|
||||
// be repeated to reconstruct the function type.
|
||||
template <typename... Args>
|
||||
using GetVmWorkerProcessType = decltype(GetVmWorkerProcess(std::declval<Args>()...))(Args...);
|
||||
|
||||
// Limit the number of allowed doorbells registered by an external HDV vdev. Currently virtio-9p only uses
|
||||
// one doorbell and wsldevicehost uses only two.
|
||||
#define DEVICE_HOST_PROXY_DOORBELL_LIMIT 8
|
||||
|
||||
using namespace wsl::windows::common::hcs;
|
||||
|
||||
DeviceHostProxy::DeviceHostProxy(const std::wstring& VmId, const GUID& RuntimeId) :
|
||||
m_systemId{VmId}, m_runtimeId{RuntimeId}, m_system{wsl::windows::common::hcs::OpenComputeSystem(VmId.c_str(), GENERIC_ALL)}, m_shutdown{false}
|
||||
{
|
||||
m_devicesShutdown = false;
|
||||
}
|
||||
|
||||
GUID DeviceHostProxy::AddNewDevice(const GUID& Type, const wil::com_ptr<IPlan9FileSystem>& Plan9Fs, const std::wstring& VirtIoTag)
|
||||
{
|
||||
const wrl::ComPtr<IUnknown> thisUnknown{CastToUnknown()};
|
||||
GUID instanceId{};
|
||||
THROW_IF_FAILED(UuidCreate(&instanceId));
|
||||
// Tell the device host to create the device.
|
||||
THROW_IF_FAILED(Plan9Fs->CreateVirtioDevice(m_systemId.c_str(), thisUnknown.Get(), VirtIoTag.c_str(), &instanceId));
|
||||
|
||||
// Add the instance ID to the list of known devices. This must be done before the device is
|
||||
// added to the system, because doing that can cause the register doorbell function to be
|
||||
// called.
|
||||
// N.B. It will be removed if there is a failure.
|
||||
{
|
||||
auto lock = m_devicesLock.lock_exclusive();
|
||||
THROW_HR_IF(E_CHANGED_STATE, m_devicesShutdown);
|
||||
|
||||
m_devices.emplace(instanceId, DeviceHostProxyEntry{});
|
||||
}
|
||||
|
||||
auto removeOnFailure = wil::scope_exit_log(WI_DIAGNOSTICS_INFO, [&]() {
|
||||
auto lock = m_devicesLock.lock_exclusive();
|
||||
m_devices.erase(instanceId);
|
||||
});
|
||||
|
||||
// Add the device to the compute system on behalf of the device host.
|
||||
ModifySettingRequest<FlexibleIoDevice> request;
|
||||
request.RequestType = ModifyRequestType::Add;
|
||||
request.ResourcePath = L"VirtualMachine/Devices/FlexibleIov/";
|
||||
request.ResourcePath += wsl::shared::string::GuidToString<wchar_t>(instanceId, wsl::shared::string::GuidToStringFlags::None);
|
||||
request.Settings.EmulatorId = Type;
|
||||
request.Settings.HostingModel = FlexibleIoDeviceHostingModel::ExternalRestricted;
|
||||
wsl::windows::common::hcs::ModifyComputeSystem(m_system.get(), wsl::shared::ToJsonW(request).c_str());
|
||||
removeOnFailure.release();
|
||||
return instanceId;
|
||||
}
|
||||
|
||||
void DeviceHostProxy::AddRemoteFileSystem(const GUID& ImplementationClsid, const std::wstring& Tag, const wil::com_ptr<IPlan9FileSystem>& Plan9Fs)
|
||||
{
|
||||
auto lock = m_lock.lock_exclusive();
|
||||
THROW_HR_IF(E_CHANGED_STATE, m_shutdown);
|
||||
|
||||
// Make sure there are no duplicate tags.
|
||||
for (auto& entry : m_fileSystems)
|
||||
{
|
||||
THROW_HR_IF(E_INVALIDARG, entry.ImplementationClsid == ImplementationClsid && entry.Tag == Tag);
|
||||
}
|
||||
|
||||
m_fileSystems.emplace_back(ImplementationClsid, Tag, Plan9Fs);
|
||||
}
|
||||
|
||||
wil::com_ptr<IPlan9FileSystem> DeviceHostProxy::GetRemoteFileSystem(const GUID& ImplementationClsid, std::wstring_view Tag)
|
||||
{
|
||||
auto lock = m_lock.lock_shared();
|
||||
THROW_HR_IF(E_CHANGED_STATE, m_shutdown);
|
||||
|
||||
for (auto& entry : m_fileSystems)
|
||||
{
|
||||
if (entry.ImplementationClsid == ImplementationClsid && entry.Tag == Tag)
|
||||
{
|
||||
return entry.Instance;
|
||||
}
|
||||
}
|
||||
|
||||
return {};
|
||||
}
|
||||
|
||||
void DeviceHostProxy::Shutdown()
|
||||
{
|
||||
{
|
||||
auto lock = m_lock.lock_exclusive();
|
||||
m_fileSystems.clear();
|
||||
m_shutdown = true;
|
||||
}
|
||||
|
||||
{
|
||||
auto lock = m_devicesLock.lock_exclusive();
|
||||
m_devices.clear();
|
||||
m_devicesShutdown = true;
|
||||
}
|
||||
}
|
||||
|
||||
HRESULT
|
||||
DeviceHostProxy::RegisterDeviceHost(_In_ IVmDeviceHost* DeviceHost, _In_ DWORD ProcessId, _Out_ UINT64* IpcSectionHandle)
|
||||
try
|
||||
{
|
||||
//
|
||||
// Because HdvProxyDeviceHost is not part of the API set, it is loaded here dynamically.
|
||||
//
|
||||
|
||||
static LxssDynamicFunction<decltype(HdvProxyDeviceHost)> proxyDeviceHost{c_hdvModuleName, "HdvProxyDeviceHost"};
|
||||
const wil::com_ptr<IVmDeviceHost> remoteHost = DeviceHost;
|
||||
const wil::com_ptr<IUnknown> unknown = remoteHost.query<IUnknown>();
|
||||
THROW_IF_FAILED(proxyDeviceHost(m_system.get(), unknown.get(), ProcessId, IpcSectionHandle));
|
||||
return S_OK;
|
||||
}
|
||||
CATCH_RETURN()
|
||||
|
||||
HRESULT
|
||||
DeviceHostProxy::NotifyAllDevicesInUse(_In_ LPCWSTR Tag)
|
||||
try
|
||||
{
|
||||
//
|
||||
// Add another Plan9 virtio device to the guest so additional mount commands will be possible.
|
||||
// This callback should be unused by virtiofs devices because a device is created for every
|
||||
// AddSharePath call.
|
||||
//
|
||||
auto p9fs = GetRemoteFileSystem(__uuidof(p9fs::Plan9FileSystem), Tag);
|
||||
THROW_HR_IF(E_NOT_SET, !p9fs);
|
||||
(void)AddNewDevice(VIRTIO_PLAN9_DEVICE_ID, p9fs, Tag);
|
||||
return S_OK;
|
||||
}
|
||||
CATCH_RETURN()
|
||||
|
||||
HRESULT
|
||||
DeviceHostProxy::RegisterDoorbell(const GUID& InstanceId, UINT8 BarIndex, UINT64 Offset, UINT64 TriggerValue, UINT64 Flags, HANDLE Event)
|
||||
try
|
||||
{
|
||||
auto lock = m_devicesLock.lock_exclusive();
|
||||
RETURN_HR_IF(E_CHANGED_STATE, m_devicesShutdown);
|
||||
|
||||
// Check if the device is one of the known devices that doorbells can be registered for, and
|
||||
// if the device has not already registered a doorbell.
|
||||
// N.B. For security it is enforced that each device can only register a small number of doorbells.
|
||||
// Currently virtio-9p only uses one and the external virtio device uses two.
|
||||
const auto knownDevice = m_devices.find(InstanceId);
|
||||
RETURN_HR_IF(E_ACCESSDENIED, knownDevice == m_devices.end() || knownDevice->second.DoorbellCount == DEVICE_HOST_PROXY_DOORBELL_LIMIT);
|
||||
|
||||
if (!knownDevice->second.MemoryNotification)
|
||||
{
|
||||
// Get an interface to the worker process to query devices.
|
||||
if (!m_deviceAccess)
|
||||
{
|
||||
static LxssDynamicFunction<GetVmWorkerProcessType<REFGUID, REFIID, IUnknown**>> getVmWorker{
|
||||
c_vmwpctrlModuleName, "GetVmWorkerProcess"};
|
||||
|
||||
RETURN_IF_FAILED(getVmWorker(m_runtimeId, __uuidof(*m_deviceAccess), reinterpret_cast<IUnknown**>(&m_deviceAccess)));
|
||||
}
|
||||
|
||||
RETURN_HR_IF(E_NOINTERFACE, !m_deviceAccess);
|
||||
|
||||
// Retrieve the device's memory notification interface to register the doorbell, and store it
|
||||
// to be used during unregistration.
|
||||
wil::com_ptr<IUnknown> device;
|
||||
RETURN_IF_FAILED(m_deviceAccess->GetDevice(FLEXIO_DEVICE_ID, InstanceId, &device));
|
||||
knownDevice->second.MemoryNotification = device.query<IVmFiovGuestMemoryFastNotification>();
|
||||
}
|
||||
|
||||
const auto result = knownDevice->second.MemoryNotification->RegisterDoorbell(
|
||||
static_cast<FIOV_BAR_SELECTOR>(BarIndex), Offset, TriggerValue, Flags, Event);
|
||||
|
||||
if (SUCCEEDED(result))
|
||||
{
|
||||
++knownDevice->second.DoorbellCount;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
CATCH_RETURN()
|
||||
|
||||
HRESULT
|
||||
DeviceHostProxy::UnregisterDoorbell(const GUID& InstanceId, UINT8 BarIndex, UINT64 Offset, UINT64 TriggerValue, UINT64 Flags)
|
||||
try
|
||||
{
|
||||
auto lock = m_devicesLock.lock_exclusive();
|
||||
RETURN_HR_IF(E_CHANGED_STATE, m_devicesShutdown);
|
||||
|
||||
// Check if the device is a known device and has registered a doorbell.
|
||||
// N.B. If the device is being removed, the device can't be retrieved from the worker process
|
||||
// so it's necessary to use the stored COM pointer.
|
||||
const auto device = m_devices.find(InstanceId);
|
||||
RETURN_HR_IF(E_ACCESSDENIED, device == m_devices.end() || device->second.DoorbellCount == 0);
|
||||
RETURN_IF_FAILED(device->second.MemoryNotification->UnregisterDoorbell(static_cast<FIOV_BAR_SELECTOR>(BarIndex), Offset, TriggerValue, Flags));
|
||||
|
||||
if (--device->second.DoorbellCount == 0)
|
||||
{
|
||||
device->second.MemoryNotification.reset();
|
||||
}
|
||||
|
||||
return S_OK;
|
||||
}
|
||||
CATCH_RETURN()
|
||||
|
||||
HRESULT
|
||||
DeviceHostProxy::CreateSectionBackedMmioRange(
|
||||
const GUID& InstanceId, UINT8 BarIndex, UINT64 BarOffsetInPages, UINT64 PageCount, UINT64 MappingFlags, HANDLE SectionHandle, UINT64 SectionOffsetInPages)
|
||||
try
|
||||
{
|
||||
auto lock = m_devicesLock.lock_exclusive();
|
||||
RETURN_HR_IF(E_CHANGED_STATE, m_devicesShutdown);
|
||||
|
||||
// Check if the device is one of the known devices.
|
||||
const auto knownDevice = m_devices.find(InstanceId);
|
||||
THROW_HR_IF(E_ACCESSDENIED, knownDevice == m_devices.end());
|
||||
|
||||
if (!knownDevice->second.MemoryMapping)
|
||||
{
|
||||
// Get an interface to the worker process to query devices.
|
||||
if (!m_deviceAccess)
|
||||
{
|
||||
static LxssDynamicFunction<GetVmWorkerProcessType<REFGUID, REFIID, IUnknown**>> getVmWorker{
|
||||
c_vmwpctrlModuleName, "GetVmWorkerProcess"};
|
||||
THROW_IF_FAILED(getVmWorker(m_runtimeId, __uuidof(*m_deviceAccess), reinterpret_cast<IUnknown**>(&m_deviceAccess)));
|
||||
}
|
||||
|
||||
THROW_HR_IF(E_NOINTERFACE, !m_deviceAccess);
|
||||
|
||||
// Retrieve the device specific interface to manage mapped sections.
|
||||
wil::com_ptr<IUnknown> device;
|
||||
THROW_IF_FAILED(m_deviceAccess->GetDevice(FLEXIO_DEVICE_ID, InstanceId, &device));
|
||||
knownDevice->second.MemoryMapping = device.query<IVmFiovGuestMmioMappings>();
|
||||
}
|
||||
|
||||
THROW_IF_FAILED(knownDevice->second.MemoryMapping->CreateSectionBackedMmioRange(
|
||||
static_cast<FIOV_BAR_SELECTOR>(BarIndex), BarOffsetInPages, PageCount, static_cast<FiovMmioMappingFlags>(MappingFlags), SectionHandle, SectionOffsetInPages));
|
||||
|
||||
return S_OK;
|
||||
}
|
||||
CATCH_RETURN()
|
||||
|
||||
HRESULT
|
||||
DeviceHostProxy::DestroySectionBackedMmioRange(const GUID& InstanceId, UINT8 BarIndex, UINT64 BarOffsetInPages)
|
||||
try
|
||||
{
|
||||
auto lock = m_devicesLock.lock_exclusive();
|
||||
RETURN_HR_IF(E_CHANGED_STATE, m_devicesShutdown);
|
||||
const auto device = m_devices.find(InstanceId);
|
||||
RETURN_HR_IF(E_ACCESSDENIED, device == m_devices.end() || !device->second.MemoryMapping);
|
||||
RETURN_IF_FAILED(device->second.MemoryMapping->DestroySectionBackedMmioRange(static_cast<FIOV_BAR_SELECTOR>(BarIndex), BarOffsetInPages));
|
||||
return S_OK;
|
||||
}
|
||||
CATCH_RETURN()
|
||||
@ -1,76 +1,76 @@
|
||||
// Copyright (C) Microsoft Corporation. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <windowsdefs.h>
|
||||
#include "hcs.hpp"
|
||||
|
||||
namespace wrl = Microsoft::WRL;
|
||||
|
||||
class DeviceHostProxy : public wrl::RuntimeClass<wrl::RuntimeClassFlags<wrl::RuntimeClassType::ClassicCom>, IVmDeviceHostSupport, IPlan9FileSystemHost>
|
||||
{
|
||||
public:
|
||||
DeviceHostProxy(const std::wstring& VmId, const GUID& RuntimeId);
|
||||
|
||||
GUID AddNewDevice(const GUID& Type, const wil::com_ptr<IPlan9FileSystem>& Plan9Fs, const std::wstring& VirtIoTag);
|
||||
|
||||
void AddRemoteFileSystem(const GUID& ImplementationClsid, const std::wstring& Tag, const wil::com_ptr<IPlan9FileSystem>& Plan9Fs);
|
||||
|
||||
wil::com_ptr<IPlan9FileSystem> GetRemoteFileSystem(const GUID& ImplementationClsid, std::wstring_view Tag);
|
||||
|
||||
void Shutdown();
|
||||
|
||||
//
|
||||
// IVmDeviceHostSupport
|
||||
//
|
||||
IFACEMETHOD(RegisterDeviceHost)(_In_ IVmDeviceHost* DeviceHost, _In_ DWORD ProcessId, _Out_ UINT64* IpcSectionHandle) override;
|
||||
|
||||
//
|
||||
// IPlan9FileSystemHost
|
||||
//
|
||||
IFACEMETHOD(NotifyAllDevicesInUse)(_In_ LPCWSTR Tag) override;
|
||||
|
||||
IFACEMETHOD(RegisterDoorbell)(const GUID& InstanceId, UINT8 BarIndex, UINT64 Offset, UINT64 TriggerValue, UINT64 Flags, HANDLE Event) override;
|
||||
|
||||
IFACEMETHOD(UnregisterDoorbell)(const GUID& InstanceId, UINT8 BarIndex, UINT64 Offset, UINT64 TriggerValue, UINT64 Flags) override;
|
||||
|
||||
IFACEMETHOD(CreateSectionBackedMmioRange)(
|
||||
const GUID& InstanceId, UINT8 BarIndex, UINT64 BarOffsetInPages, UINT64 PageCount, UINT64 MappingFlags, HANDLE SectionHandle, UINT64 SectionOffsetInPages) override;
|
||||
|
||||
IFACEMETHOD(DestroySectionBackedMmioRange)(const GUID& InstanceId, UINT8 BarIndex, UINT64 BarOffsetInPages) override;
|
||||
|
||||
private:
|
||||
struct RemoteFileSystemInfo
|
||||
{
|
||||
RemoteFileSystemInfo(GUID ImplementationClsid, const std::wstring& Tag, const wil::com_ptr<IPlan9FileSystem>& Instance) :
|
||||
ImplementationClsid{ImplementationClsid}, Tag{Tag}, Instance{Instance}
|
||||
{
|
||||
}
|
||||
|
||||
GUID ImplementationClsid;
|
||||
std::wstring Tag;
|
||||
wil::com_ptr<IPlan9FileSystem> Instance;
|
||||
};
|
||||
|
||||
std::wstring m_systemId;
|
||||
GUID m_runtimeId;
|
||||
wsl::windows::common::hcs::unique_hcs_system m_system;
|
||||
wil::srwlock m_lock;
|
||||
std::vector<RemoteFileSystemInfo> m_fileSystems;
|
||||
bool m_shutdown;
|
||||
|
||||
struct DeviceHostProxyEntry
|
||||
{
|
||||
wil::com_ptr<IVmFiovGuestMemoryFastNotification> MemoryNotification;
|
||||
wil::com_ptr<IVmFiovGuestMmioMappings> MemoryMapping;
|
||||
size_t DoorbellCount = 0;
|
||||
};
|
||||
|
||||
wil::com_ptr<IVmVirtualDeviceAccess> m_deviceAccess;
|
||||
wil::srwlock m_devicesLock;
|
||||
std::map<GUID, DeviceHostProxyEntry, wsl::windows::common::helpers::GuidLess> m_devices;
|
||||
bool m_devicesShutdown;
|
||||
|
||||
static constexpr LPCWSTR c_hdvModuleName = L"vmdevicehost.dll";
|
||||
static constexpr LPCWSTR c_vmwpctrlModuleName = L"vmwpctrl.dll";
|
||||
// Copyright (C) Microsoft Corporation. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <windowsdefs.h>
|
||||
#include "hcs.hpp"
|
||||
|
||||
namespace wrl = Microsoft::WRL;
|
||||
|
||||
class DeviceHostProxy : public wrl::RuntimeClass<wrl::RuntimeClassFlags<wrl::RuntimeClassType::ClassicCom>, IVmDeviceHostSupport, IPlan9FileSystemHost>
|
||||
{
|
||||
public:
|
||||
DeviceHostProxy(const std::wstring& VmId, const GUID& RuntimeId);
|
||||
|
||||
GUID AddNewDevice(const GUID& Type, const wil::com_ptr<IPlan9FileSystem>& Plan9Fs, const std::wstring& VirtIoTag);
|
||||
|
||||
void AddRemoteFileSystem(const GUID& ImplementationClsid, const std::wstring& Tag, const wil::com_ptr<IPlan9FileSystem>& Plan9Fs);
|
||||
|
||||
wil::com_ptr<IPlan9FileSystem> GetRemoteFileSystem(const GUID& ImplementationClsid, std::wstring_view Tag);
|
||||
|
||||
void Shutdown();
|
||||
|
||||
//
|
||||
// IVmDeviceHostSupport
|
||||
//
|
||||
IFACEMETHOD(RegisterDeviceHost)(_In_ IVmDeviceHost* DeviceHost, _In_ DWORD ProcessId, _Out_ UINT64* IpcSectionHandle) override;
|
||||
|
||||
//
|
||||
// IPlan9FileSystemHost
|
||||
//
|
||||
IFACEMETHOD(NotifyAllDevicesInUse)(_In_ LPCWSTR Tag) override;
|
||||
|
||||
IFACEMETHOD(RegisterDoorbell)(const GUID& InstanceId, UINT8 BarIndex, UINT64 Offset, UINT64 TriggerValue, UINT64 Flags, HANDLE Event) override;
|
||||
|
||||
IFACEMETHOD(UnregisterDoorbell)(const GUID& InstanceId, UINT8 BarIndex, UINT64 Offset, UINT64 TriggerValue, UINT64 Flags) override;
|
||||
|
||||
IFACEMETHOD(CreateSectionBackedMmioRange)(
|
||||
const GUID& InstanceId, UINT8 BarIndex, UINT64 BarOffsetInPages, UINT64 PageCount, UINT64 MappingFlags, HANDLE SectionHandle, UINT64 SectionOffsetInPages) override;
|
||||
|
||||
IFACEMETHOD(DestroySectionBackedMmioRange)(const GUID& InstanceId, UINT8 BarIndex, UINT64 BarOffsetInPages) override;
|
||||
|
||||
private:
|
||||
struct RemoteFileSystemInfo
|
||||
{
|
||||
RemoteFileSystemInfo(GUID ImplementationClsid, const std::wstring& Tag, const wil::com_ptr<IPlan9FileSystem>& Instance) :
|
||||
ImplementationClsid{ImplementationClsid}, Tag{Tag}, Instance{Instance}
|
||||
{
|
||||
}
|
||||
|
||||
GUID ImplementationClsid;
|
||||
std::wstring Tag;
|
||||
wil::com_ptr<IPlan9FileSystem> Instance;
|
||||
};
|
||||
|
||||
std::wstring m_systemId;
|
||||
GUID m_runtimeId;
|
||||
wsl::windows::common::hcs::unique_hcs_system m_system;
|
||||
wil::srwlock m_lock;
|
||||
std::vector<RemoteFileSystemInfo> m_fileSystems;
|
||||
bool m_shutdown;
|
||||
|
||||
struct DeviceHostProxyEntry
|
||||
{
|
||||
wil::com_ptr<IVmFiovGuestMemoryFastNotification> MemoryNotification;
|
||||
wil::com_ptr<IVmFiovGuestMmioMappings> MemoryMapping;
|
||||
size_t DoorbellCount = 0;
|
||||
};
|
||||
|
||||
wil::com_ptr<IVmVirtualDeviceAccess> m_deviceAccess;
|
||||
wil::srwlock m_devicesLock;
|
||||
std::map<GUID, DeviceHostProxyEntry, wsl::windows::common::helpers::GuidLess> m_devices;
|
||||
bool m_devicesShutdown;
|
||||
|
||||
static constexpr LPCWSTR c_hdvModuleName = L"vmdevicehost.dll";
|
||||
static constexpr LPCWSTR c_vmwpctrlModuleName = L"vmwpctrl.dll";
|
||||
};
|
||||
@ -1,417 +1,417 @@
|
||||
// Copyright (C) Microsoft Corporation. All rights reserved.
|
||||
|
||||
#include <LxssDynamicFunction.h>
|
||||
#include "precomp.h"
|
||||
#include "DnsResolver.h"
|
||||
|
||||
using wsl::core::networking::DnsResolver;
|
||||
|
||||
static constexpr auto c_dnsModuleName = L"dnsapi.dll";
|
||||
|
||||
std::optional<LxssDynamicFunction<decltype(DnsQueryRaw)>> DnsResolver::s_dnsQueryRaw;
|
||||
std::optional<LxssDynamicFunction<decltype(DnsCancelQueryRaw)>> DnsResolver::s_dnsCancelQueryRaw;
|
||||
std::optional<LxssDynamicFunction<decltype(DnsQueryRawResultFree)>> DnsResolver::s_dnsQueryRawResultFree;
|
||||
|
||||
HRESULT DnsResolver::LoadDnsResolverMethods() noexcept
|
||||
{
|
||||
static wil::shared_hmodule dnsModule;
|
||||
static DWORD loadError = ERROR_SUCCESS;
|
||||
static std::once_flag dnsLoadFlag;
|
||||
|
||||
// Load DNS dll only once
|
||||
std::call_once(dnsLoadFlag, [&]() {
|
||||
dnsModule.reset(LoadLibraryEx(c_dnsModuleName, nullptr, LOAD_LIBRARY_SEARCH_SYSTEM32));
|
||||
if (!dnsModule)
|
||||
{
|
||||
loadError = GetLastError();
|
||||
}
|
||||
});
|
||||
|
||||
RETURN_IF_WIN32_ERROR_MSG(loadError, "LoadLibraryEx %ls", c_dnsModuleName);
|
||||
|
||||
// Initialize dynamic functions for the DNS tunneling Windows APIs.
|
||||
// using the non-throwing instance of LxssDynamicFunction as to not end up in the Error telemetry
|
||||
LxssDynamicFunction<decltype(DnsQueryRaw)> local_dnsQueryRaw{DynamicFunctionErrorLogs::None};
|
||||
RETURN_IF_FAILED_EXPECTED(local_dnsQueryRaw.load(dnsModule, "DnsQueryRaw"));
|
||||
LxssDynamicFunction<decltype(DnsCancelQueryRaw)> local_dnsCancelQueryRaw{DynamicFunctionErrorLogs::None};
|
||||
RETURN_IF_FAILED_EXPECTED(local_dnsCancelQueryRaw.load(dnsModule, "DnsCancelQueryRaw"));
|
||||
LxssDynamicFunction<decltype(DnsQueryRawResultFree)> local_dnsQueryRawResultFree{DynamicFunctionErrorLogs::None};
|
||||
RETURN_IF_FAILED_EXPECTED(local_dnsQueryRawResultFree.load(dnsModule, "DnsQueryRawResultFree"));
|
||||
|
||||
// Make a dummy call to the DNS APIs to verify if they are working. The APIs are going to be present
|
||||
// on older Windows versions, where they can be turned on/off. If turned off, the APIs
|
||||
// will be unusable and will return ERROR_CALL_NOT_IMPLEMENTED.
|
||||
if (local_dnsQueryRaw(nullptr, nullptr) == ERROR_CALL_NOT_IMPLEMENTED)
|
||||
{
|
||||
RETURN_IF_WIN32_ERROR_EXPECTED(ERROR_CALL_NOT_IMPLEMENTED);
|
||||
}
|
||||
|
||||
s_dnsQueryRaw.emplace(std::move(local_dnsQueryRaw));
|
||||
s_dnsCancelQueryRaw.emplace(std::move(local_dnsCancelQueryRaw));
|
||||
s_dnsQueryRawResultFree.emplace(std::move(local_dnsQueryRawResultFree));
|
||||
return S_OK;
|
||||
}
|
||||
|
||||
DnsResolver::DnsResolver(wil::unique_socket&& dnsHvsocket, DnsResolverFlags flags) :
|
||||
m_dnsChannel(
|
||||
std::move(dnsHvsocket),
|
||||
[this](const gsl::span<gsl::byte> dnsBuffer, const LX_GNS_DNS_CLIENT_IDENTIFIER& dnsClientIdentifier) {
|
||||
ProcessDnsRequest(dnsBuffer, dnsClientIdentifier);
|
||||
}),
|
||||
m_flags(flags)
|
||||
{
|
||||
// Initialize as signaled, as there are no requests yet
|
||||
m_allRequestsFinished.SetEvent();
|
||||
|
||||
// Read external interface constraint regkey
|
||||
const auto lxssKey = windows::common::registry::OpenLxssMachineKey(KEY_READ);
|
||||
m_externalInterfaceConstraintName =
|
||||
windows::common::registry::ReadString(lxssKey.get(), nullptr, c_interfaceConstraintKey, L"");
|
||||
|
||||
if (!m_externalInterfaceConstraintName.empty())
|
||||
{
|
||||
ResolveExternalInterfaceConstraintIndex();
|
||||
|
||||
WSL_LOG(
|
||||
"DnsResolver::DnsResolver",
|
||||
TraceLoggingValue(m_externalInterfaceConstraintName.c_str(), "m_externalInterfaceConstraintName"),
|
||||
TraceLoggingValue(m_externalInterfaceConstraintIndex, "m_externalInterfaceConstraintIndex"));
|
||||
|
||||
// Register for interface change notifications. Notifications are used to determine if the external interface constraint setting is applicable.
|
||||
THROW_IF_WIN32_ERROR(NotifyIpInterfaceChange(AF_UNSPEC, &DnsResolver::InterfaceChangeCallback, this, FALSE, &m_interfaceNotificationHandle));
|
||||
}
|
||||
}
|
||||
|
||||
DnsResolver::~DnsResolver() noexcept
|
||||
{
|
||||
Stop();
|
||||
}
|
||||
|
||||
void DnsResolver::GenerateTelemetry() noexcept
|
||||
try
|
||||
{
|
||||
// Find the 3 most common DNS API failures
|
||||
uint32_t mostCommonDnsStatusError = 0;
|
||||
uint32_t mostCommonDnsStatusErrorCount = 0;
|
||||
uint32_t secondCommonDnsStatusError = 0;
|
||||
uint32_t secondCommonDnsStatusErrorCount = 0;
|
||||
uint32_t thirdCommonDnsStatusError = 0;
|
||||
uint32_t thirdCommonDnsStatusErrorCount = 0;
|
||||
|
||||
std::vector<std::pair<uint32_t, uint32_t>> failures(m_dnsApiFailures.size());
|
||||
std::copy(m_dnsApiFailures.begin(), m_dnsApiFailures.end(), failures.begin());
|
||||
|
||||
// Sort in descending order based on failure count
|
||||
std::sort(failures.begin(), failures.end(), [](const auto& lhs, const auto& rhs) { return lhs.second > rhs.second; });
|
||||
|
||||
if (failures.size() >= 1)
|
||||
{
|
||||
mostCommonDnsStatusError = failures[0].first;
|
||||
mostCommonDnsStatusErrorCount = failures[0].second;
|
||||
}
|
||||
if (failures.size() >= 2)
|
||||
{
|
||||
secondCommonDnsStatusError = failures[1].first;
|
||||
secondCommonDnsStatusErrorCount = failures[1].second;
|
||||
}
|
||||
if (failures.size() >= 3)
|
||||
{
|
||||
thirdCommonDnsStatusError = failures[2].first;
|
||||
thirdCommonDnsStatusErrorCount = failures[2].second;
|
||||
}
|
||||
|
||||
// Add telemetry with DNS tunneling statistics, before shutting down
|
||||
WSL_LOG(
|
||||
"DnsTunnelingStatistics",
|
||||
TraceLoggingValue(m_totalUdpQueries.load(), "totalUdpQueries"),
|
||||
TraceLoggingValue(m_successfulUdpQueries.load(), "successfulUdpQueries"),
|
||||
TraceLoggingValue(m_totalTcpQueries.load(), "totalTcpQueries"),
|
||||
TraceLoggingValue(m_successfulTcpQueries.load(), "successfulTcpQueries"),
|
||||
TraceLoggingValue(m_queriesWithNullResult.load(), "queriesWithNullResult"),
|
||||
TraceLoggingValue(m_failedDnsQueryRawCalls.load(), "FailedDnsQueryRawCalls"),
|
||||
TraceLoggingValue(m_dnsApiFailures.size(), "totalDnsStatusErrorInstances"),
|
||||
TraceLoggingValue(mostCommonDnsStatusError, "mostCommonDnsStatusError"),
|
||||
TraceLoggingValue(mostCommonDnsStatusErrorCount, "mostCommonDnsStatusErrorCount"),
|
||||
TraceLoggingValue(secondCommonDnsStatusError, "secondCommonDnsStatusError"),
|
||||
TraceLoggingValue(secondCommonDnsStatusErrorCount, "secondCommonDnsStatusErrorCount"),
|
||||
TraceLoggingValue(thirdCommonDnsStatusError, "thirdCommonDnsStatusError"),
|
||||
TraceLoggingValue(thirdCommonDnsStatusErrorCount, "thirdCommonDnsStatusErrorCount"));
|
||||
}
|
||||
CATCH_LOG()
|
||||
|
||||
void DnsResolver::Stop() noexcept
|
||||
try
|
||||
{
|
||||
WSL_LOG("DnsResolver::Stop");
|
||||
|
||||
// Scoped m_dnsLock
|
||||
{
|
||||
const std::lock_guard lock(m_dnsLock);
|
||||
|
||||
m_stopped = true;
|
||||
|
||||
// Cancel existing requests. Cancel is complete when DnsQueryRawCallback is
|
||||
// invoked with status == ERROR_CANCELLED
|
||||
// N.B. Cancelling can end up calling the DnsQueryRawCallback directly on this same thread. i.e., while this
|
||||
// lock is held. Which is fine because m_dnsLock is a recursive mutex.
|
||||
// N.B. Cancelling a query will synchronously remove the query from m_dnsRequests, which invalidates iterators.
|
||||
|
||||
std::vector<DNS_QUERY_RAW_CANCEL*> cancelHandles;
|
||||
cancelHandles.reserve(m_dnsRequests.size());
|
||||
|
||||
for (auto& [_, context] : m_dnsRequests)
|
||||
{
|
||||
cancelHandles.emplace_back(&context->m_cancelHandle);
|
||||
}
|
||||
|
||||
for (const auto e : cancelHandles)
|
||||
{
|
||||
LOG_IF_WIN32_ERROR(s_dnsCancelQueryRaw.value()(e));
|
||||
}
|
||||
}
|
||||
|
||||
// Wait for all requests to complete. At this point no new requests can be started since the object is stopped.
|
||||
// We are only waiting for existing requests to finish.
|
||||
m_allRequestsFinished.wait();
|
||||
|
||||
// Stop the response queue first as it can make calls in m_dnsChannel
|
||||
m_dnsResponseQueue.cancel();
|
||||
|
||||
m_dnsChannel.Stop();
|
||||
|
||||
// Stop interface change notifications
|
||||
m_interfaceNotificationHandle.reset();
|
||||
|
||||
GenerateTelemetry();
|
||||
}
|
||||
CATCH_LOG()
|
||||
|
||||
void DnsResolver::ProcessDnsRequest(const gsl::span<gsl::byte> dnsBuffer, const LX_GNS_DNS_CLIENT_IDENTIFIER& dnsClientIdentifier) noexcept
|
||||
try
|
||||
{
|
||||
const std::lock_guard lock(m_dnsLock);
|
||||
if (m_stopped)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
WSL_LOG_DEBUG(
|
||||
"DnsResolver::ProcessDnsRequest - received new DNS request",
|
||||
TraceLoggingValue(dnsBuffer.size(), "DNS buffer size"),
|
||||
TraceLoggingValue(dnsClientIdentifier.Protocol == IPPROTO_UDP ? "UDP" : "TCP", "Protocol"),
|
||||
TraceLoggingValue(dnsClientIdentifier.DnsClientId, "DNS client id"),
|
||||
TraceLoggingValue(!m_externalInterfaceConstraintName.empty(), "Is ExternalInterfaceConstraint configured"),
|
||||
TraceLoggingValue(m_externalInterfaceConstraintIndex, "m_externalInterfaceConstraintIndex"));
|
||||
|
||||
// If the external interface constraint is configured but it is *not* present/up, WSL should be net-blind, so we avoid making DNS requests.
|
||||
if (!m_externalInterfaceConstraintName.empty() && m_externalInterfaceConstraintIndex == 0)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
dnsClientIdentifier.Protocol == IPPROTO_UDP ? m_totalUdpQueries++ : m_totalTcpQueries++;
|
||||
|
||||
// Get next request id. If value reaches UINT_MAX + 1 it will be automatically reset to 0
|
||||
const auto requestId = m_currentRequestId++;
|
||||
|
||||
// Create the DNS request context
|
||||
auto context = std::make_unique<DnsResolver::DnsQueryContext>(
|
||||
requestId, dnsClientIdentifier, [this](_Inout_ DnsResolver::DnsQueryContext* context, _Inout_opt_ DNS_QUERY_RAW_RESULT* queryResults) {
|
||||
HandleDnsQueryCompletion(context, queryResults);
|
||||
});
|
||||
|
||||
auto [it, _] = m_dnsRequests.emplace(requestId, std::move(context));
|
||||
const auto localContext = it->second.get();
|
||||
|
||||
auto removeContextOnError = wil::scope_exit([&] { WI_VERIFY(m_dnsRequests.erase(requestId) == 1); });
|
||||
|
||||
// Fill DNS request structure
|
||||
DNS_QUERY_RAW_REQUEST request{};
|
||||
|
||||
request.version = DNS_QUERY_RAW_REQUEST_VERSION1;
|
||||
request.resultsVersion = DNS_QUERY_RAW_RESULTS_VERSION1;
|
||||
request.dnsQueryRawSize = static_cast<ULONG>(dnsBuffer.size());
|
||||
request.dnsQueryRaw = (PBYTE)dnsBuffer.data();
|
||||
request.protocol = (dnsClientIdentifier.Protocol == IPPROTO_TCP) ? DNS_PROTOCOL_TCP : DNS_PROTOCOL_UDP;
|
||||
request.queryCompletionCallback = DnsResolver::DnsQueryRawCallback;
|
||||
request.queryContext = localContext;
|
||||
// Only unicast UDP & TCP queries are tunneled. Pass this flag to tell Windows DNS client to *not* resolve using multicast.
|
||||
request.queryOptions |= DNS_QUERY_NO_MULTICAST;
|
||||
|
||||
// In a DNS request from Linux there might be DNS records that Windows DNS client does not know how to parse.
|
||||
// By default in this case Windows will fail the request. When the flag is enabled, Windows will extract the
|
||||
// question from the DNS request and attempt to resolve it, ignoring the unknown records.
|
||||
if (WI_IsFlagSet(m_flags, DnsResolverFlags::BestEffortDnsParsing))
|
||||
{
|
||||
request.queryRawOptions |= DNS_QUERY_RAW_OPTION_BEST_EFFORT_PARSE;
|
||||
}
|
||||
|
||||
// If the external interface constraint is configured and present on the host, only send DNS requests on that interface.
|
||||
if (m_externalInterfaceConstraintIndex != 0)
|
||||
{
|
||||
request.interfaceIndex = m_externalInterfaceConstraintIndex;
|
||||
}
|
||||
|
||||
// Start the DNS request
|
||||
// N.B. All DNS requests will bypass the Windows DNS cache
|
||||
const auto result = s_dnsQueryRaw.value()(&request, &localContext->m_cancelHandle);
|
||||
if (result != DNS_REQUEST_PENDING)
|
||||
{
|
||||
m_failedDnsQueryRawCalls++;
|
||||
|
||||
WSL_LOG(
|
||||
"ProcessDnsRequestFailed",
|
||||
TraceLoggingValue(requestId, "requestId"),
|
||||
TraceLoggingValue(result, "result"),
|
||||
TraceLoggingValue("DnsQueryRaw", "executionStep"));
|
||||
return;
|
||||
}
|
||||
|
||||
removeContextOnError.release();
|
||||
|
||||
m_allRequestsFinished.ResetEvent();
|
||||
}
|
||||
CATCH_LOG()
|
||||
|
||||
void DnsResolver::HandleDnsQueryCompletion(_Inout_ DnsResolver::DnsQueryContext* queryContext, _Inout_opt_ DNS_QUERY_RAW_RESULT* queryResults) noexcept
|
||||
try
|
||||
{
|
||||
// Always free the query result structure
|
||||
const auto freeQueryResults = wil::scope_exit([&] {
|
||||
if (queryResults != nullptr)
|
||||
{
|
||||
s_dnsQueryRawResultFree.value()(queryResults);
|
||||
}
|
||||
});
|
||||
|
||||
const std::lock_guard lock(m_dnsLock);
|
||||
|
||||
if (queryResults != nullptr)
|
||||
{
|
||||
WSL_LOG(
|
||||
"DnsResolver::HandleDnsQueryCompletion",
|
||||
TraceLoggingValue(queryContext->m_id, "queryContext->m_id"),
|
||||
TraceLoggingValue(queryResults->queryStatus, "queryResults->queryStatus"),
|
||||
TraceLoggingValue(queryResults->queryRawResponse != nullptr, "validResponse"));
|
||||
|
||||
// Note: The response may be valid even if queryResults->queryStatus is not 0, for example when the DNS server returns a negative response.
|
||||
if (queryResults->queryRawResponse != nullptr)
|
||||
{
|
||||
queryContext->m_dnsClientIdentifier.Protocol == IPPROTO_UDP ? m_successfulUdpQueries++ : m_successfulTcpQueries++;
|
||||
}
|
||||
// the Windows DNS API returned failure
|
||||
else
|
||||
{
|
||||
if (m_dnsApiFailures.find(queryResults->queryStatus) == m_dnsApiFailures.end())
|
||||
{
|
||||
m_dnsApiFailures[queryResults->queryStatus] = 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
m_dnsApiFailures[queryResults->queryStatus]++;
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
WSL_LOG(
|
||||
"DnsResolver::HandleDnsQueryCompletion - received a NULL queryResults",
|
||||
TraceLoggingValue(queryContext->m_id, "queryContext->m_id"));
|
||||
m_queriesWithNullResult++;
|
||||
}
|
||||
|
||||
if (!m_stopped && queryResults != nullptr && queryResults->queryRawResponse != nullptr)
|
||||
{
|
||||
// Copy DNS response buffer
|
||||
std::vector<gsl::byte> dnsResponse(queryResults->queryRawResponseSize);
|
||||
CopyMemory(dnsResponse.data(), queryResults->queryRawResponse, queryResults->queryRawResponseSize);
|
||||
|
||||
WSL_LOG_DEBUG(
|
||||
"DnsResolver::HandleDnsQueryCompletion - received new DNS response",
|
||||
TraceLoggingValue(dnsResponse.size(), "DNS buffer size"),
|
||||
TraceLoggingValue(queryContext->m_dnsClientIdentifier.Protocol == IPPROTO_UDP ? "UDP" : "TCP", "Protocol"),
|
||||
TraceLoggingValue(queryContext->m_dnsClientIdentifier.DnsClientId, "DNS client id"));
|
||||
|
||||
// Schedule the DNS response to be sent to Linux
|
||||
m_dnsResponseQueue.submit([this, dnsResponse = std::move(dnsResponse), dnsClientIdentifier = queryContext->m_dnsClientIdentifier]() mutable {
|
||||
m_dnsChannel.SendDnsMessage(gsl::make_span(dnsResponse), dnsClientIdentifier);
|
||||
});
|
||||
}
|
||||
|
||||
// Stop tracking this DNS request and delete the request context
|
||||
WI_VERIFY(m_dnsRequests.erase(queryContext->m_id) == 1);
|
||||
|
||||
// Set event if all tracked requests have finished
|
||||
if (m_dnsRequests.empty())
|
||||
{
|
||||
m_allRequestsFinished.SetEvent();
|
||||
}
|
||||
}
|
||||
CATCH_LOG()
|
||||
|
||||
void DnsResolver::ResolveExternalInterfaceConstraintIndex() noexcept
|
||||
try
|
||||
{
|
||||
const std::lock_guard lock(m_dnsLock);
|
||||
if (m_stopped)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
if (m_externalInterfaceConstraintName.empty())
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
NET_LUID interfaceLuid{};
|
||||
ULONG interfaceIndex = 0;
|
||||
|
||||
// Update the interface index on every exit path.
|
||||
// The calls below to convert interface name to index will fail if the interface does not exist anymore,
|
||||
// in which case we still need to reset the interface index to its default value of 0.
|
||||
const auto setInterfaceIndex = wil::scope_exit([&] {
|
||||
if (interfaceIndex != m_externalInterfaceConstraintIndex)
|
||||
{
|
||||
WSL_LOG(
|
||||
"DnsResolver::ResolveExternalInterfaceConstraintIndex - setting m_externalInterfaceConstraintIndex to new value",
|
||||
TraceLoggingValue(m_externalInterfaceConstraintIndex, "old interface index"),
|
||||
TraceLoggingValue(interfaceIndex, "new interface index"));
|
||||
|
||||
m_externalInterfaceConstraintIndex = interfaceIndex;
|
||||
}
|
||||
});
|
||||
|
||||
// If external interface constraint is configured, query to see if it's present on the host.
|
||||
auto errorCode = ConvertInterfaceAliasToLuid(m_externalInterfaceConstraintName.c_str(), &interfaceLuid);
|
||||
if (FAILED_WIN32_LOG(errorCode))
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
errorCode = ConvertInterfaceLuidToIndex(&interfaceLuid, reinterpret_cast<PNET_IFINDEX>(&interfaceIndex));
|
||||
if (FAILED_WIN32_LOG(errorCode))
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
CATCH_LOG()
|
||||
|
||||
VOID CALLBACK DnsResolver::DnsQueryRawCallback(_In_ VOID* queryContext, _Inout_opt_ DNS_QUERY_RAW_RESULT* queryResults) noexcept
|
||||
try
|
||||
{
|
||||
assert(queryContext != nullptr);
|
||||
|
||||
const auto context = static_cast<DnsQueryContext*>(queryContext);
|
||||
|
||||
// Call into DnsResolver parent object to process the query result
|
||||
context->m_handleQueryCompletion(context, queryResults);
|
||||
}
|
||||
CATCH_LOG()
|
||||
|
||||
VOID CALLBACK DnsResolver::InterfaceChangeCallback(_In_ PVOID context, PMIB_IPINTERFACE_ROW, MIB_NOTIFICATION_TYPE) noexcept
|
||||
try
|
||||
{
|
||||
const auto dnsResolver = static_cast<DnsResolver*>(context);
|
||||
dnsResolver->ResolveExternalInterfaceConstraintIndex();
|
||||
}
|
||||
CATCH_LOG()
|
||||
// Copyright (C) Microsoft Corporation. All rights reserved.
|
||||
|
||||
#include <LxssDynamicFunction.h>
|
||||
#include "precomp.h"
|
||||
#include "DnsResolver.h"
|
||||
|
||||
using wsl::core::networking::DnsResolver;
|
||||
|
||||
static constexpr auto c_dnsModuleName = L"dnsapi.dll";
|
||||
|
||||
std::optional<LxssDynamicFunction<decltype(DnsQueryRaw)>> DnsResolver::s_dnsQueryRaw;
|
||||
std::optional<LxssDynamicFunction<decltype(DnsCancelQueryRaw)>> DnsResolver::s_dnsCancelQueryRaw;
|
||||
std::optional<LxssDynamicFunction<decltype(DnsQueryRawResultFree)>> DnsResolver::s_dnsQueryRawResultFree;
|
||||
|
||||
HRESULT DnsResolver::LoadDnsResolverMethods() noexcept
|
||||
{
|
||||
static wil::shared_hmodule dnsModule;
|
||||
static DWORD loadError = ERROR_SUCCESS;
|
||||
static std::once_flag dnsLoadFlag;
|
||||
|
||||
// Load DNS dll only once
|
||||
std::call_once(dnsLoadFlag, [&]() {
|
||||
dnsModule.reset(LoadLibraryEx(c_dnsModuleName, nullptr, LOAD_LIBRARY_SEARCH_SYSTEM32));
|
||||
if (!dnsModule)
|
||||
{
|
||||
loadError = GetLastError();
|
||||
}
|
||||
});
|
||||
|
||||
RETURN_IF_WIN32_ERROR_MSG(loadError, "LoadLibraryEx %ls", c_dnsModuleName);
|
||||
|
||||
// Initialize dynamic functions for the DNS tunneling Windows APIs.
|
||||
// using the non-throwing instance of LxssDynamicFunction as to not end up in the Error telemetry
|
||||
LxssDynamicFunction<decltype(DnsQueryRaw)> local_dnsQueryRaw{DynamicFunctionErrorLogs::None};
|
||||
RETURN_IF_FAILED_EXPECTED(local_dnsQueryRaw.load(dnsModule, "DnsQueryRaw"));
|
||||
LxssDynamicFunction<decltype(DnsCancelQueryRaw)> local_dnsCancelQueryRaw{DynamicFunctionErrorLogs::None};
|
||||
RETURN_IF_FAILED_EXPECTED(local_dnsCancelQueryRaw.load(dnsModule, "DnsCancelQueryRaw"));
|
||||
LxssDynamicFunction<decltype(DnsQueryRawResultFree)> local_dnsQueryRawResultFree{DynamicFunctionErrorLogs::None};
|
||||
RETURN_IF_FAILED_EXPECTED(local_dnsQueryRawResultFree.load(dnsModule, "DnsQueryRawResultFree"));
|
||||
|
||||
// Make a dummy call to the DNS APIs to verify if they are working. The APIs are going to be present
|
||||
// on older Windows versions, where they can be turned on/off. If turned off, the APIs
|
||||
// will be unusable and will return ERROR_CALL_NOT_IMPLEMENTED.
|
||||
if (local_dnsQueryRaw(nullptr, nullptr) == ERROR_CALL_NOT_IMPLEMENTED)
|
||||
{
|
||||
RETURN_IF_WIN32_ERROR_EXPECTED(ERROR_CALL_NOT_IMPLEMENTED);
|
||||
}
|
||||
|
||||
s_dnsQueryRaw.emplace(std::move(local_dnsQueryRaw));
|
||||
s_dnsCancelQueryRaw.emplace(std::move(local_dnsCancelQueryRaw));
|
||||
s_dnsQueryRawResultFree.emplace(std::move(local_dnsQueryRawResultFree));
|
||||
return S_OK;
|
||||
}
|
||||
|
||||
DnsResolver::DnsResolver(wil::unique_socket&& dnsHvsocket, DnsResolverFlags flags) :
|
||||
m_dnsChannel(
|
||||
std::move(dnsHvsocket),
|
||||
[this](const gsl::span<gsl::byte> dnsBuffer, const LX_GNS_DNS_CLIENT_IDENTIFIER& dnsClientIdentifier) {
|
||||
ProcessDnsRequest(dnsBuffer, dnsClientIdentifier);
|
||||
}),
|
||||
m_flags(flags)
|
||||
{
|
||||
// Initialize as signaled, as there are no requests yet
|
||||
m_allRequestsFinished.SetEvent();
|
||||
|
||||
// Read external interface constraint regkey
|
||||
const auto lxssKey = windows::common::registry::OpenLxssMachineKey(KEY_READ);
|
||||
m_externalInterfaceConstraintName =
|
||||
windows::common::registry::ReadString(lxssKey.get(), nullptr, c_interfaceConstraintKey, L"");
|
||||
|
||||
if (!m_externalInterfaceConstraintName.empty())
|
||||
{
|
||||
ResolveExternalInterfaceConstraintIndex();
|
||||
|
||||
WSL_LOG(
|
||||
"DnsResolver::DnsResolver",
|
||||
TraceLoggingValue(m_externalInterfaceConstraintName.c_str(), "m_externalInterfaceConstraintName"),
|
||||
TraceLoggingValue(m_externalInterfaceConstraintIndex, "m_externalInterfaceConstraintIndex"));
|
||||
|
||||
// Register for interface change notifications. Notifications are used to determine if the external interface constraint setting is applicable.
|
||||
THROW_IF_WIN32_ERROR(NotifyIpInterfaceChange(AF_UNSPEC, &DnsResolver::InterfaceChangeCallback, this, FALSE, &m_interfaceNotificationHandle));
|
||||
}
|
||||
}
|
||||
|
||||
DnsResolver::~DnsResolver() noexcept
|
||||
{
|
||||
Stop();
|
||||
}
|
||||
|
||||
void DnsResolver::GenerateTelemetry() noexcept
|
||||
try
|
||||
{
|
||||
// Find the 3 most common DNS API failures
|
||||
uint32_t mostCommonDnsStatusError = 0;
|
||||
uint32_t mostCommonDnsStatusErrorCount = 0;
|
||||
uint32_t secondCommonDnsStatusError = 0;
|
||||
uint32_t secondCommonDnsStatusErrorCount = 0;
|
||||
uint32_t thirdCommonDnsStatusError = 0;
|
||||
uint32_t thirdCommonDnsStatusErrorCount = 0;
|
||||
|
||||
std::vector<std::pair<uint32_t, uint32_t>> failures(m_dnsApiFailures.size());
|
||||
std::copy(m_dnsApiFailures.begin(), m_dnsApiFailures.end(), failures.begin());
|
||||
|
||||
// Sort in descending order based on failure count
|
||||
std::sort(failures.begin(), failures.end(), [](const auto& lhs, const auto& rhs) { return lhs.second > rhs.second; });
|
||||
|
||||
if (failures.size() >= 1)
|
||||
{
|
||||
mostCommonDnsStatusError = failures[0].first;
|
||||
mostCommonDnsStatusErrorCount = failures[0].second;
|
||||
}
|
||||
if (failures.size() >= 2)
|
||||
{
|
||||
secondCommonDnsStatusError = failures[1].first;
|
||||
secondCommonDnsStatusErrorCount = failures[1].second;
|
||||
}
|
||||
if (failures.size() >= 3)
|
||||
{
|
||||
thirdCommonDnsStatusError = failures[2].first;
|
||||
thirdCommonDnsStatusErrorCount = failures[2].second;
|
||||
}
|
||||
|
||||
// Add telemetry with DNS tunneling statistics, before shutting down
|
||||
WSL_LOG(
|
||||
"DnsTunnelingStatistics",
|
||||
TraceLoggingValue(m_totalUdpQueries.load(), "totalUdpQueries"),
|
||||
TraceLoggingValue(m_successfulUdpQueries.load(), "successfulUdpQueries"),
|
||||
TraceLoggingValue(m_totalTcpQueries.load(), "totalTcpQueries"),
|
||||
TraceLoggingValue(m_successfulTcpQueries.load(), "successfulTcpQueries"),
|
||||
TraceLoggingValue(m_queriesWithNullResult.load(), "queriesWithNullResult"),
|
||||
TraceLoggingValue(m_failedDnsQueryRawCalls.load(), "FailedDnsQueryRawCalls"),
|
||||
TraceLoggingValue(m_dnsApiFailures.size(), "totalDnsStatusErrorInstances"),
|
||||
TraceLoggingValue(mostCommonDnsStatusError, "mostCommonDnsStatusError"),
|
||||
TraceLoggingValue(mostCommonDnsStatusErrorCount, "mostCommonDnsStatusErrorCount"),
|
||||
TraceLoggingValue(secondCommonDnsStatusError, "secondCommonDnsStatusError"),
|
||||
TraceLoggingValue(secondCommonDnsStatusErrorCount, "secondCommonDnsStatusErrorCount"),
|
||||
TraceLoggingValue(thirdCommonDnsStatusError, "thirdCommonDnsStatusError"),
|
||||
TraceLoggingValue(thirdCommonDnsStatusErrorCount, "thirdCommonDnsStatusErrorCount"));
|
||||
}
|
||||
CATCH_LOG()
|
||||
|
||||
void DnsResolver::Stop() noexcept
|
||||
try
|
||||
{
|
||||
WSL_LOG("DnsResolver::Stop");
|
||||
|
||||
// Scoped m_dnsLock
|
||||
{
|
||||
const std::lock_guard lock(m_dnsLock);
|
||||
|
||||
m_stopped = true;
|
||||
|
||||
// Cancel existing requests. Cancel is complete when DnsQueryRawCallback is
|
||||
// invoked with status == ERROR_CANCELLED
|
||||
// N.B. Cancelling can end up calling the DnsQueryRawCallback directly on this same thread. i.e., while this
|
||||
// lock is held. Which is fine because m_dnsLock is a recursive mutex.
|
||||
// N.B. Cancelling a query will synchronously remove the query from m_dnsRequests, which invalidates iterators.
|
||||
|
||||
std::vector<DNS_QUERY_RAW_CANCEL*> cancelHandles;
|
||||
cancelHandles.reserve(m_dnsRequests.size());
|
||||
|
||||
for (auto& [_, context] : m_dnsRequests)
|
||||
{
|
||||
cancelHandles.emplace_back(&context->m_cancelHandle);
|
||||
}
|
||||
|
||||
for (const auto e : cancelHandles)
|
||||
{
|
||||
LOG_IF_WIN32_ERROR(s_dnsCancelQueryRaw.value()(e));
|
||||
}
|
||||
}
|
||||
|
||||
// Wait for all requests to complete. At this point no new requests can be started since the object is stopped.
|
||||
// We are only waiting for existing requests to finish.
|
||||
m_allRequestsFinished.wait();
|
||||
|
||||
// Stop the response queue first as it can make calls in m_dnsChannel
|
||||
m_dnsResponseQueue.cancel();
|
||||
|
||||
m_dnsChannel.Stop();
|
||||
|
||||
// Stop interface change notifications
|
||||
m_interfaceNotificationHandle.reset();
|
||||
|
||||
GenerateTelemetry();
|
||||
}
|
||||
CATCH_LOG()
|
||||
|
||||
void DnsResolver::ProcessDnsRequest(const gsl::span<gsl::byte> dnsBuffer, const LX_GNS_DNS_CLIENT_IDENTIFIER& dnsClientIdentifier) noexcept
|
||||
try
|
||||
{
|
||||
const std::lock_guard lock(m_dnsLock);
|
||||
if (m_stopped)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
WSL_LOG_DEBUG(
|
||||
"DnsResolver::ProcessDnsRequest - received new DNS request",
|
||||
TraceLoggingValue(dnsBuffer.size(), "DNS buffer size"),
|
||||
TraceLoggingValue(dnsClientIdentifier.Protocol == IPPROTO_UDP ? "UDP" : "TCP", "Protocol"),
|
||||
TraceLoggingValue(dnsClientIdentifier.DnsClientId, "DNS client id"),
|
||||
TraceLoggingValue(!m_externalInterfaceConstraintName.empty(), "Is ExternalInterfaceConstraint configured"),
|
||||
TraceLoggingValue(m_externalInterfaceConstraintIndex, "m_externalInterfaceConstraintIndex"));
|
||||
|
||||
// If the external interface constraint is configured but it is *not* present/up, WSL should be net-blind, so we avoid making DNS requests.
|
||||
if (!m_externalInterfaceConstraintName.empty() && m_externalInterfaceConstraintIndex == 0)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
dnsClientIdentifier.Protocol == IPPROTO_UDP ? m_totalUdpQueries++ : m_totalTcpQueries++;
|
||||
|
||||
// Get next request id. If value reaches UINT_MAX + 1 it will be automatically reset to 0
|
||||
const auto requestId = m_currentRequestId++;
|
||||
|
||||
// Create the DNS request context
|
||||
auto context = std::make_unique<DnsResolver::DnsQueryContext>(
|
||||
requestId, dnsClientIdentifier, [this](_Inout_ DnsResolver::DnsQueryContext* context, _Inout_opt_ DNS_QUERY_RAW_RESULT* queryResults) {
|
||||
HandleDnsQueryCompletion(context, queryResults);
|
||||
});
|
||||
|
||||
auto [it, _] = m_dnsRequests.emplace(requestId, std::move(context));
|
||||
const auto localContext = it->second.get();
|
||||
|
||||
auto removeContextOnError = wil::scope_exit([&] { WI_VERIFY(m_dnsRequests.erase(requestId) == 1); });
|
||||
|
||||
// Fill DNS request structure
|
||||
DNS_QUERY_RAW_REQUEST request{};
|
||||
|
||||
request.version = DNS_QUERY_RAW_REQUEST_VERSION1;
|
||||
request.resultsVersion = DNS_QUERY_RAW_RESULTS_VERSION1;
|
||||
request.dnsQueryRawSize = static_cast<ULONG>(dnsBuffer.size());
|
||||
request.dnsQueryRaw = (PBYTE)dnsBuffer.data();
|
||||
request.protocol = (dnsClientIdentifier.Protocol == IPPROTO_TCP) ? DNS_PROTOCOL_TCP : DNS_PROTOCOL_UDP;
|
||||
request.queryCompletionCallback = DnsResolver::DnsQueryRawCallback;
|
||||
request.queryContext = localContext;
|
||||
// Only unicast UDP & TCP queries are tunneled. Pass this flag to tell Windows DNS client to *not* resolve using multicast.
|
||||
request.queryOptions |= DNS_QUERY_NO_MULTICAST;
|
||||
|
||||
// In a DNS request from Linux there might be DNS records that Windows DNS client does not know how to parse.
|
||||
// By default in this case Windows will fail the request. When the flag is enabled, Windows will extract the
|
||||
// question from the DNS request and attempt to resolve it, ignoring the unknown records.
|
||||
if (WI_IsFlagSet(m_flags, DnsResolverFlags::BestEffortDnsParsing))
|
||||
{
|
||||
request.queryRawOptions |= DNS_QUERY_RAW_OPTION_BEST_EFFORT_PARSE;
|
||||
}
|
||||
|
||||
// If the external interface constraint is configured and present on the host, only send DNS requests on that interface.
|
||||
if (m_externalInterfaceConstraintIndex != 0)
|
||||
{
|
||||
request.interfaceIndex = m_externalInterfaceConstraintIndex;
|
||||
}
|
||||
|
||||
// Start the DNS request
|
||||
// N.B. All DNS requests will bypass the Windows DNS cache
|
||||
const auto result = s_dnsQueryRaw.value()(&request, &localContext->m_cancelHandle);
|
||||
if (result != DNS_REQUEST_PENDING)
|
||||
{
|
||||
m_failedDnsQueryRawCalls++;
|
||||
|
||||
WSL_LOG(
|
||||
"ProcessDnsRequestFailed",
|
||||
TraceLoggingValue(requestId, "requestId"),
|
||||
TraceLoggingValue(result, "result"),
|
||||
TraceLoggingValue("DnsQueryRaw", "executionStep"));
|
||||
return;
|
||||
}
|
||||
|
||||
removeContextOnError.release();
|
||||
|
||||
m_allRequestsFinished.ResetEvent();
|
||||
}
|
||||
CATCH_LOG()
|
||||
|
||||
void DnsResolver::HandleDnsQueryCompletion(_Inout_ DnsResolver::DnsQueryContext* queryContext, _Inout_opt_ DNS_QUERY_RAW_RESULT* queryResults) noexcept
|
||||
try
|
||||
{
|
||||
// Always free the query result structure
|
||||
const auto freeQueryResults = wil::scope_exit([&] {
|
||||
if (queryResults != nullptr)
|
||||
{
|
||||
s_dnsQueryRawResultFree.value()(queryResults);
|
||||
}
|
||||
});
|
||||
|
||||
const std::lock_guard lock(m_dnsLock);
|
||||
|
||||
if (queryResults != nullptr)
|
||||
{
|
||||
WSL_LOG(
|
||||
"DnsResolver::HandleDnsQueryCompletion",
|
||||
TraceLoggingValue(queryContext->m_id, "queryContext->m_id"),
|
||||
TraceLoggingValue(queryResults->queryStatus, "queryResults->queryStatus"),
|
||||
TraceLoggingValue(queryResults->queryRawResponse != nullptr, "validResponse"));
|
||||
|
||||
// Note: The response may be valid even if queryResults->queryStatus is not 0, for example when the DNS server returns a negative response.
|
||||
if (queryResults->queryRawResponse != nullptr)
|
||||
{
|
||||
queryContext->m_dnsClientIdentifier.Protocol == IPPROTO_UDP ? m_successfulUdpQueries++ : m_successfulTcpQueries++;
|
||||
}
|
||||
// the Windows DNS API returned failure
|
||||
else
|
||||
{
|
||||
if (m_dnsApiFailures.find(queryResults->queryStatus) == m_dnsApiFailures.end())
|
||||
{
|
||||
m_dnsApiFailures[queryResults->queryStatus] = 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
m_dnsApiFailures[queryResults->queryStatus]++;
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
WSL_LOG(
|
||||
"DnsResolver::HandleDnsQueryCompletion - received a NULL queryResults",
|
||||
TraceLoggingValue(queryContext->m_id, "queryContext->m_id"));
|
||||
m_queriesWithNullResult++;
|
||||
}
|
||||
|
||||
if (!m_stopped && queryResults != nullptr && queryResults->queryRawResponse != nullptr)
|
||||
{
|
||||
// Copy DNS response buffer
|
||||
std::vector<gsl::byte> dnsResponse(queryResults->queryRawResponseSize);
|
||||
CopyMemory(dnsResponse.data(), queryResults->queryRawResponse, queryResults->queryRawResponseSize);
|
||||
|
||||
WSL_LOG_DEBUG(
|
||||
"DnsResolver::HandleDnsQueryCompletion - received new DNS response",
|
||||
TraceLoggingValue(dnsResponse.size(), "DNS buffer size"),
|
||||
TraceLoggingValue(queryContext->m_dnsClientIdentifier.Protocol == IPPROTO_UDP ? "UDP" : "TCP", "Protocol"),
|
||||
TraceLoggingValue(queryContext->m_dnsClientIdentifier.DnsClientId, "DNS client id"));
|
||||
|
||||
// Schedule the DNS response to be sent to Linux
|
||||
m_dnsResponseQueue.submit([this, dnsResponse = std::move(dnsResponse), dnsClientIdentifier = queryContext->m_dnsClientIdentifier]() mutable {
|
||||
m_dnsChannel.SendDnsMessage(gsl::make_span(dnsResponse), dnsClientIdentifier);
|
||||
});
|
||||
}
|
||||
|
||||
// Stop tracking this DNS request and delete the request context
|
||||
WI_VERIFY(m_dnsRequests.erase(queryContext->m_id) == 1);
|
||||
|
||||
// Set event if all tracked requests have finished
|
||||
if (m_dnsRequests.empty())
|
||||
{
|
||||
m_allRequestsFinished.SetEvent();
|
||||
}
|
||||
}
|
||||
CATCH_LOG()
|
||||
|
||||
void DnsResolver::ResolveExternalInterfaceConstraintIndex() noexcept
|
||||
try
|
||||
{
|
||||
const std::lock_guard lock(m_dnsLock);
|
||||
if (m_stopped)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
if (m_externalInterfaceConstraintName.empty())
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
NET_LUID interfaceLuid{};
|
||||
ULONG interfaceIndex = 0;
|
||||
|
||||
// Update the interface index on every exit path.
|
||||
// The calls below to convert interface name to index will fail if the interface does not exist anymore,
|
||||
// in which case we still need to reset the interface index to its default value of 0.
|
||||
const auto setInterfaceIndex = wil::scope_exit([&] {
|
||||
if (interfaceIndex != m_externalInterfaceConstraintIndex)
|
||||
{
|
||||
WSL_LOG(
|
||||
"DnsResolver::ResolveExternalInterfaceConstraintIndex - setting m_externalInterfaceConstraintIndex to new value",
|
||||
TraceLoggingValue(m_externalInterfaceConstraintIndex, "old interface index"),
|
||||
TraceLoggingValue(interfaceIndex, "new interface index"));
|
||||
|
||||
m_externalInterfaceConstraintIndex = interfaceIndex;
|
||||
}
|
||||
});
|
||||
|
||||
// If external interface constraint is configured, query to see if it's present on the host.
|
||||
auto errorCode = ConvertInterfaceAliasToLuid(m_externalInterfaceConstraintName.c_str(), &interfaceLuid);
|
||||
if (FAILED_WIN32_LOG(errorCode))
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
errorCode = ConvertInterfaceLuidToIndex(&interfaceLuid, reinterpret_cast<PNET_IFINDEX>(&interfaceIndex));
|
||||
if (FAILED_WIN32_LOG(errorCode))
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
CATCH_LOG()
|
||||
|
||||
VOID CALLBACK DnsResolver::DnsQueryRawCallback(_In_ VOID* queryContext, _Inout_opt_ DNS_QUERY_RAW_RESULT* queryResults) noexcept
|
||||
try
|
||||
{
|
||||
assert(queryContext != nullptr);
|
||||
|
||||
const auto context = static_cast<DnsQueryContext*>(queryContext);
|
||||
|
||||
// Call into DnsResolver parent object to process the query result
|
||||
context->m_handleQueryCompletion(context, queryResults);
|
||||
}
|
||||
CATCH_LOG()
|
||||
|
||||
VOID CALLBACK DnsResolver::InterfaceChangeCallback(_In_ PVOID context, PMIB_IPINTERFACE_ROW, MIB_NOTIFICATION_TYPE) noexcept
|
||||
try
|
||||
{
|
||||
const auto dnsResolver = static_cast<DnsResolver*>(context);
|
||||
dnsResolver->ResolveExternalInterfaceConstraintIndex();
|
||||
}
|
||||
CATCH_LOG()
|
||||
@ -1,141 +1,141 @@
|
||||
// Copyright (C) Microsoft Corporation. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "DnsTunnelingChannel.h"
|
||||
#include "WslCoreMessageQueue.h"
|
||||
#include "WslCoreNetworkingSupport.h"
|
||||
|
||||
namespace wsl::core::networking {
|
||||
|
||||
enum class DnsResolverFlags
|
||||
{
|
||||
None = 0x0,
|
||||
BestEffortDnsParsing = 0x1
|
||||
};
|
||||
DEFINE_ENUM_FLAG_OPERATORS(DnsResolverFlags);
|
||||
|
||||
class DnsResolver
|
||||
{
|
||||
public:
|
||||
DnsResolver(wil::unique_socket&& dnsHvsocket, DnsResolverFlags flags);
|
||||
~DnsResolver() noexcept;
|
||||
|
||||
DnsResolver(const DnsResolver&) = delete;
|
||||
DnsResolver& operator=(const DnsResolver&) = delete;
|
||||
|
||||
DnsResolver(DnsResolver&&) = delete;
|
||||
DnsResolver& operator=(DnsResolver&&) = delete;
|
||||
|
||||
void Stop() noexcept;
|
||||
|
||||
static HRESULT LoadDnsResolverMethods() noexcept;
|
||||
|
||||
private:
|
||||
struct DnsQueryContext
|
||||
{
|
||||
// Struct containing protocol (TCP/UDP) and unique id of the Linux DNS client making the request.
|
||||
LX_GNS_DNS_CLIENT_IDENTIFIER m_dnsClientIdentifier{};
|
||||
|
||||
// Handle used to cancel the request.
|
||||
DNS_QUERY_RAW_CANCEL m_cancelHandle{};
|
||||
|
||||
// Unique query id.
|
||||
uint32_t m_id{};
|
||||
|
||||
// Callback to the parent object to notify about the DNS query completion.
|
||||
std::function<void(DnsQueryContext*, DNS_QUERY_RAW_RESULT*)> m_handleQueryCompletion;
|
||||
|
||||
DnsQueryContext(
|
||||
uint32_t id,
|
||||
const LX_GNS_DNS_CLIENT_IDENTIFIER& dnsClientIdentifier,
|
||||
std::function<void(DnsQueryContext*, DNS_QUERY_RAW_RESULT*)>&& handleQueryCompletion) :
|
||||
m_dnsClientIdentifier(dnsClientIdentifier), m_id(id), m_handleQueryCompletion(std::move(handleQueryCompletion))
|
||||
{
|
||||
}
|
||||
|
||||
~DnsQueryContext() noexcept = default;
|
||||
|
||||
DnsQueryContext(const DnsQueryContext&) = delete;
|
||||
DnsQueryContext& operator=(const DnsQueryContext&) = delete;
|
||||
DnsQueryContext(DnsQueryContext&&) = delete;
|
||||
DnsQueryContext& operator=(DnsQueryContext&&) = delete;
|
||||
};
|
||||
|
||||
void GenerateTelemetry() noexcept;
|
||||
|
||||
// Process DNS request received from Linux.
|
||||
//
|
||||
// Arguments:
|
||||
// dnsBuffer - buffer containing DNS request.
|
||||
// dnsClientIdentifier - struct containing protocol (TCP/UDP) and unique id of the Linux DNS client making the request.
|
||||
void ProcessDnsRequest(const gsl::span<gsl::byte> dnsBuffer, const LX_GNS_DNS_CLIENT_IDENTIFIER& dnsClientIdentifier) noexcept;
|
||||
|
||||
// Handle completion of DNS query.
|
||||
//
|
||||
// Arguments:
|
||||
// dnsQueryContext - context structure for the DNS request.
|
||||
// queryResults - structure containing result of the DNS request.
|
||||
void HandleDnsQueryCompletion(_Inout_ DnsQueryContext* dnsQueryContext, _Inout_opt_ DNS_QUERY_RAW_RESULT* queryResults) noexcept;
|
||||
|
||||
void ResolveExternalInterfaceConstraintIndex() noexcept;
|
||||
|
||||
// Callback that will be invoked by the DNS API whenever a request finishes. The callback is invoked on success, error or when request is cancelled.
|
||||
//
|
||||
// Arguments:
|
||||
// queryContext - pointer to context structure, will be a structure of type DnsQueryContext.
|
||||
// queryResults - pointer to structure containing the result of the DNS request.
|
||||
static VOID CALLBACK DnsQueryRawCallback(_In_ VOID* queryContext, _Inout_opt_ DNS_QUERY_RAW_RESULT* queryResults) noexcept;
|
||||
|
||||
static VOID CALLBACK InterfaceChangeCallback(_In_ PVOID context, PMIB_IPINTERFACE_ROW, MIB_NOTIFICATION_TYPE) noexcept;
|
||||
|
||||
std::recursive_mutex m_dnsLock;
|
||||
|
||||
// Flag used when shutting down the object.
|
||||
_Guarded_by_(m_dnsLock) bool m_stopped = false;
|
||||
|
||||
// Hvsocket channel used to exchange DNS messages with Linux.
|
||||
DnsTunnelingChannel m_dnsChannel;
|
||||
|
||||
// Queue used to send DNS responses to Linux.
|
||||
WslCoreMessageQueue m_dnsResponseQueue;
|
||||
|
||||
// Unique id that is incremented for each request. In case the value reaches MAX_UINT and is reset to 0,
|
||||
// it's assumed previous requests with id's 0, 1, ... finished in the meantime and the id can be reused.
|
||||
_Guarded_by_(m_dnsLock) uint32_t m_currentRequestId = 0;
|
||||
|
||||
// Mapping request id to the request context structure.
|
||||
_Guarded_by_(m_dnsLock) std::unordered_map<uint32_t, std::unique_ptr<DnsQueryContext>> m_dnsRequests {};
|
||||
|
||||
// Event that is set when all tracked DNS requests have completed.
|
||||
wil::unique_event m_allRequestsFinished{wil::EventOptions::ManualReset};
|
||||
|
||||
// Used for handling of external interface constraint setting.
|
||||
unique_notify_handle m_interfaceNotificationHandle{};
|
||||
|
||||
std::wstring m_externalInterfaceConstraintName;
|
||||
_Guarded_by_(m_dnsLock) ULONG m_externalInterfaceConstraintIndex = 0;
|
||||
|
||||
const DnsResolverFlags m_flags{};
|
||||
|
||||
// Statistics used for telemetry.
|
||||
std::atomic<uint32_t> m_totalUdpQueries{0};
|
||||
std::atomic<uint32_t> m_successfulUdpQueries{0};
|
||||
std::atomic<uint32_t> m_totalTcpQueries{0};
|
||||
std::atomic<uint32_t> m_successfulTcpQueries{0};
|
||||
std::atomic<uint32_t> m_queriesWithNullResult{0};
|
||||
std::atomic<uint32_t> m_failedDnsQueryRawCalls{0};
|
||||
|
||||
_Guarded_by_(m_dnsLock) std::map<uint32_t, uint32_t> m_dnsApiFailures;
|
||||
|
||||
// Dynamic functions used for calling the DNS APIs.
|
||||
|
||||
// Function to start a raw DNS request.
|
||||
static std::optional<LxssDynamicFunction<decltype(DnsQueryRaw)>> s_dnsQueryRaw;
|
||||
// Function to cancel a raw DNS request.
|
||||
static std::optional<LxssDynamicFunction<decltype(DnsCancelQueryRaw)>> s_dnsCancelQueryRaw;
|
||||
// Function to free the structure containing the result of a raw DNS request.
|
||||
static std::optional<LxssDynamicFunction<decltype(DnsQueryRawResultFree)>> s_dnsQueryRawResultFree;
|
||||
};
|
||||
|
||||
} // namespace wsl::core::networking
|
||||
// Copyright (C) Microsoft Corporation. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "DnsTunnelingChannel.h"
|
||||
#include "WslCoreMessageQueue.h"
|
||||
#include "WslCoreNetworkingSupport.h"
|
||||
|
||||
namespace wsl::core::networking {
|
||||
|
||||
enum class DnsResolverFlags
|
||||
{
|
||||
None = 0x0,
|
||||
BestEffortDnsParsing = 0x1
|
||||
};
|
||||
DEFINE_ENUM_FLAG_OPERATORS(DnsResolverFlags);
|
||||
|
||||
class DnsResolver
|
||||
{
|
||||
public:
|
||||
DnsResolver(wil::unique_socket&& dnsHvsocket, DnsResolverFlags flags);
|
||||
~DnsResolver() noexcept;
|
||||
|
||||
DnsResolver(const DnsResolver&) = delete;
|
||||
DnsResolver& operator=(const DnsResolver&) = delete;
|
||||
|
||||
DnsResolver(DnsResolver&&) = delete;
|
||||
DnsResolver& operator=(DnsResolver&&) = delete;
|
||||
|
||||
void Stop() noexcept;
|
||||
|
||||
static HRESULT LoadDnsResolverMethods() noexcept;
|
||||
|
||||
private:
|
||||
struct DnsQueryContext
|
||||
{
|
||||
// Struct containing protocol (TCP/UDP) and unique id of the Linux DNS client making the request.
|
||||
LX_GNS_DNS_CLIENT_IDENTIFIER m_dnsClientIdentifier{};
|
||||
|
||||
// Handle used to cancel the request.
|
||||
DNS_QUERY_RAW_CANCEL m_cancelHandle{};
|
||||
|
||||
// Unique query id.
|
||||
uint32_t m_id{};
|
||||
|
||||
// Callback to the parent object to notify about the DNS query completion.
|
||||
std::function<void(DnsQueryContext*, DNS_QUERY_RAW_RESULT*)> m_handleQueryCompletion;
|
||||
|
||||
DnsQueryContext(
|
||||
uint32_t id,
|
||||
const LX_GNS_DNS_CLIENT_IDENTIFIER& dnsClientIdentifier,
|
||||
std::function<void(DnsQueryContext*, DNS_QUERY_RAW_RESULT*)>&& handleQueryCompletion) :
|
||||
m_dnsClientIdentifier(dnsClientIdentifier), m_id(id), m_handleQueryCompletion(std::move(handleQueryCompletion))
|
||||
{
|
||||
}
|
||||
|
||||
~DnsQueryContext() noexcept = default;
|
||||
|
||||
DnsQueryContext(const DnsQueryContext&) = delete;
|
||||
DnsQueryContext& operator=(const DnsQueryContext&) = delete;
|
||||
DnsQueryContext(DnsQueryContext&&) = delete;
|
||||
DnsQueryContext& operator=(DnsQueryContext&&) = delete;
|
||||
};
|
||||
|
||||
void GenerateTelemetry() noexcept;
|
||||
|
||||
// Process DNS request received from Linux.
|
||||
//
|
||||
// Arguments:
|
||||
// dnsBuffer - buffer containing DNS request.
|
||||
// dnsClientIdentifier - struct containing protocol (TCP/UDP) and unique id of the Linux DNS client making the request.
|
||||
void ProcessDnsRequest(const gsl::span<gsl::byte> dnsBuffer, const LX_GNS_DNS_CLIENT_IDENTIFIER& dnsClientIdentifier) noexcept;
|
||||
|
||||
// Handle completion of DNS query.
|
||||
//
|
||||
// Arguments:
|
||||
// dnsQueryContext - context structure for the DNS request.
|
||||
// queryResults - structure containing result of the DNS request.
|
||||
void HandleDnsQueryCompletion(_Inout_ DnsQueryContext* dnsQueryContext, _Inout_opt_ DNS_QUERY_RAW_RESULT* queryResults) noexcept;
|
||||
|
||||
void ResolveExternalInterfaceConstraintIndex() noexcept;
|
||||
|
||||
// Callback that will be invoked by the DNS API whenever a request finishes. The callback is invoked on success, error or when request is cancelled.
|
||||
//
|
||||
// Arguments:
|
||||
// queryContext - pointer to context structure, will be a structure of type DnsQueryContext.
|
||||
// queryResults - pointer to structure containing the result of the DNS request.
|
||||
static VOID CALLBACK DnsQueryRawCallback(_In_ VOID* queryContext, _Inout_opt_ DNS_QUERY_RAW_RESULT* queryResults) noexcept;
|
||||
|
||||
static VOID CALLBACK InterfaceChangeCallback(_In_ PVOID context, PMIB_IPINTERFACE_ROW, MIB_NOTIFICATION_TYPE) noexcept;
|
||||
|
||||
std::recursive_mutex m_dnsLock;
|
||||
|
||||
// Flag used when shutting down the object.
|
||||
_Guarded_by_(m_dnsLock) bool m_stopped = false;
|
||||
|
||||
// Hvsocket channel used to exchange DNS messages with Linux.
|
||||
DnsTunnelingChannel m_dnsChannel;
|
||||
|
||||
// Queue used to send DNS responses to Linux.
|
||||
WslCoreMessageQueue m_dnsResponseQueue;
|
||||
|
||||
// Unique id that is incremented for each request. In case the value reaches MAX_UINT and is reset to 0,
|
||||
// it's assumed previous requests with id's 0, 1, ... finished in the meantime and the id can be reused.
|
||||
_Guarded_by_(m_dnsLock) uint32_t m_currentRequestId = 0;
|
||||
|
||||
// Mapping request id to the request context structure.
|
||||
_Guarded_by_(m_dnsLock) std::unordered_map<uint32_t, std::unique_ptr<DnsQueryContext>> m_dnsRequests {};
|
||||
|
||||
// Event that is set when all tracked DNS requests have completed.
|
||||
wil::unique_event m_allRequestsFinished{wil::EventOptions::ManualReset};
|
||||
|
||||
// Used for handling of external interface constraint setting.
|
||||
unique_notify_handle m_interfaceNotificationHandle{};
|
||||
|
||||
std::wstring m_externalInterfaceConstraintName;
|
||||
_Guarded_by_(m_dnsLock) ULONG m_externalInterfaceConstraintIndex = 0;
|
||||
|
||||
const DnsResolverFlags m_flags{};
|
||||
|
||||
// Statistics used for telemetry.
|
||||
std::atomic<uint32_t> m_totalUdpQueries{0};
|
||||
std::atomic<uint32_t> m_successfulUdpQueries{0};
|
||||
std::atomic<uint32_t> m_totalTcpQueries{0};
|
||||
std::atomic<uint32_t> m_successfulTcpQueries{0};
|
||||
std::atomic<uint32_t> m_queriesWithNullResult{0};
|
||||
std::atomic<uint32_t> m_failedDnsQueryRawCalls{0};
|
||||
|
||||
_Guarded_by_(m_dnsLock) std::map<uint32_t, uint32_t> m_dnsApiFailures;
|
||||
|
||||
// Dynamic functions used for calling the DNS APIs.
|
||||
|
||||
// Function to start a raw DNS request.
|
||||
static std::optional<LxssDynamicFunction<decltype(DnsQueryRaw)>> s_dnsQueryRaw;
|
||||
// Function to cancel a raw DNS request.
|
||||
static std::optional<LxssDynamicFunction<decltype(DnsCancelQueryRaw)>> s_dnsCancelQueryRaw;
|
||||
// Function to free the structure containing the result of a raw DNS request.
|
||||
static std::optional<LxssDynamicFunction<decltype(DnsQueryRawResultFree)>> s_dnsQueryRawResultFree;
|
||||
};
|
||||
|
||||
} // namespace wsl::core::networking
|
||||
@ -1,115 +1,115 @@
|
||||
// Copyright (C) Microsoft Corporation. All rights reserved.
|
||||
|
||||
#include "precomp.h"
|
||||
#include "DnsTunnelingChannel.h"
|
||||
|
||||
using wsl::core::networking::DnsTunnelingChannel;
|
||||
|
||||
DnsTunnelingChannel::DnsTunnelingChannel(wil::unique_socket&& socket, DnsTunnelingCallback&& reportDnsRequest) :
|
||||
m_channel{std::move(socket), "DnsTunneling", m_stopEvent.get()}, m_reportDnsRequest(std::move(reportDnsRequest))
|
||||
{
|
||||
WSL_LOG("DnsTunnelingChannel::DnsTunnelingChannel [Windows]", TraceLoggingValue(m_channel.Socket(), "socket"));
|
||||
|
||||
// Start thread waiting for incoming messages from Linux side
|
||||
m_receiveWorkerThread = std::thread([this]() { ReceiveLoop(); });
|
||||
}
|
||||
|
||||
DnsTunnelingChannel::~DnsTunnelingChannel()
|
||||
{
|
||||
Stop();
|
||||
}
|
||||
|
||||
void DnsTunnelingChannel::SendDnsMessage(const gsl::span<gsl::byte> dnsBuffer, const LX_GNS_DNS_CLIENT_IDENTIFIER& dnsClientIdentifier) noexcept
|
||||
try
|
||||
{
|
||||
// Exit if channel was stopped
|
||||
if (m_stopEvent.is_signaled())
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
wsl::shared::MessageWriter<LX_GNS_DNS_TUNNELING_MESSAGE> message(LxGnsMessageDnsTunneling);
|
||||
message->DnsClientIdentifier = dnsClientIdentifier;
|
||||
message.WriteSpan(dnsBuffer);
|
||||
|
||||
m_channel.SendMessage<LX_GNS_DNS_TUNNELING_MESSAGE>(message.Span());
|
||||
}
|
||||
CATCH_LOG()
|
||||
|
||||
void DnsTunnelingChannel::ReceiveLoop() noexcept
|
||||
{
|
||||
std::vector<gsl::byte> receiveBuffer;
|
||||
|
||||
for (;;)
|
||||
{
|
||||
try
|
||||
{
|
||||
if (m_stopEvent.is_signaled())
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
WSL_LOG_DEBUG("DnsTunnelingChannel::ReceiveLoop [Windows] - waiting for next message from Linux");
|
||||
|
||||
// Read next message. wsl::shared::socket::RecvMessage() first reads the message header, then uses it to determine the
|
||||
// total size of the message and read the rest of the message, resizing the buffer if needed.
|
||||
auto [message, span] = m_channel.ReceiveMessageOrClosed<MESSAGE_HEADER>();
|
||||
if (message == nullptr)
|
||||
{
|
||||
WSL_LOG("DnsTunnelingChannel::ReceiveLoop [Windows] - failed to read message");
|
||||
return;
|
||||
}
|
||||
|
||||
// Get the message type from the message header
|
||||
switch (message->MessageType)
|
||||
{
|
||||
case LxGnsMessageDnsTunneling:
|
||||
{
|
||||
// Cast message to a LX_GNS_DNS_TUNNELING_MESSAGE struct
|
||||
auto* dnsMessage = gslhelpers::try_get_struct<LX_GNS_DNS_TUNNELING_MESSAGE>(span);
|
||||
if (!dnsMessage)
|
||||
{
|
||||
WSL_LOG(
|
||||
"DnsTunnelingChannel::ReceiveLoop [Windows] - failed to convert message to LX_GNS_DNS_TUNNELING_MESSAGE");
|
||||
return;
|
||||
}
|
||||
|
||||
// Extract DNS buffer from message
|
||||
auto dnsBuffer = span.subspan(offsetof(LX_GNS_DNS_TUNNELING_MESSAGE, Buffer));
|
||||
|
||||
WSL_LOG_DEBUG(
|
||||
"DnsTunnelingChannel::ReceiveLoop [Windows] - received DNS message",
|
||||
TraceLoggingValue(dnsBuffer.size(), "DNS buffer size"),
|
||||
TraceLoggingValue(dnsMessage->DnsClientIdentifier.Protocol == IPPROTO_UDP ? "UDP" : "TCP", "Protocol"),
|
||||
TraceLoggingValue(dnsMessage->DnsClientIdentifier.DnsClientId, "DNS client id"));
|
||||
|
||||
// Invoke callback to notify about the new DNS request
|
||||
m_reportDnsRequest(dnsBuffer, dnsMessage->DnsClientIdentifier);
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
default:
|
||||
{
|
||||
THROW_HR_MSG(E_UNEXPECTED, "Unexpected LX_MESSAGE_TYPE : %i", message->MessageType);
|
||||
}
|
||||
}
|
||||
}
|
||||
CATCH_LOG()
|
||||
}
|
||||
}
|
||||
|
||||
void DnsTunnelingChannel::Stop() noexcept
|
||||
try
|
||||
{
|
||||
WSL_LOG("DnsTunnelingChannel::Stop [Windows]");
|
||||
|
||||
m_stopEvent.SetEvent();
|
||||
|
||||
// Stop receive loop
|
||||
if (m_receiveWorkerThread.joinable())
|
||||
{
|
||||
m_receiveWorkerThread.join();
|
||||
}
|
||||
}
|
||||
CATCH_LOG()
|
||||
// Copyright (C) Microsoft Corporation. All rights reserved.
|
||||
|
||||
#include "precomp.h"
|
||||
#include "DnsTunnelingChannel.h"
|
||||
|
||||
using wsl::core::networking::DnsTunnelingChannel;
|
||||
|
||||
DnsTunnelingChannel::DnsTunnelingChannel(wil::unique_socket&& socket, DnsTunnelingCallback&& reportDnsRequest) :
|
||||
m_channel{std::move(socket), "DnsTunneling", m_stopEvent.get()}, m_reportDnsRequest(std::move(reportDnsRequest))
|
||||
{
|
||||
WSL_LOG("DnsTunnelingChannel::DnsTunnelingChannel [Windows]", TraceLoggingValue(m_channel.Socket(), "socket"));
|
||||
|
||||
// Start thread waiting for incoming messages from Linux side
|
||||
m_receiveWorkerThread = std::thread([this]() { ReceiveLoop(); });
|
||||
}
|
||||
|
||||
DnsTunnelingChannel::~DnsTunnelingChannel()
|
||||
{
|
||||
Stop();
|
||||
}
|
||||
|
||||
void DnsTunnelingChannel::SendDnsMessage(const gsl::span<gsl::byte> dnsBuffer, const LX_GNS_DNS_CLIENT_IDENTIFIER& dnsClientIdentifier) noexcept
|
||||
try
|
||||
{
|
||||
// Exit if channel was stopped
|
||||
if (m_stopEvent.is_signaled())
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
wsl::shared::MessageWriter<LX_GNS_DNS_TUNNELING_MESSAGE> message(LxGnsMessageDnsTunneling);
|
||||
message->DnsClientIdentifier = dnsClientIdentifier;
|
||||
message.WriteSpan(dnsBuffer);
|
||||
|
||||
m_channel.SendMessage<LX_GNS_DNS_TUNNELING_MESSAGE>(message.Span());
|
||||
}
|
||||
CATCH_LOG()
|
||||
|
||||
void DnsTunnelingChannel::ReceiveLoop() noexcept
|
||||
{
|
||||
std::vector<gsl::byte> receiveBuffer;
|
||||
|
||||
for (;;)
|
||||
{
|
||||
try
|
||||
{
|
||||
if (m_stopEvent.is_signaled())
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
WSL_LOG_DEBUG("DnsTunnelingChannel::ReceiveLoop [Windows] - waiting for next message from Linux");
|
||||
|
||||
// Read next message. wsl::shared::socket::RecvMessage() first reads the message header, then uses it to determine the
|
||||
// total size of the message and read the rest of the message, resizing the buffer if needed.
|
||||
auto [message, span] = m_channel.ReceiveMessageOrClosed<MESSAGE_HEADER>();
|
||||
if (message == nullptr)
|
||||
{
|
||||
WSL_LOG("DnsTunnelingChannel::ReceiveLoop [Windows] - failed to read message");
|
||||
return;
|
||||
}
|
||||
|
||||
// Get the message type from the message header
|
||||
switch (message->MessageType)
|
||||
{
|
||||
case LxGnsMessageDnsTunneling:
|
||||
{
|
||||
// Cast message to a LX_GNS_DNS_TUNNELING_MESSAGE struct
|
||||
auto* dnsMessage = gslhelpers::try_get_struct<LX_GNS_DNS_TUNNELING_MESSAGE>(span);
|
||||
if (!dnsMessage)
|
||||
{
|
||||
WSL_LOG(
|
||||
"DnsTunnelingChannel::ReceiveLoop [Windows] - failed to convert message to LX_GNS_DNS_TUNNELING_MESSAGE");
|
||||
return;
|
||||
}
|
||||
|
||||
// Extract DNS buffer from message
|
||||
auto dnsBuffer = span.subspan(offsetof(LX_GNS_DNS_TUNNELING_MESSAGE, Buffer));
|
||||
|
||||
WSL_LOG_DEBUG(
|
||||
"DnsTunnelingChannel::ReceiveLoop [Windows] - received DNS message",
|
||||
TraceLoggingValue(dnsBuffer.size(), "DNS buffer size"),
|
||||
TraceLoggingValue(dnsMessage->DnsClientIdentifier.Protocol == IPPROTO_UDP ? "UDP" : "TCP", "Protocol"),
|
||||
TraceLoggingValue(dnsMessage->DnsClientIdentifier.DnsClientId, "DNS client id"));
|
||||
|
||||
// Invoke callback to notify about the new DNS request
|
||||
m_reportDnsRequest(dnsBuffer, dnsMessage->DnsClientIdentifier);
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
default:
|
||||
{
|
||||
THROW_HR_MSG(E_UNEXPECTED, "Unexpected LX_MESSAGE_TYPE : %i", message->MessageType);
|
||||
}
|
||||
}
|
||||
}
|
||||
CATCH_LOG()
|
||||
}
|
||||
}
|
||||
|
||||
void DnsTunnelingChannel::Stop() noexcept
|
||||
try
|
||||
{
|
||||
WSL_LOG("DnsTunnelingChannel::Stop [Windows]");
|
||||
|
||||
m_stopEvent.SetEvent();
|
||||
|
||||
// Stop receive loop
|
||||
if (m_receiveWorkerThread.joinable())
|
||||
{
|
||||
m_receiveWorkerThread.join();
|
||||
}
|
||||
}
|
||||
CATCH_LOG()
|
||||
@ -1,50 +1,50 @@
|
||||
// Copyright (C) Microsoft Corporation. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <wil/resource.h>
|
||||
#include "lxinitshared.h"
|
||||
#include "SocketChannel.h"
|
||||
|
||||
namespace wsl::core::networking {
|
||||
|
||||
using DnsTunnelingCallback = std::function<void(const gsl::span<gsl::byte>, const LX_GNS_DNS_CLIENT_IDENTIFIER&)>;
|
||||
|
||||
class DnsTunnelingChannel
|
||||
{
|
||||
public:
|
||||
DnsTunnelingChannel(wil::unique_socket&& socket, DnsTunnelingCallback&& reportDnsRequest);
|
||||
~DnsTunnelingChannel();
|
||||
|
||||
DnsTunnelingChannel(const DnsTunnelingChannel&) = delete;
|
||||
DnsTunnelingChannel& operator=(const DnsTunnelingChannel&) = delete;
|
||||
|
||||
DnsTunnelingChannel(DnsTunnelingChannel&&) = delete;
|
||||
DnsTunnelingChannel& operator=(DnsTunnelingChannel&&) = delete;
|
||||
|
||||
// Construct and send a LX_GNS_DNS_TUNNELING_MESSAGE message on the channel.
|
||||
// Note: Callers are responsible for sequencing calls to this method.
|
||||
//
|
||||
// Arguments:
|
||||
// dnsBuffer - buffer containing DNS response.
|
||||
// dnsClientIdentifier - struct containing protocol (TCP/UDP) and unique id of the Linux DNS client making the request.
|
||||
void SendDnsMessage(const gsl::span<gsl::byte> dnsBuffer, const LX_GNS_DNS_CLIENT_IDENTIFIER& dnsClientIdentifier) noexcept;
|
||||
|
||||
// Stop the channel.
|
||||
void Stop() noexcept;
|
||||
|
||||
private:
|
||||
// Wait for messages on the channel from Linux side.
|
||||
void ReceiveLoop() noexcept;
|
||||
|
||||
wil::unique_event m_stopEvent{wil::EventOptions::ManualReset};
|
||||
|
||||
wsl::shared::SocketChannel m_channel;
|
||||
|
||||
std::thread m_receiveWorkerThread;
|
||||
|
||||
// Callback used to notify when there is a new DNS request message on the channel.
|
||||
DnsTunnelingCallback m_reportDnsRequest;
|
||||
};
|
||||
|
||||
} // namespace wsl::core::networking
|
||||
// Copyright (C) Microsoft Corporation. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <wil/resource.h>
|
||||
#include "lxinitshared.h"
|
||||
#include "SocketChannel.h"
|
||||
|
||||
namespace wsl::core::networking {
|
||||
|
||||
using DnsTunnelingCallback = std::function<void(const gsl::span<gsl::byte>, const LX_GNS_DNS_CLIENT_IDENTIFIER&)>;
|
||||
|
||||
class DnsTunnelingChannel
|
||||
{
|
||||
public:
|
||||
DnsTunnelingChannel(wil::unique_socket&& socket, DnsTunnelingCallback&& reportDnsRequest);
|
||||
~DnsTunnelingChannel();
|
||||
|
||||
DnsTunnelingChannel(const DnsTunnelingChannel&) = delete;
|
||||
DnsTunnelingChannel& operator=(const DnsTunnelingChannel&) = delete;
|
||||
|
||||
DnsTunnelingChannel(DnsTunnelingChannel&&) = delete;
|
||||
DnsTunnelingChannel& operator=(DnsTunnelingChannel&&) = delete;
|
||||
|
||||
// Construct and send a LX_GNS_DNS_TUNNELING_MESSAGE message on the channel.
|
||||
// Note: Callers are responsible for sequencing calls to this method.
|
||||
//
|
||||
// Arguments:
|
||||
// dnsBuffer - buffer containing DNS response.
|
||||
// dnsClientIdentifier - struct containing protocol (TCP/UDP) and unique id of the Linux DNS client making the request.
|
||||
void SendDnsMessage(const gsl::span<gsl::byte> dnsBuffer, const LX_GNS_DNS_CLIENT_IDENTIFIER& dnsClientIdentifier) noexcept;
|
||||
|
||||
// Stop the channel.
|
||||
void Stop() noexcept;
|
||||
|
||||
private:
|
||||
// Wait for messages on the channel from Linux side.
|
||||
void ReceiveLoop() noexcept;
|
||||
|
||||
wil::unique_event m_stopEvent{wil::EventOptions::ManualReset};
|
||||
|
||||
wsl::shared::SocketChannel m_channel;
|
||||
|
||||
std::thread m_receiveWorkerThread;
|
||||
|
||||
// Callback used to notify when there is a new DNS request message on the channel.
|
||||
DnsTunnelingCallback m_reportDnsRequest;
|
||||
};
|
||||
|
||||
} // namespace wsl::core::networking
|
||||
@ -6,7 +6,6 @@
|
||||
#include "WslCoreHostDnsInfo.h"
|
||||
#include "Stringify.h"
|
||||
#include "WslCoreFirewallSupport.h"
|
||||
#include "WslCoreVm.h"
|
||||
#include "hcs.hpp"
|
||||
|
||||
using namespace wsl::core::networking;
|
||||
@ -672,7 +671,7 @@ wsl::windows::common::hcs::unique_hcn_network NatNetworking::CreateNetwork(wsl::
|
||||
wil::ResultFromException(WI_DIAGNOSTICS_INFO, [&] {
|
||||
try
|
||||
{
|
||||
wsl::core::networking::ConfigureHyperVFirewall(config.FirewallConfig, c_vmOwner);
|
||||
wsl::core::networking::ConfigureHyperVFirewall(config.FirewallConfig, wsl::windows::common::wslutil::c_vmOwner);
|
||||
natNetwork = CreateNetworkInternal(config);
|
||||
}
|
||||
catch (...)
|
||||
@ -1,154 +1,154 @@
|
||||
/*++
|
||||
|
||||
Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
Module Name:
|
||||
|
||||
RingBuffer.cpp
|
||||
|
||||
Abstract:
|
||||
|
||||
This file contains definitions for the RingBuffer class.
|
||||
|
||||
--*/
|
||||
|
||||
#include "precomp.h"
|
||||
#include "RingBuffer.h"
|
||||
|
||||
RingBuffer::RingBuffer(size_t size) : m_maxSize(size), m_offset(0)
|
||||
{
|
||||
m_buffer.reserve(size);
|
||||
}
|
||||
|
||||
void RingBuffer::Insert(std::string_view data)
|
||||
{
|
||||
auto lock = m_lock.lock_exclusive();
|
||||
auto remainingData = gsl::make_span(data.data(), data.size());
|
||||
if (remainingData.size() > m_maxSize)
|
||||
{
|
||||
remainingData = remainingData.subspan(remainingData.size() - m_maxSize);
|
||||
}
|
||||
|
||||
const auto bytesAtEnd = std::min(m_maxSize - m_offset, remainingData.size());
|
||||
if (m_offset + bytesAtEnd > m_buffer.size())
|
||||
{
|
||||
m_buffer.resize(m_offset + bytesAtEnd);
|
||||
WI_ASSERT(m_buffer.size() <= m_maxSize);
|
||||
}
|
||||
|
||||
const auto allBuffer = gsl::make_span(m_buffer);
|
||||
const auto beginCopyBuffer = allBuffer.subspan(m_offset, bytesAtEnd);
|
||||
copy(remainingData.subspan(0, bytesAtEnd), beginCopyBuffer);
|
||||
remainingData = remainingData.subspan(bytesAtEnd);
|
||||
if (!remainingData.empty())
|
||||
{
|
||||
copy(remainingData, allBuffer);
|
||||
m_offset = remainingData.size();
|
||||
}
|
||||
else
|
||||
{
|
||||
m_offset += bytesAtEnd;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::string> RingBuffer::GetLastDelimitedStrings(char Delimiter, size_t Count) const
|
||||
{
|
||||
auto lock = m_lock.lock_shared();
|
||||
auto [begin, end] = Contents();
|
||||
std::vector<std::string> results;
|
||||
std::optional<size_t> endIndex;
|
||||
for (size_t i = end.size(); i > 0; i--)
|
||||
{
|
||||
if (results.size() == Count)
|
||||
{
|
||||
break;
|
||||
}
|
||||
|
||||
if (Delimiter == end[i - 1])
|
||||
{
|
||||
if (endIndex.has_value())
|
||||
{
|
||||
results.emplace(results.begin(), &end[i], endIndex.value() - i);
|
||||
endIndex.reset();
|
||||
}
|
||||
else
|
||||
{
|
||||
endIndex = i - 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (results.size() == Count)
|
||||
{
|
||||
return results;
|
||||
}
|
||||
|
||||
std::string partial;
|
||||
if (endIndex.has_value())
|
||||
{
|
||||
partial = std::string{&end[0], endIndex.value()};
|
||||
endIndex.reset();
|
||||
}
|
||||
|
||||
for (size_t i = begin.size(); i > 0; i--)
|
||||
{
|
||||
if (results.size() == Count)
|
||||
{
|
||||
break;
|
||||
}
|
||||
|
||||
if (Delimiter == begin[i - 1])
|
||||
{
|
||||
if (!partial.empty())
|
||||
{
|
||||
// The debug CRT will fastfail if begin[size] is accessed
|
||||
// But in this case it's not a problem because begin.size() - i would be == 0
|
||||
std::string partial_begin{&begin.data()[i], begin.size() - i};
|
||||
results.emplace(results.begin(), partial_begin + partial);
|
||||
partial.clear();
|
||||
}
|
||||
else if (endIndex.has_value())
|
||||
{
|
||||
results.emplace(results.begin(), &begin.data()[i], endIndex.value() - i);
|
||||
endIndex.reset();
|
||||
}
|
||||
else
|
||||
{
|
||||
endIndex = i - 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (results.size() < Count)
|
||||
{
|
||||
// May have lost some data, or this could be the very first line logged.
|
||||
if (!partial.empty())
|
||||
{
|
||||
results.emplace(results.begin(), partial);
|
||||
}
|
||||
else if (endIndex.has_value())
|
||||
{
|
||||
results.emplace(results.begin(), &begin[0], endIndex.value());
|
||||
}
|
||||
}
|
||||
|
||||
return results;
|
||||
}
|
||||
|
||||
std::string RingBuffer::Get() const
|
||||
{
|
||||
auto lock = m_lock.lock_shared();
|
||||
auto [begin, end] = Contents();
|
||||
std::string data;
|
||||
data.reserve(begin.size() + end.size());
|
||||
data.append(begin.data(), begin.size());
|
||||
data.append(end.data(), end.size());
|
||||
return data;
|
||||
}
|
||||
|
||||
std::pair<std::string_view, std::string_view> RingBuffer::Contents() const
|
||||
{
|
||||
std::string_view beginView(m_buffer.data() + m_offset, m_buffer.size() - m_offset);
|
||||
std::string_view endView(m_buffer.data(), m_offset);
|
||||
return {beginView, endView};
|
||||
/*++
|
||||
|
||||
Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
Module Name:
|
||||
|
||||
RingBuffer.cpp
|
||||
|
||||
Abstract:
|
||||
|
||||
This file contains definitions for the RingBuffer class.
|
||||
|
||||
--*/
|
||||
|
||||
#include "precomp.h"
|
||||
#include "RingBuffer.h"
|
||||
|
||||
RingBuffer::RingBuffer(size_t size) : m_maxSize(size), m_offset(0)
|
||||
{
|
||||
m_buffer.reserve(size);
|
||||
}
|
||||
|
||||
void RingBuffer::Insert(std::string_view data)
|
||||
{
|
||||
auto lock = m_lock.lock_exclusive();
|
||||
auto remainingData = gsl::make_span(data.data(), data.size());
|
||||
if (remainingData.size() > m_maxSize)
|
||||
{
|
||||
remainingData = remainingData.subspan(remainingData.size() - m_maxSize);
|
||||
}
|
||||
|
||||
const auto bytesAtEnd = std::min(m_maxSize - m_offset, remainingData.size());
|
||||
if (m_offset + bytesAtEnd > m_buffer.size())
|
||||
{
|
||||
m_buffer.resize(m_offset + bytesAtEnd);
|
||||
WI_ASSERT(m_buffer.size() <= m_maxSize);
|
||||
}
|
||||
|
||||
const auto allBuffer = gsl::make_span(m_buffer);
|
||||
const auto beginCopyBuffer = allBuffer.subspan(m_offset, bytesAtEnd);
|
||||
copy(remainingData.subspan(0, bytesAtEnd), beginCopyBuffer);
|
||||
remainingData = remainingData.subspan(bytesAtEnd);
|
||||
if (!remainingData.empty())
|
||||
{
|
||||
copy(remainingData, allBuffer);
|
||||
m_offset = remainingData.size();
|
||||
}
|
||||
else
|
||||
{
|
||||
m_offset += bytesAtEnd;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::string> RingBuffer::GetLastDelimitedStrings(char Delimiter, size_t Count) const
|
||||
{
|
||||
auto lock = m_lock.lock_shared();
|
||||
auto [begin, end] = Contents();
|
||||
std::vector<std::string> results;
|
||||
std::optional<size_t> endIndex;
|
||||
for (size_t i = end.size(); i > 0; i--)
|
||||
{
|
||||
if (results.size() == Count)
|
||||
{
|
||||
break;
|
||||
}
|
||||
|
||||
if (Delimiter == end[i - 1])
|
||||
{
|
||||
if (endIndex.has_value())
|
||||
{
|
||||
results.emplace(results.begin(), &end[i], endIndex.value() - i);
|
||||
endIndex.reset();
|
||||
}
|
||||
else
|
||||
{
|
||||
endIndex = i - 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (results.size() == Count)
|
||||
{
|
||||
return results;
|
||||
}
|
||||
|
||||
std::string partial;
|
||||
if (endIndex.has_value())
|
||||
{
|
||||
partial = std::string{&end[0], endIndex.value()};
|
||||
endIndex.reset();
|
||||
}
|
||||
|
||||
for (size_t i = begin.size(); i > 0; i--)
|
||||
{
|
||||
if (results.size() == Count)
|
||||
{
|
||||
break;
|
||||
}
|
||||
|
||||
if (Delimiter == begin[i - 1])
|
||||
{
|
||||
if (!partial.empty())
|
||||
{
|
||||
// The debug CRT will fastfail if begin[size] is accessed
|
||||
// But in this case it's not a problem because begin.size() - i would be == 0
|
||||
std::string partial_begin{&begin.data()[i], begin.size() - i};
|
||||
results.emplace(results.begin(), partial_begin + partial);
|
||||
partial.clear();
|
||||
}
|
||||
else if (endIndex.has_value())
|
||||
{
|
||||
results.emplace(results.begin(), &begin.data()[i], endIndex.value() - i);
|
||||
endIndex.reset();
|
||||
}
|
||||
else
|
||||
{
|
||||
endIndex = i - 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (results.size() < Count)
|
||||
{
|
||||
// May have lost some data, or this could be the very first line logged.
|
||||
if (!partial.empty())
|
||||
{
|
||||
results.emplace(results.begin(), partial);
|
||||
}
|
||||
else if (endIndex.has_value())
|
||||
{
|
||||
results.emplace(results.begin(), &begin[0], endIndex.value());
|
||||
}
|
||||
}
|
||||
|
||||
return results;
|
||||
}
|
||||
|
||||
std::string RingBuffer::Get() const
|
||||
{
|
||||
auto lock = m_lock.lock_shared();
|
||||
auto [begin, end] = Contents();
|
||||
std::string data;
|
||||
data.reserve(begin.size() + end.size());
|
||||
data.append(begin.data(), begin.size());
|
||||
data.append(end.data(), end.size());
|
||||
return data;
|
||||
}
|
||||
|
||||
std::pair<std::string_view, std::string_view> RingBuffer::Contents() const
|
||||
{
|
||||
std::string_view beginView(m_buffer.data() + m_offset, m_buffer.size() - m_offset);
|
||||
std::string_view endView(m_buffer.data(), m_offset);
|
||||
return {beginView, endView};
|
||||
}
|
||||
@ -1,34 +1,34 @@
|
||||
/*++
|
||||
|
||||
Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
Module Name:
|
||||
|
||||
RingBuffer.h
|
||||
|
||||
Abstract:
|
||||
|
||||
This file contains declarations for the RingBuffer class.
|
||||
|
||||
--*/
|
||||
|
||||
#pragma once
|
||||
|
||||
class RingBuffer
|
||||
{
|
||||
public:
|
||||
RingBuffer() = delete;
|
||||
RingBuffer(size_t size);
|
||||
|
||||
void Insert(std::string_view data);
|
||||
std::vector<std::string> GetLastDelimitedStrings(char Delimiter, size_t Count) const;
|
||||
std::string Get() const;
|
||||
|
||||
private:
|
||||
std::pair<std::string_view, std::string_view> Contents() const;
|
||||
|
||||
mutable wil::srwlock m_lock;
|
||||
std::vector<char> m_buffer;
|
||||
size_t m_maxSize;
|
||||
size_t m_offset;
|
||||
/*++
|
||||
|
||||
Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
Module Name:
|
||||
|
||||
RingBuffer.h
|
||||
|
||||
Abstract:
|
||||
|
||||
This file contains declarations for the RingBuffer class.
|
||||
|
||||
--*/
|
||||
|
||||
#pragma once
|
||||
|
||||
class RingBuffer
|
||||
{
|
||||
public:
|
||||
RingBuffer() = delete;
|
||||
RingBuffer(size_t size);
|
||||
|
||||
void Insert(std::string_view data);
|
||||
std::vector<std::string> GetLastDelimitedStrings(char Delimiter, size_t Count) const;
|
||||
std::string Get() const;
|
||||
|
||||
private:
|
||||
std::pair<std::string_view, std::string_view> Contents() const;
|
||||
|
||||
mutable wil::srwlock m_lock;
|
||||
std::vector<char> m_buffer;
|
||||
size_t m_maxSize;
|
||||
size_t m_offset;
|
||||
};
|
||||
@ -1,104 +1,104 @@
|
||||
// Copyright (C) Microsoft Corporation. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
#include <mutex>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include <iptypes.h>
|
||||
#include <wil/registry.h>
|
||||
|
||||
#include "WslCoreNetworkingSupport.h"
|
||||
#include "RegistryWatcher.h"
|
||||
|
||||
namespace wsl::core::networking {
|
||||
struct DnsInfo
|
||||
{
|
||||
std::vector<std::string> Servers;
|
||||
std::vector<std::string> Domains;
|
||||
};
|
||||
|
||||
enum class DnsSettingsFlags
|
||||
{
|
||||
None = 0x0,
|
||||
IncludeVpn = 0x1,
|
||||
IncludeIpv6Servers = 0x2,
|
||||
IncludeAllSuffixes = 0x4
|
||||
};
|
||||
DEFINE_ENUM_FLAG_OPERATORS(DnsSettingsFlags);
|
||||
|
||||
inline bool operator==(const DnsInfo& lhs, const DnsInfo& rhs) noexcept
|
||||
{
|
||||
return lhs.Servers == rhs.Servers && lhs.Domains == rhs.Domains;
|
||||
}
|
||||
inline bool operator!=(const DnsInfo& lhs, const DnsInfo& rhs) noexcept
|
||||
{
|
||||
return !(lhs == rhs);
|
||||
}
|
||||
|
||||
std::string GenerateResolvConf(_In_ const DnsInfo& Info);
|
||||
|
||||
std::vector<std::string> GetAllDnsSuffixes(const std::vector<IpAdapterAddress>& AdapterAddresses);
|
||||
|
||||
DWORD GetBestInterface();
|
||||
|
||||
class HostDnsInfo
|
||||
{
|
||||
public:
|
||||
DnsInfo GetDnsSettings(_In_ DnsSettingsFlags Flags);
|
||||
|
||||
void UpdateNetworkInformation();
|
||||
|
||||
static DnsInfo GetDnsTunnelingSettings(const std::wstring& dnsTunnelingNameserver);
|
||||
|
||||
const std::vector<IpAdapterAddress>& CurrentAddresses() const
|
||||
{
|
||||
return m_addresses;
|
||||
}
|
||||
|
||||
private:
|
||||
/// <summary>
|
||||
/// Internal function to retrieve the latest copy of interface information.
|
||||
/// </summary>
|
||||
std::vector<IpAdapterAddress> GetAdapterAddresses();
|
||||
|
||||
/// <summary>
|
||||
/// Internal function to retrieve interface DNS servers.
|
||||
/// </summary>
|
||||
std::vector<std::string> GetInterfaceDnsServers(const std::vector<IpAdapterAddress>& AdapterAddresses, _In_ DnsSettingsFlags Flags);
|
||||
|
||||
/// <summary>
|
||||
/// Internal function to retrieve all Windows DNS suffixes.
|
||||
/// </summary>
|
||||
static std::vector<std::string> GetInterfaceDnsSuffixes(const std::vector<IpAdapterAddress>& AdapterAddresses);
|
||||
|
||||
/// <summary>
|
||||
/// Internal function to convert DNS server addresses into strings.
|
||||
/// </summary>
|
||||
static std::vector<std::string> GetDnsServerStrings(_In_ const PIP_ADAPTER_DNS_SERVER_ADDRESS& DnsServer, _In_ USHORT IpFamilyFilter, _In_ USHORT MaxValues);
|
||||
|
||||
/// <summary>
|
||||
/// Stores latest copy of interface information.
|
||||
/// </summary>
|
||||
std::mutex m_lock;
|
||||
_Guarded_by_(m_lock) std::vector<IpAdapterAddress> m_addresses;
|
||||
};
|
||||
|
||||
using RegistryChangeCallback = std::function<void()>;
|
||||
|
||||
/// <summary>
|
||||
/// Class used to get notifications when Windows DNS suffixes are updated in registry.
|
||||
/// </summary>
|
||||
class DnsSuffixRegistryWatcher
|
||||
{
|
||||
public:
|
||||
DnsSuffixRegistryWatcher(RegistryChangeCallback&& reportRegistryChange);
|
||||
~DnsSuffixRegistryWatcher() noexcept = default;
|
||||
|
||||
private:
|
||||
RegistryChangeCallback m_reportRegistryChange;
|
||||
|
||||
std::vector<wistd::unique_ptr<wsl::windows::common::slim_registry_watcher>> m_registryWatchers;
|
||||
};
|
||||
|
||||
// Copyright (C) Microsoft Corporation. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
#include <mutex>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include <iptypes.h>
|
||||
#include <wil/registry.h>
|
||||
|
||||
#include "WslCoreNetworkingSupport.h"
|
||||
#include "RegistryWatcher.h"
|
||||
|
||||
namespace wsl::core::networking {
|
||||
struct DnsInfo
|
||||
{
|
||||
std::vector<std::string> Servers;
|
||||
std::vector<std::string> Domains;
|
||||
};
|
||||
|
||||
enum class DnsSettingsFlags
|
||||
{
|
||||
None = 0x0,
|
||||
IncludeVpn = 0x1,
|
||||
IncludeIpv6Servers = 0x2,
|
||||
IncludeAllSuffixes = 0x4
|
||||
};
|
||||
DEFINE_ENUM_FLAG_OPERATORS(DnsSettingsFlags);
|
||||
|
||||
inline bool operator==(const DnsInfo& lhs, const DnsInfo& rhs) noexcept
|
||||
{
|
||||
return lhs.Servers == rhs.Servers && lhs.Domains == rhs.Domains;
|
||||
}
|
||||
inline bool operator!=(const DnsInfo& lhs, const DnsInfo& rhs) noexcept
|
||||
{
|
||||
return !(lhs == rhs);
|
||||
}
|
||||
|
||||
std::string GenerateResolvConf(_In_ const DnsInfo& Info);
|
||||
|
||||
std::vector<std::string> GetAllDnsSuffixes(const std::vector<IpAdapterAddress>& AdapterAddresses);
|
||||
|
||||
DWORD GetBestInterface();
|
||||
|
||||
class HostDnsInfo
|
||||
{
|
||||
public:
|
||||
DnsInfo GetDnsSettings(_In_ DnsSettingsFlags Flags);
|
||||
|
||||
void UpdateNetworkInformation();
|
||||
|
||||
static DnsInfo GetDnsTunnelingSettings(const std::wstring& dnsTunnelingNameserver);
|
||||
|
||||
const std::vector<IpAdapterAddress>& CurrentAddresses() const
|
||||
{
|
||||
return m_addresses;
|
||||
}
|
||||
|
||||
private:
|
||||
/// <summary>
|
||||
/// Internal function to retrieve the latest copy of interface information.
|
||||
/// </summary>
|
||||
std::vector<IpAdapterAddress> GetAdapterAddresses();
|
||||
|
||||
/// <summary>
|
||||
/// Internal function to retrieve interface DNS servers.
|
||||
/// </summary>
|
||||
std::vector<std::string> GetInterfaceDnsServers(const std::vector<IpAdapterAddress>& AdapterAddresses, _In_ DnsSettingsFlags Flags);
|
||||
|
||||
/// <summary>
|
||||
/// Internal function to retrieve all Windows DNS suffixes.
|
||||
/// </summary>
|
||||
static std::vector<std::string> GetInterfaceDnsSuffixes(const std::vector<IpAdapterAddress>& AdapterAddresses);
|
||||
|
||||
/// <summary>
|
||||
/// Internal function to convert DNS server addresses into strings.
|
||||
/// </summary>
|
||||
static std::vector<std::string> GetDnsServerStrings(_In_ const PIP_ADAPTER_DNS_SERVER_ADDRESS& DnsServer, _In_ USHORT IpFamilyFilter, _In_ USHORT MaxValues);
|
||||
|
||||
/// <summary>
|
||||
/// Stores latest copy of interface information.
|
||||
/// </summary>
|
||||
std::mutex m_lock;
|
||||
_Guarded_by_(m_lock) std::vector<IpAdapterAddress> m_addresses;
|
||||
};
|
||||
|
||||
using RegistryChangeCallback = std::function<void()>;
|
||||
|
||||
/// <summary>
|
||||
/// Class used to get notifications when Windows DNS suffixes are updated in registry.
|
||||
/// </summary>
|
||||
class DnsSuffixRegistryWatcher
|
||||
{
|
||||
public:
|
||||
DnsSuffixRegistryWatcher(RegistryChangeCallback&& reportRegistryChange);
|
||||
~DnsSuffixRegistryWatcher() noexcept = default;
|
||||
|
||||
private:
|
||||
RegistryChangeCallback m_reportRegistryChange;
|
||||
|
||||
std::vector<wistd::unique_ptr<wsl::windows::common::slim_registry_watcher>> m_registryWatchers;
|
||||
};
|
||||
|
||||
} // namespace wsl::core::networking
|
||||
@ -1,359 +1,359 @@
|
||||
/*++
|
||||
|
||||
Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
Module Name:
|
||||
|
||||
WslCoreMessageQueue.h
|
||||
|
||||
Abstract:
|
||||
|
||||
This file contains a queuing implementation, guaranteeing running function objects
|
||||
with guaranteed serialization in a threadpool thread
|
||||
|
||||
--*/
|
||||
|
||||
#pragma once
|
||||
#include <deque>
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <variant>
|
||||
#include <windows.h>
|
||||
#include <wil/resource.h>
|
||||
|
||||
namespace wsl::core {
|
||||
// forward-declare classes that can instantiate a WslThreadPoolWaitableResult object
|
||||
class WslCoreMessageQueue;
|
||||
|
||||
class WslBaseThreadPoolWaitableResult
|
||||
{
|
||||
public:
|
||||
virtual ~WslBaseThreadPoolWaitableResult() noexcept = default;
|
||||
|
||||
private:
|
||||
// limit who can run() and abort()
|
||||
friend class WslCoreMessageQueue;
|
||||
|
||||
virtual void run() noexcept = 0;
|
||||
virtual void abort() noexcept = 0;
|
||||
};
|
||||
|
||||
template <typename TReturn>
|
||||
class WslThreadPoolWaitableResult : public WslBaseThreadPoolWaitableResult
|
||||
{
|
||||
public:
|
||||
// throws a wil exception on failure
|
||||
template <typename FunctorType>
|
||||
explicit WslThreadPoolWaitableResult(FunctorType&& functor) : m_function(std::forward<FunctorType>(functor))
|
||||
{
|
||||
}
|
||||
|
||||
~WslThreadPoolWaitableResult() noexcept override = default;
|
||||
|
||||
// returns ERROR_SUCCESS if the callback ran to completion
|
||||
// returns ERROR_TIMEOUT if this wait timed out
|
||||
// - this can be called multiple times if needing to probe
|
||||
// any other error code resulted from attempting to run the callback
|
||||
// - meaning it did *not* run to completion
|
||||
DWORD wait(DWORD timeout) const noexcept
|
||||
{
|
||||
if (!m_completionSignal.wait(timeout))
|
||||
{
|
||||
// not setting m_internalError to timeout
|
||||
// since the caller is allowed to try to wait() again later
|
||||
return ERROR_TIMEOUT;
|
||||
}
|
||||
const auto lock = m_lock.lock_shared();
|
||||
return m_internalError;
|
||||
}
|
||||
|
||||
// waitable event handle, signaled when the callback has run to completion (or failed)
|
||||
HANDLE notification_event() const noexcept
|
||||
{
|
||||
return m_completionSignal.get();
|
||||
}
|
||||
|
||||
const TReturn& read_result() const noexcept
|
||||
{
|
||||
return result;
|
||||
}
|
||||
|
||||
// move the result out of the object for move-only types
|
||||
TReturn move_result() noexcept
|
||||
{
|
||||
TReturn move_out(std::move(result));
|
||||
return move_out;
|
||||
}
|
||||
|
||||
// non-copyable
|
||||
WslThreadPoolWaitableResult(const WslThreadPoolWaitableResult&) = delete;
|
||||
WslThreadPoolWaitableResult& operator=(const WslThreadPoolWaitableResult&) = delete;
|
||||
|
||||
private:
|
||||
void run() noexcept override
|
||||
{
|
||||
// we are now running in the TP callback
|
||||
{
|
||||
const auto lock = m_lock.lock_exclusive();
|
||||
if (m_runStatus != RunStatus::NotYetRun)
|
||||
{
|
||||
// return early - the caller has already canceled this
|
||||
return;
|
||||
}
|
||||
m_runStatus = RunStatus::Running;
|
||||
}
|
||||
|
||||
DWORD error = NO_ERROR;
|
||||
try
|
||||
{
|
||||
result = std::move(m_function());
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
const HRESULT hr = wil::ResultFromCaughtException();
|
||||
// HRESULT_TO_WIN32
|
||||
error = (HRESULT_FACILITY(hr) == FACILITY_WIN32) ? HRESULT_CODE(hr) : hr;
|
||||
}
|
||||
|
||||
const auto lock = m_lock.lock_exclusive();
|
||||
WI_ASSERT(m_runStatus == RunStatus::Running);
|
||||
m_runStatus = RunStatus::RanToCompletion;
|
||||
m_internalError = error;
|
||||
m_completionSignal.SetEvent();
|
||||
}
|
||||
|
||||
void abort() noexcept override
|
||||
{
|
||||
const auto lock = m_lock.lock_exclusive();
|
||||
// only override the error if we know we haven't started running their functor
|
||||
if (m_runStatus == RunStatus::NotYetRun)
|
||||
{
|
||||
m_runStatus = RunStatus::Canceled;
|
||||
m_internalError = ERROR_CANCELLED;
|
||||
m_completionSignal.SetEvent();
|
||||
}
|
||||
}
|
||||
|
||||
std::function<TReturn(void)> m_function;
|
||||
// a notification event
|
||||
wil::unique_event m_completionSignal{wil::EventOptions::ManualReset};
|
||||
mutable wil::srwlock m_lock;
|
||||
TReturn result{};
|
||||
DWORD m_internalError = NO_ERROR;
|
||||
|
||||
enum class RunStatus
|
||||
{
|
||||
NotYetRun,
|
||||
Running,
|
||||
RanToCompletion,
|
||||
Canceled
|
||||
} m_runStatus{RunStatus::NotYetRun};
|
||||
};
|
||||
|
||||
class WslCoreMessageQueue
|
||||
{
|
||||
public:
|
||||
WslCoreMessageQueue() : m_tpEnvironment(0, 1)
|
||||
{
|
||||
// create a single-threaded threadpool
|
||||
m_tpHandle = m_tpEnvironment.create_tp(WorkCallback, this);
|
||||
}
|
||||
|
||||
template <typename TReturn, typename FunctorType>
|
||||
std::shared_ptr<WslThreadPoolWaitableResult<TReturn>> submit_with_results(FunctorType&& functor) noexcept
|
||||
try
|
||||
{
|
||||
FAIL_FAST_IF(m_tpHandle.get() == nullptr);
|
||||
|
||||
const auto new_result = std::make_shared<WslThreadPoolWaitableResult<TReturn>>(std::forward<FunctorType>(functor));
|
||||
// scope to the queue lock
|
||||
{
|
||||
const auto queueLock = m_lock.lock_exclusive();
|
||||
THROW_HR_IF(HRESULT_FROM_WIN32(ERROR_CANCELLED), m_isCanceled);
|
||||
m_workItems.emplace_back(new_result);
|
||||
}
|
||||
|
||||
// always maintain a 1:1 ratio for calls to SubmitWorkWithResults() and ::SubmitThreadpoolWork
|
||||
SubmitThreadpoolWork(m_tpHandle.get());
|
||||
return new_result;
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
LOG_CAUGHT_EXCEPTION();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
template <typename FunctorType>
|
||||
bool submit(FunctorType&& functor) noexcept
|
||||
try
|
||||
{
|
||||
FAIL_FAST_IF(m_tpHandle.get() == nullptr);
|
||||
|
||||
// scope to the queue lock
|
||||
{
|
||||
const auto queueLock = m_lock.lock_exclusive();
|
||||
THROW_HR_IF(HRESULT_FROM_WIN32(ERROR_CANCELLED), m_isCanceled);
|
||||
m_workItems.emplace_back(std::forward<SimpleFunction_t>(functor));
|
||||
}
|
||||
|
||||
// always maintain a 1:1 ratio for calls to SubmitWork() and ::SubmitThreadpoolWork
|
||||
SubmitThreadpoolWork(m_tpHandle.get());
|
||||
return true;
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
LOG_CAUGHT_EXCEPTION();
|
||||
return false;
|
||||
}
|
||||
|
||||
// functors must return type HRESULT
|
||||
template <typename FunctorType>
|
||||
HRESULT submit_and_wait(FunctorType&& functor) noexcept
|
||||
try
|
||||
{
|
||||
HRESULT hr = HRESULT_FROM_WIN32(ERROR_OUTOFMEMORY);
|
||||
if (const auto waitableResult = submit_with_results<HRESULT>(std::forward<FunctorType>(functor)))
|
||||
{
|
||||
hr = HRESULT_FROM_WIN32(waitableResult->wait(INFINITE));
|
||||
if (SUCCEEDED(hr))
|
||||
{
|
||||
hr = waitableResult->read_result();
|
||||
}
|
||||
}
|
||||
return hr;
|
||||
}
|
||||
CATCH_RETURN()
|
||||
|
||||
// cancels anything queued to the TP - this WslCoreMessageQueue instance can no longer be used
|
||||
void cancel() noexcept
|
||||
try
|
||||
{
|
||||
if (m_tpHandle)
|
||||
{
|
||||
// immediately release anyone waiting for these workitems not yet run
|
||||
{
|
||||
const auto queueLock = m_lock.lock_exclusive();
|
||||
m_isCanceled = true;
|
||||
|
||||
for (const auto& work : m_workItems)
|
||||
{
|
||||
// signal that these are canceled before we shutdown the TP which they could be scheduled
|
||||
if (const auto* pWaitableWorkitem = std::get_if<WaitableFunction_t>(&work))
|
||||
{
|
||||
(*pWaitableWorkitem)->abort();
|
||||
}
|
||||
}
|
||||
|
||||
m_workItems.clear();
|
||||
}
|
||||
|
||||
// force the m_tpHandle to wait and close the TP
|
||||
m_tpHandle.reset();
|
||||
m_tpEnvironment.reset();
|
||||
}
|
||||
}
|
||||
CATCH_LOG()
|
||||
|
||||
bool isRunningInQueue() const noexcept
|
||||
{
|
||||
const auto currentThreadId = GetThreadId(GetCurrentThread());
|
||||
return currentThreadId == static_cast<DWORD>(InterlockedCompareExchange64(&m_threadpoolThreadId, 0ll, 0ll));
|
||||
}
|
||||
|
||||
~WslCoreMessageQueue() noexcept
|
||||
{
|
||||
cancel();
|
||||
}
|
||||
|
||||
WslCoreMessageQueue(const WslCoreMessageQueue&) = delete;
|
||||
WslCoreMessageQueue& operator=(const WslCoreMessageQueue&) = delete;
|
||||
WslCoreMessageQueue(WslCoreMessageQueue&&) = delete;
|
||||
WslCoreMessageQueue& operator=(WslCoreMessageQueue&&) = delete;
|
||||
|
||||
private:
|
||||
struct TPEnvironment
|
||||
{
|
||||
using unique_tp_env = wil::unique_struct<TP_CALLBACK_ENVIRON, decltype(&DestroyThreadpoolEnvironment), DestroyThreadpoolEnvironment>;
|
||||
unique_tp_env m_tpEnvironment;
|
||||
|
||||
using unique_tp_pool = wil::unique_any<PTP_POOL, decltype(&CloseThreadpool), CloseThreadpool>;
|
||||
unique_tp_pool m_threadPool;
|
||||
|
||||
TPEnvironment(DWORD countMinThread, DWORD countMaxThread)
|
||||
{
|
||||
InitializeThreadpoolEnvironment(&m_tpEnvironment);
|
||||
|
||||
m_threadPool.reset(CreateThreadpool(nullptr));
|
||||
THROW_LAST_ERROR_IF_NULL(m_threadPool.get());
|
||||
|
||||
// Set min and max thread counts for custom thread pool
|
||||
THROW_LAST_ERROR_IF(!::SetThreadpoolThreadMinimum(m_threadPool.get(), countMinThread));
|
||||
SetThreadpoolThreadMaximum(m_threadPool.get(), countMaxThread);
|
||||
SetThreadpoolCallbackPool(&m_tpEnvironment, m_threadPool.get());
|
||||
}
|
||||
|
||||
wil::unique_threadpool_work create_tp(PTP_WORK_CALLBACK callback, void* pv)
|
||||
{
|
||||
wil::unique_threadpool_work newThreadpool(CreateThreadpoolWork(callback, pv, (m_threadPool) ? &m_tpEnvironment : nullptr));
|
||||
THROW_LAST_ERROR_IF_NULL(newThreadpool.get());
|
||||
return newThreadpool;
|
||||
}
|
||||
|
||||
void reset()
|
||||
{
|
||||
m_threadPool.reset();
|
||||
m_tpEnvironment.reset();
|
||||
}
|
||||
};
|
||||
|
||||
using SimpleFunction_t = std::function<void()>;
|
||||
using WaitableFunction_t = std::shared_ptr<WslBaseThreadPoolWaitableResult>;
|
||||
using FunctionVariant_t = std::variant<SimpleFunction_t, WaitableFunction_t>;
|
||||
|
||||
// the lock must be destroyed *after* the TP object (thus must be declared first)
|
||||
// since the lock is used in the TP callback
|
||||
// the lock is mutable to allow us to acquire the lock in const methods
|
||||
mutable wil::srwlock m_lock;
|
||||
TPEnvironment m_tpEnvironment;
|
||||
wil::unique_threadpool_work m_tpHandle;
|
||||
std::deque<FunctionVariant_t> m_workItems;
|
||||
mutable LONG64 m_threadpoolThreadId{0}; // useful for callers to assert they are running within the queue
|
||||
bool m_isCanceled{false};
|
||||
|
||||
static void CALLBACK WorkCallback(PTP_CALLBACK_INSTANCE, void* Context, PTP_WORK) noexcept
|
||||
try
|
||||
{
|
||||
auto* pThis = static_cast<WslCoreMessageQueue*>(Context);
|
||||
|
||||
FunctionVariant_t work;
|
||||
{
|
||||
const auto queueLock = pThis->m_lock.lock_exclusive();
|
||||
|
||||
if (pThis->m_workItems.empty())
|
||||
{
|
||||
// pThis object is being destroyed and the queue was cleared
|
||||
return;
|
||||
}
|
||||
|
||||
std::swap(work, pThis->m_workItems.front());
|
||||
pThis->m_workItems.pop_front();
|
||||
|
||||
InterlockedExchange64(&pThis->m_threadpoolThreadId, GetThreadId(GetCurrentThread()));
|
||||
}
|
||||
|
||||
// run the tasks outside the WslCoreMessageQueue lock
|
||||
const auto resetThreadIdOnExit = wil::scope_exit([pThis] { InterlockedExchange64(&pThis->m_threadpoolThreadId, 0ll); });
|
||||
if (work.index() == 0)
|
||||
{
|
||||
const auto& workItem = std::get<SimpleFunction_t>(work);
|
||||
workItem();
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto& waitableWorkItem = std::get<WaitableFunction_t>(work);
|
||||
waitableWorkItem->run();
|
||||
}
|
||||
}
|
||||
CATCH_LOG()
|
||||
};
|
||||
} // namespace wsl::core
|
||||
/*++
|
||||
|
||||
Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
Module Name:
|
||||
|
||||
WslCoreMessageQueue.h
|
||||
|
||||
Abstract:
|
||||
|
||||
This file contains a queuing implementation, guaranteeing running function objects
|
||||
with guaranteed serialization in a threadpool thread
|
||||
|
||||
--*/
|
||||
|
||||
#pragma once
|
||||
#include <deque>
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <variant>
|
||||
#include <windows.h>
|
||||
#include <wil/resource.h>
|
||||
|
||||
namespace wsl::core {
|
||||
// forward-declare classes that can instantiate a WslThreadPoolWaitableResult object
|
||||
class WslCoreMessageQueue;
|
||||
|
||||
class WslBaseThreadPoolWaitableResult
|
||||
{
|
||||
public:
|
||||
virtual ~WslBaseThreadPoolWaitableResult() noexcept = default;
|
||||
|
||||
private:
|
||||
// limit who can run() and abort()
|
||||
friend class WslCoreMessageQueue;
|
||||
|
||||
virtual void run() noexcept = 0;
|
||||
virtual void abort() noexcept = 0;
|
||||
};
|
||||
|
||||
template <typename TReturn>
|
||||
class WslThreadPoolWaitableResult : public WslBaseThreadPoolWaitableResult
|
||||
{
|
||||
public:
|
||||
// throws a wil exception on failure
|
||||
template <typename FunctorType>
|
||||
explicit WslThreadPoolWaitableResult(FunctorType&& functor) : m_function(std::forward<FunctorType>(functor))
|
||||
{
|
||||
}
|
||||
|
||||
~WslThreadPoolWaitableResult() noexcept override = default;
|
||||
|
||||
// returns ERROR_SUCCESS if the callback ran to completion
|
||||
// returns ERROR_TIMEOUT if this wait timed out
|
||||
// - this can be called multiple times if needing to probe
|
||||
// any other error code resulted from attempting to run the callback
|
||||
// - meaning it did *not* run to completion
|
||||
DWORD wait(DWORD timeout) const noexcept
|
||||
{
|
||||
if (!m_completionSignal.wait(timeout))
|
||||
{
|
||||
// not setting m_internalError to timeout
|
||||
// since the caller is allowed to try to wait() again later
|
||||
return ERROR_TIMEOUT;
|
||||
}
|
||||
const auto lock = m_lock.lock_shared();
|
||||
return m_internalError;
|
||||
}
|
||||
|
||||
// waitable event handle, signaled when the callback has run to completion (or failed)
|
||||
HANDLE notification_event() const noexcept
|
||||
{
|
||||
return m_completionSignal.get();
|
||||
}
|
||||
|
||||
const TReturn& read_result() const noexcept
|
||||
{
|
||||
return result;
|
||||
}
|
||||
|
||||
// move the result out of the object for move-only types
|
||||
TReturn move_result() noexcept
|
||||
{
|
||||
TReturn move_out(std::move(result));
|
||||
return move_out;
|
||||
}
|
||||
|
||||
// non-copyable
|
||||
WslThreadPoolWaitableResult(const WslThreadPoolWaitableResult&) = delete;
|
||||
WslThreadPoolWaitableResult& operator=(const WslThreadPoolWaitableResult&) = delete;
|
||||
|
||||
private:
|
||||
void run() noexcept override
|
||||
{
|
||||
// we are now running in the TP callback
|
||||
{
|
||||
const auto lock = m_lock.lock_exclusive();
|
||||
if (m_runStatus != RunStatus::NotYetRun)
|
||||
{
|
||||
// return early - the caller has already canceled this
|
||||
return;
|
||||
}
|
||||
m_runStatus = RunStatus::Running;
|
||||
}
|
||||
|
||||
DWORD error = NO_ERROR;
|
||||
try
|
||||
{
|
||||
result = std::move(m_function());
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
const HRESULT hr = wil::ResultFromCaughtException();
|
||||
// HRESULT_TO_WIN32
|
||||
error = (HRESULT_FACILITY(hr) == FACILITY_WIN32) ? HRESULT_CODE(hr) : hr;
|
||||
}
|
||||
|
||||
const auto lock = m_lock.lock_exclusive();
|
||||
WI_ASSERT(m_runStatus == RunStatus::Running);
|
||||
m_runStatus = RunStatus::RanToCompletion;
|
||||
m_internalError = error;
|
||||
m_completionSignal.SetEvent();
|
||||
}
|
||||
|
||||
void abort() noexcept override
|
||||
{
|
||||
const auto lock = m_lock.lock_exclusive();
|
||||
// only override the error if we know we haven't started running their functor
|
||||
if (m_runStatus == RunStatus::NotYetRun)
|
||||
{
|
||||
m_runStatus = RunStatus::Canceled;
|
||||
m_internalError = ERROR_CANCELLED;
|
||||
m_completionSignal.SetEvent();
|
||||
}
|
||||
}
|
||||
|
||||
std::function<TReturn(void)> m_function;
|
||||
// a notification event
|
||||
wil::unique_event m_completionSignal{wil::EventOptions::ManualReset};
|
||||
mutable wil::srwlock m_lock;
|
||||
TReturn result{};
|
||||
DWORD m_internalError = NO_ERROR;
|
||||
|
||||
enum class RunStatus
|
||||
{
|
||||
NotYetRun,
|
||||
Running,
|
||||
RanToCompletion,
|
||||
Canceled
|
||||
} m_runStatus{RunStatus::NotYetRun};
|
||||
};
|
||||
|
||||
class WslCoreMessageQueue
|
||||
{
|
||||
public:
|
||||
WslCoreMessageQueue() : m_tpEnvironment(0, 1)
|
||||
{
|
||||
// create a single-threaded threadpool
|
||||
m_tpHandle = m_tpEnvironment.create_tp(WorkCallback, this);
|
||||
}
|
||||
|
||||
template <typename TReturn, typename FunctorType>
|
||||
std::shared_ptr<WslThreadPoolWaitableResult<TReturn>> submit_with_results(FunctorType&& functor) noexcept
|
||||
try
|
||||
{
|
||||
FAIL_FAST_IF(m_tpHandle.get() == nullptr);
|
||||
|
||||
const auto new_result = std::make_shared<WslThreadPoolWaitableResult<TReturn>>(std::forward<FunctorType>(functor));
|
||||
// scope to the queue lock
|
||||
{
|
||||
const auto queueLock = m_lock.lock_exclusive();
|
||||
THROW_HR_IF(HRESULT_FROM_WIN32(ERROR_CANCELLED), m_isCanceled);
|
||||
m_workItems.emplace_back(new_result);
|
||||
}
|
||||
|
||||
// always maintain a 1:1 ratio for calls to SubmitWorkWithResults() and ::SubmitThreadpoolWork
|
||||
SubmitThreadpoolWork(m_tpHandle.get());
|
||||
return new_result;
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
LOG_CAUGHT_EXCEPTION();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
template <typename FunctorType>
|
||||
bool submit(FunctorType&& functor) noexcept
|
||||
try
|
||||
{
|
||||
FAIL_FAST_IF(m_tpHandle.get() == nullptr);
|
||||
|
||||
// scope to the queue lock
|
||||
{
|
||||
const auto queueLock = m_lock.lock_exclusive();
|
||||
THROW_HR_IF(HRESULT_FROM_WIN32(ERROR_CANCELLED), m_isCanceled);
|
||||
m_workItems.emplace_back(std::forward<SimpleFunction_t>(functor));
|
||||
}
|
||||
|
||||
// always maintain a 1:1 ratio for calls to SubmitWork() and ::SubmitThreadpoolWork
|
||||
SubmitThreadpoolWork(m_tpHandle.get());
|
||||
return true;
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
LOG_CAUGHT_EXCEPTION();
|
||||
return false;
|
||||
}
|
||||
|
||||
// functors must return type HRESULT
|
||||
template <typename FunctorType>
|
||||
HRESULT submit_and_wait(FunctorType&& functor) noexcept
|
||||
try
|
||||
{
|
||||
HRESULT hr = HRESULT_FROM_WIN32(ERROR_OUTOFMEMORY);
|
||||
if (const auto waitableResult = submit_with_results<HRESULT>(std::forward<FunctorType>(functor)))
|
||||
{
|
||||
hr = HRESULT_FROM_WIN32(waitableResult->wait(INFINITE));
|
||||
if (SUCCEEDED(hr))
|
||||
{
|
||||
hr = waitableResult->read_result();
|
||||
}
|
||||
}
|
||||
return hr;
|
||||
}
|
||||
CATCH_RETURN()
|
||||
|
||||
// cancels anything queued to the TP - this WslCoreMessageQueue instance can no longer be used
|
||||
void cancel() noexcept
|
||||
try
|
||||
{
|
||||
if (m_tpHandle)
|
||||
{
|
||||
// immediately release anyone waiting for these workitems not yet run
|
||||
{
|
||||
const auto queueLock = m_lock.lock_exclusive();
|
||||
m_isCanceled = true;
|
||||
|
||||
for (const auto& work : m_workItems)
|
||||
{
|
||||
// signal that these are canceled before we shutdown the TP which they could be scheduled
|
||||
if (const auto* pWaitableWorkitem = std::get_if<WaitableFunction_t>(&work))
|
||||
{
|
||||
(*pWaitableWorkitem)->abort();
|
||||
}
|
||||
}
|
||||
|
||||
m_workItems.clear();
|
||||
}
|
||||
|
||||
// force the m_tpHandle to wait and close the TP
|
||||
m_tpHandle.reset();
|
||||
m_tpEnvironment.reset();
|
||||
}
|
||||
}
|
||||
CATCH_LOG()
|
||||
|
||||
bool isRunningInQueue() const noexcept
|
||||
{
|
||||
const auto currentThreadId = GetThreadId(GetCurrentThread());
|
||||
return currentThreadId == static_cast<DWORD>(InterlockedCompareExchange64(&m_threadpoolThreadId, 0ll, 0ll));
|
||||
}
|
||||
|
||||
~WslCoreMessageQueue() noexcept
|
||||
{
|
||||
cancel();
|
||||
}
|
||||
|
||||
WslCoreMessageQueue(const WslCoreMessageQueue&) = delete;
|
||||
WslCoreMessageQueue& operator=(const WslCoreMessageQueue&) = delete;
|
||||
WslCoreMessageQueue(WslCoreMessageQueue&&) = delete;
|
||||
WslCoreMessageQueue& operator=(WslCoreMessageQueue&&) = delete;
|
||||
|
||||
private:
|
||||
struct TPEnvironment
|
||||
{
|
||||
using unique_tp_env = wil::unique_struct<TP_CALLBACK_ENVIRON, decltype(&DestroyThreadpoolEnvironment), DestroyThreadpoolEnvironment>;
|
||||
unique_tp_env m_tpEnvironment;
|
||||
|
||||
using unique_tp_pool = wil::unique_any<PTP_POOL, decltype(&CloseThreadpool), CloseThreadpool>;
|
||||
unique_tp_pool m_threadPool;
|
||||
|
||||
TPEnvironment(DWORD countMinThread, DWORD countMaxThread)
|
||||
{
|
||||
InitializeThreadpoolEnvironment(&m_tpEnvironment);
|
||||
|
||||
m_threadPool.reset(CreateThreadpool(nullptr));
|
||||
THROW_LAST_ERROR_IF_NULL(m_threadPool.get());
|
||||
|
||||
// Set min and max thread counts for custom thread pool
|
||||
THROW_LAST_ERROR_IF(!::SetThreadpoolThreadMinimum(m_threadPool.get(), countMinThread));
|
||||
SetThreadpoolThreadMaximum(m_threadPool.get(), countMaxThread);
|
||||
SetThreadpoolCallbackPool(&m_tpEnvironment, m_threadPool.get());
|
||||
}
|
||||
|
||||
wil::unique_threadpool_work create_tp(PTP_WORK_CALLBACK callback, void* pv)
|
||||
{
|
||||
wil::unique_threadpool_work newThreadpool(CreateThreadpoolWork(callback, pv, (m_threadPool) ? &m_tpEnvironment : nullptr));
|
||||
THROW_LAST_ERROR_IF_NULL(newThreadpool.get());
|
||||
return newThreadpool;
|
||||
}
|
||||
|
||||
void reset()
|
||||
{
|
||||
m_threadPool.reset();
|
||||
m_tpEnvironment.reset();
|
||||
}
|
||||
};
|
||||
|
||||
using SimpleFunction_t = std::function<void()>;
|
||||
using WaitableFunction_t = std::shared_ptr<WslBaseThreadPoolWaitableResult>;
|
||||
using FunctionVariant_t = std::variant<SimpleFunction_t, WaitableFunction_t>;
|
||||
|
||||
// the lock must be destroyed *after* the TP object (thus must be declared first)
|
||||
// since the lock is used in the TP callback
|
||||
// the lock is mutable to allow us to acquire the lock in const methods
|
||||
mutable wil::srwlock m_lock;
|
||||
TPEnvironment m_tpEnvironment;
|
||||
wil::unique_threadpool_work m_tpHandle;
|
||||
std::deque<FunctionVariant_t> m_workItems;
|
||||
mutable LONG64 m_threadpoolThreadId{0}; // useful for callers to assert they are running within the queue
|
||||
bool m_isCanceled{false};
|
||||
|
||||
static void CALLBACK WorkCallback(PTP_CALLBACK_INSTANCE, void* Context, PTP_WORK) noexcept
|
||||
try
|
||||
{
|
||||
auto* pThis = static_cast<WslCoreMessageQueue*>(Context);
|
||||
|
||||
FunctionVariant_t work;
|
||||
{
|
||||
const auto queueLock = pThis->m_lock.lock_exclusive();
|
||||
|
||||
if (pThis->m_workItems.empty())
|
||||
{
|
||||
// pThis object is being destroyed and the queue was cleared
|
||||
return;
|
||||
}
|
||||
|
||||
std::swap(work, pThis->m_workItems.front());
|
||||
pThis->m_workItems.pop_front();
|
||||
|
||||
InterlockedExchange64(&pThis->m_threadpoolThreadId, GetThreadId(GetCurrentThread()));
|
||||
}
|
||||
|
||||
// run the tasks outside the WslCoreMessageQueue lock
|
||||
const auto resetThreadIdOnExit = wil::scope_exit([pThis] { InterlockedExchange64(&pThis->m_threadpoolThreadId, 0ll); });
|
||||
if (work.index() == 0)
|
||||
{
|
||||
const auto& workItem = std::get<SimpleFunction_t>(work);
|
||||
workItem();
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto& waitableWorkItem = std::get<WaitableFunction_t>(work);
|
||||
waitableWorkItem->run();
|
||||
}
|
||||
}
|
||||
CATCH_LOG()
|
||||
};
|
||||
} // namespace wsl::core
|
||||
@ -1,110 +1,110 @@
|
||||
// Copyright (C) Microsoft Corporation. All rights reserved.
|
||||
|
||||
#include "precomp.h"
|
||||
#include "hns_schema.h"
|
||||
#include "WslCoreNetworkEndpointSettings.h"
|
||||
#include "WslCoreHostDnsInfo.h"
|
||||
|
||||
using namespace wsl::shared;
|
||||
|
||||
std::shared_ptr<wsl::core::networking::NetworkSettings> wsl::core::networking::GetEndpointSettings(const hns::HNSEndpoint& properties)
|
||||
{
|
||||
EndpointIpAddress address{};
|
||||
address.Address = windows::common::string::StringToSockAddrInet(properties.IPAddress);
|
||||
address.AddressString = properties.IPAddress;
|
||||
address.PrefixLength = properties.PrefixLength;
|
||||
|
||||
EndpointRoute route{};
|
||||
route.DestinationPrefix.PrefixLength = 0;
|
||||
IN4ADDR_SETANY(&route.DestinationPrefix.Prefix.Ipv4);
|
||||
route.DestinationPrefixString = LX_INIT_UNSPECIFIED_ADDRESS;
|
||||
route.NextHop = windows::common::string::StringToSockAddrInet(properties.GatewayAddress);
|
||||
route.NextHopString = properties.GatewayAddress;
|
||||
|
||||
return std::make_shared<wsl::core::networking::NetworkSettings>(
|
||||
properties.InterfaceConstraint.InterfaceGuid,
|
||||
address,
|
||||
route,
|
||||
properties.MacAddress,
|
||||
L"unuseddevicename",
|
||||
properties.InterfaceConstraint.InterfaceIndex,
|
||||
properties.InterfaceConstraint.InterfaceMediaType,
|
||||
properties.DNSServerList);
|
||||
}
|
||||
|
||||
std::shared_ptr<wsl::core::networking::NetworkSettings> wsl::core::networking::GetHostEndpointSettings()
|
||||
{
|
||||
HostDnsInfo dnsInfo;
|
||||
dnsInfo.UpdateNetworkInformation();
|
||||
auto addresses = dnsInfo.CurrentAddresses();
|
||||
auto bestIndex = GetBestInterface();
|
||||
auto bestInterfacePtr =
|
||||
std::find_if(addresses.cbegin(), addresses.cend(), [&](const auto& address) { return address->IfIndex == bestIndex; });
|
||||
if (bestInterfacePtr == addresses.end())
|
||||
{
|
||||
return std::make_shared<NetworkSettings>();
|
||||
}
|
||||
|
||||
const auto& bestInterface = *bestInterfacePtr;
|
||||
|
||||
std::wstring macAddress = wsl::shared::string::FormatMacAddress(
|
||||
wsl::shared::string::MacAddress{
|
||||
bestInterface->PhysicalAddress[0],
|
||||
bestInterface->PhysicalAddress[1],
|
||||
bestInterface->PhysicalAddress[2],
|
||||
bestInterface->PhysicalAddress[3],
|
||||
bestInterface->PhysicalAddress[4],
|
||||
bestInterface->PhysicalAddress[5]},
|
||||
L'-');
|
||||
|
||||
EndpointIpAddress address{};
|
||||
auto firstIpv4Address = bestInterface->FirstUnicastAddress;
|
||||
while (firstIpv4Address && firstIpv4Address->Address.lpSockaddr->sa_family != AF_INET)
|
||||
{
|
||||
firstIpv4Address = firstIpv4Address->Next;
|
||||
}
|
||||
if (firstIpv4Address)
|
||||
{
|
||||
address.Address = *reinterpret_cast<SOCKADDR_INET*>(firstIpv4Address->Address.lpSockaddr);
|
||||
address.AddressString = windows::common::string::SockAddrInetToWstring(address.Address);
|
||||
address.PrefixLength = firstIpv4Address->OnLinkPrefixLength;
|
||||
}
|
||||
|
||||
EndpointRoute route{};
|
||||
PIP_ADAPTER_GATEWAY_ADDRESS nextGatewayAddress = bestInterface->FirstGatewayAddress;
|
||||
while (nextGatewayAddress && nextGatewayAddress->Address.lpSockaddr->sa_family != AF_INET)
|
||||
{
|
||||
nextGatewayAddress = nextGatewayAddress->Next;
|
||||
}
|
||||
if (nextGatewayAddress)
|
||||
{
|
||||
route.DestinationPrefix.PrefixLength = 0;
|
||||
IN4ADDR_SETANY(&route.DestinationPrefix.Prefix.Ipv4);
|
||||
route.DestinationPrefixString = LX_INIT_UNSPECIFIED_ADDRESS;
|
||||
route.NextHop = *reinterpret_cast<SOCKADDR_INET*>(nextGatewayAddress->Address.lpSockaddr);
|
||||
route.NextHopString = windows::common::string::SockAddrInetToWstring(route.NextHop);
|
||||
}
|
||||
else if (address.Address.si_family == AF_INET)
|
||||
{
|
||||
IN_ADDR default_route{};
|
||||
default_route.s_addr = htonl((ntohl(address.Address.Ipv4.sin_addr.s_addr) & ~((1 << (32 - address.PrefixLength)) - 1)) | 1);
|
||||
route.DestinationPrefix.PrefixLength = 0;
|
||||
IN4ADDR_SETANY(&route.DestinationPrefix.Prefix.Ipv4);
|
||||
route.DestinationPrefixString = LX_INIT_UNSPECIFIED_ADDRESS;
|
||||
IN4ADDR_SETSOCKADDR(&route.NextHop.Ipv4, &default_route, 0);
|
||||
route.NextHopString = windows::common::string::SockAddrInetToWstring(route.NextHop);
|
||||
}
|
||||
|
||||
std::wstring dnsServerList;
|
||||
for (const auto& serverAddress : dnsInfo.GetDnsSettings(DnsSettingsFlags::IncludeVpn).Servers)
|
||||
{
|
||||
if (!dnsServerList.empty())
|
||||
{
|
||||
dnsServerList += L",";
|
||||
}
|
||||
dnsServerList += wsl::shared::string::MultiByteToWide(serverAddress);
|
||||
}
|
||||
|
||||
return std::shared_ptr<NetworkSettings>(new NetworkSettings(
|
||||
bestInterface->NetworkGuid, address, route, macAddress, {}, bestInterface->IfIndex, bestInterface->IfType, dnsServerList));
|
||||
}
|
||||
// Copyright (C) Microsoft Corporation. All rights reserved.
|
||||
|
||||
#include "precomp.h"
|
||||
#include "hns_schema.h"
|
||||
#include "WslCoreNetworkEndpointSettings.h"
|
||||
#include "WslCoreHostDnsInfo.h"
|
||||
|
||||
using namespace wsl::shared;
|
||||
|
||||
std::shared_ptr<wsl::core::networking::NetworkSettings> wsl::core::networking::GetEndpointSettings(const hns::HNSEndpoint& properties)
|
||||
{
|
||||
EndpointIpAddress address{};
|
||||
address.Address = windows::common::string::StringToSockAddrInet(properties.IPAddress);
|
||||
address.AddressString = properties.IPAddress;
|
||||
address.PrefixLength = properties.PrefixLength;
|
||||
|
||||
EndpointRoute route{};
|
||||
route.DestinationPrefix.PrefixLength = 0;
|
||||
IN4ADDR_SETANY(&route.DestinationPrefix.Prefix.Ipv4);
|
||||
route.DestinationPrefixString = LX_INIT_UNSPECIFIED_ADDRESS;
|
||||
route.NextHop = windows::common::string::StringToSockAddrInet(properties.GatewayAddress);
|
||||
route.NextHopString = properties.GatewayAddress;
|
||||
|
||||
return std::make_shared<wsl::core::networking::NetworkSettings>(
|
||||
properties.InterfaceConstraint.InterfaceGuid,
|
||||
address,
|
||||
route,
|
||||
properties.MacAddress,
|
||||
L"unuseddevicename",
|
||||
properties.InterfaceConstraint.InterfaceIndex,
|
||||
properties.InterfaceConstraint.InterfaceMediaType,
|
||||
properties.DNSServerList);
|
||||
}
|
||||
|
||||
std::shared_ptr<wsl::core::networking::NetworkSettings> wsl::core::networking::GetHostEndpointSettings()
|
||||
{
|
||||
HostDnsInfo dnsInfo;
|
||||
dnsInfo.UpdateNetworkInformation();
|
||||
auto addresses = dnsInfo.CurrentAddresses();
|
||||
auto bestIndex = GetBestInterface();
|
||||
auto bestInterfacePtr =
|
||||
std::find_if(addresses.cbegin(), addresses.cend(), [&](const auto& address) { return address->IfIndex == bestIndex; });
|
||||
if (bestInterfacePtr == addresses.end())
|
||||
{
|
||||
return std::make_shared<NetworkSettings>();
|
||||
}
|
||||
|
||||
const auto& bestInterface = *bestInterfacePtr;
|
||||
|
||||
std::wstring macAddress = wsl::shared::string::FormatMacAddress(
|
||||
wsl::shared::string::MacAddress{
|
||||
bestInterface->PhysicalAddress[0],
|
||||
bestInterface->PhysicalAddress[1],
|
||||
bestInterface->PhysicalAddress[2],
|
||||
bestInterface->PhysicalAddress[3],
|
||||
bestInterface->PhysicalAddress[4],
|
||||
bestInterface->PhysicalAddress[5]},
|
||||
L'-');
|
||||
|
||||
EndpointIpAddress address{};
|
||||
auto firstIpv4Address = bestInterface->FirstUnicastAddress;
|
||||
while (firstIpv4Address && firstIpv4Address->Address.lpSockaddr->sa_family != AF_INET)
|
||||
{
|
||||
firstIpv4Address = firstIpv4Address->Next;
|
||||
}
|
||||
if (firstIpv4Address)
|
||||
{
|
||||
address.Address = *reinterpret_cast<SOCKADDR_INET*>(firstIpv4Address->Address.lpSockaddr);
|
||||
address.AddressString = windows::common::string::SockAddrInetToWstring(address.Address);
|
||||
address.PrefixLength = firstIpv4Address->OnLinkPrefixLength;
|
||||
}
|
||||
|
||||
EndpointRoute route{};
|
||||
PIP_ADAPTER_GATEWAY_ADDRESS nextGatewayAddress = bestInterface->FirstGatewayAddress;
|
||||
while (nextGatewayAddress && nextGatewayAddress->Address.lpSockaddr->sa_family != AF_INET)
|
||||
{
|
||||
nextGatewayAddress = nextGatewayAddress->Next;
|
||||
}
|
||||
if (nextGatewayAddress)
|
||||
{
|
||||
route.DestinationPrefix.PrefixLength = 0;
|
||||
IN4ADDR_SETANY(&route.DestinationPrefix.Prefix.Ipv4);
|
||||
route.DestinationPrefixString = LX_INIT_UNSPECIFIED_ADDRESS;
|
||||
route.NextHop = *reinterpret_cast<SOCKADDR_INET*>(nextGatewayAddress->Address.lpSockaddr);
|
||||
route.NextHopString = windows::common::string::SockAddrInetToWstring(route.NextHop);
|
||||
}
|
||||
else if (address.Address.si_family == AF_INET)
|
||||
{
|
||||
IN_ADDR default_route{};
|
||||
default_route.s_addr = htonl((ntohl(address.Address.Ipv4.sin_addr.s_addr) & ~((1 << (32 - address.PrefixLength)) - 1)) | 1);
|
||||
route.DestinationPrefix.PrefixLength = 0;
|
||||
IN4ADDR_SETANY(&route.DestinationPrefix.Prefix.Ipv4);
|
||||
route.DestinationPrefixString = LX_INIT_UNSPECIFIED_ADDRESS;
|
||||
IN4ADDR_SETSOCKADDR(&route.NextHop.Ipv4, &default_route, 0);
|
||||
route.NextHopString = windows::common::string::SockAddrInetToWstring(route.NextHop);
|
||||
}
|
||||
|
||||
std::wstring dnsServerList;
|
||||
for (const auto& serverAddress : dnsInfo.GetDnsSettings(DnsSettingsFlags::IncludeVpn).Servers)
|
||||
{
|
||||
if (!dnsServerList.empty())
|
||||
{
|
||||
dnsServerList += L",";
|
||||
}
|
||||
dnsServerList += wsl::shared::string::MultiByteToWide(serverAddress);
|
||||
}
|
||||
|
||||
return std::shared_ptr<NetworkSettings>(new NetworkSettings(
|
||||
bestInterface->NetworkGuid, address, route, macAddress, {}, bestInterface->IfIndex, bestInterface->IfType, dnsServerList));
|
||||
}
|
||||
@ -1,398 +1,398 @@
|
||||
// Copyright (C) Microsoft Corporation. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
#include <algorithm>
|
||||
#include <set>
|
||||
#include <string>
|
||||
|
||||
#include <windows.h>
|
||||
#include <mstcpip.h>
|
||||
#include <ws2ipdef.h>
|
||||
#include <netioapi.h>
|
||||
|
||||
#include "hcs.hpp"
|
||||
#include "lxinitshared.h"
|
||||
#include "Stringify.h"
|
||||
#include "stringshared.h"
|
||||
#include "WslCoreNetworkingSupport.h"
|
||||
#include "hns_schema.h"
|
||||
|
||||
namespace wsl::core::networking {
|
||||
|
||||
constexpr auto AddEndpointRetryPeriod = std::chrono::milliseconds(100);
|
||||
constexpr auto AddEndpointRetryTimeout = std::chrono::seconds(3);
|
||||
constexpr auto AddEndpointRetryPredicate = [] {
|
||||
// Don't retry if ModifyComputeSystem fails with:
|
||||
// HCN_E_ENDPOINT_NOT_FOUND - indicates that the underlying network object was deleted.
|
||||
// HCN_E_ENDPOINT_ALREADY_ATTACHED - occurs when HNS was restarted before the endpoints were removed.
|
||||
// VM_E_INVALID_STATE - occurs when the VM has been terminated.
|
||||
const auto result = wil::ResultFromCaughtException();
|
||||
return result != HCN_E_ENDPOINT_NOT_FOUND && result != HCN_E_ENDPOINT_ALREADY_ATTACHED && result != VM_E_INVALID_STATE;
|
||||
};
|
||||
|
||||
struct EndpointIpAddress
|
||||
{
|
||||
SOCKADDR_INET Address{};
|
||||
std::wstring AddressString{};
|
||||
unsigned char PrefixLength = 0;
|
||||
unsigned int PrefixOrigin = 0;
|
||||
unsigned int SuffixOrigin = 0;
|
||||
|
||||
// The following field can be changed from a const iterator in SyncIpStateWithLinux - that's why it's marked mutable.
|
||||
mutable unsigned int PreferredLifetime = 0;
|
||||
|
||||
EndpointIpAddress() = default;
|
||||
~EndpointIpAddress() noexcept = default;
|
||||
|
||||
EndpointIpAddress(EndpointIpAddress&&) = default;
|
||||
EndpointIpAddress& operator=(EndpointIpAddress&&) = default;
|
||||
EndpointIpAddress(const EndpointIpAddress&) = default;
|
||||
EndpointIpAddress& operator=(const EndpointIpAddress&) = default;
|
||||
|
||||
explicit EndpointIpAddress(const MIB_UNICASTIPADDRESS_ROW& AddressRow) :
|
||||
Address(AddressRow.Address),
|
||||
AddressString(windows::common::string::SockAddrInetToWstring(AddressRow.Address)),
|
||||
PrefixLength(AddressRow.OnLinkPrefixLength),
|
||||
PrefixOrigin(AddressRow.PrefixOrigin),
|
||||
SuffixOrigin(AddressRow.SuffixOrigin),
|
||||
// We treat the preferred lifetime field as effective DAD state - 0 is not preferred, anything else is preferred.
|
||||
// We do this for convenience, as we can't directly set the DAD state of an address into the guest, but we
|
||||
// we can set an address's preferred lifetime (in Linux, at least).
|
||||
PreferredLifetime(AddressRow.DadState == IpDadStatePreferred ? 0xFFFFFFFF : 0)
|
||||
{
|
||||
}
|
||||
|
||||
// operator== is deliberately not comparing PreferredLifetime (DAD state) for equality - only the address portion
|
||||
bool operator==(const EndpointIpAddress& rhs) const noexcept
|
||||
{
|
||||
return Address == rhs.Address && PrefixLength == rhs.PrefixLength;
|
||||
}
|
||||
|
||||
bool operator<(const EndpointIpAddress& rhs) const noexcept
|
||||
{
|
||||
if (Address == rhs.Address)
|
||||
{
|
||||
return PrefixLength < rhs.PrefixLength;
|
||||
}
|
||||
return Address < rhs.Address;
|
||||
}
|
||||
|
||||
void Clear() noexcept
|
||||
{
|
||||
Address = {};
|
||||
AddressString.clear();
|
||||
PrefixLength = 0;
|
||||
PrefixOrigin = 0;
|
||||
SuffixOrigin = 0;
|
||||
}
|
||||
|
||||
std::wstring GetPrefix() const
|
||||
{
|
||||
SOCKADDR_INET address{Address};
|
||||
unsigned char* addressPointer{nullptr};
|
||||
|
||||
if (Address.si_family == AF_INET)
|
||||
{
|
||||
addressPointer = reinterpret_cast<unsigned char*>(&address.Ipv4.sin_addr);
|
||||
}
|
||||
else if (Address.si_family == AF_INET6)
|
||||
{
|
||||
addressPointer = address.Ipv6.sin6_addr.u.Byte;
|
||||
}
|
||||
else
|
||||
{
|
||||
return L"";
|
||||
}
|
||||
|
||||
constexpr int c_numBitsPerByte = 8;
|
||||
for (int i = 0, currPrefixLength = PrefixLength; i < INET_ADDR_LENGTH(Address.si_family); i++, currPrefixLength -= c_numBitsPerByte)
|
||||
{
|
||||
if (currPrefixLength < c_numBitsPerByte)
|
||||
{
|
||||
const int bitShiftAmt = c_numBitsPerByte - std::max(currPrefixLength, 0);
|
||||
addressPointer[i] &= (0xFF >> bitShiftAmt) << bitShiftAmt;
|
||||
}
|
||||
}
|
||||
|
||||
const auto addressString = windows::common::string::SockAddrInetToWstring(address);
|
||||
WI_ASSERT(!addressString.empty());
|
||||
if (addressString.empty())
|
||||
{
|
||||
// just return an empty string if we have a bad address
|
||||
return addressString;
|
||||
}
|
||||
|
||||
return std::format(L"{}/{}", addressString, PrefixLength);
|
||||
}
|
||||
|
||||
std::wstring GetIpv4BroadcastMask() const
|
||||
{
|
||||
// start with all bits set, then shift off the prefix
|
||||
ULONG prefixMask{0xffffffff};
|
||||
prefixMask <<= PrefixLength;
|
||||
prefixMask >>= PrefixLength;
|
||||
|
||||
SOCKADDR_INET address{Address};
|
||||
// flip to host-order, then apply the mask
|
||||
ULONG hostOrder = ntohl(address.Ipv4.sin_addr.S_un.S_addr);
|
||||
hostOrder |= prefixMask;
|
||||
address.Ipv4.sin_addr.S_un.S_addr = htonl(hostOrder);
|
||||
|
||||
return windows::common::string::SockAddrInetToWstring(address);
|
||||
}
|
||||
|
||||
bool IsPreferred() const noexcept
|
||||
{
|
||||
return PreferredLifetime > 0;
|
||||
}
|
||||
|
||||
bool IsLinkLocal() const
|
||||
{
|
||||
return (Address.si_family == AF_INET && IN4_IS_ADDR_LINKLOCAL(&Address.Ipv4.sin_addr)) ||
|
||||
(Address.si_family == AF_INET6 && IN6_IS_ADDR_LINKLOCAL(&Address.Ipv6.sin6_addr));
|
||||
}
|
||||
};
|
||||
|
||||
struct EndpointRoute
|
||||
{
|
||||
ADDRESS_FAMILY Family = AF_INET;
|
||||
IP_ADDRESS_PREFIX DestinationPrefix{};
|
||||
std::wstring DestinationPrefixString{};
|
||||
SOCKADDR_INET NextHop{};
|
||||
std::wstring NextHopString{};
|
||||
unsigned char SitePrefixLength = 0;
|
||||
unsigned int Metric = 0;
|
||||
bool IsAutoGeneratedPrefixRoute = false;
|
||||
|
||||
EndpointRoute() = default;
|
||||
~EndpointRoute() noexcept = default;
|
||||
|
||||
EndpointRoute(EndpointRoute&&) = default;
|
||||
EndpointRoute& operator=(EndpointRoute&&) = default;
|
||||
EndpointRoute(const EndpointRoute&) = default;
|
||||
EndpointRoute& operator=(const EndpointRoute&) = default;
|
||||
|
||||
EndpointRoute(const MIB_IPFORWARD_ROW2& RouteRow) :
|
||||
Family(RouteRow.NextHop.si_family),
|
||||
DestinationPrefix(RouteRow.DestinationPrefix),
|
||||
DestinationPrefixString(windows::common::string::SockAddrInetToWstring(RouteRow.DestinationPrefix.Prefix)),
|
||||
NextHop(RouteRow.NextHop),
|
||||
NextHopString(windows::common::string::SockAddrInetToWstring(RouteRow.NextHop)),
|
||||
SitePrefixLength(RouteRow.SitePrefixLength),
|
||||
Metric(RouteRow.Metric)
|
||||
{
|
||||
}
|
||||
|
||||
unsigned char GetMaxPrefixLength() const
|
||||
{
|
||||
return (Family == AF_INET) ? 32 : 128;
|
||||
}
|
||||
|
||||
std::wstring GetFullDestinationPrefix() const
|
||||
{
|
||||
return std::format(L"{}/{}", DestinationPrefixString, static_cast<unsigned int>(DestinationPrefix.PrefixLength));
|
||||
}
|
||||
|
||||
bool IsNextHopOnlink() const noexcept
|
||||
{
|
||||
return (Family == AF_INET && NextHopString == LX_INIT_UNSPECIFIED_ADDRESS) ||
|
||||
(Family == AF_INET6 && NextHopString == LX_INIT_UNSPECIFIED_V6_ADDRESS);
|
||||
}
|
||||
|
||||
bool IsDefault() const noexcept
|
||||
{
|
||||
return (Family == AF_INET && DestinationPrefixString == LX_INIT_UNSPECIFIED_ADDRESS) ||
|
||||
(Family == AF_INET6 && DestinationPrefixString == LX_INIT_UNSPECIFIED_V6_ADDRESS);
|
||||
}
|
||||
|
||||
bool IsUnicastAddressRoute() const noexcept
|
||||
{
|
||||
return (Family == AF_INET && DestinationPrefix.PrefixLength == 32) || (Family == AF_INET6 && DestinationPrefix.PrefixLength == 128);
|
||||
}
|
||||
|
||||
std::wstring ToString() const
|
||||
{
|
||||
return std::format(L"{}=>{} [metric {}]", GetFullDestinationPrefix(), NextHopString, Metric);
|
||||
}
|
||||
|
||||
bool operator==(const EndpointRoute& rhs) const noexcept
|
||||
{
|
||||
return Family == rhs.Family && DestinationPrefix.PrefixLength == rhs.DestinationPrefix.PrefixLength &&
|
||||
DestinationPrefix.Prefix == rhs.DestinationPrefix.Prefix && NextHop == rhs.NextHop &&
|
||||
SitePrefixLength == rhs.SitePrefixLength && Metric == rhs.Metric;
|
||||
}
|
||||
|
||||
bool operator!=(const EndpointRoute& other) const
|
||||
{
|
||||
return !(*this == other);
|
||||
}
|
||||
|
||||
// sort by family, then by next-hop (on-link routes first), then by prefix, then by metric
|
||||
bool operator<(const EndpointRoute& rhs) const noexcept
|
||||
{
|
||||
if (Family == rhs.Family)
|
||||
{
|
||||
if (NextHop == rhs.NextHop)
|
||||
{
|
||||
if (DestinationPrefix.Prefix == rhs.DestinationPrefix.Prefix)
|
||||
{
|
||||
if (DestinationPrefix.PrefixLength == rhs.DestinationPrefix.PrefixLength)
|
||||
{
|
||||
if (Metric == rhs.Metric)
|
||||
{
|
||||
return SitePrefixLength < rhs.SitePrefixLength;
|
||||
}
|
||||
return Metric < rhs.Metric;
|
||||
}
|
||||
return DestinationPrefix.PrefixLength < rhs.DestinationPrefix.PrefixLength;
|
||||
}
|
||||
return DestinationPrefix.Prefix < rhs.DestinationPrefix.Prefix;
|
||||
}
|
||||
return NextHop < rhs.NextHop;
|
||||
}
|
||||
return Family < rhs.Family;
|
||||
}
|
||||
};
|
||||
|
||||
struct NetworkSettings
|
||||
{
|
||||
NetworkSettings() = default;
|
||||
|
||||
NetworkSettings(
|
||||
const GUID& interfaceGuid,
|
||||
EndpointIpAddress preferredIpAddress,
|
||||
EndpointRoute gateway,
|
||||
std::wstring macAddress,
|
||||
std::wstring deviceName,
|
||||
uint32_t interfaceIndex,
|
||||
uint32_t mediaType,
|
||||
const std::wstring& dnsServerList) :
|
||||
InterfaceGuid(interfaceGuid),
|
||||
PreferredIpAddress(std::move(preferredIpAddress)),
|
||||
MacAddress(std::move(macAddress)),
|
||||
DeviceName(std::move(deviceName)),
|
||||
InterfaceIndex(interfaceIndex),
|
||||
InterfaceType(mediaType)
|
||||
{
|
||||
Routes.emplace(std::move(gateway));
|
||||
DnsServers = wsl::shared::string::Split(dnsServerList, L',');
|
||||
}
|
||||
|
||||
GUID InterfaceGuid{};
|
||||
EndpointIpAddress PreferredIpAddress{};
|
||||
std::set<EndpointIpAddress> IpAddresses{}; // Does not include PreferredIpAddress.
|
||||
std::set<EndpointRoute> Routes{};
|
||||
std::vector<std::wstring> DnsServers{};
|
||||
std::wstring MacAddress;
|
||||
std::wstring DeviceName;
|
||||
IF_INDEX InterfaceIndex = 0;
|
||||
IFTYPE InterfaceType = 0;
|
||||
ULONG IPv4InterfaceMtu = 0;
|
||||
ULONG IPv6InterfaceMtu = 0;
|
||||
// some interfaces will only have an IPv4 or IPv6 interface
|
||||
std::optional<ULONG> IPv4InterfaceMetric = 0;
|
||||
std::optional<ULONG> IPv6InterfaceMetric = 0;
|
||||
bool IsHidden = false;
|
||||
bool IsConnected = false;
|
||||
bool IsMetered = false;
|
||||
bool DisableIpv4DefaultRoutes = false;
|
||||
bool DisableIpv6DefaultRoutes = false;
|
||||
bool PendingUpdateToReconnectForMetered = false;
|
||||
bool PendingIPInterfaceUpdate = false;
|
||||
|
||||
auto operator<=>(const NetworkSettings&) const = default;
|
||||
|
||||
std::wstring GetBestGatewayAddressString() const
|
||||
{
|
||||
// Best is currently defined as simply the first IPv4 gateway.
|
||||
for (const auto& route : Routes)
|
||||
{
|
||||
if (route.Family == AF_INET && route.DestinationPrefix.PrefixLength == 0 && route.DestinationPrefixString == LX_INIT_UNSPECIFIED_ADDRESS)
|
||||
{
|
||||
return route.NextHopString;
|
||||
}
|
||||
}
|
||||
|
||||
return {};
|
||||
}
|
||||
|
||||
SOCKADDR_INET GetBestGatewayAddress() const
|
||||
{
|
||||
// Best is currently defined as simply the first IPv4 gateway.
|
||||
for (const auto& route : Routes)
|
||||
{
|
||||
if (route.Family == AF_INET && route.DestinationPrefix.PrefixLength == 0 && route.DestinationPrefixString == LX_INIT_UNSPECIFIED_ADDRESS)
|
||||
{
|
||||
return route.NextHop;
|
||||
}
|
||||
}
|
||||
|
||||
return {};
|
||||
}
|
||||
|
||||
std::wstring IpAddressesString() const
|
||||
{
|
||||
return std::accumulate(std::begin(IpAddresses), std::end(IpAddresses), std::wstring{}, [](const std::wstring& prev, const auto& addr) {
|
||||
return addr.AddressString + (prev.empty() ? L"" : L"," + prev);
|
||||
});
|
||||
}
|
||||
|
||||
std::wstring RoutesString() const
|
||||
{
|
||||
return std::accumulate(std::begin(Routes), std::end(Routes), std::wstring{}, [](const std::wstring& prev, const EndpointRoute& route) {
|
||||
return route.ToString() + (prev.empty() ? L"" : L"," + prev);
|
||||
});
|
||||
}
|
||||
|
||||
std::wstring DnsServersString() const
|
||||
{
|
||||
return wsl::shared::string::Join(DnsServers, L',');
|
||||
}
|
||||
|
||||
// will return ULONG_MAX if there's no configured MTU
|
||||
ULONG GetEffectiveMtu() const noexcept
|
||||
{
|
||||
return std::min(IPv4InterfaceMtu > 0 ? IPv4InterfaceMtu : ULONG_MAX, IPv6InterfaceMtu > 0 ? IPv6InterfaceMtu : ULONG_MAX);
|
||||
}
|
||||
|
||||
// will return zero if there's no configured metric
|
||||
ULONG GetMinimumMetric() const noexcept
|
||||
{
|
||||
if (!IPv4InterfaceMetric.has_value() && !IPv6InterfaceMetric.has_value())
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
if (!IPv4InterfaceMetric.has_value())
|
||||
{
|
||||
return IPv6InterfaceMetric.value();
|
||||
}
|
||||
if (!IPv6InterfaceMetric.has_value())
|
||||
{
|
||||
return IPv4InterfaceMetric.value();
|
||||
}
|
||||
return std::min(IPv4InterfaceMetric.value(), IPv6InterfaceMetric.value());
|
||||
}
|
||||
};
|
||||
|
||||
std::shared_ptr<NetworkSettings> GetEndpointSettings(const wsl::shared::hns::HNSEndpoint& properties);
|
||||
std::shared_ptr<NetworkSettings> GetHostEndpointSettings();
|
||||
|
||||
#define TRACE_NETWORKSETTINGS_OBJECT(settings) \
|
||||
TraceLoggingValue((settings)->InterfaceGuid, "interfaceGuid"), TraceLoggingValue((settings)->InterfaceIndex, "interfaceIndex"), \
|
||||
TraceLoggingValue((settings)->InterfaceType, "interfaceType"), \
|
||||
TraceLoggingValue((settings)->IsConnected, "isConnected"), TraceLoggingValue((settings)->IsMetered, "isMetered"), \
|
||||
TraceLoggingValue((settings)->GetBestGatewayAddressString().c_str(), "bestGatewayAddress"), \
|
||||
TraceLoggingValue((settings)->PreferredIpAddress.AddressString.c_str(), "preferredIpAddress"), \
|
||||
TraceLoggingValue((settings)->PreferredIpAddress.PrefixLength, "preferredIpAddressPrefixLength"), \
|
||||
TraceLoggingValue((settings)->IpAddressesString().c_str(), "ipAddresses"), \
|
||||
TraceLoggingValue((settings)->RoutesString().c_str(), "routes"), \
|
||||
TraceLoggingValue((settings)->DnsServersString().c_str(), "dnsServerList"), \
|
||||
TraceLoggingValue((settings)->MacAddress.c_str(), "macAddress"), \
|
||||
TraceLoggingValue((settings)->IPv4InterfaceMtu, "IPv4InterfaceMtu"), \
|
||||
TraceLoggingValue((settings)->IPv6InterfaceMtu, "IPv6InterfaceMtu"), \
|
||||
TraceLoggingValue((settings)->IPv4InterfaceMetric.value_or(0xffffffff), "IPv4InterfaceMetric"), \
|
||||
TraceLoggingValue((settings)->IPv6InterfaceMetric.value_or(0xffffffff), "IPv6InterfaceMetric"), \
|
||||
TraceLoggingValue((settings)->PendingIPInterfaceUpdate, "PendingIPInterfaceUpdate"), \
|
||||
TraceLoggingValue((settings)->PendingUpdateToReconnectForMetered, "PendingUpdateToReconnectForMetered")
|
||||
|
||||
} // namespace wsl::core::networking
|
||||
// Copyright (C) Microsoft Corporation. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
#include <algorithm>
|
||||
#include <set>
|
||||
#include <string>
|
||||
|
||||
#include <windows.h>
|
||||
#include <mstcpip.h>
|
||||
#include <ws2ipdef.h>
|
||||
#include <netioapi.h>
|
||||
|
||||
#include "hcs.hpp"
|
||||
#include "lxinitshared.h"
|
||||
#include "Stringify.h"
|
||||
#include "stringshared.h"
|
||||
#include "WslCoreNetworkingSupport.h"
|
||||
#include "hns_schema.h"
|
||||
|
||||
namespace wsl::core::networking {
|
||||
|
||||
constexpr auto AddEndpointRetryPeriod = std::chrono::milliseconds(100);
|
||||
constexpr auto AddEndpointRetryTimeout = std::chrono::seconds(3);
|
||||
constexpr auto AddEndpointRetryPredicate = [] {
|
||||
// Don't retry if ModifyComputeSystem fails with:
|
||||
// HCN_E_ENDPOINT_NOT_FOUND - indicates that the underlying network object was deleted.
|
||||
// HCN_E_ENDPOINT_ALREADY_ATTACHED - occurs when HNS was restarted before the endpoints were removed.
|
||||
// VM_E_INVALID_STATE - occurs when the VM has been terminated.
|
||||
const auto result = wil::ResultFromCaughtException();
|
||||
return result != HCN_E_ENDPOINT_NOT_FOUND && result != HCN_E_ENDPOINT_ALREADY_ATTACHED && result != VM_E_INVALID_STATE;
|
||||
};
|
||||
|
||||
struct EndpointIpAddress
|
||||
{
|
||||
SOCKADDR_INET Address{};
|
||||
std::wstring AddressString{};
|
||||
unsigned char PrefixLength = 0;
|
||||
unsigned int PrefixOrigin = 0;
|
||||
unsigned int SuffixOrigin = 0;
|
||||
|
||||
// The following field can be changed from a const iterator in SyncIpStateWithLinux - that's why it's marked mutable.
|
||||
mutable unsigned int PreferredLifetime = 0;
|
||||
|
||||
EndpointIpAddress() = default;
|
||||
~EndpointIpAddress() noexcept = default;
|
||||
|
||||
EndpointIpAddress(EndpointIpAddress&&) = default;
|
||||
EndpointIpAddress& operator=(EndpointIpAddress&&) = default;
|
||||
EndpointIpAddress(const EndpointIpAddress&) = default;
|
||||
EndpointIpAddress& operator=(const EndpointIpAddress&) = default;
|
||||
|
||||
explicit EndpointIpAddress(const MIB_UNICASTIPADDRESS_ROW& AddressRow) :
|
||||
Address(AddressRow.Address),
|
||||
AddressString(windows::common::string::SockAddrInetToWstring(AddressRow.Address)),
|
||||
PrefixLength(AddressRow.OnLinkPrefixLength),
|
||||
PrefixOrigin(AddressRow.PrefixOrigin),
|
||||
SuffixOrigin(AddressRow.SuffixOrigin),
|
||||
// We treat the preferred lifetime field as effective DAD state - 0 is not preferred, anything else is preferred.
|
||||
// We do this for convenience, as we can't directly set the DAD state of an address into the guest, but we
|
||||
// we can set an address's preferred lifetime (in Linux, at least).
|
||||
PreferredLifetime(AddressRow.DadState == IpDadStatePreferred ? 0xFFFFFFFF : 0)
|
||||
{
|
||||
}
|
||||
|
||||
// operator== is deliberately not comparing PreferredLifetime (DAD state) for equality - only the address portion
|
||||
bool operator==(const EndpointIpAddress& rhs) const noexcept
|
||||
{
|
||||
return Address == rhs.Address && PrefixLength == rhs.PrefixLength;
|
||||
}
|
||||
|
||||
bool operator<(const EndpointIpAddress& rhs) const noexcept
|
||||
{
|
||||
if (Address == rhs.Address)
|
||||
{
|
||||
return PrefixLength < rhs.PrefixLength;
|
||||
}
|
||||
return Address < rhs.Address;
|
||||
}
|
||||
|
||||
void Clear() noexcept
|
||||
{
|
||||
Address = {};
|
||||
AddressString.clear();
|
||||
PrefixLength = 0;
|
||||
PrefixOrigin = 0;
|
||||
SuffixOrigin = 0;
|
||||
}
|
||||
|
||||
std::wstring GetPrefix() const
|
||||
{
|
||||
SOCKADDR_INET address{Address};
|
||||
unsigned char* addressPointer{nullptr};
|
||||
|
||||
if (Address.si_family == AF_INET)
|
||||
{
|
||||
addressPointer = reinterpret_cast<unsigned char*>(&address.Ipv4.sin_addr);
|
||||
}
|
||||
else if (Address.si_family == AF_INET6)
|
||||
{
|
||||
addressPointer = address.Ipv6.sin6_addr.u.Byte;
|
||||
}
|
||||
else
|
||||
{
|
||||
return L"";
|
||||
}
|
||||
|
||||
constexpr int c_numBitsPerByte = 8;
|
||||
for (int i = 0, currPrefixLength = PrefixLength; i < INET_ADDR_LENGTH(Address.si_family); i++, currPrefixLength -= c_numBitsPerByte)
|
||||
{
|
||||
if (currPrefixLength < c_numBitsPerByte)
|
||||
{
|
||||
const int bitShiftAmt = c_numBitsPerByte - std::max(currPrefixLength, 0);
|
||||
addressPointer[i] &= (0xFF >> bitShiftAmt) << bitShiftAmt;
|
||||
}
|
||||
}
|
||||
|
||||
const auto addressString = windows::common::string::SockAddrInetToWstring(address);
|
||||
WI_ASSERT(!addressString.empty());
|
||||
if (addressString.empty())
|
||||
{
|
||||
// just return an empty string if we have a bad address
|
||||
return addressString;
|
||||
}
|
||||
|
||||
return std::format(L"{}/{}", addressString, PrefixLength);
|
||||
}
|
||||
|
||||
std::wstring GetIpv4BroadcastMask() const
|
||||
{
|
||||
// start with all bits set, then shift off the prefix
|
||||
ULONG prefixMask{0xffffffff};
|
||||
prefixMask <<= PrefixLength;
|
||||
prefixMask >>= PrefixLength;
|
||||
|
||||
SOCKADDR_INET address{Address};
|
||||
// flip to host-order, then apply the mask
|
||||
ULONG hostOrder = ntohl(address.Ipv4.sin_addr.S_un.S_addr);
|
||||
hostOrder |= prefixMask;
|
||||
address.Ipv4.sin_addr.S_un.S_addr = htonl(hostOrder);
|
||||
|
||||
return windows::common::string::SockAddrInetToWstring(address);
|
||||
}
|
||||
|
||||
bool IsPreferred() const noexcept
|
||||
{
|
||||
return PreferredLifetime > 0;
|
||||
}
|
||||
|
||||
bool IsLinkLocal() const
|
||||
{
|
||||
return (Address.si_family == AF_INET && IN4_IS_ADDR_LINKLOCAL(&Address.Ipv4.sin_addr)) ||
|
||||
(Address.si_family == AF_INET6 && IN6_IS_ADDR_LINKLOCAL(&Address.Ipv6.sin6_addr));
|
||||
}
|
||||
};
|
||||
|
||||
struct EndpointRoute
|
||||
{
|
||||
ADDRESS_FAMILY Family = AF_INET;
|
||||
IP_ADDRESS_PREFIX DestinationPrefix{};
|
||||
std::wstring DestinationPrefixString{};
|
||||
SOCKADDR_INET NextHop{};
|
||||
std::wstring NextHopString{};
|
||||
unsigned char SitePrefixLength = 0;
|
||||
unsigned int Metric = 0;
|
||||
bool IsAutoGeneratedPrefixRoute = false;
|
||||
|
||||
EndpointRoute() = default;
|
||||
~EndpointRoute() noexcept = default;
|
||||
|
||||
EndpointRoute(EndpointRoute&&) = default;
|
||||
EndpointRoute& operator=(EndpointRoute&&) = default;
|
||||
EndpointRoute(const EndpointRoute&) = default;
|
||||
EndpointRoute& operator=(const EndpointRoute&) = default;
|
||||
|
||||
EndpointRoute(const MIB_IPFORWARD_ROW2& RouteRow) :
|
||||
Family(RouteRow.NextHop.si_family),
|
||||
DestinationPrefix(RouteRow.DestinationPrefix),
|
||||
DestinationPrefixString(windows::common::string::SockAddrInetToWstring(RouteRow.DestinationPrefix.Prefix)),
|
||||
NextHop(RouteRow.NextHop),
|
||||
NextHopString(windows::common::string::SockAddrInetToWstring(RouteRow.NextHop)),
|
||||
SitePrefixLength(RouteRow.SitePrefixLength),
|
||||
Metric(RouteRow.Metric)
|
||||
{
|
||||
}
|
||||
|
||||
unsigned char GetMaxPrefixLength() const
|
||||
{
|
||||
return (Family == AF_INET) ? 32 : 128;
|
||||
}
|
||||
|
||||
std::wstring GetFullDestinationPrefix() const
|
||||
{
|
||||
return std::format(L"{}/{}", DestinationPrefixString, static_cast<unsigned int>(DestinationPrefix.PrefixLength));
|
||||
}
|
||||
|
||||
bool IsNextHopOnlink() const noexcept
|
||||
{
|
||||
return (Family == AF_INET && NextHopString == LX_INIT_UNSPECIFIED_ADDRESS) ||
|
||||
(Family == AF_INET6 && NextHopString == LX_INIT_UNSPECIFIED_V6_ADDRESS);
|
||||
}
|
||||
|
||||
bool IsDefault() const noexcept
|
||||
{
|
||||
return (Family == AF_INET && DestinationPrefixString == LX_INIT_UNSPECIFIED_ADDRESS) ||
|
||||
(Family == AF_INET6 && DestinationPrefixString == LX_INIT_UNSPECIFIED_V6_ADDRESS);
|
||||
}
|
||||
|
||||
bool IsUnicastAddressRoute() const noexcept
|
||||
{
|
||||
return (Family == AF_INET && DestinationPrefix.PrefixLength == 32) || (Family == AF_INET6 && DestinationPrefix.PrefixLength == 128);
|
||||
}
|
||||
|
||||
std::wstring ToString() const
|
||||
{
|
||||
return std::format(L"{}=>{} [metric {}]", GetFullDestinationPrefix(), NextHopString, Metric);
|
||||
}
|
||||
|
||||
bool operator==(const EndpointRoute& rhs) const noexcept
|
||||
{
|
||||
return Family == rhs.Family && DestinationPrefix.PrefixLength == rhs.DestinationPrefix.PrefixLength &&
|
||||
DestinationPrefix.Prefix == rhs.DestinationPrefix.Prefix && NextHop == rhs.NextHop &&
|
||||
SitePrefixLength == rhs.SitePrefixLength && Metric == rhs.Metric;
|
||||
}
|
||||
|
||||
bool operator!=(const EndpointRoute& other) const
|
||||
{
|
||||
return !(*this == other);
|
||||
}
|
||||
|
||||
// sort by family, then by next-hop (on-link routes first), then by prefix, then by metric
|
||||
bool operator<(const EndpointRoute& rhs) const noexcept
|
||||
{
|
||||
if (Family == rhs.Family)
|
||||
{
|
||||
if (NextHop == rhs.NextHop)
|
||||
{
|
||||
if (DestinationPrefix.Prefix == rhs.DestinationPrefix.Prefix)
|
||||
{
|
||||
if (DestinationPrefix.PrefixLength == rhs.DestinationPrefix.PrefixLength)
|
||||
{
|
||||
if (Metric == rhs.Metric)
|
||||
{
|
||||
return SitePrefixLength < rhs.SitePrefixLength;
|
||||
}
|
||||
return Metric < rhs.Metric;
|
||||
}
|
||||
return DestinationPrefix.PrefixLength < rhs.DestinationPrefix.PrefixLength;
|
||||
}
|
||||
return DestinationPrefix.Prefix < rhs.DestinationPrefix.Prefix;
|
||||
}
|
||||
return NextHop < rhs.NextHop;
|
||||
}
|
||||
return Family < rhs.Family;
|
||||
}
|
||||
};
|
||||
|
||||
struct NetworkSettings
|
||||
{
|
||||
NetworkSettings() = default;
|
||||
|
||||
NetworkSettings(
|
||||
const GUID& interfaceGuid,
|
||||
EndpointIpAddress preferredIpAddress,
|
||||
EndpointRoute gateway,
|
||||
std::wstring macAddress,
|
||||
std::wstring deviceName,
|
||||
uint32_t interfaceIndex,
|
||||
uint32_t mediaType,
|
||||
const std::wstring& dnsServerList) :
|
||||
InterfaceGuid(interfaceGuid),
|
||||
PreferredIpAddress(std::move(preferredIpAddress)),
|
||||
MacAddress(std::move(macAddress)),
|
||||
DeviceName(std::move(deviceName)),
|
||||
InterfaceIndex(interfaceIndex),
|
||||
InterfaceType(mediaType)
|
||||
{
|
||||
Routes.emplace(std::move(gateway));
|
||||
DnsServers = wsl::shared::string::Split(dnsServerList, L',');
|
||||
}
|
||||
|
||||
GUID InterfaceGuid{};
|
||||
EndpointIpAddress PreferredIpAddress{};
|
||||
std::set<EndpointIpAddress> IpAddresses{}; // Does not include PreferredIpAddress.
|
||||
std::set<EndpointRoute> Routes{};
|
||||
std::vector<std::wstring> DnsServers{};
|
||||
std::wstring MacAddress;
|
||||
std::wstring DeviceName;
|
||||
IF_INDEX InterfaceIndex = 0;
|
||||
IFTYPE InterfaceType = 0;
|
||||
ULONG IPv4InterfaceMtu = 0;
|
||||
ULONG IPv6InterfaceMtu = 0;
|
||||
// some interfaces will only have an IPv4 or IPv6 interface
|
||||
std::optional<ULONG> IPv4InterfaceMetric = 0;
|
||||
std::optional<ULONG> IPv6InterfaceMetric = 0;
|
||||
bool IsHidden = false;
|
||||
bool IsConnected = false;
|
||||
bool IsMetered = false;
|
||||
bool DisableIpv4DefaultRoutes = false;
|
||||
bool DisableIpv6DefaultRoutes = false;
|
||||
bool PendingUpdateToReconnectForMetered = false;
|
||||
bool PendingIPInterfaceUpdate = false;
|
||||
|
||||
auto operator<=>(const NetworkSettings&) const = default;
|
||||
|
||||
std::wstring GetBestGatewayAddressString() const
|
||||
{
|
||||
// Best is currently defined as simply the first IPv4 gateway.
|
||||
for (const auto& route : Routes)
|
||||
{
|
||||
if (route.Family == AF_INET && route.DestinationPrefix.PrefixLength == 0 && route.DestinationPrefixString == LX_INIT_UNSPECIFIED_ADDRESS)
|
||||
{
|
||||
return route.NextHopString;
|
||||
}
|
||||
}
|
||||
|
||||
return {};
|
||||
}
|
||||
|
||||
SOCKADDR_INET GetBestGatewayAddress() const
|
||||
{
|
||||
// Best is currently defined as simply the first IPv4 gateway.
|
||||
for (const auto& route : Routes)
|
||||
{
|
||||
if (route.Family == AF_INET && route.DestinationPrefix.PrefixLength == 0 && route.DestinationPrefixString == LX_INIT_UNSPECIFIED_ADDRESS)
|
||||
{
|
||||
return route.NextHop;
|
||||
}
|
||||
}
|
||||
|
||||
return {};
|
||||
}
|
||||
|
||||
std::wstring IpAddressesString() const
|
||||
{
|
||||
return std::accumulate(std::begin(IpAddresses), std::end(IpAddresses), std::wstring{}, [](const std::wstring& prev, const auto& addr) {
|
||||
return addr.AddressString + (prev.empty() ? L"" : L"," + prev);
|
||||
});
|
||||
}
|
||||
|
||||
std::wstring RoutesString() const
|
||||
{
|
||||
return std::accumulate(std::begin(Routes), std::end(Routes), std::wstring{}, [](const std::wstring& prev, const EndpointRoute& route) {
|
||||
return route.ToString() + (prev.empty() ? L"" : L"," + prev);
|
||||
});
|
||||
}
|
||||
|
||||
std::wstring DnsServersString() const
|
||||
{
|
||||
return wsl::shared::string::Join(DnsServers, L',');
|
||||
}
|
||||
|
||||
// will return ULONG_MAX if there's no configured MTU
|
||||
ULONG GetEffectiveMtu() const noexcept
|
||||
{
|
||||
return std::min(IPv4InterfaceMtu > 0 ? IPv4InterfaceMtu : ULONG_MAX, IPv6InterfaceMtu > 0 ? IPv6InterfaceMtu : ULONG_MAX);
|
||||
}
|
||||
|
||||
// will return zero if there's no configured metric
|
||||
ULONG GetMinimumMetric() const noexcept
|
||||
{
|
||||
if (!IPv4InterfaceMetric.has_value() && !IPv6InterfaceMetric.has_value())
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
if (!IPv4InterfaceMetric.has_value())
|
||||
{
|
||||
return IPv6InterfaceMetric.value();
|
||||
}
|
||||
if (!IPv6InterfaceMetric.has_value())
|
||||
{
|
||||
return IPv4InterfaceMetric.value();
|
||||
}
|
||||
return std::min(IPv4InterfaceMetric.value(), IPv6InterfaceMetric.value());
|
||||
}
|
||||
};
|
||||
|
||||
std::shared_ptr<NetworkSettings> GetEndpointSettings(const wsl::shared::hns::HNSEndpoint& properties);
|
||||
std::shared_ptr<NetworkSettings> GetHostEndpointSettings();
|
||||
|
||||
#define TRACE_NETWORKSETTINGS_OBJECT(settings) \
|
||||
TraceLoggingValue((settings)->InterfaceGuid, "interfaceGuid"), TraceLoggingValue((settings)->InterfaceIndex, "interfaceIndex"), \
|
||||
TraceLoggingValue((settings)->InterfaceType, "interfaceType"), \
|
||||
TraceLoggingValue((settings)->IsConnected, "isConnected"), TraceLoggingValue((settings)->IsMetered, "isMetered"), \
|
||||
TraceLoggingValue((settings)->GetBestGatewayAddressString().c_str(), "bestGatewayAddress"), \
|
||||
TraceLoggingValue((settings)->PreferredIpAddress.AddressString.c_str(), "preferredIpAddress"), \
|
||||
TraceLoggingValue((settings)->PreferredIpAddress.PrefixLength, "preferredIpAddressPrefixLength"), \
|
||||
TraceLoggingValue((settings)->IpAddressesString().c_str(), "ipAddresses"), \
|
||||
TraceLoggingValue((settings)->RoutesString().c_str(), "routes"), \
|
||||
TraceLoggingValue((settings)->DnsServersString().c_str(), "dnsServerList"), \
|
||||
TraceLoggingValue((settings)->MacAddress.c_str(), "macAddress"), \
|
||||
TraceLoggingValue((settings)->IPv4InterfaceMtu, "IPv4InterfaceMtu"), \
|
||||
TraceLoggingValue((settings)->IPv6InterfaceMtu, "IPv6InterfaceMtu"), \
|
||||
TraceLoggingValue((settings)->IPv4InterfaceMetric.value_or(0xffffffff), "IPv4InterfaceMetric"), \
|
||||
TraceLoggingValue((settings)->IPv6InterfaceMetric.value_or(0xffffffff), "IPv6InterfaceMetric"), \
|
||||
TraceLoggingValue((settings)->PendingIPInterfaceUpdate, "PendingIPInterfaceUpdate"), \
|
||||
TraceLoggingValue((settings)->PendingUpdateToReconnectForMetered, "PendingUpdateToReconnectForMetered")
|
||||
|
||||
} // namespace wsl::core::networking
|
||||
@ -44,6 +44,7 @@ inline auto c_msixPackageFamilyName = L"MicrosoftCorporationII.WindowsSubsystemF
|
||||
inline auto c_githubUrlOverrideRegistryValue = L"GitHubUrlOverride";
|
||||
inline auto c_vhdFileExtension = L".vhd";
|
||||
inline auto c_vhdxFileExtension = L".vhdx";
|
||||
inline constexpr auto c_vmOwner = L"WSL"; // TODO-WSLA: Does this apply to WSLA ?
|
||||
|
||||
struct GitHubReleaseAsset
|
||||
{
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
set(SOURCES
|
||||
DistributionRegistration.cpp
|
||||
LxssSecurity.cpp
|
||||
LxssUserCallback.cpp
|
||||
LxssUserSession.cpp
|
||||
LxssUserSessionFactory.cpp
|
||||
@ -11,10 +10,6 @@ set(SOURCES
|
||||
PluginManager.cpp
|
||||
ServiceMain.cpp
|
||||
BridgedNetworking.cpp
|
||||
DeviceHostProxy.cpp
|
||||
Dmesg.cpp
|
||||
DnsTunnelingChannel.cpp
|
||||
GnsChannel.cpp
|
||||
GnsPortTrackerChannel.cpp
|
||||
GnsRpcServer.cpp
|
||||
GuestTelemetryLogger.cpp
|
||||
@ -22,21 +17,12 @@ set(SOURCES
|
||||
LxssConsoleManager.cpp
|
||||
LxssCreateProcess.cpp
|
||||
MirroredNetworking.cpp
|
||||
NatNetworking.cpp
|
||||
RingBuffer.cpp
|
||||
WslCoreFilesystem.cpp
|
||||
WslCoreGuestNetworkService.cpp
|
||||
WslCoreHostDnsInfo.cpp
|
||||
WslCoreInstance.cpp
|
||||
WslMirroredNetworking.cpp
|
||||
WslCoreNetworkEndpointSettings.cpp
|
||||
DnsResolver.cpp
|
||||
WslCoreTcpIpStateTracking.cpp
|
||||
WslCoreVm.cpp
|
||||
VirtioNetworking.cpp
|
||||
WSLAUserSession.cpp
|
||||
WSLAUserSessionFactory.cpp
|
||||
WSLAVirtualMachine.cpp
|
||||
main.rc
|
||||
${CMAKE_CURRENT_BINARY_DIR}/../mc/${TARGET_PLATFORM}/${CMAKE_BUILD_TYPE}/wsleventschema.rc
|
||||
application.manifest)
|
||||
@ -44,7 +30,6 @@ set(SOURCES
|
||||
set(HEADERS
|
||||
../../inc/comservicehelper.h
|
||||
DistributionRegistration.h
|
||||
LxssSecurity.h
|
||||
LxssUserCallback.h
|
||||
LxssUserSession.h
|
||||
LxssUserSessionFactory.h
|
||||
@ -53,35 +38,20 @@ set(HEADERS
|
||||
PluginManager.h
|
||||
LxssInstance.h
|
||||
BridgedNetworking.h
|
||||
DeviceHostProxy.h
|
||||
Dmesg.h
|
||||
DnsTunnelingChannel.h
|
||||
GnsChannel.h
|
||||
GnsPortTrackerChannel.h
|
||||
GnsRpcServer.h
|
||||
GuestTelemetryLogger.h
|
||||
INetworkingEngine.h
|
||||
IMirroredNetworkManager.h
|
||||
Lifetime.h
|
||||
LxssConsoleManager.h
|
||||
LxssCreateProcess.h
|
||||
MirroredNetworking.h
|
||||
NatNetworking.h
|
||||
RingBuffer.h
|
||||
WslCoreFilesystem.h
|
||||
WslCoreGuestNetworkService.h
|
||||
WslCoreHostDnsInfo.h
|
||||
WslCoreInstance.h
|
||||
WslCoreMessageQueue.h
|
||||
WslMirroredNetworking.h
|
||||
WslCoreNetworkEndpoint.h
|
||||
WslCoreNetworkEndpointSettings.h
|
||||
DnsResolver.h
|
||||
WslCoreTcpIpStateTracking.h
|
||||
WslCoreVm.h
|
||||
WSLAUserSession.h
|
||||
WSLAUserSessionFactory.h
|
||||
WSLAVirtualMachine.h)
|
||||
WslCoreVm.h)
|
||||
|
||||
include_directories(${CMAKE_SOURCE_DIR}/src/windows/wslaclient)
|
||||
|
||||
|
||||
@ -256,7 +256,7 @@ void MirroredNetworking::Initialize()
|
||||
m_config.FirewallConfig.Enabled(), m_config.IgnoredPorts, m_runtimeId, m_gnsRpcServer->GetServerUuid(), s_GuestNetworkServiceCallback, this);
|
||||
m_ephemeralPortRange = m_guestNetworkService.AllocateEphemeralPortRange();
|
||||
|
||||
networking::ConfigureHyperVFirewall(m_config.FirewallConfig, c_vmOwner);
|
||||
networking::ConfigureHyperVFirewall(m_config.FirewallConfig, wsl::windows::common::wslutil::c_vmOwner);
|
||||
|
||||
// must keep all m_networkManager interactions (including) creation queued
|
||||
// also must queue GNS callbacks to keep them serialized
|
||||
|
||||
@ -31,7 +31,6 @@ wil::unique_event g_networkingReady{wil::EventOptions::ManualReset};
|
||||
|
||||
// Declare the LxssUserSession COM class.
|
||||
CoCreatableClassWrlCreatorMapInclude(LxssUserSession);
|
||||
CoCreatableClassWrlCreatorMapInclude(WSLAUserSession);
|
||||
|
||||
struct WslServiceSecurityPolicy
|
||||
{
|
||||
@ -241,7 +240,6 @@ void WslService::ServiceStopped()
|
||||
|
||||
// Terminate all user sessions.
|
||||
ClearSessionsAndBlockNewInstances();
|
||||
wsl::windows::service::wsla::ClearWslaSessionsAndBlockNewInstances();
|
||||
|
||||
// Disconnect from the LxCore driver.
|
||||
if (g_lxcoreInitialized)
|
||||
|
||||
@ -1468,7 +1468,7 @@ void WslCoreVm::FreeLun(_In_ ULONG lun)
|
||||
std::wstring WslCoreVm::GenerateConfigJson()
|
||||
{
|
||||
hcs::ComputeSystem systemSettings{};
|
||||
systemSettings.Owner = c_vmOwner;
|
||||
systemSettings.Owner = wsl::windows::common::wslutil::c_vmOwner;
|
||||
systemSettings.ShouldTerminateOnLastHandleClosed = true;
|
||||
systemSettings.SchemaVersion.Major = 2;
|
||||
systemSettings.SchemaVersion.Minor = 3;
|
||||
@ -1575,7 +1575,7 @@ std::wstring WslCoreVm::GenerateConfigJson()
|
||||
// Set the vmmem suffix which will change the process name in task manager.
|
||||
if (helpers::IsVmemmSuffixSupported())
|
||||
{
|
||||
vmSettings.ComputeTopology.Memory.HostingProcessNameSuffix = c_vmOwner;
|
||||
vmSettings.ComputeTopology.Memory.HostingProcessNameSuffix = wsl::windows::common::wslutil::c_vmOwner;
|
||||
}
|
||||
|
||||
// If nested virtualization was requested, ensure the platform supports it.
|
||||
|
||||
@ -39,8 +39,6 @@ inline constexpr auto c_optionsValueName = L"Options";
|
||||
inline constexpr auto c_typeValueName = L"Type";
|
||||
inline constexpr auto c_mountNameValueName = L"Name";
|
||||
|
||||
inline constexpr auto c_vmOwner = L"WSL";
|
||||
|
||||
static constexpr GUID c_virtiofsAdminClassId = {0x7e6ad219, 0xd1b3, 0x42d5, {0xb8, 0xee, 0xd9, 0x63, 0x24, 0xe6, 0x4f, 0xf6}};
|
||||
|
||||
// {60285AE6-AAF3-4456-B444-A6C2D0DEDA38}
|
||||
|
||||
@ -396,97 +396,3 @@ cpp_quote("#define WSL_E_DISK_CORRUPTED MAKE_HRESULT(SEVERITY_ERROR, FACILITY_IT
|
||||
cpp_quote("#define WSL_E_DISTRIBUTION_NAME_NEEDED MAKE_HRESULT(SEVERITY_ERROR, FACILITY_ITF, WSL_E_BASE + 0x30) /* 0x80040330 */")
|
||||
cpp_quote("#define WSL_E_INVALID_JSON MAKE_HRESULT(SEVERITY_ERROR, FACILITY_ITF, WSL_E_BASE + 0x31) /* 0x80040331 */")
|
||||
cpp_quote("#define WSL_E_VM_CRASHED MAKE_HRESULT(SEVERITY_ERROR, FACILITY_ITF, WSL_E_BASE + 0x32) /* 0x80040332 */")
|
||||
|
||||
|
||||
typedef
|
||||
struct _WSL_VERSION {
|
||||
ULONG Major;
|
||||
ULONG Minor;
|
||||
ULONG Revision;
|
||||
} WSL_VERSION;
|
||||
|
||||
|
||||
typedef [system_handle(sh_socket)] HANDLE HVSOCKET_HANDLE;
|
||||
|
||||
|
||||
typedef
|
||||
struct _WSLA_CREATE_PROCESS_OPTIONS {
|
||||
[string] LPCSTR Executable;
|
||||
ULONG CommandLineCount;
|
||||
[unique, size_is(CommandLineCount)] LPCSTR* CommandLine;
|
||||
ULONG EnvironmentCount;
|
||||
[unique, size_is(EnvironmentCount)] LPCSTR* Environment;
|
||||
[unique] LPCSTR CurrentDirectory;
|
||||
} WSLA_CREATE_PROCESS_OPTIONS;
|
||||
|
||||
typedef struct _WSLA_PROCESS_FD
|
||||
{
|
||||
LONG Fd;
|
||||
int Type;
|
||||
[string, unique] LPCSTR Path;
|
||||
} WSLA_PROCESS_FD;
|
||||
|
||||
typedef
|
||||
struct _WSLA_CREATE_PROCESS_RESULT {
|
||||
int Errno;
|
||||
int Pid;
|
||||
} WSLA_CREATE_PROCESS_RESULT;
|
||||
|
||||
[
|
||||
uuid(7BC4E198-6531-4FA6-ADE2-5EF3D2A04DFE),
|
||||
pointer_default(unique),
|
||||
object
|
||||
]
|
||||
interface ITerminationCallback : IUnknown
|
||||
{
|
||||
HRESULT OnTermination(ULONG Reason, LPCWSTR Details);
|
||||
};
|
||||
|
||||
[
|
||||
uuid(82A7ABC8-6B50-43FC-AB96-15FBBE7E8761),
|
||||
pointer_default(unique),
|
||||
object
|
||||
]
|
||||
interface IWSLAVirtualMachine : IUnknown
|
||||
{
|
||||
HRESULT AttachDisk([in] LPCWSTR Path, [in] BOOL ReadOnly, [out] LPSTR* Device, [out] ULONG* Lun);
|
||||
HRESULT Mount([in, unique] LPCSTR Source, [in] LPCSTR Target, [in] LPCSTR Type, [in] LPCSTR Options, [in] ULONG Flags);
|
||||
HRESULT CreateLinuxProcess([in] const WSLA_CREATE_PROCESS_OPTIONS* Options, [in] ULONG FdCount, [in, unique, size_is(FdCount)] WSLA_PROCESS_FD* Fds, [out, size_is(FdCount)] ULONG* Handles, [out] WSLA_CREATE_PROCESS_RESULT* Result);
|
||||
HRESULT WaitPid([in] LONG Pid, [in] ULONGLONG TimeoutMs, [out] ULONG* State, [out] int* Code);
|
||||
HRESULT Signal([in] LONG Pid, [in] int Signal);
|
||||
HRESULT Shutdown([in] ULONGLONG TimeoutMs);
|
||||
HRESULT RegisterCallback([in] ITerminationCallback* terminationCallback);
|
||||
HRESULT GetDebugShellPipe([out] LPWSTR* pipePath);
|
||||
HRESULT MapPort([in] int Family, [in] short WindowsPort, [in] short LinuxPort, [in] BOOL Remove);
|
||||
HRESULT Unmount([in] LPCSTR Path);
|
||||
HRESULT DetachDisk([in] ULONG Lun);
|
||||
HRESULT MountWindowsFolder([in] LPCWSTR WindowsPath, [in] LPCSTR LinuxPath, [in] BOOL ReadOnly);
|
||||
HRESULT UnmountWindowsFolder([in] LPCSTR LinuxPath);
|
||||
HRESULT MountGpuLibraries([in] LPCSTR LibrariesMountPoint, [in] LPCSTR DriversMountpoint, [in] DWORD Flags);
|
||||
}
|
||||
|
||||
typedef
|
||||
struct _VIRTUAL_MACHINE_SETTINGS {
|
||||
LPCWSTR DisplayName;
|
||||
ULONGLONG MemoryMb;
|
||||
ULONG CpuCount;
|
||||
ULONG BootTimeoutMs;
|
||||
ULONG DmesgOutput;
|
||||
ULONG NetworkingMode;
|
||||
BOOL EnableDnsTunneling;
|
||||
BOOL EnableDebugShell;
|
||||
BOOL EnableEarlyBootDmesg;
|
||||
BOOL EnableGPU;
|
||||
} VIRTUAL_MACHINE_SETTINGS;
|
||||
|
||||
|
||||
[
|
||||
uuid(82A7ABC8-6B50-43FC-AB96-15FBBE7E8760),
|
||||
pointer_default(unique),
|
||||
object
|
||||
]
|
||||
interface IWSLAUserSession : IUnknown
|
||||
{
|
||||
HRESULT GetVersion([out] WSL_VERSION* Error);
|
||||
HRESULT CreateVirtualMachine([in] const VIRTUAL_MACHINE_SETTINGS* Settings, [out]IWSLAVirtualMachine** VirtualMachine);
|
||||
}
|
||||
@ -3,6 +3,7 @@ set(HEADERS WSLAApi.h)
|
||||
|
||||
add_library(wslaclient SHARED ${SOURCES} ${HEADERS} wslaclient.def)
|
||||
set_target_properties(wslaclient PROPERTIES EXCLUDE_FROM_ALL FALSE)
|
||||
add_dependencies(wslaclient wslaserviceidl)
|
||||
target_link_libraries(wslaclient ${COMMON_LINK_LIBRARIES} legacy_stdio_definitions common)
|
||||
target_precompile_headers(wslaclient REUSE_FROM common)
|
||||
set_target_properties(wslaclient PROPERTIES FOLDER windows)
|
||||
@ -13,7 +13,7 @@ Abstract:
|
||||
--*/
|
||||
|
||||
#include "precomp.h"
|
||||
#include "wslservice.h"
|
||||
#include "wslaservice.h"
|
||||
#include "WSLAApi.h"
|
||||
#include "wslrelay.h"
|
||||
#include "wslInstall.h"
|
||||
|
||||
6
src/windows/wslaservice/CMakeLists.txt
Normal file
6
src/windows/wslaservice/CMakeLists.txt
Normal file
@ -0,0 +1,6 @@
|
||||
add_compile_definitions("PROXY_CLSID_IS={0x4EA0C6DD,0xE9FF,0x48E7,{0x99,0x4e,0x13,0xa3,0x1d,0x10,0xdc,0x61}}")
|
||||
add_compile_definitions("REGISTER_PROXY_DLL")
|
||||
|
||||
add_subdirectory(exe)
|
||||
add_subdirectory(inc)
|
||||
add_subdirectory(stub)
|
||||
34
src/windows/wslaservice/exe/CMakeLists.txt
Normal file
34
src/windows/wslaservice/exe/CMakeLists.txt
Normal file
@ -0,0 +1,34 @@
|
||||
set(SOURCES
|
||||
application.manifest
|
||||
main.rc
|
||||
ServiceMain.cpp
|
||||
WSLAUserSession.cpp
|
||||
WSLAUserSessionFactory.cpp
|
||||
WSLAVirtualMachine.cpp
|
||||
)
|
||||
|
||||
set(HEADERS
|
||||
WSLAUserSession.h
|
||||
WSLAUserSessionFactory.h
|
||||
WSLAVirtualMachine.h)
|
||||
|
||||
include_directories(${CMAKE_SOURCE_DIR}/src/windows/wslaclient)
|
||||
|
||||
add_executable(wslaservice ${SOURCES} ${HEADERS})
|
||||
add_dependencies(wslaservice wslaserviceidl)
|
||||
add_compile_definitions(__WRL_CLASSIC_COM__)
|
||||
add_compile_definitions(__WRL_DISABLE_STATIC_INITIALIZE__)
|
||||
add_compile_definitions(USE_COM_CONTEXT_DEF=1)
|
||||
set_target_properties(wslaservice PROPERTIES LINK_FLAGS "/merge:minATL=.rdata /include:__minATLObjMap_WSLAUserSession_COM")
|
||||
target_link_libraries(wslaservice
|
||||
${COMMON_LINK_LIBRARIES}
|
||||
computecore
|
||||
common
|
||||
configfile
|
||||
legacy_stdio_definitions
|
||||
VirtDisk.lib
|
||||
Winhttp.lib
|
||||
Synchronization.lib)
|
||||
|
||||
target_precompile_headers(wslaservice REUSE_FROM common)
|
||||
set_target_properties(wslaservice PROPERTIES FOLDER windows)
|
||||
126
src/windows/wslaservice/exe/ServiceMain.cpp
Normal file
126
src/windows/wslaservice/exe/ServiceMain.cpp
Normal file
@ -0,0 +1,126 @@
|
||||
/*++
|
||||
|
||||
Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
Module Name:
|
||||
|
||||
ServiceMain.cpp
|
||||
|
||||
Abstract:
|
||||
|
||||
This file contains the entrypoint for the Lxss Manager service.
|
||||
|
||||
--*/
|
||||
|
||||
#include "precomp.h"
|
||||
#include "comservicehelper.h"
|
||||
#include "LxssSecurity.h"
|
||||
#include "WslCoreFilesystem.h"
|
||||
#include "WSLAUserSessionFactory.h"
|
||||
#include <ctime>
|
||||
|
||||
using namespace wsl::windows::common::registry;
|
||||
using namespace wsl::windows::common::string;
|
||||
using namespace wsl::windows::common::wslutil;
|
||||
using namespace wsl::windows::policies;
|
||||
|
||||
wil::unique_event g_networkingReady{wil::EventOptions::ManualReset};
|
||||
|
||||
// Declare the WSLAUserSession COM class.
|
||||
CoCreatableClassWrlCreatorMapInclude(WSLAUserSession);
|
||||
|
||||
struct WslaServiceSecurityPolicy
|
||||
{
|
||||
static LPCWSTR GetSDDLText()
|
||||
{
|
||||
// COM Access and Launch permissions allowed for authenticated user, principal self, and system.
|
||||
// 0xB = (COM_RIGHTS_EXECUTE | COM_RIGHTS_EXECUTE_LOCAL | COM_RIGHTS_ACTIVATE_LOCAL)
|
||||
// N.B. This should be kept in sync with the security descriptors in the appxmanifest and package.wix.
|
||||
return L"O:BAG:BAD:(A;;0xB;;;AU)(A;;0xB;;;PS)(A;;0xB;;;SY)";
|
||||
}
|
||||
};
|
||||
|
||||
class WslaService : public Windows::Internal::Service<WslaService, Windows::Internal::ContinueRunningWithNoObjects, WslaServiceSecurityPolicy>
|
||||
{
|
||||
public:
|
||||
static wchar_t* GetName()
|
||||
{
|
||||
return const_cast<LPWSTR>(L"WSLAService");
|
||||
}
|
||||
|
||||
static void OnSessionChanged(DWORD eventType, DWORD sessionId);
|
||||
HRESULT OnServiceStarting();
|
||||
HRESULT ServiceStarted();
|
||||
void ServiceStopped();
|
||||
|
||||
private:
|
||||
wil::unique_couninitialize_call m_coInit{false};
|
||||
};
|
||||
|
||||
HRESULT WslaService::OnServiceStarting()
|
||||
try
|
||||
{
|
||||
ConfigureCrt();
|
||||
|
||||
// Enable contextualized errors
|
||||
wsl::windows::common::EnableContextualizedErrors(true);
|
||||
|
||||
// Initialize telemetry.
|
||||
// TODO-WSLA: Create a dedicated WSLA provider
|
||||
WslTraceLoggingInitialize(WslServiceTelemetryProvider, !wsl::shared::OfficialBuild);
|
||||
|
||||
WSL_LOG("Service starting", TraceLoggingLevel(WINEVENT_LEVEL_INFO));
|
||||
|
||||
// Don't kill the process on unknown C++ exceptions.
|
||||
wil::g_fResultFailFastUnknownExceptions = false;
|
||||
|
||||
// wsl::windows::common::security::ApplyProcessMitigationPolicies();
|
||||
|
||||
// Initialize Winsock.
|
||||
WSADATA Data;
|
||||
THROW_IF_WIN32_ERROR(WSAStartup(MAKEWORD(2, 2), &Data));
|
||||
|
||||
return S_OK;
|
||||
}
|
||||
CATCH_RETURN()
|
||||
|
||||
HRESULT WslaService::ServiceStarted()
|
||||
{
|
||||
m_coInit = wil::CoInitializeEx(COINIT_MULTITHREADED);
|
||||
|
||||
return S_OK;
|
||||
}
|
||||
|
||||
void WslaService::OnSessionChanged(DWORD eventType, DWORD sessionId)
|
||||
{
|
||||
if (eventType == WTS_SESSION_LOGOFF)
|
||||
{
|
||||
// TODO-WSLA: Implement for WSLA
|
||||
// TerminateSession(sessionId);
|
||||
}
|
||||
}
|
||||
|
||||
void WslaService::ServiceStopped()
|
||||
{
|
||||
WSL_LOG("Service stopping", TraceLoggingLevel(WINEVENT_LEVEL_INFO));
|
||||
|
||||
// Terminate all user sessions.
|
||||
wsl::windows::service::wsla::ClearWslaSessionsAndBlockNewInstances();
|
||||
|
||||
// There is a potential deadlock if CoUninitialize() is called before the LanguageChangeNotifyThread
|
||||
// isn't done initializing. Clearing the COM objects before calling CoUninitialize() works around the issue.
|
||||
winrt::clear_factory_cache();
|
||||
|
||||
// Tear down telemetry.
|
||||
WslTraceLoggingUninitialize();
|
||||
|
||||
// uninitialize COM. This must be done here because this call can cause cleanups that will be fail
|
||||
// if the CRT is shutting down.
|
||||
m_coInit.reset();
|
||||
}
|
||||
|
||||
int __cdecl wmain()
|
||||
{
|
||||
WslaService::ProcessMain();
|
||||
return 0;
|
||||
}
|
||||
87
src/windows/wslaservice/exe/WSLAUserSession.cpp
Normal file
87
src/windows/wslaservice/exe/WSLAUserSession.cpp
Normal file
@ -0,0 +1,87 @@
|
||||
/*++
|
||||
|
||||
Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
Module Name:
|
||||
|
||||
WSLAUserSession.cpp
|
||||
|
||||
Abstract:
|
||||
|
||||
TODO
|
||||
|
||||
--*/
|
||||
#include "WSLAUserSession.h"
|
||||
|
||||
using wsl::windows::service::wsla::WSLAUserSessionImpl;
|
||||
|
||||
WSLAUserSessionImpl::WSLAUserSessionImpl(HANDLE Token, wil::unique_tokeninfo_ptr<TOKEN_USER>&& TokenInfo) :
|
||||
m_tokenInfo(std::move(TokenInfo))
|
||||
{
|
||||
}
|
||||
|
||||
WSLAUserSessionImpl::~WSLAUserSessionImpl()
|
||||
{
|
||||
// Manually signal the VM termination events. This prevents being stuck on an API call that holds the VM lock.
|
||||
{
|
||||
std::lock_guard lock(m_virtualMachinesLock);
|
||||
|
||||
for (auto* e : m_virtualMachines)
|
||||
{
|
||||
e->OnSessionTerminating();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void WSLAUserSessionImpl::OnVmTerminated(WSLAVirtualMachine* machine)
|
||||
{
|
||||
std::lock_guard lock(m_virtualMachinesLock);
|
||||
auto pred = [machine](const auto* e) { return machine == e; };
|
||||
|
||||
// Remove any stale VM reference.
|
||||
m_virtualMachines.erase(std::remove_if(m_virtualMachines.begin(), m_virtualMachines.end(), pred), m_virtualMachines.end());
|
||||
}
|
||||
|
||||
HRESULT WSLAUserSessionImpl::CreateVirtualMachine(const VIRTUAL_MACHINE_SETTINGS* Settings, IWSLAVirtualMachine** VirtualMachine)
|
||||
{
|
||||
auto vm = wil::MakeOrThrow<WSLAVirtualMachine>(*Settings, GetUserSid(), this);
|
||||
|
||||
{
|
||||
std::lock_guard lock(m_virtualMachinesLock);
|
||||
m_virtualMachines.emplace_back(vm.Get());
|
||||
}
|
||||
|
||||
vm->Start();
|
||||
THROW_IF_FAILED(vm.CopyTo(__uuidof(IWSLAVirtualMachine), (void**)VirtualMachine));
|
||||
|
||||
return S_OK;
|
||||
}
|
||||
|
||||
PSID WSLAUserSessionImpl::GetUserSid() const
|
||||
{
|
||||
return m_tokenInfo->User.Sid;
|
||||
}
|
||||
|
||||
wsl::windows::service::wsla::WSLAUserSession::WSLAUserSession(std::weak_ptr<WSLAUserSessionImpl>&& Session) :
|
||||
m_session(std::move(Session))
|
||||
{
|
||||
}
|
||||
|
||||
HRESULT wsl::windows::service::wsla::WSLAUserSession::GetVersion(_Out_ WSL_VERSION* Version)
|
||||
{
|
||||
Version->Major = WSL_PACKAGE_VERSION_MAJOR;
|
||||
Version->Minor = WSL_PACKAGE_VERSION_MINOR;
|
||||
Version->Revision = WSL_PACKAGE_VERSION_REVISION;
|
||||
|
||||
return S_OK;
|
||||
}
|
||||
|
||||
HRESULT wsl::windows::service::wsla::WSLAUserSession::CreateVirtualMachine(const VIRTUAL_MACHINE_SETTINGS* Settings, IWSLAVirtualMachine** VirtualMachine)
|
||||
try
|
||||
{
|
||||
auto session = m_session.lock();
|
||||
RETURN_HR_IF(RPC_E_DISCONNECTED, !session);
|
||||
|
||||
return session->CreateVirtualMachine(Settings, VirtualMachine);
|
||||
}
|
||||
CATCH_RETURN();
|
||||
56
src/windows/wslaservice/exe/WSLAUserSession.h
Normal file
56
src/windows/wslaservice/exe/WSLAUserSession.h
Normal file
@ -0,0 +1,56 @@
|
||||
/*++
|
||||
|
||||
Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
Module Name:
|
||||
|
||||
WSLAUserSession.h
|
||||
|
||||
Abstract:
|
||||
|
||||
TODO
|
||||
|
||||
--*/
|
||||
#pragma once
|
||||
#include "WSLAVirtualMachine.h"
|
||||
|
||||
namespace wsl::windows::service::wsla {
|
||||
|
||||
class WSLAUserSessionImpl
|
||||
{
|
||||
public:
|
||||
WSLAUserSessionImpl(HANDLE Token, wil::unique_tokeninfo_ptr<TOKEN_USER>&& TokenInfo);
|
||||
WSLAUserSessionImpl(WSLAUserSessionImpl&&) = default;
|
||||
WSLAUserSessionImpl& operator=(WSLAUserSessionImpl&&) = default;
|
||||
|
||||
~WSLAUserSessionImpl();
|
||||
|
||||
PSID GetUserSid() const;
|
||||
|
||||
HRESULT CreateVirtualMachine(const VIRTUAL_MACHINE_SETTINGS* Settings, IWSLAVirtualMachine** VirtualMachine);
|
||||
|
||||
void OnVmTerminated(WSLAVirtualMachine* machine);
|
||||
|
||||
private:
|
||||
wil::unique_tokeninfo_ptr<TOKEN_USER> m_tokenInfo;
|
||||
|
||||
std::recursive_mutex m_virtualMachinesLock;
|
||||
std::vector<WSLAVirtualMachine*> m_virtualMachines;
|
||||
};
|
||||
|
||||
class DECLSPEC_UUID("a9b7a1b9-0671-405c-95f1-e0612cb4ce8f") WSLAUserSession
|
||||
: public Microsoft::WRL::RuntimeClass<Microsoft::WRL::RuntimeClassFlags<Microsoft::WRL::ClassicCom>, IWSLAUserSession, IFastRundown>
|
||||
{
|
||||
public:
|
||||
WSLAUserSession(std::weak_ptr<WSLAUserSessionImpl>&& Session);
|
||||
WSLAUserSession(const WSLAUserSession&) = delete;
|
||||
WSLAUserSession& operator=(const WSLAUserSession&) = delete;
|
||||
|
||||
IFACEMETHOD(GetVersion)(_Out_ WSL_VERSION* Version) override;
|
||||
IFACEMETHOD(CreateVirtualMachine)(const VIRTUAL_MACHINE_SETTINGS* Settings, IWSLAVirtualMachine** VirtualMachine) override;
|
||||
|
||||
private:
|
||||
std::weak_ptr<WSLAUserSessionImpl> m_session;
|
||||
};
|
||||
|
||||
} // namespace wsl::windows::service::wsla
|
||||
82
src/windows/wslaservice/exe/WSLAUserSessionFactory.cpp
Normal file
82
src/windows/wslaservice/exe/WSLAUserSessionFactory.cpp
Normal file
@ -0,0 +1,82 @@
|
||||
/*++
|
||||
|
||||
Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
Module Name:
|
||||
|
||||
WSLAUserSessionFactory.cpp
|
||||
|
||||
Abstract:
|
||||
|
||||
TODO
|
||||
|
||||
--*/
|
||||
#include "precomp.h"
|
||||
|
||||
#include "WSLAUserSessionFactory.h"
|
||||
#include "WSLAUserSession.h"
|
||||
|
||||
using wsl::windows::service::wsla::WSLAUserSessionFactory;
|
||||
using wsl::windows::service::wsla::WSLAUserSessionImpl;
|
||||
|
||||
CoCreatableClassWithFactory(WSLAUserSession, WSLAUserSessionFactory);
|
||||
|
||||
static std::mutex g_mutex;
|
||||
static std::optional<std::vector<std::shared_ptr<WSLAUserSessionImpl>>> g_sessions =
|
||||
std::make_optional<std::vector<std::shared_ptr<WSLAUserSessionImpl>>>();
|
||||
|
||||
HRESULT WSLAUserSessionFactory::CreateInstance(_In_ IUnknown* pUnkOuter, _In_ REFIID riid, _Out_ void** ppCreated)
|
||||
{
|
||||
RETURN_HR_IF_NULL(E_POINTER, ppCreated);
|
||||
*ppCreated = nullptr;
|
||||
|
||||
RETURN_HR_IF(CLASS_E_NOAGGREGATION, pUnkOuter != nullptr);
|
||||
|
||||
WSL_LOG("WSLAUserSessionFactory", TraceLoggingLevel(WINEVENT_LEVEL_VERBOSE));
|
||||
|
||||
try
|
||||
{
|
||||
const wil::unique_handle userToken = wsl::windows::common::security::GetUserToken(TokenImpersonation);
|
||||
|
||||
// Get the session ID and SID of the client process.
|
||||
DWORD sessionId{};
|
||||
DWORD length = 0;
|
||||
THROW_IF_WIN32_BOOL_FALSE(::GetTokenInformation(userToken.get(), TokenSessionId, &sessionId, sizeof(sessionId), &length));
|
||||
|
||||
auto tokenInfo = wil::get_token_information<TOKEN_USER>(userToken.get());
|
||||
|
||||
std::lock_guard lock{g_mutex};
|
||||
|
||||
THROW_HR_IF(CO_E_SERVER_STOPPING, !g_sessions.has_value());
|
||||
|
||||
auto session = std::find_if(g_sessions->begin(), g_sessions->end(), [&tokenInfo](auto it) {
|
||||
return EqualSid(it->GetUserSid(), &tokenInfo->User.Sid);
|
||||
});
|
||||
|
||||
if (session == g_sessions->end())
|
||||
{
|
||||
session = g_sessions->insert(g_sessions->end(), std::make_shared<WSLAUserSessionImpl>(userToken.get(), std::move(tokenInfo)));
|
||||
}
|
||||
|
||||
auto comInstance = wil::MakeOrThrow<WSLAUserSession>(std::weak_ptr<WSLAUserSessionImpl>(*session));
|
||||
|
||||
THROW_IF_FAILED(comInstance.CopyTo(riid, ppCreated));
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
const auto result = wil::ResultFromCaughtException();
|
||||
|
||||
// Note: S_FALSE will cause COM to retry if the service is stopping.
|
||||
return result == CO_E_SERVER_STOPPING ? S_FALSE : result;
|
||||
}
|
||||
|
||||
WSL_LOG("WSLAUserSessionFactory", TraceLoggingLevel(WINEVENT_LEVEL_VERBOSE));
|
||||
|
||||
return S_OK;
|
||||
}
|
||||
|
||||
void wsl::windows::service::wsla::ClearWslaSessionsAndBlockNewInstances()
|
||||
{
|
||||
std::lock_guard lock{g_mutex};
|
||||
g_sessions.reset();
|
||||
}
|
||||
28
src/windows/wslaservice/exe/WSLAUserSessionFactory.h
Normal file
28
src/windows/wslaservice/exe/WSLAUserSessionFactory.h
Normal file
@ -0,0 +1,28 @@
|
||||
/*++
|
||||
|
||||
Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
Module Name:
|
||||
|
||||
WSLAUserSessionFactory.h
|
||||
|
||||
Abstract:
|
||||
|
||||
TODO
|
||||
|
||||
--*/
|
||||
#pragma once
|
||||
#include <wil/resource.h>
|
||||
|
||||
namespace wsl::windows::service::wsla {
|
||||
|
||||
class WSLAUserSessionFactory : public Microsoft::WRL::ClassFactory<>
|
||||
{
|
||||
public:
|
||||
WSLAUserSessionFactory() = default;
|
||||
|
||||
STDMETHODIMP CreateInstance(_In_ IUnknown* pUnkOuter, _In_ REFIID riid, _Out_ void** ppCreated) override;
|
||||
};
|
||||
|
||||
void ClearWslaSessionsAndBlockNewInstances();
|
||||
} // namespace wsl::windows::service::wsla
|
||||
1052
src/windows/wslaservice/exe/WSLAVirtualMachine.cpp
Normal file
1052
src/windows/wslaservice/exe/WSLAVirtualMachine.cpp
Normal file
File diff suppressed because it is too large
Load Diff
106
src/windows/wslaservice/exe/WSLAVirtualMachine.h
Normal file
106
src/windows/wslaservice/exe/WSLAVirtualMachine.h
Normal file
@ -0,0 +1,106 @@
|
||||
/*++
|
||||
|
||||
Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
Module Name:
|
||||
|
||||
WSLAVirtualMachine.h
|
||||
|
||||
Abstract:
|
||||
|
||||
TODO
|
||||
|
||||
--*/
|
||||
#pragma once
|
||||
#include "wslaservice.h"
|
||||
#include "INetworkingEngine.h"
|
||||
#include "hcs.hpp"
|
||||
#include "Dmesg.h"
|
||||
#include "WSLAApi.h"
|
||||
|
||||
namespace wsl::windows::service::wsla {
|
||||
|
||||
class WSLAUserSessionImpl;
|
||||
|
||||
class DECLSPEC_UUID("0CFC5DC1-B6A7-45FC-8034-3FA9ED73CE30") WSLAVirtualMachine
|
||||
: public Microsoft::WRL::RuntimeClass<Microsoft::WRL::RuntimeClassFlags<Microsoft::WRL::ClassicCom>, IWSLAVirtualMachine, IFastRundown>
|
||||
|
||||
{
|
||||
public:
|
||||
WSLAVirtualMachine(const VIRTUAL_MACHINE_SETTINGS& Settings, PSID Sid, WSLAUserSessionImpl* UserSession);
|
||||
~WSLAVirtualMachine();
|
||||
|
||||
void Start();
|
||||
void OnSessionTerminating();
|
||||
|
||||
IFACEMETHOD(AttachDisk(_In_ PCWSTR Path, _In_ BOOL ReadOnly, _Out_ LPSTR* Device, _Out_ ULONG* Lun)) override;
|
||||
IFACEMETHOD(Mount(_In_ LPCSTR Source, _In_ LPCSTR Target, _In_ LPCSTR Type, _In_ LPCSTR Options, _In_ ULONG Flags)) override;
|
||||
IFACEMETHOD(CreateLinuxProcess(
|
||||
_In_ const WSLA_CREATE_PROCESS_OPTIONS* Options, _In_ ULONG FdCount, _In_ WSLA_PROCESS_FD* Fd, _Out_ ULONG* Handles, _Out_ WSLA_CREATE_PROCESS_RESULT* Result)) override;
|
||||
IFACEMETHOD(WaitPid(_In_ LONG Pid, _In_ ULONGLONG TimeoutMs, _Out_ ULONG* State, _Out_ int* Code)) override;
|
||||
IFACEMETHOD(Signal(_In_ LONG Pid, _In_ int Signal)) override;
|
||||
IFACEMETHOD(Shutdown(ULONGLONG _In_ TimeoutMs)) override;
|
||||
IFACEMETHOD(RegisterCallback(_In_ ITerminationCallback* callback)) override;
|
||||
IFACEMETHOD(GetDebugShellPipe(_Out_ LPWSTR* pipePath)) override;
|
||||
IFACEMETHOD(MapPort(_In_ int Family, _In_ short WindowsPort, _In_ short LinuxPort, _In_ BOOL Remove)) override;
|
||||
IFACEMETHOD(Unmount(_In_ const char* Path)) override;
|
||||
IFACEMETHOD(DetachDisk(_In_ ULONG Lun)) override;
|
||||
IFACEMETHOD(MountWindowsFolder(_In_ LPCWSTR WindowsPath, _In_ LPCSTR LinuxPath, _In_ BOOL ReadOnly)) override;
|
||||
IFACEMETHOD(UnmountWindowsFolder(_In_ LPCSTR LinuxPath)) override;
|
||||
IFACEMETHOD(MountGpuLibraries(_In_ LPCSTR LibrariesMountPoint, _In_ LPCSTR DriversMountpoint, _In_ DWORD Flags)) override;
|
||||
|
||||
private:
|
||||
static int32_t MountImpl(wsl::shared::SocketChannel& Channel, LPCSTR Source, _In_ LPCSTR Target, _In_ LPCSTR Type, _In_ LPCSTR Options, _In_ ULONG Flags);
|
||||
static void CALLBACK s_OnExit(_In_ HCS_EVENT* Event, _In_opt_ void* Context);
|
||||
static bool ParseTtyInformation(const WSLA_PROCESS_FD* Fds, ULONG FdCount, const WSLA_PROCESS_FD** TtyInput, const WSLA_PROCESS_FD** TtyOutput);
|
||||
|
||||
void ConfigureNetworking();
|
||||
void OnExit(_In_ const HCS_EVENT* Event);
|
||||
|
||||
std::tuple<int32_t, int32_t, wsl::shared::SocketChannel> Fork(enum WSLA_FORK::ForkType Type);
|
||||
std::tuple<int32_t, int32_t, wsl::shared::SocketChannel> Fork(wsl::shared::SocketChannel& Channel, enum WSLA_FORK::ForkType Type);
|
||||
int32_t ExpectClosedChannelOrError(wsl::shared::SocketChannel& Channel);
|
||||
|
||||
wil::unique_socket ConnectSocket(wsl::shared::SocketChannel& Channel, int32_t Fd);
|
||||
void OpenLinuxFile(wsl::shared::SocketChannel& Channel, const char* Path, uint32_t Flags, int32_t Fd);
|
||||
void LaunchPortRelay();
|
||||
|
||||
std::vector<wil::unique_socket> CreateLinuxProcessImpl(
|
||||
_In_ const WSLA_CREATE_PROCESS_OPTIONS* Options, _In_ ULONG FdCount, _In_ WSLA_PROCESS_FD* Fd, _Out_ WSLA_CREATE_PROCESS_RESULT* Result);
|
||||
|
||||
HRESULT MountWindowsFolderImpl(_In_ LPCWSTR WindowsPath, _In_ LPCSTR LinuxPath, _In_ BOOL ReadOnly, _In_ WslMountFlags Flags);
|
||||
|
||||
struct AttachedDisk
|
||||
{
|
||||
std::filesystem::path Path;
|
||||
std::string Device;
|
||||
bool AccessGranted = false;
|
||||
};
|
||||
|
||||
VIRTUAL_MACHINE_SETTINGS m_settings;
|
||||
GUID m_vmId{};
|
||||
std::wstring m_vmIdString;
|
||||
wsl::windows::common::helpers::WindowsVersion m_windowsVersion = wsl::windows::common::helpers::GetWindowsVersion();
|
||||
int m_coldDiscardShiftSize{};
|
||||
bool m_running = false;
|
||||
PSID m_userSid{};
|
||||
std::wstring m_debugShellPipe;
|
||||
|
||||
wsl::windows::common::hcs::unique_hcs_system m_computeSystem;
|
||||
std::shared_ptr<DmesgCollector> m_dmesgCollector;
|
||||
wil::unique_event m_vmExitEvent{wil::EventOptions::ManualReset};
|
||||
wil::unique_event m_vmTerminatingEvent{wil::EventOptions::ManualReset};
|
||||
wil::com_ptr<ITerminationCallback> m_terminationCallback;
|
||||
std::unique_ptr<wsl::core::INetworkingEngine> m_networkEngine;
|
||||
|
||||
wsl::shared::SocketChannel m_initChannel;
|
||||
wil::unique_handle m_portRelayChannelRead;
|
||||
wil::unique_handle m_portRelayChannelWrite;
|
||||
|
||||
std::map<ULONG, AttachedDisk> m_attachedDisks;
|
||||
std::map<std::string, std::wstring> m_plan9Mounts;
|
||||
std::recursive_mutex m_lock;
|
||||
std::mutex m_portRelaylock;
|
||||
WSLAUserSessionImpl* m_userSession;
|
||||
};
|
||||
} // namespace wsl::windows::service::wsla
|
||||
8
src/windows/wslaservice/exe/application.manifest
Normal file
8
src/windows/wslaservice/exe/application.manifest
Normal file
@ -0,0 +1,8 @@
|
||||
<?xml version="1.0" encoding="utf-8" standalone="yes"?>
|
||||
<assembly xmlns="urn:schemas-microsoft-com:asm.v1" manifestVersion="1.0" xmlns:asmv3="urn:schemas-microsoft-com:asm.v3" >
|
||||
<application xmlns="urn:schemas-microsoft-com:asm.v3">
|
||||
<windowsSettings xmlns:ws2="http://schemas.microsoft.com/SMI/2016/WindowsSettings">
|
||||
<ws2:longPathAware>true</ws2:longPathAware>
|
||||
</windowsSettings>
|
||||
</application>
|
||||
</assembly>
|
||||
28
src/windows/wslaservice/exe/main.rc
Normal file
28
src/windows/wslaservice/exe/main.rc
Normal file
@ -0,0 +1,28 @@
|
||||
/*++
|
||||
|
||||
Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
Module Name:
|
||||
|
||||
main.rc
|
||||
|
||||
Abstract:
|
||||
|
||||
This file contains resources for wslaservice.
|
||||
|
||||
--*/
|
||||
|
||||
#include <windows.h>
|
||||
#include "resource.h"
|
||||
#include "wslversioninfo.h"
|
||||
|
||||
#define VER_INTERNALNAME_STR "wslaservice.exe"
|
||||
#define VER_ORIGINALFILENAME_STR "wslaservice.exe"
|
||||
|
||||
#define VER_FILETYPE VFT_APP
|
||||
#define VER_FILESUBTYPE VFT2_UNKNOWN
|
||||
#define VER_FILEDESCRIPTION_STR "Windows Subsystem for Linux for Apps Service"
|
||||
ID_ICON ICON PRELOAD DISCARDABLE "..\..\..\Images\wsl.ico"
|
||||
|
||||
|
||||
#include <common.ver>
|
||||
15
src/windows/wslaservice/exe/resource.h
Normal file
15
src/windows/wslaservice/exe/resource.h
Normal file
@ -0,0 +1,15 @@
|
||||
/*++
|
||||
|
||||
Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
Module Name:
|
||||
|
||||
resource.h
|
||||
|
||||
Abstract:
|
||||
|
||||
This file contains resource declarations for wslaservice.exe
|
||||
|
||||
--*/
|
||||
|
||||
#define ID_ICON 1
|
||||
2
src/windows/wslaservice/inc/CMakeLists.txt
Normal file
2
src/windows/wslaservice/inc/CMakeLists.txt
Normal file
@ -0,0 +1,2 @@
|
||||
add_idl(wslaserviceidl "wslaservice.idl" "")
|
||||
set_target_properties(wslaserviceidl PROPERTIES FOLDER windows)
|
||||
114
src/windows/wslaservice/inc/wslaservice.idl
Normal file
114
src/windows/wslaservice/inc/wslaservice.idl
Normal file
@ -0,0 +1,114 @@
|
||||
/*++
|
||||
|
||||
Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
|
||||
Module Name:
|
||||
|
||||
wslaservice.idl
|
||||
|
||||
Abstract:
|
||||
|
||||
This file contains the COM object definitions used to talk with the WSLa
|
||||
service "WslaService"
|
||||
|
||||
--*/
|
||||
|
||||
import "unknwn.idl";
|
||||
import "wtypes.idl";
|
||||
|
||||
cpp_quote("#ifdef __cplusplus")
|
||||
cpp_quote("class DECLSPEC_UUID(\"a9b7a1b9-0671-405c-95f1-e0612cb4ce8f\") WSLAUserSession;")
|
||||
cpp_quote("#endif")
|
||||
|
||||
typedef
|
||||
struct _WSL_VERSION {
|
||||
ULONG Major;
|
||||
ULONG Minor;
|
||||
ULONG Revision;
|
||||
} WSL_VERSION;
|
||||
|
||||
|
||||
typedef [system_handle(sh_socket)] HANDLE HVSOCKET_HANDLE;
|
||||
|
||||
|
||||
typedef
|
||||
struct _WSLA_CREATE_PROCESS_OPTIONS {
|
||||
[string] LPCSTR Executable;
|
||||
ULONG CommandLineCount;
|
||||
[unique, size_is(CommandLineCount)] LPCSTR* CommandLine;
|
||||
ULONG EnvironmentCount;
|
||||
[unique, size_is(EnvironmentCount)] LPCSTR* Environment;
|
||||
[unique] LPCSTR CurrentDirectory;
|
||||
} WSLA_CREATE_PROCESS_OPTIONS;
|
||||
|
||||
typedef struct _WSLA_PROCESS_FD
|
||||
{
|
||||
LONG Fd;
|
||||
int Type;
|
||||
[string, unique] LPCSTR Path;
|
||||
} WSLA_PROCESS_FD;
|
||||
|
||||
typedef
|
||||
struct _WSLA_CREATE_PROCESS_RESULT {
|
||||
int Errno;
|
||||
int Pid;
|
||||
} WSLA_CREATE_PROCESS_RESULT;
|
||||
|
||||
[
|
||||
uuid(7BC4E198-6531-4FA6-ADE2-5EF3D2A04DFE),
|
||||
pointer_default(unique),
|
||||
object
|
||||
]
|
||||
interface ITerminationCallback : IUnknown
|
||||
{
|
||||
HRESULT OnTermination(ULONG Reason, LPCWSTR Details);
|
||||
};
|
||||
|
||||
[
|
||||
uuid(82A7ABC8-6B50-43FC-AB96-15FBBE7E8761),
|
||||
pointer_default(unique),
|
||||
object
|
||||
]
|
||||
interface IWSLAVirtualMachine : IUnknown
|
||||
{
|
||||
HRESULT AttachDisk([in] LPCWSTR Path, [in] BOOL ReadOnly, [out] LPSTR* Device, [out] ULONG* Lun);
|
||||
HRESULT Mount([in, unique] LPCSTR Source, [in] LPCSTR Target, [in] LPCSTR Type, [in] LPCSTR Options, [in] ULONG Flags);
|
||||
HRESULT CreateLinuxProcess([in] const WSLA_CREATE_PROCESS_OPTIONS* Options, [in] ULONG FdCount, [in, unique, size_is(FdCount)] WSLA_PROCESS_FD* Fds, [out, size_is(FdCount)] ULONG* Handles, [out] WSLA_CREATE_PROCESS_RESULT* Result);
|
||||
HRESULT WaitPid([in] LONG Pid, [in] ULONGLONG TimeoutMs, [out] ULONG* State, [out] int* Code);
|
||||
HRESULT Signal([in] LONG Pid, [in] int Signal);
|
||||
HRESULT Shutdown([in] ULONGLONG TimeoutMs);
|
||||
HRESULT RegisterCallback([in] ITerminationCallback* terminationCallback);
|
||||
HRESULT GetDebugShellPipe([out] LPWSTR* pipePath);
|
||||
HRESULT MapPort([in] int Family, [in] short WindowsPort, [in] short LinuxPort, [in] BOOL Remove);
|
||||
HRESULT Unmount([in] LPCSTR Path);
|
||||
HRESULT DetachDisk([in] ULONG Lun);
|
||||
HRESULT MountWindowsFolder([in] LPCWSTR WindowsPath, [in] LPCSTR LinuxPath, [in] BOOL ReadOnly);
|
||||
HRESULT UnmountWindowsFolder([in] LPCSTR LinuxPath);
|
||||
HRESULT MountGpuLibraries([in] LPCSTR LibrariesMountPoint, [in] LPCSTR DriversMountpoint, [in] DWORD Flags);
|
||||
}
|
||||
|
||||
typedef
|
||||
struct _VIRTUAL_MACHINE_SETTINGS {
|
||||
LPCWSTR DisplayName;
|
||||
ULONGLONG MemoryMb;
|
||||
ULONG CpuCount;
|
||||
ULONG BootTimeoutMs;
|
||||
ULONG DmesgOutput;
|
||||
ULONG NetworkingMode;
|
||||
BOOL EnableDnsTunneling;
|
||||
BOOL EnableDebugShell;
|
||||
BOOL EnableEarlyBootDmesg;
|
||||
BOOL EnableGPU;
|
||||
} VIRTUAL_MACHINE_SETTINGS;
|
||||
|
||||
|
||||
[
|
||||
uuid(82A7ABC8-6B50-43FC-AB96-15FBBE7E8760),
|
||||
pointer_default(unique),
|
||||
object
|
||||
]
|
||||
interface IWSLAUserSession : IUnknown
|
||||
{
|
||||
HRESULT GetVersion([out] WSL_VERSION* Error);
|
||||
HRESULT CreateVirtualMachine([in] const VIRTUAL_MACHINE_SETTINGS* Settings, [out]IWSLAVirtualMachine** VirtualMachine);
|
||||
}
|
||||
13
src/windows/wslaservice/stub/CMakeLists.txt
Normal file
13
src/windows/wslaservice/stub/CMakeLists.txt
Normal file
@ -0,0 +1,13 @@
|
||||
set(SOURCES
|
||||
${CMAKE_CURRENT_BINARY_DIR}/../inc/${TARGET_PLATFORM}/${CMAKE_BUILD_TYPE}/wslaservice_i_${TARGET_PLATFORM}.c
|
||||
${CMAKE_CURRENT_BINARY_DIR}/../inc/${TARGET_PLATFORM}/${CMAKE_BUILD_TYPE}/wslaservice_p_${TARGET_PLATFORM}.c
|
||||
${CMAKE_CURRENT_BINARY_DIR}/../inc/${TARGET_PLATFORM}/${CMAKE_BUILD_TYPE}/dlldata_${TARGET_PLATFORM}.c
|
||||
${CMAKE_CURRENT_LIST_DIR}/WslaServiceProxyStub.def
|
||||
${CMAKE_CURRENT_LIST_DIR}/WslaServiceProxyStub.rc)
|
||||
|
||||
set_source_files_properties(${SOURCES} PROPERTIES GENERATED TRUE)
|
||||
|
||||
add_library(wslaserviceproxystub SHARED ${SOURCES})
|
||||
add_dependencies(wslaserviceproxystub wslaserviceidl)
|
||||
target_link_libraries(wslaserviceproxystub ${COMMON_LINK_LIBRARIES})
|
||||
set_target_properties(wslaserviceproxystub PROPERTIES FOLDER windows)
|
||||
5
src/windows/wslaservice/stub/WslaServiceProxyStub.def
Normal file
5
src/windows/wslaservice/stub/WslaServiceProxyStub.def
Normal file
@ -0,0 +1,5 @@
|
||||
LIBRARY WslaServiceProxyStub.dll
|
||||
|
||||
EXPORTS
|
||||
DllGetClassObject PRIVATE
|
||||
DllCanUnloadNow PRIVATE
|
||||
23
src/windows/wslaservice/stub/WslaServiceProxyStub.rc
Normal file
23
src/windows/wslaservice/stub/WslaServiceProxyStub.rc
Normal file
@ -0,0 +1,23 @@
|
||||
/*++
|
||||
|
||||
Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
Module Name:
|
||||
|
||||
WslaServiceProxyStub.rc
|
||||
|
||||
Abstract:
|
||||
|
||||
This file contains resources for wslaserviceproxystub.dll.
|
||||
|
||||
--*/
|
||||
|
||||
#include <windows.h>
|
||||
#include "wslversioninfo.h"
|
||||
|
||||
#define VER_INTERNALNAME_STR "wslaserviceproxystub.dll"
|
||||
#define VER_ORIGINALFILENAME_STR "wslaserviceproxystub.dll"
|
||||
|
||||
#define VER_FILEDESCRIPTION_STR "WSLA Service ProxyStub DLL"
|
||||
|
||||
#include <common.ver>
|
||||
@ -1316,6 +1316,17 @@ void StopWslService()
|
||||
StopService(service.get());
|
||||
}
|
||||
|
||||
void StopWslaService()
|
||||
{
|
||||
LogInfo("Stopping WSLAService");
|
||||
const wil::unique_schandle manager{OpenSCManager(nullptr, nullptr, SC_MANAGER_CONNECT)};
|
||||
VERIFY_IS_NOT_NULL(manager);
|
||||
|
||||
const wil::unique_schandle service{OpenService(manager.get(), L"wslaservice", SERVICE_STOP | SERVICE_QUERY_STATUS)};
|
||||
VERIFY_IS_NOT_NULL(service);
|
||||
StopService(service.get());
|
||||
}
|
||||
|
||||
wil::unique_handle GetNonElevatedToken()
|
||||
{
|
||||
const auto token = wil::open_current_access_token(TOKEN_ALL_ACCESS);
|
||||
|
||||
@ -523,6 +523,7 @@ inline auto EnableSystemd(const std::string& extraConfig = "")
|
||||
std::wstring EscapePath(std::wstring_view Path);
|
||||
|
||||
void StopWslService();
|
||||
void StopWslaService();
|
||||
|
||||
std::optional<GUID> GetDistributionId(LPCWSTR Name);
|
||||
wil::unique_hkey OpenDistributionKey(LPCWSTR Name);
|
||||
|
||||
@ -821,7 +821,7 @@ class WSLATests
|
||||
});
|
||||
|
||||
// Stop the service
|
||||
StopWslService();
|
||||
StopWslaService();
|
||||
|
||||
// Verify that the thread is unstuck
|
||||
stuckThread.join();
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user