From ca578449e473a23cd7e7aaa7e9cc33ec58b72a13 Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Tue, 7 May 2024 08:34:21 +0100 Subject: [PATCH] simplify Cache invalidate trait, reduce EndpointCacheKey --- proxy/src/auth/backend.rs | 8 +++++++- proxy/src/auth/credentials.rs | 6 +++--- proxy/src/cache/common.rs | 26 ++++++++---------------- proxy/src/cache/project_info.rs | 9 ++------- proxy/src/console/provider.rs | 2 +- proxy/src/console/provider/neon.rs | 6 +++--- proxy/src/lib.rs | 32 +++++++++++++++++++++++++++--- proxy/src/proxy.rs | 27 +++++++++++++++++-------- proxy/src/proxy/tests.rs | 10 +++++++--- proxy/src/serverless/conn_pool.rs | 2 +- 10 files changed, 80 insertions(+), 48 deletions(-) diff --git a/proxy/src/auth/backend.rs b/proxy/src/auth/backend.rs index dd80f502eb..bc34636212 100644 --- a/proxy/src/auth/backend.rs +++ b/proxy/src/auth/backend.rs @@ -155,7 +155,13 @@ pub struct ComputeUserInfo { impl ComputeUserInfo { pub fn endpoint_cache_key(&self) -> EndpointCacheKey { - self.options.get_cache_key(&self.endpoint) + let id = EndpointIdInt::from(&self.endpoint); + let key = EndpointCacheKey::from(id); + if self.options.is_empty() { + key + } else { + key.with_options(self.options.to_string()) + } } } diff --git a/proxy/src/auth/credentials.rs b/proxy/src/auth/credentials.rs index 783a1a5a21..819409a751 100644 --- a/proxy/src/auth/credentials.rs +++ b/proxy/src/auth/credentials.rs @@ -292,7 +292,7 @@ mod tests { ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())?; assert_eq!(user_info.user, "john_doe"); assert_eq!(user_info.endpoint_id.as_deref(), Some("foo")); - assert_eq!(user_info.options.get_cache_key("foo"), "foo"); + assert_eq!(user_info.options.to_string(), ""); Ok(()) } @@ -451,8 +451,8 @@ mod tests { ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())?; assert_eq!(user_info.endpoint_id.as_deref(), Some("project")); assert_eq!( - user_info.options.get_cache_key("project"), - "project endpoint_type:read_write lsn:0/2" + user_info.options.to_string(), + "endpoint_type:read_write lsn:0/2" ); Ok(()) diff --git a/proxy/src/cache/common.rs b/proxy/src/cache/common.rs index f228e7c93f..2d6c3d5123 100644 --- a/proxy/src/cache/common.rs +++ b/proxy/src/cache/common.rs @@ -5,34 +5,26 @@ use std::ops::{Deref, DerefMut}; /// This is useful for [`Cached`]. #[allow(async_fn_in_trait)] pub trait Cache { - /// Entry's key. - type Key; - - /// Entry's value. - type Value; - /// Used for entry invalidation. - type LookupInfo; + type LookupInfo; /// Invalidate an entry using a lookup info. /// We don't have an empty default impl because it's error-prone. - async fn invalidate(&self, _: &Self::LookupInfo); + async fn invalidate(&self, _: &Self::LookupInfo); } impl Cache for &C { - type Key = C::Key; - type Value = C::Value; - type LookupInfo = C::LookupInfo; + type LookupInfo = C::LookupInfo; - async fn invalidate(&self, info: &Self::LookupInfo) { + async fn invalidate(&self, info: &Self::LookupInfo) { C::invalidate(self, info).await } } /// Wrapper for convenient entry invalidation. -pub struct Cached::Value> { +pub struct Cached { /// Cache + lookup info. - pub token: Option<(C, C::LookupInfo)>, + pub token: Option<(C, C::LookupInfo)>, /// The value itself. pub value: V, @@ -88,11 +80,9 @@ where V: Clone + Send + Sync + 'static, S: std::hash::BuildHasher + Clone + Send + Sync + 'static, { - type Key = K; - type Value = V; - type LookupInfo = Key; + type LookupInfo = K; - async fn invalidate(&self, key: &Self::LookupInfo) { + async fn invalidate(&self, key: &Self::LookupInfo) { moka::future::Cache::invalidate(self, key).await } } diff --git a/proxy/src/cache/project_info.rs b/proxy/src/cache/project_info.rs index 4820c7ec74..cb08a211ad 100644 --- a/proxy/src/cache/project_info.rs +++ b/proxy/src/cache/project_info.rs @@ -8,7 +8,6 @@ use std::{ use async_trait::async_trait; use dashmap::DashMap; use rand::{thread_rng, Rng}; -use smol_str::SmolStr; use tokio::sync::Mutex; use tokio::time::Instant; use tracing::{debug, info}; @@ -346,13 +345,9 @@ enum LookupType { } impl Cache for ProjectInfoCacheImpl { - type Key = SmolStr; - // Value is not really used here, but we need to specify it. - type Value = SmolStr; + type LookupInfo = CachedLookupInfo; - type LookupInfo = CachedLookupInfo; - - async fn invalidate(&self, key: &Self::LookupInfo) { + async fn invalidate(&self, key: &Self::LookupInfo) { match &key.lookup_type { LookupType::RoleSecret(role_name) => { if let Some(mut endpoint_info) = self.cache.get_mut(&key.endpoint_id) { diff --git a/proxy/src/console/provider.rs b/proxy/src/console/provider.rs index 1b287f39d8..f5d9a2825f 100644 --- a/proxy/src/console/provider.rs +++ b/proxy/src/console/provider.rs @@ -327,7 +327,7 @@ impl NodeInfo { } pub type NodeInfoCache = moka::future::Cache; -pub type CachedNodeInfo = Cached<&'static NodeInfoCache>; +pub type CachedNodeInfo = Cached<&'static NodeInfoCache, NodeInfo>; pub type CachedRoleSecret = Cached<&'static ProjectInfoCacheImpl, Option>; pub type CachedAllowedIps = Cached<&'static ProjectInfoCacheImpl, Arc>>; diff --git a/proxy/src/console/provider/neon.rs b/proxy/src/console/provider/neon.rs index 5ebea76728..e9b4e6cf62 100644 --- a/proxy/src/console/provider/neon.rs +++ b/proxy/src/console/provider/neon.rs @@ -276,7 +276,7 @@ impl super::Api for Api { // The connection info remains the same during that period of time, // which means that we might cache it to reduce the load and latency. if let Some(cached) = self.caches.node_info.get(&key).await { - info!(key = &*key, "found cached compute node info"); + info!(key = %key, "found cached compute node info"); ctx.set_project(cached.aux.clone()); return Ok(CachedNodeInfo { token: Some((&self.caches.node_info, key)), @@ -298,7 +298,7 @@ impl super::Api for Api { // double check if permit.should_check_cache() { if let Some(cached) = self.caches.node_info.get(&key).await { - info!(key = &*key, "found cached compute node info"); + info!(key = %key, "found cached compute node info"); ctx.set_project(cached.aux.clone()); return Ok(CachedNodeInfo { token: Some((&self.caches.node_info, key)), @@ -320,7 +320,7 @@ impl super::Api for Api { .await; node.aux.cold_start_info = cold_start_info; - info!(key = &*key, "created a cache entry for compute node info"); + info!(key = %key, "created a cache entry for compute node info"); Ok(CachedNodeInfo { token: Some((&self.caches.node_info, key)), diff --git a/proxy/src/lib.rs b/proxy/src/lib.rs index 35c1616481..008f5b1336 100644 --- a/proxy/src/lib.rs +++ b/proxy/src/lib.rs @@ -154,9 +154,6 @@ smol_str_wrapper!(BranchId); // 90% of project strings are 23 characters or less. smol_str_wrapper!(ProjectId); -// will usually equal endpoint ID -smol_str_wrapper!(EndpointCacheKey); - smol_str_wrapper!(DbName); // postgres hostname, will likely be a port:ip addr @@ -180,3 +177,32 @@ impl EndpointId { ProjectId(self.0.clone()) } } + +#[derive(Hash, PartialEq, Eq, Debug, Clone)] +pub struct EndpointCacheKey { + endpoint: intern::EndpointIdInt, + options: Option, +} + +impl std::fmt::Display for EndpointCacheKey { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(self.endpoint.as_str())?; + if let Some(options) = &self.options { + f.write_str(" ")?; + f.write_str(options)?; + } + Ok(()) + } +} + +impl From for EndpointCacheKey { + fn from(value: intern::EndpointIdInt) -> Self { + Self { endpoint: value, options: None } + } +} +impl EndpointCacheKey { + pub fn with_options(mut self, options: String) -> Self { + self.options = Some(options); + self + } +} diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index e4e095d77d..8f2524c07f 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -20,7 +20,6 @@ use crate::{ protocol2::read_proxy_protocol, proxy::handshake::{handshake, HandshakeData}, stream::{PqStream, Stream}, - EndpointCacheKey, }; use futures::TryFutureExt; use itertools::Itertools; @@ -391,13 +390,8 @@ impl NeonOptions { Self(options) } - pub fn get_cache_key(&self, prefix: &str) -> EndpointCacheKey { - // prefix + format!(" {k}:{v}") - // kinda jank because SmolStr is immutable - std::iter::once(prefix) - .chain(self.0.iter().flat_map(|(k, v)| [" ", &**k, ":", &**v])) - .collect::() - .into() + pub fn is_empty(&self) -> bool { + self.0.is_empty() } /// DeepObject format @@ -418,3 +412,20 @@ pub fn neon_option(bytes: &str) -> Option<(&str, &str)> { let (_, [k, v]) = cap.extract(); Some((k, v)) } + +impl std::fmt::Display for NeonOptions { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let mut space = false; + for (k, v) in &self.0 { + if space { + f.write_str(" ")?; + } else { + space = true; + } + f.write_str(k)?; + f.write_str(":")?; + f.write_str(v)?; + } + Ok(()) + } +} diff --git a/proxy/src/proxy/tests.rs b/proxy/src/proxy/tests.rs index 0424dceb27..632e35f153 100644 --- a/proxy/src/proxy/tests.rs +++ b/proxy/src/proxy/tests.rs @@ -16,8 +16,9 @@ use crate::console::messages::MetricsAuxInfo; use crate::console::provider::{CachedAllowedIps, CachedRoleSecret, ConsoleBackend}; use crate::console::{self, CachedNodeInfo, NodeInfo}; use crate::error::ErrorKind; +use crate::intern::EndpointIdInt; use crate::proxy::retry::retry_after; -use crate::{http, sasl, scram, BranchId, EndpointId, ProjectId}; +use crate::{http, sasl, scram, BranchId, EndpointCacheKey, EndpointId, ProjectId}; use anyhow::{bail, Context}; use async_trait::async_trait; use rstest::rstest; @@ -530,9 +531,12 @@ async fn helper_create_cached_node_info(cache: &'static NodeInfoCache) -> Cached }, allow_self_signed_compute: false, }; - cache.insert("key".into(), node.clone()).await; + let ep: EndpointId = "key".into(); + let ep = EndpointIdInt::from(ep); + let key = EndpointCacheKey::from(ep); + cache.insert(key.clone(), node.clone()).await; CachedNodeInfo { - token: Some((cache, "key".into())), + token: Some((cache, key)), value: node, } } diff --git a/proxy/src/serverless/conn_pool.rs b/proxy/src/serverless/conn_pool.rs index 798e488509..b6b8455436 100644 --- a/proxy/src/serverless/conn_pool.rs +++ b/proxy/src/serverless/conn_pool.rs @@ -61,7 +61,7 @@ impl fmt::Display for ConnInfo { self.user_info.user, self.user_info.endpoint, self.dbname, - self.user_info.options.get_cache_key("") + self.user_info.options ) } }