This commit is contained in:
Conrad Ludgate
2024-01-18 14:11:35 +00:00
parent 665f4ff4b5
commit 2cf85471f5
13 changed files with 299 additions and 191 deletions

View File

@@ -4,14 +4,12 @@
pub mod health_server;
use std::{sync::Arc, time::Duration};
use std::time::Duration;
use futures::FutureExt;
pub use reqwest::{Request, Response, StatusCode};
pub use reqwest_middleware::{ClientWithMiddleware, Error};
pub use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
use tokio::time::Instant;
use tracing::trace;
use crate::{metrics::CONSOLE_REQUEST_LATENCY, rate_limiter, url::ApiUrl};
use reqwest_middleware::RequestBuilder;
@@ -21,8 +19,6 @@ use reqwest_middleware::RequestBuilder;
/// We deliberately don't want to replace this with a public static.
pub fn new_client(rate_limiter_config: rate_limiter::RateLimiterConfig) -> ClientWithMiddleware {
let client = reqwest::ClientBuilder::new()
.dns_resolver(Arc::new(GaiResolver::default()))
.connection_verbose(true)
.build()
.expect("Failed to create http client");
@@ -34,8 +30,6 @@ pub fn new_client(rate_limiter_config: rate_limiter::RateLimiterConfig) -> Clien
pub fn new_client_with_timeout(default_timout: Duration) -> ClientWithMiddleware {
let timeout_client = reqwest::ClientBuilder::new()
.dns_resolver(Arc::new(GaiResolver::default()))
.connection_verbose(true)
.timeout(default_timout)
.build()
.expect("Failed to create http client with timeout");
@@ -100,37 +94,6 @@ impl Endpoint {
}
}
/// https://docs.rs/reqwest/0.11.18/src/reqwest/dns/gai.rs.html
use hyper::{
client::connect::dns::{GaiResolver as HyperGaiResolver, Name},
service::Service,
};
use reqwest::dns::{Addrs, Resolve, Resolving};
#[derive(Debug)]
pub struct GaiResolver(HyperGaiResolver);
impl Default for GaiResolver {
fn default() -> Self {
Self(HyperGaiResolver::new())
}
}
impl Resolve for GaiResolver {
fn resolve(&self, name: Name) -> Resolving {
let this = &mut self.0.clone();
let start = Instant::now();
Box::pin(
Service::<Name>::call(this, name.clone()).map(move |result| {
let resolve_duration = start.elapsed();
trace!(duration = ?resolve_duration, addr = %name, "resolve host complete");
result
.map(|addrs| -> Addrs { Box::new(addrs) })
.map_err(|err| -> Box<dyn std::error::Error + Send + Sync> { Box::new(err) })
}),
)
}
}
#[cfg(test)]
mod tests {
use super::*;

View File

@@ -1,14 +1,21 @@
use anyhow::{anyhow, bail};
use hyper::{Body, Request, Response, StatusCode};
use anyhow::anyhow;
use http::{Request, Response};
use hyper::StatusCode;
use hyper_util::{
rt::{TokioExecutor, TokioIo},
server::conn,
};
use std::{convert::Infallible, net::TcpListener};
use tracing::info;
use utils::http::{endpoint, error::ApiError, json::json_response, RouterBuilder, RouterService};
use utils::http::{
endpoint, error::ApiError, json::json_response, Body, RequestServiceBuilder, RouterBuilder,
};
async fn status_handler(_: Request<Body>) -> Result<Response<Body>, ApiError> {
json_response(StatusCode::OK, "")
json_response(StatusCode::OK, "").map(|req| req.map(Body::new))
}
fn make_router() -> RouterBuilder<hyper::Body, ApiError> {
fn make_router() -> RouterBuilder<Body, ApiError> {
endpoint::make_router().get("/v1/status", status_handler)
}
@@ -17,11 +24,20 @@ pub async fn task_main(http_listener: TcpListener) -> anyhow::Result<Infallible>
info!("http has shut down");
}
let service = || RouterService::new(make_router().build()?);
let router = make_router().build().map_err(|e| anyhow!(e))?;
let builder = RequestServiceBuilder::new(router).map_err(|e| anyhow!(e))?;
let listener = tokio::net::TcpListener::from_std(http_listener)?;
hyper::Server::from_tcp(http_listener)?
.serve(service().map_err(|e| anyhow!(e))?)
.await?;
bail!("hyper server without shutdown handling cannot shutdown successfully");
loop {
let (stream, remote_addr) = listener.accept().await.unwrap();
let io = TokioIo::new(stream);
let service = builder.build(remote_addr);
tokio::task::spawn(async move {
let builder = conn::auto::Builder::new(TokioExecutor::new());
let res = builder.serve_connection(io, service).await;
if let Err(err) = res {
println!("Error serving connection: {:?}", err);
}
});
}
}

View File

@@ -10,13 +10,15 @@ use std::{
};
use bytes::{Buf, BytesMut};
use hyper::server::conn::{AddrIncoming, AddrStream};
use pin_project_lite::pin_project;
use tls_listener::AsyncAccept;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf};
use tokio::{
io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf},
net::{TcpListener, TcpStream},
};
pub struct ProxyProtocolAccept {
pub incoming: AddrIncoming,
pub incoming: TcpListener,
}
pin_project! {
@@ -327,20 +329,18 @@ impl<T: AsyncRead> AsyncRead for WithClientIp<T> {
}
impl AsyncAccept for ProxyProtocolAccept {
type Connection = WithClientIp<AddrStream>;
type Connection = WithClientIp<TcpStream>;
type Address = SocketAddr;
type Error = io::Error;
fn poll_accept(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Connection, Self::Error>>> {
let conn = ready!(Pin::new(&mut self.incoming).poll_accept(cx)?);
let Some(conn) = conn else {
return Poll::Ready(None);
};
Poll::Ready(Some(Ok(WithClientIp::new(conn))))
) -> Poll<Result<(Self::Connection, Self::Address), Self::Error>> {
Pin::new(&mut self.incoming)
.poll_accept(cx)
.map_ok(|(c, a)| (WithClientIp::new(c), a))
}
}

View File

@@ -7,8 +7,8 @@ use crate::{
proxy::retry::{retry_after, ShouldRetry},
};
use async_trait::async_trait;
use hyper::StatusCode;
use pq_proto::StartupMessageParams;
use reqwest::StatusCode;
use std::ops::ControlFlow;
use tokio::time;
use tracing::{error, info, warn};

View File

@@ -6,41 +6,54 @@ mod conn_pool;
mod sql_over_http;
mod websocket;
use bytes::Bytes;
pub use conn_pool::GlobalConnPoolOptions;
use anyhow::bail;
use http_body_util::Full;
use hyper::body::Incoming;
use hyper::StatusCode;
use hyper_util::rt::{TokioExecutor, TokioIo};
use hyper_util::server::conn;
use metrics::IntCounterPairGuard;
use rand::rngs::StdRng;
use rand::SeedableRng;
pub use reqwest_middleware::{ClientWithMiddleware, Error};
pub use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::select;
use tokio_rustls::TlsAcceptor;
use tokio_util::task::TaskTracker;
use crate::config::TlsConfig;
use crate::context::RequestMonitoring;
use crate::metrics::NUM_CLIENT_CONNECTION_GAUGE;
use crate::protocol2::{ProxyProtocolAccept, WithClientIp};
use crate::protocol2::ProxyProtocolAccept;
use crate::rate_limiter::EndpointRateLimiter;
use crate::{cancellation::CancelMap, config::ProxyConfig};
use futures::StreamExt;
use hyper::{
server::{
accept,
conn::{AddrIncoming, AddrStream},
},
Body, Method, Request, Response,
};
use hyper::{Method, Request, Response};
use std::net::IpAddr;
use std::task::Poll;
use std::{future::ready, sync::Arc};
use tls_listener::TlsListener;
use std::pin::pin;
use std::sync::Arc;
use tls_listener::{AsyncTls, TlsListener};
use tokio::net::TcpListener;
use tokio_util::sync::CancellationToken;
use tracing::{error, info, info_span, warn, Instrument};
use utils::http::{error::ApiError, json::json_response};
#[derive(Clone)]
struct Tls(TlsAcceptor);
impl<C: AsyncRead + AsyncWrite + Unpin> AsyncTls<C> for Tls {
type Stream = tokio_rustls::server::TlsStream<C>;
type Error = std::io::Error;
type AcceptFuture = tokio_rustls::Accept<C>;
fn accept(&self, conn: C) -> Self::AcceptFuture {
tokio_rustls::TlsAcceptor::accept(&self.0, conn)
}
}
pub async fn task_main(
config: &'static ProxyConfig,
ws_listener: TcpListener,
@@ -79,42 +92,52 @@ pub async fn task_main(
};
let tls_acceptor: tokio_rustls::TlsAcceptor = tls_config.to_server_config().into();
let mut addr_incoming = AddrIncoming::from_listener(ws_listener)?;
let _ = addr_incoming.set_nodelay(true);
// let mut addr_incoming = AddrIncoming::from_listener(ws_listener)?;
// let _ = addr_incoming.set_nodelay(true);
let addr_incoming = ProxyProtocolAccept {
incoming: addr_incoming,
incoming: ws_listener,
};
let ws_connections = tokio_util::task::task_tracker::TaskTracker::new();
let ws_connections2 = ws_connections.clone();
ws_connections.close(); // allows `ws_connections.wait to complete`
let tls_listener = TlsListener::new(tls_acceptor, addr_incoming).filter(|conn| {
if let Err(err) = conn {
error!("failed to accept TLS connection for websockets: {err:?}");
ready(false)
} else {
ready(true)
}
});
let mut tls_listener = TlsListener::new(Tls(tls_acceptor), addr_incoming);
let make_svc = hyper::service::make_service_fn(
|stream: &tokio_rustls::server::TlsStream<WithClientIp<AddrStream>>| {
tokio::spawn(async move {
loop {
let (stream, remote_addr) = select! {
res = tls_listener.accept() => {
match res {
Err(err) =>
{error!("failed to accept TLS connection for websockets: {err:?}"); continue},
Ok(s) => s,
}
}
_ = cancellation_token.cancelled() => break,
};
let (io, tls) = stream.get_ref();
let client_addr = io.client_addr();
let remote_addr = io.inner.remote_addr();
let sni_name = tls.server_name().map(|s| s.to_string());
let conn_pool = conn_pool.clone();
let ws_connections = ws_connections.clone();
let ws_connections = ws_connections2.clone();
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
async move {
let peer_addr = match client_addr {
Some(addr) => addr,
None if config.require_client_ip => bail!("missing required client ip"),
None => remote_addr,
};
Ok(MetricService::new(hyper::service::service_fn(
move |req: Request<Body>| {
let peer_addr = match client_addr {
Some(addr) => addr,
None if config.require_client_ip => {
tracing::error!("Error serving connection: missing required client ip");
continue;
}
None => remote_addr,
};
let io = TokioIo::new(stream);
let cancellation_token = cancellation_token.clone();
tokio::task::spawn(async move {
let service = MetricService::new(hyper::service::service_fn(
move |req: Request<Incoming>| {
let sni_name = sni_name.clone();
let conn_pool = conn_pool.clone();
let ws_connections = ws_connections.clone();
@@ -144,15 +167,22 @@ pub async fn task_main(
.await
}
},
)))
}
},
);
hyper::Server::builder(accept::from_stream(tls_listener))
.serve(make_svc)
.with_graceful_shutdown(cancellation_token.cancelled())
.await?;
));
let builder = conn::auto::Builder::new(TokioExecutor::new());
let mut conn = pin!(builder.serve_connection(io, service));
let res = select! {
_ = cancellation_token.cancelled() => {
conn.as_mut().graceful_shutdown();
conn.await
}
res = conn.as_mut() => res,
};
if let Err(err) = res {
tracing::error!("Error serving connection: {:?}", err);
}
});
}
});
// await websocket connections
ws_connections.wait().await;
@@ -184,18 +214,14 @@ where
type Error = S::Error;
type Future = S::Future;
fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
fn call(&self, req: Request<ReqBody>) -> Self::Future {
self.inner.call(req)
}
}
#[allow(clippy::too_many_arguments)]
async fn request_handler(
mut request: Request<Body>,
mut request: Request<Incoming>,
config: &'static ProxyConfig,
tls: &'static TlsConfig,
conn_pool: Arc<conn_pool::GlobalConnPool>,
@@ -205,7 +231,7 @@ async fn request_handler(
sni_hostname: Option<String>,
peer_addr: IpAddr,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
) -> Result<Response<Body>, ApiError> {
) -> Result<Response<Full<Bytes>>, ApiError> {
let host = request
.headers()
.get("host")
@@ -264,7 +290,7 @@ async fn request_handler(
)
.header("Access-Control-Max-Age", "86400" /* 24 hours */)
.status(StatusCode::OK) // 204 is also valid, but see: https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/OPTIONS#status_code
.body(Body::empty())
.body(Full::new(Bytes::new()))
.map_err(|e| ApiError::InternalServerError(e.into()))
} else {
json_response(StatusCode::BAD_REQUEST, "query is not supported")

View File

@@ -1,15 +1,20 @@
use std::sync::Arc;
use anyhow::bail;
use bytes::Buf;
use bytes::Bytes;
use futures::pin_mut;
use futures::StreamExt;
use hyper::body::HttpBody;
use http_body::Body;
use http_body_util::BodyExt;
use http_body_util::Full;
use hyper::body::Incoming;
use hyper::header;
use hyper::http::HeaderName;
use hyper::http::HeaderValue;
use hyper::Response;
use hyper::StatusCode;
use hyper::{Body, HeaderMap, Request};
use hyper::{HeaderMap, Request};
use serde_json::json;
use serde_json::Map;
use serde_json::Value;
@@ -235,10 +240,10 @@ pub async fn handle(
tls: &'static TlsConfig,
config: &'static HttpConfig,
ctx: &mut RequestMonitoring,
request: Request<Body>,
request: Request<Incoming>,
sni_hostname: Option<String>,
conn_pool: Arc<GlobalConnPool>,
) -> Result<Response<Body>, ApiError> {
) -> Result<Response<Full<Bytes>>, ApiError> {
let result = tokio::time::timeout(
config.request_timeout,
handle_inner(tls, config, ctx, request, sni_hostname, conn_pool),
@@ -347,10 +352,10 @@ async fn handle_inner(
tls: &'static TlsConfig,
config: &'static HttpConfig,
ctx: &mut RequestMonitoring,
request: Request<Body>,
request: Request<Incoming>,
sni_hostname: Option<String>,
conn_pool: Arc<GlobalConnPool>,
) -> anyhow::Result<Response<Body>> {
) -> anyhow::Result<Response<Full<Bytes>>> {
let _request_gauge = NUM_CONNECTION_REQUESTS_GAUGE
.with_label_values(&["http"])
.guard();
@@ -406,8 +411,8 @@ async fn handle_inner(
//
// Read the query and query params from the request body
//
let body = hyper::body::to_bytes(request.into_body()).await?;
let payload: Payload = serde_json::from_slice(&body)?;
let body = request.into_body().collect().await?.aggregate().reader();
let payload: Payload = serde_json::from_reader(body)?;
let mut client = conn_pool.get(ctx, conn_info, !allow_pool).await?;
@@ -504,7 +509,7 @@ async fn handle_inner(
let body = serde_json::to_string(&result).expect("json serialization should not fail");
let len = body.len();
let response = response
.body(Body::from(body))
.body(Full::from(body))
// only fails if invalid status code or invalid header/values are given.
// these are not user configurable so it cannot fail dynamically
.expect("building response payload should not fail");

View File

@@ -235,18 +235,19 @@ async fn collect_metrics_iteration(
#[cfg(test)]
mod tests {
use std::{
net::TcpListener,
sync::{Arc, Mutex},
};
use std::sync::{Arc, Mutex};
use anyhow::Error;
use bytes::{Buf, Bytes};
use chrono::Utc;
use consumption_metrics::{Event, EventChunk};
use hyper::{
service::{make_service_fn, service_fn},
Body, Response,
use http_body_util::{BodyExt, Empty};
use hyper::{body::Incoming, service::service_fn, Response};
use hyper_util::{
rt::{TokioExecutor, TokioIo},
server::conn,
};
use tokio::net::TcpListener;
use url::Url;
use super::{collect_metrics_iteration, Ids, Metrics};
@@ -254,30 +255,43 @@ mod tests {
#[tokio::test]
async fn metrics() {
let listener = TcpListener::bind("0.0.0.0:0").unwrap();
let listener = TcpListener::bind("0.0.0.0:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let reports = Arc::new(Mutex::new(vec![]));
let reports2 = reports.clone();
let server = hyper::server::Server::from_tcp(listener)
.unwrap()
.serve(make_service_fn(move |_| {
let reports = reports.clone();
async move {
Ok::<_, Error>(service_fn(move |req| {
let reports = reports.clone();
async move {
let bytes = hyper::body::to_bytes(req.into_body()).await?;
let events: EventChunk<'static, Event<Ids, String>> =
serde_json::from_slice(&bytes)?;
reports.lock().unwrap().push(events);
Ok::<_, Error>(Response::new(Body::from(vec![])))
}
}))
}
}));
let addr = server.local_addr();
tokio::spawn(server);
let service = service_fn(move |req: hyper::Request<Incoming>| {
let reports = reports.clone();
async move {
let bytes = req
.into_body()
.collect()
.await
.unwrap()
.aggregate()
.reader();
let events: EventChunk<'static, Event<Ids, String>> =
serde_json::from_reader(bytes)?;
reports.lock().unwrap().push(events);
Ok::<_, Error>(Response::new(Empty::<Bytes>::new()))
}
});
tokio::spawn(async move {
loop {
let (stream, _) = listener.accept().await.unwrap();
let io = TokioIo::new(stream);
let service = service.clone();
tokio::task::spawn(async move {
let builder = conn::auto::Builder::new(TokioExecutor::new());
let res = builder.serve_connection(io, service).await;
if let Err(err) = res {
println!("Error serving connection: {:?}", err);
}
});
}
});
let metrics = Metrics::default();
let client = http::new_client(RateLimiterConfig::default());