diff --git a/Cargo.lock b/Cargo.lock index 7fd9053f62..bf75e71113 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -285,7 +285,7 @@ dependencies = [ "futures", "git-version", "humantime", - "hyper", + "hyper 0.14.26", "metrics", "once_cell", "pageserver_api", @@ -331,7 +331,7 @@ dependencies = [ "fastrand 2.0.0", "hex", "http 0.2.9", - "hyper", + "hyper 0.14.26", "ring 0.17.6", "time", "tokio", @@ -368,7 +368,7 @@ dependencies = [ "bytes", "fastrand 2.0.0", "http 0.2.9", - "http-body", + "http-body 0.4.5", "percent-encoding", "pin-project-lite", "tracing", @@ -396,7 +396,7 @@ dependencies = [ "aws-types", "bytes", "http 0.2.9", - "http-body", + "http-body 0.4.5", "once_cell", "percent-encoding", "regex-lite", @@ -547,7 +547,7 @@ dependencies = [ "crc32fast", "hex", "http 0.2.9", - "http-body", + "http-body 0.4.5", "md-5", "pin-project-lite", "sha1", @@ -579,7 +579,7 @@ dependencies = [ "bytes-utils", "futures-core", "http 0.2.9", - "http-body", + "http-body 0.4.5", "once_cell", "percent-encoding", "pin-project-lite", @@ -618,10 +618,10 @@ dependencies = [ "aws-smithy-types", "bytes", "fastrand 2.0.0", - "h2", + "h2 0.3.24", "http 0.2.9", - "http-body", - "hyper", + "http-body 0.4.5", + "hyper 0.14.26", "hyper-rustls", "once_cell", "pin-project-lite", @@ -658,7 +658,7 @@ dependencies = [ "bytes-utils", "futures-core", "http 0.2.9", - "http-body", + "http-body 0.4.5", "itoa", "num-integer", "pin-project-lite", @@ -707,8 +707,8 @@ dependencies = [ "bytes", "futures-util", "http 0.2.9", - "http-body", - "hyper", + "http-body 0.4.5", + "hyper 0.14.26", "itoa", "matchit", "memchr", @@ -723,7 +723,7 @@ dependencies = [ "sha1", "sync_wrapper", "tokio", - "tokio-tungstenite", + "tokio-tungstenite 0.20.0", "tower", "tower-layer", "tower-service", @@ -739,7 +739,7 @@ dependencies = [ "bytes", "futures-util", "http 0.2.9", - "http-body", + "http-body 0.4.5", "mime", "rustversion", "tower-layer", @@ -1228,7 +1228,7 @@ dependencies = [ "compute_api", "flate2", "futures", - "hyper", + "hyper 0.14.26", "nix 0.27.1", "notify", "num_cpus", @@ -1344,7 +1344,7 @@ dependencies = [ "futures", "git-version", "hex", - "hyper", + "hyper 0.14.26", "nix 0.27.1", "once_cell", "pageserver_api", @@ -2244,6 +2244,25 @@ dependencies = [ "tracing", ] +[[package]] +name = "h2" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "31d030e59af851932b72ceebadf4a2b5986dba4c3b99dd2493f8273a0f151943" +dependencies = [ + "bytes", + "fnv", + "futures-core", + "futures-sink", + "futures-util", + "http 1.0.0", + "indexmap 2.0.1", + "slab", + "tokio", + "tokio-util", + "tracing", +] + [[package]] name = "half" version = "1.8.2" @@ -2409,6 +2428,29 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "http-body" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1cac85db508abc24a2e48553ba12a996e87244a0395ce011e62b37158745d643" +dependencies = [ + "bytes", + "http 1.0.0", +] + +[[package]] +name = "http-body-util" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41cb79eb393015dadd30fc252023adb0b2400a0caee0fa2a077e6e21a551e840" +dependencies = [ + "bytes", + "futures-util", + "http 1.0.0", + "http-body 1.0.0", + "pin-project-lite", +] + [[package]] name = "http-types" version = "2.12.0" @@ -2467,9 +2509,9 @@ dependencies = [ "futures-channel", "futures-core", "futures-util", - "h2", + "h2 0.3.24", "http 0.2.9", - "http-body", + "http-body 0.4.5", "httparse", "httpdate", "itoa", @@ -2481,6 +2523,26 @@ dependencies = [ "want", ] +[[package]] +name = "hyper" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "186548d73ac615b32a73aafe38fb4f56c0d340e110e5a200bcadbaf2e199263a" +dependencies = [ + "bytes", + "futures-channel", + "futures-util", + "h2 0.4.2", + "http 1.0.0", + "http-body 1.0.0", + "httparse", + "httpdate", + "itoa", + "pin-project-lite", + "smallvec", + "tokio", +] + [[package]] name = "hyper-rustls" version = "0.24.0" @@ -2488,7 +2550,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0646026eb1b3eea4cd9ba47912ea5ce9cc07713d105b1a14698f4e6433d348b7" dependencies = [ "http 0.2.9", - "hyper", + "hyper 0.14.26", "log", "rustls 0.21.9", "rustls-native-certs", @@ -2502,7 +2564,7 @@ version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bbb958482e8c7be4bc3cf272a766a2b0bf1a6755e7a6ae777f017a31d11b13b1" dependencies = [ - "hyper", + "hyper 0.14.26", "pin-project-lite", "tokio", "tokio-io-timeout", @@ -2515,7 +2577,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d6183ddfa99b85da61a140bea0efc93fdf56ceaa041b37d553518030827f9905" dependencies = [ "bytes", - "hyper", + "hyper 0.14.26", "native-tls", "tokio", "tokio-native-tls", @@ -2523,15 +2585,33 @@ dependencies = [ [[package]] name = "hyper-tungstenite" -version = "0.11.1" +version = "0.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7cc7dcb1ab67cd336f468a12491765672e61a3b6b148634dbfe2fe8acd3fe7d9" +checksum = "7a343d17fe7885302ed7252767dc7bb83609a874b6ff581142241ec4b73957ad" dependencies = [ - "hyper", + "http-body-util", + "hyper 1.2.0", + "hyper-util", "pin-project-lite", "tokio", - "tokio-tungstenite", - "tungstenite", + "tokio-tungstenite 0.21.0", + "tungstenite 0.21.0", +] + +[[package]] +name = "hyper-util" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ca38ef113da30126bbff9cd1705f9273e15d45498615d138b0c20279ac7a76aa" +dependencies = [ + "bytes", + "futures-util", + "http 1.0.0", + "http-body 1.0.0", + "hyper 1.2.0", + "pin-project-lite", + "socket2 0.5.5", + "tokio", ] [[package]] @@ -3508,7 +3588,7 @@ dependencies = [ "hex-literal", "humantime", "humantime-serde", - "hyper", + "hyper 0.14.26", "itertools", "leaky-bucket", "md5", @@ -4178,9 +4258,13 @@ dependencies = [ "hex", "hmac", "hostname", + "http 1.0.0", + "http-body-util", "humantime", - "hyper", + "hyper 0.14.26", + "hyper 1.2.0", "hyper-tungstenite", + "hyper-util", "ipnet", "itertools", "lasso", @@ -4512,7 +4596,7 @@ dependencies = [ "futures-util", "http-types", "humantime", - "hyper", + "hyper 0.14.26", "itertools", "metrics", "once_cell", @@ -4542,10 +4626,10 @@ dependencies = [ "encoding_rs", "futures-core", "futures-util", - "h2", + "h2 0.3.24", "http 0.2.9", - "http-body", - "hyper", + "http-body 0.4.5", + "hyper 0.14.26", "hyper-rustls", "hyper-tls", "ipnet", @@ -4603,7 +4687,7 @@ dependencies = [ "futures", "getrandom 0.2.11", "http 0.2.9", - "hyper", + "hyper 0.14.26", "parking_lot 0.11.2", "reqwest", "reqwest-middleware", @@ -4690,7 +4774,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "496c1d3718081c45ba9c31fbfc07417900aa96f4070ff90dc29961836b7a9945" dependencies = [ "http 0.2.9", - "hyper", + "hyper 0.14.26", "lazy_static", "percent-encoding", "regex", @@ -4969,7 +5053,7 @@ dependencies = [ "git-version", "hex", "humantime", - "hyper", + "hyper 0.14.26", "metrics", "once_cell", "parking_lot 0.12.1", @@ -5444,9 +5528,9 @@ dependencies = [ [[package]] name = "smallvec" -version = "1.11.0" +version = "1.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62bb4feee49fdd9f707ef802e22365a35de4b7b299de4763d44bfea899442ff9" +checksum = "e6ecd384b10a64542d77071bd64bd7b231f4ed5940fba55e98c3de13824cf3d7" [[package]] name = "smol_str" @@ -5538,7 +5622,7 @@ dependencies = [ "futures-util", "git-version", "humantime", - "hyper", + "hyper 0.14.26", "metrics", "once_cell", "parking_lot 0.12.1", @@ -6022,7 +6106,19 @@ dependencies = [ "futures-util", "log", "tokio", - "tungstenite", + "tungstenite 0.20.1", +] + +[[package]] +name = "tokio-tungstenite" +version = "0.21.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c83b561d025642014097b66e6c1bb422783339e0909e4429cde4749d1990bc38" +dependencies = [ + "futures-util", + "log", + "tokio", + "tungstenite 0.21.0", ] [[package]] @@ -6089,10 +6185,10 @@ dependencies = [ "bytes", "futures-core", "futures-util", - "h2", + "h2 0.3.24", "http 0.2.9", - "http-body", - "hyper", + "http-body 0.4.5", + "hyper 0.14.26", "hyper-timeout", "percent-encoding", "pin-project", @@ -6278,7 +6374,7 @@ dependencies = [ name = "tracing-utils" version = "0.1.0" dependencies = [ - "hyper", + "hyper 0.14.26", "opentelemetry", "opentelemetry-otlp", "opentelemetry-semantic-conventions", @@ -6315,6 +6411,25 @@ dependencies = [ "utf-8", ] +[[package]] +name = "tungstenite" +version = "0.21.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ef1a641ea34f399a848dea702823bbecfb4c486f911735368f1f137cb8257e1" +dependencies = [ + "byteorder", + "bytes", + "data-encoding", + "http 1.0.0", + "httparse", + "log", + "rand 0.8.5", + "sha1", + "thiserror", + "url", + "utf-8", +] + [[package]] name = "twox-hash" version = "1.6.3" @@ -6478,7 +6593,7 @@ dependencies = [ "heapless", "hex", "hex-literal", - "hyper", + "hyper 0.14.26", "jsonwebtoken", "leaky-bucket", "metrics", @@ -7003,7 +7118,7 @@ dependencies = [ "hashbrown 0.14.0", "hex", "hmac", - "hyper", + "hyper 0.14.26", "indexmap 1.9.3", "itertools", "libc", @@ -7040,7 +7155,7 @@ dependencies = [ "tower", "tracing", "tracing-core", - "tungstenite", + "tungstenite 0.20.1", "url", "uuid", "zeroize", diff --git a/Cargo.toml b/Cargo.toml index 76f4ff041c..48a557562a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -92,7 +92,7 @@ http-types = { version = "2", default-features = false } humantime = "2.1" humantime-serde = "1.1.1" hyper = "0.14" -hyper-tungstenite = "0.11" +hyper-tungstenite = "0.13.0" inotify = "0.10.2" ipnet = "2.9.0" itertools = "0.10" diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index d8112c8bf0..d6c356bf26 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -30,6 +30,10 @@ hostname.workspace = true humantime.workspace = true hyper-tungstenite.workspace = true hyper.workspace = true +hyper1 = { package = "hyper", version = "1.2", features = ["server"] } +hyper-util = { version = "0.1", features = ["server", "http1", "http2", "tokio"] } +http1 = { package = "http", version = "1" } +http-body-util = { version = "0.1" } ipnet.workspace = true itertools.workspace = true lasso = { workspace = true, features = ["multi-threaded"] } diff --git a/proxy/src/protocol2.rs b/proxy/src/protocol2.rs index 3a7aabca32..70f9b4bfab 100644 --- a/proxy/src/protocol2.rs +++ b/proxy/src/protocol2.rs @@ -5,19 +5,13 @@ use std::{ io, net::SocketAddr, pin::{pin, Pin}, - sync::Mutex, task::{ready, Context, Poll}, }; use bytes::{Buf, BytesMut}; -use hyper::server::accept::Accept; -use hyper::server::conn::{AddrIncoming, AddrStream}; -use metrics::IntCounterPairGuard; +use hyper::server::conn::AddrIncoming; use pin_project_lite::pin_project; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf}; -use uuid::Uuid; - -use crate::{metrics::NUM_CLIENT_CONNECTION_GAUGE, serverless::tls_listener::AsyncAccept}; pub struct ProxyProtocolAccept { pub incoming: AddrIncoming, @@ -331,87 +325,6 @@ impl AsyncRead for WithClientIp { } } -impl AsyncAccept for ProxyProtocolAccept { - type Connection = WithConnectionGuard>; - - type Error = io::Error; - - fn poll_accept( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll>> { - let conn = ready!(Pin::new(&mut self.incoming).poll_accept(cx)?); - tracing::info!(protocol = self.protocol, "accepted new TCP connection"); - let Some(conn) = conn else { - return Poll::Ready(None); - }; - - Poll::Ready(Some(Ok(WithConnectionGuard { - inner: WithClientIp::new(conn), - connection_id: Uuid::new_v4(), - gauge: Mutex::new(Some( - NUM_CLIENT_CONNECTION_GAUGE - .with_label_values(&[self.protocol]) - .guard(), - )), - }))) - } -} - -pin_project! { - pub struct WithConnectionGuard { - #[pin] - pub inner: T, - pub connection_id: Uuid, - pub gauge: Mutex>, - } -} - -impl AsyncWrite for WithConnectionGuard { - #[inline] - fn poll_write( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - self.project().inner.poll_write(cx, buf) - } - - #[inline] - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().inner.poll_flush(cx) - } - - #[inline] - fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().inner.poll_shutdown(cx) - } - - #[inline] - fn poll_write_vectored( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - bufs: &[io::IoSlice<'_>], - ) -> Poll> { - self.project().inner.poll_write_vectored(cx, bufs) - } - - #[inline] - fn is_write_vectored(&self) -> bool { - self.inner.is_write_vectored() - } -} - -impl AsyncRead for WithConnectionGuard { - fn poll_read( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, - ) -> Poll> { - self.project().inner.poll_read(cx, buf) - } -} - #[cfg(test)] mod tests { use std::pin::pin; diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index 7848fc2ac2..5b8d654149 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -91,9 +91,8 @@ pub async fn task_main( connections.spawn(async move { let mut socket = WithClientIp::new(socket); - let mut peer_addr = peer_addr.ip(); - match socket.wait_for_addr().await { - Ok(Some(addr)) => peer_addr = addr.ip(), + let peer_addr = match socket.wait_for_addr().await { + Ok(Some(addr)) => addr.ip(), Err(e) => { error!("per-client task finished with an error: {e:#}"); return; @@ -102,8 +101,8 @@ pub async fn task_main( error!("missing required client IP"); return; } - Ok(None) => {} - } + Ok(None) => peer_addr.ip() + }; match socket.inner.set_nodelay(true) { Ok(()) => {}, diff --git a/proxy/src/serverless.rs b/proxy/src/serverless.rs index c81ae03b23..d5387c172c 100644 --- a/proxy/src/serverless.rs +++ b/proxy/src/serverless.rs @@ -6,44 +6,39 @@ mod backend; mod conn_pool; mod json; mod sql_over_http; -pub mod tls_listener; mod websocket; +use bytes::Bytes; pub use conn_pool::GlobalConnPoolOptions; -use anyhow::bail; -use hyper::StatusCode; -use metrics::IntCounterPairGuard; +use anyhow::Context; +use futures::future::{select, Either}; +use http1::{Method, Response, StatusCode}; +use http_body_util::Full; +use hyper1::body::Incoming; use rand::rngs::StdRng; use rand::SeedableRng; pub use reqwest_middleware::{ClientWithMiddleware, Error}; pub use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware}; +use serde::Serialize; use tokio_util::task::TaskTracker; use crate::context::RequestMonitoring; -use crate::metrics::TLS_HANDSHAKE_FAILURES; -use crate::protocol2::{ProxyProtocolAccept, WithClientIp, WithConnectionGuard}; +use crate::metrics::NUM_CLIENT_CONNECTION_GAUGE; +use crate::protocol2::WithClientIp; +use crate::proxy::run_until_cancelled; use crate::rate_limiter::EndpointRateLimiter; use crate::serverless::backend::PoolingBackend; use crate::{cancellation::CancellationHandler, config::ProxyConfig}; -use futures::StreamExt; -use hyper::{ - server::{ - accept, - conn::{AddrIncoming, AddrStream}, - }, - Body, Method, Request, Response, -}; use std::convert::Infallible; use std::net::IpAddr; -use std::task::Poll; -use std::{future::ready, sync::Arc}; -use tls_listener::TlsListener; +use std::pin::pin; +use std::sync::Arc; use tokio::net::TcpListener; use tokio_util::sync::CancellationToken; use tracing::{error, info, warn, Instrument}; -use utils::http::{error::ApiError, json::json_response}; +use utils::http::error::ApiError; pub const SERVERLESS_DRIVER_SNI: &str = "api"; @@ -95,134 +90,184 @@ pub async fn task_main( tls_server_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; let tls_acceptor: tokio_rustls::TlsAcceptor = Arc::new(tls_server_config).into(); - let mut addr_incoming = AddrIncoming::from_listener(ws_listener)?; - let _ = addr_incoming.set_nodelay(true); - let addr_incoming = ProxyProtocolAccept { - incoming: addr_incoming, - protocol: "http", - }; - let ws_connections = tokio_util::task::task_tracker::TaskTracker::new(); ws_connections.close(); // allows `ws_connections.wait to complete` - let tls_listener = TlsListener::new(tls_acceptor, addr_incoming).filter(|conn| { - if let Err(err) = conn { - error!( - protocol = "http", - "failed to accept TLS connection: {err:?}" - ); - TLS_HANDSHAKE_FAILURES.inc(); - ready(false) - } else { - info!(protocol = "http", "accepted new TLS connection"); - ready(true) + let http_connections = tokio_util::task::task_tracker::TaskTracker::new(); + http_connections.close(); + + let server = hyper_util::server::conn::auto::Builder::new(hyper_util::rt::TokioExecutor::new()); + + loop { + let Some(res) = run_until_cancelled(ws_listener.accept(), &cancellation_token).await else { + break; + }; + + let (conn, mut peer_addr) = res.context("could not accept TCP stream")?; + if let Err(e) = conn.set_nodelay(true) { + tracing::error!("could not set nodolay: {e}"); + continue; } - }); + let cancellation_token = cancellation_token.child_token(); - let make_svc = hyper::service::make_service_fn( - |stream: &tokio_rustls::server::TlsStream< - WithConnectionGuard>, - >| { - let (conn, _) = stream.get_ref(); + let tls = tls_acceptor.clone(); - // this is jank. should dissapear with hyper 1.0 migration. - let gauge = conn - .gauge - .lock() - .expect("lock should not be poisoned") - .take() - .expect("gauge should be set on connection start"); + let backend = backend.clone(); + let ws_connections = ws_connections.clone(); + let endpoint_rate_limiter = endpoint_rate_limiter.clone(); + let cancellation_handler = cancellation_handler.clone(); + let server = server.clone(); - let client_addr = conn.inner.client_addr(); - let remote_addr = conn.inner.inner.remote_addr(); - let backend = backend.clone(); - let ws_connections = ws_connections.clone(); - let endpoint_rate_limiter = endpoint_rate_limiter.clone(); - let cancellation_handler = cancellation_handler.clone(); - async move { - let peer_addr = match client_addr { - Some(addr) => addr, - None if config.require_client_ip => bail!("missing required client ip"), - None => remote_addr, - }; - Ok(MetricService::new( - hyper::service::service_fn(move |req: Request| { - let backend = backend.clone(); - let ws_connections = ws_connections.clone(); - let endpoint_rate_limiter = endpoint_rate_limiter.clone(); - let cancellation_handler = cancellation_handler.clone(); + http_connections.spawn(async move { + let _gauge = NUM_CLIENT_CONNECTION_GAUGE + .with_label_values(&["http"]) + .guard(); - async move { - Ok::<_, Infallible>( - request_handler( - req, - config, - backend, - ws_connections, - cancellation_handler, - peer_addr.ip(), - endpoint_rate_limiter, - ) - .await - .map_or_else(|e| e.into_response(), |r| r), - ) - } - }), - gauge, - )) + // handle PROXY protocol + let mut conn = WithClientIp::new(conn); + let peer = match conn.wait_for_addr().await { + Ok(peer) => peer, + Err(e) => { + tracing::error!("could not parse PROXY protocol header: {e}"); + return; + } + }; + + if let Some(peer) = peer { + peer_addr = peer; } - }, - ); - hyper::Server::builder(accept::from_stream(tls_listener)) - .serve(make_svc) - .with_graceful_shutdown(cancellation_token.cancelled()) - .await?; + let conn = match tls.accept(conn).await { + Ok(conn) => conn, + Err(e) => { + tracing::error!(protocol = "http", "TLS failure: {e}"); + return; + } + }; + + let conn = server.serve_connection_with_upgrades( + hyper_util::rt::TokioIo::new(conn), + hyper1::service::service_fn(move |req: hyper1::Request| { + let backend = backend.clone(); + let ws_connections = ws_connections.clone(); + let endpoint_rate_limiter = endpoint_rate_limiter.clone(); + let cancellation_handler = cancellation_handler.clone(); + + async move { + Ok::<_, Infallible>( + request_handler( + req, + config, + backend, + ws_connections, + cancellation_handler, + peer_addr.ip(), + endpoint_rate_limiter, + ) + .await + .map_or_else(api_error_into_response, |r| r), + ) + } + }), + ); + + let cancel = pin!(cancellation_token.cancelled()); + let conn = pin!(conn); + let res = match select(cancel, conn).await { + Either::Left((_cancelled, mut conn)) => { + conn.as_mut().graceful_shutdown(); + conn.await + } + Either::Right((res, _)) => res, + }; + + match res { + Ok(()) => {} + Err(e) => { + tracing::warn!("HTTP connection error {e}") + } + } + }); + } // await websocket connections + http_connections.wait().await; ws_connections.wait().await; Ok(()) } -struct MetricService { - inner: S, - _gauge: IntCounterPairGuard, -} - -impl MetricService { - fn new(inner: S, _gauge: IntCounterPairGuard) -> MetricService { - MetricService { inner, _gauge } +fn api_error_into_response(this: ApiError) -> Response> { + match this { + ApiError::BadRequest(err) => HttpErrorBody::response_from_msg_and_status( + format!("{err:#?}"), // use debug printing so that we give the cause + StatusCode::BAD_REQUEST, + ), + ApiError::Forbidden(_) => { + HttpErrorBody::response_from_msg_and_status(this.to_string(), StatusCode::FORBIDDEN) + } + ApiError::Unauthorized(_) => { + HttpErrorBody::response_from_msg_and_status(this.to_string(), StatusCode::UNAUTHORIZED) + } + ApiError::NotFound(_) => { + HttpErrorBody::response_from_msg_and_status(this.to_string(), StatusCode::NOT_FOUND) + } + ApiError::Conflict(_) => { + HttpErrorBody::response_from_msg_and_status(this.to_string(), StatusCode::CONFLICT) + } + ApiError::PreconditionFailed(_) => HttpErrorBody::response_from_msg_and_status( + this.to_string(), + StatusCode::PRECONDITION_FAILED, + ), + ApiError::ShuttingDown => HttpErrorBody::response_from_msg_and_status( + "Shutting down".to_string(), + StatusCode::SERVICE_UNAVAILABLE, + ), + ApiError::ResourceUnavailable(err) => HttpErrorBody::response_from_msg_and_status( + err.to_string(), + StatusCode::SERVICE_UNAVAILABLE, + ), + ApiError::Timeout(err) => HttpErrorBody::response_from_msg_and_status( + err.to_string(), + StatusCode::REQUEST_TIMEOUT, + ), + ApiError::InternalServerError(err) => HttpErrorBody::response_from_msg_and_status( + err.to_string(), + StatusCode::INTERNAL_SERVER_ERROR, + ), } } -impl hyper::service::Service> for MetricService -where - S: hyper::service::Service>, -{ - type Response = S::Response; - type Error = S::Error; - type Future = S::Future; +#[derive(Serialize)] +struct HttpErrorBody { + pub msg: String, +} - fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll> { - self.inner.poll_ready(cx) +impl HttpErrorBody { + pub fn response_from_msg_and_status(msg: String, status: StatusCode) -> Response> { + HttpErrorBody { msg }.to_response(status) } - fn call(&mut self, req: Request) -> Self::Future { - self.inner.call(req) + pub fn to_response(&self, status: StatusCode) -> Response> { + Response::builder() + .status(status) + .header(http1::header::CONTENT_TYPE, "application/json") + // we do not have nested maps with non string keys so serialization shouldn't fail + .body(Full::new(Bytes::from(serde_json::to_string(self).unwrap()))) + .unwrap() } } #[allow(clippy::too_many_arguments)] async fn request_handler( - mut request: Request, + mut request: hyper1::Request, config: &'static ProxyConfig, backend: Arc, ws_connections: TaskTracker, cancellation_handler: Arc, peer_addr: IpAddr, endpoint_rate_limiter: Arc, -) -> Result, ApiError> { +) -> Result>, ApiError> { let session_id = uuid::Uuid::new_v4(); let host = request @@ -261,14 +306,14 @@ async fn request_handler( // Return the response so the spawned future can continue. Ok(response) - } else if request.uri().path() == "/sql" && request.method() == Method::POST { + } else if request.uri().path() == "/sql" && *request.method() == Method::POST { let ctx = RequestMonitoring::new(session_id, peer_addr, "http", &config.region); let span = ctx.span.clone(); sql_over_http::handle(config, ctx, request, backend) .instrument(span) .await - } else if request.uri().path() == "/sql" && request.method() == Method::OPTIONS { + } else if request.uri().path() == "/sql" && *request.method() == Method::OPTIONS { Response::builder() .header("Allow", "OPTIONS, POST") .header("Access-Control-Allow-Origin", "*") @@ -278,9 +323,24 @@ async fn request_handler( ) .header("Access-Control-Max-Age", "86400" /* 24 hours */) .status(StatusCode::OK) // 204 is also valid, but see: https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/OPTIONS#status_code - .body(Body::empty()) + .body(Full::new(Bytes::new())) .map_err(|e| ApiError::InternalServerError(e.into())) } else { json_response(StatusCode::BAD_REQUEST, "query is not supported") } } + +fn json_response( + status: StatusCode, + data: T, +) -> Result>, ApiError> { + let json = serde_json::to_string(&data) + .context("Failed to serialize JSON response") + .map_err(ApiError::InternalServerError)?; + let response = Response::builder() + .status(status) + .header(http1::header::CONTENT_TYPE, "application/json") + .body(Full::new(Bytes::from(json))) + .map_err(|e| ApiError::InternalServerError(e.into()))?; + Ok(response) +} diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index 74af985211..69159d30a1 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -1,14 +1,19 @@ use std::sync::Arc; +use super::json_response; use anyhow::bail; +use bytes::Bytes; use futures::StreamExt; -use hyper::body::HttpBody; -use hyper::header; -use hyper::http::HeaderName; -use hyper::http::HeaderValue; -use hyper::Response; -use hyper::StatusCode; -use hyper::{Body, HeaderMap, Request}; +use http_body_util::BodyExt; +use http_body_util::Full; +use hyper1::body::Body; +use hyper1::body::Incoming; +use hyper1::header; +use hyper1::http::HeaderName; +use hyper1::http::HeaderValue; +use hyper1::Response; +use hyper1::StatusCode; +use hyper1::{HeaderMap, Request}; use serde_json::json; use serde_json::Value; use tokio::try_join; @@ -22,7 +27,6 @@ use tracing::error; use tracing::info; use url::Url; use utils::http::error::ApiError; -use utils::http::json::json_response; use crate::auth::backend::ComputeUserInfo; use crate::auth::endpoint_sni; @@ -191,9 +195,9 @@ fn get_conn_info( pub async fn handle( config: &'static ProxyConfig, mut ctx: RequestMonitoring, - request: Request, + request: Request, backend: Arc, -) -> Result, ApiError> { +) -> Result>, ApiError> { let result = tokio::time::timeout( config.http_config.request_timeout, handle_inner(config, &mut ctx, request, backend), @@ -300,19 +304,18 @@ pub async fn handle( } }; - response.headers_mut().insert( - "Access-Control-Allow-Origin", - hyper::http::HeaderValue::from_static("*"), - ); + response + .headers_mut() + .insert("Access-Control-Allow-Origin", HeaderValue::from_static("*")); Ok(response) } async fn handle_inner( config: &'static ProxyConfig, ctx: &mut RequestMonitoring, - request: Request, + request: Request, backend: Arc, -) -> anyhow::Result> { +) -> anyhow::Result>> { let _request_gauge = NUM_CONNECTION_REQUESTS_GAUGE .with_label_values(&[ctx.protocol]) .guard(); @@ -369,9 +372,12 @@ async fn handle_inner( } let fetch_and_process_request = async { - let body = hyper::body::to_bytes(request.into_body()) + let body = request + .into_body() + .collect() .await - .map_err(anyhow::Error::from)?; + .map_err(anyhow::Error::from)? + .to_bytes(); info!(length = body.len(), "request payload read"); let payload: Payload = serde_json::from_slice(&body)?; Ok::(payload) // Adjust error type accordingly @@ -490,7 +496,7 @@ async fn handle_inner( let body = serde_json::to_string(&result).expect("json serialization should not fail"); let len = body.len(); let response = response - .body(Body::from(body)) + .body(Full::new(Bytes::from(body))) // only fails if invalid status code or invalid header/values are given. // these are not user configurable so it cannot fail dynamically .expect("building response payload should not fail");