diff --git a/Cargo.lock b/Cargo.lock index 824cac13b3..dcf1c49924 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -347,9 +347,9 @@ dependencies = [ [[package]] name = "aws-credential-types" -version = "1.1.4" +version = "1.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "33cc49dcdd31c8b6e79850a179af4c367669150c7ac0135f176c61bec81a70f7" +checksum = "fa8587ae17c8e967e4b05a62d495be2fb7701bec52a97f7acfe8a29f938384c8" dependencies = [ "aws-smithy-async", "aws-smithy-runtime-api", @@ -359,9 +359,9 @@ dependencies = [ [[package]] name = "aws-runtime" -version = "1.1.4" +version = "1.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eb031bff99877c26c28895766f7bb8484a05e24547e370768d6cc9db514662aa" +checksum = "b13dc54b4b49f8288532334bba8f87386a40571c47c37b1304979b556dc613c8" dependencies = [ "aws-credential-types", "aws-sigv4", @@ -381,6 +381,29 @@ dependencies = [ "uuid", ] +[[package]] +name = "aws-sdk-iam" +version = "1.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8ae76026bfb1b80a6aed0bb400c1139cd9c0563e26bce1986cd021c6a968c7b" +dependencies = [ + "aws-credential-types", + "aws-runtime", + "aws-smithy-async", + "aws-smithy-http", + "aws-smithy-json", + "aws-smithy-query", + "aws-smithy-runtime", + "aws-smithy-runtime-api", + "aws-smithy-types", + "aws-smithy-xml", + "aws-types", + "http 0.2.9", + "once_cell", + "regex-lite", + "tracing", +] + [[package]] name = "aws-sdk-s3" version = "1.14.0" @@ -502,9 +525,9 @@ dependencies = [ [[package]] name = "aws-sigv4" -version = "1.1.4" +version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c371c6b0ac54d4605eb6f016624fb5c7c2925d315fdf600ac1bf21b19d5f1742" +checksum = "11d6f29688a4be9895c0ba8bef861ad0c0dac5c15e9618b9b7a6c233990fc263" dependencies = [ "aws-credential-types", "aws-smithy-eventstream", @@ -517,7 +540,7 @@ dependencies = [ "hex", "hmac", "http 0.2.9", - "http 1.0.0", + "http 1.1.0", "once_cell", "p256", "percent-encoding", @@ -531,9 +554,9 @@ dependencies = [ [[package]] name = "aws-smithy-async" -version = "1.1.4" +version = "1.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "72ee2d09cce0ef3ae526679b522835d63e75fb427aca5413cd371e490d52dcc6" +checksum = "d26ea8fa03025b2face2b3038a63525a10891e3d8829901d502e5384a0d8cd46" dependencies = [ "futures-util", "pin-project-lite", @@ -574,9 +597,9 @@ dependencies = [ [[package]] name = "aws-smithy-http" -version = "0.60.4" +version = "0.60.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dab56aea3cd9e1101a0a999447fb346afb680ab1406cebc44b32346e25b4117d" +checksum = "3f10fa66956f01540051b0aa7ad54574640f748f9839e843442d99b970d3aff9" dependencies = [ "aws-smithy-eventstream", "aws-smithy-runtime-api", @@ -595,18 +618,18 @@ dependencies = [ [[package]] name = "aws-smithy-json" -version = "0.60.4" +version = "0.60.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fd3898ca6518f9215f62678870064398f00031912390efd03f1f6ef56d83aa8e" +checksum = "4683df9469ef09468dad3473d129960119a0d3593617542b7d52086c8486f2d6" dependencies = [ "aws-smithy-types", ] [[package]] name = "aws-smithy-query" -version = "0.60.4" +version = "0.60.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bda4b1dfc9810e35fba8a620e900522cd1bd4f9578c446e82f49d1ce41d2e9f9" +checksum = "f2fbd61ceb3fe8a1cb7352e42689cec5335833cd9f94103a61e98f9bb61c64bb" dependencies = [ "aws-smithy-types", "urlencoding", @@ -614,9 +637,9 @@ dependencies = [ [[package]] name = "aws-smithy-runtime" -version = "1.1.4" +version = "1.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fafdab38f40ad7816e7da5dec279400dd505160780083759f01441af1bbb10ea" +checksum = "ec81002d883e5a7fd2bb063d6fb51c4999eb55d404f4fff3dd878bf4733b9f01" dependencies = [ "aws-smithy-async", "aws-smithy-http", @@ -639,14 +662,15 @@ dependencies = [ [[package]] name = "aws-smithy-runtime-api" -version = "1.1.4" +version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c18276dd28852f34b3bf501f4f3719781f4999a51c7bff1a5c6dc8c4529adc29" +checksum = "9acb931e0adaf5132de878f1398d83f8677f90ba70f01f65ff87f6d7244be1c5" dependencies = [ "aws-smithy-async", "aws-smithy-types", "bytes", "http 0.2.9", + "http 1.1.0", "pin-project-lite", "tokio", "tracing", @@ -655,9 +679,9 @@ dependencies = [ [[package]] name = "aws-smithy-types" -version = "1.1.4" +version = "1.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bb3e134004170d3303718baa2a4eb4ca64ee0a1c0a7041dca31b38be0fb414f3" +checksum = "abe14dceea1e70101d38fbf2a99e6a34159477c0fb95e68e05c66bd7ae4c3729" dependencies = [ "base64-simd", "bytes", @@ -678,18 +702,18 @@ dependencies = [ [[package]] name = "aws-smithy-xml" -version = "0.60.4" +version = "0.60.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8604a11b25e9ecaf32f9aa56b9fe253c5e2f606a3477f0071e96d3155a5ed218" +checksum = "872c68cf019c0e4afc5de7753c4f7288ce4b71663212771bf5e4542eb9346ca9" dependencies = [ "xmlparser", ] [[package]] name = "aws-types" -version = "1.1.4" +version = "1.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "789bbe008e65636fe1b6dbbb374c40c8960d1232b96af5ff4aec349f9c4accf4" +checksum = "0dbf2f3da841a8930f159163175cf6a3d16ddde517c1b0fba7aa776822800f40" dependencies = [ "aws-credential-types", "aws-smithy-async", @@ -2396,9 +2420,9 @@ dependencies = [ [[package]] name = "http" -version = "1.0.0" +version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b32afd38673a8016f7c9ae69e5af41a58f81b1d31689040f2f1959594ce194ea" +checksum = "21b9ddb458710bc376481b842f5da65cdf31522de232c1ca8146abce2a358258" dependencies = [ "bytes", "fnv", @@ -2498,7 +2522,7 @@ dependencies = [ "hyper", "log", "rustls 0.21.9", - "rustls-native-certs", + "rustls-native-certs 0.6.2", "tokio", "tokio-rustls 0.24.0", ] @@ -4199,6 +4223,10 @@ version = "0.1.0" dependencies = [ "anyhow", "async-trait", + "aws-config", + "aws-sdk-iam", + "aws-sigv4", + "aws-types", "base64 0.13.1", "bstr", "bytes", @@ -4216,6 +4244,7 @@ dependencies = [ "hex", "hmac", "hostname", + "http 1.1.0", "humantime", "hyper", "hyper-tungstenite", @@ -4431,9 +4460,9 @@ dependencies = [ [[package]] name = "redis" -version = "0.24.0" +version = "0.25.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c580d9cbbe1d1b479e8d67cf9daf6a62c957e6846048408b80b43ac3f6af84cd" +checksum = "71d64e978fd98a0e6b105d066ba4889a7301fca65aeac850a877d8797343feeb" dependencies = [ "async-trait", "bytes", @@ -4442,15 +4471,15 @@ dependencies = [ "itoa", "percent-encoding", "pin-project-lite", - "rustls 0.21.9", - "rustls-native-certs", - "rustls-pemfile 1.0.2", - "rustls-webpki 0.101.7", + "rustls 0.22.2", + "rustls-native-certs 0.7.0", + "rustls-pemfile 2.1.1", + "rustls-pki-types", "ryu", "sha1_smol", - "socket2 0.4.9", + "socket2 0.5.5", "tokio", - "tokio-rustls 0.24.0", + "tokio-rustls 0.25.0", "tokio-util", "url", ] @@ -4879,6 +4908,19 @@ dependencies = [ "security-framework", ] +[[package]] +name = "rustls-native-certs" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f1fb85efa936c42c6d5fc28d2629bb51e4b2f4b8a5211e297d599cc5a093792" +dependencies = [ + "openssl-probe", + "rustls-pemfile 2.1.1", + "rustls-pki-types", + "schannel", + "security-framework", +] + [[package]] name = "rustls-pemfile" version = "1.0.2" @@ -6146,7 +6188,7 @@ dependencies = [ "percent-encoding", "pin-project", "prost", - "rustls-native-certs", + "rustls-native-certs 0.6.2", "rustls-pemfile 1.0.2", "tokio", "tokio-rustls 0.24.0", @@ -7031,7 +7073,6 @@ dependencies = [ "aws-sigv4", "aws-smithy-async", "aws-smithy-http", - "aws-smithy-runtime-api", "aws-smithy-types", "axum", "base64 0.21.1", diff --git a/Cargo.toml b/Cargo.toml index 44e6ec9744..2741bd046b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -53,9 +53,12 @@ async-trait = "0.1" aws-config = { version = "1.1.4", default-features = false, features=["rustls"] } aws-sdk-s3 = "1.14" aws-sdk-secretsmanager = { version = "1.14.0" } +aws-sdk-iam = "1.15.0" aws-smithy-async = { version = "1.1.4", default-features = false, features=["rt-tokio"] } aws-smithy-types = "1.1.4" aws-credential-types = "1.1.4" +aws-sigv4 = { version = "1.2.0", features = ["sign-http"] } +aws-types = "1.1.7" axum = { version = "0.6.20", features = ["ws"] } base64 = "0.13.0" bincode = "1.3" @@ -88,6 +91,7 @@ hex = "0.4" hex-literal = "0.4" hmac = "0.12.1" hostname = "0.3.1" +http = {version = "1.1.0", features = ["std"]} http-types = { version = "2", default-features = false } humantime = "2.1" humantime-serde = "1.1.1" @@ -121,7 +125,7 @@ procfs = "0.14" prometheus = {version = "0.13", default_features=false, features = ["process"]} # removes protobuf dependency prost = "0.11" rand = "0.8" -redis = { version = "0.24.0", features = ["tokio-rustls-comp", "keep-alive"] } +redis = { version = "0.25.2", features = ["tokio-rustls-comp", "keep-alive"] } regex = "1.10.2" reqwest = { version = "0.11", default-features = false, features = ["rustls-tls"] } reqwest-tracing = { version = "0.4.7", features = ["opentelemetry_0_20"] } diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index 93a1fe85db..3566d8b728 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -11,6 +11,10 @@ testing = [] [dependencies] anyhow.workspace = true async-trait.workspace = true +aws-config.workspace = true +aws-sdk-iam.workspace = true +aws-sigv4.workspace = true +aws-types.workspace = true base64.workspace = true bstr.workspace = true bytes = { workspace = true, features = ["serde"] } @@ -27,6 +31,7 @@ hashlink.workspace = true hex.workspace = true hmac.workspace = true hostname.workspace = true +http.workspace = true humantime.workspace = true hyper-tungstenite.workspace = true hyper.workspace = true diff --git a/proxy/src/bin/proxy.rs b/proxy/src/bin/proxy.rs index b3d4fc0411..d38439c2a0 100644 --- a/proxy/src/bin/proxy.rs +++ b/proxy/src/bin/proxy.rs @@ -1,3 +1,10 @@ +use aws_config::environment::EnvironmentVariableCredentialsProvider; +use aws_config::imds::credentials::ImdsCredentialsProvider; +use aws_config::meta::credentials::CredentialsProviderChain; +use aws_config::meta::region::RegionProviderChain; +use aws_config::profile::ProfileFileCredentialsProvider; +use aws_config::provider_config::ProviderConfig; +use aws_config::web_identity_token::WebIdentityTokenCredentialsProvider; use futures::future::Either; use proxy::auth; use proxy::auth::backend::MaybeOwned; @@ -10,11 +17,14 @@ use proxy::config::ProjectInfoCacheOptions; use proxy::console; use proxy::context::parquet::ParquetUploadArgs; use proxy::http; +use proxy::metrics::NUM_CANCELLATION_REQUESTS_SOURCE_FROM_CLIENT; use proxy::rate_limiter::EndpointRateLimiter; use proxy::rate_limiter::RateBucketInfo; use proxy::rate_limiter::RateLimiterConfig; +use proxy::redis::cancellation_publisher::RedisPublisherClient; +use proxy::redis::connection_with_credentials_provider::ConnectionWithCredentialsProvider; +use proxy::redis::elasticache; use proxy::redis::notifications; -use proxy::redis::publisher::RedisPublisherClient; use proxy::serverless::GlobalConnPoolOptions; use proxy::usage_metrics; @@ -150,9 +160,24 @@ struct ProxyCliArgs { /// disable ip check for http requests. If it is too time consuming, it could be turned off. #[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)] disable_ip_check_for_http: bool, - /// redis url for notifications. + /// redis url for notifications (if empty, redis_host:port will be used for both notifications and streaming connections) #[clap(long)] redis_notifications: Option, + /// redis host for streaming connections (might be different from the notifications host) + #[clap(long)] + redis_host: Option, + /// redis port for streaming connections (might be different from the notifications host) + #[clap(long)] + redis_port: Option, + /// redis cluster name, used in aws elasticache + #[clap(long)] + redis_cluster_name: Option, + /// redis user_id, used in aws elasticache + #[clap(long)] + redis_user_id: Option, + /// aws region to retrieve credentials + #[clap(long, default_value_t = String::new())] + aws_region: String, /// cache for `project_info` (use `size=0` to disable) #[clap(long, default_value = config::ProjectInfoCacheOptions::CACHE_DEFAULT_OPTIONS)] project_info_cache: String, @@ -216,6 +241,61 @@ async fn main() -> anyhow::Result<()> { let config = build_config(&args)?; info!("Authentication backend: {}", config.auth_backend); + info!("Using region: {}", config.aws_region); + + let region_provider = RegionProviderChain::default_provider().or_else(&*config.aws_region); // Replace with your Redis region if needed + let provider_conf = + ProviderConfig::without_region().with_region(region_provider.region().await); + let aws_credentials_provider = { + // uses "AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY" + CredentialsProviderChain::first_try("env", EnvironmentVariableCredentialsProvider::new()) + // uses "AWS_PROFILE" / `aws sso login --profile ` + .or_else( + "profile-sso", + ProfileFileCredentialsProvider::builder() + .configure(&provider_conf) + .build(), + ) + // uses "AWS_WEB_IDENTITY_TOKEN_FILE", "AWS_ROLE_ARN", "AWS_ROLE_SESSION_NAME" + // needed to access remote extensions bucket + .or_else( + "token", + WebIdentityTokenCredentialsProvider::builder() + .configure(&provider_conf) + .build(), + ) + // uses imds v2 + .or_else("imds", ImdsCredentialsProvider::builder().build()) + }; + let elasticache_credentials_provider = Arc::new(elasticache::CredentialsProvider::new( + elasticache::AWSIRSAConfig::new( + config.aws_region.clone(), + args.redis_cluster_name, + args.redis_user_id, + ), + aws_credentials_provider, + )); + let redis_notifications_client = + match (args.redis_notifications, (args.redis_host, args.redis_port)) { + (Some(url), _) => { + info!("Starting redis notifications listener ({url})"); + Some(ConnectionWithCredentialsProvider::new_with_static_credentials(url)) + } + (None, (Some(host), Some(port))) => Some( + ConnectionWithCredentialsProvider::new_with_credentials_provider( + host, + port, + elasticache_credentials_provider.clone(), + ), + ), + (None, (None, None)) => { + warn!("Redis is disabled"); + None + } + _ => { + bail!("redis-host and redis-port must be specified together"); + } + }; // Check that we can bind to address before further initialization let http_address: SocketAddr = args.http.parse()?; @@ -233,17 +313,22 @@ async fn main() -> anyhow::Result<()> { let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new(&config.endpoint_rps_limit)); let cancel_map = CancelMap::default(); - let redis_publisher = match &args.redis_notifications { - Some(url) => Some(Arc::new(Mutex::new(RedisPublisherClient::new( - url, + + // let redis_notifications_client = redis_notifications_client.map(|x| Box::leak(Box::new(x))); + let redis_publisher = match &redis_notifications_client { + Some(redis_publisher) => Some(Arc::new(Mutex::new(RedisPublisherClient::new( + redis_publisher.clone(), args.region.clone(), &config.redis_rps_limit, )?))), None => None, }; - let cancellation_handler = Arc::new(CancellationHandler::new( + let cancellation_handler = Arc::new(CancellationHandler::< + Option>>, + >::new( cancel_map.clone(), redis_publisher, + NUM_CANCELLATION_REQUESTS_SOURCE_FROM_CLIENT, )); // client facing tasks. these will exit on error or on cancellation @@ -290,17 +375,16 @@ async fn main() -> anyhow::Result<()> { if let auth::BackendType::Console(api, _) = &config.auth_backend { if let proxy::console::provider::ConsoleBackend::Console(api) = &**api { - let cache = api.caches.project_info.clone(); - if let Some(url) = args.redis_notifications { - info!("Starting redis notifications listener ({url})"); + if let Some(redis_notifications_client) = redis_notifications_client { + let cache = api.caches.project_info.clone(); maintenance_tasks.spawn(notifications::task_main( - url.to_owned(), + redis_notifications_client.clone(), cache.clone(), cancel_map.clone(), args.region.clone(), )); + maintenance_tasks.spawn(async move { cache.clone().gc_worker().await }); } - maintenance_tasks.spawn(async move { cache.clone().gc_worker().await }); } } @@ -445,8 +529,8 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { endpoint_rps_limit, redis_rps_limit, handshake_timeout: args.handshake_timeout, - // TODO: add this argument region: args.region.clone(), + aws_region: args.aws_region.clone(), })); Ok(config) diff --git a/proxy/src/cancellation.rs b/proxy/src/cancellation.rs index c9607909b3..8054f33b6c 100644 --- a/proxy/src/cancellation.rs +++ b/proxy/src/cancellation.rs @@ -1,4 +1,3 @@ -use async_trait::async_trait; use dashmap::DashMap; use pq_proto::CancelKeyData; use std::{net::SocketAddr, sync::Arc}; @@ -10,18 +9,26 @@ use tracing::info; use uuid::Uuid; use crate::{ - error::ReportableError, metrics::NUM_CANCELLATION_REQUESTS, - redis::publisher::RedisPublisherClient, + error::ReportableError, + metrics::NUM_CANCELLATION_REQUESTS, + redis::cancellation_publisher::{ + CancellationPublisher, CancellationPublisherMut, RedisPublisherClient, + }, }; pub type CancelMap = Arc>>; +pub type CancellationHandlerMain = CancellationHandler>>>; +pub type CancellationHandlerMainInternal = Option>>; /// Enables serving `CancelRequest`s. /// -/// If there is a `RedisPublisherClient` available, it will be used to publish the cancellation key to other proxy instances. -pub struct CancellationHandler { +/// If `CancellationPublisher` is available, cancel request will be used to publish the cancellation key to other proxy instances. +pub struct CancellationHandler

{ map: CancelMap, - redis_client: Option>>, + client: P, + /// This field used for the monitoring purposes. + /// Represents the source of the cancellation request. + from: &'static str, } #[derive(Debug, Error)] @@ -44,49 +51,9 @@ impl ReportableError for CancelError { } } -impl CancellationHandler { - pub fn new(map: CancelMap, redis_client: Option>>) -> Self { - Self { map, redis_client } - } - /// Cancel a running query for the corresponding connection. - pub async fn cancel_session( - &self, - key: CancelKeyData, - session_id: Uuid, - ) -> Result<(), CancelError> { - let from = "from_client"; - // NB: we should immediately release the lock after cloning the token. - let Some(cancel_closure) = self.map.get(&key).and_then(|x| x.clone()) else { - tracing::warn!("query cancellation key not found: {key}"); - if let Some(redis_client) = &self.redis_client { - NUM_CANCELLATION_REQUESTS - .with_label_values(&[from, "not_found"]) - .inc(); - info!("publishing cancellation key to Redis"); - match redis_client.lock().await.try_publish(key, session_id).await { - Ok(()) => { - info!("cancellation key successfuly published to Redis"); - } - Err(e) => { - tracing::error!("failed to publish a message: {e}"); - return Err(CancelError::IO(std::io::Error::new( - std::io::ErrorKind::Other, - e.to_string(), - ))); - } - } - } - return Ok(()); - }; - NUM_CANCELLATION_REQUESTS - .with_label_values(&[from, "found"]) - .inc(); - info!("cancelling query per user's request using key {key}"); - cancel_closure.try_cancel_query().await - } - +impl CancellationHandler

{ /// Run async action within an ephemeral session identified by [`CancelKeyData`]. - pub fn get_session(self: Arc) -> Session { + pub fn get_session(self: Arc) -> Session

{ // HACK: We'd rather get the real backend_pid but tokio_postgres doesn't // expose it and we don't want to do another roundtrip to query // for it. The client will be able to notice that this is not the @@ -112,9 +79,39 @@ impl CancellationHandler { cancellation_handler: self, } } + /// Try to cancel a running query for the corresponding connection. + /// If the cancellation key is not found, it will be published to Redis. + pub async fn cancel_session( + &self, + key: CancelKeyData, + session_id: Uuid, + ) -> Result<(), CancelError> { + // NB: we should immediately release the lock after cloning the token. + let Some(cancel_closure) = self.map.get(&key).and_then(|x| x.clone()) else { + tracing::warn!("query cancellation key not found: {key}"); + NUM_CANCELLATION_REQUESTS + .with_label_values(&[self.from, "not_found"]) + .inc(); + match self.client.try_publish(key, session_id).await { + Ok(()) => {} // do nothing + Err(e) => { + return Err(CancelError::IO(std::io::Error::new( + std::io::ErrorKind::Other, + e.to_string(), + ))); + } + } + return Ok(()); + }; + NUM_CANCELLATION_REQUESTS + .with_label_values(&[self.from, "found"]) + .inc(); + info!("cancelling query per user's request using key {key}"); + cancel_closure.try_cancel_query().await + } #[cfg(test)] - fn contains(&self, session: &Session) -> bool { + fn contains(&self, session: &Session

) -> bool { self.map.contains_key(&session.key) } @@ -124,31 +121,19 @@ impl CancellationHandler { } } -#[async_trait] -pub trait NotificationsCancellationHandler { - async fn cancel_session_no_publish(&self, key: CancelKeyData) -> Result<(), CancelError>; +impl CancellationHandler<()> { + pub fn new(map: CancelMap, from: &'static str) -> Self { + Self { + map, + client: (), + from, + } + } } -#[async_trait] -impl NotificationsCancellationHandler for CancellationHandler { - async fn cancel_session_no_publish(&self, key: CancelKeyData) -> Result<(), CancelError> { - let from = "from_redis"; - let cancel_closure = self.map.get(&key).and_then(|x| x.clone()); - match cancel_closure { - Some(cancel_closure) => { - NUM_CANCELLATION_REQUESTS - .with_label_values(&[from, "found"]) - .inc(); - cancel_closure.try_cancel_query().await - } - None => { - NUM_CANCELLATION_REQUESTS - .with_label_values(&[from, "not_found"]) - .inc(); - tracing::warn!("query cancellation key not found: {key}"); - Ok(()) - } - } +impl CancellationHandler>>> { + pub fn new(map: CancelMap, client: Option>>, from: &'static str) -> Self { + Self { map, client, from } } } @@ -178,14 +163,14 @@ impl CancelClosure { } /// Helper for registering query cancellation tokens. -pub struct Session { +pub struct Session

{ /// The user-facing key identifying this session. key: CancelKeyData, /// The [`CancelMap`] this session belongs to. - cancellation_handler: Arc, + cancellation_handler: Arc>, } -impl Session { +impl

Session

{ /// Store the cancel token for the given session. /// This enables query cancellation in `crate::proxy::prepare_client_connection`. pub fn enable_query_cancellation(&self, cancel_closure: CancelClosure) -> CancelKeyData { @@ -198,7 +183,7 @@ impl Session { } } -impl Drop for Session { +impl

Drop for Session

{ fn drop(&mut self) { self.cancellation_handler.map.remove(&self.key); info!("dropped query cancellation key {}", &self.key); @@ -207,14 +192,16 @@ impl Drop for Session { #[cfg(test)] mod tests { + use crate::metrics::NUM_CANCELLATION_REQUESTS_SOURCE_FROM_REDIS; + use super::*; #[tokio::test] async fn check_session_drop() -> anyhow::Result<()> { - let cancellation_handler = Arc::new(CancellationHandler { - map: CancelMap::default(), - redis_client: None, - }); + let cancellation_handler = Arc::new(CancellationHandler::<()>::new( + CancelMap::default(), + NUM_CANCELLATION_REQUESTS_SOURCE_FROM_REDIS, + )); let session = cancellation_handler.clone().get_session(); assert!(cancellation_handler.contains(&session)); diff --git a/proxy/src/config.rs b/proxy/src/config.rs index 437ec9f401..45f8d76144 100644 --- a/proxy/src/config.rs +++ b/proxy/src/config.rs @@ -28,6 +28,7 @@ pub struct ProxyConfig { pub redis_rps_limit: Vec, pub region: String, pub handshake_timeout: Duration, + pub aws_region: String, } #[derive(Debug)] diff --git a/proxy/src/metrics.rs b/proxy/src/metrics.rs index 02ebcd6aaa..eed45e421b 100644 --- a/proxy/src/metrics.rs +++ b/proxy/src/metrics.rs @@ -161,6 +161,9 @@ pub static NUM_CANCELLATION_REQUESTS: Lazy = Lazy::new(|| { .unwrap() }); +pub const NUM_CANCELLATION_REQUESTS_SOURCE_FROM_CLIENT: &str = "from_client"; +pub const NUM_CANCELLATION_REQUESTS_SOURCE_FROM_REDIS: &str = "from_redis"; + pub enum Waiting { Cplane, Client, diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index ab5bf5d494..843bfc08cf 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -10,7 +10,7 @@ pub mod wake_compute; use crate::{ auth, - cancellation::{self, CancellationHandler}, + cancellation::{self, CancellationHandlerMain, CancellationHandlerMainInternal}, compute, config::{ProxyConfig, TlsConfig}, context::RequestMonitoring, @@ -62,7 +62,7 @@ pub async fn task_main( listener: tokio::net::TcpListener, cancellation_token: CancellationToken, endpoint_rate_limiter: Arc, - cancellation_handler: Arc, + cancellation_handler: Arc, ) -> anyhow::Result<()> { scopeguard::defer! { info!("proxy has shut down"); @@ -233,12 +233,12 @@ impl ReportableError for ClientRequestError { pub async fn handle_client( config: &'static ProxyConfig, ctx: &mut RequestMonitoring, - cancellation_handler: Arc, + cancellation_handler: Arc, stream: S, mode: ClientMode, endpoint_rate_limiter: Arc, conn_gauge: IntCounterPairGuard, -) -> Result>, ClientRequestError> { +) -> Result>, ClientRequestError> { info!("handling interactive connection from client"); let proto = ctx.protocol; @@ -338,9 +338,9 @@ pub async fn handle_client( /// Finish client connection initialization: confirm auth success, send params, etc. #[tracing::instrument(skip_all)] -async fn prepare_client_connection( +async fn prepare_client_connection

( node: &compute::PostgresConnection, - session: &cancellation::Session, + session: &cancellation::Session

, stream: &mut PqStream, ) -> Result<(), std::io::Error> { // Register compute's query cancellation token and produce a new, unique one. diff --git a/proxy/src/proxy/passthrough.rs b/proxy/src/proxy/passthrough.rs index b2f682fd2f..f6d4314391 100644 --- a/proxy/src/proxy/passthrough.rs +++ b/proxy/src/proxy/passthrough.rs @@ -55,17 +55,17 @@ pub async fn proxy_pass( Ok(()) } -pub struct ProxyPassthrough { +pub struct ProxyPassthrough { pub client: Stream, pub compute: PostgresConnection, pub aux: MetricsAuxInfo, pub req: IntCounterPairGuard, pub conn: IntCounterPairGuard, - pub cancel: cancellation::Session, + pub cancel: cancellation::Session

, } -impl ProxyPassthrough { +impl ProxyPassthrough { pub async fn proxy_pass(self) -> anyhow::Result<()> { let res = proxy_pass(self.client, self.compute.stream, self.aux).await; self.compute.cancel_closure.try_cancel_query().await?; diff --git a/proxy/src/redis.rs b/proxy/src/redis.rs index 35d6db074e..a322f0368c 100644 --- a/proxy/src/redis.rs +++ b/proxy/src/redis.rs @@ -1,2 +1,4 @@ +pub mod cancellation_publisher; +pub mod connection_with_credentials_provider; +pub mod elasticache; pub mod notifications; -pub mod publisher; diff --git a/proxy/src/redis/cancellation_publisher.rs b/proxy/src/redis/cancellation_publisher.rs new file mode 100644 index 0000000000..d9efc3561b --- /dev/null +++ b/proxy/src/redis/cancellation_publisher.rs @@ -0,0 +1,167 @@ +use std::sync::Arc; + +use async_trait::async_trait; +use pq_proto::CancelKeyData; +use redis::AsyncCommands; +use tokio::sync::Mutex; +use uuid::Uuid; + +use crate::rate_limiter::{RateBucketInfo, RedisRateLimiter}; + +use super::{ + connection_with_credentials_provider::ConnectionWithCredentialsProvider, + notifications::{CancelSession, Notification, PROXY_CHANNEL_NAME}, +}; + +#[async_trait] +pub trait CancellationPublisherMut: Send + Sync + 'static { + async fn try_publish( + &mut self, + cancel_key_data: CancelKeyData, + session_id: Uuid, + ) -> anyhow::Result<()>; +} + +#[async_trait] +pub trait CancellationPublisher: Send + Sync + 'static { + async fn try_publish( + &self, + cancel_key_data: CancelKeyData, + session_id: Uuid, + ) -> anyhow::Result<()>; +} + +#[async_trait] +impl CancellationPublisherMut for () { + async fn try_publish( + &mut self, + _cancel_key_data: CancelKeyData, + _session_id: Uuid, + ) -> anyhow::Result<()> { + Ok(()) + } +} + +#[async_trait] +impl CancellationPublisher for P { + async fn try_publish( + &self, + _cancel_key_data: CancelKeyData, + _session_id: Uuid, + ) -> anyhow::Result<()> { + self.try_publish(_cancel_key_data, _session_id).await + } +} + +#[async_trait] +impl CancellationPublisher for Option

{ + async fn try_publish( + &self, + cancel_key_data: CancelKeyData, + session_id: Uuid, + ) -> anyhow::Result<()> { + if let Some(p) = self { + p.try_publish(cancel_key_data, session_id).await + } else { + Ok(()) + } + } +} + +#[async_trait] +impl CancellationPublisher for Arc> { + async fn try_publish( + &self, + cancel_key_data: CancelKeyData, + session_id: Uuid, + ) -> anyhow::Result<()> { + self.lock() + .await + .try_publish(cancel_key_data, session_id) + .await + } +} + +pub struct RedisPublisherClient { + client: ConnectionWithCredentialsProvider, + region_id: String, + limiter: RedisRateLimiter, +} + +impl RedisPublisherClient { + pub fn new( + client: ConnectionWithCredentialsProvider, + region_id: String, + info: &'static [RateBucketInfo], + ) -> anyhow::Result { + Ok(Self { + client, + region_id, + limiter: RedisRateLimiter::new(info), + }) + } + + async fn publish( + &mut self, + cancel_key_data: CancelKeyData, + session_id: Uuid, + ) -> anyhow::Result<()> { + let payload = serde_json::to_string(&Notification::Cancel(CancelSession { + region_id: Some(self.region_id.clone()), + cancel_key_data, + session_id, + }))?; + self.client.publish(PROXY_CHANNEL_NAME, payload).await?; + Ok(()) + } + pub async fn try_connect(&mut self) -> anyhow::Result<()> { + match self.client.connect().await { + Ok(()) => {} + Err(e) => { + tracing::error!("failed to connect to redis: {e}"); + return Err(e); + } + } + Ok(()) + } + async fn try_publish_internal( + &mut self, + cancel_key_data: CancelKeyData, + session_id: Uuid, + ) -> anyhow::Result<()> { + if !self.limiter.check() { + tracing::info!("Rate limit exceeded. Skipping cancellation message"); + return Err(anyhow::anyhow!("Rate limit exceeded")); + } + match self.publish(cancel_key_data, session_id).await { + Ok(()) => return Ok(()), + Err(e) => { + tracing::error!("failed to publish a message: {e}"); + } + } + tracing::info!("Publisher is disconnected. Reconnectiong..."); + self.try_connect().await?; + self.publish(cancel_key_data, session_id).await + } +} + +#[async_trait] +impl CancellationPublisherMut for RedisPublisherClient { + async fn try_publish( + &mut self, + cancel_key_data: CancelKeyData, + session_id: Uuid, + ) -> anyhow::Result<()> { + tracing::info!("publishing cancellation key to Redis"); + match self.try_publish_internal(cancel_key_data, session_id).await { + Ok(()) => { + tracing::info!("cancellation key successfuly published to Redis"); + Ok(()) + } + Err(e) => { + tracing::error!("failed to publish a message: {e}"); + Err(e) + } + } + } +} diff --git a/proxy/src/redis/connection_with_credentials_provider.rs b/proxy/src/redis/connection_with_credentials_provider.rs new file mode 100644 index 0000000000..d183abb53a --- /dev/null +++ b/proxy/src/redis/connection_with_credentials_provider.rs @@ -0,0 +1,225 @@ +use std::{sync::Arc, time::Duration}; + +use futures::FutureExt; +use redis::{ + aio::{ConnectionLike, MultiplexedConnection}, + ConnectionInfo, IntoConnectionInfo, RedisConnectionInfo, RedisResult, +}; +use tokio::task::JoinHandle; +use tracing::{error, info}; + +use super::elasticache::CredentialsProvider; + +enum Credentials { + Static(ConnectionInfo), + Dynamic(Arc, redis::ConnectionAddr), +} + +impl Clone for Credentials { + fn clone(&self) -> Self { + match self { + Credentials::Static(info) => Credentials::Static(info.clone()), + Credentials::Dynamic(provider, addr) => { + Credentials::Dynamic(Arc::clone(provider), addr.clone()) + } + } + } +} + +/// A wrapper around `redis::MultiplexedConnection` that automatically refreshes the token. +/// Provides PubSub connection without credentials refresh. +pub struct ConnectionWithCredentialsProvider { + credentials: Credentials, + con: Option, + refresh_token_task: Option>, + mutex: tokio::sync::Mutex<()>, +} + +impl Clone for ConnectionWithCredentialsProvider { + fn clone(&self) -> Self { + Self { + credentials: self.credentials.clone(), + con: None, + refresh_token_task: None, + mutex: tokio::sync::Mutex::new(()), + } + } +} + +impl ConnectionWithCredentialsProvider { + pub fn new_with_credentials_provider( + host: String, + port: u16, + credentials_provider: Arc, + ) -> Self { + Self { + credentials: Credentials::Dynamic( + credentials_provider, + redis::ConnectionAddr::TcpTls { + host, + port, + insecure: false, + tls_params: None, + }, + ), + con: None, + refresh_token_task: None, + mutex: tokio::sync::Mutex::new(()), + } + } + + pub fn new_with_static_credentials(params: T) -> Self { + Self { + credentials: Credentials::Static(params.into_connection_info().unwrap()), + con: None, + refresh_token_task: None, + mutex: tokio::sync::Mutex::new(()), + } + } + + pub async fn connect(&mut self) -> anyhow::Result<()> { + let _guard = self.mutex.lock().await; + if let Some(con) = self.con.as_mut() { + match redis::cmd("PING").query_async(con).await { + Ok(()) => { + return Ok(()); + } + Err(e) => { + error!("Error during PING: {e:?}"); + } + } + } else { + info!("Connection is not established"); + } + info!("Establishing a new connection..."); + self.con = None; + if let Some(f) = self.refresh_token_task.take() { + f.abort() + } + let con = self + .get_client() + .await? + .get_multiplexed_tokio_connection() + .await?; + if let Credentials::Dynamic(credentials_provider, _) = &self.credentials { + let credentials_provider = credentials_provider.clone(); + let con2 = con.clone(); + let f = tokio::spawn(async move { + let _ = Self::keep_connection(con2, credentials_provider).await; + }); + self.refresh_token_task = Some(f); + } + self.con = Some(con); + Ok(()) + } + + async fn get_connection_info(&self) -> anyhow::Result { + match &self.credentials { + Credentials::Static(info) => Ok(info.clone()), + Credentials::Dynamic(provider, addr) => { + let (username, password) = provider.provide_credentials().await?; + Ok(ConnectionInfo { + addr: addr.clone(), + redis: RedisConnectionInfo { + db: 0, + username: Some(username), + password: Some(password.clone()), + }, + }) + } + } + } + + async fn get_client(&self) -> anyhow::Result { + let client = redis::Client::open(self.get_connection_info().await?)?; + Ok(client) + } + + // PubSub does not support credentials refresh. + // Requires manual reconnection every 12h. + pub async fn get_async_pubsub(&self) -> anyhow::Result { + Ok(self.get_client().await?.get_async_pubsub().await?) + } + + // The connection lives for 12h. + // It can be prolonged with sending `AUTH` commands with the refreshed token. + // https://docs.aws.amazon.com/AmazonElastiCache/latest/red-ug/auth-iam.html#auth-iam-limits + async fn keep_connection( + mut con: MultiplexedConnection, + credentials_provider: Arc, + ) -> anyhow::Result<()> { + loop { + // The connection lives for 12h, for the sanity check we refresh it every hour. + tokio::time::sleep(Duration::from_secs(60 * 60)).await; + match Self::refresh_token(&mut con, credentials_provider.clone()).await { + Ok(()) => { + info!("Token refreshed"); + } + Err(e) => { + error!("Error during token refresh: {e:?}"); + } + } + } + } + async fn refresh_token( + con: &mut MultiplexedConnection, + credentials_provider: Arc, + ) -> anyhow::Result<()> { + let (user, password) = credentials_provider.provide_credentials().await?; + redis::cmd("AUTH") + .arg(user) + .arg(password) + .query_async(con) + .await?; + Ok(()) + } + /// Sends an already encoded (packed) command into the TCP socket and + /// reads the single response from it. + pub async fn send_packed_command(&mut self, cmd: &redis::Cmd) -> RedisResult { + // Clone connection to avoid having to lock the ArcSwap in write mode + let con = self.con.as_mut().ok_or(redis::RedisError::from(( + redis::ErrorKind::IoError, + "Connection not established", + )))?; + con.send_packed_command(cmd).await + } + + /// Sends multiple already encoded (packed) command into the TCP socket + /// and reads `count` responses from it. This is used to implement + /// pipelining. + pub async fn send_packed_commands( + &mut self, + cmd: &redis::Pipeline, + offset: usize, + count: usize, + ) -> RedisResult> { + // Clone shared connection future to avoid having to lock the ArcSwap in write mode + let con = self.con.as_mut().ok_or(redis::RedisError::from(( + redis::ErrorKind::IoError, + "Connection not established", + )))?; + con.send_packed_commands(cmd, offset, count).await + } +} + +impl ConnectionLike for ConnectionWithCredentialsProvider { + fn req_packed_command<'a>( + &'a mut self, + cmd: &'a redis::Cmd, + ) -> redis::RedisFuture<'a, redis::Value> { + (async move { self.send_packed_command(cmd).await }).boxed() + } + + fn req_packed_commands<'a>( + &'a mut self, + cmd: &'a redis::Pipeline, + offset: usize, + count: usize, + ) -> redis::RedisFuture<'a, Vec> { + (async move { self.send_packed_commands(cmd, offset, count).await }).boxed() + } + + fn get_db(&self) -> i64 { + 0 + } +} diff --git a/proxy/src/redis/elasticache.rs b/proxy/src/redis/elasticache.rs new file mode 100644 index 0000000000..eded8250af --- /dev/null +++ b/proxy/src/redis/elasticache.rs @@ -0,0 +1,110 @@ +use std::time::{Duration, SystemTime}; + +use aws_config::meta::credentials::CredentialsProviderChain; +use aws_sdk_iam::config::ProvideCredentials; +use aws_sigv4::http_request::{ + self, SignableBody, SignableRequest, SignatureLocation, SigningSettings, +}; +use tracing::info; + +#[derive(Debug)] +pub struct AWSIRSAConfig { + region: String, + service_name: String, + cluster_name: String, + user_id: String, + token_ttl: Duration, + action: String, +} + +impl AWSIRSAConfig { + pub fn new(region: String, cluster_name: Option, user_id: Option) -> Self { + AWSIRSAConfig { + region, + service_name: "elasticache".to_string(), + cluster_name: cluster_name.unwrap_or_default(), + user_id: user_id.unwrap_or_default(), + // "The IAM authentication token is valid for 15 minutes" + // https://docs.aws.amazon.com/memorydb/latest/devguide/auth-iam.html#auth-iam-limits + token_ttl: Duration::from_secs(15 * 60), + action: "connect".to_string(), + } + } +} + +/// Credentials provider for AWS elasticache authentication. +/// +/// Official documentation: +/// +/// +/// Useful resources: +/// +pub struct CredentialsProvider { + config: AWSIRSAConfig, + credentials_provider: CredentialsProviderChain, +} + +impl CredentialsProvider { + pub fn new(config: AWSIRSAConfig, credentials_provider: CredentialsProviderChain) -> Self { + CredentialsProvider { + config, + credentials_provider, + } + } + pub async fn provide_credentials(&self) -> anyhow::Result<(String, String)> { + let aws_credentials = self + .credentials_provider + .provide_credentials() + .await? + .into(); + info!("AWS credentials successfully obtained"); + info!("Connecting to Redis with configuration: {:?}", self.config); + let mut settings = SigningSettings::default(); + settings.signature_location = SignatureLocation::QueryParams; + settings.expires_in = Some(self.config.token_ttl); + let signing_params = aws_sigv4::sign::v4::SigningParams::builder() + .identity(&aws_credentials) + .region(&self.config.region) + .name(&self.config.service_name) + .time(SystemTime::now()) + .settings(settings) + .build()? + .into(); + let auth_params = [ + ("Action", &self.config.action), + ("User", &self.config.user_id), + ]; + let auth_params = url::form_urlencoded::Serializer::new(String::new()) + .extend_pairs(auth_params) + .finish(); + let auth_uri = http::Uri::builder() + .scheme("http") + .authority(self.config.cluster_name.as_bytes()) + .path_and_query(format!("/?{auth_params}")) + .build()?; + info!("{}", auth_uri); + + // Convert the HTTP request into a signable request + let signable_request = SignableRequest::new( + "GET", + auth_uri.to_string(), + std::iter::empty(), + SignableBody::Bytes(&[]), + )?; + + // Sign and then apply the signature to the request + let (si, _) = http_request::sign(signable_request, &signing_params)?.into_parts(); + let mut signable_request = http::Request::builder() + .method("GET") + .uri(auth_uri) + .body(())?; + si.apply_to_request_http1x(&mut signable_request); + Ok(( + self.config.user_id.clone(), + signable_request + .uri() + .to_string() + .replacen("http://", "", 1), + )) + } +} diff --git a/proxy/src/redis/notifications.rs b/proxy/src/redis/notifications.rs index 6ae848c0d2..8b7e3e3419 100644 --- a/proxy/src/redis/notifications.rs +++ b/proxy/src/redis/notifications.rs @@ -6,11 +6,12 @@ use redis::aio::PubSub; use serde::{Deserialize, Serialize}; use uuid::Uuid; +use super::connection_with_credentials_provider::ConnectionWithCredentialsProvider; use crate::{ cache::project_info::ProjectInfoCache, - cancellation::{CancelMap, CancellationHandler, NotificationsCancellationHandler}, + cancellation::{CancelMap, CancellationHandler}, intern::{ProjectIdInt, RoleNameInt}, - metrics::REDIS_BROKEN_MESSAGES, + metrics::{NUM_CANCELLATION_REQUESTS_SOURCE_FROM_REDIS, REDIS_BROKEN_MESSAGES}, }; const CPLANE_CHANNEL_NAME: &str = "neondb-proxy-ws-updates"; @@ -18,23 +19,13 @@ pub(crate) const PROXY_CHANNEL_NAME: &str = "neondb-proxy-to-proxy-updates"; const RECONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(20); const INVALIDATION_LAG: std::time::Duration = std::time::Duration::from_secs(20); -struct RedisConsumerClient { - client: redis::Client, -} - -impl RedisConsumerClient { - pub fn new(url: &str) -> anyhow::Result { - let client = redis::Client::open(url)?; - Ok(Self { client }) - } - async fn try_connect(&self) -> anyhow::Result { - let mut conn = self.client.get_async_connection().await?.into_pubsub(); - tracing::info!("subscribing to a channel `{CPLANE_CHANNEL_NAME}`"); - conn.subscribe(CPLANE_CHANNEL_NAME).await?; - tracing::info!("subscribing to a channel `{PROXY_CHANNEL_NAME}`"); - conn.subscribe(PROXY_CHANNEL_NAME).await?; - Ok(conn) - } +async fn try_connect(client: &ConnectionWithCredentialsProvider) -> anyhow::Result { + let mut conn = client.get_async_pubsub().await?; + tracing::info!("subscribing to a channel `{CPLANE_CHANNEL_NAME}`"); + conn.subscribe(CPLANE_CHANNEL_NAME).await?; + tracing::info!("subscribing to a channel `{PROXY_CHANNEL_NAME}`"); + conn.subscribe(PROXY_CHANNEL_NAME).await?; + Ok(conn) } #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] @@ -80,21 +71,18 @@ where serde_json::from_str(&s).map_err(::custom) } -struct MessageHandler< - C: ProjectInfoCache + Send + Sync + 'static, - H: NotificationsCancellationHandler + Send + Sync + 'static, -> { +struct MessageHandler { cache: Arc, - cancellation_handler: Arc, + cancellation_handler: Arc>, region_id: String, } -impl< - C: ProjectInfoCache + Send + Sync + 'static, - H: NotificationsCancellationHandler + Send + Sync + 'static, - > MessageHandler -{ - pub fn new(cache: Arc, cancellation_handler: Arc, region_id: String) -> Self { +impl MessageHandler { + pub fn new( + cache: Arc, + cancellation_handler: Arc>, + region_id: String, + ) -> Self { Self { cache, cancellation_handler, @@ -139,7 +127,7 @@ impl< // This instance of cancellation_handler doesn't have a RedisPublisherClient so it can't publish the message. match self .cancellation_handler - .cancel_session_no_publish(cancel_session.cancel_key_data) + .cancel_session(cancel_session.cancel_key_data, uuid::Uuid::nil()) .await { Ok(()) => {} @@ -182,7 +170,7 @@ fn invalidate_cache(cache: Arc, msg: Notification) { /// Handle console's invalidation messages. #[tracing::instrument(name = "console_notifications", skip_all)] pub async fn task_main( - url: String, + redis: ConnectionWithCredentialsProvider, cache: Arc, cancel_map: CancelMap, region_id: String, @@ -193,13 +181,15 @@ where cache.enable_ttl(); let handler = MessageHandler::new( cache, - Arc::new(CancellationHandler::new(cancel_map, None)), + Arc::new(CancellationHandler::<()>::new( + cancel_map, + NUM_CANCELLATION_REQUESTS_SOURCE_FROM_REDIS, + )), region_id, ); loop { - let redis = RedisConsumerClient::new(&url)?; - let conn = match redis.try_connect().await { + let mut conn = match try_connect(&redis).await { Ok(conn) => { handler.disable_ttl(); conn @@ -212,7 +202,7 @@ where continue; } }; - let mut stream = conn.into_on_message(); + let mut stream = conn.on_message(); while let Some(msg) = stream.next().await { match handler.handle_message(msg).await { Ok(()) => {} diff --git a/proxy/src/redis/publisher.rs b/proxy/src/redis/publisher.rs deleted file mode 100644 index f85593afdd..0000000000 --- a/proxy/src/redis/publisher.rs +++ /dev/null @@ -1,80 +0,0 @@ -use pq_proto::CancelKeyData; -use redis::AsyncCommands; -use uuid::Uuid; - -use crate::rate_limiter::{RateBucketInfo, RedisRateLimiter}; - -use super::notifications::{CancelSession, Notification, PROXY_CHANNEL_NAME}; - -pub struct RedisPublisherClient { - client: redis::Client, - publisher: Option, - region_id: String, - limiter: RedisRateLimiter, -} - -impl RedisPublisherClient { - pub fn new( - url: &str, - region_id: String, - info: &'static [RateBucketInfo], - ) -> anyhow::Result { - let client = redis::Client::open(url)?; - Ok(Self { - client, - publisher: None, - region_id, - limiter: RedisRateLimiter::new(info), - }) - } - pub async fn try_publish( - &mut self, - cancel_key_data: CancelKeyData, - session_id: Uuid, - ) -> anyhow::Result<()> { - if !self.limiter.check() { - tracing::info!("Rate limit exceeded. Skipping cancellation message"); - return Err(anyhow::anyhow!("Rate limit exceeded")); - } - match self.publish(cancel_key_data, session_id).await { - Ok(()) => return Ok(()), - Err(e) => { - tracing::error!("failed to publish a message: {e}"); - self.publisher = None; - } - } - tracing::info!("Publisher is disconnected. Reconnectiong..."); - self.try_connect().await?; - self.publish(cancel_key_data, session_id).await - } - - async fn publish( - &mut self, - cancel_key_data: CancelKeyData, - session_id: Uuid, - ) -> anyhow::Result<()> { - let conn = self - .publisher - .as_mut() - .ok_or_else(|| anyhow::anyhow!("not connected"))?; - let payload = serde_json::to_string(&Notification::Cancel(CancelSession { - region_id: Some(self.region_id.clone()), - cancel_key_data, - session_id, - }))?; - conn.publish(PROXY_CHANNEL_NAME, payload).await?; - Ok(()) - } - pub async fn try_connect(&mut self) -> anyhow::Result<()> { - match self.client.get_async_connection().await { - Ok(conn) => { - self.publisher = Some(conn); - } - Err(e) => { - tracing::error!("failed to connect to redis: {e}"); - return Err(e.into()); - } - } - Ok(()) - } -} diff --git a/proxy/src/serverless.rs b/proxy/src/serverless.rs index be9f90acde..a2010fd613 100644 --- a/proxy/src/serverless.rs +++ b/proxy/src/serverless.rs @@ -21,11 +21,12 @@ pub use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware}; use tokio_util::task::TaskTracker; use tracing::instrument::Instrumented; +use crate::cancellation::CancellationHandlerMain; +use crate::config::ProxyConfig; use crate::context::RequestMonitoring; use crate::protocol2::{ProxyProtocolAccept, WithClientIp, WithConnectionGuard}; use crate::rate_limiter::EndpointRateLimiter; use crate::serverless::backend::PoolingBackend; -use crate::{cancellation::CancellationHandler, config::ProxyConfig}; use hyper::{ server::conn::{AddrIncoming, AddrStream}, Body, Method, Request, Response, @@ -47,7 +48,7 @@ pub async fn task_main( ws_listener: TcpListener, cancellation_token: CancellationToken, endpoint_rate_limiter: Arc, - cancellation_handler: Arc, + cancellation_handler: Arc, ) -> anyhow::Result<()> { scopeguard::defer! { info!("websocket server has shut down"); @@ -237,7 +238,7 @@ async fn request_handler( config: &'static ProxyConfig, backend: Arc, ws_connections: TaskTracker, - cancellation_handler: Arc, + cancellation_handler: Arc, peer_addr: IpAddr, endpoint_rate_limiter: Arc, // used to cancel in-flight HTTP requests. not used to cancel websockets diff --git a/proxy/src/serverless/websocket.rs b/proxy/src/serverless/websocket.rs index a72ede6d0a..ada6c974f4 100644 --- a/proxy/src/serverless/websocket.rs +++ b/proxy/src/serverless/websocket.rs @@ -1,5 +1,5 @@ use crate::{ - cancellation::CancellationHandler, + cancellation::CancellationHandlerMain, config::ProxyConfig, context::RequestMonitoring, error::{io_error, ReportableError}, @@ -134,7 +134,7 @@ pub async fn serve_websocket( config: &'static ProxyConfig, mut ctx: RequestMonitoring, websocket: HyperWebsocket, - cancellation_handler: Arc, + cancellation_handler: Arc, hostname: Option, endpoint_rate_limiter: Arc, ) -> anyhow::Result<()> { diff --git a/workspace_hack/Cargo.toml b/workspace_hack/Cargo.toml index 152c452dd4..7b8228a082 100644 --- a/workspace_hack/Cargo.toml +++ b/workspace_hack/Cargo.toml @@ -19,8 +19,7 @@ aws-runtime = { version = "1", default-features = false, features = ["event-stre aws-sigv4 = { version = "1", features = ["http0-compat", "sign-eventstream", "sigv4a"] } aws-smithy-async = { version = "1", default-features = false, features = ["rt-tokio"] } aws-smithy-http = { version = "0.60", default-features = false, features = ["event-stream"] } -aws-smithy-runtime-api = { version = "1", features = ["client", "http-02x", "http-auth"] } -aws-smithy-types = { version = "1", default-features = false, features = ["byte-stream-poll-next", "http-body-0-4-x", "rt-tokio"] } +aws-smithy-types = { version = "1", default-features = false, features = ["byte-stream-poll-next", "http-body-0-4-x", "rt-tokio", "test-util"] } axum = { version = "0.6", features = ["ws"] } base64 = { version = "0.21", features = ["alloc"] } base64ct = { version = "1", default-features = false, features = ["std"] }