simplify generics more

This commit is contained in:
Conrad Ludgate
2025-04-16 15:25:59 +01:00
parent cf05d4e4b2
commit 42e36ba5e8
4 changed files with 26 additions and 37 deletions

View File

@@ -42,10 +42,9 @@ use crate::rate_limiter::EndpointRateLimiter;
use crate::types::{EndpointId, Host, LOCAL_PROXY_SUFFIX};
pub(crate) struct PoolingBackend {
pub(crate) http_conn_pool: Arc<GlobalConnPool<Send, HttpConnPool>>,
pub(crate) http_conn_pool: Arc<GlobalConnPool<HttpConnPool>>,
pub(crate) local_pool: Arc<LocalConnPool<postgres_client::Client>>,
pub(crate) pool:
Arc<GlobalConnPool<postgres_client::Client, EndpointConnPool<postgres_client::Client>>>,
pub(crate) pool: Arc<GlobalConnPool<EndpointConnPool<postgres_client::Client>>>,
pub(crate) config: &'static ProxyConfig,
pub(crate) auth_backend: &'static crate::auth::Backend<'static, ()>,
@@ -248,7 +247,7 @@ impl PoolingBackend {
conn_info: ConnInfo,
) -> Result<http_conn_pool::Client<Send>, HttpConnError> {
debug!("pool: looking for an existing connection");
if let Ok(Some(client)) = self.http_conn_pool.get(ctx, &conn_info) {
if let Some(client) = self.http_conn_pool.get(ctx, &conn_info) {
return Ok(client);
}
@@ -532,7 +531,7 @@ impl ShouldRetryWakeCompute for LocalProxyConnError {
}
struct TokioMechanism {
pool: Arc<GlobalConnPool<postgres_client::Client, EndpointConnPool<postgres_client::Client>>>,
pool: Arc<GlobalConnPool<EndpointConnPool<postgres_client::Client>>>,
conn_info: ConnInfo,
conn_id: uuid::Uuid,
@@ -593,7 +592,7 @@ impl ConnectMechanism for TokioMechanism {
}
struct HyperMechanism {
pool: Arc<GlobalConnPool<Send, HttpConnPool>>,
pool: Arc<GlobalConnPool<HttpConnPool>>,
conn_info: ConnInfo,
conn_id: uuid::Uuid,

View File

@@ -57,7 +57,7 @@ impl fmt::Display for ConnInfo {
}
pub(crate) fn poll_client<C: ClientInnerExt>(
global_pool: Arc<GlobalConnPool<C, EndpointConnPool<C>>>,
global_pool: Arc<GlobalConnPool<EndpointConnPool<C>>>,
ctx: &RequestContext,
conn_info: ConnInfo,
client: C,

View File

@@ -1,5 +1,4 @@
use std::collections::HashMap;
use std::marker::PhantomData;
use std::ops::Deref;
use std::sync::atomic::{self, AtomicUsize};
use std::sync::{Arc, Weak};
@@ -326,12 +325,15 @@ impl<C: ClientInnerExt> DbUserConn<C> for DbUserConnPool<C> {
}
}
pub(crate) trait EndpointConnPoolExt<C: ClientInnerExt> {
pub(crate) trait EndpointConnPoolExt {
type Client;
fn clear_closed(&mut self) -> usize;
fn total_conns(&self) -> usize;
}
impl<C: ClientInnerExt> EndpointConnPoolExt<C> for EndpointConnPool<C> {
impl<C: ClientInnerExt> EndpointConnPoolExt for EndpointConnPool<C> {
type Client = Client<C>;
fn clear_closed(&mut self) -> usize {
let mut clients_removed: usize = 0;
for db_pool in self.pools.values_mut() {
@@ -345,10 +347,9 @@ impl<C: ClientInnerExt> EndpointConnPoolExt<C> for EndpointConnPool<C> {
}
}
pub(crate) struct GlobalConnPool<C, P>
pub(crate) struct GlobalConnPool<P>
where
C: ClientInnerExt,
P: EndpointConnPoolExt<C>,
P: EndpointConnPoolExt,
{
// endpoint -> per-endpoint connection pool
//
@@ -367,8 +368,6 @@ where
pub(crate) global_connections_count: Arc<AtomicUsize>,
pub(crate) config: &'static crate::config::HttpConfig,
_marker: PhantomData<C>,
}
#[derive(Debug, Clone, Copy)]
@@ -391,10 +390,9 @@ pub struct GlobalConnPoolOptions {
pub max_total_conns: usize,
}
impl<C, P> GlobalConnPool<C, P>
impl<P> GlobalConnPool<P>
where
C: ClientInnerExt,
P: EndpointConnPoolExt<C>,
P: EndpointConnPoolExt,
{
pub(crate) fn new(config: &'static crate::config::HttpConfig) -> Arc<Self> {
let shards = config.pool_options.pool_shards;
@@ -403,7 +401,6 @@ where
global_pool_size: AtomicUsize::new(0),
config,
global_connections_count: Arc::new(AtomicUsize::new(0)),
_marker: PhantomData,
})
}
@@ -492,7 +489,7 @@ where
}
}
impl<C: ClientInnerExt> GlobalConnPool<C, EndpointConnPool<C>> {
impl<C: ClientInnerExt> GlobalConnPool<EndpointConnPool<C>> {
pub(crate) fn get(
self: &Arc<Self>,
ctx: &RequestContext,

View File

@@ -9,7 +9,6 @@ use smol_str::ToSmolStr;
use tracing::{Instrument, debug, error, info, info_span};
use super::AsyncRW;
use super::backend::HttpConnError;
use super::conn_pool_lib::{
ClientDataEnum, ClientInnerCommon, ClientInnerExt, ConnInfo, ConnPoolEntry,
EndpointConnPoolExt, GlobalConnPool,
@@ -85,7 +84,9 @@ impl HttpConnPool {
}
}
impl EndpointConnPoolExt<Send> for HttpConnPool {
impl EndpointConnPoolExt for HttpConnPool {
type Client = Client<Send>;
fn clear_closed(&mut self) -> usize {
let Self { conns, .. } = self;
let old_len = conns.len();
@@ -114,23 +115,15 @@ impl Drop for HttpConnPool {
}
}
impl GlobalConnPool<Send, HttpConnPool> {
#[expect(unused_results)]
impl GlobalConnPool<HttpConnPool> {
pub(crate) fn get(
self: &Arc<Self>,
ctx: &RequestContext,
conn_info: &ConnInfo,
) -> Result<Option<Client<Send>>, HttpConnError> {
let result: Result<Option<Client<Send>>, HttpConnError>;
let Some(endpoint) = conn_info.endpoint_cache_key() else {
result = Ok(None);
return result;
};
let endpoint_pool = self.get_or_create_endpoint_pool(&endpoint);
let Some(client) = endpoint_pool.write().get_conn_entry() else {
result = Ok(None);
return result;
};
) -> Option<Client<Send>> {
let endpoint = conn_info.endpoint_cache_key()?;
let endpoint_pool = self.global_pool.get(&endpoint)?.clone();
let client = endpoint_pool.write().get_conn_entry()?;
tracing::Span::current().record("conn_id", tracing::field::display(client.conn.conn_id));
debug!(
@@ -140,7 +133,7 @@ impl GlobalConnPool<Send, HttpConnPool> {
ctx.set_cold_start_info(ColdStartInfo::HttpPoolHit);
ctx.success();
Ok(Some(Client::new(client.conn.clone())))
Some(Client::new(client.conn.clone()))
}
fn get_or_create_endpoint_pool(
@@ -186,7 +179,7 @@ impl GlobalConnPool<Send, HttpConnPool> {
}
pub(crate) fn poll_http2_client(
global_pool: Arc<GlobalConnPool<Send, HttpConnPool>>,
global_pool: Arc<GlobalConnPool<HttpConnPool>>,
ctx: &RequestContext,
conn_info: &ConnInfo,
client: Send,