Skip to main content

common_meta/cache/
container.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::borrow::Borrow;
16use std::hash::Hash;
17use std::sync::Arc;
18use std::sync::atomic::{AtomicUsize, Ordering};
19use std::time::Duration;
20
21use backon::{BackoffBuilder, ExponentialBuilder};
22use futures::future::BoxFuture;
23use moka::future::Cache;
24use snafu::{OptionExt, ResultExt};
25use tokio::time::sleep;
26
27use crate::cache_invalidator::{CacheInvalidator, Context};
28use crate::error::{self, Error, Result};
29use crate::instruction::CacheIdent;
30use crate::metrics;
31
32/// Filters out unused [CacheToken]s
33pub type TokenFilter<CacheToken> = Box<dyn Fn(&CacheToken) -> bool + Send + Sync>;
34
35/// Invalidates cached values by [CacheToken]s.
36pub type Invalidator<K, V, CacheToken> = Box<
37    dyn for<'a> Fn(&'a Cache<K, V>, &'a [&CacheToken]) -> BoxFuture<'a, Result<()>> + Send + Sync,
38>;
39
40/// Initializes value (i.e., fetches from remote).
41pub type Initializer<K, V> = Arc<dyn Fn(&'_ K) -> BoxFuture<'_, Result<Option<V>>> + Send + Sync>;
42
43#[derive(Debug, Clone, Copy)]
44/// Initialization strategy for cache-miss loading.
45///
46/// This strategy is selected when building [CacheContainer] and remains immutable
47/// for the lifetime of the container instance.
48pub enum InitStrategy {
49    /// Fast path: load once without version conflict retry.
50    ///
51    /// Under concurrent invalidation, callers may observe stale/dirty value.
52    Unchecked,
53    /// Strict path: retry load when version changes during initialization.
54    ///
55    /// This avoids returning dirty value under invalidate/load races.
56    VersionChecked,
57}
58
59/// [CacheContainer] provides ability to:
60/// - Cache value loaded by [Initializer].
61/// - Invalidate caches by [Invalidator].
62pub struct CacheContainer<K, V, CacheToken> {
63    name: String,
64    cache: Cache<K, V>,
65    invalidator: Invalidator<K, V, CacheToken>,
66    initializer: Initializer<K, V>,
67    token_filter: fn(&CacheToken) -> bool,
68    version: Arc<AtomicUsize>,
69    init_strategy: InitStrategy,
70}
71
72fn latest_get_backoff() -> impl Iterator<Item = Duration> {
73    ExponentialBuilder::default()
74        .with_min_delay(Duration::from_millis(10))
75        .with_max_delay(Duration::from_millis(100))
76        .with_max_times(3)
77        .build()
78}
79
80impl<K, V, CacheToken> CacheContainer<K, V, CacheToken>
81where
82    K: Send + Sync,
83    V: Send + Sync,
84    CacheToken: Send + Sync,
85{
86    /// Constructs an [CacheContainer] with [InitStrategy::Unchecked].
87    ///
88    /// This keeps the historical behavior and can return stale/dirty value under
89    /// concurrent invalidation.
90    pub fn new(
91        name: String,
92        cache: Cache<K, V>,
93        invalidator: Invalidator<K, V, CacheToken>,
94        initializer: Initializer<K, V>,
95        token_filter: fn(&CacheToken) -> bool,
96    ) -> Self {
97        Self::with_strategy(
98            name,
99            cache,
100            invalidator,
101            initializer,
102            token_filter,
103            InitStrategy::Unchecked,
104        )
105    }
106
107    /// Constructs an [CacheContainer] with explicit [InitStrategy].
108    ///
109    /// The strategy is fixed at construction time and cannot be changed later.
110    pub fn with_strategy(
111        name: String,
112        cache: Cache<K, V>,
113        invalidator: Invalidator<K, V, CacheToken>,
114        initializer: Initializer<K, V>,
115        token_filter: fn(&CacheToken) -> bool,
116        init_strategy: InitStrategy,
117    ) -> Self {
118        Self {
119            name,
120            cache,
121            invalidator,
122            initializer,
123            token_filter,
124            version: Arc::new(AtomicUsize::new(0)),
125            init_strategy,
126        }
127    }
128
129    /// Returns the `name`.
130    pub fn name(&self) -> &str {
131        &self.name
132    }
133}
134
135impl<K, V, CacheToken> CacheContainer<K, V, CacheToken> {
136    fn inc_version(&self) {
137        self.version.fetch_add(1, Ordering::Relaxed);
138    }
139}
140
141async fn init<'a, K, V>(init: Initializer<K, V>, key: K, cache_name: &'a str) -> Result<V>
142where
143    K: Send + Sync + 'a,
144    V: Send + 'a,
145{
146    metrics::CACHE_CONTAINER_CACHE_MISS
147        .with_label_values(&[cache_name])
148        .inc();
149    let _timer = metrics::CACHE_CONTAINER_LOAD_CACHE
150        .with_label_values(&[cache_name])
151        .start_timer();
152    init(&key)
153        .await
154        .transpose()
155        .context(error::ValueNotExistSnafu)?
156}
157
158async fn init_with_retry<'a, K, V>(
159    init: Initializer<K, V>,
160    key: K,
161    mut backoff: impl Iterator<Item = Duration> + 'a,
162    version: Arc<AtomicUsize>,
163    cache_name: &'a str,
164) -> Result<V>
165where
166    K: Send + Sync + 'a,
167    V: Send + 'a,
168{
169    let mut attempts = 1usize;
170    loop {
171        let pre_version = version.load(Ordering::Relaxed);
172        metrics::CACHE_CONTAINER_CACHE_MISS
173            .with_label_values(&[cache_name])
174            .inc();
175        let _timer = metrics::CACHE_CONTAINER_LOAD_CACHE
176            .with_label_values(&[cache_name])
177            .start_timer();
178        let value = init(&key)
179            .await
180            .transpose()
181            .context(error::ValueNotExistSnafu)??;
182
183        if pre_version == version.load(Ordering::Relaxed) {
184            return Ok(value);
185        }
186
187        if let Some(duration) = backoff.next() {
188            sleep(duration).await;
189            attempts += 1;
190        } else {
191            return error::GetLatestCacheRetryExceededSnafu { attempts }.fail();
192        }
193    }
194}
195
196#[async_trait::async_trait]
197impl<K, V> CacheInvalidator for CacheContainer<K, V, CacheIdent>
198where
199    K: Hash + Eq + Send + Sync + 'static,
200    V: Clone + Send + Sync + 'static,
201{
202    async fn invalidate(&self, _ctx: &Context, caches: &[CacheIdent]) -> Result<()> {
203        let idents = caches
204            .iter()
205            .filter(|token| (self.token_filter)(token))
206            .collect::<Vec<_>>();
207        if !idents.is_empty() {
208            self.inc_version();
209            (self.invalidator)(&self.cache, &idents).await?;
210        }
211
212        Ok(())
213    }
214
215    fn invalidate_all(&self) -> Result<()> {
216        self.inc_version();
217        self.cache.invalidate_all();
218        Ok(())
219    }
220}
221
222impl<K, V, CacheToken> CacheContainer<K, V, CacheToken>
223where
224    K: Copy + Hash + Eq + Send + Sync + 'static,
225    V: Clone + Send + Sync + 'static,
226{
227    /// Returns a value from cache for copyable keys.
228    ///
229    /// With [InitStrategy::Unchecked], this method prioritizes latency and may
230    /// return stale/dirty value. With [InitStrategy::VersionChecked], this method
231    /// retries initialization on version change and avoids dirty returns.
232    pub async fn get(&self, key: K) -> Result<Option<V>> {
233        metrics::CACHE_CONTAINER_CACHE_GET
234            .with_label_values(&[&self.name])
235            .inc();
236
237        let result = match self.init_strategy {
238            InitStrategy::Unchecked => {
239                self.cache
240                    .try_get_with(key, init(self.initializer.clone(), key, &self.name))
241                    .await
242            }
243            InitStrategy::VersionChecked => {
244                self.cache
245                    .try_get_with(
246                        key,
247                        init_with_retry(
248                            self.initializer.clone(),
249                            key,
250                            latest_get_backoff(),
251                            self.version.clone(),
252                            &self.name,
253                        ),
254                    )
255                    .await
256            }
257        };
258
259        match result {
260            Ok(value) => Ok(Some(value)),
261            Err(err) => match err.as_ref() {
262                Error::ValueNotExist { .. } => Ok(None),
263                _ => Err(err).context(error::GetCacheSnafu),
264            },
265        }
266    }
267}
268
269impl<K, V, CacheToken> CacheContainer<K, V, CacheToken>
270where
271    K: Hash + Eq + Send + Sync + 'static,
272    V: Clone + Send + Sync + 'static,
273{
274    /// Invalidates cache by [CacheToken].
275    pub async fn invalidate(&self, caches: &[CacheToken]) -> Result<()> {
276        let idents = caches
277            .iter()
278            .filter(|token| (self.token_filter)(token))
279            .collect::<Vec<_>>();
280        if !idents.is_empty() {
281            self.inc_version();
282            (self.invalidator)(&self.cache, &idents).await?;
283        }
284
285        Ok(())
286    }
287
288    /// Returns true if the cache contains a value for the key.
289    pub fn contains_key<Q>(&self, key: &Q) -> bool
290    where
291        K: Borrow<Q>,
292        Q: Hash + Eq + ?Sized,
293    {
294        self.cache.contains_key(key)
295    }
296
297    /// Returns a value from cache by key reference.
298    ///
299    /// With [InitStrategy::Unchecked], this method prioritizes latency and may
300    /// return stale/dirty value. With [InitStrategy::VersionChecked], this method
301    /// retries initialization on version change and avoids dirty returns.
302    pub async fn get_by_ref<Q>(&self, key: &Q) -> Result<Option<V>>
303    where
304        K: Borrow<Q>,
305        Q: ToOwned<Owned = K> + Hash + Eq + ?Sized,
306    {
307        metrics::CACHE_CONTAINER_CACHE_GET
308            .with_label_values(&[&self.name])
309            .inc();
310        let result = match self.init_strategy {
311            InitStrategy::Unchecked => {
312                self.cache
313                    .try_get_with_by_ref(
314                        key,
315                        init(self.initializer.clone(), key.to_owned(), &self.name),
316                    )
317                    .await
318            }
319            InitStrategy::VersionChecked => {
320                self.cache
321                    .try_get_with_by_ref(
322                        key,
323                        init_with_retry(
324                            self.initializer.clone(),
325                            key.to_owned(),
326                            latest_get_backoff(),
327                            self.version.clone(),
328                            &self.name,
329                        ),
330                    )
331                    .await
332            }
333        };
334
335        match result {
336            Ok(value) => Ok(Some(value)),
337            Err(err) => match err.as_ref() {
338                Error::ValueNotExist { .. } => Ok(None),
339                _ => Err(err).context(error::GetCacheSnafu),
340            },
341        }
342    }
343}
344
345#[cfg(test)]
346mod tests {
347    use std::sync::Arc;
348    use std::sync::atomic::{AtomicI32, Ordering};
349
350    use moka::future::{Cache, CacheBuilder};
351
352    use super::*;
353
354    #[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
355    struct NameKey<'a> {
356        name: &'a str,
357    }
358
359    fn always_true_filter(_: &String) -> bool {
360        true
361    }
362
363    #[tokio::test]
364    async fn test_get() {
365        let cache: Cache<NameKey, String> = CacheBuilder::new(128).build();
366        let counter = Arc::new(AtomicI32::new(0));
367        let moved_counter = counter.clone();
368        let init: Initializer<NameKey, String> = Arc::new(move |_| {
369            moved_counter.fetch_add(1, Ordering::Relaxed);
370            Box::pin(async { Ok(Some("hi".to_string())) })
371        });
372        let invalidator: Invalidator<NameKey, String, String> =
373            Box::new(|_, _| Box::pin(async { Ok(()) }));
374
375        let adv_cache = CacheContainer::new(
376            "test".to_string(),
377            cache,
378            invalidator,
379            init,
380            always_true_filter,
381        );
382        let key = NameKey { name: "key" };
383        let value = adv_cache.get(key).await.unwrap().unwrap();
384        assert_eq!(value, "hi");
385        assert_eq!(counter.load(Ordering::Relaxed), 1);
386        let key = NameKey { name: "key" };
387        let value = adv_cache.get(key).await.unwrap().unwrap();
388        assert_eq!(value, "hi");
389        assert_eq!(counter.load(Ordering::Relaxed), 1);
390    }
391
392    #[tokio::test]
393    async fn test_get_by_ref() {
394        let cache: Cache<String, String> = CacheBuilder::new(128).build();
395        let counter = Arc::new(AtomicI32::new(0));
396        let moved_counter = counter.clone();
397        let init: Initializer<String, String> = Arc::new(move |_| {
398            moved_counter.fetch_add(1, Ordering::Relaxed);
399            Box::pin(async { Ok(Some("hi".to_string())) })
400        });
401        let invalidator: Invalidator<String, String, String> =
402            Box::new(|_, _| Box::pin(async { Ok(()) }));
403
404        let adv_cache = CacheContainer::new(
405            "test".to_string(),
406            cache,
407            invalidator,
408            init,
409            always_true_filter,
410        );
411        let value = adv_cache.get_by_ref("foo").await.unwrap().unwrap();
412        assert_eq!(value, "hi");
413        let value = adv_cache.get_by_ref("foo").await.unwrap().unwrap();
414        assert_eq!(value, "hi");
415        assert_eq!(counter.load(Ordering::Relaxed), 1);
416        let value = adv_cache.get_by_ref("bar").await.unwrap().unwrap();
417        assert_eq!(value, "hi");
418        assert_eq!(counter.load(Ordering::Relaxed), 2);
419    }
420
421    #[tokio::test]
422    async fn test_get_value_not_exits() {
423        let cache: Cache<String, String> = CacheBuilder::new(128).build();
424        let init: Initializer<String, String> =
425            Arc::new(move |_| Box::pin(async { error::ValueNotExistSnafu {}.fail() }));
426        let invalidator: Invalidator<String, String, String> =
427            Box::new(|_, _| Box::pin(async { Ok(()) }));
428
429        let adv_cache = CacheContainer::new(
430            "test".to_string(),
431            cache,
432            invalidator,
433            init,
434            always_true_filter,
435        );
436        let value = adv_cache.get_by_ref("foo").await.unwrap();
437        assert!(value.is_none());
438    }
439
440    #[tokio::test]
441    async fn test_invalidate() {
442        let cache: Cache<String, String> = CacheBuilder::new(128).build();
443        let counter = Arc::new(AtomicI32::new(0));
444        let moved_counter = counter.clone();
445        let init: Initializer<String, String> = Arc::new(move |_| {
446            moved_counter.fetch_add(1, Ordering::Relaxed);
447            Box::pin(async { Ok(Some("hi".to_string())) })
448        });
449        let invalidator: Invalidator<String, String, String> = Box::new(|cache, keys| {
450            Box::pin(async move {
451                for key in keys {
452                    cache.invalidate(*key).await;
453                }
454                Ok(())
455            })
456        });
457
458        let adv_cache = CacheContainer::new(
459            "test".to_string(),
460            cache,
461            invalidator,
462            init,
463            always_true_filter,
464        );
465        let value = adv_cache.get_by_ref("foo").await.unwrap().unwrap();
466        assert_eq!(value, "hi");
467        let value = adv_cache.get_by_ref("foo").await.unwrap().unwrap();
468        assert_eq!(value, "hi");
469        assert_eq!(counter.load(Ordering::Relaxed), 1);
470        adv_cache
471            .invalidate(&["foo".to_string(), "bar".to_string()])
472            .await
473            .unwrap();
474        let value = adv_cache.get_by_ref("foo").await.unwrap().unwrap();
475        assert_eq!(value, "hi");
476        assert_eq!(counter.load(Ordering::Relaxed), 2);
477    }
478
479    #[tokio::test(flavor = "multi_thread")]
480    async fn test_get_by_ref_returns_fresh_value_after_invalidate() {
481        let cache: Cache<String, String> = CacheBuilder::new(128).build();
482        let counter = Arc::new(AtomicI32::new(0));
483        let moved_counter = counter.clone();
484        let init: Initializer<String, String> = Arc::new(move |_| {
485            let counter = moved_counter.clone();
486            Box::pin(async move {
487                let n = counter.fetch_add(1, Ordering::Relaxed) + 1;
488                sleep(Duration::from_millis(100)).await;
489                Ok(Some(format!("v{n}")))
490            })
491        });
492        let invalidator: Invalidator<String, String, String> = Box::new(|cache, keys| {
493            Box::pin(async move {
494                for key in keys {
495                    cache.invalidate(*key).await;
496                }
497                Ok(())
498            })
499        });
500
501        let adv_cache = Arc::new(CacheContainer::with_strategy(
502            "test".to_string(),
503            cache,
504            invalidator,
505            init,
506            always_true_filter,
507            InitStrategy::VersionChecked,
508        ));
509
510        let moved_cache = adv_cache.clone();
511        let get_task = tokio::spawn(async move { moved_cache.get_by_ref("foo").await });
512
513        sleep(Duration::from_millis(50)).await;
514        adv_cache.invalidate(&["foo".to_string()]).await.unwrap();
515
516        let value = get_task.await.unwrap().unwrap().unwrap();
517        assert_eq!(value, "v2");
518        assert_eq!(counter.load(Ordering::Relaxed), 2);
519    }
520}