use string interner for project cache (#6578)

## Problem

Running some memory profiling with high concurrent request rate shows
seemingly some memory fragmentation.

## Summary of changes

Eventually, we will want to separate global memory (caches) from local
memory (per connection handshake and per passthrough).

Using a string interner for project info cache helps reduce some of the
fragmentation of the global cache by having a single heap dedicated to
project strings, and not scattering them throughout all a requests.

At the same time, the interned key is 4 bytes vs the 24 bytes that
`SmolStr` offers.

Important: we should only store verified strings in the interner because
there's no way to remove them afterwards. Good for caching responses
from console.
This commit is contained in:
Conrad Ludgate
2024-02-05 14:27:25 +00:00
committed by GitHub
parent 5e8deca268
commit 74c5e3d9b8
8 changed files with 321 additions and 53 deletions

13
Cargo.lock generated
View File

@@ -2718,6 +2718,16 @@ dependencies = [
"libc",
]
[[package]]
name = "lasso"
version = "0.7.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4644821e1c3d7a560fe13d842d13f587c07348a1a05d3a797152d41c90c56df2"
dependencies = [
"dashmap",
"hashbrown 0.13.2",
]
[[package]]
name = "lazy_static"
version = "1.4.0"
@@ -4075,6 +4085,7 @@ dependencies = [
"hyper-tungstenite",
"ipnet",
"itertools",
"lasso",
"md5",
"metrics",
"native-tls",
@@ -4091,6 +4102,7 @@ dependencies = [
"pq_proto",
"prometheus",
"rand 0.8.5",
"rand_distr",
"rcgen",
"redis",
"regex",
@@ -6803,6 +6815,7 @@ dependencies = [
"futures-sink",
"futures-util",
"getrandom 0.2.11",
"hashbrown 0.13.2",
"hashbrown 0.14.0",
"hex",
"hmac",

View File

@@ -95,6 +95,7 @@ inotify = "0.10.2"
ipnet = "2.9.0"
itertools = "0.10"
jsonwebtoken = "9"
lasso = "0.7"
libc = "0.2"
md5 = "0.7.0"
memoffset = "0.8"

View File

@@ -31,6 +31,7 @@ hyper-tungstenite.workspace = true
hyper.workspace = true
ipnet.workspace = true
itertools.workspace = true
lasso = { workspace = true, features = ["multi-threaded"] }
md5.workspace = true
metrics.workspace = true
once_cell.workspace = true
@@ -92,3 +93,4 @@ rcgen.workspace = true
rstest.workspace = true
tokio-postgres-rustls.workspace = true
walkdir.workspace = true
rand_distr = "0.4"

View File

@@ -12,15 +12,18 @@ use tokio::time::Instant;
use tracing::{debug, info};
use crate::{
auth::IpPattern, config::ProjectInfoCacheOptions, console::AuthSecret, EndpointId, ProjectId,
RoleName,
auth::IpPattern,
config::ProjectInfoCacheOptions,
console::AuthSecret,
intern::{EndpointIdInt, ProjectIdInt, RoleNameInt},
EndpointId, ProjectId, RoleName,
};
use super::{Cache, Cached};
pub trait ProjectInfoCache {
fn invalidate_allowed_ips_for_project(&self, project_id: &ProjectId);
fn invalidate_role_secret_for_project(&self, project_id: &ProjectId, role_name: &RoleName);
fn invalidate_allowed_ips_for_project(&self, project_id: ProjectIdInt);
fn invalidate_role_secret_for_project(&self, project_id: ProjectIdInt, role_name: RoleNameInt);
fn enable_ttl(&self);
fn disable_ttl(&self);
}
@@ -47,7 +50,7 @@ impl<T> From<T> for Entry<T> {
#[derive(Default)]
struct EndpointInfo {
secret: std::collections::HashMap<RoleName, Entry<Option<AuthSecret>>>,
secret: std::collections::HashMap<RoleNameInt, Entry<Option<AuthSecret>>>,
allowed_ips: Option<Entry<Arc<Vec<IpPattern>>>>,
}
@@ -60,11 +63,11 @@ impl EndpointInfo {
}
pub fn get_role_secret(
&self,
role_name: &RoleName,
role_name: RoleNameInt,
valid_since: Instant,
ignore_cache_since: Option<Instant>,
) -> Option<(Option<AuthSecret>, bool)> {
if let Some(secret) = self.secret.get(role_name) {
if let Some(secret) = self.secret.get(&role_name) {
if valid_since < secret.created_at {
return Some((
secret.value.clone(),
@@ -93,8 +96,8 @@ impl EndpointInfo {
pub fn invalidate_allowed_ips(&mut self) {
self.allowed_ips = None;
}
pub fn invalidate_role_secret(&mut self, role_name: &RoleName) {
self.secret.remove(role_name);
pub fn invalidate_role_secret(&mut self, role_name: RoleNameInt) {
self.secret.remove(&role_name);
}
}
@@ -106,9 +109,9 @@ impl EndpointInfo {
/// One may ask, why the data is stored per project, when on the user request there is only data about the endpoint available?
/// On the cplane side updates are done per project (or per branch), so it's easier to invalidate the whole project cache.
pub struct ProjectInfoCacheImpl {
cache: DashMap<EndpointId, EndpointInfo>,
cache: DashMap<EndpointIdInt, EndpointInfo>,
project2ep: DashMap<ProjectId, HashSet<EndpointId>>,
project2ep: DashMap<ProjectIdInt, HashSet<EndpointIdInt>>,
config: ProjectInfoCacheOptions,
start_time: Instant,
@@ -116,11 +119,11 @@ pub struct ProjectInfoCacheImpl {
}
impl ProjectInfoCache for ProjectInfoCacheImpl {
fn invalidate_allowed_ips_for_project(&self, project_id: &ProjectId) {
fn invalidate_allowed_ips_for_project(&self, project_id: ProjectIdInt) {
info!("invalidating allowed ips for project `{}`", project_id);
let endpoints = self
.project2ep
.get(project_id)
.get(&project_id)
.map(|kv| kv.value().clone())
.unwrap_or_default();
for endpoint_id in endpoints {
@@ -129,14 +132,14 @@ impl ProjectInfoCache for ProjectInfoCacheImpl {
}
}
}
fn invalidate_role_secret_for_project(&self, project_id: &ProjectId, role_name: &RoleName) {
fn invalidate_role_secret_for_project(&self, project_id: ProjectIdInt, role_name: RoleNameInt) {
info!(
"invalidating role secret for project_id `{}` and role_name `{}`",
project_id, role_name
project_id, role_name,
);
let endpoints = self
.project2ep
.get(project_id)
.get(&project_id)
.map(|kv| kv.value().clone())
.unwrap_or_default();
for endpoint_id in endpoints {
@@ -173,15 +176,17 @@ impl ProjectInfoCacheImpl {
endpoint_id: &EndpointId,
role_name: &RoleName,
) -> Option<Cached<&Self, Option<AuthSecret>>> {
let endpoint_id = EndpointIdInt::get(endpoint_id)?;
let role_name = RoleNameInt::get(role_name)?;
let (valid_since, ignore_cache_since) = self.get_cache_times();
let endpoint_info = self.cache.get(endpoint_id)?;
let endpoint_info = self.cache.get(&endpoint_id)?;
let (value, ignore_cache) =
endpoint_info.get_role_secret(role_name, valid_since, ignore_cache_since)?;
if !ignore_cache {
let cached = Cached {
token: Some((
self,
CachedLookupInfo::new_role_secret(endpoint_id.clone(), role_name.clone()),
CachedLookupInfo::new_role_secret(endpoint_id, role_name),
)),
value,
};
@@ -193,13 +198,14 @@ impl ProjectInfoCacheImpl {
&self,
endpoint_id: &EndpointId,
) -> Option<Cached<&Self, Arc<Vec<IpPattern>>>> {
let endpoint_id = EndpointIdInt::get(endpoint_id)?;
let (valid_since, ignore_cache_since) = self.get_cache_times();
let endpoint_info = self.cache.get(endpoint_id)?;
let endpoint_info = self.cache.get(&endpoint_id)?;
let value = endpoint_info.get_allowed_ips(valid_since, ignore_cache_since);
let (value, ignore_cache) = value?;
if !ignore_cache {
let cached = Cached {
token: Some((self, CachedLookupInfo::new_allowed_ips(endpoint_id.clone()))),
token: Some((self, CachedLookupInfo::new_allowed_ips(endpoint_id))),
value,
};
return Some(cached);
@@ -213,14 +219,17 @@ impl ProjectInfoCacheImpl {
role_name: &RoleName,
secret: Option<AuthSecret>,
) {
let project_id = ProjectIdInt::from(project_id);
let endpoint_id = EndpointIdInt::from(endpoint_id);
let role_name = RoleNameInt::from(role_name);
if self.cache.len() >= self.config.size {
// If there are too many entries, wait until the next gc cycle.
return;
}
self.inser_project2endpoint(project_id, endpoint_id);
let mut entry = self.cache.entry(endpoint_id.clone()).or_default();
self.insert_project2endpoint(project_id, endpoint_id);
let mut entry = self.cache.entry(endpoint_id).or_default();
if entry.secret.len() < self.config.max_roles {
entry.secret.insert(role_name.clone(), secret.into());
entry.secret.insert(role_name, secret.into());
}
}
pub fn insert_allowed_ips(
@@ -229,22 +238,21 @@ impl ProjectInfoCacheImpl {
endpoint_id: &EndpointId,
allowed_ips: Arc<Vec<IpPattern>>,
) {
let project_id = ProjectIdInt::from(project_id);
let endpoint_id = EndpointIdInt::from(endpoint_id);
if self.cache.len() >= self.config.size {
// If there are too many entries, wait until the next gc cycle.
return;
}
self.inser_project2endpoint(project_id, endpoint_id);
self.cache
.entry(endpoint_id.clone())
.or_default()
.allowed_ips = Some(allowed_ips.into());
self.insert_project2endpoint(project_id, endpoint_id);
self.cache.entry(endpoint_id).or_default().allowed_ips = Some(allowed_ips.into());
}
fn inser_project2endpoint(&self, project_id: &ProjectId, endpoint_id: &EndpointId) {
if let Some(mut endpoints) = self.project2ep.get_mut(project_id) {
endpoints.insert(endpoint_id.clone());
fn insert_project2endpoint(&self, project_id: ProjectIdInt, endpoint_id: EndpointIdInt) {
if let Some(mut endpoints) = self.project2ep.get_mut(&project_id) {
endpoints.insert(endpoint_id);
} else {
self.project2ep
.insert(project_id.clone(), HashSet::from([endpoint_id.clone()]));
.insert(project_id, HashSet::from([endpoint_id]));
}
}
fn get_cache_times(&self) -> (Instant, Option<Instant>) {
@@ -300,18 +308,18 @@ impl ProjectInfoCacheImpl {
/// This is used to invalidate cache entries.
pub struct CachedLookupInfo {
/// Search by this key.
endpoint_id: EndpointId,
endpoint_id: EndpointIdInt,
lookup_type: LookupType,
}
impl CachedLookupInfo {
pub(self) fn new_role_secret(endpoint_id: EndpointId, role_name: RoleName) -> Self {
pub(self) fn new_role_secret(endpoint_id: EndpointIdInt, role_name: RoleNameInt) -> Self {
Self {
endpoint_id,
lookup_type: LookupType::RoleSecret(role_name),
}
}
pub(self) fn new_allowed_ips(endpoint_id: EndpointId) -> Self {
pub(self) fn new_allowed_ips(endpoint_id: EndpointIdInt) -> Self {
Self {
endpoint_id,
lookup_type: LookupType::AllowedIps,
@@ -320,7 +328,7 @@ impl CachedLookupInfo {
}
enum LookupType {
RoleSecret(RoleName),
RoleSecret(RoleNameInt),
AllowedIps,
}
@@ -335,7 +343,7 @@ impl Cache for ProjectInfoCacheImpl {
match &key.lookup_type {
LookupType::RoleSecret(role_name) => {
if let Some(mut endpoint_info) = self.cache.get_mut(&key.endpoint_id) {
endpoint_info.invalidate_role_secret(role_name);
endpoint_info.invalidate_role_secret(*role_name);
}
}
LookupType::AllowedIps => {
@@ -457,7 +465,7 @@ mod tests {
assert_eq!(cached.value, secret2);
// The only way to invalidate this value is to invalidate via the api.
cache.invalidate_role_secret_for_project(&project_id, &user2);
cache.invalidate_role_secret_for_project((&project_id).into(), (&user2).into());
assert!(cache.get_role_secret(&endpoint_id, &user2).is_none());
let cached = cache.get_allowed_ips(&endpoint_id).unwrap();

237
proxy/src/intern.rs Normal file
View File

@@ -0,0 +1,237 @@
use std::{
hash::BuildHasherDefault, marker::PhantomData, num::NonZeroUsize, ops::Index, sync::OnceLock,
};
use lasso::{Capacity, MemoryLimits, Spur, ThreadedRodeo};
use rustc_hash::FxHasher;
use crate::{BranchId, EndpointId, ProjectId, RoleName};
pub trait InternId: Sized + 'static {
fn get_interner() -> &'static StringInterner<Self>;
}
pub struct StringInterner<Id> {
inner: ThreadedRodeo<Spur, BuildHasherDefault<FxHasher>>,
_id: PhantomData<Id>,
}
#[derive(PartialEq, Debug, Clone, Copy, Eq, Hash)]
pub struct InternedString<Id> {
inner: Spur,
_id: PhantomData<Id>,
}
impl<Id: InternId> std::fmt::Display for InternedString<Id> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.as_str().fmt(f)
}
}
impl<Id: InternId> InternedString<Id> {
pub fn as_str(&self) -> &'static str {
Id::get_interner().inner.resolve(&self.inner)
}
pub fn get(s: &str) -> Option<Self> {
Id::get_interner().get(s)
}
}
impl<Id: InternId> AsRef<str> for InternedString<Id> {
fn as_ref(&self) -> &str {
self.as_str()
}
}
impl<Id: InternId> std::ops::Deref for InternedString<Id> {
type Target = str;
fn deref(&self) -> &str {
self.as_str()
}
}
impl<'de, Id: InternId> serde::de::Deserialize<'de> for InternedString<Id> {
fn deserialize<D: serde::de::Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
struct Visitor<Id>(PhantomData<Id>);
impl<'de, Id: InternId> serde::de::Visitor<'de> for Visitor<Id> {
type Value = InternedString<Id>;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("a string")
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(Id::get_interner().get_or_intern(v))
}
}
d.deserialize_str(Visitor::<Id>(PhantomData))
}
}
impl<Id: InternId> serde::Serialize for InternedString<Id> {
fn serialize<S: serde::Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
self.as_str().serialize(s)
}
}
impl<Id: InternId> StringInterner<Id> {
pub fn new() -> Self {
StringInterner {
inner: ThreadedRodeo::with_capacity_memory_limits_and_hasher(
Capacity::new(2500, NonZeroUsize::new(1 << 16).unwrap()),
// unbounded
MemoryLimits::for_memory_usage(usize::MAX),
BuildHasherDefault::<FxHasher>::default(),
),
_id: PhantomData,
}
}
pub fn is_empty(&self) -> bool {
self.inner.is_empty()
}
pub fn len(&self) -> usize {
self.inner.len()
}
pub fn current_memory_usage(&self) -> usize {
self.inner.current_memory_usage()
}
pub fn get_or_intern(&self, s: &str) -> InternedString<Id> {
InternedString {
inner: self.inner.get_or_intern(s),
_id: PhantomData,
}
}
pub fn get(&self, s: &str) -> Option<InternedString<Id>> {
Some(InternedString {
inner: self.inner.get(s)?,
_id: PhantomData,
})
}
}
impl<Id: InternId> Index<InternedString<Id>> for StringInterner<Id> {
type Output = str;
fn index(&self, index: InternedString<Id>) -> &Self::Output {
self.inner.resolve(&index.inner)
}
}
impl<Id: InternId> Default for StringInterner<Id> {
fn default() -> Self {
Self::new()
}
}
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
pub struct RoleNameTag;
impl InternId for RoleNameTag {
fn get_interner() -> &'static StringInterner<Self> {
pub static ROLE_NAMES: OnceLock<StringInterner<RoleNameTag>> = OnceLock::new();
ROLE_NAMES.get_or_init(Default::default)
}
}
pub type RoleNameInt = InternedString<RoleNameTag>;
impl From<&RoleName> for RoleNameInt {
fn from(value: &RoleName) -> Self {
RoleNameTag::get_interner().get_or_intern(value)
}
}
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
pub struct EndpointIdTag;
impl InternId for EndpointIdTag {
fn get_interner() -> &'static StringInterner<Self> {
pub static ROLE_NAMES: OnceLock<StringInterner<EndpointIdTag>> = OnceLock::new();
ROLE_NAMES.get_or_init(Default::default)
}
}
pub type EndpointIdInt = InternedString<EndpointIdTag>;
impl From<&EndpointId> for EndpointIdInt {
fn from(value: &EndpointId) -> Self {
EndpointIdTag::get_interner().get_or_intern(value)
}
}
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
pub struct BranchIdTag;
impl InternId for BranchIdTag {
fn get_interner() -> &'static StringInterner<Self> {
pub static ROLE_NAMES: OnceLock<StringInterner<BranchIdTag>> = OnceLock::new();
ROLE_NAMES.get_or_init(Default::default)
}
}
pub type BranchIdInt = InternedString<BranchIdTag>;
impl From<&BranchId> for BranchIdInt {
fn from(value: &BranchId) -> Self {
BranchIdTag::get_interner().get_or_intern(value)
}
}
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
pub struct ProjectIdTag;
impl InternId for ProjectIdTag {
fn get_interner() -> &'static StringInterner<Self> {
pub static ROLE_NAMES: OnceLock<StringInterner<ProjectIdTag>> = OnceLock::new();
ROLE_NAMES.get_or_init(Default::default)
}
}
pub type ProjectIdInt = InternedString<ProjectIdTag>;
impl From<&ProjectId> for ProjectIdInt {
fn from(value: &ProjectId) -> Self {
ProjectIdTag::get_interner().get_or_intern(value)
}
}
#[cfg(test)]
mod tests {
use std::sync::OnceLock;
use crate::intern::StringInterner;
use super::InternId;
struct MyId;
impl InternId for MyId {
fn get_interner() -> &'static StringInterner<Self> {
pub static ROLE_NAMES: OnceLock<StringInterner<MyId>> = OnceLock::new();
ROLE_NAMES.get_or_init(Default::default)
}
}
#[test]
fn push_many_strings() {
use rand::{rngs::StdRng, Rng, SeedableRng};
use rand_distr::Zipf;
let endpoint_dist = Zipf::new(500000, 0.8).unwrap();
let endpoints = StdRng::seed_from_u64(272488357).sample_iter(endpoint_dist);
let interner = MyId::get_interner();
const N: usize = 100_000;
let mut verify = Vec::with_capacity(N);
for endpoint in endpoints.take(N) {
let endpoint = format!("ep-string-interning-{endpoint}");
let key = interner.get_or_intern(&endpoint);
verify.push((endpoint, key));
}
for (s, key) in verify {
assert_eq!(interner[key], s);
}
// 2031616/59861 = 34 bytes per string
assert_eq!(interner.len(), 59_861);
// will have other overhead for the internal hashmaps that are not accounted for.
assert_eq!(interner.current_memory_usage(), 2_031_616);
}
}

View File

@@ -16,6 +16,7 @@ pub mod console;
pub mod context;
pub mod error;
pub mod http;
pub mod intern;
pub mod jemalloc;
pub mod logging;
pub mod metrics;

View File

@@ -4,7 +4,10 @@ use futures::StreamExt;
use redis::aio::PubSub;
use serde::Deserialize;
use crate::{cache::project_info::ProjectInfoCache, ProjectId, RoleName};
use crate::{
cache::project_info::ProjectInfoCache,
intern::{ProjectIdInt, RoleNameInt},
};
const CHANNEL_NAME: &str = "neondb-proxy-ws-updates";
const RECONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(20);
@@ -45,12 +48,12 @@ enum Notification {
}
#[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
struct AllowedIpsUpdate {
project_id: ProjectId,
project_id: ProjectIdInt,
}
#[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
struct PasswordUpdate {
project_id: ProjectId,
role_name: RoleName,
project_id: ProjectIdInt,
role_name: RoleNameInt,
}
fn deserialize_json_string<'de, D, T>(deserializer: D) -> Result<T, D::Error>
where
@@ -65,11 +68,11 @@ fn invalidate_cache<C: ProjectInfoCache>(cache: Arc<C>, msg: Notification) {
use Notification::*;
match msg {
AllowedIpsUpdate { allowed_ips_update } => {
cache.invalidate_allowed_ips_for_project(&allowed_ips_update.project_id)
cache.invalidate_allowed_ips_for_project(allowed_ips_update.project_id)
}
PasswordUpdate { password_update } => cache.invalidate_role_secret_for_project(
&password_update.project_id,
&password_update.role_name,
password_update.project_id,
password_update.role_name,
),
}
}
@@ -141,12 +144,14 @@ where
#[cfg(test)]
mod tests {
use crate::{ProjectId, RoleName};
use super::*;
use serde_json::json;
#[test]
fn parse_allowed_ips() -> anyhow::Result<()> {
let project_id = "new_project".to_string();
let project_id: ProjectId = "new_project".into();
let data = format!("{{\"project_id\": \"{project_id}\"}}");
let text = json!({
"type": "message",
@@ -161,7 +166,7 @@ mod tests {
result,
Notification::AllowedIpsUpdate {
allowed_ips_update: AllowedIpsUpdate {
project_id: project_id.into()
project_id: (&project_id).into()
}
}
);
@@ -171,8 +176,8 @@ mod tests {
#[test]
fn parse_password_updated() -> anyhow::Result<()> {
let project_id = "new_project".to_string();
let role_name = "new_role".to_string();
let project_id: ProjectId = "new_project".into();
let role_name: RoleName = "new_role".into();
let data = format!("{{\"project_id\": \"{project_id}\", \"role_name\": \"{role_name}\"}}");
let text = json!({
"type": "message",
@@ -187,8 +192,8 @@ mod tests {
result,
Notification::PasswordUpdate {
password_update: PasswordUpdate {
project_id: project_id.into(),
role_name: role_name.into()
project_id: (&project_id).into(),
role_name: (&role_name).into(),
}
}
);

View File

@@ -39,7 +39,8 @@ futures-io = { version = "0.3" }
futures-sink = { version = "0.3" }
futures-util = { version = "0.3", features = ["channel", "io", "sink"] }
getrandom = { version = "0.2", default-features = false, features = ["std"] }
hashbrown = { version = "0.14", default-features = false, features = ["raw"] }
hashbrown-582f2526e08bb6a0 = { package = "hashbrown", version = "0.14", default-features = false, features = ["raw"] }
hashbrown-594e8ee84c453af0 = { package = "hashbrown", version = "0.13", features = ["raw"] }
hex = { version = "0.4", features = ["serde"] }
hmac = { version = "0.12", default-features = false, features = ["reset"] }
hyper = { version = "0.14", features = ["full"] }
@@ -91,7 +92,7 @@ cc = { version = "1", default-features = false, features = ["parallel"] }
chrono = { version = "0.4", default-features = false, features = ["clock", "serde", "wasmbind"] }
either = { version = "1" }
getrandom = { version = "0.2", default-features = false, features = ["std"] }
hashbrown = { version = "0.14", default-features = false, features = ["raw"] }
hashbrown-582f2526e08bb6a0 = { package = "hashbrown", version = "0.14", default-features = false, features = ["raw"] }
indexmap = { version = "1", default-features = false, features = ["std"] }
itertools = { version = "0.10" }
libc = { version = "0.2", features = ["extra_traits", "use_std"] }