diff --git a/proxy/src/serverless.rs b/proxy/src/serverless.rs index 84f98cb8ad..3278d9a658 100644 --- a/proxy/src/serverless.rs +++ b/proxy/src/serverless.rs @@ -5,6 +5,7 @@ mod backend; pub mod cancel_set; mod conn_pool; +mod http_conn_pool; mod http_util; mod json; mod sql_over_http; @@ -81,7 +82,28 @@ pub async fn task_main( } }); + let http_conn_pool = http_conn_pool::GlobalConnPool::new(&config.http_config); + { + let http_conn_pool = Arc::clone(&http_conn_pool); + tokio::spawn(async move { + http_conn_pool.gc_worker(StdRng::from_entropy()).await; + }); + } + + // shutdown the connection pool + tokio::spawn({ + let cancellation_token = cancellation_token.clone(); + let http_conn_pool = http_conn_pool.clone(); + async move { + cancellation_token.cancelled().await; + tokio::task::spawn_blocking(move || http_conn_pool.shutdown()) + .await + .unwrap(); + } + }); + let backend = Arc::new(PoolingBackend { + http_conn_pool: Arc::clone(&http_conn_pool), pool: Arc::clone(&conn_pool), config, endpoint_rate_limiter: Arc::clone(&endpoint_rate_limiter), diff --git a/proxy/src/serverless/backend.rs b/proxy/src/serverless/backend.rs index 995b7a7cda..3931dbb797 100644 --- a/proxy/src/serverless/backend.rs +++ b/proxy/src/serverless/backend.rs @@ -1,7 +1,12 @@ -use std::{sync::Arc, time::Duration}; +use std::{io, sync::Arc, time::Duration}; use async_trait::async_trait; -use tracing::{debug, field::display, info}; +use bytes::Bytes; +use http_body_util::Full; +use hyper1::client::conn::http2; +use hyper_util::rt::{TokioExecutor, TokioIo}; +use tokio::net::{lookup_host, TcpStream}; +use tracing::{field::display, info}; use crate::{ auth::{ @@ -27,9 +32,13 @@ use crate::{ Host, }; -use super::conn_pool::{poll_client, Client, ConnInfo, GlobalConnPool}; +use super::{ + conn_pool::{poll_client, Client, ConnInfo, GlobalConnPool}, + http_conn_pool::{self, poll_http2_client}, +}; pub(crate) struct PoolingBackend { + pub(crate) http_conn_pool: Arc, pub(crate) pool: Arc>, pub(crate) config: &'static ProxyConfig, pub(crate) endpoint_rate_limiter: Arc, @@ -190,6 +199,39 @@ impl PoolingBackend { ) .await } + + // Wake up the destination if needed + #[tracing::instrument(fields(pid = tracing::field::Empty), skip_all)] + pub(crate) async fn connect_to_local_proxy( + &self, + ctx: &RequestMonitoring, + conn_info: ConnInfo, + keys: ComputeCredentials, + ) -> Result { + info!("pool: looking for an existing connection"); + if let Some(client) = self.http_conn_pool.get(ctx, &conn_info) { + return Ok(client); + } + + let conn_id = uuid::Uuid::new_v4(); + tracing::Span::current().record("conn_id", display(conn_id)); + info!(%conn_id, "pool: opening a new connection '{conn_info}'"); + let backend = self.config.auth_backend.as_ref().map(|()| keys); + crate::proxy::connect_compute::connect_to_compute( + ctx, + &HyperMechanism { + conn_id, + conn_info, + pool: self.http_conn_pool.clone(), + locks: &self.config.connect_compute_locks, + }, + &backend, + false, // do not allow self signed compute for http flow + self.config.wake_compute_retry_config, + self.config.connect_to_compute_retry_config, + ) + .await + } } #[derive(Debug, thiserror::Error)] @@ -198,6 +240,10 @@ pub(crate) enum HttpConnError { ConnectionClosedAbruptly(#[from] tokio::sync::watch::error::SendError), #[error("could not connection to compute")] ConnectionError(#[from] tokio_postgres::Error), + #[error("could not connection to compute")] + IoConnectionError(#[from] std::io::Error), + #[error("could not establish h2 connection to compute")] + H2ConnectionError(#[from] hyper1::Error), #[error("could not get auth info")] GetAuthInfo(#[from] GetAuthInfoError), @@ -214,6 +260,8 @@ impl ReportableError for HttpConnError { match self { HttpConnError::ConnectionClosedAbruptly(_) => ErrorKind::Compute, HttpConnError::ConnectionError(p) => p.get_error_kind(), + HttpConnError::IoConnectionError(_) => ErrorKind::Compute, + HttpConnError::H2ConnectionError(_) => ErrorKind::Compute, HttpConnError::GetAuthInfo(a) => a.get_error_kind(), HttpConnError::AuthError(a) => a.get_error_kind(), HttpConnError::WakeCompute(w) => w.get_error_kind(), @@ -227,6 +275,8 @@ impl UserFacingError for HttpConnError { match self { HttpConnError::ConnectionClosedAbruptly(_) => self.to_string(), HttpConnError::ConnectionError(p) => p.to_string(), + HttpConnError::IoConnectionError(p) => p.to_string(), + HttpConnError::H2ConnectionError(_) => "Could not establish HTTP connection to the database".to_string(), HttpConnError::GetAuthInfo(c) => c.to_string_client(), HttpConnError::AuthError(c) => c.to_string_client(), HttpConnError::WakeCompute(c) => c.to_string_client(), @@ -241,6 +291,8 @@ impl CouldRetry for HttpConnError { fn could_retry(&self) -> bool { match self { HttpConnError::ConnectionError(e) => e.could_retry(), + HttpConnError::IoConnectionError(e) => e.could_retry(), + HttpConnError::H2ConnectionError(_) => false, HttpConnError::ConnectionClosedAbruptly(_) => false, HttpConnError::GetAuthInfo(_) => false, HttpConnError::AuthError(_) => false, @@ -309,3 +361,100 @@ impl ConnectMechanism for TokioMechanism { fn update_connect_config(&self, _config: &mut compute::ConnCfg) {} } + +struct HyperMechanism { + pool: Arc, + conn_info: ConnInfo, + conn_id: uuid::Uuid, + + /// connect_to_compute concurrency lock + locks: &'static ApiLocks, +} + +#[async_trait] +impl ConnectMechanism for HyperMechanism { + type Connection = http_conn_pool::Client; + type ConnectError = HttpConnError; + type Error = HttpConnError; + + async fn connect_once( + &self, + ctx: &RequestMonitoring, + node_info: &CachedNodeInfo, + timeout: Duration, + ) -> Result { + let host = node_info.config.get_host()?; + let permit = self.locks.get_permit(&host).await?; + + let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute); + + // let port = node_info.config.get_ports().first().unwrap_or_else(10432); + let res = connect_http2(&host, 10432, timeout).await; + drop(pause); + let (client, connection) = permit.release_result(res)?; + + Ok(poll_http2_client( + self.pool.clone(), + ctx, + self.conn_info.clone(), + client, + connection, + self.conn_id, + node_info.aux.clone(), + )) + } + + fn update_connect_config(&self, _config: &mut compute::ConnCfg) {} +} + +async fn connect_http2( + host: &str, + port: u16, + timeout: Duration, +) -> Result< + ( + http2::SendRequest>, + http2::Connection, Full, TokioExecutor>, + ), + HttpConnError, +> { + let mut addrs = lookup_host((host, port)).await?; + + let mut last_err = None; + + let stream = loop { + let Some(addr) = addrs.next() else { + return Err(last_err.unwrap_or_else(|| { + HttpConnError::IoConnectionError(io::Error::new( + io::ErrorKind::InvalidInput, + "could not resolve any addresses", + )) + })); + }; + + let stream = match tokio::time::timeout(timeout, TcpStream::connect(addr)).await { + Ok(Ok(stream)) => stream, + Ok(Err(e)) => { + last_err = Some(HttpConnError::IoConnectionError(e)); + continue; + } + Err(e) => { + last_err = Some(HttpConnError::IoConnectionError(io::Error::new( + io::ErrorKind::TimedOut, + e, + ))); + continue; + } + }; + + stream.set_nodelay(true)?; + + break stream; + }; + + let (client, connection) = hyper1::client::conn::http2::Builder::new(TokioExecutor::new()) + .handshake(TokioIo::new(stream)) + .await?; + + Ok((client, connection)) +} diff --git a/proxy/src/serverless/http_conn_pool.rs b/proxy/src/serverless/http_conn_pool.rs new file mode 100644 index 0000000000..1c92f86dc9 --- /dev/null +++ b/proxy/src/serverless/http_conn_pool.rs @@ -0,0 +1,360 @@ +use bytes::Bytes; +use dashmap::DashMap; +use http_body_util::Full; +use hyper1::client::conn::http2; +use hyper_util::rt::{TokioExecutor, TokioIo}; +use parking_lot::RwLock; +use rand::Rng; +use std::collections::VecDeque; +use std::{ + ops::Deref, + sync::atomic::{self, AtomicUsize}, +}; +use std::{sync::Arc, sync::Weak}; +use tokio::net::TcpStream; + +use crate::console::messages::{ColdStartInfo, MetricsAuxInfo}; +use crate::metrics::{HttpEndpointPoolsGuard, Metrics}; +use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS}; +use crate::{context::RequestMonitoring, EndpointCacheKey}; + +use tracing::{debug, error, Span}; +use tracing::{info, info_span, Instrument}; + +use super::conn_pool::ConnInfo; + +#[derive(Clone)] +struct ConnPoolEntry { + conn: http2::SendRequest>, + conn_id: uuid::Uuid, + aux: MetricsAuxInfo, +} + +// Per-endpoint connection pool +// Number of open connections is limited by the `max_conns_per_endpoint`. +pub(crate) struct EndpointConnPool { + conns: VecDeque, + _guard: HttpEndpointPoolsGuard<'static>, + global_connections_count: Arc, +} + +impl EndpointConnPool { + fn get_conn_entry(&mut self) -> Option { + let Self { conns, .. } = self; + + let conn = conns.pop_front()?; + conns.push_back(conn.clone()); + Some(conn) + } + + fn remove_conn(&mut self, conn_id: uuid::Uuid) -> bool { + let Self { + conns, + global_connections_count, + .. + } = self; + + let old_len = conns.len(); + conns.retain(|conn| conn.conn_id != conn_id); + let new_len = conns.len(); + let removed = old_len - new_len; + if removed > 0 { + global_connections_count.fetch_sub(removed, atomic::Ordering::Relaxed); + Metrics::get() + .proxy + .http_pool_opened_connections + .get_metric() + .dec_by(removed as i64); + } + removed > 0 + } +} + +impl Drop for EndpointConnPool { + fn drop(&mut self) { + if !self.conns.is_empty() { + self.global_connections_count + .fetch_sub(self.conns.len(), atomic::Ordering::Relaxed); + Metrics::get() + .proxy + .http_pool_opened_connections + .get_metric() + .dec_by(self.conns.len() as i64); + } + } +} + +pub(crate) struct GlobalConnPool { + // endpoint -> per-endpoint connection pool + // + // That should be a fairly conteded map, so return reference to the per-endpoint + // pool as early as possible and release the lock. + global_pool: DashMap>>, + + /// Number of endpoint-connection pools + /// + /// [`DashMap::len`] iterates over all inner pools and acquires a read lock on each. + /// That seems like far too much effort, so we're using a relaxed increment counter instead. + /// It's only used for diagnostics. + global_pool_size: AtomicUsize, + + /// Total number of connections in the pool + global_connections_count: Arc, + + config: &'static crate::config::HttpConfig, +} + +impl GlobalConnPool { + pub(crate) fn new(config: &'static crate::config::HttpConfig) -> Arc { + let shards = config.pool_options.pool_shards; + Arc::new(Self { + global_pool: DashMap::with_shard_amount(shards), + global_pool_size: AtomicUsize::new(0), + config, + global_connections_count: Arc::new(AtomicUsize::new(0)), + }) + } + + pub(crate) fn shutdown(&self) { + // drops all strong references to endpoint-pools + self.global_pool.clear(); + } + + pub(crate) async fn gc_worker(&self, mut rng: impl Rng) { + let epoch = self.config.pool_options.gc_epoch; + let mut interval = tokio::time::interval(epoch / (self.global_pool.shards().len()) as u32); + loop { + interval.tick().await; + + let shard = rng.gen_range(0..self.global_pool.shards().len()); + self.gc(shard); + } + } + + fn gc(&self, shard: usize) { + debug!(shard, "pool: performing epoch reclamation"); + + // acquire a random shard lock + let mut shard = self.global_pool.shards()[shard].write(); + + let timer = Metrics::get() + .proxy + .http_pool_reclaimation_lag_seconds + .start_timer(); + let current_len = shard.len(); + let mut clients_removed = 0; + shard.retain(|endpoint, x| { + // if the current endpoint pool is unique (no other strong or weak references) + // then it is currently not in use by any connections. + if let Some(pool) = Arc::get_mut(x.get_mut()) { + let EndpointConnPool { conns, .. } = pool.get_mut(); + + let old_len = conns.len(); + + conns.retain(|conn| !conn.conn.is_closed()); + + let new_len = conns.len(); + let removed = old_len - new_len; + clients_removed += removed; + + // we only remove this pool if it has no active connections + if conns.is_empty() { + info!("pool: discarding pool for endpoint {endpoint}"); + return false; + } + } + + true + }); + + let new_len = shard.len(); + drop(shard); + timer.observe(); + + // Do logging outside of the lock. + if clients_removed > 0 { + let size = self + .global_connections_count + .fetch_sub(clients_removed, atomic::Ordering::Relaxed) + - clients_removed; + Metrics::get() + .proxy + .http_pool_opened_connections + .get_metric() + .dec_by(clients_removed as i64); + info!("pool: performed global pool gc. removed {clients_removed} clients, total number of clients in pool is {size}"); + } + let removed = current_len - new_len; + + if removed > 0 { + let global_pool_size = self + .global_pool_size + .fetch_sub(removed, atomic::Ordering::Relaxed) + - removed; + info!("pool: performed global pool gc. size now {global_pool_size}"); + } + } + + pub(crate) fn get( + self: &Arc, + ctx: &RequestMonitoring, + conn_info: &ConnInfo, + ) -> Option { + let endpoint = conn_info.endpoint_cache_key()?; + let endpoint_pool = self.get_or_create_endpoint_pool(&endpoint); + let client = endpoint_pool.write().get_conn_entry()?; + + if client.conn.is_closed() { + info!("pool: cached connection '{conn_info}' is closed, opening a new one"); + return None; + } + tracing::Span::current().record("conn_id", tracing::field::display(client.conn_id)); + info!( + cold_start_info = ColdStartInfo::HttpPoolHit.as_str(), + "pool: reusing connection '{conn_info}'" + ); + ctx.set_cold_start_info(ColdStartInfo::HttpPoolHit); + ctx.success(); + Some(Client::new(client.conn, conn_info.clone(), client.aux)) + } + + fn get_or_create_endpoint_pool( + self: &Arc, + endpoint: &EndpointCacheKey, + ) -> Arc> { + // fast path + if let Some(pool) = self.global_pool.get(endpoint) { + return pool.clone(); + } + + // slow path + let new_pool = Arc::new(RwLock::new(EndpointConnPool { + conns: VecDeque::new(), + _guard: Metrics::get().proxy.http_endpoint_pools.guard(), + global_connections_count: self.global_connections_count.clone(), + })); + + // find or create a pool for this endpoint + let mut created = false; + let pool = self + .global_pool + .entry(endpoint.clone()) + .or_insert_with(|| { + created = true; + new_pool + }) + .clone(); + + // log new global pool size + if created { + let global_pool_size = self + .global_pool_size + .fetch_add(1, atomic::Ordering::Relaxed) + + 1; + info!( + "pool: created new pool for '{endpoint}', global pool size now {global_pool_size}" + ); + } + + pool + } +} + +pub(crate) fn poll_http2_client( + global_pool: Arc, + ctx: &RequestMonitoring, + conn_info: ConnInfo, + client: http2::SendRequest>, + connection: http2::Connection, Full, TokioExecutor>, + conn_id: uuid::Uuid, + aux: MetricsAuxInfo, +) -> Client { + let conn_gauge = Metrics::get().proxy.db_connections.guard(ctx.protocol()); + let session_id = ctx.session_id(); + + let span = info_span!(parent: None, "connection", %conn_id); + let cold_start_info = ctx.cold_start_info(); + span.in_scope(|| { + info!(cold_start_info = cold_start_info.as_str(), %conn_info, %session_id, "new connection"); + }); + + let pool = match conn_info.endpoint_cache_key() { + Some(endpoint) => { + let pool = global_pool.get_or_create_endpoint_pool(&endpoint); + + pool.write().conns.push_back(ConnPoolEntry { + conn: client.clone(), + conn_id, + aux: aux.clone(), + }); + + Arc::downgrade(&pool) + } + None => Weak::new(), + }; + + // let idle = global_pool.get_idle_timeout(); + + tokio::spawn( + async move { + let _conn_gauge = conn_gauge; + let res = connection.await; + match res { + Ok(()) => info!("connection closed"), + Err(e) => error!(%session_id, "connection error: {}", e), + } + + // remove from connection pool + if let Some(pool) = pool.clone().upgrade() { + if pool.write().remove_conn(conn_id) { + info!("closed connection removed"); + } + } + } + .instrument(span), + ); + + Client::new(client, conn_info, aux) +} + +impl Client { + pub(crate) fn metrics(&self) -> Arc { + USAGE_METRICS.register(Ids { + endpoint_id: self.aux.endpoint_id, + branch_id: self.aux.branch_id, + }) + } +} + +pub(crate) struct Client { + span: Span, + inner: http2::SendRequest>, + aux: MetricsAuxInfo, + conn_info: ConnInfo, +} + +impl Client { + pub(self) fn new( + inner: http2::SendRequest>, + conn_info: ConnInfo, + aux: MetricsAuxInfo, + ) -> Self { + Self { + inner, + span: Span::current(), + conn_info, + aux, + } + } + pub(crate) fn inner(&mut self) -> &mut http2::SendRequest> { + &mut self.inner + } +} + +impl Deref for Client { + type Target = http2::SendRequest>; + + fn deref(&self) -> &Self::Target { + &self.inner + } +} diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index 97e280d252..dbe1df8bb0 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -514,18 +514,60 @@ async fn handle_inner( "handling interactive connection from client" ); - // - // Determine the destination and connection params - // - let headers = request.headers(); - - // TLS config should be there. - let conn_info = get_conn_info(ctx, headers, config.tls_config.as_ref())?; + let conn_info = get_conn_info(ctx, request.headers(), config.tls_config.as_ref())?; info!( user = conn_info.conn_info.user_info.user.as_str(), "credentials" ); + match conn_info.auth { + AuthData::Password(pw) => { + let res = handle_db_inner( + cancel, + config, + ctx, + request, + conn_info.conn_info, + &pw, + backend, + ) + .await?; + Ok(res) + } + AuthData::Jwt(jwt) => { + let keys = backend + .authenticate_with_jwt( + ctx, + &config.authentication_config, + &conn_info.conn_info.user_info, + jwt, + ) + .await + .map_err(HttpConnError::from)?; + + let _client = backend + .connect_to_local_proxy(ctx, conn_info.conn_info, keys) + .await?; + + todo!() + } + } +} + +async fn handle_db_inner( + cancel: CancellationToken, + config: &'static ProxyConfig, + ctx: &RequestMonitoring, + request: Request, + conn_info: ConnInfo, + password: &[u8], + backend: Arc, +) -> Result>, SqlOverHttpError> { + // + // Determine the destination and connection params + // + let headers = request.headers(); + // Allow connection pooling only if explicitly requested // or if we have decided that http pool is no longer opt-in let allow_pool = !config.http_config.pool_options.opt_in @@ -563,31 +605,17 @@ async fn handle_inner( let authenticate_and_connect = Box::pin( async { - let keys = match conn_info.auth { - AuthData::Password(pw) => { - backend - .authenticate_with_password( - ctx, - &config.authentication_config, - &conn_info.conn_info.user_info, - &pw, - ) - .await? - } - AuthData::Jwt(jwt) => { - backend - .authenticate_with_jwt( - ctx, - &config.authentication_config, - &conn_info.conn_info.user_info, - jwt, - ) - .await? - } - }; + let keys = backend + .authenticate_with_password( + ctx, + &config.authentication_config, + &conn_info.user_info, + password, + ) + .await?; let client = backend - .connect_to_compute(ctx, conn_info.conn_info, keys, !allow_pool) + .connect_to_compute(ctx, conn_info, keys, !allow_pool) .await?; // not strictly necessary to mark success here, // but it's just insurance for if we forget it somewhere else