mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-10 06:52:55 +00:00
Added auth info cache with notifiations to redis. (#6208)
## Problem Current cache doesn't support any updates from the cplane. ## Summary of changes * Added redis notifier listner. * Added cache which can be invalidated with the notifier. If the notifier is not available, it's just a normal ttl cache. * Updated cplane api. The motivation behind this organization of the data is the following: * In the Neon data model there are projects. Projects could have multiple branches and each branch could have more than one endpoint. * Also there is one special `main` branch. * Password reset works per branch. * Allowed IPs are the same for every branch in the project (except, maybe, the main one). * The main branch can be changed to the other branch. * The endpoint can be moved between branches. Every event described above requires some special processing on the porxy (or cplane) side. The idea of invalidating for the project is that whenever one of the events above is happening with the project, proxy can invalidate all entries for the entire project. This approach also requires some additional API change (returning project_id inside the auth info).
This commit is contained in:
47
Cargo.lock
generated
47
Cargo.lock
generated
@@ -1134,6 +1134,20 @@ version = "1.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "acbf1af155f9b9ef647e42cdc158db4b64a1b61f743629225fde6f3e0be2a7c7"
|
||||
|
||||
[[package]]
|
||||
name = "combine"
|
||||
version = "4.6.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "35ed6e9d84f0b51a7f52daf1c7d71dd136fd7a3f41a8462b8cdb8c78d920fad4"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"futures-core",
|
||||
"memchr",
|
||||
"pin-project-lite",
|
||||
"tokio",
|
||||
"tokio-util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "comfy-table"
|
||||
version = "6.1.4"
|
||||
@@ -3899,6 +3913,7 @@ dependencies = [
|
||||
"prometheus",
|
||||
"rand 0.8.5",
|
||||
"rcgen",
|
||||
"redis",
|
||||
"regex",
|
||||
"remote_storage",
|
||||
"reqwest",
|
||||
@@ -4061,6 +4076,32 @@ dependencies = [
|
||||
"yasna",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "redis"
|
||||
version = "0.24.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c580d9cbbe1d1b479e8d67cf9daf6a62c957e6846048408b80b43ac3f6af84cd"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"bytes",
|
||||
"combine",
|
||||
"futures-util",
|
||||
"itoa",
|
||||
"percent-encoding",
|
||||
"pin-project-lite",
|
||||
"rustls",
|
||||
"rustls-native-certs",
|
||||
"rustls-pemfile",
|
||||
"rustls-webpki 0.101.7",
|
||||
"ryu",
|
||||
"sha1_smol",
|
||||
"socket2 0.4.9",
|
||||
"tokio",
|
||||
"tokio-rustls",
|
||||
"tokio-util",
|
||||
"url",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "redox_syscall"
|
||||
version = "0.2.16"
|
||||
@@ -4917,6 +4958,12 @@ dependencies = [
|
||||
"digest",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sha1_smol"
|
||||
version = "1.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ae1a47186c03a32177042e55dbc5fd5aee900b8e0069a8d70fba96a9375cd012"
|
||||
|
||||
[[package]]
|
||||
name = "sha2"
|
||||
version = "0.10.6"
|
||||
|
||||
@@ -114,6 +114,7 @@ pin-project-lite = "0.2"
|
||||
prometheus = {version = "0.13", default_features=false, features = ["process"]} # removes protobuf dependency
|
||||
prost = "0.11"
|
||||
rand = "0.8"
|
||||
redis = { version = "0.24.0", features = ["tokio-rustls-comp", "keep-alive"] }
|
||||
regex = "1.10.2"
|
||||
reqwest = { version = "0.11", default-features = false, features = ["rustls-tls"] }
|
||||
reqwest-tracing = { version = "0.4.0", features = ["opentelemetry_0_19"] }
|
||||
|
||||
@@ -79,6 +79,7 @@ x509-parser.workspace = true
|
||||
native-tls.workspace = true
|
||||
postgres-native-tls.workspace = true
|
||||
postgres-protocol.workspace = true
|
||||
redis.workspace = true
|
||||
smol_str.workspace = true
|
||||
|
||||
workspace_hack.workspace = true
|
||||
|
||||
@@ -8,6 +8,7 @@ use tokio_postgres::config::AuthKeys;
|
||||
|
||||
use crate::auth::credentials::check_peer_addr_is_in_list;
|
||||
use crate::auth::validate_password_and_exchange;
|
||||
use crate::cache::Cached;
|
||||
use crate::console::errors::GetAuthInfoError;
|
||||
use crate::console::AuthSecret;
|
||||
use crate::context::RequestMonitoring;
|
||||
@@ -20,7 +21,7 @@ use crate::{
|
||||
config::AuthenticationConfig,
|
||||
console::{
|
||||
self,
|
||||
provider::{CachedNodeInfo, ConsoleReqExtra},
|
||||
provider::{CachedAllowedIps, CachedNodeInfo, ConsoleReqExtra},
|
||||
Api,
|
||||
},
|
||||
stream, url,
|
||||
@@ -55,7 +56,7 @@ pub enum BackendType<'a, T> {
|
||||
|
||||
pub trait TestBackend: Send + Sync + 'static {
|
||||
fn wake_compute(&self) -> Result<CachedNodeInfo, console::errors::WakeComputeError>;
|
||||
fn get_allowed_ips(&self) -> Result<Arc<Vec<String>>, console::errors::GetAuthInfoError>;
|
||||
fn get_allowed_ips(&self) -> Result<Vec<SmolStr>, console::errors::GetAuthInfoError>;
|
||||
}
|
||||
|
||||
impl std::fmt::Display for BackendType<'_, ()> {
|
||||
@@ -190,18 +191,21 @@ async fn auth_quirks(
|
||||
if !check_peer_addr_is_in_list(&ctx.peer_addr, &allowed_ips) {
|
||||
return Err(auth::AuthError::ip_address_not_allowed());
|
||||
}
|
||||
let cached_secret = api.get_role_secret(ctx, &info).await?;
|
||||
let maybe_secret = api.get_role_secret(ctx, &info).await?;
|
||||
|
||||
let secret = cached_secret.clone().unwrap_or_else(|| {
|
||||
let cached_secret = maybe_secret.unwrap_or_else(|| {
|
||||
// If we don't have an authentication secret, we mock one to
|
||||
// prevent malicious probing (possible due to missing protocol steps).
|
||||
// This mocked secret will never lead to successful authentication.
|
||||
info!("authentication info not found, mocking it");
|
||||
AuthSecret::Scram(scram::ServerSecret::mock(&info.inner.user, rand::random()))
|
||||
Cached::new_uncached(AuthSecret::Scram(scram::ServerSecret::mock(
|
||||
&info.inner.user,
|
||||
rand::random(),
|
||||
)))
|
||||
});
|
||||
match authenticate_with_secret(
|
||||
ctx,
|
||||
secret,
|
||||
cached_secret.value.clone(),
|
||||
info,
|
||||
client,
|
||||
unauthenticated_password,
|
||||
@@ -410,15 +414,15 @@ impl BackendType<'_, ComputeUserInfo> {
|
||||
pub async fn get_allowed_ips(
|
||||
&self,
|
||||
ctx: &mut RequestMonitoring,
|
||||
) -> Result<Arc<Vec<String>>, GetAuthInfoError> {
|
||||
) -> Result<CachedAllowedIps, GetAuthInfoError> {
|
||||
use BackendType::*;
|
||||
match self {
|
||||
Console(api, creds) => api.get_allowed_ips(ctx, creds).await,
|
||||
#[cfg(feature = "testing")]
|
||||
Postgres(api, creds) => api.get_allowed_ips(ctx, creds).await,
|
||||
Link(_) => Ok(Arc::new(vec![])),
|
||||
Link(_) => Ok(Cached::new_uncached(Arc::new(vec![]))),
|
||||
#[cfg(test)]
|
||||
Test(x) => x.get_allowed_ips(),
|
||||
Test(x) => Ok(Cached::new_uncached(Arc::new(x.get_allowed_ips()?))),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -155,7 +155,7 @@ impl ClientCredentials {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn check_peer_addr_is_in_list(peer_addr: &IpAddr, ip_list: &Vec<String>) -> bool {
|
||||
pub fn check_peer_addr_is_in_list(peer_addr: &IpAddr, ip_list: &Vec<SmolStr>) -> bool {
|
||||
if ip_list.is_empty() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -3,15 +3,14 @@ use proxy::auth;
|
||||
use proxy::config::AuthenticationConfig;
|
||||
use proxy::config::CacheOptions;
|
||||
use proxy::config::HttpConfig;
|
||||
use proxy::config::ProjectInfoCacheOptions;
|
||||
use proxy::console;
|
||||
use proxy::console::provider::AllowedIpsCache;
|
||||
use proxy::console::provider::NodeInfoCache;
|
||||
use proxy::console::provider::RoleSecretCache;
|
||||
use proxy::context::parquet::ParquetUploadArgs;
|
||||
use proxy::http;
|
||||
use proxy::rate_limiter::EndpointRateLimiter;
|
||||
use proxy::rate_limiter::RateBucketInfo;
|
||||
use proxy::rate_limiter::RateLimiterConfig;
|
||||
use proxy::redis::notifications;
|
||||
use proxy::serverless::GlobalConnPoolOptions;
|
||||
use proxy::usage_metrics;
|
||||
|
||||
@@ -137,6 +136,12 @@ struct ProxyCliArgs {
|
||||
/// disable ip check for http requests. If it is too time consuming, it could be turned off.
|
||||
#[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
|
||||
disable_ip_check_for_http: bool,
|
||||
/// redis url for notifications.
|
||||
#[clap(long)]
|
||||
redis_notifications: Option<String>,
|
||||
/// cache for `project_info` (use `size=0` to disable)
|
||||
#[clap(long, default_value = config::ProjectInfoCacheOptions::CACHE_DEFAULT_OPTIONS)]
|
||||
project_info_cache: String,
|
||||
|
||||
#[clap(flatten)]
|
||||
parquet_upload: ParquetUploadArgs,
|
||||
@@ -243,6 +248,15 @@ async fn main() -> anyhow::Result<()> {
|
||||
maintenance_tasks.spawn(usage_metrics::task_main(metrics_config));
|
||||
}
|
||||
|
||||
if let auth::BackendType::Console(api, _) = &config.auth_backend {
|
||||
let cache = api.caches.project_info.clone();
|
||||
if let Some(url) = args.redis_notifications {
|
||||
info!("Starting redis notifications listener ({url})");
|
||||
maintenance_tasks.spawn(notifications::task_main(url.to_owned(), cache.clone()));
|
||||
}
|
||||
maintenance_tasks.spawn(async move { cache.clone().gc_worker().await });
|
||||
}
|
||||
|
||||
let maintenance = loop {
|
||||
// get one complete task
|
||||
match futures::future::select(
|
||||
@@ -308,32 +322,17 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
|
||||
let auth_backend = match &args.auth_backend {
|
||||
AuthBackend::Console => {
|
||||
let wake_compute_cache_config: CacheOptions = args.wake_compute_cache.parse()?;
|
||||
let allowed_ips_cache_config: CacheOptions = args.allowed_ips_cache.parse()?;
|
||||
let role_secret_cache_config: CacheOptions = args.role_secret_cache.parse()?;
|
||||
let project_info_cache_config: ProjectInfoCacheOptions =
|
||||
args.project_info_cache.parse()?;
|
||||
|
||||
info!("Using NodeInfoCache (wake_compute) with options={wake_compute_cache_config:?}");
|
||||
info!("Using AllowedIpsCache (wake_compute) with options={allowed_ips_cache_config:?}");
|
||||
info!("Using RoleSecretCache (wake_compute) with options={role_secret_cache_config:?}");
|
||||
let caches = Box::leak(Box::new(console::caches::ApiCaches {
|
||||
node_info: NodeInfoCache::new(
|
||||
"node_info_cache",
|
||||
wake_compute_cache_config.size,
|
||||
wake_compute_cache_config.ttl,
|
||||
true,
|
||||
),
|
||||
allowed_ips: AllowedIpsCache::new(
|
||||
"allowed_ips_cache",
|
||||
allowed_ips_cache_config.size,
|
||||
allowed_ips_cache_config.ttl,
|
||||
false,
|
||||
),
|
||||
role_secret: RoleSecretCache::new(
|
||||
"role_secret_cache",
|
||||
role_secret_cache_config.size,
|
||||
role_secret_cache_config.ttl,
|
||||
false,
|
||||
),
|
||||
}));
|
||||
info!(
|
||||
"Using AllowedIpsCache (wake_compute) with options={project_info_cache_config:?}"
|
||||
);
|
||||
let caches = Box::leak(Box::new(console::caches::ApiCaches::new(
|
||||
wake_compute_cache_config,
|
||||
project_info_cache_config,
|
||||
)));
|
||||
|
||||
let config::WakeComputeLockOptions {
|
||||
shards,
|
||||
|
||||
@@ -1,311 +1,6 @@
|
||||
use std::{
|
||||
borrow::Borrow,
|
||||
hash::Hash,
|
||||
ops::{Deref, DerefMut},
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
use tracing::debug;
|
||||
|
||||
// This seems to make more sense than `lru` or `cached`:
|
||||
//
|
||||
// * `near/nearcore` ditched `cached` in favor of `lru`
|
||||
// (https://github.com/near/nearcore/issues?q=is%3Aissue+lru+is%3Aclosed).
|
||||
//
|
||||
// * `lru` methods use an obscure `KeyRef` type in their contraints (which is deliberately excluded from docs).
|
||||
// This severely hinders its usage both in terms of creating wrappers and supported key types.
|
||||
//
|
||||
// On the other hand, `hashlink` has good download stats and appears to be maintained.
|
||||
use hashlink::{linked_hash_map::RawEntryMut, LruCache};
|
||||
|
||||
/// A generic trait which exposes types of cache's key and value,
|
||||
/// as well as the notion of cache entry invalidation.
|
||||
/// This is useful for [`timed_lru::Cached`].
|
||||
pub trait Cache {
|
||||
/// Entry's key.
|
||||
type Key;
|
||||
|
||||
/// Entry's value.
|
||||
type Value;
|
||||
|
||||
/// Used for entry invalidation.
|
||||
type LookupInfo<Key>;
|
||||
|
||||
/// Invalidate an entry using a lookup info.
|
||||
/// We don't have an empty default impl because it's error-prone.
|
||||
fn invalidate(&self, _: &Self::LookupInfo<Self::Key>);
|
||||
}
|
||||
|
||||
impl<C: Cache> Cache for &C {
|
||||
type Key = C::Key;
|
||||
type Value = C::Value;
|
||||
type LookupInfo<Key> = C::LookupInfo<Key>;
|
||||
|
||||
fn invalidate(&self, info: &Self::LookupInfo<Self::Key>) {
|
||||
C::invalidate(self, info)
|
||||
}
|
||||
}
|
||||
pub mod common;
|
||||
pub mod project_info;
|
||||
mod timed_lru;
|
||||
|
||||
pub use common::{Cache, Cached};
|
||||
pub use timed_lru::TimedLru;
|
||||
pub mod timed_lru {
|
||||
use super::*;
|
||||
|
||||
/// An implementation of timed LRU cache with fixed capacity.
|
||||
/// Key properties:
|
||||
///
|
||||
/// * Whenever a new entry is inserted, the least recently accessed one is evicted.
|
||||
/// The cache also keeps track of entry's insertion time (`created_at`) and TTL (`expires_at`).
|
||||
///
|
||||
/// * If `update_ttl_on_retrieval` is `true`. When the entry is about to be retrieved, we check its expiration timestamp.
|
||||
/// If the entry has expired, we remove it from the cache; Otherwise we bump the
|
||||
/// expiration timestamp (e.g. +5mins) and change its place in LRU list to prolong
|
||||
/// its existence.
|
||||
///
|
||||
/// * There's an API for immediate invalidation (removal) of a cache entry;
|
||||
/// It's useful in case we know for sure that the entry is no longer correct.
|
||||
/// See [`timed_lru::LookupInfo`] & [`timed_lru::Cached`] for more information.
|
||||
///
|
||||
/// * Expired entries are kept in the cache, until they are evicted by the LRU policy,
|
||||
/// or by a successful lookup (i.e. the entry hasn't expired yet).
|
||||
/// There is no background job to reap the expired records.
|
||||
///
|
||||
/// * It's possible for an entry that has not yet expired entry to be evicted
|
||||
/// before expired items. That's a bit wasteful, but probably fine in practice.
|
||||
pub struct TimedLru<K, V> {
|
||||
/// Cache's name for tracing.
|
||||
name: &'static str,
|
||||
|
||||
/// The underlying cache implementation.
|
||||
cache: parking_lot::Mutex<LruCache<K, Entry<V>>>,
|
||||
|
||||
/// Default time-to-live of a single entry.
|
||||
ttl: Duration,
|
||||
|
||||
update_ttl_on_retrieval: bool,
|
||||
}
|
||||
|
||||
impl<K: Hash + Eq, V> Cache for TimedLru<K, V> {
|
||||
type Key = K;
|
||||
type Value = V;
|
||||
type LookupInfo<Key> = LookupInfo<Key>;
|
||||
|
||||
fn invalidate(&self, info: &Self::LookupInfo<K>) {
|
||||
self.invalidate_raw(info)
|
||||
}
|
||||
}
|
||||
|
||||
struct Entry<T> {
|
||||
created_at: Instant,
|
||||
expires_at: Instant,
|
||||
value: T,
|
||||
}
|
||||
|
||||
impl<K: Hash + Eq, V> TimedLru<K, V> {
|
||||
/// Construct a new LRU cache with timed entries.
|
||||
pub fn new(
|
||||
name: &'static str,
|
||||
capacity: usize,
|
||||
ttl: Duration,
|
||||
update_ttl_on_retrieval: bool,
|
||||
) -> Self {
|
||||
Self {
|
||||
name,
|
||||
cache: LruCache::new(capacity).into(),
|
||||
ttl,
|
||||
update_ttl_on_retrieval,
|
||||
}
|
||||
}
|
||||
|
||||
/// Drop an entry from the cache if it's outdated.
|
||||
#[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)]
|
||||
fn invalidate_raw(&self, info: &LookupInfo<K>) {
|
||||
let now = Instant::now();
|
||||
|
||||
// Do costly things before taking the lock.
|
||||
let mut cache = self.cache.lock();
|
||||
let raw_entry = match cache.raw_entry_mut().from_key(&info.key) {
|
||||
RawEntryMut::Vacant(_) => return,
|
||||
RawEntryMut::Occupied(x) => x,
|
||||
};
|
||||
|
||||
// Remove the entry if it was created prior to lookup timestamp.
|
||||
let entry = raw_entry.get();
|
||||
let (created_at, expires_at) = (entry.created_at, entry.expires_at);
|
||||
let should_remove = created_at <= info.created_at || expires_at <= now;
|
||||
|
||||
if should_remove {
|
||||
raw_entry.remove();
|
||||
}
|
||||
|
||||
drop(cache); // drop lock before logging
|
||||
debug!(
|
||||
created_at = format_args!("{created_at:?}"),
|
||||
expires_at = format_args!("{expires_at:?}"),
|
||||
entry_removed = should_remove,
|
||||
"processed a cache entry invalidation event"
|
||||
);
|
||||
}
|
||||
|
||||
/// Try retrieving an entry by its key, then execute `extract` if it exists.
|
||||
#[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)]
|
||||
fn get_raw<Q, R>(&self, key: &Q, extract: impl FnOnce(&K, &Entry<V>) -> R) -> Option<R>
|
||||
where
|
||||
K: Borrow<Q>,
|
||||
Q: Hash + Eq + ?Sized,
|
||||
{
|
||||
let now = Instant::now();
|
||||
let deadline = now.checked_add(self.ttl).expect("time overflow");
|
||||
|
||||
// Do costly things before taking the lock.
|
||||
let mut cache = self.cache.lock();
|
||||
let mut raw_entry = match cache.raw_entry_mut().from_key(key) {
|
||||
RawEntryMut::Vacant(_) => return None,
|
||||
RawEntryMut::Occupied(x) => x,
|
||||
};
|
||||
|
||||
// Immeditely drop the entry if it has expired.
|
||||
let entry = raw_entry.get();
|
||||
if entry.expires_at <= now {
|
||||
raw_entry.remove();
|
||||
return None;
|
||||
}
|
||||
|
||||
let value = extract(raw_entry.key(), entry);
|
||||
let (created_at, expires_at) = (entry.created_at, entry.expires_at);
|
||||
|
||||
// Update the deadline and the entry's position in the LRU list.
|
||||
if self.update_ttl_on_retrieval {
|
||||
raw_entry.get_mut().expires_at = deadline;
|
||||
}
|
||||
raw_entry.to_back();
|
||||
|
||||
drop(cache); // drop lock before logging
|
||||
debug!(
|
||||
created_at = format_args!("{created_at:?}"),
|
||||
old_expires_at = format_args!("{expires_at:?}"),
|
||||
new_expires_at = format_args!("{deadline:?}"),
|
||||
"accessed a cache entry"
|
||||
);
|
||||
|
||||
Some(value)
|
||||
}
|
||||
|
||||
/// Insert an entry to the cache. If an entry with the same key already
|
||||
/// existed, return the previous value and its creation timestamp.
|
||||
#[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)]
|
||||
fn insert_raw(&self, key: K, value: V) -> (Instant, Option<V>) {
|
||||
let created_at = Instant::now();
|
||||
let expires_at = created_at.checked_add(self.ttl).expect("time overflow");
|
||||
|
||||
let entry = Entry {
|
||||
created_at,
|
||||
expires_at,
|
||||
value,
|
||||
};
|
||||
|
||||
// Do costly things before taking the lock.
|
||||
let old = self
|
||||
.cache
|
||||
.lock()
|
||||
.insert(key, entry)
|
||||
.map(|entry| entry.value);
|
||||
|
||||
debug!(
|
||||
created_at = format_args!("{created_at:?}"),
|
||||
expires_at = format_args!("{expires_at:?}"),
|
||||
replaced = old.is_some(),
|
||||
"created a cache entry"
|
||||
);
|
||||
|
||||
(created_at, old)
|
||||
}
|
||||
}
|
||||
|
||||
impl<K: Hash + Eq + Clone, V: Clone> TimedLru<K, V> {
|
||||
pub fn insert(&self, key: K, value: V) -> (Option<V>, Cached<&Self>) {
|
||||
let (created_at, old) = self.insert_raw(key.clone(), value.clone());
|
||||
|
||||
let cached = Cached {
|
||||
token: Some((self, LookupInfo { created_at, key })),
|
||||
value,
|
||||
};
|
||||
|
||||
(old, cached)
|
||||
}
|
||||
}
|
||||
|
||||
impl<K: Hash + Eq, V: Clone> TimedLru<K, V> {
|
||||
/// Retrieve a cached entry in convenient wrapper.
|
||||
pub fn get<Q>(&self, key: &Q) -> Option<timed_lru::Cached<&Self>>
|
||||
where
|
||||
K: Borrow<Q> + Clone,
|
||||
Q: Hash + Eq + ?Sized,
|
||||
{
|
||||
self.get_raw(key, |key, entry| {
|
||||
let info = LookupInfo {
|
||||
created_at: entry.created_at,
|
||||
key: key.clone(),
|
||||
};
|
||||
|
||||
Cached {
|
||||
token: Some((self, info)),
|
||||
value: entry.value.clone(),
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Lookup information for key invalidation.
|
||||
pub struct LookupInfo<K> {
|
||||
/// Time of creation of a cache [`Entry`].
|
||||
/// We use this during invalidation lookups to prevent eviction of a newer
|
||||
/// entry sharing the same key (it might've been inserted by a different
|
||||
/// task after we got the entry we're trying to invalidate now).
|
||||
created_at: Instant,
|
||||
|
||||
/// Search by this key.
|
||||
key: K,
|
||||
}
|
||||
|
||||
/// Wrapper for convenient entry invalidation.
|
||||
pub struct Cached<C: Cache> {
|
||||
/// Cache + lookup info.
|
||||
token: Option<(C, C::LookupInfo<C::Key>)>,
|
||||
|
||||
/// The value itself.
|
||||
value: C::Value,
|
||||
}
|
||||
|
||||
impl<C: Cache> Cached<C> {
|
||||
/// Place any entry into this wrapper; invalidation will be a no-op.
|
||||
pub fn new_uncached(value: C::Value) -> Self {
|
||||
Self { token: None, value }
|
||||
}
|
||||
|
||||
/// Drop this entry from a cache if it's still there.
|
||||
pub fn invalidate(self) -> C::Value {
|
||||
if let Some((cache, info)) = &self.token {
|
||||
cache.invalidate(info);
|
||||
}
|
||||
self.value
|
||||
}
|
||||
|
||||
/// Tell if this entry is actually cached.
|
||||
pub fn cached(&self) -> bool {
|
||||
self.token.is_some()
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: Cache> Deref for Cached<C> {
|
||||
type Target = C::Value;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.value
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: Cache> DerefMut for Cached<C> {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
&mut self.value
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
72
proxy/src/cache/common.rs
vendored
Normal file
72
proxy/src/cache/common.rs
vendored
Normal file
@@ -0,0 +1,72 @@
|
||||
use std::ops::{Deref, DerefMut};
|
||||
|
||||
/// A generic trait which exposes types of cache's key and value,
|
||||
/// as well as the notion of cache entry invalidation.
|
||||
/// This is useful for [`Cached`].
|
||||
pub trait Cache {
|
||||
/// Entry's key.
|
||||
type Key;
|
||||
|
||||
/// Entry's value.
|
||||
type Value;
|
||||
|
||||
/// Used for entry invalidation.
|
||||
type LookupInfo<Key>;
|
||||
|
||||
/// Invalidate an entry using a lookup info.
|
||||
/// We don't have an empty default impl because it's error-prone.
|
||||
fn invalidate(&self, _: &Self::LookupInfo<Self::Key>);
|
||||
}
|
||||
|
||||
impl<C: Cache> Cache for &C {
|
||||
type Key = C::Key;
|
||||
type Value = C::Value;
|
||||
type LookupInfo<Key> = C::LookupInfo<Key>;
|
||||
|
||||
fn invalidate(&self, info: &Self::LookupInfo<Self::Key>) {
|
||||
C::invalidate(self, info)
|
||||
}
|
||||
}
|
||||
|
||||
/// Wrapper for convenient entry invalidation.
|
||||
pub struct Cached<C: Cache, V = <C as Cache>::Value> {
|
||||
/// Cache + lookup info.
|
||||
pub token: Option<(C, C::LookupInfo<C::Key>)>,
|
||||
|
||||
/// The value itself.
|
||||
pub value: V,
|
||||
}
|
||||
|
||||
impl<C: Cache, V> Cached<C, V> {
|
||||
/// Place any entry into this wrapper; invalidation will be a no-op.
|
||||
pub fn new_uncached(value: V) -> Self {
|
||||
Self { token: None, value }
|
||||
}
|
||||
|
||||
/// Drop this entry from a cache if it's still there.
|
||||
pub fn invalidate(self) -> V {
|
||||
if let Some((cache, info)) = &self.token {
|
||||
cache.invalidate(info);
|
||||
}
|
||||
self.value
|
||||
}
|
||||
|
||||
/// Tell if this entry is actually cached.
|
||||
pub fn cached(&self) -> bool {
|
||||
self.token.is_some()
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: Cache, V> Deref for Cached<C, V> {
|
||||
type Target = V;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.value
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: Cache, V> DerefMut for Cached<C, V> {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
&mut self.value
|
||||
}
|
||||
}
|
||||
496
proxy/src/cache/project_info.rs
vendored
Normal file
496
proxy/src/cache/project_info.rs
vendored
Normal file
@@ -0,0 +1,496 @@
|
||||
use std::{
|
||||
collections::HashSet,
|
||||
convert::Infallible,
|
||||
sync::{atomic::AtomicU64, Arc},
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use dashmap::DashMap;
|
||||
use rand::{thread_rng, Rng};
|
||||
use smol_str::SmolStr;
|
||||
use tokio::time::Instant;
|
||||
use tracing::{debug, info};
|
||||
|
||||
use crate::{config::ProjectInfoCacheOptions, console::AuthSecret};
|
||||
|
||||
use super::{Cache, Cached};
|
||||
|
||||
pub trait ProjectInfoCache {
|
||||
fn invalidate_allowed_ips_for_project(&self, project_id: &SmolStr);
|
||||
fn invalidate_role_secret_for_project(&self, project_id: &SmolStr, role_name: &SmolStr);
|
||||
fn enable_ttl(&self);
|
||||
fn disable_ttl(&self);
|
||||
}
|
||||
|
||||
struct Entry<T> {
|
||||
created_at: Instant,
|
||||
value: T,
|
||||
}
|
||||
|
||||
impl<T> Entry<T> {
|
||||
pub fn new(value: T) -> Self {
|
||||
Self {
|
||||
created_at: Instant::now(),
|
||||
value,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<T> for Entry<T> {
|
||||
fn from(value: T) -> Self {
|
||||
Self::new(value)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct EndpointInfo {
|
||||
secret: std::collections::HashMap<SmolStr, Entry<AuthSecret>>,
|
||||
allowed_ips: Option<Entry<Arc<Vec<SmolStr>>>>,
|
||||
}
|
||||
|
||||
impl EndpointInfo {
|
||||
fn check_ignore_cache(ignore_cache_since: Option<Instant>, created_at: Instant) -> bool {
|
||||
match ignore_cache_since {
|
||||
None => false,
|
||||
Some(t) => t < created_at,
|
||||
}
|
||||
}
|
||||
pub fn get_role_secret(
|
||||
&self,
|
||||
role_name: &SmolStr,
|
||||
valid_since: Instant,
|
||||
ignore_cache_since: Option<Instant>,
|
||||
) -> Option<(AuthSecret, bool)> {
|
||||
if let Some(secret) = self.secret.get(role_name) {
|
||||
if valid_since < secret.created_at {
|
||||
return Some((
|
||||
secret.value.clone(),
|
||||
Self::check_ignore_cache(ignore_cache_since, secret.created_at),
|
||||
));
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
pub fn get_allowed_ips(
|
||||
&self,
|
||||
valid_since: Instant,
|
||||
ignore_cache_since: Option<Instant>,
|
||||
) -> Option<(Arc<Vec<SmolStr>>, bool)> {
|
||||
if let Some(allowed_ips) = &self.allowed_ips {
|
||||
if valid_since < allowed_ips.created_at {
|
||||
return Some((
|
||||
allowed_ips.value.clone(),
|
||||
Self::check_ignore_cache(ignore_cache_since, allowed_ips.created_at),
|
||||
));
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
pub fn invalidate_allowed_ips(&mut self) {
|
||||
self.allowed_ips = None;
|
||||
}
|
||||
pub fn invalidate_role_secret(&mut self, role_name: &SmolStr) {
|
||||
self.secret.remove(role_name);
|
||||
}
|
||||
}
|
||||
|
||||
/// Cache for project info.
|
||||
/// This is used to cache auth data for endpoints.
|
||||
/// Invalidation is done by console notifications or by TTL (if console notifications are disabled).
|
||||
///
|
||||
/// We also store endpoint-to-project mapping in the cache, to be able to access per-endpoint data.
|
||||
/// One may ask, why the data is stored per project, when on the user request there is only data about the endpoint available?
|
||||
/// On the cplane side updates are done per project (or per branch), so it's easier to invalidate the whole project cache.
|
||||
pub struct ProjectInfoCacheImpl {
|
||||
cache: DashMap<SmolStr, EndpointInfo>,
|
||||
|
||||
project2ep: DashMap<SmolStr, HashSet<SmolStr>>,
|
||||
config: ProjectInfoCacheOptions,
|
||||
|
||||
start_time: Instant,
|
||||
ttl_disabled_since_us: AtomicU64,
|
||||
}
|
||||
|
||||
impl ProjectInfoCache for ProjectInfoCacheImpl {
|
||||
fn invalidate_allowed_ips_for_project(&self, project_id: &SmolStr) {
|
||||
info!("invalidating allowed ips for project `{}`", project_id);
|
||||
let endpoints = self
|
||||
.project2ep
|
||||
.get(project_id)
|
||||
.map(|kv| kv.value().clone())
|
||||
.unwrap_or_default();
|
||||
for endpoint_id in endpoints {
|
||||
if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) {
|
||||
endpoint_info.invalidate_allowed_ips();
|
||||
}
|
||||
}
|
||||
}
|
||||
fn invalidate_role_secret_for_project(&self, project_id: &SmolStr, role_name: &SmolStr) {
|
||||
info!(
|
||||
"invalidating role secret for project_id `{}` and role_name `{}`",
|
||||
project_id, role_name
|
||||
);
|
||||
let endpoints = self
|
||||
.project2ep
|
||||
.get(project_id)
|
||||
.map(|kv| kv.value().clone())
|
||||
.unwrap_or_default();
|
||||
for endpoint_id in endpoints {
|
||||
if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) {
|
||||
endpoint_info.invalidate_role_secret(role_name);
|
||||
}
|
||||
}
|
||||
}
|
||||
fn enable_ttl(&self) {
|
||||
self.ttl_disabled_since_us
|
||||
.store(u64::MAX, std::sync::atomic::Ordering::Relaxed);
|
||||
}
|
||||
|
||||
fn disable_ttl(&self) {
|
||||
let new_ttl = (self.start_time.elapsed() + self.config.ttl).as_micros() as u64;
|
||||
self.ttl_disabled_since_us
|
||||
.store(new_ttl, std::sync::atomic::Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
impl ProjectInfoCacheImpl {
|
||||
pub fn new(config: ProjectInfoCacheOptions) -> Self {
|
||||
Self {
|
||||
cache: DashMap::new(),
|
||||
project2ep: DashMap::new(),
|
||||
config,
|
||||
ttl_disabled_since_us: AtomicU64::new(u64::MAX),
|
||||
start_time: Instant::now(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_role_secret(
|
||||
&self,
|
||||
endpoint_id: &SmolStr,
|
||||
role_name: &SmolStr,
|
||||
) -> Option<Cached<&Self, AuthSecret>> {
|
||||
let (valid_since, ignore_cache_since) = self.get_cache_times();
|
||||
let endpoint_info = self.cache.get(endpoint_id)?;
|
||||
let (value, ignore_cache) =
|
||||
endpoint_info.get_role_secret(role_name, valid_since, ignore_cache_since)?;
|
||||
if !ignore_cache {
|
||||
let cached = Cached {
|
||||
token: Some((
|
||||
self,
|
||||
CachedLookupInfo::new_role_secret(endpoint_id.clone(), role_name.clone()),
|
||||
)),
|
||||
value,
|
||||
};
|
||||
return Some(cached);
|
||||
}
|
||||
Some(Cached::new_uncached(value))
|
||||
}
|
||||
pub fn get_allowed_ips(
|
||||
&self,
|
||||
endpoint_id: &SmolStr,
|
||||
) -> Option<Cached<&Self, Arc<Vec<SmolStr>>>> {
|
||||
let (valid_since, ignore_cache_since) = self.get_cache_times();
|
||||
let endpoint_info = self.cache.get(endpoint_id)?;
|
||||
let value = endpoint_info.get_allowed_ips(valid_since, ignore_cache_since);
|
||||
let (value, ignore_cache) = value?;
|
||||
if !ignore_cache {
|
||||
let cached = Cached {
|
||||
token: Some((self, CachedLookupInfo::new_allowed_ips(endpoint_id.clone()))),
|
||||
value,
|
||||
};
|
||||
return Some(cached);
|
||||
}
|
||||
Some(Cached::new_uncached(value))
|
||||
}
|
||||
pub fn insert_role_secret(
|
||||
&self,
|
||||
project_id: &SmolStr,
|
||||
endpoint_id: &SmolStr,
|
||||
role_name: &SmolStr,
|
||||
secret: AuthSecret,
|
||||
) {
|
||||
if self.cache.len() >= self.config.size {
|
||||
// If there are too many entries, wait until the next gc cycle.
|
||||
return;
|
||||
}
|
||||
self.inser_project2endpoint(project_id, endpoint_id);
|
||||
let mut entry = self.cache.entry(endpoint_id.clone()).or_default();
|
||||
if entry.secret.len() < self.config.max_roles {
|
||||
entry.secret.insert(role_name.clone(), secret.into());
|
||||
}
|
||||
}
|
||||
pub fn insert_allowed_ips(
|
||||
&self,
|
||||
project_id: &SmolStr,
|
||||
endpoint_id: &SmolStr,
|
||||
allowed_ips: Arc<Vec<SmolStr>>,
|
||||
) {
|
||||
if self.cache.len() >= self.config.size {
|
||||
// If there are too many entries, wait until the next gc cycle.
|
||||
return;
|
||||
}
|
||||
self.inser_project2endpoint(project_id, endpoint_id);
|
||||
self.cache
|
||||
.entry(endpoint_id.clone())
|
||||
.or_default()
|
||||
.allowed_ips = Some(allowed_ips.into());
|
||||
}
|
||||
fn inser_project2endpoint(&self, project_id: &SmolStr, endpoint_id: &SmolStr) {
|
||||
if let Some(mut endpoints) = self.project2ep.get_mut(project_id) {
|
||||
endpoints.insert(endpoint_id.clone());
|
||||
} else {
|
||||
self.project2ep
|
||||
.insert(project_id.clone(), HashSet::from([endpoint_id.clone()]));
|
||||
}
|
||||
}
|
||||
fn get_cache_times(&self) -> (Instant, Option<Instant>) {
|
||||
let mut valid_since = Instant::now() - self.config.ttl;
|
||||
// Only ignore cache if ttl is disabled.
|
||||
let ttl_disabled_since_us = self
|
||||
.ttl_disabled_since_us
|
||||
.load(std::sync::atomic::Ordering::Relaxed);
|
||||
let ignore_cache_since = if ttl_disabled_since_us != u64::MAX {
|
||||
let ignore_cache_since = self.start_time + Duration::from_micros(ttl_disabled_since_us);
|
||||
// We are fine if entry is not older than ttl or was added before we are getting notifications.
|
||||
valid_since = valid_since.min(ignore_cache_since);
|
||||
Some(ignore_cache_since)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
(valid_since, ignore_cache_since)
|
||||
}
|
||||
|
||||
pub async fn gc_worker(&self) -> anyhow::Result<Infallible> {
|
||||
let mut interval =
|
||||
tokio::time::interval(self.config.gc_interval / (self.cache.shards().len()) as u32);
|
||||
loop {
|
||||
interval.tick().await;
|
||||
if self.cache.len() <= self.config.size {
|
||||
// If there are not too many entries, wait until the next gc cycle.
|
||||
continue;
|
||||
}
|
||||
self.gc();
|
||||
}
|
||||
}
|
||||
|
||||
fn gc(&self) {
|
||||
let shard = thread_rng().gen_range(0..self.project2ep.shards().len());
|
||||
debug!(shard, "project_info_cache: performing epoch reclamation");
|
||||
|
||||
// acquire a random shard lock
|
||||
let mut removed = 0;
|
||||
let shard = self.project2ep.shards()[shard].write();
|
||||
for (_, endpoints) in shard.iter() {
|
||||
for endpoint in endpoints.get().iter() {
|
||||
self.cache.remove(endpoint);
|
||||
removed += 1;
|
||||
}
|
||||
}
|
||||
// We can drop this shard only after making sure that all endpoints are removed.
|
||||
drop(shard);
|
||||
info!("project_info_cache: removed {removed} endpoints");
|
||||
}
|
||||
}
|
||||
|
||||
/// Lookup info for project info cache.
|
||||
/// This is used to invalidate cache entries.
|
||||
pub struct CachedLookupInfo {
|
||||
/// Search by this key.
|
||||
endpoint_id: SmolStr,
|
||||
lookup_type: LookupType,
|
||||
}
|
||||
|
||||
impl CachedLookupInfo {
|
||||
pub(self) fn new_role_secret(endpoint_id: SmolStr, role_name: SmolStr) -> Self {
|
||||
Self {
|
||||
endpoint_id,
|
||||
lookup_type: LookupType::RoleSecret(role_name),
|
||||
}
|
||||
}
|
||||
pub(self) fn new_allowed_ips(endpoint_id: SmolStr) -> Self {
|
||||
Self {
|
||||
endpoint_id,
|
||||
lookup_type: LookupType::AllowedIps,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
enum LookupType {
|
||||
RoleSecret(SmolStr),
|
||||
AllowedIps,
|
||||
}
|
||||
|
||||
impl Cache for ProjectInfoCacheImpl {
|
||||
type Key = SmolStr;
|
||||
// Value is not really used here, but we need to specify it.
|
||||
type Value = SmolStr;
|
||||
|
||||
type LookupInfo<Key> = CachedLookupInfo;
|
||||
|
||||
fn invalidate(&self, key: &Self::LookupInfo<SmolStr>) {
|
||||
match &key.lookup_type {
|
||||
LookupType::RoleSecret(role_name) => {
|
||||
if let Some(mut endpoint_info) = self.cache.get_mut(&key.endpoint_id) {
|
||||
endpoint_info.invalidate_role_secret(role_name);
|
||||
}
|
||||
}
|
||||
LookupType::AllowedIps => {
|
||||
if let Some(mut endpoint_info) = self.cache.get_mut(&key.endpoint_id) {
|
||||
endpoint_info.invalidate_allowed_ips();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::{console::AuthSecret, scram::ServerSecret};
|
||||
use smol_str::SmolStr;
|
||||
use std::{sync::Arc, time::Duration};
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_project_info_cache_settings() {
|
||||
tokio::time::pause();
|
||||
let cache = ProjectInfoCacheImpl::new(ProjectInfoCacheOptions {
|
||||
size: 2,
|
||||
max_roles: 2,
|
||||
ttl: Duration::from_secs(1),
|
||||
gc_interval: Duration::from_secs(600),
|
||||
});
|
||||
let project_id = "project".into();
|
||||
let endpoint_id = "endpoint".into();
|
||||
let user1: SmolStr = "user1".into();
|
||||
let user2: SmolStr = "user2".into();
|
||||
let secret1 = AuthSecret::Scram(ServerSecret::mock(user1.as_str(), [1; 32]));
|
||||
let secret2 = AuthSecret::Scram(ServerSecret::mock(user2.as_str(), [2; 32]));
|
||||
let allowed_ips = Arc::new(vec!["allowed_ip1".into(), "allowed_ip2".into()]);
|
||||
cache.insert_role_secret(&project_id, &endpoint_id, &user1, secret1.clone());
|
||||
cache.insert_role_secret(&project_id, &endpoint_id, &user2, secret2.clone());
|
||||
cache.insert_allowed_ips(&project_id, &endpoint_id, allowed_ips.clone());
|
||||
|
||||
let cached = cache.get_role_secret(&endpoint_id, &user1).unwrap();
|
||||
assert!(cached.cached());
|
||||
assert_eq!(cached.value, secret1);
|
||||
let cached = cache.get_role_secret(&endpoint_id, &user2).unwrap();
|
||||
assert!(cached.cached());
|
||||
assert_eq!(cached.value, secret2);
|
||||
|
||||
// Shouldn't add more than 2 roles.
|
||||
let user3: SmolStr = "user3".into();
|
||||
let secret3 = AuthSecret::Scram(ServerSecret::mock(user3.as_str(), [3; 32]));
|
||||
cache.insert_role_secret(&project_id, &endpoint_id, &user3, secret3.clone());
|
||||
assert!(cache.get_role_secret(&endpoint_id, &user3).is_none());
|
||||
|
||||
let cached = cache.get_allowed_ips(&endpoint_id).unwrap();
|
||||
assert!(cached.cached());
|
||||
assert_eq!(cached.value, allowed_ips);
|
||||
|
||||
tokio::time::advance(Duration::from_secs(2)).await;
|
||||
let cached = cache.get_role_secret(&endpoint_id, &user1);
|
||||
assert!(cached.is_none());
|
||||
let cached = cache.get_role_secret(&endpoint_id, &user2);
|
||||
assert!(cached.is_none());
|
||||
let cached = cache.get_allowed_ips(&endpoint_id);
|
||||
assert!(cached.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_project_info_cache_invalidations() {
|
||||
tokio::time::pause();
|
||||
let cache = Arc::new(ProjectInfoCacheImpl::new(ProjectInfoCacheOptions {
|
||||
size: 2,
|
||||
max_roles: 2,
|
||||
ttl: Duration::from_secs(1),
|
||||
gc_interval: Duration::from_secs(600),
|
||||
}));
|
||||
cache.clone().disable_ttl();
|
||||
tokio::time::advance(Duration::from_secs(2)).await;
|
||||
|
||||
let project_id = "project".into();
|
||||
let endpoint_id = "endpoint".into();
|
||||
let user1: SmolStr = "user1".into();
|
||||
let user2: SmolStr = "user2".into();
|
||||
let secret1 = AuthSecret::Scram(ServerSecret::mock(user1.as_str(), [1; 32]));
|
||||
let secret2 = AuthSecret::Scram(ServerSecret::mock(user2.as_str(), [2; 32]));
|
||||
let allowed_ips = Arc::new(vec!["allowed_ip1".into(), "allowed_ip2".into()]);
|
||||
cache.insert_role_secret(&project_id, &endpoint_id, &user1, secret1.clone());
|
||||
cache.insert_role_secret(&project_id, &endpoint_id, &user2, secret2.clone());
|
||||
cache.insert_allowed_ips(&project_id, &endpoint_id, allowed_ips.clone());
|
||||
|
||||
tokio::time::advance(Duration::from_secs(2)).await;
|
||||
// Nothing should be invalidated.
|
||||
|
||||
let cached = cache.get_role_secret(&endpoint_id, &user1).unwrap();
|
||||
// TTL is disabled, so it should be impossible to invalidate this value.
|
||||
assert!(!cached.cached());
|
||||
assert_eq!(cached.value, secret1);
|
||||
|
||||
cached.invalidate(); // Shouldn't do anything.
|
||||
let cached = cache.get_role_secret(&endpoint_id, &user1).unwrap();
|
||||
assert_eq!(cached.value, secret1);
|
||||
|
||||
let cached = cache.get_role_secret(&endpoint_id, &user2).unwrap();
|
||||
assert!(!cached.cached());
|
||||
assert_eq!(cached.value, secret2);
|
||||
|
||||
// The only way to invalidate this value is to invalidate via the api.
|
||||
cache.invalidate_role_secret_for_project(&project_id, &user2);
|
||||
assert!(cache.get_role_secret(&endpoint_id, &user2).is_none());
|
||||
|
||||
let cached = cache.get_allowed_ips(&endpoint_id).unwrap();
|
||||
assert!(!cached.cached());
|
||||
assert_eq!(cached.value, allowed_ips);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_disable_ttl_invalidate_added_before() {
|
||||
tokio::time::pause();
|
||||
let cache = Arc::new(ProjectInfoCacheImpl::new(ProjectInfoCacheOptions {
|
||||
size: 2,
|
||||
max_roles: 2,
|
||||
ttl: Duration::from_secs(1),
|
||||
gc_interval: Duration::from_secs(600),
|
||||
}));
|
||||
|
||||
let project_id = "project".into();
|
||||
let endpoint_id = "endpoint".into();
|
||||
let user1: SmolStr = "user1".into();
|
||||
let user2: SmolStr = "user2".into();
|
||||
let secret1 = AuthSecret::Scram(ServerSecret::mock(user1.as_str(), [1; 32]));
|
||||
let secret2 = AuthSecret::Scram(ServerSecret::mock(user2.as_str(), [2; 32]));
|
||||
let allowed_ips = Arc::new(vec!["allowed_ip1".into(), "allowed_ip2".into()]);
|
||||
cache.insert_role_secret(&project_id, &endpoint_id, &user1, secret1.clone());
|
||||
cache.clone().disable_ttl();
|
||||
tokio::time::advance(Duration::from_millis(100)).await;
|
||||
cache.insert_role_secret(&project_id, &endpoint_id, &user2, secret2.clone());
|
||||
|
||||
// Added before ttl was disabled + ttl should be still cached.
|
||||
let cached = cache.get_role_secret(&endpoint_id, &user1).unwrap();
|
||||
assert!(cached.cached());
|
||||
let cached = cache.get_role_secret(&endpoint_id, &user2).unwrap();
|
||||
assert!(cached.cached());
|
||||
|
||||
tokio::time::advance(Duration::from_secs(1)).await;
|
||||
// Added before ttl was disabled + ttl should expire.
|
||||
assert!(cache.get_role_secret(&endpoint_id, &user1).is_none());
|
||||
assert!(cache.get_role_secret(&endpoint_id, &user2).is_none());
|
||||
|
||||
// Added after ttl was disabled + ttl should not be cached.
|
||||
cache.insert_allowed_ips(&project_id, &endpoint_id, allowed_ips.clone());
|
||||
let cached = cache.get_allowed_ips(&endpoint_id).unwrap();
|
||||
assert!(!cached.cached());
|
||||
|
||||
tokio::time::advance(Duration::from_secs(1)).await;
|
||||
// Added before ttl was disabled + ttl still should expire.
|
||||
assert!(cache.get_role_secret(&endpoint_id, &user1).is_none());
|
||||
assert!(cache.get_role_secret(&endpoint_id, &user2).is_none());
|
||||
// Shouldn't be invalidated.
|
||||
|
||||
let cached = cache.get_allowed_ips(&endpoint_id).unwrap();
|
||||
assert!(!cached.cached());
|
||||
assert_eq!(cached.value, allowed_ips);
|
||||
}
|
||||
}
|
||||
258
proxy/src/cache/timed_lru.rs
vendored
Normal file
258
proxy/src/cache/timed_lru.rs
vendored
Normal file
@@ -0,0 +1,258 @@
|
||||
use std::{
|
||||
borrow::Borrow,
|
||||
hash::Hash,
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
use tracing::debug;
|
||||
|
||||
// This seems to make more sense than `lru` or `cached`:
|
||||
//
|
||||
// * `near/nearcore` ditched `cached` in favor of `lru`
|
||||
// (https://github.com/near/nearcore/issues?q=is%3Aissue+lru+is%3Aclosed).
|
||||
//
|
||||
// * `lru` methods use an obscure `KeyRef` type in their contraints (which is deliberately excluded from docs).
|
||||
// This severely hinders its usage both in terms of creating wrappers and supported key types.
|
||||
//
|
||||
// On the other hand, `hashlink` has good download stats and appears to be maintained.
|
||||
use hashlink::{linked_hash_map::RawEntryMut, LruCache};
|
||||
|
||||
use super::{common::Cached, *};
|
||||
|
||||
/// An implementation of timed LRU cache with fixed capacity.
|
||||
/// Key properties:
|
||||
///
|
||||
/// * Whenever a new entry is inserted, the least recently accessed one is evicted.
|
||||
/// The cache also keeps track of entry's insertion time (`created_at`) and TTL (`expires_at`).
|
||||
///
|
||||
/// * If `update_ttl_on_retrieval` is `true`. When the entry is about to be retrieved, we check its expiration timestamp.
|
||||
/// If the entry has expired, we remove it from the cache; Otherwise we bump the
|
||||
/// expiration timestamp (e.g. +5mins) and change its place in LRU list to prolong
|
||||
/// its existence.
|
||||
///
|
||||
/// * There's an API for immediate invalidation (removal) of a cache entry;
|
||||
/// It's useful in case we know for sure that the entry is no longer correct.
|
||||
/// See [`timed_lru::LookupInfo`] & [`timed_lru::Cached`] for more information.
|
||||
///
|
||||
/// * Expired entries are kept in the cache, until they are evicted by the LRU policy,
|
||||
/// or by a successful lookup (i.e. the entry hasn't expired yet).
|
||||
/// There is no background job to reap the expired records.
|
||||
///
|
||||
/// * It's possible for an entry that has not yet expired entry to be evicted
|
||||
/// before expired items. That's a bit wasteful, but probably fine in practice.
|
||||
pub struct TimedLru<K, V> {
|
||||
/// Cache's name for tracing.
|
||||
name: &'static str,
|
||||
|
||||
/// The underlying cache implementation.
|
||||
cache: parking_lot::Mutex<LruCache<K, Entry<V>>>,
|
||||
|
||||
/// Default time-to-live of a single entry.
|
||||
ttl: Duration,
|
||||
|
||||
update_ttl_on_retrieval: bool,
|
||||
}
|
||||
|
||||
impl<K: Hash + Eq, V> Cache for TimedLru<K, V> {
|
||||
type Key = K;
|
||||
type Value = V;
|
||||
type LookupInfo<Key> = LookupInfo<Key>;
|
||||
|
||||
fn invalidate(&self, info: &Self::LookupInfo<K>) {
|
||||
self.invalidate_raw(info)
|
||||
}
|
||||
}
|
||||
|
||||
struct Entry<T> {
|
||||
created_at: Instant,
|
||||
expires_at: Instant,
|
||||
value: T,
|
||||
}
|
||||
|
||||
impl<K: Hash + Eq, V> TimedLru<K, V> {
|
||||
/// Construct a new LRU cache with timed entries.
|
||||
pub fn new(
|
||||
name: &'static str,
|
||||
capacity: usize,
|
||||
ttl: Duration,
|
||||
update_ttl_on_retrieval: bool,
|
||||
) -> Self {
|
||||
Self {
|
||||
name,
|
||||
cache: LruCache::new(capacity).into(),
|
||||
ttl,
|
||||
update_ttl_on_retrieval,
|
||||
}
|
||||
}
|
||||
|
||||
/// Drop an entry from the cache if it's outdated.
|
||||
#[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)]
|
||||
fn invalidate_raw(&self, info: &LookupInfo<K>) {
|
||||
let now = Instant::now();
|
||||
|
||||
// Do costly things before taking the lock.
|
||||
let mut cache = self.cache.lock();
|
||||
let raw_entry = match cache.raw_entry_mut().from_key(&info.key) {
|
||||
RawEntryMut::Vacant(_) => return,
|
||||
RawEntryMut::Occupied(x) => x,
|
||||
};
|
||||
|
||||
// Remove the entry if it was created prior to lookup timestamp.
|
||||
let entry = raw_entry.get();
|
||||
let (created_at, expires_at) = (entry.created_at, entry.expires_at);
|
||||
let should_remove = created_at <= info.created_at || expires_at <= now;
|
||||
|
||||
if should_remove {
|
||||
raw_entry.remove();
|
||||
}
|
||||
|
||||
drop(cache); // drop lock before logging
|
||||
debug!(
|
||||
created_at = format_args!("{created_at:?}"),
|
||||
expires_at = format_args!("{expires_at:?}"),
|
||||
entry_removed = should_remove,
|
||||
"processed a cache entry invalidation event"
|
||||
);
|
||||
}
|
||||
|
||||
/// Try retrieving an entry by its key, then execute `extract` if it exists.
|
||||
#[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)]
|
||||
fn get_raw<Q, R>(&self, key: &Q, extract: impl FnOnce(&K, &Entry<V>) -> R) -> Option<R>
|
||||
where
|
||||
K: Borrow<Q>,
|
||||
Q: Hash + Eq + ?Sized,
|
||||
{
|
||||
let now = Instant::now();
|
||||
let deadline = now.checked_add(self.ttl).expect("time overflow");
|
||||
|
||||
// Do costly things before taking the lock.
|
||||
let mut cache = self.cache.lock();
|
||||
let mut raw_entry = match cache.raw_entry_mut().from_key(key) {
|
||||
RawEntryMut::Vacant(_) => return None,
|
||||
RawEntryMut::Occupied(x) => x,
|
||||
};
|
||||
|
||||
// Immeditely drop the entry if it has expired.
|
||||
let entry = raw_entry.get();
|
||||
if entry.expires_at <= now {
|
||||
raw_entry.remove();
|
||||
return None;
|
||||
}
|
||||
|
||||
let value = extract(raw_entry.key(), entry);
|
||||
let (created_at, expires_at) = (entry.created_at, entry.expires_at);
|
||||
|
||||
// Update the deadline and the entry's position in the LRU list.
|
||||
if self.update_ttl_on_retrieval {
|
||||
raw_entry.get_mut().expires_at = deadline;
|
||||
}
|
||||
raw_entry.to_back();
|
||||
|
||||
drop(cache); // drop lock before logging
|
||||
debug!(
|
||||
created_at = format_args!("{created_at:?}"),
|
||||
old_expires_at = format_args!("{expires_at:?}"),
|
||||
new_expires_at = format_args!("{deadline:?}"),
|
||||
"accessed a cache entry"
|
||||
);
|
||||
|
||||
Some(value)
|
||||
}
|
||||
|
||||
/// Insert an entry to the cache. If an entry with the same key already
|
||||
/// existed, return the previous value and its creation timestamp.
|
||||
#[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)]
|
||||
fn insert_raw(&self, key: K, value: V) -> (Instant, Option<V>) {
|
||||
let created_at = Instant::now();
|
||||
let expires_at = created_at.checked_add(self.ttl).expect("time overflow");
|
||||
|
||||
let entry = Entry {
|
||||
created_at,
|
||||
expires_at,
|
||||
value,
|
||||
};
|
||||
|
||||
// Do costly things before taking the lock.
|
||||
let old = self
|
||||
.cache
|
||||
.lock()
|
||||
.insert(key, entry)
|
||||
.map(|entry| entry.value);
|
||||
|
||||
debug!(
|
||||
created_at = format_args!("{created_at:?}"),
|
||||
expires_at = format_args!("{expires_at:?}"),
|
||||
replaced = old.is_some(),
|
||||
"created a cache entry"
|
||||
);
|
||||
|
||||
(created_at, old)
|
||||
}
|
||||
}
|
||||
|
||||
impl<K: Hash + Eq + Clone, V: Clone> TimedLru<K, V> {
|
||||
pub fn insert(&self, key: K, value: V) -> (Option<V>, Cached<&Self>) {
|
||||
let (created_at, old) = self.insert_raw(key.clone(), value.clone());
|
||||
|
||||
let cached = Cached {
|
||||
token: Some((self, LookupInfo { created_at, key })),
|
||||
value,
|
||||
};
|
||||
|
||||
(old, cached)
|
||||
}
|
||||
}
|
||||
|
||||
impl<K: Hash + Eq, V: Clone> TimedLru<K, V> {
|
||||
/// Retrieve a cached entry in convenient wrapper.
|
||||
pub fn get<Q>(&self, key: &Q) -> Option<timed_lru::Cached<&Self>>
|
||||
where
|
||||
K: Borrow<Q> + Clone,
|
||||
Q: Hash + Eq + ?Sized,
|
||||
{
|
||||
self.get_raw(key, |key, entry| {
|
||||
let info = LookupInfo {
|
||||
created_at: entry.created_at,
|
||||
key: key.clone(),
|
||||
};
|
||||
|
||||
Cached {
|
||||
token: Some((self, info)),
|
||||
value: entry.value.clone(),
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Retrieve a cached entry in convenient wrapper, ignoring its TTL.
|
||||
pub fn get_ignoring_ttl<Q>(&self, key: &Q) -> Option<timed_lru::Cached<&Self>>
|
||||
where
|
||||
K: Borrow<Q>,
|
||||
Q: Hash + Eq + ?Sized,
|
||||
{
|
||||
let mut cache = self.cache.lock();
|
||||
cache
|
||||
.get(key)
|
||||
.map(|entry| Cached::new_uncached(entry.value.clone()))
|
||||
}
|
||||
|
||||
/// Remove an entry from the cache.
|
||||
pub fn remove<Q>(&self, key: &Q) -> Option<V>
|
||||
where
|
||||
K: Borrow<Q> + Clone,
|
||||
Q: Hash + Eq + ?Sized,
|
||||
{
|
||||
let mut cache = self.cache.lock();
|
||||
cache.remove(key).map(|entry| entry.value)
|
||||
}
|
||||
}
|
||||
|
||||
/// Lookup information for key invalidation.
|
||||
pub struct LookupInfo<K> {
|
||||
/// Time of creation of a cache [`Entry`].
|
||||
/// We use this during invalidation lookups to prevent eviction of a newer
|
||||
/// entry sharing the same key (it might've been inserted by a different
|
||||
/// task after we got the entry we're trying to invalidate now).
|
||||
created_at: Instant,
|
||||
|
||||
/// Search by this key.
|
||||
key: K,
|
||||
}
|
||||
@@ -352,6 +352,69 @@ impl FromStr for CacheOptions {
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper for cmdline cache options parsing.
|
||||
#[derive(Debug)]
|
||||
pub struct ProjectInfoCacheOptions {
|
||||
/// Max number of entries.
|
||||
pub size: usize,
|
||||
/// Entry's time-to-live.
|
||||
pub ttl: Duration,
|
||||
/// Max number of roles per endpoint.
|
||||
pub max_roles: usize,
|
||||
/// Gc interval.
|
||||
pub gc_interval: Duration,
|
||||
}
|
||||
|
||||
impl ProjectInfoCacheOptions {
|
||||
/// Default options for [`crate::console::provider::NodeInfoCache`].
|
||||
pub const CACHE_DEFAULT_OPTIONS: &'static str =
|
||||
"size=10000,ttl=4m,max_roles=10,gc_interval=60m";
|
||||
|
||||
/// Parse cache options passed via cmdline.
|
||||
/// Example: [`Self::CACHE_DEFAULT_OPTIONS`].
|
||||
fn parse(options: &str) -> anyhow::Result<Self> {
|
||||
let mut size = None;
|
||||
let mut ttl = None;
|
||||
let mut max_roles = None;
|
||||
let mut gc_interval = None;
|
||||
|
||||
for option in options.split(',') {
|
||||
let (key, value) = option
|
||||
.split_once('=')
|
||||
.with_context(|| format!("bad key-value pair: {option}"))?;
|
||||
|
||||
match key {
|
||||
"size" => size = Some(value.parse()?),
|
||||
"ttl" => ttl = Some(humantime::parse_duration(value)?),
|
||||
"max_roles" => max_roles = Some(value.parse()?),
|
||||
"gc_interval" => gc_interval = Some(humantime::parse_duration(value)?),
|
||||
unknown => bail!("unknown key: {unknown}"),
|
||||
}
|
||||
}
|
||||
|
||||
// TTL doesn't matter if cache is always empty.
|
||||
if let Some(0) = size {
|
||||
ttl.get_or_insert(Duration::default());
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
size: size.context("missing `size`")?,
|
||||
ttl: ttl.context("missing `ttl`")?,
|
||||
max_roles: max_roles.context("missing `max_roles`")?,
|
||||
gc_interval: gc_interval.context("missing `gc_interval`")?,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for ProjectInfoCacheOptions {
|
||||
type Err = anyhow::Error;
|
||||
|
||||
fn from_str(options: &str) -> Result<Self, Self::Err> {
|
||||
let error = || format!("failed to parse cache options '{options}'");
|
||||
Self::parse(options).with_context(error)
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper for cmdline cache options parsing.
|
||||
pub struct WakeComputeLockOptions {
|
||||
/// The number of shards the lock map should have
|
||||
|
||||
@@ -15,6 +15,7 @@ pub struct ConsoleError {
|
||||
pub struct GetRoleSecret {
|
||||
pub role_secret: Box<str>,
|
||||
pub allowed_ips: Option<Vec<Box<str>>>,
|
||||
pub project_id: Option<Box<str>>,
|
||||
}
|
||||
|
||||
// Manually implement debug to omit sensitive info.
|
||||
@@ -207,12 +208,17 @@ mod tests {
|
||||
"role_secret": "secret",
|
||||
});
|
||||
let _: GetRoleSecret = serde_json::from_str(&json.to_string())?;
|
||||
// Empty `allowed_ips` field.
|
||||
let json = json!({
|
||||
"role_secret": "secret",
|
||||
"allowed_ips": ["8.8.8.8"],
|
||||
});
|
||||
let _: GetRoleSecret = serde_json::from_str(&json.to_string())?;
|
||||
let json = json!({
|
||||
"role_secret": "secret",
|
||||
"allowed_ips": ["8.8.8.8"],
|
||||
"project_id": "project",
|
||||
});
|
||||
let _: GetRoleSecret = serde_json::from_str(&json.to_string())?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -5,8 +5,9 @@ pub mod neon;
|
||||
use super::messages::MetricsAuxInfo;
|
||||
use crate::{
|
||||
auth::backend::ComputeUserInfo,
|
||||
cache::{timed_lru, TimedLru},
|
||||
cache::{project_info::ProjectInfoCacheImpl, Cached, TimedLru},
|
||||
compute,
|
||||
config::{CacheOptions, ProjectInfoCacheOptions},
|
||||
context::RequestMonitoring,
|
||||
scram,
|
||||
};
|
||||
@@ -14,10 +15,8 @@ use async_trait::async_trait;
|
||||
use dashmap::DashMap;
|
||||
use smol_str::SmolStr;
|
||||
use std::{sync::Arc, time::Duration};
|
||||
use tokio::{
|
||||
sync::{OwnedSemaphorePermit, Semaphore},
|
||||
time::Instant,
|
||||
};
|
||||
use tokio::sync::{OwnedSemaphorePermit, Semaphore};
|
||||
use tokio::time::Instant;
|
||||
use tracing::info;
|
||||
|
||||
pub mod errors {
|
||||
@@ -215,7 +214,7 @@ impl ConsoleReqExtra {
|
||||
}
|
||||
|
||||
/// Auth secret which is managed by the cloud.
|
||||
#[derive(Clone)]
|
||||
#[derive(Clone, Eq, PartialEq, Debug)]
|
||||
pub enum AuthSecret {
|
||||
#[cfg(feature = "testing")]
|
||||
/// Md5 hash of user's password.
|
||||
@@ -229,7 +228,9 @@ pub enum AuthSecret {
|
||||
pub struct AuthInfo {
|
||||
pub secret: Option<AuthSecret>,
|
||||
/// List of IP addresses allowed for the autorization.
|
||||
pub allowed_ips: Vec<String>,
|
||||
pub allowed_ips: Vec<SmolStr>,
|
||||
/// Project ID. This is used for cache invalidation.
|
||||
pub project_id: Option<SmolStr>,
|
||||
}
|
||||
|
||||
/// Info for establishing a connection to a compute node.
|
||||
@@ -249,27 +250,28 @@ pub struct NodeInfo {
|
||||
}
|
||||
|
||||
pub type NodeInfoCache = TimedLru<Arc<str>, NodeInfo>;
|
||||
pub type CachedNodeInfo = timed_lru::Cached<&'static NodeInfoCache>;
|
||||
pub type AllowedIpsCache = TimedLru<SmolStr, Arc<Vec<String>>>;
|
||||
pub type RoleSecretCache = TimedLru<(SmolStr, SmolStr), Option<AuthSecret>>;
|
||||
pub type CachedRoleSecret = timed_lru::Cached<&'static RoleSecretCache>;
|
||||
pub type CachedNodeInfo = Cached<&'static NodeInfoCache>;
|
||||
pub type CachedRoleSecret = Cached<&'static ProjectInfoCacheImpl, AuthSecret>;
|
||||
pub type CachedAllowedIps = Cached<&'static ProjectInfoCacheImpl, Arc<Vec<SmolStr>>>;
|
||||
|
||||
/// This will allocate per each call, but the http requests alone
|
||||
/// already require a few allocations, so it should be fine.
|
||||
#[async_trait]
|
||||
pub trait Api {
|
||||
/// Get the client's auth secret for authentication.
|
||||
/// Returns option because user not found situation is special.
|
||||
/// We still have to mock the scram to avoid leaking information that user doesn't exist.
|
||||
async fn get_role_secret(
|
||||
&self,
|
||||
ctx: &mut RequestMonitoring,
|
||||
creds: &ComputeUserInfo,
|
||||
) -> Result<CachedRoleSecret, errors::GetAuthInfoError>;
|
||||
) -> Result<Option<CachedRoleSecret>, errors::GetAuthInfoError>;
|
||||
|
||||
async fn get_allowed_ips(
|
||||
&self,
|
||||
ctx: &mut RequestMonitoring,
|
||||
creds: &ComputeUserInfo,
|
||||
) -> Result<Arc<Vec<String>>, errors::GetAuthInfoError>;
|
||||
) -> Result<CachedAllowedIps, errors::GetAuthInfoError>;
|
||||
|
||||
/// Wake up the compute node and return the corresponding connection info.
|
||||
async fn wake_compute(
|
||||
@@ -284,10 +286,25 @@ pub trait Api {
|
||||
pub struct ApiCaches {
|
||||
/// Cache for the `wake_compute` API method.
|
||||
pub node_info: NodeInfoCache,
|
||||
/// Cache for the `get_allowed_ips`. TODO(anna): use notifications listener instead.
|
||||
pub allowed_ips: AllowedIpsCache,
|
||||
/// Cache for the `get_role_secret`. TODO(anna): use notifications listener instead.
|
||||
pub role_secret: RoleSecretCache,
|
||||
/// Cache which stores project_id -> endpoint_ids mapping.
|
||||
pub project_info: Arc<ProjectInfoCacheImpl>,
|
||||
}
|
||||
|
||||
impl ApiCaches {
|
||||
pub fn new(
|
||||
wake_compute_cache_config: CacheOptions,
|
||||
project_info_cache_config: ProjectInfoCacheOptions,
|
||||
) -> Self {
|
||||
Self {
|
||||
node_info: NodeInfoCache::new(
|
||||
"node_info_cache",
|
||||
wake_compute_cache_config.size,
|
||||
wake_compute_cache_config.ttl,
|
||||
true,
|
||||
),
|
||||
project_info: Arc::new(ProjectInfoCacheImpl::new(project_info_cache_config)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Various caches for [`console`](super).
|
||||
|
||||
@@ -1,15 +1,17 @@
|
||||
//! Mock console backend which relies on a user-provided postgres instance.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::{
|
||||
errors::{ApiError, GetAuthInfoError, WakeComputeError},
|
||||
AuthInfo, AuthSecret, CachedNodeInfo, ConsoleReqExtra, NodeInfo,
|
||||
};
|
||||
use crate::cache::Cached;
|
||||
use crate::console::provider::{CachedAllowedIps, CachedRoleSecret};
|
||||
use crate::context::RequestMonitoring;
|
||||
use crate::{auth::backend::ComputeUserInfo, compute, error::io_error, scram, url::ApiUrl};
|
||||
use crate::{console::provider::CachedRoleSecret, context::RequestMonitoring};
|
||||
use async_trait::async_trait;
|
||||
use futures::TryFutureExt;
|
||||
use smol_str::SmolStr;
|
||||
use std::sync::Arc;
|
||||
use thiserror::Error;
|
||||
use tokio_postgres::{config::SslMode, Client};
|
||||
use tracing::{error, info, info_span, warn, Instrument};
|
||||
@@ -98,7 +100,8 @@ impl Api {
|
||||
.await?;
|
||||
Ok(AuthInfo {
|
||||
secret,
|
||||
allowed_ips,
|
||||
allowed_ips: allowed_ips.iter().map(SmolStr::from).collect(),
|
||||
project_id: None,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -147,18 +150,22 @@ impl super::Api for Api {
|
||||
&self,
|
||||
_ctx: &mut RequestMonitoring,
|
||||
creds: &ComputeUserInfo,
|
||||
) -> Result<CachedRoleSecret, GetAuthInfoError> {
|
||||
Ok(CachedRoleSecret::new_uncached(
|
||||
self.do_get_auth_info(creds).await?.secret,
|
||||
))
|
||||
) -> Result<Option<CachedRoleSecret>, GetAuthInfoError> {
|
||||
Ok(self
|
||||
.do_get_auth_info(creds)
|
||||
.await?
|
||||
.secret
|
||||
.map(CachedRoleSecret::new_uncached))
|
||||
}
|
||||
|
||||
async fn get_allowed_ips(
|
||||
&self,
|
||||
_ctx: &mut RequestMonitoring,
|
||||
creds: &ComputeUserInfo,
|
||||
) -> Result<Arc<Vec<String>>, GetAuthInfoError> {
|
||||
Ok(Arc::new(self.do_get_auth_info(creds).await?.allowed_ips))
|
||||
) -> Result<CachedAllowedIps, GetAuthInfoError> {
|
||||
Ok(Cached::new_uncached(Arc::new(
|
||||
self.do_get_auth_info(creds).await?.allowed_ips,
|
||||
)))
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
|
||||
@@ -3,17 +3,19 @@
|
||||
use super::{
|
||||
super::messages::{ConsoleError, GetRoleSecret, WakeCompute},
|
||||
errors::{ApiError, GetAuthInfoError, WakeComputeError},
|
||||
ApiCaches, ApiLocks, AuthInfo, AuthSecret, CachedNodeInfo, CachedRoleSecret, ConsoleReqExtra,
|
||||
NodeInfo,
|
||||
ApiCaches, ApiLocks, AuthInfo, AuthSecret, CachedAllowedIps, CachedNodeInfo, CachedRoleSecret,
|
||||
ConsoleReqExtra, NodeInfo,
|
||||
};
|
||||
use crate::{auth::backend::ComputeUserInfo, compute, http, scram};
|
||||
use crate::{
|
||||
cache::Cached,
|
||||
context::RequestMonitoring,
|
||||
metrics::{ALLOWED_IPS_BY_CACHE_OUTCOME, ALLOWED_IPS_NUMBER},
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use futures::TryFutureExt;
|
||||
use itertools::Itertools;
|
||||
use smol_str::SmolStr;
|
||||
use std::sync::Arc;
|
||||
use tokio::time::Instant;
|
||||
use tokio_postgres::config::SslMode;
|
||||
@@ -22,7 +24,7 @@ use tracing::{error, info, info_span, warn, Instrument};
|
||||
#[derive(Clone)]
|
||||
pub struct Api {
|
||||
endpoint: http::Endpoint,
|
||||
caches: &'static ApiCaches,
|
||||
pub caches: &'static ApiCaches,
|
||||
locks: &'static ApiLocks,
|
||||
jwt: String,
|
||||
}
|
||||
@@ -91,12 +93,13 @@ impl Api {
|
||||
.allowed_ips
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.map(String::from)
|
||||
.map(SmolStr::from)
|
||||
.collect_vec();
|
||||
ALLOWED_IPS_NUMBER.observe(allowed_ips.len() as f64);
|
||||
Ok(AuthInfo {
|
||||
secret: Some(secret),
|
||||
allowed_ips,
|
||||
project_id: body.project_id.map(SmolStr::from),
|
||||
})
|
||||
}
|
||||
.map_err(crate::error::log_error)
|
||||
@@ -170,46 +173,56 @@ impl super::Api for Api {
|
||||
&self,
|
||||
ctx: &mut RequestMonitoring,
|
||||
creds: &ComputeUserInfo,
|
||||
) -> Result<CachedRoleSecret, GetAuthInfoError> {
|
||||
let ep = creds.endpoint.clone();
|
||||
let user = creds.inner.user.clone();
|
||||
if let Some(role_secret) = self.caches.role_secret.get(&(ep.clone(), user.clone())) {
|
||||
return Ok(role_secret);
|
||||
) -> Result<Option<CachedRoleSecret>, GetAuthInfoError> {
|
||||
let ep = &creds.endpoint;
|
||||
let user = &creds.inner.user;
|
||||
if let Some(role_secret) = self.caches.project_info.get_role_secret(ep, user) {
|
||||
return Ok(Some(role_secret));
|
||||
}
|
||||
let auth_info = self.do_get_auth_info(ctx, creds).await?;
|
||||
let (_, secret) = self
|
||||
.caches
|
||||
.role_secret
|
||||
.insert((ep.clone(), user), auth_info.secret.clone());
|
||||
self.caches
|
||||
.allowed_ips
|
||||
.insert(ep, Arc::new(auth_info.allowed_ips));
|
||||
Ok(secret)
|
||||
let project_id = auth_info.project_id.unwrap_or(ep.clone());
|
||||
if let Some(secret) = &auth_info.secret {
|
||||
self.caches
|
||||
.project_info
|
||||
.insert_role_secret(&project_id, ep, user, secret.clone())
|
||||
}
|
||||
self.caches.project_info.insert_allowed_ips(
|
||||
&project_id,
|
||||
ep,
|
||||
Arc::new(auth_info.allowed_ips),
|
||||
);
|
||||
// When we just got a secret, we don't need to invalidate it.
|
||||
Ok(auth_info.secret.map(Cached::new_uncached))
|
||||
}
|
||||
|
||||
async fn get_allowed_ips(
|
||||
&self,
|
||||
ctx: &mut RequestMonitoring,
|
||||
creds: &ComputeUserInfo,
|
||||
) -> Result<Arc<Vec<String>>, GetAuthInfoError> {
|
||||
if let Some(allowed_ips) = self.caches.allowed_ips.get(&creds.endpoint) {
|
||||
) -> Result<CachedAllowedIps, GetAuthInfoError> {
|
||||
let ep = &creds.endpoint;
|
||||
if let Some(allowed_ips) = self.caches.project_info.get_allowed_ips(ep) {
|
||||
ALLOWED_IPS_BY_CACHE_OUTCOME
|
||||
.with_label_values(&["hit"])
|
||||
.inc();
|
||||
return Ok(Arc::new(allowed_ips.to_vec()));
|
||||
return Ok(allowed_ips);
|
||||
}
|
||||
ALLOWED_IPS_BY_CACHE_OUTCOME
|
||||
.with_label_values(&["miss"])
|
||||
.inc();
|
||||
let auth_info = self.do_get_auth_info(ctx, creds).await?;
|
||||
let allowed_ips = Arc::new(auth_info.allowed_ips);
|
||||
let ep = creds.endpoint.clone();
|
||||
let user = creds.inner.user.clone();
|
||||
let user = &creds.inner.user;
|
||||
let project_id = auth_info.project_id.unwrap_or(ep.clone());
|
||||
if let Some(secret) = &auth_info.secret {
|
||||
self.caches
|
||||
.project_info
|
||||
.insert_role_secret(&project_id, ep, user, secret.clone())
|
||||
}
|
||||
self.caches
|
||||
.role_secret
|
||||
.insert((ep.clone(), user), auth_info.secret);
|
||||
self.caches.allowed_ips.insert(ep, allowed_ips.clone());
|
||||
Ok(allowed_ips)
|
||||
.project_info
|
||||
.insert_allowed_ips(&project_id, ep, allowed_ips.clone());
|
||||
Ok(Cached::new_uncached(allowed_ips))
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
|
||||
@@ -22,6 +22,7 @@ pub mod parse;
|
||||
pub mod protocol2;
|
||||
pub mod proxy;
|
||||
pub mod rate_limiter;
|
||||
pub mod redis;
|
||||
pub mod sasl;
|
||||
pub mod scram;
|
||||
pub mod serverless;
|
||||
|
||||
@@ -12,6 +12,7 @@ use crate::proxy::retry::{retry_after, NUM_RETRIES_CONNECT};
|
||||
use crate::{auth, http, sasl, scram};
|
||||
use async_trait::async_trait;
|
||||
use rstest::rstest;
|
||||
use smol_str::SmolStr;
|
||||
use tokio_postgres::config::SslMode;
|
||||
use tokio_postgres::tls::{MakeTlsConnect, NoTls};
|
||||
use tokio_postgres_rustls::{MakeRustlsConnect, RustlsStream};
|
||||
@@ -470,7 +471,7 @@ impl TestBackend for TestConnectMechanism {
|
||||
}
|
||||
}
|
||||
|
||||
fn get_allowed_ips(&self) -> Result<Arc<Vec<String>>, console::errors::GetAuthInfoError> {
|
||||
fn get_allowed_ips(&self) -> Result<Vec<SmolStr>, console::errors::GetAuthInfoError> {
|
||||
unimplemented!("not used in tests")
|
||||
}
|
||||
}
|
||||
|
||||
1
proxy/src/redis.rs
Normal file
1
proxy/src/redis.rs
Normal file
@@ -0,0 +1 @@
|
||||
pub mod notifications;
|
||||
202
proxy/src/redis/notifications.rs
Normal file
202
proxy/src/redis/notifications.rs
Normal file
@@ -0,0 +1,202 @@
|
||||
use std::{convert::Infallible, sync::Arc};
|
||||
|
||||
use futures::StreamExt;
|
||||
use redis::aio::PubSub;
|
||||
use serde::Deserialize;
|
||||
use smol_str::SmolStr;
|
||||
|
||||
use crate::cache::project_info::ProjectInfoCache;
|
||||
|
||||
const CHANNEL_NAME: &str = "neondb-proxy-ws-updates";
|
||||
const RECONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(20);
|
||||
const INVALIDATION_LAG: std::time::Duration = std::time::Duration::from_secs(20);
|
||||
|
||||
struct ConsoleRedisClient {
|
||||
client: redis::Client,
|
||||
}
|
||||
|
||||
impl ConsoleRedisClient {
|
||||
pub fn new(url: &str) -> anyhow::Result<Self> {
|
||||
let client = redis::Client::open(url)?;
|
||||
Ok(Self { client })
|
||||
}
|
||||
async fn try_connect(&self) -> anyhow::Result<PubSub> {
|
||||
let mut conn = self.client.get_async_connection().await?.into_pubsub();
|
||||
tracing::info!("subscribing to a channel `{CHANNEL_NAME}`");
|
||||
conn.subscribe(CHANNEL_NAME).await?;
|
||||
Ok(conn)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
|
||||
#[serde(tag = "topic", content = "data")]
|
||||
enum Notification {
|
||||
#[serde(
|
||||
rename = "/allowed_ips_updated",
|
||||
deserialize_with = "deserialize_json_string"
|
||||
)]
|
||||
AllowedIpsUpdate {
|
||||
allowed_ips_update: AllowedIpsUpdate,
|
||||
},
|
||||
#[serde(
|
||||
rename = "/password_updated",
|
||||
deserialize_with = "deserialize_json_string"
|
||||
)]
|
||||
PasswordUpdate { password_update: PasswordUpdate },
|
||||
}
|
||||
#[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
|
||||
struct AllowedIpsUpdate {
|
||||
#[serde(rename = "project")]
|
||||
project_id: SmolStr,
|
||||
}
|
||||
#[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
|
||||
struct PasswordUpdate {
|
||||
#[serde(rename = "project")]
|
||||
project_id: SmolStr,
|
||||
#[serde(rename = "role")]
|
||||
role_name: SmolStr,
|
||||
}
|
||||
fn deserialize_json_string<'de, D, T>(deserializer: D) -> Result<T, D::Error>
|
||||
where
|
||||
T: for<'de2> serde::Deserialize<'de2>,
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
let s = String::deserialize(deserializer)?;
|
||||
serde_json::from_str(&s).map_err(<D::Error as serde::de::Error>::custom)
|
||||
}
|
||||
|
||||
fn invalidate_cache<C: ProjectInfoCache>(cache: Arc<C>, msg: Notification) {
|
||||
use Notification::*;
|
||||
match msg {
|
||||
AllowedIpsUpdate { allowed_ips_update } => {
|
||||
cache.invalidate_allowed_ips_for_project(&allowed_ips_update.project_id)
|
||||
}
|
||||
PasswordUpdate { password_update } => cache.invalidate_role_secret_for_project(
|
||||
&password_update.project_id,
|
||||
&password_update.role_name,
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(cache))]
|
||||
fn handle_message<C>(msg: redis::Msg, cache: Arc<C>) -> anyhow::Result<()>
|
||||
where
|
||||
C: ProjectInfoCache + Send + Sync + 'static,
|
||||
{
|
||||
let payload: String = msg.get_payload()?;
|
||||
tracing::debug!(?payload, "received a message payload");
|
||||
|
||||
let msg: Notification = match serde_json::from_str(&payload) {
|
||||
Ok(msg) => msg,
|
||||
Err(e) => {
|
||||
tracing::error!("broken message: {e}");
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
tracing::debug!(?msg, "received a message");
|
||||
invalidate_cache(cache.clone(), msg.clone());
|
||||
// It might happen that the invalid entry is on the way to be cached.
|
||||
// To make sure that the entry is invalidated, let's repeat the invalidation in INVALIDATION_LAG seconds.
|
||||
// TODO: include the version (or the timestamp) in the message and invalidate only if the entry is cached before the message.
|
||||
tokio::spawn(async move {
|
||||
tokio::time::sleep(INVALIDATION_LAG).await;
|
||||
invalidate_cache(cache, msg.clone());
|
||||
});
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handle console's invalidation messages.
|
||||
#[tracing::instrument(name = "console_notifications", skip_all)]
|
||||
pub async fn task_main<C>(url: String, cache: Arc<C>) -> anyhow::Result<Infallible>
|
||||
where
|
||||
C: ProjectInfoCache + Send + Sync + 'static,
|
||||
{
|
||||
cache.enable_ttl();
|
||||
|
||||
loop {
|
||||
let redis = ConsoleRedisClient::new(&url)?;
|
||||
let conn = match redis.try_connect().await {
|
||||
Ok(conn) => {
|
||||
cache.disable_ttl();
|
||||
conn
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!(
|
||||
"failed to connect to redis: {e}, will try to reconnect in {RECONNECT_TIMEOUT:#?}"
|
||||
);
|
||||
tokio::time::sleep(RECONNECT_TIMEOUT).await;
|
||||
continue;
|
||||
}
|
||||
};
|
||||
let mut stream = conn.into_on_message();
|
||||
while let Some(msg) = stream.next().await {
|
||||
match handle_message(msg, cache.clone()) {
|
||||
Ok(()) => {}
|
||||
Err(e) => {
|
||||
tracing::error!("failed to handle message: {e}, will try to reconnect");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
cache.enable_ttl();
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn parse_allowed_ips() -> anyhow::Result<()> {
|
||||
let project_id = "new_project".to_string();
|
||||
let data = format!("{{\"project\": \"{project_id}\"}}");
|
||||
let text = json!({
|
||||
"type": "message",
|
||||
"topic": "/allowed_ips_updated",
|
||||
"data": data,
|
||||
"extre_fields": "something"
|
||||
})
|
||||
.to_string();
|
||||
|
||||
let result: Notification = serde_json::from_str(&text)?;
|
||||
assert_eq!(
|
||||
result,
|
||||
Notification::AllowedIpsUpdate {
|
||||
allowed_ips_update: AllowedIpsUpdate {
|
||||
project_id: project_id.into()
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_password_updated() -> anyhow::Result<()> {
|
||||
let project_id = "new_project".to_string();
|
||||
let role_name = "new_role".to_string();
|
||||
let data = format!("{{\"project\": \"{project_id}\", \"role\": \"{role_name}\"}}");
|
||||
let text = json!({
|
||||
"type": "message",
|
||||
"topic": "/password_updated",
|
||||
"data": data,
|
||||
"extre_fields": "something"
|
||||
})
|
||||
.to_string();
|
||||
|
||||
let result: Notification = serde_json::from_str(&text)?;
|
||||
assert_eq!(
|
||||
result,
|
||||
Notification::PasswordUpdate {
|
||||
password_update: PasswordUpdate {
|
||||
project_id: project_id.into(),
|
||||
role_name: role_name.into()
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -6,7 +6,7 @@ pub const SCRAM_KEY_LEN: usize = 32;
|
||||
/// One of the keys derived from the [password](super::password::SaltedPassword).
|
||||
/// We use the same structure for all keys, i.e.
|
||||
/// `ClientKey`, `StoredKey`, and `ServerKey`.
|
||||
#[derive(Clone, Default, PartialEq, Eq)]
|
||||
#[derive(Clone, Default, PartialEq, Eq, Debug)]
|
||||
#[repr(transparent)]
|
||||
pub struct ScramKey {
|
||||
bytes: [u8; SCRAM_KEY_LEN],
|
||||
|
||||
@@ -5,7 +5,7 @@ use super::key::ScramKey;
|
||||
|
||||
/// Server secret is produced from [password](super::password::SaltedPassword)
|
||||
/// and is used throughout the authentication process.
|
||||
#[derive(Clone)]
|
||||
#[derive(Clone, Eq, PartialEq, Debug)]
|
||||
pub struct ServerSecret {
|
||||
/// Number of iterations for `PBKDF2` function.
|
||||
pub iterations: u32,
|
||||
|
||||
Reference in New Issue
Block a user