From 517a3d0d86303ee3084ffd99ffac7042e2eca5c2 Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Wed, 25 Jun 2025 15:19:20 +0100 Subject: [PATCH] [proxy]: BatchQueue::call is not cancel safe - make it directly cancellation aware (#12345) ## Problem https://github.com/neondatabase/cloud/issues/30539 If the current leader cancels the `call` function, then it has removed the jobs from the queue, but will never finish sending the responses. Because of this, it is not cancellation safe. ## Summary of changes Document these functions as not cancellation safe. Move cancellation of the queued jobs into the queue itself. ## Alternatives considered 1. We could spawn the task that runs the batch, since that won't get cancelled. * This requires `fn call(self: Arc)` or `fn call(&'static self)`. 2. We could add another scopeguard and return the requests back to the queue. * This requires that requests are always retry safe, and also requires requests to be `Clone`. --- proxy/src/batch.rs | 80 +++++++++++++++++++++++++----------- proxy/src/cancellation.rs | 85 +++++++++++++++++++++++---------------- proxy/src/redis/keys.rs | 3 +- 3 files changed, 109 insertions(+), 59 deletions(-) diff --git a/proxy/src/batch.rs b/proxy/src/batch.rs index 61bdf2b747..33e08797f2 100644 --- a/proxy/src/batch.rs +++ b/proxy/src/batch.rs @@ -6,7 +6,6 @@ use std::collections::BTreeMap; use std::pin::pin; use std::sync::Mutex; -use futures::future::Either; use scopeguard::ScopeGuard; use tokio::sync::oneshot::error::TryRecvError; @@ -49,37 +48,67 @@ impl BatchQueue

{ } } - pub async fn call(&self, req: P::Req) -> P::Res { + /// Perform a single request-response process, this may be batched internally. + /// + /// This function is not cancel safe. + pub async fn call( + &self, + req: P::Req, + cancelled: impl Future, + ) -> Result { let (id, mut rx) = self.inner.lock_propagate_poison().register_job(req); - let guard = scopeguard::guard(id, move |id| { - let mut inner = self.inner.lock_propagate_poison(); - if inner.queue.remove(&id).is_some() { - tracing::debug!("batched task cancelled before completion"); - } - }); + let mut cancelled = pin!(cancelled); let resp = loop { // try become the leader, or try wait for success. - let mut processor = match futures::future::select(rx, pin!(self.processor.lock())).await - { - // we got the resp. - Either::Left((resp, _)) => break resp.ok(), - // we are the leader. - Either::Right((p, rx_)) => { - rx = rx_; - p - } + let mut processor = tokio::select! { + // try become leader. + p = self.processor.lock() => p, + // wait for success. + resp = &mut rx => break resp.ok(), + // wait for cancellation. + cancel = cancelled.as_mut() => { + let mut inner = self.inner.lock_propagate_poison(); + if inner.queue.remove(&id).is_some() { + tracing::warn!("batched task cancelled before completion"); + } + return Err(cancel); + }, }; + tracing::debug!(id, "batch: became leader"); let (reqs, resps) = self.inner.lock_propagate_poison().get_batch(&processor); + // snitch incase the task gets cancelled. + let cancel_safety = scopeguard::guard((), |()| { + if !std::thread::panicking() { + tracing::error!( + id, + "batch: leader cancelled, despite not being cancellation safe" + ); + } + }); + // apply a batch. + // if this is cancelled, jobs will not be completed and will panic. let values = processor.apply(reqs).await; + // good: we didn't get cancelled. + ScopeGuard::into_inner(cancel_safety); + + if values.len() != resps.len() { + tracing::error!( + "batch: invalid response size, expected={}, got={}", + resps.len(), + values.len() + ); + } + // send response values. for (tx, value) in std::iter::zip(resps, values) { - // sender hung up but that's fine. - drop(tx.send(value)); + if tx.send(value).is_err() { + // receiver hung up but that's fine. + } } match rx.try_recv() { @@ -98,10 +127,9 @@ impl BatchQueue

{ } }; - // already removed. - ScopeGuard::into_inner(guard); + tracing::debug!(id, "batch: job completed"); - resp.expect("no response found. batch processer should not panic") + Ok(resp.expect("no response found. batch processer should not panic")) } } @@ -125,6 +153,8 @@ impl BatchQueueInner

{ self.queue.insert(id, BatchJob { req, res: tx }); + tracing::debug!(id, "batch: registered job in the queue"); + (id, rx) } @@ -132,15 +162,19 @@ impl BatchQueueInner

