make connection polling generic

This commit is contained in:
Conrad Ludgate
2025-04-16 16:18:55 +01:00
parent 86fb432ab2
commit 897cea978a
5 changed files with 106 additions and 146 deletions

View File

@@ -16,9 +16,9 @@ use tracing::field::display;
use tracing::{debug, info};
use super::AsyncRW;
use super::conn_pool::poll_client;
use super::conn_pool::poll_client_generic;
use super::conn_pool_lib::{Client, ConnInfo, EndpointConnPool, GlobalConnPool};
use super::http_conn_pool::{self, HttpConnPool, Send, poll_http2_client};
use super::http_conn_pool::{self, HttpConnPool, Send};
use super::local_conn_pool::{self, EXT_NAME, EXT_SCHEMA, EXT_VERSION, LocalConnPool};
use crate::auth::backend::local::StaticAuthRules;
use crate::auth::backend::{ComputeCredentials, ComputeUserInfo};
@@ -577,7 +577,7 @@ impl ConnectMechanism for TokioMechanism {
info!("latency={}, query_id={}", ctx.get_proxy_latency(), query_id);
}
Ok(poll_client(
Ok(poll_client_generic(
self.pool.clone(),
ctx,
self.conn_info.clone(),
@@ -638,10 +638,10 @@ impl ConnectMechanism for HyperMechanism {
info!("latency={}, query_id={}", ctx.get_proxy_latency(), query_id);
}
Ok(poll_http2_client(
Ok(poll_client_generic(
self.pool.clone(),
ctx,
&self.conn_info,
self.conn_info.clone(),
client,
connection,
self.conn_id,

View File

@@ -11,7 +11,7 @@ use smallvec::SmallVec;
use tokio::net::TcpStream;
use tokio::time::Instant;
use tokio_util::sync::CancellationToken;
use tracing::{Instrument, error, info, info_span, warn};
use tracing::{error, info, info_span, warn};
#[cfg(test)]
use {
super::conn_pool_lib::GlobalConnPoolOptions,
@@ -20,8 +20,7 @@ use {
};
use super::conn_pool_lib::{
Client, ClientDataEnum, ClientInnerCommon, ClientInnerExt, ConnInfo, EndpointConnPool,
GlobalConnPool,
ClientDataEnum, ClientInnerCommon, ConnInfo, EndpointConnPoolExt, GlobalConnPool,
};
use crate::context::RequestContext;
use crate::control_plane::messages::MetricsAuxInfo;
@@ -29,6 +28,7 @@ use crate::metrics::Metrics;
use crate::tls::postgres_rustls::MakeRustlsConnect;
type TlsStream = <MakeRustlsConnect as MakeTlsConnect<TcpStream>>::Stream;
pub(super) type Conn = postgres_client::Connection<TcpStream, TlsStream>;
#[derive(Debug, Clone)]
pub(crate) struct ConnInfoWithAuth {
@@ -56,20 +56,20 @@ impl fmt::Display for ConnInfo {
}
}
pub(crate) fn poll_client<C: ClientInnerExt>(
global_pool: Arc<GlobalConnPool<EndpointConnPool<C>>>,
pub(crate) fn poll_client_generic<P: EndpointConnPoolExt>(
global_pool: Arc<GlobalConnPool<P>>,
ctx: &RequestContext,
conn_info: ConnInfo,
client: C,
mut connection: postgres_client::Connection<TcpStream, TlsStream>,
client: P::ClientInner,
connection: P::Connection,
conn_id: uuid::Uuid,
aux: MetricsAuxInfo,
) -> Client<C> {
) -> P::Client {
let conn_gauge = Metrics::get().proxy.db_connections.guard(ctx.protocol());
let mut session_id = ctx.session_id();
let session_id = ctx.session_id();
let (tx, mut rx) = tokio::sync::watch::channel(session_id);
let span = info_span!(parent: None, "connection", %conn_id);
let span = info_span!(parent: None, "connection", %conn_id, %session_id);
let cold_start_info = ctx.cold_start_info();
span.in_scope(|| {
info!(cold_start_info = cold_start_info.as_str(), %conn_info, %session_id, "new connection");
@@ -85,27 +85,30 @@ pub(crate) fn poll_client<C: ClientInnerExt>(
let cancel = CancellationToken::new();
let cancelled = cancel.clone().cancelled_owned();
tokio::spawn(
async move {
tokio::spawn(async move {
let _conn_gauge = conn_gauge;
let mut idle_timeout = pin!(tokio::time::sleep(idle));
let mut cancelled = pin!(cancelled);
let mut connection = pin!(P::spawn_conn(connection));
poll_fn(move |cx| {
let _enter = span.enter();
if cancelled.as_mut().poll(cx).is_ready() {
info!("connection dropped");
return Poll::Ready(())
return Poll::Ready(());
}
match rx.has_changed() {
Ok(true) => {
session_id = *rx.borrow_and_update();
info!(%session_id, "changed session");
let session_id = *rx.borrow_and_update();
span.record("session_id", tracing::field::display(session_id));
info!("changed session");
idle_timeout.as_mut().reset(Instant::now() + idle);
}
Err(_) => {
info!("connection dropped");
return Poll::Ready(())
return Poll::Ready(());
}
_ => {}
}
@@ -117,48 +120,25 @@ pub(crate) fn poll_client<C: ClientInnerExt>(
if let Some(pool) = pool.clone().upgrade() {
// remove client from pool - should close the connection if it's idle.
// does nothing if the client is currently checked-out and in-use
if pool.write().remove_client(db_user.clone(), conn_id) {
if pool.write().remove_conn(db_user.clone(), conn_id) {
info!("idle connection removed");
}
}
}
loop {
let message = ready!(connection.poll_message(cx));
match message {
Some(Ok(AsyncMessage::Notice(notice))) => {
info!(%session_id, "notice: {}", notice);
}
Some(Ok(AsyncMessage::Notification(notif))) => {
warn!(%session_id, pid = notif.process_id(), channel = notif.channel(), "notification received");
}
Some(Ok(_)) => {
warn!(%session_id, "unknown message");
}
Some(Err(e)) => {
error!(%session_id, "connection error: {}", e);
break
}
None => {
info!("connection closed");
break
}
}
}
ready!(connection.as_mut().poll(cx));
// remove from connection pool
if let Some(pool) = pool.clone().upgrade() {
if pool.write().remove_client(db_user.clone(), conn_id) {
if pool.write().remove_conn(db_user.clone(), conn_id) {
info!("closed connection removed");
}
}
Poll::Ready(())
}).await;
}
.instrument(span));
})
.await;
});
let inner = ClientInnerCommon {
inner: client,
aux,
@@ -169,7 +149,42 @@ pub(crate) fn poll_client<C: ClientInnerExt>(
}),
};
Client::new(inner, conn_info, pool_clone)
P::wrap_client(inner, conn_info, pool_clone)
}
pub async fn poll_tokio_postgres_conn_really(mut connection: Conn) {
poll_fn(move |cx| {
loop {
let message = ready!(connection.poll_message(cx));
match message {
Some(Ok(AsyncMessage::Notice(notice))) => {
info!("notice: {}", notice);
}
Some(Ok(AsyncMessage::Notification(notif))) => {
warn!(
pid = notif.process_id(),
channel = notif.channel(),
"notification received"
);
}
Some(Ok(_)) => {
warn!("unknown message");
}
Some(Err(e)) => {
error!("connection error: {}", e);
break;
}
None => {
info!("connection closed");
break;
}
}
}
Poll::Ready(())
})
.await;
}
#[derive(Clone)]
@@ -183,7 +198,7 @@ impl ClientDataRemote {
&self.session
}
pub fn cancel(&mut self) {
pub fn cancel(&self) {
self.cancel.cancel();
}
}
@@ -195,6 +210,7 @@ mod tests {
use super::*;
use crate::proxy::NeonOptions;
use crate::serverless::cancel_set::CancelSet;
use crate::serverless::conn_pool_lib::{Client, ClientInnerExt};
use crate::types::{BranchId, EndpointId, ProjectId};
struct MockClient(Arc<AtomicBool>);

View File

@@ -11,8 +11,7 @@ use rand::Rng;
use smol_str::ToSmolStr;
use tracing::{Span, debug, info};
use super::conn_pool::ClientDataRemote;
use super::http_conn_pool::ClientDataHttp;
use super::conn_pool::{ClientDataRemote, poll_tokio_postgres_conn_really};
use super::local_conn_pool::ClientDataLocal;
use crate::auth::backend::ComputeUserInfo;
use crate::config::HttpConfig;
@@ -50,7 +49,6 @@ impl ConnInfo {
pub(crate) enum ClientDataEnum {
Remote(ClientDataRemote),
Local(ClientDataLocal),
Http(ClientDataHttp),
}
#[derive(Clone)]
@@ -63,14 +61,9 @@ pub(crate) struct ClientInnerCommon<C: ClientInnerExt> {
impl<C: ClientInnerExt> Drop for ClientInnerCommon<C> {
fn drop(&mut self) {
match &mut self.data {
ClientDataEnum::Remote(remote_data) => {
remote_data.cancel();
}
ClientDataEnum::Local(local_data) => {
local_data.cancel();
}
ClientDataEnum::Http(_http_data) => (),
match &self.data {
ClientDataEnum::Remote(remote_data) => remote_data.cancel(),
ClientDataEnum::Local(local_data) => local_data.cancel(),
}
}
}
@@ -325,9 +318,10 @@ impl<C: ClientInnerExt> DbUserConn<C> for DbUserConnPool<C> {
}
}
pub(crate) trait EndpointConnPoolExt {
pub(crate) trait EndpointConnPoolExt: Send + Sync + 'static {
type Client;
type ClientInner: ClientInnerExt;
type Connection: Send + 'static;
fn create(config: &HttpConfig, global_connections_count: Arc<AtomicUsize>) -> Self;
fn wrap_client(
@@ -340,6 +334,9 @@ pub(crate) trait EndpointConnPoolExt {
&mut self,
db_user: (DbName, RoleName),
) -> Option<ClientInnerCommon<Self::ClientInner>>;
fn remove_conn(&mut self, db_user: (DbName, RoleName), conn_id: uuid::Uuid) -> bool;
fn spawn_conn(conn: Self::Connection) -> impl Future<Output = ()> + Send + 'static;
fn clear_closed(&mut self) -> usize;
fn total_conns(&self) -> usize;
@@ -348,6 +345,7 @@ pub(crate) trait EndpointConnPoolExt {
impl<C: ClientInnerExt> EndpointConnPoolExt for EndpointConnPool<C> {
type Client = Client<C>;
type ClientInner = C;
type Connection = super::conn_pool::Conn;
fn create(config: &HttpConfig, global_connections_count: Arc<AtomicUsize>) -> Self {
EndpointConnPool {
@@ -376,6 +374,14 @@ impl<C: ClientInnerExt> EndpointConnPoolExt for EndpointConnPool<C> {
Some(self.get_conn_entry(db_user)?.conn)
}
fn remove_conn(&mut self, db_user: (DbName, RoleName), conn_id: uuid::Uuid) -> bool {
self.remove_client(db_user, conn_id)
}
async fn spawn_conn(conn: Self::Connection) {
poll_tokio_postgres_conn_really(conn).await;
}
fn clear_closed(&mut self) -> usize {
let mut clients_removed: usize = 0;
for db_pool in self.pools.values_mut() {
@@ -568,7 +574,6 @@ impl<P: EndpointConnPoolExt> GlobalConnPool<P> {
ClientDataEnum::Remote(data) => {
data.session().send(ctx.session_id()).ok()?;
}
ClientDataEnum::Http(_) => (),
}
ctx.set_cold_start_info(ColdStartInfo::HttpPoolHit);

View File

@@ -5,16 +5,14 @@ use std::sync::{Arc, Weak};
use hyper::client::conn::http2;
use hyper_util::rt::{TokioExecutor, TokioIo};
use smol_str::ToSmolStr;
use tracing::{Instrument, error, info, info_span};
use tracing::{error, info};
use super::AsyncRW;
use super::conn_pool_lib::{
ClientDataEnum, ClientInnerCommon, ClientInnerExt, ConnInfo, ConnPoolEntry,
EndpointConnPoolExt, GlobalConnPool,
ClientInnerCommon, ClientInnerExt, ConnInfo, ConnPoolEntry, EndpointConnPoolExt,
};
use crate::config::HttpConfig;
use crate::context::RequestContext;
use crate::control_plane::messages::MetricsAuxInfo;
use crate::metrics::{HttpEndpointPoolsGuard, Metrics};
use crate::protocol2::ConnectionInfoExtra;
use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS};
@@ -22,9 +20,6 @@ use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS};
pub(crate) type Send = http2::SendRequest<hyper::body::Incoming>;
pub(crate) type Connect = http2::Connection<TokioIo<AsyncRW>, hyper::body::Incoming, TokioExecutor>;
#[derive(Clone)]
pub(crate) struct ClientDataHttp();
// Per-endpoint connection pool
// Number of open connections is limited by the `max_conns_per_endpoint`.
pub(crate) struct HttpConnPool {
@@ -86,6 +81,7 @@ impl HttpConnPool {
impl EndpointConnPoolExt for HttpConnPool {
type Client = Client<Send>;
type ClientInner = Send;
type Connection = Connect;
fn create(_config: &HttpConfig, global_connections_count: Arc<AtomicUsize>) -> Self {
HttpConnPool {
@@ -110,6 +106,22 @@ impl EndpointConnPoolExt for HttpConnPool {
Some(self.get_conn_entry()?.conn)
}
fn remove_conn(
&mut self,
_db_user: (crate::types::DbName, crate::types::RoleName),
conn_id: uuid::Uuid,
) -> bool {
self.remove_conn(conn_id)
}
async fn spawn_conn(conn: Self::Connection) {
let res = conn.await;
match res {
Ok(()) => info!("connection closed"),
Err(e) => error!("connection error: {e:?}"),
}
}
fn clear_closed(&mut self) -> usize {
let Self { conns, .. } = self;
let old_len = conns.len();
@@ -138,77 +150,6 @@ impl Drop for HttpConnPool {
}
}
pub(crate) fn poll_http2_client(
global_pool: Arc<GlobalConnPool<HttpConnPool>>,
ctx: &RequestContext,
conn_info: &ConnInfo,
client: Send,
connection: Connect,
conn_id: uuid::Uuid,
aux: MetricsAuxInfo,
) -> Client<Send> {
let conn_gauge = Metrics::get().proxy.db_connections.guard(ctx.protocol());
let session_id = ctx.session_id();
let span = info_span!(parent: None, "connection", %conn_id);
let cold_start_info = ctx.cold_start_info();
span.in_scope(|| {
info!(cold_start_info = cold_start_info.as_str(), %conn_info, %session_id, "new connection");
});
let pool = match conn_info.endpoint_cache_key() {
Some(endpoint) => {
let pool = global_pool.get_or_create_endpoint_pool(&endpoint);
let client = ClientInnerCommon {
inner: client.clone(),
aux: aux.clone(),
conn_id,
data: ClientDataEnum::Http(ClientDataHttp()),
};
pool.write().conns.push_back(ConnPoolEntry {
conn: client,
_last_access: std::time::Instant::now(),
});
Metrics::get()
.proxy
.http_pool_opened_connections
.get_metric()
.inc();
Arc::downgrade(&pool)
}
None => Weak::new(),
};
tokio::spawn(
async move {
let _conn_gauge = conn_gauge;
let res = connection.await;
match res {
Ok(()) => info!("connection closed"),
Err(e) => error!(%session_id, "connection error: {e:?}"),
}
// remove from connection pool
if let Some(pool) = pool.clone().upgrade() {
if pool.write().remove_conn(conn_id) {
info!("closed connection removed");
}
}
}
.instrument(span),
);
let client = ClientInnerCommon {
inner: client,
aux,
conn_id,
data: ClientDataEnum::Http(ClientDataHttp()),
};
Client::new(client)
}
pub(crate) struct Client<C: ClientInnerExt + Clone> {
pub(crate) inner: ClientInnerCommon<C>,
}

View File

@@ -57,7 +57,7 @@ impl ClientDataLocal {
&self.session
}
pub fn cancel(&mut self) {
pub fn cancel(&self) {
self.cancel.cancel();
}
}
@@ -120,11 +120,9 @@ impl<C: ClientInnerExt> LocalConnPool<C> {
ClientDataEnum::Local(data) => {
data.session().send(ctx.session_id())?;
}
ClientDataEnum::Remote(data) => {
data.session().send(ctx.session_id())?;
}
ClientDataEnum::Http(_) => (),
}
ctx.set_cold_start_info(ColdStartInfo::HttpPoolHit);