wsla: do not leak fd's to user processes (#13406)

This commit is contained in:
Blue 2025-08-18 16:34:57 -07:00 committed by Blue
parent 631ee32cc5
commit 5949c06db2
2 changed files with 68 additions and 13 deletions

View File

@ -43,6 +43,8 @@ int MountInit(const char* Target);
extern int EnableInterface(int Socket, const char* Name); extern int EnableInterface(int Socket, const char* Name);
extern int SetCloseOnExec(int Fd, bool Enable);
extern int g_LogFd; extern int g_LogFd;
void HandleMessageImpl(wsl::shared::SocketChannel& Channel, const LSW_GET_DISK& Message, const gsl::span<gsl::byte>& Buffer) void HandleMessageImpl(wsl::shared::SocketChannel& Channel, const LSW_GET_DISK& Message, const gsl::span<gsl::byte>& Buffer)
@ -67,13 +69,14 @@ void HandleMessageImpl(wsl::shared::SocketChannel& Channel, const LSW_GET_DISK&
void HandleMessageImpl(wsl::shared::SocketChannel& Channel, const LSW_ACCEPT& Message, const gsl::span<gsl::byte>& Buffer) void HandleMessageImpl(wsl::shared::SocketChannel& Channel, const LSW_ACCEPT& Message, const gsl::span<gsl::byte>& Buffer)
{ {
sockaddr_vm SocketAddress{}; sockaddr_vm SocketAddress{};
wil::unique_fd ListenSocket{UtilListenVsockAnyPort(&SocketAddress, 1, false)}; wil::unique_fd ListenSocket{UtilListenVsockAnyPort(&SocketAddress, 1, true)};
THROW_LAST_ERROR_IF(!ListenSocket); THROW_LAST_ERROR_IF(!ListenSocket);
Channel.SendResultMessage<uint32_t>(SocketAddress.svm_port); Channel.SendResultMessage<uint32_t>(SocketAddress.svm_port);
wil::unique_fd Socket{UtilAcceptVsock(ListenSocket.get(), SocketAddress, SESSION_LEADER_ACCEPT_TIMEOUT_MS)}; wil::unique_fd Socket{UtilAcceptVsock(ListenSocket.get(), SocketAddress, SESSION_LEADER_ACCEPT_TIMEOUT_MS)};
THROW_LAST_ERROR_IF(!Socket); THROW_LAST_ERROR_IF(!Socket);
LOG_ERROR("Socket fd: {} -> {}", Socket.get(), Message.Fd);
THROW_LAST_ERROR_IF(dup2(Socket.get(), Message.Fd) < 0); THROW_LAST_ERROR_IF(dup2(Socket.get(), Message.Fd) < 0);
} }
@ -644,13 +647,15 @@ int LswEntryPoint(int Argc, char* Argv[])
// Ensure /dev/console is present and set as the controlling terminal. // Ensure /dev/console is present and set as the controlling terminal.
// If opening /dev/console times out, stdout and stderr to the logging file descriptor. // If opening /dev/console times out, stdout and stderr to the logging file descriptor.
// //
wil::unique_fd ConsoleFd{}; wil::unique_fd ConsoleFd{};
try try
{ {
wsl::shared::retry::RetryWithTimeout<void>( wsl::shared::retry::RetryWithTimeout<void>(
[&]() { [&]() {
ConsoleFd = open("/dev/console", O_RDWR); ConsoleFd = open("/dev/console", O_RDWR | O_CLOEXEC);
THROW_LAST_ERROR_IF(!ConsoleFd); THROW_LAST_ERROR_IF(!ConsoleFd);
}, },
c_defaultRetryPeriod, c_defaultRetryPeriod,
@ -660,12 +665,12 @@ int LswEntryPoint(int Argc, char* Argv[])
} }
catch (...) catch (...)
{ {
if (dup2(g_LogFd, STDOUT_FILENO) < 0) if (dup3(g_LogFd, STDOUT_FILENO, O_CLOEXEC) < 0)
{ {
LOG_ERROR("dup2 failed {}", errno); LOG_ERROR("dup2 failed {}", errno);
} }
if (dup2(g_LogFd, STDERR_FILENO) < 0) if (dup3(g_LogFd, STDERR_FILENO, O_CLOEXEC) < 0)
{ {
LOG_ERROR("dup2 failed {}", errno); LOG_ERROR("dup2 failed {}", errno);
} }
@ -701,16 +706,27 @@ int LswEntryPoint(int Argc, char* Argv[])
// Enable the loopback interface. // Enable the loopback interface.
// //
wil::unique_fd Fd{socket(AF_INET, SOCK_DGRAM, IPPROTO_IP)};
if (!Fd)
{ {
LOG_ERROR("socket failed {}", errno); wil::unique_fd Fd{socket(AF_INET, SOCK_DGRAM, IPPROTO_IP)};
return -1; if (!Fd)
{
LOG_ERROR("socket failed {}", errno);
return -1;
}
if (EnableInterface(Fd.get(), "lo") < 0)
{
return -1;
}
} }
if (EnableInterface(Fd.get(), "lo") < 0) //
// Make sure not to leak std fds to user processes.
//
for (int fd : {STDIN_FILENO, STDOUT_FILENO, STDERR_FILENO})
{ {
return -1; SetCloseOnExec(fd, true);
} }
// //

View File

@ -641,7 +641,7 @@ class LSWTests
{ {
auto [fds, pid] = createProcess({"/bin/cat"}, {{0, LinuxFileOutput, "/tmp/output"}, {2, Default, nullptr}}); auto [fds, pid] = createProcess({"/bin/cat"}, {{0, LinuxFileOutput, "/tmp/output"}, {2, Default, nullptr}});
VERIFY_ARE_EQUAL(ReadToString((SOCKET)fds[1].get()), "/bin/cat: -: Bad file descriptor\n"); VERIFY_ARE_EQUAL(ReadToString((SOCKET)fds[1].get()), "/bin/cat: standard output: Bad file descriptor\n");
VERIFY_ARE_EQUAL(wait(pid), 1); VERIFY_ARE_EQUAL(wait(pid), 1);
} }
} }
@ -795,7 +795,7 @@ class LSWTests
settings.DisplayName = L"LSW"; settings.DisplayName = L"LSW";
settings.Memory.MemoryMb = 2048; settings.Memory.MemoryMb = 2048;
settings.Options.BootTimeoutMs = 30 * 1000; settings.Options.BootTimeoutMs = 30 * 1000;
settings.Networking.Mode = NetworkingModeNAT; settings.Networking.Mode = NetworkingModeNone;
auto vm = CreateVm(&settings); auto vm = CreateVm(&settings);
@ -824,7 +824,7 @@ class LSWTests
settings.DisplayName = L"LSW"; settings.DisplayName = L"LSW";
settings.Memory.MemoryMb = 2048; settings.Memory.MemoryMb = 2048;
settings.Options.BootTimeoutMs = 30 * 1000; settings.Options.BootTimeoutMs = 30 * 1000;
settings.Networking.Mode = NetworkingModeNAT; settings.Networking.Mode = NetworkingModeNone;
auto vm = CreateVm(&settings); auto vm = CreateVm(&settings);
@ -897,4 +897,43 @@ class LSWTests
VERIFY_SUCCEEDED(WslUnmountWindowsFolder(vm.get(), "/win-path")); VERIFY_SUCCEEDED(WslUnmountWindowsFolder(vm.get(), "/win-path"));
} }
} }
// This test case validates that no file descriptors are leaked to user processes.
TEST_METHOD(Fd)
{
WSL2_TEST_ONLY();
VirtualMachineSettings settings{};
settings.CPU.CpuCount = 4;
settings.DisplayName = L"LSW";
settings.Memory.MemoryMb = 2048;
settings.Options.BootTimeoutMs = 30 * 1000;
settings.Networking.Mode = NetworkingModeNone;
auto vm = CreateVm(&settings);
std::vector<ProcessFileDescriptorSettings> fds(1);
fds[0].Number = 1;
fds[0].Type = Default;
const char* args[] = {"/bin/bash", "-c", "echo /proc/self/fd/* && readlink /proc/self/fd/*", nullptr};
CreateProcessSettings createProcessSettings{};
createProcessSettings.Executable = "/bin/bash";
createProcessSettings.Arguments = args;
createProcessSettings.FileDescriptors = fds.data();
createProcessSettings.FdCount = 1;
int pid = -1;
VERIFY_SUCCEEDED(WslCreateLinuxProcess(vm.get(), &createProcessSettings, &pid));
wil::unique_socket output{(SOCKET)fds[0].Handle};
auto result = ReadToString(output.get());
// Note: fd/0 is opened readlink to read the actual content of /proc/fd.
if (!PathMatchSpecA(result.c_str(), "/proc/self/fd/0 /proc/self/fd/1\nsocket:[*]\n"))
{
LogInfo("Found additional fds: %hs", result.c_str());
VERIFY_FAIL();
}
}
}; };