proxy: hyper1 for only proxy (#7073)

## Problem

hyper1 offers control over the HTTP connection that hyper0_14 does not.
We're blocked on switching all services to hyper1 because of how we use
tonic, but no reason we can't switch proxy over.

## Summary of changes

1. hyper0.14 -> hyper1
    1. self managed server
    2. Remove the `WithConnectionGuard` wrapper from `protocol2`
2. Remove TLS listener as it's no longer necessary
3. include first session ID in connection startup logs
This commit is contained in:
Conrad Ludgate
2024-04-10 09:23:59 +01:00
committed by GitHub
parent fd88d4608c
commit c0ff4f18dc
9 changed files with 458 additions and 445 deletions

View File

@@ -5,19 +5,13 @@ use std::{
io,
net::SocketAddr,
pin::{pin, Pin},
sync::Mutex,
task::{ready, Context, Poll},
};
use bytes::{Buf, BytesMut};
use hyper::server::accept::Accept;
use hyper::server::conn::{AddrIncoming, AddrStream};
use metrics::IntCounterPairGuard;
use hyper::server::conn::AddrIncoming;
use pin_project_lite::pin_project;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf};
use uuid::Uuid;
use crate::metrics::NUM_CLIENT_CONNECTION_GAUGE;
pub struct ProxyProtocolAccept {
pub incoming: AddrIncoming,
@@ -331,103 +325,6 @@ impl<T: AsyncRead> AsyncRead for WithClientIp<T> {
}
}
impl Accept for ProxyProtocolAccept {
type Conn = WithConnectionGuard<WithClientIp<AddrStream>>;
type Error = io::Error;
fn poll_accept(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
let conn = ready!(Pin::new(&mut self.incoming).poll_accept(cx)?);
let conn_id = uuid::Uuid::new_v4();
let span = tracing::info_span!("http_conn", ?conn_id);
{
let _enter = span.enter();
tracing::info!("accepted new TCP connection");
}
let Some(conn) = conn else {
return Poll::Ready(None);
};
Poll::Ready(Some(Ok(WithConnectionGuard {
inner: WithClientIp::new(conn),
connection_id: Uuid::new_v4(),
gauge: Mutex::new(Some(
NUM_CLIENT_CONNECTION_GAUGE
.with_label_values(&[self.protocol])
.guard(),
)),
span,
})))
}
}
pin_project! {
pub struct WithConnectionGuard<T> {
#[pin]
pub inner: T,
pub connection_id: Uuid,
pub gauge: Mutex<Option<IntCounterPairGuard>>,
pub span: tracing::Span,
}
impl<S> PinnedDrop for WithConnectionGuard<S> {
fn drop(this: Pin<&mut Self>) {
let _enter = this.span.enter();
tracing::info!("HTTP connection closed")
}
}
}
impl<T: AsyncWrite> AsyncWrite for WithConnectionGuard<T> {
#[inline]
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, io::Error>> {
self.project().inner.poll_write(cx, buf)
}
#[inline]
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
self.project().inner.poll_flush(cx)
}
#[inline]
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
self.project().inner.poll_shutdown(cx)
}
#[inline]
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[io::IoSlice<'_>],
) -> Poll<Result<usize, io::Error>> {
self.project().inner.poll_write_vectored(cx, bufs)
}
#[inline]
fn is_write_vectored(&self) -> bool {
self.inner.is_write_vectored()
}
}
impl<T: AsyncRead> AsyncRead for WithConnectionGuard<T> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
self.project().inner.poll_read(cx, buf)
}
}
#[cfg(test)]
mod tests {
use std::pin::pin;

View File

@@ -4,42 +4,48 @@
mod backend;
mod conn_pool;
mod http_util;
mod json;
mod sql_over_http;
pub mod tls_listener;
mod websocket;
use atomic_take::AtomicTake;
use bytes::Bytes;
pub use conn_pool::GlobalConnPoolOptions;
use anyhow::bail;
use hyper::StatusCode;
use metrics::IntCounterPairGuard;
use anyhow::Context;
use futures::future::{select, Either};
use futures::TryFutureExt;
use http::{Method, Response, StatusCode};
use http_body_util::Full;
use hyper1::body::Incoming;
use hyper_util::rt::TokioExecutor;
use hyper_util::server::conn::auto::Builder;
use rand::rngs::StdRng;
use rand::SeedableRng;
pub use reqwest_middleware::{ClientWithMiddleware, Error};
pub use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
use tokio::time::timeout;
use tokio_rustls::TlsAcceptor;
use tokio_util::task::TaskTracker;
use tracing::instrument::Instrumented;
use crate::cancellation::CancellationHandlerMain;
use crate::config::ProxyConfig;
use crate::context::RequestMonitoring;
use crate::protocol2::{ProxyProtocolAccept, WithClientIp, WithConnectionGuard};
use crate::metrics::{NUM_CLIENT_CONNECTION_GAUGE, TLS_HANDSHAKE_FAILURES};
use crate::protocol2::WithClientIp;
use crate::proxy::run_until_cancelled;
use crate::rate_limiter::EndpointRateLimiter;
use crate::serverless::backend::PoolingBackend;
use hyper::{
server::conn::{AddrIncoming, AddrStream},
Body, Method, Request, Response,
};
use crate::serverless::http_util::{api_error_into_response, json_response};
use std::net::IpAddr;
use std::net::{IpAddr, SocketAddr};
use std::pin::pin;
use std::sync::Arc;
use std::task::Poll;
use tls_listener::TlsListener;
use tokio::net::TcpListener;
use tokio_util::sync::{CancellationToken, DropGuard};
use tokio::net::{TcpListener, TcpStream};
use tokio_util::sync::CancellationToken;
use tracing::{error, info, warn, Instrument};
use utils::http::{error::ApiError, json::json_response};
use utils::http::error::ApiError;
pub const SERVERLESS_DRIVER_SNI: &str = "api";
@@ -91,161 +97,174 @@ pub async fn task_main(
tls_server_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
let tls_acceptor: tokio_rustls::TlsAcceptor = Arc::new(tls_server_config).into();
let mut addr_incoming = AddrIncoming::from_listener(ws_listener)?;
let _ = addr_incoming.set_nodelay(true);
let addr_incoming = ProxyProtocolAccept {
incoming: addr_incoming,
protocol: "http",
};
let connections = tokio_util::task::task_tracker::TaskTracker::new();
connections.close(); // allows `connections.wait to complete`
let ws_connections = tokio_util::task::task_tracker::TaskTracker::new();
ws_connections.close(); // allows `ws_connections.wait to complete`
let server = Builder::new(hyper_util::rt::TokioExecutor::new());
let tls_listener = TlsListener::new(tls_acceptor, addr_incoming, config.handshake_timeout);
while let Some(res) = run_until_cancelled(ws_listener.accept(), &cancellation_token).await {
let (conn, peer_addr) = res.context("could not accept TCP stream")?;
if let Err(e) = conn.set_nodelay(true) {
tracing::error!("could not set nodelay: {e}");
continue;
}
let conn_id = uuid::Uuid::new_v4();
let http_conn_span = tracing::info_span!("http_conn", ?conn_id);
let make_svc = hyper::service::make_service_fn(
|stream: &tokio_rustls::server::TlsStream<
WithConnectionGuard<WithClientIp<AddrStream>>,
>| {
let (conn, _) = stream.get_ref();
connections.spawn(
connection_handler(
config,
backend.clone(),
connections.clone(),
cancellation_handler.clone(),
endpoint_rate_limiter.clone(),
cancellation_token.clone(),
server.clone(),
tls_acceptor.clone(),
conn,
peer_addr,
)
.instrument(http_conn_span),
);
}
// this is jank. should dissapear with hyper 1.0 migration.
let gauge = conn
.gauge
.lock()
.expect("lock should not be poisoned")
.take()
.expect("gauge should be set on connection start");
// Cancel all current inflight HTTP requests if the HTTP connection is closed.
let http_cancellation_token = CancellationToken::new();
let cancel_connection = http_cancellation_token.clone().drop_guard();
let span = conn.span.clone();
let client_addr = conn.inner.client_addr();
let remote_addr = conn.inner.inner.remote_addr();
let backend = backend.clone();
let ws_connections = ws_connections.clone();
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
let cancellation_handler = cancellation_handler.clone();
async move {
let peer_addr = match client_addr {
Some(addr) => addr,
None if config.require_client_ip => bail!("missing required client ip"),
None => remote_addr,
};
Ok(MetricService::new(
hyper::service::service_fn(move |req: Request<Body>| {
let backend = backend.clone();
let ws_connections2 = ws_connections.clone();
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
let cancellation_handler = cancellation_handler.clone();
let http_cancellation_token = http_cancellation_token.child_token();
// `request_handler` is not cancel safe. It expects to be cancelled only at specific times.
// By spawning the future, we ensure it never gets cancelled until it decides to.
ws_connections.spawn(
async move {
// Cancel the current inflight HTTP request if the requets stream is closed.
// This is slightly different to `_cancel_connection` in that
// h2 can cancel individual requests with a `RST_STREAM`.
let _cancel_session = http_cancellation_token.clone().drop_guard();
let res = request_handler(
req,
config,
backend,
ws_connections2,
cancellation_handler,
peer_addr.ip(),
endpoint_rate_limiter,
http_cancellation_token,
)
.await
.map_or_else(|e| e.into_response(), |r| r);
_cancel_session.disarm();
res
}
.in_current_span(),
)
}),
gauge,
cancel_connection,
span,
))
}
},
);
hyper::Server::builder(tls_listener)
.serve(make_svc)
.with_graceful_shutdown(cancellation_token.cancelled())
.await?;
// await websocket connections
ws_connections.wait().await;
connections.wait().await;
Ok(())
}
struct MetricService<S> {
inner: S,
_gauge: IntCounterPairGuard,
_cancel: DropGuard,
span: tracing::Span,
}
/// Handles the TCP lifecycle.
///
/// 1. Parses PROXY protocol V2
/// 2. Handles TLS handshake
/// 3. Handles HTTP connection
/// 1. With graceful shutdowns
/// 2. With graceful request cancellation with connection failure
/// 3. With websocket upgrade support.
#[allow(clippy::too_many_arguments)]
async fn connection_handler(
config: &'static ProxyConfig,
backend: Arc<PoolingBackend>,
connections: TaskTracker,
cancellation_handler: Arc<CancellationHandlerMain>,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
cancellation_token: CancellationToken,
server: Builder<TokioExecutor>,
tls_acceptor: TlsAcceptor,
conn: TcpStream,
peer_addr: SocketAddr,
) {
let session_id = uuid::Uuid::new_v4();
impl<S> MetricService<S> {
fn new(
inner: S,
_gauge: IntCounterPairGuard,
_cancel: DropGuard,
span: tracing::Span,
) -> MetricService<S> {
MetricService {
inner,
_gauge,
_cancel,
span,
let _gauge = NUM_CLIENT_CONNECTION_GAUGE
.with_label_values(&["http"])
.guard();
// handle PROXY protocol
let mut conn = WithClientIp::new(conn);
let peer = match conn.wait_for_addr().await {
Ok(peer) => peer,
Err(e) => {
tracing::error!(?session_id, %peer_addr, "failed to accept TCP connection: invalid PROXY protocol V2 header: {e:#}");
return;
}
}
}
};
impl<S, ReqBody> hyper::service::Service<Request<ReqBody>> for MetricService<S>
where
S: hyper::service::Service<Request<ReqBody>>,
{
type Response = S::Response;
type Error = S::Error;
type Future = Instrumented<S::Future>;
let peer_addr = peer.unwrap_or(peer_addr).ip();
info!(?session_id, %peer_addr, "accepted new TCP connection");
fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
// try upgrade to TLS, but with a timeout.
let conn = match timeout(config.handshake_timeout, tls_acceptor.accept(conn)).await {
Ok(Ok(conn)) => {
info!(?session_id, %peer_addr, "accepted new TLS connection");
conn
}
// The handshake failed
Ok(Err(e)) => {
TLS_HANDSHAKE_FAILURES.inc();
warn!(?session_id, %peer_addr, "failed to accept TLS connection: {e:?}");
return;
}
// The handshake timed out
Err(e) => {
TLS_HANDSHAKE_FAILURES.inc();
warn!(?session_id, %peer_addr, "failed to accept TLS connection: {e:?}");
return;
}
};
fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
self.span
.in_scope(|| self.inner.call(req))
.instrument(self.span.clone())
let session_id = AtomicTake::new(session_id);
// Cancel all current inflight HTTP requests if the HTTP connection is closed.
let http_cancellation_token = CancellationToken::new();
let _cancel_connection = http_cancellation_token.clone().drop_guard();
let conn = server.serve_connection_with_upgrades(
hyper_util::rt::TokioIo::new(conn),
hyper1::service::service_fn(move |req: hyper1::Request<Incoming>| {
// First HTTP request shares the same session ID
let session_id = session_id.take().unwrap_or_else(uuid::Uuid::new_v4);
// Cancel the current inflight HTTP request if the requets stream is closed.
// This is slightly different to `_cancel_connection` in that
// h2 can cancel individual requests with a `RST_STREAM`.
let http_request_token = http_cancellation_token.child_token();
let cancel_request = http_request_token.clone().drop_guard();
// `request_handler` is not cancel safe. It expects to be cancelled only at specific times.
// By spawning the future, we ensure it never gets cancelled until it decides to.
let handler = connections.spawn(
request_handler(
req,
config,
backend.clone(),
connections.clone(),
cancellation_handler.clone(),
session_id,
peer_addr,
endpoint_rate_limiter.clone(),
http_request_token,
)
.in_current_span()
.map_ok_or_else(api_error_into_response, |r| r),
);
async move {
let res = handler.await;
cancel_request.disarm();
res
}
}),
);
// On cancellation, trigger the HTTP connection handler to shut down.
let res = match select(pin!(cancellation_token.cancelled()), pin!(conn)).await {
Either::Left((_cancelled, mut conn)) => {
conn.as_mut().graceful_shutdown();
conn.await
}
Either::Right((res, _)) => res,
};
match res {
Ok(()) => tracing::info!(%peer_addr, "HTTP connection closed"),
Err(e) => tracing::warn!(%peer_addr, "HTTP connection error {e}"),
}
}
#[allow(clippy::too_many_arguments)]
async fn request_handler(
mut request: Request<Body>,
mut request: hyper1::Request<Incoming>,
config: &'static ProxyConfig,
backend: Arc<PoolingBackend>,
ws_connections: TaskTracker,
cancellation_handler: Arc<CancellationHandlerMain>,
session_id: uuid::Uuid,
peer_addr: IpAddr,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
// used to cancel in-flight HTTP requests. not used to cancel websockets
http_cancellation_token: CancellationToken,
) -> Result<Response<Body>, ApiError> {
let session_id = uuid::Uuid::new_v4();
) -> Result<Response<Full<Bytes>>, ApiError> {
let host = request
.headers()
.get("host")
@@ -282,14 +301,14 @@ async fn request_handler(
// Return the response so the spawned future can continue.
Ok(response)
} else if request.uri().path() == "/sql" && request.method() == Method::POST {
} else if request.uri().path() == "/sql" && *request.method() == Method::POST {
let ctx = RequestMonitoring::new(session_id, peer_addr, "http", &config.region);
let span = ctx.span.clone();
sql_over_http::handle(config, ctx, request, backend, http_cancellation_token)
.instrument(span)
.await
} else if request.uri().path() == "/sql" && request.method() == Method::OPTIONS {
} else if request.uri().path() == "/sql" && *request.method() == Method::OPTIONS {
Response::builder()
.header("Allow", "OPTIONS, POST")
.header("Access-Control-Allow-Origin", "*")
@@ -299,7 +318,7 @@ async fn request_handler(
)
.header("Access-Control-Max-Age", "86400" /* 24 hours */)
.status(StatusCode::OK) // 204 is also valid, but see: https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/OPTIONS#status_code
.body(Body::empty())
.body(Full::new(Bytes::new()))
.map_err(|e| ApiError::InternalServerError(e.into()))
} else {
json_response(StatusCode::BAD_REQUEST, "query is not supported")

View File

@@ -0,0 +1,92 @@
//! Things stolen from `libs/utils/src/http` to add hyper 1.0 compatibility
//! Will merge back in at some point in the future.
use bytes::Bytes;
use anyhow::Context;
use http::{Response, StatusCode};
use http_body_util::Full;
use serde::Serialize;
use utils::http::error::ApiError;
/// Like [`ApiError::into_response`]
pub fn api_error_into_response(this: ApiError) -> Response<Full<Bytes>> {
match this {
ApiError::BadRequest(err) => HttpErrorBody::response_from_msg_and_status(
format!("{err:#?}"), // use debug printing so that we give the cause
StatusCode::BAD_REQUEST,
),
ApiError::Forbidden(_) => {
HttpErrorBody::response_from_msg_and_status(this.to_string(), StatusCode::FORBIDDEN)
}
ApiError::Unauthorized(_) => {
HttpErrorBody::response_from_msg_and_status(this.to_string(), StatusCode::UNAUTHORIZED)
}
ApiError::NotFound(_) => {
HttpErrorBody::response_from_msg_and_status(this.to_string(), StatusCode::NOT_FOUND)
}
ApiError::Conflict(_) => {
HttpErrorBody::response_from_msg_and_status(this.to_string(), StatusCode::CONFLICT)
}
ApiError::PreconditionFailed(_) => HttpErrorBody::response_from_msg_and_status(
this.to_string(),
StatusCode::PRECONDITION_FAILED,
),
ApiError::ShuttingDown => HttpErrorBody::response_from_msg_and_status(
"Shutting down".to_string(),
StatusCode::SERVICE_UNAVAILABLE,
),
ApiError::ResourceUnavailable(err) => HttpErrorBody::response_from_msg_and_status(
err.to_string(),
StatusCode::SERVICE_UNAVAILABLE,
),
ApiError::Timeout(err) => HttpErrorBody::response_from_msg_and_status(
err.to_string(),
StatusCode::REQUEST_TIMEOUT,
),
ApiError::InternalServerError(err) => HttpErrorBody::response_from_msg_and_status(
err.to_string(),
StatusCode::INTERNAL_SERVER_ERROR,
),
}
}
/// Same as [`utils::http::error::HttpErrorBody`]
#[derive(Serialize)]
struct HttpErrorBody {
pub msg: String,
}
impl HttpErrorBody {
/// Same as [`utils::http::error::HttpErrorBody::response_from_msg_and_status`]
fn response_from_msg_and_status(msg: String, status: StatusCode) -> Response<Full<Bytes>> {
HttpErrorBody { msg }.to_response(status)
}
/// Same as [`utils::http::error::HttpErrorBody::to_response`]
fn to_response(&self, status: StatusCode) -> Response<Full<Bytes>> {
Response::builder()
.status(status)
.header(http::header::CONTENT_TYPE, "application/json")
// we do not have nested maps with non string keys so serialization shouldn't fail
.body(Full::new(Bytes::from(serde_json::to_string(self).unwrap())))
.unwrap()
}
}
/// Same as [`utils::http::json::json_response`]
pub fn json_response<T: Serialize>(
status: StatusCode,
data: T,
) -> Result<Response<Full<Bytes>>, ApiError> {
let json = serde_json::to_string(&data)
.context("Failed to serialize JSON response")
.map_err(ApiError::InternalServerError)?;
let response = Response::builder()
.status(status)
.header(http::header::CONTENT_TYPE, "application/json")
.body(Full::new(Bytes::from(json)))
.map_err(|e| ApiError::InternalServerError(e.into()))?;
Ok(response)
}

View File

@@ -1,18 +1,22 @@
use std::pin::pin;
use std::sync::Arc;
use bytes::Bytes;
use futures::future::select;
use futures::future::try_join;
use futures::future::Either;
use futures::StreamExt;
use futures::TryFutureExt;
use hyper::body::HttpBody;
use hyper::header;
use hyper::http::HeaderName;
use hyper::http::HeaderValue;
use hyper::Response;
use hyper::StatusCode;
use hyper::{Body, HeaderMap, Request};
use http_body_util::BodyExt;
use http_body_util::Full;
use hyper1::body::Body;
use hyper1::body::Incoming;
use hyper1::header;
use hyper1::http::HeaderName;
use hyper1::http::HeaderValue;
use hyper1::Response;
use hyper1::StatusCode;
use hyper1::{HeaderMap, Request};
use serde_json::json;
use serde_json::Value;
use tokio::time;
@@ -29,7 +33,6 @@ use tracing::error;
use tracing::info;
use url::Url;
use utils::http::error::ApiError;
use utils::http::json::json_response;
use crate::auth::backend::ComputeUserInfo;
use crate::auth::endpoint_sni;
@@ -52,6 +55,7 @@ use crate::RoleName;
use super::backend::PoolingBackend;
use super::conn_pool::Client;
use super::conn_pool::ConnInfo;
use super::http_util::json_response;
use super::json::json_to_pg_text;
use super::json::pg_text_row_to_json;
use super::json::JsonConversionError;
@@ -218,10 +222,10 @@ fn get_conn_info(
pub async fn handle(
config: &'static ProxyConfig,
mut ctx: RequestMonitoring,
request: Request<Body>,
request: Request<Incoming>,
backend: Arc<PoolingBackend>,
cancel: CancellationToken,
) -> Result<Response<Body>, ApiError> {
) -> Result<Response<Full<Bytes>>, ApiError> {
let result = handle_inner(cancel, config, &mut ctx, request, backend).await;
let mut response = match result {
@@ -332,10 +336,9 @@ pub async fn handle(
}
};
response.headers_mut().insert(
"Access-Control-Allow-Origin",
hyper::http::HeaderValue::from_static("*"),
);
response
.headers_mut()
.insert("Access-Control-Allow-Origin", HeaderValue::from_static("*"));
Ok(response)
}
@@ -396,7 +399,7 @@ impl UserFacingError for SqlOverHttpError {
#[derive(Debug, thiserror::Error)]
pub enum ReadPayloadError {
#[error("could not read the HTTP request body: {0}")]
Read(#[from] hyper::Error),
Read(#[from] hyper1::Error),
#[error("could not parse the HTTP request body: {0}")]
Parse(#[from] serde_json::Error),
}
@@ -437,7 +440,7 @@ struct HttpHeaders {
}
impl HttpHeaders {
fn try_parse(headers: &hyper::http::HeaderMap) -> Result<Self, SqlOverHttpError> {
fn try_parse(headers: &hyper1::http::HeaderMap) -> Result<Self, SqlOverHttpError> {
// Determine the output options. Default behaviour is 'false'. Anything that is not
// strictly 'true' assumed to be false.
let raw_output = headers.get(&RAW_TEXT_OUTPUT) == Some(&HEADER_VALUE_TRUE);
@@ -488,9 +491,9 @@ async fn handle_inner(
cancel: CancellationToken,
config: &'static ProxyConfig,
ctx: &mut RequestMonitoring,
request: Request<Body>,
request: Request<Incoming>,
backend: Arc<PoolingBackend>,
) -> Result<Response<Body>, SqlOverHttpError> {
) -> Result<Response<Full<Bytes>>, SqlOverHttpError> {
let _request_gauge = NUM_CONNECTION_REQUESTS_GAUGE
.with_label_values(&[ctx.protocol])
.guard();
@@ -528,7 +531,7 @@ async fn handle_inner(
}
let fetch_and_process_request = async {
let body = hyper::body::to_bytes(request.into_body()).await?;
let body = request.into_body().collect().await?.to_bytes();
info!(length = body.len(), "request payload read");
let payload: Payload = serde_json::from_slice(&body)?;
Ok::<Payload, ReadPayloadError>(payload) // Adjust error type accordingly
@@ -596,7 +599,7 @@ async fn handle_inner(
let body = serde_json::to_string(&result).expect("json serialization should not fail");
let len = body.len();
let response = response
.body(Body::from(body))
.body(Full::new(Bytes::from(body)))
// only fails if invalid status code or invalid header/values are given.
// these are not user configurable so it cannot fail dynamically
.expect("building response payload should not fail");
@@ -639,6 +642,7 @@ impl QueryData {
}
// The query was cancelled.
Either::Right((_cancelled, query)) => {
tracing::info!("cancelling query");
if let Err(err) = cancel_token.cancel_query(NoTls).await {
tracing::error!(?err, "could not cancel query");
}

View File

@@ -1,123 +0,0 @@
use std::{
convert::Infallible,
pin::Pin,
task::{Context, Poll},
time::Duration,
};
use hyper::server::{accept::Accept, conn::AddrStream};
use pin_project_lite::pin_project;
use tokio::{
io::{AsyncRead, AsyncWrite},
task::JoinSet,
time::timeout,
};
use tokio_rustls::{server::TlsStream, TlsAcceptor};
use tracing::{info, warn, Instrument};
use crate::{
metrics::TLS_HANDSHAKE_FAILURES,
protocol2::{WithClientIp, WithConnectionGuard},
};
pin_project! {
/// Wraps a `Stream` of connections (such as a TCP listener) so that each connection is itself
/// encrypted using TLS.
pub(crate) struct TlsListener<A: Accept> {
#[pin]
listener: A,
tls: TlsAcceptor,
waiting: JoinSet<Option<TlsStream<A::Conn>>>,
timeout: Duration,
}
}
impl<A: Accept> TlsListener<A> {
/// Create a `TlsListener` with default options.
pub(crate) fn new(tls: TlsAcceptor, listener: A, timeout: Duration) -> Self {
TlsListener {
listener,
tls,
waiting: JoinSet::new(),
timeout,
}
}
}
impl<A> Accept for TlsListener<A>
where
A: Accept<Conn = WithConnectionGuard<WithClientIp<AddrStream>>>,
A::Error: std::error::Error,
A::Conn: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
type Conn = TlsStream<A::Conn>;
type Error = Infallible;
fn poll_accept(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
let mut this = self.project();
loop {
match this.listener.as_mut().poll_accept(cx) {
Poll::Pending => break,
Poll::Ready(Some(Ok(mut conn))) => {
let t = *this.timeout;
let tls = this.tls.clone();
let span = conn.span.clone();
this.waiting.spawn(async move {
let peer_addr = match conn.inner.wait_for_addr().await {
Ok(Some(addr)) => addr,
Err(e) => {
tracing::error!("failed to accept TCP connection: invalid PROXY protocol V2 header: {e:#}");
return None;
}
Ok(None) => conn.inner.inner.remote_addr()
};
let accept = tls.accept(conn);
match timeout(t, accept).await {
Ok(Ok(conn)) => {
info!(%peer_addr, "accepted new TLS connection");
Some(conn)
},
// The handshake failed, try getting another connection from the queue
Ok(Err(e)) => {
TLS_HANDSHAKE_FAILURES.inc();
warn!(%peer_addr, "failed to accept TLS connection: {e:?}");
None
}
// The handshake timed out, try getting another connection from the queue
Err(_) => {
TLS_HANDSHAKE_FAILURES.inc();
warn!(%peer_addr, "failed to accept TLS connection: timeout");
None
}
}
}.instrument(span));
}
Poll::Ready(Some(Err(e))) => {
tracing::error!("error accepting TCP connection: {e}");
continue;
}
Poll::Ready(None) => return Poll::Ready(None),
}
}
loop {
return match this.waiting.poll_join_next(cx) {
Poll::Ready(Some(Ok(Some(conn)))) => Poll::Ready(Some(Ok(conn))),
// The handshake failed to complete, try getting another connection from the queue
Poll::Ready(Some(Ok(None))) => continue,
// The handshake panicked or was cancelled. ignore and get another connection
Poll::Ready(Some(Err(e))) => {
tracing::warn!("handshake aborted: {e}");
continue;
}
_ => Poll::Pending,
};
}
}
}