From 2abbe8baf38b684b99ec864c848a98c81209548a Mon Sep 17 00:00:00 2001 From: Blue Date: Fri, 31 Oct 2025 18:49:25 -0700 Subject: [PATCH] =?UTF-8?q?wsla:=20Update=20the=20CreateProcess=20logic=20?= =?UTF-8?q?to=20support=20allocating=20file=20descr=E2=80=A6=20(#13655)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * wsla: Update the CreateProcess logic to support allocating file descriptors from the guest * PR feedback --- src/linux/init/WSLAInit.cpp | 13 +++- src/linux/init/util.cpp | 6 +- src/linux/init/util.h | 2 +- .../wslaservice/exe/WSLAVirtualMachine.cpp | 77 +++++++++++++------ .../wslaservice/exe/WSLAVirtualMachine.h | 20 ++++- 5 files changed, 87 insertions(+), 31 deletions(-) diff --git a/src/linux/init/WSLAInit.cpp b/src/linux/init/WSLAInit.cpp index f839ae2..a990bab 100644 --- a/src/linux/init/WSLAInit.cpp +++ b/src/linux/init/WSLAInit.cpp @@ -85,10 +85,19 @@ void HandleMessageImpl(wsl::shared::SocketChannel& Channel, const WSLA_ACCEPT& M Channel.SendResultMessage(SocketAddress.svm_port); - wil::unique_fd Socket{UtilAcceptVsock(ListenSocket.get(), SocketAddress, SESSION_LEADER_ACCEPT_TIMEOUT_MS)}; + wil::unique_fd Socket{ + UtilAcceptVsock(ListenSocket.get(), SocketAddress, SESSION_LEADER_ACCEPT_TIMEOUT_MS, Message.Fd != -1 ? SOCK_CLOEXEC : 0)}; THROW_LAST_ERROR_IF(!Socket); - THROW_LAST_ERROR_IF(dup2(Socket.get(), Message.Fd) < 0); + if (Message.Fd != -1) + { + THROW_LAST_ERROR_IF(dup2(Socket.get(), Message.Fd) < 0); + } + else + { + Channel.SendResultMessage(Socket.get()); + Socket.release(); + } } void HandleMessageImpl(wsl::shared::SocketChannel& Channel, const WSLA_CONNECT& Message, const gsl::span& Buffer) diff --git a/src/linux/init/util.cpp b/src/linux/init/util.cpp index 6e35c39..2eb1780 100644 --- a/src/linux/init/util.cpp +++ b/src/linux/init/util.cpp @@ -197,7 +197,7 @@ InteropServer::~InteropServer() Reset(); } -int UtilAcceptVsock(int SocketFd, sockaddr_vm SocketAddress, int Timeout) +int UtilAcceptVsock(int SocketFd, sockaddr_vm SocketAddress, int Timeout, int SocketFlags) /*++ @@ -215,6 +215,8 @@ Arguments: Timeout - Supplies a timeout. + SocketFlags - Supplies the socket flags. + Return Value: A file descriptor representing the socket, -1 on failure. @@ -263,7 +265,7 @@ Return Value: if (Result != -1) { socklen_t SocketAddressSize = sizeof(SocketAddress); - Result = accept4(SocketFd, reinterpret_cast(&SocketAddress), &SocketAddressSize, SOCK_CLOEXEC); + Result = accept4(SocketFd, reinterpret_cast(&SocketAddress), &SocketAddressSize, SocketFlags); } if (Result < 0) diff --git a/src/linux/init/util.h b/src/linux/init/util.h index 9c68c66..718ea39 100644 --- a/src/linux/init/util.h +++ b/src/linux/init/util.h @@ -117,7 +117,7 @@ private: wil::unique_fd m_InteropSocket; }; -int UtilAcceptVsock(int SocketFd, sockaddr_vm Address, int Timeout = -1); +int UtilAcceptVsock(int SocketFd, sockaddr_vm Address, int Timeout = -1, int SocketFlags = SOCK_CLOEXEC); int UtilBindVsockAnyPort(struct sockaddr_vm* SocketAddress, int Type); diff --git a/src/windows/wslaservice/exe/WSLAVirtualMachine.cpp b/src/windows/wslaservice/exe/WSLAVirtualMachine.cpp index 2800e2d..bce7bad 100644 --- a/src/windows/wslaservice/exe/WSLAVirtualMachine.cpp +++ b/src/windows/wslaservice/exe/WSLAVirtualMachine.cpp @@ -328,30 +328,45 @@ void WSLAVirtualMachine::ConfigureNetworking() // WSLA-TODO: Using fd=4 here seems to hang gns. There's probably a hardcoded file descriptor somewhere that's causing // so using 1000 for now. std::vector fds(1); - fds[0].Fd = 1000; + fds[0].Fd = -1; fds[0].Type = WslFdType::WslFdTypeDefault; - std::vector cmd{"/gns", LX_INIT_GNS_SOCKET_ARG, "1000"}; + std::vector cmd{"/gns", LX_INIT_GNS_SOCKET_ARG}; // If DNS tunnelling is enabled, use an additional for its channel. if (m_settings.EnableDnsTunneling) { - fds.emplace_back(WSLA_PROCESS_FD{.Fd = 1001, .Type = WslFdType::WslFdTypeDefault}); - cmd.emplace_back(LX_INIT_GNS_DNS_SOCKET_ARG); - cmd.emplace_back("1001"); - cmd.emplace_back(LX_INIT_GNS_DNS_TUNNELING_IP); - cmd.emplace_back(LX_INIT_DNS_TUNNELING_IP_ADDRESS); - + fds.emplace_back(WSLA_PROCESS_FD{.Fd = -1, .Type = WslFdType::WslFdTypeDefault}); THROW_IF_FAILED(wsl::core::networking::DnsResolver::LoadDnsResolverMethods()); } WSLA_CREATE_PROCESS_OPTIONS options{}; options.Executable = "/init"; - options.CommandLine = cmd.data(); - options.CommandLineCount = static_cast(cmd.size()); + + // Because the file descriptors numbers aren't known in advance, the command line needs to be generated after the file + // descriptors are allocated. + + std::string socketFdArg; + std::string dnsFdArg; + auto prepareCommandLine = [&](const auto& sockets) { + socketFdArg = std::to_string(sockets[0].Fd); + cmd.emplace_back(socketFdArg.c_str()); + + if (sockets.size() > 1) + { + dnsFdArg = std::to_string(sockets[1].Fd); + cmd.emplace_back(LX_INIT_GNS_DNS_SOCKET_ARG); + cmd.emplace_back(dnsFdArg.c_str()); + cmd.emplace_back(LX_INIT_GNS_DNS_TUNNELING_IP); + cmd.emplace_back(LX_INIT_DNS_TUNNELING_IP_ADDRESS); + } + + options.CommandLine = cmd.data(); + options.CommandLineCount = static_cast(cmd.size()); + }; WSLA_CREATE_PROCESS_RESULT result{}; - auto sockets = CreateLinuxProcessImpl(&options, static_cast(fds.size()), fds.data(), &result); + auto sockets = CreateLinuxProcessImpl(&options, static_cast(fds.size()), fds.data(), &result, prepareCommandLine); THROW_HR_IF(E_FAIL, result.Errno != 0); @@ -364,13 +379,12 @@ void WSLAVirtualMachine::ConfigureNetworking() config.FirewallConfig.reset(); }*/ - // TODO: DNS Tunneling support m_networkEngine = std::make_unique( m_computeSystem.get(), wsl::core::NatNetworking::CreateNetwork(config), - std::move(sockets[0]), + std::move(sockets[0].Socket), config, - sockets.size() > 1 ? std::move(sockets[1]) : wil::unique_socket{}); + sockets.size() > 1 ? std::move(sockets[1].Socket) : wil::unique_socket{}); m_networkEngine->Initialize(); @@ -587,13 +601,26 @@ std::tuple WSLAVirtualMachine::For return std::make_tuple(pid, ptyMaster, wsl::shared::SocketChannel{std::move(socket), std::to_string(pid), m_vmTerminatingEvent.get()}); } -wil::unique_socket WSLAVirtualMachine::ConnectSocket(wsl::shared::SocketChannel& Channel, int32_t Fd) +WSLAVirtualMachine::ConnectedSocket WSLAVirtualMachine::ConnectSocket(wsl::shared::SocketChannel& Channel, int32_t Fd) { WSLA_ACCEPT message{}; message.Fd = Fd; const auto& response = Channel.Transaction(message); + ConnectedSocket socket; - return wsl::windows::common::hvsocket::Connect(m_vmId, response.Result); + socket.Socket = wsl::windows::common::hvsocket::Connect(m_vmId, response.Result); + + // If the FD was unspecified, read the Linux file descriptor from the guest. + if (Fd == -1) + { + socket.Fd = Channel.ReceiveMessage>().Result; + } + else + { + socket.Fd = Fd; + } + + return socket; } void WSLAVirtualMachine::OpenLinuxFile(wsl::shared::SocketChannel& Channel, const char* Path, uint32_t Flags, int32_t Fd) @@ -621,10 +648,10 @@ try for (size_t i = 0; i < sockets.size(); i++) { - if (sockets[i]) + if (sockets[i].Socket) { - Handles[i] = - HandleToUlong(wsl::windows::common::wslutil::DuplicateHandleToCallingProcess(reinterpret_cast(sockets[i].get()))); + Handles[i] = HandleToUlong( + wsl::windows::common::wslutil::DuplicateHandleToCallingProcess(reinterpret_cast(sockets[i].Socket.get()))); } } @@ -632,8 +659,12 @@ try } CATCH_RETURN(); -std::vector WSLAVirtualMachine::CreateLinuxProcessImpl( - _In_ const WSLA_CREATE_PROCESS_OPTIONS* Options, _In_ ULONG FdCount, _In_ WSLA_PROCESS_FD* Fds, _Out_ WSLA_CREATE_PROCESS_RESULT* Result) +std::vector WSLAVirtualMachine::CreateLinuxProcessImpl( + _In_ const WSLA_CREATE_PROCESS_OPTIONS* Options, + _In_ ULONG FdCount, + _In_ WSLA_PROCESS_FD* Fds, + _Out_ WSLA_CREATE_PROCESS_RESULT* Result, + const TPrepareCommandLine& PrepareCommandLine) { // Check if this is a tty or not const WSLA_PROCESS_FD* ttyInput = nullptr; @@ -642,7 +673,7 @@ std::vector WSLAVirtualMachine::CreateLinuxProcessImpl( auto interactiveTty = ParseTtyInformation(Fds, FdCount, &ttyInput, &ttyOutput, &ttyControl); auto [pid, _, childChannel] = Fork(WSLA_FORK::Process); - std::vector sockets(FdCount); + std::vector sockets(FdCount); for (size_t i = 0; i < FdCount; i++) { if (Fds[i].Type == WslFdTypeDefault || Fds[i].Type == WslFdTypeTerminalInput || Fds[i].Type == WslFdTypeTerminalOutput || @@ -664,6 +695,8 @@ std::vector WSLAVirtualMachine::CreateLinuxProcessImpl( } } + PrepareCommandLine(sockets); + wsl::shared::MessageWriter Message; Message.WriteString(Message->ExecutableIndex, Options->Executable); diff --git a/src/windows/wslaservice/exe/WSLAVirtualMachine.h b/src/windows/wslaservice/exe/WSLAVirtualMachine.h index 33016ef..9f64afa 100644 --- a/src/windows/wslaservice/exe/WSLAVirtualMachine.h +++ b/src/windows/wslaservice/exe/WSLAVirtualMachine.h @@ -50,6 +50,14 @@ public: IFACEMETHOD(MountGpuLibraries(_In_ LPCSTR LibrariesMountPoint, _In_ LPCSTR DriversMountpoint, _In_ DWORD Flags)) override; private: + struct ConnectedSocket + { + int Fd; + wil::unique_socket Socket; + }; + + using TPrepareCommandLine = std::function&)>; + static int32_t MountImpl(wsl::shared::SocketChannel& Channel, LPCSTR Source, _In_ LPCSTR Target, _In_ LPCSTR Type, _In_ LPCSTR Options, _In_ ULONG Flags); static void CALLBACK s_OnExit(_In_ HCS_EVENT* Event, _In_opt_ void* Context); static bool ParseTtyInformation( @@ -62,12 +70,16 @@ private: std::tuple Fork(wsl::shared::SocketChannel& Channel, enum WSLA_FORK::ForkType Type); int32_t ExpectClosedChannelOrError(wsl::shared::SocketChannel& Channel); - wil::unique_socket ConnectSocket(wsl::shared::SocketChannel& Channel, int32_t Fd); - void OpenLinuxFile(wsl::shared::SocketChannel& Channel, const char* Path, uint32_t Flags, int32_t Fd); + ConnectedSocket ConnectSocket(wsl::shared::SocketChannel& Channel, int32_t Fd); + static void OpenLinuxFile(wsl::shared::SocketChannel& Channel, const char* Path, uint32_t Flags, int32_t Fd); void LaunchPortRelay(); - std::vector CreateLinuxProcessImpl( - _In_ const WSLA_CREATE_PROCESS_OPTIONS* Options, _In_ ULONG FdCount, _In_ WSLA_PROCESS_FD* Fd, _Out_ WSLA_CREATE_PROCESS_RESULT* Result); + std::vector CreateLinuxProcessImpl( + _In_ const WSLA_CREATE_PROCESS_OPTIONS* Options, + _In_ ULONG FdCount, + _In_ WSLA_PROCESS_FD* Fd, + _Out_ WSLA_CREATE_PROCESS_RESULT* Result, + const TPrepareCommandLine& PrepareCommandLine = [](const auto&) {}); HRESULT MountWindowsFolderImpl(_In_ LPCWSTR WindowsPath, _In_ LPCSTR LinuxPath, _In_ BOOL ReadOnly, _In_ WslMountFlags Flags);