From 3451e8213d0290d5912839b21552fb9bc61a4f21 Mon Sep 17 00:00:00 2001 From: Blue Date: Fri, 8 Aug 2025 11:42:13 -0700 Subject: [PATCH] Implement WSLA API's to query installed components, and install them when needed (#13363) * Save state * Save state * Add tests * Finalize tests * Remove test logic --- src/windows/common/WslInstall.cpp | 22 ++++-- src/windows/common/WslInstall.h | 5 ++ src/windows/common/wslutil.cpp | 41 +++++++++-- src/windows/common/wslutil.h | 4 ++ src/windows/lswclient/DllMain.cpp | 103 ++++++++++++++++++++++++++++ src/windows/lswclient/LSWApi.h | 17 +++++ src/windows/lswclient/lswclient.def | 5 +- test/windows/InstallerTests.cpp | 75 ++++++++++++++++++++ 8 files changed, 261 insertions(+), 11 deletions(-) diff --git a/src/windows/common/WslInstall.cpp b/src/windows/common/WslInstall.cpp index 1aab748..79b277c 100644 --- a/src/windows/common/WslInstall.cpp +++ b/src/windows/common/WslInstall.cpp @@ -23,8 +23,6 @@ Abstract: extern HINSTANCE g_dllInstance; constexpr LPCWSTR c_optionalFeatureInstallStatus = L"InstallStatus"; -constexpr LPCWSTR c_optionalFeatureNameVmp = L"VirtualMachinePlatform"; -constexpr LPCWSTR c_optionalFeatureNameWsl = L"Microsoft-Windows-Subsystem-Linux"; using wsl::shared::Localization; using namespace wsl::windows::common::distribution; @@ -239,18 +237,32 @@ std::pair> WslInstall::CheckForMissingOptionalCo return {rebootRequired, std::move(missingComponents)}; } -void WslInstall::InstallOptionalComponents(const std::vector& components) +DWORD WslInstall::InstallOptionalComponent(LPCWSTR component, bool consoleOutput) { std::wstring systemDirectory; THROW_IF_FAILED(wil::GetSystemDirectoryW(systemDirectory)); const auto dismPath = std::filesystem::path(std::move(systemDirectory)) / L"dism.exe"; + + auto commandLine = std::format(L"{} /Online /NoRestart /enable-feature /featurename:{}", dismPath.native(), component); + + wsl::windows::common::SubProcess process(nullptr, commandLine.c_str()); + if (!consoleOutput) + { + process.SetFlags(CREATE_NEW_CONSOLE); + process.SetShowWindow(SW_HIDE); + } + + return process.Run(); +} + +void WslInstall::InstallOptionalComponents(const std::vector& components) +{ for (const auto& component : components) { wsl::windows::common::wslutil::PrintMessage(Localization::MessageInstallingWindowsComponent(component)); - auto commandLine = std::format(L"{} /Online /NoRestart /enable-feature /featurename:{}", dismPath.wstring(), component); - const auto exitCode = wsl::windows::common::helpers::RunProcess(commandLine); + const auto exitCode = InstallOptionalComponent(component.c_str(), true); if (exitCode != 0 && exitCode != ERROR_SUCCESS_REBOOT_REQUIRED) { THROW_HR_WITH_USER_ERROR(WSL_E_INSTALL_COMPONENT_FAILED, Localization::MessageOptionalComponentInstallFailed(component, exitCode)); diff --git a/src/windows/common/WslInstall.h b/src/windows/common/WslInstall.h index beee2fa..6e141d1 100644 --- a/src/windows/common/WslInstall.h +++ b/src/windows/common/WslInstall.h @@ -19,6 +19,9 @@ Abstract: class WslInstall { public: + static inline LPCWSTR c_optionalFeatureNameVmp = L"VirtualMachinePlatform"; + static inline LPCWSTR c_optionalFeatureNameWsl = L"Microsoft-Windows-Subsystem-Linux"; + struct InstallResult { std::wstring Name; @@ -44,6 +47,8 @@ public: static void InstallOptionalComponents(const std::vector& components); + static DWORD InstallOptionalComponent(LPCWSTR component, bool consoleOutput); + static std::pair InstallModernDistribution( const wsl::windows::common::distribution::ModernDistributionVersion& distribution, const std::optional& version, diff --git a/src/windows/common/wslutil.cpp b/src/windows/common/wslutil.cpp index 5d3bc41..81a45e3 100644 --- a/src/windows/common/wslutil.cpp +++ b/src/windows/common/wslutil.cpp @@ -531,6 +531,20 @@ GUID wsl::windows::common::wslutil::CreateV5Uuid(const GUID& namespaceGuid, cons } std::wstring wsl::windows::common::wslutil::DownloadFile(std::wstring_view Url, std::wstring Filename) +{ + wsl::windows::common::ConsoleProgressBar progressBar; + auto progress = [&](auto current, auto total) { + progressBar.Print(current, total); + return true; + }; + + auto cleanup = wil::scope_exit_log(WI_DIAGNOSTICS_INFO, [&]() { progressBar.Clear(); }); + + return DownloadFileImpl(Url, Filename, progress); +} + +std::wstring wsl::windows::common::wslutil::DownloadFileImpl( + std::wstring_view Url, std::wstring Filename, const std::function& Progress) { const auto lastSlash = Url.find_last_of('/'); THROW_HR_IF(E_INVALIDARG, lastSlash == std::wstring::npos); @@ -560,7 +574,6 @@ std::wstring wsl::windows::common::wslutil::DownloadFile(std::wstring_view Url, const auto asyncResponse = client.GetInputStreamAsync(winrt::Windows::Foundation::Uri(Url)); std::atomic totalBytes; - wsl::windows::common::ConsoleProgressBar progressBar; asyncResponse.Progress( [&](const winrt::Windows::Foundation::IAsyncOperationWithProgress&, const winrt::Windows::Web::Http::HttpProgress& progress) { @@ -575,12 +588,11 @@ std::wstring wsl::windows::common::wslutil::DownloadFile(std::wstring_view Url, download.Progress([&](const auto& _, uint64_t progress) { if (totalBytes != 0) { - progressBar.Print(progress, totalBytes); + Progress(progress, totalBytes); } }); download.get(); - progressBar.Clear(); deleteFileOnFailure.release(); return file.Path().c_str(); @@ -1134,6 +1146,25 @@ std::vector wsl::windows::common::wslutil::HashFile(HANDLE file, DWORD Alg return fileHash; } +std::optional> wsl::windows::common::wslutil::GetInstalledPackageVersion() +{ + std::wstring packageVersion; + auto result = wil::ResultFromException([&]() { + auto msiKey = wsl::windows::common::registry::OpenLxssMachineKey(KEY_READ); + + packageVersion = wsl::windows::common::registry::ReadString(msiKey.get(), L"Msi", L"Version"); + }); + + if (result == HRESULT_FROM_WIN32(ERROR_FILE_NOT_FOUND) || result == HRESULT_FROM_WIN32(ERROR_PATH_NOT_FOUND)) + { + return {}; + } + + THROW_IF_FAILED(result); + + return ParseWslPackageVersion(packageVersion); +} + void wsl::windows::common::wslutil::InitializeWil() { wil::WilInitialize_CppWinRT(); @@ -1248,8 +1279,8 @@ std::pair wsl::windows::common::wslutil::O bool wsl::windows::common::wslutil::IsVirtualMachinePlatformInstalled() { - // Note for Windows 11 22H2 and above builds: If hyper-v is installed but VMP platform isn't, HNS and vmcompute are available - // but calls to HNS will fail if vfpext isn't installed. + // Note for Windows 11 22H2 and above builds: If hyper-v is installed but VMP platform isn't, HNS and vmcompute are + // available but calls to HNS will fail if vfpext isn't installed. return wsl::windows::common::helpers::IsServicePresent(L"HNS") && wsl::windows::common::helpers::IsServicePresent(L"vmcompute") && (helpers::GetWindowsVersion().BuildNumber < helpers::WindowsBuildNumbers::Nickel || diff --git a/src/windows/common/wslutil.h b/src/windows/common/wslutil.h index 1d05526..bd1fb91 100644 --- a/src/windows/common/wslutil.h +++ b/src/windows/common/wslutil.h @@ -84,6 +84,8 @@ GUID CreateV5Uuid(const GUID& namespaceGuid, const std::span na std::wstring DownloadFile(std::wstring_view Url, std::wstring Filename); +std::wstring DownloadFileImpl(std::wstring_view Url, std::wstring Filename, const std::function& Progress); + [[nodiscard]] HANDLE DuplicateHandleFromCallingProcess(_In_ HANDLE handleInTarget); void EnforceFileLimit(LPCWSTR Folder, size_t limit, const std::function& pred); @@ -116,6 +118,8 @@ std::wstring GetSystemErrorString(_In_ HRESULT result); std::wstring GetDebugShellPipeName(_In_ PSID Sid); +std::optional> GetInstalledPackageVersion(); + std::vector HashFile(HANDLE File, DWORD Algorithm); void InitializeWil(); diff --git a/src/windows/lswclient/DllMain.cpp b/src/windows/lswclient/DllMain.cpp index f01231b..21645d5 100644 --- a/src/windows/lswclient/DllMain.cpp +++ b/src/windows/lswclient/DllMain.cpp @@ -16,6 +16,7 @@ Abstract: #include "wslservice.h" #include "LSWApi.h" #include "wslrelay.h" +#include "wslInstall.h" namespace { @@ -292,6 +293,8 @@ EXTERN_C BOOL STDAPICALLTYPE DllMain(_In_ HINSTANCE Instance, _In_ DWORD Reason, { case DLL_PROCESS_ATTACH: WslTraceLoggingInitialize(LxssTelemetryProvider, false); + wsl::windows::common::wslutil::InitializeWil(); + break; case DLL_PROCESS_DETACH: @@ -301,3 +304,103 @@ EXTERN_C BOOL STDAPICALLTYPE DllMain(_In_ HINSTANCE Instance, _In_ DWORD Reason, return TRUE; } + +DEFINE_ENUM_FLAG_OPERATORS(WslInstallComponent); + +HRESULT WslQueryMissingComponents(enum WslInstallComponent* Components) +try +{ + *Components = WslInstallComponentNone; + + // Check for Windows features + WI_SetFlagIf( + *Components, + WslInstallComponentWslOC, + !wsl::windows::common::helpers::IsWindows11OrAbove() && !wsl::windows::common::helpers::IsServicePresent(L"lxssmanager")); + + WI_SetFlagIf(*Components, WslInstallComponentVMPOC, !wsl::windows::common::wslutil::IsVirtualMachinePlatformInstalled()); + + // Check if the WSL package is installed, and if the version supports WSLA + auto version = wsl::windows::common::wslutil::GetInstalledPackageVersion(); + + constexpr auto minimalPackageVersion = wsl::shared::PackageVersion; // TODO: replace with correct version once WSLA is released. + WI_SetFlagIf(*Components, WslInstallComponentWslPackage, !version.has_value() || version < minimalPackageVersion); + + // TODO: Check if hardware supports virtualization. + + return S_OK; +} +CATCH_RETURN(); + +// Used for debugging. +static LPCWSTR PackageUrl = nullptr; + +HRESULT WslSetPackageUrl(LPCWSTR Url) +{ + PackageUrl = Url; + return S_OK; +} + +HRESULT WslInstallComponents(enum WslInstallComponent Components, WslInstallCallback ProgressCallback, void* Context) +try +{ + // Check for invalid flags. + RETURN_HR_IF_MSG( + E_INVALIDARG, + (Components & ~(WslInstallComponentVMPOC | WslInstallComponentWslOC | WslInstallComponentWslPackage)) != 0, + "Unexpected flag: %i", + Components); + + // Fail if the caller is not elevated. + RETURN_HR_IF( + HRESULT_FROM_WIN32(ERROR_ELEVATION_REQUIRED), + Components != 0 && !wsl::windows::common::security::IsTokenElevated(wil::open_current_access_token().get())); + + if (WI_IsFlagSet(Components, WslInstallComponentWslPackage)) + { + THROW_HR_IF(E_INVALIDARG, PackageUrl == nullptr); + + auto callback = [&](uint64_t progress, uint64_t total) { + if (ProgressCallback != nullptr) + { + ProgressCallback(WslInstallComponentWslPackage, progress, total, Context); + } + }; + + const auto downloadPath = wsl::windows::common::wslutil::DownloadFileImpl(PackageUrl, L"wsl.msi", callback); + + auto cleanup = wil::scope_exit_log(WI_DIAGNOSTICS_INFO, [&]() { std::filesystem::remove(downloadPath); }); + + auto exitCode = wsl::windows::common::wslutil::UpgradeViaMsi(downloadPath.c_str(), nullptr, nullptr, [](auto, auto) {}); + THROW_HR_IF_MSG( + E_FAIL, exitCode != 0, "MSI installation failed. URL: %ls, DownloadPath: %ls, exitCode: %u", downloadPath.c_str(), PackageUrl, exitCode); + } + + std::vector optionalComponents; + if (WI_IsFlagSet(Components, WslInstallComponentWslOC)) + { + if (ProgressCallback != nullptr) + { + ProgressCallback(WslInstallComponentWslOC, 0, 1, Context); + } + + auto exitCode = WslInstall::InstallOptionalComponent(WslInstall::c_optionalFeatureNameWsl, false); + THROW_HR_IF_MSG(E_FAIL, exitCode != 0 && exitCode != ERROR_SUCCESS_REBOOT_REQUIRED, "Failed to install '%ls', %lu", WslInstall::c_optionalFeatureNameWsl, exitCode); + } + + if (WI_IsFlagSet(Components, WslInstallComponentVMPOC)) + { + if (ProgressCallback != nullptr) + { + ProgressCallback(WslInstallComponentVMPOC, 0, 1, Context); + } + + auto exitCode = WslInstall::InstallOptionalComponent(WslInstall::c_optionalFeatureNameVmp, false); + THROW_HR_IF_MSG(E_FAIL, exitCode != 0 && exitCode != ERROR_SUCCESS_REBOOT_REQUIRED, "Failed to install '%ls', %lu", WslInstall::c_optionalFeatureNameVmp, exitCode); + } + + return WI_IsAnyFlagSet(Components, WslInstallComponentWslOC | WslInstallComponentVMPOC) + ? HRESULT_FROM_WIN32(ERROR_SUCCESS_REBOOT_REQUIRED) + : S_OK; +} +CATCH_RETURN(); \ No newline at end of file diff --git a/src/windows/lswclient/LSWApi.h b/src/windows/lswclient/LSWApi.h index 18d840e..bd48f4e 100644 --- a/src/windows/lswclient/LSWApi.h +++ b/src/windows/lswclient/LSWApi.h @@ -179,6 +179,23 @@ HRESULT WslMapPort(LSWVirtualMachineHandle VirtualMachine, const struct PortMapp HRESULT WslUnmapPort(LSWVirtualMachineHandle VirtualMachine, const struct PortMappingSettings* Settings); +enum WslInstallComponent +{ + WslInstallComponentNone = 0, + WslInstallComponentVMPOC = 1, + WslInstallComponentWslOC = 2, + WslInstallComponentWslPackage = 4, +}; + +HRESULT WslQueryMissingComponents(enum WslInstallComponent* Components); + +typedef void (*WslInstallCallback)(enum WslInstallComponent, uint64_t, uint64_t, void*); + +HRESULT WslInstallComponents(enum WslInstallComponent Components, WslInstallCallback ProgressCallback, void* Context); + +// Used for testing until the package is published. +HRESULT WslSetPackageUrl(LPCWSTR Url); + #ifdef __cplusplus } #endif \ No newline at end of file diff --git a/src/windows/lswclient/lswclient.def b/src/windows/lswclient/lswclient.def index d15872c..726ff1a 100644 --- a/src/windows/lswclient/lswclient.def +++ b/src/windows/lswclient/lswclient.def @@ -13,4 +13,7 @@ EXPORTS WslLaunchInteractiveTerminal WslLaunchDebugShell WslMapPort - WslUnmapPort \ No newline at end of file + WslUnmapPort + WslQueryMissingComponents + WslInstallComponents + WslSetPackageUrl \ No newline at end of file diff --git a/test/windows/InstallerTests.cpp b/test/windows/InstallerTests.cpp index 9b265f6..0f99d9b 100644 --- a/test/windows/InstallerTests.cpp +++ b/test/windows/InstallerTests.cpp @@ -18,12 +18,15 @@ Abstract: #include "Common.h" #include "registry.hpp" #include "PluginTests.h" +#include "lswapi.h" using namespace wsl::windows::common::registry; extern std::wstring g_dumpFolder; static std::wstring g_pipelineBuildId; +DEFINE_ENUM_FLAG_OPERATORS(WslInstallComponent); + class InstallerTests { std::wstring m_msixPackagePath; @@ -1037,4 +1040,76 @@ class InstallerTests SHChangeNotify(SHCNE_ASSOCCHANGED, SHCNF_IDLIST, nullptr, nullptr); VerifyWslSettingsProtocolAssociationExistsWithRetry(); } + + TEST_METHOD(WSLAInstall) + { + auto expectComponents = [](WslInstallComponent expected) { + WslInstallComponent components{}; + VERIFY_SUCCEEDED(WslQueryMissingComponents(&components)); + + VERIFY_ARE_EQUAL(components, expected); + }; + + VERIFY_ARE_EQUAL(WslInstallComponents(WslInstallComponentWslPackage, nullptr, nullptr), E_INVALIDARG); + + expectComponents(WslInstallComponentNone); + UninstallMsi(); + + expectComponents(WslInstallComponentWslPackage); + + { + UniqueWebServer fileServer(L"http://127.0.0.1:12346/", std::filesystem::path(m_msiPath)); + VERIFY_SUCCEEDED(WslSetPackageUrl(L"http://127.0.0.1:12346/")); + + WslInstallComponent progressedComponents{}; + auto callback = [](WslInstallComponent Component, uint64_t progress, uint64_t total, void* Context) { + *reinterpret_cast(Context) |= Component; + }; + + VERIFY_SUCCEEDED(WslInstallComponents(WslInstallComponentWslPackage, callback, &progressedComponents)); + VERIFY_ARE_EQUAL(progressedComponents, WslInstallComponentWslPackage); + + ValidateInstalledVersion(WIDEN(WSL_PACKAGE_VERSION)); + expectComponents(WslInstallComponentNone); + + progressedComponents = WslInstallComponentNone; + VERIFY_ARE_EQUAL(WslInstallComponents(WslInstallComponentVMPOC, callback, &progressedComponents), HRESULT_FROM_WIN32(ERROR_SUCCESS_REBOOT_REQUIRED)); + VERIFY_ARE_EQUAL(progressedComponents, WslInstallComponentVMPOC); + + progressedComponents = WslInstallComponentNone; + VERIFY_ARE_EQUAL(WslInstallComponents(WslInstallComponentWslOC, callback, &progressedComponents), HRESULT_FROM_WIN32(ERROR_SUCCESS_REBOOT_REQUIRED)); + VERIFY_ARE_EQUAL(progressedComponents, WslInstallComponentWslOC); + } + + { + VERIFY_SUCCEEDED(WslSetPackageUrl(L"http://127.0.0.1:12346/")); + VERIFY_ARE_EQUAL(WslInstallComponents(WslInstallComponentWslPackage, nullptr, nullptr), WININET_E_CANNOT_CONNECT); + } + } + + // This test case requires a machine without the OC's enabled. + TEST_METHOD(WSLAInstallManual) + { + WslInstallComponent components{}; + VERIFY_SUCCEEDED(WslQueryMissingComponents(&components)); + + if (!WI_IsAnyFlagSet(components, WslInstallComponentWslOC | WslInstallComponentVMPOC)) + { + LogSkipped("OC are installed, skipping test. Flags: %i", components); + return; + } + + auto expectedComponents = WslInstallComponentVMPOC; + WI_SetFlagIf(expectedComponents, WslInstallComponentWslOC, !wsl::windows::common::helpers::IsWindows11OrAbove()); + + VERIFY_ARE_EQUAL(components, expectedComponents); + + WslInstallComponent progressedComponents{}; + auto callback = [](WslInstallComponent Component, uint64_t progress, uint64_t total, void* Context) { + *reinterpret_cast(Context) |= Component; + }; + + VERIFY_ARE_EQUAL(WslInstallComponents(components, callback, &progressedComponents), HRESULT_FROM_WIN32(ERROR_SUCCESS_REBOOT_REQUIRED)); + VERIFY_ARE_EQUAL(progressedComponents, expectedComponents); + } }; \ No newline at end of file