Compare commits

...

4 Commits

Author SHA1 Message Date
Conrad Ludgate
01475c9e75 fix accidental recursion 2024-12-06 12:19:40 +00:00
Conrad Ludgate
c835bbba1f refactor statements and the type cache to avoid arcs 2024-12-06 12:01:19 +00:00
Conrad Ludgate
f94dde4432 delete some more 2024-12-06 11:33:34 +00:00
Conrad Ludgate
4991a85704 delete some client methods and make client take &mut 2024-12-06 11:22:03 +00:00
15 changed files with 278 additions and 846 deletions

View File

@@ -4,23 +4,18 @@ use crate::config::Host;
use crate::config::SslMode; use crate::config::SslMode;
use crate::connection::{Request, RequestMessages}; use crate::connection::{Request, RequestMessages};
use crate::query::RowStream; use crate::types::{Oid, Type};
use crate::simple_query::SimpleQueryStream;
use crate::types::{Oid, ToSql, Type};
use crate::{ use crate::{
prepare, query, simple_query, slice_iter, CancelToken, Error, ReadyForQueryStatus, Row, simple_query, CancelToken, Error, ReadyForQueryStatus, Statement, Transaction,
SimpleQueryMessage, Statement, ToStatement, Transaction, TransactionBuilder, TransactionBuilder,
}; };
use bytes::BytesMut; use bytes::BytesMut;
use fallible_iterator::FallibleIterator; use fallible_iterator::FallibleIterator;
use futures_util::{future, ready, TryStreamExt}; use futures_util::{future, ready};
use parking_lot::Mutex;
use postgres_protocol2::message::{backend::Message, frontend}; use postgres_protocol2::message::{backend::Message, frontend};
use std::collections::HashMap; use std::collections::HashMap;
use std::fmt; use std::fmt;
use std::sync::Arc;
use std::task::{Context, Poll}; use std::task::{Context, Poll};
use tokio::sync::mpsc; use tokio::sync::mpsc;
@@ -55,7 +50,7 @@ impl Responses {
/// A cache of type info and prepared statements for fetching type info /// A cache of type info and prepared statements for fetching type info
/// (corresponding to the queries in the [prepare] module). /// (corresponding to the queries in the [prepare] module).
#[derive(Default)] #[derive(Default)]
struct CachedTypeInfo { pub(crate) struct CachedTypeInfo {
/// A statement for basic information for a type from its /// A statement for basic information for a type from its
/// OID. Corresponds to [TYPEINFO_QUERY](prepare::TYPEINFO_QUERY) (or its /// OID. Corresponds to [TYPEINFO_QUERY](prepare::TYPEINFO_QUERY) (or its
/// fallback). /// fallback).
@@ -71,13 +66,45 @@ struct CachedTypeInfo {
/// Cache of types already looked up. /// Cache of types already looked up.
types: HashMap<Oid, Type>, types: HashMap<Oid, Type>,
} }
impl CachedTypeInfo {
pub(crate) fn typeinfo(&mut self) -> Option<&Statement> {
self.typeinfo.as_ref()
}
pub(crate) fn set_typeinfo(&mut self, statement: Statement) -> &Statement {
self.typeinfo.insert(statement)
}
pub(crate) fn typeinfo_composite(&mut self) -> Option<&Statement> {
self.typeinfo_composite.as_ref()
}
pub(crate) fn set_typeinfo_composite(&mut self, statement: Statement) -> &Statement {
self.typeinfo_composite.insert(statement)
}
pub(crate) fn typeinfo_enum(&mut self) -> Option<&Statement> {
self.typeinfo_enum.as_ref()
}
pub(crate) fn set_typeinfo_enum(&mut self, statement: Statement) -> &Statement {
self.typeinfo_enum.insert(statement)
}
pub(crate) fn type_(&mut self, oid: Oid) -> Option<Type> {
self.types.get(&oid).cloned()
}
pub(crate) fn set_type(&mut self, oid: Oid, type_: &Type) {
self.types.insert(oid, type_.clone());
}
}
pub struct InnerClient { pub struct InnerClient {
sender: mpsc::UnboundedSender<Request>, sender: mpsc::UnboundedSender<Request>,
cached_typeinfo: Mutex<CachedTypeInfo>,
/// A buffer to use when writing out postgres commands. /// A buffer to use when writing out postgres commands.
buffer: Mutex<BytesMut>, buffer: BytesMut,
} }
impl InnerClient { impl InnerClient {
@@ -92,47 +119,14 @@ impl InnerClient {
}) })
} }
pub fn typeinfo(&self) -> Option<Statement> {
self.cached_typeinfo.lock().typeinfo.clone()
}
pub fn set_typeinfo(&self, statement: &Statement) {
self.cached_typeinfo.lock().typeinfo = Some(statement.clone());
}
pub fn typeinfo_composite(&self) -> Option<Statement> {
self.cached_typeinfo.lock().typeinfo_composite.clone()
}
pub fn set_typeinfo_composite(&self, statement: &Statement) {
self.cached_typeinfo.lock().typeinfo_composite = Some(statement.clone());
}
pub fn typeinfo_enum(&self) -> Option<Statement> {
self.cached_typeinfo.lock().typeinfo_enum.clone()
}
pub fn set_typeinfo_enum(&self, statement: &Statement) {
self.cached_typeinfo.lock().typeinfo_enum = Some(statement.clone());
}
pub fn type_(&self, oid: Oid) -> Option<Type> {
self.cached_typeinfo.lock().types.get(&oid).cloned()
}
pub fn set_type(&self, oid: Oid, type_: &Type) {
self.cached_typeinfo.lock().types.insert(oid, type_.clone());
}
/// Call the given function with a buffer to be used when writing out /// Call the given function with a buffer to be used when writing out
/// postgres commands. /// postgres commands.
pub fn with_buf<F, R>(&self, f: F) -> R pub fn with_buf<F, R>(&mut self, f: F) -> R
where where
F: FnOnce(&mut BytesMut) -> R, F: FnOnce(&mut BytesMut) -> R,
{ {
let mut buffer = self.buffer.lock(); let r = f(&mut self.buffer);
let r = f(&mut buffer); self.buffer.clear();
buffer.clear();
r r
} }
} }
@@ -150,7 +144,8 @@ pub struct SocketConfig {
/// The client is one half of what is returned when a connection is established. Users interact with the database /// The client is one half of what is returned when a connection is established. Users interact with the database
/// through this client object. /// through this client object.
pub struct Client { pub struct Client {
inner: Arc<InnerClient>, pub(crate) inner: InnerClient,
pub(crate) cached_typeinfo: CachedTypeInfo,
socket_config: SocketConfig, socket_config: SocketConfig,
ssl_mode: SslMode, ssl_mode: SslMode,
@@ -167,11 +162,11 @@ impl Client {
secret_key: i32, secret_key: i32,
) -> Client { ) -> Client {
Client { Client {
inner: Arc::new(InnerClient { inner: InnerClient {
sender, sender,
cached_typeinfo: Default::default(),
buffer: Default::default(), buffer: Default::default(),
}), },
cached_typeinfo: Default::default(),
socket_config, socket_config,
ssl_mode, ssl_mode,
@@ -185,161 +180,6 @@ impl Client {
self.process_id self.process_id
} }
pub(crate) fn inner(&self) -> &Arc<InnerClient> {
&self.inner
}
/// Creates a new prepared statement.
///
/// Prepared statements can be executed repeatedly, and may contain query parameters (indicated by `$1`, `$2`, etc),
/// which are set when executed. Prepared statements can only be used with the connection that created them.
pub async fn prepare(&self, query: &str) -> Result<Statement, Error> {
self.prepare_typed(query, &[]).await
}
/// Like `prepare`, but allows the types of query parameters to be explicitly specified.
///
/// The list of types may be smaller than the number of parameters - the types of the remaining parameters will be
/// inferred. For example, `client.prepare_typed(query, &[])` is equivalent to `client.prepare(query)`.
pub async fn prepare_typed(
&self,
query: &str,
parameter_types: &[Type],
) -> Result<Statement, Error> {
prepare::prepare(&self.inner, query, parameter_types).await
}
/// Executes a statement, returning a vector of the resulting rows.
///
/// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list
/// provided, 1-indexed.
///
/// The `statement` argument can either be a `Statement`, or a raw query string. If the same statement will be
/// repeatedly executed (perhaps with different query parameters), consider preparing the statement up front
/// with the `prepare` method.
///
/// # Panics
///
/// Panics if the number of parameters provided does not match the number expected.
pub async fn query<T>(
&self,
statement: &T,
params: &[&(dyn ToSql + Sync)],
) -> Result<Vec<Row>, Error>
where
T: ?Sized + ToStatement,
{
self.query_raw(statement, slice_iter(params))
.await?
.try_collect()
.await
}
/// The maximally flexible version of [`query`].
///
/// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list
/// provided, 1-indexed.
///
/// The `statement` argument can either be a `Statement`, or a raw query string. If the same statement will be
/// repeatedly executed (perhaps with different query parameters), consider preparing the statement up front
/// with the `prepare` method.
///
/// # Panics
///
/// Panics if the number of parameters provided does not match the number expected.
///
/// [`query`]: #method.query
pub async fn query_raw<'a, T, I>(&self, statement: &T, params: I) -> Result<RowStream, Error>
where
T: ?Sized + ToStatement,
I: IntoIterator<Item = &'a (dyn ToSql + Sync)>,
I::IntoIter: ExactSizeIterator,
{
let statement = statement.__convert().into_statement(self).await?;
query::query(&self.inner, statement, params).await
}
/// Pass text directly to the Postgres backend to allow it to sort out typing itself and
/// to save a roundtrip
pub async fn query_raw_txt<S, I>(&self, statement: &str, params: I) -> Result<RowStream, Error>
where
S: AsRef<str>,
I: IntoIterator<Item = Option<S>>,
I::IntoIter: ExactSizeIterator,
{
query::query_txt(&self.inner, statement, params).await
}
/// Executes a statement, returning the number of rows modified.
///
/// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list
/// provided, 1-indexed.
///
/// The `statement` argument can either be a `Statement`, or a raw query string. If the same statement will be
/// repeatedly executed (perhaps with different query parameters), consider preparing the statement up front
/// with the `prepare` method.
///
/// If the statement does not modify any rows (e.g. `SELECT`), 0 is returned.
///
/// # Panics
///
/// Panics if the number of parameters provided does not match the number expected.
pub async fn execute<T>(
&self,
statement: &T,
params: &[&(dyn ToSql + Sync)],
) -> Result<u64, Error>
where
T: ?Sized + ToStatement,
{
self.execute_raw(statement, slice_iter(params)).await
}
/// The maximally flexible version of [`execute`].
///
/// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list
/// provided, 1-indexed.
///
/// The `statement` argument can either be a `Statement`, or a raw query string. If the same statement will be
/// repeatedly executed (perhaps with different query parameters), consider preparing the statement up front
/// with the `prepare` method.
///
/// # Panics
///
/// Panics if the number of parameters provided does not match the number expected.
///
/// [`execute`]: #method.execute
pub async fn execute_raw<'a, T, I>(&self, statement: &T, params: I) -> Result<u64, Error>
where
T: ?Sized + ToStatement,
I: IntoIterator<Item = &'a (dyn ToSql + Sync)>,
I::IntoIter: ExactSizeIterator,
{
let statement = statement.__convert().into_statement(self).await?;
query::execute(self.inner(), statement, params).await
}
/// Executes a sequence of SQL statements using the simple query protocol, returning the resulting rows.
///
/// Statements should be separated by semicolons. If an error occurs, execution of the sequence will stop at that
/// point. The simple query protocol returns the values in rows as strings rather than in their binary encodings,
/// so the associated row type doesn't work with the `FromSql` trait. Rather than simply returning a list of the
/// rows, this method returns a list of an enum which indicates either the completion of one of the commands,
/// or a row of data. This preserves the framing between the separate statements in the request.
///
/// # Warning
///
/// Prepared statements should be use for any query which contains user-specified data, as they provided the
/// functionality to safely embed that data in the request. Do not form statements via string concatenation and pass
/// them to this method!
pub async fn simple_query(&self, query: &str) -> Result<Vec<SimpleQueryMessage>, Error> {
self.simple_query_raw(query).await?.try_collect().await
}
pub(crate) async fn simple_query_raw(&self, query: &str) -> Result<SimpleQueryStream, Error> {
simple_query::simple_query(self.inner(), query).await
}
/// Executes a sequence of SQL statements using the simple query protocol. /// Executes a sequence of SQL statements using the simple query protocol.
/// ///
/// Statements should be separated by semicolons. If an error occurs, execution of the sequence will stop at that /// Statements should be separated by semicolons. If an error occurs, execution of the sequence will stop at that
@@ -350,8 +190,8 @@ impl Client {
/// Prepared statements should be use for any query which contains user-specified data, as they provided the /// Prepared statements should be use for any query which contains user-specified data, as they provided the
/// functionality to safely embed that data in the request. Do not form statements via string concatenation and pass /// functionality to safely embed that data in the request. Do not form statements via string concatenation and pass
/// them to this method! /// them to this method!
pub async fn batch_execute(&self, query: &str) -> Result<ReadyForQueryStatus, Error> { pub async fn batch_execute(&mut self, query: &str) -> Result<ReadyForQueryStatus, Error> {
simple_query::batch_execute(self.inner(), query).await simple_query::batch_execute(&mut self.inner, query).await
} }
/// Begins a new database transaction. /// Begins a new database transaction.
@@ -359,7 +199,7 @@ impl Client {
/// The transaction will roll back by default - use the `commit` method to commit it. /// The transaction will roll back by default - use the `commit` method to commit it.
pub async fn transaction(&mut self) -> Result<Transaction<'_>, Error> { pub async fn transaction(&mut self) -> Result<Transaction<'_>, Error> {
struct RollbackIfNotDone<'me> { struct RollbackIfNotDone<'me> {
client: &'me Client, client: &'me mut Client,
done: bool, done: bool,
} }
@@ -369,13 +209,13 @@ impl Client {
return; return;
} }
let buf = self.client.inner().with_buf(|buf| { let buf = self.client.inner.with_buf(|buf| {
frontend::query("ROLLBACK", buf).unwrap(); frontend::query("ROLLBACK", buf).unwrap();
buf.split().freeze() buf.split().freeze()
}); });
let _ = self let _ = self
.client .client
.inner() .inner
.send(RequestMessages::Single(FrontendMessage::Raw(buf))); .send(RequestMessages::Single(FrontendMessage::Raw(buf)));
} }
} }
@@ -390,7 +230,7 @@ impl Client {
client: self, client: self,
done: false, done: false,
}; };
self.batch_execute("BEGIN").await?; cleaner.client.batch_execute("BEGIN").await?;
cleaner.done = true; cleaner.done = true;
} }
@@ -416,11 +256,6 @@ impl Client {
} }
} }
/// Query for type information
pub async fn get_type(&self, oid: Oid) -> Result<Type, Error> {
crate::prepare::get_type(&self.inner, oid).await
}
/// Determines if the connection to the server has already closed. /// Determines if the connection to the server has already closed.
/// ///
/// In that case, all future queries will fail. /// In that case, all future queries will fail.

