Compare commits

...

4 Commits

Author SHA1 Message Date
Conrad Ludgate
b9a4326fbd fmt 2024-05-07 10:44:33 +01:00
Conrad Ludgate
85033e05c9 hakari 2024-05-07 08:35:18 +01:00
Conrad Ludgate
ca578449e4 simplify Cache invalidate trait, reduce EndpointCacheKey 2024-05-07 08:34:21 +01:00
Conrad Ludgate
ef3a9dfafa proxy: moka cache 2024-05-07 07:59:23 +01:00
17 changed files with 235 additions and 351 deletions

81
Cargo.lock generated
View File

@@ -213,9 +213,9 @@ dependencies = [
[[package]] [[package]]
name = "async-lock" name = "async-lock"
version = "3.2.0" version = "3.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7125e42787d53db9dd54261812ef17e937c95a51e4d291373b670342fa44310c" checksum = "d034b430882f8381900d3fe6f0aaa3ad94f2cb4ac519b429692a1bc2dda4ae7b"
dependencies = [ dependencies = [
"event-listener 4.0.0", "event-listener 4.0.0",
"event-listener-strategy", "event-listener-strategy",
@@ -1239,9 +1239,9 @@ dependencies = [
[[package]] [[package]]
name = "concurrent-queue" name = "concurrent-queue"
version = "2.3.0" version = "2.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f057a694a54f12365049b0958a1685bb52d567f5593b355fbf685838e873d400" checksum = "4ca0197aee26d1ae37445ee532fefce43251d24cc7c166799f4d46817f1d3973"
dependencies = [ dependencies = [
"crossbeam-utils", "crossbeam-utils",
] ]
@@ -1875,6 +1875,17 @@ dependencies = [
"pin-project-lite", "pin-project-lite",
] ]
[[package]]
name = "event-listener"
version = "5.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6d9944b8ca13534cdfb2800775f8dd4902ff3fc75a50101466decadfdf322a24"
dependencies = [
"concurrent-queue",
"parking",
"pin-project-lite",
]
[[package]] [[package]]
name = "event-listener-strategy" name = "event-listener-strategy"
version = "0.4.0" version = "0.4.0"
@@ -3121,6 +3132,30 @@ dependencies = [
"windows-sys 0.48.0", "windows-sys 0.48.0",
] ]
[[package]]
name = "moka"
version = "0.12.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9e0d88686dc561d743b40de8269b26eaf0dc58781bde087b0984646602021d08"
dependencies = [
"async-lock",
"async-trait",
"crossbeam-channel",
"crossbeam-epoch",
"crossbeam-utils",
"event-listener 5.3.0",
"futures-util",
"once_cell",
"parking_lot 0.12.1",
"quanta",
"rustc_version",
"smallvec",
"tagptr",
"thiserror",
"triomphe",
"uuid",
]
[[package]] [[package]]
name = "multimap" name = "multimap"
version = "0.8.3" version = "0.8.3"
@@ -4377,6 +4412,7 @@ dependencies = [
"md5", "md5",
"measured", "measured",
"metrics", "metrics",
"moka",
"native-tls", "native-tls",
"once_cell", "once_cell",
"opentelemetry", "opentelemetry",
@@ -4438,6 +4474,21 @@ dependencies = [
"x509-parser", "x509-parser",
] ]
[[package]]
name = "quanta"
version = "0.12.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8e5167a477619228a0b284fac2674e3c388cba90631d7b7de620e6f1fcd08da5"
dependencies = [
"crossbeam-utils",
"libc",
"once_cell",
"raw-cpuid",
"wasi 0.11.0+wasi-snapshot-preview1",
"web-sys",
"winapi",
]
[[package]] [[package]]
name = "quick-xml" name = "quick-xml"
version = "0.31.0" version = "0.31.0"
@@ -4549,6 +4600,15 @@ dependencies = [
"rand_core 0.5.1", "rand_core 0.5.1",
] ]
[[package]]
name = "raw-cpuid"
version = "11.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e29830cbb1290e404f24c73af91c5d8d631ce7e128691e9477556b540cd01ecd"
dependencies = [
"bitflags 2.4.1",
]
[[package]] [[package]]
name = "rayon" name = "rayon"
version = "1.7.0" version = "1.7.0"
@@ -5990,6 +6050,12 @@ dependencies = [
"winapi", "winapi",
] ]
[[package]]
name = "tagptr"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7b2093cf4c8eb1e67749a6762251bc9cd836b6fc171623bd0a9d324d37af2417"
[[package]] [[package]]
name = "tar" name = "tar"
version = "0.4.40" version = "0.4.40"
@@ -6641,6 +6707,12 @@ dependencies = [
"workspace_hack", "workspace_hack",
] ]
[[package]]
name = "triomphe"
version = "0.1.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "859eb650cfee7434994602c3a68b25d77ad9e68c8a6cd491616ef86661382eb3"
[[package]] [[package]]
name = "try-lock" name = "try-lock"
version = "0.2.4" version = "0.2.4"
@@ -7470,6 +7542,7 @@ dependencies = [
"chrono", "chrono",
"clap", "clap",
"clap_builder", "clap_builder",
"crossbeam-epoch",
"crossbeam-utils", "crossbeam-utils",
"either", "either",
"fail", "fail",

View File

@@ -9,6 +9,8 @@ default = []
testing = [] testing = []
[dependencies] [dependencies]
workspace_hack.workspace = true
anyhow.workspace = true anyhow.workspace = true
async-compression.workspace = true async-compression.workspace = true
async-trait.workspace = true async-trait.workspace = true
@@ -46,6 +48,7 @@ lasso = { workspace = true, features = ["multi-threaded"] }
md5.workspace = true md5.workspace = true
measured = { workspace = true, features = ["lasso"] } measured = { workspace = true, features = ["lasso"] }
metrics.workspace = true metrics.workspace = true
moka = { version = "0.12.7", features = ["future"] }
once_cell.workspace = true once_cell.workspace = true
opentelemetry.workspace = true opentelemetry.workspace = true
parking_lot.workspace = true parking_lot.workspace = true
@@ -100,8 +103,6 @@ postgres-native-tls.workspace = true
postgres-protocol.workspace = true postgres-protocol.workspace = true
redis.workspace = true redis.workspace = true
workspace_hack.workspace = true
[dev-dependencies] [dev-dependencies]
camino-tempfile.workspace = true camino-tempfile.workspace = true
fallible-iterator.workspace = true fallible-iterator.workspace = true

View File

@@ -69,8 +69,10 @@ pub enum BackendType<'a, T, D> {
Link(MaybeOwned<'a, url::ApiUrl>, D), Link(MaybeOwned<'a, url::ApiUrl>, D),
} }
#[cfg(test)]
#[async_trait::async_trait]
pub trait TestBackend: Send + Sync + 'static { pub trait TestBackend: Send + Sync + 'static {
fn wake_compute(&self) -> Result<CachedNodeInfo, console::errors::WakeComputeError>; async fn wake_compute(&self) -> Result<CachedNodeInfo, console::errors::WakeComputeError>;
fn get_allowed_ips_and_secret( fn get_allowed_ips_and_secret(
&self, &self,
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), console::errors::GetAuthInfoError>; ) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), console::errors::GetAuthInfoError>;
@@ -153,7 +155,13 @@ pub struct ComputeUserInfo {
impl ComputeUserInfo { impl ComputeUserInfo {
pub fn endpoint_cache_key(&self) -> EndpointCacheKey { pub fn endpoint_cache_key(&self) -> EndpointCacheKey {
self.options.get_cache_key(&self.endpoint) let id = EndpointIdInt::from(&self.endpoint);
let key = EndpointCacheKey::from(id);
if self.options.is_empty() {
key
} else {
key.with_options(self.options.to_string())
}
} }
} }
@@ -343,7 +351,7 @@ async fn auth_quirks(
Err(e) => { Err(e) => {
if e.is_auth_failed() { if e.is_auth_failed() {
// The password could have been changed, so we invalidate the cache. // The password could have been changed, so we invalidate the cache.
cached_entry.invalidate(); cached_entry.invalidate().await;
} }
Err(e) Err(e)
} }

View File

@@ -292,7 +292,7 @@ mod tests {
ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())?; ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())?;
assert_eq!(user_info.user, "john_doe"); assert_eq!(user_info.user, "john_doe");
assert_eq!(user_info.endpoint_id.as_deref(), Some("foo")); assert_eq!(user_info.endpoint_id.as_deref(), Some("foo"));
assert_eq!(user_info.options.get_cache_key("foo"), "foo"); assert_eq!(user_info.options.to_string(), "");
Ok(()) Ok(())
} }
@@ -451,8 +451,8 @@ mod tests {
ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())?; ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())?;
assert_eq!(user_info.endpoint_id.as_deref(), Some("project")); assert_eq!(user_info.endpoint_id.as_deref(), Some("project"));
assert_eq!( assert_eq!(
user_info.options.get_cache_key("project"), user_info.options.to_string(),
"project endpoint_type:read_write lsn:0/2" "endpoint_type:read_write lsn:0/2"
); );
Ok(()) Ok(())

