mirror of
https://github.com/neondatabase/neon.git
synced 2025-12-22 21:59:59 +00:00
[proxy] introduce moka for the project-info cache (#12710)
## Problem LKB-2502 The garbage collection of the project info cache is garbage. What we observed: If we get unlucky, we might throw away a very hot entry if the cache is full. The GC loop is dependent on getting a lucky shard of the projects2ep table that clears a lot of cold entries. The GC does not take into account active use, and the interval it runs at is too sparse to do any good. Can we switch to a proper cache implementation? Complications: 1. We need to invalidate by project/account. 2. We need to expire based on `retry_delay_ms`. ## Summary of changes 1. Replace `retry_delay_ms: Duration` with `retry_at: Instant` when deserializing. 2. Split the EndpointControls from the RoleControls into two different caches. 3. Introduce an expiry policy based on error retry info. 4. Introduce `moka` as a dependency, replacing our `TimedLru`. See the follow up PR for changing all TimedLru instances to use moka: #12726.
This commit is contained in:
@@ -54,6 +54,7 @@ json = { path = "../libs/proxy/json" }
|
||||
lasso = { workspace = true, features = ["multi-threaded"] }
|
||||
measured = { workspace = true, features = ["lasso"] }
|
||||
metrics.workspace = true
|
||||
moka.workspace = true
|
||||
once_cell.workspace = true
|
||||
opentelemetry = { workspace = true, features = ["trace"] }
|
||||
papaya = "0.2.0"
|
||||
|
||||
68
proxy/src/cache/common.rs
vendored
68
proxy/src/cache/common.rs
vendored
@@ -1,4 +1,14 @@
|
||||
use std::ops::{Deref, DerefMut};
|
||||
use std::{
|
||||
ops::{Deref, DerefMut},
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
|
||||
use moka::Expiry;
|
||||
|
||||
use crate::control_plane::messages::ControlPlaneErrorMessage;
|
||||
|
||||
/// Default TTL used when caching errors from control plane.
|
||||
pub const DEFAULT_ERROR_TTL: Duration = Duration::from_secs(30);
|
||||
|
||||
/// A generic trait which exposes types of cache's key and value,
|
||||
/// as well as the notion of cache entry invalidation.
|
||||
@@ -87,3 +97,59 @@ impl<C: Cache, V> DerefMut for Cached<C, V> {
|
||||
&mut self.value
|
||||
}
|
||||
}
|
||||
|
||||
pub type ControlPlaneResult<T> = Result<T, Box<ControlPlaneErrorMessage>>;
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
pub struct CplaneExpiry {
|
||||
pub error: Duration,
|
||||
}
|
||||
|
||||
impl Default for CplaneExpiry {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
error: DEFAULT_ERROR_TTL,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl CplaneExpiry {
|
||||
pub fn expire_early<V>(
|
||||
&self,
|
||||
value: &ControlPlaneResult<V>,
|
||||
updated: Instant,
|
||||
) -> Option<Duration> {
|
||||
match value {
|
||||
Ok(_) => None,
|
||||
Err(err) => Some(self.expire_err_early(err, updated)),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn expire_err_early(&self, err: &ControlPlaneErrorMessage, updated: Instant) -> Duration {
|
||||
err.status
|
||||
.as_ref()
|
||||
.and_then(|s| s.details.retry_info.as_ref())
|
||||
.map_or(self.error, |r| r.retry_at.into_std() - updated)
|
||||
}
|
||||
}
|
||||
|
||||
impl<K, V> Expiry<K, ControlPlaneResult<V>> for CplaneExpiry {
|
||||
fn expire_after_create(
|
||||
&self,
|
||||
_key: &K,
|
||||
value: &ControlPlaneResult<V>,
|
||||
created_at: Instant,
|
||||
) -> Option<Duration> {
|
||||
self.expire_early(value, created_at)
|
||||
}
|
||||
|
||||
fn expire_after_update(
|
||||
&self,
|
||||
_key: &K,
|
||||
value: &ControlPlaneResult<V>,
|
||||
updated_at: Instant,
|
||||
_duration_until_expiry: Option<Duration>,
|
||||
) -> Option<Duration> {
|
||||
self.expire_early(value, updated_at)
|
||||
}
|
||||
}
|
||||
|
||||
356
proxy/src/cache/project_info.rs
vendored
356
proxy/src/cache/project_info.rs
vendored
@@ -1,84 +1,17 @@
|
||||
use std::collections::{HashMap, HashSet, hash_map};
|
||||
use std::collections::HashSet;
|
||||
use std::convert::Infallible;
|
||||
use std::time::Duration;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use clashmap::ClashMap;
|
||||
use clashmap::mapref::one::Ref;
|
||||
use rand::Rng;
|
||||
use tokio::time::Instant;
|
||||
use moka::sync::Cache;
|
||||
use tracing::{debug, info};
|
||||
|
||||
use crate::cache::common::{ControlPlaneResult, CplaneExpiry};
|
||||
use crate::config::ProjectInfoCacheOptions;
|
||||
use crate::control_plane::messages::{ControlPlaneErrorMessage, Reason};
|
||||
use crate::control_plane::{EndpointAccessControl, RoleAccessControl};
|
||||
use crate::intern::{AccountIdInt, EndpointIdInt, ProjectIdInt, RoleNameInt};
|
||||
use crate::types::{EndpointId, RoleName};
|
||||
|
||||
#[async_trait]
|
||||
pub(crate) trait ProjectInfoCache {
|
||||
fn invalidate_endpoint_access(&self, endpoint_id: EndpointIdInt);
|
||||
fn invalidate_endpoint_access_for_project(&self, project_id: ProjectIdInt);
|
||||
fn invalidate_endpoint_access_for_org(&self, account_id: AccountIdInt);
|
||||
fn invalidate_role_secret_for_project(&self, project_id: ProjectIdInt, role_name: RoleNameInt);
|
||||
}
|
||||
|
||||
struct Entry<T> {
|
||||
expires_at: Instant,
|
||||
value: T,
|
||||
}
|
||||
|
||||
impl<T> Entry<T> {
|
||||
pub(crate) fn new(value: T, ttl: Duration) -> Self {
|
||||
Self {
|
||||
expires_at: Instant::now() + ttl,
|
||||
value,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn get(&self) -> Option<&T> {
|
||||
(!self.is_expired()).then_some(&self.value)
|
||||
}
|
||||
|
||||
fn is_expired(&self) -> bool {
|
||||
self.expires_at <= Instant::now()
|
||||
}
|
||||
}
|
||||
|
||||
struct EndpointInfo {
|
||||
role_controls: HashMap<RoleNameInt, Entry<ControlPlaneResult<RoleAccessControl>>>,
|
||||
controls: Option<Entry<ControlPlaneResult<EndpointAccessControl>>>,
|
||||
}
|
||||
|
||||
type ControlPlaneResult<T> = Result<T, Box<ControlPlaneErrorMessage>>;
|
||||
|
||||
impl EndpointInfo {
|
||||
pub(crate) fn get_role_secret_with_ttl(
|
||||
&self,
|
||||
role_name: RoleNameInt,
|
||||
) -> Option<(ControlPlaneResult<RoleAccessControl>, Duration)> {
|
||||
let entry = self.role_controls.get(&role_name)?;
|
||||
let ttl = entry.expires_at - Instant::now();
|
||||
Some((entry.get()?.clone(), ttl))
|
||||
}
|
||||
|
||||
pub(crate) fn get_controls_with_ttl(
|
||||
&self,
|
||||
) -> Option<(ControlPlaneResult<EndpointAccessControl>, Duration)> {
|
||||
let entry = self.controls.as_ref()?;
|
||||
let ttl = entry.expires_at - Instant::now();
|
||||
Some((entry.get()?.clone(), ttl))
|
||||
}
|
||||
|
||||
pub(crate) fn invalidate_endpoint(&mut self) {
|
||||
self.controls = None;
|
||||
}
|
||||
|
||||
pub(crate) fn invalidate_role_secret(&mut self, role_name: RoleNameInt) {
|
||||
self.role_controls.remove(&role_name);
|
||||
}
|
||||
}
|
||||
|
||||
/// Cache for project info.
|
||||
/// This is used to cache auth data for endpoints.
|
||||
/// Invalidation is done by console notifications or by TTL (if console notifications are disabled).
|
||||
@@ -86,8 +19,9 @@ impl EndpointInfo {
|
||||
/// We also store endpoint-to-project mapping in the cache, to be able to access per-endpoint data.
|
||||
/// One may ask, why the data is stored per project, when on the user request there is only data about the endpoint available?
|
||||
/// On the cplane side updates are done per project (or per branch), so it's easier to invalidate the whole project cache.
|
||||
pub struct ProjectInfoCacheImpl {
|
||||
cache: ClashMap<EndpointIdInt, EndpointInfo>,
|
||||
pub struct ProjectInfoCache {
|
||||
role_controls: Cache<(EndpointIdInt, RoleNameInt), ControlPlaneResult<RoleAccessControl>>,
|
||||
ep_controls: Cache<EndpointIdInt, ControlPlaneResult<EndpointAccessControl>>,
|
||||
|
||||
project2ep: ClashMap<ProjectIdInt, HashSet<EndpointIdInt>>,
|
||||
// FIXME(stefan): we need a way to GC the account2ep map.
|
||||
@@ -96,16 +30,13 @@ pub struct ProjectInfoCacheImpl {
|
||||
config: ProjectInfoCacheOptions,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ProjectInfoCache for ProjectInfoCacheImpl {
|
||||
fn invalidate_endpoint_access(&self, endpoint_id: EndpointIdInt) {
|
||||
impl ProjectInfoCache {
|
||||
pub fn invalidate_endpoint_access(&self, endpoint_id: EndpointIdInt) {
|
||||
info!("invalidating endpoint access for `{endpoint_id}`");
|
||||
if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) {
|
||||
endpoint_info.invalidate_endpoint();
|
||||
}
|
||||
self.ep_controls.invalidate(&endpoint_id);
|
||||
}
|
||||
|
||||
fn invalidate_endpoint_access_for_project(&self, project_id: ProjectIdInt) {
|
||||
pub fn invalidate_endpoint_access_for_project(&self, project_id: ProjectIdInt) {
|
||||
info!("invalidating endpoint access for project `{project_id}`");
|
||||
let endpoints = self
|
||||
.project2ep
|
||||
@@ -113,13 +44,11 @@ impl ProjectInfoCache for ProjectInfoCacheImpl {
|
||||
.map(|kv| kv.value().clone())
|
||||
.unwrap_or_default();
|
||||
for endpoint_id in endpoints {
|
||||
if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) {
|
||||
endpoint_info.invalidate_endpoint();
|
||||
}
|
||||
self.ep_controls.invalidate(&endpoint_id);
|
||||
}
|
||||
}
|
||||
|
||||
fn invalidate_endpoint_access_for_org(&self, account_id: AccountIdInt) {
|
||||
pub fn invalidate_endpoint_access_for_org(&self, account_id: AccountIdInt) {
|
||||
info!("invalidating endpoint access for org `{account_id}`");
|
||||
let endpoints = self
|
||||
.account2ep
|
||||
@@ -127,13 +56,15 @@ impl ProjectInfoCache for ProjectInfoCacheImpl {
|
||||
.map(|kv| kv.value().clone())
|
||||
.unwrap_or_default();
|
||||
for endpoint_id in endpoints {
|
||||
if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) {
|
||||
endpoint_info.invalidate_endpoint();
|
||||
}
|
||||
self.ep_controls.invalidate(&endpoint_id);
|
||||
}
|
||||
}
|
||||
|
||||
fn invalidate_role_secret_for_project(&self, project_id: ProjectIdInt, role_name: RoleNameInt) {
|
||||
pub fn invalidate_role_secret_for_project(
|
||||
&self,
|
||||
project_id: ProjectIdInt,
|
||||
role_name: RoleNameInt,
|
||||
) {
|
||||
info!(
|
||||
"invalidating role secret for project_id `{}` and role_name `{}`",
|
||||
project_id, role_name,
|
||||
@@ -144,47 +75,52 @@ impl ProjectInfoCache for ProjectInfoCacheImpl {
|
||||
.map(|kv| kv.value().clone())
|
||||
.unwrap_or_default();
|
||||
for endpoint_id in endpoints {
|
||||
if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) {
|
||||
endpoint_info.invalidate_role_secret(role_name);
|
||||
}
|
||||
self.role_controls.invalidate(&(endpoint_id, role_name));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ProjectInfoCacheImpl {
|
||||
impl ProjectInfoCache {
|
||||
pub(crate) fn new(config: ProjectInfoCacheOptions) -> Self {
|
||||
// we cache errors for 30 seconds, unless retry_at is set.
|
||||
let expiry = CplaneExpiry::default();
|
||||
Self {
|
||||
cache: ClashMap::new(),
|
||||
role_controls: Cache::builder()
|
||||
.name("role_access_controls")
|
||||
.max_capacity(config.size * config.max_roles)
|
||||
.time_to_live(config.ttl)
|
||||
.expire_after(expiry)
|
||||
.build(),
|
||||
ep_controls: Cache::builder()
|
||||
.name("endpoint_access_controls")
|
||||
.max_capacity(config.size)
|
||||
.time_to_live(config.ttl)
|
||||
.expire_after(expiry)
|
||||
.build(),
|
||||
project2ep: ClashMap::new(),
|
||||
account2ep: ClashMap::new(),
|
||||
config,
|
||||
}
|
||||
}
|
||||
|
||||
fn get_endpoint_cache(
|
||||
&self,
|
||||
endpoint_id: &EndpointId,
|
||||
) -> Option<Ref<'_, EndpointIdInt, EndpointInfo>> {
|
||||
let endpoint_id = EndpointIdInt::get(endpoint_id)?;
|
||||
self.cache.get(&endpoint_id)
|
||||
}
|
||||
|
||||
pub(crate) fn get_role_secret_with_ttl(
|
||||
pub(crate) fn get_role_secret(
|
||||
&self,
|
||||
endpoint_id: &EndpointId,
|
||||
role_name: &RoleName,
|
||||
) -> Option<(ControlPlaneResult<RoleAccessControl>, Duration)> {
|
||||
) -> Option<ControlPlaneResult<RoleAccessControl>> {
|
||||
let endpoint_id = EndpointIdInt::get(endpoint_id)?;
|
||||
let role_name = RoleNameInt::get(role_name)?;
|
||||
let endpoint_info = self.get_endpoint_cache(endpoint_id)?;
|
||||
endpoint_info.get_role_secret_with_ttl(role_name)
|
||||
|
||||
self.role_controls.get(&(endpoint_id, role_name))
|
||||
}
|
||||
|
||||
pub(crate) fn get_endpoint_access_with_ttl(
|
||||
pub(crate) fn get_endpoint_access(
|
||||
&self,
|
||||
endpoint_id: &EndpointId,
|
||||
) -> Option<(ControlPlaneResult<EndpointAccessControl>, Duration)> {
|
||||
let endpoint_info = self.get_endpoint_cache(endpoint_id)?;
|
||||
endpoint_info.get_controls_with_ttl()
|
||||
) -> Option<ControlPlaneResult<EndpointAccessControl>> {
|
||||
let endpoint_id = EndpointIdInt::get(endpoint_id)?;
|
||||
|
||||
self.ep_controls.get(&endpoint_id)
|
||||
}
|
||||
|
||||
pub(crate) fn insert_endpoint_access(
|
||||
@@ -203,34 +139,14 @@ impl ProjectInfoCacheImpl {
|
||||
self.insert_project2endpoint(project_id, endpoint_id);
|
||||
}
|
||||
|
||||
if self.cache.len() >= self.config.size {
|
||||
// If there are too many entries, wait until the next gc cycle.
|
||||
return;
|
||||
}
|
||||
|
||||
debug!(
|
||||
key = &*endpoint_id,
|
||||
"created a cache entry for endpoint access"
|
||||
);
|
||||
|
||||
let controls = Some(Entry::new(Ok(controls), self.config.ttl));
|
||||
let role_controls = Entry::new(Ok(role_controls), self.config.ttl);
|
||||
|
||||
match self.cache.entry(endpoint_id) {
|
||||
clashmap::Entry::Vacant(e) => {
|
||||
e.insert(EndpointInfo {
|
||||
role_controls: HashMap::from_iter([(role_name, role_controls)]),
|
||||
controls,
|
||||
});
|
||||
}
|
||||
clashmap::Entry::Occupied(mut e) => {
|
||||
let ep = e.get_mut();
|
||||
ep.controls = controls;
|
||||
if ep.role_controls.len() < self.config.max_roles {
|
||||
ep.role_controls.insert(role_name, role_controls);
|
||||
}
|
||||
}
|
||||
}
|
||||
self.ep_controls.insert(endpoint_id, Ok(controls));
|
||||
self.role_controls
|
||||
.insert((endpoint_id, role_name), Ok(role_controls));
|
||||
}
|
||||
|
||||
pub(crate) fn insert_endpoint_access_err(
|
||||
@@ -238,55 +154,30 @@ impl ProjectInfoCacheImpl {
|
||||
endpoint_id: EndpointIdInt,
|
||||
role_name: RoleNameInt,
|
||||
msg: Box<ControlPlaneErrorMessage>,
|
||||
ttl: Option<Duration>,
|
||||
) {
|
||||
if self.cache.len() >= self.config.size {
|
||||
// If there are too many entries, wait until the next gc cycle.
|
||||
return;
|
||||
}
|
||||
|
||||
debug!(
|
||||
key = &*endpoint_id,
|
||||
"created a cache entry for an endpoint access error"
|
||||
);
|
||||
|
||||
let ttl = ttl.unwrap_or(self.config.ttl);
|
||||
|
||||
let controls = if msg.get_reason() == Reason::RoleProtected {
|
||||
// RoleProtected is the only role-specific error that control plane can give us.
|
||||
// If a given role name does not exist, it still returns a successful response,
|
||||
// just with an empty secret.
|
||||
None
|
||||
} else {
|
||||
// We can cache all the other errors in EndpointInfo.controls,
|
||||
// because they don't depend on what role name we pass to control plane.
|
||||
Some(Entry::new(Err(msg.clone()), ttl))
|
||||
};
|
||||
|
||||
let role_controls = Entry::new(Err(msg), ttl);
|
||||
|
||||
match self.cache.entry(endpoint_id) {
|
||||
clashmap::Entry::Vacant(e) => {
|
||||
e.insert(EndpointInfo {
|
||||
role_controls: HashMap::from_iter([(role_name, role_controls)]),
|
||||
controls,
|
||||
// RoleProtected is the only role-specific error that control plane can give us.
|
||||
// If a given role name does not exist, it still returns a successful response,
|
||||
// just with an empty secret.
|
||||
if msg.get_reason() != Reason::RoleProtected {
|
||||
// We can cache all the other errors in ep_controls because they don't
|
||||
// depend on what role name we pass to control plane.
|
||||
self.ep_controls
|
||||
.entry(endpoint_id)
|
||||
.and_compute_with(|entry| match entry {
|
||||
// leave the entry alone if it's already Ok
|
||||
Some(entry) if entry.value().is_ok() => moka::ops::compute::Op::Nop,
|
||||
// replace the entry
|
||||
_ => moka::ops::compute::Op::Put(Err(msg.clone())),
|
||||
});
|
||||
}
|
||||
clashmap::Entry::Occupied(mut e) => {
|
||||
let ep = e.get_mut();
|
||||
if let Some(entry) = &ep.controls
|
||||
&& !entry.is_expired()
|
||||
&& entry.value.is_ok()
|
||||
{
|
||||
// If we have cached non-expired, non-error controls, keep them.
|
||||
} else {
|
||||
ep.controls = controls;
|
||||
}
|
||||
if ep.role_controls.len() < self.config.max_roles {
|
||||
ep.role_controls.insert(role_name, role_controls);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
self.role_controls
|
||||
.insert((endpoint_id, role_name), Err(msg));
|
||||
}
|
||||
|
||||
fn insert_project2endpoint(&self, project_id: ProjectIdInt, endpoint_id: EndpointIdInt) {
|
||||
@@ -307,58 +198,19 @@ impl ProjectInfoCacheImpl {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn maybe_invalidate_role_secret(&self, endpoint_id: &EndpointId, role_name: &RoleName) {
|
||||
let Some(endpoint_id) = EndpointIdInt::get(endpoint_id) else {
|
||||
return;
|
||||
};
|
||||
let Some(role_name) = RoleNameInt::get(role_name) else {
|
||||
return;
|
||||
};
|
||||
|
||||
let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) else {
|
||||
return;
|
||||
};
|
||||
|
||||
let entry = endpoint_info.role_controls.entry(role_name);
|
||||
let hash_map::Entry::Occupied(role_controls) = entry else {
|
||||
return;
|
||||
};
|
||||
|
||||
if role_controls.get().is_expired() {
|
||||
role_controls.remove();
|
||||
}
|
||||
pub fn maybe_invalidate_role_secret(&self, _endpoint_id: &EndpointId, _role_name: &RoleName) {
|
||||
// TODO: Expire the value early if the key is idle.
|
||||
// Currently not an issue as we would just use the TTL to decide, which is what already happens.
|
||||
}
|
||||
|
||||
pub async fn gc_worker(&self) -> anyhow::Result<Infallible> {
|
||||
let mut interval =
|
||||
tokio::time::interval(self.config.gc_interval / (self.cache.shards().len()) as u32);
|
||||
let mut interval = tokio::time::interval(self.config.gc_interval);
|
||||
loop {
|
||||
interval.tick().await;
|
||||
if self.cache.len() < self.config.size {
|
||||
// If there are not too many entries, wait until the next gc cycle.
|
||||
continue;
|
||||
}
|
||||
self.gc();
|
||||
self.ep_controls.run_pending_tasks();
|
||||
self.role_controls.run_pending_tasks();
|
||||
}
|
||||
}
|
||||
|
||||
fn gc(&self) {
|
||||
let shard = rand::rng().random_range(0..self.project2ep.shards().len());
|
||||
debug!(shard, "project_info_cache: performing epoch reclamation");
|
||||
|
||||
// acquire a random shard lock
|
||||
let mut removed = 0;
|
||||
let shard = self.project2ep.shards()[shard].write();
|
||||
for (_, endpoints) in shard.iter() {
|
||||
for endpoint in endpoints {
|
||||
self.cache.remove(endpoint);
|
||||
removed += 1;
|
||||
}
|
||||
}
|
||||
// We can drop this shard only after making sure that all endpoints are removed.
|
||||
drop(shard);
|
||||
info!("project_info_cache: removed {removed} endpoints");
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -368,12 +220,12 @@ mod tests {
|
||||
use crate::control_plane::{AccessBlockerFlags, AuthSecret};
|
||||
use crate::scram::ServerSecret;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_project_info_cache_settings() {
|
||||
tokio::time::pause();
|
||||
let cache = ProjectInfoCacheImpl::new(ProjectInfoCacheOptions {
|
||||
size: 2,
|
||||
let cache = ProjectInfoCache::new(ProjectInfoCacheOptions {
|
||||
size: 1,
|
||||
max_roles: 2,
|
||||
ttl: Duration::from_secs(1),
|
||||
gc_interval: Duration::from_secs(600),
|
||||
@@ -423,22 +275,17 @@ mod tests {
|
||||
},
|
||||
);
|
||||
|
||||
let (cached, ttl) = cache
|
||||
.get_role_secret_with_ttl(&endpoint_id, &user1)
|
||||
.unwrap();
|
||||
let cached = cache.get_role_secret(&endpoint_id, &user1).unwrap();
|
||||
assert_eq!(cached.unwrap().secret, secret1);
|
||||
assert_eq!(ttl, cache.config.ttl);
|
||||
|
||||
let (cached, ttl) = cache
|
||||
.get_role_secret_with_ttl(&endpoint_id, &user2)
|
||||
.unwrap();
|
||||
let cached = cache.get_role_secret(&endpoint_id, &user2).unwrap();
|
||||
assert_eq!(cached.unwrap().secret, secret2);
|
||||
assert_eq!(ttl, cache.config.ttl);
|
||||
|
||||
// Shouldn't add more than 2 roles.
|
||||
let user3: RoleName = "user3".into();
|
||||
let secret3 = Some(AuthSecret::Scram(ServerSecret::mock([3; 32])));
|
||||
|
||||
cache.role_controls.run_pending_tasks();
|
||||
cache.insert_endpoint_access(
|
||||
account_id,
|
||||
project_id,
|
||||
@@ -455,31 +302,18 @@ mod tests {
|
||||
},
|
||||
);
|
||||
|
||||
assert!(
|
||||
cache
|
||||
.get_role_secret_with_ttl(&endpoint_id, &user3)
|
||||
.is_none()
|
||||
);
|
||||
cache.role_controls.run_pending_tasks();
|
||||
assert_eq!(cache.role_controls.entry_count(), 2);
|
||||
|
||||
let cached = cache
|
||||
.get_endpoint_access_with_ttl(&endpoint_id)
|
||||
.unwrap()
|
||||
.0
|
||||
.unwrap();
|
||||
assert_eq!(cached.allowed_ips, allowed_ips);
|
||||
tokio::time::sleep(Duration::from_secs(2)).await;
|
||||
|
||||
tokio::time::advance(Duration::from_secs(2)).await;
|
||||
let cached = cache.get_role_secret_with_ttl(&endpoint_id, &user1);
|
||||
assert!(cached.is_none());
|
||||
let cached = cache.get_role_secret_with_ttl(&endpoint_id, &user2);
|
||||
assert!(cached.is_none());
|
||||
let cached = cache.get_endpoint_access_with_ttl(&endpoint_id);
|
||||
assert!(cached.is_none());
|
||||
cache.role_controls.run_pending_tasks();
|
||||
assert_eq!(cache.role_controls.entry_count(), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_caching_project_info_errors() {
|
||||
let cache = ProjectInfoCacheImpl::new(ProjectInfoCacheOptions {
|
||||
let cache = ProjectInfoCache::new(ProjectInfoCacheOptions {
|
||||
size: 10,
|
||||
max_roles: 10,
|
||||
ttl: Duration::from_secs(1),
|
||||
@@ -519,34 +353,23 @@ mod tests {
|
||||
status: None,
|
||||
});
|
||||
|
||||
let get_role_secret = |endpoint_id, role_name| {
|
||||
cache
|
||||
.get_role_secret_with_ttl(endpoint_id, role_name)
|
||||
.unwrap()
|
||||
.0
|
||||
};
|
||||
let get_endpoint_access =
|
||||
|endpoint_id| cache.get_endpoint_access_with_ttl(endpoint_id).unwrap().0;
|
||||
let get_role_secret =
|
||||
|endpoint_id, role_name| cache.get_role_secret(endpoint_id, role_name).unwrap();
|
||||
let get_endpoint_access = |endpoint_id| cache.get_endpoint_access(endpoint_id).unwrap();
|
||||
|
||||
// stores role-specific errors only for get_role_secret
|
||||
cache.insert_endpoint_access_err(
|
||||
(&endpoint_id).into(),
|
||||
(&user1).into(),
|
||||
role_msg.clone(),
|
||||
None,
|
||||
);
|
||||
cache.insert_endpoint_access_err((&endpoint_id).into(), (&user1).into(), role_msg.clone());
|
||||
assert_eq!(
|
||||
get_role_secret(&endpoint_id, &user1).unwrap_err().error,
|
||||
role_msg.error
|
||||
);
|
||||
assert!(cache.get_endpoint_access_with_ttl(&endpoint_id).is_none());
|
||||
assert!(cache.get_endpoint_access(&endpoint_id).is_none());
|
||||
|
||||
// stores non-role specific errors for both get_role_secret and get_endpoint_access
|
||||
cache.insert_endpoint_access_err(
|
||||
(&endpoint_id).into(),
|
||||
(&user1).into(),
|
||||
generic_msg.clone(),
|
||||
None,
|
||||
);
|
||||
assert_eq!(
|
||||
get_role_secret(&endpoint_id, &user1).unwrap_err().error,
|
||||
@@ -558,11 +381,7 @@ mod tests {
|
||||
);
|
||||
|
||||
// error isn't returned for other roles in the same endpoint
|
||||
assert!(
|
||||
cache
|
||||
.get_role_secret_with_ttl(&endpoint_id, &user2)
|
||||
.is_none()
|
||||
);
|
||||
assert!(cache.get_role_secret(&endpoint_id, &user2).is_none());
|
||||
|
||||
// success for a role does not overwrite errors for other roles
|
||||
cache.insert_endpoint_access(
|
||||
@@ -590,7 +409,6 @@ mod tests {
|
||||
(&endpoint_id).into(),
|
||||
(&user2).into(),
|
||||
generic_msg.clone(),
|
||||
None,
|
||||
);
|
||||
assert!(get_role_secret(&endpoint_id, &user2).is_err());
|
||||
assert!(get_endpoint_access(&endpoint_id).is_ok());
|
||||
|
||||
7
proxy/src/cache/timed_lru.rs
vendored
7
proxy/src/cache/timed_lru.rs
vendored
@@ -246,17 +246,14 @@ impl<K: Hash + Eq + Clone, V: Clone> TimedLru<K, V> {
|
||||
|
||||
impl<K: Hash + Eq, V: Clone> TimedLru<K, V> {
|
||||
/// Retrieve a cached entry in convenient wrapper, alongside timing information.
|
||||
pub(crate) fn get_with_created_at<Q>(
|
||||
&self,
|
||||
key: &Q,
|
||||
) -> Option<Cached<&Self, (<Self as Cache>::Value, Instant)>>
|
||||
pub(crate) fn get<Q>(&self, key: &Q) -> Option<Cached<&Self, <Self as Cache>::Value>>
|
||||
where
|
||||
K: Borrow<Q> + Clone,
|
||||
Q: Hash + Eq + ?Sized,
|
||||
{
|
||||
self.get_raw(key, |key, entry| Cached {
|
||||
token: Some((self, key.clone())),
|
||||
value: (entry.value.clone(), entry.created_at),
|
||||
value: entry.value.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -159,11 +159,11 @@ impl FromStr for CacheOptions {
|
||||
#[derive(Debug)]
|
||||
pub struct ProjectInfoCacheOptions {
|
||||
/// Max number of entries.
|
||||
pub size: usize,
|
||||
pub size: u64,
|
||||
/// Entry's time-to-live.
|
||||
pub ttl: Duration,
|
||||
/// Max number of roles per endpoint.
|
||||
pub max_roles: usize,
|
||||
pub max_roles: u64,
|
||||
/// Gc interval.
|
||||
pub gc_interval: Duration,
|
||||
}
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
use std::net::IpAddr;
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use ::http::HeaderName;
|
||||
use ::http::header::AUTHORIZATION;
|
||||
@@ -17,6 +16,7 @@ use tracing::{Instrument, debug, info, info_span, warn};
|
||||
use super::super::messages::{ControlPlaneErrorMessage, GetEndpointAccessControl, WakeCompute};
|
||||
use crate::auth::backend::ComputeUserInfo;
|
||||
use crate::auth::backend::jwt::AuthRule;
|
||||
use crate::cache::common::DEFAULT_ERROR_TTL;
|
||||
use crate::context::RequestContext;
|
||||
use crate::control_plane::caches::ApiCaches;
|
||||
use crate::control_plane::errors::{
|
||||
@@ -118,7 +118,6 @@ impl NeonControlPlaneClient {
|
||||
cache_key.into(),
|
||||
role.into(),
|
||||
msg.clone(),
|
||||
retry_info.map(|r| Duration::from_millis(r.retry_delay_ms)),
|
||||
);
|
||||
|
||||
Err(err)
|
||||
@@ -347,18 +346,11 @@ impl super::ControlPlaneApi for NeonControlPlaneClient {
|
||||
) -> Result<RoleAccessControl, GetAuthInfoError> {
|
||||
let key = endpoint.normalize();
|
||||
|
||||
if let Some((role_control, ttl)) = self
|
||||
.caches
|
||||
.project_info
|
||||
.get_role_secret_with_ttl(&key, role)
|
||||
{
|
||||
if let Some(role_control) = self.caches.project_info.get_role_secret(&key, role) {
|
||||
return match role_control {
|
||||
Err(mut msg) => {
|
||||
Err(msg) => {
|
||||
info!(key = &*key, "found cached get_role_access_control error");
|
||||
|
||||
// if retry_delay_ms is set change it to the remaining TTL
|
||||
replace_retry_delay_ms(&mut msg, |_| ttl.as_millis() as u64);
|
||||
|
||||
Err(GetAuthInfoError::ApiError(ControlPlaneError::Message(msg)))
|
||||
}
|
||||
Ok(role_control) => {
|
||||
@@ -383,17 +375,14 @@ impl super::ControlPlaneApi for NeonControlPlaneClient {
|
||||
) -> Result<EndpointAccessControl, GetAuthInfoError> {
|
||||
let key = endpoint.normalize();
|
||||
|
||||
if let Some((control, ttl)) = self.caches.project_info.get_endpoint_access_with_ttl(&key) {
|
||||
if let Some(control) = self.caches.project_info.get_endpoint_access(&key) {
|
||||
return match control {
|
||||
Err(mut msg) => {
|
||||
Err(msg) => {
|
||||
info!(
|
||||
key = &*key,
|
||||
"found cached get_endpoint_access_control error"
|
||||
);
|
||||
|
||||
// if retry_delay_ms is set change it to the remaining TTL
|
||||
replace_retry_delay_ms(&mut msg, |_| ttl.as_millis() as u64);
|
||||
|
||||
Err(GetAuthInfoError::ApiError(ControlPlaneError::Message(msg)))
|
||||
}
|
||||
Ok(control) => {
|
||||
@@ -426,17 +415,12 @@ impl super::ControlPlaneApi for NeonControlPlaneClient {
|
||||
|
||||
macro_rules! check_cache {
|
||||
() => {
|
||||
if let Some(cached) = self.caches.node_info.get_with_created_at(&key) {
|
||||
let (cached, (info, created_at)) = cached.take_value();
|
||||
if let Some(cached) = self.caches.node_info.get(&key) {
|
||||
let (cached, info) = cached.take_value();
|
||||
return match info {
|
||||
Err(mut msg) => {
|
||||
Err(msg) => {
|
||||
info!(key = &*key, "found cached wake_compute error");
|
||||
|
||||
// if retry_delay_ms is set, reduce it by the amount of time it spent in cache
|
||||
replace_retry_delay_ms(&mut msg, |delay| {
|
||||
delay.saturating_sub(created_at.elapsed().as_millis() as u64)
|
||||
});
|
||||
|
||||
Err(WakeComputeError::ControlPlane(ControlPlaneError::Message(
|
||||
msg,
|
||||
)))
|
||||
@@ -503,9 +487,7 @@ impl super::ControlPlaneApi for NeonControlPlaneClient {
|
||||
"created a cache entry for the wake compute error"
|
||||
);
|
||||
|
||||
let ttl = retry_info.map_or(Duration::from_secs(30), |r| {
|
||||
Duration::from_millis(r.retry_delay_ms)
|
||||
});
|
||||
let ttl = retry_info.map_or(DEFAULT_ERROR_TTL, |r| r.retry_at - Instant::now());
|
||||
|
||||
self.caches.node_info.insert_ttl(key, Err(msg.clone()), ttl);
|
||||
|
||||
@@ -517,14 +499,6 @@ impl super::ControlPlaneApi for NeonControlPlaneClient {
|
||||
}
|
||||
}
|
||||
|
||||
fn replace_retry_delay_ms(msg: &mut ControlPlaneErrorMessage, f: impl FnOnce(u64) -> u64) {
|
||||
if let Some(status) = &mut msg.status
|
||||
&& let Some(retry_info) = &mut status.details.retry_info
|
||||
{
|
||||
retry_info.retry_delay_ms = f(retry_info.retry_delay_ms);
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse http response body, taking status code into account.
|
||||
fn parse_body<T: for<'a> serde::Deserialize<'a>>(
|
||||
status: StatusCode,
|
||||
|
||||
@@ -13,7 +13,7 @@ use tracing::{debug, info};
|
||||
use super::{EndpointAccessControl, RoleAccessControl};
|
||||
use crate::auth::backend::ComputeUserInfo;
|
||||
use crate::auth::backend::jwt::{AuthRule, FetchAuthRules, FetchAuthRulesError};
|
||||
use crate::cache::project_info::ProjectInfoCacheImpl;
|
||||
use crate::cache::project_info::ProjectInfoCache;
|
||||
use crate::config::{CacheOptions, ProjectInfoCacheOptions};
|
||||
use crate::context::RequestContext;
|
||||
use crate::control_plane::{CachedNodeInfo, ControlPlaneApi, NodeInfoCache, errors};
|
||||
@@ -119,7 +119,7 @@ pub struct ApiCaches {
|
||||
/// Cache for the `wake_compute` API method.
|
||||
pub(crate) node_info: NodeInfoCache,
|
||||
/// Cache which stores project_id -> endpoint_ids mapping.
|
||||
pub project_info: Arc<ProjectInfoCacheImpl>,
|
||||
pub project_info: Arc<ProjectInfoCache>,
|
||||
}
|
||||
|
||||
impl ApiCaches {
|
||||
@@ -134,7 +134,7 @@ impl ApiCaches {
|
||||
wake_compute_cache_config.ttl,
|
||||
true,
|
||||
),
|
||||
project_info: Arc::new(ProjectInfoCacheImpl::new(project_info_cache_config)),
|
||||
project_info: Arc::new(ProjectInfoCache::new(project_info_cache_config)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
use std::fmt::{self, Display};
|
||||
use std::time::Duration;
|
||||
|
||||
use measured::FixedCardinalityLabel;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use smol_str::SmolStr;
|
||||
use tokio::time::Instant;
|
||||
|
||||
use crate::auth::IpPattern;
|
||||
use crate::intern::{AccountIdInt, BranchIdInt, EndpointIdInt, ProjectIdInt, RoleNameInt};
|
||||
@@ -231,7 +233,13 @@ impl Reason {
|
||||
#[derive(Copy, Clone, Debug, Deserialize)]
|
||||
#[allow(dead_code)]
|
||||
pub(crate) struct RetryInfo {
|
||||
pub(crate) retry_delay_ms: u64,
|
||||
#[serde(rename = "retry_delay_ms", deserialize_with = "milliseconds_from_now")]
|
||||
pub(crate) retry_at: Instant,
|
||||
}
|
||||
|
||||
fn milliseconds_from_now<'de, D: serde::Deserializer<'de>>(d: D) -> Result<Instant, D::Error> {
|
||||
let millis = u64::deserialize(d)?;
|
||||
Ok(Instant::now() + Duration::from_millis(millis))
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Clone)]
|
||||
|
||||
@@ -15,6 +15,7 @@ use rstest::rstest;
|
||||
use rustls::crypto::ring;
|
||||
use rustls::pki_types;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, DuplexStream};
|
||||
use tokio::time::Instant;
|
||||
use tracing_test::traced_test;
|
||||
|
||||
use super::retry::CouldRetry;
|
||||
@@ -501,7 +502,7 @@ impl TestControlPlaneClient for TestConnectMechanism {
|
||||
details: Details {
|
||||
error_info: None,
|
||||
retry_info: Some(control_plane::messages::RetryInfo {
|
||||
retry_delay_ms: 1,
|
||||
retry_at: Instant::now() + Duration::from_millis(1),
|
||||
}),
|
||||
user_facing_message: None,
|
||||
},
|
||||
|
||||
@@ -131,11 +131,11 @@ where
|
||||
Ok(())
|
||||
}
|
||||
|
||||
struct MessageHandler<C: ProjectInfoCache + Send + Sync + 'static> {
|
||||
struct MessageHandler<C: Send + Sync + 'static> {
|
||||
cache: Arc<C>,
|
||||
}
|
||||
|
||||
impl<C: ProjectInfoCache + Send + Sync + 'static> Clone for MessageHandler<C> {
|
||||
impl<C: Send + Sync + 'static> Clone for MessageHandler<C> {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
cache: self.cache.clone(),
|
||||
@@ -143,8 +143,8 @@ impl<C: ProjectInfoCache + Send + Sync + 'static> Clone for MessageHandler<C> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: ProjectInfoCache + Send + Sync + 'static> MessageHandler<C> {
|
||||
pub(crate) fn new(cache: Arc<C>) -> Self {
|
||||
impl MessageHandler<ProjectInfoCache> {
|
||||
pub(crate) fn new(cache: Arc<ProjectInfoCache>) -> Self {
|
||||
Self { cache }
|
||||
}
|
||||
|
||||
@@ -224,7 +224,7 @@ impl<C: ProjectInfoCache + Send + Sync + 'static> MessageHandler<C> {
|
||||
}
|
||||
}
|
||||
|
||||
fn invalidate_cache<C: ProjectInfoCache>(cache: Arc<C>, msg: Notification) {
|
||||
fn invalidate_cache(cache: Arc<ProjectInfoCache>, msg: Notification) {
|
||||
match msg {
|
||||
Notification::EndpointSettingsUpdate(ids) => ids
|
||||
.iter()
|
||||
@@ -247,8 +247,8 @@ fn invalidate_cache<C: ProjectInfoCache>(cache: Arc<C>, msg: Notification) {
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_messages<C: ProjectInfoCache + Send + Sync + 'static>(
|
||||
handler: MessageHandler<C>,
|
||||
async fn handle_messages(
|
||||
handler: MessageHandler<ProjectInfoCache>,
|
||||
redis: ConnectionWithCredentialsProvider,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> anyhow::Result<()> {
|
||||
@@ -284,13 +284,10 @@ async fn handle_messages<C: ProjectInfoCache + Send + Sync + 'static>(
|
||||
|
||||
/// Handle console's invalidation messages.
|
||||
#[tracing::instrument(name = "redis_notifications", skip_all)]
|
||||
pub async fn task_main<C>(
|
||||
pub async fn task_main(
|
||||
redis: ConnectionWithCredentialsProvider,
|
||||
cache: Arc<C>,
|
||||
) -> anyhow::Result<Infallible>
|
||||
where
|
||||
C: ProjectInfoCache + Send + Sync + 'static,
|
||||
{
|
||||
cache: Arc<ProjectInfoCache>,
|
||||
) -> anyhow::Result<Infallible> {
|
||||
let handler = MessageHandler::new(cache);
|
||||
// 6h - 1m.
|
||||
// There will be 1 minute overlap between two tasks. But at least we can be sure that no message is lost.
|
||||
|
||||
@@ -149,8 +149,8 @@ impl DbSchemaCache {
|
||||
ctx: &RequestContext,
|
||||
config: &'static ProxyConfig,
|
||||
) -> Result<Arc<(ApiConfig, DbSchemaOwned)>, RestError> {
|
||||
match self.get_with_created_at(endpoint_id) {
|
||||
Some(Cached { value: (v, _), .. }) => Ok(v),
|
||||
match self.get(endpoint_id) {
|
||||
Some(Cached { value: v, .. }) => Ok(v),
|
||||
None => {
|
||||
info!("db_schema cache miss for endpoint: {:?}", endpoint_id);
|
||||
let remote_value = self
|
||||
|
||||
Reference in New Issue
Block a user