Fix a major stdin wakeup race condition (#18816)

The conhost v2 rewrite from a decade ago introduced a race condition:
Previously, we would acquire and hold the global console lock while
servicing
a console API call. If the call cannot be completed a wait task is
enqueued,
while the lock is held. The v2 rewrite then split the project up into a
"server" and "host" component (which remain to this day). The "host"
would
hold the console lock, while the "server" was responsible for enqueueing
wait
tasks _outside of the console lock_. Without any form of
synchronization,
any operations on the waiter list would then of course introduce a race
condition. In conhost this primarily meant keyboard/mouse input, because
that
runs on the separate Win32 window thread. For Windows Terminal it
primarily
meant the VT input thread.

I do not know why this issue is so extremely noticeable specifically
when we
respond to DSC CPR requests, but I'm also not surprised: I suspect that
the
overall performance issues that conhost had for a long time, meant that
most
things it did were slower than allocating the wait task.
Now that both conhost and Windows Terminal became orders of magnitudes
faster
over the last few years, it probably just so happens that the DSC CPR
request
takes almost exactly as many cycles to complete as allocating the wait
task
does, hence perfectly reproducing the race condition.

There's also a slight chance that this is actually a regression from my
ConPTY
rewrite #17510, but I fail to see what that would be. Regardless of
that,
I'm 100% certain though, that this is a bug that has existed in v0.1.

Closes #18117
Closes #18800

## Validation Steps Performed
* See repro in #18800. In other words:
  * Continuously emit DSC CPR sequences
  * ...read the response from stdin
  * ...and print the response to stdout
  * Doesn't deadlock randomly anymore 
* Feature & Unit tests 
This commit is contained in:
Leonard Hecker 2025-04-23 22:27:21 +02:00 committed by GitHub
parent a8a47b9367
commit 2992421761
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 172 additions and 325 deletions

View File

@ -56,12 +56,12 @@ public:
const bool IsUnicode,
const bool IsPeek,
const bool IsWaitAllowed,
std::unique_ptr<IWaitRoutine>& waiter) noexcept override;
CONSOLE_API_MSG* pWaitReplyMessage) noexcept override;
[[nodiscard]] HRESULT ReadConsoleImpl(IConsoleInputObject& context,
std::span<char> buffer,
size_t& written,
std::unique_ptr<IWaitRoutine>& waiter,
CONSOLE_API_MSG* pWaitReplyMessage,
const std::wstring_view initialData,
const std::wstring_view exeName,
INPUT_READ_HANDLE_DATA& readHandleState,
@ -73,12 +73,12 @@ public:
[[nodiscard]] HRESULT WriteConsoleAImpl(IConsoleOutputObject& context,
const std::string_view buffer,
size_t& read,
std::unique_ptr<IWaitRoutine>& waiter) noexcept override;
CONSOLE_API_MSG* pWaitReplyMessage) noexcept override;
[[nodiscard]] HRESULT WriteConsoleWImpl(IConsoleOutputObject& context,
const std::wstring_view buffer,
size_t& read,
std::unique_ptr<IWaitRoutine>& waiter) noexcept override;
CONSOLE_API_MSG* pWaitReplyMessage) noexcept override;
#pragma region ThreadCreationInfo
[[nodiscard]] HRESULT GetConsoleLangIdImpl(LANGID& langId) noexcept override;

View File