View File

@@ -1,4 +1,4 @@
use crate::query::RowStream; use crate::query::{self, RowStream};
use crate::types::Type; use crate::types::Type;
use crate::{Client, Error, Transaction}; use crate::{Client, Error, Transaction};
use async_trait::async_trait; use async_trait::async_trait;
@@ -13,33 +13,32 @@ mod private {
/// This trait is "sealed", and cannot be implemented outside of this crate. /// This trait is "sealed", and cannot be implemented outside of this crate.
#[async_trait] #[async_trait]
pub trait GenericClient: private::Sealed { pub trait GenericClient: private::Sealed {
/// Like `Client::query_raw_txt`. async fn query_raw_txt<S, I>(&mut self, statement: &str, params: I) -> Result<RowStream, Error>
async fn query_raw_txt<S, I>(&self, statement: &str, params: I) -> Result<RowStream, Error>
where where
S: AsRef<str> + Sync + Send, S: AsRef<str> + Sync + Send,
I: IntoIterator<Item = Option<S>> + Sync + Send, I: IntoIterator<Item = Option<S>> + Sync + Send,
I::IntoIter: ExactSizeIterator + Sync + Send; I::IntoIter: ExactSizeIterator + Sync + Send;
/// Query for type information /// Query for type information
async fn get_type(&self, oid: Oid) -> Result<Type, Error>; async fn get_type(&mut self, oid: Oid) -> Result<Type, Error>;
} }
impl private::Sealed for Client {} impl private::Sealed for Client {}
#[async_trait] #[async_trait]
impl GenericClient for Client { impl GenericClient for Client {
async fn query_raw_txt<S, I>(&self, statement: &str, params: I) -> Result<RowStream, Error> async fn query_raw_txt<S, I>(&mut self, statement: &str, params: I) -> Result<RowStream, Error>
where where
S: AsRef<str> + Sync + Send, S: AsRef<str> + Sync + Send,
I: IntoIterator<Item = Option<S>> + Sync + Send, I: IntoIterator<Item = Option<S>> + Sync + Send,
I::IntoIter: ExactSizeIterator + Sync + Send, I::IntoIter: ExactSizeIterator + Sync + Send,
{ {
self.query_raw_txt(statement, params).await query::query_txt(&mut self.inner, statement, params).await
} }
/// Query for type information /// Query for type information
async fn get_type(&self, oid: Oid) -> Result<Type, Error> { async fn get_type(&mut self, oid: Oid) -> Result<Type, Error> {
self.get_type(oid).await crate::prepare::get_type(&mut self.inner, &mut self.cached_typeinfo, oid).await
} }
} }
@@ -48,17 +47,18 @@ impl private::Sealed for Transaction<'_> {}
#[async_trait] #[async_trait]
#[allow(clippy::needless_lifetimes)] #[allow(clippy::needless_lifetimes)]
impl GenericClient for Transaction<'_> { impl GenericClient for Transaction<'_> {
async fn query_raw_txt<S, I>(&self, statement: &str, params: I) -> Result<RowStream, Error> async fn query_raw_txt<S, I>(&mut self, statement: &str, params: I) -> Result<RowStream, Error>
where where
S: AsRef<str> + Sync + Send, S: AsRef<str> + Sync + Send,
I: IntoIterator<Item = Option<S>> + Sync + Send, I: IntoIterator<Item = Option<S>> + Sync + Send,
I::IntoIter: ExactSizeIterator + Sync + Send, I::IntoIter: ExactSizeIterator + Sync + Send,
{ {
self.query_raw_txt(statement, params).await query::query_txt(&mut self.client().inner, statement, params).await
} }
/// Query for type information /// Query for type information
async fn get_type(&self, oid: Oid) -> Result<Type, Error> { async fn get_type(&mut self, oid: Oid) -> Result<Type, Error> {
self.client().get_type(oid).await let client = self.client();
crate::prepare::get_type(&mut client.inner, &mut client.cached_typeinfo, oid).await
} }
} }

