diff --git a/proxy/src/serverless/mod.rs b/proxy/src/serverless/mod.rs index 540daa934c..5f2ea773a1 100644 --- a/proxy/src/serverless/mod.rs +++ b/proxy/src/serverless/mod.rs @@ -18,6 +18,7 @@ use std::future::Future; use std::net::{IpAddr, SocketAddr}; use std::pin::{pin, Pin}; use std::sync::Arc; +use std::task::Poll; use anyhow::Context; use async_trait::async_trait; @@ -25,7 +26,7 @@ use atomic_take::AtomicTake; use bytes::Bytes; pub use conn_pool_lib::GlobalConnPoolOptions; use futures::future::{select, Either}; -use futures::TryFutureExt; +use futures::{FutureExt, TryFutureExt}; use http::{Method, Response, StatusCode}; use http_body_util::combinators::BoxBody; use http_body_util::{BodyExt, Empty}; @@ -39,6 +40,7 @@ use sql_over_http::{uuid_to_header_value, NEON_REQUEST_ID}; use tokio::io::{AsyncRead, AsyncWrite}; use tokio::net::{TcpListener, TcpStream}; use tokio::sync::oneshot; +use tokio::task::JoinHandle; use tokio::time::timeout; use tokio_rustls::TlsAcceptor; use tokio_util::sync::CancellationToken; @@ -300,50 +302,25 @@ trait GracefulShutdown: Future> + Send { fn graceful_shutdown(self: Pin<&mut Self>); } -impl GracefulShutdown - for hyper::server::conn::http1::UpgradeableConnection, S> -where - S: hyper::service::HttpService + Send, - S::Future: Send, - B: hyper::body::Body + Send + 'static, - B::Data: Send, - - S::Error: Into>, - B::Error: Into>, +impl GracefulShutdown + for hyper::server::conn::http1::UpgradeableConnection, ProxyService> { fn graceful_shutdown(self: Pin<&mut Self>) { self.graceful_shutdown(); } } -impl GracefulShutdown for hyper::server::conn::http1::Connection, S> -where - S: hyper::service::HttpService + Send, - S::Future: Send, - B: hyper::body::Body + Send + 'static, - B::Data: Send, - - S::Error: Into>, - B::Error: Into>, -{ +impl GracefulShutdown for hyper::server::conn::http1::Connection, ProxyService> { fn graceful_shutdown(self: Pin<&mut Self>) { self.graceful_shutdown(); } } -impl GracefulShutdown - for hyper::server::conn::http2::Connection, S, TokioExecutor> -where - S: hyper::service::HttpService + Send, - S::Future: Send + 'static, - B: hyper::body::Body + Send + 'static, - B::Data: Send, - - S::Error: Into>, - B::Error: Into>, +impl GracefulShutdown + for hyper::server::conn::http2::Connection, ProxyService, TokioExecutor> { fn graceful_shutdown(self: Pin<&mut Self>) { - hyper::server::conn::http2::Connection::graceful_shutdown(self); + self.graceful_shutdown(); } } @@ -383,75 +360,35 @@ async fn connection_handler( } }; - if http2 || !config.http_config.accept_websockets { - // discard the ws spawner - ws_tx.take(); - } + let service = ProxyService { + config, + backend, + connections, - let service = hyper::service::service_fn(move |req: hyper::Request| { - // First HTTP request shares the same session ID - let mut session_id = session_id.take().unwrap_or_else(uuid::Uuid::new_v4); - - if matches!(backend.auth_backend, crate::auth::Backend::Local(_)) { - // take session_id from request, if given. - if let Some(id) = req - .headers() - .get(&NEON_REQUEST_ID) - .and_then(|id| uuid::Uuid::try_parse_ascii(id.as_bytes()).ok()) - { - session_id = id; - } - } - - // Cancel the current inflight HTTP request if the requets stream is closed. - // This is slightly different to `_cancel_connection` in that - // h2 can cancel individual requests with a `RST_STREAM`. - let http_request_token = http_cancellation_token.child_token(); - let cancel_request = http_request_token.clone().drop_guard(); - - // `request_handler` is not cancel safe. It expects to be cancelled only at specific times. - // By spawning the future, we ensure it never gets cancelled until it decides to. - let handler = connections.spawn( - request_handler( - req, - config, - backend.clone(), - ws_tx.clone(), - session_id, - peer_addr, - http_request_token, - ) - .in_current_span() - .map_ok_or_else(api_error_into_response, |r| r), - ); - async move { - let mut res = handler.await; - cancel_request.disarm(); - - // add the session ID to the response - if let Ok(resp) = &mut res { - resp.headers_mut() - .append(&NEON_REQUEST_ID, uuid_to_header_value(session_id)); - } - - res - } - }); + http_cancellation_token, + ws_tx, + peer_addr, + session_id, + }; let io = hyper_util::rt::TokioIo::new(conn); - let conn = if http2 { - let conn = hyper::server::conn::http2::Builder::new(TokioExecutor::new()) - .serve_connection(io, service); + let conn: Pin> = if http2 { + service.ws_tx.take(); - Box::pin(conn) as Pin> + Box::pin( + hyper::server::conn::http2::Builder::new(TokioExecutor::new()) + .serve_connection(io, service), + ) + } else if config.http_config.accept_websockets { + Box::pin( + hyper::server::conn::http1::Builder::new() + .serve_connection(io, service) + .with_upgrades(), + ) } else { - let serve = hyper::server::conn::http1::Builder::new().serve_connection(io, service); + service.ws_tx.take(); - if config.http_config.accept_websockets { - Box::pin(serve.with_upgrades()) as Pin> - } else { - Box::pin(serve) as Pin> - } + Box::pin(hyper::server::conn::http1::Builder::new().serve_connection(io, service)) }; // On cancellation, trigger the HTTP connection handler to shut down. @@ -497,10 +434,99 @@ async fn connection_handler( } } +struct ProxyService { + // global state + config: &'static ProxyConfig, + backend: Arc, + connections: TaskTracker, + + // connection state only + http_cancellation_token: CancellationToken, + ws_tx: WsSpawner, + peer_addr: IpAddr, + session_id: AtomicTake, +} + +impl hyper::service::Service> for ProxyService { + type Response = Response>; + + type Error = tokio::task::JoinError; + + type Future = ReqFut; + + fn call(&self, req: hyper::Request) -> Self::Future { + // First HTTP request shares the same session ID + let mut session_id = self.session_id.take().unwrap_or_else(uuid::Uuid::new_v4); + + if matches!(self.backend.auth_backend, crate::auth::Backend::Local(_)) { + // take session_id from request, if given. + if let Some(id) = req + .headers() + .get(&NEON_REQUEST_ID) + .and_then(|id| uuid::Uuid::try_parse_ascii(id.as_bytes()).ok()) + { + session_id = id; + } + } + + // Cancel the current inflight HTTP request if the requets stream is closed. + // This is slightly different to `_cancel_connection` in that + // h2 can cancel individual requests with a `RST_STREAM`. + let http_request_token = self.http_cancellation_token.child_token(); + let cancel_request = Some(http_request_token.clone().drop_guard()); + + // `request_handler` is not cancel safe. It expects to be cancelled only at specific times. + // By spawning the future, we ensure it never gets cancelled until it decides to. + let handle = self.connections.spawn( + request_handler( + req, + self.config, + self.backend.clone(), + self.ws_tx.clone(), + session_id, + self.peer_addr, + http_request_token, + ) + .in_current_span() + .map_ok_or_else(api_error_into_response, |r| r), + ); + + ReqFut { + session_id, + cancel_request, + handle, + } + } +} + +struct ReqFut { + session_id: uuid::Uuid, + cancel_request: Option, + handle: JoinHandle>>, +} + +impl Future for ReqFut { + type Output = Result>, tokio::task::JoinError>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { + let mut res = std::task::ready!(self.handle.poll_unpin(cx)); + self.cancel_request + .take() + .map(tokio_util::sync::DropGuard::disarm); + + // add the session ID to the response + if let Ok(resp) = &mut res { + resp.headers_mut() + .append(&NEON_REQUEST_ID, uuid_to_header_value(self.session_id)); + } + + Poll::Ready(res) + } +} + type WsUpgrade = (uuid::Uuid, Option, OnUpgrade); type WsSpawner = Arc>>; -#[allow(clippy::too_many_arguments)] async fn request_handler( mut request: hyper::Request, config: &'static ProxyConfig,