@ -415,30 +415,9 @@ void WriteClearScreen(SCREEN_INFORMATION& screenInfo)
// - pwchBuffer - wide character text to be inserted into buffer
// - pcbBuffer - byte count of pwchBuffer on the way in, number of bytes consumed on the way out.
// - screenInfo - Screen Information class to write the text into at the current cursor position
// - ppWaiter - If writing to the console is blocked for whatever reason, this will be filled with a pointer to context
// that can be used by the server to resume the call at a later time.
// Return Value:
// - STATUS_SUCCESS if OK.
// - CONSOLE_STATUS_WAIT if we couldn't finish now and need to be called back later (see ppWaiter).
// - Or a suitable NTSTATUS format error code for memory/string/math failures.
[[nodiscard]] NTSTATUS DoWriteConsole(_In_reads_bytes_(*pcbBuffer) PCWCHAR pwchBuffer,
_Inout_ size_t* const pcbBuffer,
SCREEN_INFORMATION& screenInfo,
std::unique_ptr<WriteData>& waiter)
[[nodiscard]] HRESULT DoWriteConsole(SCREEN_INFORMATION& screenInfo, std::wstring_view str)
try
{
auto& gci = ServiceLocator::LocateGlobals().getConsoleInformation();
if (WI_IsAnyFlagSet(gci.Flags, (CONSOLE_SUSPENDED | CONSOLE_SELECTING | CONSOLE_SCROLLBAR_TRACKING)))
{
waiter = std::make_unique<WriteData>(screenInfo,
pwchBuffer,
*pcbBuffer,
gci.OutputCP);
return CONSOLE_STATUS_WAIT;
}
const std::wstring_view str{ pwchBuffer, *pcbBuffer / sizeof(WCHAR) };
if (WI_IsAnyFlagClear(screenInfo.OutputMode, ENABLE_VIRTUAL_TERMINAL_PROCESSING | ENABLE_PROCESSED_OUTPUT))
{
WriteCharsLegacy(screenInfo, str, nullptr);
@ -447,55 +426,9 @@ try
{
WriteCharsVT(screenInfo, str);
}
return STATUS_SUCCESS;
}
NT_CATCH_RETURN()
// Routine Description:
// - This method performs the actual work of attempting to write to the console, converting data types as necessary
// to adapt from the server types to the legacy internal host types.
// - It operates on Unicode data only. It's assumed the text is translated by this point.
// Arguments:
// - OutContext - the console output object to write the new text into
// - pwsTextBuffer - wide character text buffer provided by client application to insert
// - cchTextBufferLength - text buffer counted in characters
// - pcchTextBufferRead - character count of the number of characters we were able to insert before returning
// - ppWaiter - If we are blocked from writing now and need to wait, this is filled with contextual data for the server to restore the call later
// Return Value:
// - S_OK if successful.
// - S_OK if we need to wait (check if ppWaiter is not nullptr).
// - Or a suitable HRESULT code for math/string/memory failures.
[[nodiscard]] HRESULT WriteConsoleWImplHelper(IConsoleOutputObject& context,
const std::wstring_view buffer,
size_t& read,
std::unique_ptr<WriteData>& waiter) noexcept
{
try
{
// Set out variables in case we exit early.
read = 0;
waiter.reset();
// Convert characters to bytes to give to DoWriteConsole.
size_t cbTextBufferLength;
RETURN_IF_FAILED(SizeTMult(buffer.size(), sizeof(wchar_t), &cbTextBufferLength));
auto Status = DoWriteConsole(const_cast<wchar_t*>(buffer.data()), &cbTextBufferLength, context, waiter);
// Convert back from bytes to characters for the resulting string length written.
read = cbTextBufferLength / sizeof(wchar_t);
if (Status == CONSOLE_STATUS_WAIT)
{
FAIL_FAST_IF_NULL(waiter.get());
Status = STATUS_SUCCESS;
}
RETURN_NTSTATUS(Status);
}
CATCH_RETURN();
return S_OK;
}
CATCH_RETURN()
// Routine Description:
// - Writes non-Unicode formatted data into the given console output object.
@ -514,13 +447,12 @@ NT_CATCH_RETURN()
[[nodiscard]] HRESULT ApiRoutines::WriteConsoleAImpl(IConsoleOutputObject& context,
const std::string_view buffer,
size_t& read,
std::unique_ptr<IWaitRoutine>& waiter) noexcept
CONSOLE_API_MSG* pWaitReplyMessage) noexcept
{
try
{
// Ensure output variables are initialized.
read = 0;
waiter.reset();
if (buffer.empty())
{
@ -620,67 +552,63 @@ NT_CATCH_RETURN()
wstr.resize((dbcsLength + mbPtrLength) / sizeof(wchar_t));
}
// Hold the specific version of the waiter locally so we can tinker with it if we have to store additional context.
std::unique_ptr<WriteData> writeDataWaiter{};
// Make the W version of the call
size_t wcBufferWritten{};
const auto hr{ WriteConsoleWImplHelper(screenInfo, wstr, wcBufferWritten, writeDataWaiter) };
// If there is no waiter, process the byte count now.
if (nullptr == writeDataWaiter.get())
auto& gci = ServiceLocator::LocateGlobals().getConsoleInformation();
if (WI_IsAnyFlagSet(gci.Flags, (CONSOLE_SUSPENDED | CONSOLE_SELECTING | CONSOLE_SCROLLBAR_TRACKING)))
{
// Calculate how many bytes of the original A buffer were consumed in the W version of the call to satisfy mbBufferRead.
// For UTF-8 conversions, we've already returned this information above.
if (CP_UTF8 != codepage)
{
size_t mbBufferRead{};
const auto waiter = new WriteData(screenInfo, std::move(wstr), gci.OutputCP);
// Start by counting the number of A bytes we used in printing our W string to the screen.
try
{
mbBufferRead = GetALengthFromW(codepage, { wstr.data(), wcBufferWritten });
}
CATCH_LOG();
// If we captured a byte off the string this time around up above, it means we didn't feed
// it into the WriteConsoleW above, and therefore its consumption isn't accounted for
// in the count we just made. Add +1 to compensate.
if (leadByteCaptured)
{
mbBufferRead++;
}
// If we consumed an internally-stored lead byte this time around up above, it means that we
// fed a byte into WriteConsoleW that wasn't a part of this particular call's request.
// We need to -1 to compensate and tell the caller the right number of bytes consumed this request.
if (leadByteConsumed)
{
mbBufferRead--;
}
read = mbBufferRead;
}
}
else
{
// If there is a waiter, then we need to stow some additional information in the wait structure so
// we can synthesize the correct byte count later when the wait routine is triggered.
if (CP_UTF8 != codepage)
{
// For non-UTF8 codepages, save the lead byte captured/consumed data so we can +1 or -1 the final decoded count
// in the WaitData::Notify method later.
writeDataWaiter->SetLeadByteAdjustmentStatus(leadByteCaptured, leadByteConsumed);
waiter->SetLeadByteAdjustmentStatus(leadByteCaptured, leadByteConsumed);
}
else
{
// For UTF8 codepages, just remember the consumption count from the UTF-8 parser.
writeDataWaiter->SetUtf8ConsumedCharacters(read);
waiter->SetUtf8ConsumedCharacters(read);
}
std::ignore = ConsoleWaitQueue::s_CreateWait(pWaitReplyMessage, waiter);
return CONSOLE_STATUS_WAIT;
}
// Give back the waiter now that we're done with tinkering with it.
waiter.reset(writeDataWaiter.release());
// Make the W version of the call
const auto hr = DoWriteConsole(screenInfo, wstr);
// Calculate how many bytes of the original A buffer were consumed in the W version of the call to satisfy mbBufferRead.
// For UTF-8 conversions, we've already returned this information above.
if (CP_UTF8 != codepage)
{
size_t mbBufferRead{};
// Start by counting the number of A bytes we used in printing our W string to the screen.
try
{
mbBufferRead = GetALengthFromW(codepage, wstr);
}
CATCH_LOG();
// If we captured a byte off the string this time around up above, it means we didn't feed
// it into the WriteConsoleW above, and therefore its consumption isn't accounted for
// in the count we just made. Add +1 to compensate.
if (leadByteCaptured)
{
mbBufferRead++;
}
// If we consumed an internally-stored lead byte this time around up above, it means that we
// fed a byte into WriteConsoleW that wasn't a part of this particular call's request.
// We need to -1 to compensate and tell the caller the right number of bytes consumed this request.
if (leadByteConsumed)
{
mbBufferRead--;
}
read = mbBufferRead;
}
return hr;
}
@ -703,20 +631,24 @@ NT_CATCH_RETURN()
[[nodiscard]] HRESULT ApiRoutines::WriteConsoleWImpl(IConsoleOutputObject& context,
const std::wstring_view buffer,
size_t& read,
std::unique_ptr<IWaitRoutine>& waiter) noexcept
CONSOLE_API_MSG* pWaitReplyMessage) noexcept
{
try
{
LockConsole();
auto unlock = wil::scope_exit([&] { UnlockConsole(); });
std::unique_ptr<WriteData> writeDataWaiter;
RETURN_IF_FAILED(WriteConsoleWImplHelper(context.GetActiveBuffer(), buffer, read, writeDataWaiter));
auto& gci = ServiceLocator::LocateGlobals().getConsoleInformation();
if (WI_IsAnyFlagSet(gci.Flags, (CONSOLE_SUSPENDED | CONSOLE_SELECTING | CONSOLE_SCROLLBAR_TRACKING)))
{
std::ignore = ConsoleWaitQueue::s_CreateWait(pWaitReplyMessage, new WriteData(context, std::wstring{ buffer }, gci.OutputCP));
return CONSOLE_STATUS_WAIT;
}
// Transfer specific waiter pointer into the generic interface wrapper.
waiter.reset(writeDataWaiter.release());
return S_OK;
read = 0;
auto Status = DoWriteConsole(context, buffer);
read = buffer.size();
return Status;
}
CATCH_RETURN();
}

