Proxy added per ep rate limiter (#7636)

## Problem

There is no global per-ep rate limiter in proxy.

## Summary of changes

* Return global per-ep rate limiter back.
* Rename weak compute rate limiter (the cli flags were not used
anywhere, so it's safe to rename).
This commit is contained in:
Anna Khanova
2024-05-10 12:17:00 +02:00
committed by GitHub
parent b9fd8dcf13
commit be1a88e574
8 changed files with 126 additions and 35 deletions

View File

@@ -13,7 +13,7 @@ use tokio_postgres::config::AuthKeys;
use tracing::{info, warn};
use crate::auth::credentials::check_peer_addr_is_in_list;
use crate::auth::validate_password_and_exchange;
use crate::auth::{validate_password_and_exchange, AuthError};
use crate::cache::Cached;
use crate::console::errors::GetAuthInfoError;
use crate::console::provider::{CachedRoleSecret, ConsoleBackend};
@@ -23,7 +23,7 @@ use crate::intern::EndpointIdInt;
use crate::metrics::Metrics;
use crate::proxy::connect_compute::ComputeConnectBackend;
use crate::proxy::NeonOptions;
use crate::rate_limiter::{BucketRateLimiter, RateBucketInfo};
use crate::rate_limiter::{BucketRateLimiter, EndpointRateLimiter, RateBucketInfo};
use crate::stream::Stream;
use crate::{
auth::{self, ComputeUserInfoMaybeEndpoint},
@@ -280,6 +280,7 @@ async fn auth_quirks(
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
allow_cleartext: bool,
config: &'static AuthenticationConfig,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
) -> auth::Result<ComputeCredentials> {
// If there's no project so far, that entails that client doesn't
// support SNI or other means of passing the endpoint (project) name.
@@ -305,6 +306,10 @@ async fn auth_quirks(
if !check_peer_addr_is_in_list(&ctx.peer_addr, &allowed_ips) {
return Err(auth::AuthError::ip_address_not_allowed(ctx.peer_addr));
}
if !endpoint_rate_limiter.check(info.endpoint.clone().into(), 1) {
return Err(AuthError::too_many_connections());
}
let cached_secret = match maybe_secret {
Some(secret) => secret,
None => api.get_role_secret(ctx, &info).await?,
@@ -417,6 +422,7 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint, &()> {
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
allow_cleartext: bool,
config: &'static AuthenticationConfig,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
) -> auth::Result<BackendType<'a, ComputeCredentials, NodeInfo>> {
use BackendType::*;
@@ -428,8 +434,16 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint, &()> {
"performing authentication using the console"
);
let credentials =
auth_quirks(ctx, &*api, user_info, client, allow_cleartext, config).await?;
let credentials = auth_quirks(
ctx,
&*api,
user_info,
client,
allow_cleartext,
config,
endpoint_rate_limiter,
)
.await?;
BackendType::Console(api, credentials)
}
// NOTE: this auth backend doesn't use client credentials.
@@ -539,7 +553,7 @@ mod tests {
},
context::RequestMonitoring,
proxy::NeonOptions,
rate_limiter::RateBucketInfo,
rate_limiter::{EndpointRateLimiter, RateBucketInfo},
scram::ServerSecret,
stream::{PqStream, Stream},
};
@@ -699,10 +713,20 @@ mod tests {
_ => panic!("wrong message"),
}
});
let endpoint_rate_limiter =
Arc::new(EndpointRateLimiter::new(&RateBucketInfo::DEFAULT_AUTH_SET));
let _creds = auth_quirks(&mut ctx, &api, user_info, &mut stream, false, &CONFIG)
.await
.unwrap();
let _creds = auth_quirks(
&mut ctx,
&api,
user_info,
&mut stream,
false,
&CONFIG,
endpoint_rate_limiter,
)
.await
.unwrap();
handle.await.unwrap();
}
@@ -739,10 +763,20 @@ mod tests {
frontend::password_message(b"my-secret-password", &mut write).unwrap();
client.write_all(&write).await.unwrap();
});
let endpoint_rate_limiter =
Arc::new(EndpointRateLimiter::new(&RateBucketInfo::DEFAULT_AUTH_SET));
let _creds = auth_quirks(&mut ctx, &api, user_info, &mut stream, true, &CONFIG)
.await
.unwrap();
let _creds = auth_quirks(
&mut ctx,
&api,
user_info,
&mut stream,
true,
&CONFIG,
endpoint_rate_limiter,
)
.await
.unwrap();
handle.await.unwrap();
}
@@ -780,9 +814,20 @@ mod tests {
client.write_all(&write).await.unwrap();
});
let creds = auth_quirks(&mut ctx, &api, user_info, &mut stream, true, &CONFIG)
.await
.unwrap();
let endpoint_rate_limiter =
Arc::new(EndpointRateLimiter::new(&RateBucketInfo::DEFAULT_AUTH_SET));
let creds = auth_quirks(
&mut ctx,
&api,
user_info,
&mut stream,
true,
&CONFIG,
endpoint_rate_limiter,
)
.await
.unwrap();
assert_eq!(creds.info.endpoint, "my-endpoint");

