mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-10 23:12:54 +00:00
## Problem Password hashing for sql-over-http takes up a lot of CPU. Perhaps we can get away with temporarily caching some steps so we only need fewer rounds, which will save some CPU time. ## Summary of changes The output of pbkdf2 is the XOR of the outputs of each iteration round, eg `U1 ^ U2 ^ ... U15 ^ U16 ^ U17 ^ ... ^ Un`. We cache the suffix of the expression `U16 ^ U17 ^ ... ^ Un`. To compute the result from the cached suffix, we only need to compute the prefix `U1 ^ U2 ^ ... U15`. The suffix by itself is useless, which prevent's its use in brute-force attacks should this cached memory leak. We are also caching the full 4096 round hash in memory, which can be used for brute-force attacks, where this suffix could be used to speed it up. My hope/expectation is that since these will be in different allocations, it makes any such memory exploitation much much harder. Since the full hash cache might be invalidated while the suffix is cached, I'm storing the timestamp of the computation as a way to identity the match. I also added `zeroize()` to clear the sensitive state from the stack/heap. For the most security conscious customers, we hope to roll out OIDC soon, so they can disable passwords entirely. --- The numbers for the threadpool were pretty random, but according to our busiest region for sql-over-http, we only see about 150 unique endpoints every minute. So storing ~100 of the most common endpoints for that minute should be the vast majority of requests. 1 minute was chosen so we don't keep data in memory for too long.
220 lines
6.6 KiB
Rust
220 lines
6.6 KiB
Rust
//! Custom threadpool implementation for password hashing.
|
|
//!
|
|
//! Requirements:
|
|
//! 1. Fairness per endpoint.
|
|
//! 2. Yield support for high iteration counts.
|
|
|
|
use std::cell::RefCell;
|
|
use std::future::Future;
|
|
use std::pin::Pin;
|
|
use std::sync::atomic::{AtomicUsize, Ordering};
|
|
use std::sync::{Arc, Weak};
|
|
use std::task::{Context, Poll};
|
|
|
|
use futures::FutureExt;
|
|
use rand::rngs::SmallRng;
|
|
use rand::{Rng, SeedableRng};
|
|
|
|
use super::cache::Pbkdf2Cache;
|
|
use super::pbkdf2;
|
|
use super::pbkdf2::Pbkdf2;
|
|
use crate::intern::EndpointIdInt;
|
|
use crate::metrics::{ThreadPoolMetrics, ThreadPoolWorkerId};
|
|
use crate::scram::countmin::CountMinSketch;
|
|
|
|
pub struct ThreadPool {
|
|
runtime: Option<tokio::runtime::Runtime>,
|
|
pub metrics: Arc<ThreadPoolMetrics>,
|
|
|
|
// we hash a lot of passwords.
|
|
// we keep a cache of partial hashes for faster validation.
|
|
pub(super) cache: Pbkdf2Cache,
|
|
}
|
|
|
|
/// How often to reset the sketch values
|
|
const SKETCH_RESET_INTERVAL: u64 = 1021;
|
|
|
|
thread_local! {
|
|
static STATE: RefCell<Option<ThreadRt>> = const { RefCell::new(None) };
|
|
}
|
|
|
|
impl ThreadPool {
|
|
pub fn new(mut n_workers: u8) -> Arc<Self> {
|
|
// rayon would be nice here, but yielding in rayon does not work well afaict.
|
|
|
|
if n_workers == 0 {
|
|
n_workers = 1;
|
|
}
|
|
|
|
Arc::new_cyclic(|pool| {
|
|
let pool = pool.clone();
|
|
let worker_id = AtomicUsize::new(0);
|
|
|
|
let runtime = tokio::runtime::Builder::new_multi_thread()
|
|
.worker_threads(n_workers as usize)
|
|
.on_thread_start(move || {
|
|
STATE.with_borrow_mut(|state| {
|
|
*state = Some(ThreadRt {
|
|
pool: pool.clone(),
|
|
id: ThreadPoolWorkerId(worker_id.fetch_add(1, Ordering::Relaxed)),
|
|
rng: SmallRng::from_os_rng(),
|
|
// used to determine whether we should temporarily skip tasks for fairness.
|
|
// 99% of estimates will overcount by no more than 4096 samples
|
|
countmin: CountMinSketch::with_params(
|
|
1.0 / (SKETCH_RESET_INTERVAL as f64),
|
|
0.01,
|
|
),
|
|
tick: 0,
|
|
});
|
|
});
|
|
})
|
|
.build()
|
|
.expect("password threadpool runtime should be configured correctly");
|
|
|
|
Self {
|
|
runtime: Some(runtime),
|
|
metrics: Arc::new(ThreadPoolMetrics::new(n_workers as usize)),
|
|
cache: Pbkdf2Cache::new(),
|
|
}
|
|
})
|
|
}
|
|
|
|
pub(crate) fn spawn_job(&self, endpoint: EndpointIdInt, pbkdf2: Pbkdf2) -> JobHandle {
|
|
JobHandle(
|
|
self.runtime
|
|
.as_ref()
|
|
.expect("runtime is always set")
|
|
.spawn(JobSpec { pbkdf2, endpoint }),
|
|
)
|
|
}
|
|
}
|
|
|
|
impl Drop for ThreadPool {
|
|
fn drop(&mut self) {
|
|
self.runtime
|
|
.take()
|
|
.expect("runtime is always set")
|
|
.shutdown_background();
|
|
}
|
|
}
|
|
|
|
struct ThreadRt {
|
|
pool: Weak<ThreadPool>,
|
|
id: ThreadPoolWorkerId,
|
|
rng: SmallRng,
|
|
countmin: CountMinSketch,
|
|
tick: u64,
|
|
}
|
|
|
|
impl ThreadRt {
|
|
fn should_run(&mut self, job: &JobSpec) -> bool {
|
|
let rate = self
|
|
.countmin
|
|
.inc_and_return(&job.endpoint, job.pbkdf2.cost());
|
|
|
|
const P: f64 = 2000.0;
|
|
// probability decreases as rate increases.
|
|
// lower probability, higher chance of being skipped
|
|
//
|
|
// estimates (rate in terms of 4096 rounds):
|
|
// rate = 0 => probability = 100%
|
|
// rate = 10 => probability = 71.3%
|
|
// rate = 50 => probability = 62.1%
|
|
// rate = 500 => probability = 52.3%
|
|
// rate = 1021 => probability = 49.8%
|
|
//
|
|
// My expectation is that the pool queue will only begin backing up at ~1000rps
|
|
// in which case the SKETCH_RESET_INTERVAL represents 1 second. Thus, the rates above
|
|
// are in requests per second.
|
|
let probability = P.ln() / (P + rate as f64).ln();
|
|
self.rng.random_bool(probability)
|
|
}
|
|
}
|
|
|
|
struct JobSpec {
|
|
pbkdf2: Pbkdf2,
|
|
endpoint: EndpointIdInt,
|
|
}
|
|
|
|
impl Future for JobSpec {
|
|
type Output = pbkdf2::Block;
|
|
|
|
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
|
STATE.with_borrow_mut(|state| {
|
|
let state = state.as_mut().expect("should be set on thread startup");
|
|
|
|
state.tick = state.tick.wrapping_add(1);
|
|
if state.tick.is_multiple_of(SKETCH_RESET_INTERVAL) {
|
|
state.countmin.reset();
|
|
}
|
|
|
|
if state.should_run(&self) {
|
|
if let Some(pool) = state.pool.upgrade() {
|
|
pool.metrics.worker_task_turns_total.inc(state.id);
|
|
}
|
|
|
|
match self.pbkdf2.turn() {
|
|
Poll::Ready(result) => Poll::Ready(result),
|
|
// more to do, we shall requeue
|
|
Poll::Pending => {
|
|
cx.waker().wake_by_ref();
|
|
Poll::Pending
|
|
}
|
|
}
|
|
} else {
|
|
if let Some(pool) = state.pool.upgrade() {
|
|
pool.metrics.worker_task_skips_total.inc(state.id);
|
|
}
|
|
|
|
cx.waker().wake_by_ref();
|
|
Poll::Pending
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
pub(crate) struct JobHandle(tokio::task::JoinHandle<pbkdf2::Block>);
|
|
|
|
impl Future for JobHandle {
|
|
type Output = pbkdf2::Block;
|
|
|
|
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
|
match self.0.poll_unpin(cx) {
|
|
Poll::Ready(Ok(ok)) => Poll::Ready(ok),
|
|
Poll::Ready(Err(err)) => std::panic::resume_unwind(err.into_panic()),
|
|
Poll::Pending => Poll::Pending,
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Drop for JobHandle {
|
|
fn drop(&mut self) {
|
|
self.0.abort();
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use crate::types::EndpointId;
|
|
|
|
#[tokio::test]
|
|
async fn hash_is_correct() {
|
|
let pool = ThreadPool::new(1);
|
|
|
|
let ep = EndpointId::from("foo");
|
|
let ep = EndpointIdInt::from(ep);
|
|
|
|
let salt = [0x55; 32];
|
|
let actual = pool
|
|
.spawn_job(ep, Pbkdf2::start(b"password", &salt, 4096))
|
|
.await;
|
|
|
|
let expected = &[
|
|
10, 114, 73, 188, 140, 222, 196, 156, 214, 184, 79, 157, 119, 242, 16, 31, 53, 242,
|
|
178, 43, 95, 8, 225, 182, 122, 40, 219, 21, 89, 147, 64, 140,
|
|
];
|
|
assert_eq!(actual.as_slice(), expected);
|
|
}
|
|
}
|