From 74330920ee6ced0099010b8ecf0974fbbb539ad6 Mon Sep 17 00:00:00 2001 From: David Freifeld Date: Fri, 27 Jun 2025 17:11:22 -0700 Subject: [PATCH] Simplify API, squash bugs, and expand hashmap test suite --- libs/neon-shmem/src/hash.rs | 108 ++++++++++++++++------------ libs/neon-shmem/src/hash/core.rs | 29 +------- libs/neon-shmem/src/hash/entry.rs | 27 +++++-- libs/neon-shmem/src/hash/tests.rs | 114 ++++++++++++++++++++++++------ 4 files changed, 178 insertions(+), 100 deletions(-) diff --git a/libs/neon-shmem/src/hash.rs b/libs/neon-shmem/src/hash.rs index 6cc641814a..7d47a4f5e5 100644 --- a/libs/neon-shmem/src/hash.rs +++ b/libs/neon-shmem/src/hash.rs @@ -224,40 +224,75 @@ where K: Clone + Hash + Eq, { /// Hash a key using the map's hasher. - pub fn get_hash_value(&self, key: &K) -> u64 { + #[inline] + fn get_hash_value(&self, key: &K) -> u64 { self.hasher.hash_one(key) } - /// Get a reference to the corresponding value for a key given its hash. - pub fn get_with_hash<'e>(&'e self, key: &K, hash: u64) -> Option<&'e V> { + /// Get a reference to the corresponding value for a key. + pub fn get<'e>(&'e self, key: &K) -> Option<&'e V> { let map = unsafe { self.shared_ptr.as_ref() }.unwrap(); - + let hash = self.get_hash_value(key); map.inner.get_with_hash(key, hash) } - /// Get a reference to the entry containing a key given its hash. - pub fn entry_with_hash(&mut self, key: K, hash: u64) -> Entry<'a, '_, K, V> { + /// Get a reference to the entry containing a key. + pub fn entry(&self, key: K) -> Entry<'a, '_, K, V> { let map = unsafe { self.shared_ptr.as_mut() }.unwrap(); - + let hash = self.get_hash_value(&key); map.inner.entry_with_hash(key, hash) } - /// Remove a key given its hash. Does nothing if key is not present. - pub fn remove_with_hash(&mut self, key: &K, hash: u64) { + /// Remove a key given its hash. Returns the associated value if it existed. + pub fn remove(&self, key: &K) -> Option { let map = unsafe { self.shared_ptr.as_mut() }.unwrap(); - + let hash = self.get_hash_value(&key); match map.inner.entry_with_hash(key.clone(), hash) { - Entry::Occupied(e) => { - e.remove(); - } - Entry::Vacant(_) => {} + Entry::Occupied(e) => Some(e.remove()), + Entry::Vacant(_) => None } } - /// Optionally return the entry for a bucket at a given index if it exists. - pub fn entry_at_bucket(&mut self, pos: usize) -> Option> { + /// Insert/update a key. Returns the previous associated value if it existed. + /// + /// # Errors + /// Will return [`core::FullError`] if there is no more space left in the map. + pub fn insert(&self, key: K, value: V) -> Result, core::FullError> { let map = unsafe { self.shared_ptr.as_mut() }.unwrap(); - map.inner.entry_at_bucket(pos) + let hash = self.get_hash_value(&key); + match map.inner.entry_with_hash(key.clone(), hash) { + Entry::Occupied(mut e) => Ok(Some(e.insert(value))), + Entry::Vacant(e) => { + e.insert(value)?; + Ok(None) + } + } + } + + /// Optionally return the entry for a bucket at a given index if it exists. + /// + /// Has more overhead than one would intuitively expect: performs both a clone of the key + /// due to the [`OccupiedEntry`] type owning the key and also a hash of the key in order + /// to enable repairing the hash chain if the entry is removed. + pub fn entry_at_bucket(&self, pos: usize) -> Option> { + let map = unsafe { self.shared_ptr.as_mut() }.unwrap(); + let inner = &mut map.inner; + if pos >= inner.buckets.len() { + return None; + } + + let entry = inner.buckets[pos].inner.as_ref(); + match entry { + Some((key, _)) => Some(OccupiedEntry { + _key: key.clone(), + bucket_pos: pos as u32, + prev_pos: entry::PrevPos::Unknown( + self.get_hash_value(&key) + ), + map: inner, + }), + _ => None, + } } /// Returns the number of buckets in the table. @@ -299,7 +334,7 @@ where } /// Clears all entries in a table. Does not reset any shrinking operations. - pub fn clear(&mut self) { + pub fn clear(&self) { let map = unsafe { self.shared_ptr.as_mut() }.unwrap(); let inner = &mut map.inner; inner.clear(); @@ -353,7 +388,7 @@ where } /// Rehash the map without growing or shrinking. - pub fn shuffle(&mut self) { + pub fn shuffle(&self) { let map = unsafe { self.shared_ptr.as_mut() }.unwrap(); let inner = &mut map.inner; let num_buckets = inner.get_num_buckets() as u32; @@ -447,14 +482,17 @@ where /// - Calling this function on a map initialized with [`HashMapInit::with_fixed`]. /// - Calling this function on a map when no shrink operation is in progress. /// - Calling this function on a map with `shrink_mode` set to [`HashMapShrinkMode::Remap`] and - /// [`HashMapAccess::get_num_buckets_in_use`] returns a value higher than [`HashMapAccess::shrink_goal`]. + /// there are more buckets in use than the value returned by [`HashMapAccess::shrink_goal`]. /// /// # Errors /// Returns an [`shmem::Error`] if any errors occur resizing the memory region. - pub fn finish_shrink(&mut self) -> Result<(), shmem::Error> { + pub fn finish_shrink(&self) -> Result<(), shmem::Error> { let map = unsafe { self.shared_ptr.as_mut() }.unwrap(); let inner = &mut map.inner; - assert!(inner.is_shrinking(), "called finish_shrink when no shrink is in progress"); + assert!( + inner.alloc_limit != INVALID_POS, + "called finish_shrink when no shrink is in progress" + ); let num_buckets = inner.alloc_limit; @@ -470,7 +508,7 @@ where for i in (num_buckets as usize)..inner.buckets.len() { if let Some((k, v)) = inner.buckets[i].inner.take() { - // alloc bucket increases buckets in use, so need to decrease since we're just moving + // alloc_bucket increases count, so need to decrease since we're just moving inner.buckets_in_use -= 1; inner.alloc_bucket(k, v).unwrap(); } @@ -491,28 +529,4 @@ where Ok(()) } - - #[cfg(feature = "stats")] - pub fn dict_len(&self) -> usize { - let map = unsafe { self.shared_ptr.as_mut() }.unwrap(); - map.inner.dictionary.len() - } - - #[cfg(feature = "stats")] - pub fn chain_distribution(&self) -> (Vec<(usize, usize)>, usize) { - let map = unsafe { self.shared_ptr.as_mut() }.unwrap(); - let mut out = Vec::new(); - let mut max = 0; - for (i, d) in map.inner.dictionary.iter().enumerate() { - let mut curr = *d; - let mut len = 0; - while curr != INVALID_POS { - curr = map.inner.buckets[curr as usize].next; - len += 1; - } - out.push((i, len)); - max = max.max(len); - } - (out, max) - } } diff --git a/libs/neon-shmem/src/hash/core.rs b/libs/neon-shmem/src/hash/core.rs index b2cf788d21..28c58e851e 100644 --- a/libs/neon-shmem/src/hash/core.rs +++ b/libs/neon-shmem/src/hash/core.rs @@ -34,7 +34,7 @@ pub(crate) struct CoreHashMap<'a, K, V> { } /// Error for when there are no empty buckets left but one is needed. -#[derive(Debug)] +#[derive(Debug, PartialEq)] pub struct FullError(); impl<'a, K: Clone + Hash + Eq, V> CoreHashMap<'a, K, V> { @@ -155,11 +155,6 @@ impl<'a, K: Clone + Hash + Eq, V> CoreHashMap<'a, K, V> { self.buckets.len() } - /// Returns whether there is an ongoing shrink operation. - pub fn is_shrinking(&self) -> bool { - self.alloc_limit != INVALID_POS - } - /// Clears all entries from the hashmap. /// /// Does not reset any allocation limits, but does clear any entries beyond them. @@ -174,32 +169,14 @@ impl<'a, K: Clone + Hash + Eq, V> CoreHashMap<'a, K, V> { inner: None, } } - for i in 0..self.dictionary.len() { self.dictionary[i] = INVALID_POS; } + self.free_head = 0; self.buckets_in_use = 0; } - /// Optionally gets the entry at an index if it is occupied. - pub fn entry_at_bucket(&mut self, pos: usize) -> Option> { - if pos >= self.buckets.len() { - return None; - } - - let entry = self.buckets[pos].inner.as_ref(); - match entry { - Some((key, _)) => Some(OccupiedEntry { - _key: key.clone(), - bucket_pos: pos as u32, - prev_pos: PrevPos::Unknown, - map: self, - }), - _ => None, - } - } - /// Find the position of an unused bucket via the freelist and initialize it. pub(crate) fn alloc_bucket(&mut self, key: K, value: V) -> Result { let mut pos = self.free_head; @@ -225,7 +202,7 @@ impl<'a, K: Clone + Hash + Eq, V> CoreHashMap<'a, K, V> { let next_pos = self.buckets[pos as usize].next; self.buckets[p as usize].next = next_pos; }, - PrevPos::Unknown => unreachable!() + _ => unreachable!() } // Initialize the bucket. diff --git a/libs/neon-shmem/src/hash/entry.rs b/libs/neon-shmem/src/hash/entry.rs index 5231061b8e..b4c973d9f5 100644 --- a/libs/neon-shmem/src/hash/entry.rs +++ b/libs/neon-shmem/src/hash/entry.rs @@ -19,7 +19,7 @@ pub(crate) enum PrevPos { /// Regular index within the buckets. Chained(u32), /// Unknown - e.g. the associated entry was retrieved by index instead of chain. - Unknown, + Unknown(u64), } /// View into an occupied entry within the map. @@ -31,7 +31,7 @@ pub struct OccupiedEntry<'a, 'b, K, V> { /// The index of the previous entry in the chain. pub(crate) prev_pos: PrevPos, /// The position of the bucket in the [`CoreHashMap`] bucket array. - pub(crate) bucket_pos: u32, + pub(crate) bucket_pos: u32, } impl OccupiedEntry<'_, '_, K, V> { @@ -60,22 +60,39 @@ impl OccupiedEntry<'_, '_, K, V> { /// Removes the entry from the hash map, returning the value originally stored within it. /// + /// This may result in multiple bucket accesses if the entry was obtained by index as the + /// previous chain entry needs to be discovered in this case. + /// /// # Panics /// Panics if the `prev_pos` field is equal to [`PrevPos::Unknown`]. In practice, this means /// the entry was obtained via calling something like [`CoreHashMap::entry_at_bucket`]. pub fn remove(self) -> V { + // If this bucket was queried by index, go ahead and follow its chain from the start. + let prev = if let PrevPos::Unknown(hash) = self.prev_pos { + let dict_idx = hash as usize % self.map.dictionary.len(); + let mut prev = PrevPos::First(dict_idx as u32); + let mut curr = self.map.dictionary[dict_idx]; + while curr != self.bucket_pos { + curr = self.map.buckets[curr as usize].next; + prev = PrevPos::Chained(curr); + } + prev + } else { + self.prev_pos + }; + // CoreHashMap::remove returns Option<(K, V)>. We know it's Some for an OccupiedEntry. let bucket = &mut self.map.buckets[self.bucket_pos as usize]; - + // unlink it from the chain - match self.prev_pos { + match prev { PrevPos::First(dict_pos) => { self.map.dictionary[dict_pos as usize] = bucket.next; }, PrevPos::Chained(bucket_pos) => { self.map.buckets[bucket_pos as usize].next = bucket.next; }, - PrevPos::Unknown => panic!("can't safely remove entry with unknown previous entry"), + _ => unreachable!(), } // and add it to the freelist diff --git a/libs/neon-shmem/src/hash/tests.rs b/libs/neon-shmem/src/hash/tests.rs index 209db599b5..d838aa0b86 100644 --- a/libs/neon-shmem/src/hash/tests.rs +++ b/libs/neon-shmem/src/hash/tests.rs @@ -6,6 +6,7 @@ use std::mem::MaybeUninit; use crate::hash::HashMapAccess; use crate::hash::HashMapInit; use crate::hash::Entry; +use crate::hash::core::FullError; use rand::seq::SliceRandom; use rand::{Rng, RngCore}; @@ -40,8 +41,7 @@ fn test_inserts + Copy>(keys: &[K]) { ).attach_writer(); for (idx, k) in keys.iter().enumerate() { - let hash = w.get_hash_value(&(*k).into()); - let res = w.entry_with_hash((*k).into(), hash); + let res = w.entry((*k).into()); match res { Entry::Occupied(mut e) => { e.insert(idx); } Entry::Vacant(e) => { @@ -52,8 +52,7 @@ fn test_inserts + Copy>(keys: &[K]) { } for (idx, k) in keys.iter().enumerate() { - let hash = w.get_hash_value(&(*k).into()); - let x = w.get_with_hash(&(*k).into(), hash); + let x = w.get(&(*k).into()); let value = x.as_deref().copied(); assert_eq!(value, Some(idx)); } @@ -110,8 +109,7 @@ fn apply_op( shadow.remove(&op.0) }; - let hash = map.get_hash_value(&op.0); - let entry = map.entry_with_hash(op.0, hash); + let entry = map.entry(op.0); let hash_existing = match op.1 { Some(new) => { match entry { @@ -152,8 +150,7 @@ fn do_deletes( ) { for _ in 0..num_ops { let (k, _) = shadow.pop_first().unwrap(); - let hash = writer.get_hash_value(&k); - writer.remove_with_hash(&k, hash); + writer.remove(&k); } } @@ -162,16 +159,20 @@ fn do_shrink( shadow: &mut BTreeMap, to: u32 ) { + assert!(writer.shrink_goal().is_none()); writer.begin_shrink(to); + assert_eq!(writer.shrink_goal(), Some(to as usize)); while writer.get_num_buckets_in_use() > to as usize { let (k, _) = shadow.pop_first().unwrap(); - let hash = writer.get_hash_value(&k); - let entry = writer.entry_with_hash(k, hash); + let entry = writer.entry(k); if let Entry::Occupied(e) = entry { e.remove(); } } + let old_usage = writer.get_num_buckets_in_use(); writer.finish_shrink().unwrap(); + assert!(writer.shrink_goal().is_none()); + assert_eq!(writer.get_num_buckets_in_use(), old_usage); } #[test] @@ -219,10 +220,80 @@ fn test_grow() { let mut rng = rand::rng(); do_random_ops(10000, 1000, 0.75, &mut writer, &mut shadow, &mut rng); + let old_usage = writer.get_num_buckets_in_use(); writer.grow(1500).unwrap(); + assert_eq!(writer.get_num_buckets_in_use(), old_usage); + assert_eq!(writer.get_num_buckets(), 1500); do_random_ops(10000, 1500, 0.75, &mut writer, &mut shadow, &mut rng); } +#[test] +fn test_clear() { + let mut writer = HashMapInit::::new_resizeable_named( + 1500, 2000, "test_clear" + ).attach_writer(); + let mut shadow: std::collections::BTreeMap = BTreeMap::new(); + let mut rng = rand::rng(); + do_random_ops(2000, 1500, 0.75, &mut writer, &mut shadow, &mut rng); + writer.clear(); + assert_eq!(writer.get_num_buckets_in_use(), 0); + assert_eq!(writer.get_num_buckets(), 1500); + while let Some((key, _)) = shadow.pop_first() { + assert!(writer.get(&key).is_none()); + } + do_random_ops(2000, 1500, 0.75, &mut writer, &mut shadow, &mut rng); + for i in 0..(1500 - writer.get_num_buckets_in_use()) { + writer.insert((1500 + i as u128).into(), 0).unwrap(); + } + assert_eq!(writer.insert(5000.into(), 0), Err(FullError {})); + writer.clear(); + assert!(writer.insert(5000.into(), 0).is_ok()); +} + +#[test] +fn test_idx_remove() { + let mut writer = HashMapInit::::new_resizeable_named( + 1500, 2000, "test_clear" + ).attach_writer(); + let mut shadow: std::collections::BTreeMap = BTreeMap::new(); + let mut rng = rand::rng(); + do_random_ops(2000, 1500, 0.25, &mut writer, &mut shadow, &mut rng); + for _ in 0..100 { + let idx = (rng.next_u32() % 1500) as usize; + if let Some(e) = writer.entry_at_bucket(idx) { + shadow.remove(&e._key); + e.remove(); + } + + } + while let Some((key, val)) = shadow.pop_first() { + assert_eq!(writer.get(&key), Some(&val)); + } +} + +#[test] +fn test_idx_get() { + let mut writer = HashMapInit::::new_resizeable_named( + 1500, 2000, "test_clear" + ).attach_writer(); + let mut shadow: std::collections::BTreeMap = BTreeMap::new(); + let mut rng = rand::rng(); + do_random_ops(2000, 1500, 0.25, &mut writer, &mut shadow, &mut rng); + for _ in 0..100 { + let idx = (rng.next_u32() % 1500) as usize; + if let Some(mut e) = writer.entry_at_bucket(idx) { + { + let v: *const usize = e.get(); + assert_eq!(writer.get_bucket_for_value(v), idx); + } + { + let v: *const usize = e.get_mut(); + assert_eq!(writer.get_bucket_for_value(v), idx); + } + } + } +} + #[test] fn test_shrink() { let mut writer = HashMapInit::::new_resizeable_named( @@ -231,8 +302,9 @@ fn test_shrink() { let mut shadow: std::collections::BTreeMap = BTreeMap::new(); let mut rng = rand::rng(); - do_random_ops(10000, 1500, 0.75, &mut writer, &mut shadow, &mut rng); - do_shrink(&mut writer, &mut shadow, 1000); + do_random_ops(10000, 1500, 0.75, &mut writer, &mut shadow, &mut rng); + do_shrink(&mut writer, &mut shadow, 1000); + assert_eq!(writer.get_num_buckets(), 1000); do_deletes(500, &mut writer, &mut shadow); do_random_ops(10000, 500, 0.75, &mut writer, &mut shadow, &mut rng); assert!(writer.get_num_buckets_in_use() <= 1000); @@ -267,15 +339,14 @@ fn test_bucket_ops() { let mut writer = HashMapInit::::new_resizeable_named( 1000, 1200, "test_bucket_ops" ).attach_writer(); - let hash = writer.get_hash_value(&1.into()); - match writer.entry_with_hash(1.into(), hash) { + match writer.entry(1.into()) { Entry::Occupied(mut e) => { e.insert(2); }, Entry::Vacant(e) => { e.insert(2).unwrap(); }, } assert_eq!(writer.get_num_buckets_in_use(), 1); assert_eq!(writer.get_num_buckets(), 1000); - assert_eq!(writer.get_with_hash(&1.into(), hash), Some(&2)); - let pos = match writer.entry_with_hash(1.into(), hash) { + assert_eq!(writer.get(&1.into()), Some(&2)); + let pos = match writer.entry(1.into()) { Entry::Occupied(e) => { assert_eq!(e._key, 1.into()); let pos = e.bucket_pos as usize; @@ -285,10 +356,10 @@ fn test_bucket_ops() { }, Entry::Vacant(_) => { panic!("Insert didn't affect entry"); }, }; - let ptr: *const usize = writer.get_with_hash(&1.into(), hash).unwrap(); + let ptr: *const usize = writer.get(&1.into()).unwrap(); assert_eq!(writer.get_bucket_for_value(ptr), pos); - writer.remove_with_hash(&1.into(), hash); - assert_eq!(writer.get_with_hash(&1.into(), hash), None); + writer.remove(&1.into()); + assert_eq!(writer.get(&1.into()), None); } #[test] @@ -302,15 +373,14 @@ fn test_shrink_zero() { } writer.finish_shrink().unwrap(); assert_eq!(writer.get_num_buckets_in_use(), 0); - let hash = writer.get_hash_value(&1.into()); - let entry = writer.entry_with_hash(1.into(), hash); + let entry = writer.entry(1.into()); if let Entry::Vacant(v) = entry { assert!(v.insert(2).is_err()); } else { panic!("Somehow got non-vacant entry in empty map.") } writer.grow(50).unwrap(); - let entry = writer.entry_with_hash(1.into(), hash); + let entry = writer.entry(1.into()); if let Entry::Vacant(v) = entry { assert!(v.insert(2).is_ok()); } else {