From 7cecb71df0da73036ad1b398fad18beedc013fd9 Mon Sep 17 00:00:00 2001 From: Bert Date: Thu, 21 Nov 2024 11:28:46 -0500 Subject: [PATCH] feat: support for checkout and checkout_latest in remote sdks (#1863) --- python/python/lancedb/remote/table.py | 6 + python/python/lancedb/table.py | 29 +++ python/python/tests/test_remote_db.py | 48 ++++ rust/lancedb/src/remote/table.rs | 306 ++++++++++++++++++++++++-- 4 files changed, 371 insertions(+), 18 deletions(-) diff --git a/python/python/lancedb/remote/table.py b/python/python/lancedb/remote/table.py index c1106bb0..c897cb6b 100644 --- a/python/python/lancedb/remote/table.py +++ b/python/python/lancedb/remote/table.py @@ -86,6 +86,12 @@ class RemoteTable(Table): """to_pandas() is not yet supported on LanceDB cloud.""" return NotImplementedError("to_pandas() is not yet supported on LanceDB cloud.") + def checkout(self, version): + return self._loop.run_until_complete(self._table.checkout(version)) + + def checkout_latest(self): + return self._loop.run_until_complete(self._table.checkout_latest()) + def list_indices(self): """List all the indices on the table""" return self._loop.run_until_complete(self._table.list_indices()) diff --git a/python/python/lancedb/table.py b/python/python/lancedb/table.py index eee14dd9..4c77beb2 100644 --- a/python/python/lancedb/table.py +++ b/python/python/lancedb/table.py @@ -1012,6 +1012,35 @@ class Table(ABC): The names of the columns to drop. """ + @abstractmethod + def checkout(self): + """ + Checks out a specific version of the Table + + Any read operation on the table will now access the data at the checked out + version. As a consequence, calling this method will disable any read consistency + interval that was previously set. + + This is a read-only operation that turns the table into a sort of "view" + or "detached head". Other table instances will not be affected. To make the + change permanent you can use the `[Self::restore]` method. + + Any operation that modifies the table will fail while the table is in a checked + out state. + + To return the table to a normal state use `[Self::checkout_latest]` + """ + + @abstractmethod + def checkout_latest(self): + """ + Ensures the table is pointing at the latest version + + This can be used to manually update a table when the read_consistency_interval + is None + It can also be used to undo a `[Self::checkout]` operation + """ + @cached_property def _dataset_uri(self) -> str: return _table_uri(self._conn.uri, self.name) diff --git a/python/python/tests/test_remote_db.py b/python/python/tests/test_remote_db.py index cd0691e8..fbf432b1 100644 --- a/python/python/tests/test_remote_db.py +++ b/python/python/tests/test_remote_db.py @@ -103,6 +103,47 @@ async def test_async_remote_db(): assert table_names == [] +@pytest.mark.asyncio +async def test_async_checkout(): + def handler(request): + if request.path == "/v1/table/test/describe/": + request.send_response(200) + request.send_header("Content-Type", "application/json") + request.end_headers() + response = json.dumps({"version": 42, "schema": {"fields": []}}) + request.wfile.write(response.encode()) + return + + content_len = int(request.headers.get("Content-Length")) + body = request.rfile.read(content_len) + body = json.loads(body) + + print("body is", body) + + count = 0 + if body["version"] == 1: + count = 100 + elif body["version"] == 2: + count = 200 + elif body["version"] is None: + count = 300 + + request.send_response(200) + request.send_header("Content-Type", "application/json") + request.end_headers() + request.wfile.write(json.dumps(count).encode()) + + async with mock_lancedb_connection_async(handler) as db: + table = await db.open_table("test") + assert await table.count_rows() == 300 + await table.checkout(1) + assert await table.count_rows() == 100 + await table.checkout(2) + assert await table.count_rows() == 200 + await table.checkout_latest() + assert await table.count_rows() == 300 + + @pytest.mark.asyncio async def test_http_error(): request_id_holder = {"request_id": None} @@ -188,6 +229,7 @@ def test_query_sync_minimal(): "ef": None, "vector": [1.0, 2.0, 3.0], "nprobes": 20, + "version": None, } return pa.table({"id": [1, 2, 3]}) @@ -205,6 +247,7 @@ def test_query_sync_empty_query(): "filter": "true", "vector": [], "columns": ["id"], + "version": None, } return pa.table({"id": [1, 2, 3]}) @@ -230,6 +273,7 @@ def test_query_sync_maximal(): "vector_column": "vector2", "fast_search": True, "with_row_id": True, + "version": None, } return pa.table({"id": [1, 2, 3], "name": ["a", "b", "c"]}) @@ -268,6 +312,7 @@ def test_query_sync_fts(): }, "k": 10, "vector": [], + "version": None, } return pa.table({"id": [1, 2, 3]}) @@ -284,6 +329,7 @@ def test_query_sync_fts(): "k": 42, "vector": [], "with_row_id": True, + "version": None, } return pa.table({"id": [1, 2, 3]}) @@ -309,6 +355,7 @@ def test_query_sync_hybrid(): "k": 42, "vector": [], "with_row_id": True, + "version": None, } return pa.table({"_rowid": [1, 2, 3], "_score": [0.1, 0.2, 0.3]}) else: @@ -322,6 +369,7 @@ def test_query_sync_hybrid(): "nprobes": 20, "ef": None, "with_row_id": True, + "version": None, } return pa.table({"_rowid": [1, 2, 3], "_distance": [0.1, 0.2, 0.3]}) diff --git a/rust/lancedb/src/remote/table.rs b/rust/lancedb/src/remote/table.rs index 30fb59e2..4388a78d 100644 --- a/rust/lancedb/src/remote/table.rs +++ b/rust/lancedb/src/remote/table.rs @@ -22,6 +22,7 @@ use lance::dataset::scanner::DatasetRecordBatchStream; use lance::dataset::{ColumnAlteration, NewColumnTransform}; use lance_datafusion::exec::OneShotExec; use serde::{Deserialize, Serialize}; +use tokio::sync::RwLock; use crate::{ connection::NoData, @@ -43,17 +44,32 @@ pub struct RemoteTable { #[allow(dead_code)] client: RestfulLanceDbClient, name: String, + + version: RwLock>, } impl RemoteTable { pub fn new(client: RestfulLanceDbClient, name: String) -> Self { - Self { client, name } + Self { + client, + name, + version: RwLock::new(None), + } } async fn describe(&self) -> Result { - let request = self + let version = self.current_version().await; + self.describe_version(version).await + } + + async fn describe_version(&self, version: Option) -> Result { + let mut request = self .client .post(&format!("/v1/table/{}/describe/", self.name)); + + let body = serde_json::json!({ "version": version }); + request = request.json(&body); + let (request_id, response) = self.client.send(request, true).await?; let response = self.check_table_response(&request_id, response).await?; @@ -251,6 +267,24 @@ impl RemoteTable { } } } + + async fn check_mutable(&self) -> Result<()> { + let read_guard = self.version.read().await; + match *read_guard { + None => Ok(()), + Some(version) => Err(Error::NotSupported { + message: format!( + "Cannot mutate table reference fixed at version {}. Call checkout_latest() to get a mutable table reference.", + version + ) + }) + } + } + + async fn current_version(&self) -> Option { + let read_guard = self.version.read().await; + *read_guard + } } #[derive(Deserialize)] @@ -278,7 +312,11 @@ mod test_utils { T: Into, { let client = client_with_handler(handler); - Self { client, name } + Self { + client, + name, + version: RwLock::new(None), + } } } } @@ -297,17 +335,30 @@ impl TableInternal for RemoteTable { async fn version(&self) -> Result { self.describe().await.map(|desc| desc.version) } - async fn checkout(&self, _version: u64) -> Result<()> { - Err(Error::NotSupported { - message: "checkout is not supported on LanceDB cloud.".into(), - }) + async fn checkout(&self, version: u64) -> Result<()> { + // check that the version exists + self.describe_version(Some(version)) + .await + .map_err(|e| match e { + // try to map the error to a more user-friendly error telling them + // specifically that the version does not exist + Error::TableNotFound { name } => Error::TableNotFound { + name: format!("{} (version: {})", name, version), + }, + e => e, + })?; + + let mut write_guard = self.version.write().await; + *write_guard = Some(version); + Ok(()) } async fn checkout_latest(&self) -> Result<()> { - Err(Error::NotSupported { - message: "checkout is not supported on LanceDB cloud.".into(), - }) + let mut write_guard = self.version.write().await; + *write_guard = None; + Ok(()) } async fn restore(&self) -> Result<()> { + self.check_mutable().await?; Err(Error::NotSupported { message: "restore is not supported on LanceDB cloud.".into(), }) @@ -321,10 +372,13 @@ impl TableInternal for RemoteTable { .client .post(&format!("/v1/table/{}/count_rows/", self.name)); + let version = self.current_version().await; + if let Some(filter) = filter { - request = request.json(&serde_json::json!({ "predicate": filter })); + request = request.json(&serde_json::json!({ "predicate": filter, "version": version })); } else { - request = request.json(&serde_json::json!({})); + let body = serde_json::json!({ "version": version }); + request = request.json(&body); } let (request_id, response) = self.client.send(request, true).await?; @@ -344,6 +398,7 @@ impl TableInternal for RemoteTable { add: AddDataBuilder, data: Box, ) -> Result<()> { + self.check_mutable().await?; let body = Self::reader_as_body(data)?; let mut request = self .client @@ -372,7 +427,8 @@ impl TableInternal for RemoteTable { ) -> Result> { let request = self.client.post(&format!("/v1/table/{}/query/", self.name)); - let body = serde_json::Value::Object(Default::default()); + let version = self.current_version().await; + let body = serde_json::json!({ "version": version }); let bodies = Self::apply_vector_query_params(body, query)?; let mut futures = Vec::with_capacity(bodies.len()); @@ -407,7 +463,8 @@ impl TableInternal for RemoteTable { .post(&format!("/v1/table/{}/query/", self.name)) .header(CONTENT_TYPE, JSON_CONTENT_TYPE); - let mut body = serde_json::Value::Object(Default::default()); + let version = self.current_version().await; + let mut body = serde_json::json!({ "version": version }); Self::apply_query_params(&mut body, query)?; // Empty vector can be passed if no vector search is performed. body["vector"] = serde_json::Value::Array(Vec::new()); @@ -421,6 +478,7 @@ impl TableInternal for RemoteTable { Ok(DatasetRecordBatchStream::new(stream)) } async fn update(&self, update: UpdateBuilder) -> Result { + self.check_mutable().await?; let request = self .client .post(&format!("/v1/table/{}/update/", self.name)); @@ -442,6 +500,7 @@ impl TableInternal for RemoteTable { Ok(0) // TODO: support returning number of modified rows once supported in SaaS. } async fn delete(&self, predicate: &str) -> Result<()> { + self.check_mutable().await?; let body = serde_json::json!({ "predicate": predicate }); let request = self .client @@ -453,6 +512,7 @@ impl TableInternal for RemoteTable { } async fn create_index(&self, mut index: IndexBuilder) -> Result<()> { + self.check_mutable().await?; let request = self .client .post(&format!("/v1/table/{}/create_index/", self.name)); @@ -531,6 +591,7 @@ impl TableInternal for RemoteTable { params: MergeInsertBuilder, new_data: Box, ) -> Result<()> { + self.check_mutable().await?; let query = MergeInsertRequest::try_from(params)?; let body = Self::reader_as_body(new_data)?; let request = self @@ -547,6 +608,7 @@ impl TableInternal for RemoteTable { Ok(()) } async fn optimize(&self, _action: OptimizeAction) -> Result { + self.check_mutable().await?; Err(Error::NotSupported { message: "optimize is not supported on LanceDB cloud.".into(), }) @@ -556,16 +618,19 @@ impl TableInternal for RemoteTable { _transforms: NewColumnTransform, _read_columns: Option>, ) -> Result<()> { + self.check_mutable().await?; Err(Error::NotSupported { message: "add_columns is not yet supported.".into(), }) } async fn alter_columns(&self, _alterations: &[ColumnAlteration]) -> Result<()> { + self.check_mutable().await?; Err(Error::NotSupported { message: "alter_columns is not yet supported.".into(), }) } async fn drop_columns(&self, _columns: &[&str]) -> Result<()> { + self.check_mutable().await?; Err(Error::NotSupported { message: "drop_columns is not yet supported.".into(), }) @@ -573,9 +638,13 @@ impl TableInternal for RemoteTable { async fn list_indices(&self) -> Result> { // Make request to list the indices - let request = self + let mut request = self .client .post(&format!("/v1/table/{}/index/list/", self.name)); + let version = self.current_version().await; + let body = serde_json::json!({ "version": version }); + request = request.json(&body); + let (request_id, response) = self.client.send(request, true).await?; let response = self.check_table_response(&request_id, response).await?; @@ -625,10 +694,14 @@ impl TableInternal for RemoteTable { } async fn index_stats(&self, index_name: &str) -> Result> { - let request = self.client.post(&format!( + let mut request = self.client.post(&format!( "/v1/table/{}/index/{}/stats/", self.name, index_name )); + let version = self.current_version().await; + let body = serde_json::json!({ "version": version }); + request = request.json(&body); + let (request_id, response) = self.client.send(request, true).await?; if response.status() == StatusCode::NOT_FOUND { @@ -806,7 +879,10 @@ mod tests { request.headers().get("Content-Type").unwrap(), JSON_CONTENT_TYPE ); - assert_eq!(request.body().unwrap().as_bytes().unwrap(), br#"{}"#); + assert_eq!( + request.body().unwrap().as_bytes().unwrap(), + br#"{"version":null}"# + ); http::Response::builder().status(200).body("42").unwrap() }); @@ -823,7 +899,7 @@ mod tests { ); assert_eq!( request.body().unwrap().as_bytes().unwrap(), - br#"{"predicate":"a > 10"}"# + br#"{"predicate":"a > 10","version":null}"# ); http::Response::builder().status(200).body("42").unwrap() @@ -1124,6 +1200,7 @@ mod tests { "nprobes": 20, "ef": Option::::None, "refine_factor": null, + "version": null, }); // Pass vector separately to make sure it matches f32 precision. expected_body["vector"] = vec![0.1f32, 0.2, 0.3].into(); @@ -1170,6 +1247,7 @@ mod tests { "nprobes": 12, "ef": Option::::None, "refine_factor": 2, + "version": null, }); // Pass vector separately to make sure it matches f32 precision. expected_body["vector"] = vec![0.1f32, 0.2, 0.3].into(); @@ -1225,6 +1303,7 @@ mod tests { "k": 10, "vector": [], "with_row_id": true, + "version": null }); assert_eq!(body, expected_body); @@ -1454,4 +1533,195 @@ mod tests { let indices = table.index_stats("my_index").await.unwrap(); assert!(indices.is_none()); } + + #[tokio::test] + async fn test_passes_version() { + let table = Table::new_with_handler("my_table", |request| { + let body = request.body().unwrap().as_bytes().unwrap(); + let body: serde_json::Value = serde_json::from_slice(body).unwrap(); + let version = body + .as_object() + .unwrap() + .get("version") + .unwrap() + .as_u64() + .unwrap(); + assert_eq!(version, 42); + + let response_body = match request.url().path() { + "/v1/table/my_table/describe/" => { + serde_json::json!({ + "version": 42, + "schema": { "fields": [] } + }) + } + "/v1/table/my_table/index/list/" => { + serde_json::json!({ + "indexes": [] + }) + } + "/v1/table/my_table/index/my_idx/stats/" => { + serde_json::json!({ + "num_indexed_rows": 100000, + "num_unindexed_rows": 0, + "index_type": "IVF_PQ", + "distance_type": "l2" + }) + } + "/v1/table/my_table/count_rows/" => { + serde_json::json!(1000) + } + "/v1/table/my_table/query/" => { + let expected_data = RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])), + vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], + ) + .unwrap(); + let expected_data_ref = expected_data.clone(); + let response_body = write_ipc_file(&expected_data_ref); + return http::Response::builder() + .status(200) + .header(CONTENT_TYPE, ARROW_FILE_CONTENT_TYPE) + .body(response_body) + .unwrap(); + } + + path => panic!("Unexpected path: {}", path), + }; + + http::Response::builder() + .status(200) + .body( + serde_json::to_string(&response_body) + .unwrap() + .as_bytes() + .to_vec(), + ) + .unwrap() + }); + + table.checkout(42).await.unwrap(); + + // ensure that version is passed to the /describe endpoint + let version = table.version().await.unwrap(); + assert_eq!(version, 42); + + // ensure it's passed to other read API calls + table.list_indices().await.unwrap(); + table.index_stats("my_idx").await.unwrap(); + table.count_rows(None).await.unwrap(); + table + .query() + .nearest_to(vec![0.1, 0.2, 0.3]) + .unwrap() + .execute() + .await + .unwrap(); + } + + #[tokio::test] + async fn test_fails_if_checkout_version_doesnt_exist() { + let table = Table::new_with_handler("my_table", |request| { + let body = request.body().unwrap().as_bytes().unwrap(); + let body: serde_json::Value = serde_json::from_slice(body).unwrap(); + let version = body + .as_object() + .unwrap() + .get("version") + .unwrap() + .as_u64() + .unwrap(); + if version != 42 { + return http::Response::builder() + .status(404) + .body(format!("Table my_table (version: {}) not found", version)) + .unwrap(); + } + + let response_body = match request.url().path() { + "/v1/table/my_table/describe/" => { + serde_json::json!({ + "version": 42, + "schema": { "fields": [] } + }) + } + _ => panic!("Unexpected path"), + }; + + http::Response::builder() + .status(200) + .body(serde_json::to_string(&response_body).unwrap()) + .unwrap() + }); + + let res = table.checkout(43).await; + println!("{:?}", res); + assert!( + matches!(res, Err(Error::TableNotFound { name }) if name == "my_table (version: 43)") + ); + } + + #[tokio::test] + async fn test_timetravel_immutable() { + let table = Table::new_with_handler::("my_table", |request| { + let response_body = match request.url().path() { + "/v1/table/my_table/describe/" => { + serde_json::json!({ + "version": 42, + "schema": { "fields": [] } + }) + } + _ => panic!("Should not have made a request: {:?}", request), + }; + + http::Response::builder() + .status(200) + .body(serde_json::to_string(&response_body).unwrap()) + .unwrap() + }); + + table.checkout(42).await.unwrap(); + + // Ensure that all mutable operations fail. + let res = table + .update() + .column("a", "a + 1") + .column("b", "b - 1") + .only_if("b > 10") + .execute() + .await; + assert!(matches!(res, Err(Error::NotSupported { .. }))); + + let batch = RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])), + vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], + ) + .unwrap(); + let data = Box::new(RecordBatchIterator::new( + [Ok(batch.clone())], + batch.schema(), + )); + let res = table.merge_insert(&["some_col"]).execute(data).await; + assert!(matches!(res, Err(Error::NotSupported { .. }))); + + let res = table.delete("id in (1, 2, 3)").await; + assert!(matches!(res, Err(Error::NotSupported { .. }))); + + let data = RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])), + vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], + ) + .unwrap(); + let res = table + .add(RecordBatchIterator::new([Ok(data.clone())], data.schema())) + .execute() + .await; + assert!(matches!(res, Err(Error::NotSupported { .. }))); + + let res = table + .create_index(&["a"], Index::IvfPq(Default::default())) + .execute() + .await; + assert!(matches!(res, Err(Error::NotSupported { .. }))); + } }