diff --git a/libs/neon-shmem/src/hash.rs b/libs/neon-shmem/src/hash.rs index ef9533d542..269c5ea9c3 100644 --- a/libs/neon-shmem/src/hash.rs +++ b/libs/neon-shmem/src/hash.rs @@ -1,3 +1,4 @@ +use std::cell::UnsafeCell; use std::hash::{BuildHasher, Hash}; use std::mem::MaybeUninit; use std::ptr::NonNull; @@ -14,7 +15,7 @@ pub mod entry; mod tests; use core::{ - CoreHashMap, DictShard, EntryKey, EntryType, + CoreHashMap, DictShard, EntryKey, EntryTag, FullError, MaybeUninitDictShard }; use bucket::{Bucket, BucketIdx}; @@ -134,11 +135,13 @@ impl<'a, K: Clone + Hash + Eq, V, S> HashMapInit<'a, K, V, S> { } let shards: &mut [RwLock>] = unsafe { std::slice::from_raw_parts_mut(shards_ptr.cast(), num_shards) }; - let buckets = - unsafe { std::slice::from_raw_parts_mut(vals_ptr.cast(), num_buckets) }; + let buckets: *const [MaybeUninit>] = + unsafe { std::slice::from_raw_parts(vals_ptr.cast(), num_buckets) }; - let hashmap = CoreHashMap::new(buckets, shards); - unsafe { std::ptr::write(shared_ptr, hashmap); } + unsafe { + let hashmap = CoreHashMap::new(&*(buckets as *const UnsafeCell<_>), shards); + std::ptr::write(shared_ptr, hashmap); + } let resize_lock = Mutex::from_raw( unsafe { PthreadMutex::new(NonNull::new_unchecked(mutex_ptr)) }, () @@ -313,18 +316,38 @@ where pub unsafe fn get_at_bucket(&self, pos: usize) -> Option<&V> { let map = unsafe { self.shared_ptr.as_mut() }.unwrap(); - if pos >= map.bucket_arr.buckets.len() { + if pos >= map.bucket_arr.len() { return None; } - let bucket = &map.bucket_arr.buckets[pos]; - if bucket.next.load(Ordering::Relaxed) == BucketIdx::RESERVED { + let bucket = &map.bucket_arr[pos]; + if bucket.next.load(Ordering::Relaxed).full_checked().is_some() { Some(unsafe { bucket.val.assume_init_ref() }) } else { None } } + pub unsafe fn entry_at_bucket(&self, pos: usize) -> Option> { + let map = unsafe { self.shared_ptr.as_mut() }.unwrap(); + if pos >= map.bucket_arr.len() { + return None; + } + + let bucket = &map.bucket_arr[pos]; + bucket.next.load(Ordering::Relaxed).full_checked().map(|entry_pos| { + let shard_size = map.get_num_buckets() / map.dict_shards.len(); + let shard_index = entry_pos / shard_size; + let shard_off = entry_pos % shard_size; + entry::OccupiedEntry { + shard: map.dict_shards[shard_index].write(), + shard_pos: shard_off, + bucket_pos: pos, + bucket_arr: &map.bucket_arr, + } + }) + } + /// bucket the number of buckets in the table. pub fn get_num_buckets(&self) -> usize { let map = unsafe { self.shared_ptr.as_ref() }.unwrap(); @@ -335,9 +358,9 @@ where pub fn get_bucket_for_value(&self, val_ptr: *const V) -> usize { let map = unsafe { self.shared_ptr.as_ref() }.unwrap(); - let origin = map.bucket_arr.buckets.as_ptr(); + let origin = map.bucket_arr.as_mut_ptr() as *const _; let idx = (val_ptr as usize - origin as usize) / size_of::>(); - assert!(idx < map.bucket_arr.buckets.len()); + assert!(idx < map.bucket_arr.len()); idx } @@ -368,8 +391,8 @@ where shards.iter_mut().for_each(|x| x.keys.iter_mut().for_each(|key| { match key.tag { - EntryType::Occupied => key.tag = EntryType::Rehash, - EntryType::Tombstone => key.tag = EntryType::RehashTombstone, + EntryTag::Occupied => key.tag = EntryTag::Rehash, + EntryTag::Tombstone => key.tag = EntryTag::RehashTombstone, _ => (), } })); @@ -379,9 +402,66 @@ where true } + + // TODO(quantumish): off by one for return value logic? + fn do_rehash(&self) -> bool { + let map = unsafe { self.shared_ptr.as_mut() }.unwrap(); + // TODO(quantumish): refactor these out into settable quantities + const REHASH_CHUNK_SIZE: usize = 10; + const REHASH_ATTEMPTS: usize = 5; + + let end = map.rehash_end.load(Ordering::Relaxed); + let ind = map.rehash_index.load(Ordering::Relaxed); + if ind >= end { return true } + + let _guard = self.resize_lock.try_lock(); + if _guard.is_none() { return false } + + map.rehash_index.store((ind+REHASH_CHUNK_SIZE).min(end), Ordering::Relaxed); + + let shard_size = map.get_num_buckets() / map.dict_shards.len(); + for i in ind..(ind+REHASH_CHUNK_SIZE).min(end) { + let (shard_index, shard_off) = (i / shard_size, i % shard_size); + let mut shard = map.dict_shards[shard_index].write(); + if shard.keys[shard_off].tag != EntryTag::Rehash { + continue; + } + loop { + let hash = self.get_hash_value(unsafe { + shard.keys[shard_off].val.assume_init_ref() + }); + + let key = unsafe { shard.keys[shard_off].val.assume_init_ref() }.clone(); + let new = map.entry(key, hash, |tag| match tag { + EntryTag::Empty => core::MapEntryType::Empty, + EntryTag::Occupied => core::MapEntryType::Occupied, + EntryTag::Tombstone => core::MapEntryType::Skip, + _ => core::MapEntryType::Tombstone, + }).unwrap(); + let new_pos = new.pos(); + + match new.tag() { + EntryTag::Empty | EntryTag::RehashTombstone => { + shard.keys[shard_off].tag = EntryTag::Empty; + unsafe { + std::mem::swap( + shard.keys[shard_off].val.assume_init_mut(), + new. + }, + EntryTag::Rehash => { + + }, + _ => unreachable!() + } + } + } + false + } + + pub fn finish_rehash(&self) { let map = unsafe { self.shared_ptr.as_mut() }.unwrap(); - while map.do_rehash() {} + while self.do_rehash() {} } pub fn shuffle(&self) { @@ -422,7 +502,7 @@ where let _resize_guard = self.resize_lock.lock(); let mut shards: Vec<_> = map.dict_shards.iter().map(|x| x.write()).collect(); - let old_num_buckets = map.bucket_arr.buckets.len(); + let old_num_buckets = map.bucket_arr.len(); assert!( num_buckets >= old_num_buckets, "grow called with a smaller number of buckets" @@ -434,7 +514,7 @@ where // Grow memory areas and initialize each of them. self.resize_shmem(num_buckets)?; unsafe { - let buckets_ptr = map.bucket_arr.buckets.as_mut_ptr(); + let buckets_ptr = map.bucket_arr.as_mut_ptr(); for i in old_num_buckets..num_buckets { let bucket = buckets_ptr.add(i); bucket.write(Bucket::empty( @@ -452,7 +532,7 @@ where for i in old_num_buckets..num_buckets { let key = keys_ptr.add(i); key.write(EntryKey { - tag: EntryType::Empty, + tag: EntryTag::Empty, val: MaybeUninit::uninit(), }); } @@ -492,7 +572,7 @@ where pub fn shrink_goal(&self) -> Option { let map = unsafe { self.shared_ptr.as_mut() }.unwrap(); let goal = map.bucket_arr.alloc_limit.load(Ordering::Relaxed); - goal.pos_checked() + goal.next_checkeddd() } pub fn finish_shrink(&self) -> Result<(), shmem::Error> { @@ -502,7 +582,7 @@ where let num_buckets = map.bucket_arr.alloc_limit .load(Ordering::Relaxed) - .pos_checked() + .next_checkeddd() .expect("called finish_shrink when no shrink is in progress"); if map.get_num_buckets() == num_buckets { diff --git a/libs/neon-shmem/src/hash/bucket.rs b/libs/neon-shmem/src/hash/bucket.rs index bbd69ca38f..2cae84472d 100644 --- a/libs/neon-shmem/src/hash/bucket.rs +++ b/libs/neon-shmem/src/hash/bucket.rs @@ -1,5 +1,7 @@ -use std::{mem::MaybeUninit, sync::atomic::{AtomicUsize, Ordering}}; +use std::cell::UnsafeCell; +use std::mem::MaybeUninit; +use std::sync::atomic::{AtomicUsize, Ordering}; use atomic::Atomic; @@ -10,34 +12,55 @@ pub(crate) struct BucketIdx(pub(super) u32); const _: () = assert!(Atomic::::is_lock_free()); impl BucketIdx { - const MARK_TAG: u32 = 0x80000000; - pub const INVALID: Self = Self(0x7FFFFFFF); - pub const RESERVED: Self = Self(0x7FFFFFFE); - pub const MAX: usize = Self::RESERVED.0 as usize - 1; + /// Tag for next pointers in free entries. + const NEXT_TAG: u32 = 0b00 << 30; + /// Tag for marked next pointers in free entries. + const MARK_TAG: u32 = 0b01 << 30; + /// Tag for full entries. + const FULL_TAG: u32 = 0b10 << 30; + /// Reserved. Don't use me. + const RSVD_TAG: u32 = 0b11 << 30; + + pub const INVALID: Self = Self(0x3FFFFFFF); + pub const MAX: usize = Self::INVALID.0 as usize - 1; pub(super) fn is_marked(&self) -> bool { - self.0 & Self::MARK_TAG != 0 + self.0 & Self::RSVD_TAG == Self::MARK_TAG } pub(super) fn as_marked(self) -> Self { - Self(self.0 | Self::MARK_TAG) + Self((self.0 & Self::INVALID.0) | Self::MARK_TAG) } pub(super) fn get_unmarked(self) -> Self { - Self(self.0 & !Self::MARK_TAG) + Self(self.0 & Self::INVALID.0) } pub fn new(val: usize) -> Self { + debug_assert!(val < Self::MAX); Self(val as u32) } + + pub fn new_full(val: usize) -> Self { + debug_assert!(val < Self::MAX); + Self(val as u32 | Self::FULL_TAG) + } - pub fn pos_checked(&self) -> Option { + pub fn next_checked(&self) -> Option { if *self == Self::INVALID || self.is_marked() { None } else { Some(self.0 as usize) } } + + pub fn full_checked(&self) -> Option { + if self.0 & Self::RSVD_TAG == Self::FULL_TAG { + Some((self.0 & Self::INVALID.0) as usize) + } else { + None + } + } } impl std::fmt::Debug for BucketIdx { @@ -48,7 +71,6 @@ impl std::fmt::Debug for BucketIdx { self.is_marked(), match *self { Self::INVALID => "INVALID".to_string(), - Self::RESERVED => "RESERVED".to_string(), _ => format!("{idx}") } ) @@ -76,8 +98,6 @@ impl Bucket { next: Atomic::new(BucketIdx::INVALID) } } - - // pub is_full pub fn as_ref(&self) -> &V { unsafe { self.val.assume_init_ref() } @@ -94,7 +114,7 @@ impl Bucket { pub(crate) struct BucketArray<'a, V> { /// Buckets containing values. - pub(crate) buckets: &'a mut [Bucket], + pub(crate) buckets: &'a UnsafeCell<[Bucket]>, /// Head of the freelist. pub(crate) free_head: Atomic, /// Maximum index of a bucket allowed to be allocated. @@ -105,8 +125,24 @@ pub(crate) struct BucketArray<'a, V> { pub(crate) _user_list_head: Atomic, } +impl <'a, V> std::ops::Index for BucketArray<'a, V> { + type Output = Bucket; + + fn index(&self, index: usize) -> &Self::Output { + let buckets: &[_] = unsafe { &*(self.buckets.get() as *mut _) }; + &buckets[index] + } +} + +impl <'a, V> std::ops::IndexMut for BucketArray<'a, V> { + fn index_mut(&mut self, index: usize) -> &mut Self::Output { + let buckets: &mut [_] = unsafe { &mut *(self.buckets.get() as *mut _) }; + &mut buckets[index] + } +} + impl<'a, V> BucketArray<'a, V> { - pub fn new(buckets: &'a mut [Bucket]) -> Self { + pub fn new(buckets: &'a UnsafeCell<[Bucket]>) -> Self { Self { buckets, free_head: Atomic::new(BucketIdx(0)), @@ -115,18 +151,29 @@ impl<'a, V> BucketArray<'a, V> { buckets_in_use: 0.into(), } } + + pub fn as_mut_ptr(&self) -> *mut Bucket { + unsafe { (&mut *self.buckets.get()).as_mut_ptr() } + } + + pub fn get_mut(&self, index: usize) -> &mut Bucket { + let buckets: &mut [_] = unsafe { &mut *(self.buckets.get() as *mut _) }; + &mut buckets[index] + } - pub fn dealloc_bucket(&mut self, pos: usize) -> V { - let bucket = &mut self.buckets[pos]; - let pos = BucketIdx::new(pos); + pub fn len(&self) -> usize { + unsafe { (&*self.buckets.get()).len() } + } + + pub fn dealloc_bucket(&self, pos: usize) -> V { loop { let free = self.free_head.load(Ordering::Relaxed); - bucket.next.store(free, Ordering::Relaxed); + self[pos].next.store(free, Ordering::Relaxed); if self.free_head.compare_exchange_weak( - free, pos, Ordering::Relaxed, Ordering::Relaxed + free, BucketIdx::new(pos), Ordering::Relaxed, Ordering::Relaxed ).is_ok() { self.buckets_in_use.fetch_sub(1, Ordering::Relaxed); - return unsafe { bucket.val.assume_init_read() }; + return unsafe { self[pos].val.assume_init_read() }; } } } @@ -140,8 +187,8 @@ impl<'a, V> BucketArray<'a, V> { loop { let mut t = BucketIdx::INVALID; let mut t_next = self.free_head.load(Ordering::Relaxed); - let alloc_limit = self.alloc_limit.load(Ordering::Relaxed).pos_checked(); - while t_next.is_marked() || t.pos_checked() + let alloc_limit = self.alloc_limit.load(Ordering::Relaxed).next_checked(); + while t_next.is_marked() || t.next_checked() .map_or(true, |v| alloc_limit.map_or(false, |l| v > l)) { if !t_next.is_marked() { @@ -150,12 +197,12 @@ impl<'a, V> BucketArray<'a, V> { } t = t_next.get_unmarked(); if t == BucketIdx::INVALID { break } - t_next = self.buckets[t.0 as usize].next.load(Ordering::Relaxed); + t_next = self[t.0 as usize].next.load(Ordering::Relaxed); } right_node = t; if left_node_next == right_node { - if right_node != BucketIdx::INVALID && self.buckets[right_node.0 as usize] + if right_node != BucketIdx::INVALID && self[right_node.0 as usize] .next.load(Ordering::Relaxed).is_marked() { continue; @@ -165,13 +212,13 @@ impl<'a, V> BucketArray<'a, V> { } let left_ref = if left_node != BucketIdx::INVALID { - &self.buckets[left_node.0 as usize].next + &self[left_node.0 as usize].next } else { &self.free_head }; if left_ref.compare_exchange_weak( left_node_next, right_node, Ordering::Relaxed, Ordering::Relaxed ).is_ok() { - if right_node != BucketIdx::INVALID && self.buckets[right_node.0 as usize] + if right_node != BucketIdx::INVALID && self[right_node.0 as usize] .next.load(Ordering::Relaxed).is_marked() { continue; @@ -183,7 +230,7 @@ impl<'a, V> BucketArray<'a, V> { } #[allow(unused_assignments)] - pub(crate) fn alloc_bucket(&mut self, value: V) -> Option { + pub(crate) fn alloc_bucket(&self, value: V, key_pos: usize) -> Option { // println!("alloc()"); let mut right_node_next = BucketIdx::INVALID; let mut left_idx = BucketIdx::INVALID; @@ -195,7 +242,7 @@ impl<'a, V> BucketArray<'a, V> { return None; } - let right = &self.buckets[right_idx.0 as usize]; + let right = &self[right_idx.0 as usize]; right_node_next = right.next.load(Ordering::Relaxed); if !right_node_next.is_marked() { if right.next.compare_exchange_weak( @@ -208,7 +255,7 @@ impl<'a, V> BucketArray<'a, V> { } let left_ref = if left_idx != BucketIdx::INVALID { - &self.buckets[left_idx.0 as usize].next + &self[left_idx.0 as usize].next } else { &self.free_head }; @@ -221,17 +268,17 @@ impl<'a, V> BucketArray<'a, V> { } self.buckets_in_use.fetch_add(1, Ordering::Relaxed); - self.buckets[right_idx.0 as usize].val.write(value); - self.buckets[right_idx.0 as usize].next.store( - BucketIdx::RESERVED, Ordering::Relaxed + self[right_idx.0 as usize].next.store( + BucketIdx::new_full(key_pos), Ordering::Relaxed ); + self.get_mut(right_idx.0 as usize).val.write(value); Some(right_idx) } pub fn clear(&mut self) { - for i in 0..self.buckets.len() { - self.buckets[i] = Bucket::empty( - if i < self.buckets.len() - 1 { + for i in 0..self.len() { + self[i] = Bucket::empty( + if i < self.len() - 1 { BucketIdx::new(i + 1) } else { BucketIdx::INVALID diff --git a/libs/neon-shmem/src/hash/core.rs b/libs/neon-shmem/src/hash/core.rs index 647b267c45..f8b426ff2d 100644 --- a/libs/neon-shmem/src/hash/core.rs +++ b/libs/neon-shmem/src/hash/core.rs @@ -1,5 +1,6 @@ //! Simple hash table with chaining. +use std::cell::UnsafeCell; use std::hash::Hash; use std::mem::MaybeUninit; use std::sync::atomic::{Ordering, AtomicUsize}; @@ -11,7 +12,7 @@ use crate::hash::{ }; #[derive(PartialEq, Eq, Clone, Copy)] -pub(crate) enum EntryType { +pub(crate) enum EntryTag { Occupied, Rehash, Tombstone, @@ -19,8 +20,15 @@ pub(crate) enum EntryType { Empty, } +pub(crate) enum MapEntryType { + Occupied, + Tombstone, + Empty, + Skip +} + pub(crate) struct EntryKey { - pub(crate) tag: EntryType, + pub(crate) tag: EntryTag, pub(crate) val: MaybeUninit, } @@ -55,9 +63,10 @@ pub struct FullError(); impl<'a, K: Clone + Hash + Eq, V> CoreHashMap<'a, K, V> { pub fn new( - buckets: &'a mut [MaybeUninit>], + buckets_cell: &'a UnsafeCell<[MaybeUninit>]>, dict_shards: &'a mut [RwLock>], ) -> Self { + let buckets = unsafe { &mut *buckets_cell.get() }; // Initialize the buckets for i in 0..buckets.len() { buckets[i].write(Bucket::empty( @@ -74,7 +83,7 @@ impl<'a, K: Clone + Hash + Eq, V> CoreHashMap<'a, K, V> { let mut dicts = shard.write(); for e in dicts.keys.iter_mut() { e.write(EntryKey { - tag: EntryType::Empty, + tag: EntryTag::Empty, val: MaybeUninit::uninit(), }); } @@ -83,10 +92,10 @@ impl<'a, K: Clone + Hash + Eq, V> CoreHashMap<'a, K, V> { } } + let buckets_cell = unsafe { + &*(buckets_cell as *const _ as *const UnsafeCell<_>) + }; // TODO: use std::slice::assume_init_mut() once it stabilizes - let buckets = - unsafe { std::slice::from_raw_parts_mut(buckets.as_mut_ptr().cast(), - buckets.len()) }; let dict_shards = unsafe { std::slice::from_raw_parts_mut(dict_shards.as_mut_ptr().cast(), dict_shards.len()) @@ -96,73 +105,64 @@ impl<'a, K: Clone + Hash + Eq, V> CoreHashMap<'a, K, V> { dict_shards, rehash_index: buckets.len().into(), rehash_end: buckets.len().into(), - bucket_arr: BucketArray::new(buckets), + bucket_arr: BucketArray::new(buckets_cell), } } - - // TODO(quantumish): off by one for return value logic? - pub fn do_rehash(&mut self) -> bool { - // TODO(quantumish): refactor these out into settable quantities - const REHASH_CHUNK_SIZE: usize = 10; - const REHASH_ATTEMPTS: usize = 5; - - let end = self.rehash_end.load(Ordering::Relaxed); - let mut ind = self.rehash_index.load(Ordering::Relaxed); - let mut i = 0; - loop { - if ind >= end { - // TODO(quantumish) questionable? - self.rehash_index.store(end, Ordering::Relaxed); - return true; - } - if i > REHASH_ATTEMPTS { - break; - } - match self.rehash_index.compare_exchange_weak( - ind, ind + REHASH_CHUNK_SIZE, - Ordering::Relaxed, Ordering::Relaxed - ) { - Err(new_ind) => ind = new_ind, - Ok(_) => break, - } - i += 1; - } - - todo!("actual rehashing"); - false - } - + pub fn get_with_hash(&'a self, key: &K, hash: u64) -> Option> { let ind = self.rehash_index.load(Ordering::Relaxed); let end = self.rehash_end.load(Ordering::Relaxed); - let first = ind >= end || ind < end/2; - if let Some(res) = self.get(key, hash, first) { - return Some(res); + let res = self.get(key, hash, |tag| match tag { + EntryTag::Empty => MapEntryType::Empty, + EntryTag::Occupied => MapEntryType::Occupied, + _ => MapEntryType::Tombstone, + }); + if res.is_some() { + return res; } - if ind < end && let Some(res) = self.get(key, hash, !first) { - return Some(res); + + if ind < end { + self.get(key, hash, |tag| match tag { + EntryTag::Empty => MapEntryType::Empty, + EntryTag::Rehash => MapEntryType::Occupied, + _ => MapEntryType::Tombstone, + }) + } else { + None } - None } - + pub fn entry_with_hash(&'a mut self, key: K, hash: u64) -> Result, FullError> { let ind = self.rehash_index.load(Ordering::Relaxed); let end = self.rehash_end.load(Ordering::Relaxed); + let res = self.entry(key.clone(), hash, |tag| match tag { + EntryTag::Empty => MapEntryType::Empty, + EntryTag::Occupied => MapEntryType::Occupied, + EntryTag::Rehash => MapEntryType::Skip, + _ => MapEntryType::Tombstone, + }); if ind < end { - if let Ok(Entry::Occupied(res)) = self.entry(key.clone(), hash, true) { - return Ok(Entry::Occupied(res)); + if let Ok(Entry::Occupied(_)) = res { + res } else { - return self.entry(key, hash, false); + self.entry(key, hash, |tag| match tag { + EntryTag::Empty => MapEntryType::Empty, + EntryTag::Occupied => MapEntryType::Skip, + EntryTag::Rehash => MapEntryType::Occupied, + _ => MapEntryType::Tombstone + }) } } else { - return self.entry(key.clone(), hash, true); + res } } /// Get the value associated with a key (if it exists) given its hash. - fn get(&'a self, key: &K, hash: u64, ignore_remap: bool) -> Option> { + fn get(&'a self, key: &K, hash: u64, f: F) -> Option> + where F: Fn(EntryTag) -> MapEntryType + { let num_buckets = self.get_num_buckets(); let shard_size = num_buckets / self.dict_shards.len(); let bucket_pos = hash as usize % num_buckets; @@ -172,19 +172,17 @@ impl<'a, K: Clone + Hash + Eq, V> CoreHashMap<'a, K, V> { let shard = self.dict_shards[shard_idx].read(); let entry_start = if off == 0 { bucket_pos % shard_size } else { 0 }; for entry_idx in entry_start..shard.len() { - match shard.keys[entry_idx].tag { - EntryType::Empty => return None, - EntryType::Tombstone | EntryType::RehashTombstone => continue, - t @ (EntryType::Occupied | EntryType::Rehash) => { - if (t == EntryType::Occupied && ignore_remap) || (t == EntryType::Rehash && !ignore_remap) { - let cand_key = unsafe { shard.keys[entry_idx].val.assume_init_ref() }; - if cand_key == key { - let bucket_idx = shard.idxs[entry_idx].pos_checked() - .expect("position is valid"); - return Some(RwLockReadGuard::map( - shard, |_| self.bucket_arr.buckets[bucket_idx].as_ref() - )); - } + match f(shard.keys[entry_idx].tag) { + MapEntryType::Empty => return None, + MapEntryType::Tombstone | MapEntryType::Skip => continue, + MapEntryType::Occupied => { + let cand_key = unsafe { shard.keys[entry_idx].val.assume_init_ref() }; + if cand_key == key { + let bucket_idx = shard.idxs[entry_idx].next_checked() + .expect("position is valid"); + return Some(RwLockReadGuard::map( + shard, |_| self.bucket_arr[bucket_idx].as_ref() + )); } }, } @@ -193,7 +191,9 @@ impl<'a, K: Clone + Hash + Eq, V> CoreHashMap<'a, K, V> { None } - fn entry(&'a mut self, key: K, hash: u64, ignore_remap: bool) -> Result, FullError> { + pub fn entry(&'a self, key: K, hash: u64, f: F) -> Result, FullError> + where F: Fn(EntryTag) -> MapEntryType + { // We need to keep holding on the locks for each shard we process since if we don't find the // key anywhere, we want to insert it at the earliest possible position (which may be several // shards away). Ideally cross-shard chains are quite rare, so this shouldn't be a big deal. @@ -211,57 +211,57 @@ impl<'a, K: Clone + Hash + Eq, V> CoreHashMap<'a, K, V> { let mut inserted = false; let entry_start = if off == 0 { bucket_pos % shard_size } else { 0 }; for entry_idx in entry_start..shard.len() { - match shard.keys[entry_idx].tag { - EntryType::Empty => { - let (shard, shard_pos) = match (insert_shard, insert_pos) { - (Some(s), Some(p)) => (s, p), - (None, Some(p)) => (shard, p), - (None, None) => (shard, entry_idx), + match f(shard.keys[entry_idx].tag) { + MapEntryType::Skip => continue, + MapEntryType::Empty => { + let ((shard, idx), shard_pos) = match (insert_shard, insert_pos) { + (Some((s, i)), Some(p)) => ((s, i), p), + (None, Some(p)) => ((shard, shard_idx), p), + (None, None) => ((shard, shard_idx), entry_idx), _ => unreachable!() }; return Ok(Entry::Vacant(VacantEntry { _key: key, shard, shard_pos, - bucket_arr: &mut self.bucket_arr, + key_pos: (shard_size * idx) + shard_pos, + bucket_arr: &self.bucket_arr, })) }, - EntryType::Tombstone | EntryType::RehashTombstone => { + MapEntryType::Tombstone => { if insert_pos.is_none() { insert_pos = Some(entry_idx); inserted = true; } }, - t @ (EntryType::Occupied | EntryType::Rehash) => { - if (t == EntryType::Occupied && ignore_remap) || (t == EntryType::Rehash && !ignore_remap) { - let cand_key = unsafe { shard.keys[entry_idx].val.assume_init_ref() }; - if *cand_key == key { - let bucket_pos = shard.idxs[entry_idx].pos_checked().unwrap(); - return Ok(Entry::Occupied(OccupiedEntry { - _key: key, - shard, - shard_pos: entry_idx, - bucket_pos, - bucket_arr: &mut self.bucket_arr, - })); - } + MapEntryType::Occupied => { + let cand_key = unsafe { shard.keys[entry_idx].val.assume_init_ref() }; + if *cand_key == key { + let bucket_pos = shard.idxs[entry_idx].next_checked().unwrap(); + return Ok(Entry::Occupied(OccupiedEntry { + shard, + shard_pos: entry_idx, + bucket_pos, + bucket_arr: &self.bucket_arr, + })); } } } } if inserted { - insert_shard = Some(shard) + insert_shard = Some((shard, shard_idx)); } else { shards.push(shard); } } - if let (Some(shard), Some(shard_pos)) = (insert_shard, insert_pos) { + if let (Some((shard, idx)), Some(shard_pos)) = (insert_shard, insert_pos) { Ok(Entry::Vacant(VacantEntry { _key: key, shard, shard_pos, - bucket_arr: &mut self.bucket_arr, + key_pos: (shard_size * idx) + shard_pos, + bucket_arr: &self.bucket_arr, })) } else { Err(FullError{}) @@ -270,14 +270,14 @@ impl<'a, K: Clone + Hash + Eq, V> CoreHashMap<'a, K, V> { /// Get number of buckets in map. pub fn get_num_buckets(&self) -> usize { - self.bucket_arr.buckets.len() + self.bucket_arr.len() } pub fn clear(&mut self) { let mut shards: Vec<_> = self.dict_shards.iter().map(|x| x.write()).collect(); for shard in shards.iter_mut() { for e in shard.keys.iter_mut() { - e.tag = EntryType::Empty; + e.tag = EntryTag::Empty; } for e in shard.idxs.iter_mut() { *e = BucketIdx::INVALID; diff --git a/libs/neon-shmem/src/hash/entry.rs b/libs/neon-shmem/src/hash/entry.rs index 6fcd2ac287..ee3d5cac70 100644 --- a/libs/neon-shmem/src/hash/entry.rs +++ b/libs/neon-shmem/src/hash/entry.rs @@ -1,49 +1,61 @@ //! Equivalent of [`std::collections::hash_map::Entry`] for this hashmap. use crate::hash::{ - core::{DictShard, EntryType}, + core::{DictShard, EntryTag}, bucket::{BucketArray, BucketIdx} }; use crate::sync::{RwLockWriteGuard, ValueWriteGuard}; use std::hash::Hash; +use super::core::EntryKey; + pub enum Entry<'a, K, V> { Occupied(OccupiedEntry<'a, K, V>), Vacant(VacantEntry<'a, K, V>), } +impl<'a, K, V> Entry<'a, K, V> { + pub fn loc(&self) -> (RwLockWriteGuard<'a, DictShard<'a, K>>, usize) { + match self { + Self::Occupied(o) => o.shard.keys[o.shard_pos].tag, + Self::Vacant(o) => o.shard.keys[o.shard_pos].tag + } + } + +} + pub struct OccupiedEntry<'a, K, V> { - /// The key of the occupied entry - pub(crate) _key: K, /// Mutable reference to the shard of the map the entry is in. pub(crate) shard: RwLockWriteGuard<'a, DictShard<'a, K>>, - /// The position of the entry in the map. + /// The position of the entry in the shard. pub(crate) shard_pos: usize, + /// True logical position of the entry in the map. + pub(crate) key_pos: usize, /// Mutable reference to the bucket array containing entry. - pub(crate) bucket_arr: &'a mut BucketArray<'a, V>, + pub(crate) bucket_arr: &'a BucketArray<'a, V>, /// The position of the bucket in the [`CoreHashMap`] bucket array. pub(crate) bucket_pos: usize, } impl OccupiedEntry<'_, K, V> { pub fn get(&self) -> &V { - self.bucket_arr.buckets[self.bucket_pos].as_ref() + self.bucket_arr[self.bucket_pos].as_ref() } pub fn get_mut(&mut self) -> &mut V { - self.bucket_arr.buckets[self.bucket_pos].as_mut() + self.bucket_arr.get_mut(self.bucket_pos).as_mut() } /// Inserts a value into the entry, replacing (and returning) the existing value. pub fn insert(&mut self, value: V) -> V { - self.bucket_arr.buckets[self.bucket_pos].replace(value) + self.bucket_arr.get_mut(self.bucket_pos).replace(value) } /// Removes the entry from the hash map, returning the value originally stored within it. pub fn remove(&mut self) -> V { self.shard.idxs[self.shard_pos] = BucketIdx::INVALID; - self.shard.keys[self.shard_pos].tag = EntryType::Tombstone; + self.shard.keys[self.shard_pos].tag = EntryTag::Tombstone; self.bucket_arr.dealloc_bucket(self.bucket_pos) } } @@ -54,23 +66,28 @@ pub struct VacantEntry<'a, K, V> { pub(crate) _key: K, /// Mutable reference to the shard of the map the entry is in. pub(crate) shard: RwLockWriteGuard<'a, DictShard<'a, K>>, - /// The position of the entry in the map. + /// The position of the entry in the shard. pub(crate) shard_pos: usize, + /// True logical position of the entry in the map. + pub(crate) key_pos: usize, /// Mutable reference to the bucket array containing entry. - pub(crate) bucket_arr: &'a mut BucketArray<'a, V>, + pub(crate) bucket_arr: &'a BucketArray<'a, V>, } impl<'a, K: Clone + Hash + Eq, V> VacantEntry<'a, K, V> { /// Insert a value into the vacant entry, finding and populating an empty bucket in the process. pub fn insert(mut self, value: V) -> ValueWriteGuard<'a, V> { - let pos = self.bucket_arr.alloc_bucket(value).expect("bucket is available if entry is"); - self.shard.keys[self.shard_pos].tag = EntryType::Occupied; + let pos = self.bucket_arr.alloc_bucket(value, self.key_pos) + .expect("bucket is available if entry is"); + self.shard.keys[self.shard_pos].tag = EntryTag::Occupied; self.shard.keys[self.shard_pos].val.write(self._key); - let idx = pos.pos_checked().expect("position is valid"); + let idx = pos.next_checkeddd().expect("position is valid"); self.shard.idxs[self.shard_pos] = pos; RwLockWriteGuard::map(self.shard, |_| { - self.bucket_arr.buckets[idx].as_mut() + self.bucket_arr.get_mut(idx).as_mut() }) } } + +