Save state

This commit is contained in:
Blue 2025-12-05 16:44:57 -08:00
parent 403215aad5
commit cdbe1d2204
11 changed files with 409 additions and 56 deletions

View File

@ -936,7 +936,12 @@ void MultiHandleWait::AddHandle(std::unique_ptr<OverlappedIOHandle>&& handle)
m_handles.emplace_back(std::move(handle));
}
void MultiHandleWait::Run(std::optional<std::chrono::milliseconds> Timeout)
void MultiHandleWait::Cancel()
{
m_cancel = true;
}
bool MultiHandleWait::Run(std::optional<std::chrono::milliseconds> Timeout)
{
std::optional<std::chrono::steady_clock::time_point> deadline;
@ -947,7 +952,7 @@ void MultiHandleWait::Run(std::optional<std::chrono::milliseconds> Timeout)
// Run until all handles are completed.
while (!m_handles.empty())
while (!m_handles.empty() && !m_cancel)
{
// Schedule IO on each handle until all are either pending, or completed.
for (auto i = 0; i < m_handles.size(); i++)
@ -997,6 +1002,8 @@ void MultiHandleWait::Run(std::optional<std::chrono::milliseconds> Timeout)
THROW_LAST_ERROR_MSG("Timeout: %lu, Count: %llu", waitTimeout, waitHandles.size());
}
}
return !m_cancel;
}
IOHandleStatus OverlappedIOHandle::GetState() const
@ -1004,8 +1011,13 @@ IOHandleStatus OverlappedIOHandle::GetState() const
return State;
}
EventHandle::EventHandle(HANDLE Handle, std::function<void()>&& OnSignalled) :
Handle(Handle), OnSignalled(std::move(OnSignalled))
{
}
EventHandle::EventHandle(wil::unique_event&& Handle, std::function<void()>&& OnSignalled) :
Handle(std::move(Handle)), OnSignalled(std::move(OnSignalled))
OwnedHandle(std::move(Handle)), Handle(OwnedHandle.get()), OnSignalled(std::move(OnSignalled))
{
}
@ -1022,7 +1034,7 @@ void EventHandle::Collect()
HANDLE EventHandle::GetHandle() const
{
return Handle.get();
return Handle;
}
ReadHandle::ReadHandle(wil::unique_handle&& MovedHandle, std::function<void(const gsl::span<char>& Buffer)>&& OnRead) :

View File

@ -183,12 +183,14 @@ public:
NON_MOVABLE(EventHandle)
EventHandle(wil::unique_event&& EventHandle, std::function<void()>&& OnSignalled);
EventHandle(HANDLE EventHandle, std::function<void()>&& OnSignalled);
void Schedule() override;
void Collect() override;
HANDLE GetHandle() const override;
private:
wil::unique_event Handle;
wil::unique_event OwnedHandle;
HANDLE Handle;
std::function<void()> OnSignalled;
};
@ -238,10 +240,12 @@ public:
MultiHandleWait() = default;
void AddHandle(std::unique_ptr<OverlappedIOHandle>&& handle);
void Run(std::optional<std::chrono::milliseconds> Timeout);
bool Run(std::optional<std::chrono::milliseconds> Timeout);
void Cancel();
private:
std::vector<std::unique_ptr<OverlappedIOHandle>> m_handles;
bool m_cancel = false;
};
} // namespace wsl::windows::common::relay

View File

