mirror of
https://github.com/microsoft/WSL.git
synced 2026-02-04 02:06:49 -06:00
Redesign Accept() logic to differenciate between cancellation and errors
This commit is contained in:
parent
fc3dfa7c52
commit
78b7720186
@ -37,11 +37,14 @@ void InitializeWildcardSocketAddress(_Out_ PSOCKADDR_HV Address)
|
||||
}
|
||||
} // namespace
|
||||
|
||||
wil::unique_socket wsl::windows::common::hvsocket::Accept(
|
||||
_In_ SOCKET ListenSocket, _In_ int Timeout, _In_opt_ HANDLE ExitHandle, _In_ const std::source_location& Location)
|
||||
std::optional<wil::unique_socket> wsl::windows::common::hvsocket::CancellableAccept(
|
||||
_In_ SOCKET ListenSocket, _In_ DWORD Timeout, _In_opt_ HANDLE ExitHandle, _In_ const std::source_location& Location)
|
||||
{
|
||||
wil::unique_socket Socket = Create();
|
||||
wsl::windows::common::socket::Accept(ListenSocket, Socket.get(), Timeout, ExitHandle, Location);
|
||||
if (!socket::CancellableAccept(ListenSocket, Socket.get(), Timeout, ExitHandle, Location))
|
||||
{
|
||||
return {};
|
||||
}
|
||||
|
||||
return Socket;
|
||||
}
|
||||
|
||||
@ -19,9 +19,9 @@ Abstract:
|
||||
|
||||
namespace wsl::windows::common::hvsocket {
|
||||
|
||||
wil::unique_socket Accept(
|
||||
std::optional<wil::unique_socket> CancellableAccept(
|
||||
_In_ SOCKET ListenSocket,
|
||||
_In_ int Timeout,
|
||||
_In_ DWORD Timeout,
|
||||
_In_opt_ HANDLE ExitHandle = nullptr,
|
||||
const std::source_location& Location = std::source_location::current());
|
||||
|
||||
|
||||
@ -28,6 +28,7 @@ using wsl::windows::common::relay::ReadHandle;
|
||||
using wsl::windows::common::relay::RelayHandle;
|
||||
using wsl::windows::common::relay::ScopedMultiRelay;
|
||||
using wsl::windows::common::relay::ScopedRelay;
|
||||
using wsl::windows::common::relay::SingleAcceptHandle;
|
||||
using wsl::windows::common::relay::WriteHandle;
|
||||
|
||||
namespace {
|
||||
@ -1581,4 +1582,68 @@ void DockerIORelayHandle::OnRead(const gsl::span<char>& Buffer)
|
||||
// If no handle is active, expect a header.
|
||||
ProcessNextHeader();
|
||||
}
|
||||
}
|
||||
|
||||
SingleAcceptHandle::SingleAcceptHandle(HandleWrapper&& ListenSocket, HandleWrapper&& AcceptedSocket, std::function<void()>&& OnAccepted) :
|
||||
ListenSocket(std::move(ListenSocket)), AcceptedSocket(std::move(AcceptedSocket)), OnAccepted(std::move(OnAccepted))
|
||||
{
|
||||
Overlapped.hEvent = Event.get();
|
||||
}
|
||||
|
||||
SingleAcceptHandle::~SingleAcceptHandle()
|
||||
{
|
||||
if (State == IOHandleStatus::Pending)
|
||||
{
|
||||
if (!CancelIoEx(ListenSocket.Get(), &Overlapped))
|
||||
{
|
||||
DWORD bytesProcessed{};
|
||||
DWORD flagsReturned{};
|
||||
if (!WSAGetOverlappedResult((SOCKET)ListenSocket.Get(), &Overlapped, &bytesProcessed, TRUE, &flagsReturned))
|
||||
{
|
||||
auto error = GetLastError();
|
||||
LOG_LAST_ERROR_IF(error != ERROR_CONNECTION_ABORTED && error != ERROR_OPERATION_ABORTED);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void SingleAcceptHandle::Schedule()
|
||||
{
|
||||
WI_ASSERT(State == IOHandleStatus::Standby);
|
||||
|
||||
// Schedule the accept.
|
||||
CHAR acceptBuffer[2 * sizeof(SOCKADDR_STORAGE)]{};
|
||||
DWORD bytesReturned{};
|
||||
|
||||
if (AcceptEx((SOCKET)ListenSocket.Get(), (SOCKET)AcceptedSocket.Get(), acceptBuffer, 0, sizeof(SOCKADDR_STORAGE), sizeof(SOCKADDR_STORAGE), &bytesReturned, &Overlapped))
|
||||
{
|
||||
// Accept completed immediately.
|
||||
State = IOHandleStatus::Completed;
|
||||
OnAccepted();
|
||||
}
|
||||
else
|
||||
{
|
||||
auto error = WSAGetLastError();
|
||||
THROW_HR_IF_MSG(HRESULT_FROM_WIN32(error), error != ERROR_IO_PENDING, "Handle: 0x%p", (void*)ListenSocket.Get());
|
||||
|
||||
State = IOHandleStatus::Pending;
|
||||
}
|
||||
}
|
||||
|
||||
void SingleAcceptHandle::Collect()
|
||||
{
|
||||
WI_ASSERT(State == IOHandleStatus::Pending);
|
||||
|
||||
DWORD bytesReceived{};
|
||||
DWORD flagsReturned{};
|
||||
|
||||
THROW_IF_WIN32_BOOL_FALSE(WSAGetOverlappedResult((SOCKET)ListenSocket.Get(), &Overlapped, &bytesReceived, false, &flagsReturned));
|
||||
|
||||
State = IOHandleStatus::Completed;
|
||||
OnAccepted();
|
||||
}
|
||||
|
||||
HANDLE SingleAcceptHandle::GetHandle() const
|
||||
{
|
||||
return Event.get();
|
||||
}
|
||||
@ -466,6 +466,28 @@ private:
|
||||
size_t RemainingBytes = 0;
|
||||
};
|
||||
|
||||
class SingleAcceptHandle : public OverlappedIOHandle
|
||||
{
|
||||
public:
|
||||
NON_COPYABLE(SingleAcceptHandle);
|
||||
NON_MOVABLE(SingleAcceptHandle);
|
||||
|
||||
SingleAcceptHandle(HandleWrapper&& ListenSocket, HandleWrapper&& AcceptedSocket, std::function<void()>&& OnAccepted);
|
||||
~SingleAcceptHandle();
|
||||
|
||||
void Schedule() override;
|
||||
void Collect() override;
|
||||
HANDLE GetHandle() const override;
|
||||
|
||||
private:
|
||||
HandleWrapper ListenSocket;
|
||||
HandleWrapper AcceptedSocket;
|
||||
std::function<void(wil::unique_socket AcceptedSocket)> OnAccept;
|
||||
wil::unique_event Event{wil::EventOptions::ManualReset};
|
||||
OVERLAPPED Overlapped{};
|
||||
std::function<void()> OnAccepted;
|
||||
};
|
||||
|
||||
class MultiHandleWait
|
||||
{
|
||||
public:
|
||||
|
||||
@ -17,20 +17,25 @@ Abstract:
|
||||
#include "socket.hpp"
|
||||
#pragma hdrstop
|
||||
|
||||
void wsl::windows::common::socket::Accept(
|
||||
_In_ SOCKET ListenSocket, _In_ SOCKET Socket, _In_ int Timeout, _In_opt_ HANDLE ExitHandle, _In_ const std::source_location& Location)
|
||||
bool wsl::windows::common::socket::CancellableAccept(
|
||||
_In_ SOCKET ListenSocket, _In_ SOCKET Socket, _In_ DWORD Timeout, _In_opt_ HANDLE ExitHandle, _In_ const std::source_location& Location)
|
||||
{
|
||||
CHAR AcceptBuffer[2 * sizeof(SOCKADDR_STORAGE)]{};
|
||||
DWORD BytesReturned;
|
||||
OVERLAPPED Overlapped{};
|
||||
const wil::unique_event OverlappedEvent(wil::EventOptions::ManualReset);
|
||||
Overlapped.hEvent = OverlappedEvent.get();
|
||||
const BOOL Success =
|
||||
AcceptEx(ListenSocket, Socket, AcceptBuffer, 0, sizeof(SOCKADDR_STORAGE), sizeof(SOCKADDR_STORAGE), &BytesReturned, &Overlapped);
|
||||
relay::MultiHandleWait io;
|
||||
|
||||
if (!Success)
|
||||
bool accepted = false;
|
||||
|
||||
io.AddHandle(std::make_unique<relay::SingleAcceptHandle>(ListenSocket, Socket, [&]() { accepted = true; }), relay::MultiHandleWait::CancelOnCompleted);
|
||||
|
||||
if (ExitHandle != nullptr)
|
||||
{
|
||||
GetResult(ListenSocket, Overlapped, Timeout, ExitHandle, Location);
|
||||
io.AddHandle(std::make_unique<relay::EventHandle>(ExitHandle), relay::MultiHandleWait::CancelOnCompleted);
|
||||
}
|
||||
|
||||
io.Run(std::chrono::milliseconds(Timeout));
|
||||
|
||||
if (!accepted)
|
||||
{
|
||||
return false; // Accept was cancelled by the exit event.
|
||||
}
|
||||
|
||||
// Set the accept context to mark the socket as connected.
|
||||
@ -39,7 +44,7 @@ void wsl::windows::common::socket::Accept(
|
||||
"From: %hs",
|
||||
std::format("{}", Location).c_str());
|
||||
|
||||
return;
|
||||
return true;
|
||||
}
|
||||
|
||||
std::pair<DWORD, DWORD> wsl::windows::common::socket::GetResult(
|
||||
|
||||
@ -18,10 +18,10 @@ Abstract:
|
||||
|
||||
namespace wsl::windows::common::socket {
|
||||
|
||||
void Accept(
|
||||
bool CancellableAccept(
|
||||
_In_ SOCKET ListenSocket,
|
||||
_In_ SOCKET Socket,
|
||||
_In_ int Timeout,
|
||||
_In_ DWORD Timeout,
|
||||
_In_opt_ HANDLE ExitHandle,
|
||||
_In_ const std::source_location& Location = std::source_location::current());
|
||||
|
||||
|
||||
@ -813,14 +813,15 @@ WslCoreVm::~WslCoreVm() noexcept
|
||||
|
||||
wil::unique_socket WslCoreVm::AcceptConnection(_In_ DWORD ReceiveTimeout, _In_ const std::source_location& Location) const
|
||||
{
|
||||
auto socket =
|
||||
wsl::windows::common::hvsocket::Accept(m_listenSocket.get(), m_vmConfig.KernelBootTimeout, m_terminatingEvent.get(), Location);
|
||||
auto socket = hvsocket::CancellableAccept(m_listenSocket.get(), m_vmConfig.KernelBootTimeout, m_terminatingEvent.get(), Location);
|
||||
THROW_HR_IF(E_ABORT, !socket.has_value());
|
||||
|
||||
if (ReceiveTimeout != 0)
|
||||
{
|
||||
THROW_LAST_ERROR_IF(setsockopt(socket.get(), SOL_SOCKET, SO_RCVTIMEO, (const char*)&ReceiveTimeout, sizeof(ReceiveTimeout)) == SOCKET_ERROR);
|
||||
THROW_LAST_ERROR_IF(setsockopt(socket->get(), SOL_SOCKET, SO_RCVTIMEO, (const char*)&ReceiveTimeout, sizeof(ReceiveTimeout)) == SOCKET_ERROR);
|
||||
}
|
||||
|
||||
return socket;
|
||||
return std::move(socket.value());
|
||||
}
|
||||
|
||||
_Requires_lock_held_(m_guestDeviceLock)
|
||||
@ -1046,13 +1047,17 @@ void WslCoreVm::CollectCrashDumps(wil::unique_socket&& listenSocket) const
|
||||
{
|
||||
try
|
||||
{
|
||||
auto socket = wsl::windows::common::hvsocket::Accept(listenSocket.get(), INFINITE, m_terminatingEvent.get());
|
||||
auto socket = hvsocket::CancellableAccept(listenSocket.get(), INFINITE, m_terminatingEvent.get());
|
||||
if (!socket.has_value())
|
||||
{
|
||||
break; // VM is exiting.
|
||||
}
|
||||
|
||||
DWORD receiveTimeout = m_vmConfig.KernelBootTimeout;
|
||||
THROW_LAST_ERROR_IF(
|
||||
setsockopt(listenSocket.get(), SOL_SOCKET, SO_RCVTIMEO, (const char*)&receiveTimeout, sizeof(receiveTimeout)) == SOCKET_ERROR);
|
||||
|
||||
auto channel = wsl::shared::SocketChannel{std::move(socket), "crash_dump", m_terminatingEvent.get()};
|
||||
auto channel = wsl::shared::SocketChannel{std::move(socket.value()), "crash_dump", m_terminatingEvent.get()};
|
||||
|
||||
const auto& message = channel.ReceiveMessage<LX_PROCESS_CRASH>();
|
||||
const char* process = reinterpret_cast<const char*>(&message.Buffer);
|
||||
@ -2513,10 +2518,14 @@ try
|
||||
for (;;)
|
||||
{
|
||||
// Create a worker thread to handle each request.
|
||||
wsl::shared::SocketChannel channel{
|
||||
wsl::windows::common::hvsocket::Accept(listenSocket.get(), INFINITE, m_terminatingEvent.get()),
|
||||
"VirtioFs",
|
||||
m_terminatingEvent.get()};
|
||||
|
||||
auto socket = hvsocket::CancellableAccept(listenSocket.get(), INFINITE, m_terminatingEvent.get());
|
||||
if (!socket.has_value())
|
||||
{
|
||||
break;
|
||||
}
|
||||
|
||||
wsl::shared::SocketChannel channel{std::move(socket.value()), "VirtioFs", m_terminatingEvent.get()};
|
||||
std::thread([this, channel = std::move(channel)]() mutable {
|
||||
try
|
||||
{
|
||||
|
||||
@ -375,8 +375,11 @@ void WSLAVirtualMachine::Start()
|
||||
|
||||
// Create a socket listening for connections from mini_init.
|
||||
auto listenSocket = wsl::windows::common::hvsocket::Listen(runtimeId, LX_INIT_UTILITY_VM_INIT_PORT);
|
||||
auto socket = wsl::windows::common::hvsocket::Accept(listenSocket.get(), m_settings.BootTimeoutMs, m_vmTerminatingEvent.get());
|
||||
m_initChannel = wsl::shared::SocketChannel{std::move(socket), "mini_init", m_vmTerminatingEvent.get()};
|
||||
auto socket =
|
||||
wsl::windows::common::hvsocket::CancellableAccept(listenSocket.get(), m_settings.BootTimeoutMs, m_vmTerminatingEvent.get());
|
||||
THROW_HR_IF(E_ABORT, !socket.has_value());
|
||||
|
||||
m_initChannel = wsl::shared::SocketChannel{std::move(socket.value()), "mini_init", m_vmTerminatingEvent.get()};
|
||||
|
||||
// Create a thread to watch for exited processes.
|
||||
auto [__, ___, childChannel] = Fork(WSLA_FORK::Thread);
|
||||
@ -1378,12 +1381,16 @@ void WSLAVirtualMachine::CollectCrashDumps(wil::unique_socket&& listenSocket) co
|
||||
{
|
||||
try
|
||||
{
|
||||
auto socket = wsl::windows::common::hvsocket::Accept(listenSocket.get(), INFINITE, m_vmExitEvent.get());
|
||||
auto socket = common::hvsocket::CancellableAccept(listenSocket.get(), INFINITE, m_vmExitEvent.get());
|
||||
if (!socket.has_value())
|
||||
{
|
||||
break;
|
||||
}
|
||||
|
||||
THROW_LAST_ERROR_IF(
|
||||
setsockopt(listenSocket.get(), SOL_SOCKET, SO_RCVTIMEO, (const char*)&RECEIVE_TIMEOUT, sizeof(RECEIVE_TIMEOUT)) == SOCKET_ERROR);
|
||||
|
||||
auto channel = wsl::shared::SocketChannel{std::move(socket), "crash_dump", m_vmExitEvent.get()};
|
||||
auto channel = wsl::shared::SocketChannel{std::move(socket.value()), "crash_dump", m_vmExitEvent.get()};
|
||||
|
||||
const auto& message = channel.ReceiveMessage<LX_PROCESS_CRASH>();
|
||||
const char* process = reinterpret_cast<const char*>(&message.Buffer);
|
||||
|
||||
@ -297,7 +297,11 @@ try
|
||||
wil::unique_socket InetSocket(WSASocket(AddressFamily, SOCK_STREAM, IPPROTO_TCP, nullptr, 0, WSA_FLAG_OVERLAPPED));
|
||||
THROW_LAST_ERROR_IF(!InetSocket);
|
||||
|
||||
wsl::windows::common::socket::Accept(Arguments->ListenSocket.get(), InetSocket.get(), INFINITE, Arguments->ExitEvent.get());
|
||||
if (!wsl::windows::common::socket::CancellableAccept(
|
||||
Arguments->ListenSocket.get(), InetSocket.get(), INFINITE, Arguments->ExitEvent.get()))
|
||||
{
|
||||
break; // Exit event was signaled, exit.
|
||||
}
|
||||
|
||||
// Establish a relay thread.
|
||||
|
||||
|
||||
@ -126,7 +126,10 @@ try
|
||||
const wil::unique_socket socket(WSASocket(AF_INET, SOCK_STREAM, IPPROTO_TCP, nullptr, 0, WSA_FLAG_OVERLAPPED));
|
||||
THROW_LAST_ERROR_IF(!socket);
|
||||
|
||||
wsl::windows::common::socket::Accept(listenSocket.get(), socket.get(), INFINITE, exitEvent.get());
|
||||
if (!wsl::windows::common::socket::CancellableAccept(listenSocket.get(), socket.get(), INFINITE, exitEvent.get()))
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Begin the relay.
|
||||
wsl::windows::common::relay::BidirectionalRelay(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user