mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-08 22:12:56 +00:00
Add per-endpoint rate limiter to proxy
This commit is contained in:
@@ -62,6 +62,9 @@ pub enum AuthErrorImpl {
|
||||
Please add it to the allowed list in the Neon console."
|
||||
)]
|
||||
IpAddressNotAllowed,
|
||||
|
||||
#[error("Too many connections to this endpoint. Please try again later.")]
|
||||
TooManyConnections,
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
@@ -80,6 +83,10 @@ impl AuthError {
|
||||
pub fn ip_address_not_allowed() -> Self {
|
||||
AuthErrorImpl::IpAddressNotAllowed.into()
|
||||
}
|
||||
|
||||
pub fn too_many_connections() -> Self {
|
||||
AuthErrorImpl::TooManyConnections.into()
|
||||
}
|
||||
}
|
||||
|
||||
impl<E: Into<AuthErrorImpl>> From<E> for AuthError {
|
||||
@@ -102,6 +109,7 @@ impl UserFacingError for AuthError {
|
||||
MissingEndpointName => self.to_string(),
|
||||
Io(_) => "Internal error".to_string(),
|
||||
IpAddressNotAllowed => self.to_string(),
|
||||
TooManyConnections => self.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -112,6 +112,9 @@ struct ProxyCliArgs {
|
||||
/// Timeout for rate limiter. If it didn't manage to aquire a permit in this time, it will return an error.
|
||||
#[clap(long, default_value = "15s", value_parser = humantime::parse_duration)]
|
||||
rate_limiter_timeout: tokio::time::Duration,
|
||||
/// Endpoint rate limiter max number of requests per second.
|
||||
#[clap(long, default_value_t = 300)]
|
||||
endpoint_rps_limit: u32,
|
||||
/// Initial limit for dynamic rate limiter. Makes sense only if `rate_limit_algorithm` is *not* `None`.
|
||||
#[clap(long, default_value_t = 100)]
|
||||
initial_limit: usize,
|
||||
@@ -317,6 +320,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
|
||||
authentication_config,
|
||||
require_client_ip: args.require_client_ip,
|
||||
disable_ip_check_for_http: args.disable_ip_check_for_http,
|
||||
endpoint_rps_limit: args.endpoint_rps_limit,
|
||||
}));
|
||||
|
||||
Ok(config)
|
||||
|
||||
@@ -20,6 +20,7 @@ pub struct ProxyConfig {
|
||||
pub authentication_config: AuthenticationConfig,
|
||||
pub require_client_ip: bool,
|
||||
pub disable_ip_check_for_http: bool,
|
||||
pub endpoint_rps_limit: u32,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
|
||||
@@ -9,6 +9,7 @@ use crate::{
|
||||
console::{self, errors::WakeComputeError, messages::MetricsAuxInfo, Api},
|
||||
http::StatusCode,
|
||||
protocol2::WithClientIp,
|
||||
rate_limiter::EndpointRateLimiter,
|
||||
stream::{PqStream, Stream},
|
||||
usage_metrics::{Ids, USAGE_METRICS},
|
||||
};
|
||||
@@ -307,6 +308,7 @@ pub async fn task_main(
|
||||
|
||||
let connections = tokio_util::task::task_tracker::TaskTracker::new();
|
||||
let cancel_map = Arc::new(CancelMap::default());
|
||||
let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new(config.endpoint_rps_limit));
|
||||
|
||||
while let Some(accept_result) =
|
||||
run_until_cancelled(listener.accept(), &cancellation_token).await
|
||||
@@ -315,6 +317,8 @@ pub async fn task_main(
|
||||
|
||||
let session_id = uuid::Uuid::new_v4();
|
||||
let cancel_map = Arc::clone(&cancel_map);
|
||||
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
|
||||
|
||||
connections.spawn(
|
||||
async move {
|
||||
info!("accepted postgres client connection");
|
||||
@@ -340,6 +344,7 @@ pub async fn task_main(
|
||||
socket,
|
||||
ClientMode::Tcp,
|
||||
peer_addr.ip(),
|
||||
endpoint_rate_limiter,
|
||||
)
|
||||
.await
|
||||
}
|
||||
@@ -415,6 +420,7 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
stream: S,
|
||||
mode: ClientMode,
|
||||
peer_addr: IpAddr,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
) -> anyhow::Result<()> {
|
||||
info!(
|
||||
protocol = mode.protocol_label(),
|
||||
@@ -463,6 +469,7 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
¶ms,
|
||||
session_id,
|
||||
mode.allow_self_signed_compute(config),
|
||||
endpoint_rate_limiter,
|
||||
);
|
||||
cancel_map
|
||||
.with_session(|session| client.connect_to_db(session, mode, &config.authentication_config))
|
||||
@@ -928,6 +935,8 @@ struct Client<'a, S> {
|
||||
session_id: uuid::Uuid,
|
||||
/// Allow self-signed certificates (for testing).
|
||||
allow_self_signed_compute: bool,
|
||||
/// Rate limiter for endpoints
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
}
|
||||
|
||||
impl<'a, S> Client<'a, S> {
|
||||
@@ -938,6 +947,7 @@ impl<'a, S> Client<'a, S> {
|
||||
params: &'a StartupMessageParams,
|
||||
session_id: uuid::Uuid,
|
||||
allow_self_signed_compute: bool,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
) -> Self {
|
||||
Self {
|
||||
stream,
|
||||
@@ -945,6 +955,7 @@ impl<'a, S> Client<'a, S> {
|
||||
params,
|
||||
session_id,
|
||||
allow_self_signed_compute,
|
||||
endpoint_rate_limiter,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -966,8 +977,18 @@ impl<S: AsyncRead + AsyncWrite + Unpin> Client<'_, S> {
|
||||
params,
|
||||
session_id,
|
||||
allow_self_signed_compute,
|
||||
endpoint_rate_limiter,
|
||||
} = self;
|
||||
|
||||
// check rate limit
|
||||
if let Some(ep) = creds.get_endpoint() {
|
||||
if !endpoint_rate_limiter.check(ep) {
|
||||
return stream
|
||||
.throw_error(auth::AuthError::too_many_connections())
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
let proto = mode.protocol_label();
|
||||
let extra = console::ConsoleReqExtra {
|
||||
session_id, // aka this connection's id
|
||||
|
||||
@@ -3,4 +3,5 @@ mod limit_algorithm;
|
||||
mod limiter;
|
||||
pub use aimd::Aimd;
|
||||
pub use limit_algorithm::{AimdConfig, Fixed, RateLimitAlgorithm, RateLimiterConfig};
|
||||
pub use limiter::EndpointRateLimiter;
|
||||
pub use limiter::Limiter;
|
||||
|
||||
@@ -6,6 +6,9 @@ use std::{
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use dashmap::DashMap;
|
||||
use parking_lot::Mutex;
|
||||
use smol_str::SmolStr;
|
||||
use tokio::sync::{Mutex as AsyncMutex, Semaphore, SemaphorePermit};
|
||||
use tokio::time::{timeout, Instant};
|
||||
use tracing::info;
|
||||
@@ -15,6 +18,74 @@ use super::{
|
||||
RateLimiterConfig,
|
||||
};
|
||||
|
||||
// Simple per-endpoint rate limiter.
|
||||
//
|
||||
// Check that number of connections to the endpoint is below `max_rps` rps.
|
||||
// Purposefully ignore user name and database name as clients can reconnect
|
||||
// with different names, so we'll end up sending some http requests to
|
||||
// the control plane.
|
||||
//
|
||||
// We also may save quite a lot of CPU (I think) by bailing out right after we
|
||||
// saw SNI, before doing TLS handshake. User-side error messages in that case
|
||||
// does not look very nice (`SSL SYSCALL error: Undefined error: 0`), so for now
|
||||
// I went with a more expensive way that yields user-friendlier error messages.
|
||||
//
|
||||
// TODO: add a better bucketing here, e.g. not more than 300 requests per second,
|
||||
// and not more than 1000 requests per 10 seconds, etc. Short bursts of reconnects
|
||||
// are noramal during redeployments, so we should not block them.
|
||||
pub struct EndpointRateLimiter {
|
||||
map: DashMap<SmolStr, Arc<Mutex<(chrono::NaiveTime, u32)>>>,
|
||||
max_rps: u32,
|
||||
access_count: AtomicUsize,
|
||||
}
|
||||
|
||||
impl EndpointRateLimiter {
|
||||
pub fn new(max_rps: u32) -> Self {
|
||||
Self {
|
||||
map: DashMap::new(),
|
||||
max_rps,
|
||||
access_count: AtomicUsize::new(1), // start from 1 to avoid GC on the first request
|
||||
}
|
||||
}
|
||||
|
||||
/// Check that number of connections to the endpoint is below `max_rps` rps.
|
||||
pub fn check(&self, endpoint: SmolStr) -> bool {
|
||||
// do GC every 100k requests (worst case memory usage is about 10MB)
|
||||
if self.access_count.fetch_add(1, Ordering::AcqRel) % 100_000 == 0 {
|
||||
self.do_gc();
|
||||
}
|
||||
|
||||
let now = chrono::Utc::now().naive_utc().time();
|
||||
let entry = self
|
||||
.map
|
||||
.entry(endpoint)
|
||||
.or_insert_with(|| Arc::new(Mutex::new((now, 0))));
|
||||
let mut entry = entry.lock();
|
||||
let (last_time, count) = *entry;
|
||||
|
||||
if now - last_time < chrono::Duration::seconds(1) {
|
||||
if count >= self.max_rps {
|
||||
return false;
|
||||
}
|
||||
*entry = (last_time, count + 1);
|
||||
} else {
|
||||
*entry = (now, 1);
|
||||
}
|
||||
true
|
||||
}
|
||||
|
||||
/// Clean the map. Simple strategy: remove all entries. At worst, we'll
|
||||
/// double the effective max_rps during the cleanup. But that way deletion
|
||||
/// does not aquire mutex on each entry access.
|
||||
pub fn do_gc(&self) {
|
||||
info!(
|
||||
"cleaning up endpoint rate limiter, current size = {}",
|
||||
self.map.len()
|
||||
);
|
||||
self.map.clear();
|
||||
}
|
||||
}
|
||||
|
||||
/// Limits the number of concurrent jobs.
|
||||
///
|
||||
/// Concurrency is limited through the use of [Token]s. Acquire a token to run a job, and release the
|
||||
|
||||
@@ -14,6 +14,7 @@ use tokio_util::task::TaskTracker;
|
||||
|
||||
use crate::protocol2::{ProxyProtocolAccept, WithClientIp};
|
||||
use crate::proxy::{NUM_CLIENT_CONNECTION_CLOSED_COUNTER, NUM_CLIENT_CONNECTION_OPENED_COUNTER};
|
||||
use crate::rate_limiter::EndpointRateLimiter;
|
||||
use crate::{cancellation::CancelMap, config::ProxyConfig};
|
||||
use futures::StreamExt;
|
||||
use hyper::{
|
||||
@@ -43,6 +44,7 @@ pub async fn task_main(
|
||||
}
|
||||
|
||||
let conn_pool = conn_pool::GlobalConnPool::new(config);
|
||||
let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new(config.endpoint_rps_limit));
|
||||
|
||||
// shutdown the connection pool
|
||||
tokio::spawn({
|
||||
@@ -91,6 +93,7 @@ pub async fn task_main(
|
||||
let sni_name = tls.server_name().map(|s| s.to_string());
|
||||
let conn_pool = conn_pool.clone();
|
||||
let ws_connections = ws_connections.clone();
|
||||
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
|
||||
|
||||
async move {
|
||||
let peer_addr = match client_addr {
|
||||
@@ -103,6 +106,7 @@ pub async fn task_main(
|
||||
let sni_name = sni_name.clone();
|
||||
let conn_pool = conn_pool.clone();
|
||||
let ws_connections = ws_connections.clone();
|
||||
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
|
||||
|
||||
async move {
|
||||
let cancel_map = Arc::new(CancelMap::default());
|
||||
@@ -117,6 +121,7 @@ pub async fn task_main(
|
||||
session_id,
|
||||
sni_name,
|
||||
peer_addr.ip(),
|
||||
endpoint_rate_limiter,
|
||||
)
|
||||
.instrument(info_span!(
|
||||
"serverless",
|
||||
@@ -190,6 +195,7 @@ async fn request_handler(
|
||||
session_id: uuid::Uuid,
|
||||
sni_hostname: Option<String>,
|
||||
peer_addr: IpAddr,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
) -> Result<Response<Body>, ApiError> {
|
||||
let host = request
|
||||
.headers()
|
||||
@@ -214,6 +220,7 @@ async fn request_handler(
|
||||
session_id,
|
||||
host,
|
||||
peer_addr,
|
||||
endpoint_rate_limiter,
|
||||
)
|
||||
.await
|
||||
{
|
||||
|
||||
@@ -3,6 +3,7 @@ use crate::{
|
||||
config::ProxyConfig,
|
||||
error::io_error,
|
||||
proxy::{handle_client, ClientMode},
|
||||
rate_limiter::EndpointRateLimiter,
|
||||
};
|
||||
use bytes::{Buf, Bytes};
|
||||
use futures::{Sink, Stream};
|
||||
@@ -13,6 +14,7 @@ use pin_project_lite::pin_project;
|
||||
use std::{
|
||||
net::IpAddr,
|
||||
pin::Pin,
|
||||
sync::Arc,
|
||||
task::{ready, Context, Poll},
|
||||
};
|
||||
use tokio::io::{self, AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
|
||||
@@ -134,6 +136,7 @@ pub async fn serve_websocket(
|
||||
session_id: uuid::Uuid,
|
||||
hostname: Option<String>,
|
||||
peer_addr: IpAddr,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
) -> anyhow::Result<()> {
|
||||
let websocket = websocket.await?;
|
||||
handle_client(
|
||||
@@ -143,6 +146,7 @@ pub async fn serve_websocket(
|
||||
WebSocketRw::new(websocket),
|
||||
ClientMode::Websockets { hostname },
|
||||
peer_addr,
|
||||
endpoint_rate_limiter,
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
|
||||
Reference in New Issue
Block a user