Compare commits

...

12 Commits

Author SHA1 Message Date
Conrad Ludgate
e965bd96bb optimise some future sizes 2025-07-23 08:58:09 +01:00
Conrad Ludgate
14daaec98e more compact code and more compact futures 2025-07-23 08:58:09 +01:00
Conrad Ludgate
286ac97a9c remove typesafe transaction type as we already ensure rollback is performed 2025-07-23 08:58:09 +01:00
Conrad Ludgate
20355cb5f0 fix rest.rs 2025-07-23 07:04:36 +01:00
Conrad Ludgate
634dbd29b6 python lints 2025-07-23 07:04:36 +01:00
Conrad Ludgate
a235b241d5 ruff format 2025-07-23 07:04:36 +01:00
Conrad Ludgate
539652fa4e rollback safety 2025-07-23 07:04:36 +01:00
Conrad Ludgate
11294ca322 rename Send to LocalProxyClient, etc 2025-07-23 07:04:36 +01:00
Conrad Ludgate
84020c1328 fix python lints 2025-07-23 07:04:36 +01:00
Conrad Ludgate
0cc7415691 remove explicit discard_all for local_proxy 2025-07-23 07:04:36 +01:00
Conrad Ludgate
38df46b381 fix session state by resetting it 2025-07-23 07:04:36 +01:00
Conrad Ludgate
cdc73ad051 add regression test 2025-07-23 07:04:36 +01:00
16 changed files with 348 additions and 525 deletions

View File

@@ -600,6 +600,7 @@ impl ParameterStatusBody {
}
}
#[derive(Clone, Copy)]
pub struct ReadyForQueryBody {
status: u8,
}

View File

