From 781bf4945d9cb3902de829a187ee0e7ebc71e432 Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Mon, 2 Jun 2025 17:13:30 +0100 Subject: [PATCH] proxy: optimise future layout allocations (#12104) A smaller version of #12066 that is somewhat easier to review. Now that I've been using https://crates.io/crates/top-type-sizes I've found a lot more of the low hanging fruit that can be tweaks to reduce the memory usage. Some context for the optimisations: Rust's stack allocation in futures is quite naive. Stack variables, even if moved, often still end up taking space in the future. Rearranging the order in which variables are defined, and properly scoping them can go a long way. `async fn` and `async move {}` have a consequence that they always duplicate the "upvars" (aka captures). All captures are permanently allocated in the future, even if moved. We can be mindful when writing futures to only capture as little as possible. TlsStream is massive. Needs boxing so it doesn't contribute to the above issue. ## Measurements from `top-type-sizes`: ### Before ``` 10328 {async block@proxy::proxy::task_main::{closure#0}::{closure#0}} align=8 6120 {async fn body of proxy::proxy::handle_client>()} align=8 ``` ### After ``` 4040 {async block@proxy::proxy::task_main::{closure#0}::{closure#0}} 4704 {async fn body of proxy::proxy::handle_client>()} align=8 ``` --- proxy/src/auth/backend/classic.rs | 16 ++-- proxy/src/console_redirect_proxy.rs | 2 +- .../control_plane/client/cplane_proxy_v1.rs | 75 +++++++++++-------- proxy/src/http/mod.rs | 27 +++++-- proxy/src/pqproto.rs | 6 +- proxy/src/proxy/handshake.rs | 9 ++- proxy/src/proxy/mod.rs | 2 +- proxy/src/proxy/passthrough.rs | 2 + proxy/src/sasl/stream.rs | 49 ++++++------ proxy/src/stream.rs | 4 +- proxy/src/tls/postgres_rustls.rs | 6 +- 11 files changed, 115 insertions(+), 83 deletions(-) diff --git a/proxy/src/auth/backend/classic.rs b/proxy/src/auth/backend/classic.rs index dcc500f2c8..8445368740 100644 --- a/proxy/src/auth/backend/classic.rs +++ b/proxy/src/auth/backend/classic.rs @@ -25,19 +25,15 @@ pub(super) async fn authenticate( } AuthSecret::Scram(secret) => { debug!("auth endpoint chooses SCRAM"); - let scram = auth::Scram(&secret, ctx); - let auth_outcome = tokio::time::timeout(config.scram_protocol_timeout, async { - AuthFlow::new(client, scram) - .authenticate() - .await - .inspect_err(|error| { - warn!(?error, "error processing scram messages"); - }) - }) + let auth_outcome = tokio::time::timeout( + config.scram_protocol_timeout, + AuthFlow::new(client, auth::Scram(&secret, ctx)).authenticate(), + ) .await .inspect_err(|_| warn!("error processing scram messages error = authentication timed out, execution time exceeded {} seconds", config.scram_protocol_timeout.as_secs())) - .map_err(auth::AuthError::user_timeout)??; + .map_err(auth::AuthError::user_timeout)? + .inspect_err(|error| warn!(?error, "error processing scram messages"))?; let client_key = match auth_outcome { sasl::Outcome::Success(key) => key, diff --git a/proxy/src/console_redirect_proxy.rs b/proxy/src/console_redirect_proxy.rs index 9499aba61b..7fb84b5ee5 100644 --- a/proxy/src/console_redirect_proxy.rs +++ b/proxy/src/console_redirect_proxy.rs @@ -159,7 +159,7 @@ pub async fn task_main( } #[allow(clippy::too_many_arguments)] -pub(crate) async fn handle_client( +pub(crate) async fn handle_client( config: &'static ProxyConfig, backend: &'static ConsoleRedirectBackend, ctx: &RequestContext, diff --git a/proxy/src/control_plane/client/cplane_proxy_v1.rs b/proxy/src/control_plane/client/cplane_proxy_v1.rs index 93f4ea6cf7..da548d6b2c 100644 --- a/proxy/src/control_plane/client/cplane_proxy_v1.rs +++ b/proxy/src/control_plane/client/cplane_proxy_v1.rs @@ -7,7 +7,9 @@ use std::time::Duration; use ::http::HeaderName; use ::http::header::AUTHORIZATION; +use bytes::Bytes; use futures::TryFutureExt; +use hyper::StatusCode; use postgres_client::config::SslMode; use tokio::time::Instant; use tracing::{Instrument, debug, info, info_span, warn}; @@ -72,28 +74,34 @@ impl NeonControlPlaneClient { role: &RoleName, ) -> Result { async { - let request = self - .endpoint - .get_path("get_endpoint_access_control") - .header(X_REQUEST_ID, ctx.session_id().to_string()) - .header(AUTHORIZATION, format!("Bearer {}", &self.jwt)) - .query(&[("session_id", ctx.session_id())]) - .query(&[ - ("application_name", ctx.console_application_name().as_str()), - ("endpointish", endpoint.as_str()), - ("role", role.as_str()), - ]) - .build()?; - - debug!(url = request.url().as_str(), "sending http request"); - let start = Instant::now(); let response = { - let _pause = ctx.latency_timer_pause_at(start, crate::metrics::Waiting::Cplane); - self.endpoint.execute(request).await? - }; - info!(duration = ?start.elapsed(), "received http response"); + let request = self + .endpoint + .get_path("get_endpoint_access_control") + .header(X_REQUEST_ID, ctx.session_id().to_string()) + .header(AUTHORIZATION, format!("Bearer {}", &self.jwt)) + .query(&[("session_id", ctx.session_id())]) + .query(&[ + ("application_name", ctx.console_application_name().as_str()), + ("endpointish", endpoint.as_str()), + ("role", role.as_str()), + ]) + .build()?; - let body = match parse_body::(response).await { + debug!(url = request.url().as_str(), "sending http request"); + let start = Instant::now(); + let _pause = ctx.latency_timer_pause_at(start, crate::metrics::Waiting::Cplane); + let response = self.endpoint.execute(request).await?; + + info!(duration = ?start.elapsed(), "received http response"); + + response + }; + + let body = match parse_body::( + response.status(), + response.bytes().await?, + ) { Ok(body) => body, // Error 404 is special: it's ok not to have a secret. // TODO(anna): retry @@ -184,7 +192,10 @@ impl NeonControlPlaneClient { drop(pause); info!(duration = ?start.elapsed(), "received http response"); - let body = parse_body::(response).await?; + let body = parse_body::( + response.status(), + response.bytes().await.map_err(ControlPlaneError::from)?, + )?; let rules = body .jwks @@ -236,7 +247,7 @@ impl NeonControlPlaneClient { let response = self.endpoint.execute(request).await?; drop(pause); info!(duration = ?start.elapsed(), "received http response"); - let body = parse_body::(response).await?; + let body = parse_body::(response.status(), response.bytes().await?)?; // Unfortunately, ownership won't let us use `Option::ok_or` here. let (host, port) = match parse_host_port(&body.address) { @@ -487,33 +498,33 @@ impl super::ControlPlaneApi for NeonControlPlaneClient { } /// Parse http response body, taking status code into account. -async fn parse_body serde::Deserialize<'a>>( - response: http::Response, +fn parse_body serde::Deserialize<'a>>( + status: StatusCode, + body: Bytes, ) -> Result { - let status = response.status(); if status.is_success() { // We shouldn't log raw body because it may contain secrets. info!("request succeeded, processing the body"); - return Ok(response.json().await?); + return Ok(serde_json::from_slice(&body).map_err(std::io::Error::other)?); } - let s = response.bytes().await?; + // Log plaintext to be able to detect, whether there are some cases not covered by the error struct. - info!("response_error plaintext: {:?}", s); + info!("response_error plaintext: {:?}", body); // Don't throw an error here because it's not as important // as the fact that the request itself has failed. - let mut body = serde_json::from_slice(&s).unwrap_or_else(|e| { + let mut body = serde_json::from_slice(&body).unwrap_or_else(|e| { warn!("failed to parse error body: {e}"); - ControlPlaneErrorMessage { + Box::new(ControlPlaneErrorMessage { error: "reason unclear (malformed error message)".into(), http_status_code: status, status: None, - } + }) }); body.http_status_code = status; warn!("console responded with an error ({status}): {body:?}"); - Err(ControlPlaneError::Message(Box::new(body))) + Err(ControlPlaneError::Message(body)) } fn parse_host_port(input: &str) -> Option<(&str, u16)> { diff --git a/proxy/src/http/mod.rs b/proxy/src/http/mod.rs index 96f600d836..36607e7861 100644 --- a/proxy/src/http/mod.rs +++ b/proxy/src/http/mod.rs @@ -4,9 +4,10 @@ pub mod health_server; -use std::time::Duration; +use std::time::{Duration, Instant}; use bytes::Bytes; +use futures::FutureExt; use http::Method; use http_body_util::BodyExt; use hyper::body::Body; @@ -109,15 +110,31 @@ impl Endpoint { } /// Execute a [request](reqwest::Request). - pub(crate) async fn execute(&self, request: Request) -> Result { - let _timer = Metrics::get() + pub(crate) fn execute( + &self, + request: Request, + ) -> impl Future> { + let metric = Metrics::get() .proxy .console_request_latency - .start_timer(ConsoleRequest { + .with_labels(ConsoleRequest { request: request.url().path(), }); - self.client.execute(request).await + let req = self.client.execute(request).boxed(); + + async move { + let start = Instant::now(); + scopeguard::defer!({ + Metrics::get() + .proxy + .console_request_latency + .get_metric(metric) + .observe_duration_since(start); + }); + + req.await + } } } diff --git a/proxy/src/pqproto.rs b/proxy/src/pqproto.rs index d68d9f9474..43074bf208 100644 --- a/proxy/src/pqproto.rs +++ b/proxy/src/pqproto.rs @@ -186,7 +186,7 @@ where pub async fn read_message<'a, S>( stream: &mut S, buf: &'a mut Vec, - max: usize, + max: u32, ) -> io::Result<(u8, &'a mut [u8])> where S: AsyncRead + Unpin, @@ -206,7 +206,7 @@ where let header = read!(stream => Header); // as described above, the length must be at least 4. - let Some(len) = (header.len.get() as usize).checked_sub(4) else { + let Some(len) = header.len.get().checked_sub(4) else { return Err(io::Error::other(format!( "invalid startup message length {}, must be at least 4.", header.len, @@ -222,7 +222,7 @@ where } // read in our entire message. - buf.resize(len, 0); + buf.resize(len as usize, 0); stream.read_exact(buf).await?; Ok((header.tag, buf)) diff --git a/proxy/src/proxy/handshake.rs b/proxy/src/proxy/handshake.rs index 13ee8c7dd2..6970ab8714 100644 --- a/proxy/src/proxy/handshake.rs +++ b/proxy/src/proxy/handshake.rs @@ -1,3 +1,4 @@ +use futures::{FutureExt, TryFutureExt}; use thiserror::Error; use tokio::io::{AsyncRead, AsyncWrite}; use tracing::{debug, info, warn}; @@ -57,7 +58,7 @@ pub(crate) enum HandshakeData { /// It's easier to work with owned `stream` here as we need to upgrade it to TLS; /// we also take an extra care of propagating only the select handshake errors to client. #[tracing::instrument(skip_all)] -pub(crate) async fn handshake( +pub(crate) async fn handshake( ctx: &RequestContext, stream: S, mut tls: Option<&TlsConfig>, @@ -108,7 +109,9 @@ pub(crate) async fn handshake( } } } - }); + }) + .map_ok(Box::new) + .boxed(); res?; @@ -146,7 +149,7 @@ pub(crate) async fn handshake( tls.cert_resolver.resolve(conn_info.server_name()); let tls = Stream::Tls { - tls: Box::new(tls_stream), + tls: tls_stream, tls_server_end_point, }; (stream, msg) = PqStream::parse_startup(tls).await?; diff --git a/proxy/src/proxy/mod.rs b/proxy/src/proxy/mod.rs index ac0aca1176..0ffc54aa88 100644 --- a/proxy/src/proxy/mod.rs +++ b/proxy/src/proxy/mod.rs @@ -270,7 +270,7 @@ impl ReportableError for ClientRequestError { } #[allow(clippy::too_many_arguments)] -pub(crate) async fn handle_client( +pub(crate) async fn handle_client( config: &'static ProxyConfig, auth_backend: &'static auth::Backend<'static, ()>, ctx: &RequestContext, diff --git a/proxy/src/proxy/passthrough.rs b/proxy/src/proxy/passthrough.rs index 8f9bd2de2d..55ab5f4dba 100644 --- a/proxy/src/proxy/passthrough.rs +++ b/proxy/src/proxy/passthrough.rs @@ -1,3 +1,4 @@ +use futures::FutureExt; use smol_str::SmolStr; use tokio::io::{AsyncRead, AsyncWrite}; use tracing::debug; @@ -89,6 +90,7 @@ impl ProxyPassthrough { .compute .cancel_closure .try_cancel_query(compute_config) + .boxed() .await { tracing::warn!(session_id = ?self.session_id, ?err, "could not cancel the query in the database"); diff --git a/proxy/src/sasl/stream.rs b/proxy/src/sasl/stream.rs index cb15132673..52ccca58d5 100644 --- a/proxy/src/sasl/stream.rs +++ b/proxy/src/sasl/stream.rs @@ -30,52 +30,53 @@ where F: FnOnce(&str) -> super::Result, M: Mechanism, { - let sasl = { + let (mut mechanism, mut input) = { // pause the timer while we communicate with the client let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client); // Initial client message contains the chosen auth method's name. let msg = stream.read_password_message().await?; - super::FirstMessage::parse(msg).ok_or(super::Error::BadClientMessage("bad sasl message"))? + + let sasl = super::FirstMessage::parse(msg) + .ok_or(super::Error::BadClientMessage("bad sasl message"))?; + + (mechanism(sasl.method)?, sasl.message) }; - let mut mechanism = mechanism(sasl.method)?; - let mut input = sasl.message; loop { - let step = mechanism - .exchange(input) - .inspect_err(|error| tracing::info!(?error, "error during SASL exchange"))?; - - match step { - Step::Continue(moved_mechanism, reply) => { + match mechanism.exchange(input) { + Ok(Step::Continue(moved_mechanism, reply)) => { mechanism = moved_mechanism; - // pause the timer while we communicate with the client - let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client); - // write reply let sasl_msg = BeAuthenticationSaslMessage::Continue(reply.as_bytes()); stream.write_message(BeMessage::AuthenticationSasl(sasl_msg)); - - // get next input - stream.flush().await?; - let msg = stream.read_password_message().await?; - input = std::str::from_utf8(msg) - .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "bad encoding"))?; + drop(reply); } - Step::Success(result, reply) => { - // pause the timer while we communicate with the client - let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client); - + Ok(Step::Success(result, reply)) => { // write reply let sasl_msg = BeAuthenticationSaslMessage::Final(reply.as_bytes()); stream.write_message(BeMessage::AuthenticationSasl(sasl_msg)); stream.write_message(BeMessage::AuthenticationOk); + // exit with success break Ok(Outcome::Success(result)); } // exit with failure - Step::Failure(reason) => break Ok(Outcome::Failure(reason)), + Ok(Step::Failure(reason)) => break Ok(Outcome::Failure(reason)), + Err(error) => { + tracing::info!(?error, "error during SASL exchange"); + return Err(error); + } } + + // pause the timer while we communicate with the client + let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client); + + // get next input + stream.flush().await?; + let msg = stream.read_password_message().await?; + input = std::str::from_utf8(msg) + .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "bad encoding"))?; } } diff --git a/proxy/src/stream.rs b/proxy/src/stream.rs index 7126430a85..c49a431c95 100644 --- a/proxy/src/stream.rs +++ b/proxy/src/stream.rs @@ -72,7 +72,7 @@ impl PqStream { impl PqStream { /// Read a raw postgres packet, which will respect the max length requested. /// This is not cancel safe. - async fn read_raw_expect(&mut self, tag: u8, max: usize) -> io::Result<&mut [u8]> { + async fn read_raw_expect(&mut self, tag: u8, max: u32) -> io::Result<&mut [u8]> { let (actual_tag, msg) = read_message(&mut self.stream, &mut self.read, max).await?; if actual_tag != tag { return Err(io::Error::other(format!( @@ -89,7 +89,7 @@ impl PqStream { // passwords are usually pretty short // and SASL SCRAM messages are no longer than 256 bytes in my testing // (a few hashes and random bytes, encoded into base64). - const MAX_PASSWORD_LENGTH: usize = 512; + const MAX_PASSWORD_LENGTH: u32 = 512; self.read_raw_expect(FE_PASSWORD_MESSAGE, MAX_PASSWORD_LENGTH) .await } diff --git a/proxy/src/tls/postgres_rustls.rs b/proxy/src/tls/postgres_rustls.rs index f09e916a1d..013b307f0b 100644 --- a/proxy/src/tls/postgres_rustls.rs +++ b/proxy/src/tls/postgres_rustls.rs @@ -31,7 +31,9 @@ mod private { type Output = io::Result>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - Pin::new(&mut self.inner).poll(cx).map_ok(RustlsStream) + Pin::new(&mut self.inner) + .poll(cx) + .map_ok(|s| RustlsStream(Box::new(s))) } } @@ -57,7 +59,7 @@ mod private { } } - pub struct RustlsStream(TlsStream); + pub struct RustlsStream(Box>); impl postgres_client::tls::TlsStream for RustlsStream where