cli: allow exec server to listen on a port and require token authentication (#188434)

* cli: allow exec server to listen on a port and require token authentication

For remote ssh on Windows where pipe forwarding doesn't work

* fix linux build
This commit is contained in:
Connor Peet
2023-07-21 09:32:20 -07:00
committed by GitHub
parent fb031d4957
commit b5038f81d1
9 changed files with 121 additions and 40 deletions

View File

@@ -25,7 +25,7 @@ fn apply_build_environment_variables() {
}
let pkg_dir = env::var("CARGO_MANIFEST_DIR").unwrap();
let mut cmd = Command::new("node");
let mut cmd = Command::new(env::var("NODE_PATH").unwrap_or_else(|_| "node".to_string()));
cmd.arg("../build/azure-pipelines/cli/prepare.js");
cmd.current_dir(&pkg_dir);
cmd.env("VSCODE_CLI_PREPARE_OUTPUT", "json");

View File

@@ -4,7 +4,10 @@
*--------------------------------------------------------------------------------------------*/
use crate::{constants::APPLICATION_NAME, util::errors::CodeError};
use async_trait::async_trait;
use std::path::{Path, PathBuf};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpListener;
use uuid::Uuid;
// todo: we could probably abstract this into some crate, if one doesn't already exist
@@ -39,7 +42,7 @@ cfg_if::cfg_if! {
pipe.into_split()
}
} else {
use tokio::{time::sleep, io::{AsyncRead, AsyncWrite, ReadBuf}};
use tokio::{time::sleep, io::ReadBuf};
use tokio::net::windows::named_pipe::{ClientOptions, ServerOptions, NamedPipeClient, NamedPipeServer};
use std::{time::Duration, pin::Pin, task::{Context, Poll}, io};
use pin_project::pin_project;
@@ -181,3 +184,34 @@ pub fn get_socket_name() -> PathBuf {
}
}
}
pub type AcceptedRW = (
Box<dyn AsyncRead + Send + Unpin>,
Box<dyn AsyncWrite + Send + Unpin>,
);
#[async_trait]
pub trait AsyncRWAccepter {
async fn accept_rw(&mut self) -> Result<AcceptedRW, CodeError>;
}
#[async_trait]
impl AsyncRWAccepter for AsyncPipeListener {
async fn accept_rw(&mut self) -> Result<AcceptedRW, CodeError> {
let pipe = self.accept().await?;
let (read, write) = socket_stream_split(pipe);
Ok((Box::new(read), Box::new(write)))
}
}
#[async_trait]
impl AsyncRWAccepter for TcpListener {
async fn accept_rw(&mut self) -> Result<AcceptedRW, CodeError> {
let (stream, _) = self
.accept()
.await
.map_err(CodeError::AsyncPipeListenerFailed)?;
let (read, write) = tokio::io::split(stream);
Ok((Box::new(read), Box::new(write)))
}
}

View File

@@ -182,6 +182,12 @@ pub struct CommandShellArgs {
/// Listen on a socket instead of stdin/stdout.
#[clap(long)]
pub on_socket: bool,
/// Listen on a port instead of stdin/stdout.
#[clap(long)]
pub on_port: bool,
/// Require the given token string to be given in the handshake.
#[clap(long)]
pub require_token: Option<String>,
}
#[derive(Args, Debug, Clone)]

View File

