Switch to neon_shmem::sync lock_api and integrate into hashmap

This commit is contained in:
David Freifeld
2025-07-02 11:44:38 -07:00
parent 9d3e07ef2c
commit 19b5618578
7 changed files with 210 additions and 398 deletions

1
Cargo.lock generated
View File

@@ -4038,6 +4038,7 @@ dependencies = [
"foldhash",
"hashbrown 0.15.4 (git+https://github.com/quantumish/hashbrown.git?rev=6610e6d)",
"libc",
"lock_api",
"nix 0.30.1",
"rand 0.9.1",
"rand_distr 0.5.1",

View File

@@ -11,6 +11,7 @@ workspace_hack = { version = "0.1", path = "../../workspace_hack" }
rustc-hash = { version = "2.1.1" }
rand = "0.9.1"
libc.workspace = true
lock_api = "0.4.13"
[dev-dependencies]
criterion = { workspace = true, features = ["html_reports"] }

View File

@@ -16,9 +16,9 @@
use std::hash::{Hash, BuildHasher};
use std::mem::MaybeUninit;
use std::default::Default;
use crate::{shmem, shmem::ShmemHandle};
use crate::{shmem, sync::*};
use crate::shmem::ShmemHandle;
mod core;
pub mod entry;
@@ -27,15 +27,14 @@ pub mod entry;
mod tests;
use core::{Bucket, CoreHashMap, INVALID_POS};
use entry::{Entry, OccupiedEntry};
use entry::{Entry, OccupiedEntry, VacantEntry, PrevPos};
/// Builder for a [`HashMapAccess`].
#[must_use]
pub struct HashMapInit<'a, K, V, S = rustc_hash::FxBuildHasher> {
shmem_handle: Option<ShmemHandle>,
shared_ptr: *mut HashMapShared<'a, K, V>,
shared_ptr: *mut RwLock<HashMapShared<'a, K, V>>,
shared_size: usize,
shrink_mode: HashMapShrinkMode,
hasher: S,
num_buckets: u32,
}
@@ -45,28 +44,6 @@ pub struct HashMapAccess<'a, K, V, S = rustc_hash::FxBuildHasher> {
shmem_handle: Option<ShmemHandle>,
shared_ptr: *mut HashMapShared<'a, K, V>,
hasher: S,
shrink_mode: HashMapShrinkMode,
}
/// Enum specifying what behavior to have surrounding occupied entries in what is
/// about-to-be-shrinked space during a call to [`HashMapAccess::finish_shrink`].
#[derive(PartialEq, Eq)]
pub enum HashMapShrinkMode {
/// Remap entry to the range of buckets that will remain after shrinking.
///
/// Requires that caller has left enough room within the map such that this is possible.
Remap,
/// Remove any entries remaining in soon to be deallocated space.
///
/// Only really useful if you legitimately do not care what entries are removed.
/// Should primarily be used for testing.
Remove,
}
impl Default for HashMapShrinkMode {
fn default() -> Self {
Self::Remap
}
}
unsafe impl<K: Sync, V: Sync, S> Sync for HashMapAccess<'_, K, V, S> {}
@@ -80,14 +57,9 @@ impl<'a, K: Clone + Hash + Eq, V, S> HashMapInit<'a, K, V, S> {
shared_ptr: self.shared_ptr,
shared_size: self.shared_size,
num_buckets: self.num_buckets,
shrink_mode: self.shrink_mode,
}
}
pub fn with_shrink_mode(self, mode: HashMapShrinkMode) -> Self {
Self { shrink_mode: mode, ..self }
}
/// Loosely (over)estimate the size needed to store a hash table with `num_buckets` buckets.
pub fn estimate_size(num_buckets: u32) -> usize {
// add some margin to cover alignment etc.
@@ -96,13 +68,17 @@ impl<'a, K: Clone + Hash + Eq, V, S> HashMapInit<'a, K, V, S> {
/// Initialize a table for writing.
pub fn attach_writer(self) -> HashMapAccess<'a, K, V, S> {
// carve out the HashMapShared struct from the area.
let mut ptr: *mut u8 = self.shared_ptr.cast();
let end_ptr: *mut u8 = unsafe { ptr.add(self.shared_size) };
ptr = unsafe { ptr.add(ptr.align_offset(align_of::<HashMapShared<K, V>>())) };
let shared_ptr: *mut HashMapShared<K, V> = ptr.cast();
ptr = unsafe { ptr.add(size_of::<HashMapShared<K, V>>()) };
// carve out area for the One Big Lock (TM) and the HashMapShared.
ptr = unsafe { ptr.add(ptr.align_offset(align_of::<libc::pthread_rwlock_t>())) };
let raw_lock_ptr = ptr;
ptr = unsafe { ptr.add(size_of::<libc::pthread_rwlock_t>()) };
ptr = unsafe { ptr.add(ptr.align_offset(align_of::<HashMapShared<K, V>>())) };
let shared_ptr: *mut HashMapShared<K, V> = ptr.cast();
ptr = unsafe { ptr.add(size_of::<HashMapShared<K, V>>()) };
// carve out the buckets
ptr = unsafe { ptr.byte_add(ptr.align_offset(align_of::<core::Bucket<K, V>>())) };
let buckets_ptr = ptr;
@@ -121,14 +97,14 @@ impl<'a, K: Clone + Hash + Eq, V, S> HashMapInit<'a, K, V, S> {
std::slice::from_raw_parts_mut(dictionary_ptr.cast(), dictionary_size as usize)
};
let hashmap = CoreHashMap::new(buckets, dictionary);
unsafe {
std::ptr::write(shared_ptr, HashMapShared { inner: hashmap });
}
let lock = RwLock::from_raw(PthreadRwLock::new(raw_lock_ptr.cast()), hashmap);
unsafe {
std::ptr::write(shared_ptr, lock);
}
HashMapAccess {
shmem_handle: self.shmem_handle,
shared_ptr: self.shared_ptr,
shrink_mode: self.shrink_mode,
shared_ptr,
hasher: self.hasher,
}
}
@@ -145,14 +121,13 @@ impl<'a, K: Clone + Hash + Eq, V, S> HashMapInit<'a, K, V, S> {
/// relies on the memory layout! The data structures are laid out in the contiguous shared memory
/// area as follows:
///
/// [`libc::pthread_rwlock_t`]
/// [`HashMapShared`]
/// [buckets]
/// [dictionary]
///
/// In between the above parts, there can be padding bytes to align the parts correctly.
struct HashMapShared<'a, K, V> {
inner: CoreHashMap<'a, K, V>
}
type HashMapShared<'a, K, V> = RwLock<CoreHashMap<'a, K, V>>;
impl<'a, K, V> HashMapInit<'a, K, V, rustc_hash::FxBuildHasher>
where
@@ -168,7 +143,6 @@ where
shmem_handle: None,
shared_ptr: area.as_mut_ptr().cast(),
shared_size: area.len(),
shrink_mode: HashMapShrinkMode::default(),
hasher: rustc_hash::FxBuildHasher,
}
}
@@ -187,7 +161,6 @@ where
shared_ptr: shmem.data_ptr.as_ptr().cast(),
shmem_handle: Some(shmem),
shared_size: size,
shrink_mode: HashMapShrinkMode::default(),
hasher: rustc_hash::FxBuildHasher
}
}
@@ -204,7 +177,6 @@ where
shared_ptr: shmem.data_ptr.as_ptr().cast(),
shmem_handle: Some(shmem),
shared_size: size,
shrink_mode: HashMapShrinkMode::default(),
hasher: rustc_hash::FxBuildHasher
}
}
@@ -229,25 +201,64 @@ where
self.hasher.hash_one(key)
}
fn entry_with_hash(&self, key: K, hash: u64) -> Entry<'a, '_, K, V> {
let mut map = unsafe { self.shared_ptr.as_ref() }.unwrap().write();
let dict_pos = hash as usize % map.dictionary.len();
let first = map.dictionary[dict_pos];
if first == INVALID_POS {
// no existing entry
return Entry::Vacant(VacantEntry {
map,
key,
dict_pos: dict_pos as u32,
});
}
let mut prev_pos = PrevPos::First(dict_pos as u32);
let mut next = first;
loop {
let bucket = &mut map.buckets[next as usize];
let (bucket_key, _bucket_value) = bucket.inner.as_mut().expect("entry is in use");
if *bucket_key == key {
// found existing entry
return Entry::Occupied(OccupiedEntry {
map,
_key: key,
prev_pos,
bucket_pos: next,
});
}
if bucket.next == INVALID_POS {
// No existing entry
return Entry::Vacant(VacantEntry {
map,
key,
dict_pos: dict_pos as u32,
});
}
prev_pos = PrevPos::Chained(next);
next = bucket.next;
}
}
/// Get a reference to the corresponding value for a key.
pub fn get<'e>(&'e self, key: &K) -> Option<&'e V> {
let map = unsafe { self.shared_ptr.as_ref() }.unwrap();
pub fn get<'e>(&'e self, key: &K) -> Option<ValueReadGuard<'e, V>> {
let hash = self.get_hash_value(key);
map.inner.get_with_hash(key, hash)
let map = unsafe { self.shared_ptr.as_ref() }.unwrap().read();
RwLockReadGuard::try_map(map, |m| m.get_with_hash(key, hash)).ok()
}
/// Get a reference to the entry containing a key.
pub fn entry(&self, key: K) -> Entry<'a, '_, K, V> {
let map = unsafe { self.shared_ptr.as_mut() }.unwrap();
let hash = self.get_hash_value(&key);
map.inner.entry_with_hash(key, hash)
self.entry_with_hash(key, hash)
}
/// Remove a key given its hash. Returns the associated value if it existed.
pub fn remove(&self, key: &K) -> Option<V> {
let map = unsafe { self.shared_ptr.as_mut() }.unwrap();
let hash = self.get_hash_value(&key);
match map.inner.entry_with_hash(key.clone(), hash) {
match self.entry_with_hash(key.clone(), hash) {
Entry::Occupied(e) => Some(e.remove()),
Entry::Vacant(_) => None
}
@@ -258,12 +269,11 @@ where
/// # Errors
/// Will return [`core::FullError`] if there is no more space left in the map.
pub fn insert(&self, key: K, value: V) -> Result<Option<V>, core::FullError> {
let map = unsafe { self.shared_ptr.as_mut() }.unwrap();
let hash = self.get_hash_value(&key);
match map.inner.entry_with_hash(key.clone(), hash) {
match self.entry_with_hash(key.clone(), hash) {
Entry::Occupied(mut e) => Ok(Some(e.insert(value))),
Entry::Vacant(e) => {
e.insert(value)?;
_ = e.insert(value)?;
Ok(None)
}
}
@@ -275,13 +285,12 @@ where
/// due to the [`OccupiedEntry`] type owning the key and also a hash of the key in order
/// to enable repairing the hash chain if the entry is removed.
pub fn entry_at_bucket(&self, pos: usize) -> Option<OccupiedEntry<'a, '_, K, V>> {
let map = unsafe { self.shared_ptr.as_mut() }.unwrap();
let inner = &mut map.inner;
if pos >= inner.buckets.len() {
let map = unsafe { self.shared_ptr.as_mut() }.unwrap().write();
if pos >= map.buckets.len() {
return None;
}
let entry = inner.buckets[pos].inner.as_ref();
let entry = map.buckets[pos].inner.as_ref();
match entry {
Some((key, _)) => Some(OccupiedEntry {
_key: key.clone(),
@@ -289,7 +298,7 @@ where
prev_pos: entry::PrevPos::Unknown(
self.get_hash_value(&key)
),
map: inner,
map,
}),
_ => None,
}
@@ -297,8 +306,8 @@ where
/// Returns the number of buckets in the table.
pub fn get_num_buckets(&self) -> usize {
let map = unsafe { self.shared_ptr.as_ref() }.unwrap();
map.inner.get_num_buckets()
let map = unsafe { self.shared_ptr.as_ref() }.unwrap().read();
map.get_num_buckets()
}
/// Return the key and value stored in bucket with given index. This can be used to
@@ -306,38 +315,35 @@ where
// TODO: An Iterator might be nicer. The communicator's clock algorithm needs to
// _slowly_ iterate through all buckets with its clock hand, without holding a lock.
// If we switch to an Iterator, it must not hold the lock.
pub fn get_at_bucket(&self, pos: usize) -> Option<&(K, V)> {
let map = unsafe { self.shared_ptr.as_ref() }.unwrap();
if pos >= map.inner.buckets.len() {
pub fn get_at_bucket(&self, pos: usize) -> Option<ValueReadGuard<(K, V)>> {
let map = unsafe { self.shared_ptr.as_ref() }.unwrap().read();
if pos >= map.buckets.len() {
return None;
}
let bucket = &map.inner.buckets[pos];
bucket.inner.as_ref()
RwLockReadGuard::try_map(map, |m| m.buckets[pos].inner.as_ref()).ok()
}
/// Returns the index of the bucket a given value corresponds to.
pub fn get_bucket_for_value(&self, val_ptr: *const V) -> usize {
let map = unsafe { self.shared_ptr.as_ref() }.unwrap();
let map = unsafe { self.shared_ptr.as_ref() }.unwrap().read();
let origin = map.inner.buckets.as_ptr();
let origin = map.buckets.as_ptr();
let idx = (val_ptr as usize - origin as usize) / size_of::<Bucket<K, V>>();
assert!(idx < map.inner.buckets.len());
assert!(idx < map.buckets.len());
idx
}
/// Returns the number of occupied buckets in the table.
pub fn get_num_buckets_in_use(&self) -> usize {
let map = unsafe { self.shared_ptr.as_ref() }.unwrap();
map.inner.buckets_in_use as usize
let map = unsafe { self.shared_ptr.as_ref() }.unwrap().read();
map.buckets_in_use as usize
}
/// Clears all entries in a table. Does not reset any shrinking operations.
pub fn clear(&self) {
let map = unsafe { self.shared_ptr.as_mut() }.unwrap();
let inner = &mut map.inner;
inner.clear();
let mut map = unsafe { self.shared_ptr.as_mut() }.unwrap().write();
map.clear();
}
/// Perform an in-place rehash of some region (0..`rehash_buckets`) of the table and reset
@@ -389,13 +395,12 @@ where
/// Rehash the map without growing or shrinking.
pub fn shuffle(&self) {
let map = unsafe { self.shared_ptr.as_mut() }.unwrap();
let inner = &mut map.inner;
let num_buckets = inner.get_num_buckets() as u32;
let mut map = unsafe { self.shared_ptr.as_mut() }.unwrap().write();
let num_buckets = map.get_num_buckets() as u32;
let size_bytes = HashMapInit::<K, V, S>::estimate_size(num_buckets);
let end_ptr: *mut u8 = unsafe { self.shared_ptr.byte_add(size_bytes).cast() };
let buckets_ptr = inner.buckets.as_mut_ptr();
self.rehash_dict(inner, buckets_ptr, end_ptr, num_buckets, num_buckets);
let buckets_ptr = map.buckets.as_mut_ptr();
self.rehash_dict(&mut map, buckets_ptr, end_ptr, num_buckets, num_buckets);
}
/// Grow the number of buckets within the table.
@@ -409,10 +414,9 @@ where
///
/// # Errors
/// Returns an [`shmem::Error`] if any errors occur resizing the memory region.
pub fn grow(&mut self, num_buckets: u32) -> Result<(), shmem::Error> {
let map = unsafe { self.shared_ptr.as_mut() }.unwrap();
let inner = &mut map.inner;
let old_num_buckets = inner.buckets.len() as u32;
pub fn grow(&self, num_buckets: u32) -> Result<(), shmem::Error> {
let mut map = unsafe { self.shared_ptr.as_mut() }.unwrap().write();
let old_num_buckets = map.buckets.len() as u32;
assert!(num_buckets >= old_num_buckets, "grow called with a smaller number of buckets");
if num_buckets == old_num_buckets {
@@ -429,7 +433,7 @@ where
// Initialize new buckets. The new buckets are linked to the free list.
// NB: This overwrites the dictionary!
let buckets_ptr = inner.buckets.as_mut_ptr();
let buckets_ptr = map.buckets.as_mut_ptr();
unsafe {
for i in old_num_buckets..num_buckets {
let bucket = buckets_ptr.add(i as usize);
@@ -437,15 +441,15 @@ where
next: if i < num_buckets-1 {
i + 1
} else {
inner.free_head
map.free_head
},
inner: None,
});
}
}
self.rehash_dict(inner, buckets_ptr, end_ptr, num_buckets, old_num_buckets);
inner.free_head = old_num_buckets;
self.rehash_dict(&mut map, buckets_ptr, end_ptr, num_buckets, old_num_buckets);
map.free_head = old_num_buckets;
Ok(())
}
@@ -456,22 +460,22 @@ where
/// Panics if called on a map initialized with [`HashMapInit::with_fixed`] or if `num_buckets` is
/// greater than the number of buckets in the map.
pub fn begin_shrink(&mut self, num_buckets: u32) {
let map = unsafe { self.shared_ptr.as_mut() }.unwrap();
let mut map = unsafe { self.shared_ptr.as_mut() }.unwrap().write();
assert!(
num_buckets <= map.inner.get_num_buckets() as u32,
num_buckets <= map.get_num_buckets() as u32,
"shrink called with a larger number of buckets"
);
_ = self
.shmem_handle
.as_ref()
.expect("shrink called on a fixed-size hash table");
map.inner.alloc_limit = num_buckets;
map.alloc_limit = num_buckets;
}
/// If a shrink operation is underway, returns the target size of the map. Otherwise, returns None.
pub fn shrink_goal(&self) -> Option<usize> {
let map = unsafe { self.shared_ptr.as_mut() }.unwrap();
let goal = map.inner.alloc_limit;
let map = unsafe { self.shared_ptr.as_mut() }.unwrap().read();
let goal = map.alloc_limit;
if goal == INVALID_POS { None } else { Some(goal as usize) }
}
@@ -487,31 +491,28 @@ where
/// # Errors
/// Returns an [`shmem::Error`] if any errors occur resizing the memory region.
pub fn finish_shrink(&self) -> Result<(), shmem::Error> {
let map = unsafe { self.shared_ptr.as_mut() }.unwrap();
let inner = &mut map.inner;
let mut map = unsafe { self.shared_ptr.as_mut() }.unwrap().write();
assert!(
inner.alloc_limit != INVALID_POS,
map.alloc_limit != INVALID_POS,
"called finish_shrink when no shrink is in progress"
);
let num_buckets = inner.alloc_limit;
let num_buckets = map.alloc_limit;
if inner.get_num_buckets() == num_buckets as usize {
if map.get_num_buckets() == num_buckets as usize {
return Ok(());
}
if self.shrink_mode == HashMapShrinkMode::Remap {
assert!(
inner.buckets_in_use <= num_buckets,
"called finish_shrink before enough entries were removed"
);
for i in (num_buckets as usize)..inner.buckets.len() {
if let Some((k, v)) = inner.buckets[i].inner.take() {
// alloc_bucket increases count, so need to decrease since we're just moving
inner.buckets_in_use -= 1;
inner.alloc_bucket(k, v).unwrap();
}
assert!(
map.buckets_in_use <= num_buckets,
"called finish_shrink before enough entries were removed"
);
for i in (num_buckets as usize)..map.buckets.len() {
if let Some((k, v)) = map.buckets[i].inner.take() {
// alloc_bucket increases count, so need to decrease since we're just moving
map.buckets_in_use -= 1;
map.alloc_bucket(k, v).unwrap();
}
}
@@ -523,9 +524,9 @@ where
let size_bytes = HashMapInit::<K, V, S>::estimate_size(num_buckets);
shmem_handle.set_size(size_bytes)?;
let end_ptr: *mut u8 = unsafe { shmem_handle.data_ptr.as_ptr().add(size_bytes) };
let buckets_ptr = inner.buckets.as_mut_ptr();
self.rehash_dict(inner, buckets_ptr, end_ptr, num_buckets, num_buckets);
inner.alloc_limit = INVALID_POS;
let buckets_ptr = map.buckets.as_mut_ptr();
self.rehash_dict(&mut map, buckets_ptr, end_ptr, num_buckets, num_buckets);
map.alloc_limit = INVALID_POS;
Ok(())
}

View File

@@ -3,7 +3,7 @@
use std::hash::Hash;
use std::mem::MaybeUninit;
use crate::hash::entry::{Entry, OccupiedEntry, PrevPos, VacantEntry};
use crate::hash::entry::*;
/// Invalid position within the map (either within the dictionary or bucket array).
pub(crate) const INVALID_POS: u32 = u32::MAX;
@@ -29,6 +29,7 @@ pub(crate) struct CoreHashMap<'a, K, V> {
pub(crate) alloc_limit: u32,
/// The number of currently occupied buckets.
pub(crate) buckets_in_use: u32,
// pub(crate) lock: libc::pthread_mutex_t,
// Unclear what the purpose of this is.
pub(crate) _user_list_head: u32,
}
@@ -109,47 +110,6 @@ impl<'a, K: Clone + Hash + Eq, V> CoreHashMap<'a, K, V> {
}
}
/// Get the [`Entry`] associated with a key given hash. This should be used for updates/inserts.
pub fn entry_with_hash(&mut self, key: K, hash: u64) -> Entry<'a, '_, K, V> {
let dict_pos = hash as usize % self.dictionary.len();
let first = self.dictionary[dict_pos];
if first == INVALID_POS {
// no existing entry
return Entry::Vacant(VacantEntry {
map: self,
key,
dict_pos: dict_pos as u32,
});
}
let mut prev_pos = PrevPos::First(dict_pos as u32);
let mut next = first;
loop {
let bucket = &mut self.buckets[next as usize];
let (bucket_key, _bucket_value) = bucket.inner.as_mut().expect("entry is in use");
if *bucket_key == key {
// found existing entry
return Entry::Occupied(OccupiedEntry {
map: self,
_key: key,
prev_pos,
bucket_pos: next,
});
}
if bucket.next == INVALID_POS {
// No existing entry
return Entry::Vacant(VacantEntry {
map: self,
key,
dict_pos: dict_pos as u32,
});
}
prev_pos = PrevPos::Chained(next);
next = bucket.next;
}
}
/// Get number of buckets in map.
pub fn get_num_buckets(&self) -> usize {
self.buckets.len()

View File

@@ -1,11 +1,12 @@
//! Equivalent of [`std::collections::hash_map::Entry`] for this hashmap.
use crate::hash::core::{CoreHashMap, FullError, INVALID_POS};
use crate::sync::{RwLockWriteGuard, ValueWriteGuard};
use std::hash::Hash;
use std::mem;
/// View into an entry in the map (either vacant or occupied).
pub enum Entry<'a, 'b, K, V> {
Occupied(OccupiedEntry<'a, 'b, K, V>),
Vacant(VacantEntry<'a, 'b, K, V>),
@@ -22,10 +23,9 @@ pub(crate) enum PrevPos {
Unknown(u64),
}
/// View into an occupied entry within the map.
pub struct OccupiedEntry<'a, 'b, K, V> {
/// Mutable reference to the map containing this entry.
pub(crate) map: &'b mut CoreHashMap<'a, K, V>,
pub(crate) map: RwLockWriteGuard<'b, CoreHashMap<'a, K, V>>,
/// The key of the occupied entry
pub(crate) _key: K,
/// The index of the previous entry in the chain.
@@ -66,7 +66,7 @@ impl<K, V> OccupiedEntry<'_, '_, K, V> {
/// # Panics
/// Panics if the `prev_pos` field is equal to [`PrevPos::Unknown`]. In practice, this means
/// the entry was obtained via calling something like [`CoreHashMap::entry_at_bucket`].
pub fn remove(self) -> V {
pub fn remove(mut self) -> V {
// If this bucket was queried by index, go ahead and follow its chain from the start.
let prev = if let PrevPos::Unknown(hash) = self.prev_pos {
let dict_idx = hash as usize % self.map.dictionary.len();
@@ -90,15 +90,17 @@ impl<K, V> OccupiedEntry<'_, '_, K, V> {
self.map.dictionary[dict_pos as usize] = bucket.next;
},
PrevPos::Chained(bucket_pos) => {
// println!("we think prev of {} is {bucket_pos}", self.bucket_pos);
self.map.buckets[bucket_pos as usize].next = bucket.next;
},
_ => unreachable!(),
}
// and add it to the freelist
// and add it to the freelist
let free = self.map.free_head;
let bucket = &mut self.map.buckets[self.bucket_pos as usize];
let old_value = bucket.inner.take();
bucket.next = self.map.free_head;
bucket.next = free;
self.map.free_head = self.bucket_pos;
self.map.buckets_in_use -= 1;
@@ -109,7 +111,7 @@ impl<K, V> OccupiedEntry<'_, '_, K, V> {
/// An abstract view into a vacant entry within the map.
pub struct VacantEntry<'a, 'b, K, V> {
/// Mutable reference to the map containing this entry.
pub(crate) map: &'b mut CoreHashMap<'a, K, V>,
pub(crate) map: RwLockWriteGuard<'b, CoreHashMap<'a, K, V>>,
/// The key to be inserted into this entry.
pub(crate) key: K,
/// The position within the dictionary corresponding to the key's hash.
@@ -121,16 +123,17 @@ impl<'b, K: Clone + Hash + Eq, V> VacantEntry<'_, 'b, K, V> {
///
/// # Errors
/// Will return [`FullError`] if there are no unoccupied buckets in the map.
pub fn insert(self, value: V) -> Result<&'b mut V, FullError> {
pub fn insert(mut self, value: V) -> Result<ValueWriteGuard<'b, V>, FullError> {
let pos = self.map.alloc_bucket(self.key, value)?;
if pos == INVALID_POS {
return Err(FullError());
}
let bucket = &mut self.map.buckets[pos as usize];
bucket.next = self.map.dictionary[self.dict_pos as usize];
self.map.buckets[pos as usize].next = self.map.dictionary[self.dict_pos as usize];
self.map.dictionary[self.dict_pos as usize] = pos;
let result = &mut self.map.buckets[pos as usize].inner.as_mut().unwrap().1;
Ok(result)
Ok(RwLockWriteGuard::map(
self.map,
|m| &mut m.buckets[pos as usize].inner.as_mut().unwrap().1
))
}
}

View File

@@ -36,7 +36,7 @@ impl<'a> From<&'a [u8]> for TestKey {
}
fn test_inserts<K: Into<TestKey> + Copy>(keys: &[K]) {
let mut w = HashMapInit::<TestKey, usize>::new_resizeable_named(
let w = HashMapInit::<TestKey, usize>::new_resizeable_named(
100000, 120000, "test_inserts"
).attach_writer();
@@ -190,10 +190,6 @@ fn random_ops() {
let op = TestOp(key, if rng.random_bool(0.75) { Some(i) } else { None });
apply_op(&op, &mut writer, &mut shadow);
if i % 1000 == 0 {
eprintln!("{i} ops processed");
}
}
}
@@ -267,7 +263,7 @@ fn test_idx_remove() {
}
while let Some((key, val)) = shadow.pop_first() {
assert_eq!(writer.get(&key), Some(&val));
assert_eq!(*writer.get(&key).unwrap(), val);
}
}
@@ -326,8 +322,10 @@ fn test_shrink_grow_seq() {
writer.grow(1500).unwrap();
do_random_ops(600, 1500, 0.1, &mut writer, &mut shadow, &mut rng);
eprintln!("Shrinking to 200");
while shadow.len() > 100 {
do_deletes(1, &mut writer, &mut shadow);
}
do_shrink(&mut writer, &mut shadow, 200);
do_deletes(100, &mut writer, &mut shadow);
do_random_ops(50, 1500, 0.25, &mut writer, &mut shadow, &mut rng);
eprintln!("Growing to 10k");
writer.grow(10000).unwrap();
@@ -336,7 +334,7 @@ fn test_shrink_grow_seq() {
#[test]
fn test_bucket_ops() {
let mut writer = HashMapInit::<TestKey, usize>::new_resizeable_named(
let writer = HashMapInit::<TestKey, usize>::new_resizeable_named(
1000, 1200, "test_bucket_ops"
).attach_writer();
match writer.entry(1.into()) {
@@ -345,21 +343,21 @@ fn test_bucket_ops() {
}
assert_eq!(writer.get_num_buckets_in_use(), 1);
assert_eq!(writer.get_num_buckets(), 1000);
assert_eq!(writer.get(&1.into()), Some(&2));
assert_eq!(*writer.get(&1.into()).unwrap(), 2);
let pos = match writer.entry(1.into()) {
Entry::Occupied(e) => {
assert_eq!(e._key, 1.into());
let pos = e.bucket_pos as usize;
assert_eq!(writer.entry_at_bucket(pos).unwrap()._key, 1.into());
assert_eq!(writer.get_at_bucket(pos), Some(&(1.into(), 2)));
assert_eq!(*writer.get_at_bucket(pos).unwrap(), (1.into(), 2));
pos
},
Entry::Vacant(_) => { panic!("Insert didn't affect entry"); },
};
let ptr: *const usize = writer.get(&1.into()).unwrap();
let ptr: *const usize = &*writer.get(&1.into()).unwrap();
assert_eq!(writer.get_bucket_for_value(ptr), pos);
writer.remove(&1.into());
assert_eq!(writer.get(&1.into()), None);
assert!(writer.get(&1.into()).is_none());
}
#[test]

View File

@@ -2,43 +2,18 @@
use std::mem::MaybeUninit;
use std::ptr::NonNull;
use std::cell::UnsafeCell;
use std::ops::{Deref, DerefMut};
use thiserror::Error;
pub type RwLock<T> = lock_api::RwLock<PthreadRwLock, T>;
pub(crate) type RwLockReadGuard<'a, T> = lock_api::RwLockReadGuard<'a, PthreadRwLock, T>;
pub type RwLockWriteGuard<'a, T> = lock_api::RwLockWriteGuard<'a, PthreadRwLock, T>;
pub type ValueReadGuard<'a, T> = lock_api::MappedRwLockReadGuard<'a, PthreadRwLock, T>;
pub type ValueWriteGuard<'a, T> = lock_api::MappedRwLockWriteGuard<'a, PthreadRwLock, T>;
/// Shared memory read-write lock.
struct RwLock<'a, T: ?Sized> {
inner: &'a mut libc::pthread_rwlock_t,
data: UnsafeCell<T>,
}
pub struct PthreadRwLock(Option<NonNull<libc::pthread_rwlock_t>>);
/// RAII guard for a read lock.
struct RwLockReadGuard<'a, 'b, T: ?Sized> {
data: NonNull<T>,
lock: &'a RwLock<'b, T>,
}
/// RAII guard for a write lock.
struct RwLockWriteGuard<'a, 'b, T: ?Sized> {
lock: &'a RwLock<'b, T>,
}
// TODO(quantumish): Support poisoning errors?
#[derive(Error, Debug)]
enum RwLockError {
#[error("deadlock detected")]
Deadlock,
#[error("max number of read locks exceeded")]
MaxReadLocks,
#[error("nonblocking operation would block")]
WouldBlock,
}
unsafe impl<T: ?Sized + Send> Send for RwLock<'_, T> {}
unsafe impl<T: ?Sized + Send + Sync> Sync for RwLock<'_, T> {}
impl<'a, T> RwLock<'a, T> {
fn new(lock: &'a mut MaybeUninit<libc::pthread_rwlock_t>, data: T) -> Self {
impl PthreadRwLock {
pub fn new(lock: *mut libc::pthread_rwlock_t) -> Self {
unsafe {
let mut attrs = MaybeUninit::uninit();
// Ignoring return value here - only possible error is OOM.
@@ -48,208 +23,81 @@ impl<'a, T> RwLock<'a, T> {
libc::PTHREAD_PROCESS_SHARED
);
// TODO(quantumish): worth making this function return Result?
libc::pthread_rwlock_init(lock.as_mut_ptr(), attrs.as_mut_ptr());
libc::pthread_rwlock_init(lock, attrs.as_mut_ptr());
// Safety: POSIX specifies that "any function affecting the attributes
// object (including destruction) shall not affect any previously
// initialized read-write locks".
libc::pthread_rwlockattr_destroy(attrs.as_mut_ptr());
Self {
inner: lock.assume_init_mut(),
data: data.into(),
}
}
}
fn read(&self) -> Result<RwLockReadGuard<'_, '_, T>, RwLockError> {
unsafe {
let res = libc::pthread_rwlock_rdlock(self.inner as *const _ as *mut _);
match res {
0 => (),
libc::EINVAL => panic!("failed to properly initialize lock"),
libc::EDEADLK => return Err(RwLockError::Deadlock),
libc::EAGAIN => return Err(RwLockError::MaxReadLocks),
e => panic!("unknown error code returned: {e}")
}
Ok(RwLockReadGuard {
data: NonNull::new_unchecked(self.data.get()),
lock: self
})
}
}
fn try_read(&self) -> Result<RwLockReadGuard<'_, '_, T>, RwLockError> {
unsafe {
let res = libc::pthread_rwlock_tryrdlock(self.inner as *const _ as *mut _);
match res {
0 => (),
libc::EINVAL => panic!("failed to properly initialize lock"),
libc::EDEADLK => return Err(RwLockError::Deadlock),
libc::EAGAIN => return Err(RwLockError::MaxReadLocks),
libc::EBUSY => return Err(RwLockError::WouldBlock),
e => panic!("unknown error code returned: {e}")
}
Ok(RwLockReadGuard {
data: NonNull::new_unchecked(self.data.get()),
lock: self
})
Self(Some(NonNull::new_unchecked(lock)))
}
}
fn write(&self) -> Result<RwLockWriteGuard<'_, '_, T>, RwLockError> {
fn inner(&self) -> NonNull<libc::pthread_rwlock_t> {
match self.0 {
None => panic!("PthreadRwLock constructed badly - something likely used RawMutex::INIT"),
Some(x) => x,
}
}
}
unsafe impl lock_api::RawRwLock for PthreadRwLock {
type GuardMarker = lock_api::GuardSend;
const INIT: Self = Self(None);
fn lock_shared(&self) {
unsafe {
let res = libc::pthread_rwlock_wrlock(self.inner as *const _ as *mut _);
match res {
0 => (),
libc::EINVAL => panic!("failed to properly initialize lock"),
libc::EDEADLK => return Err(RwLockError::Deadlock),
e => panic!("unknown error code returned: {e}")
let res = libc::pthread_rwlock_rdlock(self.inner().as_ptr());
if res != 0 {
panic!("rdlock failed with {res}");
}
}
Ok(RwLockWriteGuard { lock: self })
}
fn try_write(&self) -> Result<RwLockWriteGuard<'_, '_, T>, RwLockError> {
fn try_lock_shared(&self) -> bool {
unsafe {
let res = libc::pthread_rwlock_trywrlock(self.inner as *const _ as *mut _);
let res = libc::pthread_rwlock_tryrdlock(self.inner().as_ptr());
match res {
0 => (),
libc::EINVAL => panic!("failed to properly initialize lock"),
libc::EDEADLK => return Err(RwLockError::Deadlock),
libc::EBUSY => return Err(RwLockError::WouldBlock),
e => panic!("unknown error code returned: {e}")
0 => true,
libc::EAGAIN => false,
o => panic!("try_rdlock failed with {o}")
}
}
Ok(RwLockWriteGuard { lock: self })
}
}
unsafe impl<T: ?Sized + Sync> Sync for RwLockReadGuard<'_, '_, T> {}
unsafe impl<T: ?Sized + Sync> Sync for RwLockWriteGuard<'_, '_, T> {}
impl<T: ?Sized> Deref for RwLockReadGuard<'_, '_, T> {
type Target = T;
fn deref(&self) -> &T {
unsafe { self.data.as_ref() }
}
}
impl<T: ?Sized> Deref for RwLockWriteGuard<'_, '_, T> {
type Target = T;
fn deref(&self) -> &T {
unsafe { &*self.lock.data.get() }
}
}
impl<T: ?Sized> DerefMut for RwLockWriteGuard<'_, '_, T> {
fn deref_mut(&mut self) -> &mut T {
unsafe { &mut *self.lock.data.get() }
}
}
impl<T: ?Sized> Drop for RwLockReadGuard<'_, '_, T> {
fn drop(&mut self) -> () {
let res = unsafe { libc::pthread_rwlock_unlock(
self.lock.inner as *const _ as *mut _
) };
debug_assert!(res == 0);
}
}
impl<T: ?Sized> Drop for RwLockWriteGuard<'_, '_, T> {
fn drop(&mut self) -> () {
let res = unsafe { libc::pthread_rwlock_unlock(
self.lock.inner as *const _ as *mut _
) };
debug_assert!(res == 0);
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use super::*;
use RwLockError::*;
#[test]
fn test_single_process() {
let mut lock = MaybeUninit::uninit();
let wrapper = RwLock::new(&mut lock, 0);
let mut writer = wrapper.write().unwrap();
assert!(matches!(wrapper.try_write(), Err(Deadlock | WouldBlock)));
assert!(matches!(wrapper.try_read(), Err(Deadlock | WouldBlock)));
*writer = 5;
drop(writer);
let reader = wrapper.read().unwrap();
assert!(matches!(wrapper.try_write(), Err(Deadlock | WouldBlock)));
assert!(matches!(wrapper.read(), Ok(_)));
assert_eq!(*reader, 5);
drop(reader);
assert!(matches!(wrapper.try_write(), Ok(_)));
}
#[test]
fn test_multi_thread() {
let lock = Box::new(MaybeUninit::uninit());
let wrapper = Arc::new(RwLock::new(Box::leak(lock), 0));
let mut writer = wrapper.write().unwrap();
let t1 = {
let wrapper = wrapper.clone();
std::thread::spawn(move || {
let mut writer = wrapper.write().unwrap();
*writer = 20;
})
};
assert_eq!(*writer, 0);
*writer = 10;
assert_eq!(*writer, 10);
drop(writer);
t1.join().unwrap();
let mut writer = wrapper.write().unwrap();
assert_eq!(*writer, 20);
drop(writer);
let mut handles = vec![];
for _ in 0..5 {
handles.push({
let wrapper = wrapper.clone();
std::thread::spawn(move || {
let reader = wrapper.read().unwrap();
assert_eq!(*reader, 20);
})
});
fn lock_exclusive(&self) {
unsafe {
let res = libc::pthread_rwlock_wrlock(self.inner().as_ptr());
if res != 0 {
panic!("wrlock failed with {res}");
}
}
for h in handles {
h.join().unwrap();
}
let writer = wrapper.write().unwrap();
assert_eq!(*writer, 20);
}
// // TODO(quantumish): Terrible time-based synchronization, fix me.
// #[test]
// fn test_multi_process() {
// let max_size = 100;
// let init_struct = crate::shmem::ShmemHandle::new("test_multi_process", 0, max_size).unwrap();
// let ptr = init_struct.data_ptr.as_ptr();
// let lock: &mut _ = unsafe { ptr.add(
// ptr.align_offset(std::mem::align_of::<MaybeUninit<libc::pthread_rwlock_t>>())
// ).cast::<MaybeUninit<libc::pthread_rwlock_t>>().as_mut().unwrap() } ;
// let wrapper = RwLock::new(lock, 0);
fn try_lock_exclusive(&self) -> bool {
unsafe {
let res = libc::pthread_rwlock_trywrlock(self.inner().as_ptr());
match res {
0 => true,
libc::EAGAIN => false,
o => panic!("try_wrlock failed with {o}")
}
}
}
// let fork_result = unsafe { nix::unistd::fork().unwrap() };
// if !fork_result.is_parent() {
// let mut writer = wrapper.write().unwrap();
// std::thread::sleep(std::time::Duration::from_secs(5));
// *writer = 2;
// } else {
// std::thread::sleep(std::time::Duration::from_secs(1));
// assert!(matches!(wrapper.try_write(), Err(WouldBlock)));
// std::thread::sleep(std::time::Duration::from_secs(10));
// let writer = wrapper.try_write().unwrap();
// assert_eq!(*writer, 2);
// }
// }
unsafe fn unlock_exclusive(&self) {
unsafe {
let res = libc::pthread_rwlock_unlock(self.inner().as_ptr());
if res != 0 {
panic!("unlock failed with {res}");
}
}
}
unsafe fn unlock_shared(&self) {
unsafe {
let res = libc::pthread_rwlock_unlock(self.inner().as_ptr());
if res != 0 {
panic!("unlock failed with {res}");
}
}
}
}