start on http2 local proxy connection pool

This commit is contained in:
Conrad Ludgate
2024-09-18 13:42:20 +01:00
parent 76515cdae3
commit 2703abccc7
4 changed files with 592 additions and 33 deletions

View File

@@ -5,6 +5,7 @@
mod backend;
pub mod cancel_set;
mod conn_pool;
mod http_conn_pool;
mod http_util;
mod json;
mod sql_over_http;
@@ -81,7 +82,28 @@ pub async fn task_main(
}
});
let http_conn_pool = http_conn_pool::GlobalConnPool::new(&config.http_config);
{
let http_conn_pool = Arc::clone(&http_conn_pool);
tokio::spawn(async move {
http_conn_pool.gc_worker(StdRng::from_entropy()).await;
});
}
// shutdown the connection pool
tokio::spawn({
let cancellation_token = cancellation_token.clone();
let http_conn_pool = http_conn_pool.clone();
async move {
cancellation_token.cancelled().await;
tokio::task::spawn_blocking(move || http_conn_pool.shutdown())
.await
.unwrap();
}
});
let backend = Arc::new(PoolingBackend {
http_conn_pool: Arc::clone(&http_conn_pool),
pool: Arc::clone(&conn_pool),
config,
endpoint_rate_limiter: Arc::clone(&endpoint_rate_limiter),

View File

@@ -1,7 +1,12 @@
use std::{sync::Arc, time::Duration};
use std::{io, sync::Arc, time::Duration};
use async_trait::async_trait;
use tracing::{debug, field::display, info};
use bytes::Bytes;
use http_body_util::Full;
use hyper1::client::conn::http2;
use hyper_util::rt::{TokioExecutor, TokioIo};
use tokio::net::{lookup_host, TcpStream};
use tracing::{field::display, info};
use crate::{
auth::{
@@ -27,9 +32,13 @@ use crate::{
Host,
};
use super::conn_pool::{poll_client, Client, ConnInfo, GlobalConnPool};
use super::{
conn_pool::{poll_client, Client, ConnInfo, GlobalConnPool},
http_conn_pool::{self, poll_http2_client},
};
pub(crate) struct PoolingBackend {
pub(crate) http_conn_pool: Arc<super::http_conn_pool::GlobalConnPool>,
pub(crate) pool: Arc<GlobalConnPool<tokio_postgres::Client>>,
pub(crate) config: &'static ProxyConfig,
pub(crate) endpoint_rate_limiter: Arc<EndpointRateLimiter>,
@@ -190,6 +199,39 @@ impl PoolingBackend {
)
.await
}
// Wake up the destination if needed
#[tracing::instrument(fields(pid = tracing::field::Empty), skip_all)]
pub(crate) async fn connect_to_local_proxy(
&self,
ctx: &RequestMonitoring,
conn_info: ConnInfo,
keys: ComputeCredentials,
) -> Result<http_conn_pool::Client, HttpConnError> {
info!("pool: looking for an existing connection");
if let Some(client) = self.http_conn_pool.get(ctx, &conn_info) {
return Ok(client);
}
let conn_id = uuid::Uuid::new_v4();
tracing::Span::current().record("conn_id", display(conn_id));
info!(%conn_id, "pool: opening a new connection '{conn_info}'");
let backend = self.config.auth_backend.as_ref().map(|()| keys);
crate::proxy::connect_compute::connect_to_compute(
ctx,
&HyperMechanism {
conn_id,
conn_info,
pool: self.http_conn_pool.clone(),
locks: &self.config.connect_compute_locks,
},
&backend,
false, // do not allow self signed compute for http flow
self.config.wake_compute_retry_config,
self.config.connect_to_compute_retry_config,
)
.await
}
}
#[derive(Debug, thiserror::Error)]
@@ -198,6 +240,10 @@ pub(crate) enum HttpConnError {
ConnectionClosedAbruptly(#[from] tokio::sync::watch::error::SendError<uuid::Uuid>),
#[error("could not connection to compute")]
ConnectionError(#[from] tokio_postgres::Error),
#[error("could not connection to compute")]
IoConnectionError(#[from] std::io::Error),
#[error("could not establish h2 connection to compute")]
H2ConnectionError(#[from] hyper1::Error),
#[error("could not get auth info")]
GetAuthInfo(#[from] GetAuthInfoError),
@@ -214,6 +260,8 @@ impl ReportableError for HttpConnError {
match self {
HttpConnError::ConnectionClosedAbruptly(_) => ErrorKind::Compute,
HttpConnError::ConnectionError(p) => p.get_error_kind(),
HttpConnError::IoConnectionError(_) => ErrorKind::Compute,
HttpConnError::H2ConnectionError(_) => ErrorKind::Compute,
HttpConnError::GetAuthInfo(a) => a.get_error_kind(),
HttpConnError::AuthError(a) => a.get_error_kind(),
HttpConnError::WakeCompute(w) => w.get_error_kind(),
@@ -227,6 +275,8 @@ impl UserFacingError for HttpConnError {
match self {
HttpConnError::ConnectionClosedAbruptly(_) => self.to_string(),
HttpConnError::ConnectionError(p) => p.to_string(),
HttpConnError::IoConnectionError(p) => p.to_string(),
HttpConnError::H2ConnectionError(_) => "Could not establish HTTP connection to the database".to_string(),
HttpConnError::GetAuthInfo(c) => c.to_string_client(),
HttpConnError::AuthError(c) => c.to_string_client(),
HttpConnError::WakeCompute(c) => c.to_string_client(),
@@ -241,6 +291,8 @@ impl CouldRetry for HttpConnError {
fn could_retry(&self) -> bool {
match self {
HttpConnError::ConnectionError(e) => e.could_retry(),
HttpConnError::IoConnectionError(e) => e.could_retry(),
HttpConnError::H2ConnectionError(_) => false,
HttpConnError::ConnectionClosedAbruptly(_) => false,
HttpConnError::GetAuthInfo(_) => false,
HttpConnError::AuthError(_) => false,
@@ -309,3 +361,100 @@ impl ConnectMechanism for TokioMechanism {
fn update_connect_config(&self, _config: &mut compute::ConnCfg) {}
}
struct HyperMechanism {
pool: Arc<http_conn_pool::GlobalConnPool>,
conn_info: ConnInfo,
conn_id: uuid::Uuid,
/// connect_to_compute concurrency lock
locks: &'static ApiLocks<Host>,
}
#[async_trait]
impl ConnectMechanism for HyperMechanism {
type Connection = http_conn_pool::Client;
type ConnectError = HttpConnError;
type Error = HttpConnError;
async fn connect_once(
&self,
ctx: &RequestMonitoring,
node_info: &CachedNodeInfo,
timeout: Duration,
) -> Result<Self::Connection, Self::ConnectError> {
let host = node_info.config.get_host()?;
let permit = self.locks.get_permit(&host).await?;
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
// let port = node_info.config.get_ports().first().unwrap_or_else(10432);
let res = connect_http2(&host, 10432, timeout).await;
drop(pause);
let (client, connection) = permit.release_result(res)?;
Ok(poll_http2_client(
self.pool.clone(),
ctx,
self.conn_info.clone(),
client,
connection,
self.conn_id,
node_info.aux.clone(),
))
}
fn update_connect_config(&self, _config: &mut compute::ConnCfg) {}
}
async fn connect_http2(
host: &str,
port: u16,
timeout: Duration,
) -> Result<
(
http2::SendRequest<Full<Bytes>>,
http2::Connection<TokioIo<TcpStream>, Full<Bytes>, TokioExecutor>,
),
HttpConnError,
> {
let mut addrs = lookup_host((host, port)).await?;
let mut last_err = None;
let stream = loop {
let Some(addr) = addrs.next() else {
return Err(last_err.unwrap_or_else(|| {
HttpConnError::IoConnectionError(io::Error::new(
io::ErrorKind::InvalidInput,
"could not resolve any addresses",
))
}));
};
let stream = match tokio::time::timeout(timeout, TcpStream::connect(addr)).await {
Ok(Ok(stream)) => stream,
Ok(Err(e)) => {
last_err = Some(HttpConnError::IoConnectionError(e));
continue;
}
Err(e) => {
last_err = Some(HttpConnError::IoConnectionError(io::Error::new(
io::ErrorKind::TimedOut,
e,
)));
continue;
}
};
stream.set_nodelay(true)?;
break stream;
};
let (client, connection) = hyper1::client::conn::http2::Builder::new(TokioExecutor::new())
.handshake(TokioIo::new(stream))
.await?;
Ok((client, connection))
}

View File

@@ -0,0 +1,360 @@
use bytes::Bytes;
use dashmap::DashMap;
use http_body_util::Full;
use hyper1::client::conn::http2;
use hyper_util::rt::{TokioExecutor, TokioIo};
use parking_lot::RwLock;
use rand::Rng;
use std::collections::VecDeque;
use std::{
ops::Deref,
sync::atomic::{self, AtomicUsize},
};
use std::{sync::Arc, sync::Weak};
use tokio::net::TcpStream;
use crate::console::messages::{ColdStartInfo, MetricsAuxInfo};
use crate::metrics::{HttpEndpointPoolsGuard, Metrics};
use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS};
use crate::{context::RequestMonitoring, EndpointCacheKey};
use tracing::{debug, error, Span};
use tracing::{info, info_span, Instrument};
use super::conn_pool::ConnInfo;
#[derive(Clone)]
struct ConnPoolEntry {
conn: http2::SendRequest<Full<Bytes>>,
conn_id: uuid::Uuid,
aux: MetricsAuxInfo,
}
// Per-endpoint connection pool
// Number of open connections is limited by the `max_conns_per_endpoint`.
pub(crate) struct EndpointConnPool {
conns: VecDeque<ConnPoolEntry>,
_guard: HttpEndpointPoolsGuard<'static>,
global_connections_count: Arc<AtomicUsize>,
}
impl EndpointConnPool {
fn get_conn_entry(&mut self) -> Option<ConnPoolEntry> {
let Self { conns, .. } = self;
let conn = conns.pop_front()?;
conns.push_back(conn.clone());
Some(conn)
}
fn remove_conn(&mut self, conn_id: uuid::Uuid) -> bool {
let Self {
conns,
global_connections_count,
..
} = self;
let old_len = conns.len();
conns.retain(|conn| conn.conn_id != conn_id);
let new_len = conns.len();
let removed = old_len - new_len;
if removed > 0 {
global_connections_count.fetch_sub(removed, atomic::Ordering::Relaxed);
Metrics::get()
.proxy
.http_pool_opened_connections
.get_metric()
.dec_by(removed as i64);
}
removed > 0
}
}
impl Drop for EndpointConnPool {
fn drop(&mut self) {
if !self.conns.is_empty() {
self.global_connections_count
.fetch_sub(self.conns.len(), atomic::Ordering::Relaxed);
Metrics::get()
.proxy
.http_pool_opened_connections
.get_metric()
.dec_by(self.conns.len() as i64);
}
}
}
pub(crate) struct GlobalConnPool {
// endpoint -> per-endpoint connection pool
//
// That should be a fairly conteded map, so return reference to the per-endpoint
// pool as early as possible and release the lock.
global_pool: DashMap<EndpointCacheKey, Arc<RwLock<EndpointConnPool>>>,
/// Number of endpoint-connection pools
///
/// [`DashMap::len`] iterates over all inner pools and acquires a read lock on each.
/// That seems like far too much effort, so we're using a relaxed increment counter instead.
/// It's only used for diagnostics.
global_pool_size: AtomicUsize,
/// Total number of connections in the pool
global_connections_count: Arc<AtomicUsize>,
config: &'static crate::config::HttpConfig,
}
impl GlobalConnPool {
pub(crate) fn new(config: &'static crate::config::HttpConfig) -> Arc<Self> {
let shards = config.pool_options.pool_shards;
Arc::new(Self {
global_pool: DashMap::with_shard_amount(shards),
global_pool_size: AtomicUsize::new(0),
config,
global_connections_count: Arc::new(AtomicUsize::new(0)),
})
}
pub(crate) fn shutdown(&self) {
// drops all strong references to endpoint-pools
self.global_pool.clear();
}
pub(crate) async fn gc_worker(&self, mut rng: impl Rng) {
let epoch = self.config.pool_options.gc_epoch;
let mut interval = tokio::time::interval(epoch / (self.global_pool.shards().len()) as u32);
loop {
interval.tick().await;
let shard = rng.gen_range(0..self.global_pool.shards().len());
self.gc(shard);
}
}
fn gc(&self, shard: usize) {
debug!(shard, "pool: performing epoch reclamation");
// acquire a random shard lock
let mut shard = self.global_pool.shards()[shard].write();
let timer = Metrics::get()
.proxy
.http_pool_reclaimation_lag_seconds
.start_timer();
let current_len = shard.len();
let mut clients_removed = 0;
shard.retain(|endpoint, x| {
// if the current endpoint pool is unique (no other strong or weak references)
// then it is currently not in use by any connections.
if let Some(pool) = Arc::get_mut(x.get_mut()) {
let EndpointConnPool { conns, .. } = pool.get_mut();
let old_len = conns.len();
conns.retain(|conn| !conn.conn.is_closed());
let new_len = conns.len();
let removed = old_len - new_len;
clients_removed += removed;
// we only remove this pool if it has no active connections
if conns.is_empty() {
info!("pool: discarding pool for endpoint {endpoint}");
return false;
}
}
true
});
let new_len = shard.len();
drop(shard);
timer.observe();
// Do logging outside of the lock.
if clients_removed > 0 {
let size = self
.global_connections_count
.fetch_sub(clients_removed, atomic::Ordering::Relaxed)
- clients_removed;
Metrics::get()
.proxy
.http_pool_opened_connections
.get_metric()
.dec_by(clients_removed as i64);
info!("pool: performed global pool gc. removed {clients_removed} clients, total number of clients in pool is {size}");
}
let removed = current_len - new_len;
if removed > 0 {
let global_pool_size = self
.global_pool_size
.fetch_sub(removed, atomic::Ordering::Relaxed)
- removed;
info!("pool: performed global pool gc. size now {global_pool_size}");
}
}
pub(crate) fn get(
self: &Arc<Self>,
ctx: &RequestMonitoring,
conn_info: &ConnInfo,
) -> Option<Client> {
let endpoint = conn_info.endpoint_cache_key()?;
let endpoint_pool = self.get_or_create_endpoint_pool(&endpoint);
let client = endpoint_pool.write().get_conn_entry()?;
if client.conn.is_closed() {
info!("pool: cached connection '{conn_info}' is closed, opening a new one");
return None;
}
tracing::Span::current().record("conn_id", tracing::field::display(client.conn_id));
info!(
cold_start_info = ColdStartInfo::HttpPoolHit.as_str(),
"pool: reusing connection '{conn_info}'"
);
ctx.set_cold_start_info(ColdStartInfo::HttpPoolHit);
ctx.success();
Some(Client::new(client.conn, conn_info.clone(), client.aux))
}
fn get_or_create_endpoint_pool(
self: &Arc<Self>,
endpoint: &EndpointCacheKey,
) -> Arc<RwLock<EndpointConnPool>> {
// fast path
if let Some(pool) = self.global_pool.get(endpoint) {
return pool.clone();
}
// slow path
let new_pool = Arc::new(RwLock::new(EndpointConnPool {
conns: VecDeque::new(),
_guard: Metrics::get().proxy.http_endpoint_pools.guard(),
global_connections_count: self.global_connections_count.clone(),
}));
// find or create a pool for this endpoint
let mut created = false;
let pool = self
.global_pool
.entry(endpoint.clone())
.or_insert_with(|| {
created = true;
new_pool
})
.clone();
// log new global pool size
if created {
let global_pool_size = self
.global_pool_size
.fetch_add(1, atomic::Ordering::Relaxed)
+ 1;
info!(
"pool: created new pool for '{endpoint}', global pool size now {global_pool_size}"
);
}
pool
}
}
pub(crate) fn poll_http2_client(
global_pool: Arc<GlobalConnPool>,
ctx: &RequestMonitoring,
conn_info: ConnInfo,
client: http2::SendRequest<Full<Bytes>>,
connection: http2::Connection<TokioIo<TcpStream>, Full<Bytes>, TokioExecutor>,
conn_id: uuid::Uuid,
aux: MetricsAuxInfo,
) -> Client {
let conn_gauge = Metrics::get().proxy.db_connections.guard(ctx.protocol());
let session_id = ctx.session_id();
let span = info_span!(parent: None, "connection", %conn_id);
let cold_start_info = ctx.cold_start_info();
span.in_scope(|| {
info!(cold_start_info = cold_start_info.as_str(), %conn_info, %session_id, "new connection");
});
let pool = match conn_info.endpoint_cache_key() {
Some(endpoint) => {
let pool = global_pool.get_or_create_endpoint_pool(&endpoint);
pool.write().conns.push_back(ConnPoolEntry {
conn: client.clone(),
conn_id,
aux: aux.clone(),
});
Arc::downgrade(&pool)
}
None => Weak::new(),
};
// let idle = global_pool.get_idle_timeout();
tokio::spawn(
async move {
let _conn_gauge = conn_gauge;
let res = connection.await;
match res {
Ok(()) => info!("connection closed"),
Err(e) => error!(%session_id, "connection error: {}", e),
}
// remove from connection pool
if let Some(pool) = pool.clone().upgrade() {
if pool.write().remove_conn(conn_id) {
info!("closed connection removed");
}
}
}
.instrument(span),
);
Client::new(client, conn_info, aux)
}
impl Client {
pub(crate) fn metrics(&self) -> Arc<MetricCounter> {
USAGE_METRICS.register(Ids {
endpoint_id: self.aux.endpoint_id,
branch_id: self.aux.branch_id,
})
}
}
pub(crate) struct Client {
span: Span,
inner: http2::SendRequest<Full<Bytes>>,
aux: MetricsAuxInfo,
conn_info: ConnInfo,
}
impl Client {
pub(self) fn new(
inner: http2::SendRequest<Full<Bytes>>,
conn_info: ConnInfo,
aux: MetricsAuxInfo,
) -> Self {
Self {
inner,
span: Span::current(),
conn_info,
aux,
}
}
pub(crate) fn inner(&mut self) -> &mut http2::SendRequest<Full<Bytes>> {
&mut self.inner
}
}
impl Deref for Client {
type Target = http2::SendRequest<Full<Bytes>>;
fn deref(&self) -> &Self::Target {
&self.inner
}
}

View File

@@ -514,18 +514,60 @@ async fn handle_inner(
"handling interactive connection from client"
);
//
// Determine the destination and connection params
//
let headers = request.headers();
// TLS config should be there.
let conn_info = get_conn_info(ctx, headers, config.tls_config.as_ref())?;
let conn_info = get_conn_info(ctx, request.headers(), config.tls_config.as_ref())?;
info!(
user = conn_info.conn_info.user_info.user.as_str(),
"credentials"
);
match conn_info.auth {
AuthData::Password(pw) => {
let res = handle_db_inner(
cancel,
config,
ctx,
request,
conn_info.conn_info,
&pw,
backend,
)
.await?;
Ok(res)
}
AuthData::Jwt(jwt) => {
let keys = backend
.authenticate_with_jwt(
ctx,
&config.authentication_config,
&conn_info.conn_info.user_info,
jwt,
)
.await
.map_err(HttpConnError::from)?;
let _client = backend
.connect_to_local_proxy(ctx, conn_info.conn_info, keys)
.await?;
todo!()
}
}
}
async fn handle_db_inner(
cancel: CancellationToken,
config: &'static ProxyConfig,
ctx: &RequestMonitoring,
request: Request<Incoming>,
conn_info: ConnInfo,
password: &[u8],
backend: Arc<PoolingBackend>,
) -> Result<Response<Full<Bytes>>, SqlOverHttpError> {
//
// Determine the destination and connection params
//
let headers = request.headers();
// Allow connection pooling only if explicitly requested
// or if we have decided that http pool is no longer opt-in
let allow_pool = !config.http_config.pool_options.opt_in
@@ -563,31 +605,17 @@ async fn handle_inner(
let authenticate_and_connect = Box::pin(
async {
let keys = match conn_info.auth {
AuthData::Password(pw) => {
backend
.authenticate_with_password(
ctx,
&config.authentication_config,
&conn_info.conn_info.user_info,
&pw,
)
.await?
}
AuthData::Jwt(jwt) => {
backend
.authenticate_with_jwt(
ctx,
&config.authentication_config,
&conn_info.conn_info.user_info,
jwt,
)
.await?
}
};
let keys = backend
.authenticate_with_password(
ctx,
&config.authentication_config,
&conn_info.user_info,
password,
)
.await?;
let client = backend
.connect_to_compute(ctx, conn_info.conn_info, keys, !allow_pool)
.connect_to_compute(ctx, conn_info, keys, !allow_pool)
.await?;
// not strictly necessary to mark success here,
// but it's just insurance for if we forget it somewhere else