proxy: Changes to rate limits and GetEndpointAccessControl caches. (#12048)

Precursor to https://github.com/neondatabase/cloud/issues/28333.

We want per-endpoint configuration for rate limits, which will be
distributed via the `GetEndpointAccessControl` API. This lays some of
the ground work.

1. Allow the endpoint rate limiter to accept a custom leaky bucket
config on check.
2. Remove the unused auth rate limiter, as I don't want to think about
how it fits into this.
3. Refactor the caching of `GetEndpointAccessControl`, as it adds
friction for adding new cached data to the API.

That third one was rather large. I couldn't find any way to split it up.
The core idea is that there's now only 2 cache APIs.
`get_endpoint_access_controls` and `get_role_access_controls`.

I'm pretty sure the behaviour is unchanged, except I did a drive by
change to fix #8989 because it felt harmless. The change in question is
that when a password validation fails, we eagerly expire the role cache
if the role was cached for 5 minutes. This is to allow for edge cases
where a user tries to connect with a reset password, but the cache never
expires the entry due to some redis related quirk (lag, or
misconfiguration, or cplane error)
This commit is contained in:
Conrad Ludgate
2025-06-02 09:38:35 +01:00
committed by GitHub
parent 87179e26b3
commit 589bfdfd02
21 changed files with 551 additions and 1348 deletions

View File

@@ -1,30 +1,25 @@
use std::collections::HashSet;
use std::collections::{HashMap, HashSet, hash_map};
use std::convert::Infallible;
use std::sync::Arc;
use std::sync::atomic::AtomicU64;
use std::time::Duration;
use async_trait::async_trait;
use clashmap::ClashMap;
use clashmap::mapref::one::Ref;
use rand::{Rng, thread_rng};
use smol_str::SmolStr;
use tokio::sync::Mutex;
use tokio::time::Instant;
use tracing::{debug, info};
use super::{Cache, Cached};
use crate::auth::IpPattern;
use crate::config::ProjectInfoCacheOptions;
use crate::control_plane::{AccessBlockerFlags, AuthSecret};
use crate::control_plane::{EndpointAccessControl, RoleAccessControl};
use crate::intern::{AccountIdInt, EndpointIdInt, ProjectIdInt, RoleNameInt};
use crate::types::{EndpointId, RoleName};
#[async_trait]
pub(crate) trait ProjectInfoCache {
fn invalidate_allowed_ips_for_project(&self, project_id: ProjectIdInt);
fn invalidate_allowed_vpc_endpoint_ids_for_projects(&self, project_ids: Vec<ProjectIdInt>);
fn invalidate_allowed_vpc_endpoint_ids_for_org(&self, account_id: AccountIdInt);
fn invalidate_block_public_or_vpc_access_for_project(&self, project_id: ProjectIdInt);
fn invalidate_endpoint_access_for_project(&self, project_id: ProjectIdInt);
fn invalidate_endpoint_access_for_org(&self, account_id: AccountIdInt);
fn invalidate_role_secret_for_project(&self, project_id: ProjectIdInt, role_name: RoleNameInt);
async fn decrement_active_listeners(&self);
async fn increment_active_listeners(&self);
@@ -42,6 +37,10 @@ impl<T> Entry<T> {
value,
}
}
pub(crate) fn get(&self, valid_since: Instant) -> Option<&T> {
(valid_since < self.created_at).then_some(&self.value)
}
}
impl<T> From<T> for Entry<T> {
@@ -50,101 +49,32 @@ impl<T> From<T> for Entry<T> {
}
}
#[derive(Default)]
struct EndpointInfo {
secret: std::collections::HashMap<RoleNameInt, Entry<Option<AuthSecret>>>,
allowed_ips: Option<Entry<Arc<Vec<IpPattern>>>>,
block_public_or_vpc_access: Option<Entry<AccessBlockerFlags>>,
allowed_vpc_endpoint_ids: Option<Entry<Arc<Vec<String>>>>,
role_controls: HashMap<RoleNameInt, Entry<RoleAccessControl>>,
controls: Option<Entry<EndpointAccessControl>>,
}
impl EndpointInfo {
fn check_ignore_cache(ignore_cache_since: Option<Instant>, created_at: Instant) -> bool {
match ignore_cache_since {
None => false,
Some(t) => t < created_at,
}
}
pub(crate) fn get_role_secret(
&self,
role_name: RoleNameInt,
valid_since: Instant,
ignore_cache_since: Option<Instant>,
) -> Option<(Option<AuthSecret>, bool)> {
if let Some(secret) = self.secret.get(&role_name) {
if valid_since < secret.created_at {
return Some((
secret.value.clone(),
Self::check_ignore_cache(ignore_cache_since, secret.created_at),
));
}
}
None
) -> Option<RoleAccessControl> {
let controls = self.role_controls.get(&role_name)?;
controls.get(valid_since).cloned()
}
pub(crate) fn get_allowed_ips(
&self,
valid_since: Instant,
ignore_cache_since: Option<Instant>,
) -> Option<(Arc<Vec<IpPattern>>, bool)> {
if let Some(allowed_ips) = &self.allowed_ips {
if valid_since < allowed_ips.created_at {
return Some((
allowed_ips.value.clone(),
Self::check_ignore_cache(ignore_cache_since, allowed_ips.created_at),
));
}
}
None
}
pub(crate) fn get_allowed_vpc_endpoint_ids(
&self,
valid_since: Instant,
ignore_cache_since: Option<Instant>,
) -> Option<(Arc<Vec<String>>, bool)> {
if let Some(allowed_vpc_endpoint_ids) = &self.allowed_vpc_endpoint_ids {
if valid_since < allowed_vpc_endpoint_ids.created_at {
return Some((
allowed_vpc_endpoint_ids.value.clone(),
Self::check_ignore_cache(
ignore_cache_since,
allowed_vpc_endpoint_ids.created_at,
),
));
}
}
None
}
pub(crate) fn get_block_public_or_vpc_access(
&self,
valid_since: Instant,
ignore_cache_since: Option<Instant>,
) -> Option<(AccessBlockerFlags, bool)> {
if let Some(block_public_or_vpc_access) = &self.block_public_or_vpc_access {
if valid_since < block_public_or_vpc_access.created_at {
return Some((
block_public_or_vpc_access.value.clone(),
Self::check_ignore_cache(
ignore_cache_since,
block_public_or_vpc_access.created_at,
),
));
}
}
None
pub(crate) fn get_controls(&self, valid_since: Instant) -> Option<EndpointAccessControl> {
let controls = self.controls.as_ref()?;
controls.get(valid_since).cloned()
}
pub(crate) fn invalidate_allowed_ips(&mut self) {
self.allowed_ips = None;
}
pub(crate) fn invalidate_allowed_vpc_endpoint_ids(&mut self) {
self.allowed_vpc_endpoint_ids = None;
}
pub(crate) fn invalidate_block_public_or_vpc_access(&mut self) {
self.block_public_or_vpc_access = None;
pub(crate) fn invalidate_endpoint(&mut self) {
self.controls = None;
}
pub(crate) fn invalidate_role_secret(&mut self, role_name: RoleNameInt) {
self.secret.remove(&role_name);
self.role_controls.remove(&role_name);
}
}
@@ -170,34 +100,22 @@ pub struct ProjectInfoCacheImpl {
#[async_trait]
impl ProjectInfoCache for ProjectInfoCacheImpl {
fn invalidate_allowed_vpc_endpoint_ids_for_projects(&self, project_ids: Vec<ProjectIdInt>) {
info!(
"invalidating allowed vpc endpoint ids for projects `{}`",
project_ids
.iter()
.map(|id| id.to_string())
.collect::<Vec<_>>()
.join(", ")
);
for project_id in project_ids {
let endpoints = self
.project2ep
.get(&project_id)
.map(|kv| kv.value().clone())
.unwrap_or_default();
for endpoint_id in endpoints {
if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) {
endpoint_info.invalidate_allowed_vpc_endpoint_ids();
}
fn invalidate_endpoint_access_for_project(&self, project_id: ProjectIdInt) {
info!("invalidating endpoint access for project `{project_id}`");
let endpoints = self
.project2ep
.get(&project_id)
.map(|kv| kv.value().clone())
.unwrap_or_default();
for endpoint_id in endpoints {
if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) {
endpoint_info.invalidate_endpoint();
}
}
}
fn invalidate_allowed_vpc_endpoint_ids_for_org(&self, account_id: AccountIdInt) {
info!(
"invalidating allowed vpc endpoint ids for org `{}`",
account_id
);
fn invalidate_endpoint_access_for_org(&self, account_id: AccountIdInt) {
info!("invalidating endpoint access for org `{account_id}`");
let endpoints = self
.account2ep
.get(&account_id)
@@ -205,41 +123,11 @@ impl ProjectInfoCache for ProjectInfoCacheImpl {
.unwrap_or_default();
for endpoint_id in endpoints {
if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) {
endpoint_info.invalidate_allowed_vpc_endpoint_ids();
endpoint_info.invalidate_endpoint();
}
}
}
fn invalidate_block_public_or_vpc_access_for_project(&self, project_id: ProjectIdInt) {
info!(
"invalidating block public or vpc access for project `{}`",
project_id
);
let endpoints = self
.project2ep
.get(&project_id)
.map(|kv| kv.value().clone())
.unwrap_or_default();
for endpoint_id in endpoints {
if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) {
endpoint_info.invalidate_block_public_or_vpc_access();
}
}
}
fn invalidate_allowed_ips_for_project(&self, project_id: ProjectIdInt) {
info!("invalidating allowed ips for project `{}`", project_id);
let endpoints = self
.project2ep
.get(&project_id)
.map(|kv| kv.value().clone())
.unwrap_or_default();
for endpoint_id in endpoints {
if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) {
endpoint_info.invalidate_allowed_ips();
}
}
}
fn invalidate_role_secret_for_project(&self, project_id: ProjectIdInt, role_name: RoleNameInt) {
info!(
"invalidating role secret for project_id `{}` and role_name `{}`",
@@ -256,6 +144,7 @@ impl ProjectInfoCache for ProjectInfoCacheImpl {
}
}
}
async fn decrement_active_listeners(&self) {
let mut listeners_guard = self.active_listeners_lock.lock().await;
if *listeners_guard == 0 {
@@ -293,155 +182,71 @@ impl ProjectInfoCacheImpl {
}
}
fn get_endpoint_cache(
&self,
endpoint_id: &EndpointId,
) -> Option<Ref<'_, EndpointIdInt, EndpointInfo>> {
let endpoint_id = EndpointIdInt::get(endpoint_id)?;
self.cache.get(&endpoint_id)
}
pub(crate) fn get_role_secret(
&self,
endpoint_id: &EndpointId,
role_name: &RoleName,
) -> Option<Cached<&Self, Option<AuthSecret>>> {
let endpoint_id = EndpointIdInt::get(endpoint_id)?;
) -> Option<RoleAccessControl> {
let valid_since = self.get_cache_times();
let role_name = RoleNameInt::get(role_name)?;
let (valid_since, ignore_cache_since) = self.get_cache_times();
let endpoint_info = self.cache.get(&endpoint_id)?;
let (value, ignore_cache) =
endpoint_info.get_role_secret(role_name, valid_since, ignore_cache_since)?;
if !ignore_cache {
let cached = Cached {
token: Some((
self,
CachedLookupInfo::new_role_secret(endpoint_id, role_name),
)),
value,
};
return Some(cached);
}
Some(Cached::new_uncached(value))
}
pub(crate) fn get_allowed_ips(
&self,
endpoint_id: &EndpointId,
) -> Option<Cached<&Self, Arc<Vec<IpPattern>>>> {
let endpoint_id = EndpointIdInt::get(endpoint_id)?;
let (valid_since, ignore_cache_since) = self.get_cache_times();
let endpoint_info = self.cache.get(&endpoint_id)?;
let value = endpoint_info.get_allowed_ips(valid_since, ignore_cache_since);
let (value, ignore_cache) = value?;
if !ignore_cache {
let cached = Cached {
token: Some((self, CachedLookupInfo::new_allowed_ips(endpoint_id))),
value,
};
return Some(cached);
}
Some(Cached::new_uncached(value))
}
pub(crate) fn get_allowed_vpc_endpoint_ids(
&self,
endpoint_id: &EndpointId,
) -> Option<Cached<&Self, Arc<Vec<String>>>> {
let endpoint_id = EndpointIdInt::get(endpoint_id)?;
let (valid_since, ignore_cache_since) = self.get_cache_times();
let endpoint_info = self.cache.get(&endpoint_id)?;
let value = endpoint_info.get_allowed_vpc_endpoint_ids(valid_since, ignore_cache_since);
let (value, ignore_cache) = value?;
if !ignore_cache {
let cached = Cached {
token: Some((
self,
CachedLookupInfo::new_allowed_vpc_endpoint_ids(endpoint_id),
)),
value,
};
return Some(cached);
}
Some(Cached::new_uncached(value))
}
pub(crate) fn get_block_public_or_vpc_access(
&self,
endpoint_id: &EndpointId,
) -> Option<Cached<&Self, AccessBlockerFlags>> {
let endpoint_id = EndpointIdInt::get(endpoint_id)?;
let (valid_since, ignore_cache_since) = self.get_cache_times();
let endpoint_info = self.cache.get(&endpoint_id)?;
let value = endpoint_info.get_block_public_or_vpc_access(valid_since, ignore_cache_since);
let (value, ignore_cache) = value?;
if !ignore_cache {
let cached = Cached {
token: Some((
self,
CachedLookupInfo::new_block_public_or_vpc_access(endpoint_id),
)),
value,
};
return Some(cached);
}
Some(Cached::new_uncached(value))
let endpoint_info = self.get_endpoint_cache(endpoint_id)?;
endpoint_info.get_role_secret(role_name, valid_since)
}
pub(crate) fn insert_role_secret(
pub(crate) fn get_endpoint_access(
&self,
project_id: ProjectIdInt,
endpoint_id: EndpointIdInt,
role_name: RoleNameInt,
secret: Option<AuthSecret>,
) {
if self.cache.len() >= self.config.size {
// If there are too many entries, wait until the next gc cycle.
return;
}
self.insert_project2endpoint(project_id, endpoint_id);
let mut entry = self.cache.entry(endpoint_id).or_default();
if entry.secret.len() < self.config.max_roles {
entry.secret.insert(role_name, secret.into());
}
endpoint_id: &EndpointId,
) -> Option<EndpointAccessControl> {
let valid_since = self.get_cache_times();
let endpoint_info = self.get_endpoint_cache(endpoint_id)?;
endpoint_info.get_controls(valid_since)
}
pub(crate) fn insert_allowed_ips(
&self,
project_id: ProjectIdInt,
endpoint_id: EndpointIdInt,
allowed_ips: Arc<Vec<IpPattern>>,
) {
if self.cache.len() >= self.config.size {
// If there are too many entries, wait until the next gc cycle.
return;
}
self.insert_project2endpoint(project_id, endpoint_id);
self.cache.entry(endpoint_id).or_default().allowed_ips = Some(allowed_ips.into());
}
pub(crate) fn insert_allowed_vpc_endpoint_ids(
pub(crate) fn insert_endpoint_access(
&self,
account_id: Option<AccountIdInt>,
project_id: ProjectIdInt,
endpoint_id: EndpointIdInt,
allowed_vpc_endpoint_ids: Arc<Vec<String>>,
role_name: RoleNameInt,
controls: EndpointAccessControl,
role_controls: RoleAccessControl,
) {
if self.cache.len() >= self.config.size {
// If there are too many entries, wait until the next gc cycle.
return;
}
if let Some(account_id) = account_id {
self.insert_account2endpoint(account_id, endpoint_id);
}
self.insert_project2endpoint(project_id, endpoint_id);
self.cache
.entry(endpoint_id)
.or_default()
.allowed_vpc_endpoint_ids = Some(allowed_vpc_endpoint_ids.into());
}
pub(crate) fn insert_block_public_or_vpc_access(
&self,
project_id: ProjectIdInt,
endpoint_id: EndpointIdInt,
access_blockers: AccessBlockerFlags,
) {
if self.cache.len() >= self.config.size {
// If there are too many entries, wait until the next gc cycle.
return;
}
self.insert_project2endpoint(project_id, endpoint_id);
self.cache
.entry(endpoint_id)
.or_default()
.block_public_or_vpc_access = Some(access_blockers.into());
let controls = Entry::from(controls);
let role_controls = Entry::from(role_controls);
match self.cache.entry(endpoint_id) {
clashmap::Entry::Vacant(e) => {
e.insert(EndpointInfo {
role_controls: HashMap::from_iter([(role_name, role_controls)]),
controls: Some(controls),
});
}
clashmap::Entry::Occupied(mut e) => {
let ep = e.get_mut();
ep.controls = Some(controls);
if ep.role_controls.len() < self.config.max_roles {
ep.role_controls.insert(role_name, role_controls);
}
}
}
}
fn insert_project2endpoint(&self, project_id: ProjectIdInt, endpoint_id: EndpointIdInt) {
@@ -452,6 +257,7 @@ impl ProjectInfoCacheImpl {
.insert(project_id, HashSet::from([endpoint_id]));
}
}
fn insert_account2endpoint(&self, account_id: AccountIdInt, endpoint_id: EndpointIdInt) {
if let Some(mut endpoints) = self.account2ep.get_mut(&account_id) {
endpoints.insert(endpoint_id);
@@ -460,21 +266,57 @@ impl ProjectInfoCacheImpl {
.insert(account_id, HashSet::from([endpoint_id]));
}
}
fn get_cache_times(&self) -> (Instant, Option<Instant>) {
let mut valid_since = Instant::now() - self.config.ttl;
// Only ignore cache if ttl is disabled.
fn ignore_ttl_since(&self) -> Option<Instant> {
let ttl_disabled_since_us = self
.ttl_disabled_since_us
.load(std::sync::atomic::Ordering::Relaxed);
let ignore_cache_since = if ttl_disabled_since_us == u64::MAX {
None
} else {
let ignore_cache_since = self.start_time + Duration::from_micros(ttl_disabled_since_us);
if ttl_disabled_since_us == u64::MAX {
return None;
}
Some(self.start_time + Duration::from_micros(ttl_disabled_since_us))
}
fn get_cache_times(&self) -> Instant {
let mut valid_since = Instant::now() - self.config.ttl;
if let Some(ignore_ttl_since) = self.ignore_ttl_since() {
// We are fine if entry is not older than ttl or was added before we are getting notifications.
valid_since = valid_since.min(ignore_cache_since);
Some(ignore_cache_since)
valid_since = valid_since.min(ignore_ttl_since);
}
valid_since
}
pub fn maybe_invalidate_role_secret(&self, endpoint_id: &EndpointId, role_name: &RoleName) {
let Some(endpoint_id) = EndpointIdInt::get(endpoint_id) else {
return;
};
(valid_since, ignore_cache_since)
let Some(role_name) = RoleNameInt::get(role_name) else {
return;
};
let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) else {
return;
};
let entry = endpoint_info.role_controls.entry(role_name);
let hash_map::Entry::Occupied(role_controls) = entry else {
return;
};
let created_at = role_controls.get().created_at;
let expire = match self.ignore_ttl_since() {
// if ignoring TTL, we should still try and roll the password if it's old
// and we the client gave an incorrect password. There could be some lag on the redis channel.
Some(_) => created_at + self.config.ttl < Instant::now(),
// edge case: redis is down, let's be generous and invalidate the cache immediately.
None => true,
};
if expire {
role_controls.remove();
}
}
pub async fn gc_worker(&self) -> anyhow::Result<Infallible> {
@@ -509,84 +351,12 @@ impl ProjectInfoCacheImpl {
}
}
/// Lookup info for project info cache.
/// This is used to invalidate cache entries.
pub(crate) struct CachedLookupInfo {
/// Search by this key.
endpoint_id: EndpointIdInt,
lookup_type: LookupType,
}
impl CachedLookupInfo {
pub(self) fn new_role_secret(endpoint_id: EndpointIdInt, role_name: RoleNameInt) -> Self {
Self {
endpoint_id,
lookup_type: LookupType::RoleSecret(role_name),
}
}
pub(self) fn new_allowed_ips(endpoint_id: EndpointIdInt) -> Self {
Self {
endpoint_id,
lookup_type: LookupType::AllowedIps,
}
}
pub(self) fn new_allowed_vpc_endpoint_ids(endpoint_id: EndpointIdInt) -> Self {
Self {
endpoint_id,
lookup_type: LookupType::AllowedVpcEndpointIds,
}
}
pub(self) fn new_block_public_or_vpc_access(endpoint_id: EndpointIdInt) -> Self {
Self {
endpoint_id,
lookup_type: LookupType::BlockPublicOrVpcAccess,
}
}
}
enum LookupType {
RoleSecret(RoleNameInt),
AllowedIps,
AllowedVpcEndpointIds,
BlockPublicOrVpcAccess,
}
impl Cache for ProjectInfoCacheImpl {
type Key = SmolStr;
// Value is not really used here, but we need to specify it.
type Value = SmolStr;
type LookupInfo<Key> = CachedLookupInfo;
fn invalidate(&self, key: &Self::LookupInfo<SmolStr>) {
match &key.lookup_type {
LookupType::RoleSecret(role_name) => {
if let Some(mut endpoint_info) = self.cache.get_mut(&key.endpoint_id) {
endpoint_info.invalidate_role_secret(*role_name);
}
}
LookupType::AllowedIps => {
if let Some(mut endpoint_info) = self.cache.get_mut(&key.endpoint_id) {
endpoint_info.invalidate_allowed_ips();
}
}
LookupType::AllowedVpcEndpointIds => {
if let Some(mut endpoint_info) = self.cache.get_mut(&key.endpoint_id) {
endpoint_info.invalidate_allowed_vpc_endpoint_ids();
}
}
LookupType::BlockPublicOrVpcAccess => {
if let Some(mut endpoint_info) = self.cache.get_mut(&key.endpoint_id) {
endpoint_info.invalidate_block_public_or_vpc_access();
}
}
}
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use super::*;
use crate::control_plane::{AccessBlockerFlags, AuthSecret};
use crate::scram::ServerSecret;
use crate::types::ProjectId;
@@ -601,6 +371,8 @@ mod tests {
});
let project_id: ProjectId = "project".into();
let endpoint_id: EndpointId = "endpoint".into();
let account_id: Option<AccountIdInt> = None;
let user1: RoleName = "user1".into();
let user2: RoleName = "user2".into();
let secret1 = Some(AuthSecret::Scram(ServerSecret::mock([1; 32])));
@@ -609,183 +381,73 @@ mod tests {
"127.0.0.1".parse().unwrap(),
"127.0.0.2".parse().unwrap(),
]);
cache.insert_role_secret(
cache.insert_endpoint_access(
account_id,
(&project_id).into(),
(&endpoint_id).into(),
(&user1).into(),
secret1.clone(),
EndpointAccessControl {
allowed_ips: allowed_ips.clone(),
allowed_vpce: Arc::new(vec![]),
flags: AccessBlockerFlags::default(),
},
RoleAccessControl {
secret: secret1.clone(),
},
);
cache.insert_role_secret(
cache.insert_endpoint_access(
account_id,
(&project_id).into(),
(&endpoint_id).into(),
(&user2).into(),
secret2.clone(),
);
cache.insert_allowed_ips(
(&project_id).into(),
(&endpoint_id).into(),
allowed_ips.clone(),
EndpointAccessControl {
allowed_ips: allowed_ips.clone(),
allowed_vpce: Arc::new(vec![]),
flags: AccessBlockerFlags::default(),
},
RoleAccessControl {
secret: secret2.clone(),
},
);
let cached = cache.get_role_secret(&endpoint_id, &user1).unwrap();
assert!(cached.cached());
assert_eq!(cached.value, secret1);
assert_eq!(cached.secret, secret1);
let cached = cache.get_role_secret(&endpoint_id, &user2).unwrap();
assert!(cached.cached());
assert_eq!(cached.value, secret2);
assert_eq!(cached.secret, secret2);
// Shouldn't add more than 2 roles.
let user3: RoleName = "user3".into();
let secret3 = Some(AuthSecret::Scram(ServerSecret::mock([3; 32])));
cache.insert_role_secret(
cache.insert_endpoint_access(
account_id,
(&project_id).into(),
(&endpoint_id).into(),
(&user3).into(),
secret3.clone(),
EndpointAccessControl {
allowed_ips: allowed_ips.clone(),
allowed_vpce: Arc::new(vec![]),
flags: AccessBlockerFlags::default(),
},
RoleAccessControl {
secret: secret3.clone(),
},
);
assert!(cache.get_role_secret(&endpoint_id, &user3).is_none());
let cached = cache.get_allowed_ips(&endpoint_id).unwrap();
assert!(cached.cached());
assert_eq!(cached.value, allowed_ips);
let cached = cache.get_endpoint_access(&endpoint_id).unwrap();
assert_eq!(cached.allowed_ips, allowed_ips);
tokio::time::advance(Duration::from_secs(2)).await;
let cached = cache.get_role_secret(&endpoint_id, &user1);
assert!(cached.is_none());
let cached = cache.get_role_secret(&endpoint_id, &user2);
assert!(cached.is_none());
let cached = cache.get_allowed_ips(&endpoint_id);
let cached = cache.get_endpoint_access(&endpoint_id);
assert!(cached.is_none());
}
#[tokio::test]
async fn test_project_info_cache_invalidations() {
tokio::time::pause();
let cache = Arc::new(ProjectInfoCacheImpl::new(ProjectInfoCacheOptions {
size: 2,
max_roles: 2,
ttl: Duration::from_secs(1),
gc_interval: Duration::from_secs(600),
}));
cache.clone().increment_active_listeners().await;
tokio::time::advance(Duration::from_secs(2)).await;
let project_id: ProjectId = "project".into();
let endpoint_id: EndpointId = "endpoint".into();
let user1: RoleName = "user1".into();
let user2: RoleName = "user2".into();
let secret1 = Some(AuthSecret::Scram(ServerSecret::mock([1; 32])));
let secret2 = Some(AuthSecret::Scram(ServerSecret::mock([2; 32])));
let allowed_ips = Arc::new(vec![
"127.0.0.1".parse().unwrap(),
"127.0.0.2".parse().unwrap(),
]);
cache.insert_role_secret(
(&project_id).into(),
(&endpoint_id).into(),
(&user1).into(),
secret1.clone(),
);
cache.insert_role_secret(
(&project_id).into(),
(&endpoint_id).into(),
(&user2).into(),
secret2.clone(),
);
cache.insert_allowed_ips(
(&project_id).into(),
(&endpoint_id).into(),
allowed_ips.clone(),
);
tokio::time::advance(Duration::from_secs(2)).await;
// Nothing should be invalidated.
let cached = cache.get_role_secret(&endpoint_id, &user1).unwrap();
// TTL is disabled, so it should be impossible to invalidate this value.
assert!(!cached.cached());
assert_eq!(cached.value, secret1);
cached.invalidate(); // Shouldn't do anything.
let cached = cache.get_role_secret(&endpoint_id, &user1).unwrap();
assert_eq!(cached.value, secret1);
let cached = cache.get_role_secret(&endpoint_id, &user2).unwrap();
assert!(!cached.cached());
assert_eq!(cached.value, secret2);
// The only way to invalidate this value is to invalidate via the api.
cache.invalidate_role_secret_for_project((&project_id).into(), (&user2).into());
assert!(cache.get_role_secret(&endpoint_id, &user2).is_none());
let cached = cache.get_allowed_ips(&endpoint_id).unwrap();
assert!(!cached.cached());
assert_eq!(cached.value, allowed_ips);
}
#[tokio::test]
async fn test_increment_active_listeners_invalidate_added_before() {
tokio::time::pause();
let cache = Arc::new(ProjectInfoCacheImpl::new(ProjectInfoCacheOptions {
size: 2,
max_roles: 2,
ttl: Duration::from_secs(1),
gc_interval: Duration::from_secs(600),
}));
let project_id: ProjectId = "project".into();
let endpoint_id: EndpointId = "endpoint".into();
let user1: RoleName = "user1".into();
let user2: RoleName = "user2".into();
let secret1 = Some(AuthSecret::Scram(ServerSecret::mock([1; 32])));
let secret2 = Some(AuthSecret::Scram(ServerSecret::mock([2; 32])));
let allowed_ips = Arc::new(vec![
"127.0.0.1".parse().unwrap(),
"127.0.0.2".parse().unwrap(),
]);
cache.insert_role_secret(
(&project_id).into(),
(&endpoint_id).into(),
(&user1).into(),
secret1.clone(),
);
cache.clone().increment_active_listeners().await;
tokio::time::advance(Duration::from_millis(100)).await;
cache.insert_role_secret(
(&project_id).into(),
(&endpoint_id).into(),
(&user2).into(),
secret2.clone(),
);
// Added before ttl was disabled + ttl should be still cached.
let cached = cache.get_role_secret(&endpoint_id, &user1).unwrap();
assert!(cached.cached());
let cached = cache.get_role_secret(&endpoint_id, &user2).unwrap();
assert!(cached.cached());
tokio::time::advance(Duration::from_secs(1)).await;
// Added before ttl was disabled + ttl should expire.
assert!(cache.get_role_secret(&endpoint_id, &user1).is_none());
assert!(cache.get_role_secret(&endpoint_id, &user2).is_none());
// Added after ttl was disabled + ttl should not be cached.
cache.insert_allowed_ips(
(&project_id).into(),
(&endpoint_id).into(),
allowed_ips.clone(),
);
let cached = cache.get_allowed_ips(&endpoint_id).unwrap();
assert!(!cached.cached());
tokio::time::advance(Duration::from_secs(1)).await;
// Added before ttl was disabled + ttl still should expire.
assert!(cache.get_role_secret(&endpoint_id, &user1).is_none());
assert!(cache.get_role_secret(&endpoint_id, &user2).is_none());
// Shouldn't be invalidated.
let cached = cache.get_allowed_ips(&endpoint_id).unwrap();
assert!(!cached.cached());
assert_eq!(cached.value, allowed_ips);
}
}