From a54ea8fb1cd26396a06d2fd715bcf19b8b7a7226 Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Thu, 18 Apr 2024 06:00:33 +0100 Subject: [PATCH] proxy: move endpoint rate limiter (#7413) ## Problem ## Summary of changes Rate limit for wake_compute calls --- proxy/src/bin/proxy.rs | 12 +++++------- proxy/src/config.rs | 1 - proxy/src/console/provider.rs | 6 ++++++ proxy/src/console/provider/neon.rs | 12 ++++++++++++ proxy/src/proxy.rs | 16 +--------------- proxy/src/proxy/wake_compute.rs | 1 + proxy/src/rate_limiter/limiter.rs | 26 +++++++++++--------------- proxy/src/serverless.rs | 18 +++--------------- proxy/src/serverless/websocket.rs | 3 --- 9 files changed, 39 insertions(+), 56 deletions(-) diff --git a/proxy/src/bin/proxy.rs b/proxy/src/bin/proxy.rs index 71283dd606..b54f8c131c 100644 --- a/proxy/src/bin/proxy.rs +++ b/proxy/src/bin/proxy.rs @@ -331,7 +331,6 @@ async fn main() -> anyhow::Result<()> { let proxy_listener = TcpListener::bind(proxy_address).await?; let cancellation_token = CancellationToken::new(); - let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new(&config.endpoint_rps_limit)); let cancel_map = CancelMap::default(); let redis_publisher = match ®ional_redis_client { @@ -357,7 +356,6 @@ async fn main() -> anyhow::Result<()> { config, proxy_listener, cancellation_token.clone(), - endpoint_rate_limiter.clone(), cancellation_handler.clone(), )); @@ -372,7 +370,6 @@ async fn main() -> anyhow::Result<()> { config, serverless_listener, cancellation_token.clone(), - endpoint_rate_limiter.clone(), cancellation_handler.clone(), )); } @@ -533,7 +530,11 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { let url = args.auth_endpoint.parse()?; let endpoint = http::Endpoint::new(url, http::new_client()); - let api = console::provider::neon::Api::new(endpoint, caches, locks); + let mut endpoint_rps_limit = args.endpoint_rps_limit.clone(); + RateBucketInfo::validate(&mut endpoint_rps_limit)?; + let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new(endpoint_rps_limit)); + let api = + console::provider::neon::Api::new(endpoint, caches, locks, endpoint_rate_limiter); let api = console::provider::ConsoleBackend::Console(api); auth::BackendType::Console(MaybeOwned::Owned(api), ()) } @@ -567,8 +568,6 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { rate_limit_ip_subnet: args.auth_rate_limit_ip_subnet, }; - let mut endpoint_rps_limit = args.endpoint_rps_limit.clone(); - RateBucketInfo::validate(&mut endpoint_rps_limit)?; let mut redis_rps_limit = args.redis_rps_limit.clone(); RateBucketInfo::validate(&mut redis_rps_limit)?; @@ -581,7 +580,6 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { authentication_config, require_client_ip: args.require_client_ip, disable_ip_check_for_http: args.disable_ip_check_for_http, - endpoint_rps_limit, redis_rps_limit, handshake_timeout: args.handshake_timeout, region: args.region.clone(), diff --git a/proxy/src/config.rs b/proxy/src/config.rs index 7b4c02393b..f9519c7645 100644 --- a/proxy/src/config.rs +++ b/proxy/src/config.rs @@ -29,7 +29,6 @@ pub struct ProxyConfig { pub authentication_config: AuthenticationConfig, pub require_client_ip: bool, pub disable_ip_check_for_http: bool, - pub endpoint_rps_limit: Vec, pub redis_rps_limit: Vec, pub region: String, pub handshake_timeout: Duration, diff --git a/proxy/src/console/provider.rs b/proxy/src/console/provider.rs index 3fa7221f98..aa1800a9da 100644 --- a/proxy/src/console/provider.rs +++ b/proxy/src/console/provider.rs @@ -208,6 +208,9 @@ pub mod errors { #[error(transparent)] ApiError(ApiError), + #[error("Too many connections attempts")] + TooManyConnections, + #[error("Timeout waiting to acquire wake compute lock")] TimeoutError, } @@ -240,6 +243,8 @@ pub mod errors { // However, API might return a meaningful error. ApiError(e) => e.to_string_client(), + TooManyConnections => self.to_string(), + TimeoutError => "timeout while acquiring the compute resource lock".to_owned(), } } @@ -250,6 +255,7 @@ pub mod errors { match self { WakeComputeError::BadComputeAddress(_) => crate::error::ErrorKind::ControlPlane, WakeComputeError::ApiError(e) => e.get_error_kind(), + WakeComputeError::TooManyConnections => crate::error::ErrorKind::RateLimit, WakeComputeError::TimeoutError => crate::error::ErrorKind::ServiceRateLimit, } } diff --git a/proxy/src/console/provider/neon.rs b/proxy/src/console/provider/neon.rs index 138acdf578..58b2a1570c 100644 --- a/proxy/src/console/provider/neon.rs +++ b/proxy/src/console/provider/neon.rs @@ -12,6 +12,7 @@ use crate::{ console::messages::ColdStartInfo, http, metrics::{CacheOutcome, Metrics}, + rate_limiter::EndpointRateLimiter, scram, Normalize, }; use crate::{cache::Cached, context::RequestMonitoring}; @@ -25,6 +26,7 @@ pub struct Api { endpoint: http::Endpoint, pub caches: &'static ApiCaches, pub locks: &'static ApiLocks, + pub endpoint_rate_limiter: Arc, jwt: String, } @@ -34,6 +36,7 @@ impl Api { endpoint: http::Endpoint, caches: &'static ApiCaches, locks: &'static ApiLocks, + endpoint_rate_limiter: Arc, ) -> Self { let jwt: String = match std::env::var("NEON_PROXY_TO_CONTROLPLANE_TOKEN") { Ok(v) => v, @@ -43,6 +46,7 @@ impl Api { endpoint, caches, locks, + endpoint_rate_limiter, jwt, } } @@ -277,6 +281,14 @@ impl super::Api for Api { return Ok(cached); } + // check rate limit + if !self + .endpoint_rate_limiter + .check(user_info.endpoint.normalize().into(), 1) + { + return Err(WakeComputeError::TooManyConnections); + } + let permit = self.locks.get_wake_compute_permit(&key).await?; // after getting back a permit - it's possible the cache was filled diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index f80ced91c8..4321bad968 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -19,9 +19,8 @@ use crate::{ metrics::{Metrics, NumClientConnectionsGuard}, protocol2::WithClientIp, proxy::handshake::{handshake, HandshakeData}, - rate_limiter::EndpointRateLimiter, stream::{PqStream, Stream}, - EndpointCacheKey, Normalize, + EndpointCacheKey, }; use futures::TryFutureExt; use itertools::Itertools; @@ -61,7 +60,6 @@ pub async fn task_main( config: &'static ProxyConfig, listener: tokio::net::TcpListener, cancellation_token: CancellationToken, - endpoint_rate_limiter: Arc, cancellation_handler: Arc, ) -> anyhow::Result<()> { scopeguard::defer! { @@ -86,7 +84,6 @@ pub async fn task_main( let session_id = uuid::Uuid::new_v4(); let cancellation_handler = Arc::clone(&cancellation_handler); - let endpoint_rate_limiter = endpoint_rate_limiter.clone(); tracing::info!(protocol = "tcp", %session_id, "accepted new TCP connection"); @@ -128,7 +125,6 @@ pub async fn task_main( cancellation_handler, socket, ClientMode::Tcp, - endpoint_rate_limiter, conn_gauge, ) .instrument(span.clone()) @@ -242,7 +238,6 @@ pub async fn handle_client( cancellation_handler: Arc, stream: S, mode: ClientMode, - endpoint_rate_limiter: Arc, conn_gauge: NumClientConnectionsGuard<'static>, ) -> Result>, ClientRequestError> { info!( @@ -288,15 +283,6 @@ pub async fn handle_client( Err(e) => stream.throw_error(e).await?, }; - // check rate limit - if let Some(ep) = user_info.get_endpoint() { - if !endpoint_rate_limiter.check(ep.normalize(), 1) { - return stream - .throw_error(auth::AuthError::too_many_connections()) - .await?; - } - } - let user = user_info.get_user().to_owned(); let user_info = match user_info .authenticate( diff --git a/proxy/src/proxy/wake_compute.rs b/proxy/src/proxy/wake_compute.rs index f8154b1a94..fe228ab33d 100644 --- a/proxy/src/proxy/wake_compute.rs +++ b/proxy/src/proxy/wake_compute.rs @@ -90,6 +90,7 @@ fn report_error(e: &WakeComputeError, retry: bool) { WakeComputeError::ApiError(ApiError::Console { .. }) => { WakeupFailureKind::ApiConsoleOtherError } + WakeComputeError::TooManyConnections => WakeupFailureKind::ApiConsoleLocked, WakeComputeError::TimeoutError => WakeupFailureKind::TimeoutError, }; Metrics::get() diff --git a/proxy/src/rate_limiter/limiter.rs b/proxy/src/rate_limiter/limiter.rs index 3796b22ae9..5ba2c36436 100644 --- a/proxy/src/rate_limiter/limiter.rs +++ b/proxy/src/rate_limiter/limiter.rs @@ -15,7 +15,7 @@ use rand::{rngs::StdRng, Rng, SeedableRng}; use tokio::time::{Duration, Instant}; use tracing::info; -use crate::EndpointId; +use crate::intern::EndpointIdInt; pub struct GlobalRateLimiter { data: Vec, @@ -61,12 +61,7 @@ impl GlobalRateLimiter { // Purposefully ignore user name and database name as clients can reconnect // with different names, so we'll end up sending some http requests to // the control plane. -// -// We also may save quite a lot of CPU (I think) by bailing out right after we -// saw SNI, before doing TLS handshake. User-side error messages in that case -// does not look very nice (`SSL SYSCALL error: Undefined error: 0`), so for now -// I went with a more expensive way that yields user-friendlier error messages. -pub type EndpointRateLimiter = BucketRateLimiter; +pub type EndpointRateLimiter = BucketRateLimiter; pub struct BucketRateLimiter { map: DashMap, Hasher>, @@ -245,7 +240,7 @@ mod tests { use tokio::time; use super::{BucketRateLimiter, EndpointRateLimiter}; - use crate::{rate_limiter::RateBucketInfo, EndpointId}; + use crate::{intern::EndpointIdInt, rate_limiter::RateBucketInfo, EndpointId}; #[test] fn rate_bucket_rpi() { @@ -295,39 +290,40 @@ mod tests { let limiter = EndpointRateLimiter::new(rates); let endpoint = EndpointId::from("ep-my-endpoint-1234"); + let endpoint = EndpointIdInt::from(endpoint); time::pause(); for _ in 0..100 { - assert!(limiter.check(endpoint.clone(), 1)); + assert!(limiter.check(endpoint, 1)); } // more connections fail - assert!(!limiter.check(endpoint.clone(), 1)); + assert!(!limiter.check(endpoint, 1)); // fail even after 500ms as it's in the same bucket time::advance(time::Duration::from_millis(500)).await; - assert!(!limiter.check(endpoint.clone(), 1)); + assert!(!limiter.check(endpoint, 1)); // after a full 1s, 100 requests are allowed again time::advance(time::Duration::from_millis(500)).await; for _ in 1..6 { for _ in 0..50 { - assert!(limiter.check(endpoint.clone(), 2)); + assert!(limiter.check(endpoint, 2)); } time::advance(time::Duration::from_millis(1000)).await; } // more connections after 600 will exceed the 20rps@30s limit - assert!(!limiter.check(endpoint.clone(), 1)); + assert!(!limiter.check(endpoint, 1)); // will still fail before the 30 second limit time::advance(time::Duration::from_millis(30_000 - 6_000 - 1)).await; - assert!(!limiter.check(endpoint.clone(), 1)); + assert!(!limiter.check(endpoint, 1)); // after the full 30 seconds, 100 requests are allowed again time::advance(time::Duration::from_millis(1)).await; for _ in 0..100 { - assert!(limiter.check(endpoint.clone(), 1)); + assert!(limiter.check(endpoint, 1)); } } diff --git a/proxy/src/serverless.rs b/proxy/src/serverless.rs index f3c42cdb01..b0f4026c76 100644 --- a/proxy/src/serverless.rs +++ b/proxy/src/serverless.rs @@ -35,7 +35,6 @@ use crate::context::RequestMonitoring; use crate::metrics::Metrics; use crate::protocol2::WithClientIp; use crate::proxy::run_until_cancelled; -use crate::rate_limiter::EndpointRateLimiter; use crate::serverless::backend::PoolingBackend; use crate::serverless::http_util::{api_error_into_response, json_response}; @@ -53,7 +52,6 @@ pub async fn task_main( config: &'static ProxyConfig, ws_listener: TcpListener, cancellation_token: CancellationToken, - endpoint_rate_limiter: Arc, cancellation_handler: Arc, ) -> anyhow::Result<()> { scopeguard::defer! { @@ -117,7 +115,6 @@ pub async fn task_main( backend.clone(), connections.clone(), cancellation_handler.clone(), - endpoint_rate_limiter.clone(), cancellation_token.clone(), server.clone(), tls_acceptor.clone(), @@ -147,7 +144,6 @@ async fn connection_handler( backend: Arc, connections: TaskTracker, cancellation_handler: Arc, - endpoint_rate_limiter: Arc, cancellation_token: CancellationToken, server: Builder, tls_acceptor: TlsAcceptor, @@ -231,7 +227,6 @@ async fn connection_handler( cancellation_handler.clone(), session_id, peer_addr, - endpoint_rate_limiter.clone(), http_request_token, ) .in_current_span() @@ -270,7 +265,6 @@ async fn request_handler( cancellation_handler: Arc, session_id: uuid::Uuid, peer_addr: IpAddr, - endpoint_rate_limiter: Arc, // used to cancel in-flight HTTP requests. not used to cancel websockets http_cancellation_token: CancellationToken, ) -> Result>, ApiError> { @@ -298,15 +292,9 @@ async fn request_handler( ws_connections.spawn( async move { - if let Err(e) = websocket::serve_websocket( - config, - ctx, - websocket, - cancellation_handler, - host, - endpoint_rate_limiter, - ) - .await + if let Err(e) = + websocket::serve_websocket(config, ctx, websocket, cancellation_handler, host) + .await { error!("error in websocket connection: {e:#}"); } diff --git a/proxy/src/serverless/websocket.rs b/proxy/src/serverless/websocket.rs index d054877126..eddd278b7d 100644 --- a/proxy/src/serverless/websocket.rs +++ b/proxy/src/serverless/websocket.rs @@ -5,7 +5,6 @@ use crate::{ error::{io_error, ReportableError}, metrics::Metrics, proxy::{handle_client, ClientMode}, - rate_limiter::EndpointRateLimiter, }; use bytes::{Buf, Bytes}; use futures::{Sink, Stream}; @@ -136,7 +135,6 @@ pub async fn serve_websocket( websocket: HyperWebsocket, cancellation_handler: Arc, hostname: Option, - endpoint_rate_limiter: Arc, ) -> anyhow::Result<()> { let websocket = websocket.await?; let conn_gauge = Metrics::get() @@ -150,7 +148,6 @@ pub async fn serve_websocket( cancellation_handler, WebSocketRw::new(websocket), ClientMode::Websockets { hostname }, - endpoint_rate_limiter, conn_gauge, ) .await;