diff --git a/src/common/base/src/lib.rs b/src/common/base/src/lib.rs index d4e1454e9d..539da1ba8c 100644 --- a/src/common/base/src/lib.rs +++ b/src/common/base/src/lib.rs @@ -15,68 +15,12 @@ pub mod bit_vec; pub mod buffer; pub mod bytes; +pub mod plugins; #[allow(clippy::all)] pub mod readable_size; pub mod secrets; -use core::any::Any; -use std::sync::{Arc, Mutex, MutexGuard}; - pub type AffectedRows = usize; pub use bit_vec::BitVec; - -/// [`Plugins`] is a wrapper of Arc contents. -/// Make it Cloneable and we can treat it like an Arc struct. -#[derive(Default, Clone)] -pub struct Plugins { - inner: Arc>>, -} - -impl Plugins { - pub fn new() -> Self { - Self { - inner: Arc::new(Mutex::new(anymap::Map::new())), - } - } - - fn lock(&self) -> MutexGuard> { - self.inner.lock().unwrap() - } - - pub fn insert(&self, value: T) { - let _ = self.lock().insert(value); - } - - pub fn get(&self) -> Option { - let binding = self.lock(); - binding.get::().cloned() - } - - pub fn map_mut(&self, mapper: F) -> R - where - F: FnOnce(Option<&mut T>) -> R, - { - let mut binding = self.lock(); - let opt = binding.get_mut::(); - mapper(opt) - } - - pub fn map(&self, mapper: F) -> Option - where - F: FnOnce(&T) -> R, - { - let binding = self.lock(); - binding.get::().map(mapper) - } - - pub fn len(&self) -> usize { - let binding = self.lock(); - binding.len() - } - - pub fn is_empty(&self) -> bool { - let binding = self.lock(); - binding.is_empty() - } -} +pub use plugins::Plugins; diff --git a/src/common/base/src/plugins.rs b/src/common/base/src/plugins.rs new file mode 100644 index 0000000000..84d78b0c91 --- /dev/null +++ b/src/common/base/src/plugins.rs @@ -0,0 +1,127 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::any::Any; +use std::sync::{Arc, RwLock, RwLockReadGuard, RwLockWriteGuard}; + +/// [`Plugins`] is a wrapper of [AnyMap](https://github.com/chris-morgan/anymap) and provides a thread-safe way to store and retrieve plugins. +/// Make it Cloneable and we can treat it like an Arc struct. +#[derive(Default, Clone)] +pub struct Plugins { + inner: Arc>>, +} + +impl Plugins { + pub fn new() -> Self { + Self { + inner: Arc::new(RwLock::new(anymap::Map::new())), + } + } + + pub fn insert(&self, value: T) { + let _ = self.write().insert(value); + } + + pub fn get(&self) -> Option { + self.read().get::().cloned() + } + + pub fn map_mut(&self, mapper: F) -> R + where + F: FnOnce(Option<&mut T>) -> R, + { + let mut binding = self.write(); + let opt = binding.get_mut::(); + mapper(opt) + } + + pub fn map(&self, mapper: F) -> Option + where + F: FnOnce(&T) -> R, + { + self.read().get::().map(mapper) + } + + pub fn len(&self) -> usize { + self.read().len() + } + + pub fn is_empty(&self) -> bool { + self.read().is_empty() + } + + fn read(&self) -> RwLockReadGuard> { + self.inner.read().unwrap() + } + + fn write(&self) -> RwLockWriteGuard> { + self.inner.write().unwrap() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_plugins() { + #[derive(Debug, Clone)] + struct FooPlugin { + x: i32, + } + + #[derive(Debug, Clone)] + struct BarPlugin { + y: String, + } + + let plugins = Plugins::new(); + + let m = plugins.clone(); + let thread1 = std::thread::spawn(move || { + m.insert(FooPlugin { x: 42 }); + + if let Some(foo) = m.get::() { + assert_eq!(foo.x, 42); + } + + assert_eq!(m.map::(|foo| foo.x * 2), Some(84)); + }); + + let m = plugins.clone(); + let thread2 = std::thread::spawn(move || { + m.clone().insert(BarPlugin { + y: "hello".to_string(), + }); + + if let Some(bar) = m.get::() { + assert_eq!(bar.y, "hello"); + } + + m.map_mut::(|bar| { + if let Some(bar) = bar { + bar.y = "world".to_string(); + } + }); + + assert_eq!(m.get::().unwrap().y, "world"); + }); + + thread1.join().unwrap(); + thread2.join().unwrap(); + + assert_eq!(plugins.len(), 2); + assert!(!plugins.is_empty()); + } +}