diff --git a/Cargo.lock b/Cargo.lock index 4fd5f5802b..4bb6d2d474 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4037,6 +4037,7 @@ dependencies = [ "criterion", "foldhash", "hashbrown 0.15.4 (git+https://github.com/quantumish/hashbrown.git?rev=6610e6d)", + "libc", "nix 0.30.1", "rand 0.9.1", "rand_distr 0.5.1", diff --git a/libs/neon-shmem/Cargo.toml b/libs/neon-shmem/Cargo.toml index 08b5ac2138..8306bbf778 100644 --- a/libs/neon-shmem/Cargo.toml +++ b/libs/neon-shmem/Cargo.toml @@ -10,6 +10,7 @@ nix.workspace = true workspace_hack = { version = "0.1", path = "../../workspace_hack" } rustc-hash = { version = "2.1.1" } rand = "0.9.1" +libc.workspace = true [dev-dependencies] criterion = { workspace = true, features = ["html_reports"] } diff --git a/libs/neon-shmem/src/lib.rs b/libs/neon-shmem/src/lib.rs index f601010122..61ca168073 100644 --- a/libs/neon-shmem/src/lib.rs +++ b/libs/neon-shmem/src/lib.rs @@ -2,3 +2,4 @@ pub mod hash; pub mod shmem; +pub mod sync; diff --git a/libs/neon-shmem/src/sync.rs b/libs/neon-shmem/src/sync.rs new file mode 100644 index 0000000000..68ef7b904d --- /dev/null +++ b/libs/neon-shmem/src/sync.rs @@ -0,0 +1,255 @@ +//! Simple utilities akin to what's in [`std::sync`] but designed to work with shared memory. + +use std::mem::MaybeUninit; +use std::ptr::NonNull; +use std::cell::UnsafeCell; +use std::ops::{Deref, DerefMut}; +use thiserror::Error; + +/// Shared memory read-write lock. +struct RwLock<'a, T: ?Sized> { + inner: &'a mut libc::pthread_rwlock_t, + data: UnsafeCell, +} + +/// RAII guard for a read lock. +struct RwLockReadGuard<'a, 'b, T: ?Sized> { + data: NonNull, + lock: &'a RwLock<'b, T>, +} + +/// RAII guard for a write lock. +struct RwLockWriteGuard<'a, 'b, T: ?Sized> { + lock: &'a RwLock<'b, T>, +} + +// TODO(quantumish): Support poisoning errors? +#[derive(Error, Debug)] +enum RwLockError { + #[error("deadlock detected")] + Deadlock, + #[error("max number of read locks exceeded")] + MaxReadLocks, + #[error("nonblocking operation would block")] + WouldBlock, +} + +unsafe impl Send for RwLock<'_, T> {} +unsafe impl Sync for RwLock<'_, T> {} + +impl<'a, T> RwLock<'a, T> { + fn new(lock: &'a mut MaybeUninit, data: T) -> Self { + unsafe { + let mut attrs = MaybeUninit::uninit(); + // Ignoring return value here - only possible error is OOM. + libc::pthread_rwlockattr_init(attrs.as_mut_ptr()); + libc::pthread_rwlockattr_setpshared( + attrs.as_mut_ptr(), + libc::PTHREAD_PROCESS_SHARED + ); + // TODO(quantumish): worth making this function return Result? + libc::pthread_rwlock_init(lock.as_mut_ptr(), attrs.as_mut_ptr()); + // Safety: POSIX specifies that "any function affecting the attributes + // object (including destruction) shall not affect any previously + // initialized read-write locks". + libc::pthread_rwlockattr_destroy(attrs.as_mut_ptr()); + Self { + inner: lock.assume_init_mut(), + data: data.into(), + } + } + } + + fn read(&self) -> Result, RwLockError> { + unsafe { + let res = libc::pthread_rwlock_rdlock(self.inner as *const _ as *mut _); + match res { + 0 => (), + libc::EINVAL => panic!("failed to properly initialize lock"), + libc::EDEADLK => return Err(RwLockError::Deadlock), + libc::EAGAIN => return Err(RwLockError::MaxReadLocks), + e => panic!("unknown error code returned: {e}") + } + Ok(RwLockReadGuard { + data: NonNull::new_unchecked(self.data.get()), + lock: self + }) + } + } + + fn try_read(&self) -> Result, RwLockError> { + unsafe { + let res = libc::pthread_rwlock_tryrdlock(self.inner as *const _ as *mut _); + match res { + 0 => (), + libc::EINVAL => panic!("failed to properly initialize lock"), + libc::EDEADLK => return Err(RwLockError::Deadlock), + libc::EAGAIN => return Err(RwLockError::MaxReadLocks), + libc::EBUSY => return Err(RwLockError::WouldBlock), + e => panic!("unknown error code returned: {e}") + } + Ok(RwLockReadGuard { + data: NonNull::new_unchecked(self.data.get()), + lock: self + }) + } + } + + fn write(&self) -> Result, RwLockError> { + unsafe { + let res = libc::pthread_rwlock_wrlock(self.inner as *const _ as *mut _); + match res { + 0 => (), + libc::EINVAL => panic!("failed to properly initialize lock"), + libc::EDEADLK => return Err(RwLockError::Deadlock), + e => panic!("unknown error code returned: {e}") + } + } + Ok(RwLockWriteGuard { lock: self }) + } + + fn try_write(&self) -> Result, RwLockError> { + unsafe { + let res = libc::pthread_rwlock_trywrlock(self.inner as *const _ as *mut _); + match res { + 0 => (), + libc::EINVAL => panic!("failed to properly initialize lock"), + libc::EDEADLK => return Err(RwLockError::Deadlock), + libc::EBUSY => return Err(RwLockError::WouldBlock), + e => panic!("unknown error code returned: {e}") + } + } + Ok(RwLockWriteGuard { lock: self }) + } +} + +unsafe impl Sync for RwLockReadGuard<'_, '_, T> {} +unsafe impl Sync for RwLockWriteGuard<'_, '_, T> {} + +impl Deref for RwLockReadGuard<'_, '_, T> { + type Target = T; + + fn deref(&self) -> &T { + unsafe { self.data.as_ref() } + } +} + +impl Deref for RwLockWriteGuard<'_, '_, T> { + type Target = T; + + fn deref(&self) -> &T { + unsafe { &*self.lock.data.get() } + } +} + +impl DerefMut for RwLockWriteGuard<'_, '_, T> { + fn deref_mut(&mut self) -> &mut T { + unsafe { &mut *self.lock.data.get() } + } +} + +impl Drop for RwLockReadGuard<'_, '_, T> { + fn drop(&mut self) -> () { + let res = unsafe { libc::pthread_rwlock_unlock( + self.lock.inner as *const _ as *mut _ + ) }; + debug_assert!(res == 0); + } +} + +impl Drop for RwLockWriteGuard<'_, '_, T> { + fn drop(&mut self) -> () { + let res = unsafe { libc::pthread_rwlock_unlock( + self.lock.inner as *const _ as *mut _ + ) }; + debug_assert!(res == 0); + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use super::*; +use RwLockError::*; + + #[test] + fn test_single_process() { + let mut lock = MaybeUninit::uninit(); + let wrapper = RwLock::new(&mut lock, 0); + let mut writer = wrapper.write().unwrap(); + assert!(matches!(wrapper.try_write(), Err(Deadlock | WouldBlock))); + assert!(matches!(wrapper.try_read(), Err(Deadlock | WouldBlock))); + *writer = 5; + drop(writer); + let reader = wrapper.read().unwrap(); + assert!(matches!(wrapper.try_write(), Err(Deadlock | WouldBlock))); + assert!(matches!(wrapper.read(), Ok(_))); + assert_eq!(*reader, 5); + drop(reader); + assert!(matches!(wrapper.try_write(), Ok(_))); + } + + #[test] + fn test_multi_thread() { + let lock = Box::new(MaybeUninit::uninit()); + let wrapper = Arc::new(RwLock::new(Box::leak(lock), 0)); + let mut writer = wrapper.write().unwrap(); + let t1 = { + let wrapper = wrapper.clone(); + std::thread::spawn(move || { + let mut writer = wrapper.write().unwrap(); + *writer = 20; + }) + }; + assert_eq!(*writer, 0); + *writer = 10; + assert_eq!(*writer, 10); + drop(writer); + t1.join().unwrap(); + let mut writer = wrapper.write().unwrap(); + assert_eq!(*writer, 20); + drop(writer); + let mut handles = vec![]; + for _ in 0..5 { + handles.push({ + let wrapper = wrapper.clone(); + std::thread::spawn(move || { + let reader = wrapper.read().unwrap(); + assert_eq!(*reader, 20); + }) + }); + } + for h in handles { + h.join().unwrap(); + } + let writer = wrapper.write().unwrap(); + assert_eq!(*writer, 20); + } + + // // TODO(quantumish): Terrible time-based synchronization, fix me. + // #[test] + // fn test_multi_process() { + // let max_size = 100; + // let init_struct = crate::shmem::ShmemHandle::new("test_multi_process", 0, max_size).unwrap(); + // let ptr = init_struct.data_ptr.as_ptr(); + // let lock: &mut _ = unsafe { ptr.add( + // ptr.align_offset(std::mem::align_of::>()) + // ).cast::>().as_mut().unwrap() } ; + // let wrapper = RwLock::new(lock, 0); + + // let fork_result = unsafe { nix::unistd::fork().unwrap() }; + + // if !fork_result.is_parent() { + // let mut writer = wrapper.write().unwrap(); + // std::thread::sleep(std::time::Duration::from_secs(5)); + // *writer = 2; + // } else { + // std::thread::sleep(std::time::Duration::from_secs(1)); + // assert!(matches!(wrapper.try_write(), Err(WouldBlock))); + // std::thread::sleep(std::time::Duration::from_secs(10)); + // let writer = wrapper.try_write().unwrap(); + // assert_eq!(*writer, 2); + // } + // } +}