/*++ Copyright (c) Microsoft. All rights reserved. Module Name: SocketChannel.h Abstract: This file contains the SocketChannel helper class implementation. --*/ #pragma once #include #include "socketshared.h" #include "lxinitshared.h" #ifndef WIN32 #include #include "lxwil.h" #include "../../linux/init/util.h" extern std::optional g_EnableSocketLogging; #endif namespace wsl::shared { #ifdef WIN32 using TSocket = wil::unique_socket; using TTimeout = DWORD; constexpr DWORD DefaultSocketTimeout = INFINITE; #else using TSocket = wil::unique_fd; using TTimeout = const timeval*; constexpr timeval* DefaultSocketTimeout = nullptr; #endif class SocketChannel { public: SocketChannel() = default; SocketChannel(const SocketChannel&) = delete; SocketChannel(SocketChannel&& other) { *this = std::move(other); } SocketChannel& operator=(const SocketChannel&) = delete; SocketChannel& operator=(SocketChannel&& other) { m_name = std::move(other.m_name); m_socket = std::move(other.m_socket); #ifdef WIN32 m_exitEvent = std::move(other.m_exitEvent); #endif m_ignore_sequence = other.m_ignore_sequence; return *this; } SocketChannel(TSocket&& socket, std::string&& name) : m_socket(std::move(socket)), m_name(std::move(name)) { } #ifdef WIN32 SocketChannel(TSocket&& socket, std::string&& name, HANDLE exitEvent) : m_socket(std::move(socket)), m_exitEvent(exitEvent), m_name(std::move(name)) { } #endif template void SendMessage(gsl::span span) { // Ensure that no other thread is using this channel. const std::unique_lock lock{m_sendMutex, std::try_to_lock}; if (!lock.owns_lock()) { #ifdef WIN32 THROW_HR_MSG(E_UNEXPECTED, "Incorrect channel usage detected on channel: %hs, message type: %hs", m_name.c_str(), ToString(TMessage::Type)); #else LOG_ERROR("Incorrect channel usage detected on channel: {}, message type: {}", m_name, ToString(TMessage::Type)); THROW_ERRNO(EINVAL); #endif } THROW_INVALID_ARG_IF(m_name.empty() || span.size() < sizeof(TMessage)); m_sent_messages++; auto* header = gslhelpers::try_get_struct(span); WI_ASSERT(header->MessageSize == span.size()); header->SequenceNumber = m_sent_messages; #ifdef WIN32 auto sentBytes = wsl::windows::common::socket::Send(m_socket.get(), span, m_exitEvent); WSL_LOG( "SentMessage", TraceLoggingValue(m_name.c_str(), "Name"), TraceLoggingValue(reinterpret_cast(span.data())->PrettyPrint().c_str(), "Content"), TraceLoggingValue(sentBytes, "SentBytes")); #else if (LoggingEnabled()) { LOG_INFO("SentMessage on channel: {}: '{}'", m_name, reinterpret_cast(span.data())->PrettyPrint().c_str()); } if (UtilWriteBuffer(m_socket.get(), span.data(), span.size()) < 0) { LOG_ERROR("Failed to write message {}. Channel: {}", header->MessageType, m_name); THROW_LAST_ERROR(); } #endif } template MESSAGE_HEADER& GetMessageHeader(TMessage& message) { if constexpr (std::is_same_v) { return message; } else { return message.Header; } } template void SendMessage() { TMessage message; SendMessage(message); } template void SendMessage(TMessage& message) { // Catch situations where the other SendMessage() method should be used const auto& header = GetMessageHeader(message); if (header.MessageSize != sizeof(message)) { #ifdef WIN32 THROW_HR_MSG(E_INVALIDARG, "Incorrect header size for message type: %u on channel: %hs", header.MessageType, m_name.c_str()); #else LOG_ERROR("Incorrect header size for message type: {} on channel: {}", header.MessageType, m_name); THROW_ERRNO(EINVAL); #endif } SendMessage(gslhelpers::struct_as_writeable_bytes(message)); } template void SendResultMessage(TResult value) { RESULT_MESSAGE Result{}; Result.Header.MessageSize = sizeof(Result); Result.Header.MessageType = RESULT_MESSAGE::Type; Result.Result = value; SendMessage(Result); } template std::pair> ReceiveMessageOrClosed(TTimeout timeout = DefaultSocketTimeout) { WI_ASSERT(!m_name.empty()); // Ensure that no other thread is using this channel. const std::unique_lock lock{m_receiveMutex, std::try_to_lock}; if (!lock.owns_lock()) { #ifdef WIN32 THROW_HR_MSG(E_UNEXPECTED, "Incorrect channel usage detected on channel: %hs", m_name.c_str()); #else LOG_ERROR("Incorrect channel usage detected on channel: {}", m_name); THROW_ERRNO(EINVAL); #endif } m_received_messages++; auto receivedSpan = ReceiveImpl(TMessage::Type, timeout); if (receivedSpan.empty()) { #ifdef WIN32 if (errno == HCS_E_CONNECTION_TIMEOUT) { THROW_HR_MSG( HCS_E_CONNECTION_TIMEOUT, "Timeout: %d, expected type: %hs, channel: %hs", timeout, ToString(TMessage::Type), m_name.c_str()); } #endif return {nullptr, {}}; } auto* message = gslhelpers::try_get_struct(receivedSpan); if (message == nullptr) { #ifdef WIN32 THROW_HR_MSG( E_UNEXPECTED, "Message size is too small: %zd, expected type: %hs, channel: %hs", receivedSpan.size(), ToString(TMessage::Type), m_name.c_str()); #else LOG_ERROR("MessageSize is too small: {}, expected type: {}, channel: {}", receivedSpan.size(), ToString(TMessage::Type), m_name); THROW_ERRNO(EINVAL); #endif } ValidateMessageHeader(GetMessageHeader(*message), TMessage::Type, m_received_messages); #ifdef WIN32 WSL_LOG( "ReceivedMessage", TraceLoggingValue(m_name.c_str(), "Name"), TraceLoggingValue(message->PrettyPrint().c_str(), "Content")); #else if (LoggingEnabled()) { LOG_INFO("ReceivedMessage on channel: {}: '{}'", m_name, message->PrettyPrint().c_str()); } #endif return {message, receivedSpan}; } template TMessage& ReceiveMessage(gsl::span* responseSpan = nullptr, TTimeout timeout = DefaultSocketTimeout) { auto [message, span] = ReceiveMessageOrClosed(timeout); if (message == nullptr) { #ifdef WIN32 THROW_HR_MSG(E_UNEXPECTED, "Expected message %hs, but socket %hs was closed", ToString(TMessage::Type), m_name.c_str()); #else LOG_ERROR("ExpectedMessage {}, but socket {} was closed", ToString(TMessage::Type), m_name); THROW_ERRNO(EINVAL); #endif } if (responseSpan != nullptr) { *responseSpan = span; } return *message; } template TSentMessage::TResponse& Transaction(gsl::span message, gsl::span* responseSpan = nullptr, TTimeout timeout = DefaultSocketTimeout) { SendMessage(message); return ReceiveMessage(responseSpan, timeout); } template TSentMessage::TResponse& Transaction(TSentMessage& message, gsl::span* responseSpan = nullptr, TTimeout timeout = DefaultSocketTimeout) { WI_ASSERT(message.Header.MessageSize == sizeof(message)); return Transaction(gslhelpers::struct_as_writeable_bytes(message), responseSpan, timeout); } template TSentMessage::TResponse& Transaction() { TSentMessage message{}; message.Header.MessageSize = sizeof(message); message.Header.MessageType = TSentMessage::Type; return Transaction(message); } void Close() { m_socket.reset(); } auto Socket() const { return m_socket.get(); } bool Connected() const { return m_socket.get() >= 0; } void IgnoreSequenceNumbers() { m_ignore_sequence = true; } #ifndef WIN32 static void EnableSocketLogging(bool enable) { g_EnableSocketLogging = enable; } #endif private: #ifdef WIN32 gsl::span ReceiveImpl(auto expectedMessage, TTimeout timeout) { return wsl::shared::socket::RecvMessage(m_socket.get(), m_buffer, m_exitEvent, timeout); } #else gsl::span ReceiveImpl(auto expectedMessage, TTimeout timeout) { return wsl::shared::socket::RecvMessage(m_socket.get(), m_buffer, timeout); } #endif void ValidateMessageHeader(const MESSAGE_HEADER& header, LX_MESSAGE_TYPE expected, unsigned int expectedSequence) const { if (header.MessageSize < sizeof(header) || (expected != LxMiniInitMessageAny && header.MessageType != expected) || (!m_ignore_sequence && header.SequenceNumber != expectedSequence)) { #ifdef WIN32 THROW_HR_MSG( E_UNEXPECTED, "Protocol error: Received message size: %u, type: %u, sequence: %u. Expected type: %u, expected sequence: %u, " "channel: %hs", header.MessageSize, header.MessageType, header.SequenceNumber, expected, expectedSequence, m_name.c_str()); #else LOG_ERROR( "Protocol error: Received message size: {}, type: {}, sequence: {}. Expected type: {}, expected sequence: {}, " "channel: {}", header.MessageSize, header.MessageType, header.SequenceNumber, expected, expectedSequence, m_name); THROW_ERRNO(EINVAL); #endif } } #ifndef WIN32 static bool LoggingEnabled() { static std::once_flag flag; std::call_once(flag, [&]() { try { if (g_EnableSocketLogging.has_value()) { return; } auto content = UtilReadFileContent("/proc/cmdline"); g_EnableSocketLogging = content.find("WSL_SOCKET_LOG") != std::string::npos; } catch (...) { LOG_CAUGHT_EXCEPTION(); g_EnableSocketLogging = false; } }); return g_EnableSocketLogging.value(); } #endif TSocket m_socket{}; std::vector m_buffer; #ifdef WIN32 HANDLE m_exitEvent{}; #endif uint32_t m_sent_messages = 0; uint32_t m_received_messages = 0; bool m_ignore_sequence = false; std::string m_name{}; std::mutex m_sendMutex; std::mutex m_receiveMutex; }; } // namespace wsl::shared