Compare commits

...

2 Commits

Author SHA1 Message Date
Conrad Ludgate
abdc9bb4ba hakari 2023-11-16 09:41:04 +01:00
Conrad Ludgate
eac2e7498c start transition 2023-11-15 22:41:19 +01:00
5 changed files with 63 additions and 46 deletions

View File

@@ -86,7 +86,7 @@ hostname = "0.3.1"
http-types = { version = "2", default-features = false } http-types = { version = "2", default-features = false }
humantime = "2.1" humantime = "2.1"
humantime-serde = "1.1.1" humantime-serde = "1.1.1"
hyper = "0.14" hyper = { version = "0.14", features=["backports"] }
hyper-tungstenite = "0.11" hyper-tungstenite = "0.11"
inotify = "0.10.2" inotify = "0.10.2"
itertools = "0.10" itertools = "0.10"

View File

@@ -16,7 +16,7 @@ aws-sdk-s3.workspace = true
aws-credential-types.workspace = true aws-credential-types.workspace = true
bytes.workspace = true bytes.workspace = true
camino.workspace = true camino.workspace = true
hyper = { workspace = true, features = ["stream"] } hyper = { workspace = true }
serde.workspace = true serde.workspace = true
serde_json.workspace = true serde_json.workspace = true
tokio = { workspace = true, features = ["sync", "fs", "io-util"] } tokio = { workspace = true, features = ["sync", "fs", "io-util"] }

View File

