mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-15 09:22:55 +00:00
re-use the same tokio task for the websocket handling
This commit is contained in:
@@ -29,6 +29,7 @@ use http::{Method, Response, StatusCode};
|
||||
use http_body_util::combinators::BoxBody;
|
||||
use http_body_util::{BodyExt, Empty};
|
||||
use hyper::body::Incoming;
|
||||
use hyper::upgrade::OnUpgrade;
|
||||
use hyper_util::rt::TokioExecutor;
|
||||
use hyper_util::server::conn::auto::Builder;
|
||||
use rand::rngs::StdRng;
|
||||
@@ -36,6 +37,7 @@ use rand::SeedableRng;
|
||||
use sql_over_http::{uuid_to_header_value, NEON_REQUEST_ID};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio::net::{TcpListener, TcpStream};
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::time::timeout;
|
||||
use tokio_rustls::TlsAcceptor;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
@@ -306,6 +308,9 @@ async fn connection_handler(
|
||||
let http_cancellation_token = CancellationToken::new();
|
||||
let _cancel_connection = http_cancellation_token.clone().drop_guard();
|
||||
|
||||
let (ws_tx, mut ws_rx) = mpsc::channel(1);
|
||||
let auth_backend = backend.auth_backend;
|
||||
|
||||
let server = Builder::new(TokioExecutor::new());
|
||||
let conn = server.serve_connection_with_upgrades(
|
||||
hyper_util::rt::TokioIo::new(conn),
|
||||
@@ -337,12 +342,10 @@ async fn connection_handler(
|
||||
req,
|
||||
config,
|
||||
backend.clone(),
|
||||
connections.clone(),
|
||||
cancellation_handler.clone(),
|
||||
ws_tx.clone(),
|
||||
session_id,
|
||||
peer_addr,
|
||||
http_request_token,
|
||||
endpoint_rate_limiter.clone(),
|
||||
)
|
||||
.in_current_span()
|
||||
.map_ok_or_else(api_error_into_response, |r| r),
|
||||
@@ -373,53 +376,20 @@ async fn connection_handler(
|
||||
};
|
||||
|
||||
match res {
|
||||
Ok(()) => tracing::info!(%peer_addr, "HTTP connection closed"),
|
||||
Err(e) => tracing::warn!(%peer_addr, "HTTP connection error {e}"),
|
||||
}
|
||||
}
|
||||
Ok(()) => {
|
||||
if let Some((session_id, host, websocket)) = ws_rx.recv().await {
|
||||
tracing::info!(%peer_addr, "connection upgraded to websockets");
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn request_handler(
|
||||
mut request: hyper::Request<Incoming>,
|
||||
config: &'static ProxyConfig,
|
||||
backend: Arc<PoolingBackend>,
|
||||
ws_connections: TaskTracker,
|
||||
cancellation_handler: Arc<CancellationHandlerMain>,
|
||||
session_id: uuid::Uuid,
|
||||
peer_addr: IpAddr,
|
||||
// used to cancel in-flight HTTP requests. not used to cancel websockets
|
||||
http_cancellation_token: CancellationToken,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, ApiError> {
|
||||
let host = request
|
||||
.headers()
|
||||
.get("host")
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.and_then(|h| h.split(':').next())
|
||||
.map(|s| s.to_string());
|
||||
let ctx = RequestMonitoring::new(
|
||||
session_id,
|
||||
peer_addr,
|
||||
crate::metrics::Protocol::Ws,
|
||||
&config.region,
|
||||
);
|
||||
|
||||
// Check if the request is a websocket upgrade request.
|
||||
if config.http_config.accept_websockets
|
||||
&& framed_websockets::upgrade::is_upgrade_request(&request)
|
||||
{
|
||||
let ctx = RequestMonitoring::new(
|
||||
session_id,
|
||||
peer_addr,
|
||||
crate::metrics::Protocol::Ws,
|
||||
&config.region,
|
||||
);
|
||||
|
||||
let span = ctx.span();
|
||||
info!(parent: &span, "performing websocket upgrade");
|
||||
|
||||
let (response, websocket) = framed_websockets::upgrade::upgrade(&mut request)
|
||||
.map_err(|e| ApiError::BadRequest(e.into()))?;
|
||||
|
||||
ws_connections.spawn(
|
||||
async move {
|
||||
if let Err(e) = websocket::serve_websocket(
|
||||
config,
|
||||
backend.auth_backend,
|
||||
auth_backend,
|
||||
ctx,
|
||||
websocket,
|
||||
cancellation_handler,
|
||||
@@ -430,9 +400,45 @@ async fn request_handler(
|
||||
{
|
||||
warn!("error in websocket connection: {e:#}");
|
||||
}
|
||||
} else {
|
||||
tracing::info!(%peer_addr, "HTTP connection closed");
|
||||
}
|
||||
.instrument(span),
|
||||
);
|
||||
}
|
||||
Err(e) => tracing::warn!(%peer_addr, "HTTP connection error {e}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn request_handler(
|
||||
mut request: hyper::Request<Incoming>,
|
||||
config: &'static ProxyConfig,
|
||||
backend: Arc<PoolingBackend>,
|
||||
ws_spawner: mpsc::Sender<(uuid::Uuid, Option<String>, OnUpgrade)>,
|
||||
session_id: uuid::Uuid,
|
||||
peer_addr: IpAddr,
|
||||
// used to cancel in-flight HTTP requests. not used to cancel websockets
|
||||
http_cancellation_token: CancellationToken,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, ApiError> {
|
||||
// Check if the request is a websocket upgrade request.
|
||||
if config.http_config.accept_websockets
|
||||
&& framed_websockets::upgrade::is_upgrade_request(&request)
|
||||
{
|
||||
let (response, websocket) = framed_websockets::upgrade::upgrade(&mut request)
|
||||
.map_err(|e| ApiError::BadRequest(e.into()))?;
|
||||
|
||||
let host = request
|
||||
.headers()
|
||||
.get("host")
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.and_then(|h| h.split(':').next())
|
||||
.map(|s| s.to_string());
|
||||
|
||||
ws_spawner
|
||||
.send((session_id, host, websocket))
|
||||
.await
|
||||
.map_err(|_e| {
|
||||
ApiError::InternalServerError(anyhow::anyhow!("could not upgrade WS connection"))
|
||||
})?;
|
||||
|
||||
// Return the response so the spawned future can continue.
|
||||
Ok(response.map(|b| b.map_err(|x| match x {}).boxed()))
|
||||
|
||||
Reference in New Issue
Block a user