Add logic to handle partial hvsocket writes and additional logging (#13602)

This commit is contained in:
Blue 2025-10-16 10:53:49 -07:00 committed by GitHub
parent 0974c302c0
commit b3a7b7d395
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 64 additions and 15 deletions

View File

@ -54,9 +54,6 @@ Abstract:
#include <cstdarg>
#include "lxinitshared.h"
#include "defs.h"
#include "retryshared.h"
#include "socketshared.h"
#include "stringshared.h"
#define ETC_FOLDER "/etc/"
#define NAME_ENV "NAME"
@ -151,6 +148,11 @@ auto LogImpl(int fd, const std::format_string<Args...>& format, Args&&... args)
#define FATAL_ERROR(str, ...) FATAL_ERROR_EX(1, str, ##__VA_ARGS__)
// Some of these files need the LOG_* macros.
#include "retryshared.h"
#include "socketshared.h"
#include "stringshared.h"
int InitializeLogging(bool SetStderr, wil::LogFunction* ExceptionCallback = nullptr) noexcept;
void LogException(const char* Message, const char* Description) noexcept;

View File

@ -112,12 +112,13 @@ public:
#ifdef WIN32
auto sentBytes = wsl::windows::common::socket::Send(m_socket.get(), span, m_exitEvent);
WSL_LOG(
"SentMessage",
TraceLoggingValue(m_name, "Name"),
TraceLoggingValue(reinterpret_cast<const TMessage*>(span.data())->PrettyPrint().c_str(), "Content"));
wsl::windows::common::socket::Send(m_socket.get(), span, m_exitEvent);
TraceLoggingValue(reinterpret_cast<const TMessage*>(span.data())->PrettyPrint().c_str(), "Content"),
TraceLoggingValue(sentBytes, "SentBytes"));
#else
@ -130,7 +131,7 @@ public:
{
LOG_ERROR("Failed to write message {}. Channel: {}", header->MessageType, m_name);
THROW_LAST_ERROR();
};
}
#endif
}

View File

@ -80,6 +80,27 @@ try
#endif
if (BytesRead <= 0)
{
const auto* Header = reinterpret_cast<const MESSAGE_HEADER*>(Buffer.data());
#if defined(_MSC_VER)
LOG_HR_MSG(
E_UNEXPECTED,
"Socket closed while reading message. Size: %u, type: %i, sequence: %u",
Header->MessageSize,
Header->MessageType,
Header->SequenceNumber);
#elif defined(__GNUC__)
LOG_ERROR(
"Socket closed while reading message. Size: {}, type: {}, sequence: {}",
Header->MessageSize,
Header->MessageType,
Header->SequenceNumber);
#endif
return {};
}

View File

@ -138,18 +138,43 @@ std::vector<gsl::byte> wsl::windows::common::socket::Receive(
int wsl::windows::common::socket::Send(
_In_ SOCKET Socket, _In_ gsl::span<const gsl::byte> Buffer, _In_opt_ HANDLE ExitHandle, _In_ const std::source_location& Location)
{
OVERLAPPED Overlapped{};
const wil::unique_event OverlappedEvent(wil::EventOptions::ManualReset);
WSABUF VectorBuffer = {gsl::narrow_cast<ULONG>(Buffer.size()), const_cast<CHAR*>(reinterpret_cast<const CHAR*>(Buffer.data()))};
OVERLAPPED Overlapped{};
Overlapped.hEvent = OverlappedEvent.get();
DWORD BytesWritten{};
if (WSASend(Socket, &VectorBuffer, 1, &BytesWritten, 0, &Overlapped, nullptr) != 0)
DWORD Offset = 0;
while (Offset < Buffer.size())
{
DWORD Flags;
std::tie(BytesWritten, Flags) = GetResult(Socket, Overlapped, INFINITE, ExitHandle, Location);
OverlappedEvent.ResetEvent();
WSABUF VectorBuffer = {
gsl::narrow_cast<ULONG>(Buffer.size() - Offset), const_cast<CHAR*>(reinterpret_cast<const CHAR*>(Buffer.data() + Offset))};
DWORD BytesWritten{};
if (WSASend(Socket, &VectorBuffer, 1, &BytesWritten, 0, &Overlapped, nullptr) != 0)
{
// If WSASend returns non-zero, expect WSA_IO_PENDING.
if (auto error = WSAGetLastError(); error != WSA_IO_PENDING)
{
THROW_WIN32_MSG(error, "WSASend failed. From: %hs", std::format("{}", Location).c_str());
}
DWORD Flags;
std::tie(BytesWritten, Flags) = GetResult(Socket, Overlapped, INFINITE, ExitHandle, Location);
if (BytesWritten == 0)
{
THROW_WIN32_MSG(ERROR_CONNECTION_ABORTED, "Socket closed during WSASend(). From: %hs", std::format("{}", Location).c_str());
}
}
Offset += BytesWritten;
if (Offset < Buffer.size())
{
WSL_LOG("PartialSocketWrite", TraceLoggingValue(Buffer.size(), "MessagSize"), TraceLoggingValue(Offset, "Offset"));
}
}
WI_ASSERT(BytesWritten == gsl::narrow_cast<DWORD>(Buffer.size()));
WI_ASSERT(Offset == gsl::narrow_cast<DWORD>(Buffer.size()));
return BytesWritten;
return Offset;
}