Files
neon/proxy/src/scram/threadpool.rs
Conrad Ludgate 03e604e432 Nightly lints and small tweaks (#12456)
Let chains available in 1.88 :D new clippy lints coming up in future
releases.
2025-07-03 14:47:12 +00:00

213 lines
6.4 KiB
Rust

//! Custom threadpool implementation for password hashing.
//!
//! Requirements:
//! 1. Fairness per endpoint.
//! 2. Yield support for high iteration counts.
use std::cell::RefCell;
use std::future::Future;
use std::pin::Pin;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Weak};
use std::task::{Context, Poll};
use futures::FutureExt;
use rand::rngs::SmallRng;
use rand::{Rng, SeedableRng};
use super::pbkdf2::Pbkdf2;
use crate::intern::EndpointIdInt;
use crate::metrics::{ThreadPoolMetrics, ThreadPoolWorkerId};
use crate::scram::countmin::CountMinSketch;
pub struct ThreadPool {
runtime: Option<tokio::runtime::Runtime>,
pub metrics: Arc<ThreadPoolMetrics>,
}
/// How often to reset the sketch values
const SKETCH_RESET_INTERVAL: u64 = 1021;
thread_local! {
static STATE: RefCell<Option<ThreadRt>> = const { RefCell::new(None) };
}
impl ThreadPool {
pub fn new(mut n_workers: u8) -> Arc<Self> {
// rayon would be nice here, but yielding in rayon does not work well afaict.
if n_workers == 0 {
n_workers = 1;
}
Arc::new_cyclic(|pool| {
let pool = pool.clone();
let worker_id = AtomicUsize::new(0);
let runtime = tokio::runtime::Builder::new_multi_thread()
.worker_threads(n_workers as usize)
.on_thread_start(move || {
STATE.with_borrow_mut(|state| {
*state = Some(ThreadRt {
pool: pool.clone(),
id: ThreadPoolWorkerId(worker_id.fetch_add(1, Ordering::Relaxed)),
rng: SmallRng::from_entropy(),
// used to determine whether we should temporarily skip tasks for fairness.
// 99% of estimates will overcount by no more than 4096 samples
countmin: CountMinSketch::with_params(
1.0 / (SKETCH_RESET_INTERVAL as f64),
0.01,
),
tick: 0,
});
});
})
.build()
.expect("password threadpool runtime should be configured correctly");
Self {
runtime: Some(runtime),
metrics: Arc::new(ThreadPoolMetrics::new(n_workers as usize)),
}
})
}
pub(crate) fn spawn_job(&self, endpoint: EndpointIdInt, pbkdf2: Pbkdf2) -> JobHandle {
JobHandle(
self.runtime
.as_ref()
.expect("runtime is always set")
.spawn(JobSpec { pbkdf2, endpoint }),
)
}
}
impl Drop for ThreadPool {
fn drop(&mut self) {
self.runtime
.take()
.expect("runtime is always set")
.shutdown_background();
}
}
struct ThreadRt {
pool: Weak<ThreadPool>,
id: ThreadPoolWorkerId,
rng: SmallRng,
countmin: CountMinSketch,
tick: u64,
}
impl ThreadRt {
fn should_run(&mut self, job: &JobSpec) -> bool {
let rate = self
.countmin
.inc_and_return(&job.endpoint, job.pbkdf2.cost());
const P: f64 = 2000.0;
// probability decreases as rate increases.
// lower probability, higher chance of being skipped
//
// estimates (rate in terms of 4096 rounds):
// rate = 0 => probability = 100%
// rate = 10 => probability = 71.3%
// rate = 50 => probability = 62.1%
// rate = 500 => probability = 52.3%
// rate = 1021 => probability = 49.8%
//
// My expectation is that the pool queue will only begin backing up at ~1000rps
// in which case the SKETCH_RESET_INTERVAL represents 1 second. Thus, the rates above
// are in requests per second.
let probability = P.ln() / (P + rate as f64).ln();
self.rng.gen_bool(probability)
}
}
struct JobSpec {
pbkdf2: Pbkdf2,
endpoint: EndpointIdInt,
}
impl Future for JobSpec {
type Output = [u8; 32];
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
STATE.with_borrow_mut(|state| {
let state = state.as_mut().expect("should be set on thread startup");
state.tick = state.tick.wrapping_add(1);
if state.tick.is_multiple_of(SKETCH_RESET_INTERVAL) {
state.countmin.reset();
}
if state.should_run(&self) {
if let Some(pool) = state.pool.upgrade() {
pool.metrics.worker_task_turns_total.inc(state.id);
}
match self.pbkdf2.turn() {
Poll::Ready(result) => Poll::Ready(result),
// more to do, we shall requeue
Poll::Pending => {
cx.waker().wake_by_ref();
Poll::Pending
}
}
} else {
if let Some(pool) = state.pool.upgrade() {
pool.metrics.worker_task_skips_total.inc(state.id);
}
cx.waker().wake_by_ref();
Poll::Pending
}
})
}
}
pub(crate) struct JobHandle(tokio::task::JoinHandle<[u8; 32]>);
impl Future for JobHandle {
type Output = [u8; 32];
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.0.poll_unpin(cx) {
Poll::Ready(Ok(ok)) => Poll::Ready(ok),
Poll::Ready(Err(err)) => std::panic::resume_unwind(err.into_panic()),
Poll::Pending => Poll::Pending,
}
}
}
impl Drop for JobHandle {
fn drop(&mut self) {
self.0.abort();
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::EndpointId;
#[tokio::test]
async fn hash_is_correct() {
let pool = ThreadPool::new(1);
let ep = EndpointId::from("foo");
let ep = EndpointIdInt::from(ep);
let salt = [0x55; 32];
let actual = pool
.spawn_job(ep, Pbkdf2::start(b"password", &salt, 4096))
.await;
let expected = [
10, 114, 73, 188, 140, 222, 196, 156, 214, 184, 79, 157, 119, 242, 16, 31, 53, 242,
178, 43, 95, 8, 225, 182, 122, 40, 219, 21, 89, 147, 64, 140,
];
assert_eq!(actual, expected);
}
}