diff --git a/src/windows/service/exe/VirtioNetworking.cpp b/src/windows/service/exe/VirtioNetworking.cpp index 80766d7..1d48b03 100644 --- a/src/windows/service/exe/VirtioNetworking.cpp +++ b/src/windows/service/exe/VirtioNetworking.cpp @@ -12,34 +12,23 @@ using wsl::core::VirtioNetworking; static constexpr auto c_loopbackDeviceName = TEXT(LX_INIT_LOOPBACK_DEVICE_NAME); -VirtioNetworking::VirtioNetworking(GnsChannel&& gnsChannel, const Config& config) : - m_gnsChannel(std::move(gnsChannel)), m_config(config) +VirtioNetworking::VirtioNetworking( + GnsChannel&& gnsChannel, + bool enableLocalhostRelay, + AddGuestDeviceCallback addGuestDeviceCallback, + ModifyOpenPortsCallback modifyOpenPortsCallback, + GuestInterfaceStateChangeCallback guestInterfaceStateChangeCallback) : + m_addGuestDeviceCallback(std::move(addGuestDeviceCallback)), + m_gnsChannel(std::move(gnsChannel)), + m_modifyOpenPortsCallback(std::move(modifyOpenPortsCallback)), + m_guestInterfaceStateChangeCallback(std::move(guestInterfaceStateChangeCallback)), + m_enableLocalhostRelay(enableLocalhostRelay) { } -VirtioNetworking& VirtioNetworking::OnAddGuestDevice(const AddGuestDeviceRoutine& addGuestDeviceRoutine) -{ - m_addGuestDeviceRoutine = addGuestDeviceRoutine; - return *this; -} - -VirtioNetworking& VirtioNetworking::OnModifyOpenPorts(const ModifyOpenPortsCallback& modifyOpenPortsCallback) -{ - m_modifyOpenPortsCallback = modifyOpenPortsCallback; - return *this; -} - -VirtioNetworking& VirtioNetworking::OnGuestInterfaceStateChanged(const GuestInterfaceStateChangeCallback& guestInterfaceStateChangedCallback) -{ - m_guestInterfaceStateChangeCallback = guestInterfaceStateChangedCallback; - return *this; -} - void VirtioNetworking::Initialize() try { - THROW_HR_IF(E_NOT_SET, !m_addGuestDeviceRoutine || !m_modifyOpenPortsCallback || !m_guestInterfaceStateChangeCallback); - m_networkSettings = GetHostEndpointSettings(); // TODO: Determine gateway MAC address @@ -84,7 +73,7 @@ try } // Add virtio net adapter to guest - m_adapterId = (*m_addGuestDeviceRoutine)(c_virtioNetworkClsid, c_virtioNetworkDeviceId, L"eth0", device_options.str().c_str()); + m_adapterId = m_addGuestDeviceCallback(c_virtioNetworkClsid, c_virtioNetworkDeviceId, L"eth0", device_options.str().c_str()); auto lock = m_lock.lock_exclusive(); @@ -121,7 +110,7 @@ try UpdateDns(std::move(dnsSettings)); } - if (m_config.EnableLocalhostRelay) + if (m_enableLocalhostRelay) { SetupLoopbackDevice(); } @@ -132,7 +121,7 @@ CATCH_LOG() void VirtioNetworking::SetupLoopbackDevice() { - m_localhostAdapterId = (*m_addGuestDeviceRoutine)( + m_localhostAdapterId = m_addGuestDeviceCallback( c_virtioNetworkClsid, c_virtioNetworkDeviceId, c_loopbackDeviceName, L"client_ip=127.0.0.1;client_mac=00:11:22:33:44:55"); hns::HNSEndpoint endpointProperties; @@ -162,7 +151,7 @@ void VirtioNetworking::StartPortTracker(wil::unique_socket&& socket) m_gnsPortTrackerChannel.emplace( std::move(socket), [&](const SOCKADDR_INET& addr, int protocol, bool allocate) { return HandlePortNotification(addr, protocol, allocate); }, - [&](_In_ const std::string& interfaceName, _In_ bool up) { (*m_guestInterfaceStateChangeCallback)(interfaceName, up); }); + [&](_In_ const std::string& interfaceName, _In_ bool up) { m_guestInterfaceStateChangeCallback(interfaceName, up); }); } HRESULT VirtioNetworking::HandlePortNotification(const SOCKADDR_INET& addr, int protocol, bool allocate) const noexcept @@ -181,7 +170,7 @@ HRESULT VirtioNetworking::HandlePortNotification(const SOCKADDR_INET& addr, int } } - if (m_config.EnableLocalhostRelay && (unspecified || loopback)) + if (m_enableLocalhostRelay && (unspecified || loopback)) { SOCKADDR_INET localAddr = addr; if (!loopback) @@ -196,12 +185,12 @@ HRESULT VirtioNetworking::HandlePortNotification(const SOCKADDR_INET& addr, int localAddr.Ipv6.sin6_port = addr.Ipv6.sin6_port; } } - result = (*m_modifyOpenPortsCallback)(c_virtioNetworkClsid, c_loopbackDeviceName, localAddr, protocol, allocate); + result = m_modifyOpenPortsCallback(c_virtioNetworkClsid, c_loopbackDeviceName, localAddr, protocol, allocate); LOG_HR_IF_MSG(E_FAIL, result != S_OK, "Failure adding localhost relay port %d", localAddr.Ipv4.sin_port); } if (!loopback) { - const int localResult = (*m_modifyOpenPortsCallback)(c_virtioNetworkClsid, L"eth0", addr, protocol, allocate); + const int localResult = m_modifyOpenPortsCallback(c_virtioNetworkClsid, L"eth0", addr, protocol, allocate); LOG_HR_IF_MSG(E_FAIL, localResult != S_OK, "Failure adding relay port %d", addr.Ipv4.sin_port); if (result == 0) { diff --git a/src/windows/service/exe/VirtioNetworking.h b/src/windows/service/exe/VirtioNetworking.h index 617195e..decae80 100644 --- a/src/windows/service/exe/VirtioNetworking.h +++ b/src/windows/service/exe/VirtioNetworking.h @@ -9,20 +9,21 @@ namespace wsl::core { -using AddGuestDeviceRoutine = std::function; +using AddGuestDeviceCallback = std::function; using ModifyOpenPortsCallback = std::function; using GuestInterfaceStateChangeCallback = std::function; class VirtioNetworking : public INetworkingEngine { public: - VirtioNetworking(GnsChannel&& gnsChannel, const Config& config); + VirtioNetworking( + GnsChannel&& gnsChannel, + bool enableLocalhostRelay, + AddGuestDeviceCallback addGuestDeviceCallback, + ModifyOpenPortsCallback modifyOpenPortsCallback, + GuestInterfaceStateChangeCallback guestInterfaceStateChangeCallback); ~VirtioNetworking() = default; - VirtioNetworking& OnAddGuestDevice(const AddGuestDeviceRoutine& addGuestDeviceRoutine); - VirtioNetworking& OnModifyOpenPorts(const ModifyOpenPortsCallback& modifyOpenPortsCallback); - VirtioNetworking& OnGuestInterfaceStateChanged(const GuestInterfaceStateChangeCallback& guestInterfaceStateChangedCallback); - // Note: This class cannot be moved because m_networkNotifyHandle captures a 'this' pointer. VirtioNetworking(const VirtioNetworking&) = delete; VirtioNetworking(VirtioNetworking&&) = delete; @@ -49,17 +50,17 @@ private: mutable wil::srwlock m_lock; - std::optional m_addGuestDeviceRoutine; + AddGuestDeviceCallback m_addGuestDeviceCallback; GnsChannel m_gnsChannel; std::optional m_gnsPortTrackerChannel; std::shared_ptr m_networkSettings; - const Config& m_config; + bool m_enableLocalhostRelay; GUID m_localhostAdapterId; GUID m_adapterId; std::optional m_connectivityLevel; std::optional m_connectivityCost; - std::optional m_modifyOpenPortsCallback; - std::optional m_guestInterfaceStateChangeCallback; + ModifyOpenPortsCallback m_modifyOpenPortsCallback; + GuestInterfaceStateChangeCallback m_guestInterfaceStateChangeCallback; std::optional m_interfaceLuid; ULONG m_networkMtu = 0; diff --git a/src/windows/service/exe/WslCoreVm.cpp b/src/windows/service/exe/WslCoreVm.cpp index d235078..f9a401c 100644 --- a/src/windows/service/exe/WslCoreVm.cpp +++ b/src/windows/service/exe/WslCoreVm.cpp @@ -607,55 +607,16 @@ void WslCoreVm::Initialize(const GUID& VmId, const wil::shared_handle& UserToken } else if (m_vmConfig.NetworkingMode == NetworkingMode::VirtioProxy) { - auto virtioNetworkingEngine = std::make_unique(std::move(gnsChannel), m_vmConfig); - virtioNetworkingEngine->OnAddGuestDevice([&](const GUID& Clsid, const GUID& DeviceId, PCWSTR Tag, PCWSTR Options) { - auto guestDeviceLock = m_guestDeviceLock.lock_exclusive(); - return AddHdvShareWithOptions(DeviceId, Clsid, Tag, {}, Options, 0, m_userToken.get()); - }); - - virtioNetworkingEngine->OnModifyOpenPorts([&](const GUID& Clsid, PCWSTR Tag, const SOCKADDR_INET& addr, int protocol, bool isOpen) { - if (protocol != IPPROTO_TCP && protocol != IPPROTO_UDP) - { - LOG_HR_MSG(HRESULT_FROM_WIN32(ERROR_NOT_SUPPORTED), "Unsupported bind protocol %d", protocol); - return 0; - } - else if (addr.si_family == AF_INET6) - { - // The virtio net adapter does not yet support IPv6 packets, so any traffic would arrive via - // IPv4. If the caller wants IPv4 they will also likely listen on an IPv4 address, which will - // be handled as a separate callback to this same code. - return 0; - } - - auto guestDeviceLock = m_guestDeviceLock.lock_exclusive(); - const auto server = m_deviceHostSupport->GetRemoteFileSystem(Clsid, c_defaultTag); - if (server) - { - std::wstring portString(L"tag="); - portString += Tag; - portString += L";port_number="; - portString += std::to_wstring(addr.Ipv4.sin_port); - if (protocol == IPPROTO_UDP) - { - portString += L";udp"; - } - if (!isOpen) - { - portString += L";allocate=false"; - } - else - { - std::wstring addrStr(L"000.000.000.000\0"); - RtlIpv4AddressToStringW(&addr.Ipv4.sin_addr, addrStr.data()); - portString += L";listen_addr="; - portString += addrStr; - } - LOG_IF_FAILED(server->AddShare(portString.c_str(), nullptr, 0)); - } - return 0; - }); - virtioNetworkingEngine->OnGuestInterfaceStateChanged([&](const std::string& name, bool isUp) {}); - m_networkingEngine.reset(virtioNetworkingEngine.release()); + m_networkingEngine = std::make_unique( + std::move(gnsChannel), + m_vmConfig.EnableLocalhostRelay, + [this](const GUID& Clsid, const GUID& DeviceId, PCWSTR Tag, PCWSTR Options) { + return HandleVirtioAddGuestDevice(Clsid, DeviceId, Tag, Options); + }, + [this](const GUID& Clsid, PCWSTR Tag, const SOCKADDR_INET& Addr, int Protocol, bool IsOpen) { + return HandleVirtioModifyOpenPorts(Clsid, Tag, Addr, Protocol, IsOpen); + }, + [](const std::string&, bool) {}); } else if (m_vmConfig.NetworkingMode == NetworkingMode::Bridged) { @@ -2037,6 +1998,59 @@ bool WslCoreVm::IsDnsTunnelingSupported() const return SUCCEEDED_LOG(wsl::core::networking::DnsResolver::LoadDnsResolverMethods()); } +bool WslCoreVm::IsVhdAttached(_In_ PCWSTR VhdPath) +{ + auto lock = m_lock.lock_exclusive(); + return m_attachedDisks.contains({DiskType::VHD, VhdPath}); +} + +GUID WslCoreVm::HandleVirtioAddGuestDevice(_In_ const GUID& Clsid, _In_ const GUID& DeviceId, _In_ PCWSTR Tag, _In_ PCWSTR Options) +{ + auto guestDeviceLock = m_guestDeviceLock.lock_exclusive(); + return AddHdvShareWithOptions(DeviceId, Clsid, Tag, {}, Options, 0, m_userToken.get()); +} + +int WslCoreVm::HandleVirtioModifyOpenPorts(_In_ const GUID& Clsid, _In_ PCWSTR Tag, _In_ const SOCKADDR_INET& Addr, _In_ int Protocol, _In_ bool IsOpen) +{ + if (Protocol != IPPROTO_TCP && Protocol != IPPROTO_UDP) + { + LOG_HR_MSG(HRESULT_FROM_WIN32(ERROR_NOT_SUPPORTED), "Unsupported bind protocol %d", Protocol); + return 0; + } + else if (Addr.si_family == AF_INET6) + { + // The virtio net adapter does not yet support IPv6 packets, so any traffic would arrive via + // IPv4. If the caller wants IPv4 they will also likely listen on an IPv4 address, which will + // be handled as a separate callback to this same code. + return 0; + } + + auto guestDeviceLock = m_guestDeviceLock.lock_exclusive(); + const auto server = m_deviceHostSupport->GetRemoteFileSystem(Clsid, c_defaultTag); + if (server) + { + std::wstring portString = std::format(L"tag={};port_number={}", Tag, Addr.Ipv4.sin_port); + if (Protocol == IPPROTO_UDP) + { + portString += L";udp"; + } + + if (!IsOpen) + { + portString += L";allocate=false"; + } + else + { + wchar_t addrStr[16]; // "000.000.000.000" + null terminator + RtlIpv4AddressToStringW(&Addr.Ipv4.sin_addr, addrStr); + portString += std::format(L";listen_addr={}", addrStr); + } + + LOG_IF_FAILED(server->AddShare(portString.c_str(), nullptr, 0)); + } + return 0; +} + WslCoreVm::DiskMountResult WslCoreVm::MountDisk( _In_ PCWSTR Disk, _In_ DiskType MountDiskType, _In_ ULONG PartitionIndex, _In_opt_ PCWSTR Name, _In_opt_ PCWSTR Type, _In_opt_ PCWSTR Options) { @@ -2846,12 +2860,6 @@ LX_INIT_DRVFS_MOUNT WslCoreVm::s_InitializeDrvFs(_Inout_ WslCoreVm* VmContext, _ } } -bool WslCoreVm::IsVhdAttached(_In_ PCWSTR VhdPath) -{ - auto lock = m_lock.lock_exclusive(); - return m_attachedDisks.contains({DiskType::VHD, VhdPath}); -} - void CALLBACK WslCoreVm::s_OnExit(_In_ HCS_EVENT* Event, _In_opt_ void* Context) try { diff --git a/src/windows/service/exe/WslCoreVm.h b/src/windows/service/exe/WslCoreVm.h index df664e0..d04aefa 100644 --- a/src/windows/service/exe/WslCoreVm.h +++ b/src/windows/service/exe/WslCoreVm.h @@ -107,6 +107,10 @@ public: bool IsVhdAttached(_In_ PCWSTR VhdPath); + GUID HandleVirtioAddGuestDevice(_In_ const GUID& Clsid, _In_ const GUID& DeviceId, _In_ PCWSTR Tag, _In_ PCWSTR Options); + + int HandleVirtioModifyOpenPorts(_In_ const GUID& Clsid, _In_ PCWSTR Tag, _In_ const SOCKADDR_INET& Addr, _In_ int Protocol, _In_ bool IsOpen); + DiskMountResult MountDisk( _In_ PCWSTR Disk, _In_ DiskType MountDiskType, _In_ ULONG PartitionIndex, _In_opt_ PCWSTR Name, _In_opt_ PCWSTR Type, _In_opt_ PCWSTR Options);