[proxy]: BatchQueue::call is not cancel safe - make it directly cancellation aware (#12345)

## Problem

https://github.com/neondatabase/cloud/issues/30539

If the current leader cancels the `call` function, then it has removed
the jobs from the queue, but will never finish sending the responses.
Because of this, it is not cancellation safe.

## Summary of changes

Document these functions as not cancellation safe. Move cancellation of
the queued jobs into the queue itself.

## Alternatives considered

1. We could spawn the task that runs the batch, since that won't get
cancelled.
* This requires `fn call(self: Arc<Self>)` or `fn call(&'static self)`.
2. We could add another scopeguard and return the requests back to the
queue.
* This requires that requests are always retry safe, and also requires
requests to be `Clone`.
This commit is contained in:
Conrad Ludgate
2025-06-25 15:19:20 +01:00
committed by GitHub
parent 27ca1e21be
commit 517a3d0d86
3 changed files with 109 additions and 59 deletions

View File

@@ -6,7 +6,6 @@ use std::collections::BTreeMap;
use std::pin::pin; use std::pin::pin;
use std::sync::Mutex; use std::sync::Mutex;
use futures::future::Either;
use scopeguard::ScopeGuard; use scopeguard::ScopeGuard;
use tokio::sync::oneshot::error::TryRecvError; use tokio::sync::oneshot::error::TryRecvError;
@@ -49,37 +48,67 @@ impl<P: QueueProcessing> BatchQueue<P> {
} }
} }
pub async fn call(&self, req: P::Req) -> P::Res { /// Perform a single request-response process, this may be batched internally.
///
/// This function is not cancel safe.
pub async fn call<R>(
&self,
req: P::Req,
cancelled: impl Future<Output = R>,
) -> Result<P::Res, R> {
let (id, mut rx) = self.inner.lock_propagate_poison().register_job(req); let (id, mut rx) = self.inner.lock_propagate_poison().register_job(req);
let guard = scopeguard::guard(id, move |id| {
let mut inner = self.inner.lock_propagate_poison();
if inner.queue.remove(&id).is_some() {
tracing::debug!("batched task cancelled before completion");
}
});
let mut cancelled = pin!(cancelled);
let resp = loop { let resp = loop {
// try become the leader, or try wait for success. // try become the leader, or try wait for success.
let mut processor = match futures::future::select(rx, pin!(self.processor.lock())).await let mut processor = tokio::select! {
{ // try become leader.
// we got the resp. p = self.processor.lock() => p,
Either::Left((resp, _)) => break resp.ok(), // wait for success.
// we are the leader. resp = &mut rx => break resp.ok(),
Either::Right((p, rx_)) => { // wait for cancellation.
rx = rx_; cancel = cancelled.as_mut() => {
p let mut inner = self.inner.lock_propagate_poison();
} if inner.queue.remove(&id).is_some() {
tracing::warn!("batched task cancelled before completion");
}
return Err(cancel);
},
}; };
tracing::debug!(id, "batch: became leader");
let (reqs, resps) = self.inner.lock_propagate_poison().get_batch(&processor); let (reqs, resps) = self.inner.lock_propagate_poison().get_batch(&processor);
// snitch incase the task gets cancelled.
let cancel_safety = scopeguard::guard((), |()| {
if !std::thread::panicking() {
tracing::error!(
id,
"batch: leader cancelled, despite not being cancellation safe"
);
}
});
// apply a batch. // apply a batch.
// if this is cancelled, jobs will not be completed and will panic.
let values = processor.apply(reqs).await; let values = processor.apply(reqs).await;
// good: we didn't get cancelled.
ScopeGuard::into_inner(cancel_safety);
if values.len() != resps.len() {
tracing::error!(
"batch: invalid response size, expected={}, got={}",
resps.len(),
values.len()
);
}
// send response values. // send response values.
for (tx, value) in std::iter::zip(resps, values) { for (tx, value) in std::iter::zip(resps, values) {
// sender hung up but that's fine. if tx.send(value).is_err() {
drop(tx.send(value)); // receiver hung up but that's fine.
}
} }
match rx.try_recv() { match rx.try_recv() {
@@ -98,10 +127,9 @@ impl<P: QueueProcessing> BatchQueue<P> {
} }
}; };
// already removed. tracing::debug!(id, "batch: job completed");
ScopeGuard::into_inner(guard);
resp.expect("no response found. batch processer should not panic") Ok(resp.expect("no response found. batch processer should not panic"))
} }
} }
@@ -125,6 +153,8 @@ impl<P: QueueProcessing> BatchQueueInner<P> {
self.queue.insert(id, BatchJob { req, res: tx }); self.queue.insert(id, BatchJob { req, res: tx });
tracing::debug!(id, "batch: registered job in the queue");
(id, rx) (id, rx)
} }
@@ -132,15 +162,19 @@ impl<P: QueueProcessing> BatchQueueInner<P> {
let batch_size = p.batch_size(self.queue.len()); let batch_size = p.batch_size(self.queue.len());
let mut reqs = Vec::with_capacity(batch_size); let mut reqs = Vec::with_capacity(batch_size);
let mut resps = Vec::with_capacity(batch_size); let mut resps = Vec::with_capacity(batch_size);
let mut ids = Vec::with_capacity(batch_size);
while reqs.len() < batch_size { while reqs.len() < batch_size {
let Some((_, job)) = self.queue.pop_first() else { let Some((id, job)) = self.queue.pop_first() else {
break; break;
}; };
reqs.push(job.req); reqs.push(job.req);
resps.push(job.res); resps.push(job.res);
ids.push(id);
} }
tracing::debug!(ids=?ids, "batch: acquired jobs");
(reqs, resps) (reqs, resps)
} }
} }

