From 8460654f61d59dc639db282473148e2b8ef50b09 Mon Sep 17 00:00:00 2001 From: Stas Kelvich Date: Wed, 13 Dec 2023 03:52:10 +0200 Subject: [PATCH] Add per-endpoint rate limiter to proxy --- proxy/src/auth.rs | 8 ++++ proxy/src/bin/proxy.rs | 4 ++ proxy/src/config.rs | 1 + proxy/src/proxy.rs | 21 +++++++++ proxy/src/rate_limiter.rs | 1 + proxy/src/rate_limiter/limiter.rs | 71 +++++++++++++++++++++++++++++++ proxy/src/serverless.rs | 7 +++ proxy/src/serverless/websocket.rs | 4 ++ 8 files changed, 117 insertions(+) diff --git a/proxy/src/auth.rs b/proxy/src/auth.rs index 7d79d34045..eadb9abd43 100644 --- a/proxy/src/auth.rs +++ b/proxy/src/auth.rs @@ -62,6 +62,9 @@ pub enum AuthErrorImpl { Please add it to the allowed list in the Neon console." )] IpAddressNotAllowed, + + #[error("Too many connections to this endpoint. Please try again later.")] + TooManyConnections, } #[derive(Debug, Error)] @@ -80,6 +83,10 @@ impl AuthError { pub fn ip_address_not_allowed() -> Self { AuthErrorImpl::IpAddressNotAllowed.into() } + + pub fn too_many_connections() -> Self { + AuthErrorImpl::TooManyConnections.into() + } } impl> From for AuthError { @@ -102,6 +109,7 @@ impl UserFacingError for AuthError { MissingEndpointName => self.to_string(), Io(_) => "Internal error".to_string(), IpAddressNotAllowed => self.to_string(), + TooManyConnections => self.to_string(), } } } diff --git a/proxy/src/bin/proxy.rs b/proxy/src/bin/proxy.rs index fc1c44809a..1fa2d5599f 100644 --- a/proxy/src/bin/proxy.rs +++ b/proxy/src/bin/proxy.rs @@ -112,6 +112,9 @@ struct ProxyCliArgs { /// Timeout for rate limiter. If it didn't manage to aquire a permit in this time, it will return an error. #[clap(long, default_value = "15s", value_parser = humantime::parse_duration)] rate_limiter_timeout: tokio::time::Duration, + /// Endpoint rate limiter max number of requests per second. + #[clap(long, default_value_t = 300)] + endpoint_rps_limit: u32, /// Initial limit for dynamic rate limiter. Makes sense only if `rate_limit_algorithm` is *not* `None`. #[clap(long, default_value_t = 100)] initial_limit: usize, @@ -317,6 +320,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { authentication_config, require_client_ip: args.require_client_ip, disable_ip_check_for_http: args.disable_ip_check_for_http, + endpoint_rps_limit: args.endpoint_rps_limit, })); Ok(config) diff --git a/proxy/src/config.rs b/proxy/src/config.rs index 182d71f9be..dea446eb22 100644 --- a/proxy/src/config.rs +++ b/proxy/src/config.rs @@ -20,6 +20,7 @@ pub struct ProxyConfig { pub authentication_config: AuthenticationConfig, pub require_client_ip: bool, pub disable_ip_check_for_http: bool, + pub endpoint_rps_limit: u32, } #[derive(Debug)] diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index 152c894ca9..ae8b294841 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -9,6 +9,7 @@ use crate::{ console::{self, errors::WakeComputeError, messages::MetricsAuxInfo, Api}, http::StatusCode, protocol2::WithClientIp, + rate_limiter::EndpointRateLimiter, stream::{PqStream, Stream}, usage_metrics::{Ids, USAGE_METRICS}, }; @@ -307,6 +308,7 @@ pub async fn task_main( let connections = tokio_util::task::task_tracker::TaskTracker::new(); let cancel_map = Arc::new(CancelMap::default()); + let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new(config.endpoint_rps_limit)); while let Some(accept_result) = run_until_cancelled(listener.accept(), &cancellation_token).await @@ -315,6 +317,8 @@ pub async fn task_main( let session_id = uuid::Uuid::new_v4(); let cancel_map = Arc::clone(&cancel_map); + let endpoint_rate_limiter = endpoint_rate_limiter.clone(); + connections.spawn( async move { info!("accepted postgres client connection"); @@ -340,6 +344,7 @@ pub async fn task_main( socket, ClientMode::Tcp, peer_addr.ip(), + endpoint_rate_limiter, ) .await } @@ -415,6 +420,7 @@ pub async fn handle_client( stream: S, mode: ClientMode, peer_addr: IpAddr, + endpoint_rate_limiter: Arc, ) -> anyhow::Result<()> { info!( protocol = mode.protocol_label(), @@ -463,6 +469,7 @@ pub async fn handle_client( ¶ms, session_id, mode.allow_self_signed_compute(config), + endpoint_rate_limiter, ); cancel_map .with_session(|session| client.connect_to_db(session, mode, &config.authentication_config)) @@ -928,6 +935,8 @@ struct Client<'a, S> { session_id: uuid::Uuid, /// Allow self-signed certificates (for testing). allow_self_signed_compute: bool, + /// Rate limiter for endpoints + endpoint_rate_limiter: Arc, } impl<'a, S> Client<'a, S> { @@ -938,6 +947,7 @@ impl<'a, S> Client<'a, S> { params: &'a StartupMessageParams, session_id: uuid::Uuid, allow_self_signed_compute: bool, + endpoint_rate_limiter: Arc, ) -> Self { Self { stream, @@ -945,6 +955,7 @@ impl<'a, S> Client<'a, S> { params, session_id, allow_self_signed_compute, + endpoint_rate_limiter, } } } @@ -966,8 +977,18 @@ impl Client<'_, S> { params, session_id, allow_self_signed_compute, + endpoint_rate_limiter, } = self; + // check rate limit + if let Some(ep) = creds.get_endpoint() { + if !endpoint_rate_limiter.check(ep) { + return stream + .throw_error(auth::AuthError::too_many_connections()) + .await; + } + } + let proto = mode.protocol_label(); let extra = console::ConsoleReqExtra { session_id, // aka this connection's id diff --git a/proxy/src/rate_limiter.rs b/proxy/src/rate_limiter.rs index 5622c44a68..f40b8dbd1c 100644 --- a/proxy/src/rate_limiter.rs +++ b/proxy/src/rate_limiter.rs @@ -3,4 +3,5 @@ mod limit_algorithm; mod limiter; pub use aimd::Aimd; pub use limit_algorithm::{AimdConfig, Fixed, RateLimitAlgorithm, RateLimiterConfig}; +pub use limiter::EndpointRateLimiter; pub use limiter::Limiter; diff --git a/proxy/src/rate_limiter/limiter.rs b/proxy/src/rate_limiter/limiter.rs index 3a9fed3919..9d28bb67b3 100644 --- a/proxy/src/rate_limiter/limiter.rs +++ b/proxy/src/rate_limiter/limiter.rs @@ -6,6 +6,9 @@ use std::{ time::Duration, }; +use dashmap::DashMap; +use parking_lot::Mutex; +use smol_str::SmolStr; use tokio::sync::{Mutex as AsyncMutex, Semaphore, SemaphorePermit}; use tokio::time::{timeout, Instant}; use tracing::info; @@ -15,6 +18,74 @@ use super::{ RateLimiterConfig, }; +// Simple per-endpoint rate limiter. +// +// Check that number of connections to the endpoint is below `max_rps` rps. +// Purposefully ignore user name and database name as clients can reconnect +// with different names, so we'll end up sending some http requests to +// the control plane. +// +// We also may save quite a lot of CPU (I think) by bailing out right after we +// saw SNI, before doing TLS handshake. User-side error messages in that case +// does not look very nice (`SSL SYSCALL error: Undefined error: 0`), so for now +// I went with a more expensive way that yields user-friendlier error messages. +// +// TODO: add a better bucketing here, e.g. not more than 300 requests per second, +// and not more than 1000 requests per 10 seconds, etc. Short bursts of reconnects +// are noramal during redeployments, so we should not block them. +pub struct EndpointRateLimiter { + map: DashMap>>, + max_rps: u32, + access_count: AtomicUsize, +} + +impl EndpointRateLimiter { + pub fn new(max_rps: u32) -> Self { + Self { + map: DashMap::new(), + max_rps, + access_count: AtomicUsize::new(1), // start from 1 to avoid GC on the first request + } + } + + /// Check that number of connections to the endpoint is below `max_rps` rps. + pub fn check(&self, endpoint: SmolStr) -> bool { + // do GC every 100k requests (worst case memory usage is about 10MB) + if self.access_count.fetch_add(1, Ordering::AcqRel) % 100_000 == 0 { + self.do_gc(); + } + + let now = chrono::Utc::now().naive_utc().time(); + let entry = self + .map + .entry(endpoint) + .or_insert_with(|| Arc::new(Mutex::new((now, 0)))); + let mut entry = entry.lock(); + let (last_time, count) = *entry; + + if now - last_time < chrono::Duration::seconds(1) { + if count >= self.max_rps { + return false; + } + *entry = (last_time, count + 1); + } else { + *entry = (now, 1); + } + true + } + + /// Clean the map. Simple strategy: remove all entries. At worst, we'll + /// double the effective max_rps during the cleanup. But that way deletion + /// does not aquire mutex on each entry access. + pub fn do_gc(&self) { + info!( + "cleaning up endpoint rate limiter, current size = {}", + self.map.len() + ); + self.map.clear(); + } +} + /// Limits the number of concurrent jobs. /// /// Concurrency is limited through the use of [Token]s. Acquire a token to run a job, and release the diff --git a/proxy/src/serverless.rs b/proxy/src/serverless.rs index cd496ff01e..92d6e2d851 100644 --- a/proxy/src/serverless.rs +++ b/proxy/src/serverless.rs @@ -14,6 +14,7 @@ use tokio_util::task::TaskTracker; use crate::protocol2::{ProxyProtocolAccept, WithClientIp}; use crate::proxy::{NUM_CLIENT_CONNECTION_CLOSED_COUNTER, NUM_CLIENT_CONNECTION_OPENED_COUNTER}; +use crate::rate_limiter::EndpointRateLimiter; use crate::{cancellation::CancelMap, config::ProxyConfig}; use futures::StreamExt; use hyper::{ @@ -43,6 +44,7 @@ pub async fn task_main( } let conn_pool = conn_pool::GlobalConnPool::new(config); + let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new(config.endpoint_rps_limit)); // shutdown the connection pool tokio::spawn({ @@ -91,6 +93,7 @@ pub async fn task_main( let sni_name = tls.server_name().map(|s| s.to_string()); let conn_pool = conn_pool.clone(); let ws_connections = ws_connections.clone(); + let endpoint_rate_limiter = endpoint_rate_limiter.clone(); async move { let peer_addr = match client_addr { @@ -103,6 +106,7 @@ pub async fn task_main( let sni_name = sni_name.clone(); let conn_pool = conn_pool.clone(); let ws_connections = ws_connections.clone(); + let endpoint_rate_limiter = endpoint_rate_limiter.clone(); async move { let cancel_map = Arc::new(CancelMap::default()); @@ -117,6 +121,7 @@ pub async fn task_main( session_id, sni_name, peer_addr.ip(), + endpoint_rate_limiter, ) .instrument(info_span!( "serverless", @@ -190,6 +195,7 @@ async fn request_handler( session_id: uuid::Uuid, sni_hostname: Option, peer_addr: IpAddr, + endpoint_rate_limiter: Arc, ) -> Result, ApiError> { let host = request .headers() @@ -214,6 +220,7 @@ async fn request_handler( session_id, host, peer_addr, + endpoint_rate_limiter, ) .await { diff --git a/proxy/src/serverless/websocket.rs b/proxy/src/serverless/websocket.rs index 199b03550d..cd6184cdee 100644 --- a/proxy/src/serverless/websocket.rs +++ b/proxy/src/serverless/websocket.rs @@ -3,6 +3,7 @@ use crate::{ config::ProxyConfig, error::io_error, proxy::{handle_client, ClientMode}, + rate_limiter::EndpointRateLimiter, }; use bytes::{Buf, Bytes}; use futures::{Sink, Stream}; @@ -13,6 +14,7 @@ use pin_project_lite::pin_project; use std::{ net::IpAddr, pin::Pin, + sync::Arc, task::{ready, Context, Poll}, }; use tokio::io::{self, AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf}; @@ -134,6 +136,7 @@ pub async fn serve_websocket( session_id: uuid::Uuid, hostname: Option, peer_addr: IpAddr, + endpoint_rate_limiter: Arc, ) -> anyhow::Result<()> { let websocket = websocket.await?; handle_client( @@ -143,6 +146,7 @@ pub async fn serve_websocket( WebSocketRw::new(websocket), ClientMode::Websockets { hostname }, peer_addr, + endpoint_rate_limiter, ) .await?; Ok(())