diff --git a/Cargo.lock b/Cargo.lock index 542d382eab..f987471256 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4135,6 +4135,8 @@ dependencies = [ "criterion", "foldhash", "hashbrown 0.15.4 (git+https://github.com/quantumish/hashbrown.git?rev=6610e6d)", + "libc", + "lock_api", "nix 0.30.1", "rand 0.9.1", "rand_distr 0.5.1", diff --git a/libs/neon-shmem/Cargo.toml b/libs/neon-shmem/Cargo.toml index 8a8af1f734..8ce5b52deb 100644 --- a/libs/neon-shmem/Cargo.toml +++ b/libs/neon-shmem/Cargo.toml @@ -10,6 +10,8 @@ nix.workspace = true workspace_hack = { version = "0.1", path = "../../workspace_hack" } rustc-hash = { version = "2.1.1" } rand = "0.9.1" +libc.workspace = true +lock_api = "0.4.13" [dev-dependencies] criterion = { workspace = true, features = ["html_reports"] } @@ -21,9 +23,10 @@ seahash = "4.1.0" hashbrown = { git = "https://github.com/quantumish/hashbrown.git", rev = "6610e6d" } foldhash = "0.1.5" + [target.'cfg(target_os = "macos")'.dependencies] tempfile = "3.14.0" [[bench]] name = "hmap_resize" -harness = false \ No newline at end of file +harness = false diff --git a/libs/neon-shmem/benches/hmap_resize.rs b/libs/neon-shmem/benches/hmap_resize.rs index 30a3dca296..6b86e7ed27 100644 --- a/libs/neon-shmem/benches/hmap_resize.rs +++ b/libs/neon-shmem/benches/hmap_resize.rs @@ -1,9 +1,7 @@ -use std::hint::black_box; use criterion::{criterion_group, criterion_main, BatchSize, Criterion, BenchmarkId}; use neon_shmem::hash::HashMapAccess; use neon_shmem::hash::HashMapInit; use neon_shmem::hash::entry::Entry; -use neon_shmem::shmem::ShmemHandle; use rand::prelude::*; use rand::distr::{Distribution, StandardUniform}; use std::hash::BuildHasher; @@ -65,14 +63,13 @@ fn apply_op( op: TestOp, map: &mut HashMapAccess, ) { - let hash = map.get_hash_value(&op.0); - let entry = map.entry_with_hash(op.0, hash); + let entry = map.entry(op.0); match op.1 { Some(new) => { match entry { Entry::Occupied(mut e) => Some(e.insert(new)), - Entry::Vacant(e) => { e.insert(new).unwrap(); None }, + Entry::Vacant(e) => { _ = e.insert(new).unwrap(); None }, } }, None => { @@ -184,15 +181,14 @@ fn real_benchs(c: &mut Criterion) { let mut rng = rand::rng(); b.iter_batched( || HashMapInit::new_resizeable(size, size * 2).attach_writer(), - |mut writer| { - for i in 0..ideal_filled { + |writer| { + for _ in 0..ideal_filled { let key: FileCacheKey = rng.random(); let val = FileCacheEntry::dummy(); - let hash = writer.get_hash_value(&key); - let entry = writer.entry_with_hash(key, hash); + let entry = writer.entry(key); std::hint::black_box(match entry { Entry::Occupied(mut e) => { e.insert(val); }, - Entry::Vacant(e) => { e.insert(val).unwrap(); }, + Entry::Vacant(e) => { _ = e.insert(val).unwrap(); }, }) } }, diff --git a/libs/neon-shmem/src/hash.rs b/libs/neon-shmem/src/hash.rs index e97ad51b0d..b46a58faaf 100644 --- a/libs/neon-shmem/src/hash.rs +++ b/libs/neon-shmem/src/hash.rs @@ -1,8 +1,8 @@ -//! Resizable hash table implementation on top of byte-level storage (either `shmem` or fixed byte array). +//! Resizable hash table implementation on top of byte-level storage (either a [`ShmemHandle`] or a fixed byte array). //! //! This hash table has two major components: the bucket array and the dictionary. Each bucket within the -//! bucket array contains a Option<(K, V)> and an index of another bucket. In this way there is both an -//! implicit freelist within the bucket array (None buckets point to other None entries) and various hash +//! bucket array contains a `Option<(K, V)>` and an index of another bucket. In this way there is both an +//! implicit freelist within the bucket array (`None` buckets point to other `None` entries) and various hash //! chains within the bucket array (a Some bucket will point to other Some buckets that had the same hash). //! //! Buckets are never moved unless they are within a region that is being shrunk, and so the actual hash- @@ -10,13 +10,14 @@ //! within the dictionary is decided based on its hash, the data is inserted into an empty bucket based //! off of the freelist, and then the index of said bucket is placed in the dictionary. //! -//! This map is resizable (if initialized on top of a `ShmemHandle`). Both growing and shrinking happen +//! This map is resizable (if initialized on top of a [`ShmemHandle`]). Both growing and shrinking happen //! in-place and are at a high level achieved by expanding/reducing the bucket array and rebuilding the //! dictionary by rehashing all keys. use std::hash::{Hash, BuildHasher}; use std::mem::MaybeUninit; +use crate::{shmem, sync::*}; use crate::shmem::ShmemHandle; mod core; @@ -26,12 +27,13 @@ pub mod entry; mod tests; use core::{Bucket, CoreHashMap, INVALID_POS}; -use entry::{Entry, OccupiedEntry}; +use entry::{Entry, OccupiedEntry, VacantEntry, PrevPos}; -/// Builder for a `HashMapAccess`. +/// Builder for a [`HashMapAccess`]. +#[must_use] pub struct HashMapInit<'a, K, V, S = rustc_hash::FxBuildHasher> { shmem_handle: Option, - shared_ptr: *mut HashMapShared<'a, K, V>, + shared_ptr: *mut RwLock>, shared_size: usize, hasher: S, num_buckets: u32, @@ -44,10 +46,10 @@ pub struct HashMapAccess<'a, K, V, S = rustc_hash::FxBuildHasher> { hasher: S, } -unsafe impl<'a, K: Sync, V: Sync, S> Sync for HashMapAccess<'a, K, V, S> {} -unsafe impl<'a, K: Send, V: Send, S> Send for HashMapAccess<'a, K, V, S> {} +unsafe impl Sync for HashMapAccess<'_, K, V, S> {} +unsafe impl Send for HashMapAccess<'_, K, V, S> {} -impl<'a, K: Clone + Hash + Eq, V, S> HashMapInit<'a, K, V, S> { +impl<'a, K: Clone + Hash + Eq, V, S> HashMapInit<'a, K, V, S> { pub fn with_hasher(self, hasher: T) -> HashMapInit<'a, K, V, T> { HashMapInit { hasher, @@ -66,13 +68,17 @@ impl<'a, K: Clone + Hash + Eq, V, S> HashMapInit<'a, K, V, S> { /// Initialize a table for writing. pub fn attach_writer(self) -> HashMapAccess<'a, K, V, S> { - // carve out the HashMapShared struct from the area. let mut ptr: *mut u8 = self.shared_ptr.cast(); let end_ptr: *mut u8 = unsafe { ptr.add(self.shared_size) }; - ptr = unsafe { ptr.add(ptr.align_offset(align_of::>())) }; - let shared_ptr: *mut HashMapShared = ptr.cast(); - ptr = unsafe { ptr.add(size_of::>()) }; + // carve out area for the One Big Lock (TM) and the HashMapShared. + ptr = unsafe { ptr.add(ptr.align_offset(align_of::())) }; + let raw_lock_ptr = ptr; + ptr = unsafe { ptr.add(size_of::()) }; + ptr = unsafe { ptr.add(ptr.align_offset(align_of::>())) }; + let shared_ptr: *mut HashMapShared = ptr.cast(); + ptr = unsafe { ptr.add(size_of::>()) }; + // carve out the buckets ptr = unsafe { ptr.byte_add(ptr.align_offset(align_of::>())) }; let buckets_ptr = ptr; @@ -91,18 +97,19 @@ impl<'a, K: Clone + Hash + Eq, V, S> HashMapInit<'a, K, V, S> { std::slice::from_raw_parts_mut(dictionary_ptr.cast(), dictionary_size as usize) }; let hashmap = CoreHashMap::new(buckets, dictionary); - unsafe { - std::ptr::write(shared_ptr, HashMapShared { inner: hashmap }); - } + let lock = RwLock::from_raw(PthreadRwLock::new(raw_lock_ptr.cast()), hashmap); + unsafe { + std::ptr::write(shared_ptr, lock); + } HashMapAccess { shmem_handle: self.shmem_handle, - shared_ptr: self.shared_ptr, + shared_ptr, hasher: self.hasher, } } - /// Initialize a table for reading. Currently identical to `attach_writer`. + /// Initialize a table for reading. Currently identical to [`HashMapInit::attach_writer`]. pub fn attach_reader(self) -> HashMapAccess<'a, K, V, S> { self.attach_writer() } @@ -114,14 +121,13 @@ impl<'a, K: Clone + Hash + Eq, V, S> HashMapInit<'a, K, V, S> { /// relies on the memory layout! The data structures are laid out in the contiguous shared memory /// area as follows: /// -/// HashMapShared +/// [`libc::pthread_rwlock_t`] +/// [`HashMapShared`] /// [buckets] /// [dictionary] /// /// In between the above parts, there can be padding bytes to align the parts correctly. -struct HashMapShared<'a, K, V> { - inner: CoreHashMap<'a, K, V> -} +type HashMapShared<'a, K, V> = RwLock>; impl<'a, K, V> HashMapInit<'a, K, V, rustc_hash::FxBuildHasher> where @@ -131,18 +137,21 @@ where pub fn with_fixed( num_buckets: u32, area: &'a mut [MaybeUninit], - ) -> HashMapInit<'a, K, V> { + ) -> Self { Self { num_buckets, shmem_handle: None, shared_ptr: area.as_mut_ptr().cast(), shared_size: area.len(), - hasher: rustc_hash::FxBuildHasher::default(), + hasher: rustc_hash::FxBuildHasher, } } /// Place a new hash map in the given shared memory area - pub fn with_shmem(num_buckets: u32, shmem: ShmemHandle) -> HashMapInit<'a, K, V> { + /// + /// # Panics + /// Will panic on failure to resize area to expected map size. + pub fn with_shmem(num_buckets: u32, shmem: ShmemHandle) -> Self { let size = Self::estimate_size(num_buckets); shmem .set_size(size) @@ -152,12 +161,12 @@ where shared_ptr: shmem.data_ptr.as_ptr().cast(), shmem_handle: Some(shmem), shared_size: size, - hasher: rustc_hash::FxBuildHasher::default() + hasher: rustc_hash::FxBuildHasher } } /// Make a resizable hash map within a new shared memory area with the given name. - pub fn new_resizeable_named(num_buckets: u32, max_buckets: u32, name: &str) -> HashMapInit<'a, K, V> { + pub fn new_resizeable_named(num_buckets: u32, max_buckets: u32, name: &str) -> Self { let size = Self::estimate_size(num_buckets); let max_size = Self::estimate_size(max_buckets); let shmem = ShmemHandle::new(name, size, max_size) @@ -168,16 +177,16 @@ where shared_ptr: shmem.data_ptr.as_ptr().cast(), shmem_handle: Some(shmem), shared_size: size, - hasher: rustc_hash::FxBuildHasher::default() + hasher: rustc_hash::FxBuildHasher } } /// Make a resizable hash map within a new anonymous shared memory area. - pub fn new_resizeable(num_buckets: u32, max_buckets: u32) -> HashMapInit<'a, K, V> { + pub fn new_resizeable(num_buckets: u32, max_buckets: u32) -> Self { use std::sync::atomic::{AtomicUsize, Ordering}; - const COUNTER: AtomicUsize = AtomicUsize::new(0); + static COUNTER: AtomicUsize = AtomicUsize::new(0); let val = COUNTER.fetch_add(1, Ordering::Relaxed); - let name = format!("neon_shmem_hmap{}", val); + let name = format!("neon_shmem_hmap{val}"); Self::new_resizeable_named(num_buckets, max_buckets, &name) } } @@ -187,46 +196,119 @@ where K: Clone + Hash + Eq, { /// Hash a key using the map's hasher. - pub fn get_hash_value(&self, key: &K) -> u64 { + #[inline] + fn get_hash_value(&self, key: &K) -> u64 { self.hasher.hash_one(key) } - /// Get a reference to the corresponding value for a key given its hash. - pub fn get_with_hash<'e>(&'e self, key: &K, hash: u64) -> Option<&'e V> { - let map = unsafe { self.shared_ptr.as_ref() }.unwrap(); + fn entry_with_hash(&self, key: K, hash: u64) -> Entry<'a, '_, K, V> { + let mut map = unsafe { self.shared_ptr.as_ref() }.unwrap().write(); + let dict_pos = hash as usize % map.dictionary.len(); + let first = map.dictionary[dict_pos]; + if first == INVALID_POS { + // no existing entry + return Entry::Vacant(VacantEntry { + map, + key, + dict_pos: dict_pos as u32, + }); + } - map.inner.get_with_hash(key, hash) - } - - /// Get a reference to the entry containing a key given its hash. - pub fn entry_with_hash(&self, key: K, hash: u64) -> Entry<'a, '_, K, V> { - let map = unsafe { self.shared_ptr.as_mut() }.unwrap(); - - map.inner.entry_with_hash(key, hash) - } - - /// Remove a key given its hash. Does nothing if key is not present. - pub fn remove_with_hash(&self, key: &K, hash: u64) { - let map = unsafe { self.shared_ptr.as_mut() }.unwrap(); - - match map.inner.entry_with_hash(key.clone(), hash) { - Entry::Occupied(e) => { - e.remove(); + let mut prev_pos = PrevPos::First(dict_pos as u32); + let mut next = first; + loop { + let bucket = &mut map.buckets[next as usize]; + let (bucket_key, _bucket_value) = bucket.inner.as_mut().expect("entry is in use"); + if *bucket_key == key { + // found existing entry + return Entry::Occupied(OccupiedEntry { + map, + _key: key, + prev_pos, + bucket_pos: next, + }); } - Entry::Vacant(_) => {} - }; + + if bucket.next == INVALID_POS { + // No existing entry + return Entry::Vacant(VacantEntry { + map, + key, + dict_pos: dict_pos as u32, + }); + } + prev_pos = PrevPos::Chained(next); + next = bucket.next; + } + } + + /// Get a reference to the corresponding value for a key. + pub fn get<'e>(&'e self, key: &K) -> Option> { + let hash = self.get_hash_value(key); + let map = unsafe { self.shared_ptr.as_ref() }.unwrap().read(); + RwLockReadGuard::try_map(map, |m| m.get_with_hash(key, hash)).ok() } + /// Get a reference to the entry containing a key. + pub fn entry(&self, key: K) -> Entry<'a, '_, K, V> { + let hash = self.get_hash_value(&key); + self.entry_with_hash(key, hash) + } + + /// Remove a key given its hash. Returns the associated value if it existed. + pub fn remove(&self, key: &K) -> Option { + let hash = self.get_hash_value(&key); + match self.entry_with_hash(key.clone(), hash) { + Entry::Occupied(e) => Some(e.remove()), + Entry::Vacant(_) => None + } + } + + /// Insert/update a key. Returns the previous associated value if it existed. + /// + /// # Errors + /// Will return [`core::FullError`] if there is no more space left in the map. + pub fn insert(&self, key: K, value: V) -> Result, core::FullError> { + let hash = self.get_hash_value(&key); + match self.entry_with_hash(key.clone(), hash) { + Entry::Occupied(mut e) => Ok(Some(e.insert(value))), + Entry::Vacant(e) => { + _ = e.insert(value)?; + Ok(None) + } + } + } + /// Optionally return the entry for a bucket at a given index if it exists. + /// + /// Has more overhead than one would intuitively expect: performs both a clone of the key + /// due to the [`OccupiedEntry`] type owning the key and also a hash of the key in order + /// to enable repairing the hash chain if the entry is removed. pub fn entry_at_bucket(&self, pos: usize) -> Option> { - let map = unsafe { self.shared_ptr.as_mut() }.unwrap(); - map.inner.entry_at_bucket(pos) + let map = unsafe { self.shared_ptr.as_mut() }.unwrap().write(); + if pos >= map.buckets.len() { + return None; + } + + let entry = map.buckets[pos].inner.as_ref(); + match entry { + Some((key, _)) => Some(OccupiedEntry { + _key: key.clone(), + bucket_pos: pos as u32, + prev_pos: entry::PrevPos::Unknown( + self.get_hash_value(&key) + ), + map, + }), + _ => None, + } +>>>>>>> quantumish/lfc-resizable-map } /// Returns the number of buckets in the table. pub fn get_num_buckets(&self) -> usize { - let map = unsafe { self.shared_ptr.as_ref() }.unwrap(); - map.inner.get_num_buckets() + let map = unsafe { self.shared_ptr.as_ref() }.unwrap().read(); + map.get_num_buckets() } /// Return the key and value stored in bucket with given index. This can be used to @@ -234,38 +316,35 @@ where // TODO: An Iterator might be nicer. The communicator's clock algorithm needs to // _slowly_ iterate through all buckets with its clock hand, without holding a lock. // If we switch to an Iterator, it must not hold the lock. - pub fn get_at_bucket(&self, pos: usize) -> Option<&(K, V)> { - let map = unsafe { self.shared_ptr.as_ref() }.unwrap(); - - if pos >= map.inner.buckets.len() { + pub fn get_at_bucket(&self, pos: usize) -> Option> { + let map = unsafe { self.shared_ptr.as_ref() }.unwrap().read(); + if pos >= map.buckets.len() { return None; } - let bucket = &map.inner.buckets[pos]; - bucket.inner.as_ref() + RwLockReadGuard::try_map(map, |m| m.buckets[pos].inner.as_ref()).ok() } /// Returns the index of the bucket a given value corresponds to. pub fn get_bucket_for_value(&self, val_ptr: *const V) -> usize { - let map = unsafe { self.shared_ptr.as_ref() }.unwrap(); + let map = unsafe { self.shared_ptr.as_ref() }.unwrap().read(); - let origin = map.inner.buckets.as_ptr(); - let idx = (val_ptr as usize - origin as usize) / (size_of::>() as usize); - assert!(idx < map.inner.buckets.len()); + let origin = map.buckets.as_ptr(); + let idx = (val_ptr as usize - origin as usize) / size_of::>(); + assert!(idx < map.buckets.len()); idx } /// Returns the number of occupied buckets in the table. pub fn get_num_buckets_in_use(&self) -> usize { - let map = unsafe { self.shared_ptr.as_ref() }.unwrap(); - map.inner.buckets_in_use as usize + let map = unsafe { self.shared_ptr.as_ref() }.unwrap().read(); + map.buckets_in_use as usize } /// Clears all entries in a table. Does not reset any shrinking operations. - pub fn clear(&mut self) { - let map = unsafe { self.shared_ptr.as_mut() }.unwrap(); - let inner = &mut map.inner; - inner.clear() + pub fn clear(&self) { + let mut map = unsafe { self.shared_ptr.as_mut() }.unwrap().write(); + map.clear(); } /// Perform an in-place rehash of some region (0..`rehash_buckets`) of the table and reset @@ -294,20 +373,20 @@ where buckets = std::slice::from_raw_parts_mut(buckets_ptr, num_buckets as usize); dictionary = std::slice::from_raw_parts_mut(dictionary_ptr, dictionary_size); } - for i in 0..dictionary.len() { - dictionary[i] = INVALID_POS; + for e in dictionary.iter_mut() { + *e = INVALID_POS; } - for i in 0..rehash_buckets as usize { - if buckets[i].inner.is_none() { - buckets[i].next = inner.free_head; + for (i, bucket) in buckets.iter_mut().enumerate().take(rehash_buckets as usize) { + if bucket.inner.is_none() { + bucket.next = inner.free_head; inner.free_head = i as u32; continue; } - let hash = self.hasher.hash_one(&buckets[i].inner.as_ref().unwrap().0); + let hash = self.hasher.hash_one(&bucket.inner.as_ref().unwrap().0); let pos: usize = (hash % dictionary.len() as u64) as usize; - buckets[i].next = dictionary[pos]; + bucket.next = dictionary[pos]; dictionary[pos] = i as u32; } @@ -316,14 +395,13 @@ where } /// Rehash the map without growing or shrinking. - pub fn shuffle(&mut self) { - let map = unsafe { self.shared_ptr.as_mut() }.unwrap(); - let inner = &mut map.inner; - let num_buckets = inner.get_num_buckets() as u32; + pub fn shuffle(&self) { + let mut map = unsafe { self.shared_ptr.as_mut() }.unwrap().write(); + let num_buckets = map.get_num_buckets() as u32; let size_bytes = HashMapInit::::estimate_size(num_buckets); - let end_ptr: *mut u8 = unsafe { (self.shared_ptr as *mut u8).add(size_bytes) }; - let buckets_ptr = inner.buckets.as_mut_ptr(); - self.rehash_dict(inner, buckets_ptr, end_ptr, num_buckets, num_buckets); + let end_ptr: *mut u8 = unsafe { self.shared_ptr.byte_add(size_bytes).cast() }; + let buckets_ptr = map.buckets.as_mut_ptr(); + self.rehash_dict(&mut map, buckets_ptr, end_ptr, num_buckets, num_buckets); } /// Grow the number of buckets within the table. @@ -331,14 +409,17 @@ where /// 1. Grows the underlying shared memory area /// 2. Initializes new buckets and overwrites the current dictionary /// 3. Rehashes the dictionary - pub fn grow(&self, num_buckets: u32) -> Result<(), crate::shmem::Error> { - let map = unsafe { self.shared_ptr.as_mut() }.unwrap(); - let inner = &mut map.inner; - let old_num_buckets = inner.buckets.len() as u32; + /// + /// # Panics + /// Panics if called on a map initialized with [`HashMapInit::with_fixed`]. + /// + /// # Errors + /// Returns an [`shmem::Error`] if any errors occur resizing the memory region. + pub fn grow(&self, num_buckets: u32) -> Result<(), shmem::Error> { + let mut map = unsafe { self.shared_ptr.as_mut() }.unwrap().write(); + let old_num_buckets = map.buckets.len() as u32; - if num_buckets < old_num_buckets { - panic!("grow called with a smaller number of buckets"); - } + assert!(num_buckets >= old_num_buckets, "grow called with a smaller number of buckets"); if num_buckets == old_num_buckets { return Ok(()); } @@ -351,83 +432,88 @@ where shmem_handle.set_size(size_bytes)?; let end_ptr: *mut u8 = unsafe { shmem_handle.data_ptr.as_ptr().add(size_bytes) }; - // Initialize new buckets. The new buckets are linked to the free list. NB: This overwrites - // the dictionary! - let buckets_ptr = inner.buckets.as_mut_ptr(); + // Initialize new buckets. The new buckets are linked to the free list. + // NB: This overwrites the dictionary! + let buckets_ptr = map.buckets.as_mut_ptr(); unsafe { for i in old_num_buckets..num_buckets { - let bucket_ptr = buckets_ptr.add(i as usize); - bucket_ptr.write(core::Bucket { + let bucket = buckets_ptr.add(i as usize); + bucket.write(core::Bucket { next: if i < num_buckets-1 { - i as u32 + 1 + i + 1 } else { - inner.free_head + map.free_head }, inner: None, }); } } - self.rehash_dict(inner, buckets_ptr, end_ptr, num_buckets, old_num_buckets); - inner.free_head = old_num_buckets; + self.rehash_dict(&mut map, buckets_ptr, end_ptr, num_buckets, old_num_buckets); + map.free_head = old_num_buckets; Ok(()) } /// Begin a shrink, limiting all new allocations to be in buckets with index below `num_buckets`. + /// + /// # Panics + /// Panics if called on a map initialized with [`HashMapInit::with_fixed`] or if `num_buckets` is + /// greater than the number of buckets in the map. pub fn begin_shrink(&mut self, num_buckets: u32) { - let map = unsafe { self.shared_ptr.as_mut() }.unwrap(); - if num_buckets > map.inner.get_num_buckets() as u32 { - panic!("shrink called with a larger number of buckets"); - } + let mut map = unsafe { self.shared_ptr.as_mut() }.unwrap().write(); + assert!( + num_buckets <= map.get_num_buckets() as u32, + "shrink called with a larger number of buckets" + ); _ = self .shmem_handle .as_ref() .expect("shrink called on a fixed-size hash table"); - map.inner.alloc_limit = num_buckets; + map.alloc_limit = num_buckets; } - /// Returns whether a shrink operation is currently in progress. - pub fn is_shrinking(&self) -> bool { - let map = unsafe { self.shared_ptr.as_mut() }.unwrap(); - map.inner.is_shrinking() + /// If a shrink operation is underway, returns the target size of the map. Otherwise, returns None. + pub fn shrink_goal(&self) -> Option { + let map = unsafe { self.shared_ptr.as_mut() }.unwrap().read(); + let goal = map.alloc_limit; + if goal == INVALID_POS { None } else { Some(goal as usize) } } - /// Returns how many entries need to be evicted before shrink can complete. - pub fn shrink_remaining(&self) -> usize { - let map = unsafe { self.shared_ptr.as_mut() }.unwrap(); - let inner = &mut map.inner; - if !inner.is_shrinking() { - panic!("shrink_remaining called when no ongoing shrink") - } else { - inner.buckets_in_use - .checked_sub(inner.alloc_limit) - .unwrap_or(0) - as usize - } - } - /// Complete a shrink after caller has evicted entries, removing the unused buckets and rehashing. - pub fn finish_shrink(&self) -> Result<(), crate::shmem::Error> { - let map = unsafe { self.shared_ptr.as_mut() }.unwrap(); - let inner = &mut map.inner; - if !inner.is_shrinking() { - panic!("called finish_shrink when no shrink is in progress"); - } + /// + /// # Panics + /// The following cases result in a panic: + /// - Calling this function on a map initialized with [`HashMapInit::with_fixed`]. + /// - Calling this function on a map when no shrink operation is in progress. + /// - Calling this function on a map with `shrink_mode` set to [`HashMapShrinkMode::Remap`] and + /// there are more buckets in use than the value returned by [`HashMapAccess::shrink_goal`]. + /// + /// # Errors + /// Returns an [`shmem::Error`] if any errors occur resizing the memory region. + pub fn finish_shrink(&self) -> Result<(), shmem::Error> { + let mut map = unsafe { self.shared_ptr.as_mut() }.unwrap().write(); + assert!( + map.alloc_limit != INVALID_POS, + "called finish_shrink when no shrink is in progress" + ); - let num_buckets = inner.alloc_limit; + let num_buckets = map.alloc_limit; - if inner.get_num_buckets() == num_buckets as usize { + if map.get_num_buckets() == num_buckets as usize { return Ok(()); - } else if inner.buckets_in_use > num_buckets { - panic!("called finish_shrink before enough entries were removed"); - } + } - for i in (num_buckets as usize)..inner.buckets.len() { - if let Some((k, v)) = inner.buckets[i].inner.take() { - // alloc bucket increases buckets in use, so need to decrease since we're just moving - inner.buckets_in_use -= 1; - inner.alloc_bucket(k, v).unwrap(); + assert!( + map.buckets_in_use <= num_buckets, + "called finish_shrink before enough entries were removed" + ); + + for i in (num_buckets as usize)..map.buckets.len() { + if let Some((k, v)) = map.buckets[i].inner.take() { + // alloc_bucket increases count, so need to decrease since we're just moving + map.buckets_in_use -= 1; + map.alloc_bucket(k, v).unwrap(); } } @@ -439,10 +525,10 @@ where let size_bytes = HashMapInit::::estimate_size(num_buckets); shmem_handle.set_size(size_bytes)?; let end_ptr: *mut u8 = unsafe { shmem_handle.data_ptr.as_ptr().add(size_bytes) }; - let buckets_ptr = inner.buckets.as_mut_ptr(); - self.rehash_dict(inner, buckets_ptr, end_ptr, num_buckets, num_buckets); - inner.alloc_limit = INVALID_POS; + let buckets_ptr = map.buckets.as_mut_ptr(); + self.rehash_dict(&mut map, buckets_ptr, end_ptr, num_buckets, num_buckets); + map.alloc_limit = INVALID_POS; Ok(()) - } + } } diff --git a/libs/neon-shmem/src/hash/core.rs b/libs/neon-shmem/src/hash/core.rs index ef81ba422d..aea89358df 100644 --- a/libs/neon-shmem/src/hash/core.rs +++ b/libs/neon-shmem/src/hash/core.rs @@ -3,8 +3,9 @@ use std::hash::Hash; use std::mem::MaybeUninit; -use crate::hash::entry::{Entry, OccupiedEntry, PrevPos, VacantEntry}; +use crate::hash::entry::*; +/// Invalid position within the map (either within the dictionary or bucket array). pub(crate) const INVALID_POS: u32 = u32::MAX; /// Fundamental storage unit within the hash table. Either empty or contains a key-value pair. @@ -18,28 +19,26 @@ pub(crate) struct Bucket { /// Core hash table implementation. pub(crate) struct CoreHashMap<'a, K, V> { - /// Dictionary used to map hashes to bucket indices. + /// Dictionary used to map hashes to bucket indices. pub(crate) dictionary: &'a mut [u32], /// Buckets containing key-value pairs. pub(crate) buckets: &'a mut [Bucket], /// Head of the freelist. pub(crate) free_head: u32, - /// Maximum index of a bucket allowed to be allocated. INVALID_POS if no limit. + /// Maximum index of a bucket allowed to be allocated. [`INVALID_POS`] if no limit. pub(crate) alloc_limit: u32, /// The number of currently occupied buckets. pub(crate) buckets_in_use: u32, + // pub(crate) lock: libc::pthread_mutex_t, // Unclear what the purpose of this is. pub(crate) _user_list_head: u32, } /// Error for when there are no empty buckets left but one is needed. -#[derive(Debug)] +#[derive(Debug, PartialEq)] pub struct FullError(); -impl<'a, K: Hash + Eq, V> CoreHashMap<'a, K, V> -where - K: Clone + Hash + Eq, -{ +impl<'a, K: Clone + Hash + Eq, V> CoreHashMap<'a, K, V> { const FILL_FACTOR: f32 = 0.60; /// Estimate the size of data contained within the the hash map. @@ -59,7 +58,7 @@ where pub fn new( buckets: &'a mut [MaybeUninit>], dictionary: &'a mut [MaybeUninit], - ) -> CoreHashMap<'a, K, V> { + ) -> Self { // Initialize the buckets for i in 0..buckets.len() { buckets[i].write(Bucket { @@ -73,8 +72,8 @@ where } // Initialize the dictionary - for i in 0..dictionary.len() { - dictionary[i].write(INVALID_POS); + for e in dictionary.iter_mut() { + e.write(INVALID_POS); } // TODO: use std::slice::assume_init_mut() once it stabilizes @@ -84,7 +83,7 @@ where std::slice::from_raw_parts_mut(dictionary.as_mut_ptr().cast(), dictionary.len()) }; - CoreHashMap { + Self { dictionary, buckets, free_head: 0, @@ -105,63 +104,17 @@ where let bucket = &self.buckets[next as usize]; let (bucket_key, bucket_value) = bucket.inner.as_ref().expect("entry is in use"); if bucket_key == key { - return Some(&bucket_value); + return Some(bucket_value); } next = bucket.next; } } - /// Get the `Entry` associated with a key given hash. This should be used for updates/inserts. - pub fn entry_with_hash(&mut self, key: K, hash: u64) -> Entry<'a, '_, K, V> { - let dict_pos = hash as usize % self.dictionary.len(); - let first = self.dictionary[dict_pos]; - if first == INVALID_POS { - // no existing entry - return Entry::Vacant(VacantEntry { - map: self, - key, - dict_pos: dict_pos as u32, - }); - } - - let mut prev_pos = PrevPos::First(dict_pos as u32); - let mut next = first; - loop { - let bucket = &mut self.buckets[next as usize]; - let (bucket_key, _bucket_value) = bucket.inner.as_mut().expect("entry is in use"); - if *bucket_key == key { - // found existing entry - return Entry::Occupied(OccupiedEntry { - map: self, - _key: key, - prev_pos, - bucket_pos: next, - }); - } - - if bucket.next == INVALID_POS { - // No existing entry - return Entry::Vacant(VacantEntry { - map: self, - key, - dict_pos: dict_pos as u32, - }); - } - prev_pos = PrevPos::Chained(next); - next = bucket.next; - } - } - /// Get number of buckets in map. pub fn get_num_buckets(&self) -> usize { self.buckets.len() } - /// Returns whether there is an ongoing shrink operation. - pub fn is_shrinking(&self) -> bool { - self.alloc_limit != INVALID_POS - } - /// Clears all entries from the hashmap. /// /// Does not reset any allocation limits, but does clear any entries beyond them. @@ -176,32 +129,14 @@ where inner: None, } } - for i in 0..self.dictionary.len() { self.dictionary[i] = INVALID_POS; } + self.free_head = 0; self.buckets_in_use = 0; } - /// Optionally gets the entry at an index if it is occupied. - pub fn entry_at_bucket(&mut self, pos: usize) -> Option> { - if pos >= self.buckets.len() { - return None; - } - - let entry = self.buckets[pos].inner.as_ref(); - match entry { - Some((key, _)) => Some(OccupiedEntry { - _key: key.clone(), - bucket_pos: pos as u32, - prev_pos: PrevPos::Unknown, - map: self, - }), - _ => None, - } - } - /// Find the position of an unused bucket via the freelist and initialize it. pub(crate) fn alloc_bucket(&mut self, key: K, value: V) -> Result { let mut pos = self.free_head; @@ -227,7 +162,7 @@ where let next_pos = self.buckets[pos as usize].next; self.buckets[p as usize].next = next_pos; }, - PrevPos::Unknown => unreachable!() + _ => unreachable!() } // Initialize the bucket. @@ -236,7 +171,7 @@ where bucket.next = INVALID_POS; bucket.inner = Some((key, value)); - return Ok(pos); + Ok(pos) } } diff --git a/libs/neon-shmem/src/hash/entry.rs b/libs/neon-shmem/src/hash/entry.rs index 24c124189b..a5832665aa 100644 --- a/libs/neon-shmem/src/hash/entry.rs +++ b/libs/neon-shmem/src/hash/entry.rs @@ -1,11 +1,12 @@ -//! Like std::collections::hash_map::Entry; +//! Equivalent of [`std::collections::hash_map::Entry`] for this hashmap. use crate::hash::core::{CoreHashMap, FullError, INVALID_POS}; +use crate::sync::{RwLockWriteGuard, ValueWriteGuard}; use std::hash::Hash; use std::mem; -/// View into an entry in the map (either vacant or occupied). + pub enum Entry<'a, 'b, K, V> { Occupied(OccupiedEntry<'a, 'b, K, V>), Vacant(VacantEntry<'a, 'b, K, V>), @@ -19,22 +20,21 @@ pub(crate) enum PrevPos { /// Regular index within the buckets. Chained(u32), /// Unknown - e.g. the associated entry was retrieved by index instead of chain. - Unknown, + Unknown(u64), } -/// View into an occupied entry within the map. pub struct OccupiedEntry<'a, 'b, K, V> { /// Mutable reference to the map containing this entry. - pub(crate) map: &'b mut CoreHashMap<'a, K, V>, + pub(crate) map: RwLockWriteGuard<'b, CoreHashMap<'a, K, V>>, /// The key of the occupied entry pub(crate) _key: K, /// The index of the previous entry in the chain. pub(crate) prev_pos: PrevPos, - /// The position of the bucket in the CoreHashMap's buckets array. - pub(crate) bucket_pos: u32, + /// The position of the bucket in the [`CoreHashMap`] bucket array. + pub(crate) bucket_pos: u32, } -impl<'a, 'b, K, V> OccupiedEntry<'a, 'b, K, V> { +impl OccupiedEntry<'_, '_, K, V> { pub fn get(&self) -> &V { &self.map.buckets[self.bucket_pos as usize] .inner @@ -55,57 +55,85 @@ impl<'a, 'b, K, V> OccupiedEntry<'a, 'b, K, V> { pub fn insert(&mut self, value: V) -> V { let bucket = &mut self.map.buckets[self.bucket_pos as usize]; // This assumes inner is Some, which it must be for an OccupiedEntry - let old_value = mem::replace(&mut bucket.inner.as_mut().unwrap().1, value); - old_value + mem::replace(&mut bucket.inner.as_mut().unwrap().1, value) } /// Removes the entry from the hash map, returning the value originally stored within it. - pub fn remove(self) -> V { + /// + /// This may result in multiple bucket accesses if the entry was obtained by index as the + /// previous chain entry needs to be discovered in this case. + /// + /// # Panics + /// Panics if the `prev_pos` field is equal to [`PrevPos::Unknown`]. In practice, this means + /// the entry was obtained via calling something like [`CoreHashMap::entry_at_bucket`]. + pub fn remove(mut self) -> V { + // If this bucket was queried by index, go ahead and follow its chain from the start. + let prev = if let PrevPos::Unknown(hash) = self.prev_pos { + let dict_idx = hash as usize % self.map.dictionary.len(); + let mut prev = PrevPos::First(dict_idx as u32); + let mut curr = self.map.dictionary[dict_idx]; + while curr != self.bucket_pos { + curr = self.map.buckets[curr as usize].next; + prev = PrevPos::Chained(curr); + } + prev + } else { + self.prev_pos + }; + // CoreHashMap::remove returns Option<(K, V)>. We know it's Some for an OccupiedEntry. let bucket = &mut self.map.buckets[self.bucket_pos as usize]; - + // unlink it from the chain - match self.prev_pos { - PrevPos::First(dict_pos) => self.map.dictionary[dict_pos as usize] = bucket.next, + match prev { + PrevPos::First(dict_pos) => { + self.map.dictionary[dict_pos as usize] = bucket.next; + }, PrevPos::Chained(bucket_pos) => { - self.map.buckets[bucket_pos as usize].next = bucket.next + // println!("we think prev of {} is {bucket_pos}", self.bucket_pos); + self.map.buckets[bucket_pos as usize].next = bucket.next; }, - PrevPos::Unknown => panic!("can't safely remove entry with unknown previous entry"), + _ => unreachable!(), } - // and add it to the freelist + // and add it to the freelist + let free = self.map.free_head; let bucket = &mut self.map.buckets[self.bucket_pos as usize]; let old_value = bucket.inner.take(); - bucket.next = self.map.free_head; + bucket.next = free; self.map.free_head = self.bucket_pos; self.map.buckets_in_use -= 1; - return old_value.unwrap().1; + old_value.unwrap().1 } } /// An abstract view into a vacant entry within the map. pub struct VacantEntry<'a, 'b, K, V> { /// Mutable reference to the map containing this entry. - pub(crate) map: &'b mut CoreHashMap<'a, K, V>, + pub(crate) map: RwLockWriteGuard<'b, CoreHashMap<'a, K, V>>, /// The key to be inserted into this entry. pub(crate) key: K, /// The position within the dictionary corresponding to the key's hash. pub(crate) dict_pos: u32, } -impl<'a, 'b, K: Clone + Hash + Eq, V> VacantEntry<'a, 'b, K, V> { +impl<'b, K: Clone + Hash + Eq, V> VacantEntry<'_, 'b, K, V> { /// Insert a value into the vacant entry, finding and populating an empty bucket in the process. - pub fn insert(self, value: V) -> Result<&'b mut V, FullError> { + /// + /// # Errors + /// Will return [`FullError`] if there are no unoccupied buckets in the map. + pub fn insert(mut self, value: V) -> Result, FullError> { let pos = self.map.alloc_bucket(self.key, value)?; if pos == INVALID_POS { return Err(FullError()); } - let bucket = &mut self.map.buckets[pos as usize]; - bucket.next = self.map.dictionary[self.dict_pos as usize]; + self.map.buckets[pos as usize].next = self.map.dictionary[self.dict_pos as usize]; self.map.dictionary[self.dict_pos as usize] = pos; - let result = &mut self.map.buckets[pos as usize].inner.as_mut().unwrap().1; - return Ok(result); + Ok(RwLockWriteGuard::map( + self.map, + |m| &mut m.buckets[pos as usize].inner.as_mut().unwrap().1 + )) } } diff --git a/libs/neon-shmem/src/hash/tests.rs b/libs/neon-shmem/src/hash/tests.rs index 987859ae60..baa971098e 100644 --- a/libs/neon-shmem/src/hash/tests.rs +++ b/libs/neon-shmem/src/hash/tests.rs @@ -1,14 +1,12 @@ use std::collections::BTreeMap; use std::collections::HashSet; -use std::fmt::{Debug, Formatter}; -use std::mem::uninitialized; +use std::fmt::Debug; use std::mem::MaybeUninit; -use std::sync::atomic::{AtomicUsize, Ordering}; use crate::hash::HashMapAccess; use crate::hash::HashMapInit; use crate::hash::Entry; -use crate::shmem::ShmemHandle; +use crate::hash::core::FullError; use rand::seq::SliceRandom; use rand::{Rng, RngCore}; @@ -38,13 +36,12 @@ impl<'a> From<&'a [u8]> for TestKey { } fn test_inserts + Copy>(keys: &[K]) { - let mut w = HashMapInit::::new_resizeable_named( + let w = HashMapInit::::new_resizeable_named( 100000, 120000, "test_inserts" ).attach_writer(); for (idx, k) in keys.iter().enumerate() { - let hash = w.get_hash_value(&(*k).into()); - let res = w.entry_with_hash((*k).into(), hash); + let res = w.entry((*k).into()); match res { Entry::Occupied(mut e) => { e.insert(idx); } Entry::Vacant(e) => { @@ -55,8 +52,7 @@ fn test_inserts + Copy>(keys: &[K]) { } for (idx, k) in keys.iter().enumerate() { - let hash = w.get_hash_value(&(*k).into()); - let x = w.get_with_hash(&(*k).into(), hash); + let x = w.get(&(*k).into()); let value = x.as_deref().copied(); assert_eq!(value, Some(idx)); } @@ -98,30 +94,6 @@ fn sparse() { test_inserts(&keys); } -struct TestValue(AtomicUsize); - -impl TestValue { - fn new(val: usize) -> TestValue { - TestValue(AtomicUsize::new(val)) - } - - fn load(&self) -> usize { - self.0.load(Ordering::Relaxed) - } -} - -impl Clone for TestValue { - fn clone(&self) -> TestValue { - TestValue::new(self.load()) - } -} - -impl Debug for TestValue { - fn fmt(&self, fmt: &mut Formatter<'_>) -> Result<(), std::fmt::Error> { - write!(fmt, "{:?}", self.load()) - } -} - #[derive(Clone, Debug)] struct TestOp(TestKey, Option); @@ -137,13 +109,12 @@ fn apply_op( shadow.remove(&op.0) }; - let hash = map.get_hash_value(&op.0); - let entry = map.entry_with_hash(op.0, hash); + let entry = map.entry(op.0); let hash_existing = match op.1 { Some(new) => { match entry { Entry::Occupied(mut e) => Some(e.insert(new)), - Entry::Vacant(e) => { e.insert(new).unwrap(); None }, + Entry::Vacant(e) => { _ = e.insert(new).unwrap(); None }, } }, None => { @@ -177,29 +148,31 @@ fn do_deletes( writer: &mut HashMapAccess, shadow: &mut BTreeMap, ) { - for i in 0..num_ops { + for _ in 0..num_ops { let (k, _) = shadow.pop_first().unwrap(); - let hash = writer.get_hash_value(&k); - writer.remove_with_hash(&k, hash); + writer.remove(&k); } } fn do_shrink( writer: &mut HashMapAccess, shadow: &mut BTreeMap, - from: u32, to: u32 ) { + assert!(writer.shrink_goal().is_none()); writer.begin_shrink(to); + assert_eq!(writer.shrink_goal(), Some(to as usize)); while writer.get_num_buckets_in_use() > to as usize { let (k, _) = shadow.pop_first().unwrap(); - let hash = writer.get_hash_value(&k); - let entry = writer.entry_with_hash(k, hash); - if let Entry::Occupied(mut e) = entry { + let entry = writer.entry(k); + if let Entry::Occupied(e) = entry { e.remove(); } } + let old_usage = writer.get_num_buckets_in_use(); writer.finish_shrink().unwrap(); + assert!(writer.shrink_goal().is_none()); + assert_eq!(writer.get_num_buckets_in_use(), old_usage); } #[test] @@ -217,10 +190,6 @@ fn random_ops() { let op = TestOp(key, if rng.random_bool(0.75) { Some(i) } else { None }); apply_op(&op, &mut writer, &mut shadow); - - if i % 1000 == 0 { - eprintln!("{i} ops processed"); - } } } @@ -247,10 +216,80 @@ fn test_grow() { let mut rng = rand::rng(); do_random_ops(10000, 1000, 0.75, &mut writer, &mut shadow, &mut rng); + let old_usage = writer.get_num_buckets_in_use(); writer.grow(1500).unwrap(); + assert_eq!(writer.get_num_buckets_in_use(), old_usage); + assert_eq!(writer.get_num_buckets(), 1500); do_random_ops(10000, 1500, 0.75, &mut writer, &mut shadow, &mut rng); } +#[test] +fn test_clear() { + let mut writer = HashMapInit::::new_resizeable_named( + 1500, 2000, "test_clear" + ).attach_writer(); + let mut shadow: std::collections::BTreeMap = BTreeMap::new(); + let mut rng = rand::rng(); + do_random_ops(2000, 1500, 0.75, &mut writer, &mut shadow, &mut rng); + writer.clear(); + assert_eq!(writer.get_num_buckets_in_use(), 0); + assert_eq!(writer.get_num_buckets(), 1500); + while let Some((key, _)) = shadow.pop_first() { + assert!(writer.get(&key).is_none()); + } + do_random_ops(2000, 1500, 0.75, &mut writer, &mut shadow, &mut rng); + for i in 0..(1500 - writer.get_num_buckets_in_use()) { + writer.insert((1500 + i as u128).into(), 0).unwrap(); + } + assert_eq!(writer.insert(5000.into(), 0), Err(FullError {})); + writer.clear(); + assert!(writer.insert(5000.into(), 0).is_ok()); +} + +#[test] +fn test_idx_remove() { + let mut writer = HashMapInit::::new_resizeable_named( + 1500, 2000, "test_clear" + ).attach_writer(); + let mut shadow: std::collections::BTreeMap = BTreeMap::new(); + let mut rng = rand::rng(); + do_random_ops(2000, 1500, 0.25, &mut writer, &mut shadow, &mut rng); + for _ in 0..100 { + let idx = (rng.next_u32() % 1500) as usize; + if let Some(e) = writer.entry_at_bucket(idx) { + shadow.remove(&e._key); + e.remove(); + } + + } + while let Some((key, val)) = shadow.pop_first() { + assert_eq!(*writer.get(&key).unwrap(), val); + } +} + +#[test] +fn test_idx_get() { + let mut writer = HashMapInit::::new_resizeable_named( + 1500, 2000, "test_clear" + ).attach_writer(); + let mut shadow: std::collections::BTreeMap = BTreeMap::new(); + let mut rng = rand::rng(); + do_random_ops(2000, 1500, 0.25, &mut writer, &mut shadow, &mut rng); + for _ in 0..100 { + let idx = (rng.next_u32() % 1500) as usize; + if let Some(pair) = writer.get_at_bucket(idx) { + { + let v: *const usize = &pair.1; + assert_eq!(writer.get_bucket_for_value(v), idx); + } + { + let v: *const usize = &pair.1; + assert_eq!(writer.get_bucket_for_value(v), idx); + } + } + } +} + #[test] fn test_shrink() { let mut writer = HashMapInit::::new_resizeable_named( @@ -259,8 +298,9 @@ fn test_shrink() { let mut shadow: std::collections::BTreeMap = BTreeMap::new(); let mut rng = rand::rng(); - do_random_ops(10000, 1500, 0.75, &mut writer, &mut shadow, &mut rng); - do_shrink(&mut writer, &mut shadow, 1500, 1000); + do_random_ops(10000, 1500, 0.75, &mut writer, &mut shadow, &mut rng); + do_shrink(&mut writer, &mut shadow, 1000); + assert_eq!(writer.get_num_buckets(), 1000); do_deletes(500, &mut writer, &mut shadow); do_random_ops(10000, 500, 0.75, &mut writer, &mut shadow, &mut rng); assert!(writer.get_num_buckets_in_use() <= 1000); @@ -276,14 +316,16 @@ fn test_shrink_grow_seq() { do_random_ops(500, 1000, 0.1, &mut writer, &mut shadow, &mut rng); eprintln!("Shrinking to 750"); - do_shrink(&mut writer, &mut shadow, 1000, 750); + do_shrink(&mut writer, &mut shadow, 750); do_random_ops(200, 1000, 0.5, &mut writer, &mut shadow, &mut rng); eprintln!("Growing to 1500"); writer.grow(1500).unwrap(); do_random_ops(600, 1500, 0.1, &mut writer, &mut shadow, &mut rng); eprintln!("Shrinking to 200"); - do_shrink(&mut writer, &mut shadow, 1500, 200); - do_deletes(100, &mut writer, &mut shadow); + while shadow.len() > 100 { + do_deletes(1, &mut writer, &mut shadow); + } + do_shrink(&mut writer, &mut shadow, 200); do_random_ops(50, 1500, 0.25, &mut writer, &mut shadow, &mut rng); eprintln!("Growing to 10k"); writer.grow(10000).unwrap(); @@ -292,31 +334,32 @@ fn test_shrink_grow_seq() { #[test] fn test_bucket_ops() { - let mut writer = HashMapInit::::new_resizeable_named( + let writer = HashMapInit::::new_resizeable_named( 1000, 1200, "test_bucket_ops" ).attach_writer(); - let hash = writer.get_hash_value(&1.into()); - match writer.entry_with_hash(1.into(), hash) { + match writer.entry(1.into()) { Entry::Occupied(mut e) => { e.insert(2); }, - Entry::Vacant(e) => { e.insert(2).unwrap(); }, + Entry::Vacant(e) => { _ = e.insert(2).unwrap(); }, } assert_eq!(writer.get_num_buckets_in_use(), 1); assert_eq!(writer.get_num_buckets(), 1000); - assert_eq!(writer.get_with_hash(&1.into(), hash), Some(&2)); - let pos = match writer.entry_with_hash(1.into(), hash) { + assert_eq!(*writer.get(&1.into()).unwrap(), 2); + let pos = match writer.entry(1.into()) { Entry::Occupied(e) => { assert_eq!(e._key, 1.into()); let pos = e.bucket_pos as usize; - assert_eq!(writer.entry_at_bucket(pos).unwrap()._key, 1.into()); - assert_eq!(writer.get_at_bucket(pos), Some(&(1.into(), 2))); pos }, Entry::Vacant(_) => { panic!("Insert didn't affect entry"); }, }; - let ptr: *const usize = writer.get_with_hash(&1.into(), hash).unwrap(); - assert_eq!(writer.get_bucket_for_value(ptr), pos); - writer.remove_with_hash(&1.into(), hash); - assert_eq!(writer.get_with_hash(&1.into(), hash), None); + assert_eq!(writer.entry_at_bucket(pos).unwrap()._key, 1.into()); + assert_eq!(*writer.get_at_bucket(pos).unwrap(), (1.into(), 2)); + { + let ptr: *const usize = &*writer.get(&1.into()).unwrap(); + assert_eq!(writer.get_bucket_for_value(ptr), pos); + } + writer.remove(&1.into()); + assert!(writer.get(&1.into()).is_none()); } #[test] @@ -330,15 +373,14 @@ fn test_shrink_zero() { } writer.finish_shrink().unwrap(); assert_eq!(writer.get_num_buckets_in_use(), 0); - let hash = writer.get_hash_value(&1.into()); - let entry = writer.entry_with_hash(1.into(), hash); + let entry = writer.entry(1.into()); if let Entry::Vacant(v) = entry { assert!(v.insert(2).is_err()); } else { panic!("Somehow got non-vacant entry in empty map.") } writer.grow(50).unwrap(); - let entry = writer.entry_with_hash(1.into(), hash); + let entry = writer.entry(1.into()); if let Entry::Vacant(v) = entry { assert!(v.insert(2).is_ok()); } else { @@ -350,7 +392,7 @@ fn test_shrink_zero() { #[test] #[should_panic] fn test_grow_oom() { - let mut writer = HashMapInit::::new_resizeable_named( + let writer = HashMapInit::::new_resizeable_named( 1500, 2000, "test_grow_oom" ).attach_writer(); writer.grow(20000).unwrap(); @@ -368,7 +410,7 @@ fn test_shrink_bigger() { #[test] #[should_panic] fn test_shrink_early_finish() { - let mut writer = HashMapInit::::new_resizeable_named( + let writer = HashMapInit::::new_resizeable_named( 1500, 2500, "test_shrink_early_finish" ).attach_writer(); writer.finish_shrink().unwrap(); diff --git a/libs/neon-shmem/src/lib.rs b/libs/neon-shmem/src/lib.rs index f601010122..61ca168073 100644 --- a/libs/neon-shmem/src/lib.rs +++ b/libs/neon-shmem/src/lib.rs @@ -2,3 +2,4 @@ pub mod hash; pub mod shmem; +pub mod sync; diff --git a/libs/neon-shmem/src/shmem.rs b/libs/neon-shmem/src/shmem.rs index 21b1454b10..7c7285f67e 100644 --- a/libs/neon-shmem/src/shmem.rs +++ b/libs/neon-shmem/src/shmem.rs @@ -12,14 +12,14 @@ use nix::sys::mman::mmap as nix_mmap; use nix::sys::mman::munmap as nix_munmap; use nix::unistd::ftruncate as nix_ftruncate; -/// ShmemHandle represents a shared memory area that can be shared by processes over fork(). -/// Unlike shared memory allocated by Postgres, this area is resizable, up to 'max_size' that's +/// `ShmemHandle` represents a shared memory area that can be shared by processes over `fork()`. +/// Unlike shared memory allocated by Postgres, this area is resizable, up to `max_size` that's /// specified at creation. /// -/// The area is backed by an anonymous file created with memfd_create(). The full address space for -/// 'max_size' is reserved up-front with mmap(), but whenever you call [`ShmemHandle::set_size`], +/// The area is backed by an anonymous file created with `memfd_create()`. The full address space for +/// `max_size` is reserved up-front with `mmap()`, but whenever you call [`ShmemHandle::set_size`], /// the underlying file is resized. Do not access the area beyond the current size. Currently, that -/// will cause the file to be expanded, but we might use mprotect() etc. to enforce that in the +/// will cause the file to be expanded, but we might use `mprotect()` etc. to enforce that in the /// future. pub struct ShmemHandle { /// memfd file descriptor @@ -38,7 +38,7 @@ pub struct ShmemHandle { struct SharedStruct { max_size: usize, - /// Current size of the backing file. The high-order bit is used for the RESIZE_IN_PROGRESS flag + /// Current size of the backing file. The high-order bit is used for the [`RESIZE_IN_PROGRESS`] flag. current_size: AtomicUsize, } @@ -46,7 +46,7 @@ const RESIZE_IN_PROGRESS: usize = 1 << 63; const HEADER_SIZE: usize = std::mem::size_of::(); -/// Error type returned by the ShmemHandle functions. +/// Error type returned by the [`ShmemHandle`] functions. #[derive(thiserror::Error, Debug)] #[error("{msg}: {errno}")] pub struct Error { @@ -55,8 +55,8 @@ pub struct Error { } impl Error { - fn new(msg: &str, errno: Errno) -> Error { - Error { + fn new(msg: &str, errno: Errno) -> Self { + Self { msg: msg.to_string(), errno, } @@ -65,11 +65,11 @@ impl Error { impl ShmemHandle { /// Create a new shared memory area. To communicate between processes, the processes need to be - /// fork()'d after calling this, so that the ShmemHandle is inherited by all processes. + /// `fork()`'d after calling this, so that the `ShmemHandle` is inherited by all processes. /// - /// If the ShmemHandle is dropped, the memory is unmapped from the current process. Other + /// If the `ShmemHandle` is dropped, the memory is unmapped from the current process. Other /// processes can continue using it, however. - pub fn new(name: &str, initial_size: usize, max_size: usize) -> Result { + pub fn new(name: &str, initial_size: usize, max_size: usize) -> Result { // create the backing anonymous file. let fd = create_backing_file(name)?; @@ -80,17 +80,17 @@ impl ShmemHandle { fd: OwnedFd, initial_size: usize, max_size: usize, - ) -> Result { - // We reserve the high-order bit for the RESIZE_IN_PROGRESS flag, and the actual size + ) -> Result { + // We reserve the high-order bit for the `RESIZE_IN_PROGRESS` flag, and the actual size // is a little larger than this because of the SharedStruct header. Make the upper limit // somewhat smaller than that, because with anything close to that, you'll run out of // memory anyway. - if max_size >= 1 << 48 { - panic!("max size {} too large", max_size); - } - if initial_size > max_size { - panic!("initial size {initial_size} larger than max size {max_size}"); - } + assert!(max_size < 1 << 48, "max size {max_size} too large"); + + assert!( + initial_size <= max_size, + "initial size {initial_size} larger than max size {max_size}" + ); // The actual initial / max size is the one given by the caller, plus the size of // 'SharedStruct'. @@ -110,7 +110,7 @@ impl ShmemHandle { 0, ) } - .map_err(|e| Error::new("mmap failed: {e}", e))?; + .map_err(|e| Error::new("mmap failed", e))?; // Reserve space for the initial size enlarge_file(fd.as_fd(), initial_size as u64)?; @@ -121,13 +121,13 @@ impl ShmemHandle { shared.write(SharedStruct { max_size: max_size.into(), current_size: AtomicUsize::new(initial_size), - }) - }; + }); + } // The user data begins after the header let data_ptr = unsafe { start_ptr.cast().add(HEADER_SIZE) }; - Ok(ShmemHandle { + Ok(Self { fd, max_size: max_size.into(), shared_ptr: shared, @@ -140,28 +140,28 @@ impl ShmemHandle { unsafe { self.shared_ptr.as_ref() } } - /// Resize the shared memory area. 'new_size' must not be larger than the 'max_size' specified + /// Resize the shared memory area. `new_size` must not be larger than the `max_size` specified /// when creating the area. /// /// This may only be called from one process/thread concurrently. We detect that case - /// and return an Error. + /// and return an [`shmem::Error`](Error). pub fn set_size(&self, new_size: usize) -> Result<(), Error> { let new_size = new_size + HEADER_SIZE; let shared = self.shared(); - if new_size > self.max_size { - panic!( - "new size ({} is greater than max size ({})", - new_size, self.max_size - ); - } - assert_eq!(self.max_size, shared.max_size); + assert!( + new_size <= self.max_size, + "new size ({new_size}) is greater than max size ({})", + self.max_size + ); - // Lock the area by setting the bit in 'current_size' + assert_eq!(self.max_size, shared.max_size); + + // Lock the area by setting the bit in `current_size` // // Ordering::Relaxed would probably be sufficient here, as we don't access any other memory - // and the posix_fallocate/ftruncate call is surely a synchronization point anyway. But - // since this is not performance-critical, better safe than sorry . + // and the `posix_fallocate`/`ftruncate` call is surely a synchronization point anyway. But + // since this is not performance-critical, better safe than sorry. let mut old_size = shared.current_size.load(Ordering::Acquire); loop { if (old_size & RESIZE_IN_PROGRESS) != 0 { @@ -188,7 +188,7 @@ impl ShmemHandle { use std::cmp::Ordering::{Equal, Greater, Less}; match new_size.cmp(&old_size) { Less => nix_ftruncate(&self.fd, new_size as i64).map_err(|e| { - Error::new("could not shrink shmem segment, ftruncate failed: {e}", e) + Error::new("could not shrink shmem segment, ftruncate failed", e) }), Equal => Ok(()), Greater => enlarge_file(self.fd.as_fd(), new_size as u64), @@ -206,8 +206,8 @@ impl ShmemHandle { /// Returns the current user-visible size of the shared memory segment. /// - /// NOTE: a concurrent set_size() call can change the size at any time. It is the caller's - /// responsibility not to access the area beyond the current size. + /// NOTE: a concurrent [`ShmemHandle::set_size()`] call can change the size at any time. + /// It is the caller's responsibility not to access the area beyond the current size. pub fn current_size(&self) -> usize { let total_current_size = self.shared().current_size.load(Ordering::Relaxed) & !RESIZE_IN_PROGRESS; @@ -224,23 +224,23 @@ impl Drop for ShmemHandle { } } -/// Create a "backing file" for the shared memory area. On Linux, use memfd_create(), to create an +/// Create a "backing file" for the shared memory area. On Linux, use `memfd_create()`, to create an /// anonymous in-memory file. One macos, fall back to a regular file. That's good enough for /// development and testing, but in production we want the file to stay in memory. /// -/// disable 'unused_variables' warnings, because in the macos path, 'name' is unused. +/// Disable unused variables warnings because `name` is unused in the macos path. #[allow(unused_variables)] fn create_backing_file(name: &str) -> Result { #[cfg(not(target_os = "macos"))] { nix::sys::memfd::memfd_create(name, nix::sys::memfd::MFdFlags::empty()) - .map_err(|e| Error::new("memfd_create failed: {e}", e)) + .map_err(|e| Error::new("memfd_create failed", e)) } #[cfg(target_os = "macos")] { let file = tempfile::tempfile().map_err(|e| { Error::new( - "could not create temporary file to back shmem area: {e}", + "could not create temporary file to back shmem area", nix::errno::Errno::from_raw(e.raw_os_error().unwrap_or(0)), ) })?; @@ -255,7 +255,7 @@ fn enlarge_file(fd: BorrowedFd, size: u64) -> Result<(), Error> { { nix::fcntl::posix_fallocate(fd, 0, size as i64).map_err(|e| { Error::new( - "could not grow shmem segment, posix_fallocate failed: {e}", + "could not grow shmem segment, posix_fallocate failed", e, ) }) @@ -264,7 +264,7 @@ fn enlarge_file(fd: BorrowedFd, size: u64) -> Result<(), Error> { #[cfg(target_os = "macos")] { nix::unistd::ftruncate(fd, size as i64) - .map_err(|e| Error::new("could not grow shmem segment, ftruncate failed: {e}", e)) + .map_err(|e| Error::new("could not grow shmem segment, ftruncate failed", e)) } } @@ -330,7 +330,7 @@ mod tests { Ok(()) } - /// This is used in tests to coordinate between test processes. It's like std::sync::Barrier, + /// This is used in tests to coordinate between test processes. It's like `std::sync::Barrier`, /// but is stored in the shared memory area and works across processes. It's implemented by /// polling, because e.g. standard rust mutexes are not guaranteed to work across processes. struct SimpleBarrier { diff --git a/libs/neon-shmem/src/sync.rs b/libs/neon-shmem/src/sync.rs new file mode 100644 index 0000000000..fc39df9100 --- /dev/null +++ b/libs/neon-shmem/src/sync.rs @@ -0,0 +1,105 @@ +//! Simple utilities akin to what's in [`std::sync`] but designed to work with shared memory. + +use std::mem::MaybeUninit; +use std::ptr::NonNull; + +use nix::errno::Errno; + +pub type RwLock = lock_api::RwLock; +pub(crate) type RwLockReadGuard<'a, T> = lock_api::RwLockReadGuard<'a, PthreadRwLock, T>; +pub type RwLockWriteGuard<'a, T> = lock_api::RwLockWriteGuard<'a, PthreadRwLock, T>; +pub type ValueReadGuard<'a, T> = lock_api::MappedRwLockReadGuard<'a, PthreadRwLock, T>; +pub type ValueWriteGuard<'a, T> = lock_api::MappedRwLockWriteGuard<'a, PthreadRwLock, T>; + +/// Shared memory read-write lock. +pub struct PthreadRwLock(Option>); + +impl PthreadRwLock { + pub fn new(lock: *mut libc::pthread_rwlock_t) -> Self { + unsafe { + let mut attrs = MaybeUninit::uninit(); + // Ignoring return value here - only possible error is OOM. + libc::pthread_rwlockattr_init(attrs.as_mut_ptr()); + libc::pthread_rwlockattr_setpshared( + attrs.as_mut_ptr(), + libc::PTHREAD_PROCESS_SHARED + ); + // TODO(quantumish): worth making this function return Result? + libc::pthread_rwlock_init(lock, attrs.as_mut_ptr()); + // Safety: POSIX specifies that "any function affecting the attributes + // object (including destruction) shall not affect any previously + // initialized read-write locks". + libc::pthread_rwlockattr_destroy(attrs.as_mut_ptr()); + Self(Some(NonNull::new_unchecked(lock))) + } + } + + fn inner(&self) -> NonNull { + match self.0 { + None => panic!("PthreadRwLock constructed badly - something likely used RawMutex::INIT"), + Some(x) => x, + } + } +} + +unsafe impl lock_api::RawRwLock for PthreadRwLock { + type GuardMarker = lock_api::GuardSend; + const INIT: Self = Self(None); + + fn lock_shared(&self) { + unsafe { + let res = libc::pthread_rwlock_rdlock(self.inner().as_ptr()); + if res != 0 { + panic!("rdlock failed with {}", Errno::from_raw(res)); + } + } + } + + fn try_lock_shared(&self) -> bool { + unsafe { + let res = libc::pthread_rwlock_tryrdlock(self.inner().as_ptr()); + match res { + 0 => true, + libc::EAGAIN => false, + o => panic!("try_rdlock failed with {}", Errno::from_raw(res)), + } + } + } + + fn lock_exclusive(&self) { + unsafe { + let res = libc::pthread_rwlock_wrlock(self.inner().as_ptr()); + if res != 0 { + panic!("wrlock failed with {}", Errno::from_raw(res)); + } + } + } + + fn try_lock_exclusive(&self) -> bool { + unsafe { + let res = libc::pthread_rwlock_trywrlock(self.inner().as_ptr()); + match res { + 0 => true, + libc::EAGAIN => false, + o => panic!("try_wrlock failed with {}", Errno::from_raw(res)), + } + } + } + + unsafe fn unlock_exclusive(&self) { + unsafe { + let res = libc::pthread_rwlock_unlock(self.inner().as_ptr()); + if res != 0 { + panic!("unlock failed with {}", Errno::from_raw(res)); + } + } + } + unsafe fn unlock_shared(&self) { + unsafe { + let res = libc::pthread_rwlock_unlock(self.inner().as_ptr()); + if res != 0 { + panic!("unlock failed with {}", Errno::from_raw(res)); + } + } + } +}