From 19b5618578ceab75b6ead3bf84d9ae607203624c Mon Sep 17 00:00:00 2001 From: David Freifeld Date: Wed, 2 Jul 2025 11:44:38 -0700 Subject: [PATCH] Switch to neon_shmem::sync lock_api and integrate into hashmap --- Cargo.lock | 1 + libs/neon-shmem/Cargo.toml | 1 + libs/neon-shmem/src/hash.rs | 233 ++++++++++++------------- libs/neon-shmem/src/hash/core.rs | 44 +---- libs/neon-shmem/src/hash/entry.rs | 27 +-- libs/neon-shmem/src/hash/tests.rs | 22 ++- libs/neon-shmem/src/sync.rs | 280 +++++++----------------------- 7 files changed, 210 insertions(+), 398 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 4bb6d2d474..f798a1bdda 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4038,6 +4038,7 @@ dependencies = [ "foldhash", "hashbrown 0.15.4 (git+https://github.com/quantumish/hashbrown.git?rev=6610e6d)", "libc", + "lock_api", "nix 0.30.1", "rand 0.9.1", "rand_distr 0.5.1", diff --git a/libs/neon-shmem/Cargo.toml b/libs/neon-shmem/Cargo.toml index 8306bbf778..8ce5b52deb 100644 --- a/libs/neon-shmem/Cargo.toml +++ b/libs/neon-shmem/Cargo.toml @@ -11,6 +11,7 @@ workspace_hack = { version = "0.1", path = "../../workspace_hack" } rustc-hash = { version = "2.1.1" } rand = "0.9.1" libc.workspace = true +lock_api = "0.4.13" [dev-dependencies] criterion = { workspace = true, features = ["html_reports"] } diff --git a/libs/neon-shmem/src/hash.rs b/libs/neon-shmem/src/hash.rs index 7d47a4f5e5..733e4b6f33 100644 --- a/libs/neon-shmem/src/hash.rs +++ b/libs/neon-shmem/src/hash.rs @@ -16,9 +16,9 @@ use std::hash::{Hash, BuildHasher}; use std::mem::MaybeUninit; -use std::default::Default; -use crate::{shmem, shmem::ShmemHandle}; +use crate::{shmem, sync::*}; +use crate::shmem::ShmemHandle; mod core; pub mod entry; @@ -27,15 +27,14 @@ pub mod entry; mod tests; use core::{Bucket, CoreHashMap, INVALID_POS}; -use entry::{Entry, OccupiedEntry}; +use entry::{Entry, OccupiedEntry, VacantEntry, PrevPos}; /// Builder for a [`HashMapAccess`]. #[must_use] pub struct HashMapInit<'a, K, V, S = rustc_hash::FxBuildHasher> { shmem_handle: Option, - shared_ptr: *mut HashMapShared<'a, K, V>, + shared_ptr: *mut RwLock>, shared_size: usize, - shrink_mode: HashMapShrinkMode, hasher: S, num_buckets: u32, } @@ -45,28 +44,6 @@ pub struct HashMapAccess<'a, K, V, S = rustc_hash::FxBuildHasher> { shmem_handle: Option, shared_ptr: *mut HashMapShared<'a, K, V>, hasher: S, - shrink_mode: HashMapShrinkMode, -} - -/// Enum specifying what behavior to have surrounding occupied entries in what is -/// about-to-be-shrinked space during a call to [`HashMapAccess::finish_shrink`]. -#[derive(PartialEq, Eq)] -pub enum HashMapShrinkMode { - /// Remap entry to the range of buckets that will remain after shrinking. - /// - /// Requires that caller has left enough room within the map such that this is possible. - Remap, - /// Remove any entries remaining in soon to be deallocated space. - /// - /// Only really useful if you legitimately do not care what entries are removed. - /// Should primarily be used for testing. - Remove, -} - -impl Default for HashMapShrinkMode { - fn default() -> Self { - Self::Remap - } } unsafe impl Sync for HashMapAccess<'_, K, V, S> {} @@ -80,14 +57,9 @@ impl<'a, K: Clone + Hash + Eq, V, S> HashMapInit<'a, K, V, S> { shared_ptr: self.shared_ptr, shared_size: self.shared_size, num_buckets: self.num_buckets, - shrink_mode: self.shrink_mode, } } - pub fn with_shrink_mode(self, mode: HashMapShrinkMode) -> Self { - Self { shrink_mode: mode, ..self } - } - /// Loosely (over)estimate the size needed to store a hash table with `num_buckets` buckets. pub fn estimate_size(num_buckets: u32) -> usize { // add some margin to cover alignment etc. @@ -96,13 +68,17 @@ impl<'a, K: Clone + Hash + Eq, V, S> HashMapInit<'a, K, V, S> { /// Initialize a table for writing. pub fn attach_writer(self) -> HashMapAccess<'a, K, V, S> { - // carve out the HashMapShared struct from the area. let mut ptr: *mut u8 = self.shared_ptr.cast(); let end_ptr: *mut u8 = unsafe { ptr.add(self.shared_size) }; - ptr = unsafe { ptr.add(ptr.align_offset(align_of::>())) }; - let shared_ptr: *mut HashMapShared = ptr.cast(); - ptr = unsafe { ptr.add(size_of::>()) }; + // carve out area for the One Big Lock (TM) and the HashMapShared. + ptr = unsafe { ptr.add(ptr.align_offset(align_of::())) }; + let raw_lock_ptr = ptr; + ptr = unsafe { ptr.add(size_of::()) }; + ptr = unsafe { ptr.add(ptr.align_offset(align_of::>())) }; + let shared_ptr: *mut HashMapShared = ptr.cast(); + ptr = unsafe { ptr.add(size_of::>()) }; + // carve out the buckets ptr = unsafe { ptr.byte_add(ptr.align_offset(align_of::>())) }; let buckets_ptr = ptr; @@ -121,14 +97,14 @@ impl<'a, K: Clone + Hash + Eq, V, S> HashMapInit<'a, K, V, S> { std::slice::from_raw_parts_mut(dictionary_ptr.cast(), dictionary_size as usize) }; let hashmap = CoreHashMap::new(buckets, dictionary); - unsafe { - std::ptr::write(shared_ptr, HashMapShared { inner: hashmap }); - } + let lock = RwLock::from_raw(PthreadRwLock::new(raw_lock_ptr.cast()), hashmap); + unsafe { + std::ptr::write(shared_ptr, lock); + } HashMapAccess { shmem_handle: self.shmem_handle, - shared_ptr: self.shared_ptr, - shrink_mode: self.shrink_mode, + shared_ptr, hasher: self.hasher, } } @@ -145,14 +121,13 @@ impl<'a, K: Clone + Hash + Eq, V, S> HashMapInit<'a, K, V, S> { /// relies on the memory layout! The data structures are laid out in the contiguous shared memory /// area as follows: /// +/// [`libc::pthread_rwlock_t`] /// [`HashMapShared`] /// [buckets] /// [dictionary] /// /// In between the above parts, there can be padding bytes to align the parts correctly. -struct HashMapShared<'a, K, V> { - inner: CoreHashMap<'a, K, V> -} +type HashMapShared<'a, K, V> = RwLock>; impl<'a, K, V> HashMapInit<'a, K, V, rustc_hash::FxBuildHasher> where @@ -168,7 +143,6 @@ where shmem_handle: None, shared_ptr: area.as_mut_ptr().cast(), shared_size: area.len(), - shrink_mode: HashMapShrinkMode::default(), hasher: rustc_hash::FxBuildHasher, } } @@ -187,7 +161,6 @@ where shared_ptr: shmem.data_ptr.as_ptr().cast(), shmem_handle: Some(shmem), shared_size: size, - shrink_mode: HashMapShrinkMode::default(), hasher: rustc_hash::FxBuildHasher } } @@ -204,7 +177,6 @@ where shared_ptr: shmem.data_ptr.as_ptr().cast(), shmem_handle: Some(shmem), shared_size: size, - shrink_mode: HashMapShrinkMode::default(), hasher: rustc_hash::FxBuildHasher } } @@ -229,25 +201,64 @@ where self.hasher.hash_one(key) } + fn entry_with_hash(&self, key: K, hash: u64) -> Entry<'a, '_, K, V> { + let mut map = unsafe { self.shared_ptr.as_ref() }.unwrap().write(); + let dict_pos = hash as usize % map.dictionary.len(); + let first = map.dictionary[dict_pos]; + if first == INVALID_POS { + // no existing entry + return Entry::Vacant(VacantEntry { + map, + key, + dict_pos: dict_pos as u32, + }); + } + + let mut prev_pos = PrevPos::First(dict_pos as u32); + let mut next = first; + loop { + let bucket = &mut map.buckets[next as usize]; + let (bucket_key, _bucket_value) = bucket.inner.as_mut().expect("entry is in use"); + if *bucket_key == key { + // found existing entry + return Entry::Occupied(OccupiedEntry { + map, + _key: key, + prev_pos, + bucket_pos: next, + }); + } + + if bucket.next == INVALID_POS { + // No existing entry + return Entry::Vacant(VacantEntry { + map, + key, + dict_pos: dict_pos as u32, + }); + } + prev_pos = PrevPos::Chained(next); + next = bucket.next; + } + } + /// Get a reference to the corresponding value for a key. - pub fn get<'e>(&'e self, key: &K) -> Option<&'e V> { - let map = unsafe { self.shared_ptr.as_ref() }.unwrap(); + pub fn get<'e>(&'e self, key: &K) -> Option> { let hash = self.get_hash_value(key); - map.inner.get_with_hash(key, hash) + let map = unsafe { self.shared_ptr.as_ref() }.unwrap().read(); + RwLockReadGuard::try_map(map, |m| m.get_with_hash(key, hash)).ok() } /// Get a reference to the entry containing a key. pub fn entry(&self, key: K) -> Entry<'a, '_, K, V> { - let map = unsafe { self.shared_ptr.as_mut() }.unwrap(); let hash = self.get_hash_value(&key); - map.inner.entry_with_hash(key, hash) + self.entry_with_hash(key, hash) } /// Remove a key given its hash. Returns the associated value if it existed. pub fn remove(&self, key: &K) -> Option { - let map = unsafe { self.shared_ptr.as_mut() }.unwrap(); let hash = self.get_hash_value(&key); - match map.inner.entry_with_hash(key.clone(), hash) { + match self.entry_with_hash(key.clone(), hash) { Entry::Occupied(e) => Some(e.remove()), Entry::Vacant(_) => None } @@ -258,12 +269,11 @@ where /// # Errors /// Will return [`core::FullError`] if there is no more space left in the map. pub fn insert(&self, key: K, value: V) -> Result, core::FullError> { - let map = unsafe { self.shared_ptr.as_mut() }.unwrap(); let hash = self.get_hash_value(&key); - match map.inner.entry_with_hash(key.clone(), hash) { + match self.entry_with_hash(key.clone(), hash) { Entry::Occupied(mut e) => Ok(Some(e.insert(value))), Entry::Vacant(e) => { - e.insert(value)?; + _ = e.insert(value)?; Ok(None) } } @@ -275,13 +285,12 @@ where /// due to the [`OccupiedEntry`] type owning the key and also a hash of the key in order /// to enable repairing the hash chain if the entry is removed. pub fn entry_at_bucket(&self, pos: usize) -> Option> { - let map = unsafe { self.shared_ptr.as_mut() }.unwrap(); - let inner = &mut map.inner; - if pos >= inner.buckets.len() { + let map = unsafe { self.shared_ptr.as_mut() }.unwrap().write(); + if pos >= map.buckets.len() { return None; } - let entry = inner.buckets[pos].inner.as_ref(); + let entry = map.buckets[pos].inner.as_ref(); match entry { Some((key, _)) => Some(OccupiedEntry { _key: key.clone(), @@ -289,7 +298,7 @@ where prev_pos: entry::PrevPos::Unknown( self.get_hash_value(&key) ), - map: inner, + map, }), _ => None, } @@ -297,8 +306,8 @@ where /// Returns the number of buckets in the table. pub fn get_num_buckets(&self) -> usize { - let map = unsafe { self.shared_ptr.as_ref() }.unwrap(); - map.inner.get_num_buckets() + let map = unsafe { self.shared_ptr.as_ref() }.unwrap().read(); + map.get_num_buckets() } /// Return the key and value stored in bucket with given index. This can be used to @@ -306,38 +315,35 @@ where // TODO: An Iterator might be nicer. The communicator's clock algorithm needs to // _slowly_ iterate through all buckets with its clock hand, without holding a lock. // If we switch to an Iterator, it must not hold the lock. - pub fn get_at_bucket(&self, pos: usize) -> Option<&(K, V)> { - let map = unsafe { self.shared_ptr.as_ref() }.unwrap(); - - if pos >= map.inner.buckets.len() { + pub fn get_at_bucket(&self, pos: usize) -> Option> { + let map = unsafe { self.shared_ptr.as_ref() }.unwrap().read(); + if pos >= map.buckets.len() { return None; } - let bucket = &map.inner.buckets[pos]; - bucket.inner.as_ref() + RwLockReadGuard::try_map(map, |m| m.buckets[pos].inner.as_ref()).ok() } /// Returns the index of the bucket a given value corresponds to. pub fn get_bucket_for_value(&self, val_ptr: *const V) -> usize { - let map = unsafe { self.shared_ptr.as_ref() }.unwrap(); + let map = unsafe { self.shared_ptr.as_ref() }.unwrap().read(); - let origin = map.inner.buckets.as_ptr(); + let origin = map.buckets.as_ptr(); let idx = (val_ptr as usize - origin as usize) / size_of::>(); - assert!(idx < map.inner.buckets.len()); + assert!(idx < map.buckets.len()); idx } /// Returns the number of occupied buckets in the table. pub fn get_num_buckets_in_use(&self) -> usize { - let map = unsafe { self.shared_ptr.as_ref() }.unwrap(); - map.inner.buckets_in_use as usize + let map = unsafe { self.shared_ptr.as_ref() }.unwrap().read(); + map.buckets_in_use as usize } /// Clears all entries in a table. Does not reset any shrinking operations. pub fn clear(&self) { - let map = unsafe { self.shared_ptr.as_mut() }.unwrap(); - let inner = &mut map.inner; - inner.clear(); + let mut map = unsafe { self.shared_ptr.as_mut() }.unwrap().write(); + map.clear(); } /// Perform an in-place rehash of some region (0..`rehash_buckets`) of the table and reset @@ -389,13 +395,12 @@ where /// Rehash the map without growing or shrinking. pub fn shuffle(&self) { - let map = unsafe { self.shared_ptr.as_mut() }.unwrap(); - let inner = &mut map.inner; - let num_buckets = inner.get_num_buckets() as u32; + let mut map = unsafe { self.shared_ptr.as_mut() }.unwrap().write(); + let num_buckets = map.get_num_buckets() as u32; let size_bytes = HashMapInit::::estimate_size(num_buckets); let end_ptr: *mut u8 = unsafe { self.shared_ptr.byte_add(size_bytes).cast() }; - let buckets_ptr = inner.buckets.as_mut_ptr(); - self.rehash_dict(inner, buckets_ptr, end_ptr, num_buckets, num_buckets); + let buckets_ptr = map.buckets.as_mut_ptr(); + self.rehash_dict(&mut map, buckets_ptr, end_ptr, num_buckets, num_buckets); } /// Grow the number of buckets within the table. @@ -409,10 +414,9 @@ where /// /// # Errors /// Returns an [`shmem::Error`] if any errors occur resizing the memory region. - pub fn grow(&mut self, num_buckets: u32) -> Result<(), shmem::Error> { - let map = unsafe { self.shared_ptr.as_mut() }.unwrap(); - let inner = &mut map.inner; - let old_num_buckets = inner.buckets.len() as u32; + pub fn grow(&self, num_buckets: u32) -> Result<(), shmem::Error> { + let mut map = unsafe { self.shared_ptr.as_mut() }.unwrap().write(); + let old_num_buckets = map.buckets.len() as u32; assert!(num_buckets >= old_num_buckets, "grow called with a smaller number of buckets"); if num_buckets == old_num_buckets { @@ -429,7 +433,7 @@ where // Initialize new buckets. The new buckets are linked to the free list. // NB: This overwrites the dictionary! - let buckets_ptr = inner.buckets.as_mut_ptr(); + let buckets_ptr = map.buckets.as_mut_ptr(); unsafe { for i in old_num_buckets..num_buckets { let bucket = buckets_ptr.add(i as usize); @@ -437,15 +441,15 @@ where next: if i < num_buckets-1 { i + 1 } else { - inner.free_head + map.free_head }, inner: None, }); } } - self.rehash_dict(inner, buckets_ptr, end_ptr, num_buckets, old_num_buckets); - inner.free_head = old_num_buckets; + self.rehash_dict(&mut map, buckets_ptr, end_ptr, num_buckets, old_num_buckets); + map.free_head = old_num_buckets; Ok(()) } @@ -456,22 +460,22 @@ where /// Panics if called on a map initialized with [`HashMapInit::with_fixed`] or if `num_buckets` is /// greater than the number of buckets in the map. pub fn begin_shrink(&mut self, num_buckets: u32) { - let map = unsafe { self.shared_ptr.as_mut() }.unwrap(); + let mut map = unsafe { self.shared_ptr.as_mut() }.unwrap().write(); assert!( - num_buckets <= map.inner.get_num_buckets() as u32, + num_buckets <= map.get_num_buckets() as u32, "shrink called with a larger number of buckets" ); _ = self .shmem_handle .as_ref() .expect("shrink called on a fixed-size hash table"); - map.inner.alloc_limit = num_buckets; + map.alloc_limit = num_buckets; } /// If a shrink operation is underway, returns the target size of the map. Otherwise, returns None. pub fn shrink_goal(&self) -> Option { - let map = unsafe { self.shared_ptr.as_mut() }.unwrap(); - let goal = map.inner.alloc_limit; + let map = unsafe { self.shared_ptr.as_mut() }.unwrap().read(); + let goal = map.alloc_limit; if goal == INVALID_POS { None } else { Some(goal as usize) } } @@ -487,31 +491,28 @@ where /// # Errors /// Returns an [`shmem::Error`] if any errors occur resizing the memory region. pub fn finish_shrink(&self) -> Result<(), shmem::Error> { - let map = unsafe { self.shared_ptr.as_mut() }.unwrap(); - let inner = &mut map.inner; + let mut map = unsafe { self.shared_ptr.as_mut() }.unwrap().write(); assert!( - inner.alloc_limit != INVALID_POS, + map.alloc_limit != INVALID_POS, "called finish_shrink when no shrink is in progress" ); - let num_buckets = inner.alloc_limit; + let num_buckets = map.alloc_limit; - if inner.get_num_buckets() == num_buckets as usize { + if map.get_num_buckets() == num_buckets as usize { return Ok(()); } - if self.shrink_mode == HashMapShrinkMode::Remap { - assert!( - inner.buckets_in_use <= num_buckets, - "called finish_shrink before enough entries were removed" - ); - - for i in (num_buckets as usize)..inner.buckets.len() { - if let Some((k, v)) = inner.buckets[i].inner.take() { - // alloc_bucket increases count, so need to decrease since we're just moving - inner.buckets_in_use -= 1; - inner.alloc_bucket(k, v).unwrap(); - } + assert!( + map.buckets_in_use <= num_buckets, + "called finish_shrink before enough entries were removed" + ); + + for i in (num_buckets as usize)..map.buckets.len() { + if let Some((k, v)) = map.buckets[i].inner.take() { + // alloc_bucket increases count, so need to decrease since we're just moving + map.buckets_in_use -= 1; + map.alloc_bucket(k, v).unwrap(); } } @@ -523,9 +524,9 @@ where let size_bytes = HashMapInit::::estimate_size(num_buckets); shmem_handle.set_size(size_bytes)?; let end_ptr: *mut u8 = unsafe { shmem_handle.data_ptr.as_ptr().add(size_bytes) }; - let buckets_ptr = inner.buckets.as_mut_ptr(); - self.rehash_dict(inner, buckets_ptr, end_ptr, num_buckets, num_buckets); - inner.alloc_limit = INVALID_POS; + let buckets_ptr = map.buckets.as_mut_ptr(); + self.rehash_dict(&mut map, buckets_ptr, end_ptr, num_buckets, num_buckets); + map.alloc_limit = INVALID_POS; Ok(()) } diff --git a/libs/neon-shmem/src/hash/core.rs b/libs/neon-shmem/src/hash/core.rs index 28c58e851e..473c417ece 100644 --- a/libs/neon-shmem/src/hash/core.rs +++ b/libs/neon-shmem/src/hash/core.rs @@ -3,7 +3,7 @@ use std::hash::Hash; use std::mem::MaybeUninit; -use crate::hash::entry::{Entry, OccupiedEntry, PrevPos, VacantEntry}; +use crate::hash::entry::*; /// Invalid position within the map (either within the dictionary or bucket array). pub(crate) const INVALID_POS: u32 = u32::MAX; @@ -29,6 +29,7 @@ pub(crate) struct CoreHashMap<'a, K, V> { pub(crate) alloc_limit: u32, /// The number of currently occupied buckets. pub(crate) buckets_in_use: u32, + // pub(crate) lock: libc::pthread_mutex_t, // Unclear what the purpose of this is. pub(crate) _user_list_head: u32, } @@ -109,47 +110,6 @@ impl<'a, K: Clone + Hash + Eq, V> CoreHashMap<'a, K, V> { } } - /// Get the [`Entry`] associated with a key given hash. This should be used for updates/inserts. - pub fn entry_with_hash(&mut self, key: K, hash: u64) -> Entry<'a, '_, K, V> { - let dict_pos = hash as usize % self.dictionary.len(); - let first = self.dictionary[dict_pos]; - if first == INVALID_POS { - // no existing entry - return Entry::Vacant(VacantEntry { - map: self, - key, - dict_pos: dict_pos as u32, - }); - } - - let mut prev_pos = PrevPos::First(dict_pos as u32); - let mut next = first; - loop { - let bucket = &mut self.buckets[next as usize]; - let (bucket_key, _bucket_value) = bucket.inner.as_mut().expect("entry is in use"); - if *bucket_key == key { - // found existing entry - return Entry::Occupied(OccupiedEntry { - map: self, - _key: key, - prev_pos, - bucket_pos: next, - }); - } - - if bucket.next == INVALID_POS { - // No existing entry - return Entry::Vacant(VacantEntry { - map: self, - key, - dict_pos: dict_pos as u32, - }); - } - prev_pos = PrevPos::Chained(next); - next = bucket.next; - } - } - /// Get number of buckets in map. pub fn get_num_buckets(&self) -> usize { self.buckets.len() diff --git a/libs/neon-shmem/src/hash/entry.rs b/libs/neon-shmem/src/hash/entry.rs index b4c973d9f5..a5832665aa 100644 --- a/libs/neon-shmem/src/hash/entry.rs +++ b/libs/neon-shmem/src/hash/entry.rs @@ -1,11 +1,12 @@ //! Equivalent of [`std::collections::hash_map::Entry`] for this hashmap. use crate::hash::core::{CoreHashMap, FullError, INVALID_POS}; +use crate::sync::{RwLockWriteGuard, ValueWriteGuard}; use std::hash::Hash; use std::mem; -/// View into an entry in the map (either vacant or occupied). + pub enum Entry<'a, 'b, K, V> { Occupied(OccupiedEntry<'a, 'b, K, V>), Vacant(VacantEntry<'a, 'b, K, V>), @@ -22,10 +23,9 @@ pub(crate) enum PrevPos { Unknown(u64), } -/// View into an occupied entry within the map. pub struct OccupiedEntry<'a, 'b, K, V> { /// Mutable reference to the map containing this entry. - pub(crate) map: &'b mut CoreHashMap<'a, K, V>, + pub(crate) map: RwLockWriteGuard<'b, CoreHashMap<'a, K, V>>, /// The key of the occupied entry pub(crate) _key: K, /// The index of the previous entry in the chain. @@ -66,7 +66,7 @@ impl OccupiedEntry<'_, '_, K, V> { /// # Panics /// Panics if the `prev_pos` field is equal to [`PrevPos::Unknown`]. In practice, this means /// the entry was obtained via calling something like [`CoreHashMap::entry_at_bucket`]. - pub fn remove(self) -> V { + pub fn remove(mut self) -> V { // If this bucket was queried by index, go ahead and follow its chain from the start. let prev = if let PrevPos::Unknown(hash) = self.prev_pos { let dict_idx = hash as usize % self.map.dictionary.len(); @@ -90,15 +90,17 @@ impl OccupiedEntry<'_, '_, K, V> { self.map.dictionary[dict_pos as usize] = bucket.next; }, PrevPos::Chained(bucket_pos) => { + // println!("we think prev of {} is {bucket_pos}", self.bucket_pos); self.map.buckets[bucket_pos as usize].next = bucket.next; }, _ => unreachable!(), } - // and add it to the freelist + // and add it to the freelist + let free = self.map.free_head; let bucket = &mut self.map.buckets[self.bucket_pos as usize]; let old_value = bucket.inner.take(); - bucket.next = self.map.free_head; + bucket.next = free; self.map.free_head = self.bucket_pos; self.map.buckets_in_use -= 1; @@ -109,7 +111,7 @@ impl OccupiedEntry<'_, '_, K, V> { /// An abstract view into a vacant entry within the map. pub struct VacantEntry<'a, 'b, K, V> { /// Mutable reference to the map containing this entry. - pub(crate) map: &'b mut CoreHashMap<'a, K, V>, + pub(crate) map: RwLockWriteGuard<'b, CoreHashMap<'a, K, V>>, /// The key to be inserted into this entry. pub(crate) key: K, /// The position within the dictionary corresponding to the key's hash. @@ -121,16 +123,17 @@ impl<'b, K: Clone + Hash + Eq, V> VacantEntry<'_, 'b, K, V> { /// /// # Errors /// Will return [`FullError`] if there are no unoccupied buckets in the map. - pub fn insert(self, value: V) -> Result<&'b mut V, FullError> { + pub fn insert(mut self, value: V) -> Result, FullError> { let pos = self.map.alloc_bucket(self.key, value)?; if pos == INVALID_POS { return Err(FullError()); } - let bucket = &mut self.map.buckets[pos as usize]; - bucket.next = self.map.dictionary[self.dict_pos as usize]; + self.map.buckets[pos as usize].next = self.map.dictionary[self.dict_pos as usize]; self.map.dictionary[self.dict_pos as usize] = pos; - let result = &mut self.map.buckets[pos as usize].inner.as_mut().unwrap().1; - Ok(result) + Ok(RwLockWriteGuard::map( + self.map, + |m| &mut m.buckets[pos as usize].inner.as_mut().unwrap().1 + )) } } diff --git a/libs/neon-shmem/src/hash/tests.rs b/libs/neon-shmem/src/hash/tests.rs index d838aa0b86..0c760c56b7 100644 --- a/libs/neon-shmem/src/hash/tests.rs +++ b/libs/neon-shmem/src/hash/tests.rs @@ -36,7 +36,7 @@ impl<'a> From<&'a [u8]> for TestKey { } fn test_inserts + Copy>(keys: &[K]) { - let mut w = HashMapInit::::new_resizeable_named( + let w = HashMapInit::::new_resizeable_named( 100000, 120000, "test_inserts" ).attach_writer(); @@ -190,10 +190,6 @@ fn random_ops() { let op = TestOp(key, if rng.random_bool(0.75) { Some(i) } else { None }); apply_op(&op, &mut writer, &mut shadow); - - if i % 1000 == 0 { - eprintln!("{i} ops processed"); - } } } @@ -267,7 +263,7 @@ fn test_idx_remove() { } while let Some((key, val)) = shadow.pop_first() { - assert_eq!(writer.get(&key), Some(&val)); + assert_eq!(*writer.get(&key).unwrap(), val); } } @@ -326,8 +322,10 @@ fn test_shrink_grow_seq() { writer.grow(1500).unwrap(); do_random_ops(600, 1500, 0.1, &mut writer, &mut shadow, &mut rng); eprintln!("Shrinking to 200"); + while shadow.len() > 100 { + do_deletes(1, &mut writer, &mut shadow); + } do_shrink(&mut writer, &mut shadow, 200); - do_deletes(100, &mut writer, &mut shadow); do_random_ops(50, 1500, 0.25, &mut writer, &mut shadow, &mut rng); eprintln!("Growing to 10k"); writer.grow(10000).unwrap(); @@ -336,7 +334,7 @@ fn test_shrink_grow_seq() { #[test] fn test_bucket_ops() { - let mut writer = HashMapInit::::new_resizeable_named( + let writer = HashMapInit::::new_resizeable_named( 1000, 1200, "test_bucket_ops" ).attach_writer(); match writer.entry(1.into()) { @@ -345,21 +343,21 @@ fn test_bucket_ops() { } assert_eq!(writer.get_num_buckets_in_use(), 1); assert_eq!(writer.get_num_buckets(), 1000); - assert_eq!(writer.get(&1.into()), Some(&2)); + assert_eq!(*writer.get(&1.into()).unwrap(), 2); let pos = match writer.entry(1.into()) { Entry::Occupied(e) => { assert_eq!(e._key, 1.into()); let pos = e.bucket_pos as usize; assert_eq!(writer.entry_at_bucket(pos).unwrap()._key, 1.into()); - assert_eq!(writer.get_at_bucket(pos), Some(&(1.into(), 2))); + assert_eq!(*writer.get_at_bucket(pos).unwrap(), (1.into(), 2)); pos }, Entry::Vacant(_) => { panic!("Insert didn't affect entry"); }, }; - let ptr: *const usize = writer.get(&1.into()).unwrap(); + let ptr: *const usize = &*writer.get(&1.into()).unwrap(); assert_eq!(writer.get_bucket_for_value(ptr), pos); writer.remove(&1.into()); - assert_eq!(writer.get(&1.into()), None); + assert!(writer.get(&1.into()).is_none()); } #[test] diff --git a/libs/neon-shmem/src/sync.rs b/libs/neon-shmem/src/sync.rs index 68ef7b904d..8887299a92 100644 --- a/libs/neon-shmem/src/sync.rs +++ b/libs/neon-shmem/src/sync.rs @@ -2,43 +2,18 @@ use std::mem::MaybeUninit; use std::ptr::NonNull; -use std::cell::UnsafeCell; -use std::ops::{Deref, DerefMut}; -use thiserror::Error; + +pub type RwLock = lock_api::RwLock; +pub(crate) type RwLockReadGuard<'a, T> = lock_api::RwLockReadGuard<'a, PthreadRwLock, T>; +pub type RwLockWriteGuard<'a, T> = lock_api::RwLockWriteGuard<'a, PthreadRwLock, T>; +pub type ValueReadGuard<'a, T> = lock_api::MappedRwLockReadGuard<'a, PthreadRwLock, T>; +pub type ValueWriteGuard<'a, T> = lock_api::MappedRwLockWriteGuard<'a, PthreadRwLock, T>; /// Shared memory read-write lock. -struct RwLock<'a, T: ?Sized> { - inner: &'a mut libc::pthread_rwlock_t, - data: UnsafeCell, -} +pub struct PthreadRwLock(Option>); -/// RAII guard for a read lock. -struct RwLockReadGuard<'a, 'b, T: ?Sized> { - data: NonNull, - lock: &'a RwLock<'b, T>, -} - -/// RAII guard for a write lock. -struct RwLockWriteGuard<'a, 'b, T: ?Sized> { - lock: &'a RwLock<'b, T>, -} - -// TODO(quantumish): Support poisoning errors? -#[derive(Error, Debug)] -enum RwLockError { - #[error("deadlock detected")] - Deadlock, - #[error("max number of read locks exceeded")] - MaxReadLocks, - #[error("nonblocking operation would block")] - WouldBlock, -} - -unsafe impl Send for RwLock<'_, T> {} -unsafe impl Sync for RwLock<'_, T> {} - -impl<'a, T> RwLock<'a, T> { - fn new(lock: &'a mut MaybeUninit, data: T) -> Self { +impl PthreadRwLock { + pub fn new(lock: *mut libc::pthread_rwlock_t) -> Self { unsafe { let mut attrs = MaybeUninit::uninit(); // Ignoring return value here - only possible error is OOM. @@ -48,208 +23,81 @@ impl<'a, T> RwLock<'a, T> { libc::PTHREAD_PROCESS_SHARED ); // TODO(quantumish): worth making this function return Result? - libc::pthread_rwlock_init(lock.as_mut_ptr(), attrs.as_mut_ptr()); + libc::pthread_rwlock_init(lock, attrs.as_mut_ptr()); // Safety: POSIX specifies that "any function affecting the attributes // object (including destruction) shall not affect any previously // initialized read-write locks". libc::pthread_rwlockattr_destroy(attrs.as_mut_ptr()); - Self { - inner: lock.assume_init_mut(), - data: data.into(), - } - } - } - - fn read(&self) -> Result, RwLockError> { - unsafe { - let res = libc::pthread_rwlock_rdlock(self.inner as *const _ as *mut _); - match res { - 0 => (), - libc::EINVAL => panic!("failed to properly initialize lock"), - libc::EDEADLK => return Err(RwLockError::Deadlock), - libc::EAGAIN => return Err(RwLockError::MaxReadLocks), - e => panic!("unknown error code returned: {e}") - } - Ok(RwLockReadGuard { - data: NonNull::new_unchecked(self.data.get()), - lock: self - }) - } - } - - fn try_read(&self) -> Result, RwLockError> { - unsafe { - let res = libc::pthread_rwlock_tryrdlock(self.inner as *const _ as *mut _); - match res { - 0 => (), - libc::EINVAL => panic!("failed to properly initialize lock"), - libc::EDEADLK => return Err(RwLockError::Deadlock), - libc::EAGAIN => return Err(RwLockError::MaxReadLocks), - libc::EBUSY => return Err(RwLockError::WouldBlock), - e => panic!("unknown error code returned: {e}") - } - Ok(RwLockReadGuard { - data: NonNull::new_unchecked(self.data.get()), - lock: self - }) + Self(Some(NonNull::new_unchecked(lock))) } } - fn write(&self) -> Result, RwLockError> { + fn inner(&self) -> NonNull { + match self.0 { + None => panic!("PthreadRwLock constructed badly - something likely used RawMutex::INIT"), + Some(x) => x, + } + } +} + +unsafe impl lock_api::RawRwLock for PthreadRwLock { + type GuardMarker = lock_api::GuardSend; + const INIT: Self = Self(None); + + fn lock_shared(&self) { unsafe { - let res = libc::pthread_rwlock_wrlock(self.inner as *const _ as *mut _); - match res { - 0 => (), - libc::EINVAL => panic!("failed to properly initialize lock"), - libc::EDEADLK => return Err(RwLockError::Deadlock), - e => panic!("unknown error code returned: {e}") + let res = libc::pthread_rwlock_rdlock(self.inner().as_ptr()); + if res != 0 { + panic!("rdlock failed with {res}"); } } - Ok(RwLockWriteGuard { lock: self }) } - fn try_write(&self) -> Result, RwLockError> { + fn try_lock_shared(&self) -> bool { unsafe { - let res = libc::pthread_rwlock_trywrlock(self.inner as *const _ as *mut _); + let res = libc::pthread_rwlock_tryrdlock(self.inner().as_ptr()); match res { - 0 => (), - libc::EINVAL => panic!("failed to properly initialize lock"), - libc::EDEADLK => return Err(RwLockError::Deadlock), - libc::EBUSY => return Err(RwLockError::WouldBlock), - e => panic!("unknown error code returned: {e}") + 0 => true, + libc::EAGAIN => false, + o => panic!("try_rdlock failed with {o}") } } - Ok(RwLockWriteGuard { lock: self }) - } -} - -unsafe impl Sync for RwLockReadGuard<'_, '_, T> {} -unsafe impl Sync for RwLockWriteGuard<'_, '_, T> {} - -impl Deref for RwLockReadGuard<'_, '_, T> { - type Target = T; - - fn deref(&self) -> &T { - unsafe { self.data.as_ref() } - } -} - -impl Deref for RwLockWriteGuard<'_, '_, T> { - type Target = T; - - fn deref(&self) -> &T { - unsafe { &*self.lock.data.get() } - } -} - -impl DerefMut for RwLockWriteGuard<'_, '_, T> { - fn deref_mut(&mut self) -> &mut T { - unsafe { &mut *self.lock.data.get() } - } -} - -impl Drop for RwLockReadGuard<'_, '_, T> { - fn drop(&mut self) -> () { - let res = unsafe { libc::pthread_rwlock_unlock( - self.lock.inner as *const _ as *mut _ - ) }; - debug_assert!(res == 0); - } -} - -impl Drop for RwLockWriteGuard<'_, '_, T> { - fn drop(&mut self) -> () { - let res = unsafe { libc::pthread_rwlock_unlock( - self.lock.inner as *const _ as *mut _ - ) }; - debug_assert!(res == 0); - } -} - -#[cfg(test)] -mod tests { - use std::sync::Arc; - - use super::*; -use RwLockError::*; - - #[test] - fn test_single_process() { - let mut lock = MaybeUninit::uninit(); - let wrapper = RwLock::new(&mut lock, 0); - let mut writer = wrapper.write().unwrap(); - assert!(matches!(wrapper.try_write(), Err(Deadlock | WouldBlock))); - assert!(matches!(wrapper.try_read(), Err(Deadlock | WouldBlock))); - *writer = 5; - drop(writer); - let reader = wrapper.read().unwrap(); - assert!(matches!(wrapper.try_write(), Err(Deadlock | WouldBlock))); - assert!(matches!(wrapper.read(), Ok(_))); - assert_eq!(*reader, 5); - drop(reader); - assert!(matches!(wrapper.try_write(), Ok(_))); } - #[test] - fn test_multi_thread() { - let lock = Box::new(MaybeUninit::uninit()); - let wrapper = Arc::new(RwLock::new(Box::leak(lock), 0)); - let mut writer = wrapper.write().unwrap(); - let t1 = { - let wrapper = wrapper.clone(); - std::thread::spawn(move || { - let mut writer = wrapper.write().unwrap(); - *writer = 20; - }) - }; - assert_eq!(*writer, 0); - *writer = 10; - assert_eq!(*writer, 10); - drop(writer); - t1.join().unwrap(); - let mut writer = wrapper.write().unwrap(); - assert_eq!(*writer, 20); - drop(writer); - let mut handles = vec![]; - for _ in 0..5 { - handles.push({ - let wrapper = wrapper.clone(); - std::thread::spawn(move || { - let reader = wrapper.read().unwrap(); - assert_eq!(*reader, 20); - }) - }); + fn lock_exclusive(&self) { + unsafe { + let res = libc::pthread_rwlock_wrlock(self.inner().as_ptr()); + if res != 0 { + panic!("wrlock failed with {res}"); + } } - for h in handles { - h.join().unwrap(); - } - let writer = wrapper.write().unwrap(); - assert_eq!(*writer, 20); } - // // TODO(quantumish): Terrible time-based synchronization, fix me. - // #[test] - // fn test_multi_process() { - // let max_size = 100; - // let init_struct = crate::shmem::ShmemHandle::new("test_multi_process", 0, max_size).unwrap(); - // let ptr = init_struct.data_ptr.as_ptr(); - // let lock: &mut _ = unsafe { ptr.add( - // ptr.align_offset(std::mem::align_of::>()) - // ).cast::>().as_mut().unwrap() } ; - // let wrapper = RwLock::new(lock, 0); + fn try_lock_exclusive(&self) -> bool { + unsafe { + let res = libc::pthread_rwlock_trywrlock(self.inner().as_ptr()); + match res { + 0 => true, + libc::EAGAIN => false, + o => panic!("try_wrlock failed with {o}") + } + } + } - // let fork_result = unsafe { nix::unistd::fork().unwrap() }; - - // if !fork_result.is_parent() { - // let mut writer = wrapper.write().unwrap(); - // std::thread::sleep(std::time::Duration::from_secs(5)); - // *writer = 2; - // } else { - // std::thread::sleep(std::time::Duration::from_secs(1)); - // assert!(matches!(wrapper.try_write(), Err(WouldBlock))); - // std::thread::sleep(std::time::Duration::from_secs(10)); - // let writer = wrapper.try_write().unwrap(); - // assert_eq!(*writer, 2); - // } - // } + unsafe fn unlock_exclusive(&self) { + unsafe { + let res = libc::pthread_rwlock_unlock(self.inner().as_ptr()); + if res != 0 { + panic!("unlock failed with {res}"); + } + } + } + unsafe fn unlock_shared(&self) { + unsafe { + let res = libc::pthread_rwlock_unlock(self.inner().as_ptr()); + if res != 0 { + panic!("unlock failed with {res}"); + } + } + } }