diff --git a/rust/lancedb/src/remote/table.rs b/rust/lancedb/src/remote/table.rs index 2ad42dd0..560fb05e 100644 --- a/rust/lancedb/src/remote/table.rs +++ b/rust/lancedb/src/remote/table.rs @@ -54,11 +54,15 @@ impl RemoteTable { } async fn describe(&self) -> Result { + let version = self.current_version().await; + self.describe_version(version).await + } + + async fn describe_version(&self, version: Option) -> Result { let mut request = self .client .post(&format!("/v1/table/{}/describe/", self.name)); - let version = self.current_version().await; let body = serde_json::json!({ "version": version }); request = request.json(&body); @@ -326,9 +330,16 @@ impl TableInternal for RemoteTable { self.describe().await.map(|desc| desc.version) } async fn checkout(&self, version: u64) -> Result<()> { - // TODO check that the version exists - // we can do this when the list_versions changes land - // https://github.com/lancedb/lancedb/pull/1850 + // check that the version exists + self.describe_version(Some(version)).await.map_err(|e| match e { + // try to map the error to a more user-friendly error telling them + // specifically that the version does not exist + Error::TableNotFound { name } => Error::TableNotFound { + name: format!("{} (version: {})", name, version), + }, + e => e, + })?; + let mut write_guard = self.version.write().await; *write_guard = Some(version); Ok(()) @@ -702,7 +713,6 @@ impl TableInternal for RemoteTable { Ok(Some(stats)) } async fn table_definition(&self) -> Result { - self.check_mutable().await?; Err(Error::NotSupported { message: "table_definition is not supported on LanceDB cloud.".into(), }) @@ -861,7 +871,7 @@ mod tests { request.headers().get("Content-Type").unwrap(), JSON_CONTENT_TYPE ); - assert_eq!(request.body().unwrap().as_bytes().unwrap(), br#"{}"#); + assert_eq!(request.body().unwrap().as_bytes().unwrap(), br#"{"version":null}"#); http::Response::builder().status(200).body("42").unwrap() }); @@ -878,7 +888,7 @@ mod tests { ); assert_eq!( request.body().unwrap().as_bytes().unwrap(), - br#"{"predicate":"a > 10"}"# + br#"{"predicate":"a > 10","version":null}"# ); http::Response::builder().status(200).body("42").unwrap() @@ -1179,6 +1189,7 @@ mod tests { "nprobes": 20, "ef": Option::::None, "refine_factor": null, + "version": null, }); // Pass vector separately to make sure it matches f32 precision. expected_body["vector"] = vec![0.1f32, 0.2, 0.3].into(); @@ -1225,6 +1236,7 @@ mod tests { "nprobes": 12, "ef": Option::::None, "refine_factor": 2, + "version": null, }); // Pass vector separately to make sure it matches f32 precision. expected_body["vector"] = vec![0.1f32, 0.2, 0.3].into(); @@ -1280,6 +1292,7 @@ mod tests { "k": 10, "vector": [], "with_row_id": true, + "version": null }); assert_eq!(body, expected_body); @@ -1509,4 +1522,176 @@ mod tests { let indices = table.index_stats("my_index").await.unwrap(); assert!(indices.is_none()); } + + #[tokio::test] + async fn test_passes_version() { + let table = Table::new_with_handler("my_table", |request| { + let body = request.body().unwrap().as_bytes().unwrap(); + let body: serde_json::Value = serde_json::from_slice(body).unwrap(); + let version = body.as_object().unwrap().get("version").unwrap().as_u64().unwrap(); + assert_eq!(version, 42); + + let response_body = match request.url().path() { + "/v1/table/my_table/describe/" => { + serde_json::json!({ + "version": 42, + "schema": { "fields": [] } + }) + }, + "/v1/table/my_table/index/list/" => { + serde_json::json!({ + "indexes": [] + }) + }, + "/v1/table/my_table/index/my_idx/stats/" => { + serde_json::json!({ + "num_indexed_rows": 100000, + "num_unindexed_rows": 0, + "index_type": "IVF_PQ", + "distance_type": "l2" + }) + }, + "/v1/table/my_table/count_rows/" => { + serde_json::json!(1000) + } + "/v1/table/my_table/query/" => { + let expected_data = RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])), + vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], + ) + .unwrap(); + let expected_data_ref = expected_data.clone(); + let response_body = write_ipc_file(&expected_data_ref); + return http::Response::builder() + .status(200) + .header(CONTENT_TYPE, ARROW_FILE_CONTENT_TYPE) + .body(response_body) + .unwrap(); + }, + + path => panic!("Unexpected path: {}", path), + }; + + http::Response::builder() + .status(200) + .body(serde_json::to_string(&response_body).unwrap().as_bytes().to_vec()) + .unwrap() + }); + + table.checkout(42).await.unwrap(); + + // ensure that version is passed to the /describe endpoint + let version = table.version().await.unwrap(); + assert_eq!(version, 42); + + // ensure it's passed to other read API calls + table.list_indices().await.unwrap(); + table.index_stats("my_idx").await.unwrap(); + table.count_rows(None).await.unwrap(); + table + .query() + .nearest_to(vec![0.1, 0.2, 0.3]) + .unwrap() + .execute() + .await + .unwrap(); + } + + #[tokio::test] + async fn test_fails_if_checkout_version_doesnt_exist() { + let table = Table::new_with_handler("my_table", |request| { + let body = request.body().unwrap().as_bytes().unwrap(); + let body: serde_json::Value = serde_json::from_slice(body).unwrap(); + let version = body.as_object().unwrap().get("version").unwrap().as_u64().unwrap(); + if version != 42 { + return http::Response::builder() + .status(404) + .body(format!("Table my_table (version: {}) not found", version)) + .unwrap(); + } + + let response_body = match request.url().path() { + "/v1/table/my_table/describe/" => { + serde_json::json!({ + "version": 42, + "schema": { "fields": [] } + }) + }, + _ => panic!("Unexpected path"), + }; + + http::Response::builder() + .status(200) + .body(serde_json::to_string(&response_body).unwrap()) + .unwrap() + }); + + let res = table.checkout(43).await; + println!("{:?}", res); + assert!(matches!(res, Err(Error::TableNotFound { name }) if name == "my_table (version: 43)")); + } + + #[tokio::test] + async fn test_timetravel_immutable() { + let table = Table::new_with_handler::("my_table", |request| { + let response_body = match request.url().path() { + "/v1/table/my_table/describe/" => { + serde_json::json!({ + "version": 42, + "schema": { "fields": [] } + }) + }, + _ => panic!("Should not have made a request: {:?}", request), + }; + + http::Response::builder() + .status(200) + .body(serde_json::to_string(&response_body).unwrap()) + .unwrap() + }); + + table.checkout(42).await.unwrap(); + + // Ensure that all mutable operations fail. + let res = table + .update() + .column("a", "a + 1") + .column("b", "b - 1") + .only_if("b > 10") + .execute() + .await; + assert!(matches!(res, Err(Error::NotSupported { .. }))); + + let batch = RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])), + vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], + ) + .unwrap(); + let data = Box::new(RecordBatchIterator::new( + [Ok(batch.clone())], + batch.schema(), + )); + let res = table + .merge_insert(&["some_col"]) + .execute(data) + .await; + assert!(matches!(res, Err(Error::NotSupported { .. }))); + + let res = table.delete("id in (1, 2, 3)").await; + assert!(matches!(res, Err(Error::NotSupported { .. }))); + + let data = RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])), + vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], + ) + .unwrap(); + let res = table + .add(RecordBatchIterator::new([Ok(data.clone())], data.schema())) + .execute() + .await; + assert!(matches!(res, Err(Error::NotSupported { .. }))); + + let res = table.create_index(&["a"], Index::IvfPq(Default::default())).execute().await; + assert!(matches!(res, Err(Error::NotSupported { .. }))); + } }