diff --git a/src/windows/wslaservice/exe/CMakeLists.txt b/src/windows/wslaservice/exe/CMakeLists.txt index 027bc50..0617573 100644 --- a/src/windows/wslaservice/exe/CMakeLists.txt +++ b/src/windows/wslaservice/exe/CMakeLists.txt @@ -2,7 +2,7 @@ set(SOURCES application.manifest main.rc ServiceMain.cpp - ServiceProcessLauncher.h + ServiceProcessLauncher.cpp WSLAContainer.cpp WSLAProcess.cpp WSLASession.cpp @@ -12,7 +12,8 @@ set(SOURCES ) set(HEADERS - ServiceProcessLauncher.cpp + ServiceProcessLauncher.h + WeakRefContainer.h WSLAContainer.h WSLAProcess.h WSLASession.h diff --git a/src/windows/wslaservice/exe/WSLAContainer.cpp b/src/windows/wslaservice/exe/WSLAContainer.cpp index 5c241c8..ebad94d 100644 --- a/src/windows/wslaservice/exe/WSLAContainer.cpp +++ b/src/windows/wslaservice/exe/WSLAContainer.cpp @@ -16,15 +16,37 @@ Abstract: #include "WSLAContainer.h" #include "WSLAProcess.h" +using wsl::windows::service::wsla::WeakReference; using wsl::windows::service::wsla::WSLAContainer; constexpr const char* nerdctlPath = "/usr/bin/nerdctl"; // Constants for required default arguments for "nerdctl run..." -static std::vector defaultNerdctlRunArgs{//"--pull=never", // TODO: Uncomment once PullImage() is implemented. - "--net=host", // TODO: default for now, change later - "--ulimit", - "nofile=65536:65536"}; +static std::vector defaultNerdctlRunArgs{ + //"--pull=never", // TODO: Uncomment once PullImage() is implemented. + "--net=host", // TODO: default for now, change later + "--ulimit", + "nofile=65536:65536"}; + +WSLAContainer::WSLAContainer(WSLAVirtualMachine* parentVM, ServiceRunningProcess&& containerProcess, const char* name, const char* image) : + WeakReference(), m_parentVM(parentVM), m_containerProcess(std::move(containerProcess)), m_name(name), m_image(image) +{ +} + +WSLAContainer::~WSLAContainer() +{ + OnDestroy(); +} + +void WSLAContainer::GetName(char Name[WSLA_MAX_CONTAINER_NAME_LENGTH + 1]) const noexcept +{ + WI_VERIFY(strcpy_s(Name, sizeof(Name), m_name.c_str()) == 0); +} + +void WSLAContainer::GetImage(char Image[WSLA_MAX_IMAGE_NAME_LENGTH + 1]) const noexcept +{ + WI_VERIFY(strcpy_s(Image, WSLA_MAX_IMAGE_NAME_LENGTH + 1, m_image.c_str()) == 0); +} HRESULT WSLAContainer::Start() { diff --git a/src/windows/wslaservice/exe/WSLAContainer.h b/src/windows/wslaservice/exe/WSLAContainer.h index d84eaac..927cb2a 100644 --- a/src/windows/wslaservice/exe/WSLAContainer.h +++ b/src/windows/wslaservice/exe/WSLAContainer.h @@ -24,13 +24,10 @@ class DECLSPEC_UUID("B1F1C4E3-C225-4CAE-AD8A-34C004DE1AE4") WSLAContainer : public Microsoft::WRL::RuntimeClass, IWSLAContainer, IFastRundown> { public: - WSLAContainer() = default; // TODO - WSLAContainer(WSLAVirtualMachine* parentVM, ServiceRunningProcess&& containerProcess) : - m_parentVM(parentVM), m_containerProcess(std::move(containerProcess)) - { - } - WSLAContainer(const WSLAContainer&) = delete; - WSLAContainer& operator=(const WSLAContainer&) = delete; + WSLAContainer(WSLAVirtualMachine* parentVM, ServiceRunningProcess&& containerProcess, const char* name, const char* image); + ~WSLAContainer(); + + NON_COPYABLE(WSLAContainer); IFACEMETHOD(Start)() override; IFACEMETHOD(Stop)(_In_ int Signal, _In_ ULONG TimeoutMs) override; @@ -39,10 +36,15 @@ public: IFACEMETHOD(GetInitProcess)(_Out_ IWSLAProcess** process) override; IFACEMETHOD(Exec)(_In_ const WSLA_PROCESS_OPTIONS* Options, _Out_ IWSLAProcess** Process, _Out_ int* Errno) override; + void GetName(char Name[WSLA_MAX_IMAGE_NAME_LENGTH + 1]) const noexcept; + void GetImage(char Name[WSLA_MAX_CONTAINER_NAME_LENGTH + 1]) const noexcept; + static Microsoft::WRL::ComPtr Create(const WSLA_CONTAINER_OPTIONS& Options, WSLAVirtualMachine& parentVM); private: ServiceRunningProcess m_containerProcess; + std::string m_name; + std::string m_image; WSLAVirtualMachine* m_parentVM = nullptr; static std::vector PrepareNerdctlRunCommand(const WSLA_CONTAINER_OPTIONS& options, std::vector&& inputOptions); diff --git a/src/windows/wslaservice/exe/WSLASession.cpp b/src/windows/wslaservice/exe/WSLASession.cpp index acf20fc..29bf344 100644 --- a/src/windows/wslaservice/exe/WSLASession.cpp +++ b/src/windows/wslaservice/exe/WSLASession.cpp @@ -220,8 +220,14 @@ try std::lock_guard lock{m_lock}; THROW_HR_IF(HRESULT_FROM_WIN32(ERROR_INVALID_STATE), !m_virtualMachine); + THROW_HR_IF(E_INVALIDARG, strlen(containerOptions->Name) > WSLA_MAX_CONTAINER_NAME_LENGTH); + THROW_HR_IF(E_INVALIDARG, strlen(containerOptions->Image) > WSLA_MAX_IMAGE_NAME_LENGTH); + // TODO: Log entrance into the function. auto container = WSLAContainer::Create(*containerOptions, *m_virtualMachine.Get()); + + m_containers.Add(container.Get()); + THROW_IF_FAILED(container.CopyTo(__uuidof(IWSLAContainer), (void**)Container)); return S_OK; @@ -235,7 +241,16 @@ HRESULT WSLASession::OpenContainer(LPCWSTR Name, IWSLAContainer** Container) HRESULT WSLASession::ListContainers(WSLA_CONTAINER** Images, ULONG* Count) { - return E_NOTIMPL; + auto lockedElements = m_containers.Get(); + + auto output = wil::make_unique_cotaskmem(lockedElements.elements.size()); + size_t index = 0; + for (const auto &e: lockedElements.elements) + { + e->GetImage(output[index].Image); + e->GetName(output[index].Name); + index++; + } } HRESULT WSLASession::GetVirtualMachine(IWSLAVirtualMachine** VirtualMachine) diff --git a/src/windows/wslaservice/exe/WSLASession.h b/src/windows/wslaservice/exe/WSLASession.h index 2573238..f143003 100644 --- a/src/windows/wslaservice/exe/WSLASession.h +++ b/src/windows/wslaservice/exe/WSLASession.h @@ -16,6 +16,8 @@ Abstract: #include "wslaservice.h" #include "WSLAVirtualMachine.h" +#include "WeakRefContainer.h" +#include "WSLAContainer.h" namespace wsl::windows::service::wsla { @@ -62,6 +64,7 @@ private: Microsoft::WRL::ComPtr m_virtualMachine; std::wstring m_displayName; std::filesystem::path m_storageVhdPath; + WeakRefContainer m_containers; std::mutex m_lock; // TODO: Add container tracking here. Could reuse m_lock for that. diff --git a/src/windows/wslaservice/exe/WeakRefContainer.h b/src/windows/wslaservice/exe/WeakRefContainer.h new file mode 100644 index 0000000..73500c9 --- /dev/null +++ b/src/windows/wslaservice/exe/WeakRefContainer.h @@ -0,0 +1,93 @@ +#pragma once + +#include "defs.h" +#include +#include +#include + +namespace wsl::windows::service::wsla { + +template +class WeakRefContainer +{ +public: + struct LockedElements + { + std::lock_guard lock; + std::unordered_set& elements; + }; + + WeakRefContainer(const WeakRefContainer&) = delete; + WeakRefContainer(WeakRefContainer&&) = delete; + + WeakRefContainer& operator=(const WeakRefContainer&); + WeakRefContainer& operator=(WeakRefContainer&&); + + WeakRefContainer() = default; + ~WeakRefContainer() + { + std::lock_guard guard(m_lock); + + for (const auto& e : m_elements) + { + e->SetContainer(nullptr); + } + } + + void Add(T* element) + { + std::lock_guard guard(m_lock); + + element->SetContainer(this); + m_elements.insert(element); + } + + void Remove(T* element) + { + element->SetContainer(nullptr); + + std::lock_guard guard(m_lock); + m_elements.erase(element); + } + + LockedElements Get() + { + return {std::lock_guard(m_lock), m_elements}; + } + +private: + std::unordered_set m_elements; + std::mutex m_lock; +}; + +template +class WeakReference +{ +public: + NON_COPYABLE(WeakReference); + WeakReference() = default; + + void SetContainer(WeakRefContainer* container) noexcept + { + std::lock_guard guard(m_lock); + m_container = container; + } + +protected: + WeakRefContainer* m_container = nullptr; + + void OnDestroy() + { + std::lock_guard guard(m_lock); + if (m_container != nullptr) + { + m_container->Remove(static_cast(this)); + m_container = nullptr; + } + } + +private: + std::mutex m_lock; +}; + +} // namespace wsl::windows::service::wsla \ No newline at end of file diff --git a/src/windows/wslaservice/inc/wslaservice.idl b/src/windows/wslaservice/inc/wslaservice.idl index 782cc27..d016e25 100644 --- a/src/windows/wslaservice/inc/wslaservice.idl +++ b/src/windows/wslaservice/inc/wslaservice.idl @@ -20,6 +20,12 @@ cpp_quote("#ifdef __cplusplus") cpp_quote("class DECLSPEC_UUID(\"a9b7a1b9-0671-405c-95f1-e0612cb4ce8f\") WSLAUserSession;") cpp_quote("#endif") +#define WSLA_MAX_CONTAINER_NAME_LENGTH 255 +#define WSLA_MAX_IMAGE_NAME_LENGTH 255 + +cpp_quote("#define WSLA_MAX_CONTAINER_NAME_LENGTH 255") +cpp_quote("#define WSLA_MAX_IMAGE_NAME_LENGTH 255") + typedef struct _WSLA_VERSION { ULONG Major; @@ -96,7 +102,7 @@ struct WSLA_REGISTRY_AUTHENTICATION_INFORMATION struct WSLA_IMAGE_INFORMATION { - LPWSTR Name; + char Image[WSLA_MAX_IMAGE_NAME_LENGTH + 1]; LPWSTR Hash; ULONGLONG Size; ULONGLONG DownloadTimestamp; @@ -162,8 +168,8 @@ enum WSLA_CONTAINER_STATE struct WSLA_CONTAINER { - LPWSTR Name; - LPWSTR Image; + char Name[WSLA_MAX_CONTAINER_NAME_LENGTH + 1]; + char Image[WSLA_MAX_IMAGE_NAME_LENGTH + 1]; enum WSLA_CONTAINER_STATE State; // TODO: Add creation timestamp and other fields that the command line tool might want to display.