wsla: do not leak fd's to user processes (#13406)

This commit is contained in:
Blue 2025-08-18 16:34:57 -07:00 committed by Blue
parent 631ee32cc5
commit 5949c06db2
2 changed files with 68 additions and 13 deletions

View File

@ -43,6 +43,8 @@ int MountInit(const char* Target);
extern int EnableInterface(int Socket, const char* Name);
extern int SetCloseOnExec(int Fd, bool Enable);
extern int g_LogFd;
void HandleMessageImpl(wsl::shared::SocketChannel& Channel, const LSW_GET_DISK& Message, const gsl::span<gsl::byte>& Buffer)
@ -67,13 +69,14 @@ void HandleMessageImpl(wsl::shared::SocketChannel& Channel, const LSW_GET_DISK&
void HandleMessageImpl(wsl::shared::SocketChannel& Channel, const LSW_ACCEPT& Message, const gsl::span<gsl::byte>& Buffer)
{
sockaddr_vm SocketAddress{};
wil::unique_fd ListenSocket{UtilListenVsockAnyPort(&SocketAddress, 1, false)};
wil::unique_fd ListenSocket{UtilListenVsockAnyPort(&SocketAddress, 1, true)};
THROW_LAST_ERROR_IF(!ListenSocket);
Channel.SendResultMessage<uint32_t>(SocketAddress.svm_port);
wil::unique_fd Socket{UtilAcceptVsock(ListenSocket.get(), SocketAddress, SESSION_LEADER_ACCEPT_TIMEOUT_MS)};
THROW_LAST_ERROR_IF(!Socket);
LOG_ERROR("Socket fd: {} -> {}", Socket.get(), Message.Fd);
THROW_LAST_ERROR_IF(dup2(Socket.get(), Message.Fd) < 0);
}
@ -644,13 +647,15 @@ int LswEntryPoint(int Argc, char* Argv[])
// Ensure /dev/console is present and set as the controlling terminal.
// If opening /dev/console times out, stdout and stderr to the logging file descriptor.
//
wil::unique_fd ConsoleFd{};
try
{
wsl::shared::retry::RetryWithTimeout<void>(
[&]() {
ConsoleFd = open("/dev/console", O_RDWR);
ConsoleFd = open("/dev/console", O_RDWR | O_CLOEXEC);
THROW_LAST_ERROR_IF(!ConsoleFd);
},
c_defaultRetryPeriod,
@ -660,12 +665,12 @@ int LswEntryPoint(int Argc, char* Argv[])
}
catch (...)
{
if (dup2(g_LogFd, STDOUT_FILENO) < 0)
if (dup3(g_LogFd, STDOUT_FILENO, O_CLOEXEC) < 0)
{
LOG_ERROR("dup2 failed {}", errno);
}
if (dup2(g_LogFd, STDERR_FILENO) < 0)
if (dup3(g_LogFd, STDERR_FILENO, O_CLOEXEC) < 0)
{
LOG_ERROR("dup2 failed {}", errno);
}
@ -701,16 +706,27 @@ int LswEntryPoint(int Argc, char* Argv[])
// Enable the loopback interface.
//
wil::unique_fd Fd{socket(AF_INET, SOCK_DGRAM, IPPROTO_IP)};
if (!Fd)
{
LOG_ERROR("socket failed {}", errno);
return -1;
wil::unique_fd Fd{socket(AF_INET, SOCK_DGRAM, IPPROTO_IP)};
if (!Fd)
{
LOG_ERROR("socket failed {}", errno);
return -1;
}
if (EnableInterface(Fd.get(), "lo") < 0)
{
return -1;
}
}
if (EnableInterface(Fd.get(), "lo") < 0)
//
// Make sure not to leak std fds to user processes.
//
for (int fd : {STDIN_FILENO, STDOUT_FILENO, STDERR_FILENO})
{
return -1;
SetCloseOnExec(fd, true);
}
//

View File

@ -641,7 +641,7 @@ class LSWTests
{
auto [fds, pid] = createProcess({"/bin/cat"}, {{0, LinuxFileOutput, "/tmp/output"}, {2, Default, nullptr}});
VERIFY_ARE_EQUAL(ReadToString((SOCKET)fds[1].get()), "/bin/cat: -: Bad file descriptor\n");
VERIFY_ARE_EQUAL(ReadToString((SOCKET)fds[1].get()), "/bin/cat: standard output: Bad file descriptor\n");
VERIFY_ARE_EQUAL(wait(pid), 1);
}
}
@ -795,7 +795,7 @@ class LSWTests
settings.DisplayName = L"LSW";
settings.Memory.MemoryMb = 2048;
settings.Options.BootTimeoutMs = 30 * 1000;
settings.Networking.Mode = NetworkingModeNAT;
settings.Networking.Mode = NetworkingModeNone;
auto vm = CreateVm(&settings);
@ -824,7 +824,7 @@ class LSWTests
settings.DisplayName = L"LSW";
settings.Memory.MemoryMb = 2048;
settings.Options.BootTimeoutMs = 30 * 1000;
settings.Networking.Mode = NetworkingModeNAT;
settings.Networking.Mode = NetworkingModeNone;
auto vm = CreateVm(&settings);
@ -897,4 +897,43 @@ class LSWTests
VERIFY_SUCCEEDED(WslUnmountWindowsFolder(vm.get(), "/win-path"));
}
}
// This test case validates that no file descriptors are leaked to user processes.
TEST_METHOD(Fd)
{
WSL2_TEST_ONLY();
VirtualMachineSettings settings{};
settings.CPU.CpuCount = 4;
settings.DisplayName = L"LSW";
settings.Memory.MemoryMb = 2048;
settings.Options.BootTimeoutMs = 30 * 1000;
settings.Networking.Mode = NetworkingModeNone;
auto vm = CreateVm(&settings);
std::vector<ProcessFileDescriptorSettings> fds(1);
fds[0].Number = 1;
fds[0].Type = Default;
const char* args[] = {"/bin/bash", "-c", "echo /proc/self/fd/* && readlink /proc/self/fd/*", nullptr};
CreateProcessSettings createProcessSettings{};
createProcessSettings.Executable = "/bin/bash";
createProcessSettings.Arguments = args;
createProcessSettings.FileDescriptors = fds.data();
createProcessSettings.FdCount = 1;
int pid = -1;
VERIFY_SUCCEEDED(WslCreateLinuxProcess(vm.get(), &createProcessSettings, &pid));
wil::unique_socket output{(SOCKET)fds[0].Handle};
auto result = ReadToString(output.get());
// Note: fd/0 is opened readlink to read the actual content of /proc/fd.
if (!PathMatchSpecA(result.c_str(), "/proc/self/fd/0 /proc/self/fd/1\nsocket:[*]\n"))
{
LogInfo("Found additional fds: %hs", result.c_str());
VERIFY_FAIL();
}
}
};