This commit is contained in:
Conrad Ludgate
2024-10-30 11:13:56 +00:00
parent 7c57234de1
commit ba714431be

View File

@@ -38,7 +38,7 @@ use smallvec::SmallVec;
use sql_over_http::{uuid_to_header_value, NEON_REQUEST_ID};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::mpsc;
use tokio::sync::oneshot;
use tokio::time::timeout;
use tokio_rustls::TlsAcceptor;
use tokio_util::sync::CancellationToken;
@@ -304,9 +304,9 @@ impl<B, S> GracefulShutdown
for hyper::server::conn::http1::UpgradeableConnection<TokioIo<AsyncRW>, S>
where
S: hyper::service::HttpService<hyper::body::Incoming, ResBody = B> + Send,
S::Future: Send + 'static,
S::Future: Send,
B: hyper::body::Body + Send + 'static,
B::Data: Send + 'static,
B::Data: Send,
S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
@@ -319,9 +319,9 @@ where
impl<B, S> GracefulShutdown for hyper::server::conn::http1::Connection<TokioIo<AsyncRW>, S>
where
S: hyper::service::HttpService<hyper::body::Incoming, ResBody = B> + Send,
S::Future: Send + 'static,
S::Future: Send,
B: hyper::body::Body + Send + 'static,
B::Data: Send + 'static,
B::Data: Send,
S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
@@ -337,7 +337,7 @@ where
S: hyper::service::HttpService<hyper::body::Incoming, ResBody = B> + Send,
S::Future: Send + 'static,
B: hyper::body::Body + Send + 'static,
B::Data: Send + 'static,
B::Data: Send,
S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
@@ -370,7 +370,8 @@ async fn connection_handler(
let http_cancellation_token = CancellationToken::new();
let _cancel_connection = http_cancellation_token.clone().drop_guard();
let (ws_tx, mut ws_rx) = mpsc::channel(1);
let (ws_tx, ws_rx) = oneshot::channel();
let ws_tx = Arc::new(AtomicTake::new(ws_tx));
let auth_backend = backend.auth_backend;
let http2 = match &*alpn {
@@ -382,6 +383,11 @@ async fn connection_handler(
}
};
if http2 || !config.http_config.accept_websockets {
// discard the ws spawner
ws_tx.take();
}
let service = hyper::service::service_fn(move |req: hyper::Request<Incoming>| {
// First HTTP request shares the same session ID
let mut session_id = session_id.take().unwrap_or_else(uuid::Uuid::new_v4);
@@ -460,7 +466,7 @@ async fn connection_handler(
match res {
Ok(()) => {
if let Some((session_id, host, websocket)) = ws_rx.recv().await {
if let Ok((session_id, host, websocket)) = ws_rx.await {
tracing::info!(%peer_addr, "connection upgraded to websockets");
let ctx = RequestMonitoring::new(
@@ -491,21 +497,26 @@ async fn connection_handler(
}
}
type WsUpgrade = (uuid::Uuid, Option<String>, OnUpgrade);
type WsSpawner = Arc<AtomicTake<oneshot::Sender<WsUpgrade>>>;
#[allow(clippy::too_many_arguments)]
async fn request_handler(
mut request: hyper::Request<Incoming>,
config: &'static ProxyConfig,
backend: Arc<PoolingBackend>,
ws_spawner: mpsc::Sender<(uuid::Uuid, Option<String>, OnUpgrade)>,
ws_spawner: WsSpawner,
session_id: uuid::Uuid,
peer_addr: IpAddr,
// used to cancel in-flight HTTP requests. not used to cancel websockets
http_cancellation_token: CancellationToken,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, ApiError> {
// Check if the request is a websocket upgrade request.
if config.http_config.accept_websockets
&& framed_websockets::upgrade::is_upgrade_request(&request)
{
if framed_websockets::upgrade::is_upgrade_request(&request) {
let Some(spawner) = ws_spawner.take() else {
return json_response(StatusCode::BAD_REQUEST, "query is not supported");
};
let (response, websocket) = framed_websockets::upgrade::upgrade(&mut request)
.map_err(|e| ApiError::BadRequest(e.into()))?;
@@ -516,12 +527,9 @@ async fn request_handler(
.and_then(|h| h.split(':').next())
.map(|s| s.to_string());
ws_spawner
.send((session_id, host, websocket))
.await
.map_err(|_e| {
ApiError::InternalServerError(anyhow::anyhow!("could not upgrade WS connection"))
})?;
spawner.send((session_id, host, websocket)).map_err(|_e| {
ApiError::InternalServerError(anyhow::anyhow!("could not upgrade WS connection"))
})?;
// Return the response so the spawned future can continue.
Ok(response.map(|b| b.map_err(|x| match x {}).boxed()))