simplify Cache invalidate trait, reduce EndpointCacheKey

This commit is contained in:
Conrad Ludgate
2024-05-07 08:34:21 +01:00
parent ef3a9dfafa
commit ca578449e4
10 changed files with 80 additions and 48 deletions

View File

@@ -155,7 +155,13 @@ pub struct ComputeUserInfo {
impl ComputeUserInfo {
pub fn endpoint_cache_key(&self) -> EndpointCacheKey {
self.options.get_cache_key(&self.endpoint)
let id = EndpointIdInt::from(&self.endpoint);
let key = EndpointCacheKey::from(id);
if self.options.is_empty() {
key
} else {
key.with_options(self.options.to_string())
}
}
}

View File

@@ -292,7 +292,7 @@ mod tests {
ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())?;
assert_eq!(user_info.user, "john_doe");
assert_eq!(user_info.endpoint_id.as_deref(), Some("foo"));
assert_eq!(user_info.options.get_cache_key("foo"), "foo");
assert_eq!(user_info.options.to_string(), "");
Ok(())
}
@@ -451,8 +451,8 @@ mod tests {
ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())?;
assert_eq!(user_info.endpoint_id.as_deref(), Some("project"));
assert_eq!(
user_info.options.get_cache_key("project"),
"project endpoint_type:read_write lsn:0/2"
user_info.options.to_string(),
"endpoint_type:read_write lsn:0/2"
);
Ok(())

View File

