From 734a960696e6e433ed1660466c6ebbf99b3a6883 Mon Sep 17 00:00:00 2001 From: Ivan Efremov Date: Fri, 11 Oct 2024 17:20:11 +0300 Subject: [PATCH] proxy: abstract away connection pools --- proxy/src/serverless/backend.rs | 10 +- proxy/src/serverless/conn_pool.rs | 129 ++++----- proxy/src/serverless/local_conn_pool.rs | 352 +++++------------------- proxy/src/serverless/mod.rs | 4 +- 4 files changed, 135 insertions(+), 360 deletions(-) diff --git a/proxy/src/serverless/backend.rs b/proxy/src/serverless/backend.rs index f54476b51d..dc8dd0e8bc 100644 --- a/proxy/src/serverless/backend.rs +++ b/proxy/src/serverless/backend.rs @@ -32,15 +32,15 @@ use crate::{ }; use super::{ - conn_pool::{poll_client, Client, ConnInfo, GlobalConnPool}, + conn_pool::{poll_client, Client, ConnInfo, ConnPool, EndpointConnPool}, http_conn_pool::{self, poll_http2_client}, - local_conn_pool::{self, LocalClient, LocalConnPool}, + local_conn_pool::{self, LocalClient}, }; pub(crate) struct PoolingBackend { pub(crate) http_conn_pool: Arc, - pub(crate) local_pool: Arc>, - pub(crate) pool: Arc>, + pub(crate) local_pool: Arc>, + pub(crate) pool: Arc>, pub(crate) config: &'static ProxyConfig, pub(crate) endpoint_rate_limiter: Arc, } @@ -439,7 +439,7 @@ impl ShouldRetryWakeCompute for LocalProxyConnError { } struct TokioMechanism { - pool: Arc>, + pool: Arc>, conn_info: ConnInfo, conn_id: uuid::Uuid, diff --git a/proxy/src/serverless/conn_pool.rs b/proxy/src/serverless/conn_pool.rs index 2e576e0ded..0010aa6c0d 100644 --- a/proxy/src/serverless/conn_pool.rs +++ b/proxy/src/serverless/conn_pool.rs @@ -77,7 +77,7 @@ impl fmt::Display for ConnInfo { } } -struct ConnPoolEntry { +pub(crate) struct ConnPoolEntry { conn: ClientInner, _last_access: std::time::Instant, } @@ -87,10 +87,11 @@ struct ConnPoolEntry { pub(crate) struct EndpointConnPool { pools: HashMap<(DbName, RoleName), DbUserConnPool>, total_conns: usize, - max_conns: usize, + max_conns: usize, // max conns per endpoint _guard: HttpEndpointPoolsGuard<'static>, global_connections_count: Arc, global_pool_size_max_conns: usize, + pool_name: String, // used for logging } impl EndpointConnPool { @@ -133,21 +134,23 @@ impl EndpointConnPool { } } - fn put(pool: &RwLock, conn_info: &ConnInfo, client: ClientInner) { + pub(crate) fn put(pool: &RwLock, conn_info: &ConnInfo, client: ClientInner) { let conn_id = client.conn_id; + let p_name = pool.read().pool_name.clone(); if client.is_closed() { - info!(%conn_id, "pool: throwing away connection '{conn_info}' because connection is closed"); + info!(%conn_id, "{p_name}: throwing away connection '{conn_info}' because connection is closed"); return; } let global_max_conn = pool.read().global_pool_size_max_conns; + if pool .read() .global_connections_count .load(atomic::Ordering::Relaxed) >= global_max_conn { - info!(%conn_id, "pool: throwing away connection '{conn_info}' because pool is full"); + info!(%conn_id, "{p_name}: throwing away connection '{conn_info}' because pool is full"); return; } @@ -182,9 +185,11 @@ impl EndpointConnPool { // do logging outside of the mutex if returned { - info!(%conn_id, "pool: returning connection '{conn_info}' back to the pool, total_conns={total_conns}, for this (db, user)={per_db_size}"); + info!(%conn_id, "{p_name}: returning connection '{conn_info}' back to the pool, + total_conns={total_conns}, for this (db, user)={per_db_size}"); } else { - info!(%conn_id, "pool: throwing away connection '{conn_info}' because pool is full, total_conns={total_conns}"); + info!(%conn_id, "{p_name}: throwing away connection '{conn_info}' because pool is full, + total_conns={total_conns}"); } } } @@ -214,7 +219,7 @@ impl Default for DbUserConnPool { } impl DbUserConnPool { - fn clear_closed_clients(&mut self, conns: &mut usize) -> usize { + pub(crate) fn clear_closed_clients(&mut self, conns: &mut usize) -> usize { let old_len = self.conns.len(); self.conns.retain(|conn| !conn.conn.is_closed()); @@ -225,7 +230,7 @@ impl DbUserConnPool { removed } - fn get_conn_entry( + pub(crate) fn get_conn_entry( &mut self, conns: &mut usize, global_connections_count: Arc, @@ -246,12 +251,12 @@ impl DbUserConnPool { } } -pub(crate) struct GlobalConnPool { +pub(crate) struct ConnPool { // endpoint -> per-endpoint connection pool // // That should be a fairly conteded map, so return reference to the per-endpoint // pool as early as possible and release the lock. - global_pool: DashMap>>>, + pub(crate) global_pool: DashMap>>>, /// Number of endpoint-connection pools /// @@ -286,7 +291,7 @@ pub struct GlobalConnPoolOptions { pub max_total_conns: usize, } -impl GlobalConnPool { +impl ConnPool { pub(crate) fn new(config: &'static crate::config::HttpConfig) -> Arc { let shards = config.pool_options.pool_shards; Arc::new(Self { @@ -428,7 +433,7 @@ impl GlobalConnPool { Ok(None) } - fn get_or_create_endpoint_pool( + pub(crate) fn get_or_create_endpoint_pool( self: &Arc, endpoint: &EndpointCacheKey, ) -> Arc>> { @@ -445,6 +450,7 @@ impl GlobalConnPool { _guard: Metrics::get().proxy.http_endpoint_pools.guard(), global_connections_count: self.global_connections_count.clone(), global_pool_size_max_conns: self.config.pool_options.max_total_conns, + pool_name: String::from("global_pool"), })); // find or create a pool for this endpoint @@ -474,7 +480,7 @@ impl GlobalConnPool { } pub(crate) fn poll_client( - global_pool: Arc>, + global_pool: Arc>, ctx: &RequestMonitoring, conn_info: ConnInfo, client: C, @@ -594,6 +600,12 @@ struct ClientInner { conn_id: uuid::Uuid, } +impl ClientInner { + pub(crate) fn is_closed(&self) -> bool { + self.inner.is_closed() + } +} + impl Drop for ClientInner { fn drop(&mut self) { // on client drop, tell the conn to shut down @@ -615,22 +627,6 @@ impl ClientInnerExt for tokio_postgres::Client { } } -impl ClientInner { - pub(crate) fn is_closed(&self) -> bool { - self.inner.is_closed() - } -} - -impl Client { - pub(crate) fn metrics(&self) -> Arc { - let aux = &self.inner.as_ref().unwrap().aux; - USAGE_METRICS.register(Ids { - endpoint_id: aux.endpoint_id, - branch_id: aux.branch_id, - }) - } -} - pub(crate) struct Client { span: Span, inner: Option>, @@ -638,11 +634,6 @@ pub(crate) struct Client { pool: Weak>>, } -pub(crate) struct Discard<'a, C: ClientInnerExt> { - conn_info: &'a ConnInfo, - pool: &'a mut Weak>>, -} - impl Client { pub(self) fn new( inner: ClientInner, @@ -656,6 +647,7 @@ impl Client { pool, } } + pub(crate) fn inner(&mut self) -> (&mut C, Discard<'_, C>) { let Self { inner, @@ -666,36 +658,15 @@ impl Client { let inner = inner.as_mut().expect("client inner should not be removed"); (&mut inner.inner, Discard { conn_info, pool }) } -} -impl Discard<'_, C> { - pub(crate) fn check_idle(&mut self, status: ReadyForQueryStatus) { - let conn_info = &self.conn_info; - if status != ReadyForQueryStatus::Idle && std::mem::take(self.pool).strong_count() > 0 { - info!("pool: throwing away connection '{conn_info}' because connection is not idle"); - } + pub(crate) fn metrics(&self) -> Arc { + let aux = &self.inner.as_ref().unwrap().aux; + USAGE_METRICS.register(Ids { + endpoint_id: aux.endpoint_id, + branch_id: aux.branch_id, + }) } - pub(crate) fn discard(&mut self) { - let conn_info = &self.conn_info; - if std::mem::take(self.pool).strong_count() > 0 { - info!("pool: throwing away connection '{conn_info}' because connection is potentially in a broken state"); - } - } -} -impl Deref for Client { - type Target = C; - - fn deref(&self) -> &Self::Target { - &self - .inner - .as_ref() - .expect("client inner should not be removed") - .inner - } -} - -impl Client { fn do_drop(&mut self) -> Option { let conn_info = self.conn_info.clone(); let client = self @@ -714,6 +685,18 @@ impl Client { } } +impl Deref for Client { + type Target = C; + + fn deref(&self) -> &Self::Target { + &self + .inner + .as_ref() + .expect("client inner should not be removed") + .inner + } +} + impl Drop for Client { fn drop(&mut self) { if let Some(drop) = self.do_drop() { @@ -722,6 +705,26 @@ impl Drop for Client { } } +pub(crate) struct Discard<'a, C: ClientInnerExt> { + conn_info: &'a ConnInfo, + pool: &'a mut Weak>>, +} + +impl Discard<'_, C> { + pub(crate) fn check_idle(&mut self, status: ReadyForQueryStatus) { + let conn_info = &self.conn_info; + if status != ReadyForQueryStatus::Idle && std::mem::take(self.pool).strong_count() > 0 { + info!("pool: throwing away connection '{conn_info}' because connection is not idle"); + } + } + pub(crate) fn discard(&mut self) { + let conn_info = &self.conn_info; + if std::mem::take(self.pool).strong_count() > 0 { + info!("pool: throwing away connection '{conn_info}' because connection is potentially in a broken state"); + } + } +} + #[cfg(test)] mod tests { use std::{mem, sync::atomic::AtomicBool}; @@ -784,7 +787,7 @@ mod tests { max_request_size_bytes: u64::MAX, max_response_size_bytes: usize::MAX, })); - let pool = GlobalConnPool::new(config); + let pool = ConnPool::new(config); let conn_info = ConnInfo { user_info: ComputeUserInfo { user: "user".into(), diff --git a/proxy/src/serverless/local_conn_pool.rs b/proxy/src/serverless/local_conn_pool.rs index 1dde5952e1..0a84e466f9 100644 --- a/proxy/src/serverless/local_conn_pool.rs +++ b/proxy/src/serverless/local_conn_pool.rs @@ -6,7 +6,15 @@ use rand::rngs::OsRng; use serde_json::Value; use signature::Signer; use std::task::{ready, Poll}; -use std::{collections::HashMap, pin::pin, sync::Arc, sync::Weak, time::Duration}; +use std::{ + collections::HashMap, + pin::pin, + sync::atomic::{self, AtomicUsize}, + sync::Arc, + sync::Weak, + time::Duration, +}; + use tokio::time::Instant; use tokio_postgres::tls::NoTlsStream; use tokio_postgres::types::ToSql; @@ -15,7 +23,7 @@ use tokio_util::sync::CancellationToken; use typed_json::json; use crate::control_plane::messages::{ColdStartInfo, MetricsAuxInfo}; -use crate::metrics::Metrics; +use crate::metrics::{HttpEndpointPoolsGuard, Metrics}; use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS}; use crate::{context::RequestMonitoring, DbName, RoleName}; @@ -23,230 +31,10 @@ use tracing::{debug, error, warn, Span}; use tracing::{info, info_span, Instrument}; use super::backend::HttpConnError; -use super::conn_pool::{ClientInnerExt, ConnInfo}; +use super::conn_pool::{ClientInnerExt, ConnInfo, ConnPool, EndpointConnPool}; -struct ConnPoolEntry { - conn: ClientInner, - _last_access: std::time::Instant, -} - -// /// key id for the pg_session_jwt state -// static PG_SESSION_JWT_KID: AtomicU64 = AtomicU64::new(1); - -// Per-endpoint connection pool, (dbname, username) -> DbUserConnPool -// Number of open connections is limited by the `max_conns_per_endpoint`. -pub(crate) struct EndpointConnPool { - pools: HashMap<(DbName, RoleName), DbUserConnPool>, - total_conns: usize, - max_conns: usize, - global_pool_size_max_conns: usize, -} - -impl EndpointConnPool { - fn get_conn_entry(&mut self, db_user: (DbName, RoleName)) -> Option> { - let Self { - pools, total_conns, .. - } = self; - pools - .get_mut(&db_user) - .and_then(|pool_entries| pool_entries.get_conn_entry(total_conns)) - } - - fn remove_client(&mut self, db_user: (DbName, RoleName), conn_id: uuid::Uuid) -> bool { - let Self { - pools, total_conns, .. - } = self; - if let Some(pool) = pools.get_mut(&db_user) { - let old_len = pool.conns.len(); - pool.conns.retain(|conn| conn.conn.conn_id != conn_id); - let new_len = pool.conns.len(); - let removed = old_len - new_len; - if removed > 0 { - Metrics::get() - .proxy - .http_pool_opened_connections - .get_metric() - .dec_by(removed as i64); - } - *total_conns -= removed; - removed > 0 - } else { - false - } - } - - fn put(pool: &RwLock, conn_info: &ConnInfo, client: ClientInner) { - let conn_id = client.conn_id; - - if client.is_closed() { - info!(%conn_id, "local_pool: throwing away connection '{conn_info}' because connection is closed"); - return; - } - let global_max_conn = pool.read().global_pool_size_max_conns; - if pool.read().total_conns >= global_max_conn { - info!(%conn_id, "local_pool: throwing away connection '{conn_info}' because pool is full"); - return; - } - - // return connection to the pool - let mut returned = false; - let mut per_db_size = 0; - let total_conns = { - let mut pool = pool.write(); - - if pool.total_conns < pool.max_conns { - let pool_entries = pool.pools.entry(conn_info.db_and_user()).or_default(); - pool_entries.conns.push(ConnPoolEntry { - conn: client, - _last_access: std::time::Instant::now(), - }); - - returned = true; - per_db_size = pool_entries.conns.len(); - - pool.total_conns += 1; - Metrics::get() - .proxy - .http_pool_opened_connections - .get_metric() - .inc(); - } - - pool.total_conns - }; - - // do logging outside of the mutex - if returned { - info!(%conn_id, "local_pool: returning connection '{conn_info}' back to the pool, total_conns={total_conns}, for this (db, user)={per_db_size}"); - } else { - info!(%conn_id, "local_pool: throwing away connection '{conn_info}' because pool is full, total_conns={total_conns}"); - } - } -} - -impl Drop for EndpointConnPool { - fn drop(&mut self) { - if self.total_conns > 0 { - Metrics::get() - .proxy - .http_pool_opened_connections - .get_metric() - .dec_by(self.total_conns as i64); - } - } -} - -pub(crate) struct DbUserConnPool { - conns: Vec>, -} - -impl Default for DbUserConnPool { - fn default() -> Self { - Self { conns: Vec::new() } - } -} - -impl DbUserConnPool { - fn clear_closed_clients(&mut self, conns: &mut usize) -> usize { - let old_len = self.conns.len(); - - self.conns.retain(|conn| !conn.conn.is_closed()); - - let new_len = self.conns.len(); - let removed = old_len - new_len; - *conns -= removed; - removed - } - - fn get_conn_entry(&mut self, conns: &mut usize) -> Option> { - let mut removed = self.clear_closed_clients(conns); - let conn = self.conns.pop(); - if conn.is_some() { - *conns -= 1; - removed += 1; - } - Metrics::get() - .proxy - .http_pool_opened_connections - .get_metric() - .dec_by(removed as i64); - conn - } -} - -pub(crate) struct LocalConnPool { - global_pool: RwLock>, - - config: &'static crate::config::HttpConfig, -} - -impl LocalConnPool { - pub(crate) fn new(config: &'static crate::config::HttpConfig) -> Arc { - Arc::new(Self { - global_pool: RwLock::new(EndpointConnPool { - pools: HashMap::new(), - total_conns: 0, - max_conns: config.pool_options.max_conns_per_endpoint, - global_pool_size_max_conns: config.pool_options.max_total_conns, - }), - config, - }) - } - - pub(crate) fn get_idle_timeout(&self) -> Duration { - self.config.pool_options.idle_timeout - } - - // pub(crate) fn shutdown(&self) { - // let mut pool = self.global_pool.write(); - // pool.pools.clear(); - // pool.total_conns = 0; - // } - - pub(crate) fn get( - self: &Arc, - ctx: &RequestMonitoring, - conn_info: &ConnInfo, - ) -> Result>, HttpConnError> { - let mut client: Option> = None; - if let Some(entry) = self - .global_pool - .write() - .get_conn_entry(conn_info.db_and_user()) - { - client = Some(entry.conn); - } - - // ok return cached connection if found and establish a new one otherwise - if let Some(client) = client { - if client.is_closed() { - info!("local_pool: cached connection '{conn_info}' is closed, opening a new one"); - return Ok(None); - } - tracing::Span::current().record("conn_id", tracing::field::display(client.conn_id)); - tracing::Span::current().record( - "pid", - tracing::field::display(client.inner.get_process_id()), - ); - info!( - cold_start_info = ColdStartInfo::HttpPoolHit.as_str(), - "local_pool: reusing connection '{conn_info}'" - ); - client.session.send(ctx.session_id())?; - ctx.set_cold_start_info(ColdStartInfo::HttpPoolHit); - ctx.success(); - return Ok(Some(LocalClient::new( - client, - conn_info.clone(), - Arc::downgrade(self), - ))); - } - Ok(None) - } -} - -pub(crate) fn poll_client( - global_pool: Arc>, +pub(crate) fn poll_client( + local_pool: Arc>, ctx: &RequestMonitoring, conn_info: ConnInfo, client: tokio_postgres::Client, @@ -263,11 +51,11 @@ pub(crate) fn poll_client( span.in_scope(|| { info!(cold_start_info = cold_start_info.as_str(), %conn_info, %session_id, "new connection"); }); - let pool = Arc::downgrade(&global_pool); + let pool = Arc::downgrade(&local_pool); let pool_clone = pool.clone(); let db_user = conn_info.db_and_user(); - let idle = global_pool.get_idle_timeout(); + let idle = local_pool.get_idle_timeout(); let cancel = CancellationToken::new(); let cancelled = cancel.clone().cancelled_owned(); @@ -335,7 +123,7 @@ pub(crate) fn poll_client( // remove from connection pool if let Some(pool) = pool.clone().upgrade() { - if pool.global_pool.write().remove_client(db_user.clone(), conn_id) { + if pool.write().remove_client(db_user.clone(), conn_id) { info!("closed connection removed"); } } @@ -372,47 +160,15 @@ struct ClientInner { jti: u64, } -impl Drop for ClientInner { - fn drop(&mut self) { - // on client drop, tell the conn to shut down - self.cancel.cancel(); - } -} - -impl ClientInner { - pub(crate) fn is_closed(&self) -> bool { - self.inner.is_closed() - } -} - -impl LocalClient { - pub(crate) fn metrics(&self) -> Arc { - let aux = &self.inner.as_ref().unwrap().aux; - USAGE_METRICS.register(Ids { - endpoint_id: aux.endpoint_id, - branch_id: aux.branch_id, - }) - } -} - pub(crate) struct LocalClient { span: Span, inner: Option>, conn_info: ConnInfo, - pool: Weak>, -} - -pub(crate) struct Discard<'a, C: ClientInnerExt> { - conn_info: &'a ConnInfo, - pool: &'a mut Weak>, + pool: Weak>, } impl LocalClient { - pub(self) fn new( - inner: ClientInner, - conn_info: ConnInfo, - pool: Weak>, - ) -> Self { + pub(self) fn new(inner: ClientInner, conn_info: ConnInfo, pool: Weak>) -> Self { Self { inner: Some(inner), span: Span::current(), @@ -420,6 +176,15 @@ impl LocalClient { pool, } } + + pub(crate) fn metrics(&self) -> Arc { + let aux = &self.inner.as_ref().unwrap().aux; + USAGE_METRICS.register(Ids { + endpoint_id: aux.endpoint_id, + branch_id: aux.branch_id, + }) + } + pub(crate) fn inner(&mut self) -> (&mut C, Discard<'_, C>) { let Self { inner, @@ -430,6 +195,7 @@ impl LocalClient { let inner = inner.as_mut().expect("client inner should not be removed"); (&mut inner.inner, Discard { conn_info, pool }) } + pub(crate) fn key(&self) -> &SigningKey { let inner = &self .inner @@ -437,6 +203,31 @@ impl LocalClient { .expect("client inner should not be removed"); &inner.key } + + pub fn get_client(&self) -> &C { + &self + .inner + .as_ref() + .expect("client inner should not be removed") + .inner + } + + fn do_drop(&mut self) -> Option { + let conn_info = self.conn_info.clone(); + let client = self + .inner + .take() + .expect("client inner should not be removed"); + if let Some(conn_pool) = std::mem::take(&mut self.pool).upgrade() { + let current_span = self.span.clone(); + // return connection to the pool + return Some(move || { + let _span = current_span.enter(); + EndpointConnPool::put(&conn_pool.local_pool, &conn_info, client); + }); + } + None + } } impl LocalClient { @@ -491,6 +282,11 @@ fn sign_jwt(sk: &SigningKey, header: String, payload: String) -> String { format!("{message}.{base64_sig}") } +pub(crate) struct Discard<'a, C: ClientInnerExt> { + conn_info: &'a ConnInfo, + pool: &'a mut Weak>, +} + impl Discard<'_, C> { pub(crate) fn check_idle(&mut self, status: ReadyForQueryStatus) { let conn_info = &self.conn_info; @@ -503,38 +299,14 @@ impl Discard<'_, C> { pub(crate) fn discard(&mut self) { let conn_info = &self.conn_info; if std::mem::take(self.pool).strong_count() > 0 { - info!("local_pool: throwing away connection '{conn_info}' because connection is potentially in a broken state"); + info!( + "local_pool: throwing away connection '{conn_info}' + because connection is potentially in a broken state" + ); } } } -impl LocalClient { - pub fn get_client(&self) -> &C { - &self - .inner - .as_ref() - .expect("client inner should not be removed") - .inner - } - - fn do_drop(&mut self) -> Option { - let conn_info = self.conn_info.clone(); - let client = self - .inner - .take() - .expect("client inner should not be removed"); - if let Some(conn_pool) = std::mem::take(&mut self.pool).upgrade() { - let current_span = self.span.clone(); - // return connection to the pool - return Some(move || { - let _span = current_span.enter(); - EndpointConnPool::put(&conn_pool.global_pool, &conn_info, client); - }); - } - None - } -} - impl Drop for LocalClient { fn drop(&mut self) { if let Some(drop) = self.do_drop() { diff --git a/proxy/src/serverless/mod.rs b/proxy/src/serverless/mod.rs index 9be6b592bd..e8b5b37922 100644 --- a/proxy/src/serverless/mod.rs +++ b/proxy/src/serverless/mod.rs @@ -64,8 +64,8 @@ pub async fn task_main( info!("websocket server has shut down"); } - let local_pool = local_conn_pool::LocalConnPool::new(&config.http_config); - let conn_pool = conn_pool::GlobalConnPool::new(&config.http_config); + let local_pool = conn_pool::ConnPool::new(&config.http_config); + let conn_pool = conn_pool::ConnPool::new(&config.http_config); { let conn_pool = Arc::clone(&conn_pool); tokio::spawn(async move {