mirror of
https://github.com/neondatabase/neon.git
synced 2025-12-23 06:09:59 +00:00
Fix concurrency bugs in resizing (WIP)
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
use std::cell::UnsafeCell;
|
||||
use std::hash::{BuildHasher, Hash};
|
||||
use std::mem::MaybeUninit;
|
||||
use std::ptr::NonNull;
|
||||
@@ -14,7 +15,7 @@ pub mod entry;
|
||||
mod tests;
|
||||
|
||||
use core::{
|
||||
CoreHashMap, DictShard, EntryKey, EntryType,
|
||||
CoreHashMap, DictShard, EntryKey, EntryTag,
|
||||
FullError, MaybeUninitDictShard
|
||||
};
|
||||
use bucket::{Bucket, BucketIdx};
|
||||
@@ -134,11 +135,13 @@ impl<'a, K: Clone + Hash + Eq, V, S> HashMapInit<'a, K, V, S> {
|
||||
}
|
||||
let shards: &mut [RwLock<MaybeUninitDictShard<'_, K>>] =
|
||||
unsafe { std::slice::from_raw_parts_mut(shards_ptr.cast(), num_shards) };
|
||||
let buckets =
|
||||
unsafe { std::slice::from_raw_parts_mut(vals_ptr.cast(), num_buckets) };
|
||||
let buckets: *const [MaybeUninit<Bucket<V>>] =
|
||||
unsafe { std::slice::from_raw_parts(vals_ptr.cast(), num_buckets) };
|
||||
|
||||
let hashmap = CoreHashMap::new(buckets, shards);
|
||||
unsafe { std::ptr::write(shared_ptr, hashmap); }
|
||||
unsafe {
|
||||
let hashmap = CoreHashMap::new(&*(buckets as *const UnsafeCell<_>), shards);
|
||||
std::ptr::write(shared_ptr, hashmap);
|
||||
}
|
||||
|
||||
let resize_lock = Mutex::from_raw(
|
||||
unsafe { PthreadMutex::new(NonNull::new_unchecked(mutex_ptr)) }, ()
|
||||
@@ -313,18 +316,38 @@ where
|
||||
|
||||
pub unsafe fn get_at_bucket(&self, pos: usize) -> Option<&V> {
|
||||
let map = unsafe { self.shared_ptr.as_mut() }.unwrap();
|
||||
if pos >= map.bucket_arr.buckets.len() {
|
||||
if pos >= map.bucket_arr.len() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let bucket = &map.bucket_arr.buckets[pos];
|
||||
if bucket.next.load(Ordering::Relaxed) == BucketIdx::RESERVED {
|
||||
let bucket = &map.bucket_arr[pos];
|
||||
if bucket.next.load(Ordering::Relaxed).full_checked().is_some() {
|
||||
Some(unsafe { bucket.val.assume_init_ref() })
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub unsafe fn entry_at_bucket(&self, pos: usize) -> Option<entry::OccupiedEntry<'a, K, V>> {
|
||||
let map = unsafe { self.shared_ptr.as_mut() }.unwrap();
|
||||
if pos >= map.bucket_arr.len() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let bucket = &map.bucket_arr[pos];
|
||||
bucket.next.load(Ordering::Relaxed).full_checked().map(|entry_pos| {
|
||||
let shard_size = map.get_num_buckets() / map.dict_shards.len();
|
||||
let shard_index = entry_pos / shard_size;
|
||||
let shard_off = entry_pos % shard_size;
|
||||
entry::OccupiedEntry {
|
||||
shard: map.dict_shards[shard_index].write(),
|
||||
shard_pos: shard_off,
|
||||
bucket_pos: pos,
|
||||
bucket_arr: &map.bucket_arr,
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// bucket the number of buckets in the table.
|
||||
pub fn get_num_buckets(&self) -> usize {
|
||||
let map = unsafe { self.shared_ptr.as_ref() }.unwrap();
|
||||
@@ -335,9 +358,9 @@ where
|
||||
pub fn get_bucket_for_value(&self, val_ptr: *const V) -> usize {
|
||||
let map = unsafe { self.shared_ptr.as_ref() }.unwrap();
|
||||
|
||||
let origin = map.bucket_arr.buckets.as_ptr();
|
||||
let origin = map.bucket_arr.as_mut_ptr() as *const _;
|
||||
let idx = (val_ptr as usize - origin as usize) / size_of::<Bucket<V>>();
|
||||
assert!(idx < map.bucket_arr.buckets.len());
|
||||
assert!(idx < map.bucket_arr.len());
|
||||
|
||||
idx
|
||||
}
|
||||
@@ -368,8 +391,8 @@ where
|
||||
|
||||
shards.iter_mut().for_each(|x| x.keys.iter_mut().for_each(|key| {
|
||||
match key.tag {
|
||||
EntryType::Occupied => key.tag = EntryType::Rehash,
|
||||
EntryType::Tombstone => key.tag = EntryType::RehashTombstone,
|
||||
EntryTag::Occupied => key.tag = EntryTag::Rehash,
|
||||
EntryTag::Tombstone => key.tag = EntryTag::RehashTombstone,
|
||||
_ => (),
|
||||
}
|
||||
}));
|
||||
@@ -379,9 +402,66 @@ where
|
||||
true
|
||||
}
|
||||
|
||||
|
||||
// TODO(quantumish): off by one for return value logic?
|
||||
fn do_rehash(&self) -> bool {
|
||||
let map = unsafe { self.shared_ptr.as_mut() }.unwrap();
|
||||
// TODO(quantumish): refactor these out into settable quantities
|
||||
const REHASH_CHUNK_SIZE: usize = 10;
|
||||
const REHASH_ATTEMPTS: usize = 5;
|
||||
|
||||
let end = map.rehash_end.load(Ordering::Relaxed);
|
||||
let ind = map.rehash_index.load(Ordering::Relaxed);
|
||||
if ind >= end { return true }
|
||||
|
||||
let _guard = self.resize_lock.try_lock();
|
||||
if _guard.is_none() { return false }
|
||||
|
||||
map.rehash_index.store((ind+REHASH_CHUNK_SIZE).min(end), Ordering::Relaxed);
|
||||
|
||||
let shard_size = map.get_num_buckets() / map.dict_shards.len();
|
||||
for i in ind..(ind+REHASH_CHUNK_SIZE).min(end) {
|
||||
let (shard_index, shard_off) = (i / shard_size, i % shard_size);
|
||||
let mut shard = map.dict_shards[shard_index].write();
|
||||
if shard.keys[shard_off].tag != EntryTag::Rehash {
|
||||
continue;
|
||||
}
|
||||
loop {
|
||||
let hash = self.get_hash_value(unsafe {
|
||||
shard.keys[shard_off].val.assume_init_ref()
|
||||
});
|
||||
|
||||
let key = unsafe { shard.keys[shard_off].val.assume_init_ref() }.clone();
|
||||
let new = map.entry(key, hash, |tag| match tag {
|
||||
EntryTag::Empty => core::MapEntryType::Empty,
|
||||
EntryTag::Occupied => core::MapEntryType::Occupied,
|
||||
EntryTag::Tombstone => core::MapEntryType::Skip,
|
||||
_ => core::MapEntryType::Tombstone,
|
||||
}).unwrap();
|
||||
let new_pos = new.pos();
|
||||
|
||||
match new.tag() {
|
||||
EntryTag::Empty | EntryTag::RehashTombstone => {
|
||||
shard.keys[shard_off].tag = EntryTag::Empty;
|
||||
unsafe {
|
||||
std::mem::swap(
|
||||
shard.keys[shard_off].val.assume_init_mut(),
|
||||
new.
|
||||
},
|
||||
EntryTag::Rehash => {
|
||||
|
||||
},
|
||||
_ => unreachable!()
|
||||
}
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
|
||||
pub fn finish_rehash(&self) {
|
||||
let map = unsafe { self.shared_ptr.as_mut() }.unwrap();
|
||||
while map.do_rehash() {}
|
||||
while self.do_rehash() {}
|
||||
}
|
||||
|
||||
pub fn shuffle(&self) {
|
||||
@@ -422,7 +502,7 @@ where
|
||||
let _resize_guard = self.resize_lock.lock();
|
||||
let mut shards: Vec<_> = map.dict_shards.iter().map(|x| x.write()).collect();
|
||||
|
||||
let old_num_buckets = map.bucket_arr.buckets.len();
|
||||
let old_num_buckets = map.bucket_arr.len();
|
||||
assert!(
|
||||
num_buckets >= old_num_buckets,
|
||||
"grow called with a smaller number of buckets"
|
||||
@@ -434,7 +514,7 @@ where
|
||||
// Grow memory areas and initialize each of them.
|
||||
self.resize_shmem(num_buckets)?;
|
||||
unsafe {
|
||||
let buckets_ptr = map.bucket_arr.buckets.as_mut_ptr();
|
||||
let buckets_ptr = map.bucket_arr.as_mut_ptr();
|
||||
for i in old_num_buckets..num_buckets {
|
||||
let bucket = buckets_ptr.add(i);
|
||||
bucket.write(Bucket::empty(
|
||||
@@ -452,7 +532,7 @@ where
|
||||
for i in old_num_buckets..num_buckets {
|
||||
let key = keys_ptr.add(i);
|
||||
key.write(EntryKey {
|
||||
tag: EntryType::Empty,
|
||||
tag: EntryTag::Empty,
|
||||
val: MaybeUninit::uninit(),
|
||||
});
|
||||
}
|
||||
@@ -492,7 +572,7 @@ where
|
||||
pub fn shrink_goal(&self) -> Option<usize> {
|
||||
let map = unsafe { self.shared_ptr.as_mut() }.unwrap();
|
||||
let goal = map.bucket_arr.alloc_limit.load(Ordering::Relaxed);
|
||||
goal.pos_checked()
|
||||
goal.next_checkeddd()
|
||||
}
|
||||
|
||||
pub fn finish_shrink(&self) -> Result<(), shmem::Error> {
|
||||
@@ -502,7 +582,7 @@ where
|
||||
|
||||
let num_buckets = map.bucket_arr.alloc_limit
|
||||
.load(Ordering::Relaxed)
|
||||
.pos_checked()
|
||||
.next_checkeddd()
|
||||
.expect("called finish_shrink when no shrink is in progress");
|
||||
|
||||
if map.get_num_buckets() == num_buckets {
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
|
||||
use std::{mem::MaybeUninit, sync::atomic::{AtomicUsize, Ordering}};
|
||||
use std::cell::UnsafeCell;
|
||||
use std::mem::MaybeUninit;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
|
||||
use atomic::Atomic;
|
||||
|
||||
@@ -10,34 +12,55 @@ pub(crate) struct BucketIdx(pub(super) u32);
|
||||
const _: () = assert!(Atomic::<BucketIdx>::is_lock_free());
|
||||
|
||||
impl BucketIdx {
|
||||
const MARK_TAG: u32 = 0x80000000;
|
||||
pub const INVALID: Self = Self(0x7FFFFFFF);
|
||||
pub const RESERVED: Self = Self(0x7FFFFFFE);
|
||||
pub const MAX: usize = Self::RESERVED.0 as usize - 1;
|
||||
/// Tag for next pointers in free entries.
|
||||
const NEXT_TAG: u32 = 0b00 << 30;
|
||||
/// Tag for marked next pointers in free entries.
|
||||
const MARK_TAG: u32 = 0b01 << 30;
|
||||
/// Tag for full entries.
|
||||
const FULL_TAG: u32 = 0b10 << 30;
|
||||
/// Reserved. Don't use me.
|
||||
const RSVD_TAG: u32 = 0b11 << 30;
|
||||
|
||||
pub const INVALID: Self = Self(0x3FFFFFFF);
|
||||
pub const MAX: usize = Self::INVALID.0 as usize - 1;
|
||||
|
||||
pub(super) fn is_marked(&self) -> bool {
|
||||
self.0 & Self::MARK_TAG != 0
|
||||
self.0 & Self::RSVD_TAG == Self::MARK_TAG
|
||||
}
|
||||
|
||||
pub(super) fn as_marked(self) -> Self {
|
||||
Self(self.0 | Self::MARK_TAG)
|
||||
Self((self.0 & Self::INVALID.0) | Self::MARK_TAG)
|
||||
}
|
||||
|
||||
pub(super) fn get_unmarked(self) -> Self {
|
||||
Self(self.0 & !Self::MARK_TAG)
|
||||
Self(self.0 & Self::INVALID.0)
|
||||
}
|
||||
|
||||
pub fn new(val: usize) -> Self {
|
||||
debug_assert!(val < Self::MAX);
|
||||
Self(val as u32)
|
||||
}
|
||||
|
||||
pub fn new_full(val: usize) -> Self {
|
||||
debug_assert!(val < Self::MAX);
|
||||
Self(val as u32 | Self::FULL_TAG)
|
||||
}
|
||||
|
||||
pub fn pos_checked(&self) -> Option<usize> {
|
||||
pub fn next_checked(&self) -> Option<usize> {
|
||||
if *self == Self::INVALID || self.is_marked() {
|
||||
None
|
||||
} else {
|
||||
Some(self.0 as usize)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn full_checked(&self) -> Option<usize> {
|
||||
if self.0 & Self::RSVD_TAG == Self::FULL_TAG {
|
||||
Some((self.0 & Self::INVALID.0) as usize)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for BucketIdx {
|
||||
@@ -48,7 +71,6 @@ impl std::fmt::Debug for BucketIdx {
|
||||
self.is_marked(),
|
||||
match *self {
|
||||
Self::INVALID => "INVALID".to_string(),
|
||||
Self::RESERVED => "RESERVED".to_string(),
|
||||
_ => format!("{idx}")
|
||||
}
|
||||
)
|
||||
@@ -76,8 +98,6 @@ impl<V> Bucket<V> {
|
||||
next: Atomic::new(BucketIdx::INVALID)
|
||||
}
|
||||
}
|
||||
|
||||
// pub is_full
|
||||
|
||||
pub fn as_ref(&self) -> &V {
|
||||
unsafe { self.val.assume_init_ref() }
|
||||
@@ -94,7 +114,7 @@ impl<V> Bucket<V> {
|
||||
|
||||
pub(crate) struct BucketArray<'a, V> {
|
||||
/// Buckets containing values.
|
||||
pub(crate) buckets: &'a mut [Bucket<V>],
|
||||
pub(crate) buckets: &'a UnsafeCell<[Bucket<V>]>,
|
||||
/// Head of the freelist.
|
||||
pub(crate) free_head: Atomic<BucketIdx>,
|
||||
/// Maximum index of a bucket allowed to be allocated.
|
||||
@@ -105,8 +125,24 @@ pub(crate) struct BucketArray<'a, V> {
|
||||
pub(crate) _user_list_head: Atomic<BucketIdx>,
|
||||
}
|
||||
|
||||
impl <'a, V> std::ops::Index<usize> for BucketArray<'a, V> {
|
||||
type Output = Bucket<V>;
|
||||
|
||||
fn index(&self, index: usize) -> &Self::Output {
|
||||
let buckets: &[_] = unsafe { &*(self.buckets.get() as *mut _) };
|
||||
&buckets[index]
|
||||
}
|
||||
}
|
||||
|
||||
impl <'a, V> std::ops::IndexMut<usize> for BucketArray<'a, V> {
|
||||
fn index_mut(&mut self, index: usize) -> &mut Self::Output {
|
||||
let buckets: &mut [_] = unsafe { &mut *(self.buckets.get() as *mut _) };
|
||||
&mut buckets[index]
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, V> BucketArray<'a, V> {
|
||||
pub fn new(buckets: &'a mut [Bucket<V>]) -> Self {
|
||||
pub fn new(buckets: &'a UnsafeCell<[Bucket<V>]>) -> Self {
|
||||
Self {
|
||||
buckets,
|
||||
free_head: Atomic::new(BucketIdx(0)),
|
||||
@@ -115,18 +151,29 @@ impl<'a, V> BucketArray<'a, V> {
|
||||
buckets_in_use: 0.into(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_mut_ptr(&self) -> *mut Bucket<V> {
|
||||
unsafe { (&mut *self.buckets.get()).as_mut_ptr() }
|
||||
}
|
||||
|
||||
pub fn get_mut(&self, index: usize) -> &mut Bucket<V> {
|
||||
let buckets: &mut [_] = unsafe { &mut *(self.buckets.get() as *mut _) };
|
||||
&mut buckets[index]
|
||||
}
|
||||
|
||||
pub fn dealloc_bucket(&mut self, pos: usize) -> V {
|
||||
let bucket = &mut self.buckets[pos];
|
||||
let pos = BucketIdx::new(pos);
|
||||
pub fn len(&self) -> usize {
|
||||
unsafe { (&*self.buckets.get()).len() }
|
||||
}
|
||||
|
||||
pub fn dealloc_bucket(&self, pos: usize) -> V {
|
||||
loop {
|
||||
let free = self.free_head.load(Ordering::Relaxed);
|
||||
bucket.next.store(free, Ordering::Relaxed);
|
||||
self[pos].next.store(free, Ordering::Relaxed);
|
||||
if self.free_head.compare_exchange_weak(
|
||||
free, pos, Ordering::Relaxed, Ordering::Relaxed
|
||||
free, BucketIdx::new(pos), Ordering::Relaxed, Ordering::Relaxed
|
||||
).is_ok() {
|
||||
self.buckets_in_use.fetch_sub(1, Ordering::Relaxed);
|
||||
return unsafe { bucket.val.assume_init_read() };
|
||||
return unsafe { self[pos].val.assume_init_read() };
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -140,8 +187,8 @@ impl<'a, V> BucketArray<'a, V> {
|
||||
loop {
|
||||
let mut t = BucketIdx::INVALID;
|
||||
let mut t_next = self.free_head.load(Ordering::Relaxed);
|
||||
let alloc_limit = self.alloc_limit.load(Ordering::Relaxed).pos_checked();
|
||||
while t_next.is_marked() || t.pos_checked()
|
||||
let alloc_limit = self.alloc_limit.load(Ordering::Relaxed).next_checked();
|
||||
while t_next.is_marked() || t.next_checked()
|
||||
.map_or(true, |v| alloc_limit.map_or(false, |l| v > l))
|
||||
{
|
||||
if !t_next.is_marked() {
|
||||
@@ -150,12 +197,12 @@ impl<'a, V> BucketArray<'a, V> {
|
||||
}
|
||||
t = t_next.get_unmarked();
|
||||
if t == BucketIdx::INVALID { break }
|
||||
t_next = self.buckets[t.0 as usize].next.load(Ordering::Relaxed);
|
||||
t_next = self[t.0 as usize].next.load(Ordering::Relaxed);
|
||||
}
|
||||
right_node = t;
|
||||
|
||||
if left_node_next == right_node {
|
||||
if right_node != BucketIdx::INVALID && self.buckets[right_node.0 as usize]
|
||||
if right_node != BucketIdx::INVALID && self[right_node.0 as usize]
|
||||
.next.load(Ordering::Relaxed).is_marked()
|
||||
{
|
||||
continue;
|
||||
@@ -165,13 +212,13 @@ impl<'a, V> BucketArray<'a, V> {
|
||||
}
|
||||
|
||||
let left_ref = if left_node != BucketIdx::INVALID {
|
||||
&self.buckets[left_node.0 as usize].next
|
||||
&self[left_node.0 as usize].next
|
||||
} else { &self.free_head };
|
||||
|
||||
if left_ref.compare_exchange_weak(
|
||||
left_node_next, right_node, Ordering::Relaxed, Ordering::Relaxed
|
||||
).is_ok() {
|
||||
if right_node != BucketIdx::INVALID && self.buckets[right_node.0 as usize]
|
||||
if right_node != BucketIdx::INVALID && self[right_node.0 as usize]
|
||||
.next.load(Ordering::Relaxed).is_marked()
|
||||
{
|
||||
continue;
|
||||
@@ -183,7 +230,7 @@ impl<'a, V> BucketArray<'a, V> {
|
||||
}
|
||||
|
||||
#[allow(unused_assignments)]
|
||||
pub(crate) fn alloc_bucket(&mut self, value: V) -> Option<BucketIdx> {
|
||||
pub(crate) fn alloc_bucket(&self, value: V, key_pos: usize) -> Option<BucketIdx> {
|
||||
// println!("alloc()");
|
||||
let mut right_node_next = BucketIdx::INVALID;
|
||||
let mut left_idx = BucketIdx::INVALID;
|
||||
@@ -195,7 +242,7 @@ impl<'a, V> BucketArray<'a, V> {
|
||||
return None;
|
||||
}
|
||||
|
||||
let right = &self.buckets[right_idx.0 as usize];
|
||||
let right = &self[right_idx.0 as usize];
|
||||
right_node_next = right.next.load(Ordering::Relaxed);
|
||||
if !right_node_next.is_marked() {
|
||||
if right.next.compare_exchange_weak(
|
||||
@@ -208,7 +255,7 @@ impl<'a, V> BucketArray<'a, V> {
|
||||
}
|
||||
|
||||
let left_ref = if left_idx != BucketIdx::INVALID {
|
||||
&self.buckets[left_idx.0 as usize].next
|
||||
&self[left_idx.0 as usize].next
|
||||
} else {
|
||||
&self.free_head
|
||||
};
|
||||
@@ -221,17 +268,17 @@ impl<'a, V> BucketArray<'a, V> {
|
||||
}
|
||||
|
||||
self.buckets_in_use.fetch_add(1, Ordering::Relaxed);
|
||||
self.buckets[right_idx.0 as usize].val.write(value);
|
||||
self.buckets[right_idx.0 as usize].next.store(
|
||||
BucketIdx::RESERVED, Ordering::Relaxed
|
||||
self[right_idx.0 as usize].next.store(
|
||||
BucketIdx::new_full(key_pos), Ordering::Relaxed
|
||||
);
|
||||
self.get_mut(right_idx.0 as usize).val.write(value);
|
||||
Some(right_idx)
|
||||
}
|
||||
|
||||
pub fn clear(&mut self) {
|
||||
for i in 0..self.buckets.len() {
|
||||
self.buckets[i] = Bucket::empty(
|
||||
if i < self.buckets.len() - 1 {
|
||||
for i in 0..self.len() {
|
||||
self[i] = Bucket::empty(
|
||||
if i < self.len() - 1 {
|
||||
BucketIdx::new(i + 1)
|
||||
} else {
|
||||
BucketIdx::INVALID
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
//! Simple hash table with chaining.
|
||||
|
||||
use std::cell::UnsafeCell;
|
||||
use std::hash::Hash;
|
||||
use std::mem::MaybeUninit;
|
||||
use std::sync::atomic::{Ordering, AtomicUsize};
|
||||
@@ -11,7 +12,7 @@ use crate::hash::{
|
||||
};
|
||||
|
||||
#[derive(PartialEq, Eq, Clone, Copy)]
|
||||
pub(crate) enum EntryType {
|
||||
pub(crate) enum EntryTag {
|
||||
Occupied,
|
||||
Rehash,
|
||||
Tombstone,
|
||||
@@ -19,8 +20,15 @@ pub(crate) enum EntryType {
|
||||
Empty,
|
||||
}
|
||||
|
||||
pub(crate) enum MapEntryType {
|
||||
Occupied,
|
||||
Tombstone,
|
||||
Empty,
|
||||
Skip
|
||||
}
|
||||
|
||||
pub(crate) struct EntryKey<K> {
|
||||
pub(crate) tag: EntryType,
|
||||
pub(crate) tag: EntryTag,
|
||||
pub(crate) val: MaybeUninit<K>,
|
||||
}
|
||||
|
||||
@@ -55,9 +63,10 @@ pub struct FullError();
|
||||
|
||||
impl<'a, K: Clone + Hash + Eq, V> CoreHashMap<'a, K, V> {
|
||||
pub fn new(
|
||||
buckets: &'a mut [MaybeUninit<Bucket<V>>],
|
||||
buckets_cell: &'a UnsafeCell<[MaybeUninit<Bucket<V>>]>,
|
||||
dict_shards: &'a mut [RwLock<MaybeUninitDictShard<'a, K>>],
|
||||
) -> Self {
|
||||
let buckets = unsafe { &mut *buckets_cell.get() };
|
||||
// Initialize the buckets
|
||||
for i in 0..buckets.len() {
|
||||
buckets[i].write(Bucket::empty(
|
||||
@@ -74,7 +83,7 @@ impl<'a, K: Clone + Hash + Eq, V> CoreHashMap<'a, K, V> {
|
||||
let mut dicts = shard.write();
|
||||
for e in dicts.keys.iter_mut() {
|
||||
e.write(EntryKey {
|
||||
tag: EntryType::Empty,
|
||||
tag: EntryTag::Empty,
|
||||
val: MaybeUninit::uninit(),
|
||||
});
|
||||
}
|
||||
@@ -83,10 +92,10 @@ impl<'a, K: Clone + Hash + Eq, V> CoreHashMap<'a, K, V> {
|
||||
}
|
||||
}
|
||||
|
||||
let buckets_cell = unsafe {
|
||||
&*(buckets_cell as *const _ as *const UnsafeCell<_>)
|
||||
};
|
||||
// TODO: use std::slice::assume_init_mut() once it stabilizes
|
||||
let buckets =
|
||||
unsafe { std::slice::from_raw_parts_mut(buckets.as_mut_ptr().cast(),
|
||||
buckets.len()) };
|
||||
let dict_shards = unsafe {
|
||||
std::slice::from_raw_parts_mut(dict_shards.as_mut_ptr().cast(),
|
||||
dict_shards.len())
|
||||
@@ -96,73 +105,64 @@ impl<'a, K: Clone + Hash + Eq, V> CoreHashMap<'a, K, V> {
|
||||
dict_shards,
|
||||
rehash_index: buckets.len().into(),
|
||||
rehash_end: buckets.len().into(),
|
||||
bucket_arr: BucketArray::new(buckets),
|
||||
bucket_arr: BucketArray::new(buckets_cell),
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(quantumish): off by one for return value logic?
|
||||
pub fn do_rehash(&mut self) -> bool {
|
||||
// TODO(quantumish): refactor these out into settable quantities
|
||||
const REHASH_CHUNK_SIZE: usize = 10;
|
||||
const REHASH_ATTEMPTS: usize = 5;
|
||||
|
||||
let end = self.rehash_end.load(Ordering::Relaxed);
|
||||
let mut ind = self.rehash_index.load(Ordering::Relaxed);
|
||||
let mut i = 0;
|
||||
loop {
|
||||
if ind >= end {
|
||||
// TODO(quantumish) questionable?
|
||||
self.rehash_index.store(end, Ordering::Relaxed);
|
||||
return true;
|
||||
}
|
||||
if i > REHASH_ATTEMPTS {
|
||||
break;
|
||||
}
|
||||
match self.rehash_index.compare_exchange_weak(
|
||||
ind, ind + REHASH_CHUNK_SIZE,
|
||||
Ordering::Relaxed, Ordering::Relaxed
|
||||
) {
|
||||
Err(new_ind) => ind = new_ind,
|
||||
Ok(_) => break,
|
||||
}
|
||||
i += 1;
|
||||
}
|
||||
|
||||
todo!("actual rehashing");
|
||||
false
|
||||
}
|
||||
|
||||
|
||||
pub fn get_with_hash(&'a self, key: &K, hash: u64) -> Option<ValueReadGuard<'a, V>> {
|
||||
let ind = self.rehash_index.load(Ordering::Relaxed);
|
||||
let end = self.rehash_end.load(Ordering::Relaxed);
|
||||
|
||||
let first = ind >= end || ind < end/2;
|
||||
if let Some(res) = self.get(key, hash, first) {
|
||||
return Some(res);
|
||||
let res = self.get(key, hash, |tag| match tag {
|
||||
EntryTag::Empty => MapEntryType::Empty,
|
||||
EntryTag::Occupied => MapEntryType::Occupied,
|
||||
_ => MapEntryType::Tombstone,
|
||||
});
|
||||
if res.is_some() {
|
||||
return res;
|
||||
}
|
||||
if ind < end && let Some(res) = self.get(key, hash, !first) {
|
||||
return Some(res);
|
||||
|
||||
if ind < end {
|
||||
self.get(key, hash, |tag| match tag {
|
||||
EntryTag::Empty => MapEntryType::Empty,
|
||||
EntryTag::Rehash => MapEntryType::Occupied,
|
||||
_ => MapEntryType::Tombstone,
|
||||
})
|
||||
} else {
|
||||
None
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
|
||||
pub fn entry_with_hash(&'a mut self, key: K, hash: u64) -> Result<Entry<'a, K, V>, FullError> {
|
||||
let ind = self.rehash_index.load(Ordering::Relaxed);
|
||||
let end = self.rehash_end.load(Ordering::Relaxed);
|
||||
|
||||
let res = self.entry(key.clone(), hash, |tag| match tag {
|
||||
EntryTag::Empty => MapEntryType::Empty,
|
||||
EntryTag::Occupied => MapEntryType::Occupied,
|
||||
EntryTag::Rehash => MapEntryType::Skip,
|
||||
_ => MapEntryType::Tombstone,
|
||||
});
|
||||
if ind < end {
|
||||
if let Ok(Entry::Occupied(res)) = self.entry(key.clone(), hash, true) {
|
||||
return Ok(Entry::Occupied(res));
|
||||
if let Ok(Entry::Occupied(_)) = res {
|
||||
res
|
||||
} else {
|
||||
return self.entry(key, hash, false);
|
||||
self.entry(key, hash, |tag| match tag {
|
||||
EntryTag::Empty => MapEntryType::Empty,
|
||||
EntryTag::Occupied => MapEntryType::Skip,
|
||||
EntryTag::Rehash => MapEntryType::Occupied,
|
||||
_ => MapEntryType::Tombstone
|
||||
})
|
||||
}
|
||||
} else {
|
||||
return self.entry(key.clone(), hash, true);
|
||||
res
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the value associated with a key (if it exists) given its hash.
|
||||
fn get(&'a self, key: &K, hash: u64, ignore_remap: bool) -> Option<ValueReadGuard<'a, V>> {
|
||||
fn get<F>(&'a self, key: &K, hash: u64, f: F) -> Option<ValueReadGuard<'a, V>>
|
||||
where F: Fn(EntryTag) -> MapEntryType
|
||||
{
|
||||
let num_buckets = self.get_num_buckets();
|
||||
let shard_size = num_buckets / self.dict_shards.len();
|
||||
let bucket_pos = hash as usize % num_buckets;
|
||||
@@ -172,19 +172,17 @@ impl<'a, K: Clone + Hash + Eq, V> CoreHashMap<'a, K, V> {
|
||||
let shard = self.dict_shards[shard_idx].read();
|
||||
let entry_start = if off == 0 { bucket_pos % shard_size } else { 0 };
|
||||
for entry_idx in entry_start..shard.len() {
|
||||
match shard.keys[entry_idx].tag {
|
||||
EntryType::Empty => return None,
|
||||
EntryType::Tombstone | EntryType::RehashTombstone => continue,
|
||||
t @ (EntryType::Occupied | EntryType::Rehash) => {
|
||||
if (t == EntryType::Occupied && ignore_remap) || (t == EntryType::Rehash && !ignore_remap) {
|
||||
let cand_key = unsafe { shard.keys[entry_idx].val.assume_init_ref() };
|
||||
if cand_key == key {
|
||||
let bucket_idx = shard.idxs[entry_idx].pos_checked()
|
||||
.expect("position is valid");
|
||||
return Some(RwLockReadGuard::map(
|
||||
shard, |_| self.bucket_arr.buckets[bucket_idx].as_ref()
|
||||
));
|
||||
}
|
||||
match f(shard.keys[entry_idx].tag) {
|
||||
MapEntryType::Empty => return None,
|
||||
MapEntryType::Tombstone | MapEntryType::Skip => continue,
|
||||
MapEntryType::Occupied => {
|
||||
let cand_key = unsafe { shard.keys[entry_idx].val.assume_init_ref() };
|
||||
if cand_key == key {
|
||||
let bucket_idx = shard.idxs[entry_idx].next_checked()
|
||||
.expect("position is valid");
|
||||
return Some(RwLockReadGuard::map(
|
||||
shard, |_| self.bucket_arr[bucket_idx].as_ref()
|
||||
));
|
||||
}
|
||||
},
|
||||
}
|
||||
@@ -193,7 +191,9 @@ impl<'a, K: Clone + Hash + Eq, V> CoreHashMap<'a, K, V> {
|
||||
None
|
||||
}
|
||||
|
||||
fn entry(&'a mut self, key: K, hash: u64, ignore_remap: bool) -> Result<Entry<'a, K, V>, FullError> {
|
||||
pub fn entry<F>(&'a self, key: K, hash: u64, f: F) -> Result<Entry<'a, K, V>, FullError>
|
||||
where F: Fn(EntryTag) -> MapEntryType
|
||||
{
|
||||
// We need to keep holding on the locks for each shard we process since if we don't find the
|
||||
// key anywhere, we want to insert it at the earliest possible position (which may be several
|
||||
// shards away). Ideally cross-shard chains are quite rare, so this shouldn't be a big deal.
|
||||
@@ -211,57 +211,57 @@ impl<'a, K: Clone + Hash + Eq, V> CoreHashMap<'a, K, V> {
|
||||
let mut inserted = false;
|
||||
let entry_start = if off == 0 { bucket_pos % shard_size } else { 0 };
|
||||
for entry_idx in entry_start..shard.len() {
|
||||
match shard.keys[entry_idx].tag {
|
||||
EntryType::Empty => {
|
||||
let (shard, shard_pos) = match (insert_shard, insert_pos) {
|
||||
(Some(s), Some(p)) => (s, p),
|
||||
(None, Some(p)) => (shard, p),
|
||||
(None, None) => (shard, entry_idx),
|
||||
match f(shard.keys[entry_idx].tag) {
|
||||
MapEntryType::Skip => continue,
|
||||
MapEntryType::Empty => {
|
||||
let ((shard, idx), shard_pos) = match (insert_shard, insert_pos) {
|
||||
(Some((s, i)), Some(p)) => ((s, i), p),
|
||||
(None, Some(p)) => ((shard, shard_idx), p),
|
||||
(None, None) => ((shard, shard_idx), entry_idx),
|
||||
_ => unreachable!()
|
||||
};
|
||||
return Ok(Entry::Vacant(VacantEntry {
|
||||
_key: key,
|
||||
shard,
|
||||
shard_pos,
|
||||
bucket_arr: &mut self.bucket_arr,
|
||||
key_pos: (shard_size * idx) + shard_pos,
|
||||
bucket_arr: &self.bucket_arr,
|
||||
}))
|
||||
},
|
||||
EntryType::Tombstone | EntryType::RehashTombstone => {
|
||||
MapEntryType::Tombstone => {
|
||||
if insert_pos.is_none() {
|
||||
insert_pos = Some(entry_idx);
|
||||
inserted = true;
|
||||
}
|
||||
},
|
||||
t @ (EntryType::Occupied | EntryType::Rehash) => {
|
||||
if (t == EntryType::Occupied && ignore_remap) || (t == EntryType::Rehash && !ignore_remap) {
|
||||
let cand_key = unsafe { shard.keys[entry_idx].val.assume_init_ref() };
|
||||
if *cand_key == key {
|
||||
let bucket_pos = shard.idxs[entry_idx].pos_checked().unwrap();
|
||||
return Ok(Entry::Occupied(OccupiedEntry {
|
||||
_key: key,
|
||||
shard,
|
||||
shard_pos: entry_idx,
|
||||
bucket_pos,
|
||||
bucket_arr: &mut self.bucket_arr,
|
||||
}));
|
||||
}
|
||||
MapEntryType::Occupied => {
|
||||
let cand_key = unsafe { shard.keys[entry_idx].val.assume_init_ref() };
|
||||
if *cand_key == key {
|
||||
let bucket_pos = shard.idxs[entry_idx].next_checked().unwrap();
|
||||
return Ok(Entry::Occupied(OccupiedEntry {
|
||||
shard,
|
||||
shard_pos: entry_idx,
|
||||
bucket_pos,
|
||||
bucket_arr: &self.bucket_arr,
|
||||
}));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if inserted {
|
||||
insert_shard = Some(shard)
|
||||
insert_shard = Some((shard, shard_idx));
|
||||
} else {
|
||||
shards.push(shard);
|
||||
}
|
||||
}
|
||||
|
||||
if let (Some(shard), Some(shard_pos)) = (insert_shard, insert_pos) {
|
||||
if let (Some((shard, idx)), Some(shard_pos)) = (insert_shard, insert_pos) {
|
||||
Ok(Entry::Vacant(VacantEntry {
|
||||
_key: key,
|
||||
shard,
|
||||
shard_pos,
|
||||
bucket_arr: &mut self.bucket_arr,
|
||||
key_pos: (shard_size * idx) + shard_pos,
|
||||
bucket_arr: &self.bucket_arr,
|
||||
}))
|
||||
} else {
|
||||
Err(FullError{})
|
||||
@@ -270,14 +270,14 @@ impl<'a, K: Clone + Hash + Eq, V> CoreHashMap<'a, K, V> {
|
||||
|
||||
/// Get number of buckets in map.
|
||||
pub fn get_num_buckets(&self) -> usize {
|
||||
self.bucket_arr.buckets.len()
|
||||
self.bucket_arr.len()
|
||||
}
|
||||
|
||||
pub fn clear(&mut self) {
|
||||
let mut shards: Vec<_> = self.dict_shards.iter().map(|x| x.write()).collect();
|
||||
for shard in shards.iter_mut() {
|
||||
for e in shard.keys.iter_mut() {
|
||||
e.tag = EntryType::Empty;
|
||||
e.tag = EntryTag::Empty;
|
||||
}
|
||||
for e in shard.idxs.iter_mut() {
|
||||
*e = BucketIdx::INVALID;
|
||||
|
||||
@@ -1,49 +1,61 @@
|
||||
//! Equivalent of [`std::collections::hash_map::Entry`] for this hashmap.
|
||||
|
||||
use crate::hash::{
|
||||
core::{DictShard, EntryType},
|
||||
core::{DictShard, EntryTag},
|
||||
bucket::{BucketArray, BucketIdx}
|
||||
};
|
||||
use crate::sync::{RwLockWriteGuard, ValueWriteGuard};
|
||||
|
||||
use std::hash::Hash;
|
||||
|
||||
use super::core::EntryKey;
|
||||
|
||||
pub enum Entry<'a, K, V> {
|
||||
Occupied(OccupiedEntry<'a, K, V>),
|
||||
Vacant(VacantEntry<'a, K, V>),
|
||||
}
|
||||
|
||||
impl<'a, K, V> Entry<'a, K, V> {
|
||||
pub fn loc(&self) -> (RwLockWriteGuard<'a, DictShard<'a, K>>, usize) {
|
||||
match self {
|
||||
Self::Occupied(o) => o.shard.keys[o.shard_pos].tag,
|
||||
Self::Vacant(o) => o.shard.keys[o.shard_pos].tag
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
pub struct OccupiedEntry<'a, K, V> {
|
||||
/// The key of the occupied entry
|
||||
pub(crate) _key: K,
|
||||
/// Mutable reference to the shard of the map the entry is in.
|
||||
pub(crate) shard: RwLockWriteGuard<'a, DictShard<'a, K>>,
|
||||
/// The position of the entry in the map.
|
||||
/// The position of the entry in the shard.
|
||||
pub(crate) shard_pos: usize,
|
||||
/// True logical position of the entry in the map.
|
||||
pub(crate) key_pos: usize,
|
||||
/// Mutable reference to the bucket array containing entry.
|
||||
pub(crate) bucket_arr: &'a mut BucketArray<'a, V>,
|
||||
pub(crate) bucket_arr: &'a BucketArray<'a, V>,
|
||||
/// The position of the bucket in the [`CoreHashMap`] bucket array.
|
||||
pub(crate) bucket_pos: usize,
|
||||
}
|
||||
|
||||
impl<K, V> OccupiedEntry<'_, K, V> {
|
||||
pub fn get(&self) -> &V {
|
||||
self.bucket_arr.buckets[self.bucket_pos].as_ref()
|
||||
self.bucket_arr[self.bucket_pos].as_ref()
|
||||
}
|
||||
|
||||
pub fn get_mut(&mut self) -> &mut V {
|
||||
self.bucket_arr.buckets[self.bucket_pos].as_mut()
|
||||
self.bucket_arr.get_mut(self.bucket_pos).as_mut()
|
||||
}
|
||||
|
||||
/// Inserts a value into the entry, replacing (and returning) the existing value.
|
||||
pub fn insert(&mut self, value: V) -> V {
|
||||
self.bucket_arr.buckets[self.bucket_pos].replace(value)
|
||||
self.bucket_arr.get_mut(self.bucket_pos).replace(value)
|
||||
}
|
||||
|
||||
/// Removes the entry from the hash map, returning the value originally stored within it.
|
||||
pub fn remove(&mut self) -> V {
|
||||
self.shard.idxs[self.shard_pos] = BucketIdx::INVALID;
|
||||
self.shard.keys[self.shard_pos].tag = EntryType::Tombstone;
|
||||
self.shard.keys[self.shard_pos].tag = EntryTag::Tombstone;
|
||||
self.bucket_arr.dealloc_bucket(self.bucket_pos)
|
||||
}
|
||||
}
|
||||
@@ -54,23 +66,28 @@ pub struct VacantEntry<'a, K, V> {
|
||||
pub(crate) _key: K,
|
||||
/// Mutable reference to the shard of the map the entry is in.
|
||||
pub(crate) shard: RwLockWriteGuard<'a, DictShard<'a, K>>,
|
||||
/// The position of the entry in the map.
|
||||
/// The position of the entry in the shard.
|
||||
pub(crate) shard_pos: usize,
|
||||
/// True logical position of the entry in the map.
|
||||
pub(crate) key_pos: usize,
|
||||
/// Mutable reference to the bucket array containing entry.
|
||||
pub(crate) bucket_arr: &'a mut BucketArray<'a, V>,
|
||||
pub(crate) bucket_arr: &'a BucketArray<'a, V>,
|
||||
}
|
||||
|
||||
impl<'a, K: Clone + Hash + Eq, V> VacantEntry<'a, K, V> {
|
||||
/// Insert a value into the vacant entry, finding and populating an empty bucket in the process.
|
||||
pub fn insert(mut self, value: V) -> ValueWriteGuard<'a, V> {
|
||||
let pos = self.bucket_arr.alloc_bucket(value).expect("bucket is available if entry is");
|
||||
self.shard.keys[self.shard_pos].tag = EntryType::Occupied;
|
||||
let pos = self.bucket_arr.alloc_bucket(value, self.key_pos)
|
||||
.expect("bucket is available if entry is");
|
||||
self.shard.keys[self.shard_pos].tag = EntryTag::Occupied;
|
||||
self.shard.keys[self.shard_pos].val.write(self._key);
|
||||
let idx = pos.pos_checked().expect("position is valid");
|
||||
let idx = pos.next_checkeddd().expect("position is valid");
|
||||
self.shard.idxs[self.shard_pos] = pos;
|
||||
|
||||
RwLockWriteGuard::map(self.shard, |_| {
|
||||
self.bucket_arr.buckets[idx].as_mut()
|
||||
self.bucket_arr.get_mut(idx).as_mut()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user