wsla: Update the CreateProcess logic to support allocating file descr… (#13655)

* wsla: Update the CreateProcess logic to support allocating file descriptors from the guest

* PR feedback
This commit is contained in:
Blue 2025-10-31 18:49:25 -07:00 committed by GitHub
parent 1a6cbe560b
commit 2abbe8baf3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 87 additions and 31 deletions

View File

@ -85,10 +85,19 @@ void HandleMessageImpl(wsl::shared::SocketChannel& Channel, const WSLA_ACCEPT& M
Channel.SendResultMessage<uint32_t>(SocketAddress.svm_port);
wil::unique_fd Socket{UtilAcceptVsock(ListenSocket.get(), SocketAddress, SESSION_LEADER_ACCEPT_TIMEOUT_MS)};
wil::unique_fd Socket{
UtilAcceptVsock(ListenSocket.get(), SocketAddress, SESSION_LEADER_ACCEPT_TIMEOUT_MS, Message.Fd != -1 ? SOCK_CLOEXEC : 0)};
THROW_LAST_ERROR_IF(!Socket);
THROW_LAST_ERROR_IF(dup2(Socket.get(), Message.Fd) < 0);
if (Message.Fd != -1)
{
THROW_LAST_ERROR_IF(dup2(Socket.get(), Message.Fd) < 0);
}
else
{
Channel.SendResultMessage<int32_t>(Socket.get());
Socket.release();
}
}
void HandleMessageImpl(wsl::shared::SocketChannel& Channel, const WSLA_CONNECT& Message, const gsl::span<gsl::byte>& Buffer)

View File

@ -197,7 +197,7 @@ InteropServer::~InteropServer()
Reset();
}
int UtilAcceptVsock(int SocketFd, sockaddr_vm SocketAddress, int Timeout)
int UtilAcceptVsock(int SocketFd, sockaddr_vm SocketAddress, int Timeout, int SocketFlags)
/*++
@ -215,6 +215,8 @@ Arguments:
Timeout - Supplies a timeout.
SocketFlags - Supplies the socket flags.
Return Value:
A file descriptor representing the socket, -1 on failure.
@ -263,7 +265,7 @@ Return Value:
if (Result != -1)
{
socklen_t SocketAddressSize = sizeof(SocketAddress);
Result = accept4(SocketFd, reinterpret_cast<sockaddr*>(&SocketAddress), &SocketAddressSize, SOCK_CLOEXEC);
Result = accept4(SocketFd, reinterpret_cast<sockaddr*>(&SocketAddress), &SocketAddressSize, SocketFlags);
}
if (Result < 0)

View File

@ -117,7 +117,7 @@ private:
wil::unique_fd m_InteropSocket;
};
int UtilAcceptVsock(int SocketFd, sockaddr_vm Address, int Timeout = -1);
int UtilAcceptVsock(int SocketFd, sockaddr_vm Address, int Timeout = -1, int SocketFlags = SOCK_CLOEXEC);
int UtilBindVsockAnyPort(struct sockaddr_vm* SocketAddress, int Type);

View File

@ -328,30 +328,45 @@ void WSLAVirtualMachine::ConfigureNetworking()
// WSLA-TODO: Using fd=4 here seems to hang gns. There's probably a hardcoded file descriptor somewhere that's causing
// so using 1000 for now.
std::vector<WSLA_PROCESS_FD> fds(1);
fds[0].Fd = 1000;
fds[0].Fd = -1;
fds[0].Type = WslFdType::WslFdTypeDefault;
std::vector<const char*> cmd{"/gns", LX_INIT_GNS_SOCKET_ARG, "1000"};
std::vector<const char*> cmd{"/gns", LX_INIT_GNS_SOCKET_ARG};
// If DNS tunnelling is enabled, use an additional for its channel.
if (m_settings.EnableDnsTunneling)
{
fds.emplace_back(WSLA_PROCESS_FD{.Fd = 1001, .Type = WslFdType::WslFdTypeDefault});
cmd.emplace_back(LX_INIT_GNS_DNS_SOCKET_ARG);
cmd.emplace_back("1001");
cmd.emplace_back(LX_INIT_GNS_DNS_TUNNELING_IP);
cmd.emplace_back(LX_INIT_DNS_TUNNELING_IP_ADDRESS);
fds.emplace_back(WSLA_PROCESS_FD{.Fd = -1, .Type = WslFdType::WslFdTypeDefault});
THROW_IF_FAILED(wsl::core::networking::DnsResolver::LoadDnsResolverMethods());
}
WSLA_CREATE_PROCESS_OPTIONS options{};
options.Executable = "/init";
options.CommandLine = cmd.data();
options.CommandLineCount = static_cast<ULONG>(cmd.size());
// Because the file descriptors numbers aren't known in advance, the command line needs to be generated after the file
// descriptors are allocated.
std::string socketFdArg;
std::string dnsFdArg;
auto prepareCommandLine = [&](const auto& sockets) {
socketFdArg = std::to_string(sockets[0].Fd);
cmd.emplace_back(socketFdArg.c_str());
if (sockets.size() > 1)
{
dnsFdArg = std::to_string(sockets[1].Fd);
cmd.emplace_back(LX_INIT_GNS_DNS_SOCKET_ARG);
cmd.emplace_back(dnsFdArg.c_str());
cmd.emplace_back(LX_INIT_GNS_DNS_TUNNELING_IP);
cmd.emplace_back(LX_INIT_DNS_TUNNELING_IP_ADDRESS);
}
options.CommandLine = cmd.data();
options.CommandLineCount = static_cast<DWORD>(cmd.size());
};
WSLA_CREATE_PROCESS_RESULT result{};
auto sockets = CreateLinuxProcessImpl(&options, static_cast<DWORD>(fds.size()), fds.data(), &result);
auto sockets = CreateLinuxProcessImpl(&options, static_cast<DWORD>(fds.size()), fds.data(), &result, prepareCommandLine);
THROW_HR_IF(E_FAIL, result.Errno != 0);
@ -364,13 +379,12 @@ void WSLAVirtualMachine::ConfigureNetworking()
config.FirewallConfig.reset();
}*/
// TODO: DNS Tunneling support
m_networkEngine = std::make_unique<wsl::core::NatNetworking>(
m_computeSystem.get(),
wsl::core::NatNetworking::CreateNetwork(config),
std::move(sockets[0]),
std::move(sockets[0].Socket),
config,
sockets.size() > 1 ? std::move(sockets[1]) : wil::unique_socket{});
sockets.size() > 1 ? std::move(sockets[1].Socket) : wil::unique_socket{});
m_networkEngine->Initialize();
@ -587,13 +601,26 @@ std::tuple<int32_t, int32_t, wsl::shared::SocketChannel> WSLAVirtualMachine::For
return std::make_tuple(pid, ptyMaster, wsl::shared::SocketChannel{std::move(socket), std::to_string(pid), m_vmTerminatingEvent.get()});
}
wil::unique_socket WSLAVirtualMachine::ConnectSocket(wsl::shared::SocketChannel& Channel, int32_t Fd)
WSLAVirtualMachine::ConnectedSocket WSLAVirtualMachine::ConnectSocket(wsl::shared::SocketChannel& Channel, int32_t Fd)
{
WSLA_ACCEPT message{};
message.Fd = Fd;
const auto& response = Channel.Transaction(message);
ConnectedSocket socket;
return wsl::windows::common::hvsocket::Connect(m_vmId, response.Result);
socket.Socket = wsl::windows::common::hvsocket::Connect(m_vmId, response.Result);
// If the FD was unspecified, read the Linux file descriptor from the guest.
if (Fd == -1)
{
socket.Fd = Channel.ReceiveMessage<RESULT_MESSAGE<int32_t>>().Result;
}
else
{
socket.Fd = Fd;
}
return socket;
}
void WSLAVirtualMachine::OpenLinuxFile(wsl::shared::SocketChannel& Channel, const char* Path, uint32_t Flags, int32_t Fd)
@ -621,10 +648,10 @@ try
for (size_t i = 0; i < sockets.size(); i++)
{
if (sockets[i])
if (sockets[i].Socket)
{
Handles[i] =
HandleToUlong(wsl::windows::common::wslutil::DuplicateHandleToCallingProcess(reinterpret_cast<HANDLE>(sockets[i].get())));
Handles[i] = HandleToUlong(
wsl::windows::common::wslutil::DuplicateHandleToCallingProcess(reinterpret_cast<HANDLE>(sockets[i].Socket.get())));
}
}
@ -632,8 +659,12 @@ try
}
CATCH_RETURN();
std::vector<wil::unique_socket> WSLAVirtualMachine::CreateLinuxProcessImpl(
_In_ const WSLA_CREATE_PROCESS_OPTIONS* Options, _In_ ULONG FdCount, _In_ WSLA_PROCESS_FD* Fds, _Out_ WSLA_CREATE_PROCESS_RESULT* Result)
std::vector<WSLAVirtualMachine::ConnectedSocket> WSLAVirtualMachine::CreateLinuxProcessImpl(
_In_ const WSLA_CREATE_PROCESS_OPTIONS* Options,
_In_ ULONG FdCount,
_In_ WSLA_PROCESS_FD* Fds,
_Out_ WSLA_CREATE_PROCESS_RESULT* Result,
const TPrepareCommandLine& PrepareCommandLine)
{
// Check if this is a tty or not
const WSLA_PROCESS_FD* ttyInput = nullptr;
@ -642,7 +673,7 @@ std::vector<wil::unique_socket> WSLAVirtualMachine::CreateLinuxProcessImpl(
auto interactiveTty = ParseTtyInformation(Fds, FdCount, &ttyInput, &ttyOutput, &ttyControl);
auto [pid, _, childChannel] = Fork(WSLA_FORK::Process);
std::vector<wil::unique_socket> sockets(FdCount);
std::vector<ConnectedSocket> sockets(FdCount);
for (size_t i = 0; i < FdCount; i++)
{
if (Fds[i].Type == WslFdTypeDefault || Fds[i].Type == WslFdTypeTerminalInput || Fds[i].Type == WslFdTypeTerminalOutput ||
@ -664,6 +695,8 @@ std::vector<wil::unique_socket> WSLAVirtualMachine::CreateLinuxProcessImpl(
}
}
PrepareCommandLine(sockets);
wsl::shared::MessageWriter<WSLA_EXEC> Message;
Message.WriteString(Message->ExecutableIndex, Options->Executable);

View File

@ -50,6 +50,14 @@ public:
IFACEMETHOD(MountGpuLibraries(_In_ LPCSTR LibrariesMountPoint, _In_ LPCSTR DriversMountpoint, _In_ DWORD Flags)) override;
private:
struct ConnectedSocket
{
int Fd;
wil::unique_socket Socket;
};
using TPrepareCommandLine = std::function<void(const std::vector<ConnectedSocket>&)>;
static int32_t MountImpl(wsl::shared::SocketChannel& Channel, LPCSTR Source, _In_ LPCSTR Target, _In_ LPCSTR Type, _In_ LPCSTR Options, _In_ ULONG Flags);
static void CALLBACK s_OnExit(_In_ HCS_EVENT* Event, _In_opt_ void* Context);
static bool ParseTtyInformation(
@ -62,12 +70,16 @@ private:
std::tuple<int32_t, int32_t, wsl::shared::SocketChannel> Fork(wsl::shared::SocketChannel& Channel, enum WSLA_FORK::ForkType Type);
int32_t ExpectClosedChannelOrError(wsl::shared::SocketChannel& Channel);
wil::unique_socket ConnectSocket(wsl::shared::SocketChannel& Channel, int32_t Fd);
void OpenLinuxFile(wsl::shared::SocketChannel& Channel, const char* Path, uint32_t Flags, int32_t Fd);
ConnectedSocket ConnectSocket(wsl::shared::SocketChannel& Channel, int32_t Fd);
static void OpenLinuxFile(wsl::shared::SocketChannel& Channel, const char* Path, uint32_t Flags, int32_t Fd);
void LaunchPortRelay();
std::vector<wil::unique_socket> CreateLinuxProcessImpl(
_In_ const WSLA_CREATE_PROCESS_OPTIONS* Options, _In_ ULONG FdCount, _In_ WSLA_PROCESS_FD* Fd, _Out_ WSLA_CREATE_PROCESS_RESULT* Result);
std::vector<ConnectedSocket> CreateLinuxProcessImpl(
_In_ const WSLA_CREATE_PROCESS_OPTIONS* Options,
_In_ ULONG FdCount,
_In_ WSLA_PROCESS_FD* Fd,
_Out_ WSLA_CREATE_PROCESS_RESULT* Result,
const TPrepareCommandLine& PrepareCommandLine = [](const auto&) {});
HRESULT MountWindowsFolderImpl(_In_ LPCWSTR WindowsPath, _In_ LPCSTR LinuxPath, _In_ BOOL ReadOnly, _In_ WslMountFlags Flags);