[proxy] Cache GetEndpointAccessControl errors (#12571)

Related to https://github.com/neondatabase/cloud/issues/19353
This commit is contained in:
Krzysztof Szafrański
2025-07-18 12:17:58 +02:00
committed by GitHub
parent 8e95455aef
commit 96bcfba79e
5 changed files with 376 additions and 119 deletions

View File

@@ -10,6 +10,7 @@ use tokio::time::Instant;
use tracing::{debug, info};
use crate::config::ProjectInfoCacheOptions;
use crate::control_plane::messages::{ControlPlaneErrorMessage, Reason};
use crate::control_plane::{EndpointAccessControl, RoleAccessControl};
use crate::intern::{AccountIdInt, EndpointIdInt, ProjectIdInt, RoleNameInt};
use crate::types::{EndpointId, RoleName};
@@ -36,22 +37,37 @@ impl<T> Entry<T> {
}
pub(crate) fn get(&self) -> Option<&T> {
(self.expires_at > Instant::now()).then_some(&self.value)
(!self.is_expired()).then_some(&self.value)
}
fn is_expired(&self) -> bool {
self.expires_at <= Instant::now()
}
}
struct EndpointInfo {
role_controls: HashMap<RoleNameInt, Entry<RoleAccessControl>>,
controls: Option<Entry<EndpointAccessControl>>,
role_controls: HashMap<RoleNameInt, Entry<ControlPlaneResult<RoleAccessControl>>>,
controls: Option<Entry<ControlPlaneResult<EndpointAccessControl>>>,
}
type ControlPlaneResult<T> = Result<T, Box<ControlPlaneErrorMessage>>;
impl EndpointInfo {
pub(crate) fn get_role_secret(&self, role_name: RoleNameInt) -> Option<RoleAccessControl> {
self.role_controls.get(&role_name)?.get().cloned()
pub(crate) fn get_role_secret_with_ttl(
&self,
role_name: RoleNameInt,
) -> Option<(ControlPlaneResult<RoleAccessControl>, Duration)> {
let entry = self.role_controls.get(&role_name)?;
let ttl = entry.expires_at - Instant::now();
Some((entry.get()?.clone(), ttl))
}
pub(crate) fn get_controls(&self) -> Option<EndpointAccessControl> {
self.controls.as_ref()?.get().cloned()
pub(crate) fn get_controls_with_ttl(
&self,
) -> Option<(ControlPlaneResult<EndpointAccessControl>, Duration)> {
let entry = self.controls.as_ref()?;
let ttl = entry.expires_at - Instant::now();
Some((entry.get()?.clone(), ttl))
}
pub(crate) fn invalidate_endpoint(&mut self) {
@@ -153,28 +169,28 @@ impl ProjectInfoCacheImpl {
self.cache.get(&endpoint_id)
}
pub(crate) fn get_role_secret(
pub(crate) fn get_role_secret_with_ttl(
&self,
endpoint_id: &EndpointId,
role_name: &RoleName,
) -> Option<RoleAccessControl> {
) -> Option<(ControlPlaneResult<RoleAccessControl>, Duration)> {
let role_name = RoleNameInt::get(role_name)?;
let endpoint_info = self.get_endpoint_cache(endpoint_id)?;
endpoint_info.get_role_secret(role_name)
endpoint_info.get_role_secret_with_ttl(role_name)
}
pub(crate) fn get_endpoint_access(
pub(crate) fn get_endpoint_access_with_ttl(
&self,
endpoint_id: &EndpointId,
) -> Option<EndpointAccessControl> {
) -> Option<(ControlPlaneResult<EndpointAccessControl>, Duration)> {
let endpoint_info = self.get_endpoint_cache(endpoint_id)?;
endpoint_info.get_controls()
endpoint_info.get_controls_with_ttl()
}
pub(crate) fn insert_endpoint_access(
&self,
account_id: Option<AccountIdInt>,
project_id: ProjectIdInt,
project_id: Option<ProjectIdInt>,
endpoint_id: EndpointIdInt,
role_name: RoleNameInt,
controls: EndpointAccessControl,
@@ -183,26 +199,89 @@ impl ProjectInfoCacheImpl {
if let Some(account_id) = account_id {
self.insert_account2endpoint(account_id, endpoint_id);
}
self.insert_project2endpoint(project_id, endpoint_id);
if let Some(project_id) = project_id {
self.insert_project2endpoint(project_id, endpoint_id);
}
if self.cache.len() >= self.config.size {
// If there are too many entries, wait until the next gc cycle.
return;
}
let controls = Entry::new(controls, self.config.ttl);
let role_controls = Entry::new(role_controls, self.config.ttl);
debug!(
key = &*endpoint_id,
"created a cache entry for endpoint access"
);
let controls = Some(Entry::new(Ok(controls), self.config.ttl));
let role_controls = Entry::new(Ok(role_controls), self.config.ttl);
match self.cache.entry(endpoint_id) {
clashmap::Entry::Vacant(e) => {
e.insert(EndpointInfo {
role_controls: HashMap::from_iter([(role_name, role_controls)]),
controls: Some(controls),
controls,
});
}
clashmap::Entry::Occupied(mut e) => {
let ep = e.get_mut();
ep.controls = Some(controls);
ep.controls = controls;
if ep.role_controls.len() < self.config.max_roles {
ep.role_controls.insert(role_name, role_controls);
}
}
}
}
pub(crate) fn insert_endpoint_access_err(
&self,
endpoint_id: EndpointIdInt,
role_name: RoleNameInt,
msg: Box<ControlPlaneErrorMessage>,
ttl: Option<Duration>,
) {
if self.cache.len() >= self.config.size {
// If there are too many entries, wait until the next gc cycle.
return;
}
debug!(
key = &*endpoint_id,
"created a cache entry for an endpoint access error"
);
let ttl = ttl.unwrap_or(self.config.ttl);
let controls = if msg.get_reason() == Reason::RoleProtected {
// RoleProtected is the only role-specific error that control plane can give us.
// If a given role name does not exist, it still returns a successful response,
// just with an empty secret.
None
} else {
// We can cache all the other errors in EndpointInfo.controls,
// because they don't depend on what role name we pass to control plane.
Some(Entry::new(Err(msg.clone()), ttl))
};
let role_controls = Entry::new(Err(msg), ttl);
match self.cache.entry(endpoint_id) {
clashmap::Entry::Vacant(e) => {
e.insert(EndpointInfo {
role_controls: HashMap::from_iter([(role_name, role_controls)]),
controls,
});
}
clashmap::Entry::Occupied(mut e) => {
let ep = e.get_mut();
if let Some(entry) = &ep.controls
&& !entry.is_expired()
&& entry.value.is_ok()
{
// If we have cached non-expired, non-error controls, keep them.
} else {
ep.controls = controls;
}
if ep.role_controls.len() < self.config.max_roles {
ep.role_controls.insert(role_name, role_controls);
}
@@ -245,7 +324,7 @@ impl ProjectInfoCacheImpl {
return;
};
if role_controls.get().expires_at <= Instant::now() {
if role_controls.get().is_expired() {
role_controls.remove();
}
}
@@ -284,13 +363,11 @@ impl ProjectInfoCacheImpl {
#[cfg(test)]
mod tests {
use std::sync::Arc;
use super::*;
use crate::control_plane::messages::EndpointRateLimitConfig;
use crate::control_plane::messages::{Details, EndpointRateLimitConfig, ErrorInfo, Status};
use crate::control_plane::{AccessBlockerFlags, AuthSecret};
use crate::scram::ServerSecret;
use crate::types::ProjectId;
use std::sync::Arc;
#[tokio::test]
async fn test_project_info_cache_settings() {
@@ -301,9 +378,9 @@ mod tests {
ttl: Duration::from_secs(1),
gc_interval: Duration::from_secs(600),
});
let project_id: ProjectId = "project".into();
let project_id: Option<ProjectIdInt> = Some(ProjectIdInt::from(&"project".into()));
let endpoint_id: EndpointId = "endpoint".into();
let account_id: Option<AccountIdInt> = None;
let account_id = None;
let user1: RoleName = "user1".into();
let user2: RoleName = "user2".into();
@@ -316,7 +393,7 @@ mod tests {
cache.insert_endpoint_access(
account_id,
(&project_id).into(),
project_id,
(&endpoint_id).into(),
(&user1).into(),
EndpointAccessControl {
@@ -332,7 +409,7 @@ mod tests {
cache.insert_endpoint_access(
account_id,
(&project_id).into(),
project_id,
(&endpoint_id).into(),
(&user2).into(),
EndpointAccessControl {
@@ -346,11 +423,17 @@ mod tests {
},
);
let cached = cache.get_role_secret(&endpoint_id, &user1).unwrap();
assert_eq!(cached.secret, secret1);
let (cached, ttl) = cache
.get_role_secret_with_ttl(&endpoint_id, &user1)
.unwrap();
assert_eq!(cached.unwrap().secret, secret1);
assert_eq!(ttl, cache.config.ttl);
let cached = cache.get_role_secret(&endpoint_id, &user2).unwrap();
assert_eq!(cached.secret, secret2);
let (cached, ttl) = cache
.get_role_secret_with_ttl(&endpoint_id, &user2)
.unwrap();
assert_eq!(cached.unwrap().secret, secret2);
assert_eq!(ttl, cache.config.ttl);
// Shouldn't add more than 2 roles.
let user3: RoleName = "user3".into();
@@ -358,7 +441,7 @@ mod tests {
cache.insert_endpoint_access(
account_id,
(&project_id).into(),
project_id,
(&endpoint_id).into(),
(&user3).into(),
EndpointAccessControl {
@@ -372,17 +455,144 @@ mod tests {
},
);
assert!(cache.get_role_secret(&endpoint_id, &user3).is_none());
assert!(
cache
.get_role_secret_with_ttl(&endpoint_id, &user3)
.is_none()
);
let cached = cache.get_endpoint_access(&endpoint_id).unwrap();
let cached = cache
.get_endpoint_access_with_ttl(&endpoint_id)
.unwrap()
.0
.unwrap();
assert_eq!(cached.allowed_ips, allowed_ips);
tokio::time::advance(Duration::from_secs(2)).await;
let cached = cache.get_role_secret(&endpoint_id, &user1);
let cached = cache.get_role_secret_with_ttl(&endpoint_id, &user1);
assert!(cached.is_none());
let cached = cache.get_role_secret(&endpoint_id, &user2);
let cached = cache.get_role_secret_with_ttl(&endpoint_id, &user2);
assert!(cached.is_none());
let cached = cache.get_endpoint_access(&endpoint_id);
let cached = cache.get_endpoint_access_with_ttl(&endpoint_id);
assert!(cached.is_none());
}
#[tokio::test]
async fn test_caching_project_info_errors() {
let cache = ProjectInfoCacheImpl::new(ProjectInfoCacheOptions {
size: 10,
max_roles: 10,
ttl: Duration::from_secs(1),
gc_interval: Duration::from_secs(600),
});
let project_id = Some(ProjectIdInt::from(&"project".into()));
let endpoint_id: EndpointId = "endpoint".into();
let account_id = None;
let user1: RoleName = "user1".into();
let user2: RoleName = "user2".into();
let secret = Some(AuthSecret::Scram(ServerSecret::mock([1; 32])));
let role_msg = Box::new(ControlPlaneErrorMessage {
error: "role is protected and cannot be used for password-based authentication"
.to_owned()
.into_boxed_str(),
http_status_code: http::StatusCode::NOT_FOUND,
status: Some(Status {
code: "PERMISSION_DENIED".to_owned().into_boxed_str(),
message: "role is protected and cannot be used for password-based authentication"
.to_owned()
.into_boxed_str(),
details: Details {
error_info: Some(ErrorInfo {
reason: Reason::RoleProtected,
}),
retry_info: None,
user_facing_message: None,
},
}),
});
let generic_msg = Box::new(ControlPlaneErrorMessage {
error: "oh noes".to_owned().into_boxed_str(),
http_status_code: http::StatusCode::NOT_FOUND,
status: None,
});
let get_role_secret = |endpoint_id, role_name| {
cache
.get_role_secret_with_ttl(endpoint_id, role_name)
.unwrap()
.0
};
let get_endpoint_access =
|endpoint_id| cache.get_endpoint_access_with_ttl(endpoint_id).unwrap().0;
// stores role-specific errors only for get_role_secret
cache.insert_endpoint_access_err(
(&endpoint_id).into(),
(&user1).into(),
role_msg.clone(),
None,
);
assert_eq!(
get_role_secret(&endpoint_id, &user1).unwrap_err().error,
role_msg.error
);
assert!(cache.get_endpoint_access_with_ttl(&endpoint_id).is_none());
// stores non-role specific errors for both get_role_secret and get_endpoint_access
cache.insert_endpoint_access_err(
(&endpoint_id).into(),
(&user1).into(),
generic_msg.clone(),
None,
);
assert_eq!(
get_role_secret(&endpoint_id, &user1).unwrap_err().error,
generic_msg.error
);
assert_eq!(
get_endpoint_access(&endpoint_id).unwrap_err().error,
generic_msg.error
);
// error isn't returned for other roles in the same endpoint
assert!(
cache
.get_role_secret_with_ttl(&endpoint_id, &user2)
.is_none()
);
// success for a role does not overwrite errors for other roles
cache.insert_endpoint_access(
account_id,
project_id,
(&endpoint_id).into(),
(&user2).into(),
EndpointAccessControl {
allowed_ips: Arc::new(vec![]),
allowed_vpce: Arc::new(vec![]),
flags: AccessBlockerFlags::default(),
rate_limits: EndpointRateLimitConfig::default(),
},
RoleAccessControl {
secret: secret.clone(),
},
);
assert!(get_role_secret(&endpoint_id, &user1).is_err());
assert!(get_role_secret(&endpoint_id, &user2).is_ok());
// ...but does clear the access control error
assert!(get_endpoint_access(&endpoint_id).is_ok());
// storing an error does not overwrite successful access control response
cache.insert_endpoint_access_err(
(&endpoint_id).into(),
(&user2).into(),
generic_msg.clone(),
None,
);
assert!(get_role_secret(&endpoint_id, &user2).is_err());
assert!(get_endpoint_access(&endpoint_id).is_ok());
}
}