diff --git a/proxy/src/auth/backend/jwt.rs b/proxy/src/auth/backend/jwt.rs index 3f53ee24c3..2185677159 100644 --- a/proxy/src/auth/backend/jwt.rs +++ b/proxy/src/auth/backend/jwt.rs @@ -16,7 +16,7 @@ use crate::context::RequestMonitoring; use crate::control_plane::errors::GetEndpointJwksError; use crate::http::parse_json_body_with_limit; use crate::intern::RoleNameInt; -use crate::{EndpointId, RoleName}; +use crate::types::{EndpointId, RoleName}; // TODO(conrad): make these configurable. const CLOCK_SKEW_LEEWAY: Duration = Duration::from_secs(30); @@ -669,7 +669,7 @@ mod tests { use tokio::net::TcpListener; use super::*; - use crate::RoleName; + use crate::types::RoleName; fn new_ec_jwk(kid: String) -> (p256::SecretKey, jose_jwk::Jwk) { let sk = p256::SecretKey::random(&mut OsRng); diff --git a/proxy/src/auth/backend/local.rs b/proxy/src/auth/backend/local.rs index 1e029ff609..f9cb085daf 100644 --- a/proxy/src/auth/backend/local.rs +++ b/proxy/src/auth/backend/local.rs @@ -10,9 +10,10 @@ use crate::compute_ctl::ComputeCtlApi; use crate::context::RequestMonitoring; use crate::control_plane::messages::{ColdStartInfo, EndpointJwksResponse, MetricsAuxInfo}; use crate::control_plane::NodeInfo; +use crate::http; use crate::intern::{BranchIdTag, EndpointIdTag, InternId, ProjectIdTag}; +use crate::types::EndpointId; use crate::url::ApiUrl; -use crate::{http, EndpointId}; pub struct LocalBackend { pub(crate) initialize: Semaphore, diff --git a/proxy/src/auth/backend/mod.rs b/proxy/src/auth/backend/mod.rs index a4db130b61..17334b9cbb 100644 --- a/proxy/src/auth/backend/mod.rs +++ b/proxy/src/auth/backend/mod.rs @@ -32,7 +32,8 @@ use crate::proxy::connect_compute::ComputeConnectBackend; use crate::proxy::NeonOptions; use crate::rate_limiter::{BucketRateLimiter, EndpointRateLimiter, RateBucketInfo}; use crate::stream::Stream; -use crate::{scram, stream, EndpointCacheKey, EndpointId, RoleName}; +use crate::types::{EndpointCacheKey, EndpointId, RoleName}; +use crate::{scram, stream}; /// Alternative to [`std::borrow::Cow`] but doesn't need `T: ToOwned` as we don't need that functionality pub enum MaybeOwned<'a, T> { @@ -551,7 +552,7 @@ mod tests { async fn get_endpoint_jwks( &self, _ctx: &RequestMonitoring, - _endpoint: crate::EndpointId, + _endpoint: crate::types::EndpointId, ) -> Result, control_plane::errors::GetEndpointJwksError> { unimplemented!() diff --git a/proxy/src/auth/credentials.rs b/proxy/src/auth/credentials.rs index 465e427f7c..ddecae6af5 100644 --- a/proxy/src/auth/credentials.rs +++ b/proxy/src/auth/credentials.rs @@ -15,7 +15,7 @@ use crate::error::{ReportableError, UserFacingError}; use crate::metrics::{Metrics, SniKind}; use crate::proxy::NeonOptions; use crate::serverless::SERVERLESS_DRIVER_SNI; -use crate::{EndpointId, RoleName}; +use crate::types::{EndpointId, RoleName}; #[derive(Debug, Error, PartialEq, Eq, Clone)] pub(crate) enum ComputeUserInfoParseError { diff --git a/proxy/src/auth/password_hack.rs b/proxy/src/auth/password_hack.rs index 8585b8ff48..b934c28a78 100644 --- a/proxy/src/auth/password_hack.rs +++ b/proxy/src/auth/password_hack.rs @@ -5,7 +5,7 @@ use bstr::ByteSlice; -use crate::EndpointId; +use crate::types::EndpointId; pub(crate) struct PasswordHackPayload { pub(crate) endpoint: EndpointId, diff --git a/proxy/src/bin/local_proxy.rs b/proxy/src/bin/local_proxy.rs index a16c288e5d..df3628465f 100644 --- a/proxy/src/bin/local_proxy.rs +++ b/proxy/src/bin/local_proxy.rs @@ -25,8 +25,8 @@ use proxy::rate_limiter::{ use proxy::scram::threadpool::ThreadPool; use proxy::serverless::cancel_set::CancelSet; use proxy::serverless::{self, GlobalConnPoolOptions}; +use proxy::types::RoleName; use proxy::url::ApiUrl; -use proxy::RoleName; project_git_version!(GIT_VERSION); project_build_tag!(BUILD_TAG); @@ -177,7 +177,7 @@ async fn main() -> anyhow::Result<()> { let mut maintenance_tasks = JoinSet::new(); let refresh_config_notify = Arc::new(Notify::new()); - maintenance_tasks.spawn(proxy::handle_signals(shutdown.clone(), { + maintenance_tasks.spawn(proxy::signals::handle(shutdown.clone(), { let refresh_config_notify = Arc::clone(&refresh_config_notify); move || { refresh_config_notify.notify_one(); @@ -216,7 +216,7 @@ async fn main() -> anyhow::Result<()> { match futures::future::select(pin!(maintenance_tasks.join_next()), pin!(task)).await { // exit immediately on maintenance task completion - Either::Left((Some(res), _)) => match proxy::flatten_err(res)? {}, + Either::Left((Some(res), _)) => match proxy::error::flatten_err(res)? {}, // exit with error immediately if all maintenance tasks have ceased (should be caught by branch above) Either::Left((None, _)) => bail!("no maintenance tasks running. invalid state"), // exit immediately on client task error diff --git a/proxy/src/bin/pg_sni_router.rs b/proxy/src/bin/pg_sni_router.rs index 13b7fdd40a..025053d3cb 100644 --- a/proxy/src/bin/pg_sni_router.rs +++ b/proxy/src/bin/pg_sni_router.rs @@ -133,14 +133,14 @@ async fn main() -> anyhow::Result<()> { proxy_listener, cancellation_token.clone(), )); - let signals_task = tokio::spawn(proxy::handle_signals(cancellation_token, || {})); + let signals_task = tokio::spawn(proxy::signals::handle(cancellation_token, || {})); // the signal task cant ever succeed. // the main task can error, or can succeed on cancellation. // we want to immediately exit on either of these cases let signal = match futures::future::select(signals_task, main).await { - Either::Left((res, _)) => proxy::flatten_err(res)?, - Either::Right((res, _)) => return proxy::flatten_err(res), + Either::Left((res, _)) => proxy::error::flatten_err(res)?, + Either::Right((res, _)) => return proxy::error::flatten_err(res), }; // maintenance tasks return `Infallible` success values, this is an impossible value diff --git a/proxy/src/bin/proxy.rs b/proxy/src/bin/proxy.rs index 96a71e69c6..6e190029aa 100644 --- a/proxy/src/bin/proxy.rs +++ b/proxy/src/bin/proxy.rs @@ -495,7 +495,7 @@ async fn main() -> anyhow::Result<()> { // maintenance tasks. these never return unless there's an error let mut maintenance_tasks = JoinSet::new(); - maintenance_tasks.spawn(proxy::handle_signals(cancellation_token.clone(), || {})); + maintenance_tasks.spawn(proxy::signals::handle(cancellation_token.clone(), || {})); maintenance_tasks.spawn(http::health_server::task_main( http_listener, AppMetrics { @@ -561,11 +561,11 @@ async fn main() -> anyhow::Result<()> { .await { // exit immediately on maintenance task completion - Either::Left((Some(res), _)) => break proxy::flatten_err(res)?, + Either::Left((Some(res), _)) => break proxy::error::flatten_err(res)?, // exit with error immediately if all maintenance tasks have ceased (should be caught by branch above) Either::Left((None, _)) => bail!("no maintenance tasks running. invalid state"), // exit immediately on client task error - Either::Right((Some(res), _)) => proxy::flatten_err(res)?, + Either::Right((Some(res), _)) => proxy::error::flatten_err(res)?, // exit if all our client tasks have shutdown gracefully Either::Right((None, _)) => return Ok(()), } diff --git a/proxy/src/cache/endpoints.rs b/proxy/src/cache/endpoints.rs index 82f3247fa7..12c33169bf 100644 --- a/proxy/src/cache/endpoints.rs +++ b/proxy/src/cache/endpoints.rs @@ -17,7 +17,7 @@ use crate::intern::{BranchIdInt, EndpointIdInt, ProjectIdInt}; use crate::metrics::{Metrics, RedisErrors, RedisEventsCount}; use crate::rate_limiter::GlobalRateLimiter; use crate::redis::connection_with_credentials_provider::ConnectionWithCredentialsProvider; -use crate::EndpointId; +use crate::types::EndpointId; #[derive(Deserialize, Debug, Clone)] pub(crate) struct ControlPlaneEventKey { diff --git a/proxy/src/cache/project_info.rs b/proxy/src/cache/project_info.rs index 31d1dc96e7..84430dc812 100644 --- a/proxy/src/cache/project_info.rs +++ b/proxy/src/cache/project_info.rs @@ -17,7 +17,7 @@ use crate::auth::IpPattern; use crate::config::ProjectInfoCacheOptions; use crate::control_plane::AuthSecret; use crate::intern::{EndpointIdInt, ProjectIdInt, RoleNameInt}; -use crate::{EndpointId, RoleName}; +use crate::types::{EndpointId, RoleName}; #[async_trait] pub(crate) trait ProjectInfoCache { @@ -368,7 +368,7 @@ impl Cache for ProjectInfoCacheImpl { mod tests { use super::*; use crate::scram::ServerSecret; - use crate::ProjectId; + use crate::types::ProjectId; #[tokio::test] async fn test_project_info_cache_settings() { diff --git a/proxy/src/compute.rs b/proxy/src/compute.rs index a7c2cab4a1..b97942ee5d 100644 --- a/proxy/src/compute.rs +++ b/proxy/src/compute.rs @@ -25,7 +25,7 @@ use crate::control_plane::provider::ApiLockError; use crate::error::{ReportableError, UserFacingError}; use crate::metrics::{Metrics, NumDbConnectionsGuard}; use crate::proxy::neon_option; -use crate::Host; +use crate::types::Host; pub const COULD_NOT_CONNECT: &str = "Couldn't connect to compute node"; diff --git a/proxy/src/compute_ctl/mod.rs b/proxy/src/compute_ctl/mod.rs index 2b57897223..60fdf107d4 100644 --- a/proxy/src/compute_ctl/mod.rs +++ b/proxy/src/compute_ctl/mod.rs @@ -4,8 +4,9 @@ use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; use thiserror::Error; +use crate::http; +use crate::types::{DbName, RoleName}; use crate::url::ApiUrl; -use crate::{http, DbName, RoleName}; pub struct ComputeCtlApi { pub(crate) api: http::Endpoint, diff --git a/proxy/src/config.rs b/proxy/src/config.rs index 3baa7ec751..5183f22fa3 100644 --- a/proxy/src/config.rs +++ b/proxy/src/config.rs @@ -20,7 +20,7 @@ use crate::rate_limiter::{RateBucketInfo, RateLimitAlgorithm, RateLimiterConfig} use crate::scram::threadpool::ThreadPool; use crate::serverless::cancel_set::CancelSet; use crate::serverless::GlobalConnPoolOptions; -use crate::Host; +use crate::types::Host; pub struct ProxyConfig { pub tls_config: Option, diff --git a/proxy/src/context/mod.rs b/proxy/src/context/mod.rs index e2d2c1b766..ca3b808a1b 100644 --- a/proxy/src/context/mod.rs +++ b/proxy/src/context/mod.rs @@ -19,7 +19,7 @@ use crate::intern::{BranchIdInt, ProjectIdInt}; use crate::metrics::{ ConnectOutcome, InvalidEndpointsGroup, LatencyTimer, Metrics, Protocol, Waiting, }; -use crate::{DbName, EndpointId, RoleName}; +use crate::types::{DbName, EndpointId, RoleName}; pub mod parquet; diff --git a/proxy/src/control_plane/provider/mock.rs b/proxy/src/control_plane/provider/mock.rs index fb061376e7..75a242d8d3 100644 --- a/proxy/src/control_plane/provider/mock.rs +++ b/proxy/src/control_plane/provider/mock.rs @@ -21,8 +21,9 @@ use crate::control_plane::messages::MetricsAuxInfo; use crate::control_plane::provider::{CachedAllowedIps, CachedRoleSecret}; use crate::error::io_error; use crate::intern::RoleNameInt; +use crate::types::{BranchId, EndpointId, ProjectId, RoleName}; use crate::url::ApiUrl; -use crate::{compute, scram, BranchId, EndpointId, ProjectId, RoleName}; +use crate::{compute, scram}; #[derive(Debug, Error)] enum MockApiError { diff --git a/proxy/src/control_plane/provider/mod.rs b/proxy/src/control_plane/provider/mod.rs index 88399dffa8..49e57b6b7e 100644 --- a/proxy/src/control_plane/provider/mod.rs +++ b/proxy/src/control_plane/provider/mod.rs @@ -23,7 +23,8 @@ use crate::error::ReportableError; use crate::intern::ProjectIdInt; use crate::metrics::ApiLockMetrics; use crate::rate_limiter::{DynamicLimiter, Outcome, RateLimiterConfig, Token}; -use crate::{compute, scram, EndpointCacheKey, EndpointId}; +use crate::types::{EndpointCacheKey, EndpointId}; +use crate::{compute, scram}; pub(crate) mod errors { use thiserror::Error; diff --git a/proxy/src/control_plane/provider/neon.rs b/proxy/src/control_plane/provider/neon.rs index 5d0692c7ca..8ea91d7875 100644 --- a/proxy/src/control_plane/provider/neon.rs +++ b/proxy/src/control_plane/provider/neon.rs @@ -24,7 +24,8 @@ use crate::control_plane::errors::GetEndpointJwksError; use crate::control_plane::messages::{ColdStartInfo, EndpointJwksResponse, Reason}; use crate::metrics::{CacheOutcome, Metrics}; use crate::rate_limiter::WakeComputeRateLimiter; -use crate::{compute, http, scram, EndpointCacheKey, EndpointId}; +use crate::types::{EndpointCacheKey, EndpointId}; +use crate::{compute, http, scram}; const X_REQUEST_ID: HeaderName = HeaderName::from_static("x-request-id"); diff --git a/proxy/src/error.rs b/proxy/src/error.rs index e71ed0c048..7b693a7418 100644 --- a/proxy/src/error.rs +++ b/proxy/src/error.rs @@ -1,7 +1,9 @@ use std::error::Error as StdError; use std::{fmt, io}; +use anyhow::Context; use measured::FixedCardinalityLabel; +use tokio::task::JoinError; /// Upcast (almost) any error into an opaque [`io::Error`]. pub(crate) fn io_error(e: impl Into>) -> io::Error { @@ -97,3 +99,8 @@ impl ReportableError for tokio_postgres::error::Error { } } } + +/// Flattens `Result>` into `Result`. +pub fn flatten_err(r: Result, JoinError>) -> anyhow::Result { + r.context("join error").and_then(|x| x) +} diff --git a/proxy/src/intern.rs b/proxy/src/intern.rs index 49aab917e4..f56d92a6b3 100644 --- a/proxy/src/intern.rs +++ b/proxy/src/intern.rs @@ -7,7 +7,7 @@ use std::sync::OnceLock; use lasso::{Capacity, MemoryLimits, Spur, ThreadedRodeo}; use rustc_hash::FxHasher; -use crate::{BranchId, EndpointId, ProjectId, RoleName}; +use crate::types::{BranchId, EndpointId, ProjectId, RoleName}; pub trait InternId: Sized + 'static { fn get_interner() -> &'static StringInterner; diff --git a/proxy/src/lib.rs b/proxy/src/lib.rs index ea17a88067..f95d645c23 100644 --- a/proxy/src/lib.rs +++ b/proxy/src/lib.rs @@ -78,14 +78,6 @@ // List of temporarily allowed lints to unblock beta/nightly. #![allow(unknown_lints)] -use std::convert::Infallible; - -use anyhow::{bail, Context}; -use intern::{EndpointIdInt, EndpointIdTag, InternId}; -use tokio::task::JoinError; -use tokio_util::sync::CancellationToken; -use tracing::warn; - pub mod auth; pub mod cache; pub mod cancellation; @@ -109,165 +101,9 @@ pub mod redis; pub mod sasl; pub mod scram; pub mod serverless; +pub mod signals; pub mod stream; +pub mod types; pub mod url; pub mod usage_metrics; pub mod waiters; - -/// Handle unix signals appropriately. -pub async fn handle_signals( - token: CancellationToken, - mut refresh_config: F, -) -> anyhow::Result -where - F: FnMut(), -{ - use tokio::signal::unix::{signal, SignalKind}; - - let mut hangup = signal(SignalKind::hangup())?; - let mut interrupt = signal(SignalKind::interrupt())?; - let mut terminate = signal(SignalKind::terminate())?; - - loop { - tokio::select! { - // Hangup is commonly used for config reload. - _ = hangup.recv() => { - warn!("received SIGHUP"); - refresh_config(); - } - // Shut down the whole application. - _ = interrupt.recv() => { - warn!("received SIGINT, exiting immediately"); - bail!("interrupted"); - } - _ = terminate.recv() => { - warn!("received SIGTERM, shutting down once all existing connections have closed"); - token.cancel(); - } - } - } -} - -/// Flattens `Result>` into `Result`. -pub fn flatten_err(r: Result, JoinError>) -> anyhow::Result { - r.context("join error").and_then(|x| x) -} - -macro_rules! smol_str_wrapper { - ($name:ident) => { - #[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Default)] - pub struct $name(smol_str::SmolStr); - - impl $name { - #[allow(unused)] - pub(crate) fn as_str(&self) -> &str { - self.0.as_str() - } - } - - impl std::fmt::Display for $name { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - self.0.fmt(f) - } - } - - impl std::cmp::PartialEq for $name - where - smol_str::SmolStr: std::cmp::PartialEq, - { - fn eq(&self, other: &T) -> bool { - self.0.eq(other) - } - } - - impl From for $name - where - smol_str::SmolStr: From, - { - fn from(x: T) -> Self { - Self(x.into()) - } - } - - impl AsRef for $name { - fn as_ref(&self) -> &str { - self.0.as_ref() - } - } - - impl std::ops::Deref for $name { - type Target = str; - fn deref(&self) -> &str { - &*self.0 - } - } - - impl<'de> serde::de::Deserialize<'de> for $name { - fn deserialize>(d: D) -> Result { - >::deserialize(d).map(Self) - } - } - - impl serde::Serialize for $name { - fn serialize(&self, s: S) -> Result { - self.0.serialize(s) - } - } - }; -} - -const POOLER_SUFFIX: &str = "-pooler"; - -impl EndpointId { - fn normalize(&self) -> Self { - if let Some(stripped) = self.as_ref().strip_suffix(POOLER_SUFFIX) { - stripped.into() - } else { - self.clone() - } - } - - fn normalize_intern(&self) -> EndpointIdInt { - if let Some(stripped) = self.as_ref().strip_suffix(POOLER_SUFFIX) { - EndpointIdTag::get_interner().get_or_intern(stripped) - } else { - self.into() - } - } -} - -// 90% of role name strings are 20 characters or less. -smol_str_wrapper!(RoleName); -// 50% of endpoint strings are 23 characters or less. -smol_str_wrapper!(EndpointId); -// 50% of branch strings are 23 characters or less. -smol_str_wrapper!(BranchId); -// 90% of project strings are 23 characters or less. -smol_str_wrapper!(ProjectId); - -// will usually equal endpoint ID -smol_str_wrapper!(EndpointCacheKey); - -smol_str_wrapper!(DbName); - -// postgres hostname, will likely be a port:ip addr -smol_str_wrapper!(Host); - -// Endpoints are a bit tricky. Rare they might be branches or projects. -impl EndpointId { - pub(crate) fn is_endpoint(&self) -> bool { - self.0.starts_with("ep-") - } - pub(crate) fn is_branch(&self) -> bool { - self.0.starts_with("br-") - } - // pub(crate) fn is_project(&self) -> bool { - // !self.is_endpoint() && !self.is_branch() - // } - pub(crate) fn as_branch(&self) -> BranchId { - BranchId(self.0.clone()) - } - pub(crate) fn as_project(&self) -> ProjectId { - ProjectId(self.0.clone()) - } -} diff --git a/proxy/src/proxy/connect_compute.rs b/proxy/src/proxy/connect_compute.rs index 8e9663626a..659b7afa68 100644 --- a/proxy/src/proxy/connect_compute.rs +++ b/proxy/src/proxy/connect_compute.rs @@ -17,7 +17,7 @@ use crate::metrics::{ }; use crate::proxy::retry::{retry_after, should_retry, CouldRetry}; use crate::proxy::wake_compute::wake_compute; -use crate::Host; +use crate::types::Host; const CONNECT_TIMEOUT: time::Duration = time::Duration::from_secs(2); diff --git a/proxy/src/proxy/mod.rs b/proxy/src/proxy/mod.rs index f646862caa..2970d93393 100644 --- a/proxy/src/proxy/mod.rs +++ b/proxy/src/proxy/mod.rs @@ -32,7 +32,8 @@ use crate::protocol2::read_proxy_protocol; use crate::proxy::handshake::{handshake, HandshakeData}; use crate::rate_limiter::EndpointRateLimiter; use crate::stream::{PqStream, Stream}; -use crate::{auth, compute, EndpointCacheKey}; +use crate::types::EndpointCacheKey; +use crate::{auth, compute}; const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)"; diff --git a/proxy/src/proxy/tests/mod.rs b/proxy/src/proxy/tests/mod.rs index 3f54b0661b..fe62fee204 100644 --- a/proxy/src/proxy/tests/mod.rs +++ b/proxy/src/proxy/tests/mod.rs @@ -28,7 +28,8 @@ use crate::control_plane::provider::{ }; use crate::control_plane::{self, CachedNodeInfo, NodeInfo}; use crate::error::ErrorKind; -use crate::{sasl, scram, BranchId, EndpointId, ProjectId}; +use crate::types::{BranchId, EndpointId, ProjectId}; +use crate::{sasl, scram}; /// Generate a set of TLS certificates: CA + server. fn generate_certs( diff --git a/proxy/src/rate_limiter/limiter.rs b/proxy/src/rate_limiter/limiter.rs index 5de64c2254..4259fd04f4 100644 --- a/proxy/src/rate_limiter/limiter.rs +++ b/proxy/src/rate_limiter/limiter.rs @@ -250,7 +250,7 @@ mod tests { use super::{BucketRateLimiter, WakeComputeRateLimiter}; use crate::intern::EndpointIdInt; use crate::rate_limiter::RateBucketInfo; - use crate::EndpointId; + use crate::types::EndpointId; #[test] fn rate_bucket_rpi() { diff --git a/proxy/src/redis/notifications.rs b/proxy/src/redis/notifications.rs index e56c5a3414..62e7b1b565 100644 --- a/proxy/src/redis/notifications.rs +++ b/proxy/src/redis/notifications.rs @@ -271,7 +271,7 @@ mod tests { use serde_json::json; use super::*; - use crate::{ProjectId, RoleName}; + use crate::types::{ProjectId, RoleName}; #[test] fn parse_allowed_ips() -> anyhow::Result<()> { diff --git a/proxy/src/scram/mod.rs b/proxy/src/scram/mod.rs index 97644b6282..718445f61d 100644 --- a/proxy/src/scram/mod.rs +++ b/proxy/src/scram/mod.rs @@ -62,7 +62,7 @@ mod tests { use super::{Exchange, ServerSecret}; use crate::intern::EndpointIdInt; use crate::sasl::{Mechanism, Step}; - use crate::EndpointId; + use crate::types::EndpointId; #[test] fn snapshot() { diff --git a/proxy/src/scram/threadpool.rs b/proxy/src/scram/threadpool.rs index cc1b69fcf9..ebc6dd2a3c 100644 --- a/proxy/src/scram/threadpool.rs +++ b/proxy/src/scram/threadpool.rs @@ -189,7 +189,7 @@ impl Drop for JobHandle { #[cfg(test)] mod tests { use super::*; - use crate::EndpointId; + use crate::types::EndpointId; #[tokio::test] async fn hash_is_correct() { diff --git a/proxy/src/serverless/backend.rs b/proxy/src/serverless/backend.rs index 5d59b4d252..07e0e30148 100644 --- a/proxy/src/serverless/backend.rs +++ b/proxy/src/serverless/backend.rs @@ -18,6 +18,7 @@ use super::local_conn_pool::{self, LocalClient, LocalConnPool, EXT_NAME, EXT_SCH use crate::auth::backend::local::StaticAuthRules; use crate::auth::backend::{ComputeCredentials, ComputeUserInfo}; use crate::auth::{self, check_peer_addr_is_in_list, AuthError}; +use crate::compute; use crate::compute_ctl::{ ComputeCtlError, ExtensionInstallRequest, Privilege, SetRoleGrantsRequest, }; @@ -32,7 +33,7 @@ use crate::intern::EndpointIdInt; use crate::proxy::connect_compute::ConnectMechanism; use crate::proxy::retry::{CouldRetry, ShouldRetryWakeCompute}; use crate::rate_limiter::EndpointRateLimiter; -use crate::{compute, EndpointId, Host}; +use crate::types::{EndpointId, Host}; pub(crate) struct PoolingBackend { pub(crate) http_conn_pool: Arc>, diff --git a/proxy/src/serverless/conn_pool.rs b/proxy/src/serverless/conn_pool.rs index 8401e3a1c9..7fa3357b5b 100644 --- a/proxy/src/serverless/conn_pool.rs +++ b/proxy/src/serverless/conn_pool.rs @@ -211,7 +211,7 @@ mod tests { use super::*; use crate::proxy::NeonOptions; use crate::serverless::cancel_set::CancelSet; - use crate::{BranchId, EndpointId, ProjectId}; + use crate::types::{BranchId, EndpointId, ProjectId}; struct MockClient(Arc); impl MockClient { diff --git a/proxy/src/serverless/conn_pool_lib.rs b/proxy/src/serverless/conn_pool_lib.rs index 844730194d..8830cddf0c 100644 --- a/proxy/src/serverless/conn_pool_lib.rs +++ b/proxy/src/serverless/conn_pool_lib.rs @@ -16,8 +16,8 @@ use crate::auth::backend::ComputeUserInfo; use crate::context::RequestMonitoring; use crate::control_plane::messages::ColdStartInfo; use crate::metrics::{HttpEndpointPoolsGuard, Metrics}; +use crate::types::{DbName, EndpointCacheKey, RoleName}; use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS}; -use crate::{DbName, EndpointCacheKey, RoleName}; #[derive(Debug, Clone)] pub(crate) struct ConnInfo { diff --git a/proxy/src/serverless/http_conn_pool.rs b/proxy/src/serverless/http_conn_pool.rs index 363e397976..934a50c14f 100644 --- a/proxy/src/serverless/http_conn_pool.rs +++ b/proxy/src/serverless/http_conn_pool.rs @@ -14,8 +14,8 @@ use super::conn_pool_lib::{ClientInnerExt, ConnInfo}; use crate::context::RequestMonitoring; use crate::control_plane::messages::{ColdStartInfo, MetricsAuxInfo}; use crate::metrics::{HttpEndpointPoolsGuard, Metrics}; +use crate::types::EndpointCacheKey; use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS}; -use crate::EndpointCacheKey; pub(crate) type Send = http2::SendRequest; pub(crate) type Connect = diff --git a/proxy/src/serverless/local_conn_pool.rs b/proxy/src/serverless/local_conn_pool.rs index e1ad46c751..064e7db7b3 100644 --- a/proxy/src/serverless/local_conn_pool.rs +++ b/proxy/src/serverless/local_conn_pool.rs @@ -35,8 +35,8 @@ use super::conn_pool_lib::{ClientInnerExt, ConnInfo}; use crate::context::RequestMonitoring; use crate::control_plane::messages::{ColdStartInfo, MetricsAuxInfo}; use crate::metrics::Metrics; +use crate::types::{DbName, RoleName}; use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS}; -use crate::{DbName, RoleName}; pub(crate) const EXT_NAME: &str = "pg_session_jwt"; pub(crate) const EXT_VERSION: &str = "0.1.2"; diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index 6fbb044669..8e2d4c126a 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -38,8 +38,8 @@ use crate::error::{ErrorKind, ReportableError, UserFacingError}; use crate::metrics::{HttpDirection, Metrics}; use crate::proxy::{run_until_cancelled, NeonOptions}; use crate::serverless::backend::HttpConnError; +use crate::types::{DbName, RoleName}; use crate::usage_metrics::{MetricCounter, MetricCounterRecorder}; -use crate::{DbName, RoleName}; #[derive(serde::Deserialize)] #[serde(rename_all = "camelCase")] diff --git a/proxy/src/signals.rs b/proxy/src/signals.rs new file mode 100644 index 0000000000..514a83d5eb --- /dev/null +++ b/proxy/src/signals.rs @@ -0,0 +1,39 @@ +use std::convert::Infallible; + +use anyhow::bail; +use tokio_util::sync::CancellationToken; +use tracing::warn; + +/// Handle unix signals appropriately. +pub async fn handle( + token: CancellationToken, + mut refresh_config: F, +) -> anyhow::Result +where + F: FnMut(), +{ + use tokio::signal::unix::{signal, SignalKind}; + + let mut hangup = signal(SignalKind::hangup())?; + let mut interrupt = signal(SignalKind::interrupt())?; + let mut terminate = signal(SignalKind::terminate())?; + + loop { + tokio::select! { + // Hangup is commonly used for config reload. + _ = hangup.recv() => { + warn!("received SIGHUP"); + refresh_config(); + } + // Shut down the whole application. + _ = interrupt.recv() => { + warn!("received SIGINT, exiting immediately"); + bail!("interrupted"); + } + _ = terminate.recv() => { + warn!("received SIGTERM, shutting down once all existing connections have closed"); + token.cancel(); + } + } + } +} diff --git a/proxy/src/types.rs b/proxy/src/types.rs new file mode 100644 index 0000000000..b0408a51d1 --- /dev/null +++ b/proxy/src/types.rs @@ -0,0 +1,122 @@ +use crate::intern::{EndpointIdInt, EndpointIdTag, InternId}; + +macro_rules! smol_str_wrapper { + ($name:ident) => { + #[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Default)] + pub struct $name(smol_str::SmolStr); + + impl $name { + #[allow(unused)] + pub(crate) fn as_str(&self) -> &str { + self.0.as_str() + } + } + + impl std::fmt::Display for $name { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.0.fmt(f) + } + } + + impl std::cmp::PartialEq for $name + where + smol_str::SmolStr: std::cmp::PartialEq, + { + fn eq(&self, other: &T) -> bool { + self.0.eq(other) + } + } + + impl From for $name + where + smol_str::SmolStr: From, + { + fn from(x: T) -> Self { + Self(x.into()) + } + } + + impl AsRef for $name { + fn as_ref(&self) -> &str { + self.0.as_ref() + } + } + + impl std::ops::Deref for $name { + type Target = str; + fn deref(&self) -> &str { + &*self.0 + } + } + + impl<'de> serde::de::Deserialize<'de> for $name { + fn deserialize>(d: D) -> Result { + >::deserialize(d).map(Self) + } + } + + impl serde::Serialize for $name { + fn serialize(&self, s: S) -> Result { + self.0.serialize(s) + } + } + }; +} + +const POOLER_SUFFIX: &str = "-pooler"; + +impl EndpointId { + #[must_use] + pub fn normalize(&self) -> Self { + if let Some(stripped) = self.as_ref().strip_suffix(POOLER_SUFFIX) { + stripped.into() + } else { + self.clone() + } + } + + #[must_use] + pub fn normalize_intern(&self) -> EndpointIdInt { + if let Some(stripped) = self.as_ref().strip_suffix(POOLER_SUFFIX) { + EndpointIdTag::get_interner().get_or_intern(stripped) + } else { + self.into() + } + } +} + +// 90% of role name strings are 20 characters or less. +smol_str_wrapper!(RoleName); +// 50% of endpoint strings are 23 characters or less. +smol_str_wrapper!(EndpointId); +// 50% of branch strings are 23 characters or less. +smol_str_wrapper!(BranchId); +// 90% of project strings are 23 characters or less. +smol_str_wrapper!(ProjectId); + +// will usually equal endpoint ID +smol_str_wrapper!(EndpointCacheKey); + +smol_str_wrapper!(DbName); + +// postgres hostname, will likely be a port:ip addr +smol_str_wrapper!(Host); + +// Endpoints are a bit tricky. Rare they might be branches or projects. +impl EndpointId { + pub(crate) fn is_endpoint(&self) -> bool { + self.0.starts_with("ep-") + } + pub(crate) fn is_branch(&self) -> bool { + self.0.starts_with("br-") + } + // pub(crate) fn is_project(&self) -> bool { + // !self.is_endpoint() && !self.is_branch() + // } + pub(crate) fn as_branch(&self) -> BranchId { + BranchId(self.0.clone()) + } + pub(crate) fn as_project(&self) -> ProjectId { + ProjectId(self.0.clone()) + } +} diff --git a/proxy/src/usage_metrics.rs b/proxy/src/usage_metrics.rs index f944d5aec3..c5e8588623 100644 --- a/proxy/src/usage_metrics.rs +++ b/proxy/src/usage_metrics.rs @@ -497,7 +497,8 @@ mod tests { use url::Url; use super::*; - use crate::{http, BranchId, EndpointId}; + use crate::http; + use crate::types::{BranchId, EndpointId}; #[tokio::test] async fn metrics() {