common_procedure/
rwlock.rs1use std::collections::HashMap;
16use std::hash::Hash;
17use std::sync::{Arc, Mutex};
18
19use tokio::sync::{OwnedRwLockReadGuard, OwnedRwLockWriteGuard, RwLock};
20
21pub enum OwnedKeyRwLockGuard {
26 Read { _guard: OwnedRwLockReadGuard<()> },
29
30 Write { _guard: OwnedRwLockWriteGuard<()> },
34}
35
36impl From<OwnedRwLockReadGuard<()>> for OwnedKeyRwLockGuard {
37 fn from(guard: OwnedRwLockReadGuard<()>) -> Self {
38 OwnedKeyRwLockGuard::Read { _guard: guard }
39 }
40}
41
42impl From<OwnedRwLockWriteGuard<()>> for OwnedKeyRwLockGuard {
43 fn from(guard: OwnedRwLockWriteGuard<()>) -> Self {
44 OwnedKeyRwLockGuard::Write { _guard: guard }
45 }
46}
47
48#[derive(Debug, Default)]
50pub struct KeyRwLock<K> {
51 inner: Mutex<HashMap<K, Arc<RwLock<()>>>>,
53}
54
55impl<K> KeyRwLock<K>
56where
57 K: Eq + Hash + Clone,
58{
59 pub fn new() -> Self {
60 KeyRwLock {
61 inner: Default::default(),
62 }
63 }
64
65 pub async fn read(&self, key: K) -> OwnedRwLockReadGuard<()> {
67 let lock = {
68 let mut locks = self.inner.lock().unwrap();
69 locks.entry(key).or_default().clone()
70 };
71
72 lock.read_owned().await
73 }
74
75 pub async fn write(&self, key: K) -> OwnedRwLockWriteGuard<()> {
77 let lock = {
78 let mut locks = self.inner.lock().unwrap();
79 locks.entry(key).or_default().clone()
80 };
81
82 lock.write_owned().await
83 }
84
85 pub fn clean_keys<'a>(&'a self, iter: impl IntoIterator<Item = &'a K>) {
91 let mut locks = self.inner.lock().unwrap();
92 let mut keys = Vec::new();
93 for key in iter {
94 if let Some(lock) = locks.get(key)
95 && lock.try_write().is_ok()
96 {
97 debug_assert_eq!(Arc::weak_count(lock), 0);
98 if Arc::strong_count(lock) == 1 {
100 keys.push(key);
101 }
102 }
103 }
104
105 for key in keys {
106 locks.remove(key);
107 }
108 }
109
110 pub fn clear(&self) {
114 self.inner.lock().unwrap().clear();
115 }
116}
117
118#[cfg(test)]
119impl<K> KeyRwLock<K>
120where
121 K: Eq + Hash + Clone,
122{
123 pub fn try_read(&self, key: K) -> Result<OwnedRwLockReadGuard<()>, tokio::sync::TryLockError> {
125 let lock = {
126 let mut locks = self.inner.lock().unwrap();
127 locks.entry(key).or_default().clone()
128 };
129
130 lock.try_read_owned()
131 }
132
133 pub fn try_write(
135 &self,
136 key: K,
137 ) -> Result<OwnedRwLockWriteGuard<()>, tokio::sync::TryLockError> {
138 let lock = {
139 let mut locks = self.inner.lock().unwrap();
140 locks.entry(key).or_default().clone()
141 };
142
143 lock.try_write_owned()
144 }
145
146 pub fn len(&self) -> usize {
148 self.inner.lock().unwrap().len()
149 }
150
151 pub fn is_empty(&self) -> bool {
153 self.len() == 0
154 }
155}
156
157#[cfg(test)]
158mod tests {
159 use super::*;
160
161 #[tokio::test]
162 async fn test_naive() {
163 let lock_key = KeyRwLock::new();
164
165 {
166 let _guard = lock_key.read("test1").await;
167 assert_eq!(lock_key.len(), 1);
168 assert!(lock_key.try_read("test1").is_ok());
169 assert!(lock_key.try_write("test1").is_err());
170 }
171
172 {
173 let _guard0 = lock_key.write("test2").await;
174 let _guard = lock_key.write("test1").await;
175 assert_eq!(lock_key.len(), 2);
176 assert!(lock_key.try_read("test1").is_err());
177 assert!(lock_key.try_write("test1").is_err());
178 }
179
180 assert_eq!(lock_key.len(), 2);
181
182 lock_key.clean_keys(&vec!["test1", "test2"]);
183 assert!(lock_key.is_empty());
184
185 let mut guards = Vec::new();
186 for key in ["test1", "test2"] {
187 guards.push(lock_key.read(key).await);
188 }
189 while !guards.is_empty() {
190 guards.pop();
191 }
192 lock_key.clean_keys(vec![&"test1", &"test2"]);
193 assert_eq!(lock_key.len(), 0);
194 }
195
196 #[tokio::test]
197 async fn test_clean_keys() {
198 let lock_key = KeyRwLock::<&str>::new();
199 {
200 let rwlock = {
201 lock_key
202 .inner
203 .lock()
204 .unwrap()
205 .entry("test")
206 .or_default()
207 .clone()
208 };
209 assert_eq!(Arc::strong_count(&rwlock), 2);
210 let _guard = rwlock.read_owned().await;
211
212 {
213 let inner = lock_key.inner.lock().unwrap();
214 let rwlock = inner.get("test").unwrap();
215 assert_eq!(Arc::strong_count(rwlock), 2);
216 }
217 }
218
219 {
220 let rwlock = {
221 lock_key
222 .inner
223 .lock()
224 .unwrap()
225 .entry("test")
226 .or_default()
227 .clone()
228 };
229 assert_eq!(Arc::strong_count(&rwlock), 2);
230 let _guard = rwlock.write_owned().await;
231
232 {
233 let inner = lock_key.inner.lock().unwrap();
234 let rwlock = inner.get("test").unwrap();
235 assert_eq!(Arc::strong_count(rwlock), 2);
236 }
237 }
238
239 {
240 let inner = lock_key.inner.lock().unwrap();
241 let rwlock = inner.get("test").unwrap();
242 assert_eq!(Arc::strong_count(rwlock), 1);
243 }
244
245 let rwlock = {
247 lock_key
248 .inner
249 .lock()
250 .unwrap()
251 .entry("test")
252 .or_default()
253 .clone()
254 };
255 assert_eq!(Arc::strong_count(&rwlock), 2);
256 lock_key.clean_keys(vec![&"test"]);
258 {
260 let inner = lock_key.inner.lock().unwrap();
261 inner.get("test").unwrap();
262 }
263 }
264}