Fix concurrency bugs in resizing (WIP)

This commit is contained in:
quantumish
2025-08-12 10:34:29 -07:00
parent add1a0ad78
commit d33071a386
4 changed files with 305 additions and 161 deletions

View File

@@ -1,3 +1,4 @@
use std::cell::UnsafeCell;
use std::hash::{BuildHasher, Hash};
use std::mem::MaybeUninit;
use std::ptr::NonNull;
@@ -14,7 +15,7 @@ pub mod entry;
mod tests;
use core::{
CoreHashMap, DictShard, EntryKey, EntryType,
CoreHashMap, DictShard, EntryKey, EntryTag,
FullError, MaybeUninitDictShard
};
use bucket::{Bucket, BucketIdx};
@@ -134,11 +135,13 @@ impl<'a, K: Clone + Hash + Eq, V, S> HashMapInit<'a, K, V, S> {
}
let shards: &mut [RwLock<MaybeUninitDictShard<'_, K>>] =
unsafe { std::slice::from_raw_parts_mut(shards_ptr.cast(), num_shards) };
let buckets =
unsafe { std::slice::from_raw_parts_mut(vals_ptr.cast(), num_buckets) };
let buckets: *const [MaybeUninit<Bucket<V>>] =
unsafe { std::slice::from_raw_parts(vals_ptr.cast(), num_buckets) };
let hashmap = CoreHashMap::new(buckets, shards);
unsafe { std::ptr::write(shared_ptr, hashmap); }
unsafe {
let hashmap = CoreHashMap::new(&*(buckets as *const UnsafeCell<_>), shards);
std::ptr::write(shared_ptr, hashmap);
}
let resize_lock = Mutex::from_raw(
unsafe { PthreadMutex::new(NonNull::new_unchecked(mutex_ptr)) }, ()
@@ -313,18 +316,38 @@ where
pub unsafe fn get_at_bucket(&self, pos: usize) -> Option<&V> {
let map = unsafe { self.shared_ptr.as_mut() }.unwrap();
if pos >= map.bucket_arr.buckets.len() {
if pos >= map.bucket_arr.len() {
return None;
}
let bucket = &map.bucket_arr.buckets[pos];
if bucket.next.load(Ordering::Relaxed) == BucketIdx::RESERVED {
let bucket = &map.bucket_arr[pos];
if bucket.next.load(Ordering::Relaxed).full_checked().is_some() {
Some(unsafe { bucket.val.assume_init_ref() })
} else {
None
}
}
pub unsafe fn entry_at_bucket(&self, pos: usize) -> Option<entry::OccupiedEntry<'a, K, V>> {
let map = unsafe { self.shared_ptr.as_mut() }.unwrap();
if pos >= map.bucket_arr.len() {
return None;
}
let bucket = &map.bucket_arr[pos];
bucket.next.load(Ordering::Relaxed).full_checked().map(|entry_pos| {
let shard_size = map.get_num_buckets() / map.dict_shards.len();
let shard_index = entry_pos / shard_size;
let shard_off = entry_pos % shard_size;
entry::OccupiedEntry {
shard: map.dict_shards[shard_index].write(),
shard_pos: shard_off,
bucket_pos: pos,
bucket_arr: &map.bucket_arr,
}
})
}
/// bucket the number of buckets in the table.
pub fn get_num_buckets(&self) -> usize {
let map = unsafe { self.shared_ptr.as_ref() }.unwrap();
@@ -335,9 +358,9 @@ where
pub fn get_bucket_for_value(&self, val_ptr: *const V) -> usize {
let map = unsafe { self.shared_ptr.as_ref() }.unwrap();
let origin = map.bucket_arr.buckets.as_ptr();
let origin = map.bucket_arr.as_mut_ptr() as *const _;
let idx = (val_ptr as usize - origin as usize) / size_of::<Bucket<V>>();
assert!(idx < map.bucket_arr.buckets.len());
assert!(idx < map.bucket_arr.len());
idx
}
@@ -368,8 +391,8 @@ where
shards.iter_mut().for_each(|x| x.keys.iter_mut().for_each(|key| {
match key.tag {
EntryType::Occupied => key.tag = EntryType::Rehash,
EntryType::Tombstone => key.tag = EntryType::RehashTombstone,
EntryTag::Occupied => key.tag = EntryTag::Rehash,
EntryTag::Tombstone => key.tag = EntryTag::RehashTombstone,
_ => (),
}
}));
@@ -379,9 +402,66 @@ where
true
}
// TODO(quantumish): off by one for return value logic?
fn do_rehash(&self) -> bool {
let map = unsafe { self.shared_ptr.as_mut() }.unwrap();
// TODO(quantumish): refactor these out into settable quantities
const REHASH_CHUNK_SIZE: usize = 10;
const REHASH_ATTEMPTS: usize = 5;
let end = map.rehash_end.load(Ordering::Relaxed);
let ind = map.rehash_index.load(Ordering::Relaxed);
if ind >= end { return true }
let _guard = self.resize_lock.try_lock();
if _guard.is_none() { return false }
map.rehash_index.store((ind+REHASH_CHUNK_SIZE).min(end), Ordering::Relaxed);
let shard_size = map.get_num_buckets() / map.dict_shards.len();
for i in ind..(ind+REHASH_CHUNK_SIZE).min(end) {
let (shard_index, shard_off) = (i / shard_size, i % shard_size);
let mut shard = map.dict_shards[shard_index].write();
if shard.keys[shard_off].tag != EntryTag::Rehash {
continue;
}
loop {
let hash = self.get_hash_value(unsafe {
shard.keys[shard_off].val.assume_init_ref()
});
let key = unsafe { shard.keys[shard_off].val.assume_init_ref() }.clone();
let new = map.entry(key, hash, |tag| match tag {
EntryTag::Empty => core::MapEntryType::Empty,
EntryTag::Occupied => core::MapEntryType::Occupied,
EntryTag::Tombstone => core::MapEntryType::Skip,
_ => core::MapEntryType::Tombstone,
}).unwrap();
let new_pos = new.pos();
match new.tag() {
EntryTag::Empty | EntryTag::RehashTombstone => {
shard.keys[shard_off].tag = EntryTag::Empty;
unsafe {
std::mem::swap(
shard.keys[shard_off].val.assume_init_mut(),
new.
},
EntryTag::Rehash => {
},
_ => unreachable!()
}
}
}
false
}
pub fn finish_rehash(&self) {
let map = unsafe { self.shared_ptr.as_mut() }.unwrap();
while map.do_rehash() {}
while self.do_rehash() {}
}
pub fn shuffle(&self) {
@@ -422,7 +502,7 @@ where
let _resize_guard = self.resize_lock.lock();
let mut shards: Vec<_> = map.dict_shards.iter().map(|x| x.write()).collect();
let old_num_buckets = map.bucket_arr.buckets.len();
let old_num_buckets = map.bucket_arr.len();
assert!(
num_buckets >= old_num_buckets,
"grow called with a smaller number of buckets"
@@ -434,7 +514,7 @@ where
// Grow memory areas and initialize each of them.
self.resize_shmem(num_buckets)?;
unsafe {
let buckets_ptr = map.bucket_arr.buckets.as_mut_ptr();
let buckets_ptr = map.bucket_arr.as_mut_ptr();
for i in old_num_buckets..num_buckets {
let bucket = buckets_ptr.add(i);
bucket.write(Bucket::empty(
@@ -452,7 +532,7 @@ where
for i in old_num_buckets..num_buckets {
let key = keys_ptr.add(i);
key.write(EntryKey {
tag: EntryType::Empty,
tag: EntryTag::Empty,
val: MaybeUninit::uninit(),
});
}
@@ -492,7 +572,7 @@ where
pub fn shrink_goal(&self) -> Option<usize> {
let map = unsafe { self.shared_ptr.as_mut() }.unwrap();
let goal = map.bucket_arr.alloc_limit.load(Ordering::Relaxed);
goal.pos_checked()
goal.next_checkeddd()
}
pub fn finish_shrink(&self) -> Result<(), shmem::Error> {
@@ -502,7 +582,7 @@ where
let num_buckets = map.bucket_arr.alloc_limit
.load(Ordering::Relaxed)
.pos_checked()
.next_checkeddd()
.expect("called finish_shrink when no shrink is in progress");
if map.get_num_buckets() == num_buckets {

View File

@@ -1,5 +1,7 @@
use std::{mem::MaybeUninit, sync::atomic::{AtomicUsize, Ordering}};
use std::cell::UnsafeCell;
use std::mem::MaybeUninit;
use std::sync::atomic::{AtomicUsize, Ordering};
use atomic::Atomic;
@@ -10,34 +12,55 @@ pub(crate) struct BucketIdx(pub(super) u32);
const _: () = assert!(Atomic::<BucketIdx>::is_lock_free());
impl BucketIdx {
const MARK_TAG: u32 = 0x80000000;
pub const INVALID: Self = Self(0x7FFFFFFF);
pub const RESERVED: Self = Self(0x7FFFFFFE);
pub const MAX: usize = Self::RESERVED.0 as usize - 1;
/// Tag for next pointers in free entries.
const NEXT_TAG: u32 = 0b00 << 30;
/// Tag for marked next pointers in free entries.
const MARK_TAG: u32 = 0b01 << 30;
/// Tag for full entries.
const FULL_TAG: u32 = 0b10 << 30;
/// Reserved. Don't use me.
const RSVD_TAG: u32 = 0b11 << 30;
pub const INVALID: Self = Self(0x3FFFFFFF);
pub const MAX: usize = Self::INVALID.0 as usize - 1;
pub(super) fn is_marked(&self) -> bool {
self.0 & Self::MARK_TAG != 0
self.0 & Self::RSVD_TAG == Self::MARK_TAG
}
pub(super) fn as_marked(self) -> Self {
Self(self.0 | Self::MARK_TAG)
Self((self.0 & Self::INVALID.0) | Self::MARK_TAG)
}
pub(super) fn get_unmarked(self) -> Self {
Self(self.0 & !Self::MARK_TAG)
Self(self.0 & Self::INVALID.0)
}
pub fn new(val: usize) -> Self {
debug_assert!(val < Self::MAX);
Self(val as u32)
}
pub fn pos_checked(&self) -> Option<usize> {
pub fn new_full(val: usize) -> Self {
debug_assert!(val < Self::MAX);
Self(val as u32 | Self::FULL_TAG)
}
pub fn next_checked(&self) -> Option<usize> {
if *self == Self::INVALID || self.is_marked() {
None
} else {
Some(self.0 as usize)
}
}
pub fn full_checked(&self) -> Option<usize> {
if self.0 & Self::RSVD_TAG == Self::FULL_TAG {
Some((self.0 & Self::INVALID.0) as usize)
} else {
None
}
}
}
impl std::fmt::Debug for BucketIdx {
@@ -48,7 +71,6 @@ impl std::fmt::Debug for BucketIdx {
self.is_marked(),
match *self {
Self::INVALID => "INVALID".to_string(),
Self::RESERVED => "RESERVED".to_string(),
_ => format!("{idx}")
}
)
@@ -77,8 +99,6 @@ impl<V> Bucket<V> {
}
}
// pub is_full
pub fn as_ref(&self) -> &V {
unsafe { self.val.assume_init_ref() }
}
@@ -94,7 +114,7 @@ impl<V> Bucket<V> {
pub(crate) struct BucketArray<'a, V> {
/// Buckets containing values.
pub(crate) buckets: &'a mut [Bucket<V>],
pub(crate) buckets: &'a UnsafeCell<[Bucket<V>]>,
/// Head of the freelist.
pub(crate) free_head: Atomic<BucketIdx>,
/// Maximum index of a bucket allowed to be allocated.
@@ -105,8 +125,24 @@ pub(crate) struct BucketArray<'a, V> {
pub(crate) _user_list_head: Atomic<BucketIdx>,
}
impl <'a, V> std::ops::Index<usize> for BucketArray<'a, V> {
type Output = Bucket<V>;
fn index(&self, index: usize) -> &Self::Output {
let buckets: &[_] = unsafe { &*(self.buckets.get() as *mut _) };
&buckets[index]
}
}
impl <'a, V> std::ops::IndexMut<usize> for BucketArray<'a, V> {
fn index_mut(&mut self, index: usize) -> &mut Self::Output {
let buckets: &mut [_] = unsafe { &mut *(self.buckets.get() as *mut _) };
&mut buckets[index]
}
}
impl<'a, V> BucketArray<'a, V> {
pub fn new(buckets: &'a mut [Bucket<V>]) -> Self {
pub fn new(buckets: &'a UnsafeCell<[Bucket<V>]>) -> Self {
Self {
buckets,
free_head: Atomic::new(BucketIdx(0)),
@@ -116,17 +152,28 @@ impl<'a, V> BucketArray<'a, V> {
}
}
pub fn dealloc_bucket(&mut self, pos: usize) -> V {
let bucket = &mut self.buckets[pos];
let pos = BucketIdx::new(pos);
pub fn as_mut_ptr(&self) -> *mut Bucket<V> {
unsafe { (&mut *self.buckets.get()).as_mut_ptr() }
}
pub fn get_mut(&self, index: usize) -> &mut Bucket<V> {
let buckets: &mut [_] = unsafe { &mut *(self.buckets.get() as *mut _) };
&mut buckets[index]
}
pub fn len(&self) -> usize {
unsafe { (&*self.buckets.get()).len() }
}
pub fn dealloc_bucket(&self, pos: usize) -> V {
loop {
let free = self.free_head.load(Ordering::Relaxed);
bucket.next.store(free, Ordering::Relaxed);
self[pos].next.store(free, Ordering::Relaxed);
if self.free_head.compare_exchange_weak(
free, pos, Ordering::Relaxed, Ordering::Relaxed
free, BucketIdx::new(pos), Ordering::Relaxed, Ordering::Relaxed
).is_ok() {
self.buckets_in_use.fetch_sub(1, Ordering::Relaxed);
return unsafe { bucket.val.assume_init_read() };
return unsafe { self[pos].val.assume_init_read() };
}
}
}
@@ -140,8 +187,8 @@ impl<'a, V> BucketArray<'a, V> {
loop {
let mut t = BucketIdx::INVALID;
let mut t_next = self.free_head.load(Ordering::Relaxed);
let alloc_limit = self.alloc_limit.load(Ordering::Relaxed).pos_checked();
while t_next.is_marked() || t.pos_checked()
let alloc_limit = self.alloc_limit.load(Ordering::Relaxed).next_checked();
while t_next.is_marked() || t.next_checked()
.map_or(true, |v| alloc_limit.map_or(false, |l| v > l))
{
if !t_next.is_marked() {
@@ -150,12 +197,12 @@ impl<'a, V> BucketArray<'a, V> {
}
t = t_next.get_unmarked();
if t == BucketIdx::INVALID { break }
t_next = self.buckets[t.0 as usize].next.load(Ordering::Relaxed);
t_next = self[t.0 as usize].next.load(Ordering::Relaxed);
}
right_node = t;
if left_node_next == right_node {
if right_node != BucketIdx::INVALID && self.buckets[right_node.0 as usize]
if right_node != BucketIdx::INVALID && self[right_node.0 as usize]
.next.load(Ordering::Relaxed).is_marked()
{
continue;
@@ -165,13 +212,13 @@ impl<'a, V> BucketArray<'a, V> {
}
let left_ref = if left_node != BucketIdx::INVALID {
&self.buckets[left_node.0 as usize].next
&self[left_node.0 as usize].next
} else { &self.free_head };
if left_ref.compare_exchange_weak(
left_node_next, right_node, Ordering::Relaxed, Ordering::Relaxed
).is_ok() {
if right_node != BucketIdx::INVALID && self.buckets[right_node.0 as usize]
if right_node != BucketIdx::INVALID && self[right_node.0 as usize]
.next.load(Ordering::Relaxed).is_marked()
{
continue;
@@ -183,7 +230,7 @@ impl<'a, V> BucketArray<'a, V> {
}
#[allow(unused_assignments)]
pub(crate) fn alloc_bucket(&mut self, value: V) -> Option<BucketIdx> {
pub(crate) fn alloc_bucket(&self, value: V, key_pos: usize) -> Option<BucketIdx> {
// println!("alloc()");
let mut right_node_next = BucketIdx::INVALID;
let mut left_idx = BucketIdx::INVALID;
@@ -195,7 +242,7 @@ impl<'a, V> BucketArray<'a, V> {
return None;
}
let right = &self.buckets[right_idx.0 as usize];
let right = &self[right_idx.0 as usize];
right_node_next = right.next.load(Ordering::Relaxed);
if !right_node_next.is_marked() {
if right.next.compare_exchange_weak(
@@ -208,7 +255,7 @@ impl<'a, V> BucketArray<'a, V> {
}
let left_ref = if left_idx != BucketIdx::INVALID {
&self.buckets[left_idx.0 as usize].next
&self[left_idx.0 as usize].next
} else {
&self.free_head
};
@@ -221,17 +268,17 @@ impl<'a, V> BucketArray<'a, V> {
}
self.buckets_in_use.fetch_add(1, Ordering::Relaxed);
self.buckets[right_idx.0 as usize].val.write(value);
self.buckets[right_idx.0 as usize].next.store(
BucketIdx::RESERVED, Ordering::Relaxed
self[right_idx.0 as usize].next.store(
BucketIdx::new_full(key_pos), Ordering::Relaxed
);
self.get_mut(right_idx.0 as usize).val.write(value);
Some(right_idx)
}
pub fn clear(&mut self) {
for i in 0..self.buckets.len() {
self.buckets[i] = Bucket::empty(
if i < self.buckets.len() - 1 {
for i in 0..self.len() {
self[i] = Bucket::empty(
if i < self.len() - 1 {
BucketIdx::new(i + 1)
} else {
BucketIdx::INVALID

View File

@@ -1,5 +1,6 @@
//! Simple hash table with chaining.
use std::cell::UnsafeCell;
use std::hash::Hash;
use std::mem::MaybeUninit;
use std::sync::atomic::{Ordering, AtomicUsize};
@@ -11,7 +12,7 @@ use crate::hash::{
};
#[derive(PartialEq, Eq, Clone, Copy)]
pub(crate) enum EntryType {
pub(crate) enum EntryTag {
Occupied,
Rehash,
Tombstone,
@@ -19,8 +20,15 @@ pub(crate) enum EntryType {
Empty,
}
pub(crate) enum MapEntryType {
Occupied,
Tombstone,
Empty,
Skip
}
pub(crate) struct EntryKey<K> {
pub(crate) tag: EntryType,
pub(crate) tag: EntryTag,
pub(crate) val: MaybeUninit<K>,
}
@@ -55,9 +63,10 @@ pub struct FullError();
impl<'a, K: Clone + Hash + Eq, V> CoreHashMap<'a, K, V> {
pub fn new(
buckets: &'a mut [MaybeUninit<Bucket<V>>],
buckets_cell: &'a UnsafeCell<[MaybeUninit<Bucket<V>>]>,
dict_shards: &'a mut [RwLock<MaybeUninitDictShard<'a, K>>],
) -> Self {
let buckets = unsafe { &mut *buckets_cell.get() };
// Initialize the buckets
for i in 0..buckets.len() {
buckets[i].write(Bucket::empty(
@@ -74,7 +83,7 @@ impl<'a, K: Clone + Hash + Eq, V> CoreHashMap<'a, K, V> {
let mut dicts = shard.write();
for e in dicts.keys.iter_mut() {
e.write(EntryKey {
tag: EntryType::Empty,
tag: EntryTag::Empty,
val: MaybeUninit::uninit(),
});
}
@@ -83,10 +92,10 @@ impl<'a, K: Clone + Hash + Eq, V> CoreHashMap<'a, K, V> {
}
}
let buckets_cell = unsafe {
&*(buckets_cell as *const _ as *const UnsafeCell<_>)
};
// TODO: use std::slice::assume_init_mut() once it stabilizes
let buckets =
unsafe { std::slice::from_raw_parts_mut(buckets.as_mut_ptr().cast(),
buckets.len()) };
let dict_shards = unsafe {
std::slice::from_raw_parts_mut(dict_shards.as_mut_ptr().cast(),
dict_shards.len())
@@ -96,73 +105,64 @@ impl<'a, K: Clone + Hash + Eq, V> CoreHashMap<'a, K, V> {
dict_shards,
rehash_index: buckets.len().into(),
rehash_end: buckets.len().into(),
bucket_arr: BucketArray::new(buckets),
bucket_arr: BucketArray::new(buckets_cell),
}
}
// TODO(quantumish): off by one for return value logic?
pub fn do_rehash(&mut self) -> bool {
// TODO(quantumish): refactor these out into settable quantities
const REHASH_CHUNK_SIZE: usize = 10;
const REHASH_ATTEMPTS: usize = 5;
let end = self.rehash_end.load(Ordering::Relaxed);
let mut ind = self.rehash_index.load(Ordering::Relaxed);
let mut i = 0;
loop {
if ind >= end {
// TODO(quantumish) questionable?
self.rehash_index.store(end, Ordering::Relaxed);
return true;
}
if i > REHASH_ATTEMPTS {
break;
}
match self.rehash_index.compare_exchange_weak(
ind, ind + REHASH_CHUNK_SIZE,
Ordering::Relaxed, Ordering::Relaxed
) {
Err(new_ind) => ind = new_ind,
Ok(_) => break,
}
i += 1;
}
todo!("actual rehashing");
false
}
pub fn get_with_hash(&'a self, key: &K, hash: u64) -> Option<ValueReadGuard<'a, V>> {
let ind = self.rehash_index.load(Ordering::Relaxed);
let end = self.rehash_end.load(Ordering::Relaxed);
let first = ind >= end || ind < end/2;
if let Some(res) = self.get(key, hash, first) {
return Some(res);
}
if ind < end && let Some(res) = self.get(key, hash, !first) {
return Some(res);
let res = self.get(key, hash, |tag| match tag {
EntryTag::Empty => MapEntryType::Empty,
EntryTag::Occupied => MapEntryType::Occupied,
_ => MapEntryType::Tombstone,
});
if res.is_some() {
return res;
}
if ind < end {
self.get(key, hash, |tag| match tag {
EntryTag::Empty => MapEntryType::Empty,
EntryTag::Rehash => MapEntryType::Occupied,
_ => MapEntryType::Tombstone,
})
} else {
None
}
}
pub fn entry_with_hash(&'a mut self, key: K, hash: u64) -> Result<Entry<'a, K, V>, FullError> {
let ind = self.rehash_index.load(Ordering::Relaxed);
let end = self.rehash_end.load(Ordering::Relaxed);
let res = self.entry(key.clone(), hash, |tag| match tag {
EntryTag::Empty => MapEntryType::Empty,
EntryTag::Occupied => MapEntryType::Occupied,
EntryTag::Rehash => MapEntryType::Skip,
_ => MapEntryType::Tombstone,
});
if ind < end {
if let Ok(Entry::Occupied(res)) = self.entry(key.clone(), hash, true) {
return Ok(Entry::Occupied(res));
if let Ok(Entry::Occupied(_)) = res {
res
} else {
return self.entry(key, hash, false);
self.entry(key, hash, |tag| match tag {
EntryTag::Empty => MapEntryType::Empty,
EntryTag::Occupied => MapEntryType::Skip,
EntryTag::Rehash => MapEntryType::Occupied,
_ => MapEntryType::Tombstone
})
}
} else {
return self.entry(key.clone(), hash, true);
res
}
}
/// Get the value associated with a key (if it exists) given its hash.
fn get(&'a self, key: &K, hash: u64, ignore_remap: bool) -> Option<ValueReadGuard<'a, V>> {
fn get<F>(&'a self, key: &K, hash: u64, f: F) -> Option<ValueReadGuard<'a, V>>
where F: Fn(EntryTag) -> MapEntryType
{
let num_buckets = self.get_num_buckets();
let shard_size = num_buckets / self.dict_shards.len();
let bucket_pos = hash as usize % num_buckets;
@@ -172,20 +172,18 @@ impl<'a, K: Clone + Hash + Eq, V> CoreHashMap<'a, K, V> {
let shard = self.dict_shards[shard_idx].read();
let entry_start = if off == 0 { bucket_pos % shard_size } else { 0 };
for entry_idx in entry_start..shard.len() {
match shard.keys[entry_idx].tag {
EntryType::Empty => return None,
EntryType::Tombstone | EntryType::RehashTombstone => continue,
t @ (EntryType::Occupied | EntryType::Rehash) => {
if (t == EntryType::Occupied && ignore_remap) || (t == EntryType::Rehash && !ignore_remap) {
match f(shard.keys[entry_idx].tag) {
MapEntryType::Empty => return None,
MapEntryType::Tombstone | MapEntryType::Skip => continue,
MapEntryType::Occupied => {
let cand_key = unsafe { shard.keys[entry_idx].val.assume_init_ref() };
if cand_key == key {
let bucket_idx = shard.idxs[entry_idx].pos_checked()
let bucket_idx = shard.idxs[entry_idx].next_checked()
.expect("position is valid");
return Some(RwLockReadGuard::map(
shard, |_| self.bucket_arr.buckets[bucket_idx].as_ref()
shard, |_| self.bucket_arr[bucket_idx].as_ref()
));
}
}
},
}
}
@@ -193,7 +191,9 @@ impl<'a, K: Clone + Hash + Eq, V> CoreHashMap<'a, K, V> {
None
}
fn entry(&'a mut self, key: K, hash: u64, ignore_remap: bool) -> Result<Entry<'a, K, V>, FullError> {
pub fn entry<F>(&'a self, key: K, hash: u64, f: F) -> Result<Entry<'a, K, V>, FullError>
where F: Fn(EntryTag) -> MapEntryType
{
// We need to keep holding on the locks for each shard we process since if we don't find the
// key anywhere, we want to insert it at the earliest possible position (which may be several
// shards away). Ideally cross-shard chains are quite rare, so this shouldn't be a big deal.
@@ -211,57 +211,57 @@ impl<'a, K: Clone + Hash + Eq, V> CoreHashMap<'a, K, V> {
let mut inserted = false;
let entry_start = if off == 0 { bucket_pos % shard_size } else { 0 };
for entry_idx in entry_start..shard.len() {
match shard.keys[entry_idx].tag {
EntryType::Empty => {
let (shard, shard_pos) = match (insert_shard, insert_pos) {
(Some(s), Some(p)) => (s, p),
(None, Some(p)) => (shard, p),
(None, None) => (shard, entry_idx),
match f(shard.keys[entry_idx].tag) {
MapEntryType::Skip => continue,
MapEntryType::Empty => {
let ((shard, idx), shard_pos) = match (insert_shard, insert_pos) {
(Some((s, i)), Some(p)) => ((s, i), p),
(None, Some(p)) => ((shard, shard_idx), p),
(None, None) => ((shard, shard_idx), entry_idx),
_ => unreachable!()
};
return Ok(Entry::Vacant(VacantEntry {
_key: key,
shard,
shard_pos,
bucket_arr: &mut self.bucket_arr,
key_pos: (shard_size * idx) + shard_pos,
bucket_arr: &self.bucket_arr,
}))
},
EntryType::Tombstone | EntryType::RehashTombstone => {
MapEntryType::Tombstone => {
if insert_pos.is_none() {
insert_pos = Some(entry_idx);
inserted = true;
}
},
t @ (EntryType::Occupied | EntryType::Rehash) => {
if (t == EntryType::Occupied && ignore_remap) || (t == EntryType::Rehash && !ignore_remap) {
MapEntryType::Occupied => {
let cand_key = unsafe { shard.keys[entry_idx].val.assume_init_ref() };
if *cand_key == key {
let bucket_pos = shard.idxs[entry_idx].pos_checked().unwrap();
let bucket_pos = shard.idxs[entry_idx].next_checked().unwrap();
return Ok(Entry::Occupied(OccupiedEntry {
_key: key,
shard,
shard_pos: entry_idx,
bucket_pos,
bucket_arr: &mut self.bucket_arr,
bucket_arr: &self.bucket_arr,
}));
}
}
}
}
}
if inserted {
insert_shard = Some(shard)
insert_shard = Some((shard, shard_idx));
} else {
shards.push(shard);
}
}
if let (Some(shard), Some(shard_pos)) = (insert_shard, insert_pos) {
if let (Some((shard, idx)), Some(shard_pos)) = (insert_shard, insert_pos) {
Ok(Entry::Vacant(VacantEntry {
_key: key,
shard,
shard_pos,
bucket_arr: &mut self.bucket_arr,
key_pos: (shard_size * idx) + shard_pos,
bucket_arr: &self.bucket_arr,
}))
} else {
Err(FullError{})
@@ -270,14 +270,14 @@ impl<'a, K: Clone + Hash + Eq, V> CoreHashMap<'a, K, V> {
/// Get number of buckets in map.
pub fn get_num_buckets(&self) -> usize {
self.bucket_arr.buckets.len()
self.bucket_arr.len()
}
pub fn clear(&mut self) {
let mut shards: Vec<_> = self.dict_shards.iter().map(|x| x.write()).collect();
for shard in shards.iter_mut() {
for e in shard.keys.iter_mut() {
e.tag = EntryType::Empty;
e.tag = EntryTag::Empty;
}
for e in shard.idxs.iter_mut() {
*e = BucketIdx::INVALID;

View File

@@ -1,49 +1,61 @@
//! Equivalent of [`std::collections::hash_map::Entry`] for this hashmap.
use crate::hash::{
core::{DictShard, EntryType},
core::{DictShard, EntryTag},
bucket::{BucketArray, BucketIdx}
};
use crate::sync::{RwLockWriteGuard, ValueWriteGuard};
use std::hash::Hash;
use super::core::EntryKey;
pub enum Entry<'a, K, V> {
Occupied(OccupiedEntry<'a, K, V>),
Vacant(VacantEntry<'a, K, V>),
}
impl<'a, K, V> Entry<'a, K, V> {
pub fn loc(&self) -> (RwLockWriteGuard<'a, DictShard<'a, K>>, usize) {
match self {
Self::Occupied(o) => o.shard.keys[o.shard_pos].tag,
Self::Vacant(o) => o.shard.keys[o.shard_pos].tag
}
}
}
pub struct OccupiedEntry<'a, K, V> {
/// The key of the occupied entry
pub(crate) _key: K,
/// Mutable reference to the shard of the map the entry is in.
pub(crate) shard: RwLockWriteGuard<'a, DictShard<'a, K>>,
/// The position of the entry in the map.
/// The position of the entry in the shard.
pub(crate) shard_pos: usize,
/// True logical position of the entry in the map.
pub(crate) key_pos: usize,
/// Mutable reference to the bucket array containing entry.
pub(crate) bucket_arr: &'a mut BucketArray<'a, V>,
pub(crate) bucket_arr: &'a BucketArray<'a, V>,
/// The position of the bucket in the [`CoreHashMap`] bucket array.
pub(crate) bucket_pos: usize,
}
impl<K, V> OccupiedEntry<'_, K, V> {
pub fn get(&self) -> &V {
self.bucket_arr.buckets[self.bucket_pos].as_ref()
self.bucket_arr[self.bucket_pos].as_ref()
}
pub fn get_mut(&mut self) -> &mut V {
self.bucket_arr.buckets[self.bucket_pos].as_mut()
self.bucket_arr.get_mut(self.bucket_pos).as_mut()
}
/// Inserts a value into the entry, replacing (and returning) the existing value.
pub fn insert(&mut self, value: V) -> V {
self.bucket_arr.buckets[self.bucket_pos].replace(value)
self.bucket_arr.get_mut(self.bucket_pos).replace(value)
}
/// Removes the entry from the hash map, returning the value originally stored within it.
pub fn remove(&mut self) -> V {
self.shard.idxs[self.shard_pos] = BucketIdx::INVALID;
self.shard.keys[self.shard_pos].tag = EntryType::Tombstone;
self.shard.keys[self.shard_pos].tag = EntryTag::Tombstone;
self.bucket_arr.dealloc_bucket(self.bucket_pos)
}
}
@@ -54,23 +66,28 @@ pub struct VacantEntry<'a, K, V> {
pub(crate) _key: K,
/// Mutable reference to the shard of the map the entry is in.
pub(crate) shard: RwLockWriteGuard<'a, DictShard<'a, K>>,
/// The position of the entry in the map.
/// The position of the entry in the shard.
pub(crate) shard_pos: usize,
/// True logical position of the entry in the map.
pub(crate) key_pos: usize,
/// Mutable reference to the bucket array containing entry.
pub(crate) bucket_arr: &'a mut BucketArray<'a, V>,
pub(crate) bucket_arr: &'a BucketArray<'a, V>,
}
impl<'a, K: Clone + Hash + Eq, V> VacantEntry<'a, K, V> {
/// Insert a value into the vacant entry, finding and populating an empty bucket in the process.
pub fn insert(mut self, value: V) -> ValueWriteGuard<'a, V> {
let pos = self.bucket_arr.alloc_bucket(value).expect("bucket is available if entry is");
self.shard.keys[self.shard_pos].tag = EntryType::Occupied;
let pos = self.bucket_arr.alloc_bucket(value, self.key_pos)
.expect("bucket is available if entry is");
self.shard.keys[self.shard_pos].tag = EntryTag::Occupied;
self.shard.keys[self.shard_pos].val.write(self._key);
let idx = pos.pos_checked().expect("position is valid");
let idx = pos.next_checkeddd().expect("position is valid");
self.shard.idxs[self.shard_pos] = pos;
RwLockWriteGuard::map(self.shard, |_| {
self.bucket_arr.buckets[idx].as_mut()
self.bucket_arr.get_mut(idx).as_mut()
})
}
}