mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-10 15:02:56 +00:00
This reverts commit dbac2d2c47.
## Problem
Proxy pods fails to install in k8s clusters, cplane release blocking.
## Summary of changes
Revert
This commit is contained in:
@@ -27,7 +27,7 @@ use crate::{
|
||||
},
|
||||
stream, url,
|
||||
};
|
||||
use crate::{scram, EndpointCacheKey, EndpointId, Normalize, RoleName};
|
||||
use crate::{scram, EndpointCacheKey, EndpointId, RoleName};
|
||||
use std::sync::Arc;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tracing::{info, warn};
|
||||
@@ -186,7 +186,7 @@ impl AuthenticationConfig {
|
||||
is_cleartext: bool,
|
||||
) -> auth::Result<AuthSecret> {
|
||||
// we have validated the endpoint exists, so let's intern it.
|
||||
let endpoint_int = EndpointIdInt::from(endpoint.normalize());
|
||||
let endpoint_int = EndpointIdInt::from(endpoint);
|
||||
|
||||
// only count the full hash count if password hack or websocket flow.
|
||||
// in other words, if proxy needs to run the hashing
|
||||
|
||||
@@ -189,9 +189,7 @@ struct ProxyCliArgs {
|
||||
/// cache for `project_info` (use `size=0` to disable)
|
||||
#[clap(long, default_value = config::ProjectInfoCacheOptions::CACHE_DEFAULT_OPTIONS)]
|
||||
project_info_cache: String,
|
||||
/// cache for all valid endpoints
|
||||
#[clap(long, default_value = config::EndpointCacheConfig::CACHE_DEFAULT_OPTIONS)]
|
||||
endpoint_cache_config: String,
|
||||
|
||||
#[clap(flatten)]
|
||||
parquet_upload: ParquetUploadArgs,
|
||||
|
||||
@@ -403,7 +401,6 @@ async fn main() -> anyhow::Result<()> {
|
||||
|
||||
if let auth::BackendType::Console(api, _) = &config.auth_backend {
|
||||
if let proxy::console::provider::ConsoleBackend::Console(api) = &**api {
|
||||
maintenance_tasks.spawn(api.locks.garbage_collect_worker());
|
||||
if let Some(redis_notifications_client) = redis_notifications_client {
|
||||
let cache = api.caches.project_info.clone();
|
||||
maintenance_tasks.spawn(notifications::task_main(
|
||||
@@ -413,9 +410,6 @@ async fn main() -> anyhow::Result<()> {
|
||||
args.region.clone(),
|
||||
));
|
||||
maintenance_tasks.spawn(async move { cache.clone().gc_worker().await });
|
||||
let cache = api.caches.endpoints_cache.clone();
|
||||
let con = redis_notifications_client.clone();
|
||||
maintenance_tasks.spawn(async move { cache.do_read(con).await });
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -495,18 +489,14 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
|
||||
let wake_compute_cache_config: CacheOptions = args.wake_compute_cache.parse()?;
|
||||
let project_info_cache_config: ProjectInfoCacheOptions =
|
||||
args.project_info_cache.parse()?;
|
||||
let endpoint_cache_config: config::EndpointCacheConfig =
|
||||
args.endpoint_cache_config.parse()?;
|
||||
|
||||
info!("Using NodeInfoCache (wake_compute) with options={wake_compute_cache_config:?}");
|
||||
info!(
|
||||
"Using AllowedIpsCache (wake_compute) with options={project_info_cache_config:?}"
|
||||
);
|
||||
info!("Using EndpointCacheConfig with options={endpoint_cache_config:?}");
|
||||
let caches = Box::leak(Box::new(console::caches::ApiCaches::new(
|
||||
wake_compute_cache_config,
|
||||
project_info_cache_config,
|
||||
endpoint_cache_config,
|
||||
)));
|
||||
|
||||
let config::WakeComputeLockOptions {
|
||||
@@ -517,9 +507,10 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
|
||||
} = args.wake_compute_lock.parse()?;
|
||||
info!(permits, shards, ?epoch, "Using NodeLocks (wake_compute)");
|
||||
let locks = Box::leak(Box::new(
|
||||
console::locks::ApiLocks::new("wake_compute_lock", permits, shards, timeout, epoch)
|
||||
console::locks::ApiLocks::new("wake_compute_lock", permits, shards, timeout)
|
||||
.unwrap(),
|
||||
));
|
||||
tokio::spawn(locks.garbage_collect_worker(epoch));
|
||||
|
||||
let url = args.auth_endpoint.parse()?;
|
||||
let endpoint = http::Endpoint::new(url, http::new_client(rate_limiter_config));
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
pub mod common;
|
||||
pub mod endpoints;
|
||||
pub mod project_info;
|
||||
mod timed_lru;
|
||||
|
||||
|
||||
190
proxy/src/cache/endpoints.rs
vendored
190
proxy/src/cache/endpoints.rs
vendored
@@ -1,190 +0,0 @@
|
||||
use std::{
|
||||
convert::Infallible,
|
||||
sync::{
|
||||
atomic::{AtomicBool, Ordering},
|
||||
Arc,
|
||||
},
|
||||
};
|
||||
|
||||
use dashmap::DashSet;
|
||||
use redis::{
|
||||
streams::{StreamReadOptions, StreamReadReply},
|
||||
AsyncCommands, FromRedisValue, Value,
|
||||
};
|
||||
use serde::Deserialize;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
use crate::{
|
||||
config::EndpointCacheConfig,
|
||||
context::RequestMonitoring,
|
||||
intern::{BranchIdInt, EndpointIdInt, ProjectIdInt},
|
||||
metrics::REDIS_BROKEN_MESSAGES,
|
||||
rate_limiter::GlobalRateLimiter,
|
||||
redis::connection_with_credentials_provider::ConnectionWithCredentialsProvider,
|
||||
EndpointId,
|
||||
};
|
||||
|
||||
#[derive(Deserialize, Debug, Clone)]
|
||||
#[serde(rename_all(deserialize = "snake_case"))]
|
||||
pub enum ControlPlaneEventKey {
|
||||
EndpointCreated,
|
||||
BranchCreated,
|
||||
ProjectCreated,
|
||||
}
|
||||
|
||||
pub struct EndpointsCache {
|
||||
config: EndpointCacheConfig,
|
||||
endpoints: DashSet<EndpointIdInt>,
|
||||
branches: DashSet<BranchIdInt>,
|
||||
projects: DashSet<ProjectIdInt>,
|
||||
ready: AtomicBool,
|
||||
limiter: Arc<Mutex<GlobalRateLimiter>>,
|
||||
}
|
||||
|
||||
impl EndpointsCache {
|
||||
pub fn new(config: EndpointCacheConfig) -> Self {
|
||||
Self {
|
||||
limiter: Arc::new(Mutex::new(GlobalRateLimiter::new(
|
||||
config.limiter_info.clone(),
|
||||
))),
|
||||
config,
|
||||
endpoints: DashSet::new(),
|
||||
branches: DashSet::new(),
|
||||
projects: DashSet::new(),
|
||||
ready: AtomicBool::new(false),
|
||||
}
|
||||
}
|
||||
pub async fn is_valid(&self, ctx: &mut RequestMonitoring, endpoint: &EndpointId) -> bool {
|
||||
if !self.ready.load(Ordering::Acquire) {
|
||||
return true;
|
||||
}
|
||||
// If cache is disabled, just collect the metrics and return.
|
||||
if self.config.disable_cache {
|
||||
ctx.set_rejected(self.should_reject(endpoint));
|
||||
return true;
|
||||
}
|
||||
// If the limiter allows, we don't need to check the cache.
|
||||
if self.limiter.lock().await.check() {
|
||||
return true;
|
||||
}
|
||||
let rejected = self.should_reject(endpoint);
|
||||
ctx.set_rejected(rejected);
|
||||
!rejected
|
||||
}
|
||||
fn should_reject(&self, endpoint: &EndpointId) -> bool {
|
||||
if endpoint.is_endpoint() {
|
||||
!self.endpoints.contains(&EndpointIdInt::from(endpoint))
|
||||
} else if endpoint.is_branch() {
|
||||
!self
|
||||
.branches
|
||||
.contains(&BranchIdInt::from(&endpoint.as_branch()))
|
||||
} else {
|
||||
!self
|
||||
.projects
|
||||
.contains(&ProjectIdInt::from(&endpoint.as_project()))
|
||||
}
|
||||
}
|
||||
fn insert_event(&self, key: ControlPlaneEventKey, value: String) {
|
||||
// Do not do normalization here, we expect the events to be normalized.
|
||||
match key {
|
||||
ControlPlaneEventKey::EndpointCreated => {
|
||||
self.endpoints.insert(EndpointIdInt::from(&value.into()));
|
||||
}
|
||||
ControlPlaneEventKey::BranchCreated => {
|
||||
self.branches.insert(BranchIdInt::from(&value.into()));
|
||||
}
|
||||
ControlPlaneEventKey::ProjectCreated => {
|
||||
self.projects.insert(ProjectIdInt::from(&value.into()));
|
||||
}
|
||||
}
|
||||
}
|
||||
pub async fn do_read(
|
||||
&self,
|
||||
mut con: ConnectionWithCredentialsProvider,
|
||||
) -> anyhow::Result<Infallible> {
|
||||
let mut last_id = "0-0".to_string();
|
||||
loop {
|
||||
self.ready.store(false, Ordering::Release);
|
||||
if let Err(e) = con.connect().await {
|
||||
tracing::error!("error connecting to redis: {:?}", e);
|
||||
continue;
|
||||
}
|
||||
if let Err(e) = self.read_from_stream(&mut con, &mut last_id).await {
|
||||
tracing::error!("error reading from redis: {:?}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
async fn read_from_stream(
|
||||
&self,
|
||||
con: &mut ConnectionWithCredentialsProvider,
|
||||
last_id: &mut String,
|
||||
) -> anyhow::Result<()> {
|
||||
tracing::info!("reading endpoints/branches/projects from redis");
|
||||
self.batch_read(
|
||||
con,
|
||||
StreamReadOptions::default().count(self.config.initial_batch_size),
|
||||
last_id,
|
||||
true,
|
||||
)
|
||||
.await?;
|
||||
tracing::info!("ready to filter user requests");
|
||||
self.ready.store(true, Ordering::Release);
|
||||
self.batch_read(
|
||||
con,
|
||||
StreamReadOptions::default()
|
||||
.count(self.config.initial_batch_size)
|
||||
.block(self.config.xread_timeout.as_millis() as usize),
|
||||
last_id,
|
||||
false,
|
||||
)
|
||||
.await
|
||||
}
|
||||
fn parse_key_value(key: &str, value: &Value) -> anyhow::Result<(ControlPlaneEventKey, String)> {
|
||||
Ok((serde_json::from_str(key)?, String::from_redis_value(value)?))
|
||||
}
|
||||
async fn batch_read(
|
||||
&self,
|
||||
conn: &mut ConnectionWithCredentialsProvider,
|
||||
opts: StreamReadOptions,
|
||||
last_id: &mut String,
|
||||
return_when_finish: bool,
|
||||
) -> anyhow::Result<()> {
|
||||
let mut total: usize = 0;
|
||||
loop {
|
||||
let mut res: StreamReadReply = conn
|
||||
.xread_options(&[&self.config.stream_name], &[last_id.as_str()], &opts)
|
||||
.await?;
|
||||
if res.keys.len() != 1 {
|
||||
anyhow::bail!("Cannot read from redis stream {}", self.config.stream_name);
|
||||
}
|
||||
|
||||
let res = res.keys.pop().expect("Checked length above");
|
||||
|
||||
if return_when_finish && res.ids.len() <= self.config.default_batch_size {
|
||||
break;
|
||||
}
|
||||
for x in res.ids {
|
||||
total += 1;
|
||||
for (k, v) in x.map {
|
||||
let (key, value) = match Self::parse_key_value(&k, &v) {
|
||||
Ok(x) => x,
|
||||
Err(e) => {
|
||||
REDIS_BROKEN_MESSAGES
|
||||
.with_label_values(&[&self.config.stream_name])
|
||||
.inc();
|
||||
tracing::error!("error parsing key-value {k}-{v:?}: {e:?}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
self.insert_event(key, value);
|
||||
}
|
||||
if total.is_power_of_two() {
|
||||
tracing::debug!("endpoints read {}", total);
|
||||
}
|
||||
*last_id = x.id;
|
||||
}
|
||||
}
|
||||
tracing::info!("read {} endpoints/branches/projects from redis", total);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -313,75 +313,6 @@ impl CertResolver {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct EndpointCacheConfig {
|
||||
/// Batch size to receive all endpoints on the startup.
|
||||
pub initial_batch_size: usize,
|
||||
/// Batch size to receive endpoints.
|
||||
pub default_batch_size: usize,
|
||||
/// Timeouts for the stream read operation.
|
||||
pub xread_timeout: Duration,
|
||||
/// Stream name to read from.
|
||||
pub stream_name: String,
|
||||
/// Limiter info (to distinguish when to enable cache).
|
||||
pub limiter_info: Vec<RateBucketInfo>,
|
||||
/// Disable cache.
|
||||
/// If true, cache is ignored, but reports all statistics.
|
||||
pub disable_cache: bool,
|
||||
}
|
||||
|
||||
impl EndpointCacheConfig {
|
||||
/// Default options for [`crate::console::provider::NodeInfoCache`].
|
||||
/// Notice that by default the limiter is empty, which means that cache is disabled.
|
||||
pub const CACHE_DEFAULT_OPTIONS: &'static str =
|
||||
"initial_batch_size=1000,default_batch_size=10,xread_timeout=5m,stream_name=controlPlane,disable_cache=true,limiter_info=1000@1s";
|
||||
|
||||
/// Parse cache options passed via cmdline.
|
||||
/// Example: [`Self::CACHE_DEFAULT_OPTIONS`].
|
||||
fn parse(options: &str) -> anyhow::Result<Self> {
|
||||
let mut initial_batch_size = None;
|
||||
let mut default_batch_size = None;
|
||||
let mut xread_timeout = None;
|
||||
let mut stream_name = None;
|
||||
let mut limiter_info = vec![];
|
||||
let mut disable_cache = false;
|
||||
|
||||
for option in options.split(',') {
|
||||
let (key, value) = option
|
||||
.split_once('=')
|
||||
.with_context(|| format!("bad key-value pair: {option}"))?;
|
||||
|
||||
match key {
|
||||
"initial_batch_size" => initial_batch_size = Some(value.parse()?),
|
||||
"default_batch_size" => default_batch_size = Some(value.parse()?),
|
||||
"xread_timeout" => xread_timeout = Some(humantime::parse_duration(value)?),
|
||||
"stream_name" => stream_name = Some(value.to_string()),
|
||||
"limiter_info" => limiter_info.push(RateBucketInfo::from_str(value)?),
|
||||
"disable_cache" => disable_cache = value.parse()?,
|
||||
unknown => bail!("unknown key: {unknown}"),
|
||||
}
|
||||
}
|
||||
RateBucketInfo::validate(&mut limiter_info)?;
|
||||
|
||||
Ok(Self {
|
||||
initial_batch_size: initial_batch_size.context("missing `initial_batch_size`")?,
|
||||
default_batch_size: default_batch_size.context("missing `default_batch_size`")?,
|
||||
xread_timeout: xread_timeout.context("missing `xread_timeout`")?,
|
||||
stream_name: stream_name.context("missing `stream_name`")?,
|
||||
disable_cache,
|
||||
limiter_info,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for EndpointCacheConfig {
|
||||
type Err = anyhow::Error;
|
||||
|
||||
fn from_str(options: &str) -> Result<Self, Self::Err> {
|
||||
let error = || format!("failed to parse endpoint cache options '{options}'");
|
||||
Self::parse(options).with_context(error)
|
||||
}
|
||||
}
|
||||
#[derive(Debug)]
|
||||
pub struct MetricBackupCollectionConfig {
|
||||
pub interval: Duration,
|
||||
|
||||
@@ -8,15 +8,15 @@ use crate::{
|
||||
backend::{ComputeCredentialKeys, ComputeUserInfo},
|
||||
IpPattern,
|
||||
},
|
||||
cache::{endpoints::EndpointsCache, project_info::ProjectInfoCacheImpl, Cached, TimedLru},
|
||||
cache::{project_info::ProjectInfoCacheImpl, Cached, TimedLru},
|
||||
compute,
|
||||
config::{CacheOptions, EndpointCacheConfig, ProjectInfoCacheOptions},
|
||||
config::{CacheOptions, ProjectInfoCacheOptions},
|
||||
context::RequestMonitoring,
|
||||
intern::ProjectIdInt,
|
||||
scram, EndpointCacheKey,
|
||||
};
|
||||
use dashmap::DashMap;
|
||||
use std::{convert::Infallible, sync::Arc, time::Duration};
|
||||
use std::{sync::Arc, time::Duration};
|
||||
use tokio::sync::{OwnedSemaphorePermit, Semaphore};
|
||||
use tokio::time::Instant;
|
||||
use tracing::info;
|
||||
@@ -416,15 +416,12 @@ pub struct ApiCaches {
|
||||
pub node_info: NodeInfoCache,
|
||||
/// Cache which stores project_id -> endpoint_ids mapping.
|
||||
pub project_info: Arc<ProjectInfoCacheImpl>,
|
||||
/// List of all valid endpoints.
|
||||
pub endpoints_cache: Arc<EndpointsCache>,
|
||||
}
|
||||
|
||||
impl ApiCaches {
|
||||
pub fn new(
|
||||
wake_compute_cache_config: CacheOptions,
|
||||
project_info_cache_config: ProjectInfoCacheOptions,
|
||||
endpoint_cache_config: EndpointCacheConfig,
|
||||
) -> Self {
|
||||
Self {
|
||||
node_info: NodeInfoCache::new(
|
||||
@@ -434,7 +431,6 @@ impl ApiCaches {
|
||||
true,
|
||||
),
|
||||
project_info: Arc::new(ProjectInfoCacheImpl::new(project_info_cache_config)),
|
||||
endpoints_cache: Arc::new(EndpointsCache::new(endpoint_cache_config)),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -445,7 +441,6 @@ pub struct ApiLocks {
|
||||
node_locks: DashMap<EndpointCacheKey, Arc<Semaphore>>,
|
||||
permits: usize,
|
||||
timeout: Duration,
|
||||
epoch: std::time::Duration,
|
||||
registered: prometheus::IntCounter,
|
||||
unregistered: prometheus::IntCounter,
|
||||
reclamation_lag: prometheus::Histogram,
|
||||
@@ -458,7 +453,6 @@ impl ApiLocks {
|
||||
permits: usize,
|
||||
shards: usize,
|
||||
timeout: Duration,
|
||||
epoch: std::time::Duration,
|
||||
) -> prometheus::Result<Self> {
|
||||
let registered = prometheus::IntCounter::with_opts(
|
||||
prometheus::Opts::new(
|
||||
@@ -503,7 +497,6 @@ impl ApiLocks {
|
||||
node_locks: DashMap::with_shard_amount(shards),
|
||||
permits,
|
||||
timeout,
|
||||
epoch,
|
||||
lock_acquire_lag,
|
||||
registered,
|
||||
unregistered,
|
||||
@@ -543,9 +536,12 @@ impl ApiLocks {
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn garbage_collect_worker(&self) -> anyhow::Result<Infallible> {
|
||||
let mut interval =
|
||||
tokio::time::interval(self.epoch / (self.node_locks.shards().len()) as u32);
|
||||
pub async fn garbage_collect_worker(&self, epoch: std::time::Duration) {
|
||||
if self.permits == 0 {
|
||||
return;
|
||||
}
|
||||
|
||||
let mut interval = tokio::time::interval(epoch / (self.node_locks.shards().len()) as u32);
|
||||
loop {
|
||||
for (i, shard) in self.node_locks.shards().iter().enumerate() {
|
||||
interval.tick().await;
|
||||
|
||||
@@ -8,7 +8,6 @@ use super::{
|
||||
};
|
||||
use crate::{
|
||||
auth::backend::ComputeUserInfo, compute, console::messages::ColdStartInfo, http, scram,
|
||||
Normalize,
|
||||
};
|
||||
use crate::{
|
||||
cache::Cached,
|
||||
@@ -24,7 +23,7 @@ use tracing::{error, info, info_span, warn, Instrument};
|
||||
pub struct Api {
|
||||
endpoint: http::Endpoint,
|
||||
pub caches: &'static ApiCaches,
|
||||
pub locks: &'static ApiLocks,
|
||||
locks: &'static ApiLocks,
|
||||
jwt: String,
|
||||
}
|
||||
|
||||
@@ -56,15 +55,6 @@ impl Api {
|
||||
ctx: &mut RequestMonitoring,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<AuthInfo, GetAuthInfoError> {
|
||||
if !self
|
||||
.caches
|
||||
.endpoints_cache
|
||||
.is_valid(ctx, &user_info.endpoint.normalize())
|
||||
.await
|
||||
{
|
||||
info!("endpoint is not valid, skipping the request");
|
||||
return Ok(AuthInfo::default());
|
||||
}
|
||||
let request_id = ctx.session_id.to_string();
|
||||
let application_name = ctx.console_application_name();
|
||||
async {
|
||||
@@ -91,9 +81,7 @@ impl Api {
|
||||
Ok(body) => body,
|
||||
// Error 404 is special: it's ok not to have a secret.
|
||||
Err(e) => match e.http_status_code() {
|
||||
Some(http::StatusCode::NOT_FOUND) => {
|
||||
return Ok(AuthInfo::default());
|
||||
}
|
||||
Some(http::StatusCode::NOT_FOUND) => return Ok(AuthInfo::default()),
|
||||
_otherwise => return Err(e.into()),
|
||||
},
|
||||
};
|
||||
@@ -186,27 +174,23 @@ impl super::Api for Api {
|
||||
ctx: &mut RequestMonitoring,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<CachedRoleSecret, GetAuthInfoError> {
|
||||
let normalized_ep = &user_info.endpoint.normalize();
|
||||
let ep = &user_info.endpoint;
|
||||
let user = &user_info.user;
|
||||
if let Some(role_secret) = self
|
||||
.caches
|
||||
.project_info
|
||||
.get_role_secret(normalized_ep, user)
|
||||
{
|
||||
if let Some(role_secret) = self.caches.project_info.get_role_secret(ep, user) {
|
||||
return Ok(role_secret);
|
||||
}
|
||||
let auth_info = self.do_get_auth_info(ctx, user_info).await?;
|
||||
if let Some(project_id) = auth_info.project_id {
|
||||
let normalized_ep_int = normalized_ep.into();
|
||||
let ep_int = ep.into();
|
||||
self.caches.project_info.insert_role_secret(
|
||||
project_id,
|
||||
normalized_ep_int,
|
||||
ep_int,
|
||||
user.into(),
|
||||
auth_info.secret.clone(),
|
||||
);
|
||||
self.caches.project_info.insert_allowed_ips(
|
||||
project_id,
|
||||
normalized_ep_int,
|
||||
ep_int,
|
||||
Arc::new(auth_info.allowed_ips),
|
||||
);
|
||||
ctx.set_project_id(project_id);
|
||||
@@ -220,8 +204,8 @@ impl super::Api for Api {
|
||||
ctx: &mut RequestMonitoring,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), GetAuthInfoError> {
|
||||
let normalized_ep = &user_info.endpoint.normalize();
|
||||
if let Some(allowed_ips) = self.caches.project_info.get_allowed_ips(normalized_ep) {
|
||||
let ep = &user_info.endpoint;
|
||||
if let Some(allowed_ips) = self.caches.project_info.get_allowed_ips(ep) {
|
||||
ALLOWED_IPS_BY_CACHE_OUTCOME
|
||||
.with_label_values(&["hit"])
|
||||
.inc();
|
||||
@@ -234,18 +218,16 @@ impl super::Api for Api {
|
||||
let allowed_ips = Arc::new(auth_info.allowed_ips);
|
||||
let user = &user_info.user;
|
||||
if let Some(project_id) = auth_info.project_id {
|
||||
let normalized_ep_int = normalized_ep.into();
|
||||
let ep_int = ep.into();
|
||||
self.caches.project_info.insert_role_secret(
|
||||
project_id,
|
||||
normalized_ep_int,
|
||||
ep_int,
|
||||
user.into(),
|
||||
auth_info.secret.clone(),
|
||||
);
|
||||
self.caches.project_info.insert_allowed_ips(
|
||||
project_id,
|
||||
normalized_ep_int,
|
||||
allowed_ips.clone(),
|
||||
);
|
||||
self.caches
|
||||
.project_info
|
||||
.insert_allowed_ips(project_id, ep_int, allowed_ips.clone());
|
||||
ctx.set_project_id(project_id);
|
||||
}
|
||||
Ok((
|
||||
|
||||
@@ -12,9 +12,7 @@ use crate::{
|
||||
console::messages::{ColdStartInfo, MetricsAuxInfo},
|
||||
error::ErrorKind,
|
||||
intern::{BranchIdInt, ProjectIdInt},
|
||||
metrics::{
|
||||
bool_to_str, LatencyTimer, ENDPOINT_ERRORS_BY_KIND, ERROR_BY_KIND, NUM_INVALID_ENDPOINTS,
|
||||
},
|
||||
metrics::{LatencyTimer, ENDPOINT_ERRORS_BY_KIND, ERROR_BY_KIND},
|
||||
DbName, EndpointId, RoleName,
|
||||
};
|
||||
|
||||
@@ -52,8 +50,6 @@ pub struct RequestMonitoring {
|
||||
// This sender is here to keep the request monitoring channel open while requests are taking place.
|
||||
sender: Option<mpsc::UnboundedSender<RequestData>>,
|
||||
pub latency_timer: LatencyTimer,
|
||||
// Whether proxy decided that it's not a valid endpoint end rejected it before going to cplane.
|
||||
rejected: bool,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
@@ -97,7 +93,6 @@ impl RequestMonitoring {
|
||||
error_kind: None,
|
||||
auth_method: None,
|
||||
success: false,
|
||||
rejected: false,
|
||||
cold_start_info: ColdStartInfo::Unknown,
|
||||
|
||||
sender: LOG_CHAN.get().and_then(|tx| tx.upgrade()),
|
||||
@@ -118,10 +113,6 @@ impl RequestMonitoring {
|
||||
)
|
||||
}
|
||||
|
||||
pub fn set_rejected(&mut self, rejected: bool) {
|
||||
self.rejected = rejected;
|
||||
}
|
||||
|
||||
pub fn set_cold_start_info(&mut self, info: ColdStartInfo) {
|
||||
self.cold_start_info = info;
|
||||
self.latency_timer.cold_start_info(info);
|
||||
@@ -187,10 +178,6 @@ impl RequestMonitoring {
|
||||
|
||||
impl Drop for RequestMonitoring {
|
||||
fn drop(&mut self) {
|
||||
let outcome = if self.success { "success" } else { "failure" };
|
||||
NUM_INVALID_ENDPOINTS
|
||||
.with_label_values(&[self.protocol, bool_to_str(self.rejected), outcome])
|
||||
.inc();
|
||||
if let Some(tx) = self.sender.take() {
|
||||
let _: Result<(), _> = tx.send(RequestData::from(&*self));
|
||||
}
|
||||
|
||||
@@ -160,11 +160,6 @@ impl From<&EndpointId> for EndpointIdInt {
|
||||
EndpointIdTag::get_interner().get_or_intern(value)
|
||||
}
|
||||
}
|
||||
impl From<EndpointId> for EndpointIdInt {
|
||||
fn from(value: EndpointId) -> Self {
|
||||
EndpointIdTag::get_interner().get_or_intern(&value)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
|
||||
pub struct BranchIdTag;
|
||||
@@ -180,11 +175,6 @@ impl From<&BranchId> for BranchIdInt {
|
||||
BranchIdTag::get_interner().get_or_intern(value)
|
||||
}
|
||||
}
|
||||
impl From<BranchId> for BranchIdInt {
|
||||
fn from(value: BranchId) -> Self {
|
||||
BranchIdTag::get_interner().get_or_intern(&value)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
|
||||
pub struct ProjectIdTag;
|
||||
@@ -200,11 +190,6 @@ impl From<&ProjectId> for ProjectIdInt {
|
||||
ProjectIdTag::get_interner().get_or_intern(value)
|
||||
}
|
||||
}
|
||||
impl From<ProjectId> for ProjectIdInt {
|
||||
fn from(value: ProjectId) -> Self {
|
||||
ProjectIdTag::get_interner().get_or_intern(&value)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
@@ -127,24 +127,6 @@ macro_rules! smol_str_wrapper {
|
||||
};
|
||||
}
|
||||
|
||||
const POOLER_SUFFIX: &str = "-pooler";
|
||||
|
||||
pub trait Normalize {
|
||||
fn normalize(&self) -> Self;
|
||||
}
|
||||
|
||||
impl<S: Clone + AsRef<str> + From<String>> Normalize for S {
|
||||
fn normalize(&self) -> Self {
|
||||
if self.as_ref().ends_with(POOLER_SUFFIX) {
|
||||
let mut s = self.as_ref().to_string();
|
||||
s.truncate(s.len() - POOLER_SUFFIX.len());
|
||||
s.into()
|
||||
} else {
|
||||
self.clone()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 90% of role name strings are 20 characters or less.
|
||||
smol_str_wrapper!(RoleName);
|
||||
// 50% of endpoint strings are 23 characters or less.
|
||||
@@ -158,22 +140,3 @@ smol_str_wrapper!(ProjectId);
|
||||
smol_str_wrapper!(EndpointCacheKey);
|
||||
|
||||
smol_str_wrapper!(DbName);
|
||||
|
||||
// Endpoints are a bit tricky. Rare they might be branches or projects.
|
||||
impl EndpointId {
|
||||
pub fn is_endpoint(&self) -> bool {
|
||||
self.0.starts_with("ep-")
|
||||
}
|
||||
pub fn is_branch(&self) -> bool {
|
||||
self.0.starts_with("br-")
|
||||
}
|
||||
pub fn is_project(&self) -> bool {
|
||||
!self.is_endpoint() && !self.is_branch()
|
||||
}
|
||||
pub fn as_branch(&self) -> BranchId {
|
||||
BranchId(self.0.clone())
|
||||
}
|
||||
pub fn as_project(&self) -> ProjectId {
|
||||
ProjectId(self.0.clone())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -169,18 +169,6 @@ pub static NUM_CANCELLATION_REQUESTS: Lazy<IntCounterVec> = Lazy::new(|| {
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
pub static NUM_INVALID_ENDPOINTS: Lazy<IntCounterVec> = Lazy::new(|| {
|
||||
register_int_counter_vec!(
|
||||
"proxy_invalid_endpoints_total",
|
||||
"Number of invalid endpoints (per protocol, per rejected).",
|
||||
// http/ws/tcp, true/false, success/failure
|
||||
// TODO(anna): the last dimension is just a proxy to what we actually want to measure.
|
||||
// We need to measure whether the endpoint was found by cplane or not.
|
||||
&["protocol", "rejected", "outcome"],
|
||||
)
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
pub const NUM_CANCELLATION_REQUESTS_SOURCE_FROM_CLIENT: &str = "from_client";
|
||||
pub const NUM_CANCELLATION_REQUESTS_SOURCE_FROM_REDIS: &str = "from_redis";
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ use crate::{
|
||||
proxy::handshake::{handshake, HandshakeData},
|
||||
rate_limiter::EndpointRateLimiter,
|
||||
stream::{PqStream, Stream},
|
||||
EndpointCacheKey, Normalize,
|
||||
EndpointCacheKey,
|
||||
};
|
||||
use futures::TryFutureExt;
|
||||
use itertools::Itertools;
|
||||
@@ -280,7 +280,7 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
|
||||
// check rate limit
|
||||
if let Some(ep) = user_info.get_endpoint() {
|
||||
if !endpoint_rate_limiter.check(ep.normalize(), 1) {
|
||||
if !endpoint_rate_limiter.check(ep, 1) {
|
||||
return stream
|
||||
.throw_error(auth::AuthError::too_many_connections())
|
||||
.await?;
|
||||
|
||||
@@ -4,4 +4,4 @@ mod limiter;
|
||||
pub use aimd::Aimd;
|
||||
pub use limit_algorithm::{AimdConfig, Fixed, RateLimitAlgorithm, RateLimiterConfig};
|
||||
pub use limiter::Limiter;
|
||||
pub use limiter::{AuthRateLimiter, EndpointRateLimiter, GlobalRateLimiter, RateBucketInfo};
|
||||
pub use limiter::{AuthRateLimiter, EndpointRateLimiter, RateBucketInfo, RedisRateLimiter};
|
||||
|
||||
@@ -24,13 +24,13 @@ use super::{
|
||||
RateLimiterConfig,
|
||||
};
|
||||
|
||||
pub struct GlobalRateLimiter {
|
||||
pub struct RedisRateLimiter {
|
||||
data: Vec<RateBucket>,
|
||||
info: Vec<RateBucketInfo>,
|
||||
info: &'static [RateBucketInfo],
|
||||
}
|
||||
|
||||
impl GlobalRateLimiter {
|
||||
pub fn new(info: Vec<RateBucketInfo>) -> Self {
|
||||
impl RedisRateLimiter {
|
||||
pub fn new(info: &'static [RateBucketInfo]) -> Self {
|
||||
Self {
|
||||
data: vec![
|
||||
RateBucket {
|
||||
@@ -50,7 +50,7 @@ impl GlobalRateLimiter {
|
||||
let should_allow_request = self
|
||||
.data
|
||||
.iter_mut()
|
||||
.zip(&self.info)
|
||||
.zip(self.info)
|
||||
.all(|(bucket, info)| bucket.should_allow_request(info, now, 1));
|
||||
|
||||
if should_allow_request {
|
||||
|
||||
@@ -5,7 +5,7 @@ use redis::AsyncCommands;
|
||||
use tokio::sync::Mutex;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::rate_limiter::{GlobalRateLimiter, RateBucketInfo};
|
||||
use crate::rate_limiter::{RateBucketInfo, RedisRateLimiter};
|
||||
|
||||
use super::{
|
||||
connection_with_credentials_provider::ConnectionWithCredentialsProvider,
|
||||
@@ -80,7 +80,7 @@ impl<P: CancellationPublisherMut> CancellationPublisher for Arc<Mutex<P>> {
|
||||
pub struct RedisPublisherClient {
|
||||
client: ConnectionWithCredentialsProvider,
|
||||
region_id: String,
|
||||
limiter: GlobalRateLimiter,
|
||||
limiter: RedisRateLimiter,
|
||||
}
|
||||
|
||||
impl RedisPublisherClient {
|
||||
@@ -92,7 +92,7 @@ impl RedisPublisherClient {
|
||||
Ok(Self {
|
||||
client,
|
||||
region_id,
|
||||
limiter: GlobalRateLimiter::new(info.into()),
|
||||
limiter: RedisRateLimiter::new(info),
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
84
test_runner/regress/test_proxy_rate_limiter.py
Normal file
84
test_runner/regress/test_proxy_rate_limiter.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import asyncio
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Iterator
|
||||
|
||||
import pytest
|
||||
from fixtures.neon_fixtures import (
|
||||
PSQL,
|
||||
NeonProxy,
|
||||
)
|
||||
from fixtures.port_distributor import PortDistributor
|
||||
from pytest_httpserver import HTTPServer
|
||||
from werkzeug.wrappers.response import Response
|
||||
|
||||
|
||||
def waiting_handler(status_code: int) -> Response:
|
||||
# wait more than timeout to make sure that both (two) connections are open.
|
||||
# It would be better to use a barrier here, but I don't know how to do that together with pytest-httpserver.
|
||||
time.sleep(2)
|
||||
return Response(status=status_code)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def proxy_with_rate_limit(
|
||||
port_distributor: PortDistributor,
|
||||
neon_binpath: Path,
|
||||
httpserver_listen_address,
|
||||
test_output_dir: Path,
|
||||
) -> Iterator[NeonProxy]:
|
||||
"""Neon proxy that routes directly to vanilla postgres."""
|
||||
|
||||
proxy_port = port_distributor.get_port()
|
||||
mgmt_port = port_distributor.get_port()
|
||||
http_port = port_distributor.get_port()
|
||||
external_http_port = port_distributor.get_port()
|
||||
(host, port) = httpserver_listen_address
|
||||
endpoint = f"http://{host}:{port}/billing/api/v1/usage_events"
|
||||
|
||||
with NeonProxy(
|
||||
neon_binpath=neon_binpath,
|
||||
test_output_dir=test_output_dir,
|
||||
proxy_port=proxy_port,
|
||||
http_port=http_port,
|
||||
mgmt_port=mgmt_port,
|
||||
external_http_port=external_http_port,
|
||||
auth_backend=NeonProxy.Console(endpoint, fixed_rate_limit=5),
|
||||
) as proxy:
|
||||
proxy.start()
|
||||
yield proxy
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_proxy_rate_limit(
|
||||
httpserver: HTTPServer,
|
||||
proxy_with_rate_limit: NeonProxy,
|
||||
):
|
||||
uri = "/billing/api/v1/usage_events/proxy_get_role_secret"
|
||||
# mock control plane service
|
||||
httpserver.expect_ordered_request(uri, method="GET").respond_with_handler(
|
||||
lambda _: Response(status=200)
|
||||
)
|
||||
httpserver.expect_ordered_request(uri, method="GET").respond_with_handler(
|
||||
lambda _: waiting_handler(429)
|
||||
)
|
||||
httpserver.expect_ordered_request(uri, method="GET").respond_with_handler(
|
||||
lambda _: waiting_handler(500)
|
||||
)
|
||||
|
||||
psql = PSQL(host=proxy_with_rate_limit.host, port=proxy_with_rate_limit.proxy_port)
|
||||
f = await psql.run("select 42;")
|
||||
await proxy_with_rate_limit.find_auth_link(uri, f)
|
||||
# Limit should be 2.
|
||||
|
||||
# Run two queries in parallel.
|
||||
f1, f2 = await asyncio.gather(psql.run("select 42;"), psql.run("select 42;"))
|
||||
await proxy_with_rate_limit.find_auth_link(uri, f1)
|
||||
await proxy_with_rate_limit.find_auth_link(uri, f2)
|
||||
|
||||
# Now limit should be 0.
|
||||
f = await psql.run("select 42;")
|
||||
await proxy_with_rate_limit.find_auth_link(uri, f)
|
||||
|
||||
# There last query shouldn't reach the http-server.
|
||||
assert httpserver.assertions == []
|
||||
Reference in New Issue
Block a user