proxy(tokio-postgres): refactor typeinfo query to occur earlier (#11993)

## Problem

For #11992 I realised we need to get the type info before executing the
query. This is important to know how to decode rows with custom types,
eg the following query:

```sql
CREATE TYPE foo AS ENUM ('foo','bar','baz');
SELECT ARRAY['foo'::foo, 'bar'::foo, 'baz'::foo] AS data;
```

Getting that to work was harder that it seems. The original
tokio-postgres setup has a split between `Client` and `Connection`,
where messages are passed between. Because multiple clients were
supported, each client message included a dedicated response channel.
Each request would be terminated by the `ReadyForQuery` message.

The flow I opted to use for parsing types early would not trigger a
`ReadyForQuery`. The flow is as follows:

```
PARSE ""    // parse the user provided query
DESCRIBE "" // describe the query, returning param/result type oids
FLUSH       // force postgres to flush the responses early

// wait for descriptions

  // check if we know the types, if we don't then
  // setup the typeinfo query and execute it against each OID:

  PARSE typeinfo    // prepare our typeinfo query
  DESCRIBE typeinfo
  FLUSH // force postgres to flush the responses early

  // wait for typeinfo statement

    // for each OID we don't know:
    BIND typeinfo
    EXECUTE
    FLUSH

    // wait for type info, might reveal more OIDs to inspect

  // close the typeinfo query, we cache the OID->type map and this is kinder to pgbouncer.
  CLOSE typeinfo 

// finally once we know all the OIDs:
BIND ""   // bind the user provided query - already parsed - to the user provided params
EXECUTE   // run the user provided query
SYNC      // commit the transaction
```

## Summary of changes

Please review commit by commit. The main challenge was allowing one
query to issue multiple sub-queries. To do this I first made sure that
the client could fully own the connection, which required removing any
shared client state. I then had to replace the way responses are sent to
the client, by using only a single permanent channel. This required some
additional effort to track which query is being processed. Lastly I had
to modify the query/typeinfo functions to not issue `sync` commands, so
it would fit into the desired flow above.

To note: the flow above does force an extra roundtrip into each query. I
don't know yet if this has a measurable latency overhead.
This commit is contained in:
Conrad Ludgate
2025-05-23 20:41:12 +01:00
committed by GitHub
parent 87fc0a0374
commit 6768a71c86
15 changed files with 500 additions and 745 deletions

View File

@@ -1,76 +1,43 @@
use std::fmt;
use std::marker::PhantomPinned;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use bytes::{BufMut, Bytes, BytesMut};
use fallible_iterator::FallibleIterator;
use bytes::BufMut;
use futures_util::{Stream, ready};
use pin_project_lite::pin_project;
use postgres_protocol2::message::backend::Message;
use postgres_protocol2::message::frontend;
use postgres_types2::{Format, ToSql, Type};
use tracing::debug;
use postgres_types2::Format;
use crate::client::{InnerClient, Responses};
use crate::codec::FrontendMessage;
use crate::connection::RequestMessages;
use crate::types::IsNull;
use crate::{Column, Error, ReadyForQueryStatus, Row, Statement};
use crate::client::{CachedTypeInfo, InnerClient, Responses};
use crate::{Error, ReadyForQueryStatus, Row, Statement};
struct BorrowToSqlParamsDebug<'a>(&'a [&'a (dyn ToSql + Sync)]);
impl fmt::Debug for BorrowToSqlParamsDebug<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_list().entries(self.0.iter()).finish()
}
}
pub async fn query<'a, I>(
client: &InnerClient,
statement: Statement,
params: I,
) -> Result<RowStream, Error>
where
I: IntoIterator<Item = &'a (dyn ToSql + Sync)>,
I::IntoIter: ExactSizeIterator,
{
let buf = if tracing::enabled!(tracing::Level::DEBUG) {
let params = params.into_iter().collect::<Vec<_>>();
debug!(
"executing statement {} with parameters: {:?}",
statement.name(),
BorrowToSqlParamsDebug(params.as_slice()),
);
encode(client, &statement, params)?
} else {
encode(client, &statement, params)?
};
let responses = start(client, buf).await?;
Ok(RowStream {
statement,
responses,
command_tag: None,
status: ReadyForQueryStatus::Unknown,
output_format: Format::Binary,
_p: PhantomPinned,
})
}
pub async fn query_txt<S, I>(
client: &Arc<InnerClient>,
pub async fn query_txt<'a, S, I>(
client: &'a mut InnerClient,
typecache: &mut CachedTypeInfo,
query: &str,
params: I,
) -> Result<RowStream, Error>
) -> Result<RowStream<'a>, Error>
where
S: AsRef<str>,
I: IntoIterator<Item = Option<S>>,
I::IntoIter: ExactSizeIterator,
{
let params = params.into_iter();
let mut client = client.start()?;
let buf = client.with_buf(|buf| {
// Flow:
// 1. Parse the query
// 2. Inspect the row description for OIDs
// 3. If there's any OIDs we don't already know about, perform the typeinfo routine
// 4. Execute the query
// 5. Sync.
//
// The typeinfo routine:
// 1. Parse the typeinfo query
// 2. Execute the query on each OID
// 3. If the result does not match an OID we know, repeat 2.
// parse the query and get type info
let responses = client.send_with_flush(|buf| {
frontend::parse(
"", // unnamed prepared statement
query, // query to parse
@@ -79,7 +46,30 @@ where
)
.map_err(Error::encode)?;
frontend::describe(b'S', "", buf).map_err(Error::encode)?;
// Bind, pass params as text, retrieve as binary
Ok(())
})?;
match responses.next().await? {
Message::ParseComplete => {}
_ => return Err(Error::unexpected_message()),
}
match responses.next().await? {
Message::ParameterDescription(_) => {}
_ => return Err(Error::unexpected_message()),
};
let row_description = match responses.next().await? {
Message::RowDescription(body) => Some(body),
Message::NoData => None,
_ => return Err(Error::unexpected_message()),
};
let columns =
crate::prepare::parse_row_description(&mut client, typecache, row_description).await?;
let responses = client.send_with_sync(|buf| {
// Bind, pass params as text, retrieve as text
match frontend::bind(
"", // empty string selects the unnamed portal
"", // unnamed prepared statement
@@ -102,173 +92,55 @@ where
// Execute
frontend::execute("", 0, buf).map_err(Error::encode)?;
// Sync
frontend::sync(buf);
Ok(buf.split().freeze())
Ok(())
})?;
// now read the responses
let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
match responses.next().await? {
Message::ParseComplete => {}
_ => return Err(Error::unexpected_message()),
}
let parameter_description = match responses.next().await? {
Message::ParameterDescription(body) => body,
_ => return Err(Error::unexpected_message()),
};
let row_description = match responses.next().await? {
Message::RowDescription(body) => Some(body),
Message::NoData => None,
_ => return Err(Error::unexpected_message()),
};
match responses.next().await? {
Message::BindComplete => {}
_ => return Err(Error::unexpected_message()),
}
let mut parameters = vec![];
let mut it = parameter_description.parameters();
while let Some(oid) = it.next().map_err(Error::parse)? {
let type_ = Type::from_oid(oid).unwrap_or(Type::UNKNOWN);
parameters.push(type_);
}
let mut columns = vec![];
if let Some(row_description) = row_description {
let mut it = row_description.fields();
while let Some(field) = it.next().map_err(Error::parse)? {
let type_ = Type::from_oid(field.type_oid()).unwrap_or(Type::UNKNOWN);
let column = Column::new(field.name().to_string(), type_, field);
columns.push(column);
}
}
Ok(RowStream {
statement: Statement::new_anonymous(parameters, columns),
responses,
statement: Statement::new("", columns),
command_tag: None,
status: ReadyForQueryStatus::Unknown,
output_format: Format::Text,
_p: PhantomPinned,
})
}
async fn start(client: &InnerClient, buf: Bytes) -> Result<Responses, Error> {
let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
match responses.next().await? {
Message::BindComplete => {}
_ => return Err(Error::unexpected_message()),
}
Ok(responses)
/// A stream of table rows.
pub struct RowStream<'a> {
responses: &'a mut Responses,
output_format: Format,
pub statement: Statement,
pub command_tag: Option<String>,
pub status: ReadyForQueryStatus,
}
pub fn encode<'a, I>(client: &InnerClient, statement: &Statement, params: I) -> Result<Bytes, Error>
where
I: IntoIterator<Item = &'a (dyn ToSql + Sync)>,
I::IntoIter: ExactSizeIterator,
{
client.with_buf(|buf| {
encode_bind(statement, params, "", buf)?;
frontend::execute("", 0, buf).map_err(Error::encode)?;
frontend::sync(buf);
Ok(buf.split().freeze())
})
}
pub fn encode_bind<'a, I>(
statement: &Statement,
params: I,
portal: &str,
buf: &mut BytesMut,
) -> Result<(), Error>
where
I: IntoIterator<Item = &'a (dyn ToSql + Sync)>,
I::IntoIter: ExactSizeIterator,
{
let param_types = statement.params();
let params = params.into_iter();
assert!(
param_types.len() == params.len(),
"expected {} parameters but got {}",
param_types.len(),
params.len()
);
let (param_formats, params): (Vec<_>, Vec<_>) = params
.zip(param_types.iter())
.map(|(p, ty)| (p.encode_format(ty) as i16, p))
.unzip();
let params = params.into_iter();
let mut error_idx = 0;
let r = frontend::bind(
portal,
statement.name(),
param_formats,
params.zip(param_types).enumerate(),
|(idx, (param, ty)), buf| match param.to_sql_checked(ty, buf) {
Ok(IsNull::No) => Ok(postgres_protocol2::IsNull::No),
Ok(IsNull::Yes) => Ok(postgres_protocol2::IsNull::Yes),
Err(e) => {
error_idx = idx;
Err(e)
}
},
Some(1),
buf,
);
match r {
Ok(()) => Ok(()),
Err(frontend::BindError::Conversion(e)) => Err(Error::to_sql(e, error_idx)),
Err(frontend::BindError::Serialization(e)) => Err(Error::encode(e)),
}
}
pin_project! {
/// A stream of table rows.
pub struct RowStream {
statement: Statement,
responses: Responses,
command_tag: Option<String>,
output_format: Format,
status: ReadyForQueryStatus,
#[pin]
_p: PhantomPinned,
}
}
impl Stream for RowStream {
impl Stream for RowStream<'_> {
type Item = Result<Row, Error>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.project();
let this = self.get_mut();
loop {
match ready!(this.responses.poll_next(cx)?) {
Message::DataRow(body) => {
return Poll::Ready(Some(Ok(Row::new(
this.statement.clone(),
body,
*this.output_format,
this.output_format,
)?)));
}
Message::EmptyQueryResponse | Message::PortalSuspended => {}
Message::CommandComplete(body) => {
if let Ok(tag) = body.tag() {
*this.command_tag = Some(tag.to_string());
this.command_tag = Some(tag.to_string());
}
}
Message::ReadyForQuery(status) => {
*this.status = status.into();
this.status = status.into();
return Poll::Ready(None);
}
_ => return Poll::Ready(Some(Err(Error::unexpected_message()))),
@@ -276,24 +148,3 @@ impl Stream for RowStream {
}
}
}
impl RowStream {
/// Returns information about the columns of data in the row.
pub fn columns(&self) -> &[Column] {
self.statement.columns()
}
/// Returns the command tag of this query.
///
/// This is only available after the stream has been exhausted.
pub fn command_tag(&self) -> Option<String> {
self.command_tag.clone()
}
/// Returns if the connection is ready for querying, with the status of the connection.
///
/// This might be available only after the stream has been exhausted.
pub fn ready_status(&self) -> ReadyForQueryStatus {
self.status
}
}