diff --git a/src/shared/inc/CommandLine.h b/src/shared/inc/CommandLine.h index da4e2056..cde97e62 100644 --- a/src/shared/inc/CommandLine.h +++ b/src/shared/inc/CommandLine.h @@ -308,7 +308,7 @@ public: #ifdef WIN32 ArgumentParser(const std::wstring& CommandLine, LPCWSTR Name, int StartIndex = 1, bool ignoreUnknownArgs = false) : - m_startIndex(StartIndex), m_name(Name), m_ignoreUnknownArgs(ignoreUnknownArgs) + m_parseIndex(StartIndex), m_name(Name), m_ignoreUnknownArgs(ignoreUnknownArgs) { m_argv.reset(CommandLineToArgvW(std::wstring(CommandLine).c_str(), &m_argc)); THROW_LAST_ERROR_IF(!m_argv); @@ -317,7 +317,7 @@ public: #else ArgumentParser(int argc, const char* const* argv, bool ignoreUnknownArgs = false) : - m_argc(argc), m_argv(argv), m_startIndex(1), m_ignoreUnknownArgs(ignoreUnknownArgs) + m_argc(argc), m_argv(argv), m_parseIndex(1), m_ignoreUnknownArgs(ignoreUnknownArgs) { } @@ -354,13 +354,13 @@ public: m_arguments.emplace_back(std::move(match), BuildParseMethod(std::forward(Output)), true); } - void Parse() const + void Parse() { int argumentPosition = 0; bool stopParameters = false; - for (size_t i = m_startIndex; i < m_argc; i++) + for (; m_parseIndex < m_argc; m_parseIndex++) { - if (!stopParameters && wsl::shared::string::IsEqual(m_argv[i], TEXT("--"))) + if (!stopParameters && wsl::shared::string::IsEqual(m_argv[m_parseIndex], TEXT("--"))) { stopParameters = true; continue; @@ -370,9 +370,10 @@ public: int offset = 0; // Special case for short argument with multiple values like -abc - if (!stopParameters && m_argv[i][0] == '-' && m_argv[i][1] != '-' && m_argv[i][1] != '\0' && m_argv[i][2] != '\0') + if (!stopParameters && m_argv[m_parseIndex][0] == '-' && m_argv[m_parseIndex][1] != '-' && + m_argv[m_parseIndex][1] != '\0' && m_argv[m_parseIndex][2] != '\0') { - for (const auto* arg = &m_argv[i][1]; *arg != '\0'; arg++) + for (const auto* arg = &m_argv[m_parseIndex][1]; *arg != '\0'; arg++) { foundMatch = false; for (const auto& e : m_arguments) @@ -397,23 +398,26 @@ public: { for (const auto& e : m_arguments) { - if (e.Matches(stopParameters ? nullptr : m_argv[i], m_argv[i][0] == '-' && m_argv[i][1] != '\0' && !stopParameters ? -1 : argumentPosition)) + if (e.Matches( + stopParameters ? nullptr : m_argv[m_parseIndex], + m_argv[m_parseIndex][0] == '-' && m_argv[m_parseIndex][1] != '\0' && !stopParameters ? -1 : argumentPosition)) { const TChar* value = nullptr; if (e.Positional) { - value = m_argv[i]; // Positional arguments directly receive argv[i] + value = m_argv[m_parseIndex]; // Positional arguments directly receive argv[i] } - else if (i + 1 < m_argc) + else if (m_parseIndex + 1 < m_argc) { - value = m_argv[i + 1]; + value = m_argv[m_parseIndex + 1]; } offset = e.Consume(value); if (offset < 0) { WI_ASSERT(value == nullptr); - THROW_USER_ERROR(wsl::shared::Localization::MessageMissingArgument(m_argv[i], m_name ? m_name : m_argv[0])); + THROW_USER_ERROR( + wsl::shared::Localization::MessageMissingArgument(m_argv[m_parseIndex], m_name ? m_name : m_argv[0])); } if (e.Positional) // Positional arguments can't consume extra arguments. @@ -421,7 +425,7 @@ public: offset = 0; } - i += offset; + m_parseIndex += offset; foundMatch = true; break; @@ -436,16 +440,32 @@ public: break; } - THROW_USER_ERROR(wsl::shared::Localization::MessageInvalidCommandLine(m_argv[i], m_name ? m_name : m_argv[0])); + THROW_USER_ERROR(wsl::shared::Localization::MessageInvalidCommandLine(m_argv[m_parseIndex], m_name ? m_name : m_argv[0])); } - if (i < m_argc && m_argv[i - offset][0] != '-') + if (m_parseIndex < m_argc && m_argv[m_parseIndex - offset][0] != '-') { argumentPosition++; } } } + size_t ParseIndex() const noexcept + { + return m_parseIndex; + } + + size_t Argc() const noexcept + { + return m_argc; + } + + const auto* Argv(size_t Index) const noexcept + { + WI_ASSERT(Index < static_cast(m_argc)); + return m_argv[Index]; + } + private: template static std::function BuildParseMethod(T&& Output) @@ -540,7 +560,7 @@ private: #endif - int m_startIndex{}; + int m_parseIndex{}; const TChar* m_name{}; bool m_ignoreUnknownArgs{false}; }; diff --git a/src/windows/common/ExecutionContext.cpp b/src/windows/common/ExecutionContext.cpp index c8680522..ef6ac366 100644 --- a/src/windows/common/ExecutionContext.cpp +++ b/src/windows/common/ExecutionContext.cpp @@ -367,6 +367,11 @@ LXSS_ERROR_INFO* ClientExecutionContext::OutError() noexcept return &m_outError; } +void wsl::windows::common::SetErrorMessage(std::string&& message) +{ + return SetErrorMessage(wsl::shared::string::MultiByteToWide(message)); +} + void wsl::windows::common::SetErrorMessage(std::wstring&& message) { if (g_currentContext == nullptr || message.empty()) diff --git a/src/windows/common/ExecutionContext.h b/src/windows/common/ExecutionContext.h index 9c12ff64..75a936ae 100644 --- a/src/windows/common/ExecutionContext.h +++ b/src/windows/common/ExecutionContext.h @@ -182,6 +182,7 @@ private: void EnableContextualizedErrors(bool service); void SetErrorMessage(std::wstring&& message); +void SetErrorMessage(std::string&& message); void SetEventLog(HANDLE eventLog); diff --git a/src/windows/common/WSLAContainerLauncher.cpp b/src/windows/common/WSLAContainerLauncher.cpp index c919895b..97d4ceb6 100644 --- a/src/windows/common/WSLAContainerLauncher.cpp +++ b/src/windows/common/WSLAContainerLauncher.cpp @@ -112,7 +112,7 @@ std::pair> WSLAContainerLauncher::C // TODO: Support volumes, ports, flags, shm size, container networking mode, etc. wil::com_ptr container; - auto result = Session.CreateContainer(&options, &container); + auto result = Session.CreateContainer(&options, &container, nullptr); if (FAILED(result)) { return std::pair>(result, std::optional{}); diff --git a/src/windows/common/WslClient.cpp b/src/windows/common/WslClient.cpp index 7f713854..12460543 100644 --- a/src/windows/common/WslClient.cpp +++ b/src/windows/common/WslClient.cpp @@ -112,17 +112,9 @@ struct ShellExecOptions } }; -bool IsInteractiveConsole() -{ - const HANDLE stdinHandle = GetStdHandle(STD_INPUT_HANDLE); - DWORD mode{}; - - return GetFileType(stdinHandle) == FILE_TYPE_CHAR && GetConsoleMode(stdinHandle, &mode); -} - void PromptForKeyPress() { - if (IsInteractiveConsole()) + if (wsl::windows::common::wslutil::IsInteractiveConsole()) { wsl::windows::common::wslutil::PrintMessage(wsl::shared::Localization::MessagePressAnyKeyToExit()); LOG_IF_WIN32_BOOL_FALSE(FlushConsoleInputBuffer(GetStdHandle(STD_INPUT_HANDLE))); @@ -1649,7 +1641,7 @@ int WslaShell(_In_ std::wstring_view commandLine) } else { - THROW_IF_FAILED(session->PullImage(containerImage.c_str(), nullptr, nullptr)); + THROW_IF_FAILED(session->PullImage(containerImage.c_str(), nullptr, nullptr, nullptr)); std::vector fds; @@ -1692,7 +1684,7 @@ int WslaShell(_In_ std::wstring_view commandLine) } container.emplace(); - THROW_IF_FAILED(session->CreateContainer(&containerOptions, &container.value())); + THROW_IF_FAILED(session->CreateContainer(&containerOptions, &container.value(), nullptr)); THROW_IF_FAILED((*container)->Start()); wil::com_ptr createdProcess; @@ -1729,7 +1721,7 @@ int WslaShell(_In_ std::wstring_view commandLine) }); // Required because ReadFile() blocks if stdin is a tty. - if (IsInteractiveConsole()) + if (wsl::windows::common::wslutil::IsInteractiveConsole()) { inputThread = std::thread{[&]() { wsl::windows::common::relay::StandardInputRelay(Stdin, process->GetStdHandle(0).get(), []() {}, exitEvent.get()); diff --git a/src/windows/common/relay.cpp b/src/windows/common/relay.cpp index e805459e..9755a52a 100644 --- a/src/windows/common/relay.cpp +++ b/src/windows/common/relay.cpp @@ -1016,12 +1016,8 @@ IOHandleStatus OverlappedIOHandle::GetState() const return State; } -EventHandle::EventHandle(HANDLE Handle, std::function&& OnSignalled) : Handle(Handle), OnSignalled(std::move(OnSignalled)) -{ -} - -EventHandle::EventHandle(wil::unique_event&& Handle, std::function&& OnSignalled) : - OwnedHandle(std::move(Handle)), Handle(OwnedHandle.get()), OnSignalled(std::move(OnSignalled)) +EventHandle::EventHandle(HandleWrapper&& Handle, std::function&& OnSignalled) : + Handle(std::move(Handle)), OnSignalled(std::move(OnSignalled)) { } @@ -1038,7 +1034,7 @@ void EventHandle::Collect() HANDLE EventHandle::GetHandle() const { - return Handle; + return Handle.Get(); } ReadHandle::ReadHandle(HandleWrapper&& MovedHandle, std::function& Buffer)>&& OnRead) : diff --git a/src/windows/common/relay.hpp b/src/windows/common/relay.hpp index b2290439..d17cb0ff 100644 --- a/src/windows/common/relay.hpp +++ b/src/windows/common/relay.hpp @@ -170,6 +170,18 @@ struct HandleWrapper { } + HandleWrapper( + wil::unique_socket&& handle, std::function&& OnClose = []() {}) : + OwnedHandle((HANDLE)handle.release()), Handle(OwnedHandle.get()), OnClose(std::move(OnClose)) + { + } + + HandleWrapper( + wil::unique_event&& handle, std::function&& OnClose = []() {}) : + OwnedHandle(handle.release()), Handle(OwnedHandle.get()), OnClose(std::move(OnClose)) + { + } + HandleWrapper( SOCKET handle, std::function&& OnClose = []() {}) : Handle(reinterpret_cast(handle)), OnClose(std::move(OnClose)) @@ -237,15 +249,13 @@ public: NON_COPYABLE(EventHandle) NON_MOVABLE(EventHandle) - EventHandle(wil::unique_event&& EventHandle, std::function&& OnSignalled); - EventHandle(HANDLE EventHandle, std::function&& OnSignalled); + EventHandle(HandleWrapper&& EventHandle, std::function&& OnSignalled = []() {}); void Schedule() override; void Collect() override; HANDLE GetHandle() const override; private: - wil::unique_event OwnedHandle; - HANDLE Handle; + HandleWrapper Handle; std::function OnSignalled; }; diff --git a/src/windows/common/wslutil.cpp b/src/windows/common/wslutil.cpp index dc6b03c9..677995d7 100644 --- a/src/windows/common/wslutil.cpp +++ b/src/windows/common/wslutil.cpp @@ -16,6 +16,7 @@ Abstract: #include "wslutil.h" #include "WslPluginApi.h" #include "wslinstallerservice.h" +#include "wslaservice.h" #include "ConsoleProgressBar.h" #include "ExecutionContext.h" @@ -145,7 +146,8 @@ static const std::map g_commonErrors{ X_WIN32(ERROR_BAD_PATHNAME), X(WININET_E_TIMEOUT), X_WIN32(ERROR_INVALID_SID), - X_WIN32(ERROR_INVALID_STATE)}; + X_WIN32(ERROR_INVALID_STATE), + X(WSLA_E_IMAGE_NOT_FOUND)}; #undef X @@ -1197,6 +1199,14 @@ void wsl::windows::common::wslutil::InitializeWil() } } +bool wsl::windows::common::wslutil::IsInteractiveConsole() +{ + const HANDLE stdinHandle = GetStdHandle(STD_INPUT_HANDLE); + DWORD mode{}; + + return GetFileType(stdinHandle) == FILE_TYPE_CHAR && GetConsoleMode(stdinHandle, &mode); +} + bool wsl::windows::common::wslutil::IsRunningInMsix() { UINT32 dummy{}; @@ -1590,4 +1600,32 @@ catch (...) { LOG_CAUGHT_EXCEPTION(); return nullptr; +} + +wsl::windows::common::wslutil::WSLAErrorDetails::~WSLAErrorDetails() +{ + Reset(); +} + +void wsl::windows::common::wslutil::WSLAErrorDetails::Reset() +{ + CoTaskMemFree(Error.UserErrorMessage); + Error = {}; +} + +void wsl::windows::common::wslutil::WSLAErrorDetails::ThrowIfFailed(HRESULT Result) +{ + if (SUCCEEDED(Result)) + { + return; + } + + if (Error.UserErrorMessage != nullptr) + { + THROW_HR_WITH_USER_ERROR(Result, Error.UserErrorMessage); + } + else + { + THROW_HR(Result); + } } \ No newline at end of file diff --git a/src/windows/common/wslutil.h b/src/windows/common/wslutil.h index e931ae6a..df052650 100644 --- a/src/windows/common/wslutil.h +++ b/src/windows/common/wslutil.h @@ -19,6 +19,7 @@ Abstract: #include "SubProcess.h" #include #include "JsonUtils.h" +#include "wslaservice.h" namespace wsl::windows::common { struct Error; @@ -64,6 +65,17 @@ struct GitHubRelease NLOHMANN_DEFINE_TYPE_INTRUSIVE_WITH_DEFAULT(GitHubRelease, name, assets, created_at); }; +struct WSLAErrorDetails +{ + ~WSLAErrorDetails(); + + void Reset(); + + void ThrowIfFailed(HRESULT Result); + + WSLA_ERROR_INFO Error{}; +}; + template void AssertValidPrintfArg() { @@ -143,6 +155,8 @@ std::vector HashFile(HANDLE File, DWORD Algorithm); void InitializeWil(); +bool IsInteractiveConsole(); + bool IsRunningInMsix(); bool IsVhdFile(_In_ const std::filesystem::path& path); diff --git a/src/windows/inc/docker_schema.h b/src/windows/inc/docker_schema.h index 51c63492..a34dd512 100644 --- a/src/windows/inc/docker_schema.h +++ b/src/windows/inc/docker_schema.h @@ -151,4 +151,23 @@ struct StartExec NLOHMANN_DEFINE_TYPE_INTRUSIVE_ONLY_SERIALIZE(StartExec, Tty, Detach, ConsoleSize); }; +struct CreateImageProgressDetails +{ + uint64_t current{}; + uint64_t total{}; + std::string unit; + + NLOHMANN_DEFINE_TYPE_INTRUSIVE_WITH_DEFAULT(CreateImageProgressDetails, current, total, unit); +}; + +struct CreateImageProgress +{ + std::string status; + std::string id; + + CreateImageProgressDetails progressDetail; + + NLOHMANN_DEFINE_TYPE_INTRUSIVE_WITH_DEFAULT(CreateImageProgress, status, id, progressDetail); +}; + } // namespace wsl::windows::common::docker_schema \ No newline at end of file diff --git a/src/windows/wsladiag/wsladiag.cpp b/src/windows/wsladiag/wsladiag.cpp index dd24d06b..8c11c002 100644 --- a/src/windows/wsladiag/wsladiag.cpp +++ b/src/windows/wsladiag/wsladiag.cpp @@ -24,11 +24,40 @@ Abstract: using namespace wsl::shared; namespace wslutil = wsl::windows::common::wslutil; +using wsl::windows::common::ClientRunningWSLAProcess; using wsl::windows::common::Context; using wsl::windows::common::ExecutionContext; using wsl::windows::common::WSLAProcessLauncher; +using wsl::windows::common::relay::EventHandle; +using wsl::windows::common::relay::MultiHandleWait; +using wsl::windows::common::relay::RelayHandle; +using wsl::windows::common::wslutil::WSLAErrorDetails; + +class ChangeTerminalMode +{ +public: + NON_COPYABLE(ChangeTerminalMode); + NON_MOVABLE(ChangeTerminalMode); + + ChangeTerminalMode(HANDLE Console, bool CursorVisible) : m_console(Console) + { + THROW_IF_WIN32_BOOL_FALSE(GetConsoleCursorInfo(Console, &m_originalCursorInfo)); + CONSOLE_CURSOR_INFO newCursorInfo = m_originalCursorInfo; + newCursorInfo.bVisible = CursorVisible; + + THROW_IF_WIN32_BOOL_FALSE(SetConsoleCursorInfo(Console, &newCursorInfo)); + } + + ~ChangeTerminalMode() + { + LOG_IF_WIN32_BOOL_FALSE(SetConsoleCursorInfo(m_console, &m_originalCursorInfo)); + } + +private: + HANDLE m_console{}; + CONSOLE_CURSOR_INFO m_originalCursorInfo{}; +}; -// Report an operation failure with localized context and HRESULT details. static int ReportError(const std::wstring& context, HRESULT hr) { auto errorString = wsl::windows::common::wslutil::ErrorCodeToString(hr); @@ -284,7 +313,350 @@ static int RunListCommand(std::wstring_view commandLine) return 0; } -// Print localized usage message to stderr. +DEFINE_ENUM_FLAG_OPERATORS(WSLASessionFlags); + +static wil::com_ptr OpenCLISession() +{ + wil::com_ptr userSession; + THROW_IF_FAILED(CoCreateInstance(__uuidof(WSLAUserSession), nullptr, CLSCTX_LOCAL_SERVER, IID_PPV_ARGS(&userSession))); + wsl::windows::common::security::ConfigureForCOMImpersonation(userSession.get()); + + auto dataFolder = std::filesystem::path(wsl::windows::common::filesystem::GetLocalAppDataPath(nullptr)) / "wsla"; + + // TODO: Have a configuration file for those. + WSLA_SESSION_SETTINGS settings{}; + settings.DisplayName = L"wsla-cli"; + settings.CpuCount = 4; + settings.MemoryMb = 2024; + settings.BootTimeoutMs = 30 * 1000; + settings.StoragePath = dataFolder.c_str(); + settings.MaximumStorageSizeMb = 10000; // 10GB. + settings.NetworkingMode = WSLANetworkingModeNAT; + + wil::com_ptr session; + THROW_IF_FAILED(userSession->CreateSession(&settings, WSLASessionFlagsPersistent | WSLASessionFlagsOpenExisting, &session)); + wsl::windows::common::security::ConfigureForCOMImpersonation(session.get()); + + return session; +} + +static void PullImpl(IWSLASession& Session, const std::string& Image) +{ + HANDLE Stdout = GetStdHandle(STD_OUTPUT_HANDLE); + // Configure console for interactive usage. + DWORD OriginalOutputMode{}; + UINT OriginalOutputCP = GetConsoleOutputCP(); + THROW_LAST_ERROR_IF(!::GetConsoleMode(Stdout, &OriginalOutputMode)); + + DWORD OutputMode = OriginalOutputMode; + WI_SetAllFlags(OutputMode, ENABLE_PROCESSED_OUTPUT | ENABLE_VIRTUAL_TERMINAL_PROCESSING | DISABLE_NEWLINE_AUTO_RETURN); + THROW_IF_WIN32_BOOL_FALSE(::SetConsoleMode(Stdout, OutputMode)); + + THROW_LAST_ERROR_IF(!::SetConsoleOutputCP(CP_UTF8)); + + // TODO: Handle terminal resizes. + class DECLSPEC_UUID("7A1D3376-835A-471A-8DC9-23653D9962D0") Callback + : public Microsoft::WRL::RuntimeClass, IProgressCallback, IFastRundown> + { + public: + auto MoveToLine(SHORT Line, bool Revert = true) + { + if (Line > 0) + { + wprintf(L"\033[%iA", Line); + } + + return wil::scope_exit([Line = Line]() { + if (Line > 1) + { + wprintf(L"\033[%iB", Line - 1); + } + }); + } + + HRESULT OnProgress(LPCSTR Status, LPCSTR Id, ULONGLONG Current, ULONGLONG Total) override + try + { + if (Id == nullptr || *Id == '\0') // Print all 'global' statuses on their own line + { + wprintf(L"%hs\n", Status); + m_currentLine++; + return S_OK; + } + + auto info = Info(); + + auto it = m_statuses.find(Id); + if (it == m_statuses.end()) + { + // If this is the first time we see this ID, create a new line for it. + m_statuses.emplace(Id, m_currentLine); + wprintf(L"%ls\n", GenerateStatusLine(Status, Id, Current, Total, info).c_str()); + m_currentLine++; + } + else + { + auto revert = MoveToLine(m_currentLine - it->second); + wprintf(L"%ls\n", GenerateStatusLine(Status, Id, Current, Total, info).c_str()); + } + + return S_OK; + } + CATCH_RETURN(); + + private: + static CONSOLE_SCREEN_BUFFER_INFO Info() + { + CONSOLE_SCREEN_BUFFER_INFO info{}; + THROW_IF_WIN32_BOOL_FALSE(GetConsoleScreenBufferInfo(GetStdHandle(STD_OUTPUT_HANDLE), &info)); + + return info; + } + + std::wstring GenerateStatusLine(LPCSTR Status, LPCSTR Id, ULONGLONG Current, ULONGLONG Total, const CONSOLE_SCREEN_BUFFER_INFO& Info) + { + std::wstring line; + if (Total != 0) + { + line = std::format(L"{} '{}': {}%", Status, Id, Current * 100 / Total); + } + else if (Current != 0) + { + line = std::format(L"{} '{}': {}s", Status, Id, Current); + } + else + { + line = std::format(L"{} '{}'", Status, Id); + } + + // Erase any previously written char on that line. + while (line.size() < Info.dwSize.X) + { + line += L' '; + } + + return line; + } + + std::map m_statuses; + SHORT m_currentLine = 0; + ChangeTerminalMode m_terminalMode{GetStdHandle(STD_OUTPUT_HANDLE), false}; + }; + + wil::com_ptr session = OpenCLISession(); + + Callback callback; + WSLAErrorDetails error{}; + auto result = session->PullImage(Image.c_str(), nullptr, &callback, &error.Error); + error.ThrowIfFailed(result); +} + +static int Pull(std::wstring_view commandLine) +{ + ArgumentParser parser(std::wstring{commandLine}, L"wsladiag", 2); + + std::string image; + parser.AddPositionalArgument(Utf8String{image}, 0); + + parser.Parse(); + THROW_HR_IF(E_INVALIDARG, image.empty()); + + PullImpl(*OpenCLISession(), image); + + return 0; +} + +static int InteractiveShell(ClientRunningWSLAProcess&& Process, bool Tty) +{ + HANDLE Stdout = GetStdHandle(STD_OUTPUT_HANDLE); + HANDLE Stdin = GetStdHandle(STD_INPUT_HANDLE); + auto exitEvent = Process.GetExitEvent(); + + if (Tty) + { + // Save original console modes so they can be restored on exit. + DWORD OriginalInputMode{}; + DWORD OriginalOutputMode{}; + UINT OriginalOutputCP = GetConsoleOutputCP(); + THROW_LAST_ERROR_IF(!::GetConsoleMode(Stdin, &OriginalInputMode)); + THROW_LAST_ERROR_IF(!::GetConsoleMode(Stdout, &OriginalOutputMode)); + + auto restoreConsoleMode = wil::scope_exit_log(WI_DIAGNOSTICS_INFO, [&] { + SetConsoleMode(Stdin, OriginalInputMode); + SetConsoleMode(Stdout, OriginalOutputMode); + SetConsoleOutputCP(OriginalOutputCP); + }); + + // Configure console for interactive usage. + DWORD InputMode = OriginalInputMode; + WI_SetAllFlags(InputMode, (ENABLE_WINDOW_INPUT | ENABLE_VIRTUAL_TERMINAL_INPUT)); + WI_ClearAllFlags(InputMode, (ENABLE_ECHO_INPUT | ENABLE_INSERT_MODE | ENABLE_LINE_INPUT | ENABLE_PROCESSED_INPUT)); + THROW_IF_WIN32_BOOL_FALSE(::SetConsoleMode(Stdin, InputMode)); + + DWORD OutputMode = OriginalOutputMode; + WI_SetAllFlags(OutputMode, ENABLE_PROCESSED_OUTPUT | ENABLE_VIRTUAL_TERMINAL_PROCESSING | DISABLE_NEWLINE_AUTO_RETURN); + THROW_IF_WIN32_BOOL_FALSE(::SetConsoleMode(Stdout, OutputMode)); + + THROW_LAST_ERROR_IF(!::SetConsoleOutputCP(CP_UTF8)); + + auto processTty = Process.GetStdHandle(WSLAFDTty); + + // TODO: Study a single thread for both handles. + + // Create a thread to relay stdin to the pipe. + std::thread inputThread([&]() { + auto updateTerminal = [&Stdout, &Process, &processTty]() { + CONSOLE_SCREEN_BUFFER_INFOEX info{}; + info.cbSize = sizeof(info); + + THROW_IF_WIN32_BOOL_FALSE(GetConsoleScreenBufferInfoEx(Stdout, &info)); + + LOG_IF_FAILED(Process.Get().ResizeTty( + info.srWindow.Bottom - info.srWindow.Top + 1, info.srWindow.Right - info.srWindow.Left + 1)); + }; + + wsl::windows::common::relay::StandardInputRelay(Stdin, processTty.get(), updateTerminal, exitEvent.get()); + }); + + auto joinThread = wil::scope_exit_log(WI_DIAGNOSTICS_INFO, [&]() { + exitEvent.SetEvent(); + inputThread.join(); + }); + + // Relay the contents of the pipe to stdout. + wsl::windows::common::relay::InterruptableRelay(processTty.get(), Stdout); + + // Wait for the process to exit. + THROW_LAST_ERROR_IF(WaitForSingleObject(exitEvent.get(), INFINITE) != WAIT_OBJECT_0); + } + else + { + wsl::windows::common::relay::MultiHandleWait io; + + // Create a thread to relay stdin to the pipe. + + std::thread inputThread; + + auto joinThread = wil::scope_exit_log(WI_DIAGNOSTICS_INFO, [&]() { + if (inputThread.joinable()) + { + exitEvent.SetEvent(); + inputThread.join(); + } + }); + + // Required because ReadFile() blocks if stdin is a tty. + if (wsl::windows::common::wslutil::IsInteractiveConsole()) + { + // TODO: Will output CR instead of LF's which can confuse the linux app. + // Consider a custom relay logic to fix this. + inputThread = std::thread{ + [&]() { wsl::windows::common::relay::InterruptableRelay(Stdin, Process.GetStdHandle(0).get(), exitEvent.get()); }}; + } + else + { + io.AddHandle(std::make_unique(GetStdHandle(STD_INPUT_HANDLE), Process.GetStdHandle(0))); + } + + io.AddHandle(std::make_unique(Process.GetStdHandle(1), GetStdHandle(STD_OUTPUT_HANDLE))); + io.AddHandle(std::make_unique(Process.GetStdHandle(2), GetStdHandle(STD_ERROR_HANDLE))); + io.AddHandle(std::make_unique(exitEvent.get())); + + io.Run({}); + } + + int exitCode = Process.GetExitCode(); + + return exitCode; +} + +static int Run(std::wstring_view commandLine) +{ + ArgumentParser parser(std::wstring{commandLine}, L"wsladiag", 2, true); + + bool interactive{}; + bool tty{}; + std::string image; + parser.AddPositionalArgument(Utf8String{image}, 0); + parser.AddArgument(interactive, L"--interactive", 'i'); + parser.AddArgument(tty, L"--tty", 't'); + + parser.Parse(); + THROW_HR_IF(E_INVALIDARG, image.empty()); + + auto session = OpenCLISession(); + + WSLA_CONTAINER_OPTIONS options{}; + options.Image = image.c_str(); + + std::vector fds; + HANDLE Stdout = GetStdHandle(STD_OUTPUT_HANDLE); + HANDLE Stdin = GetStdHandle(STD_INPUT_HANDLE); + + if (tty) + { + CONSOLE_SCREEN_BUFFER_INFOEX Info{}; + Info.cbSize = sizeof(Info); + THROW_IF_WIN32_BOOL_FALSE(::GetConsoleScreenBufferInfoEx(Stdout, &Info)); + + fds.emplace_back(WSLA_PROCESS_FD{.Fd = 0, .Type = WSLAFdTypeTerminalInput}); + fds.emplace_back(WSLA_PROCESS_FD{.Fd = 1, .Type = WSLAFdTypeTerminalOutput}); + fds.emplace_back(WSLA_PROCESS_FD{.Fd = 2, .Type = WSLAFdTypeTerminalControl}); + + options.InitProcessOptions.TtyColumns = Info.srWindow.Right - Info.srWindow.Left + 1; + options.InitProcessOptions.TtyRows = Info.srWindow.Bottom - Info.srWindow.Top + 1; + } + else + { + if (interactive) + { + fds.emplace_back(WSLA_PROCESS_FD{.Fd = 0, .Type = WSLAFdTypeDefault}); + } + + fds.emplace_back(WSLA_PROCESS_FD{.Fd = 1, .Type = WSLAFdTypeDefault}); + fds.emplace_back(WSLA_PROCESS_FD{.Fd = 2, .Type = WSLAFdTypeDefault}); + } + + std::vector argsStorage; + std::vector args; + for (size_t i = parser.ParseIndex(); i < parser.Argc(); i++) + { + argsStorage.emplace_back(wsl::shared::string::WideToMultiByte(parser.Argv(i))); + } + + for (const auto& e : argsStorage) + { + args.emplace_back(e.c_str()); + } + + options.InitProcessOptions.CommandLine = args.data(); + options.InitProcessOptions.CommandLineCount = static_cast(args.size()); + options.InitProcessOptions.Fds = fds.data(); + options.InitProcessOptions.FdsCount = static_cast(fds.size()); + + wil::com_ptr container; + WSLAErrorDetails error{}; + auto result = session->CreateContainer(&options, &container, &error.Error); + if (result == WSLA_E_IMAGE_NOT_FOUND) + { + wslutil::PrintMessage(std::format(L"Image '{}' not found, pulling", image), stderr); + + PullImpl(*session.get(), image); + + error.Reset(); + result = session->CreateContainer(&options, &container, &error.Error); + } + + error.ThrowIfFailed(result); + + THROW_IF_FAILED(container->Start()); // TODO: Error message + + wil::com_ptr process; + THROW_IF_FAILED(container->GetInitProcess(&process)); + + return InteractiveShell(ClientRunningWSLAProcess(std::move(process), std::move(fds)), tty); +} + static void PrintUsage() { wslutil::PrintMessage(Localization::MessageWsladiagUsage(), stderr); @@ -324,21 +696,30 @@ int wsladiag_main(std::wstring_view commandLine) PrintUsage(); return 0; } - - if (verb == L"list") + else if (verb == L"list") { return RunListCommand(commandLine); } - - if (verb == L"shell") + else if (verb == L"shell") { return RunShellCommand(commandLine); } + else if (verb == L"pull") + { + return Pull(commandLine); + } + else if (verb == L"run") + { + return Run(commandLine); + } + else + { + wslutil::PrintMessage(Localization::MessageWslaUnknownCommand(verb.c_str()), stderr); + PrintUsage(); - // Unknown verb - show usage and fail. - wslutil::PrintMessage(Localization::MessageWslaUnknownCommand(verb.c_str()), stderr); - PrintUsage(); - return 1; + // Unknown verb - show usage and fail. + return 1; + } } int wmain(int, wchar_t**) @@ -360,15 +741,16 @@ int wmain(int, wchar_t**) if (FAILED(result)) { - if (auto reported = context.ReportedError()) + if (const auto& reported = context.ReportedError()) { auto strings = wsl::windows::common::wslutil::ErrorToString(*reported); - wslutil::PrintMessage(strings.Message.empty() ? strings.Code : strings.Message, stderr); + auto errorMessage = strings.Message.empty() ? strings.Code : strings.Message; + wslutil::PrintMessage(Localization::MessageErrorCode(errorMessage, wslutil::ErrorCodeToString(result)), stderr); } else { // Fallback for errors without context - wslutil::PrintMessage(wslutil::GetErrorString(result), stderr); + wslutil::PrintMessage(Localization::MessageErrorCode("", wslutil::ErrorCodeToString(result)), stderr); } } diff --git a/src/windows/wslaservice/exe/DockerHTTPClient.cpp b/src/windows/wslaservice/exe/DockerHTTPClient.cpp index 59e91fc4..12152719 100644 --- a/src/windows/wslaservice/exe/DockerHTTPClient.cpp +++ b/src/windows/wslaservice/exe/DockerHTTPClient.cpp @@ -38,15 +38,10 @@ DockerHTTPClient::DockerHTTPClient(wsl::shared::SocketChannel&& Channel, HANDLE { } -uint32_t DockerHTTPClient::PullImage(const char* Name, const char* Tag, const OnImageProgress& Callback) +std::unique_ptr DockerHTTPClient::PullImage(const char* Name, const char* Tag) { - auto [code, _] = SendRequest( - verb::post, - std::format("http://localhost/images/create?fromImage=library/{}&tag={}", Name, Tag), - {}, - [Callback](const gsl::span& span) { Callback(std::string{span.data(), span.size()}); }); - - return code; + auto url = std::format("http://localhost/images/create?fromImage=library/{}&tag={}", Name, Tag); + return SendRequestImpl(verb::post, url, {}, {}); } std::unique_ptr DockerHTTPClient::LoadImage(uint64_t ContentLength) @@ -77,10 +72,15 @@ std::vector DockerHTTPClient::ListImages() return Transaction>(verb::get, "http://localhost/images/json"); } -docker_schema::CreatedContainer DockerHTTPClient::CreateContainer(const docker_schema::CreateContainer& Request) +docker_schema::CreatedContainer DockerHTTPClient::CreateContainer(const docker_schema::CreateContainer& Request, const std::optional& Name) { - // TODO: Url escaping. - return Transaction(verb::post, "http://localhost/containers/create", Request); + std::string url = "http://localhost/containers/create"; + if (Name.has_value()) + { + url += std::format("?name={}", Name.value()); + } + + return Transaction(verb::post, url, Request); } void DockerHTTPClient::ResizeContainerTty(const std::string& Id, ULONG Rows, ULONG Columns) diff --git a/src/windows/wslaservice/exe/DockerHTTPClient.h b/src/windows/wslaservice/exe/DockerHTTPClient.h index cd4bf6d3..ed86cafc 100644 --- a/src/windows/wslaservice/exe/DockerHTTPClient.h +++ b/src/windows/wslaservice/exe/DockerHTTPClient.h @@ -37,7 +37,7 @@ public: } template - T DockerMessage() + T DockerMessage() const { return wsl::shared::FromJson(m_response.c_str()); } @@ -60,7 +60,6 @@ class DockerHTTPClient public: using OnResponseBytes = std::function)>; - using OnImageProgress = std::function; struct HTTPRequestContext { @@ -80,7 +79,7 @@ public: DockerHTTPClient(wsl::shared::SocketChannel&& Channel, HANDLE ExitingEvent, GUID VmId, ULONG ConnectTimeoutMs); // Container management. - common::docker_schema::CreatedContainer CreateContainer(const common::docker_schema::CreateContainer& Request); + common::docker_schema::CreatedContainer CreateContainer(const common::docker_schema::CreateContainer& Request, const std::optional& Name); void StartContainer(const std::string& Id); void StopContainer(const std::string& Id, int Signal, ULONG TimeoutSeconds); void DeleteContainer(const std::string& Id); @@ -90,7 +89,7 @@ public: void ResizeContainerTty(const std::string& Id, ULONG Rows, ULONG Columns); // Image management. - uint32_t PullImage(const char* Name, const char* Tag, const OnImageProgress& Callback); + std::unique_ptr PullImage(const char* Name, const char* Tag); std::unique_ptr ImportImage(const std::string& Repo, const std::string& Tag, uint64_t ContentLength); std::unique_ptr LoadImage(uint64_t ContentLength); void TagImage(const std::string& Id, const std::string& Repo, const std::string& Tag); diff --git a/src/windows/wslaservice/exe/WSLAContainer.cpp b/src/windows/wslaservice/exe/WSLAContainer.cpp index dc802d94..7d765a70 100644 --- a/src/windows/wslaservice/exe/WSLAContainer.cpp +++ b/src/windows/wslaservice/exe/WSLAContainer.cpp @@ -136,7 +136,7 @@ WSLAContainerImpl::WSLAContainerImpl( ContainerEventTracker& EventTracker, DockerHTTPClient& DockerClient) : m_parentVM(parentVM), - m_name(Options.Name), + m_name(Options.Name == nullptr ? "" : Options.Name), // TODO: get name from docker. m_image(Options.Image), m_id(std::move(Id)), m_mountedVolumes(std::move(volumes)), @@ -227,6 +227,11 @@ const std::string& WSLAContainerImpl::Image() const noexcept return m_image; } +const std::string& WSLAContainerImpl::Name() const noexcept +{ + return m_name; +} + IWSLAContainer& WSLAContainerImpl::ComWrapper() { return *m_comWrapper.Get(); @@ -555,13 +560,6 @@ std::unique_ptr WSLAContainerImpl::Create( for (DWORD i = 0; i < containerOptions.InitProcessOptions.EnvironmentCount; i++) { - THROW_HR_IF_MSG( - E_INVALIDARG, - containerOptions.InitProcessOptions.Environment[i][0] == '-', - "Invalid environment string at index: %i: %hs", - i, - containerOptions.InitProcessOptions.Environment[i]); - request.Env.push_back(containerOptions.InitProcessOptions.Environment[i]); } @@ -604,23 +602,16 @@ std::unique_ptr WSLAContainerImpl::Create( } // Send the request to docker. - try - { - auto result = DockerClient.CreateContainer(request); + auto result = + DockerClient.CreateContainer(request, containerOptions.Name != nullptr ? containerOptions.Name : std::optional{}); - // N.B. mappedPorts is explicitly copied because it's referenced in errorCleanup, so it can't be moved. - auto container = std::make_unique( - &parentVM, containerOptions, std::move(result.Id), std::move(volumes), std::vector(*mappedPorts), std::move(OnDeleted), EventTracker, DockerClient); + // N.B. mappedPorts is explicitly copied because it's referenced in errorCleanup, so it can't be moved. + auto container = std::make_unique( + &parentVM, containerOptions, std::move(result.Id), std::move(volumes), std::vector(*mappedPorts), std::move(OnDeleted), EventTracker, DockerClient); - errorCleanup.release(); + errorCleanup.release(); - return container; - } - catch (const DockerHTTPException& e) - { - // TODO: propagate error message to caller. - THROW_HR_MSG(E_FAIL, "Failed to create container: %hs ", e.what()); - } + return container; } const std::string& WSLAContainerImpl::ID() const noexcept diff --git a/src/windows/wslaservice/exe/WSLAContainer.h b/src/windows/wslaservice/exe/WSLAContainer.h index 126d425c..816ab7c0 100644 --- a/src/windows/wslaservice/exe/WSLAContainer.h +++ b/src/windows/wslaservice/exe/WSLAContainer.h @@ -71,6 +71,7 @@ public: IWSLAContainer& ComWrapper(); const std::string& Image() const noexcept; + const std::string& Name() const noexcept; WSLA_CONTAINER_STATE State() noexcept; void OnProcessReleased(DockerExecProcessControl* process); diff --git a/src/windows/wslaservice/exe/WSLASession.cpp b/src/windows/wslaservice/exe/WSLASession.cpp index ef1e2318..e9ef7b6d 100644 --- a/src/windows/wslaservice/exe/WSLASession.cpp +++ b/src/windows/wslaservice/exe/WSLASession.cpp @@ -275,11 +275,14 @@ try } CATCH_LOG(); -HRESULT WSLASession::PullImage(LPCSTR ImageUri, const WSLA_REGISTRY_AUTHENTICATION_INFORMATION* RegistryAuthenticationInformation, IProgressCallback* ProgressCallback) +HRESULT WSLASession::PullImage( + LPCSTR ImageUri, + const WSLA_REGISTRY_AUTHENTICATION_INFORMATION* RegistryAuthenticationInformation, + IProgressCallback* ProgressCallback, + WSLA_ERROR_INFO* Error) try { UNREFERENCED_PARAMETER(RegistryAuthenticationInformation); - UNREFERENCED_PARAMETER(ProgressCallback); RETURN_HR_IF_NULL(E_POINTER, ImageUri); @@ -287,13 +290,73 @@ try std::lock_guard lock{m_lock}; - auto callback = [&](const std::string& content) { - WSL_LOG("ImagePullProgress", TraceLoggingValue(ImageUri, "Image"), TraceLoggingValue(content.c_str(), "Content")); + auto requestContext = m_dockerClient->PullImage(repo.c_str(), tag.c_str()); + + relay::MultiHandleWait io; + + std::optional pullResult; + + auto onHttpResponse = [&](const boost::beast::http::message& response) { + WSL_LOG("PullHttpResponse", TraceLoggingValue(static_cast(response.result()), "StatusCode")); + + pullResult = response.result(); }; - auto code = m_dockerClient->PullImage(repo.c_str(), tag.c_str(), callback); + std::string errorJson; + auto onChunk = [&](const gsl::span& Content) { + if (pullResult.has_value() && pullResult.value() != boost::beast::http::status::ok) + { + // If the status code is an error, then this is an error message, not a progress update. + errorJson.append(Content.data(), Content.size()); + return; + } - THROW_HR_IF_MSG(E_FAIL, code != 200, "Failed to pull image: %hs", ImageUri); + std::string contentString{Content.begin(), Content.end()}; + WSL_LOG("ImagePullProgress", TraceLoggingValue(ImageUri, "Image"), TraceLoggingValue(contentString.c_str(), "Content")); + + if (ProgressCallback == nullptr) + { + return; + } + + auto parsed = wsl::shared::FromJson(contentString.c_str()); + + THROW_IF_FAILED(ProgressCallback->OnProgress( + parsed.status.c_str(), parsed.id.c_str(), parsed.progressDetail.current, parsed.progressDetail.total)); + }; + + auto onCompleted = [&]() { io.Cancel(); }; + + io.AddHandle(std::make_unique(m_sessionTerminatingEvent.get(), [&]() { THROW_HR(E_ABORT); })); + io.AddHandle(std::make_unique(m_sessionTerminatingEvent.get(), [&]() { THROW_HR(E_ABORT); })); + io.AddHandle(std::make_unique( + *requestContext, std::move(onHttpResponse), std::move(onChunk), std::move(onCompleted))); + + io.Run({}); + + THROW_HR_IF(E_ABORT, m_sessionTerminatingEvent.is_signaled()); + THROW_HR_IF(E_UNEXPECTED, !pullResult.has_value()); + + if (pullResult.value() != boost::beast::http::status::ok) + { + std::string errorMessage; + if (static_cast(pullResult.value()) >= 400 && static_cast(pullResult.value()) < 500) + { + // pull failed, parse the error message. + errorMessage = wsl::shared::FromJson(errorJson.c_str()).message; + if (Error != nullptr) + { + Error->UserErrorMessage = wil::make_unique_ansistring(errorMessage.c_str()).release(); + } + } + + if (pullResult.value() == boost::beast::http::status::not_found) + { + THROW_HR_MSG(WSLA_E_IMAGE_NOT_FOUND, "%hs", errorMessage.c_str()); + } + + THROW_HR_MSG(E_FAIL, "Image import failed: %hs", errorMessage.c_str()); + } return S_OK; } @@ -432,52 +495,72 @@ HRESULT WSLASession::DeleteImage(LPCWSTR Image) return E_NOTIMPL; } -HRESULT WSLASession::CreateContainer(const WSLA_CONTAINER_OPTIONS* containerOptions, IWSLAContainer** Container) +HRESULT WSLASession::CreateContainer(const WSLA_CONTAINER_OPTIONS* containerOptions, IWSLAContainer** Container, WSLA_ERROR_INFO* Error) try { RETURN_HR_IF_NULL(E_POINTER, containerOptions); - // Validate that Image and Name are not null. + // Validate that Image is not null. RETURN_HR_IF(E_INVALIDARG, containerOptions->Image == nullptr); - RETURN_HR_IF(E_INVALIDARG, containerOptions->Name == nullptr); std::lock_guard lock{m_lock}; RETURN_HR_IF(HRESULT_FROM_WIN32(ERROR_INVALID_STATE), !m_virtualMachine); - // Validate that no container with the same name already exists. - auto it = m_containers.find(containerOptions->Name); - RETURN_HR_IF(HRESULT_FROM_WIN32(ERROR_ALREADY_EXISTS), it != m_containers.end()); - // Validate that name & images are within length limits. - RETURN_HR_IF(E_INVALIDARG, strlen(containerOptions->Name) > WSLA_MAX_CONTAINER_NAME_LENGTH); + RETURN_HR_IF(E_INVALIDARG, containerOptions->Name != nullptr && strlen(containerOptions->Name) > WSLA_MAX_CONTAINER_NAME_LENGTH); RETURN_HR_IF(E_INVALIDARG, strlen(containerOptions->Image) > WSLA_MAX_IMAGE_NAME_LENGTH); // TODO: Log entrance into the function. - auto [container, inserted] = m_containers.emplace( - containerOptions->Name, - WSLAContainerImpl::Create( + + try + { + auto& it = m_containers.emplace_back(WSLAContainerImpl::Create( *containerOptions, *m_virtualMachine, std::bind(&WSLASession::OnContainerDeleted, this, std::placeholders::_1), m_eventTracker.value(), m_dockerClient.value())); - WI_ASSERT(inserted); + THROW_IF_FAILED(it->ComWrapper().QueryInterface(__uuidof(IWSLAContainer), (void**)Container)); - THROW_IF_FAILED(container->second->ComWrapper().QueryInterface(__uuidof(IWSLAContainer), (void**)Container)); + return S_OK; + } + catch (const DockerHTTPException& e) + { + std::string errorMessage; + if ((e.StatusCode() >= 400 && e.StatusCode() < 500)) + { + errorMessage = e.DockerMessage().message; + } - return S_OK; + if (Error != nullptr) + { + Error->UserErrorMessage = wil::make_unique_ansistring(errorMessage.c_str()).release(); + } + + if (e.StatusCode() == 404) + { + THROW_HR_MSG(WSLA_E_IMAGE_NOT_FOUND, "%hs", errorMessage.c_str()); + } + else if (e.StatusCode() == 409) + { + THROW_WIN32_MSG(ERROR_ALREADY_EXISTS, "%hs", errorMessage.c_str()); + } + + return E_FAIL; + } } CATCH_RETURN(); HRESULT WSLASession::OpenContainer(LPCSTR Name, IWSLAContainer** Container) try { + // TODO: Rethink name / id usage here. std::lock_guard lock{m_lock}; - auto it = m_containers.find(Name); + auto it = std::ranges::find_if(m_containers, [Name](const auto& e) { return e->Name() == Name; }); RETURN_HR_IF_MSG(HRESULT_FROM_WIN32(ERROR_NOT_FOUND), it == m_containers.end(), "Container not found: '%hs'", Name); - THROW_IF_FAILED(it->second->ComWrapper().QueryInterface(__uuidof(IWSLAContainer), (void**)Container)); + THROW_IF_FAILED((*it)->ComWrapper().QueryInterface(__uuidof(IWSLAContainer), (void**)Container)); return S_OK; } @@ -494,11 +577,11 @@ try auto output = wil::make_unique_cotaskmem(m_containers.size()); size_t index = 0; - for (const auto& [name, container] : m_containers) + for (const auto& e : m_containers) { - THROW_HR_IF(E_UNEXPECTED, strcpy_s(output[index].Image, container->Image().c_str()) != 0); - THROW_HR_IF(E_UNEXPECTED, strcpy_s(output[index].Name, name.c_str()) != 0); - container->GetState(&output[index].State); + THROW_HR_IF(E_UNEXPECTED, strcpy_s(output[index].Image, e->Image().c_str()) != 0); + THROW_HR_IF(E_UNEXPECTED, strcpy_s(output[index].Name, e->Name().c_str()) != 0); + e->GetState(&output[index].State); index++; } @@ -638,7 +721,7 @@ CATCH_RETURN(); void WSLASession::OnContainerDeleted(const WSLAContainerImpl* Container) { std::lock_guard lock{m_lock}; - WI_VERIFY(std::erase_if(m_containers, [Container](const auto& e) { return e.second.get() == Container; }) == 1); + WI_VERIFY(std::erase_if(m_containers, [Container](const auto& e) { return e.get() == Container; }) == 1); } HRESULT WSLASession::GetImplNoRef(_Out_ WSLASession** Session) diff --git a/src/windows/wslaservice/exe/WSLASession.h b/src/windows/wslaservice/exe/WSLASession.h index 8cb5f5a5..a40ba397 100644 --- a/src/windows/wslaservice/exe/WSLASession.h +++ b/src/windows/wslaservice/exe/WSLASession.h @@ -50,14 +50,15 @@ public: IFACEMETHOD(PullImage)( _In_ LPCSTR ImageUri, _In_ const WSLA_REGISTRY_AUTHENTICATION_INFORMATION* RegistryAuthenticationInformation, - _In_ IProgressCallback* ProgressCallback) override; + _In_ IProgressCallback* ProgressCallback, + _Inout_opt_ WSLA_ERROR_INFO* ErrorInfo) override; IFACEMETHOD(LoadImage)(_In_ ULONG ImageHandle, _In_ IProgressCallback* ProgressCallback, _In_ ULONGLONG ContentLength) override; IFACEMETHOD(ImportImage)(_In_ ULONG ImageHandle, _In_ LPCSTR ImageName, _In_ IProgressCallback* ProgressCallback, _In_ ULONGLONG ContentLength) override; IFACEMETHOD(ListImages)(_Out_ WSLA_IMAGE_INFORMATION** Images, _Out_ ULONG* Count) override; IFACEMETHOD(DeleteImage)(_In_ LPCWSTR Image) override; // Container management. - IFACEMETHOD(CreateContainer)(_In_ const WSLA_CONTAINER_OPTIONS* Options, _Out_ IWSLAContainer** Container) override; + IFACEMETHOD(CreateContainer)(_In_ const WSLA_CONTAINER_OPTIONS* Options, _Out_ IWSLAContainer** Container, _Inout_opt_ WSLA_ERROR_INFO* Error) override; IFACEMETHOD(OpenContainer)(_In_ LPCSTR Name, _In_ IWSLAContainer** Container) override; IFACEMETHOD(ListContainers)(_Out_ WSLA_CONTAINER** Images, _Out_ ULONG* Count) override; @@ -100,7 +101,7 @@ private: std::thread m_containerdThread; std::wstring m_displayName; std::filesystem::path m_storageVhdPath; - std::map> m_containers; + std::vector> m_containers; wil::unique_event m_sessionTerminatingEvent{wil::EventOptions::ManualReset}; std::recursive_mutex m_lock; }; diff --git a/src/windows/wslaservice/inc/wslaservice.idl b/src/windows/wslaservice/inc/wslaservice.idl index 7722b43a..4ebfd24b 100644 --- a/src/windows/wslaservice/inc/wslaservice.idl +++ b/src/windows/wslaservice/inc/wslaservice.idl @@ -111,7 +111,7 @@ interface ITerminationCallback : IUnknown ] interface IProgressCallback : IUnknown { - HRESULT OnProgress(ULONG Progress, ULONG Total); + HRESULT OnProgress(LPCSTR Status, LPCSTR Id, ULONGLONG Current, ULONGLONG Total); }; struct WSLA_REGISTRY_AUTHENTICATION_INFORMATION @@ -180,7 +180,7 @@ enum WSLA_CONTAINER_FLAGS struct WSLA_CONTAINER_OPTIONS { LPCSTR Image; - LPCSTR Name; + [unique] LPCSTR Name; struct WSLA_PROCESS_OPTIONS InitProcessOptions; [unique, size_is(VolumesCount)] struct WSLA_VOLUME* Volumes; ULONG VolumesCount; @@ -219,6 +219,12 @@ enum WSLA_PROCESS_STATE WslaProcessStateSignalled = 3 }; +// TODO: Design for localization. +typedef struct _WSLA_ERROR_INFO{ + [string] LPSTR UserErrorMessage; + ULONG WarningsPipe; +} WSLA_ERROR_INFO; + [ uuid(1AD163CD-393D-4B33-83A2-8A3F3F23E608), pointer_default(unique), @@ -299,14 +305,14 @@ interface IWSLAContainer : IUnknown interface IWSLASession : IUnknown { // Image management. - HRESULT PullImage([in] LPCSTR ImageUri, [in, unique] const struct WSLA_REGISTRY_AUTHENTICATION_INFORMATION* RegistryAuthenticationInformation, [in, unique] IProgressCallback* ProgressCallback); + HRESULT PullImage([in] LPCSTR ImageUri, [in, unique] const struct WSLA_REGISTRY_AUTHENTICATION_INFORMATION* RegistryAuthenticationInformation, [in, unique] IProgressCallback* ProgressCallback, [in, out, unique, optional] WSLA_ERROR_INFO* ErrorInfo); HRESULT LoadImage([in] ULONG ImageHandle, [in, unique] IProgressCallback* ProgressCallback, [in] ULONGLONG ContentLength); HRESULT ImportImage([in] ULONG ImageHandle, [in] LPCSTR ImageName, [in, unique] IProgressCallback* ProgressCallback, [in] ULONGLONG ContentLength); HRESULT ListImages([out, size_is(, *Count)] struct WSLA_IMAGE_INFORMATION** Images, [out] ULONG* Count); HRESULT DeleteImage([in] LPCWSTR Image); // Container management. - HRESULT CreateContainer([in] const struct WSLA_CONTAINER_OPTIONS* Options, [out] IWSLAContainer** Container); + HRESULT CreateContainer([in] const struct WSLA_CONTAINER_OPTIONS* Options, [out] IWSLAContainer** Container, [in, out, unique, optional] WSLA_ERROR_INFO* ErrorInfo); HRESULT OpenContainer([in] LPCSTR Name, [out] IWSLAContainer** Container); HRESULT ListContainers([out, size_is(, *Count)] struct WSLA_CONTAINER** Images, [out] ULONG* Count); @@ -357,4 +363,7 @@ interface IWSLAUserSession : IUnknown HRESULT OpenSessionByName([in] LPCWSTR DisplayName, [out] IWSLASession** Session); // TODO: Do we need 'TerminateSession()' ? -} \ No newline at end of file +} + +cpp_quote("#define WSLA_E_BASE (0x0600)") +cpp_quote("#define WSLA_E_IMAGE_NOT_FOUND MAKE_HRESULT(SEVERITY_ERROR, FACILITY_ITF, WSLA_E_BASE + 1) /* 0x80040601 */") diff --git a/test/windows/WSLATests.cpp b/test/windows/WSLATests.cpp index 3ed0470e..cebb04d9 100644 --- a/test/windows/WSLATests.cpp +++ b/test/windows/WSLATests.cpp @@ -28,6 +28,7 @@ using wsl::windows::common::WSLAContainerLauncher; using wsl::windows::common::WSLAProcessLauncher; using wsl::windows::common::relay::OverlappedIOHandle; using wsl::windows::common::relay::WriteHandle; +using wsl::windows::common::wslutil::WSLAErrorDetails; DEFINE_ENUM_FLAG_OPERATORS(WSLAFeatureFlags); @@ -53,8 +54,8 @@ class WSLATests storagePath = std::filesystem::current_path() / "test-storage"; auto session = CreateSession(); - VERIFY_SUCCEEDED(session->PullImage("debian:latest", nullptr, nullptr)); - VERIFY_SUCCEEDED(session->PullImage("python:3.12-alpine", nullptr, nullptr)); + VERIFY_SUCCEEDED(session->PullImage("debian:latest", nullptr, nullptr, nullptr)); + VERIFY_SUCCEEDED(session->PullImage("python:3.12-alpine", nullptr, nullptr, nullptr)); return true; } @@ -296,12 +297,28 @@ class WSLATests auto session = CreateSession(settings); - VERIFY_SUCCEEDED(session->PullImage("hello-world:latest", nullptr, nullptr)); + { + VERIFY_SUCCEEDED(session->PullImage("hello-world:linux", nullptr, nullptr, nullptr)); - // Verify that the image is in the list of images. - ExpectImagePresent(*session, "hello-world:latest"); + // Verify that the image is in the list of images. + ExpectImagePresent(*session, "hello-world:linux"); + WSLAContainerLauncher launcher("hello-world:linux", "wsla-pull-image-container"); - // TODO: Check that the image can actually be used to start a container. + auto container = launcher.Launch(*session); + auto result = container.GetInitProcess().WaitAndCaptureOutput(); + + VERIFY_ARE_EQUAL(0, result.Code); + VERIFY_IS_TRUE(result.Output[1].find("Hello from Docker!") != std::string::npos); + } + + { + std::string expectedError = + "pull access denied for does-not, repository does not exist or may require 'docker login'"; + + WSLAErrorDetails error; + VERIFY_ARE_EQUAL(session->PullImage("does-not:exist", nullptr, nullptr, &error.Error), WSLA_E_IMAGE_NOT_FOUND); + VERIFY_ARE_EQUAL(expectedError, error.Error.UserErrorMessage); + } } // TODO: Test that invalid tars are correctly handled. @@ -326,7 +343,7 @@ class WSLATests // Verify that the image is in the list of images. ExpectImagePresent(*session, "hello-world:latest"); - WSLAContainerLauncher launcher("hello-world:latest", "wsla-import-image-container"); + WSLAContainerLauncher launcher("hello-world:latest", "wsla-load-image-container"); auto container = launcher.Launch(*session); auto result = container.GetInitProcess().WaitAndCaptureOutput(); @@ -1336,11 +1353,10 @@ class WSLATests VERIFY_ARE_EQUAL(hresult, E_INVALIDARG); } - // TODO: Add logic to detect when starting the container fails, and enable this test case. { WSLAContainerLauncher launcher("invalid-image-name", "dummy", "/bin/cat"); auto [hresult, container] = launcher.LaunchNoThrow(*session); - VERIFY_ARE_EQUAL(hresult, E_FAIL); // TODO: Have a nicer error code when the image is not found. + VERIFY_ARE_EQUAL(hresult, WSLA_E_IMAGE_NOT_FOUND); } // Test null image name @@ -1352,7 +1368,7 @@ class WSLATests options.InitProcessOptions.CommandLineCount = 0; wil::com_ptr container; - auto hr = session->CreateContainer(&options, &container); + auto hr = session->CreateContainer(&options, &container, nullptr); VERIFY_ARE_EQUAL(hr, E_INVALIDARG); } @@ -1365,8 +1381,7 @@ class WSLATests options.InitProcessOptions.CommandLineCount = 0; wil::com_ptr container; - auto hr = session->CreateContainer(&options, &container); - VERIFY_ARE_EQUAL(hr, E_INVALIDARG); + VERIFY_SUCCEEDED(session->CreateContainer(&options, &container, nullptr)); } }