mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-06 04:52:55 +00:00
proxy: Unify local and remote conn pool client structures (#9604)
Unify client, EndpointConnPool and DbUserConnPool for remote and local conn. - Use new ClientDataEnum for additional client data. - Add ClientInnerCommon client structure. - Remove Client and EndpointConnPool code from local_conn_pool.rs
This commit is contained in:
@@ -14,7 +14,7 @@ use tracing::{debug, info};
|
||||
use super::conn_pool::poll_client;
|
||||
use super::conn_pool_lib::{Client, ConnInfo, GlobalConnPool};
|
||||
use super::http_conn_pool::{self, poll_http2_client, Send};
|
||||
use super::local_conn_pool::{self, LocalClient, LocalConnPool, EXT_NAME, EXT_SCHEMA, EXT_VERSION};
|
||||
use super::local_conn_pool::{self, LocalConnPool, EXT_NAME, EXT_SCHEMA, EXT_VERSION};
|
||||
use crate::auth::backend::local::StaticAuthRules;
|
||||
use crate::auth::backend::{ComputeCredentials, ComputeUserInfo};
|
||||
use crate::auth::{self, check_peer_addr_is_in_list, AuthError};
|
||||
@@ -205,7 +205,7 @@ impl PoolingBackend {
|
||||
conn_info: ConnInfo,
|
||||
) -> Result<http_conn_pool::Client<Send>, HttpConnError> {
|
||||
info!("pool: looking for an existing connection");
|
||||
if let Some(client) = self.http_conn_pool.get(ctx, &conn_info) {
|
||||
if let Ok(Some(client)) = self.http_conn_pool.get(ctx, &conn_info) {
|
||||
return Ok(client);
|
||||
}
|
||||
|
||||
@@ -248,7 +248,7 @@ impl PoolingBackend {
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
conn_info: ConnInfo,
|
||||
) -> Result<LocalClient<tokio_postgres::Client>, HttpConnError> {
|
||||
) -> Result<Client<tokio_postgres::Client>, HttpConnError> {
|
||||
if let Some(client) = self.local_pool.get(ctx, &conn_info)? {
|
||||
return Ok(client);
|
||||
}
|
||||
|
||||
@@ -18,7 +18,9 @@ use {
|
||||
std::{sync::atomic, time::Duration},
|
||||
};
|
||||
|
||||
use super::conn_pool_lib::{Client, ClientInnerExt, ConnInfo, GlobalConnPool};
|
||||
use super::conn_pool_lib::{
|
||||
Client, ClientDataEnum, ClientInnerCommon, ClientInnerExt, ConnInfo, GlobalConnPool,
|
||||
};
|
||||
use crate::context::RequestMonitoring;
|
||||
use crate::control_plane::messages::MetricsAuxInfo;
|
||||
use crate::metrics::Metrics;
|
||||
@@ -152,53 +154,30 @@ pub(crate) fn poll_client<C: ClientInnerExt>(
|
||||
|
||||
}
|
||||
.instrument(span));
|
||||
let inner = ClientInnerRemote {
|
||||
let inner = ClientInnerCommon {
|
||||
inner: client,
|
||||
session: tx,
|
||||
cancel,
|
||||
aux,
|
||||
conn_id,
|
||||
data: ClientDataEnum::Remote(ClientDataRemote {
|
||||
session: tx,
|
||||
cancel,
|
||||
}),
|
||||
};
|
||||
|
||||
Client::new(inner, conn_info, pool_clone)
|
||||
}
|
||||
|
||||
pub(crate) struct ClientInnerRemote<C: ClientInnerExt> {
|
||||
inner: C,
|
||||
pub(crate) struct ClientDataRemote {
|
||||
session: tokio::sync::watch::Sender<uuid::Uuid>,
|
||||
cancel: CancellationToken,
|
||||
aux: MetricsAuxInfo,
|
||||
conn_id: uuid::Uuid,
|
||||
}
|
||||
|
||||
impl<C: ClientInnerExt> ClientInnerRemote<C> {
|
||||
pub(crate) fn inner_mut(&mut self) -> &mut C {
|
||||
&mut self.inner
|
||||
}
|
||||
|
||||
pub(crate) fn inner(&self) -> &C {
|
||||
&self.inner
|
||||
}
|
||||
|
||||
pub(crate) fn session(&mut self) -> &mut tokio::sync::watch::Sender<uuid::Uuid> {
|
||||
impl ClientDataRemote {
|
||||
pub fn session(&mut self) -> &mut tokio::sync::watch::Sender<uuid::Uuid> {
|
||||
&mut self.session
|
||||
}
|
||||
|
||||
pub(crate) fn aux(&self) -> &MetricsAuxInfo {
|
||||
&self.aux
|
||||
}
|
||||
|
||||
pub(crate) fn get_conn_id(&self) -> uuid::Uuid {
|
||||
self.conn_id
|
||||
}
|
||||
|
||||
pub(crate) fn is_closed(&self) -> bool {
|
||||
self.inner.is_closed()
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: ClientInnerExt> Drop for ClientInnerRemote<C> {
|
||||
fn drop(&mut self) {
|
||||
// on client drop, tell the conn to shut down
|
||||
pub fn cancel(&mut self) {
|
||||
self.cancel.cancel();
|
||||
}
|
||||
}
|
||||
@@ -228,15 +207,13 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
fn create_inner() -> ClientInnerRemote<MockClient> {
|
||||
fn create_inner() -> ClientInnerCommon<MockClient> {
|
||||
create_inner_with(MockClient::new(false))
|
||||
}
|
||||
|
||||
fn create_inner_with(client: MockClient) -> ClientInnerRemote<MockClient> {
|
||||
ClientInnerRemote {
|
||||
fn create_inner_with(client: MockClient) -> ClientInnerCommon<MockClient> {
|
||||
ClientInnerCommon {
|
||||
inner: client,
|
||||
session: tokio::sync::watch::Sender::new(uuid::Uuid::new_v4()),
|
||||
cancel: CancellationToken::new(),
|
||||
aux: MetricsAuxInfo {
|
||||
endpoint_id: (&EndpointId::from("endpoint")).into(),
|
||||
project_id: (&ProjectId::from("project")).into(),
|
||||
@@ -244,6 +221,10 @@ mod tests {
|
||||
cold_start_info: crate::control_plane::messages::ColdStartInfo::Warm,
|
||||
},
|
||||
conn_id: uuid::Uuid::new_v4(),
|
||||
data: ClientDataEnum::Remote(ClientDataRemote {
|
||||
session: tokio::sync::watch::Sender::new(uuid::Uuid::new_v4()),
|
||||
cancel: CancellationToken::new(),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -280,7 +261,7 @@ mod tests {
|
||||
{
|
||||
let mut client = Client::new(create_inner(), conn_info.clone(), ep_pool.clone());
|
||||
assert_eq!(0, pool.get_global_connections_count());
|
||||
client.inner_mut().1.discard();
|
||||
client.inner().1.discard();
|
||||
// Discard should not add the connection from the pool.
|
||||
assert_eq!(0, pool.get_global_connections_count());
|
||||
}
|
||||
|
||||
@@ -11,10 +11,13 @@ use tokio_postgres::ReadyForQueryStatus;
|
||||
use tracing::{debug, info, Span};
|
||||
|
||||
use super::backend::HttpConnError;
|
||||
use super::conn_pool::ClientInnerRemote;
|
||||
use super::conn_pool::ClientDataRemote;
|
||||
use super::http_conn_pool::ClientDataHttp;
|
||||
use super::local_conn_pool::ClientDataLocal;
|
||||
use crate::auth::backend::ComputeUserInfo;
|
||||
use crate::context::RequestMonitoring;
|
||||
use crate::control_plane::messages::ColdStartInfo;
|
||||
use crate::control_plane::messages::MetricsAuxInfo;
|
||||
use crate::metrics::{HttpEndpointPoolsGuard, Metrics};
|
||||
use crate::types::{DbName, EndpointCacheKey, RoleName};
|
||||
use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS};
|
||||
@@ -41,8 +44,46 @@ impl ConnInfo {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) enum ClientDataEnum {
|
||||
Remote(ClientDataRemote),
|
||||
Local(ClientDataLocal),
|
||||
#[allow(dead_code)]
|
||||
Http(ClientDataHttp),
|
||||
}
|
||||
|
||||
pub(crate) struct ClientInnerCommon<C: ClientInnerExt> {
|
||||
pub(crate) inner: C,
|
||||
pub(crate) aux: MetricsAuxInfo,
|
||||
pub(crate) conn_id: uuid::Uuid,
|
||||
pub(crate) data: ClientDataEnum, // custom client data like session, key, jti
|
||||
}
|
||||
|
||||
impl<C: ClientInnerExt> Drop for ClientInnerCommon<C> {
|
||||
fn drop(&mut self) {
|
||||
match &mut self.data {
|
||||
ClientDataEnum::Remote(remote_data) => {
|
||||
remote_data.cancel();
|
||||
}
|
||||
ClientDataEnum::Local(local_data) => {
|
||||
local_data.cancel();
|
||||
}
|
||||
ClientDataEnum::Http(_http_data) => (),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: ClientInnerExt> ClientInnerCommon<C> {
|
||||
pub(crate) fn get_conn_id(&self) -> uuid::Uuid {
|
||||
self.conn_id
|
||||
}
|
||||
|
||||
pub(crate) fn get_data(&mut self) -> &mut ClientDataEnum {
|
||||
&mut self.data
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct ConnPoolEntry<C: ClientInnerExt> {
|
||||
pub(crate) conn: ClientInnerRemote<C>,
|
||||
pub(crate) conn: ClientInnerCommon<C>,
|
||||
pub(crate) _last_access: std::time::Instant,
|
||||
}
|
||||
|
||||
@@ -55,10 +96,33 @@ pub(crate) struct EndpointConnPool<C: ClientInnerExt> {
|
||||
_guard: HttpEndpointPoolsGuard<'static>,
|
||||
global_connections_count: Arc<AtomicUsize>,
|
||||
global_pool_size_max_conns: usize,
|
||||
pool_name: String,
|
||||
}
|
||||
|
||||
impl<C: ClientInnerExt> EndpointConnPool<C> {
|
||||
fn get_conn_entry(&mut self, db_user: (DbName, RoleName)) -> Option<ConnPoolEntry<C>> {
|
||||
pub(crate) fn new(
|
||||
hmap: HashMap<(DbName, RoleName), DbUserConnPool<C>>,
|
||||
tconns: usize,
|
||||
max_conns_per_endpoint: usize,
|
||||
global_connections_count: Arc<AtomicUsize>,
|
||||
max_total_conns: usize,
|
||||
pname: String,
|
||||
) -> Self {
|
||||
Self {
|
||||
pools: hmap,
|
||||
total_conns: tconns,
|
||||
max_conns: max_conns_per_endpoint,
|
||||
_guard: Metrics::get().proxy.http_endpoint_pools.guard(),
|
||||
global_connections_count,
|
||||
global_pool_size_max_conns: max_total_conns,
|
||||
pool_name: pname,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn get_conn_entry(
|
||||
&mut self,
|
||||
db_user: (DbName, RoleName),
|
||||
) -> Option<ConnPoolEntry<C>> {
|
||||
let Self {
|
||||
pools,
|
||||
total_conns,
|
||||
@@ -84,9 +148,10 @@ impl<C: ClientInnerExt> EndpointConnPool<C> {
|
||||
..
|
||||
} = self;
|
||||
if let Some(pool) = pools.get_mut(&db_user) {
|
||||
let old_len = pool.conns.len();
|
||||
pool.conns.retain(|conn| conn.conn.get_conn_id() != conn_id);
|
||||
let new_len = pool.conns.len();
|
||||
let old_len = pool.get_conns().len();
|
||||
pool.get_conns()
|
||||
.retain(|conn| conn.conn.get_conn_id() != conn_id);
|
||||
let new_len = pool.get_conns().len();
|
||||
let removed = old_len - new_len;
|
||||
if removed > 0 {
|
||||
global_connections_count.fetch_sub(removed, atomic::Ordering::Relaxed);
|
||||
@@ -103,11 +168,26 @@ impl<C: ClientInnerExt> EndpointConnPool<C> {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn put(pool: &RwLock<Self>, conn_info: &ConnInfo, client: ClientInnerRemote<C>) {
|
||||
let conn_id = client.get_conn_id();
|
||||
pub(crate) fn get_name(&self) -> &str {
|
||||
&self.pool_name
|
||||
}
|
||||
|
||||
if client.is_closed() {
|
||||
info!(%conn_id, "pool: throwing away connection '{conn_info}' because connection is closed");
|
||||
pub(crate) fn get_pool(&self, db_user: (DbName, RoleName)) -> Option<&DbUserConnPool<C>> {
|
||||
self.pools.get(&db_user)
|
||||
}
|
||||
|
||||
pub(crate) fn get_pool_mut(
|
||||
&mut self,
|
||||
db_user: (DbName, RoleName),
|
||||
) -> Option<&mut DbUserConnPool<C>> {
|
||||
self.pools.get_mut(&db_user)
|
||||
}
|
||||
|
||||
pub(crate) fn put(pool: &RwLock<Self>, conn_info: &ConnInfo, client: ClientInnerCommon<C>) {
|
||||
let conn_id = client.get_conn_id();
|
||||
let pool_name = pool.read().get_name().to_string();
|
||||
if client.inner.is_closed() {
|
||||
info!(%conn_id, "{}: throwing away connection '{conn_info}' because connection is closed", pool_name);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -118,7 +198,7 @@ impl<C: ClientInnerExt> EndpointConnPool<C> {
|
||||
.load(atomic::Ordering::Relaxed)
|
||||
>= global_max_conn
|
||||
{
|
||||
info!(%conn_id, "pool: throwing away connection '{conn_info}' because pool is full");
|
||||
info!(%conn_id, "{}: throwing away connection '{conn_info}' because pool is full", pool_name);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -130,13 +210,13 @@ impl<C: ClientInnerExt> EndpointConnPool<C> {
|
||||
|
||||
if pool.total_conns < pool.max_conns {
|
||||
let pool_entries = pool.pools.entry(conn_info.db_and_user()).or_default();
|
||||
pool_entries.conns.push(ConnPoolEntry {
|
||||
pool_entries.get_conns().push(ConnPoolEntry {
|
||||
conn: client,
|
||||
_last_access: std::time::Instant::now(),
|
||||
});
|
||||
|
||||
returned = true;
|
||||
per_db_size = pool_entries.conns.len();
|
||||
per_db_size = pool_entries.get_conns().len();
|
||||
|
||||
pool.total_conns += 1;
|
||||
pool.global_connections_count
|
||||
@@ -153,9 +233,9 @@ impl<C: ClientInnerExt> EndpointConnPool<C> {
|
||||
|
||||
// do logging outside of the mutex
|
||||
if returned {
|
||||
info!(%conn_id, "pool: returning connection '{conn_info}' back to the pool, total_conns={total_conns}, for this (db, user)={per_db_size}");
|
||||
info!(%conn_id, "{pool_name}: returning connection '{conn_info}' back to the pool, total_conns={total_conns}, for this (db, user)={per_db_size}");
|
||||
} else {
|
||||
info!(%conn_id, "pool: throwing away connection '{conn_info}' because pool is full, total_conns={total_conns}");
|
||||
info!(%conn_id, "{pool_name}: throwing away connection '{conn_info}' because pool is full, total_conns={total_conns}");
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -176,19 +256,39 @@ impl<C: ClientInnerExt> Drop for EndpointConnPool<C> {
|
||||
|
||||
pub(crate) struct DbUserConnPool<C: ClientInnerExt> {
|
||||
pub(crate) conns: Vec<ConnPoolEntry<C>>,
|
||||
pub(crate) initialized: Option<bool>, // a bit ugly, exists only for local pools
|
||||
}
|
||||
|
||||
impl<C: ClientInnerExt> Default for DbUserConnPool<C> {
|
||||
fn default() -> Self {
|
||||
Self { conns: Vec::new() }
|
||||
Self {
|
||||
conns: Vec::new(),
|
||||
initialized: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: ClientInnerExt> DbUserConnPool<C> {
|
||||
pub(crate) trait DbUserConn<C: ClientInnerExt>: Default {
|
||||
fn set_initialized(&mut self);
|
||||
fn is_initialized(&self) -> bool;
|
||||
fn clear_closed_clients(&mut self, conns: &mut usize) -> usize;
|
||||
fn get_conn_entry(&mut self, conns: &mut usize) -> (Option<ConnPoolEntry<C>>, usize);
|
||||
fn get_conns(&mut self) -> &mut Vec<ConnPoolEntry<C>>;
|
||||
}
|
||||
|
||||
impl<C: ClientInnerExt> DbUserConn<C> for DbUserConnPool<C> {
|
||||
fn set_initialized(&mut self) {
|
||||
self.initialized = Some(true);
|
||||
}
|
||||
|
||||
fn is_initialized(&self) -> bool {
|
||||
self.initialized.unwrap_or(false)
|
||||
}
|
||||
|
||||
fn clear_closed_clients(&mut self, conns: &mut usize) -> usize {
|
||||
let old_len = self.conns.len();
|
||||
|
||||
self.conns.retain(|conn| !conn.conn.is_closed());
|
||||
self.conns.retain(|conn| !conn.conn.inner.is_closed());
|
||||
|
||||
let new_len = self.conns.len();
|
||||
let removed = old_len - new_len;
|
||||
@@ -196,10 +296,7 @@ impl<C: ClientInnerExt> DbUserConnPool<C> {
|
||||
removed
|
||||
}
|
||||
|
||||
pub(crate) fn get_conn_entry(
|
||||
&mut self,
|
||||
conns: &mut usize,
|
||||
) -> (Option<ConnPoolEntry<C>>, usize) {
|
||||
fn get_conn_entry(&mut self, conns: &mut usize) -> (Option<ConnPoolEntry<C>>, usize) {
|
||||
let mut removed = self.clear_closed_clients(conns);
|
||||
let conn = self.conns.pop();
|
||||
if conn.is_some() {
|
||||
@@ -215,6 +312,10 @@ impl<C: ClientInnerExt> DbUserConnPool<C> {
|
||||
|
||||
(conn, removed)
|
||||
}
|
||||
|
||||
fn get_conns(&mut self) -> &mut Vec<ConnPoolEntry<C>> {
|
||||
&mut self.conns
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct GlobalConnPool<C: ClientInnerExt> {
|
||||
@@ -278,6 +379,60 @@ impl<C: ClientInnerExt> GlobalConnPool<C> {
|
||||
self.config.pool_options.idle_timeout
|
||||
}
|
||||
|
||||
pub(crate) fn get(
|
||||
self: &Arc<Self>,
|
||||
ctx: &RequestMonitoring,
|
||||
conn_info: &ConnInfo,
|
||||
) -> Result<Option<Client<C>>, HttpConnError> {
|
||||
let mut client: Option<ClientInnerCommon<C>> = None;
|
||||
let Some(endpoint) = conn_info.endpoint_cache_key() else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
let endpoint_pool = self.get_or_create_endpoint_pool(&endpoint);
|
||||
if let Some(entry) = endpoint_pool
|
||||
.write()
|
||||
.get_conn_entry(conn_info.db_and_user())
|
||||
{
|
||||
client = Some(entry.conn);
|
||||
}
|
||||
let endpoint_pool = Arc::downgrade(&endpoint_pool);
|
||||
|
||||
// ok return cached connection if found and establish a new one otherwise
|
||||
if let Some(mut client) = client {
|
||||
if client.inner.is_closed() {
|
||||
info!("pool: cached connection '{conn_info}' is closed, opening a new one");
|
||||
return Ok(None);
|
||||
}
|
||||
tracing::Span::current()
|
||||
.record("conn_id", tracing::field::display(client.get_conn_id()));
|
||||
tracing::Span::current().record(
|
||||
"pid",
|
||||
tracing::field::display(client.inner.get_process_id()),
|
||||
);
|
||||
info!(
|
||||
cold_start_info = ColdStartInfo::HttpPoolHit.as_str(),
|
||||
"pool: reusing connection '{conn_info}'"
|
||||
);
|
||||
|
||||
match client.get_data() {
|
||||
ClientDataEnum::Local(data) => {
|
||||
data.session().send(ctx.session_id())?;
|
||||
}
|
||||
|
||||
ClientDataEnum::Remote(data) => {
|
||||
data.session().send(ctx.session_id())?;
|
||||
}
|
||||
ClientDataEnum::Http(_) => (),
|
||||
}
|
||||
|
||||
ctx.set_cold_start_info(ColdStartInfo::HttpPoolHit);
|
||||
ctx.success();
|
||||
return Ok(Some(Client::new(client, conn_info.clone(), endpoint_pool)));
|
||||
}
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
pub(crate) fn shutdown(&self) {
|
||||
// drops all strong references to endpoint-pools
|
||||
self.global_pool.clear();
|
||||
@@ -374,6 +529,7 @@ impl<C: ClientInnerExt> GlobalConnPool<C> {
|
||||
_guard: Metrics::get().proxy.http_endpoint_pools.guard(),
|
||||
global_connections_count: self.global_connections_count.clone(),
|
||||
global_pool_size_max_conns: self.config.pool_options.max_total_conns,
|
||||
pool_name: String::from("remote"),
|
||||
}));
|
||||
|
||||
// find or create a pool for this endpoint
|
||||
@@ -400,55 +556,23 @@ impl<C: ClientInnerExt> GlobalConnPool<C> {
|
||||
|
||||
pool
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn get(
|
||||
self: &Arc<Self>,
|
||||
ctx: &RequestMonitoring,
|
||||
conn_info: &ConnInfo,
|
||||
) -> Result<Option<Client<C>>, HttpConnError> {
|
||||
let mut client: Option<ClientInnerRemote<C>> = None;
|
||||
let Some(endpoint) = conn_info.endpoint_cache_key() else {
|
||||
return Ok(None);
|
||||
};
|
||||
pub(crate) struct Client<C: ClientInnerExt> {
|
||||
span: Span,
|
||||
inner: Option<ClientInnerCommon<C>>,
|
||||
conn_info: ConnInfo,
|
||||
pool: Weak<RwLock<EndpointConnPool<C>>>,
|
||||
}
|
||||
|
||||
let endpoint_pool = self.get_or_create_endpoint_pool(&endpoint);
|
||||
if let Some(entry) = endpoint_pool
|
||||
.write()
|
||||
.get_conn_entry(conn_info.db_and_user())
|
||||
{
|
||||
client = Some(entry.conn);
|
||||
}
|
||||
let endpoint_pool = Arc::downgrade(&endpoint_pool);
|
||||
|
||||
// ok return cached connection if found and establish a new one otherwise
|
||||
if let Some(mut client) = client {
|
||||
if client.is_closed() {
|
||||
info!("pool: cached connection '{conn_info}' is closed, opening a new one");
|
||||
return Ok(None);
|
||||
}
|
||||
tracing::Span::current()
|
||||
.record("conn_id", tracing::field::display(client.get_conn_id()));
|
||||
tracing::Span::current().record(
|
||||
"pid",
|
||||
tracing::field::display(client.inner().get_process_id()),
|
||||
);
|
||||
info!(
|
||||
cold_start_info = ColdStartInfo::HttpPoolHit.as_str(),
|
||||
"pool: reusing connection '{conn_info}'"
|
||||
);
|
||||
|
||||
client.session().send(ctx.session_id())?;
|
||||
ctx.set_cold_start_info(ColdStartInfo::HttpPoolHit);
|
||||
ctx.success();
|
||||
return Ok(Some(Client::new(client, conn_info.clone(), endpoint_pool)));
|
||||
}
|
||||
Ok(None)
|
||||
}
|
||||
pub(crate) struct Discard<'a, C: ClientInnerExt> {
|
||||
conn_info: &'a ConnInfo,
|
||||
pool: &'a mut Weak<RwLock<EndpointConnPool<C>>>,
|
||||
}
|
||||
|
||||
impl<C: ClientInnerExt> Client<C> {
|
||||
pub(crate) fn new(
|
||||
inner: ClientInnerRemote<C>,
|
||||
inner: ClientInnerCommon<C>,
|
||||
conn_info: ConnInfo,
|
||||
pool: Weak<RwLock<EndpointConnPool<C>>>,
|
||||
) -> Self {
|
||||
@@ -460,7 +584,18 @@ impl<C: ClientInnerExt> Client<C> {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn inner_mut(&mut self) -> (&mut C, Discard<'_, C>) {
|
||||
pub(crate) fn client_inner(&mut self) -> (&mut ClientInnerCommon<C>, Discard<'_, C>) {
|
||||
let Self {
|
||||
inner,
|
||||
pool,
|
||||
conn_info,
|
||||
span: _,
|
||||
} = self;
|
||||
let inner_m = inner.as_mut().expect("client inner should not be removed");
|
||||
(inner_m, Discard { conn_info, pool })
|
||||
}
|
||||
|
||||
pub(crate) fn inner(&mut self) -> (&mut C, Discard<'_, C>) {
|
||||
let Self {
|
||||
inner,
|
||||
pool,
|
||||
@@ -468,12 +603,11 @@ impl<C: ClientInnerExt> Client<C> {
|
||||
span: _,
|
||||
} = self;
|
||||
let inner = inner.as_mut().expect("client inner should not be removed");
|
||||
let inner_ref = inner.inner_mut();
|
||||
(inner_ref, Discard { conn_info, pool })
|
||||
(&mut inner.inner, Discard { conn_info, pool })
|
||||
}
|
||||
|
||||
pub(crate) fn metrics(&self) -> Arc<MetricCounter> {
|
||||
let aux = &self.inner.as_ref().unwrap().aux();
|
||||
let aux = &self.inner.as_ref().unwrap().aux;
|
||||
USAGE_METRICS.register(Ids {
|
||||
endpoint_id: aux.endpoint_id,
|
||||
branch_id: aux.branch_id,
|
||||
@@ -498,13 +632,6 @@ impl<C: ClientInnerExt> Client<C> {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct Client<C: ClientInnerExt> {
|
||||
span: Span,
|
||||
inner: Option<ClientInnerRemote<C>>,
|
||||
conn_info: ConnInfo,
|
||||
pool: Weak<RwLock<EndpointConnPool<C>>>,
|
||||
}
|
||||
|
||||
impl<C: ClientInnerExt> Drop for Client<C> {
|
||||
fn drop(&mut self) {
|
||||
if let Some(drop) = self.do_drop() {
|
||||
@@ -517,10 +644,11 @@ impl<C: ClientInnerExt> Deref for Client<C> {
|
||||
type Target = C;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
self.inner
|
||||
&self
|
||||
.inner
|
||||
.as_ref()
|
||||
.expect("client inner should not be removed")
|
||||
.inner()
|
||||
.inner
|
||||
}
|
||||
}
|
||||
|
||||
@@ -539,11 +667,6 @@ impl ClientInnerExt for tokio_postgres::Client {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct Discard<'a, C: ClientInnerExt> {
|
||||
conn_info: &'a ConnInfo,
|
||||
pool: &'a mut Weak<RwLock<EndpointConnPool<C>>>,
|
||||
}
|
||||
|
||||
impl<C: ClientInnerExt> Discard<'_, C> {
|
||||
pub(crate) fn check_idle(&mut self, status: ReadyForQueryStatus) {
|
||||
let conn_info = &self.conn_info;
|
||||
|
||||
@@ -7,9 +7,11 @@ use hyper::client::conn::http2;
|
||||
use hyper_util::rt::{TokioExecutor, TokioIo};
|
||||
use parking_lot::RwLock;
|
||||
use rand::Rng;
|
||||
use std::result::Result::Ok;
|
||||
use tokio::net::TcpStream;
|
||||
use tracing::{debug, error, info, info_span, Instrument};
|
||||
|
||||
use super::backend::HttpConnError;
|
||||
use super::conn_pool_lib::{ClientInnerExt, ConnInfo};
|
||||
use crate::context::RequestMonitoring;
|
||||
use crate::control_plane::messages::{ColdStartInfo, MetricsAuxInfo};
|
||||
@@ -28,6 +30,8 @@ pub(crate) struct ConnPoolEntry<C: ClientInnerExt + Clone> {
|
||||
aux: MetricsAuxInfo,
|
||||
}
|
||||
|
||||
pub(crate) struct ClientDataHttp();
|
||||
|
||||
// Per-endpoint connection pool
|
||||
// Number of open connections is limited by the `max_conns_per_endpoint`.
|
||||
pub(crate) struct EndpointConnPool<C: ClientInnerExt + Clone> {
|
||||
@@ -206,14 +210,22 @@ impl<C: ClientInnerExt + Clone> GlobalConnPool<C> {
|
||||
}
|
||||
}
|
||||
|
||||
#[expect(unused_results)]
|
||||
pub(crate) fn get(
|
||||
self: &Arc<Self>,
|
||||
ctx: &RequestMonitoring,
|
||||
conn_info: &ConnInfo,
|
||||
) -> Option<Client<C>> {
|
||||
let endpoint = conn_info.endpoint_cache_key()?;
|
||||
) -> Result<Option<Client<C>>, HttpConnError> {
|
||||
let result: Result<Option<Client<C>>, HttpConnError>;
|
||||
let Some(endpoint) = conn_info.endpoint_cache_key() else {
|
||||
result = Ok(None);
|
||||
return result;
|
||||
};
|
||||
let endpoint_pool = self.get_or_create_endpoint_pool(&endpoint);
|
||||
let client = endpoint_pool.write().get_conn_entry()?;
|
||||
let Some(client) = endpoint_pool.write().get_conn_entry() else {
|
||||
result = Ok(None);
|
||||
return result;
|
||||
};
|
||||
|
||||
tracing::Span::current().record("conn_id", tracing::field::display(client.conn_id));
|
||||
info!(
|
||||
@@ -222,7 +234,7 @@ impl<C: ClientInnerExt + Clone> GlobalConnPool<C> {
|
||||
);
|
||||
ctx.set_cold_start_info(ColdStartInfo::HttpPoolHit);
|
||||
ctx.success();
|
||||
Some(Client::new(client.conn, client.aux))
|
||||
Ok(Some(Client::new(client.conn, client.aux)))
|
||||
}
|
||||
|
||||
fn get_or_create_endpoint_pool(
|
||||
|
||||
@@ -11,7 +11,8 @@
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::pin::pin;
|
||||
use std::sync::{Arc, Weak};
|
||||
use std::sync::atomic::AtomicUsize;
|
||||
use std::sync::Arc;
|
||||
use std::task::{ready, Poll};
|
||||
use std::time::Duration;
|
||||
|
||||
@@ -26,177 +27,42 @@ use signature::Signer;
|
||||
use tokio::time::Instant;
|
||||
use tokio_postgres::tls::NoTlsStream;
|
||||
use tokio_postgres::types::ToSql;
|
||||
use tokio_postgres::{AsyncMessage, ReadyForQueryStatus, Socket};
|
||||
use tokio_postgres::{AsyncMessage, Socket};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{error, info, info_span, warn, Instrument, Span};
|
||||
use tracing::{error, info, info_span, warn, Instrument};
|
||||
|
||||
use super::backend::HttpConnError;
|
||||
use super::conn_pool_lib::{ClientInnerExt, ConnInfo};
|
||||
use super::conn_pool_lib::{
|
||||
Client, ClientDataEnum, ClientInnerCommon, ClientInnerExt, ConnInfo, DbUserConn,
|
||||
EndpointConnPool,
|
||||
};
|
||||
use crate::context::RequestMonitoring;
|
||||
use crate::control_plane::messages::{ColdStartInfo, MetricsAuxInfo};
|
||||
use crate::metrics::Metrics;
|
||||
use crate::types::{DbName, RoleName};
|
||||
use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS};
|
||||
|
||||
pub(crate) const EXT_NAME: &str = "pg_session_jwt";
|
||||
pub(crate) const EXT_VERSION: &str = "0.1.2";
|
||||
pub(crate) const EXT_SCHEMA: &str = "auth";
|
||||
|
||||
struct ConnPoolEntry<C: ClientInnerExt> {
|
||||
conn: ClientInner<C>,
|
||||
_last_access: std::time::Instant,
|
||||
pub(crate) struct ClientDataLocal {
|
||||
session: tokio::sync::watch::Sender<uuid::Uuid>,
|
||||
cancel: CancellationToken,
|
||||
key: SigningKey,
|
||||
jti: u64,
|
||||
}
|
||||
|
||||
// Per-endpoint connection pool, (dbname, username) -> DbUserConnPool
|
||||
// Number of open connections is limited by the `max_conns_per_endpoint`.
|
||||
pub(crate) struct EndpointConnPool<C: ClientInnerExt> {
|
||||
pools: HashMap<(DbName, RoleName), DbUserConnPool<C>>,
|
||||
total_conns: usize,
|
||||
max_conns: usize,
|
||||
global_pool_size_max_conns: usize,
|
||||
}
|
||||
|
||||
impl<C: ClientInnerExt> EndpointConnPool<C> {
|
||||
fn get_conn_entry(&mut self, db_user: (DbName, RoleName)) -> Option<ConnPoolEntry<C>> {
|
||||
let Self {
|
||||
pools, total_conns, ..
|
||||
} = self;
|
||||
pools
|
||||
.get_mut(&db_user)
|
||||
.and_then(|pool_entries| pool_entries.get_conn_entry(total_conns))
|
||||
impl ClientDataLocal {
|
||||
pub fn session(&mut self) -> &mut tokio::sync::watch::Sender<uuid::Uuid> {
|
||||
&mut self.session
|
||||
}
|
||||
|
||||
fn remove_client(&mut self, db_user: (DbName, RoleName), conn_id: uuid::Uuid) -> bool {
|
||||
let Self {
|
||||
pools, total_conns, ..
|
||||
} = self;
|
||||
if let Some(pool) = pools.get_mut(&db_user) {
|
||||
let old_len = pool.conns.len();
|
||||
pool.conns.retain(|conn| conn.conn.conn_id != conn_id);
|
||||
let new_len = pool.conns.len();
|
||||
let removed = old_len - new_len;
|
||||
if removed > 0 {
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.http_pool_opened_connections
|
||||
.get_metric()
|
||||
.dec_by(removed as i64);
|
||||
}
|
||||
*total_conns -= removed;
|
||||
removed > 0
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
fn put(pool: &RwLock<Self>, conn_info: &ConnInfo, client: ClientInner<C>) {
|
||||
let conn_id = client.conn_id;
|
||||
|
||||
if client.is_closed() {
|
||||
info!(%conn_id, "local_pool: throwing away connection '{conn_info}' because connection is closed");
|
||||
return;
|
||||
}
|
||||
let global_max_conn = pool.read().global_pool_size_max_conns;
|
||||
if pool.read().total_conns >= global_max_conn {
|
||||
info!(%conn_id, "local_pool: throwing away connection '{conn_info}' because pool is full");
|
||||
return;
|
||||
}
|
||||
|
||||
// return connection to the pool
|
||||
let mut returned = false;
|
||||
let mut per_db_size = 0;
|
||||
let total_conns = {
|
||||
let mut pool = pool.write();
|
||||
|
||||
if pool.total_conns < pool.max_conns {
|
||||
let pool_entries = pool.pools.entry(conn_info.db_and_user()).or_default();
|
||||
pool_entries.conns.push(ConnPoolEntry {
|
||||
conn: client,
|
||||
_last_access: std::time::Instant::now(),
|
||||
});
|
||||
|
||||
returned = true;
|
||||
per_db_size = pool_entries.conns.len();
|
||||
|
||||
pool.total_conns += 1;
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.http_pool_opened_connections
|
||||
.get_metric()
|
||||
.inc();
|
||||
}
|
||||
|
||||
pool.total_conns
|
||||
};
|
||||
|
||||
// do logging outside of the mutex
|
||||
if returned {
|
||||
info!(%conn_id, "local_pool: returning connection '{conn_info}' back to the pool, total_conns={total_conns}, for this (db, user)={per_db_size}");
|
||||
} else {
|
||||
info!(%conn_id, "local_pool: throwing away connection '{conn_info}' because pool is full, total_conns={total_conns}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: ClientInnerExt> Drop for EndpointConnPool<C> {
|
||||
fn drop(&mut self) {
|
||||
if self.total_conns > 0 {
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.http_pool_opened_connections
|
||||
.get_metric()
|
||||
.dec_by(self.total_conns as i64);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct DbUserConnPool<C: ClientInnerExt> {
|
||||
conns: Vec<ConnPoolEntry<C>>,
|
||||
|
||||
// true if we have definitely installed the extension and
|
||||
// granted the role access to the auth schema.
|
||||
initialized: bool,
|
||||
}
|
||||
|
||||
impl<C: ClientInnerExt> Default for DbUserConnPool<C> {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
conns: Vec::new(),
|
||||
initialized: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: ClientInnerExt> DbUserConnPool<C> {
|
||||
fn clear_closed_clients(&mut self, conns: &mut usize) -> usize {
|
||||
let old_len = self.conns.len();
|
||||
|
||||
self.conns.retain(|conn| !conn.conn.is_closed());
|
||||
|
||||
let new_len = self.conns.len();
|
||||
let removed = old_len - new_len;
|
||||
*conns -= removed;
|
||||
removed
|
||||
}
|
||||
|
||||
fn get_conn_entry(&mut self, conns: &mut usize) -> Option<ConnPoolEntry<C>> {
|
||||
let mut removed = self.clear_closed_clients(conns);
|
||||
let conn = self.conns.pop();
|
||||
if conn.is_some() {
|
||||
*conns -= 1;
|
||||
removed += 1;
|
||||
}
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.http_pool_opened_connections
|
||||
.get_metric()
|
||||
.dec_by(removed as i64);
|
||||
conn
|
||||
pub fn cancel(&mut self) {
|
||||
self.cancel.cancel();
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct LocalConnPool<C: ClientInnerExt> {
|
||||
global_pool: RwLock<EndpointConnPool<C>>,
|
||||
global_pool: Arc<RwLock<EndpointConnPool<C>>>,
|
||||
|
||||
config: &'static crate::config::HttpConfig,
|
||||
}
|
||||
@@ -204,12 +70,14 @@ pub(crate) struct LocalConnPool<C: ClientInnerExt> {
|
||||
impl<C: ClientInnerExt> LocalConnPool<C> {
|
||||
pub(crate) fn new(config: &'static crate::config::HttpConfig) -> Arc<Self> {
|
||||
Arc::new(Self {
|
||||
global_pool: RwLock::new(EndpointConnPool {
|
||||
pools: HashMap::new(),
|
||||
total_conns: 0,
|
||||
max_conns: config.pool_options.max_conns_per_endpoint,
|
||||
global_pool_size_max_conns: config.pool_options.max_total_conns,
|
||||
}),
|
||||
global_pool: Arc::new(RwLock::new(EndpointConnPool::new(
|
||||
HashMap::new(),
|
||||
0,
|
||||
config.pool_options.max_conns_per_endpoint,
|
||||
Arc::new(AtomicUsize::new(0)),
|
||||
config.pool_options.max_total_conns,
|
||||
String::from("local_pool"),
|
||||
))),
|
||||
config,
|
||||
})
|
||||
}
|
||||
@@ -222,7 +90,7 @@ impl<C: ClientInnerExt> LocalConnPool<C> {
|
||||
self: &Arc<Self>,
|
||||
ctx: &RequestMonitoring,
|
||||
conn_info: &ConnInfo,
|
||||
) -> Result<Option<LocalClient<C>>, HttpConnError> {
|
||||
) -> Result<Option<Client<C>>, HttpConnError> {
|
||||
let client = self
|
||||
.global_pool
|
||||
.write()
|
||||
@@ -230,12 +98,14 @@ impl<C: ClientInnerExt> LocalConnPool<C> {
|
||||
.map(|entry| entry.conn);
|
||||
|
||||
// ok return cached connection if found and establish a new one otherwise
|
||||
if let Some(client) = client {
|
||||
if client.is_closed() {
|
||||
if let Some(mut client) = client {
|
||||
if client.inner.is_closed() {
|
||||
info!("local_pool: cached connection '{conn_info}' is closed, opening a new one");
|
||||
return Ok(None);
|
||||
}
|
||||
tracing::Span::current().record("conn_id", tracing::field::display(client.conn_id));
|
||||
|
||||
tracing::Span::current()
|
||||
.record("conn_id", tracing::field::display(client.get_conn_id()));
|
||||
tracing::Span::current().record(
|
||||
"pid",
|
||||
tracing::field::display(client.inner.get_process_id()),
|
||||
@@ -244,47 +114,59 @@ impl<C: ClientInnerExt> LocalConnPool<C> {
|
||||
cold_start_info = ColdStartInfo::HttpPoolHit.as_str(),
|
||||
"local_pool: reusing connection '{conn_info}'"
|
||||
);
|
||||
client.session.send(ctx.session_id())?;
|
||||
|
||||
match client.get_data() {
|
||||
ClientDataEnum::Local(data) => {
|
||||
data.session().send(ctx.session_id())?;
|
||||
}
|
||||
|
||||
ClientDataEnum::Remote(data) => {
|
||||
data.session().send(ctx.session_id())?;
|
||||
}
|
||||
ClientDataEnum::Http(_) => (),
|
||||
}
|
||||
|
||||
ctx.set_cold_start_info(ColdStartInfo::HttpPoolHit);
|
||||
ctx.success();
|
||||
return Ok(Some(LocalClient::new(
|
||||
|
||||
return Ok(Some(Client::new(
|
||||
client,
|
||||
conn_info.clone(),
|
||||
Arc::downgrade(self),
|
||||
Arc::downgrade(&self.global_pool),
|
||||
)));
|
||||
}
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
pub(crate) fn initialized(self: &Arc<Self>, conn_info: &ConnInfo) -> bool {
|
||||
self.global_pool
|
||||
.read()
|
||||
.pools
|
||||
.get(&conn_info.db_and_user())
|
||||
.map_or(false, |pool| pool.initialized)
|
||||
if let Some(pool) = self.global_pool.read().get_pool(conn_info.db_and_user()) {
|
||||
return pool.is_initialized();
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
pub(crate) fn set_initialized(self: &Arc<Self>, conn_info: &ConnInfo) {
|
||||
self.global_pool
|
||||
if let Some(pool) = self
|
||||
.global_pool
|
||||
.write()
|
||||
.pools
|
||||
.entry(conn_info.db_and_user())
|
||||
.or_default()
|
||||
.initialized = true;
|
||||
.get_pool_mut(conn_info.db_and_user())
|
||||
{
|
||||
pool.set_initialized();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) fn poll_client(
|
||||
global_pool: Arc<LocalConnPool<tokio_postgres::Client>>,
|
||||
pub(crate) fn poll_client<C: ClientInnerExt>(
|
||||
global_pool: Arc<LocalConnPool<C>>,
|
||||
ctx: &RequestMonitoring,
|
||||
conn_info: ConnInfo,
|
||||
client: tokio_postgres::Client,
|
||||
client: C,
|
||||
mut connection: tokio_postgres::Connection<Socket, NoTlsStream>,
|
||||
key: SigningKey,
|
||||
conn_id: uuid::Uuid,
|
||||
aux: MetricsAuxInfo,
|
||||
) -> LocalClient<tokio_postgres::Client> {
|
||||
) -> Client<C> {
|
||||
let conn_gauge = Metrics::get().proxy.db_connections.guard(ctx.protocol());
|
||||
let mut session_id = ctx.session_id();
|
||||
let (tx, mut rx) = tokio::sync::watch::channel(session_id);
|
||||
@@ -377,111 +259,47 @@ pub(crate) fn poll_client(
|
||||
}
|
||||
.instrument(span));
|
||||
|
||||
let inner = ClientInner {
|
||||
let inner = ClientInnerCommon {
|
||||
inner: client,
|
||||
session: tx,
|
||||
cancel,
|
||||
aux,
|
||||
conn_id,
|
||||
key,
|
||||
jti: 0,
|
||||
data: ClientDataEnum::Local(ClientDataLocal {
|
||||
session: tx,
|
||||
cancel,
|
||||
key,
|
||||
jti: 0,
|
||||
}),
|
||||
};
|
||||
LocalClient::new(inner, conn_info, pool_clone)
|
||||
|
||||
Client::new(
|
||||
inner,
|
||||
conn_info,
|
||||
Arc::downgrade(&pool_clone.upgrade().unwrap().global_pool),
|
||||
)
|
||||
}
|
||||
|
||||
pub(crate) struct ClientInner<C: ClientInnerExt> {
|
||||
inner: C,
|
||||
session: tokio::sync::watch::Sender<uuid::Uuid>,
|
||||
cancel: CancellationToken,
|
||||
aux: MetricsAuxInfo,
|
||||
conn_id: uuid::Uuid,
|
||||
|
||||
// needed for pg_session_jwt state
|
||||
key: SigningKey,
|
||||
jti: u64,
|
||||
}
|
||||
|
||||
impl<C: ClientInnerExt> Drop for ClientInner<C> {
|
||||
fn drop(&mut self) {
|
||||
// on client drop, tell the conn to shut down
|
||||
self.cancel.cancel();
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: ClientInnerExt> ClientInner<C> {
|
||||
pub(crate) fn is_closed(&self) -> bool {
|
||||
self.inner.is_closed()
|
||||
}
|
||||
}
|
||||
|
||||
impl ClientInner<tokio_postgres::Client> {
|
||||
impl ClientInnerCommon<tokio_postgres::Client> {
|
||||
pub(crate) async fn set_jwt_session(&mut self, payload: &[u8]) -> Result<(), HttpConnError> {
|
||||
self.jti += 1;
|
||||
let token = resign_jwt(&self.key, payload, self.jti)?;
|
||||
if let ClientDataEnum::Local(local_data) = &mut self.data {
|
||||
local_data.jti += 1;
|
||||
let token = resign_jwt(&local_data.key, payload, local_data.jti)?;
|
||||
|
||||
// initiates the auth session
|
||||
self.inner.simple_query("discard all").await?;
|
||||
self.inner
|
||||
.query(
|
||||
"select auth.jwt_session_init($1)",
|
||||
&[&token as &(dyn ToSql + Sync)],
|
||||
)
|
||||
.await?;
|
||||
// initiates the auth session
|
||||
self.inner.simple_query("discard all").await?;
|
||||
self.inner
|
||||
.query(
|
||||
"select auth.jwt_session_init($1)",
|
||||
&[&token as &(dyn ToSql + Sync)],
|
||||
)
|
||||
.await?;
|
||||
|
||||
let pid = self.inner.get_process_id();
|
||||
info!(pid, jti = self.jti, "user session state init");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct LocalClient<C: ClientInnerExt> {
|
||||
span: Span,
|
||||
inner: Option<ClientInner<C>>,
|
||||
conn_info: ConnInfo,
|
||||
pool: Weak<LocalConnPool<C>>,
|
||||
}
|
||||
|
||||
pub(crate) struct Discard<'a, C: ClientInnerExt> {
|
||||
conn_info: &'a ConnInfo,
|
||||
pool: &'a mut Weak<LocalConnPool<C>>,
|
||||
}
|
||||
|
||||
impl<C: ClientInnerExt> LocalClient<C> {
|
||||
pub(self) fn new(
|
||||
inner: ClientInner<C>,
|
||||
conn_info: ConnInfo,
|
||||
pool: Weak<LocalConnPool<C>>,
|
||||
) -> Self {
|
||||
Self {
|
||||
inner: Some(inner),
|
||||
span: Span::current(),
|
||||
conn_info,
|
||||
pool,
|
||||
let pid = self.inner.get_process_id();
|
||||
info!(pid, jti = local_data.jti, "user session state init");
|
||||
Ok(())
|
||||
} else {
|
||||
panic!("unexpected client data type");
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn client_inner(&mut self) -> (&mut ClientInner<C>, Discard<'_, C>) {
|
||||
let Self {
|
||||
inner,
|
||||
pool,
|
||||
conn_info,
|
||||
span: _,
|
||||
} = self;
|
||||
let inner_m = inner.as_mut().expect("client inner should not be removed");
|
||||
(inner_m, Discard { conn_info, pool })
|
||||
}
|
||||
|
||||
pub(crate) fn inner(&mut self) -> (&mut C, Discard<'_, C>) {
|
||||
let Self {
|
||||
inner,
|
||||
pool,
|
||||
conn_info,
|
||||
span: _,
|
||||
} = self;
|
||||
let inner = inner.as_mut().expect("client inner should not be removed");
|
||||
(&mut inner.inner, Discard { conn_info, pool })
|
||||
}
|
||||
}
|
||||
|
||||
/// implements relatively efficient in-place json object key upserting
|
||||
@@ -547,58 +365,6 @@ fn sign_jwt(sk: &SigningKey, payload: &[u8]) -> String {
|
||||
jwt
|
||||
}
|
||||
|
||||
impl<C: ClientInnerExt> LocalClient<C> {
|
||||
pub(crate) fn metrics(&self) -> Arc<MetricCounter> {
|
||||
let aux = &self.inner.as_ref().unwrap().aux;
|
||||
USAGE_METRICS.register(Ids {
|
||||
endpoint_id: aux.endpoint_id,
|
||||
branch_id: aux.branch_id,
|
||||
})
|
||||
}
|
||||
|
||||
fn do_drop(&mut self) -> Option<impl FnOnce() + use<C>> {
|
||||
let conn_info = self.conn_info.clone();
|
||||
let client = self
|
||||
.inner
|
||||
.take()
|
||||
.expect("client inner should not be removed");
|
||||
if let Some(conn_pool) = std::mem::take(&mut self.pool).upgrade() {
|
||||
let current_span = self.span.clone();
|
||||
// return connection to the pool
|
||||
return Some(move || {
|
||||
let _span = current_span.enter();
|
||||
EndpointConnPool::put(&conn_pool.global_pool, &conn_info, client);
|
||||
});
|
||||
}
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: ClientInnerExt> Drop for LocalClient<C> {
|
||||
fn drop(&mut self) {
|
||||
if let Some(drop) = self.do_drop() {
|
||||
tokio::task::spawn_blocking(drop);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: ClientInnerExt> Discard<'_, C> {
|
||||
pub(crate) fn check_idle(&mut self, status: ReadyForQueryStatus) {
|
||||
let conn_info = &self.conn_info;
|
||||
if status != ReadyForQueryStatus::Idle && std::mem::take(self.pool).strong_count() > 0 {
|
||||
info!(
|
||||
"local_pool: throwing away connection '{conn_info}' because connection is not idle"
|
||||
);
|
||||
}
|
||||
}
|
||||
pub(crate) fn discard(&mut self) {
|
||||
let conn_info = &self.conn_info;
|
||||
if std::mem::take(self.pool).strong_count() > 0 {
|
||||
info!("local_pool: throwing away connection '{conn_info}' because connection is potentially in a broken state");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use p256::ecdsa::SigningKey;
|
||||
|
||||
@@ -31,7 +31,6 @@ use super::conn_pool_lib::{self, ConnInfo};
|
||||
use super::error::HttpCodeError;
|
||||
use super::http_util::json_response;
|
||||
use super::json::{json_to_pg_text, pg_text_row_to_json, JsonConversionError};
|
||||
use super::local_conn_pool;
|
||||
use crate::auth::backend::{ComputeCredentialKeys, ComputeUserInfo};
|
||||
use crate::auth::{endpoint_sni, ComputeUserInfoParseError};
|
||||
use crate::config::{AuthenticationConfig, HttpConfig, ProxyConfig, TlsConfig};
|
||||
@@ -1052,12 +1051,12 @@ async fn query_to_json<T: GenericClient>(
|
||||
|
||||
enum Client {
|
||||
Remote(conn_pool_lib::Client<tokio_postgres::Client>),
|
||||
Local(local_conn_pool::LocalClient<tokio_postgres::Client>),
|
||||
Local(conn_pool_lib::Client<tokio_postgres::Client>),
|
||||
}
|
||||
|
||||
enum Discard<'a> {
|
||||
Remote(conn_pool_lib::Discard<'a, tokio_postgres::Client>),
|
||||
Local(local_conn_pool::Discard<'a, tokio_postgres::Client>),
|
||||
Local(conn_pool_lib::Discard<'a, tokio_postgres::Client>),
|
||||
}
|
||||
|
||||
impl Client {
|
||||
@@ -1071,7 +1070,7 @@ impl Client {
|
||||
fn inner(&mut self) -> (&mut tokio_postgres::Client, Discard<'_>) {
|
||||
match self {
|
||||
Client::Remote(client) => {
|
||||
let (c, d) = client.inner_mut();
|
||||
let (c, d) = client.inner();
|
||||
(c, Discard::Remote(d))
|
||||
}
|
||||
Client::Local(local_client) => {
|
||||
|
||||
Reference in New Issue
Block a user