diff --git a/src/windows/common/svccommio.cpp b/src/windows/common/svccommio.cpp index a6a012d..d9c0dd7 100644 --- a/src/windows/common/svccommio.cpp +++ b/src/windows/common/svccommio.cpp @@ -18,265 +18,165 @@ Abstract: #pragma hdrstop namespace { -void ChangeConsoleMode(_In_ HANDLE File, _In_ DWORD ConsoleMode) + +bool IsConsoleHandle(_In_ HANDLE Handle) { - // - // Use the invalid parameter error code to detect the v1 console that does - // not support the provided mode. This can be improved in the future when - // a more elegant solution exists. - // - // N.B. Ignore failures setting the mode if the console has already - // disconnected. - // + DWORD Mode; + return GetFileType(Handle) == FILE_TYPE_CHAR && GetConsoleMode(Handle, &Mode); +} - if (!SetConsoleMode(File, ConsoleMode)) +void TrySetConsoleMode(_In_ HANDLE Handle, _In_ DWORD Mode) +{ + if (!SetConsoleMode(Handle, Mode)) { - switch (GetLastError()) + const auto Error = GetLastError(); + if (Error != ERROR_INVALID_PARAMETER && Error != ERROR_PIPE_NOT_CONNECTED) { - case ERROR_INVALID_PARAMETER: - THROW_HR(WSL_E_CONSOLE); - - case ERROR_PIPE_NOT_CONNECTED: - break; - - default: - THROW_LAST_ERROR(); + LOG_IF_WIN32_ERROR(Error); } } } -void ConfigureStdHandles(_Inout_ LXSS_STD_HANDLES_INFO& StdHandlesInfo) -{ - // - // Check stdin to see if it is a console or another device. If it is - // a console, configure it to raw processing mode and VT-100 support. If the - // force console I/O is requested, ignore stdin and get active console input - // handle instead. - // - - UINT NewConsoleInputCP = 0; - DWORD NewConsoleInputMode = 0; - BOOLEAN IsConsoleInput = StdHandlesInfo.IsConsoleInput; - BOOLEAN IsConsoleOutput = StdHandlesInfo.IsConsoleOutput; - BOOLEAN IsConsoleError = StdHandlesInfo.IsConsoleError; - DWORD SavedInputMode = StdHandlesInfo.SavedInputMode; - DWORD SavedOutputMode = StdHandlesInfo.SavedOutputMode; - UINT SavedInputCP = StdHandlesInfo.SavedInputCP; - UINT SavedOutputCP = StdHandlesInfo.SavedOutputCP; - CONSOLE_SCREEN_BUFFER_INFO ScreenBufferInfo; - auto RestoreInputHandle = wil::scope_exit_log(WI_DIAGNOSTICS_INFO, [&] { - if (NewConsoleInputCP != 0) - { - SetConsoleCP(SavedInputCP); - } - - if (NewConsoleInputMode != 0) - { - ChangeConsoleMode(StdHandlesInfo.InputHandle, SavedInputMode); - } - }); - - IsConsoleInput = FALSE; - if ((GetFileType(StdHandlesInfo.InputHandle) == FILE_TYPE_CHAR) && (GetConsoleMode(StdHandlesInfo.InputHandle, &SavedInputMode))) - { - IsConsoleInput = TRUE; - NewConsoleInputMode = SavedInputMode; - WI_SetAllFlags(NewConsoleInputMode, (ENABLE_WINDOW_INPUT | ENABLE_VIRTUAL_TERMINAL_INPUT)); - WI_ClearAllFlags(NewConsoleInputMode, (ENABLE_ECHO_INPUT | ENABLE_INSERT_MODE | ENABLE_LINE_INPUT | ENABLE_PROCESSED_INPUT)); - ChangeConsoleMode(StdHandlesInfo.InputHandle, NewConsoleInputMode); - - // - // Set the console input to the UTF-8 code page. - // - - SavedInputCP = GetConsoleCP(); - NewConsoleInputCP = CP_UTF8; - THROW_LAST_ERROR_IF(!::SetConsoleCP(NewConsoleInputCP)); - } - - bool RestoreMode = false; - bool RestoreCp = false; - auto RestoreOutput = wil::scope_exit([&] { - if (RestoreMode) - { - SetConsoleMode(StdHandlesInfo.ConsoleOutputHandle.get(), SavedOutputMode); - } - - if (RestoreCp) - { - SetConsoleOutputCP(SavedOutputCP); - } - }); - - // - // If there is a console output handle, save the output mode and codepage so - // it can be restored. - // - - if (StdHandlesInfo.ConsoleOutputHandle) - { - THROW_LAST_ERROR_IF(!::GetConsoleMode(StdHandlesInfo.ConsoleOutputHandle.get(), &SavedOutputMode)); - - // - // Temporarily try both with and without the custom flag to disable newline - // auto return. - // - - DWORD NewConsoleOutputMode = SavedOutputMode | ENABLE_PROCESSED_OUTPUT | ENABLE_VIRTUAL_TERMINAL_PROCESSING | DISABLE_NEWLINE_AUTO_RETURN; - if (SetConsoleMode(StdHandlesInfo.ConsoleOutputHandle.get(), NewConsoleOutputMode) == FALSE) - { - WI_ClearFlag(NewConsoleOutputMode, DISABLE_NEWLINE_AUTO_RETURN); - ChangeConsoleMode(StdHandlesInfo.ConsoleOutputHandle.get(), NewConsoleOutputMode); - } - - RestoreMode = true; - - // - // Set the console output to the UTF-8 code page. - // - - SavedOutputCP = GetConsoleOutputCP(); - THROW_LAST_ERROR_IF(!::SetConsoleOutputCP(CP_UTF8)); - - RestoreCp = true; - } - - // - // If the force console I/O is requested, ignore stdout and treat the - // console as the output handle. - // - - IsConsoleOutput = FALSE; - if ((GetFileType(StdHandlesInfo.OutputHandle) == FILE_TYPE_CHAR) && - (GetConsoleScreenBufferInfo(StdHandlesInfo.OutputHandle, &ScreenBufferInfo))) - { - IsConsoleOutput = TRUE; - } - - IsConsoleError = FALSE; - if ((GetFileType(StdHandlesInfo.ErrorHandle) == FILE_TYPE_CHAR) && - (GetConsoleScreenBufferInfo(StdHandlesInfo.ErrorHandle, &ScreenBufferInfo))) - { - IsConsoleError = TRUE; - } - - RestoreInputHandle.release(); - RestoreOutput.release(); - StdHandlesInfo.IsConsoleInput = IsConsoleInput; - StdHandlesInfo.IsConsoleOutput = IsConsoleOutput; - StdHandlesInfo.IsConsoleError = IsConsoleError; - StdHandlesInfo.SavedInputMode = SavedInputMode; - StdHandlesInfo.SavedOutputMode = SavedOutputMode; - StdHandlesInfo.SavedInputCP = SavedInputCP; - StdHandlesInfo.SavedOutputCP = SavedOutputCP; -} } // namespace -wsl::windows::common::SvcCommIo::SvcCommIo() +namespace wsl::windows::common { + +// ConsoleInput implementation +std::unique_ptr ConsoleInput::Create(HANDLE Handle) { - _stdHandlesInfo.InputHandle = GetStdHandle(STD_INPUT_HANDLE); - _stdHandlesInfo.OutputHandle = GetStdHandle(STD_OUTPUT_HANDLE); - _stdHandlesInfo.ErrorHandle = GetStdHandle(STD_ERROR_HANDLE); - _stdHandlesInfo.ConsoleOutputHandle.reset( - CreateFileW(L"CONOUT$", GENERIC_READ | GENERIC_WRITE, FILE_SHARE_READ | FILE_SHARE_WRITE, nullptr, OPEN_EXISTING, 0, nullptr)); - - ConfigureStdHandles(_stdHandlesInfo); - _stdHandles.StdIn.HandleType = LxssHandleInput; - _stdHandles.StdIn.Handle = HandleToUlong(_stdHandlesInfo.InputHandle); - _stdHandles.StdOut.HandleType = LxssHandleOutput; - _stdHandles.StdOut.Handle = HandleToUlong(_stdHandlesInfo.OutputHandle); - _stdHandles.StdErr.HandleType = LxssHandleOutput; - _stdHandles.StdErr.Handle = HandleToUlong(_stdHandlesInfo.ErrorHandle); - - // - // N.B.: The console handle is not supposed to be closed, it is just copied - // from PEB. - // - - if (_stdHandlesInfo.IsConsoleInput) + DWORD Mode; + if (GetFileType(Handle) == FILE_TYPE_CHAR && GetConsoleMode(Handle, &Mode)) { - _stdHandles.StdIn.Handle = LXSS_HANDLE_USE_CONSOLE; - _stdHandles.StdIn.HandleType = LxssHandleConsole; + return std::unique_ptr(new ConsoleInput(Handle, Mode)); } - if (_stdHandlesInfo.IsConsoleOutput) - { - _stdHandles.StdOut.Handle = LXSS_HANDLE_USE_CONSOLE; - _stdHandles.StdOut.HandleType = LxssHandleConsole; - } + return nullptr; +} - if (_stdHandlesInfo.IsConsoleError) +ConsoleInput::ConsoleInput(HANDLE Handle, DWORD SavedMode) : m_Handle(Handle), m_SavedMode(SavedMode) +{ + // Save code page + m_SavedCodePage = GetConsoleCP(); + + // Configure for raw input with VT support + DWORD NewMode = m_SavedMode; + WI_SetAllFlags(NewMode, ENABLE_WINDOW_INPUT | ENABLE_VIRTUAL_TERMINAL_INPUT); + WI_ClearAllFlags(NewMode, ENABLE_ECHO_INPUT | ENABLE_INSERT_MODE | ENABLE_LINE_INPUT | ENABLE_PROCESSED_INPUT); + TrySetConsoleMode(Handle, NewMode); + + // Set UTF-8 code page + SetConsoleCP(CP_UTF8); +} + +ConsoleInput::~ConsoleInput() +{ + if (m_Handle) { - _stdHandles.StdErr.Handle = LXSS_HANDLE_USE_CONSOLE; - _stdHandles.StdErr.HandleType = LxssHandleConsole; + TrySetConsoleMode(m_Handle, m_SavedMode); + SetConsoleCP(m_SavedCodePage); } } -wsl::windows::common::SvcCommIo::~SvcCommIo() +// ConsoleOutput implementation +std::unique_ptr ConsoleOutput::Create() { - try + wil::unique_hfile ConsoleHandle( + CreateFileW(L"CONOUT$", GENERIC_READ | GENERIC_WRITE, FILE_SHARE_READ | FILE_SHARE_WRITE, nullptr, OPEN_EXISTING, 0, nullptr)); + + if (!ConsoleHandle) { - RestoreConsoleMode(); + return nullptr; } - CATCH_LOG() + + DWORD Mode; + if (GetConsoleMode(ConsoleHandle.get(), &Mode)) + { + return std::unique_ptr(new ConsoleOutput(std::move(ConsoleHandle), Mode)); + } + + return nullptr; +} + +ConsoleOutput::ConsoleOutput(wil::unique_hfile ConsoleHandle, DWORD SavedMode) : + m_ConsoleHandle(std::move(ConsoleHandle)), m_SavedMode(SavedMode) +{ + // Save code page + m_SavedCodePage = GetConsoleOutputCP(); + + // Configure for VT output with DISABLE_NEWLINE_AUTO_RETURN + DWORD NewMode = m_SavedMode; + WI_SetAllFlags(NewMode, ENABLE_PROCESSED_OUTPUT | ENABLE_VIRTUAL_TERMINAL_PROCESSING | DISABLE_NEWLINE_AUTO_RETURN); + + // Try with DISABLE_NEWLINE_AUTO_RETURN first, fall back without it if not supported + if (!SetConsoleMode(m_ConsoleHandle.get(), NewMode)) + { + WI_ClearFlag(NewMode, DISABLE_NEWLINE_AUTO_RETURN); + TrySetConsoleMode(m_ConsoleHandle.get(), NewMode); + } + + // Set UTF-8 code page + SetConsoleOutputCP(CP_UTF8); +} + +ConsoleOutput::~ConsoleOutput() +{ + if (m_ConsoleHandle) + { + TrySetConsoleMode(m_ConsoleHandle.get(), m_SavedMode); + SetConsoleOutputCP(m_SavedCodePage); + } +} + +// SvcCommIo implementation +SvcCommIo::SvcCommIo() +{ + const HANDLE InputHandle = GetStdHandle(STD_INPUT_HANDLE); + const HANDLE OutputHandle = GetStdHandle(STD_OUTPUT_HANDLE); + const HANDLE ErrorHandle = GetStdHandle(STD_ERROR_HANDLE); + + // Configure input console + m_ConsoleInput = ConsoleInput::Create(InputHandle); + + // Configure output console + m_ConsoleOutput = ConsoleOutput::Create(); + + // Initialize the standard handles structure + const bool IsConsoleInput = m_ConsoleInput != nullptr; + m_StdHandles.StdIn.HandleType = IsConsoleInput ? LxssHandleConsole : LxssHandleInput; + m_StdHandles.StdIn.Handle = IsConsoleInput ? LXSS_HANDLE_USE_CONSOLE : HandleToUlong(InputHandle); + + const bool IsConsoleOutput = IsConsoleHandle(OutputHandle); + m_StdHandles.StdOut.HandleType = IsConsoleOutput ? LxssHandleConsole : LxssHandleOutput; + m_StdHandles.StdOut.Handle = IsConsoleOutput ? LXSS_HANDLE_USE_CONSOLE : HandleToUlong(OutputHandle); + + const bool IsConsoleError = IsConsoleHandle(ErrorHandle); + m_StdHandles.StdErr.HandleType = IsConsoleError ? LxssHandleConsole : LxssHandleOutput; + m_StdHandles.StdErr.Handle = IsConsoleError ? LXSS_HANDLE_USE_CONSOLE : HandleToUlong(ErrorHandle); + + // Cache a console handle for GetWindowSize + m_WindowSizeHandle = IsConsoleOutput ? OutputHandle : (IsConsoleError ? ErrorHandle : nullptr); } PLXSS_STD_HANDLES -wsl::windows::common::SvcCommIo::GetStdHandles() +SvcCommIo::GetStdHandles() { - return &_stdHandles; + return &m_StdHandles; } COORD -wsl::windows::common::SvcCommIo::GetWindowSize() const +SvcCommIo::GetWindowSize() const { - CONSOLE_SCREEN_BUFFER_INFOEX Info{}; - Info.cbSize = sizeof(Info); - if (_stdHandlesInfo.IsConsoleOutput) + if (m_WindowSizeHandle) { - THROW_IF_WIN32_BOOL_FALSE(::GetConsoleScreenBufferInfoEx(_stdHandlesInfo.OutputHandle, &Info)); - } - else if (_stdHandlesInfo.IsConsoleError) - { - THROW_IF_WIN32_BOOL_FALSE(::GetConsoleScreenBufferInfoEx(_stdHandlesInfo.ErrorHandle, &Info)); + CONSOLE_SCREEN_BUFFER_INFOEX Info{}; + Info.cbSize = sizeof(Info); + THROW_IF_WIN32_BOOL_FALSE(GetConsoleScreenBufferInfoEx(m_WindowSizeHandle, &Info)); + return { + static_cast(Info.srWindow.Right - Info.srWindow.Left + 1), + static_cast(Info.srWindow.Bottom - Info.srWindow.Top + 1)}; } - return { - static_cast(Info.srWindow.Right - Info.srWindow.Left + 1), static_cast(Info.srWindow.Bottom - Info.srWindow.Top + 1)}; + return {80, 24}; // Default size if no console } -void wsl::windows::common::SvcCommIo::RestoreConsoleMode() const - -/*++ - -Routine Description: - - Restores the saved input/output console mode. - -Arguments: - - None. - -Return Value: - - None. - ---*/ - -{ - // - // Restore the console input and output modes. - // - - if (_stdHandlesInfo.ConsoleOutputHandle) - { - ChangeConsoleMode(_stdHandlesInfo.ConsoleOutputHandle.get(), _stdHandlesInfo.SavedOutputMode); - SetConsoleOutputCP(_stdHandlesInfo.SavedOutputCP); - } - - if (_stdHandlesInfo.IsConsoleInput != FALSE) - { - ChangeConsoleMode(_stdHandlesInfo.InputHandle, _stdHandlesInfo.SavedInputMode); - SetConsoleCP(_stdHandlesInfo.SavedInputCP); - } -} +} // namespace wsl::windows::common diff --git a/src/windows/common/svccommio.hpp b/src/windows/common/svccommio.hpp index 95de9b2..a734781 100644 --- a/src/windows/common/svccommio.hpp +++ b/src/windows/common/svccommio.hpp @@ -18,35 +18,57 @@ Abstract: #include #include "wslservice.h" -typedef struct _LXSS_STD_HANDLES_INFO -{ - HANDLE InputHandle; - HANDLE OutputHandle; - HANDLE ErrorHandle; - wil::unique_hfile ConsoleOutputHandle; - BOOLEAN IsConsoleInput; - BOOLEAN IsConsoleOutput; - BOOLEAN IsConsoleError; - DWORD SavedInputMode; - DWORD SavedOutputMode; - UINT SavedInputCP; - UINT SavedOutputCP; -} LXSS_STD_HANDLES_INFO, *PLXSS_STD_HANDLES_INFO; - namespace wsl::windows::common { + +// RAII wrapper for console input configuration and restoration +class ConsoleInput +{ +public: + static std::unique_ptr Create(HANDLE Handle); + ~ConsoleInput(); + ConsoleInput(const ConsoleInput&) = delete; + ConsoleInput& operator=(const ConsoleInput&) = delete; + +private: + ConsoleInput(HANDLE Handle, DWORD SavedMode); + + HANDLE m_Handle = nullptr; + DWORD m_SavedMode = 0; + UINT m_SavedCodePage = 0; +}; + +// RAII wrapper for console output configuration and restoration +class ConsoleOutput +{ +public: + static std::unique_ptr Create(); + ~ConsoleOutput(); + ConsoleOutput(const ConsoleOutput&) = delete; + ConsoleOutput& operator=(const ConsoleOutput&) = delete; + +private: + ConsoleOutput(wil::unique_hfile ConsoleHandle, DWORD SavedMode); + + wil::unique_hfile m_ConsoleHandle; + DWORD m_SavedMode = 0; + UINT m_SavedCodePage = 0; +}; + class SvcCommIo { public: SvcCommIo(); - ~SvcCommIo(); + ~SvcCommIo() = default; PLXSS_STD_HANDLES GetStdHandles(); COORD GetWindowSize() const; private: - void RestoreConsoleMode() const; + LXSS_STD_HANDLES m_StdHandles{}; + HANDLE m_WindowSizeHandle = nullptr; // Cached console handle for GetWindowSize - LXSS_STD_HANDLES _stdHandles{}; - LXSS_STD_HANDLES_INFO _stdHandlesInfo{}; + // RAII members for automatic restoration + std::unique_ptr m_ConsoleInput; + std::unique_ptr m_ConsoleOutput; }; } // namespace wsl::windows::common