diff --git a/src/windows/common/svccomm.cpp b/src/windows/common/svccomm.cpp index 0af212e..399b467 100644 --- a/src/windows/common/svccomm.cpp +++ b/src/windows/common/svccomm.cpp @@ -540,16 +540,13 @@ wsl::windows::common::SvcComm::SvcComm() }; wsl::shared::retry::RetryWithTimeout( - [this]() { - THROW_IF_FAILED(CoCreateInstance(__uuidof(LxssUserSession), nullptr, CLSCTX_LOCAL_SERVER, IID_PPV_ARGS(&m_userSession))); - }, + [this]() { m_userSession = wil::CoCreateInstance(CLSCTX_LOCAL_SERVER); }, std::chrono::seconds(1), std::chrono::minutes(1), retry_pred); // Query client security interface. - wil::com_ptr_nothrow clientSecurity; - THROW_IF_FAILED(m_userSession->QueryInterface(IID_PPV_ARGS(&clientSecurity))); + auto clientSecurity = m_userSession.query(); // Get the current proxy blanket settings. DWORD authnSvc, authzSvc, authnLvl, capabilities; diff --git a/src/windows/inc/comservicehelper.h b/src/windows/inc/comservicehelper.h index 9c412db..227cf3a 100644 --- a/src/windows/inc/comservicehelper.h +++ b/src/windows/inc/comservicehelper.h @@ -77,7 +77,7 @@ namespace Windows { namespace Internal { // Tell COM how to mask fatal exceptions. if (ownProcess) { - Microsoft::WRL::ComPtr pIGLB; + wil::com_ptr pIGLB; RETURN_IF_FAILED(CoCreateInstance(CLSID_GlobalOptions, nullptr, CLSCTX_INPROC_SERVER, IID_PPV_ARGS(&pIGLB))); RETURN_IF_FAILED(pIGLB->Set(COMGLB_EXCEPTION_HANDLING, TExceptionPolicy)); } @@ -294,7 +294,7 @@ namespace Windows { namespace Internal { bool m_addedModuleReference = false; // COM callback object to support unloading shared-process services - Microsoft::WRL::ComPtr m_icc; + wil::com_ptr m_icc; // COM Server descriptor ServerDescriptor m_serverDescriptor{}; diff --git a/src/windows/service/exe/LxssIpTables.cpp b/src/windows/service/exe/LxssIpTables.cpp index afd1cc0..2b7a143 100644 --- a/src/windows/service/exe/LxssIpTables.cpp +++ b/src/windows/service/exe/LxssIpTables.cpp @@ -335,7 +335,7 @@ const std::wstring LxssNetworkingFirewall::s_FriendlyNamePrefix(L"WSLRULE_177744 LxssNetworkingFirewall::LxssNetworkingFirewall() { - THROW_IF_FAILED(::CoCreateInstance(__uuidof(NetFwPolicy2), NULL, CLSCTX_INPROC_SERVER, IID_PPV_ARGS(&m_firewall))); + m_firewall = wil::CoCreateInstance(CLSCTX_INPROC_SERVER); } void LxssNetworkingFirewall::CopyPartialArray(SAFEARRAY* Destination, SAFEARRAY* Source, ULONG DestinationIndexStart, ULONG SourceIndexStart, ULONG ElementsToCopy) @@ -388,8 +388,7 @@ void LxssNetworkingFirewall::CopyPartialArray(SAFEARRAY* Destination, SAFEARRAY* std::wstring LxssNetworkingFirewall::AddPortRule(const IP_ADDRESS_PREFIX& Address) const { - Microsoft::WRL::ComPtr newRule; - THROW_IF_FAILED(::CoCreateInstance(__uuidof(NetFwRule), NULL, CLSCTX_INPROC_SERVER, IID_PPV_ARGS(&newRule))); + auto newRule = wil::CoCreateInstance(CLSCTX_INPROC_SERVER); // Open a port via the firewall by creating a rule that specifies the local // address and the local port to allow. Currently this rule only applies to @@ -412,9 +411,9 @@ std::wstring LxssNetworkingFirewall::AddPortRule(const IP_ADDRESS_PREFIX& Addres THROW_IF_FAILED(newRule->put_Description(s_DefaultRuleDescription.get())); THROW_IF_FAILED(newRule->put_Enabled(VARIANT_TRUE)); // Add the rule to the existing set. - Microsoft::WRL::ComPtr rules; + wil::com_ptr rules; THROW_IF_FAILED(m_firewall->get_Rules(&rules)); - THROW_IF_FAILED(rules->Add(newRule.Get())); + THROW_IF_FAILED(rules->Add(newRule.get())); // Return the unique rule name to the caller. return generatedName; } @@ -423,12 +422,11 @@ void LxssNetworkingFirewall::CleanupRemnants() { auto firewall = std::make_shared(); THROW_HR_IF(E_OUTOFMEMORY, !firewall); - Microsoft::WRL::ComPtr rules; + wil::com_ptr rules; THROW_IF_FAILED(firewall->m_firewall->get_Rules(&rules)); - Microsoft::WRL::ComPtr enumInterface; - THROW_IF_FAILED(rules->get__NewEnum(enumInterface.GetAddressOf())); - Microsoft::WRL::ComPtr rulesEnum; - THROW_IF_FAILED(enumInterface.As(&rulesEnum)); + wil::com_ptr enumInterface; + THROW_IF_FAILED(rules->get__NewEnum(enumInterface.addressof())); + auto rulesEnum = enumInterface.query(); // Find any rules with the unique WSL prefix and destroy them. for (;;) { @@ -440,7 +438,7 @@ void LxssNetworkingFirewall::CleanupRemnants() break; } - Microsoft::WRL::ComPtr nextRule; + wil::com_ptr nextRule; THROW_IF_FAILED(next.pdispVal->QueryInterface(IID_PPV_ARGS(&nextRule))); wil::unique_bstr nextRuleName; THROW_IF_FAILED(nextRule->get_Name(nextRuleName.addressof())); @@ -558,7 +556,7 @@ void LxssNetworkingFirewall::RemoveExcludedAdapter(const std::wstring& AdapterNa void LxssNetworkingFirewall::RemovePortRule(const std::wstring& RuleName) const { - Microsoft::WRL::ComPtr rules; + wil::com_ptr rules; THROW_IF_FAILED(m_firewall->get_Rules(&rules)); THROW_IF_FAILED(rules->Remove(wil::make_bstr_failfast(RuleName.c_str()).get())); } @@ -572,8 +570,7 @@ LxssNetworkingFirewallPort::LxssNetworkingFirewallPort(const std::shared_ptr& Firewall, const Microsoft::WRL::ComPtr& Existing) : +LxssNetworkingFirewallPort::LxssNetworkingFirewallPort(const std::shared_ptr& Firewall, const wil::com_ptr& Existing) : m_firewall(Firewall) { wil::unique_bstr ruleName; diff --git a/src/windows/service/exe/LxssIpTables.h b/src/windows/service/exe/LxssIpTables.h index fca26e3..0429d9a 100644 --- a/src/windows/service/exe/LxssIpTables.h +++ b/src/windows/service/exe/LxssIpTables.h @@ -262,7 +262,7 @@ private: /// /// COM firewall instance. /// - Microsoft::WRL::ComPtr m_firewall; + wil::com_ptr m_firewall; /// /// Lock to protect class members. @@ -295,7 +295,7 @@ public: /// /// Constructor to take ownership of an existing rule. /// - LxssNetworkingFirewallPort(const std::shared_ptr& Firewall, const Microsoft::WRL::ComPtr& Existing); + LxssNetworkingFirewallPort(const std::shared_ptr& Firewall, const wil::com_ptr& Existing); /// /// Destructor. diff --git a/src/windows/service/exe/LxssUserSession.cpp b/src/windows/service/exe/LxssUserSession.cpp index 2e5611c..1471b6b 100644 --- a/src/windows/service/exe/LxssUserSession.cpp +++ b/src/windows/service/exe/LxssUserSession.cpp @@ -2674,8 +2674,7 @@ try THROW_IF_FAILED(shellLink->SetArguments(commandLine.c_str())); THROW_IF_FAILED(shellLink->SetIconLocation(ShortcutIcon, 0)); - Microsoft::WRL::ComPtr storage; - THROW_IF_FAILED(shellLink->QueryInterface(IID_IPersistFile, &storage)); + auto storage = shellLink.query(); THROW_IF_FAILED(storage->Save(shortcutPath.c_str(), true)); registration.Write(Property::ShortcutPath, shortcutPath.c_str()); diff --git a/src/windows/service/exe/WslCoreVm.h b/src/windows/service/exe/WslCoreVm.h index d04aefa..c2e2e37 100644 --- a/src/windows/service/exe/WslCoreVm.h +++ b/src/windows/service/exe/WslCoreVm.h @@ -351,7 +351,7 @@ private: wsl::shared::SocketChannel m_miniInitChannel; wil::unique_socket m_notifyChannel; SE_SID m_userSid; - Microsoft::WRL::ComPtr m_deviceHostSupport; + wil::com_ptr m_deviceHostSupport; std::shared_ptr m_systemDistro; _Guarded_by_(m_lock) std::bitset m_lunBitmap; _Guarded_by_(m_lock) std::map m_attachedDisks; diff --git a/test/windows/PolicyTests.cpp b/test/windows/PolicyTests.cpp index 2fa0be7..d88f6a1 100644 --- a/test/windows/PolicyTests.cpp +++ b/test/windows/PolicyTests.cpp @@ -327,7 +327,7 @@ class PolicyTest const auto stop = std::chrono::steady_clock::now() + std::chrono::seconds{30}; for (;;) { - Microsoft::WRL::ComPtr session; + wil::com_ptr session; result = CoCreateInstance(CLSID_LxssUserSession, nullptr, CLSCTX_LOCAL_SERVER, IID_PPV_ARGS(&session)); if (result == expectedResult || std::chrono::steady_clock::now() > stop) { diff --git a/test/windows/UnitTests.cpp b/test/windows/UnitTests.cpp index 7a2624a..0521431 100644 --- a/test/windows/UnitTests.cpp +++ b/test/windows/UnitTests.cpp @@ -2448,8 +2448,7 @@ Error code: Wsl/InstallDistro/WSL_E_DISTRO_NOT_FOUND // Validate that the shortcut is actually in the start menu VERIFY_IS_TRUE(shortcutPath.find(startMenu) != std::string::npos); - Microsoft::WRL::ComPtr storage; - VERIFY_SUCCEEDED(shellLink->QueryInterface(IID_IPersistFile, &storage)); + auto storage = shellLink.query(); VERIFY_SUCCEEDED(storage->Load(shortcutPath.c_str(), 0));