diff --git a/Cargo.lock b/Cargo.lock index 5f544a05c6..512145f6c8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1134,6 +1134,20 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "acbf1af155f9b9ef647e42cdc158db4b64a1b61f743629225fde6f3e0be2a7c7" +[[package]] +name = "combine" +version = "4.6.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "35ed6e9d84f0b51a7f52daf1c7d71dd136fd7a3f41a8462b8cdb8c78d920fad4" +dependencies = [ + "bytes", + "futures-core", + "memchr", + "pin-project-lite", + "tokio", + "tokio-util", +] + [[package]] name = "comfy-table" version = "6.1.4" @@ -3899,6 +3913,7 @@ dependencies = [ "prometheus", "rand 0.8.5", "rcgen", + "redis", "regex", "remote_storage", "reqwest", @@ -4061,6 +4076,32 @@ dependencies = [ "yasna", ] +[[package]] +name = "redis" +version = "0.24.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c580d9cbbe1d1b479e8d67cf9daf6a62c957e6846048408b80b43ac3f6af84cd" +dependencies = [ + "async-trait", + "bytes", + "combine", + "futures-util", + "itoa", + "percent-encoding", + "pin-project-lite", + "rustls", + "rustls-native-certs", + "rustls-pemfile", + "rustls-webpki 0.101.7", + "ryu", + "sha1_smol", + "socket2 0.4.9", + "tokio", + "tokio-rustls", + "tokio-util", + "url", +] + [[package]] name = "redox_syscall" version = "0.2.16" @@ -4917,6 +4958,12 @@ dependencies = [ "digest", ] +[[package]] +name = "sha1_smol" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae1a47186c03a32177042e55dbc5fd5aee900b8e0069a8d70fba96a9375cd012" + [[package]] name = "sha2" version = "0.10.6" diff --git a/Cargo.toml b/Cargo.toml index e9172809d7..2d8fbaffa8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -114,6 +114,7 @@ pin-project-lite = "0.2" prometheus = {version = "0.13", default_features=false, features = ["process"]} # removes protobuf dependency prost = "0.11" rand = "0.8" +redis = { version = "0.24.0", features = ["tokio-rustls-comp", "keep-alive"] } regex = "1.10.2" reqwest = { version = "0.11", default-features = false, features = ["rustls-tls"] } reqwest-tracing = { version = "0.4.0", features = ["opentelemetry_0_19"] } diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index 5fdfd00a6a..23a9bb178d 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -79,6 +79,7 @@ x509-parser.workspace = true native-tls.workspace = true postgres-native-tls.workspace = true postgres-protocol.workspace = true +redis.workspace = true smol_str.workspace = true workspace_hack.workspace = true diff --git a/proxy/src/auth/backend.rs b/proxy/src/auth/backend.rs index d9bddff139..a4c5512521 100644 --- a/proxy/src/auth/backend.rs +++ b/proxy/src/auth/backend.rs @@ -8,6 +8,7 @@ use tokio_postgres::config::AuthKeys; use crate::auth::credentials::check_peer_addr_is_in_list; use crate::auth::validate_password_and_exchange; +use crate::cache::Cached; use crate::console::errors::GetAuthInfoError; use crate::console::AuthSecret; use crate::context::RequestMonitoring; @@ -20,7 +21,7 @@ use crate::{ config::AuthenticationConfig, console::{ self, - provider::{CachedNodeInfo, ConsoleReqExtra}, + provider::{CachedAllowedIps, CachedNodeInfo, ConsoleReqExtra}, Api, }, stream, url, @@ -55,7 +56,7 @@ pub enum BackendType<'a, T> { pub trait TestBackend: Send + Sync + 'static { fn wake_compute(&self) -> Result; - fn get_allowed_ips(&self) -> Result>, console::errors::GetAuthInfoError>; + fn get_allowed_ips(&self) -> Result, console::errors::GetAuthInfoError>; } impl std::fmt::Display for BackendType<'_, ()> { @@ -190,18 +191,21 @@ async fn auth_quirks( if !check_peer_addr_is_in_list(&ctx.peer_addr, &allowed_ips) { return Err(auth::AuthError::ip_address_not_allowed()); } - let cached_secret = api.get_role_secret(ctx, &info).await?; + let maybe_secret = api.get_role_secret(ctx, &info).await?; - let secret = cached_secret.clone().unwrap_or_else(|| { + let cached_secret = maybe_secret.unwrap_or_else(|| { // If we don't have an authentication secret, we mock one to // prevent malicious probing (possible due to missing protocol steps). // This mocked secret will never lead to successful authentication. info!("authentication info not found, mocking it"); - AuthSecret::Scram(scram::ServerSecret::mock(&info.inner.user, rand::random())) + Cached::new_uncached(AuthSecret::Scram(scram::ServerSecret::mock( + &info.inner.user, + rand::random(), + ))) }); match authenticate_with_secret( ctx, - secret, + cached_secret.value.clone(), info, client, unauthenticated_password, @@ -410,15 +414,15 @@ impl BackendType<'_, ComputeUserInfo> { pub async fn get_allowed_ips( &self, ctx: &mut RequestMonitoring, - ) -> Result>, GetAuthInfoError> { + ) -> Result { use BackendType::*; match self { Console(api, creds) => api.get_allowed_ips(ctx, creds).await, #[cfg(feature = "testing")] Postgres(api, creds) => api.get_allowed_ips(ctx, creds).await, - Link(_) => Ok(Arc::new(vec![])), + Link(_) => Ok(Cached::new_uncached(Arc::new(vec![]))), #[cfg(test)] - Test(x) => x.get_allowed_ips(), + Test(x) => Ok(Cached::new_uncached(Arc::new(x.get_allowed_ips()?))), } } diff --git a/proxy/src/auth/credentials.rs b/proxy/src/auth/credentials.rs index 4ddfa722e1..d282e894c8 100644 --- a/proxy/src/auth/credentials.rs +++ b/proxy/src/auth/credentials.rs @@ -155,7 +155,7 @@ impl ClientCredentials { } } -pub fn check_peer_addr_is_in_list(peer_addr: &IpAddr, ip_list: &Vec) -> bool { +pub fn check_peer_addr_is_in_list(peer_addr: &IpAddr, ip_list: &Vec) -> bool { if ip_list.is_empty() { return true; } diff --git a/proxy/src/bin/proxy.rs b/proxy/src/bin/proxy.rs index d42906aa4a..e1dac34a59 100644 --- a/proxy/src/bin/proxy.rs +++ b/proxy/src/bin/proxy.rs @@ -3,15 +3,14 @@ use proxy::auth; use proxy::config::AuthenticationConfig; use proxy::config::CacheOptions; use proxy::config::HttpConfig; +use proxy::config::ProjectInfoCacheOptions; use proxy::console; -use proxy::console::provider::AllowedIpsCache; -use proxy::console::provider::NodeInfoCache; -use proxy::console::provider::RoleSecretCache; use proxy::context::parquet::ParquetUploadArgs; use proxy::http; use proxy::rate_limiter::EndpointRateLimiter; use proxy::rate_limiter::RateBucketInfo; use proxy::rate_limiter::RateLimiterConfig; +use proxy::redis::notifications; use proxy::serverless::GlobalConnPoolOptions; use proxy::usage_metrics; @@ -137,6 +136,12 @@ struct ProxyCliArgs { /// disable ip check for http requests. If it is too time consuming, it could be turned off. #[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)] disable_ip_check_for_http: bool, + /// redis url for notifications. + #[clap(long)] + redis_notifications: Option, + /// cache for `project_info` (use `size=0` to disable) + #[clap(long, default_value = config::ProjectInfoCacheOptions::CACHE_DEFAULT_OPTIONS)] + project_info_cache: String, #[clap(flatten)] parquet_upload: ParquetUploadArgs, @@ -243,6 +248,15 @@ async fn main() -> anyhow::Result<()> { maintenance_tasks.spawn(usage_metrics::task_main(metrics_config)); } + if let auth::BackendType::Console(api, _) = &config.auth_backend { + let cache = api.caches.project_info.clone(); + if let Some(url) = args.redis_notifications { + info!("Starting redis notifications listener ({url})"); + maintenance_tasks.spawn(notifications::task_main(url.to_owned(), cache.clone())); + } + maintenance_tasks.spawn(async move { cache.clone().gc_worker().await }); + } + let maintenance = loop { // get one complete task match futures::future::select( @@ -308,32 +322,17 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { let auth_backend = match &args.auth_backend { AuthBackend::Console => { let wake_compute_cache_config: CacheOptions = args.wake_compute_cache.parse()?; - let allowed_ips_cache_config: CacheOptions = args.allowed_ips_cache.parse()?; - let role_secret_cache_config: CacheOptions = args.role_secret_cache.parse()?; + let project_info_cache_config: ProjectInfoCacheOptions = + args.project_info_cache.parse()?; info!("Using NodeInfoCache (wake_compute) with options={wake_compute_cache_config:?}"); - info!("Using AllowedIpsCache (wake_compute) with options={allowed_ips_cache_config:?}"); - info!("Using RoleSecretCache (wake_compute) with options={role_secret_cache_config:?}"); - let caches = Box::leak(Box::new(console::caches::ApiCaches { - node_info: NodeInfoCache::new( - "node_info_cache", - wake_compute_cache_config.size, - wake_compute_cache_config.ttl, - true, - ), - allowed_ips: AllowedIpsCache::new( - "allowed_ips_cache", - allowed_ips_cache_config.size, - allowed_ips_cache_config.ttl, - false, - ), - role_secret: RoleSecretCache::new( - "role_secret_cache", - role_secret_cache_config.size, - role_secret_cache_config.ttl, - false, - ), - })); + info!( + "Using AllowedIpsCache (wake_compute) with options={project_info_cache_config:?}" + ); + let caches = Box::leak(Box::new(console::caches::ApiCaches::new( + wake_compute_cache_config, + project_info_cache_config, + ))); let config::WakeComputeLockOptions { shards, diff --git a/proxy/src/cache.rs b/proxy/src/cache.rs index f54f360b01..fc5f416395 100644 --- a/proxy/src/cache.rs +++ b/proxy/src/cache.rs @@ -1,311 +1,6 @@ -use std::{ - borrow::Borrow, - hash::Hash, - ops::{Deref, DerefMut}, - time::{Duration, Instant}, -}; -use tracing::debug; - -// This seems to make more sense than `lru` or `cached`: -// -// * `near/nearcore` ditched `cached` in favor of `lru` -// (https://github.com/near/nearcore/issues?q=is%3Aissue+lru+is%3Aclosed). -// -// * `lru` methods use an obscure `KeyRef` type in their contraints (which is deliberately excluded from docs). -// This severely hinders its usage both in terms of creating wrappers and supported key types. -// -// On the other hand, `hashlink` has good download stats and appears to be maintained. -use hashlink::{linked_hash_map::RawEntryMut, LruCache}; - -/// A generic trait which exposes types of cache's key and value, -/// as well as the notion of cache entry invalidation. -/// This is useful for [`timed_lru::Cached`]. -pub trait Cache { - /// Entry's key. - type Key; - - /// Entry's value. - type Value; - - /// Used for entry invalidation. - type LookupInfo; - - /// Invalidate an entry using a lookup info. - /// We don't have an empty default impl because it's error-prone. - fn invalidate(&self, _: &Self::LookupInfo); -} - -impl Cache for &C { - type Key = C::Key; - type Value = C::Value; - type LookupInfo = C::LookupInfo; - - fn invalidate(&self, info: &Self::LookupInfo) { - C::invalidate(self, info) - } -} +pub mod common; +pub mod project_info; +mod timed_lru; +pub use common::{Cache, Cached}; pub use timed_lru::TimedLru; -pub mod timed_lru { - use super::*; - - /// An implementation of timed LRU cache with fixed capacity. - /// Key properties: - /// - /// * Whenever a new entry is inserted, the least recently accessed one is evicted. - /// The cache also keeps track of entry's insertion time (`created_at`) and TTL (`expires_at`). - /// - /// * If `update_ttl_on_retrieval` is `true`. When the entry is about to be retrieved, we check its expiration timestamp. - /// If the entry has expired, we remove it from the cache; Otherwise we bump the - /// expiration timestamp (e.g. +5mins) and change its place in LRU list to prolong - /// its existence. - /// - /// * There's an API for immediate invalidation (removal) of a cache entry; - /// It's useful in case we know for sure that the entry is no longer correct. - /// See [`timed_lru::LookupInfo`] & [`timed_lru::Cached`] for more information. - /// - /// * Expired entries are kept in the cache, until they are evicted by the LRU policy, - /// or by a successful lookup (i.e. the entry hasn't expired yet). - /// There is no background job to reap the expired records. - /// - /// * It's possible for an entry that has not yet expired entry to be evicted - /// before expired items. That's a bit wasteful, but probably fine in practice. - pub struct TimedLru { - /// Cache's name for tracing. - name: &'static str, - - /// The underlying cache implementation. - cache: parking_lot::Mutex>>, - - /// Default time-to-live of a single entry. - ttl: Duration, - - update_ttl_on_retrieval: bool, - } - - impl Cache for TimedLru { - type Key = K; - type Value = V; - type LookupInfo = LookupInfo; - - fn invalidate(&self, info: &Self::LookupInfo) { - self.invalidate_raw(info) - } - } - - struct Entry { - created_at: Instant, - expires_at: Instant, - value: T, - } - - impl TimedLru { - /// Construct a new LRU cache with timed entries. - pub fn new( - name: &'static str, - capacity: usize, - ttl: Duration, - update_ttl_on_retrieval: bool, - ) -> Self { - Self { - name, - cache: LruCache::new(capacity).into(), - ttl, - update_ttl_on_retrieval, - } - } - - /// Drop an entry from the cache if it's outdated. - #[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)] - fn invalidate_raw(&self, info: &LookupInfo) { - let now = Instant::now(); - - // Do costly things before taking the lock. - let mut cache = self.cache.lock(); - let raw_entry = match cache.raw_entry_mut().from_key(&info.key) { - RawEntryMut::Vacant(_) => return, - RawEntryMut::Occupied(x) => x, - }; - - // Remove the entry if it was created prior to lookup timestamp. - let entry = raw_entry.get(); - let (created_at, expires_at) = (entry.created_at, entry.expires_at); - let should_remove = created_at <= info.created_at || expires_at <= now; - - if should_remove { - raw_entry.remove(); - } - - drop(cache); // drop lock before logging - debug!( - created_at = format_args!("{created_at:?}"), - expires_at = format_args!("{expires_at:?}"), - entry_removed = should_remove, - "processed a cache entry invalidation event" - ); - } - - /// Try retrieving an entry by its key, then execute `extract` if it exists. - #[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)] - fn get_raw(&self, key: &Q, extract: impl FnOnce(&K, &Entry) -> R) -> Option - where - K: Borrow, - Q: Hash + Eq + ?Sized, - { - let now = Instant::now(); - let deadline = now.checked_add(self.ttl).expect("time overflow"); - - // Do costly things before taking the lock. - let mut cache = self.cache.lock(); - let mut raw_entry = match cache.raw_entry_mut().from_key(key) { - RawEntryMut::Vacant(_) => return None, - RawEntryMut::Occupied(x) => x, - }; - - // Immeditely drop the entry if it has expired. - let entry = raw_entry.get(); - if entry.expires_at <= now { - raw_entry.remove(); - return None; - } - - let value = extract(raw_entry.key(), entry); - let (created_at, expires_at) = (entry.created_at, entry.expires_at); - - // Update the deadline and the entry's position in the LRU list. - if self.update_ttl_on_retrieval { - raw_entry.get_mut().expires_at = deadline; - } - raw_entry.to_back(); - - drop(cache); // drop lock before logging - debug!( - created_at = format_args!("{created_at:?}"), - old_expires_at = format_args!("{expires_at:?}"), - new_expires_at = format_args!("{deadline:?}"), - "accessed a cache entry" - ); - - Some(value) - } - - /// Insert an entry to the cache. If an entry with the same key already - /// existed, return the previous value and its creation timestamp. - #[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)] - fn insert_raw(&self, key: K, value: V) -> (Instant, Option) { - let created_at = Instant::now(); - let expires_at = created_at.checked_add(self.ttl).expect("time overflow"); - - let entry = Entry { - created_at, - expires_at, - value, - }; - - // Do costly things before taking the lock. - let old = self - .cache - .lock() - .insert(key, entry) - .map(|entry| entry.value); - - debug!( - created_at = format_args!("{created_at:?}"), - expires_at = format_args!("{expires_at:?}"), - replaced = old.is_some(), - "created a cache entry" - ); - - (created_at, old) - } - } - - impl TimedLru { - pub fn insert(&self, key: K, value: V) -> (Option, Cached<&Self>) { - let (created_at, old) = self.insert_raw(key.clone(), value.clone()); - - let cached = Cached { - token: Some((self, LookupInfo { created_at, key })), - value, - }; - - (old, cached) - } - } - - impl TimedLru { - /// Retrieve a cached entry in convenient wrapper. - pub fn get(&self, key: &Q) -> Option> - where - K: Borrow + Clone, - Q: Hash + Eq + ?Sized, - { - self.get_raw(key, |key, entry| { - let info = LookupInfo { - created_at: entry.created_at, - key: key.clone(), - }; - - Cached { - token: Some((self, info)), - value: entry.value.clone(), - } - }) - } - } - - /// Lookup information for key invalidation. - pub struct LookupInfo { - /// Time of creation of a cache [`Entry`]. - /// We use this during invalidation lookups to prevent eviction of a newer - /// entry sharing the same key (it might've been inserted by a different - /// task after we got the entry we're trying to invalidate now). - created_at: Instant, - - /// Search by this key. - key: K, - } - - /// Wrapper for convenient entry invalidation. - pub struct Cached { - /// Cache + lookup info. - token: Option<(C, C::LookupInfo)>, - - /// The value itself. - value: C::Value, - } - - impl Cached { - /// Place any entry into this wrapper; invalidation will be a no-op. - pub fn new_uncached(value: C::Value) -> Self { - Self { token: None, value } - } - - /// Drop this entry from a cache if it's still there. - pub fn invalidate(self) -> C::Value { - if let Some((cache, info)) = &self.token { - cache.invalidate(info); - } - self.value - } - - /// Tell if this entry is actually cached. - pub fn cached(&self) -> bool { - self.token.is_some() - } - } - - impl Deref for Cached { - type Target = C::Value; - - fn deref(&self) -> &Self::Target { - &self.value - } - } - - impl DerefMut for Cached { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.value - } - } -} diff --git a/proxy/src/cache/common.rs b/proxy/src/cache/common.rs new file mode 100644 index 0000000000..2af6a70e90 --- /dev/null +++ b/proxy/src/cache/common.rs @@ -0,0 +1,72 @@ +use std::ops::{Deref, DerefMut}; + +/// A generic trait which exposes types of cache's key and value, +/// as well as the notion of cache entry invalidation. +/// This is useful for [`Cached`]. +pub trait Cache { + /// Entry's key. + type Key; + + /// Entry's value. + type Value; + + /// Used for entry invalidation. + type LookupInfo; + + /// Invalidate an entry using a lookup info. + /// We don't have an empty default impl because it's error-prone. + fn invalidate(&self, _: &Self::LookupInfo); +} + +impl Cache for &C { + type Key = C::Key; + type Value = C::Value; + type LookupInfo = C::LookupInfo; + + fn invalidate(&self, info: &Self::LookupInfo) { + C::invalidate(self, info) + } +} + +/// Wrapper for convenient entry invalidation. +pub struct Cached::Value> { + /// Cache + lookup info. + pub token: Option<(C, C::LookupInfo)>, + + /// The value itself. + pub value: V, +} + +impl Cached { + /// Place any entry into this wrapper; invalidation will be a no-op. + pub fn new_uncached(value: V) -> Self { + Self { token: None, value } + } + + /// Drop this entry from a cache if it's still there. + pub fn invalidate(self) -> V { + if let Some((cache, info)) = &self.token { + cache.invalidate(info); + } + self.value + } + + /// Tell if this entry is actually cached. + pub fn cached(&self) -> bool { + self.token.is_some() + } +} + +impl Deref for Cached { + type Target = V; + + fn deref(&self) -> &Self::Target { + &self.value + } +} + +impl DerefMut for Cached { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.value + } +} diff --git a/proxy/src/cache/project_info.rs b/proxy/src/cache/project_info.rs new file mode 100644 index 0000000000..7af2118873 --- /dev/null +++ b/proxy/src/cache/project_info.rs @@ -0,0 +1,496 @@ +use std::{ + collections::HashSet, + convert::Infallible, + sync::{atomic::AtomicU64, Arc}, + time::Duration, +}; + +use dashmap::DashMap; +use rand::{thread_rng, Rng}; +use smol_str::SmolStr; +use tokio::time::Instant; +use tracing::{debug, info}; + +use crate::{config::ProjectInfoCacheOptions, console::AuthSecret}; + +use super::{Cache, Cached}; + +pub trait ProjectInfoCache { + fn invalidate_allowed_ips_for_project(&self, project_id: &SmolStr); + fn invalidate_role_secret_for_project(&self, project_id: &SmolStr, role_name: &SmolStr); + fn enable_ttl(&self); + fn disable_ttl(&self); +} + +struct Entry { + created_at: Instant, + value: T, +} + +impl Entry { + pub fn new(value: T) -> Self { + Self { + created_at: Instant::now(), + value, + } + } +} + +impl From for Entry { + fn from(value: T) -> Self { + Self::new(value) + } +} + +#[derive(Default)] +struct EndpointInfo { + secret: std::collections::HashMap>, + allowed_ips: Option>>>, +} + +impl EndpointInfo { + fn check_ignore_cache(ignore_cache_since: Option, created_at: Instant) -> bool { + match ignore_cache_since { + None => false, + Some(t) => t < created_at, + } + } + pub fn get_role_secret( + &self, + role_name: &SmolStr, + valid_since: Instant, + ignore_cache_since: Option, + ) -> Option<(AuthSecret, bool)> { + if let Some(secret) = self.secret.get(role_name) { + if valid_since < secret.created_at { + return Some(( + secret.value.clone(), + Self::check_ignore_cache(ignore_cache_since, secret.created_at), + )); + } + } + None + } + + pub fn get_allowed_ips( + &self, + valid_since: Instant, + ignore_cache_since: Option, + ) -> Option<(Arc>, bool)> { + if let Some(allowed_ips) = &self.allowed_ips { + if valid_since < allowed_ips.created_at { + return Some(( + allowed_ips.value.clone(), + Self::check_ignore_cache(ignore_cache_since, allowed_ips.created_at), + )); + } + } + None + } + pub fn invalidate_allowed_ips(&mut self) { + self.allowed_ips = None; + } + pub fn invalidate_role_secret(&mut self, role_name: &SmolStr) { + self.secret.remove(role_name); + } +} + +/// Cache for project info. +/// This is used to cache auth data for endpoints. +/// Invalidation is done by console notifications or by TTL (if console notifications are disabled). +/// +/// We also store endpoint-to-project mapping in the cache, to be able to access per-endpoint data. +/// One may ask, why the data is stored per project, when on the user request there is only data about the endpoint available? +/// On the cplane side updates are done per project (or per branch), so it's easier to invalidate the whole project cache. +pub struct ProjectInfoCacheImpl { + cache: DashMap, + + project2ep: DashMap>, + config: ProjectInfoCacheOptions, + + start_time: Instant, + ttl_disabled_since_us: AtomicU64, +} + +impl ProjectInfoCache for ProjectInfoCacheImpl { + fn invalidate_allowed_ips_for_project(&self, project_id: &SmolStr) { + info!("invalidating allowed ips for project `{}`", project_id); + let endpoints = self + .project2ep + .get(project_id) + .map(|kv| kv.value().clone()) + .unwrap_or_default(); + for endpoint_id in endpoints { + if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) { + endpoint_info.invalidate_allowed_ips(); + } + } + } + fn invalidate_role_secret_for_project(&self, project_id: &SmolStr, role_name: &SmolStr) { + info!( + "invalidating role secret for project_id `{}` and role_name `{}`", + project_id, role_name + ); + let endpoints = self + .project2ep + .get(project_id) + .map(|kv| kv.value().clone()) + .unwrap_or_default(); + for endpoint_id in endpoints { + if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) { + endpoint_info.invalidate_role_secret(role_name); + } + } + } + fn enable_ttl(&self) { + self.ttl_disabled_since_us + .store(u64::MAX, std::sync::atomic::Ordering::Relaxed); + } + + fn disable_ttl(&self) { + let new_ttl = (self.start_time.elapsed() + self.config.ttl).as_micros() as u64; + self.ttl_disabled_since_us + .store(new_ttl, std::sync::atomic::Ordering::Relaxed); + } +} + +impl ProjectInfoCacheImpl { + pub fn new(config: ProjectInfoCacheOptions) -> Self { + Self { + cache: DashMap::new(), + project2ep: DashMap::new(), + config, + ttl_disabled_since_us: AtomicU64::new(u64::MAX), + start_time: Instant::now(), + } + } + + pub fn get_role_secret( + &self, + endpoint_id: &SmolStr, + role_name: &SmolStr, + ) -> Option> { + let (valid_since, ignore_cache_since) = self.get_cache_times(); + let endpoint_info = self.cache.get(endpoint_id)?; + let (value, ignore_cache) = + endpoint_info.get_role_secret(role_name, valid_since, ignore_cache_since)?; + if !ignore_cache { + let cached = Cached { + token: Some(( + self, + CachedLookupInfo::new_role_secret(endpoint_id.clone(), role_name.clone()), + )), + value, + }; + return Some(cached); + } + Some(Cached::new_uncached(value)) + } + pub fn get_allowed_ips( + &self, + endpoint_id: &SmolStr, + ) -> Option>>> { + let (valid_since, ignore_cache_since) = self.get_cache_times(); + let endpoint_info = self.cache.get(endpoint_id)?; + let value = endpoint_info.get_allowed_ips(valid_since, ignore_cache_since); + let (value, ignore_cache) = value?; + if !ignore_cache { + let cached = Cached { + token: Some((self, CachedLookupInfo::new_allowed_ips(endpoint_id.clone()))), + value, + }; + return Some(cached); + } + Some(Cached::new_uncached(value)) + } + pub fn insert_role_secret( + &self, + project_id: &SmolStr, + endpoint_id: &SmolStr, + role_name: &SmolStr, + secret: AuthSecret, + ) { + if self.cache.len() >= self.config.size { + // If there are too many entries, wait until the next gc cycle. + return; + } + self.inser_project2endpoint(project_id, endpoint_id); + let mut entry = self.cache.entry(endpoint_id.clone()).or_default(); + if entry.secret.len() < self.config.max_roles { + entry.secret.insert(role_name.clone(), secret.into()); + } + } + pub fn insert_allowed_ips( + &self, + project_id: &SmolStr, + endpoint_id: &SmolStr, + allowed_ips: Arc>, + ) { + if self.cache.len() >= self.config.size { + // If there are too many entries, wait until the next gc cycle. + return; + } + self.inser_project2endpoint(project_id, endpoint_id); + self.cache + .entry(endpoint_id.clone()) + .or_default() + .allowed_ips = Some(allowed_ips.into()); + } + fn inser_project2endpoint(&self, project_id: &SmolStr, endpoint_id: &SmolStr) { + if let Some(mut endpoints) = self.project2ep.get_mut(project_id) { + endpoints.insert(endpoint_id.clone()); + } else { + self.project2ep + .insert(project_id.clone(), HashSet::from([endpoint_id.clone()])); + } + } + fn get_cache_times(&self) -> (Instant, Option) { + let mut valid_since = Instant::now() - self.config.ttl; + // Only ignore cache if ttl is disabled. + let ttl_disabled_since_us = self + .ttl_disabled_since_us + .load(std::sync::atomic::Ordering::Relaxed); + let ignore_cache_since = if ttl_disabled_since_us != u64::MAX { + let ignore_cache_since = self.start_time + Duration::from_micros(ttl_disabled_since_us); + // We are fine if entry is not older than ttl or was added before we are getting notifications. + valid_since = valid_since.min(ignore_cache_since); + Some(ignore_cache_since) + } else { + None + }; + (valid_since, ignore_cache_since) + } + + pub async fn gc_worker(&self) -> anyhow::Result { + let mut interval = + tokio::time::interval(self.config.gc_interval / (self.cache.shards().len()) as u32); + loop { + interval.tick().await; + if self.cache.len() <= self.config.size { + // If there are not too many entries, wait until the next gc cycle. + continue; + } + self.gc(); + } + } + + fn gc(&self) { + let shard = thread_rng().gen_range(0..self.project2ep.shards().len()); + debug!(shard, "project_info_cache: performing epoch reclamation"); + + // acquire a random shard lock + let mut removed = 0; + let shard = self.project2ep.shards()[shard].write(); + for (_, endpoints) in shard.iter() { + for endpoint in endpoints.get().iter() { + self.cache.remove(endpoint); + removed += 1; + } + } + // We can drop this shard only after making sure that all endpoints are removed. + drop(shard); + info!("project_info_cache: removed {removed} endpoints"); + } +} + +/// Lookup info for project info cache. +/// This is used to invalidate cache entries. +pub struct CachedLookupInfo { + /// Search by this key. + endpoint_id: SmolStr, + lookup_type: LookupType, +} + +impl CachedLookupInfo { + pub(self) fn new_role_secret(endpoint_id: SmolStr, role_name: SmolStr) -> Self { + Self { + endpoint_id, + lookup_type: LookupType::RoleSecret(role_name), + } + } + pub(self) fn new_allowed_ips(endpoint_id: SmolStr) -> Self { + Self { + endpoint_id, + lookup_type: LookupType::AllowedIps, + } + } +} + +enum LookupType { + RoleSecret(SmolStr), + AllowedIps, +} + +impl Cache for ProjectInfoCacheImpl { + type Key = SmolStr; + // Value is not really used here, but we need to specify it. + type Value = SmolStr; + + type LookupInfo = CachedLookupInfo; + + fn invalidate(&self, key: &Self::LookupInfo) { + match &key.lookup_type { + LookupType::RoleSecret(role_name) => { + if let Some(mut endpoint_info) = self.cache.get_mut(&key.endpoint_id) { + endpoint_info.invalidate_role_secret(role_name); + } + } + LookupType::AllowedIps => { + if let Some(mut endpoint_info) = self.cache.get_mut(&key.endpoint_id) { + endpoint_info.invalidate_allowed_ips(); + } + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{console::AuthSecret, scram::ServerSecret}; + use smol_str::SmolStr; + use std::{sync::Arc, time::Duration}; + + #[tokio::test] + async fn test_project_info_cache_settings() { + tokio::time::pause(); + let cache = ProjectInfoCacheImpl::new(ProjectInfoCacheOptions { + size: 2, + max_roles: 2, + ttl: Duration::from_secs(1), + gc_interval: Duration::from_secs(600), + }); + let project_id = "project".into(); + let endpoint_id = "endpoint".into(); + let user1: SmolStr = "user1".into(); + let user2: SmolStr = "user2".into(); + let secret1 = AuthSecret::Scram(ServerSecret::mock(user1.as_str(), [1; 32])); + let secret2 = AuthSecret::Scram(ServerSecret::mock(user2.as_str(), [2; 32])); + let allowed_ips = Arc::new(vec!["allowed_ip1".into(), "allowed_ip2".into()]); + cache.insert_role_secret(&project_id, &endpoint_id, &user1, secret1.clone()); + cache.insert_role_secret(&project_id, &endpoint_id, &user2, secret2.clone()); + cache.insert_allowed_ips(&project_id, &endpoint_id, allowed_ips.clone()); + + let cached = cache.get_role_secret(&endpoint_id, &user1).unwrap(); + assert!(cached.cached()); + assert_eq!(cached.value, secret1); + let cached = cache.get_role_secret(&endpoint_id, &user2).unwrap(); + assert!(cached.cached()); + assert_eq!(cached.value, secret2); + + // Shouldn't add more than 2 roles. + let user3: SmolStr = "user3".into(); + let secret3 = AuthSecret::Scram(ServerSecret::mock(user3.as_str(), [3; 32])); + cache.insert_role_secret(&project_id, &endpoint_id, &user3, secret3.clone()); + assert!(cache.get_role_secret(&endpoint_id, &user3).is_none()); + + let cached = cache.get_allowed_ips(&endpoint_id).unwrap(); + assert!(cached.cached()); + assert_eq!(cached.value, allowed_ips); + + tokio::time::advance(Duration::from_secs(2)).await; + let cached = cache.get_role_secret(&endpoint_id, &user1); + assert!(cached.is_none()); + let cached = cache.get_role_secret(&endpoint_id, &user2); + assert!(cached.is_none()); + let cached = cache.get_allowed_ips(&endpoint_id); + assert!(cached.is_none()); + } + + #[tokio::test] + async fn test_project_info_cache_invalidations() { + tokio::time::pause(); + let cache = Arc::new(ProjectInfoCacheImpl::new(ProjectInfoCacheOptions { + size: 2, + max_roles: 2, + ttl: Duration::from_secs(1), + gc_interval: Duration::from_secs(600), + })); + cache.clone().disable_ttl(); + tokio::time::advance(Duration::from_secs(2)).await; + + let project_id = "project".into(); + let endpoint_id = "endpoint".into(); + let user1: SmolStr = "user1".into(); + let user2: SmolStr = "user2".into(); + let secret1 = AuthSecret::Scram(ServerSecret::mock(user1.as_str(), [1; 32])); + let secret2 = AuthSecret::Scram(ServerSecret::mock(user2.as_str(), [2; 32])); + let allowed_ips = Arc::new(vec!["allowed_ip1".into(), "allowed_ip2".into()]); + cache.insert_role_secret(&project_id, &endpoint_id, &user1, secret1.clone()); + cache.insert_role_secret(&project_id, &endpoint_id, &user2, secret2.clone()); + cache.insert_allowed_ips(&project_id, &endpoint_id, allowed_ips.clone()); + + tokio::time::advance(Duration::from_secs(2)).await; + // Nothing should be invalidated. + + let cached = cache.get_role_secret(&endpoint_id, &user1).unwrap(); + // TTL is disabled, so it should be impossible to invalidate this value. + assert!(!cached.cached()); + assert_eq!(cached.value, secret1); + + cached.invalidate(); // Shouldn't do anything. + let cached = cache.get_role_secret(&endpoint_id, &user1).unwrap(); + assert_eq!(cached.value, secret1); + + let cached = cache.get_role_secret(&endpoint_id, &user2).unwrap(); + assert!(!cached.cached()); + assert_eq!(cached.value, secret2); + + // The only way to invalidate this value is to invalidate via the api. + cache.invalidate_role_secret_for_project(&project_id, &user2); + assert!(cache.get_role_secret(&endpoint_id, &user2).is_none()); + + let cached = cache.get_allowed_ips(&endpoint_id).unwrap(); + assert!(!cached.cached()); + assert_eq!(cached.value, allowed_ips); + } + + #[tokio::test] + async fn test_disable_ttl_invalidate_added_before() { + tokio::time::pause(); + let cache = Arc::new(ProjectInfoCacheImpl::new(ProjectInfoCacheOptions { + size: 2, + max_roles: 2, + ttl: Duration::from_secs(1), + gc_interval: Duration::from_secs(600), + })); + + let project_id = "project".into(); + let endpoint_id = "endpoint".into(); + let user1: SmolStr = "user1".into(); + let user2: SmolStr = "user2".into(); + let secret1 = AuthSecret::Scram(ServerSecret::mock(user1.as_str(), [1; 32])); + let secret2 = AuthSecret::Scram(ServerSecret::mock(user2.as_str(), [2; 32])); + let allowed_ips = Arc::new(vec!["allowed_ip1".into(), "allowed_ip2".into()]); + cache.insert_role_secret(&project_id, &endpoint_id, &user1, secret1.clone()); + cache.clone().disable_ttl(); + tokio::time::advance(Duration::from_millis(100)).await; + cache.insert_role_secret(&project_id, &endpoint_id, &user2, secret2.clone()); + + // Added before ttl was disabled + ttl should be still cached. + let cached = cache.get_role_secret(&endpoint_id, &user1).unwrap(); + assert!(cached.cached()); + let cached = cache.get_role_secret(&endpoint_id, &user2).unwrap(); + assert!(cached.cached()); + + tokio::time::advance(Duration::from_secs(1)).await; + // Added before ttl was disabled + ttl should expire. + assert!(cache.get_role_secret(&endpoint_id, &user1).is_none()); + assert!(cache.get_role_secret(&endpoint_id, &user2).is_none()); + + // Added after ttl was disabled + ttl should not be cached. + cache.insert_allowed_ips(&project_id, &endpoint_id, allowed_ips.clone()); + let cached = cache.get_allowed_ips(&endpoint_id).unwrap(); + assert!(!cached.cached()); + + tokio::time::advance(Duration::from_secs(1)).await; + // Added before ttl was disabled + ttl still should expire. + assert!(cache.get_role_secret(&endpoint_id, &user1).is_none()); + assert!(cache.get_role_secret(&endpoint_id, &user2).is_none()); + // Shouldn't be invalidated. + + let cached = cache.get_allowed_ips(&endpoint_id).unwrap(); + assert!(!cached.cached()); + assert_eq!(cached.value, allowed_ips); + } +} diff --git a/proxy/src/cache/timed_lru.rs b/proxy/src/cache/timed_lru.rs new file mode 100644 index 0000000000..3b21381bb9 --- /dev/null +++ b/proxy/src/cache/timed_lru.rs @@ -0,0 +1,258 @@ +use std::{ + borrow::Borrow, + hash::Hash, + time::{Duration, Instant}, +}; +use tracing::debug; + +// This seems to make more sense than `lru` or `cached`: +// +// * `near/nearcore` ditched `cached` in favor of `lru` +// (https://github.com/near/nearcore/issues?q=is%3Aissue+lru+is%3Aclosed). +// +// * `lru` methods use an obscure `KeyRef` type in their contraints (which is deliberately excluded from docs). +// This severely hinders its usage both in terms of creating wrappers and supported key types. +// +// On the other hand, `hashlink` has good download stats and appears to be maintained. +use hashlink::{linked_hash_map::RawEntryMut, LruCache}; + +use super::{common::Cached, *}; + +/// An implementation of timed LRU cache with fixed capacity. +/// Key properties: +/// +/// * Whenever a new entry is inserted, the least recently accessed one is evicted. +/// The cache also keeps track of entry's insertion time (`created_at`) and TTL (`expires_at`). +/// +/// * If `update_ttl_on_retrieval` is `true`. When the entry is about to be retrieved, we check its expiration timestamp. +/// If the entry has expired, we remove it from the cache; Otherwise we bump the +/// expiration timestamp (e.g. +5mins) and change its place in LRU list to prolong +/// its existence. +/// +/// * There's an API for immediate invalidation (removal) of a cache entry; +/// It's useful in case we know for sure that the entry is no longer correct. +/// See [`timed_lru::LookupInfo`] & [`timed_lru::Cached`] for more information. +/// +/// * Expired entries are kept in the cache, until they are evicted by the LRU policy, +/// or by a successful lookup (i.e. the entry hasn't expired yet). +/// There is no background job to reap the expired records. +/// +/// * It's possible for an entry that has not yet expired entry to be evicted +/// before expired items. That's a bit wasteful, but probably fine in practice. +pub struct TimedLru { + /// Cache's name for tracing. + name: &'static str, + + /// The underlying cache implementation. + cache: parking_lot::Mutex>>, + + /// Default time-to-live of a single entry. + ttl: Duration, + + update_ttl_on_retrieval: bool, +} + +impl Cache for TimedLru { + type Key = K; + type Value = V; + type LookupInfo = LookupInfo; + + fn invalidate(&self, info: &Self::LookupInfo) { + self.invalidate_raw(info) + } +} + +struct Entry { + created_at: Instant, + expires_at: Instant, + value: T, +} + +impl TimedLru { + /// Construct a new LRU cache with timed entries. + pub fn new( + name: &'static str, + capacity: usize, + ttl: Duration, + update_ttl_on_retrieval: bool, + ) -> Self { + Self { + name, + cache: LruCache::new(capacity).into(), + ttl, + update_ttl_on_retrieval, + } + } + + /// Drop an entry from the cache if it's outdated. + #[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)] + fn invalidate_raw(&self, info: &LookupInfo) { + let now = Instant::now(); + + // Do costly things before taking the lock. + let mut cache = self.cache.lock(); + let raw_entry = match cache.raw_entry_mut().from_key(&info.key) { + RawEntryMut::Vacant(_) => return, + RawEntryMut::Occupied(x) => x, + }; + + // Remove the entry if it was created prior to lookup timestamp. + let entry = raw_entry.get(); + let (created_at, expires_at) = (entry.created_at, entry.expires_at); + let should_remove = created_at <= info.created_at || expires_at <= now; + + if should_remove { + raw_entry.remove(); + } + + drop(cache); // drop lock before logging + debug!( + created_at = format_args!("{created_at:?}"), + expires_at = format_args!("{expires_at:?}"), + entry_removed = should_remove, + "processed a cache entry invalidation event" + ); + } + + /// Try retrieving an entry by its key, then execute `extract` if it exists. + #[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)] + fn get_raw(&self, key: &Q, extract: impl FnOnce(&K, &Entry) -> R) -> Option + where + K: Borrow, + Q: Hash + Eq + ?Sized, + { + let now = Instant::now(); + let deadline = now.checked_add(self.ttl).expect("time overflow"); + + // Do costly things before taking the lock. + let mut cache = self.cache.lock(); + let mut raw_entry = match cache.raw_entry_mut().from_key(key) { + RawEntryMut::Vacant(_) => return None, + RawEntryMut::Occupied(x) => x, + }; + + // Immeditely drop the entry if it has expired. + let entry = raw_entry.get(); + if entry.expires_at <= now { + raw_entry.remove(); + return None; + } + + let value = extract(raw_entry.key(), entry); + let (created_at, expires_at) = (entry.created_at, entry.expires_at); + + // Update the deadline and the entry's position in the LRU list. + if self.update_ttl_on_retrieval { + raw_entry.get_mut().expires_at = deadline; + } + raw_entry.to_back(); + + drop(cache); // drop lock before logging + debug!( + created_at = format_args!("{created_at:?}"), + old_expires_at = format_args!("{expires_at:?}"), + new_expires_at = format_args!("{deadline:?}"), + "accessed a cache entry" + ); + + Some(value) + } + + /// Insert an entry to the cache. If an entry with the same key already + /// existed, return the previous value and its creation timestamp. + #[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)] + fn insert_raw(&self, key: K, value: V) -> (Instant, Option) { + let created_at = Instant::now(); + let expires_at = created_at.checked_add(self.ttl).expect("time overflow"); + + let entry = Entry { + created_at, + expires_at, + value, + }; + + // Do costly things before taking the lock. + let old = self + .cache + .lock() + .insert(key, entry) + .map(|entry| entry.value); + + debug!( + created_at = format_args!("{created_at:?}"), + expires_at = format_args!("{expires_at:?}"), + replaced = old.is_some(), + "created a cache entry" + ); + + (created_at, old) + } +} + +impl TimedLru { + pub fn insert(&self, key: K, value: V) -> (Option, Cached<&Self>) { + let (created_at, old) = self.insert_raw(key.clone(), value.clone()); + + let cached = Cached { + token: Some((self, LookupInfo { created_at, key })), + value, + }; + + (old, cached) + } +} + +impl TimedLru { + /// Retrieve a cached entry in convenient wrapper. + pub fn get(&self, key: &Q) -> Option> + where + K: Borrow + Clone, + Q: Hash + Eq + ?Sized, + { + self.get_raw(key, |key, entry| { + let info = LookupInfo { + created_at: entry.created_at, + key: key.clone(), + }; + + Cached { + token: Some((self, info)), + value: entry.value.clone(), + } + }) + } + + /// Retrieve a cached entry in convenient wrapper, ignoring its TTL. + pub fn get_ignoring_ttl(&self, key: &Q) -> Option> + where + K: Borrow, + Q: Hash + Eq + ?Sized, + { + let mut cache = self.cache.lock(); + cache + .get(key) + .map(|entry| Cached::new_uncached(entry.value.clone())) + } + + /// Remove an entry from the cache. + pub fn remove(&self, key: &Q) -> Option + where + K: Borrow + Clone, + Q: Hash + Eq + ?Sized, + { + let mut cache = self.cache.lock(); + cache.remove(key).map(|entry| entry.value) + } +} + +/// Lookup information for key invalidation. +pub struct LookupInfo { + /// Time of creation of a cache [`Entry`]. + /// We use this during invalidation lookups to prevent eviction of a newer + /// entry sharing the same key (it might've been inserted by a different + /// task after we got the entry we're trying to invalidate now). + created_at: Instant, + + /// Search by this key. + key: K, +} diff --git a/proxy/src/config.rs b/proxy/src/config.rs index 90956f84d3..043d8d0791 100644 --- a/proxy/src/config.rs +++ b/proxy/src/config.rs @@ -352,6 +352,69 @@ impl FromStr for CacheOptions { } } +/// Helper for cmdline cache options parsing. +#[derive(Debug)] +pub struct ProjectInfoCacheOptions { + /// Max number of entries. + pub size: usize, + /// Entry's time-to-live. + pub ttl: Duration, + /// Max number of roles per endpoint. + pub max_roles: usize, + /// Gc interval. + pub gc_interval: Duration, +} + +impl ProjectInfoCacheOptions { + /// Default options for [`crate::console::provider::NodeInfoCache`]. + pub const CACHE_DEFAULT_OPTIONS: &'static str = + "size=10000,ttl=4m,max_roles=10,gc_interval=60m"; + + /// Parse cache options passed via cmdline. + /// Example: [`Self::CACHE_DEFAULT_OPTIONS`]. + fn parse(options: &str) -> anyhow::Result { + let mut size = None; + let mut ttl = None; + let mut max_roles = None; + let mut gc_interval = None; + + for option in options.split(',') { + let (key, value) = option + .split_once('=') + .with_context(|| format!("bad key-value pair: {option}"))?; + + match key { + "size" => size = Some(value.parse()?), + "ttl" => ttl = Some(humantime::parse_duration(value)?), + "max_roles" => max_roles = Some(value.parse()?), + "gc_interval" => gc_interval = Some(humantime::parse_duration(value)?), + unknown => bail!("unknown key: {unknown}"), + } + } + + // TTL doesn't matter if cache is always empty. + if let Some(0) = size { + ttl.get_or_insert(Duration::default()); + } + + Ok(Self { + size: size.context("missing `size`")?, + ttl: ttl.context("missing `ttl`")?, + max_roles: max_roles.context("missing `max_roles`")?, + gc_interval: gc_interval.context("missing `gc_interval`")?, + }) + } +} + +impl FromStr for ProjectInfoCacheOptions { + type Err = anyhow::Error; + + fn from_str(options: &str) -> Result { + let error = || format!("failed to parse cache options '{options}'"); + Self::parse(options).with_context(error) + } +} + /// Helper for cmdline cache options parsing. pub struct WakeComputeLockOptions { /// The number of shards the lock map should have diff --git a/proxy/src/console/messages.rs b/proxy/src/console/messages.rs index 837379b21f..c02d65668f 100644 --- a/proxy/src/console/messages.rs +++ b/proxy/src/console/messages.rs @@ -15,6 +15,7 @@ pub struct ConsoleError { pub struct GetRoleSecret { pub role_secret: Box, pub allowed_ips: Option>>, + pub project_id: Option>, } // Manually implement debug to omit sensitive info. @@ -207,12 +208,17 @@ mod tests { "role_secret": "secret", }); let _: GetRoleSecret = serde_json::from_str(&json.to_string())?; - // Empty `allowed_ips` field. let json = json!({ "role_secret": "secret", "allowed_ips": ["8.8.8.8"], }); let _: GetRoleSecret = serde_json::from_str(&json.to_string())?; + let json = json!({ + "role_secret": "secret", + "allowed_ips": ["8.8.8.8"], + "project_id": "project", + }); + let _: GetRoleSecret = serde_json::from_str(&json.to_string())?; Ok(()) } diff --git a/proxy/src/console/provider.rs b/proxy/src/console/provider.rs index 974384bd5b..9497d36bc7 100644 --- a/proxy/src/console/provider.rs +++ b/proxy/src/console/provider.rs @@ -5,8 +5,9 @@ pub mod neon; use super::messages::MetricsAuxInfo; use crate::{ auth::backend::ComputeUserInfo, - cache::{timed_lru, TimedLru}, + cache::{project_info::ProjectInfoCacheImpl, Cached, TimedLru}, compute, + config::{CacheOptions, ProjectInfoCacheOptions}, context::RequestMonitoring, scram, }; @@ -14,10 +15,8 @@ use async_trait::async_trait; use dashmap::DashMap; use smol_str::SmolStr; use std::{sync::Arc, time::Duration}; -use tokio::{ - sync::{OwnedSemaphorePermit, Semaphore}, - time::Instant, -}; +use tokio::sync::{OwnedSemaphorePermit, Semaphore}; +use tokio::time::Instant; use tracing::info; pub mod errors { @@ -215,7 +214,7 @@ impl ConsoleReqExtra { } /// Auth secret which is managed by the cloud. -#[derive(Clone)] +#[derive(Clone, Eq, PartialEq, Debug)] pub enum AuthSecret { #[cfg(feature = "testing")] /// Md5 hash of user's password. @@ -229,7 +228,9 @@ pub enum AuthSecret { pub struct AuthInfo { pub secret: Option, /// List of IP addresses allowed for the autorization. - pub allowed_ips: Vec, + pub allowed_ips: Vec, + /// Project ID. This is used for cache invalidation. + pub project_id: Option, } /// Info for establishing a connection to a compute node. @@ -249,27 +250,28 @@ pub struct NodeInfo { } pub type NodeInfoCache = TimedLru, NodeInfo>; -pub type CachedNodeInfo = timed_lru::Cached<&'static NodeInfoCache>; -pub type AllowedIpsCache = TimedLru>>; -pub type RoleSecretCache = TimedLru<(SmolStr, SmolStr), Option>; -pub type CachedRoleSecret = timed_lru::Cached<&'static RoleSecretCache>; +pub type CachedNodeInfo = Cached<&'static NodeInfoCache>; +pub type CachedRoleSecret = Cached<&'static ProjectInfoCacheImpl, AuthSecret>; +pub type CachedAllowedIps = Cached<&'static ProjectInfoCacheImpl, Arc>>; /// This will allocate per each call, but the http requests alone /// already require a few allocations, so it should be fine. #[async_trait] pub trait Api { /// Get the client's auth secret for authentication. + /// Returns option because user not found situation is special. + /// We still have to mock the scram to avoid leaking information that user doesn't exist. async fn get_role_secret( &self, ctx: &mut RequestMonitoring, creds: &ComputeUserInfo, - ) -> Result; + ) -> Result, errors::GetAuthInfoError>; async fn get_allowed_ips( &self, ctx: &mut RequestMonitoring, creds: &ComputeUserInfo, - ) -> Result>, errors::GetAuthInfoError>; + ) -> Result; /// Wake up the compute node and return the corresponding connection info. async fn wake_compute( @@ -284,10 +286,25 @@ pub trait Api { pub struct ApiCaches { /// Cache for the `wake_compute` API method. pub node_info: NodeInfoCache, - /// Cache for the `get_allowed_ips`. TODO(anna): use notifications listener instead. - pub allowed_ips: AllowedIpsCache, - /// Cache for the `get_role_secret`. TODO(anna): use notifications listener instead. - pub role_secret: RoleSecretCache, + /// Cache which stores project_id -> endpoint_ids mapping. + pub project_info: Arc, +} + +impl ApiCaches { + pub fn new( + wake_compute_cache_config: CacheOptions, + project_info_cache_config: ProjectInfoCacheOptions, + ) -> Self { + Self { + node_info: NodeInfoCache::new( + "node_info_cache", + wake_compute_cache_config.size, + wake_compute_cache_config.ttl, + true, + ), + project_info: Arc::new(ProjectInfoCacheImpl::new(project_info_cache_config)), + } + } } /// Various caches for [`console`](super). diff --git a/proxy/src/console/provider/mock.rs b/proxy/src/console/provider/mock.rs index fa61ec3221..8f50865288 100644 --- a/proxy/src/console/provider/mock.rs +++ b/proxy/src/console/provider/mock.rs @@ -1,15 +1,17 @@ //! Mock console backend which relies on a user-provided postgres instance. -use std::sync::Arc; - use super::{ errors::{ApiError, GetAuthInfoError, WakeComputeError}, AuthInfo, AuthSecret, CachedNodeInfo, ConsoleReqExtra, NodeInfo, }; +use crate::cache::Cached; +use crate::console::provider::{CachedAllowedIps, CachedRoleSecret}; +use crate::context::RequestMonitoring; use crate::{auth::backend::ComputeUserInfo, compute, error::io_error, scram, url::ApiUrl}; -use crate::{console::provider::CachedRoleSecret, context::RequestMonitoring}; use async_trait::async_trait; use futures::TryFutureExt; +use smol_str::SmolStr; +use std::sync::Arc; use thiserror::Error; use tokio_postgres::{config::SslMode, Client}; use tracing::{error, info, info_span, warn, Instrument}; @@ -98,7 +100,8 @@ impl Api { .await?; Ok(AuthInfo { secret, - allowed_ips, + allowed_ips: allowed_ips.iter().map(SmolStr::from).collect(), + project_id: None, }) } @@ -147,18 +150,22 @@ impl super::Api for Api { &self, _ctx: &mut RequestMonitoring, creds: &ComputeUserInfo, - ) -> Result { - Ok(CachedRoleSecret::new_uncached( - self.do_get_auth_info(creds).await?.secret, - )) + ) -> Result, GetAuthInfoError> { + Ok(self + .do_get_auth_info(creds) + .await? + .secret + .map(CachedRoleSecret::new_uncached)) } async fn get_allowed_ips( &self, _ctx: &mut RequestMonitoring, creds: &ComputeUserInfo, - ) -> Result>, GetAuthInfoError> { - Ok(Arc::new(self.do_get_auth_info(creds).await?.allowed_ips)) + ) -> Result { + Ok(Cached::new_uncached(Arc::new( + self.do_get_auth_info(creds).await?.allowed_ips, + ))) } #[tracing::instrument(skip_all)] diff --git a/proxy/src/console/provider/neon.rs b/proxy/src/console/provider/neon.rs index 7867a1e933..e0bb7952b5 100644 --- a/proxy/src/console/provider/neon.rs +++ b/proxy/src/console/provider/neon.rs @@ -3,17 +3,19 @@ use super::{ super::messages::{ConsoleError, GetRoleSecret, WakeCompute}, errors::{ApiError, GetAuthInfoError, WakeComputeError}, - ApiCaches, ApiLocks, AuthInfo, AuthSecret, CachedNodeInfo, CachedRoleSecret, ConsoleReqExtra, - NodeInfo, + ApiCaches, ApiLocks, AuthInfo, AuthSecret, CachedAllowedIps, CachedNodeInfo, CachedRoleSecret, + ConsoleReqExtra, NodeInfo, }; use crate::{auth::backend::ComputeUserInfo, compute, http, scram}; use crate::{ + cache::Cached, context::RequestMonitoring, metrics::{ALLOWED_IPS_BY_CACHE_OUTCOME, ALLOWED_IPS_NUMBER}, }; use async_trait::async_trait; use futures::TryFutureExt; use itertools::Itertools; +use smol_str::SmolStr; use std::sync::Arc; use tokio::time::Instant; use tokio_postgres::config::SslMode; @@ -22,7 +24,7 @@ use tracing::{error, info, info_span, warn, Instrument}; #[derive(Clone)] pub struct Api { endpoint: http::Endpoint, - caches: &'static ApiCaches, + pub caches: &'static ApiCaches, locks: &'static ApiLocks, jwt: String, } @@ -91,12 +93,13 @@ impl Api { .allowed_ips .into_iter() .flatten() - .map(String::from) + .map(SmolStr::from) .collect_vec(); ALLOWED_IPS_NUMBER.observe(allowed_ips.len() as f64); Ok(AuthInfo { secret: Some(secret), allowed_ips, + project_id: body.project_id.map(SmolStr::from), }) } .map_err(crate::error::log_error) @@ -170,46 +173,56 @@ impl super::Api for Api { &self, ctx: &mut RequestMonitoring, creds: &ComputeUserInfo, - ) -> Result { - let ep = creds.endpoint.clone(); - let user = creds.inner.user.clone(); - if let Some(role_secret) = self.caches.role_secret.get(&(ep.clone(), user.clone())) { - return Ok(role_secret); + ) -> Result, GetAuthInfoError> { + let ep = &creds.endpoint; + let user = &creds.inner.user; + if let Some(role_secret) = self.caches.project_info.get_role_secret(ep, user) { + return Ok(Some(role_secret)); } let auth_info = self.do_get_auth_info(ctx, creds).await?; - let (_, secret) = self - .caches - .role_secret - .insert((ep.clone(), user), auth_info.secret.clone()); - self.caches - .allowed_ips - .insert(ep, Arc::new(auth_info.allowed_ips)); - Ok(secret) + let project_id = auth_info.project_id.unwrap_or(ep.clone()); + if let Some(secret) = &auth_info.secret { + self.caches + .project_info + .insert_role_secret(&project_id, ep, user, secret.clone()) + } + self.caches.project_info.insert_allowed_ips( + &project_id, + ep, + Arc::new(auth_info.allowed_ips), + ); + // When we just got a secret, we don't need to invalidate it. + Ok(auth_info.secret.map(Cached::new_uncached)) } async fn get_allowed_ips( &self, ctx: &mut RequestMonitoring, creds: &ComputeUserInfo, - ) -> Result>, GetAuthInfoError> { - if let Some(allowed_ips) = self.caches.allowed_ips.get(&creds.endpoint) { + ) -> Result { + let ep = &creds.endpoint; + if let Some(allowed_ips) = self.caches.project_info.get_allowed_ips(ep) { ALLOWED_IPS_BY_CACHE_OUTCOME .with_label_values(&["hit"]) .inc(); - return Ok(Arc::new(allowed_ips.to_vec())); + return Ok(allowed_ips); } ALLOWED_IPS_BY_CACHE_OUTCOME .with_label_values(&["miss"]) .inc(); let auth_info = self.do_get_auth_info(ctx, creds).await?; let allowed_ips = Arc::new(auth_info.allowed_ips); - let ep = creds.endpoint.clone(); - let user = creds.inner.user.clone(); + let user = &creds.inner.user; + let project_id = auth_info.project_id.unwrap_or(ep.clone()); + if let Some(secret) = &auth_info.secret { + self.caches + .project_info + .insert_role_secret(&project_id, ep, user, secret.clone()) + } self.caches - .role_secret - .insert((ep.clone(), user), auth_info.secret); - self.caches.allowed_ips.insert(ep, allowed_ips.clone()); - Ok(allowed_ips) + .project_info + .insert_allowed_ips(&project_id, ep, allowed_ips.clone()); + Ok(Cached::new_uncached(allowed_ips)) } #[tracing::instrument(skip_all)] diff --git a/proxy/src/lib.rs b/proxy/src/lib.rs index 87ae8894e1..a22b2459b8 100644 --- a/proxy/src/lib.rs +++ b/proxy/src/lib.rs @@ -22,6 +22,7 @@ pub mod parse; pub mod protocol2; pub mod proxy; pub mod rate_limiter; +pub mod redis; pub mod sasl; pub mod scram; pub mod serverless; diff --git a/proxy/src/proxy/tests.rs b/proxy/src/proxy/tests.rs index 0957f33a92..1d7c9bac54 100644 --- a/proxy/src/proxy/tests.rs +++ b/proxy/src/proxy/tests.rs @@ -12,6 +12,7 @@ use crate::proxy::retry::{retry_after, NUM_RETRIES_CONNECT}; use crate::{auth, http, sasl, scram}; use async_trait::async_trait; use rstest::rstest; +use smol_str::SmolStr; use tokio_postgres::config::SslMode; use tokio_postgres::tls::{MakeTlsConnect, NoTls}; use tokio_postgres_rustls::{MakeRustlsConnect, RustlsStream}; @@ -470,7 +471,7 @@ impl TestBackend for TestConnectMechanism { } } - fn get_allowed_ips(&self) -> Result>, console::errors::GetAuthInfoError> { + fn get_allowed_ips(&self) -> Result, console::errors::GetAuthInfoError> { unimplemented!("not used in tests") } } diff --git a/proxy/src/redis.rs b/proxy/src/redis.rs new file mode 100644 index 0000000000..c2a91bed97 --- /dev/null +++ b/proxy/src/redis.rs @@ -0,0 +1 @@ +pub mod notifications; diff --git a/proxy/src/redis/notifications.rs b/proxy/src/redis/notifications.rs new file mode 100644 index 0000000000..933f2a1bdb --- /dev/null +++ b/proxy/src/redis/notifications.rs @@ -0,0 +1,202 @@ +use std::{convert::Infallible, sync::Arc}; + +use futures::StreamExt; +use redis::aio::PubSub; +use serde::Deserialize; +use smol_str::SmolStr; + +use crate::cache::project_info::ProjectInfoCache; + +const CHANNEL_NAME: &str = "neondb-proxy-ws-updates"; +const RECONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(20); +const INVALIDATION_LAG: std::time::Duration = std::time::Duration::from_secs(20); + +struct ConsoleRedisClient { + client: redis::Client, +} + +impl ConsoleRedisClient { + pub fn new(url: &str) -> anyhow::Result { + let client = redis::Client::open(url)?; + Ok(Self { client }) + } + async fn try_connect(&self) -> anyhow::Result { + let mut conn = self.client.get_async_connection().await?.into_pubsub(); + tracing::info!("subscribing to a channel `{CHANNEL_NAME}`"); + conn.subscribe(CHANNEL_NAME).await?; + Ok(conn) + } +} + +#[derive(Clone, Debug, Deserialize, Eq, PartialEq)] +#[serde(tag = "topic", content = "data")] +enum Notification { + #[serde( + rename = "/allowed_ips_updated", + deserialize_with = "deserialize_json_string" + )] + AllowedIpsUpdate { + allowed_ips_update: AllowedIpsUpdate, + }, + #[serde( + rename = "/password_updated", + deserialize_with = "deserialize_json_string" + )] + PasswordUpdate { password_update: PasswordUpdate }, +} +#[derive(Clone, Debug, Deserialize, Eq, PartialEq)] +struct AllowedIpsUpdate { + #[serde(rename = "project")] + project_id: SmolStr, +} +#[derive(Clone, Debug, Deserialize, Eq, PartialEq)] +struct PasswordUpdate { + #[serde(rename = "project")] + project_id: SmolStr, + #[serde(rename = "role")] + role_name: SmolStr, +} +fn deserialize_json_string<'de, D, T>(deserializer: D) -> Result +where + T: for<'de2> serde::Deserialize<'de2>, + D: serde::Deserializer<'de>, +{ + let s = String::deserialize(deserializer)?; + serde_json::from_str(&s).map_err(::custom) +} + +fn invalidate_cache(cache: Arc, msg: Notification) { + use Notification::*; + match msg { + AllowedIpsUpdate { allowed_ips_update } => { + cache.invalidate_allowed_ips_for_project(&allowed_ips_update.project_id) + } + PasswordUpdate { password_update } => cache.invalidate_role_secret_for_project( + &password_update.project_id, + &password_update.role_name, + ), + } +} + +#[tracing::instrument(skip(cache))] +fn handle_message(msg: redis::Msg, cache: Arc) -> anyhow::Result<()> +where + C: ProjectInfoCache + Send + Sync + 'static, +{ + let payload: String = msg.get_payload()?; + tracing::debug!(?payload, "received a message payload"); + + let msg: Notification = match serde_json::from_str(&payload) { + Ok(msg) => msg, + Err(e) => { + tracing::error!("broken message: {e}"); + return Ok(()); + } + }; + tracing::debug!(?msg, "received a message"); + invalidate_cache(cache.clone(), msg.clone()); + // It might happen that the invalid entry is on the way to be cached. + // To make sure that the entry is invalidated, let's repeat the invalidation in INVALIDATION_LAG seconds. + // TODO: include the version (or the timestamp) in the message and invalidate only if the entry is cached before the message. + tokio::spawn(async move { + tokio::time::sleep(INVALIDATION_LAG).await; + invalidate_cache(cache, msg.clone()); + }); + + Ok(()) +} + +/// Handle console's invalidation messages. +#[tracing::instrument(name = "console_notifications", skip_all)] +pub async fn task_main(url: String, cache: Arc) -> anyhow::Result +where + C: ProjectInfoCache + Send + Sync + 'static, +{ + cache.enable_ttl(); + + loop { + let redis = ConsoleRedisClient::new(&url)?; + let conn = match redis.try_connect().await { + Ok(conn) => { + cache.disable_ttl(); + conn + } + Err(e) => { + tracing::error!( + "failed to connect to redis: {e}, will try to reconnect in {RECONNECT_TIMEOUT:#?}" + ); + tokio::time::sleep(RECONNECT_TIMEOUT).await; + continue; + } + }; + let mut stream = conn.into_on_message(); + while let Some(msg) = stream.next().await { + match handle_message(msg, cache.clone()) { + Ok(()) => {} + Err(e) => { + tracing::error!("failed to handle message: {e}, will try to reconnect"); + break; + } + } + } + cache.enable_ttl(); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + #[test] + fn parse_allowed_ips() -> anyhow::Result<()> { + let project_id = "new_project".to_string(); + let data = format!("{{\"project\": \"{project_id}\"}}"); + let text = json!({ + "type": "message", + "topic": "/allowed_ips_updated", + "data": data, + "extre_fields": "something" + }) + .to_string(); + + let result: Notification = serde_json::from_str(&text)?; + assert_eq!( + result, + Notification::AllowedIpsUpdate { + allowed_ips_update: AllowedIpsUpdate { + project_id: project_id.into() + } + } + ); + + Ok(()) + } + + #[test] + fn parse_password_updated() -> anyhow::Result<()> { + let project_id = "new_project".to_string(); + let role_name = "new_role".to_string(); + let data = format!("{{\"project\": \"{project_id}\", \"role\": \"{role_name}\"}}"); + let text = json!({ + "type": "message", + "topic": "/password_updated", + "data": data, + "extre_fields": "something" + }) + .to_string(); + + let result: Notification = serde_json::from_str(&text)?; + assert_eq!( + result, + Notification::PasswordUpdate { + password_update: PasswordUpdate { + project_id: project_id.into(), + role_name: role_name.into() + } + } + ); + + Ok(()) + } +} diff --git a/proxy/src/scram/key.rs b/proxy/src/scram/key.rs index bd93fb2b70..66c2c6b207 100644 --- a/proxy/src/scram/key.rs +++ b/proxy/src/scram/key.rs @@ -6,7 +6,7 @@ pub const SCRAM_KEY_LEN: usize = 32; /// One of the keys derived from the [password](super::password::SaltedPassword). /// We use the same structure for all keys, i.e. /// `ClientKey`, `StoredKey`, and `ServerKey`. -#[derive(Clone, Default, PartialEq, Eq)] +#[derive(Clone, Default, PartialEq, Eq, Debug)] #[repr(transparent)] pub struct ScramKey { bytes: [u8; SCRAM_KEY_LEN], diff --git a/proxy/src/scram/secret.rs b/proxy/src/scram/secret.rs index 9e74e07af1..041548014a 100644 --- a/proxy/src/scram/secret.rs +++ b/proxy/src/scram/secret.rs @@ -5,7 +5,7 @@ use super::key::ScramKey; /// Server secret is produced from [password](super::password::SaltedPassword) /// and is used throughout the authentication process. -#[derive(Clone)] +#[derive(Clone, Eq, PartialEq, Debug)] pub struct ServerSecret { /// Number of iterations for `PBKDF2` function. pub iterations: u32,