From 441d74476d762e7728fb8f6173cdfd0215561061 Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Thu, 18 Apr 2024 10:46:56 +0100 Subject: [PATCH] refactor conn pool to use intrusive linked list --- Cargo.lock | 17 ++++ proxy/Cargo.toml | 1 + proxy/src/serverless/backend.rs | 3 +- proxy/src/serverless/conn_pool.rs | 161 ++++++++++++++++++++++++------ 4 files changed, 153 insertions(+), 29 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index e4bf71c64f..26457095df 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3938,6 +3938,16 @@ dependencies = [ "siphasher", ] +[[package]] +name = "pin-list" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fe91484d5a948b56f858ff2b92fd5b20b97d21b11d2d41041db8e5ec12d56c5e" +dependencies = [ + "pin-project-lite", + "pinned-aliasable", +] + [[package]] name = "pin-project" version = "1.1.0" @@ -3970,6 +3980,12 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" +[[package]] +name = "pinned-aliasable" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d0f9ae89bf0ed03b69ac1f3f7ea2e6e09b4fa5448011df2e67d581c2b850b7b" + [[package]] name = "pkcs8" version = "0.9.0" @@ -4349,6 +4365,7 @@ dependencies = [ "parquet", "parquet_derive", "pbkdf2", + "pin-list", "pin-project-lite", "postgres-native-tls", "postgres-protocol", diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index 6b8f2ecbf4..95c6edfafa 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -100,6 +100,7 @@ postgres-protocol.workspace = true redis.workspace = true workspace_hack.workspace = true +pin-list = { version = "0.1.0", features = ["std"] } [dev-dependencies] camino-tempfile.workspace = true diff --git a/proxy/src/serverless/backend.rs b/proxy/src/serverless/backend.rs index 9267449b6f..7afb1c2275 100644 --- a/proxy/src/serverless/backend.rs +++ b/proxy/src/serverless/backend.rs @@ -192,7 +192,8 @@ impl ConnectMechanism for TokioMechanism { connection, self.conn_id, node_info.aux.clone(), - )) + ) + .await) } fn update_connect_config(&self, _config: &mut compute::ConnCfg) {} diff --git a/proxy/src/serverless/conn_pool.rs b/proxy/src/serverless/conn_pool.rs index e7fd1f8864..c23ffe08ab 100644 --- a/proxy/src/serverless/conn_pool.rs +++ b/proxy/src/serverless/conn_pool.rs @@ -1,9 +1,11 @@ use dashmap::DashMap; use futures::Future; use parking_lot::RwLock; +use pin_list::Node; use pin_project_lite::pin_project; use rand::Rng; use smallvec::SmallVec; +use std::pin::Pin; use std::sync::Weak; use std::{collections::HashMap, sync::Arc, time::Duration}; use std::{ @@ -14,6 +16,7 @@ use std::{ ops::Deref, sync::atomic::{self, AtomicUsize}, }; +use tokio::sync::mpsc::error::TrySendError; use tokio::time::{Instant, Sleep}; use tokio_postgres::tls::NoTlsStream; use tokio_postgres::{AsyncMessage, ReadyForQueryStatus, Socket}; @@ -85,7 +88,11 @@ pub struct EndpointConnPool { } impl EndpointConnPool { - fn get_conn_entry(&mut self, db_user: (DbName, RoleName)) -> Option> { + fn get_conn_entry( + &mut self, + db_user: (DbName, RoleName), + session_id: uuid::Uuid, + ) -> Option> { let Self { pools, total_conns, @@ -93,7 +100,7 @@ impl EndpointConnPool { .. } = self; pools.get_mut(&db_user).and_then(|pool_entries| { - pool_entries.get_conn_entry(total_conns, global_connections_count) + pool_entries.get_conn_entry(total_conns, global_connections_count, session_id) }) } @@ -105,10 +112,17 @@ impl EndpointConnPool { .. } = self; if let Some(pool) = pools.get_mut(&db_user) { - let old_len = pool.conns.len(); - pool.conns.retain(|conn| conn.conn.conn_id != conn_id); - let new_len = pool.conns.len(); - let removed = old_len - new_len; + let mut removed = 0; + let mut cursor = pool.conns2.cursor_front_mut(); + while let Some(client) = cursor.protected() { + if client.conn.conn_id != conn_id { + let _ = cursor.remove_current(uuid::Uuid::nil()); + removed += 1; + } else { + cursor.move_next() + } + } + if removed > 0 { global_connections_count.fetch_sub(removed, atomic::Ordering::Relaxed); Metrics::get() @@ -124,7 +138,13 @@ impl EndpointConnPool { } } - fn put(pool: &RwLock, conn_info: &ConnInfo, client: ClientInner) { + fn put( + pool: &RwLock, + mut node: Pin<&mut Node>>, + db_user: &(DbName, RoleName), + client: ClientInner, + conn_info: ConnInfo, + ) -> bool { let conn_id = client.conn_id; { @@ -135,7 +155,7 @@ impl EndpointConnPool { >= pool.global_pool_size_max_conns { info!(%conn_id, "pool: throwing away connection '{conn_info}' because pool is full"); - return; + return false; } } @@ -146,14 +166,24 @@ impl EndpointConnPool { let mut pool = pool.write(); if pool.total_conns < pool.max_conns { - let pool_entries = pool.pools.entry(conn_info.db_and_user()).or_default(); - pool_entries.conns.push(ConnPoolEntry { - conn: client, - _last_access: std::time::Instant::now(), - }); + let pool_entries = pool.pools.entry(db_user.clone()).or_default(); + + if let Some(node) = node.as_mut().initialized_mut() { + if node.take_removed(&pool_entries.conns2).is_err() { + panic!("client is already in the pool") + }; + } + pool_entries.conns2.cursor_front_mut().insert_after( + node, + ConnPoolEntry { + conn: client, + _last_access: std::time::Instant::now(), + }, + (), + ); returned = true; - per_db_size = pool_entries.conns.len(); + per_db_size = pool_entries.len; pool.total_conns += 1; pool.global_connections_count @@ -174,6 +204,8 @@ impl EndpointConnPool { } else { info!(%conn_id, "pool: throwing away connection '{conn_info}' because pool is full, total_conns={total_conns}"); } + + returned } } @@ -192,23 +224,33 @@ impl Drop for EndpointConnPool { } pub struct DbUserConnPool { - conns: Vec>, + conns2: pin_list::PinList>, + len: usize, } impl Default for DbUserConnPool { fn default() -> Self { - Self { conns: Vec::new() } + Self { + conns2: pin_list::PinList::new(pin_list::id::Checked::new()), + len: 0, + } } } impl DbUserConnPool { fn clear_closed_clients(&mut self, conns: &mut usize) -> usize { - let old_len = self.conns.len(); + let mut removed = 0; - self.conns.retain(|conn| !conn.conn.is_closed()); + let mut cursor = self.conns2.cursor_front_mut(); + while let Some(client) = cursor.protected() { + if client.conn.is_closed() { + let _ = cursor.remove_current(uuid::Uuid::nil()); + removed += 1; + } else { + cursor.move_next() + } + } - let new_len = self.conns.len(); - let removed = old_len - new_len; *conns -= removed; removed } @@ -217,9 +259,16 @@ impl DbUserConnPool { &mut self, conns: &mut usize, global_connections_count: &AtomicUsize, + session_id: uuid::Uuid, ) -> Option> { let mut removed = self.clear_closed_clients(conns); - let conn = self.conns.pop(); + + let conn = self + .conns2 + .cursor_front_mut() + .remove_current(session_id) + .ok(); + if conn.is_some() { *conns -= 1; removed += 1; @@ -387,7 +436,7 @@ impl GlobalConnPool { let endpoint_pool = self.get_or_create_endpoint_pool(&endpoint); if let Some(entry) = endpoint_pool .write() - .get_conn_entry(conn_info.db_and_user()) + .get_conn_entry(conn_info.db_and_user(), ctx.session_id) { client = Some(entry.conn) } @@ -462,7 +511,16 @@ impl GlobalConnPool { } } -pub fn poll_tokio_client( +type PinListTypes = dyn pin_list::Types< + Id = pin_list::id::Checked, + Protected = ConnPoolEntry, + // session ID + Removed = uuid::Uuid, + // conn ID + Unprotected = (), +>; + +pub async fn poll_tokio_client( global_pool: Arc>, ctx: &mut RequestMonitoring, conn_info: ConnInfo, @@ -538,14 +596,21 @@ pub fn poll_client + Send + 'static>( let idle = global_pool.get_idle_timeout(); let cancel = CancellationToken::new(); + let (send_client, recv_client) = tokio::sync::mpsc::channel(1); let db_conn = DbConnection { cancelled: cancel.clone().cancelled_owned(), + idle_timeout: tokio::time::sleep(idle), idle, + + node: Node::>::new(), + recv_client, db_user: conn_info.db_and_user(), pool: pool.clone(), + session_span, session_rx: rx, + conn_gauge, conn_id, connection, @@ -556,6 +621,7 @@ pub fn poll_client + Send + 'static>( let inner = ClientInner { inner: client, session: tx, + pool: send_client, cancel, aux, conn_id, @@ -565,19 +631,27 @@ pub fn poll_client + Send + 'static>( pin_project! { struct DbConnection { + // Used to close the current conn if the client is dropped #[pin] cancelled: WaitForCancellationFutureOwned, + // Used to close the current conn if it's idle #[pin] idle_timeout: Sleep, idle: tokio::time::Duration, + // Used to add/remove conn from the conn pool + #[pin] + node: Node>, + recv_client: tokio::sync::mpsc::Receiver<(tracing::Span, ClientInner, ConnInfo)>, db_user: (DbName, RoleName), pool: Weak>>, + // Used for reporting the current session the conn is attached to session_span: tracing::Span, session_rx: tokio::sync::watch::Receiver, + // Static connection state conn_gauge: NumDbConnectionsGuard<'static>, conn_id: uuid::Uuid, #[pin] @@ -596,6 +670,24 @@ impl> Future for DbConnection { return Poll::Ready(()); } + if let Poll::Ready(client) = this.recv_client.poll_recv(cx) { + // if the send_client is dropped, then the client is dropped + let Some((span, client, conn_info)) = client else { + info!("connection dropped"); + return Poll::Ready(()); + }; + // if there's no pool, then this client will be closed. + let Some(pool) = this.pool.upgrade() else { + info!("connection dropped"); + return Poll::Ready(()); + }; + + let _span = span.enter(); + if !EndpointConnPool::put(&*pool, this.node.as_mut(), this.db_user, client, conn_info) { + return Poll::Ready(()); + } + } + match this.session_rx.has_changed() { Ok(true) => { let session_id = *this.session_rx.borrow_and_update(); @@ -653,6 +745,7 @@ impl> Future for DbConnection { struct ClientInner { inner: C, session: tokio::sync::watch::Sender, + pool: tokio::sync::mpsc::Sender<(tracing::Span, ClientInner, ConnInfo)>, cancel: CancellationToken, aux: MetricsAuxInfo, conn_id: uuid::Uuid, @@ -774,22 +867,27 @@ impl Client { return; } - if let Some(conn_pool) = self.pool.upgrade() { - // return connection to the pool - let _span = self.span.enter(); - EndpointConnPool::put(&conn_pool, &conn_info, client); + let tx = client.pool.clone(); + match tx.try_send((self.span.clone(), client, self.conn_info.clone())) { + Ok(_) => {} + Err(TrySendError::Closed(_)) => {} + Err(TrySendError::Full(_)) => { + error!("client channel should not be full") + } } } } impl Drop for Client { fn drop(&mut self) { - self.do_drop(); + self.do_drop() } } #[cfg(test)] mod tests { + use tokio::task::yield_now; + use crate::{BranchId, EndpointId, ProjectId}; use super::*; @@ -855,24 +953,28 @@ mod tests { assert_eq!(0, pool.get_global_connections_count()); client.inner().1.discard(); drop(client); + yield_now().await; // Discard should not add the connection from the pool. assert_eq!(0, pool.get_global_connections_count()); } { let (client, _) = create_inner(pool.clone(), conn_info.clone()); drop(client); + yield_now().await; assert_eq!(1, pool.get_global_connections_count()); } { let (client, cancel) = create_inner(pool.clone(), conn_info.clone()); cancel.cancel(); drop(client); + yield_now().await; // The closed client shouldn't be added to the pool. assert_eq!(1, pool.get_global_connections_count()); } let cancel = { let (client, cancel) = create_inner(pool.clone(), conn_info.clone()); drop(client); + yield_now().await; // The client should be added to the pool. assert_eq!(2, pool.get_global_connections_count()); cancel @@ -880,6 +982,7 @@ mod tests { { let client = create_inner(pool.clone(), conn_info.clone()); drop(client); + yield_now().await; // The client shouldn't be added to the pool. Because the ep-pool is full. assert_eq!(2, pool.get_global_connections_count()); } @@ -896,11 +999,13 @@ mod tests { { let client = create_inner(pool.clone(), conn_info.clone()); drop(client); + yield_now().await; assert_eq!(3, pool.get_global_connections_count()); } { let client = create_inner(pool.clone(), conn_info.clone()); drop(client); + yield_now().await; // The client shouldn't be added to the pool. Because the global pool is full. assert_eq!(3, pool.get_global_connections_count()); }