proxy: move endpoint rate limiter (#7413)

## Problem

## Summary of changes

Rate limit for wake_compute calls
This commit is contained in:
Conrad Ludgate
2024-04-18 06:00:33 +01:00
committed by GitHub
parent d5708e7435
commit a54ea8fb1c
9 changed files with 39 additions and 56 deletions

View File

@@ -331,7 +331,6 @@ async fn main() -> anyhow::Result<()> {
let proxy_listener = TcpListener::bind(proxy_address).await?;
let cancellation_token = CancellationToken::new();
let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new(&config.endpoint_rps_limit));
let cancel_map = CancelMap::default();
let redis_publisher = match &regional_redis_client {
@@ -357,7 +356,6 @@ async fn main() -> anyhow::Result<()> {
config,
proxy_listener,
cancellation_token.clone(),
endpoint_rate_limiter.clone(),
cancellation_handler.clone(),
));
@@ -372,7 +370,6 @@ async fn main() -> anyhow::Result<()> {
config,
serverless_listener,
cancellation_token.clone(),
endpoint_rate_limiter.clone(),
cancellation_handler.clone(),
));
}
@@ -533,7 +530,11 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
let url = args.auth_endpoint.parse()?;
let endpoint = http::Endpoint::new(url, http::new_client());
let api = console::provider::neon::Api::new(endpoint, caches, locks);
let mut endpoint_rps_limit = args.endpoint_rps_limit.clone();
RateBucketInfo::validate(&mut endpoint_rps_limit)?;
let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new(endpoint_rps_limit));
let api =
console::provider::neon::Api::new(endpoint, caches, locks, endpoint_rate_limiter);
let api = console::provider::ConsoleBackend::Console(api);
auth::BackendType::Console(MaybeOwned::Owned(api), ())
}
@@ -567,8 +568,6 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
rate_limit_ip_subnet: args.auth_rate_limit_ip_subnet,
};
let mut endpoint_rps_limit = args.endpoint_rps_limit.clone();
RateBucketInfo::validate(&mut endpoint_rps_limit)?;
let mut redis_rps_limit = args.redis_rps_limit.clone();
RateBucketInfo::validate(&mut redis_rps_limit)?;
@@ -581,7 +580,6 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
authentication_config,
require_client_ip: args.require_client_ip,
disable_ip_check_for_http: args.disable_ip_check_for_http,
endpoint_rps_limit,
redis_rps_limit,
handshake_timeout: args.handshake_timeout,
region: args.region.clone(),

View File

@@ -29,7 +29,6 @@ pub struct ProxyConfig {
pub authentication_config: AuthenticationConfig,
pub require_client_ip: bool,
pub disable_ip_check_for_http: bool,
pub endpoint_rps_limit: Vec<RateBucketInfo>,
pub redis_rps_limit: Vec<RateBucketInfo>,
pub region: String,
pub handshake_timeout: Duration,

View File

@@ -208,6 +208,9 @@ pub mod errors {
#[error(transparent)]
ApiError(ApiError),
#[error("Too many connections attempts")]
TooManyConnections,
#[error("Timeout waiting to acquire wake compute lock")]
TimeoutError,
}
@@ -240,6 +243,8 @@ pub mod errors {
// However, API might return a meaningful error.
ApiError(e) => e.to_string_client(),
TooManyConnections => self.to_string(),
TimeoutError => "timeout while acquiring the compute resource lock".to_owned(),
}
}
@@ -250,6 +255,7 @@ pub mod errors {
match self {
WakeComputeError::BadComputeAddress(_) => crate::error::ErrorKind::ControlPlane,
WakeComputeError::ApiError(e) => e.get_error_kind(),
WakeComputeError::TooManyConnections => crate::error::ErrorKind::RateLimit,
WakeComputeError::TimeoutError => crate::error::ErrorKind::ServiceRateLimit,
}
}

View File