View File

@@ -10,11 +10,10 @@ use crate::error::DbError;
pub use crate::error::Error; pub use crate::error::Error;
pub use crate::generic_client::GenericClient; pub use crate::generic_client::GenericClient;
pub use crate::query::RowStream; pub use crate::query::RowStream;
pub use crate::row::{Row, SimpleQueryRow}; pub use crate::row::Row;
pub use crate::simple_query::SimpleQueryStream;
pub use crate::statement::{Column, Statement}; pub use crate::statement::{Column, Statement};
pub use crate::tls::NoTls; pub use crate::tls::NoTls;
pub use crate::to_statement::ToStatement; // pub use crate::to_statement::ToStatement;
pub use crate::transaction::Transaction; pub use crate::transaction::Transaction;
pub use crate::transaction_builder::{IsolationLevel, TransactionBuilder}; pub use crate::transaction_builder::{IsolationLevel, TransactionBuilder};
use crate::types::ToSql; use crate::types::ToSql;
@@ -65,7 +64,7 @@ pub mod row;
mod simple_query; mod simple_query;
mod statement; mod statement;
pub mod tls; pub mod tls;
mod to_statement; // mod to_statement;
mod transaction; mod transaction;
mod transaction_builder; mod transaction_builder;
pub mod types; pub mod types;
@@ -98,7 +97,6 @@ impl Notification {
/// An asynchronous message from the server. /// An asynchronous message from the server.
#[allow(clippy::large_enum_variant)] #[allow(clippy::large_enum_variant)]
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
#[non_exhaustive]
pub enum AsyncMessage { pub enum AsyncMessage {
/// A notice. /// A notice.
/// ///
@@ -110,18 +108,6 @@ pub enum AsyncMessage {
Notification(Notification), Notification(Notification),
} }
/// Message returned by the `SimpleQuery` stream.
#[derive(Debug)]
#[non_exhaustive]
pub enum SimpleQueryMessage {
/// A row of data.
Row(SimpleQueryRow),
/// A statement in the query has completed.
///
/// The number of rows modified or selected is returned.
CommandComplete(u64),
}
fn slice_iter<'a>( fn slice_iter<'a>(
s: &'a [&'a (dyn ToSql + Sync)], s: &'a [&'a (dyn ToSql + Sync)],
) -> impl ExactSizeIterator<Item = &'a (dyn ToSql + Sync)> + 'a { ) -> impl ExactSizeIterator<Item = &'a (dyn ToSql + Sync)> + 'a {

View File

@@ -1,4 +1,4 @@
use crate::client::InnerClient; use crate::client::{CachedTypeInfo, InnerClient};
use crate::codec::FrontendMessage; use crate::codec::FrontendMessage;
use crate::connection::RequestMessages; use crate::connection::RequestMessages;
use crate::error::SqlState; use crate::error::SqlState;
@@ -7,14 +7,13 @@ use crate::{query, slice_iter};
use crate::{Column, Error, Statement}; use crate::{Column, Error, Statement};
use bytes::Bytes; use bytes::Bytes;
use fallible_iterator::FallibleIterator; use fallible_iterator::FallibleIterator;
use futures_util::{pin_mut, TryStreamExt}; use futures_util::{pin_mut, StreamExt, TryStreamExt};
use log::debug; use log::debug;
use postgres_protocol2::message::backend::Message; use postgres_protocol2::message::backend::Message;
use postgres_protocol2::message::frontend; use postgres_protocol2::message::frontend;
use std::future::Future; use std::future::Future;
use std::pin::Pin; use std::pin::{pin, Pin};
use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
pub(crate) const TYPEINFO_QUERY: &str = "\ pub(crate) const TYPEINFO_QUERY: &str = "\
SELECT t.typname, t.typtype, t.typelem, r.rngsubtype, t.typbasetype, n.nspname, t.typrelid SELECT t.typname, t.typtype, t.typelem, r.rngsubtype, t.typbasetype, n.nspname, t.typrelid
@@ -59,7 +58,8 @@ ORDER BY attnum
static NEXT_ID: AtomicUsize = AtomicUsize::new(0); static NEXT_ID: AtomicUsize = AtomicUsize::new(0);
pub async fn prepare( pub async fn prepare(
client: &Arc<InnerClient>, client: &mut InnerClient,
cache: &mut CachedTypeInfo,
query: &str, query: &str,
types: &[Type], types: &[Type],
) -> Result<Statement, Error> { ) -> Result<Statement, Error> {
@@ -86,7 +86,7 @@ pub async fn prepare(
let mut parameters = vec![]; let mut parameters = vec![];
let mut it = parameter_description.parameters(); let mut it = parameter_description.parameters();
while let Some(oid) = it.next().map_err(Error::parse)? { while let Some(oid) = it.next().map_err(Error::parse)? {
let type_ = get_type(client, oid).await?; let type_ = get_type(client, cache, oid).await?;
parameters.push(type_); parameters.push(type_);
} }
@@ -94,24 +94,30 @@ pub async fn prepare(
if let Some(row_description) = row_description { if let Some(row_description) = row_description {
let mut it = row_description.fields(); let mut it = row_description.fields();
while let Some(field) = it.next().map_err(Error::parse)? { while let Some(field) = it.next().map_err(Error::parse)? {
let type_ = get_type(client, field.type_oid()).await?; let type_ = get_type(client, cache, field.type_oid()).await?;
let column = Column::new(field.name().to_string(), type_, field); let column = Column::new(field.name().to_string(), type_, field);
columns.push(column); columns.push(column);
} }
} }
Ok(Statement::new(client, name, parameters, columns)) Ok(Statement::new(name, parameters, columns))
} }
fn prepare_rec<'a>( fn prepare_rec<'a>(
client: &'a Arc<InnerClient>, client: &'a mut InnerClient,
cache: &'a mut CachedTypeInfo,
query: &'a str, query: &'a str,
types: &'a [Type], types: &'a [Type],
) -> Pin<Box<dyn Future<Output = Result<Statement, Error>> + 'a + Send>> { ) -> Pin<Box<dyn Future<Output = Result<Statement, Error>> + 'a + Send>> {
Box::pin(prepare(client, query, types)) Box::pin(prepare(client, cache, query, types))
} }
fn encode(client: &InnerClient, name: &str, query: &str, types: &[Type]) -> Result<Bytes, Error> { fn encode(
client: &mut InnerClient,
name: &str,
query: &str,
types: &[Type],
) -> Result<Bytes, Error> {
if types.is_empty() { if types.is_empty() {
debug!("preparing query {}: {}", name, query); debug!("preparing query {}: {}", name, query);
} else { } else {
@@ -126,16 +132,20 @@ fn encode(client: &InnerClient, name: &str, query: &str, types: &[Type]) -> Resu
}) })
} }
pub async fn get_type(client: &Arc<InnerClient>, oid: Oid) -> Result<Type, Error> { pub async fn get_type(
client: &mut InnerClient,
cache: &mut CachedTypeInfo,
oid: Oid,
) -> Result<Type, Error> {
if let Some(type_) = Type::from_oid(oid) { if let Some(type_) = Type::from_oid(oid) {
return Ok(type_); return Ok(type_);
} }
if let Some(type_) = client.type_(oid) { if let Some(type_) = cache.type_(oid) {
return Ok(type_); return Ok(type_);
} }
let stmt = typeinfo_statement(client).await?; let stmt = typeinfo_statement(client, cache).await?;
let rows = query::query(client, stmt, slice_iter(&[&oid])).await?; let rows = query::query(client, stmt, slice_iter(&[&oid])).await?;
pin_mut!(rows); pin_mut!(rows);
@@ -145,118 +155,141 @@ pub async fn get_type(client: &Arc<InnerClient>, oid: Oid) -> Result<Type, Error
None => return Err(Error::unexpected_message()), None => return Err(Error::unexpected_message()),
}; };
let name: String = row.try_get(0)?; let name: String = row.try_get(stmt.columns(), 0)?;
let type_: i8 = row.try_get(1)?; let type_: i8 = row.try_get(stmt.columns(), 1)?;
let elem_oid: Oid = row.try_get(2)?; let elem_oid: Oid = row.try_get(stmt.columns(), 2)?;
let rngsubtype: Option<Oid> = row.try_get(3)?; let rngsubtype: Option<Oid> = row.try_get(stmt.columns(), 3)?;
let basetype: Oid = row.try_get(4)?; let basetype: Oid = row.try_get(stmt.columns(), 4)?;
let schema: String = row.try_get(5)?; let schema: String = row.try_get(stmt.columns(), 5)?;
let relid: Oid = row.try_get(6)?; let relid: Oid = row.try_get(stmt.columns(), 6)?;
let kind = if type_ == b'e' as i8 { let kind = if type_ == b'e' as i8 {
let variants = get_enum_variants(client, oid).await?; let variants = get_enum_variants(client, cache, oid).await?;
Kind::Enum(variants) Kind::Enum(variants)
} else if type_ == b'p' as i8 { } else if type_ == b'p' as i8 {
Kind::Pseudo Kind::Pseudo
} else if basetype != 0 { } else if basetype != 0 {
let type_ = get_type_rec(client, basetype).await?; let type_ = get_type_rec(client, cache, basetype).await?;
Kind::Domain(type_) Kind::Domain(type_)
} else if elem_oid != 0 { } else if elem_oid != 0 {
let type_ = get_type_rec(client, elem_oid).await?; let type_ = get_type_rec(client, cache, elem_oid).await?;
Kind::Array(type_) Kind::Array(type_)
} else if relid != 0 { } else if relid != 0 {
let fields = get_composite_fields(client, relid).await?; let fields = get_composite_fields(client, cache, relid).await?;
Kind::Composite(fields) Kind::Composite(fields)
} else if let Some(rngsubtype) = rngsubtype { } else if let Some(rngsubtype) = rngsubtype {
let type_ = get_type_rec(client, rngsubtype).await?; let type_ = get_type_rec(client, cache, rngsubtype).await?;
Kind::Range(type_) Kind::Range(type_)
} else { } else {
Kind::Simple Kind::Simple
}; };
let type_ = Type::new(name, oid, kind, schema); let type_ = Type::new(name, oid, kind, schema);
client.set_type(oid, &type_); cache.set_type(oid, &type_);
Ok(type_) Ok(type_)
} }
fn get_type_rec<'a>( fn get_type_rec<'a>(
client: &'a Arc<InnerClient>, client: &'a mut InnerClient,
cache: &'a mut CachedTypeInfo,
oid: Oid, oid: Oid,
) -> Pin<Box<dyn Future<Output = Result<Type, Error>> + Send + 'a>> { ) -> Pin<Box<dyn Future<Output = Result<Type, Error>> + Send + 'a>> {
Box::pin(get_type(client, oid)) Box::pin(get_type(client, cache, oid))
} }
async fn typeinfo_statement(client: &Arc<InnerClient>) -> Result<Statement, Error> { async fn typeinfo_statement<'c>(
if let Some(stmt) = client.typeinfo() { client: &mut InnerClient,
return Ok(stmt); cache: &'c mut CachedTypeInfo,
) -> Result<&'c Statement, Error> {
if cache.typeinfo().is_some() {
// needed to get around a borrow checker limitation
return Ok(cache.typeinfo().unwrap());
} }
let stmt = match prepare_rec(client, TYPEINFO_QUERY, &[]).await { let stmt = match prepare_rec(client, cache, TYPEINFO_QUERY, &[]).await {
Ok(stmt) => stmt, Ok(stmt) => stmt,
Err(ref e) if e.code() == Some(&SqlState::UNDEFINED_TABLE) => { Err(ref e) if e.code() == Some(&SqlState::UNDEFINED_TABLE) => {
prepare_rec(client, TYPEINFO_FALLBACK_QUERY, &[]).await? prepare_rec(client, cache, TYPEINFO_FALLBACK_QUERY, &[]).await?
} }
Err(e) => return Err(e), Err(e) => return Err(e),
}; };
client.set_typeinfo(&stmt); Ok(cache.set_typeinfo(stmt))
Ok(stmt)
} }
async fn get_enum_variants(client: &Arc<InnerClient>, oid: Oid) -> Result<Vec<String>, Error> { async fn get_enum_variants(
let stmt = typeinfo_enum_statement(client).await?; client: &mut InnerClient,
cache: &mut CachedTypeInfo,
oid: Oid,
) -> Result<Vec<String>, Error> {
let stmt = typeinfo_enum_statement(client, cache).await?;
query::query(client, stmt, slice_iter(&[&oid])) let mut out = vec![];
.await?
.and_then(|row| async move { row.try_get(0) }) let mut rows = pin!(query::query(client, stmt, slice_iter(&[&oid])).await?);
.try_collect() while let Some(row) = rows.next().await {
.await out.push(row?.try_get(stmt.columns(), 0)?)
}
Ok(out)
} }
async fn typeinfo_enum_statement(client: &Arc<InnerClient>) -> Result<Statement, Error> { async fn typeinfo_enum_statement<'c>(
if let Some(stmt) = client.typeinfo_enum() { client: &mut InnerClient,
return Ok(stmt); cache: &'c mut CachedTypeInfo,
) -> Result<&'c Statement, Error> {
if cache.typeinfo_enum().is_some() {
// needed to get around a borrow checker limitation
return Ok(cache.typeinfo_enum().unwrap());
} }
let stmt = match prepare_rec(client, TYPEINFO_ENUM_QUERY, &[]).await { let stmt = match prepare_rec(client, cache, TYPEINFO_ENUM_QUERY, &[]).await {
Ok(stmt) => stmt, Ok(stmt) => stmt,
Err(ref e) if e.code() == Some(&SqlState::UNDEFINED_COLUMN) => { Err(ref e) if e.code() == Some(&SqlState::UNDEFINED_COLUMN) => {
prepare_rec(client, TYPEINFO_ENUM_FALLBACK_QUERY, &[]).await? prepare_rec(client, cache, TYPEINFO_ENUM_FALLBACK_QUERY, &[]).await?
} }
Err(e) => return Err(e), Err(e) => return Err(e),
}; };
client.set_typeinfo_enum(&stmt); Ok(cache.set_typeinfo_enum(stmt))
Ok(stmt)
} }
async fn get_composite_fields(client: &Arc<InnerClient>, oid: Oid) -> Result<Vec<Field>, Error> { async fn get_composite_fields(
let stmt = typeinfo_composite_statement(client).await?; client: &mut InnerClient,
cache: &mut CachedTypeInfo,
oid: Oid,
) -> Result<Vec<Field>, Error> {
let stmt = typeinfo_composite_statement(client, cache).await?;
let rows = query::query(client, stmt, slice_iter(&[&oid])) let mut rows = pin!(query::query(client, stmt, slice_iter(&[&oid])).await?);
.await?
.try_collect::<Vec<_>>() let mut oids = vec![];
.await?; while let Some(row) = rows.next().await {
let row = row?;
let name = row.try_get(stmt.columns(), 0)?;
let oid = row.try_get(stmt.columns(), 1)?;
oids.push((name, oid));
}
let mut fields = vec![]; let mut fields = vec![];
for row in rows { for (name, oid) in oids {
let name = row.try_get(0)?; let type_ = get_type_rec(client, cache, oid).await?;
let oid = row.try_get(1)?;
let type_ = get_type_rec(client, oid).await?;
fields.push(Field::new(name, type_)); fields.push(Field::new(name, type_));
} }
Ok(fields) Ok(fields)
} }
async fn typeinfo_composite_statement(client: &Arc<InnerClient>) -> Result<Statement, Error> { async fn typeinfo_composite_statement<'c>(
if let Some(stmt) = client.typeinfo_composite() { client: &mut InnerClient,
return Ok(stmt); cache: &'c mut CachedTypeInfo,
) -> Result<&'c Statement, Error> {
if cache.typeinfo_composite().is_some() {
// needed to get around a borrow checker limitation
return Ok(cache.typeinfo_composite().unwrap());
} }
let stmt = prepare_rec(client, TYPEINFO_COMPOSITE_QUERY, &[]).await?; let stmt = prepare_rec(client, cache, TYPEINFO_COMPOSITE_QUERY, &[]).await?;
client.set_typeinfo_composite(&stmt); Ok(cache.set_typeinfo_composite(stmt))
Ok(stmt)
} }

