diff --git a/rust/lancedb/src/remote/table.rs b/rust/lancedb/src/remote/table.rs index 8ed4b6236..facf15343 100644 --- a/rust/lancedb/src/remote/table.rs +++ b/rust/lancedb/src/remote/table.rs @@ -26,7 +26,8 @@ use async_trait::async_trait; use datafusion_common::DataFusionError; use datafusion_physical_plan::stream::RecordBatchStreamAdapter; use datafusion_physical_plan::{ExecutionPlan, RecordBatchStream, SendableRecordBatchStream}; -use futures::TryStreamExt; +use futures::future::Shared; +use futures::{FutureExt, TryStreamExt}; use http::header::CONTENT_TYPE; use http::{HeaderName, StatusCode}; use lance::arrow::json::{JsonDataType, JsonSchema}; @@ -41,7 +42,7 @@ use std::collections::HashMap; use std::io::Cursor; use std::pin::Pin; use std::sync::{Arc, Mutex}; -use std::time::Duration; +use std::time::{Duration, Instant}; use tokio::sync::RwLock; use super::client::RequestResultExt; @@ -63,6 +64,60 @@ use crate::{ const REQUEST_TIMEOUT_HEADER: HeaderName = HeaderName::from_static("x-request-timeout-ms"); const METRIC_TYPE_KEY: &str = "metric_type"; const INDEX_TYPE_KEY: &str = "index_type"; +const SCHEMA_CACHE_TTL: Duration = Duration::from_secs(30); +const SCHEMA_CACHE_REFRESH_WINDOW: Duration = Duration::from_secs(5); + +type SharedSchemaFuture = + Shared>>>; + +enum SchemaState { + Empty, + Current(SchemaRef, Instant), + Refreshing { + previous: Option<(SchemaRef, Instant)>, + future: SharedSchemaFuture, + }, +} + +struct SchemaCache { + state: SchemaState, + /// Incremented on invalidation. Background fetches check this to avoid + /// overwriting with stale data after a concurrent invalidation. + generation: u64, +} + +enum SchemaAction { + Return(SchemaRef), + Wait(SharedSchemaFuture), +} + +impl SchemaState { + /// Returns the schema if it's fresh (not in the refresh window). + fn fresh_schema(&self) -> Option { + match self { + Self::Current(schema, cached_at) => { + let elapsed = clock::now().duration_since(*cached_at); + if elapsed < SCHEMA_CACHE_TTL - SCHEMA_CACHE_REFRESH_WINDOW { + Some(schema.clone()) + } else { + None + } + } + Self::Refreshing { + previous: Some((schema, cached_at)), + .. + } => { + let elapsed = clock::now().duration_since(*cached_at); + if elapsed < SCHEMA_CACHE_TTL - SCHEMA_CACHE_REFRESH_WINDOW { + Some(schema.clone()) + } else { + None + } + } + _ => None, + } + } +} pub struct RemoteTags<'a, S: HttpSend = Sender> { inner: &'a RemoteTable, @@ -198,7 +253,6 @@ impl Tags for RemoteTags<'_, S> { } } -#[derive(Debug)] pub struct RemoteTable { #[allow(dead_code)] client: RestfulLanceDbClient, @@ -209,6 +263,16 @@ pub struct RemoteTable { version: RwLock>, location: RwLock>, + schema_cache: Arc>, +} + +impl std::fmt::Debug for RemoteTable { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("RemoteTable") + .field("name", &self.name) + .field("identifier", &self.identifier) + .finish_non_exhaustive() + } } impl RemoteTable { @@ -227,6 +291,10 @@ impl RemoteTable { server_version, version: RwLock::new(None), location: RwLock::new(None), + schema_cache: Arc::new(Mutex::new(SchemaCache { + state: SchemaState::Empty, + generation: 0, + })), } } @@ -470,14 +538,41 @@ impl RemoteTable { Ok(response) } + /// Check if a status code should trigger schema cache invalidation + fn should_invalidate_cache_for_status(status: StatusCode) -> bool { + // Only invalidate for errors that could be schema-related + // Don't invalidate for auth errors (401, 403) or temporary failures (503, 502) + matches!( + status, + StatusCode::BAD_REQUEST // 400 - could be schema mismatch + | StatusCode::NOT_FOUND // 404 - table might have been recreated + | StatusCode::UNPROCESSABLE_ENTITY // 422 - schema validation error + | StatusCode::INTERNAL_SERVER_ERROR // 500 - could be schema issue on server + ) + } + async fn check_table_response( &self, request_id: &str, response: reqwest::Response, ) -> Result { - let response = Self::handle_table_not_found(&self.name, response, request_id).await?; + let status = response.status(); + let not_found_result = Self::handle_table_not_found(&self.name, response, request_id).await; - self.client.check_response(request_id, response).await + // Check if we should invalidate cache for 404 errors + if not_found_result.is_err() && Self::should_invalidate_cache_for_status(status) { + self.invalidate_schema_cache(); + } + + let response = not_found_result?; + let result = self.client.check_response(request_id, response).await; + + // Invalidate schema cache on errors that could be schema-related + if result.is_err() && Self::should_invalidate_cache_for_status(status) { + self.invalidate_schema_cache(); + } + + result } async fn read_arrow_stream( @@ -747,6 +842,138 @@ impl RemoteTable { AnyQuery::VectorQuery(query) => self.apply_vector_query_params(base_body, query), } } + + fn invalidate_schema_cache(&self) { + let mut cache = self.schema_cache.lock().unwrap(); + cache.state = SchemaState::Empty; + cache.generation += 1; + } + + fn handle_error_invalidation(&self, error: &Error) { + let status_code = match error { + Error::Http { status_code, .. } => *status_code, + Error::Retry { status_code, .. } => *status_code, + _ => None, + }; + if let Some(status_code) = status_code { + if Self::should_invalidate_cache_for_status(status_code) { + self.invalidate_schema_cache(); + } + } + } + + fn determine_schema_action( + &self, + cache: &mut SchemaCache, + version: Option, + ) -> SchemaAction { + match &cache.state { + SchemaState::Empty => { + let (shared, _) = self.start_schema_fetch(cache, version, None); + SchemaAction::Wait(shared) + } + SchemaState::Current(schema, cached_at) => { + let elapsed = clock::now().duration_since(*cached_at); + if elapsed < SCHEMA_CACHE_TTL - SCHEMA_CACHE_REFRESH_WINDOW { + SchemaAction::Return(schema.clone()) + } else if elapsed < SCHEMA_CACHE_TTL { + // In refresh window: start background fetch, return current value + let schema = schema.clone(); + let previous = Some((schema.clone(), *cached_at)); + let _ = self.start_schema_fetch(cache, version, previous); + SchemaAction::Return(schema) + } else { + // Expired: must wait for fetch + let previous = Some((schema.clone(), *cached_at)); + let (shared, _) = self.start_schema_fetch(cache, version, previous); + SchemaAction::Wait(shared) + } + } + SchemaState::Refreshing { previous, future } => { + // If the background fetch already completed (spawned task hasn't + // run yet to update state), transition the state and re-evaluate. + if let Some(result) = future.peek() { + match result { + Ok(schema) => { + cache.state = SchemaState::Current(schema.clone(), clock::now()); + } + Err(_) => { + cache.state = match previous.clone() { + Some((s, t)) => SchemaState::Current(s, t), + None => SchemaState::Empty, + }; + } + } + return self.determine_schema_action(cache, version); + } + + if let Some((schema, cached_at)) = previous { + if clock::now().duration_since(*cached_at) < SCHEMA_CACHE_TTL { + SchemaAction::Return(schema.clone()) + } else { + SchemaAction::Wait(future.clone()) + } + } else { + SchemaAction::Wait(future.clone()) + } + } + } + } + + fn start_schema_fetch( + &self, + cache: &mut SchemaCache, + version: Option, + previous: Option<(SchemaRef, Instant)>, + ) -> (SharedSchemaFuture, u64) { + let client = self.client.clone(); + let identifier = self.identifier.clone(); + let table_name = self.name.clone(); + let generation = cache.generation; + + let shared = async move { + fetch_schema(&client, &identifier, &table_name, version) + .await + .map_err(Arc::new) + } + .boxed() + .shared(); + + // Spawn task to eagerly drive the future and update state on completion + let schema_cache = self.schema_cache.clone(); + let fut_for_spawn = shared.clone(); + tokio::spawn(async move { + let result = fut_for_spawn.await; + let mut cache = schema_cache.lock().unwrap(); + // Only update if no invalidation has happened since we started + if cache.generation != generation { + return; + } + match result { + Ok(schema) => { + cache.state = SchemaState::Current(schema, clock::now()); + } + Err(_) => { + // Revert to previous cached value if available + let prev = match &cache.state { + SchemaState::Refreshing { previous, .. } => previous.clone(), + _ => None, + }; + cache.state = match prev { + Some((s, t)) => SchemaState::Current(s, t), + None => SchemaState::Empty, + }; + } + } + }); + + cache.state = SchemaState::Refreshing { + previous, + future: shared.clone(), + }; + + (shared, generation) + } } #[derive(Deserialize)] @@ -756,6 +983,68 @@ struct TableDescription { location: Option, } +/// Extract an Error from Arc, reconstructing if the Arc is shared. +/// This is needed because `Shared` futures cache results internally, so +/// `Arc::try_unwrap` typically fails. +fn unwrap_shared_error(arc: Arc) -> Error { + match Arc::try_unwrap(arc) { + Ok(err) => err, + Err(arc) => match &*arc { + Error::TableNotFound { name, source } => Error::TableNotFound { + name: name.clone(), + source: source.to_string().into(), + }, + _ => Error::Runtime { + message: arc.to_string(), + }, + }, + } +} + +async fn fetch_schema( + client: &RestfulLanceDbClient, + identifier: &str, + table_name: &str, + version: Option, +) -> Result { + let request = client + .post(&format!("/v1/table/{}/describe/", identifier)) + .json(&serde_json::json!({ "version": version })); + + let (request_id, response) = client.send_with_retry(request, None, true).await?; + + if response.status() == StatusCode::NOT_FOUND { + let body = response.text().await.ok().unwrap_or_default(); + return Err(Error::TableNotFound { + name: table_name.to_string(), + source: Box::new(Error::Http { + source: body.into(), + request_id, + status_code: Some(StatusCode::NOT_FOUND), + }), + }); + } + + let response = client.check_response(&request_id, response).await?; + let body = response.text().await.map_err(|e| { + let status_code = e.status(); + Error::Http { + source: Box::new(e), + request_id: request_id.clone(), + status_code, + } + })?; + + let description: TableDescription = serde_json::from_str(&body).map_err(|e| Error::Http { + source: format!("Failed to parse table description: {}", e).into(), + request_id, + status_code: None, + })?; + + let arrow_schema: arrow_schema::Schema = description.schema.try_into()?; + Ok(Arc::new(arrow_schema)) +} + impl std::fmt::Display for RemoteTable { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "RemoteTable({})", self.identifier) @@ -784,6 +1073,10 @@ mod test_utils { server_version: version.map(ServerVersion).unwrap_or_default(), version: RwLock::new(None), location: RwLock::new(None), + schema_cache: Arc::new(Mutex::new(SchemaCache { + state: SchemaState::Empty, + generation: 0, + })), } } @@ -801,6 +1094,10 @@ mod test_utils { server_version: ServerVersion::default(), version: RwLock::new(None), location: RwLock::new(None), + schema_cache: Arc::new(Mutex::new(SchemaCache { + state: SchemaState::Empty, + generation: 0, + })), } } } @@ -841,11 +1138,21 @@ impl BaseTable for RemoteTable { let mut write_guard = self.version.write().await; *write_guard = Some(version); + drop(write_guard); + + // Invalidate schema cache since we're switching versions + self.invalidate_schema_cache(); + Ok(()) } async fn checkout_latest(&self) -> Result<()> { let mut write_guard = self.version.write().await; *write_guard = None; + drop(write_guard); + + // Invalidate schema cache since we're switching versions + self.invalidate_schema_cache(); + Ok(()) } async fn restore(&self) -> Result<()> { @@ -890,9 +1197,30 @@ impl BaseTable for RemoteTable { } async fn schema(&self) -> Result { - let schema = self.describe().await?.schema; - Ok(Arc::new(schema.try_into()?)) + // Fast path: check if cache is fresh (not even in refresh window) + { + let cache = self.schema_cache.lock().unwrap(); + if let Some(schema) = cache.state.fresh_schema() { + return Ok(schema); + } + } + + // Slow path: may need to fetch or start background refresh + let version = self.current_version().await; + let action = { + let mut cache = self.schema_cache.lock().unwrap(); + self.determine_schema_action(&mut cache, version) + }; + + match action { + SchemaAction::Return(schema) => Ok(schema), + SchemaAction::Wait(fut) => match fut.await { + Ok(schema) => Ok(schema), + Err(arc_err) => Err(unwrap_shared_error(arc_err)), + }, + } } + async fn count_rows(&self, filter: Option) -> Result { let mut request = self .client @@ -912,9 +1240,17 @@ impl BaseTable for RemoteTable { request = request.json(&body); } - let (request_id, response) = self.send(request, true).await?; - - let response = self.check_table_response(&request_id, response).await?; + let (request_id, response) = match self.send(request, true).await { + Ok((id, resp)) => { + // check_table_response now handles error-based invalidation + let response = self.check_table_response(&id, resp).await?; + (id, response) + } + Err(e) => { + self.handle_error_invalidation(&e); + return Err(e); + } + }; let body = response.text().await.err_to_http(request_id.clone())?; @@ -951,6 +1287,11 @@ impl BaseTable for RemoteTable { request_id, status_code: None, })?; + + if matches!(add.mode, AddDataMode::Overwrite) { + self.invalidate_schema_cache(); + } + Ok(add_response) } @@ -1354,6 +1695,11 @@ impl BaseTable for RemoteTable { let version = tags.get_version(tag).await?; let mut write_guard = self.version.write().await; *write_guard = Some(version); + drop(write_guard); + + // Invalidate schema cache since we're switching versions + self.invalidate_schema_cache(); + Ok(()) } async fn optimize(&self, _action: OptimizeAction) -> Result { @@ -1400,6 +1746,8 @@ impl BaseTable for RemoteTable { status_code: None, })?; + self.invalidate_schema_cache(); + Ok(result) } _ => { @@ -1452,6 +1800,8 @@ impl BaseTable for RemoteTable { status_code: None, })?; + self.invalidate_schema_cache(); + Ok(result) } @@ -1477,6 +1827,8 @@ impl BaseTable for RemoteTable { status_code: None, })?; + self.invalidate_schema_cache(); + Ok(result) } @@ -1705,10 +2057,47 @@ impl TryFrom for MergeInsertRequest { } } +// Clock module for testing with mock time +#[cfg(test)] +mod clock { + use std::cell::Cell; + use std::time::{Duration, Instant}; + + thread_local! { + static MOCK_NOW: Cell> = const { Cell::new(None) }; + } + + pub fn now() -> Instant { + MOCK_NOW.with(|mock| mock.get().unwrap_or_else(Instant::now)) + } + + pub fn advance_by(duration: Duration) { + MOCK_NOW.with(|mock| { + let current = mock.get().unwrap_or_else(Instant::now); + mock.set(Some(current + duration)); + }); + } + + #[allow(dead_code)] + pub fn clear_mock() { + MOCK_NOW.with(|mock| mock.set(None)); + } +} + +#[cfg(not(test))] +mod clock { + use std::time::Instant; + + pub fn now() -> Instant { + Instant::now() + } +} + #[cfg(test)] mod tests { use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; + use std::time::Duration; use std::{collections::HashMap, pin::Pin}; use super::*; @@ -3569,6 +3958,745 @@ mod tests { assert_eq!(call_count.load(Ordering::SeqCst), 1); // Still 1, no new call } + /// Test that schema is fetched once and cached for subsequent calls + #[tokio::test] + async fn test_schema_caching() { + let call_count = Arc::new(AtomicUsize::new(0)); + let call_count_clone = call_count.clone(); + + let table = Table::new_with_handler("my_table", move |request| { + assert_eq!(request.url().path(), "/v1/table/my_table/describe/"); + call_count_clone.fetch_add(1, Ordering::SeqCst); + + http::Response::builder() + .status(200) + .body( + r#"{"version": 1, "schema": {"fields": [ + {"name": "a", "type": { "type": "int32" }, "nullable": false} + ]}}"#, + ) + .unwrap() + }); + + // First call should fetch from server + let schema1 = table.schema().await.unwrap(); + assert_eq!(schema1.fields().len(), 1); + assert_eq!(schema1.field(0).name(), "a"); + assert_eq!(call_count.load(Ordering::SeqCst), 1); + + // Second call should use cached value + let schema2 = table.schema().await.unwrap(); + assert_eq!(Arc::as_ptr(&schema2), Arc::as_ptr(&schema1)); + assert_eq!(call_count.load(Ordering::SeqCst), 1); // Still 1, no new call + + // Third call should still use cached value + let schema3 = table.schema().await.unwrap(); + assert_eq!(Arc::as_ptr(&schema3), Arc::as_ptr(&schema1)); + assert_eq!(call_count.load(Ordering::SeqCst), 1); // Still 1, no new call + } + + /// Test that schema cache expires after 30 seconds TTL + #[tokio::test] + async fn test_schema_cache_invalidation_after_ttl() { + let call_count = Arc::new(AtomicUsize::new(0)); + let call_count_clone = call_count.clone(); + + let table = Table::new_with_handler("my_table", move |request| { + assert_eq!(request.url().path(), "/v1/table/my_table/describe/"); + call_count_clone.fetch_add(1, Ordering::SeqCst); + + http::Response::builder() + .status(200) + .body( + r#"{"version": 1, "schema": {"fields": [ + {"name": "a", "type": { "type": "int32" }, "nullable": false} + ]}}"#, + ) + .unwrap() + }); + + // First call should fetch from server + let _schema1 = table.schema().await.unwrap(); + assert_eq!(call_count.load(Ordering::SeqCst), 1); + + // Second call should use cached value (within TTL) + let schema2 = table.schema().await.unwrap(); + assert_eq!(Arc::as_ptr(&schema2), Arc::as_ptr(&_schema1)); + assert_eq!(call_count.load(Ordering::SeqCst), 1); + + // Advance mock time past TTL (no real wait) + clock::advance_by(Duration::from_secs(31)); + + // Third call should re-fetch from server (TTL expired) + let schema3 = table.schema().await.unwrap(); + assert_ne!(Arc::as_ptr(&schema3), Arc::as_ptr(&_schema1)); + assert_eq!(call_count.load(Ordering::SeqCst), 2); + } + + /// Test that schema cache is invalidated after schema-changing operations + #[rstest] + #[case("overwrite")] + #[case("add_columns")] + #[case("drop_columns")] + #[case("alter_columns")] + #[tokio::test] + async fn test_schema_cache_invalidation_after_operation(#[case] operation: &str) { + let call_count = Arc::new(AtomicUsize::new(0)); + let call_count_clone = call_count.clone(); + + let table = Table::new_with_handler("my_table", move |request| { + let path = request.url().path(); + + if path == "/v1/table/my_table/describe/" { + call_count_clone.fetch_add(1, Ordering::SeqCst); + http::Response::builder() + .status(200) + .body( + r#"{"version": 1, "schema": {"fields": [ + {"name": "a", "type": { "type": "int32" }, "nullable": false}, + {"name": "b", "type": { "type": "int32" }, "nullable": false} + ]}}"#, + ) + .unwrap() + } else if path == "/v1/table/my_table/insert/" + || path == "/v1/table/my_table/add_columns/" + || path == "/v1/table/my_table/drop_columns/" + || path == "/v1/table/my_table/alter_columns/" + { + http::Response::builder() + .status(200) + .body(r#"{"version": 2}"#) + .unwrap() + } else { + http::Response::builder() + .status(404) + .body("not found") + .unwrap() + } + }); + + // First schema call should fetch from server + let schema1 = table.schema().await.unwrap(); + assert_eq!(call_count.load(Ordering::SeqCst), 1); + + // Second schema call should use cached value + let schema2 = table.schema().await.unwrap(); + assert_eq!(Arc::as_ptr(&schema2), Arc::as_ptr(&schema1)); + assert_eq!(call_count.load(Ordering::SeqCst), 1); + + // Perform the schema-changing operation + match operation { + "overwrite" => { + let data = record_batch!(("a", Int32, [1, 2, 3])).unwrap(); + let _ = table.add(data).mode(AddDataMode::Overwrite).execute().await; + } + "add_columns" => { + let _ = table + .add_columns( + NewColumnTransform::SqlExpressions(vec![("c".into(), "a + 1".into())]), + None, + ) + .await; + } + "drop_columns" => { + let _ = table.drop_columns(&["b"]).await; + } + "alter_columns" => { + let alterations = vec![ColumnAlteration::new("a".into()).rename("new_a".into())]; + let _ = table.alter_columns(&alterations).await; + } + _ => panic!("Unknown operation: {}", operation), + } + + // Schema call after operation should re-fetch from server (cache invalidated) + let schema3 = table.schema().await.unwrap(); + assert_ne!(Arc::as_ptr(&schema3), Arc::as_ptr(&schema1)); + assert_eq!(call_count.load(Ordering::SeqCst), 2); + } + + /// Test that schema cache is invalidated when server returns certain error codes + #[rstest] + #[case(400, true)] // 400 Bad Request should invalidate cache + #[case(401, false)] // 401 Unauthorized should NOT invalidate cache + #[case(403, false)] // 403 Forbidden should NOT invalidate cache + #[case(404, true)] // 404 Not Found should invalidate (table might be recreated) + #[case(500, true)] // 500 Internal Server Error should invalidate cache + #[case(503, false)] // 503 Service Unavailable should NOT invalidate cache + #[tokio::test] + async fn test_schema_cache_invalidation_on_errors( + #[case] error_status: u16, + #[case] should_invalidate: bool, + ) { + let call_count = Arc::new(AtomicUsize::new(0)); + let call_count_clone = call_count.clone(); + + let table = Table::new_with_handler("my_table", move |request| { + let path = request.url().path(); + let current_count = call_count_clone.load(Ordering::SeqCst); + + if path == "/v1/table/my_table/describe/" { + call_count_clone.fetch_add(1, Ordering::SeqCst); + http::Response::builder() + .status(200) + .body( + r#"{"version": 1, "schema": {"fields": [ + {"name": "a", "type": { "type": "int32" }, "nullable": false} + ]}}"#, + ) + .unwrap() + } else if path == "/v1/table/my_table/count_rows/" { + // Return error on first count_rows call + if current_count == 1 { + http::Response::builder() + .status(error_status) + .body("error") + .unwrap() + } else { + http::Response::builder().status(200).body("10").unwrap() + } + } else { + http::Response::builder() + .status(404) + .body("not found") + .unwrap() + } + }); + + // First schema call should fetch from server + let schema1 = table.schema().await.unwrap(); + assert_eq!(call_count.load(Ordering::SeqCst), 1); + + // Second schema call should use cached value + let schema2 = table.schema().await.unwrap(); + assert_eq!(Arc::as_ptr(&schema2), Arc::as_ptr(&schema1)); + assert_eq!(call_count.load(Ordering::SeqCst), 1); + + // Perform operation that returns error + let result = table.count_rows(None).await; + assert!(result.is_err()); + + // Schema call after error - check if cache was invalidated + let schema3 = table.schema().await.unwrap(); + if should_invalidate { + assert_eq!( + call_count.load(Ordering::SeqCst), + 2, + "Cache should be invalidated for {} error", + error_status + ); + assert_ne!(Arc::as_ptr(&schema3), Arc::as_ptr(&schema1)); + } else { + assert_eq!( + call_count.load(Ordering::SeqCst), + 1, + "Cache should NOT be invalidated for {} error", + error_status + ); + assert_eq!(Arc::as_ptr(&schema3), Arc::as_ptr(&schema1)); + } + } + + /// Test that schema cache is invalidated after checkout + #[tokio::test] + async fn test_schema_cache_invalidation_on_checkout() { + let call_count = Arc::new(AtomicUsize::new(0)); + let call_count_clone = call_count.clone(); + + let table = Table::new_with_handler("my_table", move |request| { + let path = request.url().path(); + call_count_clone.fetch_add(1, Ordering::SeqCst); + let count = call_count_clone.load(Ordering::SeqCst); + + if path == "/v1/table/my_table/describe/" { + // Return different schemas for different calls + if count <= 2 { + // First schema call and checkout validation + http::Response::builder() + .status(200) + .body( + r#"{"version": 1, "schema": {"fields": [ + {"name": "a", "type": { "type": "int32" }, "nullable": false} + ]}}"#, + ) + .unwrap() + } else { + // After checkout + http::Response::builder() + .status(200) + .body( + r#"{"version": 2, "schema": {"fields": [ + {"name": "a", "type": { "type": "int32" }, "nullable": false}, + {"name": "b", "type": { "type": "int32" }, "nullable": false} + ]}}"#, + ) + .unwrap() + } + } else { + http::Response::builder() + .status(404) + .body("not found") + .unwrap() + } + }); + + // First schema call + let schema1 = table.schema().await.unwrap(); + assert_eq!(schema1.fields().len(), 1); + + // Second schema call should use cached value (no new call) + let call_count_before = call_count.load(Ordering::SeqCst); + let schema2 = table.schema().await.unwrap(); + assert_eq!(Arc::as_ptr(&schema2), Arc::as_ptr(&schema1)); + assert_eq!(call_count.load(Ordering::SeqCst), call_count_before); + + // Checkout to version 2 (makes a describe call to validate) + let _ = table.checkout(2).await; + + // Schema call after checkout should re-fetch (cache was invalidated) + let schema3 = table.schema().await.unwrap(); + assert_eq!(schema3.fields().len(), 2); + assert_ne!(Arc::as_ptr(&schema3), Arc::as_ptr(&schema1)); + } + + /// Test that schema cache is invalidated after checkout_latest + #[tokio::test] + async fn test_schema_cache_invalidation_on_checkout_latest() { + let call_count = Arc::new(AtomicUsize::new(0)); + let call_count_clone = call_count.clone(); + + let table = Table::new_with_handler("my_table", move |request| { + let path = request.url().path(); + + if path == "/v1/table/my_table/describe/" { + call_count_clone.fetch_add(1, Ordering::SeqCst); + http::Response::builder() + .status(200) + .body( + r#"{"version": 1, "schema": {"fields": [ + {"name": "a", "type": { "type": "int32" }, "nullable": false} + ]}}"#, + ) + .unwrap() + } else { + http::Response::builder() + .status(404) + .body("not found") + .unwrap() + } + }); + + // First schema call + let schema1 = table.schema().await.unwrap(); + assert_eq!(call_count.load(Ordering::SeqCst), 1); + + // Second schema call should use cached value + let schema2 = table.schema().await.unwrap(); + assert_eq!(Arc::as_ptr(&schema2), Arc::as_ptr(&schema1)); + assert_eq!(call_count.load(Ordering::SeqCst), 1); + + // Checkout latest + let _ = table.checkout_latest().await; + + // Schema call after checkout_latest should re-fetch (cache was invalidated) + let schema3 = table.schema().await.unwrap(); + assert_ne!(Arc::as_ptr(&schema3), Arc::as_ptr(&schema1)); + assert_eq!(call_count.load(Ordering::SeqCst), 2); + } + + /// Test that schema cache is invalidated after checkout_tag + #[tokio::test] + async fn test_schema_cache_invalidation_on_checkout_tag() { + let call_count = Arc::new(AtomicUsize::new(0)); + let call_count_clone = call_count.clone(); + + let table = Table::new_with_handler("my_table", move |request| { + let path = request.url().path(); + + if path == "/v1/table/my_table/describe/" { + call_count_clone.fetch_add(1, Ordering::SeqCst); + http::Response::builder() + .status(200) + .body( + r#"{"version": 1, "schema": {"fields": [ + {"name": "a", "type": { "type": "int32" }, "nullable": false} + ]}}"#, + ) + .unwrap() + } else if path == "/v1/table/my_table/tags/list/" { + http::Response::builder() + .status(200) + .body(r#"{"tags": {"v2": {"version": 2}}}"#) + .unwrap() + } else if path == "/v1/table/my_table/tags/version/" { + http::Response::builder() + .status(200) + .body(r#"{"version": 2}"#) + .unwrap() + } else { + http::Response::builder() + .status(404) + .body("not found") + .unwrap() + } + }); + + // First schema call + let schema1 = table.schema().await.unwrap(); + assert_eq!(call_count.load(Ordering::SeqCst), 1); + + // Second schema call should use cached value + let schema2 = table.schema().await.unwrap(); + assert_eq!(Arc::as_ptr(&schema2), Arc::as_ptr(&schema1)); + assert_eq!(call_count.load(Ordering::SeqCst), 1); + + // Checkout tag + table + .checkout_tag("v2") + .await + .expect("checkout_tag should succeed"); + + // Schema call after checkout_tag should re-fetch (cache was invalidated) + let schema3 = table.schema().await.unwrap(); + assert_eq!( + call_count.load(Ordering::SeqCst), + 2, + "Cache should have been invalidated and re-fetched" + ); + assert_ne!( + Arc::as_ptr(&schema3), + Arc::as_ptr(&schema1), + "Should be different Arc instances" + ); + } + + /// Test that restore invalidates cache (via checkout_latest) + #[tokio::test] + async fn test_schema_cache_invalidation_on_restore() { + let call_count = Arc::new(AtomicUsize::new(0)); + let call_count_clone = call_count.clone(); + + let table = Table::new_with_handler("my_table", move |request| { + let path = request.url().path(); + + if path == "/v1/table/my_table/describe/" { + call_count_clone.fetch_add(1, Ordering::SeqCst); + http::Response::builder() + .status(200) + .body( + r#"{"version": 1, "schema": {"fields": [ + {"name": "a", "type": { "type": "int32" }, "nullable": false} + ]}}"#, + ) + .unwrap() + } else if path == "/v1/table/my_table/restore/" { + http::Response::builder() + .status(200) + .body(r#"{"version": 1}"#) + .unwrap() + } else { + http::Response::builder() + .status(404) + .body("not found") + .unwrap() + } + }); + + // First schema call + let schema1 = table.schema().await.unwrap(); + assert_eq!(call_count.load(Ordering::SeqCst), 1); + + // Second schema call uses cache + let schema2 = table.schema().await.unwrap(); + assert_eq!(Arc::as_ptr(&schema2), Arc::as_ptr(&schema1)); + assert_eq!(call_count.load(Ordering::SeqCst), 1); + + // Restore operation + let _ = table.restore().await; + + // Schema call after restore should re-fetch (cache invalidated) + let schema3 = table.schema().await.unwrap(); + assert_ne!(Arc::as_ptr(&schema3), Arc::as_ptr(&schema1)); + assert_eq!(call_count.load(Ordering::SeqCst), 2); + } + + /// Test that centralized error handling invalidates cache on query errors + #[tokio::test] + async fn test_centralized_error_invalidation_on_query() { + let call_count = Arc::new(AtomicUsize::new(0)); + let call_count_clone = call_count.clone(); + + let table = Table::new_with_handler("my_table", move |request| { + let path = request.url().path(); + let current_count = call_count_clone.load(Ordering::SeqCst); + + if path == "/v1/table/my_table/describe/" { + call_count_clone.fetch_add(1, Ordering::SeqCst); + http::Response::builder() + .status(200) + .body( + r#"{"version": 1, "schema": {"fields": [ + {"name": "a", "type": { "type": "int32" }, "nullable": false} + ]}}"#, + ) + .unwrap() + } else if path == "/v1/table/my_table/query/" { + // Return 400 error on first query (could be schema mismatch) + if current_count == 1 { + http::Response::builder() + .status(400) + .body("Bad request") + .unwrap() + } else { + // Return empty result for successful query + http::Response::builder() + .status(200) + .header("content-type", "application/vnd.apache.arrow.stream") + .body("") + .unwrap() + } + } else { + http::Response::builder() + .status(404) + .body("not found") + .unwrap() + } + }); + + // First schema call + let schema1 = table.schema().await.unwrap(); + assert_eq!(call_count.load(Ordering::SeqCst), 1); + + // Second schema call uses cache + let schema2 = table.schema().await.unwrap(); + assert_eq!(Arc::as_ptr(&schema2), Arc::as_ptr(&schema1)); + assert_eq!(call_count.load(Ordering::SeqCst), 1); + + // Query that returns 400 error + let result = table.query().execute().await; + assert!(result.is_err()); + + // Schema call after error should re-fetch (cache invalidated by centralized handler) + let schema3 = table.schema().await.unwrap(); + assert_ne!(Arc::as_ptr(&schema3), Arc::as_ptr(&schema1)); + assert_eq!(call_count.load(Ordering::SeqCst), 2); + } + + /// Test that concurrent schema() calls with an empty cache only trigger one fetch. + #[tokio::test] + async fn test_concurrent_schema_calls_single_fetch() { + let call_count = Arc::new(AtomicUsize::new(0)); + let call_count_clone = call_count.clone(); + + let table = Arc::new(Table::new_with_handler("my_table", move |request| { + let path = request.url().path(); + if path == "/v1/table/my_table/describe/" { + call_count_clone.fetch_add(1, Ordering::SeqCst); + http::Response::builder() + .status(200) + .body( + r#"{"version": 1, "schema": {"fields": [ + {"name": "a", "type": { "type": "int32" }, "nullable": false} + ]}}"#, + ) + .unwrap() + } else { + panic!("Unexpected request: {}", path); + } + })); + + let mut handles = Vec::new(); + for _ in 0..10 { + let table = table.clone(); + handles.push(tokio::spawn(async move { table.schema().await.unwrap() })); + } + + let schemas: Vec = futures::future::try_join_all(handles).await.unwrap(); + + // All callers should get the same Arc + for schema in &schemas { + assert_eq!(Arc::as_ptr(schema), Arc::as_ptr(&schemas[0])); + } + // Only one describe call should have been made + assert_eq!(call_count.load(Ordering::SeqCst), 1); + } + + /// Test that a background refresh is triggered in the refresh window and + /// returns the cached value immediately. + #[tokio::test] + async fn test_background_refresh_triggers_in_window() { + let call_count = Arc::new(AtomicUsize::new(0)); + let call_count_clone = call_count.clone(); + + let table = Table::new_with_handler("my_table", move |request| { + let path = request.url().path(); + if path == "/v1/table/my_table/describe/" { + let count = call_count_clone.fetch_add(1, Ordering::SeqCst); + if count == 0 { + http::Response::builder() + .status(200) + .body( + r#"{"version": 1, "schema": {"fields": [ + {"name": "a", "type": { "type": "int32" }, "nullable": false} + ]}}"#, + ) + .unwrap() + } else { + http::Response::builder() + .status(200) + .body( + r#"{"version": 2, "schema": {"fields": [ + {"name": "a", "type": { "type": "int32" }, "nullable": false}, + {"name": "b", "type": { "type": "string" }, "nullable": true} + ]}}"#, + ) + .unwrap() + } + } else { + panic!("Unexpected request: {}", path); + } + }); + + // Populate cache and trigger peek transition to Current state + let schema1 = table.schema().await.unwrap(); + assert_eq!(schema1.fields().len(), 1); + assert_eq!(call_count.load(Ordering::SeqCst), 1); + // Second call transitions cache from Refreshing to Current via peek() + let schema2 = table.schema().await.unwrap(); + assert_eq!(Arc::as_ptr(&schema2), Arc::as_ptr(&schema1)); + + // Advance into refresh window (TTL=30s, window=5s, so 26s is in window) + clock::advance_by(Duration::from_secs(26)); + + // This call enters the refresh window: returns cached value and creates + // a background shared future (Refreshing state with previous). + let schema3 = table.schema().await.unwrap(); + assert_eq!(Arc::as_ptr(&schema3), Arc::as_ptr(&schema1)); + // Only the initial fetch so far + assert_eq!(call_count.load(Ordering::SeqCst), 1); + + // Advance past TTL so the previous value expires. This forces the next + // schema() to Wait on the in-flight shared future, driving it to completion. + clock::advance_by(Duration::from_secs(30)); + + let schema4 = table.schema().await.unwrap(); + assert_eq!(call_count.load(Ordering::SeqCst), 2); + assert_eq!(schema4.fields().len(), 2); + assert_ne!(Arc::as_ptr(&schema4), Arc::as_ptr(&schema1)); + } + + /// Test that multiple calls during the refresh window don't trigger + /// duplicate background refreshes. + #[tokio::test] + async fn test_no_duplicate_background_refreshes() { + let call_count = Arc::new(AtomicUsize::new(0)); + let call_count_clone = call_count.clone(); + + let table = Table::new_with_handler("my_table", move |request| { + let path = request.url().path(); + if path == "/v1/table/my_table/describe/" { + call_count_clone.fetch_add(1, Ordering::SeqCst); + http::Response::builder() + .status(200) + .body( + r#"{"version": 1, "schema": {"fields": [ + {"name": "a", "type": { "type": "int32" }, "nullable": false} + ]}}"#, + ) + .unwrap() + } else { + panic!("Unexpected request: {}", path); + } + }); + + // Populate cache and transition to Current state + let schema1 = table.schema().await.unwrap(); + assert_eq!(call_count.load(Ordering::SeqCst), 1); + let _ = table.schema().await.unwrap(); // peek transition + + // Advance into refresh window + clock::advance_by(Duration::from_secs(26)); + + // Multiple rapid calls should all return cached. The first one enters + // the refresh window and starts a background fetch (Refreshing state). + // Subsequent calls see Refreshing with a valid previous and return it. + for _ in 0..5 { + let schema = table.schema().await.unwrap(); + assert_eq!(Arc::as_ptr(&schema), Arc::as_ptr(&schema1)); + } + + // Advance past TTL and drive the shared future to completion + clock::advance_by(Duration::from_secs(30)); + let _ = table.schema().await.unwrap(); + + // Only one additional describe call (the background refresh), + // not five separate ones + assert_eq!(call_count.load(Ordering::SeqCst), 2); + } + + /// Test that if a background refresh fails, the previously cached value + /// is preserved and still returned. + #[tokio::test] + async fn test_background_refresh_error_preserves_cache() { + let call_count = Arc::new(AtomicUsize::new(0)); + let call_count_clone = call_count.clone(); + + let table = Table::new_with_handler("my_table", move |request| { + let path = request.url().path(); + if path == "/v1/table/my_table/describe/" { + let count = call_count_clone.fetch_add(1, Ordering::SeqCst); + if count == 0 { + // First call succeeds + http::Response::builder() + .status(200) + .body( + r#"{"version": 1, "schema": {"fields": [ + {"name": "a", "type": { "type": "int32" }, "nullable": false} + ]}}"#, + ) + .unwrap() + } else { + // Subsequent calls fail (422 is not retried) + http::Response::builder() + .status(422) + .body("Unprocessable Entity") + .unwrap() + } + } else { + panic!("Unexpected request: {}", path); + } + }); + + // Populate cache and transition to Current state + let schema1 = table.schema().await.unwrap(); + assert_eq!(schema1.fields().len(), 1); + assert_eq!(call_count.load(Ordering::SeqCst), 1); + let _ = table.schema().await.unwrap(); // peek transition + + // Advance into refresh window + clock::advance_by(Duration::from_secs(26)); + + // Trigger background refresh (returns cached value). The background + // fetch will fail but the previous value should be preserved. + let schema2 = table.schema().await.unwrap(); + assert_eq!(Arc::as_ptr(&schema2), Arc::as_ptr(&schema1)); + + // Still in the refresh window: the previous value is valid, + // so calling schema() should still return it. + let schema3 = table.schema().await.unwrap(); + assert_eq!(Arc::as_ptr(&schema3), Arc::as_ptr(&schema1)); + + // Advance past TTL. The shared future will be driven and fail. + // The peek() error path should revert to the previous cached value. + clock::advance_by(Duration::from_secs(30)); + + // After the error, the previous is restored but its timestamp is old, + // so the next call triggers a new fetch which also fails. + let result = table.schema().await; + assert_eq!(call_count.load(Ordering::SeqCst), 2); + // The error from the failed fetch should be propagated + assert!(result.is_err()); + } + #[tokio::test] async fn test_add_retries_rescannable_data() { let call_count = Arc::new(AtomicUsize::new(0));