From c835bbba1f809ebe4b87c16aca740facfd2c1a3e Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Fri, 6 Dec 2024 12:01:19 +0000 Subject: [PATCH] refactor statements and the type cache to avoid arcs --- libs/proxy/tokio-postgres2/src/client.rs | 88 +++++----- .../tokio-postgres2/src/generic_client.rs | 5 +- libs/proxy/tokio-postgres2/src/prepare.rs | 156 +++++++++++------- libs/proxy/tokio-postgres2/src/query.rs | 61 +++++-- libs/proxy/tokio-postgres2/src/row.rs | 54 +----- libs/proxy/tokio-postgres2/src/statement.rs | 13 +- libs/proxy/tokio-postgres2/src/transaction.rs | 10 +- proxy/src/serverless/json.rs | 10 +- proxy/src/serverless/sql_over_http.rs | 22 ++- 9 files changed, 228 insertions(+), 191 deletions(-) diff --git a/libs/proxy/tokio-postgres2/src/client.rs b/libs/proxy/tokio-postgres2/src/client.rs index 90523121b9..fc1852ab19 100644 --- a/libs/proxy/tokio-postgres2/src/client.rs +++ b/libs/proxy/tokio-postgres2/src/client.rs @@ -52,7 +52,7 @@ impl Responses { /// A cache of type info and prepared statements for fetching type info /// (corresponding to the queries in the [prepare] module). #[derive(Default)] -struct CachedTypeInfo { +pub(crate) struct CachedTypeInfo { /// A statement for basic information for a type from its /// OID. Corresponds to [TYPEINFO_QUERY](prepare::TYPEINFO_QUERY) (or its /// fallback). @@ -68,10 +68,42 @@ struct CachedTypeInfo { /// Cache of types already looked up. types: HashMap, } +impl CachedTypeInfo { + pub(crate) fn typeinfo(&mut self) -> Option<&Statement> { + self.typeinfo.as_ref() + } + + pub(crate) fn set_typeinfo(&mut self, statement: Statement) -> &Statement { + self.typeinfo.insert(statement) + } + + pub(crate) fn typeinfo_composite(&mut self) -> Option<&Statement> { + self.typeinfo_composite.as_ref() + } + + pub(crate) fn set_typeinfo_composite(&mut self, statement: Statement) -> &Statement { + self.typeinfo_composite.insert(statement) + } + + pub(crate) fn typeinfo_enum(&mut self) -> Option<&Statement> { + self.typeinfo_enum.as_ref() + } + + pub(crate) fn set_typeinfo_enum(&mut self, statement: Statement) -> &Statement { + self.typeinfo_enum.insert(statement) + } + + pub(crate) fn type_(&mut self, oid: Oid) -> Option { + self.types.get(&oid).cloned() + } + + pub(crate) fn set_type(&mut self, oid: Oid, type_: &Type) { + self.types.insert(oid, type_.clone()); + } +} pub struct InnerClient { sender: mpsc::UnboundedSender, - cached_typeinfo: CachedTypeInfo, /// A buffer to use when writing out postgres commands. buffer: BytesMut, @@ -89,38 +121,6 @@ impl InnerClient { }) } - pub fn typeinfo(&mut self) -> Option { - self.cached_typeinfo.typeinfo.clone() - } - - pub fn set_typeinfo(&mut self, statement: &Statement) { - self.cached_typeinfo.typeinfo = Some(statement.clone()); - } - - pub fn typeinfo_composite(&mut self) -> Option { - self.cached_typeinfo.typeinfo_composite.clone() - } - - pub fn set_typeinfo_composite(&mut self, statement: &Statement) { - self.cached_typeinfo.typeinfo_composite = Some(statement.clone()); - } - - pub fn typeinfo_enum(&mut self) -> Option { - self.cached_typeinfo.typeinfo_enum.clone() - } - - pub fn set_typeinfo_enum(&mut self, statement: &Statement) { - self.cached_typeinfo.typeinfo_enum = Some(statement.clone()); - } - - pub fn type_(&mut self, oid: Oid) -> Option { - self.cached_typeinfo.types.get(&oid).cloned() - } - - pub fn set_type(&mut self, oid: Oid, type_: &Type) { - self.cached_typeinfo.types.insert(oid, type_.clone()); - } - /// Call the given function with a buffer to be used when writing out /// postgres commands. pub fn with_buf(&mut self, f: F) -> R @@ -146,7 +146,8 @@ pub struct SocketConfig { /// The client is one half of what is returned when a connection is established. Users interact with the database /// through this client object. pub struct Client { - inner: InnerClient, + pub(crate) inner: InnerClient, + pub(crate) cached_typeinfo: CachedTypeInfo, socket_config: SocketConfig, ssl_mode: SslMode, @@ -165,9 +166,9 @@ impl Client { Client { inner: InnerClient { sender, - cached_typeinfo: Default::default(), buffer: Default::default(), }, + cached_typeinfo: Default::default(), socket_config, ssl_mode, @@ -181,10 +182,6 @@ impl Client { self.process_id } - pub(crate) fn inner(&mut self) -> &mut InnerClient { - &mut self.inner - } - /// Pass text directly to the Postgres backend to allow it to sort out typing itself and /// to save a roundtrip pub async fn query_raw_txt( @@ -211,7 +208,7 @@ impl Client { /// functionality to safely embed that data in the request. Do not form statements via string concatenation and pass /// them to this method! pub async fn batch_execute(&mut self, query: &str) -> Result { - simple_query::batch_execute(self.inner(), query).await + simple_query::batch_execute(&mut self.inner, query).await } /// Begins a new database transaction. @@ -229,13 +226,13 @@ impl Client { return; } - let buf = self.client.inner().with_buf(|buf| { + let buf = self.client.inner.with_buf(|buf| { frontend::query("ROLLBACK", buf).unwrap(); buf.split().freeze() }); let _ = self .client - .inner() + .inner .send(RequestMessages::Single(FrontendMessage::Raw(buf))); } } @@ -276,11 +273,6 @@ impl Client { } } - /// Query for type information - pub async fn get_type(&mut self, oid: Oid) -> Result { - crate::prepare::get_type(&mut self.inner, oid).await - } - /// Determines if the connection to the server has already closed. /// /// In that case, all future queries will fail. diff --git a/libs/proxy/tokio-postgres2/src/generic_client.rs b/libs/proxy/tokio-postgres2/src/generic_client.rs index 8f331fab26..1a5bf04be5 100644 --- a/libs/proxy/tokio-postgres2/src/generic_client.rs +++ b/libs/proxy/tokio-postgres2/src/generic_client.rs @@ -39,7 +39,7 @@ impl GenericClient for Client { /// Query for type information async fn get_type(&mut self, oid: Oid) -> Result { - crate::prepare::get_type(self.inner(), oid).await + crate::prepare::get_type(&mut self.inner, &mut self.cached_typeinfo, oid).await } } @@ -59,6 +59,7 @@ impl GenericClient for Transaction<'_> { /// Query for type information async fn get_type(&mut self, oid: Oid) -> Result { - crate::prepare::get_type(self.client().inner(), oid).await + let client = self.client(); + crate::prepare::get_type(&mut client.inner, &mut client.cached_typeinfo, oid).await } } diff --git a/libs/proxy/tokio-postgres2/src/prepare.rs b/libs/proxy/tokio-postgres2/src/prepare.rs index d5bb45fc45..562e9fa18b 100644 --- a/libs/proxy/tokio-postgres2/src/prepare.rs +++ b/libs/proxy/tokio-postgres2/src/prepare.rs @@ -1,4 +1,4 @@ -use crate::client::InnerClient; +use crate::client::{CachedTypeInfo, InnerClient}; use crate::codec::FrontendMessage; use crate::connection::RequestMessages; use crate::error::SqlState; @@ -7,12 +7,12 @@ use crate::{query, slice_iter}; use crate::{Column, Error, Statement}; use bytes::Bytes; use fallible_iterator::FallibleIterator; -use futures_util::{pin_mut, TryStreamExt}; +use futures_util::{pin_mut, StreamExt, TryStreamExt}; use log::debug; use postgres_protocol2::message::backend::Message; use postgres_protocol2::message::frontend; use std::future::Future; -use std::pin::Pin; +use std::pin::{pin, Pin}; use std::sync::atomic::{AtomicUsize, Ordering}; pub(crate) const TYPEINFO_QUERY: &str = "\ @@ -59,6 +59,7 @@ static NEXT_ID: AtomicUsize = AtomicUsize::new(0); pub async fn prepare( client: &mut InnerClient, + cache: &mut CachedTypeInfo, query: &str, types: &[Type], ) -> Result { @@ -85,7 +86,7 @@ pub async fn prepare( let mut parameters = vec![]; let mut it = parameter_description.parameters(); while let Some(oid) = it.next().map_err(Error::parse)? { - let type_ = get_type(client, oid).await?; + let type_ = get_type(client, cache, oid).await?; parameters.push(type_); } @@ -93,7 +94,7 @@ pub async fn prepare( if let Some(row_description) = row_description { let mut it = row_description.fields(); while let Some(field) = it.next().map_err(Error::parse)? { - let type_ = get_type(client, field.type_oid()).await?; + let type_ = get_type(client, cache, field.type_oid()).await?; let column = Column::new(field.name().to_string(), type_, field); columns.push(column); } @@ -104,13 +105,19 @@ pub async fn prepare( fn prepare_rec<'a>( client: &'a mut InnerClient, + cache: &'a mut CachedTypeInfo, query: &'a str, types: &'a [Type], ) -> Pin> + 'a + Send>> { - Box::pin(prepare(client, query, types)) + Box::pin(prepare(client, cache, query, types)) } -fn encode(client: &mut InnerClient, name: &str, query: &str, types: &[Type]) -> Result { +fn encode( + client: &mut InnerClient, + name: &str, + query: &str, + types: &[Type], +) -> Result { if types.is_empty() { debug!("preparing query {}: {}", name, query); } else { @@ -125,16 +132,20 @@ fn encode(client: &mut InnerClient, name: &str, query: &str, types: &[Type]) -> }) } -pub async fn get_type(client: &mut InnerClient, oid: Oid) -> Result { +pub async fn get_type( + client: &mut InnerClient, + cache: &mut CachedTypeInfo, + oid: Oid, +) -> Result { if let Some(type_) = Type::from_oid(oid) { return Ok(type_); } - if let Some(type_) = client.type_(oid) { + if let Some(type_) = cache.type_(oid) { return Ok(type_); } - let stmt = typeinfo_statement(client).await?; + let stmt = typeinfo_statement(client, cache).await?; let rows = query::query(client, stmt, slice_iter(&[&oid])).await?; pin_mut!(rows); @@ -144,118 +155,141 @@ pub async fn get_type(client: &mut InnerClient, oid: Oid) -> Result None => return Err(Error::unexpected_message()), }; - let name: String = row.try_get(0)?; - let type_: i8 = row.try_get(1)?; - let elem_oid: Oid = row.try_get(2)?; - let rngsubtype: Option = row.try_get(3)?; - let basetype: Oid = row.try_get(4)?; - let schema: String = row.try_get(5)?; - let relid: Oid = row.try_get(6)?; + let name: String = row.try_get(stmt.columns(), 0)?; + let type_: i8 = row.try_get(stmt.columns(), 1)?; + let elem_oid: Oid = row.try_get(stmt.columns(), 2)?; + let rngsubtype: Option = row.try_get(stmt.columns(), 3)?; + let basetype: Oid = row.try_get(stmt.columns(), 4)?; + let schema: String = row.try_get(stmt.columns(), 5)?; + let relid: Oid = row.try_get(stmt.columns(), 6)?; let kind = if type_ == b'e' as i8 { - let variants = get_enum_variants(client, oid).await?; + let variants = get_enum_variants(client, cache, oid).await?; Kind::Enum(variants) } else if type_ == b'p' as i8 { Kind::Pseudo } else if basetype != 0 { - let type_ = get_type_rec(client, basetype).await?; + let type_ = get_type_rec(client, cache, basetype).await?; Kind::Domain(type_) } else if elem_oid != 0 { - let type_ = get_type_rec(client, elem_oid).await?; + let type_ = get_type_rec(client, cache, elem_oid).await?; Kind::Array(type_) } else if relid != 0 { - let fields = get_composite_fields(client, relid).await?; + let fields = get_composite_fields(client, cache, relid).await?; Kind::Composite(fields) } else if let Some(rngsubtype) = rngsubtype { - let type_ = get_type_rec(client, rngsubtype).await?; + let type_ = get_type_rec(client, cache, rngsubtype).await?; Kind::Range(type_) } else { Kind::Simple }; let type_ = Type::new(name, oid, kind, schema); - client.set_type(oid, &type_); + cache.set_type(oid, &type_); Ok(type_) } fn get_type_rec<'a>( client: &'a mut InnerClient, + cache: &'a mut CachedTypeInfo, oid: Oid, ) -> Pin> + Send + 'a>> { - Box::pin(get_type(client, oid)) + Box::pin(get_type(client, cache, oid)) } -async fn typeinfo_statement(client: &mut InnerClient) -> Result { - if let Some(stmt) = client.typeinfo() { - return Ok(stmt); +async fn typeinfo_statement<'c>( + client: &mut InnerClient, + cache: &'c mut CachedTypeInfo, +) -> Result<&'c Statement, Error> { + if cache.typeinfo().is_some() { + // needed to get around a borrow checker limitation + return Ok(cache.typeinfo().unwrap()); } - let stmt = match prepare_rec(client, TYPEINFO_QUERY, &[]).await { + let stmt = match prepare_rec(client, cache, TYPEINFO_QUERY, &[]).await { Ok(stmt) => stmt, Err(ref e) if e.code() == Some(&SqlState::UNDEFINED_TABLE) => { - prepare_rec(client, TYPEINFO_FALLBACK_QUERY, &[]).await? + prepare_rec(client, cache, TYPEINFO_FALLBACK_QUERY, &[]).await? } Err(e) => return Err(e), }; - client.set_typeinfo(&stmt); - Ok(stmt) + Ok(cache.set_typeinfo(stmt)) } -async fn get_enum_variants(client: &mut InnerClient, oid: Oid) -> Result, Error> { - let stmt = typeinfo_enum_statement(client).await?; +async fn get_enum_variants( + client: &mut InnerClient, + cache: &mut CachedTypeInfo, + oid: Oid, +) -> Result, Error> { + let stmt = typeinfo_enum_statement(client, cache).await?; - query::query(client, stmt, slice_iter(&[&oid])) - .await? - .and_then(|row| async move { row.try_get(0) }) - .try_collect() - .await + let mut out = vec![]; + + let mut rows = pin!(query::query(client, stmt, slice_iter(&[&oid])).await?); + while let Some(row) = rows.next().await { + out.push(row?.try_get(stmt.columns(), 0)?) + } + Ok(out) } -async fn typeinfo_enum_statement(client: &mut InnerClient) -> Result { - if let Some(stmt) = client.typeinfo_enum() { - return Ok(stmt); +async fn typeinfo_enum_statement<'c>( + client: &mut InnerClient, + cache: &'c mut CachedTypeInfo, +) -> Result<&'c Statement, Error> { + if cache.typeinfo_enum().is_some() { + // needed to get around a borrow checker limitation + return Ok(cache.typeinfo_enum().unwrap()); } - let stmt = match prepare_rec(client, TYPEINFO_ENUM_QUERY, &[]).await { + let stmt = match prepare_rec(client, cache, TYPEINFO_ENUM_QUERY, &[]).await { Ok(stmt) => stmt, Err(ref e) if e.code() == Some(&SqlState::UNDEFINED_COLUMN) => { - prepare_rec(client, TYPEINFO_ENUM_FALLBACK_QUERY, &[]).await? + prepare_rec(client, cache, TYPEINFO_ENUM_FALLBACK_QUERY, &[]).await? } Err(e) => return Err(e), }; - client.set_typeinfo_enum(&stmt); - Ok(stmt) + Ok(cache.set_typeinfo_enum(stmt)) } -async fn get_composite_fields(client: &mut InnerClient, oid: Oid) -> Result, Error> { - let stmt = typeinfo_composite_statement(client).await?; +async fn get_composite_fields( + client: &mut InnerClient, + cache: &mut CachedTypeInfo, + oid: Oid, +) -> Result, Error> { + let stmt = typeinfo_composite_statement(client, cache).await?; - let rows = query::query(client, stmt, slice_iter(&[&oid])) - .await? - .try_collect::>() - .await?; + let mut rows = pin!(query::query(client, stmt, slice_iter(&[&oid])).await?); + + let mut oids = vec![]; + while let Some(row) = rows.next().await { + let row = row?; + let name = row.try_get(stmt.columns(), 0)?; + let oid = row.try_get(stmt.columns(), 1)?; + oids.push((name, oid)); + } let mut fields = vec![]; - for row in rows { - let name = row.try_get(0)?; - let oid = row.try_get(1)?; - let type_ = get_type_rec(client, oid).await?; + for (name, oid) in oids { + let type_ = get_type_rec(client, cache, oid).await?; fields.push(Field::new(name, type_)); } Ok(fields) } -async fn typeinfo_composite_statement(client: &mut InnerClient) -> Result { - if let Some(stmt) = client.typeinfo_composite() { - return Ok(stmt); +async fn typeinfo_composite_statement<'c>( + client: &mut InnerClient, + cache: &'c mut CachedTypeInfo, +) -> Result<&'c Statement, Error> { + if cache.typeinfo_composite().is_some() { + // needed to get around a borrow checker limitation + return Ok(cache.typeinfo_composite().unwrap()); } - let stmt = prepare_rec(client, TYPEINFO_COMPOSITE_QUERY, &[]).await?; + let stmt = prepare_rec(client, cache, TYPEINFO_COMPOSITE_QUERY, &[]).await?; - client.set_typeinfo_composite(&stmt); - Ok(stmt) + Ok(cache.set_typeinfo_composite(stmt)) } diff --git a/libs/proxy/tokio-postgres2/src/query.rs b/libs/proxy/tokio-postgres2/src/query.rs index b29e441a2b..72b00e1912 100644 --- a/libs/proxy/tokio-postgres2/src/query.rs +++ b/libs/proxy/tokio-postgres2/src/query.rs @@ -26,9 +26,9 @@ impl fmt::Debug for BorrowToSqlParamsDebug<'_> { pub async fn query<'a, I>( client: &mut InnerClient, - statement: Statement, + statement: &Statement, params: I, -) -> Result +) -> Result where I: IntoIterator, I::IntoIter: ExactSizeIterator, @@ -40,13 +40,12 @@ where statement.name(), BorrowToSqlParamsDebug(params.as_slice()), ); - encode(client, &statement, params)? + encode(client, statement, params)? } else { - encode(client, &statement, params)? + encode(client, statement, params)? }; let responses = start(client, buf).await?; - Ok(RowStream { - statement, + Ok(RawRowStream { responses, command_tag: None, status: ReadyForQueryStatus::Unknown, @@ -167,7 +166,11 @@ async fn start(client: &InnerClient, buf: Bytes) -> Result { Ok(responses) } -pub fn encode<'a, I>(client: &mut InnerClient, statement: &Statement, params: I) -> Result +pub fn encode<'a, I>( + client: &mut InnerClient, + statement: &Statement, + params: I, +) -> Result where I: IntoIterator, I::IntoIter: ExactSizeIterator, @@ -252,11 +255,7 @@ impl Stream for RowStream { loop { match ready!(this.responses.poll_next(cx)?) { Message::DataRow(body) => { - return Poll::Ready(Some(Ok(Row::new( - this.statement.clone(), - body, - *this.output_format, - )?))) + return Poll::Ready(Some(Ok(Row::new(body, *this.output_format)?))) } Message::EmptyQueryResponse | Message::PortalSuspended => {} Message::CommandComplete(body) => { @@ -294,3 +293,41 @@ impl RowStream { self.status } } + +pin_project! { + /// A stream of table rows. + pub struct RawRowStream { + responses: Responses, + command_tag: Option, + output_format: Format, + status: ReadyForQueryStatus, + #[pin] + _p: PhantomPinned, + } +} + +impl Stream for RawRowStream { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.project(); + loop { + match ready!(this.responses.poll_next(cx)?) { + Message::DataRow(body) => { + return Poll::Ready(Some(Ok(Row::new(body, *this.output_format)?))) + } + Message::EmptyQueryResponse | Message::PortalSuspended => {} + Message::CommandComplete(body) => { + if let Ok(tag) = body.tag() { + *this.command_tag = Some(tag.to_string()); + } + } + Message::ReadyForQuery(status) => { + *this.status = status.into(); + return Poll::Ready(None); + } + _ => return Poll::Ready(Some(Err(Error::unexpected_message()))), + } + } + } +} diff --git a/libs/proxy/tokio-postgres2/src/row.rs b/libs/proxy/tokio-postgres2/src/row.rs index b8c5e84b4f..230a3eaaa7 100644 --- a/libs/proxy/tokio-postgres2/src/row.rs +++ b/libs/proxy/tokio-postgres2/src/row.rs @@ -1,7 +1,7 @@ //! Rows. use crate::statement::Column; use crate::types::{FromSql, Type, WrongType}; -use crate::{Error, Statement}; +use crate::Error; use fallible_iterator::FallibleIterator; use postgres_protocol2::message::backend::DataRowBody; use postgres_types2::{Format, WrongFormat}; @@ -11,7 +11,6 @@ use std::str; /// A row of data returned from the database by a query. pub struct Row { - statement: Statement, output_format: Format, body: DataRowBody, ranges: Vec>>, @@ -19,72 +18,29 @@ pub struct Row { impl fmt::Debug for Row { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("Row") - .field("columns", &self.columns()) - .finish() + f.debug_struct("Row").finish() } } impl Row { pub(crate) fn new( - statement: Statement, + // statement: Statement, body: DataRowBody, output_format: Format, ) -> Result { let ranges = body.ranges().collect().map_err(Error::parse)?; Ok(Row { - statement, body, ranges, output_format, }) } - /// Returns information about the columns of data in the row. - pub fn columns(&self) -> &[Column] { - self.statement.columns() - } - - /// Determines if the row contains no values. - pub fn is_empty(&self) -> bool { - self.len() == 0 - } - - /// Returns the number of values in the row. - pub fn len(&self) -> usize { - self.columns().len() - } - - /// Deserializes a value from the row. - /// - /// The value can be specified either by its numeric index in the row, or by its column name. - /// - /// # Panics - /// - /// Panics if the index is out of bounds or if the value cannot be converted to the specified type. - pub fn get<'a, T>(&'a self, idx: usize) -> T + pub(crate) fn try_get<'a, T>(&'a self, columns: &[Column], idx: usize) -> Result where T: FromSql<'a>, { - match self.get_inner(idx) { - Ok(ok) => ok, - Err(err) => panic!("error retrieving column {}: {}", idx, err), - } - } - - /// Like `Row::get`, but returns a `Result` rather than panicking. - pub fn try_get<'a, T>(&'a self, idx: usize) -> Result - where - T: FromSql<'a>, - { - self.get_inner(idx) - } - - fn get_inner<'a, T>(&'a self, idx: usize) -> Result - where - T: FromSql<'a>, - { - let Some(column) = self.columns().get(idx) else { + let Some(column) = columns.get(idx) else { return Err(Error::column(idx.to_string())); }; diff --git a/libs/proxy/tokio-postgres2/src/statement.rs b/libs/proxy/tokio-postgres2/src/statement.rs index 6e5cdfe02e..4673bd029c 100644 --- a/libs/proxy/tokio-postgres2/src/statement.rs +++ b/libs/proxy/tokio-postgres2/src/statement.rs @@ -1,6 +1,6 @@ use crate::types::Type; use postgres_protocol2::{message::backend::Field, Oid}; -use std::{fmt, sync::Arc}; +use std::fmt; struct StatementInner { name: String, @@ -11,24 +11,23 @@ struct StatementInner { /// A prepared statement. /// /// Prepared statements can only be used with the connection that created them. -#[derive(Clone)] -pub struct Statement(Arc); +pub struct Statement(StatementInner); impl Statement { pub(crate) fn new(name: String, params: Vec, columns: Vec) -> Statement { - Statement(Arc::new(StatementInner { + Statement(StatementInner { name, params, columns, - })) + }) } pub(crate) fn new_anonymous(params: Vec, columns: Vec) -> Statement { - Statement(Arc::new(StatementInner { + Statement(StatementInner { name: String::new(), params, columns, - })) + }) } pub(crate) fn name(&self) -> &str { diff --git a/libs/proxy/tokio-postgres2/src/transaction.rs b/libs/proxy/tokio-postgres2/src/transaction.rs index b150f6a371..ed2087c993 100644 --- a/libs/proxy/tokio-postgres2/src/transaction.rs +++ b/libs/proxy/tokio-postgres2/src/transaction.rs @@ -19,13 +19,13 @@ impl Drop for Transaction<'_> { return; } - let buf = self.client.inner().with_buf(|buf| { + let buf = self.client.inner.with_buf(|buf| { frontend::query("ROLLBACK", buf).unwrap(); buf.split().freeze() }); let _ = self .client - .inner() + .inner .send(RequestMessages::Single(FrontendMessage::Raw(buf))); } } @@ -53,7 +53,11 @@ impl<'a> Transaction<'a> { } /// Like `Client::query_raw_txt`. - pub async fn query_raw_txt(&mut self, statement: &str, params: I) -> Result + pub async fn query_raw_txt( + &mut self, + statement: &str, + params: I, + ) -> Result where S: AsRef, I: IntoIterator>, diff --git a/proxy/src/serverless/json.rs b/proxy/src/serverless/json.rs index 25b25c66d3..282e11a078 100644 --- a/proxy/src/serverless/json.rs +++ b/proxy/src/serverless/json.rs @@ -1,5 +1,5 @@ use postgres_client::types::{Kind, Type}; -use postgres_client::Row; +use postgres_client::{Column, Row}; use serde_json::{Map, Value}; // @@ -77,14 +77,14 @@ pub(crate) enum JsonConversionError { // pub(crate) fn pg_text_row_to_json( row: &Row, - columns: &[Type], + columns: &[Column], + c_types: &[Type], raw_output: bool, array_mode: bool, ) -> Result { - let iter = row - .columns() + let iter = columns .iter() - .zip(columns) + .zip(c_types) .enumerate() .map(|(i, (column, typ))| { let name = column.name(); diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index 7bc650ba58..85100fd9b7 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -797,7 +797,13 @@ impl QueryData { let cancel_token = inner.cancel_token(); let res = match select( - pin!(query_to_json(config, &mut *inner, self, &mut 0, parsed_headers)), + pin!(query_to_json( + config, + &mut *inner, + self, + &mut 0, + parsed_headers + )), pin!(cancel.cancelled()), ) .await @@ -1027,7 +1033,7 @@ async fn query_to_json( let columns_len = row_stream.columns().len(); let mut fields = Vec::with_capacity(columns_len); - let mut columns = Vec::with_capacity(columns_len); + let mut c_types = Vec::with_capacity(columns_len); for c in row_stream.columns() { fields.push(json!({ @@ -1039,7 +1045,7 @@ async fn query_to_json( "dataTypeModifier": c.type_modifier(), "format": "text", })); - columns.push(client.get_type(c.type_oid()).await?); + c_types.push(client.get_type(c.type_oid()).await?); } let array_mode = data.array_mode.unwrap_or(parsed_headers.default_array_mode); @@ -1047,7 +1053,15 @@ async fn query_to_json( // convert rows to JSON let rows = rows .iter() - .map(|row| pg_text_row_to_json(row, &columns, parsed_headers.raw_output, array_mode)) + .map(|row| { + pg_text_row_to_json( + row, + row_stream.columns(), + &c_types, + parsed_headers.raw_output, + array_mode, + ) + }) .collect::, _>>()?; // Resulting JSON format is based on the format of node-postgres result.