diff --git a/libs/proxy/tokio-postgres2/src/client.rs b/libs/proxy/tokio-postgres2/src/client.rs index d1d87306a6..733f685fa9 100644 --- a/libs/proxy/tokio-postgres2/src/client.rs +++ b/libs/proxy/tokio-postgres2/src/client.rs @@ -1,14 +1,12 @@ use std::collections::HashMap; use std::fmt; use std::net::IpAddr; -use std::sync::Arc; use std::task::{Context, Poll}; use std::time::Duration; use bytes::BytesMut; use fallible_iterator::FallibleIterator; use futures_util::{TryStreamExt, future, ready}; -use parking_lot::Mutex; use postgres_protocol2::message::backend::Message; use postgres_protocol2::message::frontend; use serde::{Deserialize, Serialize}; @@ -68,7 +66,7 @@ pub struct InnerClient { sender: mpsc::UnboundedSender, /// A buffer to use when writing out postgres commands. - buffer: Mutex, + buffer: BytesMut, } impl InnerClient { @@ -85,13 +83,12 @@ impl InnerClient { /// Call the given function with a buffer to be used when writing out /// postgres commands. - pub fn with_buf(&self, f: F) -> R + pub fn with_buf(&mut self, f: F) -> R where F: FnOnce(&mut BytesMut) -> R, { - let mut buffer = self.buffer.lock(); - let r = f(&mut buffer); - buffer.clear(); + let r = f(&mut self.buffer); + self.buffer.clear(); r } } @@ -109,7 +106,7 @@ pub struct SocketConfig { /// The client is one half of what is returned when a connection is established. Users interact with the database /// through this client object. pub struct Client { - inner: Arc, + inner: InnerClient, cached_typeinfo: CachedTypeInfo, socket_config: SocketConfig, @@ -142,10 +139,10 @@ impl Client { secret_key: i32, ) -> Client { Client { - inner: Arc::new(InnerClient { + inner: InnerClient { sender, buffer: Default::default(), - }), + }, cached_typeinfo: Default::default(), socket_config, @@ -160,19 +157,23 @@ impl Client { self.process_id } - pub(crate) fn inner(&self) -> &Arc { - &self.inner + pub(crate) fn inner(&mut self) -> &mut InnerClient { + &mut self.inner } /// Pass text directly to the Postgres backend to allow it to sort out typing itself and /// to save a roundtrip - pub async fn query_raw_txt(&self, statement: &str, params: I) -> Result + pub async fn query_raw_txt( + &mut self, + statement: &str, + params: I, + ) -> Result where S: AsRef, I: IntoIterator>, I::IntoIter: ExactSizeIterator, { - query::query_txt(&self.inner, statement, params).await + query::query_txt(&mut self.inner, statement, params).await } /// Executes a sequence of SQL statements using the simple query protocol, returning the resulting rows. @@ -188,11 +189,14 @@ impl Client { /// Prepared statements should be use for any query which contains user-specified data, as they provided the /// functionality to safely embed that data in the request. Do not form statements via string concatenation and pass /// them to this method! - pub async fn simple_query(&self, query: &str) -> Result, Error> { + pub async fn simple_query(&mut self, query: &str) -> Result, Error> { self.simple_query_raw(query).await?.try_collect().await } - pub(crate) async fn simple_query_raw(&self, query: &str) -> Result { + pub(crate) async fn simple_query_raw( + &mut self, + query: &str, + ) -> Result { simple_query::simple_query(self.inner(), query).await } @@ -206,7 +210,7 @@ impl Client { /// Prepared statements should be use for any query which contains user-specified data, as they provided the /// functionality to safely embed that data in the request. Do not form statements via string concatenation and pass /// them to this method! - pub async fn batch_execute(&self, query: &str) -> Result { + pub async fn batch_execute(&mut self, query: &str) -> Result { simple_query::batch_execute(self.inner(), query).await } @@ -223,7 +227,7 @@ impl Client { /// The transaction will roll back by default - use the `commit` method to commit it. pub async fn transaction(&mut self) -> Result, Error> { struct RollbackIfNotDone<'me> { - client: &'me Client, + client: &'me mut Client, done: bool, } @@ -254,7 +258,7 @@ impl Client { client: self, done: false, }; - self.batch_execute("BEGIN").await?; + cleaner.client.batch_execute("BEGIN").await?; cleaner.done = true; } @@ -282,7 +286,7 @@ impl Client { /// Query for type information pub(crate) async fn get_type_inner(&mut self, oid: Oid) -> Result { - crate::prepare::get_type(&self.inner, &mut self.cached_typeinfo, oid).await + crate::prepare::get_type(&mut self.inner, &mut self.cached_typeinfo, oid).await } /// Determines if the connection to the server has already closed. diff --git a/libs/proxy/tokio-postgres2/src/generic_client.rs b/libs/proxy/tokio-postgres2/src/generic_client.rs index 8e28843347..c426bf4676 100644 --- a/libs/proxy/tokio-postgres2/src/generic_client.rs +++ b/libs/proxy/tokio-postgres2/src/generic_client.rs @@ -15,7 +15,7 @@ mod private { /// This trait is "sealed", and cannot be implemented outside of this crate. pub trait GenericClient: private::Sealed { /// Like `Client::query_raw_txt`. - async fn query_raw_txt(&self, statement: &str, params: I) -> Result + async fn query_raw_txt(&mut self, statement: &str, params: I) -> Result where S: AsRef + Sync + Send, I: IntoIterator> + Sync + Send, @@ -28,7 +28,7 @@ pub trait GenericClient: private::Sealed { impl private::Sealed for Client {} impl GenericClient for Client { - async fn query_raw_txt(&self, statement: &str, params: I) -> Result + async fn query_raw_txt(&mut self, statement: &str, params: I) -> Result where S: AsRef + Sync + Send, I: IntoIterator> + Sync + Send, @@ -46,7 +46,7 @@ impl GenericClient for Client { impl private::Sealed for Transaction<'_> {} impl GenericClient for Transaction<'_> { - async fn query_raw_txt(&self, statement: &str, params: I) -> Result + async fn query_raw_txt(&mut self, statement: &str, params: I) -> Result where S: AsRef + Sync + Send, I: IntoIterator> + Sync + Send, diff --git a/libs/proxy/tokio-postgres2/src/prepare.rs b/libs/proxy/tokio-postgres2/src/prepare.rs index 8b6b363889..f7789632e3 100644 --- a/libs/proxy/tokio-postgres2/src/prepare.rs +++ b/libs/proxy/tokio-postgres2/src/prepare.rs @@ -1,6 +1,5 @@ use std::future::Future; use std::pin::Pin; -use std::sync::Arc; use bytes::Bytes; use fallible_iterator::FallibleIterator; @@ -24,7 +23,7 @@ WHERE t.oid = $1 "; async fn prepare_typecheck( - client: &Arc, + client: &mut InnerClient, name: &'static str, query: &str, types: &[Type], @@ -68,7 +67,12 @@ async fn prepare_typecheck( Ok(Statement::new(name, parameters, columns)) } -fn encode(client: &InnerClient, name: &str, query: &str, types: &[Type]) -> Result { +fn encode( + client: &mut InnerClient, + name: &str, + query: &str, + types: &[Type], +) -> Result { if types.is_empty() { debug!("preparing query {}: {}", name, query); } else { @@ -84,7 +88,7 @@ fn encode(client: &InnerClient, name: &str, query: &str, types: &[Type]) -> Resu } pub async fn get_type( - client: &Arc, + client: &mut InnerClient, typecache: &mut CachedTypeInfo, oid: Oid, ) -> Result { @@ -139,7 +143,7 @@ pub async fn get_type( } fn get_type_rec<'a>( - client: &'a Arc, + client: &'a mut InnerClient, typecache: &'a mut CachedTypeInfo, oid: Oid, ) -> Pin> + Send + 'a>> { @@ -147,7 +151,7 @@ fn get_type_rec<'a>( } async fn typeinfo_statement( - client: &Arc, + client: &mut InnerClient, typecache: &mut CachedTypeInfo, ) -> Result { if let Some(stmt) = &typecache.typeinfo { diff --git a/libs/proxy/tokio-postgres2/src/query.rs b/libs/proxy/tokio-postgres2/src/query.rs index 106bc69d49..bda8e74d7d 100644 --- a/libs/proxy/tokio-postgres2/src/query.rs +++ b/libs/proxy/tokio-postgres2/src/query.rs @@ -1,7 +1,6 @@ use std::fmt; use std::marker::PhantomPinned; use std::pin::Pin; -use std::sync::Arc; use std::task::{Context, Poll}; use bytes::{BufMut, Bytes, BytesMut}; @@ -28,7 +27,7 @@ impl fmt::Debug for BorrowToSqlParamsDebug<'_> { } pub async fn query<'a, I>( - client: &InnerClient, + client: &mut InnerClient, statement: Statement, params: I, ) -> Result @@ -59,7 +58,7 @@ where } pub async fn query_txt( - client: &Arc, + client: &mut InnerClient, query: &str, params: I, ) -> Result @@ -159,7 +158,7 @@ where }) } -async fn start(client: &InnerClient, buf: Bytes) -> Result { +async fn start(client: &mut InnerClient, buf: Bytes) -> Result { let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; match responses.next().await? { @@ -170,7 +169,11 @@ async fn start(client: &InnerClient, buf: Bytes) -> Result { Ok(responses) } -pub fn encode<'a, I>(client: &InnerClient, statement: &Statement, params: I) -> Result +pub fn encode<'a, I>( + client: &mut InnerClient, + statement: &Statement, + params: I, +) -> Result where I: IntoIterator, I::IntoIter: ExactSizeIterator, diff --git a/libs/proxy/tokio-postgres2/src/simple_query.rs b/libs/proxy/tokio-postgres2/src/simple_query.rs index 2cf17188cf..f4d85d4100 100644 --- a/libs/proxy/tokio-postgres2/src/simple_query.rs +++ b/libs/proxy/tokio-postgres2/src/simple_query.rs @@ -33,7 +33,10 @@ impl SimpleColumn { } } -pub async fn simple_query(client: &InnerClient, query: &str) -> Result { +pub async fn simple_query( + client: &mut InnerClient, + query: &str, +) -> Result { debug!("executing simple query: {}", query); let buf = encode(client, query)?; @@ -48,7 +51,7 @@ pub async fn simple_query(client: &InnerClient, query: &str) -> Result Result { debug!("executing statement batch: {}", query); @@ -68,7 +71,7 @@ pub async fn batch_execute( } } -pub(crate) fn encode(client: &InnerClient, query: &str) -> Result { +pub(crate) fn encode(client: &mut InnerClient, query: &str) -> Result { client.with_buf(|buf| { frontend::query(query, buf).map_err(Error::encode)?; Ok(buf.split().freeze()) diff --git a/libs/proxy/tokio-postgres2/src/transaction.rs b/libs/proxy/tokio-postgres2/src/transaction.rs index f32603470f..ecb35cf60f 100644 --- a/libs/proxy/tokio-postgres2/src/transaction.rs +++ b/libs/proxy/tokio-postgres2/src/transaction.rs @@ -54,7 +54,11 @@ impl<'a> Transaction<'a> { } /// Like `Client::query_raw_txt`. - pub async fn query_raw_txt(&self, statement: &str, params: I) -> Result + pub async fn query_raw_txt( + &mut self, + statement: &str, + params: I, + ) -> Result where S: AsRef, I: IntoIterator>,