proxy password threadpool (#7806)

## Problem

Despite making password hashing async, it can still take time away from
the network code.

## Summary of changes

Introduce a custom threadpool, inspired by rayon. Features:

### Fairness

Each task is tagged with it's endpoint ID. The more times we have seen
the endpoint, the more likely we are to skip the task if it comes up in
the queue. This is using a min-count-sketch estimator for the number of
times we have seen the endpoint, resetting it every 1000+ steps.

Since tasks are immediately rescheduled if they do not complete, the
worker could get stuck in a "always work available loop". To combat
this, we check the global queue every 61 steps to ensure all tasks
quickly get a worker assigned to them.

### Balanced

Using crossbeam_deque, like rayon does, we have workstealing out of the
box. I've tested it a fair amount and it seems to balance the workload
accordingly
This commit is contained in:
Conrad Ludgate
2024-05-22 18:05:43 +01:00
committed by GitHub
parent 64577cfddc
commit 9cfe08e3d9
16 changed files with 759 additions and 74 deletions

20
Cargo.lock generated
View File

@@ -1471,26 +1471,21 @@ dependencies = [
[[package]]
name = "crossbeam-deque"
version = "0.8.3"
version = "0.8.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ce6fd6f855243022dcecf8702fef0c297d4338e226845fe067f6341ad9fa0cef"
checksum = "613f8cc01fe9cf1a3eb3d7f488fd2fa8388403e97039e2f73692932e291a770d"
dependencies = [
"cfg-if",
"crossbeam-epoch",
"crossbeam-utils",
]
[[package]]
name = "crossbeam-epoch"
version = "0.9.14"
version = "0.9.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "46bd5f3f85273295a9d14aedfb86f6aadbff6d8f5295c4a9edb08e819dcf5695"
checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e"
dependencies = [
"autocfg",
"cfg-if",
"crossbeam-utils",
"memoffset 0.8.0",
"scopeguard",
]
[[package]]
@@ -3961,9 +3956,9 @@ checksum = "de3145af08024dea9fa9914f381a17b8fc6034dfb00f3a84013f7ff43f29ed4c"
[[package]]
name = "pbkdf2"
version = "0.12.1"
version = "0.12.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f0ca0b5a68607598bf3bad68f32227a8164f6254833f84eafaac409cd6746c31"
checksum = "f8ed6a7761f76e3b9f92dfb0a60a6a6477c61024b775147ff0973a02653abaf2"
dependencies = [
"digest",
"hmac",
@@ -4386,6 +4381,7 @@ dependencies = [
name = "proxy"
version = "0.1.0"
dependencies = [
"ahash",
"anyhow",
"async-compression",
"async-trait",
@@ -4402,6 +4398,7 @@ dependencies = [
"chrono",
"clap",
"consumption_metrics",
"crossbeam-deque",
"dashmap",
"env_logger",
"fallible-iterator",
@@ -7473,6 +7470,7 @@ dependencies = [
name = "workspace_hack"
version = "0.1.0"
dependencies = [
"ahash",
"anyhow",
"aws-config",
"aws-runtime",

View File

@@ -41,6 +41,7 @@ license = "Apache-2.0"
## All dependency versions, used in the project
[workspace.dependencies]
ahash = "0.8"
anyhow = { version = "1.0", features = ["backtrace"] }
arc-swap = "1.6"
async-compression = { version = "0.4.0", features = ["tokio", "gzip", "zstd"] }
@@ -74,6 +75,7 @@ clap = { version = "4.0", features = ["derive"] }
comfy-table = "6.1"
const_format = "0.2"
crc32c = "0.6"
crossbeam-deque = "0.8.5"
crossbeam-utils = "0.8.5"
dashmap = { version = "5.5.0", features = ["raw-api"] }
either = "1.8"

View File

@@ -9,6 +9,7 @@ default = []
testing = []
[dependencies]
ahash.workspace = true
anyhow.workspace = true
async-compression.workspace = true
async-trait.workspace = true
@@ -24,6 +25,7 @@ camino.workspace = true
chrono.workspace = true
clap.workspace = true
consumption_metrics.workspace = true
crossbeam-deque.workspace = true
dashmap.workspace = true
env_logger.workspace = true
framed-websockets.workspace = true
@@ -52,7 +54,6 @@ opentelemetry.workspace = true
parking_lot.workspace = true
parquet.workspace = true
parquet_derive.workspace = true
pbkdf2 = { workspace = true, features = ["simple", "std"] }
pin-project-lite.workspace = true
postgres_backend.workspace = true
pq_proto.workspace = true
@@ -106,6 +107,7 @@ workspace_hack.workspace = true
camino-tempfile.workspace = true
fallible-iterator.workspace = true
tokio-tungstenite.workspace = true
pbkdf2 = { workspace = true, features = ["simple", "std"] }
rcgen.workspace = true
rstest.workspace = true
tokio-postgres-rustls.workspace = true

View File

@@ -365,7 +365,10 @@ async fn authenticate_with_secret(
config: &'static AuthenticationConfig,
) -> auth::Result<ComputeCredentials> {
if let Some(password) = unauthenticated_password {
let auth_outcome = validate_password_and_exchange(&password, secret).await?;
let ep = EndpointIdInt::from(&info.endpoint);
let auth_outcome =
validate_password_and_exchange(&config.thread_pool, ep, &password, secret).await?;
let keys = match auth_outcome {
crate::sasl::Outcome::Success(key) => key,
crate::sasl::Outcome::Failure(reason) => {
@@ -386,7 +389,7 @@ async fn authenticate_with_secret(
// Currently, we use it for websocket connections (latency).
if allow_cleartext {
ctx.set_auth_method(crate::context::AuthMethod::Cleartext);
return hacks::authenticate_cleartext(ctx, info, client, secret).await;
return hacks::authenticate_cleartext(ctx, info, client, secret, config).await;
}
// Finally, proceed with the main auth flow (SCRAM-based).
@@ -554,7 +557,7 @@ mod tests {
context::RequestMonitoring,
proxy::NeonOptions,
rate_limiter::{EndpointRateLimiter, RateBucketInfo},
scram::ServerSecret,
scram::{threadpool::ThreadPool, ServerSecret},
stream::{PqStream, Stream},
};
@@ -596,6 +599,7 @@ mod tests {
}
static CONFIG: Lazy<AuthenticationConfig> = Lazy::new(|| AuthenticationConfig {
thread_pool: ThreadPool::new(1),
scram_protocol_timeout: std::time::Duration::from_secs(5),
rate_limiter_enabled: true,
rate_limiter: AuthRateLimiter::new(&RateBucketInfo::DEFAULT_AUTH_SET),

View File

@@ -3,8 +3,10 @@ use super::{
};
use crate::{
auth::{self, AuthFlow},
config::AuthenticationConfig,
console::AuthSecret,
context::RequestMonitoring,
intern::EndpointIdInt,
sasl,
stream::{self, Stream},
};
@@ -20,6 +22,7 @@ pub async fn authenticate_cleartext(
info: ComputeUserInfo,
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
secret: AuthSecret,
config: &'static AuthenticationConfig,
) -> auth::Result<ComputeCredentials> {
warn!("cleartext auth flow override is enabled, proceeding");
ctx.set_auth_method(crate::context::AuthMethod::Cleartext);
@@ -27,8 +30,14 @@ pub async fn authenticate_cleartext(
// pause the timer while we communicate with the client
let paused = ctx.latency_timer.pause(crate::metrics::Waiting::Client);
let ep = EndpointIdInt::from(&info.endpoint);
let auth_flow = AuthFlow::new(client)
.begin(auth::CleartextPassword(secret))
.begin(auth::CleartextPassword {
secret,
endpoint: ep,
pool: config.thread_pool.clone(),
})
.await?;
drop(paused);
// cleartext auth is only allowed to the ws/http protocol.

View File

@@ -5,12 +5,14 @@ use crate::{
config::TlsServerEndPoint,
console::AuthSecret,
context::RequestMonitoring,
sasl, scram,
intern::EndpointIdInt,
sasl,
scram::{self, threadpool::ThreadPool},
stream::{PqStream, Stream},
};
use postgres_protocol::authentication::sasl::{SCRAM_SHA_256, SCRAM_SHA_256_PLUS};
use pq_proto::{BeAuthenticationSaslMessage, BeMessage, BeMessage as Be};
use std::io;
use std::{io, sync::Arc};
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::info;
@@ -53,7 +55,11 @@ impl AuthMethod for PasswordHack {
/// Use clear-text password auth called `password` in docs
/// <https://www.postgresql.org/docs/current/auth-password.html>
pub struct CleartextPassword(pub AuthSecret);
pub struct CleartextPassword {
pub pool: Arc<ThreadPool>,
pub endpoint: EndpointIdInt,
pub secret: AuthSecret,
}
impl AuthMethod for CleartextPassword {
#[inline(always)]
@@ -126,7 +132,13 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, CleartextPassword> {
.strip_suffix(&[0])
.ok_or(AuthErrorImpl::MalformedPassword("missing terminator"))?;
let outcome = validate_password_and_exchange(password, self.state.0).await?;
let outcome = validate_password_and_exchange(
&self.state.pool,
self.state.endpoint,
password,
self.state.secret,
)
.await?;
if let sasl::Outcome::Success(_) = &outcome {
self.stream.write_message_noflush(&Be::AuthenticationOk)?;
@@ -181,6 +193,8 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
}
pub(crate) async fn validate_password_and_exchange(
pool: &ThreadPool,
endpoint: EndpointIdInt,
password: &[u8],
secret: AuthSecret,
) -> super::Result<sasl::Outcome<ComputeCredentialKeys>> {
@@ -194,7 +208,7 @@ pub(crate) async fn validate_password_and_exchange(
}
// perform scram authentication as both client and server to validate the keys
AuthSecret::Scram(scram_secret) => {
let outcome = crate::scram::exchange(&scram_secret, password).await?;
let outcome = crate::scram::exchange(pool, endpoint, &scram_secret, password).await?;
let client_key = match outcome {
sasl::Outcome::Success(client_key) => client_key,

View File

@@ -27,6 +27,7 @@ use proxy::redis::cancellation_publisher::RedisPublisherClient;
use proxy::redis::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
use proxy::redis::elasticache;
use proxy::redis::notifications;
use proxy::scram::threadpool::ThreadPool;
use proxy::serverless::cancel_set::CancelSet;
use proxy::serverless::GlobalConnPoolOptions;
use proxy::usage_metrics;
@@ -132,6 +133,9 @@ struct ProxyCliArgs {
/// timeout for scram authentication protocol
#[clap(long, default_value = "15s", value_parser = humantime::parse_duration)]
scram_protocol_timeout: tokio::time::Duration,
/// size of the threadpool for password hashing
#[clap(long, default_value_t = 4)]
scram_thread_pool_size: u8,
/// Require that all incoming requests have a Proxy Protocol V2 packet **and** have an IP address associated.
#[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
require_client_ip: bool,
@@ -489,6 +493,9 @@ async fn main() -> anyhow::Result<()> {
/// ProxyConfig is created at proxy startup, and lives forever.
fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
let thread_pool = ThreadPool::new(args.scram_thread_pool_size);
Metrics::install(thread_pool.metrics.clone());
let tls_config = match (&args.tls_key, &args.tls_cert) {
(Some(key_path), Some(cert_path)) => Some(config::configure_tls(
key_path,
@@ -624,6 +631,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
client_conn_threshold: args.sql_over_http.sql_over_http_client_conn_threshold,
};
let authentication_config = AuthenticationConfig {
thread_pool,
scram_protocol_timeout: args.scram_protocol_timeout,
rate_limiter_enabled: args.auth_rate_limit_enabled,
rate_limiter: AuthRateLimiter::new(args.auth_rate_limit.clone()),

View File

@@ -2,6 +2,7 @@ use crate::{
auth::{self, backend::AuthRateLimiter},
console::locks::ApiLocks,
rate_limiter::RateBucketInfo,
scram::threadpool::ThreadPool,
serverless::{cancel_set::CancelSet, GlobalConnPoolOptions},
Host,
};
@@ -61,6 +62,7 @@ pub struct HttpConfig {
}
pub struct AuthenticationConfig {
pub thread_pool: Arc<ThreadPool>,
pub scram_protocol_timeout: tokio::time::Duration,
pub rate_limiter_enabled: bool,
pub rate_limiter: AuthRateLimiter,

View File

@@ -1,11 +1,11 @@
use std::sync::OnceLock;
use std::sync::{Arc, OnceLock};
use lasso::ThreadedRodeo;
use measured::{
label::StaticLabelSet,
label::{FixedCardinalitySet, LabelName, LabelSet, LabelValue, StaticLabelSet},
metric::{histogram::Thresholds, name::MetricName},
Counter, CounterVec, FixedCardinalityLabel, Gauge, Histogram, HistogramVec, LabelGroup,
MetricGroup,
Counter, CounterVec, FixedCardinalityLabel, Gauge, GaugeVec, Histogram, HistogramVec,
LabelGroup, MetricGroup,
};
use metrics::{CounterPairAssoc, CounterPairVec, HyperLogLog, HyperLogLogVec};
@@ -14,26 +14,36 @@ use tokio::time::{self, Instant};
use crate::console::messages::ColdStartInfo;
#[derive(MetricGroup)]
#[metric(new(thread_pool: Arc<ThreadPoolMetrics>))]
pub struct Metrics {
#[metric(namespace = "proxy")]
#[metric(init = ProxyMetrics::new(thread_pool))]
pub proxy: ProxyMetrics,
#[metric(namespace = "wake_compute_lock")]
pub wake_compute_lock: ApiLockMetrics,
}
static SELF: OnceLock<Metrics> = OnceLock::new();
impl Metrics {
pub fn install(thread_pool: Arc<ThreadPoolMetrics>) {
SELF.set(Metrics::new(thread_pool))
.ok()
.expect("proxy metrics must not be installed more than once");
}
pub fn get() -> &'static Self {
static SELF: OnceLock<Metrics> = OnceLock::new();
SELF.get_or_init(|| Metrics {
proxy: ProxyMetrics::default(),
wake_compute_lock: ApiLockMetrics::new(),
})
#[cfg(test)]
return SELF.get_or_init(|| Metrics::new(Arc::new(ThreadPoolMetrics::new(0))));
#[cfg(not(test))]
SELF.get()
.expect("proxy metrics must be installed by the main() function")
}
}
#[derive(MetricGroup)]
#[metric(new())]
#[metric(new(thread_pool: Arc<ThreadPoolMetrics>))]
pub struct ProxyMetrics {
#[metric(flatten)]
pub db_connections: CounterPairVec<NumDbConnectionsGauge>,
@@ -129,6 +139,10 @@ pub struct ProxyMetrics {
#[metric(namespace = "connect_compute_lock")]
pub connect_compute_lock: ApiLockMetrics,
#[metric(namespace = "scram_pool")]
#[metric(init = thread_pool)]
pub scram_pool: Arc<ThreadPoolMetrics>,
}
#[derive(MetricGroup)]
@@ -146,12 +160,6 @@ pub struct ApiLockMetrics {
pub semaphore_acquire_seconds: Histogram<16>,
}
impl Default for ProxyMetrics {
fn default() -> Self {
Self::new()
}
}
impl Default for ApiLockMetrics {
fn default() -> Self {
Self::new()
@@ -553,3 +561,52 @@ pub enum RedisEventsCount {
PasswordUpdate,
AllowedIpsUpdate,
}
pub struct ThreadPoolWorkers(usize);
pub struct ThreadPoolWorkerId(pub usize);
impl LabelValue for ThreadPoolWorkerId {
fn visit<V: measured::label::LabelVisitor>(&self, v: V) -> V::Output {
v.write_int(self.0 as i64)
}
}
impl LabelGroup for ThreadPoolWorkerId {
fn visit_values(&self, v: &mut impl measured::label::LabelGroupVisitor) {
v.write_value(LabelName::from_str("worker"), self);
}
}
impl LabelSet for ThreadPoolWorkers {
type Value<'a> = ThreadPoolWorkerId;
fn dynamic_cardinality(&self) -> Option<usize> {
Some(self.0)
}
fn encode(&self, value: Self::Value<'_>) -> Option<usize> {
(value.0 < self.0).then_some(value.0)
}
fn decode(&self, value: usize) -> Self::Value<'_> {
ThreadPoolWorkerId(value)
}
}
impl FixedCardinalitySet for ThreadPoolWorkers {
fn cardinality(&self) -> usize {
self.0
}
}
#[derive(MetricGroup)]
#[metric(new(workers: usize))]
pub struct ThreadPoolMetrics {
pub injector_queue_depth: Gauge,
#[metric(init = GaugeVec::with_label_set(ThreadPoolWorkers(workers)))]
pub worker_queue_depth: GaugeVec<ThreadPoolWorkers>,
#[metric(init = CounterVec::with_label_set(ThreadPoolWorkers(workers)))]
pub worker_task_turns_total: CounterVec<ThreadPoolWorkers>,
#[metric(init = CounterVec::with_label_set(ThreadPoolWorkers(workers)))]
pub worker_task_skips_total: CounterVec<ThreadPoolWorkers>,
}

View File

@@ -6,11 +6,14 @@
//! * <https://github.com/postgres/postgres/blob/94226d4506e66d6e7cbf4b391f1e7393c1962841/src/backend/libpq/auth-scram.c>
//! * <https://github.com/postgres/postgres/blob/94226d4506e66d6e7cbf4b391f1e7393c1962841/src/interfaces/libpq/fe-auth-scram.c>
mod countmin;
mod exchange;
mod key;
mod messages;
mod pbkdf2;
mod secret;
mod signature;
pub mod threadpool;
pub use exchange::{exchange, Exchange};
pub use key::ScramKey;
@@ -56,9 +59,13 @@ fn sha256<'a>(parts: impl IntoIterator<Item = &'a [u8]>) -> [u8; 32] {
#[cfg(test)]
mod tests {
use crate::sasl::{Mechanism, Step};
use crate::{
intern::EndpointIdInt,
sasl::{Mechanism, Step},
EndpointId,
};
use super::{Exchange, ServerSecret};
use super::{threadpool::ThreadPool, Exchange, ServerSecret};
#[test]
fn snapshot() {
@@ -112,8 +119,13 @@ mod tests {
}
async fn run_round_trip_test(server_password: &str, client_password: &str) {
let pool = ThreadPool::new(1);
let ep = EndpointId::from("foo");
let ep = EndpointIdInt::from(ep);
let scram_secret = ServerSecret::build(server_password).await.unwrap();
let outcome = super::exchange(&scram_secret, client_password.as_bytes())
let outcome = super::exchange(&pool, ep, &scram_secret, client_password.as_bytes())
.await
.unwrap();

173
proxy/src/scram/countmin.rs Normal file
View File

@@ -0,0 +1,173 @@
use std::hash::Hash;
/// estimator of hash jobs per second.
/// <https://en.wikipedia.org/wiki/Count%E2%80%93min_sketch>
pub struct CountMinSketch {
// one for each depth
hashers: Vec<ahash::RandomState>,
width: usize,
depth: usize,
// buckets, width*depth
buckets: Vec<u32>,
}
impl CountMinSketch {
/// Given parameters (ε, δ),
/// set width = ceil(e/ε)
/// set depth = ceil(ln(1/δ))
///
/// guarantees:
/// actual <= estimate
/// estimate <= actual + ε * N with probability 1 - δ
/// where N is the cardinality of the stream
pub fn with_params(epsilon: f64, delta: f64) -> Self {
CountMinSketch::new(
(std::f64::consts::E / epsilon).ceil() as usize,
(1.0_f64 / delta).ln().ceil() as usize,
)
}
fn new(width: usize, depth: usize) -> Self {
Self {
#[cfg(test)]
hashers: (0..depth)
.map(|i| {
// digits of pi for good randomness
ahash::RandomState::with_seeds(
314159265358979323,
84626433832795028,
84197169399375105,
82097494459230781 + i as u64,
)
})
.collect(),
#[cfg(not(test))]
hashers: (0..depth).map(|_| ahash::RandomState::new()).collect(),
width,
depth,
buckets: vec![0; width * depth],
}
}
pub fn inc_and_return<T: Hash>(&mut self, t: &T, x: u32) -> u32 {
let mut min = u32::MAX;
for row in 0..self.depth {
let col = (self.hashers[row].hash_one(t) as usize) % self.width;
let row = &mut self.buckets[row * self.width..][..self.width];
row[col] = row[col].saturating_add(x);
min = std::cmp::min(min, row[col]);
}
min
}
pub fn reset(&mut self) {
self.buckets.clear();
self.buckets.resize(self.width * self.depth, 0);
}
}
#[cfg(test)]
mod tests {
use rand::{rngs::StdRng, seq::SliceRandom, Rng, SeedableRng};
use super::CountMinSketch;
fn eval_precision(n: usize, p: f64, q: f64) -> usize {
// fixed value of phi for consistent test
let mut rng = StdRng::seed_from_u64(16180339887498948482);
#[allow(non_snake_case)]
let mut N = 0;
let mut ids = vec![];
for _ in 0..n {
// number of insert operations
let n = rng.gen_range(1..100);
// number to insert at once
let m = rng.gen_range(1..4096);
let id = uuid::Builder::from_random_bytes(rng.gen()).into_uuid();
ids.push((id, n, m));
// N = sum(actual)
N += n * m;
}
// q% of counts will be within p of the actual value
let mut sketch = CountMinSketch::with_params(p / N as f64, 1.0 - q);
dbg!(sketch.buckets.len());
// insert a bunch of entries in a random order
let mut ids2 = ids.clone();
while !ids2.is_empty() {
ids2.shuffle(&mut rng);
let mut i = 0;
while i < ids2.len() {
sketch.inc_and_return(&ids2[i].0, ids2[i].1);
ids2[i].2 -= 1;
if ids2[i].2 == 0 {
ids2.remove(i);
} else {
i += 1;
}
}
}
let mut within_p = 0;
for (id, n, m) in ids {
let actual = n * m;
let estimate = sketch.inc_and_return(&id, 0);
// This estimate has the guarantee that actual <= estimate
assert!(actual <= estimate);
// This estimate has the guarantee that estimate <= actual + εN with probability 1 - δ.
// ε = p / N, δ = 1 - q;
// therefore, estimate <= actual + p with probability q.
if estimate as f64 <= actual as f64 + p {
within_p += 1;
}
}
within_p
}
#[test]
fn precision() {
assert_eq!(eval_precision(100, 100.0, 0.99), 100);
assert_eq!(eval_precision(1000, 100.0, 0.99), 1000);
assert_eq!(eval_precision(100, 4096.0, 0.99), 100);
assert_eq!(eval_precision(1000, 4096.0, 0.99), 1000);
// seems to be more precise than the literature indicates?
// probably numbers are too small to truly represent the probabilities.
assert_eq!(eval_precision(100, 4096.0, 0.90), 100);
assert_eq!(eval_precision(1000, 4096.0, 0.90), 1000);
assert_eq!(eval_precision(100, 4096.0, 0.1), 98);
assert_eq!(eval_precision(1000, 4096.0, 0.1), 991);
}
// returns memory usage in bytes, and the time complexity per insert.
fn eval_cost(p: f64, q: f64) -> (usize, usize) {
#[allow(non_snake_case)]
// N = sum(actual)
// Let's assume 1021 samples, all of 4096
let N = 1021 * 4096;
let sketch = CountMinSketch::with_params(p / N as f64, 1.0 - q);
let memory = std::mem::size_of::<u32>() * sketch.buckets.len();
let time = sketch.depth;
(memory, time)
}
#[test]
fn memory_usage() {
assert_eq!(eval_cost(100.0, 0.99), (2273580, 5));
assert_eq!(eval_cost(4096.0, 0.99), (55520, 5));
assert_eq!(eval_cost(4096.0, 0.90), (33312, 3));
assert_eq!(eval_cost(4096.0, 0.1), (11104, 1));
}
}

View File

@@ -4,15 +4,17 @@ use std::convert::Infallible;
use hmac::{Hmac, Mac};
use sha2::Sha256;
use tokio::task::yield_now;
use super::messages::{
ClientFinalMessage, ClientFirstMessage, OwnedServerFirstMessage, SCRAM_RAW_NONCE_LEN,
};
use super::pbkdf2::Pbkdf2;
use super::secret::ServerSecret;
use super::signature::SignatureBuilder;
use super::threadpool::ThreadPool;
use super::ScramKey;
use crate::config;
use crate::intern::EndpointIdInt;
use crate::sasl::{self, ChannelBinding, Error as SaslError};
/// The only channel binding mode we currently support.
@@ -74,37 +76,18 @@ impl<'a> Exchange<'a> {
}
}
// copied from <https://github.com/neondatabase/rust-postgres/blob/20031d7a9ee1addeae6e0968e3899ae6bf01cee2/postgres-protocol/src/authentication/sasl.rs#L36-L61>
async fn pbkdf2(str: &[u8], salt: &[u8], iterations: u32) -> [u8; 32] {
let hmac = Hmac::<Sha256>::new_from_slice(str).expect("HMAC is able to accept all key sizes");
let mut prev = hmac
.clone()
.chain_update(salt)
.chain_update(1u32.to_be_bytes())
.finalize()
.into_bytes();
let mut hi = prev;
for i in 1..iterations {
prev = hmac.clone().chain_update(prev).finalize().into_bytes();
for (hi, prev) in hi.iter_mut().zip(prev) {
*hi ^= prev;
}
// yield every ~250us
// hopefully reduces tail latencies
if i % 1024 == 0 {
yield_now().await
}
}
hi.into()
}
// copied from <https://github.com/neondatabase/rust-postgres/blob/20031d7a9ee1addeae6e0968e3899ae6bf01cee2/postgres-protocol/src/authentication/sasl.rs#L236-L248>
async fn derive_client_key(password: &[u8], salt: &[u8], iterations: u32) -> ScramKey {
let salted_password = pbkdf2(password, salt, iterations).await;
async fn derive_client_key(
pool: &ThreadPool,
endpoint: EndpointIdInt,
password: &[u8],
salt: &[u8],
iterations: u32,
) -> ScramKey {
let salted_password = pool
.spawn_job(endpoint, Pbkdf2::start(password, salt, iterations))
.await
.expect("job should not be cancelled");
let make_key = |name| {
let key = Hmac::<Sha256>::new_from_slice(&salted_password)
@@ -119,11 +102,13 @@ async fn derive_client_key(password: &[u8], salt: &[u8], iterations: u32) -> Scr
}
pub async fn exchange(
pool: &ThreadPool,
endpoint: EndpointIdInt,
secret: &ServerSecret,
password: &[u8],
) -> sasl::Result<sasl::Outcome<super::ScramKey>> {
let salt = base64::decode(&secret.salt_base64)?;
let client_key = derive_client_key(password, &salt, secret.iterations).await;
let client_key = derive_client_key(pool, endpoint, password, &salt, secret.iterations).await;
if secret.is_password_invalid(&client_key).into() {
Ok(sasl::Outcome::Failure("password doesn't match"))

89
proxy/src/scram/pbkdf2.rs Normal file
View File

@@ -0,0 +1,89 @@
use hmac::{
digest::{consts::U32, generic_array::GenericArray},
Hmac, Mac,
};
use sha2::Sha256;
pub struct Pbkdf2 {
hmac: Hmac<Sha256>,
prev: GenericArray<u8, U32>,
hi: GenericArray<u8, U32>,
iterations: u32,
}
// inspired from <https://github.com/neondatabase/rust-postgres/blob/20031d7a9ee1addeae6e0968e3899ae6bf01cee2/postgres-protocol/src/authentication/sasl.rs#L36-L61>
impl Pbkdf2 {
pub fn start(str: &[u8], salt: &[u8], iterations: u32) -> Self {
let hmac =
Hmac::<Sha256>::new_from_slice(str).expect("HMAC is able to accept all key sizes");
let prev = hmac
.clone()
.chain_update(salt)
.chain_update(1u32.to_be_bytes())
.finalize()
.into_bytes();
Self {
hmac,
// one consumed for the hash above
iterations: iterations - 1,
hi: prev,
prev,
}
}
pub fn cost(&self) -> u32 {
(self.iterations).clamp(0, 4096)
}
pub fn turn(&mut self) -> std::task::Poll<[u8; 32]> {
let Self {
hmac,
prev,
hi,
iterations,
} = self;
// only do 4096 iterations per turn before sharing the thread for fairness
let n = (*iterations).clamp(0, 4096);
for _ in 0..n {
*prev = hmac.clone().chain_update(*prev).finalize().into_bytes();
for (hi, prev) in hi.iter_mut().zip(*prev) {
*hi ^= prev;
}
}
*iterations -= n;
if *iterations == 0 {
std::task::Poll::Ready((*hi).into())
} else {
std::task::Poll::Pending
}
}
}
#[cfg(test)]
mod tests {
use super::Pbkdf2;
use pbkdf2::pbkdf2_hmac_array;
use sha2::Sha256;
#[test]
fn works() {
let salt = b"sodium chloride";
let pass = b"Ne0n_!5_50_C007";
let mut job = Pbkdf2::start(pass, salt, 600000);
let hash = loop {
let std::task::Poll::Ready(hash) = job.turn() else {
continue;
};
break hash;
};
let expected = pbkdf2_hmac_array::<Sha256, 32>(pass, salt, 600000);
assert_eq!(hash, expected)
}
}

View File

@@ -0,0 +1,321 @@
//! Custom threadpool implementation for password hashing.
//!
//! Requirements:
//! 1. Fairness per endpoint.
//! 2. Yield support for high iteration counts.
use std::sync::{
atomic::{AtomicU64, Ordering},
Arc,
};
use crossbeam_deque::{Injector, Stealer, Worker};
use itertools::Itertools;
use parking_lot::{Condvar, Mutex};
use rand::Rng;
use rand::{rngs::SmallRng, SeedableRng};
use tokio::sync::oneshot;
use crate::{
intern::EndpointIdInt,
metrics::{ThreadPoolMetrics, ThreadPoolWorkerId},
scram::countmin::CountMinSketch,
};
use super::pbkdf2::Pbkdf2;
pub struct ThreadPool {
queue: Injector<JobSpec>,
stealers: Vec<Stealer<JobSpec>>,
parkers: Vec<(Condvar, Mutex<ThreadState>)>,
/// bitpacked representation.
/// lower 8 bits = number of sleeping threads
/// next 8 bits = number of idle threads (searching for work)
counters: AtomicU64,
pub metrics: Arc<ThreadPoolMetrics>,
}
#[derive(PartialEq)]
enum ThreadState {
Parked,
Active,
}
impl ThreadPool {
pub fn new(n_workers: u8) -> Arc<Self> {
let workers = (0..n_workers).map(|_| Worker::new_fifo()).collect_vec();
let stealers = workers.iter().map(|w| w.stealer()).collect_vec();
let parkers = (0..n_workers)
.map(|_| (Condvar::new(), Mutex::new(ThreadState::Active)))
.collect_vec();
let pool = Arc::new(Self {
queue: Injector::new(),
stealers,
parkers,
// threads start searching for work
counters: AtomicU64::new((n_workers as u64) << 8),
metrics: Arc::new(ThreadPoolMetrics::new(n_workers as usize)),
});
for (i, worker) in workers.into_iter().enumerate() {
let pool = Arc::clone(&pool);
std::thread::spawn(move || thread_rt(pool, worker, i));
}
pool
}
pub fn spawn_job(
&self,
endpoint: EndpointIdInt,
pbkdf2: Pbkdf2,
) -> oneshot::Receiver<[u8; 32]> {
let (tx, rx) = oneshot::channel();
let queue_was_empty = self.queue.is_empty();
self.metrics.injector_queue_depth.inc();
self.queue.push(JobSpec {
response: tx,
pbkdf2,
endpoint,
});
// inspired from <https://github.com/rayon-rs/rayon/blob/3e3962cb8f7b50773bcc360b48a7a674a53a2c77/rayon-core/src/sleep/mod.rs#L242>
let counts = self.counters.load(Ordering::SeqCst);
let num_awake_but_idle = (counts >> 8) & 0xff;
let num_sleepers = counts & 0xff;
// If the queue is non-empty, then we always wake up a worker
// -- clearly the existing idle jobs aren't enough. Otherwise,
// check to see if we have enough idle workers.
if !queue_was_empty || num_awake_but_idle == 0 {
let num_to_wake = Ord::min(1, num_sleepers);
self.wake_any_threads(num_to_wake);
}
rx
}
#[cold]
fn wake_any_threads(&self, mut num_to_wake: u64) {
if num_to_wake > 0 {
for i in 0..self.parkers.len() {
if self.wake_specific_thread(i) {
num_to_wake -= 1;
if num_to_wake == 0 {
return;
}
}
}
}
}
fn wake_specific_thread(&self, index: usize) -> bool {
let (condvar, lock) = &self.parkers[index];
let mut state = lock.lock();
if *state == ThreadState::Parked {
condvar.notify_one();
// When the thread went to sleep, it will have incremented
// this value. When we wake it, its our job to decrement
// it. We could have the thread do it, but that would
// introduce a delay between when the thread was
// *notified* and when this counter was decremented. That
// might mislead people with new work into thinking that
// there are sleeping threads that they should try to
// wake, when in fact there is nothing left for them to
// do.
self.counters.fetch_sub(1, Ordering::SeqCst);
*state = ThreadState::Active;
true
} else {
false
}
}
fn steal(&self, rng: &mut impl Rng, skip: usize, worker: &Worker<JobSpec>) -> Option<JobSpec> {
// announce thread as idle
self.counters.fetch_add(256, Ordering::SeqCst);
// try steal from the global queue
loop {
match self.queue.steal_batch_and_pop(worker) {
crossbeam_deque::Steal::Success(job) => {
self.metrics
.injector_queue_depth
.set(self.queue.len() as i64);
// no longer idle
self.counters.fetch_sub(256, Ordering::SeqCst);
return Some(job);
}
crossbeam_deque::Steal::Retry => continue,
crossbeam_deque::Steal::Empty => break,
}
}
// try steal from our neighbours
loop {
let mut retry = false;
let start = rng.gen_range(0..self.stealers.len());
let job = (start..self.stealers.len())
.chain(0..start)
.filter(|i| *i != skip)
.find_map(
|victim| match self.stealers[victim].steal_batch_and_pop(worker) {
crossbeam_deque::Steal::Success(job) => Some(job),
crossbeam_deque::Steal::Empty => None,
crossbeam_deque::Steal::Retry => {
retry = true;
None
}
},
);
if job.is_some() {
// no longer idle
self.counters.fetch_sub(256, Ordering::SeqCst);
return job;
}
if !retry {
return None;
}
}
}
}
fn thread_rt(pool: Arc<ThreadPool>, worker: Worker<JobSpec>, index: usize) {
/// interval when we should steal from the global queue
/// so that tail latencies are managed appropriately
const STEAL_INTERVAL: usize = 61;
/// How often to reset the sketch values
const SKETCH_RESET_INTERVAL: usize = 1021;
let mut rng = SmallRng::from_entropy();
// used to determine whether we should temporarily skip tasks for fairness.
// 99% of estimates will overcount by no more than 4096 samples
let mut sketch = CountMinSketch::with_params(1.0 / (SKETCH_RESET_INTERVAL as f64), 0.01);
let (condvar, lock) = &pool.parkers[index];
'wait: loop {
// wait for notification of work
{
let mut lock = lock.lock();
// queue is empty
pool.metrics
.worker_queue_depth
.set(ThreadPoolWorkerId(index), 0);
// subtract 1 from idle count, add 1 to sleeping count.
pool.counters.fetch_sub(255, Ordering::SeqCst);
*lock = ThreadState::Parked;
condvar.wait(&mut lock);
}
for i in 0.. {
let mut job = match worker
.pop()
.or_else(|| pool.steal(&mut rng, index, &worker))
{
Some(job) => job,
None => continue 'wait,
};
pool.metrics
.worker_queue_depth
.set(ThreadPoolWorkerId(index), worker.len() as i64);
// receiver is closed, cancel the task
if !job.response.is_closed() {
let rate = sketch.inc_and_return(&job.endpoint, job.pbkdf2.cost());
const P: f64 = 2000.0;
// probability decreases as rate increases.
// lower probability, higher chance of being skipped
//
// estimates (rate in terms of 4096 rounds):
// rate = 0 => probability = 100%
// rate = 10 => probability = 71.3%
// rate = 50 => probability = 62.1%
// rate = 500 => probability = 52.3%
// rate = 1021 => probability = 49.8%
//
// My expectation is that the pool queue will only begin backing up at ~1000rps
// in which case the SKETCH_RESET_INTERVAL represents 1 second. Thus, the rates above
// are in requests per second.
let probability = P.ln() / (P + rate as f64).ln();
if pool.queue.len() > 32 || rng.gen_bool(probability) {
pool.metrics
.worker_task_turns_total
.inc(ThreadPoolWorkerId(index));
match job.pbkdf2.turn() {
std::task::Poll::Ready(result) => {
let _ = job.response.send(result);
}
std::task::Poll::Pending => worker.push(job),
}
} else {
pool.metrics
.worker_task_skips_total
.inc(ThreadPoolWorkerId(index));
// skip for now
worker.push(job)
}
}
// if we get stuck with a few long lived jobs in the queue
// it's better to try and steal from the queue too for fairness
if i % STEAL_INTERVAL == 0 {
let _ = pool.queue.steal_batch(&worker);
}
if i % SKETCH_RESET_INTERVAL == 0 {
sketch.reset();
}
}
}
}
struct JobSpec {
response: oneshot::Sender<[u8; 32]>,
pbkdf2: Pbkdf2,
endpoint: EndpointIdInt,
}
#[cfg(test)]
mod tests {
use crate::EndpointId;
use super::*;
#[tokio::test]
async fn hash_is_correct() {
let pool = ThreadPool::new(1);
let ep = EndpointId::from("foo");
let ep = EndpointIdInt::from(ep);
let salt = [0x55; 32];
let actual = pool
.spawn_job(ep, Pbkdf2::start(b"password", &salt, 4096))
.await
.unwrap();
let expected = [
10, 114, 73, 188, 140, 222, 196, 156, 214, 184, 79, 157, 119, 242, 16, 31, 53, 242,
178, 43, 95, 8, 225, 182, 122, 40, 219, 21, 89, 147, 64, 140,
];
assert_eq!(actual, expected)
}
}

View File

@@ -15,6 +15,7 @@ use crate::{
},
context::RequestMonitoring,
error::{ErrorKind, ReportableError, UserFacingError},
intern::EndpointIdInt,
proxy::{connect_compute::ConnectMechanism, retry::ShouldRetry},
rate_limiter::EndpointRateLimiter,
Host,
@@ -66,8 +67,14 @@ impl PoolingBackend {
return Err(AuthError::auth_failed(&*user_info.user));
}
};
let auth_outcome =
crate::auth::validate_password_and_exchange(&conn_info.password, secret).await?;
let ep = EndpointIdInt::from(&conn_info.user_info.endpoint);
let auth_outcome = crate::auth::validate_password_and_exchange(
&config.thread_pool,
ep,
&conn_info.password,
secret,
)
.await?;
let res = match auth_outcome {
crate::sasl::Outcome::Success(key) => {
info!("user successfully authenticated");

View File

@@ -13,6 +13,7 @@ publish = false
### BEGIN HAKARI SECTION
[dependencies]
ahash = { version = "0.8" }
anyhow = { version = "1", features = ["backtrace"] }
aws-config = { version = "1", default-features = false, features = ["rustls", "sso"] }
aws-runtime = { version = "1", default-features = false, features = ["event-stream", "http-02x", "sigv4a"] }
@@ -85,6 +86,7 @@ zstd-safe = { version = "7", default-features = false, features = ["arrays", "le
zstd-sys = { version = "2", default-features = false, features = ["legacy", "std", "zdict_builder"] }
[build-dependencies]
ahash = { version = "0.8" }
anyhow = { version = "1", features = ["backtrace"] }
bytes = { version = "1", features = ["serde"] }
cc = { version = "1", default-features = false, features = ["parallel"] }