View File

@ -25,7 +25,4 @@ void WriteClearScreen(SCREEN_INFORMATION& screenInfo);
// NOTE: console lock must be held when calling this routine
// String has been translated to unicode at this point.
[[nodiscard]] NTSTATUS DoWriteConsole(_In_reads_bytes_(pcbBuffer) const wchar_t* pwchBuffer,
_Inout_ size_t* const pcbBuffer,
SCREEN_INFORMATION& screenInfo,
std::unique_ptr<WriteData>& waiter);
[[nodiscard]] HRESULT DoWriteConsole(SCREEN_INFORMATION& screenInfo, std::wstring_view str);

View File

@ -58,12 +58,10 @@ using Microsoft::Console::Interactivity::ServiceLocator;
const bool IsUnicode,
const bool IsPeek,
const bool IsWaitAllowed,
std::unique_ptr<IWaitRoutine>& waiter) noexcept
CONSOLE_API_MSG* pWaitReplyMessage) noexcept
{
try
{
waiter.reset();
if (eventReadCount == 0)
{
return STATUS_SUCCESS;
@ -83,9 +81,7 @@ using Microsoft::Console::Interactivity::ServiceLocator;
{
// If we're told to wait until later, move all of our context
// to the read data object and send it back up to the server.
waiter = std::make_unique<DirectReadData>(&inputBuffer,
&readHandleState,
eventReadCount);
std::ignore = ConsoleWaitQueue::s_CreateWait(pWaitReplyMessage, new DirectReadData(&inputBuffer, &readHandleState, eventReadCount));
}
return Status;
}

View File

@ -166,9 +166,9 @@ private:
MidiAudio _midiAudio;
};
#define CONSOLE_STATUS_WAIT 0xC0030001
#define CONSOLE_STATUS_READ_COMPLETE 0xC0030002
#define CONSOLE_STATUS_WAIT_NO_BLOCK 0xC0030003
#define CONSOLE_STATUS_WAIT ((HRESULT)0xC0030001)
#define CONSOLE_STATUS_READ_COMPLETE ((HRESULT)0xC0030002)
#define CONSOLE_STATUS_WAIT_NO_BLOCK ((HRESULT)0xC0030003)
#include "../server/ObjectHandle.h"

View File

@ -340,7 +340,7 @@ NT_CATCH_RETURN()
INPUT_READ_HANDLE_DATA& readHandleState,
const std::wstring_view exeName,
const bool unicode,
std::unique_ptr<IWaitRoutine>& waiter) noexcept
CONSOLE_API_MSG* pWaitReplyMessage) noexcept
{
auto& gci = ServiceLocator::LocateGlobals().getConsoleInformation();
RETURN_HR_IF(E_FAIL, !gci.HasActiveOutputBuffer());
@ -364,7 +364,8 @@ NT_CATCH_RETURN()
if (!cookedReadData->Read(unicode, bytesRead, controlKeyState))
{
// memory will be cleaned up by wait queue
waiter.reset(cookedReadData.release());
std::ignore = ConsoleWaitQueue::s_CreateWait(pWaitReplyMessage, cookedReadData.release());
return CONSOLE_STATUS_WAIT;
}
else
{
@ -468,25 +469,23 @@ NT_CATCH_RETURN()
// populated.
// - STATUS_SUCCESS on success
// - Other NSTATUS codes as necessary
[[nodiscard]] NTSTATUS DoReadConsole(InputBuffer& inputBuffer,
const HANDLE processData,
std::span<char> buffer,
size_t& bytesRead,
ULONG& controlKeyState,
const std::wstring_view initialData,
const DWORD ctrlWakeupMask,
INPUT_READ_HANDLE_DATA& readHandleState,
const std::wstring_view exeName,
const bool unicode,
std::unique_ptr<IWaitRoutine>& waiter) noexcept
[[nodiscard]] HRESULT DoReadConsole(InputBuffer& inputBuffer,
const HANDLE processData,
std::span<char> buffer,
size_t& bytesRead,
ULONG& controlKeyState,
const std::wstring_view initialData,
const DWORD ctrlWakeupMask,
INPUT_READ_HANDLE_DATA& readHandleState,
const std::wstring_view exeName,
const bool unicode,
CONSOLE_API_MSG* pWaitReplyMessage) noexcept
{
try
{
LockConsole();
auto Unlock = wil::scope_exit([&] { UnlockConsole(); });
waiter.reset();
bytesRead = 0;
if (buffer.size() < 1)
@ -504,17 +503,17 @@ NT_CATCH_RETURN()
}
else if (WI_IsFlagSet(inputBuffer.InputMode, ENABLE_LINE_INPUT))
{
return NTSTATUS_FROM_HRESULT(_ReadLineInput(inputBuffer,
processData,
buffer,
bytesRead,
controlKeyState,
initialData,
ctrlWakeupMask,
readHandleState,
exeName,
unicode,
waiter));
return _ReadLineInput(inputBuffer,
processData,
buffer,
bytesRead,
controlKeyState,
initialData,
ctrlWakeupMask,
readHandleState,
exeName,
unicode,
pWaitReplyMessage);
}
else
{
@ -525,7 +524,7 @@ NT_CATCH_RETURN()
unicode);
if (status == CONSOLE_STATUS_WAIT)
{
waiter = std::make_unique<RAW_READ_DATA>(&inputBuffer, &readHandleState, gsl::narrow<ULONG>(buffer.size()), reinterpret_cast<wchar_t*>(buffer.data()));
std::ignore = ConsoleWaitQueue::s_CreateWait(pWaitReplyMessage, new RAW_READ_DATA(&inputBuffer, &readHandleState, gsl::narrow<ULONG>(buffer.size()), reinterpret_cast<wchar_t*>(buffer.data())));
}
return status;
}
@ -536,7 +535,7 @@ NT_CATCH_RETURN()
[[nodiscard]] HRESULT ApiRoutines::ReadConsoleImpl(IConsoleInputObject& context,
std::span<char> buffer,
size_t& written,
std::unique_ptr<IWaitRoutine>& waiter,
CONSOLE_API_MSG* pWaitReplyMessage,
const std::wstring_view initialData,
const std::wstring_view exeName,
INPUT_READ_HANDLE_DATA& readHandleState,
@ -545,17 +544,17 @@ NT_CATCH_RETURN()
const DWORD controlWakeupMask,
DWORD& controlKeyState) noexcept
{
return HRESULT_FROM_NT(DoReadConsole(context,
clientHandle,
buffer,
written,
controlKeyState,
initialData,
controlWakeupMask,
readHandleState,
exeName,
IsUnicode,
waiter));
return DoReadConsole(context,
clientHandle,
buffer,
written,
controlKeyState,
initialData,
controlWakeupMask,
readHandleState,
exeName,
IsUnicode,
pWaitReplyMessage);
}
void UnblockWriteConsole(const DWORD dwReason)

View File

@ -372,46 +372,24 @@ class ApiRoutinesTests
for (size_t i = 0; i < cchTestText; i += cchIncrement)
{
Log::Comment(WEX::Common::String().Format(L"Iteration %d of loop with increment %d", i, cchIncrement));
if (fInduceWait)
{
Log::Comment(L"Blocking global output state to induce waits.");
s_AdjustOutputWait(true);
}
s_AdjustOutputWait(fInduceWait);
size_t cchRead = 0;
std::unique_ptr<IWaitRoutine> waiter;
// The increment is either the specified length or the remaining text in the string (if that is smaller).
const auto cchWriteLength = std::min(cchIncrement, cchTestText - i);
// Run the test method
const auto hr = _pApiRoutines->WriteConsoleAImpl(si, { pszTestText + i, cchWriteLength }, cchRead, waiter);
const auto hr = _pApiRoutines->WriteConsoleAImpl(si, { pszTestText + i, cchWriteLength }, cchRead, nullptr);
VERIFY_ARE_EQUAL(S_OK, hr, L"Successful result code from writing.");
if (!fInduceWait)
{
VERIFY_IS_NULL(waiter.get(), L"We should have no waiter for this case.");
VERIFY_ARE_EQUAL(S_OK, hr);
VERIFY_ARE_EQUAL(cchWriteLength, cchRead, L"We should have the same character count back as 'written' that we gave in.");
}
else
{
VERIFY_IS_NOT_NULL(waiter.get(), L"We should have a waiter for this case.");
// The cchRead is irrelevant at this point as it's not going to be returned until we're off the wait.
Log::Comment(L"Unblocking global output state so the wait can be serviced.");
s_AdjustOutputWait(false);
Log::Comment(L"Dispatching the wait.");
auto Status = STATUS_SUCCESS;
size_t dwNumBytes = 0;
DWORD dwControlKeyState = 0; // unused but matches the pattern for read.
void* pOutputData = nullptr; // unused for writes but used for read.
const BOOL bNotifyResult = waiter->Notify(WaitTerminationReason::NoReason, FALSE, &Status, &dwNumBytes, &dwControlKeyState, &pOutputData);
VERIFY_IS_TRUE(!!bNotifyResult, L"Wait completion on notify should be successful.");
VERIFY_ARE_EQUAL(STATUS_SUCCESS, Status, L"We should have a successful return code to pass to the caller.");
const auto dwBytesExpected = cchWriteLength;
VERIFY_ARE_EQUAL(dwBytesExpected, dwNumBytes, L"We should have the byte length of the string we put in as the returned value.");
VERIFY_ARE_EQUAL(CONSOLE_STATUS_WAIT, hr);
}
}
}
@ -431,43 +409,21 @@ class ApiRoutinesTests
gci.LockConsole();
auto Unlock = wil::scope_exit([&] { gci.UnlockConsole(); });
const std::wstring testText(L"Test text");
const std::wstring_view testText(L"Test text");
if (fInduceWait)
{
Log::Comment(L"Blocking global output state to induce waits.");
s_AdjustOutputWait(true);
}
s_AdjustOutputWait(fInduceWait);
size_t cchRead = 0;
std::unique_ptr<IWaitRoutine> waiter;
const auto hr = _pApiRoutines->WriteConsoleWImpl(si, testText, cchRead, waiter);
const auto hr = _pApiRoutines->WriteConsoleWImpl(si, testText, cchRead, nullptr);
VERIFY_ARE_EQUAL(S_OK, hr, L"Successful result code from writing.");
if (!fInduceWait)
{
VERIFY_IS_NULL(waiter.get(), L"We should have no waiter for this case.");
VERIFY_ARE_EQUAL(S_OK, hr);
VERIFY_ARE_EQUAL(testText.size(), cchRead, L"We should have the same character count back as 'written' that we gave in.");
}
else
{
VERIFY_IS_NOT_NULL(waiter.get(), L"We should have a waiter for this case.");
// The cchRead is irrelevant at this point as it's not going to be returned until we're off the wait.
Log::Comment(L"Unblocking global output state so the wait can be serviced.");
s_AdjustOutputWait(false);
Log::Comment(L"Dispatching the wait.");
auto Status = STATUS_SUCCESS;
size_t dwNumBytes = 0;
DWORD dwControlKeyState = 0; // unused but matches the pattern for read.
void* pOutputData = nullptr; // unused for writes but used for read.
const BOOL bNotifyResult = waiter->Notify(WaitTerminationReason::NoReason, TRUE, &Status, &dwNumBytes, &dwControlKeyState, &pOutputData);
VERIFY_IS_TRUE(!!bNotifyResult, L"Wait completion on notify should be successful.");
VERIFY_ARE_EQUAL(STATUS_SUCCESS, Status, L"We should have a successful return code to pass to the caller.");
const auto dwBytesExpected = testText.size() * sizeof(wchar_t);
VERIFY_ARE_EQUAL(dwBytesExpected, dwNumBytes, L"We should have the byte length of the string we put in as the returned value.");
VERIFY_ARE_EQUAL(CONSOLE_STATUS_WAIT, hr);
}
}

