diff --git a/Cargo.lock b/Cargo.lock index 4c2bcf250e..bdf2b08c5c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -270,6 +270,12 @@ dependencies = [ "critical-section", ] +[[package]] +name = "atomic-take" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8ab6b55fe97976e46f91ddbed8d147d966475dc29b2032757ba47e02376fbc3" + [[package]] name = "autocfg" version = "1.1.0" @@ -298,7 +304,7 @@ dependencies = [ "fastrand 2.0.0", "hex", "http 0.2.9", - "hyper", + "hyper 0.14.26", "ring 0.17.6", "time", "tokio", @@ -335,7 +341,7 @@ dependencies = [ "bytes", "fastrand 2.0.0", "http 0.2.9", - "http-body", + "http-body 0.4.5", "percent-encoding", "pin-project-lite", "tracing", @@ -386,7 +392,7 @@ dependencies = [ "aws-types", "bytes", "http 0.2.9", - "http-body", + "http-body 0.4.5", "once_cell", "percent-encoding", "regex-lite", @@ -514,7 +520,7 @@ dependencies = [ "crc32fast", "hex", "http 0.2.9", - "http-body", + "http-body 0.4.5", "md-5", "pin-project-lite", "sha1", @@ -546,7 +552,7 @@ dependencies = [ "bytes-utils", "futures-core", "http 0.2.9", - "http-body", + "http-body 0.4.5", "once_cell", "percent-encoding", "pin-project-lite", @@ -585,10 +591,10 @@ dependencies = [ "aws-smithy-types", "bytes", "fastrand 2.0.0", - "h2", + "h2 0.3.26", "http 0.2.9", - "http-body", - "hyper", + "http-body 0.4.5", + "hyper 0.14.26", "hyper-rustls", "once_cell", "pin-project-lite", @@ -626,7 +632,7 @@ dependencies = [ "bytes-utils", "futures-core", "http 0.2.9", - "http-body", + "http-body 0.4.5", "itoa", "num-integer", "pin-project-lite", @@ -675,8 +681,8 @@ dependencies = [ "bytes", "futures-util", "http 0.2.9", - "http-body", - "hyper", + "http-body 0.4.5", + "hyper 0.14.26", "itoa", "matchit", "memchr", @@ -691,7 +697,7 @@ dependencies = [ "sha1", "sync_wrapper", "tokio", - "tokio-tungstenite", + "tokio-tungstenite 0.20.0", "tower", "tower-layer", "tower-service", @@ -707,7 +713,7 @@ dependencies = [ "bytes", "futures-util", "http 0.2.9", - "http-body", + "http-body 0.4.5", "mime", "rustversion", "tower-layer", @@ -1196,7 +1202,7 @@ dependencies = [ "compute_api", "flate2", "futures", - "hyper", + "hyper 0.14.26", "nix 0.27.1", "notify", "num_cpus", @@ -1313,7 +1319,7 @@ dependencies = [ "git-version", "hex", "humantime", - "hyper", + "hyper 0.14.26", "nix 0.27.1", "once_cell", "pageserver_api", @@ -2199,6 +2205,25 @@ dependencies = [ "tracing", ] +[[package]] +name = "h2" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "816ec7294445779408f36fe57bc5b7fc1cf59664059096c65f905c1c61f58069" +dependencies = [ + "bytes", + "fnv", + "futures-core", + "futures-sink", + "futures-util", + "http 1.1.0", + "indexmap 2.0.1", + "slab", + "tokio", + "tokio-util", + "tracing", +] + [[package]] name = "half" version = "1.8.2" @@ -2370,6 +2395,29 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "http-body" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1cac85db508abc24a2e48553ba12a996e87244a0395ce011e62b37158745d643" +dependencies = [ + "bytes", + "http 1.1.0", +] + +[[package]] +name = "http-body-util" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41cb79eb393015dadd30fc252023adb0b2400a0caee0fa2a077e6e21a551e840" +dependencies = [ + "bytes", + "futures-util", + "http 1.1.0", + "http-body 1.0.0", + "pin-project-lite", +] + [[package]] name = "http-types" version = "2.12.0" @@ -2428,9 +2476,9 @@ dependencies = [ "futures-channel", "futures-core", "futures-util", - "h2", + "h2 0.3.26", "http 0.2.9", - "http-body", + "http-body 0.4.5", "httparse", "httpdate", "itoa", @@ -2442,6 +2490,26 @@ dependencies = [ "want", ] +[[package]] +name = "hyper" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "186548d73ac615b32a73aafe38fb4f56c0d340e110e5a200bcadbaf2e199263a" +dependencies = [ + "bytes", + "futures-channel", + "futures-util", + "h2 0.4.4", + "http 1.1.0", + "http-body 1.0.0", + "httparse", + "httpdate", + "itoa", + "pin-project-lite", + "smallvec", + "tokio", +] + [[package]] name = "hyper-rustls" version = "0.24.0" @@ -2449,7 +2517,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0646026eb1b3eea4cd9ba47912ea5ce9cc07713d105b1a14698f4e6433d348b7" dependencies = [ "http 0.2.9", - "hyper", + "hyper 0.14.26", "log", "rustls 0.21.9", "rustls-native-certs 0.6.2", @@ -2463,7 +2531,7 @@ version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bbb958482e8c7be4bc3cf272a766a2b0bf1a6755e7a6ae777f017a31d11b13b1" dependencies = [ - "hyper", + "hyper 0.14.26", "pin-project-lite", "tokio", "tokio-io-timeout", @@ -2476,7 +2544,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d6183ddfa99b85da61a140bea0efc93fdf56ceaa041b37d553518030827f9905" dependencies = [ "bytes", - "hyper", + "hyper 0.14.26", "native-tls", "tokio", "tokio-native-tls", @@ -2484,15 +2552,33 @@ dependencies = [ [[package]] name = "hyper-tungstenite" -version = "0.11.1" +version = "0.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7cc7dcb1ab67cd336f468a12491765672e61a3b6b148634dbfe2fe8acd3fe7d9" +checksum = "7a343d17fe7885302ed7252767dc7bb83609a874b6ff581142241ec4b73957ad" dependencies = [ - "hyper", + "http-body-util", + "hyper 1.2.0", + "hyper-util", "pin-project-lite", "tokio", - "tokio-tungstenite", - "tungstenite", + "tokio-tungstenite 0.21.0", + "tungstenite 0.21.0", +] + +[[package]] +name = "hyper-util" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ca38ef113da30126bbff9cd1705f9273e15d45498615d138b0c20279ac7a76aa" +dependencies = [ + "bytes", + "futures-util", + "http 1.1.0", + "http-body 1.0.0", + "hyper 1.2.0", + "pin-project-lite", + "socket2 0.5.5", + "tokio", ] [[package]] @@ -3523,7 +3609,7 @@ dependencies = [ "hex-literal", "humantime", "humantime-serde", - "hyper", + "hyper 0.14.26", "itertools", "leaky-bucket", "md5", @@ -4202,6 +4288,7 @@ dependencies = [ "anyhow", "async-compression", "async-trait", + "atomic-take", "aws-config", "aws-sdk-iam", "aws-sigv4", @@ -4225,9 +4312,12 @@ dependencies = [ "hmac", "hostname", "http 1.1.0", + "http-body-util", "humantime", - "hyper", + "hyper 0.14.26", + "hyper 1.2.0", "hyper-tungstenite", + "hyper-util", "ipnet", "itertools", "lasso", @@ -4560,7 +4650,7 @@ dependencies = [ "futures-util", "http-types", "humantime", - "hyper", + "hyper 0.14.26", "itertools", "metrics", "once_cell", @@ -4590,10 +4680,10 @@ dependencies = [ "encoding_rs", "futures-core", "futures-util", - "h2", + "h2 0.3.26", "http 0.2.9", - "http-body", - "hyper", + "http-body 0.4.5", + "hyper 0.14.26", "hyper-rustls", "hyper-tls", "ipnet", @@ -4651,7 +4741,7 @@ dependencies = [ "futures", "getrandom 0.2.11", "http 0.2.9", - "hyper", + "hyper 0.14.26", "parking_lot 0.11.2", "reqwest", "reqwest-middleware", @@ -4738,7 +4828,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "496c1d3718081c45ba9c31fbfc07417900aa96f4070ff90dc29961836b7a9945" dependencies = [ "http 0.2.9", - "hyper", + "hyper 0.14.26", "lazy_static", "percent-encoding", "regex", @@ -5043,7 +5133,7 @@ dependencies = [ "git-version", "hex", "humantime", - "hyper", + "hyper 0.14.26", "metrics", "once_cell", "parking_lot 0.12.1", @@ -5528,9 +5618,9 @@ dependencies = [ [[package]] name = "smallvec" -version = "1.11.0" +version = "1.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62bb4feee49fdd9f707ef802e22365a35de4b7b299de4763d44bfea899442ff9" +checksum = "e6ecd384b10a64542d77071bd64bd7b231f4ed5940fba55e98c3de13824cf3d7" [[package]] name = "smol_str" @@ -5622,7 +5712,7 @@ dependencies = [ "futures-util", "git-version", "humantime", - "hyper", + "hyper 0.14.26", "metrics", "once_cell", "parking_lot 0.12.1", @@ -5653,7 +5743,7 @@ dependencies = [ "git-version", "hex", "humantime", - "hyper", + "hyper 0.14.26", "itertools", "lasso", "measured", @@ -5682,7 +5772,7 @@ dependencies = [ "anyhow", "clap", "comfy-table", - "hyper", + "hyper 0.14.26", "pageserver_api", "pageserver_client", "reqwest", @@ -6165,7 +6255,19 @@ dependencies = [ "futures-util", "log", "tokio", - "tungstenite", + "tungstenite 0.20.1", +] + +[[package]] +name = "tokio-tungstenite" +version = "0.21.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c83b561d025642014097b66e6c1bb422783339e0909e4429cde4749d1990bc38" +dependencies = [ + "futures-util", + "log", + "tokio", + "tungstenite 0.21.0", ] [[package]] @@ -6232,10 +6334,10 @@ dependencies = [ "bytes", "futures-core", "futures-util", - "h2", + "h2 0.3.26", "http 0.2.9", - "http-body", - "hyper", + "http-body 0.4.5", + "hyper 0.14.26", "hyper-timeout", "percent-encoding", "pin-project", @@ -6421,7 +6523,7 @@ dependencies = [ name = "tracing-utils" version = "0.1.0" dependencies = [ - "hyper", + "hyper 0.14.26", "opentelemetry", "opentelemetry-otlp", "opentelemetry-semantic-conventions", @@ -6458,6 +6560,25 @@ dependencies = [ "utf-8", ] +[[package]] +name = "tungstenite" +version = "0.21.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ef1a641ea34f399a848dea702823bbecfb4c486f911735368f1f137cb8257e1" +dependencies = [ + "byteorder", + "bytes", + "data-encoding", + "http 1.1.0", + "httparse", + "log", + "rand 0.8.5", + "sha1", + "thiserror", + "url", + "utf-8", +] + [[package]] name = "twox-hash" version = "1.6.3" @@ -6623,7 +6744,7 @@ dependencies = [ "hex", "hex-literal", "humantime", - "hyper", + "hyper 0.14.26", "jsonwebtoken", "leaky-bucket", "metrics", @@ -7214,7 +7335,7 @@ dependencies = [ "hashbrown 0.14.0", "hex", "hmac", - "hyper", + "hyper 0.14.26", "indexmap 1.9.3", "itertools", "libc", @@ -7252,7 +7373,6 @@ dependencies = [ "tower", "tracing", "tracing-core", - "tungstenite", "url", "uuid", "zeroize", diff --git a/Cargo.toml b/Cargo.toml index 5db6b7016a..feea17ab05 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -44,6 +44,7 @@ license = "Apache-2.0" anyhow = { version = "1.0", features = ["backtrace"] } arc-swap = "1.6" async-compression = { version = "0.4.0", features = ["tokio", "gzip", "zstd"] } +atomic-take = "1.1.0" azure_core = "0.18" azure_identity = "0.18" azure_storage = "0.18" @@ -97,7 +98,7 @@ http-types = { version = "2", default-features = false } humantime = "2.1" humantime-serde = "1.1.1" hyper = "0.14" -hyper-tungstenite = "0.11" +hyper-tungstenite = "0.13.0" inotify = "0.10.2" ipnet = "2.9.0" itertools = "0.10" diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index b327890be2..12bd67ea36 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -12,6 +12,7 @@ testing = [] anyhow.workspace = true async-compression.workspace = true async-trait.workspace = true +atomic-take.workspace = true aws-config.workspace = true aws-sdk-iam.workspace = true aws-sigv4.workspace = true @@ -36,6 +37,9 @@ http.workspace = true humantime.workspace = true hyper-tungstenite.workspace = true hyper.workspace = true +hyper1 = { package = "hyper", version = "1.2", features = ["server"] } +hyper-util = { version = "0.1", features = ["server", "http1", "http2", "tokio"] } +http-body-util = { version = "0.1" } ipnet.workspace = true itertools.workspace = true lasso = { workspace = true, features = ["multi-threaded"] } diff --git a/proxy/src/protocol2.rs b/proxy/src/protocol2.rs index 700c8c8681..70f9b4bfab 100644 --- a/proxy/src/protocol2.rs +++ b/proxy/src/protocol2.rs @@ -5,19 +5,13 @@ use std::{ io, net::SocketAddr, pin::{pin, Pin}, - sync::Mutex, task::{ready, Context, Poll}, }; use bytes::{Buf, BytesMut}; -use hyper::server::accept::Accept; -use hyper::server::conn::{AddrIncoming, AddrStream}; -use metrics::IntCounterPairGuard; +use hyper::server::conn::AddrIncoming; use pin_project_lite::pin_project; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf}; -use uuid::Uuid; - -use crate::metrics::NUM_CLIENT_CONNECTION_GAUGE; pub struct ProxyProtocolAccept { pub incoming: AddrIncoming, @@ -331,103 +325,6 @@ impl AsyncRead for WithClientIp { } } -impl Accept for ProxyProtocolAccept { - type Conn = WithConnectionGuard>; - - type Error = io::Error; - - fn poll_accept( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll>> { - let conn = ready!(Pin::new(&mut self.incoming).poll_accept(cx)?); - - let conn_id = uuid::Uuid::new_v4(); - let span = tracing::info_span!("http_conn", ?conn_id); - { - let _enter = span.enter(); - tracing::info!("accepted new TCP connection"); - } - - let Some(conn) = conn else { - return Poll::Ready(None); - }; - - Poll::Ready(Some(Ok(WithConnectionGuard { - inner: WithClientIp::new(conn), - connection_id: Uuid::new_v4(), - gauge: Mutex::new(Some( - NUM_CLIENT_CONNECTION_GAUGE - .with_label_values(&[self.protocol]) - .guard(), - )), - span, - }))) - } -} - -pin_project! { - pub struct WithConnectionGuard { - #[pin] - pub inner: T, - pub connection_id: Uuid, - pub gauge: Mutex>, - pub span: tracing::Span, - } - - impl PinnedDrop for WithConnectionGuard { - fn drop(this: Pin<&mut Self>) { - let _enter = this.span.enter(); - tracing::info!("HTTP connection closed") - } - } -} - -impl AsyncWrite for WithConnectionGuard { - #[inline] - fn poll_write( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - self.project().inner.poll_write(cx, buf) - } - - #[inline] - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().inner.poll_flush(cx) - } - - #[inline] - fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().inner.poll_shutdown(cx) - } - - #[inline] - fn poll_write_vectored( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - bufs: &[io::IoSlice<'_>], - ) -> Poll> { - self.project().inner.poll_write_vectored(cx, bufs) - } - - #[inline] - fn is_write_vectored(&self) -> bool { - self.inner.is_write_vectored() - } -} - -impl AsyncRead for WithConnectionGuard { - fn poll_read( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, - ) -> Poll> { - self.project().inner.poll_read(cx, buf) - } -} - #[cfg(test)] mod tests { use std::pin::pin; diff --git a/proxy/src/serverless.rs b/proxy/src/serverless.rs index a2010fd613..f275caa7eb 100644 --- a/proxy/src/serverless.rs +++ b/proxy/src/serverless.rs @@ -4,42 +4,48 @@ mod backend; mod conn_pool; +mod http_util; mod json; mod sql_over_http; -pub mod tls_listener; mod websocket; +use atomic_take::AtomicTake; +use bytes::Bytes; pub use conn_pool::GlobalConnPoolOptions; -use anyhow::bail; -use hyper::StatusCode; -use metrics::IntCounterPairGuard; +use anyhow::Context; +use futures::future::{select, Either}; +use futures::TryFutureExt; +use http::{Method, Response, StatusCode}; +use http_body_util::Full; +use hyper1::body::Incoming; +use hyper_util::rt::TokioExecutor; +use hyper_util::server::conn::auto::Builder; use rand::rngs::StdRng; use rand::SeedableRng; pub use reqwest_middleware::{ClientWithMiddleware, Error}; pub use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware}; +use tokio::time::timeout; +use tokio_rustls::TlsAcceptor; use tokio_util::task::TaskTracker; -use tracing::instrument::Instrumented; use crate::cancellation::CancellationHandlerMain; use crate::config::ProxyConfig; use crate::context::RequestMonitoring; -use crate::protocol2::{ProxyProtocolAccept, WithClientIp, WithConnectionGuard}; +use crate::metrics::{NUM_CLIENT_CONNECTION_GAUGE, TLS_HANDSHAKE_FAILURES}; +use crate::protocol2::WithClientIp; +use crate::proxy::run_until_cancelled; use crate::rate_limiter::EndpointRateLimiter; use crate::serverless::backend::PoolingBackend; -use hyper::{ - server::conn::{AddrIncoming, AddrStream}, - Body, Method, Request, Response, -}; +use crate::serverless::http_util::{api_error_into_response, json_response}; -use std::net::IpAddr; +use std::net::{IpAddr, SocketAddr}; +use std::pin::pin; use std::sync::Arc; -use std::task::Poll; -use tls_listener::TlsListener; -use tokio::net::TcpListener; -use tokio_util::sync::{CancellationToken, DropGuard}; +use tokio::net::{TcpListener, TcpStream}; +use tokio_util::sync::CancellationToken; use tracing::{error, info, warn, Instrument}; -use utils::http::{error::ApiError, json::json_response}; +use utils::http::error::ApiError; pub const SERVERLESS_DRIVER_SNI: &str = "api"; @@ -91,161 +97,174 @@ pub async fn task_main( tls_server_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; let tls_acceptor: tokio_rustls::TlsAcceptor = Arc::new(tls_server_config).into(); - let mut addr_incoming = AddrIncoming::from_listener(ws_listener)?; - let _ = addr_incoming.set_nodelay(true); - let addr_incoming = ProxyProtocolAccept { - incoming: addr_incoming, - protocol: "http", - }; + let connections = tokio_util::task::task_tracker::TaskTracker::new(); + connections.close(); // allows `connections.wait to complete` - let ws_connections = tokio_util::task::task_tracker::TaskTracker::new(); - ws_connections.close(); // allows `ws_connections.wait to complete` + let server = Builder::new(hyper_util::rt::TokioExecutor::new()); - let tls_listener = TlsListener::new(tls_acceptor, addr_incoming, config.handshake_timeout); + while let Some(res) = run_until_cancelled(ws_listener.accept(), &cancellation_token).await { + let (conn, peer_addr) = res.context("could not accept TCP stream")?; + if let Err(e) = conn.set_nodelay(true) { + tracing::error!("could not set nodelay: {e}"); + continue; + } + let conn_id = uuid::Uuid::new_v4(); + let http_conn_span = tracing::info_span!("http_conn", ?conn_id); - let make_svc = hyper::service::make_service_fn( - |stream: &tokio_rustls::server::TlsStream< - WithConnectionGuard>, - >| { - let (conn, _) = stream.get_ref(); + connections.spawn( + connection_handler( + config, + backend.clone(), + connections.clone(), + cancellation_handler.clone(), + endpoint_rate_limiter.clone(), + cancellation_token.clone(), + server.clone(), + tls_acceptor.clone(), + conn, + peer_addr, + ) + .instrument(http_conn_span), + ); + } - // this is jank. should dissapear with hyper 1.0 migration. - let gauge = conn - .gauge - .lock() - .expect("lock should not be poisoned") - .take() - .expect("gauge should be set on connection start"); - - // Cancel all current inflight HTTP requests if the HTTP connection is closed. - let http_cancellation_token = CancellationToken::new(); - let cancel_connection = http_cancellation_token.clone().drop_guard(); - - let span = conn.span.clone(); - let client_addr = conn.inner.client_addr(); - let remote_addr = conn.inner.inner.remote_addr(); - let backend = backend.clone(); - let ws_connections = ws_connections.clone(); - let endpoint_rate_limiter = endpoint_rate_limiter.clone(); - let cancellation_handler = cancellation_handler.clone(); - async move { - let peer_addr = match client_addr { - Some(addr) => addr, - None if config.require_client_ip => bail!("missing required client ip"), - None => remote_addr, - }; - Ok(MetricService::new( - hyper::service::service_fn(move |req: Request| { - let backend = backend.clone(); - let ws_connections2 = ws_connections.clone(); - let endpoint_rate_limiter = endpoint_rate_limiter.clone(); - let cancellation_handler = cancellation_handler.clone(); - let http_cancellation_token = http_cancellation_token.child_token(); - - // `request_handler` is not cancel safe. It expects to be cancelled only at specific times. - // By spawning the future, we ensure it never gets cancelled until it decides to. - ws_connections.spawn( - async move { - // Cancel the current inflight HTTP request if the requets stream is closed. - // This is slightly different to `_cancel_connection` in that - // h2 can cancel individual requests with a `RST_STREAM`. - let _cancel_session = http_cancellation_token.clone().drop_guard(); - - let res = request_handler( - req, - config, - backend, - ws_connections2, - cancellation_handler, - peer_addr.ip(), - endpoint_rate_limiter, - http_cancellation_token, - ) - .await - .map_or_else(|e| e.into_response(), |r| r); - - _cancel_session.disarm(); - - res - } - .in_current_span(), - ) - }), - gauge, - cancel_connection, - span, - )) - } - }, - ); - - hyper::Server::builder(tls_listener) - .serve(make_svc) - .with_graceful_shutdown(cancellation_token.cancelled()) - .await?; - - // await websocket connections - ws_connections.wait().await; + connections.wait().await; Ok(()) } -struct MetricService { - inner: S, - _gauge: IntCounterPairGuard, - _cancel: DropGuard, - span: tracing::Span, -} +/// Handles the TCP lifecycle. +/// +/// 1. Parses PROXY protocol V2 +/// 2. Handles TLS handshake +/// 3. Handles HTTP connection +/// 1. With graceful shutdowns +/// 2. With graceful request cancellation with connection failure +/// 3. With websocket upgrade support. +#[allow(clippy::too_many_arguments)] +async fn connection_handler( + config: &'static ProxyConfig, + backend: Arc, + connections: TaskTracker, + cancellation_handler: Arc, + endpoint_rate_limiter: Arc, + cancellation_token: CancellationToken, + server: Builder, + tls_acceptor: TlsAcceptor, + conn: TcpStream, + peer_addr: SocketAddr, +) { + let session_id = uuid::Uuid::new_v4(); -impl MetricService { - fn new( - inner: S, - _gauge: IntCounterPairGuard, - _cancel: DropGuard, - span: tracing::Span, - ) -> MetricService { - MetricService { - inner, - _gauge, - _cancel, - span, + let _gauge = NUM_CLIENT_CONNECTION_GAUGE + .with_label_values(&["http"]) + .guard(); + + // handle PROXY protocol + let mut conn = WithClientIp::new(conn); + let peer = match conn.wait_for_addr().await { + Ok(peer) => peer, + Err(e) => { + tracing::error!(?session_id, %peer_addr, "failed to accept TCP connection: invalid PROXY protocol V2 header: {e:#}"); + return; } - } -} + }; -impl hyper::service::Service> for MetricService -where - S: hyper::service::Service>, -{ - type Response = S::Response; - type Error = S::Error; - type Future = Instrumented; + let peer_addr = peer.unwrap_or(peer_addr).ip(); + info!(?session_id, %peer_addr, "accepted new TCP connection"); - fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll> { - self.inner.poll_ready(cx) - } + // try upgrade to TLS, but with a timeout. + let conn = match timeout(config.handshake_timeout, tls_acceptor.accept(conn)).await { + Ok(Ok(conn)) => { + info!(?session_id, %peer_addr, "accepted new TLS connection"); + conn + } + // The handshake failed + Ok(Err(e)) => { + TLS_HANDSHAKE_FAILURES.inc(); + warn!(?session_id, %peer_addr, "failed to accept TLS connection: {e:?}"); + return; + } + // The handshake timed out + Err(e) => { + TLS_HANDSHAKE_FAILURES.inc(); + warn!(?session_id, %peer_addr, "failed to accept TLS connection: {e:?}"); + return; + } + }; - fn call(&mut self, req: Request) -> Self::Future { - self.span - .in_scope(|| self.inner.call(req)) - .instrument(self.span.clone()) + let session_id = AtomicTake::new(session_id); + + // Cancel all current inflight HTTP requests if the HTTP connection is closed. + let http_cancellation_token = CancellationToken::new(); + let _cancel_connection = http_cancellation_token.clone().drop_guard(); + + let conn = server.serve_connection_with_upgrades( + hyper_util::rt::TokioIo::new(conn), + hyper1::service::service_fn(move |req: hyper1::Request| { + // First HTTP request shares the same session ID + let session_id = session_id.take().unwrap_or_else(uuid::Uuid::new_v4); + + // Cancel the current inflight HTTP request if the requets stream is closed. + // This is slightly different to `_cancel_connection` in that + // h2 can cancel individual requests with a `RST_STREAM`. + let http_request_token = http_cancellation_token.child_token(); + let cancel_request = http_request_token.clone().drop_guard(); + + // `request_handler` is not cancel safe. It expects to be cancelled only at specific times. + // By spawning the future, we ensure it never gets cancelled until it decides to. + let handler = connections.spawn( + request_handler( + req, + config, + backend.clone(), + connections.clone(), + cancellation_handler.clone(), + session_id, + peer_addr, + endpoint_rate_limiter.clone(), + http_request_token, + ) + .in_current_span() + .map_ok_or_else(api_error_into_response, |r| r), + ); + + async move { + let res = handler.await; + cancel_request.disarm(); + res + } + }), + ); + + // On cancellation, trigger the HTTP connection handler to shut down. + let res = match select(pin!(cancellation_token.cancelled()), pin!(conn)).await { + Either::Left((_cancelled, mut conn)) => { + conn.as_mut().graceful_shutdown(); + conn.await + } + Either::Right((res, _)) => res, + }; + + match res { + Ok(()) => tracing::info!(%peer_addr, "HTTP connection closed"), + Err(e) => tracing::warn!(%peer_addr, "HTTP connection error {e}"), } } #[allow(clippy::too_many_arguments)] async fn request_handler( - mut request: Request, + mut request: hyper1::Request, config: &'static ProxyConfig, backend: Arc, ws_connections: TaskTracker, cancellation_handler: Arc, + session_id: uuid::Uuid, peer_addr: IpAddr, endpoint_rate_limiter: Arc, // used to cancel in-flight HTTP requests. not used to cancel websockets http_cancellation_token: CancellationToken, -) -> Result, ApiError> { - let session_id = uuid::Uuid::new_v4(); - +) -> Result>, ApiError> { let host = request .headers() .get("host") @@ -282,14 +301,14 @@ async fn request_handler( // Return the response so the spawned future can continue. Ok(response) - } else if request.uri().path() == "/sql" && request.method() == Method::POST { + } else if request.uri().path() == "/sql" && *request.method() == Method::POST { let ctx = RequestMonitoring::new(session_id, peer_addr, "http", &config.region); let span = ctx.span.clone(); sql_over_http::handle(config, ctx, request, backend, http_cancellation_token) .instrument(span) .await - } else if request.uri().path() == "/sql" && request.method() == Method::OPTIONS { + } else if request.uri().path() == "/sql" && *request.method() == Method::OPTIONS { Response::builder() .header("Allow", "OPTIONS, POST") .header("Access-Control-Allow-Origin", "*") @@ -299,7 +318,7 @@ async fn request_handler( ) .header("Access-Control-Max-Age", "86400" /* 24 hours */) .status(StatusCode::OK) // 204 is also valid, but see: https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/OPTIONS#status_code - .body(Body::empty()) + .body(Full::new(Bytes::new())) .map_err(|e| ApiError::InternalServerError(e.into())) } else { json_response(StatusCode::BAD_REQUEST, "query is not supported") diff --git a/proxy/src/serverless/http_util.rs b/proxy/src/serverless/http_util.rs new file mode 100644 index 0000000000..ab9127b13e --- /dev/null +++ b/proxy/src/serverless/http_util.rs @@ -0,0 +1,92 @@ +//! Things stolen from `libs/utils/src/http` to add hyper 1.0 compatibility +//! Will merge back in at some point in the future. + +use bytes::Bytes; + +use anyhow::Context; +use http::{Response, StatusCode}; +use http_body_util::Full; + +use serde::Serialize; +use utils::http::error::ApiError; + +/// Like [`ApiError::into_response`] +pub fn api_error_into_response(this: ApiError) -> Response> { + match this { + ApiError::BadRequest(err) => HttpErrorBody::response_from_msg_and_status( + format!("{err:#?}"), // use debug printing so that we give the cause + StatusCode::BAD_REQUEST, + ), + ApiError::Forbidden(_) => { + HttpErrorBody::response_from_msg_and_status(this.to_string(), StatusCode::FORBIDDEN) + } + ApiError::Unauthorized(_) => { + HttpErrorBody::response_from_msg_and_status(this.to_string(), StatusCode::UNAUTHORIZED) + } + ApiError::NotFound(_) => { + HttpErrorBody::response_from_msg_and_status(this.to_string(), StatusCode::NOT_FOUND) + } + ApiError::Conflict(_) => { + HttpErrorBody::response_from_msg_and_status(this.to_string(), StatusCode::CONFLICT) + } + ApiError::PreconditionFailed(_) => HttpErrorBody::response_from_msg_and_status( + this.to_string(), + StatusCode::PRECONDITION_FAILED, + ), + ApiError::ShuttingDown => HttpErrorBody::response_from_msg_and_status( + "Shutting down".to_string(), + StatusCode::SERVICE_UNAVAILABLE, + ), + ApiError::ResourceUnavailable(err) => HttpErrorBody::response_from_msg_and_status( + err.to_string(), + StatusCode::SERVICE_UNAVAILABLE, + ), + ApiError::Timeout(err) => HttpErrorBody::response_from_msg_and_status( + err.to_string(), + StatusCode::REQUEST_TIMEOUT, + ), + ApiError::InternalServerError(err) => HttpErrorBody::response_from_msg_and_status( + err.to_string(), + StatusCode::INTERNAL_SERVER_ERROR, + ), + } +} + +/// Same as [`utils::http::error::HttpErrorBody`] +#[derive(Serialize)] +struct HttpErrorBody { + pub msg: String, +} + +impl HttpErrorBody { + /// Same as [`utils::http::error::HttpErrorBody::response_from_msg_and_status`] + fn response_from_msg_and_status(msg: String, status: StatusCode) -> Response> { + HttpErrorBody { msg }.to_response(status) + } + + /// Same as [`utils::http::error::HttpErrorBody::to_response`] + fn to_response(&self, status: StatusCode) -> Response> { + Response::builder() + .status(status) + .header(http::header::CONTENT_TYPE, "application/json") + // we do not have nested maps with non string keys so serialization shouldn't fail + .body(Full::new(Bytes::from(serde_json::to_string(self).unwrap()))) + .unwrap() + } +} + +/// Same as [`utils::http::json::json_response`] +pub fn json_response( + status: StatusCode, + data: T, +) -> Result>, ApiError> { + let json = serde_json::to_string(&data) + .context("Failed to serialize JSON response") + .map_err(ApiError::InternalServerError)?; + let response = Response::builder() + .status(status) + .header(http::header::CONTENT_TYPE, "application/json") + .body(Full::new(Bytes::from(json))) + .map_err(|e| ApiError::InternalServerError(e.into()))?; + Ok(response) +} diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index 00dffd5784..7f7f93988c 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -1,18 +1,22 @@ use std::pin::pin; use std::sync::Arc; +use bytes::Bytes; use futures::future::select; use futures::future::try_join; use futures::future::Either; use futures::StreamExt; use futures::TryFutureExt; -use hyper::body::HttpBody; -use hyper::header; -use hyper::http::HeaderName; -use hyper::http::HeaderValue; -use hyper::Response; -use hyper::StatusCode; -use hyper::{Body, HeaderMap, Request}; +use http_body_util::BodyExt; +use http_body_util::Full; +use hyper1::body::Body; +use hyper1::body::Incoming; +use hyper1::header; +use hyper1::http::HeaderName; +use hyper1::http::HeaderValue; +use hyper1::Response; +use hyper1::StatusCode; +use hyper1::{HeaderMap, Request}; use serde_json::json; use serde_json::Value; use tokio::time; @@ -29,7 +33,6 @@ use tracing::error; use tracing::info; use url::Url; use utils::http::error::ApiError; -use utils::http::json::json_response; use crate::auth::backend::ComputeUserInfo; use crate::auth::endpoint_sni; @@ -52,6 +55,7 @@ use crate::RoleName; use super::backend::PoolingBackend; use super::conn_pool::Client; use super::conn_pool::ConnInfo; +use super::http_util::json_response; use super::json::json_to_pg_text; use super::json::pg_text_row_to_json; use super::json::JsonConversionError; @@ -218,10 +222,10 @@ fn get_conn_info( pub async fn handle( config: &'static ProxyConfig, mut ctx: RequestMonitoring, - request: Request, + request: Request, backend: Arc, cancel: CancellationToken, -) -> Result, ApiError> { +) -> Result>, ApiError> { let result = handle_inner(cancel, config, &mut ctx, request, backend).await; let mut response = match result { @@ -332,10 +336,9 @@ pub async fn handle( } }; - response.headers_mut().insert( - "Access-Control-Allow-Origin", - hyper::http::HeaderValue::from_static("*"), - ); + response + .headers_mut() + .insert("Access-Control-Allow-Origin", HeaderValue::from_static("*")); Ok(response) } @@ -396,7 +399,7 @@ impl UserFacingError for SqlOverHttpError { #[derive(Debug, thiserror::Error)] pub enum ReadPayloadError { #[error("could not read the HTTP request body: {0}")] - Read(#[from] hyper::Error), + Read(#[from] hyper1::Error), #[error("could not parse the HTTP request body: {0}")] Parse(#[from] serde_json::Error), } @@ -437,7 +440,7 @@ struct HttpHeaders { } impl HttpHeaders { - fn try_parse(headers: &hyper::http::HeaderMap) -> Result { + fn try_parse(headers: &hyper1::http::HeaderMap) -> Result { // Determine the output options. Default behaviour is 'false'. Anything that is not // strictly 'true' assumed to be false. let raw_output = headers.get(&RAW_TEXT_OUTPUT) == Some(&HEADER_VALUE_TRUE); @@ -488,9 +491,9 @@ async fn handle_inner( cancel: CancellationToken, config: &'static ProxyConfig, ctx: &mut RequestMonitoring, - request: Request, + request: Request, backend: Arc, -) -> Result, SqlOverHttpError> { +) -> Result>, SqlOverHttpError> { let _request_gauge = NUM_CONNECTION_REQUESTS_GAUGE .with_label_values(&[ctx.protocol]) .guard(); @@ -528,7 +531,7 @@ async fn handle_inner( } let fetch_and_process_request = async { - let body = hyper::body::to_bytes(request.into_body()).await?; + let body = request.into_body().collect().await?.to_bytes(); info!(length = body.len(), "request payload read"); let payload: Payload = serde_json::from_slice(&body)?; Ok::(payload) // Adjust error type accordingly @@ -596,7 +599,7 @@ async fn handle_inner( let body = serde_json::to_string(&result).expect("json serialization should not fail"); let len = body.len(); let response = response - .body(Body::from(body)) + .body(Full::new(Bytes::from(body))) // only fails if invalid status code or invalid header/values are given. // these are not user configurable so it cannot fail dynamically .expect("building response payload should not fail"); @@ -639,6 +642,7 @@ impl QueryData { } // The query was cancelled. Either::Right((_cancelled, query)) => { + tracing::info!("cancelling query"); if let Err(err) = cancel_token.cancel_query(NoTls).await { tracing::error!(?err, "could not cancel query"); } diff --git a/proxy/src/serverless/tls_listener.rs b/proxy/src/serverless/tls_listener.rs deleted file mode 100644 index 33f194dd59..0000000000 --- a/proxy/src/serverless/tls_listener.rs +++ /dev/null @@ -1,123 +0,0 @@ -use std::{ - convert::Infallible, - pin::Pin, - task::{Context, Poll}, - time::Duration, -}; - -use hyper::server::{accept::Accept, conn::AddrStream}; -use pin_project_lite::pin_project; -use tokio::{ - io::{AsyncRead, AsyncWrite}, - task::JoinSet, - time::timeout, -}; -use tokio_rustls::{server::TlsStream, TlsAcceptor}; -use tracing::{info, warn, Instrument}; - -use crate::{ - metrics::TLS_HANDSHAKE_FAILURES, - protocol2::{WithClientIp, WithConnectionGuard}, -}; - -pin_project! { - /// Wraps a `Stream` of connections (such as a TCP listener) so that each connection is itself - /// encrypted using TLS. - pub(crate) struct TlsListener { - #[pin] - listener: A, - tls: TlsAcceptor, - waiting: JoinSet>>, - timeout: Duration, - } -} - -impl TlsListener { - /// Create a `TlsListener` with default options. - pub(crate) fn new(tls: TlsAcceptor, listener: A, timeout: Duration) -> Self { - TlsListener { - listener, - tls, - waiting: JoinSet::new(), - timeout, - } - } -} - -impl Accept for TlsListener -where - A: Accept>>, - A::Error: std::error::Error, - A::Conn: AsyncRead + AsyncWrite + Unpin + Send + 'static, -{ - type Conn = TlsStream; - - type Error = Infallible; - - fn poll_accept( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll>> { - let mut this = self.project(); - - loop { - match this.listener.as_mut().poll_accept(cx) { - Poll::Pending => break, - Poll::Ready(Some(Ok(mut conn))) => { - let t = *this.timeout; - let tls = this.tls.clone(); - let span = conn.span.clone(); - this.waiting.spawn(async move { - let peer_addr = match conn.inner.wait_for_addr().await { - Ok(Some(addr)) => addr, - Err(e) => { - tracing::error!("failed to accept TCP connection: invalid PROXY protocol V2 header: {e:#}"); - return None; - } - Ok(None) => conn.inner.inner.remote_addr() - }; - - let accept = tls.accept(conn); - match timeout(t, accept).await { - Ok(Ok(conn)) => { - info!(%peer_addr, "accepted new TLS connection"); - Some(conn) - }, - // The handshake failed, try getting another connection from the queue - Ok(Err(e)) => { - TLS_HANDSHAKE_FAILURES.inc(); - warn!(%peer_addr, "failed to accept TLS connection: {e:?}"); - None - } - // The handshake timed out, try getting another connection from the queue - Err(_) => { - TLS_HANDSHAKE_FAILURES.inc(); - warn!(%peer_addr, "failed to accept TLS connection: timeout"); - None - } - } - }.instrument(span)); - } - Poll::Ready(Some(Err(e))) => { - tracing::error!("error accepting TCP connection: {e}"); - continue; - } - Poll::Ready(None) => return Poll::Ready(None), - } - } - - loop { - return match this.waiting.poll_join_next(cx) { - Poll::Ready(Some(Ok(Some(conn)))) => Poll::Ready(Some(Ok(conn))), - // The handshake failed to complete, try getting another connection from the queue - Poll::Ready(Some(Ok(None))) => continue, - // The handshake panicked or was cancelled. ignore and get another connection - Poll::Ready(Some(Err(e))) => { - tracing::warn!("handshake aborted: {e}"); - continue; - } - _ => Poll::Pending, - }; - } - } -} diff --git a/workspace_hack/Cargo.toml b/workspace_hack/Cargo.toml index bcbd4daa7e..d6e2cc2996 100644 --- a/workspace_hack/Cargo.toml +++ b/workspace_hack/Cargo.toml @@ -63,7 +63,7 @@ scopeguard = { version = "1" } serde = { version = "1", features = ["alloc", "derive"] } serde_json = { version = "1", features = ["raw_value"] } sha2 = { version = "0.10", features = ["asm"] } -smallvec = { version = "1", default-features = false, features = ["write"] } +smallvec = { version = "1", default-features = false, features = ["const_new", "write"] } subtle = { version = "2" } time = { version = "0.3", features = ["local-offset", "macros", "serde-well-known"] } tokio = { version = "1", features = ["fs", "io-std", "io-util", "macros", "net", "process", "rt-multi-thread", "signal", "test-util"] } @@ -75,7 +75,6 @@ tonic = { version = "0.9", features = ["tls-roots"] } tower = { version = "0.4", default-features = false, features = ["balance", "buffer", "limit", "log", "timeout", "util"] } tracing = { version = "0.1", features = ["log"] } tracing-core = { version = "0.1" } -tungstenite = { version = "0.20" } url = { version = "2", features = ["serde"] } uuid = { version = "1", features = ["serde", "v4", "v7"] } zeroize = { version = "1", features = ["derive"] }