diff --git a/libs/neon-shmem/src/hash.rs b/libs/neon-shmem/src/hash.rs index 9b1c1cee89..907a32bfca 100644 --- a/libs/neon-shmem/src/hash.rs +++ b/libs/neon-shmem/src/hash.rs @@ -19,7 +19,7 @@ pub mod entry; #[cfg(test)] mod tests; -use core::CoreHashMap; +use core::{CoreHashMap, INVALID_POS}; use entry::{Entry, OccupiedEntry}; #[derive(Debug)] @@ -210,6 +210,53 @@ where map.inner.buckets_in_use as usize } + /// Helper function that abstracts the common logic between growing and shrinking. + /// The only significant difference in the rehashing step is how many buckets to rehash! + fn rehash_dict( + &mut self, + inner: &mut CoreHashMap<'a, K, V>, + buckets_ptr: *mut core::Bucket, + end_ptr: *mut u8, + num_buckets: u32, + rehash_buckets: u32, + ) { + // Recalculate the dictionary + let buckets; + let dictionary; + unsafe { + let buckets_end_ptr = buckets_ptr.add(num_buckets as usize); + let dictionary_ptr: *mut u32 = buckets_end_ptr + .byte_add(buckets_end_ptr.align_offset(align_of::())) + .cast(); + let dictionary_size: usize = + end_ptr.byte_offset_from(buckets_end_ptr) as usize / size_of::(); + + buckets = std::slice::from_raw_parts_mut(buckets_ptr, num_buckets as usize); + dictionary = std::slice::from_raw_parts_mut(dictionary_ptr, dictionary_size); + } + for i in 0..dictionary.len() { + dictionary[i] = INVALID_POS; + } + + for i in 0..rehash_buckets as usize { + if buckets[i].inner.is_none() { + continue; + } + + let mut hasher = DefaultHasher::new(); + buckets[i].inner.as_ref().unwrap().0.hash(&mut hasher); + let hash = hasher.finish(); + + let pos: usize = (hash % dictionary.len() as u64) as usize; + buckets[i].next = dictionary[pos]; + dictionary[pos] = i as u32; + } + + // Finally, update the CoreHashMap struct + inner.dictionary = dictionary; + inner.buckets = buckets; + } + /// Grow /// /// 1. grow the underlying shared memory area @@ -247,46 +294,17 @@ where } else { inner.free_head }, + prev: if i > 0 { + i as u32 - 1 + } else { + INVALID_POS + }, inner: None, }); } } - // Recalculate the dictionary - let buckets; - let dictionary; - unsafe { - let buckets_end_ptr = buckets_ptr.add(num_buckets as usize); - let dictionary_ptr: *mut u32 = buckets_end_ptr - .byte_add(buckets_end_ptr.align_offset(align_of::())) - .cast(); - let dictionary_size: usize = - end_ptr.byte_offset_from(buckets_end_ptr) as usize / size_of::(); - - buckets = std::slice::from_raw_parts_mut(buckets_ptr, num_buckets as usize); - dictionary = std::slice::from_raw_parts_mut(dictionary_ptr, dictionary_size); - } - for i in 0..dictionary.len() { - dictionary[i] = core::INVALID_POS; - } - - for i in 0..old_num_buckets as usize { - if buckets[i].inner.is_none() { - continue; - } - - let mut hasher = DefaultHasher::new(); - buckets[i].inner.as_ref().unwrap().0.hash(&mut hasher); - let hash = hasher.finish(); - - let pos: usize = (hash % dictionary.len() as u64) as usize; - buckets[i].next = dictionary[pos]; - dictionary[pos] = i as u32; - } - - // Finally, update the CoreHashMap struct - inner.dictionary = dictionary; - inner.buckets = buckets; + self.rehash_dict(inner, buckets_ptr, end_ptr, num_buckets, old_num_buckets); inner.free_head = old_num_buckets; Ok(()) @@ -294,7 +312,7 @@ where fn begin_shrink(&mut self, num_buckets: u32) { let map = unsafe { self.shared_ptr.as_mut() }.unwrap(); - if num_buckets < map.inner.get_num_buckets() as u32 { + if num_buckets > map.inner.get_num_buckets() as u32 { panic!("shrink called with a larger number of buckets"); } map.inner.alloc_limit = num_buckets; @@ -307,14 +325,14 @@ where panic!("called finish_shrink when no shrink is in progress"); } - let new_num_buckets = inner.alloc_limit; + let num_buckets = inner.alloc_limit; - if inner.get_num_buckets() == new_num_buckets as usize { + if inner.get_num_buckets() == num_buckets as usize { return Ok(()); } - for b in &inner.buckets[new_num_buckets as usize..] { - if b.inner.is_some() { + for i in (num_buckets as usize)..inner.buckets.len() { + if inner.buckets[i].inner.is_some() { // TODO(quantumish) Do we want to treat this as a violation of an invariant // or a legitimate error the caller can run into? Originally I thought this // could return something like a UnevictedError(index) as soon as it runs @@ -324,6 +342,10 @@ where // Would require making a wider error type enum with this and shmem errors. panic!("unevicted entries in shrinked space") } + let prev_pos = inner.buckets[i].prev; + if prev_pos != INVALID_POS { + inner.buckets[prev_pos as usize].next = inner.buckets[i].next; + } } let shmem_handle = self @@ -331,22 +353,13 @@ where .as_ref() .expect("shrink called on a fixed-size hash table"); - let size_bytes = HashMapInit::::estimate_size(new_num_buckets); + let size_bytes = HashMapInit::::estimate_size(num_buckets); shmem_handle.set_size(size_bytes)?; let end_ptr: *mut u8 = unsafe { shmem_handle.data_ptr.as_ptr().add(size_bytes) }; - let buckets_ptr = inner.buckets.as_mut_ptr(); + self.rehash_dict(inner, buckets_ptr, end_ptr, num_buckets, num_buckets); Ok(()) } - // TODO: Shrinking is a multi-step process that requires co-operation from the caller - // - // 1. The caller must first call begin_shrink(). That forbids allocation of higher-numbered - // buckets. - // - // 2. Next, the caller must evict all entries in higher-numbered buckets. - // - // 3. Finally, call finish_shrink(). This recomputes the dictionary and shrinks the underlying - // shmem area } diff --git a/libs/neon-shmem/src/hash/core.rs b/libs/neon-shmem/src/hash/core.rs index eb60e21bad..1e0ebede4a 100644 --- a/libs/neon-shmem/src/hash/core.rs +++ b/libs/neon-shmem/src/hash/core.rs @@ -13,6 +13,7 @@ pub(crate) const INVALID_POS: u32 = u32::MAX; // Bucket pub(crate) struct Bucket { pub(crate) next: u32, + pub(crate) prev: u32, pub(crate) inner: Option<(K, V)>, } @@ -22,6 +23,7 @@ pub(crate) struct CoreHashMap<'a, K, V> { pub(crate) free_head: u32, pub(crate) _user_list_head: u32, + /// Maximum index of a bucket allowed to be allocated. INVALID_POS if no limit. pub(crate) alloc_limit: u32, // metrics @@ -62,6 +64,11 @@ where } else { INVALID_POS }, + prev: if i > 0 { + i as u32 - 1 + } else { + INVALID_POS + }, inner: None, }); } @@ -153,45 +160,61 @@ where self.alloc_limit != INVALID_POS } - pub fn entry_at_bucket(&mut self, pos: usize) -> Option> { - if pos >= self.buckets.len() { - return None; - } + pub fn entry_at_bucket(&mut self, pos: usize) -> Option> { + if pos >= self.buckets.len() { + return None; + } + let prev = self.buckets[pos].prev; let entry = self.buckets[pos].inner.as_ref(); if entry.is_none() { return None; - } - - let (key, _) = entry.unwrap(); + } + + let (key, _) = entry.unwrap(); Some(OccupiedEntry { _key: key.clone(), // TODO(quantumish): clone unavoidable? bucket_pos: pos as u32, map: self, - prev_pos: todo!(), // TODO(quantumish): possibly needs O(n) traversals to rediscover - costly! + prev_pos: if prev == INVALID_POS { + // TODO(quantumish): populating this correctly would require an O(n) scan over the dictionary + // (perhaps not if we refactored the prev field to be itself something like PrevPos). The real + // question though is whether this even needs to be populated correctly? All downstream uses of + // this function so far are just for deletion, which isn't really concerned with the dictionary. + // Then again, it's unintuitive to appear to return a normal OccupiedEntry which really is fake. + PrevPos::First(todo!("unclear what to do here")) + } else { + PrevPos::Chained(prev) + } }) } pub(crate) fn alloc_bucket(&mut self, key: K, value: V) -> Result { - let mut pos = self.free_head; + let mut pos = self.free_head; - // TODO(quantumish): relies on INVALID_POS being u32::MAX by default! - // instead add a clause `pos != INVALID_POS`? let mut prev = PrevPos::First(self.free_head); - while pos < self.alloc_limit { - if pos == INVALID_POS { - return Err(FullError()); - } + while pos!= INVALID_POS && pos >= self.alloc_limit { let bucket = &mut self.buckets[pos as usize]; prev = PrevPos::Chained(pos); pos = bucket.next; } - let bucket = &mut self.buckets[pos as usize]; + if pos == INVALID_POS { + return Err(FullError()); + } match prev { - PrevPos::First(_) => self.free_head = bucket.next, - PrevPos::Chained(p) => self.buckets[p].next = bucket.next, + PrevPos::First(_) => { + let next_pos = self.buckets[pos as usize].next; + self.free_head = next_pos; + self.buckets[next_pos as usize].prev = INVALID_POS; + } + PrevPos::Chained(p) => if p != INVALID_POS { + let next_pos = self.buckets[pos as usize].next; + self.buckets[p as usize].next = next_pos; + self.buckets[next_pos as usize].prev = p; + }, } + let bucket = &mut self.buckets[pos as usize]; self.buckets_in_use += 1; bucket.next = INVALID_POS; bucket.inner = Some((key, value)); @@ -199,3 +222,5 @@ where return Ok(pos); } } + + diff --git a/libs/neon-shmem/src/hash/tests.rs b/libs/neon-shmem/src/hash/tests.rs index 073aea5220..c207e35a56 100644 --- a/libs/neon-shmem/src/hash/tests.rs +++ b/libs/neon-shmem/src/hash/tests.rs @@ -5,7 +5,7 @@ use std::sync::atomic::{AtomicUsize, Ordering}; use crate::hash::HashMapAccess; use crate::hash::HashMapInit; -use crate::hash::UpdateAction; +use crate::hash::Entry; use crate::shmem::ShmemHandle; use rand::seq::SliceRandom; @@ -35,20 +35,28 @@ impl<'a> From<&'a [u8]> for TestKey { } } -fn test_inserts + Copy>(keys: &[K]) { +fn test_inserts + Copy>(keys: &[K]) { const MAX_MEM_SIZE: usize = 10000000; let shmem = ShmemHandle::new("test_inserts", 0, MAX_MEM_SIZE).unwrap(); let init_struct = HashMapInit::::init_in_shmem(100000, shmem); - let w = init_struct.attach_writer(); + let mut w = init_struct.attach_writer(); for (idx, k) in keys.iter().enumerate() { - let res = w.insert(&(*k).into(), idx); - assert!(res.is_ok()); + let hash = w.get_hash_value(&(*k).into()); + let res = w.entry_with_hash((*k).into(), hash); + match res { + Entry::Occupied(mut e) => { e.insert(idx); } + Entry::Vacant(e) => { + let res = e.insert(idx); + assert!(res.is_ok()); + }, + }; } for (idx, k) in keys.iter().enumerate() { - let x = w.get(&(*k).into()); + let hash = w.get_hash_value(&(*k).into()); + let x = w.get_with_hash(&(*k).into(), hash); let value = x.as_deref().copied(); assert_eq!(value, Some(idx)); } @@ -121,7 +129,7 @@ struct TestOp(TestKey, Option); fn apply_op( op: &TestOp, - sut: &HashMapAccess, + map: &mut HashMapAccess, shadow: &mut BTreeMap, ) { eprintln!("applying op: {op:?}"); @@ -133,21 +141,24 @@ fn apply_op( shadow.remove(&op.0) }; - // apply to Art tree - sut.update_with_fn(&op.0, |existing| { - assert_eq!(existing.map(TestValue::load), shadow_existing); + let hash = map.get_hash_value(&op.0); + let entry = map.entry_with_hash(op.0, hash); + let hash_existing = match op.1 { + Some(new) => { + match entry { + Entry::Occupied(mut e) => Some(e.insert(new)), + Entry::Vacant(e) => { e.insert(new).unwrap(); None }, + } + }, + None => { + match entry { + Entry::Occupied(e) => Some(e.remove()), + Entry::Vacant(_) => None, + } + }, + }; - match (existing, op.1) { - (None, None) => UpdateAction::Nothing, - (None, Some(new_val)) => UpdateAction::Insert(TestValue::new(new_val)), - (Some(_old_val), None) => UpdateAction::Remove, - (Some(old_val), Some(new_val)) => { - old_val.0.store(new_val, Ordering::Relaxed); - UpdateAction::Nothing - } - } - }) - .expect("out of memory"); + assert_eq!(shadow_existing, hash_existing); } #[test] @@ -155,8 +166,8 @@ fn random_ops() { const MAX_MEM_SIZE: usize = 10000000; let shmem = ShmemHandle::new("test_inserts", 0, MAX_MEM_SIZE).unwrap(); - let init_struct = HashMapInit::::init_in_shmem(100000, shmem); - let writer = init_struct.attach_writer(); + let init_struct = HashMapInit::::init_in_shmem(100000, shmem); + let mut writer = init_struct.attach_writer(); let mut shadow: std::collections::BTreeMap = BTreeMap::new(); @@ -167,7 +178,7 @@ fn random_ops() { let op = TestOp(key, if rng.random_bool(0.75) { Some(i) } else { None }); - apply_op(&op, &writer, &mut shadow); + apply_op(&op, &mut writer, &mut shadow); if i % 1000 == 0 { eprintln!("{i} ops processed"); @@ -182,8 +193,8 @@ fn test_grow() { const MEM_SIZE: usize = 10000000; let shmem = ShmemHandle::new("test_grow", 0, MEM_SIZE).unwrap(); - let init_struct = HashMapInit::::init_in_shmem(1000, shmem); - let writer = init_struct.attach_writer(); + let init_struct = HashMapInit::::init_in_shmem(1000, shmem); + let mut writer = init_struct.attach_writer(); let mut shadow: std::collections::BTreeMap = BTreeMap::new(); @@ -193,7 +204,7 @@ fn test_grow() { let op = TestOp(key, if rng.random_bool(0.75) { Some(i) } else { None }); - apply_op(&op, &writer, &mut shadow); + apply_op(&op, &mut writer, &mut shadow); if i % 1000 == 0 { eprintln!("{i} ops processed"); @@ -209,7 +220,7 @@ fn test_grow() { let op = TestOp(key, if rng.random_bool(0.75) { Some(i) } else { None }); - apply_op(&op, &writer, &mut shadow); + apply_op(&op, &mut writer, &mut shadow); if i % 1000 == 0 { eprintln!("{i} ops processed"); @@ -218,3 +229,49 @@ fn test_grow() { } } } + + +#[test] +fn test_shrink() { + const MEM_SIZE: usize = 10000000; + let shmem = ShmemHandle::new("test_shrink", 0, MEM_SIZE).unwrap(); + + let init_struct = HashMapInit::::init_in_shmem(1500, shmem); + let mut writer = init_struct.attach_writer(); + + let mut shadow: std::collections::BTreeMap = BTreeMap::new(); + + let mut rng = rand::rng(); + for i in 0..100 { + let key: TestKey = ((rng.next_u32() % 1500) as u128).into(); + + let op = TestOp(key, if rng.random_bool(0.75) { Some(i) } else { None }); + + apply_op(&op, &mut writer, &mut shadow); + + if i % 1000 == 0 { + eprintln!("{i} ops processed"); + } + } + + writer.begin_shrink(1000); + for i in 1000..1500 { + if let Some(entry) = writer.entry_at_bucket(i) { + shadow.remove(&entry._key); + entry.remove(); + } + } + writer.finish_shrink().unwrap(); + + for i in 0..10000 { + let key: TestKey = ((rng.next_u32() % 1000) as u128).into(); + + let op = TestOp(key, if rng.random_bool(0.75) { Some(i) } else { None }); + + apply_op(&op, &mut writer, &mut shadow); + + if i % 1000 == 0 { + eprintln!("{i} ops processed"); + } + } +}