Compare commits

...

1 Commits

Author SHA1 Message Date
Heikki Linnakangas
c18b291244 Refactor cancellation Session to be more flexible.
Instead of with_session that calls a Future with the session, have a
more conventional constructor function, `new_session`, which returns a
Session. The session is automatically removed from the cancellation
map in Drop. This makes it nicer to use.
2023-03-13 11:27:03 +02:00
2 changed files with 42 additions and 34 deletions

View File

@@ -25,12 +25,11 @@ impl CancelMap {
cancel_closure.try_cancel_query().await cancel_closure.try_cancel_query().await
} }
/// Run async action within an ephemeral session identified by [`CancelKeyData`]. /// Create a new session, with a new client-facing random cancellation key.
pub async fn with_session<'a, F, R, V>(&'a self, f: F) -> anyhow::Result<V> ///
where /// Use `enable_query_cancellation` to register the Postgres backend's cancellation
F: FnOnce(Session<'a>) -> R, /// key with it.
R: std::future::Future<Output = anyhow::Result<V>>, pub fn new_session<'a>(&'a self) -> anyhow::Result<Session<'a>> {
{
// HACK: We'd rather get the real backend_pid but tokio_postgres doesn't // HACK: We'd rather get the real backend_pid but tokio_postgres doesn't
// expose it and we don't want to do another roundtrip to query // expose it and we don't want to do another roundtrip to query
// for it. The client will be able to notice that this is not the // for it. The client will be able to notice that this is not the
@@ -44,17 +43,9 @@ impl CancelMap {
.write() .write()
.try_insert(key, None) .try_insert(key, None)
.map_err(|_| anyhow!("query cancellation key already exists: {key}"))?; .map_err(|_| anyhow!("query cancellation key already exists: {key}"))?;
// This will guarantee that the session gets dropped
// as soon as the future is finished.
scopeguard::defer! {
self.0.write().remove(&key);
info!("dropped query cancellation key {key}");
}
info!("registered new query cancellation key {key}"); info!("registered new query cancellation key {key}");
let session = Session::new(key, self);
f(session).await Ok(Session::new(key, self))
} }
#[cfg(test)] #[cfg(test)]
@@ -111,7 +102,7 @@ impl<'a> Session<'a> {
impl Session<'_> { impl Session<'_> {
/// Store the cancel token for the given session. /// Store the cancel token for the given session.
/// This enables query cancellation in [`crate::proxy::handshake`]. /// This enables query cancellation in [`crate::proxy::handshake`].
pub fn enable_query_cancellation(self, cancel_closure: CancelClosure) -> CancelKeyData { pub fn enable_query_cancellation(&self, cancel_closure: CancelClosure) -> CancelKeyData {
info!("enabling query cancellation for this session"); info!("enabling query cancellation for this session");
self.cancel_map self.cancel_map
.0 .0
@@ -122,6 +113,14 @@ impl Session<'_> {
} }
} }
impl<'a> Drop for Session<'a> {
fn drop(&mut self) {
let key = &self.key;
self.cancel_map.0.write().remove(key);
info!("dropped query cancellation key {key}");
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
@@ -132,14 +131,14 @@ mod tests {
static CANCEL_MAP: Lazy<CancelMap> = Lazy::new(Default::default); static CANCEL_MAP: Lazy<CancelMap> = Lazy::new(Default::default);
let (tx, rx) = tokio::sync::oneshot::channel(); let (tx, rx) = tokio::sync::oneshot::channel();
let task = tokio::spawn(CANCEL_MAP.with_session(|session| async move {
let session = CANCEL_MAP.new_session()?;
let task = tokio::spawn(async move {
assert!(CANCEL_MAP.contains(&session)); assert!(CANCEL_MAP.contains(&session));
tx.send(()).expect("failed to send"); tx.send(()).expect("failed to send");
futures::future::pending::<()>().await; // sleep forever futures::future::pending::<()>().await; // sleep forever
});
Ok(())
}));
// Wait until the task has been spawned. // Wait until the task has been spawned.
rx.await.context("failed to hear from the task")?; rx.await.context("failed to hear from the task")?;

View File

@@ -133,10 +133,14 @@ pub async fn handle_ws_client(
async { result }.or_else(|e| stream.throw_error(e)).await? async { result }.or_else(|e| stream.throw_error(e)).await?
}; };
let client = Client::new(stream, creds, &params, session_id); let client = Client::new(
cancel_map stream,
.with_session(|session| client.connect_to_db(session, true)) creds,
.await &params,
session_id,
cancel_map.new_session()?,
);
client.connect_to_db(true).await
} }
#[tracing::instrument(fields(session_id), skip_all)] #[tracing::instrument(fields(session_id), skip_all)]
@@ -172,10 +176,14 @@ async fn handle_client(
async { result }.or_else(|e| stream.throw_error(e)).await? async { result }.or_else(|e| stream.throw_error(e)).await?
}; };
let client = Client::new(stream, creds, &params, session_id); let client = Client::new(
cancel_map stream,
.with_session(|session| client.connect_to_db(session, false)) creds,
.await &params,
session_id,
cancel_map.new_session()?,
);
client.connect_to_db(false).await
} }
/// Establish a (most probably, secure) connection with the client. /// Establish a (most probably, secure) connection with the client.
@@ -381,6 +389,8 @@ struct Client<'a, S> {
params: &'a StartupMessageParams, params: &'a StartupMessageParams,
/// Unique connection ID. /// Unique connection ID.
session_id: uuid::Uuid, session_id: uuid::Uuid,
session: cancellation::Session<'a>,
} }
impl<'a, S> Client<'a, S> { impl<'a, S> Client<'a, S> {
@@ -390,28 +400,27 @@ impl<'a, S> Client<'a, S> {
creds: auth::BackendType<'a, auth::ClientCredentials<'a>>, creds: auth::BackendType<'a, auth::ClientCredentials<'a>>,
params: &'a StartupMessageParams, params: &'a StartupMessageParams,
session_id: uuid::Uuid, session_id: uuid::Uuid,
session: cancellation::Session<'a>,
) -> Self { ) -> Self {
Self { Self {
stream, stream,
creds, creds,
params, params,
session_id, session_id,
session,
} }
} }
} }
impl<S: AsyncRead + AsyncWrite + Unpin> Client<'_, S> { impl<S: AsyncRead + AsyncWrite + Unpin> Client<'_, S> {
/// Let the client authenticate and connect to the designated compute node. /// Let the client authenticate and connect to the designated compute node.
async fn connect_to_db( async fn connect_to_db(self, allow_cleartext: bool) -> anyhow::Result<()> {
self,
session: cancellation::Session<'_>,
allow_cleartext: bool,
) -> anyhow::Result<()> {
let Self { let Self {
mut stream, mut stream,
mut creds, mut creds,
params, params,
session_id, session_id,
session,
} = self; } = self;
let extra = console::ConsoleReqExtra { let extra = console::ConsoleReqExtra {