From a26d5652829170cbc8359deeca23f07737a3a416 Mon Sep 17 00:00:00 2001 From: Dmitry Ivanov Date: Wed, 16 Feb 2022 18:51:18 +0300 Subject: [PATCH] [proxy] Replace private static map with a public `CancelMap` This is a cleaner approach which might facilitate testing. --- proxy/src/cancellation.rs | 113 +++++++++++++++++++++----------------- proxy/src/proxy.rs | 32 +++++++---- 2 files changed, 85 insertions(+), 60 deletions(-) diff --git a/proxy/src/cancellation.rs b/proxy/src/cancellation.rs index 62f195c3d2..c1a7e81be9 100644 --- a/proxy/src/cancellation.rs +++ b/proxy/src/cancellation.rs @@ -1,15 +1,57 @@ use anyhow::{anyhow, Context}; use hashbrown::HashMap; -use lazy_static::lazy_static; use parking_lot::Mutex; use std::net::SocketAddr; use tokio::net::TcpStream; use tokio_postgres::{CancelToken, NoTls}; use zenith_utils::pq_proto::CancelKeyData; -lazy_static! { - /// Enables serving CancelRequests. - static ref CANCEL_MAP: Mutex>> = Default::default(); +/// Enables serving CancelRequests. +#[derive(Default)] +pub struct CancelMap(Mutex>>); + +impl CancelMap { + /// Cancel a running query for the corresponding connection. + pub async fn cancel_session(&self, key: CancelKeyData) -> anyhow::Result<()> { + let cancel_closure = self + .0 + .lock() + .get(&key) + .and_then(|x| x.clone()) + .with_context(|| format!("unknown session: {:?}", key))?; + + cancel_closure.try_cancel_query().await + } + + /// Run async action within an ephemeral session identified by [`CancelKeyData`]. + pub async fn with_session<'a, F, R, V>(&'a self, f: F) -> anyhow::Result + where + F: FnOnce(Session<'a>) -> R, + R: std::future::Future>, + { + // HACK: We'd rather get the real backend_pid but tokio_postgres doesn't + // expose it and we don't want to do another roundtrip to query + // for it. The client will be able to notice that this is not the + // actual backend_pid, but backend_pid is not used for anything + // so it doesn't matter. + let key = rand::random(); + + // Random key collisions are unlikely to happen here, but they're still possible, + // which is why we have to take care not to rewrite an existing key. + self.0 + .lock() + .try_insert(key, None) + .map_err(|_| anyhow!("session already exists: {:?}", key))?; + + // This will guarantee that the session gets dropped + // as soon as the future is finished. + scopeguard::defer! { + self.0.lock().remove(&key); + } + + let session = Session::new(key, self); + f(session).await + } } /// This should've been a [`std::future::Future`], but @@ -38,54 +80,27 @@ impl CancelClosure { } } -/// Cancel a running query for the corresponding connection. -pub async fn cancel_session(key: CancelKeyData) -> anyhow::Result<()> { - let cancel_closure = CANCEL_MAP - .lock() - .get(&key) - .and_then(|x| x.clone()) - .with_context(|| format!("unknown session: {:?}", key))?; - - cancel_closure.try_cancel_query().await -} - /// Helper for registering query cancellation tokens. -pub struct Session(CancelKeyData); +pub struct Session<'a> { + /// The user-facing key identifying this session. + key: CancelKeyData, + /// The [`CancelMap`] this session belongs to. + cancel_map: &'a CancelMap, +} + +impl<'a> Session<'a> { + fn new(key: CancelKeyData, cancel_map: &'a CancelMap) -> Self { + Self { key, cancel_map } + } -impl Session { /// Store the cancel token for the given session. + /// This enables query cancellation in [`crate::proxy::handshake`]. pub fn enable_cancellation(self, cancel_closure: CancelClosure) -> CancelKeyData { - CANCEL_MAP.lock().insert(self.0, Some(cancel_closure)); - self.0 + self.cancel_map + .0 + .lock() + .insert(self.key, Some(cancel_closure)); + + self.key } } - -/// Run async action within an ephemeral session identified by [`CancelKeyData`]. -pub async fn with_session(f: F) -> anyhow::Result -where - F: FnOnce(Session) -> R, - R: std::future::Future>, -{ - // HACK: We'd rather get the real backend_pid but tokio_postgres doesn't - // expose it and we don't want to do another roundtrip to query - // for it. The client will be able to notice that this is not the - // actual backend_pid, but backend_pid is not used for anything - // so it doesn't matter. - let key = rand::random(); - - // Random key collisions are unlikely to happen here, but they're still possible, - // which is why we have to take care not to rewrite an existing key. - CANCEL_MAP - .lock() - .try_insert(key, None) - .map_err(|_| anyhow!("session already exists: {:?}", key))?; - - // This will guarantee that the session gets dropped - // as soon as the future is finished. - scopeguard::defer! { - CANCEL_MAP.lock().remove(&key); - } - - let session = Session(key); - f(session).await -} diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index 1bf48f89cc..1dc301b792 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -1,10 +1,11 @@ use crate::auth; -use crate::cancellation::{self, CancelClosure}; +use crate::cancellation::{self, CancelClosure, CancelMap}; use crate::compute::DatabaseInfo; use crate::config::{ProxyConfig, TlsConfig}; use crate::stream::{MetricsStream, PqStream, Stream}; use anyhow::{bail, Context}; use lazy_static::lazy_static; +use std::sync::Arc; use tokio::io::{AsyncRead, AsyncWrite}; use tokio::net::TcpStream; use tokio_postgres::NoTls; @@ -47,22 +48,25 @@ pub async fn thread_main( println!("proxy has shut down"); } + let cancel_map = Arc::new(CancelMap::default()); loop { let (socket, peer_addr) = listener.accept().await?; println!("accepted connection from {}", peer_addr); - tokio::spawn(log_error(async { + let cancel_map = Arc::clone(&cancel_map); + tokio::spawn(log_error(async move { socket .set_nodelay(true) .context("failed to set socket option")?; - handle_client(config, socket).await + handle_client(config, &cancel_map, socket).await })); } } async fn handle_client( config: &ProxyConfig, + cancel_map: &CancelMap, stream: impl AsyncRead + AsyncWrite + Unpin, ) -> anyhow::Result<()> { // The `closed` counter will increase when this future is destroyed. @@ -72,11 +76,12 @@ async fn handle_client( } let tls = config.tls_config.clone(); - if let Some((stream, creds)) = handshake(stream, tls).await? { - cancellation::with_session(|session| async { - connect_client_to_db(config, stream, creds, session).await - }) - .await?; + if let Some((client, creds)) = handshake(stream, tls, cancel_map).await? { + cancel_map + .with_session(|session| async { + connect_client_to_db(config, session, client, creds).await + }) + .await?; } Ok(()) @@ -88,6 +93,7 @@ async fn handle_client( async fn handshake( stream: S, mut tls: Option, + cancel_map: &CancelMap, ) -> anyhow::Result>, auth::ClientCredentials)>> { // Client may try upgrading to each protocol only once let (mut tried_ssl, mut tried_gss) = (false, false); @@ -136,7 +142,7 @@ async fn handshake( break Ok(Some((stream, params.try_into()?))); } CancelRequest(cancel_key_data) => { - cancellation::cancel_session(cancel_key_data).await?; + cancel_map.cancel_session(cancel_key_data).await?; break Ok(None); } @@ -146,9 +152,9 @@ async fn handshake( async fn connect_client_to_db( config: &ProxyConfig, + session: cancellation::Session<'_>, mut client: PqStream, creds: auth::ClientCredentials, - session: cancellation::Session, ) -> anyhow::Result<()> { let db_info = creds.authenticate(config, &mut client).await?; let (db, version, cancel_closure) = connect_to_db(db_info).await?; @@ -211,8 +217,12 @@ mod tests { client: impl AsyncRead + AsyncWrite + Unpin, tls: Option, ) -> anyhow::Result<()> { + let cancel_map = CancelMap::default(); + // TODO: add some infra + tests for credentials - let (mut stream, _creds) = handshake(client, tls).await?.context("no stream")?; + let (mut stream, _creds) = handshake(client, tls, &cancel_map) + .await? + .context("no stream")?; stream .write_message_noflush(&Be::AuthenticationOk)?