From 282b90df287b1c1c20e3e1d03a9ef1b8f584e3fa Mon Sep 17 00:00:00 2001 From: David Freifeld Date: Fri, 11 Jul 2025 12:47:27 -0700 Subject: [PATCH] Fix freelist bug, clean up interface to BucketIdx --- libs/neon-shmem/src/hash.rs | 16 +++--- libs/neon-shmem/src/hash/bucket.rs | 81 +++++++++++++++++------------- libs/neon-shmem/src/hash/core.rs | 6 +-- libs/neon-shmem/src/hash/entry.rs | 2 +- libs/neon-shmem/src/hash/tests.rs | 5 +- 5 files changed, 64 insertions(+), 46 deletions(-) diff --git a/libs/neon-shmem/src/hash.rs b/libs/neon-shmem/src/hash.rs index 1e1ba76851..23864b2f2a 100644 --- a/libs/neon-shmem/src/hash.rs +++ b/libs/neon-shmem/src/hash.rs @@ -311,17 +311,21 @@ where } } - /// Optionally return reference to a bucket at a given index if it exists. - pub fn get_at_bucket(&self, pos: usize) -> Option<&V> { + pub unsafe fn get_at_bucket(&self, pos: usize) -> Option<&V> { let map = unsafe { self.shared_ptr.as_mut() }.unwrap(); if pos >= map.bucket_arr.buckets.len() { return None; } - todo!("safely check if a given bucket is empty? always mark?"); + let bucket = &map.bucket_arr.buckets[pos]; + if bucket.next.load(Ordering::Relaxed) == BucketIdx::RESERVED { + Some(unsafe { bucket.val.assume_init_ref() }) + } else { + None + } } - /// Returns the number of buckets in the table. + /// bucket the number of buckets in the table. pub fn get_num_buckets(&self) -> usize { let map = unsafe { self.shared_ptr.as_ref() }.unwrap(); map.get_num_buckets() @@ -442,7 +446,7 @@ where let idxs_ptr = first_shard.idxs.as_mut_ptr(); for i in old_num_buckets..num_buckets { let idx = idxs_ptr.add(i); - idx.write(BucketIdx::invalid()); + idx.write(BucketIdx::INVALID); } } @@ -519,7 +523,7 @@ where self.reshard(&mut shards, num_buckets); self.rehash(&mut shards, num_buckets); - map.bucket_arr.alloc_limit.store(BucketIdx::invalid(), Ordering::Relaxed); + map.bucket_arr.alloc_limit.store(BucketIdx::INVALID, Ordering::Relaxed); Ok(()) } diff --git a/libs/neon-shmem/src/hash/bucket.rs b/libs/neon-shmem/src/hash/bucket.rs index 4b0780ab75..1ccd6dd110 100644 --- a/libs/neon-shmem/src/hash/bucket.rs +++ b/libs/neon-shmem/src/hash/bucket.rs @@ -8,35 +8,29 @@ use atomic::Atomic; pub(crate) struct BucketIdx(pub(super) u32); impl BucketIdx { - pub const INVALID: u32 = 0x7FFFFFFF; - pub const MARK_TAG: u32 = 0x80000000; + const MARK_TAG: u32 = 0x80000000; + pub const INVALID: Self = Self(0x7FFFFFFF); + pub const RESERVED: Self = Self(0x7FFFFFFE); + pub const MAX: usize = Self::RESERVED.0 as usize - 1; pub(super) fn is_marked(&self) -> bool { self.0 & Self::MARK_TAG != 0 } - pub(super) fn is_invalid(self) -> bool { - self.0 & Self::INVALID == Self::INVALID - } - pub(super) fn as_marked(self) -> Self { Self(self.0 | Self::MARK_TAG) } pub(super) fn get_unmarked(self) -> Self { - Self(self.0 & Self::INVALID) + Self(self.0 & !Self::MARK_TAG) } pub fn new(val: usize) -> Self { Self(val as u32) } - - pub fn invalid() -> Self { - Self(Self::INVALID) - } - + pub fn pos_checked(&self) -> Option { - if self.0 == Self::INVALID || self.is_marked() { + if *self == Self::INVALID || self.is_marked() { None } else { Some(self.0 as usize) @@ -44,7 +38,22 @@ impl BucketIdx { } } -/// Fundamental storage unit within the hash table. Either empty or contains a key-value pair. +impl std::fmt::Debug for BucketIdx { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> { + let idx = self.get_unmarked().0; + write!( + f, "BucketIdx(marked={}, idx={})", + self.is_marked(), + match *self { + Self::INVALID => "INVALID".to_string(), + Self::RESERVED => "RESERVED".to_string(), + _ => format!("{idx}") + } + ) + } +} + +/// format storage unit within the hash table. Either empty or contains a key-value pair. /// Always part of a chain of some kind (either a freelist if empty or a hash chain if full). pub(crate) struct Bucket { pub val: MaybeUninit, @@ -62,7 +71,7 @@ impl Bucket { pub fn full(val: V) -> Self { Self { val: MaybeUninit::new(val), - next: Atomic::new(BucketIdx::invalid()) + next: Atomic::new(BucketIdx::INVALID) } } @@ -101,7 +110,7 @@ impl<'a, V> BucketArray<'a, V> { buckets, free_head: Atomic::new(BucketIdx(0)), _user_list_head: Atomic::new(BucketIdx(0)), - alloc_limit: Atomic::new(BucketIdx::invalid()), + alloc_limit: Atomic::new(BucketIdx::INVALID), buckets_in_use: 0.into(), } } @@ -111,7 +120,7 @@ impl<'a, V> BucketArray<'a, V> { let pos = BucketIdx::new(pos); loop { let free = self.free_head.load(Ordering::Relaxed); - bucket.next = Atomic::new(free); + bucket.next.store(free, Ordering::Relaxed); if self.free_head.compare_exchange_weak( free, pos, Ordering::Relaxed, Ordering::Relaxed ).is_ok() { @@ -123,64 +132,65 @@ impl<'a, V> BucketArray<'a, V> { #[allow(unused_assignments)] fn find_bucket(&self) -> (BucketIdx, BucketIdx) { - let mut left_node = BucketIdx::invalid(); - let mut right_node = BucketIdx::invalid(); - let mut left_node_next = BucketIdx::invalid(); + let mut left_node = BucketIdx::INVALID; + let mut right_node = BucketIdx::INVALID; + let mut left_node_next = BucketIdx::INVALID; loop { - let mut t = BucketIdx::invalid(); + let mut t = BucketIdx::INVALID; let mut t_next = self.free_head.load(Ordering::Relaxed); let alloc_limit = self.alloc_limit.load(Ordering::Relaxed).pos_checked(); while t_next.is_marked() || t.pos_checked() .map_or(true, |v| alloc_limit.map_or(false, |l| v > l)) { - if t_next.is_marked() { + if !t_next.is_marked() { left_node = t; left_node_next = t_next; } t = t_next.get_unmarked(); - if t.is_invalid() { break } + if t == BucketIdx::INVALID { break } t_next = self.buckets[t.0 as usize].next.load(Ordering::Relaxed); } right_node = t; if left_node_next == right_node { - if !right_node.is_invalid() && self.buckets[right_node.0 as usize] + if right_node != BucketIdx::INVALID && self.buckets[right_node.0 as usize] .next.load(Ordering::Relaxed).is_marked() - { + { continue; } else { return (left_node, right_node); } } - let left_ref = if !left_node.is_invalid() { + let left_ref = if left_node != BucketIdx::INVALID { &self.buckets[left_node.0 as usize].next } else { &self.free_head }; if left_ref.compare_exchange_weak( left_node_next, right_node, Ordering::Relaxed, Ordering::Relaxed ).is_ok() { - if !right_node.is_invalid() && self.buckets[right_node.0 as usize] + if right_node != BucketIdx::INVALID && self.buckets[right_node.0 as usize] .next.load(Ordering::Relaxed).is_marked() { continue; } else { return (left_node, right_node); } - } + } } } #[allow(unused_assignments)] pub(crate) fn alloc_bucket(&mut self, value: V) -> Option { - let mut right_node_next = BucketIdx::invalid(); - let mut left_idx = BucketIdx::invalid(); - let mut right_idx = BucketIdx::invalid(); + // println!("alloc()"); + let mut right_node_next = BucketIdx::INVALID; + let mut left_idx = BucketIdx::INVALID; + let mut right_idx = BucketIdx::INVALID; loop { (left_idx, right_idx) = self.find_bucket(); - if right_idx.is_invalid() { + if right_idx == BucketIdx::INVALID { return None; } @@ -196,7 +206,7 @@ impl<'a, V> BucketArray<'a, V> { } } - let left_ref = if !left_idx.is_invalid() { + let left_ref = if left_idx != BucketIdx::INVALID { &self.buckets[left_idx.0 as usize].next } else { &self.free_head @@ -211,6 +221,9 @@ impl<'a, V> BucketArray<'a, V> { self.buckets_in_use.fetch_add(1, Ordering::Relaxed); self.buckets[right_idx.0 as usize].val.write(value); + self.buckets[right_idx.0 as usize].next.store( + BucketIdx::RESERVED, Ordering::Relaxed + ); Some(right_idx) } @@ -220,7 +233,7 @@ impl<'a, V> BucketArray<'a, V> { if i < self.buckets.len() - 1 { BucketIdx::new(i + 1) } else { - BucketIdx::invalid() + BucketIdx::INVALID } ); } diff --git a/libs/neon-shmem/src/hash/core.rs b/libs/neon-shmem/src/hash/core.rs index 5627f6e912..2d31e7c556 100644 --- a/libs/neon-shmem/src/hash/core.rs +++ b/libs/neon-shmem/src/hash/core.rs @@ -61,7 +61,7 @@ impl<'a, K: Clone + Hash + Eq, V> CoreHashMap<'a, K, V> { if i < buckets.len() - 1 { BucketIdx::new(i + 1) } else { - BucketIdx::invalid() + BucketIdx::INVALID }) ); } @@ -76,7 +76,7 @@ impl<'a, K: Clone + Hash + Eq, V> CoreHashMap<'a, K, V> { }); } for e in dicts.idxs.iter_mut() { - e.write(BucketIdx::invalid()); + e.write(BucketIdx::INVALID); } } @@ -211,7 +211,7 @@ impl<'a, K: Clone + Hash + Eq, V> CoreHashMap<'a, K, V> { e.tag = EntryType::Empty; } for e in shard.idxs.iter_mut() { - *e = BucketIdx::invalid(); + *e = BucketIdx::INVALID; } } diff --git a/libs/neon-shmem/src/hash/entry.rs b/libs/neon-shmem/src/hash/entry.rs index 008c92ea70..6fcd2ac287 100644 --- a/libs/neon-shmem/src/hash/entry.rs +++ b/libs/neon-shmem/src/hash/entry.rs @@ -42,7 +42,7 @@ impl OccupiedEntry<'_, K, V> { /// Removes the entry from the hash map, returning the value originally stored within it. pub fn remove(&mut self) -> V { - self.shard.idxs[self.shard_pos] = BucketIdx::invalid(); + self.shard.idxs[self.shard_pos] = BucketIdx::INVALID; self.shard.keys[self.shard_pos].tag = EntryType::Tombstone; self.bucket_arr.dealloc_bucket(self.bucket_pos) } diff --git a/libs/neon-shmem/src/hash/tests.rs b/libs/neon-shmem/src/hash/tests.rs index 2e12b029a4..7262eec86a 100644 --- a/libs/neon-shmem/src/hash/tests.rs +++ b/libs/neon-shmem/src/hash/tests.rs @@ -332,7 +332,8 @@ fn test_clear() { #[test] fn test_bucket_ops() { - let writer = HashMapInit::::new_resizeable_named(1000, 1200, 10, "test_bucket_ops") + + let writer = HashMapInit::::new_resizeable_named(1000, 1200, 10, "test_bucket_ops") .attach_writer(); match writer.entry(1.into()).unwrap() { Entry::Occupied(mut e) => { @@ -355,7 +356,7 @@ fn test_bucket_ops() { panic!("Insert didn't affect entry"); } }; - assert_eq!(writer.get_at_bucket(pos).unwrap(), &2); + assert_eq!(unsafe { writer.get_at_bucket(pos).unwrap() }, &2); { let ptr: *const usize = &*writer.get(&1.into()).unwrap(); assert_eq!(writer.get_bucket_for_value(ptr), pos);