diff --git a/libs/proxy/postgres-types2/src/lib.rs b/libs/proxy/postgres-types2/src/lib.rs index 0ccd8c295f..b6bcabc922 100644 --- a/libs/proxy/postgres-types2/src/lib.rs +++ b/libs/proxy/postgres-types2/src/lib.rs @@ -135,8 +135,8 @@ impl Type { pub enum Kind { /// A simple type like `VARCHAR` or `INTEGER`. Simple, - /// An enumerated type along with its variants. - Enum(Vec), + /// An enumerated type. + Enum, /// A pseudo-type. Pseudo, /// An array type along with the type of its elements. @@ -146,9 +146,9 @@ pub enum Kind { /// A multirange type along with the type of its elements. Multirange(Type), /// A domain type along with its underlying type. - Domain(Type), - /// A composite type along with information about its fields. - Composite(Vec), + Domain(Oid), + /// A composite type. + Composite(Oid), } /// Information about a field of a composite type. diff --git a/libs/proxy/tokio-postgres2/src/client.rs b/libs/proxy/tokio-postgres2/src/client.rs index 08a06163e1..186eb07000 100644 --- a/libs/proxy/tokio-postgres2/src/client.rs +++ b/libs/proxy/tokio-postgres2/src/client.rs @@ -19,10 +19,10 @@ use crate::config::{Host, SslMode}; use crate::connection::{Request, RequestMessages}; use crate::query::RowStream; use crate::simple_query::SimpleQueryStream; -use crate::types::{Oid, ToSql, Type}; +use crate::types::{Oid, Type}; use crate::{ - CancelToken, Error, ReadyForQueryStatus, Row, SimpleQueryMessage, Statement, Transaction, - TransactionBuilder, query, simple_query, slice_iter, + CancelToken, Error, ReadyForQueryStatus, SimpleQueryMessage, Statement, Transaction, + TransactionBuilder, query, simple_query, }; pub struct Responses { @@ -54,26 +54,18 @@ impl Responses { /// A cache of type info and prepared statements for fetching type info /// (corresponding to the queries in the [crate::prepare] module). #[derive(Default)] -struct CachedTypeInfo { +pub(crate) struct CachedTypeInfo { /// A statement for basic information for a type from its /// OID. Corresponds to [TYPEINFO_QUERY](crate::prepare::TYPEINFO_QUERY) (or its /// fallback). - typeinfo: Option, - /// A statement for getting information for a composite type from its OID. - /// Corresponds to [TYPEINFO_QUERY](crate::prepare::TYPEINFO_COMPOSITE_QUERY). - typeinfo_composite: Option, - /// A statement for getting information for a composite type from its OID. - /// Corresponds to [TYPEINFO_QUERY](crate::prepare::TYPEINFO_COMPOSITE_QUERY) (or - /// its fallback). - typeinfo_enum: Option, + pub(crate) typeinfo: Option, /// Cache of types already looked up. - types: HashMap, + pub(crate) types: HashMap, } pub struct InnerClient { sender: mpsc::UnboundedSender, - cached_typeinfo: Mutex, /// A buffer to use when writing out postgres commands. buffer: Mutex, @@ -91,38 +83,6 @@ impl InnerClient { }) } - pub fn typeinfo(&self) -> Option { - self.cached_typeinfo.lock().typeinfo.clone() - } - - pub fn set_typeinfo(&self, statement: &Statement) { - self.cached_typeinfo.lock().typeinfo = Some(statement.clone()); - } - - pub fn typeinfo_composite(&self) -> Option { - self.cached_typeinfo.lock().typeinfo_composite.clone() - } - - pub fn set_typeinfo_composite(&self, statement: &Statement) { - self.cached_typeinfo.lock().typeinfo_composite = Some(statement.clone()); - } - - pub fn typeinfo_enum(&self) -> Option { - self.cached_typeinfo.lock().typeinfo_enum.clone() - } - - pub fn set_typeinfo_enum(&self, statement: &Statement) { - self.cached_typeinfo.lock().typeinfo_enum = Some(statement.clone()); - } - - pub fn type_(&self, oid: Oid) -> Option { - self.cached_typeinfo.lock().types.get(&oid).cloned() - } - - pub fn set_type(&self, oid: Oid, type_: &Type) { - self.cached_typeinfo.lock().types.insert(oid, type_.clone()); - } - /// Call the given function with a buffer to be used when writing out /// postgres commands. pub fn with_buf(&self, f: F) -> R @@ -142,7 +102,6 @@ pub struct SocketConfig { pub host: Host, pub port: u16, pub connect_timeout: Option, - // pub keepalive: Option, } /// An asynchronous PostgreSQL client. @@ -151,6 +110,7 @@ pub struct SocketConfig { /// through this client object. pub struct Client { inner: Arc, + cached_typeinfo: CachedTypeInfo, socket_config: SocketConfig, ssl_mode: SslMode, @@ -169,9 +129,9 @@ impl Client { Client { inner: Arc::new(InnerClient { sender, - cached_typeinfo: Default::default(), buffer: Default::default(), }), + cached_typeinfo: Default::default(), socket_config, ssl_mode, @@ -189,55 +149,6 @@ impl Client { &self.inner } - /// Executes a statement, returning a vector of the resulting rows. - /// - /// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list - /// provided, 1-indexed. - /// - /// The `statement` argument can either be a `Statement`, or a raw query string. If the same statement will be - /// repeatedly executed (perhaps with different query parameters), consider preparing the statement up front - /// with the `prepare` method. - /// - /// # Panics - /// - /// Panics if the number of parameters provided does not match the number expected. - pub async fn query( - &self, - statement: Statement, - params: &[&(dyn ToSql + Sync)], - ) -> Result, Error> { - self.query_raw(statement, slice_iter(params)) - .await? - .try_collect() - .await - } - - /// The maximally flexible version of [`query`]. - /// - /// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list - /// provided, 1-indexed. - /// - /// The `statement` argument can either be a `Statement`, or a raw query string. If the same statement will be - /// repeatedly executed (perhaps with different query parameters), consider preparing the statement up front - /// with the `prepare` method. - /// - /// # Panics - /// - /// Panics if the number of parameters provided does not match the number expected. - /// - /// [`query`]: #method.query - pub async fn query_raw<'a, I>( - &self, - statement: Statement, - params: I, - ) -> Result - where - I: IntoIterator, - I::IntoIter: ExactSizeIterator, - { - query::query(&self.inner, statement, params).await - } - /// Pass text directly to the Postgres backend to allow it to sort out typing itself and /// to save a roundtrip pub async fn query_raw_txt(&self, statement: &str, params: I) -> Result @@ -284,14 +195,10 @@ impl Client { simple_query::batch_execute(self.inner(), query).await } - pub async fn discard_all(&self) -> Result { + pub async fn discard_all(&mut self) -> Result { // clear the prepared statements that are about to be nuked from the postgres session - { - let mut typeinfo = self.inner.cached_typeinfo.lock(); - typeinfo.typeinfo = None; - typeinfo.typeinfo_composite = None; - typeinfo.typeinfo_enum = None; - } + + self.cached_typeinfo.typeinfo = None; self.batch_execute("discard all").await } @@ -359,8 +266,8 @@ impl Client { } /// Query for type information - pub async fn get_type(&self, oid: Oid) -> Result { - crate::prepare::get_type(&self.inner, oid).await + pub(crate) async fn get_type_inner(&mut self, oid: Oid) -> Result { + crate::prepare::get_type(&self.inner, &mut self.cached_typeinfo, oid).await } /// Determines if the connection to the server has already closed. diff --git a/libs/proxy/tokio-postgres2/src/generic_client.rs b/libs/proxy/tokio-postgres2/src/generic_client.rs index 31c3d8fa3e..8e28843347 100644 --- a/libs/proxy/tokio-postgres2/src/generic_client.rs +++ b/libs/proxy/tokio-postgres2/src/generic_client.rs @@ -22,7 +22,7 @@ pub trait GenericClient: private::Sealed { I::IntoIter: ExactSizeIterator + Sync + Send; /// Query for type information - async fn get_type(&self, oid: Oid) -> Result; + async fn get_type(&mut self, oid: Oid) -> Result; } impl private::Sealed for Client {} @@ -38,8 +38,8 @@ impl GenericClient for Client { } /// Query for type information - async fn get_type(&self, oid: Oid) -> Result { - crate::prepare::get_type(self.inner(), oid).await + async fn get_type(&mut self, oid: Oid) -> Result { + self.get_type_inner(oid).await } } @@ -56,7 +56,7 @@ impl GenericClient for Transaction<'_> { } /// Query for type information - async fn get_type(&self, oid: Oid) -> Result { - self.client().get_type(oid).await + async fn get_type(&mut self, oid: Oid) -> Result { + self.client_mut().get_type(oid).await } } diff --git a/libs/proxy/tokio-postgres2/src/prepare.rs b/libs/proxy/tokio-postgres2/src/prepare.rs index b36d2e5f74..ba13a528f6 100644 --- a/libs/proxy/tokio-postgres2/src/prepare.rs +++ b/libs/proxy/tokio-postgres2/src/prepare.rs @@ -9,10 +9,10 @@ use log::debug; use postgres_protocol2::message::backend::Message; use postgres_protocol2::message::frontend; -use crate::client::InnerClient; +use crate::client::{CachedTypeInfo, InnerClient}; use crate::codec::FrontendMessage; use crate::connection::RequestMessages; -use crate::types::{Field, Kind, Oid, Type}; +use crate::types::{Kind, Oid, Type}; use crate::{Column, Error, Statement, query, slice_iter}; pub(crate) const TYPEINFO_QUERY: &str = "\ @@ -23,23 +23,7 @@ INNER JOIN pg_catalog.pg_namespace n ON t.typnamespace = n.oid WHERE t.oid = $1 "; -const TYPEINFO_ENUM_QUERY: &str = "\ -SELECT enumlabel -FROM pg_catalog.pg_enum -WHERE enumtypid = $1 -ORDER BY enumsortorder -"; - -pub(crate) const TYPEINFO_COMPOSITE_QUERY: &str = "\ -SELECT attname, atttypid -FROM pg_catalog.pg_attribute -WHERE attrelid = $1 -AND NOT attisdropped -AND attnum > 0 -ORDER BY attnum -"; - -pub async fn prepare( +async fn prepare_typecheck( client: &Arc, name: &'static str, query: &str, @@ -67,7 +51,7 @@ pub async fn prepare( let mut parameters = vec![]; let mut it = parameter_description.parameters(); while let Some(oid) = it.next().map_err(Error::parse)? { - let type_ = get_type(client, oid).await?; + let type_ = Type::from_oid(oid).ok_or_else(Error::unexpected_message)?; parameters.push(type_); } @@ -75,7 +59,7 @@ pub async fn prepare( if let Some(row_description) = row_description { let mut it = row_description.fields(); while let Some(field) = it.next().map_err(Error::parse)? { - let type_ = get_type(client, field.type_oid()).await?; + let type_ = Type::from_oid(field.type_oid()).ok_or_else(Error::unexpected_message)?; let column = Column::new(field.name().to_string(), type_, field); columns.push(column); } @@ -84,15 +68,6 @@ pub async fn prepare( Ok(Statement::new(client, name, parameters, columns)) } -fn prepare_rec<'a>( - client: &'a Arc, - name: &'static str, - query: &'a str, - types: &'a [Type], -) -> Pin> + 'a + Send>> { - Box::pin(prepare(client, name, query, types)) -} - fn encode(client: &InnerClient, name: &str, query: &str, types: &[Type]) -> Result { if types.is_empty() { debug!("preparing query {}: {}", name, query); @@ -108,16 +83,20 @@ fn encode(client: &InnerClient, name: &str, query: &str, types: &[Type]) -> Resu }) } -pub async fn get_type(client: &Arc, oid: Oid) -> Result { +pub async fn get_type( + client: &Arc, + typecache: &mut CachedTypeInfo, + oid: Oid, +) -> Result { if let Some(type_) = Type::from_oid(oid) { return Ok(type_); } - if let Some(type_) = client.type_(oid) { - return Ok(type_); - } + if let Some(type_) = typecache.types.get(&oid) { + return Ok(type_.clone()); + }; - let stmt = typeinfo_statement(client).await?; + let stmt = typeinfo_statement(client, typecache).await?; let rows = query::query(client, stmt, slice_iter(&[&oid])).await?; pin_mut!(rows); @@ -136,100 +115,48 @@ pub async fn get_type(client: &Arc, oid: Oid) -> Result( client: &'a Arc, + typecache: &'a mut CachedTypeInfo, oid: Oid, ) -> Pin> + Send + 'a>> { - Box::pin(get_type(client, oid)) + Box::pin(get_type(client, typecache, oid)) } -async fn typeinfo_statement(client: &Arc) -> Result { - if let Some(stmt) = client.typeinfo() { - return Ok(stmt); +async fn typeinfo_statement( + client: &Arc, + typecache: &mut CachedTypeInfo, +) -> Result { + if let Some(stmt) = &typecache.typeinfo { + return Ok(stmt.clone()); } let typeinfo = "neon_proxy_typeinfo"; - let stmt = prepare_rec(client, typeinfo, TYPEINFO_QUERY, &[]).await?; + let stmt = prepare_typecheck(client, typeinfo, TYPEINFO_QUERY, &[]).await?; - client.set_typeinfo(&stmt); - Ok(stmt) -} - -async fn get_enum_variants(client: &Arc, oid: Oid) -> Result, Error> { - let stmt = typeinfo_enum_statement(client).await?; - - query::query(client, stmt, slice_iter(&[&oid])) - .await? - .and_then(|row| async move { row.try_get(0) }) - .try_collect() - .await -} - -async fn typeinfo_enum_statement(client: &Arc) -> Result { - if let Some(stmt) = client.typeinfo_enum() { - return Ok(stmt); - } - - let typeinfo = "neon_proxy_typeinfo_enum"; - let stmt = prepare_rec(client, typeinfo, TYPEINFO_ENUM_QUERY, &[]).await?; - - client.set_typeinfo_enum(&stmt); - Ok(stmt) -} - -async fn get_composite_fields(client: &Arc, oid: Oid) -> Result, Error> { - let stmt = typeinfo_composite_statement(client).await?; - - let rows = query::query(client, stmt, slice_iter(&[&oid])) - .await? - .try_collect::>() - .await?; - - let mut fields = vec![]; - for row in rows { - let name = row.try_get(0)?; - let oid = row.try_get(1)?; - let type_ = get_type_rec(client, oid).await?; - fields.push(Field::new(name, type_)); - } - - Ok(fields) -} - -async fn typeinfo_composite_statement(client: &Arc) -> Result { - if let Some(stmt) = client.typeinfo_composite() { - return Ok(stmt); - } - - let typeinfo = "neon_proxy_typeinfo_composite"; - let stmt = prepare_rec(client, typeinfo, TYPEINFO_COMPOSITE_QUERY, &[]).await?; - - client.set_typeinfo_composite(&stmt); + typecache.typeinfo = Some(stmt.clone()); Ok(stmt) } diff --git a/libs/proxy/tokio-postgres2/src/transaction.rs b/libs/proxy/tokio-postgres2/src/transaction.rs index eecbfc5873..f32603470f 100644 --- a/libs/proxy/tokio-postgres2/src/transaction.rs +++ b/libs/proxy/tokio-postgres2/src/transaction.rs @@ -72,4 +72,9 @@ impl<'a> Transaction<'a> { pub fn client(&self) -> &Client { self.client } + + /// Returns a reference to the underlying `Client`. + pub fn client_mut(&mut self) -> &mut Client { + self.client + } } diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index 612702231f..47009086c3 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -860,7 +860,13 @@ impl QueryData { let cancel_token = inner.cancel_token(); let res = match select( - pin!(query_to_json(config, &*inner, self, &mut 0, parsed_headers)), + pin!(query_to_json( + config, + &mut *inner, + self, + &mut 0, + parsed_headers + )), pin!(cancel.cancelled()), ) .await @@ -944,7 +950,7 @@ impl BatchQueryData { builder = builder.deferrable(true); } - let transaction = builder + let mut transaction = builder .start() .await .inspect_err(|_| { @@ -957,7 +963,7 @@ impl BatchQueryData { let json_output = match query_batch( config, cancel.child_token(), - &transaction, + &mut transaction, self, parsed_headers, ) @@ -1009,7 +1015,7 @@ impl BatchQueryData { async fn query_batch( config: &'static HttpConfig, cancel: CancellationToken, - transaction: &Transaction<'_>, + transaction: &mut Transaction<'_>, queries: BatchQueryData, parsed_headers: HttpHeaders, ) -> Result { @@ -1047,7 +1053,7 @@ async fn query_batch( async fn query_to_json( config: &'static HttpConfig, - client: &T, + client: &mut T, data: QueryData, current_size: &mut usize, parsed_headers: HttpHeaders,