use std::collections::HashMap; use std::ops::Deref; use std::sync::atomic::{self, AtomicUsize}; use std::sync::{Arc, Weak}; use std::time::Duration; use dashmap::DashMap; use parking_lot::RwLock; use rand::Rng; use tokio_postgres::ReadyForQueryStatus; use tracing::{debug, info, Span}; use super::backend::HttpConnError; use super::conn_pool::ClientInnerRemote; use crate::auth::backend::ComputeUserInfo; use crate::context::RequestMonitoring; use crate::control_plane::messages::ColdStartInfo; use crate::metrics::{HttpEndpointPoolsGuard, Metrics}; use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS}; use crate::{DbName, EndpointCacheKey, RoleName}; #[derive(Debug, Clone)] pub(crate) struct ConnInfo { pub(crate) user_info: ComputeUserInfo, pub(crate) dbname: DbName, } impl ConnInfo { // hm, change to hasher to avoid cloning? pub(crate) fn db_and_user(&self) -> (DbName, RoleName) { (self.dbname.clone(), self.user_info.user.clone()) } pub(crate) fn endpoint_cache_key(&self) -> Option { // We don't want to cache http connections for ephemeral endpoints. if self.user_info.options.is_ephemeral() { None } else { Some(self.user_info.endpoint_cache_key()) } } } pub(crate) struct ConnPoolEntry { pub(crate) conn: ClientInnerRemote, pub(crate) _last_access: std::time::Instant, } // Per-endpoint connection pool, (dbname, username) -> DbUserConnPool // Number of open connections is limited by the `max_conns_per_endpoint`. pub(crate) struct EndpointConnPool { pools: HashMap<(DbName, RoleName), DbUserConnPool>, total_conns: usize, max_conns: usize, _guard: HttpEndpointPoolsGuard<'static>, global_connections_count: Arc, global_pool_size_max_conns: usize, } impl EndpointConnPool { fn get_conn_entry(&mut self, db_user: (DbName, RoleName)) -> Option> { let Self { pools, total_conns, global_connections_count, .. } = self; pools.get_mut(&db_user).and_then(|pool_entries| { let (entry, removed) = pool_entries.get_conn_entry(total_conns); global_connections_count.fetch_sub(removed, atomic::Ordering::Relaxed); entry }) } pub(crate) fn remove_client( &mut self, db_user: (DbName, RoleName), conn_id: uuid::Uuid, ) -> bool { let Self { pools, total_conns, global_connections_count, .. } = self; if let Some(pool) = pools.get_mut(&db_user) { let old_len = pool.conns.len(); pool.conns.retain(|conn| conn.conn.get_conn_id() != conn_id); let new_len = pool.conns.len(); let removed = old_len - new_len; if removed > 0 { global_connections_count.fetch_sub(removed, atomic::Ordering::Relaxed); Metrics::get() .proxy .http_pool_opened_connections .get_metric() .dec_by(removed as i64); } *total_conns -= removed; removed > 0 } else { false } } pub(crate) fn put(pool: &RwLock, conn_info: &ConnInfo, client: ClientInnerRemote) { let conn_id = client.get_conn_id(); if client.is_closed() { info!(%conn_id, "pool: throwing away connection '{conn_info}' because connection is closed"); return; } let global_max_conn = pool.read().global_pool_size_max_conns; if pool .read() .global_connections_count .load(atomic::Ordering::Relaxed) >= global_max_conn { info!(%conn_id, "pool: throwing away connection '{conn_info}' because pool is full"); return; } // return connection to the pool let mut returned = false; let mut per_db_size = 0; let total_conns = { let mut pool = pool.write(); if pool.total_conns < pool.max_conns { let pool_entries = pool.pools.entry(conn_info.db_and_user()).or_default(); pool_entries.conns.push(ConnPoolEntry { conn: client, _last_access: std::time::Instant::now(), }); returned = true; per_db_size = pool_entries.conns.len(); pool.total_conns += 1; pool.global_connections_count .fetch_add(1, atomic::Ordering::Relaxed); Metrics::get() .proxy .http_pool_opened_connections .get_metric() .inc(); } pool.total_conns }; // do logging outside of the mutex if returned { info!(%conn_id, "pool: returning connection '{conn_info}' back to the pool, total_conns={total_conns}, for this (db, user)={per_db_size}"); } else { info!(%conn_id, "pool: throwing away connection '{conn_info}' because pool is full, total_conns={total_conns}"); } } } impl Drop for EndpointConnPool { fn drop(&mut self) { if self.total_conns > 0 { self.global_connections_count .fetch_sub(self.total_conns, atomic::Ordering::Relaxed); Metrics::get() .proxy .http_pool_opened_connections .get_metric() .dec_by(self.total_conns as i64); } } } pub(crate) struct DbUserConnPool { pub(crate) conns: Vec>, } impl Default for DbUserConnPool { fn default() -> Self { Self { conns: Vec::new() } } } impl DbUserConnPool { fn clear_closed_clients(&mut self, conns: &mut usize) -> usize { let old_len = self.conns.len(); self.conns.retain(|conn| !conn.conn.is_closed()); let new_len = self.conns.len(); let removed = old_len - new_len; *conns -= removed; removed } pub(crate) fn get_conn_entry( &mut self, conns: &mut usize, ) -> (Option>, usize) { let mut removed = self.clear_closed_clients(conns); let conn = self.conns.pop(); if conn.is_some() { *conns -= 1; removed += 1; } Metrics::get() .proxy .http_pool_opened_connections .get_metric() .dec_by(removed as i64); (conn, removed) } } pub(crate) struct GlobalConnPool { // endpoint -> per-endpoint connection pool // // That should be a fairly conteded map, so return reference to the per-endpoint // pool as early as possible and release the lock. global_pool: DashMap>>>, /// Number of endpoint-connection pools /// /// [`DashMap::len`] iterates over all inner pools and acquires a read lock on each. /// That seems like far too much effort, so we're using a relaxed increment counter instead. /// It's only used for diagnostics. global_pool_size: AtomicUsize, /// Total number of connections in the pool global_connections_count: Arc, config: &'static crate::config::HttpConfig, } #[derive(Debug, Clone, Copy)] pub struct GlobalConnPoolOptions { // Maximum number of connections per one endpoint. // Can mix different (dbname, username) connections. // When running out of free slots for a particular endpoint, // falls back to opening a new connection for each request. pub max_conns_per_endpoint: usize, pub gc_epoch: Duration, pub pool_shards: usize, pub idle_timeout: Duration, pub opt_in: bool, // Total number of connections in the pool. pub max_total_conns: usize, } impl GlobalConnPool { pub(crate) fn new(config: &'static crate::config::HttpConfig) -> Arc { let shards = config.pool_options.pool_shards; Arc::new(Self { global_pool: DashMap::with_shard_amount(shards), global_pool_size: AtomicUsize::new(0), config, global_connections_count: Arc::new(AtomicUsize::new(0)), }) } #[cfg(test)] pub(crate) fn get_global_connections_count(&self) -> usize { self.global_connections_count .load(atomic::Ordering::Relaxed) } pub(crate) fn get_idle_timeout(&self) -> Duration { self.config.pool_options.idle_timeout } pub(crate) fn shutdown(&self) { // drops all strong references to endpoint-pools self.global_pool.clear(); } pub(crate) async fn gc_worker(&self, mut rng: impl Rng) { let epoch = self.config.pool_options.gc_epoch; let mut interval = tokio::time::interval(epoch / (self.global_pool.shards().len()) as u32); loop { interval.tick().await; let shard = rng.gen_range(0..self.global_pool.shards().len()); self.gc(shard); } } pub(crate) fn gc(&self, shard: usize) { debug!(shard, "pool: performing epoch reclamation"); // acquire a random shard lock let mut shard = self.global_pool.shards()[shard].write(); let timer = Metrics::get() .proxy .http_pool_reclaimation_lag_seconds .start_timer(); let current_len = shard.len(); let mut clients_removed = 0; shard.retain(|endpoint, x| { // if the current endpoint pool is unique (no other strong or weak references) // then it is currently not in use by any connections. if let Some(pool) = Arc::get_mut(x.get_mut()) { let EndpointConnPool { pools, total_conns, .. } = pool.get_mut(); // ensure that closed clients are removed for db_pool in pools.values_mut() { clients_removed += db_pool.clear_closed_clients(total_conns); } // we only remove this pool if it has no active connections if *total_conns == 0 { info!("pool: discarding pool for endpoint {endpoint}"); return false; } } true }); let new_len = shard.len(); drop(shard); timer.observe(); // Do logging outside of the lock. if clients_removed > 0 { let size = self .global_connections_count .fetch_sub(clients_removed, atomic::Ordering::Relaxed) - clients_removed; Metrics::get() .proxy .http_pool_opened_connections .get_metric() .dec_by(clients_removed as i64); info!("pool: performed global pool gc. removed {clients_removed} clients, total number of clients in pool is {size}"); } let removed = current_len - new_len; if removed > 0 { let global_pool_size = self .global_pool_size .fetch_sub(removed, atomic::Ordering::Relaxed) - removed; info!("pool: performed global pool gc. size now {global_pool_size}"); } } pub(crate) fn get_or_create_endpoint_pool( self: &Arc, endpoint: &EndpointCacheKey, ) -> Arc>> { // fast path if let Some(pool) = self.global_pool.get(endpoint) { return pool.clone(); } // slow path let new_pool = Arc::new(RwLock::new(EndpointConnPool { pools: HashMap::new(), total_conns: 0, max_conns: self.config.pool_options.max_conns_per_endpoint, _guard: Metrics::get().proxy.http_endpoint_pools.guard(), global_connections_count: self.global_connections_count.clone(), global_pool_size_max_conns: self.config.pool_options.max_total_conns, })); // find or create a pool for this endpoint let mut created = false; let pool = self .global_pool .entry(endpoint.clone()) .or_insert_with(|| { created = true; new_pool }) .clone(); // log new global pool size if created { let global_pool_size = self .global_pool_size .fetch_add(1, atomic::Ordering::Relaxed) + 1; info!( "pool: created new pool for '{endpoint}', global pool size now {global_pool_size}" ); } pool } pub(crate) fn get( self: &Arc, ctx: &RequestMonitoring, conn_info: &ConnInfo, ) -> Result>, HttpConnError> { let mut client: Option> = None; let Some(endpoint) = conn_info.endpoint_cache_key() else { return Ok(None); }; let endpoint_pool = self.get_or_create_endpoint_pool(&endpoint); if let Some(entry) = endpoint_pool .write() .get_conn_entry(conn_info.db_and_user()) { client = Some(entry.conn); } let endpoint_pool = Arc::downgrade(&endpoint_pool); // ok return cached connection if found and establish a new one otherwise if let Some(mut client) = client { if client.is_closed() { info!("pool: cached connection '{conn_info}' is closed, opening a new one"); return Ok(None); } tracing::Span::current() .record("conn_id", tracing::field::display(client.get_conn_id())); tracing::Span::current().record( "pid", tracing::field::display(client.inner().get_process_id()), ); info!( cold_start_info = ColdStartInfo::HttpPoolHit.as_str(), "pool: reusing connection '{conn_info}'" ); client.session().send(ctx.session_id())?; ctx.set_cold_start_info(ColdStartInfo::HttpPoolHit); ctx.success(); return Ok(Some(Client::new(client, conn_info.clone(), endpoint_pool))); } Ok(None) } } impl Client { pub(crate) fn new( inner: ClientInnerRemote, conn_info: ConnInfo, pool: Weak>>, ) -> Self { Self { inner: Some(inner), span: Span::current(), conn_info, pool, } } pub(crate) fn inner_mut(&mut self) -> (&mut C, Discard<'_, C>) { let Self { inner, pool, conn_info, span: _, } = self; let inner = inner.as_mut().expect("client inner should not be removed"); let inner_ref = inner.inner_mut(); (inner_ref, Discard { conn_info, pool }) } pub(crate) fn metrics(&self) -> Arc { let aux = &self.inner.as_ref().unwrap().aux(); USAGE_METRICS.register(Ids { endpoint_id: aux.endpoint_id, branch_id: aux.branch_id, }) } pub(crate) fn do_drop(&mut self) -> Option> { let conn_info = self.conn_info.clone(); let client = self .inner .take() .expect("client inner should not be removed"); if let Some(conn_pool) = std::mem::take(&mut self.pool).upgrade() { let current_span = self.span.clone(); // return connection to the pool return Some(move || { let _span = current_span.enter(); EndpointConnPool::put(&conn_pool, &conn_info, client); }); } None } } pub(crate) struct Client { span: Span, inner: Option>, conn_info: ConnInfo, pool: Weak>>, } impl Drop for Client { fn drop(&mut self) { if let Some(drop) = self.do_drop() { tokio::task::spawn_blocking(drop); } } } impl Deref for Client { type Target = C; fn deref(&self) -> &Self::Target { self.inner .as_ref() .expect("client inner should not be removed") .inner() } } pub(crate) trait ClientInnerExt: Sync + Send + 'static { fn is_closed(&self) -> bool; fn get_process_id(&self) -> i32; } impl ClientInnerExt for tokio_postgres::Client { fn is_closed(&self) -> bool { self.is_closed() } fn get_process_id(&self) -> i32 { self.get_process_id() } } pub(crate) struct Discard<'a, C: ClientInnerExt> { conn_info: &'a ConnInfo, pool: &'a mut Weak>>, } impl Discard<'_, C> { pub(crate) fn check_idle(&mut self, status: ReadyForQueryStatus) { let conn_info = &self.conn_info; if status != ReadyForQueryStatus::Idle && std::mem::take(self.pool).strong_count() > 0 { info!("pool: throwing away connection '{conn_info}' because connection is not idle"); } } pub(crate) fn discard(&mut self) { let conn_info = &self.conn_info; if std::mem::take(self.pool).strong_count() > 0 { info!("pool: throwing away connection '{conn_info}' because connection is potentially in a broken state"); } } }