View File

@@ -1,7 +1,5 @@
pub mod common; pub mod common;
pub mod endpoints; pub mod endpoints;
pub mod project_info; pub mod project_info;
mod timed_lru;
pub use common::{Cache, Cached}; pub use common::{Cache, Cached};
pub use timed_lru::TimedLru;

View File

@@ -3,35 +3,28 @@ use std::ops::{Deref, DerefMut};
/// A generic trait which exposes types of cache's key and value, /// A generic trait which exposes types of cache's key and value,
/// as well as the notion of cache entry invalidation. /// as well as the notion of cache entry invalidation.
/// This is useful for [`Cached`]. /// This is useful for [`Cached`].
#[allow(async_fn_in_trait)]
pub trait Cache { pub trait Cache {
/// Entry's key.
type Key;
/// Entry's value.
type Value;
/// Used for entry invalidation. /// Used for entry invalidation.
type LookupInfo<Key>; type LookupInfo;
/// Invalidate an entry using a lookup info. /// Invalidate an entry using a lookup info.
/// We don't have an empty default impl because it's error-prone. /// We don't have an empty default impl because it's error-prone.
fn invalidate(&self, _: &Self::LookupInfo<Self::Key>); async fn invalidate(&self, _: &Self::LookupInfo);
} }
impl<C: Cache> Cache for &C { impl<C: Cache> Cache for &C {
type Key = C::Key; type LookupInfo = C::LookupInfo;
type Value = C::Value;
type LookupInfo<Key> = C::LookupInfo<Key>;
fn invalidate(&self, info: &Self::LookupInfo<Self::Key>) { async fn invalidate(&self, info: &Self::LookupInfo) {
C::invalidate(self, info) C::invalidate(self, info).await
} }
} }
/// Wrapper for convenient entry invalidation. /// Wrapper for convenient entry invalidation.
pub struct Cached<C: Cache, V = <C as Cache>::Value> { pub struct Cached<C: Cache, V> {
/// Cache + lookup info. /// Cache + lookup info.
pub token: Option<(C, C::LookupInfo<C::Key>)>, pub token: Option<(C, C::LookupInfo)>,
/// The value itself. /// The value itself.
pub value: V, pub value: V,
@@ -54,9 +47,9 @@ impl<C: Cache, V> Cached<C, V> {
} }
/// Drop this entry from a cache if it's still there. /// Drop this entry from a cache if it's still there.
pub fn invalidate(self) -> V { pub async fn invalidate(self) -> V {
if let Some((cache, info)) = &self.token { if let Some((cache, info)) = &self.token {
cache.invalidate(info); cache.invalidate(info).await;
} }
self.value self.value
} }
@@ -80,3 +73,16 @@ impl<C: Cache, V> DerefMut for Cached<C, V> {
&mut self.value &mut self.value
} }
} }
impl<K, V, S> Cache for moka::future::Cache<K, V, S>
where
K: std::hash::Hash + Eq + Send + Sync + 'static,
V: Clone + Send + Sync + 'static,
S: std::hash::BuildHasher + Clone + Send + Sync + 'static,
{
type LookupInfo = K;
async fn invalidate(&self, key: &Self::LookupInfo) {
moka::future::Cache::invalidate(self, key).await
}
}

