1use std::borrow::Borrow;
16use std::hash::Hash;
17use std::sync::Arc;
18use std::sync::atomic::{AtomicUsize, Ordering};
19use std::time::Duration;
20
21use backon::{BackoffBuilder, ExponentialBuilder};
22use futures::future::BoxFuture;
23use moka::future::Cache;
24use snafu::{OptionExt, ResultExt};
25use tokio::time::sleep;
26
27use crate::cache_invalidator::{CacheInvalidator, Context};
28use crate::error::{self, Error, Result};
29use crate::instruction::CacheIdent;
30use crate::metrics;
31
32pub type TokenFilter<CacheToken> = Box<dyn Fn(&CacheToken) -> bool + Send + Sync>;
34
35pub type Invalidator<K, V, CacheToken> = Box<
37 dyn for<'a> Fn(&'a Cache<K, V>, &'a [&CacheToken]) -> BoxFuture<'a, Result<()>> + Send + Sync,
38>;
39
40pub type Initializer<K, V> = Arc<dyn Fn(&'_ K) -> BoxFuture<'_, Result<Option<V>>> + Send + Sync>;
42
43#[derive(Debug, Clone, Copy)]
44pub enum InitStrategy {
49 Unchecked,
53 VersionChecked,
57}
58
59pub struct CacheContainer<K, V, CacheToken> {
63 name: String,
64 cache: Cache<K, V>,
65 invalidator: Invalidator<K, V, CacheToken>,
66 initializer: Initializer<K, V>,
67 token_filter: fn(&CacheToken) -> bool,
68 version: Arc<AtomicUsize>,
69 init_strategy: InitStrategy,
70}
71
72fn latest_get_backoff() -> impl Iterator<Item = Duration> {
73 ExponentialBuilder::default()
74 .with_min_delay(Duration::from_millis(10))
75 .with_max_delay(Duration::from_millis(100))
76 .with_max_times(3)
77 .build()
78}
79
80impl<K, V, CacheToken> CacheContainer<K, V, CacheToken>
81where
82 K: Send + Sync,
83 V: Send + Sync,
84 CacheToken: Send + Sync,
85{
86 pub fn new(
91 name: String,
92 cache: Cache<K, V>,
93 invalidator: Invalidator<K, V, CacheToken>,
94 initializer: Initializer<K, V>,
95 token_filter: fn(&CacheToken) -> bool,
96 ) -> Self {
97 Self::with_strategy(
98 name,
99 cache,
100 invalidator,
101 initializer,
102 token_filter,
103 InitStrategy::Unchecked,
104 )
105 }
106
107 pub fn with_strategy(
111 name: String,
112 cache: Cache<K, V>,
113 invalidator: Invalidator<K, V, CacheToken>,
114 initializer: Initializer<K, V>,
115 token_filter: fn(&CacheToken) -> bool,
116 init_strategy: InitStrategy,
117 ) -> Self {
118 Self {
119 name,
120 cache,
121 invalidator,
122 initializer,
123 token_filter,
124 version: Arc::new(AtomicUsize::new(0)),
125 init_strategy,
126 }
127 }
128
129 pub fn name(&self) -> &str {
131 &self.name
132 }
133}
134
135impl<K, V, CacheToken> CacheContainer<K, V, CacheToken> {
136 fn inc_version(&self) {
137 self.version.fetch_add(1, Ordering::Relaxed);
138 }
139}
140
141async fn init<'a, K, V>(init: Initializer<K, V>, key: K, cache_name: &'a str) -> Result<V>
142where
143 K: Send + Sync + 'a,
144 V: Send + 'a,
145{
146 metrics::CACHE_CONTAINER_CACHE_MISS
147 .with_label_values(&[cache_name])
148 .inc();
149 let _timer = metrics::CACHE_CONTAINER_LOAD_CACHE
150 .with_label_values(&[cache_name])
151 .start_timer();
152 init(&key)
153 .await
154 .transpose()
155 .context(error::ValueNotExistSnafu)?
156}
157
158async fn init_with_retry<'a, K, V>(
159 init: Initializer<K, V>,
160 key: K,
161 mut backoff: impl Iterator<Item = Duration> + 'a,
162 version: Arc<AtomicUsize>,
163 cache_name: &'a str,
164) -> Result<V>
165where
166 K: Send + Sync + 'a,
167 V: Send + 'a,
168{
169 let mut attempts = 1usize;
170 loop {
171 let pre_version = version.load(Ordering::Relaxed);
172 metrics::CACHE_CONTAINER_CACHE_MISS
173 .with_label_values(&[cache_name])
174 .inc();
175 let _timer = metrics::CACHE_CONTAINER_LOAD_CACHE
176 .with_label_values(&[cache_name])
177 .start_timer();
178 let value = init(&key)
179 .await
180 .transpose()
181 .context(error::ValueNotExistSnafu)??;
182
183 if pre_version == version.load(Ordering::Relaxed) {
184 return Ok(value);
185 }
186
187 if let Some(duration) = backoff.next() {
188 sleep(duration).await;
189 attempts += 1;
190 } else {
191 return error::GetLatestCacheRetryExceededSnafu { attempts }.fail();
192 }
193 }
194}
195
196#[async_trait::async_trait]
197impl<K, V> CacheInvalidator for CacheContainer<K, V, CacheIdent>
198where
199 K: Hash + Eq + Send + Sync + 'static,
200 V: Clone + Send + Sync + 'static,
201{
202 async fn invalidate(&self, _ctx: &Context, caches: &[CacheIdent]) -> Result<()> {
203 let idents = caches
204 .iter()
205 .filter(|token| (self.token_filter)(token))
206 .collect::<Vec<_>>();
207 if !idents.is_empty() {
208 self.inc_version();
209 (self.invalidator)(&self.cache, &idents).await?;
210 }
211
212 Ok(())
213 }
214
215 fn invalidate_all(&self) -> Result<()> {
216 self.inc_version();
217 self.cache.invalidate_all();
218 Ok(())
219 }
220}
221
222impl<K, V, CacheToken> CacheContainer<K, V, CacheToken>
223where
224 K: Copy + Hash + Eq + Send + Sync + 'static,
225 V: Clone + Send + Sync + 'static,
226{
227 pub async fn get(&self, key: K) -> Result<Option<V>> {
233 metrics::CACHE_CONTAINER_CACHE_GET
234 .with_label_values(&[&self.name])
235 .inc();
236
237 let result = match self.init_strategy {
238 InitStrategy::Unchecked => {
239 self.cache
240 .try_get_with(key, init(self.initializer.clone(), key, &self.name))
241 .await
242 }
243 InitStrategy::VersionChecked => {
244 self.cache
245 .try_get_with(
246 key,
247 init_with_retry(
248 self.initializer.clone(),
249 key,
250 latest_get_backoff(),
251 self.version.clone(),
252 &self.name,
253 ),
254 )
255 .await
256 }
257 };
258
259 match result {
260 Ok(value) => Ok(Some(value)),
261 Err(err) => match err.as_ref() {
262 Error::ValueNotExist { .. } => Ok(None),
263 _ => Err(err).context(error::GetCacheSnafu),
264 },
265 }
266 }
267}
268
269impl<K, V, CacheToken> CacheContainer<K, V, CacheToken>
270where
271 K: Hash + Eq + Send + Sync + 'static,
272 V: Clone + Send + Sync + 'static,
273{
274 pub async fn invalidate(&self, caches: &[CacheToken]) -> Result<()> {
276 let idents = caches
277 .iter()
278 .filter(|token| (self.token_filter)(token))
279 .collect::<Vec<_>>();
280 if !idents.is_empty() {
281 self.inc_version();
282 (self.invalidator)(&self.cache, &idents).await?;
283 }
284
285 Ok(())
286 }
287
288 pub fn contains_key<Q>(&self, key: &Q) -> bool
290 where
291 K: Borrow<Q>,
292 Q: Hash + Eq + ?Sized,
293 {
294 self.cache.contains_key(key)
295 }
296
297 pub async fn get_by_ref<Q>(&self, key: &Q) -> Result<Option<V>>
303 where
304 K: Borrow<Q>,
305 Q: ToOwned<Owned = K> + Hash + Eq + ?Sized,
306 {
307 metrics::CACHE_CONTAINER_CACHE_GET
308 .with_label_values(&[&self.name])
309 .inc();
310 let result = match self.init_strategy {
311 InitStrategy::Unchecked => {
312 self.cache
313 .try_get_with_by_ref(
314 key,
315 init(self.initializer.clone(), key.to_owned(), &self.name),
316 )
317 .await
318 }
319 InitStrategy::VersionChecked => {
320 self.cache
321 .try_get_with_by_ref(
322 key,
323 init_with_retry(
324 self.initializer.clone(),
325 key.to_owned(),
326 latest_get_backoff(),
327 self.version.clone(),
328 &self.name,
329 ),
330 )
331 .await
332 }
333 };
334
335 match result {
336 Ok(value) => Ok(Some(value)),
337 Err(err) => match err.as_ref() {
338 Error::ValueNotExist { .. } => Ok(None),
339 _ => Err(err).context(error::GetCacheSnafu),
340 },
341 }
342 }
343}
344
345#[cfg(test)]
346mod tests {
347 use std::sync::Arc;
348 use std::sync::atomic::{AtomicI32, Ordering};
349
350 use moka::future::{Cache, CacheBuilder};
351
352 use super::*;
353
354 #[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
355 struct NameKey<'a> {
356 name: &'a str,
357 }
358
359 fn always_true_filter(_: &String) -> bool {
360 true
361 }
362
363 #[tokio::test]
364 async fn test_get() {
365 let cache: Cache<NameKey, String> = CacheBuilder::new(128).build();
366 let counter = Arc::new(AtomicI32::new(0));
367 let moved_counter = counter.clone();
368 let init: Initializer<NameKey, String> = Arc::new(move |_| {
369 moved_counter.fetch_add(1, Ordering::Relaxed);
370 Box::pin(async { Ok(Some("hi".to_string())) })
371 });
372 let invalidator: Invalidator<NameKey, String, String> =
373 Box::new(|_, _| Box::pin(async { Ok(()) }));
374
375 let adv_cache = CacheContainer::new(
376 "test".to_string(),
377 cache,
378 invalidator,
379 init,
380 always_true_filter,
381 );
382 let key = NameKey { name: "key" };
383 let value = adv_cache.get(key).await.unwrap().unwrap();
384 assert_eq!(value, "hi");
385 assert_eq!(counter.load(Ordering::Relaxed), 1);
386 let key = NameKey { name: "key" };
387 let value = adv_cache.get(key).await.unwrap().unwrap();
388 assert_eq!(value, "hi");
389 assert_eq!(counter.load(Ordering::Relaxed), 1);
390 }
391
392 #[tokio::test]
393 async fn test_get_by_ref() {
394 let cache: Cache<String, String> = CacheBuilder::new(128).build();
395 let counter = Arc::new(AtomicI32::new(0));
396 let moved_counter = counter.clone();
397 let init: Initializer<String, String> = Arc::new(move |_| {
398 moved_counter.fetch_add(1, Ordering::Relaxed);
399 Box::pin(async { Ok(Some("hi".to_string())) })
400 });
401 let invalidator: Invalidator<String, String, String> =
402 Box::new(|_, _| Box::pin(async { Ok(()) }));
403
404 let adv_cache = CacheContainer::new(
405 "test".to_string(),
406 cache,
407 invalidator,
408 init,
409 always_true_filter,
410 );
411 let value = adv_cache.get_by_ref("foo").await.unwrap().unwrap();
412 assert_eq!(value, "hi");
413 let value = adv_cache.get_by_ref("foo").await.unwrap().unwrap();
414 assert_eq!(value, "hi");
415 assert_eq!(counter.load(Ordering::Relaxed), 1);
416 let value = adv_cache.get_by_ref("bar").await.unwrap().unwrap();
417 assert_eq!(value, "hi");
418 assert_eq!(counter.load(Ordering::Relaxed), 2);
419 }
420
421 #[tokio::test]
422 async fn test_get_value_not_exits() {
423 let cache: Cache<String, String> = CacheBuilder::new(128).build();
424 let init: Initializer<String, String> =
425 Arc::new(move |_| Box::pin(async { error::ValueNotExistSnafu {}.fail() }));
426 let invalidator: Invalidator<String, String, String> =
427 Box::new(|_, _| Box::pin(async { Ok(()) }));
428
429 let adv_cache = CacheContainer::new(
430 "test".to_string(),
431 cache,
432 invalidator,
433 init,
434 always_true_filter,
435 );
436 let value = adv_cache.get_by_ref("foo").await.unwrap();
437 assert!(value.is_none());
438 }
439
440 #[tokio::test]
441 async fn test_invalidate() {
442 let cache: Cache<String, String> = CacheBuilder::new(128).build();
443 let counter = Arc::new(AtomicI32::new(0));
444 let moved_counter = counter.clone();
445 let init: Initializer<String, String> = Arc::new(move |_| {
446 moved_counter.fetch_add(1, Ordering::Relaxed);
447 Box::pin(async { Ok(Some("hi".to_string())) })
448 });
449 let invalidator: Invalidator<String, String, String> = Box::new(|cache, keys| {
450 Box::pin(async move {
451 for key in keys {
452 cache.invalidate(*key).await;
453 }
454 Ok(())
455 })
456 });
457
458 let adv_cache = CacheContainer::new(
459 "test".to_string(),
460 cache,
461 invalidator,
462 init,
463 always_true_filter,
464 );
465 let value = adv_cache.get_by_ref("foo").await.unwrap().unwrap();
466 assert_eq!(value, "hi");
467 let value = adv_cache.get_by_ref("foo").await.unwrap().unwrap();
468 assert_eq!(value, "hi");
469 assert_eq!(counter.load(Ordering::Relaxed), 1);
470 adv_cache
471 .invalidate(&["foo".to_string(), "bar".to_string()])
472 .await
473 .unwrap();
474 let value = adv_cache.get_by_ref("foo").await.unwrap().unwrap();
475 assert_eq!(value, "hi");
476 assert_eq!(counter.load(Ordering::Relaxed), 2);
477 }
478
479 #[tokio::test(flavor = "multi_thread")]
480 async fn test_get_by_ref_returns_fresh_value_after_invalidate() {
481 let cache: Cache<String, String> = CacheBuilder::new(128).build();
482 let counter = Arc::new(AtomicI32::new(0));
483 let moved_counter = counter.clone();
484 let init: Initializer<String, String> = Arc::new(move |_| {
485 let counter = moved_counter.clone();
486 Box::pin(async move {
487 let n = counter.fetch_add(1, Ordering::Relaxed) + 1;
488 sleep(Duration::from_millis(100)).await;
489 Ok(Some(format!("v{n}")))
490 })
491 });
492 let invalidator: Invalidator<String, String, String> = Box::new(|cache, keys| {
493 Box::pin(async move {
494 for key in keys {
495 cache.invalidate(*key).await;
496 }
497 Ok(())
498 })
499 });
500
501 let adv_cache = Arc::new(CacheContainer::with_strategy(
502 "test".to_string(),
503 cache,
504 invalidator,
505 init,
506 always_true_filter,
507 InitStrategy::VersionChecked,
508 ));
509
510 let moved_cache = adv_cache.clone();
511 let get_task = tokio::spawn(async move { moved_cache.get_by_ref("foo").await });
512
513 sleep(Duration::from_millis(50)).await;
514 adv_cache.invalidate(&["foo".to_string()]).await.unwrap();
515
516 let value = get_task.await.unwrap().unwrap().unwrap();
517 assert_eq!(value, "v2");
518 assert_eq!(counter.load(Ordering::Relaxed), 2);
519 }
520}