diff --git a/proxy/src/serverless/conn_pool.rs b/proxy/src/serverless/conn_pool.rs index d09554a922..f0d6794782 100644 --- a/proxy/src/serverless/conn_pool.rs +++ b/proxy/src/serverless/conn_pool.rs @@ -1,7 +1,7 @@ use anyhow::Context; use async_trait::async_trait; use dashmap::DashMap; -use futures::future::poll_fn; +use futures::{future::poll_fn, TryStreamExt}; use parking_lot::RwLock; use pbkdf2::{ password_hash::{PasswordHashString, PasswordHasher, PasswordVerifier, SaltString}, @@ -210,12 +210,7 @@ impl GlobalConnPool { client.session.send(session_id)?; latency_timer.pool_hit(); latency_timer.success(); - return Ok(Client { - conn_id: client.conn_id, - inner: Some(client), - span: Span::current(), - pool, - }); + return Ok(Client::new(client, pool).await); } } else { let conn_id = uuid::Uuid::new_v4(); @@ -263,15 +258,11 @@ impl GlobalConnPool { _ => {} } - new_client.map(|inner| Client { - conn_id: inner.conn_id, - inner: Some(inner), - span: Span::current(), - pool, - }) + // new_client.map(|inner| Client::new(inner, pool).await) + Ok(Client::new(new_client?, pool).await) } - fn put(&self, conn_info: &ConnInfo, client: ClientInner) -> anyhow::Result<()> { + fn put(&self, conn_info: &ConnInfo, client: ClientInner, pid: i32) -> anyhow::Result<()> { let conn_id = client.conn_id; // We want to hold this open while we return. This ensures that the pool can't close @@ -315,9 +306,9 @@ impl GlobalConnPool { // do logging outside of the mutex if returned { - info!(%conn_id, "pool: returning connection '{conn_info}' back to the pool, total_conns={total_conns}, for this (db, user)={per_db_size}"); + info!(%conn_id, "pool: returning connection '{conn_info}' back to the pool, total_conns={total_conns}, for this (db, user)={per_db_size}, pid={pid}"); } else { - info!(%conn_id, "pool: throwing away connection '{conn_info}' because pool is full, total_conns={total_conns}"); + info!(%conn_id, "pool: throwing away connection '{conn_info}' because pool is full, total_conns={total_conns}, pid={pid}"); } Ok(()) @@ -528,6 +519,20 @@ struct ClientInner { conn_id: uuid::Uuid, } +impl ClientInner { + pub async fn get_pid(&mut self) -> anyhow::Result { + let rows = self.inner.query("select pg_backend_pid();", &[]).await?; + if rows.len() != 1 { + Err(anyhow::anyhow!( + "expected 1 row from pg_backend_pid(), got {}", + rows.len() + )) + } else { + Ok(rows[0].get(0)) + } + } +} + impl Client { pub fn metrics(&self) -> Arc { USAGE_METRICS.register(self.inner.as_ref().unwrap().ids.clone()) @@ -539,6 +544,7 @@ pub struct Client { span: Span, inner: Option, pool: Option<(ConnInfo, Arc)>, + pid: i32, } pub struct Discard<'a> { @@ -547,12 +553,22 @@ pub struct Discard<'a> { } impl Client { + pub async fn new(mut inner: ClientInner, pool: Option<(ConnInfo, Arc)>) -> Self { + Self { + conn_id: inner.conn_id, + pid: inner.get_pid().await.unwrap_or(-1), + inner: Some(inner), + span: Span::current(), + pool, + } + } pub fn inner(&mut self) -> (&mut tokio_postgres::Client, Discard<'_>) { let Self { inner, pool, conn_id, span: _, + pid: -1, } = self; ( &mut inner @@ -612,7 +628,7 @@ impl Drop for Client { // return connection to the pool tokio::task::spawn_blocking(move || { let _span = current_span.enter(); - let _ = conn_pool.put(&conn_info, client); + let _ = conn_pool.put(&conn_info, client, self.pid); }); } }