diff --git a/src/linux/init/common.h b/src/linux/init/common.h index eca78c5..945d0ea 100644 --- a/src/linux/init/common.h +++ b/src/linux/init/common.h @@ -54,9 +54,6 @@ Abstract: #include #include "lxinitshared.h" #include "defs.h" -#include "retryshared.h" -#include "socketshared.h" -#include "stringshared.h" #define ETC_FOLDER "/etc/" #define NAME_ENV "NAME" @@ -151,6 +148,11 @@ auto LogImpl(int fd, const std::format_string& format, Args&&... args) #define FATAL_ERROR(str, ...) FATAL_ERROR_EX(1, str, ##__VA_ARGS__) +// Some of these files need the LOG_* macros. +#include "retryshared.h" +#include "socketshared.h" +#include "stringshared.h" + int InitializeLogging(bool SetStderr, wil::LogFunction* ExceptionCallback = nullptr) noexcept; void LogException(const char* Message, const char* Description) noexcept; diff --git a/src/shared/inc/SocketChannel.h b/src/shared/inc/SocketChannel.h index 8fb47d0..154d314 100644 --- a/src/shared/inc/SocketChannel.h +++ b/src/shared/inc/SocketChannel.h @@ -112,12 +112,13 @@ public: #ifdef WIN32 + auto sentBytes = wsl::windows::common::socket::Send(m_socket.get(), span, m_exitEvent); + WSL_LOG( "SentMessage", TraceLoggingValue(m_name, "Name"), - TraceLoggingValue(reinterpret_cast(span.data())->PrettyPrint().c_str(), "Content")); - - wsl::windows::common::socket::Send(m_socket.get(), span, m_exitEvent); + TraceLoggingValue(reinterpret_cast(span.data())->PrettyPrint().c_str(), "Content"), + TraceLoggingValue(sentBytes, "SentBytes")); #else @@ -130,7 +131,7 @@ public: { LOG_ERROR("Failed to write message {}. Channel: {}", header->MessageType, m_name); THROW_LAST_ERROR(); - }; + } #endif } diff --git a/src/shared/inc/socketshared.h b/src/shared/inc/socketshared.h index d64a539..d07916d 100644 --- a/src/shared/inc/socketshared.h +++ b/src/shared/inc/socketshared.h @@ -80,6 +80,27 @@ try #endif if (BytesRead <= 0) { + const auto* Header = reinterpret_cast(Buffer.data()); + +#if defined(_MSC_VER) + + LOG_HR_MSG( + E_UNEXPECTED, + "Socket closed while reading message. Size: %u, type: %i, sequence: %u", + Header->MessageSize, + Header->MessageType, + Header->SequenceNumber); + +#elif defined(__GNUC__) + + LOG_ERROR( + "Socket closed while reading message. Size: {}, type: {}, sequence: {}", + Header->MessageSize, + Header->MessageType, + Header->SequenceNumber); + +#endif + return {}; } diff --git a/src/windows/common/socket.cpp b/src/windows/common/socket.cpp index 89d6a87..50fe0ae 100644 --- a/src/windows/common/socket.cpp +++ b/src/windows/common/socket.cpp @@ -138,18 +138,43 @@ std::vector wsl::windows::common::socket::Receive( int wsl::windows::common::socket::Send( _In_ SOCKET Socket, _In_ gsl::span Buffer, _In_opt_ HANDLE ExitHandle, _In_ const std::source_location& Location) { - OVERLAPPED Overlapped{}; const wil::unique_event OverlappedEvent(wil::EventOptions::ManualReset); - WSABUF VectorBuffer = {gsl::narrow_cast(Buffer.size()), const_cast(reinterpret_cast(Buffer.data()))}; + OVERLAPPED Overlapped{}; Overlapped.hEvent = OverlappedEvent.get(); - DWORD BytesWritten{}; - if (WSASend(Socket, &VectorBuffer, 1, &BytesWritten, 0, &Overlapped, nullptr) != 0) + + DWORD Offset = 0; + while (Offset < Buffer.size()) { - DWORD Flags; - std::tie(BytesWritten, Flags) = GetResult(Socket, Overlapped, INFINITE, ExitHandle, Location); + OverlappedEvent.ResetEvent(); + + WSABUF VectorBuffer = { + gsl::narrow_cast(Buffer.size() - Offset), const_cast(reinterpret_cast(Buffer.data() + Offset))}; + + DWORD BytesWritten{}; + if (WSASend(Socket, &VectorBuffer, 1, &BytesWritten, 0, &Overlapped, nullptr) != 0) + { + // If WSASend returns non-zero, expect WSA_IO_PENDING. + if (auto error = WSAGetLastError(); error != WSA_IO_PENDING) + { + THROW_WIN32_MSG(error, "WSASend failed. From: %hs", std::format("{}", Location).c_str()); + } + + DWORD Flags; + std::tie(BytesWritten, Flags) = GetResult(Socket, Overlapped, INFINITE, ExitHandle, Location); + if (BytesWritten == 0) + { + THROW_WIN32_MSG(ERROR_CONNECTION_ABORTED, "Socket closed during WSASend(). From: %hs", std::format("{}", Location).c_str()); + } + } + + Offset += BytesWritten; + if (Offset < Buffer.size()) + { + WSL_LOG("PartialSocketWrite", TraceLoggingValue(Buffer.size(), "MessagSize"), TraceLoggingValue(Offset, "Offset")); + } } - WI_ASSERT(BytesWritten == gsl::narrow_cast(Buffer.size())); + WI_ASSERT(Offset == gsl::narrow_cast(Buffer.size())); - return BytesWritten; + return Offset; }