diff --git a/Cargo.lock b/Cargo.lock index 5eac648fd9..5ef94063b5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1598,6 +1598,20 @@ dependencies = [ "parking_lot_core 0.9.8", ] +[[package]] +name = "dashmap" +version = "6.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "23fadfd577acfd4485fb258011b0fd080882ea83359b6fd41304900b94ccf487" +dependencies = [ + "cfg-if", + "crossbeam-utils", + "hashbrown 0.14.5", + "lock_api", + "once_cell", + "parking_lot_core 0.9.8", +] + [[package]] name = "data-encoding" version = "2.4.0" @@ -2848,7 +2862,7 @@ version = "0.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4644821e1c3d7a560fe13d842d13f587c07348a1a05d3a797152d41c90c56df2" dependencies = [ - "dashmap", + "dashmap 5.5.0", "hashbrown 0.13.2", ] @@ -4296,7 +4310,7 @@ dependencies = [ "clap", "consumption_metrics", "crossbeam-deque", - "dashmap", + "dashmap 6.0.0", "env_logger", "fallible-iterator", "framed-websockets", diff --git a/Cargo.toml b/Cargo.toml index 8fddaaef12..173d5c2ccd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -77,7 +77,7 @@ const_format = "0.2" crc32c = "0.6" crossbeam-deque = "0.8.5" crossbeam-utils = "0.8.5" -dashmap = { version = "5.5.0", features = ["raw-api"] } +dashmap = { version = "6.0", features = ["raw-api"] } either = "1.8" enum-map = "2.4.2" enumset = "1.0.12" diff --git a/proxy/src/cache/project_info.rs b/proxy/src/cache/project_info.rs index 10cc4ceee1..77e5937215 100644 --- a/proxy/src/cache/project_info.rs +++ b/proxy/src/cache/project_info.rs @@ -305,7 +305,7 @@ impl ProjectInfoCacheImpl { // acquire a random shard lock let mut removed = 0; let shard = self.project2ep.shards()[shard].write(); - for (_, endpoints) in shard.iter() { + for (_, endpoints) in crate::rawtable::iter(&*shard) { for endpoint in endpoints.get().iter() { self.cache.remove(endpoint); removed += 1; diff --git a/proxy/src/console/provider.rs b/proxy/src/console/provider.rs index 915c2ee7a6..846136089a 100644 --- a/proxy/src/console/provider.rs +++ b/proxy/src/console/provider.rs @@ -517,11 +517,18 @@ impl ApiLocks { ); let mut lock = shard.write(); let timer = self.metrics.reclamation_lag_seconds.start_timer(); - let count = lock - .extract_if(|_, semaphore| Arc::strong_count(semaphore.get_mut()) == 1) - .count(); + + let mut removed = 0; + crate::rawtable::retain(&mut *lock, |_, semaphore| { + let remove = Arc::strong_count(semaphore.get_mut()) == 1; + if remove { + removed += 1; + } + !remove + }); + drop(lock); - self.metrics.semaphores_unregistered.inc_by(count as u64); + self.metrics.semaphores_unregistered.inc_by(removed as u64); timer.observe(); } } diff --git a/proxy/src/lib.rs b/proxy/src/lib.rs index ea92eaaa55..f11d68bac7 100644 --- a/proxy/src/lib.rs +++ b/proxy/src/lib.rs @@ -25,6 +25,7 @@ pub mod parse; pub mod protocol2; pub mod proxy; pub mod rate_limiter; +mod rawtable; pub mod redis; pub mod sasl; pub mod scram; diff --git a/proxy/src/rawtable.rs b/proxy/src/rawtable.rs new file mode 100644 index 0000000000..3a66e852da --- /dev/null +++ b/proxy/src/rawtable.rs @@ -0,0 +1,61 @@ +//! Dashmap moved to using RawTable for the shards. +//! Some of the APIs we used before are unsafe to access, but we can copy the implementations from the safe +//! HashMap wrappers for our needs. + +// Safety info: All implementations here are taken directly from hashbrown HashMap impl. + +use std::marker::PhantomData; + +use hashbrown::raw; + +// taken from https://docs.rs/hashbrown/0.14.5/src/hashbrown/map.rs.html#919-932 +pub fn retain(table: &mut raw::RawTable<(K, V)>, mut f: F) +where + F: FnMut(&K, &mut V) -> bool, +{ + // SAFETY: Here we only use `iter` as a temporary, preventing use-after-free + unsafe { + for item in table.iter() { + let &mut (ref key, ref mut value) = item.as_mut(); + if !f(key, value) { + table.erase(item); + } + } + } +} + +// taken from https://docs.rs/hashbrown/0.14.5/src/hashbrown/map.rs.html#756-764 +pub fn iter(table: &raw::RawTable<(K, V)>) -> impl Iterator + '_ { + pub struct Iter<'a, K, V> { + inner: raw::RawIter<(K, V)>, + marker: PhantomData<(&'a K, &'a V)>, + } + + impl<'a, K, V> Iterator for Iter<'a, K, V> { + type Item = (&'a K, &'a V); + + #[cfg_attr(feature = "inline-more", inline)] + fn next(&mut self) -> Option<(&'a K, &'a V)> { + let x = self.inner.next()?; + // SAFETY: the borrows do not outlive the rawtable + unsafe { + let r = x.as_ref(); + Some((&r.0, &r.1)) + } + } + #[cfg_attr(feature = "inline-more", inline)] + fn size_hint(&self) -> (usize, Option) { + self.inner.size_hint() + } + } + + // SAFETY: + // > It is up to the caller to ensure that the RawTable outlives the RawIter + // Here we tie the lifetime of self to the iter. + unsafe { + Iter { + inner: table.iter(), + marker: PhantomData, + } + } +} diff --git a/proxy/src/serverless/conn_pool.rs b/proxy/src/serverless/conn_pool.rs index 5fa253acf8..6ea10cb7ca 100644 --- a/proxy/src/serverless/conn_pool.rs +++ b/proxy/src/serverless/conn_pool.rs @@ -324,7 +324,8 @@ impl GlobalConnPool { .start_timer(); let current_len = shard.len(); let mut clients_removed = 0; - shard.retain(|endpoint, x| { + + crate::rawtable::retain(&mut *shard, |endpoint, x| { // if the current endpoint pool is unique (no other strong or weak references) // then it is currently not in use by any connections. if let Some(pool) = Arc::get_mut(x.get_mut()) {