Compare commits

...

2 Commits

Author SHA1 Message Date
Conrad Ludgate
4c78a5067f compress cache key 2024-06-28 09:12:18 +01:00
Conrad Ludgate
108f08f982 proxy: cache a compressed version of the node info 2024-06-28 09:04:54 +01:00
12 changed files with 107 additions and 49 deletions

View File

@@ -153,7 +153,7 @@ pub struct ComputeUserInfo {
impl ComputeUserInfo { impl ComputeUserInfo {
pub fn endpoint_cache_key(&self) -> EndpointCacheKey { pub fn endpoint_cache_key(&self) -> EndpointCacheKey {
self.options.get_cache_key(&self.endpoint) self.options.get_cache_key((&self.endpoint).into())
} }
} }

View File

@@ -241,6 +241,8 @@ fn project_name_valid(name: &str) -> bool {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::intern::EndpointIdInt;
use super::*; use super::*;
use serde_json::json; use serde_json::json;
use ComputeUserInfoParseError::*; use ComputeUserInfoParseError::*;
@@ -284,7 +286,6 @@ mod tests {
ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())?; ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())?;
assert_eq!(user_info.user, "john_doe"); assert_eq!(user_info.user, "john_doe");
assert_eq!(user_info.endpoint_id.as_deref(), Some("foo")); assert_eq!(user_info.endpoint_id.as_deref(), Some("foo"));
assert_eq!(user_info.options.get_cache_key("foo"), "foo");
Ok(()) Ok(())
} }
@@ -442,8 +443,9 @@ mod tests {
let user_info = let user_info =
ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())?; ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())?;
assert_eq!(user_info.endpoint_id.as_deref(), Some("project")); assert_eq!(user_info.endpoint_id.as_deref(), Some("project"));
let project = EndpointIdInt::from(EndpointId::from("project"));
assert_eq!( assert_eq!(
user_info.options.get_cache_key("project"), user_info.options.get_cache_key(project).to_string(),
"project endpoint_type:read_write lsn:0/2" "project endpoint_type:read_write lsn:0/2"
); );

View File

