Create structure for WSLASession COM class & implementation (#13638)

* Create structure for WSLASession COM class & implementation

* Fix COM security pointer

* Add missing files

* Add copyright header

* Update src/windows/wslaservice/exe/WSLASession.cpp

Co-authored-by: Pooja Trivedi <poojatrivedi@gmail.com>

---------

Co-authored-by: Pooja Trivedi <poojatrivedi@gmail.com>
This commit is contained in:
Blue 2025-10-28 14:44:33 -07:00 committed by GitHub
parent 451a7e103a
commit cedbb94d36
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 170 additions and 25 deletions

View File

@ -264,6 +264,12 @@
<RegistryValue Value="WSLAVirtualMachine" Type="string" />
</RegistryKey>
<!-- WSLASession -->
<RegistryKey Root="HKCR" Key="CLSID\{4877FEFC-4977-4929-A958-9F36AA1892A4}">
<RegistryValue Name="AppId" Value="{E9B79997-57E3-4201-AECC-6A464E530DD2}" Type="string" />
<RegistryValue Value="WSLASession" Type="string" />
</RegistryKey>
<!-- IWSLAUserSession-->
<RegistryKey Root="HKCR" Key="Interface\{82A7ABC8-6B50-43FC-AB96-15FBBE7E8760}">
<RegistryValue Value="IWSLAUserSession" Type="string" />
@ -288,6 +294,13 @@
</RegistryKey>
</RegistryKey>
<!-- IWSLASession-->
<RegistryKey Root="HKCR" Key="Interface\{EF0661E4-6364-40EA-B433-E2FDF11F3519}">
<RegistryValue Value="IWSLASession" Type="string" />
<RegistryKey Key="ProxyStubClsid32">
<RegistryValue Value="{4EA0C6DD-E9FF-48E7-994E-13A31D10DC61}" Type="string" />
</RegistryKey>
</RegistryKey>
<File Id="wslaservice.exe" Source="${BIN}/wslaservice.exe" KeyPath="yes" />
<File Id="wslaserviceproxystub.dll" Name="wslaserviceproxystub.dll" Source="${BIN}/wslaserviceproxystub.dll" />

View File

@ -99,6 +99,21 @@ wil::unique_handle wsl::windows::common::security::CreateRestrictedToken(_In_ HA
return restrictedToken;
}
void wsl::windows::common::security::ConfigureForCOMImpersonation(IUnknown* Instance)
{
wil::com_ptr_nothrow<IClientSecurity> clientSecurity;
THROW_IF_FAILED(Instance->QueryInterface(IID_PPV_ARGS(&clientSecurity)));
// Get the current proxy blanket settings.
DWORD authnSvc, authzSvc, authnLvl, capabilites;
THROW_IF_FAILED(clientSecurity->QueryBlanket(Instance, &authnSvc, &authzSvc, NULL, &authnLvl, NULL, NULL, &capabilites));
// Make sure that dynamic cloaking is used.
WI_ClearFlag(capabilites, EOAC_STATIC_CLOAKING);
WI_SetFlag(capabilites, EOAC_DYNAMIC_CLOAKING);
THROW_IF_FAILED(clientSecurity->SetBlanket(Instance, authnSvc, authzSvc, NULL, authnLvl, RPC_C_IMP_LEVEL_IMPERSONATE, NULL, capabilites));
}
LUID wsl::windows::common::security::EnableTokenPrivilege(_Inout_ HANDLE token, _In_ LPCWSTR privilegeName)
{
// Convert privilege name to an LUID.

View File

@ -87,6 +87,11 @@ std::pair<PSID, std::vector<char>> CreateSid(SID_IDENTIFIER_AUTHORITY Authority,
/// </summary>
wil::unique_handle CreateRestrictedToken(_In_ HANDLE token);
/// <summary>
/// Configures a COM object for impersonation.
/// <summary>
void ConfigureForCOMImpersonation(IUnknown* instance);
/// <summary>
/// Enables a privilege on the token.
/// </summary>

View File

@ -18,24 +18,6 @@ Abstract:
#include "wslrelay.h"
#include "wslInstall.h"
namespace {
void ConfigureComSecurity(IUnknown* Instance)
{
wil::com_ptr_nothrow<IClientSecurity> clientSecurity;
THROW_IF_FAILED(Instance->QueryInterface(IID_PPV_ARGS(&clientSecurity)));
// Get the current proxy blanket settings.
DWORD authnSvc, authzSvc, authnLvl, capabilites;
THROW_IF_FAILED(clientSecurity->QueryBlanket(Instance, &authnSvc, &authzSvc, NULL, &authnLvl, NULL, NULL, &capabilites));
// Make sure that dynamic cloaking is used.
WI_ClearFlag(capabilites, EOAC_STATIC_CLOAKING);
WI_SetFlag(capabilites, EOAC_DYNAMIC_CLOAKING);
THROW_IF_FAILED(clientSecurity->SetBlanket(Instance, authnSvc, authzSvc, NULL, authnLvl, RPC_C_IMP_LEVEL_IMPERSONATE, NULL, capabilites));
}
} // namespace
class DECLSPEC_UUID("7BC4E198-6531-4FA6-ADE2-5EF3D2A04DFF") CallbackInstance
: public Microsoft::WRL::RuntimeClass<Microsoft::WRL::RuntimeClassFlags<Microsoft::WRL::ClassicCom>, ITerminationCallback, IFastRundown>
{
@ -74,7 +56,7 @@ try
wil::com_ptr<IWSLAUserSession> session;
THROW_IF_FAILED(CoCreateInstance(__uuidof(WSLAUserSession), nullptr, CLSCTX_LOCAL_SERVER, IID_PPV_ARGS(&session)));
ConfigureComSecurity(session.get());
wsl::windows::common::security::ConfigureForCOMImpersonation(session.get());
wil::com_ptr<IWSLAVirtualMachine> virtualMachineInstance;
@ -91,7 +73,7 @@ try
settings.EnableGPU = UserSettings->GPU.Enable;
THROW_IF_FAILED(session->CreateVirtualMachine(&settings, &virtualMachineInstance));
ConfigureComSecurity(virtualMachineInstance.get());
wsl::windows::common::security::ConfigureForCOMImpersonation(virtualMachineInstance.get());
// Register termination callback, if specified
if (UserSettings->Options.TerminationCallback != nullptr)

View File

@ -2,12 +2,14 @@ set(SOURCES
application.manifest
main.rc
ServiceMain.cpp
WSLASession.cpp
WSLAUserSession.cpp
WSLAUserSessionFactory.cpp
WSLAVirtualMachine.cpp
)
set(HEADERS
WSLASession.h
WSLAUserSession.h
WSLAUserSessionFactory.h
WSLAVirtualMachine.h)

View File

@ -0,0 +1,27 @@
/*++
Copyright (c) Microsoft. All rights reserved.
Module Name:
WSLASession.cpp
Abstract:
This file contains the implementation of the WSLASession COM class.
--*/
#include "WSLASession.h"
using wsl::windows::service::wsla::WSLASession;
WSLASession::WSLASession(const WSLA_SESSION_SETTINGS& Settings) : m_displayName(Settings.DisplayName)
{
}
HRESULT WSLASession::GetDisplayName(LPWSTR* DisplayName)
{
*DisplayName = wil::make_unique_string<wil::unique_cotaskmem_string>(m_displayName.c_str()).release();
return S_OK;
}

View File

@ -0,0 +1,32 @@
/*++
Copyright (c) Microsoft. All rights reserved.
Module Name:
WSLASession.h
Abstract:
TODO
--*/
#pragma once
#include "wslaservice.h"
namespace wsl::windows::service::wsla {
class DECLSPEC_UUID("4877FEFC-4977-4929-A958-9F36AA1892A4") WSLASession
: public Microsoft::WRL::RuntimeClass<Microsoft::WRL::RuntimeClassFlags<Microsoft::WRL::ClassicCom>, IWSLASession, IFastRundown>
{
public:
WSLASession(const WSLA_SESSION_SETTINGS& Settings);
IFACEMETHOD(GetDisplayName)(LPWSTR* DisplayName);
private:
std::wstring m_displayName;
};
} // namespace wsl::windows::service::wsla

View File

@ -11,7 +11,9 @@ Abstract:
TODO
--*/
#include "WSLAUserSession.h"
#include "WSLASession.h"
using wsl::windows::service::wsla::WSLAUserSessionImpl;
@ -24,7 +26,7 @@ WSLAUserSessionImpl::~WSLAUserSessionImpl()
{
// Manually signal the VM termination events. This prevents being stuck on an API call that holds the VM lock.
{
std::lock_guard lock(m_virtualMachinesLock);
std::lock_guard lock(m_lock);
for (auto* e : m_virtualMachines)
{
@ -35,7 +37,7 @@ WSLAUserSessionImpl::~WSLAUserSessionImpl()
void WSLAUserSessionImpl::OnVmTerminated(WSLAVirtualMachine* machine)
{
std::lock_guard lock(m_virtualMachinesLock);
std::lock_guard lock(m_lock);
auto pred = [machine](const auto* e) { return machine == e; };
// Remove any stale VM reference.
@ -47,7 +49,7 @@ HRESULT WSLAUserSessionImpl::CreateVirtualMachine(const VIRTUAL_MACHINE_SETTINGS
auto vm = wil::MakeOrThrow<WSLAVirtualMachine>(*Settings, GetUserSid(), this);
{
std::lock_guard lock(m_virtualMachinesLock);
std::lock_guard lock(m_lock);
m_virtualMachines.emplace_back(vm.Get());
}
@ -62,6 +64,20 @@ PSID WSLAUserSessionImpl::GetUserSid() const
return m_tokenInfo->User.Sid;
}
HRESULT wsl::windows::service::wsla::WSLAUserSessionImpl::CreateSession(const WSLA_SESSION_SETTINGS* Settings, IWSLASession** Session)
{
auto session = wil::MakeOrThrow<WSLASession>(*Settings);
{
std::lock_guard lock(m_lock);
m_sessions.emplace_back(session.Get());
}
THROW_IF_FAILED(session.CopyTo(__uuidof(IWSLASession), (void**)Session));
return S_OK;
}
wsl::windows::service::wsla::WSLAUserSession::WSLAUserSession(std::weak_ptr<WSLAUserSessionImpl>&& Session) :
m_session(std::move(Session))
{
@ -84,4 +100,14 @@ try
return session->CreateVirtualMachine(Settings, VirtualMachine);
}
CATCH_RETURN();
CATCH_RETURN();
HRESULT wsl::windows::service::wsla::WSLAUserSession::CreateSession(const WSLA_SESSION_SETTINGS* Settings, IWSLASession** Session)
try
{
auto session = m_session.lock();
RETURN_HR_IF(RPC_E_DISCONNECTED, !session);
return session->CreateSession(Settings, Session);
}
CATCH_RETURN();

View File

@ -11,8 +11,10 @@ Abstract:
TODO
--*/
#pragma once
#include "WSLAVirtualMachine.h"
#include "WSLASession.h"
namespace wsl::windows::service::wsla {
@ -28,14 +30,18 @@ public:
PSID GetUserSid() const;
HRESULT CreateVirtualMachine(const VIRTUAL_MACHINE_SETTINGS* Settings, IWSLAVirtualMachine** VirtualMachine);
HRESULT CreateSession(const WSLA_SESSION_SETTINGS* Settings, IWSLASession** Session);
void OnVmTerminated(WSLAVirtualMachine* machine);
private:
wil::unique_tokeninfo_ptr<TOKEN_USER> m_tokenInfo;
std::recursive_mutex m_virtualMachinesLock;
std::recursive_mutex m_lock;
std::vector<WSLAVirtualMachine*> m_virtualMachines;
// TODO-WSLA: Consider using a weak_ptr to easily destroy when the last client reference is released.
std::vector<Microsoft::WRL::ComPtr<WSLASession>> m_sessions;
};
class DECLSPEC_UUID("a9b7a1b9-0671-405c-95f1-e0612cb4ce8f") WSLAUserSession
@ -48,6 +54,7 @@ public:
IFACEMETHOD(GetVersion)(_Out_ WSL_VERSION* Version) override;
IFACEMETHOD(CreateVirtualMachine)(const VIRTUAL_MACHINE_SETTINGS* Settings, IWSLAVirtualMachine** VirtualMachine) override;
IFACEMETHOD(CreateSession)(const WSLA_SESSION_SETTINGS* Settings, IWSLASession** Session);
private:
std::weak_ptr<WSLAUserSessionImpl> m_session;

View File

@ -102,6 +102,23 @@ struct _VIRTUAL_MACHINE_SETTINGS {
} VIRTUAL_MACHINE_SETTINGS;
typedef
struct _WSLA_SESSION_SETTINGS {
LPCWSTR DisplayName;
// Details TBD.
} WSLA_SESSION_SETTINGS;
[
uuid(EF0661E4-6364-40EA-B433-E2FDF11F3519),
pointer_default(unique),
object
]
interface IWSLASession : IUnknown
{
HRESULT GetDisplayName([out] LPWSTR* DisplayName);
}
[
uuid(82A7ABC8-6B50-43FC-AB96-15FBBE7E8760),
pointer_default(unique),
@ -111,4 +128,5 @@ interface IWSLAUserSession : IUnknown
{
HRESULT GetVersion([out] WSL_VERSION* Error);
HRESULT CreateVirtualMachine([in] const VIRTUAL_MACHINE_SETTINGS* Settings, [out]IWSLAVirtualMachine** VirtualMachine);
HRESULT CreateSession([in] const WSLA_SESSION_SETTINGS* Settings, [out]IWSLASession** Session);
}

View File

@ -15,6 +15,7 @@ Abstract:
#include "precomp.h"
#include "Common.h"
#include "WSLAApi.h"
#include "wslaservice.h"
using namespace wsl::windows::common::registry;
@ -1067,4 +1068,21 @@ class WSLATests
// Validate that xsk_diag is now loaded.
VERIFY_ARE_EQUAL(RunCommand(vm.get(), {"/bin/bash", "-c", "lsmod | grep ^xsk_diag"}), 0);
}
TEST_METHOD(CreateSessionSmokeTest)
{
wil::com_ptr<IWSLAUserSession> userSession;
VERIFY_SUCCEEDED(CoCreateInstance(__uuidof(WSLAUserSession), nullptr, CLSCTX_LOCAL_SERVER, IID_PPV_ARGS(&userSession)));
wsl::windows::common::security::ConfigureForCOMImpersonation(userSession.get());
WSLA_SESSION_SETTINGS settings{L"my-display-name"};
wil::com_ptr<IWSLASession> session;
VERIFY_SUCCEEDED(userSession->CreateSession(&settings, &session));
wil::unique_cotaskmem_string returnedDisplayName;
VERIFY_SUCCEEDED(session->GetDisplayName(&returnedDisplayName));
VERIFY_ARE_EQUAL(returnedDisplayName.get(), std::wstring(L"my-display-name"));
}
};