From 210700d0d953b5f8045ab1ac17c064d65865f735 Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Wed, 24 Jan 2024 16:38:10 +0000 Subject: [PATCH] proxy: add newtype wrappers for string based IDs (#6445) ## Problem too many string based IDs. easy to mix up ID types. ## Summary of changes Add a bunch of `SmolStr` wrappers that provide convenience methods but are type safe --- Cargo.lock | 4 +- proxy/src/auth/backend.rs | 19 ++++--- proxy/src/auth/credentials.rs | 59 +++++++++++---------- proxy/src/auth/password_hack.rs | 5 +- proxy/src/cache/project_info.rs | 64 +++++++++++----------- proxy/src/console/messages.rs | 11 ++-- proxy/src/console/provider.rs | 11 ++-- proxy/src/console/provider/neon.rs | 5 +- proxy/src/context.rs | 17 +++--- proxy/src/lib.rs | 76 +++++++++++++++++++++++++++ proxy/src/proxy.rs | 12 +++-- proxy/src/rate_limiter/limiter.rs | 15 +++--- proxy/src/redis/notifications.rs | 9 ++-- proxy/src/serverless/conn_pool.rs | 20 ++++--- proxy/src/serverless/sql_over_http.rs | 7 +-- proxy/src/usage_metrics.rs | 7 ++- 16 files changed, 215 insertions(+), 126 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 02a437ccf9..f2f31192f0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5113,9 +5113,9 @@ checksum = "62bb4feee49fdd9f707ef802e22365a35de4b7b299de4763d44bfea899442ff9" [[package]] name = "smol_str" -version = "0.2.0" +version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "74212e6bbe9a4352329b2f68ba3130c15a3f26fe88ff22dbdc6cdd58fa85e99c" +checksum = "e6845563ada680337a52d43bb0b29f396f2d911616f6573012645b9e3d048a49" dependencies = [ "serde", ] diff --git a/proxy/src/auth/backend.rs b/proxy/src/auth/backend.rs index 1e03510119..b1634906c9 100644 --- a/proxy/src/auth/backend.rs +++ b/proxy/src/auth/backend.rs @@ -3,7 +3,6 @@ mod hacks; mod link; pub use link::LinkAuthError; -use smol_str::SmolStr; use tokio_postgres::config::AuthKeys; use crate::auth::credentials::check_peer_addr_is_in_list; @@ -16,7 +15,6 @@ use crate::context::RequestMonitoring; use crate::proxy::connect_compute::handle_try_wake; use crate::proxy::retry::retry_after; use crate::proxy::NeonOptions; -use crate::scram; use crate::stream::Stream; use crate::{ auth::{self, ComputeUserInfoMaybeEndpoint}, @@ -28,6 +26,7 @@ use crate::{ }, stream, url, }; +use crate::{scram, EndpointCacheKey, EndpointId, RoleName}; use futures::TryFutureExt; use std::borrow::Cow; use std::ops::ControlFlow; @@ -130,19 +129,19 @@ pub struct ComputeCredentials { #[derive(Debug, Clone)] pub struct ComputeUserInfoNoEndpoint { - pub user: SmolStr, + pub user: RoleName, pub options: NeonOptions, } #[derive(Debug, Clone)] pub struct ComputeUserInfo { - pub endpoint: SmolStr, - pub user: SmolStr, + pub endpoint: EndpointId, + pub user: RoleName, pub options: NeonOptions, } impl ComputeUserInfo { - pub fn endpoint_cache_key(&self) -> SmolStr { + pub fn endpoint_cache_key(&self) -> EndpointCacheKey { self.options.get_cache_key(&self.endpoint) } } @@ -158,7 +157,7 @@ impl TryFrom for ComputeUserInfo { type Error = ComputeUserInfoNoEndpoint; fn try_from(user_info: ComputeUserInfoMaybeEndpoint) -> Result { - match user_info.project { + match user_info.endpoint_id { None => Err(ComputeUserInfoNoEndpoint { user: user_info.user, options: user_info.options, @@ -317,11 +316,11 @@ async fn auth_and_wake_compute( impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> { /// Get compute endpoint name from the credentials. - pub fn get_endpoint(&self) -> Option { + pub fn get_endpoint(&self) -> Option { use BackendType::*; match self { - Console(_, user_info) => user_info.project.clone(), + Console(_, user_info) => user_info.endpoint_id.clone(), Link(_) => Some("link".into()), #[cfg(test)] Test(_) => Some("test".into()), @@ -355,7 +354,7 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> { Console(api, user_info) => { info!( user = &*user_info.user, - project = user_info.project(), + project = user_info.endpoint(), "performing authentication using the console" ); diff --git a/proxy/src/auth/credentials.rs b/proxy/src/auth/credentials.rs index 342fd6fce9..bdb79f2517 100644 --- a/proxy/src/auth/credentials.rs +++ b/proxy/src/auth/credentials.rs @@ -2,7 +2,7 @@ use crate::{ auth::password_hack::parse_endpoint_param, context::RequestMonitoring, error::UserFacingError, - metrics::NUM_CONNECTION_ACCEPTED_BY_SNI, proxy::NeonOptions, + metrics::NUM_CONNECTION_ACCEPTED_BY_SNI, proxy::NeonOptions, EndpointId, RoleName, }; use itertools::Itertools; use pq_proto::StartupMessageParams; @@ -21,7 +21,10 @@ pub enum ComputeUserInfoParseError { SNI ('{}') and project option ('{}').", .domain, .option, )] - InconsistentProjectNames { domain: SmolStr, option: SmolStr }, + InconsistentProjectNames { + domain: EndpointId, + option: EndpointId, + }, #[error( "Common name inferred from SNI ('{}') is not known", @@ -30,7 +33,7 @@ pub enum ComputeUserInfoParseError { UnknownCommonName { cn: String }, #[error("Project name ('{0}') must contain only alphanumeric characters and hyphen.")] - MalformedProjectName(SmolStr), + MalformedProjectName(EndpointId), } impl UserFacingError for ComputeUserInfoParseError {} @@ -39,17 +42,15 @@ impl UserFacingError for ComputeUserInfoParseError {} /// Note that we don't store any kind of client key or password here. #[derive(Debug, Clone, PartialEq, Eq)] pub struct ComputeUserInfoMaybeEndpoint { - pub user: SmolStr, - // TODO: this is a severe misnomer! We should think of a new name ASAP. - pub project: Option, - + pub user: RoleName, + pub endpoint_id: Option, pub options: NeonOptions, } impl ComputeUserInfoMaybeEndpoint { #[inline] - pub fn project(&self) -> Option<&str> { - self.project.as_deref() + pub fn endpoint(&self) -> Option<&str> { + self.endpoint_id.as_deref() } } @@ -79,15 +80,15 @@ impl ComputeUserInfoMaybeEndpoint { // Some parameters are stored in the startup message. let get_param = |key| params.get(key).ok_or(MissingKey(key)); - let user: SmolStr = get_param("user")?.into(); + let user: RoleName = get_param("user")?.into(); // record the values if we have them ctx.set_application(params.get("application_name").map(SmolStr::from)); ctx.set_user(user.clone()); - ctx.set_endpoint_id(sni.map(SmolStr::from)); + ctx.set_endpoint_id(sni.map(EndpointId::from)); // Project name might be passed via PG's command-line options. - let project_option = params + let endpoint_option = params .options_raw() .and_then(|options| { // We support both `project` (deprecated) and `endpoint` options for backward compatibility. @@ -100,9 +101,9 @@ impl ComputeUserInfoMaybeEndpoint { }) .map(|name| name.into()); - let project_from_domain = if let Some(sni_str) = sni { + let endpoint_from_domain = if let Some(sni_str) = sni { if let Some(cn) = common_names { - Some(SmolStr::from(endpoint_sni(sni_str, cn)?)) + Some(EndpointId::from(endpoint_sni(sni_str, cn)?)) } else { None } @@ -110,7 +111,7 @@ impl ComputeUserInfoMaybeEndpoint { None }; - let project = match (project_option, project_from_domain) { + let endpoint = match (endpoint_option, endpoint_from_domain) { // Invariant: if we have both project name variants, they should match. (Some(option), Some(domain)) if option != domain => { Some(Err(InconsistentProjectNames { domain, option })) @@ -123,13 +124,13 @@ impl ComputeUserInfoMaybeEndpoint { } .transpose()?; - info!(%user, project = project.as_deref(), "credentials"); + info!(%user, project = endpoint.as_deref(), "credentials"); if sni.is_some() { info!("Connection with sni"); NUM_CONNECTION_ACCEPTED_BY_SNI .with_label_values(&["sni"]) .inc(); - } else if project.is_some() { + } else if endpoint.is_some() { NUM_CONNECTION_ACCEPTED_BY_SNI .with_label_values(&["no_sni"]) .inc(); @@ -145,7 +146,7 @@ impl ComputeUserInfoMaybeEndpoint { Ok(Self { user, - project, + endpoint_id: endpoint.map(EndpointId::from), options, }) } @@ -238,7 +239,7 @@ mod tests { let mut ctx = RequestMonitoring::test(); let user_info = ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, None, None)?; assert_eq!(user_info.user, "john_doe"); - assert_eq!(user_info.project, None); + assert_eq!(user_info.endpoint_id, None); Ok(()) } @@ -253,7 +254,7 @@ mod tests { let mut ctx = RequestMonitoring::test(); let user_info = ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, None, None)?; assert_eq!(user_info.user, "john_doe"); - assert_eq!(user_info.project, None); + assert_eq!(user_info.endpoint_id, None); Ok(()) } @@ -269,7 +270,7 @@ mod tests { let user_info = ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())?; assert_eq!(user_info.user, "john_doe"); - assert_eq!(user_info.project.as_deref(), Some("foo")); + assert_eq!(user_info.endpoint_id.as_deref(), Some("foo")); assert_eq!(user_info.options.get_cache_key("foo"), "foo"); Ok(()) @@ -285,7 +286,7 @@ mod tests { let mut ctx = RequestMonitoring::test(); let user_info = ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, None, None)?; assert_eq!(user_info.user, "john_doe"); - assert_eq!(user_info.project.as_deref(), Some("bar")); + assert_eq!(user_info.endpoint_id.as_deref(), Some("bar")); Ok(()) } @@ -300,7 +301,7 @@ mod tests { let mut ctx = RequestMonitoring::test(); let user_info = ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, None, None)?; assert_eq!(user_info.user, "john_doe"); - assert_eq!(user_info.project.as_deref(), Some("bar")); + assert_eq!(user_info.endpoint_id.as_deref(), Some("bar")); Ok(()) } @@ -318,7 +319,7 @@ mod tests { let mut ctx = RequestMonitoring::test(); let user_info = ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, None, None)?; assert_eq!(user_info.user, "john_doe"); - assert!(user_info.project.is_none()); + assert!(user_info.endpoint_id.is_none()); Ok(()) } @@ -333,7 +334,7 @@ mod tests { let mut ctx = RequestMonitoring::test(); let user_info = ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, None, None)?; assert_eq!(user_info.user, "john_doe"); - assert!(user_info.project.is_none()); + assert!(user_info.endpoint_id.is_none()); Ok(()) } @@ -349,7 +350,7 @@ mod tests { let user_info = ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())?; assert_eq!(user_info.user, "john_doe"); - assert_eq!(user_info.project.as_deref(), Some("baz")); + assert_eq!(user_info.endpoint_id.as_deref(), Some("baz")); Ok(()) } @@ -363,14 +364,14 @@ mod tests { let mut ctx = RequestMonitoring::test(); let user_info = ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())?; - assert_eq!(user_info.project.as_deref(), Some("p1")); + assert_eq!(user_info.endpoint_id.as_deref(), Some("p1")); let common_names = Some(["a.com".into(), "b.com".into()].into()); let sni = Some("p1.b.com"); let mut ctx = RequestMonitoring::test(); let user_info = ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())?; - assert_eq!(user_info.project.as_deref(), Some("p1")); + assert_eq!(user_info.endpoint_id.as_deref(), Some("p1")); Ok(()) } @@ -427,7 +428,7 @@ mod tests { let mut ctx = RequestMonitoring::test(); let user_info = ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())?; - assert_eq!(user_info.project.as_deref(), Some("project")); + assert_eq!(user_info.endpoint_id.as_deref(), Some("project")); assert_eq!( user_info.options.get_cache_key("project"), "project endpoint_type:read_write lsn:0/2" diff --git a/proxy/src/auth/password_hack.rs b/proxy/src/auth/password_hack.rs index 372b0764ee..2ddf46fe25 100644 --- a/proxy/src/auth/password_hack.rs +++ b/proxy/src/auth/password_hack.rs @@ -4,10 +4,11 @@ //! UPDATE (Mon Aug 8 13:20:34 UTC 2022): the payload format has been simplified. use bstr::ByteSlice; -use smol_str::SmolStr; + +use crate::EndpointId; pub struct PasswordHackPayload { - pub endpoint: SmolStr, + pub endpoint: EndpointId, pub password: Vec, } diff --git a/proxy/src/cache/project_info.rs b/proxy/src/cache/project_info.rs index fa3d5d0e31..6f37868a8c 100644 --- a/proxy/src/cache/project_info.rs +++ b/proxy/src/cache/project_info.rs @@ -11,13 +11,16 @@ use smol_str::SmolStr; use tokio::time::Instant; use tracing::{debug, info}; -use crate::{auth::IpPattern, config::ProjectInfoCacheOptions, console::AuthSecret}; +use crate::{ + auth::IpPattern, config::ProjectInfoCacheOptions, console::AuthSecret, EndpointId, ProjectId, + RoleName, +}; use super::{Cache, Cached}; pub trait ProjectInfoCache { - fn invalidate_allowed_ips_for_project(&self, project_id: &SmolStr); - fn invalidate_role_secret_for_project(&self, project_id: &SmolStr, role_name: &SmolStr); + fn invalidate_allowed_ips_for_project(&self, project_id: &ProjectId); + fn invalidate_role_secret_for_project(&self, project_id: &ProjectId, role_name: &RoleName); fn enable_ttl(&self); fn disable_ttl(&self); } @@ -44,7 +47,7 @@ impl From for Entry { #[derive(Default)] struct EndpointInfo { - secret: std::collections::HashMap>>, + secret: std::collections::HashMap>>, allowed_ips: Option>>>, } @@ -57,7 +60,7 @@ impl EndpointInfo { } pub fn get_role_secret( &self, - role_name: &SmolStr, + role_name: &RoleName, valid_since: Instant, ignore_cache_since: Option, ) -> Option<(Option, bool)> { @@ -90,7 +93,7 @@ impl EndpointInfo { pub fn invalidate_allowed_ips(&mut self) { self.allowed_ips = None; } - pub fn invalidate_role_secret(&mut self, role_name: &SmolStr) { + pub fn invalidate_role_secret(&mut self, role_name: &RoleName) { self.secret.remove(role_name); } } @@ -103,9 +106,9 @@ impl EndpointInfo { /// One may ask, why the data is stored per project, when on the user request there is only data about the endpoint available? /// On the cplane side updates are done per project (or per branch), so it's easier to invalidate the whole project cache. pub struct ProjectInfoCacheImpl { - cache: DashMap, + cache: DashMap, - project2ep: DashMap>, + project2ep: DashMap>, config: ProjectInfoCacheOptions, start_time: Instant, @@ -113,7 +116,7 @@ pub struct ProjectInfoCacheImpl { } impl ProjectInfoCache for ProjectInfoCacheImpl { - fn invalidate_allowed_ips_for_project(&self, project_id: &SmolStr) { + fn invalidate_allowed_ips_for_project(&self, project_id: &ProjectId) { info!("invalidating allowed ips for project `{}`", project_id); let endpoints = self .project2ep @@ -126,7 +129,7 @@ impl ProjectInfoCache for ProjectInfoCacheImpl { } } } - fn invalidate_role_secret_for_project(&self, project_id: &SmolStr, role_name: &SmolStr) { + fn invalidate_role_secret_for_project(&self, project_id: &ProjectId, role_name: &RoleName) { info!( "invalidating role secret for project_id `{}` and role_name `{}`", project_id, role_name @@ -167,8 +170,8 @@ impl ProjectInfoCacheImpl { pub fn get_role_secret( &self, - endpoint_id: &SmolStr, - role_name: &SmolStr, + endpoint_id: &EndpointId, + role_name: &RoleName, ) -> Option>> { let (valid_since, ignore_cache_since) = self.get_cache_times(); let endpoint_info = self.cache.get(endpoint_id)?; @@ -188,7 +191,7 @@ impl ProjectInfoCacheImpl { } pub fn get_allowed_ips( &self, - endpoint_id: &SmolStr, + endpoint_id: &EndpointId, ) -> Option>>> { let (valid_since, ignore_cache_since) = self.get_cache_times(); let endpoint_info = self.cache.get(endpoint_id)?; @@ -205,9 +208,9 @@ impl ProjectInfoCacheImpl { } pub fn insert_role_secret( &self, - project_id: &SmolStr, - endpoint_id: &SmolStr, - role_name: &SmolStr, + project_id: &ProjectId, + endpoint_id: &EndpointId, + role_name: &RoleName, secret: Option, ) { if self.cache.len() >= self.config.size { @@ -222,8 +225,8 @@ impl ProjectInfoCacheImpl { } pub fn insert_allowed_ips( &self, - project_id: &SmolStr, - endpoint_id: &SmolStr, + project_id: &ProjectId, + endpoint_id: &EndpointId, allowed_ips: Arc>, ) { if self.cache.len() >= self.config.size { @@ -236,7 +239,7 @@ impl ProjectInfoCacheImpl { .or_default() .allowed_ips = Some(allowed_ips.into()); } - fn inser_project2endpoint(&self, project_id: &SmolStr, endpoint_id: &SmolStr) { + fn inser_project2endpoint(&self, project_id: &ProjectId, endpoint_id: &EndpointId) { if let Some(mut endpoints) = self.project2ep.get_mut(project_id) { endpoints.insert(endpoint_id.clone()); } else { @@ -297,18 +300,18 @@ impl ProjectInfoCacheImpl { /// This is used to invalidate cache entries. pub struct CachedLookupInfo { /// Search by this key. - endpoint_id: SmolStr, + endpoint_id: EndpointId, lookup_type: LookupType, } impl CachedLookupInfo { - pub(self) fn new_role_secret(endpoint_id: SmolStr, role_name: SmolStr) -> Self { + pub(self) fn new_role_secret(endpoint_id: EndpointId, role_name: RoleName) -> Self { Self { endpoint_id, lookup_type: LookupType::RoleSecret(role_name), } } - pub(self) fn new_allowed_ips(endpoint_id: SmolStr) -> Self { + pub(self) fn new_allowed_ips(endpoint_id: EndpointId) -> Self { Self { endpoint_id, lookup_type: LookupType::AllowedIps, @@ -317,7 +320,7 @@ impl CachedLookupInfo { } enum LookupType { - RoleSecret(SmolStr), + RoleSecret(RoleName), AllowedIps, } @@ -348,7 +351,6 @@ impl Cache for ProjectInfoCacheImpl { mod tests { use super::*; use crate::{console::AuthSecret, scram::ServerSecret}; - use smol_str::SmolStr; use std::{sync::Arc, time::Duration}; #[tokio::test] @@ -362,8 +364,8 @@ mod tests { }); let project_id = "project".into(); let endpoint_id = "endpoint".into(); - let user1: SmolStr = "user1".into(); - let user2: SmolStr = "user2".into(); + let user1: RoleName = "user1".into(); + let user2: RoleName = "user2".into(); let secret1 = Some(AuthSecret::Scram(ServerSecret::mock( user1.as_str(), [1; 32], @@ -385,7 +387,7 @@ mod tests { assert_eq!(cached.value, secret2); // Shouldn't add more than 2 roles. - let user3: SmolStr = "user3".into(); + let user3: RoleName = "user3".into(); let secret3 = Some(AuthSecret::Scram(ServerSecret::mock( user3.as_str(), [3; 32], @@ -420,8 +422,8 @@ mod tests { let project_id = "project".into(); let endpoint_id = "endpoint".into(); - let user1: SmolStr = "user1".into(); - let user2: SmolStr = "user2".into(); + let user1: RoleName = "user1".into(); + let user2: RoleName = "user2".into(); let secret1 = Some(AuthSecret::Scram(ServerSecret::mock( user1.as_str(), [1; 32], @@ -475,8 +477,8 @@ mod tests { let project_id = "project".into(); let endpoint_id = "endpoint".into(); - let user1: SmolStr = "user1".into(); - let user2: SmolStr = "user2".into(); + let user1: RoleName = "user1".into(); + let user2: RoleName = "user2".into(); let secret1 = Some(AuthSecret::Scram(ServerSecret::mock( user1.as_str(), [1; 32], diff --git a/proxy/src/console/messages.rs b/proxy/src/console/messages.rs index 1cfa2d6192..6ef9bcf4eb 100644 --- a/proxy/src/console/messages.rs +++ b/proxy/src/console/messages.rs @@ -1,9 +1,10 @@ use serde::Deserialize; -use smol_str::SmolStr; use std::fmt; use crate::auth::IpPattern; +use crate::{BranchId, EndpointId, ProjectId}; + /// Generic error response with human-readable description. /// Note that we can't always present it to user as is. #[derive(Debug, Deserialize)] @@ -17,7 +18,7 @@ pub struct ConsoleError { pub struct GetRoleSecret { pub role_secret: Box, pub allowed_ips: Option>, - pub project_id: Option>, + pub project_id: Option, } // Manually implement debug to omit sensitive info. @@ -94,9 +95,9 @@ impl fmt::Debug for DatabaseInfo { /// Also known as `ProxyMetricsAuxInfo` in the console. #[derive(Debug, Deserialize, Clone, Default)] pub struct MetricsAuxInfo { - pub endpoint_id: SmolStr, - pub project_id: SmolStr, - pub branch_id: SmolStr, + pub endpoint_id: EndpointId, + pub project_id: ProjectId, + pub branch_id: BranchId, } impl MetricsAuxInfo { diff --git a/proxy/src/console/provider.rs b/proxy/src/console/provider.rs index 53c394f52f..a6dfbd79db 100644 --- a/proxy/src/console/provider.rs +++ b/proxy/src/console/provider.rs @@ -9,11 +9,10 @@ use crate::{ compute, config::{CacheOptions, ProjectInfoCacheOptions}, context::RequestMonitoring, - scram, + scram, EndpointCacheKey, ProjectId, }; use async_trait::async_trait; use dashmap::DashMap; -use smol_str::SmolStr; use std::{sync::Arc, time::Duration}; use tokio::sync::{OwnedSemaphorePermit, Semaphore}; use tokio::time::Instant; @@ -214,7 +213,7 @@ pub struct AuthInfo { /// List of IP addresses allowed for the autorization. pub allowed_ips: Vec, /// Project ID. This is used for cache invalidation. - pub project_id: Option, + pub project_id: Option, } /// Info for establishing a connection to a compute node. @@ -233,7 +232,7 @@ pub struct NodeInfo { pub allow_self_signed_compute: bool, } -pub type NodeInfoCache = TimedLru; +pub type NodeInfoCache = TimedLru; pub type CachedNodeInfo = Cached<&'static NodeInfoCache>; pub type CachedRoleSecret = Cached<&'static ProjectInfoCacheImpl, Option>; pub type CachedAllowedIps = Cached<&'static ProjectInfoCacheImpl, Arc>>; @@ -345,7 +344,7 @@ impl ApiCaches { /// Various caches for [`console`](super). pub struct ApiLocks { name: &'static str, - node_locks: DashMap>, + node_locks: DashMap>, permits: usize, timeout: Duration, registered: prometheus::IntCounter, @@ -413,7 +412,7 @@ impl ApiLocks { pub async fn get_wake_compute_permit( &self, - key: &SmolStr, + key: &EndpointCacheKey, ) -> Result { if self.permits == 0 { return Ok(WakeComputePermit { permit: None }); diff --git a/proxy/src/console/provider/neon.rs b/proxy/src/console/provider/neon.rs index 6574e079d5..33618faed8 100644 --- a/proxy/src/console/provider/neon.rs +++ b/proxy/src/console/provider/neon.rs @@ -14,7 +14,6 @@ use crate::{ }; use async_trait::async_trait; use futures::TryFutureExt; -use smol_str::SmolStr; use std::sync::Arc; use tokio::time::Instant; use tokio_postgres::config::SslMode; @@ -98,7 +97,7 @@ impl Api { Ok(AuthInfo { secret, allowed_ips, - project_id: body.project_id.map(SmolStr::from), + project_id: body.project_id, }) } .map_err(crate::error::log_error) @@ -239,7 +238,7 @@ impl super::Api for Api { // for some time (highly depends on the console's scale-to-zero policy); // The connection info remains the same during that period of time, // which means that we might cache it to reduce the load and latency. - if let Some(cached) = self.caches.node_info.get(&*key) { + if let Some(cached) = self.caches.node_info.get(&key) { info!(key = &*key, "found cached compute node info"); return Ok(cached); } diff --git a/proxy/src/context.rs b/proxy/src/context.rs index 8a1aa4aec9..9e2ea10031 100644 --- a/proxy/src/context.rs +++ b/proxy/src/context.rs @@ -7,7 +7,10 @@ use std::net::IpAddr; use tokio::sync::mpsc; use uuid::Uuid; -use crate::{console::messages::MetricsAuxInfo, error::ErrorKind, metrics::LatencyTimer}; +use crate::{ + console::messages::MetricsAuxInfo, error::ErrorKind, metrics::LatencyTimer, BranchId, + EndpointId, ProjectId, RoleName, +}; pub mod parquet; @@ -26,10 +29,10 @@ pub struct RequestMonitoring { region: &'static str, // filled in as they are discovered - project: Option, - branch: Option, - endpoint_id: Option, - user: Option, + project: Option, + branch: Option, + endpoint_id: Option, + user: Option, application: Option, error_kind: Option, success: bool, @@ -86,7 +89,7 @@ impl RequestMonitoring { self.project = Some(x.project_id); } - pub fn set_endpoint_id(&mut self, endpoint_id: Option) { + pub fn set_endpoint_id(&mut self, endpoint_id: Option) { self.endpoint_id = endpoint_id.or_else(|| self.endpoint_id.clone()); } @@ -94,7 +97,7 @@ impl RequestMonitoring { self.application = app.or_else(|| self.application.clone()); } - pub fn set_user(&mut self, user: SmolStr) { + pub fn set_user(&mut self, user: RoleName) { self.user = Some(user); } diff --git a/proxy/src/lib.rs b/proxy/src/lib.rs index a22b2459b8..a9e4a38302 100644 --- a/proxy/src/lib.rs +++ b/proxy/src/lib.rs @@ -62,3 +62,79 @@ pub async fn handle_signals(token: CancellationToken) -> anyhow::Result(r: Result, JoinError>) -> anyhow::Result { r.context("join error").and_then(|x| x) } + +macro_rules! smol_str_wrapper { + ($name:ident) => { + #[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Default)] + pub struct $name(smol_str::SmolStr); + + impl $name { + pub fn as_str(&self) -> &str { + self.0.as_str() + } + } + + impl std::fmt::Display for $name { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.0.fmt(f) + } + } + + impl std::cmp::PartialEq for $name + where + smol_str::SmolStr: std::cmp::PartialEq, + { + fn eq(&self, other: &T) -> bool { + self.0.eq(other) + } + } + + impl From for $name + where + smol_str::SmolStr: From, + { + fn from(x: T) -> Self { + Self(x.into()) + } + } + + impl AsRef for $name { + fn as_ref(&self) -> &str { + self.0.as_ref() + } + } + + impl std::ops::Deref for $name { + type Target = str; + fn deref(&self) -> &str { + &*self.0 + } + } + + impl<'de> serde::de::Deserialize<'de> for $name { + fn deserialize>(d: D) -> Result { + >::deserialize(d).map(Self) + } + } + + impl serde::Serialize for $name { + fn serialize(&self, s: S) -> Result { + self.0.serialize(s) + } + } + }; +} + +// 90% of role name strings are 20 characters or less. +smol_str_wrapper!(RoleName); +// 50% of endpoint strings are 23 characters or less. +smol_str_wrapper!(EndpointId); +// 50% of branch strings are 23 characters or less. +smol_str_wrapper!(BranchId); +// 90% of project strings are 23 characters or less. +smol_str_wrapper!(ProjectId); + +// will usually equal endpoint ID +smol_str_wrapper!(EndpointCacheKey); + +smol_str_wrapper!(DbName); diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index 635d157383..087cc7f7a9 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -19,6 +19,7 @@ use crate::{ rate_limiter::EndpointRateLimiter, stream::{PqStream, Stream}, usage_metrics::{Ids, USAGE_METRICS}, + EndpointCacheKey, }; use anyhow::{bail, Context}; use futures::TryFutureExt; @@ -26,7 +27,7 @@ use itertools::Itertools; use once_cell::sync::OnceCell; use pq_proto::{BeMessage as Be, FeStartupPacket, StartupMessageParams}; use regex::Regex; -use smol_str::SmolStr; +use smol_str::{format_smolstr, SmolStr}; use std::sync::Arc; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use tokio_util::sync::CancellationToken; @@ -516,20 +517,21 @@ impl NeonOptions { Self(options) } - pub fn get_cache_key(&self, prefix: &str) -> SmolStr { + pub fn get_cache_key(&self, prefix: &str) -> EndpointCacheKey { // prefix + format!(" {k}:{v}") // kinda jank because SmolStr is immutable std::iter::once(prefix) .chain(self.0.iter().flat_map(|(k, v)| [" ", &**k, ":", &**v])) - .collect() + .collect::() + .into() } /// DeepObject format /// `paramName[prop1]=value1¶mName[prop2]=value2&...` - pub fn to_deep_object(&self) -> Vec<(String, SmolStr)> { + pub fn to_deep_object(&self) -> Vec<(SmolStr, SmolStr)> { self.0 .iter() - .map(|(k, v)| (format!("options[{}]", k), v.clone())) + .map(|(k, v)| (format_smolstr!("options[{}]", k), v.clone())) .collect() } } diff --git a/proxy/src/rate_limiter/limiter.rs b/proxy/src/rate_limiter/limiter.rs index a190b2cf8f..cbae72711c 100644 --- a/proxy/src/rate_limiter/limiter.rs +++ b/proxy/src/rate_limiter/limiter.rs @@ -11,11 +11,12 @@ use anyhow::bail; use dashmap::DashMap; use itertools::Itertools; use rand::{rngs::StdRng, Rng, SeedableRng}; -use smol_str::SmolStr; use tokio::sync::{Mutex as AsyncMutex, Semaphore, SemaphorePermit}; use tokio::time::{timeout, Duration, Instant}; use tracing::info; +use crate::EndpointId; + use super::{ limit_algorithm::{LimitAlgorithm, Sample}, RateLimiterConfig, @@ -33,7 +34,7 @@ use super::{ // does not look very nice (`SSL SYSCALL error: Undefined error: 0`), so for now // I went with a more expensive way that yields user-friendlier error messages. pub struct EndpointRateLimiter { - map: DashMap, Hasher>, + map: DashMap, Hasher>, info: &'static [RateBucketInfo], access_count: AtomicUsize, rand: Mutex, @@ -146,7 +147,7 @@ impl EndpointRateLimiter { } /// Check that number of connections to the endpoint is below `max_rps` rps. - pub fn check(&self, endpoint: SmolStr) -> bool { + pub fn check(&self, endpoint: EndpointId) -> bool { // do a partial GC every 2k requests. This cleans up ~ 1/64th of the map. // worst case memory usage is about: // = 2 * 2048 * 64 * (48B + 72B) @@ -493,11 +494,13 @@ mod tests { use futures::{task::noop_waker_ref, Future}; use rand::SeedableRng; use rustc_hash::FxHasher; - use smol_str::SmolStr; use tokio::time; use super::{EndpointRateLimiter, Limiter, Outcome}; - use crate::rate_limiter::{RateBucketInfo, RateLimitAlgorithm}; + use crate::{ + rate_limiter::{RateBucketInfo, RateLimitAlgorithm}, + EndpointId, + }; #[tokio::test] async fn it_works() { @@ -654,7 +657,7 @@ mod tests { RateBucketInfo::validate(&mut rates).unwrap(); let limiter = EndpointRateLimiter::new(Vec::leak(rates)); - let endpoint = SmolStr::from("ep-my-endpoint-1234"); + let endpoint = EndpointId::from("ep-my-endpoint-1234"); time::pause(); diff --git a/proxy/src/redis/notifications.rs b/proxy/src/redis/notifications.rs index d28dcbd1a7..9cd70b109b 100644 --- a/proxy/src/redis/notifications.rs +++ b/proxy/src/redis/notifications.rs @@ -3,9 +3,8 @@ use std::{convert::Infallible, sync::Arc}; use futures::StreamExt; use redis::aio::PubSub; use serde::Deserialize; -use smol_str::SmolStr; -use crate::cache::project_info::ProjectInfoCache; +use crate::{cache::project_info::ProjectInfoCache, ProjectId, RoleName}; const CHANNEL_NAME: &str = "neondb-proxy-ws-updates"; const RECONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(20); @@ -46,12 +45,12 @@ enum Notification { } #[derive(Clone, Debug, Deserialize, Eq, PartialEq)] struct AllowedIpsUpdate { - project_id: SmolStr, + project_id: ProjectId, } #[derive(Clone, Debug, Deserialize, Eq, PartialEq)] struct PasswordUpdate { - project_id: SmolStr, - role_name: SmolStr, + project_id: ProjectId, + role_name: RoleName, } fn deserialize_json_string<'de, D, T>(deserializer: D) -> Result where diff --git a/proxy/src/serverless/conn_pool.rs b/proxy/src/serverless/conn_pool.rs index c07cc2816e..5a7279ae63 100644 --- a/proxy/src/serverless/conn_pool.rs +++ b/proxy/src/serverless/conn_pool.rs @@ -31,6 +31,7 @@ use crate::{ metrics::NUM_DB_CONNECTIONS_GAUGE, proxy::connect_compute::ConnectMechanism, usage_metrics::{Ids, MetricCounter, USAGE_METRICS}, + DbName, EndpointCacheKey, RoleName, }; use crate::{compute, config}; @@ -42,17 +43,17 @@ pub const APP_NAME: SmolStr = SmolStr::new_inline("/sql_over_http"); #[derive(Debug, Clone)] pub struct ConnInfo { pub user_info: ComputeUserInfo, - pub dbname: SmolStr, + pub dbname: DbName, pub password: SmolStr, } impl ConnInfo { // hm, change to hasher to avoid cloning? - pub fn db_and_user(&self) -> (SmolStr, SmolStr) { + pub fn db_and_user(&self) -> (DbName, RoleName) { (self.dbname.clone(), self.user_info.user.clone()) } - pub fn endpoint_cache_key(&self) -> SmolStr { + pub fn endpoint_cache_key(&self) -> EndpointCacheKey { self.user_info.endpoint_cache_key() } } @@ -79,14 +80,14 @@ struct ConnPoolEntry { // Per-endpoint connection pool, (dbname, username) -> DbUserConnPool // Number of open connections is limited by the `max_conns_per_endpoint`. pub struct EndpointConnPool { - pools: HashMap<(SmolStr, SmolStr), DbUserConnPool>, + pools: HashMap<(DbName, RoleName), DbUserConnPool>, total_conns: usize, max_conns: usize, _guard: IntCounterPairGuard, } impl EndpointConnPool { - fn get_conn_entry(&mut self, db_user: (SmolStr, SmolStr)) -> Option { + fn get_conn_entry(&mut self, db_user: (DbName, RoleName)) -> Option { let Self { pools, total_conns, .. } = self; @@ -95,7 +96,7 @@ impl EndpointConnPool { .and_then(|pool_entries| pool_entries.get_conn_entry(total_conns)) } - fn remove_client(&mut self, db_user: (SmolStr, SmolStr), conn_id: uuid::Uuid) -> bool { + fn remove_client(&mut self, db_user: (DbName, RoleName), conn_id: uuid::Uuid) -> bool { let Self { pools, total_conns, .. } = self; @@ -196,7 +197,7 @@ pub struct GlobalConnPool { // // That should be a fairly conteded map, so return reference to the per-endpoint // pool as early as possible and release the lock. - global_pool: DashMap>>, + global_pool: DashMap>>, /// Number of endpoint-connection pools /// @@ -440,7 +441,10 @@ impl GlobalConnPool { Ok(Client::new(new_client, conn_info, endpoint_pool).await) } - fn get_or_create_endpoint_pool(&self, endpoint: &SmolStr) -> Arc> { + fn get_or_create_endpoint_pool( + &self, + endpoint: &EndpointCacheKey, + ) -> Arc> { // fast path if let Some(pool) = self.global_pool.get(endpoint) { return pool.clone(); diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index 9b32ae7f25..f108ab34ab 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -13,7 +13,6 @@ use hyper::{Body, HeaderMap, Request}; use serde_json::json; use serde_json::Map; use serde_json::Value; -use smol_str::SmolStr; use tokio_postgres::error::DbError; use tokio_postgres::error::ErrorPosition; use tokio_postgres::types::Kind; @@ -36,6 +35,8 @@ use crate::config::TlsConfig; use crate::context::RequestMonitoring; use crate::metrics::NUM_CONNECTION_REQUESTS_GAUGE; use crate::proxy::NeonOptions; +use crate::EndpointId; +use crate::RoleName; use super::conn_pool::ConnInfo; use super::conn_pool::GlobalConnPool; @@ -155,7 +156,7 @@ fn get_conn_info( .next() .ok_or(anyhow::anyhow!("invalid database name"))?; - let username = SmolStr::from(connection_url.username()); + let username = RoleName::from(connection_url.username()); if username.is_empty() { return Err(anyhow::anyhow!("missing username")); } @@ -189,7 +190,7 @@ fn get_conn_info( let endpoint = endpoint_sni(hostname, &tls.common_names)?; - let endpoint: SmolStr = endpoint.into(); + let endpoint: EndpointId = endpoint.into(); ctx.set_endpoint_id(Some(endpoint.clone())); let pairs = connection_url.query_pairs(); diff --git a/proxy/src/usage_metrics.rs b/proxy/src/usage_metrics.rs index 789a4c680c..d75aedf89b 100644 --- a/proxy/src/usage_metrics.rs +++ b/proxy/src/usage_metrics.rs @@ -1,12 +1,11 @@ //! Periodically collect proxy consumption metrics //! and push them to a HTTP endpoint. -use crate::{config::MetricCollectionConfig, http}; +use crate::{config::MetricCollectionConfig, http, BranchId, EndpointId}; use chrono::{DateTime, Utc}; use consumption_metrics::{idempotency_key, Event, EventChunk, EventType, CHUNK_SIZE}; use dashmap::{mapref::entry::Entry, DashMap}; use once_cell::sync::Lazy; use serde::{Deserialize, Serialize}; -use smol_str::SmolStr; use std::{ convert::Infallible, sync::{ @@ -30,8 +29,8 @@ const DEFAULT_HTTP_REPORTING_TIMEOUT: Duration = Duration::from_secs(60); /// because we enrich the event with project_id in the control-plane endpoint. #[derive(Eq, Hash, PartialEq, Serialize, Deserialize, Debug, Clone)] pub struct Ids { - pub endpoint_id: SmolStr, - pub branch_id: SmolStr, + pub endpoint_id: EndpointId, + pub branch_id: BranchId, } #[derive(Debug)]