View File

@@ -8,7 +8,6 @@ use std::{
use async_trait::async_trait; use async_trait::async_trait;
use dashmap::DashMap; use dashmap::DashMap;
use rand::{thread_rng, Rng}; use rand::{thread_rng, Rng};
use smol_str::SmolStr;
use tokio::sync::Mutex; use tokio::sync::Mutex;
use tokio::time::Instant; use tokio::time::Instant;
use tracing::{debug, info}; use tracing::{debug, info};
@@ -346,13 +345,9 @@ enum LookupType {
} }
impl Cache for ProjectInfoCacheImpl { impl Cache for ProjectInfoCacheImpl {
type Key = SmolStr; type LookupInfo = CachedLookupInfo;
// Value is not really used here, but we need to specify it.
type Value = SmolStr;
type LookupInfo<Key> = CachedLookupInfo; async fn invalidate(&self, key: &Self::LookupInfo) {
fn invalidate(&self, key: &Self::LookupInfo<SmolStr>) {
match &key.lookup_type { match &key.lookup_type {
LookupType::RoleSecret(role_name) => { LookupType::RoleSecret(role_name) => {
if let Some(mut endpoint_info) = self.cache.get_mut(&key.endpoint_id) { if let Some(mut endpoint_info) = self.cache.get_mut(&key.endpoint_id) {
@@ -489,7 +484,7 @@ mod tests {
assert!(!cached.cached()); assert!(!cached.cached());
assert_eq!(cached.value, secret1); assert_eq!(cached.value, secret1);
cached.invalidate(); // Shouldn't do anything. cached.invalidate().await; // Shouldn't do anything.
let cached = cache.get_role_secret(&endpoint_id, &user1).unwrap(); let cached = cache.get_role_secret(&endpoint_id, &user1).unwrap();
assert_eq!(cached.value, secret1); assert_eq!(cached.value, secret1);

View File

@@ -1,258 +0,0 @@
use std::{
borrow::Borrow,
hash::Hash,
time::{Duration, Instant},
};
use tracing::debug;
// This seems to make more sense than `lru` or `cached`:
//
// * `near/nearcore` ditched `cached` in favor of `lru`
// (https://github.com/near/nearcore/issues?q=is%3Aissue+lru+is%3Aclosed).
//
// * `lru` methods use an obscure `KeyRef` type in their contraints (which is deliberately excluded from docs).
// This severely hinders its usage both in terms of creating wrappers and supported key types.
//
// On the other hand, `hashlink` has good download stats and appears to be maintained.
use hashlink::{linked_hash_map::RawEntryMut, LruCache};
use super::{common::Cached, *};
/// An implementation of timed LRU cache with fixed capacity.
/// Key properties:
///
/// * Whenever a new entry is inserted, the least recently accessed one is evicted.
/// The cache also keeps track of entry's insertion time (`created_at`) and TTL (`expires_at`).
///
/// * If `update_ttl_on_retrieval` is `true`. When the entry is about to be retrieved, we check its expiration timestamp.
/// If the entry has expired, we remove it from the cache; Otherwise we bump the
/// expiration timestamp (e.g. +5mins) and change its place in LRU list to prolong
/// its existence.
///
/// * There's an API for immediate invalidation (removal) of a cache entry;
/// It's useful in case we know for sure that the entry is no longer correct.
/// See [`timed_lru::LookupInfo`] & [`timed_lru::Cached`] for more information.
///
/// * Expired entries are kept in the cache, until they are evicted by the LRU policy,
/// or by a successful lookup (i.e. the entry hasn't expired yet).
/// There is no background job to reap the expired records.
///
/// * It's possible for an entry that has not yet expired entry to be evicted
/// before expired items. That's a bit wasteful, but probably fine in practice.
pub struct TimedLru<K, V> {
/// Cache's name for tracing.
name: &'static str,
/// The underlying cache implementation.
cache: parking_lot::Mutex<LruCache<K, Entry<V>>>,
/// Default time-to-live of a single entry.
ttl: Duration,
update_ttl_on_retrieval: bool,
}
impl<K: Hash + Eq, V> Cache for TimedLru<K, V> {
type Key = K;
type Value = V;
type LookupInfo<Key> = LookupInfo<Key>;
fn invalidate(&self, info: &Self::LookupInfo<K>) {
self.invalidate_raw(info)
}
}
struct Entry<T> {
created_at: Instant,
expires_at: Instant,
value: T,
}
impl<K: Hash + Eq, V> TimedLru<K, V> {
/// Construct a new LRU cache with timed entries.
pub fn new(
name: &'static str,
capacity: usize,
ttl: Duration,
update_ttl_on_retrieval: bool,
) -> Self {
Self {
name,
cache: LruCache::new(capacity).into(),
ttl,
update_ttl_on_retrieval,
}
}
/// Drop an entry from the cache if it's outdated.
#[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)]
fn invalidate_raw(&self, info: &LookupInfo<K>) {
let now = Instant::now();
// Do costly things before taking the lock.
let mut cache = self.cache.lock();
let raw_entry = match cache.raw_entry_mut().from_key(&info.key) {
RawEntryMut::Vacant(_) => return,
RawEntryMut::Occupied(x) => x,
};
// Remove the entry if it was created prior to lookup timestamp.
let entry = raw_entry.get();
let (created_at, expires_at) = (entry.created_at, entry.expires_at);
let should_remove = created_at <= info.created_at || expires_at <= now;
if should_remove {
raw_entry.remove();
}
drop(cache); // drop lock before logging
debug!(
created_at = format_args!("{created_at:?}"),
expires_at = format_args!("{expires_at:?}"),
entry_removed = should_remove,
"processed a cache entry invalidation event"
);
}
/// Try retrieving an entry by its key, then execute `extract` if it exists.
#[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)]
fn get_raw<Q, R>(&self, key: &Q, extract: impl FnOnce(&K, &Entry<V>) -> R) -> Option<R>
where
K: Borrow<Q>,
Q: Hash + Eq + ?Sized,
{
let now = Instant::now();
let deadline = now.checked_add(self.ttl).expect("time overflow");
// Do costly things before taking the lock.
let mut cache = self.cache.lock();
let mut raw_entry = match cache.raw_entry_mut().from_key(key) {
RawEntryMut::Vacant(_) => return None,
RawEntryMut::Occupied(x) => x,
};
// Immeditely drop the entry if it has expired.
let entry = raw_entry.get();
if entry.expires_at <= now {
raw_entry.remove();
return None;
}
let value = extract(raw_entry.key(), entry);
let (created_at, expires_at) = (entry.created_at, entry.expires_at);
// Update the deadline and the entry's position in the LRU list.
if self.update_ttl_on_retrieval {
raw_entry.get_mut().expires_at = deadline;
}
raw_entry.to_back();
drop(cache); // drop lock before logging
debug!(
created_at = format_args!("{created_at:?}"),
old_expires_at = format_args!("{expires_at:?}"),
new_expires_at = format_args!("{deadline:?}"),
"accessed a cache entry"
);
Some(value)
}
/// Insert an entry to the cache. If an entry with the same key already
/// existed, return the previous value and its creation timestamp.
#[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)]
fn insert_raw(&self, key: K, value: V) -> (Instant, Option<V>) {
let created_at = Instant::now();
let expires_at = created_at.checked_add(self.ttl).expect("time overflow");
let entry = Entry {
created_at,
expires_at,
value,
};
// Do costly things before taking the lock.
let old = self
.cache
.lock()
.insert(key, entry)
.map(|entry| entry.value);
debug!(
created_at = format_args!("{created_at:?}"),
expires_at = format_args!("{expires_at:?}"),
replaced = old.is_some(),
"created a cache entry"
);
(created_at, old)
}
}
impl<K: Hash + Eq + Clone, V: Clone> TimedLru<K, V> {
pub fn insert(&self, key: K, value: V) -> (Option<V>, Cached<&Self>) {
let (created_at, old) = self.insert_raw(key.clone(), value.clone());
let cached = Cached {
token: Some((self, LookupInfo { created_at, key })),
value,
};
(old, cached)
}
}
impl<K: Hash + Eq, V: Clone> TimedLru<K, V> {
/// Retrieve a cached entry in convenient wrapper.
pub fn get<Q>(&self, key: &Q) -> Option<timed_lru::Cached<&Self>>
where
K: Borrow<Q> + Clone,
Q: Hash + Eq + ?Sized,
{
self.get_raw(key, |key, entry| {
let info = LookupInfo {
created_at: entry.created_at,
key: key.clone(),
};
Cached {
token: Some((self, info)),
value: entry.value.clone(),
}
})
}
/// Retrieve a cached entry in convenient wrapper, ignoring its TTL.
pub fn get_ignoring_ttl<Q>(&self, key: &Q) -> Option<timed_lru::Cached<&Self>>
where
K: Borrow<Q>,
Q: Hash + Eq + ?Sized,
{
let mut cache = self.cache.lock();
cache
.get(key)
.map(|entry| Cached::new_uncached(entry.value.clone()))
}
/// Remove an entry from the cache.
pub fn remove<Q>(&self, key: &Q) -> Option<V>
where
K: Borrow<Q> + Clone,
Q: Hash + Eq + ?Sized,
{
let mut cache = self.cache.lock();
cache.remove(key).map(|entry| entry.value)
}
}
/// Lookup information for key invalidation.
pub struct LookupInfo<K> {
/// Time of creation of a cache [`Entry`].
/// We use this during invalidation lookups to prevent eviction of a newer
/// entry sharing the same key (it might've been inserted by a different
/// task after we got the entry we're trying to invalidate now).
created_at: Instant,
/// Search by this key.
key: K,
}

View File

@@ -411,7 +411,7 @@ pub fn remote_storage_from_toml(s: &str) -> anyhow::Result<OptRemoteStorageConfi
#[derive(Debug)] #[derive(Debug)]
pub struct CacheOptions { pub struct CacheOptions {
/// Max number of entries. /// Max number of entries.
pub size: usize, pub size: u64,
/// Entry's time-to-live. /// Entry's time-to-live.
pub ttl: Duration, pub ttl: Duration,
} }

View File

@@ -8,7 +8,7 @@ use crate::{
backend::{ComputeCredentialKeys, ComputeUserInfo}, backend::{ComputeCredentialKeys, ComputeUserInfo},
IpPattern, IpPattern,
}, },
cache::{endpoints::EndpointsCache, project_info::ProjectInfoCacheImpl, Cached, TimedLru}, cache::{endpoints::EndpointsCache, project_info::ProjectInfoCacheImpl, Cached},
compute, compute,
config::{CacheOptions, EndpointCacheConfig, ProjectInfoCacheOptions}, config::{CacheOptions, EndpointCacheConfig, ProjectInfoCacheOptions},
context::RequestMonitoring, context::RequestMonitoring,
@@ -326,8 +326,8 @@ impl NodeInfo {
} }
} }
pub type NodeInfoCache = TimedLru<EndpointCacheKey, NodeInfo>; pub type NodeInfoCache = moka::future::Cache<EndpointCacheKey, NodeInfo>;
pub type CachedNodeInfo = Cached<&'static NodeInfoCache>; pub type CachedNodeInfo = Cached<&'static NodeInfoCache, NodeInfo>;
pub type CachedRoleSecret = Cached<&'static ProjectInfoCacheImpl, Option<AuthSecret>>; pub type CachedRoleSecret = Cached<&'static ProjectInfoCacheImpl, Option<AuthSecret>>;
pub type CachedAllowedIps = Cached<&'static ProjectInfoCacheImpl, Arc<Vec<IpPattern>>>; pub type CachedAllowedIps = Cached<&'static ProjectInfoCacheImpl, Arc<Vec<IpPattern>>>;
@@ -412,7 +412,7 @@ impl Api for ConsoleBackend {
#[cfg(any(test, feature = "testing"))] #[cfg(any(test, feature = "testing"))]
Postgres(api) => api.wake_compute(ctx, user_info).await, Postgres(api) => api.wake_compute(ctx, user_info).await,
#[cfg(test)] #[cfg(test)]
Test(api) => api.wake_compute(), Test(api) => api.wake_compute().await,
} }
} }
} }
@@ -434,12 +434,11 @@ impl ApiCaches {
endpoint_cache_config: EndpointCacheConfig, endpoint_cache_config: EndpointCacheConfig,
) -> Self { ) -> Self {
Self { Self {
node_info: NodeInfoCache::new( node_info: moka::future::Cache::builder()
"node_info_cache", .max_capacity(wake_compute_cache_config.size)
wake_compute_cache_config.size, .time_to_idle(wake_compute_cache_config.ttl)
wake_compute_cache_config.ttl, .name("node_info_cache")
true, .build(),
),
project_info: Arc::new(ProjectInfoCacheImpl::new(project_info_cache_config)), project_info: Arc::new(ProjectInfoCacheImpl::new(project_info_cache_config)),
endpoints_cache: Arc::new(EndpointsCache::new(endpoint_cache_config)), endpoints_cache: Arc::new(EndpointsCache::new(endpoint_cache_config)),
} }

View File

@@ -275,10 +275,13 @@ impl super::Api for Api {
// for some time (highly depends on the console's scale-to-zero policy); // for some time (highly depends on the console's scale-to-zero policy);
// The connection info remains the same during that period of time, // The connection info remains the same during that period of time,
// which means that we might cache it to reduce the load and latency. // which means that we might cache it to reduce the load and latency.
if let Some(cached) = self.caches.node_info.get(&key) { if let Some(cached) = self.caches.node_info.get(&key).await {
info!(key = &*key, "found cached compute node info"); info!(key = %key, "found cached compute node info");
ctx.set_project(cached.aux.clone()); ctx.set_project(cached.aux.clone());
return Ok(cached); return Ok(CachedNodeInfo {
token: Some((&self.caches.node_info, key)),
value: cached,
});
} }
// check rate limit // check rate limit
@@ -294,10 +297,13 @@ impl super::Api for Api {
// after getting back a permit - it's possible the cache was filled // after getting back a permit - it's possible the cache was filled
// double check // double check
if permit.should_check_cache() { if permit.should_check_cache() {
if let Some(cached) = self.caches.node_info.get(&key) { if let Some(cached) = self.caches.node_info.get(&key).await {
info!(key = &*key, "found cached compute node info"); info!(key = %key, "found cached compute node info");
ctx.set_project(cached.aux.clone()); ctx.set_project(cached.aux.clone());
return Ok(cached); return Ok(CachedNodeInfo {
token: Some((&self.caches.node_info, key)),
value: cached,
});
} }
} }
@@ -308,12 +314,18 @@ impl super::Api for Api {
// store the cached node as 'warm' // store the cached node as 'warm'
node.aux.cold_start_info = ColdStartInfo::WarmCached; node.aux.cold_start_info = ColdStartInfo::WarmCached;
let (_, mut cached) = self.caches.node_info.insert(key.clone(), node); self.caches
cached.aux.cold_start_info = cold_start_info; .node_info
.insert(key.clone(), node.clone())
.await;
node.aux.cold_start_info = cold_start_info;
info!(key = &*key, "created a cache entry for compute node info"); info!(key = %key, "created a cache entry for compute node info");
Ok(cached) Ok(CachedNodeInfo {
token: Some((&self.caches.node_info, key)),
value: node,
})
} }
} }

View File

@@ -154,9 +154,6 @@ smol_str_wrapper!(BranchId);
// 90% of project strings are 23 characters or less. // 90% of project strings are 23 characters or less.
smol_str_wrapper!(ProjectId); smol_str_wrapper!(ProjectId);
// will usually equal endpoint ID
smol_str_wrapper!(EndpointCacheKey);
smol_str_wrapper!(DbName); smol_str_wrapper!(DbName);
// postgres hostname, will likely be a port:ip addr // postgres hostname, will likely be a port:ip addr
@@ -180,3 +177,35 @@ impl EndpointId {
ProjectId(self.0.clone()) ProjectId(self.0.clone())
} }
} }
#[derive(Hash, PartialEq, Eq, Debug, Clone)]
pub struct EndpointCacheKey {
endpoint: intern::EndpointIdInt,
options: Option<String>,
}
impl std::fmt::Display for EndpointCacheKey {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.endpoint.as_str())?;
if let Some(options) = &self.options {
f.write_str(" ")?;
f.write_str(options)?;
}
Ok(())
}
}
impl From<intern::EndpointIdInt> for EndpointCacheKey {
fn from(value: intern::EndpointIdInt) -> Self {
Self {
endpoint: value,
options: None,
}
}
}
impl EndpointCacheKey {
pub fn with_options(mut self, options: String) -> Self {
self.options = Some(options);
self
}
}

