From 108f08f982c1add03311ad0dd2d60be004b2ceee Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Fri, 28 Jun 2024 09:04:54 +0100 Subject: [PATCH] proxy: cache a compressed version of the node info --- proxy/src/auth/backend.rs | 2 +- proxy/src/auth/credentials.rs | 6 +++-- proxy/src/cache/common.rs | 9 ++++++++ proxy/src/compute.rs | 2 +- proxy/src/console/provider.rs | 35 ++++++++++++++++++++++++---- proxy/src/console/provider/neon.rs | 37 ++++++++++++++---------------- proxy/src/lib.rs | 14 +++++++++-- proxy/src/proxy.rs | 16 ++++++++++--- proxy/src/proxy/connect_compute.rs | 4 ++-- proxy/src/proxy/tests.rs | 21 +++++++++++------ proxy/src/serverless/backend.rs | 4 ++-- proxy/src/serverless/conn_pool.rs | 2 +- 12 files changed, 107 insertions(+), 45 deletions(-) diff --git a/proxy/src/auth/backend.rs b/proxy/src/auth/backend.rs index f757a15fbb..229524f8ca 100644 --- a/proxy/src/auth/backend.rs +++ b/proxy/src/auth/backend.rs @@ -153,7 +153,7 @@ pub struct ComputeUserInfo { impl ComputeUserInfo { pub fn endpoint_cache_key(&self) -> EndpointCacheKey { - self.options.get_cache_key(&self.endpoint) + self.options.get_cache_key((&self.endpoint).into()) } } diff --git a/proxy/src/auth/credentials.rs b/proxy/src/auth/credentials.rs index d06f5614f1..4a622507c1 100644 --- a/proxy/src/auth/credentials.rs +++ b/proxy/src/auth/credentials.rs @@ -241,6 +241,8 @@ fn project_name_valid(name: &str) -> bool { #[cfg(test)] mod tests { + use crate::intern::EndpointIdInt; + use super::*; use serde_json::json; use ComputeUserInfoParseError::*; @@ -284,7 +286,6 @@ mod tests { ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())?; assert_eq!(user_info.user, "john_doe"); assert_eq!(user_info.endpoint_id.as_deref(), Some("foo")); - assert_eq!(user_info.options.get_cache_key("foo"), "foo"); Ok(()) } @@ -442,8 +443,9 @@ mod tests { let user_info = ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())?; assert_eq!(user_info.endpoint_id.as_deref(), Some("project")); + let project = EndpointIdInt::from(EndpointId::from("project")); assert_eq!( - user_info.options.get_cache_key("project"), + user_info.options.get_cache_key(project).to_string(), "project endpoint_type:read_write lsn:0/2" ); diff --git a/proxy/src/cache/common.rs b/proxy/src/cache/common.rs index bc1c37512b..43444d21c8 100644 --- a/proxy/src/cache/common.rs +++ b/proxy/src/cache/common.rs @@ -43,6 +43,15 @@ impl Cached { Self { token: None, value } } + /// Place any entry into this wrapper; invalidation will be a no-op. + pub fn map(self, f: impl FnOnce(V) -> U) -> Cached { + let token = self.token; + Cached { + token, + value: f(self.value), + } + } + pub fn take_value(self) -> (Cached, V) { ( Cached { diff --git a/proxy/src/compute.rs b/proxy/src/compute.rs index a50a96e5e8..2a6879e441 100644 --- a/proxy/src/compute.rs +++ b/proxy/src/compute.rs @@ -93,7 +93,7 @@ pub type ScramKeys = tokio_postgres::config::ScramKeys<32>; /// Eventually, `tokio_postgres` will be replaced with something better. /// Newtype allows us to implement methods on top of it. #[derive(Clone, Default)] -pub struct ConnCfg(Box); +pub struct ConnCfg(tokio_postgres::Config); /// Creation and initialization routines. impl ConnCfg { diff --git a/proxy/src/console/provider.rs b/proxy/src/console/provider.rs index bec55a8343..4fdda58dd2 100644 --- a/proxy/src/console/provider.rs +++ b/proxy/src/console/provider.rs @@ -9,14 +9,14 @@ use crate::{ IpPattern, }, cache::{endpoints::EndpointsCache, project_info::ProjectInfoCacheImpl, Cached, TimedLru}, - compute, + compute::{self, ConnCfg}, config::{CacheOptions, EndpointCacheConfig, ProjectInfoCacheOptions}, context::RequestMonitoring, error::ReportableError, intern::ProjectIdInt, metrics::ApiLockMetrics, rate_limiter::{DynamicLimiter, Outcome, RateLimiterConfig, Token}, - scram, EndpointCacheKey, + scram, EndpointCacheKey, Host, }; use dashmap::DashMap; use std::{hash::Hash, sync::Arc, time::Duration}; @@ -289,6 +289,33 @@ pub struct NodeInfo { pub allow_self_signed_compute: bool, } +/// Cached info for establishing a connection to a compute node. +#[derive(Clone)] +pub struct NodeCachedInfo { + pub host: Host, + pub port: u16, + + /// Labels for proxy's metrics. + pub aux: MetricsAuxInfo, + + /// Whether we should accept self-signed certificates (for testing) + pub allow_self_signed_compute: bool, +} + +impl NodeCachedInfo { + pub fn into_node_info(self) -> NodeInfo { + let mut config = ConnCfg::default(); + config.ssl_mode(tokio_postgres::config::SslMode::Disable); + config.host(&self.host); + config.port(self.port); + NodeInfo { + config, + aux: self.aux, + allow_self_signed_compute: self.allow_self_signed_compute, + } + } +} + impl NodeInfo { pub async fn connect( &self, @@ -317,8 +344,8 @@ impl NodeInfo { } } -pub type NodeInfoCache = TimedLru; -pub type CachedNodeInfo = Cached<&'static NodeInfoCache>; +pub type NodeInfoCache = TimedLru; +pub type CachedNodeInfo = Cached<&'static NodeInfoCache, NodeInfo>; pub type CachedRoleSecret = Cached<&'static ProjectInfoCacheImpl, Option>; pub type CachedAllowedIps = Cached<&'static ProjectInfoCacheImpl, Arc>>; diff --git a/proxy/src/console/provider/neon.rs b/proxy/src/console/provider/neon.rs index 41bd2f4956..5bc5fa1768 100644 --- a/proxy/src/console/provider/neon.rs +++ b/proxy/src/console/provider/neon.rs @@ -4,22 +4,20 @@ use super::{ super::messages::{ConsoleError, GetRoleSecret, WakeCompute}, errors::{ApiError, GetAuthInfoError, WakeComputeError}, ApiCaches, ApiLocks, AuthInfo, AuthSecret, CachedAllowedIps, CachedNodeInfo, CachedRoleSecret, - NodeInfo, + NodeCachedInfo, }; use crate::{ auth::backend::ComputeUserInfo, - compute, console::messages::ColdStartInfo, http, metrics::{CacheOutcome, Metrics}, rate_limiter::EndpointRateLimiter, - scram, EndpointCacheKey, + scram, EndpointCacheKey, Host, }; use crate::{cache::Cached, context::RequestMonitoring}; use futures::TryFutureExt; use std::sync::Arc; use tokio::time::Instant; -use tokio_postgres::config::SslMode; use tracing::{error, info, info_span, warn, Instrument}; pub struct Api { @@ -132,7 +130,7 @@ impl Api { &self, ctx: &mut RequestMonitoring, user_info: &ComputeUserInfo, - ) -> Result { + ) -> Result { let request_id = ctx.session_id.to_string(); let application_name = ctx.console_application_name(); async { @@ -167,15 +165,11 @@ impl Api { None => return Err(WakeComputeError::BadComputeAddress(body.address)), Some(x) => x, }; + let host = Host(host.into()); - // Don't set anything but host and port! This config will be cached. - // We'll set username and such later using the startup message. - // TODO: add more type safety (in progress). - let mut config = compute::ConnCfg::new(); - config.host(host).port(port).ssl_mode(SslMode::Disable); // TLS is not configured on compute nodes. - - let node = NodeInfo { - config, + let node = NodeCachedInfo { + host, + port, aux: body.aux, allow_self_signed_compute: false, }; @@ -278,9 +272,9 @@ impl super::Api for Api { // The connection info remains the same during that period of time, // which means that we might cache it to reduce the load and latency. if let Some(cached) = self.caches.node_info.get(&key) { - info!(key = &*key, "found cached compute node info"); + info!(key = display(&key), "found cached compute node info"); ctx.set_project(cached.aux.clone()); - return Ok(cached); + return Ok(cached.map(NodeCachedInfo::into_node_info)); } let permit = self.locks.get_permit(&key).await?; @@ -289,9 +283,9 @@ impl super::Api for Api { // double check if permit.should_check_cache() { if let Some(cached) = self.caches.node_info.get(&key) { - info!(key = &*key, "found cached compute node info"); + info!(key = display(&key), "found cached compute node info"); ctx.set_project(cached.aux.clone()); - return Ok(cached); + return Ok(cached.map(NodeCachedInfo::into_node_info)); } } @@ -300,7 +294,7 @@ impl super::Api for Api { .wake_compute_endpoint_rate_limiter .check(user_info.endpoint.normalize_intern(), 1) { - info!(key = &*key, "found cached compute node info"); + info!(key = display(&key), "found cached compute node info"); return Err(WakeComputeError::TooManyConnections); } @@ -314,9 +308,12 @@ impl super::Api for Api { let (_, mut cached) = self.caches.node_info.insert(key.clone(), node); cached.aux.cold_start_info = cold_start_info; - info!(key = &*key, "created a cache entry for compute node info"); + info!( + key = display(&key), + "created a cache entry for compute node info" + ); - Ok(cached) + Ok(cached.map(NodeCachedInfo::into_node_info)) } } diff --git a/proxy/src/lib.rs b/proxy/src/lib.rs index ea92eaaa55..78f0c350a9 100644 --- a/proxy/src/lib.rs +++ b/proxy/src/lib.rs @@ -157,8 +157,18 @@ smol_str_wrapper!(BranchId); // 90% of project strings are 23 characters or less. smol_str_wrapper!(ProjectId); -// will usually equal endpoint ID -smol_str_wrapper!(EndpointCacheKey); +// ket value neon_option fields +smol_str_wrapper!(EndpointCacheKeyExtra); +#[derive(PartialEq, Eq, Hash, Debug, Clone)] +pub struct EndpointCacheKey { + pub id: EndpointIdInt, + pub extra: EndpointCacheKeyExtra, +} +impl std::fmt::Display for EndpointCacheKey { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}{}", &self.id, &self.extra) + } +} smol_str_wrapper!(DbName); diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index 3edefcf21a..84a0aae652 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -10,6 +10,8 @@ pub mod wake_compute; pub use copy_bidirectional::copy_bidirectional_client_compute; pub use copy_bidirectional::ErrorSource; +use crate::intern::EndpointIdInt; +use crate::EndpointCacheKeyExtra; use crate::{ auth, cancellation::{self, CancellationHandlerMain, CancellationHandlerMainInternal}, @@ -404,11 +406,19 @@ impl NeonOptions { Self(options) } - pub fn get_cache_key(&self, prefix: &str) -> EndpointCacheKey { + pub fn get_cache_key(&self, endpoint: EndpointIdInt) -> EndpointCacheKey { + EndpointCacheKey { + id: endpoint, + extra: self.get_cache_key_extras(), + } + } + + pub fn get_cache_key_extras(&self) -> EndpointCacheKeyExtra { // prefix + format!(" {k}:{v}") // kinda jank because SmolStr is immutable - std::iter::once(prefix) - .chain(self.0.iter().flat_map(|(k, v)| [" ", &**k, ":", &**v])) + self.0 + .iter() + .flat_map(|(k, v)| [" ", &**k, ":", &**v]) .collect::() .into() } diff --git a/proxy/src/proxy/connect_compute.rs b/proxy/src/proxy/connect_compute.rs index 82180aaee3..d81500bcd4 100644 --- a/proxy/src/proxy/connect_compute.rs +++ b/proxy/src/proxy/connect_compute.rs @@ -47,7 +47,7 @@ pub trait ConnectMechanism { async fn connect_once( &self, ctx: &mut RequestMonitoring, - node_info: &console::CachedNodeInfo, + node_info: &NodeInfo, timeout: time::Duration, ) -> Result; @@ -82,7 +82,7 @@ impl ConnectMechanism for TcpMechanism<'_> { async fn connect_once( &self, ctx: &mut RequestMonitoring, - node_info: &console::CachedNodeInfo, + node_info: &NodeInfo, timeout: time::Duration, ) -> Result { let host = node_info.config.get_host()?; diff --git a/proxy/src/proxy/tests.rs b/proxy/src/proxy/tests.rs index 8119f39fae..fc93512805 100644 --- a/proxy/src/proxy/tests.rs +++ b/proxy/src/proxy/tests.rs @@ -13,8 +13,10 @@ use crate::auth::backend::{ use crate::config::{CertResolver, RetryConfig}; use crate::console::caches::NodeInfoCache; use crate::console::messages::{ConsoleError, Details, MetricsAuxInfo, Status}; -use crate::console::provider::{CachedAllowedIps, CachedRoleSecret, ConsoleBackend}; -use crate::console::{self, CachedNodeInfo, NodeInfo}; +use crate::console::provider::{ + CachedAllowedIps, CachedRoleSecret, ConsoleBackend, NodeCachedInfo, +}; +use crate::console::{self, CachedNodeInfo}; use crate::error::ErrorKind; use crate::{http, sasl, scram, BranchId, EndpointId, ProjectId}; use anyhow::{bail, Context}; @@ -458,7 +460,7 @@ impl ConnectMechanism for TestConnectMechanism { async fn connect_once( &self, _ctx: &mut RequestMonitoring, - _node_info: &console::CachedNodeInfo, + _node_info: &console::NodeInfo, _timeout: std::time::Duration, ) -> Result { let mut counter = self.counter.lock().unwrap(); @@ -530,8 +532,9 @@ impl TestBackend for TestConnectMechanism { } fn helper_create_cached_node_info(cache: &'static NodeInfoCache) -> CachedNodeInfo { - let node = NodeInfo { - config: compute::ConnCfg::new(), + let node = NodeCachedInfo { + host: "localhost".into(), + port: 5432, aux: MetricsAuxInfo { endpoint_id: (&EndpointId::from("endpoint")).into(), project_id: (&ProjectId::from("project")).into(), @@ -540,8 +543,12 @@ fn helper_create_cached_node_info(cache: &'static NodeInfoCache) -> CachedNodeIn }, allow_self_signed_compute: false, }; - let (_, node) = cache.insert("key".into(), node); - node + let key = EndpointCacheKey { + id: node.aux.endpoint_id, + extra: "".into(), + }; + let (_, node) = cache.insert(key, node); + node.map(NodeCachedInfo::into_node_info) } fn helper_create_connect_info( diff --git a/proxy/src/serverless/backend.rs b/proxy/src/serverless/backend.rs index 6c34d48338..6bac80fcc4 100644 --- a/proxy/src/serverless/backend.rs +++ b/proxy/src/serverless/backend.rs @@ -11,7 +11,7 @@ use crate::{ errors::{GetAuthInfoError, WakeComputeError}, locks::ApiLocks, provider::ApiLockError, - CachedNodeInfo, + NodeInfo, }, context::RequestMonitoring, error::{ErrorKind, ReportableError, UserFacingError}, @@ -223,7 +223,7 @@ impl ConnectMechanism for TokioMechanism { async fn connect_once( &self, ctx: &mut RequestMonitoring, - node_info: &CachedNodeInfo, + node_info: &NodeInfo, timeout: Duration, ) -> Result { let host = node_info.config.get_host()?; diff --git a/proxy/src/serverless/conn_pool.rs b/proxy/src/serverless/conn_pool.rs index 170bda062e..923f7d78d2 100644 --- a/proxy/src/serverless/conn_pool.rs +++ b/proxy/src/serverless/conn_pool.rs @@ -61,7 +61,7 @@ impl fmt::Display for ConnInfo { self.user_info.user, self.user_info.endpoint, self.dbname, - self.user_info.options.get_cache_key("") + self.user_info.options.get_cache_key_extras() ) } }