mirror of
https://github.com/microsoft/vscode.git
synced 2026-04-17 12:10:22 -05:00
cli: allow exec server to listen on a port and require token authentication (#188434)
* cli: allow exec server to listen on a port and require token authentication For remote ssh on Windows where pipe forwarding doesn't work * fix linux build
This commit is contained in:
@@ -25,7 +25,7 @@ fn apply_build_environment_variables() {
|
||||
}
|
||||
|
||||
let pkg_dir = env::var("CARGO_MANIFEST_DIR").unwrap();
|
||||
let mut cmd = Command::new("node");
|
||||
let mut cmd = Command::new(env::var("NODE_PATH").unwrap_or_else(|_| "node".to_string()));
|
||||
cmd.arg("../build/azure-pipelines/cli/prepare.js");
|
||||
cmd.current_dir(&pkg_dir);
|
||||
cmd.env("VSCODE_CLI_PREPARE_OUTPUT", "json");
|
||||
|
||||
@@ -4,7 +4,10 @@
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
use crate::{constants::APPLICATION_NAME, util::errors::CodeError};
|
||||
use async_trait::async_trait;
|
||||
use std::path::{Path, PathBuf};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio::net::TcpListener;
|
||||
use uuid::Uuid;
|
||||
|
||||
// todo: we could probably abstract this into some crate, if one doesn't already exist
|
||||
@@ -39,7 +42,7 @@ cfg_if::cfg_if! {
|
||||
pipe.into_split()
|
||||
}
|
||||
} else {
|
||||
use tokio::{time::sleep, io::{AsyncRead, AsyncWrite, ReadBuf}};
|
||||
use tokio::{time::sleep, io::ReadBuf};
|
||||
use tokio::net::windows::named_pipe::{ClientOptions, ServerOptions, NamedPipeClient, NamedPipeServer};
|
||||
use std::{time::Duration, pin::Pin, task::{Context, Poll}, io};
|
||||
use pin_project::pin_project;
|
||||
@@ -181,3 +184,34 @@ pub fn get_socket_name() -> PathBuf {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub type AcceptedRW = (
|
||||
Box<dyn AsyncRead + Send + Unpin>,
|
||||
Box<dyn AsyncWrite + Send + Unpin>,
|
||||
);
|
||||
|
||||
#[async_trait]
|
||||
pub trait AsyncRWAccepter {
|
||||
async fn accept_rw(&mut self) -> Result<AcceptedRW, CodeError>;
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl AsyncRWAccepter for AsyncPipeListener {
|
||||
async fn accept_rw(&mut self) -> Result<AcceptedRW, CodeError> {
|
||||
let pipe = self.accept().await?;
|
||||
let (read, write) = socket_stream_split(pipe);
|
||||
Ok((Box::new(read), Box::new(write)))
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl AsyncRWAccepter for TcpListener {
|
||||
async fn accept_rw(&mut self) -> Result<AcceptedRW, CodeError> {
|
||||
let (stream, _) = self
|
||||
.accept()
|
||||
.await
|
||||
.map_err(CodeError::AsyncPipeListenerFailed)?;
|
||||
let (read, write) = tokio::io::split(stream);
|
||||
Ok((Box::new(read), Box::new(write)))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -182,6 +182,12 @@ pub struct CommandShellArgs {
|
||||
/// Listen on a socket instead of stdin/stdout.
|
||||
#[clap(long)]
|
||||
pub on_socket: bool,
|
||||
/// Listen on a port instead of stdin/stdout.
|
||||
#[clap(long)]
|
||||
pub on_port: bool,
|
||||
/// Require the given token string to be given in the handshake.
|
||||
#[clap(long)]
|
||||
pub require_token: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Args, Debug, Clone)]
|
||||
|
||||
@@ -20,7 +20,7 @@ use super::{
|
||||
};
|
||||
|
||||
use crate::{
|
||||
async_pipe::{get_socket_name, listen_socket_rw_stream, socket_stream_split},
|
||||
async_pipe::{get_socket_name, listen_socket_rw_stream, AsyncRWAccepter},
|
||||
auth::Auth,
|
||||
constants::{APPLICATION_NAME, TUNNEL_CLI_LOCK_NAME, TUNNEL_SERVICE_LOCK_NAME},
|
||||
log,
|
||||
@@ -35,7 +35,7 @@ use crate::{
|
||||
singleton_server::{
|
||||
make_singleton_server, start_singleton_server, BroadcastLogSink, SingletonServerArgs,
|
||||
},
|
||||
Next, ServeStreamParams, ServiceContainer, ServiceManager,
|
||||
AuthRequired, Next, ServeStreamParams, ServiceContainer, ServiceManager,
|
||||
},
|
||||
util::{
|
||||
app_lock::AppMutex,
|
||||
@@ -128,36 +128,52 @@ pub async fn command_shell(ctx: CommandContext, args: CommandShellArgs) -> Resul
|
||||
log: ctx.log,
|
||||
launcher_paths: ctx.paths,
|
||||
platform,
|
||||
requires_auth: true,
|
||||
requires_auth: args
|
||||
.require_token
|
||||
.map(AuthRequired::VSDAWithToken)
|
||||
.unwrap_or(AuthRequired::VSDA),
|
||||
exit_barrier: ShutdownRequest::create_rx([ShutdownRequest::CtrlC]),
|
||||
code_server_args: (&ctx.args).into(),
|
||||
};
|
||||
|
||||
if !args.on_socket {
|
||||
serve_stream(tokio::io::stdin(), tokio::io::stderr(), params).await;
|
||||
return Ok(0);
|
||||
}
|
||||
let mut listener: Box<dyn AsyncRWAccepter> = match (args.on_port, args.on_socket) {
|
||||
(_, true) => {
|
||||
let socket = get_socket_name();
|
||||
let listener = listen_socket_rw_stream(&socket)
|
||||
.await
|
||||
.map_err(|e| wrap(e, "error listening on socket"))?;
|
||||
|
||||
let socket = get_socket_name();
|
||||
let mut listener = listen_socket_rw_stream(&socket)
|
||||
.await
|
||||
.map_err(|e| wrap(e, "error listening on socket"))?;
|
||||
params
|
||||
.log
|
||||
.result(format!("Listening on {}", socket.display()));
|
||||
|
||||
params
|
||||
.log
|
||||
.result(format!("Listening on {}", socket.display()));
|
||||
Box::new(listener)
|
||||
}
|
||||
(true, _) => {
|
||||
let listener = tokio::net::TcpListener::bind("127.0.0.1:0")
|
||||
.await
|
||||
.map_err(|e| wrap(e, "error listening on port"))?;
|
||||
|
||||
params
|
||||
.log
|
||||
.result(format!("Listening on {}", listener.local_addr().unwrap()));
|
||||
|
||||
Box::new(listener)
|
||||
}
|
||||
_ => {
|
||||
serve_stream(tokio::io::stdin(), tokio::io::stderr(), params).await;
|
||||
return Ok(0);
|
||||
}
|
||||
};
|
||||
|
||||
let mut servers = FuturesUnordered::new();
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
Some(_) = servers.next() => {},
|
||||
socket = listener.accept() => {
|
||||
socket = listener.accept_rw() => {
|
||||
match socket {
|
||||
Ok(s) => {
|
||||
let (read, write) = socket_stream_split(s);
|
||||
servers.push(serve_stream(read, write, params.clone()));
|
||||
},
|
||||
Ok((read, write)) => servers.push(serve_stream(read, write, params.clone())),
|
||||
Err(e) => {
|
||||
error!(params.log, &format!("Error accepting connection: {}", e));
|
||||
return Ok(1);
|
||||
|
||||
@@ -122,7 +122,7 @@ pub struct MsgPackCodec<T> {
|
||||
impl<T> MsgPackCodec<T> {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
_marker: std::marker::PhantomData::default(),
|
||||
_marker: std::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -34,7 +34,7 @@ mod service_macos;
|
||||
mod service_windows;
|
||||
mod socket_signal;
|
||||
|
||||
pub use control_server::{serve, serve_stream, Next, ServeStreamParams};
|
||||
pub use control_server::{serve, serve_stream, Next, ServeStreamParams, AuthRequired};
|
||||
pub use nosleep::SleepInhibitor;
|
||||
pub use service::{
|
||||
create_service_manager, ServiceContainer, ServiceManager, SERVICE_LOG_FILE_NAME,
|
||||
|
||||
@@ -48,11 +48,11 @@ use super::dev_tunnels::ActiveTunnel;
|
||||
use super::paths::prune_stopped_servers;
|
||||
use super::port_forwarder::{PortForwarding, PortForwardingProcessor};
|
||||
use super::protocol::{
|
||||
AcquireCliParams, CallServerHttpParams, CallServerHttpResult, ChallengeIssueResponse,
|
||||
ChallengeVerifyParams, ClientRequestMethod, EmptyObject, ForwardParams, ForwardResult,
|
||||
FsStatRequest, FsStatResponse, GetEnvResponse, GetHostnameResponse, HttpBodyParams,
|
||||
HttpHeadersParams, ServeParams, ServerLog, ServerMessageParams, SpawnParams, SpawnResult,
|
||||
ToClientRequest, UnforwardParams, UpdateParams, UpdateResult, VersionResponse,
|
||||
AcquireCliParams, CallServerHttpParams, CallServerHttpResult, ChallengeIssueParams,
|
||||
ChallengeIssueResponse, ChallengeVerifyParams, ClientRequestMethod, EmptyObject, ForwardParams,
|
||||
ForwardResult, FsStatRequest, FsStatResponse, GetEnvResponse, GetHostnameResponse,
|
||||
HttpBodyParams, HttpHeadersParams, ServeParams, ServerLog, ServerMessageParams, SpawnParams,
|
||||
SpawnResult, ToClientRequest, UnforwardParams, UpdateParams, UpdateResult, VersionResponse,
|
||||
METHOD_CHALLENGE_VERIFY,
|
||||
};
|
||||
use super::server_bridge::ServerBridge;
|
||||
@@ -94,8 +94,8 @@ struct HandlerContext {
|
||||
|
||||
/// Handler auth state.
|
||||
enum AuthState {
|
||||
/// Auth is required, we're waiting for the client to send its challenge.
|
||||
WaitingForChallenge,
|
||||
/// Auth is required, we're waiting for the client to send its challenge optionally bearing a token.
|
||||
WaitingForChallenge(Option<String>),
|
||||
/// A challenge has been issued. Waiting for a verification.
|
||||
ChallengeIssued(String),
|
||||
/// Auth is no longer required.
|
||||
@@ -215,7 +215,7 @@ pub async fn serve(
|
||||
code_server_args: own_code_server_args,
|
||||
platform,
|
||||
exit_barrier: own_exit,
|
||||
requires_auth: false,
|
||||
requires_auth: AuthRequired::None,
|
||||
}).with_context(cx.clone()).await;
|
||||
|
||||
cx.span().add_event(
|
||||
@@ -233,13 +233,20 @@ pub async fn serve(
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub enum AuthRequired {
|
||||
None,
|
||||
VSDA,
|
||||
VSDAWithToken(String),
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ServeStreamParams {
|
||||
pub log: log::Logger,
|
||||
pub launcher_paths: LauncherPaths,
|
||||
pub code_server_args: CodeServerArgs,
|
||||
pub platform: Platform,
|
||||
pub requires_auth: bool,
|
||||
pub requires_auth: AuthRequired,
|
||||
pub exit_barrier: Barrier<ShutdownSignal>,
|
||||
}
|
||||
|
||||
@@ -269,7 +276,7 @@ fn make_socket_rpc(
|
||||
launcher_paths: LauncherPaths,
|
||||
code_server_args: CodeServerArgs,
|
||||
port_forwarding: Option<PortForwarding>,
|
||||
requires_auth: bool,
|
||||
requires_auth: AuthRequired,
|
||||
platform: Platform,
|
||||
) -> RpcDispatcher<MsgPackSerializer, HandlerContext> {
|
||||
let http_requests = Arc::new(std::sync::Mutex::new(HashMap::new()));
|
||||
@@ -277,8 +284,9 @@ fn make_socket_rpc(
|
||||
let mut rpc = RpcBuilder::new(MsgPackSerializer {}).methods(HandlerContext {
|
||||
did_update: Arc::new(AtomicBool::new(false)),
|
||||
auth_state: Arc::new(std::sync::Mutex::new(match requires_auth {
|
||||
true => AuthState::WaitingForChallenge,
|
||||
false => AuthState::Authenticated,
|
||||
AuthRequired::VSDAWithToken(t) => AuthState::WaitingForChallenge(Some(t)),
|
||||
AuthRequired::VSDA => AuthState::WaitingForChallenge(None),
|
||||
AuthRequired::None => AuthState::Authenticated,
|
||||
})),
|
||||
socket_tx,
|
||||
log: log.clone(),
|
||||
@@ -305,8 +313,8 @@ fn make_socket_rpc(
|
||||
ensure_auth(&c.auth_state)?;
|
||||
handle_get_env()
|
||||
});
|
||||
rpc.register_sync(METHOD_CHALLENGE_ISSUE, |_: EmptyObject, c| {
|
||||
handle_challenge_issue(&c.auth_state)
|
||||
rpc.register_sync(METHOD_CHALLENGE_ISSUE, |p: ChallengeIssueParams, c| {
|
||||
handle_challenge_issue(p, &c.auth_state)
|
||||
});
|
||||
rpc.register_sync(METHOD_CHALLENGE_VERIFY, |p: ChallengeVerifyParams, c| {
|
||||
handle_challenge_verify(p.response, &c.auth_state)
|
||||
@@ -423,6 +431,7 @@ async fn process_socket(
|
||||
let rx_counter = Arc::new(AtomicUsize::new(0));
|
||||
let http_requests = Arc::new(std::sync::Mutex::new(HashMap::new()));
|
||||
|
||||
let already_authed = matches!(requires_auth, AuthRequired::None);
|
||||
let rpc = make_socket_rpc(
|
||||
log.clone(),
|
||||
socket_tx.clone(),
|
||||
@@ -440,7 +449,7 @@ async fn process_socket(
|
||||
let socket_tx = socket_tx.clone();
|
||||
let exit_barrier = exit_barrier.clone();
|
||||
tokio::spawn(async move {
|
||||
if !requires_auth {
|
||||
if already_authed {
|
||||
send_version(&socket_tx).await;
|
||||
}
|
||||
|
||||
@@ -826,13 +835,22 @@ fn handle_get_env() -> Result<GetEnvResponse, AnyError> {
|
||||
}
|
||||
|
||||
fn handle_challenge_issue(
|
||||
params: ChallengeIssueParams,
|
||||
auth_state: &Arc<std::sync::Mutex<AuthState>>,
|
||||
) -> Result<ChallengeIssueResponse, AnyError> {
|
||||
let challenge = create_challenge();
|
||||
|
||||
let mut auth_state = auth_state.lock().unwrap();
|
||||
*auth_state = AuthState::ChallengeIssued(challenge.clone());
|
||||
if let AuthState::WaitingForChallenge(Some(s)) = &*auth_state {
|
||||
println!("looking for token {}, got {:?}", s, params.token);
|
||||
match ¶ms.token {
|
||||
Some(t) if s != t => return Err(CodeError::AuthChallengeBadToken.into()),
|
||||
None => return Err(CodeError::AuthChallengeBadToken.into()),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
*auth_state = AuthState::ChallengeIssued(challenge.clone());
|
||||
Ok(ChallengeIssueResponse { challenge })
|
||||
}
|
||||
|
||||
@@ -844,7 +862,7 @@ fn handle_challenge_verify(
|
||||
|
||||
match &*auth_state {
|
||||
AuthState::Authenticated => Ok(EmptyObject {}),
|
||||
AuthState::WaitingForChallenge => Err(CodeError::AuthChallengeNotIssued.into()),
|
||||
AuthState::WaitingForChallenge(_) => Err(CodeError::AuthChallengeNotIssued.into()),
|
||||
AuthState::ChallengeIssued(c) => match verify_challenge(c, &response) {
|
||||
false => Err(CodeError::AuthChallengeNotIssued.into()),
|
||||
true => {
|
||||
|
||||
@@ -199,6 +199,11 @@ pub struct SpawnResult {
|
||||
pub const METHOD_CHALLENGE_ISSUE: &str = "challenge_issue";
|
||||
pub const METHOD_CHALLENGE_VERIFY: &str = "challenge_verify";
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct ChallengeIssueParams {
|
||||
pub token: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct ChallengeIssueResponse {
|
||||
pub challenge: String,
|
||||
|
||||
@@ -509,6 +509,8 @@ pub enum CodeError {
|
||||
ServerAuthRequired,
|
||||
#[error("challenge not yet issued")]
|
||||
AuthChallengeNotIssued,
|
||||
#[error("challenge token is invalid")]
|
||||
AuthChallengeBadToken,
|
||||
#[error("unauthorized client refused")]
|
||||
AuthMismatch,
|
||||
#[error("keyring communication timed out after 5s")]
|
||||
|
||||
Reference in New Issue
Block a user