View File

@ -2549,11 +2549,7 @@ void ScreenBufferTests::TestAltBufferVtDispatching()
// We're going to write some data to either the main buffer or the alt
// buffer, as if we were using the API.
std::unique_ptr<WriteData> waiter;
std::wstring seq = L"\x1b[5;6H";
auto seqCb = 2 * seq.size();
VERIFY_SUCCEEDED(DoWriteConsole(&seq[0], &seqCb, mainBuffer, waiter));
VERIFY_SUCCEEDED(DoWriteConsole(mainBuffer, L"\x1b[5;6H"));
VERIFY_ARE_EQUAL(til::point(0, 0), mainCursor.GetPosition());
// recall: vt coordinates are (row, column), 1-indexed
VERIFY_ARE_EQUAL(til::point(5, 4), altCursor.GetPosition());
@ -2565,17 +2561,11 @@ void ScreenBufferTests::TestAltBufferVtDispatching()
VERIFY_ARE_EQUAL(expectedDefaults, mainBuffer.GetAttributes());
VERIFY_ARE_EQUAL(expectedDefaults, alternate.GetAttributes());
seq = L"\x1b[48;2;255;0;255m";
seqCb = 2 * seq.size();
VERIFY_SUCCEEDED(DoWriteConsole(&seq[0], &seqCb, mainBuffer, waiter));
VERIFY_SUCCEEDED(DoWriteConsole(mainBuffer, L"\x1b[48;2;255;0;255m"));
VERIFY_ARE_EQUAL(expectedDefaults, mainBuffer.GetAttributes());
VERIFY_ARE_EQUAL(expectedRgb, alternate.GetAttributes());
seq = L"X";
seqCb = 2 * seq.size();
VERIFY_SUCCEEDED(DoWriteConsole(&seq[0], &seqCb, mainBuffer, waiter));
VERIFY_SUCCEEDED(DoWriteConsole(mainBuffer, L"X"));
VERIFY_ARE_EQUAL(til::point(0, 0), mainCursor.GetPosition());
VERIFY_ARE_EQUAL(til::point(6, 4), altCursor.GetPosition());