@@ -1,7 +1,7 @@
//! Tracing wrapper for Hyper HTTP server //! Tracing wrapper for Hyper HTTP server
use hyper::HeaderMap; use hyper::HeaderMap;
use hyper::{Body, Request, Response}; use hyper::{body::HttpBody, Request, Response};
use std::future::Future; use std::future::Future;
use tracing::Instrument; use tracing::Instrument;
use tracing_opentelemetry::OpenTelemetrySpanExt; use tracing_opentelemetry::OpenTelemetrySpanExt;
@@ -35,14 +35,14 @@ pub enum OtelName<'a> {
/// instrumentation libraries at: /// instrumentation libraries at:
/// <https://opentelemetry.io/registry/?language=rust&component=instrumentation> /// <https://opentelemetry.io/registry/?language=rust&component=instrumentation>
/// If a Hyper crate appears, consider switching to that. /// If a Hyper crate appears, consider switching to that.
pub async fn tracing_handler<F, R>( pub async fn tracing_handler<F, R, B1: HttpBody, B2: HttpBody>(
req: Request<Body>, req: Request<B1>,
handler: F, handler: F,
otel_name: OtelName<'_>, otel_name: OtelName<'_>,
) -> Response<Body> ) -> Response<B2>
where where
F: Fn(Request<Body>) -> R, F: Fn(Request<B1>) -> R,
R: Future<Output = Response<Body>>, R: Future<Output = Response<B2>>,
{ {
// Create a tracing span, with context propagated from the incoming // Create a tracing span, with context propagated from the incoming
// request if any. // request if any.

View File

@@ -10,18 +10,13 @@ use anyhow::bail;
use hyper::StatusCode; use hyper::StatusCode;
pub use reqwest_middleware::{ClientWithMiddleware, Error}; pub use reqwest_middleware::{ClientWithMiddleware, Error};
pub use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware}; pub use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
use tokio::task::JoinSet;
use crate::protocol2::{ProxyProtocolAccept, WithClientIp}; use crate::protocol2::ProxyProtocolAccept;
use crate::proxy::{NUM_CLIENT_CONNECTION_CLOSED_COUNTER, NUM_CLIENT_CONNECTION_OPENED_COUNTER}; use crate::proxy::{NUM_CLIENT_CONNECTION_CLOSED_COUNTER, NUM_CLIENT_CONNECTION_OPENED_COUNTER};
use crate::{cancellation::CancelMap, config::ProxyConfig}; use crate::{cancellation::CancelMap, config::ProxyConfig};
use futures::StreamExt; use futures::StreamExt;
use hyper::{ use hyper::{server::conn::AddrIncoming, Body, Method, Request, Response};
server::{
accept,
conn::{AddrIncoming, AddrStream},
},
Body, Method, Request, Response,
};
use std::task::Poll; use std::task::Poll;
use std::{future::ready, sync::Arc}; use std::{future::ready, sync::Arc};
@@ -69,7 +64,7 @@ pub async fn task_main(
incoming: addr_incoming, incoming: addr_incoming,
}; };
let tls_listener = TlsListener::new(tls_acceptor, addr_incoming).filter(|conn| { let mut tls_listener = TlsListener::new(tls_acceptor, addr_incoming).filter(|conn| {
if let Err(err) = conn { if let Err(err) = conn {
error!("failed to accept TLS connection for websockets: {err:?}"); error!("failed to accept TLS connection for websockets: {err:?}");
ready(false) ready(false)
@@ -78,49 +73,71 @@ pub async fn task_main(
} }
}); });
let make_svc = hyper::service::make_service_fn( let mut connections = JoinSet::new();
|stream: &tokio_rustls::server::TlsStream<WithClientIp<AddrStream>>| {
let (io, tls) = stream.get_ref(); loop {
let client_addr = io.client_addr(); tokio::select! {
let remote_addr = io.inner.remote_addr(); Some(tls_stream) = tls_listener.next() => {
let sni_name = tls.server_name().map(|s| s.to_string()); let tls_stream = tls_stream?;
let conn_pool = conn_pool.clone(); let (io, tls) = tls_stream.get_ref();
let client_addr = io.client_addr();
let remote_addr = io.inner.remote_addr();
let sni_name = tls.server_name().map(|s| s.to_string());
let conn_pool = conn_pool.clone();
async move {
let peer_addr = match client_addr { let peer_addr = match client_addr {
Some(addr) => addr, Some(addr) => addr,
None if config.require_client_ip => bail!("missing required client ip"), None if config.require_client_ip => bail!("missing required client ip"),
None => remote_addr, None => remote_addr,
}; };
Ok(MetricService::new(hyper::service::service_fn(
move |req: Request<Body>| {
let sni_name = sni_name.clone();
let conn_pool = conn_pool.clone();
async move { let service = MetricService::new(hyper::service::service_fn(move |req: Request<Body>| {
let cancel_map = Arc::new(CancelMap::default()); let sni_name = sni_name.clone();
let session_id = uuid::Uuid::new_v4(); let conn_pool = conn_pool.clone();
request_handler( async move {
req, config, conn_pool, cancel_map, session_id, sni_name, let cancel_map = Arc::new(CancelMap::default());
) let session_id = uuid::Uuid::new_v4();
request_handler(req, config, conn_pool, cancel_map, session_id, sni_name)
.instrument(info_span!( .instrument(info_span!(
"serverless", "serverless",
session = %session_id, session = %session_id,
%peer_addr, %peer_addr,
)) ))
.await .await
} }
}, }));
)))
}
},
);
hyper::Server::builder(accept::from_stream(tls_listener)) connections.spawn(async move {
.serve(make_svc) // todo(conrad): http2?
.with_graceful_shutdown(cancellation_token.cancelled()) if let Err(err) = hyper::server::conn::http1::Builder::new()
.await?; .serve_connection(tls_stream, service)
.await
{
println!("Error serving connection: {:?}", err);
}
});
}
Some(Err(e)) = connections.join_next(), if !connections.is_empty() => {
if !e.is_panic() && !e.is_cancelled() {
warn!("unexpected error from joined connection task: {e:?}");
}
}
_ = cancellation_token.cancelled() => {
drop(tls_listener);
break;
}
}
}
// Drain connections
while let Some(res) = connections.join_next().await {
if let Err(e) = res {
if !e.is_panic() && !e.is_cancelled() {
warn!("unexpected error from joined connection task: {e:?}");
}
}
}
Ok(()) Ok(())
} }

View File

@@ -36,7 +36,7 @@ futures-io = { version = "0.3" }
futures-sink = { version = "0.3" } futures-sink = { version = "0.3" }
futures-util = { version = "0.3", features = ["channel", "io", "sink"] } futures-util = { version = "0.3", features = ["channel", "io", "sink"] }
hex = { version = "0.4", features = ["serde"] } hex = { version = "0.4", features = ["serde"] }
hyper = { version = "0.14", features = ["full"] } hyper = { version = "0.14", features = ["backports", "full"] }
itertools = { version = "0.10" } itertools = { version = "0.10" }
libc = { version = "0.2", features = ["extra_traits"] } libc = { version = "0.2", features = ["extra_traits"] }
log = { version = "0.4", default-features = false, features = ["std"] } log = { version = "0.4", default-features = false, features = ["std"] }