mirror of
https://github.com/neondatabase/neon.git
synced 2026-05-27 01:50:38 +00:00
Switch to neon_shmem::sync lock_api and integrate into hashmap
This commit is contained in:
1
Cargo.lock
generated
1
Cargo.lock
generated
@@ -4038,6 +4038,7 @@ dependencies = [
|
||||
"foldhash",
|
||||
"hashbrown 0.15.4 (git+https://github.com/quantumish/hashbrown.git?rev=6610e6d)",
|
||||
"libc",
|
||||
"lock_api",
|
||||
"nix 0.30.1",
|
||||
"rand 0.9.1",
|
||||
"rand_distr 0.5.1",
|
||||
|
||||
@@ -11,6 +11,7 @@ workspace_hack = { version = "0.1", path = "../../workspace_hack" }
|
||||
rustc-hash = { version = "2.1.1" }
|
||||
rand = "0.9.1"
|
||||
libc.workspace = true
|
||||
lock_api = "0.4.13"
|
||||
|
||||
[dev-dependencies]
|
||||
criterion = { workspace = true, features = ["html_reports"] }
|
||||
|
||||
@@ -16,9 +16,9 @@
|
||||
|
||||
use std::hash::{Hash, BuildHasher};
|
||||
use std::mem::MaybeUninit;
|
||||
use std::default::Default;
|
||||
|
||||
use crate::{shmem, shmem::ShmemHandle};
|
||||
use crate::{shmem, sync::*};
|
||||
use crate::shmem::ShmemHandle;
|
||||
|
||||
mod core;
|
||||
pub mod entry;
|
||||
@@ -27,15 +27,14 @@ pub mod entry;
|
||||
mod tests;
|
||||
|
||||
use core::{Bucket, CoreHashMap, INVALID_POS};
|
||||
use entry::{Entry, OccupiedEntry};
|
||||
use entry::{Entry, OccupiedEntry, VacantEntry, PrevPos};
|
||||
|
||||
/// Builder for a [`HashMapAccess`].
|
||||
#[must_use]
|
||||
pub struct HashMapInit<'a, K, V, S = rustc_hash::FxBuildHasher> {
|
||||
shmem_handle: Option<ShmemHandle>,
|
||||
shared_ptr: *mut HashMapShared<'a, K, V>,
|
||||
shared_ptr: *mut RwLock<HashMapShared<'a, K, V>>,
|
||||
shared_size: usize,
|
||||
shrink_mode: HashMapShrinkMode,
|
||||
hasher: S,
|
||||
num_buckets: u32,
|
||||
}
|
||||
@@ -45,28 +44,6 @@ pub struct HashMapAccess<'a, K, V, S = rustc_hash::FxBuildHasher> {
|
||||
shmem_handle: Option<ShmemHandle>,
|
||||
shared_ptr: *mut HashMapShared<'a, K, V>,
|
||||
hasher: S,
|
||||
shrink_mode: HashMapShrinkMode,
|
||||
}
|
||||
|
||||
/// Enum specifying what behavior to have surrounding occupied entries in what is
|
||||
/// about-to-be-shrinked space during a call to [`HashMapAccess::finish_shrink`].
|
||||
#[derive(PartialEq, Eq)]
|
||||
pub enum HashMapShrinkMode {
|
||||
/// Remap entry to the range of buckets that will remain after shrinking.
|
||||
///
|
||||
/// Requires that caller has left enough room within the map such that this is possible.
|
||||
Remap,
|
||||
/// Remove any entries remaining in soon to be deallocated space.
|
||||
///
|
||||
/// Only really useful if you legitimately do not care what entries are removed.
|
||||
/// Should primarily be used for testing.
|
||||
Remove,
|
||||
}
|
||||
|
||||
impl Default for HashMapShrinkMode {
|
||||
fn default() -> Self {
|
||||
Self::Remap
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<K: Sync, V: Sync, S> Sync for HashMapAccess<'_, K, V, S> {}
|
||||
@@ -80,14 +57,9 @@ impl<'a, K: Clone + Hash + Eq, V, S> HashMapInit<'a, K, V, S> {
|
||||
shared_ptr: self.shared_ptr,
|
||||
shared_size: self.shared_size,
|
||||
num_buckets: self.num_buckets,
|
||||
shrink_mode: self.shrink_mode,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_shrink_mode(self, mode: HashMapShrinkMode) -> Self {
|
||||
Self { shrink_mode: mode, ..self }
|
||||
}
|
||||
|
||||
/// Loosely (over)estimate the size needed to store a hash table with `num_buckets` buckets.
|
||||
pub fn estimate_size(num_buckets: u32) -> usize {
|
||||
// add some margin to cover alignment etc.
|
||||
@@ -96,13 +68,17 @@ impl<'a, K: Clone + Hash + Eq, V, S> HashMapInit<'a, K, V, S> {
|
||||
|
||||
/// Initialize a table for writing.
|
||||
pub fn attach_writer(self) -> HashMapAccess<'a, K, V, S> {
|
||||
// carve out the HashMapShared struct from the area.
|
||||
let mut ptr: *mut u8 = self.shared_ptr.cast();
|
||||
let end_ptr: *mut u8 = unsafe { ptr.add(self.shared_size) };
|
||||
ptr = unsafe { ptr.add(ptr.align_offset(align_of::<HashMapShared<K, V>>())) };
|
||||
let shared_ptr: *mut HashMapShared<K, V> = ptr.cast();
|
||||
ptr = unsafe { ptr.add(size_of::<HashMapShared<K, V>>()) };
|
||||
|
||||
// carve out area for the One Big Lock (TM) and the HashMapShared.
|
||||
ptr = unsafe { ptr.add(ptr.align_offset(align_of::<libc::pthread_rwlock_t>())) };
|
||||
let raw_lock_ptr = ptr;
|
||||
ptr = unsafe { ptr.add(size_of::<libc::pthread_rwlock_t>()) };
|
||||
ptr = unsafe { ptr.add(ptr.align_offset(align_of::<HashMapShared<K, V>>())) };
|
||||
let shared_ptr: *mut HashMapShared<K, V> = ptr.cast();
|
||||
ptr = unsafe { ptr.add(size_of::<HashMapShared<K, V>>()) };
|
||||
|
||||
// carve out the buckets
|
||||
ptr = unsafe { ptr.byte_add(ptr.align_offset(align_of::<core::Bucket<K, V>>())) };
|
||||
let buckets_ptr = ptr;
|
||||
@@ -121,14 +97,14 @@ impl<'a, K: Clone + Hash + Eq, V, S> HashMapInit<'a, K, V, S> {
|
||||
std::slice::from_raw_parts_mut(dictionary_ptr.cast(), dictionary_size as usize)
|
||||
};
|
||||
let hashmap = CoreHashMap::new(buckets, dictionary);
|
||||
unsafe {
|
||||
std::ptr::write(shared_ptr, HashMapShared { inner: hashmap });
|
||||
}
|
||||
let lock = RwLock::from_raw(PthreadRwLock::new(raw_lock_ptr.cast()), hashmap);
|
||||
unsafe {
|
||||
std::ptr::write(shared_ptr, lock);
|
||||
}
|
||||
|
||||
HashMapAccess {
|
||||
shmem_handle: self.shmem_handle,
|
||||
shared_ptr: self.shared_ptr,
|
||||
shrink_mode: self.shrink_mode,
|
||||
shared_ptr,
|
||||
hasher: self.hasher,
|
||||
}
|
||||
}
|
||||
@@ -145,14 +121,13 @@ impl<'a, K: Clone + Hash + Eq, V, S> HashMapInit<'a, K, V, S> {
|
||||
/// relies on the memory layout! The data structures are laid out in the contiguous shared memory
|
||||
/// area as follows:
|
||||
///
|
||||
/// [`libc::pthread_rwlock_t`]
|
||||
/// [`HashMapShared`]
|
||||
/// [buckets]
|
||||
/// [dictionary]
|
||||
///
|
||||
/// In between the above parts, there can be padding bytes to align the parts correctly.
|
||||
struct HashMapShared<'a, K, V> {
|
||||
inner: CoreHashMap<'a, K, V>
|
||||
}
|
||||
type HashMapShared<'a, K, V> = RwLock<CoreHashMap<'a, K, V>>;
|
||||
|
||||
impl<'a, K, V> HashMapInit<'a, K, V, rustc_hash::FxBuildHasher>
|
||||
where
|
||||
@@ -168,7 +143,6 @@ where
|
||||
shmem_handle: None,
|
||||
shared_ptr: area.as_mut_ptr().cast(),
|
||||
shared_size: area.len(),
|
||||
shrink_mode: HashMapShrinkMode::default(),
|
||||
hasher: rustc_hash::FxBuildHasher,
|
||||
}
|
||||
}
|
||||
@@ -187,7 +161,6 @@ where
|
||||
shared_ptr: shmem.data_ptr.as_ptr().cast(),
|
||||
shmem_handle: Some(shmem),
|
||||
shared_size: size,
|
||||
shrink_mode: HashMapShrinkMode::default(),
|
||||
hasher: rustc_hash::FxBuildHasher
|
||||
}
|
||||
}
|
||||
@@ -204,7 +177,6 @@ where
|
||||
shared_ptr: shmem.data_ptr.as_ptr().cast(),
|
||||
shmem_handle: Some(shmem),
|
||||
shared_size: size,
|
||||
shrink_mode: HashMapShrinkMode::default(),
|
||||
hasher: rustc_hash::FxBuildHasher
|
||||
}
|
||||
}
|
||||
@@ -229,25 +201,64 @@ where
|
||||
self.hasher.hash_one(key)
|
||||
}
|
||||
|
||||
fn entry_with_hash(&self, key: K, hash: u64) -> Entry<'a, '_, K, V> {
|
||||
let mut map = unsafe { self.shared_ptr.as_ref() }.unwrap().write();
|
||||
let dict_pos = hash as usize % map.dictionary.len();
|
||||
let first = map.dictionary[dict_pos];
|
||||
if first == INVALID_POS {
|
||||
// no existing entry
|
||||
return Entry::Vacant(VacantEntry {
|
||||
map,
|
||||
key,
|
||||
dict_pos: dict_pos as u32,
|
||||
});
|
||||
}
|
||||
|
||||
let mut prev_pos = PrevPos::First(dict_pos as u32);
|
||||
let mut next = first;
|
||||
loop {
|
||||
let bucket = &mut map.buckets[next as usize];
|
||||
let (bucket_key, _bucket_value) = bucket.inner.as_mut().expect("entry is in use");
|
||||
if *bucket_key == key {
|
||||
// found existing entry
|
||||
return Entry::Occupied(OccupiedEntry {
|
||||
map,
|
||||
_key: key,
|
||||
prev_pos,
|
||||
bucket_pos: next,
|
||||
});
|
||||
}
|
||||
|
||||
if bucket.next == INVALID_POS {
|
||||
// No existing entry
|
||||
return Entry::Vacant(VacantEntry {
|
||||
map,
|
||||
key,
|
||||
dict_pos: dict_pos as u32,
|
||||
});
|
||||
}
|
||||
prev_pos = PrevPos::Chained(next);
|
||||
next = bucket.next;
|
||||
}
|
||||
}
|
||||
|
||||
/// Get a reference to the corresponding value for a key.
|
||||
pub fn get<'e>(&'e self, key: &K) -> Option<&'e V> {
|
||||
let map = unsafe { self.shared_ptr.as_ref() }.unwrap();
|
||||
pub fn get<'e>(&'e self, key: &K) -> Option<ValueReadGuard<'e, V>> {
|
||||
let hash = self.get_hash_value(key);
|
||||
map.inner.get_with_hash(key, hash)
|
||||
let map = unsafe { self.shared_ptr.as_ref() }.unwrap().read();
|
||||
RwLockReadGuard::try_map(map, |m| m.get_with_hash(key, hash)).ok()
|
||||
}
|
||||
|
||||
/// Get a reference to the entry containing a key.
|
||||
pub fn entry(&self, key: K) -> Entry<'a, '_, K, V> {
|
||||
let map = unsafe { self.shared_ptr.as_mut() }.unwrap();
|
||||
let hash = self.get_hash_value(&key);
|
||||
map.inner.entry_with_hash(key, hash)
|
||||
self.entry_with_hash(key, hash)
|
||||
}
|
||||
|
||||
/// Remove a key given its hash. Returns the associated value if it existed.
|
||||
pub fn remove(&self, key: &K) -> Option<V> {
|
||||
let map = unsafe { self.shared_ptr.as_mut() }.unwrap();
|
||||
let hash = self.get_hash_value(&key);
|
||||
match map.inner.entry_with_hash(key.clone(), hash) {
|
||||
match self.entry_with_hash(key.clone(), hash) {
|
||||
Entry::Occupied(e) => Some(e.remove()),
|
||||
Entry::Vacant(_) => None
|
||||
}
|
||||
@@ -258,12 +269,11 @@ where
|
||||
/// # Errors
|
||||
/// Will return [`core::FullError`] if there is no more space left in the map.
|
||||
pub fn insert(&self, key: K, value: V) -> Result<Option<V>, core::FullError> {
|
||||
let map = unsafe { self.shared_ptr.as_mut() }.unwrap();
|
||||
let hash = self.get_hash_value(&key);
|
||||
match map.inner.entry_with_hash(key.clone(), hash) {
|
||||
match self.entry_with_hash(key.clone(), hash) {
|
||||
Entry::Occupied(mut e) => Ok(Some(e.insert(value))),
|
||||
Entry::Vacant(e) => {
|
||||
e.insert(value)?;
|
||||
_ = e.insert(value)?;
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
@@ -275,13 +285,12 @@ where
|
||||
/// due to the [`OccupiedEntry`] type owning the key and also a hash of the key in order
|
||||
/// to enable repairing the hash chain if the entry is removed.
|
||||
pub fn entry_at_bucket(&self, pos: usize) -> Option<OccupiedEntry<'a, '_, K, V>> {
|
||||
let map = unsafe { self.shared_ptr.as_mut() }.unwrap();
|
||||
let inner = &mut map.inner;
|
||||
if pos >= inner.buckets.len() {
|
||||
let map = unsafe { self.shared_ptr.as_mut() }.unwrap().write();
|
||||
if pos >= map.buckets.len() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let entry = inner.buckets[pos].inner.as_ref();
|
||||
let entry = map.buckets[pos].inner.as_ref();
|
||||
match entry {
|
||||
Some((key, _)) => Some(OccupiedEntry {
|
||||
_key: key.clone(),
|
||||
@@ -289,7 +298,7 @@ where
|
||||
prev_pos: entry::PrevPos::Unknown(
|
||||
self.get_hash_value(&key)
|
||||
),
|
||||
map: inner,
|
||||
map,
|
||||
}),
|
||||
_ => None,
|
||||
}
|
||||
@@ -297,8 +306,8 @@ where
|
||||
|
||||
/// Returns the number of buckets in the table.
|
||||
pub fn get_num_buckets(&self) -> usize {
|
||||
let map = unsafe { self.shared_ptr.as_ref() }.unwrap();
|
||||
map.inner.get_num_buckets()
|
||||
let map = unsafe { self.shared_ptr.as_ref() }.unwrap().read();
|
||||
map.get_num_buckets()
|
||||
}
|
||||
|
||||
/// Return the key and value stored in bucket with given index. This can be used to
|
||||
@@ -306,38 +315,35 @@ where
|
||||
// TODO: An Iterator might be nicer. The communicator's clock algorithm needs to
|
||||
// _slowly_ iterate through all buckets with its clock hand, without holding a lock.
|
||||
// If we switch to an Iterator, it must not hold the lock.
|
||||
pub fn get_at_bucket(&self, pos: usize) -> Option<&(K, V)> {
|
||||
let map = unsafe { self.shared_ptr.as_ref() }.unwrap();
|
||||
|
||||
if pos >= map.inner.buckets.len() {
|
||||
pub fn get_at_bucket(&self, pos: usize) -> Option<ValueReadGuard<(K, V)>> {
|
||||
let map = unsafe { self.shared_ptr.as_ref() }.unwrap().read();
|
||||
if pos >= map.buckets.len() {
|
||||
return None;
|
||||
}
|
||||
let bucket = &map.inner.buckets[pos];
|
||||
bucket.inner.as_ref()
|
||||
RwLockReadGuard::try_map(map, |m| m.buckets[pos].inner.as_ref()).ok()
|
||||
}
|
||||
|
||||
/// Returns the index of the bucket a given value corresponds to.
|
||||
pub fn get_bucket_for_value(&self, val_ptr: *const V) -> usize {
|
||||
let map = unsafe { self.shared_ptr.as_ref() }.unwrap();
|
||||
let map = unsafe { self.shared_ptr.as_ref() }.unwrap().read();
|
||||
|
||||
let origin = map.inner.buckets.as_ptr();
|
||||
let origin = map.buckets.as_ptr();
|
||||
let idx = (val_ptr as usize - origin as usize) / size_of::<Bucket<K, V>>();
|
||||
assert!(idx < map.inner.buckets.len());
|
||||
assert!(idx < map.buckets.len());
|
||||
|
||||
idx
|
||||
}
|
||||
|
||||
/// Returns the number of occupied buckets in the table.
|
||||
pub fn get_num_buckets_in_use(&self) -> usize {
|
||||
let map = unsafe { self.shared_ptr.as_ref() }.unwrap();
|
||||
map.inner.buckets_in_use as usize
|
||||
let map = unsafe { self.shared_ptr.as_ref() }.unwrap().read();
|
||||
map.buckets_in_use as usize
|
||||
}
|
||||
|
||||
/// Clears all entries in a table. Does not reset any shrinking operations.
|
||||
pub fn clear(&self) {
|
||||
let map = unsafe { self.shared_ptr.as_mut() }.unwrap();
|
||||
let inner = &mut map.inner;
|
||||
inner.clear();
|
||||
let mut map = unsafe { self.shared_ptr.as_mut() }.unwrap().write();
|
||||
map.clear();
|
||||
}
|
||||
|
||||
/// Perform an in-place rehash of some region (0..`rehash_buckets`) of the table and reset
|
||||
@@ -389,13 +395,12 @@ where
|
||||
|
||||
/// Rehash the map without growing or shrinking.
|
||||
pub fn shuffle(&self) {
|
||||
let map = unsafe { self.shared_ptr.as_mut() }.unwrap();
|
||||
let inner = &mut map.inner;
|
||||
let num_buckets = inner.get_num_buckets() as u32;
|
||||
let mut map = unsafe { self.shared_ptr.as_mut() }.unwrap().write();
|
||||
let num_buckets = map.get_num_buckets() as u32;
|
||||
let size_bytes = HashMapInit::<K, V, S>::estimate_size(num_buckets);
|
||||
let end_ptr: *mut u8 = unsafe { self.shared_ptr.byte_add(size_bytes).cast() };
|
||||
let buckets_ptr = inner.buckets.as_mut_ptr();
|
||||
self.rehash_dict(inner, buckets_ptr, end_ptr, num_buckets, num_buckets);
|
||||
let buckets_ptr = map.buckets.as_mut_ptr();
|
||||
self.rehash_dict(&mut map, buckets_ptr, end_ptr, num_buckets, num_buckets);
|
||||
}
|
||||
|
||||
/// Grow the number of buckets within the table.
|
||||
@@ -409,10 +414,9 @@ where
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an [`shmem::Error`] if any errors occur resizing the memory region.
|
||||
pub fn grow(&mut self, num_buckets: u32) -> Result<(), shmem::Error> {
|
||||
let map = unsafe { self.shared_ptr.as_mut() }.unwrap();
|
||||
let inner = &mut map.inner;
|
||||
let old_num_buckets = inner.buckets.len() as u32;
|
||||
pub fn grow(&self, num_buckets: u32) -> Result<(), shmem::Error> {
|
||||
let mut map = unsafe { self.shared_ptr.as_mut() }.unwrap().write();
|
||||
let old_num_buckets = map.buckets.len() as u32;
|
||||
|
||||
assert!(num_buckets >= old_num_buckets, "grow called with a smaller number of buckets");
|
||||
if num_buckets == old_num_buckets {
|
||||
@@ -429,7 +433,7 @@ where
|
||||
|
||||
// Initialize new buckets. The new buckets are linked to the free list.
|
||||
// NB: This overwrites the dictionary!
|
||||
let buckets_ptr = inner.buckets.as_mut_ptr();
|
||||
let buckets_ptr = map.buckets.as_mut_ptr();
|
||||
unsafe {
|
||||
for i in old_num_buckets..num_buckets {
|
||||
let bucket = buckets_ptr.add(i as usize);
|
||||
@@ -437,15 +441,15 @@ where
|
||||
next: if i < num_buckets-1 {
|
||||
i + 1
|
||||
} else {
|
||||
inner.free_head
|
||||
map.free_head
|
||||
},
|
||||
inner: None,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
self.rehash_dict(inner, buckets_ptr, end_ptr, num_buckets, old_num_buckets);
|
||||
inner.free_head = old_num_buckets;
|
||||
self.rehash_dict(&mut map, buckets_ptr, end_ptr, num_buckets, old_num_buckets);
|
||||
map.free_head = old_num_buckets;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -456,22 +460,22 @@ where
|
||||
/// Panics if called on a map initialized with [`HashMapInit::with_fixed`] or if `num_buckets` is
|
||||
/// greater than the number of buckets in the map.
|
||||
pub fn begin_shrink(&mut self, num_buckets: u32) {
|
||||
let map = unsafe { self.shared_ptr.as_mut() }.unwrap();
|
||||
let mut map = unsafe { self.shared_ptr.as_mut() }.unwrap().write();
|
||||
assert!(
|
||||
num_buckets <= map.inner.get_num_buckets() as u32,
|
||||
num_buckets <= map.get_num_buckets() as u32,
|
||||
"shrink called with a larger number of buckets"
|
||||
);
|
||||
_ = self
|
||||
.shmem_handle
|
||||
.as_ref()
|
||||
.expect("shrink called on a fixed-size hash table");
|
||||
map.inner.alloc_limit = num_buckets;
|
||||
map.alloc_limit = num_buckets;
|
||||
}
|
||||
|
||||
/// If a shrink operation is underway, returns the target size of the map. Otherwise, returns None.
|
||||
pub fn shrink_goal(&self) -> Option<usize> {
|
||||
let map = unsafe { self.shared_ptr.as_mut() }.unwrap();
|
||||
let goal = map.inner.alloc_limit;
|
||||
let map = unsafe { self.shared_ptr.as_mut() }.unwrap().read();
|
||||
let goal = map.alloc_limit;
|
||||
if goal == INVALID_POS { None } else { Some(goal as usize) }
|
||||
}
|
||||
|
||||
@@ -487,31 +491,28 @@ where
|
||||
/// # Errors
|
||||
/// Returns an [`shmem::Error`] if any errors occur resizing the memory region.
|
||||
pub fn finish_shrink(&self) -> Result<(), shmem::Error> {
|
||||
let map = unsafe { self.shared_ptr.as_mut() }.unwrap();
|
||||
let inner = &mut map.inner;
|
||||
let mut map = unsafe { self.shared_ptr.as_mut() }.unwrap().write();
|
||||
assert!(
|
||||
inner.alloc_limit != INVALID_POS,
|
||||
map.alloc_limit != INVALID_POS,
|
||||
"called finish_shrink when no shrink is in progress"
|
||||
);
|
||||
|
||||
let num_buckets = inner.alloc_limit;
|
||||
let num_buckets = map.alloc_limit;
|
||||
|
||||
if inner.get_num_buckets() == num_buckets as usize {
|
||||
if map.get_num_buckets() == num_buckets as usize {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if self.shrink_mode == HashMapShrinkMode::Remap {
|
||||
assert!(
|
||||
inner.buckets_in_use <= num_buckets,
|
||||
"called finish_shrink before enough entries were removed"
|
||||
);
|
||||
|
||||
for i in (num_buckets as usize)..inner.buckets.len() {
|
||||
if let Some((k, v)) = inner.buckets[i].inner.take() {
|
||||
// alloc_bucket increases count, so need to decrease since we're just moving
|
||||
inner.buckets_in_use -= 1;
|
||||
inner.alloc_bucket(k, v).unwrap();
|
||||
}
|
||||
assert!(
|
||||
map.buckets_in_use <= num_buckets,
|
||||
"called finish_shrink before enough entries were removed"
|
||||
);
|
||||
|
||||
for i in (num_buckets as usize)..map.buckets.len() {
|
||||
if let Some((k, v)) = map.buckets[i].inner.take() {
|
||||
// alloc_bucket increases count, so need to decrease since we're just moving
|
||||
map.buckets_in_use -= 1;
|
||||
map.alloc_bucket(k, v).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -523,9 +524,9 @@ where
|
||||
let size_bytes = HashMapInit::<K, V, S>::estimate_size(num_buckets);
|
||||
shmem_handle.set_size(size_bytes)?;
|
||||
let end_ptr: *mut u8 = unsafe { shmem_handle.data_ptr.as_ptr().add(size_bytes) };
|
||||
let buckets_ptr = inner.buckets.as_mut_ptr();
|
||||
self.rehash_dict(inner, buckets_ptr, end_ptr, num_buckets, num_buckets);
|
||||
inner.alloc_limit = INVALID_POS;
|
||||
let buckets_ptr = map.buckets.as_mut_ptr();
|
||||
self.rehash_dict(&mut map, buckets_ptr, end_ptr, num_buckets, num_buckets);
|
||||
map.alloc_limit = INVALID_POS;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
use std::hash::Hash;
|
||||
use std::mem::MaybeUninit;
|
||||
|
||||
use crate::hash::entry::{Entry, OccupiedEntry, PrevPos, VacantEntry};
|
||||
use crate::hash::entry::*;
|
||||
|
||||
/// Invalid position within the map (either within the dictionary or bucket array).
|
||||
pub(crate) const INVALID_POS: u32 = u32::MAX;
|
||||
@@ -29,6 +29,7 @@ pub(crate) struct CoreHashMap<'a, K, V> {
|
||||
pub(crate) alloc_limit: u32,
|
||||
/// The number of currently occupied buckets.
|
||||
pub(crate) buckets_in_use: u32,
|
||||
// pub(crate) lock: libc::pthread_mutex_t,
|
||||
// Unclear what the purpose of this is.
|
||||
pub(crate) _user_list_head: u32,
|
||||
}
|
||||
@@ -109,47 +110,6 @@ impl<'a, K: Clone + Hash + Eq, V> CoreHashMap<'a, K, V> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the [`Entry`] associated with a key given hash. This should be used for updates/inserts.
|
||||
pub fn entry_with_hash(&mut self, key: K, hash: u64) -> Entry<'a, '_, K, V> {
|
||||
let dict_pos = hash as usize % self.dictionary.len();
|
||||
let first = self.dictionary[dict_pos];
|
||||
if first == INVALID_POS {
|
||||
// no existing entry
|
||||
return Entry::Vacant(VacantEntry {
|
||||
map: self,
|
||||
key,
|
||||
dict_pos: dict_pos as u32,
|
||||
});
|
||||
}
|
||||
|
||||
let mut prev_pos = PrevPos::First(dict_pos as u32);
|
||||
let mut next = first;
|
||||
loop {
|
||||
let bucket = &mut self.buckets[next as usize];
|
||||
let (bucket_key, _bucket_value) = bucket.inner.as_mut().expect("entry is in use");
|
||||
if *bucket_key == key {
|
||||
// found existing entry
|
||||
return Entry::Occupied(OccupiedEntry {
|
||||
map: self,
|
||||
_key: key,
|
||||
prev_pos,
|
||||
bucket_pos: next,
|
||||
});
|
||||
}
|
||||
|
||||
if bucket.next == INVALID_POS {
|
||||
// No existing entry
|
||||
return Entry::Vacant(VacantEntry {
|
||||
map: self,
|
||||
key,
|
||||
dict_pos: dict_pos as u32,
|
||||
});
|
||||
}
|
||||
prev_pos = PrevPos::Chained(next);
|
||||
next = bucket.next;
|
||||
}
|
||||
}
|
||||
|
||||
/// Get number of buckets in map.
|
||||
pub fn get_num_buckets(&self) -> usize {
|
||||
self.buckets.len()
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
//! Equivalent of [`std::collections::hash_map::Entry`] for this hashmap.
|
||||
|
||||
use crate::hash::core::{CoreHashMap, FullError, INVALID_POS};
|
||||
use crate::sync::{RwLockWriteGuard, ValueWriteGuard};
|
||||
|
||||
use std::hash::Hash;
|
||||
use std::mem;
|
||||
|
||||
/// View into an entry in the map (either vacant or occupied).
|
||||
|
||||
pub enum Entry<'a, 'b, K, V> {
|
||||
Occupied(OccupiedEntry<'a, 'b, K, V>),
|
||||
Vacant(VacantEntry<'a, 'b, K, V>),
|
||||
@@ -22,10 +23,9 @@ pub(crate) enum PrevPos {
|
||||
Unknown(u64),
|
||||
}
|
||||
|
||||
/// View into an occupied entry within the map.
|
||||
pub struct OccupiedEntry<'a, 'b, K, V> {
|
||||
/// Mutable reference to the map containing this entry.
|
||||
pub(crate) map: &'b mut CoreHashMap<'a, K, V>,
|
||||
pub(crate) map: RwLockWriteGuard<'b, CoreHashMap<'a, K, V>>,
|
||||
/// The key of the occupied entry
|
||||
pub(crate) _key: K,
|
||||
/// The index of the previous entry in the chain.
|
||||
@@ -66,7 +66,7 @@ impl<K, V> OccupiedEntry<'_, '_, K, V> {
|
||||
/// # Panics
|
||||
/// Panics if the `prev_pos` field is equal to [`PrevPos::Unknown`]. In practice, this means
|
||||
/// the entry was obtained via calling something like [`CoreHashMap::entry_at_bucket`].
|
||||
pub fn remove(self) -> V {
|
||||
pub fn remove(mut self) -> V {
|
||||
// If this bucket was queried by index, go ahead and follow its chain from the start.
|
||||
let prev = if let PrevPos::Unknown(hash) = self.prev_pos {
|
||||
let dict_idx = hash as usize % self.map.dictionary.len();
|
||||
@@ -90,15 +90,17 @@ impl<K, V> OccupiedEntry<'_, '_, K, V> {
|
||||
self.map.dictionary[dict_pos as usize] = bucket.next;
|
||||
},
|
||||
PrevPos::Chained(bucket_pos) => {
|
||||
// println!("we think prev of {} is {bucket_pos}", self.bucket_pos);
|
||||
self.map.buckets[bucket_pos as usize].next = bucket.next;
|
||||
},
|
||||
_ => unreachable!(),
|
||||
}
|
||||
|
||||
// and add it to the freelist
|
||||
// and add it to the freelist
|
||||
let free = self.map.free_head;
|
||||
let bucket = &mut self.map.buckets[self.bucket_pos as usize];
|
||||
let old_value = bucket.inner.take();
|
||||
bucket.next = self.map.free_head;
|
||||
bucket.next = free;
|
||||
self.map.free_head = self.bucket_pos;
|
||||
self.map.buckets_in_use -= 1;
|
||||
|
||||
@@ -109,7 +111,7 @@ impl<K, V> OccupiedEntry<'_, '_, K, V> {
|
||||
/// An abstract view into a vacant entry within the map.
|
||||
pub struct VacantEntry<'a, 'b, K, V> {
|
||||
/// Mutable reference to the map containing this entry.
|
||||
pub(crate) map: &'b mut CoreHashMap<'a, K, V>,
|
||||
pub(crate) map: RwLockWriteGuard<'b, CoreHashMap<'a, K, V>>,
|
||||
/// The key to be inserted into this entry.
|
||||
pub(crate) key: K,
|
||||
/// The position within the dictionary corresponding to the key's hash.
|
||||
@@ -121,16 +123,17 @@ impl<'b, K: Clone + Hash + Eq, V> VacantEntry<'_, 'b, K, V> {
|
||||
///
|
||||
/// # Errors
|
||||
/// Will return [`FullError`] if there are no unoccupied buckets in the map.
|
||||
pub fn insert(self, value: V) -> Result<&'b mut V, FullError> {
|
||||
pub fn insert(mut self, value: V) -> Result<ValueWriteGuard<'b, V>, FullError> {
|
||||
let pos = self.map.alloc_bucket(self.key, value)?;
|
||||
if pos == INVALID_POS {
|
||||
return Err(FullError());
|
||||
}
|
||||
let bucket = &mut self.map.buckets[pos as usize];
|
||||
bucket.next = self.map.dictionary[self.dict_pos as usize];
|
||||
self.map.buckets[pos as usize].next = self.map.dictionary[self.dict_pos as usize];
|
||||
self.map.dictionary[self.dict_pos as usize] = pos;
|
||||
|
||||
let result = &mut self.map.buckets[pos as usize].inner.as_mut().unwrap().1;
|
||||
Ok(result)
|
||||
Ok(RwLockWriteGuard::map(
|
||||
self.map,
|
||||
|m| &mut m.buckets[pos as usize].inner.as_mut().unwrap().1
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -36,7 +36,7 @@ impl<'a> From<&'a [u8]> for TestKey {
|
||||
}
|
||||
|
||||
fn test_inserts<K: Into<TestKey> + Copy>(keys: &[K]) {
|
||||
let mut w = HashMapInit::<TestKey, usize>::new_resizeable_named(
|
||||
let w = HashMapInit::<TestKey, usize>::new_resizeable_named(
|
||||
100000, 120000, "test_inserts"
|
||||
).attach_writer();
|
||||
|
||||
@@ -190,10 +190,6 @@ fn random_ops() {
|
||||
let op = TestOp(key, if rng.random_bool(0.75) { Some(i) } else { None });
|
||||
|
||||
apply_op(&op, &mut writer, &mut shadow);
|
||||
|
||||
if i % 1000 == 0 {
|
||||
eprintln!("{i} ops processed");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -267,7 +263,7 @@ fn test_idx_remove() {
|
||||
|
||||
}
|
||||
while let Some((key, val)) = shadow.pop_first() {
|
||||
assert_eq!(writer.get(&key), Some(&val));
|
||||
assert_eq!(*writer.get(&key).unwrap(), val);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -326,8 +322,10 @@ fn test_shrink_grow_seq() {
|
||||
writer.grow(1500).unwrap();
|
||||
do_random_ops(600, 1500, 0.1, &mut writer, &mut shadow, &mut rng);
|
||||
eprintln!("Shrinking to 200");
|
||||
while shadow.len() > 100 {
|
||||
do_deletes(1, &mut writer, &mut shadow);
|
||||
}
|
||||
do_shrink(&mut writer, &mut shadow, 200);
|
||||
do_deletes(100, &mut writer, &mut shadow);
|
||||
do_random_ops(50, 1500, 0.25, &mut writer, &mut shadow, &mut rng);
|
||||
eprintln!("Growing to 10k");
|
||||
writer.grow(10000).unwrap();
|
||||
@@ -336,7 +334,7 @@ fn test_shrink_grow_seq() {
|
||||
|
||||
#[test]
|
||||
fn test_bucket_ops() {
|
||||
let mut writer = HashMapInit::<TestKey, usize>::new_resizeable_named(
|
||||
let writer = HashMapInit::<TestKey, usize>::new_resizeable_named(
|
||||
1000, 1200, "test_bucket_ops"
|
||||
).attach_writer();
|
||||
match writer.entry(1.into()) {
|
||||
@@ -345,21 +343,21 @@ fn test_bucket_ops() {
|
||||
}
|
||||
assert_eq!(writer.get_num_buckets_in_use(), 1);
|
||||
assert_eq!(writer.get_num_buckets(), 1000);
|
||||
assert_eq!(writer.get(&1.into()), Some(&2));
|
||||
assert_eq!(*writer.get(&1.into()).unwrap(), 2);
|
||||
let pos = match writer.entry(1.into()) {
|
||||
Entry::Occupied(e) => {
|
||||
assert_eq!(e._key, 1.into());
|
||||
let pos = e.bucket_pos as usize;
|
||||
assert_eq!(writer.entry_at_bucket(pos).unwrap()._key, 1.into());
|
||||
assert_eq!(writer.get_at_bucket(pos), Some(&(1.into(), 2)));
|
||||
assert_eq!(*writer.get_at_bucket(pos).unwrap(), (1.into(), 2));
|
||||
pos
|
||||
},
|
||||
Entry::Vacant(_) => { panic!("Insert didn't affect entry"); },
|
||||
};
|
||||
let ptr: *const usize = writer.get(&1.into()).unwrap();
|
||||
let ptr: *const usize = &*writer.get(&1.into()).unwrap();
|
||||
assert_eq!(writer.get_bucket_for_value(ptr), pos);
|
||||
writer.remove(&1.into());
|
||||
assert_eq!(writer.get(&1.into()), None);
|
||||
assert!(writer.get(&1.into()).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -2,43 +2,18 @@
|
||||
|
||||
use std::mem::MaybeUninit;
|
||||
use std::ptr::NonNull;
|
||||
use std::cell::UnsafeCell;
|
||||
use std::ops::{Deref, DerefMut};
|
||||
use thiserror::Error;
|
||||
|
||||
pub type RwLock<T> = lock_api::RwLock<PthreadRwLock, T>;
|
||||
pub(crate) type RwLockReadGuard<'a, T> = lock_api::RwLockReadGuard<'a, PthreadRwLock, T>;
|
||||
pub type RwLockWriteGuard<'a, T> = lock_api::RwLockWriteGuard<'a, PthreadRwLock, T>;
|
||||
pub type ValueReadGuard<'a, T> = lock_api::MappedRwLockReadGuard<'a, PthreadRwLock, T>;
|
||||
pub type ValueWriteGuard<'a, T> = lock_api::MappedRwLockWriteGuard<'a, PthreadRwLock, T>;
|
||||
|
||||
/// Shared memory read-write lock.
|
||||
struct RwLock<'a, T: ?Sized> {
|
||||
inner: &'a mut libc::pthread_rwlock_t,
|
||||
data: UnsafeCell<T>,
|
||||
}
|
||||
pub struct PthreadRwLock(Option<NonNull<libc::pthread_rwlock_t>>);
|
||||
|
||||
/// RAII guard for a read lock.
|
||||
struct RwLockReadGuard<'a, 'b, T: ?Sized> {
|
||||
data: NonNull<T>,
|
||||
lock: &'a RwLock<'b, T>,
|
||||
}
|
||||
|
||||
/// RAII guard for a write lock.
|
||||
struct RwLockWriteGuard<'a, 'b, T: ?Sized> {
|
||||
lock: &'a RwLock<'b, T>,
|
||||
}
|
||||
|
||||
// TODO(quantumish): Support poisoning errors?
|
||||
#[derive(Error, Debug)]
|
||||
enum RwLockError {
|
||||
#[error("deadlock detected")]
|
||||
Deadlock,
|
||||
#[error("max number of read locks exceeded")]
|
||||
MaxReadLocks,
|
||||
#[error("nonblocking operation would block")]
|
||||
WouldBlock,
|
||||
}
|
||||
|
||||
unsafe impl<T: ?Sized + Send> Send for RwLock<'_, T> {}
|
||||
unsafe impl<T: ?Sized + Send + Sync> Sync for RwLock<'_, T> {}
|
||||
|
||||
impl<'a, T> RwLock<'a, T> {
|
||||
fn new(lock: &'a mut MaybeUninit<libc::pthread_rwlock_t>, data: T) -> Self {
|
||||
impl PthreadRwLock {
|
||||
pub fn new(lock: *mut libc::pthread_rwlock_t) -> Self {
|
||||
unsafe {
|
||||
let mut attrs = MaybeUninit::uninit();
|
||||
// Ignoring return value here - only possible error is OOM.
|
||||
@@ -48,208 +23,81 @@ impl<'a, T> RwLock<'a, T> {
|
||||
libc::PTHREAD_PROCESS_SHARED
|
||||
);
|
||||
// TODO(quantumish): worth making this function return Result?
|
||||
libc::pthread_rwlock_init(lock.as_mut_ptr(), attrs.as_mut_ptr());
|
||||
libc::pthread_rwlock_init(lock, attrs.as_mut_ptr());
|
||||
// Safety: POSIX specifies that "any function affecting the attributes
|
||||
// object (including destruction) shall not affect any previously
|
||||
// initialized read-write locks".
|
||||
libc::pthread_rwlockattr_destroy(attrs.as_mut_ptr());
|
||||
Self {
|
||||
inner: lock.assume_init_mut(),
|
||||
data: data.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn read(&self) -> Result<RwLockReadGuard<'_, '_, T>, RwLockError> {
|
||||
unsafe {
|
||||
let res = libc::pthread_rwlock_rdlock(self.inner as *const _ as *mut _);
|
||||
match res {
|
||||
0 => (),
|
||||
libc::EINVAL => panic!("failed to properly initialize lock"),
|
||||
libc::EDEADLK => return Err(RwLockError::Deadlock),
|
||||
libc::EAGAIN => return Err(RwLockError::MaxReadLocks),
|
||||
e => panic!("unknown error code returned: {e}")
|
||||
}
|
||||
Ok(RwLockReadGuard {
|
||||
data: NonNull::new_unchecked(self.data.get()),
|
||||
lock: self
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn try_read(&self) -> Result<RwLockReadGuard<'_, '_, T>, RwLockError> {
|
||||
unsafe {
|
||||
let res = libc::pthread_rwlock_tryrdlock(self.inner as *const _ as *mut _);
|
||||
match res {
|
||||
0 => (),
|
||||
libc::EINVAL => panic!("failed to properly initialize lock"),
|
||||
libc::EDEADLK => return Err(RwLockError::Deadlock),
|
||||
libc::EAGAIN => return Err(RwLockError::MaxReadLocks),
|
||||
libc::EBUSY => return Err(RwLockError::WouldBlock),
|
||||
e => panic!("unknown error code returned: {e}")
|
||||
}
|
||||
Ok(RwLockReadGuard {
|
||||
data: NonNull::new_unchecked(self.data.get()),
|
||||
lock: self
|
||||
})
|
||||
Self(Some(NonNull::new_unchecked(lock)))
|
||||
}
|
||||
}
|
||||
|
||||
fn write(&self) -> Result<RwLockWriteGuard<'_, '_, T>, RwLockError> {
|
||||
fn inner(&self) -> NonNull<libc::pthread_rwlock_t> {
|
||||
match self.0 {
|
||||
None => panic!("PthreadRwLock constructed badly - something likely used RawMutex::INIT"),
|
||||
Some(x) => x,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl lock_api::RawRwLock for PthreadRwLock {
|
||||
type GuardMarker = lock_api::GuardSend;
|
||||
const INIT: Self = Self(None);
|
||||
|
||||
fn lock_shared(&self) {
|
||||
unsafe {
|
||||
let res = libc::pthread_rwlock_wrlock(self.inner as *const _ as *mut _);
|
||||
match res {
|
||||
0 => (),
|
||||
libc::EINVAL => panic!("failed to properly initialize lock"),
|
||||
libc::EDEADLK => return Err(RwLockError::Deadlock),
|
||||
e => panic!("unknown error code returned: {e}")
|
||||
let res = libc::pthread_rwlock_rdlock(self.inner().as_ptr());
|
||||
if res != 0 {
|
||||
panic!("rdlock failed with {res}");
|
||||
}
|
||||
}
|
||||
Ok(RwLockWriteGuard { lock: self })
|
||||
}
|
||||
|
||||
fn try_write(&self) -> Result<RwLockWriteGuard<'_, '_, T>, RwLockError> {
|
||||
fn try_lock_shared(&self) -> bool {
|
||||
unsafe {
|
||||
let res = libc::pthread_rwlock_trywrlock(self.inner as *const _ as *mut _);
|
||||
let res = libc::pthread_rwlock_tryrdlock(self.inner().as_ptr());
|
||||
match res {
|
||||
0 => (),
|
||||
libc::EINVAL => panic!("failed to properly initialize lock"),
|
||||
libc::EDEADLK => return Err(RwLockError::Deadlock),
|
||||
libc::EBUSY => return Err(RwLockError::WouldBlock),
|
||||
e => panic!("unknown error code returned: {e}")
|
||||
0 => true,
|
||||
libc::EAGAIN => false,
|
||||
o => panic!("try_rdlock failed with {o}")
|
||||
}
|
||||
}
|
||||
Ok(RwLockWriteGuard { lock: self })
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<T: ?Sized + Sync> Sync for RwLockReadGuard<'_, '_, T> {}
|
||||
unsafe impl<T: ?Sized + Sync> Sync for RwLockWriteGuard<'_, '_, T> {}
|
||||
|
||||
impl<T: ?Sized> Deref for RwLockReadGuard<'_, '_, T> {
|
||||
type Target = T;
|
||||
|
||||
fn deref(&self) -> &T {
|
||||
unsafe { self.data.as_ref() }
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: ?Sized> Deref for RwLockWriteGuard<'_, '_, T> {
|
||||
type Target = T;
|
||||
|
||||
fn deref(&self) -> &T {
|
||||
unsafe { &*self.lock.data.get() }
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: ?Sized> DerefMut for RwLockWriteGuard<'_, '_, T> {
|
||||
fn deref_mut(&mut self) -> &mut T {
|
||||
unsafe { &mut *self.lock.data.get() }
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: ?Sized> Drop for RwLockReadGuard<'_, '_, T> {
|
||||
fn drop(&mut self) -> () {
|
||||
let res = unsafe { libc::pthread_rwlock_unlock(
|
||||
self.lock.inner as *const _ as *mut _
|
||||
) };
|
||||
debug_assert!(res == 0);
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: ?Sized> Drop for RwLockWriteGuard<'_, '_, T> {
|
||||
fn drop(&mut self) -> () {
|
||||
let res = unsafe { libc::pthread_rwlock_unlock(
|
||||
self.lock.inner as *const _ as *mut _
|
||||
) };
|
||||
debug_assert!(res == 0);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::*;
|
||||
use RwLockError::*;
|
||||
|
||||
#[test]
|
||||
fn test_single_process() {
|
||||
let mut lock = MaybeUninit::uninit();
|
||||
let wrapper = RwLock::new(&mut lock, 0);
|
||||
let mut writer = wrapper.write().unwrap();
|
||||
assert!(matches!(wrapper.try_write(), Err(Deadlock | WouldBlock)));
|
||||
assert!(matches!(wrapper.try_read(), Err(Deadlock | WouldBlock)));
|
||||
*writer = 5;
|
||||
drop(writer);
|
||||
let reader = wrapper.read().unwrap();
|
||||
assert!(matches!(wrapper.try_write(), Err(Deadlock | WouldBlock)));
|
||||
assert!(matches!(wrapper.read(), Ok(_)));
|
||||
assert_eq!(*reader, 5);
|
||||
drop(reader);
|
||||
assert!(matches!(wrapper.try_write(), Ok(_)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_multi_thread() {
|
||||
let lock = Box::new(MaybeUninit::uninit());
|
||||
let wrapper = Arc::new(RwLock::new(Box::leak(lock), 0));
|
||||
let mut writer = wrapper.write().unwrap();
|
||||
let t1 = {
|
||||
let wrapper = wrapper.clone();
|
||||
std::thread::spawn(move || {
|
||||
let mut writer = wrapper.write().unwrap();
|
||||
*writer = 20;
|
||||
})
|
||||
};
|
||||
assert_eq!(*writer, 0);
|
||||
*writer = 10;
|
||||
assert_eq!(*writer, 10);
|
||||
drop(writer);
|
||||
t1.join().unwrap();
|
||||
let mut writer = wrapper.write().unwrap();
|
||||
assert_eq!(*writer, 20);
|
||||
drop(writer);
|
||||
let mut handles = vec![];
|
||||
for _ in 0..5 {
|
||||
handles.push({
|
||||
let wrapper = wrapper.clone();
|
||||
std::thread::spawn(move || {
|
||||
let reader = wrapper.read().unwrap();
|
||||
assert_eq!(*reader, 20);
|
||||
})
|
||||
});
|
||||
fn lock_exclusive(&self) {
|
||||
unsafe {
|
||||
let res = libc::pthread_rwlock_wrlock(self.inner().as_ptr());
|
||||
if res != 0 {
|
||||
panic!("wrlock failed with {res}");
|
||||
}
|
||||
}
|
||||
for h in handles {
|
||||
h.join().unwrap();
|
||||
}
|
||||
let writer = wrapper.write().unwrap();
|
||||
assert_eq!(*writer, 20);
|
||||
}
|
||||
|
||||
// // TODO(quantumish): Terrible time-based synchronization, fix me.
|
||||
// #[test]
|
||||
// fn test_multi_process() {
|
||||
// let max_size = 100;
|
||||
// let init_struct = crate::shmem::ShmemHandle::new("test_multi_process", 0, max_size).unwrap();
|
||||
// let ptr = init_struct.data_ptr.as_ptr();
|
||||
// let lock: &mut _ = unsafe { ptr.add(
|
||||
// ptr.align_offset(std::mem::align_of::<MaybeUninit<libc::pthread_rwlock_t>>())
|
||||
// ).cast::<MaybeUninit<libc::pthread_rwlock_t>>().as_mut().unwrap() } ;
|
||||
// let wrapper = RwLock::new(lock, 0);
|
||||
fn try_lock_exclusive(&self) -> bool {
|
||||
unsafe {
|
||||
let res = libc::pthread_rwlock_trywrlock(self.inner().as_ptr());
|
||||
match res {
|
||||
0 => true,
|
||||
libc::EAGAIN => false,
|
||||
o => panic!("try_wrlock failed with {o}")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// let fork_result = unsafe { nix::unistd::fork().unwrap() };
|
||||
|
||||
// if !fork_result.is_parent() {
|
||||
// let mut writer = wrapper.write().unwrap();
|
||||
// std::thread::sleep(std::time::Duration::from_secs(5));
|
||||
// *writer = 2;
|
||||
// } else {
|
||||
// std::thread::sleep(std::time::Duration::from_secs(1));
|
||||
// assert!(matches!(wrapper.try_write(), Err(WouldBlock)));
|
||||
// std::thread::sleep(std::time::Duration::from_secs(10));
|
||||
// let writer = wrapper.try_write().unwrap();
|
||||
// assert_eq!(*writer, 2);
|
||||
// }
|
||||
// }
|
||||
unsafe fn unlock_exclusive(&self) {
|
||||
unsafe {
|
||||
let res = libc::pthread_rwlock_unlock(self.inner().as_ptr());
|
||||
if res != 0 {
|
||||
panic!("unlock failed with {res}");
|
||||
}
|
||||
}
|
||||
}
|
||||
unsafe fn unlock_shared(&self) {
|
||||
unsafe {
|
||||
let res = libc::pthread_rwlock_unlock(self.inner().as_ptr());
|
||||
if res != 0 {
|
||||
panic!("unlock failed with {res}");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user