diff --git a/CMakeLists.txt b/CMakeLists.txt index ec44370..9de9576 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -431,6 +431,7 @@ add_subdirectory(src/windows/wslhost) add_subdirectory(src/windows/wslrelay) add_subdirectory(src/windows/wslinstall) add_subdirectory(src/windows/wslaclient) +add_subdirectory(src/windows/wsladiag) if (WSL_BUILD_WSL_SETTINGS) add_subdirectory(src/windows/libwsl) diff --git a/msipackage/CMakeLists.txt b/msipackage/CMakeLists.txt index 7f7bcb1..cc011a9 100644 --- a/msipackage/CMakeLists.txt +++ b/msipackage/CMakeLists.txt @@ -3,7 +3,7 @@ set(OUTPUT_PACKAGE ${BIN}/wsl.msi) set(PACKAGE_WIX_IN ${CMAKE_CURRENT_LIST_DIR}/package.wix.in) set(PACKAGE_WIX ${BIN}/package.wix) set(CAB_CACHE ${BIN}/cab) -set(BINARIES wsl.exe;wslg.exe;wslhost.exe;wslrelay.exe;wslservice.exe;wslserviceproxystub.dll;init;initrd.img;wslinstall.dll;wslaserviceproxystub.dll;wslaservice.exe) +set(BINARIES wsl.exe;wslg.exe;wslhost.exe;wslrelay.exe;wslservice.exe;wslserviceproxystub.dll;init;initrd.img;wslinstall.dll;wslaserviceproxystub.dll;wslaservice.exe;wsladiag.exe) if (WSL_BUILD_WSL_SETTINGS) list(APPEND BINARIES_DEPENDENCIES "wslsettings/wslsettings.dll;wslsettings/wslsettings.exe;libwsl.dll") @@ -39,7 +39,7 @@ add_custom_command( add_custom_target(msipackage DEPENDS ${OUTPUT_PACKAGE}) set_target_properties(msipackage PROPERTIES EXCLUDE_FROM_ALL FALSE SOURCES ${PACKAGE_WIX_IN}) -add_dependencies(msipackage wsl wslg wslservice wslhost wslrelay wslserviceproxystub init initramfs wslinstall msixgluepackage wslaservice wslaserviceproxystub) +add_dependencies(msipackage wsl wslg wslservice wslhost wslrelay wslserviceproxystub init initramfs wslinstall msixgluepackage wslaservice wslaserviceproxystub wsladiag) if (WSL_BUILD_WSL_SETTINGS) add_dependencies(msipackage wslsettings libwsl) @@ -47,7 +47,7 @@ endif() set_source_files_properties(${OUTPUT_PACKAGE} PROPERTIES GENERATED TRUE) -if (DEFINED WSL_POST_BUILD_COMMAND) +if (DEFINED WSL_POST_BUILD_COMMAND AND NOT "${WSL_POST_BUILD_COMMAND}" STREQUAL "") add_custom_command( TARGET msipackage POST_BUILD diff --git a/msipackage/package.wix.in b/msipackage/package.wix.in index 277fd13..1ceca4e 100644 --- a/msipackage/package.wix.in +++ b/msipackage/package.wix.in @@ -27,6 +27,7 @@ + diff --git a/src/windows/common/WslTelemetry.cpp b/src/windows/common/WslTelemetry.cpp index 766513d..cac31bb 100644 --- a/src/windows/common/WslTelemetry.cpp +++ b/src/windows/common/WslTelemetry.cpp @@ -33,7 +33,7 @@ TRACELOGGING_DEFINE_PROVIDER( TraceLoggingOptionMicrosoftTelemetry()); TRACELOGGING_DEFINE_PROVIDER( - WslaServiceTelemetryProvider, + WslaTelemetryProvider, "Microsoft.Windows.Wsla", // {0383CE62-8F86-4766-AFB2-9D66A7FB1E90} (0x383ce62, 0x8f86, 0x4766, 0xaf, 0xb2, 0x9d, 0x66, 0xa7, 0xfb, 0x1e, 0x90), diff --git a/src/windows/common/WslTelemetry.h b/src/windows/common/WslTelemetry.h index 209dfd6..96ee3de 100644 --- a/src/windows/common/WslTelemetry.h +++ b/src/windows/common/WslTelemetry.h @@ -28,7 +28,7 @@ extern "C" { #endif TRACELOGGING_DECLARE_PROVIDER(LxssTelemetryProvider); TRACELOGGING_DECLARE_PROVIDER(WslServiceTelemetryProvider); -TRACELOGGING_DECLARE_PROVIDER(WslaServiceTelemetryProvider); +TRACELOGGING_DECLARE_PROVIDER(WslaTelemetryProvider); #ifdef __cplusplus } #endif diff --git a/src/windows/wsladiag/CMakeLists.txt b/src/windows/wsladiag/CMakeLists.txt new file mode 100644 index 0000000..b135c55 --- /dev/null +++ b/src/windows/wsladiag/CMakeLists.txt @@ -0,0 +1,14 @@ +set(SOURCES + wsladiag.cpp +) + +add_executable(wsladiag ${SOURCES}) + +target_link_libraries(wsladiag + ${COMMON_LINK_LIBRARIES} + common +) + +target_precompile_headers(wsladiag REUSE_FROM common) + +set_target_properties(wsladiag PROPERTIES FOLDER windows) \ No newline at end of file diff --git a/src/windows/wsladiag/wsladiag.cpp b/src/windows/wsladiag/wsladiag.cpp new file mode 100644 index 0000000..d31f382 --- /dev/null +++ b/src/windows/wsladiag/wsladiag.cpp @@ -0,0 +1,134 @@ +/*++ + +Copyright (c) Microsoft. All rights reserved. + +Module Name: + + wsladiag.cpp + +Abstract: + + Entry point for the wsladiag tool, performs WSL runtime initialization and parses --list/--help. + +--*/ + +#include "precomp.h" +#include "CommandLine.h" +#include "wslutil.h" +#include "wslaservice.h" +#include "WslSecurity.h" + +using namespace wsl::shared; +namespace wslutil = wsl::windows::common::wslutil; + +int wsladiag_main(std::wstring_view commandLine) +{ + wslutil::ConfigureCrt(); + wslutil::InitializeWil(); + + WslTraceLoggingInitialize(WslaTelemetryProvider, !wsl::shared::OfficialBuild); + auto cleanupTelemetry = wil::scope_exit_log(WI_DIAGNOSTICS_INFO, []() { WslTraceLoggingUninitialize(); }); + + wslutil::SetCrtEncoding(_O_U8TEXT); + + auto coInit = wil::CoInitializeEx(COINIT_MULTITHREADED); + wslutil::CoInitializeSecurity(); + + WSADATA data{}; + THROW_IF_WIN32_ERROR(WSAStartup(MAKEWORD(2, 2), &data)); + auto wsaCleanup = wil::scope_exit_log(WI_DIAGNOSTICS_INFO, []() { WSACleanup(); }); + + // Command-line parsing using ArgumentParser. + ArgumentParser parser(std::wstring{commandLine}, L"wsladiag"); + + bool help = false; + bool list = false; + + parser.AddArgument(list, L"--list"); + parser.AddArgument(help, L"--help", L'h'); // short option is a single wide char + parser.Parse(); + + auto printUsage = []() { + wslutil::PrintMessage( + L"wsladiag - WSLA diagnostics tool\n" + L"Usage:\n" + L" wsladiag --list List WSLA sessions\n" + L" wsladiag --help Show this help", + stderr); + }; + + // If '--help' was requested, print usage and exit. + if (help) + { + printUsage(); + return 0; + } + + if (!list) + { + // No recognized command → show usage + printUsage(); + return 0; + } + + // --list: Call WSLA service COM interface to retrieve and display sessions. + + try + { + wil::com_ptr userSession; + THROW_IF_FAILED(CoCreateInstance(__uuidof(WSLAUserSession), nullptr, CLSCTX_LOCAL_SERVER, IID_PPV_ARGS(&userSession))); + + wsl::windows::common::security::ConfigureForCOMImpersonation(userSession.get()); + + wil::unique_cotaskmem_array_ptr sessions; + + THROW_IF_FAILED(userSession->ListSessions(&sessions, sessions.size_address())); + + if (sessions.size() == 0) + { + wslutil::PrintMessage(L"No WSLA sessions found.\n", stdout); + } + else + { + wslutil::PrintMessage(std::format(L"Found {} WSLA session{}:\n", sessions.size(), sessions.size() > 1 ? L"s" : L""), stdout); + + wslutil::PrintMessage(L"ID\tCreator PID\tDisplay Name\n", stdout); + wslutil::PrintMessage(L"--\t-----------\t------------\n", stdout); + + for (const auto& session : sessions) + { + const auto* displayName = session.DisplayName; + + wslutil::PrintMessage(std::format(L"{}\t{}\t\t{}\n", session.SessionId, session.CreatorPid, displayName), stdout); + } + } + + return 0; + } + catch (...) + { + const auto hr = wil::ResultFromCaughtException(); + const std::wstring hrMessage = wslutil::ErrorCodeToString(hr); + + if (!hrMessage.empty()) + { + wslutil::PrintMessage(std::format(L"Error listing WSLA sessions: 0x{:08x} - {}\n", static_cast(hr), hrMessage), stderr); + } + else + { + wslutil::PrintMessage(std::format(L"Error listing WSLA sessions: 0x{:08x}\n", static_cast(hr)), stderr); + } + + return 1; + } +} + +int wmain(int /*argc*/, wchar_t** /*argv*/) +{ + try + { + // Use raw Unicode command line so ArgumentParser gets original input. + return wsladiag_main(GetCommandLineW()); + } + CATCH_RETURN(); +} \ No newline at end of file diff --git a/src/windows/wslaservice/exe/ServiceMain.cpp b/src/windows/wslaservice/exe/ServiceMain.cpp index 3eeb06c..872ed43 100644 --- a/src/windows/wslaservice/exe/ServiceMain.cpp +++ b/src/windows/wslaservice/exe/ServiceMain.cpp @@ -67,7 +67,7 @@ try // Initialize telemetry. // TODO-WSLA: Create a dedicated WSLA provider - WslTraceLoggingInitialize(WslaServiceTelemetryProvider, !wsl::shared::OfficialBuild); + WslTraceLoggingInitialize(WslaTelemetryProvider, !wsl::shared::OfficialBuild); WSL_LOG("Service starting", TraceLoggingLevel(WINEVENT_LEVEL_INFO)); diff --git a/src/windows/wslaservice/exe/WSLASession.cpp b/src/windows/wslaservice/exe/WSLASession.cpp index 24a8048..19acca0 100644 --- a/src/windows/wslaservice/exe/WSLASession.cpp +++ b/src/windows/wslaservice/exe/WSLASession.cpp @@ -22,8 +22,9 @@ Abstract: using wsl::windows::service::wsla::WSLASession; using wsl::windows::service::wsla::WSLAVirtualMachine; -WSLASession::WSLASession(const WSLA_SESSION_SETTINGS& Settings, WSLAUserSessionImpl& userSessionImpl) : - m_sessionSettings(Settings), m_userSession(&userSessionImpl), m_displayName(Settings.DisplayName) +WSLASession::WSLASession(ULONG id, const WSLA_SESSION_SETTINGS& Settings, WSLAUserSessionImpl& userSessionImpl) : + + m_id(id), m_sessionSettings(Settings), m_userSession(&userSessionImpl), m_displayName(Settings.DisplayName) { WSL_LOG("SessionCreated", TraceLoggingValue(m_displayName.c_str(), "DisplayName")); @@ -192,17 +193,22 @@ void WSLASession::ConfigureStorage(const WSLA_SESSION_SETTINGS& Settings) deleteVhdOnFailure.release(); } -HRESULT WSLASession::GetDisplayName(LPWSTR* DisplayName) -{ - *DisplayName = wil::make_unique_string(m_displayName.c_str()).release(); - return S_OK; -} - const std::wstring& WSLASession::DisplayName() const { return m_displayName; } +ULONG WSLASession::GetId() const noexcept +{ + return m_id; +} + +void WSLASession::CopyDisplayName(_Out_writes_z_(bufferLength) PWSTR buffer, size_t bufferLength) const +{ + THROW_HR_IF(E_BOUNDS, m_displayName.size() + 1 > bufferLength); + wcscpy_s(buffer, bufferLength, m_displayName.c_str()); +} + HRESULT WSLASession::PullImage(LPCWSTR Image, const WSLA_REGISTRY_AUTHENTICATION_INFORMATION* RegistryInformation, IProgressCallback* ProgressCallback) { return E_NOTIMPL; diff --git a/src/windows/wslaservice/exe/WSLASession.h b/src/windows/wslaservice/exe/WSLASession.h index c34a0d2..a4b3dad 100644 --- a/src/windows/wslaservice/exe/WSLASession.h +++ b/src/windows/wslaservice/exe/WSLASession.h @@ -25,12 +25,16 @@ class DECLSPEC_UUID("4877FEFC-4977-4929-A958-9F36AA1892A4") WSLASession : public Microsoft::WRL::RuntimeClass, IWSLASession, IFastRundown> { public: - WSLASession(const WSLA_SESSION_SETTINGS& Settings, WSLAUserSessionImpl& userSessionImpl); + WSLASession(ULONG id, const WSLA_SESSION_SETTINGS& Settings, WSLAUserSessionImpl& userSessionImpl); + ~WSLASession(); - IFACEMETHOD(GetDisplayName)(LPWSTR* DisplayName) override; + ULONG GetId() const noexcept; + const std::wstring& DisplayName() const; + void CopyDisplayName(_Out_writes_z_(bufferLength) PWSTR buffer, size_t bufferLength) const; + // Image management. IFACEMETHOD(PullImage)(_In_ LPCWSTR Image, _In_ const WSLA_REGISTRY_AUTHENTICATION_INFORMATION* RegistryInformation, _In_ IProgressCallback* ProgressCallback) override; IFACEMETHOD(ImportImage)(_In_ ULONG Handle, _In_ LPCWSTR Image, _In_ IProgressCallback* ProgressCallback) override; @@ -54,6 +58,8 @@ public: void OnUserSessionTerminating(); private: + ULONG m_id = 0; + static WSLAVirtualMachine::Settings CreateVmSettings(const WSLA_SESSION_SETTINGS& Settings); void ConfigureStorage(const WSLA_SESSION_SETTINGS& Settings); diff --git a/src/windows/wslaservice/exe/WSLAUserSession.cpp b/src/windows/wslaservice/exe/WSLAUserSession.cpp index ac2cd3c..7ab4fed 100644 --- a/src/windows/wslaservice/exe/WSLAUserSession.cpp +++ b/src/windows/wslaservice/exe/WSLAUserSession.cpp @@ -27,7 +27,7 @@ WSLAUserSessionImpl::~WSLAUserSessionImpl() // In case there are still COM references on sessions, signal that the user session is terminating // so the sessions are all in a 'terminated' state. { - std::lock_guard lock(m_lock); + std::lock_guard lock(m_wslaSessionsLock); for (auto& e : m_sessions) { @@ -38,7 +38,7 @@ WSLAUserSessionImpl::~WSLAUserSessionImpl() void WSLAUserSessionImpl::OnSessionTerminated(WSLASession* Session) { - std::lock_guard lock(m_lock); + std::lock_guard lock(m_wslaSessionsLock); WI_VERIFY(m_sessions.erase(Session) == 1); } @@ -49,7 +49,8 @@ PSID WSLAUserSessionImpl::GetUserSid() const HRESULT WSLAUserSessionImpl::CreateSession(const WSLA_SESSION_SETTINGS* Settings, IWSLASession** WslaSession) { - auto session = wil::MakeOrThrow(*Settings, *this); + ULONG id = m_nextSessionId++; + auto session = wil::MakeOrThrow(id, *Settings, *this); std::lock_guard lock(m_wslaSessionsLock); auto it = m_sessions.emplace(session.Get()); @@ -80,6 +81,26 @@ HRESULT WSLAUserSessionImpl::OpenSessionByName(LPCWSTR DisplayName, IWSLASession return HRESULT_FROM_WIN32(ERROR_NOT_FOUND); } +HRESULT wsl::windows::service::wsla::WSLAUserSessionImpl::ListSessions(_Out_ WSLA_SESSION_INFORMATION** Sessions, _Out_ ULONG* SessionsCount) +{ + std::lock_guard lock(m_wslaSessionsLock); + auto output = wil::make_unique_cotaskmem(m_sessions.size()); + + size_t index = 0; + for (auto* session : m_sessions) + { + output[index].SessionId = session->GetId(); + output[index].CreatorPid = 0; // placeholder until we populate this later + + session->CopyDisplayName(output[index].DisplayName, _countof(output[index].DisplayName)); + + ++index; + } + *Sessions = output.release(); + *SessionsCount = static_cast(m_sessions.size()); + return S_OK; +} + wsl::windows::service::wsla::WSLAUserSession::WSLAUserSession(std::weak_ptr&& Session) : m_session(std::move(Session)) { @@ -105,9 +126,20 @@ try CATCH_RETURN(); HRESULT wsl::windows::service::wsla::WSLAUserSession::ListSessions(WSLA_SESSION_INFORMATION** Sessions, ULONG* SessionsCount) +try { - return E_NOTIMPL; + if (!Sessions || !SessionsCount) + { + return E_INVALIDARG; + } + + auto session = m_session.lock(); + RETURN_HR_IF(RPC_E_DISCONNECTED, !session); + + RETURN_IF_FAILED(session->ListSessions(Sessions, SessionsCount)); + return S_OK; } +CATCH_RETURN(); HRESULT wsl::windows::service::wsla::WSLAUserSession::OpenSession(ULONG Id, IWSLASession** Session) { diff --git a/src/windows/wslaservice/exe/WSLAUserSession.h b/src/windows/wslaservice/exe/WSLAUserSession.h index be8c9a5..1ef6b13 100644 --- a/src/windows/wslaservice/exe/WSLAUserSession.h +++ b/src/windows/wslaservice/exe/WSLAUserSession.h @@ -15,6 +15,10 @@ Abstract: #pragma once #include "WSLAVirtualMachine.h" #include "WSLASession.h" +#include +#include +#include +#include namespace wsl::windows::service::wsla { @@ -31,14 +35,15 @@ public: HRESULT CreateSession(const WSLA_SESSION_SETTINGS* Settings, IWSLASession** WslaSession); HRESULT OpenSessionByName(_In_ LPCWSTR DisplayName, _Out_ IWSLASession** Session); + HRESULT ListSessions(_Out_ WSLA_SESSION_INFORMATION** Sessions, _Out_ ULONG* SessionsCount); void OnSessionTerminated(WSLASession* Session); private: wil::unique_tokeninfo_ptr m_tokenInfo; + std::atomic m_nextSessionId{1}; std::recursive_mutex m_wslaSessionsLock; - std::recursive_mutex m_lock; // TODO-WSLA: Consider using a weak_ptr to easily destroy when the last client reference is released. std::unordered_set m_sessions; diff --git a/src/windows/wslaservice/inc/wslaservice.idl b/src/windows/wslaservice/inc/wslaservice.idl index 957882c..2e31b20 100644 --- a/src/windows/wslaservice/inc/wslaservice.idl +++ b/src/windows/wslaservice/inc/wslaservice.idl @@ -299,15 +299,14 @@ interface IWSLASession : IUnknown HRESULT Shutdown([in] ULONG TimeoutMs); // To be deleted. - HRESULT GetDisplayName([out] LPWSTR* DisplayName); HRESULT GetVirtualMachine([out] IWSLAVirtualMachine **VirtualMachine); } struct WSLA_SESSION_INFORMATION { - ULONG Id; + ULONG SessionId; DWORD CreatorPid; - LPSTR DisplayName; + wchar_t DisplayName[256]; }; [