View File

@@ -144,6 +144,9 @@ struct ProxyCliArgs {
/// Can be given multiple times for different bucket sizes.
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_ENDPOINT_SET)]
endpoint_rps_limit: Vec<RateBucketInfo>,
/// Wake compute rate limiter max number of requests per second.
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_SET)]
wake_compute_limit: Vec<RateBucketInfo>,
/// Whether the auth rate limiter actually takes effect (for testing)
#[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
auth_rate_limit_enabled: bool,
@@ -154,7 +157,7 @@ struct ProxyCliArgs {
#[clap(long, default_value_t = 64)]
auth_rate_limit_ip_subnet: u8,
/// Redis rate limiter max number of requests per second.
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_ENDPOINT_SET)]
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_SET)]
redis_rps_limit: Vec<RateBucketInfo>,
/// cache for `allowed_ips` (use `size=0` to disable)
#[clap(long, default_value = config::CacheOptions::CACHE_DEFAULT_OPTIONS)]
@@ -365,6 +368,10 @@ async fn main() -> anyhow::Result<()> {
proxy::metrics::CancellationSource::FromClient,
));
let mut endpoint_rps_limit = args.endpoint_rps_limit.clone();
RateBucketInfo::validate(&mut endpoint_rps_limit)?;
let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new(endpoint_rps_limit));
// client facing tasks. these will exit on error or on cancellation
// cancellation returns Ok(())
let mut client_tasks = JoinSet::new();
@@ -373,6 +380,7 @@ async fn main() -> anyhow::Result<()> {
proxy_listener,
cancellation_token.clone(),
cancellation_handler.clone(),
endpoint_rate_limiter.clone(),
));
// TODO: rename the argument to something like serverless.
@@ -387,6 +395,7 @@ async fn main() -> anyhow::Result<()> {
serverless_listener,
cancellation_token.clone(),
cancellation_handler.clone(),
endpoint_rate_limiter.clone(),
));
}
@@ -559,11 +568,16 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
let url = args.auth_endpoint.parse()?;
let endpoint = http::Endpoint::new(url, http::new_client());
let mut endpoint_rps_limit = args.endpoint_rps_limit.clone();
RateBucketInfo::validate(&mut endpoint_rps_limit)?;
let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new(endpoint_rps_limit));
let api =
console::provider::neon::Api::new(endpoint, caches, locks, endpoint_rate_limiter);
let mut wake_compute_rps_limit = args.wake_compute_limit.clone();
RateBucketInfo::validate(&mut wake_compute_rps_limit)?;
let wake_compute_endpoint_rate_limiter =
Arc::new(EndpointRateLimiter::new(wake_compute_rps_limit));
let api = console::provider::neon::Api::new(
endpoint,
caches,
locks,
wake_compute_endpoint_rate_limiter,
);
let api = console::provider::ConsoleBackend::Console(api);
auth::BackendType::Console(MaybeOwned::Owned(api), ())
}

View File

