remove arc around inner client

This commit is contained in:
Conrad Ludgate
2025-05-21 21:58:03 +01:00
parent ab61864df1
commit 4b0c7f9530
6 changed files with 56 additions and 38 deletions

View File

@@ -1,14 +1,12 @@
use std::collections::HashMap;
use std::fmt;
use std::net::IpAddr;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Duration;
use bytes::BytesMut;
use fallible_iterator::FallibleIterator;
use futures_util::{TryStreamExt, future, ready};
use parking_lot::Mutex;
use postgres_protocol2::message::backend::Message;
use postgres_protocol2::message::frontend;
use serde::{Deserialize, Serialize};
@@ -68,7 +66,7 @@ pub struct InnerClient {
sender: mpsc::UnboundedSender<Request>,
/// A buffer to use when writing out postgres commands.
buffer: Mutex<BytesMut>,
buffer: BytesMut,
}
impl InnerClient {
@@ -85,13 +83,12 @@ impl InnerClient {
/// Call the given function with a buffer to be used when writing out
/// postgres commands.
pub fn with_buf<F, R>(&self, f: F) -> R
pub fn with_buf<F, R>(&mut self, f: F) -> R
where
F: FnOnce(&mut BytesMut) -> R,
{
let mut buffer = self.buffer.lock();
let r = f(&mut buffer);
buffer.clear();
let r = f(&mut self.buffer);
self.buffer.clear();
r
}
}
@@ -109,7 +106,7 @@ pub struct SocketConfig {
/// The client is one half of what is returned when a connection is established. Users interact with the database
/// through this client object.
pub struct Client {
inner: Arc<InnerClient>,
inner: InnerClient,
cached_typeinfo: CachedTypeInfo,
socket_config: SocketConfig,
@@ -142,10 +139,10 @@ impl Client {
secret_key: i32,
) -> Client {
Client {
inner: Arc::new(InnerClient {
inner: InnerClient {
sender,
buffer: Default::default(),
}),
},
cached_typeinfo: Default::default(),
socket_config,
@@ -160,19 +157,23 @@ impl Client {
self.process_id
}
pub(crate) fn inner(&self) -> &Arc<InnerClient> {
&self.inner
pub(crate) fn inner(&mut self) -> &mut InnerClient {
&mut self.inner
}
/// Pass text directly to the Postgres backend to allow it to sort out typing itself and
/// to save a roundtrip
pub async fn query_raw_txt<S, I>(&self, statement: &str, params: I) -> Result<RowStream, Error>
pub async fn query_raw_txt<S, I>(
&mut self,
statement: &str,
params: I,
) -> Result<RowStream, Error>
where
S: AsRef<str>,
I: IntoIterator<Item = Option<S>>,
I::IntoIter: ExactSizeIterator,
{
query::query_txt(&self.inner, statement, params).await
query::query_txt(&mut self.inner, statement, params).await
}
/// Executes a sequence of SQL statements using the simple query protocol, returning the resulting rows.
@@ -188,11 +189,14 @@ impl Client {
/// Prepared statements should be use for any query which contains user-specified data, as they provided the
/// functionality to safely embed that data in the request. Do not form statements via string concatenation and pass
/// them to this method!
pub async fn simple_query(&self, query: &str) -> Result<Vec<SimpleQueryMessage>, Error> {
pub async fn simple_query(&mut self, query: &str) -> Result<Vec<SimpleQueryMessage>, Error> {
self.simple_query_raw(query).await?.try_collect().await
}
pub(crate) async fn simple_query_raw(&self, query: &str) -> Result<SimpleQueryStream, Error> {
pub(crate) async fn simple_query_raw(
&mut self,
query: &str,
) -> Result<SimpleQueryStream, Error> {
simple_query::simple_query(self.inner(), query).await
}
@@ -206,7 +210,7 @@ impl Client {
/// Prepared statements should be use for any query which contains user-specified data, as they provided the
/// functionality to safely embed that data in the request. Do not form statements via string concatenation and pass
/// them to this method!
pub async fn batch_execute(&self, query: &str) -> Result<ReadyForQueryStatus, Error> {
pub async fn batch_execute(&mut self, query: &str) -> Result<ReadyForQueryStatus, Error> {
simple_query::batch_execute(self.inner(), query).await
}
@@ -223,7 +227,7 @@ impl Client {
/// The transaction will roll back by default - use the `commit` method to commit it.
pub async fn transaction(&mut self) -> Result<Transaction<'_>, Error> {
struct RollbackIfNotDone<'me> {
client: &'me Client,
client: &'me mut Client,
done: bool,
}
@@ -254,7 +258,7 @@ impl Client {
client: self,
done: false,
};
self.batch_execute("BEGIN").await?;
cleaner.client.batch_execute("BEGIN").await?;
cleaner.done = true;
}
@@ -282,7 +286,7 @@ impl Client {
/// Query for type information
pub(crate) async fn get_type_inner(&mut self, oid: Oid) -> Result<Type, Error> {
crate::prepare::get_type(&self.inner, &mut self.cached_typeinfo, oid).await
crate::prepare::get_type(&mut self.inner, &mut self.cached_typeinfo, oid).await
}
/// Determines if the connection to the server has already closed.

View File

@@ -15,7 +15,7 @@ mod private {
/// This trait is "sealed", and cannot be implemented outside of this crate.
pub trait GenericClient: private::Sealed {
/// Like `Client::query_raw_txt`.
async fn query_raw_txt<S, I>(&self, statement: &str, params: I) -> Result<RowStream, Error>
async fn query_raw_txt<S, I>(&mut self, statement: &str, params: I) -> Result<RowStream, Error>
where
S: AsRef<str> + Sync + Send,
I: IntoIterator<Item = Option<S>> + Sync + Send,
@@ -28,7 +28,7 @@ pub trait GenericClient: private::Sealed {
impl private::Sealed for Client {}
impl GenericClient for Client {
async fn query_raw_txt<S, I>(&self, statement: &str, params: I) -> Result<RowStream, Error>
async fn query_raw_txt<S, I>(&mut self, statement: &str, params: I) -> Result<RowStream, Error>
where
S: AsRef<str> + Sync + Send,
I: IntoIterator<Item = Option<S>> + Sync + Send,
@@ -46,7 +46,7 @@ impl GenericClient for Client {
impl private::Sealed for Transaction<'_> {}
impl GenericClient for Transaction<'_> {
async fn query_raw_txt<S, I>(&self, statement: &str, params: I) -> Result<RowStream, Error>
async fn query_raw_txt<S, I>(&mut self, statement: &str, params: I) -> Result<RowStream, Error>
where
S: AsRef<str> + Sync + Send,
I: IntoIterator<Item = Option<S>> + Sync + Send,

View File

@@ -1,6 +1,5 @@
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use bytes::Bytes;
use fallible_iterator::FallibleIterator;
@@ -24,7 +23,7 @@ WHERE t.oid = $1
";
async fn prepare_typecheck(
client: &Arc<InnerClient>,
client: &mut InnerClient,
name: &'static str,
query: &str,
types: &[Type],
@@ -68,7 +67,12 @@ async fn prepare_typecheck(
Ok(Statement::new(name, parameters, columns))
}
fn encode(client: &InnerClient, name: &str, query: &str, types: &[Type]) -> Result<Bytes, Error> {
fn encode(
client: &mut InnerClient,
name: &str,
query: &str,
types: &[Type],
) -> Result<Bytes, Error> {
if types.is_empty() {
debug!("preparing query {}: {}", name, query);
} else {
@@ -84,7 +88,7 @@ fn encode(client: &InnerClient, name: &str, query: &str, types: &[Type]) -> Resu
}
pub async fn get_type(
client: &Arc<InnerClient>,
client: &mut InnerClient,
typecache: &mut CachedTypeInfo,
oid: Oid,
) -> Result<Type, Error> {
@@ -139,7 +143,7 @@ pub async fn get_type(
}
fn get_type_rec<'a>(
client: &'a Arc<InnerClient>,
client: &'a mut InnerClient,
typecache: &'a mut CachedTypeInfo,
oid: Oid,
) -> Pin<Box<dyn Future<Output = Result<Type, Error>> + Send + 'a>> {
@@ -147,7 +151,7 @@ fn get_type_rec<'a>(
}
async fn typeinfo_statement(
client: &Arc<InnerClient>,
client: &mut InnerClient,
typecache: &mut CachedTypeInfo,
) -> Result<Statement, Error> {
if let Some(stmt) = &typecache.typeinfo {

View File

@@ -1,7 +1,6 @@
use std::fmt;
use std::marker::PhantomPinned;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use bytes::{BufMut, Bytes, BytesMut};
@@ -28,7 +27,7 @@ impl fmt::Debug for BorrowToSqlParamsDebug<'_> {
}
pub async fn query<'a, I>(
client: &InnerClient,
client: &mut InnerClient,
statement: Statement,
params: I,
) -> Result<RowStream, Error>
@@ -59,7 +58,7 @@ where
}
pub async fn query_txt<S, I>(
client: &Arc<InnerClient>,
client: &mut InnerClient,
query: &str,
params: I,
) -> Result<RowStream, Error>
@@ -159,7 +158,7 @@ where
})
}
async fn start(client: &InnerClient, buf: Bytes) -> Result<Responses, Error> {
async fn start(client: &mut InnerClient, buf: Bytes) -> Result<Responses, Error> {
let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
match responses.next().await? {
@@ -170,7 +169,11 @@ async fn start(client: &InnerClient, buf: Bytes) -> Result<Responses, Error> {
Ok(responses)
}
pub fn encode<'a, I>(client: &InnerClient, statement: &Statement, params: I) -> Result<Bytes, Error>
pub fn encode<'a, I>(
client: &mut InnerClient,
statement: &Statement,
params: I,
) -> Result<Bytes, Error>
where
I: IntoIterator<Item = &'a (dyn ToSql + Sync)>,
I::IntoIter: ExactSizeIterator,

View File

@@ -33,7 +33,10 @@ impl SimpleColumn {
}
}
pub async fn simple_query(client: &InnerClient, query: &str) -> Result<SimpleQueryStream, Error> {
pub async fn simple_query(
client: &mut InnerClient,
query: &str,
) -> Result<SimpleQueryStream, Error> {
debug!("executing simple query: {}", query);
let buf = encode(client, query)?;
@@ -48,7 +51,7 @@ pub async fn simple_query(client: &InnerClient, query: &str) -> Result<SimpleQue
}
pub async fn batch_execute(
client: &InnerClient,
client: &mut InnerClient,
query: &str,
) -> Result<ReadyForQueryStatus, Error> {
debug!("executing statement batch: {}", query);
@@ -68,7 +71,7 @@ pub async fn batch_execute(
}
}
pub(crate) fn encode(client: &InnerClient, query: &str) -> Result<Bytes, Error> {
pub(crate) fn encode(client: &mut InnerClient, query: &str) -> Result<Bytes, Error> {
client.with_buf(|buf| {
frontend::query(query, buf).map_err(Error::encode)?;
Ok(buf.split().freeze())

View File

@@ -54,7 +54,11 @@ impl<'a> Transaction<'a> {
}
/// Like `Client::query_raw_txt`.
pub async fn query_raw_txt<S, I>(&self, statement: &str, params: I) -> Result<RowStream, Error>
pub async fn query_raw_txt<S, I>(
&mut self,
statement: &str,
params: I,
) -> Result<RowStream, Error>
where
S: AsRef<str>,
I: IntoIterator<Item = Option<S>>,