mirror of
https://github.com/neondatabase/neon.git
synced 2026-06-01 20:40:37 +00:00
Compare commits
12 Commits
conrad/fix
...
conrad/rem
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e965bd96bb | ||
|
|
14daaec98e | ||
|
|
286ac97a9c | ||
|
|
20355cb5f0 | ||
|
|
634dbd29b6 | ||
|
|
a235b241d5 | ||
|
|
539652fa4e | ||
|
|
11294ca322 | ||
|
|
84020c1328 | ||
|
|
0cc7415691 | ||
|
|
38df46b381 | ||
|
|
cdc73ad051 |
@@ -600,6 +600,7 @@ impl ParameterStatusBody {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
pub struct ReadyForQueryBody {
|
||||
status: u8,
|
||||
}
|
||||
|
||||
@@ -18,10 +18,7 @@ use crate::config::{Host, SslMode};
|
||||
use crate::query::RowStream;
|
||||
use crate::simple_query::SimpleQueryStream;
|
||||
use crate::types::{Oid, Type};
|
||||
use crate::{
|
||||
CancelToken, Error, ReadyForQueryStatus, SimpleQueryMessage, Transaction, TransactionBuilder,
|
||||
query, simple_query,
|
||||
};
|
||||
use crate::{CancelToken, Error, ReadyForQueryStatus, SimpleQueryMessage, query, simple_query};
|
||||
|
||||
pub struct Responses {
|
||||
/// new messages from conn
|
||||
@@ -32,6 +29,9 @@ pub struct Responses {
|
||||
waiting: usize,
|
||||
/// number of ReadyForQuery messages received.
|
||||
received: usize,
|
||||
|
||||
/// The last query status we received.
|
||||
last_status: ReadyForQueryStatus,
|
||||
}
|
||||
|
||||
impl Responses {
|
||||
@@ -42,7 +42,8 @@ impl Responses {
|
||||
let received = self.received;
|
||||
|
||||
// increase the query head if this is the last message.
|
||||
if let Message::ReadyForQuery(_) = message {
|
||||
if let Message::ReadyForQuery(ref status) = message {
|
||||
self.last_status = (*status).into();
|
||||
self.received += 1;
|
||||
}
|
||||
|
||||
@@ -71,6 +72,15 @@ impl Responses {
|
||||
pub async fn next(&mut self) -> Result<Message, Error> {
|
||||
future::poll_fn(|cx| self.poll_next(cx)).await
|
||||
}
|
||||
|
||||
pub async fn wait_until_ready(&mut self) -> Result<ReadyForQueryStatus, Error> {
|
||||
while self.received < self.waiting {
|
||||
if let Message::ReadyForQuery(status) = self.next().await? {
|
||||
return Ok(status.into());
|
||||
}
|
||||
}
|
||||
Ok(self.last_status)
|
||||
}
|
||||
}
|
||||
|
||||
/// A cache of type info and prepared statements for fetching type info
|
||||
@@ -95,13 +105,6 @@ impl InnerClient {
|
||||
Ok(PartialQuery(Some(self)))
|
||||
}
|
||||
|
||||
// pub fn send_with_sync<F>(&mut self, f: F) -> Result<&mut Responses, Error>
|
||||
// where
|
||||
// F: FnOnce(&mut BytesMut) -> Result<(), Error>,
|
||||
// {
|
||||
// self.start()?.send_with_sync(f)
|
||||
// }
|
||||
|
||||
pub fn send_simple_query(&mut self, query: &str) -> Result<&mut Responses, Error> {
|
||||
self.responses.waiting += 1;
|
||||
|
||||
@@ -200,6 +203,8 @@ impl Client {
|
||||
cur: BackendMessages::empty(),
|
||||
waiting: 0,
|
||||
received: 0,
|
||||
// new connections are always idle.
|
||||
last_status: ReadyForQueryStatus::Idle,
|
||||
},
|
||||
buffer: Default::default(),
|
||||
},
|
||||
@@ -233,6 +238,11 @@ impl Client {
|
||||
rx
|
||||
}
|
||||
|
||||
/// Wait until this connection has no more active queries.
|
||||
pub async fn wait_until_ready(&mut self) -> Result<ReadyForQueryStatus, Error> {
|
||||
self.inner_mut().responses.wait_until_ready().await
|
||||
}
|
||||
|
||||
/// Pass text directly to the Postgres backend to allow it to sort out typing itself and
|
||||
/// to save a roundtrip
|
||||
pub async fn query_raw_txt<S, I>(
|
||||
@@ -292,52 +302,32 @@ impl Client {
|
||||
simple_query::batch_execute(self.inner_mut(), query).await
|
||||
}
|
||||
|
||||
pub async fn discard_all(&mut self) -> Result<ReadyForQueryStatus, Error> {
|
||||
self.batch_execute("discard all").await
|
||||
}
|
||||
|
||||
/// Begins a new database transaction.
|
||||
/// Similar to `discard_all`, but it does not clear any query plans
|
||||
///
|
||||
/// The transaction will roll back by default - use the `commit` method to commit it.
|
||||
pub async fn transaction(&mut self) -> Result<Transaction<'_>, Error> {
|
||||
struct RollbackIfNotDone<'me> {
|
||||
client: &'me mut Client,
|
||||
done: bool,
|
||||
}
|
||||
/// This runs in the background, so it can be executed without `await`ing.
|
||||
pub fn reset_session_background(&mut self) -> Result<(), Error> {
|
||||
// "CLOSE ALL": closes any cursors
|
||||
// "SET SESSION AUTHORIZATION DEFAULT": resets the current_user back to the session_user
|
||||
// "RESET ALL": resets any GUCs back to their session defaults.
|
||||
// "DEALLOCATE ALL": deallocates any prepared statements
|
||||
// "UNLISTEN *": stops listening on all channels
|
||||
// "SELECT pg_advisory_unlock_all();": unlocks all advisory locks
|
||||
// "DISCARD TEMP;": drops all temporary tables
|
||||
// "DISCARD SEQUENCES;": deallocates all cached sequence state
|
||||
|
||||
impl Drop for RollbackIfNotDone<'_> {
|
||||
fn drop(&mut self) {
|
||||
if self.done {
|
||||
return;
|
||||
}
|
||||
let _responses = self.inner_mut().send_simple_query(
|
||||
"ROLLBACK;
|
||||
CLOSE ALL;
|
||||
SET SESSION AUTHORIZATION DEFAULT;
|
||||
RESET ALL;
|
||||
DEALLOCATE ALL;
|
||||
UNLISTEN *;
|
||||
SELECT pg_advisory_unlock_all();
|
||||
DISCARD TEMP;
|
||||
DISCARD SEQUENCES;",
|
||||
)?;
|
||||
|
||||
let _ = self.client.inner.send_simple_query("ROLLBACK");
|
||||
}
|
||||
}
|
||||
|
||||
// This is done, as `Future` created by this method can be dropped after
|
||||
// `RequestMessages` is synchronously send to the `Connection` by
|
||||
// `batch_execute()`, but before `Responses` is asynchronously polled to
|
||||
// completion. In that case `Transaction` won't be created and thus
|
||||
// won't be rolled back.
|
||||
{
|
||||
let mut cleaner = RollbackIfNotDone {
|
||||
client: self,
|
||||
done: false,
|
||||
};
|
||||
cleaner.client.batch_execute("BEGIN").await?;
|
||||
cleaner.done = true;
|
||||
}
|
||||
|
||||
Ok(Transaction::new(self))
|
||||
}
|
||||
|
||||
/// Returns a builder for a transaction with custom settings.
|
||||
///
|
||||
/// Unlike the `transaction` method, the builder can be used to control the transaction's isolation level and other
|
||||
/// attributes.
|
||||
pub fn build_transaction(&mut self) -> TransactionBuilder<'_> {
|
||||
TransactionBuilder::new(self)
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Constructs a cancellation token that can later be used to request cancellation of a query running on the
|
||||
|
||||
@@ -1,58 +0,0 @@
|
||||
#![allow(async_fn_in_trait)]
|
||||
|
||||
use crate::query::RowStream;
|
||||
use crate::{Client, Error, Transaction};
|
||||
|
||||
mod private {
|
||||
pub trait Sealed {}
|
||||
}
|
||||
|
||||
/// A trait allowing abstraction over connections and transactions.
|
||||
///
|
||||
/// This trait is "sealed", and cannot be implemented outside of this crate.
|
||||
pub trait GenericClient: private::Sealed {
|
||||
/// Like `Client::query_raw_txt`.
|
||||
async fn query_raw_txt<S, I>(
|
||||
&mut self,
|
||||
statement: &str,
|
||||
params: I,
|
||||
) -> Result<RowStream<'_>, Error>
|
||||
where
|
||||
S: AsRef<str> + Sync + Send,
|
||||
I: IntoIterator<Item = Option<S>> + Sync + Send,
|
||||
I::IntoIter: ExactSizeIterator + Sync + Send;
|
||||
}
|
||||
|
||||
impl private::Sealed for Client {}
|
||||
|
||||
impl GenericClient for Client {
|
||||
async fn query_raw_txt<S, I>(
|
||||
&mut self,
|
||||
statement: &str,
|
||||
params: I,
|
||||
) -> Result<RowStream<'_>, Error>
|
||||
where
|
||||
S: AsRef<str> + Sync + Send,
|
||||
I: IntoIterator<Item = Option<S>> + Sync + Send,
|
||||
I::IntoIter: ExactSizeIterator + Sync + Send,
|
||||
{
|
||||
self.query_raw_txt(statement, params).await
|
||||
}
|
||||
}
|
||||
|
||||
impl private::Sealed for Transaction<'_> {}
|
||||
|
||||
impl GenericClient for Transaction<'_> {
|
||||
async fn query_raw_txt<S, I>(
|
||||
&mut self,
|
||||
statement: &str,
|
||||
params: I,
|
||||
) -> Result<RowStream<'_>, Error>
|
||||
where
|
||||
S: AsRef<str> + Sync + Send,
|
||||
I: IntoIterator<Item = Option<S>> + Sync + Send,
|
||||
I::IntoIter: ExactSizeIterator + Sync + Send,
|
||||
{
|
||||
self.query_raw_txt(statement, params).await
|
||||
}
|
||||
}
|
||||
@@ -9,13 +9,11 @@ pub use crate::config::Config;
|
||||
pub use crate::connect_raw::RawConnection;
|
||||
pub use crate::connection::Connection;
|
||||
pub use crate::error::Error;
|
||||
pub use crate::generic_client::GenericClient;
|
||||
pub use crate::query::RowStream;
|
||||
pub use crate::row::{Row, SimpleQueryRow};
|
||||
pub use crate::simple_query::SimpleQueryStream;
|
||||
pub use crate::statement::{Column, Statement};
|
||||
pub use crate::tls::NoTls;
|
||||
pub use crate::transaction::Transaction;
|
||||
pub use crate::transaction_builder::{IsolationLevel, TransactionBuilder};
|
||||
|
||||
/// After executing a query, the connection will be in one of these states
|
||||
@@ -55,7 +53,6 @@ mod connect_socket;
|
||||
mod connect_tls;
|
||||
mod connection;
|
||||
pub mod error;
|
||||
mod generic_client;
|
||||
pub mod maybe_tls_stream;
|
||||
mod prepare;
|
||||
mod query;
|
||||
@@ -63,7 +60,6 @@ pub mod row;
|
||||
mod simple_query;
|
||||
mod statement;
|
||||
pub mod tls;
|
||||
mod transaction;
|
||||
mod transaction_builder;
|
||||
pub mod types;
|
||||
|
||||
|
||||
@@ -1,73 +0,0 @@
|
||||
use crate::query::RowStream;
|
||||
use crate::{CancelToken, Client, Error, ReadyForQueryStatus};
|
||||
|
||||
/// A representation of a PostgreSQL database transaction.
|
||||
///
|
||||
/// Transactions will implicitly roll back when dropped. Use the `commit` method to commit the changes made in the
|
||||
/// transaction. Transactions can be nested, with inner transactions implemented via safepoints.
|
||||
pub struct Transaction<'a> {
|
||||
client: &'a mut Client,
|
||||
done: bool,
|
||||
}
|
||||
|
||||
impl Drop for Transaction<'_> {
|
||||
fn drop(&mut self) {
|
||||
if self.done {
|
||||
return;
|
||||
}
|
||||
|
||||
let _ = self.client.inner_mut().send_simple_query("ROLLBACK");
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Transaction<'a> {
|
||||
pub(crate) fn new(client: &'a mut Client) -> Transaction<'a> {
|
||||
Transaction {
|
||||
client,
|
||||
done: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Consumes the transaction, committing all changes made within it.
|
||||
pub async fn commit(mut self) -> Result<ReadyForQueryStatus, Error> {
|
||||
self.done = true;
|
||||
self.client.batch_execute("COMMIT").await
|
||||
}
|
||||
|
||||
/// Rolls the transaction back, discarding all changes made within it.
|
||||
///
|
||||
/// This is equivalent to `Transaction`'s `Drop` implementation, but provides any error encountered to the caller.
|
||||
pub async fn rollback(mut self) -> Result<ReadyForQueryStatus, Error> {
|
||||
self.done = true;
|
||||
self.client.batch_execute("ROLLBACK").await
|
||||
}
|
||||
|
||||
/// Like `Client::query_raw_txt`.
|
||||
pub async fn query_raw_txt<S, I>(
|
||||
&mut self,
|
||||
statement: &str,
|
||||
params: I,
|
||||
) -> Result<RowStream<'_>, Error>
|
||||
where
|
||||
S: AsRef<str>,
|
||||
I: IntoIterator<Item = Option<S>>,
|
||||
I::IntoIter: ExactSizeIterator,
|
||||
{
|
||||
self.client.query_raw_txt(statement, params).await
|
||||
}
|
||||
|
||||
/// Like `Client::cancel_token`.
|
||||
pub fn cancel_token(&self) -> CancelToken {
|
||||
self.client.cancel_token()
|
||||
}
|
||||
|
||||
/// Returns a reference to the underlying `Client`.
|
||||
pub fn client(&self) -> &Client {
|
||||
self.client
|
||||
}
|
||||
|
||||
/// Returns a reference to the underlying `Client`.
|
||||
pub fn client_mut(&mut self) -> &mut Client {
|
||||
self.client
|
||||
}
|
||||
}
|
||||
@@ -1,5 +1,3 @@
|
||||
use crate::{Client, Error, Transaction};
|
||||
|
||||
/// The isolation level of a database transaction.
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
#[non_exhaustive]
|
||||
@@ -20,49 +18,17 @@ pub enum IsolationLevel {
|
||||
}
|
||||
|
||||
/// A builder for database transactions.
|
||||
pub struct TransactionBuilder<'a> {
|
||||
client: &'a mut Client,
|
||||
isolation_level: Option<IsolationLevel>,
|
||||
read_only: Option<bool>,
|
||||
deferrable: Option<bool>,
|
||||
pub struct TransactionBuilder {
|
||||
pub isolation_level: Option<IsolationLevel>,
|
||||
pub read_only: Option<bool>,
|
||||
pub deferrable: Option<bool>,
|
||||
}
|
||||
|
||||
impl<'a> TransactionBuilder<'a> {
|
||||
pub(crate) fn new(client: &'a mut Client) -> TransactionBuilder<'a> {
|
||||
TransactionBuilder {
|
||||
client,
|
||||
isolation_level: None,
|
||||
read_only: None,
|
||||
deferrable: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Sets the isolation level of the transaction.
|
||||
pub fn isolation_level(mut self, isolation_level: IsolationLevel) -> Self {
|
||||
self.isolation_level = Some(isolation_level);
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets the access mode of the transaction.
|
||||
pub fn read_only(mut self, read_only: bool) -> Self {
|
||||
self.read_only = Some(read_only);
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets the deferrability of the transaction.
|
||||
///
|
||||
/// If the transaction is also serializable and read only, creation of the transaction may block, but when it
|
||||
/// completes the transaction is able to run with less overhead and a guarantee that it will not be aborted due to
|
||||
/// serialization failure.
|
||||
pub fn deferrable(mut self, deferrable: bool) -> Self {
|
||||
self.deferrable = Some(deferrable);
|
||||
self
|
||||
}
|
||||
|
||||
impl TransactionBuilder {
|
||||
/// Begins the transaction.
|
||||
///
|
||||
/// The transaction will roll back by default - use the `commit` method to commit it.
|
||||
pub async fn start(self) -> Result<Transaction<'a>, Error> {
|
||||
pub fn format(self) -> String {
|
||||
let mut query = "START TRANSACTION".to_string();
|
||||
let mut first = true;
|
||||
|
||||
@@ -106,8 +72,6 @@ impl<'a> TransactionBuilder<'a> {
|
||||
query.push_str(s);
|
||||
}
|
||||
|
||||
self.client.batch_execute(&query).await?;
|
||||
|
||||
Ok(Transaction::new(self.client))
|
||||
query
|
||||
}
|
||||
}
|
||||
|
||||
@@ -18,7 +18,7 @@ use tracing::{debug, info};
|
||||
use super::AsyncRW;
|
||||
use super::conn_pool::poll_client;
|
||||
use super::conn_pool_lib::{Client, ConnInfo, EndpointConnPool, GlobalConnPool};
|
||||
use super::http_conn_pool::{self, HttpConnPool, Send, poll_http2_client};
|
||||
use super::http_conn_pool::{self, HttpConnPool, LocalProxyClient, poll_http2_client};
|
||||
use super::local_conn_pool::{self, EXT_NAME, EXT_SCHEMA, EXT_VERSION, LocalConnPool};
|
||||
use crate::auth::backend::local::StaticAuthRules;
|
||||
use crate::auth::backend::{ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo};
|
||||
@@ -40,7 +40,8 @@ use crate::rate_limiter::EndpointRateLimiter;
|
||||
use crate::types::{EndpointId, Host, LOCAL_PROXY_SUFFIX};
|
||||
|
||||
pub(crate) struct PoolingBackend {
|
||||
pub(crate) http_conn_pool: Arc<GlobalConnPool<Send, HttpConnPool<Send>>>,
|
||||
pub(crate) http_conn_pool:
|
||||
Arc<GlobalConnPool<LocalProxyClient, HttpConnPool<LocalProxyClient>>>,
|
||||
pub(crate) local_pool: Arc<LocalConnPool<postgres_client::Client>>,
|
||||
pub(crate) pool:
|
||||
Arc<GlobalConnPool<postgres_client::Client, EndpointConnPool<postgres_client::Client>>>,
|
||||
@@ -210,7 +211,7 @@ impl PoolingBackend {
|
||||
&self,
|
||||
ctx: &RequestContext,
|
||||
conn_info: ConnInfo,
|
||||
) -> Result<http_conn_pool::Client<Send>, HttpConnError> {
|
||||
) -> Result<http_conn_pool::Client<LocalProxyClient>, HttpConnError> {
|
||||
debug!("pool: looking for an existing connection");
|
||||
if let Ok(Some(client)) = self.http_conn_pool.get(ctx, &conn_info) {
|
||||
return Ok(client);
|
||||
@@ -568,7 +569,7 @@ impl ConnectMechanism for TokioMechanism {
|
||||
}
|
||||
|
||||
struct HyperMechanism {
|
||||
pool: Arc<GlobalConnPool<Send, HttpConnPool<Send>>>,
|
||||
pool: Arc<GlobalConnPool<LocalProxyClient, HttpConnPool<LocalProxyClient>>>,
|
||||
conn_info: ConnInfo,
|
||||
conn_id: uuid::Uuid,
|
||||
|
||||
@@ -578,7 +579,7 @@ struct HyperMechanism {
|
||||
|
||||
#[async_trait]
|
||||
impl ConnectMechanism for HyperMechanism {
|
||||
type Connection = http_conn_pool::Client<Send>;
|
||||
type Connection = http_conn_pool::Client<LocalProxyClient>;
|
||||
type ConnectError = HttpConnError;
|
||||
type Error = HttpConnError;
|
||||
|
||||
@@ -632,7 +633,13 @@ async fn connect_http2(
|
||||
port: u16,
|
||||
timeout: Duration,
|
||||
tls: Option<&Arc<rustls::ClientConfig>>,
|
||||
) -> Result<(http_conn_pool::Send, http_conn_pool::Connect), LocalProxyConnError> {
|
||||
) -> Result<
|
||||
(
|
||||
http_conn_pool::LocalProxyClient,
|
||||
http_conn_pool::LocalProxyConnection,
|
||||
),
|
||||
LocalProxyConnError,
|
||||
> {
|
||||
let addrs = match host_addr {
|
||||
Some(addr) => vec![SocketAddr::new(addr, port)],
|
||||
None => lookup_host((host, port))
|
||||
|
||||
@@ -190,6 +190,9 @@ mod tests {
|
||||
fn get_process_id(&self) -> i32 {
|
||||
0
|
||||
}
|
||||
fn reset(&mut self) -> Result<(), postgres_client::Error> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn create_inner() -> ClientInnerCommon<MockClient> {
|
||||
|
||||
@@ -7,10 +7,9 @@ use std::time::Duration;
|
||||
|
||||
use clashmap::ClashMap;
|
||||
use parking_lot::RwLock;
|
||||
use postgres_client::ReadyForQueryStatus;
|
||||
use rand::Rng;
|
||||
use smol_str::ToSmolStr;
|
||||
use tracing::{Span, debug, info};
|
||||
use tracing::{Span, debug, info, warn};
|
||||
|
||||
use super::backend::HttpConnError;
|
||||
use super::conn_pool::ClientDataRemote;
|
||||
@@ -188,7 +187,7 @@ impl<C: ClientInnerExt> EndpointConnPool<C> {
|
||||
self.pools.get_mut(&db_user)
|
||||
}
|
||||
|
||||
pub(crate) fn put(pool: &RwLock<Self>, conn_info: &ConnInfo, client: ClientInnerCommon<C>) {
|
||||
pub(crate) fn put(pool: &RwLock<Self>, conn_info: &ConnInfo, mut client: ClientInnerCommon<C>) {
|
||||
let conn_id = client.get_conn_id();
|
||||
let (max_conn, conn_count, pool_name) = {
|
||||
let pool = pool.read();
|
||||
@@ -201,12 +200,17 @@ impl<C: ClientInnerExt> EndpointConnPool<C> {
|
||||
};
|
||||
|
||||
if client.inner.is_closed() {
|
||||
info!(%conn_id, "{}: throwing away connection '{conn_info}' because connection is closed", pool_name);
|
||||
info!(%conn_id, "{pool_name}: throwing away connection '{conn_info}' because connection is closed");
|
||||
return;
|
||||
}
|
||||
|
||||
if let Err(error) = client.inner.reset() {
|
||||
warn!(?error, %conn_id, "{pool_name}: throwing away connection '{conn_info}' because connection could not be reset");
|
||||
return;
|
||||
}
|
||||
|
||||
if conn_count >= max_conn {
|
||||
info!(%conn_id, "{}: throwing away connection '{conn_info}' because pool is full", pool_name);
|
||||
info!(%conn_id, "{pool_name}: throwing away connection '{conn_info}' because pool is full");
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -691,6 +695,7 @@ impl<C: ClientInnerExt> Deref for Client<C> {
|
||||
pub(crate) trait ClientInnerExt: Sync + Send + 'static {
|
||||
fn is_closed(&self) -> bool;
|
||||
fn get_process_id(&self) -> i32;
|
||||
fn reset(&mut self) -> Result<(), postgres_client::Error>;
|
||||
}
|
||||
|
||||
impl ClientInnerExt for postgres_client::Client {
|
||||
@@ -701,15 +706,13 @@ impl ClientInnerExt for postgres_client::Client {
|
||||
fn get_process_id(&self) -> i32 {
|
||||
self.get_process_id()
|
||||
}
|
||||
|
||||
fn reset(&mut self) -> Result<(), postgres_client::Error> {
|
||||
self.reset_session_background()
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: ClientInnerExt> Discard<'_, C> {
|
||||
pub(crate) fn check_idle(&mut self, status: ReadyForQueryStatus) {
|
||||
let conn_info = &self.conn_info;
|
||||
if status != ReadyForQueryStatus::Idle && std::mem::take(self.pool).strong_count() > 0 {
|
||||
info!("pool: throwing away connection '{conn_info}' because connection is not idle");
|
||||
}
|
||||
}
|
||||
pub(crate) fn discard(&mut self) {
|
||||
let conn_info = &self.conn_info;
|
||||
if std::mem::take(self.pool).strong_count() > 0 {
|
||||
|
||||
@@ -23,8 +23,8 @@ use crate::protocol2::ConnectionInfoExtra;
|
||||
use crate::types::EndpointCacheKey;
|
||||
use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS};
|
||||
|
||||
pub(crate) type Send = http2::SendRequest<BoxBody<Bytes, hyper::Error>>;
|
||||
pub(crate) type Connect =
|
||||
pub(crate) type LocalProxyClient = http2::SendRequest<BoxBody<Bytes, hyper::Error>>;
|
||||
pub(crate) type LocalProxyConnection =
|
||||
http2::Connection<TokioIo<AsyncRW>, BoxBody<Bytes, hyper::Error>, TokioExecutor>;
|
||||
|
||||
#[derive(Clone)]
|
||||
@@ -189,14 +189,14 @@ impl<C: ClientInnerExt + Clone> GlobalConnPool<C, HttpConnPool<C>> {
|
||||
}
|
||||
|
||||
pub(crate) fn poll_http2_client(
|
||||
global_pool: Arc<GlobalConnPool<Send, HttpConnPool<Send>>>,
|
||||
global_pool: Arc<GlobalConnPool<LocalProxyClient, HttpConnPool<LocalProxyClient>>>,
|
||||
ctx: &RequestContext,
|
||||
conn_info: &ConnInfo,
|
||||
client: Send,
|
||||
connection: Connect,
|
||||
client: LocalProxyClient,
|
||||
connection: LocalProxyConnection,
|
||||
conn_id: uuid::Uuid,
|
||||
aux: MetricsAuxInfo,
|
||||
) -> Client<Send> {
|
||||
) -> Client<LocalProxyClient> {
|
||||
let conn_gauge = Metrics::get().proxy.db_connections.guard(ctx.protocol());
|
||||
let session_id = ctx.session_id();
|
||||
|
||||
@@ -285,7 +285,7 @@ impl<C: ClientInnerExt + Clone> Client<C> {
|
||||
}
|
||||
}
|
||||
|
||||
impl ClientInnerExt for Send {
|
||||
impl ClientInnerExt for LocalProxyClient {
|
||||
fn is_closed(&self) -> bool {
|
||||
self.is_closed()
|
||||
}
|
||||
@@ -294,4 +294,10 @@ impl ClientInnerExt for Send {
|
||||
// ideally throw something meaningful
|
||||
-1
|
||||
}
|
||||
|
||||
fn reset(&mut self) -> Result<(), postgres_client::Error> {
|
||||
// We use HTTP/2.0 to talk to local proxy. HTTP is stateless,
|
||||
// so there's nothing to reset.
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -269,11 +269,6 @@ impl ClientInnerCommon<postgres_client::Client> {
|
||||
local_data.jti += 1;
|
||||
let token = resign_jwt(&local_data.key, payload, local_data.jti)?;
|
||||
|
||||
self.inner
|
||||
.discard_all()
|
||||
.await
|
||||
.map_err(SqlOverHttpError::InternalPostgres)?;
|
||||
|
||||
// initiates the auth session
|
||||
// this is safe from query injections as the jwt format free of any escape characters.
|
||||
let query = format!("select auth.jwt_session_init('{token}')");
|
||||
|
||||
@@ -46,7 +46,7 @@ use super::backend::{HttpConnError, LocalProxyConnError, PoolingBackend};
|
||||
use super::conn_pool::AuthData;
|
||||
use super::conn_pool_lib::ConnInfo;
|
||||
use super::error::{ConnInfoError, Credentials, HttpCodeError, ReadPayloadError};
|
||||
use super::http_conn_pool::{self, Send};
|
||||
use super::http_conn_pool::{self, LocalProxyClient};
|
||||
use super::http_util::{
|
||||
ALLOW_POOL, CONN_STRING, NEON_REQUEST_ID, RAW_TEXT_OUTPUT, TXN_ISOLATION_LEVEL, TXN_READ_ONLY,
|
||||
get_conn_info, json_response, uuid_to_header_value,
|
||||
@@ -145,7 +145,7 @@ impl DbSchemaCache {
|
||||
endpoint_id: &EndpointCacheKey,
|
||||
auth_header: &HeaderValue,
|
||||
connection_string: &str,
|
||||
client: &mut http_conn_pool::Client<Send>,
|
||||
client: &mut http_conn_pool::Client<LocalProxyClient>,
|
||||
ctx: &RequestContext,
|
||||
config: &'static ProxyConfig,
|
||||
) -> Result<Arc<(ApiConfig, DbSchemaOwned)>, RestError> {
|
||||
@@ -190,7 +190,7 @@ impl DbSchemaCache {
|
||||
&self,
|
||||
auth_header: &HeaderValue,
|
||||
connection_string: &str,
|
||||
client: &mut http_conn_pool::Client<Send>,
|
||||
client: &mut http_conn_pool::Client<LocalProxyClient>,
|
||||
ctx: &RequestContext,
|
||||
config: &'static ProxyConfig,
|
||||
) -> Result<(ApiConfig, DbSchemaOwned), RestError> {
|
||||
@@ -430,7 +430,7 @@ struct BatchQueryData<'a> {
|
||||
}
|
||||
|
||||
async fn make_local_proxy_request<S: DeserializeOwned>(
|
||||
client: &mut http_conn_pool::Client<Send>,
|
||||
client: &mut http_conn_pool::Client<LocalProxyClient>,
|
||||
headers: impl IntoIterator<Item = (&HeaderName, HeaderValue)>,
|
||||
body: QueryData<'_>,
|
||||
max_len: usize,
|
||||
@@ -461,7 +461,7 @@ async fn make_local_proxy_request<S: DeserializeOwned>(
|
||||
}
|
||||
|
||||
async fn make_raw_local_proxy_request(
|
||||
client: &mut http_conn_pool::Client<Send>,
|
||||
client: &mut http_conn_pool::Client<LocalProxyClient>,
|
||||
headers: impl IntoIterator<Item = (&HeaderName, HeaderValue)>,
|
||||
body: String,
|
||||
) -> Result<Response<Incoming>, RestError> {
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
use std::pin::pin;
|
||||
use std::sync::Arc;
|
||||
|
||||
use bytes::Bytes;
|
||||
use futures::future::{Either, select, try_join};
|
||||
use futures::{StreamExt, TryFutureExt};
|
||||
use futures::future::try_join;
|
||||
use futures::{TryFutureExt, TryStreamExt};
|
||||
use http::Method;
|
||||
use http::header::AUTHORIZATION;
|
||||
use http_body_util::combinators::BoxBody;
|
||||
@@ -14,7 +13,7 @@ use hyper::http::{HeaderName, HeaderValue};
|
||||
use hyper::{Request, Response, StatusCode, header};
|
||||
use indexmap::IndexMap;
|
||||
use postgres_client::error::{DbError, ErrorPosition, SqlState};
|
||||
use postgres_client::{GenericClient, IsolationLevel, NoTls, ReadyForQueryStatus, Transaction};
|
||||
use postgres_client::{IsolationLevel, NoTls, ReadyForQueryStatus, TransactionBuilder};
|
||||
use serde_json::Value;
|
||||
use serde_json::value::RawValue;
|
||||
use tokio::time::{self, Instant};
|
||||
@@ -495,7 +494,7 @@ async fn handle_db_inner(
|
||||
.http_conn_content_length_bytes
|
||||
.observe(HttpDirection::Request, body.len() as f64);
|
||||
|
||||
debug!(length = body.len(), "request payload read");
|
||||
debug!(length = body.len(), "request payload read ");
|
||||
let payload: Payload = serde_json::from_slice(&body)?;
|
||||
Ok::<Payload, ReadPayloadError>(payload) // Adjust error type accordingly
|
||||
}
|
||||
@@ -530,13 +529,13 @@ async fn handle_db_inner(
|
||||
let (cli_inner, _dsc) = client.client_inner();
|
||||
cli_inner.set_jwt_session(&payload).await?;
|
||||
}
|
||||
Client::Local(client)
|
||||
Box::new(Client::Local(client))
|
||||
}
|
||||
_ => {
|
||||
let client = backend
|
||||
.connect_to_compute(ctx, conn_info, keys, !allow_pool)
|
||||
.await?;
|
||||
Client::Remote(client)
|
||||
Box::new(Client::Remote(client))
|
||||
}
|
||||
};
|
||||
|
||||
@@ -550,10 +549,7 @@ async fn handle_db_inner(
|
||||
|
||||
let (payload, mut client) = match run_until_cancelled(
|
||||
// Run both operations in parallel
|
||||
try_join(
|
||||
pin!(fetch_and_process_request),
|
||||
pin!(authenticate_and_connect),
|
||||
),
|
||||
try_join(fetch_and_process_request, authenticate_and_connect),
|
||||
&cancel,
|
||||
)
|
||||
.await
|
||||
@@ -562,37 +558,38 @@ async fn handle_db_inner(
|
||||
None => return Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Connect)),
|
||||
};
|
||||
|
||||
// Now execute the query and return the result.
|
||||
let json_output = match payload
|
||||
.process(&config.http_config, cancel, &mut client, parsed_headers)
|
||||
.await
|
||||
{
|
||||
Ok(json_output) => json_output,
|
||||
Err(error) => {
|
||||
if let SqlOverHttpError::Cancelled(_) = error {
|
||||
cancel_query(&mut client).await;
|
||||
}
|
||||
return Err(error);
|
||||
}
|
||||
};
|
||||
|
||||
let mut response = Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
.header(header::CONTENT_TYPE, "application/json");
|
||||
|
||||
// Now execute the query and return the result.
|
||||
let json_output = match payload {
|
||||
Payload::Single(stmt) => {
|
||||
stmt.process(&config.http_config, cancel, &mut client, parsed_headers)
|
||||
.await?
|
||||
if let Payload::Batch(_) = payload {
|
||||
if parsed_headers.txn_read_only {
|
||||
response = response.header(TXN_READ_ONLY.clone(), &HEADER_VALUE_TRUE);
|
||||
}
|
||||
Payload::Batch(statements) => {
|
||||
if parsed_headers.txn_read_only {
|
||||
response = response.header(TXN_READ_ONLY.clone(), &HEADER_VALUE_TRUE);
|
||||
}
|
||||
if parsed_headers.txn_deferrable {
|
||||
response = response.header(TXN_DEFERRABLE.clone(), &HEADER_VALUE_TRUE);
|
||||
}
|
||||
if let Some(txn_isolation_level) = parsed_headers
|
||||
.txn_isolation_level
|
||||
.and_then(map_isolation_level_to_headers)
|
||||
{
|
||||
response = response.header(TXN_ISOLATION_LEVEL.clone(), txn_isolation_level);
|
||||
}
|
||||
|
||||
statements
|
||||
.process(&config.http_config, cancel, &mut client, parsed_headers)
|
||||
.await?
|
||||
if parsed_headers.txn_deferrable {
|
||||
response = response.header(TXN_DEFERRABLE.clone(), &HEADER_VALUE_TRUE);
|
||||
}
|
||||
};
|
||||
|
||||
let metrics = client.metrics(ctx);
|
||||
if let Some(txn_isolation_level) = parsed_headers
|
||||
.txn_isolation_level
|
||||
.and_then(map_isolation_level_to_headers)
|
||||
{
|
||||
response = response.header(TXN_ISOLATION_LEVEL.clone(), txn_isolation_level);
|
||||
}
|
||||
}
|
||||
|
||||
let len = json_output.len();
|
||||
let response = response
|
||||
@@ -607,6 +604,7 @@ async fn handle_db_inner(
|
||||
|
||||
// count the egress bytes - we miss the TLS and header overhead but oh well...
|
||||
// moving this later in the stack is going to be a lot of effort and ehhhh
|
||||
let metrics = client.metrics(ctx);
|
||||
metrics.record_egress(len as u64);
|
||||
metrics.record_ingress(request_len as u64);
|
||||
|
||||
@@ -673,233 +671,123 @@ async fn handle_auth_broker_inner(
|
||||
.map(|b| b.boxed()))
|
||||
}
|
||||
|
||||
impl QueryData {
|
||||
impl Payload {
|
||||
async fn process(
|
||||
self,
|
||||
&self,
|
||||
config: &'static HttpConfig,
|
||||
cancel: CancellationToken,
|
||||
client: &mut Client,
|
||||
parsed_headers: HttpHeaders,
|
||||
) -> Result<String, SqlOverHttpError> {
|
||||
let (inner, mut discard) = client.inner();
|
||||
let cancel_token = inner.cancel_token();
|
||||
|
||||
let mut json_buf = vec![];
|
||||
let needs_tx = matches!(self, Payload::Batch(_));
|
||||
|
||||
let batch_result = match select(
|
||||
pin!(query_to_json(
|
||||
config,
|
||||
&mut *inner,
|
||||
self,
|
||||
json::ValueSer::new(&mut json_buf),
|
||||
parsed_headers
|
||||
)),
|
||||
pin!(cancel.cancelled()),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Either::Left((res, __not_yet_cancelled)) => res,
|
||||
Either::Right((_cancelled, query)) => {
|
||||
tracing::info!("cancelling query");
|
||||
if let Err(err) = cancel_token.cancel_query(NoTls).await {
|
||||
tracing::warn!(?err, "could not cancel query");
|
||||
if needs_tx {
|
||||
info!("starting transaction");
|
||||
let query = TransactionBuilder {
|
||||
isolation_level: parsed_headers.txn_isolation_level,
|
||||
read_only: parsed_headers.txn_read_only.then_some(true),
|
||||
deferrable: parsed_headers.txn_deferrable.then_some(true),
|
||||
}
|
||||
.format();
|
||||
|
||||
inner
|
||||
.batch_execute(&query)
|
||||
.await
|
||||
.inspect_err(|_| {
|
||||
// if we cannot start a transaction, we should return immediately
|
||||
// and not return to the pool. connection is clearly broken
|
||||
discard.discard();
|
||||
})
|
||||
.map_err(SqlOverHttpError::Postgres)?;
|
||||
}
|
||||
|
||||
let json_output = json::value_to_string!(|value| match &self {
|
||||
Payload::Single(query) => {
|
||||
query_to_json(config, &cancel, inner, query, value, parsed_headers).await?;
|
||||
}
|
||||
Payload::Batch(batch) => {
|
||||
let mut obj = value.object();
|
||||
let mut results = obj.key("results").list();
|
||||
|
||||
for query in &batch.queries {
|
||||
let value = results.entry();
|
||||
query_to_json(config, &cancel, inner, query, value, parsed_headers).await?;
|
||||
}
|
||||
// wait for the query cancellation
|
||||
match time::timeout(time::Duration::from_millis(100), query).await {
|
||||
// query successed before it was cancelled.
|
||||
Ok(Ok(status)) => Ok(status),
|
||||
// query failed or was cancelled.
|
||||
Ok(Err(error)) => {
|
||||
let db_error = match &error {
|
||||
SqlOverHttpError::ConnectCompute(
|
||||
HttpConnError::PostgresConnectionError(e),
|
||||
)
|
||||
| SqlOverHttpError::Postgres(e) => e.as_db_error(),
|
||||
_ => None,
|
||||
};
|
||||
|
||||
// if errored for some other reason, it might not be safe to return
|
||||
if !db_error.is_some_and(|e| *e.code() == SqlState::QUERY_CANCELED) {
|
||||
discard.discard();
|
||||
}
|
||||
|
||||
return Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Postgres));
|
||||
}
|
||||
Err(_timeout) => {
|
||||
discard.discard();
|
||||
return Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Postgres));
|
||||
}
|
||||
}
|
||||
results.finish();
|
||||
obj.finish();
|
||||
}
|
||||
};
|
||||
});
|
||||
|
||||
match batch_result {
|
||||
// The query successfully completed.
|
||||
Ok(status) => {
|
||||
discard.check_idle(status);
|
||||
|
||||
let json_output = String::from_utf8(json_buf).expect("json should be valid utf8");
|
||||
Ok(json_output)
|
||||
}
|
||||
// The query failed with an error
|
||||
Err(e) => {
|
||||
discard.discard();
|
||||
Err(e)
|
||||
}
|
||||
if needs_tx {
|
||||
inner
|
||||
.batch_execute("COMMIT")
|
||||
.await
|
||||
.inspect_err(|_| {
|
||||
// if we cannot commit - for now don't return connection to pool
|
||||
// TODO: get a query status from the error
|
||||
discard.discard();
|
||||
})
|
||||
.map_err(SqlOverHttpError::Postgres)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl BatchQueryData {
|
||||
async fn process(
|
||||
self,
|
||||
config: &'static HttpConfig,
|
||||
cancel: CancellationToken,
|
||||
client: &mut Client,
|
||||
parsed_headers: HttpHeaders,
|
||||
) -> Result<String, SqlOverHttpError> {
|
||||
info!("starting transaction");
|
||||
let (inner, mut discard) = client.inner();
|
||||
let cancel_token = inner.cancel_token();
|
||||
let mut builder = inner.build_transaction();
|
||||
if let Some(isolation_level) = parsed_headers.txn_isolation_level {
|
||||
builder = builder.isolation_level(isolation_level);
|
||||
}
|
||||
if parsed_headers.txn_read_only {
|
||||
builder = builder.read_only(true);
|
||||
}
|
||||
if parsed_headers.txn_deferrable {
|
||||
builder = builder.deferrable(true);
|
||||
}
|
||||
|
||||
let mut transaction = builder
|
||||
.start()
|
||||
.await
|
||||
.inspect_err(|_| {
|
||||
// if we cannot start a transaction, we should return immediately
|
||||
// and not return to the pool. connection is clearly broken
|
||||
discard.discard();
|
||||
})
|
||||
.map_err(SqlOverHttpError::Postgres)?;
|
||||
|
||||
let json_output = match query_batch_to_json(
|
||||
config,
|
||||
cancel.child_token(),
|
||||
&mut transaction,
|
||||
self,
|
||||
parsed_headers,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(json_output) => {
|
||||
info!("commit");
|
||||
let status = transaction
|
||||
.commit()
|
||||
.await
|
||||
.inspect_err(|_| {
|
||||
// if we cannot commit - for now don't return connection to pool
|
||||
// TODO: get a query status from the error
|
||||
discard.discard();
|
||||
})
|
||||
.map_err(SqlOverHttpError::Postgres)?;
|
||||
discard.check_idle(status);
|
||||
json_output
|
||||
}
|
||||
Err(SqlOverHttpError::Cancelled(_)) => {
|
||||
if let Err(err) = cancel_token.cancel_query(NoTls).await {
|
||||
tracing::warn!(?err, "could not cancel query");
|
||||
}
|
||||
// TODO: after cancelling, wait to see if we can get a status. maybe the connection is still safe.
|
||||
discard.discard();
|
||||
|
||||
return Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Postgres));
|
||||
}
|
||||
Err(err) => {
|
||||
info!("rollback");
|
||||
let status = transaction
|
||||
.rollback()
|
||||
.await
|
||||
.inspect_err(|_| {
|
||||
// if we cannot rollback - for now don't return connection to pool
|
||||
// TODO: get a query status from the error
|
||||
discard.discard();
|
||||
})
|
||||
.map_err(SqlOverHttpError::Postgres)?;
|
||||
discard.check_idle(status);
|
||||
return Err(err);
|
||||
}
|
||||
};
|
||||
|
||||
Ok(json_output)
|
||||
}
|
||||
}
|
||||
|
||||
async fn query_batch(
|
||||
config: &'static HttpConfig,
|
||||
cancel: CancellationToken,
|
||||
transaction: &mut Transaction<'_>,
|
||||
queries: BatchQueryData,
|
||||
parsed_headers: HttpHeaders,
|
||||
results: &mut json::ListSer<'_>,
|
||||
) -> Result<(), SqlOverHttpError> {
|
||||
for stmt in queries.queries {
|
||||
let query = pin!(query_to_json(
|
||||
config,
|
||||
transaction,
|
||||
stmt,
|
||||
results.entry(),
|
||||
parsed_headers,
|
||||
));
|
||||
let cancelled = pin!(cancel.cancelled());
|
||||
let res = select(query, cancelled).await;
|
||||
match res {
|
||||
// TODO: maybe we should check that the transaction bit is set here
|
||||
Either::Left((Ok(_), _cancelled)) => {}
|
||||
Either::Left((Err(e), _cancelled)) => {
|
||||
return Err(e);
|
||||
}
|
||||
Either::Right((_cancelled, _)) => {
|
||||
return Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Postgres));
|
||||
}
|
||||
}
|
||||
async fn cancel_query(client: &mut Client) {
|
||||
let (inner, mut discard) = client.inner();
|
||||
let cancel_token = inner.cancel_token();
|
||||
|
||||
if let Err(err) = cancel_token.cancel_query(NoTls).await {
|
||||
tracing::warn!(?err, "could not cancel query");
|
||||
|
||||
// couldn't reach the server. let's just throw away this conn
|
||||
discard.discard();
|
||||
return;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
// wait for the query cancellation
|
||||
match time::timeout(time::Duration::from_millis(100), inner.wait_until_ready()).await {
|
||||
// we managed to cancel the query.
|
||||
Ok(Ok(_)) => {}
|
||||
// query failed or was cancelled.
|
||||
Ok(Err(error)) => {
|
||||
let db_error = error.as_db_error();
|
||||
|
||||
// if errored for some other reason, it might not be safe to reuse the connection.
|
||||
if !db_error.is_some_and(|e| *e.code() == SqlState::QUERY_CANCELED) {
|
||||
discard.discard();
|
||||
}
|
||||
}
|
||||
Err(_timeout) => {
|
||||
discard.discard();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn query_batch_to_json(
|
||||
async fn query_to_json(
|
||||
config: &'static HttpConfig,
|
||||
cancel: CancellationToken,
|
||||
tx: &mut Transaction<'_>,
|
||||
queries: BatchQueryData,
|
||||
headers: HttpHeaders,
|
||||
) -> Result<String, SqlOverHttpError> {
|
||||
let json_output = json::value_to_string!(|obj| json::value_as_object!(|obj| {
|
||||
let results = obj.key("results");
|
||||
json::value_as_list!(|results| {
|
||||
query_batch(config, cancel, tx, queries, headers, results).await?;
|
||||
});
|
||||
}));
|
||||
|
||||
Ok(json_output)
|
||||
}
|
||||
|
||||
async fn query_to_json<T: GenericClient>(
|
||||
config: &'static HttpConfig,
|
||||
client: &mut T,
|
||||
data: QueryData,
|
||||
cancel: &CancellationToken,
|
||||
client: &mut postgres_client::Client,
|
||||
data: &QueryData,
|
||||
output: json::ValueSer<'_>,
|
||||
parsed_headers: HttpHeaders,
|
||||
) -> Result<ReadyForQueryStatus, SqlOverHttpError> {
|
||||
let query_start = Instant::now();
|
||||
|
||||
let mut output = json::ObjectSer::new(output);
|
||||
let mut row_stream = client
|
||||
.query_raw_txt(&data.query, data.params)
|
||||
let params = data.params.iter().map(Option::as_deref);
|
||||
let mut row_stream = run_until_cancelled(client.query_raw_txt(&data.query, params), cancel)
|
||||
.await
|
||||
.ok_or(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Postgres))?
|
||||
.map_err(SqlOverHttpError::Postgres)?;
|
||||
|
||||
let query_acknowledged = Instant::now();
|
||||
|
||||
let mut output = json::ObjectSer::new(output);
|
||||
|
||||
let mut json_fields = output.key("fields").list();
|
||||
for c in row_stream.statement.columns() {
|
||||
let json_field = json_fields.entry();
|
||||
@@ -923,8 +811,13 @@ async fn query_to_json<T: GenericClient>(
|
||||
// big.
|
||||
let mut rows = 0;
|
||||
let mut json_rows = output.key("rows").list();
|
||||
while let Some(row) = row_stream.next().await {
|
||||
let row = row.map_err(SqlOverHttpError::Postgres)?;
|
||||
loop {
|
||||
let row = run_until_cancelled(row_stream.try_next(), cancel)
|
||||
.await
|
||||
.ok_or(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Postgres))?
|
||||
.map_err(SqlOverHttpError::Postgres)?;
|
||||
|
||||
let Some(row) = row else { break };
|
||||
|
||||
// we don't have a streaming response support yet so this is to prevent OOM
|
||||
// from a malicious query (eg a cross join)
|
||||
@@ -1012,12 +905,6 @@ impl Client {
|
||||
}
|
||||
|
||||
impl Discard<'_> {
|
||||
fn check_idle(&mut self, status: ReadyForQueryStatus) {
|
||||
match self {
|
||||
Discard::Remote(discard) => discard.check_idle(status),
|
||||
Discard::Local(discard) => discard.check_idle(status),
|
||||
}
|
||||
}
|
||||
fn discard(&mut self) {
|
||||
match self {
|
||||
Discard::Remote(discard) => discard.discard(),
|
||||
|
||||
@@ -1,23 +1,50 @@
|
||||
use std::pin::pin;
|
||||
use std::{
|
||||
pin::Pin,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
|
||||
use futures::future::{Either, select};
|
||||
use futures::FutureExt;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
pub async fn run_until_cancelled<F: Future>(
|
||||
pub fn run_until_cancelled<F: Future>(
|
||||
f: F,
|
||||
cancellation_token: &CancellationToken,
|
||||
) -> Option<F::Output> {
|
||||
run_until(f, cancellation_token.cancelled()).await.ok()
|
||||
) -> impl Future<Output = Option<F::Output>> {
|
||||
run_until(f, cancellation_token.cancelled()).map(|r| r.ok())
|
||||
}
|
||||
|
||||
/// Runs the future `f` unless interrupted by future `condition`.
|
||||
pub async fn run_until<F1: Future, F2: Future>(
|
||||
pub fn run_until<F1: Future, F2: Future>(
|
||||
f: F1,
|
||||
condition: F2,
|
||||
) -> Result<F1::Output, F2::Output> {
|
||||
match select(pin!(f), pin!(condition)).await {
|
||||
Either::Left((f1, _)) => Ok(f1),
|
||||
Either::Right((f2, _)) => Err(f2),
|
||||
) -> impl Future<Output = Result<F1::Output, F2::Output>> {
|
||||
RunUntil { a: f, b: condition }
|
||||
}
|
||||
|
||||
pin_project_lite::pin_project! {
|
||||
struct RunUntil<A, B> {
|
||||
#[pin] a: A,
|
||||
#[pin] b: B,
|
||||
}
|
||||
}
|
||||
|
||||
impl<A, B> Future for RunUntil<A, B>
|
||||
where
|
||||
A: Future,
|
||||
B: Future,
|
||||
{
|
||||
type Output = Result<A::Output, B::Output>;
|
||||
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let this = self.project();
|
||||
|
||||
if let Poll::Ready(a) = this.a.poll(cx) {
|
||||
return Poll::Ready(Ok(a));
|
||||
}
|
||||
if let Poll::Ready(b) = this.b.poll(cx) {
|
||||
return Poll::Ready(Err(b));
|
||||
}
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -3899,6 +3899,41 @@ class NeonProxy(PgProtocol):
|
||||
assert response.status_code == expected_code, f"response: {response.json()}"
|
||||
return response.json()
|
||||
|
||||
def http_multiquery(self, *queries, **kwargs):
|
||||
# TODO maybe use default values if not provided
|
||||
user = quote(kwargs["user"])
|
||||
password = quote(kwargs["password"])
|
||||
expected_code = kwargs.get("expected_code")
|
||||
timeout = kwargs.get("timeout")
|
||||
|
||||
json_queries = []
|
||||
for query in queries:
|
||||
if type(query) is str:
|
||||
json_queries.append({"query": query})
|
||||
else:
|
||||
[query, params] = query
|
||||
json_queries.append({"query": query, "params": params})
|
||||
|
||||
queries_str = [j["query"] for j in json_queries]
|
||||
log.info(f"Executing http queries: {queries_str}")
|
||||
|
||||
connstr = f"postgresql://{user}:{password}@{self.domain}:{self.proxy_port}/postgres"
|
||||
response = requests.post(
|
||||
f"https://{self.domain}:{self.external_http_port}/sql",
|
||||
data=json.dumps({"queries": json_queries}),
|
||||
headers={
|
||||
"Content-Type": "application/sql",
|
||||
"Neon-Connection-String": connstr,
|
||||
"Neon-Pool-Opt-In": "true",
|
||||
},
|
||||
verify=str(self.test_output_dir / "proxy.crt"),
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
if expected_code is not None:
|
||||
assert response.status_code == expected_code, f"response: {response.json()}"
|
||||
return response.json()
|
||||
|
||||
async def http2_query(self, query, args, **kwargs):
|
||||
# TODO maybe use default values if not provided
|
||||
user = kwargs["user"]
|
||||
|
||||
@@ -17,9 +17,6 @@ if TYPE_CHECKING:
|
||||
from typing import Any
|
||||
|
||||
|
||||
GET_CONNECTION_PID_QUERY = "SELECT pid FROM pg_stat_activity WHERE state = 'active'"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_http_pool_begin_1(static_proxy: NeonProxy):
|
||||
static_proxy.safe_psql("create user http_auth with password 'http' superuser")
|
||||
@@ -479,7 +476,7 @@ def test_sql_over_http_pool(static_proxy: NeonProxy):
|
||||
|
||||
def get_pid(status: int, pw: str, user="http_auth") -> Any:
|
||||
return static_proxy.http_query(
|
||||
GET_CONNECTION_PID_QUERY,
|
||||
"SELECT pg_backend_pid() as pid",
|
||||
[],
|
||||
user=user,
|
||||
password=pw,
|
||||
@@ -513,6 +510,35 @@ def test_sql_over_http_pool(static_proxy: NeonProxy):
|
||||
assert "password authentication failed for user" in res["message"]
|
||||
|
||||
|
||||
def test_sql_over_http_pool_settings(static_proxy: NeonProxy):
|
||||
static_proxy.safe_psql("create user http_auth with password 'http' superuser")
|
||||
|
||||
def multiquery(*queries) -> Any:
|
||||
results = static_proxy.http_multiquery(
|
||||
*queries,
|
||||
user="http_auth",
|
||||
password="http",
|
||||
expected_code=200,
|
||||
)
|
||||
|
||||
return [result["rows"] for result in results["results"]]
|
||||
|
||||
[[intervalstyle]] = static_proxy.safe_psql("SHOW IntervalStyle")
|
||||
assert intervalstyle == "postgres", "'postgres' is the default IntervalStyle in postgres"
|
||||
|
||||
result = multiquery("select '0 seconds'::interval as interval")
|
||||
assert result[0][0]["interval"] == "00:00:00", "interval is expected in postgres format"
|
||||
|
||||
result = multiquery(
|
||||
"SET IntervalStyle = 'iso_8601'",
|
||||
"select '0 seconds'::interval as interval",
|
||||
)
|
||||
assert result[1][0]["interval"] == "PT0S", "interval is expected in ISO-8601 format"
|
||||
|
||||
result = multiquery("select '0 seconds'::interval as interval")
|
||||
assert result[0][0]["interval"] == "00:00:00", "interval is expected in postgres format"
|
||||
|
||||
|
||||
def test_sql_over_http_urlencoding(static_proxy: NeonProxy):
|
||||
static_proxy.safe_psql("create user \"http+auth$$\" with password '%+$^&*@!' superuser")
|
||||
|
||||
@@ -544,23 +570,37 @@ def test_http_pool_begin(static_proxy: NeonProxy):
|
||||
query(200, "SELECT 1;") # Query that should succeed regardless of the transaction
|
||||
|
||||
|
||||
def test_sql_over_http_pool_idle(static_proxy: NeonProxy):
|
||||
def test_sql_over_http_pool_tx_reuse(static_proxy: NeonProxy):
|
||||
static_proxy.safe_psql("create user http_auth2 with password 'http' superuser")
|
||||
|
||||
def query(status: int, query: str) -> Any:
|
||||
def query(status: int, query: str, *args) -> Any:
|
||||
return static_proxy.http_query(
|
||||
query,
|
||||
[],
|
||||
args,
|
||||
user="http_auth2",
|
||||
password="http",
|
||||
expected_code=status,
|
||||
)
|
||||
|
||||
pid1 = query(200, GET_CONNECTION_PID_QUERY)["rows"][0]["pid"]
|
||||
def query_pid_txid() -> Any:
|
||||
result = query(
|
||||
200,
|
||||
"SELECT pg_backend_pid() as pid, pg_current_xact_id() as txid",
|
||||
)
|
||||
|
||||
return result["rows"][0]
|
||||
|
||||
res0 = query_pid_txid()
|
||||
|
||||
time.sleep(0.02)
|
||||
query(200, "BEGIN")
|
||||
pid2 = query(200, GET_CONNECTION_PID_QUERY)["rows"][0]["pid"]
|
||||
assert pid1 != pid2
|
||||
|
||||
res1 = query_pid_txid()
|
||||
res2 = query_pid_txid()
|
||||
|
||||
assert res0["pid"] == res1["pid"], "connection should be reused"
|
||||
assert res0["pid"] == res2["pid"], "connection should be reused"
|
||||
assert res1["txid"] != res2["txid"], "txid should be different"
|
||||
|
||||
|
||||
@pytest.mark.timeout(60)
|
||||
|
||||
Reference in New Issue
Block a user