mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-07 13:32:57 +00:00
proxy: add newtype wrappers for string based IDs (#6445)
## Problem too many string based IDs. easy to mix up ID types. ## Summary of changes Add a bunch of `SmolStr` wrappers that provide convenience methods but are type safe
This commit is contained in:
4
Cargo.lock
generated
4
Cargo.lock
generated
@@ -5113,9 +5113,9 @@ checksum = "62bb4feee49fdd9f707ef802e22365a35de4b7b299de4763d44bfea899442ff9"
|
||||
|
||||
[[package]]
|
||||
name = "smol_str"
|
||||
version = "0.2.0"
|
||||
version = "0.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "74212e6bbe9a4352329b2f68ba3130c15a3f26fe88ff22dbdc6cdd58fa85e99c"
|
||||
checksum = "e6845563ada680337a52d43bb0b29f396f2d911616f6573012645b9e3d048a49"
|
||||
dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
@@ -3,7 +3,6 @@ mod hacks;
|
||||
mod link;
|
||||
|
||||
pub use link::LinkAuthError;
|
||||
use smol_str::SmolStr;
|
||||
use tokio_postgres::config::AuthKeys;
|
||||
|
||||
use crate::auth::credentials::check_peer_addr_is_in_list;
|
||||
@@ -16,7 +15,6 @@ use crate::context::RequestMonitoring;
|
||||
use crate::proxy::connect_compute::handle_try_wake;
|
||||
use crate::proxy::retry::retry_after;
|
||||
use crate::proxy::NeonOptions;
|
||||
use crate::scram;
|
||||
use crate::stream::Stream;
|
||||
use crate::{
|
||||
auth::{self, ComputeUserInfoMaybeEndpoint},
|
||||
@@ -28,6 +26,7 @@ use crate::{
|
||||
},
|
||||
stream, url,
|
||||
};
|
||||
use crate::{scram, EndpointCacheKey, EndpointId, RoleName};
|
||||
use futures::TryFutureExt;
|
||||
use std::borrow::Cow;
|
||||
use std::ops::ControlFlow;
|
||||
@@ -130,19 +129,19 @@ pub struct ComputeCredentials<T> {
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ComputeUserInfoNoEndpoint {
|
||||
pub user: SmolStr,
|
||||
pub user: RoleName,
|
||||
pub options: NeonOptions,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ComputeUserInfo {
|
||||
pub endpoint: SmolStr,
|
||||
pub user: SmolStr,
|
||||
pub endpoint: EndpointId,
|
||||
pub user: RoleName,
|
||||
pub options: NeonOptions,
|
||||
}
|
||||
|
||||
impl ComputeUserInfo {
|
||||
pub fn endpoint_cache_key(&self) -> SmolStr {
|
||||
pub fn endpoint_cache_key(&self) -> EndpointCacheKey {
|
||||
self.options.get_cache_key(&self.endpoint)
|
||||
}
|
||||
}
|
||||
@@ -158,7 +157,7 @@ impl TryFrom<ComputeUserInfoMaybeEndpoint> for ComputeUserInfo {
|
||||
type Error = ComputeUserInfoNoEndpoint;
|
||||
|
||||
fn try_from(user_info: ComputeUserInfoMaybeEndpoint) -> Result<Self, Self::Error> {
|
||||
match user_info.project {
|
||||
match user_info.endpoint_id {
|
||||
None => Err(ComputeUserInfoNoEndpoint {
|
||||
user: user_info.user,
|
||||
options: user_info.options,
|
||||
@@ -317,11 +316,11 @@ async fn auth_and_wake_compute(
|
||||
|
||||
impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> {
|
||||
/// Get compute endpoint name from the credentials.
|
||||
pub fn get_endpoint(&self) -> Option<SmolStr> {
|
||||
pub fn get_endpoint(&self) -> Option<EndpointId> {
|
||||
use BackendType::*;
|
||||
|
||||
match self {
|
||||
Console(_, user_info) => user_info.project.clone(),
|
||||
Console(_, user_info) => user_info.endpoint_id.clone(),
|
||||
Link(_) => Some("link".into()),
|
||||
#[cfg(test)]
|
||||
Test(_) => Some("test".into()),
|
||||
@@ -355,7 +354,7 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> {
|
||||
Console(api, user_info) => {
|
||||
info!(
|
||||
user = &*user_info.user,
|
||||
project = user_info.project(),
|
||||
project = user_info.endpoint(),
|
||||
"performing authentication using the console"
|
||||
);
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
use crate::{
|
||||
auth::password_hack::parse_endpoint_param, context::RequestMonitoring, error::UserFacingError,
|
||||
metrics::NUM_CONNECTION_ACCEPTED_BY_SNI, proxy::NeonOptions,
|
||||
metrics::NUM_CONNECTION_ACCEPTED_BY_SNI, proxy::NeonOptions, EndpointId, RoleName,
|
||||
};
|
||||
use itertools::Itertools;
|
||||
use pq_proto::StartupMessageParams;
|
||||
@@ -21,7 +21,10 @@ pub enum ComputeUserInfoParseError {
|
||||
SNI ('{}') and project option ('{}').",
|
||||
.domain, .option,
|
||||
)]
|
||||
InconsistentProjectNames { domain: SmolStr, option: SmolStr },
|
||||
InconsistentProjectNames {
|
||||
domain: EndpointId,
|
||||
option: EndpointId,
|
||||
},
|
||||
|
||||
#[error(
|
||||
"Common name inferred from SNI ('{}') is not known",
|
||||
@@ -30,7 +33,7 @@ pub enum ComputeUserInfoParseError {
|
||||
UnknownCommonName { cn: String },
|
||||
|
||||
#[error("Project name ('{0}') must contain only alphanumeric characters and hyphen.")]
|
||||
MalformedProjectName(SmolStr),
|
||||
MalformedProjectName(EndpointId),
|
||||
}
|
||||
|
||||
impl UserFacingError for ComputeUserInfoParseError {}
|
||||
@@ -39,17 +42,15 @@ impl UserFacingError for ComputeUserInfoParseError {}
|
||||
/// Note that we don't store any kind of client key or password here.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct ComputeUserInfoMaybeEndpoint {
|
||||
pub user: SmolStr,
|
||||
// TODO: this is a severe misnomer! We should think of a new name ASAP.
|
||||
pub project: Option<SmolStr>,
|
||||
|
||||
pub user: RoleName,
|
||||
pub endpoint_id: Option<EndpointId>,
|
||||
pub options: NeonOptions,
|
||||
}
|
||||
|
||||
impl ComputeUserInfoMaybeEndpoint {
|
||||
#[inline]
|
||||
pub fn project(&self) -> Option<&str> {
|
||||
self.project.as_deref()
|
||||
pub fn endpoint(&self) -> Option<&str> {
|
||||
self.endpoint_id.as_deref()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -79,15 +80,15 @@ impl ComputeUserInfoMaybeEndpoint {
|
||||
|
||||
// Some parameters are stored in the startup message.
|
||||
let get_param = |key| params.get(key).ok_or(MissingKey(key));
|
||||
let user: SmolStr = get_param("user")?.into();
|
||||
let user: RoleName = get_param("user")?.into();
|
||||
|
||||
// record the values if we have them
|
||||
ctx.set_application(params.get("application_name").map(SmolStr::from));
|
||||
ctx.set_user(user.clone());
|
||||
ctx.set_endpoint_id(sni.map(SmolStr::from));
|
||||
ctx.set_endpoint_id(sni.map(EndpointId::from));
|
||||
|
||||
// Project name might be passed via PG's command-line options.
|
||||
let project_option = params
|
||||
let endpoint_option = params
|
||||
.options_raw()
|
||||
.and_then(|options| {
|
||||
// We support both `project` (deprecated) and `endpoint` options for backward compatibility.
|
||||
@@ -100,9 +101,9 @@ impl ComputeUserInfoMaybeEndpoint {
|
||||
})
|
||||
.map(|name| name.into());
|
||||
|
||||
let project_from_domain = if let Some(sni_str) = sni {
|
||||
let endpoint_from_domain = if let Some(sni_str) = sni {
|
||||
if let Some(cn) = common_names {
|
||||
Some(SmolStr::from(endpoint_sni(sni_str, cn)?))
|
||||
Some(EndpointId::from(endpoint_sni(sni_str, cn)?))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
@@ -110,7 +111,7 @@ impl ComputeUserInfoMaybeEndpoint {
|
||||
None
|
||||
};
|
||||
|
||||
let project = match (project_option, project_from_domain) {
|
||||
let endpoint = match (endpoint_option, endpoint_from_domain) {
|
||||
// Invariant: if we have both project name variants, they should match.
|
||||
(Some(option), Some(domain)) if option != domain => {
|
||||
Some(Err(InconsistentProjectNames { domain, option }))
|
||||
@@ -123,13 +124,13 @@ impl ComputeUserInfoMaybeEndpoint {
|
||||
}
|
||||
.transpose()?;
|
||||
|
||||
info!(%user, project = project.as_deref(), "credentials");
|
||||
info!(%user, project = endpoint.as_deref(), "credentials");
|
||||
if sni.is_some() {
|
||||
info!("Connection with sni");
|
||||
NUM_CONNECTION_ACCEPTED_BY_SNI
|
||||
.with_label_values(&["sni"])
|
||||
.inc();
|
||||
} else if project.is_some() {
|
||||
} else if endpoint.is_some() {
|
||||
NUM_CONNECTION_ACCEPTED_BY_SNI
|
||||
.with_label_values(&["no_sni"])
|
||||
.inc();
|
||||
@@ -145,7 +146,7 @@ impl ComputeUserInfoMaybeEndpoint {
|
||||
|
||||
Ok(Self {
|
||||
user,
|
||||
project,
|
||||
endpoint_id: endpoint.map(EndpointId::from),
|
||||
options,
|
||||
})
|
||||
}
|
||||
@@ -238,7 +239,7 @@ mod tests {
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let user_info = ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, None, None)?;
|
||||
assert_eq!(user_info.user, "john_doe");
|
||||
assert_eq!(user_info.project, None);
|
||||
assert_eq!(user_info.endpoint_id, None);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -253,7 +254,7 @@ mod tests {
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let user_info = ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, None, None)?;
|
||||
assert_eq!(user_info.user, "john_doe");
|
||||
assert_eq!(user_info.project, None);
|
||||
assert_eq!(user_info.endpoint_id, None);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -269,7 +270,7 @@ mod tests {
|
||||
let user_info =
|
||||
ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())?;
|
||||
assert_eq!(user_info.user, "john_doe");
|
||||
assert_eq!(user_info.project.as_deref(), Some("foo"));
|
||||
assert_eq!(user_info.endpoint_id.as_deref(), Some("foo"));
|
||||
assert_eq!(user_info.options.get_cache_key("foo"), "foo");
|
||||
|
||||
Ok(())
|
||||
@@ -285,7 +286,7 @@ mod tests {
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let user_info = ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, None, None)?;
|
||||
assert_eq!(user_info.user, "john_doe");
|
||||
assert_eq!(user_info.project.as_deref(), Some("bar"));
|
||||
assert_eq!(user_info.endpoint_id.as_deref(), Some("bar"));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -300,7 +301,7 @@ mod tests {
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let user_info = ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, None, None)?;
|
||||
assert_eq!(user_info.user, "john_doe");
|
||||
assert_eq!(user_info.project.as_deref(), Some("bar"));
|
||||
assert_eq!(user_info.endpoint_id.as_deref(), Some("bar"));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -318,7 +319,7 @@ mod tests {
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let user_info = ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, None, None)?;
|
||||
assert_eq!(user_info.user, "john_doe");
|
||||
assert!(user_info.project.is_none());
|
||||
assert!(user_info.endpoint_id.is_none());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -333,7 +334,7 @@ mod tests {
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let user_info = ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, None, None)?;
|
||||
assert_eq!(user_info.user, "john_doe");
|
||||
assert!(user_info.project.is_none());
|
||||
assert!(user_info.endpoint_id.is_none());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -349,7 +350,7 @@ mod tests {
|
||||
let user_info =
|
||||
ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())?;
|
||||
assert_eq!(user_info.user, "john_doe");
|
||||
assert_eq!(user_info.project.as_deref(), Some("baz"));
|
||||
assert_eq!(user_info.endpoint_id.as_deref(), Some("baz"));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -363,14 +364,14 @@ mod tests {
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let user_info =
|
||||
ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())?;
|
||||
assert_eq!(user_info.project.as_deref(), Some("p1"));
|
||||
assert_eq!(user_info.endpoint_id.as_deref(), Some("p1"));
|
||||
|
||||
let common_names = Some(["a.com".into(), "b.com".into()].into());
|
||||
let sni = Some("p1.b.com");
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let user_info =
|
||||
ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())?;
|
||||
assert_eq!(user_info.project.as_deref(), Some("p1"));
|
||||
assert_eq!(user_info.endpoint_id.as_deref(), Some("p1"));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -427,7 +428,7 @@ mod tests {
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let user_info =
|
||||
ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())?;
|
||||
assert_eq!(user_info.project.as_deref(), Some("project"));
|
||||
assert_eq!(user_info.endpoint_id.as_deref(), Some("project"));
|
||||
assert_eq!(
|
||||
user_info.options.get_cache_key("project"),
|
||||
"project endpoint_type:read_write lsn:0/2"
|
||||
|
||||
@@ -4,10 +4,11 @@
|
||||
//! UPDATE (Mon Aug 8 13:20:34 UTC 2022): the payload format has been simplified.
|
||||
|
||||
use bstr::ByteSlice;
|
||||
use smol_str::SmolStr;
|
||||
|
||||
use crate::EndpointId;
|
||||
|
||||
pub struct PasswordHackPayload {
|
||||
pub endpoint: SmolStr,
|
||||
pub endpoint: EndpointId,
|
||||
pub password: Vec<u8>,
|
||||
}
|
||||
|
||||
|
||||
64
proxy/src/cache/project_info.rs
vendored
64
proxy/src/cache/project_info.rs
vendored
@@ -11,13 +11,16 @@ use smol_str::SmolStr;
|
||||
use tokio::time::Instant;
|
||||
use tracing::{debug, info};
|
||||
|
||||
use crate::{auth::IpPattern, config::ProjectInfoCacheOptions, console::AuthSecret};
|
||||
use crate::{
|
||||
auth::IpPattern, config::ProjectInfoCacheOptions, console::AuthSecret, EndpointId, ProjectId,
|
||||
RoleName,
|
||||
};
|
||||
|
||||
use super::{Cache, Cached};
|
||||
|
||||
pub trait ProjectInfoCache {
|
||||
fn invalidate_allowed_ips_for_project(&self, project_id: &SmolStr);
|
||||
fn invalidate_role_secret_for_project(&self, project_id: &SmolStr, role_name: &SmolStr);
|
||||
fn invalidate_allowed_ips_for_project(&self, project_id: &ProjectId);
|
||||
fn invalidate_role_secret_for_project(&self, project_id: &ProjectId, role_name: &RoleName);
|
||||
fn enable_ttl(&self);
|
||||
fn disable_ttl(&self);
|
||||
}
|
||||
@@ -44,7 +47,7 @@ impl<T> From<T> for Entry<T> {
|
||||
|
||||
#[derive(Default)]
|
||||
struct EndpointInfo {
|
||||
secret: std::collections::HashMap<SmolStr, Entry<Option<AuthSecret>>>,
|
||||
secret: std::collections::HashMap<RoleName, Entry<Option<AuthSecret>>>,
|
||||
allowed_ips: Option<Entry<Arc<Vec<IpPattern>>>>,
|
||||
}
|
||||
|
||||
@@ -57,7 +60,7 @@ impl EndpointInfo {
|
||||
}
|
||||
pub fn get_role_secret(
|
||||
&self,
|
||||
role_name: &SmolStr,
|
||||
role_name: &RoleName,
|
||||
valid_since: Instant,
|
||||
ignore_cache_since: Option<Instant>,
|
||||
) -> Option<(Option<AuthSecret>, bool)> {
|
||||
@@ -90,7 +93,7 @@ impl EndpointInfo {
|
||||
pub fn invalidate_allowed_ips(&mut self) {
|
||||
self.allowed_ips = None;
|
||||
}
|
||||
pub fn invalidate_role_secret(&mut self, role_name: &SmolStr) {
|
||||
pub fn invalidate_role_secret(&mut self, role_name: &RoleName) {
|
||||
self.secret.remove(role_name);
|
||||
}
|
||||
}
|
||||
@@ -103,9 +106,9 @@ impl EndpointInfo {
|
||||
/// One may ask, why the data is stored per project, when on the user request there is only data about the endpoint available?
|
||||
/// On the cplane side updates are done per project (or per branch), so it's easier to invalidate the whole project cache.
|
||||
pub struct ProjectInfoCacheImpl {
|
||||
cache: DashMap<SmolStr, EndpointInfo>,
|
||||
cache: DashMap<EndpointId, EndpointInfo>,
|
||||
|
||||
project2ep: DashMap<SmolStr, HashSet<SmolStr>>,
|
||||
project2ep: DashMap<ProjectId, HashSet<EndpointId>>,
|
||||
config: ProjectInfoCacheOptions,
|
||||
|
||||
start_time: Instant,
|
||||
@@ -113,7 +116,7 @@ pub struct ProjectInfoCacheImpl {
|
||||
}
|
||||
|
||||
impl ProjectInfoCache for ProjectInfoCacheImpl {
|
||||
fn invalidate_allowed_ips_for_project(&self, project_id: &SmolStr) {
|
||||
fn invalidate_allowed_ips_for_project(&self, project_id: &ProjectId) {
|
||||
info!("invalidating allowed ips for project `{}`", project_id);
|
||||
let endpoints = self
|
||||
.project2ep
|
||||
@@ -126,7 +129,7 @@ impl ProjectInfoCache for ProjectInfoCacheImpl {
|
||||
}
|
||||
}
|
||||
}
|
||||
fn invalidate_role_secret_for_project(&self, project_id: &SmolStr, role_name: &SmolStr) {
|
||||
fn invalidate_role_secret_for_project(&self, project_id: &ProjectId, role_name: &RoleName) {
|
||||
info!(
|
||||
"invalidating role secret for project_id `{}` and role_name `{}`",
|
||||
project_id, role_name
|
||||
@@ -167,8 +170,8 @@ impl ProjectInfoCacheImpl {
|
||||
|
||||
pub fn get_role_secret(
|
||||
&self,
|
||||
endpoint_id: &SmolStr,
|
||||
role_name: &SmolStr,
|
||||
endpoint_id: &EndpointId,
|
||||
role_name: &RoleName,
|
||||
) -> Option<Cached<&Self, Option<AuthSecret>>> {
|
||||
let (valid_since, ignore_cache_since) = self.get_cache_times();
|
||||
let endpoint_info = self.cache.get(endpoint_id)?;
|
||||
@@ -188,7 +191,7 @@ impl ProjectInfoCacheImpl {
|
||||
}
|
||||
pub fn get_allowed_ips(
|
||||
&self,
|
||||
endpoint_id: &SmolStr,
|
||||
endpoint_id: &EndpointId,
|
||||
) -> Option<Cached<&Self, Arc<Vec<IpPattern>>>> {
|
||||
let (valid_since, ignore_cache_since) = self.get_cache_times();
|
||||
let endpoint_info = self.cache.get(endpoint_id)?;
|
||||
@@ -205,9 +208,9 @@ impl ProjectInfoCacheImpl {
|
||||
}
|
||||
pub fn insert_role_secret(
|
||||
&self,
|
||||
project_id: &SmolStr,
|
||||
endpoint_id: &SmolStr,
|
||||
role_name: &SmolStr,
|
||||
project_id: &ProjectId,
|
||||
endpoint_id: &EndpointId,
|
||||
role_name: &RoleName,
|
||||
secret: Option<AuthSecret>,
|
||||
) {
|
||||
if self.cache.len() >= self.config.size {
|
||||
@@ -222,8 +225,8 @@ impl ProjectInfoCacheImpl {
|
||||
}
|
||||
pub fn insert_allowed_ips(
|
||||
&self,
|
||||
project_id: &SmolStr,
|
||||
endpoint_id: &SmolStr,
|
||||
project_id: &ProjectId,
|
||||
endpoint_id: &EndpointId,
|
||||
allowed_ips: Arc<Vec<IpPattern>>,
|
||||
) {
|
||||
if self.cache.len() >= self.config.size {
|
||||
@@ -236,7 +239,7 @@ impl ProjectInfoCacheImpl {
|
||||
.or_default()
|
||||
.allowed_ips = Some(allowed_ips.into());
|
||||
}
|
||||
fn inser_project2endpoint(&self, project_id: &SmolStr, endpoint_id: &SmolStr) {
|
||||
fn inser_project2endpoint(&self, project_id: &ProjectId, endpoint_id: &EndpointId) {
|
||||
if let Some(mut endpoints) = self.project2ep.get_mut(project_id) {
|
||||
endpoints.insert(endpoint_id.clone());
|
||||
} else {
|
||||
@@ -297,18 +300,18 @@ impl ProjectInfoCacheImpl {
|
||||
/// This is used to invalidate cache entries.
|
||||
pub struct CachedLookupInfo {
|
||||
/// Search by this key.
|
||||
endpoint_id: SmolStr,
|
||||
endpoint_id: EndpointId,
|
||||
lookup_type: LookupType,
|
||||
}
|
||||
|
||||
impl CachedLookupInfo {
|
||||
pub(self) fn new_role_secret(endpoint_id: SmolStr, role_name: SmolStr) -> Self {
|
||||
pub(self) fn new_role_secret(endpoint_id: EndpointId, role_name: RoleName) -> Self {
|
||||
Self {
|
||||
endpoint_id,
|
||||
lookup_type: LookupType::RoleSecret(role_name),
|
||||
}
|
||||
}
|
||||
pub(self) fn new_allowed_ips(endpoint_id: SmolStr) -> Self {
|
||||
pub(self) fn new_allowed_ips(endpoint_id: EndpointId) -> Self {
|
||||
Self {
|
||||
endpoint_id,
|
||||
lookup_type: LookupType::AllowedIps,
|
||||
@@ -317,7 +320,7 @@ impl CachedLookupInfo {
|
||||
}
|
||||
|
||||
enum LookupType {
|
||||
RoleSecret(SmolStr),
|
||||
RoleSecret(RoleName),
|
||||
AllowedIps,
|
||||
}
|
||||
|
||||
@@ -348,7 +351,6 @@ impl Cache for ProjectInfoCacheImpl {
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::{console::AuthSecret, scram::ServerSecret};
|
||||
use smol_str::SmolStr;
|
||||
use std::{sync::Arc, time::Duration};
|
||||
|
||||
#[tokio::test]
|
||||
@@ -362,8 +364,8 @@ mod tests {
|
||||
});
|
||||
let project_id = "project".into();
|
||||
let endpoint_id = "endpoint".into();
|
||||
let user1: SmolStr = "user1".into();
|
||||
let user2: SmolStr = "user2".into();
|
||||
let user1: RoleName = "user1".into();
|
||||
let user2: RoleName = "user2".into();
|
||||
let secret1 = Some(AuthSecret::Scram(ServerSecret::mock(
|
||||
user1.as_str(),
|
||||
[1; 32],
|
||||
@@ -385,7 +387,7 @@ mod tests {
|
||||
assert_eq!(cached.value, secret2);
|
||||
|
||||
// Shouldn't add more than 2 roles.
|
||||
let user3: SmolStr = "user3".into();
|
||||
let user3: RoleName = "user3".into();
|
||||
let secret3 = Some(AuthSecret::Scram(ServerSecret::mock(
|
||||
user3.as_str(),
|
||||
[3; 32],
|
||||
@@ -420,8 +422,8 @@ mod tests {
|
||||
|
||||
let project_id = "project".into();
|
||||
let endpoint_id = "endpoint".into();
|
||||
let user1: SmolStr = "user1".into();
|
||||
let user2: SmolStr = "user2".into();
|
||||
let user1: RoleName = "user1".into();
|
||||
let user2: RoleName = "user2".into();
|
||||
let secret1 = Some(AuthSecret::Scram(ServerSecret::mock(
|
||||
user1.as_str(),
|
||||
[1; 32],
|
||||
@@ -475,8 +477,8 @@ mod tests {
|
||||
|
||||
let project_id = "project".into();
|
||||
let endpoint_id = "endpoint".into();
|
||||
let user1: SmolStr = "user1".into();
|
||||
let user2: SmolStr = "user2".into();
|
||||
let user1: RoleName = "user1".into();
|
||||
let user2: RoleName = "user2".into();
|
||||
let secret1 = Some(AuthSecret::Scram(ServerSecret::mock(
|
||||
user1.as_str(),
|
||||
[1; 32],
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
use serde::Deserialize;
|
||||
use smol_str::SmolStr;
|
||||
use std::fmt;
|
||||
|
||||
use crate::auth::IpPattern;
|
||||
|
||||
use crate::{BranchId, EndpointId, ProjectId};
|
||||
|
||||
/// Generic error response with human-readable description.
|
||||
/// Note that we can't always present it to user as is.
|
||||
#[derive(Debug, Deserialize)]
|
||||
@@ -17,7 +18,7 @@ pub struct ConsoleError {
|
||||
pub struct GetRoleSecret {
|
||||
pub role_secret: Box<str>,
|
||||
pub allowed_ips: Option<Vec<IpPattern>>,
|
||||
pub project_id: Option<Box<str>>,
|
||||
pub project_id: Option<ProjectId>,
|
||||
}
|
||||
|
||||
// Manually implement debug to omit sensitive info.
|
||||
@@ -94,9 +95,9 @@ impl fmt::Debug for DatabaseInfo {
|
||||
/// Also known as `ProxyMetricsAuxInfo` in the console.
|
||||
#[derive(Debug, Deserialize, Clone, Default)]
|
||||
pub struct MetricsAuxInfo {
|
||||
pub endpoint_id: SmolStr,
|
||||
pub project_id: SmolStr,
|
||||
pub branch_id: SmolStr,
|
||||
pub endpoint_id: EndpointId,
|
||||
pub project_id: ProjectId,
|
||||
pub branch_id: BranchId,
|
||||
}
|
||||
|
||||
impl MetricsAuxInfo {
|
||||
|
||||
@@ -9,11 +9,10 @@ use crate::{
|
||||
compute,
|
||||
config::{CacheOptions, ProjectInfoCacheOptions},
|
||||
context::RequestMonitoring,
|
||||
scram,
|
||||
scram, EndpointCacheKey, ProjectId,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use dashmap::DashMap;
|
||||
use smol_str::SmolStr;
|
||||
use std::{sync::Arc, time::Duration};
|
||||
use tokio::sync::{OwnedSemaphorePermit, Semaphore};
|
||||
use tokio::time::Instant;
|
||||
@@ -214,7 +213,7 @@ pub struct AuthInfo {
|
||||
/// List of IP addresses allowed for the autorization.
|
||||
pub allowed_ips: Vec<IpPattern>,
|
||||
/// Project ID. This is used for cache invalidation.
|
||||
pub project_id: Option<SmolStr>,
|
||||
pub project_id: Option<ProjectId>,
|
||||
}
|
||||
|
||||
/// Info for establishing a connection to a compute node.
|
||||
@@ -233,7 +232,7 @@ pub struct NodeInfo {
|
||||
pub allow_self_signed_compute: bool,
|
||||
}
|
||||
|
||||
pub type NodeInfoCache = TimedLru<SmolStr, NodeInfo>;
|
||||
pub type NodeInfoCache = TimedLru<EndpointCacheKey, NodeInfo>;
|
||||
pub type CachedNodeInfo = Cached<&'static NodeInfoCache>;
|
||||
pub type CachedRoleSecret = Cached<&'static ProjectInfoCacheImpl, Option<AuthSecret>>;
|
||||
pub type CachedAllowedIps = Cached<&'static ProjectInfoCacheImpl, Arc<Vec<IpPattern>>>;
|
||||
@@ -345,7 +344,7 @@ impl ApiCaches {
|
||||
/// Various caches for [`console`](super).
|
||||
pub struct ApiLocks {
|
||||
name: &'static str,
|
||||
node_locks: DashMap<SmolStr, Arc<Semaphore>>,
|
||||
node_locks: DashMap<EndpointCacheKey, Arc<Semaphore>>,
|
||||
permits: usize,
|
||||
timeout: Duration,
|
||||
registered: prometheus::IntCounter,
|
||||
@@ -413,7 +412,7 @@ impl ApiLocks {
|
||||
|
||||
pub async fn get_wake_compute_permit(
|
||||
&self,
|
||||
key: &SmolStr,
|
||||
key: &EndpointCacheKey,
|
||||
) -> Result<WakeComputePermit, errors::WakeComputeError> {
|
||||
if self.permits == 0 {
|
||||
return Ok(WakeComputePermit { permit: None });
|
||||
|
||||
@@ -14,7 +14,6 @@ use crate::{
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use futures::TryFutureExt;
|
||||
use smol_str::SmolStr;
|
||||
use std::sync::Arc;
|
||||
use tokio::time::Instant;
|
||||
use tokio_postgres::config::SslMode;
|
||||
@@ -98,7 +97,7 @@ impl Api {
|
||||
Ok(AuthInfo {
|
||||
secret,
|
||||
allowed_ips,
|
||||
project_id: body.project_id.map(SmolStr::from),
|
||||
project_id: body.project_id,
|
||||
})
|
||||
}
|
||||
.map_err(crate::error::log_error)
|
||||
@@ -239,7 +238,7 @@ impl super::Api for Api {
|
||||
// for some time (highly depends on the console's scale-to-zero policy);
|
||||
// The connection info remains the same during that period of time,
|
||||
// which means that we might cache it to reduce the load and latency.
|
||||
if let Some(cached) = self.caches.node_info.get(&*key) {
|
||||
if let Some(cached) = self.caches.node_info.get(&key) {
|
||||
info!(key = &*key, "found cached compute node info");
|
||||
return Ok(cached);
|
||||
}
|
||||
|
||||
@@ -7,7 +7,10 @@ use std::net::IpAddr;
|
||||
use tokio::sync::mpsc;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{console::messages::MetricsAuxInfo, error::ErrorKind, metrics::LatencyTimer};
|
||||
use crate::{
|
||||
console::messages::MetricsAuxInfo, error::ErrorKind, metrics::LatencyTimer, BranchId,
|
||||
EndpointId, ProjectId, RoleName,
|
||||
};
|
||||
|
||||
pub mod parquet;
|
||||
|
||||
@@ -26,10 +29,10 @@ pub struct RequestMonitoring {
|
||||
region: &'static str,
|
||||
|
||||
// filled in as they are discovered
|
||||
project: Option<SmolStr>,
|
||||
branch: Option<SmolStr>,
|
||||
endpoint_id: Option<SmolStr>,
|
||||
user: Option<SmolStr>,
|
||||
project: Option<ProjectId>,
|
||||
branch: Option<BranchId>,
|
||||
endpoint_id: Option<EndpointId>,
|
||||
user: Option<RoleName>,
|
||||
application: Option<SmolStr>,
|
||||
error_kind: Option<ErrorKind>,
|
||||
success: bool,
|
||||
@@ -86,7 +89,7 @@ impl RequestMonitoring {
|
||||
self.project = Some(x.project_id);
|
||||
}
|
||||
|
||||
pub fn set_endpoint_id(&mut self, endpoint_id: Option<SmolStr>) {
|
||||
pub fn set_endpoint_id(&mut self, endpoint_id: Option<EndpointId>) {
|
||||
self.endpoint_id = endpoint_id.or_else(|| self.endpoint_id.clone());
|
||||
}
|
||||
|
||||
@@ -94,7 +97,7 @@ impl RequestMonitoring {
|
||||
self.application = app.or_else(|| self.application.clone());
|
||||
}
|
||||
|
||||
pub fn set_user(&mut self, user: SmolStr) {
|
||||
pub fn set_user(&mut self, user: RoleName) {
|
||||
self.user = Some(user);
|
||||
}
|
||||
|
||||
|
||||
@@ -62,3 +62,79 @@ pub async fn handle_signals(token: CancellationToken) -> anyhow::Result<Infallib
|
||||
pub fn flatten_err<T>(r: Result<anyhow::Result<T>, JoinError>) -> anyhow::Result<T> {
|
||||
r.context("join error").and_then(|x| x)
|
||||
}
|
||||
|
||||
macro_rules! smol_str_wrapper {
|
||||
($name:ident) => {
|
||||
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Default)]
|
||||
pub struct $name(smol_str::SmolStr);
|
||||
|
||||
impl $name {
|
||||
pub fn as_str(&self) -> &str {
|
||||
self.0.as_str()
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for $name {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
self.0.fmt(f)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> std::cmp::PartialEq<T> for $name
|
||||
where
|
||||
smol_str::SmolStr: std::cmp::PartialEq<T>,
|
||||
{
|
||||
fn eq(&self, other: &T) -> bool {
|
||||
self.0.eq(other)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<T> for $name
|
||||
where
|
||||
smol_str::SmolStr: From<T>,
|
||||
{
|
||||
fn from(x: T) -> Self {
|
||||
Self(x.into())
|
||||
}
|
||||
}
|
||||
|
||||
impl AsRef<str> for $name {
|
||||
fn as_ref(&self) -> &str {
|
||||
self.0.as_ref()
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Deref for $name {
|
||||
type Target = str;
|
||||
fn deref(&self) -> &str {
|
||||
&*self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> serde::de::Deserialize<'de> for $name {
|
||||
fn deserialize<D: serde::de::Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
|
||||
<smol_str::SmolStr as serde::de::Deserialize<'de>>::deserialize(d).map(Self)
|
||||
}
|
||||
}
|
||||
|
||||
impl serde::Serialize for $name {
|
||||
fn serialize<S: serde::Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
|
||||
self.0.serialize(s)
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
// 90% of role name strings are 20 characters or less.
|
||||
smol_str_wrapper!(RoleName);
|
||||
// 50% of endpoint strings are 23 characters or less.
|
||||
smol_str_wrapper!(EndpointId);
|
||||
// 50% of branch strings are 23 characters or less.
|
||||
smol_str_wrapper!(BranchId);
|
||||
// 90% of project strings are 23 characters or less.
|
||||
smol_str_wrapper!(ProjectId);
|
||||
|
||||
// will usually equal endpoint ID
|
||||
smol_str_wrapper!(EndpointCacheKey);
|
||||
|
||||
smol_str_wrapper!(DbName);
|
||||
|
||||
@@ -19,6 +19,7 @@ use crate::{
|
||||
rate_limiter::EndpointRateLimiter,
|
||||
stream::{PqStream, Stream},
|
||||
usage_metrics::{Ids, USAGE_METRICS},
|
||||
EndpointCacheKey,
|
||||
};
|
||||
use anyhow::{bail, Context};
|
||||
use futures::TryFutureExt;
|
||||
@@ -26,7 +27,7 @@ use itertools::Itertools;
|
||||
use once_cell::sync::OnceCell;
|
||||
use pq_proto::{BeMessage as Be, FeStartupPacket, StartupMessageParams};
|
||||
use regex::Regex;
|
||||
use smol_str::SmolStr;
|
||||
use smol_str::{format_smolstr, SmolStr};
|
||||
use std::sync::Arc;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
@@ -516,20 +517,21 @@ impl NeonOptions {
|
||||
Self(options)
|
||||
}
|
||||
|
||||
pub fn get_cache_key(&self, prefix: &str) -> SmolStr {
|
||||
pub fn get_cache_key(&self, prefix: &str) -> EndpointCacheKey {
|
||||
// prefix + format!(" {k}:{v}")
|
||||
// kinda jank because SmolStr is immutable
|
||||
std::iter::once(prefix)
|
||||
.chain(self.0.iter().flat_map(|(k, v)| [" ", &**k, ":", &**v]))
|
||||
.collect()
|
||||
.collect::<SmolStr>()
|
||||
.into()
|
||||
}
|
||||
|
||||
/// <https://swagger.io/docs/specification/serialization/> DeepObject format
|
||||
/// `paramName[prop1]=value1¶mName[prop2]=value2&...`
|
||||
pub fn to_deep_object(&self) -> Vec<(String, SmolStr)> {
|
||||
pub fn to_deep_object(&self) -> Vec<(SmolStr, SmolStr)> {
|
||||
self.0
|
||||
.iter()
|
||||
.map(|(k, v)| (format!("options[{}]", k), v.clone()))
|
||||
.map(|(k, v)| (format_smolstr!("options[{}]", k), v.clone()))
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,11 +11,12 @@ use anyhow::bail;
|
||||
use dashmap::DashMap;
|
||||
use itertools::Itertools;
|
||||
use rand::{rngs::StdRng, Rng, SeedableRng};
|
||||
use smol_str::SmolStr;
|
||||
use tokio::sync::{Mutex as AsyncMutex, Semaphore, SemaphorePermit};
|
||||
use tokio::time::{timeout, Duration, Instant};
|
||||
use tracing::info;
|
||||
|
||||
use crate::EndpointId;
|
||||
|
||||
use super::{
|
||||
limit_algorithm::{LimitAlgorithm, Sample},
|
||||
RateLimiterConfig,
|
||||
@@ -33,7 +34,7 @@ use super::{
|
||||
// does not look very nice (`SSL SYSCALL error: Undefined error: 0`), so for now
|
||||
// I went with a more expensive way that yields user-friendlier error messages.
|
||||
pub struct EndpointRateLimiter<Rand = StdRng, Hasher = RandomState> {
|
||||
map: DashMap<SmolStr, Vec<RateBucket>, Hasher>,
|
||||
map: DashMap<EndpointId, Vec<RateBucket>, Hasher>,
|
||||
info: &'static [RateBucketInfo],
|
||||
access_count: AtomicUsize,
|
||||
rand: Mutex<Rand>,
|
||||
@@ -146,7 +147,7 @@ impl<R: Rng, S: BuildHasher + Clone> EndpointRateLimiter<R, S> {
|
||||
}
|
||||
|
||||
/// Check that number of connections to the endpoint is below `max_rps` rps.
|
||||
pub fn check(&self, endpoint: SmolStr) -> bool {
|
||||
pub fn check(&self, endpoint: EndpointId) -> bool {
|
||||
// do a partial GC every 2k requests. This cleans up ~ 1/64th of the map.
|
||||
// worst case memory usage is about:
|
||||
// = 2 * 2048 * 64 * (48B + 72B)
|
||||
@@ -493,11 +494,13 @@ mod tests {
|
||||
use futures::{task::noop_waker_ref, Future};
|
||||
use rand::SeedableRng;
|
||||
use rustc_hash::FxHasher;
|
||||
use smol_str::SmolStr;
|
||||
use tokio::time;
|
||||
|
||||
use super::{EndpointRateLimiter, Limiter, Outcome};
|
||||
use crate::rate_limiter::{RateBucketInfo, RateLimitAlgorithm};
|
||||
use crate::{
|
||||
rate_limiter::{RateBucketInfo, RateLimitAlgorithm},
|
||||
EndpointId,
|
||||
};
|
||||
|
||||
#[tokio::test]
|
||||
async fn it_works() {
|
||||
@@ -654,7 +657,7 @@ mod tests {
|
||||
RateBucketInfo::validate(&mut rates).unwrap();
|
||||
let limiter = EndpointRateLimiter::new(Vec::leak(rates));
|
||||
|
||||
let endpoint = SmolStr::from("ep-my-endpoint-1234");
|
||||
let endpoint = EndpointId::from("ep-my-endpoint-1234");
|
||||
|
||||
time::pause();
|
||||
|
||||
|
||||
@@ -3,9 +3,8 @@ use std::{convert::Infallible, sync::Arc};
|
||||
use futures::StreamExt;
|
||||
use redis::aio::PubSub;
|
||||
use serde::Deserialize;
|
||||
use smol_str::SmolStr;
|
||||
|
||||
use crate::cache::project_info::ProjectInfoCache;
|
||||
use crate::{cache::project_info::ProjectInfoCache, ProjectId, RoleName};
|
||||
|
||||
const CHANNEL_NAME: &str = "neondb-proxy-ws-updates";
|
||||
const RECONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(20);
|
||||
@@ -46,12 +45,12 @@ enum Notification {
|
||||
}
|
||||
#[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
|
||||
struct AllowedIpsUpdate {
|
||||
project_id: SmolStr,
|
||||
project_id: ProjectId,
|
||||
}
|
||||
#[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
|
||||
struct PasswordUpdate {
|
||||
project_id: SmolStr,
|
||||
role_name: SmolStr,
|
||||
project_id: ProjectId,
|
||||
role_name: RoleName,
|
||||
}
|
||||
fn deserialize_json_string<'de, D, T>(deserializer: D) -> Result<T, D::Error>
|
||||
where
|
||||
|
||||
@@ -31,6 +31,7 @@ use crate::{
|
||||
metrics::NUM_DB_CONNECTIONS_GAUGE,
|
||||
proxy::connect_compute::ConnectMechanism,
|
||||
usage_metrics::{Ids, MetricCounter, USAGE_METRICS},
|
||||
DbName, EndpointCacheKey, RoleName,
|
||||
};
|
||||
use crate::{compute, config};
|
||||
|
||||
@@ -42,17 +43,17 @@ pub const APP_NAME: SmolStr = SmolStr::new_inline("/sql_over_http");
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ConnInfo {
|
||||
pub user_info: ComputeUserInfo,
|
||||
pub dbname: SmolStr,
|
||||
pub dbname: DbName,
|
||||
pub password: SmolStr,
|
||||
}
|
||||
|
||||
impl ConnInfo {
|
||||
// hm, change to hasher to avoid cloning?
|
||||
pub fn db_and_user(&self) -> (SmolStr, SmolStr) {
|
||||
pub fn db_and_user(&self) -> (DbName, RoleName) {
|
||||
(self.dbname.clone(), self.user_info.user.clone())
|
||||
}
|
||||
|
||||
pub fn endpoint_cache_key(&self) -> SmolStr {
|
||||
pub fn endpoint_cache_key(&self) -> EndpointCacheKey {
|
||||
self.user_info.endpoint_cache_key()
|
||||
}
|
||||
}
|
||||
@@ -79,14 +80,14 @@ struct ConnPoolEntry {
|
||||
// Per-endpoint connection pool, (dbname, username) -> DbUserConnPool
|
||||
// Number of open connections is limited by the `max_conns_per_endpoint`.
|
||||
pub struct EndpointConnPool {
|
||||
pools: HashMap<(SmolStr, SmolStr), DbUserConnPool>,
|
||||
pools: HashMap<(DbName, RoleName), DbUserConnPool>,
|
||||
total_conns: usize,
|
||||
max_conns: usize,
|
||||
_guard: IntCounterPairGuard,
|
||||
}
|
||||
|
||||
impl EndpointConnPool {
|
||||
fn get_conn_entry(&mut self, db_user: (SmolStr, SmolStr)) -> Option<ConnPoolEntry> {
|
||||
fn get_conn_entry(&mut self, db_user: (DbName, RoleName)) -> Option<ConnPoolEntry> {
|
||||
let Self {
|
||||
pools, total_conns, ..
|
||||
} = self;
|
||||
@@ -95,7 +96,7 @@ impl EndpointConnPool {
|
||||
.and_then(|pool_entries| pool_entries.get_conn_entry(total_conns))
|
||||
}
|
||||
|
||||
fn remove_client(&mut self, db_user: (SmolStr, SmolStr), conn_id: uuid::Uuid) -> bool {
|
||||
fn remove_client(&mut self, db_user: (DbName, RoleName), conn_id: uuid::Uuid) -> bool {
|
||||
let Self {
|
||||
pools, total_conns, ..
|
||||
} = self;
|
||||
@@ -196,7 +197,7 @@ pub struct GlobalConnPool {
|
||||
//
|
||||
// That should be a fairly conteded map, so return reference to the per-endpoint
|
||||
// pool as early as possible and release the lock.
|
||||
global_pool: DashMap<SmolStr, Arc<RwLock<EndpointConnPool>>>,
|
||||
global_pool: DashMap<EndpointCacheKey, Arc<RwLock<EndpointConnPool>>>,
|
||||
|
||||
/// Number of endpoint-connection pools
|
||||
///
|
||||
@@ -440,7 +441,10 @@ impl GlobalConnPool {
|
||||
Ok(Client::new(new_client, conn_info, endpoint_pool).await)
|
||||
}
|
||||
|
||||
fn get_or_create_endpoint_pool(&self, endpoint: &SmolStr) -> Arc<RwLock<EndpointConnPool>> {
|
||||
fn get_or_create_endpoint_pool(
|
||||
&self,
|
||||
endpoint: &EndpointCacheKey,
|
||||
) -> Arc<RwLock<EndpointConnPool>> {
|
||||
// fast path
|
||||
if let Some(pool) = self.global_pool.get(endpoint) {
|
||||
return pool.clone();
|
||||
|
||||
@@ -13,7 +13,6 @@ use hyper::{Body, HeaderMap, Request};
|
||||
use serde_json::json;
|
||||
use serde_json::Map;
|
||||
use serde_json::Value;
|
||||
use smol_str::SmolStr;
|
||||
use tokio_postgres::error::DbError;
|
||||
use tokio_postgres::error::ErrorPosition;
|
||||
use tokio_postgres::types::Kind;
|
||||
@@ -36,6 +35,8 @@ use crate::config::TlsConfig;
|
||||
use crate::context::RequestMonitoring;
|
||||
use crate::metrics::NUM_CONNECTION_REQUESTS_GAUGE;
|
||||
use crate::proxy::NeonOptions;
|
||||
use crate::EndpointId;
|
||||
use crate::RoleName;
|
||||
|
||||
use super::conn_pool::ConnInfo;
|
||||
use super::conn_pool::GlobalConnPool;
|
||||
@@ -155,7 +156,7 @@ fn get_conn_info(
|
||||
.next()
|
||||
.ok_or(anyhow::anyhow!("invalid database name"))?;
|
||||
|
||||
let username = SmolStr::from(connection_url.username());
|
||||
let username = RoleName::from(connection_url.username());
|
||||
if username.is_empty() {
|
||||
return Err(anyhow::anyhow!("missing username"));
|
||||
}
|
||||
@@ -189,7 +190,7 @@ fn get_conn_info(
|
||||
|
||||
let endpoint = endpoint_sni(hostname, &tls.common_names)?;
|
||||
|
||||
let endpoint: SmolStr = endpoint.into();
|
||||
let endpoint: EndpointId = endpoint.into();
|
||||
ctx.set_endpoint_id(Some(endpoint.clone()));
|
||||
|
||||
let pairs = connection_url.query_pairs();
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
//! Periodically collect proxy consumption metrics
|
||||
//! and push them to a HTTP endpoint.
|
||||
use crate::{config::MetricCollectionConfig, http};
|
||||
use crate::{config::MetricCollectionConfig, http, BranchId, EndpointId};
|
||||
use chrono::{DateTime, Utc};
|
||||
use consumption_metrics::{idempotency_key, Event, EventChunk, EventType, CHUNK_SIZE};
|
||||
use dashmap::{mapref::entry::Entry, DashMap};
|
||||
use once_cell::sync::Lazy;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use smol_str::SmolStr;
|
||||
use std::{
|
||||
convert::Infallible,
|
||||
sync::{
|
||||
@@ -30,8 +29,8 @@ const DEFAULT_HTTP_REPORTING_TIMEOUT: Duration = Duration::from_secs(60);
|
||||
/// because we enrich the event with project_id in the control-plane endpoint.
|
||||
#[derive(Eq, Hash, PartialEq, Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct Ids {
|
||||
pub endpoint_id: SmolStr,
|
||||
pub branch_id: SmolStr,
|
||||
pub endpoint_id: EndpointId,
|
||||
pub branch_id: BranchId,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
|
||||
Reference in New Issue
Block a user