mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-14 17:02:56 +00:00
281 lines
9.2 KiB
Rust
281 lines
9.2 KiB
Rust
pub mod cplane_proxy_v1;
|
|
#[cfg(any(test, feature = "testing"))]
|
|
pub mod mock;
|
|
|
|
use std::hash::Hash;
|
|
use std::sync::Arc;
|
|
use std::time::Duration;
|
|
|
|
use clashmap::ClashMap;
|
|
use tokio::time::Instant;
|
|
use tracing::{debug};
|
|
|
|
use super::{EndpointAccessControl, RoleAccessControl};
|
|
use crate::auth::backend::ComputeUserInfo;
|
|
use crate::auth::backend::jwt::{AuthRule, FetchAuthRules, FetchAuthRulesError};
|
|
use crate::cache::endpoints::EndpointsCache;
|
|
use crate::cache::project_info::ProjectInfoCacheImpl;
|
|
use crate::config::{CacheOptions, EndpointCacheConfig, ProjectInfoCacheOptions};
|
|
use crate::context::RequestContext;
|
|
use crate::control_plane::{CachedNodeInfo, ControlPlaneApi, NodeInfoCache, errors};
|
|
use crate::error::ReportableError;
|
|
use crate::metrics::ApiLockMetrics;
|
|
use crate::rate_limiter::{DynamicLimiter, Outcome, RateLimiterConfig, Token};
|
|
use crate::types::EndpointId;
|
|
|
|
#[non_exhaustive]
|
|
#[derive(Clone)]
|
|
pub enum ControlPlaneClient {
|
|
/// Proxy V1 control plane API
|
|
ProxyV1(cplane_proxy_v1::NeonControlPlaneClient),
|
|
/// Local mock control plane.
|
|
#[cfg(any(test, feature = "testing"))]
|
|
PostgresMock(mock::MockControlPlane),
|
|
/// Internal testing
|
|
#[cfg(test)]
|
|
#[allow(private_interfaces)]
|
|
Test(Box<dyn TestControlPlaneClient>),
|
|
}
|
|
|
|
impl ControlPlaneApi for ControlPlaneClient {
|
|
async fn get_role_access_control(
|
|
&self,
|
|
ctx: &RequestContext,
|
|
endpoint: &EndpointId,
|
|
role: &crate::types::RoleName,
|
|
) -> Result<RoleAccessControl, errors::GetAuthInfoError> {
|
|
match self {
|
|
Self::ProxyV1(api) => api.get_role_access_control(ctx, endpoint, role).await,
|
|
#[cfg(any(test, feature = "testing"))]
|
|
Self::PostgresMock(api) => api.get_role_access_control(ctx, endpoint, role).await,
|
|
#[cfg(test)]
|
|
Self::Test(_api) => {
|
|
unreachable!("this function should never be called in the test backend")
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn get_endpoint_access_control(
|
|
&self,
|
|
ctx: &RequestContext,
|
|
endpoint: &EndpointId,
|
|
role: &crate::types::RoleName,
|
|
) -> Result<EndpointAccessControl, errors::GetAuthInfoError> {
|
|
match self {
|
|
Self::ProxyV1(api) => api.get_endpoint_access_control(ctx, endpoint, role).await,
|
|
#[cfg(any(test, feature = "testing"))]
|
|
Self::PostgresMock(api) => api.get_endpoint_access_control(ctx, endpoint, role).await,
|
|
#[cfg(test)]
|
|
Self::Test(api) => api.get_access_control(),
|
|
}
|
|
}
|
|
|
|
async fn get_endpoint_jwks(
|
|
&self,
|
|
ctx: &RequestContext,
|
|
endpoint: &EndpointId,
|
|
) -> Result<Vec<AuthRule>, errors::GetEndpointJwksError> {
|
|
match self {
|
|
Self::ProxyV1(api) => api.get_endpoint_jwks(ctx, endpoint).await,
|
|
#[cfg(any(test, feature = "testing"))]
|
|
Self::PostgresMock(api) => api.get_endpoint_jwks(ctx, endpoint).await,
|
|
#[cfg(test)]
|
|
Self::Test(_api) => Ok(vec![]),
|
|
}
|
|
}
|
|
|
|
async fn wake_compute(
|
|
&self,
|
|
ctx: &RequestContext,
|
|
user_info: &ComputeUserInfo,
|
|
) -> Result<CachedNodeInfo, errors::WakeComputeError> {
|
|
match self {
|
|
Self::ProxyV1(api) => api.wake_compute(ctx, user_info).await,
|
|
#[cfg(any(test, feature = "testing"))]
|
|
Self::PostgresMock(api) => api.wake_compute(ctx, user_info).await,
|
|
#[cfg(test)]
|
|
Self::Test(api) => api.wake_compute(),
|
|
}
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
pub(crate) trait TestControlPlaneClient: Send + Sync + 'static {
|
|
fn wake_compute(&self) -> Result<CachedNodeInfo, errors::WakeComputeError>;
|
|
|
|
fn get_access_control(&self) -> Result<EndpointAccessControl, errors::GetAuthInfoError>;
|
|
|
|
fn dyn_clone(&self) -> Box<dyn TestControlPlaneClient>;
|
|
}
|
|
|
|
#[cfg(test)]
|
|
impl Clone for Box<dyn TestControlPlaneClient> {
|
|
fn clone(&self) -> Self {
|
|
TestControlPlaneClient::dyn_clone(&**self)
|
|
}
|
|
}
|
|
|
|
/// Various caches for [`control_plane`](super).
|
|
pub struct ApiCaches {
|
|
/// Cache for the `wake_compute` API method.
|
|
pub(crate) node_info: NodeInfoCache,
|
|
/// Cache which stores project_id -> endpoint_ids mapping.
|
|
pub project_info: Arc<ProjectInfoCacheImpl>,
|
|
/// List of all valid endpoints.
|
|
pub endpoints_cache: Arc<EndpointsCache>,
|
|
}
|
|
|
|
impl ApiCaches {
|
|
pub fn new(
|
|
wake_compute_cache_config: CacheOptions,
|
|
project_info_cache_config: ProjectInfoCacheOptions,
|
|
endpoint_cache_config: EndpointCacheConfig,
|
|
) -> Self {
|
|
Self {
|
|
node_info: NodeInfoCache::new(
|
|
"node_info_cache",
|
|
wake_compute_cache_config.size,
|
|
wake_compute_cache_config.ttl,
|
|
true,
|
|
),
|
|
project_info: Arc::new(ProjectInfoCacheImpl::new(project_info_cache_config)),
|
|
endpoints_cache: Arc::new(EndpointsCache::new(endpoint_cache_config)),
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Various caches for [`control_plane`](super).
|
|
pub struct ApiLocks<K> {
|
|
name: &'static str,
|
|
node_locks: ClashMap<K, Arc<DynamicLimiter>>,
|
|
config: RateLimiterConfig,
|
|
timeout: Duration,
|
|
epoch: std::time::Duration,
|
|
metrics: &'static ApiLockMetrics,
|
|
}
|
|
|
|
#[derive(Debug, thiserror::Error)]
|
|
pub(crate) enum ApiLockError {
|
|
#[error("timeout acquiring resource permit")]
|
|
TimeoutError(#[from] tokio::time::error::Elapsed),
|
|
}
|
|
|
|
impl ReportableError for ApiLockError {
|
|
fn get_error_kind(&self) -> crate::error::ErrorKind {
|
|
match self {
|
|
ApiLockError::TimeoutError(_) => crate::error::ErrorKind::RateLimit,
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<K: Hash + Eq + Clone> ApiLocks<K> {
|
|
pub fn new(
|
|
name: &'static str,
|
|
config: RateLimiterConfig,
|
|
shards: usize,
|
|
timeout: Duration,
|
|
epoch: std::time::Duration,
|
|
metrics: &'static ApiLockMetrics,
|
|
) -> Self {
|
|
Self {
|
|
name,
|
|
node_locks: ClashMap::with_shard_amount(shards),
|
|
config,
|
|
timeout,
|
|
epoch,
|
|
metrics,
|
|
}
|
|
}
|
|
|
|
pub(crate) async fn get_permit(&self, key: &K) -> Result<WakeComputePermit, ApiLockError> {
|
|
if self.config.initial_limit == 0 {
|
|
return Ok(WakeComputePermit {
|
|
permit: Token::disabled(),
|
|
});
|
|
}
|
|
let now = Instant::now();
|
|
let semaphore = {
|
|
// get fast path
|
|
if let Some(semaphore) = self.node_locks.get(key) {
|
|
semaphore.clone()
|
|
} else {
|
|
self.node_locks
|
|
.entry(key.clone())
|
|
.or_insert_with(|| {
|
|
self.metrics.semaphores_registered.inc();
|
|
DynamicLimiter::new(self.config)
|
|
})
|
|
.clone()
|
|
}
|
|
};
|
|
let permit = semaphore.acquire_timeout(self.timeout).await;
|
|
|
|
self.metrics
|
|
.semaphore_acquire_seconds
|
|
.observe(now.elapsed().as_secs_f64());
|
|
debug!("acquired permit {:?}", now.elapsed().as_secs_f64());
|
|
Ok(WakeComputePermit { permit: permit? })
|
|
}
|
|
|
|
pub async fn garbage_collect_worker(&self) {
|
|
if self.config.initial_limit == 0 {
|
|
return;
|
|
}
|
|
let mut interval =
|
|
tokio::time::interval(self.epoch / (self.node_locks.shards().len()) as u32);
|
|
loop {
|
|
for (i, shard) in self.node_locks.shards().iter().enumerate() {
|
|
interval.tick().await;
|
|
// temporary lock a single shard and then clear any semaphores that aren't currently checked out
|
|
// race conditions: if strong_count == 1, there's no way that it can increase while the shard is locked
|
|
// therefore releasing it is safe from race conditions
|
|
debug!( //FIXME: is anything depending on this being info?
|
|
name = self.name,
|
|
shard = i,
|
|
"performing epoch reclamation on api lock"
|
|
);
|
|
let mut lock = shard.write();
|
|
let timer = self.metrics.reclamation_lag_seconds.start_timer();
|
|
let count = lock
|
|
.extract_if(|(_, semaphore)| Arc::strong_count(semaphore) == 1)
|
|
.count();
|
|
drop(lock);
|
|
self.metrics.semaphores_unregistered.inc_by(count as u64);
|
|
timer.observe();
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
pub(crate) struct WakeComputePermit {
|
|
permit: Token,
|
|
}
|
|
|
|
impl WakeComputePermit {
|
|
pub(crate) fn should_check_cache(&self) -> bool {
|
|
!self.permit.is_disabled()
|
|
}
|
|
pub(crate) fn release(self, outcome: Outcome) {
|
|
self.permit.release(outcome);
|
|
}
|
|
pub(crate) fn release_result<T, E>(self, res: Result<T, E>) -> Result<T, E> {
|
|
match res {
|
|
Ok(_) => self.release(Outcome::Success),
|
|
Err(_) => self.release(Outcome::Overload),
|
|
}
|
|
res
|
|
}
|
|
}
|
|
|
|
impl FetchAuthRules for ControlPlaneClient {
|
|
async fn fetch_auth_rules(
|
|
&self,
|
|
ctx: &RequestContext,
|
|
endpoint: EndpointId,
|
|
) -> Result<Vec<AuthRule>, FetchAuthRulesError> {
|
|
self.get_endpoint_jwks(ctx, &endpoint)
|
|
.await
|
|
.map_err(FetchAuthRulesError::GetEndpointJwks)
|
|
}
|
|
}
|