mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-07 21:42:56 +00:00
[sql-over-http] Reset session state between pooled connection re-use (#12681)
Session variables can be set during one sql-over-http query and observed on another when that pooled connection is re-used. To address this we can use `RESET ALL;` before re-using the connection. LKB-2495 To be on the safe side, we can opt for a full `DISCARD ALL;`, but that might have performance regressions since it also clears any query plans. See pgbouncer docs https://www.pgbouncer.org/config.html#server_reset_query. `DISCARD ALL` is currently defined as: ``` CLOSE ALL; SET SESSION AUTHORIZATION DEFAULT; RESET ALL; DEALLOCATE ALL; UNLISTEN *; SELECT pg_advisory_unlock_all(); DISCARD PLANS; DISCARD TEMP; DISCARD SEQUENCES; ``` I've opted to keep everything here except the `DISCARD PLANS`. I've modified the code so that this query is executed in the background when a connection is returned to the pool, rather than when taken from the pool. This should marginally improve performance for Neon RLS by removing 1 (localhost) round trip. I don't believe that keeping query plans could be a security concern. It's a potential side channel, but I can't imagine what you could extract from it. --- Thanks to https://github.com/neondatabase/neon/pull/12659#discussion_r2219016205 for probing the idea in my head.
This commit is contained in:
@@ -18,7 +18,7 @@ use tracing::{debug, info};
|
||||
use super::AsyncRW;
|
||||
use super::conn_pool::poll_client;
|
||||
use super::conn_pool_lib::{Client, ConnInfo, EndpointConnPool, GlobalConnPool};
|
||||
use super::http_conn_pool::{self, HttpConnPool, Send, poll_http2_client};
|
||||
use super::http_conn_pool::{self, HttpConnPool, LocalProxyClient, poll_http2_client};
|
||||
use super::local_conn_pool::{self, EXT_NAME, EXT_SCHEMA, EXT_VERSION, LocalConnPool};
|
||||
use crate::auth::backend::local::StaticAuthRules;
|
||||
use crate::auth::backend::{ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo};
|
||||
@@ -40,7 +40,8 @@ use crate::rate_limiter::EndpointRateLimiter;
|
||||
use crate::types::{EndpointId, Host, LOCAL_PROXY_SUFFIX};
|
||||
|
||||
pub(crate) struct PoolingBackend {
|
||||
pub(crate) http_conn_pool: Arc<GlobalConnPool<Send, HttpConnPool<Send>>>,
|
||||
pub(crate) http_conn_pool:
|
||||
Arc<GlobalConnPool<LocalProxyClient, HttpConnPool<LocalProxyClient>>>,
|
||||
pub(crate) local_pool: Arc<LocalConnPool<postgres_client::Client>>,
|
||||
pub(crate) pool:
|
||||
Arc<GlobalConnPool<postgres_client::Client, EndpointConnPool<postgres_client::Client>>>,
|
||||
@@ -210,7 +211,7 @@ impl PoolingBackend {
|
||||
&self,
|
||||
ctx: &RequestContext,
|
||||
conn_info: ConnInfo,
|
||||
) -> Result<http_conn_pool::Client<Send>, HttpConnError> {
|
||||
) -> Result<http_conn_pool::Client<LocalProxyClient>, HttpConnError> {
|
||||
debug!("pool: looking for an existing connection");
|
||||
if let Ok(Some(client)) = self.http_conn_pool.get(ctx, &conn_info) {
|
||||
return Ok(client);
|
||||
@@ -568,7 +569,7 @@ impl ConnectMechanism for TokioMechanism {
|
||||
}
|
||||
|
||||
struct HyperMechanism {
|
||||
pool: Arc<GlobalConnPool<Send, HttpConnPool<Send>>>,
|
||||
pool: Arc<GlobalConnPool<LocalProxyClient, HttpConnPool<LocalProxyClient>>>,
|
||||
conn_info: ConnInfo,
|
||||
conn_id: uuid::Uuid,
|
||||
|
||||
@@ -578,7 +579,7 @@ struct HyperMechanism {
|
||||
|
||||
#[async_trait]
|
||||
impl ConnectMechanism for HyperMechanism {
|
||||
type Connection = http_conn_pool::Client<Send>;
|
||||
type Connection = http_conn_pool::Client<LocalProxyClient>;
|
||||
type ConnectError = HttpConnError;
|
||||
type Error = HttpConnError;
|
||||
|
||||
@@ -632,7 +633,13 @@ async fn connect_http2(
|
||||
port: u16,
|
||||
timeout: Duration,
|
||||
tls: Option<&Arc<rustls::ClientConfig>>,
|
||||
) -> Result<(http_conn_pool::Send, http_conn_pool::Connect), LocalProxyConnError> {
|
||||
) -> Result<
|
||||
(
|
||||
http_conn_pool::LocalProxyClient,
|
||||
http_conn_pool::LocalProxyConnection,
|
||||
),
|
||||
LocalProxyConnError,
|
||||
> {
|
||||
let addrs = match host_addr {
|
||||
Some(addr) => vec![SocketAddr::new(addr, port)],
|
||||
None => lookup_host((host, port))
|
||||
|
||||
@@ -190,6 +190,9 @@ mod tests {
|
||||
fn get_process_id(&self) -> i32 {
|
||||
0
|
||||
}
|
||||
fn reset(&mut self) -> Result<(), postgres_client::Error> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn create_inner() -> ClientInnerCommon<MockClient> {
|
||||
|
||||
@@ -7,10 +7,9 @@ use std::time::Duration;
|
||||
|
||||
use clashmap::ClashMap;
|
||||
use parking_lot::RwLock;
|
||||
use postgres_client::ReadyForQueryStatus;
|
||||
use rand::Rng;
|
||||
use smol_str::ToSmolStr;
|
||||
use tracing::{Span, debug, info};
|
||||
use tracing::{Span, debug, info, warn};
|
||||
|
||||
use super::backend::HttpConnError;
|
||||
use super::conn_pool::ClientDataRemote;
|
||||
@@ -188,7 +187,7 @@ impl<C: ClientInnerExt> EndpointConnPool<C> {
|
||||
self.pools.get_mut(&db_user)
|
||||
}
|
||||
|
||||
pub(crate) fn put(pool: &RwLock<Self>, conn_info: &ConnInfo, client: ClientInnerCommon<C>) {
|
||||
pub(crate) fn put(pool: &RwLock<Self>, conn_info: &ConnInfo, mut client: ClientInnerCommon<C>) {
|
||||
let conn_id = client.get_conn_id();
|
||||
let (max_conn, conn_count, pool_name) = {
|
||||
let pool = pool.read();
|
||||
@@ -201,12 +200,17 @@ impl<C: ClientInnerExt> EndpointConnPool<C> {
|
||||
};
|
||||
|
||||
if client.inner.is_closed() {
|
||||
info!(%conn_id, "{}: throwing away connection '{conn_info}' because connection is closed", pool_name);
|
||||
info!(%conn_id, "{pool_name}: throwing away connection '{conn_info}' because connection is closed");
|
||||
return;
|
||||
}
|
||||
|
||||
if let Err(error) = client.inner.reset() {
|
||||
warn!(?error, %conn_id, "{pool_name}: throwing away connection '{conn_info}' because connection could not be reset");
|
||||
return;
|
||||
}
|
||||
|
||||
if conn_count >= max_conn {
|
||||
info!(%conn_id, "{}: throwing away connection '{conn_info}' because pool is full", pool_name);
|
||||
info!(%conn_id, "{pool_name}: throwing away connection '{conn_info}' because pool is full");
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -691,6 +695,7 @@ impl<C: ClientInnerExt> Deref for Client<C> {
|
||||
pub(crate) trait ClientInnerExt: Sync + Send + 'static {
|
||||
fn is_closed(&self) -> bool;
|
||||
fn get_process_id(&self) -> i32;
|
||||
fn reset(&mut self) -> Result<(), postgres_client::Error>;
|
||||
}
|
||||
|
||||
impl ClientInnerExt for postgres_client::Client {
|
||||
@@ -701,15 +706,13 @@ impl ClientInnerExt for postgres_client::Client {
|
||||
fn get_process_id(&self) -> i32 {
|
||||
self.get_process_id()
|
||||
}
|
||||
|
||||
fn reset(&mut self) -> Result<(), postgres_client::Error> {
|
||||
self.reset_session_background()
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: ClientInnerExt> Discard<'_, C> {
|
||||
pub(crate) fn check_idle(&mut self, status: ReadyForQueryStatus) {
|
||||
let conn_info = &self.conn_info;
|
||||
if status != ReadyForQueryStatus::Idle && std::mem::take(self.pool).strong_count() > 0 {
|
||||
info!("pool: throwing away connection '{conn_info}' because connection is not idle");
|
||||
}
|
||||
}
|
||||
pub(crate) fn discard(&mut self) {
|
||||
let conn_info = &self.conn_info;
|
||||
if std::mem::take(self.pool).strong_count() > 0 {
|
||||
|
||||
@@ -23,8 +23,8 @@ use crate::protocol2::ConnectionInfoExtra;
|
||||
use crate::types::EndpointCacheKey;
|
||||
use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS};
|
||||
|
||||
pub(crate) type Send = http2::SendRequest<BoxBody<Bytes, hyper::Error>>;
|
||||
pub(crate) type Connect =
|
||||
pub(crate) type LocalProxyClient = http2::SendRequest<BoxBody<Bytes, hyper::Error>>;
|
||||
pub(crate) type LocalProxyConnection =
|
||||
http2::Connection<TokioIo<AsyncRW>, BoxBody<Bytes, hyper::Error>, TokioExecutor>;
|
||||
|
||||
#[derive(Clone)]
|
||||
@@ -189,14 +189,14 @@ impl<C: ClientInnerExt + Clone> GlobalConnPool<C, HttpConnPool<C>> {
|
||||
}
|
||||
|
||||
pub(crate) fn poll_http2_client(
|
||||
global_pool: Arc<GlobalConnPool<Send, HttpConnPool<Send>>>,
|
||||
global_pool: Arc<GlobalConnPool<LocalProxyClient, HttpConnPool<LocalProxyClient>>>,
|
||||
ctx: &RequestContext,
|
||||
conn_info: &ConnInfo,
|
||||
client: Send,
|
||||
connection: Connect,
|
||||
client: LocalProxyClient,
|
||||
connection: LocalProxyConnection,
|
||||
conn_id: uuid::Uuid,
|
||||
aux: MetricsAuxInfo,
|
||||
) -> Client<Send> {
|
||||
) -> Client<LocalProxyClient> {
|
||||
let conn_gauge = Metrics::get().proxy.db_connections.guard(ctx.protocol());
|
||||
let session_id = ctx.session_id();
|
||||
|
||||
@@ -285,7 +285,7 @@ impl<C: ClientInnerExt + Clone> Client<C> {
|
||||
}
|
||||
}
|
||||
|
||||
impl ClientInnerExt for Send {
|
||||
impl ClientInnerExt for LocalProxyClient {
|
||||
fn is_closed(&self) -> bool {
|
||||
self.is_closed()
|
||||
}
|
||||
@@ -294,4 +294,10 @@ impl ClientInnerExt for Send {
|
||||
// ideally throw something meaningful
|
||||
-1
|
||||
}
|
||||
|
||||
fn reset(&mut self) -> Result<(), postgres_client::Error> {
|
||||
// We use HTTP/2.0 to talk to local proxy. HTTP is stateless,
|
||||
// so there's nothing to reset.
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -269,11 +269,6 @@ impl ClientInnerCommon<postgres_client::Client> {
|
||||
local_data.jti += 1;
|
||||
let token = resign_jwt(&local_data.key, payload, local_data.jti)?;
|
||||
|
||||
self.inner
|
||||
.discard_all()
|
||||
.await
|
||||
.map_err(SqlOverHttpError::InternalPostgres)?;
|
||||
|
||||
// initiates the auth session
|
||||
// this is safe from query injections as the jwt format free of any escape characters.
|
||||
let query = format!("select auth.jwt_session_init('{token}')");
|
||||
|
||||
@@ -46,7 +46,7 @@ use super::backend::{HttpConnError, LocalProxyConnError, PoolingBackend};
|
||||
use super::conn_pool::AuthData;
|
||||
use super::conn_pool_lib::ConnInfo;
|
||||
use super::error::{ConnInfoError, Credentials, HttpCodeError, ReadPayloadError};
|
||||
use super::http_conn_pool::{self, Send};
|
||||
use super::http_conn_pool::{self, LocalProxyClient};
|
||||
use super::http_util::{
|
||||
ALLOW_POOL, CONN_STRING, NEON_REQUEST_ID, RAW_TEXT_OUTPUT, TXN_ISOLATION_LEVEL, TXN_READ_ONLY,
|
||||
get_conn_info, json_response, uuid_to_header_value,
|
||||
@@ -145,7 +145,7 @@ impl DbSchemaCache {
|
||||
endpoint_id: &EndpointCacheKey,
|
||||
auth_header: &HeaderValue,
|
||||
connection_string: &str,
|
||||
client: &mut http_conn_pool::Client<Send>,
|
||||
client: &mut http_conn_pool::Client<LocalProxyClient>,
|
||||
ctx: &RequestContext,
|
||||
config: &'static ProxyConfig,
|
||||
) -> Result<Arc<(ApiConfig, DbSchemaOwned)>, RestError> {
|
||||
@@ -190,7 +190,7 @@ impl DbSchemaCache {
|
||||
&self,
|
||||
auth_header: &HeaderValue,
|
||||
connection_string: &str,
|
||||
client: &mut http_conn_pool::Client<Send>,
|
||||
client: &mut http_conn_pool::Client<LocalProxyClient>,
|
||||
ctx: &RequestContext,
|
||||
config: &'static ProxyConfig,
|
||||
) -> Result<(ApiConfig, DbSchemaOwned), RestError> {
|
||||
@@ -430,7 +430,7 @@ struct BatchQueryData<'a> {
|
||||
}
|
||||
|
||||
async fn make_local_proxy_request<S: DeserializeOwned>(
|
||||
client: &mut http_conn_pool::Client<Send>,
|
||||
client: &mut http_conn_pool::Client<LocalProxyClient>,
|
||||
headers: impl IntoIterator<Item = (&HeaderName, HeaderValue)>,
|
||||
body: QueryData<'_>,
|
||||
max_len: usize,
|
||||
@@ -461,7 +461,7 @@ async fn make_local_proxy_request<S: DeserializeOwned>(
|
||||
}
|
||||
|
||||
async fn make_raw_local_proxy_request(
|
||||
client: &mut http_conn_pool::Client<Send>,
|
||||
client: &mut http_conn_pool::Client<LocalProxyClient>,
|
||||
headers: impl IntoIterator<Item = (&HeaderName, HeaderValue)>,
|
||||
body: String,
|
||||
) -> Result<Response<Incoming>, RestError> {
|
||||
|
||||
@@ -735,9 +735,7 @@ impl QueryData {
|
||||
|
||||
match batch_result {
|
||||
// The query successfully completed.
|
||||
Ok(status) => {
|
||||
discard.check_idle(status);
|
||||
|
||||
Ok(_) => {
|
||||
let json_output = String::from_utf8(json_buf).expect("json should be valid utf8");
|
||||
Ok(json_output)
|
||||
}
|
||||
@@ -793,7 +791,7 @@ impl BatchQueryData {
|
||||
{
|
||||
Ok(json_output) => {
|
||||
info!("commit");
|
||||
let status = transaction
|
||||
transaction
|
||||
.commit()
|
||||
.await
|
||||
.inspect_err(|_| {
|
||||
@@ -802,7 +800,6 @@ impl BatchQueryData {
|
||||
discard.discard();
|
||||
})
|
||||
.map_err(SqlOverHttpError::Postgres)?;
|
||||
discard.check_idle(status);
|
||||
json_output
|
||||
}
|
||||
Err(SqlOverHttpError::Cancelled(_)) => {
|
||||
@@ -815,17 +812,6 @@ impl BatchQueryData {
|
||||
return Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Postgres));
|
||||
}
|
||||
Err(err) => {
|
||||
info!("rollback");
|
||||
let status = transaction
|
||||
.rollback()
|
||||
.await
|
||||
.inspect_err(|_| {
|
||||
// if we cannot rollback - for now don't return connection to pool
|
||||
// TODO: get a query status from the error
|
||||
discard.discard();
|
||||
})
|
||||
.map_err(SqlOverHttpError::Postgres)?;
|
||||
discard.check_idle(status);
|
||||
return Err(err);
|
||||
}
|
||||
};
|
||||
@@ -1012,12 +998,6 @@ impl Client {
|
||||
}
|
||||
|
||||
impl Discard<'_> {
|
||||
fn check_idle(&mut self, status: ReadyForQueryStatus) {
|
||||
match self {
|
||||
Discard::Remote(discard) => discard.check_idle(status),
|
||||
Discard::Local(discard) => discard.check_idle(status),
|
||||
}
|
||||
}
|
||||
fn discard(&mut self) {
|
||||
match self {
|
||||
Discard::Remote(discard) => discard.discard(),
|
||||
|
||||
Reference in New Issue
Block a user