diff --git a/libs/pq_proto/src/lib.rs b/libs/pq_proto/src/lib.rs index 4b0331999d..43dfbc22a4 100644 --- a/libs/pq_proto/src/lib.rs +++ b/libs/pq_proto/src/lib.rs @@ -565,6 +565,8 @@ pub enum BeMessage<'a> { /// Batch of interpreted, shard filtered WAL records, /// ready for the pageserver to ingest InterpretedWalRecords(InterpretedWalRecordsBody<'a>), + + Raw(u8, &'a [u8]), } /// Common shorthands. @@ -754,6 +756,10 @@ impl BeMessage<'_> { /// one more buffer. pub fn write(buf: &mut BytesMut, message: &BeMessage) -> Result<(), ProtocolError> { match message { + BeMessage::Raw(code, data) => { + buf.put_u8(*code); + write_body(buf, |b| b.put_slice(data)) + } BeMessage::AuthenticationOk => { buf.put_u8(b'R'); write_body(buf, |buf| { diff --git a/libs/proxy/postgres-protocol2/src/message/backend.rs b/libs/proxy/postgres-protocol2/src/message/backend.rs index 356d142f3f..33d77fc252 100644 --- a/libs/proxy/postgres-protocol2/src/message/backend.rs +++ b/libs/proxy/postgres-protocol2/src/message/backend.rs @@ -541,6 +541,10 @@ impl NoticeResponseBody { pub fn fields(&self) -> ErrorFields<'_> { ErrorFields { buf: &self.storage } } + + pub fn as_bytes(&self) -> &[u8] { + &self.storage + } } pub struct NotificationResponseBody { diff --git a/libs/proxy/tokio-postgres2/src/cancel_token.rs b/libs/proxy/tokio-postgres2/src/cancel_token.rs index b949bf358f..a10e8bf5c3 100644 --- a/libs/proxy/tokio-postgres2/src/cancel_token.rs +++ b/libs/proxy/tokio-postgres2/src/cancel_token.rs @@ -10,10 +10,10 @@ use tokio::net::TcpStream; /// connection. #[derive(Clone)] pub struct CancelToken { - pub(crate) socket_config: Option, - pub(crate) ssl_mode: SslMode, - pub(crate) process_id: i32, - pub(crate) secret_key: i32, + pub socket_config: Option, + pub ssl_mode: SslMode, + pub process_id: i32, + pub secret_key: i32, } impl CancelToken { diff --git a/libs/proxy/tokio-postgres2/src/client.rs b/libs/proxy/tokio-postgres2/src/client.rs index 96200b71e7..a7cd53afc3 100644 --- a/libs/proxy/tokio-postgres2/src/client.rs +++ b/libs/proxy/tokio-postgres2/src/client.rs @@ -138,7 +138,7 @@ impl InnerClient { } #[derive(Clone)] -pub(crate) struct SocketConfig { +pub struct SocketConfig { pub host: Host, pub port: u16, pub connect_timeout: Option, @@ -152,7 +152,7 @@ pub(crate) struct SocketConfig { pub struct Client { inner: Arc, - socket_config: Option, + socket_config: SocketConfig, ssl_mode: SslMode, process_id: i32, secret_key: i32, @@ -161,6 +161,7 @@ pub struct Client { impl Client { pub(crate) fn new( sender: mpsc::UnboundedSender, + socket_config: SocketConfig, ssl_mode: SslMode, process_id: i32, secret_key: i32, @@ -172,7 +173,7 @@ impl Client { buffer: Default::default(), }), - socket_config: None, + socket_config, ssl_mode, process_id, secret_key, @@ -188,10 +189,6 @@ impl Client { &self.inner } - pub(crate) fn set_socket_config(&mut self, socket_config: SocketConfig) { - self.socket_config = Some(socket_config); - } - /// Creates a new prepared statement. /// /// Prepared statements can be executed repeatedly, and may contain query parameters (indicated by `$1`, `$2`, etc), @@ -412,7 +409,7 @@ impl Client { /// connection associated with this client. pub fn cancel_token(&self) -> CancelToken { CancelToken { - socket_config: self.socket_config.clone(), + socket_config: Some(self.socket_config.clone()), ssl_mode: self.ssl_mode, process_id: self.process_id, secret_key: self.secret_key, diff --git a/libs/proxy/tokio-postgres2/src/config.rs b/libs/proxy/tokio-postgres2/src/config.rs index 969c20ba47..26124b38ef 100644 --- a/libs/proxy/tokio-postgres2/src/config.rs +++ b/libs/proxy/tokio-postgres2/src/config.rs @@ -2,6 +2,7 @@ use crate::connect::connect; use crate::connect_raw::connect_raw; +use crate::connect_raw::RawConnection; use crate::tls::MakeTlsConnect; use crate::tls::TlsConnect; use crate::{Client, Connection, Error}; @@ -485,14 +486,11 @@ impl Config { connect(tls, self).await } - /// Connects to a PostgreSQL database over an arbitrary stream. - /// - /// All of the settings other than `user`, `password`, `dbname`, `options`, and `application_name` name are ignored. pub async fn connect_raw( &self, stream: S, tls: T, - ) -> Result<(Client, Connection), Error> + ) -> Result, Error> where S: AsyncRead + AsyncWrite + Unpin, T: TlsConnect, diff --git a/libs/proxy/tokio-postgres2/src/connect.rs b/libs/proxy/tokio-postgres2/src/connect.rs index 7517fe0cde..98067d91f9 100644 --- a/libs/proxy/tokio-postgres2/src/connect.rs +++ b/libs/proxy/tokio-postgres2/src/connect.rs @@ -1,13 +1,16 @@ use crate::client::SocketConfig; +use crate::codec::BackendMessage; use crate::config::{Host, TargetSessionAttrs}; use crate::connect_raw::connect_raw; use crate::connect_socket::connect_socket; use crate::tls::{MakeTlsConnect, TlsConnect}; -use crate::{Client, Config, Connection, Error, SimpleQueryMessage}; +use crate::{Client, Config, Connection, Error, RawConnection, SimpleQueryMessage}; use futures_util::{future, pin_mut, Future, FutureExt, Stream}; +use postgres_protocol2::message::backend::Message; use std::io; use std::task::Poll; use tokio::net::TcpStream; +use tokio::sync::mpsc; pub async fn connect( mut tls: T, @@ -60,7 +63,36 @@ where T: TlsConnect, { let socket = connect_socket(host, port, config.connect_timeout).await?; - let (mut client, mut connection) = connect_raw(socket, tls, config).await?; + let RawConnection { + stream, + parameters, + delayed_notice, + process_id, + secret_key, + } = connect_raw(socket, tls, config).await?; + + let socket_config = SocketConfig { + host: host.clone(), + port, + connect_timeout: config.connect_timeout, + }; + + let (sender, receiver) = mpsc::unbounded_channel(); + let client = Client::new( + sender, + socket_config, + config.ssl_mode, + process_id, + secret_key, + ); + + // delayed notices are always sent as "Async" messages. + let delayed = delayed_notice + .into_iter() + .map(|m| BackendMessage::Async(Message::NoticeResponse(m))) + .collect(); + + let mut connection = Connection::new(stream, delayed, parameters, receiver); if let TargetSessionAttrs::ReadWrite = config.target_session_attrs { let rows = client.simple_query_raw("SHOW transaction_read_only"); @@ -102,11 +134,5 @@ where } } - client.set_socket_config(SocketConfig { - host: host.clone(), - port, - connect_timeout: config.connect_timeout, - }); - Ok((client, connection)) } diff --git a/libs/proxy/tokio-postgres2/src/connect_raw.rs b/libs/proxy/tokio-postgres2/src/connect_raw.rs index 80677af969..9c6f1a2552 100644 --- a/libs/proxy/tokio-postgres2/src/connect_raw.rs +++ b/libs/proxy/tokio-postgres2/src/connect_raw.rs @@ -3,27 +3,26 @@ use crate::config::{self, AuthKeys, Config, ReplicationMode}; use crate::connect_tls::connect_tls; use crate::maybe_tls_stream::MaybeTlsStream; use crate::tls::{TlsConnect, TlsStream}; -use crate::{Client, Connection, Error}; +use crate::Error; use bytes::BytesMut; use fallible_iterator::FallibleIterator; use futures_util::{ready, Sink, SinkExt, Stream, TryStreamExt}; use postgres_protocol2::authentication; use postgres_protocol2::authentication::sasl; use postgres_protocol2::authentication::sasl::ScramSha256; -use postgres_protocol2::message::backend::{AuthenticationSaslBody, Message}; +use postgres_protocol2::message::backend::{AuthenticationSaslBody, Message, NoticeResponseBody}; use postgres_protocol2::message::frontend; -use std::collections::{HashMap, VecDeque}; +use std::collections::HashMap; use std::io; use std::pin::Pin; use std::task::{Context, Poll}; use tokio::io::{AsyncRead, AsyncWrite}; -use tokio::sync::mpsc; use tokio_util::codec::Framed; pub struct StartupStream { inner: Framed, PostgresCodec>, buf: BackendMessages, - delayed: VecDeque, + delayed_notice: Vec, } impl Sink for StartupStream @@ -78,11 +77,19 @@ where } } +pub struct RawConnection { + pub stream: Framed, PostgresCodec>, + pub parameters: HashMap, + pub delayed_notice: Vec, + pub process_id: i32, + pub secret_key: i32, +} + pub async fn connect_raw( stream: S, tls: T, config: &Config, -) -> Result<(Client, Connection), Error> +) -> Result, Error> where S: AsyncRead + AsyncWrite + Unpin, T: TlsConnect, @@ -97,18 +104,20 @@ where }, ), buf: BackendMessages::empty(), - delayed: VecDeque::new(), + delayed_notice: Vec::new(), }; startup(&mut stream, config).await?; authenticate(&mut stream, config).await?; let (process_id, secret_key, parameters) = read_info(&mut stream).await?; - let (sender, receiver) = mpsc::unbounded_channel(); - let client = Client::new(sender, config.ssl_mode, process_id, secret_key); - let connection = Connection::new(stream.inner, stream.delayed, parameters, receiver); - - Ok((client, connection)) + Ok(RawConnection { + stream: stream.inner, + parameters, + delayed_notice: stream.delayed_notice, + process_id, + secret_key, + }) } async fn startup(stream: &mut StartupStream, config: &Config) -> Result<(), Error> @@ -347,9 +356,7 @@ where body.value().map_err(Error::parse)?.to_string(), ); } - Some(msg @ Message::NoticeResponse(_)) => { - stream.delayed.push_back(BackendMessage::Async(msg)) - } + Some(Message::NoticeResponse(body)) => stream.delayed_notice.push(body), Some(Message::ReadyForQuery(_)) => return Ok((process_id, secret_key, parameters)), Some(Message::ErrorResponse(body)) => return Err(Error::db(body)), Some(_) => return Err(Error::unexpected_message()), diff --git a/libs/proxy/tokio-postgres2/src/lib.rs b/libs/proxy/tokio-postgres2/src/lib.rs index 72ba8172b2..57c639a7de 100644 --- a/libs/proxy/tokio-postgres2/src/lib.rs +++ b/libs/proxy/tokio-postgres2/src/lib.rs @@ -1,9 +1,10 @@ //! An asynchronous, pipelined, PostgreSQL client. -#![warn(rust_2018_idioms, clippy::all, missing_docs)] +#![warn(rust_2018_idioms, clippy::all)] pub use crate::cancel_token::CancelToken; -pub use crate::client::Client; +pub use crate::client::{Client, SocketConfig}; pub use crate::config::Config; +pub use crate::connect_raw::RawConnection; pub use crate::connection::Connection; use crate::error::DbError; pub use crate::error::Error; diff --git a/proxy/src/compute.rs b/proxy/src/compute.rs index 2abe88ac88..b689b97a21 100644 --- a/proxy/src/compute.rs +++ b/proxy/src/compute.rs @@ -6,6 +6,7 @@ use std::time::Duration; use futures::{FutureExt, TryFutureExt}; use itertools::Itertools; use once_cell::sync::OnceCell; +use postgres_protocol::message::backend::NoticeResponseBody; use pq_proto::StartupMessageParams; use rustls::client::danger::ServerCertVerifier; use rustls::crypto::ring; @@ -13,6 +14,7 @@ use rustls::pki_types::InvalidDnsNameError; use thiserror::Error; use tokio::net::TcpStream; use tokio_postgres::tls::MakeTlsConnect; +use tokio_postgres::{CancelToken, RawConnection}; use tracing::{debug, error, info, warn}; use crate::auth::parse_endpoint_param; @@ -277,6 +279,8 @@ pub(crate) struct PostgresConnection { pub(crate) cancel_closure: CancelClosure, /// Labels for proxy's metrics. pub(crate) aux: MetricsAuxInfo, + /// Notices received from compute after authenticating + pub(crate) delayed_notice: Vec, _guage: NumDbConnectionsGuard<'static>, } @@ -322,10 +326,19 @@ impl ConnCfg { // connect_raw() will not use TLS if sslmode is "disable" let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute); - let (client, connection) = self.0.connect_raw(stream, tls).await?; + let connection = self.0.connect_raw(stream, tls).await?; drop(pause); - tracing::Span::current().record("pid", tracing::field::display(client.get_process_id())); - let stream = connection.stream.into_inner(); + + let RawConnection { + stream, + parameters, + delayed_notice, + process_id, + secret_key, + } = connection; + + tracing::Span::current().record("pid", tracing::field::display(process_id)); + let stream = stream.into_inner(); // TODO: lots of useful info but maybe we can move it elsewhere (eg traces?) info!( @@ -334,18 +347,23 @@ impl ConnCfg { self.0.get_ssl_mode() ); - // This is very ugly but as of now there's no better way to - // extract the connection parameters from tokio-postgres' connection. - // TODO: solve this problem in a more elegant manner (e.g. the new library). - let params = connection.parameters; - // NB: CancelToken is supposed to hold socket_addr, but we use connect_raw. // Yet another reason to rework the connection establishing code. - let cancel_closure = CancelClosure::new(socket_addr, client.cancel_token(), vec![]); + let cancel_closure = CancelClosure::new( + socket_addr, + CancelToken { + socket_config: None, + ssl_mode: self.0.get_ssl_mode(), + process_id, + secret_key, + }, + vec![], + ); let connection = PostgresConnection { stream, - params, + params: parameters, + delayed_notice, cancel_closure, aux, _guage: Metrics::get().proxy.db_connections.guard(ctx.protocol()), diff --git a/proxy/src/proxy/mod.rs b/proxy/src/proxy/mod.rs index 956036d29d..af97fb3d71 100644 --- a/proxy/src/proxy/mod.rs +++ b/proxy/src/proxy/mod.rs @@ -384,11 +384,13 @@ pub(crate) async fn prepare_client_connection

( // The new token (cancel_key_data) will be sent to the client. let cancel_key_data = session.enable_query_cancellation(node.cancel_closure.clone()); + // Forward all deferred notices to the client. + for notice in &node.delayed_notice { + stream.write_message_noflush(&Be::Raw(b'N', notice.as_bytes()))?; + } + // Forward all postgres connection params to the client. - // Right now the implementation is very hacky and inefficent (ideally, - // we don't need an intermediate hashmap), but at least it should be correct. for (name, value) in &node.params { - // TODO: Theoretically, this could result in a big pile of params... stream.write_message_noflush(&Be::ParameterStatus { name: name.as_bytes(), value: value.as_bytes(), diff --git a/proxy/src/proxy/tests/mod.rs b/proxy/src/proxy/tests/mod.rs index 2c2c2964b6..15be6c9724 100644 --- a/proxy/src/proxy/tests/mod.rs +++ b/proxy/src/proxy/tests/mod.rs @@ -233,7 +233,7 @@ async fn handshake_tls() -> anyhow::Result<()> { generate_tls_config("generic-project-name.localhost", "localhost")?; let proxy = tokio::spawn(dummy_proxy(client, Some(server_config), NoAuth)); - let (_client, _conn) = tokio_postgres::Config::new() + let _conn = tokio_postgres::Config::new() .user("john_doe") .dbname("earth") .ssl_mode(SslMode::Require) @@ -249,7 +249,7 @@ async fn handshake_raw() -> anyhow::Result<()> { let proxy = tokio::spawn(dummy_proxy(client, None, NoAuth)); - let (_client, _conn) = tokio_postgres::Config::new() + let _conn = tokio_postgres::Config::new() .user("john_doe") .dbname("earth") .options("project=generic-project-name") @@ -296,7 +296,7 @@ async fn scram_auth_good(#[case] password: &str) -> anyhow::Result<()> { Scram::new(password).await?, )); - let (_client, _conn) = tokio_postgres::Config::new() + let _conn = tokio_postgres::Config::new() .channel_binding(tokio_postgres::config::ChannelBinding::Require) .user("user") .dbname("db") @@ -320,7 +320,7 @@ async fn scram_auth_disable_channel_binding() -> anyhow::Result<()> { Scram::new("password").await?, )); - let (_client, _conn) = tokio_postgres::Config::new() + let _conn = tokio_postgres::Config::new() .channel_binding(tokio_postgres::config::ChannelBinding::Disable) .user("user") .dbname("db")