Implement WSL for apps TCP port forwarding (#13299)

* Merge User/oneblue/prototype lsw to a feature/wsl-for-apps (#13278)

* Save state

* Save state

* Save state

* Get the VM booting

* VM booting

* Disk mounting

* CreateLinuxProcess

* Move to a proper API

* Implement env

* Progress on fd

* Redesign fork model

* Add process wait & signal

* Include nuget package

* Format

* Format

* Format

* Cleanup

* Format

* Format

* Format

* Fix nuspec

* Implement VM termination

* Add lsw dll

* Implement termination callbacks

* Save state

* Various fixes in API header

* Save state

* Test coverage

* Don't block all signals by default

* Writeable overlay

* Add struct keyword

* rename WslCreateVirualMachine -> WslCreateVirtualMachine

* rename Environmnent -> Environment

* rename HandleToUlong -> HandleToULong

* ensure correct amount of memory is used to create the LSW VM

* Adjust LSWVirtualMachine::AttachDisk so it does not require caller to have elevated permission

* Add missing struct keyword

* PR feedback

* PR review

---------

Co-authored-by: Ben Hillis <benhill@ntdev.microsoft.com>

* Save state

* Progress

* Save state

* Move tests to socat

* Increase test coverage

* More test coverage

* Cleanup before PR

* Cleanup before PR

* Refresh branch

* Update comment

* Only run LSW tests on wsl2

* Fix missing '.0' in flight-stage.yml

* Revert "Fix missing '.0' in flight-stage.yml"

This reverts commit 9b3e9ae38f5086b201f2100bf4196a7afe14d9eb.

* PR suggestions

---------

Co-authored-by: Ben Hillis <benhill@ntdev.microsoft.com>
This commit is contained in:
Blue 2025-08-04 15:40:40 -07:00 committed by Blue
parent b5769b4f97
commit c07fa29f0d
20 changed files with 897 additions and 200 deletions

View File

@ -14,6 +14,7 @@ Abstract:
#include "util.h"
#include "SocketChannel.h"
#include "message.h"
#include "localhost.h"
#include <utmp.h>
#include <sys/wait.h>
#include <sys/mount.h>
@ -355,6 +356,18 @@ void HandleMessageImpl(wsl::shared::SocketChannel& Channel, const LSW_EXEC& Mess
Channel.SendResultMessage<int32_t>(errno);
}
void HandleMessageImpl(wsl::shared::SocketChannel& Channel, const LSW_PORT_RELAY& Message, const gsl::span<gsl::byte>& Buffer)
{
sockaddr_vm SocketAddress{};
wil::unique_fd ListenSocket{UtilListenVsockAnyPort(&SocketAddress, 10, false)};
THROW_LAST_ERROR_IF(!ListenSocket);
Channel.SendResultMessage<uint32_t>(SocketAddress.svm_port);
Channel.Close();
UtilSetThreadName("PortRelay");
RunLocalHostRelay(SocketAddress, ListenSocket.get());
}
void HandleMessageImpl(wsl::shared::SocketChannel& Channel, const LSW_WAITPID& Message, const gsl::span<gsl::byte>& Buffer)
{
LSW_WAITPID_RESULT response{};
@ -448,7 +461,8 @@ void ProcessMessage(wsl::shared::SocketChannel& Channel, LX_MESSAGE_TYPE Type, c
{
try
{
HandleMessage<LSW_GET_DISK, LSW_MOUNT, LSW_EXEC, LSW_FORK, LSW_CONNECT, LSW_WAITPID, LSW_SIGNAL, LSW_TTY_RELAY>(Channel, Type, Buffer);
HandleMessage<LSW_GET_DISK, LSW_MOUNT, LSW_EXEC, LSW_FORK, LSW_CONNECT, LSW_WAITPID, LSW_SIGNAL, LSW_TTY_RELAY, LSW_PORT_RELAY>(
Channel, Type, Buffer);
}
catch (...)
{

View File

@ -26,125 +26,6 @@
namespace {
void ListenThread(sockaddr_vm hvSocketAddress, int listenSocket)
{
pollfd pollDescriptors[] = {{listenSocket, POLLIN}};
for (;;)
{
int result = poll(pollDescriptors, COUNT_OF(pollDescriptors), -1);
if (result < 0)
{
LOG_ERROR("poll failed {}", errno);
return;
}
if ((pollDescriptors[0].revents & POLLIN) == 0)
{
LOG_ERROR("unexpected revents {:x}", pollDescriptors[0].revents);
return;
}
// Accept a connection and start a relay worker thread.
wil::unique_fd relaySocket{UtilAcceptVsock(listenSocket, hvSocketAddress)};
THROW_LAST_ERROR_IF(!relaySocket);
std::thread([relaySocket = std::move(relaySocket)]() {
try
{
// Read a message to determine which TCP port to connect to.
std::vector<gsl::byte> buffer(sizeof(LX_INIT_START_SOCKET_RELAY));
auto bytesRead = UtilReadBuffer(relaySocket.get(), buffer);
if (bytesRead == 0)
{
return;
}
auto* message = gslhelpers::try_get_struct<LX_INIT_START_SOCKET_RELAY>(gsl::make_span(buffer.data(), bytesRead));
THROW_ERRNO_IF(EINVAL, !message || (message->Header.MessageType != LxInitMessageStartSocketRelay));
// Connect to the actual socket address and set up a relay.
//
// N.B. While the relay was being set up, the server may have
// stopped listening.
sockaddr* socketAddress;
int socketAddressSize;
sockaddr_in sockaddrIn{};
sockaddr_in6 sockaddrIn6{};
if (message->Family == AF_INET)
{
sockaddrIn.sin_family = AF_INET;
sockaddrIn.sin_port = htons(message->Port);
sockaddrIn.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
socketAddress = reinterpret_cast<sockaddr*>(&sockaddrIn);
socketAddressSize = sizeof(sockaddrIn);
}
else if (message->Family == AF_INET6)
{
sockaddrIn6.sin6_family = AF_INET6;
sockaddrIn6.sin6_port = htons(message->Port);
sockaddrIn6.sin6_addr = IN6ADDR_LOOPBACK_INIT;
socketAddress = reinterpret_cast<sockaddr*>(&sockaddrIn6);
socketAddressSize = sizeof(sockaddrIn6);
}
else
{
THROW_ERRNO(EINVAL);
}
wil::unique_fd tcpSocket{socket(socketAddress->sa_family, SOCK_STREAM, IPPROTO_TCP)};
THROW_LAST_ERROR_IF(!tcpSocket);
if (TEMP_FAILURE_RETRY(connect(tcpSocket.get(), socketAddress, socketAddressSize)) < 0)
{
return;
}
// Resize the buffer to be the requested size.
buffer.resize(message->BufferSize);
// Begin relaying data.
int outFd[2] = {tcpSocket.get(), relaySocket.get()};
pollfd pollDescriptors[] = {{relaySocket.get(), POLLIN}, {tcpSocket.get(), POLLIN}};
for (;;)
{
if ((pollDescriptors[0].fd == -1) || (pollDescriptors[1].fd == -1))
{
return;
}
THROW_LAST_ERROR_IF(poll(pollDescriptors, COUNT_OF(pollDescriptors), -1) < 0);
bytesRead = 0;
for (int Index = 0; Index < COUNT_OF(pollDescriptors); Index += 1)
{
if (pollDescriptors[Index].revents & POLLIN)
{
bytesRead = UtilReadBuffer(pollDescriptors[Index].fd, buffer);
if (bytesRead == 0)
{
pollDescriptors[Index].fd = -1;
shutdown(outFd[Index], SHUT_WR);
}
else if (bytesRead < 0)
{
return;
}
else if (UtilWriteBuffer(outFd[Index], buffer.data(), bytesRead) < 0)
{
return;
}
}
}
}
}
CATCH_LOG()
}).detach();
}
return;
}
std::vector<sockaddr_storage> ParseTcpFile(int family, FILE* file)
{
char* line = nullptr;
@ -395,6 +276,126 @@ int ScanProcNetTCP(wsl::shared::SocketChannel& channel)
}
} // namespace
void RunLocalHostRelay(sockaddr_vm hvSocketAddress, int listenSocket)
{
pollfd pollDescriptors[] = {{listenSocket, POLLIN}};
for (;;)
{
int result = poll(pollDescriptors, COUNT_OF(pollDescriptors), -1);
if (result < 0)
{
LOG_ERROR("poll failed {}", errno);
return;
}
if ((pollDescriptors[0].revents & POLLIN) == 0)
{
LOG_ERROR("unexpected revents {:x}", pollDescriptors[0].revents);
return;
}
// Accept a connection and start a relay worker thread.
wil::unique_fd relaySocket{UtilAcceptVsock(listenSocket, hvSocketAddress)};
THROW_LAST_ERROR_IF(!relaySocket);
std::thread([relaySocket = std::move(relaySocket)]() {
try
{
// Read a message to determine which TCP port to connect to.
std::vector<gsl::byte> buffer(sizeof(LX_INIT_START_SOCKET_RELAY));
auto bytesRead = UtilReadBuffer(relaySocket.get(), buffer);
if (bytesRead == 0)
{
return;
}
auto* message = gslhelpers::try_get_struct<LX_INIT_START_SOCKET_RELAY>(gsl::make_span(buffer.data(), bytesRead));
THROW_ERRNO_IF(EINVAL, !message || (message->Header.MessageType != LxInitMessageStartSocketRelay));
// Connect to the actual socket address and set up a relay.
//
// N.B. During the time setting up the relay the server may have
// stopped listening.
sockaddr* socketAddress;
int socketAddressSize;
sockaddr_in sockaddrIn{};
sockaddr_in6 sockaddrIn6{};
if (message->Family == AF_INET)
{
sockaddrIn.sin_family = AF_INET;
sockaddrIn.sin_port = htons(message->Port);
sockaddrIn.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
socketAddress = reinterpret_cast<sockaddr*>(&sockaddrIn);
socketAddressSize = sizeof(sockaddrIn);
}
else if (message->Family == AF_INET6)
{
sockaddrIn6.sin6_family = AF_INET6;
sockaddrIn6.sin6_port = htons(message->Port);
sockaddrIn6.sin6_addr = IN6ADDR_LOOPBACK_INIT;
socketAddress = reinterpret_cast<sockaddr*>(&sockaddrIn6);
socketAddressSize = sizeof(sockaddrIn6);
}
else
{
THROW_ERRNO(EINVAL);
}
wil::unique_fd tcpSocket{socket(socketAddress->sa_family, SOCK_STREAM, IPPROTO_TCP)};
THROW_LAST_ERROR_IF(!tcpSocket);
if (TEMP_FAILURE_RETRY(connect(tcpSocket.get(), socketAddress, socketAddressSize)) < 0)
{
return;
}
// Resize the buffer to be the requested size.
buffer.resize(message->BufferSize);
// Begin relaying data.
int outFd[2] = {tcpSocket.get(), relaySocket.get()};
pollfd pollDescriptors[] = {{relaySocket.get(), POLLIN}, {tcpSocket.get(), POLLIN}};
for (;;)
{
if ((pollDescriptors[0].fd == -1) || (pollDescriptors[1].fd == -1))
{
return;
}
THROW_LAST_ERROR_IF(poll(pollDescriptors, COUNT_OF(pollDescriptors), -1) < 0);
bytesRead = 0;
for (int Index = 0; Index < COUNT_OF(pollDescriptors); Index += 1)
{
if (pollDescriptors[Index].revents & POLLIN)
{
bytesRead = UtilReadBuffer(pollDescriptors[Index].fd, buffer);
if (bytesRead == 0)
{
pollDescriptors[Index].fd = -1;
shutdown(outFd[Index], SHUT_WR);
}
else if (bytesRead < 0)
{
return;
}
else if (UtilWriteBuffer(outFd[Index], buffer.data(), bytesRead) < 0)
{
return;
}
}
}
}
}
CATCH_LOG()
}).detach();
}
return;
}
// Create a thread to monitor for connections to relay.
int StartLocalhostRelay(wsl::shared::SocketChannel& channel, int GuestRelayFd, bool ScanForPorts)
try
@ -419,7 +420,7 @@ try
std::thread([hvSocketAddress, listenSocket = std::move(listenSocket)]() {
try
{
ListenThread(hvSocketAddress, listenSocket.get());
RunLocalHostRelay(hvSocketAddress, listenSocket.get());
}
CATCH_LOG()
}).detach();

View File

@ -2,3 +2,5 @@
#pragma once
int RunPortTracker(int argc, char** argv);
void RunLocalHostRelay(sockaddr_vm hvSocketAddress, int listenSocket);

View File

@ -1666,6 +1666,12 @@ Return Value:
int UtilMountFile(const char* Source, const char* Destination)
try
{
// Is the file is a symlink, delete it since that would break the mount.
if (std::filesystem::is_symlink(Destination))
{
std::filesystem::remove(Destination);
}
wil::unique_fd Fd{open(Destination, (O_CREAT | O_WRONLY), 0755)};
THROW_LAST_ERROR_IF(!Fd);

View File

@ -67,15 +67,14 @@ public:
return *this;
}
// Note: 'name' must be a global string, since SocketChannel doesn't make a copy of it.
SocketChannel(TSocket&& socket, const char* name) : m_socket(std::move(socket)), m_name(name)
SocketChannel(TSocket&& socket, std::string&& name) : m_socket(std::move(socket)), m_name(std::move(name))
{
}
#ifdef WIN32
SocketChannel(TSocket&& socket, const char* name, HANDLE exitEvent) :
m_socket(std::move(socket)), m_exitEvent(exitEvent), m_name(name)
SocketChannel(TSocket&& socket, std::string&& name, HANDLE exitEvent) :
m_socket(std::move(socket)), m_exitEvent(exitEvent), m_name(std::move(name))
{
}
@ -91,7 +90,7 @@ public:
#ifdef WIN32
THROW_HR_MSG(E_UNEXPECTED, "Incorrect channel usage detected on channel: %hs, message type: %hs", m_name, ToString(TMessage::Type));
THROW_HR_MSG(E_UNEXPECTED, "Incorrect channel usage detected on channel: %hs, message type: %hs", m_name.c_str(), ToString(TMessage::Type));
#else
@ -101,7 +100,7 @@ public:
#endif
}
THROW_INVALID_ARG_IF(m_name == nullptr || span.size() < sizeof(TMessage));
THROW_INVALID_ARG_IF(m_name.empty() || span.size() < sizeof(TMessage));
m_sent_messages++;
@ -114,7 +113,7 @@ public:
WSL_LOG(
"SentMessage",
TraceLoggingValue(m_name, "Name"),
TraceLoggingValue(m_name.c_str(), "Name"),
TraceLoggingValue(reinterpret_cast<const TMessage*>(span.data())->PrettyPrint().c_str(), "Content"));
wsl::windows::common::socket::Send(m_socket.get(), span, m_exitEvent);
@ -148,6 +147,13 @@ public:
}
}
template <typename TMessage>
void SendMessage()
{
TMessage message;
SendMessage(message);
}
template <typename TMessage>
void SendMessage(TMessage& message)
{
@ -156,7 +162,7 @@ public:
if (header.MessageSize != sizeof(message))
{
#ifdef WIN32
THROW_HR_MSG(E_INVALIDARG, "Incorrect header size for message type: %u on channel: %hs", header.MessageType, m_name);
THROW_HR_MSG(E_INVALIDARG, "Incorrect header size for message type: %u on channel: %hs", header.MessageType, m_name.c_str());
#else
LOG_ERROR("Incorrect header size for message type: {} on channel: {}", header.MessageType, m_name);
THROW_ERRNO(EINVAL);
@ -180,7 +186,7 @@ public:
template <typename TMessage>
std::pair<TMessage*, gsl::span<gsl::byte>> ReceiveMessageOrClosed(TTimeout timeout = DefaultSocketTimeout)
{
WI_ASSERT(m_name != nullptr);
WI_ASSERT(!m_name.empty());
// Ensure that no other thread is using this channel.
const std::unique_lock<std::mutex> lock{m_receiveMutex, std::try_to_lock};
@ -189,7 +195,7 @@ public:
#ifdef WIN32
THROW_HR_MSG(E_UNEXPECTED, "Incorrect channel usage detected on channel: %hs", m_name);
THROW_HR_MSG(E_UNEXPECTED, "Incorrect channel usage detected on channel: %hs", m_name.c_str());
#else
LOG_ERROR("Incorrect channel usage detected on channel: {}", m_name);
@ -207,7 +213,12 @@ public:
#ifdef WIN32
if (errno == HCS_E_CONNECTION_TIMEOUT)
{
THROW_HR_MSG(HCS_E_CONNECTION_TIMEOUT, "Timeout: %d, expected type: %hs, channel: %hs", timeout, ToString(TMessage::Type), m_name);
THROW_HR_MSG(
HCS_E_CONNECTION_TIMEOUT,
"Timeout: %d, expected type: %hs, channel: %hs",
timeout,
ToString(TMessage::Type),
m_name.c_str());
}
#endif
@ -220,7 +231,11 @@ public:
{
#ifdef WIN32
THROW_HR_MSG(
E_UNEXPECTED, "Message size is too small: %zd, expected type: %hs, channel: %hs", receivedSpan.size(), ToString(TMessage::Type), m_name);
E_UNEXPECTED,
"Message size is too small: %zd, expected type: %hs, channel: %hs",
receivedSpan.size(),
ToString(TMessage::Type),
m_name.c_str());
#else
LOG_ERROR("MessageSize is too small: {}, expected type: {}, channel: {}", receivedSpan.size(), ToString(TMessage::Type), m_name);
THROW_ERRNO(EINVAL);
@ -231,7 +246,9 @@ public:
#ifdef WIN32
WSL_LOG(
"ReceivedMessage", TraceLoggingValue(m_name, "Name"), TraceLoggingValue(message->PrettyPrint().c_str(), "Content"));
"ReceivedMessage",
TraceLoggingValue(m_name.c_str(), "Name"),
TraceLoggingValue(message->PrettyPrint().c_str(), "Content"));
#else
if (LoggingEnabled())
{
@ -248,7 +265,7 @@ public:
if (message == nullptr)
{
#ifdef WIN32
THROW_HR_MSG(E_UNEXPECTED, "Expected message %hs, but socket %hs was closed", ToString(TMessage::Type), m_name);
THROW_HR_MSG(E_UNEXPECTED, "Expected message %hs, but socket %hs was closed", ToString(TMessage::Type), m_name.c_str());
#else
LOG_ERROR("ExpectedMessage {}, but socket {} was closed", ToString(TMessage::Type), m_name);
THROW_ERRNO(EINVAL);
@ -350,7 +367,7 @@ private:
header.SequenceNumber,
expected,
expectedSequence,
m_name);
m_name.c_str());
#else
LOG_ERROR(
@ -408,7 +425,7 @@ private:
uint32_t m_sent_messages = 0;
uint32_t m_received_messages = 0;
bool m_ignore_sequence = false;
const char* m_name{};
std::string m_name{};
std::mutex m_sendMutex;
std::mutex m_receiveMutex;
};

View File

@ -380,7 +380,10 @@ typedef enum _LX_MESSAGE_TYPE
LxMessageLswWaitPidResponse,
LxMessageLswSignal,
LxMessageLswShutdown,
LxMessageLswRelayTty
LxMessageLswRelayTty,
LxMessageLswMapPort,
LxMessageLswConnectRelay,
LxMessageLswPortRelay,
} LX_MESSAGE_TYPE,
*PLX_MESSAGE_TYPE;
@ -482,6 +485,9 @@ inline auto ToString(LX_MESSAGE_TYPE messageType)
X(LxMessageLswSignal)
X(LxMessageLswShutdown)
X(LxMessageLswRelayTty)
X(LxMessageLswMapPort)
X(LxMessageLswConnectRelay)
X(LxMessageLswPortRelay)
default:
return "<unexpected LX_MESSAGE_TYPE>";
}
@ -1073,13 +1079,14 @@ typedef struct _LX_INIT_START_SOCKET_RELAY
{
static inline auto Type = LxInitMessageStartSocketRelay;
DECLARE_MESSAGE_CTOR(_LX_INIT_START_SOCKET_RELAY);
MESSAGE_HEADER Header;
unsigned short Family;
unsigned short Port;
int HvSocketPort;
size_t BufferSize;
PRETTY_PRINT(FIELD(Header), FIELD(Family), FIELD(Port), FIELD(HvSocketPort), FIELD(BufferSize));
PRETTY_PRINT(FIELD(Header), FIELD(Family), FIELD(Port), FIELD(BufferSize));
} LX_INIT_START_SOCKET_RELAY, *PLX_INIT_START_SOCKET_RELAY;
using PCLX_INIT_START_SOCKET_RELAY = const LX_INIT_START_SOCKET_RELAY*;
@ -1715,6 +1722,44 @@ struct LSW_SHUTDOWN
PRETTY_PRINT(FIELD(Header));
};
struct LSW_MAP_PORT
{
static inline auto Type = LxMessageLswMapPort;
using TResponse = RESULT_MESSAGE<uint32_t>;
DECLARE_MESSAGE_CTOR(LSW_MAP_PORT);
MESSAGE_HEADER Header;
uint16_t WindowsPort;
uint16_t LinuxPort;
uint32_t AddressFamily;
bool Stop;
PRETTY_PRINT(FIELD(Header));
};
struct LSW_CONNECT_RELAY
{
static inline auto Type = LxMessageLswConnectRelay;
using TResponse = RESULT_MESSAGE<uint32_t>;
DECLARE_MESSAGE_CTOR(LSW_CONNECT_RELAY);
MESSAGE_HEADER Header;
uint16_t Port;
uint16_t Family;
PRETTY_PRINT(FIELD(Header), FIELD(Port), FIELD(Family));
};
struct LSW_PORT_RELAY
{
static inline auto Type = LxMessageLswPortRelay;
using TResponse = RESULT_MESSAGE<uint32_t>;
DECLARE_MESSAGE_CTOR(LSW_PORT_RELAY);
MESSAGE_HEADER Header;
PRETTY_PRINT(FIELD(Header));
};
typedef struct _LX_MINI_INIT_IMPORT_RESULT
{
static inline auto Type = LxMiniInitMessageImportResult;

View File

@ -137,7 +137,11 @@ static const std::map<HRESULT, LPCWSTR> g_commonErrors{
X_WIN32(STATUS_SHUTDOWN_IN_PROGRESS),
X(WININET_E_TIMEOUT),
X(WSAEADDRNOTAVAIL),
X_WIN32(ERROR_BAD_IMPERSONATION_LEVEL)};
X_WIN32(ERROR_BAD_IMPERSONATION_LEVEL),
X_WIN32(ERROR_NO_DATA),
X_WIN32(WSAETIMEDOUT),
X_WIN32(ERROR_OPERATION_ABORTED),
X_WIN32(WSAECONNREFUSED)};
#undef X

View File

@ -21,6 +21,7 @@ enum RelayMode
DebugConsole,
DebugConsoleRelay,
PortRelay,
WSLAPortRelay,
KdRelay,
InteractiveConsoleRelay
};

View File

@ -17,6 +17,24 @@ Abstract:
#include "LSWApi.h"
#include "wslrelay.h"
namespace {
void ConfigureComSecurity(IUnknown* Instance)
{
wil::com_ptr_nothrow<IClientSecurity> clientSecurity;
THROW_IF_FAILED(Instance->QueryInterface(IID_PPV_ARGS(&clientSecurity)));
// Get the current proxy blanket settings.
DWORD authnSvc, authzSvc, authnLvl, capabilites;
THROW_IF_FAILED(clientSecurity->QueryBlanket(Instance, &authnSvc, &authzSvc, NULL, &authnLvl, NULL, NULL, &capabilites));
// Make sure that dynamic cloaking is used.
WI_ClearFlag(capabilites, EOAC_STATIC_CLOAKING);
WI_SetFlag(capabilites, EOAC_DYNAMIC_CLOAKING);
THROW_IF_FAILED(clientSecurity->SetBlanket(Instance, authnSvc, authzSvc, NULL, authnLvl, RPC_C_IMP_LEVEL_IMPERSONATE, NULL, capabilites));
}
} // namespace
class DECLSPEC_UUID("7BC4E198-6531-4FA6-ADE2-5EF3D2A04DFF") CallbackInstance
: public Microsoft::WRL::RuntimeClass<Microsoft::WRL::RuntimeClassFlags<Microsoft::WRL::ClassicCom>, ITerminationCallback, IFastRundown>
{
@ -55,6 +73,7 @@ try
wil::com_ptr<ILSWUserSession> session;
THROW_IF_FAILED(CoCreateInstance(__uuidof(LSWUserSession), nullptr, CLSCTX_LOCAL_SERVER, IID_PPV_ARGS(&session)));
ConfigureComSecurity(session.get());
wil::com_ptr<ILSWVirtualMachine> virtualMachineInstance;
@ -69,19 +88,7 @@ try
settings.EnableDnsTunneling = UserSettings->Networking.DnsTunneling;
THROW_IF_FAILED(session->CreateVirtualMachine(&settings, &virtualMachineInstance));
wil::com_ptr_nothrow<IClientSecurity> clientSecurity;
THROW_IF_FAILED(virtualMachineInstance->QueryInterface(IID_PPV_ARGS(&clientSecurity)));
// Get the current proxy blanket settings.
DWORD authnSvc, authzSvc, authnLvl, capabilites;
THROW_IF_FAILED(clientSecurity->QueryBlanket(virtualMachineInstance.get(), &authnSvc, &authzSvc, NULL, &authnLvl, NULL, NULL, &capabilites));
// Make sure that dynamic cloaking is used.
WI_ClearFlag(capabilites, EOAC_STATIC_CLOAKING);
WI_SetFlag(capabilites, EOAC_DYNAMIC_CLOAKING);
THROW_IF_FAILED(clientSecurity->SetBlanket(
virtualMachineInstance.get(), authnSvc, authzSvc, NULL, authnLvl, RPC_C_IMP_LEVEL_IMPERSONATE, NULL, capabilites));
ConfigureComSecurity(virtualMachineInstance.get());
// Register termination callback, if specified
if (UserSettings->Options.TerminationCallback != nullptr)
@ -202,6 +209,18 @@ void WslReleaseVirtualMachine(LSWVirtualMachineHandle VirtualMachine)
reinterpret_cast<ILSWVirtualMachine*>(VirtualMachine)->Release();
}
HRESULT WslMapPort(LSWVirtualMachineHandle VirtualMachine, const PortMappingSettings* UserSettings)
{
return reinterpret_cast<ILSWVirtualMachine*>(VirtualMachine)
->MapPort(UserSettings->AddressFamily, UserSettings->WindowsPort, UserSettings->LinuxPort, false);
}
HRESULT WslUnmapPort(LSWVirtualMachineHandle VirtualMachine, const PortMappingSettings* UserSettings)
{
return reinterpret_cast<ILSWVirtualMachine*>(VirtualMachine)
->MapPort(UserSettings->AddressFamily, UserSettings->WindowsPort, UserSettings->LinuxPort, true);
}
HRESULT WslLaunchInteractiveTerminal(HANDLE Input, HANDLE Output, HANDLE* Process)
try
{
@ -280,4 +299,4 @@ EXTERN_C BOOL STDAPICALLTYPE DllMain(_In_ HINSTANCE Instance, _In_ DWORD Reason,
}
return TRUE;
}
}

