diff --git a/msipackage/package.wix.in b/msipackage/package.wix.in
index 5a7d1fd..3bdfd34 100644
--- a/msipackage/package.wix.in
+++ b/msipackage/package.wix.in
@@ -264,6 +264,12 @@
+
+
+
+
+
+
@@ -288,6 +294,13 @@
+
+
+
+
+
+
+
diff --git a/src/windows/common/WslSecurity.cpp b/src/windows/common/WslSecurity.cpp
index be1ff0a..85fd8b5 100644
--- a/src/windows/common/WslSecurity.cpp
+++ b/src/windows/common/WslSecurity.cpp
@@ -99,6 +99,21 @@ wil::unique_handle wsl::windows::common::security::CreateRestrictedToken(_In_ HA
return restrictedToken;
}
+void wsl::windows::common::security::ConfigureForCOMImpersonation(IUnknown* Instance)
+{
+ wil::com_ptr_nothrow clientSecurity;
+ THROW_IF_FAILED(Instance->QueryInterface(IID_PPV_ARGS(&clientSecurity)));
+
+ // Get the current proxy blanket settings.
+ DWORD authnSvc, authzSvc, authnLvl, capabilites;
+ THROW_IF_FAILED(clientSecurity->QueryBlanket(Instance, &authnSvc, &authzSvc, NULL, &authnLvl, NULL, NULL, &capabilites));
+
+ // Make sure that dynamic cloaking is used.
+ WI_ClearFlag(capabilites, EOAC_STATIC_CLOAKING);
+ WI_SetFlag(capabilites, EOAC_DYNAMIC_CLOAKING);
+ THROW_IF_FAILED(clientSecurity->SetBlanket(Instance, authnSvc, authzSvc, NULL, authnLvl, RPC_C_IMP_LEVEL_IMPERSONATE, NULL, capabilites));
+}
+
LUID wsl::windows::common::security::EnableTokenPrivilege(_Inout_ HANDLE token, _In_ LPCWSTR privilegeName)
{
// Convert privilege name to an LUID.
diff --git a/src/windows/common/WslSecurity.h b/src/windows/common/WslSecurity.h
index 2b2f214..daefe1f 100644
--- a/src/windows/common/WslSecurity.h
+++ b/src/windows/common/WslSecurity.h
@@ -87,6 +87,11 @@ std::pair> CreateSid(SID_IDENTIFIER_AUTHORITY Authority,
///
wil::unique_handle CreateRestrictedToken(_In_ HANDLE token);
+///
+/// Configures a COM object for impersonation.
+///
+void ConfigureForCOMImpersonation(IUnknown* instance);
+
///
/// Enables a privilege on the token.
///
diff --git a/src/windows/wslaclient/DllMain.cpp b/src/windows/wslaclient/DllMain.cpp
index 65d5044..012246b 100644
--- a/src/windows/wslaclient/DllMain.cpp
+++ b/src/windows/wslaclient/DllMain.cpp
@@ -18,24 +18,6 @@ Abstract:
#include "wslrelay.h"
#include "wslInstall.h"
-namespace {
-
-void ConfigureComSecurity(IUnknown* Instance)
-{
- wil::com_ptr_nothrow clientSecurity;
- THROW_IF_FAILED(Instance->QueryInterface(IID_PPV_ARGS(&clientSecurity)));
-
- // Get the current proxy blanket settings.
- DWORD authnSvc, authzSvc, authnLvl, capabilites;
- THROW_IF_FAILED(clientSecurity->QueryBlanket(Instance, &authnSvc, &authzSvc, NULL, &authnLvl, NULL, NULL, &capabilites));
-
- // Make sure that dynamic cloaking is used.
- WI_ClearFlag(capabilites, EOAC_STATIC_CLOAKING);
- WI_SetFlag(capabilites, EOAC_DYNAMIC_CLOAKING);
- THROW_IF_FAILED(clientSecurity->SetBlanket(Instance, authnSvc, authzSvc, NULL, authnLvl, RPC_C_IMP_LEVEL_IMPERSONATE, NULL, capabilites));
-}
-} // namespace
-
class DECLSPEC_UUID("7BC4E198-6531-4FA6-ADE2-5EF3D2A04DFF") CallbackInstance
: public Microsoft::WRL::RuntimeClass, ITerminationCallback, IFastRundown>
{
@@ -74,7 +56,7 @@ try
wil::com_ptr session;
THROW_IF_FAILED(CoCreateInstance(__uuidof(WSLAUserSession), nullptr, CLSCTX_LOCAL_SERVER, IID_PPV_ARGS(&session)));
- ConfigureComSecurity(session.get());
+ wsl::windows::common::security::ConfigureForCOMImpersonation(session.get());
wil::com_ptr virtualMachineInstance;
@@ -91,7 +73,7 @@ try
settings.EnableGPU = UserSettings->GPU.Enable;
THROW_IF_FAILED(session->CreateVirtualMachine(&settings, &virtualMachineInstance));
- ConfigureComSecurity(virtualMachineInstance.get());
+ wsl::windows::common::security::ConfigureForCOMImpersonation(virtualMachineInstance.get());
// Register termination callback, if specified
if (UserSettings->Options.TerminationCallback != nullptr)
diff --git a/src/windows/wslaservice/exe/CMakeLists.txt b/src/windows/wslaservice/exe/CMakeLists.txt
index 8d51923..46260d1 100644
--- a/src/windows/wslaservice/exe/CMakeLists.txt
+++ b/src/windows/wslaservice/exe/CMakeLists.txt
@@ -2,12 +2,14 @@ set(SOURCES
application.manifest
main.rc
ServiceMain.cpp
+ WSLASession.cpp
WSLAUserSession.cpp
WSLAUserSessionFactory.cpp
WSLAVirtualMachine.cpp
)
set(HEADERS
+ WSLASession.h
WSLAUserSession.h
WSLAUserSessionFactory.h
WSLAVirtualMachine.h)
diff --git a/src/windows/wslaservice/exe/WSLASession.cpp b/src/windows/wslaservice/exe/WSLASession.cpp
new file mode 100644
index 0000000..617ec40
--- /dev/null
+++ b/src/windows/wslaservice/exe/WSLASession.cpp
@@ -0,0 +1,27 @@
+/*++
+
+Copyright (c) Microsoft. All rights reserved.
+
+Module Name:
+
+ WSLASession.cpp
+
+Abstract:
+
+ This file contains the implementation of the WSLASession COM class.
+
+--*/
+
+#include "WSLASession.h"
+
+using wsl::windows::service::wsla::WSLASession;
+
+WSLASession::WSLASession(const WSLA_SESSION_SETTINGS& Settings) : m_displayName(Settings.DisplayName)
+{
+}
+
+HRESULT WSLASession::GetDisplayName(LPWSTR* DisplayName)
+{
+ *DisplayName = wil::make_unique_string(m_displayName.c_str()).release();
+ return S_OK;
+}
diff --git a/src/windows/wslaservice/exe/WSLASession.h b/src/windows/wslaservice/exe/WSLASession.h
new file mode 100644
index 0000000..55fcbc0
--- /dev/null
+++ b/src/windows/wslaservice/exe/WSLASession.h
@@ -0,0 +1,32 @@
+/*++
+
+Copyright (c) Microsoft. All rights reserved.
+
+Module Name:
+
+ WSLASession.h
+
+Abstract:
+
+ TODO
+
+--*/
+
+#pragma once
+
+#include "wslaservice.h"
+
+namespace wsl::windows::service::wsla {
+
+class DECLSPEC_UUID("4877FEFC-4977-4929-A958-9F36AA1892A4") WSLASession
+ : public Microsoft::WRL::RuntimeClass, IWSLASession, IFastRundown>
+{
+public:
+ WSLASession(const WSLA_SESSION_SETTINGS& Settings);
+ IFACEMETHOD(GetDisplayName)(LPWSTR* DisplayName);
+
+private:
+ std::wstring m_displayName;
+};
+
+} // namespace wsl::windows::service::wsla
\ No newline at end of file
diff --git a/src/windows/wslaservice/exe/WSLAUserSession.cpp b/src/windows/wslaservice/exe/WSLAUserSession.cpp
index deb6599..d75ba00 100644
--- a/src/windows/wslaservice/exe/WSLAUserSession.cpp
+++ b/src/windows/wslaservice/exe/WSLAUserSession.cpp
@@ -11,7 +11,9 @@ Abstract:
TODO
--*/
+
#include "WSLAUserSession.h"
+#include "WSLASession.h"
using wsl::windows::service::wsla::WSLAUserSessionImpl;
@@ -24,7 +26,7 @@ WSLAUserSessionImpl::~WSLAUserSessionImpl()
{
// Manually signal the VM termination events. This prevents being stuck on an API call that holds the VM lock.
{
- std::lock_guard lock(m_virtualMachinesLock);
+ std::lock_guard lock(m_lock);
for (auto* e : m_virtualMachines)
{
@@ -35,7 +37,7 @@ WSLAUserSessionImpl::~WSLAUserSessionImpl()
void WSLAUserSessionImpl::OnVmTerminated(WSLAVirtualMachine* machine)
{
- std::lock_guard lock(m_virtualMachinesLock);
+ std::lock_guard lock(m_lock);
auto pred = [machine](const auto* e) { return machine == e; };
// Remove any stale VM reference.
@@ -47,7 +49,7 @@ HRESULT WSLAUserSessionImpl::CreateVirtualMachine(const VIRTUAL_MACHINE_SETTINGS
auto vm = wil::MakeOrThrow(*Settings, GetUserSid(), this);
{
- std::lock_guard lock(m_virtualMachinesLock);
+ std::lock_guard lock(m_lock);
m_virtualMachines.emplace_back(vm.Get());
}
@@ -62,6 +64,20 @@ PSID WSLAUserSessionImpl::GetUserSid() const
return m_tokenInfo->User.Sid;
}
+HRESULT wsl::windows::service::wsla::WSLAUserSessionImpl::CreateSession(const WSLA_SESSION_SETTINGS* Settings, IWSLASession** Session)
+{
+ auto session = wil::MakeOrThrow(*Settings);
+
+ {
+ std::lock_guard lock(m_lock);
+ m_sessions.emplace_back(session.Get());
+ }
+
+ THROW_IF_FAILED(session.CopyTo(__uuidof(IWSLASession), (void**)Session));
+
+ return S_OK;
+}
+
wsl::windows::service::wsla::WSLAUserSession::WSLAUserSession(std::weak_ptr&& Session) :
m_session(std::move(Session))
{
@@ -84,4 +100,14 @@ try
return session->CreateVirtualMachine(Settings, VirtualMachine);
}
-CATCH_RETURN();
\ No newline at end of file
+CATCH_RETURN();
+
+HRESULT wsl::windows::service::wsla::WSLAUserSession::CreateSession(const WSLA_SESSION_SETTINGS* Settings, IWSLASession** Session)
+try
+{
+ auto session = m_session.lock();
+ RETURN_HR_IF(RPC_E_DISCONNECTED, !session);
+
+ return session->CreateSession(Settings, Session);
+}
+CATCH_RETURN();
diff --git a/src/windows/wslaservice/exe/WSLAUserSession.h b/src/windows/wslaservice/exe/WSLAUserSession.h
index e1c1762..391b023 100644
--- a/src/windows/wslaservice/exe/WSLAUserSession.h
+++ b/src/windows/wslaservice/exe/WSLAUserSession.h
@@ -11,8 +11,10 @@ Abstract:
TODO
--*/
+
#pragma once
#include "WSLAVirtualMachine.h"
+#include "WSLASession.h"
namespace wsl::windows::service::wsla {
@@ -28,14 +30,18 @@ public:
PSID GetUserSid() const;
HRESULT CreateVirtualMachine(const VIRTUAL_MACHINE_SETTINGS* Settings, IWSLAVirtualMachine** VirtualMachine);
+ HRESULT CreateSession(const WSLA_SESSION_SETTINGS* Settings, IWSLASession** Session);
void OnVmTerminated(WSLAVirtualMachine* machine);
private:
wil::unique_tokeninfo_ptr m_tokenInfo;
- std::recursive_mutex m_virtualMachinesLock;
+ std::recursive_mutex m_lock;
std::vector m_virtualMachines;
+
+ // TODO-WSLA: Consider using a weak_ptr to easily destroy when the last client reference is released.
+ std::vector> m_sessions;
};
class DECLSPEC_UUID("a9b7a1b9-0671-405c-95f1-e0612cb4ce8f") WSLAUserSession
@@ -48,6 +54,7 @@ public:
IFACEMETHOD(GetVersion)(_Out_ WSL_VERSION* Version) override;
IFACEMETHOD(CreateVirtualMachine)(const VIRTUAL_MACHINE_SETTINGS* Settings, IWSLAVirtualMachine** VirtualMachine) override;
+ IFACEMETHOD(CreateSession)(const WSLA_SESSION_SETTINGS* Settings, IWSLASession** Session);
private:
std::weak_ptr m_session;
diff --git a/src/windows/wslaservice/inc/wslaservice.idl b/src/windows/wslaservice/inc/wslaservice.idl
index 831beff..3698719 100644
--- a/src/windows/wslaservice/inc/wslaservice.idl
+++ b/src/windows/wslaservice/inc/wslaservice.idl
@@ -102,6 +102,23 @@ struct _VIRTUAL_MACHINE_SETTINGS {
} VIRTUAL_MACHINE_SETTINGS;
+typedef
+struct _WSLA_SESSION_SETTINGS {
+ LPCWSTR DisplayName;
+ // Details TBD.
+} WSLA_SESSION_SETTINGS;
+
+
+[
+ uuid(EF0661E4-6364-40EA-B433-E2FDF11F3519),
+ pointer_default(unique),
+ object
+]
+interface IWSLASession : IUnknown
+{
+ HRESULT GetDisplayName([out] LPWSTR* DisplayName);
+}
+
[
uuid(82A7ABC8-6B50-43FC-AB96-15FBBE7E8760),
pointer_default(unique),
@@ -111,4 +128,5 @@ interface IWSLAUserSession : IUnknown
{
HRESULT GetVersion([out] WSL_VERSION* Error);
HRESULT CreateVirtualMachine([in] const VIRTUAL_MACHINE_SETTINGS* Settings, [out]IWSLAVirtualMachine** VirtualMachine);
+ HRESULT CreateSession([in] const WSLA_SESSION_SETTINGS* Settings, [out]IWSLASession** Session);
}
\ No newline at end of file
diff --git a/test/windows/WSLATests.cpp b/test/windows/WSLATests.cpp
index ef1e5b2..e27a8b6 100644
--- a/test/windows/WSLATests.cpp
+++ b/test/windows/WSLATests.cpp
@@ -15,6 +15,7 @@ Abstract:
#include "precomp.h"
#include "Common.h"
#include "WSLAApi.h"
+#include "wslaservice.h"
using namespace wsl::windows::common::registry;
@@ -1067,4 +1068,21 @@ class WSLATests
// Validate that xsk_diag is now loaded.
VERIFY_ARE_EQUAL(RunCommand(vm.get(), {"/bin/bash", "-c", "lsmod | grep ^xsk_diag"}), 0);
}
+
+ TEST_METHOD(CreateSessionSmokeTest)
+ {
+ wil::com_ptr userSession;
+ VERIFY_SUCCEEDED(CoCreateInstance(__uuidof(WSLAUserSession), nullptr, CLSCTX_LOCAL_SERVER, IID_PPV_ARGS(&userSession)));
+ wsl::windows::common::security::ConfigureForCOMImpersonation(userSession.get());
+
+ WSLA_SESSION_SETTINGS settings{L"my-display-name"};
+ wil::com_ptr session;
+
+ VERIFY_SUCCEEDED(userSession->CreateSession(&settings, &session));
+
+ wil::unique_cotaskmem_string returnedDisplayName;
+ VERIFY_SUCCEEDED(session->GetDisplayName(&returnedDisplayName));
+
+ VERIFY_ARE_EQUAL(returnedDisplayName.get(), std::wstring(L"my-display-name"));
+ }
};
\ No newline at end of file