diff --git a/Cargo.lock b/Cargo.lock index 215b3360bc..137b883a6d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2534,6 +2534,18 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "getrandom" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26145e563e54f2cadc477553f1ec5ee650b00862f0a58bcd12cbdc5f0ea2d2f4" +dependencies = [ + "cfg-if", + "libc", + "r-efi", + "wasi 0.14.2+wasi-0.2.4", +] + [[package]] name = "gettid" version = "0.1.3" @@ -3607,9 +3619,9 @@ checksum = "4ee93343901ab17bd981295f2cf0026d4ad018c7c31ba84549a4ddbb47a45104" [[package]] name = "lock_api" -version = "0.4.10" +version = "0.4.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c1cc9717a20b1bb222f333e6a92fd32f7d8a18ddc5a3191a11af45dcbf4dcd16" +checksum = "96936507f153605bddfcda068dd804796c84324ed2510809e5b2a624c81da765" dependencies = [ "autocfg", "scopeguard", @@ -3759,7 +3771,7 @@ dependencies = [ "procfs", "prometheus", "rand 0.8.5", - "rand_distr", + "rand_distr 0.4.3", "twox-hash", ] @@ -3847,7 +3859,12 @@ checksum = "e5ce46fe64a9d73be07dcbe690a38ce1b293be448fd8ce1e6c1b8062c9f72c6a" name = "neon-shmem" version = "0.1.0" dependencies = [ + "libc", + "lock_api", "nix 0.30.1", + "rand 0.9.1", + "rand_distr 0.5.1", + "rustc-hash 2.1.1", "tempfile", "thiserror 1.0.69", "workspace_hack", @@ -5348,7 +5365,7 @@ dependencies = [ "postgres_backend", "pq_proto", "rand 0.8.5", - "rand_distr", + "rand_distr 0.4.3", "rcgen", "redis", "regex", @@ -5359,7 +5376,7 @@ dependencies = [ "reqwest-tracing", "rsa", "rstest", - "rustc-hash 1.1.0", + "rustc-hash 2.1.1", "rustls 0.23.27", "rustls-native-certs 0.8.0", "rustls-pemfile 2.1.1", @@ -5452,6 +5469,12 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "r-efi" +version = "5.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" + [[package]] name = "rand" version = "0.7.3" @@ -5476,6 +5499,16 @@ dependencies = [ "rand_core 0.6.4", ] +[[package]] +name = "rand" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9fbfd9d094a40bf3ae768db9361049ace4c0e04a4fd6b359518bd7b73a73dd97" +dependencies = [ + "rand_chacha 0.9.0", + "rand_core 0.9.3", +] + [[package]] name = "rand_chacha" version = "0.2.2" @@ -5496,6 +5529,16 @@ dependencies = [ "rand_core 0.6.4", ] +[[package]] +name = "rand_chacha" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb" +dependencies = [ + "ppv-lite86", + "rand_core 0.9.3", +] + [[package]] name = "rand_core" version = "0.5.1" @@ -5514,6 +5557,15 @@ dependencies = [ "getrandom 0.2.11", ] +[[package]] +name = "rand_core" +version = "0.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "99d9a13982dcf210057a8a78572b2217b667c3beacbf3a0d8b454f6f82837d38" +dependencies = [ + "getrandom 0.3.3", +] + [[package]] name = "rand_distr" version = "0.4.3" @@ -5524,6 +5576,16 @@ dependencies = [ "rand 0.8.5", ] +[[package]] +name = "rand_distr" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a8615d50dcf34fa31f7ab52692afec947c4dd0ab803cc87cb3b0b4570ff7463" +dependencies = [ + "num-traits", + "rand 0.9.1", +] + [[package]] name = "rand_hc" version = "0.2.0" @@ -8351,6 +8413,15 @@ version = "0.11.0+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" +[[package]] +name = "wasi" +version = "0.14.2+wasi-0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9683f9a5a998d873c0d21fcbe3c083009670149a8fab228644b8bd36b2c48cb3" +dependencies = [ + "wit-bindgen-rt", +] + [[package]] name = "wasite" version = "0.1.0" @@ -8708,6 +8779,15 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "wit-bindgen-rt" +version = "0.39.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6f42320e61fe2cfd34354ecb597f86f413484a798ba44a8ca1165c58d42da6c1" +dependencies = [ + "bitflags 2.8.0", +] + [[package]] name = "workspace_hack" version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index df2064a4a7..6d91262882 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -130,6 +130,7 @@ jemalloc_pprof = { version = "0.7", features = ["symbolize", "flamegraph"] } jsonwebtoken = "9" lasso = "0.7" libc = "0.2" +lock_api = "0.4.13" md5 = "0.7.0" measured = { version = "0.0.22", features=["lasso"] } measured-process = { version = "0.0.22" } @@ -165,7 +166,7 @@ reqwest-middleware = "0.4" reqwest-retry = "0.7" routerify = "3" rpds = "0.13" -rustc-hash = "1.1.0" +rustc-hash = "2.1.1" rustls = { version = "0.23.16", default-features = false } rustls-pemfile = "2" rustls-pki-types = "1.11" diff --git a/libs/neon-shmem/Cargo.toml b/libs/neon-shmem/Cargo.toml index 2a636bec40..7ed991502e 100644 --- a/libs/neon-shmem/Cargo.toml +++ b/libs/neon-shmem/Cargo.toml @@ -8,6 +8,13 @@ license.workspace = true thiserror.workspace = true nix.workspace=true workspace_hack = { version = "0.1", path = "../../workspace_hack" } +libc.workspace = true +lock_api.workspace = true +rustc-hash.workspace = true [target.'cfg(target_os = "macos")'.dependencies] tempfile = "3.14.0" + +[dev-dependencies] +rand = "0.9" +rand_distr = "0.5.1" diff --git a/libs/neon-shmem/src/hash.rs b/libs/neon-shmem/src/hash.rs new file mode 100644 index 0000000000..58726b9ba3 --- /dev/null +++ b/libs/neon-shmem/src/hash.rs @@ -0,0 +1,583 @@ +//! Resizable hash table implementation on top of byte-level storage (either a [`ShmemHandle`] or a fixed byte array). +//! +//! This hash table has two major components: the bucket array and the dictionary. Each bucket within the +//! bucket array contains a `Option<(K, V)>` and an index of another bucket. In this way there is both an +//! implicit freelist within the bucket array (`None` buckets point to other `None` entries) and various hash +//! chains within the bucket array (a Some bucket will point to other Some buckets that had the same hash). +//! +//! Buckets are never moved unless they are within a region that is being shrunk, and so the actual hash- +//! dependent component is done with the dictionary. When a new key is inserted into the map, a position +//! within the dictionary is decided based on its hash, the data is inserted into an empty bucket based +//! off of the freelist, and then the index of said bucket is placed in the dictionary. +//! +//! This map is resizable (if initialized on top of a [`ShmemHandle`]). Both growing and shrinking happen +//! in-place and are at a high level achieved by expanding/reducing the bucket array and rebuilding the +//! dictionary by rehashing all keys. +//! +//! Concurrency is managed very simply: the entire map is guarded by one shared-memory RwLock. + +use std::hash::{BuildHasher, Hash}; +use std::mem::MaybeUninit; + +use crate::shmem::ShmemHandle; +use crate::{shmem, sync::*}; + +mod core; +pub mod entry; + +#[cfg(test)] +mod tests; + +use core::{Bucket, CoreHashMap, INVALID_POS}; +use entry::{Entry, OccupiedEntry, PrevPos, VacantEntry}; + +use thiserror::Error; + +/// Error type for a hashmap shrink operation. +#[derive(Error, Debug)] +pub enum HashMapShrinkError { + /// There was an error encountered while resizing the memory area. + #[error("shmem resize failed: {0}")] + ResizeError(shmem::Error), + /// Occupied entries in to-be-shrunk space were encountered beginning at the given index. + #[error("occupied entry in deallocated space found at {0}")] + RemainingEntries(usize), +} + +/// This represents a hash table that (possibly) lives in shared memory. +/// If a new process is launched with fork(), the child process inherits +/// this struct. +#[must_use] +pub struct HashMapInit<'a, K, V, S = rustc_hash::FxBuildHasher> { + shmem_handle: Option, + shared_ptr: *mut HashMapShared<'a, K, V>, + shared_size: usize, + hasher: S, + num_buckets: u32, +} + +/// This is a per-process handle to a hash table that (possibly) lives in shared memory. +/// If a child process is launched with fork(), the child process should +/// get its own HashMapAccess by calling HashMapInit::attach_writer/reader(). +/// +/// XXX: We're not making use of it at the moment, but this struct could +/// hold process-local information in the future. +pub struct HashMapAccess<'a, K, V, S = rustc_hash::FxBuildHasher> { + shmem_handle: Option, + shared_ptr: *mut HashMapShared<'a, K, V>, + hasher: S, +} + +unsafe impl Sync for HashMapAccess<'_, K, V, S> {} +unsafe impl Send for HashMapAccess<'_, K, V, S> {} + +impl<'a, K: Clone + Hash + Eq, V, S> HashMapInit<'a, K, V, S> { + /// Change the 'hasher' used by the hash table. + /// + /// NOTE: This must be called right after creating the hash table, + /// before inserting any entries and before calling attach_writer/reader. + /// Otherwise different accessors could be using different hash function, + /// with confusing results. + pub fn with_hasher(self, hasher: T) -> HashMapInit<'a, K, V, T> { + HashMapInit { + hasher, + shmem_handle: self.shmem_handle, + shared_ptr: self.shared_ptr, + shared_size: self.shared_size, + num_buckets: self.num_buckets, + } + } + + /// Loosely (over)estimate the size needed to store a hash table with `num_buckets` buckets. + pub fn estimate_size(num_buckets: u32) -> usize { + // add some margin to cover alignment etc. + CoreHashMap::::estimate_size(num_buckets) + size_of::>() + 1000 + } + + fn new( + num_buckets: u32, + shmem_handle: Option, + area_ptr: *mut u8, + area_size: usize, + hasher: S, + ) -> Self { + let mut ptr: *mut u8 = area_ptr; + let end_ptr: *mut u8 = unsafe { ptr.add(area_size) }; + + // carve out area for the One Big Lock (TM) and the HashMapShared. + ptr = unsafe { ptr.add(ptr.align_offset(align_of::())) }; + let raw_lock_ptr = ptr; + ptr = unsafe { ptr.add(size_of::()) }; + ptr = unsafe { ptr.add(ptr.align_offset(align_of::>())) }; + let shared_ptr: *mut HashMapShared = ptr.cast(); + ptr = unsafe { ptr.add(size_of::>()) }; + + // carve out the buckets + ptr = unsafe { ptr.byte_add(ptr.align_offset(align_of::>())) }; + let buckets_ptr = ptr; + ptr = unsafe { ptr.add(size_of::>() * num_buckets as usize) }; + + // use remaining space for the dictionary + ptr = unsafe { ptr.byte_add(ptr.align_offset(align_of::())) }; + assert!(ptr.addr() < end_ptr.addr()); + let dictionary_ptr = ptr; + let dictionary_size = unsafe { end_ptr.byte_offset_from(ptr) / size_of::() as isize }; + assert!(dictionary_size > 0); + + let buckets = + unsafe { std::slice::from_raw_parts_mut(buckets_ptr.cast(), num_buckets as usize) }; + let dictionary = unsafe { + std::slice::from_raw_parts_mut(dictionary_ptr.cast(), dictionary_size as usize) + }; + + let hashmap = CoreHashMap::new(buckets, dictionary); + unsafe { + let lock = RwLock::from_raw(PthreadRwLock::new(raw_lock_ptr.cast()), hashmap); + std::ptr::write(shared_ptr, lock); + } + + Self { + num_buckets, + shmem_handle, + shared_ptr, + shared_size: area_size, + hasher, + } + } + + /// Attach to a hash table for writing. + pub fn attach_writer(self) -> HashMapAccess<'a, K, V, S> { + HashMapAccess { + shmem_handle: self.shmem_handle, + shared_ptr: self.shared_ptr, + hasher: self.hasher, + } + } + + /// Initialize a table for reading. Currently identical to [`HashMapInit::attach_writer`]. + /// + /// This is a holdover from a previous implementation and is being kept around for + /// backwards compatibility reasons. + pub fn attach_reader(self) -> HashMapAccess<'a, K, V, S> { + self.attach_writer() + } +} + +/// Hash table data that is actually stored in the shared memory area. +/// +/// NOTE: We carve out the parts from a contiguous chunk. Growing and shrinking the hash table +/// relies on the memory layout! The data structures are laid out in the contiguous shared memory +/// area as follows: +/// +/// [`libc::pthread_rwlock_t`] +/// [`HashMapShared`] +/// buckets +/// dictionary +/// +/// In between the above parts, there can be padding bytes to align the parts correctly. +type HashMapShared<'a, K, V> = RwLock>; + +impl<'a, K, V> HashMapInit<'a, K, V, rustc_hash::FxBuildHasher> +where + K: Clone + Hash + Eq, +{ + /// Place the hash table within a user-supplied fixed memory area. + pub fn with_fixed(num_buckets: u32, area: &'a mut [MaybeUninit]) -> Self { + Self::new( + num_buckets, + None, + area.as_mut_ptr().cast(), + area.len(), + rustc_hash::FxBuildHasher, + ) + } + + /// Place a new hash map in the given shared memory area + /// + /// # Panics + /// Will panic on failure to resize area to expected map size. + pub fn with_shmem(num_buckets: u32, shmem: ShmemHandle) -> Self { + let size = Self::estimate_size(num_buckets); + shmem + .set_size(size) + .expect("could not resize shared memory area"); + let ptr = shmem.data_ptr.as_ptr().cast(); + Self::new( + num_buckets, + Some(shmem), + ptr, + size, + rustc_hash::FxBuildHasher, + ) + } + + /// Make a resizable hash map within a new shared memory area with the given name. + pub fn new_resizeable_named(num_buckets: u32, max_buckets: u32, name: &str) -> Self { + let size = Self::estimate_size(num_buckets); + let max_size = Self::estimate_size(max_buckets); + let shmem = + ShmemHandle::new(name, size, max_size).expect("failed to make shared memory area"); + let ptr = shmem.data_ptr.as_ptr().cast(); + + Self::new( + num_buckets, + Some(shmem), + ptr, + size, + rustc_hash::FxBuildHasher, + ) + } + + /// Make a resizable hash map within a new anonymous shared memory area. + pub fn new_resizeable(num_buckets: u32, max_buckets: u32) -> Self { + use std::sync::atomic::{AtomicUsize, Ordering}; + static COUNTER: AtomicUsize = AtomicUsize::new(0); + let val = COUNTER.fetch_add(1, Ordering::Relaxed); + let name = format!("neon_shmem_hmap{val}"); + Self::new_resizeable_named(num_buckets, max_buckets, &name) + } +} + +impl<'a, K, V, S: BuildHasher> HashMapAccess<'a, K, V, S> +where + K: Clone + Hash + Eq, +{ + /// Hash a key using the map's hasher. + #[inline] + fn get_hash_value(&self, key: &K) -> u64 { + self.hasher.hash_one(key) + } + + fn entry_with_hash(&self, key: K, hash: u64) -> Entry<'a, '_, K, V> { + let mut map = unsafe { self.shared_ptr.as_ref() }.unwrap().write(); + let dict_pos = hash as usize % map.dictionary.len(); + let first = map.dictionary[dict_pos]; + if first == INVALID_POS { + // no existing entry + return Entry::Vacant(VacantEntry { + map, + key, + dict_pos: dict_pos as u32, + }); + } + + let mut prev_pos = PrevPos::First(dict_pos as u32); + let mut next = first; + loop { + let bucket = &mut map.buckets[next as usize]; + let (bucket_key, _bucket_value) = bucket.inner.as_mut().expect("entry is in use"); + if *bucket_key == key { + // found existing entry + return Entry::Occupied(OccupiedEntry { + map, + _key: key, + prev_pos, + bucket_pos: next, + }); + } + + if bucket.next == INVALID_POS { + // No existing entry + return Entry::Vacant(VacantEntry { + map, + key, + dict_pos: dict_pos as u32, + }); + } + prev_pos = PrevPos::Chained(next); + next = bucket.next; + } + } + + /// Get a reference to the corresponding value for a key. + pub fn get<'e>(&'e self, key: &K) -> Option> { + let hash = self.get_hash_value(key); + let map = unsafe { self.shared_ptr.as_ref() }.unwrap().read(); + RwLockReadGuard::try_map(map, |m| m.get_with_hash(key, hash)).ok() + } + + /// Get a reference to the entry containing a key. + /// + /// NB: THis takes a write lock as there's no way to distinguish whether the intention + /// is to use the entry for reading or for writing in advance. + pub fn entry(&self, key: K) -> Entry<'a, '_, K, V> { + let hash = self.get_hash_value(&key); + self.entry_with_hash(key, hash) + } + + /// Remove a key given its hash. Returns the associated value if it existed. + pub fn remove(&self, key: &K) -> Option { + let hash = self.get_hash_value(key); + match self.entry_with_hash(key.clone(), hash) { + Entry::Occupied(e) => Some(e.remove()), + Entry::Vacant(_) => None, + } + } + + /// Insert/update a key. Returns the previous associated value if it existed. + /// + /// # Errors + /// Will return [`core::FullError`] if there is no more space left in the map. + pub fn insert(&self, key: K, value: V) -> Result, core::FullError> { + let hash = self.get_hash_value(&key); + match self.entry_with_hash(key.clone(), hash) { + Entry::Occupied(mut e) => Ok(Some(e.insert(value))), + Entry::Vacant(e) => { + _ = e.insert(value)?; + Ok(None) + } + } + } + + /// Optionally return the entry for a bucket at a given index if it exists. + /// + /// Has more overhead than one would intuitively expect: performs both a clone of the key + /// due to the [`OccupiedEntry`] type owning the key and also a hash of the key in order + /// to enable repairing the hash chain if the entry is removed. + pub fn entry_at_bucket(&self, pos: usize) -> Option> { + let map = unsafe { self.shared_ptr.as_mut() }.unwrap().write(); + if pos >= map.buckets.len() { + return None; + } + + let entry = map.buckets[pos].inner.as_ref(); + match entry { + Some((key, _)) => Some(OccupiedEntry { + _key: key.clone(), + bucket_pos: pos as u32, + prev_pos: entry::PrevPos::Unknown(self.get_hash_value(key)), + map, + }), + _ => None, + } + } + + /// Returns the number of buckets in the table. + pub fn get_num_buckets(&self) -> usize { + let map = unsafe { self.shared_ptr.as_ref() }.unwrap().read(); + map.get_num_buckets() + } + + /// Return the key and value stored in bucket with given index. This can be used to + /// iterate through the hash map. + // TODO: An Iterator might be nicer. The communicator's clock algorithm needs to + // _slowly_ iterate through all buckets with its clock hand, without holding a lock. + // If we switch to an Iterator, it must not hold the lock. + pub fn get_at_bucket(&self, pos: usize) -> Option> { + let map = unsafe { self.shared_ptr.as_ref() }.unwrap().read(); + if pos >= map.buckets.len() { + return None; + } + RwLockReadGuard::try_map(map, |m| m.buckets[pos].inner.as_ref()).ok() + } + + /// Returns the index of the bucket a given value corresponds to. + pub fn get_bucket_for_value(&self, val_ptr: *const V) -> usize { + let map = unsafe { self.shared_ptr.as_ref() }.unwrap().read(); + + let origin = map.buckets.as_ptr(); + let idx = (val_ptr as usize - origin as usize) / size_of::>(); + assert!(idx < map.buckets.len()); + + idx + } + + /// Returns the number of occupied buckets in the table. + pub fn get_num_buckets_in_use(&self) -> usize { + let map = unsafe { self.shared_ptr.as_ref() }.unwrap().read(); + map.buckets_in_use as usize + } + + /// Clears all entries in a table. Does not reset any shrinking operations. + pub fn clear(&self) { + let mut map = unsafe { self.shared_ptr.as_mut() }.unwrap().write(); + map.clear(); + } + + /// Perform an in-place rehash of some region (0..`rehash_buckets`) of the table and reset + /// the `buckets` and `dictionary` slices to be as long as `num_buckets`. Resets the freelist + /// in the process. + fn rehash_dict( + &self, + inner: &mut CoreHashMap<'a, K, V>, + buckets_ptr: *mut core::Bucket, + end_ptr: *mut u8, + num_buckets: u32, + rehash_buckets: u32, + ) { + inner.free_head = INVALID_POS; + + let buckets; + let dictionary; + unsafe { + let buckets_end_ptr = buckets_ptr.add(num_buckets as usize); + let dictionary_ptr: *mut u32 = buckets_end_ptr + .byte_add(buckets_end_ptr.align_offset(align_of::())) + .cast(); + let dictionary_size: usize = + end_ptr.byte_offset_from(buckets_end_ptr) as usize / size_of::(); + + buckets = std::slice::from_raw_parts_mut(buckets_ptr, num_buckets as usize); + dictionary = std::slice::from_raw_parts_mut(dictionary_ptr, dictionary_size); + } + for e in dictionary.iter_mut() { + *e = INVALID_POS; + } + + for (i, bucket) in buckets.iter_mut().enumerate().take(rehash_buckets as usize) { + if bucket.inner.is_none() { + bucket.next = inner.free_head; + inner.free_head = i as u32; + continue; + } + + let hash = self.hasher.hash_one(&bucket.inner.as_ref().unwrap().0); + let pos: usize = (hash % dictionary.len() as u64) as usize; + bucket.next = dictionary[pos]; + dictionary[pos] = i as u32; + } + + inner.dictionary = dictionary; + inner.buckets = buckets; + } + + /// Rehash the map without growing or shrinking. + pub fn shuffle(&self) { + let mut map = unsafe { self.shared_ptr.as_mut() }.unwrap().write(); + let num_buckets = map.get_num_buckets() as u32; + let size_bytes = HashMapInit::::estimate_size(num_buckets); + let end_ptr: *mut u8 = unsafe { self.shared_ptr.byte_add(size_bytes).cast() }; + let buckets_ptr = map.buckets.as_mut_ptr(); + self.rehash_dict(&mut map, buckets_ptr, end_ptr, num_buckets, num_buckets); + } + + /// Grow the number of buckets within the table. + /// + /// 1. Grows the underlying shared memory area + /// 2. Initializes new buckets and overwrites the current dictionary + /// 3. Rehashes the dictionary + /// + /// # Panics + /// Panics if called on a map initialized with [`HashMapInit::with_fixed`]. + /// + /// # Errors + /// Returns an [`shmem::Error`] if any errors occur resizing the memory region. + pub fn grow(&self, num_buckets: u32) -> Result<(), shmem::Error> { + let mut map = unsafe { self.shared_ptr.as_mut() }.unwrap().write(); + let old_num_buckets = map.buckets.len() as u32; + + assert!( + num_buckets >= old_num_buckets, + "grow called with a smaller number of buckets" + ); + if num_buckets == old_num_buckets { + return Ok(()); + } + let shmem_handle = self + .shmem_handle + .as_ref() + .expect("grow called on a fixed-size hash table"); + + let size_bytes = HashMapInit::::estimate_size(num_buckets); + shmem_handle.set_size(size_bytes)?; + let end_ptr: *mut u8 = unsafe { shmem_handle.data_ptr.as_ptr().add(size_bytes) }; + + // Initialize new buckets. The new buckets are linked to the free list. + // NB: This overwrites the dictionary! + let buckets_ptr = map.buckets.as_mut_ptr(); + unsafe { + for i in old_num_buckets..num_buckets { + let bucket = buckets_ptr.add(i as usize); + bucket.write(core::Bucket { + next: if i < num_buckets - 1 { + i + 1 + } else { + map.free_head + }, + inner: None, + }); + } + } + + self.rehash_dict(&mut map, buckets_ptr, end_ptr, num_buckets, old_num_buckets); + map.free_head = old_num_buckets; + + Ok(()) + } + + /// Begin a shrink, limiting all new allocations to be in buckets with index below `num_buckets`. + /// + /// # Panics + /// Panics if called on a map initialized with [`HashMapInit::with_fixed`] or if `num_buckets` is + /// greater than the number of buckets in the map. + pub fn begin_shrink(&mut self, num_buckets: u32) { + let mut map = unsafe { self.shared_ptr.as_mut() }.unwrap().write(); + assert!( + num_buckets <= map.get_num_buckets() as u32, + "shrink called with a larger number of buckets" + ); + _ = self + .shmem_handle + .as_ref() + .expect("shrink called on a fixed-size hash table"); + map.alloc_limit = num_buckets; + } + + /// If a shrink operation is underway, returns the target size of the map. Otherwise, returns None. + pub fn shrink_goal(&self) -> Option { + let map = unsafe { self.shared_ptr.as_mut() }.unwrap().read(); + let goal = map.alloc_limit; + if goal == INVALID_POS { + None + } else { + Some(goal as usize) + } + } + + /// Complete a shrink after caller has evicted entries, removing the unused buckets and rehashing. + /// + /// # Panics + /// The following cases result in a panic: + /// - Calling this function on a map initialized with [`HashMapInit::with_fixed`]. + /// - Calling this function on a map when no shrink operation is in progress. + pub fn finish_shrink(&self) -> Result<(), HashMapShrinkError> { + let mut map = unsafe { self.shared_ptr.as_mut() }.unwrap().write(); + assert!( + map.alloc_limit != INVALID_POS, + "called finish_shrink when no shrink is in progress" + ); + + let num_buckets = map.alloc_limit; + + if map.get_num_buckets() == num_buckets as usize { + return Ok(()); + } + + assert!( + map.buckets_in_use <= num_buckets, + "called finish_shrink before enough entries were removed" + ); + + for i in (num_buckets as usize)..map.buckets.len() { + if map.buckets[i].inner.is_some() { + return Err(HashMapShrinkError::RemainingEntries(i)); + } + } + + let shmem_handle = self + .shmem_handle + .as_ref() + .expect("shrink called on a fixed-size hash table"); + + let size_bytes = HashMapInit::::estimate_size(num_buckets); + if let Err(e) = shmem_handle.set_size(size_bytes) { + return Err(HashMapShrinkError::ResizeError(e)); + } + let end_ptr: *mut u8 = unsafe { shmem_handle.data_ptr.as_ptr().add(size_bytes) }; + let buckets_ptr = map.buckets.as_mut_ptr(); + self.rehash_dict(&mut map, buckets_ptr, end_ptr, num_buckets, num_buckets); + map.alloc_limit = INVALID_POS; + + Ok(()) + } +} diff --git a/libs/neon-shmem/src/hash/core.rs b/libs/neon-shmem/src/hash/core.rs new file mode 100644 index 0000000000..4665c36adb --- /dev/null +++ b/libs/neon-shmem/src/hash/core.rs @@ -0,0 +1,174 @@ +//! Simple hash table with chaining. + +use std::hash::Hash; +use std::mem::MaybeUninit; + +use crate::hash::entry::*; + +/// Invalid position within the map (either within the dictionary or bucket array). +pub(crate) const INVALID_POS: u32 = u32::MAX; + +/// Fundamental storage unit within the hash table. Either empty or contains a key-value pair. +/// Always part of a chain of some kind (either a freelist if empty or a hash chain if full). +pub(crate) struct Bucket { + /// Index of next bucket in the chain. + pub(crate) next: u32, + /// Key-value pair contained within bucket. + pub(crate) inner: Option<(K, V)>, +} + +/// Core hash table implementation. +pub(crate) struct CoreHashMap<'a, K, V> { + /// Dictionary used to map hashes to bucket indices. + pub(crate) dictionary: &'a mut [u32], + /// Buckets containing key-value pairs. + pub(crate) buckets: &'a mut [Bucket], + /// Head of the freelist. + pub(crate) free_head: u32, + /// Maximum index of a bucket allowed to be allocated. [`INVALID_POS`] if no limit. + pub(crate) alloc_limit: u32, + /// The number of currently occupied buckets. + pub(crate) buckets_in_use: u32, +} + +/// Error for when there are no empty buckets left but one is needed. +#[derive(Debug, PartialEq)] +pub struct FullError; + +impl<'a, K: Clone + Hash + Eq, V> CoreHashMap<'a, K, V> { + const FILL_FACTOR: f32 = 0.60; + + /// Estimate the size of data contained within the the hash map. + pub fn estimate_size(num_buckets: u32) -> usize { + let mut size = 0; + + // buckets + size += size_of::>() * num_buckets as usize; + + // dictionary + size += (f32::ceil((size_of::() * num_buckets as usize) as f32 / Self::FILL_FACTOR)) + as usize; + + size + } + + pub fn new( + buckets: &'a mut [MaybeUninit>], + dictionary: &'a mut [MaybeUninit], + ) -> Self { + // Initialize the buckets + for i in 0..buckets.len() { + buckets[i].write(Bucket { + next: if i < buckets.len() - 1 { + i as u32 + 1 + } else { + INVALID_POS + }, + inner: None, + }); + } + + // Initialize the dictionary + for e in dictionary.iter_mut() { + e.write(INVALID_POS); + } + + // TODO: use std::slice::assume_init_mut() once it stabilizes + let buckets = + unsafe { std::slice::from_raw_parts_mut(buckets.as_mut_ptr().cast(), buckets.len()) }; + let dictionary = unsafe { + std::slice::from_raw_parts_mut(dictionary.as_mut_ptr().cast(), dictionary.len()) + }; + + Self { + dictionary, + buckets, + free_head: 0, + buckets_in_use: 0, + alloc_limit: INVALID_POS, + } + } + + /// Get the value associated with a key (if it exists) given its hash. + pub fn get_with_hash(&self, key: &K, hash: u64) -> Option<&V> { + let mut next = self.dictionary[hash as usize % self.dictionary.len()]; + loop { + if next == INVALID_POS { + return None; + } + + let bucket = &self.buckets[next as usize]; + let (bucket_key, bucket_value) = bucket.inner.as_ref().expect("entry is in use"); + if bucket_key == key { + return Some(bucket_value); + } + next = bucket.next; + } + } + + /// Get number of buckets in map. + pub fn get_num_buckets(&self) -> usize { + self.buckets.len() + } + + /// Clears all entries from the hashmap. + /// + /// Does not reset any allocation limits, but does clear any entries beyond them. + pub fn clear(&mut self) { + for i in 0..self.buckets.len() { + self.buckets[i] = Bucket { + next: if i < self.buckets.len() - 1 { + i as u32 + 1 + } else { + INVALID_POS + }, + inner: None, + } + } + for i in 0..self.dictionary.len() { + self.dictionary[i] = INVALID_POS; + } + + self.free_head = 0; + self.buckets_in_use = 0; + } + + /// Find the position of an unused bucket via the freelist and initialize it. + pub(crate) fn alloc_bucket(&mut self, key: K, value: V) -> Result { + let mut pos = self.free_head; + + // Find the first bucket we're *allowed* to use. + let mut prev = PrevPos::First(self.free_head); + while pos != INVALID_POS && pos >= self.alloc_limit { + let bucket = &mut self.buckets[pos as usize]; + prev = PrevPos::Chained(pos); + pos = bucket.next; + } + if pos == INVALID_POS { + return Err(FullError); + } + + // Repair the freelist. + match prev { + PrevPos::First(_) => { + let next_pos = self.buckets[pos as usize].next; + self.free_head = next_pos; + } + PrevPos::Chained(p) => { + if p != INVALID_POS { + let next_pos = self.buckets[pos as usize].next; + self.buckets[p as usize].next = next_pos; + } + } + _ => unreachable!(), + } + + // Initialize the bucket. + let bucket = &mut self.buckets[pos as usize]; + self.buckets_in_use += 1; + bucket.next = INVALID_POS; + bucket.inner = Some((key, value)); + + Ok(pos) + } +} diff --git a/libs/neon-shmem/src/hash/entry.rs b/libs/neon-shmem/src/hash/entry.rs new file mode 100644 index 0000000000..560a20db1d --- /dev/null +++ b/libs/neon-shmem/src/hash/entry.rs @@ -0,0 +1,130 @@ +//! Equivalent of [`std::collections::hash_map::Entry`] for this hashmap. + +use crate::hash::core::{CoreHashMap, FullError, INVALID_POS}; +use crate::sync::{RwLockWriteGuard, ValueWriteGuard}; + +use std::hash::Hash; +use std::mem; + +pub enum Entry<'a, 'b, K, V> { + Occupied(OccupiedEntry<'a, 'b, K, V>), + Vacant(VacantEntry<'a, 'b, K, V>), +} + +/// Enum representing the previous position within a chain. +#[derive(Clone, Copy)] +pub(crate) enum PrevPos { + /// Starting index within the dictionary. + First(u32), + /// Regular index within the buckets. + Chained(u32), + /// Unknown - e.g. the associated entry was retrieved by index instead of chain. + Unknown(u64), +} + +pub struct OccupiedEntry<'a, 'b, K, V> { + /// Mutable reference to the map containing this entry. + pub(crate) map: RwLockWriteGuard<'b, CoreHashMap<'a, K, V>>, + /// The key of the occupied entry + pub(crate) _key: K, + /// The index of the previous entry in the chain. + pub(crate) prev_pos: PrevPos, + /// The position of the bucket in the [`CoreHashMap`] bucket array. + pub(crate) bucket_pos: u32, +} + +impl OccupiedEntry<'_, '_, K, V> { + pub fn get(&self) -> &V { + &self.map.buckets[self.bucket_pos as usize] + .inner + .as_ref() + .unwrap() + .1 + } + + pub fn get_mut(&mut self) -> &mut V { + &mut self.map.buckets[self.bucket_pos as usize] + .inner + .as_mut() + .unwrap() + .1 + } + + /// Inserts a value into the entry, replacing (and returning) the existing value. + pub fn insert(&mut self, value: V) -> V { + let bucket = &mut self.map.buckets[self.bucket_pos as usize]; + // This assumes inner is Some, which it must be for an OccupiedEntry + mem::replace(&mut bucket.inner.as_mut().unwrap().1, value) + } + + /// Removes the entry from the hash map, returning the value originally stored within it. + /// + /// This may result in multiple bucket accesses if the entry was obtained by index as the + /// previous chain entry needs to be discovered in this case. + pub fn remove(mut self) -> V { + // If this bucket was queried by index, go ahead and follow its chain from the start. + let prev = if let PrevPos::Unknown(hash) = self.prev_pos { + let dict_idx = hash as usize % self.map.dictionary.len(); + let mut prev = PrevPos::First(dict_idx as u32); + let mut curr = self.map.dictionary[dict_idx]; + while curr != self.bucket_pos { + assert!(curr != INVALID_POS); + prev = PrevPos::Chained(curr); + curr = self.map.buckets[curr as usize].next; + } + prev + } else { + self.prev_pos + }; + + // CoreHashMap::remove returns Option<(K, V)>. We know it's Some for an OccupiedEntry. + let bucket = &mut self.map.buckets[self.bucket_pos as usize]; + + // unlink it from the chain + match prev { + PrevPos::First(dict_pos) => { + self.map.dictionary[dict_pos as usize] = bucket.next; + } + PrevPos::Chained(bucket_pos) => { + self.map.buckets[bucket_pos as usize].next = bucket.next; + } + _ => unreachable!(), + } + + // and add it to the freelist + let free = self.map.free_head; + let bucket = &mut self.map.buckets[self.bucket_pos as usize]; + let old_value = bucket.inner.take(); + bucket.next = free; + self.map.free_head = self.bucket_pos; + self.map.buckets_in_use -= 1; + + old_value.unwrap().1 + } +} + +/// An abstract view into a vacant entry within the map. +pub struct VacantEntry<'a, 'b, K, V> { + /// Mutable reference to the map containing this entry. + pub(crate) map: RwLockWriteGuard<'b, CoreHashMap<'a, K, V>>, + /// The key to be inserted into this entry. + pub(crate) key: K, + /// The position within the dictionary corresponding to the key's hash. + pub(crate) dict_pos: u32, +} + +impl<'b, K: Clone + Hash + Eq, V> VacantEntry<'_, 'b, K, V> { + /// Insert a value into the vacant entry, finding and populating an empty bucket in the process. + /// + /// # Errors + /// Will return [`FullError`] if there are no unoccupied buckets in the map. + pub fn insert(mut self, value: V) -> Result, FullError> { + let pos = self.map.alloc_bucket(self.key, value)?; + self.map.buckets[pos as usize].next = self.map.dictionary[self.dict_pos as usize]; + self.map.dictionary[self.dict_pos as usize] = pos; + + Ok(RwLockWriteGuard::map(self.map, |m| { + &mut m.buckets[pos as usize].inner.as_mut().unwrap().1 + })) + } +} diff --git a/libs/neon-shmem/src/hash/tests.rs b/libs/neon-shmem/src/hash/tests.rs new file mode 100644 index 0000000000..92233e8140 --- /dev/null +++ b/libs/neon-shmem/src/hash/tests.rs @@ -0,0 +1,428 @@ +use std::collections::BTreeMap; +use std::collections::HashSet; +use std::fmt::Debug; +use std::mem::MaybeUninit; + +use crate::hash::Entry; +use crate::hash::HashMapAccess; +use crate::hash::HashMapInit; +use crate::hash::core::FullError; + +use rand::seq::SliceRandom; +use rand::{Rng, RngCore}; +use rand_distr::Zipf; + +const TEST_KEY_LEN: usize = 16; + +#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)] +struct TestKey([u8; TEST_KEY_LEN]); + +impl From<&TestKey> for u128 { + fn from(val: &TestKey) -> u128 { + u128::from_be_bytes(val.0) + } +} + +impl From for TestKey { + fn from(val: u128) -> TestKey { + TestKey(val.to_be_bytes()) + } +} + +impl<'a> From<&'a [u8]> for TestKey { + fn from(bytes: &'a [u8]) -> TestKey { + TestKey(bytes.try_into().unwrap()) + } +} + +fn test_inserts + Copy>(keys: &[K]) { + let w = HashMapInit::::new_resizeable_named(100000, 120000, "test_inserts") + .attach_writer(); + + for (idx, k) in keys.iter().enumerate() { + let res = w.entry((*k).into()); + match res { + Entry::Occupied(mut e) => { + e.insert(idx); + } + Entry::Vacant(e) => { + let res = e.insert(idx); + assert!(res.is_ok()); + } + }; + } + + for (idx, k) in keys.iter().enumerate() { + let x = w.get(&(*k).into()); + let value = x.as_deref().copied(); + assert_eq!(value, Some(idx)); + } +} + +#[test] +fn dense() { + // This exercises splitting a node with prefix + let keys: &[u128] = &[0, 1, 2, 3, 256]; + test_inserts(keys); + + // Dense keys + let mut keys: Vec = (0..10000).collect(); + test_inserts(&keys); + + // Do the same in random orders + for _ in 1..10 { + keys.shuffle(&mut rand::rng()); + test_inserts(&keys); + } +} + +#[test] +fn sparse() { + // sparse keys + let mut keys: Vec = Vec::new(); + let mut used_keys = HashSet::new(); + for _ in 0..10000 { + loop { + let key = rand::random::(); + if used_keys.contains(&key) { + continue; + } + used_keys.insert(key); + keys.push(key.into()); + break; + } + } + test_inserts(&keys); +} + +#[derive(Clone, Debug)] +struct TestOp(TestKey, Option); + +fn apply_op( + op: &TestOp, + map: &mut HashMapAccess, + shadow: &mut BTreeMap, +) { + // apply the change to the shadow tree first + let shadow_existing = if let Some(v) = op.1 { + shadow.insert(op.0, v) + } else { + shadow.remove(&op.0) + }; + + let entry = map.entry(op.0); + let hash_existing = match op.1 { + Some(new) => match entry { + Entry::Occupied(mut e) => Some(e.insert(new)), + Entry::Vacant(e) => { + _ = e.insert(new).unwrap(); + None + } + }, + None => match entry { + Entry::Occupied(e) => Some(e.remove()), + Entry::Vacant(_) => None, + }, + }; + + assert_eq!(shadow_existing, hash_existing); +} + +fn do_random_ops( + num_ops: usize, + size: u32, + del_prob: f64, + writer: &mut HashMapAccess, + shadow: &mut BTreeMap, + rng: &mut rand::rngs::ThreadRng, +) { + for i in 0..num_ops { + let key: TestKey = ((rng.next_u32() % size) as u128).into(); + let op = TestOp( + key, + if rng.random_bool(del_prob) { + Some(i) + } else { + None + }, + ); + apply_op(&op, writer, shadow); + } +} + +fn do_deletes( + num_ops: usize, + writer: &mut HashMapAccess, + shadow: &mut BTreeMap, +) { + for _ in 0..num_ops { + let (k, _) = shadow.pop_first().unwrap(); + writer.remove(&k); + } +} + +fn do_shrink( + writer: &mut HashMapAccess, + shadow: &mut BTreeMap, + from: u32, + to: u32, +) { + assert!(writer.shrink_goal().is_none()); + writer.begin_shrink(to); + assert_eq!(writer.shrink_goal(), Some(to as usize)); + for i in to..from { + if let Some(entry) = writer.entry_at_bucket(i as usize) { + shadow.remove(&entry._key); + entry.remove(); + } + } + let old_usage = writer.get_num_buckets_in_use(); + writer.finish_shrink().unwrap(); + assert!(writer.shrink_goal().is_none()); + assert_eq!(writer.get_num_buckets_in_use(), old_usage); +} + +#[test] +fn random_ops() { + let mut writer = + HashMapInit::::new_resizeable_named(100000, 120000, "test_random") + .attach_writer(); + let mut shadow: std::collections::BTreeMap = BTreeMap::new(); + + let distribution = Zipf::new(u128::MAX as f64, 1.1).unwrap(); + let mut rng = rand::rng(); + for i in 0..100000 { + let key: TestKey = (rng.sample(distribution) as u128).into(); + + let op = TestOp(key, if rng.random_bool(0.75) { Some(i) } else { None }); + + apply_op(&op, &mut writer, &mut shadow); + } +} + +#[test] +fn test_shuffle() { + let mut writer = HashMapInit::::new_resizeable_named(1000, 1200, "test_shuf") + .attach_writer(); + let mut shadow: std::collections::BTreeMap = BTreeMap::new(); + let mut rng = rand::rng(); + + do_random_ops(10000, 1000, 0.75, &mut writer, &mut shadow, &mut rng); + writer.shuffle(); + do_random_ops(10000, 1000, 0.75, &mut writer, &mut shadow, &mut rng); +} + +#[test] +fn test_grow() { + let mut writer = HashMapInit::::new_resizeable_named(1000, 2000, "test_grow") + .attach_writer(); + let mut shadow: std::collections::BTreeMap = BTreeMap::new(); + let mut rng = rand::rng(); + + do_random_ops(10000, 1000, 0.75, &mut writer, &mut shadow, &mut rng); + let old_usage = writer.get_num_buckets_in_use(); + writer.grow(1500).unwrap(); + assert_eq!(writer.get_num_buckets_in_use(), old_usage); + assert_eq!(writer.get_num_buckets(), 1500); + do_random_ops(10000, 1500, 0.75, &mut writer, &mut shadow, &mut rng); +} + +#[test] +fn test_clear() { + let mut writer = HashMapInit::::new_resizeable_named(1500, 2000, "test_clear") + .attach_writer(); + let mut shadow: std::collections::BTreeMap = BTreeMap::new(); + let mut rng = rand::rng(); + do_random_ops(2000, 1500, 0.75, &mut writer, &mut shadow, &mut rng); + writer.clear(); + assert_eq!(writer.get_num_buckets_in_use(), 0); + assert_eq!(writer.get_num_buckets(), 1500); + while let Some((key, _)) = shadow.pop_first() { + assert!(writer.get(&key).is_none()); + } + do_random_ops(2000, 1500, 0.75, &mut writer, &mut shadow, &mut rng); + for i in 0..(1500 - writer.get_num_buckets_in_use()) { + writer.insert((1500 + i as u128).into(), 0).unwrap(); + } + assert_eq!(writer.insert(5000.into(), 0), Err(FullError {})); + writer.clear(); + assert!(writer.insert(5000.into(), 0).is_ok()); +} + +#[test] +fn test_idx_remove() { + let mut writer = HashMapInit::::new_resizeable_named(1500, 2000, "test_clear") + .attach_writer(); + let mut shadow: std::collections::BTreeMap = BTreeMap::new(); + let mut rng = rand::rng(); + do_random_ops(2000, 1500, 0.25, &mut writer, &mut shadow, &mut rng); + for _ in 0..100 { + let idx = (rng.next_u32() % 1500) as usize; + if let Some(e) = writer.entry_at_bucket(idx) { + shadow.remove(&e._key); + e.remove(); + } + } + while let Some((key, val)) = shadow.pop_first() { + assert_eq!(*writer.get(&key).unwrap(), val); + } +} + +#[test] +fn test_idx_get() { + let mut writer = HashMapInit::::new_resizeable_named(1500, 2000, "test_clear") + .attach_writer(); + let mut shadow: std::collections::BTreeMap = BTreeMap::new(); + let mut rng = rand::rng(); + do_random_ops(2000, 1500, 0.25, &mut writer, &mut shadow, &mut rng); + for _ in 0..100 { + let idx = (rng.next_u32() % 1500) as usize; + if let Some(pair) = writer.get_at_bucket(idx) { + { + let v: *const usize = &pair.1; + assert_eq!(writer.get_bucket_for_value(v), idx); + } + { + let v: *const usize = &pair.1; + assert_eq!(writer.get_bucket_for_value(v), idx); + } + } + } +} + +#[test] +fn test_shrink() { + let mut writer = HashMapInit::::new_resizeable_named(1500, 2000, "test_shrink") + .attach_writer(); + let mut shadow: std::collections::BTreeMap = BTreeMap::new(); + let mut rng = rand::rng(); + + do_random_ops(10000, 1500, 0.75, &mut writer, &mut shadow, &mut rng); + do_shrink(&mut writer, &mut shadow, 1500, 1000); + assert_eq!(writer.get_num_buckets(), 1000); + do_deletes(500, &mut writer, &mut shadow); + do_random_ops(10000, 500, 0.75, &mut writer, &mut shadow, &mut rng); + assert!(writer.get_num_buckets_in_use() <= 1000); +} + +#[test] +fn test_shrink_grow_seq() { + let mut writer = + HashMapInit::::new_resizeable_named(1000, 20000, "test_grow_seq") + .attach_writer(); + let mut shadow: std::collections::BTreeMap = BTreeMap::new(); + let mut rng = rand::rng(); + + do_random_ops(500, 1000, 0.1, &mut writer, &mut shadow, &mut rng); + eprintln!("Shrinking to 750"); + do_shrink(&mut writer, &mut shadow, 1000, 750); + do_random_ops(200, 1000, 0.5, &mut writer, &mut shadow, &mut rng); + eprintln!("Growing to 1500"); + writer.grow(1500).unwrap(); + do_random_ops(600, 1500, 0.1, &mut writer, &mut shadow, &mut rng); + eprintln!("Shrinking to 200"); + while shadow.len() > 100 { + do_deletes(1, &mut writer, &mut shadow); + } + do_shrink(&mut writer, &mut shadow, 1500, 200); + do_random_ops(50, 1500, 0.25, &mut writer, &mut shadow, &mut rng); + eprintln!("Growing to 10k"); + writer.grow(10000).unwrap(); + do_random_ops(10000, 5000, 0.25, &mut writer, &mut shadow, &mut rng); +} + +#[test] +fn test_bucket_ops() { + let writer = HashMapInit::::new_resizeable_named(1000, 1200, "test_bucket_ops") + .attach_writer(); + match writer.entry(1.into()) { + Entry::Occupied(mut e) => { + e.insert(2); + } + Entry::Vacant(e) => { + _ = e.insert(2).unwrap(); + } + } + assert_eq!(writer.get_num_buckets_in_use(), 1); + assert_eq!(writer.get_num_buckets(), 1000); + assert_eq!(*writer.get(&1.into()).unwrap(), 2); + let pos = match writer.entry(1.into()) { + Entry::Occupied(e) => { + assert_eq!(e._key, 1.into()); + e.bucket_pos as usize + } + Entry::Vacant(_) => { + panic!("Insert didn't affect entry"); + } + }; + assert_eq!(writer.entry_at_bucket(pos).unwrap()._key, 1.into()); + assert_eq!(*writer.get_at_bucket(pos).unwrap(), (1.into(), 2)); + { + let ptr: *const usize = &*writer.get(&1.into()).unwrap(); + assert_eq!(writer.get_bucket_for_value(ptr), pos); + } + writer.remove(&1.into()); + assert!(writer.get(&1.into()).is_none()); +} + +#[test] +fn test_shrink_zero() { + let mut writer = + HashMapInit::::new_resizeable_named(1500, 2000, "test_shrink_zero") + .attach_writer(); + writer.begin_shrink(0); + for i in 0..1500 { + writer.entry_at_bucket(i).map(|x| x.remove()); + } + writer.finish_shrink().unwrap(); + assert_eq!(writer.get_num_buckets_in_use(), 0); + let entry = writer.entry(1.into()); + if let Entry::Vacant(v) = entry { + assert!(v.insert(2).is_err()); + } else { + panic!("Somehow got non-vacant entry in empty map.") + } + writer.grow(50).unwrap(); + let entry = writer.entry(1.into()); + if let Entry::Vacant(v) = entry { + assert!(v.insert(2).is_ok()); + } else { + panic!("Somehow got non-vacant entry in empty map.") + } + assert_eq!(writer.get_num_buckets_in_use(), 1); +} + +#[test] +#[should_panic] +fn test_grow_oom() { + let writer = HashMapInit::::new_resizeable_named(1500, 2000, "test_grow_oom") + .attach_writer(); + writer.grow(20000).unwrap(); +} + +#[test] +#[should_panic] +fn test_shrink_bigger() { + let mut writer = + HashMapInit::::new_resizeable_named(1500, 2500, "test_shrink_bigger") + .attach_writer(); + writer.begin_shrink(2000); +} + +#[test] +#[should_panic] +fn test_shrink_early_finish() { + let writer = + HashMapInit::::new_resizeable_named(1500, 2500, "test_shrink_early_finish") + .attach_writer(); + writer.finish_shrink().unwrap(); +} + +#[test] +#[should_panic] +fn test_shrink_fixed_size() { + let mut area = [MaybeUninit::uninit(); 10000]; + let init_struct = HashMapInit::::with_fixed(3, &mut area); + let mut writer = init_struct.attach_writer(); + writer.begin_shrink(1); +} diff --git a/libs/neon-shmem/src/lib.rs b/libs/neon-shmem/src/lib.rs index 50d3fbb3cf..226cc0c22d 100644 --- a/libs/neon-shmem/src/lib.rs +++ b/libs/neon-shmem/src/lib.rs @@ -1 +1,3 @@ +pub mod hash; pub mod shmem; +pub mod sync; diff --git a/libs/neon-shmem/src/sync.rs b/libs/neon-shmem/src/sync.rs new file mode 100644 index 0000000000..95719778ba --- /dev/null +++ b/libs/neon-shmem/src/sync.rs @@ -0,0 +1,111 @@ +//! Simple utilities akin to what's in [`std::sync`] but designed to work with shared memory. + +use std::mem::MaybeUninit; +use std::ptr::NonNull; + +use nix::errno::Errno; + +pub type RwLock = lock_api::RwLock; +pub type RwLockReadGuard<'a, T> = lock_api::RwLockReadGuard<'a, PthreadRwLock, T>; +pub type RwLockWriteGuard<'a, T> = lock_api::RwLockWriteGuard<'a, PthreadRwLock, T>; +pub type ValueReadGuard<'a, T> = lock_api::MappedRwLockReadGuard<'a, PthreadRwLock, T>; +pub type ValueWriteGuard<'a, T> = lock_api::MappedRwLockWriteGuard<'a, PthreadRwLock, T>; + +/// Shared memory read-write lock. +pub struct PthreadRwLock(Option>); + +/// Simple macro that calls a function in the libc namespace and panics if return value is nonzero. +macro_rules! libc_checked { + ($fn_name:ident ( $($arg:expr),* )) => {{ + let res = libc::$fn_name($($arg),*); + if res != 0 { + panic!("{} failed with {}", stringify!($fn_name), Errno::from_raw(res)); + } + }}; +} + +impl PthreadRwLock { + /// Creates a new `PthreadRwLock` on top of a pointer to a pthread rwlock. + /// + /// # Safety + /// `lock` must be non-null. Every unsafe operation will panic in the event of an error. + pub unsafe fn new(lock: *mut libc::pthread_rwlock_t) -> Self { + unsafe { + let mut attrs = MaybeUninit::uninit(); + libc_checked!(pthread_rwlockattr_init(attrs.as_mut_ptr())); + libc_checked!(pthread_rwlockattr_setpshared( + attrs.as_mut_ptr(), + libc::PTHREAD_PROCESS_SHARED + )); + libc_checked!(pthread_rwlock_init(lock, attrs.as_mut_ptr())); + // Safety: POSIX specifies that "any function affecting the attributes + // object (including destruction) shall not affect any previously + // initialized read-write locks". + libc_checked!(pthread_rwlockattr_destroy(attrs.as_mut_ptr())); + Self(Some(NonNull::new_unchecked(lock))) + } + } + + fn inner(&self) -> NonNull { + match self.0 { + None => { + panic!("PthreadRwLock constructed badly - something likely used RawRwLock::INIT") + } + Some(x) => x, + } + } +} + +unsafe impl lock_api::RawRwLock for PthreadRwLock { + type GuardMarker = lock_api::GuardSend; + const INIT: Self = Self(None); + + fn try_lock_shared(&self) -> bool { + unsafe { + let res = libc::pthread_rwlock_tryrdlock(self.inner().as_ptr()); + match res { + 0 => true, + libc::EAGAIN => false, + _ => panic!( + "pthread_rwlock_tryrdlock failed with {}", + Errno::from_raw(res) + ), + } + } + } + + fn try_lock_exclusive(&self) -> bool { + unsafe { + let res = libc::pthread_rwlock_trywrlock(self.inner().as_ptr()); + match res { + 0 => true, + libc::EAGAIN => false, + _ => panic!("try_wrlock failed with {}", Errno::from_raw(res)), + } + } + } + + fn lock_shared(&self) { + unsafe { + libc_checked!(pthread_rwlock_rdlock(self.inner().as_ptr())); + } + } + + fn lock_exclusive(&self) { + unsafe { + libc_checked!(pthread_rwlock_wrlock(self.inner().as_ptr())); + } + } + + unsafe fn unlock_exclusive(&self) { + unsafe { + libc_checked!(pthread_rwlock_unlock(self.inner().as_ptr())); + } + } + + unsafe fn unlock_shared(&self) { + unsafe { + libc_checked!(pthread_rwlock_unlock(self.inner().as_ptr())); + } + } +}