Compare commits

...

2 Commits

Author SHA1 Message Date
Conrad Ludgate
abdc9bb4ba hakari 2023-11-16 09:41:04 +01:00
Conrad Ludgate
eac2e7498c start transition 2023-11-15 22:41:19 +01:00
5 changed files with 63 additions and 46 deletions

View File

@@ -86,7 +86,7 @@ hostname = "0.3.1"
http-types = { version = "2", default-features = false }
humantime = "2.1"
humantime-serde = "1.1.1"
hyper = "0.14"
hyper = { version = "0.14", features=["backports"] }
hyper-tungstenite = "0.11"
inotify = "0.10.2"
itertools = "0.10"

View File

@@ -16,7 +16,7 @@ aws-sdk-s3.workspace = true
aws-credential-types.workspace = true
bytes.workspace = true
camino.workspace = true
hyper = { workspace = true, features = ["stream"] }
hyper = { workspace = true }
serde.workspace = true
serde_json.workspace = true
tokio = { workspace = true, features = ["sync", "fs", "io-util"] }

View File

@@ -1,7 +1,7 @@
//! Tracing wrapper for Hyper HTTP server
use hyper::HeaderMap;
use hyper::{Body, Request, Response};
use hyper::{body::HttpBody, Request, Response};
use std::future::Future;
use tracing::Instrument;
use tracing_opentelemetry::OpenTelemetrySpanExt;
@@ -35,14 +35,14 @@ pub enum OtelName<'a> {
/// instrumentation libraries at:
/// <https://opentelemetry.io/registry/?language=rust&component=instrumentation>
/// If a Hyper crate appears, consider switching to that.
pub async fn tracing_handler<F, R>(
req: Request<Body>,
pub async fn tracing_handler<F, R, B1: HttpBody, B2: HttpBody>(
req: Request<B1>,
handler: F,
otel_name: OtelName<'_>,
) -> Response<Body>
) -> Response<B2>
where
F: Fn(Request<Body>) -> R,
R: Future<Output = Response<Body>>,
F: Fn(Request<B1>) -> R,
R: Future<Output = Response<B2>>,
{
// Create a tracing span, with context propagated from the incoming
// request if any.

View File

@@ -10,18 +10,13 @@ use anyhow::bail;
use hyper::StatusCode;
pub use reqwest_middleware::{ClientWithMiddleware, Error};
pub use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
use tokio::task::JoinSet;
use crate::protocol2::{ProxyProtocolAccept, WithClientIp};
use crate::protocol2::ProxyProtocolAccept;
use crate::proxy::{NUM_CLIENT_CONNECTION_CLOSED_COUNTER, NUM_CLIENT_CONNECTION_OPENED_COUNTER};
use crate::{cancellation::CancelMap, config::ProxyConfig};
use futures::StreamExt;
use hyper::{
server::{
accept,
conn::{AddrIncoming, AddrStream},
},
Body, Method, Request, Response,
};
use hyper::{server::conn::AddrIncoming, Body, Method, Request, Response};
use std::task::Poll;
use std::{future::ready, sync::Arc};
@@ -69,7 +64,7 @@ pub async fn task_main(
incoming: addr_incoming,
};
let tls_listener = TlsListener::new(tls_acceptor, addr_incoming).filter(|conn| {
let mut tls_listener = TlsListener::new(tls_acceptor, addr_incoming).filter(|conn| {
if let Err(err) = conn {
error!("failed to accept TLS connection for websockets: {err:?}");
ready(false)
@@ -78,49 +73,71 @@ pub async fn task_main(
}
});
let make_svc = hyper::service::make_service_fn(
|stream: &tokio_rustls::server::TlsStream<WithClientIp<AddrStream>>| {
let (io, tls) = stream.get_ref();
let client_addr = io.client_addr();
let remote_addr = io.inner.remote_addr();
let sni_name = tls.server_name().map(|s| s.to_string());
let conn_pool = conn_pool.clone();
let mut connections = JoinSet::new();
loop {
tokio::select! {
Some(tls_stream) = tls_listener.next() => {
let tls_stream = tls_stream?;
let (io, tls) = tls_stream.get_ref();
let client_addr = io.client_addr();
let remote_addr = io.inner.remote_addr();
let sni_name = tls.server_name().map(|s| s.to_string());
let conn_pool = conn_pool.clone();
async move {
let peer_addr = match client_addr {
Some(addr) => addr,
None if config.require_client_ip => bail!("missing required client ip"),
None => remote_addr,
};
Ok(MetricService::new(hyper::service::service_fn(
move |req: Request<Body>| {
let sni_name = sni_name.clone();
let conn_pool = conn_pool.clone();
async move {
let cancel_map = Arc::new(CancelMap::default());
let session_id = uuid::Uuid::new_v4();
let service = MetricService::new(hyper::service::service_fn(move |req: Request<Body>| {
let sni_name = sni_name.clone();
let conn_pool = conn_pool.clone();
request_handler(
req, config, conn_pool, cancel_map, session_id, sni_name,
)
async move {
let cancel_map = Arc::new(CancelMap::default());
let session_id = uuid::Uuid::new_v4();
request_handler(req, config, conn_pool, cancel_map, session_id, sni_name)
.instrument(info_span!(
"serverless",
session = %session_id,
%peer_addr,
))
.await
}
},
)))
}
},
);
}
}));
hyper::Server::builder(accept::from_stream(tls_listener))
.serve(make_svc)
.with_graceful_shutdown(cancellation_token.cancelled())
.await?;
connections.spawn(async move {
// todo(conrad): http2?
if let Err(err) = hyper::server::conn::http1::Builder::new()
.serve_connection(tls_stream, service)
.await
{
println!("Error serving connection: {:?}", err);
}
});
}
Some(Err(e)) = connections.join_next(), if !connections.is_empty() => {
if !e.is_panic() && !e.is_cancelled() {
warn!("unexpected error from joined connection task: {e:?}");
}
}
_ = cancellation_token.cancelled() => {
drop(tls_listener);
break;
}
}
}
// Drain connections
while let Some(res) = connections.join_next().await {
if let Err(e) = res {
if !e.is_panic() && !e.is_cancelled() {
warn!("unexpected error from joined connection task: {e:?}");
}
}
}
Ok(())
}

View File

@@ -36,7 +36,7 @@ futures-io = { version = "0.3" }
futures-sink = { version = "0.3" }
futures-util = { version = "0.3", features = ["channel", "io", "sink"] }
hex = { version = "0.4", features = ["serde"] }
hyper = { version = "0.14", features = ["full"] }
hyper = { version = "0.14", features = ["backports", "full"] }
itertools = { version = "0.10" }
libc = { version = "0.2", features = ["extra_traits"] }
log = { version = "0.4", default-features = false, features = ["std"] }