diff --git a/src/linux/init/WslDistributionConfig.h b/src/linux/init/WslDistributionConfig.h index 1d517ae..892240e 100644 --- a/src/linux/init/WslDistributionConfig.h +++ b/src/linux/init/WslDistributionConfig.h @@ -79,6 +79,7 @@ struct WslDistributionConfig bool GuiAppsEnabled = false; std::optional FeatureFlags = 0; std::optional NetworkingMode = LxMiniInitNetworkingModeNone; + std::optional VmId; // // Global state for boot state. The socket is used to delay-start the distro init process diff --git a/src/linux/init/config.cpp b/src/linux/init/config.cpp index e786a34..e205127 100644 --- a/src/linux/init/config.cpp +++ b/src/linux/init/config.cpp @@ -436,6 +436,18 @@ try ResponseChannel.SendResultMessage(static_cast(Config.NetworkingMode.value())); break; + case LxInitMessageQueryVmId: + { + wsl::shared::MessageWriter Response(LxInitMessageQueryVmId); + if (Config.VmId.has_value()) + { + Response.WriteString(Config.VmId.value()); + } + + ResponseChannel.SendMessage(Response.Span()); + break; + } + default: LOG_ERROR("unexpected message {}", Header->MessageType); break; @@ -1662,17 +1674,6 @@ Return Value: ConfigAppendNtPath(Environment, Buffer); } - // - // If the VM ID environment variable is present, add it to the environment - // block. - // - - auto VmId = getenv(LX_WSL2_VM_ID_ENV); - if (VmId) - { - Environment.AddVariable(LX_WSL2_VM_ID_ENV, VmId); - } - return Environment; } diff --git a/src/linux/init/init.cpp b/src/linux/init/init.cpp index 27450be..73f7e51 100644 --- a/src/linux/init/init.cpp +++ b/src/linux/init/init.cpp @@ -2239,13 +2239,27 @@ Return Value: unsetenv(LX_WSL2_DISTRO_READ_ONLY_ENV); } - const auto Value = getenv(LX_WSL2_NETWORKING_MODE_ENV); + auto Value = getenv(LX_WSL2_NETWORKING_MODE_ENV); if (Value != nullptr) { Config.NetworkingMode = static_cast(std::atoi(Value)); unsetenv(LX_WSL2_NETWORKING_MODE_ENV); } + Value = getenv(LX_WSL2_VM_ID_ENV); + if (Value != nullptr) + { + Config.VmId = Value; + + // Unset the environment variable for user distros. + // TODO: this can be removed when WSLg is updated to use `wslinfo --vm-id` instead of the environment variable. + Value = getenv(LX_WSL2_SYSTEM_DISTRO); + if (!Value || strcmp(Value, "1") != 0) + { + unsetenv(LX_WSL2_VM_ID_ENV); + } + } + // // If the boot.systemd option is specified in /etc/wsl.conf, launch the distro init process as pid 1. // WSL init and session leaders continue as children of the distro init process. diff --git a/src/linux/init/main.cpp b/src/linux/init/main.cpp index 0f9c7c0..2da5cd2 100644 --- a/src/linux/init/main.cpp +++ b/src/linux/init/main.cpp @@ -1827,6 +1827,11 @@ try AddEnvironmentVariable(LX_WSL2_USER_PROFILE, UserProfile); AddEnvironmentVariable(LX_WSL2_NETWORKING_MODE_ENV, std::to_string(static_cast(Config.NetworkingMode)).c_str()); + if (UserProfile) + { + AddEnvironmentVariable(LX_WSL2_SYSTEM_DISTRO, "1"); + } + if (DistroInitPid.has_value()) { AddEnvironmentVariable(LX_WSL2_DISTRO_INIT_PID, std::to_string(static_cast(DistroInitPid.value())).c_str()); diff --git a/src/linux/init/util.cpp b/src/linux/init/util.cpp index 04e9f5a..ed214f3 100644 --- a/src/linux/init/util.cpp +++ b/src/linux/init/util.cpp @@ -1201,8 +1201,7 @@ std::optional UtilGetNetworkingMode(void) Routine Description: - This routine gets the feature flags, either directly, from an environment - variable, or by querying it from the init process. + This routine queries the networking mode from the init process. Arguments: @@ -1210,16 +1209,12 @@ Arguments: Return Value: - The feature flags. + The networking mode if successful, std::nullopt otherwise. --*/ try { - // - // Query init for the value. - // - wsl::shared::SocketChannel channel{UtilConnectUnix(WSL_INIT_INTEROP_SOCKET), "wslinfo"}; THROW_LAST_ERROR_IF(channel.Socket() < 0); @@ -1297,6 +1292,40 @@ Return Value: return Result; } +std::string UtilGetVmId(void) + +/*++ + +Routine Description: + + This routine queries the VM ID from the init process. + +Arguments: + + None. + +Return Value: + + The VM ID if successful, an empty string otherwise. + +--*/ + +try +{ + wsl::shared::SocketChannel channel{UtilConnectUnix(WSL_INIT_INTEROP_SOCKET), "wslinfo"}; + THROW_LAST_ERROR_IF(channel.Socket() < 0); + + wsl::shared::MessageWriter Message(LxInitMessageQueryVmId); + channel.SendMessage(Message.Span()); + + return channel.ReceiveMessage().Buffer; +} +catch (...) +{ + LOG_CAUGHT_EXCEPTION(); + return {}; +} + void UtilInitGroups(const char* User, gid_t Gid) /*++ diff --git a/src/linux/init/util.h b/src/linux/init/util.h index c97bfdc..b90fa6c 100644 --- a/src/linux/init/util.h +++ b/src/linux/init/util.h @@ -227,6 +227,8 @@ std::optional UtilGetNetworkingMode(void); pid_t UtilGetPpid(pid_t Pid); +std::string UtilGetVmId(void); + void UtilInitGroups(const char* User, gid_t Gid); void UtilInitializeMessageBuffer(std::vector& Buffer); diff --git a/src/linux/init/wslinfo.cpp b/src/linux/init/wslinfo.cpp index 5605b6c..13538db 100644 --- a/src/linux/init/wslinfo.cpp +++ b/src/linux/init/wslinfo.cpp @@ -149,14 +149,21 @@ Return Value: } else if (Mode.value() == WslInfoMode::VMId) { - auto value = UtilGetEnvironmentVariable(LX_WSL2_VM_ID_ENV); - if (value.empty()) + if (UtilIsUtilityVm()) { - std::cerr << Localization::MessageNoValueFound() << "\n"; - return 1; - } + auto vmId = UtilGetVmId(); + if (vmId.empty()) + { + std::cerr << Localization::MessageNoValueFound() << "\n"; + return 1; + } - std::cout << value; + std::cout << vmId; + } + else + { + std::cout << "wsl1"; + } } else { diff --git a/src/shared/inc/lxinitshared.h b/src/shared/inc/lxinitshared.h index 145afdf..454ac08 100644 --- a/src/shared/inc/lxinitshared.h +++ b/src/shared/inc/lxinitshared.h @@ -255,6 +255,7 @@ Abstract: #define LX_WSL2_SHARED_MEMORY_OB_DIRECTORY "WSL2_SHARED_MEMORY_OB_DIRECTORY" #define LX_WSL2_INSTALL_PATH "WSL2_INSTALL_PATH" #define LX_WSL2_SAFE_MODE "WSL2_SAFE_MODE" +#define LX_WSL2_SYSTEM_DISTRO "WSL2_SYSTEM_DISTRO" #define LX_WSL2_USER_PROFILE "WSL2_USER_PROFILE" #define LX_WSL2_VM_ID_ENV "WSL2_VM_ID" #define LX_WSL_PID_ENV "WSL2_PID" @@ -304,6 +305,7 @@ typedef enum _LX_MESSAGE_TYPE LxInitMessageCreateLoginSession, LxInitMessageStopPlan9Server, LxInitMessageQueryNetworkingMode, + LxInitMessageQueryVmId, LxInitCreateProcess, LxInitOobeResult, LxMiniInitMessageLaunchInit, @@ -1496,6 +1498,16 @@ typedef struct _LX_INIT_OOBE_RESULT PRETTY_PRINT(FIELD(Header), FIELD(Result), FIELD(DefaultUid)); } LX_INIT_OOBE_RESULT, *PLX_INIT_OOBE_RESULT; +typedef struct _LX_INIT_QUERY_VM_ID +{ + static inline auto Type = LxInitMessageQueryVmId; + + MESSAGE_HEADER Header; + char Buffer[]; + + PRETTY_PRINT(FIELD(Header), FIELD(Buffer)); +} LX_INIT_QUERY_VM_ID, *PLX_INIT_QUERY_VM_ID; + template <> struct std::formatter { diff --git a/test/windows/UnitTests.cpp b/test/windows/UnitTests.cpp index 8233bfe..aabb72a 100644 --- a/test/windows/UnitTests.cpp +++ b/test/windows/UnitTests.cpp @@ -843,16 +843,25 @@ class UnitTests L"arguments.\n"); } - if (LxsstuVmMode()) { - // Get the VM ID from the distro and validate that it not null. - auto [vmId, vmIdErr] = LxsstuLaunchWslAndCaptureOutput(L"env | grep 'WSL2_VM_ID' | awk -F= '{print $2}'"); - VERIFY_ARE_NOT_EQUAL(vmId, L""); - - // Ensure that the response from wslinfo matches the VM id from the distros environment - auto [out, err] = LxsstuLaunchWslAndCaptureOutput(L"wslinfo --vm-id"); - VERIFY_ARE_EQUAL(out, std::format(L"{}", vmId)); + auto [out, err] = LxsstuLaunchWslAndCaptureOutput(L"wslinfo --vm-id -n"); VERIFY_ARE_EQUAL(err, L""); + if (LxsstuVmMode()) + { + // Ensure that the response from wslinfo has the VM ID. + auto guid = wsl::shared::string::ToGuid(out); + VERIFY_IS_TRUE(guid.has_value()); + VERIFY_IS_FALSE(IsEqualGUID(guid.value(), GUID_NULL)); + + // Validate that the VM ID is not propagated to user commands. + std::tie(out, err) = LxsstuLaunchWslAndCaptureOutput(L"echo -n \"$WSL2_VM_ID\""); + VERIFY_ARE_EQUAL(out, L""); + VERIFY_ARE_EQUAL(err, L""); + } + else + { + VERIFY_ARE_EQUAL(out, L"wsl1"); + } } }