Compare commits

...

7 Commits

Author SHA1 Message Date
Conrad Ludgate
eb96abbff7 moar 2024-10-30 12:18:22 +00:00
Conrad Ludgate
fa5d907032 move up ws handling even more 2024-10-30 11:51:47 +00:00
Conrad Ludgate
c4525483ae keep going now 2024-10-30 11:39:33 +00:00
Conrad Ludgate
ba714431be oneshot 2024-10-30 11:13:56 +00:00
Conrad Ludgate
7c57234de1 more refactoring 2024-10-30 10:19:52 +00:00
Conrad Ludgate
2acc621c4c random changes 2024-10-30 09:55:46 +00:00
Conrad Ludgate
7de2914dde re-use the same tokio task for the websocket handling 2024-10-30 09:45:19 +00:00

View File

@@ -14,9 +14,11 @@ mod local_conn_pool;
mod sql_over_http; mod sql_over_http;
mod websocket; mod websocket;
use std::future::Future;
use std::net::{IpAddr, SocketAddr}; use std::net::{IpAddr, SocketAddr};
use std::pin::{pin, Pin}; use std::pin::{pin, Pin};
use std::sync::Arc; use std::sync::Arc;
use std::task::Poll;
use anyhow::Context; use anyhow::Context;
use async_trait::async_trait; use async_trait::async_trait;
@@ -24,23 +26,26 @@ use atomic_take::AtomicTake;
use bytes::Bytes; use bytes::Bytes;
pub use conn_pool_lib::GlobalConnPoolOptions; pub use conn_pool_lib::GlobalConnPoolOptions;
use futures::future::{select, Either}; use futures::future::{select, Either};
use futures::TryFutureExt; use futures::FutureExt;
use http::{Method, Response, StatusCode}; use http::{Method, Response, StatusCode};
use http_body_util::combinators::BoxBody; use http_body_util::combinators::BoxBody;
use http_body_util::{BodyExt, Empty}; use http_body_util::{BodyExt, Empty};
use hyper::body::Incoming; use hyper::body::Incoming;
use hyper_util::rt::TokioExecutor; use hyper::upgrade::OnUpgrade;
use hyper_util::server::conn::auto::Builder; use hyper_util::rt::{TokioExecutor, TokioIo};
use rand::rngs::StdRng; use rand::rngs::StdRng;
use rand::SeedableRng; use rand::SeedableRng;
use smallvec::SmallVec;
use sql_over_http::{uuid_to_header_value, NEON_REQUEST_ID}; use sql_over_http::{uuid_to_header_value, NEON_REQUEST_ID};
use tokio::io::{AsyncRead, AsyncWrite}; use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::{TcpListener, TcpStream}; use tokio::net::{TcpListener, TcpStream};
use tokio::sync::oneshot;
use tokio::task::JoinHandle;
use tokio::time::timeout; use tokio::time::timeout;
use tokio_rustls::TlsAcceptor; use tokio_rustls::TlsAcceptor;
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use tokio_util::task::TaskTracker; use tokio_util::task::task_tracker::TaskTrackerToken;
use tracing::{info, warn, Instrument}; use tracing::{debug, info, warn, Instrument};
use utils::http::error::ApiError; use utils::http::error::ApiError;
use crate::cancellation::CancellationHandlerMain; use crate::cancellation::CancellationHandlerMain;
@@ -154,16 +159,18 @@ pub async fn task_main(
} }
} }
let conn_token = cancellation_token.child_token(); let conn_cancellation_token = cancellation_token.child_token();
let tls_acceptor = tls_acceptor.clone(); let tls_acceptor = tls_acceptor.clone();
let backend = backend.clone(); let backend = backend.clone();
let connections2 = connections.clone();
let cancellation_handler = cancellation_handler.clone(); let cancellation_handler = cancellation_handler.clone();
let endpoint_rate_limiter = endpoint_rate_limiter.clone(); let endpoint_rate_limiter = endpoint_rate_limiter.clone();
connections.spawn( let connection_token = connections.token();
tokio::spawn(
async move { async move {
let conn_token2 = conn_token.clone(); let _cancel_guard = config
let _cancel_guard = config.http_config.cancel_set.insert(conn_id, conn_token2); .http_config
.cancel_set
.insert(conn_id, conn_cancellation_token.clone());
let session_id = uuid::Uuid::new_v4(); let session_id = uuid::Uuid::new_v4();
@@ -172,30 +179,41 @@ pub async fn task_main(
.client_connections .client_connections
.guard(crate::metrics::Protocol::Http); .guard(crate::metrics::Protocol::Http);
let startup_result = Box::pin(connection_startup( let startup_result =
config, connection_startup(config, tls_acceptor, session_id, conn, peer_addr)
tls_acceptor, .boxed()
session_id, .await;
conn, let Some(conn) = startup_result else {
peer_addr,
))
.await;
let Some((conn, peer_addr)) = startup_result else {
return; return;
}; };
Box::pin(connection_handler( let ws_upgrade = http_connection_handler(
config, config,
backend, backend,
connections2, conn_cancellation_token,
cancellation_handler, connection_token.clone(),
endpoint_rate_limiter,
conn_token,
conn, conn,
peer_addr,
session_id, session_id,
)) )
.boxed()
.await; .await;
if let Some((ctx, host, websocket)) = ws_upgrade {
let ws = websocket::serve_websocket(
config,
auth_backend,
ctx,
websocket,
cancellation_handler,
endpoint_rate_limiter,
host,
)
.boxed();
if let Err(e) = ws.await {
warn!("error in websocket connection: {e:#}");
}
}
} }
.instrument(http_conn_span), .instrument(http_conn_span),
); );
@@ -212,13 +230,19 @@ pub(crate) type AsyncRW = Pin<Box<dyn AsyncReadWrite>>;
#[async_trait] #[async_trait]
trait MaybeTlsAcceptor: Send + Sync + 'static { trait MaybeTlsAcceptor: Send + Sync + 'static {
async fn accept(self: Arc<Self>, conn: ChainRW<TcpStream>) -> std::io::Result<AsyncRW>; async fn accept(self: Arc<Self>, conn: ChainRW<TcpStream>) -> std::io::Result<(AsyncRW, Alpn)>;
} }
#[async_trait] #[async_trait]
impl MaybeTlsAcceptor for rustls::ServerConfig { impl MaybeTlsAcceptor for rustls::ServerConfig {
async fn accept(self: Arc<Self>, conn: ChainRW<TcpStream>) -> std::io::Result<AsyncRW> { async fn accept(self: Arc<Self>, conn: ChainRW<TcpStream>) -> std::io::Result<(AsyncRW, Alpn)> {
Ok(Box::pin(TlsAcceptor::from(self).accept(conn).await?)) let conn = TlsAcceptor::from(self).accept(conn).await?;
let alpn = conn
.get_ref()
.1
.alpn_protocol()
.map_or_else(SmallVec::new, SmallVec::from_slice);
Ok((Box::pin(conn), alpn))
} }
} }
@@ -226,11 +250,14 @@ struct NoTls;
#[async_trait] #[async_trait]
impl MaybeTlsAcceptor for NoTls { impl MaybeTlsAcceptor for NoTls {
async fn accept(self: Arc<Self>, conn: ChainRW<TcpStream>) -> std::io::Result<AsyncRW> { async fn accept(self: Arc<Self>, conn: ChainRW<TcpStream>) -> std::io::Result<(AsyncRW, Alpn)> {
Ok(Box::pin(conn)) Ok((Box::pin(conn), SmallVec::new()))
} }
} }
type Alpn = SmallVec<[u8; 8]>;
type ConnWithInfo = (AsyncRW, IpAddr, Alpn);
/// Handles the TCP startup lifecycle. /// Handles the TCP startup lifecycle.
/// 1. Parses PROXY protocol V2 /// 1. Parses PROXY protocol V2
/// 2. Handles TLS handshake /// 2. Handles TLS handshake
@@ -240,7 +267,7 @@ async fn connection_startup(
session_id: uuid::Uuid, session_id: uuid::Uuid,
conn: TcpStream, conn: TcpStream,
peer_addr: SocketAddr, peer_addr: SocketAddr,
) -> Option<(AsyncRW, IpAddr)> { ) -> Option<ConnWithInfo> {
// handle PROXY protocol // handle PROXY protocol
let (conn, peer) = match read_proxy_protocol(conn).await { let (conn, peer) = match read_proxy_protocol(conn).await {
Ok(c) => c, Ok(c) => c,
@@ -258,7 +285,7 @@ async fn connection_startup(
info!(?session_id, %peer_addr, "accepted new TCP connection"); info!(?session_id, %peer_addr, "accepted new TCP connection");
// try upgrade to TLS, but with a timeout. // try upgrade to TLS, but with a timeout.
let conn = match timeout(config.handshake_timeout, tls_acceptor.accept(conn)).await { let (conn, alpn) = match timeout(config.handshake_timeout, tls_acceptor.accept(conn)).await {
Ok(Ok(conn)) => { Ok(Ok(conn)) => {
info!(?session_id, %peer_addr, "accepted new TLS connection"); info!(?session_id, %peer_addr, "accepted new TLS connection");
conn conn
@@ -281,89 +308,99 @@ async fn connection_startup(
} }
}; };
Some((conn, peer_addr)) Some((conn, peer_addr, alpn))
}
trait GracefulShutdown: Future<Output = Result<(), hyper::Error>> + Send {
fn graceful_shutdown(self: Pin<&mut Self>);
}
impl GracefulShutdown
for hyper::server::conn::http1::UpgradeableConnection<TokioIo<AsyncRW>, ProxyService>
{
fn graceful_shutdown(self: Pin<&mut Self>) {
self.graceful_shutdown();
}
}
impl GracefulShutdown for hyper::server::conn::http1::Connection<TokioIo<AsyncRW>, ProxyService> {
fn graceful_shutdown(self: Pin<&mut Self>) {
self.graceful_shutdown();
}
}
impl GracefulShutdown
for hyper::server::conn::http2::Connection<TokioIo<AsyncRW>, ProxyService, TokioExecutor>
{
fn graceful_shutdown(self: Pin<&mut Self>) {
self.graceful_shutdown();
}
} }
/// Handles HTTP connection /// Handles HTTP connection
/// 1. With graceful shutdowns /// 1. With graceful shutdowns
/// 2. With graceful request cancellation with connection failure /// 2. With graceful request cancellation with connection failure
/// 3. With websocket upgrade support. /// 3. With websocket upgrade support.
#[allow(clippy::too_many_arguments)] async fn http_connection_handler(
async fn connection_handler(
config: &'static ProxyConfig, config: &'static ProxyConfig,
backend: Arc<PoolingBackend>, backend: Arc<PoolingBackend>,
connections: TaskTracker,
cancellation_handler: Arc<CancellationHandlerMain>,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
cancellation_token: CancellationToken, cancellation_token: CancellationToken,
conn: AsyncRW, connection_token: TaskTrackerToken,
peer_addr: IpAddr, conn: ConnWithInfo,
session_id: uuid::Uuid, session_id: uuid::Uuid,
) { ) -> Option<WsUpgrade> {
let (conn, peer_addr, alpn) = conn;
let session_id = AtomicTake::new(session_id); let session_id = AtomicTake::new(session_id);
// Cancel all current inflight HTTP requests if the HTTP connection is closed. // Cancel all current inflight HTTP requests if the HTTP connection is closed.
let http_cancellation_token = CancellationToken::new(); let http_cancellation_token = CancellationToken::new();
let _cancel_connection = http_cancellation_token.clone().drop_guard(); let _cancel_connection = http_cancellation_token.clone().drop_guard();
let server = Builder::new(TokioExecutor::new()); let (ws_tx, ws_rx) = oneshot::channel();
let conn = server.serve_connection_with_upgrades( let ws_tx = Arc::new(AtomicTake::new(ws_tx));
hyper_util::rt::TokioIo::new(conn),
hyper::service::service_fn(move |req: hyper::Request<Incoming>| {
// First HTTP request shares the same session ID
let mut session_id = session_id.take().unwrap_or_else(uuid::Uuid::new_v4);
if matches!(backend.auth_backend, crate::auth::Backend::Local(_)) { let http2 = match &*alpn {
// take session_id from request, if given. b"h2" => true,
if let Some(id) = req b"http/1.1" => false,
.headers() _ => {
.get(&NEON_REQUEST_ID) debug!("no alpn negotiated");
.and_then(|id| uuid::Uuid::try_parse_ascii(id.as_bytes()).ok()) false
{ }
session_id = id; };
}
}
// Cancel the current inflight HTTP request if the requets stream is closed. let service = ProxyService {
// This is slightly different to `_cancel_connection` in that config,
// h2 can cancel individual requests with a `RST_STREAM`. backend,
let http_request_token = http_cancellation_token.child_token(); connection_token,
let cancel_request = http_request_token.clone().drop_guard();
// `request_handler` is not cancel safe. It expects to be cancelled only at specific times. http_cancellation_token,
// By spawning the future, we ensure it never gets cancelled until it decides to. ws_tx,
let handler = connections.spawn( peer_addr,
request_handler( session_id,
req, };
config,
backend.clone(),
connections.clone(),
cancellation_handler.clone(),
session_id,
peer_addr,
http_request_token,
endpoint_rate_limiter.clone(),
)
.in_current_span()
.map_ok_or_else(api_error_into_response, |r| r),
);
async move {
let mut res = handler.await;
cancel_request.disarm();
// add the session ID to the response let io = hyper_util::rt::TokioIo::new(conn);
if let Ok(resp) = &mut res { let conn: Pin<Box<dyn GracefulShutdown>> = if http2 {
resp.headers_mut() service.ws_tx.take();
.append(&NEON_REQUEST_ID, uuid_to_header_value(session_id));
}
res Box::pin(
} hyper::server::conn::http2::Builder::new(TokioExecutor::new())
}), .serve_connection(io, service),
); )
} else if config.http_config.accept_websockets {
Box::pin(
hyper::server::conn::http1::Builder::new()
.serve_connection(io, service)
.with_upgrades(),
)
} else {
service.ws_tx.take();
Box::pin(hyper::server::conn::http1::Builder::new().serve_connection(io, service))
};
// On cancellation, trigger the HTTP connection handler to shut down. // On cancellation, trigger the HTTP connection handler to shut down.
let res = match select(pin!(cancellation_token.cancelled()), pin!(conn)).await { let res = match select(pin!(cancellation_token.cancelled()), conn).await {
Either::Left((_cancelled, mut conn)) => { Either::Left((_cancelled, mut conn)) => {
tracing::debug!(%peer_addr, "cancelling connection"); tracing::debug!(%peer_addr, "cancelling connection");
conn.as_mut().graceful_shutdown(); conn.as_mut().graceful_shutdown();
@@ -373,35 +410,156 @@ async fn connection_handler(
}; };
match res { match res {
Ok(()) => tracing::info!(%peer_addr, "HTTP connection closed"), Ok(()) => {
if let Ok(ws_upgrade) = ws_rx.await {
tracing::info!(%peer_addr, "connection upgraded to websockets");
return Some(ws_upgrade);
}
tracing::info!(%peer_addr, "HTTP connection closed");
}
Err(e) => tracing::warn!(%peer_addr, "HTTP connection error {e}"), Err(e) => tracing::warn!(%peer_addr, "HTTP connection error {e}"),
} }
None
}
struct ProxyService {
// global state
config: &'static ProxyConfig,
backend: Arc<PoolingBackend>,
// connection state only
connection_token: TaskTrackerToken,
http_cancellation_token: CancellationToken,
ws_tx: WsSpawner,
peer_addr: IpAddr,
session_id: AtomicTake<uuid::Uuid>,
}
impl hyper::service::Service<hyper::Request<Incoming>> for ProxyService {
type Response = Response<BoxBody<Bytes, hyper::Error>>;
type Error = tokio::task::JoinError;
type Future = ReqFut;
fn call(&self, req: hyper::Request<Incoming>) -> Self::Future {
// First HTTP request shares the same session ID
let mut session_id = self.session_id.take().unwrap_or_else(uuid::Uuid::new_v4);
if matches!(self.backend.auth_backend, crate::auth::Backend::Local(_)) {
// take session_id from request, if given.
if let Some(id) = req
.headers()
.get(&NEON_REQUEST_ID)
.and_then(|id| uuid::Uuid::try_parse_ascii(id.as_bytes()).ok())
{
session_id = id;
}
}
// Cancel the current inflight HTTP request if the requets stream is closed.
// This is slightly different to `_cancel_connection` in that
// h2 can cancel individual requests with a `RST_STREAM`.
let http_req_cancellation_token = self.http_cancellation_token.child_token();
let cancel_request = Some(http_req_cancellation_token.clone().drop_guard());
let handle = request_handler(
req,
self.config,
self.backend.clone(),
self.ws_tx.clone(),
session_id,
self.peer_addr,
http_req_cancellation_token,
&self.connection_token,
);
ReqFut {
session_id,
cancel_request,
handle,
}
}
} }
struct ReqFut {
session_id: uuid::Uuid,
cancel_request: Option<tokio_util::sync::DropGuard>,
handle: HandleOrResponse,
}
enum HandleOrResponse {
Handle(JoinHandle<Response<BoxBody<Bytes, hyper::Error>>>),
Response(Option<Response<BoxBody<Bytes, hyper::Error>>>),
}
impl Future for ReqFut {
type Output = Result<Response<BoxBody<Bytes, hyper::Error>>, tokio::task::JoinError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
let mut res = match &mut self.handle {
HandleOrResponse::Handle(join_handle) => std::task::ready!(join_handle.poll_unpin(cx)),
HandleOrResponse::Response(response) => {
Ok(response.take().expect("polled after completion"))
}
};
self.cancel_request
.take()
.map(tokio_util::sync::DropGuard::disarm);
// add the session ID to the response
if let Ok(resp) = &mut res {
resp.headers_mut()
.append(&NEON_REQUEST_ID, uuid_to_header_value(self.session_id));
}
Poll::Ready(res)
}
}
type WsUpgrade = (RequestMonitoring, Option<String>, OnUpgrade);
type WsSpawner = Arc<AtomicTake<oneshot::Sender<WsUpgrade>>>;
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
async fn request_handler( fn request_handler(
mut request: hyper::Request<Incoming>, mut request: hyper::Request<Incoming>,
config: &'static ProxyConfig, config: &'static ProxyConfig,
backend: Arc<PoolingBackend>, backend: Arc<PoolingBackend>,
ws_connections: TaskTracker, ws_spawner: WsSpawner,
cancellation_handler: Arc<CancellationHandlerMain>,
session_id: uuid::Uuid, session_id: uuid::Uuid,
peer_addr: IpAddr, peer_addr: IpAddr,
// used to cancel in-flight HTTP requests. not used to cancel websockets // used to cancel in-flight HTTP requests. not used to cancel websockets
http_cancellation_token: CancellationToken, http_cancellation_token: CancellationToken,
endpoint_rate_limiter: Arc<EndpointRateLimiter>, connection_token: &TaskTrackerToken,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, ApiError> { ) -> HandleOrResponse {
let host = request
.headers()
.get("host")
.and_then(|h| h.to_str().ok())
.and_then(|h| h.split(':').next())
.map(|s| s.to_string());
// Check if the request is a websocket upgrade request. // Check if the request is a websocket upgrade request.
if config.http_config.accept_websockets if framed_websockets::upgrade::is_upgrade_request(&request) {
&& framed_websockets::upgrade::is_upgrade_request(&request) let Some(spawner) = ws_spawner.take() else {
{ return HandleOrResponse::Response(Some(
json_response(StatusCode::BAD_REQUEST, "query is not supported")
.unwrap_or_else(api_error_into_response),
));
};
let (response, websocket) = match framed_websockets::upgrade::upgrade(&mut request) {
Err(e) => {
return HandleOrResponse::Response(Some(api_error_into_response(
ApiError::BadRequest(e.into()),
)))
}
Ok(upgrade) => upgrade,
};
let host = request
.headers()
.get("host")
.and_then(|h| h.to_str().ok())
.and_then(|h| h.split(':').next())
.map(|s| s.to_string());
let ctx = RequestMonitoring::new( let ctx = RequestMonitoring::new(
session_id, session_id,
peer_addr, peer_addr,
@@ -409,33 +567,14 @@ async fn request_handler(
&config.region, &config.region,
); );
let span = ctx.span(); if let Err(_e) = spawner.send((ctx, host, websocket)) {
info!(parent: &span, "performing websocket upgrade"); return HandleOrResponse::Response(Some(api_error_into_response(
ApiError::InternalServerError(anyhow::anyhow!("could not upgrade WS connection")),
let (response, websocket) = framed_websockets::upgrade::upgrade(&mut request) )));
.map_err(|e| ApiError::BadRequest(e.into()))?; }
ws_connections.spawn(
async move {
if let Err(e) = websocket::serve_websocket(
config,
backend.auth_backend,
ctx,
websocket,
cancellation_handler,
endpoint_rate_limiter,
host,
)
.await
{
warn!("error in websocket connection: {e:#}");
}
}
.instrument(span),
);
// Return the response so the spawned future can continue. // Return the response so the spawned future can continue.
Ok(response.map(|b| b.map_err(|x| match x {}).boxed())) HandleOrResponse::Response(Some(response.map(|b| b.map_err(|x| match x {}).boxed())))
} else if request.uri().path() == "/sql" && *request.method() == Method::POST { } else if request.uri().path() == "/sql" && *request.method() == Method::POST {
let ctx = RequestMonitoring::new( let ctx = RequestMonitoring::new(
session_id, session_id,
@@ -445,11 +584,19 @@ async fn request_handler(
); );
let span = ctx.span(); let span = ctx.span();
sql_over_http::handle(config, ctx, request, backend, http_cancellation_token) let token = connection_token.clone();
.instrument(span)
.await // `sql_over_http::handle` is not cancel safe. It expects to be cancelled only at specific times.
// By spawning the future, we ensure it never gets cancelled until it decides to.
HandleOrResponse::Handle(tokio::spawn(async move {
let _token = token;
sql_over_http::handle(config, ctx, request, backend, http_cancellation_token)
.instrument(span)
.await
.unwrap_or_else(api_error_into_response)
}))
} else if request.uri().path() == "/sql" && *request.method() == Method::OPTIONS { } else if request.uri().path() == "/sql" && *request.method() == Method::OPTIONS {
Response::builder() HandleOrResponse::Response(Some( Response::builder()
.header("Allow", "OPTIONS, POST") .header("Allow", "OPTIONS, POST")
.header("Access-Control-Allow-Origin", "*") .header("Access-Control-Allow-Origin", "*")
.header( .header(
@@ -459,8 +606,11 @@ async fn request_handler(
.header("Access-Control-Max-Age", "86400" /* 24 hours */) .header("Access-Control-Max-Age", "86400" /* 24 hours */)
.status(StatusCode::OK) // 204 is also valid, but see: https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/OPTIONS#status_code .status(StatusCode::OK) // 204 is also valid, but see: https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/OPTIONS#status_code
.body(Empty::new().map_err(|x| match x {}).boxed()) .body(Empty::new().map_err(|x| match x {}).boxed())
.map_err(|e| ApiError::InternalServerError(e.into())) .map_err(|e| ApiError::InternalServerError(e.into())).unwrap_or_else(api_error_into_response)))
} else { } else {
json_response(StatusCode::BAD_REQUEST, "query is not supported") HandleOrResponse::Response(Some(
json_response(StatusCode::BAD_REQUEST, "query is not supported")
.unwrap_or_else(api_error_into_response),
))
} }
} }