View File

@@ -14,7 +14,6 @@ use postgres_types2::{Format, ToSql, Type};
use std::fmt; use std::fmt;
use std::marker::PhantomPinned; use std::marker::PhantomPinned;
use std::pin::Pin; use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll}; use std::task::{Context, Poll};
struct BorrowToSqlParamsDebug<'a>(&'a [&'a (dyn ToSql + Sync)]); struct BorrowToSqlParamsDebug<'a>(&'a [&'a (dyn ToSql + Sync)]);
@@ -26,10 +25,10 @@ impl fmt::Debug for BorrowToSqlParamsDebug<'_> {
} }
pub async fn query<'a, I>( pub async fn query<'a, I>(
client: &InnerClient, client: &mut InnerClient,
statement: Statement, statement: &Statement,
params: I, params: I,
) -> Result<RowStream, Error> ) -> Result<RawRowStream, Error>
where where
I: IntoIterator<Item = &'a (dyn ToSql + Sync)>, I: IntoIterator<Item = &'a (dyn ToSql + Sync)>,
I::IntoIter: ExactSizeIterator, I::IntoIter: ExactSizeIterator,
@@ -41,13 +40,12 @@ where
statement.name(), statement.name(),
BorrowToSqlParamsDebug(params.as_slice()), BorrowToSqlParamsDebug(params.as_slice()),
); );
encode(client, &statement, params)? encode(client, statement, params)?
} else { } else {
encode(client, &statement, params)? encode(client, statement, params)?
}; };
let responses = start(client, buf).await?; let responses = start(client, buf).await?;
Ok(RowStream { Ok(RawRowStream {
statement,
responses, responses,
command_tag: None, command_tag: None,
status: ReadyForQueryStatus::Unknown, status: ReadyForQueryStatus::Unknown,
@@ -57,7 +55,7 @@ where
} }
pub async fn query_txt<S, I>( pub async fn query_txt<S, I>(
client: &Arc<InnerClient>, client: &mut InnerClient,
query: &str, query: &str,
params: I, params: I,
) -> Result<RowStream, Error> ) -> Result<RowStream, Error>
@@ -157,49 +155,6 @@ where
}) })
} }
pub async fn execute<'a, I>(
client: &InnerClient,
statement: Statement,
params: I,
) -> Result<u64, Error>
where
I: IntoIterator<Item = &'a (dyn ToSql + Sync)>,
I::IntoIter: ExactSizeIterator,
{
let buf = if log_enabled!(Level::Debug) {
let params = params.into_iter().collect::<Vec<_>>();
debug!(
"executing statement {} with parameters: {:?}",
statement.name(),
BorrowToSqlParamsDebug(params.as_slice()),
);
encode(client, &statement, params)?
} else {
encode(client, &statement, params)?
};
let mut responses = start(client, buf).await?;
let mut rows = 0;
loop {
match responses.next().await? {
Message::DataRow(_) => {}
Message::CommandComplete(body) => {
rows = body
.tag()
.map_err(Error::parse)?
.rsplit(' ')
.next()
.unwrap()
.parse()
.unwrap_or(0);
}
Message::EmptyQueryResponse => rows = 0,
Message::ReadyForQuery(_) => return Ok(rows),
_ => return Err(Error::unexpected_message()),
}
}
}
async fn start(client: &InnerClient, buf: Bytes) -> Result<Responses, Error> { async fn start(client: &InnerClient, buf: Bytes) -> Result<Responses, Error> {
let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
@@ -211,7 +166,11 @@ async fn start(client: &InnerClient, buf: Bytes) -> Result<Responses, Error> {
Ok(responses) Ok(responses)
} }
pub fn encode<'a, I>(client: &InnerClient, statement: &Statement, params: I) -> Result<Bytes, Error> pub fn encode<'a, I>(
client: &mut InnerClient,
statement: &Statement,
params: I,
) -> Result<Bytes, Error>
where where
I: IntoIterator<Item = &'a (dyn ToSql + Sync)>, I: IntoIterator<Item = &'a (dyn ToSql + Sync)>,
I::IntoIter: ExactSizeIterator, I::IntoIter: ExactSizeIterator,
@@ -296,11 +255,7 @@ impl Stream for RowStream {
loop { loop {
match ready!(this.responses.poll_next(cx)?) { match ready!(this.responses.poll_next(cx)?) {
Message::DataRow(body) => { Message::DataRow(body) => {
return Poll::Ready(Some(Ok(Row::new( return Poll::Ready(Some(Ok(Row::new(body, *this.output_format)?)))
this.statement.clone(),
body,
*this.output_format,
)?)))
} }
Message::EmptyQueryResponse | Message::PortalSuspended => {} Message::EmptyQueryResponse | Message::PortalSuspended => {}
Message::CommandComplete(body) => { Message::CommandComplete(body) => {
@@ -338,3 +293,41 @@ impl RowStream {
self.status self.status
} }
} }
pin_project! {
/// A stream of table rows.
pub struct RawRowStream {
responses: Responses,
command_tag: Option<String>,
output_format: Format,
status: ReadyForQueryStatus,
#[pin]
_p: PhantomPinned,
}
}
impl Stream for RawRowStream {
type Item = Result<Row, Error>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.project();
loop {
match ready!(this.responses.poll_next(cx)?) {
Message::DataRow(body) => {
return Poll::Ready(Some(Ok(Row::new(body, *this.output_format)?)))
}
Message::EmptyQueryResponse | Message::PortalSuspended => {}
Message::CommandComplete(body) => {
if let Ok(tag) = body.tag() {
*this.command_tag = Some(tag.to_string());
}
}
Message::ReadyForQuery(status) => {
*this.status = status.into();
return Poll::Ready(None);
}
_ => return Poll::Ready(Some(Err(Error::unexpected_message()))),
}
}
}
}

View File

@@ -1,103 +1,16 @@
//! Rows. //! Rows.
use crate::row::sealed::{AsName, Sealed};
use crate::simple_query::SimpleColumn;
use crate::statement::Column; use crate::statement::Column;
use crate::types::{FromSql, Type, WrongType}; use crate::types::{FromSql, Type, WrongType};
use crate::{Error, Statement}; use crate::Error;
use fallible_iterator::FallibleIterator; use fallible_iterator::FallibleIterator;
use postgres_protocol2::message::backend::DataRowBody; use postgres_protocol2::message::backend::DataRowBody;
use postgres_types2::{Format, WrongFormat}; use postgres_types2::{Format, WrongFormat};
use std::fmt; use std::fmt;
use std::ops::Range; use std::ops::Range;
use std::str; use std::str;
use std::sync::Arc;
mod sealed {
pub trait Sealed {}
pub trait AsName {
fn as_name(&self) -> &str;
}
}
impl AsName for Column {
fn as_name(&self) -> &str {
self.name()
}
}
impl AsName for String {
fn as_name(&self) -> &str {
self
}
}
/// A trait implemented by types that can index into columns of a row.
///
/// This cannot be implemented outside of this crate.
pub trait RowIndex: Sealed {
#[doc(hidden)]
fn __idx<T>(&self, columns: &[T]) -> Option<usize>
where
T: AsName;
}
impl Sealed for usize {}
impl RowIndex for usize {
#[inline]
fn __idx<T>(&self, columns: &[T]) -> Option<usize>
where
T: AsName,
{
if *self >= columns.len() {
None
} else {
Some(*self)
}
}
}
impl Sealed for str {}
impl RowIndex for str {
#[inline]
fn __idx<T>(&self, columns: &[T]) -> Option<usize>
where
T: AsName,
{
if let Some(idx) = columns.iter().position(|d| d.as_name() == self) {
return Some(idx);
};
// FIXME ASCII-only case insensitivity isn't really the right thing to
// do. Postgres itself uses a dubious wrapper around tolower and JDBC
// uses the US locale.
columns
.iter()
.position(|d| d.as_name().eq_ignore_ascii_case(self))
}
}
impl<T> Sealed for &T where T: ?Sized + Sealed {}
impl<T> RowIndex for &T
where
T: ?Sized + RowIndex,
{
#[inline]
fn __idx<U>(&self, columns: &[U]) -> Option<usize>
where
U: AsName,
{
T::__idx(*self, columns)
}
}
/// A row of data returned from the database by a query. /// A row of data returned from the database by a query.
pub struct Row { pub struct Row {
statement: Statement,
output_format: Format, output_format: Format,
body: DataRowBody, body: DataRowBody,
ranges: Vec<Option<Range<usize>>>, ranges: Vec<Option<Range<usize>>>,
@@ -105,80 +18,33 @@ pub struct Row {
impl fmt::Debug for Row { impl fmt::Debug for Row {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Row") f.debug_struct("Row").finish()
.field("columns", &self.columns())
.finish()
} }
} }
impl Row { impl Row {
pub(crate) fn new( pub(crate) fn new(
statement: Statement, // statement: Statement,
body: DataRowBody, body: DataRowBody,
output_format: Format, output_format: Format,
) -> Result<Row, Error> { ) -> Result<Row, Error> {
let ranges = body.ranges().collect().map_err(Error::parse)?; let ranges = body.ranges().collect().map_err(Error::parse)?;
Ok(Row { Ok(Row {
statement,
body, body,
ranges, ranges,
output_format, output_format,
}) })
} }
/// Returns information about the columns of data in the row. pub(crate) fn try_get<'a, T>(&'a self, columns: &[Column], idx: usize) -> Result<T, Error>
pub fn columns(&self) -> &[Column] {
self.statement.columns()
}
/// Determines if the row contains no values.
pub fn is_empty(&self) -> bool {
self.len() == 0
}
/// Returns the number of values in the row.
pub fn len(&self) -> usize {
self.columns().len()
}
/// Deserializes a value from the row.
///
/// The value can be specified either by its numeric index in the row, or by its column name.
///
/// # Panics
///
/// Panics if the index is out of bounds or if the value cannot be converted to the specified type.
pub fn get<'a, I, T>(&'a self, idx: I) -> T
where where
I: RowIndex + fmt::Display,
T: FromSql<'a>, T: FromSql<'a>,
{ {
match self.get_inner(&idx) { let Some(column) = columns.get(idx) else {
Ok(ok) => ok, return Err(Error::column(idx.to_string()));
Err(err) => panic!("error retrieving column {}: {}", idx, err),
}
}
/// Like `Row::get`, but returns a `Result` rather than panicking.
pub fn try_get<'a, I, T>(&'a self, idx: I) -> Result<T, Error>
where
I: RowIndex + fmt::Display,
T: FromSql<'a>,
{
self.get_inner(&idx)
}
fn get_inner<'a, I, T>(&'a self, idx: &I) -> Result<T, Error>
where
I: RowIndex + fmt::Display,
T: FromSql<'a>,
{
let idx = match idx.__idx(self.columns()) {
Some(idx) => idx,
None => return Err(Error::column(idx.to_string())),
}; };
let ty = self.columns()[idx].type_(); let ty = column.type_();
if !T::accepts(ty) { if !T::accepts(ty) {
return Err(Error::from_sql( return Err(Error::from_sql(
Box::new(WrongType::new::<T>(ty.clone())), Box::new(WrongType::new::<T>(ty.clone())),
@@ -216,85 +82,3 @@ impl Row {
self.body.buffer().len() self.body.buffer().len()
} }
} }
impl AsName for SimpleColumn {
fn as_name(&self) -> &str {
self.name()
}
}
/// A row of data returned from the database by a simple query.
#[derive(Debug)]
pub struct SimpleQueryRow {
columns: Arc<[SimpleColumn]>,
body: DataRowBody,
ranges: Vec<Option<Range<usize>>>,
}
impl SimpleQueryRow {
#[allow(clippy::new_ret_no_self)]
pub(crate) fn new(
columns: Arc<[SimpleColumn]>,
body: DataRowBody,
) -> Result<SimpleQueryRow, Error> {
let ranges = body.ranges().collect().map_err(Error::parse)?;
Ok(SimpleQueryRow {
columns,
body,
ranges,
})
}
/// Returns information about the columns of data in the row.
pub fn columns(&self) -> &[SimpleColumn] {
&self.columns
}
/// Determines if the row contains no values.
pub fn is_empty(&self) -> bool {
self.len() == 0
}
/// Returns the number of values in the row.
pub fn len(&self) -> usize {
self.columns.len()
}
/// Returns a value from the row.
///
/// The value can be specified either by its numeric index in the row, or by its column name.
///
/// # Panics
///
/// Panics if the index is out of bounds or if the value cannot be converted to the specified type.
pub fn get<I>(&self, idx: I) -> Option<&str>
where
I: RowIndex + fmt::Display,
{
match self.get_inner(&idx) {
Ok(ok) => ok,
Err(err) => panic!("error retrieving column {}: {}", idx, err),
}
}
/// Like `SimpleQueryRow::get`, but returns a `Result` rather than panicking.
pub fn try_get<I>(&self, idx: I) -> Result<Option<&str>, Error>
where
I: RowIndex + fmt::Display,
{
self.get_inner(&idx)
}
fn get_inner<I>(&self, idx: &I) -> Result<Option<&str>, Error>
where
I: RowIndex + fmt::Display,
{
let idx = match idx.__idx(&self.columns) {
Some(idx) => idx,
None => return Err(Error::column(idx.to_string())),
};
let buf = self.ranges[idx].clone().map(|r| &self.body.buffer()[r]);
FromSql::from_sql_nullable(&Type::TEXT, buf).map_err(|e| Error::from_sql(e, idx))
}
}

View File

@@ -1,52 +1,14 @@
use crate::client::{InnerClient, Responses}; use crate::client::InnerClient;
use crate::codec::FrontendMessage; use crate::codec::FrontendMessage;
use crate::connection::RequestMessages; use crate::connection::RequestMessages;
use crate::{Error, ReadyForQueryStatus, SimpleQueryMessage, SimpleQueryRow}; use crate::{Error, ReadyForQueryStatus};
use bytes::Bytes; use bytes::Bytes;
use fallible_iterator::FallibleIterator;
use futures_util::{ready, Stream};
use log::debug; use log::debug;
use pin_project_lite::pin_project;
use postgres_protocol2::message::backend::Message; use postgres_protocol2::message::backend::Message;
use postgres_protocol2::message::frontend; use postgres_protocol2::message::frontend;
use std::marker::PhantomPinned;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
/// Information about a column of a single query row.
#[derive(Debug)]
pub struct SimpleColumn {
name: String,
}
impl SimpleColumn {
pub(crate) fn new(name: String) -> SimpleColumn {
SimpleColumn { name }
}
/// Returns the name of the column.
pub fn name(&self) -> &str {
&self.name
}
}
pub async fn simple_query(client: &InnerClient, query: &str) -> Result<SimpleQueryStream, Error> {
debug!("executing simple query: {}", query);
let buf = encode(client, query)?;
let responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
Ok(SimpleQueryStream {
responses,
columns: None,
status: ReadyForQueryStatus::Unknown,
_p: PhantomPinned,
})
}
pub async fn batch_execute( pub async fn batch_execute(
client: &InnerClient, client: &mut InnerClient,
query: &str, query: &str,
) -> Result<ReadyForQueryStatus, Error> { ) -> Result<ReadyForQueryStatus, Error> {
debug!("executing statement batch: {}", query); debug!("executing statement batch: {}", query);
@@ -66,77 +28,9 @@ pub async fn batch_execute(
} }
} }
pub(crate) fn encode(client: &InnerClient, query: &str) -> Result<Bytes, Error> { pub(crate) fn encode(client: &mut InnerClient, query: &str) -> Result<Bytes, Error> {
client.with_buf(|buf| { client.with_buf(|buf| {
frontend::query(query, buf).map_err(Error::encode)?; frontend::query(query, buf).map_err(Error::encode)?;
Ok(buf.split().freeze()) Ok(buf.split().freeze())
}) })
} }
pin_project! {
/// A stream of simple query results.
pub struct SimpleQueryStream {
responses: Responses,
columns: Option<Arc<[SimpleColumn]>>,
status: ReadyForQueryStatus,
#[pin]
_p: PhantomPinned,
}
}
impl SimpleQueryStream {
/// Returns if the connection is ready for querying, with the status of the connection.
///
/// This might be available only after the stream has been exhausted.
pub fn ready_status(&self) -> ReadyForQueryStatus {
self.status
}
}
impl Stream for SimpleQueryStream {
type Item = Result<SimpleQueryMessage, Error>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.project();
loop {
match ready!(this.responses.poll_next(cx)?) {
Message::CommandComplete(body) => {
let rows = body
.tag()
.map_err(Error::parse)?
.rsplit(' ')
.next()
.unwrap()
.parse()
.unwrap_or(0);
return Poll::Ready(Some(Ok(SimpleQueryMessage::CommandComplete(rows))));
}
Message::EmptyQueryResponse => {
return Poll::Ready(Some(Ok(SimpleQueryMessage::CommandComplete(0))));
}
Message::RowDescription(body) => {
let columns = body
.fields()
.map(|f| Ok(SimpleColumn::new(f.name().to_string())))
.collect::<Vec<_>>()
.map_err(Error::parse)?
.into();
*this.columns = Some(columns);
}
Message::DataRow(body) => {
let row = match &this.columns {
Some(columns) => SimpleQueryRow::new(columns.clone(), body)?,
None => return Poll::Ready(Some(Err(Error::unexpected_message()))),
};
return Poll::Ready(Some(Ok(SimpleQueryMessage::Row(row))));
}
Message::ReadyForQuery(s) => {
*this.status = s.into();
return Poll::Ready(None);
}
_ => return Poll::Ready(Some(Err(Error::unexpected_message()))),
}
}
}
}

View File

@@ -1,64 +1,33 @@
use crate::client::InnerClient;
use crate::codec::FrontendMessage;
use crate::connection::RequestMessages;
use crate::types::Type; use crate::types::Type;
use postgres_protocol2::{ use postgres_protocol2::{message::backend::Field, Oid};
message::{backend::Field, frontend}, use std::fmt;
Oid,
};
use std::{
fmt,
sync::{Arc, Weak},
};
struct StatementInner { struct StatementInner {
client: Weak<InnerClient>,
name: String, name: String,
params: Vec<Type>, params: Vec<Type>,
columns: Vec<Column>, columns: Vec<Column>,
} }
impl Drop for StatementInner {
fn drop(&mut self) {
if let Some(client) = self.client.upgrade() {
let buf = client.with_buf(|buf| {
frontend::close(b'S', &self.name, buf).unwrap();
frontend::sync(buf);
buf.split().freeze()
});
let _ = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)));
}
}
}
/// A prepared statement. /// A prepared statement.
/// ///
/// Prepared statements can only be used with the connection that created them. /// Prepared statements can only be used with the connection that created them.
#[derive(Clone)] pub struct Statement(StatementInner);
pub struct Statement(Arc<StatementInner>);
impl Statement { impl Statement {
pub(crate) fn new( pub(crate) fn new(name: String, params: Vec<Type>, columns: Vec<Column>) -> Statement {
inner: &Arc<InnerClient>, Statement(StatementInner {
name: String,
params: Vec<Type>,
columns: Vec<Column>,
) -> Statement {
Statement(Arc::new(StatementInner {
client: Arc::downgrade(inner),
name, name,
params, params,
columns, columns,
})) })
} }
pub(crate) fn new_anonymous(params: Vec<Type>, columns: Vec<Column>) -> Statement { pub(crate) fn new_anonymous(params: Vec<Type>, columns: Vec<Column>) -> Statement {
Statement(Arc::new(StatementInner { Statement(StatementInner {
client: Weak::new(),
name: String::new(), name: String::new(),
params, params,
columns, columns,
})) })
} }
pub(crate) fn name(&self) -> &str { pub(crate) fn name(&self) -> &str {

View File

@@ -1,57 +0,0 @@
use crate::to_statement::private::{Sealed, ToStatementType};
use crate::Statement;
mod private {
use crate::{Client, Error, Statement};
pub trait Sealed {}
pub enum ToStatementType<'a> {
Statement(&'a Statement),
Query(&'a str),
}
impl<'a> ToStatementType<'a> {
pub async fn into_statement(self, client: &Client) -> Result<Statement, Error> {
match self {
ToStatementType::Statement(s) => Ok(s.clone()),
ToStatementType::Query(s) => client.prepare(s).await,
}
}
}
}
/// A trait abstracting over prepared and unprepared statements.
///
/// Many methods are generic over this bound, so that they support both a raw query string as well as a statement which
/// was prepared previously.
///
/// This trait is "sealed" and cannot be implemented by anything outside this crate.
pub trait ToStatement: Sealed {
#[doc(hidden)]
fn __convert(&self) -> ToStatementType<'_>;
}
impl ToStatement for Statement {
fn __convert(&self) -> ToStatementType<'_> {
ToStatementType::Statement(self)
}
}
impl Sealed for Statement {}
impl ToStatement for str {
fn __convert(&self) -> ToStatementType<'_> {
ToStatementType::Query(self)
}
}
impl Sealed for str {}
impl ToStatement for String {
fn __convert(&self) -> ToStatementType<'_> {
ToStatementType::Query(self)
}
}
impl Sealed for String {}

View File

@@ -1,6 +1,5 @@
use crate::codec::FrontendMessage; use crate::codec::FrontendMessage;
use crate::connection::RequestMessages; use crate::connection::RequestMessages;
use crate::query::RowStream;
use crate::{CancelToken, Client, Error, ReadyForQueryStatus}; use crate::{CancelToken, Client, Error, ReadyForQueryStatus};
use postgres_protocol2::message::frontend; use postgres_protocol2::message::frontend;
@@ -19,13 +18,13 @@ impl Drop for Transaction<'_> {
return; return;
} }
let buf = self.client.inner().with_buf(|buf| { let buf = self.client.inner.with_buf(|buf| {
frontend::query("ROLLBACK", buf).unwrap(); frontend::query("ROLLBACK", buf).unwrap();
buf.split().freeze() buf.split().freeze()
}); });
let _ = self let _ = self
.client .client
.inner() .inner
.send(RequestMessages::Single(FrontendMessage::Raw(buf))); .send(RequestMessages::Single(FrontendMessage::Raw(buf)));
} }
} }
@@ -52,23 +51,13 @@ impl<'a> Transaction<'a> {
self.client.batch_execute("ROLLBACK").await self.client.batch_execute("ROLLBACK").await
} }
/// Like `Client::query_raw_txt`.
pub async fn query_raw_txt<S, I>(&self, statement: &str, params: I) -> Result<RowStream, Error>
where
S: AsRef<str>,
I: IntoIterator<Item = Option<S>>,
I::IntoIter: ExactSizeIterator,
{
self.client.query_raw_txt(statement, params).await
}
/// Like `Client::cancel_token`. /// Like `Client::cancel_token`.
pub fn cancel_token(&self) -> CancelToken { pub fn cancel_token(&self) -> CancelToken {
self.client.cancel_token() self.client.cancel_token()
} }
/// Returns a reference to the underlying `Client`. /// Returns a reference to the underlying `Client`.
pub fn client(&self) -> &Client { pub fn client(&mut self) -> &mut Client {
self.client self.client
} }
} }

