From 2acc621c4c70ba912db51193457045109eb31e1f Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Wed, 30 Oct 2024 09:55:46 +0000 Subject: [PATCH] random changes --- proxy/src/serverless/mod.rs | 131 +++++++++++++++++++----------------- 1 file changed, 71 insertions(+), 60 deletions(-) diff --git a/proxy/src/serverless/mod.rs b/proxy/src/serverless/mod.rs index f71fa86e57..366c7ef1e9 100644 --- a/proxy/src/serverless/mod.rs +++ b/proxy/src/serverless/mod.rs @@ -34,6 +34,7 @@ use hyper_util::rt::TokioExecutor; use hyper_util::server::conn::auto::Builder; use rand::rngs::StdRng; use rand::SeedableRng; +use smallvec::SmallVec; use sql_over_http::{uuid_to_header_value, NEON_REQUEST_ID}; use tokio::io::{AsyncRead, AsyncWrite}; use tokio::net::{TcpListener, TcpStream}; @@ -182,7 +183,7 @@ pub async fn task_main( peer_addr, )) .await; - let Some((conn, peer_addr)) = startup_result else { + let Some((conn, peer_addr, alpn)) = startup_result else { return; }; @@ -195,6 +196,7 @@ pub async fn task_main( conn_token, conn, peer_addr, + alpn, session_id, )) .await; @@ -214,13 +216,19 @@ pub(crate) type AsyncRW = Pin>; #[async_trait] trait MaybeTlsAcceptor: Send + Sync + 'static { - async fn accept(self: Arc, conn: ChainRW) -> std::io::Result; + async fn accept(self: Arc, conn: ChainRW) -> std::io::Result<(AsyncRW, Alpn)>; } #[async_trait] impl MaybeTlsAcceptor for rustls::ServerConfig { - async fn accept(self: Arc, conn: ChainRW) -> std::io::Result { - Ok(Box::pin(TlsAcceptor::from(self).accept(conn).await?)) + async fn accept(self: Arc, conn: ChainRW) -> std::io::Result<(AsyncRW, Alpn)> { + let conn = TlsAcceptor::from(self).accept(conn).await?; + let alpn = conn + .get_ref() + .1 + .alpn_protocol() + .map_or_else(SmallVec::new, SmallVec::from_slice); + Ok((Box::pin(conn), alpn)) } } @@ -228,11 +236,13 @@ struct NoTls; #[async_trait] impl MaybeTlsAcceptor for NoTls { - async fn accept(self: Arc, conn: ChainRW) -> std::io::Result { - Ok(Box::pin(conn)) + async fn accept(self: Arc, conn: ChainRW) -> std::io::Result<(AsyncRW, Alpn)> { + Ok((Box::pin(conn), SmallVec::new())) } } +type Alpn = SmallVec<[u8; 8]>; + /// Handles the TCP startup lifecycle. /// 1. Parses PROXY protocol V2 /// 2. Handles TLS handshake @@ -242,7 +252,7 @@ async fn connection_startup( session_id: uuid::Uuid, conn: TcpStream, peer_addr: SocketAddr, -) -> Option<(AsyncRW, IpAddr)> { +) -> Option<(AsyncRW, IpAddr, Alpn)> { // handle PROXY protocol let (conn, peer) = match read_proxy_protocol(conn).await { Ok(c) => c, @@ -260,7 +270,7 @@ async fn connection_startup( info!(?session_id, %peer_addr, "accepted new TCP connection"); // try upgrade to TLS, but with a timeout. - let conn = match timeout(config.handshake_timeout, tls_acceptor.accept(conn)).await { + let (conn, alpn) = match timeout(config.handshake_timeout, tls_acceptor.accept(conn)).await { Ok(Ok(conn)) => { info!(?session_id, %peer_addr, "accepted new TLS connection"); conn @@ -283,7 +293,7 @@ async fn connection_startup( } }; - Some((conn, peer_addr)) + Some((conn, peer_addr, alpn)) } /// Handles HTTP connection @@ -300,6 +310,7 @@ async fn connection_handler( cancellation_token: CancellationToken, conn: AsyncRW, peer_addr: IpAddr, + _alpn: Alpn, session_id: uuid::Uuid, ) { let session_id = AtomicTake::new(session_id); @@ -311,59 +322,59 @@ async fn connection_handler( let (ws_tx, mut ws_rx) = mpsc::channel(1); let auth_backend = backend.auth_backend; + let service = hyper::service::service_fn(move |req: hyper::Request| { + // First HTTP request shares the same session ID + let mut session_id = session_id.take().unwrap_or_else(uuid::Uuid::new_v4); + + if matches!(backend.auth_backend, crate::auth::Backend::Local(_)) { + // take session_id from request, if given. + if let Some(id) = req + .headers() + .get(&NEON_REQUEST_ID) + .and_then(|id| uuid::Uuid::try_parse_ascii(id.as_bytes()).ok()) + { + session_id = id; + } + } + + // Cancel the current inflight HTTP request if the requets stream is closed. + // This is slightly different to `_cancel_connection` in that + // h2 can cancel individual requests with a `RST_STREAM`. + let http_request_token = http_cancellation_token.child_token(); + let cancel_request = http_request_token.clone().drop_guard(); + + // `request_handler` is not cancel safe. It expects to be cancelled only at specific times. + // By spawning the future, we ensure it never gets cancelled until it decides to. + let handler = connections.spawn( + request_handler( + req, + config, + backend.clone(), + ws_tx.clone(), + session_id, + peer_addr, + http_request_token, + ) + .in_current_span() + .map_ok_or_else(api_error_into_response, |r| r), + ); + async move { + let mut res = handler.await; + cancel_request.disarm(); + + // add the session ID to the response + if let Ok(resp) = &mut res { + resp.headers_mut() + .append(&NEON_REQUEST_ID, uuid_to_header_value(session_id)); + } + + res + } + }); + let server = Builder::new(TokioExecutor::new()); - let conn = server.serve_connection_with_upgrades( - hyper_util::rt::TokioIo::new(conn), - hyper::service::service_fn(move |req: hyper::Request| { - // First HTTP request shares the same session ID - let mut session_id = session_id.take().unwrap_or_else(uuid::Uuid::new_v4); - if matches!(backend.auth_backend, crate::auth::Backend::Local(_)) { - // take session_id from request, if given. - if let Some(id) = req - .headers() - .get(&NEON_REQUEST_ID) - .and_then(|id| uuid::Uuid::try_parse_ascii(id.as_bytes()).ok()) - { - session_id = id; - } - } - - // Cancel the current inflight HTTP request if the requets stream is closed. - // This is slightly different to `_cancel_connection` in that - // h2 can cancel individual requests with a `RST_STREAM`. - let http_request_token = http_cancellation_token.child_token(); - let cancel_request = http_request_token.clone().drop_guard(); - - // `request_handler` is not cancel safe. It expects to be cancelled only at specific times. - // By spawning the future, we ensure it never gets cancelled until it decides to. - let handler = connections.spawn( - request_handler( - req, - config, - backend.clone(), - ws_tx.clone(), - session_id, - peer_addr, - http_request_token, - ) - .in_current_span() - .map_ok_or_else(api_error_into_response, |r| r), - ); - async move { - let mut res = handler.await; - cancel_request.disarm(); - - // add the session ID to the response - if let Ok(resp) = &mut res { - resp.headers_mut() - .append(&NEON_REQUEST_ID, uuid_to_header_value(session_id)); - } - - res - } - }), - ); + let conn = server.serve_connection_with_upgrades(hyper_util::rt::TokioIo::new(conn), service); // On cancellation, trigger the HTTP connection handler to shut down. let res = match select(pin!(cancellation_token.cancelled()), pin!(conn)).await {