proxy: fix redis batching support (#11905)

## Problem

For `StoreCancelKey`, we were inserting 2 commands, but we were not
inserting two replies. This mismatch leads to errors when decoding the
response.

## Summary of changes

Abstract the command + reply pipeline so that commands and replies are
registered at the same time.
This commit is contained in:
Conrad Ludgate
2025-05-13 09:33:53 +01:00
committed by GitHub
parent 9971fba584
commit a113c48c43
2 changed files with 79 additions and 48 deletions

View File

@@ -6,12 +6,12 @@ use ipnet::{IpNet, Ipv4Net, Ipv6Net};
use postgres_client::CancelToken;
use postgres_client::tls::MakeTlsConnect;
use pq_proto::CancelKeyData;
use redis::{FromRedisValue, Pipeline, Value, pipe};
use redis::{Cmd, FromRedisValue, Value};
use serde::{Deserialize, Serialize};
use thiserror::Error;
use tokio::net::TcpStream;
use tokio::sync::{mpsc, oneshot};
use tracing::{debug, info, warn};
use tracing::{debug, error, info, warn};
use crate::auth::backend::ComputeUserInfo;
use crate::auth::{AuthError, check_peer_addr_is_in_list};
@@ -56,8 +56,70 @@ pub enum CancelKeyOp {
},
}
pub struct Pipeline {
inner: redis::Pipeline,
replies: Vec<CancelReplyOp>,
}
impl Pipeline {
fn with_capacity(n: usize) -> Self {
Self {
inner: redis::Pipeline::with_capacity(n),
replies: Vec::with_capacity(n),
}
}
async fn execute(&mut self, client: &mut RedisKVClient) {
let responses = self.replies.len();
let batch_size = self.inner.len();
match client.query(&self.inner).await {
// for each reply, we expect that many values.
Ok(Value::Array(values)) if values.len() == responses => {
debug!(
batch_size,
responses, "successfully completed cancellation jobs",
);
for (value, reply) in std::iter::zip(values, self.replies.drain(..)) {
reply.send_value(value);
}
}
Ok(value) => {
error!(batch_size, ?value, "unexpected redis return value");
for reply in self.replies.drain(..) {
reply.send_err(anyhow!("incorrect response type from redis"));
}
}
Err(err) => {
for reply in self.replies.drain(..) {
reply.send_err(anyhow!("could not send cmd to redis: {err}"));
}
}
}
self.inner.clear();
self.replies.clear();
}
fn add_command_with_reply(&mut self, cmd: Cmd, reply: CancelReplyOp) {
self.inner.add_command(cmd);
self.replies.push(reply);
}
fn add_command_no_reply(&mut self, cmd: Cmd) {
self.inner.add_command(cmd).ignore();
}
fn add_command(&mut self, cmd: Cmd, reply: Option<CancelReplyOp>) {
match reply {
Some(reply) => self.add_command_with_reply(cmd, reply),
None => self.add_command_no_reply(cmd),
}
}
}
impl CancelKeyOp {
fn register(self, pipe: &mut Pipeline) -> Option<CancelReplyOp> {
fn register(self, pipe: &mut Pipeline) {
#[allow(clippy::used_underscore_binding)]
match self {
CancelKeyOp::StoreCancelKey {
@@ -68,18 +130,18 @@ impl CancelKeyOp {
_guard,
expire,
} => {
pipe.hset(&key, field, value);
pipe.expire(key, expire);
let resp_tx = resp_tx?;
Some(CancelReplyOp::StoreCancelKey { resp_tx, _guard })
let reply =
resp_tx.map(|resp_tx| CancelReplyOp::StoreCancelKey { resp_tx, _guard });
pipe.add_command(Cmd::hset(&key, field, value), reply);
pipe.add_command_no_reply(Cmd::expire(key, expire));
}
CancelKeyOp::GetCancelData {
key,
resp_tx,
_guard,
} => {
pipe.hgetall(key);
Some(CancelReplyOp::GetCancelData { resp_tx, _guard })
let reply = CancelReplyOp::GetCancelData { resp_tx, _guard };
pipe.add_command_with_reply(Cmd::hgetall(key), reply);
}
CancelKeyOp::RemoveCancelKey {
key,
@@ -87,9 +149,9 @@ impl CancelKeyOp {
resp_tx,
_guard,
} => {
pipe.hdel(key, field);
let resp_tx = resp_tx?;
Some(CancelReplyOp::RemoveCancelKey { resp_tx, _guard })
let reply =
resp_tx.map(|resp_tx| CancelReplyOp::RemoveCancelKey { resp_tx, _guard });
pipe.add_command(Cmd::hdel(key, field), reply);
}
}
}
@@ -170,8 +232,8 @@ pub async fn handle_cancel_messages(
client: &mut RedisKVClient,
mut rx: mpsc::Receiver<CancelKeyOp>,
) -> anyhow::Result<()> {
let mut batch = Vec::new();
let mut replies = vec![];
let mut batch = Vec::with_capacity(BATCH_SIZE);
let mut pipeline = Pipeline::with_capacity(BATCH_SIZE);
loop {
if rx.recv_many(&mut batch, BATCH_SIZE).await == 0 {
@@ -182,42 +244,11 @@ pub async fn handle_cancel_messages(
let batch_size = batch.len();
debug!(batch_size, "running cancellation jobs");
let mut pipe = pipe();
for msg in batch.drain(..) {
if let Some(reply) = msg.register(&mut pipe) {
replies.push(reply);
} else {
pipe.ignore();
}
msg.register(&mut pipeline);
}
let responses = replies.len();
match client.query(pipe).await {
// for each reply, we expect that many values.
Ok(Value::Array(values)) if values.len() == responses => {
debug!(
batch_size,
responses, "successfully completed cancellation jobs",
);
for (value, reply) in std::iter::zip(values, replies.drain(..)) {
reply.send_value(value);
}
}
Ok(value) => {
debug!(?value, "unexpected redis return value");
for reply in replies.drain(..) {
reply.send_err(anyhow!("incorrect response type from redis"));
}
}
Err(err) => {
for reply in replies.drain(..) {
reply.send_err(anyhow!("could not send cmd to redis: {err}"));
}
}
}
replies.clear();
pipeline.execute(client).await;
}
}

View File

@@ -47,7 +47,7 @@ impl RedisKVClient {
pub(crate) async fn query<T: FromRedisValue>(
&mut self,
q: impl Queryable,
q: &impl Queryable,
) -> anyhow::Result<T> {
if !self.limiter.check() {
tracing::info!("Rate limit exceeded. Skipping query");