From 946e971df8550ffd56c31584bc1cd372ab88599c Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Sat, 12 Apr 2025 10:16:22 +0100 Subject: [PATCH] feat(proxy): add batching to cancellation queue processing (#10607) Add batching to the redis queue, which allows us to clear it out quicker should it slow down temporarily. --- proxy/src/binary/proxy.rs | 9 +- proxy/src/cancellation.rs | 220 +++++++++++++++++++++++++++----------- proxy/src/redis/kv_ops.rs | 172 +++++------------------------ 3 files changed, 192 insertions(+), 209 deletions(-) diff --git a/proxy/src/binary/proxy.rs b/proxy/src/binary/proxy.rs index 62fdc18207..e03f2f33d9 100644 --- a/proxy/src/binary/proxy.rs +++ b/proxy/src/binary/proxy.rs @@ -509,7 +509,14 @@ pub async fn run() -> anyhow::Result<()> { if let Some(mut redis_kv_client) = redis_kv_client { maintenance_tasks.spawn(async move { redis_kv_client.try_connect().await?; - handle_cancel_messages(&mut redis_kv_client, rx_cancel).await + handle_cancel_messages(&mut redis_kv_client, rx_cancel).await?; + + drop(redis_kv_client); + + // `handle_cancel_messages` was terminated due to the tx_cancel + // being dropped. this is not worthy of an error, and this task can only return `Err`, + // so let's wait forever instead. + std::future::pending().await }); } diff --git a/proxy/src/cancellation.rs b/proxy/src/cancellation.rs index d6a7406f67..c5ba04eb8c 100644 --- a/proxy/src/cancellation.rs +++ b/proxy/src/cancellation.rs @@ -1,16 +1,17 @@ -use std::convert::Infallible; use std::net::{IpAddr, SocketAddr}; use std::sync::Arc; +use anyhow::{Context, anyhow}; use ipnet::{IpNet, Ipv4Net, Ipv6Net}; use postgres_client::CancelToken; use postgres_client::tls::MakeTlsConnect; use pq_proto::CancelKeyData; +use redis::{FromRedisValue, Pipeline, Value, pipe}; use serde::{Deserialize, Serialize}; use thiserror::Error; use tokio::net::TcpStream; use tokio::sync::{mpsc, oneshot}; -use tracing::{debug, info}; +use tracing::{debug, info, warn}; use crate::auth::backend::ComputeUserInfo; use crate::auth::{AuthError, check_peer_addr_is_in_list}; @@ -30,6 +31,7 @@ type IpSubnetKey = IpNet; const CANCEL_KEY_TTL: i64 = 1_209_600; // 2 weeks cancellation key expire time const REDIS_SEND_TIMEOUT: std::time::Duration = std::time::Duration::from_millis(10); +const BATCH_SIZE: usize = 8; // Message types for sending through mpsc channel pub enum CancelKeyOp { @@ -54,78 +56,168 @@ pub enum CancelKeyOp { }, } +impl CancelKeyOp { + fn register(self, pipe: &mut Pipeline) -> Option { + #[allow(clippy::used_underscore_binding)] + match self { + CancelKeyOp::StoreCancelKey { + key, + field, + value, + resp_tx, + _guard, + expire, + } => { + pipe.hset(&key, field, value); + pipe.expire(key, expire); + let resp_tx = resp_tx?; + Some(CancelReplyOp::StoreCancelKey { resp_tx, _guard }) + } + CancelKeyOp::GetCancelData { + key, + resp_tx, + _guard, + } => { + pipe.hgetall(key); + Some(CancelReplyOp::GetCancelData { resp_tx, _guard }) + } + CancelKeyOp::RemoveCancelKey { + key, + field, + resp_tx, + _guard, + } => { + pipe.hdel(key, field); + let resp_tx = resp_tx?; + Some(CancelReplyOp::RemoveCancelKey { resp_tx, _guard }) + } + } + } +} + +// Message types for sending through mpsc channel +pub enum CancelReplyOp { + StoreCancelKey { + resp_tx: oneshot::Sender>, + _guard: CancelChannelSizeGuard<'static>, + }, + GetCancelData { + resp_tx: oneshot::Sender>>, + _guard: CancelChannelSizeGuard<'static>, + }, + RemoveCancelKey { + resp_tx: oneshot::Sender>, + _guard: CancelChannelSizeGuard<'static>, + }, +} + +impl CancelReplyOp { + fn send_err(self, e: anyhow::Error) { + match self { + CancelReplyOp::StoreCancelKey { resp_tx, _guard } => { + resp_tx + .send(Err(e)) + .inspect_err(|_| tracing::debug!("could not send reply")) + .ok(); + } + CancelReplyOp::GetCancelData { resp_tx, _guard } => { + resp_tx + .send(Err(e)) + .inspect_err(|_| tracing::debug!("could not send reply")) + .ok(); + } + CancelReplyOp::RemoveCancelKey { resp_tx, _guard } => { + resp_tx + .send(Err(e)) + .inspect_err(|_| tracing::debug!("could not send reply")) + .ok(); + } + } + } + + fn send_value(self, v: redis::Value) { + match self { + CancelReplyOp::StoreCancelKey { resp_tx, _guard } => { + let send = + FromRedisValue::from_owned_redis_value(v).context("could not parse value"); + resp_tx + .send(send) + .inspect_err(|_| tracing::debug!("could not send reply")) + .ok(); + } + CancelReplyOp::GetCancelData { resp_tx, _guard } => { + let send = + FromRedisValue::from_owned_redis_value(v).context("could not parse value"); + resp_tx + .send(send) + .inspect_err(|_| tracing::debug!("could not send reply")) + .ok(); + } + CancelReplyOp::RemoveCancelKey { resp_tx, _guard } => { + let send = + FromRedisValue::from_owned_redis_value(v).context("could not parse value"); + resp_tx + .send(send) + .inspect_err(|_| tracing::debug!("could not send reply")) + .ok(); + } + } + } +} + // Running as a separate task to accept messages through the rx channel -// In case of problems with RTT: switch to recv_many() + redis pipeline pub async fn handle_cancel_messages( client: &mut RedisKVClient, mut rx: mpsc::Receiver, -) -> anyhow::Result { +) -> anyhow::Result<()> { + let mut batch = Vec::new(); + let mut replies = vec![]; + loop { - if let Some(msg) = rx.recv().await { - match msg { - CancelKeyOp::StoreCancelKey { - key, - field, - value, - resp_tx, - _guard, - expire, - } => { - let res = client.hset(&key, field, value).await; - if let Some(resp_tx) = resp_tx { - if res.is_ok() { - resp_tx - .send(client.expire(key, expire).await) - .inspect_err(|e| { - tracing::debug!( - "failed to send StoreCancelKey response: {:?}", - e - ); - }) - .ok(); - } else { - resp_tx - .send(res) - .inspect_err(|e| { - tracing::debug!( - "failed to send StoreCancelKey response: {:?}", - e - ); - }) - .ok(); - } - } else if res.is_ok() { - drop(client.expire(key, expire).await); - } else { - tracing::warn!("failed to store cancel key: {:?}", res); - } + if rx.recv_many(&mut batch, BATCH_SIZE).await == 0 { + warn!("shutting down cancellation queue"); + break Ok(()); + } + + let batch_size = batch.len(); + debug!(batch_size, "running cancellation jobs"); + + let mut pipe = pipe(); + for msg in batch.drain(..) { + if let Some(reply) = msg.register(&mut pipe) { + replies.push(reply); + } else { + pipe.ignore(); + } + } + + let responses = replies.len(); + + match client.query(pipe).await { + // for each reply, we expect that many values. + Ok(Value::Array(values)) if values.len() == responses => { + debug!( + batch_size, + responses, "successfully completed cancellation jobs", + ); + for (value, reply) in std::iter::zip(values, replies.drain(..)) { + reply.send_value(value); } - CancelKeyOp::GetCancelData { - key, - resp_tx, - _guard, - } => { - drop(resp_tx.send(client.hget_all(key).await)); + } + Ok(value) => { + debug!(?value, "unexpected redis return value"); + for reply in replies.drain(..) { + reply.send_err(anyhow!("incorrect response type from redis")); } - CancelKeyOp::RemoveCancelKey { - key, - field, - resp_tx, - _guard, - } => { - if let Some(resp_tx) = resp_tx { - resp_tx - .send(client.hdel(key, field).await) - .inspect_err(|e| { - tracing::debug!("failed to send StoreCancelKey response: {:?}", e); - }) - .ok(); - } else { - drop(client.hdel(key, field).await); - } + } + Err(err) => { + for reply in replies.drain(..) { + reply.send_err(anyhow!("could not send cmd to redis: {err}")); } } } + + replies.clear(); } } diff --git a/proxy/src/redis/kv_ops.rs b/proxy/src/redis/kv_ops.rs index 3689bf7ae2..aa627b29a6 100644 --- a/proxy/src/redis/kv_ops.rs +++ b/proxy/src/redis/kv_ops.rs @@ -1,4 +1,5 @@ -use redis::{AsyncCommands, ToRedisArgs}; +use redis::aio::ConnectionLike; +use redis::{Cmd, FromRedisValue, Pipeline, RedisResult}; use super::connection_with_credentials_provider::ConnectionWithCredentialsProvider; use crate::rate_limiter::{GlobalRateLimiter, RateBucketInfo}; @@ -8,6 +9,23 @@ pub struct RedisKVClient { limiter: GlobalRateLimiter, } +#[allow(async_fn_in_trait)] +pub trait Queryable { + async fn query(&self, conn: &mut impl ConnectionLike) -> RedisResult; +} + +impl Queryable for Pipeline { + async fn query(&self, conn: &mut impl ConnectionLike) -> RedisResult { + self.query_async(conn).await + } +} + +impl Queryable for Cmd { + async fn query(&self, conn: &mut impl ConnectionLike) -> RedisResult { + self.query_async(conn).await + } +} + impl RedisKVClient { pub fn new(client: ConnectionWithCredentialsProvider, info: &'static [RateBucketInfo]) -> Self { Self { @@ -27,158 +45,24 @@ impl RedisKVClient { Ok(()) } - pub(crate) async fn hset(&mut self, key: K, field: F, value: V) -> anyhow::Result<()> - where - K: ToRedisArgs + Send + Sync, - F: ToRedisArgs + Send + Sync, - V: ToRedisArgs + Send + Sync, - { - if !self.limiter.check() { - tracing::info!("Rate limit exceeded. Skipping hset"); - return Err(anyhow::anyhow!("Rate limit exceeded")); - } - - match self.client.hset(&key, &field, &value).await { - Ok(()) => return Ok(()), - Err(e) => { - tracing::error!("failed to set a key-value pair: {e}"); - } - } - - tracing::info!("Redis client is disconnected. Reconnectiong..."); - self.try_connect().await?; - self.client - .hset(key, field, value) - .await - .map_err(anyhow::Error::new) - } - - #[allow(dead_code)] - pub(crate) async fn hset_multiple( + pub(crate) async fn query( &mut self, - key: &str, - items: &[(K, V)], - ) -> anyhow::Result<()> - where - K: ToRedisArgs + Send + Sync, - V: ToRedisArgs + Send + Sync, - { + q: impl Queryable, + ) -> anyhow::Result { if !self.limiter.check() { - tracing::info!("Rate limit exceeded. Skipping hset_multiple"); + tracing::info!("Rate limit exceeded. Skipping query"); return Err(anyhow::anyhow!("Rate limit exceeded")); } - match self.client.hset_multiple(key, items).await { - Ok(()) => return Ok(()), + match q.query(&mut self.client).await { + Ok(t) => return Ok(t), Err(e) => { - tracing::error!("failed to set a key-value pair: {e}"); + tracing::error!("failed to run query: {e}"); } } - tracing::info!("Redis client is disconnected. Reconnectiong..."); + tracing::info!("Redis client is disconnected. Reconnecting..."); self.try_connect().await?; - self.client - .hset_multiple(key, items) - .await - .map_err(anyhow::Error::new) - } - - #[allow(dead_code)] - pub(crate) async fn expire(&mut self, key: K, seconds: i64) -> anyhow::Result<()> - where - K: ToRedisArgs + Send + Sync, - { - if !self.limiter.check() { - tracing::info!("Rate limit exceeded. Skipping expire"); - return Err(anyhow::anyhow!("Rate limit exceeded")); - } - - match self.client.expire(&key, seconds).await { - Ok(()) => return Ok(()), - Err(e) => { - tracing::error!("failed to set a key-value pair: {e}"); - } - } - - tracing::info!("Redis client is disconnected. Reconnectiong..."); - self.try_connect().await?; - self.client - .expire(key, seconds) - .await - .map_err(anyhow::Error::new) - } - - #[allow(dead_code)] - pub(crate) async fn hget(&mut self, key: K, field: F) -> anyhow::Result - where - K: ToRedisArgs + Send + Sync, - F: ToRedisArgs + Send + Sync, - V: redis::FromRedisValue, - { - if !self.limiter.check() { - tracing::info!("Rate limit exceeded. Skipping hget"); - return Err(anyhow::anyhow!("Rate limit exceeded")); - } - - match self.client.hget(&key, &field).await { - Ok(value) => return Ok(value), - Err(e) => { - tracing::error!("failed to get a value: {e}"); - } - } - - tracing::info!("Redis client is disconnected. Reconnectiong..."); - self.try_connect().await?; - self.client - .hget(key, field) - .await - .map_err(anyhow::Error::new) - } - - pub(crate) async fn hget_all(&mut self, key: K) -> anyhow::Result - where - K: ToRedisArgs + Send + Sync, - V: redis::FromRedisValue, - { - if !self.limiter.check() { - tracing::info!("Rate limit exceeded. Skipping hgetall"); - return Err(anyhow::anyhow!("Rate limit exceeded")); - } - - match self.client.hgetall(&key).await { - Ok(value) => return Ok(value), - Err(e) => { - tracing::error!("failed to get a value: {e}"); - } - } - - tracing::info!("Redis client is disconnected. Reconnectiong..."); - self.try_connect().await?; - self.client.hgetall(key).await.map_err(anyhow::Error::new) - } - - pub(crate) async fn hdel(&mut self, key: K, field: F) -> anyhow::Result<()> - where - K: ToRedisArgs + Send + Sync, - F: ToRedisArgs + Send + Sync, - { - if !self.limiter.check() { - tracing::info!("Rate limit exceeded. Skipping hdel"); - return Err(anyhow::anyhow!("Rate limit exceeded")); - } - - match self.client.hdel(&key, &field).await { - Ok(()) => return Ok(()), - Err(e) => { - tracing::error!("failed to delete a key-value pair: {e}"); - } - } - - tracing::info!("Redis client is disconnected. Reconnectiong..."); - self.try_connect().await?; - self.client - .hdel(key, field) - .await - .map_err(anyhow::Error::new) + Ok(q.query(&mut self.client).await?) } }