diff --git a/libs/neon-shmem/src/hash.rs b/libs/neon-shmem/src/hash.rs index 364787e2b7..efe7c96cb3 100644 --- a/libs/neon-shmem/src/hash.rs +++ b/libs/neon-shmem/src/hash.rs @@ -282,7 +282,7 @@ where inner.buckets = buckets; } - /// Rehash the map. Intended for benchmarking only. + /// Rehash the map. Intended for benchmarking only. pub fn shuffle(&mut self) { let map = unsafe { self.shared_ptr.as_mut() }.unwrap(); let inner = &mut map.inner; @@ -352,7 +352,7 @@ where Ok(()) } - /// Begin a shrink, limiting all new allocations to be in buckets with index less than `num_buckets`. + /// Begin a shrink, limiting all new allocations to be in buckets with index less than `num_buckets`. pub fn begin_shrink(&mut self, num_buckets: u32) { let map = unsafe { self.shared_ptr.as_mut() }.unwrap(); if num_buckets > map.inner.get_num_buckets() as u32 { @@ -377,32 +377,28 @@ where if inner.get_num_buckets() == num_buckets as usize { return Ok(()); - } + } else if inner.get_num_buckets() > num_buckets as usize { + panic!("called finish_shrink before enough entries were removed"); + } for i in (num_buckets as usize)..inner.buckets.len() { - if inner.buckets[i].inner.is_some() { - // TODO(quantumish) Do we want to treat this as a violation of an invariant - // or a legitimate error the caller can run into? Originally I thought this - // could return something like a UnevictedError(index) as soon as it runs - // into something (that way a caller could clear their soon-to-be-shrinked - // buckets by repeatedly trying to call `finish_shrink`). - // - // Would require making a wider error type enum with this and shmem errors. - panic!("unevicted entries in shrinked space") - } - match inner.buckets[i].prev { - PrevPos::First(_) => { - let next_pos = inner.buckets[i].next; - inner.free_head = next_pos; - if next_pos != INVALID_POS { - inner.buckets[next_pos as usize].prev = PrevPos::First(INVALID_POS); - } - }, - PrevPos::Chained(j) => { - let next_pos = inner.buckets[i].next; - inner.buckets[j as usize].next = next_pos; - if next_pos != INVALID_POS { - inner.buckets[next_pos as usize].prev = PrevPos::Chained(j); + if let Some((k, v)) = inner.buckets[i].inner.take() { + inner.alloc_bucket(k, v, inner.buckets[i].prev.unwrap_first()).unwrap(); + } else { + match inner.buckets[i].prev { + PrevPos::First(_) => { + let next_pos = inner.buckets[i].next; + inner.free_head = next_pos; + if next_pos != INVALID_POS { + inner.buckets[next_pos as usize].prev = PrevPos::First(INVALID_POS); + } + }, + PrevPos::Chained(j) => { + let next_pos = inner.buckets[i].next; + inner.buckets[j as usize].next = next_pos; + if next_pos != INVALID_POS { + inner.buckets[next_pos as usize].prev = PrevPos::Chained(j); + } } } } @@ -421,6 +417,5 @@ where inner.alloc_limit = INVALID_POS; Ok(()) - } - + } } diff --git a/libs/neon-shmem/src/hash/entry.rs b/libs/neon-shmem/src/hash/entry.rs index 64820b3d7b..7d3091c754 100644 --- a/libs/neon-shmem/src/hash/entry.rs +++ b/libs/neon-shmem/src/hash/entry.rs @@ -16,6 +16,15 @@ pub(crate) enum PrevPos { Chained(u32), } +impl PrevPos { + pub fn unwrap_first(&self) -> u32 { + match self { + Self::First(i) => *i, + _ => panic!("not first entry in chain") + } + } +} + pub struct OccupiedEntry<'a, 'b, K, V> { pub(crate) map: &'b mut CoreHashMap<'a, K, V>, pub(crate) _key: K, // The key of the occupied entry