diff --git a/Cargo.lock b/Cargo.lock
index 025f4e4116..4323254f0a 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -5289,6 +5289,7 @@ dependencies = [
"async-trait",
"atomic-take",
"aws-config",
+ "aws-credential-types",
"aws-sdk-iam",
"aws-sigv4",
"base64 0.22.1",
diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml
index ce8610be24..0a406d1ca8 100644
--- a/proxy/Cargo.toml
+++ b/proxy/Cargo.toml
@@ -16,6 +16,7 @@ async-compression.workspace = true
async-trait.workspace = true
atomic-take.workspace = true
aws-config.workspace = true
+aws-credential-types.workspace = true
aws-sdk-iam.workspace = true
aws-sigv4.workspace = true
base64.workspace = true
@@ -127,4 +128,4 @@ rstest.workspace = true
walkdir.workspace = true
rand_distr = "0.4"
tokio-postgres.workspace = true
-tracing-test = "0.2"
\ No newline at end of file
+tracing-test = "0.2"
diff --git a/proxy/src/batch.rs b/proxy/src/batch.rs
index 33e08797f2..cf866ab9a3 100644
--- a/proxy/src/batch.rs
+++ b/proxy/src/batch.rs
@@ -7,13 +7,17 @@ use std::pin::pin;
use std::sync::Mutex;
use scopeguard::ScopeGuard;
+use tokio::sync::oneshot;
use tokio::sync::oneshot::error::TryRecvError;
use crate::ext::LockExt;
+type ProcResult
= Result<
::Res,
::Err>;
+
pub trait QueueProcessing: Send + 'static {
type Req: Send + 'static;
type Res: Send;
+ type Err: Send + Clone;
/// Get the desired batch size.
fn batch_size(&self, queue_size: usize) -> usize;
@@ -24,7 +28,18 @@ pub trait QueueProcessing: Send + 'static {
/// If this apply can error, it's expected that errors be forwarded to each Self::Res.
///
/// Batching does not need to happen atomically.
- fn apply(&mut self, req: Vec) -> impl Future> + Send;
+ fn apply(
+ &mut self,
+ req: Vec,
+ ) -> impl Future, Self::Err>> + Send;
+}
+
+#[derive(thiserror::Error)]
+pub enum BatchQueueError {
+ #[error(transparent)]
+ Result(E),
+ #[error(transparent)]
+ Cancelled(C),
}
pub struct BatchQueue {
@@ -34,7 +49,7 @@ pub struct BatchQueue {
struct BatchJob {
req: P::Req,
- res: tokio::sync::oneshot::Sender,
+ res: tokio::sync::oneshot::Sender>,
}
impl BatchQueue {
@@ -55,11 +70,11 @@ impl BatchQueue {
&self,
req: P::Req,
cancelled: impl Future,
- ) -> Result {
+ ) -> Result> {
let (id, mut rx) = self.inner.lock_propagate_poison().register_job(req);
let mut cancelled = pin!(cancelled);
- let resp = loop {
+ let resp: Option> = loop {
// try become the leader, or try wait for success.
let mut processor = tokio::select! {
// try become leader.
@@ -72,7 +87,7 @@ impl BatchQueue {
if inner.queue.remove(&id).is_some() {
tracing::warn!("batched task cancelled before completion");
}
- return Err(cancel);
+ return Err(BatchQueueError::Cancelled(cancel));
},
};
@@ -96,18 +111,30 @@ impl BatchQueue {
// good: we didn't get cancelled.
ScopeGuard::into_inner(cancel_safety);
- if values.len() != resps.len() {
- tracing::error!(
- "batch: invalid response size, expected={}, got={}",
- resps.len(),
- values.len()
- );
- }
+ match values {
+ Ok(values) => {
+ if values.len() != resps.len() {
+ tracing::error!(
+ "batch: invalid response size, expected={}, got={}",
+ resps.len(),
+ values.len()
+ );
+ }
- // send response values.
- for (tx, value) in std::iter::zip(resps, values) {
- if tx.send(value).is_err() {
- // receiver hung up but that's fine.
+ // send response values.
+ for (tx, value) in std::iter::zip(resps, values) {
+ if tx.send(Ok(value)).is_err() {
+ // receiver hung up but that's fine.
+ }
+ }
+ }
+
+ Err(err) => {
+ for tx in resps {
+ if tx.send(Err(err.clone())).is_err() {
+ // receiver hung up but that's fine.
+ }
+ }
}
}
@@ -129,7 +156,8 @@ impl BatchQueue {
tracing::debug!(id, "batch: job completed");
- Ok(resp.expect("no response found. batch processer should not panic"))
+ resp.expect("no response found. batch processer should not panic")
+ .map_err(BatchQueueError::Result)
}
}
@@ -139,8 +167,8 @@ struct BatchQueueInner {
}
impl BatchQueueInner {
- fn register_job(&mut self, req: P::Req) -> (u64, tokio::sync::oneshot::Receiver) {
- let (tx, rx) = tokio::sync::oneshot::channel();
+ fn register_job(&mut self, req: P::Req) -> (u64, oneshot::Receiver>) {
+ let (tx, rx) = oneshot::channel();
let id = self.version;
@@ -158,7 +186,7 @@ impl BatchQueueInner {
(id, rx)
}
- fn get_batch(&mut self, p: &P) -> (Vec, Vec>) {
+ fn get_batch(&mut self, p: &P) -> (Vec, Vec>>) {
let batch_size = p.batch_size(self.queue.len());
let mut reqs = Vec::with_capacity(batch_size);
let mut resps = Vec::with_capacity(batch_size);
diff --git a/proxy/src/cancellation.rs b/proxy/src/cancellation.rs
index 74413f1a7d..4ea4c4ea54 100644
--- a/proxy/src/cancellation.rs
+++ b/proxy/src/cancellation.rs
@@ -4,12 +4,11 @@ use std::pin::pin;
use std::sync::{Arc, OnceLock};
use std::time::Duration;
-use anyhow::anyhow;
use futures::FutureExt;
use ipnet::{IpNet, Ipv4Net, Ipv6Net};
use postgres_client::RawCancelToken;
use postgres_client::tls::MakeTlsConnect;
-use redis::{Cmd, FromRedisValue, Value};
+use redis::{Cmd, FromRedisValue, SetExpiry, SetOptions, Value};
use serde::{Deserialize, Serialize};
use thiserror::Error;
use tokio::net::TcpStream;
@@ -18,7 +17,7 @@ use tracing::{debug, error, info};
use crate::auth::AuthError;
use crate::auth::backend::ComputeUserInfo;
-use crate::batch::{BatchQueue, QueueProcessing};
+use crate::batch::{BatchQueue, BatchQueueError, QueueProcessing};
use crate::config::ComputeConfig;
use crate::context::RequestContext;
use crate::control_plane::ControlPlaneApi;
@@ -28,7 +27,7 @@ use crate::metrics::{CancelChannelSizeGuard, CancellationRequest, Metrics, Redis
use crate::pqproto::CancelKeyData;
use crate::rate_limiter::LeakyBucketRateLimiter;
use crate::redis::keys::KeyPrefix;
-use crate::redis::kv_ops::RedisKVClient;
+use crate::redis::kv_ops::{RedisKVClient, RedisKVClientError};
type IpSubnetKey = IpNet;
@@ -45,6 +44,17 @@ pub enum CancelKeyOp {
GetCancelData {
key: CancelKeyData,
},
+ GetCancelDataOld {
+ key: CancelKeyData,
+ },
+}
+
+#[derive(thiserror::Error, Debug, Clone)]
+pub enum PipelineError {
+ #[error("could not send cmd to redis: {0}")]
+ RedisKVClient(Arc),
+ #[error("incorrect number of responses from redis")]
+ IncorrectNumberOfResponses,
}
pub struct Pipeline {
@@ -60,7 +70,7 @@ impl Pipeline {
}
}
- async fn execute(self, client: &mut RedisKVClient) -> Vec> {
+ async fn execute(self, client: &mut RedisKVClient) -> Result, PipelineError> {
let responses = self.replies;
let batch_size = self.inner.len();
@@ -78,30 +88,20 @@ impl Pipeline {
batch_size,
responses, "successfully completed cancellation jobs",
);
- values.into_iter().map(Ok).collect()
+ Ok(values.into_iter().collect())
}
Ok(value) => {
error!(batch_size, ?value, "unexpected redis return value");
- std::iter::repeat_with(|| Err(anyhow!("incorrect response type from redis")))
- .take(responses)
- .collect()
- }
- Err(err) => {
- std::iter::repeat_with(|| Err(anyhow!("could not send cmd to redis: {err}")))
- .take(responses)
- .collect()
+ Err(PipelineError::IncorrectNumberOfResponses)
}
+ Err(err) => Err(PipelineError::RedisKVClient(Arc::new(err))),
}
}
- fn add_command_with_reply(&mut self, cmd: Cmd) {
+ fn add_command(&mut self, cmd: Cmd) {
self.inner.add_command(cmd);
self.replies += 1;
}
-
- fn add_command_no_reply(&mut self, cmd: Cmd) {
- self.inner.add_command(cmd).ignore();
- }
}
impl CancelKeyOp {
@@ -109,12 +109,19 @@ impl CancelKeyOp {
match self {
CancelKeyOp::StoreCancelKey { key, value, expire } => {
let key = KeyPrefix::Cancel(*key).build_redis_key();
- pipe.add_command_with_reply(Cmd::hset(&key, "data", &**value));
- pipe.add_command_no_reply(Cmd::expire(&key, expire.as_secs() as i64));
+ pipe.add_command(Cmd::set_options(
+ &key,
+ &**value,
+ SetOptions::default().with_expiration(SetExpiry::EX(expire.as_secs())),
+ ));
+ }
+ CancelKeyOp::GetCancelDataOld { key } => {
+ let key = KeyPrefix::Cancel(*key).build_redis_key();
+ pipe.add_command(Cmd::hget(key, "data"));
}
CancelKeyOp::GetCancelData { key } => {
let key = KeyPrefix::Cancel(*key).build_redis_key();
- pipe.add_command_with_reply(Cmd::hget(key, "data"));
+ pipe.add_command(Cmd::get(key));
}
}
}
@@ -127,13 +134,14 @@ pub struct CancellationProcessor {
impl QueueProcessing for CancellationProcessor {
type Req = (CancelChannelSizeGuard<'static>, CancelKeyOp);
- type Res = anyhow::Result;
+ type Res = redis::Value;
+ type Err = PipelineError;
fn batch_size(&self, _queue_size: usize) -> usize {
self.batch_size
}
- async fn apply(&mut self, batch: Vec) -> Vec {
+ async fn apply(&mut self, batch: Vec) -> Result, Self::Err> {
if !self.client.credentials_refreshed() {
// this will cause a timeout for cancellation operations
tracing::debug!(
@@ -244,18 +252,18 @@ impl CancellationHandler {
&self,
key: CancelKeyData,
) -> Result, CancelError> {
- let guard = Metrics::get()
- .proxy
- .cancel_channel_size
- .guard(RedisMsgKind::HGet);
- let op = CancelKeyOp::GetCancelData { key };
+ const TIMEOUT: Duration = Duration::from_secs(5);
let Some(tx) = self.tx.get() else {
tracing::warn!("cancellation handler is not available");
return Err(CancelError::InternalError);
};
- const TIMEOUT: Duration = Duration::from_secs(5);
+ let guard = Metrics::get()
+ .proxy
+ .cancel_channel_size
+ .guard(RedisMsgKind::Get);
+ let op = CancelKeyOp::GetCancelData { key };
let result = timeout(
TIMEOUT,
tx.call((guard, op), std::future::pending::()),
@@ -264,10 +272,37 @@ impl CancellationHandler {
.map_err(|_| {
tracing::warn!("timed out waiting to receive GetCancelData response");
CancelError::RateLimit
- })?
- // cannot be cancelled
- .unwrap_or_else(|x| match x {})
- .map_err(|e| {
+ })?;
+
+ // We may still have cancel keys set with HSET "data".
+ // Check error type and retry with HGET.
+ // TODO: remove code after HSET is not used anymore.
+ let result = if let Err(err) = result.as_ref()
+ && let BatchQueueError::Result(err) = err
+ && let PipelineError::RedisKVClient(err) = err
+ && let RedisKVClientError::Redis(err) = &**err
+ && let Some(errcode) = err.code()
+ && errcode == "WRONGTYPE"
+ {
+ let guard = Metrics::get()
+ .proxy
+ .cancel_channel_size
+ .guard(RedisMsgKind::HGet);
+ let op = CancelKeyOp::GetCancelDataOld { key };
+ timeout(
+ TIMEOUT,
+ tx.call((guard, op), std::future::pending::()),
+ )
+ .await
+ .map_err(|_| {
+ tracing::warn!("timed out waiting to receive GetCancelData response");
+ CancelError::RateLimit
+ })?
+ } else {
+ result
+ };
+
+ let result = result.map_err(|e| {
tracing::warn!("failed to receive GetCancelData response: {e}");
CancelError::InternalError
})?;
@@ -442,7 +477,7 @@ impl Session {
let guard = Metrics::get()
.proxy
.cancel_channel_size
- .guard(RedisMsgKind::HSet);
+ .guard(RedisMsgKind::Set);
let op = CancelKeyOp::StoreCancelKey {
key: self.key,
value: closure_json.clone(),
@@ -456,7 +491,7 @@ impl Session {
);
match tx.call((guard, op), cancel.as_mut()).await {
- Ok(Ok(_)) => {
+ Ok(_) => {
tracing::debug!(
src=%self.key,
dest=?cancel_closure.cancel_token,
@@ -467,10 +502,10 @@ impl Session {
tokio::time::sleep(CANCEL_KEY_REFRESH).await;
}
// retry immediately.
- Ok(Err(error)) => {
+ Err(BatchQueueError::Result(error)) => {
tracing::warn!(?error, "error registering cancellation key");
}
- Err(Err(_cancelled)) => break,
+ Err(BatchQueueError::Cancelled(Err(_cancelled))) => break,
}
}
diff --git a/proxy/src/metrics.rs b/proxy/src/metrics.rs
index 9d1a3d4358..8439082498 100644
--- a/proxy/src/metrics.rs
+++ b/proxy/src/metrics.rs
@@ -374,11 +374,9 @@ pub enum Waiting {
#[label(singleton = "kind")]
#[allow(clippy::enum_variant_names)]
pub enum RedisMsgKind {
- HSet,
- HSetMultiple,
+ Set,
+ Get,
HGet,
- HGetAll,
- HDel,
}
#[derive(Default, Clone)]
diff --git a/proxy/src/redis/connection_with_credentials_provider.rs b/proxy/src/redis/connection_with_credentials_provider.rs
index 35a3fe4334..b0bf332e44 100644
--- a/proxy/src/redis/connection_with_credentials_provider.rs
+++ b/proxy/src/redis/connection_with_credentials_provider.rs
@@ -4,11 +4,12 @@ use std::time::Duration;
use futures::FutureExt;
use redis::aio::{ConnectionLike, MultiplexedConnection};
-use redis::{ConnectionInfo, IntoConnectionInfo, RedisConnectionInfo, RedisResult};
+use redis::{ConnectionInfo, IntoConnectionInfo, RedisConnectionInfo, RedisError, RedisResult};
use tokio::task::AbortHandle;
use tracing::{error, info, warn};
use super::elasticache::CredentialsProvider;
+use crate::redis::elasticache::CredentialsProviderError;
enum Credentials {
Static(ConnectionInfo),
@@ -26,6 +27,14 @@ impl Clone for Credentials {
}
}
+#[derive(thiserror::Error, Debug)]
+pub enum ConnectionProviderError {
+ #[error(transparent)]
+ Redis(#[from] RedisError),
+ #[error(transparent)]
+ CredentialsProvider(#[from] CredentialsProviderError),
+}
+
/// A wrapper around `redis::MultiplexedConnection` that automatically refreshes the token.
/// Provides PubSub connection without credentials refresh.
pub struct ConnectionWithCredentialsProvider {
@@ -86,15 +95,18 @@ impl ConnectionWithCredentialsProvider {
}
}
- async fn ping(con: &mut MultiplexedConnection) -> RedisResult<()> {
- redis::cmd("PING").query_async(con).await
+ async fn ping(con: &mut MultiplexedConnection) -> Result<(), ConnectionProviderError> {
+ redis::cmd("PING")
+ .query_async(con)
+ .await
+ .map_err(Into::into)
}
pub(crate) fn credentials_refreshed(&self) -> bool {
self.credentials_refreshed.load(Ordering::Relaxed)
}
- pub(crate) async fn connect(&mut self) -> anyhow::Result<()> {
+ pub(crate) async fn connect(&mut self) -> Result<(), ConnectionProviderError> {
let _guard = self.mutex.lock().await;
if let Some(con) = self.con.as_mut() {
match Self::ping(con).await {
@@ -141,7 +153,7 @@ impl ConnectionWithCredentialsProvider {
Ok(())
}
- async fn get_connection_info(&self) -> anyhow::Result {
+ async fn get_connection_info(&self) -> Result {
match &self.credentials {
Credentials::Static(info) => Ok(info.clone()),
Credentials::Dynamic(provider, addr) => {
@@ -160,7 +172,7 @@ impl ConnectionWithCredentialsProvider {
}
}
- async fn get_client(&self) -> anyhow::Result {
+ async fn get_client(&self) -> Result {
let client = redis::Client::open(self.get_connection_info().await?)?;
self.credentials_refreshed.store(true, Ordering::Relaxed);
Ok(client)
diff --git a/proxy/src/redis/elasticache.rs b/proxy/src/redis/elasticache.rs
index 58e3c889a7..6f3b34d381 100644
--- a/proxy/src/redis/elasticache.rs
+++ b/proxy/src/redis/elasticache.rs
@@ -9,10 +9,12 @@ use aws_config::meta::region::RegionProviderChain;
use aws_config::profile::ProfileFileCredentialsProvider;
use aws_config::provider_config::ProviderConfig;
use aws_config::web_identity_token::WebIdentityTokenCredentialsProvider;
+use aws_credential_types::provider::error::CredentialsError;
use aws_sdk_iam::config::ProvideCredentials;
use aws_sigv4::http_request::{
- self, SignableBody, SignableRequest, SignatureLocation, SigningSettings,
+ self, SignableBody, SignableRequest, SignatureLocation, SigningError, SigningSettings,
};
+use aws_sigv4::sign::v4::signing_params::BuildError;
use tracing::info;
#[derive(Debug)]
@@ -40,6 +42,18 @@ impl AWSIRSAConfig {
}
}
+#[derive(thiserror::Error, Debug)]
+pub enum CredentialsProviderError {
+ #[error(transparent)]
+ AwsCredentials(#[from] CredentialsError),
+ #[error(transparent)]
+ AwsSigv4Build(#[from] BuildError),
+ #[error(transparent)]
+ AwsSigv4Singing(#[from] SigningError),
+ #[error(transparent)]
+ Http(#[from] http::Error),
+}
+
/// Credentials provider for AWS elasticache authentication.
///
/// Official documentation:
@@ -92,7 +106,9 @@ impl CredentialsProvider {
})
}
- pub(crate) async fn provide_credentials(&self) -> anyhow::Result<(String, String)> {
+ pub(crate) async fn provide_credentials(
+ &self,
+ ) -> Result<(String, String), CredentialsProviderError> {
let aws_credentials = self
.credentials_provider
.provide_credentials()
diff --git a/proxy/src/redis/kv_ops.rs b/proxy/src/redis/kv_ops.rs
index cfdbc21839..d1e97b6b09 100644
--- a/proxy/src/redis/kv_ops.rs
+++ b/proxy/src/redis/kv_ops.rs
@@ -2,9 +2,18 @@ use std::time::Duration;
use futures::FutureExt;
use redis::aio::ConnectionLike;
-use redis::{Cmd, FromRedisValue, Pipeline, RedisResult};
+use redis::{Cmd, FromRedisValue, Pipeline, RedisError, RedisResult};
use super::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
+use crate::redis::connection_with_credentials_provider::ConnectionProviderError;
+
+#[derive(thiserror::Error, Debug)]
+pub enum RedisKVClientError {
+ #[error(transparent)]
+ Redis(#[from] RedisError),
+ #[error(transparent)]
+ ConnectionProvider(#[from] ConnectionProviderError),
+}
pub struct RedisKVClient {
client: ConnectionWithCredentialsProvider,
@@ -32,12 +41,13 @@ impl RedisKVClient {
Self { client }
}
- pub async fn try_connect(&mut self) -> anyhow::Result<()> {
+ pub async fn try_connect(&mut self) -> Result<(), RedisKVClientError> {
self.client
.connect()
.boxed()
.await
.inspect_err(|e| tracing::error!("failed to connect to redis: {e}"))
+ .map_err(Into::into)
}
pub(crate) fn credentials_refreshed(&self) -> bool {
@@ -47,7 +57,7 @@ impl RedisKVClient {
pub(crate) async fn query(
&mut self,
q: &impl Queryable,
- ) -> anyhow::Result {
+ ) -> Result {
let e = match q.query(&mut self.client).await {
Ok(t) => return Ok(t),
Err(e) => e,