WSLAUserSession and WSLASession changes

* Add WSLASession creation and tracking in WSLAUserSession
* Wire up WSLASession to contain and create WSLAVirtualMachine
This commit is contained in:
Pooja Trivedi 2025-11-05 10:26:09 -05:00
parent 46db45c39a
commit ab85d1ae21
6 changed files with 46 additions and 15 deletions

View File

@ -13,15 +13,28 @@ Abstract:
--*/
#include "WSLASession.h"
#include "WSLAUserSession.h"
using wsl::windows::service::wsla::WSLASession;
WSLASession::WSLASession(const WSLA_SESSION_SETTINGS& Settings) : m_displayName(Settings.DisplayName)
WSLASession::WSLASession(const WSLA_SESSION_SETTINGS& Settings, WSLAUserSessionImpl& userSessionImpl, const VIRTUAL_MACHINE_SETTINGS& VmSettings) :
m_sessionSettings(Settings),
m_userSession(userSessionImpl),
m_virtualMachine(wil::MakeOrThrow<WSLAVirtualMachine>(VmSettings, userSessionImpl.GetUserSid(), &userSessionImpl))
{
m_virtualMachine->Start();
}
HRESULT WSLASession::GetDisplayName(LPWSTR* DisplayName)
{
*DisplayName = wil::make_unique_string<wil::unique_cotaskmem_string>(m_displayName.c_str()).release();
*DisplayName = wil::make_unique_string<wil::unique_cotaskmem_string>(m_sessionSettings.DisplayName).release();
return S_OK;
}
HRESULT WSLASession::GetVirtualMachine(IWSLAVirtualMachine** VirtualMachine)
{
THROW_IF_FAILED(m_virtualMachine.CopyTo(__uuidof(IWSLAVirtualMachine), (void**)VirtualMachine));
return S_OK;
}

View File

@ -15,6 +15,7 @@ Abstract:
#pragma once
#include "wslaservice.h"
#include "WSLAVirtualMachine.h"
namespace wsl::windows::service::wsla {
@ -22,11 +23,14 @@ class DECLSPEC_UUID("4877FEFC-4977-4929-A958-9F36AA1892A4") WSLASession
: public Microsoft::WRL::RuntimeClass<Microsoft::WRL::RuntimeClassFlags<Microsoft::WRL::ClassicCom>, IWSLASession, IFastRundown>
{
public:
WSLASession(const WSLA_SESSION_SETTINGS& Settings);
WSLASession(const WSLA_SESSION_SETTINGS& Settings, WSLAUserSessionImpl& userSessionImpl, const VIRTUAL_MACHINE_SETTINGS& VmSettings);
IFACEMETHOD(GetDisplayName)(LPWSTR* DisplayName);
IFACEMETHOD(GetVirtualMachine)(IWSLAVirtualMachine** VirtualMachine);
private:
std::wstring m_displayName;
WSLA_SESSION_SETTINGS m_sessionSettings;
WSLAUserSessionImpl& m_userSession;
Microsoft::WRL::ComPtr<WSLAVirtualMachine> m_virtualMachine;
};
} // namespace wsl::windows::service::wsla

View File

@ -64,16 +64,17 @@ PSID WSLAUserSessionImpl::GetUserSid() const
return m_tokenInfo->User.Sid;
}
HRESULT wsl::windows::service::wsla::WSLAUserSessionImpl::CreateSession(const WSLA_SESSION_SETTINGS* Settings, IWSLASession** Session)
HRESULT wsl::windows::service::wsla::WSLAUserSessionImpl::CreateSession(
const WSLA_SESSION_SETTINGS* Settings, const VIRTUAL_MACHINE_SETTINGS* VmSettings, IWSLASession** WslaSession)
{
auto session = wil::MakeOrThrow<WSLASession>(*Settings);
auto session = wil::MakeOrThrow<WSLASession>(*Settings, *this, *VmSettings);
{
std::lock_guard lock(m_lock);
m_sessions.emplace_back(session.Get());
std::lock_guard lock(m_wslaSessionsLock);
m_wslaSessions.emplace_back(session.Get());
}
THROW_IF_FAILED(session.CopyTo(__uuidof(IWSLASession), (void**)Session));
THROW_IF_FAILED(session.CopyTo(__uuidof(IWSLASession), (void**)WslaSession));
return S_OK;
}
@ -102,12 +103,13 @@ try
}
CATCH_RETURN();
HRESULT wsl::windows::service::wsla::WSLAUserSession::CreateSession(const WSLA_SESSION_SETTINGS* Settings, IWSLASession** Session)
HRESULT wsl::windows::service::wsla::WSLAUserSession::CreateSession(
const WSLA_SESSION_SETTINGS* Settings, const VIRTUAL_MACHINE_SETTINGS* VmSettings, IWSLASession** WslaSession)
try
{
auto session = m_session.lock();
RETURN_HR_IF(RPC_E_DISCONNECTED, !session);
return session->CreateSession(Settings, Session);
return session->CreateSession(Settings, VmSettings, WslaSession);
}
CATCH_RETURN();

