From 761e9e0e1d5d9eed826acf66c1c03be9cb5ddbd4 Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Wed, 23 Jul 2025 14:33:21 +0100 Subject: [PATCH] [proxy] move `read_info` from the compute connection to be as late as possible (#12660) Second attempt at #12130, now with a smaller diff. This allows us to skip allocating for things like parameter status and notices that we will either just forward untouched, or discard. LKB-2494 --- libs/proxy/tokio-postgres2/src/config.rs | 19 +- libs/proxy/tokio-postgres2/src/connect.rs | 43 +++-- libs/proxy/tokio-postgres2/src/connect_raw.rs | 181 +++++++++--------- libs/proxy/tokio-postgres2/src/error/mod.rs | 8 +- libs/proxy/tokio-postgres2/src/lib.rs | 3 +- proxy/src/cancellation.rs | 21 +- proxy/src/compute/mod.rs | 57 +----- proxy/src/console_redirect_proxy.rs | 34 ++-- proxy/src/pglb/mod.rs | 2 +- proxy/src/pqproto.rs | 8 + proxy/src/proxy/mod.rs | 118 +++++++++--- proxy/src/stream.rs | 9 + 12 files changed, 276 insertions(+), 227 deletions(-) diff --git a/libs/proxy/tokio-postgres2/src/config.rs b/libs/proxy/tokio-postgres2/src/config.rs index 961cbc923e..c619f92d13 100644 --- a/libs/proxy/tokio-postgres2/src/config.rs +++ b/libs/proxy/tokio-postgres2/src/config.rs @@ -11,9 +11,8 @@ use tokio::io::{AsyncRead, AsyncWrite}; use tokio::net::TcpStream; use crate::connect::connect; -use crate::connect_raw::{RawConnection, connect_raw}; +use crate::connect_raw::{self, StartupStream}; use crate::connect_tls::connect_tls; -use crate::maybe_tls_stream::MaybeTlsStream; use crate::tls::{MakeTlsConnect, TlsConnect, TlsStream}; use crate::{Client, Connection, Error}; @@ -244,24 +243,26 @@ impl Config { &self, stream: S, tls: T, - ) -> Result, Error> + ) -> Result, Error> where S: AsyncRead + AsyncWrite + Unpin, T: TlsConnect, { let stream = connect_tls(stream, self.ssl_mode, tls).await?; - connect_raw(stream, self).await + let mut stream = StartupStream::new(stream); + connect_raw::startup(&mut stream, self).await?; + connect_raw::authenticate(&mut stream, self).await?; + + Ok(stream) } - pub async fn authenticate( - &self, - stream: MaybeTlsStream, - ) -> Result, Error> + pub async fn authenticate(&self, stream: &mut StartupStream) -> Result<(), Error> where S: AsyncRead + AsyncWrite + Unpin, T: TlsStream + Unpin, { - connect_raw(stream, self).await + connect_raw::startup(stream, self).await?; + connect_raw::authenticate(stream, self).await } } diff --git a/libs/proxy/tokio-postgres2/src/connect.rs b/libs/proxy/tokio-postgres2/src/connect.rs index 2f718e1e7d..41d95c5f84 100644 --- a/libs/proxy/tokio-postgres2/src/connect.rs +++ b/libs/proxy/tokio-postgres2/src/connect.rs @@ -1,15 +1,17 @@ use std::net::IpAddr; +use futures_util::TryStreamExt; +use postgres_protocol2::message::backend::Message; +use tokio::io::{AsyncRead, AsyncWrite}; use tokio::net::TcpStream; use tokio::sync::mpsc; use crate::client::SocketConfig; use crate::config::Host; -use crate::connect_raw::connect_raw; +use crate::connect_raw::StartupStream; use crate::connect_socket::connect_socket; -use crate::connect_tls::connect_tls; use crate::tls::{MakeTlsConnect, TlsConnect}; -use crate::{Client, Config, Connection, Error, RawConnection}; +use crate::{Client, Config, Connection, Error}; pub async fn connect( tls: &T, @@ -43,14 +45,8 @@ where T: TlsConnect, { let socket = connect_socket(host_addr, host, port, config.connect_timeout).await?; - let stream = connect_tls(socket, config.ssl_mode, tls).await?; - let RawConnection { - stream, - parameters: _, - delayed_notice: _, - process_id, - secret_key, - } = connect_raw(stream, config).await?; + let mut stream = config.tls_and_authenticate(socket, tls).await?; + let (process_id, secret_key) = wait_until_ready(&mut stream).await?; let socket_config = SocketConfig { host_addr, @@ -70,7 +66,32 @@ where secret_key, ); + let stream = stream.into_framed(); let connection = Connection::new(stream, conn_tx, conn_rx); Ok((client, connection)) } + +async fn wait_until_ready(stream: &mut StartupStream) -> Result<(i32, i32), Error> +where + S: AsyncRead + AsyncWrite + Unpin, + T: AsyncRead + AsyncWrite + Unpin, +{ + let mut process_id = 0; + let mut secret_key = 0; + + loop { + match stream.try_next().await.map_err(Error::io)? { + Some(Message::BackendKeyData(body)) => { + process_id = body.process_id(); + secret_key = body.secret_key(); + } + // These values are currently not used by `Client`/`Connection`. Ignore them. + Some(Message::ParameterStatus(_)) | Some(Message::NoticeResponse(_)) => {} + Some(Message::ReadyForQuery(_)) => return Ok((process_id, secret_key)), + Some(Message::ErrorResponse(body)) => return Err(Error::db(body)), + Some(_) => return Err(Error::unexpected_message()), + None => return Err(Error::closed()), + } + } +} diff --git a/libs/proxy/tokio-postgres2/src/connect_raw.rs b/libs/proxy/tokio-postgres2/src/connect_raw.rs index 462e1be1aa..bc35cef339 100644 --- a/libs/proxy/tokio-postgres2/src/connect_raw.rs +++ b/libs/proxy/tokio-postgres2/src/connect_raw.rs @@ -1,28 +1,26 @@ -use std::collections::HashMap; use std::io; use std::pin::Pin; -use std::task::{Context, Poll}; +use std::task::{Context, Poll, ready}; use bytes::{Bytes, BytesMut}; use fallible_iterator::FallibleIterator; -use futures_util::{Sink, SinkExt, Stream, TryStreamExt, ready}; +use futures_util::{Sink, SinkExt, Stream, TryStreamExt}; use postgres_protocol2::authentication::sasl; use postgres_protocol2::authentication::sasl::ScramSha256; -use postgres_protocol2::message::backend::{AuthenticationSaslBody, Message, NoticeResponseBody}; +use postgres_protocol2::message::backend::{AuthenticationSaslBody, Message}; use postgres_protocol2::message::frontend; -use tokio::io::{AsyncRead, AsyncWrite}; -use tokio_util::codec::Framed; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use tokio_util::codec::{Framed, FramedParts, FramedWrite}; use crate::Error; -use crate::codec::{BackendMessage, BackendMessages, PostgresCodec}; +use crate::codec::PostgresCodec; use crate::config::{self, AuthKeys, Config}; use crate::maybe_tls_stream::MaybeTlsStream; use crate::tls::TlsStream; pub struct StartupStream { - inner: Framed, PostgresCodec>, - buf: BackendMessages, - delayed_notice: Vec, + inner: FramedWrite, PostgresCodec>, + read_buf: BytesMut, } impl Sink for StartupStream @@ -56,63 +54,93 @@ where { type Item = io::Result; - fn poll_next( - mut self: Pin<&mut Self>, + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + // read 1 byte tag, 4 bytes length. + let header = ready!(self.as_mut().poll_fill_buf_exact(cx, 5)?); + + let len = u32::from_be_bytes(header[1..5].try_into().unwrap()); + if len < 4 { + return Poll::Ready(Some(Err(std::io::Error::other( + "postgres message too small", + )))); + } + if len >= 65536 { + return Poll::Ready(Some(Err(std::io::Error::other( + "postgres message too large", + )))); + } + + // the tag is an additional byte. + let _message = ready!(self.as_mut().poll_fill_buf_exact(cx, len as usize + 1)?); + + // Message::parse will remove the all the bytes from the buffer. + Poll::Ready(Message::parse(&mut self.read_buf).transpose()) + } +} + +impl StartupStream +where + S: AsyncRead + AsyncWrite + Unpin, + T: AsyncRead + AsyncWrite + Unpin, +{ + /// Fill the buffer until it's the exact length provided. No additional data will be read from the socket. + /// + /// If the current buffer length is greater, nothing happens. + fn poll_fill_buf_exact( + self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>> { - loop { - match self.buf.next() { - Ok(Some(message)) => return Poll::Ready(Some(Ok(message))), - Ok(None) => {} - Err(e) => return Poll::Ready(Some(Err(e))), + len: usize, + ) -> Poll> { + let this = self.get_mut(); + let mut stream = Pin::new(this.inner.get_mut()); + + let mut n = this.read_buf.len(); + while n < len { + this.read_buf.resize(len, 0); + + let mut buf = ReadBuf::new(&mut this.read_buf[..]); + buf.set_filled(n); + + if stream.as_mut().poll_read(cx, &mut buf)?.is_pending() { + this.read_buf.truncate(n); + return Poll::Pending; } - match ready!(Pin::new(&mut self.inner).poll_next(cx)) { - Some(Ok(BackendMessage::Normal { messages, .. })) => self.buf = messages, - Some(Ok(BackendMessage::Async(message))) => return Poll::Ready(Some(Ok(message))), - Some(Err(e)) => return Poll::Ready(Some(Err(e))), - None => return Poll::Ready(None), + if buf.filled().len() == n { + return Poll::Ready(Err(std::io::Error::new( + std::io::ErrorKind::UnexpectedEof, + "early eof", + ))); } + n = buf.filled().len(); + + this.read_buf.truncate(n); + } + + Poll::Ready(Ok(&this.read_buf[..len])) + } + + pub fn into_framed(mut self) -> Framed, PostgresCodec> { + let write_buf = std::mem::take(self.inner.write_buffer_mut()); + let io = self.inner.into_inner(); + let mut parts = FramedParts::new(io, PostgresCodec); + parts.read_buf = self.read_buf; + parts.write_buf = write_buf; + Framed::from_parts(parts) + } + + pub fn new(io: MaybeTlsStream) -> Self { + Self { + inner: FramedWrite::new(io, PostgresCodec), + read_buf: BytesMut::new(), } } } -pub struct RawConnection { - pub stream: Framed, PostgresCodec>, - pub parameters: HashMap, - pub delayed_notice: Vec, - pub process_id: i32, - pub secret_key: i32, -} - -pub async fn connect_raw( - stream: MaybeTlsStream, +pub(crate) async fn startup( + stream: &mut StartupStream, config: &Config, -) -> Result, Error> -where - S: AsyncRead + AsyncWrite + Unpin, - T: TlsStream + Unpin, -{ - let mut stream = StartupStream { - inner: Framed::new(stream, PostgresCodec), - buf: BackendMessages::empty(), - delayed_notice: Vec::new(), - }; - - startup(&mut stream, config).await?; - authenticate(&mut stream, config).await?; - let (process_id, secret_key, parameters) = read_info(&mut stream).await?; - - Ok(RawConnection { - stream: stream.inner, - parameters, - delayed_notice: stream.delayed_notice, - process_id, - secret_key, - }) -} - -async fn startup(stream: &mut StartupStream, config: &Config) -> Result<(), Error> +) -> Result<(), Error> where S: AsyncRead + AsyncWrite + Unpin, T: AsyncRead + AsyncWrite + Unpin, @@ -123,7 +151,10 @@ where stream.send(buf.freeze()).await.map_err(Error::io) } -async fn authenticate(stream: &mut StartupStream, config: &Config) -> Result<(), Error> +pub(crate) async fn authenticate( + stream: &mut StartupStream, + config: &Config, +) -> Result<(), Error> where S: AsyncRead + AsyncWrite + Unpin, T: TlsStream + Unpin, @@ -278,35 +309,3 @@ where Ok(()) } - -async fn read_info( - stream: &mut StartupStream, -) -> Result<(i32, i32, HashMap), Error> -where - S: AsyncRead + AsyncWrite + Unpin, - T: AsyncRead + AsyncWrite + Unpin, -{ - let mut process_id = 0; - let mut secret_key = 0; - let mut parameters = HashMap::new(); - - loop { - match stream.try_next().await.map_err(Error::io)? { - Some(Message::BackendKeyData(body)) => { - process_id = body.process_id(); - secret_key = body.secret_key(); - } - Some(Message::ParameterStatus(body)) => { - parameters.insert( - body.name().map_err(Error::parse)?.to_string(), - body.value().map_err(Error::parse)?.to_string(), - ); - } - Some(Message::NoticeResponse(body)) => stream.delayed_notice.push(body), - Some(Message::ReadyForQuery(_)) => return Ok((process_id, secret_key, parameters)), - Some(Message::ErrorResponse(body)) => return Err(Error::db(body)), - Some(_) => return Err(Error::unexpected_message()), - None => return Err(Error::closed()), - } - } -} diff --git a/libs/proxy/tokio-postgres2/src/error/mod.rs b/libs/proxy/tokio-postgres2/src/error/mod.rs index 5309bce17e..6e68b1e595 100644 --- a/libs/proxy/tokio-postgres2/src/error/mod.rs +++ b/libs/proxy/tokio-postgres2/src/error/mod.rs @@ -452,16 +452,16 @@ impl Error { Error(Box::new(ErrorInner { kind, cause })) } - pub(crate) fn closed() -> Error { + pub fn closed() -> Error { Error::new(Kind::Closed, None) } - pub(crate) fn unexpected_message() -> Error { + pub fn unexpected_message() -> Error { Error::new(Kind::UnexpectedMessage, None) } #[allow(clippy::needless_pass_by_value)] - pub(crate) fn db(error: ErrorResponseBody) -> Error { + pub fn db(error: ErrorResponseBody) -> Error { match DbError::parse(&mut error.fields()) { Ok(e) => Error::new(Kind::Db, Some(Box::new(e))), Err(e) => Error::new(Kind::Parse, Some(Box::new(e))), @@ -493,7 +493,7 @@ impl Error { Error::new(Kind::Tls, Some(e)) } - pub(crate) fn io(e: io::Error) -> Error { + pub fn io(e: io::Error) -> Error { Error::new(Kind::Io, Some(Box::new(e))) } diff --git a/libs/proxy/tokio-postgres2/src/lib.rs b/libs/proxy/tokio-postgres2/src/lib.rs index e3dd6d9261..a858ddca39 100644 --- a/libs/proxy/tokio-postgres2/src/lib.rs +++ b/libs/proxy/tokio-postgres2/src/lib.rs @@ -6,7 +6,6 @@ use postgres_protocol2::message::backend::ReadyForQueryBody; pub use crate::cancel_token::{CancelToken, RawCancelToken}; pub use crate::client::{Client, SocketConfig}; pub use crate::config::Config; -pub use crate::connect_raw::RawConnection; pub use crate::connection::Connection; pub use crate::error::Error; pub use crate::generic_client::GenericClient; @@ -50,7 +49,7 @@ mod client; mod codec; pub mod config; mod connect; -mod connect_raw; +pub mod connect_raw; mod connect_socket; mod connect_tls; mod connection; diff --git a/proxy/src/cancellation.rs b/proxy/src/cancellation.rs index f25121331f..13c6f0f6d7 100644 --- a/proxy/src/cancellation.rs +++ b/proxy/src/cancellation.rs @@ -429,26 +429,13 @@ impl CancellationHandler { /// (we'd need something like `#![feature(type_alias_impl_trait)]`). #[derive(Debug, Clone, Serialize, Deserialize)] pub struct CancelClosure { - socket_addr: SocketAddr, - cancel_token: RawCancelToken, - hostname: String, // for pg_sni router - user_info: ComputeUserInfo, + pub socket_addr: SocketAddr, + pub cancel_token: RawCancelToken, + pub hostname: String, // for pg_sni router + pub user_info: ComputeUserInfo, } impl CancelClosure { - pub(crate) fn new( - socket_addr: SocketAddr, - cancel_token: RawCancelToken, - hostname: String, - user_info: ComputeUserInfo, - ) -> Self { - Self { - socket_addr, - cancel_token, - hostname, - user_info, - } - } /// Cancels the query running on user's compute node. pub(crate) async fn try_cancel_query( &self, diff --git a/proxy/src/compute/mod.rs b/proxy/src/compute/mod.rs index 7b9183b05e..1e3631363e 100644 --- a/proxy/src/compute/mod.rs +++ b/proxy/src/compute/mod.rs @@ -7,17 +7,15 @@ use std::net::{IpAddr, SocketAddr}; use futures::{FutureExt, TryFutureExt}; use itertools::Itertools; use postgres_client::config::{AuthKeys, ChannelBinding, SslMode}; +use postgres_client::connect_raw::StartupStream; use postgres_client::maybe_tls_stream::MaybeTlsStream; use postgres_client::tls::MakeTlsConnect; -use postgres_client::{NoTls, RawCancelToken, RawConnection}; -use postgres_protocol::message::backend::NoticeResponseBody; use thiserror::Error; use tokio::net::{TcpStream, lookup_host}; use tracing::{debug, error, info, warn}; -use crate::auth::backend::{ComputeCredentialKeys, ComputeUserInfo}; +use crate::auth::backend::ComputeCredentialKeys; use crate::auth::parse_endpoint_param; -use crate::cancellation::CancelClosure; use crate::compute::tls::TlsError; use crate::config::ComputeConfig; use crate::context::RequestContext; @@ -236,8 +234,7 @@ impl AuthInfo { &self, ctx: &RequestContext, compute: &mut ComputeConnection, - user_info: &ComputeUserInfo, - ) -> Result { + ) -> Result<(), PostgresError> { // client config with stubbed connect info. // TODO(conrad): should we rewrite this to bypass tokio-postgres2 entirely, // utilising pqproto.rs. @@ -247,39 +244,10 @@ impl AuthInfo { let tmp_config = self.enrich(tmp_config); let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute); - let connection = tmp_config - .tls_and_authenticate(&mut compute.stream, NoTls) - .await?; + tmp_config.authenticate(&mut compute.stream).await?; drop(pause); - let RawConnection { - stream: _, - parameters, - delayed_notice, - process_id, - secret_key, - } = connection; - - tracing::Span::current().record("pid", tracing::field::display(process_id)); - - // NB: CancelToken is supposed to hold socket_addr, but we use connect_raw. - // Yet another reason to rework the connection establishing code. - let cancel_closure = CancelClosure::new( - compute.socket_addr, - RawCancelToken { - ssl_mode: compute.ssl_mode, - process_id, - secret_key, - }, - compute.hostname.to_string(), - user_info.clone(), - ); - - Ok(PostgresSettings { - params: parameters, - cancel_closure, - delayed_notice, - }) + Ok(()) } } @@ -343,21 +311,9 @@ impl ConnectInfo { pub type RustlsStream = >::Stream; pub type MaybeRustlsStream = MaybeTlsStream; -// TODO(conrad): we don't need to parse these. -// These are just immediately forwarded back to the client. -// We could instead stream them out instead of reading them into memory. -pub struct PostgresSettings { - /// PostgreSQL connection parameters. - pub params: std::collections::HashMap, - /// Query cancellation token. - pub cancel_closure: CancelClosure, - /// Notices received from compute after authenticating - pub delayed_notice: Vec, -} - pub struct ComputeConnection { /// Socket connected to a compute node. - pub stream: MaybeTlsStream, + pub stream: StartupStream, /// Labels for proxy's metrics. pub aux: MetricsAuxInfo, pub hostname: Host, @@ -390,6 +346,7 @@ impl ConnectInfo { ctx.get_testodrome_id().unwrap_or_default(), ); + let stream = StartupStream::new(stream); let connection = ComputeConnection { stream, socket_addr, diff --git a/proxy/src/console_redirect_proxy.rs b/proxy/src/console_redirect_proxy.rs index 014317d823..639cd123e1 100644 --- a/proxy/src/console_redirect_proxy.rs +++ b/proxy/src/console_redirect_proxy.rs @@ -1,12 +1,13 @@ use std::sync::Arc; use futures::{FutureExt, TryFutureExt}; +use postgres_client::RawCancelToken; use tokio::io::{AsyncRead, AsyncWrite}; use tokio_util::sync::CancellationToken; use tracing::{Instrument, debug, error, info}; use crate::auth::backend::ConsoleRedirectBackend; -use crate::cancellation::CancellationHandler; +use crate::cancellation::{CancelClosure, CancellationHandler}; use crate::config::{ProxyConfig, ProxyProtocolV2}; use crate::context::RequestContext; use crate::error::ReportableError; @@ -16,7 +17,7 @@ use crate::pglb::handshake::{HandshakeData, handshake}; use crate::pglb::passthrough::ProxyPassthrough; use crate::protocol2::{ConnectHeader, ConnectionInfo, read_proxy_protocol}; use crate::proxy::connect_compute::{TcpMechanism, connect_to_compute}; -use crate::proxy::{ErrorSource, finish_client_init}; +use crate::proxy::{ErrorSource, forward_compute_params_to_client, send_client_greeting}; use crate::util::run_until_cancelled; pub async fn task_main( @@ -226,21 +227,19 @@ pub(crate) async fn handle_client( .or_else(|e| async { Err(stream.throw_error(e, Some(ctx)).await) }) .await?; - let pg_settings = auth_info - .authenticate(ctx, &mut node, &user_info) + auth_info + .authenticate(ctx, &mut node) .or_else(|e| async { Err(stream.throw_error(e, Some(ctx)).await) }) .await?; + send_client_greeting(ctx, &config.greetings, &mut stream); let session = cancellation_handler.get_key(); - finish_client_init( - ctx, - &pg_settings, - *session.key(), - &mut stream, - &config.greetings, - ); + let (process_id, secret_key) = + forward_compute_params_to_client(ctx, *session.key(), &mut stream, &mut node.stream) + .await?; let stream = stream.flush_and_into_inner().await?; + let hostname = node.hostname.to_string(); let session_id = ctx.session_id(); let (cancel_on_shutdown, cancel) = tokio::sync::oneshot::channel(); @@ -249,7 +248,16 @@ pub(crate) async fn handle_client( .maintain_cancel_key( session_id, cancel, - &pg_settings.cancel_closure, + &CancelClosure { + socket_addr: node.socket_addr, + cancel_token: RawCancelToken { + ssl_mode: node.ssl_mode, + process_id, + secret_key, + }, + hostname, + user_info, + }, &config.connect_to_compute, ) .await; @@ -257,7 +265,7 @@ pub(crate) async fn handle_client( Ok(Some(ProxyPassthrough { client: stream, - compute: node.stream, + compute: node.stream.into_framed().into_inner(), aux: node.aux, private_link_id: None, diff --git a/proxy/src/pglb/mod.rs b/proxy/src/pglb/mod.rs index c4cab155c5..999fa6eb32 100644 --- a/proxy/src/pglb/mod.rs +++ b/proxy/src/pglb/mod.rs @@ -319,7 +319,7 @@ pub(crate) async fn handle_connection( Ok(Some(ProxyPassthrough { client, - compute: node.stream, + compute: node.stream.into_framed().into_inner(), aux: node.aux, private_link_id, diff --git a/proxy/src/pqproto.rs b/proxy/src/pqproto.rs index 680a23c435..7a68d430db 100644 --- a/proxy/src/pqproto.rs +++ b/proxy/src/pqproto.rs @@ -313,6 +313,14 @@ impl WriteBuf { self.0.set_position(0); } + /// Shrinks the buffer if efficient to do so, and returns the remaining size. + pub fn occupied_len(&mut self) -> usize { + if self.should_shrink() { + self.shrink(); + } + self.0.get_mut().len() + } + /// Write a raw message to the internal buffer. /// /// The size_hint value is only a hint for reserving space. It's ok if it's incorrect, since diff --git a/proxy/src/proxy/mod.rs b/proxy/src/proxy/mod.rs index 8b7c4ff55d..053726505d 100644 --- a/proxy/src/proxy/mod.rs +++ b/proxy/src/proxy/mod.rs @@ -9,18 +9,23 @@ use std::collections::HashSet; use std::convert::Infallible; use std::sync::Arc; +use futures::TryStreamExt; use itertools::Itertools; use once_cell::sync::OnceCell; +use postgres_client::RawCancelToken; +use postgres_client::connect_raw::StartupStream; +use postgres_protocol::message::backend::Message; use regex::Regex; use serde::{Deserialize, Serialize}; use smol_str::{SmolStr, format_smolstr}; use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::net::TcpStream; use tokio::sync::oneshot; use tracing::Instrument; use crate::cache::Cache; -use crate::cancellation::CancellationHandler; -use crate::compute::ComputeConnection; +use crate::cancellation::{CancelClosure, CancellationHandler}; +use crate::compute::{ComputeConnection, PostgresError, RustlsStream}; use crate::config::ProxyConfig; use crate::context::RequestContext; use crate::control_plane::client::ControlPlaneClient; @@ -105,7 +110,7 @@ pub(crate) async fn handle_client( // the compute was cached, and we connected, but the compute cache was actually stale // and is associated with the wrong endpoint. We detect this when the **authentication** fails. // As such, we retry once here if the `authenticate` function fails and the error is valid to retry. - let pg_settings = loop { + loop { attempt += 1; // TODO: callback to pglb @@ -127,9 +132,12 @@ pub(crate) async fn handle_client( unreachable!("ensured above"); }; - let res = auth_info.authenticate(ctx, &mut node, user_info).await; + let res = auth_info.authenticate(ctx, &mut node).await; match res { - Ok(pg_settings) => break pg_settings, + Ok(()) => { + send_client_greeting(ctx, &config.greetings, client); + break; + } Err(e) if attempt < 2 && e.should_retry_wake_compute() => { tracing::warn!(error = ?e, "retrying wake compute"); @@ -141,11 +149,17 @@ pub(crate) async fn handle_client( } Err(e) => Err(client.throw_error(e, Some(ctx)).await)?, } + } + + let auth::Backend::ControlPlane(_, user_info) = backend else { + unreachable!("ensured above"); }; let session = cancellation_handler.get_key(); - finish_client_init(ctx, &pg_settings, *session.key(), client, &config.greetings); + let (process_id, secret_key) = + forward_compute_params_to_client(ctx, *session.key(), client, &mut node.stream).await?; + let hostname = node.hostname.to_string(); let session_id = ctx.session_id(); let (cancel_on_shutdown, cancel) = oneshot::channel(); @@ -154,7 +168,16 @@ pub(crate) async fn handle_client( .maintain_cancel_key( session_id, cancel, - &pg_settings.cancel_closure, + &CancelClosure { + socket_addr: node.socket_addr, + cancel_token: RawCancelToken { + ssl_mode: node.ssl_mode, + process_id, + secret_key, + }, + hostname, + user_info, + }, &config.connect_to_compute, ) .await; @@ -163,35 +186,18 @@ pub(crate) async fn handle_client( Ok((node, cancel_on_shutdown)) } -/// Finish client connection initialization: confirm auth success, send params, etc. -pub(crate) fn finish_client_init( +/// Greet the client with any useful information. +pub(crate) fn send_client_greeting( ctx: &RequestContext, - settings: &compute::PostgresSettings, - cancel_key_data: CancelKeyData, - client: &mut PqStream, greetings: &String, + client: &mut PqStream, ) { - // Forward all deferred notices to the client. - for notice in &settings.delayed_notice { - client.write_raw(notice.as_bytes().len(), b'N', |buf| { - buf.extend_from_slice(notice.as_bytes()); - }); - } - // Expose session_id to clients if we have a greeting message. if !greetings.is_empty() { let session_msg = format!("{}, session_id: {}", greetings, ctx.session_id()); client.write_message(BeMessage::NoticeResponse(session_msg.as_str())); } - // Forward all postgres connection params to the client. - for (name, value) in &settings.params { - client.write_message(BeMessage::ParameterStatus { - name: name.as_bytes(), - value: value.as_bytes(), - }); - } - // Forward recorded latencies for probing requests if let Some(testodrome_id) = ctx.get_testodrome_id() { client.write_message(BeMessage::ParameterStatus { @@ -221,9 +227,63 @@ pub(crate) fn finish_client_init( value: latency_measured.retry.as_micros().to_string().as_bytes(), }); } +} - client.write_message(BeMessage::BackendKeyData(cancel_key_data)); - client.write_message(BeMessage::ReadyForQuery); +pub(crate) async fn forward_compute_params_to_client( + ctx: &RequestContext, + cancel_key_data: CancelKeyData, + client: &mut PqStream, + compute: &mut StartupStream, +) -> Result<(i32, i32), ClientRequestError> { + let mut process_id = 0; + let mut secret_key = 0; + + let err = loop { + // if the client buffer is too large, let's write out some bytes now to save some space + client.write_if_full().await?; + + let msg = match compute.try_next().await { + Ok(msg) => msg, + Err(e) => break postgres_client::Error::io(e), + }; + + match msg { + // Send our cancellation key data instead. + Some(Message::BackendKeyData(body)) => { + client.write_message(BeMessage::BackendKeyData(cancel_key_data)); + process_id = body.process_id(); + secret_key = body.secret_key(); + } + // Forward all postgres connection params to the client. + Some(Message::ParameterStatus(body)) => { + if let Ok(name) = body.name() + && let Ok(value) = body.value() + { + client.write_message(BeMessage::ParameterStatus { + name: name.as_bytes(), + value: value.as_bytes(), + }); + } + } + // Forward all notices to the client. + Some(Message::NoticeResponse(notice)) => { + client.write_raw(notice.as_bytes().len(), b'N', |buf| { + buf.extend_from_slice(notice.as_bytes()); + }); + } + Some(Message::ReadyForQuery(_)) => { + client.write_message(BeMessage::ReadyForQuery); + return Ok((process_id, secret_key)); + } + Some(Message::ErrorResponse(body)) => break postgres_client::Error::db(body), + Some(_) => break postgres_client::Error::unexpected_message(), + None => break postgres_client::Error::closed(), + } + }; + + Err(client + .throw_error(PostgresError::Postgres(err), Some(ctx)) + .await)? } #[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)] diff --git a/proxy/src/stream.rs b/proxy/src/stream.rs index 4e55654515..d6a43df188 100644 --- a/proxy/src/stream.rs +++ b/proxy/src/stream.rs @@ -154,6 +154,15 @@ impl PqStream { message.write_message(&mut self.write); } + /// Write the buffer to the socket until we have some more space again. + pub async fn write_if_full(&mut self) -> io::Result<()> { + while self.write.occupied_len() > 2048 { + self.stream.write_buf(&mut self.write).await?; + } + + Ok(()) + } + /// Flush the output buffer into the underlying stream. /// /// This is cancel safe.