diff --git a/src/host/ApiRoutines.h b/src/host/ApiRoutines.h index f6fc4ad4d8..9d6affe33b 100644 --- a/src/host/ApiRoutines.h +++ b/src/host/ApiRoutines.h @@ -56,12 +56,12 @@ public: const bool IsUnicode, const bool IsPeek, const bool IsWaitAllowed, - std::unique_ptr& waiter) noexcept override; + CONSOLE_API_MSG* pWaitReplyMessage) noexcept override; [[nodiscard]] HRESULT ReadConsoleImpl(IConsoleInputObject& context, std::span buffer, size_t& written, - std::unique_ptr& waiter, + CONSOLE_API_MSG* pWaitReplyMessage, const std::wstring_view initialData, const std::wstring_view exeName, INPUT_READ_HANDLE_DATA& readHandleState, @@ -73,12 +73,12 @@ public: [[nodiscard]] HRESULT WriteConsoleAImpl(IConsoleOutputObject& context, const std::string_view buffer, size_t& read, - std::unique_ptr& waiter) noexcept override; + CONSOLE_API_MSG* pWaitReplyMessage) noexcept override; [[nodiscard]] HRESULT WriteConsoleWImpl(IConsoleOutputObject& context, const std::wstring_view buffer, size_t& read, - std::unique_ptr& waiter) noexcept override; + CONSOLE_API_MSG* pWaitReplyMessage) noexcept override; #pragma region ThreadCreationInfo [[nodiscard]] HRESULT GetConsoleLangIdImpl(LANGID& langId) noexcept override; diff --git a/src/host/_stream.cpp b/src/host/_stream.cpp index 3090151540..1bb023a861 100644 --- a/src/host/_stream.cpp +++ b/src/host/_stream.cpp @@ -415,30 +415,9 @@ void WriteClearScreen(SCREEN_INFORMATION& screenInfo) // - pwchBuffer - wide character text to be inserted into buffer // - pcbBuffer - byte count of pwchBuffer on the way in, number of bytes consumed on the way out. // - screenInfo - Screen Information class to write the text into at the current cursor position -// - ppWaiter - If writing to the console is blocked for whatever reason, this will be filled with a pointer to context -// that can be used by the server to resume the call at a later time. -// Return Value: -// - STATUS_SUCCESS if OK. -// - CONSOLE_STATUS_WAIT if we couldn't finish now and need to be called back later (see ppWaiter). -// - Or a suitable NTSTATUS format error code for memory/string/math failures. -[[nodiscard]] NTSTATUS DoWriteConsole(_In_reads_bytes_(*pcbBuffer) PCWCHAR pwchBuffer, - _Inout_ size_t* const pcbBuffer, - SCREEN_INFORMATION& screenInfo, - std::unique_ptr& waiter) +[[nodiscard]] HRESULT DoWriteConsole(SCREEN_INFORMATION& screenInfo, std::wstring_view str) try { - auto& gci = ServiceLocator::LocateGlobals().getConsoleInformation(); - if (WI_IsAnyFlagSet(gci.Flags, (CONSOLE_SUSPENDED | CONSOLE_SELECTING | CONSOLE_SCROLLBAR_TRACKING))) - { - waiter = std::make_unique(screenInfo, - pwchBuffer, - *pcbBuffer, - gci.OutputCP); - return CONSOLE_STATUS_WAIT; - } - - const std::wstring_view str{ pwchBuffer, *pcbBuffer / sizeof(WCHAR) }; - if (WI_IsAnyFlagClear(screenInfo.OutputMode, ENABLE_VIRTUAL_TERMINAL_PROCESSING | ENABLE_PROCESSED_OUTPUT)) { WriteCharsLegacy(screenInfo, str, nullptr); @@ -447,55 +426,9 @@ try { WriteCharsVT(screenInfo, str); } - - return STATUS_SUCCESS; -} -NT_CATCH_RETURN() - -// Routine Description: -// - This method performs the actual work of attempting to write to the console, converting data types as necessary -// to adapt from the server types to the legacy internal host types. -// - It operates on Unicode data only. It's assumed the text is translated by this point. -// Arguments: -// - OutContext - the console output object to write the new text into -// - pwsTextBuffer - wide character text buffer provided by client application to insert -// - cchTextBufferLength - text buffer counted in characters -// - pcchTextBufferRead - character count of the number of characters we were able to insert before returning -// - ppWaiter - If we are blocked from writing now and need to wait, this is filled with contextual data for the server to restore the call later -// Return Value: -// - S_OK if successful. -// - S_OK if we need to wait (check if ppWaiter is not nullptr). -// - Or a suitable HRESULT code for math/string/memory failures. -[[nodiscard]] HRESULT WriteConsoleWImplHelper(IConsoleOutputObject& context, - const std::wstring_view buffer, - size_t& read, - std::unique_ptr& waiter) noexcept -{ - try - { - // Set out variables in case we exit early. - read = 0; - waiter.reset(); - - // Convert characters to bytes to give to DoWriteConsole. - size_t cbTextBufferLength; - RETURN_IF_FAILED(SizeTMult(buffer.size(), sizeof(wchar_t), &cbTextBufferLength)); - - auto Status = DoWriteConsole(const_cast(buffer.data()), &cbTextBufferLength, context, waiter); - - // Convert back from bytes to characters for the resulting string length written. - read = cbTextBufferLength / sizeof(wchar_t); - - if (Status == CONSOLE_STATUS_WAIT) - { - FAIL_FAST_IF_NULL(waiter.get()); - Status = STATUS_SUCCESS; - } - - RETURN_NTSTATUS(Status); - } - CATCH_RETURN(); + return S_OK; } +CATCH_RETURN() // Routine Description: // - Writes non-Unicode formatted data into the given console output object. @@ -514,13 +447,12 @@ NT_CATCH_RETURN() [[nodiscard]] HRESULT ApiRoutines::WriteConsoleAImpl(IConsoleOutputObject& context, const std::string_view buffer, size_t& read, - std::unique_ptr& waiter) noexcept + CONSOLE_API_MSG* pWaitReplyMessage) noexcept { try { // Ensure output variables are initialized. read = 0; - waiter.reset(); if (buffer.empty()) { @@ -620,67 +552,63 @@ NT_CATCH_RETURN() wstr.resize((dbcsLength + mbPtrLength) / sizeof(wchar_t)); } - // Hold the specific version of the waiter locally so we can tinker with it if we have to store additional context. - std::unique_ptr writeDataWaiter{}; - - // Make the W version of the call - size_t wcBufferWritten{}; - const auto hr{ WriteConsoleWImplHelper(screenInfo, wstr, wcBufferWritten, writeDataWaiter) }; - - // If there is no waiter, process the byte count now. - if (nullptr == writeDataWaiter.get()) + auto& gci = ServiceLocator::LocateGlobals().getConsoleInformation(); + if (WI_IsAnyFlagSet(gci.Flags, (CONSOLE_SUSPENDED | CONSOLE_SELECTING | CONSOLE_SCROLLBAR_TRACKING))) { - // Calculate how many bytes of the original A buffer were consumed in the W version of the call to satisfy mbBufferRead. - // For UTF-8 conversions, we've already returned this information above. - if (CP_UTF8 != codepage) - { - size_t mbBufferRead{}; + const auto waiter = new WriteData(screenInfo, std::move(wstr), gci.OutputCP); - // Start by counting the number of A bytes we used in printing our W string to the screen. - try - { - mbBufferRead = GetALengthFromW(codepage, { wstr.data(), wcBufferWritten }); - } - CATCH_LOG(); - - // If we captured a byte off the string this time around up above, it means we didn't feed - // it into the WriteConsoleW above, and therefore its consumption isn't accounted for - // in the count we just made. Add +1 to compensate. - if (leadByteCaptured) - { - mbBufferRead++; - } - - // If we consumed an internally-stored lead byte this time around up above, it means that we - // fed a byte into WriteConsoleW that wasn't a part of this particular call's request. - // We need to -1 to compensate and tell the caller the right number of bytes consumed this request. - if (leadByteConsumed) - { - mbBufferRead--; - } - - read = mbBufferRead; - } - } - else - { // If there is a waiter, then we need to stow some additional information in the wait structure so // we can synthesize the correct byte count later when the wait routine is triggered. if (CP_UTF8 != codepage) { // For non-UTF8 codepages, save the lead byte captured/consumed data so we can +1 or -1 the final decoded count // in the WaitData::Notify method later. - writeDataWaiter->SetLeadByteAdjustmentStatus(leadByteCaptured, leadByteConsumed); + waiter->SetLeadByteAdjustmentStatus(leadByteCaptured, leadByteConsumed); } else { // For UTF8 codepages, just remember the consumption count from the UTF-8 parser. - writeDataWaiter->SetUtf8ConsumedCharacters(read); + waiter->SetUtf8ConsumedCharacters(read); } + + std::ignore = ConsoleWaitQueue::s_CreateWait(pWaitReplyMessage, waiter); + return CONSOLE_STATUS_WAIT; } - // Give back the waiter now that we're done with tinkering with it. - waiter.reset(writeDataWaiter.release()); + // Make the W version of the call + const auto hr = DoWriteConsole(screenInfo, wstr); + + // Calculate how many bytes of the original A buffer were consumed in the W version of the call to satisfy mbBufferRead. + // For UTF-8 conversions, we've already returned this information above. + if (CP_UTF8 != codepage) + { + size_t mbBufferRead{}; + + // Start by counting the number of A bytes we used in printing our W string to the screen. + try + { + mbBufferRead = GetALengthFromW(codepage, wstr); + } + CATCH_LOG(); + + // If we captured a byte off the string this time around up above, it means we didn't feed + // it into the WriteConsoleW above, and therefore its consumption isn't accounted for + // in the count we just made. Add +1 to compensate. + if (leadByteCaptured) + { + mbBufferRead++; + } + + // If we consumed an internally-stored lead byte this time around up above, it means that we + // fed a byte into WriteConsoleW that wasn't a part of this particular call's request. + // We need to -1 to compensate and tell the caller the right number of bytes consumed this request. + if (leadByteConsumed) + { + mbBufferRead--; + } + + read = mbBufferRead; + } return hr; } @@ -703,20 +631,24 @@ NT_CATCH_RETURN() [[nodiscard]] HRESULT ApiRoutines::WriteConsoleWImpl(IConsoleOutputObject& context, const std::wstring_view buffer, size_t& read, - std::unique_ptr& waiter) noexcept + CONSOLE_API_MSG* pWaitReplyMessage) noexcept { try { LockConsole(); auto unlock = wil::scope_exit([&] { UnlockConsole(); }); - std::unique_ptr writeDataWaiter; - RETURN_IF_FAILED(WriteConsoleWImplHelper(context.GetActiveBuffer(), buffer, read, writeDataWaiter)); + auto& gci = ServiceLocator::LocateGlobals().getConsoleInformation(); + if (WI_IsAnyFlagSet(gci.Flags, (CONSOLE_SUSPENDED | CONSOLE_SELECTING | CONSOLE_SCROLLBAR_TRACKING))) + { + std::ignore = ConsoleWaitQueue::s_CreateWait(pWaitReplyMessage, new WriteData(context, std::wstring{ buffer }, gci.OutputCP)); + return CONSOLE_STATUS_WAIT; + } - // Transfer specific waiter pointer into the generic interface wrapper. - waiter.reset(writeDataWaiter.release()); - - return S_OK; + read = 0; + auto Status = DoWriteConsole(context, buffer); + read = buffer.size(); + return Status; } CATCH_RETURN(); } diff --git a/src/host/_stream.h b/src/host/_stream.h index 4651deb800..5953367d7a 100644 --- a/src/host/_stream.h +++ b/src/host/_stream.h @@ -25,7 +25,4 @@ void WriteClearScreen(SCREEN_INFORMATION& screenInfo); // NOTE: console lock must be held when calling this routine // String has been translated to unicode at this point. -[[nodiscard]] NTSTATUS DoWriteConsole(_In_reads_bytes_(pcbBuffer) const wchar_t* pwchBuffer, - _Inout_ size_t* const pcbBuffer, - SCREEN_INFORMATION& screenInfo, - std::unique_ptr& waiter); +[[nodiscard]] HRESULT DoWriteConsole(SCREEN_INFORMATION& screenInfo, std::wstring_view str); diff --git a/src/host/directio.cpp b/src/host/directio.cpp index c7e3c292da..cd28ef07b1 100644 --- a/src/host/directio.cpp +++ b/src/host/directio.cpp @@ -58,12 +58,10 @@ using Microsoft::Console::Interactivity::ServiceLocator; const bool IsUnicode, const bool IsPeek, const bool IsWaitAllowed, - std::unique_ptr& waiter) noexcept + CONSOLE_API_MSG* pWaitReplyMessage) noexcept { try { - waiter.reset(); - if (eventReadCount == 0) { return STATUS_SUCCESS; @@ -83,9 +81,7 @@ using Microsoft::Console::Interactivity::ServiceLocator; { // If we're told to wait until later, move all of our context // to the read data object and send it back up to the server. - waiter = std::make_unique(&inputBuffer, - &readHandleState, - eventReadCount); + std::ignore = ConsoleWaitQueue::s_CreateWait(pWaitReplyMessage, new DirectReadData(&inputBuffer, &readHandleState, eventReadCount)); } return Status; } diff --git a/src/host/server.h b/src/host/server.h index 0a0f905777..d62484b3eb 100644 --- a/src/host/server.h +++ b/src/host/server.h @@ -166,9 +166,9 @@ private: MidiAudio _midiAudio; }; -#define CONSOLE_STATUS_WAIT 0xC0030001 -#define CONSOLE_STATUS_READ_COMPLETE 0xC0030002 -#define CONSOLE_STATUS_WAIT_NO_BLOCK 0xC0030003 +#define CONSOLE_STATUS_WAIT ((HRESULT)0xC0030001) +#define CONSOLE_STATUS_READ_COMPLETE ((HRESULT)0xC0030002) +#define CONSOLE_STATUS_WAIT_NO_BLOCK ((HRESULT)0xC0030003) #include "../server/ObjectHandle.h" diff --git a/src/host/stream.cpp b/src/host/stream.cpp index 11a7f6a411..d97d629b0d 100644 --- a/src/host/stream.cpp +++ b/src/host/stream.cpp @@ -340,7 +340,7 @@ NT_CATCH_RETURN() INPUT_READ_HANDLE_DATA& readHandleState, const std::wstring_view exeName, const bool unicode, - std::unique_ptr& waiter) noexcept + CONSOLE_API_MSG* pWaitReplyMessage) noexcept { auto& gci = ServiceLocator::LocateGlobals().getConsoleInformation(); RETURN_HR_IF(E_FAIL, !gci.HasActiveOutputBuffer()); @@ -364,7 +364,8 @@ NT_CATCH_RETURN() if (!cookedReadData->Read(unicode, bytesRead, controlKeyState)) { // memory will be cleaned up by wait queue - waiter.reset(cookedReadData.release()); + std::ignore = ConsoleWaitQueue::s_CreateWait(pWaitReplyMessage, cookedReadData.release()); + return CONSOLE_STATUS_WAIT; } else { @@ -468,25 +469,23 @@ NT_CATCH_RETURN() // populated. // - STATUS_SUCCESS on success // - Other NSTATUS codes as necessary -[[nodiscard]] NTSTATUS DoReadConsole(InputBuffer& inputBuffer, - const HANDLE processData, - std::span buffer, - size_t& bytesRead, - ULONG& controlKeyState, - const std::wstring_view initialData, - const DWORD ctrlWakeupMask, - INPUT_READ_HANDLE_DATA& readHandleState, - const std::wstring_view exeName, - const bool unicode, - std::unique_ptr& waiter) noexcept +[[nodiscard]] HRESULT DoReadConsole(InputBuffer& inputBuffer, + const HANDLE processData, + std::span buffer, + size_t& bytesRead, + ULONG& controlKeyState, + const std::wstring_view initialData, + const DWORD ctrlWakeupMask, + INPUT_READ_HANDLE_DATA& readHandleState, + const std::wstring_view exeName, + const bool unicode, + CONSOLE_API_MSG* pWaitReplyMessage) noexcept { try { LockConsole(); auto Unlock = wil::scope_exit([&] { UnlockConsole(); }); - waiter.reset(); - bytesRead = 0; if (buffer.size() < 1) @@ -504,17 +503,17 @@ NT_CATCH_RETURN() } else if (WI_IsFlagSet(inputBuffer.InputMode, ENABLE_LINE_INPUT)) { - return NTSTATUS_FROM_HRESULT(_ReadLineInput(inputBuffer, - processData, - buffer, - bytesRead, - controlKeyState, - initialData, - ctrlWakeupMask, - readHandleState, - exeName, - unicode, - waiter)); + return _ReadLineInput(inputBuffer, + processData, + buffer, + bytesRead, + controlKeyState, + initialData, + ctrlWakeupMask, + readHandleState, + exeName, + unicode, + pWaitReplyMessage); } else { @@ -525,7 +524,7 @@ NT_CATCH_RETURN() unicode); if (status == CONSOLE_STATUS_WAIT) { - waiter = std::make_unique(&inputBuffer, &readHandleState, gsl::narrow(buffer.size()), reinterpret_cast(buffer.data())); + std::ignore = ConsoleWaitQueue::s_CreateWait(pWaitReplyMessage, new RAW_READ_DATA(&inputBuffer, &readHandleState, gsl::narrow(buffer.size()), reinterpret_cast(buffer.data()))); } return status; } @@ -536,7 +535,7 @@ NT_CATCH_RETURN() [[nodiscard]] HRESULT ApiRoutines::ReadConsoleImpl(IConsoleInputObject& context, std::span buffer, size_t& written, - std::unique_ptr& waiter, + CONSOLE_API_MSG* pWaitReplyMessage, const std::wstring_view initialData, const std::wstring_view exeName, INPUT_READ_HANDLE_DATA& readHandleState, @@ -545,17 +544,17 @@ NT_CATCH_RETURN() const DWORD controlWakeupMask, DWORD& controlKeyState) noexcept { - return HRESULT_FROM_NT(DoReadConsole(context, - clientHandle, - buffer, - written, - controlKeyState, - initialData, - controlWakeupMask, - readHandleState, - exeName, - IsUnicode, - waiter)); + return DoReadConsole(context, + clientHandle, + buffer, + written, + controlKeyState, + initialData, + controlWakeupMask, + readHandleState, + exeName, + IsUnicode, + pWaitReplyMessage); } void UnblockWriteConsole(const DWORD dwReason) diff --git a/src/host/ut_host/ApiRoutinesTests.cpp b/src/host/ut_host/ApiRoutinesTests.cpp index 37e6ec62fd..bb2f95878b 100644 --- a/src/host/ut_host/ApiRoutinesTests.cpp +++ b/src/host/ut_host/ApiRoutinesTests.cpp @@ -372,46 +372,24 @@ class ApiRoutinesTests for (size_t i = 0; i < cchTestText; i += cchIncrement) { Log::Comment(WEX::Common::String().Format(L"Iteration %d of loop with increment %d", i, cchIncrement)); - if (fInduceWait) - { - Log::Comment(L"Blocking global output state to induce waits."); - s_AdjustOutputWait(true); - } + s_AdjustOutputWait(fInduceWait); size_t cchRead = 0; - std::unique_ptr waiter; // The increment is either the specified length or the remaining text in the string (if that is smaller). const auto cchWriteLength = std::min(cchIncrement, cchTestText - i); // Run the test method - const auto hr = _pApiRoutines->WriteConsoleAImpl(si, { pszTestText + i, cchWriteLength }, cchRead, waiter); + const auto hr = _pApiRoutines->WriteConsoleAImpl(si, { pszTestText + i, cchWriteLength }, cchRead, nullptr); - VERIFY_ARE_EQUAL(S_OK, hr, L"Successful result code from writing."); if (!fInduceWait) { - VERIFY_IS_NULL(waiter.get(), L"We should have no waiter for this case."); + VERIFY_ARE_EQUAL(S_OK, hr); VERIFY_ARE_EQUAL(cchWriteLength, cchRead, L"We should have the same character count back as 'written' that we gave in."); } else { - VERIFY_IS_NOT_NULL(waiter.get(), L"We should have a waiter for this case."); - // The cchRead is irrelevant at this point as it's not going to be returned until we're off the wait. - - Log::Comment(L"Unblocking global output state so the wait can be serviced."); - s_AdjustOutputWait(false); - Log::Comment(L"Dispatching the wait."); - auto Status = STATUS_SUCCESS; - size_t dwNumBytes = 0; - DWORD dwControlKeyState = 0; // unused but matches the pattern for read. - void* pOutputData = nullptr; // unused for writes but used for read. - const BOOL bNotifyResult = waiter->Notify(WaitTerminationReason::NoReason, FALSE, &Status, &dwNumBytes, &dwControlKeyState, &pOutputData); - - VERIFY_IS_TRUE(!!bNotifyResult, L"Wait completion on notify should be successful."); - VERIFY_ARE_EQUAL(STATUS_SUCCESS, Status, L"We should have a successful return code to pass to the caller."); - - const auto dwBytesExpected = cchWriteLength; - VERIFY_ARE_EQUAL(dwBytesExpected, dwNumBytes, L"We should have the byte length of the string we put in as the returned value."); + VERIFY_ARE_EQUAL(CONSOLE_STATUS_WAIT, hr); } } } @@ -431,43 +409,21 @@ class ApiRoutinesTests gci.LockConsole(); auto Unlock = wil::scope_exit([&] { gci.UnlockConsole(); }); - const std::wstring testText(L"Test text"); + const std::wstring_view testText(L"Test text"); - if (fInduceWait) - { - Log::Comment(L"Blocking global output state to induce waits."); - s_AdjustOutputWait(true); - } + s_AdjustOutputWait(fInduceWait); size_t cchRead = 0; - std::unique_ptr waiter; - const auto hr = _pApiRoutines->WriteConsoleWImpl(si, testText, cchRead, waiter); + const auto hr = _pApiRoutines->WriteConsoleWImpl(si, testText, cchRead, nullptr); - VERIFY_ARE_EQUAL(S_OK, hr, L"Successful result code from writing."); if (!fInduceWait) { - VERIFY_IS_NULL(waiter.get(), L"We should have no waiter for this case."); + VERIFY_ARE_EQUAL(S_OK, hr); VERIFY_ARE_EQUAL(testText.size(), cchRead, L"We should have the same character count back as 'written' that we gave in."); } else { - VERIFY_IS_NOT_NULL(waiter.get(), L"We should have a waiter for this case."); - // The cchRead is irrelevant at this point as it's not going to be returned until we're off the wait. - - Log::Comment(L"Unblocking global output state so the wait can be serviced."); - s_AdjustOutputWait(false); - Log::Comment(L"Dispatching the wait."); - auto Status = STATUS_SUCCESS; - size_t dwNumBytes = 0; - DWORD dwControlKeyState = 0; // unused but matches the pattern for read. - void* pOutputData = nullptr; // unused for writes but used for read. - const BOOL bNotifyResult = waiter->Notify(WaitTerminationReason::NoReason, TRUE, &Status, &dwNumBytes, &dwControlKeyState, &pOutputData); - - VERIFY_IS_TRUE(!!bNotifyResult, L"Wait completion on notify should be successful."); - VERIFY_ARE_EQUAL(STATUS_SUCCESS, Status, L"We should have a successful return code to pass to the caller."); - - const auto dwBytesExpected = testText.size() * sizeof(wchar_t); - VERIFY_ARE_EQUAL(dwBytesExpected, dwNumBytes, L"We should have the byte length of the string we put in as the returned value."); + VERIFY_ARE_EQUAL(CONSOLE_STATUS_WAIT, hr); } } diff --git a/src/host/ut_host/ScreenBufferTests.cpp b/src/host/ut_host/ScreenBufferTests.cpp index 31ec3a747f..77b0f303cf 100644 --- a/src/host/ut_host/ScreenBufferTests.cpp +++ b/src/host/ut_host/ScreenBufferTests.cpp @@ -2549,11 +2549,7 @@ void ScreenBufferTests::TestAltBufferVtDispatching() // We're going to write some data to either the main buffer or the alt // buffer, as if we were using the API. - std::unique_ptr waiter; - std::wstring seq = L"\x1b[5;6H"; - auto seqCb = 2 * seq.size(); - VERIFY_SUCCEEDED(DoWriteConsole(&seq[0], &seqCb, mainBuffer, waiter)); - + VERIFY_SUCCEEDED(DoWriteConsole(mainBuffer, L"\x1b[5;6H")); VERIFY_ARE_EQUAL(til::point(0, 0), mainCursor.GetPosition()); // recall: vt coordinates are (row, column), 1-indexed VERIFY_ARE_EQUAL(til::point(5, 4), altCursor.GetPosition()); @@ -2565,17 +2561,11 @@ void ScreenBufferTests::TestAltBufferVtDispatching() VERIFY_ARE_EQUAL(expectedDefaults, mainBuffer.GetAttributes()); VERIFY_ARE_EQUAL(expectedDefaults, alternate.GetAttributes()); - seq = L"\x1b[48;2;255;0;255m"; - seqCb = 2 * seq.size(); - VERIFY_SUCCEEDED(DoWriteConsole(&seq[0], &seqCb, mainBuffer, waiter)); - + VERIFY_SUCCEEDED(DoWriteConsole(mainBuffer, L"\x1b[48;2;255;0;255m")); VERIFY_ARE_EQUAL(expectedDefaults, mainBuffer.GetAttributes()); VERIFY_ARE_EQUAL(expectedRgb, alternate.GetAttributes()); - seq = L"X"; - seqCb = 2 * seq.size(); - VERIFY_SUCCEEDED(DoWriteConsole(&seq[0], &seqCb, mainBuffer, waiter)); - + VERIFY_SUCCEEDED(DoWriteConsole(mainBuffer, L"X")); VERIFY_ARE_EQUAL(til::point(0, 0), mainCursor.GetPosition()); VERIFY_ARE_EQUAL(til::point(6, 4), altCursor.GetPosition()); diff --git a/src/host/ut_host/TextBufferTests.cpp b/src/host/ut_host/TextBufferTests.cpp index 4252609e30..1ee380d70e 100644 --- a/src/host/ut_host/TextBufferTests.cpp +++ b/src/host/ut_host/TextBufferTests.cpp @@ -1390,22 +1390,18 @@ void TextBufferTests::TestBackspaceStringsAPI() // backspacing it with "\b \b". // Regardless of how we write those sequences of characters, the end result // should be the same. - std::unique_ptr waiter; Log::Comment(NoThrowString().Format( L"Using WriteCharsLegacy, write \\b \\b as a single string.")); - size_t aCb = 2; - size_t seqCb = 6; - VERIFY_SUCCEEDED(DoWriteConsole(L"a", &aCb, si, waiter)); - VERIFY_SUCCEEDED(DoWriteConsole(L"\b \b", &seqCb, si, waiter)); + VERIFY_SUCCEEDED(DoWriteConsole(si, L"a")); + VERIFY_SUCCEEDED(DoWriteConsole(si, L"\b \b")); VERIFY_ARE_EQUAL(cursor.GetPosition().x, x0); VERIFY_ARE_EQUAL(cursor.GetPosition().y, y0); - seqCb = 2; - VERIFY_SUCCEEDED(DoWriteConsole(L"a", &seqCb, si, waiter)); - VERIFY_SUCCEEDED(DoWriteConsole(L"\b", &seqCb, si, waiter)); - VERIFY_SUCCEEDED(DoWriteConsole(L" ", &seqCb, si, waiter)); - VERIFY_SUCCEEDED(DoWriteConsole(L"\b", &seqCb, si, waiter)); + VERIFY_SUCCEEDED(DoWriteConsole(si, L"a")); + VERIFY_SUCCEEDED(DoWriteConsole(si, L"\b")); + VERIFY_SUCCEEDED(DoWriteConsole(si, L" ")); + VERIFY_SUCCEEDED(DoWriteConsole(si, L"\b")); VERIFY_ARE_EQUAL(cursor.GetPosition().x, x0); VERIFY_ARE_EQUAL(cursor.GetPosition().y, y0); } diff --git a/src/host/ut_host/VtIoTests.cpp b/src/host/ut_host/VtIoTests.cpp index e557ed517d..6de23ae0fa 100644 --- a/src/host/ut_host/VtIoTests.cpp +++ b/src/host/ut_host/VtIoTests.cpp @@ -249,23 +249,22 @@ class ::Microsoft::Console::VirtualTerminal::VtIoTests resetContents(); size_t written; - std::unique_ptr waiter; std::string_view expected; std::string_view actual; - THROW_IF_FAILED(routines.WriteConsoleWImpl(*screenInfo, L"", written, waiter)); + THROW_IF_FAILED(routines.WriteConsoleWImpl(*screenInfo, L"", written, nullptr)); expected = ""; actual = readOutput(); VERIFY_ARE_EQUAL(expected, actual); // Force-wrap because we write up to the last column. - THROW_IF_FAILED(routines.WriteConsoleWImpl(*screenInfo, L"aaaaaaaa", written, waiter)); + THROW_IF_FAILED(routines.WriteConsoleWImpl(*screenInfo, L"aaaaaaaa", written, nullptr)); expected = "aaaaaaaa\r\n"; actual = readOutput(); VERIFY_ARE_EQUAL(expected, actual); // Force-wrap because we write up to the last column, but this time with a tab. - THROW_IF_FAILED(routines.WriteConsoleWImpl(*screenInfo, L"a\t\r\nb", written, waiter)); + THROW_IF_FAILED(routines.WriteConsoleWImpl(*screenInfo, L"a\t\r\nb", written, nullptr)); expected = "a\t\r\n\r\nb"; actual = readOutput(); VERIFY_ARE_EQUAL(expected, actual); diff --git a/src/host/writeData.cpp b/src/host/writeData.cpp index dcf49c54c0..6250155461 100644 --- a/src/host/writeData.cpp +++ b/src/host/writeData.cpp @@ -6,9 +6,10 @@ #include "_stream.h" #include "../types/inc/convert.hpp" - #include "../interactivity/inc/ServiceLocator.hpp" +using Microsoft::Console::Interactivity::ServiceLocator; + // Routine Description: // - Creates a new write data object for used in servicing write console requests // Arguments: @@ -22,31 +23,22 @@ // Return Value: // - THROW: Throws if space cannot be allocated to copy the given string WriteData::WriteData(SCREEN_INFORMATION& siContext, - _In_reads_bytes_(cbContext) PCWCHAR pwchContext, - const size_t cbContext, + std::wstring pwchContext, const UINT uiOutputCodepage) : IWaitRoutine(ReplyDataType::Write), _siContext(siContext), - _pwchContext(THROW_IF_NULL_ALLOC(reinterpret_cast(new byte[cbContext]))), - _cbContext(cbContext), + _pwchContext(std::move(pwchContext)), _uiOutputCodepage(uiOutputCodepage), _fLeadByteCaptured(false), _fLeadByteConsumed(false), _cchUtf8Consumed(0) { - memmove(_pwchContext, pwchContext, _cbContext); } // Routine Description: // - Destroys the write data object // - Frees the string copy we made on creation -WriteData::~WriteData() -{ - if (nullptr != _pwchContext) - { - delete[] _pwchContext; - } -} +WriteData::~WriteData() = default; // Routine Description: // - Stores some additional information about lead byte adjustments from the conversion @@ -102,7 +94,7 @@ bool WriteData::Notify(const WaitTerminationReason TerminationReason, _Out_ DWORD* const pControlKeyState, _Out_ void* const /*pOutputData*/) { - *pNumBytes = _cbContext; + *pNumBytes = 0; *pControlKeyState = 0; if (WI_IsFlagSet(TerminationReason, WaitTerminationReason::ThreadDying)) @@ -111,6 +103,12 @@ bool WriteData::Notify(const WaitTerminationReason TerminationReason, return true; } + auto& gci = ServiceLocator::LocateGlobals().getConsoleInformation(); + if (WI_IsAnyFlagSet(gci.Flags, (CONSOLE_SUSPENDED | CONSOLE_SELECTING | CONSOLE_SCROLLBAR_TRACKING))) + { + return false; + } + // if we get to here, this routine was called by the input // thread, which grabs the current console lock. @@ -119,20 +117,16 @@ bool WriteData::Notify(const WaitTerminationReason TerminationReason, FAIL_FAST_IF(!(Microsoft::Console::Interactivity::ServiceLocator::LocateGlobals().getConsoleInformation().IsConsoleLocked())); - std::unique_ptr waiter; - auto cbContext = _cbContext; - auto Status = DoWriteConsole(_pwchContext, - &cbContext, - _siContext, - waiter); + auto Status = DoWriteConsole(_siContext, _pwchContext); if (Status == CONSOLE_STATUS_WAIT) { // an extra waiter will be created by DoWriteConsole, but we're already a waiter so discard it. - waiter.reset(); return false; } + auto cbContext = _pwchContext.size(); + // There's extra work to do to correct the byte counts if the original call was an A-version call. // We always process and hold text in the waiter as W-version text, but the A call is expecting // a byte value in its own codepage of how much we have written in that codepage. @@ -140,10 +134,6 @@ bool WriteData::Notify(const WaitTerminationReason TerminationReason, { if (CP_UTF8 != _uiOutputCodepage) { - // At this level with WriteConsole, everything is byte counts, so change back to char counts for - // GetALengthFromW to work correctly. - const auto cchContext = cbContext / sizeof(wchar_t); - // For non-UTF-8 codepages, we need to back convert the amount consumed and then // correlate that with any lead bytes we may have kept for later or reintroduced // from previous calls. @@ -152,7 +142,7 @@ bool WriteData::Notify(const WaitTerminationReason TerminationReason, // Start by counting the number of A bytes we used in printing our W string to the screen. try { - cchTextBufferRead = GetALengthFromW(_uiOutputCodepage, { _pwchContext, cchContext }); + cchTextBufferRead = GetALengthFromW(_uiOutputCodepage, _pwchContext); } CATCH_LOG(); diff --git a/src/host/writeData.hpp b/src/host/writeData.hpp index 4609b98d37..30e2a7953a 100644 --- a/src/host/writeData.hpp +++ b/src/host/writeData.hpp @@ -25,8 +25,7 @@ class WriteData : public IWaitRoutine { public: WriteData(SCREEN_INFORMATION& siContext, - _In_reads_bytes_(cbContext) PCWCHAR pwchContext, - const size_t cbContext, + std::wstring pwchContext, const UINT uiOutputCodepage); ~WriteData(); @@ -45,8 +44,7 @@ public: private: SCREEN_INFORMATION& _siContext; - wchar_t* const _pwchContext; - const size_t _cbContext; + std::wstring _pwchContext; UINT const _uiOutputCodepage; bool _fLeadByteCaptured; bool _fLeadByteConsumed; diff --git a/src/server/ApiDispatchers.cpp b/src/server/ApiDispatchers.cpp index 6c99344e80..be9a2fae85 100644 --- a/src/server/ApiDispatchers.cpp +++ b/src/server/ApiDispatchers.cpp @@ -173,7 +173,6 @@ constexpr T saturate(auto val) const auto pInputReadHandleData = pHandleData->GetClientInput(); - std::unique_ptr waiter; InputEventQueue outEvents; auto hr = m->_pApiRoutines->GetConsoleInputImpl( *pInputBuffer, @@ -183,7 +182,7 @@ constexpr T saturate(auto val) a->Unicode, fIsPeek, fIsWaitAllowed, - waiter); + m); // We must return the number of records in the message payload (to alert the client) // as well as in the message headers (below in SetReplyInformation) to alert the driver. @@ -192,14 +191,10 @@ constexpr T saturate(auto val) size_t cbWritten; LOG_IF_FAILED(SizeTMult(outEvents.size(), sizeof(INPUT_RECORD), &cbWritten)); - if (waiter) + if (hr == CONSOLE_STATUS_WAIT) { - hr = ConsoleWaitQueue::s_CreateWait(m, waiter.release()); - if (SUCCEEDED(hr)) - { - *pbReplyPending = TRUE; - hr = CONSOLE_STATUS_WAIT; - } + hr = S_OK; + *pbReplyPending = TRUE; } else { @@ -290,14 +285,13 @@ constexpr T saturate(auto val) // across multiple calls when we are simulating a command prompt input line for the client application. const auto pInputReadHandleData = HandleData->GetClientInput(); - std::unique_ptr waiter; size_t cbWritten; const std::span outputBuffer(reinterpret_cast(pvBuffer), cbBufferSize); auto hr = m->_pApiRoutines->ReadConsoleImpl(*pInputBuffer, outputBuffer, cbWritten, // We must set the reply length in bytes. - waiter, + m, initialData, exeView, *pInputReadHandleData, @@ -308,15 +302,10 @@ constexpr T saturate(auto val) LOG_IF_FAILED(SizeTToULong(cbWritten, &a->NumBytes)); - if (nullptr != waiter.get()) + if (hr == CONSOLE_STATUS_WAIT) { - // If we received a waiter, we need to queue the wait and not reply. - hr = ConsoleWaitQueue::s_CreateWait(m, waiter.release()); - - if (SUCCEEDED(hr)) - { - *pbReplyPending = TRUE; - } + hr = S_OK; + *pbReplyPending = TRUE; } else { @@ -355,7 +344,6 @@ constexpr T saturate(auto val) ULONG cbBufferSize; RETURN_IF_FAILED(m->GetInputBuffer(&pvBuffer, &cbBufferSize)); - std::unique_ptr waiter; size_t cbRead; // We have to hold onto the HR from the call and return it. @@ -373,7 +361,7 @@ constexpr T saturate(auto val) TraceLoggingUInt32(a->NumBytes, "NumBytes"), TraceLoggingCountedWideString(buffer.data(), static_cast(buffer.size()), "Buffer")); - hr = m->_pApiRoutines->WriteConsoleWImpl(*pScreenInfo, buffer, cchInputRead, waiter); + hr = m->_pApiRoutines->WriteConsoleWImpl(*pScreenInfo, buffer, cchInputRead, m); // We must set the reply length in bytes. Convert back from characters. LOG_IF_FAILED(SizeTMult(cchInputRead, sizeof(wchar_t), &cbRead)); @@ -388,7 +376,7 @@ constexpr T saturate(auto val) TraceLoggingUInt32(a->NumBytes, "NumBytes"), TraceLoggingCountedString(buffer.data(), static_cast(buffer.size()), "Buffer")); - hr = m->_pApiRoutines->WriteConsoleAImpl(*pScreenInfo, buffer, cchInputRead, waiter); + hr = m->_pApiRoutines->WriteConsoleAImpl(*pScreenInfo, buffer, cchInputRead, m); // Reply length is already in bytes (chars), don't need to convert. cbRead = cchInputRead; @@ -397,14 +385,10 @@ constexpr T saturate(auto val) // We must return the byte length of the read data in the message. LOG_IF_FAILED(SizeTToULong(cbRead, &a->NumBytes)); - if (nullptr != waiter.get()) + if (hr == CONSOLE_STATUS_WAIT) { - // If we received a waiter, we need to queue the wait and not reply. - hr = ConsoleWaitQueue::s_CreateWait(m, waiter.release()); - if (SUCCEEDED(hr)) - { - *pbReplyPending = TRUE; - } + hr = S_OK; + *pbReplyPending = TRUE; } else { diff --git a/src/server/IApiRoutines.h b/src/server/IApiRoutines.h index 568b150968..eb53f347a3 100644 --- a/src/server/IApiRoutines.h +++ b/src/server/IApiRoutines.h @@ -27,6 +27,8 @@ typedef InputBuffer IConsoleInputObject; class INPUT_READ_HANDLE_DATA; +typedef struct _CONSOLE_API_MSG CONSOLE_API_MSG; + #include "IWaitRoutine.h" #include "../types/inc/IInputEvent.hpp" #include "../types/inc/viewport.hpp" @@ -64,12 +66,12 @@ public: const bool IsUnicode, const bool IsPeek, const bool IsWaitAllowed, - std::unique_ptr& waiter) noexcept = 0; + CONSOLE_API_MSG* pWaitReplyMessage) noexcept = 0; [[nodiscard]] virtual HRESULT ReadConsoleImpl(IConsoleInputObject& context, std::span buffer, size_t& written, - std::unique_ptr& waiter, + CONSOLE_API_MSG* pWaitReplyMessage, const std::wstring_view initialData, const std::wstring_view exeName, INPUT_READ_HANDLE_DATA& readHandleState, @@ -81,12 +83,12 @@ public: [[nodiscard]] virtual HRESULT WriteConsoleAImpl(IConsoleOutputObject& context, const std::string_view buffer, size_t& read, - std::unique_ptr& waiter) noexcept = 0; + CONSOLE_API_MSG* pWaitReplyMessage) noexcept = 0; [[nodiscard]] virtual HRESULT WriteConsoleWImpl(IConsoleOutputObject& context, const std::wstring_view buffer, size_t& read, - std::unique_ptr& waiter) noexcept = 0; + CONSOLE_API_MSG* pWaitReplyMessage) noexcept = 0; #pragma region Thread Creation Info [[nodiscard]] virtual HRESULT GetConsoleLangIdImpl(LANGID& langId) noexcept = 0; diff --git a/src/server/WaitBlock.cpp b/src/server/WaitBlock.cpp index b26dffcb58..a871a52fc6 100644 --- a/src/server/WaitBlock.cpp +++ b/src/server/WaitBlock.cpp @@ -86,6 +86,15 @@ ConsoleWaitBlock::~ConsoleWaitBlock() [[nodiscard]] HRESULT ConsoleWaitBlock::s_CreateWait(_Inout_ CONSOLE_API_MSG* const pWaitReplyMessage, _In_ IWaitRoutine* const pWaiter) { + if (!pWaitReplyMessage || !pWaiter) + { + if (pWaiter) + { + delete pWaiter; + } + return E_INVALIDARG; + } + const auto ProcessData = pWaitReplyMessage->GetProcessHandle(); FAIL_FAST_IF_NULL(ProcessData); diff --git a/src/server/WaitQueue.h b/src/server/WaitQueue.h index 929f364446..be8a871ade 100644 --- a/src/server/WaitQueue.h +++ b/src/server/WaitQueue.h @@ -40,7 +40,6 @@ public: [[nodiscard]] static HRESULT s_CreateWait(_Inout_ CONSOLE_API_MSG* const pWaitReplyMessage, _In_ IWaitRoutine* const pWaiter); -private: bool _NotifyBlock(_In_ ConsoleWaitBlock* pWaitBlock, const WaitTerminationReason TerminationReason);