[proxy] move read_info from the compute connection to be as late as possible (#12660)

Second attempt at #12130, now with a smaller diff.

This allows us to skip allocating for things like parameter status and
notices that we will either just forward untouched, or discard.

LKB-2494
This commit is contained in:
Conrad Ludgate
2025-07-23 14:33:21 +01:00
committed by GitHub
parent 94cb9a79d9
commit 761e9e0e1d
12 changed files with 276 additions and 227 deletions

View File

@@ -429,26 +429,13 @@ impl CancellationHandler {
/// (we'd need something like `#![feature(type_alias_impl_trait)]`).
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CancelClosure {
socket_addr: SocketAddr,
cancel_token: RawCancelToken,
hostname: String, // for pg_sni router
user_info: ComputeUserInfo,
pub socket_addr: SocketAddr,
pub cancel_token: RawCancelToken,
pub hostname: String, // for pg_sni router
pub user_info: ComputeUserInfo,
}
impl CancelClosure {
pub(crate) fn new(
socket_addr: SocketAddr,
cancel_token: RawCancelToken,
hostname: String,
user_info: ComputeUserInfo,
) -> Self {
Self {
socket_addr,
cancel_token,
hostname,
user_info,
}
}
/// Cancels the query running on user's compute node.
pub(crate) async fn try_cancel_query(
&self,

View File

@@ -7,17 +7,15 @@ use std::net::{IpAddr, SocketAddr};
use futures::{FutureExt, TryFutureExt};
use itertools::Itertools;
use postgres_client::config::{AuthKeys, ChannelBinding, SslMode};
use postgres_client::connect_raw::StartupStream;
use postgres_client::maybe_tls_stream::MaybeTlsStream;
use postgres_client::tls::MakeTlsConnect;
use postgres_client::{NoTls, RawCancelToken, RawConnection};
use postgres_protocol::message::backend::NoticeResponseBody;
use thiserror::Error;
use tokio::net::{TcpStream, lookup_host};
use tracing::{debug, error, info, warn};
use crate::auth::backend::{ComputeCredentialKeys, ComputeUserInfo};
use crate::auth::backend::ComputeCredentialKeys;
use crate::auth::parse_endpoint_param;
use crate::cancellation::CancelClosure;
use crate::compute::tls::TlsError;
use crate::config::ComputeConfig;
use crate::context::RequestContext;
@@ -236,8 +234,7 @@ impl AuthInfo {
&self,
ctx: &RequestContext,
compute: &mut ComputeConnection,
user_info: &ComputeUserInfo,
) -> Result<PostgresSettings, PostgresError> {
) -> Result<(), PostgresError> {
// client config with stubbed connect info.
// TODO(conrad): should we rewrite this to bypass tokio-postgres2 entirely,
// utilising pqproto.rs.
@@ -247,39 +244,10 @@ impl AuthInfo {
let tmp_config = self.enrich(tmp_config);
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let connection = tmp_config
.tls_and_authenticate(&mut compute.stream, NoTls)
.await?;
tmp_config.authenticate(&mut compute.stream).await?;
drop(pause);
let RawConnection {
stream: _,
parameters,
delayed_notice,
process_id,
secret_key,
} = connection;
tracing::Span::current().record("pid", tracing::field::display(process_id));
// NB: CancelToken is supposed to hold socket_addr, but we use connect_raw.
// Yet another reason to rework the connection establishing code.
let cancel_closure = CancelClosure::new(
compute.socket_addr,
RawCancelToken {
ssl_mode: compute.ssl_mode,
process_id,
secret_key,
},
compute.hostname.to_string(),
user_info.clone(),
);
Ok(PostgresSettings {
params: parameters,
cancel_closure,
delayed_notice,
})
Ok(())
}
}
@@ -343,21 +311,9 @@ impl ConnectInfo {
pub type RustlsStream = <ComputeConfig as MakeTlsConnect<tokio::net::TcpStream>>::Stream;
pub type MaybeRustlsStream = MaybeTlsStream<tokio::net::TcpStream, RustlsStream>;
// TODO(conrad): we don't need to parse these.
// These are just immediately forwarded back to the client.
// We could instead stream them out instead of reading them into memory.
pub struct PostgresSettings {
/// PostgreSQL connection parameters.
pub params: std::collections::HashMap<String, String>,
/// Query cancellation token.
pub cancel_closure: CancelClosure,
/// Notices received from compute after authenticating
pub delayed_notice: Vec<NoticeResponseBody>,
}
pub struct ComputeConnection {
/// Socket connected to a compute node.
pub stream: MaybeTlsStream<tokio::net::TcpStream, RustlsStream>,
pub stream: StartupStream<tokio::net::TcpStream, RustlsStream>,
/// Labels for proxy's metrics.
pub aux: MetricsAuxInfo,
pub hostname: Host,
@@ -390,6 +346,7 @@ impl ConnectInfo {
ctx.get_testodrome_id().unwrap_or_default(),
);
let stream = StartupStream::new(stream);
let connection = ComputeConnection {
stream,
socket_addr,

View File

@@ -1,12 +1,13 @@
use std::sync::Arc;
use futures::{FutureExt, TryFutureExt};
use postgres_client::RawCancelToken;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_util::sync::CancellationToken;
use tracing::{Instrument, debug, error, info};
use crate::auth::backend::ConsoleRedirectBackend;
use crate::cancellation::CancellationHandler;
use crate::cancellation::{CancelClosure, CancellationHandler};
use crate::config::{ProxyConfig, ProxyProtocolV2};
use crate::context::RequestContext;
use crate::error::ReportableError;
@@ -16,7 +17,7 @@ use crate::pglb::handshake::{HandshakeData, handshake};
use crate::pglb::passthrough::ProxyPassthrough;
use crate::protocol2::{ConnectHeader, ConnectionInfo, read_proxy_protocol};
use crate::proxy::connect_compute::{TcpMechanism, connect_to_compute};
use crate::proxy::{ErrorSource, finish_client_init};
use crate::proxy::{ErrorSource, forward_compute_params_to_client, send_client_greeting};
use crate::util::run_until_cancelled;
pub async fn task_main(
@@ -226,21 +227,19 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
.or_else(|e| async { Err(stream.throw_error(e, Some(ctx)).await) })
.await?;
let pg_settings = auth_info
.authenticate(ctx, &mut node, &user_info)
auth_info
.authenticate(ctx, &mut node)
.or_else(|e| async { Err(stream.throw_error(e, Some(ctx)).await) })
.await?;
send_client_greeting(ctx, &config.greetings, &mut stream);
let session = cancellation_handler.get_key();
finish_client_init(
ctx,
&pg_settings,
*session.key(),
&mut stream,
&config.greetings,
);
let (process_id, secret_key) =
forward_compute_params_to_client(ctx, *session.key(), &mut stream, &mut node.stream)
.await?;
let stream = stream.flush_and_into_inner().await?;
let hostname = node.hostname.to_string();
let session_id = ctx.session_id();
let (cancel_on_shutdown, cancel) = tokio::sync::oneshot::channel();
@@ -249,7 +248,16 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
.maintain_cancel_key(
session_id,
cancel,
&pg_settings.cancel_closure,
&CancelClosure {
socket_addr: node.socket_addr,
cancel_token: RawCancelToken {
ssl_mode: node.ssl_mode,
process_id,
secret_key,
},
hostname,
user_info,
},
&config.connect_to_compute,
)
.await;
@@ -257,7 +265,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
Ok(Some(ProxyPassthrough {
client: stream,
compute: node.stream,
compute: node.stream.into_framed().into_inner(),
aux: node.aux,
private_link_id: None,

View File

@@ -319,7 +319,7 @@ pub(crate) async fn handle_connection<S: AsyncRead + AsyncWrite + Unpin + Send>(
Ok(Some(ProxyPassthrough {
client,
compute: node.stream,
compute: node.stream.into_framed().into_inner(),
aux: node.aux,
private_link_id,

View File

@@ -313,6 +313,14 @@ impl WriteBuf {
self.0.set_position(0);
}
/// Shrinks the buffer if efficient to do so, and returns the remaining size.
pub fn occupied_len(&mut self) -> usize {
if self.should_shrink() {
self.shrink();
}
self.0.get_mut().len()
}
/// Write a raw message to the internal buffer.
///
/// The size_hint value is only a hint for reserving space. It's ok if it's incorrect, since

View File

@@ -9,18 +9,23 @@ use std::collections::HashSet;
use std::convert::Infallible;
use std::sync::Arc;
use futures::TryStreamExt;
use itertools::Itertools;
use once_cell::sync::OnceCell;
use postgres_client::RawCancelToken;
use postgres_client::connect_raw::StartupStream;
use postgres_protocol::message::backend::Message;
use regex::Regex;
use serde::{Deserialize, Serialize};
use smol_str::{SmolStr, format_smolstr};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpStream;
use tokio::sync::oneshot;
use tracing::Instrument;
use crate::cache::Cache;
use crate::cancellation::CancellationHandler;
use crate::compute::ComputeConnection;
use crate::cancellation::{CancelClosure, CancellationHandler};
use crate::compute::{ComputeConnection, PostgresError, RustlsStream};
use crate::config::ProxyConfig;
use crate::context::RequestContext;
use crate::control_plane::client::ControlPlaneClient;
@@ -105,7 +110,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
// the compute was cached, and we connected, but the compute cache was actually stale
// and is associated with the wrong endpoint. We detect this when the **authentication** fails.
// As such, we retry once here if the `authenticate` function fails and the error is valid to retry.
let pg_settings = loop {
loop {
attempt += 1;
// TODO: callback to pglb
@@ -127,9 +132,12 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
unreachable!("ensured above");
};
let res = auth_info.authenticate(ctx, &mut node, user_info).await;
let res = auth_info.authenticate(ctx, &mut node).await;
match res {
Ok(pg_settings) => break pg_settings,
Ok(()) => {
send_client_greeting(ctx, &config.greetings, client);
break;
}
Err(e) if attempt < 2 && e.should_retry_wake_compute() => {
tracing::warn!(error = ?e, "retrying wake compute");
@@ -141,11 +149,17 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
}
Err(e) => Err(client.throw_error(e, Some(ctx)).await)?,
}
}
let auth::Backend::ControlPlane(_, user_info) = backend else {
unreachable!("ensured above");
};
let session = cancellation_handler.get_key();
finish_client_init(ctx, &pg_settings, *session.key(), client, &config.greetings);
let (process_id, secret_key) =
forward_compute_params_to_client(ctx, *session.key(), client, &mut node.stream).await?;
let hostname = node.hostname.to_string();
let session_id = ctx.session_id();
let (cancel_on_shutdown, cancel) = oneshot::channel();
@@ -154,7 +168,16 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
.maintain_cancel_key(
session_id,
cancel,
&pg_settings.cancel_closure,
&CancelClosure {
socket_addr: node.socket_addr,
cancel_token: RawCancelToken {
ssl_mode: node.ssl_mode,
process_id,
secret_key,
},
hostname,
user_info,
},
&config.connect_to_compute,
)
.await;
@@ -163,35 +186,18 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
Ok((node, cancel_on_shutdown))
}
/// Finish client connection initialization: confirm auth success, send params, etc.
pub(crate) fn finish_client_init(
/// Greet the client with any useful information.
pub(crate) fn send_client_greeting(
ctx: &RequestContext,
settings: &compute::PostgresSettings,
cancel_key_data: CancelKeyData,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
greetings: &String,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
) {
// Forward all deferred notices to the client.
for notice in &settings.delayed_notice {
client.write_raw(notice.as_bytes().len(), b'N', |buf| {
buf.extend_from_slice(notice.as_bytes());
});
}
// Expose session_id to clients if we have a greeting message.
if !greetings.is_empty() {
let session_msg = format!("{}, session_id: {}", greetings, ctx.session_id());
client.write_message(BeMessage::NoticeResponse(session_msg.as_str()));
}
// Forward all postgres connection params to the client.
for (name, value) in &settings.params {
client.write_message(BeMessage::ParameterStatus {
name: name.as_bytes(),
value: value.as_bytes(),
});
}
// Forward recorded latencies for probing requests
if let Some(testodrome_id) = ctx.get_testodrome_id() {
client.write_message(BeMessage::ParameterStatus {
@@ -221,9 +227,63 @@ pub(crate) fn finish_client_init(
value: latency_measured.retry.as_micros().to_string().as_bytes(),
});
}
}
client.write_message(BeMessage::BackendKeyData(cancel_key_data));
client.write_message(BeMessage::ReadyForQuery);
pub(crate) async fn forward_compute_params_to_client(
ctx: &RequestContext,
cancel_key_data: CancelKeyData,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
compute: &mut StartupStream<TcpStream, RustlsStream>,
) -> Result<(i32, i32), ClientRequestError> {
let mut process_id = 0;
let mut secret_key = 0;
let err = loop {
// if the client buffer is too large, let's write out some bytes now to save some space
client.write_if_full().await?;
let msg = match compute.try_next().await {
Ok(msg) => msg,
Err(e) => break postgres_client::Error::io(e),
};
match msg {
// Send our cancellation key data instead.
Some(Message::BackendKeyData(body)) => {
client.write_message(BeMessage::BackendKeyData(cancel_key_data));
process_id = body.process_id();
secret_key = body.secret_key();
}
// Forward all postgres connection params to the client.
Some(Message::ParameterStatus(body)) => {
if let Ok(name) = body.name()
&& let Ok(value) = body.value()
{
client.write_message(BeMessage::ParameterStatus {
name: name.as_bytes(),
value: value.as_bytes(),
});
}
}
// Forward all notices to the client.
Some(Message::NoticeResponse(notice)) => {
client.write_raw(notice.as_bytes().len(), b'N', |buf| {
buf.extend_from_slice(notice.as_bytes());
});
}
Some(Message::ReadyForQuery(_)) => {
client.write_message(BeMessage::ReadyForQuery);
return Ok((process_id, secret_key));
}
Some(Message::ErrorResponse(body)) => break postgres_client::Error::db(body),
Some(_) => break postgres_client::Error::unexpected_message(),
None => break postgres_client::Error::closed(),
}
};
Err(client
.throw_error(PostgresError::Postgres(err), Some(ctx))
.await)?
}
#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]

View File

@@ -154,6 +154,15 @@ impl<S: AsyncWrite + Unpin> PqStream<S> {
message.write_message(&mut self.write);
}
/// Write the buffer to the socket until we have some more space again.
pub async fn write_if_full(&mut self) -> io::Result<()> {
while self.write.occupied_len() > 2048 {
self.stream.write_buf(&mut self.write).await?;
}
Ok(())
}
/// Flush the output buffer into the underlying stream.
///
/// This is cancel safe.