{ let batch_size = p.batch_size(self.queue.len()); let mut reqs = Vec::with_capacity(batch_size); let mut resps = Vec::with_capacity(batch_size); + let mut ids = Vec::with_capacity(batch_size); while reqs.len() < batch_size { - let Some((_, job)) = self.queue.pop_first() else { + let Some((id, job)) = self.queue.pop_first() else { break; }; reqs.push(job.req); resps.push(job.res); + ids.push(id); } + tracing::debug!(ids=?ids, "batch: acquired jobs"); + (reqs, resps) } } diff --git a/proxy/src/cancellation.rs b/proxy/src/cancellation.rs index 036f36c7f6..ffc0cf43f1 100644 --- a/proxy/src/cancellation.rs +++ b/proxy/src/cancellation.rs @@ -1,5 +1,6 @@ use std::convert::Infallible; use std::net::{IpAddr, SocketAddr}; +use std::pin::pin; use std::sync::{Arc, OnceLock}; use std::time::Duration; @@ -98,7 +99,6 @@ impl Pipeline { impl CancelKeyOp { fn register(&self, pipe: &mut Pipeline) { - #[allow(clippy::used_underscore_binding)] match self { CancelKeyOp::StoreCancelKey { key, value, expire } => { let key = KeyPrefix::Cancel(*key).build_redis_key(); @@ -224,6 +224,7 @@ impl CancellationHandler { } } + /// This is not cancel safe async fn get_cancel_key( &self, key: CancelKeyData, @@ -240,16 +241,21 @@ impl CancellationHandler { }; const TIMEOUT: Duration = Duration::from_secs(5); - let result = timeout(TIMEOUT, tx.call((guard, op))) - .await - .map_err(|_| { - tracing::warn!("timed out waiting to receive GetCancelData response"); - CancelError::RateLimit - })? - .map_err(|e| { - tracing::warn!("failed to receive GetCancelData response: {e}"); - CancelError::InternalError - })?; + let result = timeout( + TIMEOUT, + tx.call((guard, op), std::future::pending::()), + ) + .await + .map_err(|_| { + tracing::warn!("timed out waiting to receive GetCancelData response"); + CancelError::RateLimit + })? + // cannot be cancelled + .unwrap_or_else(|x| match x {}) + .map_err(|e| { + tracing::warn!("failed to receive GetCancelData response: {e}"); + CancelError::InternalError + })?; let cancel_state_str = String::from_owned_redis_value(result).map_err(|e| { tracing::warn!("failed to receive GetCancelData response: {e}"); @@ -271,6 +277,8 @@ impl CancellationHandler { /// Will fetch IP allowlist internally. /// /// return Result primarily for tests + /// + /// This is not cancel safe pub(crate) async fn cancel_session( &self, key: CancelKeyData, @@ -394,6 +402,8 @@ impl Session { /// Ensure the cancel key is continously refreshed, /// but stop when the channel is dropped. + /// + /// This is not cancel safe pub(crate) async fn maintain_cancel_key( &self, session_id: uuid::Uuid, @@ -401,27 +411,6 @@ impl Session { cancel_closure: &CancelClosure, compute_config: &ComputeConfig, ) { - futures::future::select( - std::pin::pin!(self.maintain_redis_cancel_key(cancel_closure)), - cancel, - ) - .await; - - if let Err(err) = cancel_closure - .try_cancel_query(compute_config) - .boxed() - .await - { - tracing::warn!( - ?session_id, - ?err, - "could not cancel the query in the database" - ); - } - } - - // Ensure the cancel key is continously refreshed. - async fn maintain_redis_cancel_key(&self, cancel_closure: &CancelClosure) -> ! { let Some(tx) = self.cancellation_handler.tx.get() else { tracing::warn!("cancellation handler is not available"); // don't exit, as we only want to exit if cancelled externally. @@ -432,6 +421,8 @@ impl Session { .expect("serialising to json string should not fail") .into_boxed_str(); + let mut cancel = pin!(cancel); + loop { let guard = Metrics::get() .proxy @@ -449,9 +440,35 @@ impl Session { "registering cancellation key" ); - if tx.call((guard, op)).await.is_ok() { - tokio::time::sleep(CANCEL_KEY_REFRESH).await; + match tx.call((guard, op), cancel.as_mut()).await { + Ok(Ok(_)) => { + tracing::debug!( + src=%self.key, + dest=?cancel_closure.cancel_token, + "registered cancellation key" + ); + + // wait before continuing. + tokio::time::sleep(CANCEL_KEY_REFRESH).await; + } + // retry immediately. + Ok(Err(error)) => { + tracing::warn!(?error, "error registering cancellation key"); + } + Err(Err(_cancelled)) => break, } } + + if let Err(err) = cancel_closure + .try_cancel_query(compute_config) + .boxed() + .await + { + tracing::warn!( + ?session_id, + ?err, + "could not cancel the query in the database" + ); + } } } diff --git a/proxy/src/redis/keys.rs b/proxy/src/redis/keys.rs index b453e6851c..ffb7bc876b 100644 --- a/proxy/src/redis/keys.rs +++ b/proxy/src/redis/keys.rs @@ -23,9 +23,8 @@ impl KeyPrefix { #[cfg(test)] mod tests { - use crate::pqproto::id_to_cancel_key; - use super::*; + use crate::pqproto::id_to_cancel_key; #[test] fn test_build_redis_key() {