@@ -26,7 +26,7 @@ pub struct Api {
endpoint: http::Endpoint,
pub caches: &'static ApiCaches,
pub locks: &'static ApiLocks<EndpointCacheKey>,
pub endpoint_rate_limiter: Arc<EndpointRateLimiter>,
pub wake_compute_endpoint_rate_limiter: Arc<EndpointRateLimiter>,
jwt: String,
}
@@ -36,7 +36,7 @@ impl Api {
endpoint: http::Endpoint,
caches: &'static ApiCaches,
locks: &'static ApiLocks<EndpointCacheKey>,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
wake_compute_endpoint_rate_limiter: Arc<EndpointRateLimiter>,
) -> Self {
let jwt: String = match std::env::var("NEON_PROXY_TO_CONTROLPLANE_TOKEN") {
Ok(v) => v,
@@ -46,7 +46,7 @@ impl Api {
endpoint,
caches,
locks,
endpoint_rate_limiter,
wake_compute_endpoint_rate_limiter,
jwt,
}
}
@@ -283,7 +283,7 @@ impl super::Api for Api {
// check rate limit
if !self
.endpoint_rate_limiter
.wake_compute_endpoint_rate_limiter
.check(user_info.endpoint.normalize().into(), 1)
{
return Err(WakeComputeError::TooManyConnections);

View File

@@ -19,6 +19,7 @@ use crate::{
metrics::{Metrics, NumClientConnectionsGuard},
protocol2::read_proxy_protocol,
proxy::handshake::{handshake, HandshakeData},
rate_limiter::EndpointRateLimiter,
stream::{PqStream, Stream},
EndpointCacheKey,
};
@@ -61,6 +62,7 @@ pub async fn task_main(
listener: tokio::net::TcpListener,
cancellation_token: CancellationToken,
cancellation_handler: Arc<CancellationHandlerMain>,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
) -> anyhow::Result<()> {
scopeguard::defer! {
info!("proxy has shut down");
@@ -86,6 +88,7 @@ pub async fn task_main(
let cancellation_handler = Arc::clone(&cancellation_handler);
tracing::info!(protocol = "tcp", %session_id, "accepted new TCP connection");
let endpoint_rate_limiter2 = endpoint_rate_limiter.clone();
connections.spawn(async move {
let (socket, peer_addr) = match read_proxy_protocol(socket).await{
@@ -123,6 +126,7 @@ pub async fn task_main(
cancellation_handler,
socket,
ClientMode::Tcp,
endpoint_rate_limiter2,
conn_gauge,
)
.instrument(span.clone())
@@ -234,6 +238,7 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
cancellation_handler: Arc<CancellationHandlerMain>,
stream: S,
mode: ClientMode,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
conn_gauge: NumClientConnectionsGuard<'static>,
) -> Result<Option<ProxyPassthrough<CancellationHandlerMainInternal, S>>, ClientRequestError> {
info!(
@@ -243,7 +248,6 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
let metrics = &Metrics::get().proxy;
let proto = ctx.protocol;
// let _client_gauge = metrics.client_connections.guard(proto);
let _request_gauge = metrics.connection_requests.guard(proto);
let tls = config.tls_config.as_ref();
@@ -286,6 +290,7 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
&mut stream,
mode.allow_cleartext(),
&config.authentication_config,
endpoint_rate_limiter,
)
.await
{

View File

@@ -128,12 +128,18 @@ impl std::str::FromStr for RateBucketInfo {
}
impl RateBucketInfo {
pub const DEFAULT_ENDPOINT_SET: [Self; 3] = [
pub const DEFAULT_SET: [Self; 3] = [
Self::new(300, Duration::from_secs(1)),
Self::new(200, Duration::from_secs(60)),
Self::new(100, Duration::from_secs(600)),
];
pub const DEFAULT_ENDPOINT_SET: [Self; 3] = [
Self::new(500, Duration::from_secs(1)),
Self::new(300, Duration::from_secs(60)),
Self::new(200, Duration::from_secs(600)),
];
pub fn validate(info: &mut [Self]) -> anyhow::Result<()> {
info.sort_unstable_by_key(|info| info.interval);
let invalid = info
@@ -266,7 +272,7 @@ mod tests {
#[test]
fn default_rate_buckets() {
let mut defaults = RateBucketInfo::DEFAULT_ENDPOINT_SET;
let mut defaults = RateBucketInfo::DEFAULT_SET;
RateBucketInfo::validate(&mut defaults[..]).unwrap();
}
@@ -333,11 +339,8 @@ mod tests {
let rand = rand::rngs::StdRng::from_seed([1; 32]);
let hasher = BuildHasherDefault::<FxHasher>::default();
let limiter = BucketRateLimiter::new_with_rand_and_hasher(
&RateBucketInfo::DEFAULT_ENDPOINT_SET,
rand,
hasher,
);
let limiter =
BucketRateLimiter::new_with_rand_and_hasher(&RateBucketInfo::DEFAULT_SET, rand, hasher);
for i in 0..1_000_000 {
limiter.check(i, 1);
}

View File

@@ -36,6 +36,7 @@ use crate::context::RequestMonitoring;
use crate::metrics::Metrics;
use crate::protocol2::read_proxy_protocol;
use crate::proxy::run_until_cancelled;
use crate::rate_limiter::EndpointRateLimiter;
use crate::serverless::backend::PoolingBackend;
use crate::serverless::http_util::{api_error_into_response, json_response};
@@ -54,6 +55,7 @@ pub async fn task_main(
ws_listener: TcpListener,
cancellation_token: CancellationToken,
cancellation_handler: Arc<CancellationHandlerMain>,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
) -> anyhow::Result<()> {
scopeguard::defer! {
info!("websocket server has shut down");
@@ -82,6 +84,7 @@ pub async fn task_main(
let backend = Arc::new(PoolingBackend {
pool: Arc::clone(&conn_pool),
config,
endpoint_rate_limiter: Arc::clone(&endpoint_rate_limiter),
});
let tls_config = match config.tls_config.as_ref() {
@@ -129,6 +132,7 @@ pub async fn task_main(
backend.clone(),
connections.clone(),
cancellation_handler.clone(),
endpoint_rate_limiter.clone(),
conn_token.clone(),
server.clone(),
tls_acceptor.clone(),
@@ -162,6 +166,7 @@ async fn connection_handler(
backend: Arc<PoolingBackend>,
connections: TaskTracker,
cancellation_handler: Arc<CancellationHandlerMain>,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
cancellation_token: CancellationToken,
server: Builder<TokioExecutor>,
tls_acceptor: TlsAcceptor,
@@ -245,6 +250,7 @@ async fn connection_handler(
session_id,
peer_addr,
http_request_token,
endpoint_rate_limiter.clone(),
)
.in_current_span()
.map_ok_or_else(api_error_into_response, |r| r),
@@ -285,6 +291,7 @@ async fn request_handler(
peer_addr: IpAddr,
// used to cancel in-flight HTTP requests. not used to cancel websockets
http_cancellation_token: CancellationToken,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
) -> Result<Response<Full<Bytes>>, ApiError> {
let host = request
.headers()
@@ -310,9 +317,15 @@ async fn request_handler(
ws_connections.spawn(
async move {
if let Err(e) =
websocket::serve_websocket(config, ctx, websocket, cancellation_handler, host)
.await
if let Err(e) = websocket::serve_websocket(
config,
ctx,
websocket,
cancellation_handler,
endpoint_rate_limiter,
host,
)
.await
{
error!("error in websocket connection: {e:#}");
}

View File

@@ -16,6 +16,7 @@ use crate::{
context::RequestMonitoring,
error::{ErrorKind, ReportableError, UserFacingError},
proxy::{connect_compute::ConnectMechanism, retry::ShouldRetry},
rate_limiter::EndpointRateLimiter,
Host,
};
@@ -24,6 +25,7 @@ use super::conn_pool::{poll_client, Client, ConnInfo, GlobalConnPool};
pub struct PoolingBackend {
pub pool: Arc<GlobalConnPool<tokio_postgres::Client>>,
pub config: &'static ProxyConfig,
pub endpoint_rate_limiter: Arc<EndpointRateLimiter>,
}
impl PoolingBackend {
@@ -39,6 +41,12 @@ impl PoolingBackend {
if !check_peer_addr_is_in_list(&ctx.peer_addr, &allowed_ips) {
return Err(AuthError::ip_address_not_allowed(ctx.peer_addr));
}
if !self
.endpoint_rate_limiter
.check(conn_info.user_info.endpoint.clone().into(), 1)
{
return Err(AuthError::too_many_connections());
}
let cached_secret = match maybe_secret {
Some(secret) => secret,
None => backend.get_role_secret(ctx).await?,

View File

@@ -5,6 +5,7 @@ use crate::{
error::{io_error, ReportableError},
metrics::Metrics,
proxy::{handle_client, ClientMode},
rate_limiter::EndpointRateLimiter,
};
use bytes::{Buf, Bytes};
use futures::{Sink, Stream};
@@ -134,6 +135,7 @@ pub async fn serve_websocket(
mut ctx: RequestMonitoring,
websocket: HyperWebsocket,
cancellation_handler: Arc<CancellationHandlerMain>,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
hostname: Option<String>,
) -> anyhow::Result<()> {
let websocket = websocket.await?;
@@ -148,6 +150,7 @@ pub async fn serve_websocket(
cancellation_handler,
WebSocketRw::new(websocket),
ClientMode::Websockets { hostname },
endpoint_rate_limiter,
conn_gauge,
)
.await;