From 2cf85471f5631070417c72366691c7e591e4973f Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Thu, 18 Jan 2024 14:11:35 +0000 Subject: [PATCH] proxy --- Cargo.lock | 142 ++++++++++++++++++++------ Cargo.toml | 4 +- libs/utils/src/failpoint_support.rs | 4 +- libs/utils/src/http/json.rs | 8 +- libs/utils/src/http/mod.rs | 2 +- proxy/Cargo.toml | 6 +- proxy/src/http.rs | 39 +------ proxy/src/http/health_server.rs | 38 +++++-- proxy/src/protocol2.rs | 22 ++-- proxy/src/proxy/connect_compute.rs | 2 +- proxy/src/serverless.rs | 132 ++++++++++++++---------- proxy/src/serverless/sql_over_http.rs | 23 +++-- proxy/src/usage_metrics.rs | 68 +++++++----- 13 files changed, 299 insertions(+), 191 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 9a534142ce..ea5170623f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -612,7 +612,7 @@ dependencies = [ "once_cell", "pin-project-lite", "pin-utils", - "rustls", + "rustls 0.21.9", "tokio", "tracing", ] @@ -712,7 +712,7 @@ dependencies = [ "sha1", "sync_wrapper", "tokio", - "tokio-tungstenite", + "tokio-tungstenite 0.20.0", "tower", "tower-layer", "tower-service", @@ -2470,10 +2470,10 @@ dependencies = [ "http 0.2.9", "hyper 0.14.26", "log", - "rustls", + "rustls 0.21.9", "rustls-native-certs", "tokio", - "tokio-rustls", + "tokio-rustls 0.24.0", ] [[package]] @@ -2503,15 +2503,17 @@ dependencies = [ [[package]] name = "hyper-tungstenite" -version = "0.11.1" +version = "0.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7cc7dcb1ab67cd336f468a12491765672e61a3b6b148634dbfe2fe8acd3fe7d9" +checksum = "7a343d17fe7885302ed7252767dc7bb83609a874b6ff581142241ec4b73957ad" dependencies = [ - "hyper 0.14.26", + "http-body-util", + "hyper 1.1.0", + "hyper-util", "pin-project-lite", "tokio", - "tokio-tungstenite", - "tungstenite", + "tokio-tungstenite 0.21.0", + "tungstenite 0.21.0", ] [[package]] @@ -3824,14 +3826,14 @@ dependencies = [ "futures", "once_cell", "pq_proto", - "rustls", + "rustls 0.21.9", "rustls-pemfile", "serde", "thiserror", "tokio", "tokio-postgres", "tokio-postgres-rustls", - "tokio-rustls", + "tokio-rustls 0.24.0", "tracing", "workspace_hack", ] @@ -4042,9 +4044,13 @@ dependencies = [ "hex", "hmac", "hostname", + "http 1.0.0", + "http-body 1.0.0", + "http-body-util", "humantime", "hyper 1.1.0", "hyper-tungstenite", + "hyper-util", "ipnet", "itertools", "md5", @@ -4074,7 +4080,7 @@ dependencies = [ "routerify", "rstest", "rustc-hash", - "rustls", + "rustls 0.21.9", "rustls-pemfile", "scopeguard", "serde", @@ -4089,7 +4095,7 @@ dependencies = [ "tokio", "tokio-postgres", "tokio-postgres-rustls", - "tokio-rustls", + "tokio-rustls 0.24.0", "tokio-util", "tracing", "tracing-opentelemetry", @@ -4240,7 +4246,7 @@ dependencies = [ "itoa", "percent-encoding", "pin-project-lite", - "rustls", + "rustls 0.21.9", "rustls-native-certs", "rustls-pemfile", "rustls-webpki 0.101.7", @@ -4248,7 +4254,7 @@ dependencies = [ "sha1_smol", "socket2 0.4.9", "tokio", - "tokio-rustls", + "tokio-rustls 0.24.0", "tokio-util", "url", ] @@ -4390,14 +4396,14 @@ dependencies = [ "once_cell", "percent-encoding", "pin-project-lite", - "rustls", + "rustls 0.21.9", "rustls-pemfile", "serde", "serde_json", "serde_urlencoded", "tokio", "tokio-native-tls", - "tokio-rustls", + "tokio-rustls 0.24.0", "tokio-util", "tower-service", "url", @@ -4652,6 +4658,20 @@ dependencies = [ "sct", ] +[[package]] +name = "rustls" +version = "0.22.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e87c9956bd9807afa1f77e0f7594af32566e830e088a5576d27c5b6f30f49d41" +dependencies = [ + "log", + "ring 0.17.6", + "rustls-pki-types", + "rustls-webpki 0.102.1", + "subtle", + "zeroize", +] + [[package]] name = "rustls-native-certs" version = "0.6.3" @@ -4673,6 +4693,12 @@ dependencies = [ "base64 0.21.1", ] +[[package]] +name = "rustls-pki-types" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e9d979b3ce68192e42760c7810125eb6cf2ea10efae545a156063e61f314e2a" + [[package]] name = "rustls-webpki" version = "0.100.2" @@ -4693,6 +4719,17 @@ dependencies = [ "untrusted 0.9.0", ] +[[package]] +name = "rustls-webpki" +version = "0.102.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ef4ca26037c909dedb327b48c3327d0ba91d3dd3c4e05dad328f210ffb68e95b" +dependencies = [ + "ring 0.17.6", + "rustls-pki-types", + "untrusted 0.9.0", +] + [[package]] name = "rustversion" version = "1.0.12" @@ -4735,7 +4772,7 @@ dependencies = [ "serde_with", "thiserror", "tokio", - "tokio-rustls", + "tokio-rustls 0.24.0", "tokio-stream", "tracing", "tracing-appender", @@ -4901,7 +4938,7 @@ checksum = "2e95efd0cefa32028cdb9766c96de71d96671072f9fb494dc9fb84c0ef93e52b" dependencies = [ "httpdate", "reqwest", - "rustls", + "rustls 0.21.9", "sentry-backtrace", "sentry-contexts", "sentry-core", @@ -5600,16 +5637,15 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tls-listener" -version = "0.7.0" +version = "0.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81294c017957a1a69794f506723519255879e15a870507faf45dfed288b763dd" +checksum = "e516b94ec383622ae4da09b2a9bab471c4c880453c93fd0dd4dff1f6f2ff0849" dependencies = [ "futures-util", - "hyper 0.14.26", "pin-project-lite", "thiserror", "tokio", - "tokio-rustls", + "tokio-rustls 0.25.0", ] [[package]] @@ -5692,10 +5728,10 @@ checksum = "dd5831152cb0d3f79ef5523b357319ba154795d64c7078b2daa95a803b54057f" dependencies = [ "futures", "ring 0.16.20", - "rustls", + "rustls 0.21.9", "tokio", "tokio-postgres", - "tokio-rustls", + "tokio-rustls 0.24.0", ] [[package]] @@ -5704,7 +5740,18 @@ version = "0.24.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e0d409377ff5b1e3ca6437aa86c1eb7d40c134bfec254e44c830defa92669db5" dependencies = [ - "rustls", + "rustls 0.21.9", + "tokio", +] + +[[package]] +name = "tokio-rustls" +version = "0.25.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "775e0c0f0adb3a2f22a00c4745d728b479985fc15ee7ca6a2608388c5569860f" +dependencies = [ + "rustls 0.22.2", + "rustls-pki-types", "tokio", ] @@ -5743,7 +5790,19 @@ dependencies = [ "futures-util", "log", "tokio", - "tungstenite", + "tungstenite 0.20.1", +] + +[[package]] +name = "tokio-tungstenite" +version = "0.21.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c83b561d025642014097b66e6c1bb422783339e0909e4429cde4749d1990bc38" +dependencies = [ + "futures-util", + "log", + "tokio", + "tungstenite 0.21.0", ] [[package]] @@ -5851,7 +5910,7 @@ dependencies = [ "rustls-native-certs", "rustls-pemfile", "tokio", - "tokio-rustls", + "tokio-rustls 0.24.0", "tokio-stream", "tower", "tower-layer", @@ -6076,6 +6135,25 @@ dependencies = [ "utf-8", ] +[[package]] +name = "tungstenite" +version = "0.21.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ef1a641ea34f399a848dea702823bbecfb4c486f911735368f1f137cb8257e1" +dependencies = [ + "byteorder", + "bytes", + "data-encoding", + "http 1.0.0", + "httparse", + "log", + "rand 0.8.5", + "sha1", + "thiserror", + "url", + "utf-8", +] + [[package]] name = "twox-hash" version = "1.6.3" @@ -6173,7 +6251,7 @@ dependencies = [ "base64 0.21.1", "log", "once_cell", - "rustls", + "rustls 0.21.9", "rustls-webpki 0.100.2", "url", "webpki-roots 0.23.1", @@ -6774,7 +6852,7 @@ dependencies = [ "regex-syntax 0.8.2", "reqwest", "ring 0.16.20", - "rustls", + "rustls 0.21.9", "scopeguard", "serde", "serde_json", @@ -6785,14 +6863,14 @@ dependencies = [ "time", "time-macros", "tokio", - "tokio-rustls", + "tokio-rustls 0.24.0", "tokio-util", "toml_datetime", "toml_edit", "tower", "tracing", "tracing-core", - "tungstenite", + "tungstenite 0.20.1", "url", "uuid", "zstd", diff --git a/Cargo.toml b/Cargo.toml index 146d7eb114..7365a68a9c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -94,7 +94,7 @@ hyper-util = "0.1.0" http = "1" http-body = "1" http-body-util = "0.1" -hyper-tungstenite = "0.11" +hyper-tungstenite = "0.13.0" inotify = "0.10.2" ipnet = "2.9.0" itertools = "0.10" @@ -153,7 +153,7 @@ tar = "0.4" task-local-extensions = "0.1.4" test-context = "0.1" thiserror = "1.0" -tls-listener = { version = "0.7", features = ["rustls", "hyper-h1"] } +tls-listener = { version = "0.9", features = ["rustls", "tokio-net"] } tokio = { version = "1.17", features = ["macros"] } tokio-io-timeout = "1.2.0" tokio-postgres-rustls = "0.10.0" diff --git a/libs/utils/src/failpoint_support.rs b/libs/utils/src/failpoint_support.rs index f83b4ccfe8..6e397f581b 100644 --- a/libs/utils/src/failpoint_support.rs +++ b/libs/utils/src/failpoint_support.rs @@ -4,6 +4,8 @@ use crate::http::{ error::ApiError, json::{json_request, json_response}, }; +use bytes::Bytes; +use http_body_util::Full; use hyper::{Request, Response, StatusCode}; use routerify::Body; use serde::{Deserialize, Serialize}; @@ -152,7 +154,7 @@ pub struct FailpointConfig { pub async fn failpoints_handler( mut request: Request, _cancel: CancellationToken, -) -> Result, ApiError> { +) -> Result>, ApiError> { if !fail::has_failpoints() { return Err(ApiError::BadRequest(anyhow::anyhow!( "Cannot manage failpoints because storage was compiled without failpoints support" diff --git a/libs/utils/src/http/json.rs b/libs/utils/src/http/json.rs index 4afd6a9260..10c43006be 100644 --- a/libs/utils/src/http/json.rs +++ b/libs/utils/src/http/json.rs @@ -1,6 +1,6 @@ use anyhow::{anyhow, Context}; -use bytes::Buf; -use http_body_util::BodyExt; +use bytes::{Buf, Bytes}; +use http_body_util::{BodyExt, Full}; use hyper::{header, Request, Response, StatusCode}; use routerify::Body; use serde::{Deserialize, Serialize}; @@ -44,14 +44,14 @@ pub async fn json_request_or_empty_body Deserialize<'de>>( pub fn json_response( status: StatusCode, data: T, -) -> Result, ApiError> { +) -> Result>, ApiError> { let json = serde_json::to_string(&data) .context("Failed to serialize JSON response") .map_err(ApiError::InternalServerError)?; let response = Response::builder() .status(status) .header(header::CONTENT_TYPE, "application/json") - .body(Body::from(json)) + .body(Full::from(json)) .map_err(|e| ApiError::InternalServerError(e.into()))?; Ok(response) } diff --git a/libs/utils/src/http/mod.rs b/libs/utils/src/http/mod.rs index 15b73009e9..0b315ef029 100644 --- a/libs/utils/src/http/mod.rs +++ b/libs/utils/src/http/mod.rs @@ -5,4 +5,4 @@ pub mod request; /// Current fast way to apply simple http routing in various Neon binaries. /// Re-exported for sake of uniform approach, that could be later replaced with better alternatives, if needed. -pub use routerify::{ext::RequestExt, RouterBuilder}; +pub use routerify::{ext::RequestExt, Body, RequestServiceBuilder, RouterBuilder}; diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index 23a9bb178d..834a806b84 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -28,7 +28,11 @@ hmac.workspace = true hostname.workspace = true humantime.workspace = true hyper-tungstenite.workspace = true -hyper.workspace = true +hyper = { workspace = true, features = ["server"] } +hyper-util = { workspace = true, features = ["tokio", "server", "server-auto"] } +http = { workspace = true, features = [] } +http-body = { workspace = true, features = [] } +http-body-util = { workspace = true, features = [] } ipnet.workspace = true itertools.workspace = true md5.workspace = true diff --git a/proxy/src/http.rs b/proxy/src/http.rs index 59e1492ed4..6ef65cc3ba 100644 --- a/proxy/src/http.rs +++ b/proxy/src/http.rs @@ -4,14 +4,12 @@ pub mod health_server; -use std::{sync::Arc, time::Duration}; +use std::time::Duration; -use futures::FutureExt; pub use reqwest::{Request, Response, StatusCode}; pub use reqwest_middleware::{ClientWithMiddleware, Error}; pub use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware}; use tokio::time::Instant; -use tracing::trace; use crate::{metrics::CONSOLE_REQUEST_LATENCY, rate_limiter, url::ApiUrl}; use reqwest_middleware::RequestBuilder; @@ -21,8 +19,6 @@ use reqwest_middleware::RequestBuilder; /// We deliberately don't want to replace this with a public static. pub fn new_client(rate_limiter_config: rate_limiter::RateLimiterConfig) -> ClientWithMiddleware { let client = reqwest::ClientBuilder::new() - .dns_resolver(Arc::new(GaiResolver::default())) - .connection_verbose(true) .build() .expect("Failed to create http client"); @@ -34,8 +30,6 @@ pub fn new_client(rate_limiter_config: rate_limiter::RateLimiterConfig) -> Clien pub fn new_client_with_timeout(default_timout: Duration) -> ClientWithMiddleware { let timeout_client = reqwest::ClientBuilder::new() - .dns_resolver(Arc::new(GaiResolver::default())) - .connection_verbose(true) .timeout(default_timout) .build() .expect("Failed to create http client with timeout"); @@ -100,37 +94,6 @@ impl Endpoint { } } -/// https://docs.rs/reqwest/0.11.18/src/reqwest/dns/gai.rs.html -use hyper::{ - client::connect::dns::{GaiResolver as HyperGaiResolver, Name}, - service::Service, -}; -use reqwest::dns::{Addrs, Resolve, Resolving}; -#[derive(Debug)] -pub struct GaiResolver(HyperGaiResolver); - -impl Default for GaiResolver { - fn default() -> Self { - Self(HyperGaiResolver::new()) - } -} - -impl Resolve for GaiResolver { - fn resolve(&self, name: Name) -> Resolving { - let this = &mut self.0.clone(); - let start = Instant::now(); - Box::pin( - Service::::call(this, name.clone()).map(move |result| { - let resolve_duration = start.elapsed(); - trace!(duration = ?resolve_duration, addr = %name, "resolve host complete"); - result - .map(|addrs| -> Addrs { Box::new(addrs) }) - .map_err(|err| -> Box { Box::new(err) }) - }), - ) - } -} - #[cfg(test)] mod tests { use super::*; diff --git a/proxy/src/http/health_server.rs b/proxy/src/http/health_server.rs index 6186ddde0d..ec5a2ed6a6 100644 --- a/proxy/src/http/health_server.rs +++ b/proxy/src/http/health_server.rs @@ -1,14 +1,21 @@ -use anyhow::{anyhow, bail}; -use hyper::{Body, Request, Response, StatusCode}; +use anyhow::anyhow; +use http::{Request, Response}; +use hyper::StatusCode; +use hyper_util::{ + rt::{TokioExecutor, TokioIo}, + server::conn, +}; use std::{convert::Infallible, net::TcpListener}; use tracing::info; -use utils::http::{endpoint, error::ApiError, json::json_response, RouterBuilder, RouterService}; +use utils::http::{ + endpoint, error::ApiError, json::json_response, Body, RequestServiceBuilder, RouterBuilder, +}; async fn status_handler(_: Request) -> Result, ApiError> { - json_response(StatusCode::OK, "") + json_response(StatusCode::OK, "").map(|req| req.map(Body::new)) } -fn make_router() -> RouterBuilder { +fn make_router() -> RouterBuilder { endpoint::make_router().get("/v1/status", status_handler) } @@ -17,11 +24,20 @@ pub async fn task_main(http_listener: TcpListener) -> anyhow::Result info!("http has shut down"); } - let service = || RouterService::new(make_router().build()?); + let router = make_router().build().map_err(|e| anyhow!(e))?; + let builder = RequestServiceBuilder::new(router).map_err(|e| anyhow!(e))?; + let listener = tokio::net::TcpListener::from_std(http_listener)?; - hyper::Server::from_tcp(http_listener)? - .serve(service().map_err(|e| anyhow!(e))?) - .await?; - - bail!("hyper server without shutdown handling cannot shutdown successfully"); + loop { + let (stream, remote_addr) = listener.accept().await.unwrap(); + let io = TokioIo::new(stream); + let service = builder.build(remote_addr); + tokio::task::spawn(async move { + let builder = conn::auto::Builder::new(TokioExecutor::new()); + let res = builder.serve_connection(io, service).await; + if let Err(err) = res { + println!("Error serving connection: {:?}", err); + } + }); + } } diff --git a/proxy/src/protocol2.rs b/proxy/src/protocol2.rs index 1d8931be85..20fb37e11b 100644 --- a/proxy/src/protocol2.rs +++ b/proxy/src/protocol2.rs @@ -10,13 +10,15 @@ use std::{ }; use bytes::{Buf, BytesMut}; -use hyper::server::conn::{AddrIncoming, AddrStream}; use pin_project_lite::pin_project; use tls_listener::AsyncAccept; -use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf}; +use tokio::{ + io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf}, + net::{TcpListener, TcpStream}, +}; pub struct ProxyProtocolAccept { - pub incoming: AddrIncoming, + pub incoming: TcpListener, } pin_project! { @@ -327,20 +329,18 @@ impl AsyncRead for WithClientIp { } impl AsyncAccept for ProxyProtocolAccept { - type Connection = WithClientIp; + type Connection = WithClientIp; + type Address = SocketAddr; type Error = io::Error; fn poll_accept( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>> { - let conn = ready!(Pin::new(&mut self.incoming).poll_accept(cx)?); - let Some(conn) = conn else { - return Poll::Ready(None); - }; - - Poll::Ready(Some(Ok(WithClientIp::new(conn)))) + ) -> Poll> { + Pin::new(&mut self.incoming) + .poll_accept(cx) + .map_ok(|(c, a)| (WithClientIp::new(c), a)) } } diff --git a/proxy/src/proxy/connect_compute.rs b/proxy/src/proxy/connect_compute.rs index 72cab1fe5d..ca4e005468 100644 --- a/proxy/src/proxy/connect_compute.rs +++ b/proxy/src/proxy/connect_compute.rs @@ -7,8 +7,8 @@ use crate::{ proxy::retry::{retry_after, ShouldRetry}, }; use async_trait::async_trait; -use hyper::StatusCode; use pq_proto::StartupMessageParams; +use reqwest::StatusCode; use std::ops::ControlFlow; use tokio::time; use tracing::{error, info, warn}; diff --git a/proxy/src/serverless.rs b/proxy/src/serverless.rs index 8af008394a..18bc310eaf 100644 --- a/proxy/src/serverless.rs +++ b/proxy/src/serverless.rs @@ -6,41 +6,54 @@ mod conn_pool; mod sql_over_http; mod websocket; +use bytes::Bytes; pub use conn_pool::GlobalConnPoolOptions; -use anyhow::bail; +use http_body_util::Full; +use hyper::body::Incoming; use hyper::StatusCode; +use hyper_util::rt::{TokioExecutor, TokioIo}; +use hyper_util::server::conn; use metrics::IntCounterPairGuard; use rand::rngs::StdRng; use rand::SeedableRng; pub use reqwest_middleware::{ClientWithMiddleware, Error}; pub use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware}; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::select; +use tokio_rustls::TlsAcceptor; use tokio_util::task::TaskTracker; use crate::config::TlsConfig; use crate::context::RequestMonitoring; use crate::metrics::NUM_CLIENT_CONNECTION_GAUGE; -use crate::protocol2::{ProxyProtocolAccept, WithClientIp}; +use crate::protocol2::ProxyProtocolAccept; use crate::rate_limiter::EndpointRateLimiter; use crate::{cancellation::CancelMap, config::ProxyConfig}; -use futures::StreamExt; -use hyper::{ - server::{ - accept, - conn::{AddrIncoming, AddrStream}, - }, - Body, Method, Request, Response, -}; +use hyper::{Method, Request, Response}; use std::net::IpAddr; -use std::task::Poll; -use std::{future::ready, sync::Arc}; -use tls_listener::TlsListener; +use std::pin::pin; +use std::sync::Arc; +use tls_listener::{AsyncTls, TlsListener}; use tokio::net::TcpListener; use tokio_util::sync::CancellationToken; use tracing::{error, info, info_span, warn, Instrument}; use utils::http::{error::ApiError, json::json_response}; +#[derive(Clone)] +struct Tls(TlsAcceptor); + +impl AsyncTls for Tls { + type Stream = tokio_rustls::server::TlsStream; + type Error = std::io::Error; + type AcceptFuture = tokio_rustls::Accept; + + fn accept(&self, conn: C) -> Self::AcceptFuture { + tokio_rustls::TlsAcceptor::accept(&self.0, conn) + } +} + pub async fn task_main( config: &'static ProxyConfig, ws_listener: TcpListener, @@ -79,42 +92,52 @@ pub async fn task_main( }; let tls_acceptor: tokio_rustls::TlsAcceptor = tls_config.to_server_config().into(); - let mut addr_incoming = AddrIncoming::from_listener(ws_listener)?; - let _ = addr_incoming.set_nodelay(true); + // let mut addr_incoming = AddrIncoming::from_listener(ws_listener)?; + // let _ = addr_incoming.set_nodelay(true); let addr_incoming = ProxyProtocolAccept { - incoming: addr_incoming, + incoming: ws_listener, }; let ws_connections = tokio_util::task::task_tracker::TaskTracker::new(); + let ws_connections2 = ws_connections.clone(); ws_connections.close(); // allows `ws_connections.wait to complete` - let tls_listener = TlsListener::new(tls_acceptor, addr_incoming).filter(|conn| { - if let Err(err) = conn { - error!("failed to accept TLS connection for websockets: {err:?}"); - ready(false) - } else { - ready(true) - } - }); + let mut tls_listener = TlsListener::new(Tls(tls_acceptor), addr_incoming); - let make_svc = hyper::service::make_service_fn( - |stream: &tokio_rustls::server::TlsStream>| { + tokio::spawn(async move { + loop { + let (stream, remote_addr) = select! { + res = tls_listener.accept() => { + match res { + Err(err) => + {error!("failed to accept TLS connection for websockets: {err:?}"); continue}, + Ok(s) => s, + } + } + _ = cancellation_token.cancelled() => break, + }; let (io, tls) = stream.get_ref(); let client_addr = io.client_addr(); - let remote_addr = io.inner.remote_addr(); let sni_name = tls.server_name().map(|s| s.to_string()); let conn_pool = conn_pool.clone(); - let ws_connections = ws_connections.clone(); + let ws_connections = ws_connections2.clone(); let endpoint_rate_limiter = endpoint_rate_limiter.clone(); - async move { - let peer_addr = match client_addr { - Some(addr) => addr, - None if config.require_client_ip => bail!("missing required client ip"), - None => remote_addr, - }; - Ok(MetricService::new(hyper::service::service_fn( - move |req: Request| { + let peer_addr = match client_addr { + Some(addr) => addr, + None if config.require_client_ip => { + tracing::error!("Error serving connection: missing required client ip"); + continue; + } + None => remote_addr, + }; + + let io = TokioIo::new(stream); + + let cancellation_token = cancellation_token.clone(); + tokio::task::spawn(async move { + let service = MetricService::new(hyper::service::service_fn( + move |req: Request| { let sni_name = sni_name.clone(); let conn_pool = conn_pool.clone(); let ws_connections = ws_connections.clone(); @@ -144,15 +167,22 @@ pub async fn task_main( .await } }, - ))) - } - }, - ); - - hyper::Server::builder(accept::from_stream(tls_listener)) - .serve(make_svc) - .with_graceful_shutdown(cancellation_token.cancelled()) - .await?; + )); + let builder = conn::auto::Builder::new(TokioExecutor::new()); + let mut conn = pin!(builder.serve_connection(io, service)); + let res = select! { + _ = cancellation_token.cancelled() => { + conn.as_mut().graceful_shutdown(); + conn.await + } + res = conn.as_mut() => res, + }; + if let Err(err) = res { + tracing::error!("Error serving connection: {:?}", err); + } + }); + } + }); // await websocket connections ws_connections.wait().await; @@ -184,18 +214,14 @@ where type Error = S::Error; type Future = S::Future; - fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll> { - self.inner.poll_ready(cx) - } - - fn call(&mut self, req: Request) -> Self::Future { + fn call(&self, req: Request) -> Self::Future { self.inner.call(req) } } #[allow(clippy::too_many_arguments)] async fn request_handler( - mut request: Request, + mut request: Request, config: &'static ProxyConfig, tls: &'static TlsConfig, conn_pool: Arc, @@ -205,7 +231,7 @@ async fn request_handler( sni_hostname: Option, peer_addr: IpAddr, endpoint_rate_limiter: Arc, -) -> Result, ApiError> { +) -> Result>, ApiError> { let host = request .headers() .get("host") @@ -264,7 +290,7 @@ async fn request_handler( ) .header("Access-Control-Max-Age", "86400" /* 24 hours */) .status(StatusCode::OK) // 204 is also valid, but see: https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/OPTIONS#status_code - .body(Body::empty()) + .body(Full::new(Bytes::new())) .map_err(|e| ApiError::InternalServerError(e.into())) } else { json_response(StatusCode::BAD_REQUEST, "query is not supported") diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index 391fb95e9e..1fd181b188 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -1,15 +1,20 @@ use std::sync::Arc; use anyhow::bail; +use bytes::Buf; +use bytes::Bytes; use futures::pin_mut; use futures::StreamExt; -use hyper::body::HttpBody; +use http_body::Body; +use http_body_util::BodyExt; +use http_body_util::Full; +use hyper::body::Incoming; use hyper::header; use hyper::http::HeaderName; use hyper::http::HeaderValue; use hyper::Response; use hyper::StatusCode; -use hyper::{Body, HeaderMap, Request}; +use hyper::{HeaderMap, Request}; use serde_json::json; use serde_json::Map; use serde_json::Value; @@ -235,10 +240,10 @@ pub async fn handle( tls: &'static TlsConfig, config: &'static HttpConfig, ctx: &mut RequestMonitoring, - request: Request, + request: Request, sni_hostname: Option, conn_pool: Arc, -) -> Result, ApiError> { +) -> Result>, ApiError> { let result = tokio::time::timeout( config.request_timeout, handle_inner(tls, config, ctx, request, sni_hostname, conn_pool), @@ -347,10 +352,10 @@ async fn handle_inner( tls: &'static TlsConfig, config: &'static HttpConfig, ctx: &mut RequestMonitoring, - request: Request, + request: Request, sni_hostname: Option, conn_pool: Arc, -) -> anyhow::Result> { +) -> anyhow::Result>> { let _request_gauge = NUM_CONNECTION_REQUESTS_GAUGE .with_label_values(&["http"]) .guard(); @@ -406,8 +411,8 @@ async fn handle_inner( // // Read the query and query params from the request body // - let body = hyper::body::to_bytes(request.into_body()).await?; - let payload: Payload = serde_json::from_slice(&body)?; + let body = request.into_body().collect().await?.aggregate().reader(); + let payload: Payload = serde_json::from_reader(body)?; let mut client = conn_pool.get(ctx, conn_info, !allow_pool).await?; @@ -504,7 +509,7 @@ async fn handle_inner( let body = serde_json::to_string(&result).expect("json serialization should not fail"); let len = body.len(); let response = response - .body(Body::from(body)) + .body(Full::from(body)) // only fails if invalid status code or invalid header/values are given. // these are not user configurable so it cannot fail dynamically .expect("building response payload should not fail"); diff --git a/proxy/src/usage_metrics.rs b/proxy/src/usage_metrics.rs index 789a4c680c..adb349d1b3 100644 --- a/proxy/src/usage_metrics.rs +++ b/proxy/src/usage_metrics.rs @@ -235,18 +235,19 @@ async fn collect_metrics_iteration( #[cfg(test)] mod tests { - use std::{ - net::TcpListener, - sync::{Arc, Mutex}, - }; + use std::sync::{Arc, Mutex}; use anyhow::Error; + use bytes::{Buf, Bytes}; use chrono::Utc; use consumption_metrics::{Event, EventChunk}; - use hyper::{ - service::{make_service_fn, service_fn}, - Body, Response, + use http_body_util::{BodyExt, Empty}; + use hyper::{body::Incoming, service::service_fn, Response}; + use hyper_util::{ + rt::{TokioExecutor, TokioIo}, + server::conn, }; + use tokio::net::TcpListener; use url::Url; use super::{collect_metrics_iteration, Ids, Metrics}; @@ -254,30 +255,43 @@ mod tests { #[tokio::test] async fn metrics() { - let listener = TcpListener::bind("0.0.0.0:0").unwrap(); + let listener = TcpListener::bind("0.0.0.0:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); let reports = Arc::new(Mutex::new(vec![])); let reports2 = reports.clone(); - let server = hyper::server::Server::from_tcp(listener) - .unwrap() - .serve(make_service_fn(move |_| { - let reports = reports.clone(); - async move { - Ok::<_, Error>(service_fn(move |req| { - let reports = reports.clone(); - async move { - let bytes = hyper::body::to_bytes(req.into_body()).await?; - let events: EventChunk<'static, Event> = - serde_json::from_slice(&bytes)?; - reports.lock().unwrap().push(events); - Ok::<_, Error>(Response::new(Body::from(vec![]))) - } - })) - } - })); - let addr = server.local_addr(); - tokio::spawn(server); + let service = service_fn(move |req: hyper::Request| { + let reports = reports.clone(); + async move { + let bytes = req + .into_body() + .collect() + .await + .unwrap() + .aggregate() + .reader(); + let events: EventChunk<'static, Event> = + serde_json::from_reader(bytes)?; + reports.lock().unwrap().push(events); + Ok::<_, Error>(Response::new(Empty::::new())) + } + }); + + tokio::spawn(async move { + loop { + let (stream, _) = listener.accept().await.unwrap(); + let io = TokioIo::new(stream); + let service = service.clone(); + tokio::task::spawn(async move { + let builder = conn::auto::Builder::new(TokioExecutor::new()); + let res = builder.serve_connection(io, service).await; + if let Err(err) = res { + println!("Error serving connection: {:?}", err); + } + }); + } + }); let metrics = Metrics::default(); let client = http::new_client(RateLimiterConfig::default());