View File

@ -30,13 +30,16 @@ public:
PSID GetUserSid() const;
HRESULT CreateVirtualMachine(const VIRTUAL_MACHINE_SETTINGS* Settings, IWSLAVirtualMachine** VirtualMachine);
HRESULT CreateSession(const WSLA_SESSION_SETTINGS* Settings, IWSLASession** Session);
HRESULT CreateSession(const WSLA_SESSION_SETTINGS* Settings, const VIRTUAL_MACHINE_SETTINGS* VmSettings, IWSLASession** WslaSession);
void OnVmTerminated(WSLAVirtualMachine* machine);
private:
wil::unique_tokeninfo_ptr<TOKEN_USER> m_tokenInfo;
std::recursive_mutex m_wslaSessionsLock;
// TODO-WSLA: Consider using a weak_ptr to easily destroy when the last client reference is released.
std::vector<Microsoft::WRL::ComPtr<WSLASession>> m_wslaSessions;
std::recursive_mutex m_lock;
std::vector<WSLAVirtualMachine*> m_virtualMachines;
@ -54,7 +57,7 @@ public:
IFACEMETHOD(GetVersion)(_Out_ WSL_VERSION* Version) override;
IFACEMETHOD(CreateVirtualMachine)(const VIRTUAL_MACHINE_SETTINGS* Settings, IWSLAVirtualMachine** VirtualMachine) override;
IFACEMETHOD(CreateSession)(const WSLA_SESSION_SETTINGS* Settings, IWSLASession** Session);
IFACEMETHOD(CreateSession)(const WSLA_SESSION_SETTINGS* WslaSessionSettings, const VIRTUAL_MACHINE_SETTINGS* VmSettings, IWSLASession** WslaSession);
private:
std::weak_ptr<WSLAUserSessionImpl> m_session;

View File

@ -117,6 +117,7 @@ struct _WSLA_SESSION_SETTINGS {
interface IWSLASession : IUnknown
{
HRESULT GetDisplayName([out] LPWSTR* DisplayName);
HRESULT GetVirtualMachine([out] IWSLAVirtualMachine **VirtualMachine);
}
[
@ -128,5 +129,5 @@ interface IWSLAUserSession : IUnknown
{
HRESULT GetVersion([out] WSL_VERSION* Error);
HRESULT CreateVirtualMachine([in] const VIRTUAL_MACHINE_SETTINGS* Settings, [out]IWSLAVirtualMachine** VirtualMachine);
HRESULT CreateSession([in] const WSLA_SESSION_SETTINGS* Settings, [out]IWSLASession** Session);
HRESULT CreateSession([in] const WSLA_SESSION_SETTINGS* Settings, [in] const VIRTUAL_MACHINE_SETTINGS* VmSettings, [out] IWSLASession** Session);
}

View File

@ -1116,7 +1116,15 @@ class WSLATests
WSLA_SESSION_SETTINGS settings{L"my-display-name"};
wil::com_ptr<IWSLASession> session;
VERIFY_SUCCEEDED(userSession->CreateSession(&settings, &session));
VIRTUAL_MACHINE_SETTINGS vmSettings{};
vmSettings.BootTimeoutMs = 30 * 1000;
vmSettings.DisplayName = L"WSLA";
vmSettings.MemoryMb = 2048;
vmSettings.CpuCount = 4;
vmSettings.NetworkingMode = WslNetworkingModeNone;
vmSettings.EnableDebugShell = true;
VERIFY_SUCCEEDED(userSession->CreateSession(&settings, &vmSettings, &session));
wil::unique_cotaskmem_string returnedDisplayName;
VERIFY_SUCCEEDED(session->GetDisplayName(&returnedDisplayName));