Implement shrinking, add basic tests for core operations

This commit is contained in:
David Freifeld
2025-06-16 13:13:38 -07:00
parent b6b122e07b
commit ac87544e79
3 changed files with 193 additions and 98 deletions

View File

@@ -19,7 +19,7 @@ pub mod entry;
#[cfg(test)]
mod tests;
use core::CoreHashMap;
use core::{CoreHashMap, INVALID_POS};
use entry::{Entry, OccupiedEntry};
#[derive(Debug)]
@@ -210,6 +210,53 @@ where
map.inner.buckets_in_use as usize
}
/// Helper function that abstracts the common logic between growing and shrinking.
/// The only significant difference in the rehashing step is how many buckets to rehash!
fn rehash_dict(
&mut self,
inner: &mut CoreHashMap<'a, K, V>,
buckets_ptr: *mut core::Bucket<K, V>,
end_ptr: *mut u8,
num_buckets: u32,
rehash_buckets: u32,
) {
// Recalculate the dictionary
let buckets;
let dictionary;
unsafe {
let buckets_end_ptr = buckets_ptr.add(num_buckets as usize);
let dictionary_ptr: *mut u32 = buckets_end_ptr
.byte_add(buckets_end_ptr.align_offset(align_of::<u32>()))
.cast();
let dictionary_size: usize =
end_ptr.byte_offset_from(buckets_end_ptr) as usize / size_of::<u32>();
buckets = std::slice::from_raw_parts_mut(buckets_ptr, num_buckets as usize);
dictionary = std::slice::from_raw_parts_mut(dictionary_ptr, dictionary_size);
}
for i in 0..dictionary.len() {
dictionary[i] = INVALID_POS;
}
for i in 0..rehash_buckets as usize {
if buckets[i].inner.is_none() {
continue;
}
let mut hasher = DefaultHasher::new();
buckets[i].inner.as_ref().unwrap().0.hash(&mut hasher);
let hash = hasher.finish();
let pos: usize = (hash % dictionary.len() as u64) as usize;
buckets[i].next = dictionary[pos];
dictionary[pos] = i as u32;
}
// Finally, update the CoreHashMap struct
inner.dictionary = dictionary;
inner.buckets = buckets;
}
/// Grow
///
/// 1. grow the underlying shared memory area
@@ -247,46 +294,17 @@ where
} else {
inner.free_head
},
prev: if i > 0 {
i as u32 - 1
} else {
INVALID_POS
},
inner: None,
});
}
}
// Recalculate the dictionary
let buckets;
let dictionary;
unsafe {
let buckets_end_ptr = buckets_ptr.add(num_buckets as usize);
let dictionary_ptr: *mut u32 = buckets_end_ptr
.byte_add(buckets_end_ptr.align_offset(align_of::<u32>()))
.cast();
let dictionary_size: usize =
end_ptr.byte_offset_from(buckets_end_ptr) as usize / size_of::<u32>();
buckets = std::slice::from_raw_parts_mut(buckets_ptr, num_buckets as usize);
dictionary = std::slice::from_raw_parts_mut(dictionary_ptr, dictionary_size);
}
for i in 0..dictionary.len() {
dictionary[i] = core::INVALID_POS;
}
for i in 0..old_num_buckets as usize {
if buckets[i].inner.is_none() {
continue;
}
let mut hasher = DefaultHasher::new();
buckets[i].inner.as_ref().unwrap().0.hash(&mut hasher);
let hash = hasher.finish();
let pos: usize = (hash % dictionary.len() as u64) as usize;
buckets[i].next = dictionary[pos];
dictionary[pos] = i as u32;
}
// Finally, update the CoreHashMap struct
inner.dictionary = dictionary;
inner.buckets = buckets;
self.rehash_dict(inner, buckets_ptr, end_ptr, num_buckets, old_num_buckets);
inner.free_head = old_num_buckets;
Ok(())
@@ -294,7 +312,7 @@ where
fn begin_shrink(&mut self, num_buckets: u32) {
let map = unsafe { self.shared_ptr.as_mut() }.unwrap();
if num_buckets < map.inner.get_num_buckets() as u32 {
if num_buckets > map.inner.get_num_buckets() as u32 {
panic!("shrink called with a larger number of buckets");
}
map.inner.alloc_limit = num_buckets;
@@ -307,14 +325,14 @@ where
panic!("called finish_shrink when no shrink is in progress");
}
let new_num_buckets = inner.alloc_limit;
let num_buckets = inner.alloc_limit;
if inner.get_num_buckets() == new_num_buckets as usize {
if inner.get_num_buckets() == num_buckets as usize {
return Ok(());
}
for b in &inner.buckets[new_num_buckets as usize..] {
if b.inner.is_some() {
for i in (num_buckets as usize)..inner.buckets.len() {
if inner.buckets[i].inner.is_some() {
// TODO(quantumish) Do we want to treat this as a violation of an invariant
// or a legitimate error the caller can run into? Originally I thought this
// could return something like a UnevictedError(index) as soon as it runs
@@ -324,6 +342,10 @@ where
// Would require making a wider error type enum with this and shmem errors.
panic!("unevicted entries in shrinked space")
}
let prev_pos = inner.buckets[i].prev;
if prev_pos != INVALID_POS {
inner.buckets[prev_pos as usize].next = inner.buckets[i].next;
}
}
let shmem_handle = self
@@ -331,22 +353,13 @@ where
.as_ref()
.expect("shrink called on a fixed-size hash table");
let size_bytes = HashMapInit::<K, V>::estimate_size(new_num_buckets);
let size_bytes = HashMapInit::<K, V>::estimate_size(num_buckets);
shmem_handle.set_size(size_bytes)?;
let end_ptr: *mut u8 = unsafe { shmem_handle.data_ptr.as_ptr().add(size_bytes) };
let buckets_ptr = inner.buckets.as_mut_ptr();
self.rehash_dict(inner, buckets_ptr, end_ptr, num_buckets, num_buckets);
Ok(())
}
// TODO: Shrinking is a multi-step process that requires co-operation from the caller
//
// 1. The caller must first call begin_shrink(). That forbids allocation of higher-numbered
// buckets.
//
// 2. Next, the caller must evict all entries in higher-numbered buckets.
//
// 3. Finally, call finish_shrink(). This recomputes the dictionary and shrinks the underlying
// shmem area
}

View File

@@ -13,6 +13,7 @@ pub(crate) const INVALID_POS: u32 = u32::MAX;
// Bucket
pub(crate) struct Bucket<K, V> {
pub(crate) next: u32,
pub(crate) prev: u32,
pub(crate) inner: Option<(K, V)>,
}
@@ -22,6 +23,7 @@ pub(crate) struct CoreHashMap<'a, K, V> {
pub(crate) free_head: u32,
pub(crate) _user_list_head: u32,
/// Maximum index of a bucket allowed to be allocated. INVALID_POS if no limit.
pub(crate) alloc_limit: u32,
// metrics
@@ -62,6 +64,11 @@ where
} else {
INVALID_POS
},
prev: if i > 0 {
i as u32 - 1
} else {
INVALID_POS
},
inner: None,
});
}
@@ -153,45 +160,61 @@ where
self.alloc_limit != INVALID_POS
}
pub fn entry_at_bucket(&mut self, pos: usize) -> Option<OccupiedEntry<K, V>> {
if pos >= self.buckets.len() {
return None;
}
pub fn entry_at_bucket(&mut self, pos: usize) -> Option<OccupiedEntry<'a, '_, K, V>> {
if pos >= self.buckets.len() {
return None;
}
let prev = self.buckets[pos].prev;
let entry = self.buckets[pos].inner.as_ref();
if entry.is_none() {
return None;
}
let (key, _) = entry.unwrap();
}
let (key, _) = entry.unwrap();
Some(OccupiedEntry {
_key: key.clone(), // TODO(quantumish): clone unavoidable?
bucket_pos: pos as u32,
map: self,
prev_pos: todo!(), // TODO(quantumish): possibly needs O(n) traversals to rediscover - costly!
prev_pos: if prev == INVALID_POS {
// TODO(quantumish): populating this correctly would require an O(n) scan over the dictionary
// (perhaps not if we refactored the prev field to be itself something like PrevPos). The real
// question though is whether this even needs to be populated correctly? All downstream uses of
// this function so far are just for deletion, which isn't really concerned with the dictionary.
// Then again, it's unintuitive to appear to return a normal OccupiedEntry which really is fake.
PrevPos::First(todo!("unclear what to do here"))
} else {
PrevPos::Chained(prev)
}
})
}
pub(crate) fn alloc_bucket(&mut self, key: K, value: V) -> Result<u32, FullError> {
let mut pos = self.free_head;
let mut pos = self.free_head;
// TODO(quantumish): relies on INVALID_POS being u32::MAX by default!
// instead add a clause `pos != INVALID_POS`?
let mut prev = PrevPos::First(self.free_head);
while pos < self.alloc_limit {
if pos == INVALID_POS {
return Err(FullError());
}
while pos!= INVALID_POS && pos >= self.alloc_limit {
let bucket = &mut self.buckets[pos as usize];
prev = PrevPos::Chained(pos);
pos = bucket.next;
}
let bucket = &mut self.buckets[pos as usize];
if pos == INVALID_POS {
return Err(FullError());
}
match prev {
PrevPos::First(_) => self.free_head = bucket.next,
PrevPos::Chained(p) => self.buckets[p].next = bucket.next,
PrevPos::First(_) => {
let next_pos = self.buckets[pos as usize].next;
self.free_head = next_pos;
self.buckets[next_pos as usize].prev = INVALID_POS;
}
PrevPos::Chained(p) => if p != INVALID_POS {
let next_pos = self.buckets[pos as usize].next;
self.buckets[p as usize].next = next_pos;
self.buckets[next_pos as usize].prev = p;
},
}
let bucket = &mut self.buckets[pos as usize];
self.buckets_in_use += 1;
bucket.next = INVALID_POS;
bucket.inner = Some((key, value));
@@ -199,3 +222,5 @@ where
return Ok(pos);
}
}

