Add initial work in implementing incremental resizing (WIP)

This commit is contained in:
David Freifeld
2025-07-14 09:02:56 -07:00
parent 282b90df28
commit add1a0ad78
3 changed files with 122 additions and 55 deletions

View File

@@ -354,26 +354,40 @@ where
map.clear();
}
pub fn rehash(
fn begin_rehash(
&self,
shards: &mut Vec<RwLockWriteGuard<'_, DictShard<'_, K>>>,
rehash_buckets: usize
) {
) -> bool {
let map = unsafe { self.shared_ptr.as_mut() }.unwrap();
assert!(rehash_buckets <= map.get_num_buckets(), "rehashing subset of buckets");
if map.rehash_index.load(Ordering::Relaxed) >= map.rehash_end.load(Ordering::Relaxed) {
return false;
}
shards.iter_mut().for_each(|x| x.keys.iter_mut().for_each(|key| {
if let EntryType::Occupied = key.tag {
key.tag = EntryType::Rehash;
match key.tag {
EntryType::Occupied => key.tag = EntryType::Rehash,
EntryType::Tombstone => key.tag = EntryType::RehashTombstone,
_ => (),
}
}));
todo!("solution with no memory allocation: split out metadata?")
map.rehash_index.store(0, Ordering::Relaxed);
map.rehash_end.store(rehash_buckets, Ordering::Relaxed);
true
}
pub fn finish_rehash(&self) {
let map = unsafe { self.shared_ptr.as_mut() }.unwrap();
while map.do_rehash() {}
}
pub fn shuffle(&self) {
let map = unsafe { self.shared_ptr.as_mut() }.unwrap();
let mut shards: Vec<_> = map.dict_shards.iter().map(|x| x.write()).collect();
self.rehash(&mut shards, map.get_num_buckets());
self.begin_rehash(&mut shards, map.get_num_buckets());
}
fn reshard(&self, shards: &mut Vec<RwLockWriteGuard<'_, DictShard<'_, K>>>, num_buckets: usize) {
@@ -451,10 +465,10 @@ where
}
self.reshard(&mut shards, num_buckets);
self.rehash(&mut shards, old_num_buckets);
map.bucket_arr.free_head.store(
BucketIdx::new(old_num_buckets), Ordering::Relaxed
);
self.begin_rehash(&mut shards, old_num_buckets);
Ok(())
}
@@ -500,30 +514,12 @@ where
"called finish_shrink before enough entries were removed"
);
// let shard_size = shards[0].len();
// for i in (num_buckets as usize)..map.buckets.len() {
// let shard_start = num_buckets / shard_size;
// let shard = shards[shard_start];
// let entry_start = num_buckets % shard_size;
// for entry_idx in entry_start..shard.len() {
// }
// if let EntryKey::Occupied(v) = map.[i].inner.take() {
// // alloc_bucket increases count, so need to decrease since we're just moving
// map.buckets_in_use.fetch_sub(1, Ordering::Relaxed);
// map.alloc_bucket(k, v).unwrap();
// }
// }
todo!("dry way to handle reinsertion");
self.resize_shmem(num_buckets)?;
self.reshard(&mut shards, num_buckets);
self.rehash(&mut shards, num_buckets);
map.bucket_arr.alloc_limit.store(BucketIdx::INVALID, Ordering::Relaxed);
self.begin_rehash(&mut shards, num_buckets);
Ok(())
}

View File