@@ -43,6 +43,15 @@ impl<C: Cache, V> Cached<C, V> {
Self { token: None, value } Self { token: None, value }
} }
/// Place any entry into this wrapper; invalidation will be a no-op.
pub fn map<U>(self, f: impl FnOnce(V) -> U) -> Cached<C, U> {
let token = self.token;
Cached {
token,
value: f(self.value),
}
}
pub fn take_value(self) -> (Cached<C, ()>, V) { pub fn take_value(self) -> (Cached<C, ()>, V) {
( (
Cached { Cached {

View File

@@ -93,7 +93,7 @@ pub type ScramKeys = tokio_postgres::config::ScramKeys<32>;
/// Eventually, `tokio_postgres` will be replaced with something better. /// Eventually, `tokio_postgres` will be replaced with something better.
/// Newtype allows us to implement methods on top of it. /// Newtype allows us to implement methods on top of it.
#[derive(Clone, Default)] #[derive(Clone, Default)]
pub struct ConnCfg(Box<tokio_postgres::Config>); pub struct ConnCfg(tokio_postgres::Config);
/// Creation and initialization routines. /// Creation and initialization routines.
impl ConnCfg { impl ConnCfg {

View File

@@ -9,14 +9,14 @@ use crate::{
IpPattern, IpPattern,
}, },
cache::{endpoints::EndpointsCache, project_info::ProjectInfoCacheImpl, Cached, TimedLru}, cache::{endpoints::EndpointsCache, project_info::ProjectInfoCacheImpl, Cached, TimedLru},
compute, compute::{self, ConnCfg},
config::{CacheOptions, EndpointCacheConfig, ProjectInfoCacheOptions}, config::{CacheOptions, EndpointCacheConfig, ProjectInfoCacheOptions},
context::RequestMonitoring, context::RequestMonitoring,
error::ReportableError, error::ReportableError,
intern::ProjectIdInt, intern::ProjectIdInt,
metrics::ApiLockMetrics, metrics::ApiLockMetrics,
rate_limiter::{DynamicLimiter, Outcome, RateLimiterConfig, Token}, rate_limiter::{DynamicLimiter, Outcome, RateLimiterConfig, Token},
scram, EndpointCacheKey, scram, EndpointCacheKey, Host,
}; };
use dashmap::DashMap; use dashmap::DashMap;
use std::{hash::Hash, sync::Arc, time::Duration}; use std::{hash::Hash, sync::Arc, time::Duration};
@@ -289,6 +289,33 @@ pub struct NodeInfo {
pub allow_self_signed_compute: bool, pub allow_self_signed_compute: bool,
} }
/// Cached info for establishing a connection to a compute node.
#[derive(Clone)]
pub struct NodeCachedInfo {
pub host: Host,
pub port: u16,
/// Labels for proxy's metrics.
pub aux: MetricsAuxInfo,
/// Whether we should accept self-signed certificates (for testing)
pub allow_self_signed_compute: bool,
}
impl NodeCachedInfo {
pub fn into_node_info(self) -> NodeInfo {
let mut config = ConnCfg::default();
config.ssl_mode(tokio_postgres::config::SslMode::Disable);
config.host(&self.host);
config.port(self.port);
NodeInfo {
config,
aux: self.aux,
allow_self_signed_compute: self.allow_self_signed_compute,
}
}
}
impl NodeInfo { impl NodeInfo {
pub async fn connect( pub async fn connect(
&self, &self,
@@ -317,8 +344,8 @@ impl NodeInfo {
} }
} }
pub type NodeInfoCache = TimedLru<EndpointCacheKey, NodeInfo>; pub type NodeInfoCache = TimedLru<EndpointCacheKey, NodeCachedInfo>;
pub type CachedNodeInfo = Cached<&'static NodeInfoCache>; pub type CachedNodeInfo = Cached<&'static NodeInfoCache, NodeInfo>;
pub type CachedRoleSecret = Cached<&'static ProjectInfoCacheImpl, Option<AuthSecret>>; pub type CachedRoleSecret = Cached<&'static ProjectInfoCacheImpl, Option<AuthSecret>>;
pub type CachedAllowedIps = Cached<&'static ProjectInfoCacheImpl, Arc<Vec<IpPattern>>>; pub type CachedAllowedIps = Cached<&'static ProjectInfoCacheImpl, Arc<Vec<IpPattern>>>;

View File

@@ -4,22 +4,20 @@ use super::{
super::messages::{ConsoleError, GetRoleSecret, WakeCompute}, super::messages::{ConsoleError, GetRoleSecret, WakeCompute},
errors::{ApiError, GetAuthInfoError, WakeComputeError}, errors::{ApiError, GetAuthInfoError, WakeComputeError},
ApiCaches, ApiLocks, AuthInfo, AuthSecret, CachedAllowedIps, CachedNodeInfo, CachedRoleSecret, ApiCaches, ApiLocks, AuthInfo, AuthSecret, CachedAllowedIps, CachedNodeInfo, CachedRoleSecret,
NodeInfo, NodeCachedInfo,
}; };
use crate::{ use crate::{
auth::backend::ComputeUserInfo, auth::backend::ComputeUserInfo,
compute,
console::messages::ColdStartInfo, console::messages::ColdStartInfo,
http, http,
metrics::{CacheOutcome, Metrics}, metrics::{CacheOutcome, Metrics},
rate_limiter::EndpointRateLimiter, rate_limiter::EndpointRateLimiter,
scram, EndpointCacheKey, scram, EndpointCacheKey, Host,
}; };
use crate::{cache::Cached, context::RequestMonitoring}; use crate::{cache::Cached, context::RequestMonitoring};
use futures::TryFutureExt; use futures::TryFutureExt;
use std::sync::Arc; use std::sync::Arc;
use tokio::time::Instant; use tokio::time::Instant;
use tokio_postgres::config::SslMode;
use tracing::{error, info, info_span, warn, Instrument}; use tracing::{error, info, info_span, warn, Instrument};
pub struct Api { pub struct Api {
@@ -132,7 +130,7 @@ impl Api {
&self, &self,
ctx: &mut RequestMonitoring, ctx: &mut RequestMonitoring,
user_info: &ComputeUserInfo, user_info: &ComputeUserInfo,
) -> Result<NodeInfo, WakeComputeError> { ) -> Result<NodeCachedInfo, WakeComputeError> {
let request_id = ctx.session_id.to_string(); let request_id = ctx.session_id.to_string();
let application_name = ctx.console_application_name(); let application_name = ctx.console_application_name();
async { async {
@@ -167,15 +165,11 @@ impl Api {
None => return Err(WakeComputeError::BadComputeAddress(body.address)), None => return Err(WakeComputeError::BadComputeAddress(body.address)),
Some(x) => x, Some(x) => x,
}; };
let host = Host(host.into());
// Don't set anything but host and port! This config will be cached. let node = NodeCachedInfo {
// We'll set username and such later using the startup message. host,
// TODO: add more type safety (in progress). port,
let mut config = compute::ConnCfg::new();
config.host(host).port(port).ssl_mode(SslMode::Disable); // TLS is not configured on compute nodes.
let node = NodeInfo {
config,
aux: body.aux, aux: body.aux,
allow_self_signed_compute: false, allow_self_signed_compute: false,
}; };
@@ -278,9 +272,9 @@ impl super::Api for Api {
// The connection info remains the same during that period of time, // The connection info remains the same during that period of time,
// which means that we might cache it to reduce the load and latency. // which means that we might cache it to reduce the load and latency.
if let Some(cached) = self.caches.node_info.get(&key) { if let Some(cached) = self.caches.node_info.get(&key) {
info!(key = &*key, "found cached compute node info"); info!(key = display(&key), "found cached compute node info");
ctx.set_project(cached.aux.clone()); ctx.set_project(cached.aux.clone());
return Ok(cached); return Ok(cached.map(NodeCachedInfo::into_node_info));
} }
let permit = self.locks.get_permit(&key).await?; let permit = self.locks.get_permit(&key).await?;
@@ -289,9 +283,9 @@ impl super::Api for Api {
// double check // double check
if permit.should_check_cache() { if permit.should_check_cache() {
if let Some(cached) = self.caches.node_info.get(&key) { if let Some(cached) = self.caches.node_info.get(&key) {
info!(key = &*key, "found cached compute node info"); info!(key = display(&key), "found cached compute node info");
ctx.set_project(cached.aux.clone()); ctx.set_project(cached.aux.clone());
return Ok(cached); return Ok(cached.map(NodeCachedInfo::into_node_info));
} }
} }
@@ -300,7 +294,7 @@ impl super::Api for Api {
.wake_compute_endpoint_rate_limiter .wake_compute_endpoint_rate_limiter
.check(user_info.endpoint.normalize_intern(), 1) .check(user_info.endpoint.normalize_intern(), 1)
{ {
info!(key = &*key, "found cached compute node info"); info!(key = display(&key), "found cached compute node info");
return Err(WakeComputeError::TooManyConnections); return Err(WakeComputeError::TooManyConnections);
} }
@@ -314,9 +308,12 @@ impl super::Api for Api {
let (_, mut cached) = self.caches.node_info.insert(key.clone(), node); let (_, mut cached) = self.caches.node_info.insert(key.clone(), node);
cached.aux.cold_start_info = cold_start_info; cached.aux.cold_start_info = cold_start_info;
info!(key = &*key, "created a cache entry for compute node info"); info!(
key = display(&key),
"created a cache entry for compute node info"
);
Ok(cached) Ok(cached.map(NodeCachedInfo::into_node_info))
} }
} }

View File

@@ -157,8 +157,16 @@ smol_str_wrapper!(BranchId);
// 90% of project strings are 23 characters or less. // 90% of project strings are 23 characters or less.
smol_str_wrapper!(ProjectId); smol_str_wrapper!(ProjectId);
// will usually equal endpoint ID #[derive(PartialEq, Eq, Hash, Debug, Clone)]
smol_str_wrapper!(EndpointCacheKey); pub struct EndpointCacheKey {
pub id: EndpointIdInt,
pub extra: Box<str>,
}
impl std::fmt::Display for EndpointCacheKey {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}{}", &self.id, &self.extra)
}
}
smol_str_wrapper!(DbName); smol_str_wrapper!(DbName);

View File

@@ -10,6 +10,7 @@ pub mod wake_compute;
pub use copy_bidirectional::copy_bidirectional_client_compute; pub use copy_bidirectional::copy_bidirectional_client_compute;
pub use copy_bidirectional::ErrorSource; pub use copy_bidirectional::ErrorSource;
use crate::intern::EndpointIdInt;
use crate::{ use crate::{
auth, auth,
cancellation::{self, CancellationHandlerMain, CancellationHandlerMainInternal}, cancellation::{self, CancellationHandlerMain, CancellationHandlerMainInternal},
@@ -404,13 +405,20 @@ impl NeonOptions {
Self(options) Self(options)
} }
pub fn get_cache_key(&self, prefix: &str) -> EndpointCacheKey { pub fn get_cache_key(&self, endpoint: EndpointIdInt) -> EndpointCacheKey {
// prefix + format!(" {k}:{v}") EndpointCacheKey {
// kinda jank because SmolStr is immutable id: endpoint,
std::iter::once(prefix) extra: self.get_cache_key_extras(),
.chain(self.0.iter().flat_map(|(k, v)| [" ", &**k, ":", &**v])) }
.collect::<SmolStr>() }
.into()
pub fn get_cache_key_extras(&self) -> Box<str> {
let mut extras = String::new();
for (k, v) in &self.0 {
use std::fmt::Write;
write!(&mut extras, " {k}:{v}").unwrap();
}
extras.into_boxed_str()
} }
/// <https://swagger.io/docs/specification/serialization/> DeepObject format /// <https://swagger.io/docs/specification/serialization/> DeepObject format

View File

@@ -47,7 +47,7 @@ pub trait ConnectMechanism {
async fn connect_once( async fn connect_once(
&self, &self,
ctx: &mut RequestMonitoring, ctx: &mut RequestMonitoring,
node_info: &console::CachedNodeInfo, node_info: &NodeInfo,
timeout: time::Duration, timeout: time::Duration,
) -> Result<Self::Connection, Self::ConnectError>; ) -> Result<Self::Connection, Self::ConnectError>;
@@ -82,7 +82,7 @@ impl ConnectMechanism for TcpMechanism<'_> {
async fn connect_once( async fn connect_once(
&self, &self,
ctx: &mut RequestMonitoring, ctx: &mut RequestMonitoring,
node_info: &console::CachedNodeInfo, node_info: &NodeInfo,
timeout: time::Duration, timeout: time::Duration,
) -> Result<PostgresConnection, Self::Error> { ) -> Result<PostgresConnection, Self::Error> {
let host = node_info.config.get_host()?; let host = node_info.config.get_host()?;

View File

@@ -13,8 +13,10 @@ use crate::auth::backend::{
use crate::config::{CertResolver, RetryConfig}; use crate::config::{CertResolver, RetryConfig};
use crate::console::caches::NodeInfoCache; use crate::console::caches::NodeInfoCache;
use crate::console::messages::{ConsoleError, Details, MetricsAuxInfo, Status}; use crate::console::messages::{ConsoleError, Details, MetricsAuxInfo, Status};
use crate::console::provider::{CachedAllowedIps, CachedRoleSecret, ConsoleBackend}; use crate::console::provider::{
use crate::console::{self, CachedNodeInfo, NodeInfo}; CachedAllowedIps, CachedRoleSecret, ConsoleBackend, NodeCachedInfo,
};
use crate::console::{self, CachedNodeInfo};
use crate::error::ErrorKind; use crate::error::ErrorKind;
use crate::{http, sasl, scram, BranchId, EndpointId, ProjectId}; use crate::{http, sasl, scram, BranchId, EndpointId, ProjectId};
use anyhow::{bail, Context}; use anyhow::{bail, Context};
@@ -458,7 +460,7 @@ impl ConnectMechanism for TestConnectMechanism {
async fn connect_once( async fn connect_once(
&self, &self,
_ctx: &mut RequestMonitoring, _ctx: &mut RequestMonitoring,
_node_info: &console::CachedNodeInfo, _node_info: &console::NodeInfo,
_timeout: std::time::Duration, _timeout: std::time::Duration,
) -> Result<Self::Connection, Self::ConnectError> { ) -> Result<Self::Connection, Self::ConnectError> {
let mut counter = self.counter.lock().unwrap(); let mut counter = self.counter.lock().unwrap();
@@ -530,8 +532,9 @@ impl TestBackend for TestConnectMechanism {
} }
fn helper_create_cached_node_info(cache: &'static NodeInfoCache) -> CachedNodeInfo { fn helper_create_cached_node_info(cache: &'static NodeInfoCache) -> CachedNodeInfo {
let node = NodeInfo { let node = NodeCachedInfo {
config: compute::ConnCfg::new(), host: "localhost".into(),
port: 5432,
aux: MetricsAuxInfo { aux: MetricsAuxInfo {
endpoint_id: (&EndpointId::from("endpoint")).into(), endpoint_id: (&EndpointId::from("endpoint")).into(),
project_id: (&ProjectId::from("project")).into(), project_id: (&ProjectId::from("project")).into(),
@@ -540,8 +543,12 @@ fn helper_create_cached_node_info(cache: &'static NodeInfoCache) -> CachedNodeIn
}, },
allow_self_signed_compute: false, allow_self_signed_compute: false,
}; };
let (_, node) = cache.insert("key".into(), node); let key = EndpointCacheKey {
node id: node.aux.endpoint_id,
extra: "".into(),
};
let (_, node) = cache.insert(key, node);
node.map(NodeCachedInfo::into_node_info)
} }
fn helper_create_connect_info( fn helper_create_connect_info(

View File

@@ -11,7 +11,7 @@ use crate::{
errors::{GetAuthInfoError, WakeComputeError}, errors::{GetAuthInfoError, WakeComputeError},
locks::ApiLocks, locks::ApiLocks,
provider::ApiLockError, provider::ApiLockError,
CachedNodeInfo, NodeInfo,
}, },
context::RequestMonitoring, context::RequestMonitoring,
error::{ErrorKind, ReportableError, UserFacingError}, error::{ErrorKind, ReportableError, UserFacingError},
@@ -223,7 +223,7 @@ impl ConnectMechanism for TokioMechanism {
async fn connect_once( async fn connect_once(
&self, &self,
ctx: &mut RequestMonitoring, ctx: &mut RequestMonitoring,
node_info: &CachedNodeInfo, node_info: &NodeInfo,
timeout: Duration, timeout: Duration,
) -> Result<Self::Connection, Self::ConnectError> { ) -> Result<Self::Connection, Self::ConnectError> {
let host = node_info.config.get_host()?; let host = node_info.config.get_host()?;

View File

@@ -61,7 +61,7 @@ impl fmt::Display for ConnInfo {
self.user_info.user, self.user_info.user,
self.user_info.endpoint, self.user_info.endpoint,
self.dbname, self.dbname,
self.user_info.options.get_cache_key("") self.user_info.options.get_cache_key_extras()
) )
} }
} }