@@ -20,7 +20,7 @@ use super::{
};
use crate::{
async_pipe::{get_socket_name, listen_socket_rw_stream, socket_stream_split},
async_pipe::{get_socket_name, listen_socket_rw_stream, AsyncRWAccepter},
auth::Auth,
constants::{APPLICATION_NAME, TUNNEL_CLI_LOCK_NAME, TUNNEL_SERVICE_LOCK_NAME},
log,
@@ -35,7 +35,7 @@ use crate::{
singleton_server::{
make_singleton_server, start_singleton_server, BroadcastLogSink, SingletonServerArgs,
},
Next, ServeStreamParams, ServiceContainer, ServiceManager,
AuthRequired, Next, ServeStreamParams, ServiceContainer, ServiceManager,
},
util::{
app_lock::AppMutex,
@@ -128,36 +128,52 @@ pub async fn command_shell(ctx: CommandContext, args: CommandShellArgs) -> Resul
log: ctx.log,
launcher_paths: ctx.paths,
platform,
requires_auth: true,
requires_auth: args
.require_token
.map(AuthRequired::VSDAWithToken)
.unwrap_or(AuthRequired::VSDA),
exit_barrier: ShutdownRequest::create_rx([ShutdownRequest::CtrlC]),
code_server_args: (&ctx.args).into(),
};
if !args.on_socket {
serve_stream(tokio::io::stdin(), tokio::io::stderr(), params).await;
return Ok(0);
}
let mut listener: Box<dyn AsyncRWAccepter> = match (args.on_port, args.on_socket) {
(_, true) => {
let socket = get_socket_name();
let listener = listen_socket_rw_stream(&socket)
.await
.map_err(|e| wrap(e, "error listening on socket"))?;
let socket = get_socket_name();
let mut listener = listen_socket_rw_stream(&socket)
.await
.map_err(|e| wrap(e, "error listening on socket"))?;
params
.log
.result(format!("Listening on {}", socket.display()));
params
.log
.result(format!("Listening on {}", socket.display()));
Box::new(listener)
}
(true, _) => {
let listener = tokio::net::TcpListener::bind("127.0.0.1:0")
.await
.map_err(|e| wrap(e, "error listening on port"))?;
params
.log
.result(format!("Listening on {}", listener.local_addr().unwrap()));
Box::new(listener)
}
_ => {
serve_stream(tokio::io::stdin(), tokio::io::stderr(), params).await;
return Ok(0);
}
};
let mut servers = FuturesUnordered::new();
loop {
tokio::select! {
Some(_) = servers.next() => {},
socket = listener.accept() => {
socket = listener.accept_rw() => {
match socket {
Ok(s) => {
let (read, write) = socket_stream_split(s);
servers.push(serve_stream(read, write, params.clone()));
},
Ok((read, write)) => servers.push(serve_stream(read, write, params.clone())),
Err(e) => {
error!(params.log, &format!("Error accepting connection: {}", e));
return Ok(1);

View File

@@ -122,7 +122,7 @@ pub struct MsgPackCodec<T> {
impl<T> MsgPackCodec<T> {
pub fn new() -> Self {
Self {
_marker: std::marker::PhantomData::default(),
_marker: std::marker::PhantomData,
}
}
}

View File

@@ -34,7 +34,7 @@ mod service_macos;
mod service_windows;
mod socket_signal;
pub use control_server::{serve, serve_stream, Next, ServeStreamParams};
pub use control_server::{serve, serve_stream, Next, ServeStreamParams, AuthRequired};
pub use nosleep::SleepInhibitor;
pub use service::{
create_service_manager, ServiceContainer, ServiceManager, SERVICE_LOG_FILE_NAME,

View File

@@ -48,11 +48,11 @@ use super::dev_tunnels::ActiveTunnel;
use super::paths::prune_stopped_servers;
use super::port_forwarder::{PortForwarding, PortForwardingProcessor};
use super::protocol::{
AcquireCliParams, CallServerHttpParams, CallServerHttpResult, ChallengeIssueResponse,
ChallengeVerifyParams, ClientRequestMethod, EmptyObject, ForwardParams, ForwardResult,
FsStatRequest, FsStatResponse, GetEnvResponse, GetHostnameResponse, HttpBodyParams,
HttpHeadersParams, ServeParams, ServerLog, ServerMessageParams, SpawnParams, SpawnResult,
ToClientRequest, UnforwardParams, UpdateParams, UpdateResult, VersionResponse,
AcquireCliParams, CallServerHttpParams, CallServerHttpResult, ChallengeIssueParams,
ChallengeIssueResponse, ChallengeVerifyParams, ClientRequestMethod, EmptyObject, ForwardParams,
ForwardResult, FsStatRequest, FsStatResponse, GetEnvResponse, GetHostnameResponse,
HttpBodyParams, HttpHeadersParams, ServeParams, ServerLog, ServerMessageParams, SpawnParams,
SpawnResult, ToClientRequest, UnforwardParams, UpdateParams, UpdateResult, VersionResponse,
METHOD_CHALLENGE_VERIFY,
};
use super::server_bridge::ServerBridge;
@@ -94,8 +94,8 @@ struct HandlerContext {
/// Handler auth state.
enum AuthState {
/// Auth is required, we're waiting for the client to send its challenge.
WaitingForChallenge,
/// Auth is required, we're waiting for the client to send its challenge optionally bearing a token.
WaitingForChallenge(Option<String>),
/// A challenge has been issued. Waiting for a verification.
ChallengeIssued(String),
/// Auth is no longer required.
@@ -215,7 +215,7 @@ pub async fn serve(
code_server_args: own_code_server_args,
platform,
exit_barrier: own_exit,
requires_auth: false,
requires_auth: AuthRequired::None,
}).with_context(cx.clone()).await;
cx.span().add_event(
@@ -233,13 +233,20 @@ pub async fn serve(
}
}
#[derive(Clone)]
pub enum AuthRequired {
None,
VSDA,
VSDAWithToken(String),
}
#[derive(Clone)]
pub struct ServeStreamParams {
pub log: log::Logger,
pub launcher_paths: LauncherPaths,
pub code_server_args: CodeServerArgs,
pub platform: Platform,
pub requires_auth: bool,
pub requires_auth: AuthRequired,
pub exit_barrier: Barrier<ShutdownSignal>,
}
@@ -269,7 +276,7 @@ fn make_socket_rpc(
launcher_paths: LauncherPaths,
code_server_args: CodeServerArgs,
port_forwarding: Option<PortForwarding>,
requires_auth: bool,
requires_auth: AuthRequired,
platform: Platform,
) -> RpcDispatcher<MsgPackSerializer, HandlerContext> {
let http_requests = Arc::new(std::sync::Mutex::new(HashMap::new()));
@@ -277,8 +284,9 @@ fn make_socket_rpc(
let mut rpc = RpcBuilder::new(MsgPackSerializer {}).methods(HandlerContext {
did_update: Arc::new(AtomicBool::new(false)),
auth_state: Arc::new(std::sync::Mutex::new(match requires_auth {
true => AuthState::WaitingForChallenge,
false => AuthState::Authenticated,
AuthRequired::VSDAWithToken(t) => AuthState::WaitingForChallenge(Some(t)),
AuthRequired::VSDA => AuthState::WaitingForChallenge(None),
AuthRequired::None => AuthState::Authenticated,
})),
socket_tx,
log: log.clone(),
@@ -305,8 +313,8 @@ fn make_socket_rpc(
ensure_auth(&c.auth_state)?;
handle_get_env()
});
rpc.register_sync(METHOD_CHALLENGE_ISSUE, |_: EmptyObject, c| {
handle_challenge_issue(&c.auth_state)
rpc.register_sync(METHOD_CHALLENGE_ISSUE, |p: ChallengeIssueParams, c| {
handle_challenge_issue(p, &c.auth_state)
});
rpc.register_sync(METHOD_CHALLENGE_VERIFY, |p: ChallengeVerifyParams, c| {
handle_challenge_verify(p.response, &c.auth_state)
@@ -423,6 +431,7 @@ async fn process_socket(
let rx_counter = Arc::new(AtomicUsize::new(0));
let http_requests = Arc::new(std::sync::Mutex::new(HashMap::new()));
let already_authed = matches!(requires_auth, AuthRequired::None);
let rpc = make_socket_rpc(
log.clone(),
socket_tx.clone(),
@@ -440,7 +449,7 @@ async fn process_socket(
let socket_tx = socket_tx.clone();
let exit_barrier = exit_barrier.clone();
tokio::spawn(async move {
if !requires_auth {
if already_authed {
send_version(&socket_tx).await;
}
@@ -826,13 +835,22 @@ fn handle_get_env() -> Result<GetEnvResponse, AnyError> {
}
fn handle_challenge_issue(
params: ChallengeIssueParams,
auth_state: &Arc<std::sync::Mutex<AuthState>>,
) -> Result<ChallengeIssueResponse, AnyError> {
let challenge = create_challenge();
let mut auth_state = auth_state.lock().unwrap();
*auth_state = AuthState::ChallengeIssued(challenge.clone());
if let AuthState::WaitingForChallenge(Some(s)) = &*auth_state {
println!("looking for token {}, got {:?}", s, params.token);
match &params.token {
Some(t) if s != t => return Err(CodeError::AuthChallengeBadToken.into()),
None => return Err(CodeError::AuthChallengeBadToken.into()),
_ => {}
}
}
*auth_state = AuthState::ChallengeIssued(challenge.clone());
Ok(ChallengeIssueResponse { challenge })
}
@@ -844,7 +862,7 @@ fn handle_challenge_verify(
match &*auth_state {
AuthState::Authenticated => Ok(EmptyObject {}),
AuthState::WaitingForChallenge => Err(CodeError::AuthChallengeNotIssued.into()),
AuthState::WaitingForChallenge(_) => Err(CodeError::AuthChallengeNotIssued.into()),
AuthState::ChallengeIssued(c) => match verify_challenge(c, &response) {
false => Err(CodeError::AuthChallengeNotIssued.into()),
true => {

View File

@@ -199,6 +199,11 @@ pub struct SpawnResult {
pub const METHOD_CHALLENGE_ISSUE: &str = "challenge_issue";
pub const METHOD_CHALLENGE_VERIFY: &str = "challenge_verify";
#[derive(Serialize, Deserialize)]
pub struct ChallengeIssueParams {
pub token: Option<String>,
}
#[derive(Serialize, Deserialize)]
pub struct ChallengeIssueResponse {
pub challenge: String,

View File

@@ -509,6 +509,8 @@ pub enum CodeError {
ServerAuthRequired,
#[error("challenge not yet issued")]
AuthChallengeNotIssued,
#[error("challenge token is invalid")]
AuthChallengeBadToken,
#[error("unauthorized client refused")]
AuthMismatch,
#[error("keyring communication timed out after 5s")]