View File

@ -155,6 +155,13 @@ struct WaitResult
int32_t Code; // Signal number or exit code
};
struct PortMappingSettings
{
uint16_t WindowsPort;
uint16_t LinuxPort;
int AddressFamily;
};
HRESULT WslLaunchInteractiveTerminal(HANDLE Input, HANDLE Output, HANDLE* Process);
HRESULT WslWaitForLinuxProcess(LSWVirtualMachineHandle VirtualMachine, int32_t Pid, uint64_t TimeoutMs, struct WaitResult* Result);
@ -167,6 +174,10 @@ void WslReleaseVirtualMachine(LSWVirtualMachineHandle VirtualMachine);
HRESULT WslLaunchDebugShell(LSWVirtualMachineHandle VirtualMachine, HANDLE* Process); // Used for development, might remove
HRESULT WslMapPort(LSWVirtualMachineHandle VirtualMachine, const struct PortMappingSettings* Settings);
HRESULT WslUnmapPort(LSWVirtualMachineHandle VirtualMachine, const struct PortMappingSettings* Settings);
#ifdef __cplusplus
}
#endif

View File

@ -11,4 +11,6 @@ EXPORTS
WslReleaseVirtualMachine
WslShutdownVirtualMachine
WslLaunchInteractiveTerminal
WslLaunchDebugShell
WslLaunchDebugShell
WslMapPort
WslUnmapPort

