diff --git a/src/windows/common/NatNetworking.cpp b/src/windows/common/NatNetworking.cpp index 1be00498..4c39fe74 100644 --- a/src/windows/common/NatNetworking.cpp +++ b/src/windows/common/NatNetworking.cpp @@ -438,9 +438,6 @@ try return; } - hns::ModifyGuestEndpointSettingRequest notification{}; - notification.Settings.Options = LX_INIT_RESOLVCONF_FULL_HEADER; - networking::DnsInfo latestDnsSettings{}; // true if the "domain" entry of /etc/resolv.conf should be configured @@ -475,28 +472,19 @@ try if (latestDnsSettings != m_trackedDnsSettings) { - notification.Settings.ServerList = wsl::shared::string::MultiByteToWide(wsl::shared::string::Join(latestDnsSettings.Servers, ',')); - - if (configureLinuxDomain) - { - WI_ASSERT(!latestDnsSettings.Domains.empty()); - notification.Settings.Domain = wsl::shared::string::MultiByteToWide(latestDnsSettings.Domains.front()); - } - else - { - notification.Settings.Search = wsl::shared::string::MultiByteToWide(wsl::shared::string::Join(latestDnsSettings.Domains, ',')); - } + auto dnsNotification = BuildDnsNotification(latestDnsSettings, configureLinuxDomain); WSL_LOG( "NatNetworking::UpdateDns", - TraceLoggingValue(notification.Settings.Domain.c_str(), "domain"), - TraceLoggingValue(notification.Settings.Options.c_str(), "options"), - TraceLoggingValue(notification.Settings.Search.c_str(), "search"), - TraceLoggingValue(notification.Settings.ServerList.c_str(), "serverList")); + TraceLoggingValue(dnsNotification.Domain.c_str(), "domain"), + TraceLoggingValue(dnsNotification.Options.c_str(), "options"), + TraceLoggingValue(dnsNotification.Search.c_str(), "search"), + TraceLoggingValue(dnsNotification.ServerList.c_str(), "serverList")); + hns::ModifyGuestEndpointSettingRequest notification{}; notification.RequestType = hns::ModifyRequestType::Update; notification.ResourceType = hns::GuestEndpointResourceType::DNS; - notification.Settings = notification.Settings; + notification.Settings = std::move(dnsNotification); m_gnsChannel.SendHnsNotification(ToJsonW(notification).c_str(), m_endpoint.Id); m_trackedDnsSettings = std::move(latestDnsSettings); diff --git a/src/windows/common/VirtioNetworking.cpp b/src/windows/common/VirtioNetworking.cpp index 6cfcd3fb..64d55ef1 100644 --- a/src/windows/common/VirtioNetworking.cpp +++ b/src/windows/common/VirtioNetworking.cpp @@ -65,14 +65,15 @@ void VirtioNetworking::Initialize() device_options << L"gateway_ip=" << default_route; } - auto dns_servers = m_networkSettings->DnsServersString(); - if (!dns_servers.empty()) + // Get initial DNS settings for device options. + auto initialDns = m_dnsUpdateHelper.GetCurrentDnsSettings(networking::DnsSettingsFlags::IncludeVpn); + if (!initialDns.Servers.empty()) { if (device_options.tellp() > 0) { device_options << L";"; } - device_options << L"nameservers=" << dns_servers; + device_options << L"nameservers=" << wsl::shared::string::MultiByteToWide(wsl::shared::string::Join(initialDns.Servers, ',')); } auto lock = m_lock.lock_exclusive(); @@ -104,15 +105,9 @@ void VirtioNetworking::Initialize() m_gnsChannel.SendHnsNotification(ToJsonW(request).c_str(), m_adapterId); } - // Update DNS information. - if (!dns_servers.empty()) - { - // TODO: DNS domain suffixes - hns::DNS dnsSettings{}; - dnsSettings.Options = LX_INIT_RESOLVCONF_FULL_HEADER; - dnsSettings.ServerList = dns_servers; - UpdateDns(std::move(dnsSettings)); - } + // Send the initial DNS configuration to GNS and track it. + m_trackedDnsSettings = initialDns; + SendDnsUpdate(initialDns); if (m_enableLocalhostRelay) { @@ -263,15 +258,23 @@ try { auto lock = m_lock.lock_exclusive(); UpdateMtu(); + + // Check for DNS changes and send update if needed. + auto currentDns = m_dnsUpdateHelper.GetCurrentDnsSettings(networking::DnsSettingsFlags::IncludeVpn); + if (currentDns != m_trackedDnsSettings) + { + m_trackedDnsSettings = currentDns; + SendDnsUpdate(currentDns); + } } CATCH_LOG(); -void VirtioNetworking::UpdateDns(hns::DNS&& dnsSettings) +void VirtioNetworking::SendDnsUpdate(const networking::DnsInfo& dnsSettings) { hns::ModifyGuestEndpointSettingRequest notification{}; notification.RequestType = hns::ModifyRequestType::Update; notification.ResourceType = hns::GuestEndpointResourceType::DNS; - notification.Settings = std::move(dnsSettings); + notification.Settings = networking::BuildDnsNotification(dnsSettings); m_gnsChannel.SendHnsNotification(ToJsonW(notification).c_str(), m_adapterId); } diff --git a/src/windows/common/VirtioNetworking.h b/src/windows/common/VirtioNetworking.h index 2a679f33..dc7dedf7 100644 --- a/src/windows/common/VirtioNetworking.h +++ b/src/windows/common/VirtioNetworking.h @@ -38,7 +38,7 @@ private: int ModifyOpenPorts(_In_ PCWSTR tag, _In_ const SOCKADDR_INET& addr, _In_ int protocol, _In_ bool isOpen) const; void RefreshGuestConnection(NL_NETWORK_CONNECTIVITY_HINT hint) noexcept; void SetupLoopbackDevice(); - void UpdateDns(wsl::shared::hns::DNS&& dnsSettings); + void SendDnsUpdate(const networking::DnsInfo& dnsSettings); void UpdateMtu(); mutable wil::srwlock m_lock; @@ -54,7 +54,8 @@ private: std::optional m_interfaceLuid; ULONG m_networkMtu = 0; - std::optional m_dnsInfo; + networking::DnsUpdateHelper m_dnsUpdateHelper; + networking::DnsInfo m_trackedDnsSettings; // Note: this field must be destroyed first to stop the callbacks before any other field is destroyed. networking::unique_notify_handle m_networkNotifyHandle; diff --git a/src/windows/common/WslCoreHostDnsInfo.cpp b/src/windows/common/WslCoreHostDnsInfo.cpp index cdc094c1..64de9a93 100644 --- a/src/windows/common/WslCoreHostDnsInfo.cpp +++ b/src/windows/common/WslCoreHostDnsInfo.cpp @@ -496,3 +496,29 @@ wsl::core::networking::DnsSuffixRegistryWatcher::DnsSuffixRegistryWatcher(Regist m_registryWatchers.swap(localRegistryWatchers); } + +wsl::shared::hns::DNS wsl::core::networking::BuildDnsNotification(const DnsInfo& settings, bool useLinuxDomainEntry) +{ + wsl::shared::hns::DNS dnsNotification{}; + dnsNotification.Options = LX_INIT_RESOLVCONF_FULL_HEADER; + dnsNotification.ServerList = wsl::shared::string::MultiByteToWide(wsl::shared::string::Join(settings.Servers, ',')); + + if (useLinuxDomainEntry && !settings.Domains.empty()) + { + // Use 'domain' entry for single DNS suffix (typically used when mirroring host DNS without tunneling) + dnsNotification.Domain = wsl::shared::string::MultiByteToWide(settings.Domains.front()); + } + else + { + // Use 'search' entry for DNS suffix list + dnsNotification.Search = wsl::shared::string::MultiByteToWide(wsl::shared::string::Join(settings.Domains, ',')); + } + + return dnsNotification; +} + +wsl::core::networking::DnsInfo wsl::core::networking::DnsUpdateHelper::GetCurrentDnsSettings(DnsSettingsFlags flags) +{ + m_hostDnsInfo.UpdateNetworkInformation(); + return m_hostDnsInfo.GetDnsSettings(flags); +} diff --git a/src/windows/common/WslCoreHostDnsInfo.h b/src/windows/common/WslCoreHostDnsInfo.h index 1813574f..fe2cab07 100644 --- a/src/windows/common/WslCoreHostDnsInfo.h +++ b/src/windows/common/WslCoreHostDnsInfo.h @@ -38,6 +38,14 @@ inline bool operator!=(const DnsInfo& lhs, const DnsInfo& rhs) noexcept std::string GenerateResolvConf(_In_ const DnsInfo& Info); +/// +/// Builds an hns::DNS notification from DnsInfo settings. +/// +/// The DNS settings to convert +/// If true, uses 'domain' entry for single suffix; otherwise uses 'search' for all +/// suffixes The hns::DNS notification ready to send via GNS channel +wsl::shared::hns::DNS BuildDnsNotification(const DnsInfo& settings, bool useLinuxDomainEntry = false); + std::vector GetAllDnsSuffixes(const std::vector& AdapterAddresses); DWORD GetBestInterface(); @@ -84,6 +92,24 @@ private: _Guarded_by_(m_lock) std::vector m_addresses; }; +/// +/// Helper class that fetches current DNS settings from the host. +/// Callers are responsible for tracking changes if needed. +/// +class DnsUpdateHelper +{ +public: + /// + /// Fetches current DNS settings from the host. + /// + /// Flags controlling which DNS settings to include + /// Current DNS settings + DnsInfo GetCurrentDnsSettings(DnsSettingsFlags flags); + +private: + HostDnsInfo m_hostDnsInfo; +}; + using RegistryChangeCallback = std::function; /// diff --git a/src/windows/common/WslCoreNetworkEndpointSettings.cpp b/src/windows/common/WslCoreNetworkEndpointSettings.cpp index 92108a57..d6600f39 100644 --- a/src/windows/common/WslCoreNetworkEndpointSettings.cpp +++ b/src/windows/common/WslCoreNetworkEndpointSettings.cpp @@ -26,17 +26,13 @@ std::shared_ptr wsl::core::networking::G address, route, properties.MacAddress, - L"unuseddevicename", properties.InterfaceConstraint.InterfaceIndex, - properties.InterfaceConstraint.InterfaceMediaType, - properties.DNSServerList); + properties.InterfaceConstraint.InterfaceMediaType); } std::shared_ptr wsl::core::networking::GetHostEndpointSettings() { - HostDnsInfo dnsInfo; - dnsInfo.UpdateNetworkInformation(); - auto addresses = dnsInfo.CurrentAddresses(); + auto addresses = AdapterAddresses::GetCurrent(); auto bestIndex = GetBestInterface(); auto bestInterfacePtr = std::find_if(addresses.cbegin(), addresses.cend(), [&](const auto& address) { return address->IfIndex == bestIndex; }); @@ -95,16 +91,6 @@ std::shared_ptr wsl::core::networking::G route.NextHopString = windows::common::string::SockAddrInetToWstring(route.NextHop); } - std::wstring dnsServerList; - for (const auto& serverAddress : dnsInfo.GetDnsSettings(DnsSettingsFlags::IncludeVpn).Servers) - { - if (!dnsServerList.empty()) - { - dnsServerList += L","; - } - dnsServerList += wsl::shared::string::MultiByteToWide(serverAddress); - } - - return std::shared_ptr(new NetworkSettings( - bestInterface->NetworkGuid, address, route, macAddress, {}, bestInterface->IfIndex, bestInterface->IfType, dnsServerList)); + return std::make_shared( + bestInterface->NetworkGuid, address, route, macAddress, bestInterface->IfIndex, bestInterface->IfType); } diff --git a/src/windows/common/WslCoreNetworkEndpointSettings.h b/src/windows/common/WslCoreNetworkEndpointSettings.h index 5e165c9e..405aa16c 100644 --- a/src/windows/common/WslCoreNetworkEndpointSettings.h +++ b/src/windows/common/WslCoreNetworkEndpointSettings.h @@ -258,33 +258,21 @@ struct NetworkSettings { NetworkSettings() = default; - NetworkSettings( - const GUID& interfaceGuid, - EndpointIpAddress preferredIpAddress, - EndpointRoute gateway, - std::wstring macAddress, - std::wstring deviceName, - uint32_t interfaceIndex, - uint32_t mediaType, - const std::wstring& dnsServerList) : + NetworkSettings(const GUID& interfaceGuid, EndpointIpAddress preferredIpAddress, EndpointRoute gateway, std::wstring macAddress, uint32_t interfaceIndex, uint32_t mediaType) : InterfaceGuid(interfaceGuid), PreferredIpAddress(std::move(preferredIpAddress)), MacAddress(std::move(macAddress)), - DeviceName(std::move(deviceName)), InterfaceIndex(interfaceIndex), InterfaceType(mediaType) { Routes.emplace(std::move(gateway)); - DnsServers = wsl::shared::string::Split(dnsServerList, L','); } GUID InterfaceGuid{}; EndpointIpAddress PreferredIpAddress{}; std::set IpAddresses{}; // Does not include PreferredIpAddress. std::set Routes{}; - std::vector DnsServers{}; std::wstring MacAddress; - std::wstring DeviceName; IF_INDEX InterfaceIndex = 0; IFTYPE InterfaceType = 0; ULONG IPv4InterfaceMtu = 0; @@ -344,11 +332,6 @@ struct NetworkSettings }); } - std::wstring DnsServersString() const - { - return wsl::shared::string::Join(DnsServers, L','); - } - // will return ULONG_MAX if there's no configured MTU ULONG GetEffectiveMtu() const noexcept { @@ -386,7 +369,6 @@ std::shared_ptr GetHostEndpointSettings(); TraceLoggingValue((settings)->PreferredIpAddress.PrefixLength, "preferredIpAddressPrefixLength"), \ TraceLoggingValue((settings)->IpAddressesString().c_str(), "ipAddresses"), \ TraceLoggingValue((settings)->RoutesString().c_str(), "routes"), \ - TraceLoggingValue((settings)->DnsServersString().c_str(), "dnsServerList"), \ TraceLoggingValue((settings)->MacAddress.c_str(), "macAddress"), \ TraceLoggingValue((settings)->IPv4InterfaceMtu, "IPv4InterfaceMtu"), \ TraceLoggingValue((settings)->IPv6InterfaceMtu, "IPv6InterfaceMtu"), \ diff --git a/src/windows/service/exe/WslMirroredNetworking.cpp b/src/windows/service/exe/WslMirroredNetworking.cpp index f7bd0e17..0d3bb361 100644 --- a/src/windows/service/exe/WslMirroredNetworking.cpp +++ b/src/windows/service/exe/WslMirroredNetworking.cpp @@ -18,6 +18,7 @@ Abstract: #include "Stringify.h" #include "WslCoreNetworkingSupport.h" #include "WslCoreNetworkEndpointSettings.h" +#include "WslCoreHostDnsInfo.h" #include "hcs.hpp" #include "hns_schema.h" @@ -895,17 +896,6 @@ try } CATCH_RETURN() -static hns::DNS ConvertDnsInfoToHnsSettingsMsg(const wsl::core::networking::DnsInfo& dnsInfo) -{ - hns::DNS dnsSettings{}; - dnsSettings.Options = LX_INIT_RESOLVCONF_FULL_HEADER; - - dnsSettings.ServerList = wsl::shared::string::MultiByteToWide(wsl::shared::string::Join(dnsInfo.Servers, ',')); - dnsSettings.Search = wsl::shared::string::MultiByteToWide(wsl::shared::string::Join(dnsInfo.Domains, ',')); - - return dnsSettings; -} - _Requires_lock_held_(m_networkLock) _Check_return_ HRESULT wsl::core::networking::WslMirroredNetworkManager::SendDnsRequestToGns( const NetworkEndpoint& endpoint, const DnsInfo& dnsInfo, hns::ModifyRequestType requestType) noexcept @@ -915,7 +905,7 @@ try modifyRequest.ResourceType = hns::GuestEndpointResourceType::DNS; modifyRequest.RequestType = requestType; modifyRequest.targetDeviceName = wsl::shared::string::GuidToString(endpoint.InterfaceGuid); - modifyRequest.Settings = ConvertDnsInfoToHnsSettingsMsg(dnsInfo); + modifyRequest.Settings = BuildDnsNotification(dnsInfo); WSL_LOG( "WslMirroredNetworkManager::SendDnsRequestToGns", @@ -1982,7 +1972,6 @@ void wsl::core::networking::WslMirroredNetworkManager::AddEndpointImpl(EndpointT THROW_IF_FAILED(hr); endpointTrackingObject.m_networkEndpoint.Network->MacAddress = endpointTrackingObject.m_hnsEndpoint.MacAddress; - endpointTrackingObject.m_networkEndpoint.Network->DeviceName = endpointTrackingObject.m_hnsEndpoint.PortFriendlyName; if (IsInterfaceIndexOfGelnic(endpointTrackingObject.m_networkEndpoint.Network->InterfaceIndex)) {