diff --git a/src/windows/service/exe/VirtioNetworking.cpp b/src/windows/service/exe/VirtioNetworking.cpp index d6190fb..16ba97a 100644 --- a/src/windows/service/exe/VirtioNetworking.cpp +++ b/src/windows/service/exe/VirtioNetworking.cpp @@ -12,10 +12,12 @@ using wsl::core::VirtioNetworking; static constexpr auto c_loopbackDeviceName = TEXT(LX_INIT_LOOPBACK_DEVICE_NAME); -VirtioNetworking::VirtioNetworking(const std::wstring& vmId, const GUID& runtimeId, GnsChannel&& gnsChannel, bool enableLocalhostRelay) : +VirtioNetworking::VirtioNetworking( + const std::wstring& vmId, const GUID& runtimeId, GnsChannel&& gnsChannel, bool enableLocalhostRelay, const wil::shared_handle& userToken) : m_deviceHostProxy(wil::MakeOrThrow(vmId, runtimeId)), m_gnsChannel(std::move(gnsChannel)), - m_enableLocalhostRelay(enableLocalhostRelay) + m_enableLocalhostRelay(enableLocalhostRelay), + m_userToken(userToken) { } @@ -372,19 +374,28 @@ std::optional VirtioNetworking::FindVirtioInterfaceLuid(const SOCKADD return ipv4Connected ? VirtioLuid.Value : std::optional(); } -GUID VirtioNetworking::AddGuestDevice(const GUID& clsid, const GUID& deviceId, PCWSTR tag, PCWSTR options) +GUID VirtioNetworking::AddGuestDevice(const GUID& clsid, const GUID& deviceId, PCWSTR tag, PCWSTR path) { auto lock = m_guestDeviceLock.lock_exclusive(); - // Get or create the Plan9 file system for this device - auto server = m_deviceHostProxy->GetRemoteFileSystem(clsid, c_defaultTag); - if (!server) + wil::com_ptr server; + + // Impersonate the user token when creating/accessing the Plan9 file system { - server = wil::CoCreateInstance(__uuidof(p9fs::Plan9FileSystem)); - m_deviceHostProxy->AddRemoteFileSystem(clsid, c_defaultTag, server); + auto revert = wil::impersonate_token(m_userToken.get()); + + server = m_deviceHostProxy->GetRemoteFileSystem(clsid, c_defaultTag); + if (!server) + { + server = wil::CoCreateInstance(clsid, (CLSCTX_LOCAL_SERVER | CLSCTX_ENABLE_CLOAKING | CLSCTX_ENABLE_AAA)); + m_deviceHostProxy->AddRemoteFileSystem(clsid, c_defaultTag, server); + } + + THROW_IF_FAILED(server->AddSharePath(tag, path, 0)); } - return m_deviceHostProxy->AddNewDevice(deviceId, server, tag); + const std::wstring virtioTag(tag); + return m_deviceHostProxy->AddNewDevice(deviceId, server, virtioTag); } int VirtioNetworking::ModifyOpenPorts(const GUID& clsid, PCWSTR tag, const SOCKADDR_INET& addr, int protocol, bool isOpen) const diff --git a/src/windows/service/exe/VirtioNetworking.h b/src/windows/service/exe/VirtioNetworking.h index 56cfa01..9a5613c 100644 --- a/src/windows/service/exe/VirtioNetworking.h +++ b/src/windows/service/exe/VirtioNetworking.h @@ -13,7 +13,7 @@ namespace wsl::core { class VirtioNetworking : public INetworkingEngine { public: - VirtioNetworking(const std::wstring& vmId, const GUID& runtimeId, GnsChannel&& gnsChannel, bool enableLocalhostRelay); + VirtioNetworking(const std::wstring& vmId, const GUID& runtimeId, GnsChannel&& gnsChannel, bool enableLocalhostRelay, const wil::shared_handle& userToken); ~VirtioNetworking(); // Note: This class cannot be moved because m_networkNotifyHandle captures a 'this' pointer. @@ -40,12 +40,13 @@ private: void UpdateDns(wsl::shared::hns::DNS&& dnsSettings); void UpdateMtu(); - GUID AddGuestDevice(const GUID& clsid, const GUID& deviceId, PCWSTR tag, PCWSTR options); + GUID AddGuestDevice(const GUID& clsid, const GUID& deviceId, PCWSTR tag, PCWSTR path); int ModifyOpenPorts(const GUID& clsid, PCWSTR tag, const SOCKADDR_INET& addr, int protocol, bool isOpen) const; mutable wil::srwlock m_lock; mutable wil::srwlock m_guestDeviceLock; + wil::shared_handle m_userToken; wil::com_ptr m_deviceHostProxy; GnsChannel m_gnsChannel; std::optional m_gnsPortTrackerChannel; diff --git a/src/windows/service/exe/WslCoreVm.cpp b/src/windows/service/exe/WslCoreVm.cpp index d05acb1..b2304a2 100644 --- a/src/windows/service/exe/WslCoreVm.cpp +++ b/src/windows/service/exe/WslCoreVm.cpp @@ -606,7 +606,7 @@ void WslCoreVm::Initialize(const GUID& VmId, const wil::shared_handle& UserToken else if (m_vmConfig.NetworkingMode == NetworkingMode::VirtioProxy) { m_networkingEngine = std::make_unique( - m_machineId, m_runtimeId, std::move(gnsChannel), m_vmConfig.EnableLocalhostRelay); + m_machineId, m_runtimeId, std::move(gnsChannel), m_vmConfig.EnableLocalhostRelay, m_userToken); } else if (m_vmConfig.NetworkingMode == NetworkingMode::Bridged) {