View File

@ -292,6 +292,8 @@ void LSWVirtualMachine::ConfigureNetworking()
m_computeSystem.get(), wsl::core::NatNetworking::CreateNetwork(config), std::move(gnsSocket), config, wil::unique_socket{});
m_networkEngine->Initialize();
LaunchPortRelay();
}
else
{
@ -460,8 +462,7 @@ std::tuple<int32_t, int32_t, wsl::shared::SocketChannel> LSWVirtualMachine::Fork
auto socket = wsl::windows::common::hvsocket::Connect(m_vmId, port, m_vmExitEvent.get(), m_settings.BootTimeoutMs);
// TODO: pid in channel name
return std::make_tuple(pid, ptyMaster, wsl::shared::SocketChannel{std::move(socket), "ForkedChannel"});
return std::make_tuple(pid, ptyMaster, wsl::shared::SocketChannel{std::move(socket), std::to_string(pid)});
}
wil::unique_socket LSWVirtualMachine::ConnectSocket(wsl::shared::SocketChannel& Channel, int32_t Fd)
@ -657,3 +658,76 @@ bool LSWVirtualMachine::ParseTtyInformation(const LSW_PROCESS_FD* Fds, ULONG FdC
return !foundNonTtyFd && FdCount > 0;
}
void LSWVirtualMachine::LaunchPortRelay()
{
WI_ASSERT(!m_portRelayChannelRead);
auto [_, __, channel] = Fork(LSW_FORK::ForkType::Process);
std::lock_guard lock(m_portRelaylock);
auto relayPort = channel.Transaction<LSW_PORT_RELAY>();
wil::unique_handle readPipe;
wil::unique_handle writePipe;
THROW_IF_WIN32_BOOL_FALSE(CreatePipe(&readPipe, &m_portRelayChannelWrite, nullptr, 0));
THROW_IF_WIN32_BOOL_FALSE(CreatePipe(&m_portRelayChannelRead, &writePipe, nullptr, 0));
wsl::windows::common::helpers::SetHandleInheritable(readPipe.get());
wsl::windows::common::helpers::SetHandleInheritable(writePipe.get());
wsl::windows::common::helpers::SetHandleInheritable(m_vmExitEvent.get());
// Get an impersonation token
auto userToken = wsl::windows::common::security::GetUserToken(TokenImpersonation);
auto restrictedToken = wsl::windows::common::security::CreateRestrictedToken(userToken.get());
auto path = wsl::windows::common::wslutil::GetBasePath() / L"wslrelay.exe";
auto cmd = std::format(
L"\"{}\" {} {} {} {} {} {} {} {}",
path,
wslrelay::mode_option,
static_cast<int>(wslrelay::RelayMode::WSLAPortRelay),
wslrelay::exit_event_option,
HandleToUlong(m_vmExitEvent.get()),
wslrelay::port_option,
relayPort.Result,
wslrelay::vm_id_option,
m_vmId);
WSL_LOG("LaunchWslRelay", TraceLoggingValue(cmd.c_str(), "cmd"));
wsl::windows::common::SubProcess process{nullptr, cmd.c_str()};
process.SetStdHandles(readPipe.get(), writePipe.get(), nullptr);
process.SetToken(restrictedToken.get());
process.Start();
readPipe.release();
writePipe.release();
}
HRESULT LSWVirtualMachine::MapPort(_In_ int Family, _In_ short WindowsPort, _In_ short LinuxPort, _In_ BOOL Remove)
try
{
std::lock_guard lock(m_portRelaylock);
RETURN_HR_IF(E_ILLEGAL_STATE_CHANGE, !m_portRelayChannelWrite);
LSW_MAP_PORT message;
message.WindowsPort = WindowsPort;
message.LinuxPort = LinuxPort;
message.AddressFamily = Family;
message.Stop = Remove;
DWORD bytesTransfered{};
THROW_IF_WIN32_BOOL_FALSE(WriteFile(m_portRelayChannelWrite.get(), &message, sizeof(message), &bytesTransfered, nullptr));
THROW_HR_IF_MSG(E_UNEXPECTED, bytesTransfered != sizeof(message), "%u bytes transfered", bytesTransfered);
HRESULT result = E_UNEXPECTED;
THROW_IF_WIN32_BOOL_FALSE(ReadFile(m_portRelayChannelRead.get(), &result, sizeof(result), &bytesTransfered, nullptr));
THROW_HR_IF(E_UNEXPECTED, bytesTransfered != sizeof(result));
return result;
}
CATCH_RETURN();

