diff --git a/Cargo.lock b/Cargo.lock index 4cada013d7..a083af020a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3519,6 +3519,7 @@ dependencies = [ "sha2", "socket2 0.5.3", "sync_wrapper", + "task-local-extensions", "thiserror", "tls-listener", "tokio", diff --git a/Cargo.toml b/Cargo.toml index 363d4c6fe4..bfdb0442ab 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -136,6 +136,7 @@ strum_macros = "0.24" svg_fmt = "0.4.1" sync_wrapper = "0.1.2" tar = "0.4" +task-local-extensions = "0.1.4" test-context = "0.1" thiserror = "1.0" tls-listener = { version = "0.7", features = ["rustls", "hyper-h1"] } diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index 92498d3ecd..0ec7efd316 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -51,6 +51,7 @@ serde_json.workspace = true sha2.workspace = true socket2.workspace = true sync_wrapper.workspace = true +task-local-extensions.workspace = true thiserror.workspace = true tls-listener.workspace = true tokio-postgres.workspace = true diff --git a/proxy/src/bin/proxy.rs b/proxy/src/bin/proxy.rs index 7d1b7eaaae..570cf0943a 100644 --- a/proxy/src/bin/proxy.rs +++ b/proxy/src/bin/proxy.rs @@ -4,6 +4,7 @@ use proxy::config::AuthenticationConfig; use proxy::config::HttpConfig; use proxy::console; use proxy::http; +use proxy::rate_limiter::RateLimiterConfig; use proxy::usage_metrics; use anyhow::bail; @@ -95,6 +96,20 @@ struct ProxyCliArgs { /// Require that all incoming requests have a Proxy Protocol V2 packet **and** have an IP address associated. #[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)] require_client_ip: bool, + /// Disable dynamic rate limiter and store the metrics to ensure its production behaviour. + #[clap(long, default_value_t = true, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)] + disable_dynamic_rate_limiter: bool, + /// Rate limit algorithm. Makes sense only if `disable_rate_limiter` is `false`. + #[clap(value_enum, long, default_value_t = proxy::rate_limiter::RateLimitAlgorithm::Aimd)] + rate_limit_algorithm: proxy::rate_limiter::RateLimitAlgorithm, + /// Timeout for rate limiter. If it didn't manage to aquire a permit in this time, it will return an error. + #[clap(long, default_value = "15s", value_parser = humantime::parse_duration)] + rate_limiter_timeout: tokio::time::Duration, + /// Initial limit for dynamic rate limiter. Makes sense only if `rate_limit_algorithm` is *not* `None`. + #[clap(long, default_value_t = 100)] + initial_limit: usize, + #[clap(flatten)] + aimd_config: proxy::rate_limiter::AimdConfig, } #[tokio::main] @@ -213,6 +228,13 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { and metric-collection-interval must be specified" ), }; + let rate_limiter_config = RateLimiterConfig { + disable: args.disable_dynamic_rate_limiter, + algorithm: args.rate_limit_algorithm, + timeout: args.rate_limiter_timeout, + initial_limit: args.initial_limit, + aimd_config: Some(args.aimd_config), + }; let auth_backend = match &args.auth_backend { AuthBackend::Console => { @@ -237,7 +259,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { tokio::spawn(locks.garbage_collect_worker(epoch)); let url = args.auth_endpoint.parse()?; - let endpoint = http::Endpoint::new(url, http::new_client()); + let endpoint = http::Endpoint::new(url, http::new_client(rate_limiter_config)); let api = console::provider::neon::Api::new(endpoint, caches, locks); auth::BackendType::Console(Cow::Owned(api), ()) diff --git a/proxy/src/http.rs b/proxy/src/http.rs index 14a9072a45..159b949da3 100644 --- a/proxy/src/http.rs +++ b/proxy/src/http.rs @@ -13,13 +13,13 @@ pub use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware}; use tokio::time::Instant; use tracing::trace; -use crate::url::ApiUrl; +use crate::{rate_limiter, url::ApiUrl}; use reqwest_middleware::RequestBuilder; /// This is the preferred way to create new http clients, /// because it takes care of observability (OpenTelemetry). /// We deliberately don't want to replace this with a public static. -pub fn new_client() -> ClientWithMiddleware { +pub fn new_client(rate_limiter_config: rate_limiter::RateLimiterConfig) -> ClientWithMiddleware { let client = reqwest::ClientBuilder::new() .dns_resolver(Arc::new(GaiResolver::default())) .connection_verbose(true) @@ -28,6 +28,7 @@ pub fn new_client() -> ClientWithMiddleware { reqwest_middleware::ClientBuilder::new(client) .with(reqwest_tracing::TracingMiddleware::default()) + .with(rate_limiter::Limiter::new(rate_limiter_config)) .build() } diff --git a/proxy/src/lib.rs b/proxy/src/lib.rs index 4a95e473f6..a22600cbb3 100644 --- a/proxy/src/lib.rs +++ b/proxy/src/lib.rs @@ -19,6 +19,7 @@ pub mod logging; pub mod parse; pub mod protocol2; pub mod proxy; +pub mod rate_limiter; pub mod sasl; pub mod scram; pub mod serverless; diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index a1ebf03545..2d38acd05b 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -19,7 +19,10 @@ use itertools::Itertools; use metrics::{exponential_buckets, register_int_counter_vec, IntCounterVec}; use once_cell::sync::{Lazy, OnceCell}; use pq_proto::{BeMessage as Be, FeStartupPacket, StartupMessageParams}; -use prometheus::{register_histogram_vec, HistogramVec}; +use prometheus::{ + register_histogram, register_histogram_vec, register_int_gauge_vec, Histogram, HistogramVec, + IntGaugeVec, +}; use regex::Regex; use std::{error::Error, io, ops::ControlFlow, sync::Arc, time::Instant}; use tokio::{ @@ -107,6 +110,25 @@ static COMPUTE_CONNECTION_LATENCY: Lazy = Lazy::new(|| { .unwrap() }); +pub static RATE_LIMITER_ACQUIRE_LATENCY: Lazy = Lazy::new(|| { + register_histogram!( + "semaphore_control_plane_token_acquire_seconds", + "Time it took for proxy to establish a connection to the compute endpoint", + // largest bucket = 2^16 * 0.5ms = 32s + exponential_buckets(0.0005, 2.0, 16).unwrap(), + ) + .unwrap() +}); + +pub static RATE_LIMITER_LIMIT: Lazy = Lazy::new(|| { + register_int_gauge_vec!( + "semaphore_control_plane_limit", + "Current limit of the semaphore control plane", + &["limit"], // 2 counters + ) + .unwrap() +}); + pub struct LatencyTimer { // time since the stopwatch was started start: Option, diff --git a/proxy/src/rate_limiter.rs b/proxy/src/rate_limiter.rs new file mode 100644 index 0000000000..5622c44a68 --- /dev/null +++ b/proxy/src/rate_limiter.rs @@ -0,0 +1,6 @@ +mod aimd; +mod limit_algorithm; +mod limiter; +pub use aimd::Aimd; +pub use limit_algorithm::{AimdConfig, Fixed, RateLimitAlgorithm, RateLimiterConfig}; +pub use limiter::Limiter; diff --git a/proxy/src/rate_limiter/aimd.rs b/proxy/src/rate_limiter/aimd.rs new file mode 100644 index 0000000000..c6c532ae53 --- /dev/null +++ b/proxy/src/rate_limiter/aimd.rs @@ -0,0 +1,199 @@ +use std::usize; + +use async_trait::async_trait; + +use super::limit_algorithm::{AimdConfig, LimitAlgorithm, Sample}; + +use super::limiter::Outcome; + +/// Loss-based congestion avoidance. +/// +/// Additive-increase, multiplicative decrease. +/// +/// Adds available currency when: +/// 1. no load-based errors are observed, and +/// 2. the utilisation of the current limit is high. +/// +/// Reduces available concurrency by a factor when load-based errors are detected. +pub struct Aimd { + min_limit: usize, + max_limit: usize, + decrease_factor: f32, + increase_by: usize, + min_utilisation_threshold: f32, +} + +impl Aimd { + pub fn new(config: AimdConfig) -> Self { + Self { + min_limit: config.aimd_min_limit, + max_limit: config.aimd_max_limit, + decrease_factor: config.aimd_decrease_factor, + increase_by: config.aimd_increase_by, + min_utilisation_threshold: config.aimd_min_utilisation_threshold, + } + } + + pub fn decrease_factor(self, factor: f32) -> Self { + assert!((0.5..1.0).contains(&factor)); + Self { + decrease_factor: factor, + ..self + } + } + + pub fn increase_by(self, increase: usize) -> Self { + assert!(increase > 0); + Self { + increase_by: increase, + ..self + } + } + + pub fn with_max_limit(self, max: usize) -> Self { + assert!(max > 0); + Self { + max_limit: max, + ..self + } + } + + /// A threshold below which the limit won't be increased. 0.5 = 50%. + pub fn with_min_utilisation_threshold(self, min_util: f32) -> Self { + assert!(min_util > 0. && min_util < 1.); + Self { + min_utilisation_threshold: min_util, + ..self + } + } +} + +#[async_trait] +impl LimitAlgorithm for Aimd { + async fn update(&mut self, old_limit: usize, sample: Sample) -> usize { + use Outcome::*; + match sample.outcome { + Success => { + let utilisation = sample.in_flight as f32 / old_limit as f32; + + if utilisation > self.min_utilisation_threshold { + let limit = old_limit + self.increase_by; + limit.clamp(self.min_limit, self.max_limit) + } else { + old_limit + } + } + Overload => { + let limit = old_limit as f32 * self.decrease_factor; + + // Floor instead of round, so the limit reduces even with small numbers. + // E.g. round(2 * 0.9) = 2, but floor(2 * 0.9) = 1 + let limit = limit.floor() as usize; + + limit.clamp(self.min_limit, self.max_limit) + } + } + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use tokio::sync::Notify; + + use super::*; + + use crate::rate_limiter::{Limiter, RateLimiterConfig}; + + #[tokio::test] + async fn should_decrease_limit_on_overload() { + let config = RateLimiterConfig { + initial_limit: 10, + aimd_config: Some(AimdConfig { + aimd_decrease_factor: 0.5, + ..Default::default() + }), + disable: false, + ..Default::default() + }; + + let release_notifier = Arc::new(Notify::new()); + + let limiter = Limiter::new(config).with_release_notifier(release_notifier.clone()); + + let token = limiter.try_acquire().unwrap(); + limiter.release(token, Some(Outcome::Overload)).await; + release_notifier.notified().await; + assert_eq!(limiter.state().limit(), 5, "overload: decrease"); + } + + #[tokio::test] + async fn should_increase_limit_on_success_when_using_gt_util_threshold() { + let config = RateLimiterConfig { + initial_limit: 4, + aimd_config: Some(AimdConfig { + aimd_decrease_factor: 0.5, + aimd_min_utilisation_threshold: 0.5, + aimd_increase_by: 1, + ..Default::default() + }), + disable: false, + ..Default::default() + }; + + let limiter = Limiter::new(config); + + let token = limiter.try_acquire().unwrap(); + let _token = limiter.try_acquire().unwrap(); + let _token = limiter.try_acquire().unwrap(); + + limiter.release(token, Some(Outcome::Success)).await; + assert_eq!(limiter.state().limit(), 5, "success: increase"); + } + + #[tokio::test] + async fn should_not_change_limit_on_success_when_using_lt_util_threshold() { + let config = RateLimiterConfig { + initial_limit: 4, + aimd_config: Some(AimdConfig { + aimd_decrease_factor: 0.5, + aimd_min_utilisation_threshold: 0.5, + ..Default::default() + }), + disable: false, + ..Default::default() + }; + + let limiter = Limiter::new(config); + + let token = limiter.try_acquire().unwrap(); + + limiter.release(token, Some(Outcome::Success)).await; + assert_eq!( + limiter.state().limit(), + 4, + "success: ignore when < half limit" + ); + } + + #[tokio::test] + async fn should_not_change_limit_when_no_outcome() { + let config = RateLimiterConfig { + initial_limit: 10, + aimd_config: Some(AimdConfig { + aimd_decrease_factor: 0.5, + aimd_min_utilisation_threshold: 0.5, + ..Default::default() + }), + disable: false, + ..Default::default() + }; + + let limiter = Limiter::new(config); + + let token = limiter.try_acquire().unwrap(); + limiter.release(token, None).await; + assert_eq!(limiter.state().limit(), 10, "ignore"); + } +} diff --git a/proxy/src/rate_limiter/limit_algorithm.rs b/proxy/src/rate_limiter/limit_algorithm.rs new file mode 100644 index 0000000000..5cd2d5ebb7 --- /dev/null +++ b/proxy/src/rate_limiter/limit_algorithm.rs @@ -0,0 +1,98 @@ +//! Algorithms for controlling concurrency limits. +use async_trait::async_trait; +use std::time::Duration; + +use super::{limiter::Outcome, Aimd}; + +/// An algorithm for controlling a concurrency limit. +#[async_trait] +pub trait LimitAlgorithm: Send + Sync + 'static { + /// Update the concurrency limit in response to a new job completion. + async fn update(&mut self, old_limit: usize, sample: Sample) -> usize; +} + +/// The result of a job (or jobs), including the [Outcome] (loss) and latency (delay). +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Sample { + pub(crate) latency: Duration, + /// Jobs in flight when the sample was taken. + pub(crate) in_flight: usize, + pub(crate) outcome: Outcome, +} + +#[derive(Clone, Copy, Debug, Default, clap::ValueEnum)] +pub enum RateLimitAlgorithm { + Fixed, + #[default] + Aimd, +} + +pub struct Fixed; + +#[async_trait] +impl LimitAlgorithm for Fixed { + async fn update(&mut self, old_limit: usize, _sample: Sample) -> usize { + old_limit + } +} + +#[derive(Clone, Copy, Debug)] +pub struct RateLimiterConfig { + pub disable: bool, + pub algorithm: RateLimitAlgorithm, + pub timeout: Duration, + pub initial_limit: usize, + pub aimd_config: Option, +} + +impl RateLimiterConfig { + pub fn create_rate_limit_algorithm(self) -> Box { + match self.algorithm { + RateLimitAlgorithm::Fixed => Box::new(Fixed), + RateLimitAlgorithm::Aimd => Box::new(Aimd::new(self.aimd_config.unwrap())), // For aimd algorithm config is mandatory. + } + } +} + +impl Default for RateLimiterConfig { + fn default() -> Self { + Self { + disable: true, + algorithm: RateLimitAlgorithm::Aimd, + timeout: Duration::from_secs(1), + initial_limit: 100, + aimd_config: Some(AimdConfig::default()), + } + } +} + +#[derive(clap::Parser, Clone, Copy, Debug)] +pub struct AimdConfig { + /// Minimum limit for AIMD algorithm. Makes sense only if `rate_limit_algorithm` is `Aimd`. + #[clap(long, default_value_t = 1)] + pub aimd_min_limit: usize, + /// Maximum limit for AIMD algorithm. Makes sense only if `rate_limit_algorithm` is `Aimd`. + #[clap(long, default_value_t = 1500)] + pub aimd_max_limit: usize, + /// Increase AIMD increase by value in case of success. Makes sense only if `rate_limit_algorithm` is `Aimd`. + #[clap(long, default_value_t = 10)] + pub aimd_increase_by: usize, + /// Decrease AIMD decrease by value in case of timout/429. Makes sense only if `rate_limit_algorithm` is `Aimd`. + #[clap(long, default_value_t = 0.9)] + pub aimd_decrease_factor: f32, + /// A threshold below which the limit won't be increased. Makes sense only if `rate_limit_algorithm` is `Aimd`. + #[clap(long, default_value_t = 0.8)] + pub aimd_min_utilisation_threshold: f32, +} + +impl Default for AimdConfig { + fn default() -> Self { + Self { + aimd_min_limit: 1, + aimd_max_limit: 1500, + aimd_increase_by: 10, + aimd_decrease_factor: 0.9, + aimd_min_utilisation_threshold: 0.8, + } + } +} diff --git a/proxy/src/rate_limiter/limiter.rs b/proxy/src/rate_limiter/limiter.rs new file mode 100644 index 0000000000..3a9fed3919 --- /dev/null +++ b/proxy/src/rate_limiter/limiter.rs @@ -0,0 +1,441 @@ +use std::{ + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }, + time::Duration, +}; + +use tokio::sync::{Mutex as AsyncMutex, Semaphore, SemaphorePermit}; +use tokio::time::{timeout, Instant}; +use tracing::info; + +use super::{ + limit_algorithm::{LimitAlgorithm, Sample}, + RateLimiterConfig, +}; + +/// Limits the number of concurrent jobs. +/// +/// Concurrency is limited through the use of [Token]s. Acquire a token to run a job, and release the +/// token once the job is finished. +/// +/// The limit will be automatically adjusted based on observed latency (delay) and/or failures +/// caused by overload (loss). +pub struct Limiter { + limit_algo: AsyncMutex>, + semaphore: std::sync::Arc, + config: RateLimiterConfig, + + // ONLY WRITE WHEN LIMIT_ALGO IS LOCKED + limits: AtomicUsize, + + // ONLY USE ATOMIC ADD/SUB + in_flight: Arc, + + #[cfg(test)] + notifier: Option>, +} + +/// A concurrency token, required to run a job. +/// +/// Release the token back to the [Limiter] after the job is complete. +#[derive(Debug)] +pub struct Token<'t> { + permit: Option>, + start: Instant, + in_flight: Arc, +} + +/// A snapshot of the state of the [Limiter]. +/// +/// Not guaranteed to be consistent under high concurrency. +#[derive(Debug, Clone, Copy)] +pub struct LimiterState { + limit: usize, + available: usize, + in_flight: usize, +} + +/// Whether a job succeeded or failed as a result of congestion/overload. +/// +/// Errors not considered to be caused by overload should be ignored. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum Outcome { + /// The job succeeded, or failed in a way unrelated to overload. + Success, + /// The job failed because of overload, e.g. it timed out or an explicit backpressure signal + /// was observed. + Overload, +} + +impl Outcome { + fn from_reqwest_error(error: &reqwest_middleware::Error) -> Self { + match error { + reqwest_middleware::Error::Middleware(_) => Outcome::Success, + reqwest_middleware::Error::Reqwest(e) => { + if let Some(status) = e.status() { + if status.is_server_error() + || reqwest::StatusCode::TOO_MANY_REQUESTS.as_u16() == status + { + Outcome::Overload + } else { + Outcome::Success + } + } else { + Outcome::Success + } + } + } + } + fn from_reqwest_response(response: &reqwest::Response) -> Self { + if response.status().is_server_error() + || response.status() == reqwest::StatusCode::TOO_MANY_REQUESTS + { + Outcome::Overload + } else { + Outcome::Success + } + } +} + +impl Limiter { + /// Create a limiter with a given limit control algorithm. + pub fn new(config: RateLimiterConfig) -> Self { + assert!(config.initial_limit > 0); + Self { + limit_algo: AsyncMutex::new(config.create_rate_limit_algorithm()), + semaphore: Arc::new(Semaphore::new(config.initial_limit)), + config, + limits: AtomicUsize::new(config.initial_limit), + in_flight: Arc::new(AtomicUsize::new(0)), + #[cfg(test)] + notifier: None, + } + } + // pub fn new(limit_algorithm: T, timeout: Duration, initial_limit: usize) -> Self { + // assert!(initial_limit > 0); + + // Self { + // limit_algo: AsyncMutex::new(limit_algorithm), + // semaphore: Arc::new(Semaphore::new(initial_limit)), + // timeout, + // limits: AtomicUsize::new(initial_limit), + // in_flight: Arc::new(AtomicUsize::new(0)), + // #[cfg(test)] + // notifier: None, + // } + // } + + /// In some cases [Token]s are acquired asynchronously when updating the limit. + #[cfg(test)] + pub fn with_release_notifier(mut self, n: std::sync::Arc) -> Self { + self.notifier = Some(n); + self + } + + /// Try to immediately acquire a concurrency [Token]. + /// + /// Returns `None` if there are none available. + pub fn try_acquire(&self) -> Option { + let result = if self.config.disable { + // If the rate limiter is disabled, we can always acquire a token. + Some(Token::new(None, self.in_flight.clone())) + } else { + self.semaphore + .try_acquire() + .map(|permit| Token::new(Some(permit), self.in_flight.clone())) + .ok() + }; + if result.is_some() { + self.in_flight.fetch_add(1, Ordering::AcqRel); + } + result + } + + /// Try to acquire a concurrency [Token], waiting for `duration` if there are none available. + /// + /// Returns `None` if there are none available after `duration`. + pub async fn acquire_timeout(&self, duration: Duration) -> Option> { + info!("acquiring token: {:?}", self.semaphore.available_permits()); + let result = if self.config.disable { + // If the rate limiter is disabled, we can always acquire a token. + Some(Token::new(None, self.in_flight.clone())) + } else { + match timeout(duration, self.semaphore.acquire()).await { + Ok(maybe_permit) => maybe_permit + .map(|permit| Token::new(Some(permit), self.in_flight.clone())) + .ok(), + Err(_) => None, + } + }; + if result.is_some() { + self.in_flight.fetch_add(1, Ordering::AcqRel); + } + result + } + + /// Return the concurrency [Token], along with the outcome of the job. + /// + /// The [Outcome] of the job, and the time taken to perform it, may be used + /// to update the concurrency limit. + /// + /// Set the outcome to `None` to ignore the job. + pub async fn release(&self, mut token: Token<'_>, outcome: Option) { + tracing::info!("outcome is {:?}", outcome); + let in_flight = self.in_flight.load(Ordering::Acquire); + let old_limit = self.limits.load(Ordering::Acquire); + let available = if self.config.disable { + 0 // This is not used in the algorithm and can be anything. If the config disable it makes sense to set it to 0. + } else { + self.semaphore.available_permits() + }; + let total = in_flight + available; + + let mut algo = self.limit_algo.lock().await; + + let new_limit = if let Some(outcome) = outcome { + let sample = Sample { + latency: token.start.elapsed(), + in_flight, + outcome, + }; + algo.update(old_limit, sample).await + } else { + old_limit + }; + tracing::info!("new limit is {}", new_limit); + let actual_limit = if new_limit < total { + token.forget(); + total.saturating_sub(1) + } else { + if !self.config.disable { + self.semaphore.add_permits(new_limit.saturating_sub(total)); + } + new_limit + }; + crate::proxy::RATE_LIMITER_LIMIT + .with_label_values(&["expected"]) + .set(new_limit as i64); + crate::proxy::RATE_LIMITER_LIMIT + .with_label_values(&["actual"]) + .set(actual_limit as i64); + self.limits.store(new_limit, Ordering::Release); + #[cfg(test)] + if let Some(n) = &self.notifier { + n.notify_one(); + } + } + + /// The current state of the limiter. + pub fn state(&self) -> LimiterState { + let limit = self.limits.load(Ordering::Relaxed); + let in_flight = self.in_flight.load(Ordering::Relaxed); + LimiterState { + limit, + available: limit.saturating_sub(in_flight), + in_flight, + } + } +} + +impl<'t> Token<'t> { + fn new(permit: Option>, in_flight: Arc) -> Self { + Self { + permit, + start: Instant::now(), + in_flight, + } + } + + #[cfg(test)] + pub fn set_latency(&mut self, latency: Duration) { + use std::ops::Sub; + + self.start = Instant::now().sub(latency); + } + + pub fn forget(&mut self) { + if let Some(permit) = self.permit.take() { + permit.forget(); + } + } +} + +impl Drop for Token<'_> { + fn drop(&mut self) { + self.in_flight.fetch_sub(1, Ordering::AcqRel); + } +} + +impl LimiterState { + /// The current concurrency limit. + pub fn limit(&self) -> usize { + self.limit + } + /// The amount of concurrency available to use. + pub fn available(&self) -> usize { + self.available + } + /// The number of jobs in flight. + pub fn in_flight(&self) -> usize { + self.in_flight + } +} + +#[async_trait::async_trait] +impl reqwest_middleware::Middleware for Limiter { + async fn handle( + &self, + req: reqwest::Request, + extensions: &mut task_local_extensions::Extensions, + next: reqwest_middleware::Next<'_>, + ) -> reqwest_middleware::Result { + let start = Instant::now(); + let token = self + .acquire_timeout(self.config.timeout) + .await + .ok_or_else(|| { + reqwest_middleware::Error::Middleware( + // TODO: Should we map it into user facing errors? + crate::console::errors::ApiError::Console { + status: crate::http::StatusCode::TOO_MANY_REQUESTS, + text: "Too many requests".into(), + } + .into(), + ) + })?; + info!(duration = ?start.elapsed(), "waiting for token to connect to the control plane"); + crate::proxy::RATE_LIMITER_ACQUIRE_LATENCY.observe(start.elapsed().as_secs_f64()); + match next.run(req, extensions).await { + Ok(response) => { + self.release(token, Some(Outcome::from_reqwest_response(&response))) + .await; + Ok(response) + } + Err(e) => { + self.release(token, Some(Outcome::from_reqwest_error(&e))) + .await; + Err(e) + } + } + } +} + +#[cfg(test)] +mod tests { + use std::{pin::pin, task::Context, time::Duration}; + + use futures::{task::noop_waker_ref, Future}; + + use super::{Limiter, Outcome}; + use crate::rate_limiter::RateLimitAlgorithm; + + #[tokio::test] + async fn it_works() { + let config = super::RateLimiterConfig { + algorithm: RateLimitAlgorithm::Fixed, + timeout: Duration::from_secs(1), + initial_limit: 10, + disable: false, + ..Default::default() + }; + let limiter = Limiter::new(config); + + let token = limiter.try_acquire().unwrap(); + + limiter.release(token, Some(Outcome::Success)).await; + + assert_eq!(limiter.state().limit(), 10); + } + + #[tokio::test] + async fn is_fair() { + let config = super::RateLimiterConfig { + algorithm: RateLimitAlgorithm::Fixed, + timeout: Duration::from_secs(1), + initial_limit: 1, + disable: false, + ..Default::default() + }; + let limiter = Limiter::new(config); + + // === TOKEN 1 === + let token1 = limiter.try_acquire().unwrap(); + + let mut token2_fut = pin!(limiter.acquire_timeout(Duration::from_secs(1))); + assert!( + token2_fut + .as_mut() + .poll(&mut Context::from_waker(noop_waker_ref())) + .is_pending(), + "token is acquired by token1" + ); + + let mut token3_fut = pin!(limiter.acquire_timeout(Duration::from_secs(1))); + assert!( + token3_fut + .as_mut() + .poll(&mut Context::from_waker(noop_waker_ref())) + .is_pending(), + "token is acquired by token1" + ); + + limiter.release(token1, Some(Outcome::Success)).await; + // === END TOKEN 1 === + + // === TOKEN 2 === + assert!( + limiter.try_acquire().is_none(), + "token is acquired by token2" + ); + + assert!( + token3_fut + .as_mut() + .poll(&mut Context::from_waker(noop_waker_ref())) + .is_pending(), + "token is acquired by token2" + ); + + let token2 = token2_fut.await.unwrap(); + + limiter.release(token2, Some(Outcome::Success)).await; + // === END TOKEN 2 === + + // === TOKEN 3 === + assert!( + limiter.try_acquire().is_none(), + "token is acquired by token3" + ); + + let token3 = token3_fut.await.unwrap(); + limiter.release(token3, Some(Outcome::Success)).await; + // === END TOKEN 3 === + + // === TOKEN 4 === + let token4 = limiter.try_acquire().unwrap(); + limiter.release(token4, Some(Outcome::Success)).await; + } + + #[tokio::test] + async fn disable() { + let config = super::RateLimiterConfig { + algorithm: RateLimitAlgorithm::Fixed, + timeout: Duration::from_secs(1), + initial_limit: 1, + disable: true, + ..Default::default() + }; + let limiter = Limiter::new(config); + + // === TOKEN 1 === + let token1 = limiter.try_acquire().unwrap(); + let token2 = limiter.try_acquire().unwrap(); + let state = limiter.state(); + assert_eq!(state.limit(), 1); + assert_eq!(state.in_flight(), 2); // For disabled limiter, it's expected. + limiter.release(token1, None).await; + limiter.release(token2, None).await; + } +} diff --git a/proxy/src/usage_metrics.rs b/proxy/src/usage_metrics.rs index cfeec5622b..180b5f7199 100644 --- a/proxy/src/usage_metrics.rs +++ b/proxy/src/usage_metrics.rs @@ -249,7 +249,7 @@ mod tests { use url::Url; use super::{collect_metrics_iteration, Ids, Metrics}; - use crate::http; + use crate::{http, rate_limiter::RateLimiterConfig}; #[tokio::test] async fn metrics() { @@ -279,7 +279,7 @@ mod tests { tokio::spawn(server); let metrics = Metrics::default(); - let client = http::new_client(); + let client = http::new_client(RateLimiterConfig::default()); let endpoint = Url::parse(&format!("http://{addr}")).unwrap(); let now = Utc::now(); diff --git a/test_runner/fixtures/neon_fixtures.py b/test_runner/fixtures/neon_fixtures.py index b45b0a12c0..6737ca5fe3 100644 --- a/test_runner/fixtures/neon_fixtures.py +++ b/test_runner/fixtures/neon_fixtures.py @@ -2179,6 +2179,29 @@ class NeonProxy(PgProtocol): *["--allow-self-signed-compute", "true"], ] + class Console(AuthBackend): + def __init__(self, endpoint: str, fixed_rate_limit: Optional[int] = None): + self.endpoint = endpoint + self.fixed_rate_limit = fixed_rate_limit + + def extra_args(self) -> list[str]: + args = [ + # Console auth backend params + *["--auth-backend", "console"], + *["--auth-endpoint", self.endpoint], + ] + if self.fixed_rate_limit is not None: + args += [ + *["--disable-dynamic-rate-limiter", "false"], + *["--rate-limit-algorithm", "aimd"], + *["--initial-limit", str(1)], + *["--rate-limiter-timeout", "1s"], + *["--aimd-min-limit", "0"], + *["--aimd-increase-by", "1"], + *["--wake-compute-cache", "size=0"], # Disable cache to test rate limiter. + ] + return args + @dataclass(frozen=True) class Postgres(AuthBackend): pg_conn_url: str diff --git a/test_runner/regress/test_proxy.py b/test_runner/regress/test_proxy.py index c93cdf637a..0f2cd9768f 100644 --- a/test_runner/regress/test_proxy.py +++ b/test_runner/regress/test_proxy.py @@ -1,3 +1,4 @@ +import asyncio import json import subprocess import time @@ -11,6 +12,29 @@ from fixtures.neon_fixtures import PSQL, NeonProxy, VanillaPostgres GET_CONNECTION_PID_QUERY = "SELECT pid FROM pg_stat_activity WHERE state = 'active'" +@pytest.mark.asyncio +async def test_http_pool_begin_1(static_proxy: NeonProxy): + static_proxy.safe_psql("create user http_auth with password 'http' superuser") + + def query(*args) -> Any: + static_proxy.http_query( + "SELECT pg_sleep(10);", + args, + user="http_auth", + password="http", + expected_code=200, + ) + + query() + loop = asyncio.get_running_loop() + tasks = [loop.run_in_executor(None, query) for _ in range(10)] + # Wait for all the tasks to complete + completed, pending = await asyncio.wait(tasks) + # Get the results + results = [task.result() for task in completed] + print(results) + + def test_proxy_select_1(static_proxy: NeonProxy): """ A simplest smoke test: check proxy against a local postgres instance. diff --git a/test_runner/regress/test_proxy_rate_limiter.py b/test_runner/regress/test_proxy_rate_limiter.py new file mode 100644 index 0000000000..f39f0cad07 --- /dev/null +++ b/test_runner/regress/test_proxy_rate_limiter.py @@ -0,0 +1,84 @@ +import asyncio +import time +from pathlib import Path +from typing import Iterator + +import pytest +from fixtures.neon_fixtures import ( + PSQL, + NeonProxy, +) +from fixtures.port_distributor import PortDistributor +from pytest_httpserver import HTTPServer +from werkzeug.wrappers.response import Response + + +def waiting_handler(status_code: int) -> Response: + # wait more than timeout to make sure that both (two) connections are open. + # It would be better to use a barrier here, but I don't know how to do that together with pytest-httpserver. + time.sleep(2) + return Response(status=status_code) + + +@pytest.fixture(scope="function") +def proxy_with_rate_limit( + port_distributor: PortDistributor, + neon_binpath: Path, + httpserver_listen_address, + test_output_dir: Path, +) -> Iterator[NeonProxy]: + """Neon proxy that routes directly to vanilla postgres.""" + + proxy_port = port_distributor.get_port() + mgmt_port = port_distributor.get_port() + http_port = port_distributor.get_port() + external_http_port = port_distributor.get_port() + (host, port) = httpserver_listen_address + endpoint = f"http://{host}:{port}/billing/api/v1/usage_events" + + with NeonProxy( + neon_binpath=neon_binpath, + test_output_dir=test_output_dir, + proxy_port=proxy_port, + http_port=http_port, + mgmt_port=mgmt_port, + external_http_port=external_http_port, + auth_backend=NeonProxy.Console(endpoint, fixed_rate_limit=5), + ) as proxy: + proxy.start() + yield proxy + + +@pytest.mark.asyncio +async def test_proxy_rate_limit( + httpserver: HTTPServer, + proxy_with_rate_limit: NeonProxy, +): + uri = "/billing/api/v1/usage_events/proxy_get_role_secret" + # mock control plane service + httpserver.expect_ordered_request(uri, method="GET").respond_with_handler( + lambda _: Response(status=200) + ) + httpserver.expect_ordered_request(uri, method="GET").respond_with_handler( + lambda _: waiting_handler(429) + ) + httpserver.expect_ordered_request(uri, method="GET").respond_with_handler( + lambda _: waiting_handler(500) + ) + + psql = PSQL(host=proxy_with_rate_limit.host, port=proxy_with_rate_limit.proxy_port) + f = await psql.run("select 42;") + await proxy_with_rate_limit.find_auth_link(uri, f) + # Limit should be 2. + + # Run two queries in parallel. + f1, f2 = await asyncio.gather(psql.run("select 42;"), psql.run("select 42;")) + await proxy_with_rate_limit.find_auth_link(uri, f1) + await proxy_with_rate_limit.find_auth_link(uri, f2) + + # Now limit should be 0. + f = await psql.run("select 42;") + await proxy_with_rate_limit.find_auth_link(uri, f) + + # There last query shouldn't reach the http-server. + assert httpserver.assertions == []