View File

@ -1390,22 +1390,18 @@ void TextBufferTests::TestBackspaceStringsAPI()
// backspacing it with "\b \b".
// Regardless of how we write those sequences of characters, the end result
// should be the same.
std::unique_ptr<WriteData> waiter;
Log::Comment(NoThrowString().Format(
L"Using WriteCharsLegacy, write \\b \\b as a single string."));
size_t aCb = 2;
size_t seqCb = 6;
VERIFY_SUCCEEDED(DoWriteConsole(L"a", &aCb, si, waiter));
VERIFY_SUCCEEDED(DoWriteConsole(L"\b \b", &seqCb, si, waiter));
VERIFY_SUCCEEDED(DoWriteConsole(si, L"a"));
VERIFY_SUCCEEDED(DoWriteConsole(si, L"\b \b"));
VERIFY_ARE_EQUAL(cursor.GetPosition().x, x0);
VERIFY_ARE_EQUAL(cursor.GetPosition().y, y0);
seqCb = 2;
VERIFY_SUCCEEDED(DoWriteConsole(L"a", &seqCb, si, waiter));
VERIFY_SUCCEEDED(DoWriteConsole(L"\b", &seqCb, si, waiter));
VERIFY_SUCCEEDED(DoWriteConsole(L" ", &seqCb, si, waiter));
VERIFY_SUCCEEDED(DoWriteConsole(L"\b", &seqCb, si, waiter));
VERIFY_SUCCEEDED(DoWriteConsole(si, L"a"));
VERIFY_SUCCEEDED(DoWriteConsole(si, L"\b"));
VERIFY_SUCCEEDED(DoWriteConsole(si, L" "));
VERIFY_SUCCEEDED(DoWriteConsole(si, L"\b"));
VERIFY_ARE_EQUAL(cursor.GetPosition().x, x0);
VERIFY_ARE_EQUAL(cursor.GetPosition().y, y0);
}

