From 48ddc833ddb2f3a1d41309239fb6c10a52215dc8 Mon Sep 17 00:00:00 2001 From: Will Jones Date: Fri, 20 Feb 2026 11:18:33 -0800 Subject: [PATCH] feat: check for dataset updates in the background (#3021) This updates `DatasetConsistencyWrapper` to block less: 1. `DatasetConsistencyWrapper::get()` just returns `Arc` now, instead of a guard that blocks writes. `DatasetConsistencyWrapper::get_mut()` is gone; now write methods just use `get()` and then later call `update()` with the new version. This means a given table handle can do concurrent reads **and** writes. 2. In weak consistency mode, will check for dataset updates in the background, instead of blocking calls to `get()`. --------- Co-authored-by: Claude Sonnet 4.5 --- rust/lancedb/src/remote/table.rs | 262 +------- rust/lancedb/src/table.rs | 81 +-- rust/lancedb/src/table/datafusion/insert.rs | 2 +- rust/lancedb/src/table/dataset.rs | 703 +++++++++++++------- rust/lancedb/src/table/delete.rs | 15 +- rust/lancedb/src/table/merge.rs | 4 +- rust/lancedb/src/table/optimize.rs | 23 +- rust/lancedb/src/table/schema_evolution.rs | 27 +- rust/lancedb/src/table/update.rs | 11 +- rust/lancedb/src/utils/background_cache.rs | 593 +++++++++++++++++ rust/lancedb/src/{utils.rs => utils/mod.rs} | 2 + 11 files changed, 1162 insertions(+), 561 deletions(-) create mode 100644 rust/lancedb/src/utils/background_cache.rs rename rust/lancedb/src/{utils.rs => utils/mod.rs} (99%) diff --git a/rust/lancedb/src/remote/table.rs b/rust/lancedb/src/remote/table.rs index f488c6377..fcfca3b5a 100644 --- a/rust/lancedb/src/remote/table.rs +++ b/rust/lancedb/src/remote/table.rs @@ -24,6 +24,7 @@ use crate::table::MergeResult; use crate::table::Tags; use crate::table::UpdateResult; use crate::table::{AddDataMode, AnyQuery, Filter, TableStatistics}; +use crate::utils::background_cache::BackgroundCache; use crate::utils::{supported_btree_data_type, supported_vector_data_type}; use crate::{ error::Result, @@ -42,8 +43,7 @@ use async_trait::async_trait; use datafusion_common::DataFusionError; use datafusion_physical_plan::stream::RecordBatchStreamAdapter; use datafusion_physical_plan::{ExecutionPlan, RecordBatchStream, SendableRecordBatchStream}; -use futures::future::Shared; -use futures::{FutureExt, TryStreamExt}; +use futures::TryStreamExt; use http::header::CONTENT_TYPE; use http::{HeaderName, StatusCode}; use lance::arrow::json::{JsonDataType, JsonSchema}; @@ -58,7 +58,7 @@ use std::collections::HashMap; use std::io::Cursor; use std::pin::Pin; use std::sync::{Arc, Mutex}; -use std::time::{Duration, Instant}; +use std::time::Duration; use tokio::sync::RwLock; const REQUEST_TIMEOUT_HEADER: HeaderName = HeaderName::from_static("x-request-timeout-ms"); @@ -67,58 +67,6 @@ const INDEX_TYPE_KEY: &str = "index_type"; const SCHEMA_CACHE_TTL: Duration = Duration::from_secs(30); const SCHEMA_CACHE_REFRESH_WINDOW: Duration = Duration::from_secs(5); -type SharedSchemaFuture = - Shared>>>; - -enum SchemaState { - Empty, - Current(SchemaRef, Instant), - Refreshing { - previous: Option<(SchemaRef, Instant)>, - future: SharedSchemaFuture, - }, -} - -struct SchemaCache { - state: SchemaState, - /// Incremented on invalidation. Background fetches check this to avoid - /// overwriting with stale data after a concurrent invalidation. - generation: u64, -} - -enum SchemaAction { - Return(SchemaRef), - Wait(SharedSchemaFuture), -} - -impl SchemaState { - /// Returns the schema if it's fresh (not in the refresh window). - fn fresh_schema(&self) -> Option { - match self { - Self::Current(schema, cached_at) => { - let elapsed = clock::now().duration_since(*cached_at); - if elapsed < SCHEMA_CACHE_TTL - SCHEMA_CACHE_REFRESH_WINDOW { - Some(schema.clone()) - } else { - None - } - } - Self::Refreshing { - previous: Some((schema, cached_at)), - .. - } => { - let elapsed = clock::now().duration_since(*cached_at); - if elapsed < SCHEMA_CACHE_TTL - SCHEMA_CACHE_REFRESH_WINDOW { - Some(schema.clone()) - } else { - None - } - } - _ => None, - } - } -} - pub struct RemoteTags<'a, S: HttpSend = Sender> { inner: &'a RemoteTable, } @@ -263,7 +211,7 @@ pub struct RemoteTable { version: RwLock>, location: RwLock>, - schema_cache: Arc>, + schema_cache: BackgroundCache, } impl std::fmt::Debug for RemoteTable { @@ -291,10 +239,7 @@ impl RemoteTable { server_version, version: RwLock::new(None), location: RwLock::new(None), - schema_cache: Arc::new(Mutex::new(SchemaCache { - state: SchemaState::Empty, - generation: 0, - })), + schema_cache: BackgroundCache::new(SCHEMA_CACHE_TTL, SCHEMA_CACHE_REFRESH_WINDOW), } } @@ -844,9 +789,7 @@ impl RemoteTable { } fn invalidate_schema_cache(&self) { - let mut cache = self.schema_cache.lock().unwrap(); - cache.state = SchemaState::Empty; - cache.generation += 1; + self.schema_cache.invalidate(); } fn handle_error_invalidation(&self, error: &Error) { @@ -861,119 +804,6 @@ impl RemoteTable { } } } - - fn determine_schema_action( - &self, - cache: &mut SchemaCache, - version: Option, - ) -> SchemaAction { - match &cache.state { - SchemaState::Empty => { - let (shared, _) = self.start_schema_fetch(cache, version, None); - SchemaAction::Wait(shared) - } - SchemaState::Current(schema, cached_at) => { - let elapsed = clock::now().duration_since(*cached_at); - if elapsed < SCHEMA_CACHE_TTL - SCHEMA_CACHE_REFRESH_WINDOW { - SchemaAction::Return(schema.clone()) - } else if elapsed < SCHEMA_CACHE_TTL { - // In refresh window: start background fetch, return current value - let schema = schema.clone(); - let previous = Some((schema.clone(), *cached_at)); - let _ = self.start_schema_fetch(cache, version, previous); - SchemaAction::Return(schema) - } else { - // Expired: must wait for fetch - let previous = Some((schema.clone(), *cached_at)); - let (shared, _) = self.start_schema_fetch(cache, version, previous); - SchemaAction::Wait(shared) - } - } - SchemaState::Refreshing { previous, future } => { - // If the background fetch already completed (spawned task hasn't - // run yet to update state), transition the state and re-evaluate. - if let Some(result) = future.peek() { - match result { - Ok(schema) => { - cache.state = SchemaState::Current(schema.clone(), clock::now()); - } - Err(_) => { - cache.state = match previous.clone() { - Some((s, t)) => SchemaState::Current(s, t), - None => SchemaState::Empty, - }; - } - } - return self.determine_schema_action(cache, version); - } - - if let Some((schema, cached_at)) = previous { - if clock::now().duration_since(*cached_at) < SCHEMA_CACHE_TTL { - SchemaAction::Return(schema.clone()) - } else { - SchemaAction::Wait(future.clone()) - } - } else { - SchemaAction::Wait(future.clone()) - } - } - } - } - - fn start_schema_fetch( - &self, - cache: &mut SchemaCache, - version: Option, - previous: Option<(SchemaRef, Instant)>, - ) -> (SharedSchemaFuture, u64) { - let client = self.client.clone(); - let identifier = self.identifier.clone(); - let table_name = self.name.clone(); - let generation = cache.generation; - - let shared = async move { - fetch_schema(&client, &identifier, &table_name, version) - .await - .map_err(Arc::new) - } - .boxed() - .shared(); - - // Spawn task to eagerly drive the future and update state on completion - let schema_cache = self.schema_cache.clone(); - let fut_for_spawn = shared.clone(); - tokio::spawn(async move { - let result = fut_for_spawn.await; - let mut cache = schema_cache.lock().unwrap(); - // Only update if no invalidation has happened since we started - if cache.generation != generation { - return; - } - match result { - Ok(schema) => { - cache.state = SchemaState::Current(schema, clock::now()); - } - Err(_) => { - // Revert to previous cached value if available - let prev = match &cache.state { - SchemaState::Refreshing { previous, .. } => previous.clone(), - _ => None, - }; - cache.state = match prev { - Some((s, t)) => SchemaState::Current(s, t), - None => SchemaState::Empty, - }; - } - } - }); - - cache.state = SchemaState::Refreshing { - previous, - future: shared.clone(), - }; - - (shared, generation) - } } #[derive(Deserialize)] @@ -1054,8 +884,8 @@ impl std::fmt::Display for RemoteTable { #[cfg(all(test, feature = "remote"))] mod test_utils { use super::*; - use crate::remote::client::test_utils::MockSender; - use crate::remote::client::test_utils::{client_with_handler, client_with_handler_and_config}; + use crate::remote::client::test_utils::client_with_handler; + use crate::remote::client::test_utils::{client_with_handler_and_config, MockSender}; use crate::remote::ClientConfig; impl RemoteTable { @@ -1073,10 +903,7 @@ mod test_utils { server_version: version.map(ServerVersion).unwrap_or_default(), version: RwLock::new(None), location: RwLock::new(None), - schema_cache: Arc::new(Mutex::new(SchemaCache { - state: SchemaState::Empty, - generation: 0, - })), + schema_cache: BackgroundCache::new(SCHEMA_CACHE_TTL, SCHEMA_CACHE_REFRESH_WINDOW), } } @@ -1094,10 +921,7 @@ mod test_utils { server_version: ServerVersion::default(), version: RwLock::new(None), location: RwLock::new(None), - schema_cache: Arc::new(Mutex::new(SchemaCache { - state: SchemaState::Empty, - generation: 0, - })), + schema_cache: BackgroundCache::new(SCHEMA_CACHE_TTL, SCHEMA_CACHE_REFRESH_WINDOW), } } } @@ -1197,28 +1021,21 @@ impl BaseTable for RemoteTable { } async fn schema(&self) -> Result { - // Fast path: check if cache is fresh (not even in refresh window) - { - let cache = self.schema_cache.lock().unwrap(); - if let Some(schema) = cache.state.fresh_schema() { - return Ok(schema); - } + if let Some(schema) = self.schema_cache.try_get() { + return Ok(schema); } - // Slow path: may need to fetch or start background refresh let version = self.current_version().await; - let action = { - let mut cache = self.schema_cache.lock().unwrap(); - self.determine_schema_action(&mut cache, version) - }; + let client = self.client.clone(); + let identifier = self.identifier.clone(); + let table_name = self.name.clone(); - match action { - SchemaAction::Return(schema) => Ok(schema), - SchemaAction::Wait(fut) => match fut.await { - Ok(schema) => Ok(schema), - Err(arc_err) => Err(unwrap_shared_error(arc_err)), - }, - } + self.schema_cache + .get(move || async move { + fetch_schema(&client, &identifier, &table_name, version).await + }) + .await + .map_err(unwrap_shared_error) } async fn count_rows(&self, filter: Option) -> Result { @@ -2057,42 +1874,6 @@ impl TryFrom for MergeInsertRequest { } } -// Clock module for testing with mock time -#[cfg(test)] -mod clock { - use std::cell::Cell; - use std::time::{Duration, Instant}; - - thread_local! { - static MOCK_NOW: Cell> = const { Cell::new(None) }; - } - - pub fn now() -> Instant { - MOCK_NOW.with(|mock| mock.get().unwrap_or_else(Instant::now)) - } - - pub fn advance_by(duration: Duration) { - MOCK_NOW.with(|mock| { - let current = mock.get().unwrap_or_else(Instant::now); - mock.set(Some(current + duration)); - }); - } - - #[allow(dead_code)] - pub fn clear_mock() { - MOCK_NOW.with(|mock| mock.set(None)); - } -} - -#[cfg(not(test))] -mod clock { - use std::time::Instant; - - pub fn now() -> Instant { - Instant::now() - } -} - #[cfg(test)] mod tests { use std::sync::atomic::{AtomicUsize, Ordering}; @@ -2116,6 +1897,7 @@ mod tests { use crate::index::vector::{IvfFlatIndexBuilder, IvfHnswSqIndexBuilder}; use crate::remote::db::DEFAULT_SERVER_VERSION; use crate::remote::JSON_CONTENT_TYPE; + use crate::utils::background_cache::clock; use crate::{ index::{vector::IvfPqIndexBuilder, Index, IndexStatistics, IndexType}, query::{ExecutableQuery, QueryBase}, diff --git a/rust/lancedb/src/table.rs b/rust/lancedb/src/table.rs index 7d62e3271..f1449e10c 100644 --- a/rust/lancedb/src/table.rs +++ b/rust/lancedb/src/table.rs @@ -1640,11 +1640,10 @@ impl NativeTable { left_on: &str, right_on: &str, ) -> Result<()> { - self.dataset - .get_mut() - .await? - .merge(batches, left_on, right_on) - .await?; + self.dataset.ensure_mutable()?; + let mut dataset = (*self.dataset.get().await?).clone(); + dataset.merge(batches, left_on, right_on).await?; + self.dataset.update(dataset); Ok(()) } @@ -1964,8 +1963,10 @@ impl NativeTable { /// You can use [Self::uses_v2_manifest_paths] to check if the table is already /// using V2 manifest paths. pub async fn migrate_manifest_paths_v2(&self) -> Result<()> { - let mut dataset = self.dataset.get_mut().await?; + self.dataset.ensure_mutable()?; + let mut dataset = (*self.dataset.get().await?).clone(); dataset.migrate_manifest_paths_v2().await?; + self.dataset.update(dataset); Ok(()) } @@ -1980,17 +1981,21 @@ impl NativeTable { &self, upsert_values: impl IntoIterator, ) -> Result<()> { - let mut dataset = self.dataset.get_mut().await?; + self.dataset.ensure_mutable()?; + let mut dataset = (*self.dataset.get().await?).clone(); dataset.update_config(upsert_values).await?; + self.dataset.update(dataset); Ok(()) } /// Delete keys from the config pub async fn delete_config_keys(&self, delete_keys: &[&str]) -> Result<()> { - let mut dataset = self.dataset.get_mut().await?; + self.dataset.ensure_mutable()?; + let mut dataset = (*self.dataset.get().await?).clone(); // TODO: update this when we implement metadata APIs #[allow(deprecated)] dataset.delete_config_keys(delete_keys).await?; + self.dataset.update(dataset); Ok(()) } @@ -1999,10 +2004,12 @@ impl NativeTable { &self, upsert_values: impl IntoIterator, ) -> Result<()> { - let mut dataset = self.dataset.get_mut().await?; + self.dataset.ensure_mutable()?; + let mut dataset = (*self.dataset.get().await?).clone(); // TODO: update this when we implement metadata APIs #[allow(deprecated)] dataset.replace_schema_metadata(upsert_values).await?; + self.dataset.update(dataset); Ok(()) } @@ -2017,8 +2024,10 @@ impl NativeTable { &self, new_values: impl IntoIterator)>, ) -> Result<()> { - let mut dataset = self.dataset.get_mut().await?; + self.dataset.ensure_mutable()?; + let mut dataset = (*self.dataset.get().await?).clone(); dataset.replace_field_metadata(new_values).await?; + self.dataset.update(dataset); Ok(()) } } @@ -2054,9 +2063,7 @@ impl BaseTable for NativeTable { } async fn checkout_latest(&self) -> Result<()> { - self.dataset - .as_latest(self.read_consistency_interval) - .await?; + self.dataset.as_latest().await?; self.dataset.reload().await } @@ -2065,24 +2072,19 @@ impl BaseTable for NativeTable { } async fn restore(&self) -> Result<()> { - let version = - self.dataset - .time_travel_version() - .await - .ok_or_else(|| Error::InvalidInput { - message: "you must run checkout before running restore".to_string(), - })?; + let version = self + .dataset + .time_travel_version() + .ok_or_else(|| Error::InvalidInput { + message: "you must run checkout before running restore".to_string(), + })?; { - // Use get_mut_unchecked as restore is the only "write" operation that is allowed - // when the table is in time travel mode. - // Also, drop the guard after .restore because as_latest will need it - let mut dataset = self.dataset.get_mut_unchecked().await?; + // restore is the only "write" operation allowed in time travel mode + let mut dataset = (*self.dataset.get().await?).clone(); debug_assert_eq!(dataset.version().version, version); dataset.restore().await?; } - self.dataset - .as_latest(self.read_consistency_interval) - .await?; + self.dataset.as_latest().await?; Ok(()) } @@ -2121,16 +2123,15 @@ impl BaseTable for NativeTable { let data = scannable_with_embeddings(add.data, &table_def, add.embedding_registry.as_ref())?; - let dataset = { - // Limited scope for the mutable borrow of self.dataset avoids deadlock. - let ds = self.dataset.get_mut().await?; - InsertBuilder::new(Arc::new(ds.clone())) - .with_params(&lance_params) - .execute_stream(data) - .await? - }; + self.dataset.ensure_mutable()?; + let ds = self.dataset.get().await?; + let dataset = InsertBuilder::new(ds) + .with_params(&lance_params) + .execute_stream(data) + .await?; + let version = dataset.manifest().version; - self.dataset.set_latest(dataset).await; + self.dataset.update(dataset); Ok(AddResult { version }) } @@ -2147,7 +2148,8 @@ impl BaseTable for NativeTable { let lance_idx_params = self.make_index_params(field, opts.index.clone()).await?; let index_type = self.get_index_type_for_field(field, &opts.index); let columns = [field.name().as_str()]; - let mut dataset = self.dataset.get_mut().await?; + self.dataset.ensure_mutable()?; + let mut dataset = (*self.dataset.get().await?).clone(); let mut builder = dataset .create_index_builder(&columns, index_type, lance_idx_params.as_ref()) .train(opts.train) @@ -2157,12 +2159,15 @@ impl BaseTable for NativeTable { builder = builder.name(name); } builder.await?; + self.dataset.update(dataset); Ok(()) } async fn drop_index(&self, index_name: &str) -> Result<()> { - let mut dataset = self.dataset.get_mut().await?; + self.dataset.ensure_mutable()?; + let mut dataset = (*self.dataset.get().await?).clone(); dataset.drop_index(index_name).await?; + self.dataset.update(dataset); Ok(()) } diff --git a/rust/lancedb/src/table/datafusion/insert.rs b/rust/lancedb/src/table/datafusion/insert.rs index 53ae9520c..ae0489156 100644 --- a/rust/lancedb/src/table/datafusion/insert.rs +++ b/rust/lancedb/src/table/datafusion/insert.rs @@ -200,7 +200,7 @@ impl ExecutionPlan for InsertExec { let new_dataset = CommitBuilder::new(dataset.clone()) .execute(merged_txn) .await?; - ds_wrapper.set_latest(new_dataset).await; + ds_wrapper.update(new_dataset); } } diff --git a/rust/lancedb/src/table/dataset.rs b/rust/lancedb/src/table/dataset.rs index bbec316c7..31cb7fa8f 100644 --- a/rust/lancedb/src/table/dataset.rs +++ b/rust/lancedb/src/table/dataset.rs @@ -2,301 +2,499 @@ // SPDX-FileCopyrightText: Copyright The LanceDB Authors use std::{ - ops::{Deref, DerefMut}, - sync::Arc, - time::{self, Duration, Instant}, + sync::{Arc, Mutex}, + time::Duration, }; use lance::{dataset::refs, Dataset}; -use tokio::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard}; -use crate::error::Result; - -/// A wrapper around a [Dataset] that provides lazy-loading and consistency checks. -/// -/// This can be cloned cheaply. It supports concurrent reads or exclusive writes. -#[derive(Debug, Clone)] -pub struct DatasetConsistencyWrapper(Arc>); +use crate::{error::Result, utils::background_cache::BackgroundCache, Error}; /// A wrapper around a [Dataset] that provides consistency checks. /// -/// The dataset is lazily loaded, and starts off as None. On the first access, -/// the dataset is loaded. +/// This can be cloned cheaply. Callers get an [`Arc`] from [`get()`](Self::get) +/// and call [`update()`](Self::update) after writes to store the new version. #[derive(Debug, Clone)] -enum DatasetRef { - /// In this mode, the dataset is always the latest version. - Latest { - dataset: Dataset, - read_consistency_interval: Option, - last_consistency_check: Option, - }, - /// In this mode, the dataset is a specific version. It cannot be mutated. - TimeTravel { dataset: Dataset, version: u64 }, +pub struct DatasetConsistencyWrapper { + state: Arc>, + consistency: ConsistencyMode, } -impl DatasetRef { - /// Reload the dataset to the appropriate version. - async fn reload(&mut self) -> Result<()> { - match self { - Self::Latest { - dataset, - last_consistency_check, - .. - } => { - dataset.checkout_latest().await?; - last_consistency_check.replace(Instant::now()); - } - Self::TimeTravel { dataset, version } => { - dataset.checkout_version(*version).await?; - } - } - Ok(()) - } +/// The current dataset and whether it is pinned to a specific version. +#[derive(Debug, Clone)] +struct DatasetState { + dataset: Arc, + /// `Some(version)` = pinned to a specific version (time travel), + /// `None` = tracking latest. + pinned_version: Option, +} - fn is_latest(&self) -> bool { - matches!(self, Self::Latest { .. }) - } - - async fn as_latest(&mut self, read_consistency_interval: Option) -> Result<()> { - match self { - Self::Latest { .. } => Ok(()), - Self::TimeTravel { dataset, .. } => { - dataset - .checkout_version(dataset.latest_version_id().await?) - .await?; - *self = Self::Latest { - dataset: dataset.clone(), - read_consistency_interval, - last_consistency_check: Some(Instant::now()), - }; - Ok(()) - } - } - } - - async fn as_time_travel(&mut self, target_version: impl Into) -> Result<()> { - let target_ref = target_version.into(); - - match self { - Self::Latest { dataset, .. } => { - let new_dataset = dataset.checkout_version(target_ref.clone()).await?; - let version_value = new_dataset.version().version; - - *self = Self::TimeTravel { - dataset: new_dataset, - version: version_value, - }; - } - Self::TimeTravel { dataset, version } => { - let should_checkout = match &target_ref { - refs::Ref::Version(_, Some(target_ver)) => version != target_ver, - refs::Ref::Version(_, None) => true, // No specific version, always checkout - refs::Ref::VersionNumber(target_ver) => version != target_ver, - refs::Ref::Tag(_) => true, // Always checkout for tags - }; - - if should_checkout { - let new_dataset = dataset.checkout_version(target_ref).await?; - let version_value = new_dataset.version().version; - - *self = Self::TimeTravel { - dataset: new_dataset, - version: version_value, - }; - } - } - } - Ok(()) - } - - fn is_up_to_date(&self) -> bool { - match self { - Self::Latest { - read_consistency_interval, - last_consistency_check, - .. - } => match (read_consistency_interval, last_consistency_check) { - (None, _) => true, - (Some(_), None) => false, - (Some(interval), Some(last_check)) => last_check.elapsed() < *interval, - }, - Self::TimeTravel { dataset, version } => dataset.version().version == *version, - } - } - - fn time_travel_version(&self) -> Option { - match self { - Self::Latest { .. } => None, - Self::TimeTravel { version, .. } => Some(*version), - } - } - - fn set_latest(&mut self, dataset: Dataset) { - match self { - Self::Latest { - dataset: ref mut ds, - .. - } => { - if dataset.manifest().version > ds.manifest().version { - *ds = dataset; - } - } - _ => unreachable!("Dataset should be in latest mode at this point"), - } - } +#[derive(Debug, Clone)] +enum ConsistencyMode { + /// Only update table state when explicitly asked. + Lazy, + /// Always check for a new version on every read. + Strong, + /// Periodically check for new version in the background. If the table is being + /// regularly accessed, refresh will happen in the background. If the table is idle for a while, + /// the next access will trigger a refresh before returning the dataset. + /// + /// read_consistency_interval = TTL + /// refresh_window = min(3s, TTL/4) + /// + /// | t < TTL - refresh_window | t < TTL | t >= TTL | + /// | Return value | Background refresh & return value | syncronous refresh | + Eventual(BackgroundCache, Error>), } impl DatasetConsistencyWrapper { /// Create a new wrapper in the latest version mode. pub fn new_latest(dataset: Dataset, read_consistency_interval: Option) -> Self { - Self(Arc::new(RwLock::new(DatasetRef::Latest { - dataset, - read_consistency_interval, - last_consistency_check: Some(Instant::now()), - }))) + let dataset = Arc::new(dataset); + let consistency = match read_consistency_interval { + Some(d) if d == Duration::ZERO => ConsistencyMode::Strong, + Some(d) => { + let refresh_window = std::cmp::min(std::time::Duration::from_secs(3), d / 4); + let cache = BackgroundCache::new(d, refresh_window); + cache.seed(dataset.clone()); + ConsistencyMode::Eventual(cache) + } + None => ConsistencyMode::Lazy, + }; + Self { + state: Arc::new(Mutex::new(DatasetState { + dataset, + pinned_version: None, + })), + consistency, + } } - /// Get an immutable reference to the dataset. - pub async fn get(&self) -> Result> { - self.ensure_up_to_date().await?; - Ok(DatasetReadGuard { - guard: self.0.read().await, - }) - } - - /// Get a mutable reference to the dataset. + /// Get the current dataset. /// - /// If the dataset is in time travel mode this will fail - pub async fn get_mut(&self) -> Result> { - self.ensure_mutable().await?; - self.ensure_up_to_date().await?; - Ok(DatasetWriteGuard { - guard: self.0.write().await, - }) - } - - /// Get a mutable reference to the dataset without requiring the - /// dataset to be in a Latest mode. - pub async fn get_mut_unchecked(&self) -> Result> { - self.ensure_up_to_date().await?; - Ok(DatasetWriteGuard { - guard: self.0.write().await, - }) - } - - /// Convert into a wrapper in latest version mode - pub async fn as_latest(&self, read_consistency_interval: Option) -> Result<()> { - if self.0.read().await.is_latest() { - return Ok(()); + /// Behavior depends on the consistency mode: + /// - **Lazy** (`None`): returns the cached dataset immediately. + /// - **Strong** (`Some(ZERO)`): checks for a new version before returning. + /// - **Eventual** (`Some(d)` where `d > 0`): returns a cached value immediately + /// while refreshing in the background when the TTL expires. + /// + /// If pinned to a specific version (time travel), always returns the + /// pinned dataset regardless of consistency mode. + pub async fn get(&self) -> Result> { + { + let state = self.state.lock().unwrap(); + if state.pinned_version.is_some() { + return Ok(state.dataset.clone()); + } } - let mut write_guard = self.0.write().await; - if write_guard.is_latest() { - return Ok(()); + match &self.consistency { + ConsistencyMode::Eventual(bg_cache) => { + if let Some(dataset) = bg_cache.try_get() { + return Ok(dataset); + } + let state = self.state.clone(); + bg_cache + .get(move || refresh_latest(state)) + .await + .map_err(unwrap_shared_error) + } + ConsistencyMode::Strong => refresh_latest(self.state.clone()).await, + ConsistencyMode::Lazy => { + let state = self.state.lock().unwrap(); + Ok(state.dataset.clone()) + } } + } - write_guard.as_latest(read_consistency_interval).await + /// Store a new dataset version after a write operation. + /// + /// Only stores the dataset if its version is newer than the current one. + /// If the wrapper has since transitioned to time-travel mode (e.g. via a + /// concurrent [`as_time_travel`](Self::as_time_travel) call), the update + /// is silently ignored — the write already committed to storage. + pub fn update(&self, dataset: Dataset) { + let mut state = self.state.lock().unwrap(); + if state.pinned_version.is_some() { + // A concurrent as_time_travel() beat us here. The write succeeded + // in storage, but since we're now pinned we don't advance the + // cached pointer. + return; + } + if dataset.manifest().version > state.dataset.manifest().version { + state.dataset = Arc::new(dataset); + } + drop(state); + if let ConsistencyMode::Eventual(bg_cache) = &self.consistency { + bg_cache.invalidate(); + } + } + + /// Checkout a branch and track its HEAD for new versions. + pub async fn as_branch(&self, _branch: impl Into) -> Result<()> { + todo!("Branch support not yet implemented") + } + + /// Check that the dataset is in a mutable mode (Latest). + pub fn ensure_mutable(&self) -> Result<()> { + let state = self.state.lock().unwrap(); + if state.pinned_version.is_some() { + Err(crate::Error::InvalidInput { + message: "table cannot be modified when a specific version is checked out" + .to_string(), + }) + } else { + Ok(()) + } + } + + /// Returns the version, if in time travel mode, or None otherwise. + pub fn time_travel_version(&self) -> Option { + self.state.lock().unwrap().pinned_version + } + + /// Convert into a wrapper in latest version mode. + pub async fn as_latest(&self) -> Result<()> { + let dataset = { + let state = self.state.lock().unwrap(); + if state.pinned_version.is_none() { + return Ok(()); + } + state.dataset.clone() + }; + + let latest_version = dataset.latest_version_id().await?; + let new_dataset = dataset.checkout_version(latest_version).await?; + + let mut state = self.state.lock().unwrap(); + if state.pinned_version.is_some() { + state.dataset = Arc::new(new_dataset); + state.pinned_version = None; + } + drop(state); + if let ConsistencyMode::Eventual(bg_cache) = &self.consistency { + bg_cache.invalidate(); + } + Ok(()) } pub async fn as_time_travel(&self, target_version: impl Into) -> Result<()> { - self.0.write().await.as_time_travel(target_version).await - } + let target_ref = target_version.into(); - /// Provide a known latest version of the dataset. - /// - /// This is usually done after some write operation, which inherently will - /// have the latest version. - pub async fn set_latest(&self, dataset: Dataset) { - self.0.write().await.set_latest(dataset); + let (should_checkout, dataset) = { + let state = self.state.lock().unwrap(); + let should = match state.pinned_version { + None => true, + Some(version) => match &target_ref { + refs::Ref::Version(_, Some(target_ver)) => version != *target_ver, + refs::Ref::Version(_, None) => true, + refs::Ref::VersionNumber(target_ver) => version != *target_ver, + refs::Ref::Tag(_) => true, + }, + }; + (should, state.dataset.clone()) + }; + + if !should_checkout { + return Ok(()); + } + + let new_dataset = dataset.checkout_version(target_ref).await?; + let version_value = new_dataset.version().version; + + let mut state = self.state.lock().unwrap(); + state.dataset = Arc::new(new_dataset); + state.pinned_version = Some(version_value); + Ok(()) } pub async fn reload(&self) -> Result<()> { - self.0.write().await.reload().await - } + let (dataset, pinned_version) = { + let state = self.state.lock().unwrap(); + (state.dataset.clone(), state.pinned_version) + }; - /// Returns the version, if in time travel mode, or None otherwise - pub async fn time_travel_version(&self) -> Option { - self.0.read().await.time_travel_version() - } + match pinned_version { + None => { + refresh_latest(self.state.clone()).await?; + if let ConsistencyMode::Eventual(bg_cache) = &self.consistency { + bg_cache.invalidate(); + } + } + Some(version) => { + if dataset.version().version == version { + return Ok(()); + } - pub async fn ensure_mutable(&self) -> Result<()> { - let dataset_ref = self.0.read().await; - match &*dataset_ref { - DatasetRef::Latest { .. } => Ok(()), - DatasetRef::TimeTravel { .. } => Err(crate::Error::InvalidInput { - message: "table cannot be modified when a specific version is checked out" - .to_string(), - }), - } - } + let new_dataset = dataset.checkout_version(version).await?; - async fn is_up_to_date(&self) -> bool { - self.0.read().await.is_up_to_date() - } - - /// Ensures that the dataset is loaded and up-to-date with consistency and - /// version parameters. - async fn ensure_up_to_date(&self) -> Result<()> { - if !self.is_up_to_date().await { - // Re-check under write lock — another task may have reloaded - // while we waited for the lock. - let mut write_guard = self.0.write().await; - if !write_guard.is_up_to_date() { - write_guard.reload().await?; + let mut state = self.state.lock().unwrap(); + if state.pinned_version == Some(version) { + state.dataset = Arc::new(new_dataset); + } } } + Ok(()) } } -pub struct DatasetReadGuard<'a> { - guard: RwLockReadGuard<'a, DatasetRef>, -} +async fn refresh_latest(state: Arc>) -> Result> { + let dataset = { state.lock().unwrap().dataset.clone() }; -impl Deref for DatasetReadGuard<'_> { - type Target = Dataset; + let mut ds = (*dataset).clone(); + ds.checkout_latest().await?; + let new_arc = Arc::new(ds); - fn deref(&self) -> &Self::Target { - match &*self.guard { - DatasetRef::Latest { dataset, .. } => dataset, - DatasetRef::TimeTravel { dataset, .. } => dataset, + { + let mut state = state.lock().unwrap(); + if state.pinned_version.is_none() + && new_arc.manifest().version >= state.dataset.manifest().version + { + state.dataset = new_arc.clone(); } } + + Ok(new_arc) } -pub struct DatasetWriteGuard<'a> { - guard: RwLockWriteGuard<'a, DatasetRef>, -} - -impl Deref for DatasetWriteGuard<'_> { - type Target = Dataset; - - fn deref(&self) -> &Self::Target { - match &*self.guard { - DatasetRef::Latest { dataset, .. } => dataset, - DatasetRef::TimeTravel { dataset, .. } => dataset, - } - } -} - -impl DerefMut for DatasetWriteGuard<'_> { - fn deref_mut(&mut self) -> &mut Self::Target { - match &mut *self.guard { - DatasetRef::Latest { dataset, .. } => dataset, - DatasetRef::TimeTravel { dataset, .. } => dataset, - } +fn unwrap_shared_error(arc: Arc) -> Error { + match Arc::try_unwrap(arc) { + Ok(err) => err, + Err(arc) => Error::Runtime { + message: arc.to_string(), + }, } } #[cfg(test)] mod tests { + use std::time::Instant; + + use arrow_array::{Int32Array, RecordBatch, RecordBatchIterator}; use arrow_schema::{DataType, Field, Schema}; - use lance::{dataset::WriteParams, io::ObjectStoreParams}; + use lance::{ + dataset::{WriteMode, WriteParams}, + io::ObjectStoreParams, + }; use super::*; use crate::{connect, io::object_store::io_tracking::IoStatsHolder, table::WriteOptions}; + async fn create_test_dataset(uri: &str) -> Dataset { + let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)])); + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], + ) + .unwrap(); + Dataset::write( + RecordBatchIterator::new(vec![Ok(batch)], schema), + uri, + Some(WriteParams::default()), + ) + .await + .unwrap() + } + + async fn append_to_dataset(uri: &str) -> Dataset { + let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)])); + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Int32Array::from(vec![4, 5, 6]))], + ) + .unwrap(); + Dataset::write( + RecordBatchIterator::new(vec![Ok(batch)], schema), + uri, + Some(WriteParams { + mode: WriteMode::Append, + ..Default::default() + }), + ) + .await + .unwrap() + } + + #[tokio::test] + async fn test_get_returns_dataset() { + let dir = tempfile::tempdir().unwrap(); + let uri = dir.path().to_str().unwrap(); + let ds = create_test_dataset(uri).await; + let version = ds.version().version; + + let wrapper = DatasetConsistencyWrapper::new_latest(ds, None); + let ds1 = wrapper.get().await.unwrap(); + let ds2 = wrapper.get().await.unwrap(); + + assert_eq!(ds1.version().version, version); + assert_eq!(ds2.version().version, version); + + // Arc is independent — not borrowing from wrapper + drop(wrapper); + assert_eq!(ds1.version().version, version); + } + + #[tokio::test] + async fn test_update_stores_newer_version() { + let dir = tempfile::tempdir().unwrap(); + let uri = dir.path().to_str().unwrap(); + let ds_v1 = create_test_dataset(uri).await; + assert_eq!(ds_v1.version().version, 1); + + let wrapper = DatasetConsistencyWrapper::new_latest(ds_v1, None); + + let ds_v2 = append_to_dataset(uri).await; + assert_eq!(ds_v2.version().version, 2); + + wrapper.update(ds_v2); + + let ds = wrapper.get().await.unwrap(); + assert_eq!(ds.version().version, 2); + } + + #[tokio::test] + async fn test_update_ignores_older_version() { + let dir = tempfile::tempdir().unwrap(); + let uri = dir.path().to_str().unwrap(); + let ds_v1 = create_test_dataset(uri).await; + let ds_v2 = append_to_dataset(uri).await; + + let wrapper = DatasetConsistencyWrapper::new_latest(ds_v2, None); + wrapper.update(ds_v1); + + let ds = wrapper.get().await.unwrap(); + assert_eq!(ds.version().version, 2); + } + + #[tokio::test] + async fn test_ensure_mutable_allows_latest() { + let dir = tempfile::tempdir().unwrap(); + let uri = dir.path().to_str().unwrap(); + let ds = create_test_dataset(uri).await; + + let wrapper = DatasetConsistencyWrapper::new_latest(ds, None); + assert!(wrapper.ensure_mutable().is_ok()); + } + + #[tokio::test] + async fn test_ensure_mutable_rejects_time_travel() { + let dir = tempfile::tempdir().unwrap(); + let uri = dir.path().to_str().unwrap(); + let ds = create_test_dataset(uri).await; + + let wrapper = DatasetConsistencyWrapper::new_latest(ds, None); + wrapper.as_time_travel(1u64).await.unwrap(); + + assert!(wrapper.ensure_mutable().is_err()); + } + + #[tokio::test] + async fn test_time_travel_version() { + let dir = tempfile::tempdir().unwrap(); + let uri = dir.path().to_str().unwrap(); + let ds = create_test_dataset(uri).await; + + let wrapper = DatasetConsistencyWrapper::new_latest(ds, None); + assert_eq!(wrapper.time_travel_version(), None); + + wrapper.as_time_travel(1u64).await.unwrap(); + assert_eq!(wrapper.time_travel_version(), Some(1)); + } + + #[tokio::test] + async fn test_as_latest_from_time_travel() { + let dir = tempfile::tempdir().unwrap(); + let uri = dir.path().to_str().unwrap(); + let ds = create_test_dataset(uri).await; + + let wrapper = DatasetConsistencyWrapper::new_latest(ds, None); + wrapper.as_time_travel(1u64).await.unwrap(); + assert!(wrapper.ensure_mutable().is_err()); + + wrapper.as_latest().await.unwrap(); + assert!(wrapper.ensure_mutable().is_ok()); + assert_eq!(wrapper.time_travel_version(), None); + } + + #[tokio::test] + async fn test_lazy_consistency_never_refreshes() { + let dir = tempfile::tempdir().unwrap(); + let uri = dir.path().to_str().unwrap(); + let ds = create_test_dataset(uri).await; + + let wrapper = DatasetConsistencyWrapper::new_latest(ds, None); + let v1 = wrapper.get().await.unwrap().version().version; + + // External write + append_to_dataset(uri).await; + + // Lazy consistency should not pick up external write + let v_after = wrapper.get().await.unwrap().version().version; + assert_eq!(v1, v_after); + } + + #[tokio::test] + async fn test_strong_consistency_always_refreshes() { + let dir = tempfile::tempdir().unwrap(); + let uri = dir.path().to_str().unwrap(); + let ds = create_test_dataset(uri).await; + + let wrapper = DatasetConsistencyWrapper::new_latest(ds, Some(Duration::ZERO)); + let v1 = wrapper.get().await.unwrap().version().version; + + // External write + append_to_dataset(uri).await; + + // Strong consistency should pick up external write + let v_after = wrapper.get().await.unwrap().version().version; + assert_eq!(v_after, v1 + 1); + } + + #[tokio::test] + async fn test_eventual_consistency_background_refresh() { + let dir = tempfile::tempdir().unwrap(); + let uri = dir.path().to_str().unwrap(); + let ds = create_test_dataset(uri).await; + + let wrapper = DatasetConsistencyWrapper::new_latest(ds, Some(Duration::from_millis(200))); + + // Populate the cache + let v1 = wrapper.get().await.unwrap().version().version; + assert_eq!(v1, 1); + + // External write + append_to_dataset(uri).await; + + // Should return cached value immediately (within TTL) + let v_cached = wrapper.get().await.unwrap().version().version; + assert_eq!(v_cached, 1); + + // Wait for TTL to expire, then get() should trigger a refresh + tokio::time::sleep(Duration::from_millis(300)).await; + let v_after = wrapper.get().await.unwrap().version().version; + assert_eq!(v_after, 2); + } + + #[tokio::test] + async fn test_eventual_consistency_update_invalidates_cache() { + let dir = tempfile::tempdir().unwrap(); + let uri = dir.path().to_str().unwrap(); + let ds_v1 = create_test_dataset(uri).await; + + let wrapper = DatasetConsistencyWrapper::new_latest(ds_v1, Some(Duration::from_secs(60))); + + // Simulate a write that produces v2 + let ds_v2 = append_to_dataset(uri).await; + wrapper.update(ds_v2); + + // get() should return v2 immediately (update invalidated the bg_cache, + // and the mutex state was updated) + let v = wrapper.get().await.unwrap().version().version; + assert_eq!(v, 2); + } + #[tokio::test] async fn test_iops_open_strong_consistency() { let db = connect("memory://") @@ -312,7 +510,7 @@ mod tests { .create_empty_table("test", schema) .write_options(WriteOptions { lance_write_params: Some(WriteParams { - store_params: Some(ObjectStoreParams { + store_params: Some(lance::io::ObjectStoreParams { object_store_wrapper: Some(Arc::new(io_stats.clone())), ..Default::default() }), @@ -332,6 +530,31 @@ mod tests { assert_eq!(stats.read_iops, 1); } + /// Regression test: a write that races with as_time_travel() must not panic. + /// + /// Sequence: ensure_mutable() passes → as_time_travel() completes → write + /// calls update(). Previously the assert!() in update() would fire. + #[tokio::test] + async fn test_update_after_concurrent_time_travel_does_not_panic() { + let dir = tempfile::tempdir().unwrap(); + let uri = dir.path().to_str().unwrap(); + let ds_v1 = create_test_dataset(uri).await; + + let wrapper = DatasetConsistencyWrapper::new_latest(ds_v1, None); + + // Simulate: as_time_travel() completes just before the write's update(). + wrapper.as_time_travel(1u64).await.unwrap(); + assert_eq!(wrapper.time_travel_version(), Some(1)); + + // The write already committed to storage; now it calls update(). + // This must not panic, and the wrapper must stay pinned. + let ds_v2 = append_to_dataset(uri).await; + wrapper.update(ds_v2); + + let ds = wrapper.get().await.unwrap(); + assert_eq!(ds.version().version, 1); + } + /// Regression test: before the fix, the reload fast-path (no version change) /// did not reset `last_consistency_check`, causing a list call on every /// subsequent query once the interval expired. diff --git a/rust/lancedb/src/table/delete.rs b/rust/lancedb/src/table/delete.rs index 4c4c304ba..dbc24ceb7 100644 --- a/rust/lancedb/src/table/delete.rs +++ b/rust/lancedb/src/table/delete.rs @@ -18,17 +18,12 @@ pub struct DeleteResult { /// /// This logic was moved from NativeTable::delete to keep table.rs clean. pub(crate) async fn execute_delete(table: &NativeTable, predicate: &str) -> Result { - // We access the dataset from the table. Since this is in the same module hierarchy (super), - // and 'dataset' is pub(crate), we can access it. - let mut dataset = table.dataset.get_mut().await?; - - // Perform the actual delete on the Lance dataset + table.dataset.ensure_mutable()?; + let mut dataset = (*table.dataset.get().await?).clone(); dataset.delete(predicate).await?; - - // Return the result with the new version - Ok(DeleteResult { - version: dataset.version().version, - }) + let version = dataset.version().version; + table.dataset.update(dataset); + Ok(DeleteResult { version }) } #[cfg(test)] diff --git a/rust/lancedb/src/table/merge.rs b/rust/lancedb/src/table/merge.rs index 1f2b6777e..d8805acb8 100644 --- a/rust/lancedb/src/table/merge.rs +++ b/rust/lancedb/src/table/merge.rs @@ -165,7 +165,7 @@ pub(crate) async fn execute_merge_insert( params: MergeInsertBuilder, new_data: Box, ) -> Result { - let dataset = Arc::new(table.dataset.get().await?.clone()); + let dataset = table.dataset.get().await?; let mut builder = LanceMergeInsertBuilder::try_new(dataset.clone(), params.on)?; match ( params.when_matched_update_all, @@ -210,7 +210,7 @@ pub(crate) async fn execute_merge_insert( }; let (new_dataset, stats) = future.await?; let version = new_dataset.manifest().version; - table.dataset.set_latest(new_dataset.as_ref().clone()).await; + table.dataset.update(new_dataset.as_ref().clone()); Ok(MergeResult { version, num_updated_rows: stats.num_updated_rows, diff --git a/rust/lancedb/src/table/optimize.rs b/rust/lancedb/src/table/optimize.rs index f75671e5b..abe660b38 100644 --- a/rust/lancedb/src/table/optimize.rs +++ b/rust/lancedb/src/table/optimize.rs @@ -105,12 +105,10 @@ pub struct OptimizeStats { /// This logic was moved from NativeTable to keep table.rs clean. pub(crate) async fn optimize_indices(table: &NativeTable, options: &OptimizeOptions) -> Result<()> { info!("LanceDB: optimizing indices: {:?}", options); - table - .dataset - .get_mut() - .await? - .optimize_indices(options) - .await?; + table.dataset.ensure_mutable()?; + let mut dataset = (*table.dataset.get().await?).clone(); + dataset.optimize_indices(options).await?; + table.dataset.update(dataset); Ok(()) } @@ -131,10 +129,9 @@ pub(crate) async fn cleanup_old_versions( delete_unverified: Option, error_if_tagged_old_versions: Option, ) -> Result { - Ok(table - .dataset - .get_mut() - .await? + table.dataset.ensure_mutable()?; + let dataset = table.dataset.get().await?; + Ok(dataset .cleanup_old_versions(older_than, delete_unverified, error_if_tagged_old_versions) .await?) } @@ -150,8 +147,10 @@ pub(crate) async fn compact_files_impl( options: CompactionOptions, remap_options: Option>, ) -> Result { - let mut dataset_mut = table.dataset.get_mut().await?; - let metrics = compact_files(&mut dataset_mut, options, remap_options).await?; + table.dataset.ensure_mutable()?; + let mut dataset = (*table.dataset.get().await?).clone(); + let metrics = compact_files(&mut dataset, options, remap_options).await?; + table.dataset.update(dataset); Ok(metrics) } diff --git a/rust/lancedb/src/table/schema_evolution.rs b/rust/lancedb/src/table/schema_evolution.rs index 3f774cd98..6adf7f3a8 100644 --- a/rust/lancedb/src/table/schema_evolution.rs +++ b/rust/lancedb/src/table/schema_evolution.rs @@ -52,11 +52,12 @@ pub(crate) async fn execute_add_columns( transforms: NewColumnTransform, read_columns: Option>, ) -> Result { - let mut dataset = table.dataset.get_mut().await?; + table.dataset.ensure_mutable()?; + let mut dataset = (*table.dataset.get().await?).clone(); dataset.add_columns(transforms, read_columns, None).await?; - Ok(AddColumnsResult { - version: dataset.version().version, - }) + let version = dataset.version().version; + table.dataset.update(dataset); + Ok(AddColumnsResult { version }) } /// Internal implementation of the alter columns logic. @@ -66,11 +67,12 @@ pub(crate) async fn execute_alter_columns( table: &NativeTable, alterations: &[ColumnAlteration], ) -> Result { - let mut dataset = table.dataset.get_mut().await?; + table.dataset.ensure_mutable()?; + let mut dataset = (*table.dataset.get().await?).clone(); dataset.alter_columns(alterations).await?; - Ok(AlterColumnsResult { - version: dataset.version().version, - }) + let version = dataset.version().version; + table.dataset.update(dataset); + Ok(AlterColumnsResult { version }) } /// Internal implementation of the drop columns logic. @@ -80,11 +82,12 @@ pub(crate) async fn execute_drop_columns( table: &NativeTable, columns: &[&str], ) -> Result { - let mut dataset = table.dataset.get_mut().await?; + table.dataset.ensure_mutable()?; + let mut dataset = (*table.dataset.get().await?).clone(); dataset.drop_columns(columns).await?; - Ok(DropColumnsResult { - version: dataset.version().version, - }) + let version = dataset.version().version; + table.dataset.update(dataset); + Ok(DropColumnsResult { version }) } #[cfg(test)] diff --git a/rust/lancedb/src/table/update.rs b/rust/lancedb/src/table/update.rs index 1d5ca5f4d..07ff7fd8e 100644 --- a/rust/lancedb/src/table/update.rs +++ b/rust/lancedb/src/table/update.rs @@ -78,11 +78,13 @@ pub(crate) async fn execute_update( table: &NativeTable, update: UpdateBuilder, ) -> Result { + table.dataset.ensure_mutable()?; + // 1. Snapshot the current dataset - let dataset = table.dataset.get().await?.clone(); + let dataset = table.dataset.get().await?; // 2. Initialize the Lance Core builder - let mut builder = LanceUpdateBuilder::new(Arc::new(dataset)); + let mut builder = LanceUpdateBuilder::new(dataset); // 3. Apply the filter (WHERE clause) if let Some(predicate) = update.filter { @@ -99,10 +101,7 @@ pub(crate) async fn execute_update( let res = operation.execute().await?; // 6. Update the table's view of the latest version - table - .dataset - .set_latest(res.new_dataset.as_ref().clone()) - .await; + table.dataset.update(res.new_dataset.as_ref().clone()); Ok(UpdateResult { rows_updated: res.rows_updated, diff --git a/rust/lancedb/src/utils/background_cache.rs b/rust/lancedb/src/utils/background_cache.rs new file mode 100644 index 000000000..37d0aa2b8 --- /dev/null +++ b/rust/lancedb/src/utils/background_cache.rs @@ -0,0 +1,593 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The LanceDB Authors + +//! A cache that refreshes values in the background before they expire. +//! +//! See [`BackgroundCache`] for details. + +use std::future::Future; +use std::sync::{Arc, Mutex}; +use std::time::Duration; + +use futures::future::{BoxFuture, Shared}; +use futures::FutureExt; + +type SharedFut = Shared>>>; + +enum State { + Empty, + Current(V, clock::Instant), + Refreshing { + previous: Option<(V, clock::Instant)>, + future: SharedFut, + }, +} + +impl State { + fn fresh_value(&self, ttl: Duration, refresh_window: Duration) -> Option { + let fresh_threshold = ttl - refresh_window; + match self { + Self::Current(value, cached_at) => { + if clock::now().duration_since(*cached_at) < fresh_threshold { + Some(value.clone()) + } else { + None + } + } + Self::Refreshing { + previous: Some((value, cached_at)), + .. + } => { + if clock::now().duration_since(*cached_at) < fresh_threshold { + Some(value.clone()) + } else { + None + } + } + _ => None, + } + } +} + +struct CacheInner { + state: State, + /// Incremented on invalidation. Background fetches check this to avoid + /// overwriting with stale data after a concurrent invalidation. + generation: u64, +} + +enum Action { + Return(V), + Wait(SharedFut), +} + +/// A cache that refreshes values in the background before they expire. +/// +/// The cache has three states: +/// - **Empty**: No cached value. The next [`get()`](Self::get) blocks until a fetch completes. +/// - **Current**: A valid cached value with a timestamp. Returns immediately if fresh. +/// - **Refreshing**: A fetch is in progress. Returns the previous value if still valid, +/// otherwise blocks until the fetch completes. +/// +/// When the cached value enters the refresh window (close to TTL expiry), +/// [`get()`](Self::get) starts a background fetch and returns the current value +/// immediately. Multiple concurrent callers share a single in-flight fetch. +pub struct BackgroundCache { + inner: Arc>>, + ttl: Duration, + refresh_window: Duration, +} + +impl std::fmt::Debug for BackgroundCache { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("BackgroundCache") + .field("ttl", &self.ttl) + .field("refresh_window", &self.refresh_window) + .finish_non_exhaustive() + } +} + +impl Clone for BackgroundCache { + fn clone(&self) -> Self { + Self { + inner: self.inner.clone(), + ttl: self.ttl, + refresh_window: self.refresh_window, + } + } +} + +impl BackgroundCache +where + V: Clone + Send + Sync + 'static, + E: Send + Sync + 'static, +{ + pub fn new(ttl: Duration, refresh_window: Duration) -> Self { + assert!( + refresh_window < ttl, + "refresh_window ({refresh_window:?}) must be less than ttl ({ttl:?})" + ); + Self { + inner: Arc::new(Mutex::new(CacheInner { + state: State::Empty, + generation: 0, + })), + ttl, + refresh_window, + } + } + + /// Returns the cached value if it's fresh (not in the refresh window). + /// + /// This is a cheap synchronous check useful as a fast path before + /// constructing a fetch closure for [`get()`](Self::get). + pub fn try_get(&self) -> Option { + let cache = self.inner.lock().unwrap(); + cache.state.fresh_value(self.ttl, self.refresh_window) + } + + /// Get the cached value, fetching if needed. + /// + /// The closure is called to create the fetch future only when a new fetch + /// is needed. If the cache already has an in-flight fetch, the closure is + /// not called and the caller joins the existing fetch. + pub async fn get(&self, fetch: F) -> Result> + where + F: FnOnce() -> Fut + Send + 'static, + Fut: Future> + Send + 'static, + { + // Fast path: check if cache is fresh + { + let cache = self.inner.lock().unwrap(); + if let Some(value) = cache.state.fresh_value(self.ttl, self.refresh_window) { + return Ok(value); + } + } + + // Slow path + let mut fetch = Some(fetch); + let action = { + let mut cache = self.inner.lock().unwrap(); + self.determine_action(&mut cache, &mut fetch) + }; + + match action { + Action::Return(value) => Ok(value), + Action::Wait(fut) => fut.await, + } + } + + /// Pre-populate the cache with an initial value. + /// + /// This avoids a blocking fetch on the first [`get()`](Self::get) call. + pub fn seed(&self, value: V) { + let mut cache = self.inner.lock().unwrap(); + cache.state = State::Current(value, clock::now()); + } + + /// Invalidate the cache. The next [`get()`](Self::get) will start a fresh fetch. + /// + /// Any in-flight background fetch from before this call will not update the + /// cache (the generation counter prevents stale writes). + pub fn invalidate(&self) { + let mut cache = self.inner.lock().unwrap(); + cache.state = State::Empty; + cache.generation += 1; + } + + fn determine_action( + &self, + cache: &mut CacheInner, + fetch: &mut Option, + ) -> Action + where + F: FnOnce() -> Fut + Send + 'static, + Fut: Future> + Send + 'static, + { + match &cache.state { + State::Empty => { + let f = fetch + .take() + .expect("fetch closure required for empty cache"); + let shared = self.start_fetch(cache, f, None); + Action::Wait(shared) + } + State::Current(value, cached_at) => { + let elapsed = clock::now().duration_since(*cached_at); + if elapsed < self.ttl - self.refresh_window { + Action::Return(value.clone()) + } else if elapsed < self.ttl { + // In refresh window: start background fetch, return current value + let value = value.clone(); + let previous = Some((value.clone(), *cached_at)); + if let Some(f) = fetch.take() { + // The spawned task inside start_fetch drives the future; + // we don't need to await the returned handle here. + drop(self.start_fetch(cache, f, previous)); + } + Action::Return(value) + } else { + // Expired: must wait for fetch + let previous = Some((value.clone(), *cached_at)); + let f = fetch + .take() + .expect("fetch closure required for expired cache"); + let shared = self.start_fetch(cache, f, previous); + Action::Wait(shared) + } + } + State::Refreshing { previous, future } => { + // If the background fetch already completed (spawned task hasn't + // run yet to update state), transition the state and re-evaluate. + if let Some(result) = future.peek() { + match result { + Ok(value) => { + cache.state = State::Current(value.clone(), clock::now()); + } + Err(_) => { + cache.state = match previous.clone() { + Some((v, t)) => State::Current(v, t), + None => State::Empty, + }; + } + } + return self.determine_action(cache, fetch); + } + + if let Some((value, cached_at)) = previous { + if clock::now().duration_since(*cached_at) < self.ttl { + Action::Return(value.clone()) + } else { + Action::Wait(future.clone()) + } + } else { + Action::Wait(future.clone()) + } + } + } + } + + fn start_fetch( + &self, + cache: &mut CacheInner, + fetch: F, + previous: Option<(V, clock::Instant)>, + ) -> SharedFut + where + F: FnOnce() -> Fut + Send + 'static, + Fut: Future> + Send + 'static, + { + let generation = cache.generation; + let shared = async move { (fetch)().await.map_err(Arc::new) } + .boxed() + .shared(); + + // Spawn task to eagerly drive the future and update state on completion + let inner = self.inner.clone(); + let fut_for_spawn = shared.clone(); + tokio::spawn(async move { + let result = fut_for_spawn.await; + let mut cache = inner.lock().unwrap(); + // Only update if no invalidation has happened since we started + if cache.generation != generation { + return; + } + match result { + Ok(value) => { + cache.state = State::Current(value, clock::now()); + } + Err(_) => { + let prev = match &cache.state { + State::Refreshing { previous, .. } => previous.clone(), + _ => None, + }; + cache.state = match prev { + Some((v, t)) => State::Current(v, t), + None => State::Empty, + }; + } + } + }); + + cache.state = State::Refreshing { + previous, + future: shared.clone(), + }; + + shared + } +} + +#[cfg(test)] +pub mod clock { + use std::cell::Cell; + use std::time::Duration; + + // Re-export Instant so callers use the same type + pub use std::time::Instant; + + thread_local! { + static MOCK_NOW: Cell> = const { Cell::new(None) }; + } + + pub fn now() -> Instant { + MOCK_NOW.with(|mock| mock.get().unwrap_or_else(Instant::now)) + } + + pub fn advance_by(duration: Duration) { + MOCK_NOW.with(|mock| { + let current = mock.get().unwrap_or_else(Instant::now); + mock.set(Some(current + duration)); + }); + } + + #[allow(dead_code)] + pub fn clear_mock() { + MOCK_NOW.with(|mock| mock.set(None)); + } +} + +#[cfg(not(test))] +mod clock { + // Re-export Instant so callers use the same type + pub use std::time::Instant; + + pub fn now() -> Instant { + Instant::now() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::atomic::{AtomicUsize, Ordering}; + + #[derive(Debug)] + struct TestError(String); + + impl std::fmt::Display for TestError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } + } + + const TEST_TTL: Duration = Duration::from_secs(30); + const TEST_REFRESH_WINDOW: Duration = Duration::from_secs(5); + + fn new_cache() -> BackgroundCache { + BackgroundCache::new(TEST_TTL, TEST_REFRESH_WINDOW) + } + + fn ok_fetcher( + counter: Arc, + value: &str, + ) -> impl FnOnce() -> BoxFuture<'static, Result> + Send + 'static { + let value = value.to_string(); + move || { + counter.fetch_add(1, Ordering::SeqCst); + async move { Ok(value) }.boxed() + } + } + + fn err_fetcher( + counter: Arc, + msg: &str, + ) -> impl FnOnce() -> BoxFuture<'static, Result> + Send + 'static { + let msg = msg.to_string(); + move || { + counter.fetch_add(1, Ordering::SeqCst); + async move { Err(TestError(msg)) }.boxed() + } + } + + #[tokio::test] + async fn test_basic_caching() { + let cache = new_cache(); + let count = Arc::new(AtomicUsize::new(0)); + + let v1 = cache.get(ok_fetcher(count.clone(), "hello")).await.unwrap(); + assert_eq!(v1, "hello"); + assert_eq!(count.load(Ordering::SeqCst), 1); + + // Second call triggers peek transition to Current, returns cached + let v2 = cache.get(ok_fetcher(count.clone(), "hello")).await.unwrap(); + assert_eq!(v2, "hello"); + assert_eq!(count.load(Ordering::SeqCst), 1); + + // Third call still cached + let v3 = cache.get(ok_fetcher(count.clone(), "hello")).await.unwrap(); + assert_eq!(v3, "hello"); + assert_eq!(count.load(Ordering::SeqCst), 1); + } + + #[tokio::test] + async fn test_try_get_returns_none_when_empty() { + let cache: BackgroundCache = new_cache(); + assert!(cache.try_get().is_none()); + } + + #[tokio::test] + async fn test_try_get_returns_value_when_fresh() { + let cache = new_cache(); + let count = Arc::new(AtomicUsize::new(0)); + + cache.get(ok_fetcher(count.clone(), "hello")).await.unwrap(); + // Peek transition + cache.get(ok_fetcher(count.clone(), "hello")).await.unwrap(); + + assert_eq!(cache.try_get().unwrap(), "hello"); + } + + #[tokio::test] + async fn test_try_get_returns_none_in_refresh_window() { + let cache = new_cache(); + let count = Arc::new(AtomicUsize::new(0)); + + cache.get(ok_fetcher(count.clone(), "hello")).await.unwrap(); + cache.get(ok_fetcher(count.clone(), "hello")).await.unwrap(); // peek + + clock::advance_by(Duration::from_secs(26)); + assert!(cache.try_get().is_none()); + } + + #[tokio::test] + async fn test_ttl_expiration() { + let cache = new_cache(); + let count = Arc::new(AtomicUsize::new(0)); + + cache.get(ok_fetcher(count.clone(), "v1")).await.unwrap(); + cache.get(ok_fetcher(count.clone(), "v1")).await.unwrap(); // peek + assert_eq!(count.load(Ordering::SeqCst), 1); + + clock::advance_by(Duration::from_secs(31)); + + let v = cache.get(ok_fetcher(count.clone(), "v2")).await.unwrap(); + assert_eq!(v, "v2"); + assert_eq!(count.load(Ordering::SeqCst), 2); + } + + #[tokio::test] + async fn test_invalidate_forces_refetch() { + let cache = new_cache(); + let count = Arc::new(AtomicUsize::new(0)); + + cache.get(ok_fetcher(count.clone(), "v1")).await.unwrap(); + cache.get(ok_fetcher(count.clone(), "v1")).await.unwrap(); // peek + assert_eq!(count.load(Ordering::SeqCst), 1); + + cache.invalidate(); + + let v = cache.get(ok_fetcher(count.clone(), "v2")).await.unwrap(); + assert_eq!(v, "v2"); + assert_eq!(count.load(Ordering::SeqCst), 2); + } + + #[tokio::test] + async fn test_concurrent_get_single_fetch() { + let cache = Arc::new(new_cache()); + let count = Arc::new(AtomicUsize::new(0)); + + let mut handles = Vec::new(); + for _ in 0..10 { + let cache = cache.clone(); + let count = count.clone(); + handles.push(tokio::spawn(async move { + cache.get(ok_fetcher(count, "hello")).await.unwrap() + })); + } + + let results: Vec = futures::future::try_join_all(handles).await.unwrap(); + for r in &results { + assert_eq!(r, "hello"); + } + assert_eq!(count.load(Ordering::SeqCst), 1); + } + + #[tokio::test] + async fn test_background_refresh_in_window() { + let cache = new_cache(); + let count = Arc::new(AtomicUsize::new(0)); + + // Populate and transition to Current + cache.get(ok_fetcher(count.clone(), "v1")).await.unwrap(); + cache.get(ok_fetcher(count.clone(), "v1")).await.unwrap(); // peek + assert_eq!(count.load(Ordering::SeqCst), 1); + + // Move into refresh window + clock::advance_by(Duration::from_secs(26)); + + // Returns cached value and starts background fetch + let v = cache.get(ok_fetcher(count.clone(), "v2")).await.unwrap(); + assert_eq!(v, "v1"); // Still old value + assert_eq!(count.load(Ordering::SeqCst), 1); // bg task hasn't run yet + + // Advance past TTL to force waiting on the shared future + clock::advance_by(Duration::from_secs(30)); + + let v = cache.get(ok_fetcher(count.clone(), "v3")).await.unwrap(); + assert_eq!(count.load(Ordering::SeqCst), 2); + assert_eq!(v, "v2"); // Got the bg refresh result + } + + #[tokio::test] + async fn test_no_duplicate_background_refreshes() { + let cache = new_cache(); + let count = Arc::new(AtomicUsize::new(0)); + + // Populate and transition to Current + cache.get(ok_fetcher(count.clone(), "v1")).await.unwrap(); + cache.get(ok_fetcher(count.clone(), "v1")).await.unwrap(); // peek + assert_eq!(count.load(Ordering::SeqCst), 1); + + // Move into refresh window + clock::advance_by(Duration::from_secs(26)); + + // Multiple calls should all return cached, only one bg fetch + for _ in 0..5 { + let v = cache.get(ok_fetcher(count.clone(), "v2")).await.unwrap(); + assert_eq!(v, "v1"); + } + + // Drive the shared future to completion + clock::advance_by(Duration::from_secs(30)); + cache.get(ok_fetcher(count.clone(), "v3")).await.unwrap(); + + // Only 1 additional fetch (the background refresh) + assert_eq!(count.load(Ordering::SeqCst), 2); + } + + #[tokio::test] + async fn test_background_refresh_error_preserves_cache() { + let cache = new_cache(); + let count = Arc::new(AtomicUsize::new(0)); + + // Populate and transition to Current + cache.get(ok_fetcher(count.clone(), "v1")).await.unwrap(); + cache.get(ok_fetcher(count.clone(), "v1")).await.unwrap(); // peek + assert_eq!(count.load(Ordering::SeqCst), 1); + + // Move into refresh window + clock::advance_by(Duration::from_secs(26)); + + // Start bg refresh that will fail, returns cached value + let v = cache.get(err_fetcher(count.clone(), "fail")).await.unwrap(); + assert_eq!(v, "v1"); + + // Still in refresh window, previous is valid + let v = cache.get(err_fetcher(count.clone(), "fail")).await.unwrap(); + assert_eq!(v, "v1"); + + // Advance past TTL to drive the failed future + clock::advance_by(Duration::from_secs(30)); + + // The peek error path restores previous, but it's expired, + // so a new fetch is needed. This one also fails. + let result = cache.get(err_fetcher(count.clone(), "fail again")).await; + assert!(result.is_err()); + assert_eq!(count.load(Ordering::SeqCst), 2); + } + + #[tokio::test] + async fn test_invalidation_during_fetch_prevents_stale_update() { + let cache = new_cache(); + let count = Arc::new(AtomicUsize::new(0)); + + // Populate and transition to Current + cache.get(ok_fetcher(count.clone(), "v1")).await.unwrap(); + cache.get(ok_fetcher(count.clone(), "v1")).await.unwrap(); // peek + + // Move into refresh window to start background fetch + clock::advance_by(Duration::from_secs(26)); + cache.get(ok_fetcher(count.clone(), "stale")).await.unwrap(); + + // Invalidate before bg task completes + cache.invalidate(); + + // Advance past TTL + clock::advance_by(Duration::from_secs(30)); + + // Should get fresh data, not the stale background result + let v = cache.get(ok_fetcher(count.clone(), "fresh")).await.unwrap(); + assert_eq!(v, "fresh"); + } +} diff --git a/rust/lancedb/src/utils.rs b/rust/lancedb/src/utils/mod.rs similarity index 99% rename from rust/lancedb/src/utils.rs rename to rust/lancedb/src/utils/mod.rs index 970512a5e..031487a50 100644 --- a/rust/lancedb/src/utils.rs +++ b/rust/lancedb/src/utils/mod.rs @@ -1,6 +1,8 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The LanceDB Authors +pub(crate) mod background_cache; + use std::sync::Arc; use arrow_array::RecordBatch;