Compare commits

...

7 Commits

Author SHA1 Message Date
Conrad Ludgate
eb96abbff7 moar 2024-10-30 12:18:22 +00:00
Conrad Ludgate
fa5d907032 move up ws handling even more 2024-10-30 11:51:47 +00:00
Conrad Ludgate
c4525483ae keep going now 2024-10-30 11:39:33 +00:00
Conrad Ludgate
ba714431be oneshot 2024-10-30 11:13:56 +00:00
Conrad Ludgate
7c57234de1 more refactoring 2024-10-30 10:19:52 +00:00
Conrad Ludgate
2acc621c4c random changes 2024-10-30 09:55:46 +00:00
Conrad Ludgate
7de2914dde re-use the same tokio task for the websocket handling 2024-10-30 09:45:19 +00:00

View File

@@ -14,9 +14,11 @@ mod local_conn_pool;
mod sql_over_http;
mod websocket;
use std::future::Future;
use std::net::{IpAddr, SocketAddr};
use std::pin::{pin, Pin};
use std::sync::Arc;
use std::task::Poll;
use anyhow::Context;
use async_trait::async_trait;
@@ -24,23 +26,26 @@ use atomic_take::AtomicTake;
use bytes::Bytes;
pub use conn_pool_lib::GlobalConnPoolOptions;
use futures::future::{select, Either};
use futures::TryFutureExt;
use futures::FutureExt;
use http::{Method, Response, StatusCode};
use http_body_util::combinators::BoxBody;
use http_body_util::{BodyExt, Empty};
use hyper::body::Incoming;
use hyper_util::rt::TokioExecutor;
use hyper_util::server::conn::auto::Builder;
use hyper::upgrade::OnUpgrade;
use hyper_util::rt::{TokioExecutor, TokioIo};
use rand::rngs::StdRng;
use rand::SeedableRng;
use smallvec::SmallVec;
use sql_over_http::{uuid_to_header_value, NEON_REQUEST_ID};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::oneshot;
use tokio::task::JoinHandle;
use tokio::time::timeout;
use tokio_rustls::TlsAcceptor;
use tokio_util::sync::CancellationToken;
use tokio_util::task::TaskTracker;
use tracing::{info, warn, Instrument};
use tokio_util::task::task_tracker::TaskTrackerToken;
use tracing::{debug, info, warn, Instrument};
use utils::http::error::ApiError;
use crate::cancellation::CancellationHandlerMain;
@@ -154,16 +159,18 @@ pub async fn task_main(
}
}
let conn_token = cancellation_token.child_token();
let conn_cancellation_token = cancellation_token.child_token();
let tls_acceptor = tls_acceptor.clone();
let backend = backend.clone();
let connections2 = connections.clone();
let cancellation_handler = cancellation_handler.clone();
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
connections.spawn(
let connection_token = connections.token();
tokio::spawn(
async move {
let conn_token2 = conn_token.clone();
let _cancel_guard = config.http_config.cancel_set.insert(conn_id, conn_token2);
let _cancel_guard = config
.http_config
.cancel_set
.insert(conn_id, conn_cancellation_token.clone());
let session_id = uuid::Uuid::new_v4();
@@ -172,30 +179,41 @@ pub async fn task_main(
.client_connections
.guard(crate::metrics::Protocol::Http);
let startup_result = Box::pin(connection_startup(
config,
tls_acceptor,
session_id,
conn,
peer_addr,
))
.await;
let Some((conn, peer_addr)) = startup_result else {
let startup_result =
connection_startup(config, tls_acceptor, session_id, conn, peer_addr)
.boxed()
.await;
let Some(conn) = startup_result else {
return;
};
Box::pin(connection_handler(
let ws_upgrade = http_connection_handler(
config,
backend,
connections2,
cancellation_handler,
endpoint_rate_limiter,
conn_token,
conn_cancellation_token,
connection_token.clone(),
conn,
peer_addr,
session_id,
))
)
.boxed()
.await;
if let Some((ctx, host, websocket)) = ws_upgrade {
let ws = websocket::serve_websocket(
config,
auth_backend,
ctx,
websocket,
cancellation_handler,
endpoint_rate_limiter,
host,
)
.boxed();
if let Err(e) = ws.await {
warn!("error in websocket connection: {e:#}");
}
}
}
.instrument(http_conn_span),
);
@@ -212,13 +230,19 @@ pub(crate) type AsyncRW = Pin<Box<dyn AsyncReadWrite>>;
#[async_trait]
trait MaybeTlsAcceptor: Send + Sync + 'static {
async fn accept(self: Arc<Self>, conn: ChainRW<TcpStream>) -> std::io::Result<AsyncRW>;
async fn accept(self: Arc<Self>, conn: ChainRW<TcpStream>) -> std::io::Result<(AsyncRW, Alpn)>;
}
#[async_trait]
impl MaybeTlsAcceptor for rustls::ServerConfig {
async fn accept(self: Arc<Self>, conn: ChainRW<TcpStream>) -> std::io::Result<AsyncRW> {
Ok(Box::pin(TlsAcceptor::from(self).accept(conn).await?))
async fn accept(self: Arc<Self>, conn: ChainRW<TcpStream>) -> std::io::Result<(AsyncRW, Alpn)> {
let conn = TlsAcceptor::from(self).accept(conn).await?;
let alpn = conn
.get_ref()
.1
.alpn_protocol()
.map_or_else(SmallVec::new, SmallVec::from_slice);
Ok((Box::pin(conn), alpn))
}
}
@@ -226,11 +250,14 @@ struct NoTls;
#[async_trait]
impl MaybeTlsAcceptor for NoTls {
async fn accept(self: Arc<Self>, conn: ChainRW<TcpStream>) -> std::io::Result<AsyncRW> {
Ok(Box::pin(conn))
async fn accept(self: Arc<Self>, conn: ChainRW<TcpStream>) -> std::io::Result<(AsyncRW, Alpn)> {
Ok((Box::pin(conn), SmallVec::new()))
}
}
type Alpn = SmallVec<[u8; 8]>;
type ConnWithInfo = (AsyncRW, IpAddr, Alpn);
/// Handles the TCP startup lifecycle.
/// 1. Parses PROXY protocol V2
/// 2. Handles TLS handshake
@@ -240,7 +267,7 @@ async fn connection_startup(
session_id: uuid::Uuid,
conn: TcpStream,
peer_addr: SocketAddr,
) -> Option<(AsyncRW, IpAddr)> {
) -> Option<ConnWithInfo> {
// handle PROXY protocol
let (conn, peer) = match read_proxy_protocol(conn).await {
Ok(c) => c,
@@ -258,7 +285,7 @@ async fn connection_startup(
info!(?session_id, %peer_addr, "accepted new TCP connection");
// try upgrade to TLS, but with a timeout.
let conn = match timeout(config.handshake_timeout, tls_acceptor.accept(conn)).await {
let (conn, alpn) = match timeout(config.handshake_timeout, tls_acceptor.accept(conn)).await {
Ok(Ok(conn)) => {
info!(?session_id, %peer_addr, "accepted new TLS connection");
conn
@@ -281,89 +308,99 @@ async fn connection_startup(
}
};
Some((conn, peer_addr))
Some((conn, peer_addr, alpn))
}
trait GracefulShutdown: Future<Output = Result<(), hyper::Error>> + Send {
fn graceful_shutdown(self: Pin<&mut Self>);
}
impl GracefulShutdown
for hyper::server::conn::http1::UpgradeableConnection<TokioIo<AsyncRW>, ProxyService>
{
fn graceful_shutdown(self: Pin<&mut Self>) {
self.graceful_shutdown();
}
}
impl GracefulShutdown for hyper::server::conn::http1::Connection<TokioIo<AsyncRW>, ProxyService> {
fn graceful_shutdown(self: Pin<&mut Self>) {
self.graceful_shutdown();
}
}
impl GracefulShutdown
for hyper::server::conn::http2::Connection<TokioIo<AsyncRW>, ProxyService, TokioExecutor>
{
fn graceful_shutdown(self: Pin<&mut Self>) {
self.graceful_shutdown();
}
}
/// Handles HTTP connection
/// 1. With graceful shutdowns
/// 2. With graceful request cancellation with connection failure
/// 3. With websocket upgrade support.
#[allow(clippy::too_many_arguments)]
async fn connection_handler(
async fn http_connection_handler(
config: &'static ProxyConfig,
backend: Arc<PoolingBackend>,
connections: TaskTracker,
cancellation_handler: Arc<CancellationHandlerMain>,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
cancellation_token: CancellationToken,
conn: AsyncRW,
peer_addr: IpAddr,
connection_token: TaskTrackerToken,
conn: ConnWithInfo,
session_id: uuid::Uuid,
) {
) -> Option<WsUpgrade> {
let (conn, peer_addr, alpn) = conn;
let session_id = AtomicTake::new(session_id);
// Cancel all current inflight HTTP requests if the HTTP connection is closed.
let http_cancellation_token = CancellationToken::new();
let _cancel_connection = http_cancellation_token.clone().drop_guard();
let server = Builder::new(TokioExecutor::new());
let conn = server.serve_connection_with_upgrades(
hyper_util::rt::TokioIo::new(conn),
hyper::service::service_fn(move |req: hyper::Request<Incoming>| {
// First HTTP request shares the same session ID
let mut session_id = session_id.take().unwrap_or_else(uuid::Uuid::new_v4);
let (ws_tx, ws_rx) = oneshot::channel();
let ws_tx = Arc::new(AtomicTake::new(ws_tx));
if matches!(backend.auth_backend, crate::auth::Backend::Local(_)) {
// take session_id from request, if given.
if let Some(id) = req
.headers()
.get(&NEON_REQUEST_ID)
.and_then(|id| uuid::Uuid::try_parse_ascii(id.as_bytes()).ok())
{
session_id = id;
}
}
let http2 = match &*alpn {
b"h2" => true,
b"http/1.1" => false,
_ => {
debug!("no alpn negotiated");
false
}
};
// Cancel the current inflight HTTP request if the requets stream is closed.
// This is slightly different to `_cancel_connection` in that
// h2 can cancel individual requests with a `RST_STREAM`.
let http_request_token = http_cancellation_token.child_token();
let cancel_request = http_request_token.clone().drop_guard();
let service = ProxyService {
config,
backend,
connection_token,
// `request_handler` is not cancel safe. It expects to be cancelled only at specific times.
// By spawning the future, we ensure it never gets cancelled until it decides to.
let handler = connections.spawn(
request_handler(
req,
config,
backend.clone(),
connections.clone(),
cancellation_handler.clone(),
session_id,
peer_addr,
http_request_token,
endpoint_rate_limiter.clone(),
)
.in_current_span()
.map_ok_or_else(api_error_into_response, |r| r),
);
async move {
let mut res = handler.await;
cancel_request.disarm();
http_cancellation_token,
ws_tx,
peer_addr,
session_id,
};
// add the session ID to the response
if let Ok(resp) = &mut res {
resp.headers_mut()
.append(&NEON_REQUEST_ID, uuid_to_header_value(session_id));
}
let io = hyper_util::rt::TokioIo::new(conn);
let conn: Pin<Box<dyn GracefulShutdown>> = if http2 {
service.ws_tx.take();
res
}
}),
);
Box::pin(
hyper::server::conn::http2::Builder::new(TokioExecutor::new())
.serve_connection(io, service),
)
} else if config.http_config.accept_websockets {
Box::pin(
hyper::server::conn::http1::Builder::new()
.serve_connection(io, service)
.with_upgrades(),
)
} else {
service.ws_tx.take();
Box::pin(hyper::server::conn::http1::Builder::new().serve_connection(io, service))
};
// On cancellation, trigger the HTTP connection handler to shut down.
let res = match select(pin!(cancellation_token.cancelled()), pin!(conn)).await {
let res = match select(pin!(cancellation_token.cancelled()), conn).await {
Either::Left((_cancelled, mut conn)) => {
tracing::debug!(%peer_addr, "cancelling connection");
conn.as_mut().graceful_shutdown();
@@ -373,35 +410,156 @@ async fn connection_handler(
};
match res {
Ok(()) => tracing::info!(%peer_addr, "HTTP connection closed"),
Ok(()) => {
if let Ok(ws_upgrade) = ws_rx.await {
tracing::info!(%peer_addr, "connection upgraded to websockets");
return Some(ws_upgrade);
}
tracing::info!(%peer_addr, "HTTP connection closed");
}
Err(e) => tracing::warn!(%peer_addr, "HTTP connection error {e}"),
}
None
}
struct ProxyService {
// global state
config: &'static ProxyConfig,
backend: Arc<PoolingBackend>,
// connection state only
connection_token: TaskTrackerToken,
http_cancellation_token: CancellationToken,
ws_tx: WsSpawner,
peer_addr: IpAddr,
session_id: AtomicTake<uuid::Uuid>,
}
impl hyper::service::Service<hyper::Request<Incoming>> for ProxyService {
type Response = Response<BoxBody<Bytes, hyper::Error>>;
type Error = tokio::task::JoinError;
type Future = ReqFut;
fn call(&self, req: hyper::Request<Incoming>) -> Self::Future {
// First HTTP request shares the same session ID
let mut session_id = self.session_id.take().unwrap_or_else(uuid::Uuid::new_v4);
if matches!(self.backend.auth_backend, crate::auth::Backend::Local(_)) {
// take session_id from request, if given.
if let Some(id) = req
.headers()
.get(&NEON_REQUEST_ID)
.and_then(|id| uuid::Uuid::try_parse_ascii(id.as_bytes()).ok())
{
session_id = id;
}
}
// Cancel the current inflight HTTP request if the requets stream is closed.
// This is slightly different to `_cancel_connection` in that
// h2 can cancel individual requests with a `RST_STREAM`.
let http_req_cancellation_token = self.http_cancellation_token.child_token();
let cancel_request = Some(http_req_cancellation_token.clone().drop_guard());
let handle = request_handler(
req,
self.config,
self.backend.clone(),
self.ws_tx.clone(),
session_id,
self.peer_addr,
http_req_cancellation_token,
&self.connection_token,
);
ReqFut {
session_id,
cancel_request,
handle,
}
}
}
struct ReqFut {
session_id: uuid::Uuid,
cancel_request: Option<tokio_util::sync::DropGuard>,
handle: HandleOrResponse,
}
enum HandleOrResponse {
Handle(JoinHandle<Response<BoxBody<Bytes, hyper::Error>>>),
Response(Option<Response<BoxBody<Bytes, hyper::Error>>>),
}
impl Future for ReqFut {
type Output = Result<Response<BoxBody<Bytes, hyper::Error>>, tokio::task::JoinError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
let mut res = match &mut self.handle {
HandleOrResponse::Handle(join_handle) => std::task::ready!(join_handle.poll_unpin(cx)),
HandleOrResponse::Response(response) => {
Ok(response.take().expect("polled after completion"))
}
};
self.cancel_request
.take()
.map(tokio_util::sync::DropGuard::disarm);
// add the session ID to the response
if let Ok(resp) = &mut res {
resp.headers_mut()
.append(&NEON_REQUEST_ID, uuid_to_header_value(self.session_id));
}
Poll::Ready(res)
}
}
type WsUpgrade = (RequestMonitoring, Option<String>, OnUpgrade);
type WsSpawner = Arc<AtomicTake<oneshot::Sender<WsUpgrade>>>;
#[allow(clippy::too_many_arguments)]
async fn request_handler(
fn request_handler(
mut request: hyper::Request<Incoming>,
config: &'static ProxyConfig,
backend: Arc<PoolingBackend>,
ws_connections: TaskTracker,
cancellation_handler: Arc<CancellationHandlerMain>,
ws_spawner: WsSpawner,
session_id: uuid::Uuid,
peer_addr: IpAddr,
// used to cancel in-flight HTTP requests. not used to cancel websockets
http_cancellation_token: CancellationToken,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, ApiError> {
let host = request
.headers()
.get("host")
.and_then(|h| h.to_str().ok())
.and_then(|h| h.split(':').next())
.map(|s| s.to_string());
connection_token: &TaskTrackerToken,
) -> HandleOrResponse {
// Check if the request is a websocket upgrade request.
if config.http_config.accept_websockets
&& framed_websockets::upgrade::is_upgrade_request(&request)
{
if framed_websockets::upgrade::is_upgrade_request(&request) {
let Some(spawner) = ws_spawner.take() else {
return HandleOrResponse::Response(Some(
json_response(StatusCode::BAD_REQUEST, "query is not supported")
.unwrap_or_else(api_error_into_response),
));
};
let (response, websocket) = match framed_websockets::upgrade::upgrade(&mut request) {
Err(e) => {
return HandleOrResponse::Response(Some(api_error_into_response(
ApiError::BadRequest(e.into()),
)))
}
Ok(upgrade) => upgrade,
};
let host = request
.headers()
.get("host")
.and_then(|h| h.to_str().ok())
.and_then(|h| h.split(':').next())
.map(|s| s.to_string());
let ctx = RequestMonitoring::new(
session_id,
peer_addr,
@@ -409,33 +567,14 @@ async fn request_handler(
&config.region,
);
let span = ctx.span();
info!(parent: &span, "performing websocket upgrade");
let (response, websocket) = framed_websockets::upgrade::upgrade(&mut request)
.map_err(|e| ApiError::BadRequest(e.into()))?;
ws_connections.spawn(
async move {
if let Err(e) = websocket::serve_websocket(
config,
backend.auth_backend,
ctx,
websocket,
cancellation_handler,
endpoint_rate_limiter,
host,
)
.await
{
warn!("error in websocket connection: {e:#}");
}
}
.instrument(span),
);
if let Err(_e) = spawner.send((ctx, host, websocket)) {
return HandleOrResponse::Response(Some(api_error_into_response(
ApiError::InternalServerError(anyhow::anyhow!("could not upgrade WS connection")),
)));
}
// Return the response so the spawned future can continue.
Ok(response.map(|b| b.map_err(|x| match x {}).boxed()))
HandleOrResponse::Response(Some(response.map(|b| b.map_err(|x| match x {}).boxed())))
} else if request.uri().path() == "/sql" && *request.method() == Method::POST {
let ctx = RequestMonitoring::new(
session_id,
@@ -445,11 +584,19 @@ async fn request_handler(
);
let span = ctx.span();
sql_over_http::handle(config, ctx, request, backend, http_cancellation_token)
.instrument(span)
.await
let token = connection_token.clone();
// `sql_over_http::handle` is not cancel safe. It expects to be cancelled only at specific times.
// By spawning the future, we ensure it never gets cancelled until it decides to.
HandleOrResponse::Handle(tokio::spawn(async move {
let _token = token;
sql_over_http::handle(config, ctx, request, backend, http_cancellation_token)
.instrument(span)
.await
.unwrap_or_else(api_error_into_response)
}))
} else if request.uri().path() == "/sql" && *request.method() == Method::OPTIONS {
Response::builder()
HandleOrResponse::Response(Some( Response::builder()
.header("Allow", "OPTIONS, POST")
.header("Access-Control-Allow-Origin", "*")
.header(
@@ -459,8 +606,11 @@ async fn request_handler(
.header("Access-Control-Max-Age", "86400" /* 24 hours */)
.status(StatusCode::OK) // 204 is also valid, but see: https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/OPTIONS#status_code
.body(Empty::new().map_err(|x| match x {}).boxed())
.map_err(|e| ApiError::InternalServerError(e.into()))
.map_err(|e| ApiError::InternalServerError(e.into())).unwrap_or_else(api_error_into_response)))
} else {
json_response(StatusCode::BAD_REQUEST, "query is not supported")
HandleOrResponse::Response(Some(
json_response(StatusCode::BAD_REQUEST, "query is not supported")
.unwrap_or_else(api_error_into_response),
))
}
}