Fix various issues with the interactive shell (#13634)

* Save state

* Save state

* Save state

* Save state

* Save state

* Save state

* Save state

* Save state

* Save state

* Cleanup for review

* Update ServiceMain.cpp comment

* Remove duplicated definitions from wslservice.idl

* poc: Prototype interactive shell improvments

* Format

* Merge

* Save state

* Correctly configure terminal

* Format

* PR feedback
This commit is contained in:
Blue 2025-10-31 17:18:34 -07:00 committed by GitHub
parent 8b62d5a662
commit 1a6cbe560b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 539 additions and 401 deletions

View File

@ -159,12 +159,16 @@ void HandleMessageImpl(wsl::shared::SocketChannel& Channel, const WSLA_TTY_RELAY
{
THROW_LAST_ERROR_IF(fcntl(Message.TtyMaster, F_SETFL, O_NONBLOCK) < 0);
pollfd pollDescriptors[2];
wsl::shared::SocketChannel TerminalControlChannel({Message.TtyControl}, "TerminalControl");
pollfd pollDescriptors[3];
pollDescriptors[0].fd = Message.TtyInput;
pollDescriptors[0].events = POLLIN;
pollDescriptors[1].fd = Message.TtyMaster;
pollDescriptors[1].events = POLLIN;
pollDescriptors[2].fd = Message.TtyControl;
pollDescriptors[2].events = POLLIN;
std::vector<gsl::byte> pendingStdin;
std::vector<gsl::byte> buffer;
@ -269,6 +273,30 @@ void HandleMessageImpl(wsl::shared::SocketChannel& Channel, const WSLA_TTY_RELAY
pollDescriptors[1].fd = -1;
}
}
// Process message from the terminal control channel.
if (pollDescriptors[2].revents & (POLLIN | POLLHUP | POLLERR))
{
auto [ttyMessage, _] = TerminalControlChannel.ReceiveMessageOrClosed<WSLA_TERMINAL_CHANGED>();
//
// A zero-byte read means that the control channel has been closed
// and that the relay process should exit.
//
if (ttyMessage == nullptr)
{
break;
}
winsize terminal{};
terminal.ws_col = ttyMessage->Columns;
terminal.ws_row = ttyMessage->Rows;
if (ioctl(Message.TtyMaster, TIOCSWINSZ, &terminal))
{
LOG_ERROR("ioctl({}, TIOCSWINSZ) failed {}", Message.TtyMaster, errno);
}
}
}
// Shutdown sockets and tty

View File

@ -162,9 +162,10 @@ struct AbsolutePath
}
};
template <typename THandle = wil::unique_handle>
struct Handle
{
wil::unique_handle& output;
THandle& output;
int operator()(const TChar* input) const
{
@ -173,7 +174,14 @@ struct Handle
return -1;
}
output.reset(ULongToHandle(wcstoul(input, nullptr, 0)));
if constexpr (std::is_same_v<THandle, wil::unique_socket>)
{
output.reset(reinterpret_cast<SOCKET>(ULongToHandle(wcstoul(input, nullptr, 0))));
}
else
{
output.reset(ULongToHandle(wcstoul(input, nullptr, 0)));
}
return 1;
}

View File

