From 0bb04ebe19c1dd024c7762926ecce166f4259d82 Mon Sep 17 00:00:00 2001 From: Anna Khanova <32508607+khanova@users.noreply.github.com> Date: Wed, 10 Apr 2024 12:12:55 +0200 Subject: [PATCH] Revert "Proxy read ids from redis (#7205)" (#7350) This reverts commit dbac2d2c473f3648251f0a64e36d066f444dfe00. ## Problem Proxy pods fails to install in k8s clusters, cplane release blocking. ## Summary of changes Revert --- proxy/src/auth/backend.rs | 4 +- proxy/src/bin/proxy.rs | 15 +- proxy/src/cache.rs | 1 - proxy/src/cache/endpoints.rs | 190 ------------------ proxy/src/config.rs | 69 ------- proxy/src/console/provider.rs | 22 +- proxy/src/console/provider/neon.rs | 46 ++--- proxy/src/context.rs | 15 +- proxy/src/intern.rs | 15 -- proxy/src/lib.rs | 37 ---- proxy/src/metrics.rs | 12 -- proxy/src/proxy.rs | 4 +- proxy/src/rate_limiter.rs | 2 +- proxy/src/rate_limiter/limiter.rs | 10 +- proxy/src/redis/cancellation_publisher.rs | 6 +- .../regress/test_proxy_rate_limiter.py | 84 ++++++++ 16 files changed, 124 insertions(+), 408 deletions(-) delete mode 100644 proxy/src/cache/endpoints.rs create mode 100644 test_runner/regress/test_proxy_rate_limiter.py diff --git a/proxy/src/auth/backend.rs b/proxy/src/auth/backend.rs index 71e9da18bc..e421798067 100644 --- a/proxy/src/auth/backend.rs +++ b/proxy/src/auth/backend.rs @@ -27,7 +27,7 @@ use crate::{ }, stream, url, }; -use crate::{scram, EndpointCacheKey, EndpointId, Normalize, RoleName}; +use crate::{scram, EndpointCacheKey, EndpointId, RoleName}; use std::sync::Arc; use tokio::io::{AsyncRead, AsyncWrite}; use tracing::{info, warn}; @@ -186,7 +186,7 @@ impl AuthenticationConfig { is_cleartext: bool, ) -> auth::Result { // we have validated the endpoint exists, so let's intern it. - let endpoint_int = EndpointIdInt::from(endpoint.normalize()); + let endpoint_int = EndpointIdInt::from(endpoint); // only count the full hash count if password hack or websocket flow. // in other words, if proxy needs to run the hashing diff --git a/proxy/src/bin/proxy.rs b/proxy/src/bin/proxy.rs index 9302b31d5c..56a3ef79cd 100644 --- a/proxy/src/bin/proxy.rs +++ b/proxy/src/bin/proxy.rs @@ -189,9 +189,7 @@ struct ProxyCliArgs { /// cache for `project_info` (use `size=0` to disable) #[clap(long, default_value = config::ProjectInfoCacheOptions::CACHE_DEFAULT_OPTIONS)] project_info_cache: String, - /// cache for all valid endpoints - #[clap(long, default_value = config::EndpointCacheConfig::CACHE_DEFAULT_OPTIONS)] - endpoint_cache_config: String, + #[clap(flatten)] parquet_upload: ParquetUploadArgs, @@ -403,7 +401,6 @@ async fn main() -> anyhow::Result<()> { if let auth::BackendType::Console(api, _) = &config.auth_backend { if let proxy::console::provider::ConsoleBackend::Console(api) = &**api { - maintenance_tasks.spawn(api.locks.garbage_collect_worker()); if let Some(redis_notifications_client) = redis_notifications_client { let cache = api.caches.project_info.clone(); maintenance_tasks.spawn(notifications::task_main( @@ -413,9 +410,6 @@ async fn main() -> anyhow::Result<()> { args.region.clone(), )); maintenance_tasks.spawn(async move { cache.clone().gc_worker().await }); - let cache = api.caches.endpoints_cache.clone(); - let con = redis_notifications_client.clone(); - maintenance_tasks.spawn(async move { cache.do_read(con).await }); } } } @@ -495,18 +489,14 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { let wake_compute_cache_config: CacheOptions = args.wake_compute_cache.parse()?; let project_info_cache_config: ProjectInfoCacheOptions = args.project_info_cache.parse()?; - let endpoint_cache_config: config::EndpointCacheConfig = - args.endpoint_cache_config.parse()?; info!("Using NodeInfoCache (wake_compute) with options={wake_compute_cache_config:?}"); info!( "Using AllowedIpsCache (wake_compute) with options={project_info_cache_config:?}" ); - info!("Using EndpointCacheConfig with options={endpoint_cache_config:?}"); let caches = Box::leak(Box::new(console::caches::ApiCaches::new( wake_compute_cache_config, project_info_cache_config, - endpoint_cache_config, ))); let config::WakeComputeLockOptions { @@ -517,9 +507,10 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { } = args.wake_compute_lock.parse()?; info!(permits, shards, ?epoch, "Using NodeLocks (wake_compute)"); let locks = Box::leak(Box::new( - console::locks::ApiLocks::new("wake_compute_lock", permits, shards, timeout, epoch) + console::locks::ApiLocks::new("wake_compute_lock", permits, shards, timeout) .unwrap(), )); + tokio::spawn(locks.garbage_collect_worker(epoch)); let url = args.auth_endpoint.parse()?; let endpoint = http::Endpoint::new(url, http::new_client(rate_limiter_config)); diff --git a/proxy/src/cache.rs b/proxy/src/cache.rs index d1d4087241..fc5f416395 100644 --- a/proxy/src/cache.rs +++ b/proxy/src/cache.rs @@ -1,5 +1,4 @@ pub mod common; -pub mod endpoints; pub mod project_info; mod timed_lru; diff --git a/proxy/src/cache/endpoints.rs b/proxy/src/cache/endpoints.rs deleted file mode 100644 index 31e3ef6891..0000000000 --- a/proxy/src/cache/endpoints.rs +++ /dev/null @@ -1,190 +0,0 @@ -use std::{ - convert::Infallible, - sync::{ - atomic::{AtomicBool, Ordering}, - Arc, - }, -}; - -use dashmap::DashSet; -use redis::{ - streams::{StreamReadOptions, StreamReadReply}, - AsyncCommands, FromRedisValue, Value, -}; -use serde::Deserialize; -use tokio::sync::Mutex; - -use crate::{ - config::EndpointCacheConfig, - context::RequestMonitoring, - intern::{BranchIdInt, EndpointIdInt, ProjectIdInt}, - metrics::REDIS_BROKEN_MESSAGES, - rate_limiter::GlobalRateLimiter, - redis::connection_with_credentials_provider::ConnectionWithCredentialsProvider, - EndpointId, -}; - -#[derive(Deserialize, Debug, Clone)] -#[serde(rename_all(deserialize = "snake_case"))] -pub enum ControlPlaneEventKey { - EndpointCreated, - BranchCreated, - ProjectCreated, -} - -pub struct EndpointsCache { - config: EndpointCacheConfig, - endpoints: DashSet, - branches: DashSet, - projects: DashSet, - ready: AtomicBool, - limiter: Arc>, -} - -impl EndpointsCache { - pub fn new(config: EndpointCacheConfig) -> Self { - Self { - limiter: Arc::new(Mutex::new(GlobalRateLimiter::new( - config.limiter_info.clone(), - ))), - config, - endpoints: DashSet::new(), - branches: DashSet::new(), - projects: DashSet::new(), - ready: AtomicBool::new(false), - } - } - pub async fn is_valid(&self, ctx: &mut RequestMonitoring, endpoint: &EndpointId) -> bool { - if !self.ready.load(Ordering::Acquire) { - return true; - } - // If cache is disabled, just collect the metrics and return. - if self.config.disable_cache { - ctx.set_rejected(self.should_reject(endpoint)); - return true; - } - // If the limiter allows, we don't need to check the cache. - if self.limiter.lock().await.check() { - return true; - } - let rejected = self.should_reject(endpoint); - ctx.set_rejected(rejected); - !rejected - } - fn should_reject(&self, endpoint: &EndpointId) -> bool { - if endpoint.is_endpoint() { - !self.endpoints.contains(&EndpointIdInt::from(endpoint)) - } else if endpoint.is_branch() { - !self - .branches - .contains(&BranchIdInt::from(&endpoint.as_branch())) - } else { - !self - .projects - .contains(&ProjectIdInt::from(&endpoint.as_project())) - } - } - fn insert_event(&self, key: ControlPlaneEventKey, value: String) { - // Do not do normalization here, we expect the events to be normalized. - match key { - ControlPlaneEventKey::EndpointCreated => { - self.endpoints.insert(EndpointIdInt::from(&value.into())); - } - ControlPlaneEventKey::BranchCreated => { - self.branches.insert(BranchIdInt::from(&value.into())); - } - ControlPlaneEventKey::ProjectCreated => { - self.projects.insert(ProjectIdInt::from(&value.into())); - } - } - } - pub async fn do_read( - &self, - mut con: ConnectionWithCredentialsProvider, - ) -> anyhow::Result { - let mut last_id = "0-0".to_string(); - loop { - self.ready.store(false, Ordering::Release); - if let Err(e) = con.connect().await { - tracing::error!("error connecting to redis: {:?}", e); - continue; - } - if let Err(e) = self.read_from_stream(&mut con, &mut last_id).await { - tracing::error!("error reading from redis: {:?}", e); - } - } - } - async fn read_from_stream( - &self, - con: &mut ConnectionWithCredentialsProvider, - last_id: &mut String, - ) -> anyhow::Result<()> { - tracing::info!("reading endpoints/branches/projects from redis"); - self.batch_read( - con, - StreamReadOptions::default().count(self.config.initial_batch_size), - last_id, - true, - ) - .await?; - tracing::info!("ready to filter user requests"); - self.ready.store(true, Ordering::Release); - self.batch_read( - con, - StreamReadOptions::default() - .count(self.config.initial_batch_size) - .block(self.config.xread_timeout.as_millis() as usize), - last_id, - false, - ) - .await - } - fn parse_key_value(key: &str, value: &Value) -> anyhow::Result<(ControlPlaneEventKey, String)> { - Ok((serde_json::from_str(key)?, String::from_redis_value(value)?)) - } - async fn batch_read( - &self, - conn: &mut ConnectionWithCredentialsProvider, - opts: StreamReadOptions, - last_id: &mut String, - return_when_finish: bool, - ) -> anyhow::Result<()> { - let mut total: usize = 0; - loop { - let mut res: StreamReadReply = conn - .xread_options(&[&self.config.stream_name], &[last_id.as_str()], &opts) - .await?; - if res.keys.len() != 1 { - anyhow::bail!("Cannot read from redis stream {}", self.config.stream_name); - } - - let res = res.keys.pop().expect("Checked length above"); - - if return_when_finish && res.ids.len() <= self.config.default_batch_size { - break; - } - for x in res.ids { - total += 1; - for (k, v) in x.map { - let (key, value) = match Self::parse_key_value(&k, &v) { - Ok(x) => x, - Err(e) => { - REDIS_BROKEN_MESSAGES - .with_label_values(&[&self.config.stream_name]) - .inc(); - tracing::error!("error parsing key-value {k}-{v:?}: {e:?}"); - continue; - } - }; - self.insert_event(key, value); - } - if total.is_power_of_two() { - tracing::debug!("endpoints read {}", total); - } - *last_id = x.id; - } - } - tracing::info!("read {} endpoints/branches/projects from redis", total); - Ok(()) - } -} diff --git a/proxy/src/config.rs b/proxy/src/config.rs index 3bdfb3cfad..fc490c7348 100644 --- a/proxy/src/config.rs +++ b/proxy/src/config.rs @@ -313,75 +313,6 @@ impl CertResolver { } } -#[derive(Debug)] -pub struct EndpointCacheConfig { - /// Batch size to receive all endpoints on the startup. - pub initial_batch_size: usize, - /// Batch size to receive endpoints. - pub default_batch_size: usize, - /// Timeouts for the stream read operation. - pub xread_timeout: Duration, - /// Stream name to read from. - pub stream_name: String, - /// Limiter info (to distinguish when to enable cache). - pub limiter_info: Vec, - /// Disable cache. - /// If true, cache is ignored, but reports all statistics. - pub disable_cache: bool, -} - -impl EndpointCacheConfig { - /// Default options for [`crate::console::provider::NodeInfoCache`]. - /// Notice that by default the limiter is empty, which means that cache is disabled. - pub const CACHE_DEFAULT_OPTIONS: &'static str = - "initial_batch_size=1000,default_batch_size=10,xread_timeout=5m,stream_name=controlPlane,disable_cache=true,limiter_info=1000@1s"; - - /// Parse cache options passed via cmdline. - /// Example: [`Self::CACHE_DEFAULT_OPTIONS`]. - fn parse(options: &str) -> anyhow::Result { - let mut initial_batch_size = None; - let mut default_batch_size = None; - let mut xread_timeout = None; - let mut stream_name = None; - let mut limiter_info = vec![]; - let mut disable_cache = false; - - for option in options.split(',') { - let (key, value) = option - .split_once('=') - .with_context(|| format!("bad key-value pair: {option}"))?; - - match key { - "initial_batch_size" => initial_batch_size = Some(value.parse()?), - "default_batch_size" => default_batch_size = Some(value.parse()?), - "xread_timeout" => xread_timeout = Some(humantime::parse_duration(value)?), - "stream_name" => stream_name = Some(value.to_string()), - "limiter_info" => limiter_info.push(RateBucketInfo::from_str(value)?), - "disable_cache" => disable_cache = value.parse()?, - unknown => bail!("unknown key: {unknown}"), - } - } - RateBucketInfo::validate(&mut limiter_info)?; - - Ok(Self { - initial_batch_size: initial_batch_size.context("missing `initial_batch_size`")?, - default_batch_size: default_batch_size.context("missing `default_batch_size`")?, - xread_timeout: xread_timeout.context("missing `xread_timeout`")?, - stream_name: stream_name.context("missing `stream_name`")?, - disable_cache, - limiter_info, - }) - } -} - -impl FromStr for EndpointCacheConfig { - type Err = anyhow::Error; - - fn from_str(options: &str) -> Result { - let error = || format!("failed to parse endpoint cache options '{options}'"); - Self::parse(options).with_context(error) - } -} #[derive(Debug)] pub struct MetricBackupCollectionConfig { pub interval: Duration, diff --git a/proxy/src/console/provider.rs b/proxy/src/console/provider.rs index ee2bc866ab..f7d621fb12 100644 --- a/proxy/src/console/provider.rs +++ b/proxy/src/console/provider.rs @@ -8,15 +8,15 @@ use crate::{ backend::{ComputeCredentialKeys, ComputeUserInfo}, IpPattern, }, - cache::{endpoints::EndpointsCache, project_info::ProjectInfoCacheImpl, Cached, TimedLru}, + cache::{project_info::ProjectInfoCacheImpl, Cached, TimedLru}, compute, - config::{CacheOptions, EndpointCacheConfig, ProjectInfoCacheOptions}, + config::{CacheOptions, ProjectInfoCacheOptions}, context::RequestMonitoring, intern::ProjectIdInt, scram, EndpointCacheKey, }; use dashmap::DashMap; -use std::{convert::Infallible, sync::Arc, time::Duration}; +use std::{sync::Arc, time::Duration}; use tokio::sync::{OwnedSemaphorePermit, Semaphore}; use tokio::time::Instant; use tracing::info; @@ -416,15 +416,12 @@ pub struct ApiCaches { pub node_info: NodeInfoCache, /// Cache which stores project_id -> endpoint_ids mapping. pub project_info: Arc, - /// List of all valid endpoints. - pub endpoints_cache: Arc, } impl ApiCaches { pub fn new( wake_compute_cache_config: CacheOptions, project_info_cache_config: ProjectInfoCacheOptions, - endpoint_cache_config: EndpointCacheConfig, ) -> Self { Self { node_info: NodeInfoCache::new( @@ -434,7 +431,6 @@ impl ApiCaches { true, ), project_info: Arc::new(ProjectInfoCacheImpl::new(project_info_cache_config)), - endpoints_cache: Arc::new(EndpointsCache::new(endpoint_cache_config)), } } } @@ -445,7 +441,6 @@ pub struct ApiLocks { node_locks: DashMap>, permits: usize, timeout: Duration, - epoch: std::time::Duration, registered: prometheus::IntCounter, unregistered: prometheus::IntCounter, reclamation_lag: prometheus::Histogram, @@ -458,7 +453,6 @@ impl ApiLocks { permits: usize, shards: usize, timeout: Duration, - epoch: std::time::Duration, ) -> prometheus::Result { let registered = prometheus::IntCounter::with_opts( prometheus::Opts::new( @@ -503,7 +497,6 @@ impl ApiLocks { node_locks: DashMap::with_shard_amount(shards), permits, timeout, - epoch, lock_acquire_lag, registered, unregistered, @@ -543,9 +536,12 @@ impl ApiLocks { }) } - pub async fn garbage_collect_worker(&self) -> anyhow::Result { - let mut interval = - tokio::time::interval(self.epoch / (self.node_locks.shards().len()) as u32); + pub async fn garbage_collect_worker(&self, epoch: std::time::Duration) { + if self.permits == 0 { + return; + } + + let mut interval = tokio::time::interval(epoch / (self.node_locks.shards().len()) as u32); loop { for (i, shard) in self.node_locks.shards().iter().enumerate() { interval.tick().await; diff --git a/proxy/src/console/provider/neon.rs b/proxy/src/console/provider/neon.rs index 68b91447f9..1a3e2ca795 100644 --- a/proxy/src/console/provider/neon.rs +++ b/proxy/src/console/provider/neon.rs @@ -8,7 +8,6 @@ use super::{ }; use crate::{ auth::backend::ComputeUserInfo, compute, console::messages::ColdStartInfo, http, scram, - Normalize, }; use crate::{ cache::Cached, @@ -24,7 +23,7 @@ use tracing::{error, info, info_span, warn, Instrument}; pub struct Api { endpoint: http::Endpoint, pub caches: &'static ApiCaches, - pub locks: &'static ApiLocks, + locks: &'static ApiLocks, jwt: String, } @@ -56,15 +55,6 @@ impl Api { ctx: &mut RequestMonitoring, user_info: &ComputeUserInfo, ) -> Result { - if !self - .caches - .endpoints_cache - .is_valid(ctx, &user_info.endpoint.normalize()) - .await - { - info!("endpoint is not valid, skipping the request"); - return Ok(AuthInfo::default()); - } let request_id = ctx.session_id.to_string(); let application_name = ctx.console_application_name(); async { @@ -91,9 +81,7 @@ impl Api { Ok(body) => body, // Error 404 is special: it's ok not to have a secret. Err(e) => match e.http_status_code() { - Some(http::StatusCode::NOT_FOUND) => { - return Ok(AuthInfo::default()); - } + Some(http::StatusCode::NOT_FOUND) => return Ok(AuthInfo::default()), _otherwise => return Err(e.into()), }, }; @@ -186,27 +174,23 @@ impl super::Api for Api { ctx: &mut RequestMonitoring, user_info: &ComputeUserInfo, ) -> Result { - let normalized_ep = &user_info.endpoint.normalize(); + let ep = &user_info.endpoint; let user = &user_info.user; - if let Some(role_secret) = self - .caches - .project_info - .get_role_secret(normalized_ep, user) - { + if let Some(role_secret) = self.caches.project_info.get_role_secret(ep, user) { return Ok(role_secret); } let auth_info = self.do_get_auth_info(ctx, user_info).await?; if let Some(project_id) = auth_info.project_id { - let normalized_ep_int = normalized_ep.into(); + let ep_int = ep.into(); self.caches.project_info.insert_role_secret( project_id, - normalized_ep_int, + ep_int, user.into(), auth_info.secret.clone(), ); self.caches.project_info.insert_allowed_ips( project_id, - normalized_ep_int, + ep_int, Arc::new(auth_info.allowed_ips), ); ctx.set_project_id(project_id); @@ -220,8 +204,8 @@ impl super::Api for Api { ctx: &mut RequestMonitoring, user_info: &ComputeUserInfo, ) -> Result<(CachedAllowedIps, Option), GetAuthInfoError> { - let normalized_ep = &user_info.endpoint.normalize(); - if let Some(allowed_ips) = self.caches.project_info.get_allowed_ips(normalized_ep) { + let ep = &user_info.endpoint; + if let Some(allowed_ips) = self.caches.project_info.get_allowed_ips(ep) { ALLOWED_IPS_BY_CACHE_OUTCOME .with_label_values(&["hit"]) .inc(); @@ -234,18 +218,16 @@ impl super::Api for Api { let allowed_ips = Arc::new(auth_info.allowed_ips); let user = &user_info.user; if let Some(project_id) = auth_info.project_id { - let normalized_ep_int = normalized_ep.into(); + let ep_int = ep.into(); self.caches.project_info.insert_role_secret( project_id, - normalized_ep_int, + ep_int, user.into(), auth_info.secret.clone(), ); - self.caches.project_info.insert_allowed_ips( - project_id, - normalized_ep_int, - allowed_ips.clone(), - ); + self.caches + .project_info + .insert_allowed_ips(project_id, ep_int, allowed_ips.clone()); ctx.set_project_id(project_id); } Ok(( diff --git a/proxy/src/context.rs b/proxy/src/context.rs index 85544f1d65..fec95f4722 100644 --- a/proxy/src/context.rs +++ b/proxy/src/context.rs @@ -12,9 +12,7 @@ use crate::{ console::messages::{ColdStartInfo, MetricsAuxInfo}, error::ErrorKind, intern::{BranchIdInt, ProjectIdInt}, - metrics::{ - bool_to_str, LatencyTimer, ENDPOINT_ERRORS_BY_KIND, ERROR_BY_KIND, NUM_INVALID_ENDPOINTS, - }, + metrics::{LatencyTimer, ENDPOINT_ERRORS_BY_KIND, ERROR_BY_KIND}, DbName, EndpointId, RoleName, }; @@ -52,8 +50,6 @@ pub struct RequestMonitoring { // This sender is here to keep the request monitoring channel open while requests are taking place. sender: Option>, pub latency_timer: LatencyTimer, - // Whether proxy decided that it's not a valid endpoint end rejected it before going to cplane. - rejected: bool, } #[derive(Clone, Debug)] @@ -97,7 +93,6 @@ impl RequestMonitoring { error_kind: None, auth_method: None, success: false, - rejected: false, cold_start_info: ColdStartInfo::Unknown, sender: LOG_CHAN.get().and_then(|tx| tx.upgrade()), @@ -118,10 +113,6 @@ impl RequestMonitoring { ) } - pub fn set_rejected(&mut self, rejected: bool) { - self.rejected = rejected; - } - pub fn set_cold_start_info(&mut self, info: ColdStartInfo) { self.cold_start_info = info; self.latency_timer.cold_start_info(info); @@ -187,10 +178,6 @@ impl RequestMonitoring { impl Drop for RequestMonitoring { fn drop(&mut self) { - let outcome = if self.success { "success" } else { "failure" }; - NUM_INVALID_ENDPOINTS - .with_label_values(&[self.protocol, bool_to_str(self.rejected), outcome]) - .inc(); if let Some(tx) = self.sender.take() { let _: Result<(), _> = tx.send(RequestData::from(&*self)); } diff --git a/proxy/src/intern.rs b/proxy/src/intern.rs index e38135dd22..a6519bdff9 100644 --- a/proxy/src/intern.rs +++ b/proxy/src/intern.rs @@ -160,11 +160,6 @@ impl From<&EndpointId> for EndpointIdInt { EndpointIdTag::get_interner().get_or_intern(value) } } -impl From for EndpointIdInt { - fn from(value: EndpointId) -> Self { - EndpointIdTag::get_interner().get_or_intern(&value) - } -} #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] pub struct BranchIdTag; @@ -180,11 +175,6 @@ impl From<&BranchId> for BranchIdInt { BranchIdTag::get_interner().get_or_intern(value) } } -impl From for BranchIdInt { - fn from(value: BranchId) -> Self { - BranchIdTag::get_interner().get_or_intern(&value) - } -} #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] pub struct ProjectIdTag; @@ -200,11 +190,6 @@ impl From<&ProjectId> for ProjectIdInt { ProjectIdTag::get_interner().get_or_intern(value) } } -impl From for ProjectIdInt { - fn from(value: ProjectId) -> Self { - ProjectIdTag::get_interner().get_or_intern(&value) - } -} #[cfg(test)] mod tests { diff --git a/proxy/src/lib.rs b/proxy/src/lib.rs index 3f6d985fe8..da7c7f3ed2 100644 --- a/proxy/src/lib.rs +++ b/proxy/src/lib.rs @@ -127,24 +127,6 @@ macro_rules! smol_str_wrapper { }; } -const POOLER_SUFFIX: &str = "-pooler"; - -pub trait Normalize { - fn normalize(&self) -> Self; -} - -impl + From> Normalize for S { - fn normalize(&self) -> Self { - if self.as_ref().ends_with(POOLER_SUFFIX) { - let mut s = self.as_ref().to_string(); - s.truncate(s.len() - POOLER_SUFFIX.len()); - s.into() - } else { - self.clone() - } - } -} - // 90% of role name strings are 20 characters or less. smol_str_wrapper!(RoleName); // 50% of endpoint strings are 23 characters or less. @@ -158,22 +140,3 @@ smol_str_wrapper!(ProjectId); smol_str_wrapper!(EndpointCacheKey); smol_str_wrapper!(DbName); - -// Endpoints are a bit tricky. Rare they might be branches or projects. -impl EndpointId { - pub fn is_endpoint(&self) -> bool { - self.0.starts_with("ep-") - } - pub fn is_branch(&self) -> bool { - self.0.starts_with("br-") - } - pub fn is_project(&self) -> bool { - !self.is_endpoint() && !self.is_branch() - } - pub fn as_branch(&self) -> BranchId { - BranchId(self.0.clone()) - } - pub fn as_project(&self) -> ProjectId { - ProjectId(self.0.clone()) - } -} diff --git a/proxy/src/metrics.rs b/proxy/src/metrics.rs index f299313e0a..59ee899c08 100644 --- a/proxy/src/metrics.rs +++ b/proxy/src/metrics.rs @@ -169,18 +169,6 @@ pub static NUM_CANCELLATION_REQUESTS: Lazy = Lazy::new(|| { .unwrap() }); -pub static NUM_INVALID_ENDPOINTS: Lazy = Lazy::new(|| { - register_int_counter_vec!( - "proxy_invalid_endpoints_total", - "Number of invalid endpoints (per protocol, per rejected).", - // http/ws/tcp, true/false, success/failure - // TODO(anna): the last dimension is just a proxy to what we actually want to measure. - // We need to measure whether the endpoint was found by cplane or not. - &["protocol", "rejected", "outcome"], - ) - .unwrap() -}); - pub const NUM_CANCELLATION_REQUESTS_SOURCE_FROM_CLIENT: &str = "from_client"; pub const NUM_CANCELLATION_REQUESTS_SOURCE_FROM_REDIS: &str = "from_redis"; diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index 166e761a4e..6051c0a812 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -20,7 +20,7 @@ use crate::{ proxy::handshake::{handshake, HandshakeData}, rate_limiter::EndpointRateLimiter, stream::{PqStream, Stream}, - EndpointCacheKey, Normalize, + EndpointCacheKey, }; use futures::TryFutureExt; use itertools::Itertools; @@ -280,7 +280,7 @@ pub async fn handle_client( // check rate limit if let Some(ep) = user_info.get_endpoint() { - if !endpoint_rate_limiter.check(ep.normalize(), 1) { + if !endpoint_rate_limiter.check(ep, 1) { return stream .throw_error(auth::AuthError::too_many_connections()) .await?; diff --git a/proxy/src/rate_limiter.rs b/proxy/src/rate_limiter.rs index a3b83e5e50..13dffffca0 100644 --- a/proxy/src/rate_limiter.rs +++ b/proxy/src/rate_limiter.rs @@ -4,4 +4,4 @@ mod limiter; pub use aimd::Aimd; pub use limit_algorithm::{AimdConfig, Fixed, RateLimitAlgorithm, RateLimiterConfig}; pub use limiter::Limiter; -pub use limiter::{AuthRateLimiter, EndpointRateLimiter, GlobalRateLimiter, RateBucketInfo}; +pub use limiter::{AuthRateLimiter, EndpointRateLimiter, RateBucketInfo, RedisRateLimiter}; diff --git a/proxy/src/rate_limiter/limiter.rs b/proxy/src/rate_limiter/limiter.rs index 0503deb311..f590896dd9 100644 --- a/proxy/src/rate_limiter/limiter.rs +++ b/proxy/src/rate_limiter/limiter.rs @@ -24,13 +24,13 @@ use super::{ RateLimiterConfig, }; -pub struct GlobalRateLimiter { +pub struct RedisRateLimiter { data: Vec, - info: Vec, + info: &'static [RateBucketInfo], } -impl GlobalRateLimiter { - pub fn new(info: Vec) -> Self { +impl RedisRateLimiter { + pub fn new(info: &'static [RateBucketInfo]) -> Self { Self { data: vec![ RateBucket { @@ -50,7 +50,7 @@ impl GlobalRateLimiter { let should_allow_request = self .data .iter_mut() - .zip(&self.info) + .zip(self.info) .all(|(bucket, info)| bucket.should_allow_request(info, now, 1)); if should_allow_request { diff --git a/proxy/src/redis/cancellation_publisher.rs b/proxy/src/redis/cancellation_publisher.rs index 7baf104374..422789813c 100644 --- a/proxy/src/redis/cancellation_publisher.rs +++ b/proxy/src/redis/cancellation_publisher.rs @@ -5,7 +5,7 @@ use redis::AsyncCommands; use tokio::sync::Mutex; use uuid::Uuid; -use crate::rate_limiter::{GlobalRateLimiter, RateBucketInfo}; +use crate::rate_limiter::{RateBucketInfo, RedisRateLimiter}; use super::{ connection_with_credentials_provider::ConnectionWithCredentialsProvider, @@ -80,7 +80,7 @@ impl CancellationPublisher for Arc> { pub struct RedisPublisherClient { client: ConnectionWithCredentialsProvider, region_id: String, - limiter: GlobalRateLimiter, + limiter: RedisRateLimiter, } impl RedisPublisherClient { @@ -92,7 +92,7 @@ impl RedisPublisherClient { Ok(Self { client, region_id, - limiter: GlobalRateLimiter::new(info.into()), + limiter: RedisRateLimiter::new(info), }) } diff --git a/test_runner/regress/test_proxy_rate_limiter.py b/test_runner/regress/test_proxy_rate_limiter.py new file mode 100644 index 0000000000..f39f0cad07 --- /dev/null +++ b/test_runner/regress/test_proxy_rate_limiter.py @@ -0,0 +1,84 @@ +import asyncio +import time +from pathlib import Path +from typing import Iterator + +import pytest +from fixtures.neon_fixtures import ( + PSQL, + NeonProxy, +) +from fixtures.port_distributor import PortDistributor +from pytest_httpserver import HTTPServer +from werkzeug.wrappers.response import Response + + +def waiting_handler(status_code: int) -> Response: + # wait more than timeout to make sure that both (two) connections are open. + # It would be better to use a barrier here, but I don't know how to do that together with pytest-httpserver. + time.sleep(2) + return Response(status=status_code) + + +@pytest.fixture(scope="function") +def proxy_with_rate_limit( + port_distributor: PortDistributor, + neon_binpath: Path, + httpserver_listen_address, + test_output_dir: Path, +) -> Iterator[NeonProxy]: + """Neon proxy that routes directly to vanilla postgres.""" + + proxy_port = port_distributor.get_port() + mgmt_port = port_distributor.get_port() + http_port = port_distributor.get_port() + external_http_port = port_distributor.get_port() + (host, port) = httpserver_listen_address + endpoint = f"http://{host}:{port}/billing/api/v1/usage_events" + + with NeonProxy( + neon_binpath=neon_binpath, + test_output_dir=test_output_dir, + proxy_port=proxy_port, + http_port=http_port, + mgmt_port=mgmt_port, + external_http_port=external_http_port, + auth_backend=NeonProxy.Console(endpoint, fixed_rate_limit=5), + ) as proxy: + proxy.start() + yield proxy + + +@pytest.mark.asyncio +async def test_proxy_rate_limit( + httpserver: HTTPServer, + proxy_with_rate_limit: NeonProxy, +): + uri = "/billing/api/v1/usage_events/proxy_get_role_secret" + # mock control plane service + httpserver.expect_ordered_request(uri, method="GET").respond_with_handler( + lambda _: Response(status=200) + ) + httpserver.expect_ordered_request(uri, method="GET").respond_with_handler( + lambda _: waiting_handler(429) + ) + httpserver.expect_ordered_request(uri, method="GET").respond_with_handler( + lambda _: waiting_handler(500) + ) + + psql = PSQL(host=proxy_with_rate_limit.host, port=proxy_with_rate_limit.proxy_port) + f = await psql.run("select 42;") + await proxy_with_rate_limit.find_auth_link(uri, f) + # Limit should be 2. + + # Run two queries in parallel. + f1, f2 = await asyncio.gather(psql.run("select 42;"), psql.run("select 42;")) + await proxy_with_rate_limit.find_auth_link(uri, f1) + await proxy_with_rate_limit.find_auth_link(uri, f2) + + # Now limit should be 0. + f = await psql.run("select 42;") + await proxy_with_rate_limit.find_auth_link(uri, f) + + # There last query shouldn't reach the http-server. + assert httpserver.assertions == []