From eac2e7498c70d88620344b0a7aa13c4998ba6b39 Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Wed, 15 Nov 2023 22:26:27 +0100 Subject: [PATCH] start transition --- Cargo.toml | 2 +- libs/remote_storage/Cargo.toml | 2 +- libs/tracing-utils/src/http.rs | 12 ++--- proxy/src/serverless.rs | 91 ++++++++++++++++++++-------------- 4 files changed, 62 insertions(+), 45 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index bfdb0442ab..c9e5e79b03 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -86,7 +86,7 @@ hostname = "0.3.1" http-types = { version = "2", default-features = false } humantime = "2.1" humantime-serde = "1.1.1" -hyper = "0.14" +hyper = { version = "0.14", features=["backports"] } hyper-tungstenite = "0.11" inotify = "0.10.2" itertools = "0.10" diff --git a/libs/remote_storage/Cargo.toml b/libs/remote_storage/Cargo.toml index d7bcce28cb..7e92e2632d 100644 --- a/libs/remote_storage/Cargo.toml +++ b/libs/remote_storage/Cargo.toml @@ -16,7 +16,7 @@ aws-sdk-s3.workspace = true aws-credential-types.workspace = true bytes.workspace = true camino.workspace = true -hyper = { workspace = true, features = ["stream"] } +hyper = { workspace = true } serde.workspace = true serde_json.workspace = true tokio = { workspace = true, features = ["sync", "fs", "io-util"] } diff --git a/libs/tracing-utils/src/http.rs b/libs/tracing-utils/src/http.rs index f5ab267ff3..838734b858 100644 --- a/libs/tracing-utils/src/http.rs +++ b/libs/tracing-utils/src/http.rs @@ -1,7 +1,7 @@ //! Tracing wrapper for Hyper HTTP server use hyper::HeaderMap; -use hyper::{Body, Request, Response}; +use hyper::{body::HttpBody, Request, Response}; use std::future::Future; use tracing::Instrument; use tracing_opentelemetry::OpenTelemetrySpanExt; @@ -35,14 +35,14 @@ pub enum OtelName<'a> { /// instrumentation libraries at: /// /// If a Hyper crate appears, consider switching to that. -pub async fn tracing_handler( - req: Request, +pub async fn tracing_handler( + req: Request, handler: F, otel_name: OtelName<'_>, -) -> Response +) -> Response where - F: Fn(Request) -> R, - R: Future>, + F: Fn(Request) -> R, + R: Future>, { // Create a tracing span, with context propagated from the incoming // request if any. diff --git a/proxy/src/serverless.rs b/proxy/src/serverless.rs index 23deda3ae6..35f3bee9e8 100644 --- a/proxy/src/serverless.rs +++ b/proxy/src/serverless.rs @@ -10,18 +10,13 @@ use anyhow::bail; use hyper::StatusCode; pub use reqwest_middleware::{ClientWithMiddleware, Error}; pub use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware}; +use tokio::task::JoinSet; -use crate::protocol2::{ProxyProtocolAccept, WithClientIp}; +use crate::protocol2::ProxyProtocolAccept; use crate::proxy::{NUM_CLIENT_CONNECTION_CLOSED_COUNTER, NUM_CLIENT_CONNECTION_OPENED_COUNTER}; use crate::{cancellation::CancelMap, config::ProxyConfig}; use futures::StreamExt; -use hyper::{ - server::{ - accept, - conn::{AddrIncoming, AddrStream}, - }, - Body, Method, Request, Response, -}; +use hyper::{server::conn::AddrIncoming, Body, Method, Request, Response}; use std::task::Poll; use std::{future::ready, sync::Arc}; @@ -69,7 +64,7 @@ pub async fn task_main( incoming: addr_incoming, }; - let tls_listener = TlsListener::new(tls_acceptor, addr_incoming).filter(|conn| { + let mut tls_listener = TlsListener::new(tls_acceptor, addr_incoming).filter(|conn| { if let Err(err) = conn { error!("failed to accept TLS connection for websockets: {err:?}"); ready(false) @@ -78,49 +73,71 @@ pub async fn task_main( } }); - let make_svc = hyper::service::make_service_fn( - |stream: &tokio_rustls::server::TlsStream>| { - let (io, tls) = stream.get_ref(); - let client_addr = io.client_addr(); - let remote_addr = io.inner.remote_addr(); - let sni_name = tls.server_name().map(|s| s.to_string()); - let conn_pool = conn_pool.clone(); + let mut connections = JoinSet::new(); + + loop { + tokio::select! { + Some(tls_stream) = tls_listener.next() => { + let tls_stream = tls_stream?; + let (io, tls) = tls_stream.get_ref(); + let client_addr = io.client_addr(); + let remote_addr = io.inner.remote_addr(); + let sni_name = tls.server_name().map(|s| s.to_string()); + let conn_pool = conn_pool.clone(); - async move { let peer_addr = match client_addr { Some(addr) => addr, None if config.require_client_ip => bail!("missing required client ip"), None => remote_addr, }; - Ok(MetricService::new(hyper::service::service_fn( - move |req: Request| { - let sni_name = sni_name.clone(); - let conn_pool = conn_pool.clone(); - async move { - let cancel_map = Arc::new(CancelMap::default()); - let session_id = uuid::Uuid::new_v4(); + let service = MetricService::new(hyper::service::service_fn(move |req: Request| { + let sni_name = sni_name.clone(); + let conn_pool = conn_pool.clone(); - request_handler( - req, config, conn_pool, cancel_map, session_id, sni_name, - ) + async move { + let cancel_map = Arc::new(CancelMap::default()); + let session_id = uuid::Uuid::new_v4(); + + request_handler(req, config, conn_pool, cancel_map, session_id, sni_name) .instrument(info_span!( "serverless", session = %session_id, %peer_addr, )) .await - } - }, - ))) - } - }, - ); + } + })); - hyper::Server::builder(accept::from_stream(tls_listener)) - .serve(make_svc) - .with_graceful_shutdown(cancellation_token.cancelled()) - .await?; + connections.spawn(async move { + // todo(conrad): http2? + if let Err(err) = hyper::server::conn::http1::Builder::new() + .serve_connection(tls_stream, service) + .await + { + println!("Error serving connection: {:?}", err); + } + }); + } + Some(Err(e)) = connections.join_next(), if !connections.is_empty() => { + if !e.is_panic() && !e.is_cancelled() { + warn!("unexpected error from joined connection task: {e:?}"); + } + } + _ = cancellation_token.cancelled() => { + drop(tls_listener); + break; + } + } + } + // Drain connections + while let Some(res) = connections.join_next().await { + if let Err(e) = res { + if !e.is_panic() && !e.is_cancelled() { + warn!("unexpected error from joined connection task: {e:?}"); + } + } + } Ok(()) }