From 204aefd025b0fb6813d548b1b4dd891eadee63ca Mon Sep 17 00:00:00 2001 From: Blue Date: Wed, 13 Aug 2025 16:32:40 -0700 Subject: [PATCH] Signal the VM termination event to unblock any pending call when the service stops --- src/windows/service/exe/LSWUserSession.cpp | 49 +++++++++++++++---- src/windows/service/exe/LSWUserSession.h | 11 +++++ .../service/exe/LSWUserSessionFactory.cpp | 25 +++++++--- .../service/exe/LSWUserSessionFactory.h | 2 + src/windows/service/exe/LSWVirtualMachine.cpp | 33 +++++++++++-- src/windows/service/exe/LSWVirtualMachine.h | 7 ++- src/windows/service/exe/ServiceMain.cpp | 1 + test/windows/LSWTests.cpp | 29 +++++++++++ 8 files changed, 136 insertions(+), 21 deletions(-) diff --git a/src/windows/service/exe/LSWUserSession.cpp b/src/windows/service/exe/LSWUserSession.cpp index 11efd9d1..98140290 100644 --- a/src/windows/service/exe/LSWUserSession.cpp +++ b/src/windows/service/exe/LSWUserSession.cpp @@ -12,16 +12,52 @@ Abstract: --*/ #include "LSWUserSession.h" -#include "LSWVirtualMachine.h" using wsl::windows::service::lsw::LSWUserSessionImpl; -wsl::windows::service::lsw::LSWUserSessionImpl::LSWUserSessionImpl(HANDLE Token, wil::unique_tokeninfo_ptr&& TokenInfo) : +LSWUserSessionImpl::LSWUserSessionImpl(HANDLE Token, wil::unique_tokeninfo_ptr&& TokenInfo) : m_tokenInfo(std::move(TokenInfo)) { } -PSID wsl::windows::service::lsw::LSWUserSessionImpl::GetUserSid() const +LSWUserSessionImpl::~LSWUserSessionImpl() +{ + // Manually signal the VM termination events. This prevents being stuck on an API call that holds the VM lock. + { + std::lock_guard lock(m_virtualMachinesLock); + + for (auto* e : m_virtualMachines) + { + e->OnSessionTerminating(); + } + } +} + +void LSWUserSessionImpl::OnVmTerminated(LSWVirtualMachine* machine) +{ + std::lock_guard lock(m_virtualMachinesLock); + auto pred = [machine](const auto* e) { return machine == e; }; + + // Remove any stale VM reference. + m_virtualMachines.erase(std::remove_if(m_virtualMachines.begin(), m_virtualMachines.end(), pred), m_virtualMachines.end()); +} + +HRESULT LSWUserSessionImpl::CreateVirtualMachine(const VIRTUAL_MACHINE_SETTINGS* Settings, ILSWVirtualMachine** VirtualMachine) +{ + auto vm = wil::MakeOrThrow(*Settings, GetUserSid(), this); + + { + std::lock_guard lock(m_virtualMachinesLock); + m_virtualMachines.emplace_back(vm.Get()); + } + + vm->Start(); + THROW_IF_FAILED(vm.CopyTo(__uuidof(ILSWVirtualMachine), (void**)VirtualMachine)); + + return S_OK; +} + +PSID LSWUserSessionImpl::GetUserSid() const { return m_tokenInfo->User.Sid; } @@ -46,11 +82,6 @@ try auto session = m_session.lock(); RETURN_HR_IF(RPC_E_DISCONNECTED, !session); - auto vm = wil::MakeOrThrow(*Settings, session->GetUserSid()); - - THROW_IF_FAILED(vm.CopyTo(__uuidof(ILSWVirtualMachine), (void**)VirtualMachine)); - - vm->Start(); - return S_OK; + return session->CreateVirtualMachine(Settings, VirtualMachine); } CATCH_RETURN(); \ No newline at end of file diff --git a/src/windows/service/exe/LSWUserSession.h b/src/windows/service/exe/LSWUserSession.h index 6c4369bf..f8995710 100644 --- a/src/windows/service/exe/LSWUserSession.h +++ b/src/windows/service/exe/LSWUserSession.h @@ -12,8 +12,10 @@ Abstract: --*/ #pragma once +#include "LSWVirtualMachine.h" namespace wsl::windows::service::lsw { + class LSWUserSessionImpl { public: @@ -21,10 +23,19 @@ public: LSWUserSessionImpl(LSWUserSessionImpl&&) = default; LSWUserSessionImpl& operator=(LSWUserSessionImpl&&) = default; + ~LSWUserSessionImpl(); + PSID GetUserSid() const; + HRESULT CreateVirtualMachine(const VIRTUAL_MACHINE_SETTINGS* Settings, ILSWVirtualMachine** VirtualMachine); + + void OnVmTerminated(LSWVirtualMachine* machine); + private: wil::unique_tokeninfo_ptr m_tokenInfo; + + std::recursive_mutex m_virtualMachinesLock; + std::vector m_virtualMachines; }; class DECLSPEC_UUID("a9b7a1b9-0671-405c-95f1-e0612cb4ce8f") LSWUserSession diff --git a/src/windows/service/exe/LSWUserSessionFactory.cpp b/src/windows/service/exe/LSWUserSessionFactory.cpp index 83ee5d00..c9d7d477 100644 --- a/src/windows/service/exe/LSWUserSessionFactory.cpp +++ b/src/windows/service/exe/LSWUserSessionFactory.cpp @@ -17,9 +17,14 @@ Abstract: #include "LSWUserSession.h" using wsl::windows::service::lsw::LSWUserSessionFactory; +using wsl::windows::service::lsw::LSWUserSessionImpl; CoCreatableClassWithFactory(LSWUserSession, LSWUserSessionFactory); +static std::mutex g_mutex; +static std::optional>> g_sessions = + std::make_optional>>(); + HRESULT LSWUserSessionFactory::CreateInstance(_In_ IUnknown* pUnkOuter, _In_ REFIID riid, _Out_ void** ppCreated) { RETURN_HR_IF_NULL(E_POINTER, ppCreated); @@ -40,17 +45,17 @@ HRESULT LSWUserSessionFactory::CreateInstance(_In_ IUnknown* pUnkOuter, _In_ REF auto tokenInfo = wil::get_token_information(userToken.get()); - static std::mutex mutex; - static std::vector> sessions; + std::lock_guard lock{g_mutex}; - std::lock_guard lock{mutex}; + THROW_HR_IF(CO_E_SERVER_STOPPING, !g_sessions.has_value()); - auto session = std::find_if( - sessions.begin(), sessions.end(), [&tokenInfo](auto it) { return EqualSid(it->GetUserSid(), &tokenInfo->User.Sid); }); + auto session = std::find_if(g_sessions->begin(), g_sessions->end(), [&tokenInfo](auto it) { + return EqualSid(it->GetUserSid(), &tokenInfo->User.Sid); + }); - if (session == sessions.end()) + if (session == g_sessions->end()) { - session = sessions.insert(sessions.end(), std::make_shared(userToken.get(), std::move(tokenInfo))); + session = g_sessions->insert(g_sessions->end(), std::make_shared(userToken.get(), std::move(tokenInfo))); } auto comInstance = wil::MakeOrThrow(std::weak_ptr(*session)); @@ -68,4 +73,10 @@ HRESULT LSWUserSessionFactory::CreateInstance(_In_ IUnknown* pUnkOuter, _In_ REF WSL_LOG("LSWUserSessionFactory", TraceLoggingLevel(WINEVENT_LEVEL_VERBOSE)); return S_OK; +} + +void wsl::windows::service::lsw::ClearLswSessionsAndBlockNewInstances() +{ + std::lock_guard lock{g_mutex}; + g_sessions.reset(); } \ No newline at end of file diff --git a/src/windows/service/exe/LSWUserSessionFactory.h b/src/windows/service/exe/LSWUserSessionFactory.h index b35ad701..a2fed4eb 100644 --- a/src/windows/service/exe/LSWUserSessionFactory.h +++ b/src/windows/service/exe/LSWUserSessionFactory.h @@ -23,4 +23,6 @@ public: STDMETHODIMP CreateInstance(_In_ IUnknown* pUnkOuter, _In_ REFIID riid, _Out_ void** ppCreated) override; }; + +void ClearLswSessionsAndBlockNewInstances(); } // namespace wsl::windows::service::lsw \ No newline at end of file diff --git a/src/windows/service/exe/LSWVirtualMachine.cpp b/src/windows/service/exe/LSWVirtualMachine.cpp index 33baf0fa..9abe1928 100644 --- a/src/windows/service/exe/LSWVirtualMachine.cpp +++ b/src/windows/service/exe/LSWVirtualMachine.cpp @@ -16,14 +16,15 @@ Abstract: #include "LSWApi.h" #include "NatNetworking.h" #include "MirroredNetworking.h" +#include "LSWUserSession.h" using namespace wsl::windows::common; using helpers::WindowsBuildNumbers; using helpers::WindowsVersion; using wsl::windows::service::lsw::LSWVirtualMachine; -LSWVirtualMachine::LSWVirtualMachine(const VIRTUAL_MACHINE_SETTINGS& Settings, PSID UserSid) : - m_settings(Settings), m_userSid(UserSid) +LSWVirtualMachine::LSWVirtualMachine(const VIRTUAL_MACHINE_SETTINGS& Settings, PSID UserSid, LSWUserSessionImpl* Session) : + m_settings(Settings), m_userSid(UserSid), m_userSession(Session) { THROW_IF_FAILED(CoCreateGuid(&m_vmId)); @@ -42,14 +43,38 @@ HRESULT LSWVirtualMachine::GetDebugShellPipe(LPWSTR* pipePath) return S_OK; } +void LSWVirtualMachine::OnSessionTerminating() +{ + m_userSession = nullptr; + std::lock_guard mutex(m_lock); + + if (m_vmTerminatingEvent.is_signaled()) + { + return; + } + + WSL_LOG("LswSignalTerminating", TraceLoggingValue(m_running, "running")); + + m_vmTerminatingEvent.SetEvent(); +} + LSWVirtualMachine::~LSWVirtualMachine() { + { + std::lock_guard mutex(m_lock); + + if (m_userSession != nullptr) + { + m_userSession->OnVmTerminated(this); + } + } + WSL_LOG("LswTerminateVmStart", TraceLoggingValue(m_running, "running")); - m_vmTerminatingEvent.SetEvent(); m_initChannel.Close(); bool forceTerminate = false; + // Wait up to 5 seconds for the VM to terminate. if (!m_vmExitEvent.wait(5000)) { @@ -502,7 +527,7 @@ std::tuple LSWVirtualMachine::Fork auto socket = wsl::windows::common::hvsocket::Connect(m_vmId, port, m_vmExitEvent.get(), m_settings.BootTimeoutMs); - return std::make_tuple(pid, ptyMaster, wsl::shared::SocketChannel{std::move(socket), std::to_string(pid)}); + return std::make_tuple(pid, ptyMaster, wsl::shared::SocketChannel{std::move(socket), std::to_string(pid), m_vmTerminatingEvent.get()}); } wil::unique_socket LSWVirtualMachine::ConnectSocket(wsl::shared::SocketChannel& Channel, int32_t Fd) diff --git a/src/windows/service/exe/LSWVirtualMachine.h b/src/windows/service/exe/LSWVirtualMachine.h index 78cd4b20..19433b66 100644 --- a/src/windows/service/exe/LSWVirtualMachine.h +++ b/src/windows/service/exe/LSWVirtualMachine.h @@ -18,15 +18,19 @@ Abstract: #include "Dmesg.h" namespace wsl::windows::service::lsw { + +class LSWUserSessionImpl; + class DECLSPEC_UUID("0CFC5DC1-B6A7-45FC-8034-3FA9ED73CE30") LSWVirtualMachine : public Microsoft::WRL::RuntimeClass, ILSWVirtualMachine, IFastRundown> { public: - LSWVirtualMachine(const VIRTUAL_MACHINE_SETTINGS& Settings, PSID Sid); + LSWVirtualMachine(const VIRTUAL_MACHINE_SETTINGS& Settings, PSID Sid, LSWUserSessionImpl* UserSession); ~LSWVirtualMachine(); void Start(); + void OnSessionTerminating(); IFACEMETHOD(AttachDisk(_In_ PCWSTR Path, _In_ BOOL ReadOnly, _Out_ LPSTR* Device, _Out_ ULONG* Lun)) override; IFACEMETHOD(Mount(_In_ LPCSTR Source, _In_ LPCSTR Target, _In_ LPCSTR Type, _In_ LPCSTR Options, _In_ ULONG Flags)) override; @@ -88,5 +92,6 @@ private: std::map m_attachedDisks; std::mutex m_lock; std::mutex m_portRelaylock; + LSWUserSessionImpl* m_userSession; }; } // namespace wsl::windows::service::lsw \ No newline at end of file diff --git a/src/windows/service/exe/ServiceMain.cpp b/src/windows/service/exe/ServiceMain.cpp index 569cb245..55883fa1 100644 --- a/src/windows/service/exe/ServiceMain.cpp +++ b/src/windows/service/exe/ServiceMain.cpp @@ -241,6 +241,7 @@ void WslService::ServiceStopped() // Terminate all user sessions. ClearSessionsAndBlockNewInstances(); + wsl::windows::service::lsw::ClearLswSessionsAndBlockNewInstances(); // Disconnect from the LxCore driver. if (g_lxcoreInitialized) diff --git a/test/windows/LSWTests.cpp b/test/windows/LSWTests.cpp index 044251d6..2daf7dec 100644 --- a/test/windows/LSWTests.cpp +++ b/test/windows/LSWTests.cpp @@ -785,4 +785,33 @@ class LSWTests VERIFY_SUCCEEDED(WslUnmapPort(vm.get(), &port)); } + + TEST_METHOD(StuckVmTermination) + { + WSL2_TEST_ONLY(); + + VirtualMachineSettings settings{}; + settings.CPU.CpuCount = 4; + settings.DisplayName = L"LSW"; + settings.Memory.MemoryMb = 2048; + settings.Options.BootTimeoutMs = 30 * 1000; + settings.Networking.Mode = NetworkingModeNAT; + + auto vm = CreateVm(&settings); + + auto [pid, stdinFd, _, __] = LaunchCommand(vm.get(), {"/bin/cat"}); + + // Create a 'stuck' thread, waiting for cat to exit + + std::thread stuckThread([&]() { + WaitResult result{}; + WslWaitForLinuxProcess(vm.get(), pid, INFINITE, &result); + }); + + // Stop the service + StopWslService(); + + // Verify that the thread is unstuck + stuckThread.join(); + } }; \ No newline at end of file