diff --git a/proxy/src/binary/local_proxy.rs b/proxy/src/binary/local_proxy.rs index e3be454713..423ecf821e 100644 --- a/proxy/src/binary/local_proxy.rs +++ b/proxy/src/binary/local_proxy.rs @@ -279,7 +279,6 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig }, proxy_protocol_v2: config::ProxyProtocolV2::Rejected, handshake_timeout: Duration::from_secs(10), - region: "local".into(), wake_compute_retry_config: RetryConfig::parse(RetryConfig::WAKE_COMPUTE_DEFAULT_VALUES)?, connect_compute_locks, connect_to_compute: compute_config, diff --git a/proxy/src/binary/pg_sni_router.rs b/proxy/src/binary/pg_sni_router.rs index 481bd8501c..070c73cdcf 100644 --- a/proxy/src/binary/pg_sni_router.rs +++ b/proxy/src/binary/pg_sni_router.rs @@ -236,7 +236,6 @@ pub(super) async fn task_main( extra: None, }, crate::metrics::Protocol::SniRouter, - "sni", ); handle_client(ctx, dest_suffix, tls_config, compute_tls_config, socket).await } diff --git a/proxy/src/binary/proxy.rs b/proxy/src/binary/proxy.rs index 9215dbf73f..9ead05d492 100644 --- a/proxy/src/binary/proxy.rs +++ b/proxy/src/binary/proxy.rs @@ -123,12 +123,6 @@ struct ProxyCliArgs { /// timeout for the TLS handshake #[clap(long, default_value = "15s", value_parser = humantime::parse_duration)] handshake_timeout: tokio::time::Duration, - /// http endpoint to receive periodic metric updates - #[clap(long)] - metric_collection_endpoint: Option, - /// how often metrics should be sent to a collection endpoint - #[clap(long)] - metric_collection_interval: Option, /// cache for `wake_compute` api method (use `size=0` to disable) #[clap(long, default_value = config::CacheOptions::CACHE_DEFAULT_OPTIONS)] wake_compute_cache: String, @@ -155,40 +149,31 @@ struct ProxyCliArgs { /// Wake compute rate limiter max number of requests per second. #[clap(long, default_values_t = RateBucketInfo::DEFAULT_SET)] wake_compute_limit: Vec, - /// Redis rate limiter max number of requests per second. - #[clap(long, default_values_t = RateBucketInfo::DEFAULT_REDIS_SET)] - redis_rps_limit: Vec, /// Cancellation channel size (max queue size for redis kv client) #[clap(long, default_value_t = 1024)] cancellation_ch_size: usize, /// Cancellation ops batch size for redis #[clap(long, default_value_t = 8)] cancellation_batch_size: usize, - /// cache for `allowed_ips` (use `size=0` to disable) - #[clap(long, default_value = config::CacheOptions::CACHE_DEFAULT_OPTIONS)] - allowed_ips_cache: String, - /// cache for `role_secret` (use `size=0` to disable) - #[clap(long, default_value = config::CacheOptions::CACHE_DEFAULT_OPTIONS)] - role_secret_cache: String, - /// redis url for notifications (if empty, redis_host:port will be used for both notifications and streaming connections) - #[clap(long)] - redis_notifications: Option, - /// what from the available authentications type to use for the regional redis we have. Supported are "irsa" and "plain". + /// redis url for plain authentication + #[clap(long, alias("redis-notifications"))] + redis_plain: Option, + /// what from the available authentications type to use for redis. Supported are "irsa" and "plain". #[clap(long, default_value = "irsa")] redis_auth_type: String, - /// redis host for streaming connections (might be different from the notifications host) + /// redis host for irsa authentication #[clap(long)] redis_host: Option, - /// redis port for streaming connections (might be different from the notifications host) + /// redis port for irsa authentication #[clap(long)] redis_port: Option, - /// redis cluster name, used in aws elasticache + /// redis cluster name for irsa authentication #[clap(long)] redis_cluster_name: Option, - /// redis user_id, used in aws elasticache + /// redis user_id for irsa authentication #[clap(long)] redis_user_id: Option, - /// aws region to retrieve credentials + /// aws region for irsa authentication #[clap(long, default_value_t = String::new())] aws_region: String, /// cache for `project_info` (use `size=0` to disable) @@ -200,6 +185,12 @@ struct ProxyCliArgs { #[clap(flatten)] parquet_upload: ParquetUploadArgs, + /// http endpoint to receive periodic metric updates + #[clap(long)] + metric_collection_endpoint: Option, + /// how often metrics should be sent to a collection endpoint + #[clap(long)] + metric_collection_interval: Option, /// interval for backup metric collection #[clap(long, default_value = "10m", value_parser = humantime::parse_duration)] metric_backup_collection_interval: std::time::Duration, @@ -212,6 +203,7 @@ struct ProxyCliArgs { /// Size of each event is no more than 400 bytes, so 2**22 is about 200MB before the compression. #[clap(long, default_value = "4194304")] metric_backup_collection_chunk_size: usize, + /// Whether to retry the connection to the compute node #[clap(long, default_value = config::RetryConfig::CONNECT_TO_COMPUTE_DEFAULT_VALUES)] connect_to_compute_retry: String, @@ -331,7 +323,7 @@ pub async fn run() -> anyhow::Result<()> { Either::Right(auth_backend) => info!("Authentication backend: {auth_backend:?}"), } info!("Using region: {}", args.aws_region); - let (regional_redis_client, redis_notifications_client) = configure_redis(&args).await?; + let redis_client = configure_redis(&args).await?; // Check that we can bind to address before further initialization info!("Starting http on {}", args.http); @@ -386,13 +378,6 @@ pub async fn run() -> anyhow::Result<()> { let cancellation_token = CancellationToken::new(); - let redis_rps_limit = Vec::leak(args.redis_rps_limit.clone()); - RateBucketInfo::validate(redis_rps_limit)?; - - let redis_kv_client = regional_redis_client - .as_ref() - .map(|redis_publisher| RedisKVClient::new(redis_publisher.clone(), redis_rps_limit)); - let cancellation_handler = Arc::new(CancellationHandler::new(&config.connect_to_compute)); let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new_with_shards( @@ -472,6 +457,7 @@ pub async fn run() -> anyhow::Result<()> { client_tasks.spawn(crate::context::parquet::worker( cancellation_token.clone(), args.parquet_upload, + args.region, )); // maintenance tasks. these never return unless there's an error @@ -495,32 +481,17 @@ pub async fn run() -> anyhow::Result<()> { #[cfg_attr(not(any(test, feature = "testing")), expect(irrefutable_let_patterns))] if let Either::Left(auth::Backend::ControlPlane(api, ())) = &auth_backend { if let crate::control_plane::client::ControlPlaneClient::ProxyV1(api) = &**api { - match (redis_notifications_client, regional_redis_client.clone()) { - (None, None) => {} - (client1, client2) => { - let cache = api.caches.project_info.clone(); - if let Some(client) = client1 { - maintenance_tasks.spawn(notifications::task_main( - client, - cache.clone(), - args.region.clone(), - )); - } - if let Some(client) = client2 { - maintenance_tasks.spawn(notifications::task_main( - client, - cache.clone(), - args.region.clone(), - )); - } - maintenance_tasks.spawn(async move { cache.clone().gc_worker().await }); - } - } + if let Some(client) = redis_client { + // project info cache and invalidation of that cache. + let cache = api.caches.project_info.clone(); + maintenance_tasks.spawn(notifications::task_main(client.clone(), cache.clone())); + maintenance_tasks.spawn(async move { cache.clone().gc_worker().await }); - // Try to connect to Redis 3 times with 1 + (0..0.1) second interval. - // This prevents immediate exit and pod restart, - // which can cause hammering of the redis in case of connection issues. - if let Some(mut redis_kv_client) = redis_kv_client { + // Try to connect to Redis 3 times with 1 + (0..0.1) second interval. + // This prevents immediate exit and pod restart, + // which can cause hammering of the redis in case of connection issues. + // cancellation key management + let mut redis_kv_client = RedisKVClient::new(client.clone()); for attempt in (0..3).with_position() { match redis_kv_client.try_connect().await { Ok(()) => { @@ -545,14 +516,12 @@ pub async fn run() -> anyhow::Result<()> { } } } - } - if let Some(regional_redis_client) = regional_redis_client { + // listen for notifications of new projects/endpoints/branches let cache = api.caches.endpoints_cache.clone(); - let con = regional_redis_client; let span = tracing::info_span!("endpoints_cache"); maintenance_tasks.spawn( - async move { cache.do_read(con, cancellation_token.clone()).await } + async move { cache.do_read(client, cancellation_token.clone()).await } .instrument(span), ); } @@ -681,7 +650,6 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { authentication_config, proxy_protocol_v2: args.proxy_protocol_v2, handshake_timeout: args.handshake_timeout, - region: args.region.clone(), wake_compute_retry_config: config::RetryConfig::parse(&args.wake_compute_retry)?, connect_compute_locks, connect_to_compute: compute_config, @@ -843,21 +811,18 @@ fn build_auth_backend( async fn configure_redis( args: &ProxyCliArgs, -) -> anyhow::Result<( - Option, - Option, -)> { +) -> anyhow::Result> { // TODO: untangle the config args - let regional_redis_client = match (args.redis_auth_type.as_str(), &args.redis_notifications) { - ("plain", redis_url) => match redis_url { + let redis_client = match &*args.redis_auth_type { + "plain" => match &args.redis_plain { None => { - bail!("plain auth requires redis_notifications to be set"); + bail!("plain auth requires redis_plain to be set"); } Some(url) => { Some(ConnectionWithCredentialsProvider::new_with_static_credentials(url.clone())) } }, - ("irsa", _) => match (&args.redis_host, args.redis_port) { + "irsa" => match (&args.redis_host, args.redis_port) { (Some(host), Some(port)) => Some( ConnectionWithCredentialsProvider::new_with_credentials_provider( host.clone(), @@ -881,18 +846,12 @@ async fn configure_redis( bail!("redis-host and redis-port must be specified together"); } }, - _ => { - bail!("unknown auth type given"); + auth_type => { + bail!("unknown auth type {auth_type:?} given") } }; - let redis_notifications_client = if let Some(url) = &args.redis_notifications { - Some(ConnectionWithCredentialsProvider::new_with_static_credentials(&**url)) - } else { - regional_redis_client.clone() - }; - - Ok((regional_redis_client, redis_notifications_client)) + Ok(redis_client) } #[cfg(test)] diff --git a/proxy/src/config.rs b/proxy/src/config.rs index 248584a19a..cee15ac7fa 100644 --- a/proxy/src/config.rs +++ b/proxy/src/config.rs @@ -22,7 +22,6 @@ pub struct ProxyConfig { pub http_config: HttpConfig, pub authentication_config: AuthenticationConfig, pub proxy_protocol_v2: ProxyProtocolV2, - pub region: String, pub handshake_timeout: Duration, pub wake_compute_retry_config: RetryConfig, pub connect_compute_locks: ApiLocks, diff --git a/proxy/src/console_redirect_proxy.rs b/proxy/src/console_redirect_proxy.rs index 113a11beab..112465a89b 100644 --- a/proxy/src/console_redirect_proxy.rs +++ b/proxy/src/console_redirect_proxy.rs @@ -89,12 +89,7 @@ pub async fn task_main( } } - let ctx = RequestContext::new( - session_id, - conn_info, - crate::metrics::Protocol::Tcp, - &config.region, - ); + let ctx = RequestContext::new(session_id, conn_info, crate::metrics::Protocol::Tcp); let res = handle_client( config, diff --git a/proxy/src/context/mod.rs b/proxy/src/context/mod.rs index 24268997ba..df1c4e194a 100644 --- a/proxy/src/context/mod.rs +++ b/proxy/src/context/mod.rs @@ -46,7 +46,6 @@ struct RequestContextInner { pub(crate) session_id: Uuid, pub(crate) protocol: Protocol, first_packet: chrono::DateTime, - region: &'static str, pub(crate) span: Span, // filled in as they are discovered @@ -94,7 +93,6 @@ impl Clone for RequestContext { session_id: inner.session_id, protocol: inner.protocol, first_packet: inner.first_packet, - region: inner.region, span: info_span!("background_task"), project: inner.project, @@ -124,12 +122,7 @@ impl Clone for RequestContext { } impl RequestContext { - pub fn new( - session_id: Uuid, - conn_info: ConnectionInfo, - protocol: Protocol, - region: &'static str, - ) -> Self { + pub fn new(session_id: Uuid, conn_info: ConnectionInfo, protocol: Protocol) -> Self { // TODO: be careful with long lived spans let span = info_span!( "connect_request", @@ -145,7 +138,6 @@ impl RequestContext { session_id, protocol, first_packet: Utc::now(), - region, span, project: None, @@ -179,7 +171,7 @@ impl RequestContext { let ip = IpAddr::from([127, 0, 0, 1]); let addr = SocketAddr::new(ip, 5432); let conn_info = ConnectionInfo { addr, extra: None }; - RequestContext::new(Uuid::now_v7(), conn_info, Protocol::Tcp, "test") + RequestContext::new(Uuid::now_v7(), conn_info, Protocol::Tcp) } pub(crate) fn console_application_name(&self) -> String { diff --git a/proxy/src/context/parquet.rs b/proxy/src/context/parquet.rs index c9d3905abd..b55cc14532 100644 --- a/proxy/src/context/parquet.rs +++ b/proxy/src/context/parquet.rs @@ -74,7 +74,7 @@ pub(crate) const FAILED_UPLOAD_MAX_RETRIES: u32 = 10; #[derive(parquet_derive::ParquetRecordWriter)] pub(crate) struct RequestData { - region: &'static str, + region: String, protocol: &'static str, /// Must be UTC. The derive macro doesn't like the timezones timestamp: chrono::NaiveDateTime, @@ -147,7 +147,7 @@ impl From<&RequestContextInner> for RequestData { }), jwt_issuer: value.jwt_issuer.clone(), protocol: value.protocol.as_str(), - region: value.region, + region: String::new(), error: value.error_kind.as_ref().map(|e| e.to_metric_label()), success: value.success, cold_start_info: value.cold_start_info.as_str(), @@ -167,6 +167,7 @@ impl From<&RequestContextInner> for RequestData { pub async fn worker( cancellation_token: CancellationToken, config: ParquetUploadArgs, + region: String, ) -> anyhow::Result<()> { let Some(remote_storage_config) = config.parquet_upload_remote_storage else { tracing::warn!("parquet request upload: no s3 bucket configured"); @@ -232,12 +233,17 @@ pub async fn worker( .context("remote storage for disconnect events init")?; let parquet_config_disconnect = parquet_config.clone(); tokio::try_join!( - worker_inner(storage, rx, parquet_config), - worker_inner(storage_disconnect, rx_disconnect, parquet_config_disconnect) + worker_inner(storage, rx, parquet_config, ®ion), + worker_inner( + storage_disconnect, + rx_disconnect, + parquet_config_disconnect, + ®ion + ) ) .map(|_| ()) } else { - worker_inner(storage, rx, parquet_config).await + worker_inner(storage, rx, parquet_config, ®ion).await } } @@ -257,6 +263,7 @@ async fn worker_inner( storage: GenericRemoteStorage, rx: impl Stream, config: ParquetConfig, + region: &str, ) -> anyhow::Result<()> { #[cfg(any(test, feature = "testing"))] let storage = if config.test_remote_failures > 0 { @@ -277,7 +284,8 @@ async fn worker_inner( let mut last_upload = time::Instant::now(); let mut len = 0; - while let Some(row) = rx.next().await { + while let Some(mut row) = rx.next().await { + region.clone_into(&mut row.region); rows.push(row); let force = last_upload.elapsed() > config.max_duration; if rows.len() == config.rows_per_group || force { @@ -533,7 +541,7 @@ mod tests { auth_method: None, jwt_issuer: None, protocol: ["tcp", "ws", "http"][rng.gen_range(0..3)], - region: "us-east-1", + region: String::new(), error: None, success: rng.r#gen(), cold_start_info: "no", @@ -565,7 +573,9 @@ mod tests { .await .unwrap(); - worker_inner(storage, rx, config).await.unwrap(); + worker_inner(storage, rx, config, "us-east-1") + .await + .unwrap(); let mut files = WalkDir::new(tmpdir.as_std_path()) .into_iter() diff --git a/proxy/src/proxy/mod.rs b/proxy/src/proxy/mod.rs index 6947e07488..6b84e47982 100644 --- a/proxy/src/proxy/mod.rs +++ b/proxy/src/proxy/mod.rs @@ -122,12 +122,7 @@ pub async fn task_main( } } - let ctx = RequestContext::new( - session_id, - conn_info, - crate::metrics::Protocol::Tcp, - &config.region, - ); + let ctx = RequestContext::new(session_id, conn_info, crate::metrics::Protocol::Tcp); let res = handle_client( config, diff --git a/proxy/src/rate_limiter/limiter.rs b/proxy/src/rate_limiter/limiter.rs index 0cd539188a..2e40f5bf60 100644 --- a/proxy/src/rate_limiter/limiter.rs +++ b/proxy/src/rate_limiter/limiter.rs @@ -139,12 +139,6 @@ impl RateBucketInfo { Self::new(200, Duration::from_secs(600)), ]; - // For all the sessions will be cancel key. So this limit is essentially global proxy limit. - pub const DEFAULT_REDIS_SET: [Self; 2] = [ - Self::new(100_000, Duration::from_secs(1)), - Self::new(50_000, Duration::from_secs(10)), - ]; - pub fn rps(&self) -> f64 { (self.max_rpi as f64) / self.interval.as_secs_f64() } diff --git a/proxy/src/redis/kv_ops.rs b/proxy/src/redis/kv_ops.rs index f8d3b5cc66..671fe09b0b 100644 --- a/proxy/src/redis/kv_ops.rs +++ b/proxy/src/redis/kv_ops.rs @@ -5,11 +5,9 @@ use redis::aio::ConnectionLike; use redis::{Cmd, FromRedisValue, Pipeline, RedisResult}; use super::connection_with_credentials_provider::ConnectionWithCredentialsProvider; -use crate::rate_limiter::{GlobalRateLimiter, RateBucketInfo}; pub struct RedisKVClient { client: ConnectionWithCredentialsProvider, - limiter: GlobalRateLimiter, } #[allow(async_fn_in_trait)] @@ -30,11 +28,8 @@ impl Queryable for Cmd { } impl RedisKVClient { - pub fn new(client: ConnectionWithCredentialsProvider, info: &'static [RateBucketInfo]) -> Self { - Self { - client, - limiter: GlobalRateLimiter::new(info.into()), - } + pub fn new(client: ConnectionWithCredentialsProvider) -> Self { + Self { client } } pub async fn try_connect(&mut self) -> anyhow::Result<()> { @@ -49,11 +44,6 @@ impl RedisKVClient { &mut self, q: &impl Queryable, ) -> anyhow::Result { - if !self.limiter.check() { - tracing::info!("Rate limit exceeded. Skipping query"); - return Err(anyhow::anyhow!("Rate limit exceeded")); - } - let e = match q.query(&mut self.client).await { Ok(t) => return Ok(t), Err(e) => e, diff --git a/proxy/src/redis/notifications.rs b/proxy/src/redis/notifications.rs index 6c8260027f..973a4c5b02 100644 --- a/proxy/src/redis/notifications.rs +++ b/proxy/src/redis/notifications.rs @@ -141,29 +141,19 @@ where struct MessageHandler { cache: Arc, - region_id: String, } impl Clone for MessageHandler { fn clone(&self) -> Self { Self { cache: self.cache.clone(), - region_id: self.region_id.clone(), } } } impl MessageHandler { - pub(crate) fn new(cache: Arc, region_id: String) -> Self { - Self { cache, region_id } - } - - pub(crate) async fn increment_active_listeners(&self) { - self.cache.increment_active_listeners().await; - } - - pub(crate) async fn decrement_active_listeners(&self) { - self.cache.decrement_active_listeners().await; + pub(crate) fn new(cache: Arc) -> Self { + Self { cache } } #[tracing::instrument(skip(self, msg), fields(session_id = tracing::field::Empty))] @@ -276,7 +266,7 @@ async fn handle_messages( } let mut conn = match try_connect(&redis).await { Ok(conn) => { - handler.increment_active_listeners().await; + handler.cache.increment_active_listeners().await; conn } Err(e) => { @@ -297,11 +287,11 @@ async fn handle_messages( } } if cancellation_token.is_cancelled() { - handler.decrement_active_listeners().await; + handler.cache.decrement_active_listeners().await; return Ok(()); } } - handler.decrement_active_listeners().await; + handler.cache.decrement_active_listeners().await; } } @@ -310,12 +300,11 @@ async fn handle_messages( pub async fn task_main( redis: ConnectionWithCredentialsProvider, cache: Arc, - region_id: String, ) -> anyhow::Result where C: ProjectInfoCache + Send + Sync + 'static, { - let handler = MessageHandler::new(cache, region_id); + let handler = MessageHandler::new(cache); // 6h - 1m. // There will be 1 minute overlap between two tasks. But at least we can be sure that no message is lost. let mut interval = tokio::time::interval(std::time::Duration::from_secs(6 * 60 * 60 - 60)); diff --git a/proxy/src/serverless/mod.rs b/proxy/src/serverless/mod.rs index ed33bf1246..d8942bb814 100644 --- a/proxy/src/serverless/mod.rs +++ b/proxy/src/serverless/mod.rs @@ -417,12 +417,7 @@ async fn request_handler( if config.http_config.accept_websockets && framed_websockets::upgrade::is_upgrade_request(&request) { - let ctx = RequestContext::new( - session_id, - conn_info, - crate::metrics::Protocol::Ws, - &config.region, - ); + let ctx = RequestContext::new(session_id, conn_info, crate::metrics::Protocol::Ws); ctx.set_user_agent( request @@ -462,12 +457,7 @@ async fn request_handler( // Return the response so the spawned future can continue. Ok(response.map(|b| b.map_err(|x| match x {}).boxed())) } else if request.uri().path() == "/sql" && *request.method() == Method::POST { - let ctx = RequestContext::new( - session_id, - conn_info, - crate::metrics::Protocol::Http, - &config.region, - ); + let ctx = RequestContext::new(session_id, conn_info, crate::metrics::Protocol::Http); let span = ctx.span(); let testodrome_id = request