From 595248532c3908418597ca5fbcd0d29d37cc6214 Mon Sep 17 00:00:00 2001 From: Arthur Petukhovsky Date: Tue, 18 Apr 2023 15:29:00 +0000 Subject: [PATCH] Add global connection cache for ws --- proxy/src/http/websocket.rs | 94 ++++++++++++++++++++++++++----------- proxy/src/main.rs | 5 ++ 2 files changed, 71 insertions(+), 28 deletions(-) diff --git a/proxy/src/http/websocket.rs b/proxy/src/http/websocket.rs index 6eed496910..1988348adb 100644 --- a/proxy/src/http/websocket.rs +++ b/proxy/src/http/websocket.rs @@ -12,6 +12,7 @@ use hyper::{ use hyper_tungstenite::{tungstenite::Message, HyperWebsocket, WebSocketStream}; use pin_project_lite::pin_project; use pq_proto::StartupMessageParams; +use tokio::sync::Mutex; use percent_encoding::percent_decode; use std::collections::HashMap; @@ -165,6 +166,7 @@ async fn ws_handler( config: &'static ProxyConfig, cancel_map: Arc, session_id: uuid::Uuid, + cache: Arc>, ) -> Result, ApiError> { let host = request .headers() @@ -188,7 +190,7 @@ async fn ws_handler( // Return the response so the spawned future can continue. Ok(response) } else if request.uri().path() == "/sql" && request.method() == Method::POST { - match handle_sql(config, request).await { + match handle_sql(config, request, cache).await { Ok(resp) => json_response(StatusCode::OK, resp).map(|mut r| { r.headers_mut().insert( "Access-Control-Allow-Origin", @@ -211,6 +213,7 @@ async fn ws_handler( async fn handle_sql( config: &'static ProxyConfig, request: Request, + cache: Arc>, ) -> anyhow::Result { let get_params = request .uri() @@ -279,42 +282,77 @@ async fn handle_sql( dbname ); - info!("!!!! connecting to: {}", conn_string); + ConnectionCache::execute(&cache, conn_string, &hostname, sql).await +} - let (client, connection) = tokio_postgres::connect(conn_string, tokio_postgres::NoTls).await?; +pub struct ConnectionCache { + connections: HashMap, +} - tokio::spawn(async move { - if let Err(e) = connection.await { - eprintln!("connection error: {}", e); - } - }); +impl ConnectionCache { + pub fn new() -> Arc> { + Arc::new(Mutex::new(Self { + connections: HashMap::new(), + })) + } - let sql = percent_decode(sql.as_bytes()).decode_utf8()?.to_string(); + pub async fn execute( + cache: &Arc>, + conn_string: &str, + hostname: &str, + sql: &str, + ) -> anyhow::Result { + // TODO: let go mutex when establishing connection + let mut cache = cache.lock().await; + let cache_key = format!("connstr={}, hostname={}", conn_string, hostname); + let client = if let Some(client) = cache.connections.get(&cache_key) { + info!("using cached connection {}", conn_string); + client + } else { + info!("!!!! connecting to: {}", conn_string); - let rows: Vec> = client - .simple_query(&sql) - .await? - .into_iter() - .filter_map(|el| { - if let tokio_postgres::SimpleQueryMessage::Row(row) = el { - let mut serilaized_row: HashMap = HashMap::new(); - for i in 0..row.len() { - let col = row.columns().get(i).map_or("?", |c| c.name()); - let val = row.get(i).unwrap_or("?"); - serilaized_row.insert(col.into(), val.into()); + let (client, connection) = + tokio_postgres::connect(conn_string, tokio_postgres::NoTls).await?; + + tokio::spawn(async move { + // TODO: remove connection from cache + if let Err(e) = connection.await { + eprintln!("connection error: {}", e); } - Some(serilaized_row) - } else { - None - } - }) - .collect(); + }); - Ok(serde_json::to_string(&rows)?) + cache.connections.insert(cache_key.clone(), client); + cache.connections.get(&cache_key).unwrap() + }; + + let sql = percent_decode(sql.as_bytes()).decode_utf8()?.to_string(); + + let rows: Vec> = client + .simple_query(&sql) + .await? + .into_iter() + .filter_map(|el| { + if let tokio_postgres::SimpleQueryMessage::Row(row) = el { + let mut serilaized_row: HashMap = HashMap::new(); + for i in 0..row.len() { + let col = row.columns().get(i).map_or("?", |c| c.name()); + let val = row.get(i).unwrap_or("?"); + serilaized_row.insert(col.into(), val.into()); + } + Some(serilaized_row) + } else { + None + } + }) + .collect(); + + Ok(serde_json::to_string(&rows)?) + } } pub async fn task_main( config: &'static ProxyConfig, + cache: &'static Arc>, ws_listener: TcpListener, cancellation_token: CancellationToken, ) -> anyhow::Result<()> { @@ -348,7 +386,7 @@ pub async fn task_main( move |req: Request| async move { let cancel_map = Arc::new(CancelMap::default()); let session_id = uuid::Uuid::new_v4(); - ws_handler(req, config, cancel_map, session_id) + ws_handler(req, config, cancel_map, session_id, cache.clone()) .instrument(info_span!( "ws-client", session = format_args!("{session_id}") diff --git a/proxy/src/main.rs b/proxy/src/main.rs index 1fd13c9f68..3b51697e1d 100644 --- a/proxy/src/main.rs +++ b/proxy/src/main.rs @@ -32,6 +32,8 @@ use tokio_util::sync::CancellationToken; use tracing::{info, warn}; use utils::{project_git_version, sentry_init::init_sentry}; +use crate::http::websocket::ConnectionCache; + project_git_version!(GIT_VERSION); /// Flattens `Result>` into `Result`. @@ -53,6 +55,8 @@ async fn main() -> anyhow::Result<()> { let args = cli().get_matches(); let config = build_config(&args)?; + let wsconn_cache = Box::leak(Box::new(ConnectionCache::new())); + info!("Authentication backend: {}", config.auth_backend); // Check that we can bind to address before further initialization @@ -82,6 +86,7 @@ async fn main() -> anyhow::Result<()> { client_tasks.push(tokio::spawn(http::websocket::task_main( config, + wsconn_cache, wss_listener, cancellation_token.clone(), )));