View File

@@ -340,7 +340,7 @@ impl PoolingBackend {
debug!("setting up backend session state"); debug!("setting up backend session state");
// initiates the auth session // initiates the auth session
if let Err(e) = client.execute("select auth.init()", &[]).await { if let Err(e) = client.batch_execute("select auth.init();").await {
discard.discard(); discard.discard();
return Err(e.into()); return Err(e.into());
} }

View File

@@ -11,7 +11,7 @@ use smallvec::SmallVec;
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tokio::time::Instant; use tokio::time::Instant;
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use tracing::{error, info, info_span, warn, Instrument}; use tracing::{debug, error, info, info_span, Instrument};
#[cfg(test)] #[cfg(test)]
use { use {
super::conn_pool_lib::GlobalConnPoolOptions, super::conn_pool_lib::GlobalConnPoolOptions,
@@ -125,13 +125,10 @@ pub(crate) fn poll_client<C: ClientInnerExt>(
match message { match message {
Some(Ok(AsyncMessage::Notice(notice))) => { Some(Ok(AsyncMessage::Notice(notice))) => {
info!(%session_id, "notice: {}", notice); debug!(%session_id, "notice: {}", notice);
} }
Some(Ok(AsyncMessage::Notification(notif))) => { Some(Ok(AsyncMessage::Notification(notif))) => {
warn!(%session_id, pid = notif.process_id(), channel = notif.channel(), "notification received"); debug!(%session_id, pid = notif.process_id(), channel = notif.channel(), "notification received");
}
Some(Ok(_)) => {
warn!(%session_id, "unknown message");
} }
Some(Err(e)) => { Some(Err(e)) => {
error!(%session_id, "connection error: {}", e); error!(%session_id, "connection error: {}", e);

View File

@@ -1,5 +1,5 @@
use postgres_client::types::{Kind, Type}; use postgres_client::types::{Kind, Type};
use postgres_client::Row; use postgres_client::{Column, Row};
use serde_json::{Map, Value}; use serde_json::{Map, Value};
// //
@@ -77,14 +77,14 @@ pub(crate) enum JsonConversionError {
// //
pub(crate) fn pg_text_row_to_json( pub(crate) fn pg_text_row_to_json(
row: &Row, row: &Row,
columns: &[Type], columns: &[Column],
c_types: &[Type],
raw_output: bool, raw_output: bool,
array_mode: bool, array_mode: bool,
) -> Result<Value, JsonConversionError> { ) -> Result<Value, JsonConversionError> {
let iter = row let iter = columns
.columns()
.iter() .iter()
.zip(columns) .zip(c_types)
.enumerate() .enumerate()
.map(|(i, (column, typ))| { .map(|(i, (column, typ))| {
let name = column.name(); let name = column.name();

View File

@@ -23,14 +23,13 @@ use jose_jwk::jose_b64::base64ct::{Base64UrlUnpadded, Encoding};
use p256::ecdsa::{Signature, SigningKey}; use p256::ecdsa::{Signature, SigningKey};
use parking_lot::RwLock; use parking_lot::RwLock;
use postgres_client::tls::NoTlsStream; use postgres_client::tls::NoTlsStream;
use postgres_client::types::ToSql;
use postgres_client::AsyncMessage; use postgres_client::AsyncMessage;
use serde_json::value::RawValue; use serde_json::value::RawValue;
use signature::Signer; use signature::Signer;
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tokio::time::Instant; use tokio::time::Instant;
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use tracing::{debug, error, info, info_span, warn, Instrument}; use tracing::{debug, error, info, info_span, Instrument};
use super::backend::HttpConnError; use super::backend::HttpConnError;
use super::conn_pool_lib::{ use super::conn_pool_lib::{
@@ -229,13 +228,10 @@ pub(crate) fn poll_client<C: ClientInnerExt>(
match message { match message {
Some(Ok(AsyncMessage::Notice(notice))) => { Some(Ok(AsyncMessage::Notice(notice))) => {
info!(%session_id, "notice: {}", notice); debug!(%session_id, "notice: {}", notice);
} }
Some(Ok(AsyncMessage::Notification(notif))) => { Some(Ok(AsyncMessage::Notification(notif))) => {
warn!(%session_id, pid = notif.process_id(), channel = notif.channel(), "notification received"); debug!(%session_id, pid = notif.process_id(), channel = notif.channel(), "notification received");
}
Some(Ok(_)) => {
warn!(%session_id, "unknown message");
} }
Some(Err(e)) => { Some(Err(e)) => {
error!(%session_id, "connection error: {}", e); error!(%session_id, "connection error: {}", e);
@@ -287,12 +283,11 @@ impl ClientInnerCommon<postgres_client::Client> {
let token = resign_jwt(&local_data.key, payload, local_data.jti)?; let token = resign_jwt(&local_data.key, payload, local_data.jti)?;
// initiates the auth session // initiates the auth session
self.inner.batch_execute("discard all").await?; // the token contains only `[a-zA-Z1-9_-\.]+` so it cannot escape the string literal formatting.
self.inner self.inner
.execute( .batch_execute(&format!(
"select auth.jwt_session_init($1)", "discard all; select auth.jwt_session_init('{token}');"
&[&&*token as &(dyn ToSql + Sync)], ))
)
.await?; .await?;
let pid = self.inner.get_process_id(); let pid = self.inner.get_process_id();

View File

@@ -797,7 +797,13 @@ impl QueryData {
let cancel_token = inner.cancel_token(); let cancel_token = inner.cancel_token();
let res = match select( let res = match select(
pin!(query_to_json(config, &*inner, self, &mut 0, parsed_headers)), pin!(query_to_json(
config,
&mut *inner,
self,
&mut 0,
parsed_headers
)),
pin!(cancel.cancelled()), pin!(cancel.cancelled()),
) )
.await .await
@@ -881,7 +887,7 @@ impl BatchQueryData {
builder = builder.deferrable(true); builder = builder.deferrable(true);
} }
let transaction = builder.start().await.inspect_err(|_| { let mut transaction = builder.start().await.inspect_err(|_| {
// if we cannot start a transaction, we should return immediately // if we cannot start a transaction, we should return immediately
// and not return to the pool. connection is clearly broken // and not return to the pool. connection is clearly broken
discard.discard(); discard.discard();
@@ -890,7 +896,7 @@ impl BatchQueryData {
let json_output = match query_batch( let json_output = match query_batch(
config, config,
cancel.child_token(), cancel.child_token(),
&transaction, &mut transaction,
self, self,
parsed_headers, parsed_headers,
) )
@@ -934,7 +940,7 @@ impl BatchQueryData {
async fn query_batch( async fn query_batch(
config: &'static HttpConfig, config: &'static HttpConfig,
cancel: CancellationToken, cancel: CancellationToken,
transaction: &Transaction<'_>, transaction: &mut Transaction<'_>,
queries: BatchQueryData, queries: BatchQueryData,
parsed_headers: HttpHeaders, parsed_headers: HttpHeaders,
) -> Result<String, SqlOverHttpError> { ) -> Result<String, SqlOverHttpError> {
@@ -972,7 +978,7 @@ async fn query_batch(
async fn query_to_json<T: GenericClient>( async fn query_to_json<T: GenericClient>(
config: &'static HttpConfig, config: &'static HttpConfig,
client: &T, client: &mut T,
data: QueryData, data: QueryData,
current_size: &mut usize, current_size: &mut usize,
parsed_headers: HttpHeaders, parsed_headers: HttpHeaders,
@@ -1027,7 +1033,7 @@ async fn query_to_json<T: GenericClient>(
let columns_len = row_stream.columns().len(); let columns_len = row_stream.columns().len();
let mut fields = Vec::with_capacity(columns_len); let mut fields = Vec::with_capacity(columns_len);
let mut columns = Vec::with_capacity(columns_len); let mut c_types = Vec::with_capacity(columns_len);
for c in row_stream.columns() { for c in row_stream.columns() {
fields.push(json!({ fields.push(json!({
@@ -1039,7 +1045,7 @@ async fn query_to_json<T: GenericClient>(
"dataTypeModifier": c.type_modifier(), "dataTypeModifier": c.type_modifier(),
"format": "text", "format": "text",
})); }));
columns.push(client.get_type(c.type_oid()).await?); c_types.push(client.get_type(c.type_oid()).await?);
} }
let array_mode = data.array_mode.unwrap_or(parsed_headers.default_array_mode); let array_mode = data.array_mode.unwrap_or(parsed_headers.default_array_mode);
@@ -1047,7 +1053,15 @@ async fn query_to_json<T: GenericClient>(
// convert rows to JSON // convert rows to JSON
let rows = rows let rows = rows
.iter() .iter()
.map(|row| pg_text_row_to_json(row, &columns, parsed_headers.raw_output, array_mode)) .map(|row| {
pg_text_row_to_json(
row,
row_stream.columns(),
&c_types,
parsed_headers.raw_output,
array_mode,
)
})
.collect::<Result<Vec<_>, _>>()?; .collect::<Result<Vec<_>, _>>()?;
// Resulting JSON format is based on the format of node-postgres result. // Resulting JSON format is based on the format of node-postgres result.