@ -1,5 +1,6 @@
set(SOURCES
application.manifest
ContainerEventTracker.cpp
main.rc
ServiceMain.cpp
ServiceProcessLauncher.cpp
@ -12,6 +13,7 @@ set(SOURCES
)
set(HEADERS
ContainerEventTracker.h
ServiceProcessLauncher.h
WSLAContainer.h
WSLAProcess.h

View File

@ -0,0 +1,185 @@
#include "precomp.h"
#include "ContainerEventTracker.h"
#include "WSLAVirtualMachine.h"
#include <nlohmann/json.hpp>
using wsl::windows::service::wsla::ContainerEventTracker;
using wsl::windows::service::wsla::WSLAVirtualMachine;
ContainerEventTracker::ContainerTrackingReference::ContainerTrackingReference(ContainerEventTracker* tracker, size_t id) :
m_tracker(tracker), m_id(id)
{
}
ContainerEventTracker::ContainerTrackingReference::ContainerTrackingReference(ContainerEventTracker::ContainerTrackingReference&& other)
{
(*this) = std::move(other);
}
ContainerEventTracker::ContainerTrackingReference& ContainerEventTracker::ContainerTrackingReference::operator=(ContainerEventTracker::ContainerTrackingReference&& other)
{
m_id = other.m_id;
m_tracker = other.m_tracker;
other.m_tracker = nullptr;
return *this;
}
void ContainerEventTracker::ContainerTrackingReference::Reset()
{
if (m_tracker != nullptr)
{
m_tracker->UnregisterContainerStateUpdates(m_tracker != nullptr);
m_tracker = nullptr;
}
}
ContainerEventTracker::ContainerTrackingReference::~ContainerTrackingReference()
{
Reset();
}
ContainerEventTracker::ContainerEventTracker(WSLAVirtualMachine& virtualMachine)
{
ServiceProcessLauncher launcher{"/usr/bin/nerdctl", {"/usr/bin/nerdctl", "events", "--format", "{{json .}}"}, {}, common::ProcessFlags::Stdout};
// Redirect stderr to /dev/null to avoid pipe deadlocks.
launcher.AddFd({.Fd = 2, .Type = WSLAFdTypeLinuxFileOutput, .Path = "/dev/null"});
auto process = launcher.Launch(virtualMachine);
m_thread = std::thread(std::bind(&ContainerEventTracker::Run, this, std::move(process)));
}
ContainerEventTracker::~ContainerEventTracker()
{
// N.B. No callback should be left when the tracker is destroyed.
WI_ASSERT(m_callbacks.empty());
m_stopEvent.SetEvent();
if (m_thread.joinable())
{
m_thread.join();
}
}
void ContainerEventTracker::OnEvent(const std::string& event)
{
// TODO: log session ID
WSL_LOG("NerdCtlEvent", TraceLoggingValue(event.c_str(), "Data"));
static std::map<std::string, ContainerEvent> events{
{"/tasks/create", ContainerEvent::Create},
{"/tasks/start", ContainerEvent::Start},
{"/tasks/stop", ContainerEvent::Stop},
{"/tasks/exit", ContainerEvent::Exit},
{"/tasks/destroy", ContainerEvent::Destroy}};
auto parsed = nlohmann::json::parse(event);
auto type = parsed.find("Topic");
auto details = parsed.find("Event");
THROW_HR_IF_MSG(E_INVALIDARG, type == parsed.end() || details == parsed.end(), "Failed to parse json: %hs", event.c_str());
auto it = events.find(type->get<std::string>());
if (it == events.end())
{
return; // Event is not tracked, dropped.
}
// The 'Event' field is a json string,
auto innerEventJson = details->get<std::string>();
auto innerEvent = nlohmann::json::parse(innerEventJson);
auto containerId = innerEvent.find("container_id");
THROW_HR_IF_MSG(E_INVALIDARG, containerId == innerEvent.end(), "Failed to parse json: %hs", innerEventJson.c_str());
WSL_LOG(
"ContainerStateChange",
TraceLoggingValue(containerId->get<std::string>().c_str(), "Id"),
TraceLoggingValue((int)it->second, "State"));
std::lock_guard lock{m_lock};
auto containerEntry = m_callbacks.find(containerId->get<std::string>());
if (containerEntry != m_callbacks.end())
{
for (auto& [id, callback] : containerEntry->second)
{
callback(it->second);
}
}
}
void ContainerEventTracker::Run(ServiceRunningProcess& process)
{
std::string pendingBuffer;
wsl::windows::common::relay::MultiHandleWait io;
auto onStdout = [&](const gsl::span<char>& buffer) {
// nerdctl events' output is line based. Call OnEvent() for each completed line.
auto begin = buffer.begin();
auto end = std::ranges::find(buffer, '\n');
while (end != buffer.end())
{
pendingBuffer.insert(pendingBuffer.end(), begin, end);
if (!pendingBuffer.empty()) // nerdctl inserts empty lines between events, skip those.
{
OnEvent(pendingBuffer);
}
pendingBuffer.clear();
begin = end + 1;
end = std::ranges::find(begin, buffer.end(), '\n');
}
pendingBuffer.insert(pendingBuffer.end(), begin, end);
};
auto onStop = [&]() { io.Cancel(); };
io.AddHandle(std::make_unique<common::relay::ReadHandle>(process.GetStdHandle(1), std::move(onStdout)));
io.AddHandle(std::make_unique<common::relay::EventHandle>(m_stopEvent.get(), std::move(onStop)));
if (io.Run({}))
{
// TODO: Report error to session.
WSL_LOG("Unexpected nerdctl exit");
}
}
ContainerEventTracker::ContainerTrackingReference ContainerEventTracker::RegisterContainerStateUpdates(
const std::string& ContainerId, ContainerStateChangeCallback&& Callback)
{
std::lock_guard lock{m_lock};
auto id = callbackId++;
m_callbacks[ContainerId][id] = std::move(Callback);
return ContainerTrackingReference{this, id};
}
void ContainerEventTracker::UnregisterContainerStateUpdates(size_t Id)
{
std::lock_guard lock{m_lock};
for (auto& [containerId, callbacks] : m_callbacks)
{
auto it = callbacks.find(Id);
if (it != callbacks.end())
{
callbacks.erase(it);
if (callbacks.empty())
{
m_callbacks.erase(containerId);
}
return;
}
}
WI_ASSERT(false);
}

View File

@ -0,0 +1,59 @@
#pragma once
#include "ServiceProcessLauncher.h"
namespace wsl::windows::service::wsla {
class WSLAVirtualMachine;
enum class ContainerEvent
{
Create,
Start,
Stop,
Exit,
Destroy
};
class ContainerEventTracker
{
public:
NON_COPYABLE(ContainerEventTracker);
NON_MOVABLE(ContainerEventTracker);
struct ContainerTrackingReference
{
NON_COPYABLE(ContainerTrackingReference);
ContainerTrackingReference() = default;
ContainerTrackingReference(ContainerEventTracker* tracker, size_t id);
ContainerTrackingReference(ContainerTrackingReference&&);
~ContainerTrackingReference();
ContainerTrackingReference& operator=(ContainerTrackingReference&&);
void Reset();
size_t m_id;
ContainerEventTracker* m_tracker = nullptr;
};
using ContainerStateChangeCallback = std::function<void(ContainerEvent)>;
ContainerEventTracker(WSLAVirtualMachine& virtualMachine);
~ContainerEventTracker();
void OnEvent(const std::string& event);
ContainerTrackingReference RegisterContainerStateUpdates(const std::string& ContainerId, ContainerStateChangeCallback&& Callback);
void UnregisterContainerStateUpdates(size_t Id);
private:
void Run(ServiceRunningProcess& process);
std::map<std::string, std::map<size_t, ContainerStateChangeCallback>> m_callbacks;
std::thread m_thread;
wil::unique_event m_stopEvent{wil::EventOptions::ManualReset};
std::mutex m_lock;
std::atomic<size_t> callbackId;
};
} // namespace wsl::windows::service::wsla

View File

@ -26,36 +26,18 @@ static std::vector<std::string> defaultNerdctlRunArgs{//"--pull=never", // TODO:
"--ulimit",
"nofile=65536:65536"};
WSLAContainer::WSLAContainer(WSLAVirtualMachine* parentVM, ServiceRunningProcess&& containerProcess, const char* name, const char* image) :
m_parentVM(parentVM), m_containerProcess(std::move(containerProcess)), m_name(name), m_image(image)
WSLAContainer::WSLAContainer(WSLAVirtualMachine* parentVM, const WSLA_CONTAINER_OPTIONS& Options, std::string&& Id, ContainerEventTracker& tracker) :
m_parentVM(parentVM), m_name(Options.Name), m_image(Options.Image), m_id(std::move(Id))
{
m_state = WslaContainerStateCreated;
// TODO: Find a better way to wait for the container to be fully started.
auto status = GetNerdctlStatus();
while (status != "running")
{
if (status == "exited" || m_containerProcess.State() != WslaProcessStateRunning)
{
m_state = WslaContainerStateExited;
return;
}
m_trackingReference =
tracker.RegisterContainerStateUpdates(m_id, std::bind(&WSLAContainer::OnStateChange, this, std::placeholders::_1));
}
// TODO: empty string is returned while the container image is still downloading.
// Remove this logic once the image pull is separated from container creation.
if (status.has_value() && status != "created")
{
THROW_HR_MSG(
E_UNEXPECTED, "Unexpected nerdctl status '%hs', for container '%hs'", status.value_or("<empty>").c_str(), m_name.c_str());
}
std::this_thread::sleep_for(std::chrono::milliseconds(100));
status = GetNerdctlStatus();
}
// TODO: move to start() once create() and start() are split to different methods.
m_state = WslaContainerStateRunning;
WSLAContainer::~WSLAContainer()
{
m_trackingReference.Reset();
}
const std::string& WSLAContainer::Image() const noexcept
@ -63,9 +45,89 @@ const std::string& WSLAContainer::Image() const noexcept
return m_image;
}
HRESULT WSLAContainer::Start()
void WSLAContainer::Start(const WSLA_CONTAINER_OPTIONS& Options)
{
return E_NOTIMPL;
std::lock_guard lock{m_lock};
THROW_HR_IF_MSG(
HRESULT_FROM_WIN32(ERROR_INVALID_STATE),
m_state != WslaContainerStateCreated,
"Cannot start container '%hs', state: %i",
m_name.c_str(),
m_state);
ServiceProcessLauncher launcher(nerdctlPath, {nerdctlPath, "start", "-a", m_id}, {}, common::ProcessFlags::None);
for (auto i = 0; i < Options.InitProcessOptions.FdsCount; i++)
{
launcher.AddFd(Options.InitProcessOptions.Fds[i]);
}
m_containerProcess = launcher.Launch(*m_parentVM);
auto newState = WaitForTransition(WslaContainerStateRunning);
THROW_HR_IF_MSG(E_FAIL, newState != WslaContainerStateRunning, "Failed to start container '%hs', state: %i", m_name.c_str(), newState);
}
WSLA_CONTAINER_STATE WSLAContainer::WaitForTransition(WSLA_CONTAINER_STATE expectedState)
{
wil::shared_event transitionEvent;
{
std::lock_guard lock{m_stateLock};
if (m_state == expectedState)
{
return m_state; // Already in expected state, return immediately
}
transitionEvent = m_stateChangeEvent;
}
m_stateChangeEvent.wait();
std::lock_guard lock{m_stateLock};
return m_state;
}
void WSLAContainer::OnStateChange(ContainerEvent event)
{
std::lock_guard lock{m_stateLock};
WSLA_CONTAINER_STATE newState{};
switch (event)
{
case ContainerEvent::Start:
newState = WslaContainerStateRunning;
break;
case ContainerEvent::Stop:
case ContainerEvent::Exit:
newState = WslaContainerStateExited;
break;
case ContainerEvent::Destroy:
newState = WslaContainerStateDeleted;
break;
default:
WI_ASSERT(false);
}
if (newState == m_state)
{
return;
}
WSL_LOG(
"ContainerStateChange",
TraceLoggingValue(m_name.c_str(), "Name"),
TraceLoggingValue(m_id.c_str(), "Id"),
TraceLoggingValue((int)m_state, "OldState"),
TraceLoggingValue((int)m_state, "NewState"));
m_state = newState;
m_stateChangeEvent.SetEvent();
// Reset the event for the next state change.
m_stateChangeEvent = wil::shared_event{wil::EventOptions::None};
}
HRESULT WSLAContainer::Stop(int Signal, ULONG TimeoutMs)
@ -97,28 +159,27 @@ CATCH_RETURN();
WSLA_CONTAINER_STATE WSLAContainer::State() noexcept
{
std::lock_guard lock{m_lock};
// If the container is running, refresh the init process state before returning.
if (m_state == WslaContainerStateRunning && m_containerProcess.State() != WSLAProcessStateRunning)
{
m_state = WslaContainerStateExited;
}
std::lock_guard lock{m_stateLock};
return m_state;
}
HRESULT WSLAContainer::GetState(WSLA_CONTAINER_STATE* Result)
try
{
*Result = State();
return S_OK;
}
CATCH_RETURN();
HRESULT WSLAContainer::GetInitProcess(IWSLAProcess** Process)
try
{
return m_containerProcess.Get().QueryInterface(__uuidof(IWSLAProcess), (void**)Process);
std::lock_guard lock{m_lock};
THROW_HR_IF(HRESULT_FROM_WIN32(ERROR_INVALID_STATE), !m_containerProcess.has_value());
return m_containerProcess->Get().QueryInterface(__uuidof(IWSLAProcess), (void**)Process);
}
CATCH_RETURN();
@ -133,7 +194,8 @@ try
}
CATCH_RETURN();
Microsoft::WRL::ComPtr<WSLAContainer> WSLAContainer::Create(const WSLA_CONTAINER_OPTIONS& containerOptions, WSLAVirtualMachine& parentVM)
Microsoft::WRL::ComPtr<WSLAContainer> WSLAContainer::Create(
const WSLA_CONTAINER_OPTIONS& containerOptions, WSLAVirtualMachine& parentVM, ContainerEventTracker& eventTracker)
{
// TODO: Switch to nerdctl create, and call nerdctl start in Start().
@ -166,19 +228,25 @@ Microsoft::WRL::ComPtr<WSLAContainer> WSLAContainer::Create(const WSLA_CONTAINER
auto args = PrepareNerdctlRunCommand(containerOptions, std::move(inputOptions));
ServiceProcessLauncher launcher(nerdctlPath, args, {}, common::ProcessFlags::None);
for (size_t i = 0; i < containerOptions.InitProcessOptions.FdsCount; i++)
ServiceProcessLauncher launcher(nerdctlPath, args, {});
auto result = launcher.Launch(parentVM).WaitAndCaptureOutput();
// TODO: Have better error codes.
THROW_HR_IF_MSG(E_FAIL, result.Code != 0, "Failed to create container: %hs", launcher.FormatResult(result).c_str());
auto id = result.Output[1];
while (!id.empty() && (id.back() == '\n'))
{
launcher.AddFd(containerOptions.InitProcessOptions.Fds[i]);
id.pop_back();
}
return wil::MakeOrThrow<WSLAContainer>(&parentVM, launcher.Launch(parentVM), containerOptions.Name, containerOptions.Image);
return wil::MakeOrThrow<WSLAContainer>(&parentVM, containerOptions, std::move(id), eventTracker);
}
std::vector<std::string> WSLAContainer::PrepareNerdctlRunCommand(const WSLA_CONTAINER_OPTIONS& options, std::vector<std::string>&& inputOptions)
{
std::vector<std::string> args{nerdctlPath};
args.push_back("run");
args.push_back("create");
args.push_back("--name");
args.push_back(options.Name);
if (options.ShmSize > 0)

View File

@ -17,6 +17,7 @@ Abstract:
#include "ServiceProcessLauncher.h"
#include "wslaservice.h"
#include "WSLAVirtualMachine.h"
#include "ContainerEventTracker.h"
namespace wsl::windows::service::wsla {
@ -26,9 +27,11 @@ class DECLSPEC_UUID("B1F1C4E3-C225-4CAE-AD8A-34C004DE1AE4") WSLAContainer
public:
NON_COPYABLE(WSLAContainer);
WSLAContainer(WSLAVirtualMachine* parentVM, ServiceRunningProcess&& containerProcess, const char* name, const char* image);
WSLAContainer(WSLAVirtualMachine* parentVM, const WSLA_CONTAINER_OPTIONS& Options, std::string&& Id, ContainerEventTracker& tracker);
~WSLAContainer();
void Start(const WSLA_CONTAINER_OPTIONS& Options);
IFACEMETHOD(Start)() override;
IFACEMETHOD(Stop)(_In_ int Signal, _In_ ULONG TimeoutMs) override;
IFACEMETHOD(Delete)() override;
IFACEMETHOD(GetState)(_Out_ WSLA_CONTAINER_STATE* State) override;
@ -38,17 +41,24 @@ public:
const std::string& Image() const noexcept;
WSLA_CONTAINER_STATE State() noexcept;
static Microsoft::WRL::ComPtr<WSLAContainer> Create(const WSLA_CONTAINER_OPTIONS& Options, WSLAVirtualMachine& parentVM);
static Microsoft::WRL::ComPtr<WSLAContainer> Create(const WSLA_CONTAINER_OPTIONS& Options, WSLAVirtualMachine& parentVM, ContainerEventTracker& tracker);
private:
void OnStateChange(ContainerEvent event);
WSLA_CONTAINER_STATE WaitForTransition(WSLA_CONTAINER_STATE expectedState);
std::optional<std::string> GetNerdctlStatus();
ServiceRunningProcess m_containerProcess;
std::mutex m_lock;
std::mutex m_stateLock;
wil::shared_event m_stateChangeEvent{wil::EventOptions::None};
std::optional<ServiceRunningProcess> m_containerProcess;
std::string m_name;
std::string m_image;
std::string m_id;
WSLA_CONTAINER_STATE m_state = WslaContainerStateInvalid;
WSLAVirtualMachine* m_parentVM = nullptr;
std::mutex m_lock;
ContainerEventTracker::ContainerTrackingReference m_trackingReference;
static std::vector<std::string> PrepareNerdctlRunCommand(const WSLA_CONTAINER_OPTIONS& options, std::vector<std::string>&& inputOptions);
};

View File

@ -56,6 +56,9 @@ WSLASession::WSLASession(const WSLA_SESSION_SETTINGS& Settings, WSLAUserSessionI
throw;
}
}
// Start the event tracker.
m_eventTracker.emplace(*m_virtualMachine.Get());
}
WSLAVirtualMachine::Settings WSLASession::CreateVmSettings(const WSLA_SESSION_SETTINGS& Settings)
@ -104,6 +107,11 @@ WSLASession::~WSLASession()
std::lock_guard lock{m_lock};
if (m_eventTracker.has_value())
{
m_eventTracker.reset();
}
if (m_virtualMachine)
{
m_virtualMachine->OnSessionTerminated();
@ -231,11 +239,14 @@ try
RETURN_HR_IF(E_INVALIDARG, strlen(containerOptions->Image) > WSLA_MAX_IMAGE_NAME_LENGTH);
// TODO: Log entrance into the function.
auto container = WSLAContainer::Create(*containerOptions, *m_virtualMachine.Get());
auto container = WSLAContainer::Create(*containerOptions, *m_virtualMachine.Get(), *m_eventTracker);
RETURN_IF_FAILED(container.CopyTo(__uuidof(IWSLAContainer), (void**)Container));
m_containers.emplace(containerOptions->Name, std::move(container));
auto [newElement, inserted] = m_containers.emplace(containerOptions->Name, std::move(container));
WI_ASSERT(inserted);
newElement->second->Start(*containerOptions);
return S_OK;
}

