Skip to main content

common_meta/cache/
registry.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use anymap2::SendSyncAnyMap;
18use futures::future::join_all;
19
20use crate::cache_invalidator::{CacheInvalidator, Context};
21use crate::error::Result;
22use crate::instruction::CacheIdent;
23
24pub type CacheRegistryRef = Arc<CacheRegistry>;
25pub type LayeredCacheRegistryRef = Arc<LayeredCacheRegistry>;
26
27/// [LayeredCacheRegistry] Builder.
28#[derive(Default)]
29pub struct LayeredCacheRegistryBuilder {
30    registry: LayeredCacheRegistry,
31}
32
33impl LayeredCacheRegistryBuilder {
34    /// Adds [CacheRegistry] into the next layer.
35    ///
36    /// During cache invalidation, [LayeredCacheRegistry] ensures sequential invalidation
37    /// of each layer (after the previous layer).
38    pub fn add_cache_registry(mut self, registry: CacheRegistry) -> Self {
39        self.registry.layers.push(registry);
40
41        self
42    }
43
44    /// Returns __cloned__ the value stored in the collection for the type `T`, if it exists.
45    pub fn get<T: Send + Sync + Clone + 'static>(&self) -> Option<T> {
46        self.registry.get()
47    }
48
49    /// Builds the [LayeredCacheRegistry]
50    pub fn build(self) -> LayeredCacheRegistry {
51        self.registry
52    }
53}
54
55/// [LayeredCacheRegistry] invalidate caches sequentially from the first layer.
56#[derive(Default)]
57pub struct LayeredCacheRegistry {
58    layers: Vec<CacheRegistry>,
59}
60
61#[async_trait::async_trait]
62impl CacheInvalidator for LayeredCacheRegistry {
63    async fn invalidate(&self, ctx: &Context, caches: &[CacheIdent]) -> Result<()> {
64        let mut results = Vec::with_capacity(self.layers.len());
65        for registry in &self.layers {
66            results.push(registry.invalidate(ctx, caches).await);
67        }
68        results.into_iter().collect::<Result<Vec<_>>>().map(|_| ())
69    }
70
71    fn invalidate_all(&self) -> Result<()> {
72        for registry in &self.layers {
73            registry.invalidate_all()?;
74        }
75        Ok(())
76    }
77}
78
79impl LayeredCacheRegistry {
80    /// Returns __cloned__ the value stored in the collection for the type `T`, if it exists.
81    pub fn get<T: Send + Sync + Clone + 'static>(&self) -> Option<T> {
82        for registry in &self.layers {
83            if let Some(cache) = registry.get::<T>() {
84                return Some(cache);
85            }
86        }
87
88        None
89    }
90}
91
92/// [CacheRegistryBuilder] provides ability of
93/// - Register the `cache` which implements the [CacheInvalidator] trait into [CacheRegistry].
94/// - Build a [CacheRegistry]
95#[derive(Default)]
96pub struct CacheRegistryBuilder {
97    registry: CacheRegistry,
98}
99
100impl CacheRegistryBuilder {
101    /// Adds the cache.
102    pub fn add_cache<T: CacheInvalidator + 'static>(mut self, cache: Arc<T>) -> Self {
103        self.registry.register(cache);
104        self
105    }
106
107    /// Builds [CacheRegistry].
108    pub fn build(self) -> CacheRegistry {
109        self.registry
110    }
111}
112
113/// [CacheRegistry] provides ability of
114/// - Get a specific `cache`.
115#[derive(Default)]
116pub struct CacheRegistry {
117    indexes: Vec<Arc<dyn CacheInvalidator>>,
118    registry: SendSyncAnyMap,
119}
120
121#[async_trait::async_trait]
122impl CacheInvalidator for CacheRegistry {
123    async fn invalidate(&self, ctx: &Context, caches: &[CacheIdent]) -> Result<()> {
124        let tasks = self
125            .indexes
126            .iter()
127            .map(|invalidator| invalidator.invalidate(ctx, caches));
128        join_all(tasks)
129            .await
130            .into_iter()
131            .collect::<Result<Vec<_>>>()?;
132        Ok(())
133    }
134
135    fn invalidate_all(&self) -> Result<()> {
136        for invalidator in &self.indexes {
137            invalidator.invalidate_all()?;
138        }
139        Ok(())
140    }
141}
142
143impl CacheRegistry {
144    /// Sets the value stored in the collection for the type `T`.
145    /// Returns true if the collection already had a value of type `T`
146    fn register<T: CacheInvalidator + 'static>(&mut self, cache: Arc<T>) -> bool {
147        self.indexes.push(cache.clone());
148        self.registry.insert(cache).is_some()
149    }
150
151    /// Returns __cloned__ the value stored in the collection for the type `T`, if it exists.
152    pub fn get<T: Send + Sync + Clone + 'static>(&self) -> Option<T> {
153        self.registry.get().cloned()
154    }
155}
156
157#[cfg(test)]
158mod tests {
159    use std::sync::Arc;
160    use std::sync::atomic::{AtomicBool, AtomicI32, Ordering};
161
162    use moka::future::{Cache, CacheBuilder};
163
164    use crate::cache::registry::CacheRegistryBuilder;
165    use crate::cache::*;
166    use crate::cache_invalidator::{CacheInvalidator, Context};
167    use crate::error::Result;
168    use crate::instruction::CacheIdent;
169
170    fn always_true_filter(_: &CacheIdent) -> bool {
171        true
172    }
173
174    fn test_cache(
175        name: &str,
176        invalidator: Invalidator<String, String, CacheIdent>,
177    ) -> CacheContainer<String, String, CacheIdent> {
178        let cache: Cache<String, String> = CacheBuilder::new(128).build();
179        let counter = Arc::new(AtomicI32::new(0));
180        let moved_counter = counter.clone();
181        let init: Initializer<String, String> = Arc::new(move |_| {
182            moved_counter.fetch_add(1, Ordering::Relaxed);
183            Box::pin(async { Ok(Some("hi".to_string())) })
184        });
185
186        CacheContainer::new(
187            name.to_string(),
188            cache,
189            invalidator,
190            init,
191            always_true_filter,
192        )
193    }
194
195    fn test_i32_cache(
196        name: &str,
197        invalidator: Invalidator<i32, String, CacheIdent>,
198    ) -> CacheContainer<i32, String, CacheIdent> {
199        let cache: Cache<i32, String> = CacheBuilder::new(128).build();
200        let counter = Arc::new(AtomicI32::new(0));
201        let moved_counter = counter.clone();
202        let init: Initializer<i32, String> = Arc::new(move |_| {
203            moved_counter.fetch_add(1, Ordering::Relaxed);
204            Box::pin(async { Ok(Some("foo".to_string())) })
205        });
206
207        CacheContainer::new(
208            name.to_string(),
209            cache,
210            invalidator,
211            init,
212            always_true_filter,
213        )
214    }
215
216    #[tokio::test]
217    async fn test_register() {
218        let builder = CacheRegistryBuilder::default();
219        let invalidator: Invalidator<_, String, CacheIdent> =
220            Box::new(|_, _| Box::pin(async { Ok(()) }));
221        let i32_cache = Arc::new(test_i32_cache("i32_cache", invalidator));
222        let invalidator: Invalidator<_, String, CacheIdent> =
223            Box::new(|_, _| Box::pin(async { Ok(()) }));
224        let cache = Arc::new(test_cache("string_cache", invalidator));
225        let registry = builder.add_cache(i32_cache).add_cache(cache).build();
226
227        let cache = registry
228            .get::<Arc<CacheContainer<i32, String, CacheIdent>>>()
229            .unwrap();
230        assert_eq!(cache.name(), "i32_cache");
231
232        let cache = registry
233            .get::<Arc<CacheContainer<String, String, CacheIdent>>>()
234            .unwrap();
235        assert_eq!(cache.name(), "string_cache");
236    }
237
238    #[tokio::test]
239    async fn test_layered_registry() {
240        let builder = LayeredCacheRegistryBuilder::default();
241        // 1st layer
242        let counter = Arc::new(AtomicBool::new(false));
243        let moved_counter = counter.clone();
244        let invalidator: Invalidator<String, String, CacheIdent> = Box::new(move |_, _| {
245            let counter = moved_counter.clone();
246            Box::pin(async move {
247                assert!(!counter.load(Ordering::Relaxed));
248                counter.store(true, Ordering::Relaxed);
249                Ok(())
250            })
251        });
252        let cache = Arc::new(test_cache("string_cache", invalidator));
253        let builder =
254            builder.add_cache_registry(CacheRegistryBuilder::default().add_cache(cache).build());
255        // 2nd layer
256        let moved_counter = counter.clone();
257        let invalidator: Invalidator<i32, String, CacheIdent> = Box::new(move |_, _| {
258            let counter = moved_counter.clone();
259            Box::pin(async move {
260                assert!(counter.load(Ordering::Relaxed));
261                Ok(())
262            })
263        });
264        let i32_cache = Arc::new(test_i32_cache("i32_cache", invalidator));
265        let builder = builder
266            .add_cache_registry(CacheRegistryBuilder::default().add_cache(i32_cache).build());
267
268        let registry = builder.build();
269        let cache = registry
270            .get::<Arc<CacheContainer<i32, String, CacheIdent>>>()
271            .unwrap();
272        assert_eq!(cache.name(), "i32_cache");
273        let cache = registry
274            .get::<Arc<CacheContainer<String, String, CacheIdent>>>()
275            .unwrap();
276        assert_eq!(cache.name(), "string_cache");
277    }
278
279    #[tokio::test]
280    async fn test_registry_invalidate_all() {
281        let invalidator: Invalidator<_, String, CacheIdent> =
282            Box::new(|_, _| Box::pin(async { Ok(()) }));
283        let i32_cache = Arc::new(test_i32_cache("i32_cache", invalidator));
284        let invalidator: Invalidator<_, String, CacheIdent> =
285            Box::new(|_, _| Box::pin(async { Ok(()) }));
286        let string_cache = Arc::new(test_cache("string_cache", invalidator));
287
288        i32_cache.get(1).await.unwrap();
289        string_cache.get_by_ref("foo").await.unwrap();
290        assert!(i32_cache.contains_key(&1));
291        assert!(string_cache.contains_key("foo"));
292
293        let registry = CacheRegistryBuilder::default()
294            .add_cache(i32_cache.clone())
295            .add_cache(string_cache.clone())
296            .build();
297
298        registry.invalidate_all().unwrap();
299
300        assert!(!i32_cache.contains_key(&1));
301        assert!(!string_cache.contains_key("foo"));
302    }
303
304    struct LayerOrderInvalidator {
305        expected_order: i32,
306        order: Arc<AtomicI32>,
307    }
308
309    #[async_trait::async_trait]
310    impl CacheInvalidator for LayerOrderInvalidator {
311        async fn invalidate(&self, _ctx: &Context, _caches: &[CacheIdent]) -> Result<()> {
312            Ok(())
313        }
314
315        fn invalidate_all(&self) -> Result<()> {
316            let previous = self.order.fetch_add(1, Ordering::Relaxed);
317            assert_eq!(self.expected_order, previous);
318            Ok(())
319        }
320    }
321
322    #[tokio::test]
323    async fn test_layered_registry_invalidate_all() {
324        let order = Arc::new(AtomicI32::new(0));
325        let invalidator: Invalidator<_, String, CacheIdent> =
326            Box::new(|_, _| Box::pin(async { Ok(()) }));
327        let first_layer_cache = Arc::new(test_cache("first_layer_cache", invalidator));
328        let first_layer_order = Arc::new(LayerOrderInvalidator {
329            expected_order: 0,
330            order: order.clone(),
331        });
332        let first_layer = CacheRegistryBuilder::default()
333            .add_cache(first_layer_order)
334            .add_cache(first_layer_cache.clone())
335            .build();
336
337        let invalidator: Invalidator<_, String, CacheIdent> =
338            Box::new(|_, _| Box::pin(async { Ok(()) }));
339        let second_layer_cache = Arc::new(test_i32_cache("second_layer_cache", invalidator));
340        let second_layer_order = Arc::new(LayerOrderInvalidator {
341            expected_order: 1,
342            order: order.clone(),
343        });
344        let second_layer = CacheRegistryBuilder::default()
345            .add_cache(second_layer_order)
346            .add_cache(second_layer_cache.clone())
347            .build();
348
349        first_layer_cache.get_by_ref("foo").await.unwrap();
350        second_layer_cache.get(1).await.unwrap();
351        assert!(first_layer_cache.contains_key("foo"));
352        assert!(second_layer_cache.contains_key(&1));
353
354        let registry = LayeredCacheRegistryBuilder::default()
355            .add_cache_registry(first_layer)
356            .add_cache_registry(second_layer)
357            .build();
358
359        registry.invalidate_all().unwrap();
360
361        assert_eq!(2, order.load(Ordering::Relaxed));
362        assert!(!first_layer_cache.contains_key("foo"));
363        assert!(!second_layer_cache.contains_key(&1));
364    }
365}