Merge branch 'main' into local-proxy-lazy-ext-install

This commit is contained in:
Conrad Ludgate
2024-10-17 12:38:51 +01:00
committed by GitHub
9 changed files with 751 additions and 665 deletions

View File

@@ -11,8 +11,9 @@ use tokio::net::{lookup_host, TcpStream};
use tracing::field::display;
use tracing::{debug, info};
use super::conn_pool::{poll_client, Client, ConnInfo, GlobalConnPool};
use super::http_conn_pool::{self, poll_http2_client};
use super::conn_pool::poll_client;
use super::conn_pool_lib::{Client, ConnInfo, GlobalConnPool};
use super::http_conn_pool::{self, poll_http2_client, Send};
use super::local_conn_pool::{self, LocalClient, LocalConnPool, EXT_NAME, EXT_SCHEMA, EXT_VERSION};
use crate::auth::backend::local::StaticAuthRules;
use crate::auth::backend::{ComputeCredentials, ComputeUserInfo};
@@ -34,7 +35,7 @@ use crate::rate_limiter::EndpointRateLimiter;
use crate::{compute, EndpointId, Host};
pub(crate) struct PoolingBackend {
pub(crate) http_conn_pool: Arc<super::http_conn_pool::GlobalConnPool>,
pub(crate) http_conn_pool: Arc<super::http_conn_pool::GlobalConnPool<Send>>,
pub(crate) local_pool: Arc<LocalConnPool<tokio_postgres::Client>>,
pub(crate) pool: Arc<GlobalConnPool<tokio_postgres::Client>>,
@@ -203,7 +204,7 @@ impl PoolingBackend {
&self,
ctx: &RequestMonitoring,
conn_info: ConnInfo,
) -> Result<http_conn_pool::Client, HttpConnError> {
) -> Result<http_conn_pool::Client<Send>, HttpConnError> {
info!("pool: looking for an existing connection");
if let Some(client) = self.http_conn_pool.get(ctx, &conn_info) {
return Ok(client);
@@ -524,7 +525,7 @@ impl ConnectMechanism for TokioMechanism {
}
struct HyperMechanism {
pool: Arc<http_conn_pool::GlobalConnPool>,
pool: Arc<http_conn_pool::GlobalConnPool<Send>>,
conn_info: ConnInfo,
conn_id: uuid::Uuid,
@@ -534,7 +535,7 @@ struct HyperMechanism {
#[async_trait]
impl ConnectMechanism for HyperMechanism {
type Connection = http_conn_pool::Client;
type Connection = http_conn_pool::Client<Send>;
type ConnectError = HttpConnError;
type Error = HttpConnError;

View File

@@ -1,31 +1,29 @@
use std::collections::HashMap;
use std::fmt;
use std::ops::Deref;
use std::pin::pin;
use std::sync::atomic::{self, AtomicUsize};
use std::sync::{Arc, Weak};
use std::task::{ready, Poll};
use std::time::Duration;
use dashmap::DashMap;
use futures::future::poll_fn;
use futures::Future;
use parking_lot::RwLock;
use rand::Rng;
use smallvec::SmallVec;
use tokio::time::Instant;
use tokio_postgres::tls::NoTlsStream;
use tokio_postgres::{AsyncMessage, ReadyForQueryStatus, Socket};
use tokio_postgres::{AsyncMessage, Socket};
use tokio_util::sync::CancellationToken;
use tracing::{debug, error, info, info_span, warn, Instrument, Span};
use tracing::{error, info, info_span, warn, Instrument};
use super::backend::HttpConnError;
use crate::auth::backend::ComputeUserInfo;
use crate::context::RequestMonitoring;
use crate::control_plane::messages::{ColdStartInfo, MetricsAuxInfo};
use crate::metrics::{HttpEndpointPoolsGuard, Metrics};
use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS};
use crate::{DbName, EndpointCacheKey, RoleName};
use crate::control_plane::messages::MetricsAuxInfo;
use crate::metrics::Metrics;
use super::conn_pool_lib::{Client, ClientInnerExt, ConnInfo, GlobalConnPool};
#[cfg(test)]
use {
super::conn_pool_lib::GlobalConnPoolOptions,
crate::auth::backend::ComputeUserInfo,
std::{sync::atomic, time::Duration},
};
#[derive(Debug, Clone)]
pub(crate) struct ConnInfoWithAuth {
@@ -33,34 +31,12 @@ pub(crate) struct ConnInfoWithAuth {
pub(crate) auth: AuthData,
}
#[derive(Debug, Clone)]
pub(crate) struct ConnInfo {
pub(crate) user_info: ComputeUserInfo,
pub(crate) dbname: DbName,
}
#[derive(Debug, Clone)]
pub(crate) enum AuthData {
Password(SmallVec<[u8; 16]>),
Jwt(String),
}
impl ConnInfo {
// hm, change to hasher to avoid cloning?
pub(crate) fn db_and_user(&self) -> (DbName, RoleName) {
(self.dbname.clone(), self.user_info.user.clone())
}
pub(crate) fn endpoint_cache_key(&self) -> Option<EndpointCacheKey> {
// We don't want to cache http connections for ephemeral endpoints.
if self.user_info.options.is_ephemeral() {
None
} else {
Some(self.user_info.endpoint_cache_key())
}
}
}
impl fmt::Display for ConnInfo {
// use custom display to avoid logging password
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
@@ -75,402 +51,6 @@ impl fmt::Display for ConnInfo {
}
}
struct ConnPoolEntry<C: ClientInnerExt> {
conn: ClientInner<C>,
_last_access: std::time::Instant,
}
// Per-endpoint connection pool, (dbname, username) -> DbUserConnPool
// Number of open connections is limited by the `max_conns_per_endpoint`.
pub(crate) struct EndpointConnPool<C: ClientInnerExt> {
pools: HashMap<(DbName, RoleName), DbUserConnPool<C>>,
total_conns: usize,
max_conns: usize,
_guard: HttpEndpointPoolsGuard<'static>,
global_connections_count: Arc<AtomicUsize>,
global_pool_size_max_conns: usize,
}
impl<C: ClientInnerExt> EndpointConnPool<C> {
fn get_conn_entry(&mut self, db_user: (DbName, RoleName)) -> Option<ConnPoolEntry<C>> {
let Self {
pools,
total_conns,
global_connections_count,
..
} = self;
pools.get_mut(&db_user).and_then(|pool_entries| {
pool_entries.get_conn_entry(total_conns, global_connections_count.clone())
})
}
fn remove_client(&mut self, db_user: (DbName, RoleName), conn_id: uuid::Uuid) -> bool {
let Self {
pools,
total_conns,
global_connections_count,
..
} = self;
if let Some(pool) = pools.get_mut(&db_user) {
let old_len = pool.conns.len();
pool.conns.retain(|conn| conn.conn.conn_id != conn_id);
let new_len = pool.conns.len();
let removed = old_len - new_len;
if removed > 0 {
global_connections_count.fetch_sub(removed, atomic::Ordering::Relaxed);
Metrics::get()
.proxy
.http_pool_opened_connections
.get_metric()
.dec_by(removed as i64);
}
*total_conns -= removed;
removed > 0
} else {
false
}
}
fn put(pool: &RwLock<Self>, conn_info: &ConnInfo, client: ClientInner<C>) {
let conn_id = client.conn_id;
if client.is_closed() {
info!(%conn_id, "pool: throwing away connection '{conn_info}' because connection is closed");
return;
}
let global_max_conn = pool.read().global_pool_size_max_conns;
if pool
.read()
.global_connections_count
.load(atomic::Ordering::Relaxed)
>= global_max_conn
{
info!(%conn_id, "pool: throwing away connection '{conn_info}' because pool is full");
return;
}
// return connection to the pool
let mut returned = false;
let mut per_db_size = 0;
let total_conns = {
let mut pool = pool.write();
if pool.total_conns < pool.max_conns {
let pool_entries = pool.pools.entry(conn_info.db_and_user()).or_default();
pool_entries.conns.push(ConnPoolEntry {
conn: client,
_last_access: std::time::Instant::now(),
});
returned = true;
per_db_size = pool_entries.conns.len();
pool.total_conns += 1;
pool.global_connections_count
.fetch_add(1, atomic::Ordering::Relaxed);
Metrics::get()
.proxy
.http_pool_opened_connections
.get_metric()
.inc();
}
pool.total_conns
};
// do logging outside of the mutex
if returned {
info!(%conn_id, "pool: returning connection '{conn_info}' back to the pool, total_conns={total_conns}, for this (db, user)={per_db_size}");
} else {
info!(%conn_id, "pool: throwing away connection '{conn_info}' because pool is full, total_conns={total_conns}");
}
}
}
impl<C: ClientInnerExt> Drop for EndpointConnPool<C> {
fn drop(&mut self) {
if self.total_conns > 0 {
self.global_connections_count
.fetch_sub(self.total_conns, atomic::Ordering::Relaxed);
Metrics::get()
.proxy
.http_pool_opened_connections
.get_metric()
.dec_by(self.total_conns as i64);
}
}
}
pub(crate) struct DbUserConnPool<C: ClientInnerExt> {
conns: Vec<ConnPoolEntry<C>>,
}
impl<C: ClientInnerExt> Default for DbUserConnPool<C> {
fn default() -> Self {
Self { conns: Vec::new() }
}
}
impl<C: ClientInnerExt> DbUserConnPool<C> {
fn clear_closed_clients(&mut self, conns: &mut usize) -> usize {
let old_len = self.conns.len();
self.conns.retain(|conn| !conn.conn.is_closed());
let new_len = self.conns.len();
let removed = old_len - new_len;
*conns -= removed;
removed
}
fn get_conn_entry(
&mut self,
conns: &mut usize,
global_connections_count: Arc<AtomicUsize>,
) -> Option<ConnPoolEntry<C>> {
let mut removed = self.clear_closed_clients(conns);
let conn = self.conns.pop();
if conn.is_some() {
*conns -= 1;
removed += 1;
}
global_connections_count.fetch_sub(removed, atomic::Ordering::Relaxed);
Metrics::get()
.proxy
.http_pool_opened_connections
.get_metric()
.dec_by(removed as i64);
conn
}
}
pub(crate) struct GlobalConnPool<C: ClientInnerExt> {
// endpoint -> per-endpoint connection pool
//
// That should be a fairly conteded map, so return reference to the per-endpoint
// pool as early as possible and release the lock.
global_pool: DashMap<EndpointCacheKey, Arc<RwLock<EndpointConnPool<C>>>>,
/// Number of endpoint-connection pools
///
/// [`DashMap::len`] iterates over all inner pools and acquires a read lock on each.
/// That seems like far too much effort, so we're using a relaxed increment counter instead.
/// It's only used for diagnostics.
global_pool_size: AtomicUsize,
/// Total number of connections in the pool
global_connections_count: Arc<AtomicUsize>,
config: &'static crate::config::HttpConfig,
}
#[derive(Debug, Clone, Copy)]
pub struct GlobalConnPoolOptions {
// Maximum number of connections per one endpoint.
// Can mix different (dbname, username) connections.
// When running out of free slots for a particular endpoint,
// falls back to opening a new connection for each request.
pub max_conns_per_endpoint: usize,
pub gc_epoch: Duration,
pub pool_shards: usize,
pub idle_timeout: Duration,
pub opt_in: bool,
// Total number of connections in the pool.
pub max_total_conns: usize,
}
impl<C: ClientInnerExt> GlobalConnPool<C> {
pub(crate) fn new(config: &'static crate::config::HttpConfig) -> Arc<Self> {
let shards = config.pool_options.pool_shards;
Arc::new(Self {
global_pool: DashMap::with_shard_amount(shards),
global_pool_size: AtomicUsize::new(0),
config,
global_connections_count: Arc::new(AtomicUsize::new(0)),
})
}
#[cfg(test)]
pub(crate) fn get_global_connections_count(&self) -> usize {
self.global_connections_count
.load(atomic::Ordering::Relaxed)
}
pub(crate) fn get_idle_timeout(&self) -> Duration {
self.config.pool_options.idle_timeout
}
pub(crate) fn shutdown(&self) {
// drops all strong references to endpoint-pools
self.global_pool.clear();
}
pub(crate) async fn gc_worker(&self, mut rng: impl Rng) {
let epoch = self.config.pool_options.gc_epoch;
let mut interval = tokio::time::interval(epoch / (self.global_pool.shards().len()) as u32);
loop {
interval.tick().await;
let shard = rng.gen_range(0..self.global_pool.shards().len());
self.gc(shard);
}
}
fn gc(&self, shard: usize) {
debug!(shard, "pool: performing epoch reclamation");
// acquire a random shard lock
let mut shard = self.global_pool.shards()[shard].write();
let timer = Metrics::get()
.proxy
.http_pool_reclaimation_lag_seconds
.start_timer();
let current_len = shard.len();
let mut clients_removed = 0;
shard.retain(|endpoint, x| {
// if the current endpoint pool is unique (no other strong or weak references)
// then it is currently not in use by any connections.
if let Some(pool) = Arc::get_mut(x.get_mut()) {
let EndpointConnPool {
pools, total_conns, ..
} = pool.get_mut();
// ensure that closed clients are removed
for db_pool in pools.values_mut() {
clients_removed += db_pool.clear_closed_clients(total_conns);
}
// we only remove this pool if it has no active connections
if *total_conns == 0 {
info!("pool: discarding pool for endpoint {endpoint}");
return false;
}
}
true
});
let new_len = shard.len();
drop(shard);
timer.observe();
// Do logging outside of the lock.
if clients_removed > 0 {
let size = self
.global_connections_count
.fetch_sub(clients_removed, atomic::Ordering::Relaxed)
- clients_removed;
Metrics::get()
.proxy
.http_pool_opened_connections
.get_metric()
.dec_by(clients_removed as i64);
info!("pool: performed global pool gc. removed {clients_removed} clients, total number of clients in pool is {size}");
}
let removed = current_len - new_len;
if removed > 0 {
let global_pool_size = self
.global_pool_size
.fetch_sub(removed, atomic::Ordering::Relaxed)
- removed;
info!("pool: performed global pool gc. size now {global_pool_size}");
}
}
pub(crate) fn get(
self: &Arc<Self>,
ctx: &RequestMonitoring,
conn_info: &ConnInfo,
) -> Result<Option<Client<C>>, HttpConnError> {
let mut client: Option<ClientInner<C>> = None;
let Some(endpoint) = conn_info.endpoint_cache_key() else {
return Ok(None);
};
let endpoint_pool = self.get_or_create_endpoint_pool(&endpoint);
if let Some(entry) = endpoint_pool
.write()
.get_conn_entry(conn_info.db_and_user())
{
client = Some(entry.conn);
}
let endpoint_pool = Arc::downgrade(&endpoint_pool);
// ok return cached connection if found and establish a new one otherwise
if let Some(client) = client {
if client.is_closed() {
info!("pool: cached connection '{conn_info}' is closed, opening a new one");
return Ok(None);
}
tracing::Span::current().record("conn_id", tracing::field::display(client.conn_id));
tracing::Span::current().record(
"pid",
tracing::field::display(client.inner.get_process_id()),
);
info!(
cold_start_info = ColdStartInfo::HttpPoolHit.as_str(),
"pool: reusing connection '{conn_info}'"
);
client.session.send(ctx.session_id())?;
ctx.set_cold_start_info(ColdStartInfo::HttpPoolHit);
ctx.success();
return Ok(Some(Client::new(client, conn_info.clone(), endpoint_pool)));
}
Ok(None)
}
fn get_or_create_endpoint_pool(
self: &Arc<Self>,
endpoint: &EndpointCacheKey,
) -> Arc<RwLock<EndpointConnPool<C>>> {
// fast path
if let Some(pool) = self.global_pool.get(endpoint) {
return pool.clone();
}
// slow path
let new_pool = Arc::new(RwLock::new(EndpointConnPool {
pools: HashMap::new(),
total_conns: 0,
max_conns: self.config.pool_options.max_conns_per_endpoint,
_guard: Metrics::get().proxy.http_endpoint_pools.guard(),
global_connections_count: self.global_connections_count.clone(),
global_pool_size_max_conns: self.config.pool_options.max_total_conns,
}));
// find or create a pool for this endpoint
let mut created = false;
let pool = self
.global_pool
.entry(endpoint.clone())
.or_insert_with(|| {
created = true;
new_pool
})
.clone();
// log new global pool size
if created {
let global_pool_size = self
.global_pool_size
.fetch_add(1, atomic::Ordering::Relaxed)
+ 1;
info!(
"pool: created new pool for '{endpoint}', global pool size now {global_pool_size}"
);
}
pool
}
}
pub(crate) fn poll_client<C: ClientInnerExt>(
global_pool: Arc<GlobalConnPool<C>>,
ctx: &RequestMonitoring,
@@ -574,7 +154,7 @@ pub(crate) fn poll_client<C: ClientInnerExt>(
}
.instrument(span));
let inner = ClientInner {
let inner = ClientInnerRemote {
inner: client,
session: tx,
cancel,
@@ -584,7 +164,7 @@ pub(crate) fn poll_client<C: ClientInnerExt>(
Client::new(inner, conn_info, pool_clone)
}
struct ClientInner<C: ClientInnerExt> {
pub(crate) struct ClientInnerRemote<C: ClientInnerExt> {
inner: C,
session: tokio::sync::watch::Sender<uuid::Uuid>,
cancel: CancellationToken,
@@ -592,131 +172,36 @@ struct ClientInner<C: ClientInnerExt> {
conn_id: uuid::Uuid,
}
impl<C: ClientInnerExt> Drop for ClientInner<C> {
fn drop(&mut self) {
// on client drop, tell the conn to shut down
self.cancel.cancel();
impl<C: ClientInnerExt> ClientInnerRemote<C> {
pub(crate) fn inner_mut(&mut self) -> &mut C {
&mut self.inner
}
}
pub(crate) trait ClientInnerExt: Sync + Send + 'static {
fn is_closed(&self) -> bool;
fn get_process_id(&self) -> i32;
}
impl ClientInnerExt for tokio_postgres::Client {
fn is_closed(&self) -> bool {
self.is_closed()
pub(crate) fn inner(&self) -> &C {
&self.inner
}
pub(crate) fn session(&mut self) -> &mut tokio::sync::watch::Sender<uuid::Uuid> {
&mut self.session
}
pub(crate) fn aux(&self) -> &MetricsAuxInfo {
&self.aux
}
pub(crate) fn get_conn_id(&self) -> uuid::Uuid {
self.conn_id
}
fn get_process_id(&self) -> i32 {
self.get_process_id()
}
}
impl<C: ClientInnerExt> ClientInner<C> {
pub(crate) fn is_closed(&self) -> bool {
self.inner.is_closed()
}
}
impl<C: ClientInnerExt> Client<C> {
pub(crate) fn metrics(&self) -> Arc<MetricCounter> {
let aux = &self.inner.as_ref().unwrap().aux;
USAGE_METRICS.register(Ids {
endpoint_id: aux.endpoint_id,
branch_id: aux.branch_id,
})
}
}
pub(crate) struct Client<C: ClientInnerExt> {
span: Span,
inner: Option<ClientInner<C>>,
conn_info: ConnInfo,
pool: Weak<RwLock<EndpointConnPool<C>>>,
}
pub(crate) struct Discard<'a, C: ClientInnerExt> {
conn_info: &'a ConnInfo,
pool: &'a mut Weak<RwLock<EndpointConnPool<C>>>,
}
impl<C: ClientInnerExt> Client<C> {
pub(self) fn new(
inner: ClientInner<C>,
conn_info: ConnInfo,
pool: Weak<RwLock<EndpointConnPool<C>>>,
) -> Self {
Self {
inner: Some(inner),
span: Span::current(),
conn_info,
pool,
}
}
pub(crate) fn inner(&mut self) -> (&mut C, Discard<'_, C>) {
let Self {
inner,
pool,
conn_info,
span: _,
} = self;
let inner = inner.as_mut().expect("client inner should not be removed");
(&mut inner.inner, Discard { conn_info, pool })
}
}
impl<C: ClientInnerExt> Discard<'_, C> {
pub(crate) fn check_idle(&mut self, status: ReadyForQueryStatus) {
let conn_info = &self.conn_info;
if status != ReadyForQueryStatus::Idle && std::mem::take(self.pool).strong_count() > 0 {
info!("pool: throwing away connection '{conn_info}' because connection is not idle");
}
}
pub(crate) fn discard(&mut self) {
let conn_info = &self.conn_info;
if std::mem::take(self.pool).strong_count() > 0 {
info!("pool: throwing away connection '{conn_info}' because connection is potentially in a broken state");
}
}
}
impl<C: ClientInnerExt> Deref for Client<C> {
type Target = C;
fn deref(&self) -> &Self::Target {
&self
.inner
.as_ref()
.expect("client inner should not be removed")
.inner
}
}
impl<C: ClientInnerExt> Client<C> {
fn do_drop(&mut self) -> Option<impl FnOnce()> {
let conn_info = self.conn_info.clone();
let client = self
.inner
.take()
.expect("client inner should not be removed");
if let Some(conn_pool) = std::mem::take(&mut self.pool).upgrade() {
let current_span = self.span.clone();
// return connection to the pool
return Some(move || {
let _span = current_span.enter();
EndpointConnPool::put(&conn_pool, &conn_info, client);
});
}
None
}
}
impl<C: ClientInnerExt> Drop for Client<C> {
impl<C: ClientInnerExt> Drop for ClientInnerRemote<C> {
fn drop(&mut self) {
if let Some(drop) = self.do_drop() {
tokio::task::spawn_blocking(drop);
}
// on client drop, tell the conn to shut down
self.cancel.cancel();
}
}
@@ -745,12 +230,12 @@ mod tests {
}
}
fn create_inner() -> ClientInner<MockClient> {
fn create_inner() -> ClientInnerRemote<MockClient> {
create_inner_with(MockClient::new(false))
}
fn create_inner_with(client: MockClient) -> ClientInner<MockClient> {
ClientInner {
fn create_inner_with(client: MockClient) -> ClientInnerRemote<MockClient> {
ClientInnerRemote {
inner: client,
session: tokio::sync::watch::Sender::new(uuid::Uuid::new_v4()),
cancel: CancellationToken::new(),
@@ -797,7 +282,7 @@ mod tests {
{
let mut client = Client::new(create_inner(), conn_info.clone(), ep_pool.clone());
assert_eq!(0, pool.get_global_connections_count());
client.inner().1.discard();
client.inner_mut().1.discard();
// Discard should not add the connection from the pool.
assert_eq!(0, pool.get_global_connections_count());
}

View File

@@ -0,0 +1,562 @@
use dashmap::DashMap;
use parking_lot::RwLock;
use rand::Rng;
use std::{collections::HashMap, sync::Arc, sync::Weak, time::Duration};
use std::{
ops::Deref,
sync::atomic::{self, AtomicUsize},
};
use tokio_postgres::ReadyForQueryStatus;
use crate::control_plane::messages::ColdStartInfo;
use crate::metrics::{HttpEndpointPoolsGuard, Metrics};
use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS};
use crate::{
auth::backend::ComputeUserInfo, context::RequestMonitoring, DbName, EndpointCacheKey, RoleName,
};
use super::conn_pool::ClientInnerRemote;
use tracing::info;
use tracing::{debug, Span};
use super::backend::HttpConnError;
#[derive(Debug, Clone)]
pub(crate) struct ConnInfo {
pub(crate) user_info: ComputeUserInfo,
pub(crate) dbname: DbName,
}
impl ConnInfo {
// hm, change to hasher to avoid cloning?
pub(crate) fn db_and_user(&self) -> (DbName, RoleName) {
(self.dbname.clone(), self.user_info.user.clone())
}
pub(crate) fn endpoint_cache_key(&self) -> Option<EndpointCacheKey> {
// We don't want to cache http connections for ephemeral endpoints.
if self.user_info.options.is_ephemeral() {
None
} else {
Some(self.user_info.endpoint_cache_key())
}
}
}
pub(crate) struct ConnPoolEntry<C: ClientInnerExt> {
pub(crate) conn: ClientInnerRemote<C>,
pub(crate) _last_access: std::time::Instant,
}
// Per-endpoint connection pool, (dbname, username) -> DbUserConnPool
// Number of open connections is limited by the `max_conns_per_endpoint`.
pub(crate) struct EndpointConnPool<C: ClientInnerExt> {
pools: HashMap<(DbName, RoleName), DbUserConnPool<C>>,
total_conns: usize,
max_conns: usize,
_guard: HttpEndpointPoolsGuard<'static>,
global_connections_count: Arc<AtomicUsize>,
global_pool_size_max_conns: usize,
}
impl<C: ClientInnerExt> EndpointConnPool<C> {
fn get_conn_entry(&mut self, db_user: (DbName, RoleName)) -> Option<ConnPoolEntry<C>> {
let Self {
pools,
total_conns,
global_connections_count,
..
} = self;
pools.get_mut(&db_user).and_then(|pool_entries| {
let (entry, removed) = pool_entries.get_conn_entry(total_conns);
global_connections_count.fetch_sub(removed, atomic::Ordering::Relaxed);
entry
})
}
pub(crate) fn remove_client(
&mut self,
db_user: (DbName, RoleName),
conn_id: uuid::Uuid,
) -> bool {
let Self {
pools,
total_conns,
global_connections_count,
..
} = self;
if let Some(pool) = pools.get_mut(&db_user) {
let old_len = pool.conns.len();
pool.conns.retain(|conn| conn.conn.get_conn_id() != conn_id);
let new_len = pool.conns.len();
let removed = old_len - new_len;
if removed > 0 {
global_connections_count.fetch_sub(removed, atomic::Ordering::Relaxed);
Metrics::get()
.proxy
.http_pool_opened_connections
.get_metric()
.dec_by(removed as i64);
}
*total_conns -= removed;
removed > 0
} else {
false
}
}
pub(crate) fn put(pool: &RwLock<Self>, conn_info: &ConnInfo, client: ClientInnerRemote<C>) {
let conn_id = client.get_conn_id();
if client.is_closed() {
info!(%conn_id, "pool: throwing away connection '{conn_info}' because connection is closed");
return;
}
let global_max_conn = pool.read().global_pool_size_max_conns;
if pool
.read()
.global_connections_count
.load(atomic::Ordering::Relaxed)
>= global_max_conn
{
info!(%conn_id, "pool: throwing away connection '{conn_info}' because pool is full");
return;
}
// return connection to the pool
let mut returned = false;
let mut per_db_size = 0;
let total_conns = {
let mut pool = pool.write();
if pool.total_conns < pool.max_conns {
let pool_entries = pool.pools.entry(conn_info.db_and_user()).or_default();
pool_entries.conns.push(ConnPoolEntry {
conn: client,
_last_access: std::time::Instant::now(),
});
returned = true;
per_db_size = pool_entries.conns.len();
pool.total_conns += 1;
pool.global_connections_count
.fetch_add(1, atomic::Ordering::Relaxed);
Metrics::get()
.proxy
.http_pool_opened_connections
.get_metric()
.inc();
}
pool.total_conns
};
// do logging outside of the mutex
if returned {
info!(%conn_id, "pool: returning connection '{conn_info}' back to the pool, total_conns={total_conns}, for this (db, user)={per_db_size}");
} else {
info!(%conn_id, "pool: throwing away connection '{conn_info}' because pool is full, total_conns={total_conns}");
}
}
}
impl<C: ClientInnerExt> Drop for EndpointConnPool<C> {
fn drop(&mut self) {
if self.total_conns > 0 {
self.global_connections_count
.fetch_sub(self.total_conns, atomic::Ordering::Relaxed);
Metrics::get()
.proxy
.http_pool_opened_connections
.get_metric()
.dec_by(self.total_conns as i64);
}
}
}
pub(crate) struct DbUserConnPool<C: ClientInnerExt> {
pub(crate) conns: Vec<ConnPoolEntry<C>>,
}
impl<C: ClientInnerExt> Default for DbUserConnPool<C> {
fn default() -> Self {
Self { conns: Vec::new() }
}
}
impl<C: ClientInnerExt> DbUserConnPool<C> {
fn clear_closed_clients(&mut self, conns: &mut usize) -> usize {
let old_len = self.conns.len();
self.conns.retain(|conn| !conn.conn.is_closed());
let new_len = self.conns.len();
let removed = old_len - new_len;
*conns -= removed;
removed
}
pub(crate) fn get_conn_entry(
&mut self,
conns: &mut usize,
) -> (Option<ConnPoolEntry<C>>, usize) {
let mut removed = self.clear_closed_clients(conns);
let conn = self.conns.pop();
if conn.is_some() {
*conns -= 1;
removed += 1;
}
Metrics::get()
.proxy
.http_pool_opened_connections
.get_metric()
.dec_by(removed as i64);
(conn, removed)
}
}
pub(crate) struct GlobalConnPool<C: ClientInnerExt> {
// endpoint -> per-endpoint connection pool
//
// That should be a fairly conteded map, so return reference to the per-endpoint
// pool as early as possible and release the lock.
global_pool: DashMap<EndpointCacheKey, Arc<RwLock<EndpointConnPool<C>>>>,
/// Number of endpoint-connection pools
///
/// [`DashMap::len`] iterates over all inner pools and acquires a read lock on each.
/// That seems like far too much effort, so we're using a relaxed increment counter instead.
/// It's only used for diagnostics.
global_pool_size: AtomicUsize,
/// Total number of connections in the pool
global_connections_count: Arc<AtomicUsize>,
config: &'static crate::config::HttpConfig,
}
#[derive(Debug, Clone, Copy)]
pub struct GlobalConnPoolOptions {
// Maximum number of connections per one endpoint.
// Can mix different (dbname, username) connections.
// When running out of free slots for a particular endpoint,
// falls back to opening a new connection for each request.
pub max_conns_per_endpoint: usize,
pub gc_epoch: Duration,
pub pool_shards: usize,
pub idle_timeout: Duration,
pub opt_in: bool,
// Total number of connections in the pool.
pub max_total_conns: usize,
}
impl<C: ClientInnerExt> GlobalConnPool<C> {
pub(crate) fn new(config: &'static crate::config::HttpConfig) -> Arc<Self> {
let shards = config.pool_options.pool_shards;
Arc::new(Self {
global_pool: DashMap::with_shard_amount(shards),
global_pool_size: AtomicUsize::new(0),
config,
global_connections_count: Arc::new(AtomicUsize::new(0)),
})
}
#[cfg(test)]
pub(crate) fn get_global_connections_count(&self) -> usize {
self.global_connections_count
.load(atomic::Ordering::Relaxed)
}
pub(crate) fn get_idle_timeout(&self) -> Duration {
self.config.pool_options.idle_timeout
}
pub(crate) fn shutdown(&self) {
// drops all strong references to endpoint-pools
self.global_pool.clear();
}
pub(crate) async fn gc_worker(&self, mut rng: impl Rng) {
let epoch = self.config.pool_options.gc_epoch;
let mut interval = tokio::time::interval(epoch / (self.global_pool.shards().len()) as u32);
loop {
interval.tick().await;
let shard = rng.gen_range(0..self.global_pool.shards().len());
self.gc(shard);
}
}
pub(crate) fn gc(&self, shard: usize) {
debug!(shard, "pool: performing epoch reclamation");
// acquire a random shard lock
let mut shard = self.global_pool.shards()[shard].write();
let timer = Metrics::get()
.proxy
.http_pool_reclaimation_lag_seconds
.start_timer();
let current_len = shard.len();
let mut clients_removed = 0;
shard.retain(|endpoint, x| {
// if the current endpoint pool is unique (no other strong or weak references)
// then it is currently not in use by any connections.
if let Some(pool) = Arc::get_mut(x.get_mut()) {
let EndpointConnPool {
pools, total_conns, ..
} = pool.get_mut();
// ensure that closed clients are removed
for db_pool in pools.values_mut() {
clients_removed += db_pool.clear_closed_clients(total_conns);
}
// we only remove this pool if it has no active connections
if *total_conns == 0 {
info!("pool: discarding pool for endpoint {endpoint}");
return false;
}
}
true
});
let new_len = shard.len();
drop(shard);
timer.observe();
// Do logging outside of the lock.
if clients_removed > 0 {
let size = self
.global_connections_count
.fetch_sub(clients_removed, atomic::Ordering::Relaxed)
- clients_removed;
Metrics::get()
.proxy
.http_pool_opened_connections
.get_metric()
.dec_by(clients_removed as i64);
info!("pool: performed global pool gc. removed {clients_removed} clients, total number of clients in pool is {size}");
}
let removed = current_len - new_len;
if removed > 0 {
let global_pool_size = self
.global_pool_size
.fetch_sub(removed, atomic::Ordering::Relaxed)
- removed;
info!("pool: performed global pool gc. size now {global_pool_size}");
}
}
pub(crate) fn get_or_create_endpoint_pool(
self: &Arc<Self>,
endpoint: &EndpointCacheKey,
) -> Arc<RwLock<EndpointConnPool<C>>> {
// fast path
if let Some(pool) = self.global_pool.get(endpoint) {
return pool.clone();
}
// slow path
let new_pool = Arc::new(RwLock::new(EndpointConnPool {
pools: HashMap::new(),
total_conns: 0,
max_conns: self.config.pool_options.max_conns_per_endpoint,
_guard: Metrics::get().proxy.http_endpoint_pools.guard(),
global_connections_count: self.global_connections_count.clone(),
global_pool_size_max_conns: self.config.pool_options.max_total_conns,
}));
// find or create a pool for this endpoint
let mut created = false;
let pool = self
.global_pool
.entry(endpoint.clone())
.or_insert_with(|| {
created = true;
new_pool
})
.clone();
// log new global pool size
if created {
let global_pool_size = self
.global_pool_size
.fetch_add(1, atomic::Ordering::Relaxed)
+ 1;
info!(
"pool: created new pool for '{endpoint}', global pool size now {global_pool_size}"
);
}
pool
}
pub(crate) fn get(
self: &Arc<Self>,
ctx: &RequestMonitoring,
conn_info: &ConnInfo,
) -> Result<Option<Client<C>>, HttpConnError> {
let mut client: Option<ClientInnerRemote<C>> = None;
let Some(endpoint) = conn_info.endpoint_cache_key() else {
return Ok(None);
};
let endpoint_pool = self.get_or_create_endpoint_pool(&endpoint);
if let Some(entry) = endpoint_pool
.write()
.get_conn_entry(conn_info.db_and_user())
{
client = Some(entry.conn);
}
let endpoint_pool = Arc::downgrade(&endpoint_pool);
// ok return cached connection if found and establish a new one otherwise
if let Some(mut client) = client {
if client.is_closed() {
info!("pool: cached connection '{conn_info}' is closed, opening a new one");
return Ok(None);
}
tracing::Span::current()
.record("conn_id", tracing::field::display(client.get_conn_id()));
tracing::Span::current().record(
"pid",
tracing::field::display(client.inner().get_process_id()),
);
info!(
cold_start_info = ColdStartInfo::HttpPoolHit.as_str(),
"pool: reusing connection '{conn_info}'"
);
client.session().send(ctx.session_id())?;
ctx.set_cold_start_info(ColdStartInfo::HttpPoolHit);
ctx.success();
return Ok(Some(Client::new(client, conn_info.clone(), endpoint_pool)));
}
Ok(None)
}
}
impl<C: ClientInnerExt> Client<C> {
pub(crate) fn new(
inner: ClientInnerRemote<C>,
conn_info: ConnInfo,
pool: Weak<RwLock<EndpointConnPool<C>>>,
) -> Self {
Self {
inner: Some(inner),
span: Span::current(),
conn_info,
pool,
}
}
pub(crate) fn inner_mut(&mut self) -> (&mut C, Discard<'_, C>) {
let Self {
inner,
pool,
conn_info,
span: _,
} = self;
let inner = inner.as_mut().expect("client inner should not be removed");
let inner_ref = inner.inner_mut();
(inner_ref, Discard { conn_info, pool })
}
pub(crate) fn metrics(&self) -> Arc<MetricCounter> {
let aux = &self.inner.as_ref().unwrap().aux();
USAGE_METRICS.register(Ids {
endpoint_id: aux.endpoint_id,
branch_id: aux.branch_id,
})
}
pub(crate) fn do_drop(&mut self) -> Option<impl FnOnce()> {
let conn_info = self.conn_info.clone();
let client = self
.inner
.take()
.expect("client inner should not be removed");
if let Some(conn_pool) = std::mem::take(&mut self.pool).upgrade() {
let current_span = self.span.clone();
// return connection to the pool
return Some(move || {
let _span = current_span.enter();
EndpointConnPool::put(&conn_pool, &conn_info, client);
});
}
None
}
}
pub(crate) struct Client<C: ClientInnerExt> {
span: Span,
inner: Option<ClientInnerRemote<C>>,
conn_info: ConnInfo,
pool: Weak<RwLock<EndpointConnPool<C>>>,
}
impl<C: ClientInnerExt> Drop for Client<C> {
fn drop(&mut self) {
if let Some(drop) = self.do_drop() {
tokio::task::spawn_blocking(drop);
}
}
}
impl<C: ClientInnerExt> Deref for Client<C> {
type Target = C;
fn deref(&self) -> &Self::Target {
self.inner
.as_ref()
.expect("client inner should not be removed")
.inner()
}
}
pub(crate) trait ClientInnerExt: Sync + Send + 'static {
fn is_closed(&self) -> bool;
fn get_process_id(&self) -> i32;
}
impl ClientInnerExt for tokio_postgres::Client {
fn is_closed(&self) -> bool {
self.is_closed()
}
fn get_process_id(&self) -> i32 {
self.get_process_id()
}
}
pub(crate) struct Discard<'a, C: ClientInnerExt> {
conn_info: &'a ConnInfo,
pool: &'a mut Weak<RwLock<EndpointConnPool<C>>>,
}
impl<C: ClientInnerExt> Discard<'_, C> {
pub(crate) fn check_idle(&mut self, status: ReadyForQueryStatus) {
let conn_info = &self.conn_info;
if status != ReadyForQueryStatus::Idle && std::mem::take(self.pool).strong_count() > 0 {
info!("pool: throwing away connection '{conn_info}' because connection is not idle");
}
}
pub(crate) fn discard(&mut self) {
let conn_info = &self.conn_info;
if std::mem::take(self.pool).strong_count() > 0 {
info!("pool: throwing away connection '{conn_info}' because connection is potentially in a broken state");
}
}
}

View File

@@ -10,11 +10,12 @@ use rand::Rng;
use tokio::net::TcpStream;
use tracing::{debug, error, info, info_span, Instrument};
use super::conn_pool::ConnInfo;
use crate::context::RequestMonitoring;
use crate::control_plane::messages::{ColdStartInfo, MetricsAuxInfo};
use crate::metrics::{HttpEndpointPoolsGuard, Metrics};
use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS};
use super::conn_pool_lib::{ClientInnerExt, ConnInfo};
use crate::EndpointCacheKey;
pub(crate) type Send = http2::SendRequest<hyper::body::Incoming>;
@@ -22,15 +23,15 @@ pub(crate) type Connect =
http2::Connection<TokioIo<TcpStream>, hyper::body::Incoming, TokioExecutor>;
#[derive(Clone)]
struct ConnPoolEntry {
conn: Send,
pub(crate) struct ConnPoolEntry<C: ClientInnerExt + Clone> {
conn: C,
conn_id: uuid::Uuid,
aux: MetricsAuxInfo,
}
// Per-endpoint connection pool
// Number of open connections is limited by the `max_conns_per_endpoint`.
pub(crate) struct EndpointConnPool {
pub(crate) struct EndpointConnPool<C: ClientInnerExt + Clone> {
// TODO(conrad):
// either we should open more connections depending on stream count
// (not exposed by hyper, need our own counter)
@@ -40,13 +41,13 @@ pub(crate) struct EndpointConnPool {
// seems somewhat redundant though.
//
// Probably we should run a semaphore and just the single conn. TBD.
conns: VecDeque<ConnPoolEntry>,
conns: VecDeque<ConnPoolEntry<C>>,
_guard: HttpEndpointPoolsGuard<'static>,
global_connections_count: Arc<AtomicUsize>,
}
impl EndpointConnPool {
fn get_conn_entry(&mut self) -> Option<ConnPoolEntry> {
impl<C: ClientInnerExt + Clone> EndpointConnPool<C> {
fn get_conn_entry(&mut self) -> Option<ConnPoolEntry<C>> {
let Self { conns, .. } = self;
loop {
@@ -81,7 +82,7 @@ impl EndpointConnPool {
}
}
impl Drop for EndpointConnPool {
impl<C: ClientInnerExt + Clone> Drop for EndpointConnPool<C> {
fn drop(&mut self) {
if !self.conns.is_empty() {
self.global_connections_count
@@ -95,12 +96,12 @@ impl Drop for EndpointConnPool {
}
}
pub(crate) struct GlobalConnPool {
pub(crate) struct GlobalConnPool<C: ClientInnerExt + Clone> {
// endpoint -> per-endpoint connection pool
//
// That should be a fairly conteded map, so return reference to the per-endpoint
// pool as early as possible and release the lock.
global_pool: DashMap<EndpointCacheKey, Arc<RwLock<EndpointConnPool>>>,
global_pool: DashMap<EndpointCacheKey, Arc<RwLock<EndpointConnPool<C>>>>,
/// Number of endpoint-connection pools
///
@@ -115,7 +116,7 @@ pub(crate) struct GlobalConnPool {
config: &'static crate::config::HttpConfig,
}
impl GlobalConnPool {
impl<C: ClientInnerExt + Clone> GlobalConnPool<C> {
pub(crate) fn new(config: &'static crate::config::HttpConfig) -> Arc<Self> {
let shards = config.pool_options.pool_shards;
Arc::new(Self {
@@ -210,7 +211,7 @@ impl GlobalConnPool {
self: &Arc<Self>,
ctx: &RequestMonitoring,
conn_info: &ConnInfo,
) -> Option<Client> {
) -> Option<Client<C>> {
let endpoint = conn_info.endpoint_cache_key()?;
let endpoint_pool = self.get_or_create_endpoint_pool(&endpoint);
let client = endpoint_pool.write().get_conn_entry()?;
@@ -228,7 +229,7 @@ impl GlobalConnPool {
fn get_or_create_endpoint_pool(
self: &Arc<Self>,
endpoint: &EndpointCacheKey,
) -> Arc<RwLock<EndpointConnPool>> {
) -> Arc<RwLock<EndpointConnPool<C>>> {
// fast path
if let Some(pool) = self.global_pool.get(endpoint) {
return pool.clone();
@@ -268,14 +269,14 @@ impl GlobalConnPool {
}
pub(crate) fn poll_http2_client(
global_pool: Arc<GlobalConnPool>,
global_pool: Arc<GlobalConnPool<Send>>,
ctx: &RequestMonitoring,
conn_info: &ConnInfo,
client: Send,
connection: Connect,
conn_id: uuid::Uuid,
aux: MetricsAuxInfo,
) -> Client {
) -> Client<Send> {
let conn_gauge = Metrics::get().proxy.db_connections.guard(ctx.protocol());
let session_id = ctx.session_id();
@@ -322,13 +323,13 @@ pub(crate) fn poll_http2_client(
Client::new(client, aux)
}
pub(crate) struct Client {
pub(crate) inner: Send,
pub(crate) struct Client<C: ClientInnerExt + Clone> {
pub(crate) inner: C,
aux: MetricsAuxInfo,
}
impl Client {
pub(self) fn new(inner: Send, aux: MetricsAuxInfo) -> Self {
impl<C: ClientInnerExt + Clone> Client<C> {
pub(self) fn new(inner: C, aux: MetricsAuxInfo) -> Self {
Self { inner, aux }
}
@@ -339,3 +340,14 @@ impl Client {
})
}
}
impl ClientInnerExt for Send {
fn is_closed(&self) -> bool {
self.is_closed()
}
fn get_process_id(&self) -> i32 {
// ideally throw something meaningful
-1
}
}

View File

@@ -31,11 +31,12 @@ use tokio_util::sync::CancellationToken;
use tracing::{error, info, info_span, warn, Instrument, Span};
use super::backend::HttpConnError;
use super::conn_pool::{ClientInnerExt, ConnInfo};
use super::conn_pool_lib::{ClientInnerExt, ConnInfo};
use crate::context::RequestMonitoring;
use crate::control_plane::messages::{ColdStartInfo, MetricsAuxInfo};
use crate::metrics::Metrics;
use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS};
use crate::{DbName, RoleName};
pub(crate) const EXT_NAME: &str = "pg_session_jwt";
@@ -389,7 +390,7 @@ pub(crate) fn poll_client(
LocalClient::new(inner, conn_info, pool_clone)
}
struct ClientInner<C: ClientInnerExt> {
pub(crate) struct ClientInner<C: ClientInnerExt> {
inner: C,
session: tokio::sync::watch::Sender<uuid::Uuid>,
cancel: CancellationToken,
@@ -414,13 +415,24 @@ impl<C: ClientInnerExt> ClientInner<C> {
}
}
impl<C: ClientInnerExt> LocalClient<C> {
pub(crate) fn metrics(&self) -> Arc<MetricCounter> {
let aux = &self.inner.as_ref().unwrap().aux;
USAGE_METRICS.register(Ids {
endpoint_id: aux.endpoint_id,
branch_id: aux.branch_id,
})
impl ClientInner<tokio_postgres::Client> {
pub(crate) async fn set_jwt_session(&mut self, payload: &[u8]) -> Result<(), HttpConnError> {
self.jti += 1;
let token = resign_jwt(&self.key, payload, self.jti)?;
// initiates the auth session
self.inner.simple_query("discard all").await?;
self.inner
.query(
"select auth.jwt_session_init($1)",
&[&token as &(dyn ToSql + Sync)],
)
.await?;
let pid = self.inner.get_process_id();
info!(pid, jti = self.jti, "user session state init");
Ok(())
}
}
@@ -449,6 +461,18 @@ impl<C: ClientInnerExt> LocalClient<C> {
pool,
}
}
pub(crate) fn client_inner(&mut self) -> (&mut ClientInner<C>, Discard<'_, C>) {
let Self {
inner,
pool,
conn_info,
span: _,
} = self;
let inner_m = inner.as_mut().expect("client inner should not be removed");
(inner_m, Discard { conn_info, pool })
}
pub(crate) fn inner(&mut self) -> (&mut C, Discard<'_, C>) {
let Self {
inner,
@@ -461,33 +485,6 @@ impl<C: ClientInnerExt> LocalClient<C> {
}
}
impl LocalClient<tokio_postgres::Client> {
pub(crate) async fn set_jwt_session(&mut self, payload: &[u8]) -> Result<(), HttpConnError> {
let inner = self
.inner
.as_mut()
.expect("client inner should not be removed");
inner.jti += 1;
let token = resign_jwt(&inner.key, payload, inner.jti)?;
// initiates the auth session
inner.inner.simple_query("discard all").await?;
inner
.inner
.query(
"select auth.jwt_session_init($1)",
&[&token as &(dyn ToSql + Sync)],
)
.await?;
let pid = inner.inner.get_process_id();
info!(pid, jti = inner.jti, "user session state init");
Ok(())
}
}
/// implements relatively efficient in-place json object key upserting
///
/// only supports top-level keys
@@ -551,24 +548,15 @@ fn sign_jwt(sk: &SigningKey, payload: &[u8]) -> String {
jwt
}
impl<C: ClientInnerExt> Discard<'_, C> {
pub(crate) fn check_idle(&mut self, status: ReadyForQueryStatus) {
let conn_info = &self.conn_info;
if status != ReadyForQueryStatus::Idle && std::mem::take(self.pool).strong_count() > 0 {
info!(
"local_pool: throwing away connection '{conn_info}' because connection is not idle"
);
}
}
pub(crate) fn discard(&mut self) {
let conn_info = &self.conn_info;
if std::mem::take(self.pool).strong_count() > 0 {
info!("local_pool: throwing away connection '{conn_info}' because connection is potentially in a broken state");
}
}
}
impl<C: ClientInnerExt> LocalClient<C> {
pub(crate) fn metrics(&self) -> Arc<MetricCounter> {
let aux = &self.inner.as_ref().unwrap().aux;
USAGE_METRICS.register(Ids {
endpoint_id: aux.endpoint_id,
branch_id: aux.branch_id,
})
}
fn do_drop(&mut self) -> Option<impl FnOnce()> {
let conn_info = self.conn_info.clone();
let client = self
@@ -595,6 +583,23 @@ impl<C: ClientInnerExt> Drop for LocalClient<C> {
}
}
impl<C: ClientInnerExt> Discard<'_, C> {
pub(crate) fn check_idle(&mut self, status: ReadyForQueryStatus) {
let conn_info = &self.conn_info;
if status != ReadyForQueryStatus::Idle && std::mem::take(self.pool).strong_count() > 0 {
info!(
"local_pool: throwing away connection '{conn_info}' because connection is not idle"
);
}
}
pub(crate) fn discard(&mut self) {
let conn_info = &self.conn_info;
if std::mem::take(self.pool).strong_count() > 0 {
info!("local_pool: throwing away connection '{conn_info}' because connection is potentially in a broken state");
}
}
}
#[cfg(test)]
mod tests {
use p256::ecdsa::SigningKey;

View File

@@ -5,6 +5,7 @@
mod backend;
pub mod cancel_set;
mod conn_pool;
mod conn_pool_lib;
mod http_conn_pool;
mod http_util;
mod json;
@@ -20,7 +21,7 @@ use anyhow::Context;
use async_trait::async_trait;
use atomic_take::AtomicTake;
use bytes::Bytes;
pub use conn_pool::GlobalConnPoolOptions;
pub use conn_pool_lib::GlobalConnPoolOptions;
use futures::future::{select, Either};
use futures::TryFutureExt;
use http::{Method, Response, StatusCode};
@@ -65,7 +66,7 @@ pub async fn task_main(
}
let local_pool = local_conn_pool::LocalConnPool::new(&config.http_config);
let conn_pool = conn_pool::GlobalConnPool::new(&config.http_config);
let conn_pool = conn_pool_lib::GlobalConnPool::new(&config.http_config);
{
let conn_pool = Arc::clone(&conn_pool);
tokio::spawn(async move {

View File

@@ -25,10 +25,11 @@ use urlencoding;
use utils::http::error::ApiError;
use super::backend::{LocalProxyConnError, PoolingBackend};
use super::conn_pool::{AuthData, ConnInfo, ConnInfoWithAuth};
use super::conn_pool::{AuthData, ConnInfoWithAuth};
use super::conn_pool_lib::{self, ConnInfo};
use super::http_util::json_response;
use super::json::{json_to_pg_text, pg_text_row_to_json, JsonConversionError};
use super::{conn_pool, local_conn_pool};
use super::local_conn_pool;
use crate::auth::backend::{ComputeCredentialKeys, ComputeUserInfo};
use crate::auth::{endpoint_sni, ComputeUserInfoParseError};
use crate::config::{AuthenticationConfig, HttpConfig, ProxyConfig, TlsConfig};
@@ -37,6 +38,7 @@ use crate::error::{ErrorKind, ReportableError, UserFacingError};
use crate::metrics::{HttpDirection, Metrics};
use crate::proxy::{run_until_cancelled, NeonOptions};
use crate::serverless::backend::HttpConnError;
use crate::usage_metrics::{MetricCounter, MetricCounterRecorder};
use crate::{DbName, RoleName};
@@ -607,7 +609,8 @@ async fn handle_db_inner(
let client = match keys.keys {
ComputeCredentialKeys::JwtPayload(payload) if is_local_proxy => {
let mut client = backend.connect_to_local_postgres(ctx, conn_info).await?;
client.set_jwt_session(&payload).await?;
let (cli_inner, _dsc) = client.client_inner();
cli_inner.set_jwt_session(&payload).await?;
Client::Local(client)
}
_ => {
@@ -1021,12 +1024,12 @@ async fn query_to_json<T: GenericClient>(
}
enum Client {
Remote(conn_pool::Client<tokio_postgres::Client>),
Remote(conn_pool_lib::Client<tokio_postgres::Client>),
Local(local_conn_pool::LocalClient<tokio_postgres::Client>),
}
enum Discard<'a> {
Remote(conn_pool::Discard<'a, tokio_postgres::Client>),
Remote(conn_pool_lib::Discard<'a, tokio_postgres::Client>),
Local(local_conn_pool::Discard<'a, tokio_postgres::Client>),
}
@@ -1041,7 +1044,7 @@ impl Client {
fn inner(&mut self) -> (&mut tokio_postgres::Client, Discard<'_>) {
match self {
Client::Remote(client) => {
let (c, d) = client.inner();
let (c, d) = client.inner_mut();
(c, Discard::Remote(d))
}
Client::Local(local_client) => {

View File

@@ -138,7 +138,7 @@ pub struct ProjectData {
pub name: String,
pub region_id: String,
pub platform_id: String,
pub user_id: String,
pub user_id: Option<String>,
pub pageserver_id: Option<u64>,
#[serde(deserialize_with = "from_nullable_id")]
pub tenant: TenantId,

View File

@@ -16,13 +16,13 @@ use remote_storage::{GenericRemoteStorage, ListingMode, ListingObject, RemotePat
use serde::{Deserialize, Serialize};
use tokio_stream::StreamExt;
use tokio_util::sync::CancellationToken;
use utils::id::TenantId;
use utils::{backoff, id::TenantId};
use crate::{
cloud_admin_api::{CloudAdminApiClient, MaybeDeleted, ProjectData},
init_remote, list_objects_with_retries,
metadata_stream::{stream_tenant_timelines, stream_tenants},
BucketConfig, ConsoleConfig, NodeKind, TenantShardTimelineId, TraversingDepth,
BucketConfig, ConsoleConfig, NodeKind, TenantShardTimelineId, TraversingDepth, MAX_RETRIES,
};
#[derive(Serialize, Deserialize, Debug)]
@@ -250,13 +250,16 @@ async fn find_garbage_inner(
&target.tenant_root(&tenant_shard_id),
)
.await?;
let object = tenant_objects.keys.first().unwrap();
if object.key.get_path().as_str().ends_with("heatmap-v1.json") {
tracing::info!("Tenant {tenant_shard_id}: is missing in console and is only a heatmap (known historic deletion bug)");
garbage.append_buggy(GarbageEntity::Tenant(tenant_shard_id));
continue;
if let Some(object) = tenant_objects.keys.first() {
if object.key.get_path().as_str().ends_with("heatmap-v1.json") {
tracing::info!("Tenant {tenant_shard_id}: is missing in console and is only a heatmap (known historic deletion bug)");
garbage.append_buggy(GarbageEntity::Tenant(tenant_shard_id));
continue;
} else {
tracing::info!("Tenant {tenant_shard_id} is missing in console and contains one object: {}", object.key);
}
} else {
tracing::info!("Tenant {tenant_shard_id} is missing in console and contains one object: {}", object.key);
tracing::info!("Tenant {tenant_shard_id} is missing in console appears to have been deleted while we ran");
}
} else {
// A console-unknown tenant with timelines: check if these timelines only contain initdb.tar.zst, from the initial
@@ -406,14 +409,17 @@ pub async fn get_tenant_objects(
// TODO: apply extra validation based on object modification time. Don't purge
// tenants where any timeline's index_part.json has been touched recently.
let list = s3_client
.list(
Some(&tenant_root),
ListingMode::NoDelimiter,
None,
&CancellationToken::new(),
)
.await?;
let cancel = CancellationToken::new();
let list = backoff::retry(
|| s3_client.list(Some(&tenant_root), ListingMode::NoDelimiter, None, &cancel),
|_| false,
3,
MAX_RETRIES as u32,
"get_tenant_objects",
&cancel,
)
.await
.expect("dummy cancellation token")?;
Ok(list.keys)
}
@@ -424,14 +430,25 @@ pub async fn get_timeline_objects(
tracing::debug!("Listing objects in timeline {ttid}");
let timeline_root = super::remote_timeline_path_id(&ttid);
let list = s3_client
.list(
Some(&timeline_root),
ListingMode::NoDelimiter,
None,
&CancellationToken::new(),
)
.await?;
let cancel = CancellationToken::new();
let list = backoff::retry(
|| {
s3_client.list(
Some(&timeline_root),
ListingMode::NoDelimiter,
None,
&cancel,
)
},
|_| false,
3,
MAX_RETRIES as u32,
"get_timeline_objects",
&cancel,
)
.await
.expect("dummy cancellation token")?;
Ok(list.keys)
}