Skip to main content

common_procedure/
rwlock.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::hash::Hash;
17use std::sync::{Arc, Mutex};
18
19use tokio::sync::{OwnedRwLockReadGuard, OwnedRwLockWriteGuard, RwLock};
20
21/// A guard that owns a read or write lock on a key.
22///
23/// This enum wraps either a read or write lock guard obtained from a `KeyRwLock`.
24/// The guard is automatically released when it is dropped.
25pub enum OwnedKeyRwLockGuard {
26    /// Represents a shared read lock on a key.
27    /// Multiple read locks can be held simultaneously for the same key.
28    Read { _guard: OwnedRwLockReadGuard<()> },
29
30    /// Represents an exclusive write lock on a key.
31    /// Only one write lock can be held at a time for a given key,
32    /// and no read locks can be held simultaneously with a write lock.
33    Write { _guard: OwnedRwLockWriteGuard<()> },
34}
35
36impl From<OwnedRwLockReadGuard<()>> for OwnedKeyRwLockGuard {
37    fn from(guard: OwnedRwLockReadGuard<()>) -> Self {
38        OwnedKeyRwLockGuard::Read { _guard: guard }
39    }
40}
41
42impl From<OwnedRwLockWriteGuard<()>> for OwnedKeyRwLockGuard {
43    fn from(guard: OwnedRwLockWriteGuard<()>) -> Self {
44        OwnedKeyRwLockGuard::Write { _guard: guard }
45    }
46}
47
48/// Locks based on a key, allowing other keys to lock independently.
49#[derive(Debug, Default)]
50pub struct KeyRwLock<K> {
51    /// The inner map of locks for specific keys.
52    inner: Mutex<HashMap<K, Arc<RwLock<()>>>>,
53}
54
55impl<K> KeyRwLock<K>
56where
57    K: Eq + Hash + Clone,
58{
59    pub fn new() -> Self {
60        KeyRwLock {
61            inner: Default::default(),
62        }
63    }
64
65    /// Locks the key with shared read access, returning a guard.
66    pub async fn read(&self, key: K) -> OwnedRwLockReadGuard<()> {
67        let lock = {
68            let mut locks = self.inner.lock().unwrap();
69            locks.entry(key).or_default().clone()
70        };
71
72        lock.read_owned().await
73    }
74
75    /// Locks the key with exclusive write access, returning a guard.
76    pub async fn write(&self, key: K) -> OwnedRwLockWriteGuard<()> {
77        let lock = {
78            let mut locks = self.inner.lock().unwrap();
79            locks.entry(key).or_default().clone()
80        };
81
82        lock.write_owned().await
83    }
84
85    /// Clean up stale locks.
86    ///
87    /// Note: It only cleans a lock if
88    /// - Its strong ref count equals one.
89    /// - Able to acquire the write lock.
90    pub fn clean_keys<'a>(&'a self, iter: impl IntoIterator<Item = &'a K>) {
91        let mut locks = self.inner.lock().unwrap();
92        let mut keys = Vec::new();
93        for key in iter {
94            if let Some(lock) = locks.get(key)
95                && lock.try_write().is_ok()
96            {
97                debug_assert_eq!(Arc::weak_count(lock), 0);
98                // Ensures nobody keeps this ref.
99                if Arc::strong_count(lock) == 1 {
100                    keys.push(key);
101                }
102            }
103        }
104
105        for key in keys {
106            locks.remove(key);
107        }
108    }
109
110    /// Clears all key locks.
111    ///
112    /// Callers must ensure no tasks are holding or waiting for these locks.
113    pub fn clear(&self) {
114        self.inner.lock().unwrap().clear();
115    }
116}
117
118#[cfg(test)]
119impl<K> KeyRwLock<K>
120where
121    K: Eq + Hash + Clone,
122{
123    /// Tries to lock the key with shared read access, returning immediately.
124    pub fn try_read(&self, key: K) -> Result<OwnedRwLockReadGuard<()>, tokio::sync::TryLockError> {
125        let lock = {
126            let mut locks = self.inner.lock().unwrap();
127            locks.entry(key).or_default().clone()
128        };
129
130        lock.try_read_owned()
131    }
132
133    /// Tries lock this key with exclusive write access, returning immediately.
134    pub fn try_write(
135        &self,
136        key: K,
137    ) -> Result<OwnedRwLockWriteGuard<()>, tokio::sync::TryLockError> {
138        let lock = {
139            let mut locks = self.inner.lock().unwrap();
140            locks.entry(key).or_default().clone()
141        };
142
143        lock.try_write_owned()
144    }
145
146    /// Returns number of keys.
147    pub fn len(&self) -> usize {
148        self.inner.lock().unwrap().len()
149    }
150
151    /// Returns true the inner map is empty.
152    pub fn is_empty(&self) -> bool {
153        self.len() == 0
154    }
155}
156
157#[cfg(test)]
158mod tests {
159    use super::*;
160
161    #[tokio::test]
162    async fn test_naive() {
163        let lock_key = KeyRwLock::new();
164
165        {
166            let _guard = lock_key.read("test1").await;
167            assert_eq!(lock_key.len(), 1);
168            assert!(lock_key.try_read("test1").is_ok());
169            assert!(lock_key.try_write("test1").is_err());
170        }
171
172        {
173            let _guard0 = lock_key.write("test2").await;
174            let _guard = lock_key.write("test1").await;
175            assert_eq!(lock_key.len(), 2);
176            assert!(lock_key.try_read("test1").is_err());
177            assert!(lock_key.try_write("test1").is_err());
178        }
179
180        assert_eq!(lock_key.len(), 2);
181
182        lock_key.clean_keys(&vec!["test1", "test2"]);
183        assert!(lock_key.is_empty());
184
185        let mut guards = Vec::new();
186        for key in ["test1", "test2"] {
187            guards.push(lock_key.read(key).await);
188        }
189        while !guards.is_empty() {
190            guards.pop();
191        }
192        lock_key.clean_keys(vec![&"test1", &"test2"]);
193        assert_eq!(lock_key.len(), 0);
194    }
195
196    #[tokio::test]
197    async fn test_clean_keys() {
198        let lock_key = KeyRwLock::<&str>::new();
199        {
200            let rwlock = {
201                lock_key
202                    .inner
203                    .lock()
204                    .unwrap()
205                    .entry("test")
206                    .or_default()
207                    .clone()
208            };
209            assert_eq!(Arc::strong_count(&rwlock), 2);
210            let _guard = rwlock.read_owned().await;
211
212            {
213                let inner = lock_key.inner.lock().unwrap();
214                let rwlock = inner.get("test").unwrap();
215                assert_eq!(Arc::strong_count(rwlock), 2);
216            }
217        }
218
219        {
220            let rwlock = {
221                lock_key
222                    .inner
223                    .lock()
224                    .unwrap()
225                    .entry("test")
226                    .or_default()
227                    .clone()
228            };
229            assert_eq!(Arc::strong_count(&rwlock), 2);
230            let _guard = rwlock.write_owned().await;
231
232            {
233                let inner = lock_key.inner.lock().unwrap();
234                let rwlock = inner.get("test").unwrap();
235                assert_eq!(Arc::strong_count(rwlock), 2);
236            }
237        }
238
239        {
240            let inner = lock_key.inner.lock().unwrap();
241            let rwlock = inner.get("test").unwrap();
242            assert_eq!(Arc::strong_count(rwlock), 1);
243        }
244
245        // Someone has the ref of the rwlock, but it waits to be granted the lock.
246        let rwlock = {
247            lock_key
248                .inner
249                .lock()
250                .unwrap()
251                .entry("test")
252                .or_default()
253                .clone()
254        };
255        assert_eq!(Arc::strong_count(&rwlock), 2);
256        // However, One thread trying to remove the "test" key should have no effect.
257        lock_key.clean_keys(vec![&"test"]);
258        // Should get the rwlock.
259        {
260            let inner = lock_key.inner.lock().unwrap();
261            inner.get("test").unwrap();
262        }
263    }
264}