re-use the same tokio task for the websocket handling

This commit is contained in:
Conrad Ludgate
2024-10-30 09:45:19 +00:00
parent 0c828c57e2
commit 7de2914dde

View File

@@ -29,6 +29,7 @@ use http::{Method, Response, StatusCode};
use http_body_util::combinators::BoxBody;
use http_body_util::{BodyExt, Empty};
use hyper::body::Incoming;
use hyper::upgrade::OnUpgrade;
use hyper_util::rt::TokioExecutor;
use hyper_util::server::conn::auto::Builder;
use rand::rngs::StdRng;
@@ -36,6 +37,7 @@ use rand::SeedableRng;
use sql_over_http::{uuid_to_header_value, NEON_REQUEST_ID};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::mpsc;
use tokio::time::timeout;
use tokio_rustls::TlsAcceptor;
use tokio_util::sync::CancellationToken;
@@ -306,6 +308,9 @@ async fn connection_handler(
let http_cancellation_token = CancellationToken::new();
let _cancel_connection = http_cancellation_token.clone().drop_guard();
let (ws_tx, mut ws_rx) = mpsc::channel(1);
let auth_backend = backend.auth_backend;
let server = Builder::new(TokioExecutor::new());
let conn = server.serve_connection_with_upgrades(
hyper_util::rt::TokioIo::new(conn),
@@ -337,12 +342,10 @@ async fn connection_handler(
req,
config,
backend.clone(),
connections.clone(),
cancellation_handler.clone(),
ws_tx.clone(),
session_id,
peer_addr,
http_request_token,
endpoint_rate_limiter.clone(),
)
.in_current_span()
.map_ok_or_else(api_error_into_response, |r| r),
@@ -373,53 +376,20 @@ async fn connection_handler(
};
match res {
Ok(()) => tracing::info!(%peer_addr, "HTTP connection closed"),
Err(e) => tracing::warn!(%peer_addr, "HTTP connection error {e}"),
}
}
Ok(()) => {
if let Some((session_id, host, websocket)) = ws_rx.recv().await {
tracing::info!(%peer_addr, "connection upgraded to websockets");
#[allow(clippy::too_many_arguments)]
async fn request_handler(
mut request: hyper::Request<Incoming>,
config: &'static ProxyConfig,
backend: Arc<PoolingBackend>,
ws_connections: TaskTracker,
cancellation_handler: Arc<CancellationHandlerMain>,
session_id: uuid::Uuid,
peer_addr: IpAddr,
// used to cancel in-flight HTTP requests. not used to cancel websockets
http_cancellation_token: CancellationToken,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, ApiError> {
let host = request
.headers()
.get("host")
.and_then(|h| h.to_str().ok())
.and_then(|h| h.split(':').next())
.map(|s| s.to_string());
let ctx = RequestMonitoring::new(
session_id,
peer_addr,
crate::metrics::Protocol::Ws,
&config.region,
);
// Check if the request is a websocket upgrade request.
if config.http_config.accept_websockets
&& framed_websockets::upgrade::is_upgrade_request(&request)
{
let ctx = RequestMonitoring::new(
session_id,
peer_addr,
crate::metrics::Protocol::Ws,
&config.region,
);
let span = ctx.span();
info!(parent: &span, "performing websocket upgrade");
let (response, websocket) = framed_websockets::upgrade::upgrade(&mut request)
.map_err(|e| ApiError::BadRequest(e.into()))?;
ws_connections.spawn(
async move {
if let Err(e) = websocket::serve_websocket(
config,
backend.auth_backend,
auth_backend,
ctx,
websocket,
cancellation_handler,
@@ -430,9 +400,45 @@ async fn request_handler(
{
warn!("error in websocket connection: {e:#}");
}
} else {
tracing::info!(%peer_addr, "HTTP connection closed");
}
.instrument(span),
);
}
Err(e) => tracing::warn!(%peer_addr, "HTTP connection error {e}"),
}
}
#[allow(clippy::too_many_arguments)]
async fn request_handler(
mut request: hyper::Request<Incoming>,
config: &'static ProxyConfig,
backend: Arc<PoolingBackend>,
ws_spawner: mpsc::Sender<(uuid::Uuid, Option<String>, OnUpgrade)>,
session_id: uuid::Uuid,
peer_addr: IpAddr,
// used to cancel in-flight HTTP requests. not used to cancel websockets
http_cancellation_token: CancellationToken,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, ApiError> {
// Check if the request is a websocket upgrade request.
if config.http_config.accept_websockets
&& framed_websockets::upgrade::is_upgrade_request(&request)
{
let (response, websocket) = framed_websockets::upgrade::upgrade(&mut request)
.map_err(|e| ApiError::BadRequest(e.into()))?;
let host = request
.headers()
.get("host")
.and_then(|h| h.to_str().ok())
.and_then(|h| h.split(':').next())
.map(|s| s.to_string());
ws_spawner
.send((session_id, host, websocket))
.await
.map_err(|_e| {
ApiError::InternalServerError(anyhow::anyhow!("could not upgrade WS connection"))
})?;
// Return the response so the spawned future can continue.
Ok(response.map(|b| b.map_err(|x| match x {}).boxed()))