diff --git a/Cargo.lock b/Cargo.lock index 6409c79ef9..95af9af62b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -347,9 +347,9 @@ dependencies = [ [[package]] name = "aws-credential-types" -version = "1.1.8" +version = "1.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fa8587ae17c8e967e4b05a62d495be2fb7701bec52a97f7acfe8a29f938384c8" +checksum = "33cc49dcdd31c8b6e79850a179af4c367669150c7ac0135f176c61bec81a70f7" dependencies = [ "aws-smithy-async", "aws-smithy-runtime-api", @@ -359,9 +359,9 @@ dependencies = [ [[package]] name = "aws-runtime" -version = "1.1.8" +version = "1.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b13dc54b4b49f8288532334bba8f87386a40571c47c37b1304979b556dc613c8" +checksum = "eb031bff99877c26c28895766f7bb8484a05e24547e370768d6cc9db514662aa" dependencies = [ "aws-credential-types", "aws-sigv4", @@ -381,29 +381,6 @@ dependencies = [ "uuid", ] -[[package]] -name = "aws-sdk-iam" -version = "1.17.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b8ae76026bfb1b80a6aed0bb400c1139cd9c0563e26bce1986cd021c6a968c7b" -dependencies = [ - "aws-credential-types", - "aws-runtime", - "aws-smithy-async", - "aws-smithy-http", - "aws-smithy-json", - "aws-smithy-query", - "aws-smithy-runtime", - "aws-smithy-runtime-api", - "aws-smithy-types", - "aws-smithy-xml", - "aws-types", - "http 0.2.9", - "once_cell", - "regex-lite", - "tracing", -] - [[package]] name = "aws-sdk-s3" version = "1.14.0" @@ -525,9 +502,9 @@ dependencies = [ [[package]] name = "aws-sigv4" -version = "1.2.0" +version = "1.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "11d6f29688a4be9895c0ba8bef861ad0c0dac5c15e9618b9b7a6c233990fc263" +checksum = "c371c6b0ac54d4605eb6f016624fb5c7c2925d315fdf600ac1bf21b19d5f1742" dependencies = [ "aws-credential-types", "aws-smithy-eventstream", @@ -540,7 +517,7 @@ dependencies = [ "hex", "hmac", "http 0.2.9", - "http 1.1.0", + "http 1.0.0", "once_cell", "p256", "percent-encoding", @@ -554,9 +531,9 @@ dependencies = [ [[package]] name = "aws-smithy-async" -version = "1.1.8" +version = "1.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d26ea8fa03025b2face2b3038a63525a10891e3d8829901d502e5384a0d8cd46" +checksum = "72ee2d09cce0ef3ae526679b522835d63e75fb427aca5413cd371e490d52dcc6" dependencies = [ "futures-util", "pin-project-lite", @@ -597,9 +574,9 @@ dependencies = [ [[package]] name = "aws-smithy-http" -version = "0.60.7" +version = "0.60.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3f10fa66956f01540051b0aa7ad54574640f748f9839e843442d99b970d3aff9" +checksum = "dab56aea3cd9e1101a0a999447fb346afb680ab1406cebc44b32346e25b4117d" dependencies = [ "aws-smithy-eventstream", "aws-smithy-runtime-api", @@ -618,18 +595,18 @@ dependencies = [ [[package]] name = "aws-smithy-json" -version = "0.60.7" +version = "0.60.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4683df9469ef09468dad3473d129960119a0d3593617542b7d52086c8486f2d6" +checksum = "fd3898ca6518f9215f62678870064398f00031912390efd03f1f6ef56d83aa8e" dependencies = [ "aws-smithy-types", ] [[package]] name = "aws-smithy-query" -version = "0.60.7" +version = "0.60.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f2fbd61ceb3fe8a1cb7352e42689cec5335833cd9f94103a61e98f9bb61c64bb" +checksum = "bda4b1dfc9810e35fba8a620e900522cd1bd4f9578c446e82f49d1ce41d2e9f9" dependencies = [ "aws-smithy-types", "urlencoding", @@ -637,9 +614,9 @@ dependencies = [ [[package]] name = "aws-smithy-runtime" -version = "1.1.8" +version = "1.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec81002d883e5a7fd2bb063d6fb51c4999eb55d404f4fff3dd878bf4733b9f01" +checksum = "fafdab38f40ad7816e7da5dec279400dd505160780083759f01441af1bbb10ea" dependencies = [ "aws-smithy-async", "aws-smithy-http", @@ -662,15 +639,14 @@ dependencies = [ [[package]] name = "aws-smithy-runtime-api" -version = "1.2.0" +version = "1.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9acb931e0adaf5132de878f1398d83f8677f90ba70f01f65ff87f6d7244be1c5" +checksum = "c18276dd28852f34b3bf501f4f3719781f4999a51c7bff1a5c6dc8c4529adc29" dependencies = [ "aws-smithy-async", "aws-smithy-types", "bytes", "http 0.2.9", - "http 1.1.0", "pin-project-lite", "tokio", "tracing", @@ -679,9 +655,9 @@ dependencies = [ [[package]] name = "aws-smithy-types" -version = "1.1.8" +version = "1.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "abe14dceea1e70101d38fbf2a99e6a34159477c0fb95e68e05c66bd7ae4c3729" +checksum = "bb3e134004170d3303718baa2a4eb4ca64ee0a1c0a7041dca31b38be0fb414f3" dependencies = [ "base64-simd", "bytes", @@ -702,18 +678,18 @@ dependencies = [ [[package]] name = "aws-smithy-xml" -version = "0.60.7" +version = "0.60.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "872c68cf019c0e4afc5de7753c4f7288ce4b71663212771bf5e4542eb9346ca9" +checksum = "8604a11b25e9ecaf32f9aa56b9fe253c5e2f606a3477f0071e96d3155a5ed218" dependencies = [ "xmlparser", ] [[package]] name = "aws-types" -version = "1.1.8" +version = "1.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0dbf2f3da841a8930f159163175cf6a3d16ddde517c1b0fba7aa776822800f40" +checksum = "789bbe008e65636fe1b6dbbb374c40c8960d1232b96af5ff4aec349f9c4accf4" dependencies = [ "aws-credential-types", "aws-smithy-async", @@ -2420,9 +2396,9 @@ dependencies = [ [[package]] name = "http" -version = "1.1.0" +version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "21b9ddb458710bc376481b842f5da65cdf31522de232c1ca8146abce2a358258" +checksum = "b32afd38673a8016f7c9ae69e5af41a58f81b1d31689040f2f1959594ce194ea" dependencies = [ "bytes", "fnv", @@ -2522,7 +2498,7 @@ dependencies = [ "hyper", "log", "rustls 0.21.9", - "rustls-native-certs 0.6.2", + "rustls-native-certs", "tokio", "tokio-rustls 0.24.0", ] @@ -4223,10 +4199,6 @@ version = "0.1.0" dependencies = [ "anyhow", "async-trait", - "aws-config", - "aws-sdk-iam", - "aws-sigv4", - "aws-types", "base64 0.13.1", "bstr", "bytes", @@ -4245,7 +4217,6 @@ dependencies = [ "hex", "hmac", "hostname", - "http 1.1.0", "humantime", "hyper", "hyper-tungstenite", @@ -4461,9 +4432,9 @@ dependencies = [ [[package]] name = "redis" -version = "0.25.2" +version = "0.24.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "71d64e978fd98a0e6b105d066ba4889a7301fca65aeac850a877d8797343feeb" +checksum = "c580d9cbbe1d1b479e8d67cf9daf6a62c957e6846048408b80b43ac3f6af84cd" dependencies = [ "async-trait", "bytes", @@ -4472,15 +4443,15 @@ dependencies = [ "itoa", "percent-encoding", "pin-project-lite", - "rustls 0.22.2", - "rustls-native-certs 0.7.0", - "rustls-pemfile 2.1.1", - "rustls-pki-types", + "rustls 0.21.9", + "rustls-native-certs", + "rustls-pemfile 1.0.2", + "rustls-webpki 0.101.7", "ryu", "sha1_smol", - "socket2 0.5.5", + "socket2 0.4.9", "tokio", - "tokio-rustls 0.25.0", + "tokio-rustls 0.24.0", "tokio-util", "url", ] @@ -4909,19 +4880,6 @@ dependencies = [ "security-framework", ] -[[package]] -name = "rustls-native-certs" -version = "0.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f1fb85efa936c42c6d5fc28d2629bb51e4b2f4b8a5211e297d599cc5a093792" -dependencies = [ - "openssl-probe", - "rustls-pemfile 2.1.1", - "rustls-pki-types", - "schannel", - "security-framework", -] - [[package]] name = "rustls-pemfile" version = "1.0.2" @@ -6189,7 +6147,7 @@ dependencies = [ "percent-encoding", "pin-project", "prost", - "rustls-native-certs 0.6.2", + "rustls-native-certs", "rustls-pemfile 1.0.2", "tokio", "tokio-rustls 0.24.0", @@ -7074,6 +7032,7 @@ dependencies = [ "aws-sigv4", "aws-smithy-async", "aws-smithy-http", + "aws-smithy-runtime-api", "aws-smithy-types", "axum", "base64 0.21.1", diff --git a/Cargo.toml b/Cargo.toml index 4dda63ff58..065f190126 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -53,12 +53,9 @@ async-trait = "0.1" aws-config = { version = "1.1.4", default-features = false, features=["rustls"] } aws-sdk-s3 = "1.14" aws-sdk-secretsmanager = { version = "1.14.0" } -aws-sdk-iam = "1.15.0" aws-smithy-async = { version = "1.1.4", default-features = false, features=["rt-tokio"] } aws-smithy-types = "1.1.4" aws-credential-types = "1.1.4" -aws-sigv4 = { version = "1.2.0", features = ["sign-http"] } -aws-types = "1.1.7" axum = { version = "0.6.20", features = ["ws"] } base64 = "0.13.0" bincode = "1.3" @@ -92,7 +89,6 @@ hex = "0.4" hex-literal = "0.4" hmac = "0.12.1" hostname = "0.3.1" -http = {version = "1.1.0", features = ["std"]} http-types = { version = "2", default-features = false } humantime = "2.1" humantime-serde = "1.1.1" @@ -126,7 +122,7 @@ procfs = "0.14" prometheus = {version = "0.13", default_features=false, features = ["process"]} # removes protobuf dependency prost = "0.11" rand = "0.8" -redis = { version = "0.25.2", features = ["tokio-rustls-comp", "keep-alive"] } +redis = { version = "0.24.0", features = ["tokio-rustls-comp", "keep-alive"] } regex = "1.10.2" reqwest = { version = "0.11", default-features = false, features = ["rustls-tls"] } reqwest-tracing = { version = "0.4.7", features = ["opentelemetry_0_20"] } diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index 57a2736d5b..601b99a42f 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -11,10 +11,6 @@ testing = [] [dependencies] anyhow.workspace = true async-trait.workspace = true -aws-config.workspace = true -aws-sdk-iam.workspace = true -aws-sigv4.workspace = true -aws-types.workspace = true base64.workspace = true bstr.workspace = true bytes = { workspace = true, features = ["serde"] } @@ -31,7 +27,6 @@ hashlink.workspace = true hex.workspace = true hmac.workspace = true hostname.workspace = true -http.workspace = true humantime.workspace = true hyper-tungstenite.workspace = true hyper.workspace = true diff --git a/proxy/src/bin/proxy.rs b/proxy/src/bin/proxy.rs index d38439c2a0..b3d4fc0411 100644 --- a/proxy/src/bin/proxy.rs +++ b/proxy/src/bin/proxy.rs @@ -1,10 +1,3 @@ -use aws_config::environment::EnvironmentVariableCredentialsProvider; -use aws_config::imds::credentials::ImdsCredentialsProvider; -use aws_config::meta::credentials::CredentialsProviderChain; -use aws_config::meta::region::RegionProviderChain; -use aws_config::profile::ProfileFileCredentialsProvider; -use aws_config::provider_config::ProviderConfig; -use aws_config::web_identity_token::WebIdentityTokenCredentialsProvider; use futures::future::Either; use proxy::auth; use proxy::auth::backend::MaybeOwned; @@ -17,14 +10,11 @@ use proxy::config::ProjectInfoCacheOptions; use proxy::console; use proxy::context::parquet::ParquetUploadArgs; use proxy::http; -use proxy::metrics::NUM_CANCELLATION_REQUESTS_SOURCE_FROM_CLIENT; use proxy::rate_limiter::EndpointRateLimiter; use proxy::rate_limiter::RateBucketInfo; use proxy::rate_limiter::RateLimiterConfig; -use proxy::redis::cancellation_publisher::RedisPublisherClient; -use proxy::redis::connection_with_credentials_provider::ConnectionWithCredentialsProvider; -use proxy::redis::elasticache; use proxy::redis::notifications; +use proxy::redis::publisher::RedisPublisherClient; use proxy::serverless::GlobalConnPoolOptions; use proxy::usage_metrics; @@ -160,24 +150,9 @@ struct ProxyCliArgs { /// disable ip check for http requests. If it is too time consuming, it could be turned off. #[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)] disable_ip_check_for_http: bool, - /// redis url for notifications (if empty, redis_host:port will be used for both notifications and streaming connections) + /// redis url for notifications. #[clap(long)] redis_notifications: Option, - /// redis host for streaming connections (might be different from the notifications host) - #[clap(long)] - redis_host: Option, - /// redis port for streaming connections (might be different from the notifications host) - #[clap(long)] - redis_port: Option, - /// redis cluster name, used in aws elasticache - #[clap(long)] - redis_cluster_name: Option, - /// redis user_id, used in aws elasticache - #[clap(long)] - redis_user_id: Option, - /// aws region to retrieve credentials - #[clap(long, default_value_t = String::new())] - aws_region: String, /// cache for `project_info` (use `size=0` to disable) #[clap(long, default_value = config::ProjectInfoCacheOptions::CACHE_DEFAULT_OPTIONS)] project_info_cache: String, @@ -241,61 +216,6 @@ async fn main() -> anyhow::Result<()> { let config = build_config(&args)?; info!("Authentication backend: {}", config.auth_backend); - info!("Using region: {}", config.aws_region); - - let region_provider = RegionProviderChain::default_provider().or_else(&*config.aws_region); // Replace with your Redis region if needed - let provider_conf = - ProviderConfig::without_region().with_region(region_provider.region().await); - let aws_credentials_provider = { - // uses "AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY" - CredentialsProviderChain::first_try("env", EnvironmentVariableCredentialsProvider::new()) - // uses "AWS_PROFILE" / `aws sso login --profile ` - .or_else( - "profile-sso", - ProfileFileCredentialsProvider::builder() - .configure(&provider_conf) - .build(), - ) - // uses "AWS_WEB_IDENTITY_TOKEN_FILE", "AWS_ROLE_ARN", "AWS_ROLE_SESSION_NAME" - // needed to access remote extensions bucket - .or_else( - "token", - WebIdentityTokenCredentialsProvider::builder() - .configure(&provider_conf) - .build(), - ) - // uses imds v2 - .or_else("imds", ImdsCredentialsProvider::builder().build()) - }; - let elasticache_credentials_provider = Arc::new(elasticache::CredentialsProvider::new( - elasticache::AWSIRSAConfig::new( - config.aws_region.clone(), - args.redis_cluster_name, - args.redis_user_id, - ), - aws_credentials_provider, - )); - let redis_notifications_client = - match (args.redis_notifications, (args.redis_host, args.redis_port)) { - (Some(url), _) => { - info!("Starting redis notifications listener ({url})"); - Some(ConnectionWithCredentialsProvider::new_with_static_credentials(url)) - } - (None, (Some(host), Some(port))) => Some( - ConnectionWithCredentialsProvider::new_with_credentials_provider( - host, - port, - elasticache_credentials_provider.clone(), - ), - ), - (None, (None, None)) => { - warn!("Redis is disabled"); - None - } - _ => { - bail!("redis-host and redis-port must be specified together"); - } - }; // Check that we can bind to address before further initialization let http_address: SocketAddr = args.http.parse()?; @@ -313,22 +233,17 @@ async fn main() -> anyhow::Result<()> { let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new(&config.endpoint_rps_limit)); let cancel_map = CancelMap::default(); - - // let redis_notifications_client = redis_notifications_client.map(|x| Box::leak(Box::new(x))); - let redis_publisher = match &redis_notifications_client { - Some(redis_publisher) => Some(Arc::new(Mutex::new(RedisPublisherClient::new( - redis_publisher.clone(), + let redis_publisher = match &args.redis_notifications { + Some(url) => Some(Arc::new(Mutex::new(RedisPublisherClient::new( + url, args.region.clone(), &config.redis_rps_limit, )?))), None => None, }; - let cancellation_handler = Arc::new(CancellationHandler::< - Option>>, - >::new( + let cancellation_handler = Arc::new(CancellationHandler::new( cancel_map.clone(), redis_publisher, - NUM_CANCELLATION_REQUESTS_SOURCE_FROM_CLIENT, )); // client facing tasks. these will exit on error or on cancellation @@ -375,16 +290,17 @@ async fn main() -> anyhow::Result<()> { if let auth::BackendType::Console(api, _) = &config.auth_backend { if let proxy::console::provider::ConsoleBackend::Console(api) = &**api { - if let Some(redis_notifications_client) = redis_notifications_client { - let cache = api.caches.project_info.clone(); + let cache = api.caches.project_info.clone(); + if let Some(url) = args.redis_notifications { + info!("Starting redis notifications listener ({url})"); maintenance_tasks.spawn(notifications::task_main( - redis_notifications_client.clone(), + url.to_owned(), cache.clone(), cancel_map.clone(), args.region.clone(), )); - maintenance_tasks.spawn(async move { cache.clone().gc_worker().await }); } + maintenance_tasks.spawn(async move { cache.clone().gc_worker().await }); } } @@ -529,8 +445,8 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { endpoint_rps_limit, redis_rps_limit, handshake_timeout: args.handshake_timeout, + // TODO: add this argument region: args.region.clone(), - aws_region: args.aws_region.clone(), })); Ok(config) diff --git a/proxy/src/cancellation.rs b/proxy/src/cancellation.rs index 8054f33b6c..c9607909b3 100644 --- a/proxy/src/cancellation.rs +++ b/proxy/src/cancellation.rs @@ -1,3 +1,4 @@ +use async_trait::async_trait; use dashmap::DashMap; use pq_proto::CancelKeyData; use std::{net::SocketAddr, sync::Arc}; @@ -9,26 +10,18 @@ use tracing::info; use uuid::Uuid; use crate::{ - error::ReportableError, - metrics::NUM_CANCELLATION_REQUESTS, - redis::cancellation_publisher::{ - CancellationPublisher, CancellationPublisherMut, RedisPublisherClient, - }, + error::ReportableError, metrics::NUM_CANCELLATION_REQUESTS, + redis::publisher::RedisPublisherClient, }; pub type CancelMap = Arc>>; -pub type CancellationHandlerMain = CancellationHandler>>>; -pub type CancellationHandlerMainInternal = Option>>; /// Enables serving `CancelRequest`s. /// -/// If `CancellationPublisher` is available, cancel request will be used to publish the cancellation key to other proxy instances. -pub struct CancellationHandler

{ +/// If there is a `RedisPublisherClient` available, it will be used to publish the cancellation key to other proxy instances. +pub struct CancellationHandler { map: CancelMap, - client: P, - /// This field used for the monitoring purposes. - /// Represents the source of the cancellation request. - from: &'static str, + redis_client: Option>>, } #[derive(Debug, Error)] @@ -51,9 +44,49 @@ impl ReportableError for CancelError { } } -impl CancellationHandler

{ +impl CancellationHandler { + pub fn new(map: CancelMap, redis_client: Option>>) -> Self { + Self { map, redis_client } + } + /// Cancel a running query for the corresponding connection. + pub async fn cancel_session( + &self, + key: CancelKeyData, + session_id: Uuid, + ) -> Result<(), CancelError> { + let from = "from_client"; + // NB: we should immediately release the lock after cloning the token. + let Some(cancel_closure) = self.map.get(&key).and_then(|x| x.clone()) else { + tracing::warn!("query cancellation key not found: {key}"); + if let Some(redis_client) = &self.redis_client { + NUM_CANCELLATION_REQUESTS + .with_label_values(&[from, "not_found"]) + .inc(); + info!("publishing cancellation key to Redis"); + match redis_client.lock().await.try_publish(key, session_id).await { + Ok(()) => { + info!("cancellation key successfuly published to Redis"); + } + Err(e) => { + tracing::error!("failed to publish a message: {e}"); + return Err(CancelError::IO(std::io::Error::new( + std::io::ErrorKind::Other, + e.to_string(), + ))); + } + } + } + return Ok(()); + }; + NUM_CANCELLATION_REQUESTS + .with_label_values(&[from, "found"]) + .inc(); + info!("cancelling query per user's request using key {key}"); + cancel_closure.try_cancel_query().await + } + /// Run async action within an ephemeral session identified by [`CancelKeyData`]. - pub fn get_session(self: Arc) -> Session

{ + pub fn get_session(self: Arc) -> Session { // HACK: We'd rather get the real backend_pid but tokio_postgres doesn't // expose it and we don't want to do another roundtrip to query // for it. The client will be able to notice that this is not the @@ -79,39 +112,9 @@ impl CancellationHandler

{ cancellation_handler: self, } } - /// Try to cancel a running query for the corresponding connection. - /// If the cancellation key is not found, it will be published to Redis. - pub async fn cancel_session( - &self, - key: CancelKeyData, - session_id: Uuid, - ) -> Result<(), CancelError> { - // NB: we should immediately release the lock after cloning the token. - let Some(cancel_closure) = self.map.get(&key).and_then(|x| x.clone()) else { - tracing::warn!("query cancellation key not found: {key}"); - NUM_CANCELLATION_REQUESTS - .with_label_values(&[self.from, "not_found"]) - .inc(); - match self.client.try_publish(key, session_id).await { - Ok(()) => {} // do nothing - Err(e) => { - return Err(CancelError::IO(std::io::Error::new( - std::io::ErrorKind::Other, - e.to_string(), - ))); - } - } - return Ok(()); - }; - NUM_CANCELLATION_REQUESTS - .with_label_values(&[self.from, "found"]) - .inc(); - info!("cancelling query per user's request using key {key}"); - cancel_closure.try_cancel_query().await - } #[cfg(test)] - fn contains(&self, session: &Session

) -> bool { + fn contains(&self, session: &Session) -> bool { self.map.contains_key(&session.key) } @@ -121,19 +124,31 @@ impl CancellationHandler

{ } } -impl CancellationHandler<()> { - pub fn new(map: CancelMap, from: &'static str) -> Self { - Self { - map, - client: (), - from, - } - } +#[async_trait] +pub trait NotificationsCancellationHandler { + async fn cancel_session_no_publish(&self, key: CancelKeyData) -> Result<(), CancelError>; } -impl CancellationHandler>>> { - pub fn new(map: CancelMap, client: Option>>, from: &'static str) -> Self { - Self { map, client, from } +#[async_trait] +impl NotificationsCancellationHandler for CancellationHandler { + async fn cancel_session_no_publish(&self, key: CancelKeyData) -> Result<(), CancelError> { + let from = "from_redis"; + let cancel_closure = self.map.get(&key).and_then(|x| x.clone()); + match cancel_closure { + Some(cancel_closure) => { + NUM_CANCELLATION_REQUESTS + .with_label_values(&[from, "found"]) + .inc(); + cancel_closure.try_cancel_query().await + } + None => { + NUM_CANCELLATION_REQUESTS + .with_label_values(&[from, "not_found"]) + .inc(); + tracing::warn!("query cancellation key not found: {key}"); + Ok(()) + } + } } } @@ -163,14 +178,14 @@ impl CancelClosure { } /// Helper for registering query cancellation tokens. -pub struct Session

{ +pub struct Session { /// The user-facing key identifying this session. key: CancelKeyData, /// The [`CancelMap`] this session belongs to. - cancellation_handler: Arc>, + cancellation_handler: Arc, } -impl

Session

{ +impl Session { /// Store the cancel token for the given session. /// This enables query cancellation in `crate::proxy::prepare_client_connection`. pub fn enable_query_cancellation(&self, cancel_closure: CancelClosure) -> CancelKeyData { @@ -183,7 +198,7 @@ impl

Session

{ } } -impl

Drop for Session

{ +impl Drop for Session { fn drop(&mut self) { self.cancellation_handler.map.remove(&self.key); info!("dropped query cancellation key {}", &self.key); @@ -192,16 +207,14 @@ impl

Drop for Session

{ #[cfg(test)] mod tests { - use crate::metrics::NUM_CANCELLATION_REQUESTS_SOURCE_FROM_REDIS; - use super::*; #[tokio::test] async fn check_session_drop() -> anyhow::Result<()> { - let cancellation_handler = Arc::new(CancellationHandler::<()>::new( - CancelMap::default(), - NUM_CANCELLATION_REQUESTS_SOURCE_FROM_REDIS, - )); + let cancellation_handler = Arc::new(CancellationHandler { + map: CancelMap::default(), + redis_client: None, + }); let session = cancellation_handler.clone().get_session(); assert!(cancellation_handler.contains(&session)); diff --git a/proxy/src/config.rs b/proxy/src/config.rs index 45f8d76144..437ec9f401 100644 --- a/proxy/src/config.rs +++ b/proxy/src/config.rs @@ -28,7 +28,6 @@ pub struct ProxyConfig { pub redis_rps_limit: Vec, pub region: String, pub handshake_timeout: Duration, - pub aws_region: String, } #[derive(Debug)] diff --git a/proxy/src/metrics.rs b/proxy/src/metrics.rs index eed45e421b..02ebcd6aaa 100644 --- a/proxy/src/metrics.rs +++ b/proxy/src/metrics.rs @@ -161,9 +161,6 @@ pub static NUM_CANCELLATION_REQUESTS: Lazy = Lazy::new(|| { .unwrap() }); -pub const NUM_CANCELLATION_REQUESTS_SOURCE_FROM_CLIENT: &str = "from_client"; -pub const NUM_CANCELLATION_REQUESTS_SOURCE_FROM_REDIS: &str = "from_redis"; - pub enum Waiting { Cplane, Client, diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index 843bfc08cf..ab5bf5d494 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -10,7 +10,7 @@ pub mod wake_compute; use crate::{ auth, - cancellation::{self, CancellationHandlerMain, CancellationHandlerMainInternal}, + cancellation::{self, CancellationHandler}, compute, config::{ProxyConfig, TlsConfig}, context::RequestMonitoring, @@ -62,7 +62,7 @@ pub async fn task_main( listener: tokio::net::TcpListener, cancellation_token: CancellationToken, endpoint_rate_limiter: Arc, - cancellation_handler: Arc, + cancellation_handler: Arc, ) -> anyhow::Result<()> { scopeguard::defer! { info!("proxy has shut down"); @@ -233,12 +233,12 @@ impl ReportableError for ClientRequestError { pub async fn handle_client( config: &'static ProxyConfig, ctx: &mut RequestMonitoring, - cancellation_handler: Arc, + cancellation_handler: Arc, stream: S, mode: ClientMode, endpoint_rate_limiter: Arc, conn_gauge: IntCounterPairGuard, -) -> Result>, ClientRequestError> { +) -> Result>, ClientRequestError> { info!("handling interactive connection from client"); let proto = ctx.protocol; @@ -338,9 +338,9 @@ pub async fn handle_client( /// Finish client connection initialization: confirm auth success, send params, etc. #[tracing::instrument(skip_all)] -async fn prepare_client_connection

( +async fn prepare_client_connection( node: &compute::PostgresConnection, - session: &cancellation::Session

, + session: &cancellation::Session, stream: &mut PqStream, ) -> Result<(), std::io::Error> { // Register compute's query cancellation token and produce a new, unique one. diff --git a/proxy/src/proxy/passthrough.rs b/proxy/src/proxy/passthrough.rs index f6d4314391..b2f682fd2f 100644 --- a/proxy/src/proxy/passthrough.rs +++ b/proxy/src/proxy/passthrough.rs @@ -55,17 +55,17 @@ pub async fn proxy_pass( Ok(()) } -pub struct ProxyPassthrough { +pub struct ProxyPassthrough { pub client: Stream, pub compute: PostgresConnection, pub aux: MetricsAuxInfo, pub req: IntCounterPairGuard, pub conn: IntCounterPairGuard, - pub cancel: cancellation::Session

, + pub cancel: cancellation::Session, } -impl ProxyPassthrough { +impl ProxyPassthrough { pub async fn proxy_pass(self) -> anyhow::Result<()> { let res = proxy_pass(self.client, self.compute.stream, self.aux).await; self.compute.cancel_closure.try_cancel_query().await?; diff --git a/proxy/src/redis.rs b/proxy/src/redis.rs index a322f0368c..35d6db074e 100644 --- a/proxy/src/redis.rs +++ b/proxy/src/redis.rs @@ -1,4 +1,2 @@ -pub mod cancellation_publisher; -pub mod connection_with_credentials_provider; -pub mod elasticache; pub mod notifications; +pub mod publisher; diff --git a/proxy/src/redis/cancellation_publisher.rs b/proxy/src/redis/cancellation_publisher.rs deleted file mode 100644 index d9efc3561b..0000000000 --- a/proxy/src/redis/cancellation_publisher.rs +++ /dev/null @@ -1,167 +0,0 @@ -use std::sync::Arc; - -use async_trait::async_trait; -use pq_proto::CancelKeyData; -use redis::AsyncCommands; -use tokio::sync::Mutex; -use uuid::Uuid; - -use crate::rate_limiter::{RateBucketInfo, RedisRateLimiter}; - -use super::{ - connection_with_credentials_provider::ConnectionWithCredentialsProvider, - notifications::{CancelSession, Notification, PROXY_CHANNEL_NAME}, -}; - -#[async_trait] -pub trait CancellationPublisherMut: Send + Sync + 'static { - async fn try_publish( - &mut self, - cancel_key_data: CancelKeyData, - session_id: Uuid, - ) -> anyhow::Result<()>; -} - -#[async_trait] -pub trait CancellationPublisher: Send + Sync + 'static { - async fn try_publish( - &self, - cancel_key_data: CancelKeyData, - session_id: Uuid, - ) -> anyhow::Result<()>; -} - -#[async_trait] -impl CancellationPublisherMut for () { - async fn try_publish( - &mut self, - _cancel_key_data: CancelKeyData, - _session_id: Uuid, - ) -> anyhow::Result<()> { - Ok(()) - } -} - -#[async_trait] -impl CancellationPublisher for P { - async fn try_publish( - &self, - _cancel_key_data: CancelKeyData, - _session_id: Uuid, - ) -> anyhow::Result<()> { - self.try_publish(_cancel_key_data, _session_id).await - } -} - -#[async_trait] -impl CancellationPublisher for Option

{ - async fn try_publish( - &self, - cancel_key_data: CancelKeyData, - session_id: Uuid, - ) -> anyhow::Result<()> { - if let Some(p) = self { - p.try_publish(cancel_key_data, session_id).await - } else { - Ok(()) - } - } -} - -#[async_trait] -impl CancellationPublisher for Arc> { - async fn try_publish( - &self, - cancel_key_data: CancelKeyData, - session_id: Uuid, - ) -> anyhow::Result<()> { - self.lock() - .await - .try_publish(cancel_key_data, session_id) - .await - } -} - -pub struct RedisPublisherClient { - client: ConnectionWithCredentialsProvider, - region_id: String, - limiter: RedisRateLimiter, -} - -impl RedisPublisherClient { - pub fn new( - client: ConnectionWithCredentialsProvider, - region_id: String, - info: &'static [RateBucketInfo], - ) -> anyhow::Result { - Ok(Self { - client, - region_id, - limiter: RedisRateLimiter::new(info), - }) - } - - async fn publish( - &mut self, - cancel_key_data: CancelKeyData, - session_id: Uuid, - ) -> anyhow::Result<()> { - let payload = serde_json::to_string(&Notification::Cancel(CancelSession { - region_id: Some(self.region_id.clone()), - cancel_key_data, - session_id, - }))?; - self.client.publish(PROXY_CHANNEL_NAME, payload).await?; - Ok(()) - } - pub async fn try_connect(&mut self) -> anyhow::Result<()> { - match self.client.connect().await { - Ok(()) => {} - Err(e) => { - tracing::error!("failed to connect to redis: {e}"); - return Err(e); - } - } - Ok(()) - } - async fn try_publish_internal( - &mut self, - cancel_key_data: CancelKeyData, - session_id: Uuid, - ) -> anyhow::Result<()> { - if !self.limiter.check() { - tracing::info!("Rate limit exceeded. Skipping cancellation message"); - return Err(anyhow::anyhow!("Rate limit exceeded")); - } - match self.publish(cancel_key_data, session_id).await { - Ok(()) => return Ok(()), - Err(e) => { - tracing::error!("failed to publish a message: {e}"); - } - } - tracing::info!("Publisher is disconnected. Reconnectiong..."); - self.try_connect().await?; - self.publish(cancel_key_data, session_id).await - } -} - -#[async_trait] -impl CancellationPublisherMut for RedisPublisherClient { - async fn try_publish( - &mut self, - cancel_key_data: CancelKeyData, - session_id: Uuid, - ) -> anyhow::Result<()> { - tracing::info!("publishing cancellation key to Redis"); - match self.try_publish_internal(cancel_key_data, session_id).await { - Ok(()) => { - tracing::info!("cancellation key successfuly published to Redis"); - Ok(()) - } - Err(e) => { - tracing::error!("failed to publish a message: {e}"); - Err(e) - } - } - } -} diff --git a/proxy/src/redis/connection_with_credentials_provider.rs b/proxy/src/redis/connection_with_credentials_provider.rs deleted file mode 100644 index d183abb53a..0000000000 --- a/proxy/src/redis/connection_with_credentials_provider.rs +++ /dev/null @@ -1,225 +0,0 @@ -use std::{sync::Arc, time::Duration}; - -use futures::FutureExt; -use redis::{ - aio::{ConnectionLike, MultiplexedConnection}, - ConnectionInfo, IntoConnectionInfo, RedisConnectionInfo, RedisResult, -}; -use tokio::task::JoinHandle; -use tracing::{error, info}; - -use super::elasticache::CredentialsProvider; - -enum Credentials { - Static(ConnectionInfo), - Dynamic(Arc, redis::ConnectionAddr), -} - -impl Clone for Credentials { - fn clone(&self) -> Self { - match self { - Credentials::Static(info) => Credentials::Static(info.clone()), - Credentials::Dynamic(provider, addr) => { - Credentials::Dynamic(Arc::clone(provider), addr.clone()) - } - } - } -} - -/// A wrapper around `redis::MultiplexedConnection` that automatically refreshes the token. -/// Provides PubSub connection without credentials refresh. -pub struct ConnectionWithCredentialsProvider { - credentials: Credentials, - con: Option, - refresh_token_task: Option>, - mutex: tokio::sync::Mutex<()>, -} - -impl Clone for ConnectionWithCredentialsProvider { - fn clone(&self) -> Self { - Self { - credentials: self.credentials.clone(), - con: None, - refresh_token_task: None, - mutex: tokio::sync::Mutex::new(()), - } - } -} - -impl ConnectionWithCredentialsProvider { - pub fn new_with_credentials_provider( - host: String, - port: u16, - credentials_provider: Arc, - ) -> Self { - Self { - credentials: Credentials::Dynamic( - credentials_provider, - redis::ConnectionAddr::TcpTls { - host, - port, - insecure: false, - tls_params: None, - }, - ), - con: None, - refresh_token_task: None, - mutex: tokio::sync::Mutex::new(()), - } - } - - pub fn new_with_static_credentials(params: T) -> Self { - Self { - credentials: Credentials::Static(params.into_connection_info().unwrap()), - con: None, - refresh_token_task: None, - mutex: tokio::sync::Mutex::new(()), - } - } - - pub async fn connect(&mut self) -> anyhow::Result<()> { - let _guard = self.mutex.lock().await; - if let Some(con) = self.con.as_mut() { - match redis::cmd("PING").query_async(con).await { - Ok(()) => { - return Ok(()); - } - Err(e) => { - error!("Error during PING: {e:?}"); - } - } - } else { - info!("Connection is not established"); - } - info!("Establishing a new connection..."); - self.con = None; - if let Some(f) = self.refresh_token_task.take() { - f.abort() - } - let con = self - .get_client() - .await? - .get_multiplexed_tokio_connection() - .await?; - if let Credentials::Dynamic(credentials_provider, _) = &self.credentials { - let credentials_provider = credentials_provider.clone(); - let con2 = con.clone(); - let f = tokio::spawn(async move { - let _ = Self::keep_connection(con2, credentials_provider).await; - }); - self.refresh_token_task = Some(f); - } - self.con = Some(con); - Ok(()) - } - - async fn get_connection_info(&self) -> anyhow::Result { - match &self.credentials { - Credentials::Static(info) => Ok(info.clone()), - Credentials::Dynamic(provider, addr) => { - let (username, password) = provider.provide_credentials().await?; - Ok(ConnectionInfo { - addr: addr.clone(), - redis: RedisConnectionInfo { - db: 0, - username: Some(username), - password: Some(password.clone()), - }, - }) - } - } - } - - async fn get_client(&self) -> anyhow::Result { - let client = redis::Client::open(self.get_connection_info().await?)?; - Ok(client) - } - - // PubSub does not support credentials refresh. - // Requires manual reconnection every 12h. - pub async fn get_async_pubsub(&self) -> anyhow::Result { - Ok(self.get_client().await?.get_async_pubsub().await?) - } - - // The connection lives for 12h. - // It can be prolonged with sending `AUTH` commands with the refreshed token. - // https://docs.aws.amazon.com/AmazonElastiCache/latest/red-ug/auth-iam.html#auth-iam-limits - async fn keep_connection( - mut con: MultiplexedConnection, - credentials_provider: Arc, - ) -> anyhow::Result<()> { - loop { - // The connection lives for 12h, for the sanity check we refresh it every hour. - tokio::time::sleep(Duration::from_secs(60 * 60)).await; - match Self::refresh_token(&mut con, credentials_provider.clone()).await { - Ok(()) => { - info!("Token refreshed"); - } - Err(e) => { - error!("Error during token refresh: {e:?}"); - } - } - } - } - async fn refresh_token( - con: &mut MultiplexedConnection, - credentials_provider: Arc, - ) -> anyhow::Result<()> { - let (user, password) = credentials_provider.provide_credentials().await?; - redis::cmd("AUTH") - .arg(user) - .arg(password) - .query_async(con) - .await?; - Ok(()) - } - /// Sends an already encoded (packed) command into the TCP socket and - /// reads the single response from it. - pub async fn send_packed_command(&mut self, cmd: &redis::Cmd) -> RedisResult { - // Clone connection to avoid having to lock the ArcSwap in write mode - let con = self.con.as_mut().ok_or(redis::RedisError::from(( - redis::ErrorKind::IoError, - "Connection not established", - )))?; - con.send_packed_command(cmd).await - } - - /// Sends multiple already encoded (packed) command into the TCP socket - /// and reads `count` responses from it. This is used to implement - /// pipelining. - pub async fn send_packed_commands( - &mut self, - cmd: &redis::Pipeline, - offset: usize, - count: usize, - ) -> RedisResult> { - // Clone shared connection future to avoid having to lock the ArcSwap in write mode - let con = self.con.as_mut().ok_or(redis::RedisError::from(( - redis::ErrorKind::IoError, - "Connection not established", - )))?; - con.send_packed_commands(cmd, offset, count).await - } -} - -impl ConnectionLike for ConnectionWithCredentialsProvider { - fn req_packed_command<'a>( - &'a mut self, - cmd: &'a redis::Cmd, - ) -> redis::RedisFuture<'a, redis::Value> { - (async move { self.send_packed_command(cmd).await }).boxed() - } - - fn req_packed_commands<'a>( - &'a mut self, - cmd: &'a redis::Pipeline, - offset: usize, - count: usize, - ) -> redis::RedisFuture<'a, Vec> { - (async move { self.send_packed_commands(cmd, offset, count).await }).boxed() - } - - fn get_db(&self) -> i64 { - 0 - } -} diff --git a/proxy/src/redis/elasticache.rs b/proxy/src/redis/elasticache.rs deleted file mode 100644 index eded8250af..0000000000 --- a/proxy/src/redis/elasticache.rs +++ /dev/null @@ -1,110 +0,0 @@ -use std::time::{Duration, SystemTime}; - -use aws_config::meta::credentials::CredentialsProviderChain; -use aws_sdk_iam::config::ProvideCredentials; -use aws_sigv4::http_request::{ - self, SignableBody, SignableRequest, SignatureLocation, SigningSettings, -}; -use tracing::info; - -#[derive(Debug)] -pub struct AWSIRSAConfig { - region: String, - service_name: String, - cluster_name: String, - user_id: String, - token_ttl: Duration, - action: String, -} - -impl AWSIRSAConfig { - pub fn new(region: String, cluster_name: Option, user_id: Option) -> Self { - AWSIRSAConfig { - region, - service_name: "elasticache".to_string(), - cluster_name: cluster_name.unwrap_or_default(), - user_id: user_id.unwrap_or_default(), - // "The IAM authentication token is valid for 15 minutes" - // https://docs.aws.amazon.com/memorydb/latest/devguide/auth-iam.html#auth-iam-limits - token_ttl: Duration::from_secs(15 * 60), - action: "connect".to_string(), - } - } -} - -/// Credentials provider for AWS elasticache authentication. -/// -/// Official documentation: -/// -/// -/// Useful resources: -/// -pub struct CredentialsProvider { - config: AWSIRSAConfig, - credentials_provider: CredentialsProviderChain, -} - -impl CredentialsProvider { - pub fn new(config: AWSIRSAConfig, credentials_provider: CredentialsProviderChain) -> Self { - CredentialsProvider { - config, - credentials_provider, - } - } - pub async fn provide_credentials(&self) -> anyhow::Result<(String, String)> { - let aws_credentials = self - .credentials_provider - .provide_credentials() - .await? - .into(); - info!("AWS credentials successfully obtained"); - info!("Connecting to Redis with configuration: {:?}", self.config); - let mut settings = SigningSettings::default(); - settings.signature_location = SignatureLocation::QueryParams; - settings.expires_in = Some(self.config.token_ttl); - let signing_params = aws_sigv4::sign::v4::SigningParams::builder() - .identity(&aws_credentials) - .region(&self.config.region) - .name(&self.config.service_name) - .time(SystemTime::now()) - .settings(settings) - .build()? - .into(); - let auth_params = [ - ("Action", &self.config.action), - ("User", &self.config.user_id), - ]; - let auth_params = url::form_urlencoded::Serializer::new(String::new()) - .extend_pairs(auth_params) - .finish(); - let auth_uri = http::Uri::builder() - .scheme("http") - .authority(self.config.cluster_name.as_bytes()) - .path_and_query(format!("/?{auth_params}")) - .build()?; - info!("{}", auth_uri); - - // Convert the HTTP request into a signable request - let signable_request = SignableRequest::new( - "GET", - auth_uri.to_string(), - std::iter::empty(), - SignableBody::Bytes(&[]), - )?; - - // Sign and then apply the signature to the request - let (si, _) = http_request::sign(signable_request, &signing_params)?.into_parts(); - let mut signable_request = http::Request::builder() - .method("GET") - .uri(auth_uri) - .body(())?; - si.apply_to_request_http1x(&mut signable_request); - Ok(( - self.config.user_id.clone(), - signable_request - .uri() - .to_string() - .replacen("http://", "", 1), - )) - } -} diff --git a/proxy/src/redis/notifications.rs b/proxy/src/redis/notifications.rs index 8b7e3e3419..6ae848c0d2 100644 --- a/proxy/src/redis/notifications.rs +++ b/proxy/src/redis/notifications.rs @@ -6,12 +6,11 @@ use redis::aio::PubSub; use serde::{Deserialize, Serialize}; use uuid::Uuid; -use super::connection_with_credentials_provider::ConnectionWithCredentialsProvider; use crate::{ cache::project_info::ProjectInfoCache, - cancellation::{CancelMap, CancellationHandler}, + cancellation::{CancelMap, CancellationHandler, NotificationsCancellationHandler}, intern::{ProjectIdInt, RoleNameInt}, - metrics::{NUM_CANCELLATION_REQUESTS_SOURCE_FROM_REDIS, REDIS_BROKEN_MESSAGES}, + metrics::REDIS_BROKEN_MESSAGES, }; const CPLANE_CHANNEL_NAME: &str = "neondb-proxy-ws-updates"; @@ -19,13 +18,23 @@ pub(crate) const PROXY_CHANNEL_NAME: &str = "neondb-proxy-to-proxy-updates"; const RECONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(20); const INVALIDATION_LAG: std::time::Duration = std::time::Duration::from_secs(20); -async fn try_connect(client: &ConnectionWithCredentialsProvider) -> anyhow::Result { - let mut conn = client.get_async_pubsub().await?; - tracing::info!("subscribing to a channel `{CPLANE_CHANNEL_NAME}`"); - conn.subscribe(CPLANE_CHANNEL_NAME).await?; - tracing::info!("subscribing to a channel `{PROXY_CHANNEL_NAME}`"); - conn.subscribe(PROXY_CHANNEL_NAME).await?; - Ok(conn) +struct RedisConsumerClient { + client: redis::Client, +} + +impl RedisConsumerClient { + pub fn new(url: &str) -> anyhow::Result { + let client = redis::Client::open(url)?; + Ok(Self { client }) + } + async fn try_connect(&self) -> anyhow::Result { + let mut conn = self.client.get_async_connection().await?.into_pubsub(); + tracing::info!("subscribing to a channel `{CPLANE_CHANNEL_NAME}`"); + conn.subscribe(CPLANE_CHANNEL_NAME).await?; + tracing::info!("subscribing to a channel `{PROXY_CHANNEL_NAME}`"); + conn.subscribe(PROXY_CHANNEL_NAME).await?; + Ok(conn) + } } #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] @@ -71,18 +80,21 @@ where serde_json::from_str(&s).map_err(::custom) } -struct MessageHandler { +struct MessageHandler< + C: ProjectInfoCache + Send + Sync + 'static, + H: NotificationsCancellationHandler + Send + Sync + 'static, +> { cache: Arc, - cancellation_handler: Arc>, + cancellation_handler: Arc, region_id: String, } -impl MessageHandler { - pub fn new( - cache: Arc, - cancellation_handler: Arc>, - region_id: String, - ) -> Self { +impl< + C: ProjectInfoCache + Send + Sync + 'static, + H: NotificationsCancellationHandler + Send + Sync + 'static, + > MessageHandler +{ + pub fn new(cache: Arc, cancellation_handler: Arc, region_id: String) -> Self { Self { cache, cancellation_handler, @@ -127,7 +139,7 @@ impl MessageHandler { // This instance of cancellation_handler doesn't have a RedisPublisherClient so it can't publish the message. match self .cancellation_handler - .cancel_session(cancel_session.cancel_key_data, uuid::Uuid::nil()) + .cancel_session_no_publish(cancel_session.cancel_key_data) .await { Ok(()) => {} @@ -170,7 +182,7 @@ fn invalidate_cache(cache: Arc, msg: Notification) { /// Handle console's invalidation messages. #[tracing::instrument(name = "console_notifications", skip_all)] pub async fn task_main( - redis: ConnectionWithCredentialsProvider, + url: String, cache: Arc, cancel_map: CancelMap, region_id: String, @@ -181,15 +193,13 @@ where cache.enable_ttl(); let handler = MessageHandler::new( cache, - Arc::new(CancellationHandler::<()>::new( - cancel_map, - NUM_CANCELLATION_REQUESTS_SOURCE_FROM_REDIS, - )), + Arc::new(CancellationHandler::new(cancel_map, None)), region_id, ); loop { - let mut conn = match try_connect(&redis).await { + let redis = RedisConsumerClient::new(&url)?; + let conn = match redis.try_connect().await { Ok(conn) => { handler.disable_ttl(); conn @@ -202,7 +212,7 @@ where continue; } }; - let mut stream = conn.on_message(); + let mut stream = conn.into_on_message(); while let Some(msg) = stream.next().await { match handler.handle_message(msg).await { Ok(()) => {} diff --git a/proxy/src/redis/publisher.rs b/proxy/src/redis/publisher.rs new file mode 100644 index 0000000000..f85593afdd --- /dev/null +++ b/proxy/src/redis/publisher.rs @@ -0,0 +1,80 @@ +use pq_proto::CancelKeyData; +use redis::AsyncCommands; +use uuid::Uuid; + +use crate::rate_limiter::{RateBucketInfo, RedisRateLimiter}; + +use super::notifications::{CancelSession, Notification, PROXY_CHANNEL_NAME}; + +pub struct RedisPublisherClient { + client: redis::Client, + publisher: Option, + region_id: String, + limiter: RedisRateLimiter, +} + +impl RedisPublisherClient { + pub fn new( + url: &str, + region_id: String, + info: &'static [RateBucketInfo], + ) -> anyhow::Result { + let client = redis::Client::open(url)?; + Ok(Self { + client, + publisher: None, + region_id, + limiter: RedisRateLimiter::new(info), + }) + } + pub async fn try_publish( + &mut self, + cancel_key_data: CancelKeyData, + session_id: Uuid, + ) -> anyhow::Result<()> { + if !self.limiter.check() { + tracing::info!("Rate limit exceeded. Skipping cancellation message"); + return Err(anyhow::anyhow!("Rate limit exceeded")); + } + match self.publish(cancel_key_data, session_id).await { + Ok(()) => return Ok(()), + Err(e) => { + tracing::error!("failed to publish a message: {e}"); + self.publisher = None; + } + } + tracing::info!("Publisher is disconnected. Reconnectiong..."); + self.try_connect().await?; + self.publish(cancel_key_data, session_id).await + } + + async fn publish( + &mut self, + cancel_key_data: CancelKeyData, + session_id: Uuid, + ) -> anyhow::Result<()> { + let conn = self + .publisher + .as_mut() + .ok_or_else(|| anyhow::anyhow!("not connected"))?; + let payload = serde_json::to_string(&Notification::Cancel(CancelSession { + region_id: Some(self.region_id.clone()), + cancel_key_data, + session_id, + }))?; + conn.publish(PROXY_CHANNEL_NAME, payload).await?; + Ok(()) + } + pub async fn try_connect(&mut self) -> anyhow::Result<()> { + match self.client.get_async_connection().await { + Ok(conn) => { + self.publisher = Some(conn); + } + Err(e) => { + tracing::error!("failed to connect to redis: {e}"); + return Err(e.into()); + } + } + Ok(()) + } +} diff --git a/proxy/src/serverless.rs b/proxy/src/serverless.rs index a2010fd613..be9f90acde 100644 --- a/proxy/src/serverless.rs +++ b/proxy/src/serverless.rs @@ -21,12 +21,11 @@ pub use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware}; use tokio_util::task::TaskTracker; use tracing::instrument::Instrumented; -use crate::cancellation::CancellationHandlerMain; -use crate::config::ProxyConfig; use crate::context::RequestMonitoring; use crate::protocol2::{ProxyProtocolAccept, WithClientIp, WithConnectionGuard}; use crate::rate_limiter::EndpointRateLimiter; use crate::serverless::backend::PoolingBackend; +use crate::{cancellation::CancellationHandler, config::ProxyConfig}; use hyper::{ server::conn::{AddrIncoming, AddrStream}, Body, Method, Request, Response, @@ -48,7 +47,7 @@ pub async fn task_main( ws_listener: TcpListener, cancellation_token: CancellationToken, endpoint_rate_limiter: Arc, - cancellation_handler: Arc, + cancellation_handler: Arc, ) -> anyhow::Result<()> { scopeguard::defer! { info!("websocket server has shut down"); @@ -238,7 +237,7 @@ async fn request_handler( config: &'static ProxyConfig, backend: Arc, ws_connections: TaskTracker, - cancellation_handler: Arc, + cancellation_handler: Arc, peer_addr: IpAddr, endpoint_rate_limiter: Arc, // used to cancel in-flight HTTP requests. not used to cancel websockets diff --git a/proxy/src/serverless/websocket.rs b/proxy/src/serverless/websocket.rs index ada6c974f4..a72ede6d0a 100644 --- a/proxy/src/serverless/websocket.rs +++ b/proxy/src/serverless/websocket.rs @@ -1,5 +1,5 @@ use crate::{ - cancellation::CancellationHandlerMain, + cancellation::CancellationHandler, config::ProxyConfig, context::RequestMonitoring, error::{io_error, ReportableError}, @@ -134,7 +134,7 @@ pub async fn serve_websocket( config: &'static ProxyConfig, mut ctx: RequestMonitoring, websocket: HyperWebsocket, - cancellation_handler: Arc, + cancellation_handler: Arc, hostname: Option, endpoint_rate_limiter: Arc, ) -> anyhow::Result<()> { diff --git a/workspace_hack/Cargo.toml b/workspace_hack/Cargo.toml index 7b8228a082..152c452dd4 100644 --- a/workspace_hack/Cargo.toml +++ b/workspace_hack/Cargo.toml @@ -19,7 +19,8 @@ aws-runtime = { version = "1", default-features = false, features = ["event-stre aws-sigv4 = { version = "1", features = ["http0-compat", "sign-eventstream", "sigv4a"] } aws-smithy-async = { version = "1", default-features = false, features = ["rt-tokio"] } aws-smithy-http = { version = "0.60", default-features = false, features = ["event-stream"] } -aws-smithy-types = { version = "1", default-features = false, features = ["byte-stream-poll-next", "http-body-0-4-x", "rt-tokio", "test-util"] } +aws-smithy-runtime-api = { version = "1", features = ["client", "http-02x", "http-auth"] } +aws-smithy-types = { version = "1", default-features = false, features = ["byte-stream-poll-next", "http-body-0-4-x", "rt-tokio"] } axum = { version = "0.6", features = ["ws"] } base64 = { version = "0.21", features = ["alloc"] } base64ct = { version = "1", default-features = false, features = ["std"] }