View File

@ -37,6 +37,7 @@ public:
IFACEMETHOD(Shutdown(ULONGLONG _In_ TimeoutMs)) override;
IFACEMETHOD(RegisterCallback(_In_ ITerminationCallback* callback)) override;
IFACEMETHOD(GetDebugShellPipe(_Out_ LPWSTR* pipePath)) override;
IFACEMETHOD(MapPort(_In_ int Family, _In_ short WindowsPort, _In_ short LinuxPort, _In_ BOOL Remove)) override;
private:
static void CALLBACK s_OnExit(_In_ HCS_EVENT* Event, _In_opt_ void* Context);
@ -50,6 +51,7 @@ private:
int32_t ExpectClosedChannelOrError(wsl::shared::SocketChannel& Channel);
wil::unique_socket ConnectSocket(wsl::shared::SocketChannel& Channel, int32_t Fd);
void LaunchPortRelay();
struct AttachedDisk
{
@ -74,8 +76,11 @@ private:
std::unique_ptr<wsl::core::INetworkingEngine> m_networkEngine;
wsl::shared::SocketChannel m_initChannel;
wil::unique_handle m_portRelayChannelRead;
wil::unique_handle m_portRelayChannelWrite;
std::map<ULONG, AttachedDisk> m_attachedDisks;
std::mutex m_lock;
std::mutex m_portRelaylock;
};
} // namespace wsl::windows::service::lsw

