proxy: experiment with idea to split crates

This commit is contained in:
Conrad Ludgate
2024-08-13 14:16:50 +01:00
parent a2968c6cf8
commit b62e7c0138
36 changed files with 549 additions and 264 deletions

31
Cargo.lock generated
View File

@@ -4472,7 +4472,7 @@ dependencies = [
"postgres-protocol",
"postgres_backend",
"pq_proto",
"prometheus",
"proxy-sasl",
"rand 0.8.5",
"rand_distr",
"rcgen",
@@ -4525,6 +4525,35 @@ dependencies = [
"x509-parser",
]
[[package]]
name = "proxy-sasl"
version = "0.1.0"
dependencies = [
"ahash",
"anyhow",
"base64 0.13.1",
"bytes",
"crossbeam-deque",
"hmac",
"itertools 0.10.5",
"lasso",
"measured",
"parking_lot 0.12.1",
"pbkdf2",
"postgres-protocol",
"pq_proto",
"rand 0.8.5",
"rustls 0.22.4",
"sha2",
"subtle",
"thiserror",
"tokio",
"tracing",
"uuid",
"workspace_hack",
"x509-parser",
]
[[package]]
name = "quick-xml"
version = "0.31.0"

View File

@@ -10,6 +10,7 @@ members = [
"pageserver/client",
"pageserver/pagebench",
"proxy/core",
"proxy/sasl",
"safekeeper",
"storage_broker",
"storage_controller",

View File

@@ -9,6 +9,8 @@ default = []
testing = []
[dependencies]
proxy-sasl = { version = "0.1", path = "../sasl" }
ahash.workspace = true
anyhow.workspace = true
arc-swap.workspace = true
@@ -59,7 +61,7 @@ parquet_derive.workspace = true
pin-project-lite.workspace = true
postgres_backend.workspace = true
pq_proto.workspace = true
prometheus.workspace = true
# prometheus.workspace = true
rand.workspace = true
regex.workspace = true
remote_storage = { version = "0.1", path = "../../libs/remote_storage/" }

View File

@@ -38,7 +38,7 @@ pub enum AuthErrorImpl {
/// SASL protocol errors (includes [SCRAM](crate::scram)).
#[error(transparent)]
Sasl(#[from] crate::sasl::Error),
Sasl(#[from] proxy_sasl::sasl::Error),
#[error("Unsupported authentication method: {0}")]
BadAuthMethod(Box<str>),
@@ -148,3 +148,28 @@ impl ReportableError for AuthError {
}
}
}
impl UserFacingError for proxy_sasl::sasl::Error {
fn to_string_client(&self) -> String {
match self {
proxy_sasl::sasl::Error::ChannelBindingFailed(m) => m.to_string(),
proxy_sasl::sasl::Error::ChannelBindingBadMethod(m) => {
format!("unsupported channel binding method {m}")
}
_ => "authentication protocol violation".to_string(),
}
}
}
impl ReportableError for proxy_sasl::sasl::Error {
fn get_error_kind(&self) -> crate::error::ErrorKind {
match self {
proxy_sasl::sasl::Error::ChannelBindingFailed(_) => crate::error::ErrorKind::User,
proxy_sasl::sasl::Error::ChannelBindingBadMethod(_) => crate::error::ErrorKind::User,
proxy_sasl::sasl::Error::BadClientMessage(_) => crate::error::ErrorKind::User,
proxy_sasl::sasl::Error::MissingBinding => crate::error::ErrorKind::Service,
proxy_sasl::sasl::Error::Base64(_) => crate::error::ErrorKind::ControlPlane,
proxy_sasl::sasl::Error::Io(_) => crate::error::ErrorKind::ClientDisconnect,
}
}
}

View File

@@ -9,6 +9,7 @@ use std::time::Duration;
use ipnet::{Ipv4Net, Ipv6Net};
pub use link::LinkAuthError;
use proxy_sasl::scram;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_postgres::config::AuthKeys;
use tracing::{info, warn};
@@ -36,7 +37,7 @@ use crate::{
},
stream, url,
};
use crate::{scram, EndpointCacheKey, EndpointId, RoleName};
use crate::{EndpointCacheKey, EndpointId, RoleName};
/// Alternative to [`std::borrow::Cow`] but doesn't need `T: ToOwned` as we don't need that functionality
pub enum MaybeOwned<'a, T> {
@@ -371,8 +372,8 @@ async fn authenticate_with_secret(
let auth_outcome =
validate_password_and_exchange(&config.thread_pool, ep, &password, secret).await?;
let keys = match auth_outcome {
crate::sasl::Outcome::Success(key) => key,
crate::sasl::Outcome::Failure(reason) => {
proxy_sasl::sasl::Outcome::Success(key) => key,
proxy_sasl::sasl::Outcome::Failure(reason) => {
info!("auth backend failed with an error: {reason}");
return Err(auth::AuthError::auth_failed(&*info.user));
}
@@ -558,9 +559,9 @@ mod tests {
context::RequestMonitoring,
proxy::NeonOptions,
rate_limiter::{EndpointRateLimiter, RateBucketInfo},
scram::{threadpool::ThreadPool, ServerSecret},
stream::{PqStream, Stream},
};
use proxy_sasl::scram::{threadpool::ThreadPool, ServerSecret};
use super::{auth_quirks, AuthRateLimiter};
@@ -669,7 +670,11 @@ mod tests {
let ctx = RequestMonitoring::test();
let api = Auth {
ips: vec![],
secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()),
secret: AuthSecret::Scram(
ServerSecret::build_test_secret("my-secret-password")
.await
.unwrap(),
),
};
let user_info = ComputeUserInfoMaybeEndpoint {
@@ -746,7 +751,11 @@ mod tests {
let ctx = RequestMonitoring::test();
let api = Auth {
ips: vec![],
secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()),
secret: AuthSecret::Scram(
ServerSecret::build_test_secret("my-secret-password")
.await
.unwrap(),
),
};
let user_info = ComputeUserInfoMaybeEndpoint {
@@ -798,7 +807,11 @@ mod tests {
let ctx = RequestMonitoring::test();
let api = Auth {
ips: vec![],
secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()),
secret: AuthSecret::Scram(
ServerSecret::build_test_secret("my-secret-password")
.await
.unwrap(),
),
};
let user_info = ComputeUserInfoMaybeEndpoint {

View File

@@ -5,9 +5,9 @@ use crate::{
config::AuthenticationConfig,
console::AuthSecret,
context::RequestMonitoring,
sasl,
stream::{PqStream, Stream},
};
use proxy_sasl::sasl;
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{info, warn};

View File

@@ -7,9 +7,9 @@ use crate::{
console::AuthSecret,
context::RequestMonitoring,
intern::EndpointIdInt,
sasl,
stream::{self, Stream},
};
use proxy_sasl::sasl;
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{info, warn};

View File

@@ -2,16 +2,17 @@
use super::{backend::ComputeCredentialKeys, AuthErrorImpl, PasswordHackPayload};
use crate::{
config::TlsServerEndPoint,
console::AuthSecret,
context::RequestMonitoring,
intern::EndpointIdInt,
sasl,
scram::{self, threadpool::ThreadPool},
stream::{PqStream, Stream},
};
use postgres_protocol::authentication::sasl::{SCRAM_SHA_256, SCRAM_SHA_256_PLUS};
use pq_proto::{BeAuthenticationSaslMessage, BeMessage, BeMessage as Be};
use proxy_sasl::{
sasl,
scram::{self, threadpool::ThreadPool, TlsServerEndPoint},
};
use std::{io, sync::Arc};
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::info;
@@ -56,7 +57,7 @@ impl AuthMethod for PasswordHack {
/// Use clear-text password auth called `password` in docs
/// <https://www.postgresql.org/docs/current/auth-password.html>
pub struct CleartextPassword {
pub pool: Arc<ThreadPool>,
pub pool: Arc<ThreadPool<EndpointIdInt>>,
pub endpoint: EndpointIdInt,
pub secret: AuthSecret,
}
@@ -174,7 +175,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
}
info!("client chooses {}", sasl.method);
let outcome = sasl::SaslStream::new(self.stream, sasl.message)
let outcome = sasl::SaslStream::new(&mut self.stream.framed, sasl.message)
.authenticate(scram::Exchange::new(
secret,
rand::random,
@@ -191,7 +192,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
}
pub(crate) async fn validate_password_and_exchange(
pool: &ThreadPool,
pool: &ThreadPool<EndpointIdInt>,
endpoint: EndpointIdInt,
password: &[u8],
secret: AuthSecret,
@@ -206,7 +207,8 @@ pub(crate) async fn validate_password_and_exchange(
}
// perform scram authentication as both client and server to validate the keys
AuthSecret::Scram(scram_secret) => {
let outcome = crate::scram::exchange(pool, endpoint, &scram_secret, password).await?;
let outcome =
proxy_sasl::scram::exchange(pool, endpoint, &scram_secret, password).await?;
let client_key = match outcome {
sasl::Outcome::Success(client_key) => client_key,

View File

@@ -7,10 +7,11 @@ use std::{net::SocketAddr, sync::Arc};
use futures::future::Either;
use itertools::Itertools;
use proxy::config::TlsServerEndPoint;
use proxy::context::RequestMonitoring;
use proxy::metrics::{Metrics, ThreadPoolMetrics};
use proxy::metrics::Metrics;
use proxy::proxy::{copy_bidirectional_client_compute, run_until_cancelled, ErrorSource};
use proxy_sasl::scram::threadpool::ThreadPoolMetrics;
use proxy_sasl::scram::TlsServerEndPoint;
use rustls::pki_types::PrivateKeyDer;
use tokio::net::TcpListener;

View File

@@ -30,7 +30,6 @@ use proxy::redis::cancellation_publisher::RedisPublisherClient;
use proxy::redis::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
use proxy::redis::elasticache;
use proxy::redis::notifications;
use proxy::scram::threadpool::ThreadPool;
use proxy::serverless::cancel_set::CancelSet;
use proxy::serverless::GlobalConnPoolOptions;
use proxy::usage_metrics;
@@ -38,6 +37,7 @@ use proxy::usage_metrics;
use anyhow::bail;
use proxy::config::{self, ProxyConfig};
use proxy::serverless;
use proxy_sasl::scram::threadpool::ThreadPool;
use remote_storage::RemoteStorageConfig;
use std::net::SocketAddr;
use std::pin::pin;
@@ -607,7 +607,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
timeout,
epoch,
&Metrics::get().wake_compute_lock,
)?));
)));
tokio::spawn(locks.garbage_collect_worker());
let url = args.auth_endpoint.parse()?;
@@ -658,7 +658,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
timeout,
epoch,
&Metrics::get().proxy.connect_compute_lock,
)?;
);
let http_config = HttpConfig {
pool_options: GlobalConnPoolOptions {

View File

@@ -371,7 +371,8 @@ impl Cache for ProjectInfoCacheImpl {
#[cfg(test)]
mod tests {
use super::*;
use crate::{scram::ServerSecret, ProjectId};
use crate::ProjectId;
use proxy_sasl::scram::ServerSecret;
#[tokio::test]
async fn test_project_info_cache_settings() {

View File

@@ -1,27 +1,26 @@
use crate::{
auth::{self, backend::AuthRateLimiter},
console::locks::ApiLocks,
intern::EndpointIdInt,
rate_limiter::{RateBucketInfo, RateLimitAlgorithm, RateLimiterConfig},
scram::threadpool::ThreadPool,
serverless::{cancel_set::CancelSet, GlobalConnPoolOptions},
Host,
};
use anyhow::{bail, ensure, Context, Ok};
use itertools::Itertools;
use proxy_sasl::scram::{threadpool::ThreadPool, TlsServerEndPoint};
use remote_storage::RemoteStorageConfig;
use rustls::{
crypto::ring::sign,
pki_types::{CertificateDer, PrivateKeyDer},
};
use sha2::{Digest, Sha256};
use std::{
collections::{HashMap, HashSet},
str::FromStr,
sync::Arc,
time::Duration,
};
use tracing::{error, info};
use x509_parser::oid_registry;
pub struct ProxyConfig {
pub tls_config: Option<TlsConfig>,
@@ -58,7 +57,7 @@ pub struct HttpConfig {
}
pub struct AuthenticationConfig {
pub thread_pool: Arc<ThreadPool>,
pub thread_pool: Arc<ThreadPool<EndpointIdInt>>,
pub scram_protocol_timeout: tokio::time::Duration,
pub rate_limiter_enabled: bool,
pub rate_limiter: AuthRateLimiter,
@@ -126,66 +125,6 @@ pub fn configure_tls(
})
}
/// Channel binding parameter
///
/// <https://www.rfc-editor.org/rfc/rfc5929#section-4>
/// Description: The hash of the TLS server's certificate as it
/// appears, octet for octet, in the server's Certificate message. Note
/// that the Certificate message contains a certificate_list, in which
/// the first element is the server's certificate.
///
/// The hash function is to be selected as follows:
///
/// * if the certificate's signatureAlgorithm uses a single hash
/// function, and that hash function is either MD5 or SHA-1, then use SHA-256;
///
/// * if the certificate's signatureAlgorithm uses a single hash
/// function and that hash function neither MD5 nor SHA-1, then use
/// the hash function associated with the certificate's
/// signatureAlgorithm;
///
/// * if the certificate's signatureAlgorithm uses no hash functions or
/// uses multiple hash functions, then this channel binding type's
/// channel bindings are undefined at this time (updates to is channel
/// binding type may occur to address this issue if it ever arises).
#[derive(Debug, Clone, Copy)]
pub enum TlsServerEndPoint {
Sha256([u8; 32]),
Undefined,
}
impl TlsServerEndPoint {
pub fn new(cert: &CertificateDer) -> anyhow::Result<Self> {
let sha256_oids = [
// I'm explicitly not adding MD5 or SHA1 here... They're bad.
oid_registry::OID_SIG_ECDSA_WITH_SHA256,
oid_registry::OID_PKCS1_SHA256WITHRSA,
];
let pem = x509_parser::parse_x509_certificate(cert)
.context("Failed to parse PEM object from cerficiate")?
.1;
info!(subject = %pem.subject, "parsing TLS certificate");
let reg = oid_registry::OidRegistry::default().with_all_crypto();
let oid = pem.signature_algorithm.oid();
let alg = reg.get(oid);
if sha256_oids.contains(oid) {
let tls_server_end_point: [u8; 32] = Sha256::new().chain_update(cert).finalize().into();
info!(subject = %pem.subject, signature_algorithm = alg.map(|a| a.description()), tls_server_end_point = %base64::encode(tls_server_end_point), "determined channel binding");
Ok(Self::Sha256(tls_server_end_point))
} else {
error!(subject = %pem.subject, signature_algorithm = alg.map(|a| a.description()), "unknown channel binding");
Ok(Self::Undefined)
}
}
pub fn supported(&self) -> bool {
!matches!(self, TlsServerEndPoint::Undefined)
}
}
#[derive(Default, Debug)]
pub struct CertResolver {
certs: HashMap<String, (Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)>,

View File

@@ -16,9 +16,10 @@ use crate::{
intern::ProjectIdInt,
metrics::ApiLockMetrics,
rate_limiter::{DynamicLimiter, Outcome, RateLimiterConfig, Token},
scram, EndpointCacheKey,
EndpointCacheKey,
};
use dashmap::DashMap;
use proxy_sasl::scram;
use std::{hash::Hash, sync::Arc, time::Duration};
use tokio::time::Instant;
use tracing::info;
@@ -469,15 +470,15 @@ impl<K: Hash + Eq + Clone> ApiLocks<K> {
timeout: Duration,
epoch: std::time::Duration,
metrics: &'static ApiLockMetrics,
) -> prometheus::Result<Self> {
Ok(Self {
) -> Self {
Self {
name,
node_locks: DashMap::with_shard_amount(shards),
config,
timeout,
epoch,
metrics,
})
}
}
pub async fn get_permit(&self, key: &K) -> Result<WakeComputePermit, ApiLockError> {

View File

@@ -5,7 +5,7 @@ use super::{
AuthInfo, AuthSecret, CachedNodeInfo, NodeInfo,
};
use crate::context::RequestMonitoring;
use crate::{auth::backend::ComputeUserInfo, compute, error::io_error, scram, url::ApiUrl};
use crate::{auth::backend::ComputeUserInfo, compute, error::io_error, url::ApiUrl};
use crate::{auth::IpPattern, cache::Cached};
use crate::{
console::{
@@ -15,6 +15,7 @@ use crate::{
BranchId, EndpointId, ProjectId,
};
use futures::TryFutureExt;
use proxy_sasl::scram;
use std::{str::FromStr, sync::Arc};
use thiserror::Error;
use tokio_postgres::{config::SslMode, Client};

View File

@@ -13,10 +13,11 @@ use crate::{
http,
metrics::{CacheOutcome, Metrics},
rate_limiter::WakeComputeRateLimiter,
scram, EndpointCacheKey,
EndpointCacheKey,
};
use crate::{cache::Cached, context::RequestMonitoring};
use futures::TryFutureExt;
use proxy_sasl::scram;
use std::{sync::Arc, time::Duration};
use tokio::time::Instant;
use tokio_postgres::config::SslMode;

View File

@@ -21,13 +21,13 @@ pub mod intern;
pub mod jemalloc;
pub mod logging;
pub mod metrics;
pub mod parse;
// pub mod parse;
pub mod protocol2;
pub mod proxy;
pub mod rate_limiter;
pub mod redis;
pub mod sasl;
pub mod scram;
// pub mod sasl;
// pub mod scram;
pub mod serverless;
pub mod stream;
pub mod url;

View File

@@ -2,13 +2,14 @@ use std::sync::{Arc, OnceLock};
use lasso::ThreadedRodeo;
use measured::{
label::{FixedCardinalitySet, LabelGroupSet, LabelName, LabelSet, LabelValue, StaticLabelSet},
label::StaticLabelSet,
metric::{histogram::Thresholds, name::MetricName},
Counter, CounterVec, FixedCardinalityLabel, Gauge, GaugeVec, Histogram, HistogramVec,
LabelGroup, MetricGroup,
Counter, CounterVec, FixedCardinalityLabel, Gauge, Histogram, HistogramVec, LabelGroup,
MetricGroup,
};
use metrics::{CounterPairAssoc, CounterPairVec, HyperLogLog, HyperLogLogVec};
use proxy_sasl::scram::threadpool::ThreadPoolMetrics;
use tokio::time::{self, Instant};
use crate::console::messages::ColdStartInfo;
@@ -546,78 +547,3 @@ pub enum RedisEventsCount {
PasswordUpdate,
AllowedIpsUpdate,
}
pub struct ThreadPoolWorkers(usize);
pub struct ThreadPoolWorkerId(pub usize);
impl LabelValue for ThreadPoolWorkerId {
fn visit<V: measured::label::LabelVisitor>(&self, v: V) -> V::Output {
v.write_int(self.0 as i64)
}
}
impl LabelGroup for ThreadPoolWorkerId {
fn visit_values(&self, v: &mut impl measured::label::LabelGroupVisitor) {
v.write_value(LabelName::from_str("worker"), self);
}
}
impl LabelGroupSet for ThreadPoolWorkers {
type Group<'a> = ThreadPoolWorkerId;
fn cardinality(&self) -> Option<usize> {
Some(self.0)
}
fn encode_dense(&self, value: Self::Unique) -> Option<usize> {
Some(value)
}
fn decode_dense(&self, value: usize) -> Self::Group<'_> {
ThreadPoolWorkerId(value)
}
type Unique = usize;
fn encode(&self, value: Self::Group<'_>) -> Option<Self::Unique> {
Some(value.0)
}
fn decode(&self, value: &Self::Unique) -> Self::Group<'_> {
ThreadPoolWorkerId(*value)
}
}
impl LabelSet for ThreadPoolWorkers {
type Value<'a> = ThreadPoolWorkerId;
fn dynamic_cardinality(&self) -> Option<usize> {
Some(self.0)
}
fn encode(&self, value: Self::Value<'_>) -> Option<usize> {
(value.0 < self.0).then_some(value.0)
}
fn decode(&self, value: usize) -> Self::Value<'_> {
ThreadPoolWorkerId(value)
}
}
impl FixedCardinalitySet for ThreadPoolWorkers {
fn cardinality(&self) -> usize {
self.0
}
}
#[derive(MetricGroup)]
#[metric(new(workers: usize))]
pub struct ThreadPoolMetrics {
pub injector_queue_depth: Gauge,
#[metric(init = GaugeVec::with_label_set(ThreadPoolWorkers(workers)))]
pub worker_queue_depth: GaugeVec<ThreadPoolWorkers>,
#[metric(init = CounterVec::with_label_set(ThreadPoolWorkers(workers)))]
pub worker_task_turns_total: CounterVec<ThreadPoolWorkers>,
#[metric(init = CounterVec::with_label_set(ThreadPoolWorkers(workers)))]
pub worker_task_skips_total: CounterVec<ThreadPoolWorkers>,
}

View File

@@ -16,9 +16,10 @@ use crate::console::messages::{ConsoleError, Details, MetricsAuxInfo, Status};
use crate::console::provider::{CachedAllowedIps, CachedRoleSecret, ConsoleBackend};
use crate::console::{self, CachedNodeInfo, NodeInfo};
use crate::error::ErrorKind;
use crate::{http, sasl, scram, BranchId, EndpointId, ProjectId};
use crate::{http, BranchId, EndpointId, ProjectId};
use anyhow::{bail, Context};
use async_trait::async_trait;
use proxy_sasl::{sasl, scram};
use retry::{retry_after, ShouldRetryWakeCompute};
use rstest::rstest;
use rustls::pki_types;
@@ -137,7 +138,7 @@ struct Scram(scram::ServerSecret);
impl Scram {
async fn new(password: &str) -> anyhow::Result<Self> {
let secret = scram::ServerSecret::build(password)
let secret = scram::ServerSecret::build_test_secret(password)
.await
.context("failed to generate scram secret")?;
Ok(Scram(secret))

View File

@@ -79,11 +79,11 @@ impl PoolingBackend {
)
.await?;
let res = match auth_outcome {
crate::sasl::Outcome::Success(key) => {
proxy_sasl::sasl::Outcome::Success(key) => {
info!("user successfully authenticated");
Ok(key)
}
crate::sasl::Outcome::Failure(reason) => {
proxy_sasl::sasl::Outcome::Failure(reason) => {
info!("auth backend failed with an error: {reason}");
Err(AuthError::auth_failed(&*conn_info.user_info.user))
}

View File

@@ -1,10 +1,10 @@
use crate::config::TlsServerEndPoint;
use crate::error::{ErrorKind, ReportableError, UserFacingError};
use crate::metrics::Metrics;
use bytes::BytesMut;
use pq_proto::framed::{ConnectionError, Framed};
use pq_proto::{BeMessage, FeMessage, FeStartupPacket, ProtocolError};
use proxy_sasl::scram::TlsServerEndPoint;
use rustls::ServerConfig;
use std::pin::Pin;
use std::sync::Arc;

128
proxy/sasl/Cargo.toml Normal file
View File

@@ -0,0 +1,128 @@
[package]
name = "proxy-sasl"
version = "0.1.0"
edition.workspace = true
license.workspace = true
[features]
default = []
testing = []
[dependencies]
ahash.workspace = true
anyhow.workspace = true
# arc-swap.workspace = true
# async-compression.workspace = true
# async-trait.workspace = true
# atomic-take.workspace = true
# aws-config.workspace = true
# aws-sdk-iam.workspace = true
# aws-sigv4.workspace = true
# aws-types.workspace = true
base64.workspace = true
# bstr.workspace = true
bytes = { workspace = true, features = ["serde"] }
# camino.workspace = true
# chrono.workspace = true
# clap.workspace = true
# consumption_metrics.workspace = true
crossbeam-deque.workspace = true
# dashmap.workspace = true
# env_logger.workspace = true
# framed-websockets.workspace = true
# futures.workspace = true
# git-version.workspace = true
# hashbrown.workspace = true
# hashlink.workspace = true
# hex.workspace = true
hmac.workspace = true
# hostname.workspace = true
# http.workspace = true
# humantime.workspace = true
# humantime-serde.workspace = true
# hyper.workspace = true
# hyper1 = { package = "hyper", version = "1.2", features = ["server"] }
# hyper-util = { version = "0.1", features = ["server", "http1", "http2", "tokio"] }
# http-body-util = { version = "0.1" }
# indexmap.workspace = true
# ipnet.workspace = true
itertools.workspace = true
lasso = { workspace = true, features = ["multi-threaded"] }
# md5.workspace = true
measured = { workspace = true, features = ["lasso"] }
# metrics.workspace = true
# once_cell.workspace = true
# opentelemetry.workspace = true
parking_lot.workspace = true
# parquet.workspace = true
# parquet_derive.workspace = true
# pin-project-lite.workspace = true
# postgres_backend.workspace = true
pq_proto.workspace = true
# prometheus.workspace = true
rand.workspace = true
# regex.workspace = true
# remote_storage = { version = "0.1", path = "../../libs/remote_storage/" }
# reqwest.workspace = true
# reqwest-middleware = { workspace = true, features = ["json"] }
# reqwest-retry.workspace = true
# reqwest-tracing.workspace = true
# routerify.workspace = true
# rustc-hash.workspace = true
# rustls-pemfile.workspace = true
rustls.workspace = true
# scopeguard.workspace = true
# serde.workspace = true
# serde_json.workspace = true
sha2 = { workspace = true, features = ["asm", "oid"] }
# smol_str.workspace = true
# smallvec.workspace = true
# socket2.workspace = true
subtle.workspace = true
# task-local-extensions.workspace = true
thiserror.workspace = true
# tikv-jemallocator.workspace = true
# tikv-jemalloc-ctl = { workspace = true, features = ["use_std"] }
# tokio-postgres.workspace = true
# tokio-postgres-rustls.workspace = true
# tokio-rustls.workspace = true
# tokio-util.workspace = true
tokio = { workspace = true, features = ["signal"] }
# tower-service.workspace = true
# tracing-opentelemetry.workspace = true
# tracing-subscriber.workspace = true
# tracing-utils.workspace = true
tracing.workspace = true
# try-lock.workspace = true
# typed-json.workspace = true
# url.workspace = true
# urlencoding.workspace = true
# utils.workspace = true
# uuid.workspace = true
# rustls-native-certs.workspace = true
x509-parser.workspace = true
postgres-protocol.workspace = true
# redis.workspace = true
# # jwt stuff
# jose-jwa = "0.1.2"
# jose-jwk = { version = "0.1.2", features = ["p256", "p384", "rsa"] }
# signature = "2"
# ecdsa = "0.16"
# p256 = "0.13"
# rsa = "0.9"
workspace_hack.workspace = true
[dev-dependencies]
# camino-tempfile.workspace = true
# fallible-iterator.workspace = true
# tokio-tungstenite.workspace = true
pbkdf2 = { workspace = true, features = ["simple", "std"] }
# rcgen.workspace = true
# rstest.workspace = true
# tokio-postgres-rustls.workspace = true
# walkdir.workspace = true
# rand_distr = "0.4"
uuid.workspace = true

3
proxy/sasl/src/lib.rs Normal file
View File

@@ -0,0 +1,3 @@
mod parse;
pub mod sasl;
pub mod scram;

43
proxy/sasl/src/parse.rs Normal file
View File

@@ -0,0 +1,43 @@
//! Small parsing helpers.
use std::ffi::CStr;
pub fn split_cstr(bytes: &[u8]) -> Option<(&CStr, &[u8])> {
let cstr = CStr::from_bytes_until_nul(bytes).ok()?;
let (_, other) = bytes.split_at(cstr.to_bytes_with_nul().len());
Some((cstr, other))
}
/// See <https://doc.rust-lang.org/std/primitive.slice.html#method.split_array_ref>.
pub fn split_at_const<const N: usize>(bytes: &[u8]) -> Option<(&[u8; N], &[u8])> {
(bytes.len() >= N).then(|| {
let (head, tail) = bytes.split_at(N);
(head.try_into().unwrap(), tail)
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_split_cstr() {
assert!(split_cstr(b"").is_none());
assert!(split_cstr(b"foo").is_none());
let (cstr, rest) = split_cstr(b"\0").expect("uh-oh");
assert_eq!(cstr.to_bytes(), b"");
assert_eq!(rest, b"");
let (cstr, rest) = split_cstr(b"foo\0bar").expect("uh-oh");
assert_eq!(cstr.to_bytes(), b"foo");
assert_eq!(rest, b"bar");
}
#[test]
fn test_split_at_const() {
assert!(split_at_const::<0>(b"").is_some());
assert!(split_at_const::<1>(b"").is_none());
assert!(matches!(split_at_const::<1>(b"ok"), Some((b"o", b"k"))));
}
}

View File

@@ -10,7 +10,7 @@ mod channel_binding;
mod messages;
mod stream;
use crate::error::{ReportableError, UserFacingError};
// use crate::error::{ReportableError, UserFacingError};
use std::io;
use thiserror::Error;
@@ -40,29 +40,29 @@ pub enum Error {
Io(#[from] io::Error),
}
impl UserFacingError for Error {
fn to_string_client(&self) -> String {
use Error::*;
match self {
ChannelBindingFailed(m) => m.to_string(),
ChannelBindingBadMethod(m) => format!("unsupported channel binding method {m}"),
_ => "authentication protocol violation".to_string(),
}
}
}
// impl UserFacingError for Error {
// fn to_string_client(&self) -> String {
// use Error::*;
// match self {
// ChannelBindingFailed(m) => m.to_string(),
// ChannelBindingBadMethod(m) => format!("unsupported channel binding method {m}"),
// _ => "authentication protocol violation".to_string(),
// }
// }
// }
impl ReportableError for Error {
fn get_error_kind(&self) -> crate::error::ErrorKind {
match self {
Error::ChannelBindingFailed(_) => crate::error::ErrorKind::User,
Error::ChannelBindingBadMethod(_) => crate::error::ErrorKind::User,
Error::BadClientMessage(_) => crate::error::ErrorKind::User,
Error::MissingBinding => crate::error::ErrorKind::Service,
Error::Base64(_) => crate::error::ErrorKind::ControlPlane,
Error::Io(_) => crate::error::ErrorKind::ClientDisconnect,
}
}
}
// impl ReportableError for Error {
// fn get_error_kind(&self) -> crate::error::ErrorKind {
// match self {
// Error::ChannelBindingFailed(_) => crate::error::ErrorKind::User,
// Error::ChannelBindingBadMethod(_) => crate::error::ErrorKind::User,
// Error::BadClientMessage(_) => crate::error::ErrorKind::User,
// Error::MissingBinding => crate::error::ErrorKind::Service,
// Error::Base64(_) => crate::error::ErrorKind::ControlPlane,
// Error::Io(_) => crate::error::ErrorKind::ClientDisconnect,
// }
// }
// }
/// A convenient result type for SASL exchange.
pub type Result<T> = std::result::Result<T, Error>;

View File

@@ -1,7 +1,10 @@
//! Abstraction for the string-oriented SASL protocols.
use super::{messages::ServerMessage, Mechanism};
use crate::stream::PqStream;
use pq_proto::{
framed::{ConnectionError, Framed},
FeMessage, ProtocolError,
};
use std::io;
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::info;
@@ -9,7 +12,7 @@ use tracing::info;
/// Abstracts away all peculiarities of the libpq's protocol.
pub struct SaslStream<'a, S> {
/// The underlying stream.
stream: &'a mut PqStream<S>,
stream: &'a mut Framed<S>,
/// Current password message we received from client.
current: bytes::Bytes,
/// First SASL message produced by client.
@@ -17,7 +20,7 @@ pub struct SaslStream<'a, S> {
}
impl<'a, S> SaslStream<'a, S> {
pub fn new(stream: &'a mut PqStream<S>, first: &'a str) -> Self {
pub fn new(stream: &'a mut Framed<S>, first: &'a str) -> Self {
Self {
stream,
current: bytes::Bytes::new(),
@@ -26,6 +29,27 @@ impl<'a, S> SaslStream<'a, S> {
}
}
fn err_connection() -> io::Error {
io::Error::new(io::ErrorKind::ConnectionAborted, "connection is lost")
}
pub async fn read_password_message<S: AsyncRead + Unpin>(
framed: &mut Framed<S>,
) -> io::Result<bytes::Bytes> {
let msg = framed
.read_message()
.await
.map_err(ConnectionError::into_io_error)?
.ok_or_else(err_connection)?;
match msg {
FeMessage::PasswordMessage(msg) => Ok(msg),
bad => Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("unexpected message type: {:?}", bad),
)),
}
}
impl<S: AsyncRead + Unpin> SaslStream<'_, S> {
// Receive a new SASL message from the client.
async fn recv(&mut self) -> io::Result<&str> {
@@ -33,7 +57,7 @@ impl<S: AsyncRead + Unpin> SaslStream<'_, S> {
return Ok(first);
}
self.current = self.stream.read_password_message().await?;
self.current = read_password_message(self.stream).await?;
let s = std::str::from_utf8(&self.current)
.map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "bad encoding"))?;
@@ -44,7 +68,10 @@ impl<S: AsyncRead + Unpin> SaslStream<'_, S> {
impl<S: AsyncWrite + Unpin> SaslStream<'_, S> {
// Send a SASL message to the client.
async fn send(&mut self, msg: &ServerMessage<&str>) -> io::Result<()> {
self.stream.write_message(&msg.to_reply()).await?;
self.stream
.write_message(&msg.to_reply())
.map_err(ProtocolError::into_io_error)?;
self.stream.flush().await?;
Ok(())
}
}

View File

@@ -15,12 +15,16 @@ mod secret;
mod signature;
pub mod threadpool;
use anyhow::Context;
pub use exchange::{exchange, Exchange};
pub use key::ScramKey;
use rustls::pki_types::CertificateDer;
pub use secret::ServerSecret;
use hmac::{Hmac, Mac};
use sha2::{Digest, Sha256};
use tracing::{error, info};
use x509_parser::oid_registry;
const SCRAM_SHA_256: &str = "SCRAM-SHA-256";
const SCRAM_SHA_256_PLUS: &str = "SCRAM-SHA-256-PLUS";
@@ -57,12 +61,71 @@ fn sha256<'a>(parts: impl IntoIterator<Item = &'a [u8]>) -> [u8; 32] {
hasher.finalize().into()
}
/// Channel binding parameter
///
/// <https://www.rfc-editor.org/rfc/rfc5929#section-4>
/// Description: The hash of the TLS server's certificate as it
/// appears, octet for octet, in the server's Certificate message. Note
/// that the Certificate message contains a certificate_list, in which
/// the first element is the server's certificate.
///
/// The hash function is to be selected as follows:
///
/// * if the certificate's signatureAlgorithm uses a single hash
/// function, and that hash function is either MD5 or SHA-1, then use SHA-256;
///
/// * if the certificate's signatureAlgorithm uses a single hash
/// function and that hash function neither MD5 nor SHA-1, then use
/// the hash function associated with the certificate's
/// signatureAlgorithm;
///
/// * if the certificate's signatureAlgorithm uses no hash functions or
/// uses multiple hash functions, then this channel binding type's
/// channel bindings are undefined at this time (updates to is channel
/// binding type may occur to address this issue if it ever arises).
#[derive(Debug, Clone, Copy)]
pub enum TlsServerEndPoint {
Sha256([u8; 32]),
Undefined,
}
impl TlsServerEndPoint {
pub fn new(cert: &CertificateDer) -> anyhow::Result<Self> {
let sha256_oids = [
// I'm explicitly not adding MD5 or SHA1 here... They're bad.
oid_registry::OID_SIG_ECDSA_WITH_SHA256,
oid_registry::OID_PKCS1_SHA256WITHRSA,
];
let pem = x509_parser::parse_x509_certificate(cert)
.context("Failed to parse PEM object from cerficiate")?
.1;
info!(subject = %pem.subject, "parsing TLS certificate");
let reg = oid_registry::OidRegistry::default().with_all_crypto();
let oid = pem.signature_algorithm.oid();
let alg = reg.get(oid);
if sha256_oids.contains(oid) {
let tls_server_end_point: [u8; 32] = Sha256::new().chain_update(cert).finalize().into();
info!(subject = %pem.subject, signature_algorithm = alg.map(|a| a.description()), tls_server_end_point = %base64::encode(tls_server_end_point), "determined channel binding");
Ok(Self::Sha256(tls_server_end_point))
} else {
error!(subject = %pem.subject, signature_algorithm = alg.map(|a| a.description()), "unknown channel binding");
Ok(Self::Undefined)
}
}
pub fn supported(&self) -> bool {
!matches!(self, TlsServerEndPoint::Undefined)
}
}
#[cfg(test)]
mod tests {
use crate::{
intern::EndpointIdInt,
sasl::{Mechanism, Step},
EndpointId,
scram::TlsServerEndPoint,
};
use super::{threadpool::ThreadPool, Exchange, ServerSecret};
@@ -79,11 +142,7 @@ mod tests {
const NONCE: [u8; 18] = [
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
];
let mut exchange = Exchange::new(
&secret,
|| NONCE,
crate::config::TlsServerEndPoint::Undefined,
);
let mut exchange = Exchange::new(&secret, || NONCE, TlsServerEndPoint::Undefined);
let client_first = "n,,n=user,r=rOprNGfwEbeRWgbNEkqO";
let client_final = "c=biws,r=rOprNGfwEbeRWgbNEkqOAQIDBAUGBwgJCgsMDQ4PEBES,p=rw1r5Kph5ThxmaUBC2GAQ6MfXbPnNkFiTIvdb/Rear0=";
@@ -120,11 +179,11 @@ mod tests {
async fn run_round_trip_test(server_password: &str, client_password: &str) {
let pool = ThreadPool::new(1);
let ep = "foo";
let ep = EndpointId::from("foo");
let ep = EndpointIdInt::from(ep);
let scram_secret = ServerSecret::build(server_password).await.unwrap();
let scram_secret = ServerSecret::build_test_secret(server_password)
.await
.unwrap();
let outcome = super::exchange(&pool, ep, &scram_secret, client_password.as_bytes())
.await
.unwrap();

View File

@@ -1,6 +1,7 @@
//! Implementation of the SCRAM authentication algorithm.
use std::convert::Infallible;
use std::hash::Hash;
use hmac::{Hmac, Mac};
use sha2::Sha256;
@@ -13,8 +14,6 @@ use super::secret::ServerSecret;
use super::signature::SignatureBuilder;
use super::threadpool::ThreadPool;
use super::ScramKey;
use crate::config;
use crate::intern::EndpointIdInt;
use crate::sasl::{self, ChannelBinding, Error as SaslError};
/// The only channel binding mode we currently support.
@@ -59,14 +58,14 @@ enum ExchangeState {
pub struct Exchange<'a> {
state: ExchangeState,
secret: &'a ServerSecret,
tls_server_end_point: config::TlsServerEndPoint,
tls_server_end_point: super::TlsServerEndPoint,
}
impl<'a> Exchange<'a> {
pub fn new(
secret: &'a ServerSecret,
nonce: fn() -> [u8; SCRAM_RAW_NONCE_LEN],
tls_server_end_point: config::TlsServerEndPoint,
tls_server_end_point: super::TlsServerEndPoint,
) -> Self {
Self {
state: ExchangeState::Initial(SaslInitial { nonce }),
@@ -77,15 +76,15 @@ impl<'a> Exchange<'a> {
}
// copied from <https://github.com/neondatabase/rust-postgres/blob/20031d7a9ee1addeae6e0968e3899ae6bf01cee2/postgres-protocol/src/authentication/sasl.rs#L236-L248>
async fn derive_client_key(
pool: &ThreadPool,
endpoint: EndpointIdInt,
async fn derive_client_key<K: Hash + Send + 'static>(
pool: &ThreadPool<K>,
concurrency_key: K,
password: &[u8],
salt: &[u8],
iterations: u32,
) -> ScramKey {
let salted_password = pool
.spawn_job(endpoint, Pbkdf2::start(password, salt, iterations))
.spawn_job(concurrency_key, Pbkdf2::start(password, salt, iterations))
.await
.expect("job should not be cancelled");
@@ -101,14 +100,15 @@ async fn derive_client_key(
make_key(b"Client Key").into()
}
pub async fn exchange(
pool: &ThreadPool,
endpoint: EndpointIdInt,
pub async fn exchange<K: Hash + Send + 'static>(
pool: &ThreadPool<K>,
concurrency_key: K,
secret: &ServerSecret,
password: &[u8],
) -> sasl::Result<sasl::Outcome<super::ScramKey>> {
let salt = base64::decode(&secret.salt_base64)?;
let client_key = derive_client_key(pool, endpoint, password, &salt, secret.iterations).await;
let client_key =
derive_client_key(pool, concurrency_key, password, &salt, secret.iterations).await;
if secret.is_password_invalid(&client_key).into() {
Ok(sasl::Outcome::Failure("password doesn't match"))
@@ -121,7 +121,7 @@ impl SaslInitial {
fn transition(
&self,
secret: &ServerSecret,
tls_server_end_point: &config::TlsServerEndPoint,
tls_server_end_point: &super::TlsServerEndPoint,
input: &str,
) -> sasl::Result<sasl::Step<SaslSentInner, Infallible>> {
let client_first_message = ClientFirstMessage::parse(input)
@@ -156,7 +156,7 @@ impl SaslSentInner {
fn transition(
&self,
secret: &ServerSecret,
tls_server_end_point: &config::TlsServerEndPoint,
tls_server_end_point: &super::TlsServerEndPoint,
input: &str,
) -> sasl::Result<sasl::Step<Infallible, super::ScramKey>> {
let Self {
@@ -169,8 +169,8 @@ impl SaslSentInner {
.ok_or(SaslError::BadClientMessage("invalid client-final-message"))?;
let channel_binding = cbind_flag.encode(|_| match tls_server_end_point {
config::TlsServerEndPoint::Sha256(x) => Ok(x),
config::TlsServerEndPoint::Undefined => Err(SaslError::MissingBinding),
super::TlsServerEndPoint::Sha256(x) => Ok(x),
super::TlsServerEndPoint::Undefined => Err(SaslError::MissingBinding),
})?;
// This might've been caused by a MITM attack

View File

@@ -64,9 +64,7 @@ impl ServerSecret {
}
/// Build a new server secret from the prerequisites.
/// XXX: We only use this function in tests.
#[cfg(test)]
pub async fn build(password: &str) -> Option<Self> {
pub async fn build_test_secret(password: &str) -> Option<Self> {
Self::parse(&postgres_protocol::password::scram_sha_256(password.as_bytes()).await)
}
}

View File

@@ -4,29 +4,36 @@
//! 1. Fairness per endpoint.
//! 2. Yield support for high iteration counts.
use std::sync::{
atomic::{AtomicU64, Ordering},
Arc,
use std::{
hash::Hash,
sync::{
atomic::{AtomicU64, Ordering},
Arc,
},
};
use crossbeam_deque::{Injector, Stealer, Worker};
use itertools::Itertools;
use measured::{
label::{FixedCardinalitySet, LabelGroupSet, LabelName, LabelSet, LabelValue},
CounterVec, Gauge, GaugeVec, LabelGroup, MetricGroup,
};
use parking_lot::{Condvar, Mutex};
use rand::Rng;
use rand::{rngs::SmallRng, SeedableRng};
use tokio::sync::oneshot;
use crate::{
intern::EndpointIdInt,
metrics::{ThreadPoolMetrics, ThreadPoolWorkerId},
// intern::EndpointIdInt,
// metrics::{ThreadPoolMetrics, ThreadPoolWorkerId},
scram::countmin::CountMinSketch,
};
use super::pbkdf2::Pbkdf2;
pub struct ThreadPool {
queue: Injector<JobSpec>,
stealers: Vec<Stealer<JobSpec>>,
pub struct ThreadPool<K> {
queue: Injector<JobSpec<K>>,
stealers: Vec<Stealer<JobSpec<K>>>,
parkers: Vec<(Condvar, Mutex<ThreadState>)>,
/// bitpacked representation.
/// lower 8 bits = number of sleeping threads
@@ -42,7 +49,7 @@ enum ThreadState {
Active,
}
impl ThreadPool {
impl<K: Hash + Send + 'static> ThreadPool<K> {
pub fn new(n_workers: u8) -> Arc<Self> {
let workers = (0..n_workers).map(|_| Worker::new_fifo()).collect_vec();
let stealers = workers.iter().map(|w| w.stealer()).collect_vec();
@@ -68,11 +75,7 @@ impl ThreadPool {
pool
}
pub fn spawn_job(
&self,
endpoint: EndpointIdInt,
pbkdf2: Pbkdf2,
) -> oneshot::Receiver<[u8; 32]> {
pub fn spawn_job(&self, key: K, pbkdf2: Pbkdf2) -> oneshot::Receiver<[u8; 32]> {
let (tx, rx) = oneshot::channel();
let queue_was_empty = self.queue.is_empty();
@@ -81,7 +84,7 @@ impl ThreadPool {
self.queue.push(JobSpec {
response: tx,
pbkdf2,
endpoint,
key,
});
// inspired from <https://github.com/rayon-rs/rayon/blob/3e3962cb8f7b50773bcc360b48a7a674a53a2c77/rayon-core/src/sleep/mod.rs#L242>
@@ -139,7 +142,12 @@ impl ThreadPool {
}
}
fn steal(&self, rng: &mut impl Rng, skip: usize, worker: &Worker<JobSpec>) -> Option<JobSpec> {
fn steal(
&self,
rng: &mut impl Rng,
skip: usize,
worker: &Worker<JobSpec<K>>,
) -> Option<JobSpec<K>> {
// announce thread as idle
self.counters.fetch_add(256, Ordering::SeqCst);
@@ -188,7 +196,11 @@ impl ThreadPool {
}
}
fn thread_rt(pool: Arc<ThreadPool>, worker: Worker<JobSpec>, index: usize) {
fn thread_rt<K: Hash + Send + 'static>(
pool: Arc<ThreadPool<K>>,
worker: Worker<JobSpec<K>>,
index: usize,
) {
/// interval when we should steal from the global queue
/// so that tail latencies are managed appropriately
const STEAL_INTERVAL: usize = 61;
@@ -236,7 +248,7 @@ fn thread_rt(pool: Arc<ThreadPool>, worker: Worker<JobSpec>, index: usize) {
// receiver is closed, cancel the task
if !job.response.is_closed() {
let rate = sketch.inc_and_return(&job.endpoint, job.pbkdf2.cost());
let rate = sketch.inc_and_return(&job.key, job.pbkdf2.cost());
const P: f64 = 2000.0;
// probability decreases as rate increases.
@@ -287,24 +299,96 @@ fn thread_rt(pool: Arc<ThreadPool>, worker: Worker<JobSpec>, index: usize) {
}
}
struct JobSpec {
struct JobSpec<K> {
response: oneshot::Sender<[u8; 32]>,
pbkdf2: Pbkdf2,
endpoint: EndpointIdInt,
key: K,
}
pub struct ThreadPoolWorkers(usize);
pub struct ThreadPoolWorkerId(pub usize);
impl LabelValue for ThreadPoolWorkerId {
fn visit<V: measured::label::LabelVisitor>(&self, v: V) -> V::Output {
v.write_int(self.0 as i64)
}
}
impl LabelGroup for ThreadPoolWorkerId {
fn visit_values(&self, v: &mut impl measured::label::LabelGroupVisitor) {
v.write_value(LabelName::from_str("worker"), self);
}
}
impl LabelGroupSet for ThreadPoolWorkers {
type Group<'a> = ThreadPoolWorkerId;
fn cardinality(&self) -> Option<usize> {
Some(self.0)
}
fn encode_dense(&self, value: Self::Unique) -> Option<usize> {
Some(value)
}
fn decode_dense(&self, value: usize) -> Self::Group<'_> {
ThreadPoolWorkerId(value)
}
type Unique = usize;
fn encode(&self, value: Self::Group<'_>) -> Option<Self::Unique> {
Some(value.0)
}
fn decode(&self, value: &Self::Unique) -> Self::Group<'_> {
ThreadPoolWorkerId(*value)
}
}
impl LabelSet for ThreadPoolWorkers {
type Value<'a> = ThreadPoolWorkerId;
fn dynamic_cardinality(&self) -> Option<usize> {
Some(self.0)
}
fn encode(&self, value: Self::Value<'_>) -> Option<usize> {
(value.0 < self.0).then_some(value.0)
}
fn decode(&self, value: usize) -> Self::Value<'_> {
ThreadPoolWorkerId(value)
}
}
impl FixedCardinalitySet for ThreadPoolWorkers {
fn cardinality(&self) -> usize {
self.0
}
}
#[derive(MetricGroup)]
#[metric(new(workers: usize))]
pub struct ThreadPoolMetrics {
pub injector_queue_depth: Gauge,
#[metric(init = GaugeVec::with_label_set(ThreadPoolWorkers(workers)))]
pub worker_queue_depth: GaugeVec<ThreadPoolWorkers>,
#[metric(init = CounterVec::with_label_set(ThreadPoolWorkers(workers)))]
pub worker_task_turns_total: CounterVec<ThreadPoolWorkers>,
#[metric(init = CounterVec::with_label_set(ThreadPoolWorkers(workers)))]
pub worker_task_skips_total: CounterVec<ThreadPoolWorkers>,
}
#[cfg(test)]
mod tests {
use crate::EndpointId;
use super::*;
#[tokio::test]
async fn hash_is_correct() {
let pool = ThreadPool::new(1);
let ep = EndpointId::from("foo");
let ep = EndpointIdInt::from(ep);
let ep = "foo";
let salt = [0x55; 32];
let actual = pool