diff --git a/src/cascadia/TerminalConnection/CTerminalHandoff.cpp b/src/cascadia/TerminalConnection/CTerminalHandoff.cpp index a6edb23243..01b25c9ffe 100644 --- a/src/cascadia/TerminalConnection/CTerminalHandoff.cpp +++ b/src/cascadia/TerminalConnection/CTerminalHandoff.cpp @@ -14,6 +14,13 @@ static DWORD g_cTerminalHandoffRegistration = 0; // Mutex so we only do start/stop/establish one at a time. static std::shared_mutex _mtx; +// This is the callback that will be called when a connection is received. +// Call this once during startup and don't ever change it again (race condition). +void CTerminalHandoff::s_setCallback(NewHandoffFunction callback) noexcept +{ + _pfnHandoff = callback; +} + // Routine Description: // - Starts listening for TerminalHandoff requests by registering // our class and interface with COM. @@ -21,24 +28,19 @@ static std::shared_mutex _mtx; // - pfnHandoff - Function to callback when a handoff is received // Return Value: // - S_OK, E_NOT_VALID_STATE (start called when already started) or relevant COM registration error. -HRESULT CTerminalHandoff::s_StartListening(NewHandoffFunction pfnHandoff) +HRESULT CTerminalHandoff::s_StartListening() try { std::unique_lock lock{ _mtx }; - RETURN_HR_IF(E_NOT_VALID_STATE, _pfnHandoff != nullptr); - const auto classFactory = Make>(); - - RETURN_IF_NULL_ALLOC(classFactory); + RETURN_LAST_ERROR_IF_NULL(classFactory); ComPtr unk; RETURN_IF_FAILED(classFactory.As(&unk)); RETURN_IF_FAILED(CoRegisterClassObject(__uuidof(CTerminalHandoff), unk.Get(), CLSCTX_LOCAL_SERVER, REGCLS_SINGLEUSE, &g_cTerminalHandoffRegistration)); - _pfnHandoff = pfnHandoff; - return S_OK; } CATCH_RETURN() @@ -53,15 +55,6 @@ CATCH_RETURN() HRESULT CTerminalHandoff::s_StopListening() { std::unique_lock lock{ _mtx }; - return s_StopListeningLocked(); -} - -// See s_StopListening() -HRESULT CTerminalHandoff::s_StopListeningLocked() -{ - RETURN_HR_IF_NULL(E_NOT_VALID_STATE, _pfnHandoff); - - _pfnHandoff = nullptr; if (g_cTerminalHandoffRegistration) { @@ -92,22 +85,15 @@ HRESULT CTerminalHandoff::EstablishPtyHandoff(HANDLE* in, HANDLE* out, HANDLE si { try { - std::unique_lock lock{ _mtx }; - - // s_StopListeningLocked sets _pfnHandoff to nullptr. - // localPfnHandoff is tested for nullness below. -#pragma warning(suppress : 26429) // Symbol '...' is never tested for nullness, it can be marked as not_null (f.23). - auto localPfnHandoff = _pfnHandoff; - // Because we are REGCLS_SINGLEUSE... we need to `CoRevokeClassObject` after we handle this ONE call. // COM does not automatically clean that up for us. We must do it. - LOG_IF_FAILED(s_StopListeningLocked()); + LOG_IF_FAILED(s_StopListening()); // Report an error if no one registered a handoff function before calling this. - THROW_HR_IF_NULL(E_NOT_VALID_STATE, localPfnHandoff); + THROW_HR_IF_NULL(E_NOT_VALID_STATE, _pfnHandoff); // Call registered handler from when we started listening. - THROW_IF_FAILED(localPfnHandoff(in, out, signal, reference, server, client, startupInfo)); + THROW_IF_FAILED(_pfnHandoff(in, out, signal, reference, server, client, startupInfo)); #pragma warning(suppress : 26477) TraceLoggingWrite( diff --git a/src/cascadia/TerminalConnection/CTerminalHandoff.h b/src/cascadia/TerminalConnection/CTerminalHandoff.h index 440a2636f2..004b3f5274 100644 --- a/src/cascadia/TerminalConnection/CTerminalHandoff.h +++ b/src/cascadia/TerminalConnection/CTerminalHandoff.h @@ -38,11 +38,11 @@ struct __declspec(uuid(__CLSID_CTerminalHandoff)) #pragma endregion - static HRESULT s_StartListening(NewHandoffFunction pfnHandoff); - static HRESULT s_StopListening(); + static void s_setCallback(NewHandoffFunction callback) noexcept; + static HRESULT s_StartListening(); private: - static HRESULT s_StopListeningLocked(); + static HRESULT s_StopListening(); }; // Disable warnings from the CoCreatableClass macro as the value it provides for diff --git a/src/cascadia/TerminalConnection/ConptyConnection.cpp b/src/cascadia/TerminalConnection/ConptyConnection.cpp index 0a0b26aa93..95517373c7 100644 --- a/src/cascadia/TerminalConnection/ConptyConnection.cpp +++ b/src/cascadia/TerminalConnection/ConptyConnection.cpp @@ -780,12 +780,12 @@ namespace winrt::Microsoft::Terminal::TerminalConnection::implementation void ConptyConnection::StartInboundListener() { - THROW_IF_FAILED(CTerminalHandoff::s_StartListening(&ConptyConnection::NewHandoff)); - } + static const auto init = []() noexcept { + CTerminalHandoff::s_setCallback(&ConptyConnection::NewHandoff); + return true; + }(); - void ConptyConnection::StopInboundListener() - { - THROW_IF_FAILED(CTerminalHandoff::s_StopListening()); + CTerminalHandoff::s_StartListening(); } // Function Description: diff --git a/src/cascadia/TerminalConnection/ConptyConnection.h b/src/cascadia/TerminalConnection/ConptyConnection.h index ca67588f64..69edd7b7a8 100644 --- a/src/cascadia/TerminalConnection/ConptyConnection.h +++ b/src/cascadia/TerminalConnection/ConptyConnection.h @@ -36,7 +36,6 @@ namespace winrt::Microsoft::Terminal::TerminalConnection::implementation WORD ShowWindow() const noexcept; static void StartInboundListener(); - static void StopInboundListener(); static winrt::event_token NewConnection(const NewConnectionHandler& handler); static void NewConnection(const winrt::event_token& token); diff --git a/src/cascadia/TerminalConnection/ConptyConnection.idl b/src/cascadia/TerminalConnection/ConptyConnection.idl index 09466b9d11..e1ed83cc0c 100644 --- a/src/cascadia/TerminalConnection/ConptyConnection.idl +++ b/src/cascadia/TerminalConnection/ConptyConnection.idl @@ -23,7 +23,6 @@ namespace Microsoft.Terminal.TerminalConnection static event NewConnectionHandler NewConnection; static void StartInboundListener(); - static void StopInboundListener(); static Windows.Foundation.Collections.ValueSet CreateSettings(String cmdline, String startingDirectory,