fix(proxy): forward notifications from authentication (#9948)

Fixes https://github.com/neondatabase/cloud/issues/20973. 

This refactors `connect_raw` in order to return direct access to the
delayed notices.

I cannot find a way to test this with psycopg2 unfortunately, although
testing it with psql does return the expected results.
This commit is contained in:
Conrad Ludgate
2024-12-02 12:29:57 +00:00
committed by GitHub
parent bd09369198
commit cd1d2d1996
11 changed files with 117 additions and 58 deletions

View File

@@ -541,6 +541,10 @@ impl NoticeResponseBody {
pub fn fields(&self) -> ErrorFields<'_> {
ErrorFields { buf: &self.storage }
}
pub fn as_bytes(&self) -> &[u8] {
&self.storage
}
}
pub struct NotificationResponseBody {

View File

@@ -10,10 +10,10 @@ use tokio::net::TcpStream;
/// connection.
#[derive(Clone)]
pub struct CancelToken {
pub(crate) socket_config: Option<SocketConfig>,
pub(crate) ssl_mode: SslMode,
pub(crate) process_id: i32,
pub(crate) secret_key: i32,
pub socket_config: Option<SocketConfig>,
pub ssl_mode: SslMode,
pub process_id: i32,
pub secret_key: i32,
}
impl CancelToken {

View File

@@ -138,7 +138,7 @@ impl InnerClient {
}
#[derive(Clone)]
pub(crate) struct SocketConfig {
pub struct SocketConfig {
pub host: Host,
pub port: u16,
pub connect_timeout: Option<Duration>,
@@ -152,7 +152,7 @@ pub(crate) struct SocketConfig {
pub struct Client {
inner: Arc<InnerClient>,
socket_config: Option<SocketConfig>,
socket_config: SocketConfig,
ssl_mode: SslMode,
process_id: i32,
secret_key: i32,
@@ -161,6 +161,7 @@ pub struct Client {
impl Client {
pub(crate) fn new(
sender: mpsc::UnboundedSender<Request>,
socket_config: SocketConfig,
ssl_mode: SslMode,
process_id: i32,
secret_key: i32,
@@ -172,7 +173,7 @@ impl Client {
buffer: Default::default(),
}),
socket_config: None,
socket_config,
ssl_mode,
process_id,
secret_key,
@@ -188,10 +189,6 @@ impl Client {
&self.inner
}
pub(crate) fn set_socket_config(&mut self, socket_config: SocketConfig) {
self.socket_config = Some(socket_config);
}
/// Creates a new prepared statement.
///
/// Prepared statements can be executed repeatedly, and may contain query parameters (indicated by `$1`, `$2`, etc),
@@ -412,7 +409,7 @@ impl Client {
/// connection associated with this client.
pub fn cancel_token(&self) -> CancelToken {
CancelToken {
socket_config: self.socket_config.clone(),
socket_config: Some(self.socket_config.clone()),
ssl_mode: self.ssl_mode,
process_id: self.process_id,
secret_key: self.secret_key,

View File

@@ -2,6 +2,7 @@
use crate::connect::connect;
use crate::connect_raw::connect_raw;
use crate::connect_raw::RawConnection;
use crate::tls::MakeTlsConnect;
use crate::tls::TlsConnect;
use crate::{Client, Connection, Error};
@@ -485,14 +486,11 @@ impl Config {
connect(tls, self).await
}
/// Connects to a PostgreSQL database over an arbitrary stream.
///
/// All of the settings other than `user`, `password`, `dbname`, `options`, and `application_name` name are ignored.
pub async fn connect_raw<S, T>(
&self,
stream: S,
tls: T,
) -> Result<(Client, Connection<S, T::Stream>), Error>
) -> Result<RawConnection<S, T::Stream>, Error>
where
S: AsyncRead + AsyncWrite + Unpin,
T: TlsConnect<S>,

View File

@@ -1,13 +1,16 @@
use crate::client::SocketConfig;
use crate::codec::BackendMessage;
use crate::config::{Host, TargetSessionAttrs};
use crate::connect_raw::connect_raw;
use crate::connect_socket::connect_socket;
use crate::tls::{MakeTlsConnect, TlsConnect};
use crate::{Client, Config, Connection, Error, SimpleQueryMessage};
use crate::{Client, Config, Connection, Error, RawConnection, SimpleQueryMessage};
use futures_util::{future, pin_mut, Future, FutureExt, Stream};
use postgres_protocol2::message::backend::Message;
use std::io;
use std::task::Poll;
use tokio::net::TcpStream;
use tokio::sync::mpsc;
pub async fn connect<T>(
mut tls: T,
@@ -60,7 +63,36 @@ where
T: TlsConnect<TcpStream>,
{
let socket = connect_socket(host, port, config.connect_timeout).await?;
let (mut client, mut connection) = connect_raw(socket, tls, config).await?;
let RawConnection {
stream,
parameters,
delayed_notice,
process_id,
secret_key,
} = connect_raw(socket, tls, config).await?;
let socket_config = SocketConfig {
host: host.clone(),
port,
connect_timeout: config.connect_timeout,
};
let (sender, receiver) = mpsc::unbounded_channel();
let client = Client::new(
sender,
socket_config,
config.ssl_mode,
process_id,
secret_key,
);
// delayed notices are always sent as "Async" messages.
let delayed = delayed_notice
.into_iter()
.map(|m| BackendMessage::Async(Message::NoticeResponse(m)))
.collect();
let mut connection = Connection::new(stream, delayed, parameters, receiver);
if let TargetSessionAttrs::ReadWrite = config.target_session_attrs {
let rows = client.simple_query_raw("SHOW transaction_read_only");
@@ -102,11 +134,5 @@ where
}
}
client.set_socket_config(SocketConfig {
host: host.clone(),
port,
connect_timeout: config.connect_timeout,
});
Ok((client, connection))
}

View File

@@ -3,27 +3,26 @@ use crate::config::{self, AuthKeys, Config, ReplicationMode};
use crate::connect_tls::connect_tls;
use crate::maybe_tls_stream::MaybeTlsStream;
use crate::tls::{TlsConnect, TlsStream};
use crate::{Client, Connection, Error};
use crate::Error;
use bytes::BytesMut;
use fallible_iterator::FallibleIterator;
use futures_util::{ready, Sink, SinkExt, Stream, TryStreamExt};
use postgres_protocol2::authentication;
use postgres_protocol2::authentication::sasl;
use postgres_protocol2::authentication::sasl::ScramSha256;
use postgres_protocol2::message::backend::{AuthenticationSaslBody, Message};
use postgres_protocol2::message::backend::{AuthenticationSaslBody, Message, NoticeResponseBody};
use postgres_protocol2::message::frontend;
use std::collections::{HashMap, VecDeque};
use std::collections::HashMap;
use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::sync::mpsc;
use tokio_util::codec::Framed;
pub struct StartupStream<S, T> {
inner: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
buf: BackendMessages,
delayed: VecDeque<BackendMessage>,
delayed_notice: Vec<NoticeResponseBody>,
}
impl<S, T> Sink<FrontendMessage> for StartupStream<S, T>
@@ -78,11 +77,19 @@ where
}
}
pub struct RawConnection<S, T> {
pub stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
pub parameters: HashMap<String, String>,
pub delayed_notice: Vec<NoticeResponseBody>,
pub process_id: i32,
pub secret_key: i32,
}
pub async fn connect_raw<S, T>(
stream: S,
tls: T,
config: &Config,
) -> Result<(Client, Connection<S, T::Stream>), Error>
) -> Result<RawConnection<S, T::Stream>, Error>
where
S: AsyncRead + AsyncWrite + Unpin,
T: TlsConnect<S>,
@@ -97,18 +104,20 @@ where
},
),
buf: BackendMessages::empty(),
delayed: VecDeque::new(),
delayed_notice: Vec::new(),
};
startup(&mut stream, config).await?;
authenticate(&mut stream, config).await?;
let (process_id, secret_key, parameters) = read_info(&mut stream).await?;
let (sender, receiver) = mpsc::unbounded_channel();
let client = Client::new(sender, config.ssl_mode, process_id, secret_key);
let connection = Connection::new(stream.inner, stream.delayed, parameters, receiver);
Ok((client, connection))
Ok(RawConnection {
stream: stream.inner,
parameters,
delayed_notice: stream.delayed_notice,
process_id,
secret_key,
})
}
async fn startup<S, T>(stream: &mut StartupStream<S, T>, config: &Config) -> Result<(), Error>
@@ -347,9 +356,7 @@ where
body.value().map_err(Error::parse)?.to_string(),
);
}
Some(msg @ Message::NoticeResponse(_)) => {
stream.delayed.push_back(BackendMessage::Async(msg))
}
Some(Message::NoticeResponse(body)) => stream.delayed_notice.push(body),
Some(Message::ReadyForQuery(_)) => return Ok((process_id, secret_key, parameters)),
Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
Some(_) => return Err(Error::unexpected_message()),

View File

@@ -1,9 +1,10 @@
//! An asynchronous, pipelined, PostgreSQL client.
#![warn(rust_2018_idioms, clippy::all, missing_docs)]
#![warn(rust_2018_idioms, clippy::all)]
pub use crate::cancel_token::CancelToken;
pub use crate::client::Client;
pub use crate::client::{Client, SocketConfig};
pub use crate::config::Config;
pub use crate::connect_raw::RawConnection;
pub use crate::connection::Connection;
use crate::error::DbError;
pub use crate::error::Error;