@ -390,6 +390,7 @@ typedef enum _LX_MESSAGE_TYPE
LxMessageWSLAOpen,
LxMessageWSLAUnmount,
LxMessageWSLADetach,
LxMessageWSLATerminalChanged,
} LX_MESSAGE_TYPE,
*PLX_MESSAGE_TYPE;
@ -498,6 +499,7 @@ inline auto ToString(LX_MESSAGE_TYPE messageType)
X(LxMessageWSLAOpen)
X(LxMessageWSLAUnmount)
X(LxMessageWSLADetach)
X(LxMessageWSLATerminalChanged)
default:
return "<unexpected LX_MESSAGE_TYPE>";
}
@ -1658,8 +1660,11 @@ struct WSLA_TTY_RELAY
int32_t TtyMaster;
int32_t TtyInput;
int32_t TtyOutput;
int32_t TtyControl;
uint32_t Rows;
uint32_t Columns;
PRETTY_PRINT(FIELD(Header), FIELD(TtyMaster), FIELD(TtyInput), FIELD(TtyOutput));
PRETTY_PRINT(FIELD(Header), FIELD(TtyMaster), FIELD(TtyInput), FIELD(TtyOutput), FIELD(TtyControl), FIELD(Rows), FIELD(Columns));
};
struct WSLA_ACCEPT
@ -1830,6 +1835,19 @@ struct WSLA_DETACH
PRETTY_PRINT(FIELD(Header), FIELD(Lun));
};
struct WSLA_TERMINAL_CHANGED
{
DECLARE_MESSAGE_CTOR(WSLA_TERMINAL_CHANGED);
static inline auto Type = LxMessageWSLATerminalChanged;
MESSAGE_HEADER Header;
unsigned short Rows;
unsigned short Columns;
PRETTY_PRINT(FIELD(Header), FIELD(Rows), FIELD(Columns));
};
typedef struct _LX_MINI_INIT_IMPORT_RESULT
{
static inline auto Type = LxMiniInitMessageImportResult;

View File

@ -1521,10 +1521,10 @@ int RunDebugShell()
THROW_IF_WIN32_BOOL_FALSE(WriteFile(pipe.get(), "\n", 1, nullptr, nullptr));
// Create a thread to relay stdin to the pipe.
wsl::windows::common::SvcCommIo Io;
auto exitEvent = wil::unique_event(wil::EventOptions::ManualReset);
std::thread inputThread(
[&]() { wsl::windows::common::RelayStandardInput(GetStdHandle(STD_INPUT_HANDLE), pipe.get(), {}, exitEvent.get(), &Io); });
std::thread inputThread([&]() {
wsl::windows::common::relay::StandardInputRelay(GetStdHandle(STD_INPUT_HANDLE), pipe.get(), []() {}, exitEvent.get());
});
auto joinThread = wil::scope_exit_log(WI_DIAGNOSTICS_INFO, [&]() {
exitEvent.SetEvent();

View File

@ -377,6 +377,360 @@ void wsl::windows::common::relay::BidirectionalRelay(_In_ HANDLE LeftHandle, _In
}
}
#define TTY_ALT_NUMPAD_VK_MENU (0x12)
#define TTY_ESCAPE_CHARACTER (L'\x1b')
#define TTY_INPUT_EVENT_BUFFER_SIZE (16)
#define TTY_UTF8_TRANSLATION_BUFFER_SIZE (4 * TTY_INPUT_EVENT_BUFFER_SIZE)
BOOL IsActionableKey(_In_ PKEY_EVENT_RECORD KeyEvent)
{
//
// This is a bit complicated to discern.
//
// 1. Our first check is that we only want structures that
// represent at least one key press. If we have 0, then we don't
// need to bother. If we have >1, we'll send the key through
// that many times into the pipe.
// 2. Our second check is where it gets confusing.
// a. Characters that are non-null get an automatic pass. Copy
// them through to the pipe.
// b. Null characters need further scrutiny. We generally do not
// pass nulls through EXCEPT if they're sourced from the
// virtual terminal engine (or another application living
// above our layer). If they're sourced by a non-keyboard
// source, they'll have no scan code (since they didn't come
// from a keyboard). But that rule has an exception too:
// "Enhanced keys" from above the standard range of scan
// codes will return 0 also with a special flag set that says
// they're an enhanced key. That means the desired behavior
// is:
// Scan Code = 0, ENHANCED_KEY = 0
// -> This came from the VT engine or another app
// above our layer.
// Scan Code = 0, ENHANCED_KEY = 1
// -> This came from the keyboard, but is a special
// key like 'Volume Up' that wasn't generally a
// part of historic (pre-1990s) keyboards.
// Scan Code = <anything else>
// -> This came from a keyboard directly.
//
if ((KeyEvent->wRepeatCount == 0) || ((KeyEvent->uChar.UnicodeChar == UNICODE_NULL) &&
((KeyEvent->wVirtualScanCode != 0) || (WI_IsFlagSet(KeyEvent->dwControlKeyState, ENHANCED_KEY)))))
{
return FALSE;
}
return TRUE;
}
BOOL GetNextCharacter(_In_ INPUT_RECORD* InputRecord, _Out_ PWCHAR NextCharacter)
{
BOOL IsNextCharacterValid = FALSE;
if (InputRecord->EventType == KEY_EVENT)
{
const auto KeyEvent = &InputRecord->Event.KeyEvent;
if ((IsActionableKey(KeyEvent) != FALSE) && ((KeyEvent->bKeyDown != FALSE) || (KeyEvent->wVirtualKeyCode == TTY_ALT_NUMPAD_VK_MENU)))
{
*NextCharacter = KeyEvent->uChar.UnicodeChar;
IsNextCharacterValid = TRUE;
}
}
return IsNextCharacterValid;
}
void wsl::windows::common::relay::StandardInputRelay(HANDLE ConsoleHandle, HANDLE OutputHandle, const std::function<void()>& UpdateTerminalSize, HANDLE ExitEvent)
try
{
if (GetFileType(ConsoleHandle) != FILE_TYPE_CHAR)
{
wsl::windows::common::relay::InterruptableRelay(ConsoleHandle, OutputHandle, ExitEvent);
return;
}
//
// N.B. ReadConsoleInputEx has no associated import library.
//
static LxssDynamicFunction<decltype(ReadConsoleInputExW)> readConsoleInput(L"Kernel32.dll", "ReadConsoleInputExW");
INPUT_RECORD InputRecordBuffer[TTY_INPUT_EVENT_BUFFER_SIZE];
INPUT_RECORD* InputRecordPeek = &(InputRecordBuffer[1]);
KEY_EVENT_RECORD* KeyEvent;
DWORD RecordsRead;
OVERLAPPED Overlapped = {0};
const wil::unique_event OverlappedEvent(wil::EventOptions::ManualReset);
Overlapped.hEvent = OverlappedEvent.get();
const HANDLE WaitHandles[] = {ExitEvent, ConsoleHandle};
const std::vector<HANDLE> ExitHandles = {ExitEvent};
for (;;)
{
//
// Because some input events generated by the console are encoded with
// more than one input event, we have to be smart about reading the
// events.
//
// First, we peek at the next input event.
// If it's an escape (wch == L'\x1b') event, then the characters that
// follow are part of an input sequence. We can't know for sure
// how long that sequence is, but we can assume it's all sent to
// the input queue at once, and it's less that 16 events.
// Furthermore, we can assume that if there's an Escape in those
// 16 events, that the escape marks the start of a new sequence.
// So, we'll peek at another 15 events looking for escapes.
// If we see an escape, then we'll read one less than that,
// such that the escape remains the next event in the input.
// From those read events, we'll aggregate chars into a single
// string to send to the subsystem.
// If it's not an escape, send the event through one at a time.
//
//
// Read one input event.
//
DWORD WaitStatus = (WAIT_OBJECT_0 + 1);
do
{
THROW_IF_WIN32_BOOL_FALSE(readConsoleInput(ConsoleHandle, InputRecordBuffer, 1, &RecordsRead, CONSOLE_READ_NOWAIT));
if (RecordsRead == 0)
{
WaitStatus = WaitForMultipleObjects(RTL_NUMBER_OF(WaitHandles), WaitHandles, false, INFINITE);
}
} while ((WaitStatus == (WAIT_OBJECT_0 + 1)) && (RecordsRead == 0));
//
// Stop processing if the exit event has been signaled.
//
if (WaitStatus != (WAIT_OBJECT_0 + 1))
{
WI_ASSERT(WaitStatus == WAIT_OBJECT_0);
break;
}
WI_ASSERT(RecordsRead == 1);
//
// Don't read additional records if the first entry is a window size
// event, or a repeated character. Handle those events on their own.
//
DWORD RecordsPeeked = 0;
if ((InputRecordBuffer[0].EventType != WINDOW_BUFFER_SIZE_EVENT) &&
((InputRecordBuffer[0].EventType != KEY_EVENT) || (InputRecordBuffer[0].Event.KeyEvent.wRepeatCount < 2)))
{
//
// Read additional input records into the buffer if available.
//
THROW_IF_WIN32_BOOL_FALSE(PeekConsoleInputW(ConsoleHandle, InputRecordPeek, (RTL_NUMBER_OF(InputRecordBuffer) - 1), &RecordsPeeked));
}
//
// Iterate over peeked records [1, RecordsPeeked].
//
DWORD AdditionalRecordsToRead = 0;
WCHAR NextCharacter;
for (DWORD RecordIndex = 1; RecordIndex <= RecordsPeeked; RecordIndex++)
{
if (GetNextCharacter(&InputRecordBuffer[RecordIndex], &NextCharacter) != FALSE)
{
KeyEvent = &InputRecordBuffer[RecordIndex].Event.KeyEvent;
if (NextCharacter == TTY_ESCAPE_CHARACTER)
{
//
// CurrentRecord is an escape event. We will start here
// on the next input loop.
//
break;
}
else if (KeyEvent->wRepeatCount > 1)
{
//
// Repeated keys are handled on their own. Start with this
// key on the next input loop.
//
break;
}
else if (IS_HIGH_SURROGATE(NextCharacter) && (RecordIndex >= (RecordsPeeked - 1)))
{
//
// If there is not enough room for the second character of
// a surrogate pair, start with this character on the next
// input loop.
//
// N.B. The test is for at least two remaining records
// because typically a surrogate pair will be entered
// via copy/paste, which will appear as an input
// record with alt-down, alt-up and character. So to
// include the next character of the surrogate pair it
// is likely that the alt-up record will need to be
// read first.
//
break;
}
}
else if (InputRecordBuffer[RecordIndex].EventType == WINDOW_BUFFER_SIZE_EVENT)
{
//
// A window size event is handled on its own.
//
break;
}
//
// Process the additional input record.
//
AdditionalRecordsToRead += 1;
}
if (AdditionalRecordsToRead > 0)
{
THROW_IF_WIN32_BOOL_FALSE(readConsoleInput(ConsoleHandle, InputRecordPeek, AdditionalRecordsToRead, &RecordsRead, CONSOLE_READ_NOWAIT));
if (RecordsRead == 0)
{
//
// This would be an unexpected case. We've already peeked to see
// that there are AdditionalRecordsToRead # of records in the
// input that need reading, yet we didn't get them when we read.
// In this case, move along and finish this input event.
//
break;
}
//
// We already had one input record in the buffer before reading
// additional, So account for that one too
//
RecordsRead += 1;
}
//
// Process each input event. Keydowns will get aggregated into
// Utf8String before getting injected into the subsystem.
//
WCHAR Utf16String[TTY_INPUT_EVENT_BUFFER_SIZE];
ULONG Utf16StringSize = 0;
COORD WindowSize{};
for (DWORD RecordIndex = 0; RecordIndex < RecordsRead; RecordIndex++)
{
INPUT_RECORD* CurrentInputRecord = &(InputRecordBuffer[RecordIndex]);
switch (CurrentInputRecord->EventType)
{
case KEY_EVENT:
//
// Filter out key up events unless they are from an <Alt> key.
// Key up with an <Alt> key could contain a Unicode character
// pasted from the clipboard and converted to an <Alt>+<Numpad> sequence.
//
KeyEvent = &CurrentInputRecord->Event.KeyEvent;
if ((KeyEvent->bKeyDown == FALSE) && (KeyEvent->wVirtualKeyCode != TTY_ALT_NUMPAD_VK_MENU))
{
break;
}
//
// Filter out key presses that are not actionable, such as just
// pressing <Ctrl>, <Alt>, <Shift> etc. These key presses return
// the character of null but will have a valid scan code off the
// keyboard. Certain other key sequences such as Ctrl+A,
// Ctrl+<space>, and Ctrl+@ will also return the character null
// but have no scan code.
// <Alt> + <NumPad> sequences will show an <Alt> but will have
// a scancode and character specified, so they should be actionable.
//
if (IsActionableKey(KeyEvent) == FALSE)
{
break;
}
Utf16String[Utf16StringSize] = KeyEvent->uChar.UnicodeChar;
Utf16StringSize += 1;
break;
case WINDOW_BUFFER_SIZE_EVENT:
//
// Query the window size and send an update message via the
// control channel.
//
UpdateTerminalSize();
break;
}
}
CHAR Utf8String[TTY_UTF8_TRANSLATION_BUFFER_SIZE];
DWORD Utf8StringSize = 0;
if (Utf16StringSize > 0)
{
//
// Windows uses UTF-16LE encoding, Linux uses UTF-8 by default.
// Convert each UTF-16LE character into the proper UTF-8 byte
// sequence equivalent.
//
THROW_LAST_ERROR_IF(
(Utf8StringSize = WideCharToMultiByte(
CP_UTF8, 0, Utf16String, Utf16StringSize, Utf8String, sizeof(Utf8String), nullptr, nullptr)) == 0);
}
//
// Send the input bytes to the terminal.
//
DWORD BytesWritten = 0;
const auto Utf8Span = gslhelpers::struct_as_bytes(Utf8String).first(Utf8StringSize);
if ((RecordsRead == 1) && (InputRecordBuffer[0].EventType == KEY_EVENT) && (InputRecordBuffer[0].Event.KeyEvent.wRepeatCount > 1))
{
WI_ASSERT(Utf16StringSize == 1);
//
// Handle repeated characters. They aren't part of an input
// sequence, so there's only one event that's generating characters.
//
WORD RepeatIndex;
for (RepeatIndex = 0; RepeatIndex < InputRecordBuffer[0].Event.KeyEvent.wRepeatCount; RepeatIndex += 1)
{
BytesWritten = wsl::windows::common::relay::InterruptableWrite(OutputHandle, Utf8Span, ExitHandles, &Overlapped);
if (BytesWritten == 0)
{
break;
}
}
}
else if (Utf8StringSize > 0)
{
BytesWritten = wsl::windows::common::relay::InterruptableWrite(OutputHandle, Utf8Span, ExitHandles, &Overlapped);
if (BytesWritten == 0)
{
break;
}
}
}
return;
}
CATCH_LOG()
void wsl::windows::common::relay::SocketRelay(_In_ SOCKET LeftSocket, _In_ SOCKET RightSocket, _In_ size_t BufferSize)
{
constexpr RelayFlags flags = RelayFlags::LeftIsSocket | RelayFlags::RightIsSocket;

View File

@ -43,6 +43,8 @@ bool InterruptableWait(_In_ HANDLE WaitObject, _In_ const std::vector<HANDLE>& E
DWORD
InterruptableWrite(_In_ HANDLE OutputHandle, _In_ gsl::span<const gsl::byte> Buffer, _In_ const std::vector<HANDLE>& ExitHandles, _In_ LPOVERLAPPED Overlapped);
void StandardInputRelay(HANDLE ConsoleHandle, HANDLE OutputHandle, const std::function<void()>& UpdateTerminalSize, HANDLE ExitEvent);
enum class RelayFlags
{
None = 0,

View File

@ -29,11 +29,6 @@ Abstract:
#define IS_VALID_HANDLE(_handle) ((_handle != NULL) && (_handle != INVALID_HANDLE_VALUE))
#define TTY_ALT_NUMPAD_VK_MENU (0x12)
#define TTY_ESCAPE_CHARACTER (L'\x1b')
#define TTY_INPUT_EVENT_BUFFER_SIZE (16)
#define TTY_UTF8_TRANSLATION_BUFFER_SIZE (4 * TTY_INPUT_EVENT_BUFFER_SIZE)
using wsl::windows::common::ClientExecutionContext;
namespace {
@ -112,64 +107,6 @@ struct CreateProcessArguments
std::wstring NtPath{};
};
BOOL GetNextCharacter(_In_ INPUT_RECORD* InputRecord, _Out_ PWCHAR NextCharacter)
{
BOOL IsNextCharacterValid = FALSE;
if (InputRecord->EventType == KEY_EVENT)
{
const auto KeyEvent = &InputRecord->Event.KeyEvent;
if ((IsActionableKey(KeyEvent) != FALSE) && ((KeyEvent->bKeyDown != FALSE) || (KeyEvent->wVirtualKeyCode == TTY_ALT_NUMPAD_VK_MENU)))
{
*NextCharacter = KeyEvent->uChar.UnicodeChar;
IsNextCharacterValid = TRUE;
}
}
return IsNextCharacterValid;
}
BOOL IsActionableKey(_In_ PKEY_EVENT_RECORD KeyEvent)
{
//
// This is a bit complicated to discern.
//
// 1. Our first check is that we only want structures that
// represent at least one key press. If we have 0, then we don't
// need to bother. If we have >1, we'll send the key through
// that many times into the pipe.
// 2. Our second check is where it gets confusing.
// a. Characters that are non-null get an automatic pass. Copy
// them through to the pipe.
// b. Null characters need further scrutiny. We generally do not
// pass nulls through EXCEPT if they're sourced from the
// virtual terminal engine (or another application living
// above our layer). If they're sourced by a non-keyboard
// source, they'll have no scan code (since they didn't come
// from a keyboard). But that rule has an exception too:
// "Enhanced keys" from above the standard range of scan
// codes will return 0 also with a special flag set that says
// they're an enhanced key. That means the desired behavior
// is:
// Scan Code = 0, ENHANCED_KEY = 0
// -> This came from the VT engine or another app
// above our layer.
// Scan Code = 0, ENHANCED_KEY = 1
// -> This came from the keyboard, but is a special
// key like 'Volume Up' that wasn't generally a
// part of historic (pre-1990s) keyboards.
// Scan Code = <anything else>
// -> This came from a keyboard directly.
//
if ((KeyEvent->wRepeatCount == 0) || ((KeyEvent->uChar.UnicodeChar == UNICODE_NULL) &&
((KeyEvent->wVirtualScanCode != 0) || (WI_IsFlagSet(KeyEvent->dwControlKeyState, ENHANCED_KEY)))))
{
return FALSE;
}
return TRUE;
}
void InitializeInterop(_In_ HANDLE ServerPort, _In_ const GUID& DistroId)
{
//
@ -211,316 +148,6 @@ void SpawnWslHost(_In_ HANDLE ServerPort, _In_ const GUID& DistroId, _In_opt_ LP
// Exported function definitions.
//
void wsl::windows::common::RelayStandardInput(
HANDLE ConsoleHandle,
HANDLE OutputHandle,
const std::shared_ptr<wsl::shared::SocketChannel>& ControlChannel,
HANDLE ExitEvent,
wsl::windows::common::SvcCommIo* Io)
try
{
if (GetFileType(ConsoleHandle) != FILE_TYPE_CHAR)
{
wsl::windows::common::relay::InterruptableRelay(ConsoleHandle, OutputHandle, ExitEvent);
return;
}
//
// N.B. ReadConsoleInputEx has no associated import library.
//
static LxssDynamicFunction<decltype(ReadConsoleInputExW)> readConsoleInput(L"Kernel32.dll", "ReadConsoleInputExW");
INPUT_RECORD InputRecordBuffer[TTY_INPUT_EVENT_BUFFER_SIZE];
INPUT_RECORD* InputRecordPeek = &(InputRecordBuffer[1]);
KEY_EVENT_RECORD* KeyEvent;
DWORD RecordsRead;
OVERLAPPED Overlapped = {0};
const wil::unique_event OverlappedEvent(wil::EventOptions::ManualReset);
Overlapped.hEvent = OverlappedEvent.get();
const HANDLE WaitHandles[] = {ExitEvent, ConsoleHandle};
const std::vector<HANDLE> ExitHandles = {ExitEvent};
for (;;)
{
//
// Because some input events generated by the console are encoded with
// more than one input event, we have to be smart about reading the
// events.
//
// First, we peek at the next input event.
// If it's an escape (wch == L'\x1b') event, then the characters that
// follow are part of an input sequence. We can't know for sure
// how long that sequence is, but we can assume it's all sent to
// the input queue at once, and it's less that 16 events.
// Furthermore, we can assume that if there's an Escape in those
// 16 events, that the escape marks the start of a new sequence.
// So, we'll peek at another 15 events looking for escapes.
// If we see an escape, then we'll read one less than that,
// such that the escape remains the next event in the input.
// From those read events, we'll aggregate chars into a single
// string to send to the subsystem.
// If it's not an escape, send the event through one at a time.
//
//
// Read one input event.
//
DWORD WaitStatus = (WAIT_OBJECT_0 + 1);
do
{
THROW_IF_WIN32_BOOL_FALSE(readConsoleInput(ConsoleHandle, InputRecordBuffer, 1, &RecordsRead, CONSOLE_READ_NOWAIT));
if (RecordsRead == 0)
{
WaitStatus = WaitForMultipleObjects(RTL_NUMBER_OF(WaitHandles), WaitHandles, false, INFINITE);
}
} while ((WaitStatus == (WAIT_OBJECT_0 + 1)) && (RecordsRead == 0));
//
// Stop processing if the exit event has been signaled.
//
if (WaitStatus != (WAIT_OBJECT_0 + 1))
{
WI_ASSERT(WaitStatus == WAIT_OBJECT_0);
break;
}
WI_ASSERT(RecordsRead == 1);
//
// Don't read additional records if the first entry is a window size
// event, or a repeated character. Handle those events on their own.
//
DWORD RecordsPeeked = 0;
if ((InputRecordBuffer[0].EventType != WINDOW_BUFFER_SIZE_EVENT) &&
((InputRecordBuffer[0].EventType != KEY_EVENT) || (InputRecordBuffer[0].Event.KeyEvent.wRepeatCount < 2)))
{
//
// Read additional input records into the buffer if available.
//
THROW_IF_WIN32_BOOL_FALSE(PeekConsoleInputW(ConsoleHandle, InputRecordPeek, (RTL_NUMBER_OF(InputRecordBuffer) - 1), &RecordsPeeked));
}
//
// Iterate over peeked records [1, RecordsPeeked].
//
DWORD AdditionalRecordsToRead = 0;
WCHAR NextCharacter;
for (DWORD RecordIndex = 1; RecordIndex <= RecordsPeeked; RecordIndex++)
{
if (GetNextCharacter(&InputRecordBuffer[RecordIndex], &NextCharacter) != FALSE)
{
KeyEvent = &InputRecordBuffer[RecordIndex].Event.KeyEvent;
if (NextCharacter == TTY_ESCAPE_CHARACTER)
{
//
// CurrentRecord is an escape event. We will start here
// on the next input loop.
//
break;
}
else if (KeyEvent->wRepeatCount > 1)
{
//
// Repeated keys are handled on their own. Start with this
// key on the next input loop.
//
break;
}
else if (IS_HIGH_SURROGATE(NextCharacter) && (RecordIndex >= (RecordsPeeked - 1)))
{
//
// If there is not enough room for the second character of
// a surrogate pair, start with this character on the next
// input loop.
//
// N.B. The test is for at least two remaining records
// because typically a surrogate pair will be entered
// via copy/paste, which will appear as an input
// record with alt-down, alt-up and character. So to
// include the next character of the surrogate pair it
// is likely that the alt-up record will need to be
// read first.
//
break;
}
}
else if (InputRecordBuffer[RecordIndex].EventType == WINDOW_BUFFER_SIZE_EVENT)
{
//
// A window size event is handled on its own.
//
break;
}
//
// Process the additional input record.
//
AdditionalRecordsToRead += 1;
}
if (AdditionalRecordsToRead > 0)
{
THROW_IF_WIN32_BOOL_FALSE(readConsoleInput(ConsoleHandle, InputRecordPeek, AdditionalRecordsToRead, &RecordsRead, CONSOLE_READ_NOWAIT));
if (RecordsRead == 0)
{
//
// This would be an unexpected case. We've already peeked to see
// that there are AdditionalRecordsToRead # of records in the
// input that need reading, yet we didn't get them when we read.
// In this case, move along and finish this input event.
//
break;
}
//
// We already had one input record in the buffer before reading
// additional, So account for that one too
//
RecordsRead += 1;
}
//
// Process each input event. Keydowns will get aggregated into
// Utf8String before getting injected into the subsystem.
//
WCHAR Utf16String[TTY_INPUT_EVENT_BUFFER_SIZE];
ULONG Utf16StringSize = 0;
COORD WindowSize{};
LX_INIT_WINDOW_SIZE_CHANGED WindowSizeMessage{};
for (DWORD RecordIndex = 0; RecordIndex < RecordsRead; RecordIndex++)
{
INPUT_RECORD* CurrentInputRecord = &(InputRecordBuffer[RecordIndex]);
switch (CurrentInputRecord->EventType)
{
case KEY_EVENT:
//
// Filter out key up events unless they are from an <Alt> key.
// Key up with an <Alt> key could contain a Unicode character
// pasted from the clipboard and converted to an <Alt>+<Numpad> sequence.
//
KeyEvent = &CurrentInputRecord->Event.KeyEvent;
if ((KeyEvent->bKeyDown == FALSE) && (KeyEvent->wVirtualKeyCode != TTY_ALT_NUMPAD_VK_MENU))
{
break;
}
//
// Filter out key presses that are not actionable, such as just
// pressing <Ctrl>, <Alt>, <Shift> etc. These key presses return
// the character of null but will have a valid scan code off the
// keyboard. Certain other key sequences such as Ctrl+A,
// Ctrl+<space>, and Ctrl+@ will also return the character null
// but have no scan code.
// <Alt> + <NumPad> sequences will show an <Alt> but will have
// a scancode and character specified, so they should be actionable.
//
if (IsActionableKey(KeyEvent) == FALSE)
{
break;
}
Utf16String[Utf16StringSize] = KeyEvent->uChar.UnicodeChar;
Utf16StringSize += 1;
break;
case WINDOW_BUFFER_SIZE_EVENT:
//
// Query the window size and send an update message via the
// control channel.
//
if (ControlChannel)
{
WindowSize = Io->GetWindowSize();
WindowSizeMessage.Header.MessageType = LxInitMessageWindowSizeChanged;
WindowSizeMessage.Header.MessageSize = sizeof(WindowSizeMessage);
WindowSizeMessage.Columns = WindowSize.X;
WindowSizeMessage.Rows = WindowSize.Y;
try
{
ControlChannel->SendMessage(WindowSizeMessage);
}
CATCH_LOG();
}
break;
}
}
CHAR Utf8String[TTY_UTF8_TRANSLATION_BUFFER_SIZE];
DWORD Utf8StringSize = 0;
if (Utf16StringSize > 0)
{
//
// Windows uses UTF-16LE encoding, Linux uses UTF-8 by default.
// Convert each UTF-16LE character into the proper UTF-8 byte
// sequence equivalent.
//
THROW_LAST_ERROR_IF(
(Utf8StringSize = WideCharToMultiByte(
CP_UTF8, 0, Utf16String, Utf16StringSize, Utf8String, sizeof(Utf8String), nullptr, nullptr)) == 0);
}
//
// Send the input bytes to the terminal.
//
DWORD BytesWritten = 0;
const auto Utf8Span = gslhelpers::struct_as_bytes(Utf8String).first(Utf8StringSize);
if ((RecordsRead == 1) && (InputRecordBuffer[0].EventType == KEY_EVENT) && (InputRecordBuffer[0].Event.KeyEvent.wRepeatCount > 1))
{
WI_ASSERT(Utf16StringSize == 1);
//
// Handle repeated characters. They aren't part of an input
// sequence, so there's only one event that's generating characters.
//
WORD RepeatIndex;
for (RepeatIndex = 0; RepeatIndex < InputRecordBuffer[0].Event.KeyEvent.wRepeatCount; RepeatIndex += 1)
{
BytesWritten = wsl::windows::common::relay::InterruptableWrite(OutputHandle, Utf8Span, ExitHandles, &Overlapped);
if (BytesWritten == 0)
{
break;
}
}
}
else if (Utf8StringSize > 0)
{
BytesWritten = wsl::windows::common::relay::InterruptableWrite(OutputHandle, Utf8Span, ExitHandles, &Overlapped);
if (BytesWritten == 0)
{
break;
}
}
}
return;
}
CATCH_LOG()
wsl::windows::common::SvcComm::SvcComm()
{
// Ensure that the OS has support for running lifted WSL. This interface is always present on Windows 11 and later.
@ -808,7 +435,30 @@ wsl::windows::common::SvcComm::LaunchProcess(
if (IS_VALID_HANDLE(StdIn))
{
std::thread([StdIn, StdInSocket = std::move(StdInSocket), ControlChannel = ControlChannel, ExitHandle = ExitEvent.get(), Io = &Io]() mutable {
RelayStandardInput(StdIn, StdInSocket.get(), ControlChannel, ExitHandle, Io);
auto updateTerminal = [&]() {
//
// Query the window size and send an update message via the
// control channel.
//
if (ControlChannel)
{
auto WindowSize = Io->GetWindowSize();
LX_INIT_WINDOW_SIZE_CHANGED WindowSizeMessage{};
WindowSizeMessage.Header.MessageType = LxInitMessageWindowSizeChanged;
WindowSizeMessage.Header.MessageSize = sizeof(WindowSizeMessage);
WindowSizeMessage.Columns = WindowSize.X;
WindowSizeMessage.Rows = WindowSize.Y;
try
{
ControlChannel->SendMessage(WindowSizeMessage);
}
CATCH_LOG();
}
};
wsl::windows::common::relay::StandardInputRelay(StdIn, StdInSocket.get(), updateTerminal, ExitHandle);
}).detach();
}

View File

@ -21,8 +21,6 @@ Abstract:
namespace wsl::windows::common {
void RelayStandardInput(HANDLE ConsoleHandle, HANDLE OutputHandle, const std::shared_ptr<wsl::shared::SocketChannel>& ControlChannel, HANDLE ExitEvent, SvcCommIo* Io);
class SvcComm
{
public:

View File

@ -36,4 +36,5 @@ LPCWSTR const port_option = L"--port";
LPCWSTR const disable_telemetry_option = L"--disable-telemetry";
LPCWSTR const input_option = L"--input";
LPCWSTR const output_option = L"--output";
LPCWSTR const control_option = L"--control";
} // namespace wslrelay

View File

@ -209,27 +209,34 @@ HRESULT WslUnmapPort(WslVirtualMachineHandle VirtualMachine, const WslPortMappin
->MapPort(UserSettings->AddressFamily, UserSettings->WindowsPort, UserSettings->LinuxPort, true);
}
HRESULT WslLaunchInteractiveTerminal(HANDLE Input, HANDLE Output, HANDLE* Process)
HRESULT WslLaunchInteractiveTerminal(HANDLE Input, HANDLE Output, HANDLE Control, HANDLE* Process)
try
{
wsl::windows::common::helpers::SetHandleInheritable(Input);
wsl::windows::common::helpers::SetHandleInheritable(Output);
wsl::windows::common::helpers::SetHandleInheritable(Control);
auto basePath = wsl::windows::common::wslutil::GetMsiPackagePath();
THROW_HR_IF(E_UNEXPECTED, !basePath.has_value());
auto commandLine = std::format(
L"{}/wslrelay.exe --mode {} --input {} --output {}",
L"{}/wslrelay.exe {} {} {} {} {} {} {} {}",
basePath.value(),
wslrelay::mode_option,
static_cast<int>(wslrelay::RelayMode::InteractiveConsoleRelay),
wslrelay::input_option,
HandleToULong(Input),
HandleToULong(Output));
wslrelay::output_option,
HandleToULong(Output),
wslrelay::control_option,
HandleToUlong(Control));
WSL_LOG("LaunchWslRelay", TraceLoggingValue(commandLine.c_str(), "cmd"));
wsl::windows::common::SubProcess process{nullptr, commandLine.c_str()};
process.InheritHandle(Input);
process.InheritHandle(Output);
process.InheritHandle(Control);
process.SetFlags(CREATE_NEW_CONSOLE);
process.SetShowWindow(SW_SHOW);
*Process = process.Start().release();

View File

@ -131,6 +131,7 @@ enum WslFdType
WslFdTypeLinuxFileOutput = 8,
WslFdTypeLinuxFileAppend = 16,
WslFdTypeLinuxFileCreate = 32,
WslFdTypeTerminalControl = 64,
};
struct WslProcessFileDescriptorSettings
@ -174,7 +175,7 @@ struct WslPortMappingSettings
int AddressFamily;
};
HRESULT WslLaunchInteractiveTerminal(HANDLE Input, HANDLE Output, HANDLE* Process);
HRESULT WslLaunchInteractiveTerminal(HANDLE Input, HANDLE Output, HANDLE TerminalControl, HANDLE* Process);
HRESULT WslWaitForLinuxProcess(WslVirtualMachineHandle VirtualMachine, int32_t Pid, uint64_t TimeoutMs, struct WslWaitResult* Result);

View File

@ -638,15 +638,16 @@ std::vector<wil::unique_socket> WSLAVirtualMachine::CreateLinuxProcessImpl(
// Check if this is a tty or not
const WSLA_PROCESS_FD* ttyInput = nullptr;
const WSLA_PROCESS_FD* ttyOutput = nullptr;
auto interactiveTty = ParseTtyInformation(Fds, FdCount, &ttyInput, &ttyOutput);
const WSLA_PROCESS_FD* ttyControl = nullptr;
auto interactiveTty = ParseTtyInformation(Fds, FdCount, &ttyInput, &ttyOutput, &ttyControl);
auto [pid, _, childChannel] = Fork(WSLA_FORK::Process);
std::vector<wil::unique_socket> sockets(FdCount);
for (size_t i = 0; i < FdCount; i++)
{
if (Fds[i].Type == WslFdTypeDefault || Fds[i].Type == WslFdTypeTerminalInput || Fds[i].Type == WslFdTypeTerminalOutput)
if (Fds[i].Type == WslFdTypeDefault || Fds[i].Type == WslFdTypeTerminalInput || Fds[i].Type == WslFdTypeTerminalOutput ||
Fds[i].Type == WslFdTypeTerminalControl)
{
THROW_HR_IF_MSG(E_INVALIDARG, Fds[i].Type > WslFdTypeTerminalOutput, "Invalid flags: %i", Fds[i].Type);
THROW_HR_IF_MSG(E_INVALIDARG, Fds[i].Path != nullptr, "Fd[%zu] has a non-null path but flags: %i", i, Fds[i].Type);
sockets[i] = ConnectSocket(childChannel, static_cast<int32_t>(Fds[i].Fd));
}
@ -654,7 +655,7 @@ std::vector<wil::unique_socket> WSLAVirtualMachine::CreateLinuxProcessImpl(
{
THROW_HR_IF_MSG(
E_INVALIDARG,
WI_IsAnyFlagSet(Fds[i].Type, WslFdTypeTerminalInput | WslFdTypeTerminalOutput),
WI_IsAnyFlagSet(Fds[i].Type, WslFdTypeTerminalInput | WslFdTypeTerminalOutput | WslFdTypeTerminalControl),
"Invalid flags: %i",
Fds[i].Type);
@ -674,10 +675,11 @@ std::vector<wil::unique_socket> WSLAVirtualMachine::CreateLinuxProcessImpl(
if (interactiveTty)
{
auto [grandChildPid, ptyMaster, grandChildChannel] = Fork(childChannel, WSLA_FORK::Pty);
WSLA_TTY_RELAY relayMessage;
WSLA_TTY_RELAY relayMessage{};
relayMessage.TtyMaster = ptyMaster;
relayMessage.TtyInput = ttyInput->Fd;
relayMessage.TtyOutput = ttyOutput->Fd;
relayMessage.TtyControl = ttyControl == nullptr ? -1 : ttyControl->Fd;
childChannel.SendMessage(relayMessage);
auto result = ExpectClosedChannelOrError(childChannel);
@ -829,7 +831,8 @@ try
}
CATCH_RETURN();
bool WSLAVirtualMachine::ParseTtyInformation(const WSLA_PROCESS_FD* Fds, ULONG FdCount, const WSLA_PROCESS_FD** TtyInput, const WSLA_PROCESS_FD** TtyOutput)
bool WSLAVirtualMachine::ParseTtyInformation(
const WSLA_PROCESS_FD* Fds, ULONG FdCount, const WSLA_PROCESS_FD** TtyInput, const WSLA_PROCESS_FD** TtyOutput, const WSLA_PROCESS_FD** TtyControl)
{
bool foundNonTtyFd = false;
@ -846,6 +849,11 @@ bool WSLAVirtualMachine::ParseTtyInformation(const WSLA_PROCESS_FD* Fds, ULONG F
THROW_HR_IF_MSG(E_INVALIDARG, *TtyOutput != nullptr, "Only one TtyOutput fd can be passed. Index=%lu", i);
*TtyOutput = &Fds[i];
}
else if (Fds[i].Type == WslFdTypeTerminalControl)
{
THROW_HR_IF_MSG(E_INVALIDARG, *TtyControl != nullptr, "Only one TtyOutput fd can be passed. Index=%lu", i);
*TtyControl = &Fds[i];
}
else
{
foundNonTtyFd = true;

View File

@ -52,7 +52,8 @@ public:
private:
static int32_t MountImpl(wsl::shared::SocketChannel& Channel, LPCSTR Source, _In_ LPCSTR Target, _In_ LPCSTR Type, _In_ LPCSTR Options, _In_ ULONG Flags);
static void CALLBACK s_OnExit(_In_ HCS_EVENT* Event, _In_opt_ void* Context);
static bool ParseTtyInformation(const WSLA_PROCESS_FD* Fds, ULONG FdCount, const WSLA_PROCESS_FD** TtyInput, const WSLA_PROCESS_FD** TtyOutput);
static bool ParseTtyInformation(
const WSLA_PROCESS_FD* Fds, ULONG FdCount, const WSLA_PROCESS_FD** TtyInput, const WSLA_PROCESS_FD** TtyOutput, const WSLA_PROCESS_FD** TtyControl);
void ConfigureNetworking();
void OnExit(_In_ const HCS_EVENT* Event);

View File

@ -40,6 +40,7 @@ try
wil::unique_handle exitEvent{};
wil::unique_handle terminalInputHandle{};
wil::unique_handle terminalOutputHandle{};
wil::unique_socket terminalControlHandle{};
uint32_t port{};
GUID vmId{};
bool disableTelemetry = !wsl::shared::OfficialBuild;
@ -54,6 +55,7 @@ try
parser.AddArgument(disableTelemetry, wslrelay::disable_telemetry_option);
parser.AddArgument(Handle{terminalInputHandle}, wslrelay::input_option);
parser.AddArgument(Handle{terminalOutputHandle}, wslrelay::output_option);
parser.AddArgument(Handle<wil::unique_socket>{terminalControlHandle}, wslrelay::control_option);
parser.Parse();
// Initialize logging.
@ -145,15 +147,68 @@ try
case wslrelay::RelayMode::InteractiveConsoleRelay:
{
AllocConsole();
THROW_HR_IF(E_INVALIDARG, !terminalInputHandle || !terminalOutputHandle);
AllocConsole();
auto consoleOutputHandle = wil::unique_handle{CreateFileW(
L"CONOUT$", GENERIC_READ | GENERIC_WRITE, FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE, nullptr, OPEN_EXISTING, 0, nullptr)};
THROW_LAST_ERROR_IF(!consoleOutputHandle.is_valid());
auto consoleInputHandle = wil::unique_handle{CreateFileW(
L"CONIN$", GENERIC_READ | GENERIC_WRITE, FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE, nullptr, OPEN_EXISTING, 0, nullptr)};
THROW_LAST_ERROR_IF(!consoleInputHandle.is_valid());
// Configure console for interactive usage.
{
DWORD OutputMode{};
THROW_LAST_ERROR_IF(!::GetConsoleMode(consoleOutputHandle.get(), &OutputMode));
WI_SetAllFlags(OutputMode, ENABLE_PROCESSED_OUTPUT | ENABLE_VIRTUAL_TERMINAL_PROCESSING | DISABLE_NEWLINE_AUTO_RETURN);
THROW_IF_WIN32_BOOL_FALSE(SetConsoleMode(consoleOutputHandle.get(), OutputMode));
}
{
DWORD InputMode{};
THROW_LAST_ERROR_IF(!::GetConsoleMode(consoleInputHandle.get(), &InputMode));
WI_SetAllFlags(InputMode, (ENABLE_WINDOW_INPUT | ENABLE_VIRTUAL_TERMINAL_INPUT));
WI_ClearAllFlags(InputMode, (ENABLE_ECHO_INPUT | ENABLE_INSERT_MODE | ENABLE_LINE_INPUT | ENABLE_PROCESSED_INPUT));
THROW_IF_WIN32_BOOL_FALSE(SetConsoleMode(consoleInputHandle.get(), InputMode));
}
THROW_LAST_ERROR_IF(!::SetConsoleOutputCP(CP_UTF8));
// Create a thread to relay stdin to the pipe.
wsl::windows::common::SvcCommIo Io;
auto exitEvent = wil::unique_event(wil::EventOptions::ManualReset);
std::optional<wsl::shared::SocketChannel> controlChannel;
if (terminalControlHandle)
{
controlChannel.emplace(std::move(terminalControlHandle), "TerminalControl", exitEvent.get());
}
std::thread inputThread([&]() {
wsl::windows::common::RelayStandardInput(GetStdHandle(STD_INPUT_HANDLE), terminalInputHandle.get(), {}, exitEvent.get(), &Io);
auto updateTerminal = [&controlChannel, &consoleOutputHandle]() {
if (controlChannel.has_value())
{
CONSOLE_SCREEN_BUFFER_INFOEX info{};
info.cbSize = sizeof(info);
THROW_IF_WIN32_BOOL_FALSE(GetConsoleScreenBufferInfoEx(consoleOutputHandle.get(), &info));
WSLA_TERMINAL_CHANGED message{};
message.Columns = info.srWindow.Right - info.srWindow.Left + 1;
message.Rows = info.srWindow.Bottom - info.srWindow.Top + 1;
controlChannel->SendMessage(message);
}
};
wsl::windows::common::relay::StandardInputRelay(
GetStdHandle(STD_INPUT_HANDLE), terminalInputHandle.get(), updateTerminal, exitEvent.get());
});
auto joinThread = wil::scope_exit_log(WI_DIAGNOSTICS_INFO, [&]() {

View File

@ -441,17 +441,21 @@ class WSLATests
auto vm = CreateVm(&settings);
std::vector<const char*> commandLine{"/bin/sh", nullptr};
std::vector<WslProcessFileDescriptorSettings> fds(2);
std::vector<WslProcessFileDescriptorSettings> fds(3);
fds[0].Number = 0;
fds[0].Type = WslFdTypeTerminalInput;
fds[1].Number = 1;
fds[1].Type = WslFdTypeTerminalOutput;
fds[2].Number = 2;
fds[2].Type = WslFdTypeTerminalControl;
const char* env[] = {"TERM=xterm-256color", nullptr};
WslCreateProcessSettings WslCreateProcessSettings{};
WslCreateProcessSettings.Executable = "/bin/sh";
WslCreateProcessSettings.Arguments = commandLine.data();
WslCreateProcessSettings.FileDescriptors = fds.data();
WslCreateProcessSettings.FdCount = static_cast<ULONG>(fds.size());
WslCreateProcessSettings.Environment = env;
int pid = -1;
VERIFY_SUCCEEDED(WslCreateLinuxProcess(vm.get(), &WslCreateProcessSettings, &pid));
@ -487,11 +491,14 @@ class WSLATests
// Validate that the interactive process successfully starts
wil::unique_handle process;
VERIFY_SUCCEEDED(WslLaunchInteractiveTerminal(
WslCreateProcessSettings.FileDescriptors[0].Handle, WslCreateProcessSettings.FileDescriptors[1].Handle, &process));
WslCreateProcessSettings.FileDescriptors[0].Handle,
WslCreateProcessSettings.FileDescriptors[1].Handle,
WslCreateProcessSettings.FileDescriptors[2].Handle,
&process));
// Exit the shell
writeTty("exit\n");
VERIFY_ARE_EQUAL(WaitForSingleObject(process.get(), 30 * 1000), WAIT_OBJECT_0);
VERIFY_ARE_EQUAL(WaitForSingleObject(process.get(), 30000 * 1000), WAIT_OBJECT_0);
}
TEST_METHOD(NATNetworking)