mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-08 14:02:55 +00:00
proxy: clear lib.rs of code items (#9479)
We keep lib.rs for crate configs, lint configs and re-exports for the binaries.
This commit is contained in:
@@ -16,7 +16,7 @@ use crate::context::RequestMonitoring;
|
||||
use crate::control_plane::errors::GetEndpointJwksError;
|
||||
use crate::http::parse_json_body_with_limit;
|
||||
use crate::intern::RoleNameInt;
|
||||
use crate::{EndpointId, RoleName};
|
||||
use crate::types::{EndpointId, RoleName};
|
||||
|
||||
// TODO(conrad): make these configurable.
|
||||
const CLOCK_SKEW_LEEWAY: Duration = Duration::from_secs(30);
|
||||
@@ -669,7 +669,7 @@ mod tests {
|
||||
use tokio::net::TcpListener;
|
||||
|
||||
use super::*;
|
||||
use crate::RoleName;
|
||||
use crate::types::RoleName;
|
||||
|
||||
fn new_ec_jwk(kid: String) -> (p256::SecretKey, jose_jwk::Jwk) {
|
||||
let sk = p256::SecretKey::random(&mut OsRng);
|
||||
|
||||
@@ -10,9 +10,10 @@ use crate::compute_ctl::ComputeCtlApi;
|
||||
use crate::context::RequestMonitoring;
|
||||
use crate::control_plane::messages::{ColdStartInfo, EndpointJwksResponse, MetricsAuxInfo};
|
||||
use crate::control_plane::NodeInfo;
|
||||
use crate::http;
|
||||
use crate::intern::{BranchIdTag, EndpointIdTag, InternId, ProjectIdTag};
|
||||
use crate::types::EndpointId;
|
||||
use crate::url::ApiUrl;
|
||||
use crate::{http, EndpointId};
|
||||
|
||||
pub struct LocalBackend {
|
||||
pub(crate) initialize: Semaphore,
|
||||
|
||||
@@ -32,7 +32,8 @@ use crate::proxy::connect_compute::ComputeConnectBackend;
|
||||
use crate::proxy::NeonOptions;
|
||||
use crate::rate_limiter::{BucketRateLimiter, EndpointRateLimiter, RateBucketInfo};
|
||||
use crate::stream::Stream;
|
||||
use crate::{scram, stream, EndpointCacheKey, EndpointId, RoleName};
|
||||
use crate::types::{EndpointCacheKey, EndpointId, RoleName};
|
||||
use crate::{scram, stream};
|
||||
|
||||
/// Alternative to [`std::borrow::Cow`] but doesn't need `T: ToOwned` as we don't need that functionality
|
||||
pub enum MaybeOwned<'a, T> {
|
||||
@@ -551,7 +552,7 @@ mod tests {
|
||||
async fn get_endpoint_jwks(
|
||||
&self,
|
||||
_ctx: &RequestMonitoring,
|
||||
_endpoint: crate::EndpointId,
|
||||
_endpoint: crate::types::EndpointId,
|
||||
) -> Result<Vec<super::jwt::AuthRule>, control_plane::errors::GetEndpointJwksError>
|
||||
{
|
||||
unimplemented!()
|
||||
|
||||
@@ -15,7 +15,7 @@ use crate::error::{ReportableError, UserFacingError};
|
||||
use crate::metrics::{Metrics, SniKind};
|
||||
use crate::proxy::NeonOptions;
|
||||
use crate::serverless::SERVERLESS_DRIVER_SNI;
|
||||
use crate::{EndpointId, RoleName};
|
||||
use crate::types::{EndpointId, RoleName};
|
||||
|
||||
#[derive(Debug, Error, PartialEq, Eq, Clone)]
|
||||
pub(crate) enum ComputeUserInfoParseError {
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
|
||||
use bstr::ByteSlice;
|
||||
|
||||
use crate::EndpointId;
|
||||
use crate::types::EndpointId;
|
||||
|
||||
pub(crate) struct PasswordHackPayload {
|
||||
pub(crate) endpoint: EndpointId,
|
||||
|
||||
@@ -25,8 +25,8 @@ use proxy::rate_limiter::{
|
||||
use proxy::scram::threadpool::ThreadPool;
|
||||
use proxy::serverless::cancel_set::CancelSet;
|
||||
use proxy::serverless::{self, GlobalConnPoolOptions};
|
||||
use proxy::types::RoleName;
|
||||
use proxy::url::ApiUrl;
|
||||
use proxy::RoleName;
|
||||
|
||||
project_git_version!(GIT_VERSION);
|
||||
project_build_tag!(BUILD_TAG);
|
||||
@@ -177,7 +177,7 @@ async fn main() -> anyhow::Result<()> {
|
||||
let mut maintenance_tasks = JoinSet::new();
|
||||
|
||||
let refresh_config_notify = Arc::new(Notify::new());
|
||||
maintenance_tasks.spawn(proxy::handle_signals(shutdown.clone(), {
|
||||
maintenance_tasks.spawn(proxy::signals::handle(shutdown.clone(), {
|
||||
let refresh_config_notify = Arc::clone(&refresh_config_notify);
|
||||
move || {
|
||||
refresh_config_notify.notify_one();
|
||||
@@ -216,7 +216,7 @@ async fn main() -> anyhow::Result<()> {
|
||||
|
||||
match futures::future::select(pin!(maintenance_tasks.join_next()), pin!(task)).await {
|
||||
// exit immediately on maintenance task completion
|
||||
Either::Left((Some(res), _)) => match proxy::flatten_err(res)? {},
|
||||
Either::Left((Some(res), _)) => match proxy::error::flatten_err(res)? {},
|
||||
// exit with error immediately if all maintenance tasks have ceased (should be caught by branch above)
|
||||
Either::Left((None, _)) => bail!("no maintenance tasks running. invalid state"),
|
||||
// exit immediately on client task error
|
||||
|
||||
@@ -133,14 +133,14 @@ async fn main() -> anyhow::Result<()> {
|
||||
proxy_listener,
|
||||
cancellation_token.clone(),
|
||||
));
|
||||
let signals_task = tokio::spawn(proxy::handle_signals(cancellation_token, || {}));
|
||||
let signals_task = tokio::spawn(proxy::signals::handle(cancellation_token, || {}));
|
||||
|
||||
// the signal task cant ever succeed.
|
||||
// the main task can error, or can succeed on cancellation.
|
||||
// we want to immediately exit on either of these cases
|
||||
let signal = match futures::future::select(signals_task, main).await {
|
||||
Either::Left((res, _)) => proxy::flatten_err(res)?,
|
||||
Either::Right((res, _)) => return proxy::flatten_err(res),
|
||||
Either::Left((res, _)) => proxy::error::flatten_err(res)?,
|
||||
Either::Right((res, _)) => return proxy::error::flatten_err(res),
|
||||
};
|
||||
|
||||
// maintenance tasks return `Infallible` success values, this is an impossible value
|
||||
|
||||
@@ -495,7 +495,7 @@ async fn main() -> anyhow::Result<()> {
|
||||
|
||||
// maintenance tasks. these never return unless there's an error
|
||||
let mut maintenance_tasks = JoinSet::new();
|
||||
maintenance_tasks.spawn(proxy::handle_signals(cancellation_token.clone(), || {}));
|
||||
maintenance_tasks.spawn(proxy::signals::handle(cancellation_token.clone(), || {}));
|
||||
maintenance_tasks.spawn(http::health_server::task_main(
|
||||
http_listener,
|
||||
AppMetrics {
|
||||
@@ -561,11 +561,11 @@ async fn main() -> anyhow::Result<()> {
|
||||
.await
|
||||
{
|
||||
// exit immediately on maintenance task completion
|
||||
Either::Left((Some(res), _)) => break proxy::flatten_err(res)?,
|
||||
Either::Left((Some(res), _)) => break proxy::error::flatten_err(res)?,
|
||||
// exit with error immediately if all maintenance tasks have ceased (should be caught by branch above)
|
||||
Either::Left((None, _)) => bail!("no maintenance tasks running. invalid state"),
|
||||
// exit immediately on client task error
|
||||
Either::Right((Some(res), _)) => proxy::flatten_err(res)?,
|
||||
Either::Right((Some(res), _)) => proxy::error::flatten_err(res)?,
|
||||
// exit if all our client tasks have shutdown gracefully
|
||||
Either::Right((None, _)) => return Ok(()),
|
||||
}
|
||||
|
||||
2
proxy/src/cache/endpoints.rs
vendored
2
proxy/src/cache/endpoints.rs
vendored
@@ -17,7 +17,7 @@ use crate::intern::{BranchIdInt, EndpointIdInt, ProjectIdInt};
|
||||
use crate::metrics::{Metrics, RedisErrors, RedisEventsCount};
|
||||
use crate::rate_limiter::GlobalRateLimiter;
|
||||
use crate::redis::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
|
||||
use crate::EndpointId;
|
||||
use crate::types::EndpointId;
|
||||
|
||||
#[derive(Deserialize, Debug, Clone)]
|
||||
pub(crate) struct ControlPlaneEventKey {
|
||||
|
||||
4
proxy/src/cache/project_info.rs
vendored
4
proxy/src/cache/project_info.rs
vendored
@@ -17,7 +17,7 @@ use crate::auth::IpPattern;
|
||||
use crate::config::ProjectInfoCacheOptions;
|
||||
use crate::control_plane::AuthSecret;
|
||||
use crate::intern::{EndpointIdInt, ProjectIdInt, RoleNameInt};
|
||||
use crate::{EndpointId, RoleName};
|
||||
use crate::types::{EndpointId, RoleName};
|
||||
|
||||
#[async_trait]
|
||||
pub(crate) trait ProjectInfoCache {
|
||||
@@ -368,7 +368,7 @@ impl Cache for ProjectInfoCacheImpl {
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::scram::ServerSecret;
|
||||
use crate::ProjectId;
|
||||
use crate::types::ProjectId;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_project_info_cache_settings() {
|
||||
|
||||
@@ -25,7 +25,7 @@ use crate::control_plane::provider::ApiLockError;
|
||||
use crate::error::{ReportableError, UserFacingError};
|
||||
use crate::metrics::{Metrics, NumDbConnectionsGuard};
|
||||
use crate::proxy::neon_option;
|
||||
use crate::Host;
|
||||
use crate::types::Host;
|
||||
|
||||
pub const COULD_NOT_CONNECT: &str = "Couldn't connect to compute node";
|
||||
|
||||
|
||||
@@ -4,8 +4,9 @@ use serde::de::DeserializeOwned;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use thiserror::Error;
|
||||
|
||||
use crate::http;
|
||||
use crate::types::{DbName, RoleName};
|
||||
use crate::url::ApiUrl;
|
||||
use crate::{http, DbName, RoleName};
|
||||
|
||||
pub struct ComputeCtlApi {
|
||||
pub(crate) api: http::Endpoint,
|
||||
|
||||
@@ -20,7 +20,7 @@ use crate::rate_limiter::{RateBucketInfo, RateLimitAlgorithm, RateLimiterConfig}
|
||||
use crate::scram::threadpool::ThreadPool;
|
||||
use crate::serverless::cancel_set::CancelSet;
|
||||
use crate::serverless::GlobalConnPoolOptions;
|
||||
use crate::Host;
|
||||
use crate::types::Host;
|
||||
|
||||
pub struct ProxyConfig {
|
||||
pub tls_config: Option<TlsConfig>,
|
||||
|
||||
@@ -19,7 +19,7 @@ use crate::intern::{BranchIdInt, ProjectIdInt};
|
||||
use crate::metrics::{
|
||||
ConnectOutcome, InvalidEndpointsGroup, LatencyTimer, Metrics, Protocol, Waiting,
|
||||
};
|
||||
use crate::{DbName, EndpointId, RoleName};
|
||||
use crate::types::{DbName, EndpointId, RoleName};
|
||||
|
||||
pub mod parquet;
|
||||
|
||||
|
||||
@@ -21,8 +21,9 @@ use crate::control_plane::messages::MetricsAuxInfo;
|
||||
use crate::control_plane::provider::{CachedAllowedIps, CachedRoleSecret};
|
||||
use crate::error::io_error;
|
||||
use crate::intern::RoleNameInt;
|
||||
use crate::types::{BranchId, EndpointId, ProjectId, RoleName};
|
||||
use crate::url::ApiUrl;
|
||||
use crate::{compute, scram, BranchId, EndpointId, ProjectId, RoleName};
|
||||
use crate::{compute, scram};
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
enum MockApiError {
|
||||
|
||||
@@ -23,7 +23,8 @@ use crate::error::ReportableError;
|
||||
use crate::intern::ProjectIdInt;
|
||||
use crate::metrics::ApiLockMetrics;
|
||||
use crate::rate_limiter::{DynamicLimiter, Outcome, RateLimiterConfig, Token};
|
||||
use crate::{compute, scram, EndpointCacheKey, EndpointId};
|
||||
use crate::types::{EndpointCacheKey, EndpointId};
|
||||
use crate::{compute, scram};
|
||||
|
||||
pub(crate) mod errors {
|
||||
use thiserror::Error;
|
||||
|
||||
@@ -24,7 +24,8 @@ use crate::control_plane::errors::GetEndpointJwksError;
|
||||
use crate::control_plane::messages::{ColdStartInfo, EndpointJwksResponse, Reason};
|
||||
use crate::metrics::{CacheOutcome, Metrics};
|
||||
use crate::rate_limiter::WakeComputeRateLimiter;
|
||||
use crate::{compute, http, scram, EndpointCacheKey, EndpointId};
|
||||
use crate::types::{EndpointCacheKey, EndpointId};
|
||||
use crate::{compute, http, scram};
|
||||
|
||||
const X_REQUEST_ID: HeaderName = HeaderName::from_static("x-request-id");
|
||||
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
use std::error::Error as StdError;
|
||||
use std::{fmt, io};
|
||||
|
||||
use anyhow::Context;
|
||||
use measured::FixedCardinalityLabel;
|
||||
use tokio::task::JoinError;
|
||||
|
||||
/// Upcast (almost) any error into an opaque [`io::Error`].
|
||||
pub(crate) fn io_error(e: impl Into<Box<dyn StdError + Send + Sync>>) -> io::Error {
|
||||
@@ -97,3 +99,8 @@ impl ReportableError for tokio_postgres::error::Error {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Flattens `Result<Result<T>>` into `Result<T>`.
|
||||
pub fn flatten_err<T>(r: Result<anyhow::Result<T>, JoinError>) -> anyhow::Result<T> {
|
||||
r.context("join error").and_then(|x| x)
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ use std::sync::OnceLock;
|
||||
use lasso::{Capacity, MemoryLimits, Spur, ThreadedRodeo};
|
||||
use rustc_hash::FxHasher;
|
||||
|
||||
use crate::{BranchId, EndpointId, ProjectId, RoleName};
|
||||
use crate::types::{BranchId, EndpointId, ProjectId, RoleName};
|
||||
|
||||
pub trait InternId: Sized + 'static {
|
||||
fn get_interner() -> &'static StringInterner<Self>;
|
||||
|
||||
168
proxy/src/lib.rs
168
proxy/src/lib.rs
@@ -78,14 +78,6 @@
|
||||
// List of temporarily allowed lints to unblock beta/nightly.
|
||||
#![allow(unknown_lints)]
|
||||
|
||||
use std::convert::Infallible;
|
||||
|
||||
use anyhow::{bail, Context};
|
||||
use intern::{EndpointIdInt, EndpointIdTag, InternId};
|
||||
use tokio::task::JoinError;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::warn;
|
||||
|
||||
pub mod auth;
|
||||
pub mod cache;
|
||||
pub mod cancellation;
|
||||
@@ -109,165 +101,9 @@ pub mod redis;
|
||||
pub mod sasl;
|
||||
pub mod scram;
|
||||
pub mod serverless;
|
||||
pub mod signals;
|
||||
pub mod stream;
|
||||
pub mod types;
|
||||
pub mod url;
|
||||
pub mod usage_metrics;
|
||||
pub mod waiters;
|
||||
|
||||
/// Handle unix signals appropriately.
|
||||
pub async fn handle_signals<F>(
|
||||
token: CancellationToken,
|
||||
mut refresh_config: F,
|
||||
) -> anyhow::Result<Infallible>
|
||||
where
|
||||
F: FnMut(),
|
||||
{
|
||||
use tokio::signal::unix::{signal, SignalKind};
|
||||
|
||||
let mut hangup = signal(SignalKind::hangup())?;
|
||||
let mut interrupt = signal(SignalKind::interrupt())?;
|
||||
let mut terminate = signal(SignalKind::terminate())?;
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
// Hangup is commonly used for config reload.
|
||||
_ = hangup.recv() => {
|
||||
warn!("received SIGHUP");
|
||||
refresh_config();
|
||||
}
|
||||
// Shut down the whole application.
|
||||
_ = interrupt.recv() => {
|
||||
warn!("received SIGINT, exiting immediately");
|
||||
bail!("interrupted");
|
||||
}
|
||||
_ = terminate.recv() => {
|
||||
warn!("received SIGTERM, shutting down once all existing connections have closed");
|
||||
token.cancel();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Flattens `Result<Result<T>>` into `Result<T>`.
|
||||
pub fn flatten_err<T>(r: Result<anyhow::Result<T>, JoinError>) -> anyhow::Result<T> {
|
||||
r.context("join error").and_then(|x| x)
|
||||
}
|
||||
|
||||
macro_rules! smol_str_wrapper {
|
||||
($name:ident) => {
|
||||
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Default)]
|
||||
pub struct $name(smol_str::SmolStr);
|
||||
|
||||
impl $name {
|
||||
#[allow(unused)]
|
||||
pub(crate) fn as_str(&self) -> &str {
|
||||
self.0.as_str()
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for $name {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
self.0.fmt(f)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> std::cmp::PartialEq<T> for $name
|
||||
where
|
||||
smol_str::SmolStr: std::cmp::PartialEq<T>,
|
||||
{
|
||||
fn eq(&self, other: &T) -> bool {
|
||||
self.0.eq(other)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<T> for $name
|
||||
where
|
||||
smol_str::SmolStr: From<T>,
|
||||
{
|
||||
fn from(x: T) -> Self {
|
||||
Self(x.into())
|
||||
}
|
||||
}
|
||||
|
||||
impl AsRef<str> for $name {
|
||||
fn as_ref(&self) -> &str {
|
||||
self.0.as_ref()
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Deref for $name {
|
||||
type Target = str;
|
||||
fn deref(&self) -> &str {
|
||||
&*self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> serde::de::Deserialize<'de> for $name {
|
||||
fn deserialize<D: serde::de::Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
|
||||
<smol_str::SmolStr as serde::de::Deserialize<'de>>::deserialize(d).map(Self)
|
||||
}
|
||||
}
|
||||
|
||||
impl serde::Serialize for $name {
|
||||
fn serialize<S: serde::Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
|
||||
self.0.serialize(s)
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
const POOLER_SUFFIX: &str = "-pooler";
|
||||
|
||||
impl EndpointId {
|
||||
fn normalize(&self) -> Self {
|
||||
if let Some(stripped) = self.as_ref().strip_suffix(POOLER_SUFFIX) {
|
||||
stripped.into()
|
||||
} else {
|
||||
self.clone()
|
||||
}
|
||||
}
|
||||
|
||||
fn normalize_intern(&self) -> EndpointIdInt {
|
||||
if let Some(stripped) = self.as_ref().strip_suffix(POOLER_SUFFIX) {
|
||||
EndpointIdTag::get_interner().get_or_intern(stripped)
|
||||
} else {
|
||||
self.into()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 90% of role name strings are 20 characters or less.
|
||||
smol_str_wrapper!(RoleName);
|
||||
// 50% of endpoint strings are 23 characters or less.
|
||||
smol_str_wrapper!(EndpointId);
|
||||
// 50% of branch strings are 23 characters or less.
|
||||
smol_str_wrapper!(BranchId);
|
||||
// 90% of project strings are 23 characters or less.
|
||||
smol_str_wrapper!(ProjectId);
|
||||
|
||||
// will usually equal endpoint ID
|
||||
smol_str_wrapper!(EndpointCacheKey);
|
||||
|
||||
smol_str_wrapper!(DbName);
|
||||
|
||||
// postgres hostname, will likely be a port:ip addr
|
||||
smol_str_wrapper!(Host);
|
||||
|
||||
// Endpoints are a bit tricky. Rare they might be branches or projects.
|
||||
impl EndpointId {
|
||||
pub(crate) fn is_endpoint(&self) -> bool {
|
||||
self.0.starts_with("ep-")
|
||||
}
|
||||
pub(crate) fn is_branch(&self) -> bool {
|
||||
self.0.starts_with("br-")
|
||||
}
|
||||
// pub(crate) fn is_project(&self) -> bool {
|
||||
// !self.is_endpoint() && !self.is_branch()
|
||||
// }
|
||||
pub(crate) fn as_branch(&self) -> BranchId {
|
||||
BranchId(self.0.clone())
|
||||
}
|
||||
pub(crate) fn as_project(&self) -> ProjectId {
|
||||
ProjectId(self.0.clone())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,7 +17,7 @@ use crate::metrics::{
|
||||
};
|
||||
use crate::proxy::retry::{retry_after, should_retry, CouldRetry};
|
||||
use crate::proxy::wake_compute::wake_compute;
|
||||
use crate::Host;
|
||||
use crate::types::Host;
|
||||
|
||||
const CONNECT_TIMEOUT: time::Duration = time::Duration::from_secs(2);
|
||||
|
||||
|
||||
@@ -32,7 +32,8 @@ use crate::protocol2::read_proxy_protocol;
|
||||
use crate::proxy::handshake::{handshake, HandshakeData};
|
||||
use crate::rate_limiter::EndpointRateLimiter;
|
||||
use crate::stream::{PqStream, Stream};
|
||||
use crate::{auth, compute, EndpointCacheKey};
|
||||
use crate::types::EndpointCacheKey;
|
||||
use crate::{auth, compute};
|
||||
|
||||
const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)";
|
||||
|
||||
|
||||
@@ -28,7 +28,8 @@ use crate::control_plane::provider::{
|
||||
};
|
||||
use crate::control_plane::{self, CachedNodeInfo, NodeInfo};
|
||||
use crate::error::ErrorKind;
|
||||
use crate::{sasl, scram, BranchId, EndpointId, ProjectId};
|
||||
use crate::types::{BranchId, EndpointId, ProjectId};
|
||||
use crate::{sasl, scram};
|
||||
|
||||
/// Generate a set of TLS certificates: CA + server.
|
||||
fn generate_certs(
|
||||
|
||||
@@ -250,7 +250,7 @@ mod tests {
|
||||
use super::{BucketRateLimiter, WakeComputeRateLimiter};
|
||||
use crate::intern::EndpointIdInt;
|
||||
use crate::rate_limiter::RateBucketInfo;
|
||||
use crate::EndpointId;
|
||||
use crate::types::EndpointId;
|
||||
|
||||
#[test]
|
||||
fn rate_bucket_rpi() {
|
||||
|
||||
@@ -271,7 +271,7 @@ mod tests {
|
||||
use serde_json::json;
|
||||
|
||||
use super::*;
|
||||
use crate::{ProjectId, RoleName};
|
||||
use crate::types::{ProjectId, RoleName};
|
||||
|
||||
#[test]
|
||||
fn parse_allowed_ips() -> anyhow::Result<()> {
|
||||
|
||||
@@ -62,7 +62,7 @@ mod tests {
|
||||
use super::{Exchange, ServerSecret};
|
||||
use crate::intern::EndpointIdInt;
|
||||
use crate::sasl::{Mechanism, Step};
|
||||
use crate::EndpointId;
|
||||
use crate::types::EndpointId;
|
||||
|
||||
#[test]
|
||||
fn snapshot() {
|
||||
|
||||
@@ -189,7 +189,7 @@ impl Drop for JobHandle {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::EndpointId;
|
||||
use crate::types::EndpointId;
|
||||
|
||||
#[tokio::test]
|
||||
async fn hash_is_correct() {
|
||||
|
||||
@@ -18,6 +18,7 @@ use super::local_conn_pool::{self, LocalClient, LocalConnPool, EXT_NAME, EXT_SCH
|
||||
use crate::auth::backend::local::StaticAuthRules;
|
||||
use crate::auth::backend::{ComputeCredentials, ComputeUserInfo};
|
||||
use crate::auth::{self, check_peer_addr_is_in_list, AuthError};
|
||||
use crate::compute;
|
||||
use crate::compute_ctl::{
|
||||
ComputeCtlError, ExtensionInstallRequest, Privilege, SetRoleGrantsRequest,
|
||||
};
|
||||
@@ -32,7 +33,7 @@ use crate::intern::EndpointIdInt;
|
||||
use crate::proxy::connect_compute::ConnectMechanism;
|
||||
use crate::proxy::retry::{CouldRetry, ShouldRetryWakeCompute};
|
||||
use crate::rate_limiter::EndpointRateLimiter;
|
||||
use crate::{compute, EndpointId, Host};
|
||||
use crate::types::{EndpointId, Host};
|
||||
|
||||
pub(crate) struct PoolingBackend {
|
||||
pub(crate) http_conn_pool: Arc<super::http_conn_pool::GlobalConnPool<Send>>,
|
||||
|
||||
@@ -211,7 +211,7 @@ mod tests {
|
||||
use super::*;
|
||||
use crate::proxy::NeonOptions;
|
||||
use crate::serverless::cancel_set::CancelSet;
|
||||
use crate::{BranchId, EndpointId, ProjectId};
|
||||
use crate::types::{BranchId, EndpointId, ProjectId};
|
||||
|
||||
struct MockClient(Arc<AtomicBool>);
|
||||
impl MockClient {
|
||||
|
||||
@@ -16,8 +16,8 @@ use crate::auth::backend::ComputeUserInfo;
|
||||
use crate::context::RequestMonitoring;
|
||||
use crate::control_plane::messages::ColdStartInfo;
|
||||
use crate::metrics::{HttpEndpointPoolsGuard, Metrics};
|
||||
use crate::types::{DbName, EndpointCacheKey, RoleName};
|
||||
use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS};
|
||||
use crate::{DbName, EndpointCacheKey, RoleName};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct ConnInfo {
|
||||
|
||||
@@ -14,8 +14,8 @@ use super::conn_pool_lib::{ClientInnerExt, ConnInfo};
|
||||
use crate::context::RequestMonitoring;
|
||||
use crate::control_plane::messages::{ColdStartInfo, MetricsAuxInfo};
|
||||
use crate::metrics::{HttpEndpointPoolsGuard, Metrics};
|
||||
use crate::types::EndpointCacheKey;
|
||||
use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS};
|
||||
use crate::EndpointCacheKey;
|
||||
|
||||
pub(crate) type Send = http2::SendRequest<hyper::body::Incoming>;
|
||||
pub(crate) type Connect =
|
||||
|
||||
@@ -35,8 +35,8 @@ use super::conn_pool_lib::{ClientInnerExt, ConnInfo};
|
||||
use crate::context::RequestMonitoring;
|
||||
use crate::control_plane::messages::{ColdStartInfo, MetricsAuxInfo};
|
||||
use crate::metrics::Metrics;
|
||||
use crate::types::{DbName, RoleName};
|
||||
use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS};
|
||||
use crate::{DbName, RoleName};
|
||||
|
||||
pub(crate) const EXT_NAME: &str = "pg_session_jwt";
|
||||
pub(crate) const EXT_VERSION: &str = "0.1.2";
|
||||
|
||||
@@ -38,8 +38,8 @@ use crate::error::{ErrorKind, ReportableError, UserFacingError};
|
||||
use crate::metrics::{HttpDirection, Metrics};
|
||||
use crate::proxy::{run_until_cancelled, NeonOptions};
|
||||
use crate::serverless::backend::HttpConnError;
|
||||
use crate::types::{DbName, RoleName};
|
||||
use crate::usage_metrics::{MetricCounter, MetricCounterRecorder};
|
||||
use crate::{DbName, RoleName};
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
|
||||
39
proxy/src/signals.rs
Normal file
39
proxy/src/signals.rs
Normal file
@@ -0,0 +1,39 @@
|
||||
use std::convert::Infallible;
|
||||
|
||||
use anyhow::bail;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::warn;
|
||||
|
||||
/// Handle unix signals appropriately.
|
||||
pub async fn handle<F>(
|
||||
token: CancellationToken,
|
||||
mut refresh_config: F,
|
||||
) -> anyhow::Result<Infallible>
|
||||
where
|
||||
F: FnMut(),
|
||||
{
|
||||
use tokio::signal::unix::{signal, SignalKind};
|
||||
|
||||
let mut hangup = signal(SignalKind::hangup())?;
|
||||
let mut interrupt = signal(SignalKind::interrupt())?;
|
||||
let mut terminate = signal(SignalKind::terminate())?;
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
// Hangup is commonly used for config reload.
|
||||
_ = hangup.recv() => {
|
||||
warn!("received SIGHUP");
|
||||
refresh_config();
|
||||
}
|
||||
// Shut down the whole application.
|
||||
_ = interrupt.recv() => {
|
||||
warn!("received SIGINT, exiting immediately");
|
||||
bail!("interrupted");
|
||||
}
|
||||
_ = terminate.recv() => {
|
||||
warn!("received SIGTERM, shutting down once all existing connections have closed");
|
||||
token.cancel();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
122
proxy/src/types.rs
Normal file
122
proxy/src/types.rs
Normal file
@@ -0,0 +1,122 @@
|
||||
use crate::intern::{EndpointIdInt, EndpointIdTag, InternId};
|
||||
|
||||
macro_rules! smol_str_wrapper {
|
||||
($name:ident) => {
|
||||
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Default)]
|
||||
pub struct $name(smol_str::SmolStr);
|
||||
|
||||
impl $name {
|
||||
#[allow(unused)]
|
||||
pub(crate) fn as_str(&self) -> &str {
|
||||
self.0.as_str()
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for $name {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
self.0.fmt(f)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> std::cmp::PartialEq<T> for $name
|
||||
where
|
||||
smol_str::SmolStr: std::cmp::PartialEq<T>,
|
||||
{
|
||||
fn eq(&self, other: &T) -> bool {
|
||||
self.0.eq(other)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<T> for $name
|
||||
where
|
||||
smol_str::SmolStr: From<T>,
|
||||
{
|
||||
fn from(x: T) -> Self {
|
||||
Self(x.into())
|
||||
}
|
||||
}
|
||||
|
||||
impl AsRef<str> for $name {
|
||||
fn as_ref(&self) -> &str {
|
||||
self.0.as_ref()
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Deref for $name {
|
||||
type Target = str;
|
||||
fn deref(&self) -> &str {
|
||||
&*self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> serde::de::Deserialize<'de> for $name {
|
||||
fn deserialize<D: serde::de::Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
|
||||
<smol_str::SmolStr as serde::de::Deserialize<'de>>::deserialize(d).map(Self)
|
||||
}
|
||||
}
|
||||
|
||||
impl serde::Serialize for $name {
|
||||
fn serialize<S: serde::Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
|
||||
self.0.serialize(s)
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
const POOLER_SUFFIX: &str = "-pooler";
|
||||
|
||||
impl EndpointId {
|
||||
#[must_use]
|
||||
pub fn normalize(&self) -> Self {
|
||||
if let Some(stripped) = self.as_ref().strip_suffix(POOLER_SUFFIX) {
|
||||
stripped.into()
|
||||
} else {
|
||||
self.clone()
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn normalize_intern(&self) -> EndpointIdInt {
|
||||
if let Some(stripped) = self.as_ref().strip_suffix(POOLER_SUFFIX) {
|
||||
EndpointIdTag::get_interner().get_or_intern(stripped)
|
||||
} else {
|
||||
self.into()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 90% of role name strings are 20 characters or less.
|
||||
smol_str_wrapper!(RoleName);
|
||||
// 50% of endpoint strings are 23 characters or less.
|
||||
smol_str_wrapper!(EndpointId);
|
||||
// 50% of branch strings are 23 characters or less.
|
||||
smol_str_wrapper!(BranchId);
|
||||
// 90% of project strings are 23 characters or less.
|
||||
smol_str_wrapper!(ProjectId);
|
||||
|
||||
// will usually equal endpoint ID
|
||||
smol_str_wrapper!(EndpointCacheKey);
|
||||
|
||||
smol_str_wrapper!(DbName);
|
||||
|
||||
// postgres hostname, will likely be a port:ip addr
|
||||
smol_str_wrapper!(Host);
|
||||
|
||||
// Endpoints are a bit tricky. Rare they might be branches or projects.
|
||||
impl EndpointId {
|
||||
pub(crate) fn is_endpoint(&self) -> bool {
|
||||
self.0.starts_with("ep-")
|
||||
}
|
||||
pub(crate) fn is_branch(&self) -> bool {
|
||||
self.0.starts_with("br-")
|
||||
}
|
||||
// pub(crate) fn is_project(&self) -> bool {
|
||||
// !self.is_endpoint() && !self.is_branch()
|
||||
// }
|
||||
pub(crate) fn as_branch(&self) -> BranchId {
|
||||
BranchId(self.0.clone())
|
||||
}
|
||||
pub(crate) fn as_project(&self) -> ProjectId {
|
||||
ProjectId(self.0.clone())
|
||||
}
|
||||
}
|
||||
@@ -497,7 +497,8 @@ mod tests {
|
||||
use url::Url;
|
||||
|
||||
use super::*;
|
||||
use crate::{http, BranchId, EndpointId};
|
||||
use crate::http;
|
||||
use crate::types::{BranchId, EndpointId};
|
||||
|
||||
#[tokio::test]
|
||||
async fn metrics() {
|
||||
|
||||
Reference in New Issue
Block a user