Redesign Accept() logic to differenciate between cancellation and errors

This commit is contained in:
Blue 2026-01-30 18:11:17 -08:00
parent fc3dfa7c52
commit 78b7720186
10 changed files with 153 additions and 35 deletions

View File

@ -37,11 +37,14 @@ void InitializeWildcardSocketAddress(_Out_ PSOCKADDR_HV Address)
}
} // namespace
wil::unique_socket wsl::windows::common::hvsocket::Accept(
_In_ SOCKET ListenSocket, _In_ int Timeout, _In_opt_ HANDLE ExitHandle, _In_ const std::source_location& Location)
std::optional<wil::unique_socket> wsl::windows::common::hvsocket::CancellableAccept(
_In_ SOCKET ListenSocket, _In_ DWORD Timeout, _In_opt_ HANDLE ExitHandle, _In_ const std::source_location& Location)
{
wil::unique_socket Socket = Create();
wsl::windows::common::socket::Accept(ListenSocket, Socket.get(), Timeout, ExitHandle, Location);
if (!socket::CancellableAccept(ListenSocket, Socket.get(), Timeout, ExitHandle, Location))
{
return {};
}
return Socket;
}

View File

@ -19,9 +19,9 @@ Abstract:
namespace wsl::windows::common::hvsocket {
wil::unique_socket Accept(
std::optional<wil::unique_socket> CancellableAccept(
_In_ SOCKET ListenSocket,
_In_ int Timeout,
_In_ DWORD Timeout,
_In_opt_ HANDLE ExitHandle = nullptr,
const std::source_location& Location = std::source_location::current());

View File

@ -28,6 +28,7 @@ using wsl::windows::common::relay::ReadHandle;
using wsl::windows::common::relay::RelayHandle;
using wsl::windows::common::relay::ScopedMultiRelay;
using wsl::windows::common::relay::ScopedRelay;
using wsl::windows::common::relay::SingleAcceptHandle;
using wsl::windows::common::relay::WriteHandle;
namespace {
@ -1581,4 +1582,68 @@ void DockerIORelayHandle::OnRead(const gsl::span<char>& Buffer)
// If no handle is active, expect a header.
ProcessNextHeader();
}
}
SingleAcceptHandle::SingleAcceptHandle(HandleWrapper&& ListenSocket, HandleWrapper&& AcceptedSocket, std::function<void()>&& OnAccepted) :
ListenSocket(std::move(ListenSocket)), AcceptedSocket(std::move(AcceptedSocket)), OnAccepted(std::move(OnAccepted))
{
Overlapped.hEvent = Event.get();
}
SingleAcceptHandle::~SingleAcceptHandle()
{
if (State == IOHandleStatus::Pending)
{
if (!CancelIoEx(ListenSocket.Get(), &Overlapped))
{
DWORD bytesProcessed{};
DWORD flagsReturned{};
if (!WSAGetOverlappedResult((SOCKET)ListenSocket.Get(), &Overlapped, &bytesProcessed, TRUE, &flagsReturned))
{
auto error = GetLastError();
LOG_LAST_ERROR_IF(error != ERROR_CONNECTION_ABORTED && error != ERROR_OPERATION_ABORTED);
}
}
}
}
void SingleAcceptHandle::Schedule()
{
WI_ASSERT(State == IOHandleStatus::Standby);
// Schedule the accept.
CHAR acceptBuffer[2 * sizeof(SOCKADDR_STORAGE)]{};
DWORD bytesReturned{};
if (AcceptEx((SOCKET)ListenSocket.Get(), (SOCKET)AcceptedSocket.Get(), acceptBuffer, 0, sizeof(SOCKADDR_STORAGE), sizeof(SOCKADDR_STORAGE), &bytesReturned, &Overlapped))
{
// Accept completed immediately.
State = IOHandleStatus::Completed;
OnAccepted();
}
else
{
auto error = WSAGetLastError();
THROW_HR_IF_MSG(HRESULT_FROM_WIN32(error), error != ERROR_IO_PENDING, "Handle: 0x%p", (void*)ListenSocket.Get());
State = IOHandleStatus::Pending;
}
}
void SingleAcceptHandle::Collect()
{
WI_ASSERT(State == IOHandleStatus::Pending);
DWORD bytesReceived{};
DWORD flagsReturned{};
THROW_IF_WIN32_BOOL_FALSE(WSAGetOverlappedResult((SOCKET)ListenSocket.Get(), &Overlapped, &bytesReceived, false, &flagsReturned));
State = IOHandleStatus::Completed;
OnAccepted();
}
HANDLE SingleAcceptHandle::GetHandle() const
{
return Event.get();
}

View File

@ -466,6 +466,28 @@ private:
size_t RemainingBytes = 0;
};
class SingleAcceptHandle : public OverlappedIOHandle
{
public:
NON_COPYABLE(SingleAcceptHandle);
NON_MOVABLE(SingleAcceptHandle);
SingleAcceptHandle(HandleWrapper&& ListenSocket, HandleWrapper&& AcceptedSocket, std::function<void()>&& OnAccepted);
~SingleAcceptHandle();
void Schedule() override;
void Collect() override;
HANDLE GetHandle() const override;
private:
HandleWrapper ListenSocket;
HandleWrapper AcceptedSocket;
std::function<void(wil::unique_socket AcceptedSocket)> OnAccept;
wil::unique_event Event{wil::EventOptions::ManualReset};
OVERLAPPED Overlapped{};
std::function<void()> OnAccepted;
};
class MultiHandleWait
{
public:

View File

@ -17,20 +17,25 @@ Abstract:
#include "socket.hpp"
#pragma hdrstop
void wsl::windows::common::socket::Accept(
_In_ SOCKET ListenSocket, _In_ SOCKET Socket, _In_ int Timeout, _In_opt_ HANDLE ExitHandle, _In_ const std::source_location& Location)
bool wsl::windows::common::socket::CancellableAccept(
_In_ SOCKET ListenSocket, _In_ SOCKET Socket, _In_ DWORD Timeout, _In_opt_ HANDLE ExitHandle, _In_ const std::source_location& Location)
{
CHAR AcceptBuffer[2 * sizeof(SOCKADDR_STORAGE)]{};
DWORD BytesReturned;
OVERLAPPED Overlapped{};
const wil::unique_event OverlappedEvent(wil::EventOptions::ManualReset);
Overlapped.hEvent = OverlappedEvent.get();
const BOOL Success =
AcceptEx(ListenSocket, Socket, AcceptBuffer, 0, sizeof(SOCKADDR_STORAGE), sizeof(SOCKADDR_STORAGE), &BytesReturned, &Overlapped);
relay::MultiHandleWait io;
if (!Success)
bool accepted = false;
io.AddHandle(std::make_unique<relay::SingleAcceptHandle>(ListenSocket, Socket, [&]() { accepted = true; }), relay::MultiHandleWait::CancelOnCompleted);
if (ExitHandle != nullptr)
{
GetResult(ListenSocket, Overlapped, Timeout, ExitHandle, Location);
io.AddHandle(std::make_unique<relay::EventHandle>(ExitHandle), relay::MultiHandleWait::CancelOnCompleted);
}
io.Run(std::chrono::milliseconds(Timeout));
if (!accepted)
{
return false; // Accept was cancelled by the exit event.
}
// Set the accept context to mark the socket as connected.
@ -39,7 +44,7 @@ void wsl::windows::common::socket::Accept(
"From: %hs",
std::format("{}", Location).c_str());
return;
return true;
}
std::pair<DWORD, DWORD> wsl::windows::common::socket::GetResult(

View File

@ -18,10 +18,10 @@ Abstract:
namespace wsl::windows::common::socket {
void Accept(
bool CancellableAccept(
_In_ SOCKET ListenSocket,
_In_ SOCKET Socket,
_In_ int Timeout,
_In_ DWORD Timeout,
_In_opt_ HANDLE ExitHandle,
_In_ const std::source_location& Location = std::source_location::current());

View File

@ -813,14 +813,15 @@ WslCoreVm::~WslCoreVm() noexcept
wil::unique_socket WslCoreVm::AcceptConnection(_In_ DWORD ReceiveTimeout, _In_ const std::source_location& Location) const
{
auto socket =
wsl::windows::common::hvsocket::Accept(m_listenSocket.get(), m_vmConfig.KernelBootTimeout, m_terminatingEvent.get(), Location);
auto socket = hvsocket::CancellableAccept(m_listenSocket.get(), m_vmConfig.KernelBootTimeout, m_terminatingEvent.get(), Location);
THROW_HR_IF(E_ABORT, !socket.has_value());
if (ReceiveTimeout != 0)
{
THROW_LAST_ERROR_IF(setsockopt(socket.get(), SOL_SOCKET, SO_RCVTIMEO, (const char*)&ReceiveTimeout, sizeof(ReceiveTimeout)) == SOCKET_ERROR);
THROW_LAST_ERROR_IF(setsockopt(socket->get(), SOL_SOCKET, SO_RCVTIMEO, (const char*)&ReceiveTimeout, sizeof(ReceiveTimeout)) == SOCKET_ERROR);
}
return socket;
return std::move(socket.value());
}
_Requires_lock_held_(m_guestDeviceLock)
@ -1046,13 +1047,17 @@ void WslCoreVm::CollectCrashDumps(wil::unique_socket&& listenSocket) const
{
try
{
auto socket = wsl::windows::common::hvsocket::Accept(listenSocket.get(), INFINITE, m_terminatingEvent.get());
auto socket = hvsocket::CancellableAccept(listenSocket.get(), INFINITE, m_terminatingEvent.get());
if (!socket.has_value())
{
break; // VM is exiting.
}
DWORD receiveTimeout = m_vmConfig.KernelBootTimeout;
THROW_LAST_ERROR_IF(
setsockopt(listenSocket.get(), SOL_SOCKET, SO_RCVTIMEO, (const char*)&receiveTimeout, sizeof(receiveTimeout)) == SOCKET_ERROR);
auto channel = wsl::shared::SocketChannel{std::move(socket), "crash_dump", m_terminatingEvent.get()};
auto channel = wsl::shared::SocketChannel{std::move(socket.value()), "crash_dump", m_terminatingEvent.get()};
const auto& message = channel.ReceiveMessage<LX_PROCESS_CRASH>();
const char* process = reinterpret_cast<const char*>(&message.Buffer);
@ -2513,10 +2518,14 @@ try
for (;;)
{
// Create a worker thread to handle each request.
wsl::shared::SocketChannel channel{
wsl::windows::common::hvsocket::Accept(listenSocket.get(), INFINITE, m_terminatingEvent.get()),
"VirtioFs",
m_terminatingEvent.get()};
auto socket = hvsocket::CancellableAccept(listenSocket.get(), INFINITE, m_terminatingEvent.get());
if (!socket.has_value())
{
break;
}
wsl::shared::SocketChannel channel{std::move(socket.value()), "VirtioFs", m_terminatingEvent.get()};
std::thread([this, channel = std::move(channel)]() mutable {
try
{

View File

@ -375,8 +375,11 @@ void WSLAVirtualMachine::Start()
// Create a socket listening for connections from mini_init.
auto listenSocket = wsl::windows::common::hvsocket::Listen(runtimeId, LX_INIT_UTILITY_VM_INIT_PORT);
auto socket = wsl::windows::common::hvsocket::Accept(listenSocket.get(), m_settings.BootTimeoutMs, m_vmTerminatingEvent.get());
m_initChannel = wsl::shared::SocketChannel{std::move(socket), "mini_init", m_vmTerminatingEvent.get()};
auto socket =
wsl::windows::common::hvsocket::CancellableAccept(listenSocket.get(), m_settings.BootTimeoutMs, m_vmTerminatingEvent.get());
THROW_HR_IF(E_ABORT, !socket.has_value());
m_initChannel = wsl::shared::SocketChannel{std::move(socket.value()), "mini_init", m_vmTerminatingEvent.get()};
// Create a thread to watch for exited processes.
auto [__, ___, childChannel] = Fork(WSLA_FORK::Thread);
@ -1378,12 +1381,16 @@ void WSLAVirtualMachine::CollectCrashDumps(wil::unique_socket&& listenSocket) co
{
try
{
auto socket = wsl::windows::common::hvsocket::Accept(listenSocket.get(), INFINITE, m_vmExitEvent.get());
auto socket = common::hvsocket::CancellableAccept(listenSocket.get(), INFINITE, m_vmExitEvent.get());
if (!socket.has_value())
{
break;
}
THROW_LAST_ERROR_IF(
setsockopt(listenSocket.get(), SOL_SOCKET, SO_RCVTIMEO, (const char*)&RECEIVE_TIMEOUT, sizeof(RECEIVE_TIMEOUT)) == SOCKET_ERROR);
auto channel = wsl::shared::SocketChannel{std::move(socket), "crash_dump", m_vmExitEvent.get()};
auto channel = wsl::shared::SocketChannel{std::move(socket.value()), "crash_dump", m_vmExitEvent.get()};
const auto& message = channel.ReceiveMessage<LX_PROCESS_CRASH>();
const char* process = reinterpret_cast<const char*>(&message.Buffer);

View File

@ -297,7 +297,11 @@ try
wil::unique_socket InetSocket(WSASocket(AddressFamily, SOCK_STREAM, IPPROTO_TCP, nullptr, 0, WSA_FLAG_OVERLAPPED));
THROW_LAST_ERROR_IF(!InetSocket);
wsl::windows::common::socket::Accept(Arguments->ListenSocket.get(), InetSocket.get(), INFINITE, Arguments->ExitEvent.get());
if (!wsl::windows::common::socket::CancellableAccept(
Arguments->ListenSocket.get(), InetSocket.get(), INFINITE, Arguments->ExitEvent.get()))
{
break; // Exit event was signaled, exit.
}
// Establish a relay thread.

View File

@ -126,7 +126,10 @@ try
const wil::unique_socket socket(WSASocket(AF_INET, SOCK_STREAM, IPPROTO_TCP, nullptr, 0, WSA_FLAG_OVERLAPPED));
THROW_LAST_ERROR_IF(!socket);
wsl::windows::common::socket::Accept(listenSocket.get(), socket.get(), INFINITE, exitEvent.get());
if (!wsl::windows::common::socket::CancellableAccept(listenSocket.get(), socket.get(), INFINITE, exitEvent.get()))
{
return 1;
}
// Begin the relay.
wsl::windows::common::relay::BidirectionalRelay(