From add1a0ad7859b1cc4f4291c305d69436fee6c0b9 Mon Sep 17 00:00:00 2001 From: David Freifeld Date: Mon, 14 Jul 2025 09:02:56 -0700 Subject: [PATCH] Add initial work in implementing incremental resizing (WIP) --- libs/neon-shmem/src/hash.rs | 50 ++++++------ libs/neon-shmem/src/hash/bucket.rs | 5 +- libs/neon-shmem/src/hash/core.rs | 122 +++++++++++++++++++++++------ 3 files changed, 122 insertions(+), 55 deletions(-) diff --git a/libs/neon-shmem/src/hash.rs b/libs/neon-shmem/src/hash.rs index 23864b2f2a..ef9533d542 100644 --- a/libs/neon-shmem/src/hash.rs +++ b/libs/neon-shmem/src/hash.rs @@ -354,26 +354,40 @@ where map.clear(); } - pub fn rehash( + fn begin_rehash( &self, shards: &mut Vec>>, rehash_buckets: usize - ) { + ) -> bool { let map = unsafe { self.shared_ptr.as_mut() }.unwrap(); assert!(rehash_buckets <= map.get_num_buckets(), "rehashing subset of buckets"); + + if map.rehash_index.load(Ordering::Relaxed) >= map.rehash_end.load(Ordering::Relaxed) { + return false; + } + shards.iter_mut().for_each(|x| x.keys.iter_mut().for_each(|key| { - if let EntryType::Occupied = key.tag { - key.tag = EntryType::Rehash; + match key.tag { + EntryType::Occupied => key.tag = EntryType::Rehash, + EntryType::Tombstone => key.tag = EntryType::RehashTombstone, + _ => (), } })); - - todo!("solution with no memory allocation: split out metadata?") + + map.rehash_index.store(0, Ordering::Relaxed); + map.rehash_end.store(rehash_buckets, Ordering::Relaxed); + true } + pub fn finish_rehash(&self) { + let map = unsafe { self.shared_ptr.as_mut() }.unwrap(); + while map.do_rehash() {} + } + pub fn shuffle(&self) { let map = unsafe { self.shared_ptr.as_mut() }.unwrap(); let mut shards: Vec<_> = map.dict_shards.iter().map(|x| x.write()).collect(); - self.rehash(&mut shards, map.get_num_buckets()); + self.begin_rehash(&mut shards, map.get_num_buckets()); } fn reshard(&self, shards: &mut Vec>>, num_buckets: usize) { @@ -451,10 +465,10 @@ where } self.reshard(&mut shards, num_buckets); - self.rehash(&mut shards, old_num_buckets); map.bucket_arr.free_head.store( BucketIdx::new(old_num_buckets), Ordering::Relaxed ); + self.begin_rehash(&mut shards, old_num_buckets); Ok(()) } @@ -500,30 +514,12 @@ where "called finish_shrink before enough entries were removed" ); - // let shard_size = shards[0].len(); - // for i in (num_buckets as usize)..map.buckets.len() { - // let shard_start = num_buckets / shard_size; - // let shard = shards[shard_start]; - // let entry_start = num_buckets % shard_size; - // for entry_idx in entry_start..shard.len() { - - // } - - // if let EntryKey::Occupied(v) = map.[i].inner.take() { - // // alloc_bucket increases count, so need to decrease since we're just moving - // map.buckets_in_use.fetch_sub(1, Ordering::Relaxed); - // map.alloc_bucket(k, v).unwrap(); - // } - // } - - todo!("dry way to handle reinsertion"); - self.resize_shmem(num_buckets)?; self.reshard(&mut shards, num_buckets); - self.rehash(&mut shards, num_buckets); map.bucket_arr.alloc_limit.store(BucketIdx::INVALID, Ordering::Relaxed); + self.begin_rehash(&mut shards, num_buckets); Ok(()) } diff --git a/libs/neon-shmem/src/hash/bucket.rs b/libs/neon-shmem/src/hash/bucket.rs index 1ccd6dd110..bbd69ca38f 100644 --- a/libs/neon-shmem/src/hash/bucket.rs +++ b/libs/neon-shmem/src/hash/bucket.rs @@ -7,6 +7,8 @@ use atomic::Atomic; #[repr(transparent)] pub(crate) struct BucketIdx(pub(super) u32); +const _: () = assert!(Atomic::::is_lock_free()); + impl BucketIdx { const MARK_TAG: u32 = 0x80000000; pub const INVALID: Self = Self(0x7FFFFFFF); @@ -104,8 +106,7 @@ pub(crate) struct BucketArray<'a, V> { } impl<'a, V> BucketArray<'a, V> { - pub fn new(buckets: &'a mut [Bucket]) -> Self { - debug_assert!(Atomic::::is_lock_free()); + pub fn new(buckets: &'a mut [Bucket]) -> Self { Self { buckets, free_head: Atomic::new(BucketIdx(0)), diff --git a/libs/neon-shmem/src/hash/core.rs b/libs/neon-shmem/src/hash/core.rs index 2d31e7c556..647b267c45 100644 --- a/libs/neon-shmem/src/hash/core.rs +++ b/libs/neon-shmem/src/hash/core.rs @@ -2,6 +2,7 @@ use std::hash::Hash; use std::mem::MaybeUninit; +use std::sync::atomic::{Ordering, AtomicUsize}; use crate::sync::*; use crate::hash::{ @@ -9,7 +10,7 @@ use crate::hash::{ bucket::{BucketArray, Bucket, BucketIdx} }; -#[derive(PartialEq, Eq)] +#[derive(PartialEq, Eq, Clone, Copy)] pub(crate) enum EntryType { Occupied, Rehash, @@ -44,6 +45,8 @@ pub(crate) struct CoreHashMap<'a, K, V> { /// Dictionary used to map hashes to bucket indices. pub(crate) dict_shards: &'a mut [RwLock>], pub(crate) bucket_arr: BucketArray<'a, V>, + pub(crate) rehash_index: AtomicUsize, + pub(crate) rehash_end: AtomicUsize, } /// Error for when there are no empty buckets left but one is needed. @@ -91,12 +94,75 @@ impl<'a, K: Clone + Hash + Eq, V> CoreHashMap<'a, K, V> { Self { dict_shards, + rehash_index: buckets.len().into(), + rehash_end: buckets.len().into(), bucket_arr: BucketArray::new(buckets), } } + + // TODO(quantumish): off by one for return value logic? + pub fn do_rehash(&mut self) -> bool { + // TODO(quantumish): refactor these out into settable quantities + const REHASH_CHUNK_SIZE: usize = 10; + const REHASH_ATTEMPTS: usize = 5; + + let end = self.rehash_end.load(Ordering::Relaxed); + let mut ind = self.rehash_index.load(Ordering::Relaxed); + let mut i = 0; + loop { + if ind >= end { + // TODO(quantumish) questionable? + self.rehash_index.store(end, Ordering::Relaxed); + return true; + } + if i > REHASH_ATTEMPTS { + break; + } + match self.rehash_index.compare_exchange_weak( + ind, ind + REHASH_CHUNK_SIZE, + Ordering::Relaxed, Ordering::Relaxed + ) { + Err(new_ind) => ind = new_ind, + Ok(_) => break, + } + i += 1; + } + + todo!("actual rehashing"); + false + } + + pub fn get_with_hash(&'a self, key: &K, hash: u64) -> Option> { + let ind = self.rehash_index.load(Ordering::Relaxed); + let end = self.rehash_end.load(Ordering::Relaxed); + + let first = ind >= end || ind < end/2; + if let Some(res) = self.get(key, hash, first) { + return Some(res); + } + if ind < end && let Some(res) = self.get(key, hash, !first) { + return Some(res); + } + None + } + + pub fn entry_with_hash(&'a mut self, key: K, hash: u64) -> Result, FullError> { + let ind = self.rehash_index.load(Ordering::Relaxed); + let end = self.rehash_end.load(Ordering::Relaxed); + + if ind < end { + if let Ok(Entry::Occupied(res)) = self.entry(key.clone(), hash, true) { + return Ok(Entry::Occupied(res)); + } else { + return self.entry(key, hash, false); + } + } else { + return self.entry(key.clone(), hash, true); + } + } /// Get the value associated with a key (if it exists) given its hash. - pub fn get_with_hash(&'a self, key: &K, hash: u64) -> Option> { + fn get(&'a self, key: &K, hash: u64, ignore_remap: bool) -> Option> { let num_buckets = self.get_num_buckets(); let shard_size = num_buckets / self.dict_shards.len(); let bucket_pos = hash as usize % num_buckets; @@ -108,24 +174,26 @@ impl<'a, K: Clone + Hash + Eq, V> CoreHashMap<'a, K, V> { for entry_idx in entry_start..shard.len() { match shard.keys[entry_idx].tag { EntryType::Empty => return None, - EntryType::Tombstone => continue, - EntryType::Occupied => { - let cand_key = unsafe { shard.keys[entry_idx].val.assume_init_ref() }; - if cand_key == key { - let bucket_idx = shard.idxs[entry_idx].pos_checked().expect("position is valid"); - return Some(RwLockReadGuard::map( - shard, |_| self.bucket_arr.buckets[bucket_idx].as_ref() - )); - } + EntryType::Tombstone | EntryType::RehashTombstone => continue, + t @ (EntryType::Occupied | EntryType::Rehash) => { + if (t == EntryType::Occupied && ignore_remap) || (t == EntryType::Rehash && !ignore_remap) { + let cand_key = unsafe { shard.keys[entry_idx].val.assume_init_ref() }; + if cand_key == key { + let bucket_idx = shard.idxs[entry_idx].pos_checked() + .expect("position is valid"); + return Some(RwLockReadGuard::map( + shard, |_| self.bucket_arr.buckets[bucket_idx].as_ref() + )); + } + } }, - _ => unreachable!(), } } } None } - pub fn entry_with_hash(&'a mut self, key: K, hash: u64) -> Result, FullError> { + fn entry(&'a mut self, key: K, hash: u64, ignore_remap: bool) -> Result, FullError> { // We need to keep holding on the locks for each shard we process since if we don't find the // key anywhere, we want to insert it at the earliest possible position (which may be several // shards away). Ideally cross-shard chains are quite rare, so this shouldn't be a big deal. @@ -158,26 +226,27 @@ impl<'a, K: Clone + Hash + Eq, V> CoreHashMap<'a, K, V> { bucket_arr: &mut self.bucket_arr, })) }, - EntryType::Tombstone => { + EntryType::Tombstone | EntryType::RehashTombstone => { if insert_pos.is_none() { insert_pos = Some(entry_idx); inserted = true; } }, - EntryType::Occupied => { - let cand_key = unsafe { shard.keys[entry_idx].val.assume_init_ref() }; - if *cand_key == key { - let bucket_pos = shard.idxs[entry_idx].pos_checked().unwrap(); - return Ok(Entry::Occupied(OccupiedEntry { - _key: key, - shard, - shard_pos: entry_idx, - bucket_pos, - bucket_arr: &mut self.bucket_arr, - })); + t @ (EntryType::Occupied | EntryType::Rehash) => { + if (t == EntryType::Occupied && ignore_remap) || (t == EntryType::Rehash && !ignore_remap) { + let cand_key = unsafe { shard.keys[entry_idx].val.assume_init_ref() }; + if *cand_key == key { + let bucket_pos = shard.idxs[entry_idx].pos_checked().unwrap(); + return Ok(Entry::Occupied(OccupiedEntry { + _key: key, + shard, + shard_pos: entry_idx, + bucket_pos, + bucket_arr: &mut self.bucket_arr, + })); + } } } - _ => unreachable!(), } } if inserted { @@ -218,3 +287,4 @@ impl<'a, K: Clone + Hash + Eq, V> CoreHashMap<'a, K, V> { self.bucket_arr.clear(); } } +