remove typesafe transaction type as we already ensure rollback is performed

This commit is contained in:
Conrad Ludgate
2025-07-22 16:35:07 +01:00
parent 20355cb5f0
commit 286ac97a9c
6 changed files with 41 additions and 260 deletions

View File

@@ -18,10 +18,7 @@ use crate::config::{Host, SslMode};
use crate::query::RowStream;
use crate::simple_query::SimpleQueryStream;
use crate::types::{Oid, Type};
use crate::{
CancelToken, Error, ReadyForQueryStatus, SimpleQueryMessage, Transaction, TransactionBuilder,
query, simple_query,
};
use crate::{CancelToken, Error, ReadyForQueryStatus, SimpleQueryMessage, query, simple_query};
pub struct Responses {
/// new messages from conn
@@ -320,48 +317,9 @@ impl Client {
Ok(())
}
/// Begins a new database transaction.
///
/// The transaction will roll back by default - use the `commit` method to commit it.
pub async fn transaction(&mut self) -> Result<Transaction<'_>, Error> {
struct RollbackIfNotDone<'me> {
client: &'me mut Client,
done: bool,
}
impl Drop for RollbackIfNotDone<'_> {
fn drop(&mut self) {
if self.done {
return;
}
let _ = self.client.inner.send_simple_query("ROLLBACK");
}
}
// This is done, as `Future` created by this method can be dropped after
// `RequestMessages` is synchronously send to the `Connection` by
// `batch_execute()`, but before `Responses` is asynchronously polled to
// completion. In that case `Transaction` won't be created and thus
// won't be rolled back.
{
let mut cleaner = RollbackIfNotDone {
client: self,
done: false,
};
cleaner.client.batch_execute("BEGIN").await?;
cleaner.done = true;
}
Ok(Transaction::new(self))
}
/// Returns a builder for a transaction with custom settings.
///
/// Unlike the `transaction` method, the builder can be used to control the transaction's isolation level and other
/// attributes.
pub fn build_transaction(&mut self) -> TransactionBuilder<'_> {
TransactionBuilder::new(self)
/// Commit the transaction.
pub async fn commit(&mut self) -> Result<ReadyForQueryStatus, Error> {
self.batch_execute("COMMIT").await
}
/// Constructs a cancellation token that can later be used to request cancellation of a query running on the

View File

@@ -1,58 +0,0 @@
#![allow(async_fn_in_trait)]
use crate::query::RowStream;
use crate::{Client, Error, Transaction};
mod private {
pub trait Sealed {}
}
/// A trait allowing abstraction over connections and transactions.
///
/// This trait is "sealed", and cannot be implemented outside of this crate.
pub trait GenericClient: private::Sealed {
/// Like `Client::query_raw_txt`.
async fn query_raw_txt<S, I>(
&mut self,
statement: &str,
params: I,
) -> Result<RowStream<'_>, Error>
where
S: AsRef<str> + Sync + Send,
I: IntoIterator<Item = Option<S>> + Sync + Send,
I::IntoIter: ExactSizeIterator + Sync + Send;
}
impl private::Sealed for Client {}
impl GenericClient for Client {
async fn query_raw_txt<S, I>(
&mut self,
statement: &str,
params: I,
) -> Result<RowStream<'_>, Error>
where
S: AsRef<str> + Sync + Send,
I: IntoIterator<Item = Option<S>> + Sync + Send,
I::IntoIter: ExactSizeIterator + Sync + Send,
{
self.query_raw_txt(statement, params).await
}
}
impl private::Sealed for Transaction<'_> {}
impl GenericClient for Transaction<'_> {
async fn query_raw_txt<S, I>(
&mut self,
statement: &str,
params: I,
) -> Result<RowStream<'_>, Error>
where
S: AsRef<str> + Sync + Send,
I: IntoIterator<Item = Option<S>> + Sync + Send,
I::IntoIter: ExactSizeIterator + Sync + Send,
{
self.query_raw_txt(statement, params).await
}
}

View File

