Compare commits

...

4 Commits

Author SHA1 Message Date
Conrad Ludgate
4c573e13d4 fix tests 2024-05-01 12:00:40 +01:00
Conrad Ludgate
f40bf86575 separate connection from poll client 2024-05-01 11:45:36 +01:00
Conrad Ludgate
eb48df0bbf move connection future out of poll_fn 2024-05-01 11:28:16 +01:00
Conrad Ludgate
e05bfd6fd5 slight optimisation 2024-05-01 11:28:16 +01:00
2 changed files with 228 additions and 177 deletions

View File

@@ -16,7 +16,7 @@ use crate::{
proxy::connect_compute::ConnectMechanism,
};
use super::conn_pool::{poll_client, Client, ConnInfo, GlobalConnPool};
use super::conn_pool::{poll_tokio_client, Client, ConnInfo, GlobalConnPool};
pub struct PoolingBackend {
pub pool: Arc<GlobalConnPool<tokio_postgres::Client>>,
@@ -184,7 +184,7 @@ impl ConnectMechanism for TokioMechanism {
drop(pause);
tracing::Span::current().record("pid", &tracing::field::display(client.get_process_id()));
Ok(poll_client(
Ok(poll_tokio_client(
self.pool.clone(),
ctx,
self.conn_info.clone(),

View File

@@ -1,9 +1,11 @@
use dashmap::DashMap;
use futures::{future::poll_fn, Future};
use futures::Future;
use parking_lot::RwLock;
use pin_project_lite::pin_project;
use rand::Rng;
use smallvec::SmallVec;
use std::{collections::HashMap, pin::pin, sync::Arc, sync::Weak, time::Duration};
use std::sync::Weak;
use std::{collections::HashMap, sync::Arc, time::Duration};
use std::{
fmt,
task::{ready, Poll},
@@ -12,13 +14,13 @@ use std::{
ops::Deref,
sync::atomic::{self, AtomicUsize},
};
use tokio::time::Instant;
use tokio::time::{Instant, Sleep};
use tokio_postgres::tls::NoTlsStream;
use tokio_postgres::{AsyncMessage, ReadyForQueryStatus, Socket};
use tokio_util::sync::CancellationToken;
use tokio_util::sync::{CancellationToken, WaitForCancellationFutureOwned};
use crate::console::messages::{ColdStartInfo, MetricsAuxInfo};
use crate::metrics::{HttpEndpointPoolsGuard, Metrics};
use crate::metrics::{HttpEndpointPoolsGuard, Metrics, NumDbConnectionsGuard};
use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS};
use crate::{
auth::backend::ComputeUserInfo, context::RequestMonitoring, DbName, EndpointCacheKey, RoleName,
@@ -91,7 +93,7 @@ impl<C: ClientInnerExt> EndpointConnPool<C> {
..
} = self;
pools.get_mut(&db_user).and_then(|pool_entries| {
pool_entries.get_conn_entry(total_conns, global_connections_count.clone())
pool_entries.get_conn_entry(total_conns, global_connections_count)
})
}
@@ -125,19 +127,16 @@ impl<C: ClientInnerExt> EndpointConnPool<C> {
fn put(pool: &RwLock<Self>, conn_info: &ConnInfo, client: ClientInner<C>) {
let conn_id = client.conn_id;
if client.is_closed() {
info!(%conn_id, "pool: throwing away connection '{conn_info}' because connection is closed");
return;
}
let global_max_conn = pool.read().global_pool_size_max_conns;
if pool
.read()
.global_connections_count
.load(atomic::Ordering::Relaxed)
>= global_max_conn
{
info!(%conn_id, "pool: throwing away connection '{conn_info}' because pool is full");
return;
let pool = pool.read();
if pool
.global_connections_count
.load(atomic::Ordering::Relaxed)
>= pool.global_pool_size_max_conns
{
info!(%conn_id, "pool: throwing away connection '{conn_info}' because pool is full");
return;
}
}
// return connection to the pool
@@ -217,7 +216,7 @@ impl<C: ClientInnerExt> DbUserConnPool<C> {
fn get_conn_entry(
&mut self,
conns: &mut usize,
global_connections_count: Arc<AtomicUsize>,
global_connections_count: &AtomicUsize,
) -> Option<ConnPoolEntry<C>> {
let mut removed = self.clear_closed_clients(conns);
let conn = self.conns.pop();
@@ -463,109 +462,97 @@ impl<C: ClientInnerExt> GlobalConnPool<C> {
}
}
pub fn poll_client<C: ClientInnerExt>(
pub fn poll_tokio_client(
global_pool: Arc<GlobalConnPool<tokio_postgres::Client>>,
ctx: &mut RequestMonitoring,
conn_info: ConnInfo,
client: tokio_postgres::Client,
mut connection: tokio_postgres::Connection<Socket, NoTlsStream>,
conn_id: uuid::Uuid,
aux: MetricsAuxInfo,
) -> Client<tokio_postgres::Client> {
let connection = std::future::poll_fn(move |cx| {
loop {
let message = ready!(connection.poll_message(cx));
match message {
Some(Ok(AsyncMessage::Notice(notice))) => {
info!("notice: {}", notice);
}
Some(Ok(AsyncMessage::Notification(notif))) => {
warn!(
pid = notif.process_id(),
channel = notif.channel(),
"notification received"
);
}
Some(Ok(_)) => {
warn!("unknown message");
}
Some(Err(e)) => {
error!("connection error: {}", e);
break;
}
None => {
info!("connection closed");
break;
}
}
}
Poll::Ready(())
});
poll_client(
global_pool,
ctx,
conn_info,
client,
connection,
conn_id,
aux,
)
}
pub fn poll_client<C: ClientInnerExt, I: Future<Output = ()> + Send + 'static>(
global_pool: Arc<GlobalConnPool<C>>,
ctx: &mut RequestMonitoring,
conn_info: ConnInfo,
client: C,
mut connection: tokio_postgres::Connection<Socket, NoTlsStream>,
connection: I,
conn_id: uuid::Uuid,
aux: MetricsAuxInfo,
) -> Client<C> {
let conn_gauge = Metrics::get().proxy.db_connections.guard(ctx.protocol);
let mut session_id = ctx.session_id;
let (tx, mut rx) = tokio::sync::watch::channel(session_id);
let session_id = ctx.session_id;
let (tx, rx) = tokio::sync::watch::channel(session_id);
let span = info_span!(parent: None, "connection", %conn_id);
let cold_start_info = ctx.cold_start_info;
span.in_scope(|| {
info!(cold_start_info = cold_start_info.as_str(), %conn_info, %session_id, "new connection");
let session_span = info_span!(parent: span.clone(), "", %session_id);
session_span.in_scope(|| {
info!(cold_start_info = cold_start_info.as_str(), %conn_info, "new connection");
});
let pool = match conn_info.endpoint_cache_key() {
Some(endpoint) => Arc::downgrade(&global_pool.get_or_create_endpoint_pool(&endpoint)),
None => Weak::new(),
};
let pool_clone = pool.clone();
let db_user = conn_info.db_and_user();
let idle = global_pool.get_idle_timeout();
let cancel = CancellationToken::new();
let cancelled = cancel.clone().cancelled_owned();
tokio::spawn(
async move {
let _conn_gauge = conn_gauge;
let mut idle_timeout = pin!(tokio::time::sleep(idle));
let mut cancelled = pin!(cancelled);
let db_conn = DbConnection {
cancelled: cancel.clone().cancelled_owned(),
idle_timeout: tokio::time::sleep(idle),
idle,
db_user: conn_info.db_and_user(),
pool: pool.clone(),
session_span,
session_rx: rx,
conn_gauge,
conn_id,
connection,
};
poll_fn(move |cx| {
if cancelled.as_mut().poll(cx).is_ready() {
info!("connection dropped");
return Poll::Ready(())
}
tokio::spawn(db_conn.instrument(span));
match rx.has_changed() {
Ok(true) => {
session_id = *rx.borrow_and_update();
info!(%session_id, "changed session");
idle_timeout.as_mut().reset(Instant::now() + idle);
}
Err(_) => {
info!("connection dropped");
return Poll::Ready(())
}
_ => {}
}
// 5 minute idle connection timeout
if idle_timeout.as_mut().poll(cx).is_ready() {
idle_timeout.as_mut().reset(Instant::now() + idle);
info!("connection idle");
if let Some(pool) = pool.clone().upgrade() {
// remove client from pool - should close the connection if it's idle.
// does nothing if the client is currently checked-out and in-use
if pool.write().remove_client(db_user.clone(), conn_id) {
info!("idle connection removed");
}
}
}
loop {
let message = ready!(connection.poll_message(cx));
match message {
Some(Ok(AsyncMessage::Notice(notice))) => {
info!(%session_id, "notice: {}", notice);
}
Some(Ok(AsyncMessage::Notification(notif))) => {
warn!(%session_id, pid = notif.process_id(), channel = notif.channel(), "notification received");
}
Some(Ok(_)) => {
warn!(%session_id, "unknown message");
}
Some(Err(e)) => {
error!(%session_id, "connection error: {}", e);
break
}
None => {
info!("connection closed");
break
}
}
}
// remove from connection pool
if let Some(pool) = pool.clone().upgrade() {
if pool.write().remove_client(db_user.clone(), conn_id) {
info!("closed connection removed");
}
}
Poll::Ready(())
}).await;
}
.instrument(span));
let inner = ClientInner {
inner: client,
session: tx,
@@ -573,7 +560,94 @@ pub fn poll_client<C: ClientInnerExt>(
aux,
conn_id,
};
Client::new(inner, conn_info, pool_clone)
Client::new(inner, conn_info, pool)
}
pin_project! {
struct DbConnection<C: ClientInnerExt, Inner> {
#[pin]
cancelled: WaitForCancellationFutureOwned,
#[pin]
idle_timeout: Sleep,
idle: tokio::time::Duration,
db_user: (DbName, RoleName),
pool: Weak<RwLock<EndpointConnPool<C>>>,
session_span: tracing::Span,
session_rx: tokio::sync::watch::Receiver<uuid::Uuid>,
conn_gauge: NumDbConnectionsGuard<'static>,
conn_id: uuid::Uuid,
#[pin]
connection: Inner,
}
}
impl<C: ClientInnerExt, I: Future<Output = ()>> Future for DbConnection<C, I> {
type Output = ();
fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
let mut this = self.project();
if this.cancelled.as_mut().poll(cx).is_ready() {
let _span = this.session_span.enter();
info!("connection dropped");
return Poll::Ready(());
}
match this.session_rx.has_changed() {
Ok(true) => {
let session_id = *this.session_rx.borrow_and_update();
*this.session_span = info_span!("", %session_id);
let _span = this.session_span.enter();
info!("changed session");
this.idle_timeout
.as_mut()
.reset(Instant::now() + *this.idle);
}
Err(_) => {
let _span = this.session_span.enter();
info!("connection dropped");
return Poll::Ready(());
}
_ => {}
}
let _span = this.session_span.enter();
// 5 minute idle connection timeout
if this.idle_timeout.as_mut().poll(cx).is_ready() {
this.idle_timeout
.as_mut()
.reset(Instant::now() + *this.idle);
info!("connection idle");
if let Some(pool) = this.pool.upgrade() {
// remove client from pool - should close the connection if it's idle.
// does nothing if the client is currently checked-out and in-use
if pool
.write()
.remove_client(this.db_user.clone(), *this.conn_id)
{
info!("idle connection removed");
}
}
}
ready!(this.connection.poll(cx));
// remove from connection pool
if let Some(pool) = this.pool.upgrade() {
if pool
.write()
.remove_client(this.db_user.clone(), *this.conn_id)
{
info!("closed connection removed");
}
}
Poll::Ready(())
}
}
struct ClientInner<C: ClientInnerExt> {
@@ -686,72 +760,70 @@ impl<C: ClientInnerExt> Deref for Client<C> {
}
impl<C: ClientInnerExt> Client<C> {
fn do_drop(&mut self) -> Option<impl FnOnce()> {
fn do_drop(&mut self) {
let conn_info = self.conn_info.clone();
let client = self
.inner
.take()
.expect("client inner should not be removed");
if let Some(conn_pool) = std::mem::take(&mut self.pool).upgrade() {
let current_span = self.span.clone();
// return connection to the pool
return Some(move || {
let _span = current_span.enter();
EndpointConnPool::put(&conn_pool, &conn_info, client);
});
let conn_id = client.conn_id;
if client.is_closed() {
info!(%conn_id, "pool: throwing away connection '{conn_info}' because connection is closed");
return;
}
if let Some(conn_pool) = self.pool.upgrade() {
// return connection to the pool
let _span = self.span.enter();
EndpointConnPool::put(&conn_pool, &conn_info, client);
}
None
}
}
impl<C: ClientInnerExt> Drop for Client<C> {
fn drop(&mut self) {
if let Some(drop) = self.do_drop() {
tokio::task::spawn_blocking(drop);
}
self.do_drop();
}
}
#[cfg(test)]
mod tests {
use std::{mem, sync::atomic::AtomicBool};
use crate::{BranchId, EndpointId, ProjectId};
use super::*;
struct MockClient(Arc<AtomicBool>);
impl MockClient {
fn new(is_closed: bool) -> Self {
MockClient(Arc::new(is_closed.into()))
}
}
struct MockClient(CancellationToken);
impl ClientInnerExt for MockClient {
fn is_closed(&self) -> bool {
self.0.load(atomic::Ordering::Relaxed)
self.0.is_cancelled()
}
fn get_process_id(&self) -> i32 {
0
}
}
fn create_inner() -> ClientInner<MockClient> {
create_inner_with(MockClient::new(false))
}
fn create_inner_with(client: MockClient) -> ClientInner<MockClient> {
ClientInner {
inner: client,
session: tokio::sync::watch::Sender::new(uuid::Uuid::new_v4()),
cancel: CancellationToken::new(),
aux: MetricsAuxInfo {
fn create_inner(
global_pool: Arc<GlobalConnPool<MockClient>>,
conn_info: ConnInfo,
) -> (Client<MockClient>, CancellationToken) {
let cancelled = CancellationToken::new();
let client = poll_client(
global_pool,
&mut RequestMonitoring::test(),
conn_info,
MockClient(cancelled.clone()),
cancelled.clone().cancelled_owned(),
uuid::Uuid::new_v4(),
MetricsAuxInfo {
endpoint_id: (&EndpointId::from("endpoint")).into(),
project_id: (&ProjectId::from("project")).into(),
branch_id: (&BranchId::from("branch")).into(),
cold_start_info: crate::console::messages::ColdStartInfo::Warm,
},
conn_id: uuid::Uuid::new_v4(),
}
);
(client, cancelled)
}
#[tokio::test]
@@ -778,51 +850,36 @@ mod tests {
dbname: "dbname".into(),
password: "password".as_bytes().into(),
};
let ep_pool = Arc::downgrade(
&pool.get_or_create_endpoint_pool(&conn_info.endpoint_cache_key().unwrap()),
);
{
let mut client = Client::new(create_inner(), conn_info.clone(), ep_pool.clone());
let (mut client, _) = create_inner(pool.clone(), conn_info.clone());
assert_eq!(0, pool.get_global_connections_count());
client.inner().1.discard();
drop(client);
// Discard should not add the connection from the pool.
assert_eq!(0, pool.get_global_connections_count());
}
{
let mut client = Client::new(create_inner(), conn_info.clone(), ep_pool.clone());
client.do_drop().unwrap()();
mem::forget(client); // drop the client
let (client, _) = create_inner(pool.clone(), conn_info.clone());
drop(client);
assert_eq!(1, pool.get_global_connections_count());
}
{
let mut closed_client = Client::new(
create_inner_with(MockClient::new(true)),
conn_info.clone(),
ep_pool.clone(),
);
closed_client.do_drop().unwrap()();
mem::forget(closed_client); // drop the client
// The closed client shouldn't be added to the pool.
let (client, cancel) = create_inner(pool.clone(), conn_info.clone());
cancel.cancel();
drop(client);
// The closed client shouldn't be added to the pool.
assert_eq!(1, pool.get_global_connections_count());
}
let is_closed: Arc<AtomicBool> = Arc::new(false.into());
{
let mut client = Client::new(
create_inner_with(MockClient(is_closed.clone())),
conn_info.clone(),
ep_pool.clone(),
);
client.do_drop().unwrap()();
mem::forget(client); // drop the client
let cancel = {
let (client, cancel) = create_inner(pool.clone(), conn_info.clone());
drop(client);
// The client should be added to the pool.
assert_eq!(2, pool.get_global_connections_count());
}
cancel
};
{
let mut client = Client::new(create_inner(), conn_info, ep_pool);
client.do_drop().unwrap()();
mem::forget(client); // drop the client
let client = create_inner(pool.clone(), conn_info.clone());
drop(client);
// The client shouldn't be added to the pool. Because the ep-pool is full.
assert_eq!(2, pool.get_global_connections_count());
}
@@ -836,25 +893,19 @@ mod tests {
dbname: "dbname".into(),
password: "password".as_bytes().into(),
};
let ep_pool = Arc::downgrade(
&pool.get_or_create_endpoint_pool(&conn_info.endpoint_cache_key().unwrap()),
);
{
let mut client = Client::new(create_inner(), conn_info.clone(), ep_pool.clone());
client.do_drop().unwrap()();
mem::forget(client); // drop the client
let client = create_inner(pool.clone(), conn_info.clone());
drop(client);
assert_eq!(3, pool.get_global_connections_count());
}
{
let mut client = Client::new(create_inner(), conn_info.clone(), ep_pool.clone());
client.do_drop().unwrap()();
mem::forget(client); // drop the client
let client = create_inner(pool.clone(), conn_info.clone());
drop(client);
// The client shouldn't be added to the pool. Because the global pool is full.
assert_eq!(3, pool.get_global_connections_count());
}
is_closed.store(true, atomic::Ordering::Relaxed);
cancel.cancel();
// Do gc for all shards.
pool.gc(0);
pool.gc(1);