@@ -12,6 +12,7 @@ use crate::{
console::messages::ColdStartInfo,
http,
metrics::{CacheOutcome, Metrics},
rate_limiter::EndpointRateLimiter,
scram, Normalize,
};
use crate::{cache::Cached, context::RequestMonitoring};
@@ -25,6 +26,7 @@ pub struct Api {
endpoint: http::Endpoint,
pub caches: &'static ApiCaches,
pub locks: &'static ApiLocks,
pub endpoint_rate_limiter: Arc<EndpointRateLimiter>,
jwt: String,
}
@@ -34,6 +36,7 @@ impl Api {
endpoint: http::Endpoint,
caches: &'static ApiCaches,
locks: &'static ApiLocks,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
) -> Self {
let jwt: String = match std::env::var("NEON_PROXY_TO_CONTROLPLANE_TOKEN") {
Ok(v) => v,
@@ -43,6 +46,7 @@ impl Api {
endpoint,
caches,
locks,
endpoint_rate_limiter,
jwt,
}
}
@@ -277,6 +281,14 @@ impl super::Api for Api {
return Ok(cached);
}
// check rate limit
if !self
.endpoint_rate_limiter
.check(user_info.endpoint.normalize().into(), 1)
{
return Err(WakeComputeError::TooManyConnections);
}
let permit = self.locks.get_wake_compute_permit(&key).await?;
// after getting back a permit - it's possible the cache was filled

View File

@@ -19,9 +19,8 @@ use crate::{
metrics::{Metrics, NumClientConnectionsGuard},
protocol2::WithClientIp,
proxy::handshake::{handshake, HandshakeData},
rate_limiter::EndpointRateLimiter,
stream::{PqStream, Stream},
EndpointCacheKey, Normalize,
EndpointCacheKey,
};
use futures::TryFutureExt;
use itertools::Itertools;
@@ -61,7 +60,6 @@ pub async fn task_main(
config: &'static ProxyConfig,
listener: tokio::net::TcpListener,
cancellation_token: CancellationToken,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
cancellation_handler: Arc<CancellationHandlerMain>,
) -> anyhow::Result<()> {
scopeguard::defer! {
@@ -86,7 +84,6 @@ pub async fn task_main(
let session_id = uuid::Uuid::new_v4();
let cancellation_handler = Arc::clone(&cancellation_handler);
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
tracing::info!(protocol = "tcp", %session_id, "accepted new TCP connection");
@@ -128,7 +125,6 @@ pub async fn task_main(
cancellation_handler,
socket,
ClientMode::Tcp,
endpoint_rate_limiter,
conn_gauge,
)
.instrument(span.clone())
@@ -242,7 +238,6 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
cancellation_handler: Arc<CancellationHandlerMain>,
stream: S,
mode: ClientMode,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
conn_gauge: NumClientConnectionsGuard<'static>,
) -> Result<Option<ProxyPassthrough<CancellationHandlerMainInternal, S>>, ClientRequestError> {
info!(
@@ -288,15 +283,6 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
Err(e) => stream.throw_error(e).await?,
};
// check rate limit
if let Some(ep) = user_info.get_endpoint() {
if !endpoint_rate_limiter.check(ep.normalize(), 1) {
return stream
.throw_error(auth::AuthError::too_many_connections())
.await?;
}
}
let user = user_info.get_user().to_owned();
let user_info = match user_info
.authenticate(

View File

@@ -90,6 +90,7 @@ fn report_error(e: &WakeComputeError, retry: bool) {
WakeComputeError::ApiError(ApiError::Console { .. }) => {
WakeupFailureKind::ApiConsoleOtherError
}
WakeComputeError::TooManyConnections => WakeupFailureKind::ApiConsoleLocked,
WakeComputeError::TimeoutError => WakeupFailureKind::TimeoutError,
};
Metrics::get()

View File

@@ -15,7 +15,7 @@ use rand::{rngs::StdRng, Rng, SeedableRng};
use tokio::time::{Duration, Instant};
use tracing::info;
use crate::EndpointId;
use crate::intern::EndpointIdInt;
pub struct GlobalRateLimiter {
data: Vec<RateBucket>,
@@ -61,12 +61,7 @@ impl GlobalRateLimiter {
// Purposefully ignore user name and database name as clients can reconnect
// with different names, so we'll end up sending some http requests to
// the control plane.
//
// We also may save quite a lot of CPU (I think) by bailing out right after we
// saw SNI, before doing TLS handshake. User-side error messages in that case
// does not look very nice (`SSL SYSCALL error: Undefined error: 0`), so for now
// I went with a more expensive way that yields user-friendlier error messages.
pub type EndpointRateLimiter = BucketRateLimiter<EndpointId, StdRng, RandomState>;
pub type EndpointRateLimiter = BucketRateLimiter<EndpointIdInt, StdRng, RandomState>;
pub struct BucketRateLimiter<Key, Rand = StdRng, Hasher = RandomState> {
map: DashMap<Key, Vec<RateBucket>, Hasher>,
@@ -245,7 +240,7 @@ mod tests {
use tokio::time;
use super::{BucketRateLimiter, EndpointRateLimiter};
use crate::{rate_limiter::RateBucketInfo, EndpointId};
use crate::{intern::EndpointIdInt, rate_limiter::RateBucketInfo, EndpointId};
#[test]
fn rate_bucket_rpi() {
@@ -295,39 +290,40 @@ mod tests {
let limiter = EndpointRateLimiter::new(rates);
let endpoint = EndpointId::from("ep-my-endpoint-1234");
let endpoint = EndpointIdInt::from(endpoint);
time::pause();
for _ in 0..100 {
assert!(limiter.check(endpoint.clone(), 1));
assert!(limiter.check(endpoint, 1));
}
// more connections fail
assert!(!limiter.check(endpoint.clone(), 1));
assert!(!limiter.check(endpoint, 1));
// fail even after 500ms as it's in the same bucket
time::advance(time::Duration::from_millis(500)).await;
assert!(!limiter.check(endpoint.clone(), 1));
assert!(!limiter.check(endpoint, 1));
// after a full 1s, 100 requests are allowed again
time::advance(time::Duration::from_millis(500)).await;
for _ in 1..6 {
for _ in 0..50 {
assert!(limiter.check(endpoint.clone(), 2));
assert!(limiter.check(endpoint, 2));
}
time::advance(time::Duration::from_millis(1000)).await;
}
// more connections after 600 will exceed the 20rps@30s limit
assert!(!limiter.check(endpoint.clone(), 1));
assert!(!limiter.check(endpoint, 1));
// will still fail before the 30 second limit
time::advance(time::Duration::from_millis(30_000 - 6_000 - 1)).await;
assert!(!limiter.check(endpoint.clone(), 1));
assert!(!limiter.check(endpoint, 1));
// after the full 30 seconds, 100 requests are allowed again
time::advance(time::Duration::from_millis(1)).await;
for _ in 0..100 {
assert!(limiter.check(endpoint.clone(), 1));
assert!(limiter.check(endpoint, 1));
}
}

View File

@@ -35,7 +35,6 @@ use crate::context::RequestMonitoring;
use crate::metrics::Metrics;
use crate::protocol2::WithClientIp;
use crate::proxy::run_until_cancelled;
use crate::rate_limiter::EndpointRateLimiter;
use crate::serverless::backend::PoolingBackend;
use crate::serverless::http_util::{api_error_into_response, json_response};
@@ -53,7 +52,6 @@ pub async fn task_main(
config: &'static ProxyConfig,
ws_listener: TcpListener,
cancellation_token: CancellationToken,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
cancellation_handler: Arc<CancellationHandlerMain>,
) -> anyhow::Result<()> {
scopeguard::defer! {
@@ -117,7 +115,6 @@ pub async fn task_main(
backend.clone(),
connections.clone(),
cancellation_handler.clone(),
endpoint_rate_limiter.clone(),
cancellation_token.clone(),
server.clone(),
tls_acceptor.clone(),
@@ -147,7 +144,6 @@ async fn connection_handler(
backend: Arc<PoolingBackend>,
connections: TaskTracker,
cancellation_handler: Arc<CancellationHandlerMain>,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
cancellation_token: CancellationToken,
server: Builder<TokioExecutor>,
tls_acceptor: TlsAcceptor,
@@ -231,7 +227,6 @@ async fn connection_handler(
cancellation_handler.clone(),
session_id,
peer_addr,
endpoint_rate_limiter.clone(),
http_request_token,
)
.in_current_span()
@@ -270,7 +265,6 @@ async fn request_handler(
cancellation_handler: Arc<CancellationHandlerMain>,
session_id: uuid::Uuid,
peer_addr: IpAddr,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
// used to cancel in-flight HTTP requests. not used to cancel websockets
http_cancellation_token: CancellationToken,
) -> Result<Response<Full<Bytes>>, ApiError> {
@@ -298,15 +292,9 @@ async fn request_handler(
ws_connections.spawn(
async move {
if let Err(e) = websocket::serve_websocket(
config,
ctx,
websocket,
cancellation_handler,
host,
endpoint_rate_limiter,
)
.await
if let Err(e) =
websocket::serve_websocket(config, ctx, websocket, cancellation_handler, host)
.await
{
error!("error in websocket connection: {e:#}");
}

View File

@@ -5,7 +5,6 @@ use crate::{
error::{io_error, ReportableError},
metrics::Metrics,
proxy::{handle_client, ClientMode},
rate_limiter::EndpointRateLimiter,
};
use bytes::{Buf, Bytes};
use futures::{Sink, Stream};
@@ -136,7 +135,6 @@ pub async fn serve_websocket(
websocket: HyperWebsocket,
cancellation_handler: Arc<CancellationHandlerMain>,
hostname: Option<String>,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
) -> anyhow::Result<()> {
let websocket = websocket.await?;
let conn_gauge = Metrics::get()
@@ -150,7 +148,6 @@ pub async fn serve_websocket(
cancellation_handler,
WebSocketRw::new(websocket),
ClientMode::Websockets { hostname },
endpoint_rate_limiter,
conn_gauge,
)
.await;