diff --git a/proxy/src/batch.rs b/proxy/src/batch.rs index cf866ab9a3..6a82df5ae4 100644 --- a/proxy/src/batch.rs +++ b/proxy/src/batch.rs @@ -63,16 +63,19 @@ impl BatchQueue

{ } } + pub fn enqueue(&self, req: P::Req) -> (u64, oneshot::Receiver>) { + self.inner.lock_propagate_poison().register_job(req) + } + /// Perform a single request-response process, this may be batched internally. /// /// This function is not cancel safe. pub async fn call( &self, - req: P::Req, + id: u64, + mut rx: oneshot::Receiver>, cancelled: impl Future, ) -> Result> { - let (id, mut rx) = self.inner.lock_propagate_poison().register_job(req); - let mut cancelled = pin!(cancelled); let resp: Option> = loop { // try become the leader, or try wait for success. diff --git a/proxy/src/cancellation.rs b/proxy/src/cancellation.rs index 77062d3bb4..617cf21b2c 100644 --- a/proxy/src/cancellation.rs +++ b/proxy/src/cancellation.rs @@ -6,6 +6,7 @@ use std::time::Duration; use futures::FutureExt; use ipnet::{IpNet, Ipv4Net, Ipv6Net}; +use metrics::MeasuredCounterPairGuard; use postgres_client::RawCancelToken; use postgres_client::tls::MakeTlsConnect; use redis::{Cmd, FromRedisValue, SetExpiry, SetOptions, Value}; @@ -23,7 +24,9 @@ use crate::context::RequestContext; use crate::control_plane::ControlPlaneApi; use crate::error::ReportableError; use crate::ext::LockExt; -use crate::metrics::{CancelChannelSizeGuard, CancellationRequest, Metrics, RedisMsgKind}; +use crate::metrics::{ + CancelChannelSizeGauge, CancelChannelSizeGuard, CancellationRequest, Metrics, RedisMsgKind, +}; use crate::pqproto::CancelKeyData; use crate::rate_limiter::LeakyBucketRateLimiter; use crate::redis::keys::KeyPrefix; @@ -54,6 +57,24 @@ pub enum CancelKeyOp { }, } +impl CancelKeyOp { + fn redis_msg_kind(&self) -> RedisMsgKind { + match self { + CancelKeyOp::Store { .. } => RedisMsgKind::Set, + CancelKeyOp::Refresh { .. } => RedisMsgKind::Expire, + CancelKeyOp::Get { .. } => RedisMsgKind::Get, + CancelKeyOp::GetOld { .. } => RedisMsgKind::HGet, + } + } + + fn metric_guard(&self) -> MeasuredCounterPairGuard<'static, CancelChannelSizeGauge> { + Metrics::get() + .proxy + .cancel_channel_size + .guard(self.redis_msg_kind()) + } +} + #[derive(thiserror::Error, Debug, Clone)] pub enum PipelineError { #[error("could not send cmd to redis: {0}")] @@ -268,14 +289,11 @@ impl CancellationHandler { return Err(CancelError::InternalError); }; - let guard = Metrics::get() - .proxy - .cancel_channel_size - .guard(RedisMsgKind::Get); let op = CancelKeyOp::Get { key }; + let (id, rx) = tx.enqueue((op.metric_guard(), op)); let result = timeout( TIMEOUT, - tx.call((guard, op), std::future::pending::()), + tx.call(id, rx, std::future::pending::()), ) .await .map_err(|_| { @@ -293,14 +311,11 @@ impl CancellationHandler { && let Some(errcode) = err.code() && errcode == "WRONGTYPE" { - let guard = Metrics::get() - .proxy - .cancel_channel_size - .guard(RedisMsgKind::HGet); let op = CancelKeyOp::GetOld { key }; + let (id, rx) = tx.enqueue((op.metric_guard(), op)); timeout( TIMEOUT, - tx.call((guard, op), std::future::pending::()), + tx.call(id, rx, std::future::pending::()), ) .await .map_err(|_| { @@ -489,44 +504,36 @@ impl Session { let mut state = State::Set; loop { - let guard_op = match state { + let op = match state { State::Set => { - let guard = Metrics::get() - .proxy - .cancel_channel_size - .guard(RedisMsgKind::Set); - let op = CancelKeyOp::Store { - key: self.key, - value: closure_json.clone(), - expire: CANCEL_KEY_TTL, - }; tracing::debug!( src=%self.key, dest=?cancel_closure.cancel_token, "registering cancellation key" ); - (guard, op) + CancelKeyOp::Store { + key: self.key, + value: closure_json.clone(), + expire: CANCEL_KEY_TTL, + } } State::Refresh => { - let guard = Metrics::get() - .proxy - .cancel_channel_size - .guard(RedisMsgKind::Expire); - let op = CancelKeyOp::Refresh { - key: self.key, - expire: CANCEL_KEY_TTL, - }; tracing::debug!( src=%self.key, dest=?cancel_closure.cancel_token, "refreshing cancellation key" ); - (guard, op) + CancelKeyOp::Refresh { + key: self.key, + expire: CANCEL_KEY_TTL, + } } }; - match tx.call(guard_op, cancel.as_mut()).await { + let (id, rx) = tx.enqueue((op.metric_guard(), op)); + + match tx.call(id, rx, cancel.as_mut()).await { // SET returns OK Ok(Value::Okay) => { tracing::debug!(