From 380d167b7ca2c8312fafffef30ae8cbdea7fd8a0 Mon Sep 17 00:00:00 2001 From: Folke Behrens Date: Fri, 11 Jul 2025 21:35:42 +0200 Subject: [PATCH] proxy: For cancellation data replace HSET+EXPIRE/HGET with SET..EX/GET (#12553) ## Problem To store cancellation data we send two commands to redis because the redis server version doesn't support HSET with EX. Also, HSET is not really needed. ## Summary of changes * Replace the HSET + EXPIRE command pair with one SET .. EX command. * Replace HGET with GET. * Leave a workaround for old keys set with HSET. * Replace some anyhow errors with specific errors to surface the WRONGTYPE error from redis. --- Cargo.lock | 1 + proxy/Cargo.toml | 3 +- proxy/src/batch.rs | 68 +++++++---- proxy/src/cancellation.rs | 111 ++++++++++++------ proxy/src/metrics.rs | 6 +- .../connection_with_credentials_provider.rs | 24 +++- proxy/src/redis/elasticache.rs | 20 +++- proxy/src/redis/kv_ops.rs | 16 ++- 8 files changed, 175 insertions(+), 74 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 025f4e4116..4323254f0a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5289,6 +5289,7 @@ dependencies = [ "async-trait", "atomic-take", "aws-config", + "aws-credential-types", "aws-sdk-iam", "aws-sigv4", "base64 0.22.1", diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index ce8610be24..0a406d1ca8 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -16,6 +16,7 @@ async-compression.workspace = true async-trait.workspace = true atomic-take.workspace = true aws-config.workspace = true +aws-credential-types.workspace = true aws-sdk-iam.workspace = true aws-sigv4.workspace = true base64.workspace = true @@ -127,4 +128,4 @@ rstest.workspace = true walkdir.workspace = true rand_distr = "0.4" tokio-postgres.workspace = true -tracing-test = "0.2" \ No newline at end of file +tracing-test = "0.2" diff --git a/proxy/src/batch.rs b/proxy/src/batch.rs index 33e08797f2..cf866ab9a3 100644 --- a/proxy/src/batch.rs +++ b/proxy/src/batch.rs @@ -7,13 +7,17 @@ use std::pin::pin; use std::sync::Mutex; use scopeguard::ScopeGuard; +use tokio::sync::oneshot; use tokio::sync::oneshot::error::TryRecvError; use crate::ext::LockExt; +type ProcResult

= Result<

::Res,

::Err>; + pub trait QueueProcessing: Send + 'static { type Req: Send + 'static; type Res: Send; + type Err: Send + Clone; /// Get the desired batch size. fn batch_size(&self, queue_size: usize) -> usize; @@ -24,7 +28,18 @@ pub trait QueueProcessing: Send + 'static { /// If this apply can error, it's expected that errors be forwarded to each Self::Res. /// /// Batching does not need to happen atomically. - fn apply(&mut self, req: Vec) -> impl Future> + Send; + fn apply( + &mut self, + req: Vec, + ) -> impl Future, Self::Err>> + Send; +} + +#[derive(thiserror::Error)] +pub enum BatchQueueError { + #[error(transparent)] + Result(E), + #[error(transparent)] + Cancelled(C), } pub struct BatchQueue { @@ -34,7 +49,7 @@ pub struct BatchQueue { struct BatchJob { req: P::Req, - res: tokio::sync::oneshot::Sender, + res: tokio::sync::oneshot::Sender>, } impl BatchQueue

{ @@ -55,11 +70,11 @@ impl BatchQueue

{ &self, req: P::Req, cancelled: impl Future, - ) -> Result { + ) -> Result> { let (id, mut rx) = self.inner.lock_propagate_poison().register_job(req); let mut cancelled = pin!(cancelled); - let resp = loop { + let resp: Option> = loop { // try become the leader, or try wait for success. let mut processor = tokio::select! { // try become leader. @@ -72,7 +87,7 @@ impl BatchQueue

{ if inner.queue.remove(&id).is_some() { tracing::warn!("batched task cancelled before completion"); } - return Err(cancel); + return Err(BatchQueueError::Cancelled(cancel)); }, }; @@ -96,18 +111,30 @@ impl BatchQueue

{ // good: we didn't get cancelled. ScopeGuard::into_inner(cancel_safety); - if values.len() != resps.len() { - tracing::error!( - "batch: invalid response size, expected={}, got={}", - resps.len(), - values.len() - ); - } + match values { + Ok(values) => { + if values.len() != resps.len() { + tracing::error!( + "batch: invalid response size, expected={}, got={}", + resps.len(), + values.len() + ); + } - // send response values. - for (tx, value) in std::iter::zip(resps, values) { - if tx.send(value).is_err() { - // receiver hung up but that's fine. + // send response values. + for (tx, value) in std::iter::zip(resps, values) { + if tx.send(Ok(value)).is_err() { + // receiver hung up but that's fine. + } + } + } + + Err(err) => { + for tx in resps { + if tx.send(Err(err.clone())).is_err() { + // receiver hung up but that's fine. + } + } } } @@ -129,7 +156,8 @@ impl BatchQueue

{ tracing::debug!(id, "batch: job completed"); - Ok(resp.expect("no response found. batch processer should not panic")) + resp.expect("no response found. batch processer should not panic") + .map_err(BatchQueueError::Result) } } @@ -139,8 +167,8 @@ struct BatchQueueInner { } impl BatchQueueInner

{ - fn register_job(&mut self, req: P::Req) -> (u64, tokio::sync::oneshot::Receiver) { - let (tx, rx) = tokio::sync::oneshot::channel(); + fn register_job(&mut self, req: P::Req) -> (u64, oneshot::Receiver>) { + let (tx, rx) = oneshot::channel(); let id = self.version; @@ -158,7 +186,7 @@ impl BatchQueueInner

{ (id, rx) } - fn get_batch(&mut self, p: &P) -> (Vec, Vec>) { + fn get_batch(&mut self, p: &P) -> (Vec, Vec>>) { let batch_size = p.batch_size(self.queue.len()); let mut reqs = Vec::with_capacity(batch_size); let mut resps = Vec::with_capacity(batch_size); diff --git a/proxy/src/cancellation.rs b/proxy/src/cancellation.rs index 74413f1a7d..4ea4c4ea54 100644 --- a/proxy/src/cancellation.rs +++ b/proxy/src/cancellation.rs @@ -4,12 +4,11 @@ use std::pin::pin; use std::sync::{Arc, OnceLock}; use std::time::Duration; -use anyhow::anyhow; use futures::FutureExt; use ipnet::{IpNet, Ipv4Net, Ipv6Net}; use postgres_client::RawCancelToken; use postgres_client::tls::MakeTlsConnect; -use redis::{Cmd, FromRedisValue, Value}; +use redis::{Cmd, FromRedisValue, SetExpiry, SetOptions, Value}; use serde::{Deserialize, Serialize}; use thiserror::Error; use tokio::net::TcpStream; @@ -18,7 +17,7 @@ use tracing::{debug, error, info}; use crate::auth::AuthError; use crate::auth::backend::ComputeUserInfo; -use crate::batch::{BatchQueue, QueueProcessing}; +use crate::batch::{BatchQueue, BatchQueueError, QueueProcessing}; use crate::config::ComputeConfig; use crate::context::RequestContext; use crate::control_plane::ControlPlaneApi; @@ -28,7 +27,7 @@ use crate::metrics::{CancelChannelSizeGuard, CancellationRequest, Metrics, Redis use crate::pqproto::CancelKeyData; use crate::rate_limiter::LeakyBucketRateLimiter; use crate::redis::keys::KeyPrefix; -use crate::redis::kv_ops::RedisKVClient; +use crate::redis::kv_ops::{RedisKVClient, RedisKVClientError}; type IpSubnetKey = IpNet; @@ -45,6 +44,17 @@ pub enum CancelKeyOp { GetCancelData { key: CancelKeyData, }, + GetCancelDataOld { + key: CancelKeyData, + }, +} + +#[derive(thiserror::Error, Debug, Clone)] +pub enum PipelineError { + #[error("could not send cmd to redis: {0}")] + RedisKVClient(Arc), + #[error("incorrect number of responses from redis")] + IncorrectNumberOfResponses, } pub struct Pipeline { @@ -60,7 +70,7 @@ impl Pipeline { } } - async fn execute(self, client: &mut RedisKVClient) -> Vec> { + async fn execute(self, client: &mut RedisKVClient) -> Result, PipelineError> { let responses = self.replies; let batch_size = self.inner.len(); @@ -78,30 +88,20 @@ impl Pipeline { batch_size, responses, "successfully completed cancellation jobs", ); - values.into_iter().map(Ok).collect() + Ok(values.into_iter().collect()) } Ok(value) => { error!(batch_size, ?value, "unexpected redis return value"); - std::iter::repeat_with(|| Err(anyhow!("incorrect response type from redis"))) - .take(responses) - .collect() - } - Err(err) => { - std::iter::repeat_with(|| Err(anyhow!("could not send cmd to redis: {err}"))) - .take(responses) - .collect() + Err(PipelineError::IncorrectNumberOfResponses) } + Err(err) => Err(PipelineError::RedisKVClient(Arc::new(err))), } } - fn add_command_with_reply(&mut self, cmd: Cmd) { + fn add_command(&mut self, cmd: Cmd) { self.inner.add_command(cmd); self.replies += 1; } - - fn add_command_no_reply(&mut self, cmd: Cmd) { - self.inner.add_command(cmd).ignore(); - } } impl CancelKeyOp { @@ -109,12 +109,19 @@ impl CancelKeyOp { match self { CancelKeyOp::StoreCancelKey { key, value, expire } => { let key = KeyPrefix::Cancel(*key).build_redis_key(); - pipe.add_command_with_reply(Cmd::hset(&key, "data", &**value)); - pipe.add_command_no_reply(Cmd::expire(&key, expire.as_secs() as i64)); + pipe.add_command(Cmd::set_options( + &key, + &**value, + SetOptions::default().with_expiration(SetExpiry::EX(expire.as_secs())), + )); + } + CancelKeyOp::GetCancelDataOld { key } => { + let key = KeyPrefix::Cancel(*key).build_redis_key(); + pipe.add_command(Cmd::hget(key, "data")); } CancelKeyOp::GetCancelData { key } => { let key = KeyPrefix::Cancel(*key).build_redis_key(); - pipe.add_command_with_reply(Cmd::hget(key, "data")); + pipe.add_command(Cmd::get(key)); } } } @@ -127,13 +134,14 @@ pub struct CancellationProcessor { impl QueueProcessing for CancellationProcessor { type Req = (CancelChannelSizeGuard<'static>, CancelKeyOp); - type Res = anyhow::Result; + type Res = redis::Value; + type Err = PipelineError; fn batch_size(&self, _queue_size: usize) -> usize { self.batch_size } - async fn apply(&mut self, batch: Vec) -> Vec { + async fn apply(&mut self, batch: Vec) -> Result, Self::Err> { if !self.client.credentials_refreshed() { // this will cause a timeout for cancellation operations tracing::debug!( @@ -244,18 +252,18 @@ impl CancellationHandler { &self, key: CancelKeyData, ) -> Result, CancelError> { - let guard = Metrics::get() - .proxy - .cancel_channel_size - .guard(RedisMsgKind::HGet); - let op = CancelKeyOp::GetCancelData { key }; + const TIMEOUT: Duration = Duration::from_secs(5); let Some(tx) = self.tx.get() else { tracing::warn!("cancellation handler is not available"); return Err(CancelError::InternalError); }; - const TIMEOUT: Duration = Duration::from_secs(5); + let guard = Metrics::get() + .proxy + .cancel_channel_size + .guard(RedisMsgKind::Get); + let op = CancelKeyOp::GetCancelData { key }; let result = timeout( TIMEOUT, tx.call((guard, op), std::future::pending::()), @@ -264,10 +272,37 @@ impl CancellationHandler { .map_err(|_| { tracing::warn!("timed out waiting to receive GetCancelData response"); CancelError::RateLimit - })? - // cannot be cancelled - .unwrap_or_else(|x| match x {}) - .map_err(|e| { + })?; + + // We may still have cancel keys set with HSET "data". + // Check error type and retry with HGET. + // TODO: remove code after HSET is not used anymore. + let result = if let Err(err) = result.as_ref() + && let BatchQueueError::Result(err) = err + && let PipelineError::RedisKVClient(err) = err + && let RedisKVClientError::Redis(err) = &**err + && let Some(errcode) = err.code() + && errcode == "WRONGTYPE" + { + let guard = Metrics::get() + .proxy + .cancel_channel_size + .guard(RedisMsgKind::HGet); + let op = CancelKeyOp::GetCancelDataOld { key }; + timeout( + TIMEOUT, + tx.call((guard, op), std::future::pending::()), + ) + .await + .map_err(|_| { + tracing::warn!("timed out waiting to receive GetCancelData response"); + CancelError::RateLimit + })? + } else { + result + }; + + let result = result.map_err(|e| { tracing::warn!("failed to receive GetCancelData response: {e}"); CancelError::InternalError })?; @@ -442,7 +477,7 @@ impl Session { let guard = Metrics::get() .proxy .cancel_channel_size - .guard(RedisMsgKind::HSet); + .guard(RedisMsgKind::Set); let op = CancelKeyOp::StoreCancelKey { key: self.key, value: closure_json.clone(), @@ -456,7 +491,7 @@ impl Session { ); match tx.call((guard, op), cancel.as_mut()).await { - Ok(Ok(_)) => { + Ok(_) => { tracing::debug!( src=%self.key, dest=?cancel_closure.cancel_token, @@ -467,10 +502,10 @@ impl Session { tokio::time::sleep(CANCEL_KEY_REFRESH).await; } // retry immediately. - Ok(Err(error)) => { + Err(BatchQueueError::Result(error)) => { tracing::warn!(?error, "error registering cancellation key"); } - Err(Err(_cancelled)) => break, + Err(BatchQueueError::Cancelled(Err(_cancelled))) => break, } } diff --git a/proxy/src/metrics.rs b/proxy/src/metrics.rs index 9d1a3d4358..8439082498 100644 --- a/proxy/src/metrics.rs +++ b/proxy/src/metrics.rs @@ -374,11 +374,9 @@ pub enum Waiting { #[label(singleton = "kind")] #[allow(clippy::enum_variant_names)] pub enum RedisMsgKind { - HSet, - HSetMultiple, + Set, + Get, HGet, - HGetAll, - HDel, } #[derive(Default, Clone)] diff --git a/proxy/src/redis/connection_with_credentials_provider.rs b/proxy/src/redis/connection_with_credentials_provider.rs index 35a3fe4334..b0bf332e44 100644 --- a/proxy/src/redis/connection_with_credentials_provider.rs +++ b/proxy/src/redis/connection_with_credentials_provider.rs @@ -4,11 +4,12 @@ use std::time::Duration; use futures::FutureExt; use redis::aio::{ConnectionLike, MultiplexedConnection}; -use redis::{ConnectionInfo, IntoConnectionInfo, RedisConnectionInfo, RedisResult}; +use redis::{ConnectionInfo, IntoConnectionInfo, RedisConnectionInfo, RedisError, RedisResult}; use tokio::task::AbortHandle; use tracing::{error, info, warn}; use super::elasticache::CredentialsProvider; +use crate::redis::elasticache::CredentialsProviderError; enum Credentials { Static(ConnectionInfo), @@ -26,6 +27,14 @@ impl Clone for Credentials { } } +#[derive(thiserror::Error, Debug)] +pub enum ConnectionProviderError { + #[error(transparent)] + Redis(#[from] RedisError), + #[error(transparent)] + CredentialsProvider(#[from] CredentialsProviderError), +} + /// A wrapper around `redis::MultiplexedConnection` that automatically refreshes the token. /// Provides PubSub connection without credentials refresh. pub struct ConnectionWithCredentialsProvider { @@ -86,15 +95,18 @@ impl ConnectionWithCredentialsProvider { } } - async fn ping(con: &mut MultiplexedConnection) -> RedisResult<()> { - redis::cmd("PING").query_async(con).await + async fn ping(con: &mut MultiplexedConnection) -> Result<(), ConnectionProviderError> { + redis::cmd("PING") + .query_async(con) + .await + .map_err(Into::into) } pub(crate) fn credentials_refreshed(&self) -> bool { self.credentials_refreshed.load(Ordering::Relaxed) } - pub(crate) async fn connect(&mut self) -> anyhow::Result<()> { + pub(crate) async fn connect(&mut self) -> Result<(), ConnectionProviderError> { let _guard = self.mutex.lock().await; if let Some(con) = self.con.as_mut() { match Self::ping(con).await { @@ -141,7 +153,7 @@ impl ConnectionWithCredentialsProvider { Ok(()) } - async fn get_connection_info(&self) -> anyhow::Result { + async fn get_connection_info(&self) -> Result { match &self.credentials { Credentials::Static(info) => Ok(info.clone()), Credentials::Dynamic(provider, addr) => { @@ -160,7 +172,7 @@ impl ConnectionWithCredentialsProvider { } } - async fn get_client(&self) -> anyhow::Result { + async fn get_client(&self) -> Result { let client = redis::Client::open(self.get_connection_info().await?)?; self.credentials_refreshed.store(true, Ordering::Relaxed); Ok(client) diff --git a/proxy/src/redis/elasticache.rs b/proxy/src/redis/elasticache.rs index 58e3c889a7..6f3b34d381 100644 --- a/proxy/src/redis/elasticache.rs +++ b/proxy/src/redis/elasticache.rs @@ -9,10 +9,12 @@ use aws_config::meta::region::RegionProviderChain; use aws_config::profile::ProfileFileCredentialsProvider; use aws_config::provider_config::ProviderConfig; use aws_config::web_identity_token::WebIdentityTokenCredentialsProvider; +use aws_credential_types::provider::error::CredentialsError; use aws_sdk_iam::config::ProvideCredentials; use aws_sigv4::http_request::{ - self, SignableBody, SignableRequest, SignatureLocation, SigningSettings, + self, SignableBody, SignableRequest, SignatureLocation, SigningError, SigningSettings, }; +use aws_sigv4::sign::v4::signing_params::BuildError; use tracing::info; #[derive(Debug)] @@ -40,6 +42,18 @@ impl AWSIRSAConfig { } } +#[derive(thiserror::Error, Debug)] +pub enum CredentialsProviderError { + #[error(transparent)] + AwsCredentials(#[from] CredentialsError), + #[error(transparent)] + AwsSigv4Build(#[from] BuildError), + #[error(transparent)] + AwsSigv4Singing(#[from] SigningError), + #[error(transparent)] + Http(#[from] http::Error), +} + /// Credentials provider for AWS elasticache authentication. /// /// Official documentation: @@ -92,7 +106,9 @@ impl CredentialsProvider { }) } - pub(crate) async fn provide_credentials(&self) -> anyhow::Result<(String, String)> { + pub(crate) async fn provide_credentials( + &self, + ) -> Result<(String, String), CredentialsProviderError> { let aws_credentials = self .credentials_provider .provide_credentials() diff --git a/proxy/src/redis/kv_ops.rs b/proxy/src/redis/kv_ops.rs index cfdbc21839..d1e97b6b09 100644 --- a/proxy/src/redis/kv_ops.rs +++ b/proxy/src/redis/kv_ops.rs @@ -2,9 +2,18 @@ use std::time::Duration; use futures::FutureExt; use redis::aio::ConnectionLike; -use redis::{Cmd, FromRedisValue, Pipeline, RedisResult}; +use redis::{Cmd, FromRedisValue, Pipeline, RedisError, RedisResult}; use super::connection_with_credentials_provider::ConnectionWithCredentialsProvider; +use crate::redis::connection_with_credentials_provider::ConnectionProviderError; + +#[derive(thiserror::Error, Debug)] +pub enum RedisKVClientError { + #[error(transparent)] + Redis(#[from] RedisError), + #[error(transparent)] + ConnectionProvider(#[from] ConnectionProviderError), +} pub struct RedisKVClient { client: ConnectionWithCredentialsProvider, @@ -32,12 +41,13 @@ impl RedisKVClient { Self { client } } - pub async fn try_connect(&mut self) -> anyhow::Result<()> { + pub async fn try_connect(&mut self) -> Result<(), RedisKVClientError> { self.client .connect() .boxed() .await .inspect_err(|e| tracing::error!("failed to connect to redis: {e}")) + .map_err(Into::into) } pub(crate) fn credentials_refreshed(&self) -> bool { @@ -47,7 +57,7 @@ impl RedisKVClient { pub(crate) async fn query( &mut self, q: &impl Queryable, - ) -> anyhow::Result { + ) -> Result { let e = match q.query(&mut self.client).await { Ok(t) => return Ok(t), Err(e) => e,