Compare commits

...

5 Commits

Author SHA1 Message Date
Conrad Ludgate
6e4209e7a7 faster math 2024-09-21 11:28:50 +01:00
Conrad Ludgate
0e51790ebd cache the pre-computed ready_at 2024-09-21 08:40:29 +01:00
Conrad Ludgate
28ba1cbb09 batch wakeup support 2024-09-21 08:29:08 +01:00
Conrad Ludgate
b2e2eb54ce who even needs unsafe 2024-09-20 17:30:43 +01:00
Conrad Ludgate
55866e99d7 utils: write my own mutex/queue using pin-list and some unsafe 2024-09-20 17:20:34 +01:00
4 changed files with 330 additions and 92 deletions

17
Cargo.lock generated
View File

@@ -3959,6 +3959,16 @@ dependencies = [
"siphasher",
]
[[package]]
name = "pin-list"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fe91484d5a948b56f858ff2b92fd5b20b97d21b11d2d41041db8e5ec12d56c5e"
dependencies = [
"pin-project-lite",
"pinned-aliasable",
]
[[package]]
name = "pin-project"
version = "1.1.0"
@@ -3991,6 +4001,12 @@ version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184"
[[package]]
name = "pinned-aliasable"
version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5d0f9ae89bf0ed03b69ac1f3f7ea2e6e09b4fa5448011df2e67d581c2b850b7b"
[[package]]
name = "pkcs1"
version = "0.7.5"
@@ -6807,6 +6823,7 @@ dependencies = [
"metrics",
"nix 0.27.1",
"once_cell",
"pin-list",
"pin-project-lite",
"postgres_connection",
"pq_proto",

View File

@@ -27,6 +27,7 @@ futures = { workspace = true}
jsonwebtoken.workspace = true
nix.workspace = true
once_cell.workspace = true
pin-list = "0.1"
pin-project-lite.workspace = true
regex.workspace = true
routerify.workspace = true

View File

@@ -22,30 +22,48 @@
//! Another explaination can be found here: <https://brandur.org/rate-limiting>
use std::{
sync::{
atomic::{AtomicU64, Ordering},
Mutex,
},
fmt::Debug,
sync::Mutex,
task::{Poll, Waker},
time::Duration,
};
use tokio::{sync::Notify, time::Instant};
use pin_list::{Node, NodeData, PinList};
use tokio::time::Instant;
pub struct LeakyBucketConfig {
pub epoch: Instant,
/// This is the "time cost" of a single request unit.
/// Should loosely represent how long it takes to handle a request unit in active resource time.
/// Loosely speaking this is the inverse of the steady-rate requests-per-second
pub cost: Duration,
pub cost: Ns,
/// total size of the bucket
pub bucket_width: Duration,
pub bucket_width: Ns,
}
impl LeakyBucketConfig {
pub fn new(rps: f64, bucket_size: f64) -> Self {
let cost = Duration::from_secs_f64(rps.recip());
let bucket_width = cost.mul_f64(bucket_size);
Self { cost, bucket_width }
Self {
epoch: Instant::now(),
cost: Ns(cost.as_nanos() as u64),
bucket_width: Ns(bucket_width.as_nanos() as u64),
}
}
pub fn to_epoch(&self, t: Instant) -> InstantNs {
InstantNs((t - self.epoch).into())
}
pub fn from_epoch(&self, t: InstantNs) -> Instant {
self.epoch + t
}
pub fn now(&self) -> InstantNs {
self.to_epoch(Instant::now())
}
}
@@ -64,17 +82,17 @@ pub struct LeakyBucketState {
///
/// This is inspired by the generic cell rate algorithm (GCRA) and works
/// exactly the same as a leaky-bucket.
pub empty_at: Instant,
pub empty_at: InstantNs,
}
impl LeakyBucketState {
pub fn with_initial_tokens(config: &LeakyBucketConfig, initial_tokens: f64) -> Self {
LeakyBucketState {
empty_at: Instant::now() + config.cost.mul_f64(initial_tokens),
empty_at: InstantNs(config.cost * initial_tokens),
}
}
pub fn bucket_is_empty(&self, now: Instant) -> bool {
pub fn bucket_is_empty(&self, now: InstantNs) -> bool {
// if self.end is after now, the bucket is not empty
self.empty_at <= now
}
@@ -96,8 +114,22 @@ impl LeakyBucketState {
started: Instant,
n: f64,
) -> Result<(), Instant> {
let now = Instant::now();
self.add_tokens_fast(
config.bucket_width,
config.to_epoch(started),
config.to_epoch(Instant::now()),
config.cost * n,
)
.map_err(|ready_at| config.from_epoch(ready_at))
}
pub fn add_tokens_fast(
&mut self,
bucket_width: Ns,
started: InstantNs,
now: InstantNs,
n: Ns,
) -> Result<(), InstantNs> {
// invariant: started <= now
debug_assert!(started <= now);
@@ -109,9 +141,8 @@ impl LeakyBucketState {
empty_at = started;
}
let n = config.cost.mul_f64(n);
let new_empty_at = empty_at + n;
let allow_at = new_empty_at.checked_sub(config.bucket_width);
let allow_at = new_empty_at.checked_sub(bucket_width);
// empty_at
// allow_at | new_empty_at
@@ -132,36 +163,242 @@ impl LeakyBucketState {
}
}
pub struct RateLimiter {
pub config: LeakyBucketConfig,
pub sleep_counter: AtomicU64,
pub state: Mutex<LeakyBucketState>,
/// a queue to provide this fair ordering.
pub queue: Notify,
// u64 nanoseconds allows for 584 years
#[derive(PartialEq, PartialOrd, Copy, Clone)]
pub struct Ns(u64);
impl Debug for Ns {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Duration::from_nanos(self.0).fmt(f)
}
}
struct Requeue<'a>(&'a Notify);
impl Ns {
fn as_secs_f64(self) -> f64 {
(self.0 as f64) / 1_000_000_000.0
}
}
impl Drop for Requeue<'_> {
fn drop(&mut self) {
self.0.notify_one();
impl From<Duration> for Ns {
fn from(value: Duration) -> Self {
Self(value.as_nanos() as u64)
}
}
impl std::ops::Mul<f64> for Ns {
type Output = Ns;
fn mul(self, rhs: f64) -> Ns {
Ns((self.0 as f64 * rhs) as u64)
}
}
impl std::ops::Mul<u64> for Ns {
type Output = Ns;
fn mul(self, rhs: u64) -> Ns {
Ns(self.0 * rhs)
}
}
// ns since some epoch
#[derive(PartialEq, PartialOrd, Copy, Clone)]
pub struct InstantNs(Ns);
impl std::ops::Add<Ns> for InstantNs {
type Output = InstantNs;
fn add(self, rhs: Ns) -> InstantNs {
Self(Ns(self.0 .0 + rhs.0))
}
}
impl std::ops::Sub for InstantNs {
type Output = Ns;
fn sub(self, rhs: InstantNs) -> Ns {
Ns(self.0 .0 - rhs.0 .0)
}
}
impl InstantNs {
fn checked_sub(self, rhs: Ns) -> Option<Self> {
self.0 .0.checked_sub(rhs.0).map(Ns).map(Self)
}
}
impl std::ops::Add<InstantNs> for Instant {
type Output = Instant;
fn add(self, rhs: InstantNs) -> Instant {
self + Duration::from_nanos(rhs.0 .0)
}
}
pub struct RateLimiter {
config: LeakyBucketConfig,
queue: Mutex<Queue>,
}
struct Queue {
sleep_counter: u64,
queue: PinList<RateLimitQueue>,
state: Option<LeakyBucketState>,
}
impl RateLimiter {
/// returns the sleep_counter start value on await.
/// sleep_counter end value can be found within the enqueued.
fn wait(&self, count: usize) -> Enqueued<'_> {
Enqueued {
entry: pin_list::Node::new(),
limiter: self,
sleep_counter: 0,
state: None,
count: self.config.cost * (count as u64),
start: self.config.to_epoch(Instant::now()),
}
}
}
type RateLimitQueue = dyn pin_list::Types<
Id = pin_list::id::Checked,
// the waker that lets us wake the next in the queue
// the instant is our acquire start time
// the duration is our GCRA token cost
Protected = (Waker, InstantNs, Ns),
// the token that gives us access to the rate limit state, along with when it should be ready.
// if None, then we were granted access already by the leader
Removed = Option<(LeakyBucketState, InstantNs)>,
// the sleep count at the start of the enqueue
Unprotected = u64,
>;
pin_project_lite::pin_project! {
struct Enqueued<'a> {
#[pin]
entry: Node<RateLimitQueue>,
state: Option<(LeakyBucketState, InstantNs)>,
sleep_counter: u64,
limiter: &'a RateLimiter,
start: InstantNs,
count: Ns,
}
impl<'a> PinnedDrop for Enqueued<'a> {
fn drop(this: Pin<&mut Self>) {
let this = this.project();
#[allow(clippy::mut_mutex_lock, reason = "false positive")]
let mut q = this.limiter.queue.lock().unwrap();
let mut state = if let Some(init) = this.entry.initialized_mut() {
let (data, _start_count) = init.reset(&mut q.queue) ;
match data {
// we were in the queue and are not holding any resources.
NodeData::Linked(_) | NodeData::Removed(None) => return,
// we were the head of the queue and were about to be the current leader
NodeData::Removed(Some((state, _ready_at))) => state
}
} else if let Some((state, _ready_at)) = this.state.take() {
// we were holding the lock, and are now releasing it.
q.sleep_counter = *this.sleep_counter;
state
} else {
// we apparently didn't even get into the queue to begin with
return;
};
let mut cursor = q.queue.cursor_front_mut();
let now = this.limiter.config.to_epoch(Instant::now());
loop {
match cursor.protected() {
Some((_waker, start, n)) => {
match state.add_tokens_fast(this.limiter.config.bucket_width, *start, now, *n) {
Ok(()) => {
let (waker, _, _) = cursor.remove_current(None)
.map_err(|_| {}).expect("we have just checked that the current node is in the list");
waker.wake();
},
// next in the queue has to sleep
Err(ready_at) => {
let (waker, _, _) = cursor.remove_current(Some((state, ready_at)))
.map_err(|_| {}).expect("we have just checked that the current node is in the list");
waker.wake();
break;
}
}
},
// no tasks left in the queue. unlocked
None => {
q.state = Some(state);
break;
}
}
}
}
}
}
impl std::future::Future for Enqueued<'_> {
type Output = u64;
fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
let this = self.project();
let mut node = this.entry;
#[allow(clippy::mut_mutex_lock, reason = "false positive")]
let mut q = this.limiter.queue.lock().unwrap();
if let Some(init) = node.as_mut().initialized_mut() {
// we are registered in the queue
match init.take_removed(&q.queue) {
// if we are removed from the queue, that means we are the new leader
Ok((state, start_count)) => {
*this.state = state;
*this.sleep_counter = q.sleep_counter;
Poll::Ready(start_count)
}
// if we are not removed from the queue, that means we are still waiting
// and had a spurious wake up
Err(init) => {
init.protected_mut(&mut q.queue)
.unwrap()
.0
.clone_from(cx.waker());
Poll::Pending
}
}
} else {
// we are not yet registered in the queue
let start_count = q.sleep_counter;
if let Some(state) = q.state.take() {
// we are the first in the queue and it is not yet acquired
*this.state = Some((state, *this.start));
*this.sleep_counter = q.sleep_counter;
Poll::Ready(start_count)
} else {
// we push ourselves to the back of the queue
q.queue.push_back(
node,
(cx.waker().clone(), *this.start, *this.count),
start_count,
);
Poll::Pending
}
}
}
}
impl RateLimiter {
pub fn with_initial_tokens(config: LeakyBucketConfig, initial_tokens: f64) -> Self {
RateLimiter {
sleep_counter: AtomicU64::new(0),
state: Mutex::new(LeakyBucketState::with_initial_tokens(
&config,
initial_tokens,
)),
queue: Mutex::new(Queue {
sleep_counter: 0,
queue: PinList::new(pin_list::id::Checked::new()),
state: Some(LeakyBucketState::with_initial_tokens(
&config,
initial_tokens,
)),
}),
config,
queue: {
let queue = Notify::new();
queue.notify_one();
queue
},
}
}
@@ -171,46 +408,29 @@ impl RateLimiter {
/// returns true if we did throttle
pub async fn acquire(&self, count: usize) -> bool {
let start = tokio::time::Instant::now();
let mut entry = std::pin::pin!(self.wait(count));
let start_count = entry.as_mut().await;
let entry = entry.project();
let start_count = self.sleep_counter.load(Ordering::Acquire);
let mut end_count = start_count;
// wait until we are the first in the queue
let mut notified = std::pin::pin!(self.queue.notified());
if !notified.as_mut().enable() {
notified.await;
end_count = self.sleep_counter.load(Ordering::Acquire);
}
// notify the next waiter in the queue when we are done.
let _guard = Requeue(&self.queue);
let Some((state, ready_at)) = entry.state.as_mut() else {
// we were woken up without the state,
// thus the state leader must have allowed us to continue
return start_count < *entry.sleep_counter;
};
let mut now = self.config.to_epoch(Instant::now());
loop {
let res = self
.state
.lock()
.unwrap()
.add_tokens(&self.config, start, count as f64);
match res {
Ok(()) => return end_count > start_count,
Err(ready_at) => {
struct Increment<'a>(&'a AtomicU64);
if *ready_at > now {
*entry.sleep_counter += 1;
tokio::time::sleep_until(self.config.from_epoch(*ready_at)).await;
now = self.config.to_epoch(Instant::now());
}
impl Drop for Increment<'_> {
fn drop(&mut self) {
self.0.fetch_add(1, Ordering::AcqRel);
}
}
// increment the counter after we finish sleeping (or cancel this task).
// this ensures that tasks that have already started the acquire will observe
// the new sleep count when they are allowed to resume on the notify.
let _inc = Increment(&self.sleep_counter);
end_count += 1;
tokio::time::sleep_until(ready_at).await;
}
match state.add_tokens_fast(self.config.bucket_width, *entry.start, now, *entry.count) {
Ok(()) => return start_count < *entry.sleep_counter,
// we might hit this branch if we were the first in the queue
// and the limit happened to be exhausted already
Err(next_ready_at) => *ready_at = next_ready_at,
}
}
}
@@ -226,16 +446,10 @@ mod tests {
#[tokio::test(start_paused = true)]
async fn check() {
let config = LeakyBucketConfig {
// average 100rps
cost: Duration::from_millis(10),
// burst up to 100 requests
bucket_width: Duration::from_millis(1000),
};
let mut state = LeakyBucketState {
empty_at: Instant::now(),
};
// average 100rps
// burst up to 100 requests
let config = LeakyBucketConfig::new(100.0, 100.0);
let mut state = LeakyBucketState::with_initial_tokens(&config, 0.0);
// supports burst
{
@@ -251,7 +465,7 @@ mod tests {
{
// after 1s we should have an empty bucket again.
tokio::time::advance(Duration::from_secs(1)).await;
assert!(state.bucket_is_empty(Instant::now()));
assert!(state.bucket_is_empty(config.now()));
// after 1s more, we should not over count the tokens and allow more than 200 requests.
tokio::time::advance(Duration::from_secs(1)).await;
@@ -278,7 +492,7 @@ mod tests {
{
// start the bucket completely empty
tokio::time::advance(Duration::from_secs(5)).await;
assert!(state.bucket_is_empty(Instant::now()));
assert!(state.bucket_is_empty(config.now()));
// requesting 200 tokens of space should take 200*cost = 2s
// but we already have 1s available, so we wait 1s from start.

View File

@@ -6,9 +6,8 @@ use std::{
use ahash::RandomState;
use dashmap::DashMap;
use rand::{thread_rng, Rng};
use tokio::time::Instant;
use tracing::info;
use utils::leaky_bucket::LeakyBucketState;
use utils::leaky_bucket::{InstantNs, LeakyBucketState};
use crate::intern::EndpointIdInt;
@@ -37,7 +36,7 @@ impl<K: Hash + Eq> LeakyBucketRateLimiter<K> {
/// Check that number of connections to the endpoint is below `max_rps` rps.
pub(crate) fn check(&self, key: K, n: u32) -> bool {
let now = Instant::now();
let now = self.config.now();
if self.access_count.fetch_add(1, Ordering::AcqRel) % 2048 == 0 {
self.do_gc(now);
@@ -48,10 +47,17 @@ impl<K: Hash + Eq> LeakyBucketRateLimiter<K> {
.entry(key)
.or_insert_with(|| LeakyBucketState { empty_at: now });
entry.add_tokens(&self.config, now, n as f64).is_ok()
entry
.add_tokens_fast(
self.config.bucket_width,
now,
now,
self.config.cost * (n as u64),
)
.is_ok()
}
fn do_gc(&self, now: Instant) {
fn do_gc(&self, now: InstantNs) {
info!(
"cleaning up bucket rate limiter, current size = {}",
self.map.len()
@@ -90,7 +96,7 @@ mod tests {
use std::time::Duration;
use tokio::time::Instant;
use utils::leaky_bucket::LeakyBucketState;
use utils::leaky_bucket::{LeakyBucketState, Ns};
use super::LeakyBucketConfig;
@@ -98,11 +104,11 @@ mod tests {
async fn check() {
let config: utils::leaky_bucket::LeakyBucketConfig =
LeakyBucketConfig::new(500.0, 2000.0).into();
assert_eq!(config.cost, Duration::from_millis(2));
assert_eq!(config.bucket_width, Duration::from_secs(4));
assert_eq!(config.cost, Ns::from(Duration::from_millis(2)));
assert_eq!(config.bucket_width, Ns::from(Duration::from_secs(4)));
let mut bucket = LeakyBucketState {
empty_at: Instant::now(),
empty_at: config.now(),
};
// should work for 2000 requests this second
@@ -110,7 +116,7 @@ mod tests {
bucket.add_tokens(&config, Instant::now(), 1.0).unwrap();
}
bucket.add_tokens(&config, Instant::now(), 1.0).unwrap_err();
assert_eq!(bucket.empty_at - Instant::now(), config.bucket_width);
assert_eq!(bucket.empty_at - config.now(), config.bucket_width);
// in 1ms we should drain 0.5 tokens.
// make sure we don't lose any tokens