proxy: add newtype wrappers for string based IDs (#6445)

## Problem

too many string based IDs. easy to mix up ID types.

## Summary of changes

Add a bunch of `SmolStr` wrappers that provide convenience methods but
are type safe
This commit is contained in:
Conrad Ludgate
2024-01-24 16:38:10 +00:00
committed by GitHub
parent a0a3ba85e7
commit 210700d0d9
16 changed files with 215 additions and 126 deletions

4
Cargo.lock generated
View File

@@ -5113,9 +5113,9 @@ checksum = "62bb4feee49fdd9f707ef802e22365a35de4b7b299de4763d44bfea899442ff9"
[[package]]
name = "smol_str"
version = "0.2.0"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "74212e6bbe9a4352329b2f68ba3130c15a3f26fe88ff22dbdc6cdd58fa85e99c"
checksum = "e6845563ada680337a52d43bb0b29f396f2d911616f6573012645b9e3d048a49"
dependencies = [
"serde",
]

View File

@@ -3,7 +3,6 @@ mod hacks;
mod link;
pub use link::LinkAuthError;
use smol_str::SmolStr;
use tokio_postgres::config::AuthKeys;
use crate::auth::credentials::check_peer_addr_is_in_list;
@@ -16,7 +15,6 @@ use crate::context::RequestMonitoring;
use crate::proxy::connect_compute::handle_try_wake;
use crate::proxy::retry::retry_after;
use crate::proxy::NeonOptions;
use crate::scram;
use crate::stream::Stream;
use crate::{
auth::{self, ComputeUserInfoMaybeEndpoint},
@@ -28,6 +26,7 @@ use crate::{
},
stream, url,
};
use crate::{scram, EndpointCacheKey, EndpointId, RoleName};
use futures::TryFutureExt;
use std::borrow::Cow;
use std::ops::ControlFlow;
@@ -130,19 +129,19 @@ pub struct ComputeCredentials<T> {
#[derive(Debug, Clone)]
pub struct ComputeUserInfoNoEndpoint {
pub user: SmolStr,
pub user: RoleName,
pub options: NeonOptions,
}
#[derive(Debug, Clone)]
pub struct ComputeUserInfo {
pub endpoint: SmolStr,
pub user: SmolStr,
pub endpoint: EndpointId,
pub user: RoleName,
pub options: NeonOptions,
}
impl ComputeUserInfo {
pub fn endpoint_cache_key(&self) -> SmolStr {
pub fn endpoint_cache_key(&self) -> EndpointCacheKey {
self.options.get_cache_key(&self.endpoint)
}
}
@@ -158,7 +157,7 @@ impl TryFrom<ComputeUserInfoMaybeEndpoint> for ComputeUserInfo {
type Error = ComputeUserInfoNoEndpoint;
fn try_from(user_info: ComputeUserInfoMaybeEndpoint) -> Result<Self, Self::Error> {
match user_info.project {
match user_info.endpoint_id {
None => Err(ComputeUserInfoNoEndpoint {
user: user_info.user,
options: user_info.options,
@@ -317,11 +316,11 @@ async fn auth_and_wake_compute(
impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> {
/// Get compute endpoint name from the credentials.
pub fn get_endpoint(&self) -> Option<SmolStr> {
pub fn get_endpoint(&self) -> Option<EndpointId> {
use BackendType::*;
match self {
Console(_, user_info) => user_info.project.clone(),
Console(_, user_info) => user_info.endpoint_id.clone(),
Link(_) => Some("link".into()),
#[cfg(test)]
Test(_) => Some("test".into()),
@@ -355,7 +354,7 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> {
Console(api, user_info) => {
info!(
user = &*user_info.user,
project = user_info.project(),
project = user_info.endpoint(),
"performing authentication using the console"
);

View File

@@ -2,7 +2,7 @@
use crate::{
auth::password_hack::parse_endpoint_param, context::RequestMonitoring, error::UserFacingError,
metrics::NUM_CONNECTION_ACCEPTED_BY_SNI, proxy::NeonOptions,
metrics::NUM_CONNECTION_ACCEPTED_BY_SNI, proxy::NeonOptions, EndpointId, RoleName,
};
use itertools::Itertools;
use pq_proto::StartupMessageParams;
@@ -21,7 +21,10 @@ pub enum ComputeUserInfoParseError {
SNI ('{}') and project option ('{}').",
.domain, .option,
)]
InconsistentProjectNames { domain: SmolStr, option: SmolStr },
InconsistentProjectNames {
domain: EndpointId,
option: EndpointId,
},
#[error(
"Common name inferred from SNI ('{}') is not known",
@@ -30,7 +33,7 @@ pub enum ComputeUserInfoParseError {
UnknownCommonName { cn: String },
#[error("Project name ('{0}') must contain only alphanumeric characters and hyphen.")]
MalformedProjectName(SmolStr),
MalformedProjectName(EndpointId),
}
impl UserFacingError for ComputeUserInfoParseError {}
@@ -39,17 +42,15 @@ impl UserFacingError for ComputeUserInfoParseError {}
/// Note that we don't store any kind of client key or password here.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ComputeUserInfoMaybeEndpoint {
pub user: SmolStr,
// TODO: this is a severe misnomer! We should think of a new name ASAP.
pub project: Option<SmolStr>,
pub user: RoleName,
pub endpoint_id: Option<EndpointId>,
pub options: NeonOptions,
}
impl ComputeUserInfoMaybeEndpoint {
#[inline]
pub fn project(&self) -> Option<&str> {
self.project.as_deref()
pub fn endpoint(&self) -> Option<&str> {
self.endpoint_id.as_deref()
}
}
@@ -79,15 +80,15 @@ impl ComputeUserInfoMaybeEndpoint {
// Some parameters are stored in the startup message.
let get_param = |key| params.get(key).ok_or(MissingKey(key));
let user: SmolStr = get_param("user")?.into();
let user: RoleName = get_param("user")?.into();
// record the values if we have them
ctx.set_application(params.get("application_name").map(SmolStr::from));
ctx.set_user(user.clone());
ctx.set_endpoint_id(sni.map(SmolStr::from));
ctx.set_endpoint_id(sni.map(EndpointId::from));
// Project name might be passed via PG's command-line options.
let project_option = params
let endpoint_option = params
.options_raw()
.and_then(|options| {
// We support both `project` (deprecated) and `endpoint` options for backward compatibility.
@@ -100,9 +101,9 @@ impl ComputeUserInfoMaybeEndpoint {
})
.map(|name| name.into());
let project_from_domain = if let Some(sni_str) = sni {
let endpoint_from_domain = if let Some(sni_str) = sni {
if let Some(cn) = common_names {
Some(SmolStr::from(endpoint_sni(sni_str, cn)?))
Some(EndpointId::from(endpoint_sni(sni_str, cn)?))
} else {
None
}
@@ -110,7 +111,7 @@ impl ComputeUserInfoMaybeEndpoint {
None
};
let project = match (project_option, project_from_domain) {
let endpoint = match (endpoint_option, endpoint_from_domain) {
// Invariant: if we have both project name variants, they should match.
(Some(option), Some(domain)) if option != domain => {
Some(Err(InconsistentProjectNames { domain, option }))
@@ -123,13 +124,13 @@ impl ComputeUserInfoMaybeEndpoint {
}
.transpose()?;
info!(%user, project = project.as_deref(), "credentials");
info!(%user, project = endpoint.as_deref(), "credentials");
if sni.is_some() {
info!("Connection with sni");
NUM_CONNECTION_ACCEPTED_BY_SNI
.with_label_values(&["sni"])
.inc();
} else if project.is_some() {
} else if endpoint.is_some() {
NUM_CONNECTION_ACCEPTED_BY_SNI
.with_label_values(&["no_sni"])
.inc();
@@ -145,7 +146,7 @@ impl ComputeUserInfoMaybeEndpoint {
Ok(Self {
user,
project,
endpoint_id: endpoint.map(EndpointId::from),
options,
})
}
@@ -238,7 +239,7 @@ mod tests {
let mut ctx = RequestMonitoring::test();
let user_info = ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, None, None)?;
assert_eq!(user_info.user, "john_doe");
assert_eq!(user_info.project, None);
assert_eq!(user_info.endpoint_id, None);
Ok(())
}
@@ -253,7 +254,7 @@ mod tests {
let mut ctx = RequestMonitoring::test();
let user_info = ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, None, None)?;
assert_eq!(user_info.user, "john_doe");
assert_eq!(user_info.project, None);
assert_eq!(user_info.endpoint_id, None);
Ok(())
}
@@ -269,7 +270,7 @@ mod tests {
let user_info =
ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())?;
assert_eq!(user_info.user, "john_doe");
assert_eq!(user_info.project.as_deref(), Some("foo"));
assert_eq!(user_info.endpoint_id.as_deref(), Some("foo"));
assert_eq!(user_info.options.get_cache_key("foo"), "foo");
Ok(())
@@ -285,7 +286,7 @@ mod tests {
let mut ctx = RequestMonitoring::test();
let user_info = ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, None, None)?;
assert_eq!(user_info.user, "john_doe");
assert_eq!(user_info.project.as_deref(), Some("bar"));
assert_eq!(user_info.endpoint_id.as_deref(), Some("bar"));
Ok(())
}
@@ -300,7 +301,7 @@ mod tests {
let mut ctx = RequestMonitoring::test();
let user_info = ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, None, None)?;
assert_eq!(user_info.user, "john_doe");
assert_eq!(user_info.project.as_deref(), Some("bar"));
assert_eq!(user_info.endpoint_id.as_deref(), Some("bar"));
Ok(())
}
@@ -318,7 +319,7 @@ mod tests {
let mut ctx = RequestMonitoring::test();
let user_info = ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, None, None)?;
assert_eq!(user_info.user, "john_doe");
assert!(user_info.project.is_none());
assert!(user_info.endpoint_id.is_none());
Ok(())
}
@@ -333,7 +334,7 @@ mod tests {
let mut ctx = RequestMonitoring::test();
let user_info = ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, None, None)?;
assert_eq!(user_info.user, "john_doe");
assert!(user_info.project.is_none());
assert!(user_info.endpoint_id.is_none());
Ok(())
}
@@ -349,7 +350,7 @@ mod tests {
let user_info =
ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())?;
assert_eq!(user_info.user, "john_doe");
assert_eq!(user_info.project.as_deref(), Some("baz"));
assert_eq!(user_info.endpoint_id.as_deref(), Some("baz"));
Ok(())
}
@@ -363,14 +364,14 @@ mod tests {
let mut ctx = RequestMonitoring::test();
let user_info =
ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())?;
assert_eq!(user_info.project.as_deref(), Some("p1"));
assert_eq!(user_info.endpoint_id.as_deref(), Some("p1"));
let common_names = Some(["a.com".into(), "b.com".into()].into());
let sni = Some("p1.b.com");
let mut ctx = RequestMonitoring::test();
let user_info =
ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())?;
assert_eq!(user_info.project.as_deref(), Some("p1"));
assert_eq!(user_info.endpoint_id.as_deref(), Some("p1"));
Ok(())
}
@@ -427,7 +428,7 @@ mod tests {
let mut ctx = RequestMonitoring::test();
let user_info =
ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())?;
assert_eq!(user_info.project.as_deref(), Some("project"));
assert_eq!(user_info.endpoint_id.as_deref(), Some("project"));
assert_eq!(
user_info.options.get_cache_key("project"),
"project endpoint_type:read_write lsn:0/2"

View File

@@ -4,10 +4,11 @@
//! UPDATE (Mon Aug 8 13:20:34 UTC 2022): the payload format has been simplified.
use bstr::ByteSlice;
use smol_str::SmolStr;
use crate::EndpointId;
pub struct PasswordHackPayload {
pub endpoint: SmolStr,
pub endpoint: EndpointId,
pub password: Vec<u8>,
}

View File

@@ -11,13 +11,16 @@ use smol_str::SmolStr;
use tokio::time::Instant;
use tracing::{debug, info};
use crate::{auth::IpPattern, config::ProjectInfoCacheOptions, console::AuthSecret};
use crate::{
auth::IpPattern, config::ProjectInfoCacheOptions, console::AuthSecret, EndpointId, ProjectId,
RoleName,
};
use super::{Cache, Cached};
pub trait ProjectInfoCache {
fn invalidate_allowed_ips_for_project(&self, project_id: &SmolStr);
fn invalidate_role_secret_for_project(&self, project_id: &SmolStr, role_name: &SmolStr);
fn invalidate_allowed_ips_for_project(&self, project_id: &ProjectId);
fn invalidate_role_secret_for_project(&self, project_id: &ProjectId, role_name: &RoleName);
fn enable_ttl(&self);
fn disable_ttl(&self);
}
@@ -44,7 +47,7 @@ impl<T> From<T> for Entry<T> {
#[derive(Default)]
struct EndpointInfo {
secret: std::collections::HashMap<SmolStr, Entry<Option<AuthSecret>>>,
secret: std::collections::HashMap<RoleName, Entry<Option<AuthSecret>>>,
allowed_ips: Option<Entry<Arc<Vec<IpPattern>>>>,
}
@@ -57,7 +60,7 @@ impl EndpointInfo {
}
pub fn get_role_secret(
&self,
role_name: &SmolStr,
role_name: &RoleName,
valid_since: Instant,
ignore_cache_since: Option<Instant>,
) -> Option<(Option<AuthSecret>, bool)> {
@@ -90,7 +93,7 @@ impl EndpointInfo {
pub fn invalidate_allowed_ips(&mut self) {
self.allowed_ips = None;
}
pub fn invalidate_role_secret(&mut self, role_name: &SmolStr) {
pub fn invalidate_role_secret(&mut self, role_name: &RoleName) {
self.secret.remove(role_name);
}
}
@@ -103,9 +106,9 @@ impl EndpointInfo {
/// One may ask, why the data is stored per project, when on the user request there is only data about the endpoint available?
/// On the cplane side updates are done per project (or per branch), so it's easier to invalidate the whole project cache.
pub struct ProjectInfoCacheImpl {
cache: DashMap<SmolStr, EndpointInfo>,
cache: DashMap<EndpointId, EndpointInfo>,
project2ep: DashMap<SmolStr, HashSet<SmolStr>>,
project2ep: DashMap<ProjectId, HashSet<EndpointId>>,
config: ProjectInfoCacheOptions,
start_time: Instant,
@@ -113,7 +116,7 @@ pub struct ProjectInfoCacheImpl {
}
impl ProjectInfoCache for ProjectInfoCacheImpl {
fn invalidate_allowed_ips_for_project(&self, project_id: &SmolStr) {
fn invalidate_allowed_ips_for_project(&self, project_id: &ProjectId) {
info!("invalidating allowed ips for project `{}`", project_id);
let endpoints = self
.project2ep
@@ -126,7 +129,7 @@ impl ProjectInfoCache for ProjectInfoCacheImpl {
}
}
}
fn invalidate_role_secret_for_project(&self, project_id: &SmolStr, role_name: &SmolStr) {
fn invalidate_role_secret_for_project(&self, project_id: &ProjectId, role_name: &RoleName) {
info!(
"invalidating role secret for project_id `{}` and role_name `{}`",
project_id, role_name
@@ -167,8 +170,8 @@ impl ProjectInfoCacheImpl {
pub fn get_role_secret(
&self,
endpoint_id: &SmolStr,
role_name: &SmolStr,
endpoint_id: &EndpointId,
role_name: &RoleName,
) -> Option<Cached<&Self, Option<AuthSecret>>> {
let (valid_since, ignore_cache_since) = self.get_cache_times();
let endpoint_info = self.cache.get(endpoint_id)?;
@@ -188,7 +191,7 @@ impl ProjectInfoCacheImpl {
}
pub fn get_allowed_ips(
&self,
endpoint_id: &SmolStr,
endpoint_id: &EndpointId,
) -> Option<Cached<&Self, Arc<Vec<IpPattern>>>> {
let (valid_since, ignore_cache_since) = self.get_cache_times();
let endpoint_info = self.cache.get(endpoint_id)?;
@@ -205,9 +208,9 @@ impl ProjectInfoCacheImpl {
}
pub fn insert_role_secret(
&self,
project_id: &SmolStr,
endpoint_id: &SmolStr,
role_name: &SmolStr,
project_id: &ProjectId,
endpoint_id: &EndpointId,
role_name: &RoleName,
secret: Option<AuthSecret>,
) {
if self.cache.len() >= self.config.size {
@@ -222,8 +225,8 @@ impl ProjectInfoCacheImpl {
}
pub fn insert_allowed_ips(
&self,
project_id: &SmolStr,
endpoint_id: &SmolStr,
project_id: &ProjectId,
endpoint_id: &EndpointId,
allowed_ips: Arc<Vec<IpPattern>>,
) {
if self.cache.len() >= self.config.size {
@@ -236,7 +239,7 @@ impl ProjectInfoCacheImpl {
.or_default()
.allowed_ips = Some(allowed_ips.into());
}
fn inser_project2endpoint(&self, project_id: &SmolStr, endpoint_id: &SmolStr) {
fn inser_project2endpoint(&self, project_id: &ProjectId, endpoint_id: &EndpointId) {
if let Some(mut endpoints) = self.project2ep.get_mut(project_id) {
endpoints.insert(endpoint_id.clone());
} else {
@@ -297,18 +300,18 @@ impl ProjectInfoCacheImpl {
/// This is used to invalidate cache entries.
pub struct CachedLookupInfo {
/// Search by this key.
endpoint_id: SmolStr,
endpoint_id: EndpointId,
lookup_type: LookupType,
}
impl CachedLookupInfo {
pub(self) fn new_role_secret(endpoint_id: SmolStr, role_name: SmolStr) -> Self {
pub(self) fn new_role_secret(endpoint_id: EndpointId, role_name: RoleName) -> Self {
Self {
endpoint_id,
lookup_type: LookupType::RoleSecret(role_name),
}
}
pub(self) fn new_allowed_ips(endpoint_id: SmolStr) -> Self {
pub(self) fn new_allowed_ips(endpoint_id: EndpointId) -> Self {
Self {
endpoint_id,
lookup_type: LookupType::AllowedIps,
@@ -317,7 +320,7 @@ impl CachedLookupInfo {
}
enum LookupType {
RoleSecret(SmolStr),
RoleSecret(RoleName),
AllowedIps,
}
@@ -348,7 +351,6 @@ impl Cache for ProjectInfoCacheImpl {
mod tests {
use super::*;
use crate::{console::AuthSecret, scram::ServerSecret};
use smol_str::SmolStr;
use std::{sync::Arc, time::Duration};
#[tokio::test]
@@ -362,8 +364,8 @@ mod tests {
});
let project_id = "project".into();
let endpoint_id = "endpoint".into();
let user1: SmolStr = "user1".into();
let user2: SmolStr = "user2".into();
let user1: RoleName = "user1".into();
let user2: RoleName = "user2".into();
let secret1 = Some(AuthSecret::Scram(ServerSecret::mock(
user1.as_str(),
[1; 32],
@@ -385,7 +387,7 @@ mod tests {
assert_eq!(cached.value, secret2);
// Shouldn't add more than 2 roles.
let user3: SmolStr = "user3".into();
let user3: RoleName = "user3".into();
let secret3 = Some(AuthSecret::Scram(ServerSecret::mock(
user3.as_str(),
[3; 32],
@@ -420,8 +422,8 @@ mod tests {
let project_id = "project".into();
let endpoint_id = "endpoint".into();
let user1: SmolStr = "user1".into();
let user2: SmolStr = "user2".into();
let user1: RoleName = "user1".into();
let user2: RoleName = "user2".into();
let secret1 = Some(AuthSecret::Scram(ServerSecret::mock(
user1.as_str(),
[1; 32],
@@ -475,8 +477,8 @@ mod tests {
let project_id = "project".into();
let endpoint_id = "endpoint".into();
let user1: SmolStr = "user1".into();
let user2: SmolStr = "user2".into();
let user1: RoleName = "user1".into();
let user2: RoleName = "user2".into();
let secret1 = Some(AuthSecret::Scram(ServerSecret::mock(
user1.as_str(),
[1; 32],

View File

@@ -1,9 +1,10 @@
use serde::Deserialize;
use smol_str::SmolStr;
use std::fmt;
use crate::auth::IpPattern;
use crate::{BranchId, EndpointId, ProjectId};
/// Generic error response with human-readable description.
/// Note that we can't always present it to user as is.
#[derive(Debug, Deserialize)]
@@ -17,7 +18,7 @@ pub struct ConsoleError {
pub struct GetRoleSecret {
pub role_secret: Box<str>,
pub allowed_ips: Option<Vec<IpPattern>>,
pub project_id: Option<Box<str>>,
pub project_id: Option<ProjectId>,
}
// Manually implement debug to omit sensitive info.
@@ -94,9 +95,9 @@ impl fmt::Debug for DatabaseInfo {
/// Also known as `ProxyMetricsAuxInfo` in the console.
#[derive(Debug, Deserialize, Clone, Default)]
pub struct MetricsAuxInfo {
pub endpoint_id: SmolStr,
pub project_id: SmolStr,
pub branch_id: SmolStr,
pub endpoint_id: EndpointId,
pub project_id: ProjectId,
pub branch_id: BranchId,
}
impl MetricsAuxInfo {

View File

@@ -9,11 +9,10 @@ use crate::{
compute,
config::{CacheOptions, ProjectInfoCacheOptions},
context::RequestMonitoring,
scram,
scram, EndpointCacheKey, ProjectId,
};
use async_trait::async_trait;
use dashmap::DashMap;
use smol_str::SmolStr;
use std::{sync::Arc, time::Duration};
use tokio::sync::{OwnedSemaphorePermit, Semaphore};
use tokio::time::Instant;
@@ -214,7 +213,7 @@ pub struct AuthInfo {
/// List of IP addresses allowed for the autorization.
pub allowed_ips: Vec<IpPattern>,
/// Project ID. This is used for cache invalidation.
pub project_id: Option<SmolStr>,
pub project_id: Option<ProjectId>,
}
/// Info for establishing a connection to a compute node.
@@ -233,7 +232,7 @@ pub struct NodeInfo {
pub allow_self_signed_compute: bool,
}
pub type NodeInfoCache = TimedLru<SmolStr, NodeInfo>;
pub type NodeInfoCache = TimedLru<EndpointCacheKey, NodeInfo>;
pub type CachedNodeInfo = Cached<&'static NodeInfoCache>;
pub type CachedRoleSecret = Cached<&'static ProjectInfoCacheImpl, Option<AuthSecret>>;
pub type CachedAllowedIps = Cached<&'static ProjectInfoCacheImpl, Arc<Vec<IpPattern>>>;
@@ -345,7 +344,7 @@ impl ApiCaches {
/// Various caches for [`console`](super).
pub struct ApiLocks {
name: &'static str,
node_locks: DashMap<SmolStr, Arc<Semaphore>>,
node_locks: DashMap<EndpointCacheKey, Arc<Semaphore>>,
permits: usize,
timeout: Duration,
registered: prometheus::IntCounter,
@@ -413,7 +412,7 @@ impl ApiLocks {
pub async fn get_wake_compute_permit(
&self,
key: &SmolStr,
key: &EndpointCacheKey,
) -> Result<WakeComputePermit, errors::WakeComputeError> {
if self.permits == 0 {
return Ok(WakeComputePermit { permit: None });

View File

@@ -14,7 +14,6 @@ use crate::{
};
use async_trait::async_trait;
use futures::TryFutureExt;
use smol_str::SmolStr;
use std::sync::Arc;
use tokio::time::Instant;
use tokio_postgres::config::SslMode;
@@ -98,7 +97,7 @@ impl Api {
Ok(AuthInfo {
secret,
allowed_ips,
project_id: body.project_id.map(SmolStr::from),
project_id: body.project_id,
})
}
.map_err(crate::error::log_error)
@@ -239,7 +238,7 @@ impl super::Api for Api {
// for some time (highly depends on the console's scale-to-zero policy);
// The connection info remains the same during that period of time,
// which means that we might cache it to reduce the load and latency.
if let Some(cached) = self.caches.node_info.get(&*key) {
if let Some(cached) = self.caches.node_info.get(&key) {
info!(key = &*key, "found cached compute node info");
return Ok(cached);
}

View File

@@ -7,7 +7,10 @@ use std::net::IpAddr;
use tokio::sync::mpsc;
use uuid::Uuid;
use crate::{console::messages::MetricsAuxInfo, error::ErrorKind, metrics::LatencyTimer};
use crate::{
console::messages::MetricsAuxInfo, error::ErrorKind, metrics::LatencyTimer, BranchId,
EndpointId, ProjectId, RoleName,
};
pub mod parquet;
@@ -26,10 +29,10 @@ pub struct RequestMonitoring {
region: &'static str,
// filled in as they are discovered
project: Option<SmolStr>,
branch: Option<SmolStr>,
endpoint_id: Option<SmolStr>,
user: Option<SmolStr>,
project: Option<ProjectId>,
branch: Option<BranchId>,
endpoint_id: Option<EndpointId>,
user: Option<RoleName>,
application: Option<SmolStr>,
error_kind: Option<ErrorKind>,
success: bool,
@@ -86,7 +89,7 @@ impl RequestMonitoring {
self.project = Some(x.project_id);
}
pub fn set_endpoint_id(&mut self, endpoint_id: Option<SmolStr>) {
pub fn set_endpoint_id(&mut self, endpoint_id: Option<EndpointId>) {
self.endpoint_id = endpoint_id.or_else(|| self.endpoint_id.clone());
}
@@ -94,7 +97,7 @@ impl RequestMonitoring {
self.application = app.or_else(|| self.application.clone());
}
pub fn set_user(&mut self, user: SmolStr) {
pub fn set_user(&mut self, user: RoleName) {
self.user = Some(user);
}

View File

@@ -62,3 +62,79 @@ pub async fn handle_signals(token: CancellationToken) -> anyhow::Result<Infallib
pub fn flatten_err<T>(r: Result<anyhow::Result<T>, JoinError>) -> anyhow::Result<T> {
r.context("join error").and_then(|x| x)
}
macro_rules! smol_str_wrapper {
($name:ident) => {
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Default)]
pub struct $name(smol_str::SmolStr);
impl $name {
pub fn as_str(&self) -> &str {
self.0.as_str()
}
}
impl std::fmt::Display for $name {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}
impl<T> std::cmp::PartialEq<T> for $name
where
smol_str::SmolStr: std::cmp::PartialEq<T>,
{
fn eq(&self, other: &T) -> bool {
self.0.eq(other)
}
}
impl<T> From<T> for $name
where
smol_str::SmolStr: From<T>,
{
fn from(x: T) -> Self {
Self(x.into())
}
}
impl AsRef<str> for $name {
fn as_ref(&self) -> &str {
self.0.as_ref()
}
}
impl std::ops::Deref for $name {
type Target = str;
fn deref(&self) -> &str {
&*self.0
}
}
impl<'de> serde::de::Deserialize<'de> for $name {
fn deserialize<D: serde::de::Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
<smol_str::SmolStr as serde::de::Deserialize<'de>>::deserialize(d).map(Self)
}
}
impl serde::Serialize for $name {
fn serialize<S: serde::Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
self.0.serialize(s)
}
}
};
}
// 90% of role name strings are 20 characters or less.
smol_str_wrapper!(RoleName);
// 50% of endpoint strings are 23 characters or less.
smol_str_wrapper!(EndpointId);
// 50% of branch strings are 23 characters or less.
smol_str_wrapper!(BranchId);
// 90% of project strings are 23 characters or less.
smol_str_wrapper!(ProjectId);
// will usually equal endpoint ID
smol_str_wrapper!(EndpointCacheKey);
smol_str_wrapper!(DbName);

View File

@@ -19,6 +19,7 @@ use crate::{
rate_limiter::EndpointRateLimiter,
stream::{PqStream, Stream},
usage_metrics::{Ids, USAGE_METRICS},
EndpointCacheKey,
};
use anyhow::{bail, Context};
use futures::TryFutureExt;
@@ -26,7 +27,7 @@ use itertools::Itertools;
use once_cell::sync::OnceCell;
use pq_proto::{BeMessage as Be, FeStartupPacket, StartupMessageParams};
use regex::Regex;
use smol_str::SmolStr;
use smol_str::{format_smolstr, SmolStr};
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tokio_util::sync::CancellationToken;
@@ -516,20 +517,21 @@ impl NeonOptions {
Self(options)
}
pub fn get_cache_key(&self, prefix: &str) -> SmolStr {
pub fn get_cache_key(&self, prefix: &str) -> EndpointCacheKey {
// prefix + format!(" {k}:{v}")
// kinda jank because SmolStr is immutable
std::iter::once(prefix)
.chain(self.0.iter().flat_map(|(k, v)| [" ", &**k, ":", &**v]))
.collect()
.collect::<SmolStr>()
.into()
}
/// <https://swagger.io/docs/specification/serialization/> DeepObject format
/// `paramName[prop1]=value1&paramName[prop2]=value2&...`
pub fn to_deep_object(&self) -> Vec<(String, SmolStr)> {
pub fn to_deep_object(&self) -> Vec<(SmolStr, SmolStr)> {
self.0
.iter()
.map(|(k, v)| (format!("options[{}]", k), v.clone()))
.map(|(k, v)| (format_smolstr!("options[{}]", k), v.clone()))
.collect()
}
}

View File

@@ -11,11 +11,12 @@ use anyhow::bail;
use dashmap::DashMap;
use itertools::Itertools;
use rand::{rngs::StdRng, Rng, SeedableRng};
use smol_str::SmolStr;
use tokio::sync::{Mutex as AsyncMutex, Semaphore, SemaphorePermit};
use tokio::time::{timeout, Duration, Instant};
use tracing::info;
use crate::EndpointId;
use super::{
limit_algorithm::{LimitAlgorithm, Sample},
RateLimiterConfig,
@@ -33,7 +34,7 @@ use super::{
// does not look very nice (`SSL SYSCALL error: Undefined error: 0`), so for now
// I went with a more expensive way that yields user-friendlier error messages.
pub struct EndpointRateLimiter<Rand = StdRng, Hasher = RandomState> {
map: DashMap<SmolStr, Vec<RateBucket>, Hasher>,
map: DashMap<EndpointId, Vec<RateBucket>, Hasher>,
info: &'static [RateBucketInfo],
access_count: AtomicUsize,
rand: Mutex<Rand>,
@@ -146,7 +147,7 @@ impl<R: Rng, S: BuildHasher + Clone> EndpointRateLimiter<R, S> {
}
/// Check that number of connections to the endpoint is below `max_rps` rps.
pub fn check(&self, endpoint: SmolStr) -> bool {
pub fn check(&self, endpoint: EndpointId) -> bool {
// do a partial GC every 2k requests. This cleans up ~ 1/64th of the map.
// worst case memory usage is about:
// = 2 * 2048 * 64 * (48B + 72B)
@@ -493,11 +494,13 @@ mod tests {
use futures::{task::noop_waker_ref, Future};
use rand::SeedableRng;
use rustc_hash::FxHasher;
use smol_str::SmolStr;
use tokio::time;
use super::{EndpointRateLimiter, Limiter, Outcome};
use crate::rate_limiter::{RateBucketInfo, RateLimitAlgorithm};
use crate::{
rate_limiter::{RateBucketInfo, RateLimitAlgorithm},
EndpointId,
};
#[tokio::test]
async fn it_works() {
@@ -654,7 +657,7 @@ mod tests {
RateBucketInfo::validate(&mut rates).unwrap();
let limiter = EndpointRateLimiter::new(Vec::leak(rates));
let endpoint = SmolStr::from("ep-my-endpoint-1234");
let endpoint = EndpointId::from("ep-my-endpoint-1234");
time::pause();

View File

@@ -3,9 +3,8 @@ use std::{convert::Infallible, sync::Arc};
use futures::StreamExt;
use redis::aio::PubSub;
use serde::Deserialize;
use smol_str::SmolStr;
use crate::cache::project_info::ProjectInfoCache;
use crate::{cache::project_info::ProjectInfoCache, ProjectId, RoleName};
const CHANNEL_NAME: &str = "neondb-proxy-ws-updates";
const RECONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(20);
@@ -46,12 +45,12 @@ enum Notification {
}
#[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
struct AllowedIpsUpdate {
project_id: SmolStr,
project_id: ProjectId,
}
#[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
struct PasswordUpdate {
project_id: SmolStr,
role_name: SmolStr,
project_id: ProjectId,
role_name: RoleName,
}
fn deserialize_json_string<'de, D, T>(deserializer: D) -> Result<T, D::Error>
where

View File

@@ -31,6 +31,7 @@ use crate::{
metrics::NUM_DB_CONNECTIONS_GAUGE,
proxy::connect_compute::ConnectMechanism,
usage_metrics::{Ids, MetricCounter, USAGE_METRICS},
DbName, EndpointCacheKey, RoleName,
};
use crate::{compute, config};
@@ -42,17 +43,17 @@ pub const APP_NAME: SmolStr = SmolStr::new_inline("/sql_over_http");
#[derive(Debug, Clone)]
pub struct ConnInfo {
pub user_info: ComputeUserInfo,
pub dbname: SmolStr,
pub dbname: DbName,
pub password: SmolStr,
}
impl ConnInfo {
// hm, change to hasher to avoid cloning?
pub fn db_and_user(&self) -> (SmolStr, SmolStr) {
pub fn db_and_user(&self) -> (DbName, RoleName) {
(self.dbname.clone(), self.user_info.user.clone())
}
pub fn endpoint_cache_key(&self) -> SmolStr {
pub fn endpoint_cache_key(&self) -> EndpointCacheKey {
self.user_info.endpoint_cache_key()
}
}
@@ -79,14 +80,14 @@ struct ConnPoolEntry {
// Per-endpoint connection pool, (dbname, username) -> DbUserConnPool
// Number of open connections is limited by the `max_conns_per_endpoint`.
pub struct EndpointConnPool {
pools: HashMap<(SmolStr, SmolStr), DbUserConnPool>,
pools: HashMap<(DbName, RoleName), DbUserConnPool>,
total_conns: usize,
max_conns: usize,
_guard: IntCounterPairGuard,
}
impl EndpointConnPool {
fn get_conn_entry(&mut self, db_user: (SmolStr, SmolStr)) -> Option<ConnPoolEntry> {
fn get_conn_entry(&mut self, db_user: (DbName, RoleName)) -> Option<ConnPoolEntry> {
let Self {
pools, total_conns, ..
} = self;
@@ -95,7 +96,7 @@ impl EndpointConnPool {
.and_then(|pool_entries| pool_entries.get_conn_entry(total_conns))
}
fn remove_client(&mut self, db_user: (SmolStr, SmolStr), conn_id: uuid::Uuid) -> bool {
fn remove_client(&mut self, db_user: (DbName, RoleName), conn_id: uuid::Uuid) -> bool {
let Self {
pools, total_conns, ..
} = self;
@@ -196,7 +197,7 @@ pub struct GlobalConnPool {
//
// That should be a fairly conteded map, so return reference to the per-endpoint
// pool as early as possible and release the lock.
global_pool: DashMap<SmolStr, Arc<RwLock<EndpointConnPool>>>,
global_pool: DashMap<EndpointCacheKey, Arc<RwLock<EndpointConnPool>>>,
/// Number of endpoint-connection pools
///
@@ -440,7 +441,10 @@ impl GlobalConnPool {
Ok(Client::new(new_client, conn_info, endpoint_pool).await)
}
fn get_or_create_endpoint_pool(&self, endpoint: &SmolStr) -> Arc<RwLock<EndpointConnPool>> {
fn get_or_create_endpoint_pool(
&self,
endpoint: &EndpointCacheKey,
) -> Arc<RwLock<EndpointConnPool>> {
// fast path
if let Some(pool) = self.global_pool.get(endpoint) {
return pool.clone();

View File

@@ -13,7 +13,6 @@ use hyper::{Body, HeaderMap, Request};
use serde_json::json;
use serde_json::Map;
use serde_json::Value;
use smol_str::SmolStr;
use tokio_postgres::error::DbError;
use tokio_postgres::error::ErrorPosition;
use tokio_postgres::types::Kind;
@@ -36,6 +35,8 @@ use crate::config::TlsConfig;
use crate::context::RequestMonitoring;
use crate::metrics::NUM_CONNECTION_REQUESTS_GAUGE;
use crate::proxy::NeonOptions;
use crate::EndpointId;
use crate::RoleName;
use super::conn_pool::ConnInfo;
use super::conn_pool::GlobalConnPool;
@@ -155,7 +156,7 @@ fn get_conn_info(
.next()
.ok_or(anyhow::anyhow!("invalid database name"))?;
let username = SmolStr::from(connection_url.username());
let username = RoleName::from(connection_url.username());
if username.is_empty() {
return Err(anyhow::anyhow!("missing username"));
}
@@ -189,7 +190,7 @@ fn get_conn_info(
let endpoint = endpoint_sni(hostname, &tls.common_names)?;
let endpoint: SmolStr = endpoint.into();
let endpoint: EndpointId = endpoint.into();
ctx.set_endpoint_id(Some(endpoint.clone()));
let pairs = connection_url.query_pairs();

View File

@@ -1,12 +1,11 @@
//! Periodically collect proxy consumption metrics
//! and push them to a HTTP endpoint.
use crate::{config::MetricCollectionConfig, http};
use crate::{config::MetricCollectionConfig, http, BranchId, EndpointId};
use chrono::{DateTime, Utc};
use consumption_metrics::{idempotency_key, Event, EventChunk, EventType, CHUNK_SIZE};
use dashmap::{mapref::entry::Entry, DashMap};
use once_cell::sync::Lazy;
use serde::{Deserialize, Serialize};
use smol_str::SmolStr;
use std::{
convert::Infallible,
sync::{
@@ -30,8 +29,8 @@ const DEFAULT_HTTP_REPORTING_TIMEOUT: Duration = Duration::from_secs(60);
/// because we enrich the event with project_id in the control-plane endpoint.
#[derive(Eq, Hash, PartialEq, Serialize, Deserialize, Debug, Clone)]
pub struct Ids {
pub endpoint_id: SmolStr,
pub branch_id: SmolStr,
pub endpoint_id: EndpointId,
pub branch_id: BranchId,
}
#[derive(Debug)]