fix: prevent stale in-flight cache refill after invalidation in CacheContainer (#7825)

* fix: prevent stale cache refill after invalidate

Signed-off-by: WenyXu <wenymedia@gmail.com>

* chore: apply suggestions from CR

Signed-off-by: WenyXu <wenymedia@gmail.com>

* feat: introduce `get_latest`

Signed-off-by: WenyXu <wenymedia@gmail.com>

* chore: styling

Signed-off-by: WenyXu <wenymedia@gmail.com>

* fix: enforce construction-time cache init strategy

Make cache initialization behavior explicit via InitStrategy selected at construction and document dirty-vs-checked semantics. Keep latest-read call compatibility while partition manager uses strategy-driven get paths.

Signed-off-by: WenyXu <wenymedia@gmail.com>

* test: rename get_by_ref freshness test

Signed-off-by: WenyXu <wenymedia@gmail.com>

* feat: use `InitStrategy::VersionChecked` for table route cache

Signed-off-by: WenyXu <wenymedia@gmail.com>

* chore: apply suggestions

Signed-off-by: WenyXu <wenymedia@gmail.com>

* chore: apply suggestions from CR

Signed-off-by: WenyXu <wenymedia@gmail.com>

* chore: apply suggestions from CR

Signed-off-by: WenyXu <wenymedia@gmail.com>

---------

Signed-off-by: WenyXu <wenymedia@gmail.com>
This commit is contained in:
Weny Xu
2026-03-24 12:24:15 +08:00
committed by GitHub
parent 5231ee40c8
commit 9bd983ea40
12 changed files with 313 additions and 88 deletions

View File

@@ -65,11 +65,13 @@ fn init_factory(
fn invalidator<'a>(
cache: &'a Cache<TableName, TableRef>,
ident: &'a CacheIdent,
idents: &'a [&CacheIdent],
) -> BoxFuture<'a, MetaResult<()>> {
Box::pin(async move {
if let CacheIdent::TableName(table_name) = ident {
cache.invalidate(table_name).await
for ident in idents {
if let CacheIdent::TableName(table_name) = ident {
cache.invalidate(table_name).await
}
}
Ok(())
})

View File

@@ -8,7 +8,6 @@ license.workspace = true
testing = []
pg_kvbackend = [
"dep:tokio-postgres",
"dep:backon",
"dep:deadpool-postgres",
"dep:deadpool",
"dep:tokio-postgres-rustls",
@@ -16,7 +15,7 @@ pg_kvbackend = [
"dep:rustls-native-certs",
"dep:rustls",
]
mysql_kvbackend = ["dep:sqlx", "dep:backon"]
mysql_kvbackend = ["dep:sqlx"]
enterprise = ["prost-types"]
[lints]
@@ -28,7 +27,7 @@ api.workspace = true
async-recursion = "1.0"
async-stream.workspace = true
async-trait.workspace = true
backon = { workspace = true, optional = true }
backon.workspace = true
base64.workspace = true
bytes.workspace = true
chrono.workspace = true

View File

@@ -15,10 +15,14 @@
use std::borrow::Borrow;
use std::hash::Hash;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::Duration;
use futures::future::{BoxFuture, join_all};
use backon::{BackoffBuilder, ExponentialBuilder};
use futures::future::BoxFuture;
use moka::future::Cache;
use snafu::{OptionExt, ResultExt};
use tokio::time::sleep;
use crate::cache_invalidator::{CacheInvalidator, Context};
use crate::error::{self, Error, Result};
@@ -29,12 +33,29 @@ use crate::metrics;
pub type TokenFilter<CacheToken> = Box<dyn Fn(&CacheToken) -> bool + Send + Sync>;
/// Invalidates cached values by [CacheToken]s.
pub type Invalidator<K, V, CacheToken> =
Box<dyn for<'a> Fn(&'a Cache<K, V>, &'a CacheToken) -> BoxFuture<'a, Result<()>> + Send + Sync>;
pub type Invalidator<K, V, CacheToken> = Box<
dyn for<'a> Fn(&'a Cache<K, V>, &'a [&CacheToken]) -> BoxFuture<'a, Result<()>> + Send + Sync,
>;
/// Initializes value (i.e., fetches from remote).
pub type Initializer<K, V> = Arc<dyn Fn(&'_ K) -> BoxFuture<'_, Result<Option<V>>> + Send + Sync>;
#[derive(Debug, Clone, Copy)]
/// Initialization strategy for cache-miss loading.
///
/// This strategy is selected when building [CacheContainer] and remains immutable
/// for the lifetime of the container instance.
pub enum InitStrategy {
/// Fast path: load once without version conflict retry.
///
/// Under concurrent invalidation, callers may observe stale/dirty value.
Unchecked,
/// Strict path: retry load when version changes during initialization.
///
/// This avoids returning dirty value under invalidate/load races.
VersionChecked,
}
/// [CacheContainer] provides ability to:
/// - Cache value loaded by [Initializer].
/// - Invalidate caches by [Invalidator].
@@ -44,6 +65,16 @@ pub struct CacheContainer<K, V, CacheToken> {
invalidator: Invalidator<K, V, CacheToken>,
initializer: Initializer<K, V>,
token_filter: fn(&CacheToken) -> bool,
version: Arc<AtomicUsize>,
init_strategy: InitStrategy,
}
fn latest_get_backoff() -> impl Iterator<Item = Duration> {
ExponentialBuilder::default()
.with_min_delay(Duration::from_millis(10))
.with_max_delay(Duration::from_millis(100))
.with_max_times(3)
.build()
}
impl<K, V, CacheToken> CacheContainer<K, V, CacheToken>
@@ -52,13 +83,37 @@ where
V: Send + Sync,
CacheToken: Send + Sync,
{
/// Constructs an [CacheContainer].
/// Constructs an [CacheContainer] with [InitStrategy::Unchecked].
///
/// This keeps the historical behavior and can return stale/dirty value under
/// concurrent invalidation.
pub fn new(
name: String,
cache: Cache<K, V>,
invalidator: Invalidator<K, V, CacheToken>,
initializer: Initializer<K, V>,
token_filter: fn(&CacheToken) -> bool,
) -> Self {
Self::with_strategy(
name,
cache,
invalidator,
initializer,
token_filter,
InitStrategy::Unchecked,
)
}
/// Constructs an [CacheContainer] with explicit [InitStrategy].
///
/// The strategy is fixed at construction time and cannot be changed later.
pub fn with_strategy(
name: String,
cache: Cache<K, V>,
invalidator: Invalidator<K, V, CacheToken>,
initializer: Initializer<K, V>,
token_filter: fn(&CacheToken) -> bool,
init_strategy: InitStrategy,
) -> Self {
Self {
name,
@@ -66,6 +121,8 @@ where
invalidator,
initializer,
token_filter,
version: Arc::new(AtomicUsize::new(0)),
init_strategy,
}
}
@@ -75,6 +132,67 @@ where
}
}
impl<K, V, CacheToken> CacheContainer<K, V, CacheToken> {
fn inc_version(&self) {
self.version.fetch_add(1, Ordering::Relaxed);
}
}
async fn init<'a, K, V>(init: Initializer<K, V>, key: K, cache_name: &'a str) -> Result<V>
where
K: Send + Sync + 'a,
V: Send + 'a,
{
metrics::CACHE_CONTAINER_CACHE_MISS
.with_label_values(&[cache_name])
.inc();
let _timer = metrics::CACHE_CONTAINER_LOAD_CACHE
.with_label_values(&[cache_name])
.start_timer();
init(&key)
.await
.transpose()
.context(error::ValueNotExistSnafu)?
}
async fn init_with_retry<'a, K, V>(
init: Initializer<K, V>,
key: K,
mut backoff: impl Iterator<Item = Duration> + 'a,
version: Arc<AtomicUsize>,
cache_name: &'a str,
) -> Result<V>
where
K: Send + Sync + 'a,
V: Send + 'a,
{
let mut attempts = 1usize;
loop {
let pre_version = version.load(Ordering::Relaxed);
metrics::CACHE_CONTAINER_CACHE_MISS
.with_label_values(&[cache_name])
.inc();
let _timer = metrics::CACHE_CONTAINER_LOAD_CACHE
.with_label_values(&[cache_name])
.start_timer();
let value = init(&key)
.await
.transpose()
.context(error::ValueNotExistSnafu)??;
if pre_version == version.load(Ordering::Relaxed) {
return Ok(value);
}
if let Some(duration) = backoff.next() {
sleep(duration).await;
attempts += 1;
} else {
return error::GetLatestCacheRetryExceededSnafu { attempts }.fail();
}
}
}
#[async_trait::async_trait]
impl<K, V> CacheInvalidator for CacheContainer<K, V, CacheIdent>
where
@@ -82,14 +200,15 @@ where
V: Send + Sync,
{
async fn invalidate(&self, _ctx: &Context, caches: &[CacheIdent]) -> Result<()> {
let tasks = caches
let idents = caches
.iter()
.filter(|token| (self.token_filter)(token))
.map(|token| (self.invalidator)(&self.cache, token));
join_all(tasks)
.await
.into_iter()
.collect::<Result<Vec<_>>>()?;
.collect::<Vec<_>>();
if !idents.is_empty() {
self.inc_version();
(self.invalidator)(&self.cache, &idents).await?;
}
Ok(())
}
}
@@ -99,27 +218,39 @@ where
K: Copy + Hash + Eq + Send + Sync + 'static,
V: Clone + Send + Sync + 'static,
{
/// Returns a _clone_ of the value corresponding to the key.
/// Returns a value from cache for copyable keys.
///
/// With [InitStrategy::Unchecked], this method prioritizes latency and may
/// return stale/dirty value. With [InitStrategy::VersionChecked], this method
/// retries initialization on version change and avoids dirty returns.
pub async fn get(&self, key: K) -> Result<Option<V>> {
metrics::CACHE_CONTAINER_CACHE_GET
.with_label_values(&[&self.name])
.inc();
let moved_init = self.initializer.clone();
let moved_key = key;
let init = async move {
metrics::CACHE_CONTAINER_CACHE_MISS
.with_label_values(&[&self.name])
.inc();
let _timer = metrics::CACHE_CONTAINER_LOAD_CACHE
.with_label_values(&[&self.name])
.start_timer();
moved_init(&moved_key)
.await
.transpose()
.context(error::ValueNotExistSnafu)?
let result = match self.init_strategy {
InitStrategy::Unchecked => {
self.cache
.try_get_with(key, init(self.initializer.clone(), key, &self.name))
.await
}
InitStrategy::VersionChecked => {
self.cache
.try_get_with(
key,
init_with_retry(
self.initializer.clone(),
key,
latest_get_backoff(),
self.version.clone(),
&self.name,
),
)
.await
}
};
match self.cache.try_get_with(key, init).await {
match result {
Ok(value) => Ok(Some(value)),
Err(err) => match err.as_ref() {
Error::ValueNotExist { .. } => Ok(None),
@@ -136,14 +267,15 @@ where
{
/// Invalidates cache by [CacheToken].
pub async fn invalidate(&self, caches: &[CacheToken]) -> Result<()> {
let tasks = caches
let idents = caches
.iter()
.filter(|token| (self.token_filter)(token))
.map(|token| (self.invalidator)(&self.cache, token));
join_all(tasks)
.await
.into_iter()
.collect::<Result<Vec<_>>>()?;
.collect::<Vec<_>>();
if !idents.is_empty() {
self.inc_version();
(self.invalidator)(&self.cache, &idents).await?;
}
Ok(())
}
@@ -156,7 +288,11 @@ where
self.cache.contains_key(key)
}
/// Returns a _clone_ of the value corresponding to the key.
/// Returns a value from cache by key reference.
///
/// With [InitStrategy::Unchecked], this method prioritizes latency and may
/// return stale/dirty value. With [InitStrategy::VersionChecked], this method
/// retries initialization on version change and avoids dirty returns.
pub async fn get_by_ref<Q>(&self, key: &Q) -> Result<Option<V>>
where
K: Borrow<Q>,
@@ -165,24 +301,32 @@ where
metrics::CACHE_CONTAINER_CACHE_GET
.with_label_values(&[&self.name])
.inc();
let moved_init = self.initializer.clone();
let moved_key = key.to_owned();
let init = async move {
metrics::CACHE_CONTAINER_CACHE_MISS
.with_label_values(&[&self.name])
.inc();
let _timer = metrics::CACHE_CONTAINER_LOAD_CACHE
.with_label_values(&[&self.name])
.start_timer();
moved_init(&moved_key)
.await
.transpose()
.context(error::ValueNotExistSnafu)?
let result = match self.init_strategy {
InitStrategy::Unchecked => {
self.cache
.try_get_with_by_ref(
key,
init(self.initializer.clone(), key.to_owned(), &self.name),
)
.await
}
InitStrategy::VersionChecked => {
self.cache
.try_get_with_by_ref(
key,
init_with_retry(
self.initializer.clone(),
key.to_owned(),
latest_get_backoff(),
self.version.clone(),
&self.name,
),
)
.await
}
};
match self.cache.try_get_with_by_ref(key, init).await {
match result {
Ok(value) => Ok(Some(value)),
Err(err) => match err.as_ref() {
Error::ValueNotExist { .. } => Ok(None),
@@ -296,9 +440,11 @@ mod tests {
moved_counter.fetch_add(1, Ordering::Relaxed);
Box::pin(async { Ok(Some("hi".to_string())) })
});
let invalidator: Invalidator<String, String, String> = Box::new(|cache, key| {
let invalidator: Invalidator<String, String, String> = Box::new(|cache, keys| {
Box::pin(async move {
cache.invalidate(key).await;
for key in keys {
cache.invalidate(*key).await;
}
Ok(())
})
});
@@ -323,4 +469,46 @@ mod tests {
assert_eq!(value, "hi");
assert_eq!(counter.load(Ordering::Relaxed), 2);
}
#[tokio::test(flavor = "multi_thread")]
async fn test_get_by_ref_returns_fresh_value_after_invalidate() {
let cache: Cache<String, String> = CacheBuilder::new(128).build();
let counter = Arc::new(AtomicI32::new(0));
let moved_counter = counter.clone();
let init: Initializer<String, String> = Arc::new(move |_| {
let counter = moved_counter.clone();
Box::pin(async move {
let n = counter.fetch_add(1, Ordering::Relaxed) + 1;
sleep(Duration::from_millis(100)).await;
Ok(Some(format!("v{n}")))
})
});
let invalidator: Invalidator<String, String, String> = Box::new(|cache, keys| {
Box::pin(async move {
for key in keys {
cache.invalidate(*key).await;
}
Ok(())
})
});
let adv_cache = Arc::new(CacheContainer::with_strategy(
"test".to_string(),
cache,
invalidator,
init,
always_true_filter,
InitStrategy::VersionChecked,
));
let moved_cache = adv_cache.clone();
let get_task = tokio::spawn(async move { moved_cache.get_by_ref("foo").await });
sleep(Duration::from_millis(50)).await;
adv_cache.invalidate(&["foo".to_string()]).await.unwrap();
let value = get_task.await.unwrap().unwrap().unwrap();
assert_eq!(value, "v2");
assert_eq!(counter.load(Ordering::Relaxed), 2);
}
}

View File

@@ -170,20 +170,22 @@ async fn handle_drop_flow(
fn invalidator<'a>(
cache: &'a Cache<TableId, FlownodeFlowSet>,
ident: &'a CacheIdent,
idents: &'a [&CacheIdent],
) -> BoxFuture<'a, Result<()>> {
Box::pin(async move {
match ident {
CacheIdent::CreateFlow(create_flow) => handle_create_flow(cache, create_flow).await,
CacheIdent::DropFlow(drop_flow) => handle_drop_flow(cache, drop_flow).await,
CacheIdent::FlowNodeAddressChange(node_id) => {
info!(
"Invalidate flow node cache for node_id in table_flownode: {}",
node_id
);
cache.invalidate_all();
for ident in idents {
match ident {
CacheIdent::CreateFlow(create_flow) => handle_create_flow(cache, create_flow).await,
CacheIdent::DropFlow(drop_flow) => handle_drop_flow(cache, drop_flow).await,
CacheIdent::FlowNodeAddressChange(node_id) => {
info!(
"Invalidate flow node cache for node_id in table_flownode: {}",
node_id
);
cache.invalidate_all();
}
_ => {}
}
_ => {}
}
Ok(())
})

View File

@@ -58,11 +58,13 @@ fn init_factory(schema_manager: SchemaManager) -> Initializer<SchemaName, Arc<Sc
fn invalidator<'a>(
cache: &'a Cache<SchemaName, Arc<SchemaNameValue>>,
ident: &'a CacheIdent,
idents: &'a [&CacheIdent],
) -> BoxFuture<'a, crate::error::Result<()>> {
Box::pin(async move {
if let CacheIdent::SchemaName(schema_name) = ident {
cache.invalidate(schema_name).await
for ident in idents {
if let CacheIdent::SchemaName(schema_name) = ident {
cache.invalidate(schema_name).await
}
}
Ok(())
})

View File

@@ -61,11 +61,13 @@ fn init_factory(table_info_manager: TableInfoManagerRef) -> Initializer<TableId,
fn invalidator<'a>(
cache: &'a Cache<TableId, Arc<TableInfo>>,
ident: &'a CacheIdent,
idents: &'a [&CacheIdent],
) -> BoxFuture<'a, Result<()>> {
Box::pin(async move {
if let CacheIdent::TableId(table_id) = ident {
cache.invalidate(table_id).await
for ident in idents {
if let CacheIdent::TableId(table_id) = ident {
cache.invalidate(table_id).await
}
}
Ok(())
})

View File

@@ -71,11 +71,13 @@ fn init_factory(table_name_manager: TableNameManagerRef) -> Initializer<TableNam
fn invalidator<'a>(
cache: &'a Cache<TableName, TableId>,
ident: &'a CacheIdent,
idents: &'a [&CacheIdent],
) -> BoxFuture<'a, Result<()>> {
Box::pin(async move {
if let CacheIdent::TableName(table_name) = ident {
cache.invalidate(table_name).await
for ident in idents {
if let CacheIdent::TableName(table_name) = ident {
cache.invalidate(table_name).await
}
}
Ok(())
})

View File

@@ -19,6 +19,7 @@ use moka::future::Cache;
use snafu::OptionExt;
use store_api::storage::TableId;
use crate::cache::container::InitStrategy;
use crate::cache::{CacheContainer, Initializer};
use crate::error;
use crate::error::Result;
@@ -65,7 +66,14 @@ pub fn new_table_route_cache(
let table_info_manager = Arc::new(TableRouteManager::new(kv_backend));
let init = init_factory(table_info_manager);
CacheContainer::new(name, cache, Box::new(invalidator), init, filter)
CacheContainer::with_strategy(
name,
cache,
Box::new(invalidator),
init,
filter,
InitStrategy::VersionChecked,
)
}
fn init_factory(
@@ -92,11 +100,13 @@ fn init_factory(
fn invalidator<'a>(
cache: &'a Cache<TableId, Arc<TableRoute>>,
ident: &'a CacheIdent,
idents: &'a [&CacheIdent],
) -> BoxFuture<'a, Result<()>> {
Box::pin(async move {
if let CacheIdent::TableId(table_id) = ident {
cache.invalidate(table_id).await
for ident in idents {
if let CacheIdent::TableId(table_id) = ident {
cache.invalidate(table_id).await
}
}
Ok(())
})

View File

@@ -65,7 +65,7 @@ fn init_factory(table_info_manager: TableInfoManager) -> Initializer<TableId, Ar
/// Never invalidates table id schema cache.
fn invalidator<'a>(
_cache: &'a Cache<TableId, Arc<SchemaName>>,
_ident: &'a CacheIdent,
_idents: &'a [&CacheIdent],
) -> BoxFuture<'a, error::Result<()>> {
Box::pin(std::future::ready(Ok(())))
}

View File

@@ -60,11 +60,13 @@ fn init_factory(view_info_manager: ViewInfoManagerRef) -> Initializer<TableId, A
fn invalidator<'a>(
cache: &'a Cache<TableId, Arc<ViewInfoValue>>,
ident: &'a CacheIdent,
idents: &'a [&CacheIdent],
) -> BoxFuture<'a, Result<()>> {
Box::pin(async move {
if let CacheIdent::TableId(view_id) = ident {
cache.invalidate(view_id).await
for ident in idents {
if let CacheIdent::TableId(view_id) = ident {
cache.invalidate(view_id).await
}
}
Ok(())
})

View File

@@ -714,6 +714,16 @@ pub enum Error {
#[snafu(display("Failed to get cache"))]
GetCache { source: Arc<Error> },
#[snafu(display(
"Failed to get latest cache value after {} attempts due to concurrent invalidation",
attempts
))]
GetLatestCacheRetryExceeded {
attempts: usize,
#[snafu(implicit)]
location: Location,
},
#[cfg(feature = "pg_kvbackend")]
#[snafu(display("Failed to execute via Postgres, sql: {}", sql))]
PostgresExecution {
@@ -1063,6 +1073,7 @@ impl ErrorExt for Error {
| ConnectEtcd { .. }
| MoveValues { .. }
| GetCache { .. }
| GetLatestCacheRetryExceeded { .. }
| SerializeToJson { .. }
| DeserializeFromJson { .. } => StatusCode::Internal,
@@ -1243,7 +1254,10 @@ impl Error {
/// Determine whether it is a retry later type through [StatusCode]
pub fn is_retry_later(&self) -> bool {
matches!(self, Error::RetryLater { .. })
matches!(
self,
Error::RetryLater { .. } | Error::GetLatestCacheRetryExceeded { .. }
)
}
/// Determine whether it needs to clean poisons.

View File

@@ -121,10 +121,12 @@ pub fn new_partition_info_cache(
CacheContainer::new(
name,
cache,
Box::new(|cache, ident| {
Box::new(|cache, idents| {
Box::pin(async move {
if let CacheIdent::TableId(table_id) = ident {
cache.invalidate(table_id).await
for ident in idents {
if let CacheIdent::TableId(table_id) = ident {
cache.invalidate(table_id).await
}
}
Ok(())
})