From 7de2914ddecb4689d4ff85c9d4bcb5d642166e75 Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Wed, 30 Oct 2024 09:45:19 +0000 Subject: [PATCH] re-use the same tokio task for the websocket handling --- proxy/src/serverless/mod.rs | 102 +++++++++++++++++++----------------- 1 file changed, 54 insertions(+), 48 deletions(-) diff --git a/proxy/src/serverless/mod.rs b/proxy/src/serverless/mod.rs index edbb0347d3..f71fa86e57 100644 --- a/proxy/src/serverless/mod.rs +++ b/proxy/src/serverless/mod.rs @@ -29,6 +29,7 @@ use http::{Method, Response, StatusCode}; use http_body_util::combinators::BoxBody; use http_body_util::{BodyExt, Empty}; use hyper::body::Incoming; +use hyper::upgrade::OnUpgrade; use hyper_util::rt::TokioExecutor; use hyper_util::server::conn::auto::Builder; use rand::rngs::StdRng; @@ -36,6 +37,7 @@ use rand::SeedableRng; use sql_over_http::{uuid_to_header_value, NEON_REQUEST_ID}; use tokio::io::{AsyncRead, AsyncWrite}; use tokio::net::{TcpListener, TcpStream}; +use tokio::sync::mpsc; use tokio::time::timeout; use tokio_rustls::TlsAcceptor; use tokio_util::sync::CancellationToken; @@ -306,6 +308,9 @@ async fn connection_handler( let http_cancellation_token = CancellationToken::new(); let _cancel_connection = http_cancellation_token.clone().drop_guard(); + let (ws_tx, mut ws_rx) = mpsc::channel(1); + let auth_backend = backend.auth_backend; + let server = Builder::new(TokioExecutor::new()); let conn = server.serve_connection_with_upgrades( hyper_util::rt::TokioIo::new(conn), @@ -337,12 +342,10 @@ async fn connection_handler( req, config, backend.clone(), - connections.clone(), - cancellation_handler.clone(), + ws_tx.clone(), session_id, peer_addr, http_request_token, - endpoint_rate_limiter.clone(), ) .in_current_span() .map_ok_or_else(api_error_into_response, |r| r), @@ -373,53 +376,20 @@ async fn connection_handler( }; match res { - Ok(()) => tracing::info!(%peer_addr, "HTTP connection closed"), - Err(e) => tracing::warn!(%peer_addr, "HTTP connection error {e}"), - } -} + Ok(()) => { + if let Some((session_id, host, websocket)) = ws_rx.recv().await { + tracing::info!(%peer_addr, "connection upgraded to websockets"); -#[allow(clippy::too_many_arguments)] -async fn request_handler( - mut request: hyper::Request, - config: &'static ProxyConfig, - backend: Arc, - ws_connections: TaskTracker, - cancellation_handler: Arc, - session_id: uuid::Uuid, - peer_addr: IpAddr, - // used to cancel in-flight HTTP requests. not used to cancel websockets - http_cancellation_token: CancellationToken, - endpoint_rate_limiter: Arc, -) -> Result>, ApiError> { - let host = request - .headers() - .get("host") - .and_then(|h| h.to_str().ok()) - .and_then(|h| h.split(':').next()) - .map(|s| s.to_string()); + let ctx = RequestMonitoring::new( + session_id, + peer_addr, + crate::metrics::Protocol::Ws, + &config.region, + ); - // Check if the request is a websocket upgrade request. - if config.http_config.accept_websockets - && framed_websockets::upgrade::is_upgrade_request(&request) - { - let ctx = RequestMonitoring::new( - session_id, - peer_addr, - crate::metrics::Protocol::Ws, - &config.region, - ); - - let span = ctx.span(); - info!(parent: &span, "performing websocket upgrade"); - - let (response, websocket) = framed_websockets::upgrade::upgrade(&mut request) - .map_err(|e| ApiError::BadRequest(e.into()))?; - - ws_connections.spawn( - async move { if let Err(e) = websocket::serve_websocket( config, - backend.auth_backend, + auth_backend, ctx, websocket, cancellation_handler, @@ -430,9 +400,45 @@ async fn request_handler( { warn!("error in websocket connection: {e:#}"); } + } else { + tracing::info!(%peer_addr, "HTTP connection closed"); } - .instrument(span), - ); + } + Err(e) => tracing::warn!(%peer_addr, "HTTP connection error {e}"), + } +} + +#[allow(clippy::too_many_arguments)] +async fn request_handler( + mut request: hyper::Request, + config: &'static ProxyConfig, + backend: Arc, + ws_spawner: mpsc::Sender<(uuid::Uuid, Option, OnUpgrade)>, + session_id: uuid::Uuid, + peer_addr: IpAddr, + // used to cancel in-flight HTTP requests. not used to cancel websockets + http_cancellation_token: CancellationToken, +) -> Result>, ApiError> { + // Check if the request is a websocket upgrade request. + if config.http_config.accept_websockets + && framed_websockets::upgrade::is_upgrade_request(&request) + { + let (response, websocket) = framed_websockets::upgrade::upgrade(&mut request) + .map_err(|e| ApiError::BadRequest(e.into()))?; + + let host = request + .headers() + .get("host") + .and_then(|h| h.to_str().ok()) + .and_then(|h| h.split(':').next()) + .map(|s| s.to_string()); + + ws_spawner + .send((session_id, host, websocket)) + .await + .map_err(|_e| { + ApiError::InternalServerError(anyhow::anyhow!("could not upgrade WS connection")) + })?; // Return the response so the spawned future can continue. Ok(response.map(|b| b.map_err(|x| match x {}).boxed()))