From 6768a71c8656dd9bb28bcd57f042c7306cb23c9e Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Fri, 23 May 2025 20:41:12 +0100 Subject: [PATCH] proxy(tokio-postgres): refactor typeinfo query to occur earlier (#11993) ## Problem For #11992 I realised we need to get the type info before executing the query. This is important to know how to decode rows with custom types, eg the following query: ```sql CREATE TYPE foo AS ENUM ('foo','bar','baz'); SELECT ARRAY['foo'::foo, 'bar'::foo, 'baz'::foo] AS data; ``` Getting that to work was harder that it seems. The original tokio-postgres setup has a split between `Client` and `Connection`, where messages are passed between. Because multiple clients were supported, each client message included a dedicated response channel. Each request would be terminated by the `ReadyForQuery` message. The flow I opted to use for parsing types early would not trigger a `ReadyForQuery`. The flow is as follows: ``` PARSE "" // parse the user provided query DESCRIBE "" // describe the query, returning param/result type oids FLUSH // force postgres to flush the responses early // wait for descriptions // check if we know the types, if we don't then // setup the typeinfo query and execute it against each OID: PARSE typeinfo // prepare our typeinfo query DESCRIBE typeinfo FLUSH // force postgres to flush the responses early // wait for typeinfo statement // for each OID we don't know: BIND typeinfo EXECUTE FLUSH // wait for type info, might reveal more OIDs to inspect // close the typeinfo query, we cache the OID->type map and this is kinder to pgbouncer. CLOSE typeinfo // finally once we know all the OIDs: BIND "" // bind the user provided query - already parsed - to the user provided params EXECUTE // run the user provided query SYNC // commit the transaction ``` ## Summary of changes Please review commit by commit. The main challenge was allowing one query to issue multiple sub-queries. To do this I first made sure that the client could fully own the connection, which required removing any shared client state. I then had to replace the way responses are sent to the client, by using only a single permanent channel. This required some additional effort to track which query is being processed. Lastly I had to modify the query/typeinfo functions to not issue `sync` commands, so it would fit into the desired flow above. To note: the flow above does force an extra roundtrip into each query. I don't know yet if this has a measurable latency overhead. --- .../src/message/frontend.rs | 7 + libs/proxy/postgres-types2/src/lib.rs | 124 +------- libs/proxy/postgres-types2/src/type_gen.rs | 21 +- libs/proxy/tokio-postgres2/src/client.rs | 198 ++++++++---- libs/proxy/tokio-postgres2/src/codec.rs | 13 +- libs/proxy/tokio-postgres2/src/connect.rs | 8 +- libs/proxy/tokio-postgres2/src/connection.rs | 92 ++---- .../tokio-postgres2/src/generic_client.rs | 22 +- libs/proxy/tokio-postgres2/src/lib.rs | 7 - libs/proxy/tokio-postgres2/src/prepare.rs | 285 ++++++++++++------ libs/proxy/tokio-postgres2/src/query.rs | 271 ++++------------- .../proxy/tokio-postgres2/src/simple_query.rs | 36 +-- libs/proxy/tokio-postgres2/src/statement.rs | 54 +--- libs/proxy/tokio-postgres2/src/transaction.rs | 19 +- proxy/src/serverless/sql_over_http.rs | 88 +++--- 15 files changed, 500 insertions(+), 745 deletions(-) diff --git a/libs/proxy/postgres-protocol2/src/message/frontend.rs b/libs/proxy/postgres-protocol2/src/message/frontend.rs index b447290ea8..9faed2c065 100644 --- a/libs/proxy/postgres-protocol2/src/message/frontend.rs +++ b/libs/proxy/postgres-protocol2/src/message/frontend.rs @@ -25,6 +25,7 @@ where Ok(()) } +#[derive(Debug)] pub enum BindError { Conversion(Box), Serialization(io::Error), @@ -288,6 +289,12 @@ pub fn sync(buf: &mut BytesMut) { write_body(buf, |_| Ok::<(), io::Error>(())).unwrap(); } +#[inline] +pub fn flush(buf: &mut BytesMut) { + buf.put_u8(b'H'); + write_body(buf, |_| Ok::<(), io::Error>(())).unwrap(); +} + #[inline] pub fn terminate(buf: &mut BytesMut) { buf.put_u8(b'X'); diff --git a/libs/proxy/postgres-types2/src/lib.rs b/libs/proxy/postgres-types2/src/lib.rs index b6bcabc922..7c9874bda3 100644 --- a/libs/proxy/postgres-types2/src/lib.rs +++ b/libs/proxy/postgres-types2/src/lib.rs @@ -9,7 +9,6 @@ use std::error::Error; use std::fmt; use std::sync::Arc; -use bytes::BytesMut; use fallible_iterator::FallibleIterator; #[doc(inline)] pub use postgres_protocol2::Oid; @@ -27,41 +26,6 @@ macro_rules! accepts { ) } -/// Generates an implementation of `ToSql::to_sql_checked`. -/// -/// All `ToSql` implementations should use this macro. -macro_rules! to_sql_checked { - () => { - fn to_sql_checked( - &self, - ty: &$crate::Type, - out: &mut $crate::private::BytesMut, - ) -> ::std::result::Result< - $crate::IsNull, - Box, - > { - $crate::__to_sql_checked(self, ty, out) - } - }; -} - -// WARNING: this function is not considered part of this crate's public API. -// It is subject to change at any time. -#[doc(hidden)] -pub fn __to_sql_checked( - v: &T, - ty: &Type, - out: &mut BytesMut, -) -> Result> -where - T: ToSql, -{ - if !T::accepts(ty) { - return Err(Box::new(WrongType::new::(ty.clone()))); - } - v.to_sql(ty, out) -} - // mod pg_lsn; #[doc(hidden)] pub mod private; @@ -142,7 +106,7 @@ pub enum Kind { /// An array type along with the type of its elements. Array(Type), /// A range type along with the type of its elements. - Range(Type), + Range(Oid), /// A multirange type along with the type of its elements. Multirange(Type), /// A domain type along with its underlying type. @@ -377,43 +341,6 @@ pub enum IsNull { No, } -/// A trait for types that can be converted into Postgres values. -pub trait ToSql: fmt::Debug { - /// Converts the value of `self` into the binary format of the specified - /// Postgres `Type`, appending it to `out`. - /// - /// The caller of this method is responsible for ensuring that this type - /// is compatible with the Postgres `Type`. - /// - /// The return value indicates if this value should be represented as - /// `NULL`. If this is the case, implementations **must not** write - /// anything to `out`. - fn to_sql(&self, ty: &Type, out: &mut BytesMut) -> Result> - where - Self: Sized; - - /// Determines if a value of this type can be converted to the specified - /// Postgres `Type`. - fn accepts(ty: &Type) -> bool - where - Self: Sized; - - /// An adaptor method used internally by Rust-Postgres. - /// - /// *All* implementations of this method should be generated by the - /// `to_sql_checked!()` macro. - fn to_sql_checked( - &self, - ty: &Type, - out: &mut BytesMut, - ) -> Result>; - - /// Specify the encode format - fn encode_format(&self, _ty: &Type) -> Format { - Format::Binary - } -} - /// Supported Postgres message format types /// /// Using Text format in a message assumes a Postgres `SERVER_ENCODING` of `UTF8` @@ -424,52 +351,3 @@ pub enum Format { /// Compact, typed binary format Binary, } - -impl ToSql for &str { - fn to_sql(&self, ty: &Type, w: &mut BytesMut) -> Result> { - match *ty { - ref ty if ty.name() == "ltree" => types::ltree_to_sql(self, w), - ref ty if ty.name() == "lquery" => types::lquery_to_sql(self, w), - ref ty if ty.name() == "ltxtquery" => types::ltxtquery_to_sql(self, w), - _ => types::text_to_sql(self, w), - } - Ok(IsNull::No) - } - - fn accepts(ty: &Type) -> bool { - match *ty { - Type::VARCHAR | Type::TEXT | Type::BPCHAR | Type::NAME | Type::UNKNOWN => true, - ref ty - if (ty.name() == "citext" - || ty.name() == "ltree" - || ty.name() == "lquery" - || ty.name() == "ltxtquery") => - { - true - } - _ => false, - } - } - - to_sql_checked!(); -} - -macro_rules! simple_to { - ($t:ty, $f:ident, $($expected:ident),+) => { - impl ToSql for $t { - fn to_sql(&self, - _: &Type, - w: &mut BytesMut) - -> Result> { - types::$f(*self, w); - Ok(IsNull::No) - } - - accepts!($($expected),+); - - to_sql_checked!(); - } - } -} - -simple_to!(u32, oid_to_sql, OID); diff --git a/libs/proxy/postgres-types2/src/type_gen.rs b/libs/proxy/postgres-types2/src/type_gen.rs index a1bc3f85c0..6e6163e343 100644 --- a/libs/proxy/postgres-types2/src/type_gen.rs +++ b/libs/proxy/postgres-types2/src/type_gen.rs @@ -393,7 +393,7 @@ impl Inner { } } - pub fn oid(&self) -> Oid { + pub const fn const_oid(&self) -> Oid { match *self { Inner::Bool => 16, Inner::Bytea => 17, @@ -580,7 +580,14 @@ impl Inner { Inner::TstzmultiRangeArray => 6153, Inner::DatemultiRangeArray => 6155, Inner::Int8multiRangeArray => 6157, + Inner::Other(_) => u32::MAX, + } + } + + pub fn oid(&self) -> Oid { + match *self { Inner::Other(ref u) => u.oid, + _ => self.const_oid(), } } @@ -727,17 +734,17 @@ impl Inner { Inner::JsonbArray => &Kind::Array(Type(Inner::Jsonb)), Inner::AnyRange => &Kind::Pseudo, Inner::EventTrigger => &Kind::Pseudo, - Inner::Int4Range => &Kind::Range(Type(Inner::Int4)), + Inner::Int4Range => &const { Kind::Range(Inner::Int4.const_oid()) }, Inner::Int4RangeArray => &Kind::Array(Type(Inner::Int4Range)), - Inner::NumRange => &Kind::Range(Type(Inner::Numeric)), + Inner::NumRange => &const { Kind::Range(Inner::Numeric.const_oid()) }, Inner::NumRangeArray => &Kind::Array(Type(Inner::NumRange)), - Inner::TsRange => &Kind::Range(Type(Inner::Timestamp)), + Inner::TsRange => &const { Kind::Range(Inner::Timestamp.const_oid()) }, Inner::TsRangeArray => &Kind::Array(Type(Inner::TsRange)), - Inner::TstzRange => &Kind::Range(Type(Inner::Timestamptz)), + Inner::TstzRange => &const { Kind::Range(Inner::Timestamptz.const_oid()) }, Inner::TstzRangeArray => &Kind::Array(Type(Inner::TstzRange)), - Inner::DateRange => &Kind::Range(Type(Inner::Date)), + Inner::DateRange => &const { Kind::Range(Inner::Date.const_oid()) }, Inner::DateRangeArray => &Kind::Array(Type(Inner::DateRange)), - Inner::Int8Range => &Kind::Range(Type(Inner::Int8)), + Inner::Int8Range => &const { Kind::Range(Inner::Int8.const_oid()) }, Inner::Int8RangeArray => &Kind::Array(Type(Inner::Int8Range)), Inner::Jsonpath => &Kind::Simple, Inner::JsonpathArray => &Kind::Array(Type(Inner::Jsonpath)), diff --git a/libs/proxy/tokio-postgres2/src/client.rs b/libs/proxy/tokio-postgres2/src/client.rs index 186eb07000..a7edfc076a 100644 --- a/libs/proxy/tokio-postgres2/src/client.rs +++ b/libs/proxy/tokio-postgres2/src/client.rs @@ -1,14 +1,12 @@ use std::collections::HashMap; use std::fmt; use std::net::IpAddr; -use std::sync::Arc; use std::task::{Context, Poll}; use std::time::Duration; use bytes::BytesMut; use fallible_iterator::FallibleIterator; use futures_util::{TryStreamExt, future, ready}; -use parking_lot::Mutex; use postgres_protocol2::message::backend::Message; use postgres_protocol2::message::frontend; use serde::{Deserialize, Serialize}; @@ -16,29 +14,52 @@ use tokio::sync::mpsc; use crate::codec::{BackendMessages, FrontendMessage}; use crate::config::{Host, SslMode}; -use crate::connection::{Request, RequestMessages}; use crate::query::RowStream; use crate::simple_query::SimpleQueryStream; use crate::types::{Oid, Type}; use crate::{ - CancelToken, Error, ReadyForQueryStatus, SimpleQueryMessage, Statement, Transaction, - TransactionBuilder, query, simple_query, + CancelToken, Error, ReadyForQueryStatus, SimpleQueryMessage, Transaction, TransactionBuilder, + query, simple_query, }; pub struct Responses { + /// new messages from conn receiver: mpsc::Receiver, + /// current batch of messages cur: BackendMessages, + /// number of total queries sent. + waiting: usize, + /// number of ReadyForQuery messages received. + received: usize, } impl Responses { pub fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll> { loop { - match self.cur.next().map_err(Error::parse)? { - Some(Message::ErrorResponse(body)) => return Poll::Ready(Err(Error::db(body))), - Some(message) => return Poll::Ready(Ok(message)), - None => {} + // get the next saved message + if let Some(message) = self.cur.next().map_err(Error::parse)? { + let received = self.received; + + // increase the query head if this is the last message. + if let Message::ReadyForQuery(_) = message { + self.received += 1; + } + + // check if the client has skipped this query. + if received + 1 < self.waiting { + // grab the next message. + continue; + } + + // convenience: turn the error messaage into a proper error. + let res = match message { + Message::ErrorResponse(body) => Err(Error::db(body)), + message => Ok(message), + }; + return Poll::Ready(res); } + // get the next batch of messages. match ready!(self.receiver.poll_recv(cx)) { Some(messages) => self.cur = messages, None => return Poll::Ready(Err(Error::closed())), @@ -55,44 +76,87 @@ impl Responses { /// (corresponding to the queries in the [crate::prepare] module). #[derive(Default)] pub(crate) struct CachedTypeInfo { - /// A statement for basic information for a type from its - /// OID. Corresponds to [TYPEINFO_QUERY](crate::prepare::TYPEINFO_QUERY) (or its - /// fallback). - pub(crate) typeinfo: Option, - /// Cache of types already looked up. pub(crate) types: HashMap, } pub struct InnerClient { - sender: mpsc::UnboundedSender, + sender: mpsc::UnboundedSender, + responses: Responses, /// A buffer to use when writing out postgres commands. - buffer: Mutex, + buffer: BytesMut, } impl InnerClient { - pub fn send(&self, messages: RequestMessages) -> Result { - let (sender, receiver) = mpsc::channel(1); - let request = Request { messages, sender }; - self.sender.send(request).map_err(|_| Error::closed())?; - - Ok(Responses { - receiver, - cur: BackendMessages::empty(), - }) + pub fn start(&mut self) -> Result { + self.responses.waiting += 1; + Ok(PartialQuery(Some(self))) } - /// Call the given function with a buffer to be used when writing out - /// postgres commands. - pub fn with_buf(&self, f: F) -> R + // pub fn send_with_sync(&mut self, f: F) -> Result<&mut Responses, Error> + // where + // F: FnOnce(&mut BytesMut) -> Result<(), Error>, + // { + // self.start()?.send_with_sync(f) + // } + + pub fn send_simple_query(&mut self, query: &str) -> Result<&mut Responses, Error> { + self.responses.waiting += 1; + + self.buffer.clear(); + // simple queries do not need sync. + frontend::query(query, &mut self.buffer).map_err(Error::encode)?; + let buf = self.buffer.split().freeze(); + self.send_message(FrontendMessage::Raw(buf)) + } + + fn send_message(&mut self, messages: FrontendMessage) -> Result<&mut Responses, Error> { + self.sender.send(messages).map_err(|_| Error::closed())?; + Ok(&mut self.responses) + } +} + +pub struct PartialQuery<'a>(Option<&'a mut InnerClient>); + +impl Drop for PartialQuery<'_> { + fn drop(&mut self) { + if let Some(client) = self.0.take() { + client.buffer.clear(); + frontend::sync(&mut client.buffer); + let buf = client.buffer.split().freeze(); + let _ = client.send_message(FrontendMessage::Raw(buf)); + } + } +} + +impl<'a> PartialQuery<'a> { + pub fn send_with_flush(&mut self, f: F) -> Result<&mut Responses, Error> where - F: FnOnce(&mut BytesMut) -> R, + F: FnOnce(&mut BytesMut) -> Result<(), Error>, { - let mut buffer = self.buffer.lock(); - let r = f(&mut buffer); - buffer.clear(); - r + let client = self.0.as_deref_mut().unwrap(); + + client.buffer.clear(); + f(&mut client.buffer)?; + frontend::flush(&mut client.buffer); + let buf = client.buffer.split().freeze(); + client.send_message(FrontendMessage::Raw(buf)) + } + + pub fn send_with_sync(mut self, f: F) -> Result<&'a mut Responses, Error> + where + F: FnOnce(&mut BytesMut) -> Result<(), Error>, + { + let client = self.0.as_deref_mut().unwrap(); + + client.buffer.clear(); + f(&mut client.buffer)?; + frontend::sync(&mut client.buffer); + let buf = client.buffer.split().freeze(); + let _ = client.send_message(FrontendMessage::Raw(buf)); + + Ok(&mut self.0.take().unwrap().responses) } } @@ -109,7 +173,7 @@ pub struct SocketConfig { /// The client is one half of what is returned when a connection is established. Users interact with the database /// through this client object. pub struct Client { - inner: Arc, + inner: InnerClient, cached_typeinfo: CachedTypeInfo, socket_config: SocketConfig, @@ -120,17 +184,24 @@ pub struct Client { impl Client { pub(crate) fn new( - sender: mpsc::UnboundedSender, + sender: mpsc::UnboundedSender, + receiver: mpsc::Receiver, socket_config: SocketConfig, ssl_mode: SslMode, process_id: i32, secret_key: i32, ) -> Client { Client { - inner: Arc::new(InnerClient { + inner: InnerClient { sender, + responses: Responses { + receiver, + cur: BackendMessages::empty(), + waiting: 0, + received: 0, + }, buffer: Default::default(), - }), + }, cached_typeinfo: Default::default(), socket_config, @@ -145,19 +216,29 @@ impl Client { self.process_id } - pub(crate) fn inner(&self) -> &Arc { - &self.inner + pub(crate) fn inner_mut(&mut self) -> &mut InnerClient { + &mut self.inner } /// Pass text directly to the Postgres backend to allow it to sort out typing itself and /// to save a roundtrip - pub async fn query_raw_txt(&self, statement: &str, params: I) -> Result + pub async fn query_raw_txt( + &mut self, + statement: &str, + params: I, + ) -> Result where S: AsRef, I: IntoIterator>, I::IntoIter: ExactSizeIterator, { - query::query_txt(&self.inner, statement, params).await + query::query_txt( + &mut self.inner, + &mut self.cached_typeinfo, + statement, + params, + ) + .await } /// Executes a sequence of SQL statements using the simple query protocol, returning the resulting rows. @@ -173,12 +254,15 @@ impl Client { /// Prepared statements should be use for any query which contains user-specified data, as they provided the /// functionality to safely embed that data in the request. Do not form statements via string concatenation and pass /// them to this method! - pub async fn simple_query(&self, query: &str) -> Result, Error> { + pub async fn simple_query(&mut self, query: &str) -> Result, Error> { self.simple_query_raw(query).await?.try_collect().await } - pub(crate) async fn simple_query_raw(&self, query: &str) -> Result { - simple_query::simple_query(self.inner(), query).await + pub(crate) async fn simple_query_raw( + &mut self, + query: &str, + ) -> Result { + simple_query::simple_query(self.inner_mut(), query).await } /// Executes a sequence of SQL statements using the simple query protocol. @@ -191,15 +275,11 @@ impl Client { /// Prepared statements should be use for any query which contains user-specified data, as they provided the /// functionality to safely embed that data in the request. Do not form statements via string concatenation and pass /// them to this method! - pub async fn batch_execute(&self, query: &str) -> Result { - simple_query::batch_execute(self.inner(), query).await + pub async fn batch_execute(&mut self, query: &str) -> Result { + simple_query::batch_execute(self.inner_mut(), query).await } pub async fn discard_all(&mut self) -> Result { - // clear the prepared statements that are about to be nuked from the postgres session - - self.cached_typeinfo.typeinfo = None; - self.batch_execute("discard all").await } @@ -208,7 +288,7 @@ impl Client { /// The transaction will roll back by default - use the `commit` method to commit it. pub async fn transaction(&mut self) -> Result, Error> { struct RollbackIfNotDone<'me> { - client: &'me Client, + client: &'me mut Client, done: bool, } @@ -218,14 +298,7 @@ impl Client { return; } - let buf = self.client.inner().with_buf(|buf| { - frontend::query("ROLLBACK", buf).unwrap(); - buf.split().freeze() - }); - let _ = self - .client - .inner() - .send(RequestMessages::Single(FrontendMessage::Raw(buf))); + let _ = self.client.inner.send_simple_query("ROLLBACK"); } } @@ -239,7 +312,7 @@ impl Client { client: self, done: false, }; - self.batch_execute("BEGIN").await?; + cleaner.client.batch_execute("BEGIN").await?; cleaner.done = true; } @@ -265,11 +338,6 @@ impl Client { } } - /// Query for type information - pub(crate) async fn get_type_inner(&mut self, oid: Oid) -> Result { - crate::prepare::get_type(&self.inner, &mut self.cached_typeinfo, oid).await - } - /// Determines if the connection to the server has already closed. /// /// In that case, all future queries will fail. diff --git a/libs/proxy/tokio-postgres2/src/codec.rs b/libs/proxy/tokio-postgres2/src/codec.rs index f1fd9b47b3..daa5371426 100644 --- a/libs/proxy/tokio-postgres2/src/codec.rs +++ b/libs/proxy/tokio-postgres2/src/codec.rs @@ -1,21 +1,16 @@ use std::io; -use bytes::{Buf, Bytes, BytesMut}; +use bytes::{Bytes, BytesMut}; use fallible_iterator::FallibleIterator; use postgres_protocol2::message::backend; -use postgres_protocol2::message::frontend::CopyData; use tokio_util::codec::{Decoder, Encoder}; pub enum FrontendMessage { Raw(Bytes), - CopyData(CopyData>), } pub enum BackendMessage { - Normal { - messages: BackendMessages, - request_complete: bool, - }, + Normal { messages: BackendMessages }, Async(backend::Message), } @@ -44,7 +39,6 @@ impl Encoder for PostgresCodec { fn encode(&mut self, item: FrontendMessage, dst: &mut BytesMut) -> io::Result<()> { match item { FrontendMessage::Raw(buf) => dst.extend_from_slice(&buf), - FrontendMessage::CopyData(data) => data.write(dst), } Ok(()) @@ -57,7 +51,6 @@ impl Decoder for PostgresCodec { fn decode(&mut self, src: &mut BytesMut) -> Result, io::Error> { let mut idx = 0; - let mut request_complete = false; while let Some(header) = backend::Header::parse(&src[idx..])? { let len = header.len() as usize + 1; @@ -82,7 +75,6 @@ impl Decoder for PostgresCodec { idx += len; if header.tag() == backend::READY_FOR_QUERY_TAG { - request_complete = true; break; } } @@ -92,7 +84,6 @@ impl Decoder for PostgresCodec { } else { Ok(Some(BackendMessage::Normal { messages: BackendMessages(src.split_to(idx)), - request_complete, })) } } diff --git a/libs/proxy/tokio-postgres2/src/connect.rs b/libs/proxy/tokio-postgres2/src/connect.rs index 7c3a358bba..39a0a87c74 100644 --- a/libs/proxy/tokio-postgres2/src/connect.rs +++ b/libs/proxy/tokio-postgres2/src/connect.rs @@ -59,9 +59,11 @@ where connect_timeout: config.connect_timeout, }; - let (sender, receiver) = mpsc::unbounded_channel(); + let (client_tx, conn_rx) = mpsc::unbounded_channel(); + let (conn_tx, client_rx) = mpsc::channel(4); let client = Client::new( - sender, + client_tx, + client_rx, socket_config, config.ssl_mode, process_id, @@ -74,7 +76,7 @@ where .map(|m| BackendMessage::Async(Message::NoticeResponse(m))) .collect(); - let connection = Connection::new(stream, delayed, parameters, receiver); + let connection = Connection::new(stream, delayed, parameters, conn_tx, conn_rx); Ok((client, connection)) } diff --git a/libs/proxy/tokio-postgres2/src/connection.rs b/libs/proxy/tokio-postgres2/src/connection.rs index 99d6f3f8e2..fe0372b266 100644 --- a/libs/proxy/tokio-postgres2/src/connection.rs +++ b/libs/proxy/tokio-postgres2/src/connection.rs @@ -4,7 +4,6 @@ use std::pin::Pin; use std::task::{Context, Poll}; use bytes::BytesMut; -use fallible_iterator::FallibleIterator; use futures_util::{Sink, Stream, ready}; use postgres_protocol2::message::backend::Message; use postgres_protocol2::message::frontend; @@ -19,30 +18,12 @@ use crate::error::DbError; use crate::maybe_tls_stream::MaybeTlsStream; use crate::{AsyncMessage, Error, Notification}; -pub enum RequestMessages { - Single(FrontendMessage), -} - -pub struct Request { - pub messages: RequestMessages, - pub sender: mpsc::Sender, -} - -pub struct Response { - sender: PollSender, -} - #[derive(PartialEq, Debug)] enum State { Active, Closing, } -enum WriteReady { - Terminating, - WaitingOnRead, -} - /// A connection to a PostgreSQL database. /// /// This is one half of what is returned when a new connection is established. It performs the actual IO with the @@ -56,9 +37,11 @@ pub struct Connection { pub stream: Framed, PostgresCodec>, /// HACK: we need this in the Neon Proxy to forward params. pub parameters: HashMap, - receiver: mpsc::UnboundedReceiver, + + sender: PollSender, + receiver: mpsc::UnboundedReceiver, + pending_responses: VecDeque, - responses: VecDeque, state: State, } @@ -71,14 +54,15 @@ where stream: Framed, PostgresCodec>, pending_responses: VecDeque, parameters: HashMap, - receiver: mpsc::UnboundedReceiver, + sender: mpsc::Sender, + receiver: mpsc::UnboundedReceiver, ) -> Connection { Connection { stream, parameters, + sender: PollSender::new(sender), receiver, pending_responses, - responses: VecDeque::new(), state: State::Active, } } @@ -110,7 +94,7 @@ where } }; - let (mut messages, request_complete) = match message { + let messages = match message { BackendMessage::Async(Message::NoticeResponse(body)) => { let error = DbError::parse(&mut body.fields()).map_err(Error::parse)?; return Poll::Ready(Ok(AsyncMessage::Notice(error))); @@ -131,41 +115,19 @@ where continue; } BackendMessage::Async(_) => unreachable!(), - BackendMessage::Normal { - messages, - request_complete, - } => (messages, request_complete), + BackendMessage::Normal { messages } => messages, }; - let mut response = match self.responses.pop_front() { - Some(response) => response, - None => match messages.next().map_err(Error::parse)? { - Some(Message::ErrorResponse(error)) => { - return Poll::Ready(Err(Error::db(error))); - } - _ => return Poll::Ready(Err(Error::unexpected_message())), - }, - }; - - match response.sender.poll_reserve(cx) { + match self.sender.poll_reserve(cx) { Poll::Ready(Ok(())) => { - let _ = response.sender.send_item(messages); - if !request_complete { - self.responses.push_front(response); - } + let _ = self.sender.send_item(messages); } Poll::Ready(Err(_)) => { - // we need to keep paging through the rest of the messages even if the receiver's hung up - if !request_complete { - self.responses.push_front(response); - } + return Poll::Ready(Err(Error::closed())); } Poll::Pending => { - self.responses.push_front(response); - self.pending_responses.push_back(BackendMessage::Normal { - messages, - request_complete, - }); + self.pending_responses + .push_back(BackendMessage::Normal { messages }); trace!("poll_read: waiting on sender"); return Poll::Pending; } @@ -174,7 +136,7 @@ where } /// Fetch the next client request and enqueue the response sender. - fn poll_request(&mut self, cx: &mut Context<'_>) -> Poll> { + fn poll_request(&mut self, cx: &mut Context<'_>) -> Poll> { if self.receiver.is_closed() { return Poll::Ready(None); } @@ -182,10 +144,7 @@ where match self.receiver.poll_recv(cx) { Poll::Ready(Some(request)) => { trace!("polled new request"); - self.responses.push_back(Response { - sender: PollSender::new(request.sender), - }); - Poll::Ready(Some(request.messages)) + Poll::Ready(Some(request)) } Poll::Ready(None) => Poll::Ready(None), Poll::Pending => Poll::Pending, @@ -194,7 +153,7 @@ where /// Process client requests and write them to the postgres connection, flushing if necessary. /// client -> postgres - fn poll_write(&mut self, cx: &mut Context<'_>) -> Poll> { + fn poll_write(&mut self, cx: &mut Context<'_>) -> Poll> { loop { if Pin::new(&mut self.stream) .poll_ready(cx) @@ -209,14 +168,14 @@ where match self.poll_request(cx) { // send the message to postgres - Poll::Ready(Some(RequestMessages::Single(request))) => { + Poll::Ready(Some(request)) => { Pin::new(&mut self.stream) .start_send(request) .map_err(Error::io)?; } // No more messages from the client, and no more responses to wait for. // Send a terminate message to postgres - Poll::Ready(None) if self.responses.is_empty() => { + Poll::Ready(None) => { trace!("poll_write: at eof, terminating"); let mut request = BytesMut::new(); frontend::terminate(&mut request); @@ -228,16 +187,7 @@ where trace!("poll_write: sent eof, closing"); trace!("poll_write: done"); - return Poll::Ready(Ok(WriteReady::Terminating)); - } - // No more messages from the client, but there are still some responses to wait for. - Poll::Ready(None) => { - trace!( - "poll_write: at eof, pending responses {}", - self.responses.len() - ); - ready!(self.poll_flush(cx))?; - return Poll::Ready(Ok(WriteReady::WaitingOnRead)); + return Poll::Ready(Ok(())); } // Still waiting for a message from the client. Poll::Pending => { @@ -298,7 +248,7 @@ where // if the state is still active, try read from and write to postgres. let message = self.poll_read(cx)?; let closing = self.poll_write(cx)?; - if let Poll::Ready(WriteReady::Terminating) = closing { + if let Poll::Ready(()) = closing { self.state = State::Closing; } diff --git a/libs/proxy/tokio-postgres2/src/generic_client.rs b/libs/proxy/tokio-postgres2/src/generic_client.rs index 8e28843347..eeefb45d26 100644 --- a/libs/proxy/tokio-postgres2/src/generic_client.rs +++ b/libs/proxy/tokio-postgres2/src/generic_client.rs @@ -1,9 +1,6 @@ #![allow(async_fn_in_trait)] -use postgres_protocol2::Oid; - use crate::query::RowStream; -use crate::types::Type; use crate::{Client, Error, Transaction}; mod private { @@ -15,20 +12,17 @@ mod private { /// This trait is "sealed", and cannot be implemented outside of this crate. pub trait GenericClient: private::Sealed { /// Like `Client::query_raw_txt`. - async fn query_raw_txt(&self, statement: &str, params: I) -> Result + async fn query_raw_txt(&mut self, statement: &str, params: I) -> Result where S: AsRef + Sync + Send, I: IntoIterator> + Sync + Send, I::IntoIter: ExactSizeIterator + Sync + Send; - - /// Query for type information - async fn get_type(&mut self, oid: Oid) -> Result; } impl private::Sealed for Client {} impl GenericClient for Client { - async fn query_raw_txt(&self, statement: &str, params: I) -> Result + async fn query_raw_txt(&mut self, statement: &str, params: I) -> Result where S: AsRef + Sync + Send, I: IntoIterator> + Sync + Send, @@ -36,17 +30,12 @@ impl GenericClient for Client { { self.query_raw_txt(statement, params).await } - - /// Query for type information - async fn get_type(&mut self, oid: Oid) -> Result { - self.get_type_inner(oid).await - } } impl private::Sealed for Transaction<'_> {} impl GenericClient for Transaction<'_> { - async fn query_raw_txt(&self, statement: &str, params: I) -> Result + async fn query_raw_txt(&mut self, statement: &str, params: I) -> Result where S: AsRef + Sync + Send, I: IntoIterator> + Sync + Send, @@ -54,9 +43,4 @@ impl GenericClient for Transaction<'_> { { self.query_raw_txt(statement, params).await } - - /// Query for type information - async fn get_type(&mut self, oid: Oid) -> Result { - self.client_mut().get_type(oid).await - } } diff --git a/libs/proxy/tokio-postgres2/src/lib.rs b/libs/proxy/tokio-postgres2/src/lib.rs index c8ebba5487..9556070ed5 100644 --- a/libs/proxy/tokio-postgres2/src/lib.rs +++ b/libs/proxy/tokio-postgres2/src/lib.rs @@ -18,7 +18,6 @@ pub use crate::statement::{Column, Statement}; pub use crate::tls::NoTls; pub use crate::transaction::Transaction; pub use crate::transaction_builder::{IsolationLevel, TransactionBuilder}; -use crate::types::ToSql; /// After executing a query, the connection will be in one of these states #[derive(Clone, Copy, Debug, PartialEq)] @@ -120,9 +119,3 @@ pub enum SimpleQueryMessage { /// The number of rows modified or selected is returned. CommandComplete(u64), } - -fn slice_iter<'a>( - s: &'a [&'a (dyn ToSql + Sync)], -) -> impl ExactSizeIterator + 'a { - s.iter().map(|s| *s as _) -} diff --git a/libs/proxy/tokio-postgres2/src/prepare.rs b/libs/proxy/tokio-postgres2/src/prepare.rs index b27eabcb0e..16b9cf66f4 100644 --- a/libs/proxy/tokio-postgres2/src/prepare.rs +++ b/libs/proxy/tokio-postgres2/src/prepare.rs @@ -1,19 +1,14 @@ -use std::future::Future; -use std::pin::Pin; -use std::sync::Arc; - -use bytes::Bytes; +use bytes::BytesMut; use fallible_iterator::FallibleIterator; -use futures_util::{TryStreamExt, pin_mut}; -use postgres_protocol2::message::backend::Message; +use postgres_protocol2::IsNull; +use postgres_protocol2::message::backend::{Message, RowDescriptionBody}; use postgres_protocol2::message::frontend; -use tracing::debug; +use postgres_protocol2::types::oid_to_sql; +use postgres_types2::Format; -use crate::client::{CachedTypeInfo, InnerClient}; -use crate::codec::FrontendMessage; -use crate::connection::RequestMessages; +use crate::client::{CachedTypeInfo, PartialQuery, Responses}; use crate::types::{Kind, Oid, Type}; -use crate::{Column, Error, Statement, query, slice_iter}; +use crate::{Column, Error, Row, Statement}; pub(crate) const TYPEINFO_QUERY: &str = "\ SELECT t.typname, t.typtype, t.typelem, r.rngsubtype, t.typbasetype, n.nspname, t.typrelid @@ -23,22 +18,51 @@ INNER JOIN pg_catalog.pg_namespace n ON t.typnamespace = n.oid WHERE t.oid = $1 "; +/// we need to make sure we close this prepared statement. +struct CloseStmt<'a, 'b> { + client: Option<&'a mut PartialQuery<'b>>, + name: &'static str, +} + +impl<'a> CloseStmt<'a, '_> { + fn close(mut self) -> Result<&'a mut Responses, Error> { + let client = self.client.take().unwrap(); + client.send_with_flush(|buf| { + frontend::close(b'S', self.name, buf).map_err(Error::encode)?; + Ok(()) + }) + } +} + +impl Drop for CloseStmt<'_, '_> { + fn drop(&mut self) { + if let Some(client) = self.client.take() { + let _ = client.send_with_flush(|buf| { + frontend::close(b'S', self.name, buf).map_err(Error::encode)?; + Ok(()) + }); + } + } +} + async fn prepare_typecheck( - client: &Arc, + client: &mut PartialQuery<'_>, name: &'static str, query: &str, - types: &[Type], ) -> Result { - let buf = encode(client, name, query, types)?; - let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; + let responses = client.send_with_flush(|buf| { + frontend::parse(name, query, [], buf).map_err(Error::encode)?; + frontend::describe(b'S', name, buf).map_err(Error::encode)?; + Ok(()) + })?; match responses.next().await? { Message::ParseComplete => {} _ => return Err(Error::unexpected_message()), } - let parameter_description = match responses.next().await? { - Message::ParameterDescription(body) => body, + match responses.next().await? { + Message::ParameterDescription(_) => {} _ => return Err(Error::unexpected_message()), }; @@ -48,13 +72,6 @@ async fn prepare_typecheck( _ => return Err(Error::unexpected_message()), }; - let mut parameters = vec![]; - let mut it = parameter_description.parameters(); - while let Some(oid) = it.next().map_err(Error::parse)? { - let type_ = Type::from_oid(oid).ok_or_else(Error::unexpected_message)?; - parameters.push(type_); - } - let mut columns = vec![]; if let Some(row_description) = row_description { let mut it = row_description.fields(); @@ -65,98 +82,168 @@ async fn prepare_typecheck( } } - Ok(Statement::new(client, name, parameters, columns)) + Ok(Statement::new(name, columns)) } -fn encode(client: &InnerClient, name: &str, query: &str, types: &[Type]) -> Result { - if types.is_empty() { - debug!("preparing query {}: {}", name, query); - } else { - debug!("preparing query {} with types {:?}: {}", name, types, query); - } - - client.with_buf(|buf| { - frontend::parse(name, query, types.iter().map(Type::oid), buf).map_err(Error::encode)?; - frontend::describe(b'S', name, buf).map_err(Error::encode)?; - frontend::sync(buf); - Ok(buf.split().freeze()) - }) -} - -pub async fn get_type( - client: &Arc, - typecache: &mut CachedTypeInfo, - oid: Oid, -) -> Result { +fn try_from_cache(typecache: &CachedTypeInfo, oid: Oid) -> Option { if let Some(type_) = Type::from_oid(oid) { - return Ok(type_); + return Some(type_); } if let Some(type_) = typecache.types.get(&oid) { - return Ok(type_.clone()); + return Some(type_.clone()); }; - let stmt = typeinfo_statement(client, typecache).await?; + None +} - let rows = query::query(client, stmt, slice_iter(&[&oid])).await?; - pin_mut!(rows); +pub async fn parse_row_description( + client: &mut PartialQuery<'_>, + typecache: &mut CachedTypeInfo, + row_description: Option, +) -> Result, Error> { + let mut columns = vec![]; - let row = match rows.try_next().await? { - Some(row) => row, - None => return Err(Error::unexpected_message()), + if let Some(row_description) = row_description { + let mut it = row_description.fields(); + while let Some(field) = it.next().map_err(Error::parse)? { + let type_ = try_from_cache(typecache, field.type_oid()).unwrap_or(Type::UNKNOWN); + let column = Column::new(field.name().to_string(), type_, field); + columns.push(column); + } + } + + let all_known = columns.iter().all(|c| c.type_ != Type::UNKNOWN); + if all_known { + // all known, return early. + return Ok(columns); + } + + let typeinfo = "neon_proxy_typeinfo"; + + // make sure to close the typeinfo statement before exiting. + let mut guard = CloseStmt { + name: typeinfo, + client: None, + }; + let client = guard.client.insert(client); + + // get the typeinfo statement. + let stmt = prepare_typecheck(client, typeinfo, TYPEINFO_QUERY).await?; + + for column in &mut columns { + column.type_ = get_type(client, typecache, &stmt, column.type_oid()).await?; + } + + // cancel the close guard. + let responses = guard.close()?; + + match responses.next().await? { + Message::CloseComplete => {} + _ => return Err(Error::unexpected_message()), + } + + Ok(columns) +} + +async fn get_type( + client: &mut PartialQuery<'_>, + typecache: &mut CachedTypeInfo, + stmt: &Statement, + mut oid: Oid, +) -> Result { + let mut stack = vec![]; + let mut type_ = loop { + if let Some(type_) = try_from_cache(typecache, oid) { + break type_; + } + + let row = exec(client, stmt, oid).await?; + if stack.len() > 8 { + return Err(Error::unexpected_message()); + } + + let name: String = row.try_get(0)?; + let type_: i8 = row.try_get(1)?; + let elem_oid: Oid = row.try_get(2)?; + let rngsubtype: Option = row.try_get(3)?; + let basetype: Oid = row.try_get(4)?; + let schema: String = row.try_get(5)?; + let relid: Oid = row.try_get(6)?; + + let kind = if type_ == b'e' as i8 { + Kind::Enum + } else if type_ == b'p' as i8 { + Kind::Pseudo + } else if basetype != 0 { + Kind::Domain(basetype) + } else if elem_oid != 0 { + stack.push((name, oid, schema)); + oid = elem_oid; + continue; + } else if relid != 0 { + Kind::Composite(relid) + } else if let Some(rngsubtype) = rngsubtype { + Kind::Range(rngsubtype) + } else { + Kind::Simple + }; + + let type_ = Type::new(name, oid, kind, schema); + typecache.types.insert(oid, type_.clone()); + break type_; }; - let name: String = row.try_get(0)?; - let type_: i8 = row.try_get(1)?; - let elem_oid: Oid = row.try_get(2)?; - let rngsubtype: Option = row.try_get(3)?; - let basetype: Oid = row.try_get(4)?; - let schema: String = row.try_get(5)?; - let relid: Oid = row.try_get(6)?; - - let kind = if type_ == b'e' as i8 { - Kind::Enum - } else if type_ == b'p' as i8 { - Kind::Pseudo - } else if basetype != 0 { - Kind::Domain(basetype) - } else if elem_oid != 0 { - let type_ = get_type_rec(client, typecache, elem_oid).await?; - Kind::Array(type_) - } else if relid != 0 { - Kind::Composite(relid) - } else if let Some(rngsubtype) = rngsubtype { - let type_ = get_type_rec(client, typecache, rngsubtype).await?; - Kind::Range(type_) - } else { - Kind::Simple - }; - - let type_ = Type::new(name, oid, kind, schema); - typecache.types.insert(oid, type_.clone()); + while let Some((name, oid, schema)) = stack.pop() { + type_ = Type::new(name, oid, Kind::Array(type_), schema); + typecache.types.insert(oid, type_.clone()); + } Ok(type_) } -fn get_type_rec<'a>( - client: &'a Arc, - typecache: &'a mut CachedTypeInfo, - oid: Oid, -) -> Pin> + Send + 'a>> { - Box::pin(get_type(client, typecache, oid)) -} +/// exec the typeinfo statement returning one row. +async fn exec( + client: &mut PartialQuery<'_>, + statement: &Statement, + param: Oid, +) -> Result { + let responses = client.send_with_flush(|buf| { + encode_bind(statement, param, "", buf); + frontend::execute("", 0, buf).map_err(Error::encode)?; + Ok(()) + })?; -async fn typeinfo_statement( - client: &Arc, - typecache: &mut CachedTypeInfo, -) -> Result { - if let Some(stmt) = &typecache.typeinfo { - return Ok(stmt.clone()); + match responses.next().await? { + Message::BindComplete => {} + _ => return Err(Error::unexpected_message()), } - let typeinfo = "neon_proxy_typeinfo"; - let stmt = prepare_typecheck(client, typeinfo, TYPEINFO_QUERY, &[]).await?; + let row = match responses.next().await? { + Message::DataRow(body) => Row::new(statement.clone(), body, Format::Binary)?, + _ => return Err(Error::unexpected_message()), + }; - typecache.typeinfo = Some(stmt.clone()); - Ok(stmt) + match responses.next().await? { + Message::CommandComplete(_) => {} + _ => return Err(Error::unexpected_message()), + }; + + Ok(row) +} + +fn encode_bind(statement: &Statement, param: Oid, portal: &str, buf: &mut BytesMut) { + frontend::bind( + portal, + statement.name(), + [Format::Binary as i16], + [param], + |param, buf| { + oid_to_sql(param, buf); + Ok(IsNull::No) + }, + [Format::Binary as i16], + buf, + ) + .unwrap(); } diff --git a/libs/proxy/tokio-postgres2/src/query.rs b/libs/proxy/tokio-postgres2/src/query.rs index 106bc69d49..5f3ed8ef5a 100644 --- a/libs/proxy/tokio-postgres2/src/query.rs +++ b/libs/proxy/tokio-postgres2/src/query.rs @@ -1,76 +1,43 @@ -use std::fmt; -use std::marker::PhantomPinned; use std::pin::Pin; -use std::sync::Arc; use std::task::{Context, Poll}; -use bytes::{BufMut, Bytes, BytesMut}; -use fallible_iterator::FallibleIterator; +use bytes::BufMut; use futures_util::{Stream, ready}; -use pin_project_lite::pin_project; use postgres_protocol2::message::backend::Message; use postgres_protocol2::message::frontend; -use postgres_types2::{Format, ToSql, Type}; -use tracing::debug; +use postgres_types2::Format; -use crate::client::{InnerClient, Responses}; -use crate::codec::FrontendMessage; -use crate::connection::RequestMessages; -use crate::types::IsNull; -use crate::{Column, Error, ReadyForQueryStatus, Row, Statement}; +use crate::client::{CachedTypeInfo, InnerClient, Responses}; +use crate::{Error, ReadyForQueryStatus, Row, Statement}; -struct BorrowToSqlParamsDebug<'a>(&'a [&'a (dyn ToSql + Sync)]); - -impl fmt::Debug for BorrowToSqlParamsDebug<'_> { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_list().entries(self.0.iter()).finish() - } -} - -pub async fn query<'a, I>( - client: &InnerClient, - statement: Statement, - params: I, -) -> Result -where - I: IntoIterator, - I::IntoIter: ExactSizeIterator, -{ - let buf = if tracing::enabled!(tracing::Level::DEBUG) { - let params = params.into_iter().collect::>(); - debug!( - "executing statement {} with parameters: {:?}", - statement.name(), - BorrowToSqlParamsDebug(params.as_slice()), - ); - encode(client, &statement, params)? - } else { - encode(client, &statement, params)? - }; - let responses = start(client, buf).await?; - Ok(RowStream { - statement, - responses, - command_tag: None, - status: ReadyForQueryStatus::Unknown, - output_format: Format::Binary, - _p: PhantomPinned, - }) -} - -pub async fn query_txt( - client: &Arc, +pub async fn query_txt<'a, S, I>( + client: &'a mut InnerClient, + typecache: &mut CachedTypeInfo, query: &str, params: I, -) -> Result +) -> Result, Error> where S: AsRef, I: IntoIterator>, I::IntoIter: ExactSizeIterator, { let params = params.into_iter(); + let mut client = client.start()?; - let buf = client.with_buf(|buf| { + // Flow: + // 1. Parse the query + // 2. Inspect the row description for OIDs + // 3. If there's any OIDs we don't already know about, perform the typeinfo routine + // 4. Execute the query + // 5. Sync. + // + // The typeinfo routine: + // 1. Parse the typeinfo query + // 2. Execute the query on each OID + // 3. If the result does not match an OID we know, repeat 2. + + // parse the query and get type info + let responses = client.send_with_flush(|buf| { frontend::parse( "", // unnamed prepared statement query, // query to parse @@ -79,7 +46,30 @@ where ) .map_err(Error::encode)?; frontend::describe(b'S', "", buf).map_err(Error::encode)?; - // Bind, pass params as text, retrieve as binary + Ok(()) + })?; + + match responses.next().await? { + Message::ParseComplete => {} + _ => return Err(Error::unexpected_message()), + } + + match responses.next().await? { + Message::ParameterDescription(_) => {} + _ => return Err(Error::unexpected_message()), + }; + + let row_description = match responses.next().await? { + Message::RowDescription(body) => Some(body), + Message::NoData => None, + _ => return Err(Error::unexpected_message()), + }; + + let columns = + crate::prepare::parse_row_description(&mut client, typecache, row_description).await?; + + let responses = client.send_with_sync(|buf| { + // Bind, pass params as text, retrieve as text match frontend::bind( "", // empty string selects the unnamed portal "", // unnamed prepared statement @@ -102,173 +92,55 @@ where // Execute frontend::execute("", 0, buf).map_err(Error::encode)?; - // Sync - frontend::sync(buf); - Ok(buf.split().freeze()) + Ok(()) })?; - // now read the responses - let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; - - match responses.next().await? { - Message::ParseComplete => {} - _ => return Err(Error::unexpected_message()), - } - - let parameter_description = match responses.next().await? { - Message::ParameterDescription(body) => body, - _ => return Err(Error::unexpected_message()), - }; - - let row_description = match responses.next().await? { - Message::RowDescription(body) => Some(body), - Message::NoData => None, - _ => return Err(Error::unexpected_message()), - }; - match responses.next().await? { Message::BindComplete => {} _ => return Err(Error::unexpected_message()), } - let mut parameters = vec![]; - let mut it = parameter_description.parameters(); - while let Some(oid) = it.next().map_err(Error::parse)? { - let type_ = Type::from_oid(oid).unwrap_or(Type::UNKNOWN); - parameters.push(type_); - } - - let mut columns = vec![]; - if let Some(row_description) = row_description { - let mut it = row_description.fields(); - while let Some(field) = it.next().map_err(Error::parse)? { - let type_ = Type::from_oid(field.type_oid()).unwrap_or(Type::UNKNOWN); - let column = Column::new(field.name().to_string(), type_, field); - columns.push(column); - } - } - Ok(RowStream { - statement: Statement::new_anonymous(parameters, columns), responses, + statement: Statement::new("", columns), command_tag: None, status: ReadyForQueryStatus::Unknown, output_format: Format::Text, - _p: PhantomPinned, }) } -async fn start(client: &InnerClient, buf: Bytes) -> Result { - let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; - - match responses.next().await? { - Message::BindComplete => {} - _ => return Err(Error::unexpected_message()), - } - - Ok(responses) +/// A stream of table rows. +pub struct RowStream<'a> { + responses: &'a mut Responses, + output_format: Format, + pub statement: Statement, + pub command_tag: Option, + pub status: ReadyForQueryStatus, } -pub fn encode<'a, I>(client: &InnerClient, statement: &Statement, params: I) -> Result -where - I: IntoIterator, - I::IntoIter: ExactSizeIterator, -{ - client.with_buf(|buf| { - encode_bind(statement, params, "", buf)?; - frontend::execute("", 0, buf).map_err(Error::encode)?; - frontend::sync(buf); - Ok(buf.split().freeze()) - }) -} - -pub fn encode_bind<'a, I>( - statement: &Statement, - params: I, - portal: &str, - buf: &mut BytesMut, -) -> Result<(), Error> -where - I: IntoIterator, - I::IntoIter: ExactSizeIterator, -{ - let param_types = statement.params(); - let params = params.into_iter(); - - assert!( - param_types.len() == params.len(), - "expected {} parameters but got {}", - param_types.len(), - params.len() - ); - - let (param_formats, params): (Vec<_>, Vec<_>) = params - .zip(param_types.iter()) - .map(|(p, ty)| (p.encode_format(ty) as i16, p)) - .unzip(); - - let params = params.into_iter(); - - let mut error_idx = 0; - let r = frontend::bind( - portal, - statement.name(), - param_formats, - params.zip(param_types).enumerate(), - |(idx, (param, ty)), buf| match param.to_sql_checked(ty, buf) { - Ok(IsNull::No) => Ok(postgres_protocol2::IsNull::No), - Ok(IsNull::Yes) => Ok(postgres_protocol2::IsNull::Yes), - Err(e) => { - error_idx = idx; - Err(e) - } - }, - Some(1), - buf, - ); - match r { - Ok(()) => Ok(()), - Err(frontend::BindError::Conversion(e)) => Err(Error::to_sql(e, error_idx)), - Err(frontend::BindError::Serialization(e)) => Err(Error::encode(e)), - } -} - -pin_project! { - /// A stream of table rows. - pub struct RowStream { - statement: Statement, - responses: Responses, - command_tag: Option, - output_format: Format, - status: ReadyForQueryStatus, - #[pin] - _p: PhantomPinned, - } -} - -impl Stream for RowStream { +impl Stream for RowStream<'_> { type Item = Result; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = self.project(); + let this = self.get_mut(); loop { match ready!(this.responses.poll_next(cx)?) { Message::DataRow(body) => { return Poll::Ready(Some(Ok(Row::new( this.statement.clone(), body, - *this.output_format, + this.output_format, )?))); } Message::EmptyQueryResponse | Message::PortalSuspended => {} Message::CommandComplete(body) => { if let Ok(tag) = body.tag() { - *this.command_tag = Some(tag.to_string()); + this.command_tag = Some(tag.to_string()); } } Message::ReadyForQuery(status) => { - *this.status = status.into(); + this.status = status.into(); return Poll::Ready(None); } _ => return Poll::Ready(Some(Err(Error::unexpected_message()))), @@ -276,24 +148,3 @@ impl Stream for RowStream { } } } - -impl RowStream { - /// Returns information about the columns of data in the row. - pub fn columns(&self) -> &[Column] { - self.statement.columns() - } - - /// Returns the command tag of this query. - /// - /// This is only available after the stream has been exhausted. - pub fn command_tag(&self) -> Option { - self.command_tag.clone() - } - - /// Returns if the connection is ready for querying, with the status of the connection. - /// - /// This might be available only after the stream has been exhausted. - pub fn ready_status(&self) -> ReadyForQueryStatus { - self.status - } -} diff --git a/libs/proxy/tokio-postgres2/src/simple_query.rs b/libs/proxy/tokio-postgres2/src/simple_query.rs index 2cf17188cf..e1ed48cdaf 100644 --- a/libs/proxy/tokio-postgres2/src/simple_query.rs +++ b/libs/proxy/tokio-postgres2/src/simple_query.rs @@ -1,19 +1,14 @@ -use std::marker::PhantomPinned; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; -use bytes::Bytes; use fallible_iterator::FallibleIterator; use futures_util::{Stream, ready}; use pin_project_lite::pin_project; use postgres_protocol2::message::backend::Message; -use postgres_protocol2::message::frontend; use tracing::debug; use crate::client::{InnerClient, Responses}; -use crate::codec::FrontendMessage; -use crate::connection::RequestMessages; use crate::{Error, ReadyForQueryStatus, SimpleQueryMessage, SimpleQueryRow}; /// Information about a column of a single query row. @@ -33,28 +28,28 @@ impl SimpleColumn { } } -pub async fn simple_query(client: &InnerClient, query: &str) -> Result { +pub async fn simple_query<'a>( + client: &'a mut InnerClient, + query: &str, +) -> Result, Error> { debug!("executing simple query: {}", query); - let buf = encode(client, query)?; - let responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; + let responses = client.send_simple_query(query)?; Ok(SimpleQueryStream { responses, columns: None, status: ReadyForQueryStatus::Unknown, - _p: PhantomPinned, }) } pub async fn batch_execute( - client: &InnerClient, + client: &mut InnerClient, query: &str, ) -> Result { debug!("executing statement batch: {}", query); - let buf = encode(client, query)?; - let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; + let responses = client.send_simple_query(query)?; loop { match responses.next().await? { @@ -68,25 +63,16 @@ pub async fn batch_execute( } } -pub(crate) fn encode(client: &InnerClient, query: &str) -> Result { - client.with_buf(|buf| { - frontend::query(query, buf).map_err(Error::encode)?; - Ok(buf.split().freeze()) - }) -} - pin_project! { /// A stream of simple query results. - pub struct SimpleQueryStream { - responses: Responses, + pub struct SimpleQueryStream<'a> { + responses: &'a mut Responses, columns: Option>, status: ReadyForQueryStatus, - #[pin] - _p: PhantomPinned, } } -impl SimpleQueryStream { +impl SimpleQueryStream<'_> { /// Returns if the connection is ready for querying, with the status of the connection. /// /// This might be available only after the stream has been exhausted. @@ -95,7 +81,7 @@ impl SimpleQueryStream { } } -impl Stream for SimpleQueryStream { +impl Stream for SimpleQueryStream<'_> { type Item = Result; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { diff --git a/libs/proxy/tokio-postgres2/src/statement.rs b/libs/proxy/tokio-postgres2/src/statement.rs index e4828db712..1f22d87fd7 100644 --- a/libs/proxy/tokio-postgres2/src/statement.rs +++ b/libs/proxy/tokio-postgres2/src/statement.rs @@ -1,35 +1,15 @@ use std::fmt; -use std::sync::{Arc, Weak}; +use std::sync::Arc; +use crate::types::Type; use postgres_protocol2::Oid; use postgres_protocol2::message::backend::Field; -use postgres_protocol2::message::frontend; - -use crate::client::InnerClient; -use crate::codec::FrontendMessage; -use crate::connection::RequestMessages; -use crate::types::Type; struct StatementInner { - client: Weak, name: &'static str, - params: Vec, columns: Vec, } -impl Drop for StatementInner { - fn drop(&mut self) { - if let Some(client) = self.client.upgrade() { - let buf = client.with_buf(|buf| { - frontend::close(b'S', self.name, buf).unwrap(); - frontend::sync(buf); - buf.split().freeze() - }); - let _ = client.send(RequestMessages::Single(FrontendMessage::Raw(buf))); - } - } -} - /// A prepared statement. /// /// Prepared statements can only be used with the connection that created them. @@ -37,38 +17,14 @@ impl Drop for StatementInner { pub struct Statement(Arc); impl Statement { - pub(crate) fn new( - inner: &Arc, - name: &'static str, - params: Vec, - columns: Vec, - ) -> Statement { - Statement(Arc::new(StatementInner { - client: Arc::downgrade(inner), - name, - params, - columns, - })) - } - - pub(crate) fn new_anonymous(params: Vec, columns: Vec) -> Statement { - Statement(Arc::new(StatementInner { - client: Weak::new(), - name: "", - params, - columns, - })) + pub(crate) fn new(name: &'static str, columns: Vec) -> Statement { + Statement(Arc::new(StatementInner { name, columns })) } pub(crate) fn name(&self) -> &str { self.0.name } - /// Returns the expected types of the statement's parameters. - pub fn params(&self) -> &[Type] { - &self.0.params - } - /// Returns information about the columns returned when the statement is queried. pub fn columns(&self) -> &[Column] { &self.0.columns @@ -78,7 +34,7 @@ impl Statement { /// Information about a column of a query. pub struct Column { name: String, - type_: Type, + pub(crate) type_: Type, // raw fields from RowDescription table_oid: Oid, diff --git a/libs/proxy/tokio-postgres2/src/transaction.rs b/libs/proxy/tokio-postgres2/src/transaction.rs index f32603470f..12fe0737d4 100644 --- a/libs/proxy/tokio-postgres2/src/transaction.rs +++ b/libs/proxy/tokio-postgres2/src/transaction.rs @@ -1,7 +1,3 @@ -use postgres_protocol2::message::frontend; - -use crate::codec::FrontendMessage; -use crate::connection::RequestMessages; use crate::query::RowStream; use crate::{CancelToken, Client, Error, ReadyForQueryStatus}; @@ -20,14 +16,7 @@ impl Drop for Transaction<'_> { return; } - let buf = self.client.inner().with_buf(|buf| { - frontend::query("ROLLBACK", buf).unwrap(); - buf.split().freeze() - }); - let _ = self - .client - .inner() - .send(RequestMessages::Single(FrontendMessage::Raw(buf))); + let _ = self.client.inner_mut().send_simple_query("ROLLBACK"); } } @@ -54,7 +43,11 @@ impl<'a> Transaction<'a> { } /// Like `Client::query_raw_txt`. - pub async fn query_raw_txt(&self, statement: &str, params: I) -> Result + pub async fn query_raw_txt( + &mut self, + statement: &str, + params: I, + ) -> Result where S: AsRef, I: IntoIterator>, diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index dfaeedaeae..1c5bb64480 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -14,7 +14,9 @@ use hyper::http::{HeaderName, HeaderValue}; use hyper::{HeaderMap, Request, Response, StatusCode, header}; use indexmap::IndexMap; use postgres_client::error::{DbError, ErrorPosition, SqlState}; -use postgres_client::{GenericClient, IsolationLevel, NoTls, ReadyForQueryStatus, Transaction}; +use postgres_client::{ + GenericClient, IsolationLevel, NoTls, ReadyForQueryStatus, RowStream, Transaction, +}; use pq_proto::StartupMessageParamsBuilder; use serde::Serialize; use serde_json::Value; @@ -1092,22 +1094,41 @@ async fn query_to_json( let query_start = Instant::now(); let query_params = data.params; - let mut row_stream = std::pin::pin!( - client - .query_raw_txt(&data.query, query_params) - .await - .map_err(SqlOverHttpError::Postgres)? - ); + let mut row_stream = client + .query_raw_txt(&data.query, query_params) + .await + .map_err(SqlOverHttpError::Postgres)?; let query_acknowledged = Instant::now(); + let columns_len = row_stream.statement.columns().len(); + let mut fields = Vec::with_capacity(columns_len); + let mut types = Vec::with_capacity(columns_len); + + for c in row_stream.statement.columns() { + fields.push(json!({ + "name": c.name().to_owned(), + "dataTypeID": c.type_().oid(), + "tableID": c.table_oid(), + "columnID": c.column_id(), + "dataTypeSize": c.type_size(), + "dataTypeModifier": c.type_modifier(), + "format": "text", + })); + + types.push(c.type_().clone()); + } + + let raw_output = parsed_headers.raw_output; + let array_mode = data.array_mode.unwrap_or(parsed_headers.default_array_mode); + // Manually drain the stream into a vector to leave row_stream hanging // around to get a command tag. Also check that the response is not too // big. - let mut rows: Vec = Vec::new(); + let mut rows = Vec::new(); while let Some(row) = row_stream.next().await { let row = row.map_err(SqlOverHttpError::Postgres)?; *current_size += row.body_len(); - rows.push(row); + // we don't have a streaming response support yet so this is to prevent OOM // from a malicious query (eg a cross join) if *current_size > config.max_response_size_bytes { @@ -1115,13 +1136,26 @@ async fn query_to_json( config.max_response_size_bytes, )); } + + let row = pg_text_row_to_json(&row, &types, raw_output, array_mode)?; + rows.push(row); + + // assumption: parsing pg text and converting to json takes CPU time. + // let's assume it is slightly expensive, so we should consume some cooperative budget. + // Especially considering that `RowStream::next` might be pulling from a batch + // of rows and never hit the tokio mpsc for a long time (although unlikely). + tokio::task::consume_budget().await; } let query_resp_end = Instant::now(); - let ready = row_stream.ready_status(); + let RowStream { + command_tag, + status: ready, + .. + } = row_stream; // grab the command tag and number of rows affected - let command_tag = row_stream.command_tag().unwrap_or_default(); + let command_tag = command_tag.unwrap_or_default(); let mut command_tag_split = command_tag.split(' '); let command_tag_name = command_tag_split.next().unwrap_or_default(); let command_tag_count = if command_tag_name == "INSERT" { @@ -1142,38 +1176,6 @@ async fn query_to_json( "finished executing query" ); - let columns_len = row_stream.columns().len(); - let mut fields = Vec::with_capacity(columns_len); - let mut columns = Vec::with_capacity(columns_len); - - for c in row_stream.columns() { - fields.push(json!({ - "name": c.name().to_owned(), - "dataTypeID": c.type_().oid(), - "tableID": c.table_oid(), - "columnID": c.column_id(), - "dataTypeSize": c.type_size(), - "dataTypeModifier": c.type_modifier(), - "format": "text", - })); - - match client.get_type(c.type_oid()).await { - Ok(t) => columns.push(t), - Err(err) => { - tracing::warn!(?err, "unable to query type information"); - return Err(SqlOverHttpError::InternalPostgres(err)); - } - } - } - - let array_mode = data.array_mode.unwrap_or(parsed_headers.default_array_mode); - - // convert rows to JSON - let rows = rows - .iter() - .map(|row| pg_text_row_to_json(row, &columns, parsed_headers.raw_output, array_mode)) - .collect::, _>>()?; - // Resulting JSON format is based on the format of node-postgres result. let results = json!({ "command": command_tag_name.to_string(),