diff --git a/src/windows/common/NatNetworking.cpp b/src/windows/common/NatNetworking.cpp index dbad2a2..1be0049 100644 --- a/src/windows/common/NatNetworking.cpp +++ b/src/windows/common/NatNetworking.cpp @@ -506,23 +506,12 @@ CATCH_LOG() void NatNetworking::UpdateMtu() { - unique_interface_table interfaceTable{}; - THROW_IF_WIN32_ERROR(::GetIpInterfaceTable(AF_UNSPEC, &interfaceTable)); - - ULONG minMtu = ULONG_MAX; - for (ULONG index = 0; index < interfaceTable.get()->NumEntries; index++) - { - const auto& ipInterface = interfaceTable.get()->Table[index]; - if (ipInterface.Connected) - { - minMtu = std::min(ipInterface.NlMtu, minMtu); - } - } + const auto minMtu = GetMinimumConnectedInterfaceMtu(); // Only send the update if the MTU changed. - if (minMtu != ULONG_MAX && minMtu != m_networkMtu) + if (minMtu && minMtu.value() != m_networkMtu) { - m_networkMtu = minMtu; + m_networkMtu = minMtu.value(); hns::ModifyGuestEndpointSettingRequest notification{}; notification.ResourceType = hns::GuestEndpointResourceType::Interface; diff --git a/src/windows/common/VirtioNetworking.cpp b/src/windows/common/VirtioNetworking.cpp index cb69244..a633696 100644 --- a/src/windows/common/VirtioNetworking.cpp +++ b/src/windows/common/VirtioNetworking.cpp @@ -22,6 +22,14 @@ VirtioNetworking::VirtioNetworking( { } +VirtioNetworking::~VirtioNetworking() +{ + // Unregister the network notification callback to prevent it from using the GNS channel. + m_networkNotifyHandle.reset(); + // Stop the GNS channel to unblock any stuck communications with the guest. + m_gnsChannel.Stop(); +} + void VirtioNetworking::Initialize() try { @@ -271,32 +279,25 @@ void VirtioNetworking::UpdateDns(hns::DNS&& dnsSettings) void VirtioNetworking::UpdateMtu() { - unique_interface_table interfaceTable{}; - THROW_IF_WIN32_ERROR(::GetIpInterfaceTable(AF_UNSPEC, &interfaceTable)); - - ULONG minMtu = ULONG_MAX; - for (ULONG index = 0; index < interfaceTable.get()->NumEntries; index++) - { - const auto& ipInterface = interfaceTable.get()->Table[index]; - if (ipInterface.Connected) - { - minMtu = std::min(ipInterface.NlMtu, minMtu); - } - } + const auto minMtu = GetMinimumConnectedInterfaceMtu(); // Only send the update if the MTU changed. - if (minMtu != ULONG_MAX && minMtu != m_networkMtu) + if (minMtu && minMtu.value() != m_networkMtu) { + m_networkMtu = minMtu.value(); + hns::ModifyGuestEndpointSettingRequest notification{}; notification.ResourceType = hns::GuestEndpointResourceType::Interface; notification.RequestType = hns::ModifyRequestType::Update; - notification.Settings.NlMtu = m_networkMtu; notification.Settings.Connected = true; + notification.Settings.NlMtu = m_networkMtu; - WSL_LOG("VirtioNetworking::UpdateMtu", TraceLoggingValue(m_networkMtu, "VirtioMtu")); + WSL_LOG( + "VirtioNetworking::UpdateMtu", + TraceLoggingValue(m_adapterId, "endpointId"), + TraceLoggingValue(m_networkMtu, "virtioMtu")); - // TODO: Why was this commented ? - // m_gnsChannel.SendHnsNotification(ToJsonW(notification).c_str(), m_endpointId); + m_gnsChannel.SendHnsNotification(ToJsonW(notification).c_str(), m_adapterId); } } diff --git a/src/windows/common/VirtioNetworking.h b/src/windows/common/VirtioNetworking.h index 629afb2..b32b545 100644 --- a/src/windows/common/VirtioNetworking.h +++ b/src/windows/common/VirtioNetworking.h @@ -14,7 +14,7 @@ class VirtioNetworking : public INetworkingEngine { public: VirtioNetworking(GnsChannel&& gnsChannel, bool enableLocalhostRelay, std::shared_ptr guestDeviceManager, wil::shared_handle userToken); - ~VirtioNetworking() = default; + ~VirtioNetworking(); // Note: This class cannot be moved because m_networkNotifyHandle captures a 'this' pointer. VirtioNetworking(const VirtioNetworking&) = delete; diff --git a/src/windows/common/WslCoreNetworkingSupport.cpp b/src/windows/common/WslCoreNetworkingSupport.cpp index a884579..d4fd0d9 100644 --- a/src/windows/common/WslCoreNetworkingSupport.cpp +++ b/src/windows/common/WslCoreNetworkingSupport.cpp @@ -248,3 +248,25 @@ wsl::core::networking::EphemeralHcnEndpoint wsl::core::networking::CreateEphemer return endpoint; } + +std::optional wsl::core::networking::GetMinimumConnectedInterfaceMtu() noexcept +{ + std::optional minMtu{}; + try + { + unique_interface_table interfaceTable{}; + THROW_IF_WIN32_ERROR(::GetIpInterfaceTable(AF_UNSPEC, &interfaceTable)); + + for (ULONG index = 0; index < interfaceTable.get()->NumEntries; index++) + { + const auto& ipInterface = interfaceTable.get()->Table[index]; + if (ipInterface.Connected) + { + minMtu = std::min(minMtu.value_or(ipInterface.NlMtu), ipInterface.NlMtu); + } + } + } + CATCH_LOG() + + return minMtu; +} diff --git a/src/windows/common/WslCoreNetworkingSupport.h b/src/windows/common/WslCoreNetworkingSupport.h index ff6ac76..f66a742 100644 --- a/src/windows/common/WslCoreNetworkingSupport.h +++ b/src/windows/common/WslCoreNetworkingSupport.h @@ -456,6 +456,11 @@ std::vector EnumerateConnect bool IsMetered(ABI::Windows::Networking::Connectivity::NetworkCostType cost) noexcept; +/// +/// Gets the minimum MTU across all connected network interfaces. +/// +std::optional GetMinimumConnectedInterfaceMtu() noexcept; + /// /// This instance acts as an IP_ADAPTER_ADDRESS pointer. ///