@@ -5,34 +5,26 @@ use std::ops::{Deref, DerefMut};
/// This is useful for [`Cached`].
#[allow(async_fn_in_trait)]
pub trait Cache {
/// Entry's key.
type Key;
/// Entry's value.
type Value;
/// Used for entry invalidation.
type LookupInfo<Key>;
type LookupInfo;
/// Invalidate an entry using a lookup info.
/// We don't have an empty default impl because it's error-prone.
async fn invalidate(&self, _: &Self::LookupInfo<Self::Key>);
async fn invalidate(&self, _: &Self::LookupInfo);
}
impl<C: Cache> Cache for &C {
type Key = C::Key;
type Value = C::Value;
type LookupInfo<Key> = C::LookupInfo<Key>;
type LookupInfo = C::LookupInfo;
async fn invalidate(&self, info: &Self::LookupInfo<Self::Key>) {
async fn invalidate(&self, info: &Self::LookupInfo) {
C::invalidate(self, info).await
}
}
/// Wrapper for convenient entry invalidation.
pub struct Cached<C: Cache, V = <C as Cache>::Value> {
pub struct Cached<C: Cache, V> {
/// Cache + lookup info.
pub token: Option<(C, C::LookupInfo<C::Key>)>,
pub token: Option<(C, C::LookupInfo)>,
/// The value itself.
pub value: V,
@@ -88,11 +80,9 @@ where
V: Clone + Send + Sync + 'static,
S: std::hash::BuildHasher + Clone + Send + Sync + 'static,
{
type Key = K;
type Value = V;
type LookupInfo<Key> = Key;
type LookupInfo = K;
async fn invalidate(&self, key: &Self::LookupInfo<Self::Key>) {
async fn invalidate(&self, key: &Self::LookupInfo) {
moka::future::Cache::invalidate(self, key).await
}
}

View File

@@ -8,7 +8,6 @@ use std::{
use async_trait::async_trait;
use dashmap::DashMap;
use rand::{thread_rng, Rng};
use smol_str::SmolStr;
use tokio::sync::Mutex;
use tokio::time::Instant;
use tracing::{debug, info};
@@ -346,13 +345,9 @@ enum LookupType {
}
impl Cache for ProjectInfoCacheImpl {
type Key = SmolStr;
// Value is not really used here, but we need to specify it.
type Value = SmolStr;
type LookupInfo = CachedLookupInfo;
type LookupInfo<Key> = CachedLookupInfo;
async fn invalidate(&self, key: &Self::LookupInfo<SmolStr>) {
async fn invalidate(&self, key: &Self::LookupInfo) {
match &key.lookup_type {
LookupType::RoleSecret(role_name) => {
if let Some(mut endpoint_info) = self.cache.get_mut(&key.endpoint_id) {

View File

@@ -327,7 +327,7 @@ impl NodeInfo {
}
pub type NodeInfoCache = moka::future::Cache<EndpointCacheKey, NodeInfo>;
pub type CachedNodeInfo = Cached<&'static NodeInfoCache>;
pub type CachedNodeInfo = Cached<&'static NodeInfoCache, NodeInfo>;
pub type CachedRoleSecret = Cached<&'static ProjectInfoCacheImpl, Option<AuthSecret>>;
pub type CachedAllowedIps = Cached<&'static ProjectInfoCacheImpl, Arc<Vec<IpPattern>>>;

View File

@@ -276,7 +276,7 @@ impl super::Api for Api {
// The connection info remains the same during that period of time,
// which means that we might cache it to reduce the load and latency.
if let Some(cached) = self.caches.node_info.get(&key).await {
info!(key = &*key, "found cached compute node info");
info!(key = %key, "found cached compute node info");
ctx.set_project(cached.aux.clone());
return Ok(CachedNodeInfo {
token: Some((&self.caches.node_info, key)),
@@ -298,7 +298,7 @@ impl super::Api for Api {
// double check
if permit.should_check_cache() {
if let Some(cached) = self.caches.node_info.get(&key).await {
info!(key = &*key, "found cached compute node info");
info!(key = %key, "found cached compute node info");
ctx.set_project(cached.aux.clone());
return Ok(CachedNodeInfo {
token: Some((&self.caches.node_info, key)),
@@ -320,7 +320,7 @@ impl super::Api for Api {
.await;
node.aux.cold_start_info = cold_start_info;
info!(key = &*key, "created a cache entry for compute node info");
info!(key = %key, "created a cache entry for compute node info");
Ok(CachedNodeInfo {
token: Some((&self.caches.node_info, key)),

View File

@@ -154,9 +154,6 @@ smol_str_wrapper!(BranchId);
// 90% of project strings are 23 characters or less.
smol_str_wrapper!(ProjectId);
// will usually equal endpoint ID
smol_str_wrapper!(EndpointCacheKey);
smol_str_wrapper!(DbName);
// postgres hostname, will likely be a port:ip addr
@@ -180,3 +177,32 @@ impl EndpointId {
ProjectId(self.0.clone())
}
}
#[derive(Hash, PartialEq, Eq, Debug, Clone)]
pub struct EndpointCacheKey {
endpoint: intern::EndpointIdInt,
options: Option<String>,
}
impl std::fmt::Display for EndpointCacheKey {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.endpoint.as_str())?;
if let Some(options) = &self.options {
f.write_str(" ")?;
f.write_str(options)?;
}
Ok(())
}
}
impl From<intern::EndpointIdInt> for EndpointCacheKey {
fn from(value: intern::EndpointIdInt) -> Self {
Self { endpoint: value, options: None }
}
}
impl EndpointCacheKey {
pub fn with_options(mut self, options: String) -> Self {
self.options = Some(options);
self
}
}

View File

@@ -20,7 +20,6 @@ use crate::{
protocol2::read_proxy_protocol,
proxy::handshake::{handshake, HandshakeData},
stream::{PqStream, Stream},
EndpointCacheKey,
};
use futures::TryFutureExt;
use itertools::Itertools;
@@ -391,13 +390,8 @@ impl NeonOptions {
Self(options)
}
pub fn get_cache_key(&self, prefix: &str) -> EndpointCacheKey {
// prefix + format!(" {k}:{v}")
// kinda jank because SmolStr is immutable
std::iter::once(prefix)
.chain(self.0.iter().flat_map(|(k, v)| [" ", &**k, ":", &**v]))
.collect::<SmolStr>()
.into()
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
/// <https://swagger.io/docs/specification/serialization/> DeepObject format
@@ -418,3 +412,20 @@ pub fn neon_option(bytes: &str) -> Option<(&str, &str)> {
let (_, [k, v]) = cap.extract();
Some((k, v))
}
impl std::fmt::Display for NeonOptions {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut space = false;
for (k, v) in &self.0 {
if space {
f.write_str(" ")?;
} else {
space = true;
}
f.write_str(k)?;
f.write_str(":")?;
f.write_str(v)?;
}
Ok(())
}
}

View File

@@ -16,8 +16,9 @@ use crate::console::messages::MetricsAuxInfo;
use crate::console::provider::{CachedAllowedIps, CachedRoleSecret, ConsoleBackend};
use crate::console::{self, CachedNodeInfo, NodeInfo};
use crate::error::ErrorKind;
use crate::intern::EndpointIdInt;
use crate::proxy::retry::retry_after;
use crate::{http, sasl, scram, BranchId, EndpointId, ProjectId};
use crate::{http, sasl, scram, BranchId, EndpointCacheKey, EndpointId, ProjectId};
use anyhow::{bail, Context};
use async_trait::async_trait;
use rstest::rstest;
@@ -530,9 +531,12 @@ async fn helper_create_cached_node_info(cache: &'static NodeInfoCache) -> Cached
},
allow_self_signed_compute: false,
};
cache.insert("key".into(), node.clone()).await;
let ep: EndpointId = "key".into();
let ep = EndpointIdInt::from(ep);
let key = EndpointCacheKey::from(ep);
cache.insert(key.clone(), node.clone()).await;
CachedNodeInfo {
token: Some((cache, "key".into())),
token: Some((cache, key)),
value: node,
}
}

View File

@@ -61,7 +61,7 @@ impl fmt::Display for ConnInfo {
self.user_info.user,
self.user_info.endpoint,
self.dbname,
self.user_info.options.get_cache_key("")
self.user_info.options
)
}
}