Compare commits

...

2 Commits

Author SHA1 Message Date
Conrad Ludgate
5dd600f9f9 proxy: refactor redis batching 2025-05-13 09:20:52 +01:00
Conrad Ludgate
2122f962d5 proxy: fix redis batching support 2025-05-13 08:24:33 +01:00
2 changed files with 112 additions and 118 deletions

View File

@@ -6,12 +6,12 @@ use ipnet::{IpNet, Ipv4Net, Ipv6Net};
use postgres_client::CancelToken;
use postgres_client::tls::MakeTlsConnect;
use pq_proto::CancelKeyData;
use redis::{FromRedisValue, Pipeline, Value, pipe};
use redis::{Cmd, FromRedisValue, Value};
use serde::{Deserialize, Serialize};
use thiserror::Error;
use tokio::net::TcpStream;
use tokio::sync::{mpsc, oneshot};
use tracing::{debug, info, warn};
use tracing::{debug, error, info, warn};
use crate::auth::backend::ComputeUserInfo;
use crate::auth::{AuthError, check_peer_addr_is_in_list};
@@ -56,8 +56,86 @@ pub enum CancelKeyOp {
},
}
type Callback = Box<dyn FnOnce(anyhow::Result<&[redis::Value]>) + Send>;
pub struct Pipeline {
inner: redis::Pipeline,
// vec![(number of commands, fn(values))]
replies: Vec<(usize, Callback)>,
}
impl Pipeline {
fn with_capacity(n: usize) -> Self {
Self {
inner: redis::Pipeline::with_capacity(n),
replies: Vec::with_capacity(n),
}
}
async fn execute(&mut self, client: &mut RedisKVClient) {
let commands = self.inner.len();
let batch_size = self.replies.len();
match client.query(&self.inner).await {
Ok(Value::Array(values)) if values.len() == commands => {
debug!(
commands,
batch_size, "successfully completed cancellation jobs",
);
let mut values = &*values;
for (n, resp) in self.replies.drain(..) {
let (v, rest) = values.split_at(n);
values = rest;
resp(Ok(v));
}
}
Ok(value) => {
error!(
commands,
batch_size,
?value,
"unexpected redis return value"
);
for (_n, resp) in self.replies.drain(..) {
resp(Err(anyhow!("incorrect response type from redis")));
}
}
Err(err) => {
for (_n, resp) in self.replies.drain(..) {
resp(Err(anyhow!("could not send cmd to redis: {err}")));
}
}
}
self.inner.clear();
self.replies.clear();
}
/// Add a batch of commands to the pipeline, and run the resp fn when they are all done.
///
/// If multiple commands are provided, the response should be able to decode
/// all of the values. You can provide a tuple in that case.
fn add_commands<F, T, const N: usize>(&mut self, cmds: [Cmd; N], resp: F)
where
F: FnOnce(anyhow::Result<T>) + Send + 'static,
T: FromRedisValue,
{
for cmd in cmds {
self.inner.add_command(cmd);
}
let reply = Box::new(move |res: anyhow::Result<&[redis::Value]>| {
let res = match res {
Ok(v) => T::from_redis_value(&redis::Value::Array(v.to_owned()))
.context("could not parse value"),
Err(e) => Err(e),
};
resp(res);
});
self.replies.push((N, reply as Box<_>));
}
}
impl CancelKeyOp {
fn register(self, pipe: &mut Pipeline) -> Option<CancelReplyOp> {
fn register(self, pipe: &mut Pipeline) {
#[allow(clippy::used_underscore_binding)]
match self {
CancelKeyOp::StoreCancelKey {
@@ -68,18 +146,30 @@ impl CancelKeyOp {
_guard,
expire,
} => {
pipe.hset(&key, field, value);
pipe.expire(key, expire);
let resp_tx = resp_tx?;
Some(CancelReplyOp::StoreCancelKey { resp_tx, _guard })
pipe.add_commands(
[Cmd::hset(&key, field, value), Cmd::expire(key, expire)],
// ignore all results
move |res: anyhow::Result<()>| {
let _guard = _guard;
if let Some(resp_tx) = resp_tx {
if resp_tx.send(res).is_err() {
tracing::debug!("could not send reply");
}
}
},
);
}
CancelKeyOp::GetCancelData {
key,
resp_tx,
_guard,
} => {
pipe.hgetall(key);
Some(CancelReplyOp::GetCancelData { resp_tx, _guard })
pipe.add_commands([Cmd::hgetall(key)], move |res| {
let _guard = _guard;
if resp_tx.send(res).is_err() {
tracing::debug!("could not send reply");
}
});
}
CancelKeyOp::RemoveCancelKey {
key,
@@ -87,79 +177,14 @@ impl CancelKeyOp {
resp_tx,
_guard,
} => {
pipe.hdel(key, field);
let resp_tx = resp_tx?;
Some(CancelReplyOp::RemoveCancelKey { resp_tx, _guard })
}
}
}
}
// Message types for sending through mpsc channel
pub enum CancelReplyOp {
StoreCancelKey {
resp_tx: oneshot::Sender<anyhow::Result<()>>,
_guard: CancelChannelSizeGuard<'static>,
},
GetCancelData {
resp_tx: oneshot::Sender<anyhow::Result<Vec<(String, String)>>>,
_guard: CancelChannelSizeGuard<'static>,
},
RemoveCancelKey {
resp_tx: oneshot::Sender<anyhow::Result<()>>,
_guard: CancelChannelSizeGuard<'static>,
},
}
impl CancelReplyOp {
fn send_err(self, e: anyhow::Error) {
match self {
CancelReplyOp::StoreCancelKey { resp_tx, _guard } => {
resp_tx
.send(Err(e))
.inspect_err(|_| tracing::debug!("could not send reply"))
.ok();
}
CancelReplyOp::GetCancelData { resp_tx, _guard } => {
resp_tx
.send(Err(e))
.inspect_err(|_| tracing::debug!("could not send reply"))
.ok();
}
CancelReplyOp::RemoveCancelKey { resp_tx, _guard } => {
resp_tx
.send(Err(e))
.inspect_err(|_| tracing::debug!("could not send reply"))
.ok();
}
}
}
fn send_value(self, v: redis::Value) {
match self {
CancelReplyOp::StoreCancelKey { resp_tx, _guard } => {
let send =
FromRedisValue::from_owned_redis_value(v).context("could not parse value");
resp_tx
.send(send)
.inspect_err(|_| tracing::debug!("could not send reply"))
.ok();
}
CancelReplyOp::GetCancelData { resp_tx, _guard } => {
let send =
FromRedisValue::from_owned_redis_value(v).context("could not parse value");
resp_tx
.send(send)
.inspect_err(|_| tracing::debug!("could not send reply"))
.ok();
}
CancelReplyOp::RemoveCancelKey { resp_tx, _guard } => {
let send =
FromRedisValue::from_owned_redis_value(v).context("could not parse value");
resp_tx
.send(send)
.inspect_err(|_| tracing::debug!("could not send reply"))
.ok();
pipe.add_commands([Cmd::hdel(key, field)], move |res| {
let _guard = _guard;
if let Some(resp_tx) = resp_tx {
if resp_tx.send(res).is_err() {
tracing::debug!("could not send reply");
}
}
});
}
}
}
@@ -170,8 +195,8 @@ pub async fn handle_cancel_messages(
client: &mut RedisKVClient,
mut rx: mpsc::Receiver<CancelKeyOp>,
) -> anyhow::Result<()> {
let mut batch = Vec::new();
let mut replies = vec![];
let mut batch = Vec::with_capacity(BATCH_SIZE);
let mut pipeline = Pipeline::with_capacity(BATCH_SIZE);
loop {
if rx.recv_many(&mut batch, BATCH_SIZE).await == 0 {
@@ -182,42 +207,11 @@ pub async fn handle_cancel_messages(
let batch_size = batch.len();
debug!(batch_size, "running cancellation jobs");
let mut pipe = pipe();
for msg in batch.drain(..) {
if let Some(reply) = msg.register(&mut pipe) {
replies.push(reply);
} else {
pipe.ignore();
}
msg.register(&mut pipeline);
}
let responses = replies.len();
match client.query(pipe).await {
// for each reply, we expect that many values.
Ok(Value::Array(values)) if values.len() == responses => {
debug!(
batch_size,
responses, "successfully completed cancellation jobs",
);
for (value, reply) in std::iter::zip(values, replies.drain(..)) {
reply.send_value(value);
}
}
Ok(value) => {
debug!(?value, "unexpected redis return value");
for reply in replies.drain(..) {
reply.send_err(anyhow!("incorrect response type from redis"));
}
}
Err(err) => {
for reply in replies.drain(..) {
reply.send_err(anyhow!("could not send cmd to redis: {err}"));
}
}
}
replies.clear();
pipeline.execute(client).await;
}
}

View File

@@ -47,7 +47,7 @@ impl RedisKVClient {
pub(crate) async fn query<T: FromRedisValue>(
&mut self,
q: impl Queryable,
q: &impl Queryable,
) -> anyhow::Result<T> {
if !self.limiter.check() {
tracing::info!("Rate limit exceeded. Skipping query");