View File

@@ -20,7 +20,6 @@ use crate::{
protocol2::read_proxy_protocol, protocol2::read_proxy_protocol,
proxy::handshake::{handshake, HandshakeData}, proxy::handshake::{handshake, HandshakeData},
stream::{PqStream, Stream}, stream::{PqStream, Stream},
EndpointCacheKey,
}; };
use futures::TryFutureExt; use futures::TryFutureExt;
use itertools::Itertools; use itertools::Itertools;
@@ -391,13 +390,8 @@ impl NeonOptions {
Self(options) Self(options)
} }
pub fn get_cache_key(&self, prefix: &str) -> EndpointCacheKey { pub fn is_empty(&self) -> bool {
// prefix + format!(" {k}:{v}") self.0.is_empty()
// kinda jank because SmolStr is immutable
std::iter::once(prefix)
.chain(self.0.iter().flat_map(|(k, v)| [" ", &**k, ":", &**v]))
.collect::<SmolStr>()
.into()
} }
/// <https://swagger.io/docs/specification/serialization/> DeepObject format /// <https://swagger.io/docs/specification/serialization/> DeepObject format
@@ -418,3 +412,20 @@ pub fn neon_option(bytes: &str) -> Option<(&str, &str)> {
let (_, [k, v]) = cap.extract(); let (_, [k, v]) = cap.extract();
Some((k, v)) Some((k, v))
} }
impl std::fmt::Display for NeonOptions {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut space = false;
for (k, v) in &self.0 {
if space {
f.write_str(" ")?;
} else {
space = true;
}
f.write_str(k)?;
f.write_str(":")?;
f.write_str(v)?;
}
Ok(())
}
}