@@ -7,6 +7,8 @@ use atomic::Atomic;
#[repr(transparent)]
pub(crate) struct BucketIdx(pub(super) u32);
const _: () = assert!(Atomic::<BucketIdx>::is_lock_free());
impl BucketIdx {
const MARK_TAG: u32 = 0x80000000;
pub const INVALID: Self = Self(0x7FFFFFFF);
@@ -104,8 +106,7 @@ pub(crate) struct BucketArray<'a, V> {
}
impl<'a, V> BucketArray<'a, V> {
pub fn new(buckets: &'a mut [Bucket<V>]) -> Self {
debug_assert!(Atomic::<BucketIdx>::is_lock_free());
pub fn new(buckets: &'a mut [Bucket<V>]) -> Self {
Self {
buckets,
free_head: Atomic::new(BucketIdx(0)),

View File

@@ -2,6 +2,7 @@
use std::hash::Hash;
use std::mem::MaybeUninit;
use std::sync::atomic::{Ordering, AtomicUsize};
use crate::sync::*;
use crate::hash::{
@@ -9,7 +10,7 @@ use crate::hash::{
bucket::{BucketArray, Bucket, BucketIdx}
};
#[derive(PartialEq, Eq)]
#[derive(PartialEq, Eq, Clone, Copy)]
pub(crate) enum EntryType {
Occupied,
Rehash,
@@ -44,6 +45,8 @@ pub(crate) struct CoreHashMap<'a, K, V> {
/// Dictionary used to map hashes to bucket indices.
pub(crate) dict_shards: &'a mut [RwLock<DictShard<'a, K>>],
pub(crate) bucket_arr: BucketArray<'a, V>,
pub(crate) rehash_index: AtomicUsize,
pub(crate) rehash_end: AtomicUsize,
}
/// Error for when there are no empty buckets left but one is needed.
@@ -91,12 +94,75 @@ impl<'a, K: Clone + Hash + Eq, V> CoreHashMap<'a, K, V> {
Self {
dict_shards,
rehash_index: buckets.len().into(),
rehash_end: buckets.len().into(),
bucket_arr: BucketArray::new(buckets),
}
}
// TODO(quantumish): off by one for return value logic?
pub fn do_rehash(&mut self) -> bool {
// TODO(quantumish): refactor these out into settable quantities
const REHASH_CHUNK_SIZE: usize = 10;
const REHASH_ATTEMPTS: usize = 5;
let end = self.rehash_end.load(Ordering::Relaxed);
let mut ind = self.rehash_index.load(Ordering::Relaxed);
let mut i = 0;
loop {
if ind >= end {
// TODO(quantumish) questionable?
self.rehash_index.store(end, Ordering::Relaxed);
return true;
}
if i > REHASH_ATTEMPTS {
break;
}
match self.rehash_index.compare_exchange_weak(
ind, ind + REHASH_CHUNK_SIZE,
Ordering::Relaxed, Ordering::Relaxed
) {
Err(new_ind) => ind = new_ind,
Ok(_) => break,
}
i += 1;
}
todo!("actual rehashing");
false
}
pub fn get_with_hash(&'a self, key: &K, hash: u64) -> Option<ValueReadGuard<'a, V>> {
let ind = self.rehash_index.load(Ordering::Relaxed);
let end = self.rehash_end.load(Ordering::Relaxed);
let first = ind >= end || ind < end/2;
if let Some(res) = self.get(key, hash, first) {
return Some(res);
}
if ind < end && let Some(res) = self.get(key, hash, !first) {
return Some(res);
}
None
}
pub fn entry_with_hash(&'a mut self, key: K, hash: u64) -> Result<Entry<'a, K, V>, FullError> {
let ind = self.rehash_index.load(Ordering::Relaxed);
let end = self.rehash_end.load(Ordering::Relaxed);
if ind < end {
if let Ok(Entry::Occupied(res)) = self.entry(key.clone(), hash, true) {
return Ok(Entry::Occupied(res));
} else {
return self.entry(key, hash, false);
}
} else {
return self.entry(key.clone(), hash, true);
}
}
/// Get the value associated with a key (if it exists) given its hash.
pub fn get_with_hash(&'a self, key: &K, hash: u64) -> Option<ValueReadGuard<'a, V>> {
fn get(&'a self, key: &K, hash: u64, ignore_remap: bool) -> Option<ValueReadGuard<'a, V>> {
let num_buckets = self.get_num_buckets();
let shard_size = num_buckets / self.dict_shards.len();
let bucket_pos = hash as usize % num_buckets;
@@ -108,24 +174,26 @@ impl<'a, K: Clone + Hash + Eq, V> CoreHashMap<'a, K, V> {
for entry_idx in entry_start..shard.len() {
match shard.keys[entry_idx].tag {
EntryType::Empty => return None,
EntryType::Tombstone => continue,
EntryType::Occupied => {
let cand_key = unsafe { shard.keys[entry_idx].val.assume_init_ref() };
if cand_key == key {
let bucket_idx = shard.idxs[entry_idx].pos_checked().expect("position is valid");
return Some(RwLockReadGuard::map(
shard, |_| self.bucket_arr.buckets[bucket_idx].as_ref()
));
}
EntryType::Tombstone | EntryType::RehashTombstone => continue,
t @ (EntryType::Occupied | EntryType::Rehash) => {
if (t == EntryType::Occupied && ignore_remap) || (t == EntryType::Rehash && !ignore_remap) {
let cand_key = unsafe { shard.keys[entry_idx].val.assume_init_ref() };
if cand_key == key {
let bucket_idx = shard.idxs[entry_idx].pos_checked()
.expect("position is valid");
return Some(RwLockReadGuard::map(
shard, |_| self.bucket_arr.buckets[bucket_idx].as_ref()
));
}
}
},
_ => unreachable!(),
}
}
}
None
}
pub fn entry_with_hash(&'a mut self, key: K, hash: u64) -> Result<Entry<'a, K, V>, FullError> {
fn entry(&'a mut self, key: K, hash: u64, ignore_remap: bool) -> Result<Entry<'a, K, V>, FullError> {
// We need to keep holding on the locks for each shard we process since if we don't find the
// key anywhere, we want to insert it at the earliest possible position (which may be several
// shards away). Ideally cross-shard chains are quite rare, so this shouldn't be a big deal.
@@ -158,26 +226,27 @@ impl<'a, K: Clone + Hash + Eq, V> CoreHashMap<'a, K, V> {
bucket_arr: &mut self.bucket_arr,
}))
},
EntryType::Tombstone => {
EntryType::Tombstone | EntryType::RehashTombstone => {
if insert_pos.is_none() {
insert_pos = Some(entry_idx);
inserted = true;
}
},
EntryType::Occupied => {
let cand_key = unsafe { shard.keys[entry_idx].val.assume_init_ref() };
if *cand_key == key {
let bucket_pos = shard.idxs[entry_idx].pos_checked().unwrap();
return Ok(Entry::Occupied(OccupiedEntry {
_key: key,
shard,
shard_pos: entry_idx,
bucket_pos,
bucket_arr: &mut self.bucket_arr,
}));
t @ (EntryType::Occupied | EntryType::Rehash) => {
if (t == EntryType::Occupied && ignore_remap) || (t == EntryType::Rehash && !ignore_remap) {
let cand_key = unsafe { shard.keys[entry_idx].val.assume_init_ref() };
if *cand_key == key {
let bucket_pos = shard.idxs[entry_idx].pos_checked().unwrap();
return Ok(Entry::Occupied(OccupiedEntry {
_key: key,
shard,
shard_pos: entry_idx,
bucket_pos,
bucket_arr: &mut self.bucket_arr,
}));
}
}
}
_ => unreachable!(),
}
}
if inserted {
@@ -218,3 +287,4 @@ impl<'a, K: Clone + Hash + Eq, V> CoreHashMap<'a, K, V> {
self.bucket_arr.clear();
}
}