View File

@ -17,6 +17,7 @@ Abstract:
#include "wslaservice.h"
#include "WSLAVirtualMachine.h"
#include "WSLAContainer.h"
#include "ContainerEventTracker.h"
namespace wsl::windows::service::wsla {
@ -62,6 +63,7 @@ private:
WSLA_SESSION_SETTINGS m_sessionSettings; // TODO: Revisit to see if we should have session settings as a member or not
WSLAUserSessionImpl* m_userSession = nullptr;
Microsoft::WRL::ComPtr<WSLAVirtualMachine> m_virtualMachine;
std::optional<ContainerEventTracker> m_eventTracker;
std::wstring m_displayName;
std::filesystem::path m_storageVhdPath;
std::map<std::string, Microsoft::WRL::ComPtr<WSLAContainer>> m_containers;

View File

@ -20,6 +20,7 @@ Abstract:
#include "GuestDeviceManager.h"
#include "WSLAApi.h"
#include "WSLAProcess.h"
#include "ContainerEventTracker.h"
namespace wsl::windows::service::wsla {

View File

@ -258,7 +258,6 @@ struct WSLA_SESSION_SETTINGS {
]
interface IWSLAContainer : IUnknown
{
HRESULT Start();
HRESULT Stop([in] int Signal, [in] ULONG TimeoutMs);
HRESULT Delete(); // TODO: Look into lifetime logic.
HRESULT GetState([out] enum WSLA_CONTAINER_STATE* State);