Merge pull request #13764 from microsoft/user/benhill/wsla_virtio_proxy

WSLA: Implement virtioproxy networking mode
This commit is contained in:
Ben Hillis 2025-12-01 16:10:09 -08:00 committed by GitHub
commit 2ab516edf4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 117 additions and 52 deletions

View File

@ -1557,6 +1557,7 @@ int WslaShell(_In_ std::wstring_view commandLine)
parser.AddArgument(reinterpret_cast<bool&>(settings.EnableDnsTunneling), L"--dns-tunneling"); parser.AddArgument(reinterpret_cast<bool&>(settings.EnableDnsTunneling), L"--dns-tunneling");
parser.AddArgument(Integer(settings.MemoryMb), L"--memory"); parser.AddArgument(Integer(settings.MemoryMb), L"--memory");
parser.AddArgument(Integer(settings.CpuCount), L"--cpu"); parser.AddArgument(Integer(settings.CpuCount), L"--cpu");
parser.AddArgument(Integer(reinterpret_cast<int&>(settings.NetworkingMode)), L"--networking-mode");
parser.AddArgument(Utf8String(fsType), L"--fstype"); parser.AddArgument(Utf8String(fsType), L"--fstype");
parser.AddArgument(containerRootVhd, L"--container-vhd"); parser.AddArgument(containerRootVhd, L"--container-vhd");
parser.AddArgument(help, L"--help"); parser.AddArgument(help, L"--help");
@ -1565,13 +1566,23 @@ int WslaShell(_In_ std::wstring_view commandLine)
if (help) if (help)
{ {
const auto usage = std::format( const auto usage = std::format(
LR"({} --wsla [--vhd </path/to/vhd>] [--shell </path/to/shell>] [--memory <memory-mb>] [--cpu <cpus>] [--dns-tunneling] [--fstype <fstype>] [--container-vhd </path/to/vhd>] [--help])", LR"({} --wsla [--vhd </path/to/vhd>] [--shell </path/to/shell>] [--memory <memory-mb>] [--cpu <cpus>] [--dns-tunneling] [--networking-mode <mode>] [--fstype <fstype>] [--container-vhd </path/to/vhd>] [--help])",
WSL_BINARY_NAME); WSL_BINARY_NAME);
wprintf(L"%ls\n", usage.c_str()); wprintf(L"%ls\n", usage.c_str());
return 1; return 1;
} }
switch (settings.NetworkingMode)
{
case WSLANetworkingMode::WSLANetworkingModeNone:
case WSLANetworkingMode::WSLANetworkingModeNAT:
case WSLANetworkingMode::WSLANetworkingModeVirtioProxy:
break;
default:
THROW_HR(E_INVALIDARG);
}
if (!containerRootVhd.empty()) if (!containerRootVhd.empty())
{ {
settings.ContainerRootVhd = containerRootVhd.c_str(); settings.ContainerRootVhd = containerRootVhd.c_str();

View File

@ -16,9 +16,9 @@ Abstract:
#include <format> #include <format>
#include <filesystem> #include <filesystem>
#include "hcs_schema.h" #include "hcs_schema.h"
#include "VirtioNetworking.h"
#include "NatNetworking.h" #include "NatNetworking.h"
#include "WSLAUserSession.h" #include "WSLAUserSession.h"
#include "DnsResolver.h"
#include "ServiceProcessLauncher.h" #include "ServiceProcessLauncher.h"
using namespace wsl::windows::common; using namespace wsl::windows::common;
@ -94,6 +94,12 @@ WSLAVirtualMachine::~WSLAVirtualMachine()
WSL_LOG("WSLATerminateVm", TraceLoggingValue(forceTerminate, "forced"), TraceLoggingValue(m_running, "running")); WSL_LOG("WSLATerminateVm", TraceLoggingValue(forceTerminate, "forced"), TraceLoggingValue(m_running, "running"));
// Shutdown DeviceHostProxy before resetting compute system
if (m_guestDeviceManager)
{
m_guestDeviceManager->Shutdown();
}
m_computeSystem.reset(); m_computeSystem.reset();
for (const auto& e : m_attachedDisks) for (const auto& e : m_attachedDisks)
@ -308,6 +314,13 @@ void WSLAVirtualMachine::Start()
auto runtimeId = wsl::windows::common::hcs::GetRuntimeId(m_computeSystem.get()); auto runtimeId = wsl::windows::common::hcs::GetRuntimeId(m_computeSystem.get());
WI_ASSERT(IsEqualGUID(m_vmId, runtimeId)); WI_ASSERT(IsEqualGUID(m_vmId, runtimeId));
// Initialize DeviceHostProxy for virtio device support.
// N.B. This is currently only needed for VirtioProxy networking mode but would also be needed for virtiofs.
if (m_settings.NetworkingMode == WSLANetworkingModeVirtioProxy)
{
m_guestDeviceManager = std::make_shared<GuestDeviceManager>(m_vmIdString, m_vmId);
}
wsl::windows::common::hcs::RegisterCallback(m_computeSystem.get(), &s_OnExit, this); wsl::windows::common::hcs::RegisterCallback(m_computeSystem.get(), &s_OnExit, this);
wsl::windows::common::hcs::StartComputeSystem(m_computeSystem.get(), json.c_str()); wsl::windows::common::hcs::StartComputeSystem(m_computeSystem.get(), json.c_str());
@ -445,59 +458,71 @@ CATCH_LOG();
void WSLAVirtualMachine::ConfigureNetworking() void WSLAVirtualMachine::ConfigureNetworking()
{ {
if (m_settings.NetworkingMode == WSLANetworkingModeNone) switch (m_settings.NetworkingMode)
{ {
case WSLANetworkingModeNone:
return; return;
case WSLANetworkingModeNAT:
case WSLANetworkingModeVirtioProxy:
break;
default:
THROW_HR_MSG(E_INVALIDARG, "Invalid networking mode: %lu", m_settings.NetworkingMode);
} }
else if (m_settings.NetworkingMode == WSLANetworkingModeNAT)
// Launch GNS
std::vector<WSLA_PROCESS_FD> fds(1);
fds[0].Fd = -1;
fds[0].Type = WSLAFdType::WSLAFdTypeDefault;
std::vector<const char*> cmd{"/gns", LX_INIT_GNS_SOCKET_ARG};
// If DNS tunnelling is enabled, use an additional for its channel.
if (m_settings.EnableDnsTunneling)
{ {
// Launch GNS THROW_HR_IF_MSG(
std::vector<WSLA_PROCESS_FD> fds(1); E_NOTIMPL,
fds[0].Fd = -1; m_settings.NetworkingMode == WSLANetworkingModeVirtioProxy,
fds[0].Type = WSLAFdType::WSLAFdTypeDefault; "DNS tunneling not currently supported for VirtioProxy");
std::vector<const char*> cmd{"/gns", LX_INIT_GNS_SOCKET_ARG}; fds.emplace_back(WSLA_PROCESS_FD{.Fd = -1, .Type = WSLAFdType::WSLAFdTypeDefault});
THROW_IF_FAILED(wsl::core::networking::DnsResolver::LoadDnsResolverMethods());
}
// If DNS tunnelling is enabled, use an additional for its channel. WSLA_PROCESS_OPTIONS options{};
if (m_settings.EnableDnsTunneling) options.Executable = "/init";
options.Fds = fds.data();
options.FdsCount = static_cast<DWORD>(fds.size());
// Because the file descriptors numbers aren't known in advance, the command line needs to be generated after the file
// descriptors are allocated.
std::string socketFdArg;
std::string dnsFdArg;
int gnsChannelFd = -1;
int dnsChannelFd = -1;
auto prepareCommandLine = [&](const auto& sockets) {
gnsChannelFd = sockets[0].Fd;
socketFdArg = std::to_string(gnsChannelFd);
cmd.emplace_back(socketFdArg.c_str());
if (sockets.size() > 1)
{ {
fds.emplace_back(WSLA_PROCESS_FD{.Fd = -1, .Type = WSLAFdType::WSLAFdTypeDefault}); dnsChannelFd = sockets[1].Fd;
THROW_IF_FAILED(wsl::core::networking::DnsResolver::LoadDnsResolverMethods()); dnsFdArg = std::to_string(dnsChannelFd);
cmd.emplace_back(LX_INIT_GNS_DNS_SOCKET_ARG);
cmd.emplace_back(dnsFdArg.c_str());
cmd.emplace_back(LX_INIT_GNS_DNS_TUNNELING_IP);
cmd.emplace_back(LX_INIT_DNS_TUNNELING_IP_ADDRESS);
} }
WSLA_PROCESS_OPTIONS options{}; options.CommandLine = cmd.data();
options.Executable = "/init"; options.CommandLineCount = static_cast<DWORD>(cmd.size());
options.Fds = fds.data(); };
options.FdsCount = static_cast<DWORD>(fds.size());
// Because the file descriptors numbers aren't known in advance, the command line needs to be generated after the file auto process = CreateLinuxProcess(options, nullptr, prepareCommandLine);
// descriptors are allocated. auto gnsChannel = wsl::core::GnsChannel(wil::unique_socket{(SOCKET)process->GetStdHandle(gnsChannelFd).release()});
std::string socketFdArg;
std::string dnsFdArg;
int gnsChannelFd = -1;
int dnsChannelFd = -1;
auto prepareCommandLine = [&](const auto& sockets) {
gnsChannelFd = sockets[0].Fd;
socketFdArg = std::to_string(gnsChannelFd);
cmd.emplace_back(socketFdArg.c_str());
if (sockets.size() > 1)
{
dnsChannelFd = sockets[1].Fd;
dnsFdArg = std::to_string(dnsChannelFd);
cmd.emplace_back(LX_INIT_GNS_DNS_SOCKET_ARG);
cmd.emplace_back(dnsFdArg.c_str());
cmd.emplace_back(LX_INIT_GNS_DNS_TUNNELING_IP);
cmd.emplace_back(LX_INIT_DNS_TUNNELING_IP_ADDRESS);
}
options.CommandLine = cmd.data();
options.CommandLineCount = static_cast<DWORD>(cmd.size());
};
auto process = CreateLinuxProcess(options, nullptr, prepareCommandLine);
if (m_settings.NetworkingMode == WSLANetworkingModeNAT)
{
// TODO: refactor this to avoid using wsl config // TODO: refactor this to avoid using wsl config
static wsl::core::Config config(nullptr); static wsl::core::Config config(nullptr);
@ -510,18 +535,18 @@ void WSLAVirtualMachine::ConfigureNetworking()
m_networkEngine = std::make_unique<wsl::core::NatNetworking>( m_networkEngine = std::make_unique<wsl::core::NatNetworking>(
m_computeSystem.get(), m_computeSystem.get(),
wsl::core::NatNetworking::CreateNetwork(config), wsl::core::NatNetworking::CreateNetwork(config),
wil::unique_socket{(SOCKET)process->GetStdHandle(gnsChannelFd).release()}, std::move(gnsChannel),
config, config,
dnsChannelFd != -1 ? wil::unique_socket{(SOCKET)process->GetStdHandle(dnsChannelFd).release()} : wil::unique_socket{}); dnsChannelFd != -1 ? wil::unique_socket{(SOCKET)process->GetStdHandle(dnsChannelFd).release()} : wil::unique_socket{});
m_networkEngine->Initialize();
LaunchPortRelay();
} }
else else
{ {
THROW_HR_MSG(E_INVALIDARG, "Invalid networking mode: %lu", m_settings.NetworkingMode); m_networkEngine = std::make_unique<wsl::core::VirtioNetworking>(std::move(gnsChannel), true, m_guestDeviceManager, m_userToken);
} }
m_networkEngine->Initialize();
LaunchPortRelay();
} }
void CALLBACK WSLAVirtualMachine::s_OnExit(_In_ HCS_EVENT* Event, _In_opt_ void* Context) void CALLBACK WSLAVirtualMachine::s_OnExit(_In_ HCS_EVENT* Event, _In_opt_ void* Context)

View File

@ -16,6 +16,8 @@ Abstract:
#include "INetworkingEngine.h" #include "INetworkingEngine.h"
#include "hcs.hpp" #include "hcs.hpp"
#include "Dmesg.h" #include "Dmesg.h"
#include "DnsResolver.h"
#include "GuestDeviceManager.h"
#include "WSLAApi.h" #include "WSLAApi.h"
#include "WSLAProcess.h" #include "WSLAProcess.h"
@ -120,7 +122,7 @@ private:
int m_coldDiscardShiftSize{}; int m_coldDiscardShiftSize{};
bool m_running = false; bool m_running = false;
PSID m_userSid{}; PSID m_userSid{};
wil::unique_handle m_userToken; wil::shared_handle m_userToken;
std::wstring m_debugShellPipe; std::wstring m_debugShellPipe;
std::mutex m_trackedProcessesLock; std::mutex m_trackedProcessesLock;
@ -133,6 +135,7 @@ private:
bool m_vmSavedStateCaptured = false; bool m_vmSavedStateCaptured = false;
bool m_crashLogCaptured = false; bool m_crashLogCaptured = false;
std::shared_ptr<GuestDeviceManager> m_guestDeviceManager;
std::shared_ptr<DmesgCollector> m_dmesgCollector; std::shared_ptr<DmesgCollector> m_dmesgCollector;
wil::unique_event m_vmExitEvent{wil::EventOptions::ManualReset}; wil::unique_event m_vmExitEvent{wil::EventOptions::ManualReset};
wil::unique_event m_vmTerminatingEvent{wil::EventOptions::ManualReset}; wil::unique_event m_vmTerminatingEvent{wil::EventOptions::ManualReset};

View File

@ -209,7 +209,8 @@ interface IWSLAVirtualMachine : IUnknown
typedef enum _WSLANetworkingMode typedef enum _WSLANetworkingMode
{ {
WSLANetworkingModeNone, WSLANetworkingModeNone,
WSLANetworkingModeNAT WSLANetworkingModeNAT,
WSLANetworkingModeVirtioProxy
} WSLANetworkingMode; } WSLANetworkingMode;
typedef typedef

View File

@ -380,6 +380,31 @@ class WSLATests
VERIFY_ARE_EQUAL(result.Output[1], std::format("nameserver {}\n", LX_INIT_DNS_TUNNELING_IP_ADDRESS)); VERIFY_ARE_EQUAL(result.Output[1], std::format("nameserver {}\n", LX_INIT_DNS_TUNNELING_IP_ADDRESS));
} }
TEST_METHOD(VirtioProxyNetworking)
{
WSL2_TEST_ONLY();
VIRTUAL_MACHINE_SETTINGS settings{};
settings.CpuCount = 4;
settings.DisplayName = L"WSLA";
settings.MemoryMb = 2048;
settings.BootTimeoutMs = 30 * 1000;
settings.NetworkingMode = WSLANetworkingModeVirtioProxy;
settings.RootVhd = testVhd.c_str();
auto session = CreateSession(settings);
// Validate that eth0 has an ip address
ExpectCommandResult(
session.get(),
{"/bin/bash",
"-c",
"ip a show dev eth0 | grep -iF 'inet ' | grep -E '[0-9]{1,3}\\.[0-9]{1,3}\\.[0-9]{1,3}\\.[0-9]{1,3}'"},
0);
ExpectCommandResult(session.get(), {"/bin/grep", "-iF", "nameserver", "/etc/resolv.conf"}, 0);
}
TEST_METHOD(OpenFiles) TEST_METHOD(OpenFiles)
{ {
WSL2_TEST_ONLY(); WSL2_TEST_ONLY();