From a113c48c43c9ff0130e404e47a55e4721bbb63a4 Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Tue, 13 May 2025 09:33:53 +0100 Subject: [PATCH] proxy: fix redis batching support (#11905) ## Problem For `StoreCancelKey`, we were inserting 2 commands, but we were not inserting two replies. This mismatch leads to errors when decoding the response. ## Summary of changes Abstract the command + reply pipeline so that commands and replies are registered at the same time. --- proxy/src/cancellation.rs | 125 ++++++++++++++++++++++++-------------- proxy/src/redis/kv_ops.rs | 2 +- 2 files changed, 79 insertions(+), 48 deletions(-) diff --git a/proxy/src/cancellation.rs b/proxy/src/cancellation.rs index c5ba04eb8c..f34fb747ca 100644 --- a/proxy/src/cancellation.rs +++ b/proxy/src/cancellation.rs @@ -6,12 +6,12 @@ use ipnet::{IpNet, Ipv4Net, Ipv6Net}; use postgres_client::CancelToken; use postgres_client::tls::MakeTlsConnect; use pq_proto::CancelKeyData; -use redis::{FromRedisValue, Pipeline, Value, pipe}; +use redis::{Cmd, FromRedisValue, Value}; use serde::{Deserialize, Serialize}; use thiserror::Error; use tokio::net::TcpStream; use tokio::sync::{mpsc, oneshot}; -use tracing::{debug, info, warn}; +use tracing::{debug, error, info, warn}; use crate::auth::backend::ComputeUserInfo; use crate::auth::{AuthError, check_peer_addr_is_in_list}; @@ -56,8 +56,70 @@ pub enum CancelKeyOp { }, } +pub struct Pipeline { + inner: redis::Pipeline, + replies: Vec, +} + +impl Pipeline { + fn with_capacity(n: usize) -> Self { + Self { + inner: redis::Pipeline::with_capacity(n), + replies: Vec::with_capacity(n), + } + } + + async fn execute(&mut self, client: &mut RedisKVClient) { + let responses = self.replies.len(); + let batch_size = self.inner.len(); + + match client.query(&self.inner).await { + // for each reply, we expect that many values. + Ok(Value::Array(values)) if values.len() == responses => { + debug!( + batch_size, + responses, "successfully completed cancellation jobs", + ); + for (value, reply) in std::iter::zip(values, self.replies.drain(..)) { + reply.send_value(value); + } + } + Ok(value) => { + error!(batch_size, ?value, "unexpected redis return value"); + for reply in self.replies.drain(..) { + reply.send_err(anyhow!("incorrect response type from redis")); + } + } + Err(err) => { + for reply in self.replies.drain(..) { + reply.send_err(anyhow!("could not send cmd to redis: {err}")); + } + } + } + + self.inner.clear(); + self.replies.clear(); + } + + fn add_command_with_reply(&mut self, cmd: Cmd, reply: CancelReplyOp) { + self.inner.add_command(cmd); + self.replies.push(reply); + } + + fn add_command_no_reply(&mut self, cmd: Cmd) { + self.inner.add_command(cmd).ignore(); + } + + fn add_command(&mut self, cmd: Cmd, reply: Option) { + match reply { + Some(reply) => self.add_command_with_reply(cmd, reply), + None => self.add_command_no_reply(cmd), + } + } +} + impl CancelKeyOp { - fn register(self, pipe: &mut Pipeline) -> Option { + fn register(self, pipe: &mut Pipeline) { #[allow(clippy::used_underscore_binding)] match self { CancelKeyOp::StoreCancelKey { @@ -68,18 +130,18 @@ impl CancelKeyOp { _guard, expire, } => { - pipe.hset(&key, field, value); - pipe.expire(key, expire); - let resp_tx = resp_tx?; - Some(CancelReplyOp::StoreCancelKey { resp_tx, _guard }) + let reply = + resp_tx.map(|resp_tx| CancelReplyOp::StoreCancelKey { resp_tx, _guard }); + pipe.add_command(Cmd::hset(&key, field, value), reply); + pipe.add_command_no_reply(Cmd::expire(key, expire)); } CancelKeyOp::GetCancelData { key, resp_tx, _guard, } => { - pipe.hgetall(key); - Some(CancelReplyOp::GetCancelData { resp_tx, _guard }) + let reply = CancelReplyOp::GetCancelData { resp_tx, _guard }; + pipe.add_command_with_reply(Cmd::hgetall(key), reply); } CancelKeyOp::RemoveCancelKey { key, @@ -87,9 +149,9 @@ impl CancelKeyOp { resp_tx, _guard, } => { - pipe.hdel(key, field); - let resp_tx = resp_tx?; - Some(CancelReplyOp::RemoveCancelKey { resp_tx, _guard }) + let reply = + resp_tx.map(|resp_tx| CancelReplyOp::RemoveCancelKey { resp_tx, _guard }); + pipe.add_command(Cmd::hdel(key, field), reply); } } } @@ -170,8 +232,8 @@ pub async fn handle_cancel_messages( client: &mut RedisKVClient, mut rx: mpsc::Receiver, ) -> anyhow::Result<()> { - let mut batch = Vec::new(); - let mut replies = vec![]; + let mut batch = Vec::with_capacity(BATCH_SIZE); + let mut pipeline = Pipeline::with_capacity(BATCH_SIZE); loop { if rx.recv_many(&mut batch, BATCH_SIZE).await == 0 { @@ -182,42 +244,11 @@ pub async fn handle_cancel_messages( let batch_size = batch.len(); debug!(batch_size, "running cancellation jobs"); - let mut pipe = pipe(); for msg in batch.drain(..) { - if let Some(reply) = msg.register(&mut pipe) { - replies.push(reply); - } else { - pipe.ignore(); - } + msg.register(&mut pipeline); } - let responses = replies.len(); - - match client.query(pipe).await { - // for each reply, we expect that many values. - Ok(Value::Array(values)) if values.len() == responses => { - debug!( - batch_size, - responses, "successfully completed cancellation jobs", - ); - for (value, reply) in std::iter::zip(values, replies.drain(..)) { - reply.send_value(value); - } - } - Ok(value) => { - debug!(?value, "unexpected redis return value"); - for reply in replies.drain(..) { - reply.send_err(anyhow!("incorrect response type from redis")); - } - } - Err(err) => { - for reply in replies.drain(..) { - reply.send_err(anyhow!("could not send cmd to redis: {err}")); - } - } - } - - replies.clear(); + pipeline.execute(client).await; } } diff --git a/proxy/src/redis/kv_ops.rs b/proxy/src/redis/kv_ops.rs index aa627b29a6..f71730c533 100644 --- a/proxy/src/redis/kv_ops.rs +++ b/proxy/src/redis/kv_ops.rs @@ -47,7 +47,7 @@ impl RedisKVClient { pub(crate) async fn query( &mut self, - q: impl Queryable, + q: &impl Queryable, ) -> anyhow::Result { if !self.limiter.check() { tracing::info!("Rate limit exceeded. Skipping query");