View File

@@ -1,5 +1,6 @@
use std::convert::Infallible; use std::convert::Infallible;
use std::net::{IpAddr, SocketAddr}; use std::net::{IpAddr, SocketAddr};
use std::pin::pin;
use std::sync::{Arc, OnceLock}; use std::sync::{Arc, OnceLock};
use std::time::Duration; use std::time::Duration;
@@ -98,7 +99,6 @@ impl Pipeline {
impl CancelKeyOp { impl CancelKeyOp {
fn register(&self, pipe: &mut Pipeline) { fn register(&self, pipe: &mut Pipeline) {
#[allow(clippy::used_underscore_binding)]
match self { match self {
CancelKeyOp::StoreCancelKey { key, value, expire } => { CancelKeyOp::StoreCancelKey { key, value, expire } => {
let key = KeyPrefix::Cancel(*key).build_redis_key(); let key = KeyPrefix::Cancel(*key).build_redis_key();
@@ -224,6 +224,7 @@ impl CancellationHandler {
} }
} }
/// This is not cancel safe
async fn get_cancel_key( async fn get_cancel_key(
&self, &self,
key: CancelKeyData, key: CancelKeyData,
@@ -240,16 +241,21 @@ impl CancellationHandler {
}; };
const TIMEOUT: Duration = Duration::from_secs(5); const TIMEOUT: Duration = Duration::from_secs(5);
let result = timeout(TIMEOUT, tx.call((guard, op))) let result = timeout(
.await TIMEOUT,
.map_err(|_| { tx.call((guard, op), std::future::pending::<Infallible>()),
tracing::warn!("timed out waiting to receive GetCancelData response"); )
CancelError::RateLimit .await
})? .map_err(|_| {
.map_err(|e| { tracing::warn!("timed out waiting to receive GetCancelData response");
tracing::warn!("failed to receive GetCancelData response: {e}"); CancelError::RateLimit
CancelError::InternalError })?
})?; // cannot be cancelled
.unwrap_or_else(|x| match x {})
.map_err(|e| {
tracing::warn!("failed to receive GetCancelData response: {e}");
CancelError::InternalError
})?;
let cancel_state_str = String::from_owned_redis_value(result).map_err(|e| { let cancel_state_str = String::from_owned_redis_value(result).map_err(|e| {
tracing::warn!("failed to receive GetCancelData response: {e}"); tracing::warn!("failed to receive GetCancelData response: {e}");
@@ -271,6 +277,8 @@ impl CancellationHandler {
/// Will fetch IP allowlist internally. /// Will fetch IP allowlist internally.
/// ///
/// return Result primarily for tests /// return Result primarily for tests
///
/// This is not cancel safe
pub(crate) async fn cancel_session<T: ControlPlaneApi>( pub(crate) async fn cancel_session<T: ControlPlaneApi>(
&self, &self,
key: CancelKeyData, key: CancelKeyData,
@@ -394,6 +402,8 @@ impl Session {
/// Ensure the cancel key is continously refreshed, /// Ensure the cancel key is continously refreshed,
/// but stop when the channel is dropped. /// but stop when the channel is dropped.
///
/// This is not cancel safe
pub(crate) async fn maintain_cancel_key( pub(crate) async fn maintain_cancel_key(
&self, &self,
session_id: uuid::Uuid, session_id: uuid::Uuid,
@@ -401,27 +411,6 @@ impl Session {
cancel_closure: &CancelClosure, cancel_closure: &CancelClosure,
compute_config: &ComputeConfig, compute_config: &ComputeConfig,
) { ) {
futures::future::select(
std::pin::pin!(self.maintain_redis_cancel_key(cancel_closure)),
cancel,
)
.await;
if let Err(err) = cancel_closure
.try_cancel_query(compute_config)
.boxed()
.await
{
tracing::warn!(
?session_id,
?err,
"could not cancel the query in the database"
);
}
}
// Ensure the cancel key is continously refreshed.
async fn maintain_redis_cancel_key(&self, cancel_closure: &CancelClosure) -> ! {
let Some(tx) = self.cancellation_handler.tx.get() else { let Some(tx) = self.cancellation_handler.tx.get() else {
tracing::warn!("cancellation handler is not available"); tracing::warn!("cancellation handler is not available");
// don't exit, as we only want to exit if cancelled externally. // don't exit, as we only want to exit if cancelled externally.
@@ -432,6 +421,8 @@ impl Session {
.expect("serialising to json string should not fail") .expect("serialising to json string should not fail")
.into_boxed_str(); .into_boxed_str();
let mut cancel = pin!(cancel);
loop { loop {
let guard = Metrics::get() let guard = Metrics::get()
.proxy .proxy
@@ -449,9 +440,35 @@ impl Session {
"registering cancellation key" "registering cancellation key"
); );
if tx.call((guard, op)).await.is_ok() { match tx.call((guard, op), cancel.as_mut()).await {
tokio::time::sleep(CANCEL_KEY_REFRESH).await; Ok(Ok(_)) => {
tracing::debug!(
src=%self.key,
dest=?cancel_closure.cancel_token,
"registered cancellation key"
);
// wait before continuing.
tokio::time::sleep(CANCEL_KEY_REFRESH).await;
}
// retry immediately.
Ok(Err(error)) => {
tracing::warn!(?error, "error registering cancellation key");
}
Err(Err(_cancelled)) => break,
} }
} }
if let Err(err) = cancel_closure
.try_cancel_query(compute_config)
.boxed()
.await
{
tracing::warn!(
?session_id,
?err,
"could not cancel the query in the database"
);
}
} }
} }

View File

@@ -23,9 +23,8 @@ impl KeyPrefix {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::pqproto::id_to_cancel_key;
use super::*; use super::*;
use crate::pqproto::id_to_cancel_key;
#[test] #[test]
fn test_build_redis_key() { fn test_build_redis_key() {