diff --git a/proxy/src/http.rs b/proxy/src/http.rs index 153cdc02e5..e281240380 100644 --- a/proxy/src/http.rs +++ b/proxy/src/http.rs @@ -2,6 +2,7 @@ //! Other modules should use stuff from this module instead of //! directly relying on deps like `reqwest` (think loose coupling). +pub mod conn_pool; pub mod server; pub mod sql_over_http; pub mod websocket; diff --git a/proxy/src/http/conn_pool.rs b/proxy/src/http/conn_pool.rs new file mode 100644 index 0000000000..52c1e2f2ce --- /dev/null +++ b/proxy/src/http/conn_pool.rs @@ -0,0 +1,278 @@ +use parking_lot::Mutex; +use pq_proto::StartupMessageParams; +use std::fmt; +use std::{collections::HashMap, sync::Arc}; + +use futures::TryFutureExt; + +use crate::config; +use crate::{auth, console}; + +use super::sql_over_http::MAX_RESPONSE_SIZE; + +use crate::proxy::invalidate_cache; +use crate::proxy::NUM_RETRIES_WAKE_COMPUTE; + +use tracing::error; +use tracing::info; + +pub const APP_NAME: &str = "sql_over_http"; +const MAX_CONNS_PER_ENDPOINT: usize = 20; + +#[derive(Debug)] +pub struct ConnInfo { + pub username: String, + pub dbname: String, + pub hostname: String, + pub password: String, +} + +impl ConnInfo { + // hm, change to hasher to avoid cloning? + pub fn db_and_user(&self) -> (String, String) { + (self.dbname.clone(), self.username.clone()) + } +} + +impl fmt::Display for ConnInfo { + // use custom display to avoid logging password + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}@{}/{}", self.username, self.hostname, self.dbname) + } +} + +struct ConnPoolEntry { + conn: tokio_postgres::Client, + _last_access: std::time::Instant, +} + +// Per-endpoint connection pool, (dbname, username) -> Vec +// Number of open connections is limited by the `max_conns_per_endpoint`. +pub struct EndpointConnPool { + pools: HashMap<(String, String), Vec>, + total_conns: usize, +} + +pub struct GlobalConnPool { + // endpoint -> per-endpoint connection pool + // + // That should be a fairly conteded map, so return reference to the per-endpoint + // pool as early as possible and release the lock. + global_pool: Mutex>>>, + + // Maximum number of connections per one endpoint. + // Can mix different (dbname, username) connections. + // When running out of free slots for a particular endpoint, + // falls back to opening a new connection for each request. + max_conns_per_endpoint: usize, + + proxy_config: &'static crate::config::ProxyConfig, +} + +impl GlobalConnPool { + pub fn new(config: &'static crate::config::ProxyConfig) -> Arc { + Arc::new(Self { + global_pool: Mutex::new(HashMap::new()), + max_conns_per_endpoint: MAX_CONNS_PER_ENDPOINT, + proxy_config: config, + }) + } + + pub async fn get( + &self, + conn_info: &ConnInfo, + force_new: bool, + ) -> anyhow::Result { + let mut client: Option = None; + + if !force_new { + let pool = self.get_endpoint_pool(&conn_info.hostname).await; + + // find a pool entry by (dbname, username) if exists + let mut pool = pool.lock(); + let pool_entries = pool.pools.get_mut(&conn_info.db_and_user()); + if let Some(pool_entries) = pool_entries { + if let Some(entry) = pool_entries.pop() { + client = Some(entry.conn); + pool.total_conns -= 1; + } + } + } + + // ok return cached connection if found and establish a new one otherwise + if let Some(client) = client { + if client.is_closed() { + info!("pool: cached connection '{conn_info}' is closed, opening a new one"); + connect_to_compute(self.proxy_config, conn_info).await + } else { + info!("pool: reusing connection '{conn_info}'"); + Ok(client) + } + } else { + info!("pool: opening a new connection '{conn_info}'"); + connect_to_compute(self.proxy_config, conn_info).await + } + } + + pub async fn put( + &self, + conn_info: &ConnInfo, + client: tokio_postgres::Client, + ) -> anyhow::Result<()> { + let pool = self.get_endpoint_pool(&conn_info.hostname).await; + + // return connection to the pool + let mut total_conns; + let mut returned = false; + let mut per_db_size = 0; + { + let mut pool = pool.lock(); + total_conns = pool.total_conns; + + let pool_entries: &mut Vec = pool + .pools + .entry(conn_info.db_and_user()) + .or_insert_with(|| Vec::with_capacity(1)); + if total_conns < self.max_conns_per_endpoint { + pool_entries.push(ConnPoolEntry { + conn: client, + _last_access: std::time::Instant::now(), + }); + + total_conns += 1; + returned = true; + per_db_size = pool_entries.len(); + + pool.total_conns += 1; + } + } + + // do logging outside of the mutex + if returned { + info!("pool: returning connection '{conn_info}' back to the pool, total_conns={total_conns}, for this (db, user)={per_db_size}"); + } else { + info!("pool: throwing away connection '{conn_info}' because pool is full, total_conns={total_conns}"); + } + + Ok(()) + } + + async fn get_endpoint_pool(&self, endpoint: &String) -> Arc> { + // find or create a pool for this endpoint + let mut created = false; + let mut global_pool = self.global_pool.lock(); + let pool = global_pool + .entry(endpoint.clone()) + .or_insert_with(|| { + created = true; + Arc::new(Mutex::new(EndpointConnPool { + pools: HashMap::new(), + total_conns: 0, + })) + }) + .clone(); + let global_pool_size = global_pool.len(); + drop(global_pool); + + // log new global pool size + if created { + info!( + "pool: created new pool for '{endpoint}', global pool size now {global_pool_size}" + ); + } + + pool + } +} + +// +// Wake up the destination if needed. Code here is a bit involved because +// we reuse the code from the usual proxy and we need to prepare few structures +// that this code expects. +// +async fn connect_to_compute( + config: &config::ProxyConfig, + conn_info: &ConnInfo, +) -> anyhow::Result { + let tls = config.tls_config.as_ref(); + let common_names = tls.and_then(|tls| tls.common_names.clone()); + + let credential_params = StartupMessageParams::new([ + ("user", &conn_info.username), + ("database", &conn_info.dbname), + ("application_name", APP_NAME), + ]); + + let creds = config + .auth_backend + .as_ref() + .map(|_| { + auth::ClientCredentials::parse( + &credential_params, + Some(&conn_info.hostname), + common_names, + ) + }) + .transpose()?; + let extra = console::ConsoleReqExtra { + session_id: uuid::Uuid::new_v4(), + application_name: Some(APP_NAME), + }; + + let node_info = &mut creds.wake_compute(&extra).await?.expect("msg"); + + // This code is a copy of `connect_to_compute` from `src/proxy.rs` with + // the difference that it uses `tokio_postgres` for the connection. + let mut num_retries: usize = NUM_RETRIES_WAKE_COMPUTE; + loop { + match connect_to_compute_once(node_info, conn_info).await { + Err(e) if num_retries > 0 => { + info!("compute node's state has changed; requesting a wake-up"); + match creds.wake_compute(&extra).await? { + // Update `node_info` and try one more time. + Some(new) => { + *node_info = new; + } + // Link auth doesn't work that way, so we just exit. + None => return Err(e), + } + } + other => return other, + } + + num_retries -= 1; + info!("retrying after wake-up ({num_retries} attempts left)"); + } +} + +async fn connect_to_compute_once( + node_info: &console::CachedNodeInfo, + conn_info: &ConnInfo, +) -> anyhow::Result { + let mut config = (*node_info.config).clone(); + + let (client, connection) = config + .user(&conn_info.username) + .password(&conn_info.password) + .dbname(&conn_info.dbname) + .max_backend_message_size(MAX_RESPONSE_SIZE) + .connect(tokio_postgres::NoTls) + .inspect_err(|e: &tokio_postgres::Error| { + error!( + "failed to connect to compute node hosts={:?} ports={:?}: {}", + node_info.config.get_hosts(), + node_info.config.get_ports(), + e + ); + invalidate_cache(node_info) + }) + .await?; + + tokio::spawn(async move { + if let Err(e) = connection.await { + error!("connection error: {}", e); + } + }); + + Ok(client) +} diff --git a/proxy/src/http/sql_over_http.rs b/proxy/src/http/sql_over_http.rs index e8ad2d04f3..adf7252f72 100644 --- a/proxy/src/http/sql_over_http.rs +++ b/proxy/src/http/sql_over_http.rs @@ -1,25 +1,21 @@ +use std::sync::Arc; + use futures::pin_mut; use futures::StreamExt; -use futures::TryFutureExt; use hyper::body::HttpBody; use hyper::http::HeaderName; use hyper::http::HeaderValue; use hyper::{Body, HeaderMap, Request}; -use pq_proto::StartupMessageParams; use serde_json::json; use serde_json::Map; use serde_json::Value; use tokio_postgres::types::Kind; use tokio_postgres::types::Type; use tokio_postgres::Row; -use tracing::error; -use tracing::info; -use tracing::instrument; use url::Url; -use crate::proxy::invalidate_cache; -use crate::proxy::NUM_RETRIES_WAKE_COMPUTE; -use crate::{auth, config::ProxyConfig, console}; +use super::conn_pool::ConnInfo; +use super::conn_pool::GlobalConnPool; #[derive(serde::Deserialize)] struct QueryData { @@ -27,12 +23,13 @@ struct QueryData { params: Vec, } -const APP_NAME: &str = "sql_over_http"; -const MAX_RESPONSE_SIZE: usize = 1024 * 1024; // 1 MB +pub const MAX_RESPONSE_SIZE: usize = 1024 * 1024; // 1 MB const MAX_REQUEST_SIZE: u64 = 1024 * 1024; // 1 MB static RAW_TEXT_OUTPUT: HeaderName = HeaderName::from_static("neon-raw-text-output"); static ARRAY_MODE: HeaderName = HeaderName::from_static("neon-array-mode"); +static ALLOW_POOL: HeaderName = HeaderName::from_static("neon-pool-opt-in"); + static HEADER_VALUE_TRUE: HeaderValue = HeaderValue::from_static("true"); // @@ -96,13 +93,6 @@ fn json_array_to_pg_array(value: &Value) -> Result, serde_json::E } } -struct ConnInfo { - username: String, - dbname: String, - hostname: String, - password: String, -} - fn get_conn_info( headers: &HeaderMap, sni_hostname: Option, @@ -169,50 +159,23 @@ fn get_conn_info( // TODO: return different http error codes pub async fn handle( - config: &'static ProxyConfig, request: Request, sni_hostname: Option, + conn_pool: Arc, ) -> anyhow::Result { // // Determine the destination and connection params // let headers = request.headers(); let conn_info = get_conn_info(headers, sni_hostname)?; - let credential_params = StartupMessageParams::new([ - ("user", &conn_info.username), - ("database", &conn_info.dbname), - ("application_name", APP_NAME), - ]); // Determine the output options. Default behaviour is 'false'. Anything that is not // strictly 'true' assumed to be false. let raw_output = headers.get(&RAW_TEXT_OUTPUT) == Some(&HEADER_VALUE_TRUE); let array_mode = headers.get(&ARRAY_MODE) == Some(&HEADER_VALUE_TRUE); - // - // Wake up the destination if needed. Code here is a bit involved because - // we reuse the code from the usual proxy and we need to prepare few structures - // that this code expects. - // - let tls = config.tls_config.as_ref(); - let common_names = tls.and_then(|tls| tls.common_names.clone()); - let creds = config - .auth_backend - .as_ref() - .map(|_| { - auth::ClientCredentials::parse( - &credential_params, - Some(&conn_info.hostname), - common_names, - ) - }) - .transpose()?; - let extra = console::ConsoleReqExtra { - session_id: uuid::Uuid::new_v4(), - application_name: Some(APP_NAME), - }; - - let mut node_info = creds.wake_compute(&extra).await?.expect("msg"); + // Allow connection pooling only if explicitly requested + let allow_pool = headers.get(&ALLOW_POOL) == Some(&HEADER_VALUE_TRUE); let request_content_length = match request.body().size_hint().upper() { Some(v) => v, @@ -235,7 +198,8 @@ pub async fn handle( // // Now execute the query and return the result // - let client = connect_to_compute(&mut node_info, &extra, &creds, &conn_info).await?; + let client = conn_pool.get(&conn_info, !allow_pool).await?; + let row_stream = client.query_raw_txt(query, query_params).await?; // Manually drain the stream into a vector to leave row_stream hanging @@ -292,6 +256,13 @@ pub async fn handle( .map(|row| pg_text_row_to_json(row, raw_output, array_mode)) .collect::, _>>()?; + if allow_pool { + // return connection to the pool + tokio::task::spawn(async move { + let _ = conn_pool.put(&conn_info, client).await; + }); + } + // resulting JSON format is based on the format of node-postgres result Ok(json!({ "command": command_tag_name, @@ -302,70 +273,6 @@ pub async fn handle( })) } -/// This function is a copy of `connect_to_compute` from `src/proxy.rs` with -/// the difference that it uses `tokio_postgres` for the connection. -#[instrument(skip_all)] -async fn connect_to_compute( - node_info: &mut console::CachedNodeInfo, - extra: &console::ConsoleReqExtra<'_>, - creds: &auth::BackendType<'_, auth::ClientCredentials<'_>>, - conn_info: &ConnInfo, -) -> anyhow::Result { - let mut num_retries: usize = NUM_RETRIES_WAKE_COMPUTE; - - loop { - match connect_to_compute_once(node_info, conn_info).await { - Err(e) if num_retries > 0 => { - info!("compute node's state has changed; requesting a wake-up"); - match creds.wake_compute(extra).await? { - // Update `node_info` and try one more time. - Some(new) => { - *node_info = new; - } - // Link auth doesn't work that way, so we just exit. - None => return Err(e), - } - } - other => return other, - } - - num_retries -= 1; - info!("retrying after wake-up ({num_retries} attempts left)"); - } -} - -async fn connect_to_compute_once( - node_info: &console::CachedNodeInfo, - conn_info: &ConnInfo, -) -> anyhow::Result { - let mut config = (*node_info.config).clone(); - - let (client, connection) = config - .user(&conn_info.username) - .password(&conn_info.password) - .dbname(&conn_info.dbname) - .max_backend_message_size(MAX_RESPONSE_SIZE) - .connect(tokio_postgres::NoTls) - .inspect_err(|e: &tokio_postgres::Error| { - error!( - "failed to connect to compute node hosts={:?} ports={:?}: {}", - node_info.config.get_hosts(), - node_info.config.get_ports(), - e - ); - invalidate_cache(node_info) - }) - .await?; - - tokio::spawn(async move { - if let Err(e) = connection.await { - error!("connection error: {}", e); - } - }); - - Ok(client) -} - // // Convert postgres row with text-encoded values to JSON object // diff --git a/proxy/src/http/websocket.rs b/proxy/src/http/websocket.rs index 9f467aceb7..83ba034e57 100644 --- a/proxy/src/http/websocket.rs +++ b/proxy/src/http/websocket.rs @@ -35,7 +35,7 @@ use utils::http::{error::ApiError, json::json_response}; // Tracking issue: https://github.com/rust-lang/rust/issues/98407. use sync_wrapper::SyncWrapper; -use super::sql_over_http; +use super::{conn_pool::GlobalConnPool, sql_over_http}; pin_project! { /// This is a wrapper around a [`WebSocketStream`] that @@ -164,6 +164,7 @@ async fn serve_websocket( async fn ws_handler( mut request: Request, config: &'static ProxyConfig, + conn_pool: Arc, cancel_map: Arc, session_id: uuid::Uuid, sni_hostname: Option, @@ -192,7 +193,7 @@ async fn ws_handler( // TODO: that deserves a refactor as now this function also handles http json client besides websockets. // Right now I don't want to blow up sql-over-http patch with file renames and do that as a follow up instead. } else if request.uri().path() == "/sql" && request.method() == Method::POST { - let result = sql_over_http::handle(config, request, sni_hostname) + let result = sql_over_http::handle(request, sni_hostname, conn_pool) .instrument(info_span!("sql-over-http")) .await; let status_code = match result { @@ -234,6 +235,8 @@ pub async fn task_main( info!("websocket server has shut down"); } + let conn_pool: Arc = GlobalConnPool::new(config); + let tls_config = config.tls_config.as_ref().map(|cfg| cfg.to_server_config()); let tls_acceptor: tokio_rustls::TlsAcceptor = match tls_config { Some(config) => config.into(), @@ -258,15 +261,18 @@ pub async fn task_main( let make_svc = hyper::service::make_service_fn(|stream: &tokio_rustls::server::TlsStream| { let sni_name = stream.get_ref().1.sni_hostname().map(|s| s.to_string()); + let conn_pool = conn_pool.clone(); async move { Ok::<_, Infallible>(hyper::service::service_fn(move |req: Request| { let sni_name = sni_name.clone(); + let conn_pool = conn_pool.clone(); + async move { let cancel_map = Arc::new(CancelMap::default()); let session_id = uuid::Uuid::new_v4(); - ws_handler(req, config, cancel_map, session_id, sni_name) + ws_handler(req, config, conn_pool, cancel_map, session_id, sni_name) .instrument(info_span!( "ws-client", session = format_args!("{session_id}")