1use std::sync::Arc;
16
17use anymap2::SendSyncAnyMap;
18use futures::future::join_all;
19
20use crate::cache_invalidator::{CacheInvalidator, Context};
21use crate::error::Result;
22use crate::instruction::CacheIdent;
23
24pub type CacheRegistryRef = Arc<CacheRegistry>;
25pub type LayeredCacheRegistryRef = Arc<LayeredCacheRegistry>;
26
27#[derive(Default)]
29pub struct LayeredCacheRegistryBuilder {
30 registry: LayeredCacheRegistry,
31}
32
33impl LayeredCacheRegistryBuilder {
34 pub fn add_cache_registry(mut self, registry: CacheRegistry) -> Self {
39 self.registry.layers.push(registry);
40
41 self
42 }
43
44 pub fn get<T: Send + Sync + Clone + 'static>(&self) -> Option<T> {
46 self.registry.get()
47 }
48
49 pub fn build(self) -> LayeredCacheRegistry {
51 self.registry
52 }
53}
54
55#[derive(Default)]
57pub struct LayeredCacheRegistry {
58 layers: Vec<CacheRegistry>,
59}
60
61#[async_trait::async_trait]
62impl CacheInvalidator for LayeredCacheRegistry {
63 async fn invalidate(&self, ctx: &Context, caches: &[CacheIdent]) -> Result<()> {
64 let mut results = Vec::with_capacity(self.layers.len());
65 for registry in &self.layers {
66 results.push(registry.invalidate(ctx, caches).await);
67 }
68 results.into_iter().collect::<Result<Vec<_>>>().map(|_| ())
69 }
70
71 fn invalidate_all(&self) -> Result<()> {
72 for registry in &self.layers {
73 registry.invalidate_all()?;
74 }
75 Ok(())
76 }
77}
78
79impl LayeredCacheRegistry {
80 pub fn get<T: Send + Sync + Clone + 'static>(&self) -> Option<T> {
82 for registry in &self.layers {
83 if let Some(cache) = registry.get::<T>() {
84 return Some(cache);
85 }
86 }
87
88 None
89 }
90}
91
92#[derive(Default)]
96pub struct CacheRegistryBuilder {
97 registry: CacheRegistry,
98}
99
100impl CacheRegistryBuilder {
101 pub fn add_cache<T: CacheInvalidator + 'static>(mut self, cache: Arc<T>) -> Self {
103 self.registry.register(cache);
104 self
105 }
106
107 pub fn build(self) -> CacheRegistry {
109 self.registry
110 }
111}
112
113#[derive(Default)]
116pub struct CacheRegistry {
117 indexes: Vec<Arc<dyn CacheInvalidator>>,
118 registry: SendSyncAnyMap,
119}
120
121#[async_trait::async_trait]
122impl CacheInvalidator for CacheRegistry {
123 async fn invalidate(&self, ctx: &Context, caches: &[CacheIdent]) -> Result<()> {
124 let tasks = self
125 .indexes
126 .iter()
127 .map(|invalidator| invalidator.invalidate(ctx, caches));
128 join_all(tasks)
129 .await
130 .into_iter()
131 .collect::<Result<Vec<_>>>()?;
132 Ok(())
133 }
134
135 fn invalidate_all(&self) -> Result<()> {
136 for invalidator in &self.indexes {
137 invalidator.invalidate_all()?;
138 }
139 Ok(())
140 }
141}
142
143impl CacheRegistry {
144 fn register<T: CacheInvalidator + 'static>(&mut self, cache: Arc<T>) -> bool {
147 self.indexes.push(cache.clone());
148 self.registry.insert(cache).is_some()
149 }
150
151 pub fn get<T: Send + Sync + Clone + 'static>(&self) -> Option<T> {
153 self.registry.get().cloned()
154 }
155}
156
157#[cfg(test)]
158mod tests {
159 use std::sync::Arc;
160 use std::sync::atomic::{AtomicBool, AtomicI32, Ordering};
161
162 use moka::future::{Cache, CacheBuilder};
163
164 use crate::cache::registry::CacheRegistryBuilder;
165 use crate::cache::*;
166 use crate::cache_invalidator::{CacheInvalidator, Context};
167 use crate::error::Result;
168 use crate::instruction::CacheIdent;
169
170 fn always_true_filter(_: &CacheIdent) -> bool {
171 true
172 }
173
174 fn test_cache(
175 name: &str,
176 invalidator: Invalidator<String, String, CacheIdent>,
177 ) -> CacheContainer<String, String, CacheIdent> {
178 let cache: Cache<String, String> = CacheBuilder::new(128).build();
179 let counter = Arc::new(AtomicI32::new(0));
180 let moved_counter = counter.clone();
181 let init: Initializer<String, String> = Arc::new(move |_| {
182 moved_counter.fetch_add(1, Ordering::Relaxed);
183 Box::pin(async { Ok(Some("hi".to_string())) })
184 });
185
186 CacheContainer::new(
187 name.to_string(),
188 cache,
189 invalidator,
190 init,
191 always_true_filter,
192 )
193 }
194
195 fn test_i32_cache(
196 name: &str,
197 invalidator: Invalidator<i32, String, CacheIdent>,
198 ) -> CacheContainer<i32, String, CacheIdent> {
199 let cache: Cache<i32, String> = CacheBuilder::new(128).build();
200 let counter = Arc::new(AtomicI32::new(0));
201 let moved_counter = counter.clone();
202 let init: Initializer<i32, String> = Arc::new(move |_| {
203 moved_counter.fetch_add(1, Ordering::Relaxed);
204 Box::pin(async { Ok(Some("foo".to_string())) })
205 });
206
207 CacheContainer::new(
208 name.to_string(),
209 cache,
210 invalidator,
211 init,
212 always_true_filter,
213 )
214 }
215
216 #[tokio::test]
217 async fn test_register() {
218 let builder = CacheRegistryBuilder::default();
219 let invalidator: Invalidator<_, String, CacheIdent> =
220 Box::new(|_, _| Box::pin(async { Ok(()) }));
221 let i32_cache = Arc::new(test_i32_cache("i32_cache", invalidator));
222 let invalidator: Invalidator<_, String, CacheIdent> =
223 Box::new(|_, _| Box::pin(async { Ok(()) }));
224 let cache = Arc::new(test_cache("string_cache", invalidator));
225 let registry = builder.add_cache(i32_cache).add_cache(cache).build();
226
227 let cache = registry
228 .get::<Arc<CacheContainer<i32, String, CacheIdent>>>()
229 .unwrap();
230 assert_eq!(cache.name(), "i32_cache");
231
232 let cache = registry
233 .get::<Arc<CacheContainer<String, String, CacheIdent>>>()
234 .unwrap();
235 assert_eq!(cache.name(), "string_cache");
236 }
237
238 #[tokio::test]
239 async fn test_layered_registry() {
240 let builder = LayeredCacheRegistryBuilder::default();
241 let counter = Arc::new(AtomicBool::new(false));
243 let moved_counter = counter.clone();
244 let invalidator: Invalidator<String, String, CacheIdent> = Box::new(move |_, _| {
245 let counter = moved_counter.clone();
246 Box::pin(async move {
247 assert!(!counter.load(Ordering::Relaxed));
248 counter.store(true, Ordering::Relaxed);
249 Ok(())
250 })
251 });
252 let cache = Arc::new(test_cache("string_cache", invalidator));
253 let builder =
254 builder.add_cache_registry(CacheRegistryBuilder::default().add_cache(cache).build());
255 let moved_counter = counter.clone();
257 let invalidator: Invalidator<i32, String, CacheIdent> = Box::new(move |_, _| {
258 let counter = moved_counter.clone();
259 Box::pin(async move {
260 assert!(counter.load(Ordering::Relaxed));
261 Ok(())
262 })
263 });
264 let i32_cache = Arc::new(test_i32_cache("i32_cache", invalidator));
265 let builder = builder
266 .add_cache_registry(CacheRegistryBuilder::default().add_cache(i32_cache).build());
267
268 let registry = builder.build();
269 let cache = registry
270 .get::<Arc<CacheContainer<i32, String, CacheIdent>>>()
271 .unwrap();
272 assert_eq!(cache.name(), "i32_cache");
273 let cache = registry
274 .get::<Arc<CacheContainer<String, String, CacheIdent>>>()
275 .unwrap();
276 assert_eq!(cache.name(), "string_cache");
277 }
278
279 #[tokio::test]
280 async fn test_registry_invalidate_all() {
281 let invalidator: Invalidator<_, String, CacheIdent> =
282 Box::new(|_, _| Box::pin(async { Ok(()) }));
283 let i32_cache = Arc::new(test_i32_cache("i32_cache", invalidator));
284 let invalidator: Invalidator<_, String, CacheIdent> =
285 Box::new(|_, _| Box::pin(async { Ok(()) }));
286 let string_cache = Arc::new(test_cache("string_cache", invalidator));
287
288 i32_cache.get(1).await.unwrap();
289 string_cache.get_by_ref("foo").await.unwrap();
290 assert!(i32_cache.contains_key(&1));
291 assert!(string_cache.contains_key("foo"));
292
293 let registry = CacheRegistryBuilder::default()
294 .add_cache(i32_cache.clone())
295 .add_cache(string_cache.clone())
296 .build();
297
298 registry.invalidate_all().unwrap();
299
300 assert!(!i32_cache.contains_key(&1));
301 assert!(!string_cache.contains_key("foo"));
302 }
303
304 struct LayerOrderInvalidator {
305 expected_order: i32,
306 order: Arc<AtomicI32>,
307 }
308
309 #[async_trait::async_trait]
310 impl CacheInvalidator for LayerOrderInvalidator {
311 async fn invalidate(&self, _ctx: &Context, _caches: &[CacheIdent]) -> Result<()> {
312 Ok(())
313 }
314
315 fn invalidate_all(&self) -> Result<()> {
316 let previous = self.order.fetch_add(1, Ordering::Relaxed);
317 assert_eq!(self.expected_order, previous);
318 Ok(())
319 }
320 }
321
322 #[tokio::test]
323 async fn test_layered_registry_invalidate_all() {
324 let order = Arc::new(AtomicI32::new(0));
325 let invalidator: Invalidator<_, String, CacheIdent> =
326 Box::new(|_, _| Box::pin(async { Ok(()) }));
327 let first_layer_cache = Arc::new(test_cache("first_layer_cache", invalidator));
328 let first_layer_order = Arc::new(LayerOrderInvalidator {
329 expected_order: 0,
330 order: order.clone(),
331 });
332 let first_layer = CacheRegistryBuilder::default()
333 .add_cache(first_layer_order)
334 .add_cache(first_layer_cache.clone())
335 .build();
336
337 let invalidator: Invalidator<_, String, CacheIdent> =
338 Box::new(|_, _| Box::pin(async { Ok(()) }));
339 let second_layer_cache = Arc::new(test_i32_cache("second_layer_cache", invalidator));
340 let second_layer_order = Arc::new(LayerOrderInvalidator {
341 expected_order: 1,
342 order: order.clone(),
343 });
344 let second_layer = CacheRegistryBuilder::default()
345 .add_cache(second_layer_order)
346 .add_cache(second_layer_cache.clone())
347 .build();
348
349 first_layer_cache.get_by_ref("foo").await.unwrap();
350 second_layer_cache.get(1).await.unwrap();
351 assert!(first_layer_cache.contains_key("foo"));
352 assert!(second_layer_cache.contains_key(&1));
353
354 let registry = LayeredCacheRegistryBuilder::default()
355 .add_cache_registry(first_layer)
356 .add_cache_registry(second_layer)
357 .build();
358
359 registry.invalidate_all().unwrap();
360
361 assert_eq!(2, order.load(Ordering::Relaxed));
362 assert!(!first_layer_cache.contains_key("foo"));
363 assert!(!second_layer_cache.contains_key(&1));
364 }
365}