View File

@ -457,6 +457,7 @@ interface ILSWVirtualMachine : IUnknown
HRESULT Shutdown([in] ULONGLONG TimeoutMs);
HRESULT RegisterCallback([in] ITerminationCallback* terminationCallback);
HRESULT GetDebugShellPipe([out] LPWSTR* pipePath);
HRESULT MapPort([in] int Family, [in] short WindowsPort, [in] short LinuxPort, [in] BOOL Remove);
}
typedef

View File

@ -324,3 +324,294 @@ try
}
}
CATCH_LOG()
struct PortRelay
{
wil::unique_socket ListenSocket;
uint32_t LinuxPort;
uint32_t RelayPort;
wil::unique_event AcceptEvent{wil::EventOptions::None};
wil::unique_event StopRelayEvent{wil::EventOptions::None};
OVERLAPPED Overlapped{};
bool Pending = false;
wil::unique_socket PendingSocket;
int Family;
PortRelay(wil::unique_socket&& ListenSocket, uint32_t LinuxPort, uint32_t RelayPort, int Family) :
ListenSocket(std::move(ListenSocket)), LinuxPort(LinuxPort), RelayPort(RelayPort), Family(Family)
{
Overlapped.hEvent = AcceptEvent.get();
}
~PortRelay()
{
StopRelayEvent.SetEvent();
if (Pending) // Cancel pending accept(), if any.
{
DWORD bytesProcessed;
DWORD flagsReturned;
CancelIoEx(reinterpret_cast<HANDLE>(ListenSocket.get()), &Overlapped);
WSAGetOverlappedResult(ListenSocket.get(), &Overlapped, &bytesProcessed, TRUE, &flagsReturned);
}
}
void LaunchRelay(const GUID& VmId)
{
WI_VERIFY(PendingSocket);
std::thread thread{[WindowsSocket = std::move(PendingSocket), LinuxPort = LinuxPort, RelayPort = RelayPort, Family = Family, VmId = VmId]() {
try
{
WSL_LOG(
"StartPortRelay", TraceLoggingValue(LinuxPort, "LinuxPort"), TraceLoggingValue(WindowsSocket.get(), "Socket"));
RunRelay(WindowsSocket.get(), VmId, LinuxPort, RelayPort, Family);
}
CATCH_LOG();
WSL_LOG("StopPortRelay", TraceLoggingValue(LinuxPort, "LinuxPort"), TraceLoggingValue(WindowsSocket.get(), "Socket"));
}};
thread.detach();
}
static void RunRelay(SOCKET WindowsSocket, const GUID& VmId, uint32_t LinuxPort, uint32_t RelayPort, uint32_t Family)
{
wsl::shared::SocketChannel channel(wsl::windows::common::hvsocket::Connect(VmId, RelayPort), "SocketRelay");
WI_VERIFY(Family == AF_INET || Family == AF_INET6);
LX_INIT_START_SOCKET_RELAY message;
message.Port = LinuxPort;
message.Family = Family == AF_INET ? LX_AF_INET : LX_AF_INET6;
message.BufferSize = 4096;
channel.SendMessage(message);
wsl::windows::common::relay::SocketRelay(WindowsSocket, channel.Socket());
}
void CompleteAccept()
{
Pending = false;
DWORD bytes{};
DWORD flags{};
if (!WSAGetOverlappedResult(ListenSocket.get(), &Overlapped, &bytes, false, &flags))
{
THROW_WIN32(WSAGetLastError());
}
}
bool ScheduleAccept()
{
WI_VERIFY(!Pending);
PendingSocket.reset(WSASocket(Family, SOCK_STREAM, IPPROTO_TCP, nullptr, 0, WSA_FLAG_OVERLAPPED));
CHAR AcceptBuffer[2 * sizeof(SOCKADDR_STORAGE)]{};
DWORD BytesReturned{};
if (!AcceptEx(ListenSocket.get(), PendingSocket.get(), AcceptBuffer, 0, sizeof(SOCKADDR_STORAGE), sizeof(SOCKADDR_STORAGE), &BytesReturned, &Overlapped))
{
const int error = WSAGetLastError();
THROW_HR_IF(HRESULT_FROM_WIN32(error), error != WSA_IO_PENDING);
Pending = true;
return false;
}
return true;
}
};
std::shared_ptr<PortRelay> CreatePortListener(uint16_t WindowsPort, uint16_t LinuxPort, uint32_t RelayPort, int Family)
{
wil::unique_socket ListenSocket(WSASocket(Family, SOCK_STREAM, IPPROTO_TCP, nullptr, 0, WSA_FLAG_OVERLAPPED));
THROW_LAST_ERROR_IF(!ListenSocket);
constexpr BOOLEAN On = true;
THROW_LAST_ERROR_IF(setsockopt(ListenSocket.get(), SOL_SOCKET, SO_REUSEADDR, reinterpret_cast<const char*>(&On), sizeof(On)) == SOCKET_ERROR);
sockaddr* Address{};
sockaddr_in InetAddress{};
sockaddr_in6 Inet6Address{};
DWORD AddressSize{};
if (Family == AF_INET)
{
InetAddress.sin_family = AF_INET;
InetAddress.sin_port = htons(WindowsPort);
InetAddress.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
Address = reinterpret_cast<sockaddr*>(&InetAddress);
AddressSize = sizeof(InetAddress);
}
else
{
Inet6Address.sin6_family = AF_INET6;
Inet6Address.sin6_port = htons(WindowsPort);
Inet6Address.sin6_addr = IN6ADDR_LOOPBACK_INIT;
Address = reinterpret_cast<sockaddr*>(&Inet6Address);
AddressSize = sizeof(Inet6Address);
}
THROW_LAST_ERROR_IF(bind(ListenSocket.get(), Address, AddressSize) == SOCKET_ERROR);
THROW_LAST_ERROR_IF(listen(ListenSocket.get(), -1) == SOCKET_ERROR);
return std::make_shared<PortRelay>(std::move(ListenSocket), LinuxPort, RelayPort, Family);
}
void AcceptThread(std::vector<std::shared_ptr<PortRelay>>& ports, const GUID& VmId, HANDLE ExitEvent)
{
while (true)
{
// First make sure that all the accept() are scheduled
std::vector<HANDLE> events{ExitEvent};
for (auto& e : ports)
{
if (!e->Pending)
{
while (e->ScheduleAccept())
{
e->LaunchRelay(VmId); // Start the relay if accept completes immediately.
}
}
events.push_back(e->AcceptEvent.get());
}
// Then wait for IO, or exit event.
auto result = WaitForMultipleObjects(static_cast<DWORD>(events.size()), events.data(), false, INFINITE);
THROW_LAST_ERROR_IF(result == WAIT_FAILED);
if (result == 0) // If the exit event is signaled, leave the loop
{
break;
}
// Otherwise complete the accept and start a relay
try
{
ports[result - 1]->CompleteAccept();
ports[result - 1]->LaunchRelay(VmId);
}
CATCH_LOG();
}
}
std::optional<LSW_MAP_PORT> ReceiveServiceMessage()
{
LSW_MAP_PORT message{};
DWORD bytesRead{};
if (!ReadFile(GetStdHandle(STD_INPUT_HANDLE), &message, sizeof(message), &bytesRead, nullptr))
{
LOG_LAST_ERROR();
return {};
}
else if (bytesRead == 0)
{
return {};
}
WI_ASSERT(message.Header.MessageSize == sizeof(message));
WI_ASSERT(message.Header.MessageType == LxMessageLswMapPort);
return message;
}
void wsl::windows::wslrelay::localhost::RunWSLAPortRelay(const GUID& VmId, uint32_t RelayPort, HANDLE ExitEvent)
{
std::map<std::tuple<uint16_t, uint16_t, uint32_t>, std::shared_ptr<PortRelay>> ports;
std::thread acceptThread;
wil::unique_event acceptThreadEvent{wil::EventOptions::ManualReset};
auto stopAcceptThread = [&]() {
if (acceptThread.joinable())
{
acceptThreadEvent.SetEvent();
acceptThread.join();
acceptThread = {};
acceptThreadEvent.ResetEvent();
}
};
auto cleanup = wil::scope_exit_log(WI_DIAGNOSTICS_INFO, [&]() { stopAcceptThread(); });
while (true)
{
// Receive a message
auto message = ReceiveServiceMessage();
if (!message.has_value())
{
return;
}
std::tuple<uint16_t, uint16_t, uint16_t> key{message->WindowsPort, message->LinuxPort, message->AddressFamily};
HRESULT result = E_UNEXPECTED;
auto sendResponse = wil::scope_exit_log(WI_DIAGNOSTICS_INFO, [&]() {
WSL_LOG(
"PortMapping",
TraceLoggingValue(result, "Result"),
TraceLoggingValue(message->WindowsPort, "WindowsPort"),
TraceLoggingValue(message->LinuxPort, "LinuxPort"),
TraceLoggingValue(message->Stop, "Remove"));
THROW_LAST_ERROR_IF(!WriteFile(GetStdHandle(STD_OUTPUT_HANDLE), &result, sizeof(result), nullptr, nullptr));
});
// Check if the binding is valid.
bool update = false;
auto it = ports.find(key);
if (message->Stop)
{
if (it == ports.end())
{
result = HRESULT_FROM_WIN32(ERROR_NOT_FOUND);
continue;
}
else
{
ports.erase(it);
update = true;
}
}
else
{
if (it != ports.end())
{
result = HRESULT_FROM_WIN32(ERROR_ALREADY_EXISTS);
continue;
}
else
{
ports.emplace(key, CreatePortListener(message->WindowsPort, message->LinuxPort, RelayPort, message->AddressFamily));
update = true;
}
}
// Update the ports list
if (update)
{
stopAcceptThread();
}
// Start the accept thread, if needed
if (!acceptThread.joinable())
{
std::vector<std::shared_ptr<PortRelay>> relays;
for (auto& e : ports)
{
relays.emplace_back(e.second);
}
acceptThread = std::thread([&, relays = std::move(relays)]() mutable {
try
{
AcceptThread(relays, VmId, acceptThreadEvent.get());
}
CATCH_LOG();
});
}
result = S_OK;
}
}

