mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-04 20:12:54 +00:00
fix(proxy): forward notifications from authentication (#9948)
Fixes https://github.com/neondatabase/cloud/issues/20973. This refactors `connect_raw` in order to return direct access to the delayed notices. I cannot find a way to test this with psycopg2 unfortunately, although testing it with psql does return the expected results.
This commit is contained in:
@@ -541,6 +541,10 @@ impl NoticeResponseBody {
|
||||
pub fn fields(&self) -> ErrorFields<'_> {
|
||||
ErrorFields { buf: &self.storage }
|
||||
}
|
||||
|
||||
pub fn as_bytes(&self) -> &[u8] {
|
||||
&self.storage
|
||||
}
|
||||
}
|
||||
|
||||
pub struct NotificationResponseBody {
|
||||
|
||||
@@ -10,10 +10,10 @@ use tokio::net::TcpStream;
|
||||
/// connection.
|
||||
#[derive(Clone)]
|
||||
pub struct CancelToken {
|
||||
pub(crate) socket_config: Option<SocketConfig>,
|
||||
pub(crate) ssl_mode: SslMode,
|
||||
pub(crate) process_id: i32,
|
||||
pub(crate) secret_key: i32,
|
||||
pub socket_config: Option<SocketConfig>,
|
||||
pub ssl_mode: SslMode,
|
||||
pub process_id: i32,
|
||||
pub secret_key: i32,
|
||||
}
|
||||
|
||||
impl CancelToken {
|
||||
|
||||
@@ -138,7 +138,7 @@ impl InnerClient {
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct SocketConfig {
|
||||
pub struct SocketConfig {
|
||||
pub host: Host,
|
||||
pub port: u16,
|
||||
pub connect_timeout: Option<Duration>,
|
||||
@@ -152,7 +152,7 @@ pub(crate) struct SocketConfig {
|
||||
pub struct Client {
|
||||
inner: Arc<InnerClient>,
|
||||
|
||||
socket_config: Option<SocketConfig>,
|
||||
socket_config: SocketConfig,
|
||||
ssl_mode: SslMode,
|
||||
process_id: i32,
|
||||
secret_key: i32,
|
||||
@@ -161,6 +161,7 @@ pub struct Client {
|
||||
impl Client {
|
||||
pub(crate) fn new(
|
||||
sender: mpsc::UnboundedSender<Request>,
|
||||
socket_config: SocketConfig,
|
||||
ssl_mode: SslMode,
|
||||
process_id: i32,
|
||||
secret_key: i32,
|
||||
@@ -172,7 +173,7 @@ impl Client {
|
||||
buffer: Default::default(),
|
||||
}),
|
||||
|
||||
socket_config: None,
|
||||
socket_config,
|
||||
ssl_mode,
|
||||
process_id,
|
||||
secret_key,
|
||||
@@ -188,10 +189,6 @@ impl Client {
|
||||
&self.inner
|
||||
}
|
||||
|
||||
pub(crate) fn set_socket_config(&mut self, socket_config: SocketConfig) {
|
||||
self.socket_config = Some(socket_config);
|
||||
}
|
||||
|
||||
/// Creates a new prepared statement.
|
||||
///
|
||||
/// Prepared statements can be executed repeatedly, and may contain query parameters (indicated by `$1`, `$2`, etc),
|
||||
@@ -412,7 +409,7 @@ impl Client {
|
||||
/// connection associated with this client.
|
||||
pub fn cancel_token(&self) -> CancelToken {
|
||||
CancelToken {
|
||||
socket_config: self.socket_config.clone(),
|
||||
socket_config: Some(self.socket_config.clone()),
|
||||
ssl_mode: self.ssl_mode,
|
||||
process_id: self.process_id,
|
||||
secret_key: self.secret_key,
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
use crate::connect::connect;
|
||||
use crate::connect_raw::connect_raw;
|
||||
use crate::connect_raw::RawConnection;
|
||||
use crate::tls::MakeTlsConnect;
|
||||
use crate::tls::TlsConnect;
|
||||
use crate::{Client, Connection, Error};
|
||||
@@ -485,14 +486,11 @@ impl Config {
|
||||
connect(tls, self).await
|
||||
}
|
||||
|
||||
/// Connects to a PostgreSQL database over an arbitrary stream.
|
||||
///
|
||||
/// All of the settings other than `user`, `password`, `dbname`, `options`, and `application_name` name are ignored.
|
||||
pub async fn connect_raw<S, T>(
|
||||
&self,
|
||||
stream: S,
|
||||
tls: T,
|
||||
) -> Result<(Client, Connection<S, T::Stream>), Error>
|
||||
) -> Result<RawConnection<S, T::Stream>, Error>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
T: TlsConnect<S>,
|
||||
|
||||
@@ -1,13 +1,16 @@
|
||||
use crate::client::SocketConfig;
|
||||
use crate::codec::BackendMessage;
|
||||
use crate::config::{Host, TargetSessionAttrs};
|
||||
use crate::connect_raw::connect_raw;
|
||||
use crate::connect_socket::connect_socket;
|
||||
use crate::tls::{MakeTlsConnect, TlsConnect};
|
||||
use crate::{Client, Config, Connection, Error, SimpleQueryMessage};
|
||||
use crate::{Client, Config, Connection, Error, RawConnection, SimpleQueryMessage};
|
||||
use futures_util::{future, pin_mut, Future, FutureExt, Stream};
|
||||
use postgres_protocol2::message::backend::Message;
|
||||
use std::io;
|
||||
use std::task::Poll;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
pub async fn connect<T>(
|
||||
mut tls: T,
|
||||
@@ -60,7 +63,36 @@ where
|
||||
T: TlsConnect<TcpStream>,
|
||||
{
|
||||
let socket = connect_socket(host, port, config.connect_timeout).await?;
|
||||
let (mut client, mut connection) = connect_raw(socket, tls, config).await?;
|
||||
let RawConnection {
|
||||
stream,
|
||||
parameters,
|
||||
delayed_notice,
|
||||
process_id,
|
||||
secret_key,
|
||||
} = connect_raw(socket, tls, config).await?;
|
||||
|
||||
let socket_config = SocketConfig {
|
||||
host: host.clone(),
|
||||
port,
|
||||
connect_timeout: config.connect_timeout,
|
||||
};
|
||||
|
||||
let (sender, receiver) = mpsc::unbounded_channel();
|
||||
let client = Client::new(
|
||||
sender,
|
||||
socket_config,
|
||||
config.ssl_mode,
|
||||
process_id,
|
||||
secret_key,
|
||||
);
|
||||
|
||||
// delayed notices are always sent as "Async" messages.
|
||||
let delayed = delayed_notice
|
||||
.into_iter()
|
||||
.map(|m| BackendMessage::Async(Message::NoticeResponse(m)))
|
||||
.collect();
|
||||
|
||||
let mut connection = Connection::new(stream, delayed, parameters, receiver);
|
||||
|
||||
if let TargetSessionAttrs::ReadWrite = config.target_session_attrs {
|
||||
let rows = client.simple_query_raw("SHOW transaction_read_only");
|
||||
@@ -102,11 +134,5 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
client.set_socket_config(SocketConfig {
|
||||
host: host.clone(),
|
||||
port,
|
||||
connect_timeout: config.connect_timeout,
|
||||
});
|
||||
|
||||
Ok((client, connection))
|
||||
}
|
||||
|
||||
@@ -3,27 +3,26 @@ use crate::config::{self, AuthKeys, Config, ReplicationMode};
|
||||
use crate::connect_tls::connect_tls;
|
||||
use crate::maybe_tls_stream::MaybeTlsStream;
|
||||
use crate::tls::{TlsConnect, TlsStream};
|
||||
use crate::{Client, Connection, Error};
|
||||
use crate::Error;
|
||||
use bytes::BytesMut;
|
||||
use fallible_iterator::FallibleIterator;
|
||||
use futures_util::{ready, Sink, SinkExt, Stream, TryStreamExt};
|
||||
use postgres_protocol2::authentication;
|
||||
use postgres_protocol2::authentication::sasl;
|
||||
use postgres_protocol2::authentication::sasl::ScramSha256;
|
||||
use postgres_protocol2::message::backend::{AuthenticationSaslBody, Message};
|
||||
use postgres_protocol2::message::backend::{AuthenticationSaslBody, Message, NoticeResponseBody};
|
||||
use postgres_protocol2::message::frontend;
|
||||
use std::collections::{HashMap, VecDeque};
|
||||
use std::collections::HashMap;
|
||||
use std::io;
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_util::codec::Framed;
|
||||
|
||||
pub struct StartupStream<S, T> {
|
||||
inner: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
|
||||
buf: BackendMessages,
|
||||
delayed: VecDeque<BackendMessage>,
|
||||
delayed_notice: Vec<NoticeResponseBody>,
|
||||
}
|
||||
|
||||
impl<S, T> Sink<FrontendMessage> for StartupStream<S, T>
|
||||
@@ -78,11 +77,19 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
pub struct RawConnection<S, T> {
|
||||
pub stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
|
||||
pub parameters: HashMap<String, String>,
|
||||
pub delayed_notice: Vec<NoticeResponseBody>,
|
||||
pub process_id: i32,
|
||||
pub secret_key: i32,
|
||||
}
|
||||
|
||||
pub async fn connect_raw<S, T>(
|
||||
stream: S,
|
||||
tls: T,
|
||||
config: &Config,
|
||||
) -> Result<(Client, Connection<S, T::Stream>), Error>
|
||||
) -> Result<RawConnection<S, T::Stream>, Error>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
T: TlsConnect<S>,
|
||||
@@ -97,18 +104,20 @@ where
|
||||
},
|
||||
),
|
||||
buf: BackendMessages::empty(),
|
||||
delayed: VecDeque::new(),
|
||||
delayed_notice: Vec::new(),
|
||||
};
|
||||
|
||||
startup(&mut stream, config).await?;
|
||||
authenticate(&mut stream, config).await?;
|
||||
let (process_id, secret_key, parameters) = read_info(&mut stream).await?;
|
||||
|
||||
let (sender, receiver) = mpsc::unbounded_channel();
|
||||
let client = Client::new(sender, config.ssl_mode, process_id, secret_key);
|
||||
let connection = Connection::new(stream.inner, stream.delayed, parameters, receiver);
|
||||
|
||||
Ok((client, connection))
|
||||
Ok(RawConnection {
|
||||
stream: stream.inner,
|
||||
parameters,
|
||||
delayed_notice: stream.delayed_notice,
|
||||
process_id,
|
||||
secret_key,
|
||||
})
|
||||
}
|
||||
|
||||
async fn startup<S, T>(stream: &mut StartupStream<S, T>, config: &Config) -> Result<(), Error>
|
||||
@@ -347,9 +356,7 @@ where
|
||||
body.value().map_err(Error::parse)?.to_string(),
|
||||
);
|
||||
}
|
||||
Some(msg @ Message::NoticeResponse(_)) => {
|
||||
stream.delayed.push_back(BackendMessage::Async(msg))
|
||||
}
|
||||
Some(Message::NoticeResponse(body)) => stream.delayed_notice.push(body),
|
||||
Some(Message::ReadyForQuery(_)) => return Ok((process_id, secret_key, parameters)),
|
||||
Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
|
||||
Some(_) => return Err(Error::unexpected_message()),
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
//! An asynchronous, pipelined, PostgreSQL client.
|
||||
#![warn(rust_2018_idioms, clippy::all, missing_docs)]
|
||||
#![warn(rust_2018_idioms, clippy::all)]
|
||||
|
||||
pub use crate::cancel_token::CancelToken;
|
||||
pub use crate::client::Client;
|
||||
pub use crate::client::{Client, SocketConfig};
|
||||
pub use crate::config::Config;
|
||||
pub use crate::connect_raw::RawConnection;
|
||||
pub use crate::connection::Connection;
|
||||
use crate::error::DbError;
|
||||
pub use crate::error::Error;
|
||||
|
||||
Reference in New Issue
Block a user