feat: support for checkout and checkout_latest in remote sdks (#1863)

This commit is contained in:
Bert
2024-11-21 11:28:46 -05:00
committed by GitHub
parent 285071e5c8
commit 7cecb71df0
4 changed files with 371 additions and 18 deletions

View File

@@ -86,6 +86,12 @@ class RemoteTable(Table):
"""to_pandas() is not yet supported on LanceDB cloud."""
return NotImplementedError("to_pandas() is not yet supported on LanceDB cloud.")
def checkout(self, version):
return self._loop.run_until_complete(self._table.checkout(version))
def checkout_latest(self):
return self._loop.run_until_complete(self._table.checkout_latest())
def list_indices(self):
"""List all the indices on the table"""
return self._loop.run_until_complete(self._table.list_indices())

View File

@@ -1012,6 +1012,35 @@ class Table(ABC):
The names of the columns to drop.
"""
@abstractmethod
def checkout(self):
"""
Checks out a specific version of the Table
Any read operation on the table will now access the data at the checked out
version. As a consequence, calling this method will disable any read consistency
interval that was previously set.
This is a read-only operation that turns the table into a sort of "view"
or "detached head". Other table instances will not be affected. To make the
change permanent you can use the `[Self::restore]` method.
Any operation that modifies the table will fail while the table is in a checked
out state.
To return the table to a normal state use `[Self::checkout_latest]`
"""
@abstractmethod
def checkout_latest(self):
"""
Ensures the table is pointing at the latest version
This can be used to manually update a table when the read_consistency_interval
is None
It can also be used to undo a `[Self::checkout]` operation
"""
@cached_property
def _dataset_uri(self) -> str:
return _table_uri(self._conn.uri, self.name)

View File

@@ -103,6 +103,47 @@ async def test_async_remote_db():
assert table_names == []
@pytest.mark.asyncio
async def test_async_checkout():
def handler(request):
if request.path == "/v1/table/test/describe/":
request.send_response(200)
request.send_header("Content-Type", "application/json")
request.end_headers()
response = json.dumps({"version": 42, "schema": {"fields": []}})
request.wfile.write(response.encode())
return
content_len = int(request.headers.get("Content-Length"))
body = request.rfile.read(content_len)
body = json.loads(body)
print("body is", body)
count = 0
if body["version"] == 1:
count = 100
elif body["version"] == 2:
count = 200
elif body["version"] is None:
count = 300
request.send_response(200)
request.send_header("Content-Type", "application/json")
request.end_headers()
request.wfile.write(json.dumps(count).encode())
async with mock_lancedb_connection_async(handler) as db:
table = await db.open_table("test")
assert await table.count_rows() == 300
await table.checkout(1)
assert await table.count_rows() == 100
await table.checkout(2)
assert await table.count_rows() == 200
await table.checkout_latest()
assert await table.count_rows() == 300
@pytest.mark.asyncio
async def test_http_error():
request_id_holder = {"request_id": None}
@@ -188,6 +229,7 @@ def test_query_sync_minimal():
"ef": None,
"vector": [1.0, 2.0, 3.0],
"nprobes": 20,
"version": None,
}
return pa.table({"id": [1, 2, 3]})
@@ -205,6 +247,7 @@ def test_query_sync_empty_query():
"filter": "true",
"vector": [],
"columns": ["id"],
"version": None,
}
return pa.table({"id": [1, 2, 3]})
@@ -230,6 +273,7 @@ def test_query_sync_maximal():
"vector_column": "vector2",
"fast_search": True,
"with_row_id": True,
"version": None,
}
return pa.table({"id": [1, 2, 3], "name": ["a", "b", "c"]})
@@ -268,6 +312,7 @@ def test_query_sync_fts():
},
"k": 10,
"vector": [],
"version": None,
}
return pa.table({"id": [1, 2, 3]})
@@ -284,6 +329,7 @@ def test_query_sync_fts():
"k": 42,
"vector": [],
"with_row_id": True,
"version": None,
}
return pa.table({"id": [1, 2, 3]})
@@ -309,6 +355,7 @@ def test_query_sync_hybrid():
"k": 42,
"vector": [],
"with_row_id": True,
"version": None,
}
return pa.table({"_rowid": [1, 2, 3], "_score": [0.1, 0.2, 0.3]})
else:
@@ -322,6 +369,7 @@ def test_query_sync_hybrid():
"nprobes": 20,
"ef": None,
"with_row_id": True,
"version": None,
}
return pa.table({"_rowid": [1, 2, 3], "_distance": [0.1, 0.2, 0.3]})

View File

@@ -22,6 +22,7 @@ use lance::dataset::scanner::DatasetRecordBatchStream;
use lance::dataset::{ColumnAlteration, NewColumnTransform};
use lance_datafusion::exec::OneShotExec;
use serde::{Deserialize, Serialize};
use tokio::sync::RwLock;
use crate::{
connection::NoData,
@@ -43,17 +44,32 @@ pub struct RemoteTable<S: HttpSend = Sender> {
#[allow(dead_code)]
client: RestfulLanceDbClient<S>,
name: String,
version: RwLock<Option<u64>>,
}
impl<S: HttpSend> RemoteTable<S> {
pub fn new(client: RestfulLanceDbClient<S>, name: String) -> Self {
Self { client, name }
Self {
client,
name,
version: RwLock::new(None),
}
}
async fn describe(&self) -> Result<TableDescription> {
let request = self
let version = self.current_version().await;
self.describe_version(version).await
}
async fn describe_version(&self, version: Option<u64>) -> Result<TableDescription> {
let mut request = self
.client
.post(&format!("/v1/table/{}/describe/", self.name));
let body = serde_json::json!({ "version": version });
request = request.json(&body);
let (request_id, response) = self.client.send(request, true).await?;
let response = self.check_table_response(&request_id, response).await?;
@@ -251,6 +267,24 @@ impl<S: HttpSend> RemoteTable<S> {
}
}
}
async fn check_mutable(&self) -> Result<()> {
let read_guard = self.version.read().await;
match *read_guard {
None => Ok(()),
Some(version) => Err(Error::NotSupported {
message: format!(
"Cannot mutate table reference fixed at version {}. Call checkout_latest() to get a mutable table reference.",
version
)
})
}
}
async fn current_version(&self) -> Option<u64> {
let read_guard = self.version.read().await;
*read_guard
}
}
#[derive(Deserialize)]
@@ -278,7 +312,11 @@ mod test_utils {
T: Into<reqwest::Body>,
{
let client = client_with_handler(handler);
Self { client, name }
Self {
client,
name,
version: RwLock::new(None),
}
}
}
}
@@ -297,17 +335,30 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
async fn version(&self) -> Result<u64> {
self.describe().await.map(|desc| desc.version)
}
async fn checkout(&self, _version: u64) -> Result<()> {
Err(Error::NotSupported {
message: "checkout is not supported on LanceDB cloud.".into(),
})
async fn checkout(&self, version: u64) -> Result<()> {
// check that the version exists
self.describe_version(Some(version))
.await
.map_err(|e| match e {
// try to map the error to a more user-friendly error telling them
// specifically that the version does not exist
Error::TableNotFound { name } => Error::TableNotFound {
name: format!("{} (version: {})", name, version),
},
e => e,
})?;
let mut write_guard = self.version.write().await;
*write_guard = Some(version);
Ok(())
}
async fn checkout_latest(&self) -> Result<()> {
Err(Error::NotSupported {
message: "checkout is not supported on LanceDB cloud.".into(),
})
let mut write_guard = self.version.write().await;
*write_guard = None;
Ok(())
}
async fn restore(&self) -> Result<()> {
self.check_mutable().await?;
Err(Error::NotSupported {
message: "restore is not supported on LanceDB cloud.".into(),
})
@@ -321,10 +372,13 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
.client
.post(&format!("/v1/table/{}/count_rows/", self.name));
let version = self.current_version().await;
if let Some(filter) = filter {
request = request.json(&serde_json::json!({ "predicate": filter }));
request = request.json(&serde_json::json!({ "predicate": filter, "version": version }));
} else {
request = request.json(&serde_json::json!({}));
let body = serde_json::json!({ "version": version });
request = request.json(&body);
}
let (request_id, response) = self.client.send(request, true).await?;
@@ -344,6 +398,7 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
add: AddDataBuilder<NoData>,
data: Box<dyn RecordBatchReader + Send>,
) -> Result<()> {
self.check_mutable().await?;
let body = Self::reader_as_body(data)?;
let mut request = self
.client
@@ -372,7 +427,8 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
) -> Result<Arc<dyn ExecutionPlan>> {
let request = self.client.post(&format!("/v1/table/{}/query/", self.name));
let body = serde_json::Value::Object(Default::default());
let version = self.current_version().await;
let body = serde_json::json!({ "version": version });
let bodies = Self::apply_vector_query_params(body, query)?;
let mut futures = Vec::with_capacity(bodies.len());
@@ -407,7 +463,8 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
.post(&format!("/v1/table/{}/query/", self.name))
.header(CONTENT_TYPE, JSON_CONTENT_TYPE);
let mut body = serde_json::Value::Object(Default::default());
let version = self.current_version().await;
let mut body = serde_json::json!({ "version": version });
Self::apply_query_params(&mut body, query)?;
// Empty vector can be passed if no vector search is performed.
body["vector"] = serde_json::Value::Array(Vec::new());
@@ -421,6 +478,7 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
Ok(DatasetRecordBatchStream::new(stream))
}
async fn update(&self, update: UpdateBuilder) -> Result<u64> {
self.check_mutable().await?;
let request = self
.client
.post(&format!("/v1/table/{}/update/", self.name));
@@ -442,6 +500,7 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
Ok(0) // TODO: support returning number of modified rows once supported in SaaS.
}
async fn delete(&self, predicate: &str) -> Result<()> {
self.check_mutable().await?;
let body = serde_json::json!({ "predicate": predicate });
let request = self
.client
@@ -453,6 +512,7 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
}
async fn create_index(&self, mut index: IndexBuilder) -> Result<()> {
self.check_mutable().await?;
let request = self
.client
.post(&format!("/v1/table/{}/create_index/", self.name));
@@ -531,6 +591,7 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
params: MergeInsertBuilder,
new_data: Box<dyn RecordBatchReader + Send>,
) -> Result<()> {
self.check_mutable().await?;
let query = MergeInsertRequest::try_from(params)?;
let body = Self::reader_as_body(new_data)?;
let request = self
@@ -547,6 +608,7 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
Ok(())
}
async fn optimize(&self, _action: OptimizeAction) -> Result<OptimizeStats> {
self.check_mutable().await?;
Err(Error::NotSupported {
message: "optimize is not supported on LanceDB cloud.".into(),
})
@@ -556,16 +618,19 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
_transforms: NewColumnTransform,
_read_columns: Option<Vec<String>>,
) -> Result<()> {
self.check_mutable().await?;
Err(Error::NotSupported {
message: "add_columns is not yet supported.".into(),
})
}
async fn alter_columns(&self, _alterations: &[ColumnAlteration]) -> Result<()> {
self.check_mutable().await?;
Err(Error::NotSupported {
message: "alter_columns is not yet supported.".into(),
})
}
async fn drop_columns(&self, _columns: &[&str]) -> Result<()> {
self.check_mutable().await?;
Err(Error::NotSupported {
message: "drop_columns is not yet supported.".into(),
})
@@ -573,9 +638,13 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
async fn list_indices(&self) -> Result<Vec<IndexConfig>> {
// Make request to list the indices
let request = self
let mut request = self
.client
.post(&format!("/v1/table/{}/index/list/", self.name));
let version = self.current_version().await;
let body = serde_json::json!({ "version": version });
request = request.json(&body);
let (request_id, response) = self.client.send(request, true).await?;
let response = self.check_table_response(&request_id, response).await?;
@@ -625,10 +694,14 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
}
async fn index_stats(&self, index_name: &str) -> Result<Option<IndexStatistics>> {
let request = self.client.post(&format!(
let mut request = self.client.post(&format!(
"/v1/table/{}/index/{}/stats/",
self.name, index_name
));
let version = self.current_version().await;
let body = serde_json::json!({ "version": version });
request = request.json(&body);
let (request_id, response) = self.client.send(request, true).await?;
if response.status() == StatusCode::NOT_FOUND {
@@ -806,7 +879,10 @@ mod tests {
request.headers().get("Content-Type").unwrap(),
JSON_CONTENT_TYPE
);
assert_eq!(request.body().unwrap().as_bytes().unwrap(), br#"{}"#);
assert_eq!(
request.body().unwrap().as_bytes().unwrap(),
br#"{"version":null}"#
);
http::Response::builder().status(200).body("42").unwrap()
});
@@ -823,7 +899,7 @@ mod tests {
);
assert_eq!(
request.body().unwrap().as_bytes().unwrap(),
br#"{"predicate":"a > 10"}"#
br#"{"predicate":"a > 10","version":null}"#
);
http::Response::builder().status(200).body("42").unwrap()
@@ -1124,6 +1200,7 @@ mod tests {
"nprobes": 20,
"ef": Option::<usize>::None,
"refine_factor": null,
"version": null,
});
// Pass vector separately to make sure it matches f32 precision.
expected_body["vector"] = vec![0.1f32, 0.2, 0.3].into();
@@ -1170,6 +1247,7 @@ mod tests {
"nprobes": 12,
"ef": Option::<usize>::None,
"refine_factor": 2,
"version": null,
});
// Pass vector separately to make sure it matches f32 precision.
expected_body["vector"] = vec![0.1f32, 0.2, 0.3].into();
@@ -1225,6 +1303,7 @@ mod tests {
"k": 10,
"vector": [],
"with_row_id": true,
"version": null
});
assert_eq!(body, expected_body);
@@ -1454,4 +1533,195 @@ mod tests {
let indices = table.index_stats("my_index").await.unwrap();
assert!(indices.is_none());
}
#[tokio::test]
async fn test_passes_version() {
let table = Table::new_with_handler("my_table", |request| {
let body = request.body().unwrap().as_bytes().unwrap();
let body: serde_json::Value = serde_json::from_slice(body).unwrap();
let version = body
.as_object()
.unwrap()
.get("version")
.unwrap()
.as_u64()
.unwrap();
assert_eq!(version, 42);
let response_body = match request.url().path() {
"/v1/table/my_table/describe/" => {
serde_json::json!({
"version": 42,
"schema": { "fields": [] }
})
}
"/v1/table/my_table/index/list/" => {
serde_json::json!({
"indexes": []
})
}
"/v1/table/my_table/index/my_idx/stats/" => {
serde_json::json!({
"num_indexed_rows": 100000,
"num_unindexed_rows": 0,
"index_type": "IVF_PQ",
"distance_type": "l2"
})
}
"/v1/table/my_table/count_rows/" => {
serde_json::json!(1000)
}
"/v1/table/my_table/query/" => {
let expected_data = RecordBatch::try_new(
Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])),
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
)
.unwrap();
let expected_data_ref = expected_data.clone();
let response_body = write_ipc_file(&expected_data_ref);
return http::Response::builder()
.status(200)
.header(CONTENT_TYPE, ARROW_FILE_CONTENT_TYPE)
.body(response_body)
.unwrap();
}
path => panic!("Unexpected path: {}", path),
};
http::Response::builder()
.status(200)
.body(
serde_json::to_string(&response_body)
.unwrap()
.as_bytes()
.to_vec(),
)
.unwrap()
});
table.checkout(42).await.unwrap();
// ensure that version is passed to the /describe endpoint
let version = table.version().await.unwrap();
assert_eq!(version, 42);
// ensure it's passed to other read API calls
table.list_indices().await.unwrap();
table.index_stats("my_idx").await.unwrap();
table.count_rows(None).await.unwrap();
table
.query()
.nearest_to(vec![0.1, 0.2, 0.3])
.unwrap()
.execute()
.await
.unwrap();
}
#[tokio::test]
async fn test_fails_if_checkout_version_doesnt_exist() {
let table = Table::new_with_handler("my_table", |request| {
let body = request.body().unwrap().as_bytes().unwrap();
let body: serde_json::Value = serde_json::from_slice(body).unwrap();
let version = body
.as_object()
.unwrap()
.get("version")
.unwrap()
.as_u64()
.unwrap();
if version != 42 {
return http::Response::builder()
.status(404)
.body(format!("Table my_table (version: {}) not found", version))
.unwrap();
}
let response_body = match request.url().path() {
"/v1/table/my_table/describe/" => {
serde_json::json!({
"version": 42,
"schema": { "fields": [] }
})
}
_ => panic!("Unexpected path"),
};
http::Response::builder()
.status(200)
.body(serde_json::to_string(&response_body).unwrap())
.unwrap()
});
let res = table.checkout(43).await;
println!("{:?}", res);
assert!(
matches!(res, Err(Error::TableNotFound { name }) if name == "my_table (version: 43)")
);
}
#[tokio::test]
async fn test_timetravel_immutable() {
let table = Table::new_with_handler::<String>("my_table", |request| {
let response_body = match request.url().path() {
"/v1/table/my_table/describe/" => {
serde_json::json!({
"version": 42,
"schema": { "fields": [] }
})
}
_ => panic!("Should not have made a request: {:?}", request),
};
http::Response::builder()
.status(200)
.body(serde_json::to_string(&response_body).unwrap())
.unwrap()
});
table.checkout(42).await.unwrap();
// Ensure that all mutable operations fail.
let res = table
.update()
.column("a", "a + 1")
.column("b", "b - 1")
.only_if("b > 10")
.execute()
.await;
assert!(matches!(res, Err(Error::NotSupported { .. })));
let batch = RecordBatch::try_new(
Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])),
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
)
.unwrap();
let data = Box::new(RecordBatchIterator::new(
[Ok(batch.clone())],
batch.schema(),
));
let res = table.merge_insert(&["some_col"]).execute(data).await;
assert!(matches!(res, Err(Error::NotSupported { .. })));
let res = table.delete("id in (1, 2, 3)").await;
assert!(matches!(res, Err(Error::NotSupported { .. })));
let data = RecordBatch::try_new(
Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])),
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
)
.unwrap();
let res = table
.add(RecordBatchIterator::new([Ok(data.clone())], data.schema()))
.execute()
.await;
assert!(matches!(res, Err(Error::NotSupported { .. })));
let res = table
.create_index(&["a"], Index::IvfPq(Default::default()))
.execute()
.await;
assert!(matches!(res, Err(Error::NotSupported { .. })));
}
}