[proxy] Replace private static map with a public CancelMap

This is a cleaner approach which might facilitate testing.
This commit is contained in:
Dmitry Ivanov
2022-02-16 18:51:18 +03:00
parent a47dade622
commit a26d565282
2 changed files with 85 additions and 60 deletions

View File

@@ -1,15 +1,57 @@
use anyhow::{anyhow, Context};
use hashbrown::HashMap;
use lazy_static::lazy_static;
use parking_lot::Mutex;
use std::net::SocketAddr;
use tokio::net::TcpStream;
use tokio_postgres::{CancelToken, NoTls};
use zenith_utils::pq_proto::CancelKeyData;
lazy_static! {
/// Enables serving CancelRequests.
static ref CANCEL_MAP: Mutex<HashMap<CancelKeyData, Option<CancelClosure>>> = Default::default();
/// Enables serving CancelRequests.
#[derive(Default)]
pub struct CancelMap(Mutex<HashMap<CancelKeyData, Option<CancelClosure>>>);
impl CancelMap {
/// Cancel a running query for the corresponding connection.
pub async fn cancel_session(&self, key: CancelKeyData) -> anyhow::Result<()> {
let cancel_closure = self
.0
.lock()
.get(&key)
.and_then(|x| x.clone())
.with_context(|| format!("unknown session: {:?}", key))?;
cancel_closure.try_cancel_query().await
}
/// Run async action within an ephemeral session identified by [`CancelKeyData`].
pub async fn with_session<'a, F, R, V>(&'a self, f: F) -> anyhow::Result<V>
where
F: FnOnce(Session<'a>) -> R,
R: std::future::Future<Output = anyhow::Result<V>>,
{
// HACK: We'd rather get the real backend_pid but tokio_postgres doesn't
// expose it and we don't want to do another roundtrip to query
// for it. The client will be able to notice that this is not the
// actual backend_pid, but backend_pid is not used for anything
// so it doesn't matter.
let key = rand::random();
// Random key collisions are unlikely to happen here, but they're still possible,
// which is why we have to take care not to rewrite an existing key.
self.0
.lock()
.try_insert(key, None)
.map_err(|_| anyhow!("session already exists: {:?}", key))?;
// This will guarantee that the session gets dropped
// as soon as the future is finished.
scopeguard::defer! {
self.0.lock().remove(&key);
}
let session = Session::new(key, self);
f(session).await
}
}
/// This should've been a [`std::future::Future`], but
@@ -38,54 +80,27 @@ impl CancelClosure {
}
}
/// Cancel a running query for the corresponding connection.
pub async fn cancel_session(key: CancelKeyData) -> anyhow::Result<()> {
let cancel_closure = CANCEL_MAP
.lock()
.get(&key)
.and_then(|x| x.clone())
.with_context(|| format!("unknown session: {:?}", key))?;
cancel_closure.try_cancel_query().await
}
/// Helper for registering query cancellation tokens.
pub struct Session(CancelKeyData);
pub struct Session<'a> {
/// The user-facing key identifying this session.
key: CancelKeyData,
/// The [`CancelMap`] this session belongs to.
cancel_map: &'a CancelMap,
}
impl<'a> Session<'a> {
fn new(key: CancelKeyData, cancel_map: &'a CancelMap) -> Self {
Self { key, cancel_map }
}
impl Session {
/// Store the cancel token for the given session.
/// This enables query cancellation in [`crate::proxy::handshake`].
pub fn enable_cancellation(self, cancel_closure: CancelClosure) -> CancelKeyData {
CANCEL_MAP.lock().insert(self.0, Some(cancel_closure));
self.0
self.cancel_map
.0
.lock()
.insert(self.key, Some(cancel_closure));
self.key
}
}
/// Run async action within an ephemeral session identified by [`CancelKeyData`].
pub async fn with_session<F, R, V>(f: F) -> anyhow::Result<V>
where
F: FnOnce(Session) -> R,
R: std::future::Future<Output = anyhow::Result<V>>,
{
// HACK: We'd rather get the real backend_pid but tokio_postgres doesn't
// expose it and we don't want to do another roundtrip to query
// for it. The client will be able to notice that this is not the
// actual backend_pid, but backend_pid is not used for anything
// so it doesn't matter.
let key = rand::random();
// Random key collisions are unlikely to happen here, but they're still possible,
// which is why we have to take care not to rewrite an existing key.
CANCEL_MAP
.lock()
.try_insert(key, None)
.map_err(|_| anyhow!("session already exists: {:?}", key))?;
// This will guarantee that the session gets dropped
// as soon as the future is finished.
scopeguard::defer! {
CANCEL_MAP.lock().remove(&key);
}
let session = Session(key);
f(session).await
}

View File

@@ -1,10 +1,11 @@
use crate::auth;
use crate::cancellation::{self, CancelClosure};
use crate::cancellation::{self, CancelClosure, CancelMap};
use crate::compute::DatabaseInfo;
use crate::config::{ProxyConfig, TlsConfig};
use crate::stream::{MetricsStream, PqStream, Stream};
use anyhow::{bail, Context};
use lazy_static::lazy_static;
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpStream;
use tokio_postgres::NoTls;
@@ -47,22 +48,25 @@ pub async fn thread_main(
println!("proxy has shut down");
}
let cancel_map = Arc::new(CancelMap::default());
loop {
let (socket, peer_addr) = listener.accept().await?;
println!("accepted connection from {}", peer_addr);
tokio::spawn(log_error(async {
let cancel_map = Arc::clone(&cancel_map);
tokio::spawn(log_error(async move {
socket
.set_nodelay(true)
.context("failed to set socket option")?;
handle_client(config, socket).await
handle_client(config, &cancel_map, socket).await
}));
}
}
async fn handle_client(
config: &ProxyConfig,
cancel_map: &CancelMap,
stream: impl AsyncRead + AsyncWrite + Unpin,
) -> anyhow::Result<()> {
// The `closed` counter will increase when this future is destroyed.
@@ -72,11 +76,12 @@ async fn handle_client(
}
let tls = config.tls_config.clone();
if let Some((stream, creds)) = handshake(stream, tls).await? {
cancellation::with_session(|session| async {
connect_client_to_db(config, stream, creds, session).await
})
.await?;
if let Some((client, creds)) = handshake(stream, tls, cancel_map).await? {
cancel_map
.with_session(|session| async {
connect_client_to_db(config, session, client, creds).await
})
.await?;
}
Ok(())
@@ -88,6 +93,7 @@ async fn handle_client(
async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
stream: S,
mut tls: Option<TlsConfig>,
cancel_map: &CancelMap,
) -> anyhow::Result<Option<(PqStream<Stream<S>>, auth::ClientCredentials)>> {
// Client may try upgrading to each protocol only once
let (mut tried_ssl, mut tried_gss) = (false, false);
@@ -136,7 +142,7 @@ async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
break Ok(Some((stream, params.try_into()?)));
}
CancelRequest(cancel_key_data) => {
cancellation::cancel_session(cancel_key_data).await?;
cancel_map.cancel_session(cancel_key_data).await?;
break Ok(None);
}
@@ -146,9 +152,9 @@ async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
async fn connect_client_to_db(
config: &ProxyConfig,
session: cancellation::Session<'_>,
mut client: PqStream<impl AsyncRead + AsyncWrite + Unpin>,
creds: auth::ClientCredentials,
session: cancellation::Session,
) -> anyhow::Result<()> {
let db_info = creds.authenticate(config, &mut client).await?;
let (db, version, cancel_closure) = connect_to_db(db_info).await?;
@@ -211,8 +217,12 @@ mod tests {
client: impl AsyncRead + AsyncWrite + Unpin,
tls: Option<TlsConfig>,
) -> anyhow::Result<()> {
let cancel_map = CancelMap::default();
// TODO: add some infra + tests for credentials
let (mut stream, _creds) = handshake(client, tls).await?.context("no stream")?;
let (mut stream, _creds) = handshake(client, tls, &cancel_map)
.await?
.context("no stream")?;
stream
.write_message_noflush(&Be::AuthenticationOk)?