From a695713727f77891cf3cc85077c44dee4d7c84fd Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Wed, 23 Jul 2025 18:43:43 +0100 Subject: [PATCH] [sql-over-http] Reset session state between pooled connection re-use (#12681) Session variables can be set during one sql-over-http query and observed on another when that pooled connection is re-used. To address this we can use `RESET ALL;` before re-using the connection. LKB-2495 To be on the safe side, we can opt for a full `DISCARD ALL;`, but that might have performance regressions since it also clears any query plans. See pgbouncer docs https://www.pgbouncer.org/config.html#server_reset_query. `DISCARD ALL` is currently defined as: ``` CLOSE ALL; SET SESSION AUTHORIZATION DEFAULT; RESET ALL; DEALLOCATE ALL; UNLISTEN *; SELECT pg_advisory_unlock_all(); DISCARD PLANS; DISCARD TEMP; DISCARD SEQUENCES; ``` I've opted to keep everything here except the `DISCARD PLANS`. I've modified the code so that this query is executed in the background when a connection is returned to the pool, rather than when taken from the pool. This should marginally improve performance for Neon RLS by removing 1 (localhost) round trip. I don't believe that keeping query plans could be a security concern. It's a potential side channel, but I can't imagine what you could extract from it. --- Thanks to https://github.com/neondatabase/neon/pull/12659#discussion_r2219016205 for probing the idea in my head. --- libs/proxy/tokio-postgres2/src/client.rs | 28 ++++++++++- proxy/src/serverless/backend.rs | 19 +++++--- proxy/src/serverless/conn_pool.rs | 3 ++ proxy/src/serverless/conn_pool_lib.rs | 25 +++++----- proxy/src/serverless/http_conn_pool.rs | 20 +++++--- proxy/src/serverless/local_conn_pool.rs | 5 -- proxy/src/serverless/rest.rs | 10 ++-- proxy/src/serverless/sql_over_http.rs | 24 +--------- test_runner/fixtures/neon_fixtures.py | 35 ++++++++++++++ test_runner/regress/test_proxy.py | 60 ++++++++++++++++++++---- 10 files changed, 161 insertions(+), 68 deletions(-) diff --git a/libs/proxy/tokio-postgres2/src/client.rs b/libs/proxy/tokio-postgres2/src/client.rs index 068566e955..f8aceb5263 100644 --- a/libs/proxy/tokio-postgres2/src/client.rs +++ b/libs/proxy/tokio-postgres2/src/client.rs @@ -292,8 +292,32 @@ impl Client { simple_query::batch_execute(self.inner_mut(), query).await } - pub async fn discard_all(&mut self) -> Result { - self.batch_execute("discard all").await + /// Similar to `discard_all`, but it does not clear any query plans + /// + /// This runs in the background, so it can be executed without `await`ing. + pub fn reset_session_background(&mut self) -> Result<(), Error> { + // "CLOSE ALL": closes any cursors + // "SET SESSION AUTHORIZATION DEFAULT": resets the current_user back to the session_user + // "RESET ALL": resets any GUCs back to their session defaults. + // "DEALLOCATE ALL": deallocates any prepared statements + // "UNLISTEN *": stops listening on all channels + // "SELECT pg_advisory_unlock_all();": unlocks all advisory locks + // "DISCARD TEMP;": drops all temporary tables + // "DISCARD SEQUENCES;": deallocates all cached sequence state + + let _responses = self.inner_mut().send_simple_query( + "ROLLBACK; + CLOSE ALL; + SET SESSION AUTHORIZATION DEFAULT; + RESET ALL; + DEALLOCATE ALL; + UNLISTEN *; + SELECT pg_advisory_unlock_all(); + DISCARD TEMP; + DISCARD SEQUENCES;", + )?; + + Ok(()) } /// Begins a new database transaction. diff --git a/proxy/src/serverless/backend.rs b/proxy/src/serverless/backend.rs index 59e4b09bc9..31df7eb9f1 100644 --- a/proxy/src/serverless/backend.rs +++ b/proxy/src/serverless/backend.rs @@ -18,7 +18,7 @@ use tracing::{debug, info}; use super::AsyncRW; use super::conn_pool::poll_client; use super::conn_pool_lib::{Client, ConnInfo, EndpointConnPool, GlobalConnPool}; -use super::http_conn_pool::{self, HttpConnPool, Send, poll_http2_client}; +use super::http_conn_pool::{self, HttpConnPool, LocalProxyClient, poll_http2_client}; use super::local_conn_pool::{self, EXT_NAME, EXT_SCHEMA, EXT_VERSION, LocalConnPool}; use crate::auth::backend::local::StaticAuthRules; use crate::auth::backend::{ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo}; @@ -40,7 +40,8 @@ use crate::rate_limiter::EndpointRateLimiter; use crate::types::{EndpointId, Host, LOCAL_PROXY_SUFFIX}; pub(crate) struct PoolingBackend { - pub(crate) http_conn_pool: Arc>>, + pub(crate) http_conn_pool: + Arc>>, pub(crate) local_pool: Arc>, pub(crate) pool: Arc>>, @@ -210,7 +211,7 @@ impl PoolingBackend { &self, ctx: &RequestContext, conn_info: ConnInfo, - ) -> Result, HttpConnError> { + ) -> Result, HttpConnError> { debug!("pool: looking for an existing connection"); if let Ok(Some(client)) = self.http_conn_pool.get(ctx, &conn_info) { return Ok(client); @@ -568,7 +569,7 @@ impl ConnectMechanism for TokioMechanism { } struct HyperMechanism { - pool: Arc>>, + pool: Arc>>, conn_info: ConnInfo, conn_id: uuid::Uuid, @@ -578,7 +579,7 @@ struct HyperMechanism { #[async_trait] impl ConnectMechanism for HyperMechanism { - type Connection = http_conn_pool::Client; + type Connection = http_conn_pool::Client; type ConnectError = HttpConnError; type Error = HttpConnError; @@ -632,7 +633,13 @@ async fn connect_http2( port: u16, timeout: Duration, tls: Option<&Arc>, -) -> Result<(http_conn_pool::Send, http_conn_pool::Connect), LocalProxyConnError> { +) -> Result< + ( + http_conn_pool::LocalProxyClient, + http_conn_pool::LocalProxyConnection, + ), + LocalProxyConnError, +> { let addrs = match host_addr { Some(addr) => vec![SocketAddr::new(addr, port)], None => lookup_host((host, port)) diff --git a/proxy/src/serverless/conn_pool.rs b/proxy/src/serverless/conn_pool.rs index 015c46f787..17305e30f1 100644 --- a/proxy/src/serverless/conn_pool.rs +++ b/proxy/src/serverless/conn_pool.rs @@ -190,6 +190,9 @@ mod tests { fn get_process_id(&self) -> i32 { 0 } + fn reset(&mut self) -> Result<(), postgres_client::Error> { + Ok(()) + } } fn create_inner() -> ClientInnerCommon { diff --git a/proxy/src/serverless/conn_pool_lib.rs b/proxy/src/serverless/conn_pool_lib.rs index ed5cc0ea03..6adca49723 100644 --- a/proxy/src/serverless/conn_pool_lib.rs +++ b/proxy/src/serverless/conn_pool_lib.rs @@ -7,10 +7,9 @@ use std::time::Duration; use clashmap::ClashMap; use parking_lot::RwLock; -use postgres_client::ReadyForQueryStatus; use rand::Rng; use smol_str::ToSmolStr; -use tracing::{Span, debug, info}; +use tracing::{Span, debug, info, warn}; use super::backend::HttpConnError; use super::conn_pool::ClientDataRemote; @@ -188,7 +187,7 @@ impl EndpointConnPool { self.pools.get_mut(&db_user) } - pub(crate) fn put(pool: &RwLock, conn_info: &ConnInfo, client: ClientInnerCommon) { + pub(crate) fn put(pool: &RwLock, conn_info: &ConnInfo, mut client: ClientInnerCommon) { let conn_id = client.get_conn_id(); let (max_conn, conn_count, pool_name) = { let pool = pool.read(); @@ -201,12 +200,17 @@ impl EndpointConnPool { }; if client.inner.is_closed() { - info!(%conn_id, "{}: throwing away connection '{conn_info}' because connection is closed", pool_name); + info!(%conn_id, "{pool_name}: throwing away connection '{conn_info}' because connection is closed"); + return; + } + + if let Err(error) = client.inner.reset() { + warn!(?error, %conn_id, "{pool_name}: throwing away connection '{conn_info}' because connection could not be reset"); return; } if conn_count >= max_conn { - info!(%conn_id, "{}: throwing away connection '{conn_info}' because pool is full", pool_name); + info!(%conn_id, "{pool_name}: throwing away connection '{conn_info}' because pool is full"); return; } @@ -691,6 +695,7 @@ impl Deref for Client { pub(crate) trait ClientInnerExt: Sync + Send + 'static { fn is_closed(&self) -> bool; fn get_process_id(&self) -> i32; + fn reset(&mut self) -> Result<(), postgres_client::Error>; } impl ClientInnerExt for postgres_client::Client { @@ -701,15 +706,13 @@ impl ClientInnerExt for postgres_client::Client { fn get_process_id(&self) -> i32 { self.get_process_id() } + + fn reset(&mut self) -> Result<(), postgres_client::Error> { + self.reset_session_background() + } } impl Discard<'_, C> { - pub(crate) fn check_idle(&mut self, status: ReadyForQueryStatus) { - let conn_info = &self.conn_info; - if status != ReadyForQueryStatus::Idle && std::mem::take(self.pool).strong_count() > 0 { - info!("pool: throwing away connection '{conn_info}' because connection is not idle"); - } - } pub(crate) fn discard(&mut self) { let conn_info = &self.conn_info; if std::mem::take(self.pool).strong_count() > 0 { diff --git a/proxy/src/serverless/http_conn_pool.rs b/proxy/src/serverless/http_conn_pool.rs index 7acd816026..bf6b934d20 100644 --- a/proxy/src/serverless/http_conn_pool.rs +++ b/proxy/src/serverless/http_conn_pool.rs @@ -23,8 +23,8 @@ use crate::protocol2::ConnectionInfoExtra; use crate::types::EndpointCacheKey; use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS}; -pub(crate) type Send = http2::SendRequest>; -pub(crate) type Connect = +pub(crate) type LocalProxyClient = http2::SendRequest>; +pub(crate) type LocalProxyConnection = http2::Connection, BoxBody, TokioExecutor>; #[derive(Clone)] @@ -189,14 +189,14 @@ impl GlobalConnPool> { } pub(crate) fn poll_http2_client( - global_pool: Arc>>, + global_pool: Arc>>, ctx: &RequestContext, conn_info: &ConnInfo, - client: Send, - connection: Connect, + client: LocalProxyClient, + connection: LocalProxyConnection, conn_id: uuid::Uuid, aux: MetricsAuxInfo, -) -> Client { +) -> Client { let conn_gauge = Metrics::get().proxy.db_connections.guard(ctx.protocol()); let session_id = ctx.session_id(); @@ -285,7 +285,7 @@ impl Client { } } -impl ClientInnerExt for Send { +impl ClientInnerExt for LocalProxyClient { fn is_closed(&self) -> bool { self.is_closed() } @@ -294,4 +294,10 @@ impl ClientInnerExt for Send { // ideally throw something meaningful -1 } + + fn reset(&mut self) -> Result<(), postgres_client::Error> { + // We use HTTP/2.0 to talk to local proxy. HTTP is stateless, + // so there's nothing to reset. + Ok(()) + } } diff --git a/proxy/src/serverless/local_conn_pool.rs b/proxy/src/serverless/local_conn_pool.rs index f63d84d66b..b8a502c37e 100644 --- a/proxy/src/serverless/local_conn_pool.rs +++ b/proxy/src/serverless/local_conn_pool.rs @@ -269,11 +269,6 @@ impl ClientInnerCommon { local_data.jti += 1; let token = resign_jwt(&local_data.key, payload, local_data.jti)?; - self.inner - .discard_all() - .await - .map_err(SqlOverHttpError::InternalPostgres)?; - // initiates the auth session // this is safe from query injections as the jwt format free of any escape characters. let query = format!("select auth.jwt_session_init('{token}')"); diff --git a/proxy/src/serverless/rest.rs b/proxy/src/serverless/rest.rs index 173c2629f7..c9b5e99747 100644 --- a/proxy/src/serverless/rest.rs +++ b/proxy/src/serverless/rest.rs @@ -46,7 +46,7 @@ use super::backend::{HttpConnError, LocalProxyConnError, PoolingBackend}; use super::conn_pool::AuthData; use super::conn_pool_lib::ConnInfo; use super::error::{ConnInfoError, Credentials, HttpCodeError, ReadPayloadError}; -use super::http_conn_pool::{self, Send}; +use super::http_conn_pool::{self, LocalProxyClient}; use super::http_util::{ ALLOW_POOL, CONN_STRING, NEON_REQUEST_ID, RAW_TEXT_OUTPUT, TXN_ISOLATION_LEVEL, TXN_READ_ONLY, get_conn_info, json_response, uuid_to_header_value, @@ -145,7 +145,7 @@ impl DbSchemaCache { endpoint_id: &EndpointCacheKey, auth_header: &HeaderValue, connection_string: &str, - client: &mut http_conn_pool::Client, + client: &mut http_conn_pool::Client, ctx: &RequestContext, config: &'static ProxyConfig, ) -> Result, RestError> { @@ -190,7 +190,7 @@ impl DbSchemaCache { &self, auth_header: &HeaderValue, connection_string: &str, - client: &mut http_conn_pool::Client, + client: &mut http_conn_pool::Client, ctx: &RequestContext, config: &'static ProxyConfig, ) -> Result<(ApiConfig, DbSchemaOwned), RestError> { @@ -430,7 +430,7 @@ struct BatchQueryData<'a> { } async fn make_local_proxy_request( - client: &mut http_conn_pool::Client, + client: &mut http_conn_pool::Client, headers: impl IntoIterator, body: QueryData<'_>, max_len: usize, @@ -461,7 +461,7 @@ async fn make_local_proxy_request( } async fn make_raw_local_proxy_request( - client: &mut http_conn_pool::Client, + client: &mut http_conn_pool::Client, headers: impl IntoIterator, body: String, ) -> Result, RestError> { diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index f254b41b5b..26f65379e7 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -735,9 +735,7 @@ impl QueryData { match batch_result { // The query successfully completed. - Ok(status) => { - discard.check_idle(status); - + Ok(_) => { let json_output = String::from_utf8(json_buf).expect("json should be valid utf8"); Ok(json_output) } @@ -793,7 +791,7 @@ impl BatchQueryData { { Ok(json_output) => { info!("commit"); - let status = transaction + transaction .commit() .await .inspect_err(|_| { @@ -802,7 +800,6 @@ impl BatchQueryData { discard.discard(); }) .map_err(SqlOverHttpError::Postgres)?; - discard.check_idle(status); json_output } Err(SqlOverHttpError::Cancelled(_)) => { @@ -815,17 +812,6 @@ impl BatchQueryData { return Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Postgres)); } Err(err) => { - info!("rollback"); - let status = transaction - .rollback() - .await - .inspect_err(|_| { - // if we cannot rollback - for now don't return connection to pool - // TODO: get a query status from the error - discard.discard(); - }) - .map_err(SqlOverHttpError::Postgres)?; - discard.check_idle(status); return Err(err); } }; @@ -1012,12 +998,6 @@ impl Client { } impl Discard<'_> { - fn check_idle(&mut self, status: ReadyForQueryStatus) { - match self { - Discard::Remote(discard) => discard.check_idle(status), - Discard::Local(discard) => discard.check_idle(status), - } - } fn discard(&mut self) { match self { Discard::Remote(discard) => discard.discard(), diff --git a/test_runner/fixtures/neon_fixtures.py b/test_runner/fixtures/neon_fixtures.py index f7917f214a..33a18e4394 100644 --- a/test_runner/fixtures/neon_fixtures.py +++ b/test_runner/fixtures/neon_fixtures.py @@ -3910,6 +3910,41 @@ class NeonProxy(PgProtocol): assert response.status_code == expected_code, f"response: {response.json()}" return response.json() + def http_multiquery(self, *queries, **kwargs): + # TODO maybe use default values if not provided + user = quote(kwargs["user"]) + password = quote(kwargs["password"]) + expected_code = kwargs.get("expected_code") + timeout = kwargs.get("timeout") + + json_queries = [] + for query in queries: + if type(query) is str: + json_queries.append({"query": query}) + else: + [query, params] = query + json_queries.append({"query": query, "params": params}) + + queries_str = [j["query"] for j in json_queries] + log.info(f"Executing http queries: {queries_str}") + + connstr = f"postgresql://{user}:{password}@{self.domain}:{self.proxy_port}/postgres" + response = requests.post( + f"https://{self.domain}:{self.external_http_port}/sql", + data=json.dumps({"queries": json_queries}), + headers={ + "Content-Type": "application/sql", + "Neon-Connection-String": connstr, + "Neon-Pool-Opt-In": "true", + }, + verify=str(self.test_output_dir / "proxy.crt"), + timeout=timeout, + ) + + if expected_code is not None: + assert response.status_code == expected_code, f"response: {response.json()}" + return response.json() + async def http2_query(self, query, args, **kwargs): # TODO maybe use default values if not provided user = kwargs["user"] diff --git a/test_runner/regress/test_proxy.py b/test_runner/regress/test_proxy.py index 9860658ba5..dadaf8a1cf 100644 --- a/test_runner/regress/test_proxy.py +++ b/test_runner/regress/test_proxy.py @@ -17,9 +17,6 @@ if TYPE_CHECKING: from typing import Any -GET_CONNECTION_PID_QUERY = "SELECT pid FROM pg_stat_activity WHERE state = 'active'" - - @pytest.mark.asyncio async def test_http_pool_begin_1(static_proxy: NeonProxy): static_proxy.safe_psql("create user http_auth with password 'http' superuser") @@ -479,7 +476,7 @@ def test_sql_over_http_pool(static_proxy: NeonProxy): def get_pid(status: int, pw: str, user="http_auth") -> Any: return static_proxy.http_query( - GET_CONNECTION_PID_QUERY, + "SELECT pg_backend_pid() as pid", [], user=user, password=pw, @@ -513,6 +510,35 @@ def test_sql_over_http_pool(static_proxy: NeonProxy): assert "password authentication failed for user" in res["message"] +def test_sql_over_http_pool_settings(static_proxy: NeonProxy): + static_proxy.safe_psql("create user http_auth with password 'http' superuser") + + def multiquery(*queries) -> Any: + results = static_proxy.http_multiquery( + *queries, + user="http_auth", + password="http", + expected_code=200, + ) + + return [result["rows"] for result in results["results"]] + + [[intervalstyle]] = static_proxy.safe_psql("SHOW IntervalStyle") + assert intervalstyle == "postgres", "'postgres' is the default IntervalStyle in postgres" + + result = multiquery("select '0 seconds'::interval as interval") + assert result[0][0]["interval"] == "00:00:00", "interval is expected in postgres format" + + result = multiquery( + "SET IntervalStyle = 'iso_8601'", + "select '0 seconds'::interval as interval", + ) + assert result[1][0]["interval"] == "PT0S", "interval is expected in ISO-8601 format" + + result = multiquery("select '0 seconds'::interval as interval") + assert result[0][0]["interval"] == "00:00:00", "interval is expected in postgres format" + + def test_sql_over_http_urlencoding(static_proxy: NeonProxy): static_proxy.safe_psql("create user \"http+auth$$\" with password '%+$^&*@!' superuser") @@ -544,23 +570,37 @@ def test_http_pool_begin(static_proxy: NeonProxy): query(200, "SELECT 1;") # Query that should succeed regardless of the transaction -def test_sql_over_http_pool_idle(static_proxy: NeonProxy): +def test_sql_over_http_pool_tx_reuse(static_proxy: NeonProxy): static_proxy.safe_psql("create user http_auth2 with password 'http' superuser") - def query(status: int, query: str) -> Any: + def query(status: int, query: str, *args) -> Any: return static_proxy.http_query( query, - [], + args, user="http_auth2", password="http", expected_code=status, ) - pid1 = query(200, GET_CONNECTION_PID_QUERY)["rows"][0]["pid"] + def query_pid_txid() -> Any: + result = query( + 200, + "SELECT pg_backend_pid() as pid, pg_current_xact_id() as txid", + ) + + return result["rows"][0] + + res0 = query_pid_txid() + time.sleep(0.02) query(200, "BEGIN") - pid2 = query(200, GET_CONNECTION_PID_QUERY)["rows"][0]["pid"] - assert pid1 != pid2 + + res1 = query_pid_txid() + res2 = query_pid_txid() + + assert res0["pid"] == res1["pid"], "connection should be reused" + assert res0["pid"] == res2["pid"], "connection should be reused" + assert res1["txid"] != res2["txid"], "txid should be different" @pytest.mark.timeout(60)