View File

@@ -5,7 +5,7 @@ use std::sync::atomic::{AtomicUsize, Ordering};
use crate::hash::HashMapAccess;
use crate::hash::HashMapInit;
use crate::hash::UpdateAction;
use crate::hash::Entry;
use crate::shmem::ShmemHandle;
use rand::seq::SliceRandom;
@@ -35,20 +35,28 @@ impl<'a> From<&'a [u8]> for TestKey {
}
}
fn test_inserts<K: Into<TestKey> + Copy>(keys: &[K]) {
fn test_inserts<K: Into<TestKey> + Copy>(keys: &[K]) {
const MAX_MEM_SIZE: usize = 10000000;
let shmem = ShmemHandle::new("test_inserts", 0, MAX_MEM_SIZE).unwrap();
let init_struct = HashMapInit::<TestKey, usize>::init_in_shmem(100000, shmem);
let w = init_struct.attach_writer();
let mut w = init_struct.attach_writer();
for (idx, k) in keys.iter().enumerate() {
let res = w.insert(&(*k).into(), idx);
assert!(res.is_ok());
let hash = w.get_hash_value(&(*k).into());
let res = w.entry_with_hash((*k).into(), hash);
match res {
Entry::Occupied(mut e) => { e.insert(idx); }
Entry::Vacant(e) => {
let res = e.insert(idx);
assert!(res.is_ok());
},
};
}
for (idx, k) in keys.iter().enumerate() {
let x = w.get(&(*k).into());
let hash = w.get_hash_value(&(*k).into());
let x = w.get_with_hash(&(*k).into(), hash);
let value = x.as_deref().copied();
assert_eq!(value, Some(idx));
}
@@ -121,7 +129,7 @@ struct TestOp(TestKey, Option<usize>);
fn apply_op(
op: &TestOp,
sut: &HashMapAccess<TestKey, TestValue>,
map: &mut HashMapAccess<TestKey, usize>,
shadow: &mut BTreeMap<TestKey, usize>,
) {
eprintln!("applying op: {op:?}");
@@ -133,21 +141,24 @@ fn apply_op(
shadow.remove(&op.0)
};
// apply to Art tree
sut.update_with_fn(&op.0, |existing| {
assert_eq!(existing.map(TestValue::load), shadow_existing);
let hash = map.get_hash_value(&op.0);
let entry = map.entry_with_hash(op.0, hash);
let hash_existing = match op.1 {
Some(new) => {
match entry {
Entry::Occupied(mut e) => Some(e.insert(new)),
Entry::Vacant(e) => { e.insert(new).unwrap(); None },
}
},
None => {
match entry {
Entry::Occupied(e) => Some(e.remove()),
Entry::Vacant(_) => None,
}
},
};
match (existing, op.1) {
(None, None) => UpdateAction::Nothing,
(None, Some(new_val)) => UpdateAction::Insert(TestValue::new(new_val)),
(Some(_old_val), None) => UpdateAction::Remove,
(Some(old_val), Some(new_val)) => {
old_val.0.store(new_val, Ordering::Relaxed);
UpdateAction::Nothing
}
}
})
.expect("out of memory");
assert_eq!(shadow_existing, hash_existing);
}
#[test]
@@ -155,8 +166,8 @@ fn random_ops() {
const MAX_MEM_SIZE: usize = 10000000;
let shmem = ShmemHandle::new("test_inserts", 0, MAX_MEM_SIZE).unwrap();
let init_struct = HashMapInit::<TestKey, TestValue>::init_in_shmem(100000, shmem);
let writer = init_struct.attach_writer();
let init_struct = HashMapInit::<TestKey, usize>::init_in_shmem(100000, shmem);
let mut writer = init_struct.attach_writer();
let mut shadow: std::collections::BTreeMap<TestKey, usize> = BTreeMap::new();
@@ -167,7 +178,7 @@ fn random_ops() {
let op = TestOp(key, if rng.random_bool(0.75) { Some(i) } else { None });
apply_op(&op, &writer, &mut shadow);
apply_op(&op, &mut writer, &mut shadow);
if i % 1000 == 0 {
eprintln!("{i} ops processed");
@@ -182,8 +193,8 @@ fn test_grow() {
const MEM_SIZE: usize = 10000000;
let shmem = ShmemHandle::new("test_grow", 0, MEM_SIZE).unwrap();
let init_struct = HashMapInit::<TestKey, TestValue>::init_in_shmem(1000, shmem);
let writer = init_struct.attach_writer();
let init_struct = HashMapInit::<TestKey, usize>::init_in_shmem(1000, shmem);
let mut writer = init_struct.attach_writer();
let mut shadow: std::collections::BTreeMap<TestKey, usize> = BTreeMap::new();
@@ -193,7 +204,7 @@ fn test_grow() {
let op = TestOp(key, if rng.random_bool(0.75) { Some(i) } else { None });
apply_op(&op, &writer, &mut shadow);
apply_op(&op, &mut writer, &mut shadow);
if i % 1000 == 0 {
eprintln!("{i} ops processed");
@@ -209,7 +220,7 @@ fn test_grow() {
let op = TestOp(key, if rng.random_bool(0.75) { Some(i) } else { None });
apply_op(&op, &writer, &mut shadow);
apply_op(&op, &mut writer, &mut shadow);
if i % 1000 == 0 {
eprintln!("{i} ops processed");
@@ -218,3 +229,49 @@ fn test_grow() {
}
}
}
#[test]
fn test_shrink() {
const MEM_SIZE: usize = 10000000;
let shmem = ShmemHandle::new("test_shrink", 0, MEM_SIZE).unwrap();
let init_struct = HashMapInit::<TestKey, usize>::init_in_shmem(1500, shmem);
let mut writer = init_struct.attach_writer();
let mut shadow: std::collections::BTreeMap<TestKey, usize> = BTreeMap::new();
let mut rng = rand::rng();
for i in 0..100 {
let key: TestKey = ((rng.next_u32() % 1500) as u128).into();
let op = TestOp(key, if rng.random_bool(0.75) { Some(i) } else { None });
apply_op(&op, &mut writer, &mut shadow);
if i % 1000 == 0 {
eprintln!("{i} ops processed");
}
}
writer.begin_shrink(1000);
for i in 1000..1500 {
if let Some(entry) = writer.entry_at_bucket(i) {
shadow.remove(&entry._key);
entry.remove();
}
}
writer.finish_shrink().unwrap();
for i in 0..10000 {
let key: TestKey = ((rng.next_u32() % 1000) as u128).into();
let op = TestOp(key, if rng.random_bool(0.75) { Some(i) } else { None });
apply_op(&op, &mut writer, &mut shadow);
if i % 1000 == 0 {
eprintln!("{i} ops processed");
}
}
}