mirror of
https://github.com/neondatabase/neon.git
synced 2026-02-08 05:00:38 +00:00
Compare commits
7 Commits
release
...
refactor-w
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
eb96abbff7 | ||
|
|
fa5d907032 | ||
|
|
c4525483ae | ||
|
|
ba714431be | ||
|
|
7c57234de1 | ||
|
|
2acc621c4c | ||
|
|
7de2914dde |
@@ -14,9 +14,11 @@ mod local_conn_pool;
|
|||||||
mod sql_over_http;
|
mod sql_over_http;
|
||||||
mod websocket;
|
mod websocket;
|
||||||
|
|
||||||
|
use std::future::Future;
|
||||||
use std::net::{IpAddr, SocketAddr};
|
use std::net::{IpAddr, SocketAddr};
|
||||||
use std::pin::{pin, Pin};
|
use std::pin::{pin, Pin};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
use std::task::Poll;
|
||||||
|
|
||||||
use anyhow::Context;
|
use anyhow::Context;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
@@ -24,23 +26,26 @@ use atomic_take::AtomicTake;
|
|||||||
use bytes::Bytes;
|
use bytes::Bytes;
|
||||||
pub use conn_pool_lib::GlobalConnPoolOptions;
|
pub use conn_pool_lib::GlobalConnPoolOptions;
|
||||||
use futures::future::{select, Either};
|
use futures::future::{select, Either};
|
||||||
use futures::TryFutureExt;
|
use futures::FutureExt;
|
||||||
use http::{Method, Response, StatusCode};
|
use http::{Method, Response, StatusCode};
|
||||||
use http_body_util::combinators::BoxBody;
|
use http_body_util::combinators::BoxBody;
|
||||||
use http_body_util::{BodyExt, Empty};
|
use http_body_util::{BodyExt, Empty};
|
||||||
use hyper::body::Incoming;
|
use hyper::body::Incoming;
|
||||||
use hyper_util::rt::TokioExecutor;
|
use hyper::upgrade::OnUpgrade;
|
||||||
use hyper_util::server::conn::auto::Builder;
|
use hyper_util::rt::{TokioExecutor, TokioIo};
|
||||||
use rand::rngs::StdRng;
|
use rand::rngs::StdRng;
|
||||||
use rand::SeedableRng;
|
use rand::SeedableRng;
|
||||||
|
use smallvec::SmallVec;
|
||||||
use sql_over_http::{uuid_to_header_value, NEON_REQUEST_ID};
|
use sql_over_http::{uuid_to_header_value, NEON_REQUEST_ID};
|
||||||
use tokio::io::{AsyncRead, AsyncWrite};
|
use tokio::io::{AsyncRead, AsyncWrite};
|
||||||
use tokio::net::{TcpListener, TcpStream};
|
use tokio::net::{TcpListener, TcpStream};
|
||||||
|
use tokio::sync::oneshot;
|
||||||
|
use tokio::task::JoinHandle;
|
||||||
use tokio::time::timeout;
|
use tokio::time::timeout;
|
||||||
use tokio_rustls::TlsAcceptor;
|
use tokio_rustls::TlsAcceptor;
|
||||||
use tokio_util::sync::CancellationToken;
|
use tokio_util::sync::CancellationToken;
|
||||||
use tokio_util::task::TaskTracker;
|
use tokio_util::task::task_tracker::TaskTrackerToken;
|
||||||
use tracing::{info, warn, Instrument};
|
use tracing::{debug, info, warn, Instrument};
|
||||||
use utils::http::error::ApiError;
|
use utils::http::error::ApiError;
|
||||||
|
|
||||||
use crate::cancellation::CancellationHandlerMain;
|
use crate::cancellation::CancellationHandlerMain;
|
||||||
@@ -154,16 +159,18 @@ pub async fn task_main(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let conn_token = cancellation_token.child_token();
|
let conn_cancellation_token = cancellation_token.child_token();
|
||||||
let tls_acceptor = tls_acceptor.clone();
|
let tls_acceptor = tls_acceptor.clone();
|
||||||
let backend = backend.clone();
|
let backend = backend.clone();
|
||||||
let connections2 = connections.clone();
|
|
||||||
let cancellation_handler = cancellation_handler.clone();
|
let cancellation_handler = cancellation_handler.clone();
|
||||||
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
|
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
|
||||||
connections.spawn(
|
let connection_token = connections.token();
|
||||||
|
tokio::spawn(
|
||||||
async move {
|
async move {
|
||||||
let conn_token2 = conn_token.clone();
|
let _cancel_guard = config
|
||||||
let _cancel_guard = config.http_config.cancel_set.insert(conn_id, conn_token2);
|
.http_config
|
||||||
|
.cancel_set
|
||||||
|
.insert(conn_id, conn_cancellation_token.clone());
|
||||||
|
|
||||||
let session_id = uuid::Uuid::new_v4();
|
let session_id = uuid::Uuid::new_v4();
|
||||||
|
|
||||||
@@ -172,30 +179,41 @@ pub async fn task_main(
|
|||||||
.client_connections
|
.client_connections
|
||||||
.guard(crate::metrics::Protocol::Http);
|
.guard(crate::metrics::Protocol::Http);
|
||||||
|
|
||||||
let startup_result = Box::pin(connection_startup(
|
let startup_result =
|
||||||
config,
|
connection_startup(config, tls_acceptor, session_id, conn, peer_addr)
|
||||||
tls_acceptor,
|
.boxed()
|
||||||
session_id,
|
.await;
|
||||||
conn,
|
let Some(conn) = startup_result else {
|
||||||
peer_addr,
|
|
||||||
))
|
|
||||||
.await;
|
|
||||||
let Some((conn, peer_addr)) = startup_result else {
|
|
||||||
return;
|
return;
|
||||||
};
|
};
|
||||||
|
|
||||||
Box::pin(connection_handler(
|
let ws_upgrade = http_connection_handler(
|
||||||
config,
|
config,
|
||||||
backend,
|
backend,
|
||||||
connections2,
|
conn_cancellation_token,
|
||||||
cancellation_handler,
|
connection_token.clone(),
|
||||||
endpoint_rate_limiter,
|
|
||||||
conn_token,
|
|
||||||
conn,
|
conn,
|
||||||
peer_addr,
|
|
||||||
session_id,
|
session_id,
|
||||||
))
|
)
|
||||||
|
.boxed()
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
|
if let Some((ctx, host, websocket)) = ws_upgrade {
|
||||||
|
let ws = websocket::serve_websocket(
|
||||||
|
config,
|
||||||
|
auth_backend,
|
||||||
|
ctx,
|
||||||
|
websocket,
|
||||||
|
cancellation_handler,
|
||||||
|
endpoint_rate_limiter,
|
||||||
|
host,
|
||||||
|
)
|
||||||
|
.boxed();
|
||||||
|
|
||||||
|
if let Err(e) = ws.await {
|
||||||
|
warn!("error in websocket connection: {e:#}");
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
.instrument(http_conn_span),
|
.instrument(http_conn_span),
|
||||||
);
|
);
|
||||||
@@ -212,13 +230,19 @@ pub(crate) type AsyncRW = Pin<Box<dyn AsyncReadWrite>>;
|
|||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
trait MaybeTlsAcceptor: Send + Sync + 'static {
|
trait MaybeTlsAcceptor: Send + Sync + 'static {
|
||||||
async fn accept(self: Arc<Self>, conn: ChainRW<TcpStream>) -> std::io::Result<AsyncRW>;
|
async fn accept(self: Arc<Self>, conn: ChainRW<TcpStream>) -> std::io::Result<(AsyncRW, Alpn)>;
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl MaybeTlsAcceptor for rustls::ServerConfig {
|
impl MaybeTlsAcceptor for rustls::ServerConfig {
|
||||||
async fn accept(self: Arc<Self>, conn: ChainRW<TcpStream>) -> std::io::Result<AsyncRW> {
|
async fn accept(self: Arc<Self>, conn: ChainRW<TcpStream>) -> std::io::Result<(AsyncRW, Alpn)> {
|
||||||
Ok(Box::pin(TlsAcceptor::from(self).accept(conn).await?))
|
let conn = TlsAcceptor::from(self).accept(conn).await?;
|
||||||
|
let alpn = conn
|
||||||
|
.get_ref()
|
||||||
|
.1
|
||||||
|
.alpn_protocol()
|
||||||
|
.map_or_else(SmallVec::new, SmallVec::from_slice);
|
||||||
|
Ok((Box::pin(conn), alpn))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -226,11 +250,14 @@ struct NoTls;
|
|||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl MaybeTlsAcceptor for NoTls {
|
impl MaybeTlsAcceptor for NoTls {
|
||||||
async fn accept(self: Arc<Self>, conn: ChainRW<TcpStream>) -> std::io::Result<AsyncRW> {
|
async fn accept(self: Arc<Self>, conn: ChainRW<TcpStream>) -> std::io::Result<(AsyncRW, Alpn)> {
|
||||||
Ok(Box::pin(conn))
|
Ok((Box::pin(conn), SmallVec::new()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type Alpn = SmallVec<[u8; 8]>;
|
||||||
|
type ConnWithInfo = (AsyncRW, IpAddr, Alpn);
|
||||||
|
|
||||||
/// Handles the TCP startup lifecycle.
|
/// Handles the TCP startup lifecycle.
|
||||||
/// 1. Parses PROXY protocol V2
|
/// 1. Parses PROXY protocol V2
|
||||||
/// 2. Handles TLS handshake
|
/// 2. Handles TLS handshake
|
||||||
@@ -240,7 +267,7 @@ async fn connection_startup(
|
|||||||
session_id: uuid::Uuid,
|
session_id: uuid::Uuid,
|
||||||
conn: TcpStream,
|
conn: TcpStream,
|
||||||
peer_addr: SocketAddr,
|
peer_addr: SocketAddr,
|
||||||
) -> Option<(AsyncRW, IpAddr)> {
|
) -> Option<ConnWithInfo> {
|
||||||
// handle PROXY protocol
|
// handle PROXY protocol
|
||||||
let (conn, peer) = match read_proxy_protocol(conn).await {
|
let (conn, peer) = match read_proxy_protocol(conn).await {
|
||||||
Ok(c) => c,
|
Ok(c) => c,
|
||||||
@@ -258,7 +285,7 @@ async fn connection_startup(
|
|||||||
info!(?session_id, %peer_addr, "accepted new TCP connection");
|
info!(?session_id, %peer_addr, "accepted new TCP connection");
|
||||||
|
|
||||||
// try upgrade to TLS, but with a timeout.
|
// try upgrade to TLS, but with a timeout.
|
||||||
let conn = match timeout(config.handshake_timeout, tls_acceptor.accept(conn)).await {
|
let (conn, alpn) = match timeout(config.handshake_timeout, tls_acceptor.accept(conn)).await {
|
||||||
Ok(Ok(conn)) => {
|
Ok(Ok(conn)) => {
|
||||||
info!(?session_id, %peer_addr, "accepted new TLS connection");
|
info!(?session_id, %peer_addr, "accepted new TLS connection");
|
||||||
conn
|
conn
|
||||||
@@ -281,89 +308,99 @@ async fn connection_startup(
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
Some((conn, peer_addr))
|
Some((conn, peer_addr, alpn))
|
||||||
|
}
|
||||||
|
|
||||||
|
trait GracefulShutdown: Future<Output = Result<(), hyper::Error>> + Send {
|
||||||
|
fn graceful_shutdown(self: Pin<&mut Self>);
|
||||||
|
}
|
||||||
|
|
||||||
|
impl GracefulShutdown
|
||||||
|
for hyper::server::conn::http1::UpgradeableConnection<TokioIo<AsyncRW>, ProxyService>
|
||||||
|
{
|
||||||
|
fn graceful_shutdown(self: Pin<&mut Self>) {
|
||||||
|
self.graceful_shutdown();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl GracefulShutdown for hyper::server::conn::http1::Connection<TokioIo<AsyncRW>, ProxyService> {
|
||||||
|
fn graceful_shutdown(self: Pin<&mut Self>) {
|
||||||
|
self.graceful_shutdown();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl GracefulShutdown
|
||||||
|
for hyper::server::conn::http2::Connection<TokioIo<AsyncRW>, ProxyService, TokioExecutor>
|
||||||
|
{
|
||||||
|
fn graceful_shutdown(self: Pin<&mut Self>) {
|
||||||
|
self.graceful_shutdown();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Handles HTTP connection
|
/// Handles HTTP connection
|
||||||
/// 1. With graceful shutdowns
|
/// 1. With graceful shutdowns
|
||||||
/// 2. With graceful request cancellation with connection failure
|
/// 2. With graceful request cancellation with connection failure
|
||||||
/// 3. With websocket upgrade support.
|
/// 3. With websocket upgrade support.
|
||||||
#[allow(clippy::too_many_arguments)]
|
async fn http_connection_handler(
|
||||||
async fn connection_handler(
|
|
||||||
config: &'static ProxyConfig,
|
config: &'static ProxyConfig,
|
||||||
backend: Arc<PoolingBackend>,
|
backend: Arc<PoolingBackend>,
|
||||||
connections: TaskTracker,
|
|
||||||
cancellation_handler: Arc<CancellationHandlerMain>,
|
|
||||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
|
||||||
cancellation_token: CancellationToken,
|
cancellation_token: CancellationToken,
|
||||||
conn: AsyncRW,
|
connection_token: TaskTrackerToken,
|
||||||
peer_addr: IpAddr,
|
conn: ConnWithInfo,
|
||||||
session_id: uuid::Uuid,
|
session_id: uuid::Uuid,
|
||||||
) {
|
) -> Option<WsUpgrade> {
|
||||||
|
let (conn, peer_addr, alpn) = conn;
|
||||||
let session_id = AtomicTake::new(session_id);
|
let session_id = AtomicTake::new(session_id);
|
||||||
|
|
||||||
// Cancel all current inflight HTTP requests if the HTTP connection is closed.
|
// Cancel all current inflight HTTP requests if the HTTP connection is closed.
|
||||||
let http_cancellation_token = CancellationToken::new();
|
let http_cancellation_token = CancellationToken::new();
|
||||||
let _cancel_connection = http_cancellation_token.clone().drop_guard();
|
let _cancel_connection = http_cancellation_token.clone().drop_guard();
|
||||||
|
|
||||||
let server = Builder::new(TokioExecutor::new());
|
let (ws_tx, ws_rx) = oneshot::channel();
|
||||||
let conn = server.serve_connection_with_upgrades(
|
let ws_tx = Arc::new(AtomicTake::new(ws_tx));
|
||||||
hyper_util::rt::TokioIo::new(conn),
|
|
||||||
hyper::service::service_fn(move |req: hyper::Request<Incoming>| {
|
|
||||||
// First HTTP request shares the same session ID
|
|
||||||
let mut session_id = session_id.take().unwrap_or_else(uuid::Uuid::new_v4);
|
|
||||||
|
|
||||||
if matches!(backend.auth_backend, crate::auth::Backend::Local(_)) {
|
let http2 = match &*alpn {
|
||||||
// take session_id from request, if given.
|
b"h2" => true,
|
||||||
if let Some(id) = req
|
b"http/1.1" => false,
|
||||||
.headers()
|
_ => {
|
||||||
.get(&NEON_REQUEST_ID)
|
debug!("no alpn negotiated");
|
||||||
.and_then(|id| uuid::Uuid::try_parse_ascii(id.as_bytes()).ok())
|
false
|
||||||
{
|
}
|
||||||
session_id = id;
|
};
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Cancel the current inflight HTTP request if the requets stream is closed.
|
let service = ProxyService {
|
||||||
// This is slightly different to `_cancel_connection` in that
|
config,
|
||||||
// h2 can cancel individual requests with a `RST_STREAM`.
|
backend,
|
||||||
let http_request_token = http_cancellation_token.child_token();
|
connection_token,
|
||||||
let cancel_request = http_request_token.clone().drop_guard();
|
|
||||||
|
|
||||||
// `request_handler` is not cancel safe. It expects to be cancelled only at specific times.
|
http_cancellation_token,
|
||||||
// By spawning the future, we ensure it never gets cancelled until it decides to.
|
ws_tx,
|
||||||
let handler = connections.spawn(
|
peer_addr,
|
||||||
request_handler(
|
session_id,
|
||||||
req,
|
};
|
||||||
config,
|
|
||||||
backend.clone(),
|
|
||||||
connections.clone(),
|
|
||||||
cancellation_handler.clone(),
|
|
||||||
session_id,
|
|
||||||
peer_addr,
|
|
||||||
http_request_token,
|
|
||||||
endpoint_rate_limiter.clone(),
|
|
||||||
)
|
|
||||||
.in_current_span()
|
|
||||||
.map_ok_or_else(api_error_into_response, |r| r),
|
|
||||||
);
|
|
||||||
async move {
|
|
||||||
let mut res = handler.await;
|
|
||||||
cancel_request.disarm();
|
|
||||||
|
|
||||||
// add the session ID to the response
|
let io = hyper_util::rt::TokioIo::new(conn);
|
||||||
if let Ok(resp) = &mut res {
|
let conn: Pin<Box<dyn GracefulShutdown>> = if http2 {
|
||||||
resp.headers_mut()
|
service.ws_tx.take();
|
||||||
.append(&NEON_REQUEST_ID, uuid_to_header_value(session_id));
|
|
||||||
}
|
|
||||||
|
|
||||||
res
|
Box::pin(
|
||||||
}
|
hyper::server::conn::http2::Builder::new(TokioExecutor::new())
|
||||||
}),
|
.serve_connection(io, service),
|
||||||
);
|
)
|
||||||
|
} else if config.http_config.accept_websockets {
|
||||||
|
Box::pin(
|
||||||
|
hyper::server::conn::http1::Builder::new()
|
||||||
|
.serve_connection(io, service)
|
||||||
|
.with_upgrades(),
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
service.ws_tx.take();
|
||||||
|
|
||||||
|
Box::pin(hyper::server::conn::http1::Builder::new().serve_connection(io, service))
|
||||||
|
};
|
||||||
|
|
||||||
// On cancellation, trigger the HTTP connection handler to shut down.
|
// On cancellation, trigger the HTTP connection handler to shut down.
|
||||||
let res = match select(pin!(cancellation_token.cancelled()), pin!(conn)).await {
|
let res = match select(pin!(cancellation_token.cancelled()), conn).await {
|
||||||
Either::Left((_cancelled, mut conn)) => {
|
Either::Left((_cancelled, mut conn)) => {
|
||||||
tracing::debug!(%peer_addr, "cancelling connection");
|
tracing::debug!(%peer_addr, "cancelling connection");
|
||||||
conn.as_mut().graceful_shutdown();
|
conn.as_mut().graceful_shutdown();
|
||||||
@@ -373,35 +410,156 @@ async fn connection_handler(
|
|||||||
};
|
};
|
||||||
|
|
||||||
match res {
|
match res {
|
||||||
Ok(()) => tracing::info!(%peer_addr, "HTTP connection closed"),
|
Ok(()) => {
|
||||||
|
if let Ok(ws_upgrade) = ws_rx.await {
|
||||||
|
tracing::info!(%peer_addr, "connection upgraded to websockets");
|
||||||
|
|
||||||
|
return Some(ws_upgrade);
|
||||||
|
}
|
||||||
|
tracing::info!(%peer_addr, "HTTP connection closed");
|
||||||
|
}
|
||||||
Err(e) => tracing::warn!(%peer_addr, "HTTP connection error {e}"),
|
Err(e) => tracing::warn!(%peer_addr, "HTTP connection error {e}"),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ProxyService {
|
||||||
|
// global state
|
||||||
|
config: &'static ProxyConfig,
|
||||||
|
backend: Arc<PoolingBackend>,
|
||||||
|
|
||||||
|
// connection state only
|
||||||
|
connection_token: TaskTrackerToken,
|
||||||
|
http_cancellation_token: CancellationToken,
|
||||||
|
ws_tx: WsSpawner,
|
||||||
|
peer_addr: IpAddr,
|
||||||
|
session_id: AtomicTake<uuid::Uuid>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl hyper::service::Service<hyper::Request<Incoming>> for ProxyService {
|
||||||
|
type Response = Response<BoxBody<Bytes, hyper::Error>>;
|
||||||
|
|
||||||
|
type Error = tokio::task::JoinError;
|
||||||
|
|
||||||
|
type Future = ReqFut;
|
||||||
|
|
||||||
|
fn call(&self, req: hyper::Request<Incoming>) -> Self::Future {
|
||||||
|
// First HTTP request shares the same session ID
|
||||||
|
let mut session_id = self.session_id.take().unwrap_or_else(uuid::Uuid::new_v4);
|
||||||
|
|
||||||
|
if matches!(self.backend.auth_backend, crate::auth::Backend::Local(_)) {
|
||||||
|
// take session_id from request, if given.
|
||||||
|
if let Some(id) = req
|
||||||
|
.headers()
|
||||||
|
.get(&NEON_REQUEST_ID)
|
||||||
|
.and_then(|id| uuid::Uuid::try_parse_ascii(id.as_bytes()).ok())
|
||||||
|
{
|
||||||
|
session_id = id;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cancel the current inflight HTTP request if the requets stream is closed.
|
||||||
|
// This is slightly different to `_cancel_connection` in that
|
||||||
|
// h2 can cancel individual requests with a `RST_STREAM`.
|
||||||
|
let http_req_cancellation_token = self.http_cancellation_token.child_token();
|
||||||
|
let cancel_request = Some(http_req_cancellation_token.clone().drop_guard());
|
||||||
|
|
||||||
|
let handle = request_handler(
|
||||||
|
req,
|
||||||
|
self.config,
|
||||||
|
self.backend.clone(),
|
||||||
|
self.ws_tx.clone(),
|
||||||
|
session_id,
|
||||||
|
self.peer_addr,
|
||||||
|
http_req_cancellation_token,
|
||||||
|
&self.connection_token,
|
||||||
|
);
|
||||||
|
|
||||||
|
ReqFut {
|
||||||
|
session_id,
|
||||||
|
cancel_request,
|
||||||
|
handle,
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct ReqFut {
|
||||||
|
session_id: uuid::Uuid,
|
||||||
|
cancel_request: Option<tokio_util::sync::DropGuard>,
|
||||||
|
handle: HandleOrResponse,
|
||||||
|
}
|
||||||
|
|
||||||
|
enum HandleOrResponse {
|
||||||
|
Handle(JoinHandle<Response<BoxBody<Bytes, hyper::Error>>>),
|
||||||
|
Response(Option<Response<BoxBody<Bytes, hyper::Error>>>),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Future for ReqFut {
|
||||||
|
type Output = Result<Response<BoxBody<Bytes, hyper::Error>>, tokio::task::JoinError>;
|
||||||
|
|
||||||
|
fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
|
||||||
|
let mut res = match &mut self.handle {
|
||||||
|
HandleOrResponse::Handle(join_handle) => std::task::ready!(join_handle.poll_unpin(cx)),
|
||||||
|
HandleOrResponse::Response(response) => {
|
||||||
|
Ok(response.take().expect("polled after completion"))
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
self.cancel_request
|
||||||
|
.take()
|
||||||
|
.map(tokio_util::sync::DropGuard::disarm);
|
||||||
|
|
||||||
|
// add the session ID to the response
|
||||||
|
if let Ok(resp) = &mut res {
|
||||||
|
resp.headers_mut()
|
||||||
|
.append(&NEON_REQUEST_ID, uuid_to_header_value(self.session_id));
|
||||||
|
}
|
||||||
|
|
||||||
|
Poll::Ready(res)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type WsUpgrade = (RequestMonitoring, Option<String>, OnUpgrade);
|
||||||
|
type WsSpawner = Arc<AtomicTake<oneshot::Sender<WsUpgrade>>>;
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
async fn request_handler(
|
fn request_handler(
|
||||||
mut request: hyper::Request<Incoming>,
|
mut request: hyper::Request<Incoming>,
|
||||||
config: &'static ProxyConfig,
|
config: &'static ProxyConfig,
|
||||||
backend: Arc<PoolingBackend>,
|
backend: Arc<PoolingBackend>,
|
||||||
ws_connections: TaskTracker,
|
ws_spawner: WsSpawner,
|
||||||
cancellation_handler: Arc<CancellationHandlerMain>,
|
|
||||||
session_id: uuid::Uuid,
|
session_id: uuid::Uuid,
|
||||||
peer_addr: IpAddr,
|
peer_addr: IpAddr,
|
||||||
// used to cancel in-flight HTTP requests. not used to cancel websockets
|
// used to cancel in-flight HTTP requests. not used to cancel websockets
|
||||||
http_cancellation_token: CancellationToken,
|
http_cancellation_token: CancellationToken,
|
||||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
connection_token: &TaskTrackerToken,
|
||||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, ApiError> {
|
) -> HandleOrResponse {
|
||||||
let host = request
|
|
||||||
.headers()
|
|
||||||
.get("host")
|
|
||||||
.and_then(|h| h.to_str().ok())
|
|
||||||
.and_then(|h| h.split(':').next())
|
|
||||||
.map(|s| s.to_string());
|
|
||||||
|
|
||||||
// Check if the request is a websocket upgrade request.
|
// Check if the request is a websocket upgrade request.
|
||||||
if config.http_config.accept_websockets
|
if framed_websockets::upgrade::is_upgrade_request(&request) {
|
||||||
&& framed_websockets::upgrade::is_upgrade_request(&request)
|
let Some(spawner) = ws_spawner.take() else {
|
||||||
{
|
return HandleOrResponse::Response(Some(
|
||||||
|
json_response(StatusCode::BAD_REQUEST, "query is not supported")
|
||||||
|
.unwrap_or_else(api_error_into_response),
|
||||||
|
));
|
||||||
|
};
|
||||||
|
|
||||||
|
let (response, websocket) = match framed_websockets::upgrade::upgrade(&mut request) {
|
||||||
|
Err(e) => {
|
||||||
|
return HandleOrResponse::Response(Some(api_error_into_response(
|
||||||
|
ApiError::BadRequest(e.into()),
|
||||||
|
)))
|
||||||
|
}
|
||||||
|
Ok(upgrade) => upgrade,
|
||||||
|
};
|
||||||
|
|
||||||
|
let host = request
|
||||||
|
.headers()
|
||||||
|
.get("host")
|
||||||
|
.and_then(|h| h.to_str().ok())
|
||||||
|
.and_then(|h| h.split(':').next())
|
||||||
|
.map(|s| s.to_string());
|
||||||
|
|
||||||
let ctx = RequestMonitoring::new(
|
let ctx = RequestMonitoring::new(
|
||||||
session_id,
|
session_id,
|
||||||
peer_addr,
|
peer_addr,
|
||||||
@@ -409,33 +567,14 @@ async fn request_handler(
|
|||||||
&config.region,
|
&config.region,
|
||||||
);
|
);
|
||||||
|
|
||||||
let span = ctx.span();
|
if let Err(_e) = spawner.send((ctx, host, websocket)) {
|
||||||
info!(parent: &span, "performing websocket upgrade");
|
return HandleOrResponse::Response(Some(api_error_into_response(
|
||||||
|
ApiError::InternalServerError(anyhow::anyhow!("could not upgrade WS connection")),
|
||||||
let (response, websocket) = framed_websockets::upgrade::upgrade(&mut request)
|
)));
|
||||||
.map_err(|e| ApiError::BadRequest(e.into()))?;
|
}
|
||||||
|
|
||||||
ws_connections.spawn(
|
|
||||||
async move {
|
|
||||||
if let Err(e) = websocket::serve_websocket(
|
|
||||||
config,
|
|
||||||
backend.auth_backend,
|
|
||||||
ctx,
|
|
||||||
websocket,
|
|
||||||
cancellation_handler,
|
|
||||||
endpoint_rate_limiter,
|
|
||||||
host,
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
warn!("error in websocket connection: {e:#}");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
.instrument(span),
|
|
||||||
);
|
|
||||||
|
|
||||||
// Return the response so the spawned future can continue.
|
// Return the response so the spawned future can continue.
|
||||||
Ok(response.map(|b| b.map_err(|x| match x {}).boxed()))
|
HandleOrResponse::Response(Some(response.map(|b| b.map_err(|x| match x {}).boxed())))
|
||||||
} else if request.uri().path() == "/sql" && *request.method() == Method::POST {
|
} else if request.uri().path() == "/sql" && *request.method() == Method::POST {
|
||||||
let ctx = RequestMonitoring::new(
|
let ctx = RequestMonitoring::new(
|
||||||
session_id,
|
session_id,
|
||||||
@@ -445,11 +584,19 @@ async fn request_handler(
|
|||||||
);
|
);
|
||||||
let span = ctx.span();
|
let span = ctx.span();
|
||||||
|
|
||||||
sql_over_http::handle(config, ctx, request, backend, http_cancellation_token)
|
let token = connection_token.clone();
|
||||||
.instrument(span)
|
|
||||||
.await
|
// `sql_over_http::handle` is not cancel safe. It expects to be cancelled only at specific times.
|
||||||
|
// By spawning the future, we ensure it never gets cancelled until it decides to.
|
||||||
|
HandleOrResponse::Handle(tokio::spawn(async move {
|
||||||
|
let _token = token;
|
||||||
|
sql_over_http::handle(config, ctx, request, backend, http_cancellation_token)
|
||||||
|
.instrument(span)
|
||||||
|
.await
|
||||||
|
.unwrap_or_else(api_error_into_response)
|
||||||
|
}))
|
||||||
} else if request.uri().path() == "/sql" && *request.method() == Method::OPTIONS {
|
} else if request.uri().path() == "/sql" && *request.method() == Method::OPTIONS {
|
||||||
Response::builder()
|
HandleOrResponse::Response(Some( Response::builder()
|
||||||
.header("Allow", "OPTIONS, POST")
|
.header("Allow", "OPTIONS, POST")
|
||||||
.header("Access-Control-Allow-Origin", "*")
|
.header("Access-Control-Allow-Origin", "*")
|
||||||
.header(
|
.header(
|
||||||
@@ -459,8 +606,11 @@ async fn request_handler(
|
|||||||
.header("Access-Control-Max-Age", "86400" /* 24 hours */)
|
.header("Access-Control-Max-Age", "86400" /* 24 hours */)
|
||||||
.status(StatusCode::OK) // 204 is also valid, but see: https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/OPTIONS#status_code
|
.status(StatusCode::OK) // 204 is also valid, but see: https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/OPTIONS#status_code
|
||||||
.body(Empty::new().map_err(|x| match x {}).boxed())
|
.body(Empty::new().map_err(|x| match x {}).boxed())
|
||||||
.map_err(|e| ApiError::InternalServerError(e.into()))
|
.map_err(|e| ApiError::InternalServerError(e.into())).unwrap_or_else(api_error_into_response)))
|
||||||
} else {
|
} else {
|
||||||
json_response(StatusCode::BAD_REQUEST, "query is not supported")
|
HandleOrResponse::Response(Some(
|
||||||
|
json_response(StatusCode::BAD_REQUEST, "query is not supported")
|
||||||
|
.unwrap_or_else(api_error_into_response),
|
||||||
|
))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user