diff --git a/proxy/src/console_redirect_proxy.rs b/proxy/src/console_redirect_proxy.rs index e7fd057641..3a8a22ed55 100644 --- a/proxy/src/console_redirect_proxy.rs +++ b/proxy/src/console_redirect_proxy.rs @@ -1,6 +1,6 @@ use std::sync::Arc; -use futures::{FutureExt, TryFutureExt}; +use futures::TryFutureExt; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use tokio_util::sync::CancellationToken; use tokio_util::task::task_tracker::TaskTrackerToken; @@ -15,7 +15,7 @@ use crate::metrics::{Metrics, NumClientConnectionsGuard}; use crate::protocol2::{ConnectHeader, ConnectionInfo, read_proxy_protocol}; use crate::proxy::connect_compute::{TcpMechanism, connect_to_compute}; use crate::proxy::handshake::{HandshakeData, handshake}; -use crate::proxy::passthrough::ProxyPassthrough; +use crate::proxy::passthrough::passthrough; use crate::proxy::{ClientRequestError, prepare_client_connection, run_until_cancelled}; pub async fn task_main( @@ -101,30 +101,28 @@ pub async fn task_main( &config.region, ); + let span = ctx.span(); + let mut slot = Some(ctx); let res = handle_client( config, backend, - &ctx, + &mut slot, cancellation_handler, socket, conn_gauge, tracker, ) - .instrument(ctx.span()) - .boxed() + .instrument(span) .await; - match res { - Err(e) => { + match (slot, res) { + (None, _) => {} + (Some(ctx), Ok(())) => { + ctx.success(); + } + (Some(ctx), Err(e)) => { ctx.set_error_kind(e.get_error_kind()); - error!(parent: &ctx.span(), "per-client task finished with an error: {e:#}"); - } - Ok(None) => { - ctx.set_success(); - } - Ok(Some(p)) => { - ctx.set_success(); - tokio::spawn(p.proxy_pass(ctx, &config.connect_to_compute)); + tracing::warn!(parent: &ctx.span(), "per-client task finished with an error: {e:#}"); } } }); @@ -140,40 +138,39 @@ pub async fn task_main( } #[allow(clippy::too_many_arguments)] -pub(crate) async fn handle_client( +pub(crate) async fn handle_client( config: &'static ProxyConfig, backend: &'static ConsoleRedirectBackend, - ctx: &RequestContext, + ctx_slot: &mut Option, cancellation_handler: Arc, stream: S, conn_gauge: NumClientConnectionsGuard<'static>, tracker: TaskTrackerToken, -) -> Result>, ClientRequestError> { - debug!( - protocol = %ctx.protocol(), - "handling interactive connection from client" - ); +) -> Result<(), ClientRequestError> { + let protocol = ctx_slot.as_ref().expect("context must be set").protocol(); + debug!(%protocol, "handling interactive connection from client"); let metrics = &Metrics::get().proxy; - let proto = ctx.protocol(); - let request_gauge = metrics.connection_requests.guard(proto); + let request_gauge = metrics.connection_requests.guard(protocol); let tls = config.tls_config.load(); let tls = tls.as_deref(); - let record_handshake_error = !ctx.has_private_peer_addr(); - let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client); - let do_handshake = handshake(ctx, stream, tracker, tls, record_handshake_error); + let data = { + let ctx = ctx_slot.as_ref().expect("context must be set"); + let record_handshake_error = !ctx.has_private_peer_addr(); + let _pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client); + let do_handshake = handshake(ctx, stream, tracker, tls, record_handshake_error); + tokio::time::timeout(config.handshake_timeout, do_handshake).await?? + }; - let (mut stream, params) = match tokio::time::timeout(config.handshake_timeout, do_handshake) - .await?? - { + let (mut stream, params) = match data { HandshakeData::Startup(stream, params) => (stream, params), HandshakeData::Cancel(cancel_key_data, tracker) => { // spawn a task to cancel the session, but don't wait for it tokio::spawn({ let cancellation_handler_clone = Arc::clone(&cancellation_handler); - let ctx = ctx.clone(); + let ctx = ctx_slot.take().expect("context must be set"); let cancel_span = tracing::span!(parent: None, tracing::Level::INFO, "cancel_session", session_id = ?ctx.session_id()); cancel_span.follows_from(tracing::Span::current()); async move { @@ -193,11 +190,11 @@ pub(crate) async fn handle_client( .instrument(cancel_span) }); - return Ok(None); + return Ok(()); } }; - drop(pause); + let ctx = ctx_slot.as_ref().expect("context must be set"); ctx.set_db_options(params.clone()); let (node_info, user_info, _ip_allowlist) = match backend @@ -239,14 +236,19 @@ pub(crate) async fn handle_client( let (stream, read_buf, tracker) = stream.into_inner(); node.stream.write_all(&read_buf).await?; - Ok(Some(ProxyPassthrough { - client: stream, - private_link_id: None, - compute: node, - session_id: ctx.session_id(), - cancel: session, - _req: request_gauge, - _conn: conn_gauge, - _tracker: tracker, - })) + let ctx = ctx_slot.take().expect("context must be set"); + ctx.set_success(); + + tokio::spawn(passthrough( + ctx, + &config.connect_to_compute, + stream, + node, + session, + request_gauge, + conn_gauge, + tracker, + )); + + Ok(()) } diff --git a/proxy/src/proxy/mod.rs b/proxy/src/proxy/mod.rs index 50d443f53d..7be27c19b0 100644 --- a/proxy/src/proxy/mod.rs +++ b/proxy/src/proxy/mod.rs @@ -13,10 +13,11 @@ pub use copy_bidirectional::{ErrorSource, copy_bidirectional_client_compute}; use futures::{FutureExt, TryFutureExt}; use itertools::Itertools; use once_cell::sync::OnceCell; +use passthrough::passthrough; use pq_proto::{BeMessage as Be, CancelKeyData, StartupMessageParams}; use regex::Regex; use serde::{Deserialize, Serialize}; -use smol_str::{SmolStr, ToSmolStr, format_smolstr}; +use smol_str::{SmolStr, format_smolstr}; use thiserror::Error; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use tokio_util::sync::CancellationToken; @@ -24,13 +25,12 @@ use tokio_util::task::task_tracker::TaskTrackerToken; use tracing::{Instrument, debug, error, info, warn}; use self::connect_compute::{TcpMechanism, connect_to_compute}; -use self::passthrough::ProxyPassthrough; use crate::cancellation::{self, CancellationHandler}; use crate::config::{ProxyConfig, ProxyProtocolV2, TlsConfig}; use crate::context::RequestContext; use crate::error::ReportableError; use crate::metrics::{Metrics, NumClientConnectionsGuard}; -use crate::protocol2::{ConnectHeader, ConnectionInfo, ConnectionInfoExtra, read_proxy_protocol}; +use crate::protocol2::{ConnectHeader, ConnectionInfo, read_proxy_protocol}; use crate::proxy::handshake::{HandshakeData, handshake}; use crate::rate_limiter::EndpointRateLimiter; use crate::stream::{PqStream, Stream}; @@ -412,23 +412,16 @@ pub(crate) async fn handle_client Some(vpce_id.clone()), - Some(ConnectionInfoExtra::Azure { link_id }) => Some(link_id.to_smolstr()), - None => None, - }; - - let p = ProxyPassthrough { - client: stream, - private_link_id, - compute: node, - session_id: ctx.session_id(), - cancel: session, - _req: request_gauge, - _conn: conn_gauge, - _tracker: tracker, - }; - tokio::spawn(p.proxy_pass(ctx, &config.connect_to_compute)); + tokio::spawn(passthrough( + ctx, + &config.connect_to_compute, + stream, + node, + session, + request_gauge, + conn_gauge, + tracker, + )); Ok(()) } diff --git a/proxy/src/proxy/passthrough.rs b/proxy/src/proxy/passthrough.rs index e9c4d0e2f4..052c3238ad 100644 --- a/proxy/src/proxy/passthrough.rs +++ b/proxy/src/proxy/passthrough.rs @@ -1,4 +1,4 @@ -use smol_str::SmolStr; +use smol_str::{SmolStr, ToSmolStr}; use tokio::io::{AsyncRead, AsyncWrite}; use tokio_util::task::task_tracker::TaskTrackerToken; use tracing::debug; @@ -11,6 +11,7 @@ use crate::config::ComputeConfig; use crate::context::RequestContext; use crate::control_plane::messages::MetricsAuxInfo; use crate::metrics::{Direction, Metrics, NumClientConnectionsGuard, NumConnectionRequestsGuard}; +use crate::protocol2::ConnectionInfoExtra; use crate::stream::Stream; use crate::usage_metrics::{Ids, MetricCounterRecorder, USAGE_METRICS}; @@ -62,60 +63,53 @@ pub(crate) async fn proxy_pass( Ok(()) } -pub(crate) struct ProxyPassthrough { - pub(crate) client: Stream, - pub(crate) compute: PostgresConnection, - pub(crate) session_id: uuid::Uuid, - pub(crate) private_link_id: Option, - pub(crate) cancel: cancellation::Session, +#[allow(clippy::too_many_arguments)] +pub(crate) async fn passthrough( + ctx: RequestContext, + compute_config: &'static ComputeConfig, - pub(crate) _req: NumConnectionRequestsGuard<'static>, - pub(crate) _conn: NumClientConnectionsGuard<'static>, - /// ensures proxy stays online while this is set. - pub(crate) _tracker: TaskTrackerToken, -} + client: Stream, + compute: PostgresConnection, + cancel: cancellation::Session, -impl ProxyPassthrough { - pub(crate) async fn proxy_pass( - self, - ctx: RequestContext, - compute_config: &'static ComputeConfig, - ) { - let _disconnect = ctx.log_connect(); - let res = proxy_pass( - self.client, - self.compute.stream, - self.compute.aux, - self.private_link_id, - ) - .await; + _req: NumConnectionRequestsGuard<'static>, + _conn: NumClientConnectionsGuard<'static>, + _tracker: TaskTrackerToken, +) { + let session_id = ctx.session_id(); + let private_link_id = match ctx.extra() { + Some(ConnectionInfoExtra::Aws { vpce_id }) => Some(vpce_id.clone()), + Some(ConnectionInfoExtra::Azure { link_id }) => Some(link_id.to_smolstr()), + None => None, + }; - match res { - Ok(()) => {} - Err(ErrorSource::Client(e)) => { - tracing::warn!( - session_id = ?self.session_id, - "per-client task finished with an IO error from the client: {e:#}" - ); - } - Err(ErrorSource::Compute(e)) => { - tracing::error!( - session_id = ?self.session_id, - "per-client task finished with an IO error from the compute: {e:#}" - ); - } + let _disconnect = ctx.log_connect(); + let res = proxy_pass(client, compute.stream, compute.aux, private_link_id).await; + + match res { + Ok(()) => {} + Err(ErrorSource::Client(e)) => { + tracing::warn!( + session_id = ?session_id, + "per-client task finished with an IO error from the client: {e:#}" + ); } - - if let Err(err) = self - .compute - .cancel_closure - .try_cancel_query(compute_config) - .await - { - tracing::warn!(session_id = ?self.session_id, ?err, "could not cancel the query in the database"); + Err(ErrorSource::Compute(e)) => { + tracing::error!( + session_id = ?session_id, + "per-client task finished with an IO error from the compute: {e:#}" + ); } - - // we don't need a result. If the queue is full, we just log the error - drop(self.cancel.remove_cancel_key()); } + + if let Err(err) = compute + .cancel_closure + .try_cancel_query(compute_config) + .await + { + tracing::warn!(session_id = ?session_id, ?err, "could not cancel the query in the database"); + } + + // we don't need a result. If the queue is full, we just log the error + drop(cancel.remove_cancel_key()); }