@@ -9,13 +9,11 @@ pub use crate::config::Config;
pub use crate::connect_raw::RawConnection;
pub use crate::connection::Connection;
pub use crate::error::Error;
pub use crate::generic_client::GenericClient;
pub use crate::query::RowStream;
pub use crate::row::{Row, SimpleQueryRow};
pub use crate::simple_query::SimpleQueryStream;
pub use crate::statement::{Column, Statement};
pub use crate::tls::NoTls;
pub use crate::transaction::Transaction;
pub use crate::transaction_builder::{IsolationLevel, TransactionBuilder};
/// After executing a query, the connection will be in one of these states
@@ -55,7 +53,6 @@ mod connect_socket;
mod connect_tls;
mod connection;
pub mod error;
mod generic_client;
pub mod maybe_tls_stream;
mod prepare;
mod query;
@@ -63,7 +60,6 @@ pub mod row;
mod simple_query;
mod statement;
pub mod tls;
mod transaction;
mod transaction_builder;
pub mod types;

View File

@@ -1,73 +0,0 @@
use crate::query::RowStream;
use crate::{CancelToken, Client, Error, ReadyForQueryStatus};
/// A representation of a PostgreSQL database transaction.
///
/// Transactions will implicitly roll back when dropped. Use the `commit` method to commit the changes made in the
/// transaction. Transactions can be nested, with inner transactions implemented via safepoints.
pub struct Transaction<'a> {
client: &'a mut Client,
done: bool,
}
impl Drop for Transaction<'_> {
fn drop(&mut self) {
if self.done {
return;
}
let _ = self.client.inner_mut().send_simple_query("ROLLBACK");
}
}
impl<'a> Transaction<'a> {
pub(crate) fn new(client: &'a mut Client) -> Transaction<'a> {
Transaction {
client,
done: false,
}
}
/// Consumes the transaction, committing all changes made within it.
pub async fn commit(mut self) -> Result<ReadyForQueryStatus, Error> {
self.done = true;
self.client.batch_execute("COMMIT").await
}
/// Rolls the transaction back, discarding all changes made within it.
///
/// This is equivalent to `Transaction`'s `Drop` implementation, but provides any error encountered to the caller.
pub async fn rollback(mut self) -> Result<ReadyForQueryStatus, Error> {
self.done = true;
self.client.batch_execute("ROLLBACK").await
}
/// Like `Client::query_raw_txt`.
pub async fn query_raw_txt<S, I>(
&mut self,
statement: &str,
params: I,
) -> Result<RowStream<'_>, Error>
where
S: AsRef<str>,
I: IntoIterator<Item = Option<S>>,
I::IntoIter: ExactSizeIterator,
{
self.client.query_raw_txt(statement, params).await
}
/// Like `Client::cancel_token`.
pub fn cancel_token(&self) -> CancelToken {
self.client.cancel_token()
}
/// Returns a reference to the underlying `Client`.
pub fn client(&self) -> &Client {
self.client
}
/// Returns a reference to the underlying `Client`.
pub fn client_mut(&mut self) -> &mut Client {
self.client
}
}

View File

@@ -1,5 +1,3 @@
use crate::{Client, Error, Transaction};
/// The isolation level of a database transaction.
#[derive(Debug, Copy, Clone)]
#[non_exhaustive]
@@ -20,49 +18,17 @@ pub enum IsolationLevel {
}
/// A builder for database transactions.
pub struct TransactionBuilder<'a> {
client: &'a mut Client,
isolation_level: Option<IsolationLevel>,
read_only: Option<bool>,
deferrable: Option<bool>,
pub struct TransactionBuilder {
pub isolation_level: Option<IsolationLevel>,
pub read_only: Option<bool>,
pub deferrable: Option<bool>,
}
impl<'a> TransactionBuilder<'a> {
pub(crate) fn new(client: &'a mut Client) -> TransactionBuilder<'a> {
TransactionBuilder {
client,
isolation_level: None,
read_only: None,
deferrable: None,
}
}
/// Sets the isolation level of the transaction.
pub fn isolation_level(mut self, isolation_level: IsolationLevel) -> Self {
self.isolation_level = Some(isolation_level);
self
}
/// Sets the access mode of the transaction.
pub fn read_only(mut self, read_only: bool) -> Self {
self.read_only = Some(read_only);
self
}
/// Sets the deferrability of the transaction.
///
/// If the transaction is also serializable and read only, creation of the transaction may block, but when it
/// completes the transaction is able to run with less overhead and a guarantee that it will not be aborted due to
/// serialization failure.
pub fn deferrable(mut self, deferrable: bool) -> Self {
self.deferrable = Some(deferrable);
self
}
impl TransactionBuilder {
/// Begins the transaction.
///
/// The transaction will roll back by default - use the `commit` method to commit it.
pub async fn start(self) -> Result<Transaction<'a>, Error> {
pub fn format(self) -> String {
let mut query = "START TRANSACTION".to_string();
let mut first = true;
@@ -106,8 +72,6 @@ impl<'a> TransactionBuilder<'a> {
query.push_str(s);
}
self.client.batch_execute(&query).await?;
Ok(Transaction::new(self.client))
query
}
}

View File

@@ -14,7 +14,7 @@ use hyper::http::{HeaderName, HeaderValue};
use hyper::{Request, Response, StatusCode, header};
use indexmap::IndexMap;
use postgres_client::error::{DbError, ErrorPosition, SqlState};
use postgres_client::{GenericClient, IsolationLevel, NoTls, ReadyForQueryStatus, Transaction};
use postgres_client::{IsolationLevel, NoTls, ReadyForQueryStatus, TransactionBuilder};
use serde_json::Value;
use serde_json::value::RawValue;
use tokio::time::{self, Instant};
@@ -759,39 +759,33 @@ impl BatchQueryData {
info!("starting transaction");
let (inner, mut discard) = client.inner();
let cancel_token = inner.cancel_token();
let mut builder = inner.build_transaction();
if let Some(isolation_level) = parsed_headers.txn_isolation_level {
builder = builder.isolation_level(isolation_level);
}
if parsed_headers.txn_read_only {
builder = builder.read_only(true);
}
if parsed_headers.txn_deferrable {
builder = builder.deferrable(true);
}
let mut transaction = builder
.start()
.await
.inspect_err(|_| {
// if we cannot start a transaction, we should return immediately
// and not return to the pool. connection is clearly broken
discard.discard();
})
.map_err(SqlOverHttpError::Postgres)?;
let json_output = match query_batch_to_json(
config,
cancel.child_token(),
&mut transaction,
self,
parsed_headers,
)
.await
{
let query = TransactionBuilder {
isolation_level: parsed_headers.txn_isolation_level,
read_only: parsed_headers.txn_read_only.then_some(true),
deferrable: parsed_headers.txn_deferrable.then_some(true),
}
.format();
inner
.batch_execute(&query)
.await
.inspect_err(|_| {
// if we cannot start a transaction, we should return immediately
// and not return to the pool. connection is clearly broken
discard.discard();
})
.map_err(SqlOverHttpError::Postgres)?;
}
let res =
query_batch_to_json(config, cancel.child_token(), inner, self, parsed_headers).await;
let json_output = match res {
Ok(json_output) => {
info!("commit");
transaction
inner
.commit()
.await
.inspect_err(|_| {
@@ -823,7 +817,7 @@ impl BatchQueryData {
async fn query_batch(
config: &'static HttpConfig,
cancel: CancellationToken,
transaction: &mut Transaction<'_>,
client: &mut postgres_client::Client,
queries: BatchQueryData,
parsed_headers: HttpHeaders,
results: &mut json::ListSer<'_>,
@@ -831,7 +825,7 @@ async fn query_batch(
for stmt in queries.queries {
let query = pin!(query_to_json(
config,
transaction,
client,
stmt,
results.entry(),
parsed_headers,
@@ -856,23 +850,23 @@ async fn query_batch(
async fn query_batch_to_json(
config: &'static HttpConfig,
cancel: CancellationToken,
tx: &mut Transaction<'_>,
client: &mut postgres_client::Client,
queries: BatchQueryData,
headers: HttpHeaders,
) -> Result<String, SqlOverHttpError> {
let json_output = json::value_to_string!(|obj| json::value_as_object!(|obj| {
let results = obj.key("results");
json::value_as_list!(|results| {
query_batch(config, cancel, tx, queries, headers, results).await?;
query_batch(config, cancel, client, queries, headers, results).await?;
});
}));
Ok(json_output)
}
async fn query_to_json<T: GenericClient>(
async fn query_to_json(
config: &'static HttpConfig,
client: &mut T,
client: &mut postgres_client::Client,
data: QueryData,
output: json::ValueSer<'_>,
parsed_headers: HttpHeaders,