View File

@ -249,23 +249,22 @@ class ::Microsoft::Console::VirtualTerminal::VtIoTests
resetContents();
size_t written;
std::unique_ptr<IWaitRoutine> waiter;
std::string_view expected;
std::string_view actual;
THROW_IF_FAILED(routines.WriteConsoleWImpl(*screenInfo, L"", written, waiter));
THROW_IF_FAILED(routines.WriteConsoleWImpl(*screenInfo, L"", written, nullptr));
expected = "";
actual = readOutput();
VERIFY_ARE_EQUAL(expected, actual);
// Force-wrap because we write up to the last column.
THROW_IF_FAILED(routines.WriteConsoleWImpl(*screenInfo, L"aaaaaaaa", written, waiter));
THROW_IF_FAILED(routines.WriteConsoleWImpl(*screenInfo, L"aaaaaaaa", written, nullptr));
expected = "aaaaaaaa\r\n";
actual = readOutput();
VERIFY_ARE_EQUAL(expected, actual);
// Force-wrap because we write up to the last column, but this time with a tab.
THROW_IF_FAILED(routines.WriteConsoleWImpl(*screenInfo, L"a\t\r\nb", written, waiter));
THROW_IF_FAILED(routines.WriteConsoleWImpl(*screenInfo, L"a\t\r\nb", written, nullptr));
expected = "a\t\r\n\r\nb";
actual = readOutput();
VERIFY_ARE_EQUAL(expected, actual);

View File

