Compare commits

...

14 Commits

Author SHA1 Message Date
Conrad Ludgate
e965bd96bb optimise some future sizes 2025-07-23 08:58:09 +01:00
Conrad Ludgate
14daaec98e more compact code and more compact futures 2025-07-23 08:58:09 +01:00
Conrad Ludgate
286ac97a9c remove typesafe transaction type as we already ensure rollback is performed 2025-07-23 08:58:09 +01:00
Conrad Ludgate
20355cb5f0 fix rest.rs 2025-07-23 07:04:36 +01:00
Conrad Ludgate
634dbd29b6 python lints 2025-07-23 07:04:36 +01:00
Conrad Ludgate
a235b241d5 ruff format 2025-07-23 07:04:36 +01:00
Conrad Ludgate
539652fa4e rollback safety 2025-07-23 07:04:36 +01:00
Conrad Ludgate
11294ca322 rename Send to LocalProxyClient, etc 2025-07-23 07:04:36 +01:00
Conrad Ludgate
84020c1328 fix python lints 2025-07-23 07:04:36 +01:00
Conrad Ludgate
0cc7415691 remove explicit discard_all for local_proxy 2025-07-23 07:04:36 +01:00
Conrad Ludgate
38df46b381 fix session state by resetting it 2025-07-23 07:04:36 +01:00
Conrad Ludgate
cdc73ad051 add regression test 2025-07-23 07:04:36 +01:00
Tristan Partin
fc242afcc2 PG ignore PageserverFeedback from unknown shards (#12671)
## Problem
When testing tenant splits, I found that PG can get backpressure
throttled indefinitely if the split is aborted afterwards. It turns out
that each PageServer activates new shard separately even before the
split is committed and they may start sending PageserverFeedback to PG
directly. As a result, if the split is aborted, no one resets the
pageserver feedback in PG, and thus PG will be backpressure throttled
forever unless it's restarted manually.

## Summary of changes
This PR fixes this problem by having
`walprop_pg_process_safekeeper_feedback` simply ignore all pageserver
feedback from unknown shards. The source of truth here is defined by the
shard map, which is guaranteed to be reloaded only after the split is
committed.

Co-authored-by: Chen Luo <chen.luo@databricks.com>
2025-07-22 21:41:56 +00:00
Suhas Thalanki
e275221aef add hadron-specific metrics (#12686) 2025-07-22 21:17:45 +00:00
20 changed files with 467 additions and 537 deletions

View File

@@ -0,0 +1,60 @@
use metrics::{
IntCounter, IntGaugeVec, core::Collector, proto::MetricFamily, register_int_counter,
register_int_gauge_vec,
};
use once_cell::sync::Lazy;
// Counter keeping track of the number of PageStream request errors reported by Postgres.
// An error is registered every time Postgres calls compute_ctl's /refresh_configuration API.
// Postgres will invoke this API if it detected trouble with PageStream requests (get_page@lsn,
// get_base_backup, etc.) it sends to any pageserver. An increase in this counter value typically
// indicates Postgres downtime, as PageStream requests are critical for Postgres to function.
pub static POSTGRES_PAGESTREAM_REQUEST_ERRORS: Lazy<IntCounter> = Lazy::new(|| {
register_int_counter!(
"pg_cctl_pagestream_request_errors_total",
"Number of PageStream request errors reported by the postgres process"
)
.expect("failed to define a metric")
});
// Counter keeping track of the number of compute configuration errors due to Postgres statement
// timeouts. An error is registered every time `ComputeNode::reconfigure()` fails due to Postgres
// error code 57014 (query cancelled). This statement timeout typically occurs when postgres is
// stuck in a problematic retry loop when the PS is reject its connection requests (usually due
// to PG pointing at the wrong PS). We should investigate the root cause when this counter value
// increases by checking PG and PS logs.
pub static COMPUTE_CONFIGURE_STATEMENT_TIMEOUT_ERRORS: Lazy<IntCounter> = Lazy::new(|| {
register_int_counter!(
"pg_cctl_configure_statement_timeout_errors_total",
"Number of compute configuration errors due to Postgres statement timeouts."
)
.expect("failed to define a metric")
});
pub static COMPUTE_ATTACHED: Lazy<IntGaugeVec> = Lazy::new(|| {
register_int_gauge_vec!(
"pg_cctl_attached",
"Compute node attached status (1 if attached)",
&[
"pg_compute_id",
"pg_instance_id",
"tenant_id",
"timeline_id"
]
)
.expect("failed to define a metric")
});
pub fn collect() -> Vec<MetricFamily> {
let mut metrics = Vec::new();
metrics.extend(POSTGRES_PAGESTREAM_REQUEST_ERRORS.collect());
metrics.extend(COMPUTE_CONFIGURE_STATEMENT_TIMEOUT_ERRORS.collect());
metrics.extend(COMPUTE_ATTACHED.collect());
metrics
}
pub fn initialize_metrics() {
Lazy::force(&POSTGRES_PAGESTREAM_REQUEST_ERRORS);
Lazy::force(&COMPUTE_CONFIGURE_STATEMENT_TIMEOUT_ERRORS);
Lazy::force(&COMPUTE_ATTACHED);
}

View File

@@ -16,6 +16,7 @@ pub mod compute_prewarm;
pub mod compute_promote;
pub mod disk_quota;
pub mod extension_server;
pub mod hadron_metrics;
pub mod installed_extensions;
pub mod local_proxy;
pub mod lsn_lease;

View File

@@ -600,6 +600,7 @@ impl ParameterStatusBody {
}
}
#[derive(Clone, Copy)]
pub struct ReadyForQueryBody {
status: u8,
}

View File

@@ -18,10 +18,7 @@ use crate::config::{Host, SslMode};
use crate::query::RowStream;
use crate::simple_query::SimpleQueryStream;
use crate::types::{Oid, Type};
use crate::{
CancelToken, Error, ReadyForQueryStatus, SimpleQueryMessage, Transaction, TransactionBuilder,
query, simple_query,
};
use crate::{CancelToken, Error, ReadyForQueryStatus, SimpleQueryMessage, query, simple_query};
pub struct Responses {
/// new messages from conn
@@ -32,6 +29,9 @@ pub struct Responses {
waiting: usize,
/// number of ReadyForQuery messages received.
received: usize,
/// The last query status we received.
last_status: ReadyForQueryStatus,
}
impl Responses {
@@ -42,7 +42,8 @@ impl Responses {
let received = self.received;
// increase the query head if this is the last message.
if let Message::ReadyForQuery(_) = message {
if let Message::ReadyForQuery(ref status) = message {
self.last_status = (*status).into();
self.received += 1;
}
@@ -71,6 +72,15 @@ impl Responses {
pub async fn next(&mut self) -> Result<Message, Error> {
future::poll_fn(|cx| self.poll_next(cx)).await
}
pub async fn wait_until_ready(&mut self) -> Result<ReadyForQueryStatus, Error> {
while self.received < self.waiting {
if let Message::ReadyForQuery(status) = self.next().await? {
return Ok(status.into());
}
}
Ok(self.last_status)
}
}
/// A cache of type info and prepared statements for fetching type info
@@ -95,13 +105,6 @@ impl InnerClient {
Ok(PartialQuery(Some(self)))
}
// pub fn send_with_sync<F>(&mut self, f: F) -> Result<&mut Responses, Error>
// where
// F: FnOnce(&mut BytesMut) -> Result<(), Error>,
// {
// self.start()?.send_with_sync(f)
// }
pub fn send_simple_query(&mut self, query: &str) -> Result<&mut Responses, Error> {
self.responses.waiting += 1;
@@ -200,6 +203,8 @@ impl Client {
cur: BackendMessages::empty(),
waiting: 0,
received: 0,
// new connections are always idle.
last_status: ReadyForQueryStatus::Idle,
},
buffer: Default::default(),
},
@@ -233,6 +238,11 @@ impl Client {
rx
}
/// Wait until this connection has no more active queries.
pub async fn wait_until_ready(&mut self) -> Result<ReadyForQueryStatus, Error> {
self.inner_mut().responses.wait_until_ready().await
}
/// Pass text directly to the Postgres backend to allow it to sort out typing itself and
/// to save a roundtrip
pub async fn query_raw_txt<S, I>(
@@ -292,52 +302,32 @@ impl Client {
simple_query::batch_execute(self.inner_mut(), query).await
}
pub async fn discard_all(&mut self) -> Result<ReadyForQueryStatus, Error> {
self.batch_execute("discard all").await
}
/// Begins a new database transaction.
/// Similar to `discard_all`, but it does not clear any query plans
///
/// The transaction will roll back by default - use the `commit` method to commit it.
pub async fn transaction(&mut self) -> Result<Transaction<'_>, Error> {
struct RollbackIfNotDone<'me> {
client: &'me mut Client,
done: bool,
}
/// This runs in the background, so it can be executed without `await`ing.
pub fn reset_session_background(&mut self) -> Result<(), Error> {
// "CLOSE ALL": closes any cursors
// "SET SESSION AUTHORIZATION DEFAULT": resets the current_user back to the session_user
// "RESET ALL": resets any GUCs back to their session defaults.
// "DEALLOCATE ALL": deallocates any prepared statements
// "UNLISTEN *": stops listening on all channels
// "SELECT pg_advisory_unlock_all();": unlocks all advisory locks
// "DISCARD TEMP;": drops all temporary tables
// "DISCARD SEQUENCES;": deallocates all cached sequence state
impl Drop for RollbackIfNotDone<'_> {
fn drop(&mut self) {
if self.done {
return;
}
let _responses = self.inner_mut().send_simple_query(
"ROLLBACK;
CLOSE ALL;
SET SESSION AUTHORIZATION DEFAULT;
RESET ALL;
DEALLOCATE ALL;
UNLISTEN *;
SELECT pg_advisory_unlock_all();
DISCARD TEMP;
DISCARD SEQUENCES;",
)?;
let _ = self.client.inner.send_simple_query("ROLLBACK");
}
}
// This is done, as `Future` created by this method can be dropped after
// `RequestMessages` is synchronously send to the `Connection` by
// `batch_execute()`, but before `Responses` is asynchronously polled to
// completion. In that case `Transaction` won't be created and thus
// won't be rolled back.
{
let mut cleaner = RollbackIfNotDone {
client: self,
done: false,
};
cleaner.client.batch_execute("BEGIN").await?;
cleaner.done = true;
}
Ok(Transaction::new(self))
}
/// Returns a builder for a transaction with custom settings.
///
/// Unlike the `transaction` method, the builder can be used to control the transaction's isolation level and other
/// attributes.
pub fn build_transaction(&mut self) -> TransactionBuilder<'_> {
TransactionBuilder::new(self)
Ok(())
}
/// Constructs a cancellation token that can later be used to request cancellation of a query running on the

View File

@@ -1,58 +0,0 @@
#![allow(async_fn_in_trait)]
use crate::query::RowStream;
use crate::{Client, Error, Transaction};
mod private {
pub trait Sealed {}
}
/// A trait allowing abstraction over connections and transactions.
///
/// This trait is "sealed", and cannot be implemented outside of this crate.
pub trait GenericClient: private::Sealed {
/// Like `Client::query_raw_txt`.
async fn query_raw_txt<S, I>(
&mut self,
statement: &str,
params: I,
) -> Result<RowStream<'_>, Error>
where
S: AsRef<str> + Sync + Send,
I: IntoIterator<Item = Option<S>> + Sync + Send,
I::IntoIter: ExactSizeIterator + Sync + Send;
}
impl private::Sealed for Client {}
impl GenericClient for Client {
async fn query_raw_txt<S, I>(
&mut self,
statement: &str,
params: I,
) -> Result<RowStream<'_>, Error>
where
S: AsRef<str> + Sync + Send,
I: IntoIterator<Item = Option<S>> + Sync + Send,
I::IntoIter: ExactSizeIterator + Sync + Send,
{
self.query_raw_txt(statement, params).await
}
}
impl private::Sealed for Transaction<'_> {}
impl GenericClient for Transaction<'_> {
async fn query_raw_txt<S, I>(
&mut self,
statement: &str,
params: I,
) -> Result<RowStream<'_>, Error>
where
S: AsRef<str> + Sync + Send,
I: IntoIterator<Item = Option<S>> + Sync + Send,
I::IntoIter: ExactSizeIterator + Sync + Send,
{
self.query_raw_txt(statement, params).await
}
}

View File

@@ -9,13 +9,11 @@ pub use crate::config::Config;
pub use crate::connect_raw::RawConnection;
pub use crate::connection::Connection;
pub use crate::error::Error;
pub use crate::generic_client::GenericClient;
pub use crate::query::RowStream;
pub use crate::row::{Row, SimpleQueryRow};
pub use crate::simple_query::SimpleQueryStream;
pub use crate::statement::{Column, Statement};
pub use crate::tls::NoTls;
pub use crate::transaction::Transaction;
pub use crate::transaction_builder::{IsolationLevel, TransactionBuilder};
/// After executing a query, the connection will be in one of these states
@@ -55,7 +53,6 @@ mod connect_socket;
mod connect_tls;
mod connection;
pub mod error;
mod generic_client;
pub mod maybe_tls_stream;
mod prepare;
mod query;
@@ -63,7 +60,6 @@ pub mod row;
mod simple_query;
mod statement;
pub mod tls;
mod transaction;
mod transaction_builder;
pub mod types;

View File

@@ -1,73 +0,0 @@
use crate::query::RowStream;
use crate::{CancelToken, Client, Error, ReadyForQueryStatus};
/// A representation of a PostgreSQL database transaction.
///
/// Transactions will implicitly roll back when dropped. Use the `commit` method to commit the changes made in the
/// transaction. Transactions can be nested, with inner transactions implemented via safepoints.
pub struct Transaction<'a> {
client: &'a mut Client,
done: bool,
}
impl Drop for Transaction<'_> {
fn drop(&mut self) {
if self.done {
return;
}
let _ = self.client.inner_mut().send_simple_query("ROLLBACK");
}
}
impl<'a> Transaction<'a> {
pub(crate) fn new(client: &'a mut Client) -> Transaction<'a> {
Transaction {
client,
done: false,
}
}
/// Consumes the transaction, committing all changes made within it.
pub async fn commit(mut self) -> Result<ReadyForQueryStatus, Error> {
self.done = true;
self.client.batch_execute("COMMIT").await
}
/// Rolls the transaction back, discarding all changes made within it.
///
/// This is equivalent to `Transaction`'s `Drop` implementation, but provides any error encountered to the caller.
pub async fn rollback(mut self) -> Result<ReadyForQueryStatus, Error> {
self.done = true;
self.client.batch_execute("ROLLBACK").await
}
/// Like `Client::query_raw_txt`.
pub async fn query_raw_txt<S, I>(
&mut self,
statement: &str,
params: I,
) -> Result<RowStream<'_>, Error>
where
S: AsRef<str>,
I: IntoIterator<Item = Option<S>>,
I::IntoIter: ExactSizeIterator,
{
self.client.query_raw_txt(statement, params).await
}
/// Like `Client::cancel_token`.
pub fn cancel_token(&self) -> CancelToken {
self.client.cancel_token()
}
/// Returns a reference to the underlying `Client`.
pub fn client(&self) -> &Client {
self.client
}
/// Returns a reference to the underlying `Client`.
pub fn client_mut(&mut self) -> &mut Client {
self.client
}
}

View File

@@ -1,5 +1,3 @@
use crate::{Client, Error, Transaction};
/// The isolation level of a database transaction.
#[derive(Debug, Copy, Clone)]
#[non_exhaustive]
@@ -20,49 +18,17 @@ pub enum IsolationLevel {
}
/// A builder for database transactions.
pub struct TransactionBuilder<'a> {
client: &'a mut Client,
isolation_level: Option<IsolationLevel>,
read_only: Option<bool>,
deferrable: Option<bool>,
pub struct TransactionBuilder {
pub isolation_level: Option<IsolationLevel>,
pub read_only: Option<bool>,
pub deferrable: Option<bool>,
}
impl<'a> TransactionBuilder<'a> {
pub(crate) fn new(client: &'a mut Client) -> TransactionBuilder<'a> {
TransactionBuilder {
client,
isolation_level: None,
read_only: None,
deferrable: None,
}
}
/// Sets the isolation level of the transaction.
pub fn isolation_level(mut self, isolation_level: IsolationLevel) -> Self {
self.isolation_level = Some(isolation_level);
self
}
/// Sets the access mode of the transaction.
pub fn read_only(mut self, read_only: bool) -> Self {
self.read_only = Some(read_only);
self
}
/// Sets the deferrability of the transaction.
///
/// If the transaction is also serializable and read only, creation of the transaction may block, but when it
/// completes the transaction is able to run with less overhead and a guarantee that it will not be aborted due to
/// serialization failure.
pub fn deferrable(mut self, deferrable: bool) -> Self {
self.deferrable = Some(deferrable);
self
}
impl TransactionBuilder {
/// Begins the transaction.
///
/// The transaction will roll back by default - use the `commit` method to commit it.
pub async fn start(self) -> Result<Transaction<'a>, Error> {
pub fn format(self) -> String {
let mut query = "START TRANSACTION".to_string();
let mut first = true;
@@ -106,8 +72,6 @@ impl<'a> TransactionBuilder<'a> {
query.push_str(s);
}
self.client.batch_execute(&query).await?;
Ok(Transaction::new(self.client))
query
}
}

View File

@@ -178,6 +178,8 @@ static PageServer page_servers[MAX_SHARDS];
static bool pageserver_flush(shardno_t shard_no);
static void pageserver_disconnect(shardno_t shard_no);
static void pageserver_disconnect_shard(shardno_t shard_no);
// HADRON
shardno_t get_num_shards(void);
static bool
PagestoreShmemIsValid(void)
@@ -286,6 +288,22 @@ AssignPageserverConnstring(const char *newval, void *extra)
}
}
/* BEGIN_HADRON */
/**
* Return the total number of shards seen in the shard map.
*/
shardno_t get_num_shards(void)
{
const ShardMap *shard_map;
Assert(pagestore_shared);
shard_map = &pagestore_shared->shard_map;
Assert(shard_map != NULL);
return shard_map->num_shards;
}
/* END_HADRON */
/*
* Get the current number of shards, and/or the connection string for a
* particular shard from the shard map in shared memory.

View File

@@ -110,6 +110,9 @@ static void rm_safekeeper_event_set(Safekeeper *to_remove, bool is_sk);
static void CheckGracefulShutdown(WalProposer *wp);
// HADRON
shardno_t get_num_shards(void);
static void
init_walprop_config(bool syncSafekeepers)
{
@@ -646,18 +649,19 @@ walprop_pg_get_shmem_state(WalProposer *wp)
* Record new ps_feedback in the array with shards and update min_feedback.
*/
static PageserverFeedback
record_pageserver_feedback(PageserverFeedback *ps_feedback)
record_pageserver_feedback(PageserverFeedback *ps_feedback, shardno_t num_shards)
{
PageserverFeedback min_feedback;
Assert(ps_feedback->present);
Assert(ps_feedback->shard_number < MAX_SHARDS);
Assert(ps_feedback->shard_number < num_shards);
SpinLockAcquire(&walprop_shared->mutex);
/* Update the number of shards */
if (ps_feedback->shard_number + 1 > walprop_shared->num_shards)
walprop_shared->num_shards = ps_feedback->shard_number + 1;
// Hadron: Update the num_shards from the source-of-truth (shard map) lazily when we receive
// a new pageserver feedback.
walprop_shared->num_shards = Max(walprop_shared->num_shards, num_shards);
/* Update the feedback */
memcpy(&walprop_shared->shard_ps_feedback[ps_feedback->shard_number], ps_feedback, sizeof(PageserverFeedback));
@@ -2023,19 +2027,43 @@ walprop_pg_process_safekeeper_feedback(WalProposer *wp, Safekeeper *sk)
if (wp->config->syncSafekeepers)
return;
/* handle fresh ps_feedback */
if (sk->appendResponse.ps_feedback.present)
{
PageserverFeedback min_feedback = record_pageserver_feedback(&sk->appendResponse.ps_feedback);
shardno_t num_shards = get_num_shards();
/* Only one main shard sends non-zero currentClusterSize */
if (sk->appendResponse.ps_feedback.currentClusterSize > 0)
SetNeonCurrentClusterSize(sk->appendResponse.ps_feedback.currentClusterSize);
if (min_feedback.disk_consistent_lsn != standby_apply_lsn)
// During shard split, we receive ps_feedback from child shards before
// the split commits and our shard map GUC has been updated. We must
// filter out such feedback here because record_pageserver_feedback()
// doesn't do it.
//
// NB: what we would actually want to happen is that we only receive
// ps_feedback from the parent shards when the split is committed, then
// apply the split to our set of tracked feedback and from here on only
// receive ps_feedback from child shards. This filter condition doesn't
// do that: if we split from N parent to 2N child shards, the first N
// child shards' feedback messages will pass this condition, even before
// the split is committed. That's a bit sloppy, but OK for now.
if (sk->appendResponse.ps_feedback.shard_number < num_shards)
{
standby_apply_lsn = min_feedback.disk_consistent_lsn;
needToAdvanceSlot = true;
PageserverFeedback min_feedback = record_pageserver_feedback(&sk->appendResponse.ps_feedback, num_shards);
/* Only one main shard sends non-zero currentClusterSize */
if (sk->appendResponse.ps_feedback.currentClusterSize > 0)
SetNeonCurrentClusterSize(sk->appendResponse.ps_feedback.currentClusterSize);
if (min_feedback.disk_consistent_lsn != standby_apply_lsn)
{
standby_apply_lsn = min_feedback.disk_consistent_lsn;
needToAdvanceSlot = true;
}
}
else
{
// HADRON
elog(DEBUG2, "Ignoring pageserver feedback for unknown shard %d (current shard number %d)",
sk->appendResponse.ps_feedback.shard_number, num_shards);
}
}

View File

@@ -18,7 +18,7 @@ use tracing::{debug, info};
use super::AsyncRW;
use super::conn_pool::poll_client;
use super::conn_pool_lib::{Client, ConnInfo, EndpointConnPool, GlobalConnPool};
use super::http_conn_pool::{self, HttpConnPool, Send, poll_http2_client};
use super::http_conn_pool::{self, HttpConnPool, LocalProxyClient, poll_http2_client};
use super::local_conn_pool::{self, EXT_NAME, EXT_SCHEMA, EXT_VERSION, LocalConnPool};
use crate::auth::backend::local::StaticAuthRules;
use crate::auth::backend::{ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo};
@@ -40,7 +40,8 @@ use crate::rate_limiter::EndpointRateLimiter;
use crate::types::{EndpointId, Host, LOCAL_PROXY_SUFFIX};
pub(crate) struct PoolingBackend {
pub(crate) http_conn_pool: Arc<GlobalConnPool<Send, HttpConnPool<Send>>>,
pub(crate) http_conn_pool:
Arc<GlobalConnPool<LocalProxyClient, HttpConnPool<LocalProxyClient>>>,
pub(crate) local_pool: Arc<LocalConnPool<postgres_client::Client>>,
pub(crate) pool:
Arc<GlobalConnPool<postgres_client::Client, EndpointConnPool<postgres_client::Client>>>,
@@ -210,7 +211,7 @@ impl PoolingBackend {
&self,
ctx: &RequestContext,
conn_info: ConnInfo,
) -> Result<http_conn_pool::Client<Send>, HttpConnError> {
) -> Result<http_conn_pool::Client<LocalProxyClient>, HttpConnError> {
debug!("pool: looking for an existing connection");
if let Ok(Some(client)) = self.http_conn_pool.get(ctx, &conn_info) {
return Ok(client);
@@ -568,7 +569,7 @@ impl ConnectMechanism for TokioMechanism {
}
struct HyperMechanism {
pool: Arc<GlobalConnPool<Send, HttpConnPool<Send>>>,
pool: Arc<GlobalConnPool<LocalProxyClient, HttpConnPool<LocalProxyClient>>>,
conn_info: ConnInfo,
conn_id: uuid::Uuid,
@@ -578,7 +579,7 @@ struct HyperMechanism {
#[async_trait]
impl ConnectMechanism for HyperMechanism {
type Connection = http_conn_pool::Client<Send>;
type Connection = http_conn_pool::Client<LocalProxyClient>;
type ConnectError = HttpConnError;
type Error = HttpConnError;
@@ -632,7 +633,13 @@ async fn connect_http2(
port: u16,
timeout: Duration,
tls: Option<&Arc<rustls::ClientConfig>>,
) -> Result<(http_conn_pool::Send, http_conn_pool::Connect), LocalProxyConnError> {
) -> Result<
(
http_conn_pool::LocalProxyClient,
http_conn_pool::LocalProxyConnection,
),
LocalProxyConnError,
> {
let addrs = match host_addr {
Some(addr) => vec![SocketAddr::new(addr, port)],
None => lookup_host((host, port))

View File

@@ -190,6 +190,9 @@ mod tests {
fn get_process_id(&self) -> i32 {
0
}
fn reset(&mut self) -> Result<(), postgres_client::Error> {
Ok(())
}
}
fn create_inner() -> ClientInnerCommon<MockClient> {

View File

@@ -7,10 +7,9 @@ use std::time::Duration;
use clashmap::ClashMap;
use parking_lot::RwLock;
use postgres_client::ReadyForQueryStatus;
use rand::Rng;
use smol_str::ToSmolStr;
use tracing::{Span, debug, info};
use tracing::{Span, debug, info, warn};
use super::backend::HttpConnError;
use super::conn_pool::ClientDataRemote;
@@ -188,7 +187,7 @@ impl<C: ClientInnerExt> EndpointConnPool<C> {
self.pools.get_mut(&db_user)
}
pub(crate) fn put(pool: &RwLock<Self>, conn_info: &ConnInfo, client: ClientInnerCommon<C>) {
pub(crate) fn put(pool: &RwLock<Self>, conn_info: &ConnInfo, mut client: ClientInnerCommon<C>) {
let conn_id = client.get_conn_id();
let (max_conn, conn_count, pool_name) = {
let pool = pool.read();
@@ -201,12 +200,17 @@ impl<C: ClientInnerExt> EndpointConnPool<C> {
};
if client.inner.is_closed() {
info!(%conn_id, "{}: throwing away connection '{conn_info}' because connection is closed", pool_name);
info!(%conn_id, "{pool_name}: throwing away connection '{conn_info}' because connection is closed");
return;
}
if let Err(error) = client.inner.reset() {
warn!(?error, %conn_id, "{pool_name}: throwing away connection '{conn_info}' because connection could not be reset");
return;
}
if conn_count >= max_conn {
info!(%conn_id, "{}: throwing away connection '{conn_info}' because pool is full", pool_name);
info!(%conn_id, "{pool_name}: throwing away connection '{conn_info}' because pool is full");
return;
}
@@ -691,6 +695,7 @@ impl<C: ClientInnerExt> Deref for Client<C> {
pub(crate) trait ClientInnerExt: Sync + Send + 'static {
fn is_closed(&self) -> bool;
fn get_process_id(&self) -> i32;
fn reset(&mut self) -> Result<(), postgres_client::Error>;
}
impl ClientInnerExt for postgres_client::Client {
@@ -701,15 +706,13 @@ impl ClientInnerExt for postgres_client::Client {
fn get_process_id(&self) -> i32 {
self.get_process_id()
}
fn reset(&mut self) -> Result<(), postgres_client::Error> {
self.reset_session_background()
}
}
impl<C: ClientInnerExt> Discard<'_, C> {
pub(crate) fn check_idle(&mut self, status: ReadyForQueryStatus) {
let conn_info = &self.conn_info;
if status != ReadyForQueryStatus::Idle && std::mem::take(self.pool).strong_count() > 0 {
info!("pool: throwing away connection '{conn_info}' because connection is not idle");
}
}
pub(crate) fn discard(&mut self) {
let conn_info = &self.conn_info;
if std::mem::take(self.pool).strong_count() > 0 {

View File

@@ -23,8 +23,8 @@ use crate::protocol2::ConnectionInfoExtra;
use crate::types::EndpointCacheKey;
use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS};
pub(crate) type Send = http2::SendRequest<BoxBody<Bytes, hyper::Error>>;
pub(crate) type Connect =
pub(crate) type LocalProxyClient = http2::SendRequest<BoxBody<Bytes, hyper::Error>>;
pub(crate) type LocalProxyConnection =
http2::Connection<TokioIo<AsyncRW>, BoxBody<Bytes, hyper::Error>, TokioExecutor>;
#[derive(Clone)]
@@ -189,14 +189,14 @@ impl<C: ClientInnerExt + Clone> GlobalConnPool<C, HttpConnPool<C>> {
}
pub(crate) fn poll_http2_client(
global_pool: Arc<GlobalConnPool<Send, HttpConnPool<Send>>>,
global_pool: Arc<GlobalConnPool<LocalProxyClient, HttpConnPool<LocalProxyClient>>>,
ctx: &RequestContext,
conn_info: &ConnInfo,
client: Send,
connection: Connect,
client: LocalProxyClient,
connection: LocalProxyConnection,
conn_id: uuid::Uuid,
aux: MetricsAuxInfo,
) -> Client<Send> {
) -> Client<LocalProxyClient> {
let conn_gauge = Metrics::get().proxy.db_connections.guard(ctx.protocol());
let session_id = ctx.session_id();
@@ -285,7 +285,7 @@ impl<C: ClientInnerExt + Clone> Client<C> {
}
}
impl ClientInnerExt for Send {
impl ClientInnerExt for LocalProxyClient {
fn is_closed(&self) -> bool {
self.is_closed()
}
@@ -294,4 +294,10 @@ impl ClientInnerExt for Send {
// ideally throw something meaningful
-1
}
fn reset(&mut self) -> Result<(), postgres_client::Error> {
// We use HTTP/2.0 to talk to local proxy. HTTP is stateless,
// so there's nothing to reset.
Ok(())
}
}

View File

@@ -269,11 +269,6 @@ impl ClientInnerCommon<postgres_client::Client> {
local_data.jti += 1;
let token = resign_jwt(&local_data.key, payload, local_data.jti)?;
self.inner
.discard_all()
.await
.map_err(SqlOverHttpError::InternalPostgres)?;
// initiates the auth session
// this is safe from query injections as the jwt format free of any escape characters.
let query = format!("select auth.jwt_session_init('{token}')");

View File

@@ -46,7 +46,7 @@ use super::backend::{HttpConnError, LocalProxyConnError, PoolingBackend};
use super::conn_pool::AuthData;
use super::conn_pool_lib::ConnInfo;
use super::error::{ConnInfoError, Credentials, HttpCodeError, ReadPayloadError};
use super::http_conn_pool::{self, Send};
use super::http_conn_pool::{self, LocalProxyClient};
use super::http_util::{
ALLOW_POOL, CONN_STRING, NEON_REQUEST_ID, RAW_TEXT_OUTPUT, TXN_ISOLATION_LEVEL, TXN_READ_ONLY,
get_conn_info, json_response, uuid_to_header_value,
@@ -145,7 +145,7 @@ impl DbSchemaCache {
endpoint_id: &EndpointCacheKey,
auth_header: &HeaderValue,
connection_string: &str,
client: &mut http_conn_pool::Client<Send>,
client: &mut http_conn_pool::Client<LocalProxyClient>,
ctx: &RequestContext,
config: &'static ProxyConfig,
) -> Result<Arc<(ApiConfig, DbSchemaOwned)>, RestError> {
@@ -190,7 +190,7 @@ impl DbSchemaCache {
&self,
auth_header: &HeaderValue,
connection_string: &str,
client: &mut http_conn_pool::Client<Send>,
client: &mut http_conn_pool::Client<LocalProxyClient>,
ctx: &RequestContext,
config: &'static ProxyConfig,
) -> Result<(ApiConfig, DbSchemaOwned), RestError> {
@@ -430,7 +430,7 @@ struct BatchQueryData<'a> {
}
async fn make_local_proxy_request<S: DeserializeOwned>(
client: &mut http_conn_pool::Client<Send>,
client: &mut http_conn_pool::Client<LocalProxyClient>,
headers: impl IntoIterator<Item = (&HeaderName, HeaderValue)>,
body: QueryData<'_>,
max_len: usize,
@@ -461,7 +461,7 @@ async fn make_local_proxy_request<S: DeserializeOwned>(
}
async fn make_raw_local_proxy_request(
client: &mut http_conn_pool::Client<Send>,
client: &mut http_conn_pool::Client<LocalProxyClient>,
headers: impl IntoIterator<Item = (&HeaderName, HeaderValue)>,
body: String,
) -> Result<Response<Incoming>, RestError> {

View File

@@ -1,9 +1,8 @@
use std::pin::pin;
use std::sync::Arc;
use bytes::Bytes;
use futures::future::{Either, select, try_join};
use futures::{StreamExt, TryFutureExt};
use futures::future::try_join;
use futures::{TryFutureExt, TryStreamExt};
use http::Method;
use http::header::AUTHORIZATION;
use http_body_util::combinators::BoxBody;
@@ -14,7 +13,7 @@ use hyper::http::{HeaderName, HeaderValue};
use hyper::{Request, Response, StatusCode, header};
use indexmap::IndexMap;
use postgres_client::error::{DbError, ErrorPosition, SqlState};
use postgres_client::{GenericClient, IsolationLevel, NoTls, ReadyForQueryStatus, Transaction};
use postgres_client::{IsolationLevel, NoTls, ReadyForQueryStatus, TransactionBuilder};
use serde_json::Value;
use serde_json::value::RawValue;
use tokio::time::{self, Instant};
@@ -495,7 +494,7 @@ async fn handle_db_inner(
.http_conn_content_length_bytes
.observe(HttpDirection::Request, body.len() as f64);
debug!(length = body.len(), "request payload read");
debug!(length = body.len(), "request payload read ");
let payload: Payload = serde_json::from_slice(&body)?;
Ok::<Payload, ReadPayloadError>(payload) // Adjust error type accordingly
}
@@ -530,13 +529,13 @@ async fn handle_db_inner(
let (cli_inner, _dsc) = client.client_inner();
cli_inner.set_jwt_session(&payload).await?;
}
Client::Local(client)
Box::new(Client::Local(client))
}
_ => {
let client = backend
.connect_to_compute(ctx, conn_info, keys, !allow_pool)
.await?;
Client::Remote(client)
Box::new(Client::Remote(client))
}
};
@@ -550,10 +549,7 @@ async fn handle_db_inner(
let (payload, mut client) = match run_until_cancelled(
// Run both operations in parallel
try_join(
pin!(fetch_and_process_request),
pin!(authenticate_and_connect),
),
try_join(fetch_and_process_request, authenticate_and_connect),
&cancel,
)
.await
@@ -562,37 +558,38 @@ async fn handle_db_inner(
None => return Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Connect)),
};
// Now execute the query and return the result.
let json_output = match payload
.process(&config.http_config, cancel, &mut client, parsed_headers)
.await
{
Ok(json_output) => json_output,
Err(error) => {
if let SqlOverHttpError::Cancelled(_) = error {
cancel_query(&mut client).await;
}
return Err(error);
}
};
let mut response = Response::builder()
.status(StatusCode::OK)
.header(header::CONTENT_TYPE, "application/json");
// Now execute the query and return the result.
let json_output = match payload {
Payload::Single(stmt) => {
stmt.process(&config.http_config, cancel, &mut client, parsed_headers)
.await?
if let Payload::Batch(_) = payload {
if parsed_headers.txn_read_only {
response = response.header(TXN_READ_ONLY.clone(), &HEADER_VALUE_TRUE);
}
Payload::Batch(statements) => {
if parsed_headers.txn_read_only {
response = response.header(TXN_READ_ONLY.clone(), &HEADER_VALUE_TRUE);
}
if parsed_headers.txn_deferrable {
response = response.header(TXN_DEFERRABLE.clone(), &HEADER_VALUE_TRUE);
}
if let Some(txn_isolation_level) = parsed_headers
.txn_isolation_level
.and_then(map_isolation_level_to_headers)
{
response = response.header(TXN_ISOLATION_LEVEL.clone(), txn_isolation_level);
}
statements
.process(&config.http_config, cancel, &mut client, parsed_headers)
.await?
if parsed_headers.txn_deferrable {
response = response.header(TXN_DEFERRABLE.clone(), &HEADER_VALUE_TRUE);
}
};
let metrics = client.metrics(ctx);
if let Some(txn_isolation_level) = parsed_headers
.txn_isolation_level
.and_then(map_isolation_level_to_headers)
{
response = response.header(TXN_ISOLATION_LEVEL.clone(), txn_isolation_level);
}
}
let len = json_output.len();
let response = response
@@ -607,6 +604,7 @@ async fn handle_db_inner(
// count the egress bytes - we miss the TLS and header overhead but oh well...
// moving this later in the stack is going to be a lot of effort and ehhhh
let metrics = client.metrics(ctx);
metrics.record_egress(len as u64);
metrics.record_ingress(request_len as u64);
@@ -673,233 +671,123 @@ async fn handle_auth_broker_inner(
.map(|b| b.boxed()))
}
impl QueryData {
impl Payload {
async fn process(
self,
&self,
config: &'static HttpConfig,
cancel: CancellationToken,
client: &mut Client,
parsed_headers: HttpHeaders,
) -> Result<String, SqlOverHttpError> {
let (inner, mut discard) = client.inner();
let cancel_token = inner.cancel_token();
let mut json_buf = vec![];
let needs_tx = matches!(self, Payload::Batch(_));
let batch_result = match select(
pin!(query_to_json(
config,
&mut *inner,
self,
json::ValueSer::new(&mut json_buf),
parsed_headers
)),
pin!(cancel.cancelled()),
)
.await
{
Either::Left((res, __not_yet_cancelled)) => res,
Either::Right((_cancelled, query)) => {
tracing::info!("cancelling query");
if let Err(err) = cancel_token.cancel_query(NoTls).await {
tracing::warn!(?err, "could not cancel query");
if needs_tx {
info!("starting transaction");
let query = TransactionBuilder {
isolation_level: parsed_headers.txn_isolation_level,
read_only: parsed_headers.txn_read_only.then_some(true),
deferrable: parsed_headers.txn_deferrable.then_some(true),
}
.format();
inner
.batch_execute(&query)
.await
.inspect_err(|_| {
// if we cannot start a transaction, we should return immediately
// and not return to the pool. connection is clearly broken
discard.discard();
})
.map_err(SqlOverHttpError::Postgres)?;
}
let json_output = json::value_to_string!(|value| match &self {
Payload::Single(query) => {
query_to_json(config, &cancel, inner, query, value, parsed_headers).await?;
}
Payload::Batch(batch) => {
let mut obj = value.object();
let mut results = obj.key("results").list();
for query in &batch.queries {
let value = results.entry();
query_to_json(config, &cancel, inner, query, value, parsed_headers).await?;
}
// wait for the query cancellation
match time::timeout(time::Duration::from_millis(100), query).await {
// query successed before it was cancelled.
Ok(Ok(status)) => Ok(status),
// query failed or was cancelled.
Ok(Err(error)) => {
let db_error = match &error {
SqlOverHttpError::ConnectCompute(
HttpConnError::PostgresConnectionError(e),
)
| SqlOverHttpError::Postgres(e) => e.as_db_error(),
_ => None,
};
// if errored for some other reason, it might not be safe to return
if !db_error.is_some_and(|e| *e.code() == SqlState::QUERY_CANCELED) {
discard.discard();
}
return Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Postgres));
}
Err(_timeout) => {
discard.discard();
return Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Postgres));
}
}
results.finish();
obj.finish();
}
};
});
match batch_result {
// The query successfully completed.
Ok(status) => {
discard.check_idle(status);
let json_output = String::from_utf8(json_buf).expect("json should be valid utf8");
Ok(json_output)
}
// The query failed with an error
Err(e) => {
discard.discard();
Err(e)
}
if needs_tx {
inner
.batch_execute("COMMIT")
.await
.inspect_err(|_| {
// if we cannot commit - for now don't return connection to pool
// TODO: get a query status from the error
discard.discard();
})
.map_err(SqlOverHttpError::Postgres)?;
}
}
}
impl BatchQueryData {
async fn process(
self,
config: &'static HttpConfig,
cancel: CancellationToken,
client: &mut Client,
parsed_headers: HttpHeaders,
) -> Result<String, SqlOverHttpError> {
info!("starting transaction");
let (inner, mut discard) = client.inner();
let cancel_token = inner.cancel_token();
let mut builder = inner.build_transaction();
if let Some(isolation_level) = parsed_headers.txn_isolation_level {
builder = builder.isolation_level(isolation_level);
}
if parsed_headers.txn_read_only {
builder = builder.read_only(true);
}
if parsed_headers.txn_deferrable {
builder = builder.deferrable(true);
}
let mut transaction = builder
.start()
.await
.inspect_err(|_| {
// if we cannot start a transaction, we should return immediately
// and not return to the pool. connection is clearly broken
discard.discard();
})
.map_err(SqlOverHttpError::Postgres)?;
let json_output = match query_batch_to_json(
config,
cancel.child_token(),
&mut transaction,
self,
parsed_headers,
)
.await
{
Ok(json_output) => {
info!("commit");
let status = transaction
.commit()
.await
.inspect_err(|_| {
// if we cannot commit - for now don't return connection to pool
// TODO: get a query status from the error
discard.discard();
})
.map_err(SqlOverHttpError::Postgres)?;
discard.check_idle(status);
json_output
}
Err(SqlOverHttpError::Cancelled(_)) => {
if let Err(err) = cancel_token.cancel_query(NoTls).await {
tracing::warn!(?err, "could not cancel query");
}
// TODO: after cancelling, wait to see if we can get a status. maybe the connection is still safe.
discard.discard();
return Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Postgres));
}
Err(err) => {
info!("rollback");
let status = transaction
.rollback()
.await
.inspect_err(|_| {
// if we cannot rollback - for now don't return connection to pool
// TODO: get a query status from the error
discard.discard();
})
.map_err(SqlOverHttpError::Postgres)?;
discard.check_idle(status);
return Err(err);
}
};
Ok(json_output)
}
}
async fn query_batch(
config: &'static HttpConfig,
cancel: CancellationToken,
transaction: &mut Transaction<'_>,
queries: BatchQueryData,
parsed_headers: HttpHeaders,
results: &mut json::ListSer<'_>,
) -> Result<(), SqlOverHttpError> {
for stmt in queries.queries {
let query = pin!(query_to_json(
config,
transaction,
stmt,
results.entry(),
parsed_headers,
));
let cancelled = pin!(cancel.cancelled());
let res = select(query, cancelled).await;
match res {
// TODO: maybe we should check that the transaction bit is set here
Either::Left((Ok(_), _cancelled)) => {}
Either::Left((Err(e), _cancelled)) => {
return Err(e);
}
Either::Right((_cancelled, _)) => {
return Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Postgres));
}
}
async fn cancel_query(client: &mut Client) {
let (inner, mut discard) = client.inner();
let cancel_token = inner.cancel_token();
if let Err(err) = cancel_token.cancel_query(NoTls).await {
tracing::warn!(?err, "could not cancel query");
// couldn't reach the server. let's just throw away this conn
discard.discard();
return;
}
Ok(())
// wait for the query cancellation
match time::timeout(time::Duration::from_millis(100), inner.wait_until_ready()).await {
// we managed to cancel the query.
Ok(Ok(_)) => {}
// query failed or was cancelled.
Ok(Err(error)) => {
let db_error = error.as_db_error();
// if errored for some other reason, it might not be safe to reuse the connection.
if !db_error.is_some_and(|e| *e.code() == SqlState::QUERY_CANCELED) {
discard.discard();
}
}
Err(_timeout) => {
discard.discard();
}
}
}
async fn query_batch_to_json(
async fn query_to_json(
config: &'static HttpConfig,
cancel: CancellationToken,
tx: &mut Transaction<'_>,
queries: BatchQueryData,
headers: HttpHeaders,
) -> Result<String, SqlOverHttpError> {
let json_output = json::value_to_string!(|obj| json::value_as_object!(|obj| {
let results = obj.key("results");
json::value_as_list!(|results| {
query_batch(config, cancel, tx, queries, headers, results).await?;
});
}));
Ok(json_output)
}
async fn query_to_json<T: GenericClient>(
config: &'static HttpConfig,
client: &mut T,
data: QueryData,
cancel: &CancellationToken,
client: &mut postgres_client::Client,
data: &QueryData,
output: json::ValueSer<'_>,
parsed_headers: HttpHeaders,
) -> Result<ReadyForQueryStatus, SqlOverHttpError> {
let query_start = Instant::now();
let mut output = json::ObjectSer::new(output);
let mut row_stream = client
.query_raw_txt(&data.query, data.params)
let params = data.params.iter().map(Option::as_deref);
let mut row_stream = run_until_cancelled(client.query_raw_txt(&data.query, params), cancel)
.await
.ok_or(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Postgres))?
.map_err(SqlOverHttpError::Postgres)?;
let query_acknowledged = Instant::now();
let mut output = json::ObjectSer::new(output);
let mut json_fields = output.key("fields").list();
for c in row_stream.statement.columns() {
let json_field = json_fields.entry();
@@ -923,8 +811,13 @@ async fn query_to_json<T: GenericClient>(
// big.
let mut rows = 0;
let mut json_rows = output.key("rows").list();
while let Some(row) = row_stream.next().await {
let row = row.map_err(SqlOverHttpError::Postgres)?;
loop {
let row = run_until_cancelled(row_stream.try_next(), cancel)
.await
.ok_or(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Postgres))?
.map_err(SqlOverHttpError::Postgres)?;
let Some(row) = row else { break };
// we don't have a streaming response support yet so this is to prevent OOM
// from a malicious query (eg a cross join)
@@ -1012,12 +905,6 @@ impl Client {
}
impl Discard<'_> {
fn check_idle(&mut self, status: ReadyForQueryStatus) {
match self {
Discard::Remote(discard) => discard.check_idle(status),
Discard::Local(discard) => discard.check_idle(status),
}
}
fn discard(&mut self) {
match self {
Discard::Remote(discard) => discard.discard(),

View File

@@ -1,23 +1,50 @@
use std::pin::pin;
use std::{
pin::Pin,
task::{Context, Poll},
};
use futures::future::{Either, select};
use futures::FutureExt;
use tokio_util::sync::CancellationToken;
pub async fn run_until_cancelled<F: Future>(
pub fn run_until_cancelled<F: Future>(
f: F,
cancellation_token: &CancellationToken,
) -> Option<F::Output> {
run_until(f, cancellation_token.cancelled()).await.ok()
) -> impl Future<Output = Option<F::Output>> {
run_until(f, cancellation_token.cancelled()).map(|r| r.ok())
}
/// Runs the future `f` unless interrupted by future `condition`.
pub async fn run_until<F1: Future, F2: Future>(
pub fn run_until<F1: Future, F2: Future>(
f: F1,
condition: F2,
) -> Result<F1::Output, F2::Output> {
match select(pin!(f), pin!(condition)).await {
Either::Left((f1, _)) => Ok(f1),
Either::Right((f2, _)) => Err(f2),
) -> impl Future<Output = Result<F1::Output, F2::Output>> {
RunUntil { a: f, b: condition }
}
pin_project_lite::pin_project! {
struct RunUntil<A, B> {
#[pin] a: A,
#[pin] b: B,
}
}
impl<A, B> Future for RunUntil<A, B>
where
A: Future,
B: Future,
{
type Output = Result<A::Output, B::Output>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
if let Poll::Ready(a) = this.a.poll(cx) {
return Poll::Ready(Ok(a));
}
if let Poll::Ready(b) = this.b.poll(cx) {
return Poll::Ready(Err(b));
}
Poll::Pending
}
}

View File

@@ -3899,6 +3899,41 @@ class NeonProxy(PgProtocol):
assert response.status_code == expected_code, f"response: {response.json()}"
return response.json()
def http_multiquery(self, *queries, **kwargs):
# TODO maybe use default values if not provided
user = quote(kwargs["user"])
password = quote(kwargs["password"])
expected_code = kwargs.get("expected_code")
timeout = kwargs.get("timeout")
json_queries = []
for query in queries:
if type(query) is str:
json_queries.append({"query": query})
else:
[query, params] = query
json_queries.append({"query": query, "params": params})
queries_str = [j["query"] for j in json_queries]
log.info(f"Executing http queries: {queries_str}")
connstr = f"postgresql://{user}:{password}@{self.domain}:{self.proxy_port}/postgres"
response = requests.post(
f"https://{self.domain}:{self.external_http_port}/sql",
data=json.dumps({"queries": json_queries}),
headers={
"Content-Type": "application/sql",
"Neon-Connection-String": connstr,
"Neon-Pool-Opt-In": "true",
},
verify=str(self.test_output_dir / "proxy.crt"),
timeout=timeout,
)
if expected_code is not None:
assert response.status_code == expected_code, f"response: {response.json()}"
return response.json()
async def http2_query(self, query, args, **kwargs):
# TODO maybe use default values if not provided
user = kwargs["user"]

View File

@@ -17,9 +17,6 @@ if TYPE_CHECKING:
from typing import Any
GET_CONNECTION_PID_QUERY = "SELECT pid FROM pg_stat_activity WHERE state = 'active'"
@pytest.mark.asyncio
async def test_http_pool_begin_1(static_proxy: NeonProxy):
static_proxy.safe_psql("create user http_auth with password 'http' superuser")
@@ -479,7 +476,7 @@ def test_sql_over_http_pool(static_proxy: NeonProxy):
def get_pid(status: int, pw: str, user="http_auth") -> Any:
return static_proxy.http_query(
GET_CONNECTION_PID_QUERY,
"SELECT pg_backend_pid() as pid",
[],
user=user,
password=pw,
@@ -513,6 +510,35 @@ def test_sql_over_http_pool(static_proxy: NeonProxy):
assert "password authentication failed for user" in res["message"]
def test_sql_over_http_pool_settings(static_proxy: NeonProxy):
static_proxy.safe_psql("create user http_auth with password 'http' superuser")
def multiquery(*queries) -> Any:
results = static_proxy.http_multiquery(
*queries,
user="http_auth",
password="http",
expected_code=200,
)
return [result["rows"] for result in results["results"]]
[[intervalstyle]] = static_proxy.safe_psql("SHOW IntervalStyle")
assert intervalstyle == "postgres", "'postgres' is the default IntervalStyle in postgres"
result = multiquery("select '0 seconds'::interval as interval")
assert result[0][0]["interval"] == "00:00:00", "interval is expected in postgres format"
result = multiquery(
"SET IntervalStyle = 'iso_8601'",
"select '0 seconds'::interval as interval",
)
assert result[1][0]["interval"] == "PT0S", "interval is expected in ISO-8601 format"
result = multiquery("select '0 seconds'::interval as interval")
assert result[0][0]["interval"] == "00:00:00", "interval is expected in postgres format"
def test_sql_over_http_urlencoding(static_proxy: NeonProxy):
static_proxy.safe_psql("create user \"http+auth$$\" with password '%+$^&*@!' superuser")
@@ -544,23 +570,37 @@ def test_http_pool_begin(static_proxy: NeonProxy):
query(200, "SELECT 1;") # Query that should succeed regardless of the transaction
def test_sql_over_http_pool_idle(static_proxy: NeonProxy):
def test_sql_over_http_pool_tx_reuse(static_proxy: NeonProxy):
static_proxy.safe_psql("create user http_auth2 with password 'http' superuser")
def query(status: int, query: str) -> Any:
def query(status: int, query: str, *args) -> Any:
return static_proxy.http_query(
query,
[],
args,
user="http_auth2",
password="http",
expected_code=status,
)
pid1 = query(200, GET_CONNECTION_PID_QUERY)["rows"][0]["pid"]
def query_pid_txid() -> Any:
result = query(
200,
"SELECT pg_backend_pid() as pid, pg_current_xact_id() as txid",
)
return result["rows"][0]
res0 = query_pid_txid()
time.sleep(0.02)
query(200, "BEGIN")
pid2 = query(200, GET_CONNECTION_PID_QUERY)["rows"][0]["pid"]
assert pid1 != pid2
res1 = query_pid_txid()
res2 = query_pid_txid()
assert res0["pid"] == res1["pid"], "connection should be reused"
assert res0["pid"] == res2["pid"], "connection should be reused"
assert res1["txid"] != res2["txid"], "txid should be different"
@pytest.mark.timeout(60)