mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-08 14:02:55 +00:00
proxy: optimise future layout allocations (#12104)
A smaller version of #12066 that is somewhat easier to review. Now that I've been using https://crates.io/crates/top-type-sizes I've found a lot more of the low hanging fruit that can be tweaks to reduce the memory usage. Some context for the optimisations: Rust's stack allocation in futures is quite naive. Stack variables, even if moved, often still end up taking space in the future. Rearranging the order in which variables are defined, and properly scoping them can go a long way. `async fn` and `async move {}` have a consequence that they always duplicate the "upvars" (aka captures). All captures are permanently allocated in the future, even if moved. We can be mindful when writing futures to only capture as little as possible. TlsStream is massive. Needs boxing so it doesn't contribute to the above issue. ## Measurements from `top-type-sizes`: ### Before ``` 10328 {async block@proxy::proxy::task_main::{closure#0}::{closure#0}} align=8 6120 {async fn body of proxy::proxy::handle_client<proxy::protocol2::ChainRW<tokio::net::TcpStream>>()} align=8 ``` ### After ``` 4040 {async block@proxy::proxy::task_main::{closure#0}::{closure#0}} 4704 {async fn body of proxy::proxy::handle_client<proxy::protocol2::ChainRW<tokio::net::TcpStream>>()} align=8 ```
This commit is contained in:
@@ -25,19 +25,15 @@ pub(super) async fn authenticate(
|
||||
}
|
||||
AuthSecret::Scram(secret) => {
|
||||
debug!("auth endpoint chooses SCRAM");
|
||||
let scram = auth::Scram(&secret, ctx);
|
||||
|
||||
let auth_outcome = tokio::time::timeout(config.scram_protocol_timeout, async {
|
||||
AuthFlow::new(client, scram)
|
||||
.authenticate()
|
||||
.await
|
||||
.inspect_err(|error| {
|
||||
warn!(?error, "error processing scram messages");
|
||||
})
|
||||
})
|
||||
let auth_outcome = tokio::time::timeout(
|
||||
config.scram_protocol_timeout,
|
||||
AuthFlow::new(client, auth::Scram(&secret, ctx)).authenticate(),
|
||||
)
|
||||
.await
|
||||
.inspect_err(|_| warn!("error processing scram messages error = authentication timed out, execution time exceeded {} seconds", config.scram_protocol_timeout.as_secs()))
|
||||
.map_err(auth::AuthError::user_timeout)??;
|
||||
.map_err(auth::AuthError::user_timeout)?
|
||||
.inspect_err(|error| warn!(?error, "error processing scram messages"))?;
|
||||
|
||||
let client_key = match auth_outcome {
|
||||
sasl::Outcome::Success(key) => key,
|
||||
|
||||
@@ -159,7 +159,7 @@ pub async fn task_main(
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
config: &'static ProxyConfig,
|
||||
backend: &'static ConsoleRedirectBackend,
|
||||
ctx: &RequestContext,
|
||||
|
||||
@@ -7,7 +7,9 @@ use std::time::Duration;
|
||||
|
||||
use ::http::HeaderName;
|
||||
use ::http::header::AUTHORIZATION;
|
||||
use bytes::Bytes;
|
||||
use futures::TryFutureExt;
|
||||
use hyper::StatusCode;
|
||||
use postgres_client::config::SslMode;
|
||||
use tokio::time::Instant;
|
||||
use tracing::{Instrument, debug, info, info_span, warn};
|
||||
@@ -72,28 +74,34 @@ impl NeonControlPlaneClient {
|
||||
role: &RoleName,
|
||||
) -> Result<AuthInfo, GetAuthInfoError> {
|
||||
async {
|
||||
let request = self
|
||||
.endpoint
|
||||
.get_path("get_endpoint_access_control")
|
||||
.header(X_REQUEST_ID, ctx.session_id().to_string())
|
||||
.header(AUTHORIZATION, format!("Bearer {}", &self.jwt))
|
||||
.query(&[("session_id", ctx.session_id())])
|
||||
.query(&[
|
||||
("application_name", ctx.console_application_name().as_str()),
|
||||
("endpointish", endpoint.as_str()),
|
||||
("role", role.as_str()),
|
||||
])
|
||||
.build()?;
|
||||
|
||||
debug!(url = request.url().as_str(), "sending http request");
|
||||
let start = Instant::now();
|
||||
let response = {
|
||||
let _pause = ctx.latency_timer_pause_at(start, crate::metrics::Waiting::Cplane);
|
||||
self.endpoint.execute(request).await?
|
||||
};
|
||||
info!(duration = ?start.elapsed(), "received http response");
|
||||
let request = self
|
||||
.endpoint
|
||||
.get_path("get_endpoint_access_control")
|
||||
.header(X_REQUEST_ID, ctx.session_id().to_string())
|
||||
.header(AUTHORIZATION, format!("Bearer {}", &self.jwt))
|
||||
.query(&[("session_id", ctx.session_id())])
|
||||
.query(&[
|
||||
("application_name", ctx.console_application_name().as_str()),
|
||||
("endpointish", endpoint.as_str()),
|
||||
("role", role.as_str()),
|
||||
])
|
||||
.build()?;
|
||||
|
||||
let body = match parse_body::<GetEndpointAccessControl>(response).await {
|
||||
debug!(url = request.url().as_str(), "sending http request");
|
||||
let start = Instant::now();
|
||||
let _pause = ctx.latency_timer_pause_at(start, crate::metrics::Waiting::Cplane);
|
||||
let response = self.endpoint.execute(request).await?;
|
||||
|
||||
info!(duration = ?start.elapsed(), "received http response");
|
||||
|
||||
response
|
||||
};
|
||||
|
||||
let body = match parse_body::<GetEndpointAccessControl>(
|
||||
response.status(),
|
||||
response.bytes().await?,
|
||||
) {
|
||||
Ok(body) => body,
|
||||
// Error 404 is special: it's ok not to have a secret.
|
||||
// TODO(anna): retry
|
||||
@@ -184,7 +192,10 @@ impl NeonControlPlaneClient {
|
||||
drop(pause);
|
||||
info!(duration = ?start.elapsed(), "received http response");
|
||||
|
||||
let body = parse_body::<EndpointJwksResponse>(response).await?;
|
||||
let body = parse_body::<EndpointJwksResponse>(
|
||||
response.status(),
|
||||
response.bytes().await.map_err(ControlPlaneError::from)?,
|
||||
)?;
|
||||
|
||||
let rules = body
|
||||
.jwks
|
||||
@@ -236,7 +247,7 @@ impl NeonControlPlaneClient {
|
||||
let response = self.endpoint.execute(request).await?;
|
||||
drop(pause);
|
||||
info!(duration = ?start.elapsed(), "received http response");
|
||||
let body = parse_body::<WakeCompute>(response).await?;
|
||||
let body = parse_body::<WakeCompute>(response.status(), response.bytes().await?)?;
|
||||
|
||||
// Unfortunately, ownership won't let us use `Option::ok_or` here.
|
||||
let (host, port) = match parse_host_port(&body.address) {
|
||||
@@ -487,33 +498,33 @@ impl super::ControlPlaneApi for NeonControlPlaneClient {
|
||||
}
|
||||
|
||||
/// Parse http response body, taking status code into account.
|
||||
async fn parse_body<T: for<'a> serde::Deserialize<'a>>(
|
||||
response: http::Response,
|
||||
fn parse_body<T: for<'a> serde::Deserialize<'a>>(
|
||||
status: StatusCode,
|
||||
body: Bytes,
|
||||
) -> Result<T, ControlPlaneError> {
|
||||
let status = response.status();
|
||||
if status.is_success() {
|
||||
// We shouldn't log raw body because it may contain secrets.
|
||||
info!("request succeeded, processing the body");
|
||||
return Ok(response.json().await?);
|
||||
return Ok(serde_json::from_slice(&body).map_err(std::io::Error::other)?);
|
||||
}
|
||||
let s = response.bytes().await?;
|
||||
|
||||
// Log plaintext to be able to detect, whether there are some cases not covered by the error struct.
|
||||
info!("response_error plaintext: {:?}", s);
|
||||
info!("response_error plaintext: {:?}", body);
|
||||
|
||||
// Don't throw an error here because it's not as important
|
||||
// as the fact that the request itself has failed.
|
||||
let mut body = serde_json::from_slice(&s).unwrap_or_else(|e| {
|
||||
let mut body = serde_json::from_slice(&body).unwrap_or_else(|e| {
|
||||
warn!("failed to parse error body: {e}");
|
||||
ControlPlaneErrorMessage {
|
||||
Box::new(ControlPlaneErrorMessage {
|
||||
error: "reason unclear (malformed error message)".into(),
|
||||
http_status_code: status,
|
||||
status: None,
|
||||
}
|
||||
})
|
||||
});
|
||||
body.http_status_code = status;
|
||||
|
||||
warn!("console responded with an error ({status}): {body:?}");
|
||||
Err(ControlPlaneError::Message(Box::new(body)))
|
||||
Err(ControlPlaneError::Message(body))
|
||||
}
|
||||
|
||||
fn parse_host_port(input: &str) -> Option<(&str, u16)> {
|
||||
|
||||
@@ -4,9 +4,10 @@
|
||||
|
||||
pub mod health_server;
|
||||
|
||||
use std::time::Duration;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use bytes::Bytes;
|
||||
use futures::FutureExt;
|
||||
use http::Method;
|
||||
use http_body_util::BodyExt;
|
||||
use hyper::body::Body;
|
||||
@@ -109,15 +110,31 @@ impl Endpoint {
|
||||
}
|
||||
|
||||
/// Execute a [request](reqwest::Request).
|
||||
pub(crate) async fn execute(&self, request: Request) -> Result<Response, Error> {
|
||||
let _timer = Metrics::get()
|
||||
pub(crate) fn execute(
|
||||
&self,
|
||||
request: Request,
|
||||
) -> impl Future<Output = Result<Response, Error>> {
|
||||
let metric = Metrics::get()
|
||||
.proxy
|
||||
.console_request_latency
|
||||
.start_timer(ConsoleRequest {
|
||||
.with_labels(ConsoleRequest {
|
||||
request: request.url().path(),
|
||||
});
|
||||
|
||||
self.client.execute(request).await
|
||||
let req = self.client.execute(request).boxed();
|
||||
|
||||
async move {
|
||||
let start = Instant::now();
|
||||
scopeguard::defer!({
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.console_request_latency
|
||||
.get_metric(metric)
|
||||
.observe_duration_since(start);
|
||||
});
|
||||
|
||||
req.await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -186,7 +186,7 @@ where
|
||||
pub async fn read_message<'a, S>(
|
||||
stream: &mut S,
|
||||
buf: &'a mut Vec<u8>,
|
||||
max: usize,
|
||||
max: u32,
|
||||
) -> io::Result<(u8, &'a mut [u8])>
|
||||
where
|
||||
S: AsyncRead + Unpin,
|
||||
@@ -206,7 +206,7 @@ where
|
||||
let header = read!(stream => Header);
|
||||
|
||||
// as described above, the length must be at least 4.
|
||||
let Some(len) = (header.len.get() as usize).checked_sub(4) else {
|
||||
let Some(len) = header.len.get().checked_sub(4) else {
|
||||
return Err(io::Error::other(format!(
|
||||
"invalid startup message length {}, must be at least 4.",
|
||||
header.len,
|
||||
@@ -222,7 +222,7 @@ where
|
||||
}
|
||||
|
||||
// read in our entire message.
|
||||
buf.resize(len, 0);
|
||||
buf.resize(len as usize, 0);
|
||||
stream.read_exact(buf).await?;
|
||||
|
||||
Ok((header.tag, buf))
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
use futures::{FutureExt, TryFutureExt};
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tracing::{debug, info, warn};
|
||||
@@ -57,7 +58,7 @@ pub(crate) enum HandshakeData<S> {
|
||||
/// It's easier to work with owned `stream` here as we need to upgrade it to TLS;
|
||||
/// we also take an extra care of propagating only the select handshake errors to client.
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
ctx: &RequestContext,
|
||||
stream: S,
|
||||
mut tls: Option<&TlsConfig>,
|
||||
@@ -108,7 +109,9 @@ pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
})
|
||||
.map_ok(Box::new)
|
||||
.boxed();
|
||||
|
||||
res?;
|
||||
|
||||
@@ -146,7 +149,7 @@ pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
tls.cert_resolver.resolve(conn_info.server_name());
|
||||
|
||||
let tls = Stream::Tls {
|
||||
tls: Box::new(tls_stream),
|
||||
tls: tls_stream,
|
||||
tls_server_end_point,
|
||||
};
|
||||
(stream, msg) = PqStream::parse_startup(tls).await?;
|
||||
|
||||
@@ -270,7 +270,7 @@ impl ReportableError for ClientRequestError {
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
config: &'static ProxyConfig,
|
||||
auth_backend: &'static auth::Backend<'static, ()>,
|
||||
ctx: &RequestContext,
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
use futures::FutureExt;
|
||||
use smol_str::SmolStr;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tracing::debug;
|
||||
@@ -89,6 +90,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> ProxyPassthrough<S> {
|
||||
.compute
|
||||
.cancel_closure
|
||||
.try_cancel_query(compute_config)
|
||||
.boxed()
|
||||
.await
|
||||
{
|
||||
tracing::warn!(session_id = ?self.session_id, ?err, "could not cancel the query in the database");
|
||||
|
||||
@@ -30,52 +30,53 @@ where
|
||||
F: FnOnce(&str) -> super::Result<M>,
|
||||
M: Mechanism,
|
||||
{
|
||||
let sasl = {
|
||||
let (mut mechanism, mut input) = {
|
||||
// pause the timer while we communicate with the client
|
||||
let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
|
||||
|
||||
// Initial client message contains the chosen auth method's name.
|
||||
let msg = stream.read_password_message().await?;
|
||||
super::FirstMessage::parse(msg).ok_or(super::Error::BadClientMessage("bad sasl message"))?
|
||||
|
||||
let sasl = super::FirstMessage::parse(msg)
|
||||
.ok_or(super::Error::BadClientMessage("bad sasl message"))?;
|
||||
|
||||
(mechanism(sasl.method)?, sasl.message)
|
||||
};
|
||||
|
||||
let mut mechanism = mechanism(sasl.method)?;
|
||||
let mut input = sasl.message;
|
||||
loop {
|
||||
let step = mechanism
|
||||
.exchange(input)
|
||||
.inspect_err(|error| tracing::info!(?error, "error during SASL exchange"))?;
|
||||
|
||||
match step {
|
||||
Step::Continue(moved_mechanism, reply) => {
|
||||
match mechanism.exchange(input) {
|
||||
Ok(Step::Continue(moved_mechanism, reply)) => {
|
||||
mechanism = moved_mechanism;
|
||||
|
||||
// pause the timer while we communicate with the client
|
||||
let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
|
||||
|
||||
// write reply
|
||||
let sasl_msg = BeAuthenticationSaslMessage::Continue(reply.as_bytes());
|
||||
stream.write_message(BeMessage::AuthenticationSasl(sasl_msg));
|
||||
|
||||
// get next input
|
||||
stream.flush().await?;
|
||||
let msg = stream.read_password_message().await?;
|
||||
input = std::str::from_utf8(msg)
|
||||
.map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "bad encoding"))?;
|
||||
drop(reply);
|
||||
}
|
||||
Step::Success(result, reply) => {
|
||||
// pause the timer while we communicate with the client
|
||||
let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
|
||||
|
||||
Ok(Step::Success(result, reply)) => {
|
||||
// write reply
|
||||
let sasl_msg = BeAuthenticationSaslMessage::Final(reply.as_bytes());
|
||||
stream.write_message(BeMessage::AuthenticationSasl(sasl_msg));
|
||||
stream.write_message(BeMessage::AuthenticationOk);
|
||||
|
||||
// exit with success
|
||||
break Ok(Outcome::Success(result));
|
||||
}
|
||||
// exit with failure
|
||||
Step::Failure(reason) => break Ok(Outcome::Failure(reason)),
|
||||
Ok(Step::Failure(reason)) => break Ok(Outcome::Failure(reason)),
|
||||
Err(error) => {
|
||||
tracing::info!(?error, "error during SASL exchange");
|
||||
return Err(error);
|
||||
}
|
||||
}
|
||||
|
||||
// pause the timer while we communicate with the client
|
||||
let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
|
||||
|
||||
// get next input
|
||||
stream.flush().await?;
|
||||
let msg = stream.read_password_message().await?;
|
||||
input = std::str::from_utf8(msg)
|
||||
.map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "bad encoding"))?;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -72,7 +72,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> PqStream<S> {
|
||||
impl<S: AsyncRead + Unpin> PqStream<S> {
|
||||
/// Read a raw postgres packet, which will respect the max length requested.
|
||||
/// This is not cancel safe.
|
||||
async fn read_raw_expect(&mut self, tag: u8, max: usize) -> io::Result<&mut [u8]> {
|
||||
async fn read_raw_expect(&mut self, tag: u8, max: u32) -> io::Result<&mut [u8]> {
|
||||
let (actual_tag, msg) = read_message(&mut self.stream, &mut self.read, max).await?;
|
||||
if actual_tag != tag {
|
||||
return Err(io::Error::other(format!(
|
||||
@@ -89,7 +89,7 @@ impl<S: AsyncRead + Unpin> PqStream<S> {
|
||||
// passwords are usually pretty short
|
||||
// and SASL SCRAM messages are no longer than 256 bytes in my testing
|
||||
// (a few hashes and random bytes, encoded into base64).
|
||||
const MAX_PASSWORD_LENGTH: usize = 512;
|
||||
const MAX_PASSWORD_LENGTH: u32 = 512;
|
||||
self.read_raw_expect(FE_PASSWORD_MESSAGE, MAX_PASSWORD_LENGTH)
|
||||
.await
|
||||
}
|
||||
|
||||
@@ -31,7 +31,9 @@ mod private {
|
||||
type Output = io::Result<RustlsStream<S>>;
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
Pin::new(&mut self.inner).poll(cx).map_ok(RustlsStream)
|
||||
Pin::new(&mut self.inner)
|
||||
.poll(cx)
|
||||
.map_ok(|s| RustlsStream(Box::new(s)))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -57,7 +59,7 @@ mod private {
|
||||
}
|
||||
}
|
||||
|
||||
pub struct RustlsStream<S>(TlsStream<S>);
|
||||
pub struct RustlsStream<S>(Box<TlsStream<S>>);
|
||||
|
||||
impl<S> postgres_client::tls::TlsStream for RustlsStream<S>
|
||||
where
|
||||
|
||||
Reference in New Issue
Block a user