diff --git a/libs/neon-shmem/src/hash.rs b/libs/neon-shmem/src/hash.rs index 45e593fc48..f173f42a6d 100644 --- a/libs/neon-shmem/src/hash.rs +++ b/libs/neon-shmem/src/hash.rs @@ -131,7 +131,7 @@ where } HashMapInit { - shmem_handle: shmem_handle, + shmem_handle, shared_ptr, } } @@ -211,7 +211,7 @@ where } /// Helper function that abstracts the common logic between growing and shrinking. - /// The only significant difference in the rehashing step is how many buckets to rehash! + /// The only significant difference in the rehashing step is how many buckets to rehash. fn rehash_dict( &mut self, inner: &mut CoreHashMap<'a, K, V>, @@ -310,7 +310,8 @@ where Ok(()) } - fn begin_shrink(&mut self, num_buckets: u32) { + /// Begin a shrink, limiting all new allocations to be in buckets with index less than `num_buckets`. + pub fn begin_shrink(&mut self, num_buckets: u32) { let map = unsafe { self.shared_ptr.as_mut() }.unwrap(); if num_buckets > map.inner.get_num_buckets() as u32 { panic!("shrink called with a larger number of buckets"); @@ -322,7 +323,8 @@ where map.inner.alloc_limit = num_buckets; } - fn finish_shrink(&mut self) -> Result<(), crate::shmem::Error> { + /// Complete a shrink after caller has evicted entries, removing the unused buckets and rehashing. + pub fn finish_shrink(&mut self) -> Result<(), crate::shmem::Error> { let map = unsafe { self.shared_ptr.as_mut() }.unwrap(); let inner = &mut map.inner; if !inner.is_shrinking() { diff --git a/libs/neon-shmem/src/hash/core.rs b/libs/neon-shmem/src/hash/core.rs index d02d234a87..80f41fd8d4 100644 --- a/libs/neon-shmem/src/hash/core.rs +++ b/libs/neon-shmem/src/hash/core.rs @@ -159,7 +159,7 @@ where pub fn is_shrinking(&self) -> bool { self.alloc_limit != INVALID_POS } - + pub fn entry_at_bucket(&mut self, pos: usize) -> Option> { if pos >= self.buckets.len() { return None; @@ -167,22 +167,22 @@ where let prev = self.buckets[pos].prev; let entry = self.buckets[pos].inner.as_ref(); - if entry.is_none() { - return None; + match entry { + Some((key, _)) => Some(OccupiedEntry { + _key: key.clone(), + bucket_pos: pos as u32, + prev_pos: prev, + map: self, + }), + _ => None, } - - let (key, _) = entry.unwrap(); - Some(OccupiedEntry { - _key: key.clone(), // TODO(quantumish): clone unavoidable? - bucket_pos: pos as u32, - map: self, - prev_pos: prev, - }) } - - pub(crate) fn alloc_bucket(&mut self, key: K, value: V) -> Result { + + /// Find the position of an unused bucket via the freelist and initialize it. + pub(crate) fn alloc_bucket(&mut self, key: K, value: V, dict_pos: u32) -> Result { let mut pos = self.free_head; + // Find the first bucket we're *allowed* to use. let mut prev = PrevPos::First(self.free_head); while pos != INVALID_POS && pos >= self.alloc_limit { let bucket = &mut self.buckets[pos as usize]; @@ -192,15 +192,14 @@ where if pos == INVALID_POS { return Err(FullError()); } + + // Repair the freelist. match prev { PrevPos::First(_) => { let next_pos = self.buckets[pos as usize].next; self.free_head = next_pos; - // HACK(quantumish): Really, the INVALID_POS should be the position within the dictionary. - // This isn't passed into this function, though, and so for now rather than changing that - // we can just check it from `alloc_bucket`. Not a great solution. if next_pos != INVALID_POS { - self.buckets[next_pos as usize].prev = PrevPos::First(INVALID_POS); + self.buckets[next_pos as usize].prev = PrevPos::First(dict_pos); } } PrevPos::Chained(p) => if p != INVALID_POS { @@ -211,7 +210,8 @@ where } }, } - + + // Initialize the bucket. let bucket = &mut self.buckets[pos as usize]; self.buckets_in_use += 1; bucket.next = INVALID_POS; diff --git a/libs/neon-shmem/src/hash/entry.rs b/libs/neon-shmem/src/hash/entry.rs index 147a464745..64820b3d7b 100644 --- a/libs/neon-shmem/src/hash/entry.rs +++ b/libs/neon-shmem/src/hash/entry.rs @@ -82,7 +82,7 @@ pub struct VacantEntry<'a, 'b, K, V> { impl<'a, 'b, K: Clone + Hash + Eq, V> VacantEntry<'a, 'b, K, V> { pub fn insert(self, value: V) -> Result<&'b mut V, FullError> { - let pos = self.map.alloc_bucket(self.key, value)?; + let pos = self.map.alloc_bucket(self.key, value, self.dict_pos)?; if pos == INVALID_POS { return Err(FullError()); } diff --git a/libs/neon-shmem/src/hash/tests.rs b/libs/neon-shmem/src/hash/tests.rs index b95c33e578..ee6acfc144 100644 --- a/libs/neon-shmem/src/hash/tests.rs +++ b/libs/neon-shmem/src/hash/tests.rs @@ -62,8 +62,6 @@ fn test_inserts + Copy>(keys: &[K]) { let value = x.as_deref().copied(); assert_eq!(value, Some(idx)); } - - //eprintln!("stats: {:?}", tree_writer.get_statistics()); } #[test] @@ -188,6 +186,23 @@ fn do_deletes( } } +fn do_shrink( + writer: &mut HashMapAccess, + shadow: &mut BTreeMap, + from: u32, + to: u32 +) { + writer.begin_shrink(to); + for i in to..from { + if let Some(entry) = writer.entry_at_bucket(i as usize) { + shadow.remove(&entry._key); + entry.remove(); + } + } + writer.finish_shrink().unwrap(); + +} + #[test] fn random_ops() { let shmem = ShmemHandle::new("test_inserts", 0, 10000000).unwrap(); @@ -208,8 +223,6 @@ fn random_ops() { if i % 1000 == 0 { eprintln!("{i} ops processed"); - //eprintln!("stats: {:?}", tree_writer.get_statistics()); - //test_iter(&tree_writer, &shadow); } } } @@ -227,23 +240,6 @@ fn test_grow() { do_random_ops(10000, 1500, 0.75, &mut writer, &mut shadow, &mut rng); } -fn do_shrink( - writer: &mut HashMapAccess, - shadow: &mut BTreeMap, - from: u32, - to: u32 -) { - writer.begin_shrink(to); - for i in to..from { - if let Some(entry) = writer.entry_at_bucket(i as usize) { - shadow.remove(&entry._key); - entry.remove(); - } - } - writer.finish_shrink().unwrap(); - -} - #[test] fn test_shrink() { let shmem = ShmemHandle::new("test_shrink", 0, 10000000).unwrap(); @@ -261,7 +257,7 @@ fn test_shrink() { #[test] fn test_shrink_grow_seq() { - let shmem = ShmemHandle::new("test_shrink", 0, 10000000).unwrap(); + let shmem = ShmemHandle::new("test_shrink_grow_seq", 0, 10000000).unwrap(); let init_struct = HashMapInit::::init_in_shmem(1500, shmem); let mut writer = init_struct.attach_writer(); let mut shadow: std::collections::BTreeMap = BTreeMap::new(); @@ -283,11 +279,73 @@ fn test_shrink_grow_seq() { do_random_ops(10000, 5000, 0.25, &mut writer, &mut shadow, &mut rng); } +#[test] +fn test_bucket_ops() { + let shmem = ShmemHandle::new("test_bucket_ops", 0, 10000000).unwrap(); + let init_struct = HashMapInit::::init_in_shmem(1000, shmem); + let mut writer = init_struct.attach_writer(); + let hash = writer.get_hash_value(&1.into()); + match writer.entry_with_hash(1.into(), hash) { + Entry::Occupied(mut e) => { e.insert(2); }, + Entry::Vacant(e) => { e.insert(2).unwrap(); }, + } + assert_eq!(writer.get_num_buckets_in_use(), 1); + assert_eq!(writer.get_num_buckets(), 1000); + assert_eq!(writer.get_with_hash(&1.into(), hash), Some(&2)); + match writer.entry_with_hash(1.into(), hash) { + Entry::Occupied(e) => { + assert_eq!(e._key, 1.into()); + let pos = e.bucket_pos as usize; + assert_eq!(writer.entry_at_bucket(pos).unwrap()._key, 1.into()); + assert_eq!(writer.get_at_bucket(pos), Some(&(1.into(), 2))); + }, + Entry::Vacant(_) => { panic!("Insert didn't affect entry"); }, + } + writer.remove_with_hash(&1.into(), hash); + assert_eq!(writer.get_with_hash(&1.into(), hash), None); +} + +#[test] +fn test_shrink_zero() { + let shmem = ShmemHandle::new("test_shrink_zero", 0, 10000000).unwrap(); + let init_struct = HashMapInit::::init_in_shmem(1500, shmem); + let mut writer = init_struct.attach_writer(); + writer.begin_shrink(0); + for i in 0..1500 { + writer.entry_at_bucket(i).map(|x| x.remove()); + } + writer.finish_shrink().unwrap(); + assert_eq!(writer.get_num_buckets_in_use(), 0); + let hash = writer.get_hash_value(&1.into()); + let entry = writer.entry_with_hash(1.into(), hash); + if let Entry::Vacant(v) = entry { + assert!(v.insert(2).is_err()); + } else { + panic!("Somehow got non-vacant entry in empty map.") + } + writer.grow(50).unwrap(); + let entry = writer.entry_with_hash(1.into(), hash); + if let Entry::Vacant(v) = entry { + assert!(v.insert(2).is_ok()); + } else { + panic!("Somehow got non-vacant entry in empty map.") + } + assert_eq!(writer.get_num_buckets_in_use(), 1); +} + +#[test] +#[should_panic] +fn test_grow_oom() { + let shmem = ShmemHandle::new("test_grow_oom", 0, 500).unwrap(); + let init_struct = HashMapInit::::init_in_shmem(5, shmem); + let mut writer = init_struct.attach_writer(); + writer.grow(20000).unwrap(); +} #[test] #[should_panic] fn test_shrink_bigger() { - let shmem = ShmemHandle::new("test_shrink", 0, 10000000).unwrap(); + let shmem = ShmemHandle::new("test_shrink_bigger", 0, 10000000).unwrap(); let init_struct = HashMapInit::::init_in_shmem(1500, shmem); let mut writer = init_struct.attach_writer(); writer.begin_shrink(2000); @@ -296,7 +354,7 @@ fn test_shrink_bigger() { #[test] #[should_panic] fn test_shrink_early_finish() { - let shmem = ShmemHandle::new("test_shrink", 0, 10000000).unwrap(); + let shmem = ShmemHandle::new("test_shrink_early_finish", 0, 10000000).unwrap(); let init_struct = HashMapInit::::init_in_shmem(1500, shmem); let mut writer = init_struct.attach_writer(); writer.finish_shrink().unwrap(); @@ -310,3 +368,4 @@ fn test_shrink_fixed_size() { let mut writer = init_struct.attach_writer(); writer.begin_shrink(1); } +