From 9cfe08e3d9f1181a163322705ff41cbcfb11db3b Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Wed, 22 May 2024 18:05:43 +0100 Subject: [PATCH] proxy password threadpool (#7806) ## Problem Despite making password hashing async, it can still take time away from the network code. ## Summary of changes Introduce a custom threadpool, inspired by rayon. Features: ### Fairness Each task is tagged with it's endpoint ID. The more times we have seen the endpoint, the more likely we are to skip the task if it comes up in the queue. This is using a min-count-sketch estimator for the number of times we have seen the endpoint, resetting it every 1000+ steps. Since tasks are immediately rescheduled if they do not complete, the worker could get stuck in a "always work available loop". To combat this, we check the global queue every 61 steps to ensure all tasks quickly get a worker assigned to them. ### Balanced Using crossbeam_deque, like rayon does, we have workstealing out of the box. I've tested it a fair amount and it seems to balance the workload accordingly --- Cargo.lock | 20 +- Cargo.toml | 2 + proxy/Cargo.toml | 4 +- proxy/src/auth/backend.rs | 10 +- proxy/src/auth/backend/hacks.rs | 11 +- proxy/src/auth/flow.rs | 24 ++- proxy/src/bin/proxy.rs | 8 + proxy/src/config.rs | 2 + proxy/src/metrics.rs | 89 +++++++-- proxy/src/scram.rs | 18 +- proxy/src/scram/countmin.rs | 173 +++++++++++++++++ proxy/src/scram/exchange.rs | 49 ++--- proxy/src/scram/pbkdf2.rs | 89 +++++++++ proxy/src/scram/threadpool.rs | 321 ++++++++++++++++++++++++++++++++ proxy/src/serverless/backend.rs | 11 +- workspace_hack/Cargo.toml | 2 + 16 files changed, 759 insertions(+), 74 deletions(-) create mode 100644 proxy/src/scram/countmin.rs create mode 100644 proxy/src/scram/pbkdf2.rs create mode 100644 proxy/src/scram/threadpool.rs diff --git a/Cargo.lock b/Cargo.lock index e6060c82f5..d8f9021eb8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1471,26 +1471,21 @@ dependencies = [ [[package]] name = "crossbeam-deque" -version = "0.8.3" +version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ce6fd6f855243022dcecf8702fef0c297d4338e226845fe067f6341ad9fa0cef" +checksum = "613f8cc01fe9cf1a3eb3d7f488fd2fa8388403e97039e2f73692932e291a770d" dependencies = [ - "cfg-if", "crossbeam-epoch", "crossbeam-utils", ] [[package]] name = "crossbeam-epoch" -version = "0.9.14" +version = "0.9.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "46bd5f3f85273295a9d14aedfb86f6aadbff6d8f5295c4a9edb08e819dcf5695" +checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" dependencies = [ - "autocfg", - "cfg-if", "crossbeam-utils", - "memoffset 0.8.0", - "scopeguard", ] [[package]] @@ -3961,9 +3956,9 @@ checksum = "de3145af08024dea9fa9914f381a17b8fc6034dfb00f3a84013f7ff43f29ed4c" [[package]] name = "pbkdf2" -version = "0.12.1" +version = "0.12.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f0ca0b5a68607598bf3bad68f32227a8164f6254833f84eafaac409cd6746c31" +checksum = "f8ed6a7761f76e3b9f92dfb0a60a6a6477c61024b775147ff0973a02653abaf2" dependencies = [ "digest", "hmac", @@ -4386,6 +4381,7 @@ dependencies = [ name = "proxy" version = "0.1.0" dependencies = [ + "ahash", "anyhow", "async-compression", "async-trait", @@ -4402,6 +4398,7 @@ dependencies = [ "chrono", "clap", "consumption_metrics", + "crossbeam-deque", "dashmap", "env_logger", "fallible-iterator", @@ -7473,6 +7470,7 @@ dependencies = [ name = "workspace_hack" version = "0.1.0" dependencies = [ + "ahash", "anyhow", "aws-config", "aws-runtime", diff --git a/Cargo.toml b/Cargo.toml index 2a7dea447e..0887c039f8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -41,6 +41,7 @@ license = "Apache-2.0" ## All dependency versions, used in the project [workspace.dependencies] +ahash = "0.8" anyhow = { version = "1.0", features = ["backtrace"] } arc-swap = "1.6" async-compression = { version = "0.4.0", features = ["tokio", "gzip", "zstd"] } @@ -74,6 +75,7 @@ clap = { version = "4.0", features = ["derive"] } comfy-table = "6.1" const_format = "0.2" crc32c = "0.6" +crossbeam-deque = "0.8.5" crossbeam-utils = "0.8.5" dashmap = { version = "5.5.0", features = ["raw-api"] } either = "1.8" diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index 5f9b0aa75b..7da0763bc1 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -9,6 +9,7 @@ default = [] testing = [] [dependencies] +ahash.workspace = true anyhow.workspace = true async-compression.workspace = true async-trait.workspace = true @@ -24,6 +25,7 @@ camino.workspace = true chrono.workspace = true clap.workspace = true consumption_metrics.workspace = true +crossbeam-deque.workspace = true dashmap.workspace = true env_logger.workspace = true framed-websockets.workspace = true @@ -52,7 +54,6 @@ opentelemetry.workspace = true parking_lot.workspace = true parquet.workspace = true parquet_derive.workspace = true -pbkdf2 = { workspace = true, features = ["simple", "std"] } pin-project-lite.workspace = true postgres_backend.workspace = true pq_proto.workspace = true @@ -106,6 +107,7 @@ workspace_hack.workspace = true camino-tempfile.workspace = true fallible-iterator.workspace = true tokio-tungstenite.workspace = true +pbkdf2 = { workspace = true, features = ["simple", "std"] } rcgen.workspace = true rstest.workspace = true tokio-postgres-rustls.workspace = true diff --git a/proxy/src/auth/backend.rs b/proxy/src/auth/backend.rs index 6a906b299b..3555eba543 100644 --- a/proxy/src/auth/backend.rs +++ b/proxy/src/auth/backend.rs @@ -365,7 +365,10 @@ async fn authenticate_with_secret( config: &'static AuthenticationConfig, ) -> auth::Result { if let Some(password) = unauthenticated_password { - let auth_outcome = validate_password_and_exchange(&password, secret).await?; + let ep = EndpointIdInt::from(&info.endpoint); + + let auth_outcome = + validate_password_and_exchange(&config.thread_pool, ep, &password, secret).await?; let keys = match auth_outcome { crate::sasl::Outcome::Success(key) => key, crate::sasl::Outcome::Failure(reason) => { @@ -386,7 +389,7 @@ async fn authenticate_with_secret( // Currently, we use it for websocket connections (latency). if allow_cleartext { ctx.set_auth_method(crate::context::AuthMethod::Cleartext); - return hacks::authenticate_cleartext(ctx, info, client, secret).await; + return hacks::authenticate_cleartext(ctx, info, client, secret, config).await; } // Finally, proceed with the main auth flow (SCRAM-based). @@ -554,7 +557,7 @@ mod tests { context::RequestMonitoring, proxy::NeonOptions, rate_limiter::{EndpointRateLimiter, RateBucketInfo}, - scram::ServerSecret, + scram::{threadpool::ThreadPool, ServerSecret}, stream::{PqStream, Stream}, }; @@ -596,6 +599,7 @@ mod tests { } static CONFIG: Lazy = Lazy::new(|| AuthenticationConfig { + thread_pool: ThreadPool::new(1), scram_protocol_timeout: std::time::Duration::from_secs(5), rate_limiter_enabled: true, rate_limiter: AuthRateLimiter::new(&RateBucketInfo::DEFAULT_AUTH_SET), diff --git a/proxy/src/auth/backend/hacks.rs b/proxy/src/auth/backend/hacks.rs index f7241be4a9..6b0f5e1726 100644 --- a/proxy/src/auth/backend/hacks.rs +++ b/proxy/src/auth/backend/hacks.rs @@ -3,8 +3,10 @@ use super::{ }; use crate::{ auth::{self, AuthFlow}, + config::AuthenticationConfig, console::AuthSecret, context::RequestMonitoring, + intern::EndpointIdInt, sasl, stream::{self, Stream}, }; @@ -20,6 +22,7 @@ pub async fn authenticate_cleartext( info: ComputeUserInfo, client: &mut stream::PqStream>, secret: AuthSecret, + config: &'static AuthenticationConfig, ) -> auth::Result { warn!("cleartext auth flow override is enabled, proceeding"); ctx.set_auth_method(crate::context::AuthMethod::Cleartext); @@ -27,8 +30,14 @@ pub async fn authenticate_cleartext( // pause the timer while we communicate with the client let paused = ctx.latency_timer.pause(crate::metrics::Waiting::Client); + let ep = EndpointIdInt::from(&info.endpoint); + let auth_flow = AuthFlow::new(client) - .begin(auth::CleartextPassword(secret)) + .begin(auth::CleartextPassword { + secret, + endpoint: ep, + pool: config.thread_pool.clone(), + }) .await?; drop(paused); // cleartext auth is only allowed to the ws/http protocol. diff --git a/proxy/src/auth/flow.rs b/proxy/src/auth/flow.rs index 45bbad8cb2..59d1ac17f4 100644 --- a/proxy/src/auth/flow.rs +++ b/proxy/src/auth/flow.rs @@ -5,12 +5,14 @@ use crate::{ config::TlsServerEndPoint, console::AuthSecret, context::RequestMonitoring, - sasl, scram, + intern::EndpointIdInt, + sasl, + scram::{self, threadpool::ThreadPool}, stream::{PqStream, Stream}, }; use postgres_protocol::authentication::sasl::{SCRAM_SHA_256, SCRAM_SHA_256_PLUS}; use pq_proto::{BeAuthenticationSaslMessage, BeMessage, BeMessage as Be}; -use std::io; +use std::{io, sync::Arc}; use tokio::io::{AsyncRead, AsyncWrite}; use tracing::info; @@ -53,7 +55,11 @@ impl AuthMethod for PasswordHack { /// Use clear-text password auth called `password` in docs /// -pub struct CleartextPassword(pub AuthSecret); +pub struct CleartextPassword { + pub pool: Arc, + pub endpoint: EndpointIdInt, + pub secret: AuthSecret, +} impl AuthMethod for CleartextPassword { #[inline(always)] @@ -126,7 +132,13 @@ impl AuthFlow<'_, S, CleartextPassword> { .strip_suffix(&[0]) .ok_or(AuthErrorImpl::MalformedPassword("missing terminator"))?; - let outcome = validate_password_and_exchange(password, self.state.0).await?; + let outcome = validate_password_and_exchange( + &self.state.pool, + self.state.endpoint, + password, + self.state.secret, + ) + .await?; if let sasl::Outcome::Success(_) = &outcome { self.stream.write_message_noflush(&Be::AuthenticationOk)?; @@ -181,6 +193,8 @@ impl AuthFlow<'_, S, Scram<'_>> { } pub(crate) async fn validate_password_and_exchange( + pool: &ThreadPool, + endpoint: EndpointIdInt, password: &[u8], secret: AuthSecret, ) -> super::Result> { @@ -194,7 +208,7 @@ pub(crate) async fn validate_password_and_exchange( } // perform scram authentication as both client and server to validate the keys AuthSecret::Scram(scram_secret) => { - let outcome = crate::scram::exchange(&scram_secret, password).await?; + let outcome = crate::scram::exchange(pool, endpoint, &scram_secret, password).await?; let client_key = match outcome { sasl::Outcome::Success(client_key) => client_key, diff --git a/proxy/src/bin/proxy.rs b/proxy/src/bin/proxy.rs index be7d961b8c..30f2e6f4b7 100644 --- a/proxy/src/bin/proxy.rs +++ b/proxy/src/bin/proxy.rs @@ -27,6 +27,7 @@ use proxy::redis::cancellation_publisher::RedisPublisherClient; use proxy::redis::connection_with_credentials_provider::ConnectionWithCredentialsProvider; use proxy::redis::elasticache; use proxy::redis::notifications; +use proxy::scram::threadpool::ThreadPool; use proxy::serverless::cancel_set::CancelSet; use proxy::serverless::GlobalConnPoolOptions; use proxy::usage_metrics; @@ -132,6 +133,9 @@ struct ProxyCliArgs { /// timeout for scram authentication protocol #[clap(long, default_value = "15s", value_parser = humantime::parse_duration)] scram_protocol_timeout: tokio::time::Duration, + /// size of the threadpool for password hashing + #[clap(long, default_value_t = 4)] + scram_thread_pool_size: u8, /// Require that all incoming requests have a Proxy Protocol V2 packet **and** have an IP address associated. #[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)] require_client_ip: bool, @@ -489,6 +493,9 @@ async fn main() -> anyhow::Result<()> { /// ProxyConfig is created at proxy startup, and lives forever. fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { + let thread_pool = ThreadPool::new(args.scram_thread_pool_size); + Metrics::install(thread_pool.metrics.clone()); + let tls_config = match (&args.tls_key, &args.tls_cert) { (Some(key_path), Some(cert_path)) => Some(config::configure_tls( key_path, @@ -624,6 +631,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { client_conn_threshold: args.sql_over_http.sql_over_http_client_conn_threshold, }; let authentication_config = AuthenticationConfig { + thread_pool, scram_protocol_timeout: args.scram_protocol_timeout, rate_limiter_enabled: args.auth_rate_limit_enabled, rate_limiter: AuthRateLimiter::new(args.auth_rate_limit.clone()), diff --git a/proxy/src/config.rs b/proxy/src/config.rs index b7ab2c00f9..5a0c251ce2 100644 --- a/proxy/src/config.rs +++ b/proxy/src/config.rs @@ -2,6 +2,7 @@ use crate::{ auth::{self, backend::AuthRateLimiter}, console::locks::ApiLocks, rate_limiter::RateBucketInfo, + scram::threadpool::ThreadPool, serverless::{cancel_set::CancelSet, GlobalConnPoolOptions}, Host, }; @@ -61,6 +62,7 @@ pub struct HttpConfig { } pub struct AuthenticationConfig { + pub thread_pool: Arc, pub scram_protocol_timeout: tokio::time::Duration, pub rate_limiter_enabled: bool, pub rate_limiter: AuthRateLimiter, diff --git a/proxy/src/metrics.rs b/proxy/src/metrics.rs index 1590316925..e2a75a8720 100644 --- a/proxy/src/metrics.rs +++ b/proxy/src/metrics.rs @@ -1,11 +1,11 @@ -use std::sync::OnceLock; +use std::sync::{Arc, OnceLock}; use lasso::ThreadedRodeo; use measured::{ - label::StaticLabelSet, + label::{FixedCardinalitySet, LabelName, LabelSet, LabelValue, StaticLabelSet}, metric::{histogram::Thresholds, name::MetricName}, - Counter, CounterVec, FixedCardinalityLabel, Gauge, Histogram, HistogramVec, LabelGroup, - MetricGroup, + Counter, CounterVec, FixedCardinalityLabel, Gauge, GaugeVec, Histogram, HistogramVec, + LabelGroup, MetricGroup, }; use metrics::{CounterPairAssoc, CounterPairVec, HyperLogLog, HyperLogLogVec}; @@ -14,26 +14,36 @@ use tokio::time::{self, Instant}; use crate::console::messages::ColdStartInfo; #[derive(MetricGroup)] +#[metric(new(thread_pool: Arc))] pub struct Metrics { #[metric(namespace = "proxy")] + #[metric(init = ProxyMetrics::new(thread_pool))] pub proxy: ProxyMetrics, #[metric(namespace = "wake_compute_lock")] pub wake_compute_lock: ApiLockMetrics, } +static SELF: OnceLock = OnceLock::new(); impl Metrics { + pub fn install(thread_pool: Arc) { + SELF.set(Metrics::new(thread_pool)) + .ok() + .expect("proxy metrics must not be installed more than once"); + } + pub fn get() -> &'static Self { - static SELF: OnceLock = OnceLock::new(); - SELF.get_or_init(|| Metrics { - proxy: ProxyMetrics::default(), - wake_compute_lock: ApiLockMetrics::new(), - }) + #[cfg(test)] + return SELF.get_or_init(|| Metrics::new(Arc::new(ThreadPoolMetrics::new(0)))); + + #[cfg(not(test))] + SELF.get() + .expect("proxy metrics must be installed by the main() function") } } #[derive(MetricGroup)] -#[metric(new())] +#[metric(new(thread_pool: Arc))] pub struct ProxyMetrics { #[metric(flatten)] pub db_connections: CounterPairVec, @@ -129,6 +139,10 @@ pub struct ProxyMetrics { #[metric(namespace = "connect_compute_lock")] pub connect_compute_lock: ApiLockMetrics, + + #[metric(namespace = "scram_pool")] + #[metric(init = thread_pool)] + pub scram_pool: Arc, } #[derive(MetricGroup)] @@ -146,12 +160,6 @@ pub struct ApiLockMetrics { pub semaphore_acquire_seconds: Histogram<16>, } -impl Default for ProxyMetrics { - fn default() -> Self { - Self::new() - } -} - impl Default for ApiLockMetrics { fn default() -> Self { Self::new() @@ -553,3 +561,52 @@ pub enum RedisEventsCount { PasswordUpdate, AllowedIpsUpdate, } + +pub struct ThreadPoolWorkers(usize); +pub struct ThreadPoolWorkerId(pub usize); + +impl LabelValue for ThreadPoolWorkerId { + fn visit(&self, v: V) -> V::Output { + v.write_int(self.0 as i64) + } +} + +impl LabelGroup for ThreadPoolWorkerId { + fn visit_values(&self, v: &mut impl measured::label::LabelGroupVisitor) { + v.write_value(LabelName::from_str("worker"), self); + } +} + +impl LabelSet for ThreadPoolWorkers { + type Value<'a> = ThreadPoolWorkerId; + + fn dynamic_cardinality(&self) -> Option { + Some(self.0) + } + + fn encode(&self, value: Self::Value<'_>) -> Option { + (value.0 < self.0).then_some(value.0) + } + + fn decode(&self, value: usize) -> Self::Value<'_> { + ThreadPoolWorkerId(value) + } +} + +impl FixedCardinalitySet for ThreadPoolWorkers { + fn cardinality(&self) -> usize { + self.0 + } +} + +#[derive(MetricGroup)] +#[metric(new(workers: usize))] +pub struct ThreadPoolMetrics { + pub injector_queue_depth: Gauge, + #[metric(init = GaugeVec::with_label_set(ThreadPoolWorkers(workers)))] + pub worker_queue_depth: GaugeVec, + #[metric(init = CounterVec::with_label_set(ThreadPoolWorkers(workers)))] + pub worker_task_turns_total: CounterVec, + #[metric(init = CounterVec::with_label_set(ThreadPoolWorkers(workers)))] + pub worker_task_skips_total: CounterVec, +} diff --git a/proxy/src/scram.rs b/proxy/src/scram.rs index ed80675f8a..862facb4e5 100644 --- a/proxy/src/scram.rs +++ b/proxy/src/scram.rs @@ -6,11 +6,14 @@ //! * //! * +mod countmin; mod exchange; mod key; mod messages; +mod pbkdf2; mod secret; mod signature; +pub mod threadpool; pub use exchange::{exchange, Exchange}; pub use key::ScramKey; @@ -56,9 +59,13 @@ fn sha256<'a>(parts: impl IntoIterator) -> [u8; 32] { #[cfg(test)] mod tests { - use crate::sasl::{Mechanism, Step}; + use crate::{ + intern::EndpointIdInt, + sasl::{Mechanism, Step}, + EndpointId, + }; - use super::{Exchange, ServerSecret}; + use super::{threadpool::ThreadPool, Exchange, ServerSecret}; #[test] fn snapshot() { @@ -112,8 +119,13 @@ mod tests { } async fn run_round_trip_test(server_password: &str, client_password: &str) { + let pool = ThreadPool::new(1); + + let ep = EndpointId::from("foo"); + let ep = EndpointIdInt::from(ep); + let scram_secret = ServerSecret::build(server_password).await.unwrap(); - let outcome = super::exchange(&scram_secret, client_password.as_bytes()) + let outcome = super::exchange(&pool, ep, &scram_secret, client_password.as_bytes()) .await .unwrap(); diff --git a/proxy/src/scram/countmin.rs b/proxy/src/scram/countmin.rs new file mode 100644 index 0000000000..f2b794e5fe --- /dev/null +++ b/proxy/src/scram/countmin.rs @@ -0,0 +1,173 @@ +use std::hash::Hash; + +/// estimator of hash jobs per second. +/// +pub struct CountMinSketch { + // one for each depth + hashers: Vec, + width: usize, + depth: usize, + // buckets, width*depth + buckets: Vec, +} + +impl CountMinSketch { + /// Given parameters (ε, δ), + /// set width = ceil(e/ε) + /// set depth = ceil(ln(1/δ)) + /// + /// guarantees: + /// actual <= estimate + /// estimate <= actual + ε * N with probability 1 - δ + /// where N is the cardinality of the stream + pub fn with_params(epsilon: f64, delta: f64) -> Self { + CountMinSketch::new( + (std::f64::consts::E / epsilon).ceil() as usize, + (1.0_f64 / delta).ln().ceil() as usize, + ) + } + + fn new(width: usize, depth: usize) -> Self { + Self { + #[cfg(test)] + hashers: (0..depth) + .map(|i| { + // digits of pi for good randomness + ahash::RandomState::with_seeds( + 314159265358979323, + 84626433832795028, + 84197169399375105, + 82097494459230781 + i as u64, + ) + }) + .collect(), + #[cfg(not(test))] + hashers: (0..depth).map(|_| ahash::RandomState::new()).collect(), + width, + depth, + buckets: vec![0; width * depth], + } + } + + pub fn inc_and_return(&mut self, t: &T, x: u32) -> u32 { + let mut min = u32::MAX; + for row in 0..self.depth { + let col = (self.hashers[row].hash_one(t) as usize) % self.width; + + let row = &mut self.buckets[row * self.width..][..self.width]; + row[col] = row[col].saturating_add(x); + min = std::cmp::min(min, row[col]); + } + min + } + + pub fn reset(&mut self) { + self.buckets.clear(); + self.buckets.resize(self.width * self.depth, 0); + } +} + +#[cfg(test)] +mod tests { + use rand::{rngs::StdRng, seq::SliceRandom, Rng, SeedableRng}; + + use super::CountMinSketch; + + fn eval_precision(n: usize, p: f64, q: f64) -> usize { + // fixed value of phi for consistent test + let mut rng = StdRng::seed_from_u64(16180339887498948482); + + #[allow(non_snake_case)] + let mut N = 0; + + let mut ids = vec![]; + + for _ in 0..n { + // number of insert operations + let n = rng.gen_range(1..100); + // number to insert at once + let m = rng.gen_range(1..4096); + + let id = uuid::Builder::from_random_bytes(rng.gen()).into_uuid(); + ids.push((id, n, m)); + + // N = sum(actual) + N += n * m; + } + + // q% of counts will be within p of the actual value + let mut sketch = CountMinSketch::with_params(p / N as f64, 1.0 - q); + + dbg!(sketch.buckets.len()); + + // insert a bunch of entries in a random order + let mut ids2 = ids.clone(); + while !ids2.is_empty() { + ids2.shuffle(&mut rng); + + let mut i = 0; + while i < ids2.len() { + sketch.inc_and_return(&ids2[i].0, ids2[i].1); + ids2[i].2 -= 1; + if ids2[i].2 == 0 { + ids2.remove(i); + } else { + i += 1; + } + } + } + + let mut within_p = 0; + for (id, n, m) in ids { + let actual = n * m; + let estimate = sketch.inc_and_return(&id, 0); + + // This estimate has the guarantee that actual <= estimate + assert!(actual <= estimate); + + // This estimate has the guarantee that estimate <= actual + εN with probability 1 - δ. + // ε = p / N, δ = 1 - q; + // therefore, estimate <= actual + p with probability q. + if estimate as f64 <= actual as f64 + p { + within_p += 1; + } + } + within_p + } + + #[test] + fn precision() { + assert_eq!(eval_precision(100, 100.0, 0.99), 100); + assert_eq!(eval_precision(1000, 100.0, 0.99), 1000); + assert_eq!(eval_precision(100, 4096.0, 0.99), 100); + assert_eq!(eval_precision(1000, 4096.0, 0.99), 1000); + + // seems to be more precise than the literature indicates? + // probably numbers are too small to truly represent the probabilities. + assert_eq!(eval_precision(100, 4096.0, 0.90), 100); + assert_eq!(eval_precision(1000, 4096.0, 0.90), 1000); + assert_eq!(eval_precision(100, 4096.0, 0.1), 98); + assert_eq!(eval_precision(1000, 4096.0, 0.1), 991); + } + + // returns memory usage in bytes, and the time complexity per insert. + fn eval_cost(p: f64, q: f64) -> (usize, usize) { + #[allow(non_snake_case)] + // N = sum(actual) + // Let's assume 1021 samples, all of 4096 + let N = 1021 * 4096; + let sketch = CountMinSketch::with_params(p / N as f64, 1.0 - q); + + let memory = std::mem::size_of::() * sketch.buckets.len(); + let time = sketch.depth; + (memory, time) + } + + #[test] + fn memory_usage() { + assert_eq!(eval_cost(100.0, 0.99), (2273580, 5)); + assert_eq!(eval_cost(4096.0, 0.99), (55520, 5)); + assert_eq!(eval_cost(4096.0, 0.90), (33312, 3)); + assert_eq!(eval_cost(4096.0, 0.1), (11104, 1)); + } +} diff --git a/proxy/src/scram/exchange.rs b/proxy/src/scram/exchange.rs index 89dd33e59f..d0adbc780e 100644 --- a/proxy/src/scram/exchange.rs +++ b/proxy/src/scram/exchange.rs @@ -4,15 +4,17 @@ use std::convert::Infallible; use hmac::{Hmac, Mac}; use sha2::Sha256; -use tokio::task::yield_now; use super::messages::{ ClientFinalMessage, ClientFirstMessage, OwnedServerFirstMessage, SCRAM_RAW_NONCE_LEN, }; +use super::pbkdf2::Pbkdf2; use super::secret::ServerSecret; use super::signature::SignatureBuilder; +use super::threadpool::ThreadPool; use super::ScramKey; use crate::config; +use crate::intern::EndpointIdInt; use crate::sasl::{self, ChannelBinding, Error as SaslError}; /// The only channel binding mode we currently support. @@ -74,37 +76,18 @@ impl<'a> Exchange<'a> { } } -// copied from -async fn pbkdf2(str: &[u8], salt: &[u8], iterations: u32) -> [u8; 32] { - let hmac = Hmac::::new_from_slice(str).expect("HMAC is able to accept all key sizes"); - let mut prev = hmac - .clone() - .chain_update(salt) - .chain_update(1u32.to_be_bytes()) - .finalize() - .into_bytes(); - - let mut hi = prev; - - for i in 1..iterations { - prev = hmac.clone().chain_update(prev).finalize().into_bytes(); - - for (hi, prev) in hi.iter_mut().zip(prev) { - *hi ^= prev; - } - // yield every ~250us - // hopefully reduces tail latencies - if i % 1024 == 0 { - yield_now().await - } - } - - hi.into() -} - // copied from -async fn derive_client_key(password: &[u8], salt: &[u8], iterations: u32) -> ScramKey { - let salted_password = pbkdf2(password, salt, iterations).await; +async fn derive_client_key( + pool: &ThreadPool, + endpoint: EndpointIdInt, + password: &[u8], + salt: &[u8], + iterations: u32, +) -> ScramKey { + let salted_password = pool + .spawn_job(endpoint, Pbkdf2::start(password, salt, iterations)) + .await + .expect("job should not be cancelled"); let make_key = |name| { let key = Hmac::::new_from_slice(&salted_password) @@ -119,11 +102,13 @@ async fn derive_client_key(password: &[u8], salt: &[u8], iterations: u32) -> Scr } pub async fn exchange( + pool: &ThreadPool, + endpoint: EndpointIdInt, secret: &ServerSecret, password: &[u8], ) -> sasl::Result> { let salt = base64::decode(&secret.salt_base64)?; - let client_key = derive_client_key(password, &salt, secret.iterations).await; + let client_key = derive_client_key(pool, endpoint, password, &salt, secret.iterations).await; if secret.is_password_invalid(&client_key).into() { Ok(sasl::Outcome::Failure("password doesn't match")) diff --git a/proxy/src/scram/pbkdf2.rs b/proxy/src/scram/pbkdf2.rs new file mode 100644 index 0000000000..a803ba7e1b --- /dev/null +++ b/proxy/src/scram/pbkdf2.rs @@ -0,0 +1,89 @@ +use hmac::{ + digest::{consts::U32, generic_array::GenericArray}, + Hmac, Mac, +}; +use sha2::Sha256; + +pub struct Pbkdf2 { + hmac: Hmac, + prev: GenericArray, + hi: GenericArray, + iterations: u32, +} + +// inspired from +impl Pbkdf2 { + pub fn start(str: &[u8], salt: &[u8], iterations: u32) -> Self { + let hmac = + Hmac::::new_from_slice(str).expect("HMAC is able to accept all key sizes"); + + let prev = hmac + .clone() + .chain_update(salt) + .chain_update(1u32.to_be_bytes()) + .finalize() + .into_bytes(); + + Self { + hmac, + // one consumed for the hash above + iterations: iterations - 1, + hi: prev, + prev, + } + } + + pub fn cost(&self) -> u32 { + (self.iterations).clamp(0, 4096) + } + + pub fn turn(&mut self) -> std::task::Poll<[u8; 32]> { + let Self { + hmac, + prev, + hi, + iterations, + } = self; + + // only do 4096 iterations per turn before sharing the thread for fairness + let n = (*iterations).clamp(0, 4096); + for _ in 0..n { + *prev = hmac.clone().chain_update(*prev).finalize().into_bytes(); + + for (hi, prev) in hi.iter_mut().zip(*prev) { + *hi ^= prev; + } + } + + *iterations -= n; + if *iterations == 0 { + std::task::Poll::Ready((*hi).into()) + } else { + std::task::Poll::Pending + } + } +} + +#[cfg(test)] +mod tests { + use super::Pbkdf2; + use pbkdf2::pbkdf2_hmac_array; + use sha2::Sha256; + + #[test] + fn works() { + let salt = b"sodium chloride"; + let pass = b"Ne0n_!5_50_C007"; + + let mut job = Pbkdf2::start(pass, salt, 600000); + let hash = loop { + let std::task::Poll::Ready(hash) = job.turn() else { + continue; + }; + break hash; + }; + + let expected = pbkdf2_hmac_array::(pass, salt, 600000); + assert_eq!(hash, expected) + } +} diff --git a/proxy/src/scram/threadpool.rs b/proxy/src/scram/threadpool.rs new file mode 100644 index 0000000000..7701b869a3 --- /dev/null +++ b/proxy/src/scram/threadpool.rs @@ -0,0 +1,321 @@ +//! Custom threadpool implementation for password hashing. +//! +//! Requirements: +//! 1. Fairness per endpoint. +//! 2. Yield support for high iteration counts. + +use std::sync::{ + atomic::{AtomicU64, Ordering}, + Arc, +}; + +use crossbeam_deque::{Injector, Stealer, Worker}; +use itertools::Itertools; +use parking_lot::{Condvar, Mutex}; +use rand::Rng; +use rand::{rngs::SmallRng, SeedableRng}; +use tokio::sync::oneshot; + +use crate::{ + intern::EndpointIdInt, + metrics::{ThreadPoolMetrics, ThreadPoolWorkerId}, + scram::countmin::CountMinSketch, +}; + +use super::pbkdf2::Pbkdf2; + +pub struct ThreadPool { + queue: Injector, + stealers: Vec>, + parkers: Vec<(Condvar, Mutex)>, + /// bitpacked representation. + /// lower 8 bits = number of sleeping threads + /// next 8 bits = number of idle threads (searching for work) + counters: AtomicU64, + + pub metrics: Arc, +} + +#[derive(PartialEq)] +enum ThreadState { + Parked, + Active, +} + +impl ThreadPool { + pub fn new(n_workers: u8) -> Arc { + let workers = (0..n_workers).map(|_| Worker::new_fifo()).collect_vec(); + let stealers = workers.iter().map(|w| w.stealer()).collect_vec(); + + let parkers = (0..n_workers) + .map(|_| (Condvar::new(), Mutex::new(ThreadState::Active))) + .collect_vec(); + + let pool = Arc::new(Self { + queue: Injector::new(), + stealers, + parkers, + // threads start searching for work + counters: AtomicU64::new((n_workers as u64) << 8), + metrics: Arc::new(ThreadPoolMetrics::new(n_workers as usize)), + }); + + for (i, worker) in workers.into_iter().enumerate() { + let pool = Arc::clone(&pool); + std::thread::spawn(move || thread_rt(pool, worker, i)); + } + + pool + } + + pub fn spawn_job( + &self, + endpoint: EndpointIdInt, + pbkdf2: Pbkdf2, + ) -> oneshot::Receiver<[u8; 32]> { + let (tx, rx) = oneshot::channel(); + + let queue_was_empty = self.queue.is_empty(); + + self.metrics.injector_queue_depth.inc(); + self.queue.push(JobSpec { + response: tx, + pbkdf2, + endpoint, + }); + + // inspired from + let counts = self.counters.load(Ordering::SeqCst); + let num_awake_but_idle = (counts >> 8) & 0xff; + let num_sleepers = counts & 0xff; + + // If the queue is non-empty, then we always wake up a worker + // -- clearly the existing idle jobs aren't enough. Otherwise, + // check to see if we have enough idle workers. + if !queue_was_empty || num_awake_but_idle == 0 { + let num_to_wake = Ord::min(1, num_sleepers); + self.wake_any_threads(num_to_wake); + } + + rx + } + + #[cold] + fn wake_any_threads(&self, mut num_to_wake: u64) { + if num_to_wake > 0 { + for i in 0..self.parkers.len() { + if self.wake_specific_thread(i) { + num_to_wake -= 1; + if num_to_wake == 0 { + return; + } + } + } + } + } + + fn wake_specific_thread(&self, index: usize) -> bool { + let (condvar, lock) = &self.parkers[index]; + + let mut state = lock.lock(); + if *state == ThreadState::Parked { + condvar.notify_one(); + + // When the thread went to sleep, it will have incremented + // this value. When we wake it, its our job to decrement + // it. We could have the thread do it, but that would + // introduce a delay between when the thread was + // *notified* and when this counter was decremented. That + // might mislead people with new work into thinking that + // there are sleeping threads that they should try to + // wake, when in fact there is nothing left for them to + // do. + self.counters.fetch_sub(1, Ordering::SeqCst); + *state = ThreadState::Active; + + true + } else { + false + } + } + + fn steal(&self, rng: &mut impl Rng, skip: usize, worker: &Worker) -> Option { + // announce thread as idle + self.counters.fetch_add(256, Ordering::SeqCst); + + // try steal from the global queue + loop { + match self.queue.steal_batch_and_pop(worker) { + crossbeam_deque::Steal::Success(job) => { + self.metrics + .injector_queue_depth + .set(self.queue.len() as i64); + // no longer idle + self.counters.fetch_sub(256, Ordering::SeqCst); + return Some(job); + } + crossbeam_deque::Steal::Retry => continue, + crossbeam_deque::Steal::Empty => break, + } + } + + // try steal from our neighbours + loop { + let mut retry = false; + let start = rng.gen_range(0..self.stealers.len()); + let job = (start..self.stealers.len()) + .chain(0..start) + .filter(|i| *i != skip) + .find_map( + |victim| match self.stealers[victim].steal_batch_and_pop(worker) { + crossbeam_deque::Steal::Success(job) => Some(job), + crossbeam_deque::Steal::Empty => None, + crossbeam_deque::Steal::Retry => { + retry = true; + None + } + }, + ); + if job.is_some() { + // no longer idle + self.counters.fetch_sub(256, Ordering::SeqCst); + return job; + } + if !retry { + return None; + } + } + } +} + +fn thread_rt(pool: Arc, worker: Worker, index: usize) { + /// interval when we should steal from the global queue + /// so that tail latencies are managed appropriately + const STEAL_INTERVAL: usize = 61; + + /// How often to reset the sketch values + const SKETCH_RESET_INTERVAL: usize = 1021; + + let mut rng = SmallRng::from_entropy(); + + // used to determine whether we should temporarily skip tasks for fairness. + // 99% of estimates will overcount by no more than 4096 samples + let mut sketch = CountMinSketch::with_params(1.0 / (SKETCH_RESET_INTERVAL as f64), 0.01); + + let (condvar, lock) = &pool.parkers[index]; + + 'wait: loop { + // wait for notification of work + { + let mut lock = lock.lock(); + + // queue is empty + pool.metrics + .worker_queue_depth + .set(ThreadPoolWorkerId(index), 0); + + // subtract 1 from idle count, add 1 to sleeping count. + pool.counters.fetch_sub(255, Ordering::SeqCst); + + *lock = ThreadState::Parked; + condvar.wait(&mut lock); + } + + for i in 0.. { + let mut job = match worker + .pop() + .or_else(|| pool.steal(&mut rng, index, &worker)) + { + Some(job) => job, + None => continue 'wait, + }; + + pool.metrics + .worker_queue_depth + .set(ThreadPoolWorkerId(index), worker.len() as i64); + + // receiver is closed, cancel the task + if !job.response.is_closed() { + let rate = sketch.inc_and_return(&job.endpoint, job.pbkdf2.cost()); + + const P: f64 = 2000.0; + // probability decreases as rate increases. + // lower probability, higher chance of being skipped + // + // estimates (rate in terms of 4096 rounds): + // rate = 0 => probability = 100% + // rate = 10 => probability = 71.3% + // rate = 50 => probability = 62.1% + // rate = 500 => probability = 52.3% + // rate = 1021 => probability = 49.8% + // + // My expectation is that the pool queue will only begin backing up at ~1000rps + // in which case the SKETCH_RESET_INTERVAL represents 1 second. Thus, the rates above + // are in requests per second. + let probability = P.ln() / (P + rate as f64).ln(); + if pool.queue.len() > 32 || rng.gen_bool(probability) { + pool.metrics + .worker_task_turns_total + .inc(ThreadPoolWorkerId(index)); + + match job.pbkdf2.turn() { + std::task::Poll::Ready(result) => { + let _ = job.response.send(result); + } + std::task::Poll::Pending => worker.push(job), + } + } else { + pool.metrics + .worker_task_skips_total + .inc(ThreadPoolWorkerId(index)); + + // skip for now + worker.push(job) + } + } + + // if we get stuck with a few long lived jobs in the queue + // it's better to try and steal from the queue too for fairness + if i % STEAL_INTERVAL == 0 { + let _ = pool.queue.steal_batch(&worker); + } + + if i % SKETCH_RESET_INTERVAL == 0 { + sketch.reset(); + } + } + } +} + +struct JobSpec { + response: oneshot::Sender<[u8; 32]>, + pbkdf2: Pbkdf2, + endpoint: EndpointIdInt, +} + +#[cfg(test)] +mod tests { + use crate::EndpointId; + + use super::*; + + #[tokio::test] + async fn hash_is_correct() { + let pool = ThreadPool::new(1); + + let ep = EndpointId::from("foo"); + let ep = EndpointIdInt::from(ep); + + let salt = [0x55; 32]; + let actual = pool + .spawn_job(ep, Pbkdf2::start(b"password", &salt, 4096)) + .await + .unwrap(); + + let expected = [ + 10, 114, 73, 188, 140, 222, 196, 156, 214, 184, 79, 157, 119, 242, 16, 31, 53, 242, + 178, 43, 95, 8, 225, 182, 122, 40, 219, 21, 89, 147, 64, 140, + ]; + assert_eq!(actual, expected) + } +} diff --git a/proxy/src/serverless/backend.rs b/proxy/src/serverless/backend.rs index 6b79c12316..52fc7b556a 100644 --- a/proxy/src/serverless/backend.rs +++ b/proxy/src/serverless/backend.rs @@ -15,6 +15,7 @@ use crate::{ }, context::RequestMonitoring, error::{ErrorKind, ReportableError, UserFacingError}, + intern::EndpointIdInt, proxy::{connect_compute::ConnectMechanism, retry::ShouldRetry}, rate_limiter::EndpointRateLimiter, Host, @@ -66,8 +67,14 @@ impl PoolingBackend { return Err(AuthError::auth_failed(&*user_info.user)); } }; - let auth_outcome = - crate::auth::validate_password_and_exchange(&conn_info.password, secret).await?; + let ep = EndpointIdInt::from(&conn_info.user_info.endpoint); + let auth_outcome = crate::auth::validate_password_and_exchange( + &config.thread_pool, + ep, + &conn_info.password, + secret, + ) + .await?; let res = match auth_outcome { crate::sasl::Outcome::Success(key) => { info!("user successfully authenticated"); diff --git a/workspace_hack/Cargo.toml b/workspace_hack/Cargo.toml index 7582562450..f364a6c2e0 100644 --- a/workspace_hack/Cargo.toml +++ b/workspace_hack/Cargo.toml @@ -13,6 +13,7 @@ publish = false ### BEGIN HAKARI SECTION [dependencies] +ahash = { version = "0.8" } anyhow = { version = "1", features = ["backtrace"] } aws-config = { version = "1", default-features = false, features = ["rustls", "sso"] } aws-runtime = { version = "1", default-features = false, features = ["event-stream", "http-02x", "sigv4a"] } @@ -85,6 +86,7 @@ zstd-safe = { version = "7", default-features = false, features = ["arrays", "le zstd-sys = { version = "2", default-features = false, features = ["legacy", "std", "zdict_builder"] } [build-dependencies] +ahash = { version = "0.8" } anyhow = { version = "1", features = ["backtrace"] } bytes = { version = "1", features = ["serde"] } cc = { version = "1", default-features = false, features = ["parallel"] }