WSL/src/windows/common/VirtioNetworking.cpp
Ben Hillis 0f63354384
Update src/windows/common/VirtioNetworking.cpp
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-12-01 17:01:19 -08:00

421 lines
15 KiB
C++

// Copyright (C) Microsoft Corporation. All rights reserved.
#include "precomp.h"
#include "VirtioNetworking.h"
#include "GuestDeviceManager.h"
#include "Stringify.h"
#include "stringshared.h"
using namespace wsl::core::networking;
using namespace wsl::shared;
using namespace wsl::windows::common::stringify;
using wsl::core::VirtioNetworking;
static constexpr auto c_loopbackDeviceName = TEXT(LX_INIT_LOOPBACK_DEVICE_NAME);
VirtioNetworking::VirtioNetworking(
GnsChannel&& gnsChannel, bool enableLocalhostRelay, std::shared_ptr<GuestDeviceManager> guestDeviceManager, wil::shared_handle userToken) :
m_guestDeviceManager(std::move(guestDeviceManager)),
m_userToken(std::move(userToken)),
m_gnsChannel(std::move(gnsChannel)),
m_enableLocalhostRelay(enableLocalhostRelay)
{
}
VirtioNetworking::~VirtioNetworking()
{
// Unregister the network notification callback to prevent it from using the GNS channel.
m_networkNotifyHandle.reset();
// Stop the GNS channel to unblock any stuck communications with the guest.
m_gnsChannel.Stop();
}
void VirtioNetworking::Initialize()
try
{
m_networkSettings = GetHostEndpointSettings();
// TODO: Determine gateway MAC address
std::wstringstream device_options;
auto client_ip = m_networkSettings->PreferredIpAddress.AddressString;
if (!client_ip.empty())
{
if (device_options.tellp() > 0)
{
device_options << L";";
}
device_options << L"client_ip=" << client_ip;
}
if (!m_networkSettings->MacAddress.empty())
{
if (device_options.tellp() > 0)
{
device_options << L";";
}
device_options << L"client_mac=" << m_networkSettings->MacAddress;
}
std::wstring default_route = m_networkSettings->GetBestGatewayAddressString();
if (!default_route.empty())
{
if (device_options.tellp() > 0)
{
device_options << L";";
}
device_options << L"gateway_ip=" << default_route;
}
auto dns_servers = m_networkSettings->DnsServersString();
if (!dns_servers.empty())
{
if (device_options.tellp() > 0)
{
device_options << L";";
}
device_options << L"nameservers=" << dns_servers;
}
auto lock = m_lock.lock_exclusive();
// Add virtio net adapter to guest
m_adapterId = m_guestDeviceManager->AddGuestDevice(
c_virtioNetworkDeviceId, c_virtioNetworkClsid, L"eth0", nullptr, device_options.str().c_str(), 0, m_userToken.get());
hns::HNSEndpoint endpointProperties;
endpointProperties.ID = m_adapterId;
endpointProperties.IPAddress = m_networkSettings->PreferredIpAddress.AddressString;
endpointProperties.PrefixLength = m_networkSettings->PreferredIpAddress.PrefixLength;
m_gnsChannel.SendEndpointState(endpointProperties);
// N.B. The MAC address is advertised with the virtio device so doesn't need to be explicitly set.
// Send the default route to gns
if (!default_route.empty())
{
wsl::shared::hns::Route route;
route.NextHop = default_route;
route.DestinationPrefix = LX_INIT_DEFAULT_ROUTE_PREFIX;
route.Family = AF_INET;
hns::ModifyGuestEndpointSettingRequest<hns::Route> request;
request.RequestType = hns::ModifyRequestType::Add;
request.ResourceType = hns::GuestEndpointResourceType::Route;
request.Settings = route;
m_gnsChannel.SendHnsNotification(ToJsonW(request).c_str(), m_adapterId);
}
// Update DNS information.
if (!dns_servers.empty())
{
// TODO: DNS domain suffixes
hns::DNS dnsSettings{};
dnsSettings.Options = LX_INIT_RESOLVCONF_FULL_HEADER;
dnsSettings.ServerList = dns_servers;
UpdateDns(std::move(dnsSettings));
}
if (m_enableLocalhostRelay)
{
SetupLoopbackDevice();
}
THROW_IF_WIN32_ERROR(NotifyNetworkConnectivityHintChange(&VirtioNetworking::OnNetworkConnectivityChange, this, true, &m_networkNotifyHandle));
}
CATCH_LOG()
void VirtioNetworking::SetupLoopbackDevice()
{
m_localhostAdapterId = m_guestDeviceManager->AddGuestDevice(
c_virtioNetworkDeviceId,
c_virtioNetworkClsid,
c_loopbackDeviceName,
nullptr,
L"client_ip=127.0.0.1;client_mac=00:11:22:33:44:55",
0,
m_userToken.get());
hns::HNSEndpoint endpointProperties;
endpointProperties.ID = m_localhostAdapterId;
// The loopback gateway (see LX_INIT_IPV4_LOOPBACK_GATEWAY_ADDRESS) is 169.254.73.152, so assign loopback0 an
// address of 169.254.73.153 with a netmask of 30 so that the only addresses associated with this adapter are
// itself and the gateway.
endpointProperties.IPAddress = L"169.254.73.153";
endpointProperties.PrefixLength = 30;
endpointProperties.PortFriendlyName = c_loopbackDeviceName;
m_gnsChannel.SendEndpointState(endpointProperties);
// N.B. The MAC address is advertised with the virtio device so doesn't need to be explicitly set.
hns::CreateDeviceRequest createLoopbackDevice;
createLoopbackDevice.deviceName = c_loopbackDeviceName;
createLoopbackDevice.type = hns::DeviceType::Loopback;
createLoopbackDevice.lowerEdgeAdapterId = m_localhostAdapterId;
constexpr auto loopbackType = GnsMessageType(createLoopbackDevice);
m_gnsChannel.SendNetworkDeviceMessage(loopbackType, ToJsonW(createLoopbackDevice).c_str());
}
void VirtioNetworking::StartPortTracker(wil::unique_socket&& socket)
{
WI_ASSERT(!m_gnsPortTrackerChannel.has_value());
m_gnsPortTrackerChannel.emplace(
std::move(socket),
[&](const SOCKADDR_INET& addr, int protocol, bool allocate) { return HandlePortNotification(addr, protocol, allocate); },
[](const std::string&, bool) {}); // TODO: reconsider if InterfaceStateCallback is needed.
}
HRESULT VirtioNetworking::HandlePortNotification(const SOCKADDR_INET& addr, int protocol, bool allocate) const noexcept
{
int result = 0;
const auto ipAddress = (addr.si_family == AF_INET) ? reinterpret_cast<const void*>(&addr.Ipv4.sin_addr)
: reinterpret_cast<const void*>(&addr.Ipv6.sin6_addr);
const bool loopback = INET_IS_ADDR_LOOPBACK(addr.si_family, ipAddress);
const bool unspecified = INET_IS_ADDR_UNSPECIFIED(addr.si_family, ipAddress);
if (addr.si_family == AF_INET && loopback)
{
// Only intercepting 127.0.0.1; any other loopback address will remain on 'lo'.
if (addr.Ipv4.sin_addr.s_addr != htonl(INADDR_LOOPBACK))
{
return result;
}
}
if (m_enableLocalhostRelay && (unspecified || loopback))
{
SOCKADDR_INET localAddr = addr;
if (!loopback)
{
INETADDR_SETLOOPBACK(reinterpret_cast<PSOCKADDR>(&localAddr));
if (addr.si_family == AF_INET)
{
localAddr.Ipv4.sin_port = addr.Ipv4.sin_port;
}
else
{
localAddr.Ipv6.sin6_port = addr.Ipv6.sin6_port;
}
}
result = ModifyOpenPorts(c_virtioNetworkClsid, c_loopbackDeviceName, localAddr, protocol, allocate);
LOG_HR_IF_MSG(E_FAIL, result != S_OK, "Failure adding localhost relay port %d", localAddr.Ipv4.sin_port);
}
if (!loopback)
{
const int localResult = ModifyOpenPorts(c_virtioNetworkClsid, L"eth0", addr, protocol, allocate);
LOG_HR_IF_MSG(E_FAIL, localResult != S_OK, "Failure adding relay port %d", addr.Ipv4.sin_port);
if (result == 0)
{
result = localResult;
}
}
return result;
}
int VirtioNetworking::ModifyOpenPorts(_In_ const GUID& clsid, _In_ PCWSTR tag, _In_ const SOCKADDR_INET& addr, _In_ int protocol, _In_ bool isOpen) const
{
if (protocol != IPPROTO_TCP && protocol != IPPROTO_UDP)
{
LOG_HR_MSG(HRESULT_FROM_WIN32(ERROR_NOT_SUPPORTED), "Unsupported bind protocol %d", protocol);
return 0;
}
else if (addr.si_family == AF_INET6)
{
// The virtio net adapter does not yet support IPv6 packets, so any traffic would arrive via
// IPv4. If the caller wants IPv4 they will also likely listen on an IPv4 address, which will
// be handled as a separate callback to this same code.
return 0;
}
auto lock = m_lock.lock_exclusive();
const auto server = m_guestDeviceManager->GetRemoteFileSystem(clsid, c_defaultDeviceTag);
if (server)
{
std::wstring portString = std::format(L"tag={};port_number={}", tag, addr.Ipv4.sin_port);
if (protocol == IPPROTO_UDP)
{
portString += L";udp";
}
if (!isOpen)
{
portString += L";allocate=false";
}
else
{
wchar_t addrStr[16]; // "000.000.000.000" + null terminator
RtlIpv4AddressToStringW(&addr.Ipv4.sin_addr, addrStr);
portString += std::format(L";listen_addr={}", addrStr);
}
LOG_IF_FAILED(server->AddShare(portString.c_str(), nullptr, 0));
}
return 0;
}
void NETIOAPI_API_ VirtioNetworking::OnNetworkConnectivityChange(PVOID context, NL_NETWORK_CONNECTIVITY_HINT hint)
{
static_cast<VirtioNetworking*>(context)->RefreshGuestConnection(hint);
}
void VirtioNetworking::RefreshGuestConnection(NL_NETWORK_CONNECTIVITY_HINT connectivityHint) noexcept
try
{
auto lock = m_lock.lock_exclusive();
UpdateMtu();
}
CATCH_LOG();
void VirtioNetworking::UpdateDns(hns::DNS&& dnsSettings)
{
hns::ModifyGuestEndpointSettingRequest<hns::DNS> notification{};
notification.RequestType = hns::ModifyRequestType::Update;
notification.ResourceType = hns::GuestEndpointResourceType::DNS;
notification.Settings = std::move(dnsSettings);
m_gnsChannel.SendHnsNotification(ToJsonW(notification).c_str(), m_adapterId);
}
void VirtioNetworking::UpdateMtu()
{
const auto minMtu = GetMinimumConnectedInterfaceMtu();
// Only send the update if the MTU changed.
if (minMtu && minMtu.value() != m_networkMtu)
{
m_networkMtu = minMtu.value();
hns::ModifyGuestEndpointSettingRequest<hns::NetworkInterface> notification{};
notification.ResourceType = hns::GuestEndpointResourceType::Interface;
notification.RequestType = hns::ModifyRequestType::Update;
notification.Settings.Connected = true;
notification.Settings.NlMtu = m_networkMtu;
WSL_LOG(
"VirtioNetworking::UpdateMtu",
TraceLoggingValue(m_adapterId, "endpointId"),
TraceLoggingValue(m_networkMtu, "virtioMtu"));
m_gnsChannel.SendHnsNotification(ToJsonW(notification).c_str(), m_adapterId);
}
}
void VirtioNetworking::TraceLoggingRundown() noexcept
{
auto lock = m_lock.lock_exclusive();
WSL_LOG("VirtioNetworking::TraceLoggingRundown", TRACE_NETWORKSETTINGS_OBJECT(m_networkSettings));
}
void VirtioNetworking::FillInitialConfiguration(LX_MINI_INIT_NETWORKING_CONFIGURATION& message)
{
message.NetworkingMode = LxMiniInitNetworkingModeVirtioProxy;
message.DisableIpv6 = false;
message.EnableDhcpClient = false;
message.PortTrackerType = LX_MINI_INIT_PORT_TRACKER_TYPE::LxMiniInitPortTrackerTypeMirrored;
}
std::optional<ULONGLONG> VirtioNetworking::FindVirtioInterfaceLuid(const SOCKADDR_INET& VirtioAddress, const NL_NETWORK_CONNECTIVITY_HINT& currentConnectivityHint)
{
constexpr ULONGLONG maxTimeToWaitMs = 10 * 1000;
constexpr ULONG timeToSleepMs = 100;
const auto startTickCount = GetTickCount64();
NET_LUID VirtioLuid{};
for (;;)
{
unique_address_table addressTable;
THROW_IF_WIN32_ERROR(GetUnicastIpAddressTable(AF_INET, &addressTable));
for (const auto& address : wil::make_range(addressTable.get()->Table, addressTable.get()->NumEntries))
{
if (VirtioAddress == address.Address)
{
VirtioLuid.Value = address.InterfaceLuid.Value;
break;
}
WSL_LOG(
"VirtioNetworking::FindVirtioInterfaceLuid [IP Address comparison mismatch]",
TraceLoggingValue(wsl::windows::common::string::SockAddrInetToString(VirtioAddress).c_str(), "VirtioAddress"),
TraceLoggingValue(
wsl::windows::common::string::SockAddrInetToString(address.Address).c_str(), "enumeratedAddress"));
}
if (VirtioLuid.Value != 0)
{
break;
}
// give up if something is just broken and taking too long
if (GetTickCount64() - startTickCount >= maxTimeToWaitMs)
{
break;
}
// else sleep and try again shortly
Sleep(timeToSleepMs);
// bail if connectivity on the host has completely changed
NL_NETWORK_CONNECTIVITY_HINT latestConnectivityHint{};
GetNetworkConnectivityHint(&latestConnectivityHint);
if (latestConnectivityHint != currentConnectivityHint)
{
WSL_LOG("VirtioNetworking::FindVirtioInterfaceLuid [connectivity changed while waiting for the Virtio interface]");
THROW_WIN32_MSG(ERROR_RETRY, "connectivity changed while waiting for the Virtio interface");
}
}
if (VirtioLuid.Value == 0)
{
WSL_LOG(
"VirtioNetworking::FindVirtioInterfaceLuid [IP address not found]",
TraceLoggingValue(VirtioLuid.Value, "VirtioInterfaceLuid"),
TraceLoggingValue(wsl::windows::common::string::SockAddrInetToString(VirtioAddress).c_str(), "VirtioIPAddress"));
return {};
}
WSL_LOG(
"VirtioNetworking::FindVirtioInterfaceLuid [waiting for Virtio interface to be connected]",
TraceLoggingValue(VirtioLuid.Value, "VirtioInterfaceLuid"),
TraceLoggingValue(wsl::windows::common::string::SockAddrInetToString(VirtioAddress).c_str(), "VirtioIPAddress"));
bool ipv4Connected = false;
for (;;)
{
unique_interface_table interfaceTable{};
THROW_IF_WIN32_ERROR(::GetIpInterfaceTable(AF_UNSPEC, &interfaceTable));
// we only track the IPv4 interface because we only Virtio IPv4 to the container
for (auto index = 0ul; index < interfaceTable.get()->NumEntries; ++index)
{
const auto& ipInterface = interfaceTable.get()->Table[index];
if (ipInterface.Family == AF_INET && !!ipInterface.Connected && ipInterface.InterfaceLuid.Value == VirtioLuid.Value)
{
ipv4Connected = true;
break;
}
}
if (ipv4Connected)
{
break;
}
// give up if something is just broken and taking too long
if (GetTickCount64() - startTickCount >= maxTimeToWaitMs)
{
break;
}
// else sleep and try again shortly
Sleep(timeToSleepMs);
// bail if connectivity on the host has completely changed
NL_NETWORK_CONNECTIVITY_HINT latestConnectivityHint{};
GetNetworkConnectivityHint(&latestConnectivityHint);
if (latestConnectivityHint != currentConnectivityHint)
{
WSL_LOG("VirtioNetworking::FindVirtioInterfaceLuid [connectivity changed while waiting for the Virtio interface]");
THROW_WIN32_MSG(ERROR_RETRY, "connectivity changed while waiting for the Virtio interface");
}
}
// return zero if it's not connected yet so we can retry the next cycle
return ipv4Connected ? VirtioLuid.Value : std::optional<ULONGLONG>();
}