mirror of
https://github.com/neondatabase/neon.git
synced 2026-05-29 11:00:38 +00:00
proxy: Changes to rate limits and GetEndpointAccessControl caches. (#12048)
Precursor to https://github.com/neondatabase/cloud/issues/28333. We want per-endpoint configuration for rate limits, which will be distributed via the `GetEndpointAccessControl` API. This lays some of the ground work. 1. Allow the endpoint rate limiter to accept a custom leaky bucket config on check. 2. Remove the unused auth rate limiter, as I don't want to think about how it fits into this. 3. Refactor the caching of `GetEndpointAccessControl`, as it adds friction for adding new cached data to the API. That third one was rather large. I couldn't find any way to split it up. The core idea is that there's now only 2 cache APIs. `get_endpoint_access_controls` and `get_role_access_controls`. I'm pretty sure the behaviour is unchanged, except I did a drive by change to fix #8989 because it felt harmless. The change in question is that when a password validation fails, we eagerly expire the role cache if the role was cached for 5 minutes. This is to allow for edge cases where a user tries to connect with a reset password, but the cache never expires the entry due to some redis related quirk (lag, or misconfiguration, or cplane error)
This commit is contained in:
@@ -15,7 +15,6 @@ use tracing::{Instrument, debug, info, info_span, warn};
|
||||
use super::super::messages::{ControlPlaneErrorMessage, GetEndpointAccessControl, WakeCompute};
|
||||
use crate::auth::backend::ComputeUserInfo;
|
||||
use crate::auth::backend::jwt::AuthRule;
|
||||
use crate::cache::Cached;
|
||||
use crate::context::RequestContext;
|
||||
use crate::control_plane::caches::ApiCaches;
|
||||
use crate::control_plane::errors::{
|
||||
@@ -24,12 +23,12 @@ use crate::control_plane::errors::{
|
||||
use crate::control_plane::locks::ApiLocks;
|
||||
use crate::control_plane::messages::{ColdStartInfo, EndpointJwksResponse, Reason};
|
||||
use crate::control_plane::{
|
||||
AccessBlockerFlags, AuthInfo, AuthSecret, CachedAccessBlockerFlags, CachedAllowedIps,
|
||||
CachedAllowedVpcEndpointIds, CachedNodeInfo, CachedRoleSecret, NodeInfo,
|
||||
AccessBlockerFlags, AuthInfo, AuthSecret, CachedNodeInfo, EndpointAccessControl, NodeInfo,
|
||||
RoleAccessControl,
|
||||
};
|
||||
use crate::metrics::{CacheOutcome, Metrics};
|
||||
use crate::metrics::Metrics;
|
||||
use crate::rate_limiter::WakeComputeRateLimiter;
|
||||
use crate::types::{EndpointCacheKey, EndpointId};
|
||||
use crate::types::{EndpointCacheKey, EndpointId, RoleName};
|
||||
use crate::{compute, http, scram};
|
||||
|
||||
pub(crate) const X_REQUEST_ID: HeaderName = HeaderName::from_static("x-request-id");
|
||||
@@ -66,65 +65,34 @@ impl NeonControlPlaneClient {
|
||||
self.endpoint.url().as_str()
|
||||
}
|
||||
|
||||
async fn do_get_auth_info(
|
||||
&self,
|
||||
ctx: &RequestContext,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<AuthInfo, GetAuthInfoError> {
|
||||
if !self
|
||||
.caches
|
||||
.endpoints_cache
|
||||
.is_valid(ctx, &user_info.endpoint.normalize())
|
||||
{
|
||||
// TODO: refactor this because it's weird
|
||||
// this is a failure to authenticate but we return Ok.
|
||||
info!("endpoint is not valid, skipping the request");
|
||||
return Ok(AuthInfo::default());
|
||||
}
|
||||
self.do_get_auth_req(user_info, &ctx.session_id(), Some(ctx))
|
||||
.await
|
||||
}
|
||||
|
||||
async fn do_get_auth_req(
|
||||
&self,
|
||||
user_info: &ComputeUserInfo,
|
||||
session_id: &uuid::Uuid,
|
||||
ctx: Option<&RequestContext>,
|
||||
ctx: &RequestContext,
|
||||
endpoint: &EndpointId,
|
||||
role: &RoleName,
|
||||
) -> Result<AuthInfo, GetAuthInfoError> {
|
||||
let request_id: String = session_id.to_string();
|
||||
let application_name = if let Some(ctx) = ctx {
|
||||
ctx.console_application_name()
|
||||
} else {
|
||||
"auth_cancellation".to_string()
|
||||
};
|
||||
|
||||
async {
|
||||
let request = self
|
||||
.endpoint
|
||||
.get_path("get_endpoint_access_control")
|
||||
.header(X_REQUEST_ID, &request_id)
|
||||
.header(X_REQUEST_ID, ctx.session_id().to_string())
|
||||
.header(AUTHORIZATION, format!("Bearer {}", &self.jwt))
|
||||
.query(&[("session_id", session_id)])
|
||||
.query(&[("session_id", ctx.session_id())])
|
||||
.query(&[
|
||||
("application_name", application_name.as_str()),
|
||||
("endpointish", user_info.endpoint.as_str()),
|
||||
("role", user_info.user.as_str()),
|
||||
("application_name", ctx.console_application_name().as_str()),
|
||||
("endpointish", endpoint.as_str()),
|
||||
("role", role.as_str()),
|
||||
])
|
||||
.build()?;
|
||||
|
||||
debug!(url = request.url().as_str(), "sending http request");
|
||||
let start = Instant::now();
|
||||
let response = match ctx {
|
||||
Some(ctx) => {
|
||||
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Cplane);
|
||||
let rsp = self.endpoint.execute(request).await;
|
||||
drop(pause);
|
||||
rsp?
|
||||
}
|
||||
None => self.endpoint.execute(request).await?,
|
||||
let response = {
|
||||
let _pause = ctx.latency_timer_pause_at(start, crate::metrics::Waiting::Cplane);
|
||||
self.endpoint.execute(request).await?
|
||||
};
|
||||
|
||||
info!(duration = ?start.elapsed(), "received http response");
|
||||
|
||||
let body = match parse_body::<GetEndpointAccessControl>(response).await {
|
||||
Ok(body) => body,
|
||||
// Error 404 is special: it's ok not to have a secret.
|
||||
@@ -180,7 +148,7 @@ impl NeonControlPlaneClient {
|
||||
async fn do_get_endpoint_jwks(
|
||||
&self,
|
||||
ctx: &RequestContext,
|
||||
endpoint: EndpointId,
|
||||
endpoint: &EndpointId,
|
||||
) -> Result<Vec<AuthRule>, GetEndpointJwksError> {
|
||||
if !self
|
||||
.caches
|
||||
@@ -313,225 +281,104 @@ impl NeonControlPlaneClient {
|
||||
|
||||
impl super::ControlPlaneApi for NeonControlPlaneClient {
|
||||
#[tracing::instrument(skip_all)]
|
||||
async fn get_role_secret(
|
||||
async fn get_role_access_control(
|
||||
&self,
|
||||
ctx: &RequestContext,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<CachedRoleSecret, GetAuthInfoError> {
|
||||
let normalized_ep = &user_info.endpoint.normalize();
|
||||
let user = &user_info.user;
|
||||
if let Some(role_secret) = self
|
||||
endpoint: &EndpointId,
|
||||
role: &RoleName,
|
||||
) -> Result<RoleAccessControl, crate::control_plane::errors::GetAuthInfoError> {
|
||||
let normalized_ep = &endpoint.normalize();
|
||||
if let Some(secret) = self
|
||||
.caches
|
||||
.project_info
|
||||
.get_role_secret(normalized_ep, user)
|
||||
.get_role_secret(normalized_ep, role)
|
||||
{
|
||||
return Ok(role_secret);
|
||||
return Ok(secret);
|
||||
}
|
||||
let auth_info = self.do_get_auth_info(ctx, user_info).await?;
|
||||
let account_id = auth_info.account_id;
|
||||
|
||||
if !self.caches.endpoints_cache.is_valid(ctx, normalized_ep) {
|
||||
info!("endpoint is not valid, skipping the request");
|
||||
return Err(GetAuthInfoError::UnknownEndpoint);
|
||||
}
|
||||
|
||||
let auth_info = self.do_get_auth_req(ctx, endpoint, role).await?;
|
||||
|
||||
let control = EndpointAccessControl {
|
||||
allowed_ips: Arc::new(auth_info.allowed_ips),
|
||||
allowed_vpce: Arc::new(auth_info.allowed_vpc_endpoint_ids),
|
||||
flags: auth_info.access_blocker_flags,
|
||||
};
|
||||
let role_control = RoleAccessControl {
|
||||
secret: auth_info.secret,
|
||||
};
|
||||
|
||||
if let Some(project_id) = auth_info.project_id {
|
||||
let normalized_ep_int = normalized_ep.into();
|
||||
self.caches.project_info.insert_role_secret(
|
||||
|
||||
self.caches.project_info.insert_endpoint_access(
|
||||
auth_info.account_id,
|
||||
project_id,
|
||||
normalized_ep_int,
|
||||
user.into(),
|
||||
auth_info.secret.clone(),
|
||||
);
|
||||
self.caches.project_info.insert_allowed_ips(
|
||||
project_id,
|
||||
normalized_ep_int,
|
||||
Arc::new(auth_info.allowed_ips),
|
||||
);
|
||||
self.caches.project_info.insert_allowed_vpc_endpoint_ids(
|
||||
account_id,
|
||||
project_id,
|
||||
normalized_ep_int,
|
||||
Arc::new(auth_info.allowed_vpc_endpoint_ids),
|
||||
);
|
||||
self.caches.project_info.insert_block_public_or_vpc_access(
|
||||
project_id,
|
||||
normalized_ep_int,
|
||||
auth_info.access_blocker_flags,
|
||||
role.into(),
|
||||
control,
|
||||
role_control.clone(),
|
||||
);
|
||||
ctx.set_project_id(project_id);
|
||||
}
|
||||
// When we just got a secret, we don't need to invalidate it.
|
||||
Ok(Cached::new_uncached(auth_info.secret))
|
||||
|
||||
Ok(role_control)
|
||||
}
|
||||
|
||||
async fn get_allowed_ips(
|
||||
#[tracing::instrument(skip_all)]
|
||||
async fn get_endpoint_access_control(
|
||||
&self,
|
||||
ctx: &RequestContext,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<CachedAllowedIps, GetAuthInfoError> {
|
||||
let normalized_ep = &user_info.endpoint.normalize();
|
||||
if let Some(allowed_ips) = self.caches.project_info.get_allowed_ips(normalized_ep) {
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.allowed_ips_cache_misses // TODO SR: Should we rename this variable to something like allowed_ip_cache_stats?
|
||||
.inc(CacheOutcome::Hit);
|
||||
return Ok(allowed_ips);
|
||||
endpoint: &EndpointId,
|
||||
role: &RoleName,
|
||||
) -> Result<EndpointAccessControl, GetAuthInfoError> {
|
||||
let normalized_ep = &endpoint.normalize();
|
||||
if let Some(control) = self.caches.project_info.get_endpoint_access(normalized_ep) {
|
||||
return Ok(control);
|
||||
}
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.allowed_ips_cache_misses
|
||||
.inc(CacheOutcome::Miss);
|
||||
let auth_info = self.do_get_auth_info(ctx, user_info).await?;
|
||||
let allowed_ips = Arc::new(auth_info.allowed_ips);
|
||||
let allowed_vpc_endpoint_ids = Arc::new(auth_info.allowed_vpc_endpoint_ids);
|
||||
let access_blocker_flags = auth_info.access_blocker_flags;
|
||||
let user = &user_info.user;
|
||||
let account_id = auth_info.account_id;
|
||||
|
||||
if !self.caches.endpoints_cache.is_valid(ctx, normalized_ep) {
|
||||
info!("endpoint is not valid, skipping the request");
|
||||
return Err(GetAuthInfoError::UnknownEndpoint);
|
||||
}
|
||||
|
||||
let auth_info = self.do_get_auth_req(ctx, endpoint, role).await?;
|
||||
|
||||
let control = EndpointAccessControl {
|
||||
allowed_ips: Arc::new(auth_info.allowed_ips),
|
||||
allowed_vpce: Arc::new(auth_info.allowed_vpc_endpoint_ids),
|
||||
flags: auth_info.access_blocker_flags,
|
||||
};
|
||||
let role_control = RoleAccessControl {
|
||||
secret: auth_info.secret,
|
||||
};
|
||||
|
||||
if let Some(project_id) = auth_info.project_id {
|
||||
let normalized_ep_int = normalized_ep.into();
|
||||
self.caches.project_info.insert_role_secret(
|
||||
|
||||
self.caches.project_info.insert_endpoint_access(
|
||||
auth_info.account_id,
|
||||
project_id,
|
||||
normalized_ep_int,
|
||||
user.into(),
|
||||
auth_info.secret.clone(),
|
||||
);
|
||||
self.caches.project_info.insert_allowed_ips(
|
||||
project_id,
|
||||
normalized_ep_int,
|
||||
allowed_ips.clone(),
|
||||
);
|
||||
self.caches.project_info.insert_allowed_vpc_endpoint_ids(
|
||||
account_id,
|
||||
project_id,
|
||||
normalized_ep_int,
|
||||
allowed_vpc_endpoint_ids.clone(),
|
||||
);
|
||||
self.caches.project_info.insert_block_public_or_vpc_access(
|
||||
project_id,
|
||||
normalized_ep_int,
|
||||
access_blocker_flags,
|
||||
role.into(),
|
||||
control.clone(),
|
||||
role_control,
|
||||
);
|
||||
ctx.set_project_id(project_id);
|
||||
}
|
||||
Ok(Cached::new_uncached(allowed_ips))
|
||||
}
|
||||
|
||||
async fn get_allowed_vpc_endpoint_ids(
|
||||
&self,
|
||||
ctx: &RequestContext,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<CachedAllowedVpcEndpointIds, GetAuthInfoError> {
|
||||
let normalized_ep = &user_info.endpoint.normalize();
|
||||
if let Some(allowed_vpc_endpoint_ids) = self
|
||||
.caches
|
||||
.project_info
|
||||
.get_allowed_vpc_endpoint_ids(normalized_ep)
|
||||
{
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.vpc_endpoint_id_cache_stats
|
||||
.inc(CacheOutcome::Hit);
|
||||
return Ok(allowed_vpc_endpoint_ids);
|
||||
}
|
||||
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.vpc_endpoint_id_cache_stats
|
||||
.inc(CacheOutcome::Miss);
|
||||
|
||||
let auth_info = self.do_get_auth_info(ctx, user_info).await?;
|
||||
let allowed_ips = Arc::new(auth_info.allowed_ips);
|
||||
let allowed_vpc_endpoint_ids = Arc::new(auth_info.allowed_vpc_endpoint_ids);
|
||||
let access_blocker_flags = auth_info.access_blocker_flags;
|
||||
let user = &user_info.user;
|
||||
let account_id = auth_info.account_id;
|
||||
if let Some(project_id) = auth_info.project_id {
|
||||
let normalized_ep_int = normalized_ep.into();
|
||||
self.caches.project_info.insert_role_secret(
|
||||
project_id,
|
||||
normalized_ep_int,
|
||||
user.into(),
|
||||
auth_info.secret.clone(),
|
||||
);
|
||||
self.caches.project_info.insert_allowed_ips(
|
||||
project_id,
|
||||
normalized_ep_int,
|
||||
allowed_ips.clone(),
|
||||
);
|
||||
self.caches.project_info.insert_allowed_vpc_endpoint_ids(
|
||||
account_id,
|
||||
project_id,
|
||||
normalized_ep_int,
|
||||
allowed_vpc_endpoint_ids.clone(),
|
||||
);
|
||||
self.caches.project_info.insert_block_public_or_vpc_access(
|
||||
project_id,
|
||||
normalized_ep_int,
|
||||
access_blocker_flags,
|
||||
);
|
||||
ctx.set_project_id(project_id);
|
||||
}
|
||||
Ok(Cached::new_uncached(allowed_vpc_endpoint_ids))
|
||||
}
|
||||
|
||||
async fn get_block_public_or_vpc_access(
|
||||
&self,
|
||||
ctx: &RequestContext,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<CachedAccessBlockerFlags, GetAuthInfoError> {
|
||||
let normalized_ep = &user_info.endpoint.normalize();
|
||||
if let Some(access_blocker_flags) = self
|
||||
.caches
|
||||
.project_info
|
||||
.get_block_public_or_vpc_access(normalized_ep)
|
||||
{
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.access_blocker_flags_cache_stats
|
||||
.inc(CacheOutcome::Hit);
|
||||
return Ok(access_blocker_flags);
|
||||
}
|
||||
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.access_blocker_flags_cache_stats
|
||||
.inc(CacheOutcome::Miss);
|
||||
|
||||
let auth_info = self.do_get_auth_info(ctx, user_info).await?;
|
||||
let allowed_ips = Arc::new(auth_info.allowed_ips);
|
||||
let allowed_vpc_endpoint_ids = Arc::new(auth_info.allowed_vpc_endpoint_ids);
|
||||
let access_blocker_flags = auth_info.access_blocker_flags;
|
||||
let user = &user_info.user;
|
||||
let account_id = auth_info.account_id;
|
||||
if let Some(project_id) = auth_info.project_id {
|
||||
let normalized_ep_int = normalized_ep.into();
|
||||
self.caches.project_info.insert_role_secret(
|
||||
project_id,
|
||||
normalized_ep_int,
|
||||
user.into(),
|
||||
auth_info.secret.clone(),
|
||||
);
|
||||
self.caches.project_info.insert_allowed_ips(
|
||||
project_id,
|
||||
normalized_ep_int,
|
||||
allowed_ips.clone(),
|
||||
);
|
||||
self.caches.project_info.insert_allowed_vpc_endpoint_ids(
|
||||
account_id,
|
||||
project_id,
|
||||
normalized_ep_int,
|
||||
allowed_vpc_endpoint_ids.clone(),
|
||||
);
|
||||
self.caches.project_info.insert_block_public_or_vpc_access(
|
||||
project_id,
|
||||
normalized_ep_int,
|
||||
access_blocker_flags.clone(),
|
||||
);
|
||||
ctx.set_project_id(project_id);
|
||||
}
|
||||
Ok(Cached::new_uncached(access_blocker_flags))
|
||||
Ok(control)
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
async fn get_endpoint_jwks(
|
||||
&self,
|
||||
ctx: &RequestContext,
|
||||
endpoint: EndpointId,
|
||||
endpoint: &EndpointId,
|
||||
) -> Result<Vec<AuthRule>, GetEndpointJwksError> {
|
||||
self.do_get_endpoint_jwks(ctx, endpoint).await
|
||||
}
|
||||
|
||||
@@ -15,14 +15,14 @@ use crate::auth::backend::ComputeUserInfo;
|
||||
use crate::auth::backend::jwt::AuthRule;
|
||||
use crate::cache::Cached;
|
||||
use crate::context::RequestContext;
|
||||
use crate::control_plane::client::{
|
||||
CachedAllowedIps, CachedAllowedVpcEndpointIds, CachedRoleSecret,
|
||||
};
|
||||
use crate::control_plane::errors::{
|
||||
ControlPlaneError, GetAuthInfoError, GetEndpointJwksError, WakeComputeError,
|
||||
};
|
||||
use crate::control_plane::messages::MetricsAuxInfo;
|
||||
use crate::control_plane::{AccessBlockerFlags, AuthInfo, AuthSecret, CachedNodeInfo, NodeInfo};
|
||||
use crate::control_plane::{
|
||||
AccessBlockerFlags, AuthInfo, AuthSecret, CachedNodeInfo, EndpointAccessControl, NodeInfo,
|
||||
RoleAccessControl,
|
||||
};
|
||||
use crate::intern::RoleNameInt;
|
||||
use crate::types::{BranchId, EndpointId, ProjectId, RoleName};
|
||||
use crate::url::ApiUrl;
|
||||
@@ -66,7 +66,8 @@ impl MockControlPlane {
|
||||
|
||||
async fn do_get_auth_info(
|
||||
&self,
|
||||
user_info: &ComputeUserInfo,
|
||||
endpoint: &EndpointId,
|
||||
role: &RoleName,
|
||||
) -> Result<AuthInfo, GetAuthInfoError> {
|
||||
let (secret, allowed_ips) = async {
|
||||
// Perhaps we could persist this connection, but then we'd have to
|
||||
@@ -80,7 +81,7 @@ impl MockControlPlane {
|
||||
let secret = if let Some(entry) = get_execute_postgres_query(
|
||||
&client,
|
||||
"select rolpassword from pg_catalog.pg_authid where rolname = $1",
|
||||
&[&&*user_info.user],
|
||||
&[&role.as_str()],
|
||||
"rolpassword",
|
||||
)
|
||||
.await?
|
||||
@@ -89,7 +90,7 @@ impl MockControlPlane {
|
||||
let secret = scram::ServerSecret::parse(&entry).map(AuthSecret::Scram);
|
||||
secret.or_else(|| parse_md5(&entry).map(AuthSecret::Md5))
|
||||
} else {
|
||||
warn!("user '{}' does not exist", user_info.user);
|
||||
warn!("user '{role}' does not exist");
|
||||
None
|
||||
};
|
||||
|
||||
@@ -97,7 +98,7 @@ impl MockControlPlane {
|
||||
match get_execute_postgres_query(
|
||||
&client,
|
||||
"select allowed_ips from neon_control_plane.endpoints where endpoint_id = $1",
|
||||
&[&user_info.endpoint.as_str()],
|
||||
&[&endpoint.as_str()],
|
||||
"allowed_ips",
|
||||
)
|
||||
.await?
|
||||
@@ -133,7 +134,7 @@ impl MockControlPlane {
|
||||
|
||||
async fn do_get_endpoint_jwks(
|
||||
&self,
|
||||
endpoint: EndpointId,
|
||||
endpoint: &EndpointId,
|
||||
) -> Result<Vec<AuthRule>, GetEndpointJwksError> {
|
||||
let (client, connection) =
|
||||
tokio_postgres::connect(self.endpoint.as_str(), tokio_postgres::NoTls).await?;
|
||||
@@ -222,53 +223,36 @@ async fn get_execute_postgres_query(
|
||||
}
|
||||
|
||||
impl super::ControlPlaneApi for MockControlPlane {
|
||||
#[tracing::instrument(skip_all)]
|
||||
async fn get_role_secret(
|
||||
async fn get_endpoint_access_control(
|
||||
&self,
|
||||
_ctx: &RequestContext,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<CachedRoleSecret, GetAuthInfoError> {
|
||||
Ok(CachedRoleSecret::new_uncached(
|
||||
self.do_get_auth_info(user_info).await?.secret,
|
||||
))
|
||||
endpoint: &EndpointId,
|
||||
role: &RoleName,
|
||||
) -> Result<EndpointAccessControl, GetAuthInfoError> {
|
||||
let info = self.do_get_auth_info(endpoint, role).await?;
|
||||
Ok(EndpointAccessControl {
|
||||
allowed_ips: Arc::new(info.allowed_ips),
|
||||
allowed_vpce: Arc::new(info.allowed_vpc_endpoint_ids),
|
||||
flags: info.access_blocker_flags,
|
||||
})
|
||||
}
|
||||
|
||||
async fn get_allowed_ips(
|
||||
async fn get_role_access_control(
|
||||
&self,
|
||||
_ctx: &RequestContext,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<CachedAllowedIps, GetAuthInfoError> {
|
||||
Ok(Cached::new_uncached(Arc::new(
|
||||
self.do_get_auth_info(user_info).await?.allowed_ips,
|
||||
)))
|
||||
}
|
||||
|
||||
async fn get_allowed_vpc_endpoint_ids(
|
||||
&self,
|
||||
_ctx: &RequestContext,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<CachedAllowedVpcEndpointIds, super::errors::GetAuthInfoError> {
|
||||
Ok(Cached::new_uncached(Arc::new(
|
||||
self.do_get_auth_info(user_info)
|
||||
.await?
|
||||
.allowed_vpc_endpoint_ids,
|
||||
)))
|
||||
}
|
||||
|
||||
async fn get_block_public_or_vpc_access(
|
||||
&self,
|
||||
_ctx: &RequestContext,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<super::CachedAccessBlockerFlags, super::errors::GetAuthInfoError> {
|
||||
Ok(Cached::new_uncached(
|
||||
self.do_get_auth_info(user_info).await?.access_blocker_flags,
|
||||
))
|
||||
endpoint: &EndpointId,
|
||||
role: &RoleName,
|
||||
) -> Result<RoleAccessControl, GetAuthInfoError> {
|
||||
let info = self.do_get_auth_info(endpoint, role).await?;
|
||||
Ok(RoleAccessControl {
|
||||
secret: info.secret,
|
||||
})
|
||||
}
|
||||
|
||||
async fn get_endpoint_jwks(
|
||||
&self,
|
||||
_ctx: &RequestContext,
|
||||
endpoint: EndpointId,
|
||||
endpoint: &EndpointId,
|
||||
) -> Result<Vec<AuthRule>, GetEndpointJwksError> {
|
||||
self.do_get_endpoint_jwks(endpoint).await
|
||||
}
|
||||
|
||||
@@ -16,15 +16,14 @@ use crate::cache::endpoints::EndpointsCache;
|
||||
use crate::cache::project_info::ProjectInfoCacheImpl;
|
||||
use crate::config::{CacheOptions, EndpointCacheConfig, ProjectInfoCacheOptions};
|
||||
use crate::context::RequestContext;
|
||||
use crate::control_plane::{
|
||||
CachedAccessBlockerFlags, CachedAllowedIps, CachedAllowedVpcEndpointIds, CachedNodeInfo,
|
||||
CachedRoleSecret, ControlPlaneApi, NodeInfoCache, errors,
|
||||
};
|
||||
use crate::control_plane::{CachedNodeInfo, ControlPlaneApi, NodeInfoCache, errors};
|
||||
use crate::error::ReportableError;
|
||||
use crate::metrics::ApiLockMetrics;
|
||||
use crate::rate_limiter::{DynamicLimiter, Outcome, RateLimiterConfig, Token};
|
||||
use crate::types::EndpointId;
|
||||
|
||||
use super::{EndpointAccessControl, RoleAccessControl};
|
||||
|
||||
#[non_exhaustive]
|
||||
#[derive(Clone)]
|
||||
pub enum ControlPlaneClient {
|
||||
@@ -40,68 +39,42 @@ pub enum ControlPlaneClient {
|
||||
}
|
||||
|
||||
impl ControlPlaneApi for ControlPlaneClient {
|
||||
async fn get_role_secret(
|
||||
async fn get_role_access_control(
|
||||
&self,
|
||||
ctx: &RequestContext,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<CachedRoleSecret, errors::GetAuthInfoError> {
|
||||
endpoint: &EndpointId,
|
||||
role: &crate::types::RoleName,
|
||||
) -> Result<RoleAccessControl, errors::GetAuthInfoError> {
|
||||
match self {
|
||||
Self::ProxyV1(api) => api.get_role_secret(ctx, user_info).await,
|
||||
Self::ProxyV1(api) => api.get_role_access_control(ctx, endpoint, role).await,
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
Self::PostgresMock(api) => api.get_role_secret(ctx, user_info).await,
|
||||
Self::PostgresMock(api) => api.get_role_access_control(ctx, endpoint, role).await,
|
||||
#[cfg(test)]
|
||||
Self::Test(_) => {
|
||||
Self::Test(_api) => {
|
||||
unreachable!("this function should never be called in the test backend")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_allowed_ips(
|
||||
async fn get_endpoint_access_control(
|
||||
&self,
|
||||
ctx: &RequestContext,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<CachedAllowedIps, errors::GetAuthInfoError> {
|
||||
endpoint: &EndpointId,
|
||||
role: &crate::types::RoleName,
|
||||
) -> Result<EndpointAccessControl, errors::GetAuthInfoError> {
|
||||
match self {
|
||||
Self::ProxyV1(api) => api.get_allowed_ips(ctx, user_info).await,
|
||||
Self::ProxyV1(api) => api.get_endpoint_access_control(ctx, endpoint, role).await,
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
Self::PostgresMock(api) => api.get_allowed_ips(ctx, user_info).await,
|
||||
Self::PostgresMock(api) => api.get_endpoint_access_control(ctx, endpoint, role).await,
|
||||
#[cfg(test)]
|
||||
Self::Test(api) => api.get_allowed_ips(),
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_allowed_vpc_endpoint_ids(
|
||||
&self,
|
||||
ctx: &RequestContext,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<CachedAllowedVpcEndpointIds, errors::GetAuthInfoError> {
|
||||
match self {
|
||||
Self::ProxyV1(api) => api.get_allowed_vpc_endpoint_ids(ctx, user_info).await,
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
Self::PostgresMock(api) => api.get_allowed_vpc_endpoint_ids(ctx, user_info).await,
|
||||
#[cfg(test)]
|
||||
Self::Test(api) => api.get_allowed_vpc_endpoint_ids(),
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_block_public_or_vpc_access(
|
||||
&self,
|
||||
ctx: &RequestContext,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<CachedAccessBlockerFlags, errors::GetAuthInfoError> {
|
||||
match self {
|
||||
Self::ProxyV1(api) => api.get_block_public_or_vpc_access(ctx, user_info).await,
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
Self::PostgresMock(api) => api.get_block_public_or_vpc_access(ctx, user_info).await,
|
||||
#[cfg(test)]
|
||||
Self::Test(api) => api.get_block_public_or_vpc_access(),
|
||||
Self::Test(api) => api.get_access_control(),
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_endpoint_jwks(
|
||||
&self,
|
||||
ctx: &RequestContext,
|
||||
endpoint: EndpointId,
|
||||
endpoint: &EndpointId,
|
||||
) -> Result<Vec<AuthRule>, errors::GetEndpointJwksError> {
|
||||
match self {
|
||||
Self::ProxyV1(api) => api.get_endpoint_jwks(ctx, endpoint).await,
|
||||
@@ -131,15 +104,7 @@ impl ControlPlaneApi for ControlPlaneClient {
|
||||
pub(crate) trait TestControlPlaneClient: Send + Sync + 'static {
|
||||
fn wake_compute(&self) -> Result<CachedNodeInfo, errors::WakeComputeError>;
|
||||
|
||||
fn get_allowed_ips(&self) -> Result<CachedAllowedIps, errors::GetAuthInfoError>;
|
||||
|
||||
fn get_allowed_vpc_endpoint_ids(
|
||||
&self,
|
||||
) -> Result<CachedAllowedVpcEndpointIds, errors::GetAuthInfoError>;
|
||||
|
||||
fn get_block_public_or_vpc_access(
|
||||
&self,
|
||||
) -> Result<CachedAccessBlockerFlags, errors::GetAuthInfoError>;
|
||||
fn get_access_control(&self) -> Result<EndpointAccessControl, errors::GetAuthInfoError>;
|
||||
|
||||
fn dyn_clone(&self) -> Box<dyn TestControlPlaneClient>;
|
||||
}
|
||||
@@ -309,7 +274,7 @@ impl FetchAuthRules for ControlPlaneClient {
|
||||
ctx: &RequestContext,
|
||||
endpoint: EndpointId,
|
||||
) -> Result<Vec<AuthRule>, FetchAuthRulesError> {
|
||||
self.get_endpoint_jwks(ctx, endpoint)
|
||||
self.get_endpoint_jwks(ctx, &endpoint)
|
||||
.await
|
||||
.map_err(FetchAuthRulesError::GetEndpointJwks)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user