Implement WSLA API's to query installed components, and install them when needed (#13363)

* Save state

* Save state

* Add tests

* Finalize tests

* Remove test logic
This commit is contained in:
Blue 2025-08-08 11:42:13 -07:00 committed by Blue
parent 5438e046f9
commit 3451e8213d
8 changed files with 261 additions and 11 deletions

View File

@ -23,8 +23,6 @@ Abstract:
extern HINSTANCE g_dllInstance; extern HINSTANCE g_dllInstance;
constexpr LPCWSTR c_optionalFeatureInstallStatus = L"InstallStatus"; constexpr LPCWSTR c_optionalFeatureInstallStatus = L"InstallStatus";
constexpr LPCWSTR c_optionalFeatureNameVmp = L"VirtualMachinePlatform";
constexpr LPCWSTR c_optionalFeatureNameWsl = L"Microsoft-Windows-Subsystem-Linux";
using wsl::shared::Localization; using wsl::shared::Localization;
using namespace wsl::windows::common::distribution; using namespace wsl::windows::common::distribution;
@ -239,18 +237,32 @@ std::pair<bool, std::vector<std::wstring>> WslInstall::CheckForMissingOptionalCo
return {rebootRequired, std::move(missingComponents)}; return {rebootRequired, std::move(missingComponents)};
} }
void WslInstall::InstallOptionalComponents(const std::vector<std::wstring>& components) DWORD WslInstall::InstallOptionalComponent(LPCWSTR component, bool consoleOutput)
{ {
std::wstring systemDirectory; std::wstring systemDirectory;
THROW_IF_FAILED(wil::GetSystemDirectoryW(systemDirectory)); THROW_IF_FAILED(wil::GetSystemDirectoryW(systemDirectory));
const auto dismPath = std::filesystem::path(std::move(systemDirectory)) / L"dism.exe"; const auto dismPath = std::filesystem::path(std::move(systemDirectory)) / L"dism.exe";
auto commandLine = std::format(L"{} /Online /NoRestart /enable-feature /featurename:{}", dismPath.native(), component);
wsl::windows::common::SubProcess process(nullptr, commandLine.c_str());
if (!consoleOutput)
{
process.SetFlags(CREATE_NEW_CONSOLE);
process.SetShowWindow(SW_HIDE);
}
return process.Run();
}
void WslInstall::InstallOptionalComponents(const std::vector<std::wstring>& components)
{
for (const auto& component : components) for (const auto& component : components)
{ {
wsl::windows::common::wslutil::PrintMessage(Localization::MessageInstallingWindowsComponent(component)); wsl::windows::common::wslutil::PrintMessage(Localization::MessageInstallingWindowsComponent(component));
auto commandLine = std::format(L"{} /Online /NoRestart /enable-feature /featurename:{}", dismPath.wstring(), component); const auto exitCode = InstallOptionalComponent(component.c_str(), true);
const auto exitCode = wsl::windows::common::helpers::RunProcess(commandLine);
if (exitCode != 0 && exitCode != ERROR_SUCCESS_REBOOT_REQUIRED) if (exitCode != 0 && exitCode != ERROR_SUCCESS_REBOOT_REQUIRED)
{ {
THROW_HR_WITH_USER_ERROR(WSL_E_INSTALL_COMPONENT_FAILED, Localization::MessageOptionalComponentInstallFailed(component, exitCode)); THROW_HR_WITH_USER_ERROR(WSL_E_INSTALL_COMPONENT_FAILED, Localization::MessageOptionalComponentInstallFailed(component, exitCode));

View File

@ -19,6 +19,9 @@ Abstract:
class WslInstall class WslInstall
{ {
public: public:
static inline LPCWSTR c_optionalFeatureNameVmp = L"VirtualMachinePlatform";
static inline LPCWSTR c_optionalFeatureNameWsl = L"Microsoft-Windows-Subsystem-Linux";
struct InstallResult struct InstallResult
{ {
std::wstring Name; std::wstring Name;
@ -44,6 +47,8 @@ public:
static void InstallOptionalComponents(const std::vector<std::wstring>& components); static void InstallOptionalComponents(const std::vector<std::wstring>& components);
static DWORD InstallOptionalComponent(LPCWSTR component, bool consoleOutput);
static std::pair<std::wstring, GUID> InstallModernDistribution( static std::pair<std::wstring, GUID> InstallModernDistribution(
const wsl::windows::common::distribution::ModernDistributionVersion& distribution, const wsl::windows::common::distribution::ModernDistributionVersion& distribution,
const std::optional<ULONG>& version, const std::optional<ULONG>& version,

View File

@ -531,6 +531,20 @@ GUID wsl::windows::common::wslutil::CreateV5Uuid(const GUID& namespaceGuid, cons
} }
std::wstring wsl::windows::common::wslutil::DownloadFile(std::wstring_view Url, std::wstring Filename) std::wstring wsl::windows::common::wslutil::DownloadFile(std::wstring_view Url, std::wstring Filename)
{
wsl::windows::common::ConsoleProgressBar progressBar;
auto progress = [&](auto current, auto total) {
progressBar.Print(current, total);
return true;
};
auto cleanup = wil::scope_exit_log(WI_DIAGNOSTICS_INFO, [&]() { progressBar.Clear(); });
return DownloadFileImpl(Url, Filename, progress);
}
std::wstring wsl::windows::common::wslutil::DownloadFileImpl(
std::wstring_view Url, std::wstring Filename, const std::function<void(uint64_t, uint64_t)>& Progress)
{ {
const auto lastSlash = Url.find_last_of('/'); const auto lastSlash = Url.find_last_of('/');
THROW_HR_IF(E_INVALIDARG, lastSlash == std::wstring::npos); THROW_HR_IF(E_INVALIDARG, lastSlash == std::wstring::npos);
@ -560,7 +574,6 @@ std::wstring wsl::windows::common::wslutil::DownloadFile(std::wstring_view Url,
const auto asyncResponse = client.GetInputStreamAsync(winrt::Windows::Foundation::Uri(Url)); const auto asyncResponse = client.GetInputStreamAsync(winrt::Windows::Foundation::Uri(Url));
std::atomic<uint64_t> totalBytes; std::atomic<uint64_t> totalBytes;
wsl::windows::common::ConsoleProgressBar progressBar;
asyncResponse.Progress( asyncResponse.Progress(
[&](const winrt::Windows::Foundation::IAsyncOperationWithProgress<winrt::Windows::Storage::Streams::IInputStream, winrt::Windows::Web::Http::HttpProgress>&, [&](const winrt::Windows::Foundation::IAsyncOperationWithProgress<winrt::Windows::Storage::Streams::IInputStream, winrt::Windows::Web::Http::HttpProgress>&,
const winrt::Windows::Web::Http::HttpProgress& progress) { const winrt::Windows::Web::Http::HttpProgress& progress) {
@ -575,12 +588,11 @@ std::wstring wsl::windows::common::wslutil::DownloadFile(std::wstring_view Url,
download.Progress([&](const auto& _, uint64_t progress) { download.Progress([&](const auto& _, uint64_t progress) {
if (totalBytes != 0) if (totalBytes != 0)
{ {
progressBar.Print(progress, totalBytes); Progress(progress, totalBytes);
} }
}); });
download.get(); download.get();
progressBar.Clear();
deleteFileOnFailure.release(); deleteFileOnFailure.release();
return file.Path().c_str(); return file.Path().c_str();
@ -1134,6 +1146,25 @@ std::vector<BYTE> wsl::windows::common::wslutil::HashFile(HANDLE file, DWORD Alg
return fileHash; return fileHash;
} }
std::optional<std::tuple<uint32_t, uint32_t, uint32_t>> wsl::windows::common::wslutil::GetInstalledPackageVersion()
{
std::wstring packageVersion;
auto result = wil::ResultFromException([&]() {
auto msiKey = wsl::windows::common::registry::OpenLxssMachineKey(KEY_READ);
packageVersion = wsl::windows::common::registry::ReadString(msiKey.get(), L"Msi", L"Version");
});
if (result == HRESULT_FROM_WIN32(ERROR_FILE_NOT_FOUND) || result == HRESULT_FROM_WIN32(ERROR_PATH_NOT_FOUND))
{
return {};
}
THROW_IF_FAILED(result);
return ParseWslPackageVersion(packageVersion);
}
void wsl::windows::common::wslutil::InitializeWil() void wsl::windows::common::wslutil::InitializeWil()
{ {
wil::WilInitialize_CppWinRT(); wil::WilInitialize_CppWinRT();
@ -1248,8 +1279,8 @@ std::pair<wil::unique_hfile, wil::unique_hfile> wsl::windows::common::wslutil::O
bool wsl::windows::common::wslutil::IsVirtualMachinePlatformInstalled() bool wsl::windows::common::wslutil::IsVirtualMachinePlatformInstalled()
{ {
// Note for Windows 11 22H2 and above builds: If hyper-v is installed but VMP platform isn't, HNS and vmcompute are available // Note for Windows 11 22H2 and above builds: If hyper-v is installed but VMP platform isn't, HNS and vmcompute are
// but calls to HNS will fail if vfpext isn't installed. // available but calls to HNS will fail if vfpext isn't installed.
return wsl::windows::common::helpers::IsServicePresent(L"HNS") && return wsl::windows::common::helpers::IsServicePresent(L"HNS") &&
wsl::windows::common::helpers::IsServicePresent(L"vmcompute") && wsl::windows::common::helpers::IsServicePresent(L"vmcompute") &&
(helpers::GetWindowsVersion().BuildNumber < helpers::WindowsBuildNumbers::Nickel || (helpers::GetWindowsVersion().BuildNumber < helpers::WindowsBuildNumbers::Nickel ||

View File

@ -84,6 +84,8 @@ GUID CreateV5Uuid(const GUID& namespaceGuid, const std::span<const std::byte> na
std::wstring DownloadFile(std::wstring_view Url, std::wstring Filename); std::wstring DownloadFile(std::wstring_view Url, std::wstring Filename);
std::wstring DownloadFileImpl(std::wstring_view Url, std::wstring Filename, const std::function<void(uint64_t, uint64_t)>& Progress);
[[nodiscard]] HANDLE DuplicateHandleFromCallingProcess(_In_ HANDLE handleInTarget); [[nodiscard]] HANDLE DuplicateHandleFromCallingProcess(_In_ HANDLE handleInTarget);
void EnforceFileLimit(LPCWSTR Folder, size_t limit, const std::function<bool(const std::filesystem::directory_entry&)>& pred); void EnforceFileLimit(LPCWSTR Folder, size_t limit, const std::function<bool(const std::filesystem::directory_entry&)>& pred);
@ -116,6 +118,8 @@ std::wstring GetSystemErrorString(_In_ HRESULT result);
std::wstring GetDebugShellPipeName(_In_ PSID Sid); std::wstring GetDebugShellPipeName(_In_ PSID Sid);
std::optional<std::tuple<uint32_t, uint32_t, uint32_t>> GetInstalledPackageVersion();
std::vector<BYTE> HashFile(HANDLE File, DWORD Algorithm); std::vector<BYTE> HashFile(HANDLE File, DWORD Algorithm);
void InitializeWil(); void InitializeWil();

View File

@ -16,6 +16,7 @@ Abstract:
#include "wslservice.h" #include "wslservice.h"
#include "LSWApi.h" #include "LSWApi.h"
#include "wslrelay.h" #include "wslrelay.h"
#include "wslInstall.h"
namespace { namespace {
@ -292,6 +293,8 @@ EXTERN_C BOOL STDAPICALLTYPE DllMain(_In_ HINSTANCE Instance, _In_ DWORD Reason,
{ {
case DLL_PROCESS_ATTACH: case DLL_PROCESS_ATTACH:
WslTraceLoggingInitialize(LxssTelemetryProvider, false); WslTraceLoggingInitialize(LxssTelemetryProvider, false);
wsl::windows::common::wslutil::InitializeWil();
break; break;
case DLL_PROCESS_DETACH: case DLL_PROCESS_DETACH:
@ -301,3 +304,103 @@ EXTERN_C BOOL STDAPICALLTYPE DllMain(_In_ HINSTANCE Instance, _In_ DWORD Reason,
return TRUE; return TRUE;
} }
DEFINE_ENUM_FLAG_OPERATORS(WslInstallComponent);
HRESULT WslQueryMissingComponents(enum WslInstallComponent* Components)
try
{
*Components = WslInstallComponentNone;
// Check for Windows features
WI_SetFlagIf(
*Components,
WslInstallComponentWslOC,
!wsl::windows::common::helpers::IsWindows11OrAbove() && !wsl::windows::common::helpers::IsServicePresent(L"lxssmanager"));
WI_SetFlagIf(*Components, WslInstallComponentVMPOC, !wsl::windows::common::wslutil::IsVirtualMachinePlatformInstalled());
// Check if the WSL package is installed, and if the version supports WSLA
auto version = wsl::windows::common::wslutil::GetInstalledPackageVersion();
constexpr auto minimalPackageVersion = wsl::shared::PackageVersion; // TODO: replace with correct version once WSLA is released.
WI_SetFlagIf(*Components, WslInstallComponentWslPackage, !version.has_value() || version < minimalPackageVersion);
// TODO: Check if hardware supports virtualization.
return S_OK;
}
CATCH_RETURN();
// Used for debugging.
static LPCWSTR PackageUrl = nullptr;
HRESULT WslSetPackageUrl(LPCWSTR Url)
{
PackageUrl = Url;
return S_OK;
}
HRESULT WslInstallComponents(enum WslInstallComponent Components, WslInstallCallback ProgressCallback, void* Context)
try
{
// Check for invalid flags.
RETURN_HR_IF_MSG(
E_INVALIDARG,
(Components & ~(WslInstallComponentVMPOC | WslInstallComponentWslOC | WslInstallComponentWslPackage)) != 0,
"Unexpected flag: %i",
Components);
// Fail if the caller is not elevated.
RETURN_HR_IF(
HRESULT_FROM_WIN32(ERROR_ELEVATION_REQUIRED),
Components != 0 && !wsl::windows::common::security::IsTokenElevated(wil::open_current_access_token().get()));
if (WI_IsFlagSet(Components, WslInstallComponentWslPackage))
{
THROW_HR_IF(E_INVALIDARG, PackageUrl == nullptr);
auto callback = [&](uint64_t progress, uint64_t total) {
if (ProgressCallback != nullptr)
{
ProgressCallback(WslInstallComponentWslPackage, progress, total, Context);
}
};
const auto downloadPath = wsl::windows::common::wslutil::DownloadFileImpl(PackageUrl, L"wsl.msi", callback);
auto cleanup = wil::scope_exit_log(WI_DIAGNOSTICS_INFO, [&]() { std::filesystem::remove(downloadPath); });
auto exitCode = wsl::windows::common::wslutil::UpgradeViaMsi(downloadPath.c_str(), nullptr, nullptr, [](auto, auto) {});
THROW_HR_IF_MSG(
E_FAIL, exitCode != 0, "MSI installation failed. URL: %ls, DownloadPath: %ls, exitCode: %u", downloadPath.c_str(), PackageUrl, exitCode);
}
std::vector<std::wstring> optionalComponents;
if (WI_IsFlagSet(Components, WslInstallComponentWslOC))
{
if (ProgressCallback != nullptr)
{
ProgressCallback(WslInstallComponentWslOC, 0, 1, Context);
}
auto exitCode = WslInstall::InstallOptionalComponent(WslInstall::c_optionalFeatureNameWsl, false);
THROW_HR_IF_MSG(E_FAIL, exitCode != 0 && exitCode != ERROR_SUCCESS_REBOOT_REQUIRED, "Failed to install '%ls', %lu", WslInstall::c_optionalFeatureNameWsl, exitCode);
}
if (WI_IsFlagSet(Components, WslInstallComponentVMPOC))
{
if (ProgressCallback != nullptr)
{
ProgressCallback(WslInstallComponentVMPOC, 0, 1, Context);
}
auto exitCode = WslInstall::InstallOptionalComponent(WslInstall::c_optionalFeatureNameVmp, false);
THROW_HR_IF_MSG(E_FAIL, exitCode != 0 && exitCode != ERROR_SUCCESS_REBOOT_REQUIRED, "Failed to install '%ls', %lu", WslInstall::c_optionalFeatureNameVmp, exitCode);
}
return WI_IsAnyFlagSet(Components, WslInstallComponentWslOC | WslInstallComponentVMPOC)
? HRESULT_FROM_WIN32(ERROR_SUCCESS_REBOOT_REQUIRED)
: S_OK;
}
CATCH_RETURN();

View File

@ -179,6 +179,23 @@ HRESULT WslMapPort(LSWVirtualMachineHandle VirtualMachine, const struct PortMapp
HRESULT WslUnmapPort(LSWVirtualMachineHandle VirtualMachine, const struct PortMappingSettings* Settings); HRESULT WslUnmapPort(LSWVirtualMachineHandle VirtualMachine, const struct PortMappingSettings* Settings);
enum WslInstallComponent
{
WslInstallComponentNone = 0,
WslInstallComponentVMPOC = 1,
WslInstallComponentWslOC = 2,
WslInstallComponentWslPackage = 4,
};
HRESULT WslQueryMissingComponents(enum WslInstallComponent* Components);
typedef void (*WslInstallCallback)(enum WslInstallComponent, uint64_t, uint64_t, void*);
HRESULT WslInstallComponents(enum WslInstallComponent Components, WslInstallCallback ProgressCallback, void* Context);
// Used for testing until the package is published.
HRESULT WslSetPackageUrl(LPCWSTR Url);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif

View File

@ -13,4 +13,7 @@ EXPORTS
WslLaunchInteractiveTerminal WslLaunchInteractiveTerminal
WslLaunchDebugShell WslLaunchDebugShell
WslMapPort WslMapPort
WslUnmapPort WslUnmapPort
WslQueryMissingComponents
WslInstallComponents
WslSetPackageUrl

View File

@ -18,12 +18,15 @@ Abstract:
#include "Common.h" #include "Common.h"
#include "registry.hpp" #include "registry.hpp"
#include "PluginTests.h" #include "PluginTests.h"
#include "lswapi.h"
using namespace wsl::windows::common::registry; using namespace wsl::windows::common::registry;
extern std::wstring g_dumpFolder; extern std::wstring g_dumpFolder;
static std::wstring g_pipelineBuildId; static std::wstring g_pipelineBuildId;
DEFINE_ENUM_FLAG_OPERATORS(WslInstallComponent);
class InstallerTests class InstallerTests
{ {
std::wstring m_msixPackagePath; std::wstring m_msixPackagePath;
@ -1037,4 +1040,76 @@ class InstallerTests
SHChangeNotify(SHCNE_ASSOCCHANGED, SHCNF_IDLIST, nullptr, nullptr); SHChangeNotify(SHCNE_ASSOCCHANGED, SHCNF_IDLIST, nullptr, nullptr);
VerifyWslSettingsProtocolAssociationExistsWithRetry(); VerifyWslSettingsProtocolAssociationExistsWithRetry();
} }
TEST_METHOD(WSLAInstall)
{
auto expectComponents = [](WslInstallComponent expected) {
WslInstallComponent components{};
VERIFY_SUCCEEDED(WslQueryMissingComponents(&components));
VERIFY_ARE_EQUAL(components, expected);
};
VERIFY_ARE_EQUAL(WslInstallComponents(WslInstallComponentWslPackage, nullptr, nullptr), E_INVALIDARG);
expectComponents(WslInstallComponentNone);
UninstallMsi();
expectComponents(WslInstallComponentWslPackage);
{
UniqueWebServer fileServer(L"http://127.0.0.1:12346/", std::filesystem::path(m_msiPath));
VERIFY_SUCCEEDED(WslSetPackageUrl(L"http://127.0.0.1:12346/"));
WslInstallComponent progressedComponents{};
auto callback = [](WslInstallComponent Component, uint64_t progress, uint64_t total, void* Context) {
*reinterpret_cast<WslInstallComponent*>(Context) |= Component;
};
VERIFY_SUCCEEDED(WslInstallComponents(WslInstallComponentWslPackage, callback, &progressedComponents));
VERIFY_ARE_EQUAL(progressedComponents, WslInstallComponentWslPackage);
ValidateInstalledVersion(WIDEN(WSL_PACKAGE_VERSION));
expectComponents(WslInstallComponentNone);
progressedComponents = WslInstallComponentNone;
VERIFY_ARE_EQUAL(WslInstallComponents(WslInstallComponentVMPOC, callback, &progressedComponents), HRESULT_FROM_WIN32(ERROR_SUCCESS_REBOOT_REQUIRED));
VERIFY_ARE_EQUAL(progressedComponents, WslInstallComponentVMPOC);
progressedComponents = WslInstallComponentNone;
VERIFY_ARE_EQUAL(WslInstallComponents(WslInstallComponentWslOC, callback, &progressedComponents), HRESULT_FROM_WIN32(ERROR_SUCCESS_REBOOT_REQUIRED));
VERIFY_ARE_EQUAL(progressedComponents, WslInstallComponentWslOC);
}
{
VERIFY_SUCCEEDED(WslSetPackageUrl(L"http://127.0.0.1:12346/"));
VERIFY_ARE_EQUAL(WslInstallComponents(WslInstallComponentWslPackage, nullptr, nullptr), WININET_E_CANNOT_CONNECT);
}
}
// This test case requires a machine without the OC's enabled.
TEST_METHOD(WSLAInstallManual)
{
WslInstallComponent components{};
VERIFY_SUCCEEDED(WslQueryMissingComponents(&components));
if (!WI_IsAnyFlagSet(components, WslInstallComponentWslOC | WslInstallComponentVMPOC))
{
LogSkipped("OC are installed, skipping test. Flags: %i", components);
return;
}
auto expectedComponents = WslInstallComponentVMPOC;
WI_SetFlagIf(expectedComponents, WslInstallComponentWslOC, !wsl::windows::common::helpers::IsWindows11OrAbove());
VERIFY_ARE_EQUAL(components, expectedComponents);
WslInstallComponent progressedComponents{};
auto callback = [](WslInstallComponent Component, uint64_t progress, uint64_t total, void* Context) {
*reinterpret_cast<WslInstallComponent*>(Context) |= Component;
};
VERIFY_ARE_EQUAL(WslInstallComponents(components, callback, &progressedComponents), HRESULT_FROM_WIN32(ERROR_SUCCESS_REBOOT_REQUIRED));
VERIFY_ARE_EQUAL(progressedComponents, expectedComponents);
}
}; };