[proxy]: BatchQueue::call is not cancel safe - make it directly cancellation aware (#12345)

## Problem

https://github.com/neondatabase/cloud/issues/30539

If the current leader cancels the `call` function, then it has removed
the jobs from the queue, but will never finish sending the responses.
Because of this, it is not cancellation safe.

## Summary of changes

Document these functions as not cancellation safe. Move cancellation of
the queued jobs into the queue itself.

## Alternatives considered

1. We could spawn the task that runs the batch, since that won't get
cancelled.
* This requires `fn call(self: Arc<Self>)` or `fn call(&'static self)`.
2. We could add another scopeguard and return the requests back to the
queue.
* This requires that requests are always retry safe, and also requires
requests to be `Clone`.
This commit is contained in:
Conrad Ludgate
2025-06-25 15:19:20 +01:00
committed by GitHub
parent 27ca1e21be
commit 517a3d0d86
3 changed files with 109 additions and 59 deletions

View File

@@ -6,7 +6,6 @@ use std::collections::BTreeMap;
use std::pin::pin;
use std::sync::Mutex;
use futures::future::Either;
use scopeguard::ScopeGuard;
use tokio::sync::oneshot::error::TryRecvError;
@@ -49,37 +48,67 @@ impl<P: QueueProcessing> BatchQueue<P> {
}
}
pub async fn call(&self, req: P::Req) -> P::Res {
/// Perform a single request-response process, this may be batched internally.
///
/// This function is not cancel safe.
pub async fn call<R>(
&self,
req: P::Req,
cancelled: impl Future<Output = R>,
) -> Result<P::Res, R> {
let (id, mut rx) = self.inner.lock_propagate_poison().register_job(req);
let guard = scopeguard::guard(id, move |id| {
let mut inner = self.inner.lock_propagate_poison();
if inner.queue.remove(&id).is_some() {
tracing::debug!("batched task cancelled before completion");
}
});
let mut cancelled = pin!(cancelled);
let resp = loop {
// try become the leader, or try wait for success.
let mut processor = match futures::future::select(rx, pin!(self.processor.lock())).await
{
// we got the resp.
Either::Left((resp, _)) => break resp.ok(),
// we are the leader.
Either::Right((p, rx_)) => {
rx = rx_;
p
}
let mut processor = tokio::select! {
// try become leader.
p = self.processor.lock() => p,
// wait for success.
resp = &mut rx => break resp.ok(),
// wait for cancellation.
cancel = cancelled.as_mut() => {
let mut inner = self.inner.lock_propagate_poison();
if inner.queue.remove(&id).is_some() {
tracing::warn!("batched task cancelled before completion");
}
return Err(cancel);
},
};
tracing::debug!(id, "batch: became leader");
let (reqs, resps) = self.inner.lock_propagate_poison().get_batch(&processor);
// snitch incase the task gets cancelled.
let cancel_safety = scopeguard::guard((), |()| {
if !std::thread::panicking() {
tracing::error!(
id,
"batch: leader cancelled, despite not being cancellation safe"
);
}
});
// apply a batch.
// if this is cancelled, jobs will not be completed and will panic.
let values = processor.apply(reqs).await;
// good: we didn't get cancelled.
ScopeGuard::into_inner(cancel_safety);
if values.len() != resps.len() {
tracing::error!(
"batch: invalid response size, expected={}, got={}",
resps.len(),
values.len()
);
}
// send response values.
for (tx, value) in std::iter::zip(resps, values) {
// sender hung up but that's fine.
drop(tx.send(value));
if tx.send(value).is_err() {
// receiver hung up but that's fine.
}
}
match rx.try_recv() {
@@ -98,10 +127,9 @@ impl<P: QueueProcessing> BatchQueue<P> {
}
};
// already removed.
ScopeGuard::into_inner(guard);
tracing::debug!(id, "batch: job completed");
resp.expect("no response found. batch processer should not panic")
Ok(resp.expect("no response found. batch processer should not panic"))
}
}
@@ -125,6 +153,8 @@ impl<P: QueueProcessing> BatchQueueInner<P> {
self.queue.insert(id, BatchJob { req, res: tx });
tracing::debug!(id, "batch: registered job in the queue");
(id, rx)
}
@@ -132,15 +162,19 @@ impl<P: QueueProcessing> BatchQueueInner<P> {
let batch_size = p.batch_size(self.queue.len());
let mut reqs = Vec::with_capacity(batch_size);
let mut resps = Vec::with_capacity(batch_size);
let mut ids = Vec::with_capacity(batch_size);
while reqs.len() < batch_size {
let Some((_, job)) = self.queue.pop_first() else {
let Some((id, job)) = self.queue.pop_first() else {
break;
};
reqs.push(job.req);
resps.push(job.res);
ids.push(id);
}
tracing::debug!(ids=?ids, "batch: acquired jobs");
(reqs, resps)
}
}