[proxy] rework handling of notices in sql-over-http (#12659)

A replacement for #10254 which allows us to introduce notice messages
for sql-over-http in the future if we want to. This also removes the
`ParameterStatus` and `Notification` handling as there's nothing we
could/should do for those.
This commit is contained in:
Conrad Ludgate
2025-07-21 13:50:13 +01:00
committed by GitHub
parent 5a48365fb9
commit b2ecb10f91
9 changed files with 150 additions and 252 deletions

View File

@@ -74,7 +74,6 @@ impl Header {
}
/// An enum representing Postgres backend messages.
#[non_exhaustive]
pub enum Message {
AuthenticationCleartextPassword,
AuthenticationGss,
@@ -145,16 +144,7 @@ impl Message {
PARSE_COMPLETE_TAG => Message::ParseComplete,
BIND_COMPLETE_TAG => Message::BindComplete,
CLOSE_COMPLETE_TAG => Message::CloseComplete,
NOTIFICATION_RESPONSE_TAG => {
let process_id = buf.read_i32::<BigEndian>()?;
let channel = buf.read_cstr()?;
let message = buf.read_cstr()?;
Message::NotificationResponse(NotificationResponseBody {
process_id,
channel,
message,
})
}
NOTIFICATION_RESPONSE_TAG => Message::NotificationResponse(NotificationResponseBody {}),
COPY_DONE_TAG => Message::CopyDone,
COMMAND_COMPLETE_TAG => {
let tag = buf.read_cstr()?;
@@ -543,28 +533,7 @@ impl NoticeResponseBody {
}
}
pub struct NotificationResponseBody {
process_id: i32,
channel: Bytes,
message: Bytes,
}
impl NotificationResponseBody {
#[inline]
pub fn process_id(&self) -> i32 {
self.process_id
}
#[inline]
pub fn channel(&self) -> io::Result<&str> {
get_str(&self.channel)
}
#[inline]
pub fn message(&self) -> io::Result<&str> {
get_str(&self.message)
}
}
pub struct NotificationResponseBody {}
pub struct ParameterDescriptionBody {
storage: Bytes,

View File

@@ -13,7 +13,7 @@ use serde::{Deserialize, Serialize};
use tokio::sync::mpsc;
use crate::cancel_token::RawCancelToken;
use crate::codec::{BackendMessages, FrontendMessage};
use crate::codec::{BackendMessages, FrontendMessage, RecordNotices};
use crate::config::{Host, SslMode};
use crate::query::RowStream;
use crate::simple_query::SimpleQueryStream;
@@ -221,6 +221,18 @@ impl Client {
&mut self.inner
}
pub fn record_notices(&mut self, limit: usize) -> mpsc::UnboundedReceiver<Box<str>> {
let (tx, rx) = mpsc::unbounded_channel();
let notices = RecordNotices { sender: tx, limit };
self.inner
.sender
.send(FrontendMessage::RecordNotices(notices))
.ok();
rx
}
/// Pass text directly to the Postgres backend to allow it to sort out typing itself and
/// to save a roundtrip
pub async fn query_raw_txt<S, I>(

View File

@@ -3,10 +3,17 @@ use std::io;
use bytes::{Bytes, BytesMut};
use fallible_iterator::FallibleIterator;
use postgres_protocol2::message::backend;
use tokio::sync::mpsc::UnboundedSender;
use tokio_util::codec::{Decoder, Encoder};
pub enum FrontendMessage {
Raw(Bytes),
RecordNotices(RecordNotices),
}
pub struct RecordNotices {
pub sender: UnboundedSender<Box<str>>,
pub limit: usize,
}
pub enum BackendMessage {
@@ -33,14 +40,11 @@ impl FallibleIterator for BackendMessages {
pub struct PostgresCodec;
impl Encoder<FrontendMessage> for PostgresCodec {
impl Encoder<Bytes> for PostgresCodec {
type Error = io::Error;
fn encode(&mut self, item: FrontendMessage, dst: &mut BytesMut) -> io::Result<()> {
match item {
FrontendMessage::Raw(buf) => dst.extend_from_slice(&buf),
}
fn encode(&mut self, item: Bytes, dst: &mut BytesMut) -> io::Result<()> {
dst.extend_from_slice(&item);
Ok(())
}
}

View File

@@ -1,11 +1,9 @@
use std::net::IpAddr;
use postgres_protocol2::message::backend::Message;
use tokio::net::TcpStream;
use tokio::sync::mpsc;
use crate::client::SocketConfig;
use crate::codec::BackendMessage;
use crate::config::Host;
use crate::connect_raw::connect_raw;
use crate::connect_socket::connect_socket;
@@ -48,8 +46,8 @@ where
let stream = connect_tls(socket, config.ssl_mode, tls).await?;
let RawConnection {
stream,
parameters,
delayed_notice,
parameters: _,
delayed_notice: _,
process_id,
secret_key,
} = connect_raw(stream, config).await?;
@@ -72,13 +70,7 @@ where
secret_key,
);
// delayed notices are always sent as "Async" messages.
let delayed = delayed_notice
.into_iter()
.map(|m| BackendMessage::Async(Message::NoticeResponse(m)))
.collect();
let connection = Connection::new(stream, delayed, parameters, conn_tx, conn_rx);
let connection = Connection::new(stream, conn_tx, conn_rx);
Ok((client, connection))
}

View File

@@ -3,7 +3,7 @@ use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};
use bytes::BytesMut;
use bytes::{Bytes, BytesMut};
use fallible_iterator::FallibleIterator;
use futures_util::{Sink, SinkExt, Stream, TryStreamExt, ready};
use postgres_protocol2::authentication::sasl;
@@ -14,7 +14,7 @@ use tokio::io::{AsyncRead, AsyncWrite};
use tokio_util::codec::Framed;
use crate::Error;
use crate::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCodec};
use crate::codec::{BackendMessage, BackendMessages, PostgresCodec};
use crate::config::{self, AuthKeys, Config};
use crate::maybe_tls_stream::MaybeTlsStream;
use crate::tls::TlsStream;
@@ -25,7 +25,7 @@ pub struct StartupStream<S, T> {
delayed_notice: Vec<NoticeResponseBody>,
}
impl<S, T> Sink<FrontendMessage> for StartupStream<S, T>
impl<S, T> Sink<Bytes> for StartupStream<S, T>
where
S: AsyncRead + AsyncWrite + Unpin,
T: AsyncRead + AsyncWrite + Unpin,
@@ -36,7 +36,7 @@ where
Pin::new(&mut self.inner).poll_ready(cx)
}
fn start_send(mut self: Pin<&mut Self>, item: FrontendMessage) -> io::Result<()> {
fn start_send(mut self: Pin<&mut Self>, item: Bytes) -> io::Result<()> {
Pin::new(&mut self.inner).start_send(item)
}
@@ -120,10 +120,7 @@ where
let mut buf = BytesMut::new();
frontend::startup_message(&config.server_params, &mut buf).map_err(Error::encode)?;
stream
.send(FrontendMessage::Raw(buf.freeze()))
.await
.map_err(Error::io)
stream.send(buf.freeze()).await.map_err(Error::io)
}
async fn authenticate<S, T>(stream: &mut StartupStream<S, T>, config: &Config) -> Result<(), Error>
@@ -191,10 +188,7 @@ where
let mut buf = BytesMut::new();
frontend::password_message(password, &mut buf).map_err(Error::encode)?;
stream
.send(FrontendMessage::Raw(buf.freeze()))
.await
.map_err(Error::io)
stream.send(buf.freeze()).await.map_err(Error::io)
}
async fn authenticate_sasl<S, T>(
@@ -253,10 +247,7 @@ where
let mut buf = BytesMut::new();
frontend::sasl_initial_response(mechanism, scram.message(), &mut buf).map_err(Error::encode)?;
stream
.send(FrontendMessage::Raw(buf.freeze()))
.await
.map_err(Error::io)?;
stream.send(buf.freeze()).await.map_err(Error::io)?;
let body = match stream.try_next().await.map_err(Error::io)? {
Some(Message::AuthenticationSaslContinue(body)) => body,
@@ -272,10 +263,7 @@ where
let mut buf = BytesMut::new();
frontend::sasl_response(scram.message(), &mut buf).map_err(Error::encode)?;
stream
.send(FrontendMessage::Raw(buf.freeze()))
.await
.map_err(Error::io)?;
stream.send(buf.freeze()).await.map_err(Error::io)?;
let body = match stream.try_next().await.map_err(Error::io)? {
Some(Message::AuthenticationSaslFinal(body)) => body,

View File

@@ -1,22 +1,23 @@
use std::collections::{HashMap, VecDeque};
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use bytes::BytesMut;
use futures_util::{Sink, Stream, ready};
use postgres_protocol2::message::backend::Message;
use fallible_iterator::FallibleIterator;
use futures_util::{Sink, StreamExt, ready};
use postgres_protocol2::message::backend::{Message, NoticeResponseBody};
use postgres_protocol2::message::frontend;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::sync::mpsc;
use tokio_util::codec::Framed;
use tokio_util::sync::PollSender;
use tracing::{info, trace};
use tracing::trace;
use crate::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCodec};
use crate::error::DbError;
use crate::Error;
use crate::codec::{
BackendMessage, BackendMessages, FrontendMessage, PostgresCodec, RecordNotices,
};
use crate::maybe_tls_stream::MaybeTlsStream;
use crate::{AsyncMessage, Error, Notification};
#[derive(PartialEq, Debug)]
enum State {
@@ -33,18 +34,18 @@ enum State {
/// occurred, or because its associated `Client` has dropped and all outstanding work has completed.
#[must_use = "futures do nothing unless polled"]
pub struct Connection<S, T> {
/// HACK: we need this in the Neon Proxy.
pub stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
/// HACK: we need this in the Neon Proxy to forward params.
pub parameters: HashMap<String, String>,
stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
sender: PollSender<BackendMessages>,
receiver: mpsc::UnboundedReceiver<FrontendMessage>,
notices: Option<RecordNotices>,
pending_responses: VecDeque<BackendMessage>,
pending_response: Option<BackendMessages>,
state: State,
}
pub enum Never {}
impl<S, T> Connection<S, T>
where
S: AsyncRead + AsyncWrite + Unpin,
@@ -52,70 +53,42 @@ where
{
pub(crate) fn new(
stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
pending_responses: VecDeque<BackendMessage>,
parameters: HashMap<String, String>,
sender: mpsc::Sender<BackendMessages>,
receiver: mpsc::UnboundedReceiver<FrontendMessage>,
) -> Connection<S, T> {
Connection {
stream,
parameters,
sender: PollSender::new(sender),
receiver,
pending_responses,
notices: None,
pending_response: None,
state: State::Active,
}
}
fn poll_response(
&mut self,
cx: &mut Context<'_>,
) -> Poll<Option<Result<BackendMessage, Error>>> {
if let Some(message) = self.pending_responses.pop_front() {
trace!("retrying pending response");
return Poll::Ready(Some(Ok(message)));
}
Pin::new(&mut self.stream)
.poll_next(cx)
.map(|o| o.map(|r| r.map_err(Error::io)))
}
/// Read and process messages from the connection to postgres.
/// client <- postgres
fn poll_read(&mut self, cx: &mut Context<'_>) -> Poll<Result<AsyncMessage, Error>> {
fn poll_read(&mut self, cx: &mut Context<'_>) -> Poll<Result<Never, Error>> {
loop {
let message = match self.poll_response(cx)? {
Poll::Ready(Some(message)) => message,
Poll::Ready(None) => return Poll::Ready(Err(Error::closed())),
Poll::Pending => {
trace!("poll_read: waiting on response");
return Poll::Pending;
}
};
let messages = match message {
BackendMessage::Async(Message::NoticeResponse(body)) => {
let error = DbError::parse(&mut body.fields()).map_err(Error::parse)?;
return Poll::Ready(Ok(AsyncMessage::Notice(error)));
}
BackendMessage::Async(Message::NotificationResponse(body)) => {
let notification = Notification {
process_id: body.process_id(),
channel: body.channel().map_err(Error::parse)?.to_string(),
payload: body.message().map_err(Error::parse)?.to_string(),
let messages = match self.pending_response.take() {
Some(messages) => messages,
None => {
let message = match self.stream.poll_next_unpin(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(None) => return Poll::Ready(Err(Error::closed())),
Poll::Ready(Some(Err(e))) => return Poll::Ready(Err(Error::io(e))),
Poll::Ready(Some(Ok(message))) => message,
};
return Poll::Ready(Ok(AsyncMessage::Notification(notification)));
match message {
BackendMessage::Async(Message::NoticeResponse(body)) => {
self.handle_notice(body)?;
continue;
}
BackendMessage::Async(_) => continue,
BackendMessage::Normal { messages } => messages,
}
}
BackendMessage::Async(Message::ParameterStatus(body)) => {
self.parameters.insert(
body.name().map_err(Error::parse)?.to_string(),
body.value().map_err(Error::parse)?.to_string(),
);
continue;
}
BackendMessage::Async(_) => unreachable!(),
BackendMessage::Normal { messages } => messages,
};
match self.sender.poll_reserve(cx) {
@@ -126,8 +99,7 @@ where
return Poll::Ready(Err(Error::closed()));
}
Poll::Pending => {
self.pending_responses
.push_back(BackendMessage::Normal { messages });
self.pending_response = Some(messages);
trace!("poll_read: waiting on sender");
return Poll::Pending;
}
@@ -135,6 +107,31 @@ where
}
}
fn handle_notice(&mut self, body: NoticeResponseBody) -> Result<(), Error> {
let Some(notices) = &mut self.notices else {
return Ok(());
};
let mut fields = body.fields();
while let Some(field) = fields.next().map_err(Error::parse)? {
// loop until we find the message field
if field.type_() == b'M' {
// if the message field is within the limit, send it.
if let Some(new_limit) = notices.limit.checked_sub(field.value().len()) {
match notices.sender.send(field.value().into()) {
// set the new limit.
Ok(()) => notices.limit = new_limit,
// closed.
Err(_) => self.notices = None,
}
}
break;
}
}
Ok(())
}
/// Fetch the next client request and enqueue the response sender.
fn poll_request(&mut self, cx: &mut Context<'_>) -> Poll<Option<FrontendMessage>> {
if self.receiver.is_closed() {
@@ -168,21 +165,23 @@ where
match self.poll_request(cx) {
// send the message to postgres
Poll::Ready(Some(request)) => {
Poll::Ready(Some(FrontendMessage::Raw(request))) => {
Pin::new(&mut self.stream)
.start_send(request)
.map_err(Error::io)?;
}
Poll::Ready(Some(FrontendMessage::RecordNotices(notices))) => {
self.notices = Some(notices)
}
// No more messages from the client, and no more responses to wait for.
// Send a terminate message to postgres
Poll::Ready(None) => {
trace!("poll_write: at eof, terminating");
let mut request = BytesMut::new();
frontend::terminate(&mut request);
let request = FrontendMessage::Raw(request.freeze());
Pin::new(&mut self.stream)
.start_send(request)
.start_send(request.freeze())
.map_err(Error::io)?;
trace!("poll_write: sent eof, closing");
@@ -231,34 +230,17 @@ where
}
}
/// Returns the value of a runtime parameter for this connection.
pub fn parameter(&self, name: &str) -> Option<&str> {
self.parameters.get(name).map(|s| &**s)
}
/// Polls for asynchronous messages from the server.
///
/// The server can send notices as well as notifications asynchronously to the client. Applications that wish to
/// examine those messages should use this method to drive the connection rather than its `Future` implementation.
pub fn poll_message(
&mut self,
cx: &mut Context<'_>,
) -> Poll<Option<Result<AsyncMessage, Error>>> {
fn poll_message(&mut self, cx: &mut Context<'_>) -> Poll<Option<Result<Never, Error>>> {
if self.state != State::Closing {
// if the state is still active, try read from and write to postgres.
let message = self.poll_read(cx)?;
let closing = self.poll_write(cx)?;
if let Poll::Ready(()) = closing {
let Poll::Pending = self.poll_read(cx)?;
if self.poll_write(cx)?.is_ready() {
self.state = State::Closing;
}
if let Poll::Ready(message) = message {
return Poll::Ready(Some(Ok(message)));
}
// poll_read returned Pending.
// poll_write returned Pending or Ready(WriteReady::WaitingOnRead).
// if poll_write returned Ready(WriteReady::WaitingOnRead), then we are waiting to read more data from postgres.
// poll_write returned Pending or Ready(()).
// if poll_write returned Ready(()), then we are waiting to read more data from postgres.
if self.state != State::Closing {
return Poll::Pending;
}
@@ -280,11 +262,9 @@ where
type Output = Result<(), Error>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
while let Some(message) = ready!(self.poll_message(cx)?) {
if let AsyncMessage::Notice(notice) = message {
info!("{}: {}", notice.severity(), notice.message());
}
match self.poll_message(cx)? {
Poll::Ready(None) => Poll::Ready(Ok(())),
Poll::Pending => Poll::Pending,
}
Poll::Ready(Ok(()))
}
}

View File

@@ -8,7 +8,6 @@ pub use crate::client::{Client, SocketConfig};
pub use crate::config::Config;
pub use crate::connect_raw::RawConnection;
pub use crate::connection::Connection;
use crate::error::DbError;
pub use crate::error::Error;
pub use crate::generic_client::GenericClient;
pub use crate::query::RowStream;
@@ -93,21 +92,6 @@ impl Notification {
}
}
/// An asynchronous message from the server.
#[allow(clippy::large_enum_variant)]
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum AsyncMessage {
/// A notice.
///
/// Notices use the same format as errors, but aren't "errors" per-se.
Notice(DbError),
/// A notification.
///
/// Connections can subscribe to notifications with the `LISTEN` command.
Notification(Notification),
}
/// Message returned by the `SimpleQuery` stream.
#[derive(Debug)]
#[non_exhaustive]