diff --git a/proxy/src/serverless/backend.rs b/proxy/src/serverless/backend.rs index cb13283791..7f7ddeec50 100644 --- a/proxy/src/serverless/backend.rs +++ b/proxy/src/serverless/backend.rs @@ -16,9 +16,9 @@ use tracing::field::display; use tracing::{debug, info}; use super::AsyncRW; -use super::conn_pool::poll_client; +use super::conn_pool::poll_client_generic; use super::conn_pool_lib::{Client, ConnInfo, EndpointConnPool, GlobalConnPool}; -use super::http_conn_pool::{self, HttpConnPool, Send, poll_http2_client}; +use super::http_conn_pool::{self, HttpConnPool, Send}; use super::local_conn_pool::{self, EXT_NAME, EXT_SCHEMA, EXT_VERSION, LocalConnPool}; use crate::auth::backend::local::StaticAuthRules; use crate::auth::backend::{ComputeCredentials, ComputeUserInfo}; @@ -577,7 +577,7 @@ impl ConnectMechanism for TokioMechanism { info!("latency={}, query_id={}", ctx.get_proxy_latency(), query_id); } - Ok(poll_client( + Ok(poll_client_generic( self.pool.clone(), ctx, self.conn_info.clone(), @@ -638,10 +638,10 @@ impl ConnectMechanism for HyperMechanism { info!("latency={}, query_id={}", ctx.get_proxy_latency(), query_id); } - Ok(poll_http2_client( + Ok(poll_client_generic( self.pool.clone(), ctx, - &self.conn_info, + self.conn_info.clone(), client, connection, self.conn_id, diff --git a/proxy/src/serverless/conn_pool.rs b/proxy/src/serverless/conn_pool.rs index b5f32986c2..da552bfadf 100644 --- a/proxy/src/serverless/conn_pool.rs +++ b/proxy/src/serverless/conn_pool.rs @@ -11,7 +11,7 @@ use smallvec::SmallVec; use tokio::net::TcpStream; use tokio::time::Instant; use tokio_util::sync::CancellationToken; -use tracing::{Instrument, error, info, info_span, warn}; +use tracing::{error, info, info_span, warn}; #[cfg(test)] use { super::conn_pool_lib::GlobalConnPoolOptions, @@ -20,8 +20,7 @@ use { }; use super::conn_pool_lib::{ - Client, ClientDataEnum, ClientInnerCommon, ClientInnerExt, ConnInfo, EndpointConnPool, - GlobalConnPool, + ClientDataEnum, ClientInnerCommon, ConnInfo, EndpointConnPoolExt, GlobalConnPool, }; use crate::context::RequestContext; use crate::control_plane::messages::MetricsAuxInfo; @@ -29,6 +28,7 @@ use crate::metrics::Metrics; use crate::tls::postgres_rustls::MakeRustlsConnect; type TlsStream = >::Stream; +pub(super) type Conn = postgres_client::Connection; #[derive(Debug, Clone)] pub(crate) struct ConnInfoWithAuth { @@ -56,20 +56,20 @@ impl fmt::Display for ConnInfo { } } -pub(crate) fn poll_client( - global_pool: Arc>>, +pub(crate) fn poll_client_generic( + global_pool: Arc>, ctx: &RequestContext, conn_info: ConnInfo, - client: C, - mut connection: postgres_client::Connection, + client: P::ClientInner, + connection: P::Connection, conn_id: uuid::Uuid, aux: MetricsAuxInfo, -) -> Client { +) -> P::Client { let conn_gauge = Metrics::get().proxy.db_connections.guard(ctx.protocol()); - let mut session_id = ctx.session_id(); + let session_id = ctx.session_id(); let (tx, mut rx) = tokio::sync::watch::channel(session_id); - let span = info_span!(parent: None, "connection", %conn_id); + let span = info_span!(parent: None, "connection", %conn_id, %session_id); let cold_start_info = ctx.cold_start_info(); span.in_scope(|| { info!(cold_start_info = cold_start_info.as_str(), %conn_info, %session_id, "new connection"); @@ -85,27 +85,30 @@ pub(crate) fn poll_client( let cancel = CancellationToken::new(); let cancelled = cancel.clone().cancelled_owned(); - tokio::spawn( - async move { + tokio::spawn(async move { let _conn_gauge = conn_gauge; let mut idle_timeout = pin!(tokio::time::sleep(idle)); let mut cancelled = pin!(cancelled); + let mut connection = pin!(P::spawn_conn(connection)); poll_fn(move |cx| { + let _enter = span.enter(); + if cancelled.as_mut().poll(cx).is_ready() { info!("connection dropped"); - return Poll::Ready(()) + return Poll::Ready(()); } match rx.has_changed() { Ok(true) => { - session_id = *rx.borrow_and_update(); - info!(%session_id, "changed session"); + let session_id = *rx.borrow_and_update(); + span.record("session_id", tracing::field::display(session_id)); + info!("changed session"); idle_timeout.as_mut().reset(Instant::now() + idle); } Err(_) => { info!("connection dropped"); - return Poll::Ready(()) + return Poll::Ready(()); } _ => {} } @@ -117,48 +120,25 @@ pub(crate) fn poll_client( if let Some(pool) = pool.clone().upgrade() { // remove client from pool - should close the connection if it's idle. // does nothing if the client is currently checked-out and in-use - if pool.write().remove_client(db_user.clone(), conn_id) { + if pool.write().remove_conn(db_user.clone(), conn_id) { info!("idle connection removed"); } } } - loop { - let message = ready!(connection.poll_message(cx)); - - match message { - Some(Ok(AsyncMessage::Notice(notice))) => { - info!(%session_id, "notice: {}", notice); - } - Some(Ok(AsyncMessage::Notification(notif))) => { - warn!(%session_id, pid = notif.process_id(), channel = notif.channel(), "notification received"); - } - Some(Ok(_)) => { - warn!(%session_id, "unknown message"); - } - Some(Err(e)) => { - error!(%session_id, "connection error: {}", e); - break - } - None => { - info!("connection closed"); - break - } - } - } + ready!(connection.as_mut().poll(cx)); // remove from connection pool if let Some(pool) = pool.clone().upgrade() { - if pool.write().remove_client(db_user.clone(), conn_id) { + if pool.write().remove_conn(db_user.clone(), conn_id) { info!("closed connection removed"); } } Poll::Ready(()) - }).await; - - } - .instrument(span)); + }) + .await; + }); let inner = ClientInnerCommon { inner: client, aux, @@ -169,7 +149,42 @@ pub(crate) fn poll_client( }), }; - Client::new(inner, conn_info, pool_clone) + P::wrap_client(inner, conn_info, pool_clone) +} + +pub async fn poll_tokio_postgres_conn_really(mut connection: Conn) { + poll_fn(move |cx| { + loop { + let message = ready!(connection.poll_message(cx)); + + match message { + Some(Ok(AsyncMessage::Notice(notice))) => { + info!("notice: {}", notice); + } + Some(Ok(AsyncMessage::Notification(notif))) => { + warn!( + pid = notif.process_id(), + channel = notif.channel(), + "notification received" + ); + } + Some(Ok(_)) => { + warn!("unknown message"); + } + Some(Err(e)) => { + error!("connection error: {}", e); + break; + } + None => { + info!("connection closed"); + break; + } + } + } + + Poll::Ready(()) + }) + .await; } #[derive(Clone)] @@ -183,7 +198,7 @@ impl ClientDataRemote { &self.session } - pub fn cancel(&mut self) { + pub fn cancel(&self) { self.cancel.cancel(); } } @@ -195,6 +210,7 @@ mod tests { use super::*; use crate::proxy::NeonOptions; use crate::serverless::cancel_set::CancelSet; + use crate::serverless::conn_pool_lib::{Client, ClientInnerExt}; use crate::types::{BranchId, EndpointId, ProjectId}; struct MockClient(Arc); diff --git a/proxy/src/serverless/conn_pool_lib.rs b/proxy/src/serverless/conn_pool_lib.rs index c3e1000f6b..b10e82b006 100644 --- a/proxy/src/serverless/conn_pool_lib.rs +++ b/proxy/src/serverless/conn_pool_lib.rs @@ -11,8 +11,7 @@ use rand::Rng; use smol_str::ToSmolStr; use tracing::{Span, debug, info}; -use super::conn_pool::ClientDataRemote; -use super::http_conn_pool::ClientDataHttp; +use super::conn_pool::{ClientDataRemote, poll_tokio_postgres_conn_really}; use super::local_conn_pool::ClientDataLocal; use crate::auth::backend::ComputeUserInfo; use crate::config::HttpConfig; @@ -50,7 +49,6 @@ impl ConnInfo { pub(crate) enum ClientDataEnum { Remote(ClientDataRemote), Local(ClientDataLocal), - Http(ClientDataHttp), } #[derive(Clone)] @@ -63,14 +61,9 @@ pub(crate) struct ClientInnerCommon { impl Drop for ClientInnerCommon { fn drop(&mut self) { - match &mut self.data { - ClientDataEnum::Remote(remote_data) => { - remote_data.cancel(); - } - ClientDataEnum::Local(local_data) => { - local_data.cancel(); - } - ClientDataEnum::Http(_http_data) => (), + match &self.data { + ClientDataEnum::Remote(remote_data) => remote_data.cancel(), + ClientDataEnum::Local(local_data) => local_data.cancel(), } } } @@ -325,9 +318,10 @@ impl DbUserConn for DbUserConnPool { } } -pub(crate) trait EndpointConnPoolExt { +pub(crate) trait EndpointConnPoolExt: Send + Sync + 'static { type Client; type ClientInner: ClientInnerExt; + type Connection: Send + 'static; fn create(config: &HttpConfig, global_connections_count: Arc) -> Self; fn wrap_client( @@ -340,6 +334,9 @@ pub(crate) trait EndpointConnPoolExt { &mut self, db_user: (DbName, RoleName), ) -> Option>; + fn remove_conn(&mut self, db_user: (DbName, RoleName), conn_id: uuid::Uuid) -> bool; + + fn spawn_conn(conn: Self::Connection) -> impl Future + Send + 'static; fn clear_closed(&mut self) -> usize; fn total_conns(&self) -> usize; @@ -348,6 +345,7 @@ pub(crate) trait EndpointConnPoolExt { impl EndpointConnPoolExt for EndpointConnPool { type Client = Client; type ClientInner = C; + type Connection = super::conn_pool::Conn; fn create(config: &HttpConfig, global_connections_count: Arc) -> Self { EndpointConnPool { @@ -376,6 +374,14 @@ impl EndpointConnPoolExt for EndpointConnPool { Some(self.get_conn_entry(db_user)?.conn) } + fn remove_conn(&mut self, db_user: (DbName, RoleName), conn_id: uuid::Uuid) -> bool { + self.remove_client(db_user, conn_id) + } + + async fn spawn_conn(conn: Self::Connection) { + poll_tokio_postgres_conn_really(conn).await; + } + fn clear_closed(&mut self) -> usize { let mut clients_removed: usize = 0; for db_pool in self.pools.values_mut() { @@ -568,7 +574,6 @@ impl GlobalConnPool

{ ClientDataEnum::Remote(data) => { data.session().send(ctx.session_id()).ok()?; } - ClientDataEnum::Http(_) => (), } ctx.set_cold_start_info(ColdStartInfo::HttpPoolHit); diff --git a/proxy/src/serverless/http_conn_pool.rs b/proxy/src/serverless/http_conn_pool.rs index 7e77da8408..0e36a1a288 100644 --- a/proxy/src/serverless/http_conn_pool.rs +++ b/proxy/src/serverless/http_conn_pool.rs @@ -5,16 +5,14 @@ use std::sync::{Arc, Weak}; use hyper::client::conn::http2; use hyper_util::rt::{TokioExecutor, TokioIo}; use smol_str::ToSmolStr; -use tracing::{Instrument, error, info, info_span}; +use tracing::{error, info}; use super::AsyncRW; use super::conn_pool_lib::{ - ClientDataEnum, ClientInnerCommon, ClientInnerExt, ConnInfo, ConnPoolEntry, - EndpointConnPoolExt, GlobalConnPool, + ClientInnerCommon, ClientInnerExt, ConnInfo, ConnPoolEntry, EndpointConnPoolExt, }; use crate::config::HttpConfig; use crate::context::RequestContext; -use crate::control_plane::messages::MetricsAuxInfo; use crate::metrics::{HttpEndpointPoolsGuard, Metrics}; use crate::protocol2::ConnectionInfoExtra; use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS}; @@ -22,9 +20,6 @@ use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS}; pub(crate) type Send = http2::SendRequest; pub(crate) type Connect = http2::Connection, hyper::body::Incoming, TokioExecutor>; -#[derive(Clone)] -pub(crate) struct ClientDataHttp(); - // Per-endpoint connection pool // Number of open connections is limited by the `max_conns_per_endpoint`. pub(crate) struct HttpConnPool { @@ -86,6 +81,7 @@ impl HttpConnPool { impl EndpointConnPoolExt for HttpConnPool { type Client = Client; type ClientInner = Send; + type Connection = Connect; fn create(_config: &HttpConfig, global_connections_count: Arc) -> Self { HttpConnPool { @@ -110,6 +106,22 @@ impl EndpointConnPoolExt for HttpConnPool { Some(self.get_conn_entry()?.conn) } + fn remove_conn( + &mut self, + _db_user: (crate::types::DbName, crate::types::RoleName), + conn_id: uuid::Uuid, + ) -> bool { + self.remove_conn(conn_id) + } + + async fn spawn_conn(conn: Self::Connection) { + let res = conn.await; + match res { + Ok(()) => info!("connection closed"), + Err(e) => error!("connection error: {e:?}"), + } + } + fn clear_closed(&mut self) -> usize { let Self { conns, .. } = self; let old_len = conns.len(); @@ -138,77 +150,6 @@ impl Drop for HttpConnPool { } } -pub(crate) fn poll_http2_client( - global_pool: Arc>, - ctx: &RequestContext, - conn_info: &ConnInfo, - client: Send, - connection: Connect, - conn_id: uuid::Uuid, - aux: MetricsAuxInfo, -) -> Client { - let conn_gauge = Metrics::get().proxy.db_connections.guard(ctx.protocol()); - let session_id = ctx.session_id(); - - let span = info_span!(parent: None, "connection", %conn_id); - let cold_start_info = ctx.cold_start_info(); - span.in_scope(|| { - info!(cold_start_info = cold_start_info.as_str(), %conn_info, %session_id, "new connection"); - }); - - let pool = match conn_info.endpoint_cache_key() { - Some(endpoint) => { - let pool = global_pool.get_or_create_endpoint_pool(&endpoint); - let client = ClientInnerCommon { - inner: client.clone(), - aux: aux.clone(), - conn_id, - data: ClientDataEnum::Http(ClientDataHttp()), - }; - pool.write().conns.push_back(ConnPoolEntry { - conn: client, - _last_access: std::time::Instant::now(), - }); - Metrics::get() - .proxy - .http_pool_opened_connections - .get_metric() - .inc(); - - Arc::downgrade(&pool) - } - None => Weak::new(), - }; - - tokio::spawn( - async move { - let _conn_gauge = conn_gauge; - let res = connection.await; - match res { - Ok(()) => info!("connection closed"), - Err(e) => error!(%session_id, "connection error: {e:?}"), - } - - // remove from connection pool - if let Some(pool) = pool.clone().upgrade() { - if pool.write().remove_conn(conn_id) { - info!("closed connection removed"); - } - } - } - .instrument(span), - ); - - let client = ClientInnerCommon { - inner: client, - aux, - conn_id, - data: ClientDataEnum::Http(ClientDataHttp()), - }; - - Client::new(client) -} - pub(crate) struct Client { pub(crate) inner: ClientInnerCommon, } diff --git a/proxy/src/serverless/local_conn_pool.rs b/proxy/src/serverless/local_conn_pool.rs index ed55bc2063..3d07f242f9 100644 --- a/proxy/src/serverless/local_conn_pool.rs +++ b/proxy/src/serverless/local_conn_pool.rs @@ -57,7 +57,7 @@ impl ClientDataLocal { &self.session } - pub fn cancel(&mut self) { + pub fn cancel(&self) { self.cancel.cancel(); } } @@ -120,11 +120,9 @@ impl LocalConnPool { ClientDataEnum::Local(data) => { data.session().send(ctx.session_id())?; } - ClientDataEnum::Remote(data) => { data.session().send(ctx.session_id())?; } - ClientDataEnum::Http(_) => (), } ctx.set_cold_start_info(ColdStartInfo::HttpPoolHit);