From dbf88cf2d72c2298ee6606bf89c5b6bf854fc195 Mon Sep 17 00:00:00 2001 From: Stas Kelvich Date: Tue, 4 Jul 2023 13:27:08 +0300 Subject: [PATCH] Minimalistic pool for http endpoint compute connections (under opt-in flag) Cache up to 20 connections per endpoint. Once all pooled connections are used current implementation can open an extra connection, so the maximum number of simultaneous connections is not enforced. There are more things to do here, especially with background clean-up of closed connections, and checks for transaction state. But current implementation allows to check for smaller coonection latencies that this cache should bring. --- proxy/src/http.rs | 1 + proxy/src/http/conn_pool.rs | 278 ++++++++++++++++++++++++++++++++ proxy/src/http/sql_over_http.rs | 131 +++------------ proxy/src/http/websocket.rs | 12 +- 4 files changed, 307 insertions(+), 115 deletions(-) create mode 100644 proxy/src/http/conn_pool.rs diff --git a/proxy/src/http.rs b/proxy/src/http.rs index 153cdc02e5..e281240380 100644 --- a/proxy/src/http.rs +++ b/proxy/src/http.rs @@ -2,6 +2,7 @@ //! Other modules should use stuff from this module instead of //! directly relying on deps like `reqwest` (think loose coupling). +pub mod conn_pool; pub mod server; pub mod sql_over_http; pub mod websocket; diff --git a/proxy/src/http/conn_pool.rs b/proxy/src/http/conn_pool.rs new file mode 100644 index 0000000000..52c1e2f2ce --- /dev/null +++ b/proxy/src/http/conn_pool.rs @@ -0,0 +1,278 @@ +use parking_lot::Mutex; +use pq_proto::StartupMessageParams; +use std::fmt; +use std::{collections::HashMap, sync::Arc}; + +use futures::TryFutureExt; + +use crate::config; +use crate::{auth, console}; + +use super::sql_over_http::MAX_RESPONSE_SIZE; + +use crate::proxy::invalidate_cache; +use crate::proxy::NUM_RETRIES_WAKE_COMPUTE; + +use tracing::error; +use tracing::info; + +pub const APP_NAME: &str = "sql_over_http"; +const MAX_CONNS_PER_ENDPOINT: usize = 20; + +#[derive(Debug)] +pub struct ConnInfo { + pub username: String, + pub dbname: String, + pub hostname: String, + pub password: String, +} + +impl ConnInfo { + // hm, change to hasher to avoid cloning? + pub fn db_and_user(&self) -> (String, String) { + (self.dbname.clone(), self.username.clone()) + } +} + +impl fmt::Display for ConnInfo { + // use custom display to avoid logging password + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}@{}/{}", self.username, self.hostname, self.dbname) + } +} + +struct ConnPoolEntry { + conn: tokio_postgres::Client, + _last_access: std::time::Instant, +} + +// Per-endpoint connection pool, (dbname, username) -> Vec +// Number of open connections is limited by the `max_conns_per_endpoint`. +pub struct EndpointConnPool { + pools: HashMap<(String, String), Vec>, + total_conns: usize, +} + +pub struct GlobalConnPool { + // endpoint -> per-endpoint connection pool + // + // That should be a fairly conteded map, so return reference to the per-endpoint + // pool as early as possible and release the lock. + global_pool: Mutex>>>, + + // Maximum number of connections per one endpoint. + // Can mix different (dbname, username) connections. + // When running out of free slots for a particular endpoint, + // falls back to opening a new connection for each request. + max_conns_per_endpoint: usize, + + proxy_config: &'static crate::config::ProxyConfig, +} + +impl GlobalConnPool { + pub fn new(config: &'static crate::config::ProxyConfig) -> Arc { + Arc::new(Self { + global_pool: Mutex::new(HashMap::new()), + max_conns_per_endpoint: MAX_CONNS_PER_ENDPOINT, + proxy_config: config, + }) + } + + pub async fn get( + &self, + conn_info: &ConnInfo, + force_new: bool, + ) -> anyhow::Result { + let mut client: Option = None; + + if !force_new { + let pool = self.get_endpoint_pool(&conn_info.hostname).await; + + // find a pool entry by (dbname, username) if exists + let mut pool = pool.lock(); + let pool_entries = pool.pools.get_mut(&conn_info.db_and_user()); + if let Some(pool_entries) = pool_entries { + if let Some(entry) = pool_entries.pop() { + client = Some(entry.conn); + pool.total_conns -= 1; + } + } + } + + // ok return cached connection if found and establish a new one otherwise + if let Some(client) = client { + if client.is_closed() { + info!("pool: cached connection '{conn_info}' is closed, opening a new one"); + connect_to_compute(self.proxy_config, conn_info).await + } else { + info!("pool: reusing connection '{conn_info}'"); + Ok(client) + } + } else { + info!("pool: opening a new connection '{conn_info}'"); + connect_to_compute(self.proxy_config, conn_info).await + } + } + + pub async fn put( + &self, + conn_info: &ConnInfo, + client: tokio_postgres::Client, + ) -> anyhow::Result<()> { + let pool = self.get_endpoint_pool(&conn_info.hostname).await; + + // return connection to the pool + let mut total_conns; + let mut returned = false; + let mut per_db_size = 0; + { + let mut pool = pool.lock(); + total_conns = pool.total_conns; + + let pool_entries: &mut Vec = pool + .pools + .entry(conn_info.db_and_user()) + .or_insert_with(|| Vec::with_capacity(1)); + if total_conns < self.max_conns_per_endpoint { + pool_entries.push(ConnPoolEntry { + conn: client, + _last_access: std::time::Instant::now(), + }); + + total_conns += 1; + returned = true; + per_db_size = pool_entries.len(); + + pool.total_conns += 1; + } + } + + // do logging outside of the mutex + if returned { + info!("pool: returning connection '{conn_info}' back to the pool, total_conns={total_conns}, for this (db, user)={per_db_size}"); + } else { + info!("pool: throwing away connection '{conn_info}' because pool is full, total_conns={total_conns}"); + } + + Ok(()) + } + + async fn get_endpoint_pool(&self, endpoint: &String) -> Arc> { + // find or create a pool for this endpoint + let mut created = false; + let mut global_pool = self.global_pool.lock(); + let pool = global_pool + .entry(endpoint.clone()) + .or_insert_with(|| { + created = true; + Arc::new(Mutex::new(EndpointConnPool { + pools: HashMap::new(), + total_conns: 0, + })) + }) + .clone(); + let global_pool_size = global_pool.len(); + drop(global_pool); + + // log new global pool size + if created { + info!( + "pool: created new pool for '{endpoint}', global pool size now {global_pool_size}" + ); + } + + pool + } +} + +// +// Wake up the destination if needed. Code here is a bit involved because +// we reuse the code from the usual proxy and we need to prepare few structures +// that this code expects. +// +async fn connect_to_compute( + config: &config::ProxyConfig, + conn_info: &ConnInfo, +) -> anyhow::Result { + let tls = config.tls_config.as_ref(); + let common_names = tls.and_then(|tls| tls.common_names.clone()); + + let credential_params = StartupMessageParams::new([ + ("user", &conn_info.username), + ("database", &conn_info.dbname), + ("application_name", APP_NAME), + ]); + + let creds = config + .auth_backend + .as_ref() + .map(|_| { + auth::ClientCredentials::parse( + &credential_params, + Some(&conn_info.hostname), + common_names, + ) + }) + .transpose()?; + let extra = console::ConsoleReqExtra { + session_id: uuid::Uuid::new_v4(), + application_name: Some(APP_NAME), + }; + + let node_info = &mut creds.wake_compute(&extra).await?.expect("msg"); + + // This code is a copy of `connect_to_compute` from `src/proxy.rs` with + // the difference that it uses `tokio_postgres` for the connection. + let mut num_retries: usize = NUM_RETRIES_WAKE_COMPUTE; + loop { + match connect_to_compute_once(node_info, conn_info).await { + Err(e) if num_retries > 0 => { + info!("compute node's state has changed; requesting a wake-up"); + match creds.wake_compute(&extra).await? { + // Update `node_info` and try one more time. + Some(new) => { + *node_info = new; + } + // Link auth doesn't work that way, so we just exit. + None => return Err(e), + } + } + other => return other, + } + + num_retries -= 1; + info!("retrying after wake-up ({num_retries} attempts left)"); + } +} + +async fn connect_to_compute_once( + node_info: &console::CachedNodeInfo, + conn_info: &ConnInfo, +) -> anyhow::Result { + let mut config = (*node_info.config).clone(); + + let (client, connection) = config + .user(&conn_info.username) + .password(&conn_info.password) + .dbname(&conn_info.dbname) + .max_backend_message_size(MAX_RESPONSE_SIZE) + .connect(tokio_postgres::NoTls) + .inspect_err(|e: &tokio_postgres::Error| { + error!( + "failed to connect to compute node hosts={:?} ports={:?}: {}", + node_info.config.get_hosts(), + node_info.config.get_ports(), + e + ); + invalidate_cache(node_info) + }) + .await?; + + tokio::spawn(async move { + if let Err(e) = connection.await { + error!("connection error: {}", e); + } + }); + + Ok(client) +} diff --git a/proxy/src/http/sql_over_http.rs b/proxy/src/http/sql_over_http.rs index e8ad2d04f3..adf7252f72 100644 --- a/proxy/src/http/sql_over_http.rs +++ b/proxy/src/http/sql_over_http.rs @@ -1,25 +1,21 @@ +use std::sync::Arc; + use futures::pin_mut; use futures::StreamExt; -use futures::TryFutureExt; use hyper::body::HttpBody; use hyper::http::HeaderName; use hyper::http::HeaderValue; use hyper::{Body, HeaderMap, Request}; -use pq_proto::StartupMessageParams; use serde_json::json; use serde_json::Map; use serde_json::Value; use tokio_postgres::types::Kind; use tokio_postgres::types::Type; use tokio_postgres::Row; -use tracing::error; -use tracing::info; -use tracing::instrument; use url::Url; -use crate::proxy::invalidate_cache; -use crate::proxy::NUM_RETRIES_WAKE_COMPUTE; -use crate::{auth, config::ProxyConfig, console}; +use super::conn_pool::ConnInfo; +use super::conn_pool::GlobalConnPool; #[derive(serde::Deserialize)] struct QueryData { @@ -27,12 +23,13 @@ struct QueryData { params: Vec, } -const APP_NAME: &str = "sql_over_http"; -const MAX_RESPONSE_SIZE: usize = 1024 * 1024; // 1 MB +pub const MAX_RESPONSE_SIZE: usize = 1024 * 1024; // 1 MB const MAX_REQUEST_SIZE: u64 = 1024 * 1024; // 1 MB static RAW_TEXT_OUTPUT: HeaderName = HeaderName::from_static("neon-raw-text-output"); static ARRAY_MODE: HeaderName = HeaderName::from_static("neon-array-mode"); +static ALLOW_POOL: HeaderName = HeaderName::from_static("neon-pool-opt-in"); + static HEADER_VALUE_TRUE: HeaderValue = HeaderValue::from_static("true"); // @@ -96,13 +93,6 @@ fn json_array_to_pg_array(value: &Value) -> Result, serde_json::E } } -struct ConnInfo { - username: String, - dbname: String, - hostname: String, - password: String, -} - fn get_conn_info( headers: &HeaderMap, sni_hostname: Option, @@ -169,50 +159,23 @@ fn get_conn_info( // TODO: return different http error codes pub async fn handle( - config: &'static ProxyConfig, request: Request, sni_hostname: Option, + conn_pool: Arc, ) -> anyhow::Result { // // Determine the destination and connection params // let headers = request.headers(); let conn_info = get_conn_info(headers, sni_hostname)?; - let credential_params = StartupMessageParams::new([ - ("user", &conn_info.username), - ("database", &conn_info.dbname), - ("application_name", APP_NAME), - ]); // Determine the output options. Default behaviour is 'false'. Anything that is not // strictly 'true' assumed to be false. let raw_output = headers.get(&RAW_TEXT_OUTPUT) == Some(&HEADER_VALUE_TRUE); let array_mode = headers.get(&ARRAY_MODE) == Some(&HEADER_VALUE_TRUE); - // - // Wake up the destination if needed. Code here is a bit involved because - // we reuse the code from the usual proxy and we need to prepare few structures - // that this code expects. - // - let tls = config.tls_config.as_ref(); - let common_names = tls.and_then(|tls| tls.common_names.clone()); - let creds = config - .auth_backend - .as_ref() - .map(|_| { - auth::ClientCredentials::parse( - &credential_params, - Some(&conn_info.hostname), - common_names, - ) - }) - .transpose()?; - let extra = console::ConsoleReqExtra { - session_id: uuid::Uuid::new_v4(), - application_name: Some(APP_NAME), - }; - - let mut node_info = creds.wake_compute(&extra).await?.expect("msg"); + // Allow connection pooling only if explicitly requested + let allow_pool = headers.get(&ALLOW_POOL) == Some(&HEADER_VALUE_TRUE); let request_content_length = match request.body().size_hint().upper() { Some(v) => v, @@ -235,7 +198,8 @@ pub async fn handle( // // Now execute the query and return the result // - let client = connect_to_compute(&mut node_info, &extra, &creds, &conn_info).await?; + let client = conn_pool.get(&conn_info, !allow_pool).await?; + let row_stream = client.query_raw_txt(query, query_params).await?; // Manually drain the stream into a vector to leave row_stream hanging @@ -292,6 +256,13 @@ pub async fn handle( .map(|row| pg_text_row_to_json(row, raw_output, array_mode)) .collect::, _>>()?; + if allow_pool { + // return connection to the pool + tokio::task::spawn(async move { + let _ = conn_pool.put(&conn_info, client).await; + }); + } + // resulting JSON format is based on the format of node-postgres result Ok(json!({ "command": command_tag_name, @@ -302,70 +273,6 @@ pub async fn handle( })) } -/// This function is a copy of `connect_to_compute` from `src/proxy.rs` with -/// the difference that it uses `tokio_postgres` for the connection. -#[instrument(skip_all)] -async fn connect_to_compute( - node_info: &mut console::CachedNodeInfo, - extra: &console::ConsoleReqExtra<'_>, - creds: &auth::BackendType<'_, auth::ClientCredentials<'_>>, - conn_info: &ConnInfo, -) -> anyhow::Result { - let mut num_retries: usize = NUM_RETRIES_WAKE_COMPUTE; - - loop { - match connect_to_compute_once(node_info, conn_info).await { - Err(e) if num_retries > 0 => { - info!("compute node's state has changed; requesting a wake-up"); - match creds.wake_compute(extra).await? { - // Update `node_info` and try one more time. - Some(new) => { - *node_info = new; - } - // Link auth doesn't work that way, so we just exit. - None => return Err(e), - } - } - other => return other, - } - - num_retries -= 1; - info!("retrying after wake-up ({num_retries} attempts left)"); - } -} - -async fn connect_to_compute_once( - node_info: &console::CachedNodeInfo, - conn_info: &ConnInfo, -) -> anyhow::Result { - let mut config = (*node_info.config).clone(); - - let (client, connection) = config - .user(&conn_info.username) - .password(&conn_info.password) - .dbname(&conn_info.dbname) - .max_backend_message_size(MAX_RESPONSE_SIZE) - .connect(tokio_postgres::NoTls) - .inspect_err(|e: &tokio_postgres::Error| { - error!( - "failed to connect to compute node hosts={:?} ports={:?}: {}", - node_info.config.get_hosts(), - node_info.config.get_ports(), - e - ); - invalidate_cache(node_info) - }) - .await?; - - tokio::spawn(async move { - if let Err(e) = connection.await { - error!("connection error: {}", e); - } - }); - - Ok(client) -} - // // Convert postgres row with text-encoded values to JSON object // diff --git a/proxy/src/http/websocket.rs b/proxy/src/http/websocket.rs index 9f467aceb7..83ba034e57 100644 --- a/proxy/src/http/websocket.rs +++ b/proxy/src/http/websocket.rs @@ -35,7 +35,7 @@ use utils::http::{error::ApiError, json::json_response}; // Tracking issue: https://github.com/rust-lang/rust/issues/98407. use sync_wrapper::SyncWrapper; -use super::sql_over_http; +use super::{conn_pool::GlobalConnPool, sql_over_http}; pin_project! { /// This is a wrapper around a [`WebSocketStream`] that @@ -164,6 +164,7 @@ async fn serve_websocket( async fn ws_handler( mut request: Request, config: &'static ProxyConfig, + conn_pool: Arc, cancel_map: Arc, session_id: uuid::Uuid, sni_hostname: Option, @@ -192,7 +193,7 @@ async fn ws_handler( // TODO: that deserves a refactor as now this function also handles http json client besides websockets. // Right now I don't want to blow up sql-over-http patch with file renames and do that as a follow up instead. } else if request.uri().path() == "/sql" && request.method() == Method::POST { - let result = sql_over_http::handle(config, request, sni_hostname) + let result = sql_over_http::handle(request, sni_hostname, conn_pool) .instrument(info_span!("sql-over-http")) .await; let status_code = match result { @@ -234,6 +235,8 @@ pub async fn task_main( info!("websocket server has shut down"); } + let conn_pool: Arc = GlobalConnPool::new(config); + let tls_config = config.tls_config.as_ref().map(|cfg| cfg.to_server_config()); let tls_acceptor: tokio_rustls::TlsAcceptor = match tls_config { Some(config) => config.into(), @@ -258,15 +261,18 @@ pub async fn task_main( let make_svc = hyper::service::make_service_fn(|stream: &tokio_rustls::server::TlsStream| { let sni_name = stream.get_ref().1.sni_hostname().map(|s| s.to_string()); + let conn_pool = conn_pool.clone(); async move { Ok::<_, Infallible>(hyper::service::service_fn(move |req: Request| { let sni_name = sni_name.clone(); + let conn_pool = conn_pool.clone(); + async move { let cancel_map = Arc::new(CancelMap::default()); let session_id = uuid::Uuid::new_v4(); - ws_handler(req, config, cancel_map, session_id, sni_name) + ws_handler(req, config, conn_pool, cancel_map, session_id, sni_name) .instrument(info_span!( "ws-client", session = format_args!("{session_id}")