feat(proxy): add batching to cancellation queue processing (#10607)

Add batching to the redis queue, which allows us to clear it out quicker
should it slow down temporarily.
This commit is contained in:
Conrad Ludgate
2025-04-12 10:16:22 +01:00
committed by GitHub
parent d109bf8c1d
commit 946e971df8
3 changed files with 192 additions and 209 deletions

View File

@@ -509,7 +509,14 @@ pub async fn run() -> anyhow::Result<()> {
if let Some(mut redis_kv_client) = redis_kv_client {
maintenance_tasks.spawn(async move {
redis_kv_client.try_connect().await?;
handle_cancel_messages(&mut redis_kv_client, rx_cancel).await
handle_cancel_messages(&mut redis_kv_client, rx_cancel).await?;
drop(redis_kv_client);
// `handle_cancel_messages` was terminated due to the tx_cancel
// being dropped. this is not worthy of an error, and this task can only return `Err`,
// so let's wait forever instead.
std::future::pending().await
});
}

View File

@@ -1,16 +1,17 @@
use std::convert::Infallible;
use std::net::{IpAddr, SocketAddr};
use std::sync::Arc;
use anyhow::{Context, anyhow};
use ipnet::{IpNet, Ipv4Net, Ipv6Net};
use postgres_client::CancelToken;
use postgres_client::tls::MakeTlsConnect;
use pq_proto::CancelKeyData;
use redis::{FromRedisValue, Pipeline, Value, pipe};
use serde::{Deserialize, Serialize};
use thiserror::Error;
use tokio::net::TcpStream;
use tokio::sync::{mpsc, oneshot};
use tracing::{debug, info};
use tracing::{debug, info, warn};
use crate::auth::backend::ComputeUserInfo;
use crate::auth::{AuthError, check_peer_addr_is_in_list};
@@ -30,6 +31,7 @@ type IpSubnetKey = IpNet;
const CANCEL_KEY_TTL: i64 = 1_209_600; // 2 weeks cancellation key expire time
const REDIS_SEND_TIMEOUT: std::time::Duration = std::time::Duration::from_millis(10);
const BATCH_SIZE: usize = 8;
// Message types for sending through mpsc channel
pub enum CancelKeyOp {
@@ -54,78 +56,168 @@ pub enum CancelKeyOp {
},
}
impl CancelKeyOp {
fn register(self, pipe: &mut Pipeline) -> Option<CancelReplyOp> {
#[allow(clippy::used_underscore_binding)]
match self {
CancelKeyOp::StoreCancelKey {
key,
field,
value,
resp_tx,
_guard,
expire,
} => {
pipe.hset(&key, field, value);
pipe.expire(key, expire);
let resp_tx = resp_tx?;
Some(CancelReplyOp::StoreCancelKey { resp_tx, _guard })
}
CancelKeyOp::GetCancelData {
key,
resp_tx,
_guard,
} => {
pipe.hgetall(key);
Some(CancelReplyOp::GetCancelData { resp_tx, _guard })
}
CancelKeyOp::RemoveCancelKey {
key,
field,
resp_tx,
_guard,
} => {
pipe.hdel(key, field);
let resp_tx = resp_tx?;
Some(CancelReplyOp::RemoveCancelKey { resp_tx, _guard })
}
}
}
}
// Message types for sending through mpsc channel
pub enum CancelReplyOp {
StoreCancelKey {
resp_tx: oneshot::Sender<anyhow::Result<()>>,
_guard: CancelChannelSizeGuard<'static>,
},
GetCancelData {
resp_tx: oneshot::Sender<anyhow::Result<Vec<(String, String)>>>,
_guard: CancelChannelSizeGuard<'static>,
},
RemoveCancelKey {
resp_tx: oneshot::Sender<anyhow::Result<()>>,
_guard: CancelChannelSizeGuard<'static>,
},
}
impl CancelReplyOp {
fn send_err(self, e: anyhow::Error) {
match self {
CancelReplyOp::StoreCancelKey { resp_tx, _guard } => {
resp_tx
.send(Err(e))
.inspect_err(|_| tracing::debug!("could not send reply"))
.ok();
}
CancelReplyOp::GetCancelData { resp_tx, _guard } => {
resp_tx
.send(Err(e))
.inspect_err(|_| tracing::debug!("could not send reply"))
.ok();
}
CancelReplyOp::RemoveCancelKey { resp_tx, _guard } => {
resp_tx
.send(Err(e))
.inspect_err(|_| tracing::debug!("could not send reply"))
.ok();
}
}
}
fn send_value(self, v: redis::Value) {
match self {
CancelReplyOp::StoreCancelKey { resp_tx, _guard } => {
let send =
FromRedisValue::from_owned_redis_value(v).context("could not parse value");
resp_tx
.send(send)
.inspect_err(|_| tracing::debug!("could not send reply"))
.ok();
}
CancelReplyOp::GetCancelData { resp_tx, _guard } => {
let send =
FromRedisValue::from_owned_redis_value(v).context("could not parse value");
resp_tx
.send(send)
.inspect_err(|_| tracing::debug!("could not send reply"))
.ok();
}
CancelReplyOp::RemoveCancelKey { resp_tx, _guard } => {
let send =
FromRedisValue::from_owned_redis_value(v).context("could not parse value");
resp_tx
.send(send)
.inspect_err(|_| tracing::debug!("could not send reply"))
.ok();
}
}
}
}
// Running as a separate task to accept messages through the rx channel
// In case of problems with RTT: switch to recv_many() + redis pipeline
pub async fn handle_cancel_messages(
client: &mut RedisKVClient,
mut rx: mpsc::Receiver<CancelKeyOp>,
) -> anyhow::Result<Infallible> {
) -> anyhow::Result<()> {
let mut batch = Vec::new();
let mut replies = vec![];
loop {
if let Some(msg) = rx.recv().await {
match msg {
CancelKeyOp::StoreCancelKey {
key,
field,
value,
resp_tx,
_guard,
expire,
} => {
let res = client.hset(&key, field, value).await;
if let Some(resp_tx) = resp_tx {
if res.is_ok() {
resp_tx
.send(client.expire(key, expire).await)
.inspect_err(|e| {
tracing::debug!(
"failed to send StoreCancelKey response: {:?}",
e
);
})
.ok();
} else {
resp_tx
.send(res)
.inspect_err(|e| {
tracing::debug!(
"failed to send StoreCancelKey response: {:?}",
e
);
})
.ok();
}
} else if res.is_ok() {
drop(client.expire(key, expire).await);
} else {
tracing::warn!("failed to store cancel key: {:?}", res);
}
if rx.recv_many(&mut batch, BATCH_SIZE).await == 0 {
warn!("shutting down cancellation queue");
break Ok(());
}
let batch_size = batch.len();
debug!(batch_size, "running cancellation jobs");
let mut pipe = pipe();
for msg in batch.drain(..) {
if let Some(reply) = msg.register(&mut pipe) {
replies.push(reply);
} else {
pipe.ignore();
}
}
let responses = replies.len();
match client.query(pipe).await {
// for each reply, we expect that many values.
Ok(Value::Array(values)) if values.len() == responses => {
debug!(
batch_size,
responses, "successfully completed cancellation jobs",
);
for (value, reply) in std::iter::zip(values, replies.drain(..)) {
reply.send_value(value);
}
CancelKeyOp::GetCancelData {
key,
resp_tx,
_guard,
} => {
drop(resp_tx.send(client.hget_all(key).await));
}
Ok(value) => {
debug!(?value, "unexpected redis return value");
for reply in replies.drain(..) {
reply.send_err(anyhow!("incorrect response type from redis"));
}
CancelKeyOp::RemoveCancelKey {
key,
field,
resp_tx,
_guard,
} => {
if let Some(resp_tx) = resp_tx {
resp_tx
.send(client.hdel(key, field).await)
.inspect_err(|e| {
tracing::debug!("failed to send StoreCancelKey response: {:?}", e);
})
.ok();
} else {
drop(client.hdel(key, field).await);
}
}
Err(err) => {
for reply in replies.drain(..) {
reply.send_err(anyhow!("could not send cmd to redis: {err}"));
}
}
}
replies.clear();
}
}

View File

@@ -1,4 +1,5 @@
use redis::{AsyncCommands, ToRedisArgs};
use redis::aio::ConnectionLike;
use redis::{Cmd, FromRedisValue, Pipeline, RedisResult};
use super::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
use crate::rate_limiter::{GlobalRateLimiter, RateBucketInfo};
@@ -8,6 +9,23 @@ pub struct RedisKVClient {
limiter: GlobalRateLimiter,
}
#[allow(async_fn_in_trait)]
pub trait Queryable {
async fn query<T: FromRedisValue>(&self, conn: &mut impl ConnectionLike) -> RedisResult<T>;
}
impl Queryable for Pipeline {
async fn query<T: FromRedisValue>(&self, conn: &mut impl ConnectionLike) -> RedisResult<T> {
self.query_async(conn).await
}
}
impl Queryable for Cmd {
async fn query<T: FromRedisValue>(&self, conn: &mut impl ConnectionLike) -> RedisResult<T> {
self.query_async(conn).await
}
}
impl RedisKVClient {
pub fn new(client: ConnectionWithCredentialsProvider, info: &'static [RateBucketInfo]) -> Self {
Self {
@@ -27,158 +45,24 @@ impl RedisKVClient {
Ok(())
}
pub(crate) async fn hset<K, F, V>(&mut self, key: K, field: F, value: V) -> anyhow::Result<()>
where
K: ToRedisArgs + Send + Sync,
F: ToRedisArgs + Send + Sync,
V: ToRedisArgs + Send + Sync,
{
if !self.limiter.check() {
tracing::info!("Rate limit exceeded. Skipping hset");
return Err(anyhow::anyhow!("Rate limit exceeded"));
}
match self.client.hset(&key, &field, &value).await {
Ok(()) => return Ok(()),
Err(e) => {
tracing::error!("failed to set a key-value pair: {e}");
}
}
tracing::info!("Redis client is disconnected. Reconnectiong...");
self.try_connect().await?;
self.client
.hset(key, field, value)
.await
.map_err(anyhow::Error::new)
}
#[allow(dead_code)]
pub(crate) async fn hset_multiple<K, V>(
pub(crate) async fn query<T: FromRedisValue>(
&mut self,
key: &str,
items: &[(K, V)],
) -> anyhow::Result<()>
where
K: ToRedisArgs + Send + Sync,
V: ToRedisArgs + Send + Sync,
{
q: impl Queryable,
) -> anyhow::Result<T> {
if !self.limiter.check() {
tracing::info!("Rate limit exceeded. Skipping hset_multiple");
tracing::info!("Rate limit exceeded. Skipping query");
return Err(anyhow::anyhow!("Rate limit exceeded"));
}
match self.client.hset_multiple(key, items).await {
Ok(()) => return Ok(()),
match q.query(&mut self.client).await {
Ok(t) => return Ok(t),
Err(e) => {
tracing::error!("failed to set a key-value pair: {e}");
tracing::error!("failed to run query: {e}");
}
}
tracing::info!("Redis client is disconnected. Reconnectiong...");
tracing::info!("Redis client is disconnected. Reconnecting...");
self.try_connect().await?;
self.client
.hset_multiple(key, items)
.await
.map_err(anyhow::Error::new)
}
#[allow(dead_code)]
pub(crate) async fn expire<K>(&mut self, key: K, seconds: i64) -> anyhow::Result<()>
where
K: ToRedisArgs + Send + Sync,
{
if !self.limiter.check() {
tracing::info!("Rate limit exceeded. Skipping expire");
return Err(anyhow::anyhow!("Rate limit exceeded"));
}
match self.client.expire(&key, seconds).await {
Ok(()) => return Ok(()),
Err(e) => {
tracing::error!("failed to set a key-value pair: {e}");
}
}
tracing::info!("Redis client is disconnected. Reconnectiong...");
self.try_connect().await?;
self.client
.expire(key, seconds)
.await
.map_err(anyhow::Error::new)
}
#[allow(dead_code)]
pub(crate) async fn hget<K, F, V>(&mut self, key: K, field: F) -> anyhow::Result<V>
where
K: ToRedisArgs + Send + Sync,
F: ToRedisArgs + Send + Sync,
V: redis::FromRedisValue,
{
if !self.limiter.check() {
tracing::info!("Rate limit exceeded. Skipping hget");
return Err(anyhow::anyhow!("Rate limit exceeded"));
}
match self.client.hget(&key, &field).await {
Ok(value) => return Ok(value),
Err(e) => {
tracing::error!("failed to get a value: {e}");
}
}
tracing::info!("Redis client is disconnected. Reconnectiong...");
self.try_connect().await?;
self.client
.hget(key, field)
.await
.map_err(anyhow::Error::new)
}
pub(crate) async fn hget_all<K, V>(&mut self, key: K) -> anyhow::Result<V>
where
K: ToRedisArgs + Send + Sync,
V: redis::FromRedisValue,
{
if !self.limiter.check() {
tracing::info!("Rate limit exceeded. Skipping hgetall");
return Err(anyhow::anyhow!("Rate limit exceeded"));
}
match self.client.hgetall(&key).await {
Ok(value) => return Ok(value),
Err(e) => {
tracing::error!("failed to get a value: {e}");
}
}
tracing::info!("Redis client is disconnected. Reconnectiong...");
self.try_connect().await?;
self.client.hgetall(key).await.map_err(anyhow::Error::new)
}
pub(crate) async fn hdel<K, F>(&mut self, key: K, field: F) -> anyhow::Result<()>
where
K: ToRedisArgs + Send + Sync,
F: ToRedisArgs + Send + Sync,
{
if !self.limiter.check() {
tracing::info!("Rate limit exceeded. Skipping hdel");
return Err(anyhow::anyhow!("Rate limit exceeded"));
}
match self.client.hdel(&key, &field).await {
Ok(()) => return Ok(()),
Err(e) => {
tracing::error!("failed to delete a key-value pair: {e}");
}
}
tracing::info!("Redis client is disconnected. Reconnectiong...");
self.try_connect().await?;
self.client
.hdel(key, field)
.await
.map_err(anyhow::Error::new)
Ok(q.query(&mut self.client).await?)
}
}