@ -6,9 +6,10 @@
#include "_stream.h"
#include "../types/inc/convert.hpp"
#include "../interactivity/inc/ServiceLocator.hpp"
using Microsoft::Console::Interactivity::ServiceLocator;
// Routine Description:
// - Creates a new write data object for used in servicing write console requests
// Arguments:
@ -22,31 +23,22 @@
// Return Value:
// - THROW: Throws if space cannot be allocated to copy the given string
WriteData::WriteData(SCREEN_INFORMATION& siContext,
_In_reads_bytes_(cbContext) PCWCHAR pwchContext,
const size_t cbContext,
std::wstring pwchContext,
const UINT uiOutputCodepage) :
IWaitRoutine(ReplyDataType::Write),
_siContext(siContext),
_pwchContext(THROW_IF_NULL_ALLOC(reinterpret_cast<wchar_t*>(new byte[cbContext]))),
_cbContext(cbContext),
_pwchContext(std::move(pwchContext)),
_uiOutputCodepage(uiOutputCodepage),
_fLeadByteCaptured(false),
_fLeadByteConsumed(false),
_cchUtf8Consumed(0)
{
memmove(_pwchContext, pwchContext, _cbContext);
}
// Routine Description:
// - Destroys the write data object
// - Frees the string copy we made on creation
WriteData::~WriteData()
{
if (nullptr != _pwchContext)
{
delete[] _pwchContext;
}
}
WriteData::~WriteData() = default;
// Routine Description:
// - Stores some additional information about lead byte adjustments from the conversion
@ -102,7 +94,7 @@ bool WriteData::Notify(const WaitTerminationReason TerminationReason,
_Out_ DWORD* const pControlKeyState,
_Out_ void* const /*pOutputData*/)
{
*pNumBytes = _cbContext;
*pNumBytes = 0;
*pControlKeyState = 0;
if (WI_IsFlagSet(TerminationReason, WaitTerminationReason::ThreadDying))
@ -111,6 +103,12 @@ bool WriteData::Notify(const WaitTerminationReason TerminationReason,
return true;
}
auto& gci = ServiceLocator::LocateGlobals().getConsoleInformation();
if (WI_IsAnyFlagSet(gci.Flags, (CONSOLE_SUSPENDED | CONSOLE_SELECTING | CONSOLE_SCROLLBAR_TRACKING)))
{
return false;
}
// if we get to here, this routine was called by the input
// thread, which grabs the current console lock.
@ -119,20 +117,16 @@ bool WriteData::Notify(const WaitTerminationReason TerminationReason,
FAIL_FAST_IF(!(Microsoft::Console::Interactivity::ServiceLocator::LocateGlobals().getConsoleInformation().IsConsoleLocked()));
std::unique_ptr<WriteData> waiter;
auto cbContext = _cbContext;
auto Status = DoWriteConsole(_pwchContext,
&cbContext,
_siContext,
waiter);
auto Status = DoWriteConsole(_siContext, _pwchContext);
if (Status == CONSOLE_STATUS_WAIT)
{
// an extra waiter will be created by DoWriteConsole, but we're already a waiter so discard it.
waiter.reset();
return false;
}
auto cbContext = _pwchContext.size();
// There's extra work to do to correct the byte counts if the original call was an A-version call.
// We always process and hold text in the waiter as W-version text, but the A call is expecting
// a byte value in its own codepage of how much we have written in that codepage.
@ -140,10 +134,6 @@ bool WriteData::Notify(const WaitTerminationReason TerminationReason,
{
if (CP_UTF8 != _uiOutputCodepage)
{
// At this level with WriteConsole, everything is byte counts, so change back to char counts for
// GetALengthFromW to work correctly.
const auto cchContext = cbContext / sizeof(wchar_t);
// For non-UTF-8 codepages, we need to back convert the amount consumed and then
// correlate that with any lead bytes we may have kept for later or reintroduced
// from previous calls.
@ -152,7 +142,7 @@ bool WriteData::Notify(const WaitTerminationReason TerminationReason,
// Start by counting the number of A bytes we used in printing our W string to the screen.
try
{
cchTextBufferRead = GetALengthFromW(_uiOutputCodepage, { _pwchContext, cchContext });
cchTextBufferRead = GetALengthFromW(_uiOutputCodepage, _pwchContext);
}
CATCH_LOG();

View File

@ -25,8 +25,7 @@ class WriteData : public IWaitRoutine
{
public:
WriteData(SCREEN_INFORMATION& siContext,
_In_reads_bytes_(cbContext) PCWCHAR pwchContext,
const size_t cbContext,
std::wstring pwchContext,
const UINT uiOutputCodepage);
~WriteData();
@ -45,8 +44,7 @@ public:
private:
SCREEN_INFORMATION& _siContext;
wchar_t* const _pwchContext;
const size_t _cbContext;
std::wstring _pwchContext;
UINT const _uiOutputCodepage;
bool _fLeadByteCaptured;
bool _fLeadByteConsumed;

View File

@ -173,7 +173,6 @@ constexpr T saturate(auto val)
const auto pInputReadHandleData = pHandleData->GetClientInput();
std::unique_ptr<IWaitRoutine> waiter;
InputEventQueue outEvents;
auto hr = m->_pApiRoutines->GetConsoleInputImpl(
*pInputBuffer,
@ -183,7 +182,7 @@ constexpr T saturate(auto val)
a->Unicode,
fIsPeek,
fIsWaitAllowed,
waiter);
m);
// We must return the number of records in the message payload (to alert the client)
// as well as in the message headers (below in SetReplyInformation) to alert the driver.
@ -192,14 +191,10 @@ constexpr T saturate(auto val)
size_t cbWritten;
LOG_IF_FAILED(SizeTMult(outEvents.size(), sizeof(INPUT_RECORD), &cbWritten));
if (waiter)
if (hr == CONSOLE_STATUS_WAIT)
{
hr = ConsoleWaitQueue::s_CreateWait(m, waiter.release());
if (SUCCEEDED(hr))
{
*pbReplyPending = TRUE;
hr = CONSOLE_STATUS_WAIT;
}
hr = S_OK;
*pbReplyPending = TRUE;
}
else
{
@ -290,14 +285,13 @@ constexpr T saturate(auto val)
// across multiple calls when we are simulating a command prompt input line for the client application.
const auto pInputReadHandleData = HandleData->GetClientInput();
std::unique_ptr<IWaitRoutine> waiter;
size_t cbWritten;
const std::span<char> outputBuffer(reinterpret_cast<char*>(pvBuffer), cbBufferSize);
auto hr = m->_pApiRoutines->ReadConsoleImpl(*pInputBuffer,
outputBuffer,
cbWritten, // We must set the reply length in bytes.
waiter,
m,
initialData,
exeView,
*pInputReadHandleData,
@ -308,15 +302,10 @@ constexpr T saturate(auto val)
LOG_IF_FAILED(SizeTToULong(cbWritten, &a->NumBytes));
if (nullptr != waiter.get())
if (hr == CONSOLE_STATUS_WAIT)
{
// If we received a waiter, we need to queue the wait and not reply.
hr = ConsoleWaitQueue::s_CreateWait(m, waiter.release());
if (SUCCEEDED(hr))
{
*pbReplyPending = TRUE;
}
hr = S_OK;
*pbReplyPending = TRUE;
}
else
{
@ -355,7 +344,6 @@ constexpr T saturate(auto val)
ULONG cbBufferSize;
RETURN_IF_FAILED(m->GetInputBuffer(&pvBuffer, &cbBufferSize));
std::unique_ptr<IWaitRoutine> waiter;
size_t cbRead;
// We have to hold onto the HR from the call and return it.
@ -373,7 +361,7 @@ constexpr T saturate(auto val)
TraceLoggingUInt32(a->NumBytes, "NumBytes"),
TraceLoggingCountedWideString(buffer.data(), static_cast<ULONG>(buffer.size()), "Buffer"));
hr = m->_pApiRoutines->WriteConsoleWImpl(*pScreenInfo, buffer, cchInputRead, waiter);
hr = m->_pApiRoutines->WriteConsoleWImpl(*pScreenInfo, buffer, cchInputRead, m);
// We must set the reply length in bytes. Convert back from characters.
LOG_IF_FAILED(SizeTMult(cchInputRead, sizeof(wchar_t), &cbRead));
@ -388,7 +376,7 @@ constexpr T saturate(auto val)
TraceLoggingUInt32(a->NumBytes, "NumBytes"),
TraceLoggingCountedString(buffer.data(), static_cast<ULONG>(buffer.size()), "Buffer"));
hr = m->_pApiRoutines->WriteConsoleAImpl(*pScreenInfo, buffer, cchInputRead, waiter);
hr = m->_pApiRoutines->WriteConsoleAImpl(*pScreenInfo, buffer, cchInputRead, m);
// Reply length is already in bytes (chars), don't need to convert.
cbRead = cchInputRead;
@ -397,14 +385,10 @@ constexpr T saturate(auto val)
// We must return the byte length of the read data in the message.
LOG_IF_FAILED(SizeTToULong(cbRead, &a->NumBytes));
if (nullptr != waiter.get())
if (hr == CONSOLE_STATUS_WAIT)
{
// If we received a waiter, we need to queue the wait and not reply.
hr = ConsoleWaitQueue::s_CreateWait(m, waiter.release());
if (SUCCEEDED(hr))
{
*pbReplyPending = TRUE;
}
hr = S_OK;
*pbReplyPending = TRUE;
}
else
{

View File

@ -27,6 +27,8 @@ typedef InputBuffer IConsoleInputObject;
class INPUT_READ_HANDLE_DATA;
typedef struct _CONSOLE_API_MSG CONSOLE_API_MSG;
#include "IWaitRoutine.h"
#include "../types/inc/IInputEvent.hpp"
#include "../types/inc/viewport.hpp"
@ -64,12 +66,12 @@ public:
const bool IsUnicode,
const bool IsPeek,
const bool IsWaitAllowed,
std::unique_ptr<IWaitRoutine>& waiter) noexcept = 0;
CONSOLE_API_MSG* pWaitReplyMessage) noexcept = 0;
[[nodiscard]] virtual HRESULT ReadConsoleImpl(IConsoleInputObject& context,
std::span<char> buffer,
size_t& written,
std::unique_ptr<IWaitRoutine>& waiter,
CONSOLE_API_MSG* pWaitReplyMessage,
const std::wstring_view initialData,
const std::wstring_view exeName,
INPUT_READ_HANDLE_DATA& readHandleState,
@ -81,12 +83,12 @@ public:
[[nodiscard]] virtual HRESULT WriteConsoleAImpl(IConsoleOutputObject& context,
const std::string_view buffer,
size_t& read,
std::unique_ptr<IWaitRoutine>& waiter) noexcept = 0;
CONSOLE_API_MSG* pWaitReplyMessage) noexcept = 0;
[[nodiscard]] virtual HRESULT WriteConsoleWImpl(IConsoleOutputObject& context,
const std::wstring_view buffer,
size_t& read,
std::unique_ptr<IWaitRoutine>& waiter) noexcept = 0;
CONSOLE_API_MSG* pWaitReplyMessage) noexcept = 0;
#pragma region Thread Creation Info
[[nodiscard]] virtual HRESULT GetConsoleLangIdImpl(LANGID& langId) noexcept = 0;

View File

@ -86,6 +86,15 @@ ConsoleWaitBlock::~ConsoleWaitBlock()
[[nodiscard]] HRESULT ConsoleWaitBlock::s_CreateWait(_Inout_ CONSOLE_API_MSG* const pWaitReplyMessage,
_In_ IWaitRoutine* const pWaiter)
{
if (!pWaitReplyMessage || !pWaiter)
{
if (pWaiter)
{
delete pWaiter;
}
return E_INVALIDARG;
}
const auto ProcessData = pWaitReplyMessage->GetProcessHandle();
FAIL_FAST_IF_NULL(ProcessData);

View File

@ -40,7 +40,6 @@ public:
[[nodiscard]] static HRESULT s_CreateWait(_Inout_ CONSOLE_API_MSG* const pWaitReplyMessage,
_In_ IWaitRoutine* const pWaiter);
private:
bool _NotifyBlock(_In_ ConsoleWaitBlock* pWaitBlock,
const WaitTerminationReason TerminationReason);