diff --git a/Cargo.lock b/Cargo.lock index dee15b6aa7..23867eb2e4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4472,7 +4472,7 @@ dependencies = [ "postgres-protocol", "postgres_backend", "pq_proto", - "prometheus", + "proxy-sasl", "rand 0.8.5", "rand_distr", "rcgen", @@ -4525,6 +4525,35 @@ dependencies = [ "x509-parser", ] +[[package]] +name = "proxy-sasl" +version = "0.1.0" +dependencies = [ + "ahash", + "anyhow", + "base64 0.13.1", + "bytes", + "crossbeam-deque", + "hmac", + "itertools 0.10.5", + "lasso", + "measured", + "parking_lot 0.12.1", + "pbkdf2", + "postgres-protocol", + "pq_proto", + "rand 0.8.5", + "rustls 0.22.4", + "sha2", + "subtle", + "thiserror", + "tokio", + "tracing", + "uuid", + "workspace_hack", + "x509-parser", +] + [[package]] name = "quick-xml" version = "0.31.0" diff --git a/Cargo.toml b/Cargo.toml index 22815b9e80..8f2512fd5e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,6 +10,7 @@ members = [ "pageserver/client", "pageserver/pagebench", "proxy/core", + "proxy/sasl", "safekeeper", "storage_broker", "storage_controller", diff --git a/proxy/core/Cargo.toml b/proxy/core/Cargo.toml index 0f2a80ddc9..6adfc9b2e0 100644 --- a/proxy/core/Cargo.toml +++ b/proxy/core/Cargo.toml @@ -9,6 +9,8 @@ default = [] testing = [] [dependencies] +proxy-sasl = { version = "0.1", path = "../sasl" } + ahash.workspace = true anyhow.workspace = true arc-swap.workspace = true @@ -59,7 +61,7 @@ parquet_derive.workspace = true pin-project-lite.workspace = true postgres_backend.workspace = true pq_proto.workspace = true -prometheus.workspace = true +# prometheus.workspace = true rand.workspace = true regex.workspace = true remote_storage = { version = "0.1", path = "../../libs/remote_storage/" } diff --git a/proxy/core/src/auth.rs b/proxy/core/src/auth.rs index 8c44823c98..bab9b67d63 100644 --- a/proxy/core/src/auth.rs +++ b/proxy/core/src/auth.rs @@ -38,7 +38,7 @@ pub enum AuthErrorImpl { /// SASL protocol errors (includes [SCRAM](crate::scram)). #[error(transparent)] - Sasl(#[from] crate::sasl::Error), + Sasl(#[from] proxy_sasl::sasl::Error), #[error("Unsupported authentication method: {0}")] BadAuthMethod(Box), @@ -148,3 +148,28 @@ impl ReportableError for AuthError { } } } + +impl UserFacingError for proxy_sasl::sasl::Error { + fn to_string_client(&self) -> String { + match self { + proxy_sasl::sasl::Error::ChannelBindingFailed(m) => m.to_string(), + proxy_sasl::sasl::Error::ChannelBindingBadMethod(m) => { + format!("unsupported channel binding method {m}") + } + _ => "authentication protocol violation".to_string(), + } + } +} + +impl ReportableError for proxy_sasl::sasl::Error { + fn get_error_kind(&self) -> crate::error::ErrorKind { + match self { + proxy_sasl::sasl::Error::ChannelBindingFailed(_) => crate::error::ErrorKind::User, + proxy_sasl::sasl::Error::ChannelBindingBadMethod(_) => crate::error::ErrorKind::User, + proxy_sasl::sasl::Error::BadClientMessage(_) => crate::error::ErrorKind::User, + proxy_sasl::sasl::Error::MissingBinding => crate::error::ErrorKind::Service, + proxy_sasl::sasl::Error::Base64(_) => crate::error::ErrorKind::ControlPlane, + proxy_sasl::sasl::Error::Io(_) => crate::error::ErrorKind::ClientDisconnect, + } + } +} diff --git a/proxy/core/src/auth/backend.rs b/proxy/core/src/auth/backend.rs index c6a0b2af5a..01ffdeba9d 100644 --- a/proxy/core/src/auth/backend.rs +++ b/proxy/core/src/auth/backend.rs @@ -9,6 +9,7 @@ use std::time::Duration; use ipnet::{Ipv4Net, Ipv6Net}; pub use link::LinkAuthError; +use proxy_sasl::scram; use tokio::io::{AsyncRead, AsyncWrite}; use tokio_postgres::config::AuthKeys; use tracing::{info, warn}; @@ -36,7 +37,7 @@ use crate::{ }, stream, url, }; -use crate::{scram, EndpointCacheKey, EndpointId, RoleName}; +use crate::{EndpointCacheKey, EndpointId, RoleName}; /// Alternative to [`std::borrow::Cow`] but doesn't need `T: ToOwned` as we don't need that functionality pub enum MaybeOwned<'a, T> { @@ -371,8 +372,8 @@ async fn authenticate_with_secret( let auth_outcome = validate_password_and_exchange(&config.thread_pool, ep, &password, secret).await?; let keys = match auth_outcome { - crate::sasl::Outcome::Success(key) => key, - crate::sasl::Outcome::Failure(reason) => { + proxy_sasl::sasl::Outcome::Success(key) => key, + proxy_sasl::sasl::Outcome::Failure(reason) => { info!("auth backend failed with an error: {reason}"); return Err(auth::AuthError::auth_failed(&*info.user)); } @@ -558,9 +559,9 @@ mod tests { context::RequestMonitoring, proxy::NeonOptions, rate_limiter::{EndpointRateLimiter, RateBucketInfo}, - scram::{threadpool::ThreadPool, ServerSecret}, stream::{PqStream, Stream}, }; + use proxy_sasl::scram::{threadpool::ThreadPool, ServerSecret}; use super::{auth_quirks, AuthRateLimiter}; @@ -669,7 +670,11 @@ mod tests { let ctx = RequestMonitoring::test(); let api = Auth { ips: vec![], - secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()), + secret: AuthSecret::Scram( + ServerSecret::build_test_secret("my-secret-password") + .await + .unwrap(), + ), }; let user_info = ComputeUserInfoMaybeEndpoint { @@ -746,7 +751,11 @@ mod tests { let ctx = RequestMonitoring::test(); let api = Auth { ips: vec![], - secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()), + secret: AuthSecret::Scram( + ServerSecret::build_test_secret("my-secret-password") + .await + .unwrap(), + ), }; let user_info = ComputeUserInfoMaybeEndpoint { @@ -798,7 +807,11 @@ mod tests { let ctx = RequestMonitoring::test(); let api = Auth { ips: vec![], - secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()), + secret: AuthSecret::Scram( + ServerSecret::build_test_secret("my-secret-password") + .await + .unwrap(), + ), }; let user_info = ComputeUserInfoMaybeEndpoint { diff --git a/proxy/core/src/auth/backend/classic.rs b/proxy/core/src/auth/backend/classic.rs index 285fa29428..d9a051bf81 100644 --- a/proxy/core/src/auth/backend/classic.rs +++ b/proxy/core/src/auth/backend/classic.rs @@ -5,9 +5,9 @@ use crate::{ config::AuthenticationConfig, console::AuthSecret, context::RequestMonitoring, - sasl, stream::{PqStream, Stream}, }; +use proxy_sasl::sasl; use tokio::io::{AsyncRead, AsyncWrite}; use tracing::{info, warn}; diff --git a/proxy/core/src/auth/backend/hacks.rs b/proxy/core/src/auth/backend/hacks.rs index 56921dd949..339308e905 100644 --- a/proxy/core/src/auth/backend/hacks.rs +++ b/proxy/core/src/auth/backend/hacks.rs @@ -7,9 +7,9 @@ use crate::{ console::AuthSecret, context::RequestMonitoring, intern::EndpointIdInt, - sasl, stream::{self, Stream}, }; +use proxy_sasl::sasl; use tokio::io::{AsyncRead, AsyncWrite}; use tracing::{info, warn}; diff --git a/proxy/core/src/auth/flow.rs b/proxy/core/src/auth/flow.rs index acf7b4f6b6..2212a4c2b8 100644 --- a/proxy/core/src/auth/flow.rs +++ b/proxy/core/src/auth/flow.rs @@ -2,16 +2,17 @@ use super::{backend::ComputeCredentialKeys, AuthErrorImpl, PasswordHackPayload}; use crate::{ - config::TlsServerEndPoint, console::AuthSecret, context::RequestMonitoring, intern::EndpointIdInt, - sasl, - scram::{self, threadpool::ThreadPool}, stream::{PqStream, Stream}, }; use postgres_protocol::authentication::sasl::{SCRAM_SHA_256, SCRAM_SHA_256_PLUS}; use pq_proto::{BeAuthenticationSaslMessage, BeMessage, BeMessage as Be}; +use proxy_sasl::{ + sasl, + scram::{self, threadpool::ThreadPool, TlsServerEndPoint}, +}; use std::{io, sync::Arc}; use tokio::io::{AsyncRead, AsyncWrite}; use tracing::info; @@ -56,7 +57,7 @@ impl AuthMethod for PasswordHack { /// Use clear-text password auth called `password` in docs /// pub struct CleartextPassword { - pub pool: Arc, + pub pool: Arc>, pub endpoint: EndpointIdInt, pub secret: AuthSecret, } @@ -174,7 +175,7 @@ impl AuthFlow<'_, S, Scram<'_>> { } info!("client chooses {}", sasl.method); - let outcome = sasl::SaslStream::new(self.stream, sasl.message) + let outcome = sasl::SaslStream::new(&mut self.stream.framed, sasl.message) .authenticate(scram::Exchange::new( secret, rand::random, @@ -191,7 +192,7 @@ impl AuthFlow<'_, S, Scram<'_>> { } pub(crate) async fn validate_password_and_exchange( - pool: &ThreadPool, + pool: &ThreadPool, endpoint: EndpointIdInt, password: &[u8], secret: AuthSecret, @@ -206,7 +207,8 @@ pub(crate) async fn validate_password_and_exchange( } // perform scram authentication as both client and server to validate the keys AuthSecret::Scram(scram_secret) => { - let outcome = crate::scram::exchange(pool, endpoint, &scram_secret, password).await?; + let outcome = + proxy_sasl::scram::exchange(pool, endpoint, &scram_secret, password).await?; let client_key = match outcome { sasl::Outcome::Success(client_key) => client_key, diff --git a/proxy/core/src/bin/pg_sni_router.rs b/proxy/core/src/bin/pg_sni_router.rs index 1038fa5116..0ac7e6d965 100644 --- a/proxy/core/src/bin/pg_sni_router.rs +++ b/proxy/core/src/bin/pg_sni_router.rs @@ -7,10 +7,11 @@ use std::{net::SocketAddr, sync::Arc}; use futures::future::Either; use itertools::Itertools; -use proxy::config::TlsServerEndPoint; use proxy::context::RequestMonitoring; -use proxy::metrics::{Metrics, ThreadPoolMetrics}; +use proxy::metrics::Metrics; use proxy::proxy::{copy_bidirectional_client_compute, run_until_cancelled, ErrorSource}; +use proxy_sasl::scram::threadpool::ThreadPoolMetrics; +use proxy_sasl::scram::TlsServerEndPoint; use rustls::pki_types::PrivateKeyDer; use tokio::net::TcpListener; diff --git a/proxy/core/src/bin/proxy.rs b/proxy/core/src/bin/proxy.rs index b44e0ddd2f..b9c43e017d 100644 --- a/proxy/core/src/bin/proxy.rs +++ b/proxy/core/src/bin/proxy.rs @@ -30,7 +30,6 @@ use proxy::redis::cancellation_publisher::RedisPublisherClient; use proxy::redis::connection_with_credentials_provider::ConnectionWithCredentialsProvider; use proxy::redis::elasticache; use proxy::redis::notifications; -use proxy::scram::threadpool::ThreadPool; use proxy::serverless::cancel_set::CancelSet; use proxy::serverless::GlobalConnPoolOptions; use proxy::usage_metrics; @@ -38,6 +37,7 @@ use proxy::usage_metrics; use anyhow::bail; use proxy::config::{self, ProxyConfig}; use proxy::serverless; +use proxy_sasl::scram::threadpool::ThreadPool; use remote_storage::RemoteStorageConfig; use std::net::SocketAddr; use std::pin::pin; @@ -607,7 +607,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { timeout, epoch, &Metrics::get().wake_compute_lock, - )?)); + ))); tokio::spawn(locks.garbage_collect_worker()); let url = args.auth_endpoint.parse()?; @@ -658,7 +658,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { timeout, epoch, &Metrics::get().proxy.connect_compute_lock, - )?; + ); let http_config = HttpConfig { pool_options: GlobalConnPoolOptions { diff --git a/proxy/core/src/cache/project_info.rs b/proxy/core/src/cache/project_info.rs index 10cc4ceee1..14af8f5d3d 100644 --- a/proxy/core/src/cache/project_info.rs +++ b/proxy/core/src/cache/project_info.rs @@ -371,7 +371,8 @@ impl Cache for ProjectInfoCacheImpl { #[cfg(test)] mod tests { use super::*; - use crate::{scram::ServerSecret, ProjectId}; + use crate::ProjectId; + use proxy_sasl::scram::ServerSecret; #[tokio::test] async fn test_project_info_cache_settings() { diff --git a/proxy/core/src/config.rs b/proxy/core/src/config.rs index 1412095505..126376e02e 100644 --- a/proxy/core/src/config.rs +++ b/proxy/core/src/config.rs @@ -1,27 +1,26 @@ use crate::{ auth::{self, backend::AuthRateLimiter}, console::locks::ApiLocks, + intern::EndpointIdInt, rate_limiter::{RateBucketInfo, RateLimitAlgorithm, RateLimiterConfig}, - scram::threadpool::ThreadPool, serverless::{cancel_set::CancelSet, GlobalConnPoolOptions}, Host, }; + use anyhow::{bail, ensure, Context, Ok}; use itertools::Itertools; +use proxy_sasl::scram::{threadpool::ThreadPool, TlsServerEndPoint}; use remote_storage::RemoteStorageConfig; use rustls::{ crypto::ring::sign, pki_types::{CertificateDer, PrivateKeyDer}, }; -use sha2::{Digest, Sha256}; use std::{ collections::{HashMap, HashSet}, str::FromStr, sync::Arc, time::Duration, }; -use tracing::{error, info}; -use x509_parser::oid_registry; pub struct ProxyConfig { pub tls_config: Option, @@ -58,7 +57,7 @@ pub struct HttpConfig { } pub struct AuthenticationConfig { - pub thread_pool: Arc, + pub thread_pool: Arc>, pub scram_protocol_timeout: tokio::time::Duration, pub rate_limiter_enabled: bool, pub rate_limiter: AuthRateLimiter, @@ -126,66 +125,6 @@ pub fn configure_tls( }) } -/// Channel binding parameter -/// -/// -/// Description: The hash of the TLS server's certificate as it -/// appears, octet for octet, in the server's Certificate message. Note -/// that the Certificate message contains a certificate_list, in which -/// the first element is the server's certificate. -/// -/// The hash function is to be selected as follows: -/// -/// * if the certificate's signatureAlgorithm uses a single hash -/// function, and that hash function is either MD5 or SHA-1, then use SHA-256; -/// -/// * if the certificate's signatureAlgorithm uses a single hash -/// function and that hash function neither MD5 nor SHA-1, then use -/// the hash function associated with the certificate's -/// signatureAlgorithm; -/// -/// * if the certificate's signatureAlgorithm uses no hash functions or -/// uses multiple hash functions, then this channel binding type's -/// channel bindings are undefined at this time (updates to is channel -/// binding type may occur to address this issue if it ever arises). -#[derive(Debug, Clone, Copy)] -pub enum TlsServerEndPoint { - Sha256([u8; 32]), - Undefined, -} - -impl TlsServerEndPoint { - pub fn new(cert: &CertificateDer) -> anyhow::Result { - let sha256_oids = [ - // I'm explicitly not adding MD5 or SHA1 here... They're bad. - oid_registry::OID_SIG_ECDSA_WITH_SHA256, - oid_registry::OID_PKCS1_SHA256WITHRSA, - ]; - - let pem = x509_parser::parse_x509_certificate(cert) - .context("Failed to parse PEM object from cerficiate")? - .1; - - info!(subject = %pem.subject, "parsing TLS certificate"); - - let reg = oid_registry::OidRegistry::default().with_all_crypto(); - let oid = pem.signature_algorithm.oid(); - let alg = reg.get(oid); - if sha256_oids.contains(oid) { - let tls_server_end_point: [u8; 32] = Sha256::new().chain_update(cert).finalize().into(); - info!(subject = %pem.subject, signature_algorithm = alg.map(|a| a.description()), tls_server_end_point = %base64::encode(tls_server_end_point), "determined channel binding"); - Ok(Self::Sha256(tls_server_end_point)) - } else { - error!(subject = %pem.subject, signature_algorithm = alg.map(|a| a.description()), "unknown channel binding"); - Ok(Self::Undefined) - } - } - - pub fn supported(&self) -> bool { - !matches!(self, TlsServerEndPoint::Undefined) - } -} - #[derive(Default, Debug)] pub struct CertResolver { certs: HashMap, TlsServerEndPoint)>, diff --git a/proxy/core/src/console/provider.rs b/proxy/core/src/console/provider.rs index 15fc0134b3..e83d0cd5c9 100644 --- a/proxy/core/src/console/provider.rs +++ b/proxy/core/src/console/provider.rs @@ -16,9 +16,10 @@ use crate::{ intern::ProjectIdInt, metrics::ApiLockMetrics, rate_limiter::{DynamicLimiter, Outcome, RateLimiterConfig, Token}, - scram, EndpointCacheKey, + EndpointCacheKey, }; use dashmap::DashMap; +use proxy_sasl::scram; use std::{hash::Hash, sync::Arc, time::Duration}; use tokio::time::Instant; use tracing::info; @@ -469,15 +470,15 @@ impl ApiLocks { timeout: Duration, epoch: std::time::Duration, metrics: &'static ApiLockMetrics, - ) -> prometheus::Result { - Ok(Self { + ) -> Self { + Self { name, node_locks: DashMap::with_shard_amount(shards), config, timeout, epoch, metrics, - }) + } } pub async fn get_permit(&self, key: &K) -> Result { diff --git a/proxy/core/src/console/provider/mock.rs b/proxy/core/src/console/provider/mock.rs index 2093da7562..178143d979 100644 --- a/proxy/core/src/console/provider/mock.rs +++ b/proxy/core/src/console/provider/mock.rs @@ -5,7 +5,7 @@ use super::{ AuthInfo, AuthSecret, CachedNodeInfo, NodeInfo, }; use crate::context::RequestMonitoring; -use crate::{auth::backend::ComputeUserInfo, compute, error::io_error, scram, url::ApiUrl}; +use crate::{auth::backend::ComputeUserInfo, compute, error::io_error, url::ApiUrl}; use crate::{auth::IpPattern, cache::Cached}; use crate::{ console::{ @@ -15,6 +15,7 @@ use crate::{ BranchId, EndpointId, ProjectId, }; use futures::TryFutureExt; +use proxy_sasl::scram; use std::{str::FromStr, sync::Arc}; use thiserror::Error; use tokio_postgres::{config::SslMode, Client}; diff --git a/proxy/core/src/console/provider/neon.rs b/proxy/core/src/console/provider/neon.rs index 7eda238b66..294c4637a0 100644 --- a/proxy/core/src/console/provider/neon.rs +++ b/proxy/core/src/console/provider/neon.rs @@ -13,10 +13,11 @@ use crate::{ http, metrics::{CacheOutcome, Metrics}, rate_limiter::WakeComputeRateLimiter, - scram, EndpointCacheKey, + EndpointCacheKey, }; use crate::{cache::Cached, context::RequestMonitoring}; use futures::TryFutureExt; +use proxy_sasl::scram; use std::{sync::Arc, time::Duration}; use tokio::time::Instant; use tokio_postgres::config::SslMode; diff --git a/proxy/core/src/lib.rs b/proxy/core/src/lib.rs index ea92eaaa55..024c9459b3 100644 --- a/proxy/core/src/lib.rs +++ b/proxy/core/src/lib.rs @@ -21,13 +21,13 @@ pub mod intern; pub mod jemalloc; pub mod logging; pub mod metrics; -pub mod parse; +// pub mod parse; pub mod protocol2; pub mod proxy; pub mod rate_limiter; pub mod redis; -pub mod sasl; -pub mod scram; +// pub mod sasl; +// pub mod scram; pub mod serverless; pub mod stream; pub mod url; diff --git a/proxy/core/src/metrics.rs b/proxy/core/src/metrics.rs index 0167553e30..62fb80fae9 100644 --- a/proxy/core/src/metrics.rs +++ b/proxy/core/src/metrics.rs @@ -2,13 +2,14 @@ use std::sync::{Arc, OnceLock}; use lasso::ThreadedRodeo; use measured::{ - label::{FixedCardinalitySet, LabelGroupSet, LabelName, LabelSet, LabelValue, StaticLabelSet}, + label::StaticLabelSet, metric::{histogram::Thresholds, name::MetricName}, - Counter, CounterVec, FixedCardinalityLabel, Gauge, GaugeVec, Histogram, HistogramVec, - LabelGroup, MetricGroup, + Counter, CounterVec, FixedCardinalityLabel, Gauge, Histogram, HistogramVec, LabelGroup, + MetricGroup, }; use metrics::{CounterPairAssoc, CounterPairVec, HyperLogLog, HyperLogLogVec}; +use proxy_sasl::scram::threadpool::ThreadPoolMetrics; use tokio::time::{self, Instant}; use crate::console::messages::ColdStartInfo; @@ -546,78 +547,3 @@ pub enum RedisEventsCount { PasswordUpdate, AllowedIpsUpdate, } - -pub struct ThreadPoolWorkers(usize); -pub struct ThreadPoolWorkerId(pub usize); - -impl LabelValue for ThreadPoolWorkerId { - fn visit(&self, v: V) -> V::Output { - v.write_int(self.0 as i64) - } -} - -impl LabelGroup for ThreadPoolWorkerId { - fn visit_values(&self, v: &mut impl measured::label::LabelGroupVisitor) { - v.write_value(LabelName::from_str("worker"), self); - } -} - -impl LabelGroupSet for ThreadPoolWorkers { - type Group<'a> = ThreadPoolWorkerId; - - fn cardinality(&self) -> Option { - Some(self.0) - } - - fn encode_dense(&self, value: Self::Unique) -> Option { - Some(value) - } - - fn decode_dense(&self, value: usize) -> Self::Group<'_> { - ThreadPoolWorkerId(value) - } - - type Unique = usize; - - fn encode(&self, value: Self::Group<'_>) -> Option { - Some(value.0) - } - - fn decode(&self, value: &Self::Unique) -> Self::Group<'_> { - ThreadPoolWorkerId(*value) - } -} - -impl LabelSet for ThreadPoolWorkers { - type Value<'a> = ThreadPoolWorkerId; - - fn dynamic_cardinality(&self) -> Option { - Some(self.0) - } - - fn encode(&self, value: Self::Value<'_>) -> Option { - (value.0 < self.0).then_some(value.0) - } - - fn decode(&self, value: usize) -> Self::Value<'_> { - ThreadPoolWorkerId(value) - } -} - -impl FixedCardinalitySet for ThreadPoolWorkers { - fn cardinality(&self) -> usize { - self.0 - } -} - -#[derive(MetricGroup)] -#[metric(new(workers: usize))] -pub struct ThreadPoolMetrics { - pub injector_queue_depth: Gauge, - #[metric(init = GaugeVec::with_label_set(ThreadPoolWorkers(workers)))] - pub worker_queue_depth: GaugeVec, - #[metric(init = CounterVec::with_label_set(ThreadPoolWorkers(workers)))] - pub worker_task_turns_total: CounterVec, - #[metric(init = CounterVec::with_label_set(ThreadPoolWorkers(workers)))] - pub worker_task_skips_total: CounterVec, -} diff --git a/proxy/core/src/proxy/tests.rs b/proxy/core/src/proxy/tests.rs index d8308c4f2a..5c84f78f7a 100644 --- a/proxy/core/src/proxy/tests.rs +++ b/proxy/core/src/proxy/tests.rs @@ -16,9 +16,10 @@ use crate::console::messages::{ConsoleError, Details, MetricsAuxInfo, Status}; use crate::console::provider::{CachedAllowedIps, CachedRoleSecret, ConsoleBackend}; use crate::console::{self, CachedNodeInfo, NodeInfo}; use crate::error::ErrorKind; -use crate::{http, sasl, scram, BranchId, EndpointId, ProjectId}; +use crate::{http, BranchId, EndpointId, ProjectId}; use anyhow::{bail, Context}; use async_trait::async_trait; +use proxy_sasl::{sasl, scram}; use retry::{retry_after, ShouldRetryWakeCompute}; use rstest::rstest; use rustls::pki_types; @@ -137,7 +138,7 @@ struct Scram(scram::ServerSecret); impl Scram { async fn new(password: &str) -> anyhow::Result { - let secret = scram::ServerSecret::build(password) + let secret = scram::ServerSecret::build_test_secret(password) .await .context("failed to generate scram secret")?; Ok(Scram(secret)) diff --git a/proxy/core/src/serverless/backend.rs b/proxy/core/src/serverless/backend.rs index 295ea1a1c7..c33e22a595 100644 --- a/proxy/core/src/serverless/backend.rs +++ b/proxy/core/src/serverless/backend.rs @@ -79,11 +79,11 @@ impl PoolingBackend { ) .await?; let res = match auth_outcome { - crate::sasl::Outcome::Success(key) => { + proxy_sasl::sasl::Outcome::Success(key) => { info!("user successfully authenticated"); Ok(key) } - crate::sasl::Outcome::Failure(reason) => { + proxy_sasl::sasl::Outcome::Failure(reason) => { info!("auth backend failed with an error: {reason}"); Err(AuthError::auth_failed(&*conn_info.user_info.user)) } diff --git a/proxy/core/src/stream.rs b/proxy/core/src/stream.rs index 690e92ffb1..f235207fee 100644 --- a/proxy/core/src/stream.rs +++ b/proxy/core/src/stream.rs @@ -1,10 +1,10 @@ -use crate::config::TlsServerEndPoint; use crate::error::{ErrorKind, ReportableError, UserFacingError}; use crate::metrics::Metrics; use bytes::BytesMut; use pq_proto::framed::{ConnectionError, Framed}; use pq_proto::{BeMessage, FeMessage, FeStartupPacket, ProtocolError}; +use proxy_sasl::scram::TlsServerEndPoint; use rustls::ServerConfig; use std::pin::Pin; use std::sync::Arc; diff --git a/proxy/sasl/Cargo.toml b/proxy/sasl/Cargo.toml new file mode 100644 index 0000000000..b27ba74c7a --- /dev/null +++ b/proxy/sasl/Cargo.toml @@ -0,0 +1,128 @@ +[package] +name = "proxy-sasl" +version = "0.1.0" +edition.workspace = true +license.workspace = true + +[features] +default = [] +testing = [] + +[dependencies] +ahash.workspace = true +anyhow.workspace = true +# arc-swap.workspace = true +# async-compression.workspace = true +# async-trait.workspace = true +# atomic-take.workspace = true +# aws-config.workspace = true +# aws-sdk-iam.workspace = true +# aws-sigv4.workspace = true +# aws-types.workspace = true +base64.workspace = true +# bstr.workspace = true +bytes = { workspace = true, features = ["serde"] } +# camino.workspace = true +# chrono.workspace = true +# clap.workspace = true +# consumption_metrics.workspace = true +crossbeam-deque.workspace = true +# dashmap.workspace = true +# env_logger.workspace = true +# framed-websockets.workspace = true +# futures.workspace = true +# git-version.workspace = true +# hashbrown.workspace = true +# hashlink.workspace = true +# hex.workspace = true +hmac.workspace = true +# hostname.workspace = true +# http.workspace = true +# humantime.workspace = true +# humantime-serde.workspace = true +# hyper.workspace = true +# hyper1 = { package = "hyper", version = "1.2", features = ["server"] } +# hyper-util = { version = "0.1", features = ["server", "http1", "http2", "tokio"] } +# http-body-util = { version = "0.1" } +# indexmap.workspace = true +# ipnet.workspace = true +itertools.workspace = true +lasso = { workspace = true, features = ["multi-threaded"] } +# md5.workspace = true +measured = { workspace = true, features = ["lasso"] } +# metrics.workspace = true +# once_cell.workspace = true +# opentelemetry.workspace = true +parking_lot.workspace = true +# parquet.workspace = true +# parquet_derive.workspace = true +# pin-project-lite.workspace = true +# postgres_backend.workspace = true +pq_proto.workspace = true +# prometheus.workspace = true +rand.workspace = true +# regex.workspace = true +# remote_storage = { version = "0.1", path = "../../libs/remote_storage/" } +# reqwest.workspace = true +# reqwest-middleware = { workspace = true, features = ["json"] } +# reqwest-retry.workspace = true +# reqwest-tracing.workspace = true +# routerify.workspace = true +# rustc-hash.workspace = true +# rustls-pemfile.workspace = true +rustls.workspace = true +# scopeguard.workspace = true +# serde.workspace = true +# serde_json.workspace = true +sha2 = { workspace = true, features = ["asm", "oid"] } +# smol_str.workspace = true +# smallvec.workspace = true +# socket2.workspace = true +subtle.workspace = true +# task-local-extensions.workspace = true +thiserror.workspace = true +# tikv-jemallocator.workspace = true +# tikv-jemalloc-ctl = { workspace = true, features = ["use_std"] } +# tokio-postgres.workspace = true +# tokio-postgres-rustls.workspace = true +# tokio-rustls.workspace = true +# tokio-util.workspace = true +tokio = { workspace = true, features = ["signal"] } +# tower-service.workspace = true +# tracing-opentelemetry.workspace = true +# tracing-subscriber.workspace = true +# tracing-utils.workspace = true +tracing.workspace = true +# try-lock.workspace = true +# typed-json.workspace = true +# url.workspace = true +# urlencoding.workspace = true +# utils.workspace = true +# uuid.workspace = true +# rustls-native-certs.workspace = true +x509-parser.workspace = true +postgres-protocol.workspace = true +# redis.workspace = true + +# # jwt stuff +# jose-jwa = "0.1.2" +# jose-jwk = { version = "0.1.2", features = ["p256", "p384", "rsa"] } +# signature = "2" +# ecdsa = "0.16" +# p256 = "0.13" +# rsa = "0.9" + +workspace_hack.workspace = true + +[dev-dependencies] +# camino-tempfile.workspace = true +# fallible-iterator.workspace = true +# tokio-tungstenite.workspace = true +pbkdf2 = { workspace = true, features = ["simple", "std"] } +# rcgen.workspace = true +# rstest.workspace = true +# tokio-postgres-rustls.workspace = true +# walkdir.workspace = true +# rand_distr = "0.4" + +uuid.workspace = true diff --git a/proxy/sasl/src/lib.rs b/proxy/sasl/src/lib.rs new file mode 100644 index 0000000000..9eab9b3e0f --- /dev/null +++ b/proxy/sasl/src/lib.rs @@ -0,0 +1,3 @@ +mod parse; +pub mod sasl; +pub mod scram; diff --git a/proxy/sasl/src/parse.rs b/proxy/sasl/src/parse.rs new file mode 100644 index 0000000000..0d03574901 --- /dev/null +++ b/proxy/sasl/src/parse.rs @@ -0,0 +1,43 @@ +//! Small parsing helpers. + +use std::ffi::CStr; + +pub fn split_cstr(bytes: &[u8]) -> Option<(&CStr, &[u8])> { + let cstr = CStr::from_bytes_until_nul(bytes).ok()?; + let (_, other) = bytes.split_at(cstr.to_bytes_with_nul().len()); + Some((cstr, other)) +} + +/// See . +pub fn split_at_const(bytes: &[u8]) -> Option<(&[u8; N], &[u8])> { + (bytes.len() >= N).then(|| { + let (head, tail) = bytes.split_at(N); + (head.try_into().unwrap(), tail) + }) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_split_cstr() { + assert!(split_cstr(b"").is_none()); + assert!(split_cstr(b"foo").is_none()); + + let (cstr, rest) = split_cstr(b"\0").expect("uh-oh"); + assert_eq!(cstr.to_bytes(), b""); + assert_eq!(rest, b""); + + let (cstr, rest) = split_cstr(b"foo\0bar").expect("uh-oh"); + assert_eq!(cstr.to_bytes(), b"foo"); + assert_eq!(rest, b"bar"); + } + + #[test] + fn test_split_at_const() { + assert!(split_at_const::<0>(b"").is_some()); + assert!(split_at_const::<1>(b"").is_none()); + assert!(matches!(split_at_const::<1>(b"ok"), Some((b"o", b"k")))); + } +} diff --git a/proxy/core/src/sasl.rs b/proxy/sasl/src/sasl.rs similarity index 65% rename from proxy/core/src/sasl.rs rename to proxy/sasl/src/sasl.rs index 0811416ca2..9939a9d71d 100644 --- a/proxy/core/src/sasl.rs +++ b/proxy/sasl/src/sasl.rs @@ -10,7 +10,7 @@ mod channel_binding; mod messages; mod stream; -use crate::error::{ReportableError, UserFacingError}; +// use crate::error::{ReportableError, UserFacingError}; use std::io; use thiserror::Error; @@ -40,29 +40,29 @@ pub enum Error { Io(#[from] io::Error), } -impl UserFacingError for Error { - fn to_string_client(&self) -> String { - use Error::*; - match self { - ChannelBindingFailed(m) => m.to_string(), - ChannelBindingBadMethod(m) => format!("unsupported channel binding method {m}"), - _ => "authentication protocol violation".to_string(), - } - } -} +// impl UserFacingError for Error { +// fn to_string_client(&self) -> String { +// use Error::*; +// match self { +// ChannelBindingFailed(m) => m.to_string(), +// ChannelBindingBadMethod(m) => format!("unsupported channel binding method {m}"), +// _ => "authentication protocol violation".to_string(), +// } +// } +// } -impl ReportableError for Error { - fn get_error_kind(&self) -> crate::error::ErrorKind { - match self { - Error::ChannelBindingFailed(_) => crate::error::ErrorKind::User, - Error::ChannelBindingBadMethod(_) => crate::error::ErrorKind::User, - Error::BadClientMessage(_) => crate::error::ErrorKind::User, - Error::MissingBinding => crate::error::ErrorKind::Service, - Error::Base64(_) => crate::error::ErrorKind::ControlPlane, - Error::Io(_) => crate::error::ErrorKind::ClientDisconnect, - } - } -} +// impl ReportableError for Error { +// fn get_error_kind(&self) -> crate::error::ErrorKind { +// match self { +// Error::ChannelBindingFailed(_) => crate::error::ErrorKind::User, +// Error::ChannelBindingBadMethod(_) => crate::error::ErrorKind::User, +// Error::BadClientMessage(_) => crate::error::ErrorKind::User, +// Error::MissingBinding => crate::error::ErrorKind::Service, +// Error::Base64(_) => crate::error::ErrorKind::ControlPlane, +// Error::Io(_) => crate::error::ErrorKind::ClientDisconnect, +// } +// } +// } /// A convenient result type for SASL exchange. pub type Result = std::result::Result; diff --git a/proxy/core/src/sasl/channel_binding.rs b/proxy/sasl/src/sasl/channel_binding.rs similarity index 100% rename from proxy/core/src/sasl/channel_binding.rs rename to proxy/sasl/src/sasl/channel_binding.rs diff --git a/proxy/core/src/sasl/messages.rs b/proxy/sasl/src/sasl/messages.rs similarity index 100% rename from proxy/core/src/sasl/messages.rs rename to proxy/sasl/src/sasl/messages.rs diff --git a/proxy/core/src/sasl/stream.rs b/proxy/sasl/src/sasl/stream.rs similarity index 72% rename from proxy/core/src/sasl/stream.rs rename to proxy/sasl/src/sasl/stream.rs index 9115b0f61a..e98932cfcb 100644 --- a/proxy/core/src/sasl/stream.rs +++ b/proxy/sasl/src/sasl/stream.rs @@ -1,7 +1,10 @@ //! Abstraction for the string-oriented SASL protocols. use super::{messages::ServerMessage, Mechanism}; -use crate::stream::PqStream; +use pq_proto::{ + framed::{ConnectionError, Framed}, + FeMessage, ProtocolError, +}; use std::io; use tokio::io::{AsyncRead, AsyncWrite}; use tracing::info; @@ -9,7 +12,7 @@ use tracing::info; /// Abstracts away all peculiarities of the libpq's protocol. pub struct SaslStream<'a, S> { /// The underlying stream. - stream: &'a mut PqStream, + stream: &'a mut Framed, /// Current password message we received from client. current: bytes::Bytes, /// First SASL message produced by client. @@ -17,7 +20,7 @@ pub struct SaslStream<'a, S> { } impl<'a, S> SaslStream<'a, S> { - pub fn new(stream: &'a mut PqStream, first: &'a str) -> Self { + pub fn new(stream: &'a mut Framed, first: &'a str) -> Self { Self { stream, current: bytes::Bytes::new(), @@ -26,6 +29,27 @@ impl<'a, S> SaslStream<'a, S> { } } +fn err_connection() -> io::Error { + io::Error::new(io::ErrorKind::ConnectionAborted, "connection is lost") +} + +pub async fn read_password_message( + framed: &mut Framed, +) -> io::Result { + let msg = framed + .read_message() + .await + .map_err(ConnectionError::into_io_error)? + .ok_or_else(err_connection)?; + match msg { + FeMessage::PasswordMessage(msg) => Ok(msg), + bad => Err(io::Error::new( + io::ErrorKind::InvalidData, + format!("unexpected message type: {:?}", bad), + )), + } +} + impl SaslStream<'_, S> { // Receive a new SASL message from the client. async fn recv(&mut self) -> io::Result<&str> { @@ -33,7 +57,7 @@ impl SaslStream<'_, S> { return Ok(first); } - self.current = self.stream.read_password_message().await?; + self.current = read_password_message(self.stream).await?; let s = std::str::from_utf8(&self.current) .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "bad encoding"))?; @@ -44,7 +68,10 @@ impl SaslStream<'_, S> { impl SaslStream<'_, S> { // Send a SASL message to the client. async fn send(&mut self, msg: &ServerMessage<&str>) -> io::Result<()> { - self.stream.write_message(&msg.to_reply()).await?; + self.stream + .write_message(&msg.to_reply()) + .map_err(ProtocolError::into_io_error)?; + self.stream.flush().await?; Ok(()) } } diff --git a/proxy/core/src/scram.rs b/proxy/sasl/src/scram.rs similarity index 61% rename from proxy/core/src/scram.rs rename to proxy/sasl/src/scram.rs index 862facb4e5..5cfa0e5984 100644 --- a/proxy/core/src/scram.rs +++ b/proxy/sasl/src/scram.rs @@ -15,12 +15,16 @@ mod secret; mod signature; pub mod threadpool; +use anyhow::Context; pub use exchange::{exchange, Exchange}; pub use key::ScramKey; +use rustls::pki_types::CertificateDer; pub use secret::ServerSecret; use hmac::{Hmac, Mac}; use sha2::{Digest, Sha256}; +use tracing::{error, info}; +use x509_parser::oid_registry; const SCRAM_SHA_256: &str = "SCRAM-SHA-256"; const SCRAM_SHA_256_PLUS: &str = "SCRAM-SHA-256-PLUS"; @@ -57,12 +61,71 @@ fn sha256<'a>(parts: impl IntoIterator) -> [u8; 32] { hasher.finalize().into() } +/// Channel binding parameter +/// +/// +/// Description: The hash of the TLS server's certificate as it +/// appears, octet for octet, in the server's Certificate message. Note +/// that the Certificate message contains a certificate_list, in which +/// the first element is the server's certificate. +/// +/// The hash function is to be selected as follows: +/// +/// * if the certificate's signatureAlgorithm uses a single hash +/// function, and that hash function is either MD5 or SHA-1, then use SHA-256; +/// +/// * if the certificate's signatureAlgorithm uses a single hash +/// function and that hash function neither MD5 nor SHA-1, then use +/// the hash function associated with the certificate's +/// signatureAlgorithm; +/// +/// * if the certificate's signatureAlgorithm uses no hash functions or +/// uses multiple hash functions, then this channel binding type's +/// channel bindings are undefined at this time (updates to is channel +/// binding type may occur to address this issue if it ever arises). +#[derive(Debug, Clone, Copy)] +pub enum TlsServerEndPoint { + Sha256([u8; 32]), + Undefined, +} + +impl TlsServerEndPoint { + pub fn new(cert: &CertificateDer) -> anyhow::Result { + let sha256_oids = [ + // I'm explicitly not adding MD5 or SHA1 here... They're bad. + oid_registry::OID_SIG_ECDSA_WITH_SHA256, + oid_registry::OID_PKCS1_SHA256WITHRSA, + ]; + + let pem = x509_parser::parse_x509_certificate(cert) + .context("Failed to parse PEM object from cerficiate")? + .1; + + info!(subject = %pem.subject, "parsing TLS certificate"); + + let reg = oid_registry::OidRegistry::default().with_all_crypto(); + let oid = pem.signature_algorithm.oid(); + let alg = reg.get(oid); + if sha256_oids.contains(oid) { + let tls_server_end_point: [u8; 32] = Sha256::new().chain_update(cert).finalize().into(); + info!(subject = %pem.subject, signature_algorithm = alg.map(|a| a.description()), tls_server_end_point = %base64::encode(tls_server_end_point), "determined channel binding"); + Ok(Self::Sha256(tls_server_end_point)) + } else { + error!(subject = %pem.subject, signature_algorithm = alg.map(|a| a.description()), "unknown channel binding"); + Ok(Self::Undefined) + } + } + + pub fn supported(&self) -> bool { + !matches!(self, TlsServerEndPoint::Undefined) + } +} + #[cfg(test)] mod tests { use crate::{ - intern::EndpointIdInt, sasl::{Mechanism, Step}, - EndpointId, + scram::TlsServerEndPoint, }; use super::{threadpool::ThreadPool, Exchange, ServerSecret}; @@ -79,11 +142,7 @@ mod tests { const NONCE: [u8; 18] = [ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, ]; - let mut exchange = Exchange::new( - &secret, - || NONCE, - crate::config::TlsServerEndPoint::Undefined, - ); + let mut exchange = Exchange::new(&secret, || NONCE, TlsServerEndPoint::Undefined); let client_first = "n,,n=user,r=rOprNGfwEbeRWgbNEkqO"; let client_final = "c=biws,r=rOprNGfwEbeRWgbNEkqOAQIDBAUGBwgJCgsMDQ4PEBES,p=rw1r5Kph5ThxmaUBC2GAQ6MfXbPnNkFiTIvdb/Rear0="; @@ -120,11 +179,11 @@ mod tests { async fn run_round_trip_test(server_password: &str, client_password: &str) { let pool = ThreadPool::new(1); + let ep = "foo"; - let ep = EndpointId::from("foo"); - let ep = EndpointIdInt::from(ep); - - let scram_secret = ServerSecret::build(server_password).await.unwrap(); + let scram_secret = ServerSecret::build_test_secret(server_password) + .await + .unwrap(); let outcome = super::exchange(&pool, ep, &scram_secret, client_password.as_bytes()) .await .unwrap(); diff --git a/proxy/core/src/scram/countmin.rs b/proxy/sasl/src/scram/countmin.rs similarity index 100% rename from proxy/core/src/scram/countmin.rs rename to proxy/sasl/src/scram/countmin.rs diff --git a/proxy/core/src/scram/exchange.rs b/proxy/sasl/src/scram/exchange.rs similarity index 89% rename from proxy/core/src/scram/exchange.rs rename to proxy/sasl/src/scram/exchange.rs index d0adbc780e..06ae61f5b2 100644 --- a/proxy/core/src/scram/exchange.rs +++ b/proxy/sasl/src/scram/exchange.rs @@ -1,6 +1,7 @@ //! Implementation of the SCRAM authentication algorithm. use std::convert::Infallible; +use std::hash::Hash; use hmac::{Hmac, Mac}; use sha2::Sha256; @@ -13,8 +14,6 @@ use super::secret::ServerSecret; use super::signature::SignatureBuilder; use super::threadpool::ThreadPool; use super::ScramKey; -use crate::config; -use crate::intern::EndpointIdInt; use crate::sasl::{self, ChannelBinding, Error as SaslError}; /// The only channel binding mode we currently support. @@ -59,14 +58,14 @@ enum ExchangeState { pub struct Exchange<'a> { state: ExchangeState, secret: &'a ServerSecret, - tls_server_end_point: config::TlsServerEndPoint, + tls_server_end_point: super::TlsServerEndPoint, } impl<'a> Exchange<'a> { pub fn new( secret: &'a ServerSecret, nonce: fn() -> [u8; SCRAM_RAW_NONCE_LEN], - tls_server_end_point: config::TlsServerEndPoint, + tls_server_end_point: super::TlsServerEndPoint, ) -> Self { Self { state: ExchangeState::Initial(SaslInitial { nonce }), @@ -77,15 +76,15 @@ impl<'a> Exchange<'a> { } // copied from -async fn derive_client_key( - pool: &ThreadPool, - endpoint: EndpointIdInt, +async fn derive_client_key( + pool: &ThreadPool, + concurrency_key: K, password: &[u8], salt: &[u8], iterations: u32, ) -> ScramKey { let salted_password = pool - .spawn_job(endpoint, Pbkdf2::start(password, salt, iterations)) + .spawn_job(concurrency_key, Pbkdf2::start(password, salt, iterations)) .await .expect("job should not be cancelled"); @@ -101,14 +100,15 @@ async fn derive_client_key( make_key(b"Client Key").into() } -pub async fn exchange( - pool: &ThreadPool, - endpoint: EndpointIdInt, +pub async fn exchange( + pool: &ThreadPool, + concurrency_key: K, secret: &ServerSecret, password: &[u8], ) -> sasl::Result> { let salt = base64::decode(&secret.salt_base64)?; - let client_key = derive_client_key(pool, endpoint, password, &salt, secret.iterations).await; + let client_key = + derive_client_key(pool, concurrency_key, password, &salt, secret.iterations).await; if secret.is_password_invalid(&client_key).into() { Ok(sasl::Outcome::Failure("password doesn't match")) @@ -121,7 +121,7 @@ impl SaslInitial { fn transition( &self, secret: &ServerSecret, - tls_server_end_point: &config::TlsServerEndPoint, + tls_server_end_point: &super::TlsServerEndPoint, input: &str, ) -> sasl::Result> { let client_first_message = ClientFirstMessage::parse(input) @@ -156,7 +156,7 @@ impl SaslSentInner { fn transition( &self, secret: &ServerSecret, - tls_server_end_point: &config::TlsServerEndPoint, + tls_server_end_point: &super::TlsServerEndPoint, input: &str, ) -> sasl::Result> { let Self { @@ -169,8 +169,8 @@ impl SaslSentInner { .ok_or(SaslError::BadClientMessage("invalid client-final-message"))?; let channel_binding = cbind_flag.encode(|_| match tls_server_end_point { - config::TlsServerEndPoint::Sha256(x) => Ok(x), - config::TlsServerEndPoint::Undefined => Err(SaslError::MissingBinding), + super::TlsServerEndPoint::Sha256(x) => Ok(x), + super::TlsServerEndPoint::Undefined => Err(SaslError::MissingBinding), })?; // This might've been caused by a MITM attack diff --git a/proxy/core/src/scram/key.rs b/proxy/sasl/src/scram/key.rs similarity index 100% rename from proxy/core/src/scram/key.rs rename to proxy/sasl/src/scram/key.rs diff --git a/proxy/core/src/scram/messages.rs b/proxy/sasl/src/scram/messages.rs similarity index 100% rename from proxy/core/src/scram/messages.rs rename to proxy/sasl/src/scram/messages.rs diff --git a/proxy/core/src/scram/pbkdf2.rs b/proxy/sasl/src/scram/pbkdf2.rs similarity index 100% rename from proxy/core/src/scram/pbkdf2.rs rename to proxy/sasl/src/scram/pbkdf2.rs diff --git a/proxy/core/src/scram/secret.rs b/proxy/sasl/src/scram/secret.rs similarity index 96% rename from proxy/core/src/scram/secret.rs rename to proxy/sasl/src/scram/secret.rs index 44c4f9e44a..1f5b184b33 100644 --- a/proxy/core/src/scram/secret.rs +++ b/proxy/sasl/src/scram/secret.rs @@ -64,9 +64,7 @@ impl ServerSecret { } /// Build a new server secret from the prerequisites. - /// XXX: We only use this function in tests. - #[cfg(test)] - pub async fn build(password: &str) -> Option { + pub async fn build_test_secret(password: &str) -> Option { Self::parse(&postgres_protocol::password::scram_sha_256(password.as_bytes()).await) } } diff --git a/proxy/core/src/scram/signature.rs b/proxy/sasl/src/scram/signature.rs similarity index 100% rename from proxy/core/src/scram/signature.rs rename to proxy/sasl/src/scram/signature.rs diff --git a/proxy/core/src/scram/threadpool.rs b/proxy/sasl/src/scram/threadpool.rs similarity index 76% rename from proxy/core/src/scram/threadpool.rs rename to proxy/sasl/src/scram/threadpool.rs index 7701b869a3..3a6c689378 100644 --- a/proxy/core/src/scram/threadpool.rs +++ b/proxy/sasl/src/scram/threadpool.rs @@ -4,29 +4,36 @@ //! 1. Fairness per endpoint. //! 2. Yield support for high iteration counts. -use std::sync::{ - atomic::{AtomicU64, Ordering}, - Arc, +use std::{ + hash::Hash, + sync::{ + atomic::{AtomicU64, Ordering}, + Arc, + }, }; use crossbeam_deque::{Injector, Stealer, Worker}; use itertools::Itertools; +use measured::{ + label::{FixedCardinalitySet, LabelGroupSet, LabelName, LabelSet, LabelValue}, + CounterVec, Gauge, GaugeVec, LabelGroup, MetricGroup, +}; use parking_lot::{Condvar, Mutex}; use rand::Rng; use rand::{rngs::SmallRng, SeedableRng}; use tokio::sync::oneshot; use crate::{ - intern::EndpointIdInt, - metrics::{ThreadPoolMetrics, ThreadPoolWorkerId}, + // intern::EndpointIdInt, + // metrics::{ThreadPoolMetrics, ThreadPoolWorkerId}, scram::countmin::CountMinSketch, }; use super::pbkdf2::Pbkdf2; -pub struct ThreadPool { - queue: Injector, - stealers: Vec>, +pub struct ThreadPool { + queue: Injector>, + stealers: Vec>>, parkers: Vec<(Condvar, Mutex)>, /// bitpacked representation. /// lower 8 bits = number of sleeping threads @@ -42,7 +49,7 @@ enum ThreadState { Active, } -impl ThreadPool { +impl ThreadPool { pub fn new(n_workers: u8) -> Arc { let workers = (0..n_workers).map(|_| Worker::new_fifo()).collect_vec(); let stealers = workers.iter().map(|w| w.stealer()).collect_vec(); @@ -68,11 +75,7 @@ impl ThreadPool { pool } - pub fn spawn_job( - &self, - endpoint: EndpointIdInt, - pbkdf2: Pbkdf2, - ) -> oneshot::Receiver<[u8; 32]> { + pub fn spawn_job(&self, key: K, pbkdf2: Pbkdf2) -> oneshot::Receiver<[u8; 32]> { let (tx, rx) = oneshot::channel(); let queue_was_empty = self.queue.is_empty(); @@ -81,7 +84,7 @@ impl ThreadPool { self.queue.push(JobSpec { response: tx, pbkdf2, - endpoint, + key, }); // inspired from @@ -139,7 +142,12 @@ impl ThreadPool { } } - fn steal(&self, rng: &mut impl Rng, skip: usize, worker: &Worker) -> Option { + fn steal( + &self, + rng: &mut impl Rng, + skip: usize, + worker: &Worker>, + ) -> Option> { // announce thread as idle self.counters.fetch_add(256, Ordering::SeqCst); @@ -188,7 +196,11 @@ impl ThreadPool { } } -fn thread_rt(pool: Arc, worker: Worker, index: usize) { +fn thread_rt( + pool: Arc>, + worker: Worker>, + index: usize, +) { /// interval when we should steal from the global queue /// so that tail latencies are managed appropriately const STEAL_INTERVAL: usize = 61; @@ -236,7 +248,7 @@ fn thread_rt(pool: Arc, worker: Worker, index: usize) { // receiver is closed, cancel the task if !job.response.is_closed() { - let rate = sketch.inc_and_return(&job.endpoint, job.pbkdf2.cost()); + let rate = sketch.inc_and_return(&job.key, job.pbkdf2.cost()); const P: f64 = 2000.0; // probability decreases as rate increases. @@ -287,24 +299,96 @@ fn thread_rt(pool: Arc, worker: Worker, index: usize) { } } -struct JobSpec { +struct JobSpec { response: oneshot::Sender<[u8; 32]>, pbkdf2: Pbkdf2, - endpoint: EndpointIdInt, + key: K, +} + +pub struct ThreadPoolWorkers(usize); +pub struct ThreadPoolWorkerId(pub usize); + +impl LabelValue for ThreadPoolWorkerId { + fn visit(&self, v: V) -> V::Output { + v.write_int(self.0 as i64) + } +} + +impl LabelGroup for ThreadPoolWorkerId { + fn visit_values(&self, v: &mut impl measured::label::LabelGroupVisitor) { + v.write_value(LabelName::from_str("worker"), self); + } +} + +impl LabelGroupSet for ThreadPoolWorkers { + type Group<'a> = ThreadPoolWorkerId; + + fn cardinality(&self) -> Option { + Some(self.0) + } + + fn encode_dense(&self, value: Self::Unique) -> Option { + Some(value) + } + + fn decode_dense(&self, value: usize) -> Self::Group<'_> { + ThreadPoolWorkerId(value) + } + + type Unique = usize; + + fn encode(&self, value: Self::Group<'_>) -> Option { + Some(value.0) + } + + fn decode(&self, value: &Self::Unique) -> Self::Group<'_> { + ThreadPoolWorkerId(*value) + } +} + +impl LabelSet for ThreadPoolWorkers { + type Value<'a> = ThreadPoolWorkerId; + + fn dynamic_cardinality(&self) -> Option { + Some(self.0) + } + + fn encode(&self, value: Self::Value<'_>) -> Option { + (value.0 < self.0).then_some(value.0) + } + + fn decode(&self, value: usize) -> Self::Value<'_> { + ThreadPoolWorkerId(value) + } +} + +impl FixedCardinalitySet for ThreadPoolWorkers { + fn cardinality(&self) -> usize { + self.0 + } +} + +#[derive(MetricGroup)] +#[metric(new(workers: usize))] +pub struct ThreadPoolMetrics { + pub injector_queue_depth: Gauge, + #[metric(init = GaugeVec::with_label_set(ThreadPoolWorkers(workers)))] + pub worker_queue_depth: GaugeVec, + #[metric(init = CounterVec::with_label_set(ThreadPoolWorkers(workers)))] + pub worker_task_turns_total: CounterVec, + #[metric(init = CounterVec::with_label_set(ThreadPoolWorkers(workers)))] + pub worker_task_skips_total: CounterVec, } #[cfg(test)] mod tests { - use crate::EndpointId; - use super::*; #[tokio::test] async fn hash_is_correct() { let pool = ThreadPool::new(1); - let ep = EndpointId::from("foo"); - let ep = EndpointIdInt::from(ep); + let ep = "foo"; let salt = [0x55; 32]; let actual = pool