From c63e3e7e84c2dd9c9792619cc4fee15b07cfe7d7 Mon Sep 17 00:00:00 2001 From: Anna Khanova <32508607+khanova@users.noreply.github.com> Date: Thu, 8 Feb 2024 12:57:05 +0100 Subject: [PATCH] Proxy: improve http-pool (#6577) ## Problem The password check logic for the sql-over-http is a bit non-intuitive. ## Summary of changes 1. Perform scram auth using the same logic as for websocket cleartext password. 2. Split establish connection logic and connection pool. 3. Parallelize param parsing logic with authentication + wake compute. 4. Limit the total number of clients --- Cargo.lock | 1 + proxy/Cargo.toml | 1 + proxy/src/auth/backend.rs | 12 + proxy/src/auth/flow.rs | 2 +- proxy/src/bin/proxy.rs | 5 + proxy/src/console/provider/neon.rs | 2 + proxy/src/context.rs | 4 + proxy/src/metrics.rs | 44 +- proxy/src/proxy/connect_compute.rs | 22 +- proxy/src/proxy/tests.rs | 3 + proxy/src/serverless.rs | 41 +- proxy/src/serverless/backend.rs | 157 +++++ proxy/src/serverless/conn_pool.rs | 797 +++++++++++++------------- proxy/src/serverless/json.rs | 28 +- proxy/src/serverless/sql_over_http.rs | 92 ++- test_runner/regress/test_proxy.py | 20 +- 16 files changed, 753 insertions(+), 478 deletions(-) create mode 100644 proxy/src/serverless/backend.rs diff --git a/Cargo.lock b/Cargo.lock index 30e233ecc1..c0c319cd89 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4079,6 +4079,7 @@ dependencies = [ "clap", "consumption_metrics", "dashmap", + "env_logger", "futures", "git-version", "hashbrown 0.13.2", diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index 1247f08ee6..83cab381b3 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -19,6 +19,7 @@ chrono.workspace = true clap.workspace = true consumption_metrics.workspace = true dashmap.workspace = true +env_logger.workspace = true futures.workspace = true git-version.workspace = true hashbrown.workspace = true diff --git a/proxy/src/auth/backend.rs b/proxy/src/auth/backend.rs index 236567163e..fa2782bee3 100644 --- a/proxy/src/auth/backend.rs +++ b/proxy/src/auth/backend.rs @@ -68,6 +68,7 @@ pub trait TestBackend: Send + Sync + 'static { fn get_allowed_ips_and_secret( &self, ) -> Result<(CachedAllowedIps, Option), console::errors::GetAuthInfoError>; + fn get_role_secret(&self) -> Result; } impl std::fmt::Display for BackendType<'_, ()> { @@ -358,6 +359,17 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> { } impl BackendType<'_, ComputeUserInfo> { + pub async fn get_role_secret( + &self, + ctx: &mut RequestMonitoring, + ) -> Result { + use BackendType::*; + match self { + Console(api, user_info) => api.get_role_secret(ctx, user_info).await, + Link(_) => Ok(Cached::new_uncached(None)), + } + } + pub async fn get_allowed_ips_and_secret( &self, ctx: &mut RequestMonitoring, diff --git a/proxy/src/auth/flow.rs b/proxy/src/auth/flow.rs index 077178d107..c2783e236c 100644 --- a/proxy/src/auth/flow.rs +++ b/proxy/src/auth/flow.rs @@ -167,7 +167,7 @@ impl AuthFlow<'_, S, Scram<'_>> { } } -pub(super) fn validate_password_and_exchange( +pub(crate) fn validate_password_and_exchange( password: &[u8], secret: AuthSecret, ) -> super::Result> { diff --git a/proxy/src/bin/proxy.rs b/proxy/src/bin/proxy.rs index 3bbb87808d..6974f1a274 100644 --- a/proxy/src/bin/proxy.rs +++ b/proxy/src/bin/proxy.rs @@ -165,6 +165,10 @@ struct SqlOverHttpArgs { #[clap(long, default_value_t = 20)] sql_over_http_pool_max_conns_per_endpoint: usize, + /// How many connections to pool for each endpoint. Excess connections are discarded + #[clap(long, default_value_t = 20000)] + sql_over_http_pool_max_total_conns: usize, + /// How long pooled connections should remain idle for before closing #[clap(long, default_value = "5m", value_parser = humantime::parse_duration)] sql_over_http_idle_timeout: tokio::time::Duration, @@ -387,6 +391,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { pool_shards: args.sql_over_http.sql_over_http_pool_shards, idle_timeout: args.sql_over_http.sql_over_http_idle_timeout, opt_in: args.sql_over_http.sql_over_http_pool_opt_in, + max_total_conns: args.sql_over_http.sql_over_http_pool_max_total_conns, }, }; let authentication_config = AuthenticationConfig { diff --git a/proxy/src/console/provider/neon.rs b/proxy/src/console/provider/neon.rs index 0785419790..71b34cb676 100644 --- a/proxy/src/console/provider/neon.rs +++ b/proxy/src/console/provider/neon.rs @@ -188,6 +188,7 @@ impl super::Api for Api { ep, Arc::new(auth_info.allowed_ips), ); + ctx.set_project_id(project_id); } // When we just got a secret, we don't need to invalidate it. Ok(Cached::new_uncached(auth_info.secret)) @@ -221,6 +222,7 @@ impl super::Api for Api { self.caches .project_info .insert_allowed_ips(&project_id, ep, allowed_ips.clone()); + ctx.set_project_id(project_id); } Ok(( Cached::new_uncached(allowed_ips), diff --git a/proxy/src/context.rs b/proxy/src/context.rs index e2b0294cd3..fe204534b7 100644 --- a/proxy/src/context.rs +++ b/proxy/src/context.rs @@ -89,6 +89,10 @@ impl RequestMonitoring { self.project = Some(x.project_id); } + pub fn set_project_id(&mut self, project_id: ProjectId) { + self.project = Some(project_id); + } + pub fn set_endpoint_id(&mut self, endpoint_id: EndpointId) { crate::metrics::CONNECTING_ENDPOINTS .with_label_values(&[self.protocol]) diff --git a/proxy/src/metrics.rs b/proxy/src/metrics.rs index fa663d8ff6..e2d96a9c27 100644 --- a/proxy/src/metrics.rs +++ b/proxy/src/metrics.rs @@ -1,8 +1,10 @@ use ::metrics::{ exponential_buckets, register_histogram, register_histogram_vec, register_hll_vec, - register_int_counter_pair_vec, register_int_counter_vec, register_int_gauge_vec, Histogram, - HistogramVec, HyperLogLogVec, IntCounterPairVec, IntCounterVec, IntGaugeVec, + register_int_counter_pair_vec, register_int_counter_vec, register_int_gauge, + register_int_gauge_vec, Histogram, HistogramVec, HyperLogLogVec, IntCounterPairVec, + IntCounterVec, IntGauge, IntGaugeVec, }; +use metrics::{register_int_counter_pair, IntCounterPair}; use once_cell::sync::Lazy; use tokio::time; @@ -112,6 +114,44 @@ pub static ALLOWED_IPS_NUMBER: Lazy = Lazy::new(|| { .unwrap() }); +pub static HTTP_CONTENT_LENGTH: Lazy = Lazy::new(|| { + register_histogram!( + "proxy_http_conn_content_length_bytes", + "Time it took for proxy to establish a connection to the compute endpoint", + // largest bucket = 3^16 * 0.05ms = 2.15s + exponential_buckets(8.0, 2.0, 20).unwrap() + ) + .unwrap() +}); + +pub static GC_LATENCY: Lazy = Lazy::new(|| { + register_histogram!( + "proxy_http_pool_reclaimation_lag_seconds", + "Time it takes to reclaim unused connection pools", + // 1us -> 65ms + exponential_buckets(1e-6, 2.0, 16).unwrap(), + ) + .unwrap() +}); + +pub static ENDPOINT_POOLS: Lazy = Lazy::new(|| { + register_int_counter_pair!( + "proxy_http_pool_endpoints_registered_total", + "Number of endpoints we have registered pools for", + "proxy_http_pool_endpoints_unregistered_total", + "Number of endpoints we have unregistered pools for", + ) + .unwrap() +}); + +pub static NUM_OPEN_CLIENTS_IN_HTTP_POOL: Lazy = Lazy::new(|| { + register_int_gauge!( + "proxy_http_pool_opened_connections", + "Number of opened connections to a database.", + ) + .unwrap() +}); + #[derive(Clone)] pub struct LatencyTimer { // time since the stopwatch was started diff --git a/proxy/src/proxy/connect_compute.rs b/proxy/src/proxy/connect_compute.rs index 58c59dba36..b9346aa743 100644 --- a/proxy/src/proxy/connect_compute.rs +++ b/proxy/src/proxy/connect_compute.rs @@ -34,21 +34,6 @@ pub fn invalidate_cache(node_info: console::CachedNodeInfo) -> compute::ConnCfg node_info.invalidate().config } -/// Try to connect to the compute node once. -#[tracing::instrument(name = "connect_once", fields(pid = tracing::field::Empty), skip_all)] -async fn connect_to_compute_once( - ctx: &mut RequestMonitoring, - node_info: &console::CachedNodeInfo, - timeout: time::Duration, -) -> Result { - let allow_self_signed_compute = node_info.allow_self_signed_compute; - - node_info - .config - .connect(ctx, allow_self_signed_compute, timeout) - .await -} - #[async_trait] pub trait ConnectMechanism { type Connection; @@ -75,13 +60,18 @@ impl ConnectMechanism for TcpMechanism<'_> { type ConnectError = compute::ConnectionError; type Error = compute::ConnectionError; + #[tracing::instrument(fields(pid = tracing::field::Empty), skip_all)] async fn connect_once( &self, ctx: &mut RequestMonitoring, node_info: &console::CachedNodeInfo, timeout: time::Duration, ) -> Result { - connect_to_compute_once(ctx, node_info, timeout).await + let allow_self_signed_compute = node_info.allow_self_signed_compute; + node_info + .config + .connect(ctx, allow_self_signed_compute, timeout) + .await } fn update_connect_config(&self, config: &mut compute::ConnCfg) { diff --git a/proxy/src/proxy/tests.rs b/proxy/src/proxy/tests.rs index 2000774224..656cabac75 100644 --- a/proxy/src/proxy/tests.rs +++ b/proxy/src/proxy/tests.rs @@ -478,6 +478,9 @@ impl TestBackend for TestConnectMechanism { { unimplemented!("not used in tests") } + fn get_role_secret(&self) -> Result { + unimplemented!("not used in tests") + } } fn helper_create_cached_node_info() -> CachedNodeInfo { diff --git a/proxy/src/serverless.rs b/proxy/src/serverless.rs index 7ff93b23b8..58aa925a6a 100644 --- a/proxy/src/serverless.rs +++ b/proxy/src/serverless.rs @@ -2,6 +2,7 @@ //! //! Handles both SQL over HTTP and SQL over Websockets. +mod backend; mod conn_pool; mod json; mod sql_over_http; @@ -18,11 +19,11 @@ pub use reqwest_middleware::{ClientWithMiddleware, Error}; pub use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware}; use tokio_util::task::TaskTracker; -use crate::config::TlsConfig; use crate::context::RequestMonitoring; use crate::metrics::NUM_CLIENT_CONNECTION_GAUGE; use crate::protocol2::{ProxyProtocolAccept, WithClientIp}; use crate::rate_limiter::EndpointRateLimiter; +use crate::serverless::backend::PoolingBackend; use crate::{cancellation::CancelMap, config::ProxyConfig}; use futures::StreamExt; use hyper::{ @@ -54,12 +55,13 @@ pub async fn task_main( info!("websocket server has shut down"); } - let conn_pool = conn_pool::GlobalConnPool::new(config); - - let conn_pool2 = Arc::clone(&conn_pool); - tokio::spawn(async move { - conn_pool2.gc_worker(StdRng::from_entropy()).await; - }); + let conn_pool = conn_pool::GlobalConnPool::new(&config.http_config); + { + let conn_pool = Arc::clone(&conn_pool); + tokio::spawn(async move { + conn_pool.gc_worker(StdRng::from_entropy()).await; + }); + } // shutdown the connection pool tokio::spawn({ @@ -73,6 +75,11 @@ pub async fn task_main( } }); + let backend = Arc::new(PoolingBackend { + pool: Arc::clone(&conn_pool), + config, + }); + let tls_config = match config.tls_config.as_ref() { Some(config) => config, None => { @@ -106,7 +113,7 @@ pub async fn task_main( let client_addr = io.client_addr(); let remote_addr = io.inner.remote_addr(); let sni_name = tls.server_name().map(|s| s.to_string()); - let conn_pool = conn_pool.clone(); + let backend = backend.clone(); let ws_connections = ws_connections.clone(); let endpoint_rate_limiter = endpoint_rate_limiter.clone(); @@ -119,7 +126,7 @@ pub async fn task_main( Ok(MetricService::new(hyper::service::service_fn( move |req: Request| { let sni_name = sni_name.clone(); - let conn_pool = conn_pool.clone(); + let backend = backend.clone(); let ws_connections = ws_connections.clone(); let endpoint_rate_limiter = endpoint_rate_limiter.clone(); @@ -130,8 +137,7 @@ pub async fn task_main( request_handler( req, config, - tls_config, - conn_pool, + backend, ws_connections, cancel_map, session_id, @@ -200,8 +206,7 @@ where async fn request_handler( mut request: Request, config: &'static ProxyConfig, - tls: &'static TlsConfig, - conn_pool: Arc, + backend: Arc, ws_connections: TaskTracker, cancel_map: Arc, session_id: uuid::Uuid, @@ -248,15 +253,7 @@ async fn request_handler( } else if request.uri().path() == "/sql" && request.method() == Method::POST { let mut ctx = RequestMonitoring::new(session_id, peer_addr, "http", &config.region); - sql_over_http::handle( - tls, - &config.http_config, - &mut ctx, - request, - sni_hostname, - conn_pool, - ) - .await + sql_over_http::handle(config, &mut ctx, request, sni_hostname, backend).await } else if request.uri().path() == "/sql" && request.method() == Method::OPTIONS { Response::builder() .header("Allow", "OPTIONS, POST") diff --git a/proxy/src/serverless/backend.rs b/proxy/src/serverless/backend.rs new file mode 100644 index 0000000000..466a74f0ea --- /dev/null +++ b/proxy/src/serverless/backend.rs @@ -0,0 +1,157 @@ +use std::{sync::Arc, time::Duration}; + +use anyhow::Context; +use async_trait::async_trait; +use tracing::info; + +use crate::{ + auth::{backend::ComputeCredentialKeys, check_peer_addr_is_in_list, AuthError}, + compute, + config::ProxyConfig, + console::CachedNodeInfo, + context::RequestMonitoring, + proxy::connect_compute::ConnectMechanism, +}; + +use super::conn_pool::{poll_client, Client, ConnInfo, GlobalConnPool, APP_NAME}; + +pub struct PoolingBackend { + pub pool: Arc>, + pub config: &'static ProxyConfig, +} + +impl PoolingBackend { + pub async fn authenticate( + &self, + ctx: &mut RequestMonitoring, + conn_info: &ConnInfo, + ) -> Result { + let user_info = conn_info.user_info.clone(); + let backend = self.config.auth_backend.as_ref().map(|_| user_info.clone()); + let (allowed_ips, maybe_secret) = backend.get_allowed_ips_and_secret(ctx).await?; + if !check_peer_addr_is_in_list(&ctx.peer_addr, &allowed_ips) { + return Err(AuthError::ip_address_not_allowed()); + } + let cached_secret = match maybe_secret { + Some(secret) => secret, + None => backend.get_role_secret(ctx).await?, + }; + + let secret = match cached_secret.value.clone() { + Some(secret) => secret, + None => { + // If we don't have an authentication secret, for the http flow we can just return an error. + info!("authentication info not found"); + return Err(AuthError::auth_failed(&*user_info.user)); + } + }; + let auth_outcome = + crate::auth::validate_password_and_exchange(conn_info.password.as_bytes(), secret)?; + match auth_outcome { + crate::sasl::Outcome::Success(key) => Ok(key), + crate::sasl::Outcome::Failure(reason) => { + info!("auth backend failed with an error: {reason}"); + Err(AuthError::auth_failed(&*conn_info.user_info.user)) + } + } + } + + // Wake up the destination if needed. Code here is a bit involved because + // we reuse the code from the usual proxy and we need to prepare few structures + // that this code expects. + #[tracing::instrument(fields(pid = tracing::field::Empty), skip_all)] + pub async fn connect_to_compute( + &self, + ctx: &mut RequestMonitoring, + conn_info: ConnInfo, + keys: ComputeCredentialKeys, + force_new: bool, + ) -> anyhow::Result> { + let maybe_client = if !force_new { + info!("pool: looking for an existing connection"); + self.pool.get(ctx, &conn_info).await? + } else { + info!("pool: pool is disabled"); + None + }; + + if let Some(client) = maybe_client { + return Ok(client); + } + let conn_id = uuid::Uuid::new_v4(); + info!(%conn_id, "pool: opening a new connection '{conn_info}'"); + ctx.set_application(Some(APP_NAME)); + let backend = self + .config + .auth_backend + .as_ref() + .map(|_| conn_info.user_info.clone()); + + let mut node_info = backend + .wake_compute(ctx) + .await? + .context("missing cache entry from wake_compute")?; + + match keys { + #[cfg(any(test, feature = "testing"))] + ComputeCredentialKeys::Password(password) => node_info.config.password(password), + ComputeCredentialKeys::AuthKeys(auth_keys) => node_info.config.auth_keys(auth_keys), + }; + + ctx.set_project(node_info.aux.clone()); + + crate::proxy::connect_compute::connect_to_compute( + ctx, + &TokioMechanism { + conn_id, + conn_info, + pool: self.pool.clone(), + }, + node_info, + &backend, + ) + .await + } +} + +struct TokioMechanism { + pool: Arc>, + conn_info: ConnInfo, + conn_id: uuid::Uuid, +} + +#[async_trait] +impl ConnectMechanism for TokioMechanism { + type Connection = Client; + type ConnectError = tokio_postgres::Error; + type Error = anyhow::Error; + + async fn connect_once( + &self, + ctx: &mut RequestMonitoring, + node_info: &CachedNodeInfo, + timeout: Duration, + ) -> Result { + let mut config = (*node_info.config).clone(); + let config = config + .user(&self.conn_info.user_info.user) + .password(&*self.conn_info.password) + .dbname(&self.conn_info.dbname) + .connect_timeout(timeout); + + let (client, connection) = config.connect(tokio_postgres::NoTls).await?; + + tracing::Span::current().record("pid", &tracing::field::display(client.get_process_id())); + Ok(poll_client( + self.pool.clone(), + ctx, + self.conn_info.clone(), + client, + connection, + self.conn_id, + node_info.aux.clone(), + )) + } + + fn update_connect_config(&self, _config: &mut compute::ConnCfg) {} +} diff --git a/proxy/src/serverless/conn_pool.rs b/proxy/src/serverless/conn_pool.rs index 312fa2b36f..a7b2c532d2 100644 --- a/proxy/src/serverless/conn_pool.rs +++ b/proxy/src/serverless/conn_pool.rs @@ -1,15 +1,7 @@ -use anyhow::Context; -use async_trait::async_trait; use dashmap::DashMap; use futures::{future::poll_fn, Future}; -use metrics::{register_int_counter_pair, IntCounterPair, IntCounterPairGuard}; -use once_cell::sync::Lazy; +use metrics::IntCounterPairGuard; use parking_lot::RwLock; -use pbkdf2::{ - password_hash::{PasswordHashString, PasswordHasher, PasswordVerifier, SaltString}, - Params, Pbkdf2, -}; -use prometheus::{exponential_buckets, register_histogram, Histogram}; use rand::Rng; use smol_str::SmolStr; use std::{collections::HashMap, pin::pin, sync::Arc, sync::Weak, time::Duration}; @@ -21,19 +13,17 @@ use std::{ ops::Deref, sync::atomic::{self, AtomicUsize}, }; -use tokio::time::{self, Instant}; -use tokio_postgres::{AsyncMessage, ReadyForQueryStatus}; +use tokio::time::Instant; +use tokio_postgres::tls::NoTlsStream; +use tokio_postgres::{AsyncMessage, ReadyForQueryStatus, Socket}; +use crate::console::messages::MetricsAuxInfo; +use crate::metrics::{ENDPOINT_POOLS, GC_LATENCY, NUM_OPEN_CLIENTS_IN_HTTP_POOL}; +use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS}; use crate::{ - auth::{self, backend::ComputeUserInfo, check_peer_addr_is_in_list}, - console::{self, messages::MetricsAuxInfo}, - context::RequestMonitoring, - metrics::NUM_DB_CONNECTIONS_GAUGE, - proxy::connect_compute::ConnectMechanism, - usage_metrics::{Ids, MetricCounter, USAGE_METRICS}, + auth::backend::ComputeUserInfo, context::RequestMonitoring, metrics::NUM_DB_CONNECTIONS_GAUGE, DbName, EndpointCacheKey, RoleName, }; -use crate::{compute, config}; use tracing::{debug, error, warn, Span}; use tracing::{info, info_span, Instrument}; @@ -72,39 +62,51 @@ impl fmt::Display for ConnInfo { } } -struct ConnPoolEntry { - conn: ClientInner, +struct ConnPoolEntry { + conn: ClientInner, _last_access: std::time::Instant, } // Per-endpoint connection pool, (dbname, username) -> DbUserConnPool // Number of open connections is limited by the `max_conns_per_endpoint`. -pub struct EndpointConnPool { - pools: HashMap<(DbName, RoleName), DbUserConnPool>, +pub struct EndpointConnPool { + pools: HashMap<(DbName, RoleName), DbUserConnPool>, total_conns: usize, max_conns: usize, _guard: IntCounterPairGuard, + global_connections_count: Arc, + global_pool_size_max_conns: usize, } -impl EndpointConnPool { - fn get_conn_entry(&mut self, db_user: (DbName, RoleName)) -> Option { +impl EndpointConnPool { + fn get_conn_entry(&mut self, db_user: (DbName, RoleName)) -> Option> { let Self { - pools, total_conns, .. + pools, + total_conns, + global_connections_count, + .. } = self; - pools - .get_mut(&db_user) - .and_then(|pool_entries| pool_entries.get_conn_entry(total_conns)) + pools.get_mut(&db_user).and_then(|pool_entries| { + pool_entries.get_conn_entry(total_conns, global_connections_count.clone()) + }) } fn remove_client(&mut self, db_user: (DbName, RoleName), conn_id: uuid::Uuid) -> bool { let Self { - pools, total_conns, .. + pools, + total_conns, + global_connections_count, + .. } = self; if let Some(pool) = pools.get_mut(&db_user) { let old_len = pool.conns.len(); pool.conns.retain(|conn| conn.conn.conn_id != conn_id); let new_len = pool.conns.len(); let removed = old_len - new_len; + if removed > 0 { + global_connections_count.fetch_sub(removed, atomic::Ordering::Relaxed); + NUM_OPEN_CLIENTS_IN_HTTP_POOL.sub(removed as i64); + } *total_conns -= removed; removed > 0 } else { @@ -112,13 +114,27 @@ impl EndpointConnPool { } } - fn put(pool: &RwLock, conn_info: &ConnInfo, client: ClientInner) -> anyhow::Result<()> { + fn put( + pool: &RwLock, + conn_info: &ConnInfo, + client: ClientInner, + ) -> anyhow::Result<()> { let conn_id = client.conn_id; - if client.inner.is_closed() { + if client.is_closed() { info!(%conn_id, "pool: throwing away connection '{conn_info}' because connection is closed"); return Ok(()); } + let global_max_conn = pool.read().global_pool_size_max_conns; + if pool + .read() + .global_connections_count + .load(atomic::Ordering::Relaxed) + >= global_max_conn + { + info!(%conn_id, "pool: throwing away connection '{conn_info}' because pool is full"); + return Ok(()); + } // return connection to the pool let mut returned = false; @@ -127,18 +143,19 @@ impl EndpointConnPool { let mut pool = pool.write(); if pool.total_conns < pool.max_conns { - // we create this db-user entry in get, so it should not be None - if let Some(pool_entries) = pool.pools.get_mut(&conn_info.db_and_user()) { - pool_entries.conns.push(ConnPoolEntry { - conn: client, - _last_access: std::time::Instant::now(), - }); + let pool_entries = pool.pools.entry(conn_info.db_and_user()).or_default(); + pool_entries.conns.push(ConnPoolEntry { + conn: client, + _last_access: std::time::Instant::now(), + }); - returned = true; - per_db_size = pool_entries.conns.len(); + returned = true; + per_db_size = pool_entries.conns.len(); - pool.total_conns += 1; - } + pool.total_conns += 1; + pool.global_connections_count + .fetch_add(1, atomic::Ordering::Relaxed); + NUM_OPEN_CLIENTS_IN_HTTP_POOL.inc(); } pool.total_conns @@ -155,49 +172,61 @@ impl EndpointConnPool { } } -/// 4096 is the number of rounds that SCRAM-SHA-256 recommends. -/// It's not the 600,000 that OWASP recommends... but our passwords are high entropy anyway. -/// -/// Still takes 1.4ms to hash on my hardware. -/// We don't want to ruin the latency improvements of using the pool by making password verification take too long -const PARAMS: Params = Params { - rounds: 4096, - output_length: 32, -}; - -#[derive(Default)] -pub struct DbUserConnPool { - conns: Vec, - password_hash: Option, +impl Drop for EndpointConnPool { + fn drop(&mut self) { + if self.total_conns > 0 { + self.global_connections_count + .fetch_sub(self.total_conns, atomic::Ordering::Relaxed); + NUM_OPEN_CLIENTS_IN_HTTP_POOL.sub(self.total_conns as i64); + } + } } -impl DbUserConnPool { - fn clear_closed_clients(&mut self, conns: &mut usize) { +pub struct DbUserConnPool { + conns: Vec>, +} + +impl Default for DbUserConnPool { + fn default() -> Self { + Self { conns: Vec::new() } + } +} + +impl DbUserConnPool { + fn clear_closed_clients(&mut self, conns: &mut usize) -> usize { let old_len = self.conns.len(); - self.conns.retain(|conn| !conn.conn.inner.is_closed()); + self.conns.retain(|conn| !conn.conn.is_closed()); let new_len = self.conns.len(); let removed = old_len - new_len; *conns -= removed; + removed } - fn get_conn_entry(&mut self, conns: &mut usize) -> Option { - self.clear_closed_clients(conns); + fn get_conn_entry( + &mut self, + conns: &mut usize, + global_connections_count: Arc, + ) -> Option> { + let mut removed = self.clear_closed_clients(conns); let conn = self.conns.pop(); if conn.is_some() { *conns -= 1; + removed += 1; } + global_connections_count.fetch_sub(removed, atomic::Ordering::Relaxed); + NUM_OPEN_CLIENTS_IN_HTTP_POOL.sub(removed as i64); conn } } -pub struct GlobalConnPool { +pub struct GlobalConnPool { // endpoint -> per-endpoint connection pool // // That should be a fairly conteded map, so return reference to the per-endpoint // pool as early as possible and release the lock. - global_pool: DashMap>>, + global_pool: DashMap>>>, /// Number of endpoint-connection pools /// @@ -206,7 +235,10 @@ pub struct GlobalConnPool { /// It's only used for diagnostics. global_pool_size: AtomicUsize, - proxy_config: &'static crate::config::ProxyConfig, + /// Total number of connections in the pool + global_connections_count: Arc, + + config: &'static crate::config::HttpConfig, } #[derive(Debug, Clone, Copy)] @@ -224,45 +256,39 @@ pub struct GlobalConnPoolOptions { pub idle_timeout: Duration, pub opt_in: bool, + + // Total number of connections in the pool. + pub max_total_conns: usize, } -pub static GC_LATENCY: Lazy = Lazy::new(|| { - register_histogram!( - "proxy_http_pool_reclaimation_lag_seconds", - "Time it takes to reclaim unused connection pools", - // 1us -> 65ms - exponential_buckets(1e-6, 2.0, 16).unwrap(), - ) - .unwrap() -}); - -pub static ENDPOINT_POOLS: Lazy = Lazy::new(|| { - register_int_counter_pair!( - "proxy_http_pool_endpoints_registered_total", - "Number of endpoints we have registered pools for", - "proxy_http_pool_endpoints_unregistered_total", - "Number of endpoints we have unregistered pools for", - ) - .unwrap() -}); - -impl GlobalConnPool { - pub fn new(config: &'static crate::config::ProxyConfig) -> Arc { - let shards = config.http_config.pool_options.pool_shards; +impl GlobalConnPool { + pub fn new(config: &'static crate::config::HttpConfig) -> Arc { + let shards = config.pool_options.pool_shards; Arc::new(Self { global_pool: DashMap::with_shard_amount(shards), global_pool_size: AtomicUsize::new(0), - proxy_config: config, + config, + global_connections_count: Arc::new(AtomicUsize::new(0)), }) } + #[cfg(test)] + pub fn get_global_connections_count(&self) -> usize { + self.global_connections_count + .load(atomic::Ordering::Relaxed) + } + + pub fn get_idle_timeout(&self) -> Duration { + self.config.pool_options.idle_timeout + } + pub fn shutdown(&self) { // drops all strong references to endpoint-pools self.global_pool.clear(); } pub async fn gc_worker(&self, mut rng: impl Rng) { - let epoch = self.proxy_config.http_config.pool_options.gc_epoch; + let epoch = self.config.pool_options.gc_epoch; let mut interval = tokio::time::interval(epoch / (self.global_pool.shards().len()) as u32); loop { interval.tick().await; @@ -280,6 +306,7 @@ impl GlobalConnPool { let timer = GC_LATENCY.start_timer(); let current_len = shard.len(); + let mut clients_removed = 0; shard.retain(|endpoint, x| { // if the current endpoint pool is unique (no other strong or weak references) // then it is currently not in use by any connections. @@ -289,9 +316,9 @@ impl GlobalConnPool { } = pool.get_mut(); // ensure that closed clients are removed - pools - .iter_mut() - .for_each(|(_, db_pool)| db_pool.clear_closed_clients(total_conns)); + pools.iter_mut().for_each(|(_, db_pool)| { + clients_removed += db_pool.clear_closed_clients(total_conns); + }); // we only remove this pool if it has no active connections if *total_conns == 0 { @@ -302,10 +329,20 @@ impl GlobalConnPool { true }); + let new_len = shard.len(); drop(shard); timer.observe_duration(); + // Do logging outside of the lock. + if clients_removed > 0 { + let size = self + .global_connections_count + .fetch_sub(clients_removed, atomic::Ordering::Relaxed) + - clients_removed; + NUM_OPEN_CLIENTS_IN_HTTP_POOL.sub(clients_removed as i64); + info!("pool: performed global pool gc. removed {clients_removed} clients, total number of clients in pool is {size}"); + } let removed = current_len - new_len; if removed > 0 { @@ -320,61 +357,24 @@ impl GlobalConnPool { pub async fn get( self: &Arc, ctx: &mut RequestMonitoring, - conn_info: ConnInfo, - force_new: bool, - ) -> anyhow::Result { - let mut client: Option = None; + conn_info: &ConnInfo, + ) -> anyhow::Result>> { + let mut client: Option> = None; - let mut hash_valid = false; - let mut endpoint_pool = Weak::new(); - if !force_new { - let pool = self.get_or_create_endpoint_pool(&conn_info.endpoint_cache_key()); - endpoint_pool = Arc::downgrade(&pool); - let mut hash = None; - - // find a pool entry by (dbname, username) if exists - { - let pool = pool.read(); - if let Some(pool_entries) = pool.pools.get(&conn_info.db_and_user()) { - if !pool_entries.conns.is_empty() { - hash = pool_entries.password_hash.clone(); - } - } - } - - // a connection exists in the pool, verify the password hash - if let Some(hash) = hash { - let pw = conn_info.password.clone(); - let validate = tokio::task::spawn_blocking(move || { - Pbkdf2.verify_password(pw.as_bytes(), &hash.password_hash()) - }) - .await?; - - // if the hash is invalid, don't error - // we will continue with the regular connection flow - if validate.is_ok() { - hash_valid = true; - if let Some(entry) = pool.write().get_conn_entry(conn_info.db_and_user()) { - client = Some(entry.conn) - } - } - } + let endpoint_pool = self.get_or_create_endpoint_pool(&conn_info.endpoint_cache_key()); + if let Some(entry) = endpoint_pool + .write() + .get_conn_entry(conn_info.db_and_user()) + { + client = Some(entry.conn) } + let endpoint_pool = Arc::downgrade(&endpoint_pool); // ok return cached connection if found and establish a new one otherwise - let new_client = if let Some(client) = client { - ctx.set_project(client.aux.clone()); - if client.inner.is_closed() { - let conn_id = uuid::Uuid::new_v4(); - info!(%conn_id, "pool: cached connection '{conn_info}' is closed, opening a new one"); - connect_to_compute( - self.proxy_config, - ctx, - &conn_info, - conn_id, - endpoint_pool.clone(), - ) - .await + if let Some(client) = client { + if client.is_closed() { + info!("pool: cached connection '{conn_info}' is closed, opening a new one"); + return Ok(None); } else { info!("pool: reusing connection '{conn_info}'"); client.session.send(ctx.session_id)?; @@ -384,67 +384,16 @@ impl GlobalConnPool { ); ctx.latency_timer.pool_hit(); ctx.latency_timer.success(); - return Ok(Client::new(client, conn_info, endpoint_pool).await); + return Ok(Some(Client::new(client, conn_info.clone(), endpoint_pool))); } - } else { - let conn_id = uuid::Uuid::new_v4(); - info!(%conn_id, "pool: opening a new connection '{conn_info}'"); - connect_to_compute( - self.proxy_config, - ctx, - &conn_info, - conn_id, - endpoint_pool.clone(), - ) - .await - }; - if let Ok(client) = &new_client { - tracing::Span::current().record( - "pid", - &tracing::field::display(client.inner.get_process_id()), - ); } - - match &new_client { - // clear the hash. it's no longer valid - // TODO: update tokio-postgres fork to allow access to this error kind directly - Err(err) - if hash_valid && err.to_string().contains("password authentication failed") => - { - let pool = self.get_or_create_endpoint_pool(&conn_info.endpoint_cache_key()); - let mut pool = pool.write(); - if let Some(entry) = pool.pools.get_mut(&conn_info.db_and_user()) { - entry.password_hash = None; - } - } - // new password is valid and we should insert/update it - Ok(_) if !force_new && !hash_valid => { - let pw = conn_info.password.clone(); - let new_hash = tokio::task::spawn_blocking(move || { - let salt = SaltString::generate(rand::rngs::OsRng); - Pbkdf2 - .hash_password_customized(pw.as_bytes(), None, None, PARAMS, &salt) - .map(|s| s.serialize()) - }) - .await??; - - let pool = self.get_or_create_endpoint_pool(&conn_info.endpoint_cache_key()); - let mut pool = pool.write(); - pool.pools - .entry(conn_info.db_and_user()) - .or_default() - .password_hash = Some(new_hash); - } - _ => {} - } - let new_client = new_client?; - Ok(Client::new(new_client, conn_info, endpoint_pool).await) + Ok(None) } fn get_or_create_endpoint_pool( - &self, + self: &Arc, endpoint: &EndpointCacheKey, - ) -> Arc> { + ) -> Arc>> { // fast path if let Some(pool) = self.global_pool.get(endpoint) { return pool.clone(); @@ -454,12 +403,10 @@ impl GlobalConnPool { let new_pool = Arc::new(RwLock::new(EndpointConnPool { pools: HashMap::new(), total_conns: 0, - max_conns: self - .proxy_config - .http_config - .pool_options - .max_conns_per_endpoint, + max_conns: self.config.pool_options.max_conns_per_endpoint, _guard: ENDPOINT_POOLS.guard(), + global_connections_count: self.global_connections_count.clone(), + global_pool_size_max_conns: self.config.pool_options.max_total_conns, })); // find or create a pool for this endpoint @@ -488,196 +435,128 @@ impl GlobalConnPool { } } -struct TokioMechanism<'a> { - pool: Weak>, - conn_info: &'a ConnInfo, - conn_id: uuid::Uuid, - idle: Duration, -} - -#[async_trait] -impl ConnectMechanism for TokioMechanism<'_> { - type Connection = ClientInner; - type ConnectError = tokio_postgres::Error; - type Error = anyhow::Error; - - async fn connect_once( - &self, - ctx: &mut RequestMonitoring, - node_info: &console::CachedNodeInfo, - timeout: time::Duration, - ) -> Result { - connect_to_compute_once( - ctx, - node_info, - self.conn_info, - timeout, - self.conn_id, - self.pool.clone(), - self.idle, - ) - .await - } - - fn update_connect_config(&self, _config: &mut compute::ConnCfg) {} -} - -// Wake up the destination if needed. Code here is a bit involved because -// we reuse the code from the usual proxy and we need to prepare few structures -// that this code expects. -#[tracing::instrument(fields(pid = tracing::field::Empty), skip_all)] -async fn connect_to_compute( - config: &config::ProxyConfig, +pub fn poll_client( + global_pool: Arc>, ctx: &mut RequestMonitoring, - conn_info: &ConnInfo, + conn_info: ConnInfo, + client: C, + mut connection: tokio_postgres::Connection, conn_id: uuid::Uuid, - pool: Weak>, -) -> anyhow::Result { - ctx.set_application(Some(APP_NAME)); - let backend = config - .auth_backend - .as_ref() - .map(|_| conn_info.user_info.clone()); - - if !config.disable_ip_check_for_http { - let (allowed_ips, _) = backend.get_allowed_ips_and_secret(ctx).await?; - if !check_peer_addr_is_in_list(&ctx.peer_addr, &allowed_ips) { - return Err(auth::AuthError::ip_address_not_allowed().into()); - } - } - let node_info = backend - .wake_compute(ctx) - .await? - .context("missing cache entry from wake_compute")?; - - ctx.set_project(node_info.aux.clone()); - - crate::proxy::connect_compute::connect_to_compute( - ctx, - &TokioMechanism { - conn_id, - conn_info, - pool, - idle: config.http_config.pool_options.idle_timeout, - }, - node_info, - &backend, - ) - .await -} - -async fn connect_to_compute_once( - ctx: &mut RequestMonitoring, - node_info: &console::CachedNodeInfo, - conn_info: &ConnInfo, - timeout: time::Duration, - conn_id: uuid::Uuid, - pool: Weak>, - idle: Duration, -) -> Result { - let mut config = (*node_info.config).clone(); - let mut session = ctx.session_id; - - let (client, mut connection) = config - .user(&conn_info.user_info.user) - .password(&*conn_info.password) - .dbname(&conn_info.dbname) - .connect_timeout(timeout) - .connect(tokio_postgres::NoTls) - .await?; - + aux: MetricsAuxInfo, +) -> Client { let conn_gauge = NUM_DB_CONNECTIONS_GAUGE .with_label_values(&[ctx.protocol]) .guard(); - - tracing::Span::current().record("pid", &tracing::field::display(client.get_process_id())); - - let (tx, mut rx) = tokio::sync::watch::channel(session); + let mut session_id = ctx.session_id; + let (tx, mut rx) = tokio::sync::watch::channel(session_id); let span = info_span!(parent: None, "connection", %conn_id); span.in_scope(|| { - info!(%conn_info, %session, "new connection"); + info!(%conn_info, %session_id, "new connection"); }); + let pool = + Arc::downgrade(&global_pool.get_or_create_endpoint_pool(&conn_info.endpoint_cache_key())); + let pool_clone = pool.clone(); let db_user = conn_info.db_and_user(); + let idle = global_pool.get_idle_timeout(); tokio::spawn( - async move { - let _conn_gauge = conn_gauge; - let mut idle_timeout = pin!(tokio::time::sleep(idle)); - poll_fn(move |cx| { - if matches!(rx.has_changed(), Ok(true)) { - session = *rx.borrow_and_update(); - info!(%session, "changed session"); - idle_timeout.as_mut().reset(Instant::now() + idle); - } + async move { + let _conn_gauge = conn_gauge; + let mut idle_timeout = pin!(tokio::time::sleep(idle)); + poll_fn(move |cx| { + if matches!(rx.has_changed(), Ok(true)) { + session_id = *rx.borrow_and_update(); + info!(%session_id, "changed session"); + idle_timeout.as_mut().reset(Instant::now() + idle); + } - // 5 minute idle connection timeout - if idle_timeout.as_mut().poll(cx).is_ready() { - idle_timeout.as_mut().reset(Instant::now() + idle); - info!("connection idle"); - if let Some(pool) = pool.clone().upgrade() { - // remove client from pool - should close the connection if it's idle. - // does nothing if the client is currently checked-out and in-use - if pool.write().remove_client(db_user.clone(), conn_id) { - info!("idle connection removed"); - } - } - } - - loop { - let message = ready!(connection.poll_message(cx)); - - match message { - Some(Ok(AsyncMessage::Notice(notice))) => { - info!(%session, "notice: {}", notice); - } - Some(Ok(AsyncMessage::Notification(notif))) => { - warn!(%session, pid = notif.process_id(), channel = notif.channel(), "notification received"); - } - Some(Ok(_)) => { - warn!(%session, "unknown message"); - } - Some(Err(e)) => { - error!(%session, "connection error: {}", e); - break - } - None => { - info!("connection closed"); - break - } - } - } - - // remove from connection pool + // 5 minute idle connection timeout + if idle_timeout.as_mut().poll(cx).is_ready() { + idle_timeout.as_mut().reset(Instant::now() + idle); + info!("connection idle"); if let Some(pool) = pool.clone().upgrade() { + // remove client from pool - should close the connection if it's idle. + // does nothing if the client is currently checked-out and in-use if pool.write().remove_client(db_user.clone(), conn_id) { - info!("closed connection removed"); + info!("idle connection removed"); } } + } - Poll::Ready(()) - }).await; + loop { + let message = ready!(connection.poll_message(cx)); - } - .instrument(span) - ); + match message { + Some(Ok(AsyncMessage::Notice(notice))) => { + info!(%session_id, "notice: {}", notice); + } + Some(Ok(AsyncMessage::Notification(notif))) => { + warn!(%session_id, pid = notif.process_id(), channel = notif.channel(), "notification received"); + } + Some(Ok(_)) => { + warn!(%session_id, "unknown message"); + } + Some(Err(e)) => { + error!(%session_id, "connection error: {}", e); + break + } + None => { + info!("connection closed"); + break + } + } + } - Ok(ClientInner { + // remove from connection pool + if let Some(pool) = pool.clone().upgrade() { + if pool.write().remove_client(db_user.clone(), conn_id) { + info!("closed connection removed"); + } + } + + Poll::Ready(()) + }).await; + + } + .instrument(span)); + let inner = ClientInner { inner: client, session: tx, - aux: node_info.aux.clone(), + aux, conn_id, - }) + }; + Client::new(inner, conn_info, pool_clone) } -struct ClientInner { - inner: tokio_postgres::Client, +struct ClientInner { + inner: C, session: tokio::sync::watch::Sender, aux: MetricsAuxInfo, conn_id: uuid::Uuid, } -impl Client { +pub trait ClientInnerExt: Sync + Send + 'static { + fn is_closed(&self) -> bool; + fn get_process_id(&self) -> i32; +} + +impl ClientInnerExt for tokio_postgres::Client { + fn is_closed(&self) -> bool { + self.is_closed() + } + fn get_process_id(&self) -> i32 { + self.get_process_id() + } +} + +impl ClientInner { + pub fn is_closed(&self) -> bool { + self.inner.is_closed() + } +} + +impl Client { pub fn metrics(&self) -> Arc { let aux = &self.inner.as_ref().unwrap().aux; USAGE_METRICS.register(Ids { @@ -687,51 +566,46 @@ impl Client { } } -pub struct Client { - conn_id: uuid::Uuid, +pub struct Client { span: Span, - inner: Option, + inner: Option>, conn_info: ConnInfo, - pool: Weak>, + pool: Weak>>, } -pub struct Discard<'a> { +pub struct Discard<'a, C: ClientInnerExt> { conn_id: uuid::Uuid, conn_info: &'a ConnInfo, - pool: &'a mut Weak>, + pool: &'a mut Weak>>, } -impl Client { - pub(self) async fn new( - inner: ClientInner, +impl Client { + pub(self) fn new( + inner: ClientInner, conn_info: ConnInfo, - pool: Weak>, + pool: Weak>>, ) -> Self { Self { - conn_id: inner.conn_id, inner: Some(inner), span: Span::current(), conn_info, pool, } } - pub fn inner(&mut self) -> (&mut tokio_postgres::Client, Discard<'_>) { + pub fn inner(&mut self) -> (&mut C, Discard<'_, C>) { let Self { inner, pool, - conn_id, conn_info, span: _, } = self; + let inner = inner.as_mut().expect("client inner should not be removed"); ( - &mut inner - .as_mut() - .expect("client inner should not be removed") - .inner, + &mut inner.inner, Discard { pool, conn_info, - conn_id: *conn_id, + conn_id: inner.conn_id, }, ) } @@ -744,7 +618,7 @@ impl Client { } } -impl Discard<'_> { +impl Discard<'_, C> { pub fn check_idle(&mut self, status: ReadyForQueryStatus) { let conn_info = &self.conn_info; if status != ReadyForQueryStatus::Idle && std::mem::take(self.pool).strong_count() > 0 { @@ -759,8 +633,8 @@ impl Discard<'_> { } } -impl Deref for Client { - type Target = tokio_postgres::Client; +impl Deref for Client { + type Target = C; fn deref(&self) -> &Self::Target { &self @@ -771,8 +645,8 @@ impl Deref for Client { } } -impl Drop for Client { - fn drop(&mut self) { +impl Client { + fn do_drop(&mut self) -> Option { let conn_info = self.conn_info.clone(); let client = self .inner @@ -781,10 +655,161 @@ impl Drop for Client { if let Some(conn_pool) = std::mem::take(&mut self.pool).upgrade() { let current_span = self.span.clone(); // return connection to the pool - tokio::task::spawn_blocking(move || { + return Some(move || { let _span = current_span.enter(); let _ = EndpointConnPool::put(&conn_pool, &conn_info, client); }); } + None + } +} + +impl Drop for Client { + fn drop(&mut self) { + if let Some(drop) = self.do_drop() { + tokio::task::spawn_blocking(drop); + } + } +} + +#[cfg(test)] +mod tests { + use env_logger; + use std::{mem, sync::atomic::AtomicBool}; + + use super::*; + + struct MockClient(Arc); + impl MockClient { + fn new(is_closed: bool) -> Self { + MockClient(Arc::new(is_closed.into())) + } + } + impl ClientInnerExt for MockClient { + fn is_closed(&self) -> bool { + self.0.load(atomic::Ordering::Relaxed) + } + fn get_process_id(&self) -> i32 { + 0 + } + } + + fn create_inner() -> ClientInner { + create_inner_with(MockClient::new(false)) + } + + fn create_inner_with(client: MockClient) -> ClientInner { + ClientInner { + inner: client, + session: tokio::sync::watch::Sender::new(uuid::Uuid::new_v4()), + aux: Default::default(), + conn_id: uuid::Uuid::new_v4(), + } + } + + #[tokio::test] + async fn test_pool() { + let _ = env_logger::try_init(); + let config = Box::leak(Box::new(crate::config::HttpConfig { + pool_options: GlobalConnPoolOptions { + max_conns_per_endpoint: 2, + gc_epoch: Duration::from_secs(1), + pool_shards: 2, + idle_timeout: Duration::from_secs(1), + opt_in: false, + max_total_conns: 3, + }, + request_timeout: Duration::from_secs(1), + })); + let pool = GlobalConnPool::new(config); + let conn_info = ConnInfo { + user_info: ComputeUserInfo { + user: "user".into(), + endpoint: "endpoint".into(), + options: Default::default(), + }, + dbname: "dbname".into(), + password: "password".into(), + }; + let ep_pool = + Arc::downgrade(&pool.get_or_create_endpoint_pool(&conn_info.endpoint_cache_key())); + { + let mut client = Client::new(create_inner(), conn_info.clone(), ep_pool.clone()); + assert_eq!(0, pool.get_global_connections_count()); + client.discard(); + // Discard should not add the connection from the pool. + assert_eq!(0, pool.get_global_connections_count()); + } + { + let mut client = Client::new(create_inner(), conn_info.clone(), ep_pool.clone()); + client.do_drop().unwrap()(); + mem::forget(client); // drop the client + assert_eq!(1, pool.get_global_connections_count()); + } + { + let mut closed_client = Client::new( + create_inner_with(MockClient::new(true)), + conn_info.clone(), + ep_pool.clone(), + ); + closed_client.do_drop().unwrap()(); + mem::forget(closed_client); // drop the client + // The closed client shouldn't be added to the pool. + assert_eq!(1, pool.get_global_connections_count()); + } + let is_closed: Arc = Arc::new(false.into()); + { + let mut client = Client::new( + create_inner_with(MockClient(is_closed.clone())), + conn_info.clone(), + ep_pool.clone(), + ); + client.do_drop().unwrap()(); + mem::forget(client); // drop the client + + // The client should be added to the pool. + assert_eq!(2, pool.get_global_connections_count()); + } + { + let mut client = Client::new(create_inner(), conn_info, ep_pool); + client.do_drop().unwrap()(); + mem::forget(client); // drop the client + + // The client shouldn't be added to the pool. Because the ep-pool is full. + assert_eq!(2, pool.get_global_connections_count()); + } + + let conn_info = ConnInfo { + user_info: ComputeUserInfo { + user: "user".into(), + endpoint: "endpoint-2".into(), + options: Default::default(), + }, + dbname: "dbname".into(), + password: "password".into(), + }; + let ep_pool = + Arc::downgrade(&pool.get_or_create_endpoint_pool(&conn_info.endpoint_cache_key())); + { + let mut client = Client::new(create_inner(), conn_info.clone(), ep_pool.clone()); + client.do_drop().unwrap()(); + mem::forget(client); // drop the client + assert_eq!(3, pool.get_global_connections_count()); + } + { + let mut client = Client::new(create_inner(), conn_info.clone(), ep_pool.clone()); + client.do_drop().unwrap()(); + mem::forget(client); // drop the client + + // The client shouldn't be added to the pool. Because the global pool is full. + assert_eq!(3, pool.get_global_connections_count()); + } + + is_closed.store(true, atomic::Ordering::Relaxed); + // Do gc for all shards. + pool.gc(0); + pool.gc(1); + // Closed client should be removed from the pool. + assert_eq!(2, pool.get_global_connections_count()); } } diff --git a/proxy/src/serverless/json.rs b/proxy/src/serverless/json.rs index 05835b23ce..a089d34040 100644 --- a/proxy/src/serverless/json.rs +++ b/proxy/src/serverless/json.rs @@ -9,23 +9,23 @@ use tokio_postgres::Row; // as parameters. // pub fn json_to_pg_text(json: Vec) -> Vec> { - json.iter() - .map(|value| { - match value { - // special care for nulls - Value::Null => None, + json.iter().map(json_value_to_pg_text).collect() +} - // convert to text with escaping - v @ (Value::Bool(_) | Value::Number(_) | Value::Object(_)) => Some(v.to_string()), +fn json_value_to_pg_text(value: &Value) -> Option { + match value { + // special care for nulls + Value::Null => None, - // avoid escaping here, as we pass this as a parameter - Value::String(s) => Some(s.to_string()), + // convert to text with escaping + v @ (Value::Bool(_) | Value::Number(_) | Value::Object(_)) => Some(v.to_string()), - // special care for arrays - Value::Array(_) => json_array_to_pg_array(value), - } - }) - .collect() + // avoid escaping here, as we pass this as a parameter + Value::String(s) => Some(s.to_string()), + + // special care for arrays + Value::Array(_) => json_array_to_pg_array(value), + } } // diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index 96bf39c915..7092b65f03 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -13,6 +13,7 @@ use hyper::StatusCode; use hyper::{Body, HeaderMap, Request}; use serde_json::json; use serde_json::Value; +use tokio::join; use tokio_postgres::error::DbError; use tokio_postgres::error::ErrorPosition; use tokio_postgres::GenericClient; @@ -20,6 +21,7 @@ use tokio_postgres::IsolationLevel; use tokio_postgres::ReadyForQueryStatus; use tokio_postgres::Transaction; use tracing::error; +use tracing::info; use tracing::instrument; use url::Url; use utils::http::error::ApiError; @@ -27,22 +29,25 @@ use utils::http::json::json_response; use crate::auth::backend::ComputeUserInfo; use crate::auth::endpoint_sni; -use crate::config::HttpConfig; +use crate::config::ProxyConfig; use crate::config::TlsConfig; use crate::context::RequestMonitoring; +use crate::metrics::HTTP_CONTENT_LENGTH; use crate::metrics::NUM_CONNECTION_REQUESTS_GAUGE; use crate::proxy::NeonOptions; use crate::RoleName; +use super::backend::PoolingBackend; use super::conn_pool::ConnInfo; -use super::conn_pool::GlobalConnPool; -use super::json::{json_to_pg_text, pg_text_row_to_json}; +use super::json::json_to_pg_text; +use super::json::pg_text_row_to_json; use super::SERVERLESS_DRIVER_SNI; #[derive(serde::Deserialize)] struct QueryData { query: String, - params: Vec, + #[serde(deserialize_with = "bytes_to_pg_text")] + params: Vec>, } #[derive(serde::Deserialize)] @@ -69,6 +74,15 @@ static TXN_DEFERRABLE: HeaderName = HeaderName::from_static("neon-batch-deferrab static HEADER_VALUE_TRUE: HeaderValue = HeaderValue::from_static("true"); +fn bytes_to_pg_text<'de, D>(deserializer: D) -> Result>, D::Error> +where + D: serde::de::Deserializer<'de>, +{ + // TODO: consider avoiding the allocation here. + let json: Vec = serde::de::Deserialize::deserialize(deserializer)?; + Ok(json_to_pg_text(json)) +} + fn get_conn_info( ctx: &mut RequestMonitoring, headers: &HeaderMap, @@ -171,16 +185,15 @@ fn check_matches(sni_hostname: &str, hostname: &str) -> Result, sni_hostname: Option, - conn_pool: Arc, + backend: Arc, ) -> Result, ApiError> { let result = tokio::time::timeout( - config.request_timeout, - handle_inner(tls, config, ctx, request, sni_hostname, conn_pool), + config.http_config.request_timeout, + handle_inner(config, ctx, request, sni_hostname, backend), ) .await; let mut response = match result { @@ -265,7 +278,7 @@ pub async fn handle( Err(_) => { let message = format!( "HTTP-Connection timed out, execution time exeeded {} seconds", - config.request_timeout.as_secs() + config.http_config.request_timeout.as_secs() ); error!(message); json_response( @@ -283,22 +296,36 @@ pub async fn handle( #[instrument(name = "sql-over-http", fields(pid = tracing::field::Empty), skip_all)] async fn handle_inner( - tls: &'static TlsConfig, - config: &'static HttpConfig, + config: &'static ProxyConfig, ctx: &mut RequestMonitoring, request: Request, sni_hostname: Option, - conn_pool: Arc, + backend: Arc, ) -> anyhow::Result> { let _request_gauge = NUM_CONNECTION_REQUESTS_GAUGE - .with_label_values(&["http"]) + .with_label_values(&[ctx.protocol]) .guard(); + info!( + protocol = ctx.protocol, + "handling interactive connection from client" + ); // // Determine the destination and connection params // let headers = request.headers(); - let conn_info = get_conn_info(ctx, headers, sni_hostname, tls)?; + // TLS config should be there. + let conn_info = get_conn_info( + ctx, + headers, + sni_hostname, + config.tls_config.as_ref().unwrap(), + )?; + info!( + user = conn_info.user_info.user.as_str(), + project = conn_info.user_info.endpoint.as_str(), + "credentials" + ); // Determine the output options. Default behaviour is 'false'. Anything that is not // strictly 'true' assumed to be false. @@ -307,8 +334,8 @@ async fn handle_inner( // Allow connection pooling only if explicitly requested // or if we have decided that http pool is no longer opt-in - let allow_pool = - !config.pool_options.opt_in || headers.get(&ALLOW_POOL) == Some(&HEADER_VALUE_TRUE); + let allow_pool = !config.http_config.pool_options.opt_in + || headers.get(&ALLOW_POOL) == Some(&HEADER_VALUE_TRUE); // isolation level, read only and deferrable @@ -333,6 +360,8 @@ async fn handle_inner( None => MAX_REQUEST_SIZE + 1, }; drop(paused); + info!(request_content_length, "request size in bytes"); + HTTP_CONTENT_LENGTH.observe(request_content_length as f64); // we don't have a streaming request support yet so this is to prevent OOM // from a malicious user sending an extremely large request body @@ -342,13 +371,28 @@ async fn handle_inner( )); } - // - // Read the query and query params from the request body - // - let body = hyper::body::to_bytes(request.into_body()).await?; - let payload: Payload = serde_json::from_slice(&body)?; + let fetch_and_process_request = async { + let body = hyper::body::to_bytes(request.into_body()) + .await + .map_err(anyhow::Error::from)?; + let payload: Payload = serde_json::from_slice(&body)?; + Ok::(payload) // Adjust error type accordingly + }; - let mut client = conn_pool.get(ctx, conn_info, !allow_pool).await?; + let authenticate_and_connect = async { + let keys = backend.authenticate(ctx, &conn_info).await?; + backend + .connect_to_compute(ctx, conn_info, keys, !allow_pool) + .await + }; + + // Run both operations in parallel + let (payload_result, auth_and_connect_result) = + join!(fetch_and_process_request, authenticate_and_connect,); + + // Handle the results + let payload = payload_result?; // Handle errors appropriately + let mut client = auth_and_connect_result?; // Handle errors appropriately let mut response = Response::builder() .status(StatusCode::OK) @@ -482,7 +526,7 @@ async fn query_to_json( raw_output: bool, array_mode: bool, ) -> anyhow::Result<(ReadyForQueryStatus, Value)> { - let query_params = json_to_pg_text(data.params); + let query_params = data.params; let row_stream = client.query_raw_txt(&data.query, query_params).await?; // Manually drain the stream into a vector to leave row_stream hanging diff --git a/test_runner/regress/test_proxy.py b/test_runner/regress/test_proxy.py index 1d62f09840..b3b35e446d 100644 --- a/test_runner/regress/test_proxy.py +++ b/test_runner/regress/test_proxy.py @@ -393,11 +393,11 @@ def test_sql_over_http_batch(static_proxy: NeonProxy): def test_sql_over_http_pool(static_proxy: NeonProxy): static_proxy.safe_psql("create user http_auth with password 'http' superuser") - def get_pid(status: int, pw: str) -> Any: + def get_pid(status: int, pw: str, user="http_auth") -> Any: return static_proxy.http_query( GET_CONNECTION_PID_QUERY, [], - user="http_auth", + user=user, password=pw, expected_code=status, ) @@ -418,20 +418,14 @@ def test_sql_over_http_pool(static_proxy: NeonProxy): static_proxy.safe_psql("alter user http_auth with password 'http2'") - # after password change, should open a new connection to verify it - pid2 = get_pid(200, "http2")["rows"][0]["pid"] - assert pid1 != pid2 + # after password change, shouldn't open a new connection because it checks password in proxy. + rows = get_pid(200, "http2")["rows"] + assert rows == [{"pid": pid1}] time.sleep(0.02) - # query should be on an existing connection - pid = get_pid(200, "http2")["rows"][0]["pid"] - assert pid in [pid1, pid2] - - time.sleep(0.02) - - # old password should not work - res = get_pid(400, "http") + # incorrect user shouldn't reveal that the user doesn't exists + res = get_pid(400, "http", user="http_auth2") assert "password authentication failed for user" in res["message"]