cleanup: VirtioNetworking refactoring (#13760)

* cleanup: update VirtioNetworking class to not rely on the WslCoreConfig struct

* cleanup: simplify VirtioNetworking construction

* remove old constructor and other cleanup

* more minor cleanup

* string cleanup in HandleVirtioModifyOpenPorts

---------

Co-authored-by: Ben Hillis <benhill@ntdev.microsoft.com>
This commit is contained in:
Ben Hillis 2025-11-21 16:50:26 -08:00 committed by GitHub
parent c3d369df90
commit d9c69a50ab
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 96 additions and 94 deletions

View File

@ -12,34 +12,23 @@ using wsl::core::VirtioNetworking;
static constexpr auto c_loopbackDeviceName = TEXT(LX_INIT_LOOPBACK_DEVICE_NAME);
VirtioNetworking::VirtioNetworking(GnsChannel&& gnsChannel, const Config& config) :
m_gnsChannel(std::move(gnsChannel)), m_config(config)
VirtioNetworking::VirtioNetworking(
GnsChannel&& gnsChannel,
bool enableLocalhostRelay,
AddGuestDeviceCallback addGuestDeviceCallback,
ModifyOpenPortsCallback modifyOpenPortsCallback,
GuestInterfaceStateChangeCallback guestInterfaceStateChangeCallback) :
m_addGuestDeviceCallback(std::move(addGuestDeviceCallback)),
m_gnsChannel(std::move(gnsChannel)),
m_modifyOpenPortsCallback(std::move(modifyOpenPortsCallback)),
m_guestInterfaceStateChangeCallback(std::move(guestInterfaceStateChangeCallback)),
m_enableLocalhostRelay(enableLocalhostRelay)
{
}
VirtioNetworking& VirtioNetworking::OnAddGuestDevice(const AddGuestDeviceRoutine& addGuestDeviceRoutine)
{
m_addGuestDeviceRoutine = addGuestDeviceRoutine;
return *this;
}
VirtioNetworking& VirtioNetworking::OnModifyOpenPorts(const ModifyOpenPortsCallback& modifyOpenPortsCallback)
{
m_modifyOpenPortsCallback = modifyOpenPortsCallback;
return *this;
}
VirtioNetworking& VirtioNetworking::OnGuestInterfaceStateChanged(const GuestInterfaceStateChangeCallback& guestInterfaceStateChangedCallback)
{
m_guestInterfaceStateChangeCallback = guestInterfaceStateChangedCallback;
return *this;
}
void VirtioNetworking::Initialize()
try
{
THROW_HR_IF(E_NOT_SET, !m_addGuestDeviceRoutine || !m_modifyOpenPortsCallback || !m_guestInterfaceStateChangeCallback);
m_networkSettings = GetHostEndpointSettings();
// TODO: Determine gateway MAC address
@ -84,7 +73,7 @@ try
}
// Add virtio net adapter to guest
m_adapterId = (*m_addGuestDeviceRoutine)(c_virtioNetworkClsid, c_virtioNetworkDeviceId, L"eth0", device_options.str().c_str());
m_adapterId = m_addGuestDeviceCallback(c_virtioNetworkClsid, c_virtioNetworkDeviceId, L"eth0", device_options.str().c_str());
auto lock = m_lock.lock_exclusive();
@ -121,7 +110,7 @@ try
UpdateDns(std::move(dnsSettings));
}
if (m_config.EnableLocalhostRelay)
if (m_enableLocalhostRelay)
{
SetupLoopbackDevice();
}
@ -132,7 +121,7 @@ CATCH_LOG()
void VirtioNetworking::SetupLoopbackDevice()
{
m_localhostAdapterId = (*m_addGuestDeviceRoutine)(
m_localhostAdapterId = m_addGuestDeviceCallback(
c_virtioNetworkClsid, c_virtioNetworkDeviceId, c_loopbackDeviceName, L"client_ip=127.0.0.1;client_mac=00:11:22:33:44:55");
hns::HNSEndpoint endpointProperties;
@ -162,7 +151,7 @@ void VirtioNetworking::StartPortTracker(wil::unique_socket&& socket)
m_gnsPortTrackerChannel.emplace(
std::move(socket),
[&](const SOCKADDR_INET& addr, int protocol, bool allocate) { return HandlePortNotification(addr, protocol, allocate); },
[&](_In_ const std::string& interfaceName, _In_ bool up) { (*m_guestInterfaceStateChangeCallback)(interfaceName, up); });
[&](_In_ const std::string& interfaceName, _In_ bool up) { m_guestInterfaceStateChangeCallback(interfaceName, up); });
}
HRESULT VirtioNetworking::HandlePortNotification(const SOCKADDR_INET& addr, int protocol, bool allocate) const noexcept
@ -181,7 +170,7 @@ HRESULT VirtioNetworking::HandlePortNotification(const SOCKADDR_INET& addr, int
}
}
if (m_config.EnableLocalhostRelay && (unspecified || loopback))
if (m_enableLocalhostRelay && (unspecified || loopback))
{
SOCKADDR_INET localAddr = addr;
if (!loopback)
@ -196,12 +185,12 @@ HRESULT VirtioNetworking::HandlePortNotification(const SOCKADDR_INET& addr, int
localAddr.Ipv6.sin6_port = addr.Ipv6.sin6_port;
}
}
result = (*m_modifyOpenPortsCallback)(c_virtioNetworkClsid, c_loopbackDeviceName, localAddr, protocol, allocate);
result = m_modifyOpenPortsCallback(c_virtioNetworkClsid, c_loopbackDeviceName, localAddr, protocol, allocate);
LOG_HR_IF_MSG(E_FAIL, result != S_OK, "Failure adding localhost relay port %d", localAddr.Ipv4.sin_port);
}
if (!loopback)
{
const int localResult = (*m_modifyOpenPortsCallback)(c_virtioNetworkClsid, L"eth0", addr, protocol, allocate);
const int localResult = m_modifyOpenPortsCallback(c_virtioNetworkClsid, L"eth0", addr, protocol, allocate);
LOG_HR_IF_MSG(E_FAIL, localResult != S_OK, "Failure adding relay port %d", addr.Ipv4.sin_port);
if (result == 0)
{

View File

@ -9,20 +9,21 @@
namespace wsl::core {
using AddGuestDeviceRoutine = std::function<GUID(const GUID& clsid, const GUID& deviceId, PCWSTR tag, PCWSTR options)>;
using AddGuestDeviceCallback = std::function<GUID(const GUID& clsid, const GUID& deviceId, PCWSTR tag, PCWSTR options)>;
using ModifyOpenPortsCallback = std::function<int(const GUID& clsid, PCWSTR tag, const SOCKADDR_INET& addr, int protocol, bool isOpen)>;
using GuestInterfaceStateChangeCallback = std::function<void(const std::string& name, bool isUp)>;
class VirtioNetworking : public INetworkingEngine
{
public:
VirtioNetworking(GnsChannel&& gnsChannel, const Config& config);
VirtioNetworking(
GnsChannel&& gnsChannel,
bool enableLocalhostRelay,
AddGuestDeviceCallback addGuestDeviceCallback,
ModifyOpenPortsCallback modifyOpenPortsCallback,
GuestInterfaceStateChangeCallback guestInterfaceStateChangeCallback);
~VirtioNetworking() = default;
VirtioNetworking& OnAddGuestDevice(const AddGuestDeviceRoutine& addGuestDeviceRoutine);
VirtioNetworking& OnModifyOpenPorts(const ModifyOpenPortsCallback& modifyOpenPortsCallback);
VirtioNetworking& OnGuestInterfaceStateChanged(const GuestInterfaceStateChangeCallback& guestInterfaceStateChangedCallback);
// Note: This class cannot be moved because m_networkNotifyHandle captures a 'this' pointer.
VirtioNetworking(const VirtioNetworking&) = delete;
VirtioNetworking(VirtioNetworking&&) = delete;
@ -49,17 +50,17 @@ private:
mutable wil::srwlock m_lock;
std::optional<AddGuestDeviceRoutine> m_addGuestDeviceRoutine;
AddGuestDeviceCallback m_addGuestDeviceCallback;
GnsChannel m_gnsChannel;
std::optional<GnsPortTrackerChannel> m_gnsPortTrackerChannel;
std::shared_ptr<networking::NetworkSettings> m_networkSettings;
const Config& m_config;
bool m_enableLocalhostRelay;
GUID m_localhostAdapterId;
GUID m_adapterId;
std::optional<NL_NETWORK_CONNECTIVITY_LEVEL_HINT> m_connectivityLevel;
std::optional<NL_NETWORK_CONNECTIVITY_COST_HINT> m_connectivityCost;
std::optional<ModifyOpenPortsCallback> m_modifyOpenPortsCallback;
std::optional<GuestInterfaceStateChangeCallback> m_guestInterfaceStateChangeCallback;
ModifyOpenPortsCallback m_modifyOpenPortsCallback;
GuestInterfaceStateChangeCallback m_guestInterfaceStateChangeCallback;
std::optional<ULONGLONG> m_interfaceLuid;
ULONG m_networkMtu = 0;

View File

@ -607,55 +607,16 @@ void WslCoreVm::Initialize(const GUID& VmId, const wil::shared_handle& UserToken
}
else if (m_vmConfig.NetworkingMode == NetworkingMode::VirtioProxy)
{
auto virtioNetworkingEngine = std::make_unique<wsl::core::VirtioNetworking>(std::move(gnsChannel), m_vmConfig);
virtioNetworkingEngine->OnAddGuestDevice([&](const GUID& Clsid, const GUID& DeviceId, PCWSTR Tag, PCWSTR Options) {
auto guestDeviceLock = m_guestDeviceLock.lock_exclusive();
return AddHdvShareWithOptions(DeviceId, Clsid, Tag, {}, Options, 0, m_userToken.get());
});
virtioNetworkingEngine->OnModifyOpenPorts([&](const GUID& Clsid, PCWSTR Tag, const SOCKADDR_INET& addr, int protocol, bool isOpen) {
if (protocol != IPPROTO_TCP && protocol != IPPROTO_UDP)
{
LOG_HR_MSG(HRESULT_FROM_WIN32(ERROR_NOT_SUPPORTED), "Unsupported bind protocol %d", protocol);
return 0;
}
else if (addr.si_family == AF_INET6)
{
// The virtio net adapter does not yet support IPv6 packets, so any traffic would arrive via
// IPv4. If the caller wants IPv4 they will also likely listen on an IPv4 address, which will
// be handled as a separate callback to this same code.
return 0;
}
auto guestDeviceLock = m_guestDeviceLock.lock_exclusive();
const auto server = m_deviceHostSupport->GetRemoteFileSystem(Clsid, c_defaultTag);
if (server)
{
std::wstring portString(L"tag=");
portString += Tag;
portString += L";port_number=";
portString += std::to_wstring(addr.Ipv4.sin_port);
if (protocol == IPPROTO_UDP)
{
portString += L";udp";
}
if (!isOpen)
{
portString += L";allocate=false";
}
else
{
std::wstring addrStr(L"000.000.000.000\0");
RtlIpv4AddressToStringW(&addr.Ipv4.sin_addr, addrStr.data());
portString += L";listen_addr=";
portString += addrStr;
}
LOG_IF_FAILED(server->AddShare(portString.c_str(), nullptr, 0));
}
return 0;
});
virtioNetworkingEngine->OnGuestInterfaceStateChanged([&](const std::string& name, bool isUp) {});
m_networkingEngine.reset(virtioNetworkingEngine.release());
m_networkingEngine = std::make_unique<wsl::core::VirtioNetworking>(
std::move(gnsChannel),
m_vmConfig.EnableLocalhostRelay,
[this](const GUID& Clsid, const GUID& DeviceId, PCWSTR Tag, PCWSTR Options) {
return HandleVirtioAddGuestDevice(Clsid, DeviceId, Tag, Options);
},
[this](const GUID& Clsid, PCWSTR Tag, const SOCKADDR_INET& Addr, int Protocol, bool IsOpen) {
return HandleVirtioModifyOpenPorts(Clsid, Tag, Addr, Protocol, IsOpen);
},
[](const std::string&, bool) {});
}
else if (m_vmConfig.NetworkingMode == NetworkingMode::Bridged)
{
@ -2037,6 +1998,59 @@ bool WslCoreVm::IsDnsTunnelingSupported() const
return SUCCEEDED_LOG(wsl::core::networking::DnsResolver::LoadDnsResolverMethods());
}
bool WslCoreVm::IsVhdAttached(_In_ PCWSTR VhdPath)
{
auto lock = m_lock.lock_exclusive();
return m_attachedDisks.contains({DiskType::VHD, VhdPath});
}
GUID WslCoreVm::HandleVirtioAddGuestDevice(_In_ const GUID& Clsid, _In_ const GUID& DeviceId, _In_ PCWSTR Tag, _In_ PCWSTR Options)
{
auto guestDeviceLock = m_guestDeviceLock.lock_exclusive();
return AddHdvShareWithOptions(DeviceId, Clsid, Tag, {}, Options, 0, m_userToken.get());
}
int WslCoreVm::HandleVirtioModifyOpenPorts(_In_ const GUID& Clsid, _In_ PCWSTR Tag, _In_ const SOCKADDR_INET& Addr, _In_ int Protocol, _In_ bool IsOpen)
{
if (Protocol != IPPROTO_TCP && Protocol != IPPROTO_UDP)
{
LOG_HR_MSG(HRESULT_FROM_WIN32(ERROR_NOT_SUPPORTED), "Unsupported bind protocol %d", Protocol);
return 0;
}
else if (Addr.si_family == AF_INET6)
{
// The virtio net adapter does not yet support IPv6 packets, so any traffic would arrive via
// IPv4. If the caller wants IPv4 they will also likely listen on an IPv4 address, which will
// be handled as a separate callback to this same code.
return 0;
}
auto guestDeviceLock = m_guestDeviceLock.lock_exclusive();
const auto server = m_deviceHostSupport->GetRemoteFileSystem(Clsid, c_defaultTag);
if (server)
{
std::wstring portString = std::format(L"tag={};port_number={}", Tag, Addr.Ipv4.sin_port);
if (Protocol == IPPROTO_UDP)
{
portString += L";udp";
}
if (!IsOpen)
{
portString += L";allocate=false";
}
else
{
wchar_t addrStr[16]; // "000.000.000.000" + null terminator
RtlIpv4AddressToStringW(&Addr.Ipv4.sin_addr, addrStr);
portString += std::format(L";listen_addr={}", addrStr);
}
LOG_IF_FAILED(server->AddShare(portString.c_str(), nullptr, 0));
}
return 0;
}
WslCoreVm::DiskMountResult WslCoreVm::MountDisk(
_In_ PCWSTR Disk, _In_ DiskType MountDiskType, _In_ ULONG PartitionIndex, _In_opt_ PCWSTR Name, _In_opt_ PCWSTR Type, _In_opt_ PCWSTR Options)
{
@ -2846,12 +2860,6 @@ LX_INIT_DRVFS_MOUNT WslCoreVm::s_InitializeDrvFs(_Inout_ WslCoreVm* VmContext, _
}
}
bool WslCoreVm::IsVhdAttached(_In_ PCWSTR VhdPath)
{
auto lock = m_lock.lock_exclusive();
return m_attachedDisks.contains({DiskType::VHD, VhdPath});
}
void CALLBACK WslCoreVm::s_OnExit(_In_ HCS_EVENT* Event, _In_opt_ void* Context)
try
{

View File

@ -107,6 +107,10 @@ public:
bool IsVhdAttached(_In_ PCWSTR VhdPath);
GUID HandleVirtioAddGuestDevice(_In_ const GUID& Clsid, _In_ const GUID& DeviceId, _In_ PCWSTR Tag, _In_ PCWSTR Options);
int HandleVirtioModifyOpenPorts(_In_ const GUID& Clsid, _In_ PCWSTR Tag, _In_ const SOCKADDR_INET& Addr, _In_ int Protocol, _In_ bool IsOpen);
DiskMountResult MountDisk(
_In_ PCWSTR Disk, _In_ DiskType MountDiskType, _In_ ULONG PartitionIndex, _In_opt_ PCWSTR Name, _In_opt_ PCWSTR Type, _In_opt_ PCWSTR Options);