From b1e3161d4e0f602e0f56efa854a171312c00d6a1 Mon Sep 17 00:00:00 2001 From: David Freifeld Date: Thu, 26 Jun 2025 14:32:32 -0700 Subject: [PATCH] Satisfy `cargo clippy` lints, simplify shrinking API --- libs/neon-shmem/src/hash.rs | 193 ++++++++++++++++++------------ libs/neon-shmem/src/hash/core.rs | 24 ++-- libs/neon-shmem/src/hash/entry.rs | 28 +++-- libs/neon-shmem/src/hash/tests.rs | 40 +------ libs/neon-shmem/src/shmem.rs | 92 +++++++------- 5 files changed, 200 insertions(+), 177 deletions(-) diff --git a/libs/neon-shmem/src/hash.rs b/libs/neon-shmem/src/hash.rs index a3d465db93..36fbb1112c 100644 --- a/libs/neon-shmem/src/hash.rs +++ b/libs/neon-shmem/src/hash.rs @@ -1,8 +1,8 @@ -//! Resizable hash table implementation on top of byte-level storage (either `shmem` or fixed byte array). +//! Resizable hash table implementation on top of byte-level storage (either a [`ShmemHandle`] or a fixed byte array). //! //! This hash table has two major components: the bucket array and the dictionary. Each bucket within the -//! bucket array contains a Option<(K, V)> and an index of another bucket. In this way there is both an -//! implicit freelist within the bucket array (None buckets point to other None entries) and various hash +//! bucket array contains a `Option<(K, V)>` and an index of another bucket. In this way there is both an +//! implicit freelist within the bucket array (`None` buckets point to other `None` entries) and various hash //! chains within the bucket array (a Some bucket will point to other Some buckets that had the same hash). //! //! Buckets are never moved unless they are within a region that is being shrunk, and so the actual hash- @@ -10,14 +10,15 @@ //! within the dictionary is decided based on its hash, the data is inserted into an empty bucket based //! off of the freelist, and then the index of said bucket is placed in the dictionary. //! -//! This map is resizable (if initialized on top of a `ShmemHandle`). Both growing and shrinking happen +//! This map is resizable (if initialized on top of a [`ShmemHandle`]). Both growing and shrinking happen //! in-place and are at a high level achieved by expanding/reducing the bucket array and rebuilding the //! dictionary by rehashing all keys. use std::hash::{Hash, BuildHasher}; use std::mem::MaybeUninit; +use std::default::Default; -use crate::shmem::ShmemHandle; +use crate::{shmem, shmem::ShmemHandle}; mod core; pub mod entry; @@ -28,11 +29,13 @@ mod tests; use core::{Bucket, CoreHashMap, INVALID_POS}; use entry::{Entry, OccupiedEntry}; -/// Builder for a `HashMapAccess`. +/// Builder for a [`HashMapAccess`]. +#[must_use] pub struct HashMapInit<'a, K, V, S = rustc_hash::FxBuildHasher> { shmem_handle: Option, shared_ptr: *mut HashMapShared<'a, K, V>, shared_size: usize, + shrink_mode: HashMapShrinkMode, hasher: S, num_buckets: u32, } @@ -42,12 +45,34 @@ pub struct HashMapAccess<'a, K, V, S = rustc_hash::FxBuildHasher> { shmem_handle: Option, shared_ptr: *mut HashMapShared<'a, K, V>, hasher: S, + shrink_mode: HashMapShrinkMode, } -unsafe impl<'a, K: Sync, V: Sync, S> Sync for HashMapAccess<'a, K, V, S> {} -unsafe impl<'a, K: Send, V: Send, S> Send for HashMapAccess<'a, K, V, S> {} +/// Enum specifying what behavior to have surrounding occupied entries in what is +/// about-to-be-shrinked space during a call to [`HashMapAccess::finish_shrink`]. +#[derive(PartialEq, Eq)] +pub enum HashMapShrinkMode { + /// Remap entry to the range of buckets that will remain after shrinking. + /// + /// Requires that caller has left enough room within the map such that this is possible. + Remap, + /// Remove any entries remaining in soon to be deallocated space. + /// + /// Only really useful if you legitimately do not care what entries are removed. + /// Should primarily be used for testing. + Remove, +} -impl<'a, K: Clone + Hash + Eq, V, S> HashMapInit<'a, K, V, S> { +impl Default for HashMapShrinkMode { + fn default() -> Self { + Self::Remap + } +} + +unsafe impl Sync for HashMapAccess<'_, K, V, S> {} +unsafe impl Send for HashMapAccess<'_, K, V, S> {} + +impl<'a, K: Clone + Hash + Eq, V, S> HashMapInit<'a, K, V, S> { pub fn with_hasher(self, hasher: T) -> HashMapInit<'a, K, V, T> { HashMapInit { hasher, @@ -55,9 +80,14 @@ impl<'a, K: Clone + Hash + Eq, V, S> HashMapInit<'a, K, V, S> { shared_ptr: self.shared_ptr, shared_size: self.shared_size, num_buckets: self.num_buckets, + shrink_mode: self.shrink_mode, } } + pub fn with_shrink_mode(self, mode: HashMapShrinkMode) -> Self { + Self { shrink_mode: mode, ..self } + } + /// Loosely (over)estimate the size needed to store a hash table with `num_buckets` buckets. pub fn estimate_size(num_buckets: u32) -> usize { // add some margin to cover alignment etc. @@ -98,11 +128,12 @@ impl<'a, K: Clone + Hash + Eq, V, S> HashMapInit<'a, K, V, S> { HashMapAccess { shmem_handle: self.shmem_handle, shared_ptr: self.shared_ptr, + shrink_mode: self.shrink_mode, hasher: self.hasher, } } - /// Initialize a table for reading. Currently identical to `attach_writer`. + /// Initialize a table for reading. Currently identical to [`HashMapInit::attach_writer`]. pub fn attach_reader(self) -> HashMapAccess<'a, K, V, S> { self.attach_writer() } @@ -114,7 +145,7 @@ impl<'a, K: Clone + Hash + Eq, V, S> HashMapInit<'a, K, V, S> { /// relies on the memory layout! The data structures are laid out in the contiguous shared memory /// area as follows: /// -/// HashMapShared +/// [`HashMapShared`] /// [buckets] /// [dictionary] /// @@ -131,18 +162,22 @@ where pub fn with_fixed( num_buckets: u32, area: &'a mut [MaybeUninit], - ) -> HashMapInit<'a, K, V> { + ) -> Self { Self { num_buckets, shmem_handle: None, shared_ptr: area.as_mut_ptr().cast(), shared_size: area.len(), - hasher: rustc_hash::FxBuildHasher::default(), + shrink_mode: HashMapShrinkMode::default(), + hasher: rustc_hash::FxBuildHasher, } } /// Place a new hash map in the given shared memory area - pub fn with_shmem(num_buckets: u32, shmem: ShmemHandle) -> HashMapInit<'a, K, V> { + /// + /// # Panics + /// Will panic on failure to resize area to expected map size. + pub fn with_shmem(num_buckets: u32, shmem: ShmemHandle) -> Self { let size = Self::estimate_size(num_buckets); shmem .set_size(size) @@ -152,12 +187,13 @@ where shared_ptr: shmem.data_ptr.as_ptr().cast(), shmem_handle: Some(shmem), shared_size: size, - hasher: rustc_hash::FxBuildHasher::default() + shrink_mode: HashMapShrinkMode::default(), + hasher: rustc_hash::FxBuildHasher } } /// Make a resizable hash map within a new shared memory area with the given name. - pub fn new_resizeable_named(num_buckets: u32, max_buckets: u32, name: &str) -> HashMapInit<'a, K, V> { + pub fn new_resizeable_named(num_buckets: u32, max_buckets: u32, name: &str) -> Self { let size = Self::estimate_size(num_buckets); let max_size = Self::estimate_size(max_buckets); let shmem = ShmemHandle::new(name, size, max_size) @@ -168,16 +204,17 @@ where shared_ptr: shmem.data_ptr.as_ptr().cast(), shmem_handle: Some(shmem), shared_size: size, - hasher: rustc_hash::FxBuildHasher::default() + shrink_mode: HashMapShrinkMode::default(), + hasher: rustc_hash::FxBuildHasher } } /// Make a resizable hash map within a new anonymous shared memory area. - pub fn new_resizeable(num_buckets: u32, max_buckets: u32) -> HashMapInit<'a, K, V> { + pub fn new_resizeable(num_buckets: u32, max_buckets: u32) -> Self { use std::sync::atomic::{AtomicUsize, Ordering}; - const COUNTER: AtomicUsize = AtomicUsize::new(0); + static COUNTER: AtomicUsize = AtomicUsize::new(0); let val = COUNTER.fetch_add(1, Ordering::Relaxed); - let name = format!("neon_shmem_hmap{}", val); + let name = format!("neon_shmem_hmap{val}"); Self::new_resizeable_named(num_buckets, max_buckets, &name) } } @@ -214,7 +251,7 @@ where e.remove(); } Entry::Vacant(_) => {} - }; + } } /// Optionally return the entry for a bucket at a given index if it exists. @@ -249,7 +286,7 @@ where let map = unsafe { self.shared_ptr.as_ref() }.unwrap(); let origin = map.inner.buckets.as_ptr(); - let idx = (val_ptr as usize - origin as usize) / (size_of::>() as usize); + let idx = (val_ptr as usize - origin as usize) / size_of::>(); assert!(idx < map.inner.buckets.len()); idx @@ -265,14 +302,14 @@ where pub fn clear(&mut self) { let map = unsafe { self.shared_ptr.as_mut() }.unwrap(); let inner = &mut map.inner; - inner.clear() + inner.clear(); } /// Perform an in-place rehash of some region (0..`rehash_buckets`) of the table and reset /// the `buckets` and `dictionary` slices to be as long as `num_buckets`. Resets the freelist /// in the process. fn rehash_dict( - &mut self, + &self, inner: &mut CoreHashMap<'a, K, V>, buckets_ptr: *mut core::Bucket, end_ptr: *mut u8, @@ -293,22 +330,21 @@ where buckets = std::slice::from_raw_parts_mut(buckets_ptr, num_buckets as usize); dictionary = std::slice::from_raw_parts_mut(dictionary_ptr, dictionary_size); - (dictionary_ptr, dictionary_size) } - for i in 0..dictionary.len() { - dictionary[i] = INVALID_POS; + for e in dictionary.iter_mut() { + *e = INVALID_POS; } - for i in 0..rehash_buckets as usize { - if buckets[i].inner.is_none() { - buckets[i].next = inner.free_head; + for (i, bucket) in buckets.iter_mut().enumerate().take(rehash_buckets as usize) { + if bucket.inner.is_none() { + bucket.next = inner.free_head; inner.free_head = i as u32; continue; } - let hash = self.hasher.hash_one(&buckets[i].inner.as_ref().unwrap().0); + let hash = self.hasher.hash_one(&bucket.inner.as_ref().unwrap().0); let pos: usize = (hash % dictionary.len() as u64) as usize; - buckets[i].next = dictionary[pos]; + bucket.next = dictionary[pos]; dictionary[pos] = i as u32; } @@ -322,7 +358,7 @@ where let inner = &mut map.inner; let num_buckets = inner.get_num_buckets() as u32; let size_bytes = HashMapInit::::estimate_size(num_buckets); - let end_ptr: *mut u8 = unsafe { (self.shared_ptr as *mut u8).add(size_bytes) }; + let end_ptr: *mut u8 = unsafe { self.shared_ptr.byte_add(size_bytes).cast() }; let buckets_ptr = inner.buckets.as_mut_ptr(); self.rehash_dict(inner, buckets_ptr, end_ptr, num_buckets, num_buckets); } @@ -332,14 +368,18 @@ where /// 1. Grows the underlying shared memory area /// 2. Initializes new buckets and overwrites the current dictionary /// 3. Rehashes the dictionary - pub fn grow(&mut self, num_buckets: u32) -> Result<(), crate::shmem::Error> { + /// + /// # Panics + /// Panics if called on a map initialized with [`HashMapInit::with_fixed`]. + /// + /// # Errors + /// Returns an [`shmem::Error`] if any errors occur resizing the memory region. + pub fn grow(&mut self, num_buckets: u32) -> Result<(), shmem::Error> { let map = unsafe { self.shared_ptr.as_mut() }.unwrap(); let inner = &mut map.inner; let old_num_buckets = inner.buckets.len() as u32; - if num_buckets < old_num_buckets { - panic!("grow called with a smaller number of buckets"); - } + assert!(num_buckets >= old_num_buckets, "grow called with a smaller number of buckets"); if num_buckets == old_num_buckets { return Ok(()); } @@ -352,15 +392,15 @@ where shmem_handle.set_size(size_bytes)?; let end_ptr: *mut u8 = unsafe { shmem_handle.data_ptr.as_ptr().add(size_bytes) }; - // Initialize new buckets. The new buckets are linked to the free list. NB: This overwrites - // the dictionary! + // Initialize new buckets. The new buckets are linked to the free list. + // NB: This overwrites the dictionary! let buckets_ptr = inner.buckets.as_mut_ptr(); unsafe { for i in old_num_buckets..num_buckets { - let bucket_ptr = buckets_ptr.add(i as usize); - bucket_ptr.write(core::Bucket { + let bucket = buckets_ptr.add(i as usize); + bucket.write(core::Bucket { next: if i < num_buckets-1 { - i as u32 + 1 + i + 1 } else { inner.free_head }, @@ -376,11 +416,16 @@ where } /// Begin a shrink, limiting all new allocations to be in buckets with index below `num_buckets`. + /// + /// # Panics + /// Panics if called on a map initialized with [`HashMapInit::with_fixed`] or if `num_buckets` is + /// greater than the number of buckets in the map. pub fn begin_shrink(&mut self, num_buckets: u32) { let map = unsafe { self.shared_ptr.as_mut() }.unwrap(); - if num_buckets > map.inner.get_num_buckets() as u32 { - panic!("shrink called with a larger number of buckets"); - } + assert!( + num_buckets <= map.inner.get_num_buckets() as u32, + "shrink called with a larger number of buckets" + ); _ = self .shmem_handle .as_ref() @@ -388,47 +433,47 @@ where map.inner.alloc_limit = num_buckets; } - /// Returns whether a shrink operation is currently in progress. - pub fn is_shrinking(&self) -> bool { + /// If a shrink operation is underway, returns the target size of the map. Otherwise, returns None. + pub fn shrink_goal(&self) -> Option { let map = unsafe { self.shared_ptr.as_mut() }.unwrap(); - map.inner.is_shrinking() + let goal = map.inner.alloc_limit; + if goal == INVALID_POS { None } else { Some(goal as usize) } } - /// Returns how many entries need to be evicted before shrink can complete. - pub fn shrink_remaining(&self) -> usize { - let map = unsafe { self.shared_ptr.as_mut() }.unwrap(); - let inner = &mut map.inner; - if !inner.is_shrinking() { - panic!("shrink_remaining called when no ongoing shrink") - } else { - inner.buckets_in_use - .checked_sub(inner.alloc_limit) - .unwrap_or(0) - as usize - } - } - /// Complete a shrink after caller has evicted entries, removing the unused buckets and rehashing. - pub fn finish_shrink(&mut self) -> Result<(), crate::shmem::Error> { + /// + /// # Panics + /// The following cases result in a panic: + /// - Calling this function on a map initialized with [`HashMapInit::with_fixed`]. + /// - Calling this function on a map when no shrink operation is in progress. + /// - Calling this function on a map with `shrink_mode` set to [`HashMapShrinkMode::Remap`] and + /// [`HashMapAccess::get_num_buckets_in_use`] returns a value higher than [`HashMapAccess::shrink_goal`]. + /// + /// # Errors + /// Returns an [`shmem::Error`] if any errors occur resizing the memory region. + pub fn finish_shrink(&mut self) -> Result<(), shmem::Error> { let map = unsafe { self.shared_ptr.as_mut() }.unwrap(); let inner = &mut map.inner; - if !inner.is_shrinking() { - panic!("called finish_shrink when no shrink is in progress"); - } + assert!(inner.is_shrinking(), "called finish_shrink when no shrink is in progress"); let num_buckets = inner.alloc_limit; if inner.get_num_buckets() == num_buckets as usize { return Ok(()); - } else if inner.buckets_in_use > num_buckets { - panic!("called finish_shrink before enough entries were removed"); - } + } - for i in (num_buckets as usize)..inner.buckets.len() { - if let Some((k, v)) = inner.buckets[i].inner.take() { - // alloc bucket increases buckets in use, so need to decrease since we're just moving - inner.buckets_in_use -= 1; - inner.alloc_bucket(k, v).unwrap(); + if self.shrink_mode == HashMapShrinkMode::Remap { + assert!( + inner.buckets_in_use <= num_buckets, + "called finish_shrink before enough entries were removed" + ); + + for i in (num_buckets as usize)..inner.buckets.len() { + if let Some((k, v)) = inner.buckets[i].inner.take() { + // alloc bucket increases buckets in use, so need to decrease since we're just moving + inner.buckets_in_use -= 1; + inner.alloc_bucket(k, v).unwrap(); + } } } diff --git a/libs/neon-shmem/src/hash/core.rs b/libs/neon-shmem/src/hash/core.rs index 22c44f20ac..b2cf788d21 100644 --- a/libs/neon-shmem/src/hash/core.rs +++ b/libs/neon-shmem/src/hash/core.rs @@ -5,6 +5,7 @@ use std::mem::MaybeUninit; use crate::hash::entry::{Entry, OccupiedEntry, PrevPos, VacantEntry}; +/// Invalid position within the map (either within the dictionary or bucket array). pub(crate) const INVALID_POS: u32 = u32::MAX; /// Fundamental storage unit within the hash table. Either empty or contains a key-value pair. @@ -18,13 +19,13 @@ pub(crate) struct Bucket { /// Core hash table implementation. pub(crate) struct CoreHashMap<'a, K, V> { - /// Dictionary used to map hashes to bucket indices. + /// Dictionary used to map hashes to bucket indices. pub(crate) dictionary: &'a mut [u32], /// Buckets containing key-value pairs. pub(crate) buckets: &'a mut [Bucket], /// Head of the freelist. pub(crate) free_head: u32, - /// Maximum index of a bucket allowed to be allocated. INVALID_POS if no limit. + /// Maximum index of a bucket allowed to be allocated. [`INVALID_POS`] if no limit. pub(crate) alloc_limit: u32, /// The number of currently occupied buckets. pub(crate) buckets_in_use: u32, @@ -36,10 +37,7 @@ pub(crate) struct CoreHashMap<'a, K, V> { #[derive(Debug)] pub struct FullError(); -impl<'a, K: Hash + Eq, V> CoreHashMap<'a, K, V> -where - K: Clone + Hash + Eq, -{ +impl<'a, K: Clone + Hash + Eq, V> CoreHashMap<'a, K, V> { const FILL_FACTOR: f32 = 0.60; /// Estimate the size of data contained within the the hash map. @@ -59,7 +57,7 @@ where pub fn new( buckets: &'a mut [MaybeUninit>], dictionary: &'a mut [MaybeUninit], - ) -> CoreHashMap<'a, K, V> { + ) -> Self { // Initialize the buckets for i in 0..buckets.len() { buckets[i].write(Bucket { @@ -73,8 +71,8 @@ where } // Initialize the dictionary - for i in 0..dictionary.len() { - dictionary[i].write(INVALID_POS); + for e in dictionary.iter_mut() { + e.write(INVALID_POS); } // TODO: use std::slice::assume_init_mut() once it stabilizes @@ -84,7 +82,7 @@ where std::slice::from_raw_parts_mut(dictionary.as_mut_ptr().cast(), dictionary.len()) }; - CoreHashMap { + Self { dictionary, buckets, free_head: 0, @@ -105,13 +103,13 @@ where let bucket = &self.buckets[next as usize]; let (bucket_key, bucket_value) = bucket.inner.as_ref().expect("entry is in use"); if bucket_key == key { - return Some(&bucket_value); + return Some(bucket_value); } next = bucket.next; } } - /// Get the `Entry` associated with a key given hash. This should be used for updates/inserts. + /// Get the [`Entry`] associated with a key given hash. This should be used for updates/inserts. pub fn entry_with_hash(&mut self, key: K, hash: u64) -> Entry<'a, '_, K, V> { let dict_pos = hash as usize % self.dictionary.len(); let first = self.dictionary[dict_pos]; @@ -236,7 +234,7 @@ where bucket.next = INVALID_POS; bucket.inner = Some((key, value)); - return Ok(pos); + Ok(pos) } } diff --git a/libs/neon-shmem/src/hash/entry.rs b/libs/neon-shmem/src/hash/entry.rs index 24c124189b..5231061b8e 100644 --- a/libs/neon-shmem/src/hash/entry.rs +++ b/libs/neon-shmem/src/hash/entry.rs @@ -1,4 +1,4 @@ -//! Like std::collections::hash_map::Entry; +//! Equivalent of [`std::collections::hash_map::Entry`] for this hashmap. use crate::hash::core::{CoreHashMap, FullError, INVALID_POS}; @@ -30,11 +30,11 @@ pub struct OccupiedEntry<'a, 'b, K, V> { pub(crate) _key: K, /// The index of the previous entry in the chain. pub(crate) prev_pos: PrevPos, - /// The position of the bucket in the CoreHashMap's buckets array. + /// The position of the bucket in the [`CoreHashMap`] bucket array. pub(crate) bucket_pos: u32, } -impl<'a, 'b, K, V> OccupiedEntry<'a, 'b, K, V> { +impl OccupiedEntry<'_, '_, K, V> { pub fn get(&self) -> &V { &self.map.buckets[self.bucket_pos as usize] .inner @@ -55,20 +55,25 @@ impl<'a, 'b, K, V> OccupiedEntry<'a, 'b, K, V> { pub fn insert(&mut self, value: V) -> V { let bucket = &mut self.map.buckets[self.bucket_pos as usize]; // This assumes inner is Some, which it must be for an OccupiedEntry - let old_value = mem::replace(&mut bucket.inner.as_mut().unwrap().1, value); - old_value + mem::replace(&mut bucket.inner.as_mut().unwrap().1, value) } /// Removes the entry from the hash map, returning the value originally stored within it. + /// + /// # Panics + /// Panics if the `prev_pos` field is equal to [`PrevPos::Unknown`]. In practice, this means + /// the entry was obtained via calling something like [`CoreHashMap::entry_at_bucket`]. pub fn remove(self) -> V { // CoreHashMap::remove returns Option<(K, V)>. We know it's Some for an OccupiedEntry. let bucket = &mut self.map.buckets[self.bucket_pos as usize]; // unlink it from the chain match self.prev_pos { - PrevPos::First(dict_pos) => self.map.dictionary[dict_pos as usize] = bucket.next, + PrevPos::First(dict_pos) => { + self.map.dictionary[dict_pos as usize] = bucket.next; + }, PrevPos::Chained(bucket_pos) => { - self.map.buckets[bucket_pos as usize].next = bucket.next + self.map.buckets[bucket_pos as usize].next = bucket.next; }, PrevPos::Unknown => panic!("can't safely remove entry with unknown previous entry"), } @@ -80,7 +85,7 @@ impl<'a, 'b, K, V> OccupiedEntry<'a, 'b, K, V> { self.map.free_head = self.bucket_pos; self.map.buckets_in_use -= 1; - return old_value.unwrap().1; + old_value.unwrap().1 } } @@ -94,8 +99,11 @@ pub struct VacantEntry<'a, 'b, K, V> { pub(crate) dict_pos: u32, } -impl<'a, 'b, K: Clone + Hash + Eq, V> VacantEntry<'a, 'b, K, V> { +impl<'b, K: Clone + Hash + Eq, V> VacantEntry<'_, 'b, K, V> { /// Insert a value into the vacant entry, finding and populating an empty bucket in the process. + /// + /// # Errors + /// Will return [`FullError`] if there are no unoccupied buckets in the map. pub fn insert(self, value: V) -> Result<&'b mut V, FullError> { let pos = self.map.alloc_bucket(self.key, value)?; if pos == INVALID_POS { @@ -106,6 +114,6 @@ impl<'a, 'b, K: Clone + Hash + Eq, V> VacantEntry<'a, 'b, K, V> { self.map.dictionary[self.dict_pos as usize] = pos; let result = &mut self.map.buckets[pos as usize].inner.as_mut().unwrap().1; - return Ok(result); + Ok(result) } } diff --git a/libs/neon-shmem/src/hash/tests.rs b/libs/neon-shmem/src/hash/tests.rs index 8a6e8a0a29..209db599b5 100644 --- a/libs/neon-shmem/src/hash/tests.rs +++ b/libs/neon-shmem/src/hash/tests.rs @@ -1,14 +1,11 @@ use std::collections::BTreeMap; use std::collections::HashSet; -use std::fmt::{Debug, Formatter}; -use std::mem::uninitialized; +use std::fmt::Debug; use std::mem::MaybeUninit; -use std::sync::atomic::{AtomicUsize, Ordering}; use crate::hash::HashMapAccess; use crate::hash::HashMapInit; use crate::hash::Entry; -use crate::shmem::ShmemHandle; use rand::seq::SliceRandom; use rand::{Rng, RngCore}; @@ -98,30 +95,6 @@ fn sparse() { test_inserts(&keys); } -struct TestValue(AtomicUsize); - -impl TestValue { - fn new(val: usize) -> TestValue { - TestValue(AtomicUsize::new(val)) - } - - fn load(&self) -> usize { - self.0.load(Ordering::Relaxed) - } -} - -impl Clone for TestValue { - fn clone(&self) -> TestValue { - TestValue::new(self.load()) - } -} - -impl Debug for TestValue { - fn fmt(&self, fmt: &mut Formatter<'_>) -> Result<(), std::fmt::Error> { - write!(fmt, "{:?}", self.load()) - } -} - #[derive(Clone, Debug)] struct TestOp(TestKey, Option); @@ -177,7 +150,7 @@ fn do_deletes( writer: &mut HashMapAccess, shadow: &mut BTreeMap, ) { - for i in 0..num_ops { + for _ in 0..num_ops { let (k, _) = shadow.pop_first().unwrap(); let hash = writer.get_hash_value(&k); writer.remove_with_hash(&k, hash); @@ -187,7 +160,6 @@ fn do_deletes( fn do_shrink( writer: &mut HashMapAccess, shadow: &mut BTreeMap, - from: u32, to: u32 ) { writer.begin_shrink(to); @@ -195,7 +167,7 @@ fn do_shrink( let (k, _) = shadow.pop_first().unwrap(); let hash = writer.get_hash_value(&k); let entry = writer.entry_with_hash(k, hash); - if let Entry::Occupied(mut e) = entry { + if let Entry::Occupied(e) = entry { e.remove(); } } @@ -260,7 +232,7 @@ fn test_shrink() { let mut rng = rand::rng(); do_random_ops(10000, 1500, 0.75, &mut writer, &mut shadow, &mut rng); - do_shrink(&mut writer, &mut shadow, 1500, 1000); + do_shrink(&mut writer, &mut shadow, 1000); do_deletes(500, &mut writer, &mut shadow); do_random_ops(10000, 500, 0.75, &mut writer, &mut shadow, &mut rng); assert!(writer.get_num_buckets_in_use() <= 1000); @@ -276,13 +248,13 @@ fn test_shrink_grow_seq() { do_random_ops(500, 1000, 0.1, &mut writer, &mut shadow, &mut rng); eprintln!("Shrinking to 750"); - do_shrink(&mut writer, &mut shadow, 1000, 750); + do_shrink(&mut writer, &mut shadow, 750); do_random_ops(200, 1000, 0.5, &mut writer, &mut shadow, &mut rng); eprintln!("Growing to 1500"); writer.grow(1500).unwrap(); do_random_ops(600, 1500, 0.1, &mut writer, &mut shadow, &mut rng); eprintln!("Shrinking to 200"); - do_shrink(&mut writer, &mut shadow, 1500, 200); + do_shrink(&mut writer, &mut shadow, 200); do_deletes(100, &mut writer, &mut shadow); do_random_ops(50, 1500, 0.25, &mut writer, &mut shadow, &mut rng); eprintln!("Growing to 10k"); diff --git a/libs/neon-shmem/src/shmem.rs b/libs/neon-shmem/src/shmem.rs index 21b1454b10..7c7285f67e 100644 --- a/libs/neon-shmem/src/shmem.rs +++ b/libs/neon-shmem/src/shmem.rs @@ -12,14 +12,14 @@ use nix::sys::mman::mmap as nix_mmap; use nix::sys::mman::munmap as nix_munmap; use nix::unistd::ftruncate as nix_ftruncate; -/// ShmemHandle represents a shared memory area that can be shared by processes over fork(). -/// Unlike shared memory allocated by Postgres, this area is resizable, up to 'max_size' that's +/// `ShmemHandle` represents a shared memory area that can be shared by processes over `fork()`. +/// Unlike shared memory allocated by Postgres, this area is resizable, up to `max_size` that's /// specified at creation. /// -/// The area is backed by an anonymous file created with memfd_create(). The full address space for -/// 'max_size' is reserved up-front with mmap(), but whenever you call [`ShmemHandle::set_size`], +/// The area is backed by an anonymous file created with `memfd_create()`. The full address space for +/// `max_size` is reserved up-front with `mmap()`, but whenever you call [`ShmemHandle::set_size`], /// the underlying file is resized. Do not access the area beyond the current size. Currently, that -/// will cause the file to be expanded, but we might use mprotect() etc. to enforce that in the +/// will cause the file to be expanded, but we might use `mprotect()` etc. to enforce that in the /// future. pub struct ShmemHandle { /// memfd file descriptor @@ -38,7 +38,7 @@ pub struct ShmemHandle { struct SharedStruct { max_size: usize, - /// Current size of the backing file. The high-order bit is used for the RESIZE_IN_PROGRESS flag + /// Current size of the backing file. The high-order bit is used for the [`RESIZE_IN_PROGRESS`] flag. current_size: AtomicUsize, } @@ -46,7 +46,7 @@ const RESIZE_IN_PROGRESS: usize = 1 << 63; const HEADER_SIZE: usize = std::mem::size_of::(); -/// Error type returned by the ShmemHandle functions. +/// Error type returned by the [`ShmemHandle`] functions. #[derive(thiserror::Error, Debug)] #[error("{msg}: {errno}")] pub struct Error { @@ -55,8 +55,8 @@ pub struct Error { } impl Error { - fn new(msg: &str, errno: Errno) -> Error { - Error { + fn new(msg: &str, errno: Errno) -> Self { + Self { msg: msg.to_string(), errno, } @@ -65,11 +65,11 @@ impl Error { impl ShmemHandle { /// Create a new shared memory area. To communicate between processes, the processes need to be - /// fork()'d after calling this, so that the ShmemHandle is inherited by all processes. + /// `fork()`'d after calling this, so that the `ShmemHandle` is inherited by all processes. /// - /// If the ShmemHandle is dropped, the memory is unmapped from the current process. Other + /// If the `ShmemHandle` is dropped, the memory is unmapped from the current process. Other /// processes can continue using it, however. - pub fn new(name: &str, initial_size: usize, max_size: usize) -> Result { + pub fn new(name: &str, initial_size: usize, max_size: usize) -> Result { // create the backing anonymous file. let fd = create_backing_file(name)?; @@ -80,17 +80,17 @@ impl ShmemHandle { fd: OwnedFd, initial_size: usize, max_size: usize, - ) -> Result { - // We reserve the high-order bit for the RESIZE_IN_PROGRESS flag, and the actual size + ) -> Result { + // We reserve the high-order bit for the `RESIZE_IN_PROGRESS` flag, and the actual size // is a little larger than this because of the SharedStruct header. Make the upper limit // somewhat smaller than that, because with anything close to that, you'll run out of // memory anyway. - if max_size >= 1 << 48 { - panic!("max size {} too large", max_size); - } - if initial_size > max_size { - panic!("initial size {initial_size} larger than max size {max_size}"); - } + assert!(max_size < 1 << 48, "max size {max_size} too large"); + + assert!( + initial_size <= max_size, + "initial size {initial_size} larger than max size {max_size}" + ); // The actual initial / max size is the one given by the caller, plus the size of // 'SharedStruct'. @@ -110,7 +110,7 @@ impl ShmemHandle { 0, ) } - .map_err(|e| Error::new("mmap failed: {e}", e))?; + .map_err(|e| Error::new("mmap failed", e))?; // Reserve space for the initial size enlarge_file(fd.as_fd(), initial_size as u64)?; @@ -121,13 +121,13 @@ impl ShmemHandle { shared.write(SharedStruct { max_size: max_size.into(), current_size: AtomicUsize::new(initial_size), - }) - }; + }); + } // The user data begins after the header let data_ptr = unsafe { start_ptr.cast().add(HEADER_SIZE) }; - Ok(ShmemHandle { + Ok(Self { fd, max_size: max_size.into(), shared_ptr: shared, @@ -140,28 +140,28 @@ impl ShmemHandle { unsafe { self.shared_ptr.as_ref() } } - /// Resize the shared memory area. 'new_size' must not be larger than the 'max_size' specified + /// Resize the shared memory area. `new_size` must not be larger than the `max_size` specified /// when creating the area. /// /// This may only be called from one process/thread concurrently. We detect that case - /// and return an Error. + /// and return an [`shmem::Error`](Error). pub fn set_size(&self, new_size: usize) -> Result<(), Error> { let new_size = new_size + HEADER_SIZE; let shared = self.shared(); - if new_size > self.max_size { - panic!( - "new size ({} is greater than max size ({})", - new_size, self.max_size - ); - } - assert_eq!(self.max_size, shared.max_size); + assert!( + new_size <= self.max_size, + "new size ({new_size}) is greater than max size ({})", + self.max_size + ); - // Lock the area by setting the bit in 'current_size' + assert_eq!(self.max_size, shared.max_size); + + // Lock the area by setting the bit in `current_size` // // Ordering::Relaxed would probably be sufficient here, as we don't access any other memory - // and the posix_fallocate/ftruncate call is surely a synchronization point anyway. But - // since this is not performance-critical, better safe than sorry . + // and the `posix_fallocate`/`ftruncate` call is surely a synchronization point anyway. But + // since this is not performance-critical, better safe than sorry. let mut old_size = shared.current_size.load(Ordering::Acquire); loop { if (old_size & RESIZE_IN_PROGRESS) != 0 { @@ -188,7 +188,7 @@ impl ShmemHandle { use std::cmp::Ordering::{Equal, Greater, Less}; match new_size.cmp(&old_size) { Less => nix_ftruncate(&self.fd, new_size as i64).map_err(|e| { - Error::new("could not shrink shmem segment, ftruncate failed: {e}", e) + Error::new("could not shrink shmem segment, ftruncate failed", e) }), Equal => Ok(()), Greater => enlarge_file(self.fd.as_fd(), new_size as u64), @@ -206,8 +206,8 @@ impl ShmemHandle { /// Returns the current user-visible size of the shared memory segment. /// - /// NOTE: a concurrent set_size() call can change the size at any time. It is the caller's - /// responsibility not to access the area beyond the current size. + /// NOTE: a concurrent [`ShmemHandle::set_size()`] call can change the size at any time. + /// It is the caller's responsibility not to access the area beyond the current size. pub fn current_size(&self) -> usize { let total_current_size = self.shared().current_size.load(Ordering::Relaxed) & !RESIZE_IN_PROGRESS; @@ -224,23 +224,23 @@ impl Drop for ShmemHandle { } } -/// Create a "backing file" for the shared memory area. On Linux, use memfd_create(), to create an +/// Create a "backing file" for the shared memory area. On Linux, use `memfd_create()`, to create an /// anonymous in-memory file. One macos, fall back to a regular file. That's good enough for /// development and testing, but in production we want the file to stay in memory. /// -/// disable 'unused_variables' warnings, because in the macos path, 'name' is unused. +/// Disable unused variables warnings because `name` is unused in the macos path. #[allow(unused_variables)] fn create_backing_file(name: &str) -> Result { #[cfg(not(target_os = "macos"))] { nix::sys::memfd::memfd_create(name, nix::sys::memfd::MFdFlags::empty()) - .map_err(|e| Error::new("memfd_create failed: {e}", e)) + .map_err(|e| Error::new("memfd_create failed", e)) } #[cfg(target_os = "macos")] { let file = tempfile::tempfile().map_err(|e| { Error::new( - "could not create temporary file to back shmem area: {e}", + "could not create temporary file to back shmem area", nix::errno::Errno::from_raw(e.raw_os_error().unwrap_or(0)), ) })?; @@ -255,7 +255,7 @@ fn enlarge_file(fd: BorrowedFd, size: u64) -> Result<(), Error> { { nix::fcntl::posix_fallocate(fd, 0, size as i64).map_err(|e| { Error::new( - "could not grow shmem segment, posix_fallocate failed: {e}", + "could not grow shmem segment, posix_fallocate failed", e, ) }) @@ -264,7 +264,7 @@ fn enlarge_file(fd: BorrowedFd, size: u64) -> Result<(), Error> { #[cfg(target_os = "macos")] { nix::unistd::ftruncate(fd, size as i64) - .map_err(|e| Error::new("could not grow shmem segment, ftruncate failed: {e}", e)) + .map_err(|e| Error::new("could not grow shmem segment, ftruncate failed", e)) } } @@ -330,7 +330,7 @@ mod tests { Ok(()) } - /// This is used in tests to coordinate between test processes. It's like std::sync::Barrier, + /// This is used in tests to coordinate between test processes. It's like `std::sync::Barrier`, /// but is stored in the shared memory area and works across processes. It's implemented by /// polling, because e.g. standard rust mutexes are not guaranteed to work across processes. struct SimpleBarrier {