mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-15 09:22:55 +00:00
oneshot
This commit is contained in:
@@ -38,7 +38,7 @@ use smallvec::SmallVec;
|
||||
use sql_over_http::{uuid_to_header_value, NEON_REQUEST_ID};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio::net::{TcpListener, TcpStream};
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::oneshot;
|
||||
use tokio::time::timeout;
|
||||
use tokio_rustls::TlsAcceptor;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
@@ -304,9 +304,9 @@ impl<B, S> GracefulShutdown
|
||||
for hyper::server::conn::http1::UpgradeableConnection<TokioIo<AsyncRW>, S>
|
||||
where
|
||||
S: hyper::service::HttpService<hyper::body::Incoming, ResBody = B> + Send,
|
||||
S::Future: Send + 'static,
|
||||
S::Future: Send,
|
||||
B: hyper::body::Body + Send + 'static,
|
||||
B::Data: Send + 'static,
|
||||
B::Data: Send,
|
||||
|
||||
S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
|
||||
B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
|
||||
@@ -319,9 +319,9 @@ where
|
||||
impl<B, S> GracefulShutdown for hyper::server::conn::http1::Connection<TokioIo<AsyncRW>, S>
|
||||
where
|
||||
S: hyper::service::HttpService<hyper::body::Incoming, ResBody = B> + Send,
|
||||
S::Future: Send + 'static,
|
||||
S::Future: Send,
|
||||
B: hyper::body::Body + Send + 'static,
|
||||
B::Data: Send + 'static,
|
||||
B::Data: Send,
|
||||
|
||||
S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
|
||||
B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
|
||||
@@ -337,7 +337,7 @@ where
|
||||
S: hyper::service::HttpService<hyper::body::Incoming, ResBody = B> + Send,
|
||||
S::Future: Send + 'static,
|
||||
B: hyper::body::Body + Send + 'static,
|
||||
B::Data: Send + 'static,
|
||||
B::Data: Send,
|
||||
|
||||
S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
|
||||
B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
|
||||
@@ -370,7 +370,8 @@ async fn connection_handler(
|
||||
let http_cancellation_token = CancellationToken::new();
|
||||
let _cancel_connection = http_cancellation_token.clone().drop_guard();
|
||||
|
||||
let (ws_tx, mut ws_rx) = mpsc::channel(1);
|
||||
let (ws_tx, ws_rx) = oneshot::channel();
|
||||
let ws_tx = Arc::new(AtomicTake::new(ws_tx));
|
||||
let auth_backend = backend.auth_backend;
|
||||
|
||||
let http2 = match &*alpn {
|
||||
@@ -382,6 +383,11 @@ async fn connection_handler(
|
||||
}
|
||||
};
|
||||
|
||||
if http2 || !config.http_config.accept_websockets {
|
||||
// discard the ws spawner
|
||||
ws_tx.take();
|
||||
}
|
||||
|
||||
let service = hyper::service::service_fn(move |req: hyper::Request<Incoming>| {
|
||||
// First HTTP request shares the same session ID
|
||||
let mut session_id = session_id.take().unwrap_or_else(uuid::Uuid::new_v4);
|
||||
@@ -460,7 +466,7 @@ async fn connection_handler(
|
||||
|
||||
match res {
|
||||
Ok(()) => {
|
||||
if let Some((session_id, host, websocket)) = ws_rx.recv().await {
|
||||
if let Ok((session_id, host, websocket)) = ws_rx.await {
|
||||
tracing::info!(%peer_addr, "connection upgraded to websockets");
|
||||
|
||||
let ctx = RequestMonitoring::new(
|
||||
@@ -491,21 +497,26 @@ async fn connection_handler(
|
||||
}
|
||||
}
|
||||
|
||||
type WsUpgrade = (uuid::Uuid, Option<String>, OnUpgrade);
|
||||
type WsSpawner = Arc<AtomicTake<oneshot::Sender<WsUpgrade>>>;
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn request_handler(
|
||||
mut request: hyper::Request<Incoming>,
|
||||
config: &'static ProxyConfig,
|
||||
backend: Arc<PoolingBackend>,
|
||||
ws_spawner: mpsc::Sender<(uuid::Uuid, Option<String>, OnUpgrade)>,
|
||||
ws_spawner: WsSpawner,
|
||||
session_id: uuid::Uuid,
|
||||
peer_addr: IpAddr,
|
||||
// used to cancel in-flight HTTP requests. not used to cancel websockets
|
||||
http_cancellation_token: CancellationToken,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, ApiError> {
|
||||
// Check if the request is a websocket upgrade request.
|
||||
if config.http_config.accept_websockets
|
||||
&& framed_websockets::upgrade::is_upgrade_request(&request)
|
||||
{
|
||||
if framed_websockets::upgrade::is_upgrade_request(&request) {
|
||||
let Some(spawner) = ws_spawner.take() else {
|
||||
return json_response(StatusCode::BAD_REQUEST, "query is not supported");
|
||||
};
|
||||
|
||||
let (response, websocket) = framed_websockets::upgrade::upgrade(&mut request)
|
||||
.map_err(|e| ApiError::BadRequest(e.into()))?;
|
||||
|
||||
@@ -516,12 +527,9 @@ async fn request_handler(
|
||||
.and_then(|h| h.split(':').next())
|
||||
.map(|s| s.to_string());
|
||||
|
||||
ws_spawner
|
||||
.send((session_id, host, websocket))
|
||||
.await
|
||||
.map_err(|_e| {
|
||||
ApiError::InternalServerError(anyhow::anyhow!("could not upgrade WS connection"))
|
||||
})?;
|
||||
spawner.send((session_id, host, websocket)).map_err(|_e| {
|
||||
ApiError::InternalServerError(anyhow::anyhow!("could not upgrade WS connection"))
|
||||
})?;
|
||||
|
||||
// Return the response so the spawned future can continue.
|
||||
Ok(response.map(|b| b.map_err(|x| match x {}).boxed()))
|
||||
|
||||
Reference in New Issue
Block a user