From 78b7720186a926f3344d014c5ae64ea97719fe61 Mon Sep 17 00:00:00 2001 From: Blue Date: Fri, 30 Jan 2026 18:11:17 -0800 Subject: [PATCH] Redesign Accept() logic to differenciate between cancellation and errors --- src/windows/common/hvsocket.cpp | 9 ++- src/windows/common/hvsocket.hpp | 4 +- src/windows/common/relay.cpp | 65 +++++++++++++++++++ src/windows/common/relay.hpp | 22 +++++++ src/windows/common/socket.cpp | 29 +++++---- src/windows/common/socket.hpp | 4 +- src/windows/service/exe/WslCoreVm.cpp | 29 ++++++--- .../wslaservice/exe/WSLAVirtualMachine.cpp | 15 +++-- src/windows/wslrelay/localhost.cpp | 6 +- src/windows/wslrelay/main.cpp | 5 +- 10 files changed, 153 insertions(+), 35 deletions(-) diff --git a/src/windows/common/hvsocket.cpp b/src/windows/common/hvsocket.cpp index a7f1ee7c..204de0c8 100644 --- a/src/windows/common/hvsocket.cpp +++ b/src/windows/common/hvsocket.cpp @@ -37,11 +37,14 @@ void InitializeWildcardSocketAddress(_Out_ PSOCKADDR_HV Address) } } // namespace -wil::unique_socket wsl::windows::common::hvsocket::Accept( - _In_ SOCKET ListenSocket, _In_ int Timeout, _In_opt_ HANDLE ExitHandle, _In_ const std::source_location& Location) +std::optional wsl::windows::common::hvsocket::CancellableAccept( + _In_ SOCKET ListenSocket, _In_ DWORD Timeout, _In_opt_ HANDLE ExitHandle, _In_ const std::source_location& Location) { wil::unique_socket Socket = Create(); - wsl::windows::common::socket::Accept(ListenSocket, Socket.get(), Timeout, ExitHandle, Location); + if (!socket::CancellableAccept(ListenSocket, Socket.get(), Timeout, ExitHandle, Location)) + { + return {}; + } return Socket; } diff --git a/src/windows/common/hvsocket.hpp b/src/windows/common/hvsocket.hpp index 7bbd618f..9aa67253 100644 --- a/src/windows/common/hvsocket.hpp +++ b/src/windows/common/hvsocket.hpp @@ -19,9 +19,9 @@ Abstract: namespace wsl::windows::common::hvsocket { -wil::unique_socket Accept( +std::optional CancellableAccept( _In_ SOCKET ListenSocket, - _In_ int Timeout, + _In_ DWORD Timeout, _In_opt_ HANDLE ExitHandle = nullptr, const std::source_location& Location = std::source_location::current()); diff --git a/src/windows/common/relay.cpp b/src/windows/common/relay.cpp index 5f992a90..3e45589e 100644 --- a/src/windows/common/relay.cpp +++ b/src/windows/common/relay.cpp @@ -28,6 +28,7 @@ using wsl::windows::common::relay::ReadHandle; using wsl::windows::common::relay::RelayHandle; using wsl::windows::common::relay::ScopedMultiRelay; using wsl::windows::common::relay::ScopedRelay; +using wsl::windows::common::relay::SingleAcceptHandle; using wsl::windows::common::relay::WriteHandle; namespace { @@ -1581,4 +1582,68 @@ void DockerIORelayHandle::OnRead(const gsl::span& Buffer) // If no handle is active, expect a header. ProcessNextHeader(); } +} + +SingleAcceptHandle::SingleAcceptHandle(HandleWrapper&& ListenSocket, HandleWrapper&& AcceptedSocket, std::function&& OnAccepted) : + ListenSocket(std::move(ListenSocket)), AcceptedSocket(std::move(AcceptedSocket)), OnAccepted(std::move(OnAccepted)) +{ + Overlapped.hEvent = Event.get(); +} + +SingleAcceptHandle::~SingleAcceptHandle() +{ + if (State == IOHandleStatus::Pending) + { + if (!CancelIoEx(ListenSocket.Get(), &Overlapped)) + { + DWORD bytesProcessed{}; + DWORD flagsReturned{}; + if (!WSAGetOverlappedResult((SOCKET)ListenSocket.Get(), &Overlapped, &bytesProcessed, TRUE, &flagsReturned)) + { + auto error = GetLastError(); + LOG_LAST_ERROR_IF(error != ERROR_CONNECTION_ABORTED && error != ERROR_OPERATION_ABORTED); + } + } + } +} + +void SingleAcceptHandle::Schedule() +{ + WI_ASSERT(State == IOHandleStatus::Standby); + + // Schedule the accept. + CHAR acceptBuffer[2 * sizeof(SOCKADDR_STORAGE)]{}; + DWORD bytesReturned{}; + + if (AcceptEx((SOCKET)ListenSocket.Get(), (SOCKET)AcceptedSocket.Get(), acceptBuffer, 0, sizeof(SOCKADDR_STORAGE), sizeof(SOCKADDR_STORAGE), &bytesReturned, &Overlapped)) + { + // Accept completed immediately. + State = IOHandleStatus::Completed; + OnAccepted(); + } + else + { + auto error = WSAGetLastError(); + THROW_HR_IF_MSG(HRESULT_FROM_WIN32(error), error != ERROR_IO_PENDING, "Handle: 0x%p", (void*)ListenSocket.Get()); + + State = IOHandleStatus::Pending; + } +} + +void SingleAcceptHandle::Collect() +{ + WI_ASSERT(State == IOHandleStatus::Pending); + + DWORD bytesReceived{}; + DWORD flagsReturned{}; + + THROW_IF_WIN32_BOOL_FALSE(WSAGetOverlappedResult((SOCKET)ListenSocket.Get(), &Overlapped, &bytesReceived, false, &flagsReturned)); + + State = IOHandleStatus::Completed; + OnAccepted(); +} + +HANDLE SingleAcceptHandle::GetHandle() const +{ + return Event.get(); } \ No newline at end of file diff --git a/src/windows/common/relay.hpp b/src/windows/common/relay.hpp index dc7aa3a0..ae4935be 100644 --- a/src/windows/common/relay.hpp +++ b/src/windows/common/relay.hpp @@ -466,6 +466,28 @@ private: size_t RemainingBytes = 0; }; +class SingleAcceptHandle : public OverlappedIOHandle +{ +public: + NON_COPYABLE(SingleAcceptHandle); + NON_MOVABLE(SingleAcceptHandle); + + SingleAcceptHandle(HandleWrapper&& ListenSocket, HandleWrapper&& AcceptedSocket, std::function&& OnAccepted); + ~SingleAcceptHandle(); + + void Schedule() override; + void Collect() override; + HANDLE GetHandle() const override; + +private: + HandleWrapper ListenSocket; + HandleWrapper AcceptedSocket; + std::function OnAccept; + wil::unique_event Event{wil::EventOptions::ManualReset}; + OVERLAPPED Overlapped{}; + std::function OnAccepted; +}; + class MultiHandleWait { public: diff --git a/src/windows/common/socket.cpp b/src/windows/common/socket.cpp index edc49aae..83e24c74 100644 --- a/src/windows/common/socket.cpp +++ b/src/windows/common/socket.cpp @@ -17,20 +17,25 @@ Abstract: #include "socket.hpp" #pragma hdrstop -void wsl::windows::common::socket::Accept( - _In_ SOCKET ListenSocket, _In_ SOCKET Socket, _In_ int Timeout, _In_opt_ HANDLE ExitHandle, _In_ const std::source_location& Location) +bool wsl::windows::common::socket::CancellableAccept( + _In_ SOCKET ListenSocket, _In_ SOCKET Socket, _In_ DWORD Timeout, _In_opt_ HANDLE ExitHandle, _In_ const std::source_location& Location) { - CHAR AcceptBuffer[2 * sizeof(SOCKADDR_STORAGE)]{}; - DWORD BytesReturned; - OVERLAPPED Overlapped{}; - const wil::unique_event OverlappedEvent(wil::EventOptions::ManualReset); - Overlapped.hEvent = OverlappedEvent.get(); - const BOOL Success = - AcceptEx(ListenSocket, Socket, AcceptBuffer, 0, sizeof(SOCKADDR_STORAGE), sizeof(SOCKADDR_STORAGE), &BytesReturned, &Overlapped); + relay::MultiHandleWait io; - if (!Success) + bool accepted = false; + + io.AddHandle(std::make_unique(ListenSocket, Socket, [&]() { accepted = true; }), relay::MultiHandleWait::CancelOnCompleted); + + if (ExitHandle != nullptr) { - GetResult(ListenSocket, Overlapped, Timeout, ExitHandle, Location); + io.AddHandle(std::make_unique(ExitHandle), relay::MultiHandleWait::CancelOnCompleted); + } + + io.Run(std::chrono::milliseconds(Timeout)); + + if (!accepted) + { + return false; // Accept was cancelled by the exit event. } // Set the accept context to mark the socket as connected. @@ -39,7 +44,7 @@ void wsl::windows::common::socket::Accept( "From: %hs", std::format("{}", Location).c_str()); - return; + return true; } std::pair wsl::windows::common::socket::GetResult( diff --git a/src/windows/common/socket.hpp b/src/windows/common/socket.hpp index a72e8ce5..f3f668ae 100644 --- a/src/windows/common/socket.hpp +++ b/src/windows/common/socket.hpp @@ -18,10 +18,10 @@ Abstract: namespace wsl::windows::common::socket { -void Accept( +bool CancellableAccept( _In_ SOCKET ListenSocket, _In_ SOCKET Socket, - _In_ int Timeout, + _In_ DWORD Timeout, _In_opt_ HANDLE ExitHandle, _In_ const std::source_location& Location = std::source_location::current()); diff --git a/src/windows/service/exe/WslCoreVm.cpp b/src/windows/service/exe/WslCoreVm.cpp index d904cde9..24d36d8e 100644 --- a/src/windows/service/exe/WslCoreVm.cpp +++ b/src/windows/service/exe/WslCoreVm.cpp @@ -813,14 +813,15 @@ WslCoreVm::~WslCoreVm() noexcept wil::unique_socket WslCoreVm::AcceptConnection(_In_ DWORD ReceiveTimeout, _In_ const std::source_location& Location) const { - auto socket = - wsl::windows::common::hvsocket::Accept(m_listenSocket.get(), m_vmConfig.KernelBootTimeout, m_terminatingEvent.get(), Location); + auto socket = hvsocket::CancellableAccept(m_listenSocket.get(), m_vmConfig.KernelBootTimeout, m_terminatingEvent.get(), Location); + THROW_HR_IF(E_ABORT, !socket.has_value()); + if (ReceiveTimeout != 0) { - THROW_LAST_ERROR_IF(setsockopt(socket.get(), SOL_SOCKET, SO_RCVTIMEO, (const char*)&ReceiveTimeout, sizeof(ReceiveTimeout)) == SOCKET_ERROR); + THROW_LAST_ERROR_IF(setsockopt(socket->get(), SOL_SOCKET, SO_RCVTIMEO, (const char*)&ReceiveTimeout, sizeof(ReceiveTimeout)) == SOCKET_ERROR); } - return socket; + return std::move(socket.value()); } _Requires_lock_held_(m_guestDeviceLock) @@ -1046,13 +1047,17 @@ void WslCoreVm::CollectCrashDumps(wil::unique_socket&& listenSocket) const { try { - auto socket = wsl::windows::common::hvsocket::Accept(listenSocket.get(), INFINITE, m_terminatingEvent.get()); + auto socket = hvsocket::CancellableAccept(listenSocket.get(), INFINITE, m_terminatingEvent.get()); + if (!socket.has_value()) + { + break; // VM is exiting. + } DWORD receiveTimeout = m_vmConfig.KernelBootTimeout; THROW_LAST_ERROR_IF( setsockopt(listenSocket.get(), SOL_SOCKET, SO_RCVTIMEO, (const char*)&receiveTimeout, sizeof(receiveTimeout)) == SOCKET_ERROR); - auto channel = wsl::shared::SocketChannel{std::move(socket), "crash_dump", m_terminatingEvent.get()}; + auto channel = wsl::shared::SocketChannel{std::move(socket.value()), "crash_dump", m_terminatingEvent.get()}; const auto& message = channel.ReceiveMessage(); const char* process = reinterpret_cast(&message.Buffer); @@ -2513,10 +2518,14 @@ try for (;;) { // Create a worker thread to handle each request. - wsl::shared::SocketChannel channel{ - wsl::windows::common::hvsocket::Accept(listenSocket.get(), INFINITE, m_terminatingEvent.get()), - "VirtioFs", - m_terminatingEvent.get()}; + + auto socket = hvsocket::CancellableAccept(listenSocket.get(), INFINITE, m_terminatingEvent.get()); + if (!socket.has_value()) + { + break; + } + + wsl::shared::SocketChannel channel{std::move(socket.value()), "VirtioFs", m_terminatingEvent.get()}; std::thread([this, channel = std::move(channel)]() mutable { try { diff --git a/src/windows/wslaservice/exe/WSLAVirtualMachine.cpp b/src/windows/wslaservice/exe/WSLAVirtualMachine.cpp index 2e814fe0..2fe5d45c 100644 --- a/src/windows/wslaservice/exe/WSLAVirtualMachine.cpp +++ b/src/windows/wslaservice/exe/WSLAVirtualMachine.cpp @@ -375,8 +375,11 @@ void WSLAVirtualMachine::Start() // Create a socket listening for connections from mini_init. auto listenSocket = wsl::windows::common::hvsocket::Listen(runtimeId, LX_INIT_UTILITY_VM_INIT_PORT); - auto socket = wsl::windows::common::hvsocket::Accept(listenSocket.get(), m_settings.BootTimeoutMs, m_vmTerminatingEvent.get()); - m_initChannel = wsl::shared::SocketChannel{std::move(socket), "mini_init", m_vmTerminatingEvent.get()}; + auto socket = + wsl::windows::common::hvsocket::CancellableAccept(listenSocket.get(), m_settings.BootTimeoutMs, m_vmTerminatingEvent.get()); + THROW_HR_IF(E_ABORT, !socket.has_value()); + + m_initChannel = wsl::shared::SocketChannel{std::move(socket.value()), "mini_init", m_vmTerminatingEvent.get()}; // Create a thread to watch for exited processes. auto [__, ___, childChannel] = Fork(WSLA_FORK::Thread); @@ -1378,12 +1381,16 @@ void WSLAVirtualMachine::CollectCrashDumps(wil::unique_socket&& listenSocket) co { try { - auto socket = wsl::windows::common::hvsocket::Accept(listenSocket.get(), INFINITE, m_vmExitEvent.get()); + auto socket = common::hvsocket::CancellableAccept(listenSocket.get(), INFINITE, m_vmExitEvent.get()); + if (!socket.has_value()) + { + break; + } THROW_LAST_ERROR_IF( setsockopt(listenSocket.get(), SOL_SOCKET, SO_RCVTIMEO, (const char*)&RECEIVE_TIMEOUT, sizeof(RECEIVE_TIMEOUT)) == SOCKET_ERROR); - auto channel = wsl::shared::SocketChannel{std::move(socket), "crash_dump", m_vmExitEvent.get()}; + auto channel = wsl::shared::SocketChannel{std::move(socket.value()), "crash_dump", m_vmExitEvent.get()}; const auto& message = channel.ReceiveMessage(); const char* process = reinterpret_cast(&message.Buffer); diff --git a/src/windows/wslrelay/localhost.cpp b/src/windows/wslrelay/localhost.cpp index 6bae9923..6614d6f9 100644 --- a/src/windows/wslrelay/localhost.cpp +++ b/src/windows/wslrelay/localhost.cpp @@ -297,7 +297,11 @@ try wil::unique_socket InetSocket(WSASocket(AddressFamily, SOCK_STREAM, IPPROTO_TCP, nullptr, 0, WSA_FLAG_OVERLAPPED)); THROW_LAST_ERROR_IF(!InetSocket); - wsl::windows::common::socket::Accept(Arguments->ListenSocket.get(), InetSocket.get(), INFINITE, Arguments->ExitEvent.get()); + if (!wsl::windows::common::socket::CancellableAccept( + Arguments->ListenSocket.get(), InetSocket.get(), INFINITE, Arguments->ExitEvent.get())) + { + break; // Exit event was signaled, exit. + } // Establish a relay thread. diff --git a/src/windows/wslrelay/main.cpp b/src/windows/wslrelay/main.cpp index 2c01cc31..93a83e44 100644 --- a/src/windows/wslrelay/main.cpp +++ b/src/windows/wslrelay/main.cpp @@ -126,7 +126,10 @@ try const wil::unique_socket socket(WSASocket(AF_INET, SOCK_STREAM, IPPROTO_TCP, nullptr, 0, WSA_FLAG_OVERLAPPED)); THROW_LAST_ERROR_IF(!socket); - wsl::windows::common::socket::Accept(listenSocket.get(), socket.get(), INFINITE, exitEvent.get()); + if (!wsl::windows::common::socket::CancellableAccept(listenSocket.get(), socket.get(), INFINITE, exitEvent.get())) + { + return 1; + } // Begin the relay. wsl::windows::common::relay::BidirectionalRelay(