Signal the VM termination event to unblock any pending call when the service stops

This commit is contained in:
Blue 2025-08-13 16:32:40 -07:00
parent bce2ab38c3
commit 204aefd025
8 changed files with 136 additions and 21 deletions

View File

@ -12,16 +12,52 @@ Abstract:
--*/
#include "LSWUserSession.h"
#include "LSWVirtualMachine.h"
using wsl::windows::service::lsw::LSWUserSessionImpl;
wsl::windows::service::lsw::LSWUserSessionImpl::LSWUserSessionImpl(HANDLE Token, wil::unique_tokeninfo_ptr<TOKEN_USER>&& TokenInfo) :
LSWUserSessionImpl::LSWUserSessionImpl(HANDLE Token, wil::unique_tokeninfo_ptr<TOKEN_USER>&& TokenInfo) :
m_tokenInfo(std::move(TokenInfo))
{
}
PSID wsl::windows::service::lsw::LSWUserSessionImpl::GetUserSid() const
LSWUserSessionImpl::~LSWUserSessionImpl()
{
// Manually signal the VM termination events. This prevents being stuck on an API call that holds the VM lock.
{
std::lock_guard lock(m_virtualMachinesLock);
for (auto* e : m_virtualMachines)
{
e->OnSessionTerminating();
}
}
}
void LSWUserSessionImpl::OnVmTerminated(LSWVirtualMachine* machine)
{
std::lock_guard lock(m_virtualMachinesLock);
auto pred = [machine](const auto* e) { return machine == e; };
// Remove any stale VM reference.
m_virtualMachines.erase(std::remove_if(m_virtualMachines.begin(), m_virtualMachines.end(), pred), m_virtualMachines.end());
}
HRESULT LSWUserSessionImpl::CreateVirtualMachine(const VIRTUAL_MACHINE_SETTINGS* Settings, ILSWVirtualMachine** VirtualMachine)
{
auto vm = wil::MakeOrThrow<LSWVirtualMachine>(*Settings, GetUserSid(), this);
{
std::lock_guard lock(m_virtualMachinesLock);
m_virtualMachines.emplace_back(vm.Get());
}
vm->Start();
THROW_IF_FAILED(vm.CopyTo(__uuidof(ILSWVirtualMachine), (void**)VirtualMachine));
return S_OK;
}
PSID LSWUserSessionImpl::GetUserSid() const
{
return m_tokenInfo->User.Sid;
}
@ -46,11 +82,6 @@ try
auto session = m_session.lock();
RETURN_HR_IF(RPC_E_DISCONNECTED, !session);
auto vm = wil::MakeOrThrow<LSWVirtualMachine>(*Settings, session->GetUserSid());
THROW_IF_FAILED(vm.CopyTo(__uuidof(ILSWVirtualMachine), (void**)VirtualMachine));
vm->Start();
return S_OK;
return session->CreateVirtualMachine(Settings, VirtualMachine);
}
CATCH_RETURN();

View File