@@ -18,10 +18,7 @@ use crate::config::{Host, SslMode};
use crate::query::RowStream;
use crate::simple_query::SimpleQueryStream;
use crate::types::{Oid, Type};
use crate::{
CancelToken, Error, ReadyForQueryStatus, SimpleQueryMessage, Transaction, TransactionBuilder,
query, simple_query,
};
use crate::{CancelToken, Error, ReadyForQueryStatus, SimpleQueryMessage, query, simple_query};
pub struct Responses {
/// new messages from conn
@@ -32,6 +29,9 @@ pub struct Responses {
waiting: usize,
/// number of ReadyForQuery messages received.
received: usize,
/// The last query status we received.
last_status: ReadyForQueryStatus,
}
impl Responses {
@@ -42,7 +42,8 @@ impl Responses {
let received = self.received;
// increase the query head if this is the last message.
if let Message::ReadyForQuery(_) = message {
if let Message::ReadyForQuery(ref status) = message {
self.last_status = (*status).into();
self.received += 1;
}
@@ -71,6 +72,15 @@ impl Responses {
pub async fn next(&mut self) -> Result<Message, Error> {
future::poll_fn(|cx| self.poll_next(cx)).await
}
pub async fn wait_until_ready(&mut self) -> Result<ReadyForQueryStatus, Error> {
while self.received < self.waiting {
if let Message::ReadyForQuery(status) = self.next().await? {
return Ok(status.into());
}
}
Ok(self.last_status)
}
}
/// A cache of type info and prepared statements for fetching type info
@@ -95,13 +105,6 @@ impl InnerClient {
Ok(PartialQuery(Some(self)))
}
// pub fn send_with_sync<F>(&mut self, f: F) -> Result<&mut Responses, Error>
// where
// F: FnOnce(&mut BytesMut) -> Result<(), Error>,
// {
// self.start()?.send_with_sync(f)
// }
pub fn send_simple_query(&mut self, query: &str) -> Result<&mut Responses, Error> {
self.responses.waiting += 1;
@@ -200,6 +203,8 @@ impl Client {
cur: BackendMessages::empty(),
waiting: 0,
received: 0,
// new connections are always idle.
last_status: ReadyForQueryStatus::Idle,
},
buffer: Default::default(),
},
@@ -233,6 +238,11 @@ impl Client {
rx
}
/// Wait until this connection has no more active queries.
pub async fn wait_until_ready(&mut self) -> Result<ReadyForQueryStatus, Error> {
self.inner_mut().responses.wait_until_ready().await
}
/// Pass text directly to the Postgres backend to allow it to sort out typing itself and
/// to save a roundtrip
pub async fn query_raw_txt<S, I>(
@@ -292,52 +302,32 @@ impl Client {
simple_query::batch_execute(self.inner_mut(), query).await
}
pub async fn discard_all(&mut self) -> Result<ReadyForQueryStatus, Error> {
self.batch_execute("discard all").await
}
/// Begins a new database transaction.
/// Similar to `discard_all`, but it does not clear any query plans
///
/// The transaction will roll back by default - use the `commit` method to commit it.
pub async fn transaction(&mut self) -> Result<Transaction<'_>, Error> {
struct RollbackIfNotDone<'me> {
client: &'me mut Client,
done: bool,
}
/// This runs in the background, so it can be executed without `await`ing.
pub fn reset_session_background(&mut self) -> Result<(), Error> {
// "CLOSE ALL": closes any cursors
// "SET SESSION AUTHORIZATION DEFAULT": resets the current_user back to the session_user
// "RESET ALL": resets any GUCs back to their session defaults.
// "DEALLOCATE ALL": deallocates any prepared statements
// "UNLISTEN *": stops listening on all channels
// "SELECT pg_advisory_unlock_all();": unlocks all advisory locks
// "DISCARD TEMP;": drops all temporary tables
// "DISCARD SEQUENCES;": deallocates all cached sequence state
impl Drop for RollbackIfNotDone<'_> {
fn drop(&mut self) {
if self.done {
return;
}
let _responses = self.inner_mut().send_simple_query(
"ROLLBACK;
CLOSE ALL;
SET SESSION AUTHORIZATION DEFAULT;
RESET ALL;
DEALLOCATE ALL;
UNLISTEN *;
SELECT pg_advisory_unlock_all();
DISCARD TEMP;
DISCARD SEQUENCES;",
)?;
let _ = self.client.inner.send_simple_query("ROLLBACK");
}
}
// This is done, as `Future` created by this method can be dropped after
// `RequestMessages` is synchronously send to the `Connection` by
// `batch_execute()`, but before `Responses` is asynchronously polled to
// completion. In that case `Transaction` won't be created and thus
// won't be rolled back.
{
let mut cleaner = RollbackIfNotDone {
client: self,
done: false,
};
cleaner.client.batch_execute("BEGIN").await?;
cleaner.done = true;
}
Ok(Transaction::new(self))
}
/// Returns a builder for a transaction with custom settings.
///
/// Unlike the `transaction` method, the builder can be used to control the transaction's isolation level and other
/// attributes.
pub fn build_transaction(&mut self) -> TransactionBuilder<'_> {
TransactionBuilder::new(self)
Ok(())
}
/// Constructs a cancellation token that can later be used to request cancellation of a query running on the

View File

@@ -1,58 +0,0 @@
#![allow(async_fn_in_trait)]
use crate::query::RowStream;
use crate::{Client, Error, Transaction};
mod private {
pub trait Sealed {}
}
/// A trait allowing abstraction over connections and transactions.
///
/// This trait is "sealed", and cannot be implemented outside of this crate.
pub trait GenericClient: private::Sealed {
/// Like `Client::query_raw_txt`.
async fn query_raw_txt<S, I>(
&mut self,
statement: &str,
params: I,
) -> Result<RowStream<'_>, Error>
where
S: AsRef<str> + Sync + Send,
I: IntoIterator<Item = Option<S>> + Sync + Send,
I::IntoIter: ExactSizeIterator + Sync + Send;
}
impl private::Sealed for Client {}
impl GenericClient for Client {
async fn query_raw_txt<S, I>(
&mut self,
statement: &str,
params: I,
) -> Result<RowStream<'_>, Error>
where
S: AsRef<str> + Sync + Send,
I: IntoIterator<Item = Option<S>> + Sync + Send,
I::IntoIter: ExactSizeIterator + Sync + Send,
{
self.query_raw_txt(statement, params).await
}
}
impl private::Sealed for Transaction<'_> {}
impl GenericClient for Transaction<'_> {
async fn query_raw_txt<S, I>(
&mut self,
statement: &str,
params: I,
) -> Result<RowStream<'_>, Error>
where
S: AsRef<str> + Sync + Send,
I: IntoIterator<Item = Option<S>> + Sync + Send,
I::IntoIter: ExactSizeIterator + Sync + Send,
{
self.query_raw_txt(statement, params).await
}
}

View File

@@ -9,13 +9,11 @@ pub use crate::config::Config;
pub use crate::connect_raw::RawConnection;
pub use crate::connection::Connection;
pub use crate::error::Error;
pub use crate::generic_client::GenericClient;
pub use crate::query::RowStream;
pub use crate::row::{Row, SimpleQueryRow};
pub use crate::simple_query::SimpleQueryStream;
pub use crate::statement::{Column, Statement};
pub use crate::tls::NoTls;
pub use crate::transaction::Transaction;
pub use crate::transaction_builder::{IsolationLevel, TransactionBuilder};
/// After executing a query, the connection will be in one of these states
@@ -55,7 +53,6 @@ mod connect_socket;
mod connect_tls;
mod connection;
pub mod error;
mod generic_client;
pub mod maybe_tls_stream;
mod prepare;
mod query;
@@ -63,7 +60,6 @@ pub mod row;
mod simple_query;
mod statement;
pub mod tls;
mod transaction;
mod transaction_builder;
pub mod types;

View File

@@ -1,73 +0,0 @@
use crate::query::RowStream;
use crate::{CancelToken, Client, Error, ReadyForQueryStatus};
/// A representation of a PostgreSQL database transaction.
///
/// Transactions will implicitly roll back when dropped. Use the `commit` method to commit the changes made in the
/// transaction. Transactions can be nested, with inner transactions implemented via safepoints.
pub struct Transaction<'a> {
client: &'a mut Client,
done: bool,
}
impl Drop for Transaction<'_> {
fn drop(&mut self) {
if self.done {
return;
}
let _ = self.client.inner_mut().send_simple_query("ROLLBACK");
}
}
impl<'a> Transaction<'a> {
pub(crate) fn new(client: &'a mut Client) -> Transaction<'a> {
Transaction {
client,
done: false,
}
}
/// Consumes the transaction, committing all changes made within it.
pub async fn commit(mut self) -> Result<ReadyForQueryStatus, Error> {
self.done = true;
self.client.batch_execute("COMMIT").await
}
/// Rolls the transaction back, discarding all changes made within it.
///
/// This is equivalent to `Transaction`'s `Drop` implementation, but provides any error encountered to the caller.
pub async fn rollback(mut self) -> Result<ReadyForQueryStatus, Error> {
self.done = true;
self.client.batch_execute("ROLLBACK").await
}
/// Like `Client::query_raw_txt`.
pub async fn query_raw_txt<S, I>(
&mut self,
statement: &str,
params: I,
) -> Result<RowStream<'_>, Error>
where
S: AsRef<str>,
I: IntoIterator<Item = Option<S>>,
I::IntoIter: ExactSizeIterator,
{
self.client.query_raw_txt(statement, params).await
}
/// Like `Client::cancel_token`.
pub fn cancel_token(&self) -> CancelToken {
self.client.cancel_token()
}
/// Returns a reference to the underlying `Client`.
pub fn client(&self) -> &Client {
self.client
}
/// Returns a reference to the underlying `Client`.
pub fn client_mut(&mut self) -> &mut Client {
self.client
}
}

View File

@@ -1,5 +1,3 @@
use crate::{Client, Error, Transaction};
/// The isolation level of a database transaction.
#[derive(Debug, Copy, Clone)]
#[non_exhaustive]
@@ -20,49 +18,17 @@ pub enum IsolationLevel {
}
/// A builder for database transactions.
pub struct TransactionBuilder<'a> {
client: &'a mut Client,
isolation_level: Option<IsolationLevel>,
read_only: Option<bool>,
deferrable: Option<bool>,
pub struct TransactionBuilder {
pub isolation_level: Option<IsolationLevel>,
pub read_only: Option<bool>,
pub deferrable: Option<bool>,
}
impl<'a> TransactionBuilder<'a> {
pub(crate) fn new(client: &'a mut Client) -> TransactionBuilder<'a> {
TransactionBuilder {
client,
isolation_level: None,
read_only: None,
deferrable: None,
}
}
/// Sets the isolation level of the transaction.
pub fn isolation_level(mut self, isolation_level: IsolationLevel) -> Self {
self.isolation_level = Some(isolation_level);
self
}
/// Sets the access mode of the transaction.
pub fn read_only(mut self, read_only: bool) -> Self {
self.read_only = Some(read_only);
self
}
/// Sets the deferrability of the transaction.
///
/// If the transaction is also serializable and read only, creation of the transaction may block, but when it
/// completes the transaction is able to run with less overhead and a guarantee that it will not be aborted due to
/// serialization failure.
pub fn deferrable(mut self, deferrable: bool) -> Self {
self.deferrable = Some(deferrable);
self
}
impl TransactionBuilder {
/// Begins the transaction.
///
/// The transaction will roll back by default - use the `commit` method to commit it.
pub async fn start(self) -> Result<Transaction<'a>, Error> {
pub fn format(self) -> String {
let mut query = "START TRANSACTION".to_string();
let mut first = true;
@@ -106,8 +72,6 @@ impl<'a> TransactionBuilder<'a> {
query.push_str(s);
}
self.client.batch_execute(&query).await?;
Ok(Transaction::new(self.client))
query
}
}

View File

@@ -18,7 +18,7 @@ use tracing::{debug, info};
use super::AsyncRW;
use super::conn_pool::poll_client;
use super::conn_pool_lib::{Client, ConnInfo, EndpointConnPool, GlobalConnPool};
use super::http_conn_pool::{self, HttpConnPool, Send, poll_http2_client};
use super::http_conn_pool::{self, HttpConnPool, LocalProxyClient, poll_http2_client};
use super::local_conn_pool::{self, EXT_NAME, EXT_SCHEMA, EXT_VERSION, LocalConnPool};
use crate::auth::backend::local::StaticAuthRules;
use crate::auth::backend::{ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo};
@@ -40,7 +40,8 @@ use crate::rate_limiter::EndpointRateLimiter;
use crate::types::{EndpointId, Host, LOCAL_PROXY_SUFFIX};
pub(crate) struct PoolingBackend {
pub(crate) http_conn_pool: Arc<GlobalConnPool<Send, HttpConnPool<Send>>>,
pub(crate) http_conn_pool:
Arc<GlobalConnPool<LocalProxyClient, HttpConnPool<LocalProxyClient>>>,
pub(crate) local_pool: Arc<LocalConnPool<postgres_client::Client>>,
pub(crate) pool:
Arc<GlobalConnPool<postgres_client::Client, EndpointConnPool<postgres_client::Client>>>,
@@ -210,7 +211,7 @@ impl PoolingBackend {
&self,
ctx: &RequestContext,
conn_info: ConnInfo,
) -> Result<http_conn_pool::Client<Send>, HttpConnError> {
) -> Result<http_conn_pool::Client<LocalProxyClient>, HttpConnError> {
debug!("pool: looking for an existing connection");
if let Ok(Some(client)) = self.http_conn_pool.get(ctx, &conn_info) {
return Ok(client);
@@ -568,7 +569,7 @@ impl ConnectMechanism for TokioMechanism {
}
struct HyperMechanism {
pool: Arc<GlobalConnPool<Send, HttpConnPool<Send>>>,
pool: Arc<GlobalConnPool<LocalProxyClient, HttpConnPool<LocalProxyClient>>>,
conn_info: ConnInfo,
conn_id: uuid::Uuid,
@@ -578,7 +579,7 @@ struct HyperMechanism {
#[async_trait]
impl ConnectMechanism for HyperMechanism {
type Connection = http_conn_pool::Client<Send>;
type Connection = http_conn_pool::Client<LocalProxyClient>;
type ConnectError = HttpConnError;
type Error = HttpConnError;
@@ -632,7 +633,13 @@ async fn connect_http2(
port: u16,
timeout: Duration,
tls: Option<&Arc<rustls::ClientConfig>>,
) -> Result<(http_conn_pool::Send, http_conn_pool::Connect), LocalProxyConnError> {
) -> Result<
(
http_conn_pool::LocalProxyClient,
http_conn_pool::LocalProxyConnection,
),
LocalProxyConnError,
> {
let addrs = match host_addr {
Some(addr) => vec![SocketAddr::new(addr, port)],
None => lookup_host((host, port))

View File

@@ -190,6 +190,9 @@ mod tests {
fn get_process_id(&self) -> i32 {
0
}
fn reset(&mut self) -> Result<(), postgres_client::Error> {
Ok(())
}
}
fn create_inner() -> ClientInnerCommon<MockClient> {

View File

@@ -7,10 +7,9 @@ use std::time::Duration;
use clashmap::ClashMap;
use parking_lot::RwLock;
use postgres_client::ReadyForQueryStatus;
use rand::Rng;
use smol_str::ToSmolStr;
use tracing::{Span, debug, info};
use tracing::{Span, debug, info, warn};
use super::backend::HttpConnError;
use super::conn_pool::ClientDataRemote;
@@ -188,7 +187,7 @@ impl<C: ClientInnerExt> EndpointConnPool<C> {
self.pools.get_mut(&db_user)
}
pub(crate) fn put(pool: &RwLock<Self>, conn_info: &ConnInfo, client: ClientInnerCommon<C>) {
pub(crate) fn put(pool: &RwLock<Self>, conn_info: &ConnInfo, mut client: ClientInnerCommon<C>) {
let conn_id = client.get_conn_id();
let (max_conn, conn_count, pool_name) = {
let pool = pool.read();
@@ -201,12 +200,17 @@ impl<C: ClientInnerExt> EndpointConnPool<C> {
};
if client.inner.is_closed() {
info!(%conn_id, "{}: throwing away connection '{conn_info}' because connection is closed", pool_name);
info!(%conn_id, "{pool_name}: throwing away connection '{conn_info}' because connection is closed");
return;
}
if let Err(error) = client.inner.reset() {
warn!(?error, %conn_id, "{pool_name}: throwing away connection '{conn_info}' because connection could not be reset");
return;
}
if conn_count >= max_conn {
info!(%conn_id, "{}: throwing away connection '{conn_info}' because pool is full", pool_name);
info!(%conn_id, "{pool_name}: throwing away connection '{conn_info}' because pool is full");
return;
}
@@ -691,6 +695,7 @@ impl<C: ClientInnerExt> Deref for Client<C> {
pub(crate) trait ClientInnerExt: Sync + Send + 'static {
fn is_closed(&self) -> bool;
fn get_process_id(&self) -> i32;
fn reset(&mut self) -> Result<(), postgres_client::Error>;
}
impl ClientInnerExt for postgres_client::Client {
@@ -701,15 +706,13 @@ impl ClientInnerExt for postgres_client::Client {
fn get_process_id(&self) -> i32 {
self.get_process_id()
}
fn reset(&mut self) -> Result<(), postgres_client::Error> {
self.reset_session_background()
}
}
impl<C: ClientInnerExt> Discard<'_, C> {
pub(crate) fn check_idle(&mut self, status: ReadyForQueryStatus) {
let conn_info = &self.conn_info;
if status != ReadyForQueryStatus::Idle && std::mem::take(self.pool).strong_count() > 0 {
info!("pool: throwing away connection '{conn_info}' because connection is not idle");
}
}
pub(crate) fn discard(&mut self) {
let conn_info = &self.conn_info;
if std::mem::take(self.pool).strong_count() > 0 {

View File

@@ -23,8 +23,8 @@ use crate::protocol2::ConnectionInfoExtra;
use crate::types::EndpointCacheKey;
use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS};
pub(crate) type Send = http2::SendRequest<BoxBody<Bytes, hyper::Error>>;
pub(crate) type Connect =
pub(crate) type LocalProxyClient = http2::SendRequest<BoxBody<Bytes, hyper::Error>>;
pub(crate) type LocalProxyConnection =
http2::Connection<TokioIo<AsyncRW>, BoxBody<Bytes, hyper::Error>, TokioExecutor>;
#[derive(Clone)]
@@ -189,14 +189,14 @@ impl<C: ClientInnerExt + Clone> GlobalConnPool<C, HttpConnPool<C>> {
}
pub(crate) fn poll_http2_client(
global_pool: Arc<GlobalConnPool<Send, HttpConnPool<Send>>>,
global_pool: Arc<GlobalConnPool<LocalProxyClient, HttpConnPool<LocalProxyClient>>>,
ctx: &RequestContext,
conn_info: &ConnInfo,
client: Send,
connection: Connect,
client: LocalProxyClient,
connection: LocalProxyConnection,
conn_id: uuid::Uuid,
aux: MetricsAuxInfo,
) -> Client<Send> {
) -> Client<LocalProxyClient> {
let conn_gauge = Metrics::get().proxy.db_connections.guard(ctx.protocol());
let session_id = ctx.session_id();
@@ -285,7 +285,7 @@ impl<C: ClientInnerExt + Clone> Client<C> {
}
}
impl ClientInnerExt for Send {
impl ClientInnerExt for LocalProxyClient {
fn is_closed(&self) -> bool {
self.is_closed()
}
@@ -294,4 +294,10 @@ impl ClientInnerExt for Send {
// ideally throw something meaningful
-1
}
fn reset(&mut self) -> Result<(), postgres_client::Error> {
// We use HTTP/2.0 to talk to local proxy. HTTP is stateless,
// so there's nothing to reset.
Ok(())
}
}

View File

@@ -269,11 +269,6 @@ impl ClientInnerCommon<postgres_client::Client> {
local_data.jti += 1;
let token = resign_jwt(&local_data.key, payload, local_data.jti)?;
self.inner
.discard_all()
.await
.map_err(SqlOverHttpError::InternalPostgres)?;
// initiates the auth session
// this is safe from query injections as the jwt format free of any escape characters.
let query = format!("select auth.jwt_session_init('{token}')");

View File

@@ -46,7 +46,7 @@ use super::backend::{HttpConnError, LocalProxyConnError, PoolingBackend};
use super::conn_pool::AuthData;
use super::conn_pool_lib::ConnInfo;
use super::error::{ConnInfoError, Credentials, HttpCodeError, ReadPayloadError};
use super::http_conn_pool::{self, Send};
use super::http_conn_pool::{self, LocalProxyClient};
use super::http_util::{
ALLOW_POOL, CONN_STRING, NEON_REQUEST_ID, RAW_TEXT_OUTPUT, TXN_ISOLATION_LEVEL, TXN_READ_ONLY,
get_conn_info, json_response, uuid_to_header_value,
@@ -145,7 +145,7 @@ impl DbSchemaCache {
endpoint_id: &EndpointCacheKey,
auth_header: &HeaderValue,
connection_string: &str,
client: &mut http_conn_pool::Client<Send>,
client: &mut http_conn_pool::Client<LocalProxyClient>,
ctx: &RequestContext,
config: &'static ProxyConfig,
) -> Result<Arc<(ApiConfig, DbSchemaOwned)>, RestError> {
@@ -190,7 +190,7 @@ impl DbSchemaCache {
&self,
auth_header: &HeaderValue,
connection_string: &str,
client: &mut http_conn_pool::Client<Send>,
client: &mut http_conn_pool::Client<LocalProxyClient>,
ctx: &RequestContext,
config: &'static ProxyConfig,
) -> Result<(ApiConfig, DbSchemaOwned), RestError> {
@@ -430,7 +430,7 @@ struct BatchQueryData<'a> {
}
async fn make_local_proxy_request<S: DeserializeOwned>(
client: &mut http_conn_pool::Client<Send>,
client: &mut http_conn_pool::Client<LocalProxyClient>,
headers: impl IntoIterator<Item = (&HeaderName, HeaderValue)>,
body: QueryData<'_>,
max_len: usize,
@@ -461,7 +461,7 @@ async fn make_local_proxy_request<S: DeserializeOwned>(
}
async fn make_raw_local_proxy_request(
client: &mut http_conn_pool::Client<Send>,
client: &mut http_conn_pool::Client<LocalProxyClient>,
headers: impl IntoIterator<Item = (&HeaderName, HeaderValue)>,
body: String,
) -> Result<Response<Incoming>, RestError> {

View File

@@ -1,9 +1,8 @@
use std::pin::pin;
use std::sync::Arc;
use bytes::Bytes;
use futures::future::{Either, select, try_join};
use futures::{StreamExt, TryFutureExt};
use futures::future::try_join;
use futures::{TryFutureExt, TryStreamExt};
use http::Method;
use http::header::AUTHORIZATION;
use http_body_util::combinators::BoxBody;
@@ -14,7 +13,7 @@ use hyper::http::{HeaderName, HeaderValue};
use hyper::{Request, Response, StatusCode, header};
use indexmap::IndexMap;
use postgres_client::error::{DbError, ErrorPosition, SqlState};
use postgres_client::{GenericClient, IsolationLevel, NoTls, ReadyForQueryStatus, Transaction};
use postgres_client::{IsolationLevel, NoTls, ReadyForQueryStatus, TransactionBuilder};
use serde_json::Value;
use serde_json::value::RawValue;
use tokio::time::{self, Instant};
@@ -495,7 +494,7 @@ async fn handle_db_inner(
.http_conn_content_length_bytes
.observe(HttpDirection::Request, body.len() as f64);
debug!(length = body.len(), "request payload read");
debug!(length = body.len(), "request payload read ");
let payload: Payload = serde_json::from_slice(&body)?;
Ok::<Payload, ReadPayloadError>(payload) // Adjust error type accordingly
}
@@ -530,13 +529,13 @@ async fn handle_db_inner(
let (cli_inner, _dsc) = client.client_inner();
cli_inner.set_jwt_session(&payload).await?;
}
Client::Local(client)
Box::new(Client::Local(client))
}
_ => {
let client = backend
.connect_to_compute(ctx, conn_info, keys, !allow_pool)
.await?;
Client::Remote(client)
Box::new(Client::Remote(client))
}
};
@@ -550,10 +549,7 @@ async fn handle_db_inner(
let (payload, mut client) = match run_until_cancelled(
// Run both operations in parallel
try_join(
pin!(fetch_and_process_request),
pin!(authenticate_and_connect),
),
try_join(fetch_and_process_request, authenticate_and_connect),
&cancel,
)
.await
@@ -562,37 +558,38 @@ async fn handle_db_inner(
None => return Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Connect)),
};
// Now execute the query and return the result.
let json_output = match payload
.process(&config.http_config, cancel, &mut client, parsed_headers)
.await
{
Ok(json_output) => json_output,
Err(error) => {
if let SqlOverHttpError::Cancelled(_) = error {
cancel_query(&mut client).await;
}
return Err(error);
}
};
let mut response = Response::builder()
.status(StatusCode::OK)
.header(header::CONTENT_TYPE, "application/json");
// Now execute the query and return the result.
let json_output = match payload {
Payload::Single(stmt) => {
stmt.process(&config.http_config, cancel, &mut client, parsed_headers)
.await?
if let Payload::Batch(_) = payload {
if parsed_headers.txn_read_only {
response = response.header(TXN_READ_ONLY.clone(), &HEADER_VALUE_TRUE);
}
Payload::Batch(statements) => {
if parsed_headers.txn_read_only {
response = response.header(TXN_READ_ONLY.clone(), &HEADER_VALUE_TRUE);
}
if parsed_headers.txn_deferrable {
response = response.header(TXN_DEFERRABLE.clone(), &HEADER_VALUE_TRUE);
}
if let Some(txn_isolation_level) = parsed_headers
.txn_isolation_level
.and_then(map_isolation_level_to_headers)
{
response = response.header(TXN_ISOLATION_LEVEL.clone(), txn_isolation_level);
}
statements
.process(&config.http_config, cancel, &mut client, parsed_headers)
.await?
if parsed_headers.txn_deferrable {
response = response.header(TXN_DEFERRABLE.clone(), &HEADER_VALUE_TRUE);
}
};
let metrics = client.metrics(ctx);
if let Some(txn_isolation_level) = parsed_headers
.txn_isolation_level
.and_then(map_isolation_level_to_headers)
{
response = response.header(TXN_ISOLATION_LEVEL.clone(), txn_isolation_level);
}
}
let len = json_output.len();
let response = response
@@ -607,6 +604,7 @@ async fn handle_db_inner(
// count the egress bytes - we miss the TLS and header overhead but oh well...
// moving this later in the stack is going to be a lot of effort and ehhhh
let metrics = client.metrics(ctx);
metrics.record_egress(len as u64);
metrics.record_ingress(request_len as u64);
@@ -673,233 +671,123 @@ async fn handle_auth_broker_inner(
.map(|b| b.boxed()))
}
impl QueryData {
impl Payload {
async fn process(
self,
&self,
config: &'static HttpConfig,
cancel: CancellationToken,
client: &mut Client,
parsed_headers: HttpHeaders,
) -> Result<String, SqlOverHttpError> {
let (inner, mut discard) = client.inner();
let cancel_token = inner.cancel_token();
let mut json_buf = vec![];
let needs_tx = matches!(self, Payload::Batch(_));
let batch_result = match select(
pin!(query_to_json(
config,
&mut *inner,
self,
json::ValueSer::new(&mut json_buf),
parsed_headers
)),
pin!(cancel.cancelled()),
)
.await
{
Either::Left((res, __not_yet_cancelled)) => res,
Either::Right((_cancelled, query)) => {
tracing::info!("cancelling query");
if let Err(err) = cancel_token.cancel_query(NoTls).await {
tracing::warn!(?err, "could not cancel query");
if needs_tx {
info!("starting transaction");
let query = TransactionBuilder {
isolation_level: parsed_headers.txn_isolation_level,
read_only: parsed_headers.txn_read_only.then_some(true),
deferrable: parsed_headers.txn_deferrable.then_some(true),
}
.format();
inner
.batch_execute(&query)
.await
.inspect_err(|_| {
// if we cannot start a transaction, we should return immediately
// and not return to the pool. connection is clearly broken
discard.discard();
})
.map_err(SqlOverHttpError::Postgres)?;
}
let json_output = json::value_to_string!(|value| match &self {
Payload::Single(query) => {
query_to_json(config, &cancel, inner, query, value, parsed_headers).await?;
}
Payload::Batch(batch) => {
let mut obj = value.object();
let mut results = obj.key("results").list();
for query in &batch.queries {
let value = results.entry();
query_to_json(config, &cancel, inner, query, value, parsed_headers).await?;
}
// wait for the query cancellation
match time::timeout(time::Duration::from_millis(100), query).await {
// query successed before it was cancelled.
Ok(Ok(status)) => Ok(status),
// query failed or was cancelled.
Ok(Err(error)) => {
let db_error = match &error {
SqlOverHttpError::ConnectCompute(
HttpConnError::PostgresConnectionError(e),
)
| SqlOverHttpError::Postgres(e) => e.as_db_error(),
_ => None,
};
// if errored for some other reason, it might not be safe to return
if !db_error.is_some_and(|e| *e.code() == SqlState::QUERY_CANCELED) {
discard.discard();
}
return Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Postgres));
}
Err(_timeout) => {
discard.discard();
return Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Postgres));
}
}
results.finish();
obj.finish();
}
};
});
match batch_result {
// The query successfully completed.
Ok(status) => {
discard.check_idle(status);
let json_output = String::from_utf8(json_buf).expect("json should be valid utf8");
Ok(json_output)
}
// The query failed with an error
Err(e) => {
discard.discard();
Err(e)
}
if needs_tx {
inner
.batch_execute("COMMIT")
.await
.inspect_err(|_| {
// if we cannot commit - for now don't return connection to pool
// TODO: get a query status from the error
discard.discard();
})
.map_err(SqlOverHttpError::Postgres)?;
}
}
}
impl BatchQueryData {
async fn process(
self,
config: &'static HttpConfig,
cancel: CancellationToken,
client: &mut Client,
parsed_headers: HttpHeaders,
) -> Result<String, SqlOverHttpError> {
info!("starting transaction");
let (inner, mut discard) = client.inner();
let cancel_token = inner.cancel_token();
let mut builder = inner.build_transaction();
if let Some(isolation_level) = parsed_headers.txn_isolation_level {
builder = builder.isolation_level(isolation_level);
}
if parsed_headers.txn_read_only {
builder = builder.read_only(true);
}
if parsed_headers.txn_deferrable {
builder = builder.deferrable(true);
}
let mut transaction = builder
.start()
.await
.inspect_err(|_| {
// if we cannot start a transaction, we should return immediately
// and not return to the pool. connection is clearly broken
discard.discard();
})
.map_err(SqlOverHttpError::Postgres)?;
let json_output = match query_batch_to_json(
config,
cancel.child_token(),
&mut transaction,
self,
parsed_headers,
)
.await
{
Ok(json_output) => {
info!("commit");
let status = transaction
.commit()
.await
.inspect_err(|_| {
// if we cannot commit - for now don't return connection to pool
// TODO: get a query status from the error
discard.discard();
})
.map_err(SqlOverHttpError::Postgres)?;
discard.check_idle(status);
json_output
}
Err(SqlOverHttpError::Cancelled(_)) => {
if let Err(err) = cancel_token.cancel_query(NoTls).await {
tracing::warn!(?err, "could not cancel query");
}
// TODO: after cancelling, wait to see if we can get a status. maybe the connection is still safe.
discard.discard();
return Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Postgres));
}
Err(err) => {
info!("rollback");
let status = transaction
.rollback()
.await
.inspect_err(|_| {
// if we cannot rollback - for now don't return connection to pool
// TODO: get a query status from the error
discard.discard();
})
.map_err(SqlOverHttpError::Postgres)?;
discard.check_idle(status);
return Err(err);
}
};
Ok(json_output)
}
}
async fn query_batch(
config: &'static HttpConfig,
cancel: CancellationToken,
transaction: &mut Transaction<'_>,
queries: BatchQueryData,
parsed_headers: HttpHeaders,
results: &mut json::ListSer<'_>,
) -> Result<(), SqlOverHttpError> {
for stmt in queries.queries {
let query = pin!(query_to_json(
config,
transaction,
stmt,
results.entry(),
parsed_headers,
));
let cancelled = pin!(cancel.cancelled());
let res = select(query, cancelled).await;
match res {
// TODO: maybe we should check that the transaction bit is set here
Either::Left((Ok(_), _cancelled)) => {}
Either::Left((Err(e), _cancelled)) => {
return Err(e);
}
Either::Right((_cancelled, _)) => {
return Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Postgres));
}
}
async fn cancel_query(client: &mut Client) {
let (inner, mut discard) = client.inner();
let cancel_token = inner.cancel_token();
if let Err(err) = cancel_token.cancel_query(NoTls).await {
tracing::warn!(?err, "could not cancel query");
// couldn't reach the server. let's just throw away this conn
discard.discard();
return;
}
Ok(())
// wait for the query cancellation
match time::timeout(time::Duration::from_millis(100), inner.wait_until_ready()).await {
// we managed to cancel the query.
Ok(Ok(_)) => {}
// query failed or was cancelled.
Ok(Err(error)) => {
let db_error = error.as_db_error();
// if errored for some other reason, it might not be safe to reuse the connection.
if !db_error.is_some_and(|e| *e.code() == SqlState::QUERY_CANCELED) {
discard.discard();
}
}
Err(_timeout) => {
discard.discard();
}
}
}
async fn query_batch_to_json(
async fn query_to_json(
config: &'static HttpConfig,
cancel: CancellationToken,
tx: &mut Transaction<'_>,
queries: BatchQueryData,
headers: HttpHeaders,
) -> Result<String, SqlOverHttpError> {
let json_output = json::value_to_string!(|obj| json::value_as_object!(|obj| {
let results = obj.key("results");
json::value_as_list!(|results| {
query_batch(config, cancel, tx, queries, headers, results).await?;
});
}));
Ok(json_output)
}
async fn query_to_json<T: GenericClient>(
config: &'static HttpConfig,
client: &mut T,
data: QueryData,
cancel: &CancellationToken,
client: &mut postgres_client::Client,
data: &QueryData,
output: json::ValueSer<'_>,
parsed_headers: HttpHeaders,
) -> Result<ReadyForQueryStatus, SqlOverHttpError> {
let query_start = Instant::now();
let mut output = json::ObjectSer::new(output);
let mut row_stream = client
.query_raw_txt(&data.query, data.params)
let params = data.params.iter().map(Option::as_deref);
let mut row_stream = run_until_cancelled(client.query_raw_txt(&data.query, params), cancel)
.await
.ok_or(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Postgres))?
.map_err(SqlOverHttpError::Postgres)?;
let query_acknowledged = Instant::now();
let mut output = json::ObjectSer::new(output);
let mut json_fields = output.key("fields").list();
for c in row_stream.statement.columns() {
let json_field = json_fields.entry();
@@ -923,8 +811,13 @@ async fn query_to_json<T: GenericClient>(
// big.
let mut rows = 0;
let mut json_rows = output.key("rows").list();
while let Some(row) = row_stream.next().await {
let row = row.map_err(SqlOverHttpError::Postgres)?;
loop {
let row = run_until_cancelled(row_stream.try_next(), cancel)
.await
.ok_or(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Postgres))?
.map_err(SqlOverHttpError::Postgres)?;
let Some(row) = row else { break };
// we don't have a streaming response support yet so this is to prevent OOM
// from a malicious query (eg a cross join)
@@ -1012,12 +905,6 @@ impl Client {
}
impl Discard<'_> {
fn check_idle(&mut self, status: ReadyForQueryStatus) {
match self {
Discard::Remote(discard) => discard.check_idle(status),
Discard::Local(discard) => discard.check_idle(status),
}
}
fn discard(&mut self) {
match self {
Discard::Remote(discard) => discard.discard(),

View File

@@ -1,23 +1,50 @@
use std::pin::pin;
use std::{
pin::Pin,
task::{Context, Poll},
};
use futures::future::{Either, select};
use futures::FutureExt;
use tokio_util::sync::CancellationToken;
pub async fn run_until_cancelled<F: Future>(
pub fn run_until_cancelled<F: Future>(
f: F,
cancellation_token: &CancellationToken,
) -> Option<F::Output> {
run_until(f, cancellation_token.cancelled()).await.ok()
) -> impl Future<Output = Option<F::Output>> {
run_until(f, cancellation_token.cancelled()).map(|r| r.ok())
}
/// Runs the future `f` unless interrupted by future `condition`.
pub async fn run_until<F1: Future, F2: Future>(
pub fn run_until<F1: Future, F2: Future>(
f: F1,
condition: F2,
) -> Result<F1::Output, F2::Output> {
match select(pin!(f), pin!(condition)).await {
Either::Left((f1, _)) => Ok(f1),
Either::Right((f2, _)) => Err(f2),
) -> impl Future<Output = Result<F1::Output, F2::Output>> {
RunUntil { a: f, b: condition }
}
pin_project_lite::pin_project! {
struct RunUntil<A, B> {
#[pin] a: A,
#[pin] b: B,
}
}
impl<A, B> Future for RunUntil<A, B>
where
A: Future,
B: Future,
{
type Output = Result<A::Output, B::Output>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
if let Poll::Ready(a) = this.a.poll(cx) {
return Poll::Ready(Ok(a));
}
if let Poll::Ready(b) = this.b.poll(cx) {
return Poll::Ready(Err(b));
}
Poll::Pending
}
}

View File

@@ -3899,6 +3899,41 @@ class NeonProxy(PgProtocol):
assert response.status_code == expected_code, f"response: {response.json()}"
return response.json()
def http_multiquery(self, *queries, **kwargs):
# TODO maybe use default values if not provided
user = quote(kwargs["user"])
password = quote(kwargs["password"])
expected_code = kwargs.get("expected_code")
timeout = kwargs.get("timeout")
json_queries = []
for query in queries:
if type(query) is str:
json_queries.append({"query": query})
else:
[query, params] = query
json_queries.append({"query": query, "params": params})
queries_str = [j["query"] for j in json_queries]
log.info(f"Executing http queries: {queries_str}")
connstr = f"postgresql://{user}:{password}@{self.domain}:{self.proxy_port}/postgres"
response = requests.post(
f"https://{self.domain}:{self.external_http_port}/sql",
data=json.dumps({"queries": json_queries}),
headers={
"Content-Type": "application/sql",
"Neon-Connection-String": connstr,
"Neon-Pool-Opt-In": "true",
},
verify=str(self.test_output_dir / "proxy.crt"),
timeout=timeout,
)
if expected_code is not None:
assert response.status_code == expected_code, f"response: {response.json()}"
return response.json()
async def http2_query(self, query, args, **kwargs):
# TODO maybe use default values if not provided
user = kwargs["user"]

View File

@@ -17,9 +17,6 @@ if TYPE_CHECKING:
from typing import Any
GET_CONNECTION_PID_QUERY = "SELECT pid FROM pg_stat_activity WHERE state = 'active'"
@pytest.mark.asyncio
async def test_http_pool_begin_1(static_proxy: NeonProxy):
static_proxy.safe_psql("create user http_auth with password 'http' superuser")
@@ -479,7 +476,7 @@ def test_sql_over_http_pool(static_proxy: NeonProxy):
def get_pid(status: int, pw: str, user="http_auth") -> Any:
return static_proxy.http_query(
GET_CONNECTION_PID_QUERY,
"SELECT pg_backend_pid() as pid",
[],
user=user,
password=pw,
@@ -513,6 +510,35 @@ def test_sql_over_http_pool(static_proxy: NeonProxy):
assert "password authentication failed for user" in res["message"]
def test_sql_over_http_pool_settings(static_proxy: NeonProxy):
static_proxy.safe_psql("create user http_auth with password 'http' superuser")
def multiquery(*queries) -> Any:
results = static_proxy.http_multiquery(
*queries,
user="http_auth",
password="http",
expected_code=200,
)
return [result["rows"] for result in results["results"]]
[[intervalstyle]] = static_proxy.safe_psql("SHOW IntervalStyle")
assert intervalstyle == "postgres", "'postgres' is the default IntervalStyle in postgres"
result = multiquery("select '0 seconds'::interval as interval")
assert result[0][0]["interval"] == "00:00:00", "interval is expected in postgres format"
result = multiquery(
"SET IntervalStyle = 'iso_8601'",
"select '0 seconds'::interval as interval",
)
assert result[1][0]["interval"] == "PT0S", "interval is expected in ISO-8601 format"
result = multiquery("select '0 seconds'::interval as interval")
assert result[0][0]["interval"] == "00:00:00", "interval is expected in postgres format"
def test_sql_over_http_urlencoding(static_proxy: NeonProxy):
static_proxy.safe_psql("create user \"http+auth$$\" with password '%+$^&*@!' superuser")
@@ -544,23 +570,37 @@ def test_http_pool_begin(static_proxy: NeonProxy):
query(200, "SELECT 1;") # Query that should succeed regardless of the transaction
def test_sql_over_http_pool_idle(static_proxy: NeonProxy):
def test_sql_over_http_pool_tx_reuse(static_proxy: NeonProxy):
static_proxy.safe_psql("create user http_auth2 with password 'http' superuser")
def query(status: int, query: str) -> Any:
def query(status: int, query: str, *args) -> Any:
return static_proxy.http_query(
query,
[],
args,
user="http_auth2",
password="http",
expected_code=status,
)
pid1 = query(200, GET_CONNECTION_PID_QUERY)["rows"][0]["pid"]
def query_pid_txid() -> Any:
result = query(
200,
"SELECT pg_backend_pid() as pid, pg_current_xact_id() as txid",
)
return result["rows"][0]
res0 = query_pid_txid()
time.sleep(0.02)
query(200, "BEGIN")
pid2 = query(200, GET_CONNECTION_PID_QUERY)["rows"][0]["pid"]
assert pid1 != pid2
res1 = query_pid_txid()
res2 = query_pid_txid()
assert res0["pid"] == res1["pid"], "connection should be reused"
assert res0["pid"] == res2["pid"], "connection should be reused"
assert res1["txid"] != res2["txid"], "txid should be different"
@pytest.mark.timeout(60)