This commit is contained in:
Conrad Ludgate
2024-10-30 12:18:22 +00:00
parent fa5d907032
commit eb96abbff7

View File

@@ -26,7 +26,7 @@ use atomic_take::AtomicTake;
use bytes::Bytes;
pub use conn_pool_lib::GlobalConnPoolOptions;
use futures::future::{select, Either};
use futures::{FutureExt, TryFutureExt};
use futures::FutureExt;
use http::{Method, Response, StatusCode};
use http_body_util::combinators::BoxBody;
use http_body_util::{BodyExt, Empty};
@@ -44,7 +44,7 @@ use tokio::task::JoinHandle;
use tokio::time::timeout;
use tokio_rustls::TlsAcceptor;
use tokio_util::sync::CancellationToken;
use tokio_util::task::TaskTracker;
use tokio_util::task::task_tracker::TaskTrackerToken;
use tracing::{debug, info, warn, Instrument};
use utils::http::error::ApiError;
@@ -159,16 +159,18 @@ pub async fn task_main(
}
}
let conn_token = cancellation_token.child_token();
let conn_cancellation_token = cancellation_token.child_token();
let tls_acceptor = tls_acceptor.clone();
let backend = backend.clone();
let connections2 = connections.clone();
let cancellation_handler = cancellation_handler.clone();
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
connections.spawn(
let connection_token = connections.token();
tokio::spawn(
async move {
let conn_token2 = conn_token.clone();
let _cancel_guard = config.http_config.cancel_set.insert(conn_id, conn_token2);
let _cancel_guard = config
.http_config
.cancel_set
.insert(conn_id, conn_cancellation_token.clone());
let session_id = uuid::Uuid::new_v4();
@@ -184,27 +186,19 @@ pub async fn task_main(
let Some(conn) = startup_result else {
return;
};
let (_, peer_addr, _) = conn;
let ws_upgrade = http_connection_handler(
config,
backend,
connections2,
conn_token,
conn_cancellation_token,
connection_token.clone(),
conn,
session_id,
)
.boxed()
.await;
if let Some((session_id, host, websocket)) = ws_upgrade {
let ctx = RequestMonitoring::new(
session_id,
peer_addr,
crate::metrics::Protocol::Ws,
&config.region,
);
if let Some((ctx, host, websocket)) = ws_upgrade {
let ws = websocket::serve_websocket(
config,
auth_backend,
@@ -350,8 +344,8 @@ impl GracefulShutdown
async fn http_connection_handler(
config: &'static ProxyConfig,
backend: Arc<PoolingBackend>,
connections: TaskTracker,
cancellation_token: CancellationToken,
connection_token: TaskTrackerToken,
conn: ConnWithInfo,
session_id: uuid::Uuid,
) -> Option<WsUpgrade> {
@@ -377,7 +371,7 @@ async fn http_connection_handler(
let service = ProxyService {
config,
backend,
connections,
connection_token,
http_cancellation_token,
ws_tx,
@@ -434,9 +428,9 @@ struct ProxyService {
// global state
config: &'static ProxyConfig,
backend: Arc<PoolingBackend>,
connections: TaskTracker,
// connection state only
connection_token: TaskTrackerToken,
http_cancellation_token: CancellationToken,
ws_tx: WsSpawner,
peer_addr: IpAddr,
@@ -468,23 +462,18 @@ impl hyper::service::Service<hyper::Request<Incoming>> for ProxyService {
// Cancel the current inflight HTTP request if the requets stream is closed.
// This is slightly different to `_cancel_connection` in that
// h2 can cancel individual requests with a `RST_STREAM`.
let http_request_token = self.http_cancellation_token.child_token();
let cancel_request = Some(http_request_token.clone().drop_guard());
let http_req_cancellation_token = self.http_cancellation_token.child_token();
let cancel_request = Some(http_req_cancellation_token.clone().drop_guard());
// `request_handler` is not cancel safe. It expects to be cancelled only at specific times.
// By spawning the future, we ensure it never gets cancelled until it decides to.
let handle = self.connections.spawn(
request_handler(
req,
self.config,
self.backend.clone(),
self.ws_tx.clone(),
session_id,
self.peer_addr,
http_request_token,
)
.in_current_span()
.map_ok_or_else(api_error_into_response, |r| r),
let handle = request_handler(
req,
self.config,
self.backend.clone(),
self.ws_tx.clone(),
session_id,
self.peer_addr,
http_req_cancellation_token,
&self.connection_token,
);
ReqFut {
@@ -498,14 +487,25 @@ impl hyper::service::Service<hyper::Request<Incoming>> for ProxyService {
struct ReqFut {
session_id: uuid::Uuid,
cancel_request: Option<tokio_util::sync::DropGuard>,
handle: JoinHandle<Response<BoxBody<Bytes, hyper::Error>>>,
handle: HandleOrResponse,
}
enum HandleOrResponse {
Handle(JoinHandle<Response<BoxBody<Bytes, hyper::Error>>>),
Response(Option<Response<BoxBody<Bytes, hyper::Error>>>),
}
impl Future for ReqFut {
type Output = Result<Response<BoxBody<Bytes, hyper::Error>>, tokio::task::JoinError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
let mut res = std::task::ready!(self.handle.poll_unpin(cx));
let mut res = match &mut self.handle {
HandleOrResponse::Handle(join_handle) => std::task::ready!(join_handle.poll_unpin(cx)),
HandleOrResponse::Response(response) => {
Ok(response.take().expect("polled after completion"))
}
};
self.cancel_request
.take()
.map(tokio_util::sync::DropGuard::disarm);
@@ -520,10 +520,11 @@ impl Future for ReqFut {
}
}
type WsUpgrade = (uuid::Uuid, Option<String>, OnUpgrade);
type WsUpgrade = (RequestMonitoring, Option<String>, OnUpgrade);
type WsSpawner = Arc<AtomicTake<oneshot::Sender<WsUpgrade>>>;
async fn request_handler(
#[allow(clippy::too_many_arguments)]
fn request_handler(
mut request: hyper::Request<Incoming>,
config: &'static ProxyConfig,
backend: Arc<PoolingBackend>,
@@ -532,15 +533,25 @@ async fn request_handler(
peer_addr: IpAddr,
// used to cancel in-flight HTTP requests. not used to cancel websockets
http_cancellation_token: CancellationToken,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, ApiError> {
connection_token: &TaskTrackerToken,
) -> HandleOrResponse {
// Check if the request is a websocket upgrade request.
if framed_websockets::upgrade::is_upgrade_request(&request) {
let Some(spawner) = ws_spawner.take() else {
return json_response(StatusCode::BAD_REQUEST, "query is not supported");
return HandleOrResponse::Response(Some(
json_response(StatusCode::BAD_REQUEST, "query is not supported")
.unwrap_or_else(api_error_into_response),
));
};
let (response, websocket) = framed_websockets::upgrade::upgrade(&mut request)
.map_err(|e| ApiError::BadRequest(e.into()))?;
let (response, websocket) = match framed_websockets::upgrade::upgrade(&mut request) {
Err(e) => {
return HandleOrResponse::Response(Some(api_error_into_response(
ApiError::BadRequest(e.into()),
)))
}
Ok(upgrade) => upgrade,
};
let host = request
.headers()
@@ -549,12 +560,21 @@ async fn request_handler(
.and_then(|h| h.split(':').next())
.map(|s| s.to_string());
spawner.send((session_id, host, websocket)).map_err(|_e| {
ApiError::InternalServerError(anyhow::anyhow!("could not upgrade WS connection"))
})?;
let ctx = RequestMonitoring::new(
session_id,
peer_addr,
crate::metrics::Protocol::Ws,
&config.region,
);
if let Err(_e) = spawner.send((ctx, host, websocket)) {
return HandleOrResponse::Response(Some(api_error_into_response(
ApiError::InternalServerError(anyhow::anyhow!("could not upgrade WS connection")),
)));
}
// Return the response so the spawned future can continue.
Ok(response.map(|b| b.map_err(|x| match x {}).boxed()))
HandleOrResponse::Response(Some(response.map(|b| b.map_err(|x| match x {}).boxed())))
} else if request.uri().path() == "/sql" && *request.method() == Method::POST {
let ctx = RequestMonitoring::new(
session_id,
@@ -564,11 +584,19 @@ async fn request_handler(
);
let span = ctx.span();
sql_over_http::handle(config, ctx, request, backend, http_cancellation_token)
.instrument(span)
.await
let token = connection_token.clone();
// `sql_over_http::handle` is not cancel safe. It expects to be cancelled only at specific times.
// By spawning the future, we ensure it never gets cancelled until it decides to.
HandleOrResponse::Handle(tokio::spawn(async move {
let _token = token;
sql_over_http::handle(config, ctx, request, backend, http_cancellation_token)
.instrument(span)
.await
.unwrap_or_else(api_error_into_response)
}))
} else if request.uri().path() == "/sql" && *request.method() == Method::OPTIONS {
Response::builder()
HandleOrResponse::Response(Some( Response::builder()
.header("Allow", "OPTIONS, POST")
.header("Access-Control-Allow-Origin", "*")
.header(
@@ -578,8 +606,11 @@ async fn request_handler(
.header("Access-Control-Max-Age", "86400" /* 24 hours */)
.status(StatusCode::OK) // 204 is also valid, but see: https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/OPTIONS#status_code
.body(Empty::new().map_err(|x| match x {}).boxed())
.map_err(|e| ApiError::InternalServerError(e.into()))
.map_err(|e| ApiError::InternalServerError(e.into())).unwrap_or_else(api_error_into_response)))
} else {
json_response(StatusCode::BAD_REQUEST, "query is not supported")
HandleOrResponse::Response(Some(
json_response(StatusCode::BAD_REQUEST, "query is not supported")
.unwrap_or_else(api_error_into_response),
))
}
}