From 219c72c24c26be8259e6b86169bc987925d47ab1 Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Thu, 29 May 2025 17:52:26 +0100 Subject: [PATCH] optimise proxy_pass memory size a little, also boxing requestcontext since it is large --- proxy/src/console_redirect_proxy.rs | 21 ++----------------- proxy/src/context/mod.rs | 12 +++++------ proxy/src/proxy/mod.rs | 29 +++------------------------ proxy/src/proxy/passthrough.rs | 31 +++++++++++++++++++++++------ proxy/src/serverless/mod.rs | 15 +++++++++----- proxy/src/serverless/websocket.rs | 20 ++++++------------- 6 files changed, 52 insertions(+), 76 deletions(-) diff --git a/proxy/src/console_redirect_proxy.rs b/proxy/src/console_redirect_proxy.rs index 971ca273db..e7fd057641 100644 --- a/proxy/src/console_redirect_proxy.rs +++ b/proxy/src/console_redirect_proxy.rs @@ -16,9 +16,7 @@ use crate::protocol2::{ConnectHeader, ConnectionInfo, read_proxy_protocol}; use crate::proxy::connect_compute::{TcpMechanism, connect_to_compute}; use crate::proxy::handshake::{HandshakeData, handshake}; use crate::proxy::passthrough::ProxyPassthrough; -use crate::proxy::{ - ClientRequestError, ErrorSource, prepare_client_connection, run_until_cancelled, -}; +use crate::proxy::{ClientRequestError, prepare_client_connection, run_until_cancelled}; pub async fn task_main( config: &'static ProxyConfig, @@ -126,22 +124,7 @@ pub async fn task_main( } Ok(Some(p)) => { ctx.set_success(); - let _disconnect = ctx.log_connect(); - match p.proxy_pass(&config.connect_to_compute).await { - Ok(()) => {} - Err(ErrorSource::Client(e)) => { - error!( - ?session_id, - "per-client task finished with an IO error from the client: {e:#}" - ); - } - Err(ErrorSource::Compute(e)) => { - error!( - ?session_id, - "per-client task finished with an IO error from the compute: {e:#}" - ); - } - } + tokio::spawn(p.proxy_pass(ctx, &config.connect_to_compute)); } } }); diff --git a/proxy/src/context/mod.rs b/proxy/src/context/mod.rs index 79aaf22990..14b8e219dd 100644 --- a/proxy/src/context/mod.rs +++ b/proxy/src/context/mod.rs @@ -38,7 +38,7 @@ pub struct RequestContext( /// I would typically use a RefCell but that would break the `Send` requirements /// so we need something with thread-safety. `TryLock` is a cheap alternative /// that offers similar semantics to a `RefCell` but with synchronisation. - TryLock, + TryLock>, ); struct RequestContextInner { @@ -89,7 +89,7 @@ pub(crate) enum AuthMethod { impl Clone for RequestContext { fn clone(&self) -> Self { let inner = self.0.try_lock().expect("should not deadlock"); - let new = RequestContextInner { + let new = Box::new(RequestContextInner { conn_info: inner.conn_info.clone(), session_id: inner.session_id, protocol: inner.protocol, @@ -117,7 +117,7 @@ impl Clone for RequestContext { disconnect_sender: None, latency_timer: LatencyTimer::noop(inner.protocol), disconnect_timestamp: inner.disconnect_timestamp, - }; + }); Self(TryLock::new(new)) } @@ -140,7 +140,7 @@ impl RequestContext { role = tracing::field::Empty, ); - let inner = RequestContextInner { + let inner = Box::new(RequestContextInner { conn_info, session_id, protocol, @@ -168,7 +168,7 @@ impl RequestContext { disconnect_sender: LOG_CHAN_DISCONNECT.get().and_then(|tx| tx.upgrade()), latency_timer: LatencyTimer::new(protocol), disconnect_timestamp: None, - }; + }); Self(TryLock::new(inner)) } @@ -522,7 +522,7 @@ impl Drop for RequestContextInner { } } -pub struct DisconnectLogger(RequestContextInner); +pub struct DisconnectLogger(Box); impl Drop for DisconnectLogger { fn drop(&mut self) { diff --git a/proxy/src/proxy/mod.rs b/proxy/src/proxy/mod.rs index 9bb3f2f305..2843823cd9 100644 --- a/proxy/src/proxy/mod.rs +++ b/proxy/src/proxy/mod.rs @@ -154,42 +154,19 @@ pub async fn task_main( .boxed() .await; - let passthrough = match res { + match res { Err(e) => { ctx.set_error_kind(e.get_error_kind()); warn!(parent: &ctx.span(), "per-client task finished with an error: {e:#}"); - return; } Ok(None) => { ctx.set_success(); - return; } Ok(Some(p)) => { ctx.set_success(); - p + tokio::spawn(p.proxy_pass(ctx, &config.connect_to_compute)); } - }; - - // spawn passthrough as a new task. - // holds a task tracker token to prevent the proxy shutting down early. - tokio::spawn(async move { - let _disconnect = ctx.log_connect(); - match passthrough.proxy_pass(&config.connect_to_compute).await { - Ok(()) => {} - Err(ErrorSource::Client(e)) => { - warn!( - ?session_id, - "per-client task finished with an IO error from the client: {e:#}" - ); - } - Err(ErrorSource::Compute(e)) => { - error!( - ?session_id, - "per-client task finished with an IO error from the compute: {e:#}" - ); - } - } - }); + } }); } diff --git a/proxy/src/proxy/passthrough.rs b/proxy/src/proxy/passthrough.rs index a1762285ef..b27a3b9034 100644 --- a/proxy/src/proxy/passthrough.rs +++ b/proxy/src/proxy/passthrough.rs @@ -8,6 +8,7 @@ use super::copy_bidirectional::ErrorSource; use crate::cancellation; use crate::compute::PostgresConnection; use crate::config::ComputeConfig; +use crate::context::RequestContext; use crate::control_plane::messages::MetricsAuxInfo; use crate::metrics::{Direction, Metrics, NumClientConnectionsGuard, NumConnectionRequestsGuard}; use crate::stream::Stream; @@ -75,11 +76,13 @@ pub(crate) struct ProxyPassthrough { pub(crate) _tracker: TaskTrackerToken, } -impl ProxyPassthrough { +impl ProxyPassthrough { pub(crate) async fn proxy_pass( self, - compute_config: &ComputeConfig, - ) -> Result<(), ErrorSource> { + ctx: RequestContext, + compute_config: &'static ComputeConfig, + ) { + let _disconnect = ctx.log_connect(); let res = proxy_pass( self.client, self.compute.stream, @@ -87,6 +90,23 @@ impl ProxyPassthrough { self.private_link_id, ) .await; + + match res { + Ok(()) => {} + Err(ErrorSource::Client(e)) => { + tracing::warn!( + session_id = ?self.session_id, + "per-client task finished with an IO error from the client: {e:#}" + ); + } + Err(ErrorSource::Compute(e)) => { + tracing::error!( + session_id = ?self.session_id, + "per-client task finished with an IO error from the compute: {e:#}" + ); + } + } + if let Err(err) = self .compute .cancel_closure @@ -96,8 +116,7 @@ impl ProxyPassthrough { tracing::warn!(session_id = ?self.session_id, ?err, "could not cancel the query in the database"); } - drop(self.cancel.remove_cancel_key()); // we don't need a result. If the queue is full, we just log the error - - res + // we don't need a result. If the queue is full, we just log the error + drop(self.cancel.remove_cancel_key()); } } diff --git a/proxy/src/serverless/mod.rs b/proxy/src/serverless/mod.rs index a080d8c4d6..08cc9ddb53 100644 --- a/proxy/src/serverless/mod.rs +++ b/proxy/src/serverless/mod.rs @@ -437,7 +437,15 @@ async fn request_handler( tokio::spawn( async move { - if let Err(e) = websocket::serve_websocket( + let websocket = match websocket.await { + Err(e) => { + warn!("could not upgrade websocket connection: {e:#}"); + return; + } + Ok(websocket) => websocket, + }; + + websocket::serve_websocket( config, backend.auth_backend, ctx, @@ -447,10 +455,7 @@ async fn request_handler( host, tracker, ) - .await - { - warn!("error in websocket connection: {e:#}"); - } + .await; } .instrument(span), ); diff --git a/proxy/src/serverless/websocket.rs b/proxy/src/serverless/websocket.rs index d0d207ea40..eb1c90e7f9 100644 --- a/proxy/src/serverless/websocket.rs +++ b/proxy/src/serverless/websocket.rs @@ -2,11 +2,10 @@ use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll, ready}; -use anyhow::Context as _; use bytes::{Buf, BufMut, Bytes, BytesMut}; use framed_websockets::{Frame, OpCode, WebSocketServer}; use futures::{Sink, Stream}; -use hyper::upgrade::OnUpgrade; +use hyper::upgrade::Upgraded; use hyper_util::rt::TokioIo; use pin_project_lite::pin_project; use tokio::io::{self, AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf}; @@ -18,7 +17,7 @@ use crate::config::ProxyConfig; use crate::context::RequestContext; use crate::error::ReportableError; use crate::metrics::Metrics; -use crate::proxy::{ClientMode, ErrorSource, handle_client}; +use crate::proxy::{ClientMode, handle_client}; use crate::rate_limiter::EndpointRateLimiter; pin_project! { @@ -129,13 +128,12 @@ pub(crate) async fn serve_websocket( config: &'static ProxyConfig, auth_backend: &'static crate::auth::Backend<'static, ()>, ctx: RequestContext, - websocket: OnUpgrade, + websocket: Upgraded, cancellation_handler: Arc, endpoint_rate_limiter: Arc, hostname: Option, tracker: TaskTrackerToken, -) -> anyhow::Result<()> { - let websocket = websocket.await?; +) { let websocket = WebSocketServer::after_handshake(TokioIo::new(websocket)); let conn_gauge = Metrics::get() @@ -159,20 +157,14 @@ pub(crate) async fn serve_websocket( match res { Err(e) => { ctx.set_error_kind(e.get_error_kind()); - Err(e.into()) + tracing::warn!(parent: &ctx.span(), "per-client task finished with an error: {e:#}"); } Ok(None) => { ctx.set_success(); - Ok(()) } Ok(Some(p)) => { ctx.set_success(); - ctx.log_connect(); - match p.proxy_pass(&config.connect_to_compute).await { - Ok(()) => Ok(()), - Err(ErrorSource::Client(err)) => Err(err).context("client"), - Err(ErrorSource::Compute(err)) => Err(err).context("compute"), - } + tokio::spawn(p.proxy_pass(ctx, &config.connect_to_compute)); } } }