diff --git a/Cargo.lock b/Cargo.lock index 02450709d1..c16331636a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2718,6 +2718,16 @@ dependencies = [ "libc", ] +[[package]] +name = "lasso" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4644821e1c3d7a560fe13d842d13f587c07348a1a05d3a797152d41c90c56df2" +dependencies = [ + "dashmap", + "hashbrown 0.13.2", +] + [[package]] name = "lazy_static" version = "1.4.0" @@ -4075,6 +4085,7 @@ dependencies = [ "hyper-tungstenite", "ipnet", "itertools", + "lasso", "md5", "metrics", "native-tls", @@ -4091,6 +4102,7 @@ dependencies = [ "pq_proto", "prometheus", "rand 0.8.5", + "rand_distr", "rcgen", "redis", "regex", @@ -6803,6 +6815,7 @@ dependencies = [ "futures-sink", "futures-util", "getrandom 0.2.11", + "hashbrown 0.13.2", "hashbrown 0.14.0", "hex", "hmac", diff --git a/Cargo.toml b/Cargo.toml index 0cfe522ff9..271edee742 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -95,6 +95,7 @@ inotify = "0.10.2" ipnet = "2.9.0" itertools = "0.10" jsonwebtoken = "9" +lasso = "0.7" libc = "0.2" md5 = "0.7.0" memoffset = "0.8" diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index 79abe639ed..1247f08ee6 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -31,6 +31,7 @@ hyper-tungstenite.workspace = true hyper.workspace = true ipnet.workspace = true itertools.workspace = true +lasso = { workspace = true, features = ["multi-threaded"] } md5.workspace = true metrics.workspace = true once_cell.workspace = true @@ -92,3 +93,4 @@ rcgen.workspace = true rstest.workspace = true tokio-postgres-rustls.workspace = true walkdir.workspace = true +rand_distr = "0.4" diff --git a/proxy/src/cache/project_info.rs b/proxy/src/cache/project_info.rs index 6f37868a8c..62015312a9 100644 --- a/proxy/src/cache/project_info.rs +++ b/proxy/src/cache/project_info.rs @@ -12,15 +12,18 @@ use tokio::time::Instant; use tracing::{debug, info}; use crate::{ - auth::IpPattern, config::ProjectInfoCacheOptions, console::AuthSecret, EndpointId, ProjectId, - RoleName, + auth::IpPattern, + config::ProjectInfoCacheOptions, + console::AuthSecret, + intern::{EndpointIdInt, ProjectIdInt, RoleNameInt}, + EndpointId, ProjectId, RoleName, }; use super::{Cache, Cached}; pub trait ProjectInfoCache { - fn invalidate_allowed_ips_for_project(&self, project_id: &ProjectId); - fn invalidate_role_secret_for_project(&self, project_id: &ProjectId, role_name: &RoleName); + fn invalidate_allowed_ips_for_project(&self, project_id: ProjectIdInt); + fn invalidate_role_secret_for_project(&self, project_id: ProjectIdInt, role_name: RoleNameInt); fn enable_ttl(&self); fn disable_ttl(&self); } @@ -47,7 +50,7 @@ impl From for Entry { #[derive(Default)] struct EndpointInfo { - secret: std::collections::HashMap>>, + secret: std::collections::HashMap>>, allowed_ips: Option>>>, } @@ -60,11 +63,11 @@ impl EndpointInfo { } pub fn get_role_secret( &self, - role_name: &RoleName, + role_name: RoleNameInt, valid_since: Instant, ignore_cache_since: Option, ) -> Option<(Option, bool)> { - if let Some(secret) = self.secret.get(role_name) { + if let Some(secret) = self.secret.get(&role_name) { if valid_since < secret.created_at { return Some(( secret.value.clone(), @@ -93,8 +96,8 @@ impl EndpointInfo { pub fn invalidate_allowed_ips(&mut self) { self.allowed_ips = None; } - pub fn invalidate_role_secret(&mut self, role_name: &RoleName) { - self.secret.remove(role_name); + pub fn invalidate_role_secret(&mut self, role_name: RoleNameInt) { + self.secret.remove(&role_name); } } @@ -106,9 +109,9 @@ impl EndpointInfo { /// One may ask, why the data is stored per project, when on the user request there is only data about the endpoint available? /// On the cplane side updates are done per project (or per branch), so it's easier to invalidate the whole project cache. pub struct ProjectInfoCacheImpl { - cache: DashMap, + cache: DashMap, - project2ep: DashMap>, + project2ep: DashMap>, config: ProjectInfoCacheOptions, start_time: Instant, @@ -116,11 +119,11 @@ pub struct ProjectInfoCacheImpl { } impl ProjectInfoCache for ProjectInfoCacheImpl { - fn invalidate_allowed_ips_for_project(&self, project_id: &ProjectId) { + fn invalidate_allowed_ips_for_project(&self, project_id: ProjectIdInt) { info!("invalidating allowed ips for project `{}`", project_id); let endpoints = self .project2ep - .get(project_id) + .get(&project_id) .map(|kv| kv.value().clone()) .unwrap_or_default(); for endpoint_id in endpoints { @@ -129,14 +132,14 @@ impl ProjectInfoCache for ProjectInfoCacheImpl { } } } - fn invalidate_role_secret_for_project(&self, project_id: &ProjectId, role_name: &RoleName) { + fn invalidate_role_secret_for_project(&self, project_id: ProjectIdInt, role_name: RoleNameInt) { info!( "invalidating role secret for project_id `{}` and role_name `{}`", - project_id, role_name + project_id, role_name, ); let endpoints = self .project2ep - .get(project_id) + .get(&project_id) .map(|kv| kv.value().clone()) .unwrap_or_default(); for endpoint_id in endpoints { @@ -173,15 +176,17 @@ impl ProjectInfoCacheImpl { endpoint_id: &EndpointId, role_name: &RoleName, ) -> Option>> { + let endpoint_id = EndpointIdInt::get(endpoint_id)?; + let role_name = RoleNameInt::get(role_name)?; let (valid_since, ignore_cache_since) = self.get_cache_times(); - let endpoint_info = self.cache.get(endpoint_id)?; + let endpoint_info = self.cache.get(&endpoint_id)?; let (value, ignore_cache) = endpoint_info.get_role_secret(role_name, valid_since, ignore_cache_since)?; if !ignore_cache { let cached = Cached { token: Some(( self, - CachedLookupInfo::new_role_secret(endpoint_id.clone(), role_name.clone()), + CachedLookupInfo::new_role_secret(endpoint_id, role_name), )), value, }; @@ -193,13 +198,14 @@ impl ProjectInfoCacheImpl { &self, endpoint_id: &EndpointId, ) -> Option>>> { + let endpoint_id = EndpointIdInt::get(endpoint_id)?; let (valid_since, ignore_cache_since) = self.get_cache_times(); - let endpoint_info = self.cache.get(endpoint_id)?; + let endpoint_info = self.cache.get(&endpoint_id)?; let value = endpoint_info.get_allowed_ips(valid_since, ignore_cache_since); let (value, ignore_cache) = value?; if !ignore_cache { let cached = Cached { - token: Some((self, CachedLookupInfo::new_allowed_ips(endpoint_id.clone()))), + token: Some((self, CachedLookupInfo::new_allowed_ips(endpoint_id))), value, }; return Some(cached); @@ -213,14 +219,17 @@ impl ProjectInfoCacheImpl { role_name: &RoleName, secret: Option, ) { + let project_id = ProjectIdInt::from(project_id); + let endpoint_id = EndpointIdInt::from(endpoint_id); + let role_name = RoleNameInt::from(role_name); if self.cache.len() >= self.config.size { // If there are too many entries, wait until the next gc cycle. return; } - self.inser_project2endpoint(project_id, endpoint_id); - let mut entry = self.cache.entry(endpoint_id.clone()).or_default(); + self.insert_project2endpoint(project_id, endpoint_id); + let mut entry = self.cache.entry(endpoint_id).or_default(); if entry.secret.len() < self.config.max_roles { - entry.secret.insert(role_name.clone(), secret.into()); + entry.secret.insert(role_name, secret.into()); } } pub fn insert_allowed_ips( @@ -229,22 +238,21 @@ impl ProjectInfoCacheImpl { endpoint_id: &EndpointId, allowed_ips: Arc>, ) { + let project_id = ProjectIdInt::from(project_id); + let endpoint_id = EndpointIdInt::from(endpoint_id); if self.cache.len() >= self.config.size { // If there are too many entries, wait until the next gc cycle. return; } - self.inser_project2endpoint(project_id, endpoint_id); - self.cache - .entry(endpoint_id.clone()) - .or_default() - .allowed_ips = Some(allowed_ips.into()); + self.insert_project2endpoint(project_id, endpoint_id); + self.cache.entry(endpoint_id).or_default().allowed_ips = Some(allowed_ips.into()); } - fn inser_project2endpoint(&self, project_id: &ProjectId, endpoint_id: &EndpointId) { - if let Some(mut endpoints) = self.project2ep.get_mut(project_id) { - endpoints.insert(endpoint_id.clone()); + fn insert_project2endpoint(&self, project_id: ProjectIdInt, endpoint_id: EndpointIdInt) { + if let Some(mut endpoints) = self.project2ep.get_mut(&project_id) { + endpoints.insert(endpoint_id); } else { self.project2ep - .insert(project_id.clone(), HashSet::from([endpoint_id.clone()])); + .insert(project_id, HashSet::from([endpoint_id])); } } fn get_cache_times(&self) -> (Instant, Option) { @@ -300,18 +308,18 @@ impl ProjectInfoCacheImpl { /// This is used to invalidate cache entries. pub struct CachedLookupInfo { /// Search by this key. - endpoint_id: EndpointId, + endpoint_id: EndpointIdInt, lookup_type: LookupType, } impl CachedLookupInfo { - pub(self) fn new_role_secret(endpoint_id: EndpointId, role_name: RoleName) -> Self { + pub(self) fn new_role_secret(endpoint_id: EndpointIdInt, role_name: RoleNameInt) -> Self { Self { endpoint_id, lookup_type: LookupType::RoleSecret(role_name), } } - pub(self) fn new_allowed_ips(endpoint_id: EndpointId) -> Self { + pub(self) fn new_allowed_ips(endpoint_id: EndpointIdInt) -> Self { Self { endpoint_id, lookup_type: LookupType::AllowedIps, @@ -320,7 +328,7 @@ impl CachedLookupInfo { } enum LookupType { - RoleSecret(RoleName), + RoleSecret(RoleNameInt), AllowedIps, } @@ -335,7 +343,7 @@ impl Cache for ProjectInfoCacheImpl { match &key.lookup_type { LookupType::RoleSecret(role_name) => { if let Some(mut endpoint_info) = self.cache.get_mut(&key.endpoint_id) { - endpoint_info.invalidate_role_secret(role_name); + endpoint_info.invalidate_role_secret(*role_name); } } LookupType::AllowedIps => { @@ -457,7 +465,7 @@ mod tests { assert_eq!(cached.value, secret2); // The only way to invalidate this value is to invalidate via the api. - cache.invalidate_role_secret_for_project(&project_id, &user2); + cache.invalidate_role_secret_for_project((&project_id).into(), (&user2).into()); assert!(cache.get_role_secret(&endpoint_id, &user2).is_none()); let cached = cache.get_allowed_ips(&endpoint_id).unwrap(); diff --git a/proxy/src/intern.rs b/proxy/src/intern.rs new file mode 100644 index 0000000000..a6519bdff9 --- /dev/null +++ b/proxy/src/intern.rs @@ -0,0 +1,237 @@ +use std::{ + hash::BuildHasherDefault, marker::PhantomData, num::NonZeroUsize, ops::Index, sync::OnceLock, +}; + +use lasso::{Capacity, MemoryLimits, Spur, ThreadedRodeo}; +use rustc_hash::FxHasher; + +use crate::{BranchId, EndpointId, ProjectId, RoleName}; + +pub trait InternId: Sized + 'static { + fn get_interner() -> &'static StringInterner; +} + +pub struct StringInterner { + inner: ThreadedRodeo>, + _id: PhantomData, +} + +#[derive(PartialEq, Debug, Clone, Copy, Eq, Hash)] +pub struct InternedString { + inner: Spur, + _id: PhantomData, +} + +impl std::fmt::Display for InternedString { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.as_str().fmt(f) + } +} + +impl InternedString { + pub fn as_str(&self) -> &'static str { + Id::get_interner().inner.resolve(&self.inner) + } + pub fn get(s: &str) -> Option { + Id::get_interner().get(s) + } +} + +impl AsRef for InternedString { + fn as_ref(&self) -> &str { + self.as_str() + } +} + +impl std::ops::Deref for InternedString { + type Target = str; + fn deref(&self) -> &str { + self.as_str() + } +} + +impl<'de, Id: InternId> serde::de::Deserialize<'de> for InternedString { + fn deserialize>(d: D) -> Result { + struct Visitor(PhantomData); + impl<'de, Id: InternId> serde::de::Visitor<'de> for Visitor { + type Value = InternedString; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a string") + } + + fn visit_str(self, v: &str) -> Result + where + E: serde::de::Error, + { + Ok(Id::get_interner().get_or_intern(v)) + } + } + d.deserialize_str(Visitor::(PhantomData)) + } +} + +impl serde::Serialize for InternedString { + fn serialize(&self, s: S) -> Result { + self.as_str().serialize(s) + } +} + +impl StringInterner { + pub fn new() -> Self { + StringInterner { + inner: ThreadedRodeo::with_capacity_memory_limits_and_hasher( + Capacity::new(2500, NonZeroUsize::new(1 << 16).unwrap()), + // unbounded + MemoryLimits::for_memory_usage(usize::MAX), + BuildHasherDefault::::default(), + ), + _id: PhantomData, + } + } + + pub fn is_empty(&self) -> bool { + self.inner.is_empty() + } + + pub fn len(&self) -> usize { + self.inner.len() + } + + pub fn current_memory_usage(&self) -> usize { + self.inner.current_memory_usage() + } + + pub fn get_or_intern(&self, s: &str) -> InternedString { + InternedString { + inner: self.inner.get_or_intern(s), + _id: PhantomData, + } + } + + pub fn get(&self, s: &str) -> Option> { + Some(InternedString { + inner: self.inner.get(s)?, + _id: PhantomData, + }) + } +} + +impl Index> for StringInterner { + type Output = str; + + fn index(&self, index: InternedString) -> &Self::Output { + self.inner.resolve(&index.inner) + } +} + +impl Default for StringInterner { + fn default() -> Self { + Self::new() + } +} + +#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] +pub struct RoleNameTag; +impl InternId for RoleNameTag { + fn get_interner() -> &'static StringInterner { + pub static ROLE_NAMES: OnceLock> = OnceLock::new(); + ROLE_NAMES.get_or_init(Default::default) + } +} +pub type RoleNameInt = InternedString; +impl From<&RoleName> for RoleNameInt { + fn from(value: &RoleName) -> Self { + RoleNameTag::get_interner().get_or_intern(value) + } +} + +#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] +pub struct EndpointIdTag; +impl InternId for EndpointIdTag { + fn get_interner() -> &'static StringInterner { + pub static ROLE_NAMES: OnceLock> = OnceLock::new(); + ROLE_NAMES.get_or_init(Default::default) + } +} +pub type EndpointIdInt = InternedString; +impl From<&EndpointId> for EndpointIdInt { + fn from(value: &EndpointId) -> Self { + EndpointIdTag::get_interner().get_or_intern(value) + } +} + +#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] +pub struct BranchIdTag; +impl InternId for BranchIdTag { + fn get_interner() -> &'static StringInterner { + pub static ROLE_NAMES: OnceLock> = OnceLock::new(); + ROLE_NAMES.get_or_init(Default::default) + } +} +pub type BranchIdInt = InternedString; +impl From<&BranchId> for BranchIdInt { + fn from(value: &BranchId) -> Self { + BranchIdTag::get_interner().get_or_intern(value) + } +} + +#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] +pub struct ProjectIdTag; +impl InternId for ProjectIdTag { + fn get_interner() -> &'static StringInterner { + pub static ROLE_NAMES: OnceLock> = OnceLock::new(); + ROLE_NAMES.get_or_init(Default::default) + } +} +pub type ProjectIdInt = InternedString; +impl From<&ProjectId> for ProjectIdInt { + fn from(value: &ProjectId) -> Self { + ProjectIdTag::get_interner().get_or_intern(value) + } +} + +#[cfg(test)] +mod tests { + use std::sync::OnceLock; + + use crate::intern::StringInterner; + + use super::InternId; + + struct MyId; + impl InternId for MyId { + fn get_interner() -> &'static StringInterner { + pub static ROLE_NAMES: OnceLock> = OnceLock::new(); + ROLE_NAMES.get_or_init(Default::default) + } + } + + #[test] + fn push_many_strings() { + use rand::{rngs::StdRng, Rng, SeedableRng}; + use rand_distr::Zipf; + + let endpoint_dist = Zipf::new(500000, 0.8).unwrap(); + let endpoints = StdRng::seed_from_u64(272488357).sample_iter(endpoint_dist); + + let interner = MyId::get_interner(); + + const N: usize = 100_000; + let mut verify = Vec::with_capacity(N); + for endpoint in endpoints.take(N) { + let endpoint = format!("ep-string-interning-{endpoint}"); + let key = interner.get_or_intern(&endpoint); + verify.push((endpoint, key)); + } + + for (s, key) in verify { + assert_eq!(interner[key], s); + } + + // 2031616/59861 = 34 bytes per string + assert_eq!(interner.len(), 59_861); + // will have other overhead for the internal hashmaps that are not accounted for. + assert_eq!(interner.current_memory_usage(), 2_031_616); + } +} diff --git a/proxy/src/lib.rs b/proxy/src/lib.rs index db6256d611..da7c7f3ed2 100644 --- a/proxy/src/lib.rs +++ b/proxy/src/lib.rs @@ -16,6 +16,7 @@ pub mod console; pub mod context; pub mod error; pub mod http; +pub mod intern; pub mod jemalloc; pub mod logging; pub mod metrics; diff --git a/proxy/src/redis/notifications.rs b/proxy/src/redis/notifications.rs index 9cd70b109b..158884aa17 100644 --- a/proxy/src/redis/notifications.rs +++ b/proxy/src/redis/notifications.rs @@ -4,7 +4,10 @@ use futures::StreamExt; use redis::aio::PubSub; use serde::Deserialize; -use crate::{cache::project_info::ProjectInfoCache, ProjectId, RoleName}; +use crate::{ + cache::project_info::ProjectInfoCache, + intern::{ProjectIdInt, RoleNameInt}, +}; const CHANNEL_NAME: &str = "neondb-proxy-ws-updates"; const RECONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(20); @@ -45,12 +48,12 @@ enum Notification { } #[derive(Clone, Debug, Deserialize, Eq, PartialEq)] struct AllowedIpsUpdate { - project_id: ProjectId, + project_id: ProjectIdInt, } #[derive(Clone, Debug, Deserialize, Eq, PartialEq)] struct PasswordUpdate { - project_id: ProjectId, - role_name: RoleName, + project_id: ProjectIdInt, + role_name: RoleNameInt, } fn deserialize_json_string<'de, D, T>(deserializer: D) -> Result where @@ -65,11 +68,11 @@ fn invalidate_cache(cache: Arc, msg: Notification) { use Notification::*; match msg { AllowedIpsUpdate { allowed_ips_update } => { - cache.invalidate_allowed_ips_for_project(&allowed_ips_update.project_id) + cache.invalidate_allowed_ips_for_project(allowed_ips_update.project_id) } PasswordUpdate { password_update } => cache.invalidate_role_secret_for_project( - &password_update.project_id, - &password_update.role_name, + password_update.project_id, + password_update.role_name, ), } } @@ -141,12 +144,14 @@ where #[cfg(test)] mod tests { + use crate::{ProjectId, RoleName}; + use super::*; use serde_json::json; #[test] fn parse_allowed_ips() -> anyhow::Result<()> { - let project_id = "new_project".to_string(); + let project_id: ProjectId = "new_project".into(); let data = format!("{{\"project_id\": \"{project_id}\"}}"); let text = json!({ "type": "message", @@ -161,7 +166,7 @@ mod tests { result, Notification::AllowedIpsUpdate { allowed_ips_update: AllowedIpsUpdate { - project_id: project_id.into() + project_id: (&project_id).into() } } ); @@ -171,8 +176,8 @@ mod tests { #[test] fn parse_password_updated() -> anyhow::Result<()> { - let project_id = "new_project".to_string(); - let role_name = "new_role".to_string(); + let project_id: ProjectId = "new_project".into(); + let role_name: RoleName = "new_role".into(); let data = format!("{{\"project_id\": \"{project_id}\", \"role_name\": \"{role_name}\"}}"); let text = json!({ "type": "message", @@ -187,8 +192,8 @@ mod tests { result, Notification::PasswordUpdate { password_update: PasswordUpdate { - project_id: project_id.into(), - role_name: role_name.into() + project_id: (&project_id).into(), + role_name: (&role_name).into(), } } ); diff --git a/workspace_hack/Cargo.toml b/workspace_hack/Cargo.toml index f58b912a77..74464dd4c8 100644 --- a/workspace_hack/Cargo.toml +++ b/workspace_hack/Cargo.toml @@ -39,7 +39,8 @@ futures-io = { version = "0.3" } futures-sink = { version = "0.3" } futures-util = { version = "0.3", features = ["channel", "io", "sink"] } getrandom = { version = "0.2", default-features = false, features = ["std"] } -hashbrown = { version = "0.14", default-features = false, features = ["raw"] } +hashbrown-582f2526e08bb6a0 = { package = "hashbrown", version = "0.14", default-features = false, features = ["raw"] } +hashbrown-594e8ee84c453af0 = { package = "hashbrown", version = "0.13", features = ["raw"] } hex = { version = "0.4", features = ["serde"] } hmac = { version = "0.12", default-features = false, features = ["reset"] } hyper = { version = "0.14", features = ["full"] } @@ -91,7 +92,7 @@ cc = { version = "1", default-features = false, features = ["parallel"] } chrono = { version = "0.4", default-features = false, features = ["clock", "serde", "wasmbind"] } either = { version = "1" } getrandom = { version = "0.2", default-features = false, features = ["std"] } -hashbrown = { version = "0.14", default-features = false, features = ["raw"] } +hashbrown-582f2526e08bb6a0 = { package = "hashbrown", version = "0.14", default-features = false, features = ["raw"] } indexmap = { version = "1", default-features = false, features = ["std"] } itertools = { version = "0.10" } libc = { version = "0.2", features = ["extra_traits", "use_std"] }