View File

@@ -23,7 +23,7 @@ const CONNECT_TIMEOUT: time::Duration = time::Duration::from_secs(2);
/// (e.g. the compute node's address might've changed at the wrong time). /// (e.g. the compute node's address might've changed at the wrong time).
/// Invalidate the cache entry (if any) to prevent subsequent errors. /// Invalidate the cache entry (if any) to prevent subsequent errors.
#[tracing::instrument(name = "invalidate_cache", skip_all)] #[tracing::instrument(name = "invalidate_cache", skip_all)]
pub fn invalidate_cache(node_info: console::CachedNodeInfo) -> NodeInfo { pub async fn invalidate_cache(node_info: console::CachedNodeInfo) -> NodeInfo {
let is_cached = node_info.cached(); let is_cached = node_info.cached();
if is_cached { if is_cached {
warn!("invalidating stalled compute node info cache entry"); warn!("invalidating stalled compute node info cache entry");
@@ -34,7 +34,7 @@ pub fn invalidate_cache(node_info: console::CachedNodeInfo) -> NodeInfo {
}; };
Metrics::get().proxy.connection_failures_total.inc(label); Metrics::get().proxy.connection_failures_total.inc(label);
node_info.invalidate() node_info.invalidate().await
} }
#[async_trait] #[async_trait]
@@ -156,7 +156,7 @@ where
} else { } else {
// if we failed to connect, it's likely that the compute node was suspended, wake a new compute node // if we failed to connect, it's likely that the compute node was suspended, wake a new compute node
info!("compute node's state has likely changed; requesting a wake-up"); info!("compute node's state has likely changed; requesting a wake-up");
let old_node_info = invalidate_cache(node_info); let old_node_info = invalidate_cache(node_info).await;
let mut node_info = let mut node_info =
wake_compute(&mut num_retries, ctx, user_info, wake_compute_retry_config).await?; wake_compute(&mut num_retries, ctx, user_info, wake_compute_retry_config).await?;
node_info.reuse_settings(old_node_info); node_info.reuse_settings(old_node_info);

View File

@@ -16,8 +16,9 @@ use crate::console::messages::MetricsAuxInfo;
use crate::console::provider::{CachedAllowedIps, CachedRoleSecret, ConsoleBackend}; use crate::console::provider::{CachedAllowedIps, CachedRoleSecret, ConsoleBackend};
use crate::console::{self, CachedNodeInfo, NodeInfo}; use crate::console::{self, CachedNodeInfo, NodeInfo};
use crate::error::ErrorKind; use crate::error::ErrorKind;
use crate::intern::EndpointIdInt;
use crate::proxy::retry::retry_after; use crate::proxy::retry::retry_after;
use crate::{http, sasl, scram, BranchId, EndpointId, ProjectId}; use crate::{http, sasl, scram, BranchId, EndpointCacheKey, EndpointId, ProjectId};
use anyhow::{bail, Context}; use anyhow::{bail, Context};
use async_trait::async_trait; use async_trait::async_trait;
use rstest::rstest; use rstest::rstest;
@@ -405,12 +406,13 @@ impl TestConnectMechanism {
Self { Self {
counter: Arc::new(std::sync::Mutex::new(0)), counter: Arc::new(std::sync::Mutex::new(0)),
sequence, sequence,
cache: Box::leak(Box::new(NodeInfoCache::new( cache: Box::leak(Box::new(
"test", NodeInfoCache::builder()
1, .name("test")
Duration::from_secs(100), .max_capacity(1)
false, .time_to_live(Duration::from_secs(100))
))), .build(),
)),
} }
} }
} }
@@ -476,13 +478,17 @@ impl ConnectMechanism for TestConnectMechanism {
fn update_connect_config(&self, _conf: &mut compute::ConnCfg) {} fn update_connect_config(&self, _conf: &mut compute::ConnCfg) {}
} }
#[async_trait]
impl TestBackend for TestConnectMechanism { impl TestBackend for TestConnectMechanism {
fn wake_compute(&self) -> Result<CachedNodeInfo, console::errors::WakeComputeError> { async fn wake_compute(&self) -> Result<CachedNodeInfo, console::errors::WakeComputeError> {
let mut counter = self.counter.lock().unwrap(); let action = {
let action = self.sequence[*counter]; let mut counter = self.counter.lock().unwrap();
*counter += 1; let action = self.sequence[*counter];
*counter += 1;
action
};
match action { match action {
ConnectAction::Wake => Ok(helper_create_cached_node_info(self.cache)), ConnectAction::Wake => Ok(helper_create_cached_node_info(self.cache).await),
ConnectAction::WakeFail => { ConnectAction::WakeFail => {
let err = console::errors::ApiError::Console { let err = console::errors::ApiError::Console {
status: http::StatusCode::FORBIDDEN, status: http::StatusCode::FORBIDDEN,
@@ -514,7 +520,7 @@ impl TestBackend for TestConnectMechanism {
} }
} }
fn helper_create_cached_node_info(cache: &'static NodeInfoCache) -> CachedNodeInfo { async fn helper_create_cached_node_info(cache: &'static NodeInfoCache) -> CachedNodeInfo {
let node = NodeInfo { let node = NodeInfo {
config: compute::ConnCfg::new(), config: compute::ConnCfg::new(),
aux: MetricsAuxInfo { aux: MetricsAuxInfo {
@@ -525,8 +531,14 @@ fn helper_create_cached_node_info(cache: &'static NodeInfoCache) -> CachedNodeIn
}, },
allow_self_signed_compute: false, allow_self_signed_compute: false,
}; };
let (_, node) = cache.insert("key".into(), node); let ep: EndpointId = "key".into();
node let ep = EndpointIdInt::from(ep);
let key = EndpointCacheKey::from(ep);
cache.insert(key.clone(), node.clone()).await;
CachedNodeInfo {
token: Some((cache, key)),
value: node,
}
} }
fn helper_create_connect_info( fn helper_create_connect_info(

View File

@@ -58,10 +58,7 @@ impl fmt::Display for ConnInfo {
write!( write!(
f, f,
"{}@{}/{}?{}", "{}@{}/{}?{}",
self.user_info.user, self.user_info.user, self.user_info.endpoint, self.dbname, self.user_info.options
self.user_info.endpoint,
self.dbname,
self.user_info.options.get_cache_key("")
) )
} }
} }

View File

@@ -27,6 +27,7 @@ bytes = { version = "1", features = ["serde"] }
chrono = { version = "0.4", default-features = false, features = ["clock", "serde", "wasmbind"] } chrono = { version = "0.4", default-features = false, features = ["clock", "serde", "wasmbind"] }
clap = { version = "4", features = ["derive", "string"] } clap = { version = "4", features = ["derive", "string"] }
clap_builder = { version = "4", default-features = false, features = ["color", "help", "std", "string", "suggestions", "usage"] } clap_builder = { version = "4", default-features = false, features = ["color", "help", "std", "string", "suggestions", "usage"] }
crossbeam-epoch = { version = "0.9" }
crossbeam-utils = { version = "0.8" } crossbeam-utils = { version = "0.8" }
either = { version = "1" } either = { version = "1" }
fail = { version = "0.5", default-features = false, features = ["failpoints"] } fail = { version = "0.5", default-features = false, features = ["failpoints"] }