From eb96abbff77f946c59db27e1ab858e7347693742 Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Wed, 30 Oct 2024 12:18:22 +0000 Subject: [PATCH] moar --- proxy/src/serverless/mod.rs | 141 ++++++++++++++++++++++-------------- 1 file changed, 86 insertions(+), 55 deletions(-) diff --git a/proxy/src/serverless/mod.rs b/proxy/src/serverless/mod.rs index 4e22c5a89c..aed5abda31 100644 --- a/proxy/src/serverless/mod.rs +++ b/proxy/src/serverless/mod.rs @@ -26,7 +26,7 @@ use atomic_take::AtomicTake; use bytes::Bytes; pub use conn_pool_lib::GlobalConnPoolOptions; use futures::future::{select, Either}; -use futures::{FutureExt, TryFutureExt}; +use futures::FutureExt; use http::{Method, Response, StatusCode}; use http_body_util::combinators::BoxBody; use http_body_util::{BodyExt, Empty}; @@ -44,7 +44,7 @@ use tokio::task::JoinHandle; use tokio::time::timeout; use tokio_rustls::TlsAcceptor; use tokio_util::sync::CancellationToken; -use tokio_util::task::TaskTracker; +use tokio_util::task::task_tracker::TaskTrackerToken; use tracing::{debug, info, warn, Instrument}; use utils::http::error::ApiError; @@ -159,16 +159,18 @@ pub async fn task_main( } } - let conn_token = cancellation_token.child_token(); + let conn_cancellation_token = cancellation_token.child_token(); let tls_acceptor = tls_acceptor.clone(); let backend = backend.clone(); - let connections2 = connections.clone(); let cancellation_handler = cancellation_handler.clone(); let endpoint_rate_limiter = endpoint_rate_limiter.clone(); - connections.spawn( + let connection_token = connections.token(); + tokio::spawn( async move { - let conn_token2 = conn_token.clone(); - let _cancel_guard = config.http_config.cancel_set.insert(conn_id, conn_token2); + let _cancel_guard = config + .http_config + .cancel_set + .insert(conn_id, conn_cancellation_token.clone()); let session_id = uuid::Uuid::new_v4(); @@ -184,27 +186,19 @@ pub async fn task_main( let Some(conn) = startup_result else { return; }; - let (_, peer_addr, _) = conn; let ws_upgrade = http_connection_handler( config, backend, - connections2, - conn_token, + conn_cancellation_token, + connection_token.clone(), conn, session_id, ) .boxed() .await; - if let Some((session_id, host, websocket)) = ws_upgrade { - let ctx = RequestMonitoring::new( - session_id, - peer_addr, - crate::metrics::Protocol::Ws, - &config.region, - ); - + if let Some((ctx, host, websocket)) = ws_upgrade { let ws = websocket::serve_websocket( config, auth_backend, @@ -350,8 +344,8 @@ impl GracefulShutdown async fn http_connection_handler( config: &'static ProxyConfig, backend: Arc, - connections: TaskTracker, cancellation_token: CancellationToken, + connection_token: TaskTrackerToken, conn: ConnWithInfo, session_id: uuid::Uuid, ) -> Option { @@ -377,7 +371,7 @@ async fn http_connection_handler( let service = ProxyService { config, backend, - connections, + connection_token, http_cancellation_token, ws_tx, @@ -434,9 +428,9 @@ struct ProxyService { // global state config: &'static ProxyConfig, backend: Arc, - connections: TaskTracker, // connection state only + connection_token: TaskTrackerToken, http_cancellation_token: CancellationToken, ws_tx: WsSpawner, peer_addr: IpAddr, @@ -468,23 +462,18 @@ impl hyper::service::Service> for ProxyService { // Cancel the current inflight HTTP request if the requets stream is closed. // This is slightly different to `_cancel_connection` in that // h2 can cancel individual requests with a `RST_STREAM`. - let http_request_token = self.http_cancellation_token.child_token(); - let cancel_request = Some(http_request_token.clone().drop_guard()); + let http_req_cancellation_token = self.http_cancellation_token.child_token(); + let cancel_request = Some(http_req_cancellation_token.clone().drop_guard()); - // `request_handler` is not cancel safe. It expects to be cancelled only at specific times. - // By spawning the future, we ensure it never gets cancelled until it decides to. - let handle = self.connections.spawn( - request_handler( - req, - self.config, - self.backend.clone(), - self.ws_tx.clone(), - session_id, - self.peer_addr, - http_request_token, - ) - .in_current_span() - .map_ok_or_else(api_error_into_response, |r| r), + let handle = request_handler( + req, + self.config, + self.backend.clone(), + self.ws_tx.clone(), + session_id, + self.peer_addr, + http_req_cancellation_token, + &self.connection_token, ); ReqFut { @@ -498,14 +487,25 @@ impl hyper::service::Service> for ProxyService { struct ReqFut { session_id: uuid::Uuid, cancel_request: Option, - handle: JoinHandle>>, + handle: HandleOrResponse, +} + +enum HandleOrResponse { + Handle(JoinHandle>>), + Response(Option>>), } impl Future for ReqFut { type Output = Result>, tokio::task::JoinError>; fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { - let mut res = std::task::ready!(self.handle.poll_unpin(cx)); + let mut res = match &mut self.handle { + HandleOrResponse::Handle(join_handle) => std::task::ready!(join_handle.poll_unpin(cx)), + HandleOrResponse::Response(response) => { + Ok(response.take().expect("polled after completion")) + } + }; + self.cancel_request .take() .map(tokio_util::sync::DropGuard::disarm); @@ -520,10 +520,11 @@ impl Future for ReqFut { } } -type WsUpgrade = (uuid::Uuid, Option, OnUpgrade); +type WsUpgrade = (RequestMonitoring, Option, OnUpgrade); type WsSpawner = Arc>>; -async fn request_handler( +#[allow(clippy::too_many_arguments)] +fn request_handler( mut request: hyper::Request, config: &'static ProxyConfig, backend: Arc, @@ -532,15 +533,25 @@ async fn request_handler( peer_addr: IpAddr, // used to cancel in-flight HTTP requests. not used to cancel websockets http_cancellation_token: CancellationToken, -) -> Result>, ApiError> { + connection_token: &TaskTrackerToken, +) -> HandleOrResponse { // Check if the request is a websocket upgrade request. if framed_websockets::upgrade::is_upgrade_request(&request) { let Some(spawner) = ws_spawner.take() else { - return json_response(StatusCode::BAD_REQUEST, "query is not supported"); + return HandleOrResponse::Response(Some( + json_response(StatusCode::BAD_REQUEST, "query is not supported") + .unwrap_or_else(api_error_into_response), + )); }; - let (response, websocket) = framed_websockets::upgrade::upgrade(&mut request) - .map_err(|e| ApiError::BadRequest(e.into()))?; + let (response, websocket) = match framed_websockets::upgrade::upgrade(&mut request) { + Err(e) => { + return HandleOrResponse::Response(Some(api_error_into_response( + ApiError::BadRequest(e.into()), + ))) + } + Ok(upgrade) => upgrade, + }; let host = request .headers() @@ -549,12 +560,21 @@ async fn request_handler( .and_then(|h| h.split(':').next()) .map(|s| s.to_string()); - spawner.send((session_id, host, websocket)).map_err(|_e| { - ApiError::InternalServerError(anyhow::anyhow!("could not upgrade WS connection")) - })?; + let ctx = RequestMonitoring::new( + session_id, + peer_addr, + crate::metrics::Protocol::Ws, + &config.region, + ); + + if let Err(_e) = spawner.send((ctx, host, websocket)) { + return HandleOrResponse::Response(Some(api_error_into_response( + ApiError::InternalServerError(anyhow::anyhow!("could not upgrade WS connection")), + ))); + } // Return the response so the spawned future can continue. - Ok(response.map(|b| b.map_err(|x| match x {}).boxed())) + HandleOrResponse::Response(Some(response.map(|b| b.map_err(|x| match x {}).boxed()))) } else if request.uri().path() == "/sql" && *request.method() == Method::POST { let ctx = RequestMonitoring::new( session_id, @@ -564,11 +584,19 @@ async fn request_handler( ); let span = ctx.span(); - sql_over_http::handle(config, ctx, request, backend, http_cancellation_token) - .instrument(span) - .await + let token = connection_token.clone(); + + // `sql_over_http::handle` is not cancel safe. It expects to be cancelled only at specific times. + // By spawning the future, we ensure it never gets cancelled until it decides to. + HandleOrResponse::Handle(tokio::spawn(async move { + let _token = token; + sql_over_http::handle(config, ctx, request, backend, http_cancellation_token) + .instrument(span) + .await + .unwrap_or_else(api_error_into_response) + })) } else if request.uri().path() == "/sql" && *request.method() == Method::OPTIONS { - Response::builder() + HandleOrResponse::Response(Some( Response::builder() .header("Allow", "OPTIONS, POST") .header("Access-Control-Allow-Origin", "*") .header( @@ -578,8 +606,11 @@ async fn request_handler( .header("Access-Control-Max-Age", "86400" /* 24 hours */) .status(StatusCode::OK) // 204 is also valid, but see: https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/OPTIONS#status_code .body(Empty::new().map_err(|x| match x {}).boxed()) - .map_err(|e| ApiError::InternalServerError(e.into())) + .map_err(|e| ApiError::InternalServerError(e.into())).unwrap_or_else(api_error_into_response))) } else { - json_response(StatusCode::BAD_REQUEST, "query is not supported") + HandleOrResponse::Response(Some( + json_response(StatusCode::BAD_REQUEST, "query is not supported") + .unwrap_or_else(api_error_into_response), + )) } }