diff --git a/proxy/src/auth/backend.rs b/proxy/src/auth/backend.rs index e421798067..71e9da18bc 100644 --- a/proxy/src/auth/backend.rs +++ b/proxy/src/auth/backend.rs @@ -27,7 +27,7 @@ use crate::{ }, stream, url, }; -use crate::{scram, EndpointCacheKey, EndpointId, RoleName}; +use crate::{scram, EndpointCacheKey, EndpointId, Normalize, RoleName}; use std::sync::Arc; use tokio::io::{AsyncRead, AsyncWrite}; use tracing::{info, warn}; @@ -186,7 +186,7 @@ impl AuthenticationConfig { is_cleartext: bool, ) -> auth::Result { // we have validated the endpoint exists, so let's intern it. - let endpoint_int = EndpointIdInt::from(endpoint); + let endpoint_int = EndpointIdInt::from(endpoint.normalize()); // only count the full hash count if password hack or websocket flow. // in other words, if proxy needs to run the hashing diff --git a/proxy/src/bin/proxy.rs b/proxy/src/bin/proxy.rs index 56a3ef79cd..9302b31d5c 100644 --- a/proxy/src/bin/proxy.rs +++ b/proxy/src/bin/proxy.rs @@ -189,7 +189,9 @@ struct ProxyCliArgs { /// cache for `project_info` (use `size=0` to disable) #[clap(long, default_value = config::ProjectInfoCacheOptions::CACHE_DEFAULT_OPTIONS)] project_info_cache: String, - + /// cache for all valid endpoints + #[clap(long, default_value = config::EndpointCacheConfig::CACHE_DEFAULT_OPTIONS)] + endpoint_cache_config: String, #[clap(flatten)] parquet_upload: ParquetUploadArgs, @@ -401,6 +403,7 @@ async fn main() -> anyhow::Result<()> { if let auth::BackendType::Console(api, _) = &config.auth_backend { if let proxy::console::provider::ConsoleBackend::Console(api) = &**api { + maintenance_tasks.spawn(api.locks.garbage_collect_worker()); if let Some(redis_notifications_client) = redis_notifications_client { let cache = api.caches.project_info.clone(); maintenance_tasks.spawn(notifications::task_main( @@ -410,6 +413,9 @@ async fn main() -> anyhow::Result<()> { args.region.clone(), )); maintenance_tasks.spawn(async move { cache.clone().gc_worker().await }); + let cache = api.caches.endpoints_cache.clone(); + let con = redis_notifications_client.clone(); + maintenance_tasks.spawn(async move { cache.do_read(con).await }); } } } @@ -489,14 +495,18 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { let wake_compute_cache_config: CacheOptions = args.wake_compute_cache.parse()?; let project_info_cache_config: ProjectInfoCacheOptions = args.project_info_cache.parse()?; + let endpoint_cache_config: config::EndpointCacheConfig = + args.endpoint_cache_config.parse()?; info!("Using NodeInfoCache (wake_compute) with options={wake_compute_cache_config:?}"); info!( "Using AllowedIpsCache (wake_compute) with options={project_info_cache_config:?}" ); + info!("Using EndpointCacheConfig with options={endpoint_cache_config:?}"); let caches = Box::leak(Box::new(console::caches::ApiCaches::new( wake_compute_cache_config, project_info_cache_config, + endpoint_cache_config, ))); let config::WakeComputeLockOptions { @@ -507,10 +517,9 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { } = args.wake_compute_lock.parse()?; info!(permits, shards, ?epoch, "Using NodeLocks (wake_compute)"); let locks = Box::leak(Box::new( - console::locks::ApiLocks::new("wake_compute_lock", permits, shards, timeout) + console::locks::ApiLocks::new("wake_compute_lock", permits, shards, timeout, epoch) .unwrap(), )); - tokio::spawn(locks.garbage_collect_worker(epoch)); let url = args.auth_endpoint.parse()?; let endpoint = http::Endpoint::new(url, http::new_client(rate_limiter_config)); diff --git a/proxy/src/cache.rs b/proxy/src/cache.rs index fc5f416395..d1d4087241 100644 --- a/proxy/src/cache.rs +++ b/proxy/src/cache.rs @@ -1,4 +1,5 @@ pub mod common; +pub mod endpoints; pub mod project_info; mod timed_lru; diff --git a/proxy/src/cache/endpoints.rs b/proxy/src/cache/endpoints.rs new file mode 100644 index 0000000000..9bc019c2d8 --- /dev/null +++ b/proxy/src/cache/endpoints.rs @@ -0,0 +1,191 @@ +use std::{ + convert::Infallible, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, +}; + +use dashmap::DashSet; +use redis::{ + streams::{StreamReadOptions, StreamReadReply}, + AsyncCommands, FromRedisValue, Value, +}; +use serde::Deserialize; +use tokio::sync::Mutex; + +use crate::{ + config::EndpointCacheConfig, + context::RequestMonitoring, + intern::{BranchIdInt, EndpointIdInt, ProjectIdInt}, + metrics::REDIS_BROKEN_MESSAGES, + rate_limiter::GlobalRateLimiter, + redis::connection_with_credentials_provider::ConnectionWithCredentialsProvider, + EndpointId, Normalize, +}; + +#[derive(Deserialize, Debug, Clone)] +#[serde(rename_all(deserialize = "snake_case"))] +pub enum ControlPlaneEventKey { + EndpointCreated, + BranchCreated, + ProjectCreated, +} + +pub struct EndpointsCache { + config: EndpointCacheConfig, + endpoints: DashSet, + branches: DashSet, + projects: DashSet, + ready: AtomicBool, + limiter: Arc>, +} + +impl EndpointsCache { + pub fn new(config: EndpointCacheConfig) -> Self { + Self { + limiter: Arc::new(Mutex::new(GlobalRateLimiter::new( + config.limiter_info.clone(), + ))), + config, + endpoints: DashSet::new(), + branches: DashSet::new(), + projects: DashSet::new(), + ready: AtomicBool::new(false), + } + } + pub async fn is_valid(&self, ctx: &mut RequestMonitoring, endpoint: &EndpointId) -> bool { + if !self.ready.load(Ordering::Acquire) { + return true; + } + // If cache is disabled, just collect the metrics and return. + if self.config.disable_cache { + ctx.set_rejected(self.should_reject(endpoint)); + return true; + } + // If the limiter allows, we don't need to check the cache. + if self.limiter.lock().await.check() { + return true; + } + let rejected = self.should_reject(endpoint); + ctx.set_rejected(rejected); + !rejected + } + fn should_reject(&self, endpoint: &EndpointId) -> bool { + let endpoint = endpoint.normalize(); + if endpoint.is_endpoint() { + !self.endpoints.contains(&EndpointIdInt::from(&endpoint)) + } else if endpoint.is_branch() { + !self + .branches + .contains(&BranchIdInt::from(&endpoint.as_branch())) + } else { + !self + .projects + .contains(&ProjectIdInt::from(&endpoint.as_project())) + } + } + fn insert_event(&self, key: ControlPlaneEventKey, value: String) { + // Do not do normalization here, we expect the events to be normalized. + match key { + ControlPlaneEventKey::EndpointCreated => { + self.endpoints.insert(EndpointIdInt::from(&value.into())); + } + ControlPlaneEventKey::BranchCreated => { + self.branches.insert(BranchIdInt::from(&value.into())); + } + ControlPlaneEventKey::ProjectCreated => { + self.projects.insert(ProjectIdInt::from(&value.into())); + } + } + } + pub async fn do_read( + &self, + mut con: ConnectionWithCredentialsProvider, + ) -> anyhow::Result { + let mut last_id = "0-0".to_string(); + loop { + self.ready.store(false, Ordering::Release); + if let Err(e) = con.connect().await { + tracing::error!("error connecting to redis: {:?}", e); + continue; + } + if let Err(e) = self.read_from_stream(&mut con, &mut last_id).await { + tracing::error!("error reading from redis: {:?}", e); + } + } + } + async fn read_from_stream( + &self, + con: &mut ConnectionWithCredentialsProvider, + last_id: &mut String, + ) -> anyhow::Result<()> { + tracing::info!("reading endpoints/branches/projects from redis"); + self.batch_read( + con, + StreamReadOptions::default().count(self.config.initial_batch_size), + last_id, + true, + ) + .await?; + tracing::info!("ready to filter user requests"); + self.ready.store(true, Ordering::Release); + self.batch_read( + con, + StreamReadOptions::default() + .count(self.config.initial_batch_size) + .block(self.config.xread_timeout.as_millis() as usize), + last_id, + false, + ) + .await + } + fn parse_key_value(key: &str, value: &Value) -> anyhow::Result<(ControlPlaneEventKey, String)> { + Ok((serde_json::from_str(key)?, String::from_redis_value(value)?)) + } + async fn batch_read( + &self, + conn: &mut ConnectionWithCredentialsProvider, + opts: StreamReadOptions, + last_id: &mut String, + return_when_finish: bool, + ) -> anyhow::Result<()> { + let mut total: usize = 0; + loop { + let mut res: StreamReadReply = conn + .xread_options(&[&self.config.stream_name], &[last_id.as_str()], &opts) + .await?; + if res.keys.len() != 1 { + anyhow::bail!("Cannot read from redis stream {}", self.config.stream_name); + } + + let res = res.keys.pop().expect("Checked length above"); + + if return_when_finish && res.ids.len() <= self.config.default_batch_size { + break; + } + for x in res.ids { + total += 1; + for (k, v) in x.map { + let (key, value) = match Self::parse_key_value(&k, &v) { + Ok(x) => x, + Err(e) => { + REDIS_BROKEN_MESSAGES + .with_label_values(&[&self.config.stream_name]) + .inc(); + tracing::error!("error parsing key-value {k}-{v:?}: {e:?}"); + continue; + } + }; + self.insert_event(key, value); + } + if total.is_power_of_two() { + tracing::debug!("endpoints read {}", total); + } + *last_id = x.id; + } + } + tracing::info!("read {} endpoints/branches/projects from redis", total); + Ok(()) + } +} diff --git a/proxy/src/config.rs b/proxy/src/config.rs index fc490c7348..3bdfb3cfad 100644 --- a/proxy/src/config.rs +++ b/proxy/src/config.rs @@ -313,6 +313,75 @@ impl CertResolver { } } +#[derive(Debug)] +pub struct EndpointCacheConfig { + /// Batch size to receive all endpoints on the startup. + pub initial_batch_size: usize, + /// Batch size to receive endpoints. + pub default_batch_size: usize, + /// Timeouts for the stream read operation. + pub xread_timeout: Duration, + /// Stream name to read from. + pub stream_name: String, + /// Limiter info (to distinguish when to enable cache). + pub limiter_info: Vec, + /// Disable cache. + /// If true, cache is ignored, but reports all statistics. + pub disable_cache: bool, +} + +impl EndpointCacheConfig { + /// Default options for [`crate::console::provider::NodeInfoCache`]. + /// Notice that by default the limiter is empty, which means that cache is disabled. + pub const CACHE_DEFAULT_OPTIONS: &'static str = + "initial_batch_size=1000,default_batch_size=10,xread_timeout=5m,stream_name=controlPlane,disable_cache=true,limiter_info=1000@1s"; + + /// Parse cache options passed via cmdline. + /// Example: [`Self::CACHE_DEFAULT_OPTIONS`]. + fn parse(options: &str) -> anyhow::Result { + let mut initial_batch_size = None; + let mut default_batch_size = None; + let mut xread_timeout = None; + let mut stream_name = None; + let mut limiter_info = vec![]; + let mut disable_cache = false; + + for option in options.split(',') { + let (key, value) = option + .split_once('=') + .with_context(|| format!("bad key-value pair: {option}"))?; + + match key { + "initial_batch_size" => initial_batch_size = Some(value.parse()?), + "default_batch_size" => default_batch_size = Some(value.parse()?), + "xread_timeout" => xread_timeout = Some(humantime::parse_duration(value)?), + "stream_name" => stream_name = Some(value.to_string()), + "limiter_info" => limiter_info.push(RateBucketInfo::from_str(value)?), + "disable_cache" => disable_cache = value.parse()?, + unknown => bail!("unknown key: {unknown}"), + } + } + RateBucketInfo::validate(&mut limiter_info)?; + + Ok(Self { + initial_batch_size: initial_batch_size.context("missing `initial_batch_size`")?, + default_batch_size: default_batch_size.context("missing `default_batch_size`")?, + xread_timeout: xread_timeout.context("missing `xread_timeout`")?, + stream_name: stream_name.context("missing `stream_name`")?, + disable_cache, + limiter_info, + }) + } +} + +impl FromStr for EndpointCacheConfig { + type Err = anyhow::Error; + + fn from_str(options: &str) -> Result { + let error = || format!("failed to parse endpoint cache options '{options}'"); + Self::parse(options).with_context(error) + } +} #[derive(Debug)] pub struct MetricBackupCollectionConfig { pub interval: Duration, diff --git a/proxy/src/console/provider.rs b/proxy/src/console/provider.rs index f7d621fb12..ee2bc866ab 100644 --- a/proxy/src/console/provider.rs +++ b/proxy/src/console/provider.rs @@ -8,15 +8,15 @@ use crate::{ backend::{ComputeCredentialKeys, ComputeUserInfo}, IpPattern, }, - cache::{project_info::ProjectInfoCacheImpl, Cached, TimedLru}, + cache::{endpoints::EndpointsCache, project_info::ProjectInfoCacheImpl, Cached, TimedLru}, compute, - config::{CacheOptions, ProjectInfoCacheOptions}, + config::{CacheOptions, EndpointCacheConfig, ProjectInfoCacheOptions}, context::RequestMonitoring, intern::ProjectIdInt, scram, EndpointCacheKey, }; use dashmap::DashMap; -use std::{sync::Arc, time::Duration}; +use std::{convert::Infallible, sync::Arc, time::Duration}; use tokio::sync::{OwnedSemaphorePermit, Semaphore}; use tokio::time::Instant; use tracing::info; @@ -416,12 +416,15 @@ pub struct ApiCaches { pub node_info: NodeInfoCache, /// Cache which stores project_id -> endpoint_ids mapping. pub project_info: Arc, + /// List of all valid endpoints. + pub endpoints_cache: Arc, } impl ApiCaches { pub fn new( wake_compute_cache_config: CacheOptions, project_info_cache_config: ProjectInfoCacheOptions, + endpoint_cache_config: EndpointCacheConfig, ) -> Self { Self { node_info: NodeInfoCache::new( @@ -431,6 +434,7 @@ impl ApiCaches { true, ), project_info: Arc::new(ProjectInfoCacheImpl::new(project_info_cache_config)), + endpoints_cache: Arc::new(EndpointsCache::new(endpoint_cache_config)), } } } @@ -441,6 +445,7 @@ pub struct ApiLocks { node_locks: DashMap>, permits: usize, timeout: Duration, + epoch: std::time::Duration, registered: prometheus::IntCounter, unregistered: prometheus::IntCounter, reclamation_lag: prometheus::Histogram, @@ -453,6 +458,7 @@ impl ApiLocks { permits: usize, shards: usize, timeout: Duration, + epoch: std::time::Duration, ) -> prometheus::Result { let registered = prometheus::IntCounter::with_opts( prometheus::Opts::new( @@ -497,6 +503,7 @@ impl ApiLocks { node_locks: DashMap::with_shard_amount(shards), permits, timeout, + epoch, lock_acquire_lag, registered, unregistered, @@ -536,12 +543,9 @@ impl ApiLocks { }) } - pub async fn garbage_collect_worker(&self, epoch: std::time::Duration) { - if self.permits == 0 { - return; - } - - let mut interval = tokio::time::interval(epoch / (self.node_locks.shards().len()) as u32); + pub async fn garbage_collect_worker(&self) -> anyhow::Result { + let mut interval = + tokio::time::interval(self.epoch / (self.node_locks.shards().len()) as u32); loop { for (i, shard) in self.node_locks.shards().iter().enumerate() { interval.tick().await; diff --git a/proxy/src/console/provider/neon.rs b/proxy/src/console/provider/neon.rs index 1a3e2ca795..3a0e5609d8 100644 --- a/proxy/src/console/provider/neon.rs +++ b/proxy/src/console/provider/neon.rs @@ -8,6 +8,7 @@ use super::{ }; use crate::{ auth::backend::ComputeUserInfo, compute, console::messages::ColdStartInfo, http, scram, + Normalize, }; use crate::{ cache::Cached, @@ -23,7 +24,7 @@ use tracing::{error, info, info_span, warn, Instrument}; pub struct Api { endpoint: http::Endpoint, pub caches: &'static ApiCaches, - locks: &'static ApiLocks, + pub locks: &'static ApiLocks, jwt: String, } @@ -55,6 +56,15 @@ impl Api { ctx: &mut RequestMonitoring, user_info: &ComputeUserInfo, ) -> Result { + if !self + .caches + .endpoints_cache + .is_valid(ctx, &user_info.endpoint) + .await + { + info!("endpoint is not valid, skipping the request"); + return Ok(AuthInfo::default()); + } let request_id = ctx.session_id.to_string(); let application_name = ctx.console_application_name(); async { @@ -81,7 +91,9 @@ impl Api { Ok(body) => body, // Error 404 is special: it's ok not to have a secret. Err(e) => match e.http_status_code() { - Some(http::StatusCode::NOT_FOUND) => return Ok(AuthInfo::default()), + Some(http::StatusCode::NOT_FOUND) => { + return Ok(AuthInfo::default()); + } _otherwise => return Err(e.into()), }, }; @@ -181,7 +193,7 @@ impl super::Api for Api { } let auth_info = self.do_get_auth_info(ctx, user_info).await?; if let Some(project_id) = auth_info.project_id { - let ep_int = ep.into(); + let ep_int = ep.normalize().into(); self.caches.project_info.insert_role_secret( project_id, ep_int, @@ -218,7 +230,7 @@ impl super::Api for Api { let allowed_ips = Arc::new(auth_info.allowed_ips); let user = &user_info.user; if let Some(project_id) = auth_info.project_id { - let ep_int = ep.into(); + let ep_int = ep.normalize().into(); self.caches.project_info.insert_role_secret( project_id, ep_int, diff --git a/proxy/src/context.rs b/proxy/src/context.rs index fec95f4722..85544f1d65 100644 --- a/proxy/src/context.rs +++ b/proxy/src/context.rs @@ -12,7 +12,9 @@ use crate::{ console::messages::{ColdStartInfo, MetricsAuxInfo}, error::ErrorKind, intern::{BranchIdInt, ProjectIdInt}, - metrics::{LatencyTimer, ENDPOINT_ERRORS_BY_KIND, ERROR_BY_KIND}, + metrics::{ + bool_to_str, LatencyTimer, ENDPOINT_ERRORS_BY_KIND, ERROR_BY_KIND, NUM_INVALID_ENDPOINTS, + }, DbName, EndpointId, RoleName, }; @@ -50,6 +52,8 @@ pub struct RequestMonitoring { // This sender is here to keep the request monitoring channel open while requests are taking place. sender: Option>, pub latency_timer: LatencyTimer, + // Whether proxy decided that it's not a valid endpoint end rejected it before going to cplane. + rejected: bool, } #[derive(Clone, Debug)] @@ -93,6 +97,7 @@ impl RequestMonitoring { error_kind: None, auth_method: None, success: false, + rejected: false, cold_start_info: ColdStartInfo::Unknown, sender: LOG_CHAN.get().and_then(|tx| tx.upgrade()), @@ -113,6 +118,10 @@ impl RequestMonitoring { ) } + pub fn set_rejected(&mut self, rejected: bool) { + self.rejected = rejected; + } + pub fn set_cold_start_info(&mut self, info: ColdStartInfo) { self.cold_start_info = info; self.latency_timer.cold_start_info(info); @@ -178,6 +187,10 @@ impl RequestMonitoring { impl Drop for RequestMonitoring { fn drop(&mut self) { + let outcome = if self.success { "success" } else { "failure" }; + NUM_INVALID_ENDPOINTS + .with_label_values(&[self.protocol, bool_to_str(self.rejected), outcome]) + .inc(); if let Some(tx) = self.sender.take() { let _: Result<(), _> = tx.send(RequestData::from(&*self)); } diff --git a/proxy/src/intern.rs b/proxy/src/intern.rs index a6519bdff9..e38135dd22 100644 --- a/proxy/src/intern.rs +++ b/proxy/src/intern.rs @@ -160,6 +160,11 @@ impl From<&EndpointId> for EndpointIdInt { EndpointIdTag::get_interner().get_or_intern(value) } } +impl From for EndpointIdInt { + fn from(value: EndpointId) -> Self { + EndpointIdTag::get_interner().get_or_intern(&value) + } +} #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] pub struct BranchIdTag; @@ -175,6 +180,11 @@ impl From<&BranchId> for BranchIdInt { BranchIdTag::get_interner().get_or_intern(value) } } +impl From for BranchIdInt { + fn from(value: BranchId) -> Self { + BranchIdTag::get_interner().get_or_intern(&value) + } +} #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] pub struct ProjectIdTag; @@ -190,6 +200,11 @@ impl From<&ProjectId> for ProjectIdInt { ProjectIdTag::get_interner().get_or_intern(value) } } +impl From for ProjectIdInt { + fn from(value: ProjectId) -> Self { + ProjectIdTag::get_interner().get_or_intern(&value) + } +} #[cfg(test)] mod tests { diff --git a/proxy/src/lib.rs b/proxy/src/lib.rs index da7c7f3ed2..3f6d985fe8 100644 --- a/proxy/src/lib.rs +++ b/proxy/src/lib.rs @@ -127,6 +127,24 @@ macro_rules! smol_str_wrapper { }; } +const POOLER_SUFFIX: &str = "-pooler"; + +pub trait Normalize { + fn normalize(&self) -> Self; +} + +impl + From> Normalize for S { + fn normalize(&self) -> Self { + if self.as_ref().ends_with(POOLER_SUFFIX) { + let mut s = self.as_ref().to_string(); + s.truncate(s.len() - POOLER_SUFFIX.len()); + s.into() + } else { + self.clone() + } + } +} + // 90% of role name strings are 20 characters or less. smol_str_wrapper!(RoleName); // 50% of endpoint strings are 23 characters or less. @@ -140,3 +158,22 @@ smol_str_wrapper!(ProjectId); smol_str_wrapper!(EndpointCacheKey); smol_str_wrapper!(DbName); + +// Endpoints are a bit tricky. Rare they might be branches or projects. +impl EndpointId { + pub fn is_endpoint(&self) -> bool { + self.0.starts_with("ep-") + } + pub fn is_branch(&self) -> bool { + self.0.starts_with("br-") + } + pub fn is_project(&self) -> bool { + !self.is_endpoint() && !self.is_branch() + } + pub fn as_branch(&self) -> BranchId { + BranchId(self.0.clone()) + } + pub fn as_project(&self) -> ProjectId { + ProjectId(self.0.clone()) + } +} diff --git a/proxy/src/metrics.rs b/proxy/src/metrics.rs index 59ee899c08..f299313e0a 100644 --- a/proxy/src/metrics.rs +++ b/proxy/src/metrics.rs @@ -169,6 +169,18 @@ pub static NUM_CANCELLATION_REQUESTS: Lazy = Lazy::new(|| { .unwrap() }); +pub static NUM_INVALID_ENDPOINTS: Lazy = Lazy::new(|| { + register_int_counter_vec!( + "proxy_invalid_endpoints_total", + "Number of invalid endpoints (per protocol, per rejected).", + // http/ws/tcp, true/false, success/failure + // TODO(anna): the last dimension is just a proxy to what we actually want to measure. + // We need to measure whether the endpoint was found by cplane or not. + &["protocol", "rejected", "outcome"], + ) + .unwrap() +}); + pub const NUM_CANCELLATION_REQUESTS_SOURCE_FROM_CLIENT: &str = "from_client"; pub const NUM_CANCELLATION_REQUESTS_SOURCE_FROM_REDIS: &str = "from_redis"; diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index 6051c0a812..166e761a4e 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -20,7 +20,7 @@ use crate::{ proxy::handshake::{handshake, HandshakeData}, rate_limiter::EndpointRateLimiter, stream::{PqStream, Stream}, - EndpointCacheKey, + EndpointCacheKey, Normalize, }; use futures::TryFutureExt; use itertools::Itertools; @@ -280,7 +280,7 @@ pub async fn handle_client( // check rate limit if let Some(ep) = user_info.get_endpoint() { - if !endpoint_rate_limiter.check(ep, 1) { + if !endpoint_rate_limiter.check(ep.normalize(), 1) { return stream .throw_error(auth::AuthError::too_many_connections()) .await?; diff --git a/proxy/src/rate_limiter.rs b/proxy/src/rate_limiter.rs index 13dffffca0..a3b83e5e50 100644 --- a/proxy/src/rate_limiter.rs +++ b/proxy/src/rate_limiter.rs @@ -4,4 +4,4 @@ mod limiter; pub use aimd::Aimd; pub use limit_algorithm::{AimdConfig, Fixed, RateLimitAlgorithm, RateLimiterConfig}; pub use limiter::Limiter; -pub use limiter::{AuthRateLimiter, EndpointRateLimiter, RateBucketInfo, RedisRateLimiter}; +pub use limiter::{AuthRateLimiter, EndpointRateLimiter, GlobalRateLimiter, RateBucketInfo}; diff --git a/proxy/src/rate_limiter/limiter.rs b/proxy/src/rate_limiter/limiter.rs index f590896dd9..0503deb311 100644 --- a/proxy/src/rate_limiter/limiter.rs +++ b/proxy/src/rate_limiter/limiter.rs @@ -24,13 +24,13 @@ use super::{ RateLimiterConfig, }; -pub struct RedisRateLimiter { +pub struct GlobalRateLimiter { data: Vec, - info: &'static [RateBucketInfo], + info: Vec, } -impl RedisRateLimiter { - pub fn new(info: &'static [RateBucketInfo]) -> Self { +impl GlobalRateLimiter { + pub fn new(info: Vec) -> Self { Self { data: vec![ RateBucket { @@ -50,7 +50,7 @@ impl RedisRateLimiter { let should_allow_request = self .data .iter_mut() - .zip(self.info) + .zip(&self.info) .all(|(bucket, info)| bucket.should_allow_request(info, now, 1)); if should_allow_request { diff --git a/proxy/src/redis/cancellation_publisher.rs b/proxy/src/redis/cancellation_publisher.rs index 422789813c..7baf104374 100644 --- a/proxy/src/redis/cancellation_publisher.rs +++ b/proxy/src/redis/cancellation_publisher.rs @@ -5,7 +5,7 @@ use redis::AsyncCommands; use tokio::sync::Mutex; use uuid::Uuid; -use crate::rate_limiter::{RateBucketInfo, RedisRateLimiter}; +use crate::rate_limiter::{GlobalRateLimiter, RateBucketInfo}; use super::{ connection_with_credentials_provider::ConnectionWithCredentialsProvider, @@ -80,7 +80,7 @@ impl CancellationPublisher for Arc> { pub struct RedisPublisherClient { client: ConnectionWithCredentialsProvider, region_id: String, - limiter: RedisRateLimiter, + limiter: GlobalRateLimiter, } impl RedisPublisherClient { @@ -92,7 +92,7 @@ impl RedisPublisherClient { Ok(Self { client, region_id, - limiter: RedisRateLimiter::new(info), + limiter: GlobalRateLimiter::new(info.into()), }) } diff --git a/test_runner/regress/test_proxy_rate_limiter.py b/test_runner/regress/test_proxy_rate_limiter.py deleted file mode 100644 index f39f0cad07..0000000000 --- a/test_runner/regress/test_proxy_rate_limiter.py +++ /dev/null @@ -1,84 +0,0 @@ -import asyncio -import time -from pathlib import Path -from typing import Iterator - -import pytest -from fixtures.neon_fixtures import ( - PSQL, - NeonProxy, -) -from fixtures.port_distributor import PortDistributor -from pytest_httpserver import HTTPServer -from werkzeug.wrappers.response import Response - - -def waiting_handler(status_code: int) -> Response: - # wait more than timeout to make sure that both (two) connections are open. - # It would be better to use a barrier here, but I don't know how to do that together with pytest-httpserver. - time.sleep(2) - return Response(status=status_code) - - -@pytest.fixture(scope="function") -def proxy_with_rate_limit( - port_distributor: PortDistributor, - neon_binpath: Path, - httpserver_listen_address, - test_output_dir: Path, -) -> Iterator[NeonProxy]: - """Neon proxy that routes directly to vanilla postgres.""" - - proxy_port = port_distributor.get_port() - mgmt_port = port_distributor.get_port() - http_port = port_distributor.get_port() - external_http_port = port_distributor.get_port() - (host, port) = httpserver_listen_address - endpoint = f"http://{host}:{port}/billing/api/v1/usage_events" - - with NeonProxy( - neon_binpath=neon_binpath, - test_output_dir=test_output_dir, - proxy_port=proxy_port, - http_port=http_port, - mgmt_port=mgmt_port, - external_http_port=external_http_port, - auth_backend=NeonProxy.Console(endpoint, fixed_rate_limit=5), - ) as proxy: - proxy.start() - yield proxy - - -@pytest.mark.asyncio -async def test_proxy_rate_limit( - httpserver: HTTPServer, - proxy_with_rate_limit: NeonProxy, -): - uri = "/billing/api/v1/usage_events/proxy_get_role_secret" - # mock control plane service - httpserver.expect_ordered_request(uri, method="GET").respond_with_handler( - lambda _: Response(status=200) - ) - httpserver.expect_ordered_request(uri, method="GET").respond_with_handler( - lambda _: waiting_handler(429) - ) - httpserver.expect_ordered_request(uri, method="GET").respond_with_handler( - lambda _: waiting_handler(500) - ) - - psql = PSQL(host=proxy_with_rate_limit.host, port=proxy_with_rate_limit.proxy_port) - f = await psql.run("select 42;") - await proxy_with_rate_limit.find_auth_link(uri, f) - # Limit should be 2. - - # Run two queries in parallel. - f1, f2 = await asyncio.gather(psql.run("select 42;"), psql.run("select 42;")) - await proxy_with_rate_limit.find_auth_link(uri, f1) - await proxy_with_rate_limit.find_auth_link(uri, f2) - - # Now limit should be 0. - f = await psql.run("select 42;") - await proxy_with_rate_limit.find_auth_link(uri, f) - - # There last query shouldn't reach the http-server. - assert httpserver.assertions == []