Implement WSLA API's to query installed components, and install them when needed (#13363)

* Save state

* Save state

* Add tests

* Finalize tests

* Remove test logic
This commit is contained in:
Blue 2025-08-08 11:42:13 -07:00 committed by Blue
parent 5438e046f9
commit 3451e8213d
8 changed files with 261 additions and 11 deletions

View File

@ -23,8 +23,6 @@ Abstract:
extern HINSTANCE g_dllInstance;
constexpr LPCWSTR c_optionalFeatureInstallStatus = L"InstallStatus";
constexpr LPCWSTR c_optionalFeatureNameVmp = L"VirtualMachinePlatform";
constexpr LPCWSTR c_optionalFeatureNameWsl = L"Microsoft-Windows-Subsystem-Linux";
using wsl::shared::Localization;
using namespace wsl::windows::common::distribution;
@ -239,18 +237,32 @@ std::pair<bool, std::vector<std::wstring>> WslInstall::CheckForMissingOptionalCo
return {rebootRequired, std::move(missingComponents)};
}
void WslInstall::InstallOptionalComponents(const std::vector<std::wstring>& components)
DWORD WslInstall::InstallOptionalComponent(LPCWSTR component, bool consoleOutput)
{
std::wstring systemDirectory;
THROW_IF_FAILED(wil::GetSystemDirectoryW(systemDirectory));
const auto dismPath = std::filesystem::path(std::move(systemDirectory)) / L"dism.exe";
auto commandLine = std::format(L"{} /Online /NoRestart /enable-feature /featurename:{}", dismPath.native(), component);
wsl::windows::common::SubProcess process(nullptr, commandLine.c_str());
if (!consoleOutput)
{
process.SetFlags(CREATE_NEW_CONSOLE);
process.SetShowWindow(SW_HIDE);
}
return process.Run();
}
void WslInstall::InstallOptionalComponents(const std::vector<std::wstring>& components)
{
for (const auto& component : components)
{
wsl::windows::common::wslutil::PrintMessage(Localization::MessageInstallingWindowsComponent(component));
auto commandLine = std::format(L"{} /Online /NoRestart /enable-feature /featurename:{}", dismPath.wstring(), component);
const auto exitCode = wsl::windows::common::helpers::RunProcess(commandLine);
const auto exitCode = InstallOptionalComponent(component.c_str(), true);
if (exitCode != 0 && exitCode != ERROR_SUCCESS_REBOOT_REQUIRED)
{
THROW_HR_WITH_USER_ERROR(WSL_E_INSTALL_COMPONENT_FAILED, Localization::MessageOptionalComponentInstallFailed(component, exitCode));

View File

@ -19,6 +19,9 @@ Abstract:
class WslInstall
{
public:
static inline LPCWSTR c_optionalFeatureNameVmp = L"VirtualMachinePlatform";
static inline LPCWSTR c_optionalFeatureNameWsl = L"Microsoft-Windows-Subsystem-Linux";
struct InstallResult
{
std::wstring Name;
@ -44,6 +47,8 @@ public:
static void InstallOptionalComponents(const std::vector<std::wstring>& components);
static DWORD InstallOptionalComponent(LPCWSTR component, bool consoleOutput);
static std::pair<std::wstring, GUID> InstallModernDistribution(
const wsl::windows::common::distribution::ModernDistributionVersion& distribution,
const std::optional<ULONG>& version,

View File

@ -531,6 +531,20 @@ GUID wsl::windows::common::wslutil::CreateV5Uuid(const GUID& namespaceGuid, cons
}
std::wstring wsl::windows::common::wslutil::DownloadFile(std::wstring_view Url, std::wstring Filename)
{
wsl::windows::common::ConsoleProgressBar progressBar;
auto progress = [&](auto current, auto total) {
progressBar.Print(current, total);
return true;
};
auto cleanup = wil::scope_exit_log(WI_DIAGNOSTICS_INFO, [&]() { progressBar.Clear(); });
return DownloadFileImpl(Url, Filename, progress);
}
std::wstring wsl::windows::common::wslutil::DownloadFileImpl(
std::wstring_view Url, std::wstring Filename, const std::function<void(uint64_t, uint64_t)>& Progress)
{
const auto lastSlash = Url.find_last_of('/');
THROW_HR_IF(E_INVALIDARG, lastSlash == std::wstring::npos);
@ -560,7 +574,6 @@ std::wstring wsl::windows::common::wslutil::DownloadFile(std::wstring_view Url,
const auto asyncResponse = client.GetInputStreamAsync(winrt::Windows::Foundation::Uri(Url));
std::atomic<uint64_t> totalBytes;
wsl::windows::common::ConsoleProgressBar progressBar;
asyncResponse.Progress(
[&](const winrt::Windows::Foundation::IAsyncOperationWithProgress<winrt::Windows::Storage::Streams::IInputStream, winrt::Windows::Web::Http::HttpProgress>&,
const winrt::Windows::Web::Http::HttpProgress& progress) {
@ -575,12 +588,11 @@ std::wstring wsl::windows::common::wslutil::DownloadFile(std::wstring_view Url,
download.Progress([&](const auto& _, uint64_t progress) {
if (totalBytes != 0)
{
progressBar.Print(progress, totalBytes);
Progress(progress, totalBytes);
}
});
download.get();
progressBar.Clear();
deleteFileOnFailure.release();
return file.Path().c_str();
@ -1134,6 +1146,25 @@ std::vector<BYTE> wsl::windows::common::wslutil::HashFile(HANDLE file, DWORD Alg
return fileHash;
}
std::optional<std::tuple<uint32_t, uint32_t, uint32_t>> wsl::windows::common::wslutil::GetInstalledPackageVersion()
{
std::wstring packageVersion;
auto result = wil::ResultFromException([&]() {
auto msiKey = wsl::windows::common::registry::OpenLxssMachineKey(KEY_READ);
packageVersion = wsl::windows::common::registry::ReadString(msiKey.get(), L"Msi", L"Version");
});
if (result == HRESULT_FROM_WIN32(ERROR_FILE_NOT_FOUND) || result == HRESULT_FROM_WIN32(ERROR_PATH_NOT_FOUND))
{
return {};
}
THROW_IF_FAILED(result);
return ParseWslPackageVersion(packageVersion);
}
void wsl::windows::common::wslutil::InitializeWil()
{
wil::WilInitialize_CppWinRT();
@ -1248,8 +1279,8 @@ std::pair<wil::unique_hfile, wil::unique_hfile> wsl::windows::common::wslutil::O
bool wsl::windows::common::wslutil::IsVirtualMachinePlatformInstalled()
{
// Note for Windows 11 22H2 and above builds: If hyper-v is installed but VMP platform isn't, HNS and vmcompute are available
// but calls to HNS will fail if vfpext isn't installed.
// Note for Windows 11 22H2 and above builds: If hyper-v is installed but VMP platform isn't, HNS and vmcompute are
// available but calls to HNS will fail if vfpext isn't installed.
return wsl::windows::common::helpers::IsServicePresent(L"HNS") &&
wsl::windows::common::helpers::IsServicePresent(L"vmcompute") &&
(helpers::GetWindowsVersion().BuildNumber < helpers::WindowsBuildNumbers::Nickel ||

View File

@ -84,6 +84,8 @@ GUID CreateV5Uuid(const GUID& namespaceGuid, const std::span<const std::byte> na
std::wstring DownloadFile(std::wstring_view Url, std::wstring Filename);
std::wstring DownloadFileImpl(std::wstring_view Url, std::wstring Filename, const std::function<void(uint64_t, uint64_t)>& Progress);
[[nodiscard]] HANDLE DuplicateHandleFromCallingProcess(_In_ HANDLE handleInTarget);
void EnforceFileLimit(LPCWSTR Folder, size_t limit, const std::function<bool(const std::filesystem::directory_entry&)>& pred);
@ -116,6 +118,8 @@ std::wstring GetSystemErrorString(_In_ HRESULT result);
std::wstring GetDebugShellPipeName(_In_ PSID Sid);
std::optional<std::tuple<uint32_t, uint32_t, uint32_t>> GetInstalledPackageVersion();
std::vector<BYTE> HashFile(HANDLE File, DWORD Algorithm);
void InitializeWil();

View File

@ -16,6 +16,7 @@ Abstract:
#include "wslservice.h"
#include "LSWApi.h"
#include "wslrelay.h"
#include "wslInstall.h"
namespace {
@ -292,6 +293,8 @@ EXTERN_C BOOL STDAPICALLTYPE DllMain(_In_ HINSTANCE Instance, _In_ DWORD Reason,
{
case DLL_PROCESS_ATTACH:
WslTraceLoggingInitialize(LxssTelemetryProvider, false);
wsl::windows::common::wslutil::InitializeWil();
break;
case DLL_PROCESS_DETACH:
@ -301,3 +304,103 @@ EXTERN_C BOOL STDAPICALLTYPE DllMain(_In_ HINSTANCE Instance, _In_ DWORD Reason,
return TRUE;
}
DEFINE_ENUM_FLAG_OPERATORS(WslInstallComponent);
HRESULT WslQueryMissingComponents(enum WslInstallComponent* Components)
try
{
*Components = WslInstallComponentNone;
// Check for Windows features
WI_SetFlagIf(
*Components,
WslInstallComponentWslOC,
!wsl::windows::common::helpers::IsWindows11OrAbove() && !wsl::windows::common::helpers::IsServicePresent(L"lxssmanager"));
WI_SetFlagIf(*Components, WslInstallComponentVMPOC, !wsl::windows::common::wslutil::IsVirtualMachinePlatformInstalled());
// Check if the WSL package is installed, and if the version supports WSLA
auto version = wsl::windows::common::wslutil::GetInstalledPackageVersion();
constexpr auto minimalPackageVersion = wsl::shared::PackageVersion; // TODO: replace with correct version once WSLA is released.
WI_SetFlagIf(*Components, WslInstallComponentWslPackage, !version.has_value() || version < minimalPackageVersion);
// TODO: Check if hardware supports virtualization.
return S_OK;
}
CATCH_RETURN();
// Used for debugging.
static LPCWSTR PackageUrl = nullptr;
HRESULT WslSetPackageUrl(LPCWSTR Url)
{
PackageUrl = Url;
return S_OK;
}
HRESULT WslInstallComponents(enum WslInstallComponent Components, WslInstallCallback ProgressCallback, void* Context)
try
{
// Check for invalid flags.
RETURN_HR_IF_MSG(
E_INVALIDARG,
(Components & ~(WslInstallComponentVMPOC | WslInstallComponentWslOC | WslInstallComponentWslPackage)) != 0,
"Unexpected flag: %i",
Components);
// Fail if the caller is not elevated.
RETURN_HR_IF(
HRESULT_FROM_WIN32(ERROR_ELEVATION_REQUIRED),
Components != 0 && !wsl::windows::common::security::IsTokenElevated(wil::open_current_access_token().get()));
if (WI_IsFlagSet(Components, WslInstallComponentWslPackage))
{
THROW_HR_IF(E_INVALIDARG, PackageUrl == nullptr);
auto callback = [&](uint64_t progress, uint64_t total) {
if (ProgressCallback != nullptr)
{
ProgressCallback(WslInstallComponentWslPackage, progress, total, Context);
}
};
const auto downloadPath = wsl::windows::common::wslutil::DownloadFileImpl(PackageUrl, L"wsl.msi", callback);
auto cleanup = wil::scope_exit_log(WI_DIAGNOSTICS_INFO, [&]() { std::filesystem::remove(downloadPath); });
auto exitCode = wsl::windows::common::wslutil::UpgradeViaMsi(downloadPath.c_str(), nullptr, nullptr, [](auto, auto) {});
THROW_HR_IF_MSG(
E_FAIL, exitCode != 0, "MSI installation failed. URL: %ls, DownloadPath: %ls, exitCode: %u", downloadPath.c_str(), PackageUrl, exitCode);
}
std::vector<std::wstring> optionalComponents;
if (WI_IsFlagSet(Components, WslInstallComponentWslOC))
{
if (ProgressCallback != nullptr)
{
ProgressCallback(WslInstallComponentWslOC, 0, 1, Context);
}
auto exitCode = WslInstall::InstallOptionalComponent(WslInstall::c_optionalFeatureNameWsl, false);
THROW_HR_IF_MSG(E_FAIL, exitCode != 0 && exitCode != ERROR_SUCCESS_REBOOT_REQUIRED, "Failed to install '%ls', %lu", WslInstall::c_optionalFeatureNameWsl, exitCode);
}
if (WI_IsFlagSet(Components, WslInstallComponentVMPOC))
{
if (ProgressCallback != nullptr)
{
ProgressCallback(WslInstallComponentVMPOC, 0, 1, Context);
}
auto exitCode = WslInstall::InstallOptionalComponent(WslInstall::c_optionalFeatureNameVmp, false);
THROW_HR_IF_MSG(E_FAIL, exitCode != 0 && exitCode != ERROR_SUCCESS_REBOOT_REQUIRED, "Failed to install '%ls', %lu", WslInstall::c_optionalFeatureNameVmp, exitCode);
}
return WI_IsAnyFlagSet(Components, WslInstallComponentWslOC | WslInstallComponentVMPOC)
? HRESULT_FROM_WIN32(ERROR_SUCCESS_REBOOT_REQUIRED)
: S_OK;
}
CATCH_RETURN();

View File

@ -179,6 +179,23 @@ HRESULT WslMapPort(LSWVirtualMachineHandle VirtualMachine, const struct PortMapp
HRESULT WslUnmapPort(LSWVirtualMachineHandle VirtualMachine, const struct PortMappingSettings* Settings);
enum WslInstallComponent
{
WslInstallComponentNone = 0,
WslInstallComponentVMPOC = 1,
WslInstallComponentWslOC = 2,
WslInstallComponentWslPackage = 4,
};
HRESULT WslQueryMissingComponents(enum WslInstallComponent* Components);
typedef void (*WslInstallCallback)(enum WslInstallComponent, uint64_t, uint64_t, void*);
HRESULT WslInstallComponents(enum WslInstallComponent Components, WslInstallCallback ProgressCallback, void* Context);
// Used for testing until the package is published.
HRESULT WslSetPackageUrl(LPCWSTR Url);
#ifdef __cplusplus
}
#endif

View File

@ -13,4 +13,7 @@ EXPORTS
WslLaunchInteractiveTerminal
WslLaunchDebugShell
WslMapPort
WslUnmapPort
WslUnmapPort
WslQueryMissingComponents
WslInstallComponents
WslSetPackageUrl

View File

@ -18,12 +18,15 @@ Abstract:
#include "Common.h"
#include "registry.hpp"
#include "PluginTests.h"
#include "lswapi.h"
using namespace wsl::windows::common::registry;
extern std::wstring g_dumpFolder;
static std::wstring g_pipelineBuildId;
DEFINE_ENUM_FLAG_OPERATORS(WslInstallComponent);
class InstallerTests
{
std::wstring m_msixPackagePath;
@ -1037,4 +1040,76 @@ class InstallerTests
SHChangeNotify(SHCNE_ASSOCCHANGED, SHCNF_IDLIST, nullptr, nullptr);
VerifyWslSettingsProtocolAssociationExistsWithRetry();
}
TEST_METHOD(WSLAInstall)
{
auto expectComponents = [](WslInstallComponent expected) {
WslInstallComponent components{};
VERIFY_SUCCEEDED(WslQueryMissingComponents(&components));
VERIFY_ARE_EQUAL(components, expected);
};
VERIFY_ARE_EQUAL(WslInstallComponents(WslInstallComponentWslPackage, nullptr, nullptr), E_INVALIDARG);
expectComponents(WslInstallComponentNone);
UninstallMsi();
expectComponents(WslInstallComponentWslPackage);
{
UniqueWebServer fileServer(L"http://127.0.0.1:12346/", std::filesystem::path(m_msiPath));
VERIFY_SUCCEEDED(WslSetPackageUrl(L"http://127.0.0.1:12346/"));
WslInstallComponent progressedComponents{};
auto callback = [](WslInstallComponent Component, uint64_t progress, uint64_t total, void* Context) {
*reinterpret_cast<WslInstallComponent*>(Context) |= Component;
};
VERIFY_SUCCEEDED(WslInstallComponents(WslInstallComponentWslPackage, callback, &progressedComponents));
VERIFY_ARE_EQUAL(progressedComponents, WslInstallComponentWslPackage);
ValidateInstalledVersion(WIDEN(WSL_PACKAGE_VERSION));
expectComponents(WslInstallComponentNone);
progressedComponents = WslInstallComponentNone;
VERIFY_ARE_EQUAL(WslInstallComponents(WslInstallComponentVMPOC, callback, &progressedComponents), HRESULT_FROM_WIN32(ERROR_SUCCESS_REBOOT_REQUIRED));
VERIFY_ARE_EQUAL(progressedComponents, WslInstallComponentVMPOC);
progressedComponents = WslInstallComponentNone;
VERIFY_ARE_EQUAL(WslInstallComponents(WslInstallComponentWslOC, callback, &progressedComponents), HRESULT_FROM_WIN32(ERROR_SUCCESS_REBOOT_REQUIRED));
VERIFY_ARE_EQUAL(progressedComponents, WslInstallComponentWslOC);
}
{
VERIFY_SUCCEEDED(WslSetPackageUrl(L"http://127.0.0.1:12346/"));
VERIFY_ARE_EQUAL(WslInstallComponents(WslInstallComponentWslPackage, nullptr, nullptr), WININET_E_CANNOT_CONNECT);
}
}
// This test case requires a machine without the OC's enabled.
TEST_METHOD(WSLAInstallManual)
{
WslInstallComponent components{};
VERIFY_SUCCEEDED(WslQueryMissingComponents(&components));
if (!WI_IsAnyFlagSet(components, WslInstallComponentWslOC | WslInstallComponentVMPOC))
{
LogSkipped("OC are installed, skipping test. Flags: %i", components);
return;
}
auto expectedComponents = WslInstallComponentVMPOC;
WI_SetFlagIf(expectedComponents, WslInstallComponentWslOC, !wsl::windows::common::helpers::IsWindows11OrAbove());
VERIFY_ARE_EQUAL(components, expectedComponents);
WslInstallComponent progressedComponents{};
auto callback = [](WslInstallComponent Component, uint64_t progress, uint64_t total, void* Context) {
*reinterpret_cast<WslInstallComponent*>(Context) |= Component;
};
VERIFY_ARE_EQUAL(WslInstallComponents(components, callback, &progressedComponents), HRESULT_FROM_WIN32(ERROR_SUCCESS_REBOOT_REQUIRED));
VERIFY_ARE_EQUAL(progressedComponents, expectedComponents);
}
};