diff --git a/proxy/src/serverless/conn_pool.rs b/proxy/src/serverless/conn_pool.rs index 188dadb5d2..0383cf83c7 100644 --- a/proxy/src/serverless/conn_pool.rs +++ b/proxy/src/serverless/conn_pool.rs @@ -1,9 +1,10 @@ use dashmap::DashMap; -use futures::{future::poll_fn, Future}; +use futures::Future; use parking_lot::RwLock; +use pin_project_lite::pin_project; use rand::Rng; use smallvec::SmallVec; -use std::{collections::HashMap, pin::pin, sync::Arc, sync::Weak, time::Duration}; +use std::{collections::HashMap, sync::Arc, sync::Weak, time::Duration}; use std::{ fmt, task::{ready, Poll}, @@ -12,13 +13,13 @@ use std::{ ops::Deref, sync::atomic::{self, AtomicUsize}, }; -use tokio::time::Instant; +use tokio::time::{Instant, Sleep}; use tokio_postgres::tls::NoTlsStream; use tokio_postgres::{AsyncMessage, ReadyForQueryStatus, Socket}; -use tokio_util::sync::CancellationToken; +use tokio_util::sync::{CancellationToken, WaitForCancellationFutureOwned}; use crate::console::messages::{ColdStartInfo, MetricsAuxInfo}; -use crate::metrics::{HttpEndpointPoolsGuard, Metrics}; +use crate::metrics::{HttpEndpointPoolsGuard, Metrics, NumDbConnectionsGuard}; use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS}; use crate::{ auth::backend::ComputeUserInfo, context::RequestMonitoring, DbName, EndpointCacheKey, RoleName, @@ -470,8 +471,8 @@ pub fn poll_client( aux: MetricsAuxInfo, ) -> Client { let conn_gauge = Metrics::get().proxy.db_connections.guard(ctx.protocol); - let mut session_id = ctx.session_id; - let (tx, mut rx) = tokio::sync::watch::channel(session_id); + let session_id = ctx.session_id; + let (tx, rx) = tokio::sync::watch::channel(session_id); let span = info_span!(parent: None, "connection", %conn_id); let cold_start_info = ctx.cold_start_info; @@ -482,87 +483,27 @@ pub fn poll_client( Some(endpoint) => Arc::downgrade(&global_pool.get_or_create_endpoint_pool(&endpoint)), None => Weak::new(), }; - let pool_clone = pool.clone(); - let db_user = conn_info.db_and_user(); let idle = global_pool.get_idle_timeout(); let cancel = CancellationToken::new(); - let cancelled = cancel.clone().cancelled_owned(); - tokio::spawn( - async move { - let _conn_gauge = conn_gauge; - let mut idle_timeout = pin!(tokio::time::sleep(idle)); - let mut cancelled = pin!(cancelled); + let db_conn = DbConnection { + cancelled: cancel.clone().cancelled_owned(), + idle_timeout: tokio::time::sleep(idle), + idle, + db_user: conn_info.db_and_user(), + pool: pool.clone(), + session_id, + session_rx: rx, + conn_gauge, + conn_id, + connection, + }; - poll_fn(move |cx| { - if cancelled.as_mut().poll(cx).is_ready() { - info!("connection dropped"); - return Poll::Ready(()) - } + tokio::spawn(async move { + db_conn.instrument(span).await; + }); - match rx.has_changed() { - Ok(true) => { - session_id = *rx.borrow_and_update(); - info!(%session_id, "changed session"); - idle_timeout.as_mut().reset(Instant::now() + idle); - } - Err(_) => { - info!("connection dropped"); - return Poll::Ready(()) - } - _ => {} - } - - // 5 minute idle connection timeout - if idle_timeout.as_mut().poll(cx).is_ready() { - idle_timeout.as_mut().reset(Instant::now() + idle); - info!("connection idle"); - if let Some(pool) = pool.clone().upgrade() { - // remove client from pool - should close the connection if it's idle. - // does nothing if the client is currently checked-out and in-use - if pool.write().remove_client(db_user.clone(), conn_id) { - info!("idle connection removed"); - } - } - } - - loop { - let message = ready!(connection.poll_message(cx)); - - match message { - Some(Ok(AsyncMessage::Notice(notice))) => { - info!(%session_id, "notice: {}", notice); - } - Some(Ok(AsyncMessage::Notification(notif))) => { - warn!(%session_id, pid = notif.process_id(), channel = notif.channel(), "notification received"); - } - Some(Ok(_)) => { - warn!(%session_id, "unknown message"); - } - Some(Err(e)) => { - error!(%session_id, "connection error: {}", e); - break - } - None => { - info!("connection closed"); - break - } - } - } - - // remove from connection pool - if let Some(pool) = pool.clone().upgrade() { - if pool.write().remove_client(db_user.clone(), conn_id) { - info!("closed connection removed"); - } - } - - Poll::Ready(()) - }).await; - - } - .instrument(span)); let inner = ClientInner { inner: client, session: tx, @@ -570,7 +511,109 @@ pub fn poll_client( aux, conn_id, }; - Client::new(inner, conn_info, pool_clone) + Client::new(inner, conn_info, pool) +} + +pin_project! { + struct DbConnection { + #[pin] + cancelled: WaitForCancellationFutureOwned, + + #[pin] + idle_timeout: Sleep, + idle: tokio::time::Duration, + + db_user: (DbName, RoleName), + pool: Weak>>, + + session_id: uuid::Uuid, + session_rx: tokio::sync::watch::Receiver, + + conn_gauge: NumDbConnectionsGuard<'static>, + conn_id: uuid::Uuid, + connection: tokio_postgres::Connection, + } +} + +impl Future for DbConnection { + type Output = (); + + fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { + let mut this = self.project(); + if this.cancelled.as_mut().poll(cx).is_ready() { + info!("connection dropped"); + return Poll::Ready(()); + } + + match this.session_rx.has_changed() { + Ok(true) => { + *this.session_id = *this.session_rx.borrow_and_update(); + info!(%this.session_id, "changed session"); + this.idle_timeout + .as_mut() + .reset(Instant::now() + *this.idle); + } + Err(_) => { + info!("connection dropped"); + return Poll::Ready(()); + } + _ => {} + } + + // 5 minute idle connection timeout + if this.idle_timeout.as_mut().poll(cx).is_ready() { + this.idle_timeout + .as_mut() + .reset(Instant::now() + *this.idle); + info!("connection idle"); + if let Some(pool) = this.pool.clone().upgrade() { + // remove client from pool - should close the connection if it's idle. + // does nothing if the client is currently checked-out and in-use + if pool + .write() + .remove_client(this.db_user.clone(), *this.conn_id) + { + info!("idle connection removed"); + } + } + } + + loop { + let message = ready!(this.connection.poll_message(cx)); + + match message { + Some(Ok(AsyncMessage::Notice(notice))) => { + info!(session_id = %this.session_id, "notice: {}", notice); + } + Some(Ok(AsyncMessage::Notification(notif))) => { + warn!(session_id = %this.session_id, pid = notif.process_id(), channel = notif.channel(), "notification received"); + } + Some(Ok(_)) => { + warn!(session_id = %this.session_id, "unknown message"); + } + Some(Err(e)) => { + error!(session_id = %this.session_id, "connection error: {}", e); + break; + } + None => { + info!("connection closed"); + break; + } + } + } + + // remove from connection pool + if let Some(pool) = this.pool.upgrade() { + if pool + .write() + .remove_client(this.db_user.clone(), *this.conn_id) + { + info!("closed connection removed"); + } + } + + Poll::Ready(()) + } } struct ClientInner {