From 1a6cbe560bf1903439f990363d2f69d9758cb7d8 Mon Sep 17 00:00:00 2001 From: Blue Date: Fri, 31 Oct 2025 17:18:34 -0700 Subject: [PATCH 1/3] Fix various issues with the interactive shell (#13634) * Save state * Save state * Save state * Save state * Save state * Save state * Save state * Save state * Save state * Cleanup for review * Update ServiceMain.cpp comment * Remove duplicated definitions from wslservice.idl * poc: Prototype interactive shell improvments * Format * Merge * Save state * Correctly configure terminal * Format * PR feedback --- src/linux/init/WSLAInit.cpp | 30 +- src/shared/inc/CommandLine.h | 12 +- src/shared/inc/lxinitshared.h | 20 +- src/windows/common/WslClient.cpp | 6 +- src/windows/common/relay.cpp | 354 ++++++++++++++++ src/windows/common/relay.hpp | 2 + src/windows/common/svccomm.cpp | 398 ++---------------- src/windows/common/svccomm.hpp | 2 - src/windows/inc/wslrelay.h | 1 + src/windows/wslaclient/DllMain.cpp | 13 +- src/windows/wslaclient/WSLAApi.h | 3 +- .../wslaservice/exe/WSLAVirtualMachine.cpp | 20 +- .../wslaservice/exe/WSLAVirtualMachine.h | 3 +- src/windows/wslrelay/main.cpp | 63 ++- test/windows/WSLATests.cpp | 13 +- 15 files changed, 539 insertions(+), 401 deletions(-) diff --git a/src/linux/init/WSLAInit.cpp b/src/linux/init/WSLAInit.cpp index 11a2c48..f839ae2 100644 --- a/src/linux/init/WSLAInit.cpp +++ b/src/linux/init/WSLAInit.cpp @@ -159,12 +159,16 @@ void HandleMessageImpl(wsl::shared::SocketChannel& Channel, const WSLA_TTY_RELAY { THROW_LAST_ERROR_IF(fcntl(Message.TtyMaster, F_SETFL, O_NONBLOCK) < 0); - pollfd pollDescriptors[2]; + wsl::shared::SocketChannel TerminalControlChannel({Message.TtyControl}, "TerminalControl"); + + pollfd pollDescriptors[3]; pollDescriptors[0].fd = Message.TtyInput; pollDescriptors[0].events = POLLIN; pollDescriptors[1].fd = Message.TtyMaster; pollDescriptors[1].events = POLLIN; + pollDescriptors[2].fd = Message.TtyControl; + pollDescriptors[2].events = POLLIN; std::vector pendingStdin; std::vector buffer; @@ -269,6 +273,30 @@ void HandleMessageImpl(wsl::shared::SocketChannel& Channel, const WSLA_TTY_RELAY pollDescriptors[1].fd = -1; } } + + // Process message from the terminal control channel. + if (pollDescriptors[2].revents & (POLLIN | POLLHUP | POLLERR)) + { + auto [ttyMessage, _] = TerminalControlChannel.ReceiveMessageOrClosed(); + + // + // A zero-byte read means that the control channel has been closed + // and that the relay process should exit. + // + + if (ttyMessage == nullptr) + { + break; + } + + winsize terminal{}; + terminal.ws_col = ttyMessage->Columns; + terminal.ws_row = ttyMessage->Rows; + if (ioctl(Message.TtyMaster, TIOCSWINSZ, &terminal)) + { + LOG_ERROR("ioctl({}, TIOCSWINSZ) failed {}", Message.TtyMaster, errno); + } + } } // Shutdown sockets and tty diff --git a/src/shared/inc/CommandLine.h b/src/shared/inc/CommandLine.h index 791d561..dc98122 100644 --- a/src/shared/inc/CommandLine.h +++ b/src/shared/inc/CommandLine.h @@ -162,9 +162,10 @@ struct AbsolutePath } }; +template struct Handle { - wil::unique_handle& output; + THandle& output; int operator()(const TChar* input) const { @@ -173,7 +174,14 @@ struct Handle return -1; } - output.reset(ULongToHandle(wcstoul(input, nullptr, 0))); + if constexpr (std::is_same_v) + { + output.reset(reinterpret_cast(ULongToHandle(wcstoul(input, nullptr, 0)))); + } + else + { + output.reset(ULongToHandle(wcstoul(input, nullptr, 0))); + } return 1; } diff --git a/src/shared/inc/lxinitshared.h b/src/shared/inc/lxinitshared.h index 64f3b8a..624b893 100644 --- a/src/shared/inc/lxinitshared.h +++ b/src/shared/inc/lxinitshared.h @@ -390,6 +390,7 @@ typedef enum _LX_MESSAGE_TYPE LxMessageWSLAOpen, LxMessageWSLAUnmount, LxMessageWSLADetach, + LxMessageWSLATerminalChanged, } LX_MESSAGE_TYPE, *PLX_MESSAGE_TYPE; @@ -498,6 +499,7 @@ inline auto ToString(LX_MESSAGE_TYPE messageType) X(LxMessageWSLAOpen) X(LxMessageWSLAUnmount) X(LxMessageWSLADetach) + X(LxMessageWSLATerminalChanged) default: return ""; } @@ -1658,8 +1660,11 @@ struct WSLA_TTY_RELAY int32_t TtyMaster; int32_t TtyInput; int32_t TtyOutput; + int32_t TtyControl; + uint32_t Rows; + uint32_t Columns; - PRETTY_PRINT(FIELD(Header), FIELD(TtyMaster), FIELD(TtyInput), FIELD(TtyOutput)); + PRETTY_PRINT(FIELD(Header), FIELD(TtyMaster), FIELD(TtyInput), FIELD(TtyOutput), FIELD(TtyControl), FIELD(Rows), FIELD(Columns)); }; struct WSLA_ACCEPT @@ -1830,6 +1835,19 @@ struct WSLA_DETACH PRETTY_PRINT(FIELD(Header), FIELD(Lun)); }; +struct WSLA_TERMINAL_CHANGED +{ + DECLARE_MESSAGE_CTOR(WSLA_TERMINAL_CHANGED); + + static inline auto Type = LxMessageWSLATerminalChanged; + + MESSAGE_HEADER Header; + unsigned short Rows; + unsigned short Columns; + + PRETTY_PRINT(FIELD(Header), FIELD(Rows), FIELD(Columns)); +}; + typedef struct _LX_MINI_INIT_IMPORT_RESULT { static inline auto Type = LxMiniInitMessageImportResult; diff --git a/src/windows/common/WslClient.cpp b/src/windows/common/WslClient.cpp index 04d167a..26cfdfc 100644 --- a/src/windows/common/WslClient.cpp +++ b/src/windows/common/WslClient.cpp @@ -1521,10 +1521,10 @@ int RunDebugShell() THROW_IF_WIN32_BOOL_FALSE(WriteFile(pipe.get(), "\n", 1, nullptr, nullptr)); // Create a thread to relay stdin to the pipe. - wsl::windows::common::SvcCommIo Io; auto exitEvent = wil::unique_event(wil::EventOptions::ManualReset); - std::thread inputThread( - [&]() { wsl::windows::common::RelayStandardInput(GetStdHandle(STD_INPUT_HANDLE), pipe.get(), {}, exitEvent.get(), &Io); }); + std::thread inputThread([&]() { + wsl::windows::common::relay::StandardInputRelay(GetStdHandle(STD_INPUT_HANDLE), pipe.get(), []() {}, exitEvent.get()); + }); auto joinThread = wil::scope_exit_log(WI_DIAGNOSTICS_INFO, [&]() { exitEvent.SetEvent(); diff --git a/src/windows/common/relay.cpp b/src/windows/common/relay.cpp index 8b0f92e..9628848 100644 --- a/src/windows/common/relay.cpp +++ b/src/windows/common/relay.cpp @@ -377,6 +377,360 @@ void wsl::windows::common::relay::BidirectionalRelay(_In_ HANDLE LeftHandle, _In } } +#define TTY_ALT_NUMPAD_VK_MENU (0x12) +#define TTY_ESCAPE_CHARACTER (L'\x1b') +#define TTY_INPUT_EVENT_BUFFER_SIZE (16) +#define TTY_UTF8_TRANSLATION_BUFFER_SIZE (4 * TTY_INPUT_EVENT_BUFFER_SIZE) + +BOOL IsActionableKey(_In_ PKEY_EVENT_RECORD KeyEvent) +{ + // + // This is a bit complicated to discern. + // + // 1. Our first check is that we only want structures that + // represent at least one key press. If we have 0, then we don't + // need to bother. If we have >1, we'll send the key through + // that many times into the pipe. + // 2. Our second check is where it gets confusing. + // a. Characters that are non-null get an automatic pass. Copy + // them through to the pipe. + // b. Null characters need further scrutiny. We generally do not + // pass nulls through EXCEPT if they're sourced from the + // virtual terminal engine (or another application living + // above our layer). If they're sourced by a non-keyboard + // source, they'll have no scan code (since they didn't come + // from a keyboard). But that rule has an exception too: + // "Enhanced keys" from above the standard range of scan + // codes will return 0 also with a special flag set that says + // they're an enhanced key. That means the desired behavior + // is: + // Scan Code = 0, ENHANCED_KEY = 0 + // -> This came from the VT engine or another app + // above our layer. + // Scan Code = 0, ENHANCED_KEY = 1 + // -> This came from the keyboard, but is a special + // key like 'Volume Up' that wasn't generally a + // part of historic (pre-1990s) keyboards. + // Scan Code = + // -> This came from a keyboard directly. + // + + if ((KeyEvent->wRepeatCount == 0) || ((KeyEvent->uChar.UnicodeChar == UNICODE_NULL) && + ((KeyEvent->wVirtualScanCode != 0) || (WI_IsFlagSet(KeyEvent->dwControlKeyState, ENHANCED_KEY))))) + { + return FALSE; + } + + return TRUE; +} + +BOOL GetNextCharacter(_In_ INPUT_RECORD* InputRecord, _Out_ PWCHAR NextCharacter) +{ + BOOL IsNextCharacterValid = FALSE; + if (InputRecord->EventType == KEY_EVENT) + { + const auto KeyEvent = &InputRecord->Event.KeyEvent; + if ((IsActionableKey(KeyEvent) != FALSE) && ((KeyEvent->bKeyDown != FALSE) || (KeyEvent->wVirtualKeyCode == TTY_ALT_NUMPAD_VK_MENU))) + { + *NextCharacter = KeyEvent->uChar.UnicodeChar; + IsNextCharacterValid = TRUE; + } + } + + return IsNextCharacterValid; +} + +void wsl::windows::common::relay::StandardInputRelay(HANDLE ConsoleHandle, HANDLE OutputHandle, const std::function& UpdateTerminalSize, HANDLE ExitEvent) +try +{ + if (GetFileType(ConsoleHandle) != FILE_TYPE_CHAR) + { + wsl::windows::common::relay::InterruptableRelay(ConsoleHandle, OutputHandle, ExitEvent); + return; + } + + // + // N.B. ReadConsoleInputEx has no associated import library. + // + + static LxssDynamicFunction readConsoleInput(L"Kernel32.dll", "ReadConsoleInputExW"); + + INPUT_RECORD InputRecordBuffer[TTY_INPUT_EVENT_BUFFER_SIZE]; + INPUT_RECORD* InputRecordPeek = &(InputRecordBuffer[1]); + KEY_EVENT_RECORD* KeyEvent; + DWORD RecordsRead; + OVERLAPPED Overlapped = {0}; + const wil::unique_event OverlappedEvent(wil::EventOptions::ManualReset); + Overlapped.hEvent = OverlappedEvent.get(); + const HANDLE WaitHandles[] = {ExitEvent, ConsoleHandle}; + const std::vector ExitHandles = {ExitEvent}; + for (;;) + { + // + // Because some input events generated by the console are encoded with + // more than one input event, we have to be smart about reading the + // events. + // + // First, we peek at the next input event. + // If it's an escape (wch == L'\x1b') event, then the characters that + // follow are part of an input sequence. We can't know for sure + // how long that sequence is, but we can assume it's all sent to + // the input queue at once, and it's less that 16 events. + // Furthermore, we can assume that if there's an Escape in those + // 16 events, that the escape marks the start of a new sequence. + // So, we'll peek at another 15 events looking for escapes. + // If we see an escape, then we'll read one less than that, + // such that the escape remains the next event in the input. + // From those read events, we'll aggregate chars into a single + // string to send to the subsystem. + // If it's not an escape, send the event through one at a time. + // + + // + // Read one input event. + // + + DWORD WaitStatus = (WAIT_OBJECT_0 + 1); + do + { + THROW_IF_WIN32_BOOL_FALSE(readConsoleInput(ConsoleHandle, InputRecordBuffer, 1, &RecordsRead, CONSOLE_READ_NOWAIT)); + + if (RecordsRead == 0) + { + WaitStatus = WaitForMultipleObjects(RTL_NUMBER_OF(WaitHandles), WaitHandles, false, INFINITE); + } + } while ((WaitStatus == (WAIT_OBJECT_0 + 1)) && (RecordsRead == 0)); + + // + // Stop processing if the exit event has been signaled. + // + + if (WaitStatus != (WAIT_OBJECT_0 + 1)) + { + WI_ASSERT(WaitStatus == WAIT_OBJECT_0); + + break; + } + + WI_ASSERT(RecordsRead == 1); + + // + // Don't read additional records if the first entry is a window size + // event, or a repeated character. Handle those events on their own. + // + + DWORD RecordsPeeked = 0; + if ((InputRecordBuffer[0].EventType != WINDOW_BUFFER_SIZE_EVENT) && + ((InputRecordBuffer[0].EventType != KEY_EVENT) || (InputRecordBuffer[0].Event.KeyEvent.wRepeatCount < 2))) + { + // + // Read additional input records into the buffer if available. + // + + THROW_IF_WIN32_BOOL_FALSE(PeekConsoleInputW(ConsoleHandle, InputRecordPeek, (RTL_NUMBER_OF(InputRecordBuffer) - 1), &RecordsPeeked)); + } + + // + // Iterate over peeked records [1, RecordsPeeked]. + // + + DWORD AdditionalRecordsToRead = 0; + WCHAR NextCharacter; + for (DWORD RecordIndex = 1; RecordIndex <= RecordsPeeked; RecordIndex++) + { + if (GetNextCharacter(&InputRecordBuffer[RecordIndex], &NextCharacter) != FALSE) + { + KeyEvent = &InputRecordBuffer[RecordIndex].Event.KeyEvent; + if (NextCharacter == TTY_ESCAPE_CHARACTER) + { + // + // CurrentRecord is an escape event. We will start here + // on the next input loop. + // + + break; + } + else if (KeyEvent->wRepeatCount > 1) + { + // + // Repeated keys are handled on their own. Start with this + // key on the next input loop. + // + + break; + } + else if (IS_HIGH_SURROGATE(NextCharacter) && (RecordIndex >= (RecordsPeeked - 1))) + { + // + // If there is not enough room for the second character of + // a surrogate pair, start with this character on the next + // input loop. + // + // N.B. The test is for at least two remaining records + // because typically a surrogate pair will be entered + // via copy/paste, which will appear as an input + // record with alt-down, alt-up and character. So to + // include the next character of the surrogate pair it + // is likely that the alt-up record will need to be + // read first. + // + + break; + } + } + else if (InputRecordBuffer[RecordIndex].EventType == WINDOW_BUFFER_SIZE_EVENT) + { + // + // A window size event is handled on its own. + // + + break; + } + + // + // Process the additional input record. + // + + AdditionalRecordsToRead += 1; + } + + if (AdditionalRecordsToRead > 0) + { + THROW_IF_WIN32_BOOL_FALSE(readConsoleInput(ConsoleHandle, InputRecordPeek, AdditionalRecordsToRead, &RecordsRead, CONSOLE_READ_NOWAIT)); + + if (RecordsRead == 0) + { + // + // This would be an unexpected case. We've already peeked to see + // that there are AdditionalRecordsToRead # of records in the + // input that need reading, yet we didn't get them when we read. + // In this case, move along and finish this input event. + // + + break; + } + + // + // We already had one input record in the buffer before reading + // additional, So account for that one too + // + + RecordsRead += 1; + } + + // + // Process each input event. Keydowns will get aggregated into + // Utf8String before getting injected into the subsystem. + // + + WCHAR Utf16String[TTY_INPUT_EVENT_BUFFER_SIZE]; + ULONG Utf16StringSize = 0; + COORD WindowSize{}; + for (DWORD RecordIndex = 0; RecordIndex < RecordsRead; RecordIndex++) + { + INPUT_RECORD* CurrentInputRecord = &(InputRecordBuffer[RecordIndex]); + switch (CurrentInputRecord->EventType) + { + case KEY_EVENT: + + // + // Filter out key up events unless they are from an key. + // Key up with an key could contain a Unicode character + // pasted from the clipboard and converted to an + sequence. + // + + KeyEvent = &CurrentInputRecord->Event.KeyEvent; + if ((KeyEvent->bKeyDown == FALSE) && (KeyEvent->wVirtualKeyCode != TTY_ALT_NUMPAD_VK_MENU)) + { + break; + } + + // + // Filter out key presses that are not actionable, such as just + // pressing , , etc. These key presses return + // the character of null but will have a valid scan code off the + // keyboard. Certain other key sequences such as Ctrl+A, + // Ctrl+, and Ctrl+@ will also return the character null + // but have no scan code. + // + sequences will show an but will have + // a scancode and character specified, so they should be actionable. + // + + if (IsActionableKey(KeyEvent) == FALSE) + { + break; + } + + Utf16String[Utf16StringSize] = KeyEvent->uChar.UnicodeChar; + Utf16StringSize += 1; + break; + + case WINDOW_BUFFER_SIZE_EVENT: + + // + // Query the window size and send an update message via the + // control channel. + // + + UpdateTerminalSize(); + break; + } + } + + CHAR Utf8String[TTY_UTF8_TRANSLATION_BUFFER_SIZE]; + DWORD Utf8StringSize = 0; + if (Utf16StringSize > 0) + { + // + // Windows uses UTF-16LE encoding, Linux uses UTF-8 by default. + // Convert each UTF-16LE character into the proper UTF-8 byte + // sequence equivalent. + // + + THROW_LAST_ERROR_IF( + (Utf8StringSize = WideCharToMultiByte( + CP_UTF8, 0, Utf16String, Utf16StringSize, Utf8String, sizeof(Utf8String), nullptr, nullptr)) == 0); + } + + // + // Send the input bytes to the terminal. + // + + DWORD BytesWritten = 0; + const auto Utf8Span = gslhelpers::struct_as_bytes(Utf8String).first(Utf8StringSize); + if ((RecordsRead == 1) && (InputRecordBuffer[0].EventType == KEY_EVENT) && (InputRecordBuffer[0].Event.KeyEvent.wRepeatCount > 1)) + { + WI_ASSERT(Utf16StringSize == 1); + + // + // Handle repeated characters. They aren't part of an input + // sequence, so there's only one event that's generating characters. + // + + WORD RepeatIndex; + for (RepeatIndex = 0; RepeatIndex < InputRecordBuffer[0].Event.KeyEvent.wRepeatCount; RepeatIndex += 1) + { + BytesWritten = wsl::windows::common::relay::InterruptableWrite(OutputHandle, Utf8Span, ExitHandles, &Overlapped); + if (BytesWritten == 0) + { + break; + } + } + } + else if (Utf8StringSize > 0) + { + BytesWritten = wsl::windows::common::relay::InterruptableWrite(OutputHandle, Utf8Span, ExitHandles, &Overlapped); + if (BytesWritten == 0) + { + break; + } + } + } + + return; +} +CATCH_LOG() + void wsl::windows::common::relay::SocketRelay(_In_ SOCKET LeftSocket, _In_ SOCKET RightSocket, _In_ size_t BufferSize) { constexpr RelayFlags flags = RelayFlags::LeftIsSocket | RelayFlags::RightIsSocket; diff --git a/src/windows/common/relay.hpp b/src/windows/common/relay.hpp index b546fa7..d995c63 100644 --- a/src/windows/common/relay.hpp +++ b/src/windows/common/relay.hpp @@ -43,6 +43,8 @@ bool InterruptableWait(_In_ HANDLE WaitObject, _In_ const std::vector& E DWORD InterruptableWrite(_In_ HANDLE OutputHandle, _In_ gsl::span Buffer, _In_ const std::vector& ExitHandles, _In_ LPOVERLAPPED Overlapped); +void StandardInputRelay(HANDLE ConsoleHandle, HANDLE OutputHandle, const std::function& UpdateTerminalSize, HANDLE ExitEvent); + enum class RelayFlags { None = 0, diff --git a/src/windows/common/svccomm.cpp b/src/windows/common/svccomm.cpp index d2f1c09..9190310 100644 --- a/src/windows/common/svccomm.cpp +++ b/src/windows/common/svccomm.cpp @@ -29,11 +29,6 @@ Abstract: #define IS_VALID_HANDLE(_handle) ((_handle != NULL) && (_handle != INVALID_HANDLE_VALUE)) -#define TTY_ALT_NUMPAD_VK_MENU (0x12) -#define TTY_ESCAPE_CHARACTER (L'\x1b') -#define TTY_INPUT_EVENT_BUFFER_SIZE (16) -#define TTY_UTF8_TRANSLATION_BUFFER_SIZE (4 * TTY_INPUT_EVENT_BUFFER_SIZE) - using wsl::windows::common::ClientExecutionContext; namespace { @@ -112,64 +107,6 @@ struct CreateProcessArguments std::wstring NtPath{}; }; -BOOL GetNextCharacter(_In_ INPUT_RECORD* InputRecord, _Out_ PWCHAR NextCharacter) -{ - BOOL IsNextCharacterValid = FALSE; - if (InputRecord->EventType == KEY_EVENT) - { - const auto KeyEvent = &InputRecord->Event.KeyEvent; - if ((IsActionableKey(KeyEvent) != FALSE) && ((KeyEvent->bKeyDown != FALSE) || (KeyEvent->wVirtualKeyCode == TTY_ALT_NUMPAD_VK_MENU))) - { - *NextCharacter = KeyEvent->uChar.UnicodeChar; - IsNextCharacterValid = TRUE; - } - } - - return IsNextCharacterValid; -} - -BOOL IsActionableKey(_In_ PKEY_EVENT_RECORD KeyEvent) -{ - // - // This is a bit complicated to discern. - // - // 1. Our first check is that we only want structures that - // represent at least one key press. If we have 0, then we don't - // need to bother. If we have >1, we'll send the key through - // that many times into the pipe. - // 2. Our second check is where it gets confusing. - // a. Characters that are non-null get an automatic pass. Copy - // them through to the pipe. - // b. Null characters need further scrutiny. We generally do not - // pass nulls through EXCEPT if they're sourced from the - // virtual terminal engine (or another application living - // above our layer). If they're sourced by a non-keyboard - // source, they'll have no scan code (since they didn't come - // from a keyboard). But that rule has an exception too: - // "Enhanced keys" from above the standard range of scan - // codes will return 0 also with a special flag set that says - // they're an enhanced key. That means the desired behavior - // is: - // Scan Code = 0, ENHANCED_KEY = 0 - // -> This came from the VT engine or another app - // above our layer. - // Scan Code = 0, ENHANCED_KEY = 1 - // -> This came from the keyboard, but is a special - // key like 'Volume Up' that wasn't generally a - // part of historic (pre-1990s) keyboards. - // Scan Code = - // -> This came from a keyboard directly. - // - - if ((KeyEvent->wRepeatCount == 0) || ((KeyEvent->uChar.UnicodeChar == UNICODE_NULL) && - ((KeyEvent->wVirtualScanCode != 0) || (WI_IsFlagSet(KeyEvent->dwControlKeyState, ENHANCED_KEY))))) - { - return FALSE; - } - - return TRUE; -} - void InitializeInterop(_In_ HANDLE ServerPort, _In_ const GUID& DistroId) { // @@ -211,316 +148,6 @@ void SpawnWslHost(_In_ HANDLE ServerPort, _In_ const GUID& DistroId, _In_opt_ LP // Exported function definitions. // -void wsl::windows::common::RelayStandardInput( - HANDLE ConsoleHandle, - HANDLE OutputHandle, - const std::shared_ptr& ControlChannel, - HANDLE ExitEvent, - wsl::windows::common::SvcCommIo* Io) -try -{ - if (GetFileType(ConsoleHandle) != FILE_TYPE_CHAR) - { - wsl::windows::common::relay::InterruptableRelay(ConsoleHandle, OutputHandle, ExitEvent); - return; - } - - // - // N.B. ReadConsoleInputEx has no associated import library. - // - - static LxssDynamicFunction readConsoleInput(L"Kernel32.dll", "ReadConsoleInputExW"); - - INPUT_RECORD InputRecordBuffer[TTY_INPUT_EVENT_BUFFER_SIZE]; - INPUT_RECORD* InputRecordPeek = &(InputRecordBuffer[1]); - KEY_EVENT_RECORD* KeyEvent; - DWORD RecordsRead; - OVERLAPPED Overlapped = {0}; - const wil::unique_event OverlappedEvent(wil::EventOptions::ManualReset); - Overlapped.hEvent = OverlappedEvent.get(); - const HANDLE WaitHandles[] = {ExitEvent, ConsoleHandle}; - const std::vector ExitHandles = {ExitEvent}; - for (;;) - { - // - // Because some input events generated by the console are encoded with - // more than one input event, we have to be smart about reading the - // events. - // - // First, we peek at the next input event. - // If it's an escape (wch == L'\x1b') event, then the characters that - // follow are part of an input sequence. We can't know for sure - // how long that sequence is, but we can assume it's all sent to - // the input queue at once, and it's less that 16 events. - // Furthermore, we can assume that if there's an Escape in those - // 16 events, that the escape marks the start of a new sequence. - // So, we'll peek at another 15 events looking for escapes. - // If we see an escape, then we'll read one less than that, - // such that the escape remains the next event in the input. - // From those read events, we'll aggregate chars into a single - // string to send to the subsystem. - // If it's not an escape, send the event through one at a time. - // - - // - // Read one input event. - // - - DWORD WaitStatus = (WAIT_OBJECT_0 + 1); - do - { - THROW_IF_WIN32_BOOL_FALSE(readConsoleInput(ConsoleHandle, InputRecordBuffer, 1, &RecordsRead, CONSOLE_READ_NOWAIT)); - - if (RecordsRead == 0) - { - WaitStatus = WaitForMultipleObjects(RTL_NUMBER_OF(WaitHandles), WaitHandles, false, INFINITE); - } - } while ((WaitStatus == (WAIT_OBJECT_0 + 1)) && (RecordsRead == 0)); - - // - // Stop processing if the exit event has been signaled. - // - - if (WaitStatus != (WAIT_OBJECT_0 + 1)) - { - WI_ASSERT(WaitStatus == WAIT_OBJECT_0); - - break; - } - - WI_ASSERT(RecordsRead == 1); - - // - // Don't read additional records if the first entry is a window size - // event, or a repeated character. Handle those events on their own. - // - - DWORD RecordsPeeked = 0; - if ((InputRecordBuffer[0].EventType != WINDOW_BUFFER_SIZE_EVENT) && - ((InputRecordBuffer[0].EventType != KEY_EVENT) || (InputRecordBuffer[0].Event.KeyEvent.wRepeatCount < 2))) - { - // - // Read additional input records into the buffer if available. - // - - THROW_IF_WIN32_BOOL_FALSE(PeekConsoleInputW(ConsoleHandle, InputRecordPeek, (RTL_NUMBER_OF(InputRecordBuffer) - 1), &RecordsPeeked)); - } - - // - // Iterate over peeked records [1, RecordsPeeked]. - // - - DWORD AdditionalRecordsToRead = 0; - WCHAR NextCharacter; - for (DWORD RecordIndex = 1; RecordIndex <= RecordsPeeked; RecordIndex++) - { - if (GetNextCharacter(&InputRecordBuffer[RecordIndex], &NextCharacter) != FALSE) - { - KeyEvent = &InputRecordBuffer[RecordIndex].Event.KeyEvent; - if (NextCharacter == TTY_ESCAPE_CHARACTER) - { - // - // CurrentRecord is an escape event. We will start here - // on the next input loop. - // - - break; - } - else if (KeyEvent->wRepeatCount > 1) - { - // - // Repeated keys are handled on their own. Start with this - // key on the next input loop. - // - - break; - } - else if (IS_HIGH_SURROGATE(NextCharacter) && (RecordIndex >= (RecordsPeeked - 1))) - { - // - // If there is not enough room for the second character of - // a surrogate pair, start with this character on the next - // input loop. - // - // N.B. The test is for at least two remaining records - // because typically a surrogate pair will be entered - // via copy/paste, which will appear as an input - // record with alt-down, alt-up and character. So to - // include the next character of the surrogate pair it - // is likely that the alt-up record will need to be - // read first. - // - - break; - } - } - else if (InputRecordBuffer[RecordIndex].EventType == WINDOW_BUFFER_SIZE_EVENT) - { - // - // A window size event is handled on its own. - // - - break; - } - - // - // Process the additional input record. - // - - AdditionalRecordsToRead += 1; - } - - if (AdditionalRecordsToRead > 0) - { - THROW_IF_WIN32_BOOL_FALSE(readConsoleInput(ConsoleHandle, InputRecordPeek, AdditionalRecordsToRead, &RecordsRead, CONSOLE_READ_NOWAIT)); - - if (RecordsRead == 0) - { - // - // This would be an unexpected case. We've already peeked to see - // that there are AdditionalRecordsToRead # of records in the - // input that need reading, yet we didn't get them when we read. - // In this case, move along and finish this input event. - // - - break; - } - - // - // We already had one input record in the buffer before reading - // additional, So account for that one too - // - - RecordsRead += 1; - } - - // - // Process each input event. Keydowns will get aggregated into - // Utf8String before getting injected into the subsystem. - // - - WCHAR Utf16String[TTY_INPUT_EVENT_BUFFER_SIZE]; - ULONG Utf16StringSize = 0; - COORD WindowSize{}; - LX_INIT_WINDOW_SIZE_CHANGED WindowSizeMessage{}; - for (DWORD RecordIndex = 0; RecordIndex < RecordsRead; RecordIndex++) - { - INPUT_RECORD* CurrentInputRecord = &(InputRecordBuffer[RecordIndex]); - switch (CurrentInputRecord->EventType) - { - case KEY_EVENT: - - // - // Filter out key up events unless they are from an key. - // Key up with an key could contain a Unicode character - // pasted from the clipboard and converted to an + sequence. - // - - KeyEvent = &CurrentInputRecord->Event.KeyEvent; - if ((KeyEvent->bKeyDown == FALSE) && (KeyEvent->wVirtualKeyCode != TTY_ALT_NUMPAD_VK_MENU)) - { - break; - } - - // - // Filter out key presses that are not actionable, such as just - // pressing , , etc. These key presses return - // the character of null but will have a valid scan code off the - // keyboard. Certain other key sequences such as Ctrl+A, - // Ctrl+, and Ctrl+@ will also return the character null - // but have no scan code. - // + sequences will show an but will have - // a scancode and character specified, so they should be actionable. - // - - if (IsActionableKey(KeyEvent) == FALSE) - { - break; - } - - Utf16String[Utf16StringSize] = KeyEvent->uChar.UnicodeChar; - Utf16StringSize += 1; - break; - - case WINDOW_BUFFER_SIZE_EVENT: - - // - // Query the window size and send an update message via the - // control channel. - // - if (ControlChannel) - { - WindowSize = Io->GetWindowSize(); - WindowSizeMessage.Header.MessageType = LxInitMessageWindowSizeChanged; - WindowSizeMessage.Header.MessageSize = sizeof(WindowSizeMessage); - WindowSizeMessage.Columns = WindowSize.X; - WindowSizeMessage.Rows = WindowSize.Y; - - try - { - ControlChannel->SendMessage(WindowSizeMessage); - } - CATCH_LOG(); - } - - break; - } - } - - CHAR Utf8String[TTY_UTF8_TRANSLATION_BUFFER_SIZE]; - DWORD Utf8StringSize = 0; - if (Utf16StringSize > 0) - { - // - // Windows uses UTF-16LE encoding, Linux uses UTF-8 by default. - // Convert each UTF-16LE character into the proper UTF-8 byte - // sequence equivalent. - // - - THROW_LAST_ERROR_IF( - (Utf8StringSize = WideCharToMultiByte( - CP_UTF8, 0, Utf16String, Utf16StringSize, Utf8String, sizeof(Utf8String), nullptr, nullptr)) == 0); - } - - // - // Send the input bytes to the terminal. - // - - DWORD BytesWritten = 0; - const auto Utf8Span = gslhelpers::struct_as_bytes(Utf8String).first(Utf8StringSize); - if ((RecordsRead == 1) && (InputRecordBuffer[0].EventType == KEY_EVENT) && (InputRecordBuffer[0].Event.KeyEvent.wRepeatCount > 1)) - { - WI_ASSERT(Utf16StringSize == 1); - - // - // Handle repeated characters. They aren't part of an input - // sequence, so there's only one event that's generating characters. - // - - WORD RepeatIndex; - for (RepeatIndex = 0; RepeatIndex < InputRecordBuffer[0].Event.KeyEvent.wRepeatCount; RepeatIndex += 1) - { - BytesWritten = wsl::windows::common::relay::InterruptableWrite(OutputHandle, Utf8Span, ExitHandles, &Overlapped); - if (BytesWritten == 0) - { - break; - } - } - } - else if (Utf8StringSize > 0) - { - BytesWritten = wsl::windows::common::relay::InterruptableWrite(OutputHandle, Utf8Span, ExitHandles, &Overlapped); - if (BytesWritten == 0) - { - break; - } - } - } - - return; -} -CATCH_LOG() - wsl::windows::common::SvcComm::SvcComm() { // Ensure that the OS has support for running lifted WSL. This interface is always present on Windows 11 and later. @@ -808,7 +435,30 @@ wsl::windows::common::SvcComm::LaunchProcess( if (IS_VALID_HANDLE(StdIn)) { std::thread([StdIn, StdInSocket = std::move(StdInSocket), ControlChannel = ControlChannel, ExitHandle = ExitEvent.get(), Io = &Io]() mutable { - RelayStandardInput(StdIn, StdInSocket.get(), ControlChannel, ExitHandle, Io); + auto updateTerminal = [&]() { + // + // Query the window size and send an update message via the + // control channel. + // + if (ControlChannel) + { + auto WindowSize = Io->GetWindowSize(); + + LX_INIT_WINDOW_SIZE_CHANGED WindowSizeMessage{}; + WindowSizeMessage.Header.MessageType = LxInitMessageWindowSizeChanged; + WindowSizeMessage.Header.MessageSize = sizeof(WindowSizeMessage); + WindowSizeMessage.Columns = WindowSize.X; + WindowSizeMessage.Rows = WindowSize.Y; + + try + { + ControlChannel->SendMessage(WindowSizeMessage); + } + CATCH_LOG(); + } + }; + + wsl::windows::common::relay::StandardInputRelay(StdIn, StdInSocket.get(), updateTerminal, ExitHandle); }).detach(); } diff --git a/src/windows/common/svccomm.hpp b/src/windows/common/svccomm.hpp index 35c6018..983134c 100644 --- a/src/windows/common/svccomm.hpp +++ b/src/windows/common/svccomm.hpp @@ -21,8 +21,6 @@ Abstract: namespace wsl::windows::common { -void RelayStandardInput(HANDLE ConsoleHandle, HANDLE OutputHandle, const std::shared_ptr& ControlChannel, HANDLE ExitEvent, SvcCommIo* Io); - class SvcComm { public: diff --git a/src/windows/inc/wslrelay.h b/src/windows/inc/wslrelay.h index 9638a23..557d31c 100644 --- a/src/windows/inc/wslrelay.h +++ b/src/windows/inc/wslrelay.h @@ -36,4 +36,5 @@ LPCWSTR const port_option = L"--port"; LPCWSTR const disable_telemetry_option = L"--disable-telemetry"; LPCWSTR const input_option = L"--input"; LPCWSTR const output_option = L"--output"; +LPCWSTR const control_option = L"--control"; } // namespace wslrelay \ No newline at end of file diff --git a/src/windows/wslaclient/DllMain.cpp b/src/windows/wslaclient/DllMain.cpp index 012246b..ac193f7 100644 --- a/src/windows/wslaclient/DllMain.cpp +++ b/src/windows/wslaclient/DllMain.cpp @@ -209,27 +209,34 @@ HRESULT WslUnmapPort(WslVirtualMachineHandle VirtualMachine, const WslPortMappin ->MapPort(UserSettings->AddressFamily, UserSettings->WindowsPort, UserSettings->LinuxPort, true); } -HRESULT WslLaunchInteractiveTerminal(HANDLE Input, HANDLE Output, HANDLE* Process) +HRESULT WslLaunchInteractiveTerminal(HANDLE Input, HANDLE Output, HANDLE Control, HANDLE* Process) try { wsl::windows::common::helpers::SetHandleInheritable(Input); wsl::windows::common::helpers::SetHandleInheritable(Output); + wsl::windows::common::helpers::SetHandleInheritable(Control); auto basePath = wsl::windows::common::wslutil::GetMsiPackagePath(); THROW_HR_IF(E_UNEXPECTED, !basePath.has_value()); auto commandLine = std::format( - L"{}/wslrelay.exe --mode {} --input {} --output {}", + L"{}/wslrelay.exe {} {} {} {} {} {} {} {}", basePath.value(), + wslrelay::mode_option, static_cast(wslrelay::RelayMode::InteractiveConsoleRelay), + wslrelay::input_option, HandleToULong(Input), - HandleToULong(Output)); + wslrelay::output_option, + HandleToULong(Output), + wslrelay::control_option, + HandleToUlong(Control)); WSL_LOG("LaunchWslRelay", TraceLoggingValue(commandLine.c_str(), "cmd")); wsl::windows::common::SubProcess process{nullptr, commandLine.c_str()}; process.InheritHandle(Input); process.InheritHandle(Output); + process.InheritHandle(Control); process.SetFlags(CREATE_NEW_CONSOLE); process.SetShowWindow(SW_SHOW); *Process = process.Start().release(); diff --git a/src/windows/wslaclient/WSLAApi.h b/src/windows/wslaclient/WSLAApi.h index b4880ab..1352cb0 100644 --- a/src/windows/wslaclient/WSLAApi.h +++ b/src/windows/wslaclient/WSLAApi.h @@ -131,6 +131,7 @@ enum WslFdType WslFdTypeLinuxFileOutput = 8, WslFdTypeLinuxFileAppend = 16, WslFdTypeLinuxFileCreate = 32, + WslFdTypeTerminalControl = 64, }; struct WslProcessFileDescriptorSettings @@ -174,7 +175,7 @@ struct WslPortMappingSettings int AddressFamily; }; -HRESULT WslLaunchInteractiveTerminal(HANDLE Input, HANDLE Output, HANDLE* Process); +HRESULT WslLaunchInteractiveTerminal(HANDLE Input, HANDLE Output, HANDLE TerminalControl, HANDLE* Process); HRESULT WslWaitForLinuxProcess(WslVirtualMachineHandle VirtualMachine, int32_t Pid, uint64_t TimeoutMs, struct WslWaitResult* Result); diff --git a/src/windows/wslaservice/exe/WSLAVirtualMachine.cpp b/src/windows/wslaservice/exe/WSLAVirtualMachine.cpp index 593e448..2800e2d 100644 --- a/src/windows/wslaservice/exe/WSLAVirtualMachine.cpp +++ b/src/windows/wslaservice/exe/WSLAVirtualMachine.cpp @@ -638,15 +638,16 @@ std::vector WSLAVirtualMachine::CreateLinuxProcessImpl( // Check if this is a tty or not const WSLA_PROCESS_FD* ttyInput = nullptr; const WSLA_PROCESS_FD* ttyOutput = nullptr; - auto interactiveTty = ParseTtyInformation(Fds, FdCount, &ttyInput, &ttyOutput); + const WSLA_PROCESS_FD* ttyControl = nullptr; + auto interactiveTty = ParseTtyInformation(Fds, FdCount, &ttyInput, &ttyOutput, &ttyControl); auto [pid, _, childChannel] = Fork(WSLA_FORK::Process); std::vector sockets(FdCount); for (size_t i = 0; i < FdCount; i++) { - if (Fds[i].Type == WslFdTypeDefault || Fds[i].Type == WslFdTypeTerminalInput || Fds[i].Type == WslFdTypeTerminalOutput) + if (Fds[i].Type == WslFdTypeDefault || Fds[i].Type == WslFdTypeTerminalInput || Fds[i].Type == WslFdTypeTerminalOutput || + Fds[i].Type == WslFdTypeTerminalControl) { - THROW_HR_IF_MSG(E_INVALIDARG, Fds[i].Type > WslFdTypeTerminalOutput, "Invalid flags: %i", Fds[i].Type); THROW_HR_IF_MSG(E_INVALIDARG, Fds[i].Path != nullptr, "Fd[%zu] has a non-null path but flags: %i", i, Fds[i].Type); sockets[i] = ConnectSocket(childChannel, static_cast(Fds[i].Fd)); } @@ -654,7 +655,7 @@ std::vector WSLAVirtualMachine::CreateLinuxProcessImpl( { THROW_HR_IF_MSG( E_INVALIDARG, - WI_IsAnyFlagSet(Fds[i].Type, WslFdTypeTerminalInput | WslFdTypeTerminalOutput), + WI_IsAnyFlagSet(Fds[i].Type, WslFdTypeTerminalInput | WslFdTypeTerminalOutput | WslFdTypeTerminalControl), "Invalid flags: %i", Fds[i].Type); @@ -674,10 +675,11 @@ std::vector WSLAVirtualMachine::CreateLinuxProcessImpl( if (interactiveTty) { auto [grandChildPid, ptyMaster, grandChildChannel] = Fork(childChannel, WSLA_FORK::Pty); - WSLA_TTY_RELAY relayMessage; + WSLA_TTY_RELAY relayMessage{}; relayMessage.TtyMaster = ptyMaster; relayMessage.TtyInput = ttyInput->Fd; relayMessage.TtyOutput = ttyOutput->Fd; + relayMessage.TtyControl = ttyControl == nullptr ? -1 : ttyControl->Fd; childChannel.SendMessage(relayMessage); auto result = ExpectClosedChannelOrError(childChannel); @@ -829,7 +831,8 @@ try } CATCH_RETURN(); -bool WSLAVirtualMachine::ParseTtyInformation(const WSLA_PROCESS_FD* Fds, ULONG FdCount, const WSLA_PROCESS_FD** TtyInput, const WSLA_PROCESS_FD** TtyOutput) +bool WSLAVirtualMachine::ParseTtyInformation( + const WSLA_PROCESS_FD* Fds, ULONG FdCount, const WSLA_PROCESS_FD** TtyInput, const WSLA_PROCESS_FD** TtyOutput, const WSLA_PROCESS_FD** TtyControl) { bool foundNonTtyFd = false; @@ -846,6 +849,11 @@ bool WSLAVirtualMachine::ParseTtyInformation(const WSLA_PROCESS_FD* Fds, ULONG F THROW_HR_IF_MSG(E_INVALIDARG, *TtyOutput != nullptr, "Only one TtyOutput fd can be passed. Index=%lu", i); *TtyOutput = &Fds[i]; } + else if (Fds[i].Type == WslFdTypeTerminalControl) + { + THROW_HR_IF_MSG(E_INVALIDARG, *TtyControl != nullptr, "Only one TtyOutput fd can be passed. Index=%lu", i); + *TtyControl = &Fds[i]; + } else { foundNonTtyFd = true; diff --git a/src/windows/wslaservice/exe/WSLAVirtualMachine.h b/src/windows/wslaservice/exe/WSLAVirtualMachine.h index 2ae3aa2..33016ef 100644 --- a/src/windows/wslaservice/exe/WSLAVirtualMachine.h +++ b/src/windows/wslaservice/exe/WSLAVirtualMachine.h @@ -52,7 +52,8 @@ public: private: static int32_t MountImpl(wsl::shared::SocketChannel& Channel, LPCSTR Source, _In_ LPCSTR Target, _In_ LPCSTR Type, _In_ LPCSTR Options, _In_ ULONG Flags); static void CALLBACK s_OnExit(_In_ HCS_EVENT* Event, _In_opt_ void* Context); - static bool ParseTtyInformation(const WSLA_PROCESS_FD* Fds, ULONG FdCount, const WSLA_PROCESS_FD** TtyInput, const WSLA_PROCESS_FD** TtyOutput); + static bool ParseTtyInformation( + const WSLA_PROCESS_FD* Fds, ULONG FdCount, const WSLA_PROCESS_FD** TtyInput, const WSLA_PROCESS_FD** TtyOutput, const WSLA_PROCESS_FD** TtyControl); void ConfigureNetworking(); void OnExit(_In_ const HCS_EVENT* Event); diff --git a/src/windows/wslrelay/main.cpp b/src/windows/wslrelay/main.cpp index 564219d..ef069a0 100644 --- a/src/windows/wslrelay/main.cpp +++ b/src/windows/wslrelay/main.cpp @@ -40,6 +40,7 @@ try wil::unique_handle exitEvent{}; wil::unique_handle terminalInputHandle{}; wil::unique_handle terminalOutputHandle{}; + wil::unique_socket terminalControlHandle{}; uint32_t port{}; GUID vmId{}; bool disableTelemetry = !wsl::shared::OfficialBuild; @@ -54,6 +55,7 @@ try parser.AddArgument(disableTelemetry, wslrelay::disable_telemetry_option); parser.AddArgument(Handle{terminalInputHandle}, wslrelay::input_option); parser.AddArgument(Handle{terminalOutputHandle}, wslrelay::output_option); + parser.AddArgument(Handle{terminalControlHandle}, wslrelay::control_option); parser.Parse(); // Initialize logging. @@ -145,15 +147,68 @@ try case wslrelay::RelayMode::InteractiveConsoleRelay: { - AllocConsole(); - THROW_HR_IF(E_INVALIDARG, !terminalInputHandle || !terminalOutputHandle); + AllocConsole(); + auto consoleOutputHandle = wil::unique_handle{CreateFileW( + L"CONOUT$", GENERIC_READ | GENERIC_WRITE, FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE, nullptr, OPEN_EXISTING, 0, nullptr)}; + + THROW_LAST_ERROR_IF(!consoleOutputHandle.is_valid()); + + auto consoleInputHandle = wil::unique_handle{CreateFileW( + L"CONIN$", GENERIC_READ | GENERIC_WRITE, FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE, nullptr, OPEN_EXISTING, 0, nullptr)}; + + THROW_LAST_ERROR_IF(!consoleInputHandle.is_valid()); + + // Configure console for interactive usage. + + { + DWORD OutputMode{}; + THROW_LAST_ERROR_IF(!::GetConsoleMode(consoleOutputHandle.get(), &OutputMode)); + + WI_SetAllFlags(OutputMode, ENABLE_PROCESSED_OUTPUT | ENABLE_VIRTUAL_TERMINAL_PROCESSING | DISABLE_NEWLINE_AUTO_RETURN); + THROW_IF_WIN32_BOOL_FALSE(SetConsoleMode(consoleOutputHandle.get(), OutputMode)); + } + + { + DWORD InputMode{}; + THROW_LAST_ERROR_IF(!::GetConsoleMode(consoleInputHandle.get(), &InputMode)); + + WI_SetAllFlags(InputMode, (ENABLE_WINDOW_INPUT | ENABLE_VIRTUAL_TERMINAL_INPUT)); + WI_ClearAllFlags(InputMode, (ENABLE_ECHO_INPUT | ENABLE_INSERT_MODE | ENABLE_LINE_INPUT | ENABLE_PROCESSED_INPUT)); + THROW_IF_WIN32_BOOL_FALSE(SetConsoleMode(consoleInputHandle.get(), InputMode)); + } + + THROW_LAST_ERROR_IF(!::SetConsoleOutputCP(CP_UTF8)); + // Create a thread to relay stdin to the pipe. - wsl::windows::common::SvcCommIo Io; auto exitEvent = wil::unique_event(wil::EventOptions::ManualReset); + + std::optional controlChannel; + if (terminalControlHandle) + { + controlChannel.emplace(std::move(terminalControlHandle), "TerminalControl", exitEvent.get()); + } + std::thread inputThread([&]() { - wsl::windows::common::RelayStandardInput(GetStdHandle(STD_INPUT_HANDLE), terminalInputHandle.get(), {}, exitEvent.get(), &Io); + auto updateTerminal = [&controlChannel, &consoleOutputHandle]() { + if (controlChannel.has_value()) + { + CONSOLE_SCREEN_BUFFER_INFOEX info{}; + info.cbSize = sizeof(info); + + THROW_IF_WIN32_BOOL_FALSE(GetConsoleScreenBufferInfoEx(consoleOutputHandle.get(), &info)); + + WSLA_TERMINAL_CHANGED message{}; + message.Columns = info.srWindow.Right - info.srWindow.Left + 1; + message.Rows = info.srWindow.Bottom - info.srWindow.Top + 1; + + controlChannel->SendMessage(message); + } + }; + + wsl::windows::common::relay::StandardInputRelay( + GetStdHandle(STD_INPUT_HANDLE), terminalInputHandle.get(), updateTerminal, exitEvent.get()); }); auto joinThread = wil::scope_exit_log(WI_DIAGNOSTICS_INFO, [&]() { diff --git a/test/windows/WSLATests.cpp b/test/windows/WSLATests.cpp index 672d1cc..394775c 100644 --- a/test/windows/WSLATests.cpp +++ b/test/windows/WSLATests.cpp @@ -441,17 +441,21 @@ class WSLATests auto vm = CreateVm(&settings); std::vector commandLine{"/bin/sh", nullptr}; - std::vector fds(2); + std::vector fds(3); fds[0].Number = 0; fds[0].Type = WslFdTypeTerminalInput; fds[1].Number = 1; fds[1].Type = WslFdTypeTerminalOutput; + fds[2].Number = 2; + fds[2].Type = WslFdTypeTerminalControl; + const char* env[] = {"TERM=xterm-256color", nullptr}; WslCreateProcessSettings WslCreateProcessSettings{}; WslCreateProcessSettings.Executable = "/bin/sh"; WslCreateProcessSettings.Arguments = commandLine.data(); WslCreateProcessSettings.FileDescriptors = fds.data(); WslCreateProcessSettings.FdCount = static_cast(fds.size()); + WslCreateProcessSettings.Environment = env; int pid = -1; VERIFY_SUCCEEDED(WslCreateLinuxProcess(vm.get(), &WslCreateProcessSettings, &pid)); @@ -487,11 +491,14 @@ class WSLATests // Validate that the interactive process successfully starts wil::unique_handle process; VERIFY_SUCCEEDED(WslLaunchInteractiveTerminal( - WslCreateProcessSettings.FileDescriptors[0].Handle, WslCreateProcessSettings.FileDescriptors[1].Handle, &process)); + WslCreateProcessSettings.FileDescriptors[0].Handle, + WslCreateProcessSettings.FileDescriptors[1].Handle, + WslCreateProcessSettings.FileDescriptors[2].Handle, + &process)); // Exit the shell writeTty("exit\n"); - VERIFY_ARE_EQUAL(WaitForSingleObject(process.get(), 30 * 1000), WAIT_OBJECT_0); + VERIFY_ARE_EQUAL(WaitForSingleObject(process.get(), 30000 * 1000), WAIT_OBJECT_0); } TEST_METHOD(NATNetworking) From 2abbe8baf38b684b99ec864c848a98c81209548a Mon Sep 17 00:00:00 2001 From: Blue Date: Fri, 31 Oct 2025 18:49:25 -0700 Subject: [PATCH 2/3] =?UTF-8?q?wsla:=20Update=20the=20CreateProcess=20logi?= =?UTF-8?q?c=20to=20support=20allocating=20file=20descr=E2=80=A6=20(#13655?= =?UTF-8?q?)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * wsla: Update the CreateProcess logic to support allocating file descriptors from the guest * PR feedback --- src/linux/init/WSLAInit.cpp | 13 +++- src/linux/init/util.cpp | 6 +- src/linux/init/util.h | 2 +- .../wslaservice/exe/WSLAVirtualMachine.cpp | 77 +++++++++++++------ .../wslaservice/exe/WSLAVirtualMachine.h | 20 ++++- 5 files changed, 87 insertions(+), 31 deletions(-) diff --git a/src/linux/init/WSLAInit.cpp b/src/linux/init/WSLAInit.cpp index f839ae2..a990bab 100644 --- a/src/linux/init/WSLAInit.cpp +++ b/src/linux/init/WSLAInit.cpp @@ -85,10 +85,19 @@ void HandleMessageImpl(wsl::shared::SocketChannel& Channel, const WSLA_ACCEPT& M Channel.SendResultMessage(SocketAddress.svm_port); - wil::unique_fd Socket{UtilAcceptVsock(ListenSocket.get(), SocketAddress, SESSION_LEADER_ACCEPT_TIMEOUT_MS)}; + wil::unique_fd Socket{ + UtilAcceptVsock(ListenSocket.get(), SocketAddress, SESSION_LEADER_ACCEPT_TIMEOUT_MS, Message.Fd != -1 ? SOCK_CLOEXEC : 0)}; THROW_LAST_ERROR_IF(!Socket); - THROW_LAST_ERROR_IF(dup2(Socket.get(), Message.Fd) < 0); + if (Message.Fd != -1) + { + THROW_LAST_ERROR_IF(dup2(Socket.get(), Message.Fd) < 0); + } + else + { + Channel.SendResultMessage(Socket.get()); + Socket.release(); + } } void HandleMessageImpl(wsl::shared::SocketChannel& Channel, const WSLA_CONNECT& Message, const gsl::span& Buffer) diff --git a/src/linux/init/util.cpp b/src/linux/init/util.cpp index 6e35c39..2eb1780 100644 --- a/src/linux/init/util.cpp +++ b/src/linux/init/util.cpp @@ -197,7 +197,7 @@ InteropServer::~InteropServer() Reset(); } -int UtilAcceptVsock(int SocketFd, sockaddr_vm SocketAddress, int Timeout) +int UtilAcceptVsock(int SocketFd, sockaddr_vm SocketAddress, int Timeout, int SocketFlags) /*++ @@ -215,6 +215,8 @@ Arguments: Timeout - Supplies a timeout. + SocketFlags - Supplies the socket flags. + Return Value: A file descriptor representing the socket, -1 on failure. @@ -263,7 +265,7 @@ Return Value: if (Result != -1) { socklen_t SocketAddressSize = sizeof(SocketAddress); - Result = accept4(SocketFd, reinterpret_cast(&SocketAddress), &SocketAddressSize, SOCK_CLOEXEC); + Result = accept4(SocketFd, reinterpret_cast(&SocketAddress), &SocketAddressSize, SocketFlags); } if (Result < 0) diff --git a/src/linux/init/util.h b/src/linux/init/util.h index 9c68c66..718ea39 100644 --- a/src/linux/init/util.h +++ b/src/linux/init/util.h @@ -117,7 +117,7 @@ private: wil::unique_fd m_InteropSocket; }; -int UtilAcceptVsock(int SocketFd, sockaddr_vm Address, int Timeout = -1); +int UtilAcceptVsock(int SocketFd, sockaddr_vm Address, int Timeout = -1, int SocketFlags = SOCK_CLOEXEC); int UtilBindVsockAnyPort(struct sockaddr_vm* SocketAddress, int Type); diff --git a/src/windows/wslaservice/exe/WSLAVirtualMachine.cpp b/src/windows/wslaservice/exe/WSLAVirtualMachine.cpp index 2800e2d..bce7bad 100644 --- a/src/windows/wslaservice/exe/WSLAVirtualMachine.cpp +++ b/src/windows/wslaservice/exe/WSLAVirtualMachine.cpp @@ -328,30 +328,45 @@ void WSLAVirtualMachine::ConfigureNetworking() // WSLA-TODO: Using fd=4 here seems to hang gns. There's probably a hardcoded file descriptor somewhere that's causing // so using 1000 for now. std::vector fds(1); - fds[0].Fd = 1000; + fds[0].Fd = -1; fds[0].Type = WslFdType::WslFdTypeDefault; - std::vector cmd{"/gns", LX_INIT_GNS_SOCKET_ARG, "1000"}; + std::vector cmd{"/gns", LX_INIT_GNS_SOCKET_ARG}; // If DNS tunnelling is enabled, use an additional for its channel. if (m_settings.EnableDnsTunneling) { - fds.emplace_back(WSLA_PROCESS_FD{.Fd = 1001, .Type = WslFdType::WslFdTypeDefault}); - cmd.emplace_back(LX_INIT_GNS_DNS_SOCKET_ARG); - cmd.emplace_back("1001"); - cmd.emplace_back(LX_INIT_GNS_DNS_TUNNELING_IP); - cmd.emplace_back(LX_INIT_DNS_TUNNELING_IP_ADDRESS); - + fds.emplace_back(WSLA_PROCESS_FD{.Fd = -1, .Type = WslFdType::WslFdTypeDefault}); THROW_IF_FAILED(wsl::core::networking::DnsResolver::LoadDnsResolverMethods()); } WSLA_CREATE_PROCESS_OPTIONS options{}; options.Executable = "/init"; - options.CommandLine = cmd.data(); - options.CommandLineCount = static_cast(cmd.size()); + + // Because the file descriptors numbers aren't known in advance, the command line needs to be generated after the file + // descriptors are allocated. + + std::string socketFdArg; + std::string dnsFdArg; + auto prepareCommandLine = [&](const auto& sockets) { + socketFdArg = std::to_string(sockets[0].Fd); + cmd.emplace_back(socketFdArg.c_str()); + + if (sockets.size() > 1) + { + dnsFdArg = std::to_string(sockets[1].Fd); + cmd.emplace_back(LX_INIT_GNS_DNS_SOCKET_ARG); + cmd.emplace_back(dnsFdArg.c_str()); + cmd.emplace_back(LX_INIT_GNS_DNS_TUNNELING_IP); + cmd.emplace_back(LX_INIT_DNS_TUNNELING_IP_ADDRESS); + } + + options.CommandLine = cmd.data(); + options.CommandLineCount = static_cast(cmd.size()); + }; WSLA_CREATE_PROCESS_RESULT result{}; - auto sockets = CreateLinuxProcessImpl(&options, static_cast(fds.size()), fds.data(), &result); + auto sockets = CreateLinuxProcessImpl(&options, static_cast(fds.size()), fds.data(), &result, prepareCommandLine); THROW_HR_IF(E_FAIL, result.Errno != 0); @@ -364,13 +379,12 @@ void WSLAVirtualMachine::ConfigureNetworking() config.FirewallConfig.reset(); }*/ - // TODO: DNS Tunneling support m_networkEngine = std::make_unique( m_computeSystem.get(), wsl::core::NatNetworking::CreateNetwork(config), - std::move(sockets[0]), + std::move(sockets[0].Socket), config, - sockets.size() > 1 ? std::move(sockets[1]) : wil::unique_socket{}); + sockets.size() > 1 ? std::move(sockets[1].Socket) : wil::unique_socket{}); m_networkEngine->Initialize(); @@ -587,13 +601,26 @@ std::tuple WSLAVirtualMachine::For return std::make_tuple(pid, ptyMaster, wsl::shared::SocketChannel{std::move(socket), std::to_string(pid), m_vmTerminatingEvent.get()}); } -wil::unique_socket WSLAVirtualMachine::ConnectSocket(wsl::shared::SocketChannel& Channel, int32_t Fd) +WSLAVirtualMachine::ConnectedSocket WSLAVirtualMachine::ConnectSocket(wsl::shared::SocketChannel& Channel, int32_t Fd) { WSLA_ACCEPT message{}; message.Fd = Fd; const auto& response = Channel.Transaction(message); + ConnectedSocket socket; - return wsl::windows::common::hvsocket::Connect(m_vmId, response.Result); + socket.Socket = wsl::windows::common::hvsocket::Connect(m_vmId, response.Result); + + // If the FD was unspecified, read the Linux file descriptor from the guest. + if (Fd == -1) + { + socket.Fd = Channel.ReceiveMessage>().Result; + } + else + { + socket.Fd = Fd; + } + + return socket; } void WSLAVirtualMachine::OpenLinuxFile(wsl::shared::SocketChannel& Channel, const char* Path, uint32_t Flags, int32_t Fd) @@ -621,10 +648,10 @@ try for (size_t i = 0; i < sockets.size(); i++) { - if (sockets[i]) + if (sockets[i].Socket) { - Handles[i] = - HandleToUlong(wsl::windows::common::wslutil::DuplicateHandleToCallingProcess(reinterpret_cast(sockets[i].get()))); + Handles[i] = HandleToUlong( + wsl::windows::common::wslutil::DuplicateHandleToCallingProcess(reinterpret_cast(sockets[i].Socket.get()))); } } @@ -632,8 +659,12 @@ try } CATCH_RETURN(); -std::vector WSLAVirtualMachine::CreateLinuxProcessImpl( - _In_ const WSLA_CREATE_PROCESS_OPTIONS* Options, _In_ ULONG FdCount, _In_ WSLA_PROCESS_FD* Fds, _Out_ WSLA_CREATE_PROCESS_RESULT* Result) +std::vector WSLAVirtualMachine::CreateLinuxProcessImpl( + _In_ const WSLA_CREATE_PROCESS_OPTIONS* Options, + _In_ ULONG FdCount, + _In_ WSLA_PROCESS_FD* Fds, + _Out_ WSLA_CREATE_PROCESS_RESULT* Result, + const TPrepareCommandLine& PrepareCommandLine) { // Check if this is a tty or not const WSLA_PROCESS_FD* ttyInput = nullptr; @@ -642,7 +673,7 @@ std::vector WSLAVirtualMachine::CreateLinuxProcessImpl( auto interactiveTty = ParseTtyInformation(Fds, FdCount, &ttyInput, &ttyOutput, &ttyControl); auto [pid, _, childChannel] = Fork(WSLA_FORK::Process); - std::vector sockets(FdCount); + std::vector sockets(FdCount); for (size_t i = 0; i < FdCount; i++) { if (Fds[i].Type == WslFdTypeDefault || Fds[i].Type == WslFdTypeTerminalInput || Fds[i].Type == WslFdTypeTerminalOutput || @@ -664,6 +695,8 @@ std::vector WSLAVirtualMachine::CreateLinuxProcessImpl( } } + PrepareCommandLine(sockets); + wsl::shared::MessageWriter Message; Message.WriteString(Message->ExecutableIndex, Options->Executable); diff --git a/src/windows/wslaservice/exe/WSLAVirtualMachine.h b/src/windows/wslaservice/exe/WSLAVirtualMachine.h index 33016ef..9f64afa 100644 --- a/src/windows/wslaservice/exe/WSLAVirtualMachine.h +++ b/src/windows/wslaservice/exe/WSLAVirtualMachine.h @@ -50,6 +50,14 @@ public: IFACEMETHOD(MountGpuLibraries(_In_ LPCSTR LibrariesMountPoint, _In_ LPCSTR DriversMountpoint, _In_ DWORD Flags)) override; private: + struct ConnectedSocket + { + int Fd; + wil::unique_socket Socket; + }; + + using TPrepareCommandLine = std::function&)>; + static int32_t MountImpl(wsl::shared::SocketChannel& Channel, LPCSTR Source, _In_ LPCSTR Target, _In_ LPCSTR Type, _In_ LPCSTR Options, _In_ ULONG Flags); static void CALLBACK s_OnExit(_In_ HCS_EVENT* Event, _In_opt_ void* Context); static bool ParseTtyInformation( @@ -62,12 +70,16 @@ private: std::tuple Fork(wsl::shared::SocketChannel& Channel, enum WSLA_FORK::ForkType Type); int32_t ExpectClosedChannelOrError(wsl::shared::SocketChannel& Channel); - wil::unique_socket ConnectSocket(wsl::shared::SocketChannel& Channel, int32_t Fd); - void OpenLinuxFile(wsl::shared::SocketChannel& Channel, const char* Path, uint32_t Flags, int32_t Fd); + ConnectedSocket ConnectSocket(wsl::shared::SocketChannel& Channel, int32_t Fd); + static void OpenLinuxFile(wsl::shared::SocketChannel& Channel, const char* Path, uint32_t Flags, int32_t Fd); void LaunchPortRelay(); - std::vector CreateLinuxProcessImpl( - _In_ const WSLA_CREATE_PROCESS_OPTIONS* Options, _In_ ULONG FdCount, _In_ WSLA_PROCESS_FD* Fd, _Out_ WSLA_CREATE_PROCESS_RESULT* Result); + std::vector CreateLinuxProcessImpl( + _In_ const WSLA_CREATE_PROCESS_OPTIONS* Options, + _In_ ULONG FdCount, + _In_ WSLA_PROCESS_FD* Fd, + _Out_ WSLA_CREATE_PROCESS_RESULT* Result, + const TPrepareCommandLine& PrepareCommandLine = [](const auto&) {}); HRESULT MountWindowsFolderImpl(_In_ LPCWSTR WindowsPath, _In_ LPCSTR LinuxPath, _In_ BOOL ReadOnly, _In_ WslMountFlags Flags); From 46db45c39aa791e77cd68ab5e8e7e7f167f79c52 Mon Sep 17 00:00:00 2001 From: Blue Date: Tue, 4 Nov 2025 13:09:20 -0800 Subject: [PATCH 3/3] Add a temporary flag to wsl.exe to create a WSLA shell (#13664) * Add a temporary flag to wsl.exe to create a WSLA shell * Improve some error paths * Add dependency --- src/shared/inc/CommandLine.h | 17 ++ src/windows/common/CMakeLists.txt | 3 +- src/windows/common/WslClient.cpp | 167 ++++++++++++++++++++ src/windows/wslaservice/inc/wslaservice.idl | 1 - 4 files changed, 186 insertions(+), 2 deletions(-) diff --git a/src/shared/inc/CommandLine.h b/src/shared/inc/CommandLine.h index dc98122..9966c5e 100644 --- a/src/shared/inc/CommandLine.h +++ b/src/shared/inc/CommandLine.h @@ -187,6 +187,23 @@ struct Handle } }; +struct Utf8String +{ + std::string& Value; + + int operator()(const TChar* Input) const + { + if (Input == nullptr) + { + return -1; + } + + Value = wsl::shared::string::WideToMultiByte(Input); + + return 1; + } +}; + #else struct UniqueFd diff --git a/src/windows/common/CMakeLists.txt b/src/windows/common/CMakeLists.txt index 150ee09..e451d5e 100644 --- a/src/windows/common/CMakeLists.txt +++ b/src/windows/common/CMakeLists.txt @@ -119,8 +119,9 @@ set(HEADERS wslutil.cpp ) +include_directories(${CMAKE_SOURCE_DIR}/src/windows/wslaclient) add_library(common STATIC ${SOURCES} ${HEADERS}) -add_dependencies(common wslserviceidl localization wslservicemc wslinstalleridl) +add_dependencies(common wslserviceidl localization wslservicemc wslinstalleridl wslaserviceproxystub) target_precompile_headers(common PRIVATE precomp.h) set_target_properties(common PROPERTIES FOLDER windows) diff --git a/src/windows/common/WslClient.cpp b/src/windows/common/WslClient.cpp index 26cfdfc..cbb9823 100644 --- a/src/windows/common/WslClient.cpp +++ b/src/windows/common/WslClient.cpp @@ -18,6 +18,8 @@ Abstract: #include "Distribution.h" #include "CommandLine.h" #include +#include "wslaservice.h" +#include "WSLAApi.h" #define BASH_PATH L"/bin/bash" @@ -1539,6 +1541,167 @@ int RunDebugShell() THROW_HR(HCS_E_CONNECTION_CLOSED); } +// Temporary debugging tool for WSLA +int WslaShell(_In_ std::wstring_view commandLine) +{ +#ifdef WSL_SYSTEM_DISTRO_PATH + + std::wstring vhd = TEXT(WSL_SYSTEM_DISTRO_PATH); + +#else + + std::wstring vhd = wsl::windows::common::wslutil::GetMsiPackagePath().value() + L"/system.vhd"; + +#endif + + VIRTUAL_MACHINE_SETTINGS settings{}; + settings.CpuCount = 4; + settings.DisplayName = L"WSLA"; + settings.MemoryMb = 1024; + settings.BootTimeoutMs = 30000; + settings.NetworkingMode = WslNetworkingModeNAT; + std::string shell = "/bin/bash"; + bool help = false; + + ArgumentParser parser(std::wstring{commandLine}, WSL_BINARY_NAME); + parser.AddArgument(vhd, L"--vhd"); + parser.AddArgument(Utf8String(shell), L"--shell"); + parser.AddArgument(reinterpret_cast(settings.EnableDnsTunneling), L"--dns-tunneling"); + parser.AddArgument(Integer(settings.MemoryMb), L"--memory"); + parser.AddArgument(Integer(settings.CpuCount), L"--cpu"); + parser.AddArgument(help, L"--help"); + + parser.Parse(); + + if (help) + { + const auto usage = std::format( + LR"({} --wsla [--vhd ] [--shell ] [--memory ] [--cpu ] [--dns-tunneling] [--help])", + WSL_BINARY_NAME); + + wprintf(L"%ls\n", usage.c_str()); + return 1; + } + + wil::com_ptr userSession; + THROW_IF_FAILED(CoCreateInstance(__uuidof(WSLAUserSession), nullptr, CLSCTX_LOCAL_SERVER, IID_PPV_ARGS(&userSession))); + wsl::windows::common::security::ConfigureForCOMImpersonation(userSession.get()); + + wil::com_ptr virtualMachine; + THROW_IF_FAILED(userSession->CreateVirtualMachine(&settings, &virtualMachine)); + wsl::windows::common::security::ConfigureForCOMImpersonation(userSession.get()); + + wil::unique_cotaskmem_ansistring diskDevice; + ULONG Lun{}; + THROW_IF_FAILED(virtualMachine->AttachDisk(vhd.c_str(), true, &diskDevice, &Lun)); + + THROW_IF_FAILED(virtualMachine->Mount(diskDevice.get(), "/mnt", "ext4", "ro", WslMountFlagsChroot | WslMountFlagsWriteableOverlayFs)); + THROW_IF_FAILED(virtualMachine->Mount(nullptr, "/dev", "devtmpfs", "", 0)); + THROW_IF_FAILED(virtualMachine->Mount(nullptr, "/sys", "sysfs", "", 0)); + THROW_IF_FAILED(virtualMachine->Mount(nullptr, "/proc", "proc", "", 0)); + THROW_IF_FAILED(virtualMachine->Mount(nullptr, "/dev/pts", "devpts", "noatime,nosuid,noexec,gid=5,mode=620", 0)); + + std::vector shellCommandLine{shell.c_str()}; + std::vector env{"TERM=xterm-256color"}; + + std::vector fds(3); + fds[0].Fd = 0; + fds[0].Type = WslFdTypeTerminalInput; + fds[1].Fd = 1; + fds[1].Type = WslFdTypeTerminalOutput; + fds[2].Fd = 2; + fds[2].Type = WslFdTypeTerminalControl; + + WSLA_CREATE_PROCESS_OPTIONS processOptions{}; + processOptions.Executable = shell.c_str(); + processOptions.CommandLine = shellCommandLine.data(); + processOptions.CommandLineCount = static_cast(shellCommandLine.size()); + processOptions.Environment = env.data(); + processOptions.EnvironmentCount = static_cast(env.size()); + processOptions.CurrentDirectory = "/"; + + std::vector handles(fds.size()); + + WSLA_CREATE_PROCESS_RESULT result{}; + auto createProcessResult = + virtualMachine->CreateLinuxProcess(&processOptions, static_cast(fds.size()), fds.data(), handles.data(), &result); + + if (FAILED(createProcessResult)) + { + if (result.Errno != 0) + { + THROW_HR_WITH_USER_ERROR(E_FAIL, std::format(L"Failed to create process {}, errno = {}", shell.c_str(), result.Errno)); + } + else + { + THROW_HR(createProcessResult); + } + } + + // Configure console for interactive usage. + + HANDLE Stdout = GetStdHandle(STD_OUTPUT_HANDLE); + HANDLE Stdin = GetStdHandle(STD_INPUT_HANDLE); + { + DWORD OutputMode{}; + THROW_LAST_ERROR_IF(!::GetConsoleMode(Stdout, &OutputMode)); + + WI_SetAllFlags(OutputMode, ENABLE_PROCESSED_OUTPUT | ENABLE_VIRTUAL_TERMINAL_PROCESSING | DISABLE_NEWLINE_AUTO_RETURN); + THROW_IF_WIN32_BOOL_FALSE(SetConsoleMode(Stdout, OutputMode)); + } + + { + DWORD InputMode{}; + THROW_LAST_ERROR_IF(!::GetConsoleMode(Stdin, &InputMode)); + + WI_SetAllFlags(InputMode, (ENABLE_WINDOW_INPUT | ENABLE_VIRTUAL_TERMINAL_INPUT)); + WI_ClearAllFlags(InputMode, (ENABLE_ECHO_INPUT | ENABLE_INSERT_MODE | ENABLE_LINE_INPUT | ENABLE_PROCESSED_INPUT)); + THROW_IF_WIN32_BOOL_FALSE(SetConsoleMode(Stdin, InputMode)); + } + + THROW_LAST_ERROR_IF(!::SetConsoleOutputCP(CP_UTF8)); + + { + // Create a thread to relay stdin to the pipe. + auto exitEvent = wil::unique_event(wil::EventOptions::ManualReset); + + wsl::shared::SocketChannel controlChannel{wil::unique_socket(handles[2]), "TerminalControl", exitEvent.get()}; + + std::thread inputThread([&]() { + auto updateTerminal = [&controlChannel, &Stdout]() { + CONSOLE_SCREEN_BUFFER_INFOEX info{}; + info.cbSize = sizeof(info); + + THROW_IF_WIN32_BOOL_FALSE(GetConsoleScreenBufferInfoEx(Stdout, &info)); + + WSLA_TERMINAL_CHANGED message{}; + message.Columns = info.srWindow.Right - info.srWindow.Left + 1; + message.Rows = info.srWindow.Bottom - info.srWindow.Top + 1; + + controlChannel.SendMessage(message); + }; + + wsl::windows::common::relay::StandardInputRelay(Stdin, UlongToHandle(handles[0]), updateTerminal, exitEvent.get()); + }); + + auto joinThread = wil::scope_exit_log(WI_DIAGNOSTICS_INFO, [&]() { + exitEvent.SetEvent(); + inputThread.join(); + }); + + // Relay the contents of the pipe to stdout. + wsl::windows::common::relay::InterruptableRelay(UlongToHandle(handles[1]), Stdout); + } + + ULONG exitState{}; + int exitCode{}; + THROW_IF_FAILED(virtualMachine->WaitPid(result.Pid, 0, &exitState, &exitCode)); + + wprintf(L"%hs exited with: %i", shell.c_str(), exitCode); + + return exitCode; +} + int WslMain(_In_ std::wstring_view commandLine) { // Call the MSI package if we're in an MSIX context @@ -1772,6 +1935,10 @@ int WslMain(_In_ std::wstring_view commandLine) { return Uninstall(); } + else if (argument == L"--wsla") + { + return WslaShell(commandLine); + } else { if ((argument.size() > 0) && (argument[0] == L'-')) diff --git a/src/windows/wslaservice/inc/wslaservice.idl b/src/windows/wslaservice/inc/wslaservice.idl index 177a09e..3698719 100644 --- a/src/windows/wslaservice/inc/wslaservice.idl +++ b/src/windows/wslaservice/inc/wslaservice.idl @@ -99,7 +99,6 @@ struct _VIRTUAL_MACHINE_SETTINGS { BOOL EnableDebugShell; BOOL EnableEarlyBootDmesg; BOOL EnableGPU; - BOOL EnableDnsTunnelling; } VIRTUAL_MACHINE_SETTINGS;