[proxy]: BatchQueue::call is not cancel safe - make it directly cancellation aware (#12345)

## Problem

https://github.com/neondatabase/cloud/issues/30539

If the current leader cancels the `call` function, then it has removed
the jobs from the queue, but will never finish sending the responses.
Because of this, it is not cancellation safe.

## Summary of changes

Document these functions as not cancellation safe. Move cancellation of
the queued jobs into the queue itself.

## Alternatives considered

1. We could spawn the task that runs the batch, since that won't get
cancelled.
* This requires `fn call(self: Arc<Self>)` or `fn call(&'static self)`.
2. We could add another scopeguard and return the requests back to the
queue.
* This requires that requests are always retry safe, and also requires
requests to be `Clone`.
This commit is contained in:
Conrad Ludgate
2025-06-25 15:19:20 +01:00
committed by GitHub
parent 27ca1e21be
commit 517a3d0d86
3 changed files with 109 additions and 59 deletions

View File

@@ -6,7 +6,6 @@ use std::collections::BTreeMap;
use std::pin::pin;
use std::sync::Mutex;
use futures::future::Either;
use scopeguard::ScopeGuard;
use tokio::sync::oneshot::error::TryRecvError;
@@ -49,37 +48,67 @@ impl<P: QueueProcessing> BatchQueue<P> {
}
}
pub async fn call(&self, req: P::Req) -> P::Res {
/// Perform a single request-response process, this may be batched internally.
///
/// This function is not cancel safe.
pub async fn call<R>(
&self,
req: P::Req,
cancelled: impl Future<Output = R>,
) -> Result<P::Res, R> {
let (id, mut rx) = self.inner.lock_propagate_poison().register_job(req);
let guard = scopeguard::guard(id, move |id| {
let mut inner = self.inner.lock_propagate_poison();
if inner.queue.remove(&id).is_some() {
tracing::debug!("batched task cancelled before completion");
}
});
let mut cancelled = pin!(cancelled);
let resp = loop {
// try become the leader, or try wait for success.
let mut processor = match futures::future::select(rx, pin!(self.processor.lock())).await
{
// we got the resp.
Either::Left((resp, _)) => break resp.ok(),
// we are the leader.
Either::Right((p, rx_)) => {
rx = rx_;
p
}
let mut processor = tokio::select! {
// try become leader.
p = self.processor.lock() => p,
// wait for success.
resp = &mut rx => break resp.ok(),
// wait for cancellation.
cancel = cancelled.as_mut() => {
let mut inner = self.inner.lock_propagate_poison();
if inner.queue.remove(&id).is_some() {
tracing::warn!("batched task cancelled before completion");
}
return Err(cancel);
},
};
tracing::debug!(id, "batch: became leader");
let (reqs, resps) = self.inner.lock_propagate_poison().get_batch(&processor);
// snitch incase the task gets cancelled.
let cancel_safety = scopeguard::guard((), |()| {
if !std::thread::panicking() {
tracing::error!(
id,
"batch: leader cancelled, despite not being cancellation safe"
);
}
});
// apply a batch.
// if this is cancelled, jobs will not be completed and will panic.
let values = processor.apply(reqs).await;
// good: we didn't get cancelled.
ScopeGuard::into_inner(cancel_safety);
if values.len() != resps.len() {
tracing::error!(
"batch: invalid response size, expected={}, got={}",
resps.len(),
values.len()
);
}
// send response values.
for (tx, value) in std::iter::zip(resps, values) {
// sender hung up but that's fine.
drop(tx.send(value));
if tx.send(value).is_err() {
// receiver hung up but that's fine.
}
}
match rx.try_recv() {
@@ -98,10 +127,9 @@ impl<P: QueueProcessing> BatchQueue<P> {
}
};
// already removed.
ScopeGuard::into_inner(guard);
tracing::debug!(id, "batch: job completed");
resp.expect("no response found. batch processer should not panic")
Ok(resp.expect("no response found. batch processer should not panic"))
}
}
@@ -125,6 +153,8 @@ impl<P: QueueProcessing> BatchQueueInner<P> {
self.queue.insert(id, BatchJob { req, res: tx });
tracing::debug!(id, "batch: registered job in the queue");
(id, rx)
}
@@ -132,15 +162,19 @@ impl<P: QueueProcessing> BatchQueueInner<P> {
let batch_size = p.batch_size(self.queue.len());
let mut reqs = Vec::with_capacity(batch_size);
let mut resps = Vec::with_capacity(batch_size);
let mut ids = Vec::with_capacity(batch_size);
while reqs.len() < batch_size {
let Some((_, job)) = self.queue.pop_first() else {
let Some((id, job)) = self.queue.pop_first() else {
break;
};
reqs.push(job.req);
resps.push(job.res);
ids.push(id);
}
tracing::debug!(ids=?ids, "batch: acquired jobs");
(reqs, resps)
}
}

View File

@@ -1,5 +1,6 @@
use std::convert::Infallible;
use std::net::{IpAddr, SocketAddr};
use std::pin::pin;
use std::sync::{Arc, OnceLock};
use std::time::Duration;
@@ -98,7 +99,6 @@ impl Pipeline {
impl CancelKeyOp {
fn register(&self, pipe: &mut Pipeline) {
#[allow(clippy::used_underscore_binding)]
match self {
CancelKeyOp::StoreCancelKey { key, value, expire } => {
let key = KeyPrefix::Cancel(*key).build_redis_key();
@@ -224,6 +224,7 @@ impl CancellationHandler {
}
}
/// This is not cancel safe
async fn get_cancel_key(
&self,
key: CancelKeyData,
@@ -240,16 +241,21 @@ impl CancellationHandler {
};
const TIMEOUT: Duration = Duration::from_secs(5);
let result = timeout(TIMEOUT, tx.call((guard, op)))
.await
.map_err(|_| {
tracing::warn!("timed out waiting to receive GetCancelData response");
CancelError::RateLimit
})?
.map_err(|e| {
tracing::warn!("failed to receive GetCancelData response: {e}");
CancelError::InternalError
})?;
let result = timeout(
TIMEOUT,
tx.call((guard, op), std::future::pending::<Infallible>()),
)
.await
.map_err(|_| {
tracing::warn!("timed out waiting to receive GetCancelData response");
CancelError::RateLimit
})?
// cannot be cancelled
.unwrap_or_else(|x| match x {})
.map_err(|e| {
tracing::warn!("failed to receive GetCancelData response: {e}");
CancelError::InternalError
})?;
let cancel_state_str = String::from_owned_redis_value(result).map_err(|e| {
tracing::warn!("failed to receive GetCancelData response: {e}");
@@ -271,6 +277,8 @@ impl CancellationHandler {
/// Will fetch IP allowlist internally.
///
/// return Result primarily for tests
///
/// This is not cancel safe
pub(crate) async fn cancel_session<T: ControlPlaneApi>(
&self,
key: CancelKeyData,
@@ -394,6 +402,8 @@ impl Session {
/// Ensure the cancel key is continously refreshed,
/// but stop when the channel is dropped.
///
/// This is not cancel safe
pub(crate) async fn maintain_cancel_key(
&self,
session_id: uuid::Uuid,
@@ -401,27 +411,6 @@ impl Session {
cancel_closure: &CancelClosure,
compute_config: &ComputeConfig,
) {
futures::future::select(
std::pin::pin!(self.maintain_redis_cancel_key(cancel_closure)),
cancel,
)
.await;
if let Err(err) = cancel_closure
.try_cancel_query(compute_config)
.boxed()
.await
{
tracing::warn!(
?session_id,
?err,
"could not cancel the query in the database"
);
}
}
// Ensure the cancel key is continously refreshed.
async fn maintain_redis_cancel_key(&self, cancel_closure: &CancelClosure) -> ! {
let Some(tx) = self.cancellation_handler.tx.get() else {
tracing::warn!("cancellation handler is not available");
// don't exit, as we only want to exit if cancelled externally.
@@ -432,6 +421,8 @@ impl Session {
.expect("serialising to json string should not fail")
.into_boxed_str();
let mut cancel = pin!(cancel);
loop {
let guard = Metrics::get()
.proxy
@@ -449,9 +440,35 @@ impl Session {
"registering cancellation key"
);
if tx.call((guard, op)).await.is_ok() {
tokio::time::sleep(CANCEL_KEY_REFRESH).await;
match tx.call((guard, op), cancel.as_mut()).await {
Ok(Ok(_)) => {
tracing::debug!(
src=%self.key,
dest=?cancel_closure.cancel_token,
"registered cancellation key"
);
// wait before continuing.
tokio::time::sleep(CANCEL_KEY_REFRESH).await;
}
// retry immediately.
Ok(Err(error)) => {
tracing::warn!(?error, "error registering cancellation key");
}
Err(Err(_cancelled)) => break,
}
}
if let Err(err) = cancel_closure
.try_cancel_query(compute_config)
.boxed()
.await
{
tracing::warn!(
?session_id,
?err,
"could not cancel the query in the database"
);
}
}
}

View File

@@ -23,9 +23,8 @@ impl KeyPrefix {
#[cfg(test)]
mod tests {
use crate::pqproto::id_to_cancel_key;
use super::*;
use crate::pqproto::id_to_cancel_key;
#[test]
fn test_build_redis_key() {