@ -12,8 +12,10 @@ Abstract:
--*/
#pragma once
#include "LSWVirtualMachine.h"
namespace wsl::windows::service::lsw {
class LSWUserSessionImpl
{
public:
@ -21,10 +23,19 @@ public:
LSWUserSessionImpl(LSWUserSessionImpl&&) = default;
LSWUserSessionImpl& operator=(LSWUserSessionImpl&&) = default;
~LSWUserSessionImpl();
PSID GetUserSid() const;
HRESULT CreateVirtualMachine(const VIRTUAL_MACHINE_SETTINGS* Settings, ILSWVirtualMachine** VirtualMachine);
void OnVmTerminated(LSWVirtualMachine* machine);
private:
wil::unique_tokeninfo_ptr<TOKEN_USER> m_tokenInfo;
std::recursive_mutex m_virtualMachinesLock;
std::vector<LSWVirtualMachine*> m_virtualMachines;
};
class DECLSPEC_UUID("a9b7a1b9-0671-405c-95f1-e0612cb4ce8f") LSWUserSession

View File

@ -17,9 +17,14 @@ Abstract:
#include "LSWUserSession.h"
using wsl::windows::service::lsw::LSWUserSessionFactory;
using wsl::windows::service::lsw::LSWUserSessionImpl;
CoCreatableClassWithFactory(LSWUserSession, LSWUserSessionFactory);
static std::mutex g_mutex;
static std::optional<std::vector<std::shared_ptr<LSWUserSessionImpl>>> g_sessions =
std::make_optional<std::vector<std::shared_ptr<LSWUserSessionImpl>>>();
HRESULT LSWUserSessionFactory::CreateInstance(_In_ IUnknown* pUnkOuter, _In_ REFIID riid, _Out_ void** ppCreated)
{
RETURN_HR_IF_NULL(E_POINTER, ppCreated);
@ -40,17 +45,17 @@ HRESULT LSWUserSessionFactory::CreateInstance(_In_ IUnknown* pUnkOuter, _In_ REF
auto tokenInfo = wil::get_token_information<TOKEN_USER>(userToken.get());
static std::mutex mutex;
static std::vector<std::shared_ptr<LSWUserSessionImpl>> sessions;
std::lock_guard lock{g_mutex};
std::lock_guard lock{mutex};
THROW_HR_IF(CO_E_SERVER_STOPPING, !g_sessions.has_value());
auto session = std::find_if(
sessions.begin(), sessions.end(), [&tokenInfo](auto it) { return EqualSid(it->GetUserSid(), &tokenInfo->User.Sid); });
auto session = std::find_if(g_sessions->begin(), g_sessions->end(), [&tokenInfo](auto it) {
return EqualSid(it->GetUserSid(), &tokenInfo->User.Sid);
});
if (session == sessions.end())
if (session == g_sessions->end())
{
session = sessions.insert(sessions.end(), std::make_shared<LSWUserSessionImpl>(userToken.get(), std::move(tokenInfo)));
session = g_sessions->insert(g_sessions->end(), std::make_shared<LSWUserSessionImpl>(userToken.get(), std::move(tokenInfo)));
}
auto comInstance = wil::MakeOrThrow<LSWUserSession>(std::weak_ptr<LSWUserSessionImpl>(*session));
@ -68,4 +73,10 @@ HRESULT LSWUserSessionFactory::CreateInstance(_In_ IUnknown* pUnkOuter, _In_ REF
WSL_LOG("LSWUserSessionFactory", TraceLoggingLevel(WINEVENT_LEVEL_VERBOSE));
return S_OK;
}
void wsl::windows::service::lsw::ClearLswSessionsAndBlockNewInstances()
{
std::lock_guard lock{g_mutex};
g_sessions.reset();
}

View File

@ -23,4 +23,6 @@ public:
STDMETHODIMP CreateInstance(_In_ IUnknown* pUnkOuter, _In_ REFIID riid, _Out_ void** ppCreated) override;
};
void ClearLswSessionsAndBlockNewInstances();
} // namespace wsl::windows::service::lsw

View File

@ -16,14 +16,15 @@ Abstract:
#include "LSWApi.h"
#include "NatNetworking.h"
#include "MirroredNetworking.h"
#include "LSWUserSession.h"
using namespace wsl::windows::common;
using helpers::WindowsBuildNumbers;
using helpers::WindowsVersion;
using wsl::windows::service::lsw::LSWVirtualMachine;
LSWVirtualMachine::LSWVirtualMachine(const VIRTUAL_MACHINE_SETTINGS& Settings, PSID UserSid) :
m_settings(Settings), m_userSid(UserSid)
LSWVirtualMachine::LSWVirtualMachine(const VIRTUAL_MACHINE_SETTINGS& Settings, PSID UserSid, LSWUserSessionImpl* Session) :
m_settings(Settings), m_userSid(UserSid), m_userSession(Session)
{
THROW_IF_FAILED(CoCreateGuid(&m_vmId));
@ -42,14 +43,38 @@ HRESULT LSWVirtualMachine::GetDebugShellPipe(LPWSTR* pipePath)
return S_OK;
}
void LSWVirtualMachine::OnSessionTerminating()
{
m_userSession = nullptr;
std::lock_guard mutex(m_lock);
if (m_vmTerminatingEvent.is_signaled())
{
return;
}
WSL_LOG("LswSignalTerminating", TraceLoggingValue(m_running, "running"));
m_vmTerminatingEvent.SetEvent();
}
LSWVirtualMachine::~LSWVirtualMachine()
{
{
std::lock_guard mutex(m_lock);
if (m_userSession != nullptr)
{
m_userSession->OnVmTerminated(this);
}
}
WSL_LOG("LswTerminateVmStart", TraceLoggingValue(m_running, "running"));
m_vmTerminatingEvent.SetEvent();
m_initChannel.Close();
bool forceTerminate = false;
// Wait up to 5 seconds for the VM to terminate.
if (!m_vmExitEvent.wait(5000))
{
@ -502,7 +527,7 @@ std::tuple<int32_t, int32_t, wsl::shared::SocketChannel> LSWVirtualMachine::Fork
auto socket = wsl::windows::common::hvsocket::Connect(m_vmId, port, m_vmExitEvent.get(), m_settings.BootTimeoutMs);
return std::make_tuple(pid, ptyMaster, wsl::shared::SocketChannel{std::move(socket), std::to_string(pid)});
return std::make_tuple(pid, ptyMaster, wsl::shared::SocketChannel{std::move(socket), std::to_string(pid), m_vmTerminatingEvent.get()});
}
wil::unique_socket LSWVirtualMachine::ConnectSocket(wsl::shared::SocketChannel& Channel, int32_t Fd)

View File

@ -18,15 +18,19 @@ Abstract:
#include "Dmesg.h"
namespace wsl::windows::service::lsw {
class LSWUserSessionImpl;
class DECLSPEC_UUID("0CFC5DC1-B6A7-45FC-8034-3FA9ED73CE30") LSWVirtualMachine
: public Microsoft::WRL::RuntimeClass<Microsoft::WRL::RuntimeClassFlags<Microsoft::WRL::ClassicCom>, ILSWVirtualMachine, IFastRundown>
{
public:
LSWVirtualMachine(const VIRTUAL_MACHINE_SETTINGS& Settings, PSID Sid);
LSWVirtualMachine(const VIRTUAL_MACHINE_SETTINGS& Settings, PSID Sid, LSWUserSessionImpl* UserSession);
~LSWVirtualMachine();
void Start();
void OnSessionTerminating();
IFACEMETHOD(AttachDisk(_In_ PCWSTR Path, _In_ BOOL ReadOnly, _Out_ LPSTR* Device, _Out_ ULONG* Lun)) override;
IFACEMETHOD(Mount(_In_ LPCSTR Source, _In_ LPCSTR Target, _In_ LPCSTR Type, _In_ LPCSTR Options, _In_ ULONG Flags)) override;
@ -88,5 +92,6 @@ private:
std::map<ULONG, AttachedDisk> m_attachedDisks;
std::mutex m_lock;
std::mutex m_portRelaylock;
LSWUserSessionImpl* m_userSession;
};
} // namespace wsl::windows::service::lsw

View File

@ -241,6 +241,7 @@ void WslService::ServiceStopped()
// Terminate all user sessions.
ClearSessionsAndBlockNewInstances();
wsl::windows::service::lsw::ClearLswSessionsAndBlockNewInstances();
// Disconnect from the LxCore driver.
if (g_lxcoreInitialized)

View File

@ -785,4 +785,33 @@ class LSWTests
VERIFY_SUCCEEDED(WslUnmapPort(vm.get(), &port));
}
TEST_METHOD(StuckVmTermination)
{
WSL2_TEST_ONLY();
VirtualMachineSettings settings{};
settings.CPU.CpuCount = 4;
settings.DisplayName = L"LSW";
settings.Memory.MemoryMb = 2048;
settings.Options.BootTimeoutMs = 30 * 1000;
settings.Networking.Mode = NetworkingModeNAT;
auto vm = CreateVm(&settings);
auto [pid, stdinFd, _, __] = LaunchCommand(vm.get(), {"/bin/cat"});
// Create a 'stuck' thread, waiting for cat to exit
std::thread stuckThread([&]() {
WaitResult result{};
WslWaitForLinuxProcess(vm.get(), pid, INFINITE, &result);
});
// Stop the service
StopWslService();
// Verify that the thread is unstuck
stuckThread.join();
}
};