From ba714431bef3edf21ba2dae8e93376e0cb96ef0e Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Wed, 30 Oct 2024 11:13:56 +0000 Subject: [PATCH] oneshot --- proxy/src/serverless/mod.rs | 44 ++++++++++++++++++++++--------------- 1 file changed, 26 insertions(+), 18 deletions(-) diff --git a/proxy/src/serverless/mod.rs b/proxy/src/serverless/mod.rs index 16bdc1983b..540daa934c 100644 --- a/proxy/src/serverless/mod.rs +++ b/proxy/src/serverless/mod.rs @@ -38,7 +38,7 @@ use smallvec::SmallVec; use sql_over_http::{uuid_to_header_value, NEON_REQUEST_ID}; use tokio::io::{AsyncRead, AsyncWrite}; use tokio::net::{TcpListener, TcpStream}; -use tokio::sync::mpsc; +use tokio::sync::oneshot; use tokio::time::timeout; use tokio_rustls::TlsAcceptor; use tokio_util::sync::CancellationToken; @@ -304,9 +304,9 @@ impl GracefulShutdown for hyper::server::conn::http1::UpgradeableConnection, S> where S: hyper::service::HttpService + Send, - S::Future: Send + 'static, + S::Future: Send, B: hyper::body::Body + Send + 'static, - B::Data: Send + 'static, + B::Data: Send, S::Error: Into>, B::Error: Into>, @@ -319,9 +319,9 @@ where impl GracefulShutdown for hyper::server::conn::http1::Connection, S> where S: hyper::service::HttpService + Send, - S::Future: Send + 'static, + S::Future: Send, B: hyper::body::Body + Send + 'static, - B::Data: Send + 'static, + B::Data: Send, S::Error: Into>, B::Error: Into>, @@ -337,7 +337,7 @@ where S: hyper::service::HttpService + Send, S::Future: Send + 'static, B: hyper::body::Body + Send + 'static, - B::Data: Send + 'static, + B::Data: Send, S::Error: Into>, B::Error: Into>, @@ -370,7 +370,8 @@ async fn connection_handler( let http_cancellation_token = CancellationToken::new(); let _cancel_connection = http_cancellation_token.clone().drop_guard(); - let (ws_tx, mut ws_rx) = mpsc::channel(1); + let (ws_tx, ws_rx) = oneshot::channel(); + let ws_tx = Arc::new(AtomicTake::new(ws_tx)); let auth_backend = backend.auth_backend; let http2 = match &*alpn { @@ -382,6 +383,11 @@ async fn connection_handler( } }; + if http2 || !config.http_config.accept_websockets { + // discard the ws spawner + ws_tx.take(); + } + let service = hyper::service::service_fn(move |req: hyper::Request| { // First HTTP request shares the same session ID let mut session_id = session_id.take().unwrap_or_else(uuid::Uuid::new_v4); @@ -460,7 +466,7 @@ async fn connection_handler( match res { Ok(()) => { - if let Some((session_id, host, websocket)) = ws_rx.recv().await { + if let Ok((session_id, host, websocket)) = ws_rx.await { tracing::info!(%peer_addr, "connection upgraded to websockets"); let ctx = RequestMonitoring::new( @@ -491,21 +497,26 @@ async fn connection_handler( } } +type WsUpgrade = (uuid::Uuid, Option, OnUpgrade); +type WsSpawner = Arc>>; + #[allow(clippy::too_many_arguments)] async fn request_handler( mut request: hyper::Request, config: &'static ProxyConfig, backend: Arc, - ws_spawner: mpsc::Sender<(uuid::Uuid, Option, OnUpgrade)>, + ws_spawner: WsSpawner, session_id: uuid::Uuid, peer_addr: IpAddr, // used to cancel in-flight HTTP requests. not used to cancel websockets http_cancellation_token: CancellationToken, ) -> Result>, ApiError> { // Check if the request is a websocket upgrade request. - if config.http_config.accept_websockets - && framed_websockets::upgrade::is_upgrade_request(&request) - { + if framed_websockets::upgrade::is_upgrade_request(&request) { + let Some(spawner) = ws_spawner.take() else { + return json_response(StatusCode::BAD_REQUEST, "query is not supported"); + }; + let (response, websocket) = framed_websockets::upgrade::upgrade(&mut request) .map_err(|e| ApiError::BadRequest(e.into()))?; @@ -516,12 +527,9 @@ async fn request_handler( .and_then(|h| h.split(':').next()) .map(|s| s.to_string()); - ws_spawner - .send((session_id, host, websocket)) - .await - .map_err(|_e| { - ApiError::InternalServerError(anyhow::anyhow!("could not upgrade WS connection")) - })?; + spawner.send((session_id, host, websocket)).map_err(|_e| { + ApiError::InternalServerError(anyhow::anyhow!("could not upgrade WS connection")) + })?; // Return the response so the spawned future can continue. Ok(response.map(|b| b.map_err(|x| match x {}).boxed()))