diff --git a/CMakeLists.txt b/CMakeLists.txt index d69c521..9de9576 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -379,10 +379,10 @@ if (DEFINED WSL_DEV_BINARY_PATH) # Development shortcut to make the package smal WSL_KERNEL_MODULES_PATH="${WSL_DEV_BINARY_PATH}/modules.vhd" WSL_DEV_INSTALL_PATH="${WSL_DEV_BINARY_PATH}" WSL_GPU_LIB_PATH="${WSL_DEV_BINARY_PATH}/lib") -endif() -if (NOT OFFICIAL_BUILD AND ${TARGET_PLATFORM} STREQUAL "x64") - add_compile_definitions(WSLA_TEST_DISTRO_PATH="${WSLA_TEST_DISTRO_SOURCE_DIR}/wslatestrootfs.vhd") + if (NOT OFFICIAL_BUILD AND ${TARGET_PLATFORM} STREQUAL "x64") + add_compile_definitions(WSLA_TEST_DISTRO_PATH="${WSLA_TEST_DISTRO_SOURCE_DIR}/wslatestrootfs.vhd") + endif() endif() # Common include paths diff --git a/msipackage/package.wix.in b/msipackage/package.wix.in index 8314068..1ceca4e 100644 --- a/msipackage/package.wix.in +++ b/msipackage/package.wix.in @@ -344,6 +344,27 @@ + + + + + + + + + + + + + + + + + + + + + diff --git a/packages.config b/packages.config index 2b14616..36372a9 100644 --- a/packages.config +++ b/packages.config @@ -18,7 +18,7 @@ - + diff --git a/src/windows/common/GuestDeviceManager.cpp b/src/windows/common/GuestDeviceManager.cpp index 475e7fc..6b23a3b 100644 --- a/src/windows/common/GuestDeviceManager.cpp +++ b/src/windows/common/GuestDeviceManager.cpp @@ -70,7 +70,7 @@ void GuestDeviceManager::AddSharedMemoryDevice(_In_ const GUID& ImplementationCl static constexpr auto VIRTIO_FS_FLAGS_SHMEM_SIZE_SHIFT = 16; UINT32 flags = (SizeMb << VIRTIO_FS_FLAGS_SHMEM_SIZE_SHIFT); WI_SetFlag(flags, VIRTIO_FS_FLAGS_TYPE_SECTIONS); - (void)AddHdvShareWithOptions(VIRTIO_VIRTIOFS_DEVICE_ID, ImplementationClsid, Tag, {}, objectLifetime.Path.c_str(), flags, UserToken); + (void)AddHdvShareWithOptions(VIRTIO_FS_DEVICE_ID, ImplementationClsid, Tag, {}, objectLifetime.Path.c_str(), flags, UserToken); m_objectDirectories.emplace_back(std::move(objectLifetime)); } diff --git a/src/windows/common/GuestDeviceManager.h b/src/windows/common/GuestDeviceManager.h index f1afa99..53de3c6 100644 --- a/src/windows/common/GuestDeviceManager.h +++ b/src/windows/common/GuestDeviceManager.h @@ -8,14 +8,13 @@ #define VIRTIO_FS_FLAGS_TYPE_FILES 0x8000 #define VIRTIO_FS_FLAGS_TYPE_SECTIONS 0x4000 -// {872270E1-A899-4AF6-B454-7193634435AD} -DEFINE_GUID(VIRTIO_VIRTIOFS_DEVICE_ID, 0x872270E1, 0xA899, 0x4AF6, 0xB4, 0x54, 0x71, 0x93, 0x63, 0x44, 0x35, 0xAD); - -// {ABB755FC-1B86-4255-83E2-E5787ABCF6C2} -DEFINE_GUID(VIRTIO_PMEM_CLASS_ID, 0xABB755FC, 0x1B86, 0x4255, 0x83, 0xe2, 0xe5, 0x78, 0x7a, 0xbc, 0xf6, 0xc2); - inline const std::wstring c_defaultDeviceTag = L"default"; +// These device types are implemented by the external wsldevicehost vdev. +DEFINE_GUID(VIRTIO_FS_DEVICE_ID, 0x872270E1, 0xA899, 0x4AF6, 0xB4, 0x54, 0x71, 0x93, 0x63, 0x44, 0x35, 0xAD); // {872270E1-A899-4AF6-B454-7193634435AD} +DEFINE_GUID(VIRTIO_NET_DEVICE_ID, 0xF07010D0, 0x0EA9, 0x447F, 0x88, 0xEF, 0xBD, 0x95, 0x2A, 0x4D, 0x2F, 0x14); // {F07010D0-0EA9-447F-88EF-BD952A4D2F14} +DEFINE_GUID(VIRTIO_PMEM_DEVICE_ID, 0xEDBB24BB, 0x5E19, 0x40F4, 0x8A, 0x0F, 0x82, 0x24, 0x31, 0x30, 0x64, 0xFD); // {EDBB24BB-5E19-40F4-8A0F-8224313064FD} + // // Provides synchronized access to guest device operations. // diff --git a/src/windows/common/VirtioNetworking.cpp b/src/windows/common/VirtioNetworking.cpp index cb69244..aeb3722 100644 --- a/src/windows/common/VirtioNetworking.cpp +++ b/src/windows/common/VirtioNetworking.cpp @@ -14,16 +14,16 @@ using wsl::core::VirtioNetworking; static constexpr auto c_loopbackDeviceName = TEXT(LX_INIT_LOOPBACK_DEVICE_NAME); VirtioNetworking::VirtioNetworking( - GnsChannel&& gnsChannel, bool enableLocalhostRelay, std::shared_ptr guestDeviceManager, wil::shared_handle userToken) : + GnsChannel&& gnsChannel, bool enableLocalhostRelay, std::shared_ptr guestDeviceManager, GUID classId, wil::shared_handle userToken) : m_guestDeviceManager(std::move(guestDeviceManager)), m_userToken(std::move(userToken)), m_gnsChannel(std::move(gnsChannel)), - m_enableLocalhostRelay(enableLocalhostRelay) + m_enableLocalhostRelay(enableLocalhostRelay), + m_virtioNetworkClsid(classId) { } void VirtioNetworking::Initialize() -try { m_networkSettings = GetHostEndpointSettings(); @@ -72,7 +72,7 @@ try // Add virtio net adapter to guest m_adapterId = m_guestDeviceManager->AddGuestDevice( - c_virtioNetworkDeviceId, c_virtioNetworkClsid, L"eth0", nullptr, device_options.str().c_str(), 0, m_userToken.get()); + VIRTIO_NET_DEVICE_ID, m_virtioNetworkClsid, L"eth0", nullptr, device_options.str().c_str(), 0, m_userToken.get()); hns::HNSEndpoint endpointProperties; endpointProperties.ID = m_adapterId; @@ -114,13 +114,12 @@ try THROW_IF_WIN32_ERROR(NotifyNetworkConnectivityHintChange(&VirtioNetworking::OnNetworkConnectivityChange, this, true, &m_networkNotifyHandle)); } -CATCH_LOG() void VirtioNetworking::SetupLoopbackDevice() { m_localhostAdapterId = m_guestDeviceManager->AddGuestDevice( - c_virtioNetworkDeviceId, - c_virtioNetworkClsid, + VIRTIO_NET_DEVICE_ID, + m_virtioNetworkClsid, c_loopbackDeviceName, nullptr, L"client_ip=127.0.0.1;client_mac=00:11:22:33:44:55", @@ -188,13 +187,13 @@ HRESULT VirtioNetworking::HandlePortNotification(const SOCKADDR_INET& addr, int localAddr.Ipv6.sin6_port = addr.Ipv6.sin6_port; } } - result = ModifyOpenPorts(c_virtioNetworkClsid, c_loopbackDeviceName, localAddr, protocol, allocate); + result = ModifyOpenPorts(c_loopbackDeviceName, localAddr, protocol, allocate); LOG_HR_IF_MSG(E_FAIL, result != S_OK, "Failure adding localhost relay port %d", localAddr.Ipv4.sin_port); } if (!loopback) { - const int localResult = ModifyOpenPorts(c_virtioNetworkClsid, L"eth0", addr, protocol, allocate); + const int localResult = ModifyOpenPorts(L"eth0", addr, protocol, allocate); LOG_HR_IF_MSG(E_FAIL, localResult != S_OK, "Failure adding relay port %d", addr.Ipv4.sin_port); if (result == 0) { @@ -205,7 +204,7 @@ HRESULT VirtioNetworking::HandlePortNotification(const SOCKADDR_INET& addr, int return result; } -int VirtioNetworking::ModifyOpenPorts(_In_ const GUID& clsid, _In_ PCWSTR tag, _In_ const SOCKADDR_INET& addr, _In_ int protocol, _In_ bool isOpen) const +int VirtioNetworking::ModifyOpenPorts(_In_ PCWSTR tag, _In_ const SOCKADDR_INET& addr, _In_ int protocol, _In_ bool isOpen) const { if (protocol != IPPROTO_TCP && protocol != IPPROTO_UDP) { @@ -221,7 +220,7 @@ int VirtioNetworking::ModifyOpenPorts(_In_ const GUID& clsid, _In_ PCWSTR tag, _ } auto lock = m_lock.lock_exclusive(); - const auto server = m_guestDeviceManager->GetRemoteFileSystem(clsid, c_defaultDeviceTag); + const auto server = m_guestDeviceManager->GetRemoteFileSystem(m_virtioNetworkClsid, c_defaultDeviceTag); if (server) { std::wstring portString = std::format(L"tag={};port_number={}", tag, addr.Ipv4.sin_port); diff --git a/src/windows/common/VirtioNetworking.h b/src/windows/common/VirtioNetworking.h index 629afb2..2c78515 100644 --- a/src/windows/common/VirtioNetworking.h +++ b/src/windows/common/VirtioNetworking.h @@ -13,7 +13,7 @@ namespace wsl::core { class VirtioNetworking : public INetworkingEngine { public: - VirtioNetworking(GnsChannel&& gnsChannel, bool enableLocalhostRelay, std::shared_ptr guestDeviceManager, wil::shared_handle userToken); + VirtioNetworking(GnsChannel&& gnsChannel, bool enableLocalhostRelay, std::shared_ptr guestDeviceManager, GUID classId, wil::shared_handle userToken); ~VirtioNetworking() = default; // Note: This class cannot be moved because m_networkNotifyHandle captures a 'this' pointer. @@ -35,7 +35,7 @@ private: static std::optional FindVirtioInterfaceLuid(const SOCKADDR_INET& virtioAddress, const NL_NETWORK_CONNECTIVITY_HINT& currentConnectivityHint); HRESULT HandlePortNotification(const SOCKADDR_INET& addr, int protocol, bool allocate) const noexcept; - int ModifyOpenPorts(_In_ const GUID& clsid, _In_ PCWSTR tag, _In_ const SOCKADDR_INET& addr, _In_ int protocol, _In_ bool isOpen) const; + int ModifyOpenPorts(_In_ PCWSTR tag, _In_ const SOCKADDR_INET& addr, _In_ int protocol, _In_ bool isOpen) const; void RefreshGuestConnection(NL_NETWORK_CONNECTIVITY_HINT hint) noexcept; void SetupLoopbackDevice(); void UpdateDns(wsl::shared::hns::DNS&& dnsSettings); @@ -51,6 +51,7 @@ private: bool m_enableLocalhostRelay; GUID m_localhostAdapterId; GUID m_adapterId; + GUID m_virtioNetworkClsid; std::optional m_interfaceLuid; ULONG m_networkMtu = 0; @@ -58,11 +59,6 @@ private: // Note: this field must be destroyed first to stop the callbacks before any other field is destroyed. networking::unique_notify_handle m_networkNotifyHandle; - - // 16479D2E-F0C3-4DBA-BF7A-04FFF0892B07 - static constexpr GUID c_virtioNetworkClsid = {0x16479D2E, 0xF0C3, 0x4DBA, {0xBF, 0x7A, 0x04, 0xFF, 0xF0, 0x89, 0x2B, 0x07}}; - // F07010D0-0EA9-447F-88EF-BD952A4D2F14 - static constexpr GUID c_virtioNetworkDeviceId = {0xF07010D0, 0x0EA9, 0x447F, {0x88, 0xEF, 0xBD, 0x95, 0x2A, 0x4D, 0x2F, 0x14}}; }; } // namespace wsl::core diff --git a/src/windows/common/WSLAProcessLauncher.cpp b/src/windows/common/WSLAProcessLauncher.cpp index 84a3cf4..e95e110 100644 --- a/src/windows/common/WSLAProcessLauncher.cpp +++ b/src/windows/common/WSLAProcessLauncher.cpp @@ -109,6 +109,12 @@ std::string WSLAProcessLauncher::FormatResult(const RunningWSLAProcess::ProcessR stdErr != result.Output.end() ? stdErr->second : ""); } +std::pair RunningWSLAProcess::Wait(DWORD TimeoutMs) +{ + THROW_HR_IF(HRESULT_FROM_WIN32(ERROR_TIMEOUT), !GetExitEvent().wait(TimeoutMs)); + return GetExitState(); +} + RunningWSLAProcess::ProcessResult RunningWSLAProcess::WaitAndCaptureOutput(DWORD TimeoutMs, std::vector>&& ExtraHandles) { RunningWSLAProcess::ProcessResult result; diff --git a/src/windows/common/WSLAProcessLauncher.h b/src/windows/common/WSLAProcessLauncher.h index e86b068..c895662 100644 --- a/src/windows/common/WSLAProcessLauncher.h +++ b/src/windows/common/WSLAProcessLauncher.h @@ -47,6 +47,7 @@ public: DEFAULT_MOVABLE(RunningWSLAProcess); ProcessResult WaitAndCaptureOutput(DWORD TimeoutMs = INFINITE, std::vector>&& ExtraHandles = {}); + std::pair Wait(DWORD TimeoutMs = INFINITE); virtual wil::unique_handle GetStdHandle(int Index) = 0; virtual wil::unique_event GetExitEvent() = 0; std::pair GetExitState(); diff --git a/src/windows/common/WslClient.cpp b/src/windows/common/WslClient.cpp index 8010189..f17e0a2 100644 --- a/src/windows/common/WslClient.cpp +++ b/src/windows/common/WslClient.cpp @@ -1525,43 +1525,38 @@ int RunDebugShell() THROW_HR(HCS_E_CONNECTION_CLOSED); } +DEFINE_ENUM_FLAG_OPERATORS(WSLAFeatureFlags); + // Temporary debugging tool for WSLA int WslaShell(_In_ std::wstring_view commandLine) { -#ifdef WSLA_TEST_DISTRO_PATH + WSLA_SESSION_SETTINGS sessionSettings{}; + sessionSettings.DisplayName = L"WSLAShell"; + sessionSettings.CpuCount = 4; + sessionSettings.MemoryMb = 4096; + sessionSettings.NetworkingMode = WSLANetworkingModeNAT; + sessionSettings.BootTimeoutMs = 30 * 1000; + sessionSettings.MaximumStorageSizeMb = 4096; - std::wstring vhd = TEXT(WSLA_TEST_DISTRO_PATH); std::string shell = "/bin/sh"; - std::string fsType = "squashfs"; -#else - - std::wstring vhd = wsl::windows::common::wslutil::GetMsiPackagePath().value() + L"/system.vhd"; - std::string shell = "/bin/bash"; - std::string fsType = "ext4"; - -#endif - - VIRTUAL_MACHINE_SETTINGS settings{}; - settings.CpuCount = 4; - settings.DisplayName = L"WSLA"; - settings.MemoryMb = 1024; - settings.BootTimeoutMs = 30000; - settings.NetworkingMode = WSLANetworkingModeNAT; - std::wstring containerRootVhd; std::string containerImage; bool help = false; std::wstring debugShell; + std::wstring storagePath; + std::wstring rootVhdOverride; + std::string rootVhdTypeOverride; ArgumentParser parser(std::wstring{commandLine}, WSL_BINARY_NAME); - parser.AddArgument(vhd, L"--vhd"); + parser.AddArgument(rootVhdOverride, L"--vhd"); parser.AddArgument(Utf8String(shell), L"--shell"); - parser.AddArgument(reinterpret_cast(settings.EnableDnsTunneling), L"--dns-tunneling"); - parser.AddArgument(Integer(settings.MemoryMb), L"--memory"); - parser.AddArgument(Integer(settings.CpuCount), L"--cpu"); - parser.AddArgument(Integer(reinterpret_cast(settings.NetworkingMode)), L"--networking-mode"); - parser.AddArgument(Utf8String(fsType), L"--fstype"); - parser.AddArgument(containerRootVhd, L"--container-vhd"); + parser.AddArgument( + SetFlag(reinterpret_cast(sessionSettings.FeatureFlags)), L"--dns-tunneling"); + parser.AddArgument(Integer(sessionSettings.MemoryMb), L"--memory"); + parser.AddArgument(Integer(sessionSettings.CpuCount), L"--cpu"); + parser.AddArgument(Utf8String(rootVhdTypeOverride), L"--fstype"); + parser.AddArgument(storagePath, L"--storage"); + parser.AddArgument(Integer(reinterpret_cast(sessionSettings.NetworkingMode)), L"--networking-mode"); parser.AddArgument(Utf8String(containerImage), L"--image"); parser.AddArgument(debugShell, L"--debug-shell"); parser.AddArgument(help, L"--help"); @@ -1577,7 +1572,7 @@ int WslaShell(_In_ std::wstring_view commandLine) return 1; } - switch (settings.NetworkingMode) + switch (sessionSettings.NetworkingMode) { case WSLANetworkingMode::WSLANetworkingModeNone: case WSLANetworkingMode::WSLANetworkingModeNAT: @@ -1587,28 +1582,30 @@ int WslaShell(_In_ std::wstring_view commandLine) THROW_HR(E_INVALIDARG); } - if (!containerRootVhd.empty()) - { - settings.ContainerRootVhd = containerRootVhd.c_str(); - - if (!std::filesystem::exists(containerRootVhd)) - { - auto token = wil::open_current_access_token(); - auto tokenInfo = wil::get_token_information(token.get()); - wsl::core::filesystem::CreateVhd(containerRootVhd.c_str(), 5368709120 /* 5 GB */, tokenInfo->User.Sid, FALSE, FALSE); - settings.FormatContainerRootVhd = TRUE; - } - } - wil::com_ptr userSession; THROW_IF_FAILED(CoCreateInstance(__uuidof(WSLAUserSession), nullptr, CLSCTX_LOCAL_SERVER, IID_PPV_ARGS(&userSession))); wsl::windows::common::security::ConfigureForCOMImpersonation(userSession.get()); wil::com_ptr virtualMachine; - WSLA_SESSION_SETTINGS sessionSettings{L"WSLA Test Session"}; wil::com_ptr session; - settings.RootVhd = vhd.c_str(); - settings.RootVhdType = fsType.c_str(); + + if (!rootVhdOverride.empty()) + { + if (rootVhdTypeOverride.empty()) + { + wprintf(L"--fstype required when --vhd is passed\n"); + return 1; + } + + sessionSettings.RootVhdOverride = rootVhdOverride.c_str(); + sessionSettings.RootVhdTypeOverride = rootVhdTypeOverride.c_str(); + } + + if (!storagePath.empty()) + { + storagePath = std::filesystem::weakly_canonical(storagePath).wstring(); + sessionSettings.StoragePath = storagePath.c_str(); + } if (!debugShell.empty()) { @@ -1616,17 +1613,10 @@ int WslaShell(_In_ std::wstring_view commandLine) } else { - THROW_IF_FAILED(userSession->CreateSession(&sessionSettings, &settings, &session)); + THROW_IF_FAILED(userSession->CreateSession(&sessionSettings, &session)); THROW_IF_FAILED(session->GetVirtualMachine(&virtualMachine)); wsl::windows::common::security::ConfigureForCOMImpersonation(userSession.get()); - - if (!containerRootVhd.empty()) - { - wsl::windows::common::WSLAProcessLauncher initProcessLauncher{shell, {shell, "/etc/lsw-init.sh"}}; - auto initProcess = initProcessLauncher.Launch(*session); - THROW_HR_IF(E_FAIL, initProcess.WaitAndCaptureOutput().Code != 0); - } } std::optional> container; diff --git a/src/windows/service/exe/WslCoreVm.cpp b/src/windows/service/exe/WslCoreVm.cpp index c2e61c4..b1907df 100644 --- a/src/windows/service/exe/WslCoreVm.cpp +++ b/src/windows/service/exe/WslCoreVm.cpp @@ -39,13 +39,15 @@ using namespace std::string_literals; // Start of unaddressable memory if guest only supports the minimum 36-bit addressing. #define MAX_36_BIT_PAGE_IN_MB (0x1000000000 / _1MB) -// This device type is implemented by the external virtio-pmem vdev. -// {EDBB24BB-5E19-40F4-8A0F-8224313064FD} -DEFINE_GUID(VIRTIO_PMEM_DEVICE_ID, 0xEDBB24BB, 0x5E19, 0x40F4, 0x8A, 0x0F, 0x82, 0x24, 0x31, 0x30, 0x64, 0xFD); - #define WSLG_SHARED_MEMORY_SIZE_MB 8192 #define PAGE_SIZE 0x1000 +// WSL-specific virtio device class IDs. +DEFINE_GUID(VIRTIO_FS_ADMIN_CLASS_ID, 0x7E6AD219, 0xD1B3, 0x42D5, 0xB8, 0xEE, 0xD9, 0x63, 0x24, 0xE6, 0x4F, 0xF6); // {7E6AD219-D1B3-42D5-B8EE-D96324E64FF6} +DEFINE_GUID(VIRTIO_FS_CLASS_ID, 0x60285AE6, 0xAAF3, 0x4456, 0xB4, 0x44, 0xA6, 0xC2, 0xD0, 0xDE, 0xDA, 0x38); // {60285AE6-AAF3-4456-B444-A6C2D0DEDA38} +DEFINE_GUID(VIRTIO_NET_CLASS_ID, 0x16479D2E, 0xF0C3, 0x4DBA, 0xBF, 0x7A, 0x04, 0xFF, 0xF0, 0x89, 0x2B, 0x07); // {16479D2E-F0C3-4DBA-BF7A-04FFF0892B07} +DEFINE_GUID(VIRTIO_PMEM_CLASS_ID, 0xABB755FC, 0x1B86, 0x4255, 0x83, 0xE2, 0xE5, 0x78, 0x7A, 0xBC, 0xF6, 0xC2); // {ABB755FC-1B86-4255-83E2-E5787ABCF6C2} + static constexpr size_t c_bootEntropy = 0x1000; static constexpr auto c_localDevicesKey = L"SOFTWARE\\Microsoft\\Terminal Server Client\\LocalDevices"; @@ -589,7 +591,7 @@ void WslCoreVm::Initialize(const GUID& VmId, const wil::shared_handle& UserToken else if (m_vmConfig.NetworkingMode == NetworkingMode::VirtioProxy) { m_networkingEngine = std::make_unique( - std::move(gnsChannel), m_vmConfig.EnableLocalhostRelay, m_guestDeviceManager, m_userToken); + std::move(gnsChannel), m_vmConfig.EnableLocalhostRelay, m_guestDeviceManager, VIRTIO_NET_CLASS_ID, m_userToken); } else if (m_vmConfig.NetworkingMode == NetworkingMode::Bridged) { @@ -1754,7 +1756,7 @@ void WslCoreVm::InitializeGuest() try { m_guestDeviceManager->AddSharedMemoryDevice( - c_virtiofsClassId, L"wslg", L"wslg", WSLG_SHARED_MEMORY_SIZE_MB, m_userToken.get()); + VIRTIO_FS_CLASS_ID, L"wslg", L"wslg", WSLG_SHARED_MEMORY_SIZE_MB, m_userToken.get()); m_sharedMemoryRoot = std::format(L"WSL\\{}\\wslg", m_machineId); } CATCH_LOG() @@ -2107,8 +2109,8 @@ std::wstring WslCoreVm::AddVirtioFsShare(_In_ bool Admin, _In_ PCWSTR Path, _In_ WI_ASSERT(!FindVirtioFsShare(tag.c_str(), Admin)); (void)m_guestDeviceManager->AddGuestDevice( - VIRTIO_VIRTIOFS_DEVICE_ID, - Admin ? c_virtiofsAdminClassId : c_virtiofsClassId, + VIRTIO_FS_DEVICE_ID, + Admin ? VIRTIO_FS_ADMIN_CLASS_ID : VIRTIO_FS_CLASS_ID, tag.c_str(), key.OptionsString().c_str(), sharePath.c_str(), diff --git a/src/windows/service/exe/WslCoreVm.h b/src/windows/service/exe/WslCoreVm.h index f807e91..431bb6d 100644 --- a/src/windows/service/exe/WslCoreVm.h +++ b/src/windows/service/exe/WslCoreVm.h @@ -40,11 +40,6 @@ inline constexpr auto c_optionsValueName = L"Options"; inline constexpr auto c_typeValueName = L"Type"; inline constexpr auto c_mountNameValueName = L"Name"; -static constexpr GUID c_virtiofsAdminClassId = {0x7e6ad219, 0xd1b3, 0x42d5, {0xb8, 0xee, 0xd9, 0x63, 0x24, 0xe6, 0x4f, 0xf6}}; - -// {60285AE6-AAF3-4456-B444-A6C2D0DEDA38} -static constexpr GUID c_virtiofsClassId = {0x60285ae6, 0xaaf3, 0x4456, {0xb4, 0x44, 0xa6, 0xc2, 0xd0, 0xde, 0xda, 0x38}}; - namespace wrl = Microsoft::WRL; /// diff --git a/src/windows/wslaservice/exe/WSLAContainer.cpp b/src/windows/wslaservice/exe/WSLAContainer.cpp index 5c241c8..f603d5f 100644 --- a/src/windows/wslaservice/exe/WSLAContainer.cpp +++ b/src/windows/wslaservice/exe/WSLAContainer.cpp @@ -87,8 +87,6 @@ Microsoft::WRL::ComPtr WSLAContainer::Create(const WSLA_CONTAINER std::vector inputOptions; if (hasStdin) { - // For now return a proper error if the caller tries to pass stdin without a TTY to prevent hangs. - THROW_WIN32_IF(ERROR_NOT_SUPPORTED, hasTty == false); inputOptions.push_back("-i"); } diff --git a/src/windows/wslaservice/exe/WSLASession.cpp b/src/windows/wslaservice/exe/WSLASession.cpp index bf43226..baf52d8 100644 --- a/src/windows/wslaservice/exe/WSLASession.cpp +++ b/src/windows/wslaservice/exe/WSLASession.cpp @@ -17,25 +17,89 @@ Abstract: #include "WSLAUserSession.h" #include "WSLAContainer.h" #include "ServiceProcessLauncher.h" +#include "WslCoreFilesystem.h" using wsl::windows::service::wsla::WSLASession; +using wsl::windows::service::wsla::WSLAVirtualMachine; -WSLASession::WSLASession(ULONG id, const WSLA_SESSION_SETTINGS& Settings, WSLAUserSessionImpl& userSessionImpl, const VIRTUAL_MACHINE_SETTINGS& VmSettings) : +WSLASession::WSLASession(ULONG id, const WSLA_SESSION_SETTINGS& Settings, WSLAUserSessionImpl& userSessionImpl) : m_id(id), m_sessionSettings(Settings), m_userSession(&userSessionImpl), - m_virtualMachine(wil::MakeOrThrow(VmSettings, userSessionImpl.GetUserSid(), &userSessionImpl)), m_displayName(Settings.DisplayName) { WSL_LOG("SessionCreated", TraceLoggingValue(m_displayName.c_str(), "DisplayName")); + m_virtualMachine = wil::MakeOrThrow(CreateVmSettings(Settings), userSessionImpl.GetUserSid()); + if (Settings.TerminationCallback != nullptr) { m_virtualMachine->RegisterCallback(Settings.TerminationCallback); } m_virtualMachine->Start(); + + ConfigureStorage(Settings); + + // Launch the init script. + // TODO: Replace with something more robust once the final VHD is ready. + try + { + ServiceProcessLauncher launcher{"/bin/sh", {"/bin/sh", "-c", "/etc/lsw-init.sh"}}; + auto result = launcher.Launch(*m_virtualMachine.Get()).WaitAndCaptureOutput(); + + THROW_HR_IF_MSG(E_FAIL, result.Code != 0, "Init script failed: %hs", launcher.FormatResult(result).c_str()); + } + catch (...) + { + // Ignore issues launching the init script with custom root VHD's, for convenience. + // TODO: Remove once the final VHD is ready. + if (Settings.RootVhdOverride == nullptr) + { + throw; + } + } +} + +WSLAVirtualMachine::Settings WSLASession::CreateVmSettings(const WSLA_SESSION_SETTINGS& Settings) +{ + WSLAVirtualMachine::Settings vmSettings{}; + vmSettings.CpuCount = Settings.CpuCount; + vmSettings.MemoryMb = Settings.MemoryMb; + vmSettings.NetworkingMode = Settings.NetworkingMode; + vmSettings.BootTimeoutMs = Settings.BootTimeoutMs; + vmSettings.FeatureFlags = static_cast(Settings.FeatureFlags); + vmSettings.DisplayName = Settings.DisplayName; + + if (Settings.RootVhdOverride != nullptr) + { + THROW_HR_IF(E_INVALIDARG, Settings.RootVhdTypeOverride == nullptr); + + vmSettings.RootVhd = Settings.RootVhdOverride; + vmSettings.RootVhdType = Settings.RootVhdTypeOverride; + } + else + { + +#ifdef WSLA_TEST_DISTRO_PATH + + vmSettings.RootVhd = TEXT(WSLA_TEST_DISTRO_PATH); + +#else + vmSettings.RootVhd = std::filesystem::path(common::wslutil::GetMsiPackagePath().value()) / L"wslarootfs.vhd"; + +#endif + + vmSettings.RootVhdType = "squashfs"; + } + + if (Settings.DmesgOutput != 0) + { + vmSettings.DmesgHandle.reset(wsl::windows::common::wslutil::DuplicateHandleFromCallingProcess(ULongToHandle(Settings.DmesgOutput))); + } + + return vmSettings; } WSLASession::~WSLASession() @@ -56,17 +120,83 @@ WSLASession::~WSLASession() } } +void WSLASession::ConfigureStorage(const WSLA_SESSION_SETTINGS& Settings) +{ + if (Settings.StoragePath == nullptr) + { + // If no storage path is specified, use a tmpfs for convenience. + m_virtualMachine->Mount("", "/root", "tmpfs", "", 0); + return; + } + + std::filesystem::path storagePath{Settings.StoragePath}; + THROW_HR_IF_MSG(E_INVALIDARG, !storagePath.is_absolute(), "Storage path is not absolute: %ls", storagePath.c_str()); + + m_storageVhdPath = storagePath / "storage.vhdx"; + + std::string diskDevice; + std::optional diskLun{}; + bool vhdCreated = false; + + auto deleteVhdOnFailure = wil::scope_exit_log(WI_DIAGNOSTICS_INFO, [&]() { + if (vhdCreated) + { + if (diskLun.has_value()) + { + m_virtualMachine->DetachDisk(diskLun.value()); + } + + auto runAsUser = wil::CoImpersonateClient(); + LOG_IF_WIN32_BOOL_FALSE(DeleteFileW(m_storageVhdPath.c_str())); + } + }); + + auto result = + wil::ResultFromException([&]() { diskDevice = m_virtualMachine->AttachDisk(m_storageVhdPath.c_str(), false).second; }); + + if (FAILED(result)) + { + THROW_HR_IF_MSG( + result, + result != HRESULT_FROM_WIN32(ERROR_PATH_NOT_FOUND) && result != HRESULT_FROM_WIN32(ERROR_FILE_NOT_FOUND), + "Failed to attach vhd: %ls", + m_storageVhdPath.c_str()); + + // If the VHD wasn't found, create it. + WSL_LOG("CreateStorageVhd", TraceLoggingValue(m_storageVhdPath.c_str(), "StorageVhdPath")); + + auto runAsUser = wil::CoImpersonateClient(); + + std::filesystem::create_directories(storagePath); + wsl::core::filesystem::CreateVhd( + m_storageVhdPath.c_str(), Settings.MaximumStorageSizeMb * _1MB, m_userSession->GetUserSid(), false, false); + vhdCreated = true; + + // Then attach the new disk. + std::tie(diskLun, diskDevice) = m_virtualMachine->AttachDisk(m_storageVhdPath.c_str(), false); + + // Then format it. + Ext4Format(diskDevice); + } + + // Mount the device to /root. + m_virtualMachine->Mount(diskDevice.c_str(), "/root", "ext4", "", 0); + + deleteVhdOnFailure.release(); +} + +HRESULT WSLASession::GetDisplayName(LPWSTR* DisplayName) +{ + RETURN_HR_IF_NULL(E_POINTER, DisplayName); + return wil::make_cotaskmem_string_nothrow(m_displayName.c_str(), DisplayName); +} + void WSLASession::CopyDisplayName(_Out_writes_z_(bufferLength) PWSTR buffer, size_t bufferLength) const { THROW_HR_IF(E_BOUNDS, m_displayName.size() + 1 > bufferLength); wcscpy_s(buffer, bufferLength, m_displayName.c_str()); } -/** const std::wstring& WSLASession::DisplayName() const -{ - return m_displayName; -}*/ - HRESULT WSLASession::PullImage(LPCWSTR Image, const WSLA_REGISTRY_AUTHENTICATION_INFORMATION* RegistryInformation, IProgressCallback* ProgressCallback) { return E_NOTIMPL; @@ -137,6 +267,15 @@ try } CATCH_RETURN(); +void WSLASession::Ext4Format(const std::string& Device) +{ + constexpr auto mkfsPath = "/usr/sbin/mkfs.ext4"; + ServiceProcessLauncher launcher(mkfsPath, {mkfsPath, Device}); + auto result = launcher.Launch(*m_virtualMachine.Get()).WaitAndCaptureOutput(); + + THROW_HR_IF_MSG(E_FAIL, result.Code != 0, "%hs", launcher.FormatResult(result).c_str()); +} + HRESULT WSLASession::FormatVirtualDisk(LPCWSTR Path) try { @@ -152,11 +291,7 @@ try auto detachDisk = wil::scope_exit_log(WI_DIAGNOSTICS_INFO, [this, lun]() { m_virtualMachine->DetachDisk(lun); }); // Format it to ext4. - constexpr auto mkfsPath = "/usr/sbin/mkfs.ext4"; - ServiceProcessLauncher launcher(mkfsPath, {mkfsPath, device}); - auto result = launcher.Launch(*m_virtualMachine.Get()).WaitAndCaptureOutput(); - - THROW_HR_IF_MSG(E_FAIL, result.Code != 0, "%hs", launcher.FormatResult(result).c_str()); + Ext4Format(device); return S_OK; } diff --git a/src/windows/wslaservice/exe/WSLASession.h b/src/windows/wslaservice/exe/WSLASession.h index f34faa4..a07bca5 100644 --- a/src/windows/wslaservice/exe/WSLASession.h +++ b/src/windows/wslaservice/exe/WSLASession.h @@ -23,7 +23,9 @@ class DECLSPEC_UUID("4877FEFC-4977-4929-A958-9F36AA1892A4") WSLASession : public Microsoft::WRL::RuntimeClass, IWSLASession, IFastRundown> { public: - WSLASession(ULONG id, const WSLA_SESSION_SETTINGS& Settings, WSLAUserSessionImpl& userSessionImpl, const VIRTUAL_MACHINE_SETTINGS& VmSettings); + + WSLASession(ULONG id, const WSLA_SESSION_SETTINGS& Settings, WSLAUserSessionImpl& userSessionImpl); + ~WSLASession(); ULONG GetId() const noexcept @@ -61,11 +63,19 @@ public: void OnUserSessionTerminating(); private: + ULONG m_id = 0; + + static WSLAVirtualMachine::Settings CreateVmSettings(const WSLA_SESSION_SETTINGS& Settings); + + void ConfigureStorage(const WSLA_SESSION_SETTINGS& Settings); + void Ext4Format(const std::string& Device); + WSLA_SESSION_SETTINGS m_sessionSettings; // TODO: Revisit to see if we should have session settings as a member or not WSLAUserSessionImpl* m_userSession = nullptr; Microsoft::WRL::ComPtr m_virtualMachine; std::wstring m_displayName; + std::filesystem::path m_storageVhdPath; std::mutex m_lock; // TODO: Add container tracking here. Could reuse m_lock for that. diff --git a/src/windows/wslaservice/exe/WSLAUserSession.cpp b/src/windows/wslaservice/exe/WSLAUserSession.cpp index 093fd7f..c558485 100644 --- a/src/windows/wslaservice/exe/WSLAUserSession.cpp +++ b/src/windows/wslaservice/exe/WSLAUserSession.cpp @@ -47,14 +47,13 @@ PSID WSLAUserSessionImpl::GetUserSid() const return m_tokenInfo->User.Sid; } -HRESULT WSLAUserSessionImpl::CreateSession(const WSLA_SESSION_SETTINGS* Settings, const VIRTUAL_MACHINE_SETTINGS* VmSettings, IWSLASession** WslaSession) +HRESULT WSLAUserSessionImpl::CreateSession(const WSLA_SESSION_SETTINGS* Settings, IWSLASession** WslaSession) { ULONG id = m_nextSessionId++; - auto session = wil::MakeOrThrow(id, *Settings, *this, *VmSettings); - { - std::lock_guard lock(m_lock); - m_sessions.emplace(session.Get()); - } + auto session = wil::MakeOrThrow(id, *Settings, *this); + + std::lock_guard lock(m_wslaSessionsLock); + auto it = m_sessions.emplace(session.Get()); // Client now owns the session. // TODO: Add a flag for the client to specify that the session should outlive its process. @@ -116,14 +115,13 @@ HRESULT wsl::windows::service::wsla::WSLAUserSession::GetVersion(_Out_ WSLA_VERS return S_OK; } -HRESULT wsl::windows::service::wsla::WSLAUserSession::CreateSession( - const WSLA_SESSION_SETTINGS* Settings, const VIRTUAL_MACHINE_SETTINGS* VmSettings, IWSLASession** WslaSession) +HRESULT wsl::windows::service::wsla::WSLAUserSession::CreateSession(const WSLA_SESSION_SETTINGS* Settings, IWSLASession** WslaSession) try { auto session = m_session.lock(); RETURN_HR_IF(RPC_E_DISCONNECTED, !session); - return session->CreateSession(Settings, VmSettings, WslaSession); + return session->CreateSession(Settings, WslaSession); } CATCH_RETURN(); diff --git a/src/windows/wslaservice/exe/WSLAUserSession.h b/src/windows/wslaservice/exe/WSLAUserSession.h index ce9894a..4c0df92 100644 --- a/src/windows/wslaservice/exe/WSLAUserSession.h +++ b/src/windows/wslaservice/exe/WSLAUserSession.h @@ -33,7 +33,7 @@ public: PSID GetUserSid() const; - HRESULT CreateSession(const WSLA_SESSION_SETTINGS* Settings, const VIRTUAL_MACHINE_SETTINGS* VmSettings, IWSLASession** WslaSession); + HRESULT CreateSession(const WSLA_SESSION_SETTINGS* Settings, IWSLASession** WslaSession); HRESULT OpenSessionByName(_In_ LPCWSTR DisplayName, _Out_ IWSLASession** Session); HRESULT ListSessions(_Out_ WSLA_SESSION_INFORMATION** Sessions, _Out_ ULONG* SessionsCount); @@ -58,7 +58,7 @@ public: WSLAUserSession& operator=(const WSLAUserSession&) = delete; IFACEMETHOD(GetVersion)(_Out_ WSLA_VERSION* Version) override; - IFACEMETHOD(CreateSession)(const WSLA_SESSION_SETTINGS* WslaSessionSettings, const VIRTUAL_MACHINE_SETTINGS* VmSettings, IWSLASession** WslaSession) override; + IFACEMETHOD(CreateSession)(const WSLA_SESSION_SETTINGS* WslaSessionSettings, IWSLASession** WslaSession) override; IFACEMETHOD(ListSessions)(_Out_ WSLA_SESSION_INFORMATION** Sessions, _Out_ ULONG* SessionsCount) override; IFACEMETHOD(OpenSession)(_In_ ULONG Id, _Out_ IWSLASession** Session) override; IFACEMETHOD(OpenSessionByName)(_In_ LPCWSTR DisplayName, _Out_ IWSLASession** Session) override; diff --git a/src/windows/wslaservice/exe/WSLAVirtualMachine.cpp b/src/windows/wslaservice/exe/WSLAVirtualMachine.cpp index aa56004..ee3de2c 100644 --- a/src/windows/wslaservice/exe/WSLAVirtualMachine.cpp +++ b/src/windows/wslaservice/exe/WSLAVirtualMachine.cpp @@ -33,28 +33,17 @@ constexpr auto SAVED_STATE_FILE_EXTENSION = L".vmrs"; constexpr auto SAVED_STATE_FILE_PREFIX = L"saved-state-"; constexpr auto RECEIVE_TIMEOUT = 30 * 1000; -WSLAVirtualMachine::WSLAVirtualMachine(const VIRTUAL_MACHINE_SETTINGS& Settings, PSID UserSid, WSLAUserSessionImpl* Session) : - m_settings(Settings), m_userSid(UserSid) +// WSLA-specific virtio device class IDs. +DEFINE_GUID(WSLA_VIRTIO_NET_CLASS_ID, 0x7B3C9A42, 0x8E1F, 0x4D5A, 0x9F, 0x2E, 0xC4, 0xA7, 0xB8, 0xD3, 0xE6, 0xF1); // {7B3C9A42-8E1F-4D5A-9F2E-C4A7B8D3E6F1} + +WSLAVirtualMachine::WSLAVirtualMachine(WSLAVirtualMachine::Settings&& Settings, PSID UserSid) : + m_settings(std::move(Settings)), m_userSid(UserSid) { THROW_IF_FAILED(CoCreateGuid(&m_vmId)); m_vmIdString = wsl::shared::string::GuidToString(m_vmId, wsl::shared::string::GuidToStringFlags::Uppercase); m_userToken = wsl::windows::common::security::GetUserToken(TokenImpersonation); m_crashDumpFolder = GetCrashDumpFolder(); - - if (Settings.EnableDebugShell) - { - m_debugShellPipe = wsl::windows::common::wslutil::GetDebugShellPipeName(m_userSid) + m_settings.DisplayName; - } -} - -HRESULT WSLAVirtualMachine::GetDebugShellPipe(LPWSTR* pipePath) -{ - RETURN_HR_IF(E_INVALIDARG, m_debugShellPipe.empty()); - - *pipePath = wil::make_unique_string(m_debugShellPipe.c_str()).release(); - - return S_OK; } void WSLAVirtualMachine::OnSessionTerminated() @@ -76,6 +65,11 @@ void WSLAVirtualMachine::OnSessionTerminated() WSLAVirtualMachine::~WSLAVirtualMachine() { WSL_LOG("WSLATerminateVmStart", TraceLoggingValue(m_running, "running")); + if (!m_computeSystem) + { + // If m_computeSystem is null, don't try to stop the VM since it never started. + return; + } m_initChannel.Close(); @@ -209,14 +203,11 @@ void WSLAVirtualMachine::Start() kernelCmdLine += L" hv_utils.timesync_implicit=1"; wil::unique_handle dmesgOutput; - if (m_settings.DmesgOutput != 0) - { - dmesgOutput.reset(wsl::windows::common::wslutil::DuplicateHandleFromCallingProcess(ULongToHandle(m_settings.DmesgOutput))); - } + dmesgOutput = std::move(m_settings.DmesgHandle); m_dmesgCollector = DmesgCollector::Create(m_vmId, m_vmExitEvent, true, false, L"", true, std::move(dmesgOutput)); - if (m_settings.EnableEarlyBootDmesg) + if (FeatureEnabled(WslaFeatureFlagsEarlyBootDmesg)) { kernelCmdLine += L" earlycon=uart8250,io,0x3f8,115200"; vmSettings.Devices.ComPorts["0"] = hcs::ComPort{m_dmesgCollector->EarlyConsoleName()}; @@ -324,6 +315,7 @@ void WSLAVirtualMachine::Start() wsl::windows::common::hcs::RegisterCallback(m_computeSystem.get(), &s_OnExit, this); wsl::windows::common::hcs::StartComputeSystem(m_computeSystem.get(), json.c_str()); + m_running = true; // Create a socket listening for crash dumps. auto crashDumpSocket = wsl::windows::common::hvsocket::Listen(runtimeId, LX_INIT_UTILITY_VM_CRASH_DUMP_PORT); @@ -364,7 +356,7 @@ void WSLAVirtualMachine::Start() Mount(m_initChannel, device.c_str(), "", "ext4", "ro", WSLA_MOUNT::KernelModules); // Configure GPU if requested. - if (m_settings.EnableGPU) + if (FeatureEnabled(WslaFeatureFlagsGPU)) { hcs::ModifySettingRequest gpuRequest{}; gpuRequest.ResourcePath = L"VirtualMachine/ComputeTopology/Gpu"; @@ -385,32 +377,23 @@ void WSLAVirtualMachine::Start() void WSLAVirtualMachine::ConfigureMounts() { - auto [_, device] = AttachDisk(m_settings.RootVhd, true); + auto [_, device] = AttachDisk(m_settings.RootVhd.c_str(), true); - Mount(m_initChannel, device.c_str(), "/mnt", m_settings.RootVhdType, "ro", WSLAMountFlagsChroot | WSLAMountFlagsWriteableOverlayFs); + Mount(m_initChannel, device.c_str(), "/mnt", m_settings.RootVhdType.c_str(), "ro", WSLAMountFlagsChroot | WSLAMountFlagsWriteableOverlayFs); Mount(m_initChannel, nullptr, "/dev", "devtmpfs", "", 0); Mount(m_initChannel, nullptr, "/sys", "sysfs", "", 0); Mount(m_initChannel, nullptr, "/proc", "proc", "", 0); Mount(m_initChannel, nullptr, "/dev/pts", "devpts", "noatime,nosuid,noexec,gid=5,mode=620", 0); - if (m_settings.EnableGPU) // TODO: re-think how GPU settings should work at the session level API. + if (FeatureEnabled(WslaFeatureFlagsGPU)) // TODO: re-think how GPU settings should work at the session level API. { MountGpuLibraries("/usr/lib/wsl/lib", "/usr/lib/wsl/drivers", WSLAMountFlagsNone); } +} - if (m_settings.ContainerRootVhd) // TODO: re-think how container root settings should work at the session level API. - { - auto [_, containerRootDevice] = AttachDisk(m_settings.ContainerRootVhd, false); - - if (m_settings.FormatContainerRootVhd) - { - ServiceProcessLauncher formatProcessLauncher{"/usr/sbin/mkfs.ext4", {"/usr/sbin/mkfs.ext4", containerRootDevice}}; - auto formatProcess = formatProcessLauncher.Launch(*this); - THROW_HR_IF(E_FAIL, formatProcess.WaitAndCaptureOutput().Code != 0); - } - - Mount(m_initChannel, containerRootDevice.c_str(), "/root", "ext4", "rw", 0); - } +bool WSLAVirtualMachine::FeatureEnabled(WSLAFeatureFlags Value) const +{ + return static_cast(m_settings.FeatureFlags) & static_cast(Value); } void WSLAVirtualMachine::WatchForExitedProcesses(wsl::shared::SocketChannel& Channel) @@ -477,7 +460,7 @@ void WSLAVirtualMachine::ConfigureNetworking() std::vector cmd{"/gns", LX_INIT_GNS_SOCKET_ARG}; // If DNS tunnelling is enabled, use an additional for its channel. - if (m_settings.EnableDnsTunneling) + if (FeatureEnabled(WslaFeatureFlagsDnsTunneling)) { THROW_HR_IF_MSG( E_NOTIMPL, @@ -541,7 +524,8 @@ void WSLAVirtualMachine::ConfigureNetworking() } else { - m_networkEngine = std::make_unique(std::move(gnsChannel), true, m_guestDeviceManager, m_userToken); + m_networkEngine = std::make_unique( + std::move(gnsChannel), true, m_guestDeviceManager, WSLA_VIRTIO_NET_CLASS_ID, m_userToken); } m_networkEngine->Initialize(); @@ -630,7 +614,7 @@ std::pair WSLAVirtualMachine::AttachDisk(_In_ PCWSTR Path, _ auto result = wil::ResultFromException([&]() { std::lock_guard lock{m_lock}; - THROW_HR_IF(HRESULT_FROM_WIN32(ERROR_INVALID_STATE), m_running); + THROW_HR_IF(HRESULT_FROM_WIN32(ERROR_INVALID_STATE), !m_running); AttachedDisk disk{Path}; @@ -759,7 +743,7 @@ std::tuple WSLAVirtualMachine::For int32_t pid{}; int32_t ptyMaster{}; { - THROW_HR_IF(HRESULT_FROM_WIN32(ERROR_INVALID_STATE), m_running); + THROW_HR_IF(HRESULT_FROM_WIN32(ERROR_INVALID_STATE), !m_running); WSLA_FORK message; message.ForkType = Type; @@ -937,6 +921,13 @@ Microsoft::WRL::ComPtr WSLAVirtualMachine::CreateLinuxProcess(_In_ return process; } +void WSLAVirtualMachine::Mount(LPCSTR Source, LPCSTR Target, LPCSTR Type, LPCSTR Options, ULONG Flags) +{ + std::lock_guard lock{m_lock}; + + Mount(m_initChannel, Source, Target, Type, Options, Flags); +} + void WSLAVirtualMachine::Mount(shared::SocketChannel& Channel, LPCSTR Source, LPCSTR Target, LPCSTR Type, LPCSTR Options, ULONG Flags) { static_assert(WSLAMountFlagsNone == WSLA_MOUNT::None); @@ -1010,7 +1001,7 @@ try { std::lock_guard lock(m_lock); - THROW_HR_IF(HRESULT_FROM_WIN32(ERROR_INVALID_STATE), m_running); + THROW_HR_IF(HRESULT_FROM_WIN32(ERROR_INVALID_STATE), !m_running); WSLA_SHUTDOWN message{}; m_initChannel.SendMessage(message); @@ -1027,7 +1018,7 @@ HRESULT WSLAVirtualMachine::Signal(_In_ LONG Pid, _In_ int Signal) try { std::lock_guard lock(m_lock); - THROW_HR_IF(HRESULT_FROM_WIN32(ERROR_INVALID_STATE), m_running); + THROW_HR_IF(HRESULT_FROM_WIN32(ERROR_INVALID_STATE), !m_running); WSLA_SIGNAL message; message.Pid = Pid; @@ -1246,7 +1237,7 @@ CATCH_RETURN(); void WSLAVirtualMachine::MountGpuLibraries(_In_ LPCSTR LibrariesMountPoint, _In_ LPCSTR DriversMountpoint, _In_ DWORD Flags) { - THROW_HR_IF(HRESULT_FROM_WIN32(ERROR_INVALID_CONFIG_VALUE), !m_settings.EnableGPU); + THROW_HR_IF(HRESULT_FROM_WIN32(ERROR_INVALID_CONFIG_VALUE), !FeatureEnabled(WslaFeatureFlagsGPU)); auto [channel, _, __] = Fork(WSLA_FORK::Thread); diff --git a/src/windows/wslaservice/exe/WSLAVirtualMachine.h b/src/windows/wslaservice/exe/WSLAVirtualMachine.h index 5620b0c..e436cd1 100644 --- a/src/windows/wslaservice/exe/WSLAVirtualMachine.h +++ b/src/windows/wslaservice/exe/WSLAVirtualMachine.h @@ -43,9 +43,22 @@ public: wil::unique_socket Socket; }; + struct Settings + { + std::wstring DisplayName; + ULONGLONG MemoryMb{}; + ULONG CpuCount; + ULONG BootTimeoutMs{}; + WSLANetworkingMode NetworkingMode{}; + WSLAFeatureFlags FeatureFlags{}; + wil::unique_handle DmesgHandle; + std::filesystem::path RootVhd; + std::string RootVhdType; + }; + using TPrepareCommandLine = std::function&)>; - WSLAVirtualMachine(const VIRTUAL_MACHINE_SETTINGS& Settings, PSID Sid, WSLAUserSessionImpl* UserSession); + WSLAVirtualMachine(Settings&& Settings, PSID Sid); ~WSLAVirtualMachine(); @@ -56,7 +69,6 @@ public: IFACEMETHOD(WaitPid(_In_ LONG Pid, _In_ ULONGLONG TimeoutMs, _Out_ ULONG* State, _Out_ int* Code)) override; IFACEMETHOD(Signal(_In_ LONG Pid, _In_ int Signal)) override; IFACEMETHOD(Shutdown(ULONGLONG _In_ TimeoutMs)) override; - IFACEMETHOD(GetDebugShellPipe(_Out_ LPWSTR* pipePath)) override; IFACEMETHOD(MapPort(_In_ int Family, _In_ short WindowsPort, _In_ short LinuxPort, _In_ BOOL Remove)) override; IFACEMETHOD(Unmount(_In_ const char* Path)) override; IFACEMETHOD(MountWindowsFolder(_In_ LPCWSTR WindowsPath, _In_ LPCSTR LinuxPath, _In_ BOOL ReadOnly)) override; @@ -71,6 +83,7 @@ public: std::pair AttachDisk(_In_ PCWSTR Path, _In_ BOOL ReadOnly); void DetachDisk(_In_ ULONG Lun); + void Mount(_In_ LPCSTR Source, _In_ LPCSTR Target, _In_ LPCSTR Type, _In_ LPCSTR Options, _In_ ULONG Flags); private: static void Mount(wsl::shared::SocketChannel& Channel, LPCSTR Source, _In_ LPCSTR Target, _In_ LPCSTR Type, _In_ LPCSTR Options, _In_ ULONG Flags); @@ -82,6 +95,7 @@ private: void ConfigureMounts(); void OnExit(_In_ const HCS_EVENT* Event); void OnCrash(_In_ const HCS_EVENT* Event); + bool FeatureEnabled(WSLAFeatureFlags Flag) const; std::tuple Fork(enum WSLA_FORK::ForkType Type); std::tuple Fork( @@ -112,7 +126,7 @@ private: bool AccessGranted = false; }; - VIRTUAL_MACHINE_SETTINGS m_settings; + Settings m_settings; std::thread m_processExitThread; std::thread m_crashDumpCollectionThread; diff --git a/src/windows/wslaservice/inc/wslaservice.idl b/src/windows/wslaservice/inc/wslaservice.idl index 0ace498..5731c05 100644 --- a/src/windows/wslaservice/inc/wslaservice.idl +++ b/src/windows/wslaservice/inc/wslaservice.idl @@ -206,7 +206,6 @@ interface IWSLAVirtualMachine : IUnknown HRESULT WaitPid([in] LONG Pid, [in] ULONGLONG TimeoutMs, [out] ULONG* State, [out] int* Code); HRESULT Signal([in] LONG Pid, [in] int Signal); HRESULT Shutdown([in] ULONGLONG TimeoutMs); - HRESULT GetDebugShellPipe([out] LPWSTR* pipePath); HRESULT MapPort([in] int Family, [in] short WindowsPort, [in] short LinuxPort, [in] BOOL Remove); HRESULT Unmount([in] LPCSTR Path); HRESULT MountWindowsFolder([in] LPCWSTR WindowsPath, [in] LPCSTR LinuxPath, [in] BOOL ReadOnly); @@ -220,32 +219,29 @@ typedef enum _WSLANetworkingMode WSLANetworkingModeVirtioProxy } WSLANetworkingMode; -typedef -struct _VIRTUAL_MACHINE_SETTINGS { // TODO: Delete once the new API is wired. - LPCWSTR DisplayName; - ULONGLONG MemoryMb; - ULONG CpuCount; - ULONG BootTimeoutMs; - ULONG DmesgOutput; - WSLANetworkingMode NetworkingMode; - BOOL EnableDnsTunneling; - BOOL EnableDebugShell; - BOOL EnableEarlyBootDmesg; - BOOL EnableGPU; - LPCWSTR RootVhd; // Temporary option to provide the root VHD. TODO: Remove once runtime VHD is available. - LPCSTR RootVhdType; // Temporary option to provide the root VHD. TODO: Remove once runtime VHD is available. - LPCWSTR ContainerRootVhd; - BOOL FormatContainerRootVhd; -} VIRTUAL_MACHINE_SETTINGS; - - +typedef enum _WSLAFeatureFlags +{ + WslaFeatureFlagsNone = 0, + WslaFeatureFlagsDnsTunneling = 1, + WslaFeatureFlagsEarlyBootDmesg = 2, + WslaFeatureFlagsGPU = 4, +} WSLAFeatureFlags; struct WSLA_SESSION_SETTINGS { LPCWSTR DisplayName; LPCWSTR StoragePath; + ULONGLONG MaximumStorageSizeMb; + ULONG CpuCount; + ULONG MemoryMb; + ULONG BootTimeoutMs; + WSLANetworkingMode NetworkingMode; [unique] ITerminationCallback* TerminationCallback; + ULONG FeatureFlags; + ULONG DmesgOutput; - // TODO: Termination callback, flags + // Below options are used for debugging purposes only. + [unique] LPCWSTR RootVhdOverride; + [unique] LPCSTR RootVhdTypeOverride; }; @@ -318,7 +314,7 @@ interface IWSLAUserSession : IUnknown HRESULT GetVersion([out] WSLA_VERSION* Version); // Session managment. - HRESULT CreateSession([in] const struct WSLA_SESSION_SETTINGS* Settings, [in] const VIRTUAL_MACHINE_SETTINGS* VmSettings, [out]IWSLASession** Session); + HRESULT CreateSession([in] const struct WSLA_SESSION_SETTINGS* Settings, [out]IWSLASession** Session); HRESULT ListSessions([out, size_is(, *SessionsCount)] struct WSLA_SESSION_INFORMATION** Sessions, [out] ULONG* SessionsCount); HRESULT OpenSession([in] ULONG Id, [out]IWSLASession** Session); HRESULT OpenSessionByName([in] LPCWSTR DisplayName, [out] IWSLASession** Session); diff --git a/test/windows/WSLATests.cpp b/test/windows/WSLATests.cpp index 0115736..32e7faf 100644 --- a/test/windows/WSLATests.cpp +++ b/test/windows/WSLATests.cpp @@ -29,6 +29,8 @@ using wsl::windows::common::WSLAProcessLauncher; using wsl::windows::common::relay::OverlappedIOHandle; using wsl::windows::common::relay::WriteHandle; +DEFINE_ENUM_FLAG_OPERATORS(WSLAFeatureFlags); + class WSLATests { WSL_TEST_CLASS(WSLATests) @@ -54,20 +56,25 @@ class WSLATests return true; } - wil::com_ptr CreateSession(VIRTUAL_MACHINE_SETTINGS& vmSettings, const WSLA_SESSION_SETTINGS& sessionSettings = {L"wsla-test"}) + static WSLA_SESSION_SETTINGS GetDefaultSessionSettings() { - if (vmSettings.RootVhdType == nullptr) - { - vmSettings.RootVhdType = "ext4"; - } + WSLA_SESSION_SETTINGS settings{}; + settings.DisplayName = L"wsla-test"; + settings.CpuCount = 4; + settings.MemoryMb = 2024; + settings.BootTimeoutMs = 30 * 1000; + return settings; + } + wil::com_ptr CreateSession(const WSLA_SESSION_SETTINGS& sessionSettings = GetDefaultSessionSettings()) + { wil::com_ptr userSession; VERIFY_SUCCEEDED(CoCreateInstance(__uuidof(WSLAUserSession), nullptr, CLSCTX_LOCAL_SERVER, IID_PPV_ARGS(&userSession))); wsl::windows::common::security::ConfigureForCOMImpersonation(userSession.get()); wil::com_ptr session; - VERIFY_SUCCEEDED(userSession->CreateSession(&sessionSettings, &vmSettings, &session)); + VERIFY_SUCCEEDED(userSession->CreateSession(&sessionSettings, &session)); wsl::windows::common::security::ConfigureForCOMImpersonation(session.get()); return session; @@ -145,7 +152,7 @@ class WSLATests if (result.Code != expectedResult) { LogError( - "Command didn't return expected code (%i). ExitCode: %i, Stdout: '%hs', Stderr: '%hs'", + "Comman didn't return expected code (%i). ExitCode: %i, Stdout: '%hs', Stderr: '%hs'", expectedResult, result.Code, result.Output[1].c_str(), @@ -177,14 +184,9 @@ class WSLATests auto createVmWithDmesg = [this](bool earlyBootLogging) { auto [read, write] = CreateSubprocessPipe(false, false); - VIRTUAL_MACHINE_SETTINGS settings{}; - settings.CpuCount = 4; - settings.DisplayName = L"WSLA"; - settings.MemoryMb = 2048; - settings.BootTimeoutMs = 30 * 1000; + auto settings = GetDefaultSessionSettings(); settings.DmesgOutput = (ULONG) reinterpret_cast(write.get()); - settings.EnableEarlyBootDmesg = earlyBootLogging; - settings.RootVhd = testVhd.c_str(); + WI_SetFlagIf(settings.FeatureFlags, WslaFeatureFlagsEarlyBootDmesg, earlyBootLogging); std::vector dmesgContent; auto readDmesg = [read = read.get(), &dmesgContent]() mutable { @@ -223,7 +225,7 @@ class WSLATests write.reset(); - ExpectCommandResult(session.get(), {"/bin/bash", "-c", "echo DmesgTest > /dev/kmsg"}, 0); + ExpectCommandResult(session.get(), {"/bin/sh", "-c", "echo DmesgTest > /dev/kmsg"}, 0); VERIFY_ARE_EQUAL(session->Shutdown(30 * 1000), S_OK); detach.reset(); @@ -281,23 +283,16 @@ class WSLATests std::function m_callback; }; - VIRTUAL_MACHINE_SETTINGS settings{}; - settings.CpuCount = 4; - settings.DisplayName = L"WSLA"; - settings.MemoryMb = 2048; - settings.BootTimeoutMs = 30 * 1000; - settings.RootVhd = testVhd.c_str(); - std::promise> promise; CallbackInstance callback{[&](WSLAVirtualMachineTerminationReason reason, LPCWSTR details) { promise.set_value(std::make_pair(reason, details)); }}; - WSLA_SESSION_SETTINGS sessionSettings{L"wsla-test"}; + WSLA_SESSION_SETTINGS sessionSettings = GetDefaultSessionSettings(); sessionSettings.TerminationCallback = &callback; - auto session = CreateSession(settings, sessionSettings); + auto session = CreateSession(sessionSettings); wil::com_ptr vm; VERIFY_SUCCEEDED(session->GetVirtualMachine(&vm)); @@ -313,14 +308,7 @@ class WSLATests { WSL2_TEST_ONLY(); - VIRTUAL_MACHINE_SETTINGS settings{}; - settings.CpuCount = 4; - settings.DisplayName = L"WSLA"; - settings.MemoryMb = 2048; - settings.BootTimeoutMs = 30 * 1000; - settings.RootVhd = testVhd.c_str(); - - auto session = CreateSession(settings); + auto session = CreateSession(); WSLAProcessLauncher launcher("/bin/sh", {"/bin/sh"}, {"TERM=xterm-256color"}, ProcessFlags::None); launcher.AddFd(WSLA_PROCESS_FD{.Fd = 0, .Type = WSLAFdTypeTerminalInput}); @@ -354,7 +342,7 @@ class WSLATests }; // Expect the shell prompt to be displayed - validateTtyOutput("#"); + validateTtyOutput("/ #"); writeTty("echo OK\n"); validateTtyOutput(" echo OK\r\nOK"); @@ -368,20 +356,15 @@ class WSLATests { WSL2_TEST_ONLY(); - VIRTUAL_MACHINE_SETTINGS settings{}; - settings.CpuCount = 4; - settings.DisplayName = L"WSLA"; - settings.MemoryMb = 2048; - settings.BootTimeoutMs = 30 * 1000; + auto settings = GetDefaultSessionSettings(); settings.NetworkingMode = WSLANetworkingModeNAT; - settings.RootVhd = testVhd.c_str(); auto session = CreateSession(settings); // Validate that eth0 has an ip address ExpectCommandResult( session.get(), - {"/bin/bash", + {"/bin/sh", "-c", "ip a show dev eth0 | grep -iF 'inet ' | grep -E '[0-9]{1,3}\\.[0-9]{1,3}\\.[0-9]{1,3}\\.[0-9]{1,3}'"}, 0); @@ -393,21 +376,16 @@ class WSLATests { WSL2_TEST_ONLY(); - VIRTUAL_MACHINE_SETTINGS settings{}; - settings.CpuCount = 4; - settings.DisplayName = L"WSLA"; - settings.MemoryMb = 2048; - settings.BootTimeoutMs = 30 * 1000; + auto settings = GetDefaultSessionSettings(); settings.NetworkingMode = WSLANetworkingModeNAT; - settings.EnableDnsTunneling = true; - settings.RootVhd = testVhd.c_str(); + WI_SetFlag(settings.FeatureFlags, WslaFeatureFlagsDnsTunneling); auto session = CreateSession(settings); // Validate that eth0 has an ip address ExpectCommandResult( session.get(), - {"/bin/bash", + {"/bin/sh", "-c", "ip a show dev eth0 | grep -iF 'inet ' | grep -E '[0-9]{1,3}\\.[0-9]{1,3}\\.[0-9]{1,3}\\.[0-9]{1,3}'"}, 0); @@ -418,43 +396,11 @@ class WSLATests VERIFY_ARE_EQUAL(result.Output[1], std::format("nameserver {}\n", LX_INIT_DNS_TUNNELING_IP_ADDRESS)); } - TEST_METHOD(VirtioProxyNetworking) - { - WSL2_TEST_ONLY(); - - VIRTUAL_MACHINE_SETTINGS settings{}; - settings.CpuCount = 4; - settings.DisplayName = L"WSLA"; - settings.MemoryMb = 2048; - settings.BootTimeoutMs = 30 * 1000; - settings.NetworkingMode = WSLANetworkingModeVirtioProxy; - settings.RootVhd = testVhd.c_str(); - - auto session = CreateSession(settings); - - // Validate that eth0 has an ip address - ExpectCommandResult( - session.get(), - {"/bin/bash", - "-c", - "ip a show dev eth0 | grep -iF 'inet ' | grep -E '[0-9]{1,3}\\.[0-9]{1,3}\\.[0-9]{1,3}\\.[0-9]{1,3}'"}, - 0); - - ExpectCommandResult(session.get(), {"/bin/grep", "-iF", "nameserver", "/etc/resolv.conf"}, 0); - } - TEST_METHOD(OpenFiles) { WSL2_TEST_ONLY(); - VIRTUAL_MACHINE_SETTINGS settings{}; - settings.CpuCount = 4; - settings.DisplayName = L"WSLA"; - settings.MemoryMb = 2048; - settings.BootTimeoutMs = 30 * 1000; - settings.RootVhd = testVhd.c_str(); - - auto session = CreateSession(settings); + auto session = CreateSession(); struct FileFd { @@ -544,7 +490,7 @@ class WSLATests {{0, WSLAFdTypeLinuxFileInput, "/proc/self/comm"}, {1, WSLAFdTypeLinuxFileInput, "/tmp/output"}, {2, WSLAFdTypeDefault, nullptr}}); auto result = process->WaitAndCaptureOutput(); - VERIFY_ARE_EQUAL(result.Output[2], "/bin/cat: write error: Bad file descriptor\n"); + VERIFY_ARE_EQUAL(result.Output[2], "cat: write error: Bad file descriptor\n"); VERIFY_ARE_EQUAL(result.Code, 1); } @@ -552,7 +498,7 @@ class WSLATests auto process = createProcess({"/bin/cat"}, {{0, WSLAFdTypeLinuxFileOutput, "/tmp/output"}, {2, WSLAFdTypeDefault, nullptr}}); auto result = process->WaitAndCaptureOutput(); - VERIFY_ARE_EQUAL(result.Output[2], "/bin/cat: standard output: Bad file descriptor\n"); + VERIFY_ARE_EQUAL(result.Output[2], "cat: read error: Bad file descriptor\n"); VERIFY_ARE_EQUAL(result.Code, 1); } } @@ -561,13 +507,10 @@ class WSLATests { WSL2_TEST_ONLY(); - VIRTUAL_MACHINE_SETTINGS settings{}; - settings.CpuCount = 4; - settings.DisplayName = L"WSLA"; - settings.MemoryMb = 2048; - settings.BootTimeoutMs = 30 * 1000; + auto settings = GetDefaultSessionSettings(); + settings.RootVhdOverride = testVhd.c_str(); // socat is required to run this test case. + settings.RootVhdTypeOverride = "ext4"; settings.NetworkingMode = WSLANetworkingModeNAT; - settings.RootVhd = testVhd.c_str(); auto session = CreateSession(settings); @@ -607,7 +550,7 @@ class WSLATests auto listen = [&](short port, const char* content, bool ipv6) { auto cmd = std::format("echo -n '{}' | /usr/bin/socat -dd TCP{}-LISTEN:{},reuseaddr -", content, ipv6 ? "6" : "", port); - auto process = WSLAProcessLauncher("/bin/bash", {"/bin/bash", "-c", cmd}).Launch(*session); + auto process = WSLAProcessLauncher("/bin/sh", {"/bin/sh", "-c", cmd}).Launch(*session); waitForOutput(process.GetStdHandle(2).get(), "listening on"); return process; @@ -702,14 +645,7 @@ class WSLATests { WSL2_TEST_ONLY(); - VIRTUAL_MACHINE_SETTINGS settings{}; - settings.CpuCount = 4; - settings.DisplayName = L"WSLA"; - settings.MemoryMb = 2048; - settings.BootTimeoutMs = 30 * 1000; - settings.RootVhd = testVhd.c_str(); - - auto session = CreateSession(settings); + auto session = CreateSession(); // Create a 'stuck' process auto process = WSLAProcessLauncher{"/bin/cat", {"/bin/cat"}, {}, ProcessFlags::Stdin | ProcessFlags::Stdout}.Launch(*session); @@ -722,14 +658,7 @@ class WSLATests { WSL2_TEST_ONLY(); - VIRTUAL_MACHINE_SETTINGS settings{}; - settings.CpuCount = 4; - settings.DisplayName = L"WSLA"; - settings.MemoryMb = 2048; - settings.BootTimeoutMs = 30 * 1000; - settings.RootVhd = testVhd.c_str(); - - auto session = CreateSession(settings); + auto session = CreateSession(); wil::com_ptr vm; VERIFY_SUCCEEDED(session->GetVirtualMachine(&vm)); @@ -738,7 +667,7 @@ class WSLATests auto expectMount = [&](const std::string& target, const std::optional& options) { auto cmd = std::format("set -o pipefail ; findmnt '{}' | tail -n 1", target); - auto result = ExpectCommandResult(session.get(), {"/bin/bash", "-c", cmd}, options.has_value() ? 0 : 1); + auto result = ExpectCommandResult(session.get(), {"/bin/sh", "-c", cmd}, options.has_value() ? 0 : 1); const auto& output = result.Output[1]; const auto& error = result.Output[2]; @@ -769,7 +698,7 @@ class WSLATests VERIFY_ARE_EQUAL(vm->MountWindowsFolder(testFolder.c_str(), "/win-path", false), HRESULT_FROM_WIN32(ERROR_ALREADY_EXISTS)); // Validate that folder is writeable from linux - ExpectCommandResult(session.get(), {"/bin/bash", "-c", "echo -n content > /win-path/file.txt && sync"}, 0); + ExpectCommandResult(session.get(), {"/bin/sh", "-c", "echo -n content > /win-path/file.txt && sync"}, 0); VERIFY_ARE_EQUAL(ReadFileContent(testFolder / "file.txt"), L"content"); VERIFY_SUCCEEDED(vm->UnmountWindowsFolder("/win-path")); @@ -782,7 +711,7 @@ class WSLATests expectMount("/win-path", "/win-path*9p*rw,relatime,aname=*,cache=5,access=client,msize=65536,trans=fd,rfd=*,wfd=*"); // Validate that folder is not writeable from linux - ExpectCommandResult(session.get(), {"/bin/bash", "-c", "echo -n content > /win-path/file.txt"}, 1); + ExpectCommandResult(session.get(), {"/bin/sh", "-c", "echo -n content > /win-path/file.txt"}, 1); VERIFY_SUCCEEDED(vm->UnmountWindowsFolder("/win-path")); expectMount("/win-path", {}); @@ -809,19 +738,12 @@ class WSLATests { WSL2_TEST_ONLY(); - VIRTUAL_MACHINE_SETTINGS settings{}; - settings.CpuCount = 4; - settings.DisplayName = L"WSLA"; - settings.MemoryMb = 2048; - settings.BootTimeoutMs = 30 * 1000; - settings.RootVhd = testVhd.c_str(); - - auto session = CreateSession(settings); - auto result = ExpectCommandResult( - session.get(), {"/bin/bash", "-c", "echo /proc/self/fd/* && (readlink -v /proc/self/fd/* || true)"}, 0); + auto session = CreateSession(); + auto result = + ExpectCommandResult(session.get(), {"/bin/sh", "-c", "echo /proc/self/fd/* && (readlink -v /proc/self/fd/* || true)"}, 0); // Note: fd/0 is opened by readlink to read the actual content of /proc/self/fd. - if (!PathMatchSpecA(result.Output[1].c_str(), "/proc/self/fd/0 /proc/self/fd/1 /proc/self/fd/2\nsocket:[*]\nsocket:[*]\n")) + if (!PathMatchSpecA(result.Output[1].c_str(), "/proc/self/fd/0 /proc/self/fd/1 /proc/self/fd/2\n")) { LogInfo("Found additional fds: %hs", result.Output[1].c_str()); VERIFY_FAIL(); @@ -832,13 +754,8 @@ class WSLATests { WSL2_TEST_ONLY(); - VIRTUAL_MACHINE_SETTINGS settings{}; - settings.CpuCount = 4; - settings.DisplayName = L"WSLA"; - settings.MemoryMb = 2048; - settings.BootTimeoutMs = 30 * 1000; - settings.EnableGPU = true; - settings.RootVhd = testVhd.c_str(); + auto settings = GetDefaultSessionSettings(); + WI_SetFlag(settings.FeatureFlags, WslaFeatureFlagsGPU); auto session = CreateSession(settings); @@ -846,10 +763,10 @@ class WSLATests VERIFY_SUCCEEDED(session->GetVirtualMachine(&vm)); // Validate that the GPU device is available. - ExpectCommandResult(session.get(), {"/bin/bash", "-c", "test -c /dev/dxg"}, 0); + ExpectCommandResult(session.get(), {"/bin/sh", "-c", "test -c /dev/dxg"}, 0); auto expectMount = [&](const std::string& target, const std::optional& options) { auto cmd = std::format("set -o pipefail ; findmnt '{}' | tail -n 1", target); - WSLAProcessLauncher launcher{"/bin/bash", {"/bin/bash", "-c", cmd}}; + WSLAProcessLauncher launcher{"/bin/sh", {"/bin/sh", "-c", cmd}}; auto result = launcher.Launch(*session).WaitAndCaptureOutput(); const auto& output = result.Output[1]; @@ -878,7 +795,7 @@ class WSLATests // Validate that trying to mount the shares without GPU support disabled fails. { - settings.EnableGPU = false; + WI_ClearFlag(settings.FeatureFlags, WslaFeatureFlagsGPU); session = CreateSession(settings); wil::com_ptr vm; @@ -894,49 +811,23 @@ class WSLATests { WSL2_TEST_ONLY(); - VIRTUAL_MACHINE_SETTINGS settings{}; - settings.CpuCount = 4; - settings.DisplayName = L"WSLA"; - settings.MemoryMb = 2048; - settings.BootTimeoutMs = 30 * 1000; - settings.RootVhd = testVhd.c_str(); - - // Use the system distro vhd for modprobe & lsmod. - -#ifdef WSL_SYSTEM_DISTRO_PATH - - auto rootfs = std::filesystem::path(TEXT(WSL_SYSTEM_DISTRO_PATH)); - -#else - auto rootfs = std::filesystem::path(wsl::windows::common::wslutil::GetMsiPackagePath().value()) / L"system.vhd"; - -#endif - settings.RootVhd = rootfs.c_str(); - - auto session = CreateSession(settings); + auto session = CreateSession(); // Sanity check. - ExpectCommandResult(session.get(), {"/bin/bash", "-c", "lsmod | grep ^xsk_diag"}, 1); + ExpectCommandResult(session.get(), {"/bin/sh", "-c", "lsmod | grep ^xsk_diag"}, 1); // Validate that modules can be loaded. ExpectCommandResult(session.get(), {"/usr/sbin/modprobe", "xsk_diag"}, 0); // Validate that xsk_diag is now loaded. - ExpectCommandResult(session.get(), {"/bin/bash", "-c", "lsmod | grep ^xsk_diag"}, 0); + ExpectCommandResult(session.get(), {"/bin/sh", "-c", "lsmod | grep ^xsk_diag"}, 0); } TEST_METHOD(CreateRootNamespaceProcess) { WSL2_TEST_ONLY(); - VIRTUAL_MACHINE_SETTINGS settings{}; - settings.CpuCount = 4; - settings.DisplayName = L"WSLA"; - settings.MemoryMb = 2048; - settings.BootTimeoutMs = 30 * 1000; - settings.RootVhd = testVhd.c_str(); - - auto session = CreateSession(settings); + auto session = CreateSession(); // Simple case { @@ -1056,14 +947,7 @@ class WSLATests { WSL2_TEST_ONLY(); - VIRTUAL_MACHINE_SETTINGS settings{}; - settings.CpuCount = 4; - settings.DisplayName = L"WSLA"; - settings.MemoryMb = 2048; - settings.BootTimeoutMs = 30 * 1000; - settings.RootVhd = testVhd.c_str(); - - auto session = CreateSession(settings); + auto session = CreateSession(); int processId = 0; // Cache the existing crash dumps so we can check that a new one is created. @@ -1098,7 +982,7 @@ class WSLATests // Dumps files are named with the format: wsl-crash----.dmp // Check if a new file was added in crashDumpsDir matching the pattern and not in existingDumps. - std::string expectedPattern = std::format("wsl-crash-*-{}-_usr_bin_cat-11.dmp", processId); + std::string expectedPattern = std::format("wsl-crash-*-{}-_usr_bin_busybox-11.dmp", processId); auto dumpFile = wsl::shared::retry::RetryWithTimeout( [crashDumpsDir, expectedPattern, existingDumps]() { @@ -1133,14 +1017,7 @@ class WSLATests { WSL2_TEST_ONLY(); - VIRTUAL_MACHINE_SETTINGS settings{}; - settings.CpuCount = 4; - settings.DisplayName = L"WSLA"; - settings.MemoryMb = 2048; - settings.BootTimeoutMs = 30 * 1000; - settings.RootVhd = testVhd.c_str(); - - auto session = CreateSession(settings); + auto session = CreateSession(); constexpr auto formatedVhd = L"test-format-vhd.vhdx"; @@ -1173,45 +1050,28 @@ class WSLATests #else - auto storageVhd = std::filesystem::current_path() / "storage.vhdx"; + auto storagePath = std::filesystem::current_path() / "test-storage"; - // Create a 1G temporary VHD. - if (!std::filesystem::exists(storageVhd)) - { - wsl::core::filesystem::CreateVhd(storageVhd.native().c_str(), 1024 * 1024 * 1024, nullptr, true, false); - } + auto cleanup = wil::scope_exit_log(WI_DIAGNOSTICS_INFO, [&]() { + std::error_code error; - auto cleanup = wil::scope_exit_log(WI_DIAGNOSTICS_INFO, [&]() { LOG_IF_WIN32_BOOL_FALSE(DeleteFileW(storageVhd.c_str())); }); - - VIRTUAL_MACHINE_SETTINGS settings{}; - settings.CpuCount = 4; - settings.DisplayName = L"WSLA"; - settings.MemoryMb = 2048; - settings.BootTimeoutMs = 30 * 1000; + std::filesystem::remove_all(storagePath, error); + if (error) + { + LogError("Failed to cleanup storage path %ws: %s", storagePath.c_str(), error.message().c_str()); + } + }); auto installedVhdPath = std::filesystem::path(wsl::windows::common::wslutil::GetMsiPackagePath().value()) / L"wslarootfs.vhd"; -#ifdef WSL_DEV_INSTALL_PATH - - settings.RootVhd = TEXT(WSLA_TEST_DISTRO_PATH); - -#else - - settings.RootVhd = installedVhdPath.c_str(); - -#endif - - settings.RootVhdType = "squashfs"; + auto settings = GetDefaultSessionSettings(); settings.NetworkingMode = WSLANetworkingModeNAT; - settings.ContainerRootVhd = storageVhd.c_str(); - settings.FormatContainerRootVhd = true; + settings.StoragePath = storagePath.c_str(); + settings.MaximumStorageSizeMb = 1024; auto session = CreateSession(settings); - // TODO: Remove once the proper rootfs VHD is available. - ExpectCommandResult(session.get(), {"/etc/lsw-init.sh"}, 0); - // Test a simple container start. { WSLAContainerLauncher launcher("debian:latest", "test-simple", "echo", {"OK"}); @@ -1223,39 +1083,38 @@ class WSLATests // Validate that env is correctly wired. { - WSLAContainerLauncher launcher("debian:latest", "test-env", "/bin/bash", {"-c", "echo $testenv"}, {{"testenv=testvalue"}}); + WSLAContainerLauncher launcher("debian:latest", "test-env", "/bin/sh", {"-c", "echo $testenv"}, {{"testenv=testvalue"}}); auto container = launcher.Launch(*session); auto process = container.GetInitProcess(); ValidateProcessOutput(process, {{1, "testvalue\n"}}); } - // Validate that starting containers works with the default entrypoint. + // Validate that starting containers works with the default entrypoint and content on stdin { WSLAContainerLauncher launcher( "debian:latest", "test-default-entrypoint", "/bin/cat", {}, {}, ProcessFlags::Stdin | ProcessFlags::Stdout | ProcessFlags::Stderr); // For now, validate that trying to use stdin without a tty returns the appropriate error. - auto result = wil::ResultFromException([&]() { auto container = launcher.Launch(*session); }); + auto container = launcher.Launch(*session); - VERIFY_ARE_EQUAL(result, HRESULT_FROM_WIN32(ERROR_NOT_SUPPORTED)); - - // This is hanging. nerdctl run seems to hang with -i is passed outside of a TTY context. - // TODO: Restore the test case once this is fixed. - - /* + // TODO: nerdctl hangs if stdin is closed without writing to it. + // Add test coverage for that usecase once the hang is fixed. auto process = container.GetInitProcess(); + auto input = process.GetStdHandle(0); + + std::string shellInput = "foo"; + std::vector inputBuffer{shellInput.begin(), shellInput.end()}; + + std::unique_ptr writeStdin(new WriteHandle(std::move(input), inputBuffer)); - std::string shellInput = "echo $SHELL\n exit"; - std::unique_ptr writeStdin( - new WriteHandle(process.GetStdHandle(0), {shellInput.begin(), shellInput.end()})); std::vector> extraHandles; extraHandles.emplace_back(std::move(writeStdin)); auto result = process.WaitAndCaptureOutput(INFINITE, std::move(extraHandles)); - VERIFY_ARE_EQUAL(result.Output[1], "bash\n"); - */ + VERIFY_ARE_EQUAL(result.Output[2], ""); + VERIFY_ARE_EQUAL(result.Output[1], "foo"); } // Validate that stdin is empty if ProcessFlags::Stdin is not passed.