proxy: Refactor http conn pool (#9785)

- Use the same ConnPoolEntry for http connection pool.
- Rename EndpointConnPool to the HttpConnPool.
- Narrow clone bound for client

Fixes #9284
This commit is contained in:
Ivan Efremov
2024-11-20 21:36:29 +02:00
committed by GitHub
parent 313ebfdb88
commit 2d6bf176a0
7 changed files with 187 additions and 238 deletions

View File

@@ -12,8 +12,8 @@ use tracing::field::display;
use tracing::{debug, info};
use super::conn_pool::poll_client;
use super::conn_pool_lib::{Client, ConnInfo, GlobalConnPool};
use super::http_conn_pool::{self, poll_http2_client, Send};
use super::conn_pool_lib::{Client, ConnInfo, EndpointConnPool, GlobalConnPool};
use super::http_conn_pool::{self, poll_http2_client, HttpConnPool, Send};
use super::local_conn_pool::{self, LocalConnPool, EXT_NAME, EXT_SCHEMA, EXT_VERSION};
use crate::auth::backend::local::StaticAuthRules;
use crate::auth::backend::{ComputeCredentials, ComputeUserInfo};
@@ -36,9 +36,10 @@ use crate::rate_limiter::EndpointRateLimiter;
use crate::types::{EndpointId, Host, LOCAL_PROXY_SUFFIX};
pub(crate) struct PoolingBackend {
pub(crate) http_conn_pool: Arc<super::http_conn_pool::GlobalConnPool<Send>>,
pub(crate) http_conn_pool: Arc<GlobalConnPool<Send, HttpConnPool<Send>>>,
pub(crate) local_pool: Arc<LocalConnPool<tokio_postgres::Client>>,
pub(crate) pool: Arc<GlobalConnPool<tokio_postgres::Client>>,
pub(crate) pool:
Arc<GlobalConnPool<tokio_postgres::Client, EndpointConnPool<tokio_postgres::Client>>>,
pub(crate) config: &'static ProxyConfig,
pub(crate) auth_backend: &'static crate::auth::Backend<'static, ()>,
@@ -474,7 +475,7 @@ impl ShouldRetryWakeCompute for LocalProxyConnError {
}
struct TokioMechanism {
pool: Arc<GlobalConnPool<tokio_postgres::Client>>,
pool: Arc<GlobalConnPool<tokio_postgres::Client, EndpointConnPool<tokio_postgres::Client>>>,
conn_info: ConnInfo,
conn_id: uuid::Uuid,
@@ -524,7 +525,7 @@ impl ConnectMechanism for TokioMechanism {
}
struct HyperMechanism {
pool: Arc<http_conn_pool::GlobalConnPool<Send>>,
pool: Arc<GlobalConnPool<Send, HttpConnPool<Send>>>,
conn_info: ConnInfo,
conn_id: uuid::Uuid,

View File

@@ -19,7 +19,8 @@ use {
};
use super::conn_pool_lib::{
Client, ClientDataEnum, ClientInnerCommon, ClientInnerExt, ConnInfo, GlobalConnPool,
Client, ClientDataEnum, ClientInnerCommon, ClientInnerExt, ConnInfo, EndpointConnPool,
GlobalConnPool,
};
use crate::context::RequestContext;
use crate::control_plane::messages::MetricsAuxInfo;
@@ -52,7 +53,7 @@ impl fmt::Display for ConnInfo {
}
pub(crate) fn poll_client<C: ClientInnerExt>(
global_pool: Arc<GlobalConnPool<C>>,
global_pool: Arc<GlobalConnPool<C, EndpointConnPool<C>>>,
ctx: &RequestContext,
conn_info: ConnInfo,
client: C,
@@ -167,6 +168,7 @@ pub(crate) fn poll_client<C: ClientInnerExt>(
Client::new(inner, conn_info, pool_clone)
}
#[derive(Clone)]
pub(crate) struct ClientDataRemote {
session: tokio::sync::watch::Sender<uuid::Uuid>,
cancel: CancellationToken,

View File

@@ -1,4 +1,5 @@
use std::collections::HashMap;
use std::marker::PhantomData;
use std::ops::Deref;
use std::sync::atomic::{self, AtomicUsize};
use std::sync::{Arc, Weak};
@@ -43,13 +44,14 @@ impl ConnInfo {
}
}
#[derive(Clone)]
pub(crate) enum ClientDataEnum {
Remote(ClientDataRemote),
Local(ClientDataLocal),
#[allow(dead_code)]
Http(ClientDataHttp),
}
#[derive(Clone)]
pub(crate) struct ClientInnerCommon<C: ClientInnerExt> {
pub(crate) inner: C,
pub(crate) aux: MetricsAuxInfo,
@@ -91,6 +93,7 @@ pub(crate) struct ConnPoolEntry<C: ClientInnerExt> {
pub(crate) struct EndpointConnPool<C: ClientInnerExt> {
pools: HashMap<(DbName, RoleName), DbUserConnPool<C>>,
total_conns: usize,
/// max # connections per endpoint
max_conns: usize,
_guard: HttpEndpointPoolsGuard<'static>,
global_connections_count: Arc<AtomicUsize>,
@@ -317,24 +320,49 @@ impl<C: ClientInnerExt> DbUserConn<C> for DbUserConnPool<C> {
}
}
pub(crate) struct GlobalConnPool<C: ClientInnerExt> {
pub(crate) trait EndpointConnPoolExt<C: ClientInnerExt> {
fn clear_closed(&mut self) -> usize;
fn total_conns(&self) -> usize;
}
impl<C: ClientInnerExt> EndpointConnPoolExt<C> for EndpointConnPool<C> {
fn clear_closed(&mut self) -> usize {
let mut clients_removed: usize = 0;
for db_pool in self.pools.values_mut() {
clients_removed += db_pool.clear_closed_clients(&mut self.total_conns);
}
clients_removed
}
fn total_conns(&self) -> usize {
self.total_conns
}
}
pub(crate) struct GlobalConnPool<C, P>
where
C: ClientInnerExt,
P: EndpointConnPoolExt<C>,
{
// endpoint -> per-endpoint connection pool
//
// That should be a fairly conteded map, so return reference to the per-endpoint
// pool as early as possible and release the lock.
global_pool: DashMap<EndpointCacheKey, Arc<RwLock<EndpointConnPool<C>>>>,
pub(crate) global_pool: DashMap<EndpointCacheKey, Arc<RwLock<P>>>,
/// Number of endpoint-connection pools
///
/// [`DashMap::len`] iterates over all inner pools and acquires a read lock on each.
/// That seems like far too much effort, so we're using a relaxed increment counter instead.
/// It's only used for diagnostics.
global_pool_size: AtomicUsize,
pub(crate) global_pool_size: AtomicUsize,
/// Total number of connections in the pool
global_connections_count: Arc<AtomicUsize>,
pub(crate) global_connections_count: Arc<AtomicUsize>,
config: &'static crate::config::HttpConfig,
pub(crate) config: &'static crate::config::HttpConfig,
_marker: PhantomData<C>,
}
#[derive(Debug, Clone, Copy)]
@@ -357,7 +385,11 @@ pub struct GlobalConnPoolOptions {
pub max_total_conns: usize,
}
impl<C: ClientInnerExt> GlobalConnPool<C> {
impl<C, P> GlobalConnPool<C, P>
where
C: ClientInnerExt,
P: EndpointConnPoolExt<C>,
{
pub(crate) fn new(config: &'static crate::config::HttpConfig) -> Arc<Self> {
let shards = config.pool_options.pool_shards;
Arc::new(Self {
@@ -365,6 +397,7 @@ impl<C: ClientInnerExt> GlobalConnPool<C> {
global_pool_size: AtomicUsize::new(0),
config,
global_connections_count: Arc::new(AtomicUsize::new(0)),
_marker: PhantomData,
})
}
@@ -378,6 +411,80 @@ impl<C: ClientInnerExt> GlobalConnPool<C> {
self.config.pool_options.idle_timeout
}
pub(crate) fn shutdown(&self) {
// drops all strong references to endpoint-pools
self.global_pool.clear();
}
pub(crate) async fn gc_worker(&self, mut rng: impl Rng) {
let epoch = self.config.pool_options.gc_epoch;
let mut interval = tokio::time::interval(epoch / (self.global_pool.shards().len()) as u32);
loop {
interval.tick().await;
let shard = rng.gen_range(0..self.global_pool.shards().len());
self.gc(shard);
}
}
pub(crate) fn gc(&self, shard: usize) {
debug!(shard, "pool: performing epoch reclamation");
// acquire a random shard lock
let mut shard = self.global_pool.shards()[shard].write();
let timer = Metrics::get()
.proxy
.http_pool_reclaimation_lag_seconds
.start_timer();
let current_len = shard.len();
let mut clients_removed = 0;
shard.retain(|endpoint, x| {
// if the current endpoint pool is unique (no other strong or weak references)
// then it is currently not in use by any connections.
if let Some(pool) = Arc::get_mut(x.get_mut()) {
let endpoints = pool.get_mut();
clients_removed = endpoints.clear_closed();
if endpoints.total_conns() == 0 {
info!("pool: discarding pool for endpoint {endpoint}");
return false;
}
}
true
});
let new_len = shard.len();
drop(shard);
timer.observe();
// Do logging outside of the lock.
if clients_removed > 0 {
let size = self
.global_connections_count
.fetch_sub(clients_removed, atomic::Ordering::Relaxed)
- clients_removed;
Metrics::get()
.proxy
.http_pool_opened_connections
.get_metric()
.dec_by(clients_removed as i64);
info!("pool: performed global pool gc. removed {clients_removed} clients, total number of clients in pool is {size}");
}
let removed = current_len - new_len;
if removed > 0 {
let global_pool_size = self
.global_pool_size
.fetch_sub(removed, atomic::Ordering::Relaxed)
- removed;
info!("pool: performed global pool gc. size now {global_pool_size}");
}
}
}
impl<C: ClientInnerExt> GlobalConnPool<C, EndpointConnPool<C>> {
pub(crate) fn get(
self: &Arc<Self>,
ctx: &RequestContext,
@@ -432,85 +539,6 @@ impl<C: ClientInnerExt> GlobalConnPool<C> {
Ok(None)
}
pub(crate) fn shutdown(&self) {
// drops all strong references to endpoint-pools
self.global_pool.clear();
}
pub(crate) async fn gc_worker(&self, mut rng: impl Rng) {
let epoch = self.config.pool_options.gc_epoch;
let mut interval = tokio::time::interval(epoch / (self.global_pool.shards().len()) as u32);
loop {
interval.tick().await;
let shard = rng.gen_range(0..self.global_pool.shards().len());
self.gc(shard);
}
}
pub(crate) fn gc(&self, shard: usize) {
debug!(shard, "pool: performing epoch reclamation");
// acquire a random shard lock
let mut shard = self.global_pool.shards()[shard].write();
let timer = Metrics::get()
.proxy
.http_pool_reclaimation_lag_seconds
.start_timer();
let current_len = shard.len();
let mut clients_removed = 0;
shard.retain(|endpoint, x| {
// if the current endpoint pool is unique (no other strong or weak references)
// then it is currently not in use by any connections.
if let Some(pool) = Arc::get_mut(x.get_mut()) {
let EndpointConnPool {
pools, total_conns, ..
} = pool.get_mut();
// ensure that closed clients are removed
for db_pool in pools.values_mut() {
clients_removed += db_pool.clear_closed_clients(total_conns);
}
// we only remove this pool if it has no active connections
if *total_conns == 0 {
info!("pool: discarding pool for endpoint {endpoint}");
return false;
}
}
true
});
let new_len = shard.len();
drop(shard);
timer.observe();
// Do logging outside of the lock.
if clients_removed > 0 {
let size = self
.global_connections_count
.fetch_sub(clients_removed, atomic::Ordering::Relaxed)
- clients_removed;
Metrics::get()
.proxy
.http_pool_opened_connections
.get_metric()
.dec_by(clients_removed as i64);
info!("pool: performed global pool gc. removed {clients_removed} clients, total number of clients in pool is {size}");
}
let removed = current_len - new_len;
if removed > 0 {
let global_pool_size = self
.global_pool_size
.fetch_sub(removed, atomic::Ordering::Relaxed)
- removed;
info!("pool: performed global pool gc. size now {global_pool_size}");
}
}
pub(crate) fn get_or_create_endpoint_pool(
self: &Arc<Self>,
endpoint: &EndpointCacheKey,
@@ -556,7 +584,6 @@ impl<C: ClientInnerExt> GlobalConnPool<C> {
pool
}
}
pub(crate) struct Client<C: ClientInnerExt> {
span: Span,
inner: Option<ClientInnerCommon<C>>,

View File

@@ -2,16 +2,17 @@ use std::collections::VecDeque;
use std::sync::atomic::{self, AtomicUsize};
use std::sync::{Arc, Weak};
use dashmap::DashMap;
use hyper::client::conn::http2;
use hyper_util::rt::{TokioExecutor, TokioIo};
use parking_lot::RwLock;
use rand::Rng;
use tokio::net::TcpStream;
use tracing::{debug, error, info, info_span, Instrument};
use super::backend::HttpConnError;
use super::conn_pool_lib::{ClientInnerExt, ConnInfo};
use super::conn_pool_lib::{
ClientDataEnum, ClientInnerCommon, ClientInnerExt, ConnInfo, ConnPoolEntry,
EndpointConnPoolExt, GlobalConnPool,
};
use crate::context::RequestContext;
use crate::control_plane::messages::{ColdStartInfo, MetricsAuxInfo};
use crate::metrics::{HttpEndpointPoolsGuard, Metrics};
@@ -23,17 +24,11 @@ pub(crate) type Connect =
http2::Connection<TokioIo<TcpStream>, hyper::body::Incoming, TokioExecutor>;
#[derive(Clone)]
pub(crate) struct ConnPoolEntry<C: ClientInnerExt + Clone> {
conn: C,
conn_id: uuid::Uuid,
aux: MetricsAuxInfo,
}
pub(crate) struct ClientDataHttp();
// Per-endpoint connection pool
// Number of open connections is limited by the `max_conns_per_endpoint`.
pub(crate) struct EndpointConnPool<C: ClientInnerExt + Clone> {
pub(crate) struct HttpConnPool<C: ClientInnerExt + Clone> {
// TODO(conrad):
// either we should open more connections depending on stream count
// (not exposed by hyper, need our own counter)
@@ -48,14 +43,19 @@ pub(crate) struct EndpointConnPool<C: ClientInnerExt + Clone> {
global_connections_count: Arc<AtomicUsize>,
}
impl<C: ClientInnerExt + Clone> EndpointConnPool<C> {
impl<C: ClientInnerExt + Clone> HttpConnPool<C> {
fn get_conn_entry(&mut self) -> Option<ConnPoolEntry<C>> {
let Self { conns, .. } = self;
loop {
let conn = conns.pop_front()?;
if !conn.conn.is_closed() {
conns.push_back(conn.clone());
if !conn.conn.inner.is_closed() {
let new_conn = ConnPoolEntry {
conn: conn.conn.clone(),
_last_access: std::time::Instant::now(),
};
conns.push_back(new_conn);
return Some(conn);
}
}
@@ -69,7 +69,7 @@ impl<C: ClientInnerExt + Clone> EndpointConnPool<C> {
} = self;
let old_len = conns.len();
conns.retain(|conn| conn.conn_id != conn_id);
conns.retain(|entry| entry.conn.conn_id != conn_id);
let new_len = conns.len();
let removed = old_len - new_len;
if removed > 0 {
@@ -84,7 +84,22 @@ impl<C: ClientInnerExt + Clone> EndpointConnPool<C> {
}
}
impl<C: ClientInnerExt + Clone> Drop for EndpointConnPool<C> {
impl<C: ClientInnerExt + Clone> EndpointConnPoolExt<C> for HttpConnPool<C> {
fn clear_closed(&mut self) -> usize {
let Self { conns, .. } = self;
let old_len = conns.len();
conns.retain(|entry| !entry.conn.inner.is_closed());
let new_len = conns.len();
old_len - new_len
}
fn total_conns(&self) -> usize {
self.conns.len()
}
}
impl<C: ClientInnerExt + Clone> Drop for HttpConnPool<C> {
fn drop(&mut self) {
if !self.conns.is_empty() {
self.global_connections_count
@@ -98,117 +113,7 @@ impl<C: ClientInnerExt + Clone> Drop for EndpointConnPool<C> {
}
}
pub(crate) struct GlobalConnPool<C: ClientInnerExt + Clone> {
// endpoint -> per-endpoint connection pool
//
// That should be a fairly conteded map, so return reference to the per-endpoint
// pool as early as possible and release the lock.
global_pool: DashMap<EndpointCacheKey, Arc<RwLock<EndpointConnPool<C>>>>,
/// Number of endpoint-connection pools
///
/// [`DashMap::len`] iterates over all inner pools and acquires a read lock on each.
/// That seems like far too much effort, so we're using a relaxed increment counter instead.
/// It's only used for diagnostics.
global_pool_size: AtomicUsize,
/// Total number of connections in the pool
global_connections_count: Arc<AtomicUsize>,
config: &'static crate::config::HttpConfig,
}
impl<C: ClientInnerExt + Clone> GlobalConnPool<C> {
pub(crate) fn new(config: &'static crate::config::HttpConfig) -> Arc<Self> {
let shards = config.pool_options.pool_shards;
Arc::new(Self {
global_pool: DashMap::with_shard_amount(shards),
global_pool_size: AtomicUsize::new(0),
config,
global_connections_count: Arc::new(AtomicUsize::new(0)),
})
}
pub(crate) fn shutdown(&self) {
// drops all strong references to endpoint-pools
self.global_pool.clear();
}
pub(crate) async fn gc_worker(&self, mut rng: impl Rng) {
let epoch = self.config.pool_options.gc_epoch;
let mut interval = tokio::time::interval(epoch / (self.global_pool.shards().len()) as u32);
loop {
interval.tick().await;
let shard = rng.gen_range(0..self.global_pool.shards().len());
self.gc(shard);
}
}
fn gc(&self, shard: usize) {
debug!(shard, "pool: performing epoch reclamation");
// acquire a random shard lock
let mut shard = self.global_pool.shards()[shard].write();
let timer = Metrics::get()
.proxy
.http_pool_reclaimation_lag_seconds
.start_timer();
let current_len = shard.len();
let mut clients_removed = 0;
shard.retain(|endpoint, x| {
// if the current endpoint pool is unique (no other strong or weak references)
// then it is currently not in use by any connections.
if let Some(pool) = Arc::get_mut(x.get_mut()) {
let EndpointConnPool { conns, .. } = pool.get_mut();
let old_len = conns.len();
conns.retain(|conn| !conn.conn.is_closed());
let new_len = conns.len();
let removed = old_len - new_len;
clients_removed += removed;
// we only remove this pool if it has no active connections
if conns.is_empty() {
info!("pool: discarding pool for endpoint {endpoint}");
return false;
}
}
true
});
let new_len = shard.len();
drop(shard);
timer.observe();
// Do logging outside of the lock.
if clients_removed > 0 {
let size = self
.global_connections_count
.fetch_sub(clients_removed, atomic::Ordering::Relaxed)
- clients_removed;
Metrics::get()
.proxy
.http_pool_opened_connections
.get_metric()
.dec_by(clients_removed as i64);
info!("pool: performed global pool gc. removed {clients_removed} clients, total number of clients in pool is {size}");
}
let removed = current_len - new_len;
if removed > 0 {
let global_pool_size = self
.global_pool_size
.fetch_sub(removed, atomic::Ordering::Relaxed)
- removed;
info!("pool: performed global pool gc. size now {global_pool_size}");
}
}
impl<C: ClientInnerExt + Clone> GlobalConnPool<C, HttpConnPool<C>> {
#[expect(unused_results)]
pub(crate) fn get(
self: &Arc<Self>,
@@ -226,27 +131,28 @@ impl<C: ClientInnerExt + Clone> GlobalConnPool<C> {
return result;
};
tracing::Span::current().record("conn_id", tracing::field::display(client.conn_id));
tracing::Span::current().record("conn_id", tracing::field::display(client.conn.conn_id));
debug!(
cold_start_info = ColdStartInfo::HttpPoolHit.as_str(),
"pool: reusing connection '{conn_info}'"
);
ctx.set_cold_start_info(ColdStartInfo::HttpPoolHit);
ctx.success();
Ok(Some(Client::new(client.conn, client.aux)))
Ok(Some(Client::new(client.conn.clone())))
}
fn get_or_create_endpoint_pool(
self: &Arc<Self>,
endpoint: &EndpointCacheKey,
) -> Arc<RwLock<EndpointConnPool<C>>> {
) -> Arc<RwLock<HttpConnPool<C>>> {
// fast path
if let Some(pool) = self.global_pool.get(endpoint) {
return pool.clone();
}
// slow path
let new_pool = Arc::new(RwLock::new(EndpointConnPool {
let new_pool = Arc::new(RwLock::new(HttpConnPool {
conns: VecDeque::new(),
_guard: Metrics::get().proxy.http_endpoint_pools.guard(),
global_connections_count: self.global_connections_count.clone(),
@@ -279,7 +185,7 @@ impl<C: ClientInnerExt + Clone> GlobalConnPool<C> {
}
pub(crate) fn poll_http2_client(
global_pool: Arc<GlobalConnPool<Send>>,
global_pool: Arc<GlobalConnPool<Send, HttpConnPool<Send>>>,
ctx: &RequestContext,
conn_info: &ConnInfo,
client: Send,
@@ -299,11 +205,15 @@ pub(crate) fn poll_http2_client(
let pool = match conn_info.endpoint_cache_key() {
Some(endpoint) => {
let pool = global_pool.get_or_create_endpoint_pool(&endpoint);
pool.write().conns.push_back(ConnPoolEntry {
conn: client.clone(),
conn_id,
let client = ClientInnerCommon {
inner: client.clone(),
aux: aux.clone(),
conn_id,
data: ClientDataEnum::Http(ClientDataHttp()),
};
pool.write().conns.push_back(ConnPoolEntry {
conn: client,
_last_access: std::time::Instant::now(),
});
Metrics::get()
.proxy
@@ -335,23 +245,30 @@ pub(crate) fn poll_http2_client(
.instrument(span),
);
Client::new(client, aux)
let client = ClientInnerCommon {
inner: client,
aux,
conn_id,
data: ClientDataEnum::Http(ClientDataHttp()),
};
Client::new(client)
}
pub(crate) struct Client<C: ClientInnerExt + Clone> {
pub(crate) inner: C,
aux: MetricsAuxInfo,
pub(crate) inner: ClientInnerCommon<C>,
}
impl<C: ClientInnerExt + Clone> Client<C> {
pub(self) fn new(inner: C, aux: MetricsAuxInfo) -> Self {
Self { inner, aux }
pub(self) fn new(inner: ClientInnerCommon<C>) -> Self {
Self { inner }
}
pub(crate) fn metrics(&self) -> Arc<MetricCounter> {
let aux = &self.inner.aux;
USAGE_METRICS.register(Ids {
endpoint_id: self.aux.endpoint_id,
branch_id: self.aux.branch_id,
endpoint_id: aux.endpoint_id,
branch_id: aux.branch_id,
})
}
}

View File

@@ -44,6 +44,7 @@ pub(crate) const EXT_NAME: &str = "pg_session_jwt";
pub(crate) const EXT_VERSION: &str = "0.1.2";
pub(crate) const EXT_SCHEMA: &str = "auth";
#[derive(Clone)]
pub(crate) struct ClientDataLocal {
session: tokio::sync::watch::Sender<uuid::Uuid>,
cancel: CancellationToken,

View File

@@ -88,7 +88,7 @@ pub async fn task_main(
}
});
let http_conn_pool = http_conn_pool::GlobalConnPool::new(&config.http_config);
let http_conn_pool = conn_pool_lib::GlobalConnPool::new(&config.http_config);
{
let http_conn_pool = Arc::clone(&http_conn_pool);
tokio::spawn(async move {

View File

@@ -779,6 +779,7 @@ async fn handle_auth_broker_inner(
let _metrics = client.metrics();
Ok(client
.inner
.inner
.send_request(req)
.await