wsla: Update the CreateProcess logic to support allocating file descr… (#13655)

* wsla: Update the CreateProcess logic to support allocating file descriptors from the guest

* PR feedback
This commit is contained in:
Blue 2025-10-31 18:49:25 -07:00 committed by GitHub
parent 1a6cbe560b
commit 2abbe8baf3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 87 additions and 31 deletions

View File

@ -85,11 +85,20 @@ void HandleMessageImpl(wsl::shared::SocketChannel& Channel, const WSLA_ACCEPT& M
Channel.SendResultMessage<uint32_t>(SocketAddress.svm_port); Channel.SendResultMessage<uint32_t>(SocketAddress.svm_port);
wil::unique_fd Socket{UtilAcceptVsock(ListenSocket.get(), SocketAddress, SESSION_LEADER_ACCEPT_TIMEOUT_MS)}; wil::unique_fd Socket{
UtilAcceptVsock(ListenSocket.get(), SocketAddress, SESSION_LEADER_ACCEPT_TIMEOUT_MS, Message.Fd != -1 ? SOCK_CLOEXEC : 0)};
THROW_LAST_ERROR_IF(!Socket); THROW_LAST_ERROR_IF(!Socket);
if (Message.Fd != -1)
{
THROW_LAST_ERROR_IF(dup2(Socket.get(), Message.Fd) < 0); THROW_LAST_ERROR_IF(dup2(Socket.get(), Message.Fd) < 0);
} }
else
{
Channel.SendResultMessage<int32_t>(Socket.get());
Socket.release();
}
}
void HandleMessageImpl(wsl::shared::SocketChannel& Channel, const WSLA_CONNECT& Message, const gsl::span<gsl::byte>& Buffer) void HandleMessageImpl(wsl::shared::SocketChannel& Channel, const WSLA_CONNECT& Message, const gsl::span<gsl::byte>& Buffer)
{ {

View File

@ -197,7 +197,7 @@ InteropServer::~InteropServer()
Reset(); Reset();
} }
int UtilAcceptVsock(int SocketFd, sockaddr_vm SocketAddress, int Timeout) int UtilAcceptVsock(int SocketFd, sockaddr_vm SocketAddress, int Timeout, int SocketFlags)
/*++ /*++
@ -215,6 +215,8 @@ Arguments:
Timeout - Supplies a timeout. Timeout - Supplies a timeout.
SocketFlags - Supplies the socket flags.
Return Value: Return Value:
A file descriptor representing the socket, -1 on failure. A file descriptor representing the socket, -1 on failure.
@ -263,7 +265,7 @@ Return Value:
if (Result != -1) if (Result != -1)
{ {
socklen_t SocketAddressSize = sizeof(SocketAddress); socklen_t SocketAddressSize = sizeof(SocketAddress);
Result = accept4(SocketFd, reinterpret_cast<sockaddr*>(&SocketAddress), &SocketAddressSize, SOCK_CLOEXEC); Result = accept4(SocketFd, reinterpret_cast<sockaddr*>(&SocketAddress), &SocketAddressSize, SocketFlags);
} }
if (Result < 0) if (Result < 0)

View File

@ -117,7 +117,7 @@ private:
wil::unique_fd m_InteropSocket; wil::unique_fd m_InteropSocket;
}; };
int UtilAcceptVsock(int SocketFd, sockaddr_vm Address, int Timeout = -1); int UtilAcceptVsock(int SocketFd, sockaddr_vm Address, int Timeout = -1, int SocketFlags = SOCK_CLOEXEC);
int UtilBindVsockAnyPort(struct sockaddr_vm* SocketAddress, int Type); int UtilBindVsockAnyPort(struct sockaddr_vm* SocketAddress, int Type);

View File

@ -328,30 +328,45 @@ void WSLAVirtualMachine::ConfigureNetworking()
// WSLA-TODO: Using fd=4 here seems to hang gns. There's probably a hardcoded file descriptor somewhere that's causing // WSLA-TODO: Using fd=4 here seems to hang gns. There's probably a hardcoded file descriptor somewhere that's causing
// so using 1000 for now. // so using 1000 for now.
std::vector<WSLA_PROCESS_FD> fds(1); std::vector<WSLA_PROCESS_FD> fds(1);
fds[0].Fd = 1000; fds[0].Fd = -1;
fds[0].Type = WslFdType::WslFdTypeDefault; fds[0].Type = WslFdType::WslFdTypeDefault;
std::vector<const char*> cmd{"/gns", LX_INIT_GNS_SOCKET_ARG, "1000"}; std::vector<const char*> cmd{"/gns", LX_INIT_GNS_SOCKET_ARG};
// If DNS tunnelling is enabled, use an additional for its channel. // If DNS tunnelling is enabled, use an additional for its channel.
if (m_settings.EnableDnsTunneling) if (m_settings.EnableDnsTunneling)
{ {
fds.emplace_back(WSLA_PROCESS_FD{.Fd = 1001, .Type = WslFdType::WslFdTypeDefault}); fds.emplace_back(WSLA_PROCESS_FD{.Fd = -1, .Type = WslFdType::WslFdTypeDefault});
cmd.emplace_back(LX_INIT_GNS_DNS_SOCKET_ARG);
cmd.emplace_back("1001");
cmd.emplace_back(LX_INIT_GNS_DNS_TUNNELING_IP);
cmd.emplace_back(LX_INIT_DNS_TUNNELING_IP_ADDRESS);
THROW_IF_FAILED(wsl::core::networking::DnsResolver::LoadDnsResolverMethods()); THROW_IF_FAILED(wsl::core::networking::DnsResolver::LoadDnsResolverMethods());
} }
WSLA_CREATE_PROCESS_OPTIONS options{}; WSLA_CREATE_PROCESS_OPTIONS options{};
options.Executable = "/init"; options.Executable = "/init";
// Because the file descriptors numbers aren't known in advance, the command line needs to be generated after the file
// descriptors are allocated.
std::string socketFdArg;
std::string dnsFdArg;
auto prepareCommandLine = [&](const auto& sockets) {
socketFdArg = std::to_string(sockets[0].Fd);
cmd.emplace_back(socketFdArg.c_str());
if (sockets.size() > 1)
{
dnsFdArg = std::to_string(sockets[1].Fd);
cmd.emplace_back(LX_INIT_GNS_DNS_SOCKET_ARG);
cmd.emplace_back(dnsFdArg.c_str());
cmd.emplace_back(LX_INIT_GNS_DNS_TUNNELING_IP);
cmd.emplace_back(LX_INIT_DNS_TUNNELING_IP_ADDRESS);
}
options.CommandLine = cmd.data(); options.CommandLine = cmd.data();
options.CommandLineCount = static_cast<ULONG>(cmd.size()); options.CommandLineCount = static_cast<DWORD>(cmd.size());
};
WSLA_CREATE_PROCESS_RESULT result{}; WSLA_CREATE_PROCESS_RESULT result{};
auto sockets = CreateLinuxProcessImpl(&options, static_cast<DWORD>(fds.size()), fds.data(), &result); auto sockets = CreateLinuxProcessImpl(&options, static_cast<DWORD>(fds.size()), fds.data(), &result, prepareCommandLine);
THROW_HR_IF(E_FAIL, result.Errno != 0); THROW_HR_IF(E_FAIL, result.Errno != 0);
@ -364,13 +379,12 @@ void WSLAVirtualMachine::ConfigureNetworking()
config.FirewallConfig.reset(); config.FirewallConfig.reset();
}*/ }*/
// TODO: DNS Tunneling support
m_networkEngine = std::make_unique<wsl::core::NatNetworking>( m_networkEngine = std::make_unique<wsl::core::NatNetworking>(
m_computeSystem.get(), m_computeSystem.get(),
wsl::core::NatNetworking::CreateNetwork(config), wsl::core::NatNetworking::CreateNetwork(config),
std::move(sockets[0]), std::move(sockets[0].Socket),
config, config,
sockets.size() > 1 ? std::move(sockets[1]) : wil::unique_socket{}); sockets.size() > 1 ? std::move(sockets[1].Socket) : wil::unique_socket{});
m_networkEngine->Initialize(); m_networkEngine->Initialize();
@ -587,13 +601,26 @@ std::tuple<int32_t, int32_t, wsl::shared::SocketChannel> WSLAVirtualMachine::For
return std::make_tuple(pid, ptyMaster, wsl::shared::SocketChannel{std::move(socket), std::to_string(pid), m_vmTerminatingEvent.get()}); return std::make_tuple(pid, ptyMaster, wsl::shared::SocketChannel{std::move(socket), std::to_string(pid), m_vmTerminatingEvent.get()});
} }
wil::unique_socket WSLAVirtualMachine::ConnectSocket(wsl::shared::SocketChannel& Channel, int32_t Fd) WSLAVirtualMachine::ConnectedSocket WSLAVirtualMachine::ConnectSocket(wsl::shared::SocketChannel& Channel, int32_t Fd)
{ {
WSLA_ACCEPT message{}; WSLA_ACCEPT message{};
message.Fd = Fd; message.Fd = Fd;
const auto& response = Channel.Transaction(message); const auto& response = Channel.Transaction(message);
ConnectedSocket socket;
return wsl::windows::common::hvsocket::Connect(m_vmId, response.Result); socket.Socket = wsl::windows::common::hvsocket::Connect(m_vmId, response.Result);
// If the FD was unspecified, read the Linux file descriptor from the guest.
if (Fd == -1)
{
socket.Fd = Channel.ReceiveMessage<RESULT_MESSAGE<int32_t>>().Result;
}
else
{
socket.Fd = Fd;
}
return socket;
} }
void WSLAVirtualMachine::OpenLinuxFile(wsl::shared::SocketChannel& Channel, const char* Path, uint32_t Flags, int32_t Fd) void WSLAVirtualMachine::OpenLinuxFile(wsl::shared::SocketChannel& Channel, const char* Path, uint32_t Flags, int32_t Fd)
@ -621,10 +648,10 @@ try
for (size_t i = 0; i < sockets.size(); i++) for (size_t i = 0; i < sockets.size(); i++)
{ {
if (sockets[i]) if (sockets[i].Socket)
{ {
Handles[i] = Handles[i] = HandleToUlong(
HandleToUlong(wsl::windows::common::wslutil::DuplicateHandleToCallingProcess(reinterpret_cast<HANDLE>(sockets[i].get()))); wsl::windows::common::wslutil::DuplicateHandleToCallingProcess(reinterpret_cast<HANDLE>(sockets[i].Socket.get())));
} }
} }
@ -632,8 +659,12 @@ try
} }
CATCH_RETURN(); CATCH_RETURN();
std::vector<wil::unique_socket> WSLAVirtualMachine::CreateLinuxProcessImpl( std::vector<WSLAVirtualMachine::ConnectedSocket> WSLAVirtualMachine::CreateLinuxProcessImpl(
_In_ const WSLA_CREATE_PROCESS_OPTIONS* Options, _In_ ULONG FdCount, _In_ WSLA_PROCESS_FD* Fds, _Out_ WSLA_CREATE_PROCESS_RESULT* Result) _In_ const WSLA_CREATE_PROCESS_OPTIONS* Options,
_In_ ULONG FdCount,
_In_ WSLA_PROCESS_FD* Fds,
_Out_ WSLA_CREATE_PROCESS_RESULT* Result,
const TPrepareCommandLine& PrepareCommandLine)
{ {
// Check if this is a tty or not // Check if this is a tty or not
const WSLA_PROCESS_FD* ttyInput = nullptr; const WSLA_PROCESS_FD* ttyInput = nullptr;
@ -642,7 +673,7 @@ std::vector<wil::unique_socket> WSLAVirtualMachine::CreateLinuxProcessImpl(
auto interactiveTty = ParseTtyInformation(Fds, FdCount, &ttyInput, &ttyOutput, &ttyControl); auto interactiveTty = ParseTtyInformation(Fds, FdCount, &ttyInput, &ttyOutput, &ttyControl);
auto [pid, _, childChannel] = Fork(WSLA_FORK::Process); auto [pid, _, childChannel] = Fork(WSLA_FORK::Process);
std::vector<wil::unique_socket> sockets(FdCount); std::vector<ConnectedSocket> sockets(FdCount);
for (size_t i = 0; i < FdCount; i++) for (size_t i = 0; i < FdCount; i++)
{ {
if (Fds[i].Type == WslFdTypeDefault || Fds[i].Type == WslFdTypeTerminalInput || Fds[i].Type == WslFdTypeTerminalOutput || if (Fds[i].Type == WslFdTypeDefault || Fds[i].Type == WslFdTypeTerminalInput || Fds[i].Type == WslFdTypeTerminalOutput ||
@ -664,6 +695,8 @@ std::vector<wil::unique_socket> WSLAVirtualMachine::CreateLinuxProcessImpl(
} }
} }
PrepareCommandLine(sockets);
wsl::shared::MessageWriter<WSLA_EXEC> Message; wsl::shared::MessageWriter<WSLA_EXEC> Message;
Message.WriteString(Message->ExecutableIndex, Options->Executable); Message.WriteString(Message->ExecutableIndex, Options->Executable);

View File

@ -50,6 +50,14 @@ public:
IFACEMETHOD(MountGpuLibraries(_In_ LPCSTR LibrariesMountPoint, _In_ LPCSTR DriversMountpoint, _In_ DWORD Flags)) override; IFACEMETHOD(MountGpuLibraries(_In_ LPCSTR LibrariesMountPoint, _In_ LPCSTR DriversMountpoint, _In_ DWORD Flags)) override;
private: private:
struct ConnectedSocket
{
int Fd;
wil::unique_socket Socket;
};
using TPrepareCommandLine = std::function<void(const std::vector<ConnectedSocket>&)>;
static int32_t MountImpl(wsl::shared::SocketChannel& Channel, LPCSTR Source, _In_ LPCSTR Target, _In_ LPCSTR Type, _In_ LPCSTR Options, _In_ ULONG Flags); static int32_t MountImpl(wsl::shared::SocketChannel& Channel, LPCSTR Source, _In_ LPCSTR Target, _In_ LPCSTR Type, _In_ LPCSTR Options, _In_ ULONG Flags);
static void CALLBACK s_OnExit(_In_ HCS_EVENT* Event, _In_opt_ void* Context); static void CALLBACK s_OnExit(_In_ HCS_EVENT* Event, _In_opt_ void* Context);
static bool ParseTtyInformation( static bool ParseTtyInformation(
@ -62,12 +70,16 @@ private:
std::tuple<int32_t, int32_t, wsl::shared::SocketChannel> Fork(wsl::shared::SocketChannel& Channel, enum WSLA_FORK::ForkType Type); std::tuple<int32_t, int32_t, wsl::shared::SocketChannel> Fork(wsl::shared::SocketChannel& Channel, enum WSLA_FORK::ForkType Type);
int32_t ExpectClosedChannelOrError(wsl::shared::SocketChannel& Channel); int32_t ExpectClosedChannelOrError(wsl::shared::SocketChannel& Channel);
wil::unique_socket ConnectSocket(wsl::shared::SocketChannel& Channel, int32_t Fd); ConnectedSocket ConnectSocket(wsl::shared::SocketChannel& Channel, int32_t Fd);
void OpenLinuxFile(wsl::shared::SocketChannel& Channel, const char* Path, uint32_t Flags, int32_t Fd); static void OpenLinuxFile(wsl::shared::SocketChannel& Channel, const char* Path, uint32_t Flags, int32_t Fd);
void LaunchPortRelay(); void LaunchPortRelay();
std::vector<wil::unique_socket> CreateLinuxProcessImpl( std::vector<ConnectedSocket> CreateLinuxProcessImpl(
_In_ const WSLA_CREATE_PROCESS_OPTIONS* Options, _In_ ULONG FdCount, _In_ WSLA_PROCESS_FD* Fd, _Out_ WSLA_CREATE_PROCESS_RESULT* Result); _In_ const WSLA_CREATE_PROCESS_OPTIONS* Options,
_In_ ULONG FdCount,
_In_ WSLA_PROCESS_FD* Fd,
_Out_ WSLA_CREATE_PROCESS_RESULT* Result,
const TPrepareCommandLine& PrepareCommandLine = [](const auto&) {});
HRESULT MountWindowsFolderImpl(_In_ LPCWSTR WindowsPath, _In_ LPCSTR LinuxPath, _In_ BOOL ReadOnly, _In_ WslMountFlags Flags); HRESULT MountWindowsFolderImpl(_In_ LPCWSTR WindowsPath, _In_ LPCSTR LinuxPath, _In_ BOOL ReadOnly, _In_ WslMountFlags Flags);