diff --git a/proxy/src/auth/backend.rs b/proxy/src/auth/backend.rs index 3555eba543..f757a15fbb 100644 --- a/proxy/src/auth/backend.rs +++ b/proxy/src/auth/backend.rs @@ -35,7 +35,7 @@ use crate::{ }, stream, url, }; -use crate::{scram, EndpointCacheKey, EndpointId, Normalize, RoleName}; +use crate::{scram, EndpointCacheKey, EndpointId, RoleName}; /// Alternative to [`std::borrow::Cow`] but doesn't need `T: ToOwned` as we don't need that functionality pub enum MaybeOwned<'a, T> { diff --git a/proxy/src/console/provider/neon.rs b/proxy/src/console/provider/neon.rs index 5d691e5f15..d72229b029 100644 --- a/proxy/src/console/provider/neon.rs +++ b/proxy/src/console/provider/neon.rs @@ -13,7 +13,7 @@ use crate::{ http, metrics::{CacheOutcome, Metrics}, rate_limiter::EndpointRateLimiter, - scram, EndpointCacheKey, Normalize, + scram, EndpointCacheKey, }; use crate::{cache::Cached, context::RequestMonitoring}; use futures::TryFutureExt; @@ -281,14 +281,6 @@ impl super::Api for Api { return Ok(cached); } - // check rate limit - if !self - .wake_compute_endpoint_rate_limiter - .check(user_info.endpoint.normalize().into(), 1) - { - return Err(WakeComputeError::TooManyConnections); - } - let permit = self.locks.get_permit(&key).await?; // after getting back a permit - it's possible the cache was filled @@ -301,6 +293,15 @@ impl super::Api for Api { } } + // check rate limit + if !self + .wake_compute_endpoint_rate_limiter + .check(user_info.endpoint.normalize_intern(), 1) + { + info!(key = &*key, "found cached compute node info"); + return Err(WakeComputeError::TooManyConnections); + } + let mut node = permit.release_result(self.do_wake_compute(ctx, user_info).await)?; ctx.set_project(node.aux.clone()); let cold_start_info = node.aux.cold_start_info; diff --git a/proxy/src/lib.rs b/proxy/src/lib.rs index 35c1616481..ea92eaaa55 100644 --- a/proxy/src/lib.rs +++ b/proxy/src/lib.rs @@ -3,6 +3,7 @@ use std::convert::Infallible; use anyhow::{bail, Context}; +use intern::{EndpointIdInt, EndpointIdTag, InternId}; use tokio::task::JoinError; use tokio_util::sync::CancellationToken; use tracing::warn; @@ -129,20 +130,22 @@ macro_rules! smol_str_wrapper { const POOLER_SUFFIX: &str = "-pooler"; -pub trait Normalize { - fn normalize(&self) -> Self; -} - -impl + From> Normalize for S { +impl EndpointId { fn normalize(&self) -> Self { - if self.as_ref().ends_with(POOLER_SUFFIX) { - let mut s = self.as_ref().to_string(); - s.truncate(s.len() - POOLER_SUFFIX.len()); - s.into() + if let Some(stripped) = self.as_ref().strip_suffix(POOLER_SUFFIX) { + stripped.into() } else { self.clone() } } + + fn normalize_intern(&self) -> EndpointIdInt { + if let Some(stripped) = self.as_ref().strip_suffix(POOLER_SUFFIX) { + EndpointIdTag::get_interner().get_or_intern(stripped) + } else { + self.into() + } + } } // 90% of role name strings are 20 characters or less.