mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-15 09:22:55 +00:00
Add global connection cache for ws
This commit is contained in:
@@ -12,6 +12,7 @@ use hyper::{
|
||||
use hyper_tungstenite::{tungstenite::Message, HyperWebsocket, WebSocketStream};
|
||||
use pin_project_lite::pin_project;
|
||||
use pq_proto::StartupMessageParams;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
use percent_encoding::percent_decode;
|
||||
use std::collections::HashMap;
|
||||
@@ -165,6 +166,7 @@ async fn ws_handler(
|
||||
config: &'static ProxyConfig,
|
||||
cancel_map: Arc<CancelMap>,
|
||||
session_id: uuid::Uuid,
|
||||
cache: Arc<Mutex<ConnectionCache>>,
|
||||
) -> Result<Response<Body>, ApiError> {
|
||||
let host = request
|
||||
.headers()
|
||||
@@ -188,7 +190,7 @@ async fn ws_handler(
|
||||
// Return the response so the spawned future can continue.
|
||||
Ok(response)
|
||||
} else if request.uri().path() == "/sql" && request.method() == Method::POST {
|
||||
match handle_sql(config, request).await {
|
||||
match handle_sql(config, request, cache).await {
|
||||
Ok(resp) => json_response(StatusCode::OK, resp).map(|mut r| {
|
||||
r.headers_mut().insert(
|
||||
"Access-Control-Allow-Origin",
|
||||
@@ -211,6 +213,7 @@ async fn ws_handler(
|
||||
async fn handle_sql(
|
||||
config: &'static ProxyConfig,
|
||||
request: Request<Body>,
|
||||
cache: Arc<Mutex<ConnectionCache>>,
|
||||
) -> anyhow::Result<String> {
|
||||
let get_params = request
|
||||
.uri()
|
||||
@@ -279,42 +282,77 @@ async fn handle_sql(
|
||||
dbname
|
||||
);
|
||||
|
||||
info!("!!!! connecting to: {}", conn_string);
|
||||
ConnectionCache::execute(&cache, conn_string, &hostname, sql).await
|
||||
}
|
||||
|
||||
let (client, connection) = tokio_postgres::connect(conn_string, tokio_postgres::NoTls).await?;
|
||||
pub struct ConnectionCache {
|
||||
connections: HashMap<String, tokio_postgres::Client>,
|
||||
}
|
||||
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = connection.await {
|
||||
eprintln!("connection error: {}", e);
|
||||
}
|
||||
});
|
||||
impl ConnectionCache {
|
||||
pub fn new() -> Arc<Mutex<Self>> {
|
||||
Arc::new(Mutex::new(Self {
|
||||
connections: HashMap::new(),
|
||||
}))
|
||||
}
|
||||
|
||||
let sql = percent_decode(sql.as_bytes()).decode_utf8()?.to_string();
|
||||
pub async fn execute(
|
||||
cache: &Arc<Mutex<ConnectionCache>>,
|
||||
conn_string: &str,
|
||||
hostname: &str,
|
||||
sql: &str,
|
||||
) -> anyhow::Result<String> {
|
||||
// TODO: let go mutex when establishing connection
|
||||
let mut cache = cache.lock().await;
|
||||
let cache_key = format!("connstr={}, hostname={}", conn_string, hostname);
|
||||
let client = if let Some(client) = cache.connections.get(&cache_key) {
|
||||
info!("using cached connection {}", conn_string);
|
||||
client
|
||||
} else {
|
||||
info!("!!!! connecting to: {}", conn_string);
|
||||
|
||||
let rows: Vec<HashMap<_, _>> = client
|
||||
.simple_query(&sql)
|
||||
.await?
|
||||
.into_iter()
|
||||
.filter_map(|el| {
|
||||
if let tokio_postgres::SimpleQueryMessage::Row(row) = el {
|
||||
let mut serilaized_row: HashMap<String, String> = HashMap::new();
|
||||
for i in 0..row.len() {
|
||||
let col = row.columns().get(i).map_or("?", |c| c.name());
|
||||
let val = row.get(i).unwrap_or("?");
|
||||
serilaized_row.insert(col.into(), val.into());
|
||||
let (client, connection) =
|
||||
tokio_postgres::connect(conn_string, tokio_postgres::NoTls).await?;
|
||||
|
||||
tokio::spawn(async move {
|
||||
// TODO: remove connection from cache
|
||||
if let Err(e) = connection.await {
|
||||
eprintln!("connection error: {}", e);
|
||||
}
|
||||
Some(serilaized_row)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
});
|
||||
|
||||
Ok(serde_json::to_string(&rows)?)
|
||||
cache.connections.insert(cache_key.clone(), client);
|
||||
cache.connections.get(&cache_key).unwrap()
|
||||
};
|
||||
|
||||
let sql = percent_decode(sql.as_bytes()).decode_utf8()?.to_string();
|
||||
|
||||
let rows: Vec<HashMap<_, _>> = client
|
||||
.simple_query(&sql)
|
||||
.await?
|
||||
.into_iter()
|
||||
.filter_map(|el| {
|
||||
if let tokio_postgres::SimpleQueryMessage::Row(row) = el {
|
||||
let mut serilaized_row: HashMap<String, String> = HashMap::new();
|
||||
for i in 0..row.len() {
|
||||
let col = row.columns().get(i).map_or("?", |c| c.name());
|
||||
let val = row.get(i).unwrap_or("?");
|
||||
serilaized_row.insert(col.into(), val.into());
|
||||
}
|
||||
Some(serilaized_row)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(serde_json::to_string(&rows)?)
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn task_main(
|
||||
config: &'static ProxyConfig,
|
||||
cache: &'static Arc<Mutex<ConnectionCache>>,
|
||||
ws_listener: TcpListener,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> anyhow::Result<()> {
|
||||
@@ -348,7 +386,7 @@ pub async fn task_main(
|
||||
move |req: Request<Body>| async move {
|
||||
let cancel_map = Arc::new(CancelMap::default());
|
||||
let session_id = uuid::Uuid::new_v4();
|
||||
ws_handler(req, config, cancel_map, session_id)
|
||||
ws_handler(req, config, cancel_map, session_id, cache.clone())
|
||||
.instrument(info_span!(
|
||||
"ws-client",
|
||||
session = format_args!("{session_id}")
|
||||
|
||||
@@ -32,6 +32,8 @@ use tokio_util::sync::CancellationToken;
|
||||
use tracing::{info, warn};
|
||||
use utils::{project_git_version, sentry_init::init_sentry};
|
||||
|
||||
use crate::http::websocket::ConnectionCache;
|
||||
|
||||
project_git_version!(GIT_VERSION);
|
||||
|
||||
/// Flattens `Result<Result<T>>` into `Result<T>`.
|
||||
@@ -53,6 +55,8 @@ async fn main() -> anyhow::Result<()> {
|
||||
let args = cli().get_matches();
|
||||
let config = build_config(&args)?;
|
||||
|
||||
let wsconn_cache = Box::leak(Box::new(ConnectionCache::new()));
|
||||
|
||||
info!("Authentication backend: {}", config.auth_backend);
|
||||
|
||||
// Check that we can bind to address before further initialization
|
||||
@@ -82,6 +86,7 @@ async fn main() -> anyhow::Result<()> {
|
||||
|
||||
client_tasks.push(tokio::spawn(http::websocket::task_main(
|
||||
config,
|
||||
wsconn_cache,
|
||||
wss_listener,
|
||||
cancellation_token.clone(),
|
||||
)));
|
||||
|
||||
Reference in New Issue
Block a user