Compare commits

...

1 Commits

Author SHA1 Message Date
Heikki Linnakangas
c18b291244 Refactor cancellation Session to be more flexible.
Instead of with_session that calls a Future with the session, have a
more conventional constructor function, `new_session`, which returns a
Session. The session is automatically removed from the cancellation
map in Drop. This makes it nicer to use.
2023-03-13 11:27:03 +02:00
2 changed files with 42 additions and 34 deletions

View File

@@ -25,12 +25,11 @@ impl CancelMap {
cancel_closure.try_cancel_query().await
}
/// Run async action within an ephemeral session identified by [`CancelKeyData`].
pub async fn with_session<'a, F, R, V>(&'a self, f: F) -> anyhow::Result<V>
where
F: FnOnce(Session<'a>) -> R,
R: std::future::Future<Output = anyhow::Result<V>>,
{
/// Create a new session, with a new client-facing random cancellation key.
///
/// Use `enable_query_cancellation` to register the Postgres backend's cancellation
/// key with it.
pub fn new_session<'a>(&'a self) -> anyhow::Result<Session<'a>> {
// HACK: We'd rather get the real backend_pid but tokio_postgres doesn't
// expose it and we don't want to do another roundtrip to query
// for it. The client will be able to notice that this is not the
@@ -44,17 +43,9 @@ impl CancelMap {
.write()
.try_insert(key, None)
.map_err(|_| anyhow!("query cancellation key already exists: {key}"))?;
// This will guarantee that the session gets dropped
// as soon as the future is finished.
scopeguard::defer! {
self.0.write().remove(&key);
info!("dropped query cancellation key {key}");
}
info!("registered new query cancellation key {key}");
let session = Session::new(key, self);
f(session).await
Ok(Session::new(key, self))
}
#[cfg(test)]
@@ -111,7 +102,7 @@ impl<'a> Session<'a> {
impl Session<'_> {
/// Store the cancel token for the given session.
/// This enables query cancellation in [`crate::proxy::handshake`].
pub fn enable_query_cancellation(self, cancel_closure: CancelClosure) -> CancelKeyData {
pub fn enable_query_cancellation(&self, cancel_closure: CancelClosure) -> CancelKeyData {
info!("enabling query cancellation for this session");
self.cancel_map
.0
@@ -122,6 +113,14 @@ impl Session<'_> {
}
}
impl<'a> Drop for Session<'a> {
fn drop(&mut self) {
let key = &self.key;
self.cancel_map.0.write().remove(key);
info!("dropped query cancellation key {key}");
}
}
#[cfg(test)]
mod tests {
use super::*;
@@ -132,14 +131,14 @@ mod tests {
static CANCEL_MAP: Lazy<CancelMap> = Lazy::new(Default::default);
let (tx, rx) = tokio::sync::oneshot::channel();
let task = tokio::spawn(CANCEL_MAP.with_session(|session| async move {
let session = CANCEL_MAP.new_session()?;
let task = tokio::spawn(async move {
assert!(CANCEL_MAP.contains(&session));
tx.send(()).expect("failed to send");
futures::future::pending::<()>().await; // sleep forever
Ok(())
}));
});
// Wait until the task has been spawned.
rx.await.context("failed to hear from the task")?;

View File

@@ -133,10 +133,14 @@ pub async fn handle_ws_client(
async { result }.or_else(|e| stream.throw_error(e)).await?
};
let client = Client::new(stream, creds, &params, session_id);
cancel_map
.with_session(|session| client.connect_to_db(session, true))
.await
let client = Client::new(
stream,
creds,
&params,
session_id,
cancel_map.new_session()?,
);
client.connect_to_db(true).await
}
#[tracing::instrument(fields(session_id), skip_all)]
@@ -172,10 +176,14 @@ async fn handle_client(
async { result }.or_else(|e| stream.throw_error(e)).await?
};
let client = Client::new(stream, creds, &params, session_id);
cancel_map
.with_session(|session| client.connect_to_db(session, false))
.await
let client = Client::new(
stream,
creds,
&params,
session_id,
cancel_map.new_session()?,
);
client.connect_to_db(false).await
}
/// Establish a (most probably, secure) connection with the client.
@@ -381,6 +389,8 @@ struct Client<'a, S> {
params: &'a StartupMessageParams,
/// Unique connection ID.
session_id: uuid::Uuid,
session: cancellation::Session<'a>,
}
impl<'a, S> Client<'a, S> {
@@ -390,28 +400,27 @@ impl<'a, S> Client<'a, S> {
creds: auth::BackendType<'a, auth::ClientCredentials<'a>>,
params: &'a StartupMessageParams,
session_id: uuid::Uuid,
session: cancellation::Session<'a>,
) -> Self {
Self {
stream,
creds,
params,
session_id,
session,
}
}
}
impl<S: AsyncRead + AsyncWrite + Unpin> Client<'_, S> {
/// Let the client authenticate and connect to the designated compute node.
async fn connect_to_db(
self,
session: cancellation::Session<'_>,
allow_cleartext: bool,
) -> anyhow::Result<()> {
async fn connect_to_db(self, allow_cleartext: bool) -> anyhow::Result<()> {
let Self {
mut stream,
mut creds,
params,
session_id,
session,
} = self;
let extra = console::ConsoleReqExtra {