diff --git a/msipackage/package.wix.in b/msipackage/package.wix.in index 5a7d1fd..3bdfd34 100644 --- a/msipackage/package.wix.in +++ b/msipackage/package.wix.in @@ -264,6 +264,12 @@ + + + + + + @@ -288,6 +294,13 @@ + + + + + + + diff --git a/src/windows/common/WslSecurity.cpp b/src/windows/common/WslSecurity.cpp index be1ff0a..85fd8b5 100644 --- a/src/windows/common/WslSecurity.cpp +++ b/src/windows/common/WslSecurity.cpp @@ -99,6 +99,21 @@ wil::unique_handle wsl::windows::common::security::CreateRestrictedToken(_In_ HA return restrictedToken; } +void wsl::windows::common::security::ConfigureForCOMImpersonation(IUnknown* Instance) +{ + wil::com_ptr_nothrow clientSecurity; + THROW_IF_FAILED(Instance->QueryInterface(IID_PPV_ARGS(&clientSecurity))); + + // Get the current proxy blanket settings. + DWORD authnSvc, authzSvc, authnLvl, capabilites; + THROW_IF_FAILED(clientSecurity->QueryBlanket(Instance, &authnSvc, &authzSvc, NULL, &authnLvl, NULL, NULL, &capabilites)); + + // Make sure that dynamic cloaking is used. + WI_ClearFlag(capabilites, EOAC_STATIC_CLOAKING); + WI_SetFlag(capabilites, EOAC_DYNAMIC_CLOAKING); + THROW_IF_FAILED(clientSecurity->SetBlanket(Instance, authnSvc, authzSvc, NULL, authnLvl, RPC_C_IMP_LEVEL_IMPERSONATE, NULL, capabilites)); +} + LUID wsl::windows::common::security::EnableTokenPrivilege(_Inout_ HANDLE token, _In_ LPCWSTR privilegeName) { // Convert privilege name to an LUID. diff --git a/src/windows/common/WslSecurity.h b/src/windows/common/WslSecurity.h index 2b2f214..daefe1f 100644 --- a/src/windows/common/WslSecurity.h +++ b/src/windows/common/WslSecurity.h @@ -87,6 +87,11 @@ std::pair> CreateSid(SID_IDENTIFIER_AUTHORITY Authority, /// wil::unique_handle CreateRestrictedToken(_In_ HANDLE token); +/// +/// Configures a COM object for impersonation. +/// +void ConfigureForCOMImpersonation(IUnknown* instance); + /// /// Enables a privilege on the token. /// diff --git a/src/windows/wslaclient/DllMain.cpp b/src/windows/wslaclient/DllMain.cpp index 65d5044..012246b 100644 --- a/src/windows/wslaclient/DllMain.cpp +++ b/src/windows/wslaclient/DllMain.cpp @@ -18,24 +18,6 @@ Abstract: #include "wslrelay.h" #include "wslInstall.h" -namespace { - -void ConfigureComSecurity(IUnknown* Instance) -{ - wil::com_ptr_nothrow clientSecurity; - THROW_IF_FAILED(Instance->QueryInterface(IID_PPV_ARGS(&clientSecurity))); - - // Get the current proxy blanket settings. - DWORD authnSvc, authzSvc, authnLvl, capabilites; - THROW_IF_FAILED(clientSecurity->QueryBlanket(Instance, &authnSvc, &authzSvc, NULL, &authnLvl, NULL, NULL, &capabilites)); - - // Make sure that dynamic cloaking is used. - WI_ClearFlag(capabilites, EOAC_STATIC_CLOAKING); - WI_SetFlag(capabilites, EOAC_DYNAMIC_CLOAKING); - THROW_IF_FAILED(clientSecurity->SetBlanket(Instance, authnSvc, authzSvc, NULL, authnLvl, RPC_C_IMP_LEVEL_IMPERSONATE, NULL, capabilites)); -} -} // namespace - class DECLSPEC_UUID("7BC4E198-6531-4FA6-ADE2-5EF3D2A04DFF") CallbackInstance : public Microsoft::WRL::RuntimeClass, ITerminationCallback, IFastRundown> { @@ -74,7 +56,7 @@ try wil::com_ptr session; THROW_IF_FAILED(CoCreateInstance(__uuidof(WSLAUserSession), nullptr, CLSCTX_LOCAL_SERVER, IID_PPV_ARGS(&session))); - ConfigureComSecurity(session.get()); + wsl::windows::common::security::ConfigureForCOMImpersonation(session.get()); wil::com_ptr virtualMachineInstance; @@ -91,7 +73,7 @@ try settings.EnableGPU = UserSettings->GPU.Enable; THROW_IF_FAILED(session->CreateVirtualMachine(&settings, &virtualMachineInstance)); - ConfigureComSecurity(virtualMachineInstance.get()); + wsl::windows::common::security::ConfigureForCOMImpersonation(virtualMachineInstance.get()); // Register termination callback, if specified if (UserSettings->Options.TerminationCallback != nullptr) diff --git a/src/windows/wslaservice/exe/CMakeLists.txt b/src/windows/wslaservice/exe/CMakeLists.txt index 8d51923..46260d1 100644 --- a/src/windows/wslaservice/exe/CMakeLists.txt +++ b/src/windows/wslaservice/exe/CMakeLists.txt @@ -2,12 +2,14 @@ set(SOURCES application.manifest main.rc ServiceMain.cpp + WSLASession.cpp WSLAUserSession.cpp WSLAUserSessionFactory.cpp WSLAVirtualMachine.cpp ) set(HEADERS + WSLASession.h WSLAUserSession.h WSLAUserSessionFactory.h WSLAVirtualMachine.h) diff --git a/src/windows/wslaservice/exe/WSLASession.cpp b/src/windows/wslaservice/exe/WSLASession.cpp new file mode 100644 index 0000000..617ec40 --- /dev/null +++ b/src/windows/wslaservice/exe/WSLASession.cpp @@ -0,0 +1,27 @@ +/*++ + +Copyright (c) Microsoft. All rights reserved. + +Module Name: + + WSLASession.cpp + +Abstract: + + This file contains the implementation of the WSLASession COM class. + +--*/ + +#include "WSLASession.h" + +using wsl::windows::service::wsla::WSLASession; + +WSLASession::WSLASession(const WSLA_SESSION_SETTINGS& Settings) : m_displayName(Settings.DisplayName) +{ +} + +HRESULT WSLASession::GetDisplayName(LPWSTR* DisplayName) +{ + *DisplayName = wil::make_unique_string(m_displayName.c_str()).release(); + return S_OK; +} diff --git a/src/windows/wslaservice/exe/WSLASession.h b/src/windows/wslaservice/exe/WSLASession.h new file mode 100644 index 0000000..55fcbc0 --- /dev/null +++ b/src/windows/wslaservice/exe/WSLASession.h @@ -0,0 +1,32 @@ +/*++ + +Copyright (c) Microsoft. All rights reserved. + +Module Name: + + WSLASession.h + +Abstract: + + TODO + +--*/ + +#pragma once + +#include "wslaservice.h" + +namespace wsl::windows::service::wsla { + +class DECLSPEC_UUID("4877FEFC-4977-4929-A958-9F36AA1892A4") WSLASession + : public Microsoft::WRL::RuntimeClass, IWSLASession, IFastRundown> +{ +public: + WSLASession(const WSLA_SESSION_SETTINGS& Settings); + IFACEMETHOD(GetDisplayName)(LPWSTR* DisplayName); + +private: + std::wstring m_displayName; +}; + +} // namespace wsl::windows::service::wsla \ No newline at end of file diff --git a/src/windows/wslaservice/exe/WSLAUserSession.cpp b/src/windows/wslaservice/exe/WSLAUserSession.cpp index deb6599..d75ba00 100644 --- a/src/windows/wslaservice/exe/WSLAUserSession.cpp +++ b/src/windows/wslaservice/exe/WSLAUserSession.cpp @@ -11,7 +11,9 @@ Abstract: TODO --*/ + #include "WSLAUserSession.h" +#include "WSLASession.h" using wsl::windows::service::wsla::WSLAUserSessionImpl; @@ -24,7 +26,7 @@ WSLAUserSessionImpl::~WSLAUserSessionImpl() { // Manually signal the VM termination events. This prevents being stuck on an API call that holds the VM lock. { - std::lock_guard lock(m_virtualMachinesLock); + std::lock_guard lock(m_lock); for (auto* e : m_virtualMachines) { @@ -35,7 +37,7 @@ WSLAUserSessionImpl::~WSLAUserSessionImpl() void WSLAUserSessionImpl::OnVmTerminated(WSLAVirtualMachine* machine) { - std::lock_guard lock(m_virtualMachinesLock); + std::lock_guard lock(m_lock); auto pred = [machine](const auto* e) { return machine == e; }; // Remove any stale VM reference. @@ -47,7 +49,7 @@ HRESULT WSLAUserSessionImpl::CreateVirtualMachine(const VIRTUAL_MACHINE_SETTINGS auto vm = wil::MakeOrThrow(*Settings, GetUserSid(), this); { - std::lock_guard lock(m_virtualMachinesLock); + std::lock_guard lock(m_lock); m_virtualMachines.emplace_back(vm.Get()); } @@ -62,6 +64,20 @@ PSID WSLAUserSessionImpl::GetUserSid() const return m_tokenInfo->User.Sid; } +HRESULT wsl::windows::service::wsla::WSLAUserSessionImpl::CreateSession(const WSLA_SESSION_SETTINGS* Settings, IWSLASession** Session) +{ + auto session = wil::MakeOrThrow(*Settings); + + { + std::lock_guard lock(m_lock); + m_sessions.emplace_back(session.Get()); + } + + THROW_IF_FAILED(session.CopyTo(__uuidof(IWSLASession), (void**)Session)); + + return S_OK; +} + wsl::windows::service::wsla::WSLAUserSession::WSLAUserSession(std::weak_ptr&& Session) : m_session(std::move(Session)) { @@ -84,4 +100,14 @@ try return session->CreateVirtualMachine(Settings, VirtualMachine); } -CATCH_RETURN(); \ No newline at end of file +CATCH_RETURN(); + +HRESULT wsl::windows::service::wsla::WSLAUserSession::CreateSession(const WSLA_SESSION_SETTINGS* Settings, IWSLASession** Session) +try +{ + auto session = m_session.lock(); + RETURN_HR_IF(RPC_E_DISCONNECTED, !session); + + return session->CreateSession(Settings, Session); +} +CATCH_RETURN(); diff --git a/src/windows/wslaservice/exe/WSLAUserSession.h b/src/windows/wslaservice/exe/WSLAUserSession.h index e1c1762..391b023 100644 --- a/src/windows/wslaservice/exe/WSLAUserSession.h +++ b/src/windows/wslaservice/exe/WSLAUserSession.h @@ -11,8 +11,10 @@ Abstract: TODO --*/ + #pragma once #include "WSLAVirtualMachine.h" +#include "WSLASession.h" namespace wsl::windows::service::wsla { @@ -28,14 +30,18 @@ public: PSID GetUserSid() const; HRESULT CreateVirtualMachine(const VIRTUAL_MACHINE_SETTINGS* Settings, IWSLAVirtualMachine** VirtualMachine); + HRESULT CreateSession(const WSLA_SESSION_SETTINGS* Settings, IWSLASession** Session); void OnVmTerminated(WSLAVirtualMachine* machine); private: wil::unique_tokeninfo_ptr m_tokenInfo; - std::recursive_mutex m_virtualMachinesLock; + std::recursive_mutex m_lock; std::vector m_virtualMachines; + + // TODO-WSLA: Consider using a weak_ptr to easily destroy when the last client reference is released. + std::vector> m_sessions; }; class DECLSPEC_UUID("a9b7a1b9-0671-405c-95f1-e0612cb4ce8f") WSLAUserSession @@ -48,6 +54,7 @@ public: IFACEMETHOD(GetVersion)(_Out_ WSL_VERSION* Version) override; IFACEMETHOD(CreateVirtualMachine)(const VIRTUAL_MACHINE_SETTINGS* Settings, IWSLAVirtualMachine** VirtualMachine) override; + IFACEMETHOD(CreateSession)(const WSLA_SESSION_SETTINGS* Settings, IWSLASession** Session); private: std::weak_ptr m_session; diff --git a/src/windows/wslaservice/inc/wslaservice.idl b/src/windows/wslaservice/inc/wslaservice.idl index 831beff..3698719 100644 --- a/src/windows/wslaservice/inc/wslaservice.idl +++ b/src/windows/wslaservice/inc/wslaservice.idl @@ -102,6 +102,23 @@ struct _VIRTUAL_MACHINE_SETTINGS { } VIRTUAL_MACHINE_SETTINGS; +typedef +struct _WSLA_SESSION_SETTINGS { + LPCWSTR DisplayName; + // Details TBD. +} WSLA_SESSION_SETTINGS; + + +[ + uuid(EF0661E4-6364-40EA-B433-E2FDF11F3519), + pointer_default(unique), + object +] +interface IWSLASession : IUnknown +{ + HRESULT GetDisplayName([out] LPWSTR* DisplayName); +} + [ uuid(82A7ABC8-6B50-43FC-AB96-15FBBE7E8760), pointer_default(unique), @@ -111,4 +128,5 @@ interface IWSLAUserSession : IUnknown { HRESULT GetVersion([out] WSL_VERSION* Error); HRESULT CreateVirtualMachine([in] const VIRTUAL_MACHINE_SETTINGS* Settings, [out]IWSLAVirtualMachine** VirtualMachine); + HRESULT CreateSession([in] const WSLA_SESSION_SETTINGS* Settings, [out]IWSLASession** Session); } \ No newline at end of file diff --git a/test/windows/WSLATests.cpp b/test/windows/WSLATests.cpp index ef1e5b2..e27a8b6 100644 --- a/test/windows/WSLATests.cpp +++ b/test/windows/WSLATests.cpp @@ -15,6 +15,7 @@ Abstract: #include "precomp.h" #include "Common.h" #include "WSLAApi.h" +#include "wslaservice.h" using namespace wsl::windows::common::registry; @@ -1067,4 +1068,21 @@ class WSLATests // Validate that xsk_diag is now loaded. VERIFY_ARE_EQUAL(RunCommand(vm.get(), {"/bin/bash", "-c", "lsmod | grep ^xsk_diag"}), 0); } + + TEST_METHOD(CreateSessionSmokeTest) + { + wil::com_ptr userSession; + VERIFY_SUCCEEDED(CoCreateInstance(__uuidof(WSLAUserSession), nullptr, CLSCTX_LOCAL_SERVER, IID_PPV_ARGS(&userSession))); + wsl::windows::common::security::ConfigureForCOMImpersonation(userSession.get()); + + WSLA_SESSION_SETTINGS settings{L"my-display-name"}; + wil::com_ptr session; + + VERIFY_SUCCEEDED(userSession->CreateSession(&settings, &session)); + + wil::unique_cotaskmem_string returnedDisplayName; + VERIFY_SUCCEEDED(session->GetDisplayName(&returnedDisplayName)); + + VERIFY_ARE_EQUAL(returnedDisplayName.get(), std::wstring(L"my-display-name")); + } }; \ No newline at end of file