fix(proxy): forward notifications from authentication (#9948)

Fixes https://github.com/neondatabase/cloud/issues/20973. 

This refactors `connect_raw` in order to return direct access to the
delayed notices.

I cannot find a way to test this with psycopg2 unfortunately, although
testing it with psql does return the expected results.
This commit is contained in:
Conrad Ludgate
2024-12-02 12:29:57 +00:00
committed by GitHub
parent bd09369198
commit cd1d2d1996
11 changed files with 117 additions and 58 deletions

View File

@@ -565,6 +565,8 @@ pub enum BeMessage<'a> {
/// Batch of interpreted, shard filtered WAL records,
/// ready for the pageserver to ingest
InterpretedWalRecords(InterpretedWalRecordsBody<'a>),
Raw(u8, &'a [u8]),
}
/// Common shorthands.
@@ -754,6 +756,10 @@ impl BeMessage<'_> {
/// one more buffer.
pub fn write(buf: &mut BytesMut, message: &BeMessage) -> Result<(), ProtocolError> {
match message {
BeMessage::Raw(code, data) => {
buf.put_u8(*code);
write_body(buf, |b| b.put_slice(data))
}
BeMessage::AuthenticationOk => {
buf.put_u8(b'R');
write_body(buf, |buf| {

View File

@@ -541,6 +541,10 @@ impl NoticeResponseBody {
pub fn fields(&self) -> ErrorFields<'_> {
ErrorFields { buf: &self.storage }
}
pub fn as_bytes(&self) -> &[u8] {
&self.storage
}
}
pub struct NotificationResponseBody {

View File

@@ -10,10 +10,10 @@ use tokio::net::TcpStream;
/// connection.
#[derive(Clone)]
pub struct CancelToken {
pub(crate) socket_config: Option<SocketConfig>,
pub(crate) ssl_mode: SslMode,
pub(crate) process_id: i32,
pub(crate) secret_key: i32,
pub socket_config: Option<SocketConfig>,
pub ssl_mode: SslMode,
pub process_id: i32,
pub secret_key: i32,
}
impl CancelToken {

View File

@@ -138,7 +138,7 @@ impl InnerClient {
}
#[derive(Clone)]
pub(crate) struct SocketConfig {
pub struct SocketConfig {
pub host: Host,
pub port: u16,
pub connect_timeout: Option<Duration>,
@@ -152,7 +152,7 @@ pub(crate) struct SocketConfig {
pub struct Client {
inner: Arc<InnerClient>,
socket_config: Option<SocketConfig>,
socket_config: SocketConfig,
ssl_mode: SslMode,
process_id: i32,
secret_key: i32,
@@ -161,6 +161,7 @@ pub struct Client {
impl Client {
pub(crate) fn new(
sender: mpsc::UnboundedSender<Request>,
socket_config: SocketConfig,
ssl_mode: SslMode,
process_id: i32,
secret_key: i32,
@@ -172,7 +173,7 @@ impl Client {
buffer: Default::default(),
}),
socket_config: None,
socket_config,
ssl_mode,
process_id,
secret_key,
@@ -188,10 +189,6 @@ impl Client {
&self.inner
}
pub(crate) fn set_socket_config(&mut self, socket_config: SocketConfig) {
self.socket_config = Some(socket_config);
}
/// Creates a new prepared statement.
///
/// Prepared statements can be executed repeatedly, and may contain query parameters (indicated by `$1`, `$2`, etc),
@@ -412,7 +409,7 @@ impl Client {
/// connection associated with this client.
pub fn cancel_token(&self) -> CancelToken {
CancelToken {
socket_config: self.socket_config.clone(),
socket_config: Some(self.socket_config.clone()),
ssl_mode: self.ssl_mode,
process_id: self.process_id,
secret_key: self.secret_key,

View File

@@ -2,6 +2,7 @@
use crate::connect::connect;
use crate::connect_raw::connect_raw;
use crate::connect_raw::RawConnection;
use crate::tls::MakeTlsConnect;
use crate::tls::TlsConnect;
use crate::{Client, Connection, Error};
@@ -485,14 +486,11 @@ impl Config {
connect(tls, self).await
}
/// Connects to a PostgreSQL database over an arbitrary stream.
///
/// All of the settings other than `user`, `password`, `dbname`, `options`, and `application_name` name are ignored.
pub async fn connect_raw<S, T>(
&self,
stream: S,
tls: T,
) -> Result<(Client, Connection<S, T::Stream>), Error>
) -> Result<RawConnection<S, T::Stream>, Error>
where
S: AsyncRead + AsyncWrite + Unpin,
T: TlsConnect<S>,

View File

@@ -1,13 +1,16 @@
use crate::client::SocketConfig;
use crate::codec::BackendMessage;
use crate::config::{Host, TargetSessionAttrs};
use crate::connect_raw::connect_raw;
use crate::connect_socket::connect_socket;
use crate::tls::{MakeTlsConnect, TlsConnect};
use crate::{Client, Config, Connection, Error, SimpleQueryMessage};
use crate::{Client, Config, Connection, Error, RawConnection, SimpleQueryMessage};
use futures_util::{future, pin_mut, Future, FutureExt, Stream};
use postgres_protocol2::message::backend::Message;
use std::io;
use std::task::Poll;
use tokio::net::TcpStream;
use tokio::sync::mpsc;
pub async fn connect<T>(
mut tls: T,
@@ -60,7 +63,36 @@ where
T: TlsConnect<TcpStream>,
{
let socket = connect_socket(host, port, config.connect_timeout).await?;
let (mut client, mut connection) = connect_raw(socket, tls, config).await?;
let RawConnection {
stream,
parameters,
delayed_notice,
process_id,
secret_key,
} = connect_raw(socket, tls, config).await?;
let socket_config = SocketConfig {
host: host.clone(),
port,
connect_timeout: config.connect_timeout,
};
let (sender, receiver) = mpsc::unbounded_channel();
let client = Client::new(
sender,
socket_config,
config.ssl_mode,
process_id,
secret_key,
);
// delayed notices are always sent as "Async" messages.
let delayed = delayed_notice
.into_iter()
.map(|m| BackendMessage::Async(Message::NoticeResponse(m)))
.collect();
let mut connection = Connection::new(stream, delayed, parameters, receiver);
if let TargetSessionAttrs::ReadWrite = config.target_session_attrs {
let rows = client.simple_query_raw("SHOW transaction_read_only");
@@ -102,11 +134,5 @@ where
}
}
client.set_socket_config(SocketConfig {
host: host.clone(),
port,
connect_timeout: config.connect_timeout,
});
Ok((client, connection))
}

View File

@@ -3,27 +3,26 @@ use crate::config::{self, AuthKeys, Config, ReplicationMode};
use crate::connect_tls::connect_tls;
use crate::maybe_tls_stream::MaybeTlsStream;
use crate::tls::{TlsConnect, TlsStream};
use crate::{Client, Connection, Error};
use crate::Error;
use bytes::BytesMut;
use fallible_iterator::FallibleIterator;
use futures_util::{ready, Sink, SinkExt, Stream, TryStreamExt};
use postgres_protocol2::authentication;
use postgres_protocol2::authentication::sasl;
use postgres_protocol2::authentication::sasl::ScramSha256;
use postgres_protocol2::message::backend::{AuthenticationSaslBody, Message};
use postgres_protocol2::message::backend::{AuthenticationSaslBody, Message, NoticeResponseBody};
use postgres_protocol2::message::frontend;
use std::collections::{HashMap, VecDeque};
use std::collections::HashMap;
use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::sync::mpsc;
use tokio_util::codec::Framed;
pub struct StartupStream<S, T> {
inner: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
buf: BackendMessages,
delayed: VecDeque<BackendMessage>,
delayed_notice: Vec<NoticeResponseBody>,
}
impl<S, T> Sink<FrontendMessage> for StartupStream<S, T>
@@ -78,11 +77,19 @@ where
}
}
pub struct RawConnection<S, T> {
pub stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
pub parameters: HashMap<String, String>,
pub delayed_notice: Vec<NoticeResponseBody>,
pub process_id: i32,
pub secret_key: i32,
}
pub async fn connect_raw<S, T>(
stream: S,
tls: T,
config: &Config,
) -> Result<(Client, Connection<S, T::Stream>), Error>
) -> Result<RawConnection<S, T::Stream>, Error>
where
S: AsyncRead + AsyncWrite + Unpin,
T: TlsConnect<S>,
@@ -97,18 +104,20 @@ where
},
),
buf: BackendMessages::empty(),
delayed: VecDeque::new(),
delayed_notice: Vec::new(),
};
startup(&mut stream, config).await?;
authenticate(&mut stream, config).await?;
let (process_id, secret_key, parameters) = read_info(&mut stream).await?;
let (sender, receiver) = mpsc::unbounded_channel();
let client = Client::new(sender, config.ssl_mode, process_id, secret_key);
let connection = Connection::new(stream.inner, stream.delayed, parameters, receiver);
Ok((client, connection))
Ok(RawConnection {
stream: stream.inner,
parameters,
delayed_notice: stream.delayed_notice,
process_id,
secret_key,
})
}
async fn startup<S, T>(stream: &mut StartupStream<S, T>, config: &Config) -> Result<(), Error>
@@ -347,9 +356,7 @@ where
body.value().map_err(Error::parse)?.to_string(),
);
}
Some(msg @ Message::NoticeResponse(_)) => {
stream.delayed.push_back(BackendMessage::Async(msg))
}
Some(Message::NoticeResponse(body)) => stream.delayed_notice.push(body),
Some(Message::ReadyForQuery(_)) => return Ok((process_id, secret_key, parameters)),
Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
Some(_) => return Err(Error::unexpected_message()),

View File

@@ -1,9 +1,10 @@
//! An asynchronous, pipelined, PostgreSQL client.
#![warn(rust_2018_idioms, clippy::all, missing_docs)]
#![warn(rust_2018_idioms, clippy::all)]
pub use crate::cancel_token::CancelToken;
pub use crate::client::Client;
pub use crate::client::{Client, SocketConfig};
pub use crate::config::Config;
pub use crate::connect_raw::RawConnection;
pub use crate::connection::Connection;
use crate::error::DbError;
pub use crate::error::Error;

View File

@@ -6,6 +6,7 @@ use std::time::Duration;
use futures::{FutureExt, TryFutureExt};
use itertools::Itertools;
use once_cell::sync::OnceCell;
use postgres_protocol::message::backend::NoticeResponseBody;
use pq_proto::StartupMessageParams;
use rustls::client::danger::ServerCertVerifier;
use rustls::crypto::ring;
@@ -13,6 +14,7 @@ use rustls::pki_types::InvalidDnsNameError;
use thiserror::Error;
use tokio::net::TcpStream;
use tokio_postgres::tls::MakeTlsConnect;
use tokio_postgres::{CancelToken, RawConnection};
use tracing::{debug, error, info, warn};
use crate::auth::parse_endpoint_param;
@@ -277,6 +279,8 @@ pub(crate) struct PostgresConnection {
pub(crate) cancel_closure: CancelClosure,
/// Labels for proxy's metrics.
pub(crate) aux: MetricsAuxInfo,
/// Notices received from compute after authenticating
pub(crate) delayed_notice: Vec<NoticeResponseBody>,
_guage: NumDbConnectionsGuard<'static>,
}
@@ -322,10 +326,19 @@ impl ConnCfg {
// connect_raw() will not use TLS if sslmode is "disable"
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let (client, connection) = self.0.connect_raw(stream, tls).await?;
let connection = self.0.connect_raw(stream, tls).await?;
drop(pause);
tracing::Span::current().record("pid", tracing::field::display(client.get_process_id()));
let stream = connection.stream.into_inner();
let RawConnection {
stream,
parameters,
delayed_notice,
process_id,
secret_key,
} = connection;
tracing::Span::current().record("pid", tracing::field::display(process_id));
let stream = stream.into_inner();
// TODO: lots of useful info but maybe we can move it elsewhere (eg traces?)
info!(
@@ -334,18 +347,23 @@ impl ConnCfg {
self.0.get_ssl_mode()
);
// This is very ugly but as of now there's no better way to
// extract the connection parameters from tokio-postgres' connection.
// TODO: solve this problem in a more elegant manner (e.g. the new library).
let params = connection.parameters;
// NB: CancelToken is supposed to hold socket_addr, but we use connect_raw.
// Yet another reason to rework the connection establishing code.
let cancel_closure = CancelClosure::new(socket_addr, client.cancel_token(), vec![]);
let cancel_closure = CancelClosure::new(
socket_addr,
CancelToken {
socket_config: None,
ssl_mode: self.0.get_ssl_mode(),
process_id,
secret_key,
},
vec![],
);
let connection = PostgresConnection {
stream,
params,
params: parameters,
delayed_notice,
cancel_closure,
aux,
_guage: Metrics::get().proxy.db_connections.guard(ctx.protocol()),

View File

@@ -384,11 +384,13 @@ pub(crate) async fn prepare_client_connection<P>(
// The new token (cancel_key_data) will be sent to the client.
let cancel_key_data = session.enable_query_cancellation(node.cancel_closure.clone());
// Forward all deferred notices to the client.
for notice in &node.delayed_notice {
stream.write_message_noflush(&Be::Raw(b'N', notice.as_bytes()))?;
}
// Forward all postgres connection params to the client.
// Right now the implementation is very hacky and inefficent (ideally,
// we don't need an intermediate hashmap), but at least it should be correct.
for (name, value) in &node.params {
// TODO: Theoretically, this could result in a big pile of params...
stream.write_message_noflush(&Be::ParameterStatus {
name: name.as_bytes(),
value: value.as_bytes(),

View File

@@ -233,7 +233,7 @@ async fn handshake_tls() -> anyhow::Result<()> {
generate_tls_config("generic-project-name.localhost", "localhost")?;
let proxy = tokio::spawn(dummy_proxy(client, Some(server_config), NoAuth));
let (_client, _conn) = tokio_postgres::Config::new()
let _conn = tokio_postgres::Config::new()
.user("john_doe")
.dbname("earth")
.ssl_mode(SslMode::Require)
@@ -249,7 +249,7 @@ async fn handshake_raw() -> anyhow::Result<()> {
let proxy = tokio::spawn(dummy_proxy(client, None, NoAuth));
let (_client, _conn) = tokio_postgres::Config::new()
let _conn = tokio_postgres::Config::new()
.user("john_doe")
.dbname("earth")
.options("project=generic-project-name")
@@ -296,7 +296,7 @@ async fn scram_auth_good(#[case] password: &str) -> anyhow::Result<()> {
Scram::new(password).await?,
));
let (_client, _conn) = tokio_postgres::Config::new()
let _conn = tokio_postgres::Config::new()
.channel_binding(tokio_postgres::config::ChannelBinding::Require)
.user("user")
.dbname("db")
@@ -320,7 +320,7 @@ async fn scram_auth_disable_channel_binding() -> anyhow::Result<()> {
Scram::new("password").await?,
));
let (_client, _conn) = tokio_postgres::Config::new()
let _conn = tokio_postgres::Config::new()
.channel_binding(tokio_postgres::config::ChannelBinding::Disable)
.user("user")
.dbname("db")