View File

@ -47,6 +47,8 @@ typedef struct _LX_PORT_LISTENER_CONTEXT
namespace wsl::windows::wslrelay::localhost {
void RelayWorker(_In_ wsl::shared::SocketChannel& SocketChannel, _In_ const GUID& VmId);
void RunWSLAPortRelay(const GUID& VmId, uint32_t RelayPort, HANDLE ExitEvent);
class Relay
{
public:

View File

@ -40,7 +40,7 @@ try
wil::unique_handle exitEvent{};
wil::unique_handle terminalInputHandle{};
wil::unique_handle terminalOutputHandle{};
int port{};
uint32_t port{};
GUID vmId{};
bool disableTelemetry = !wsl::shared::OfficialBuild;
@ -98,6 +98,20 @@ try
break;
}
case wslrelay::RelayMode::WSLAPortRelay:
{
try
{
wsl::windows::wslrelay::localhost::RunWSLAPortRelay(vmId, port, exitEvent.get());
}
catch (...)
{
LOG_CAUGHT_EXCEPTION();
}
break;
}
case wslrelay::RelayMode::KdRelay:
{
THROW_HR_IF(E_INVALIDARG, port == 0);

View File

@ -2529,3 +2529,33 @@ void DistroFileChange::Delete()
{
VERIFY_ARE_EQUAL(LxsstuLaunchWsl(std::format(L"-u root rm -f '{}'", m_path).c_str()), 0L);
}
std::string ReadToString(SOCKET Handle)
{
std::string output;
DWORD offset = 0;
while (true) // TODO: timeout
{
constexpr auto bufferSize = 512;
output.resize(output.size() + bufferSize);
int bytesRead = 0;
if ((bytesRead = recv(Handle, &output[offset], bufferSize, 0)) < 0)
{
LogError("recv failed with %lu", GetLastError());
VERIFY_FAIL();
}
if (bytesRead == 0)
{
output.resize(offset);
break;
}
output.resize(offset + bytesRead);
offset += bytesRead;
}
return output;
}

View File

@ -512,4 +512,6 @@ void StopWslService();
std::optional<GUID> GetDistributionId(LPCWSTR Name);
wil::unique_hkey OpenDistributionKey(LPCWSTR Name);
void ValidateOutput(LPCWSTR CommandLine, const std::wstring& ExpectedOutput, const std::wstring& ExpectedWarnings = L"", int ExitCode = -1);
void ValidateOutput(LPCWSTR CommandLine, const std::wstring& ExpectedOutput, const std::wstring& ExpectedWarnings = L"", int ExitCode = -1);
std::string ReadToString(SOCKET Handle);

View File

@ -18,16 +18,25 @@ Abstract:
using namespace wsl::windows::common::registry;
using unique_vm = wil::unique_any<LSWVirtualMachineHandle, decltype(WslReleaseVirtualMachine), &WslReleaseVirtualMachine>;
class LSWTests
{
WSL_TEST_CLASS(LSWTests)
wil::unique_couninitialize_call coinit = wil::CoInitializeEx();
WSADATA Data;
std::filesystem::path testVhd;
TEST_CLASS_SETUP(TestClassSetup)
{
THROW_IF_WIN32_ERROR(WSAStartup(MAKEWORD(2, 2), &Data));
auto distroKey = OpenDistributionKey(LXSS_DISTRO_NAME_TEST_L);
auto vhdPath = wsl::windows::common::registry::ReadString(distroKey.get(), nullptr, L"BasePath");
testVhd = std::filesystem::path{vhdPath} / "ext4.vhdx";
WslShutdown();
return true;
}
@ -48,7 +57,8 @@ class LSWTests
VERIFY_ARE_EQUAL(version.Revision, WSL_PACKAGE_VERSION_REVISION);
}
int RunCommand(LSWVirtualMachineHandle vm, const std::vector<const char*>& command)
std::tuple<int, wil::unique_handle, wil::unique_handle, wil::unique_handle> LaunchCommand(
LSWVirtualMachineHandle vm, const std::vector<const char*>& command)
{
auto copiedCommand = command;
if (copiedCommand.back() != nullptr)
@ -70,50 +80,52 @@ class LSWTests
int pid = -1;
VERIFY_SUCCEEDED(WslCreateLinuxProcess(vm, &createProcessSettings, &pid));
return std::make_tuple(
pid, wil::unique_handle{fds[0].Handle}, wil::unique_handle(fds[1].Handle), wil::unique_handle{fds[2].Handle});
}
int RunCommand(LSWVirtualMachineHandle vm, const std::vector<const char*>& command, int timeout = 600000)
{
auto [pid, _, __, ___] = LaunchCommand(vm, command);
WaitResult result{};
VERIFY_SUCCEEDED(WslWaitForLinuxProcess(vm, pid, 1000, &result));
VERIFY_SUCCEEDED(WslWaitForLinuxProcess(vm, pid, timeout, &result));
VERIFY_ARE_EQUAL(result.State, ProcessStateExited);
return result.Code;
}
LSWVirtualMachineHandle CreateVm(const VirtualMachineSettings* settings)
unique_vm CreateVm(const VirtualMachineSettings* settings)
{
LSWVirtualMachineHandle vm{};
VERIFY_SUCCEEDED(WslCreateVirtualMachine(settings, (LSWVirtualMachineHandle*)&vm));
unique_vm vm{};
VERIFY_SUCCEEDED(WslCreateVirtualMachine(settings, &vm));
#ifdef WSL_SYSTEM_DISTRO_PATH
std::wstring systemdDistroDiskPath = TEXT(WSL_SYSTEM_DISTRO_PATH);
#else
auto systemdDistroDiskPath = std::format(L"{}/system.vhd", wsl::windows::common::wslutil::GetMsiPackagePath().value());
#endif
DiskAttachSettings attachSettings{systemdDistroDiskPath.c_str(), true};
DiskAttachSettings attachSettings{testVhd.c_str(), true};
AttachedDiskInformation attachedDisk;
VERIFY_SUCCEEDED(WslAttachDisk(vm, &attachSettings, &attachedDisk));
VERIFY_SUCCEEDED(WslAttachDisk(vm.get(), &attachSettings, &attachedDisk));
MountSettings mountSettings{attachedDisk.Device, "/mnt", "ext4", "ro", MountFlagsChroot | MountFlagsWriteableOverlayFs};
VERIFY_SUCCEEDED(WslMount(vm, &mountSettings));
VERIFY_SUCCEEDED(WslMount(vm.get(), &mountSettings));
MountSettings devmountSettings{nullptr, "/dev", "devtmpfs", "", false};
VERIFY_SUCCEEDED(WslMount(vm, &devmountSettings));
VERIFY_SUCCEEDED(WslMount(vm.get(), &devmountSettings));
MountSettings sysmountSettings{nullptr, "/sys", "sysfs", "", false};
VERIFY_SUCCEEDED(WslMount(vm, &sysmountSettings));
VERIFY_SUCCEEDED(WslMount(vm.get(), &sysmountSettings));
MountSettings procmountSettings{nullptr, "/proc", "proc", "", false};
VERIFY_SUCCEEDED(WslMount(vm, &procmountSettings));
VERIFY_SUCCEEDED(WslMount(vm.get(), &procmountSettings));
MountSettings ptsMountSettings{nullptr, "/dev/pts", "devpts", "noatime,nosuid,noexec,gid=5,mode=620", false};
VERIFY_SUCCEEDED(WslMount(vm, &ptsMountSettings));
VERIFY_SUCCEEDED(WslMount(vm.get(), &ptsMountSettings));
return vm;
}
TEST_METHOD(CustomDmesgOutput)
{
WSL2_TEST_ONLY();
auto [read, write] = CreateSubprocessPipe(false, false);
VirtualMachineSettings settings{};
@ -153,7 +165,7 @@ class LSWTests
write.reset();
auto detach = wil::scope_exit_log(WI_DIAGNOSTICS_INFO, [&]() {
WslReleaseVirtualMachine(vm);
vm.reset();
if (thread.joinable())
{
thread.join();
@ -161,9 +173,9 @@ class LSWTests
});
std::vector<const char*> cmd = {"/bin/bash", "-c", "echo DmesgTest > /dev/kmsg"};
VERIFY_ARE_EQUAL(RunCommand(vm, cmd), 0);
VERIFY_ARE_EQUAL(RunCommand(vm.get(), cmd), 0);
VERIFY_ARE_EQUAL(WslShutdownVirtualMachine(vm, 30 * 1000), S_OK);
VERIFY_ARE_EQUAL(WslShutdownVirtualMachine(vm.get(), 30 * 1000), S_OK);
detach.reset();
auto contentString = std::string(dmesgContent.begin(), dmesgContent.end());
@ -174,6 +186,8 @@ class LSWTests
TEST_METHOD(TerminationCallback)
{
WSL2_TEST_ONLY();
std::promise<std::pair<VirtualMachineTerminationReason, std::wstring>> callbackInfo;
auto callback = [](void* context, VirtualMachineTerminationReason reason, LPCWSTR details) -> HRESULT {
@ -194,19 +208,19 @@ class LSWTests
auto vm = CreateVm(&settings);
VERIFY_SUCCEEDED(WslShutdownVirtualMachine(vm, 30 * 1000));
VERIFY_SUCCEEDED(WslShutdownVirtualMachine(vm.get(), 30 * 1000));
auto future = callbackInfo.get_future();
auto result = future.wait_for(std::chrono::seconds(10));
auto [reason, details] = future.get();
VERIFY_ARE_EQUAL(reason, VirtualMachineTerminationReasonShutdown);
VERIFY_ARE_NOT_EQUAL(details, L"");
WslReleaseVirtualMachine(vm);
}
TEST_METHOD(CreateVmSmokeTest)
{
WSL2_TEST_ONLY();
VirtualMachineSettings settings{};
settings.CPU.CpuCount = 4;
settings.DisplayName = L"LSW";
@ -233,7 +247,7 @@ class LSWTests
createProcessSettings.FdCount = 3;
int pid = -1;
VERIFY_SUCCEEDED(WslCreateLinuxProcess(vm, &createProcessSettings, &pid));
VERIFY_SUCCEEDED(WslCreateLinuxProcess(vm.get(), &createProcessSettings, &pid));
LogInfo("pid: %lu", pid);
@ -249,7 +263,7 @@ class LSWTests
VERIFY_ARE_EQUAL(buffer.data(), std::string("foo\n"));
WaitResult result{};
VERIFY_SUCCEEDED(WslWaitForLinuxProcess(vm, pid, 1000, &result));
VERIFY_SUCCEEDED(WslWaitForLinuxProcess(vm.get(), pid, 1000, &result));
VERIFY_ARE_EQUAL(result.State, ProcessStateExited);
VERIFY_ARE_EQUAL(result.Code, 0);
}
@ -271,19 +285,19 @@ class LSWTests
createProcessSettings.FdCount = 3;
int pid = -1;
VERIFY_SUCCEEDED(WslCreateLinuxProcess(vm, &createProcessSettings, &pid));
VERIFY_SUCCEEDED(WslCreateLinuxProcess(vm.get(), &createProcessSettings, &pid));
// Verify that the process is in a running state
WaitResult result{};
VERIFY_SUCCEEDED(WslWaitForLinuxProcess(vm, pid, 1000, &result));
VERIFY_SUCCEEDED(WslWaitForLinuxProcess(vm.get(), pid, 1000, &result));
VERIFY_ARE_EQUAL(result.State, ProcessStateRunning);
// Verify that it can be killed.
VERIFY_SUCCEEDED(WslSignalLinuxProcess(vm, pid, 9));
VERIFY_SUCCEEDED(WslSignalLinuxProcess(vm.get(), pid, 9));
// Verify that the process is in a running state
VERIFY_SUCCEEDED(WslWaitForLinuxProcess(vm, pid, 1000, &result));
VERIFY_SUCCEEDED(WslWaitForLinuxProcess(vm.get(), pid, 1000, &result));
VERIFY_ARE_EQUAL(result.State, ProcessStateSignaled);
VERIFY_ARE_EQUAL(result.Code, 9);
}
@ -305,18 +319,18 @@ class LSWTests
createProcessSettings.FdCount = 3;
int pid = -1;
VERIFY_ARE_EQUAL(WslCreateLinuxProcess(vm, &createProcessSettings, &pid), E_FAIL);
VERIFY_ARE_EQUAL(WslCreateLinuxProcess(vm.get(), &createProcessSettings, &pid), E_FAIL);
WaitResult result{};
VERIFY_ARE_EQUAL(WslWaitForLinuxProcess(vm, 1234, 1000, &result), E_FAIL);
VERIFY_ARE_EQUAL(WslWaitForLinuxProcess(vm.get(), 1234, 1000, &result), E_FAIL);
VERIFY_ARE_EQUAL(result.State, ProcessStateUnknown);
}
WslReleaseVirtualMachine(vm);
}
TEST_METHOD(InteractiveShell)
{
WSL2_TEST_ONLY();
VirtualMachineSettings settings{};
settings.CPU.CpuCount = 4;
settings.DisplayName = L"LSW";
@ -342,7 +356,7 @@ class LSWTests
createProcessSettings.FdCount = static_cast<ULONG>(fds.size());
int pid = -1;
VERIFY_SUCCEEDED(WslCreateLinuxProcess(vm, &createProcessSettings, &pid));
VERIFY_SUCCEEDED(WslCreateLinuxProcess(vm.get(), &createProcessSettings, &pid));
auto validateTtyOutput = [&](const std::string& expected) {
std::string buffer(expected.size(), '\0');
@ -368,7 +382,7 @@ class LSWTests
};
// Expect the shell prompt to be displayed
validateTtyOutput("sh-5.1#");
validateTtyOutput("#");
writeTty("echo OK\n");
validateTtyOutput(" echo OK\r\nOK");
@ -384,6 +398,8 @@ class LSWTests
TEST_METHOD(NATNetworking)
{
WSL2_TEST_ONLY();
VirtualMachineSettings settings{};
settings.CPU.CpuCount = 4;
settings.DisplayName = L"LSW";
@ -396,13 +412,153 @@ class LSWTests
// Validate that eth0 has an ip address
VERIFY_ARE_EQUAL(
RunCommand(
vm,
vm.get(),
{"/bin/bash",
"-c",
"ip a show dev eth0 | grep -iF 'inet ' | grep -E '[0-9]{1,3}\\.[0-9]{1,3}\\.[0-9]{1,3}\\.[0-9]{1,3}'"}),
0);
// Verify that /etc/resolv.conf is configured
VERIFY_ARE_EQUAL(RunCommand(vm, {"/bin/grep", "-iF", "nameserver", "/etc/resolv.conf"}), 0);
VERIFY_ARE_EQUAL(RunCommand(vm.get(), {"/bin/grep", "-iF", "nameserver", "/etc/resolv.conf"}), 0);
}
TEST_METHOD(NATPortMapping)
{
WSL2_TEST_ONLY();
VirtualMachineSettings settings{};
settings.CPU.CpuCount = 4;
settings.DisplayName = L"LSW";
settings.Memory.MemoryMb = 2048;
settings.Options.BootTimeoutMs = 30 * 1000;
settings.Networking.Mode = NetworkingModeNAT;
auto vm = CreateVm(&settings);
auto waitForOutput = [](HANDLE Handle, const char* Content) {
std::string output;
DWORD index = 0;
while (true) // TODO: timeout
{
constexpr auto bufferSize = 100;
output.resize(output.size() + bufferSize);
DWORD bytesRead = 0;
if (!ReadFile(Handle, &output[index], bufferSize, &bytesRead, nullptr))
{
LogError("ReadFile failed with %lu", GetLastError());
VERIFY_FAIL();
}
output.resize(index + bytesRead);
if (bytesRead == 0)
{
LogError("Process exited, output: %hs", output.c_str());
VERIFY_FAIL();
}
index += bytesRead;
if (output.find(Content) != std::string::npos)
{
break;
}
}
};
auto listen = [&](short port, const char* content, bool ipv6) {
auto cmd = std::format("echo -n '{}' | /usr/bin/socat -dd TCP{}-LISTEN:{},reuseaddr -", content, ipv6 ? "6" : "", port);
auto [pid, in, out, err] = LaunchCommand(vm.get(), {"/bin/bash", "-c", cmd.c_str()});
waitForOutput(err.get(), "listening on");
return pid;
};
auto connectAndRead = [&](short port, int family) -> std::string {
SOCKADDR_INET addr{};
addr.si_family = family;
INETADDR_SETLOOPBACK((PSOCKADDR)&addr);
SS_PORT(&addr) = htons(port);
wil::unique_socket hostSocket{socket(family, SOCK_STREAM, IPPROTO_TCP)};
THROW_LAST_ERROR_IF(!hostSocket);
THROW_LAST_ERROR_IF(connect(hostSocket.get(), reinterpret_cast<SOCKADDR*>(&addr), sizeof(addr)) == SOCKET_ERROR);
return ReadToString(hostSocket.get());
};
auto expectContent = [&](short port, int family, const char* expected) {
auto content = connectAndRead(port, family);
VERIFY_ARE_EQUAL(content, expected);
};
auto expectNotBound = [&](short port, int family) {
auto result = wil::ResultFromException([&]() { connectAndRead(port, family); });
VERIFY_ARE_EQUAL(result, HRESULT_FROM_WIN32(WSAECONNREFUSED));
};
// Map port
PortMappingSettings port{1234, 80, AF_INET};
VERIFY_SUCCEEDED(WslMapPort(vm.get(), &port));
// Validate that the same port can't be bound twice
VERIFY_ARE_EQUAL(WslMapPort(vm.get(), &port), HRESULT_FROM_WIN32(ERROR_ALREADY_EXISTS));
// Check simple case
listen(80, "port80", false);
expectContent(1234, AF_INET, "port80");
// Validate that same port mapping can be reused
listen(80, "port80", false);
expectContent(1234, AF_INET, "port80");
// Validate that the connection is immediately reset if the port is not bound on the linux side
expectContent(1234, AF_INET, "");
// Add a ipv6 binding
PortMappingSettings portv6{1234, 80, AF_INET6};
VERIFY_SUCCEEDED(WslMapPort(vm.get(), &portv6));
// Validate that ipv6 bindings work as well.
listen(80, "port80ipv6", true);
expectContent(1234, AF_INET6, "port80ipv6");
// Unmap the ipv4 port
VERIFY_SUCCEEDED(WslUnmapPort(vm.get(), &port));
expectNotBound(1234, AF_INET);
// Verify that a proper error is returned if the mapping doesn't exist
VERIFY_ARE_EQUAL(WslUnmapPort(vm.get(), &port), HRESULT_FROM_WIN32(ERROR_NOT_FOUND));
// Unmap the v6 port
VERIFY_SUCCEEDED(WslUnmapPort(vm.get(), &portv6));
expectNotBound(1234, AF_INET6);
// Map another port as v6 only
PortMappingSettings portv6Only{1235, 81, AF_INET6};
VERIFY_SUCCEEDED(WslMapPort(vm.get(), &portv6Only));
listen(81, "port81ipv6", true);
expectContent(1235, AF_INET6, "port81ipv6");
expectNotBound(1235, AF_INET);
VERIFY_SUCCEEDED(WslUnmapPort(vm.get(), &portv6Only));
VERIFY_ARE_EQUAL(WslUnmapPort(vm.get(), &portv6Only), HRESULT_FROM_WIN32(ERROR_NOT_FOUND));
expectNotBound(1235, AF_INET6);
// Create a forking relay and stress test
VERIFY_SUCCEEDED(WslMapPort(vm.get(), &port));
auto [pid, in, out, err] =
LaunchCommand(vm.get(), {"/usr/bin/socat", "-dd", "TCP-LISTEN:80,fork,reuseaddr", "system:'echo -n OK'"});
waitForOutput(err.get(), "listening on");
for (auto i = 0; i < 100; i++)
{
expectContent(1234, AF_INET, "OK");
}
VERIFY_SUCCEEDED(WslUnmapPort(vm.get(), &port));
}
};