mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-06 20:02:58 +00:00
feat(rust): remote endpoints for schema, version, count_rows (#1644)
A handful of additional endpoints.
This commit is contained in:
@@ -21,3 +21,5 @@ pub mod client;
|
|||||||
pub mod db;
|
pub mod db;
|
||||||
pub mod table;
|
pub mod table;
|
||||||
pub mod util;
|
pub mod util;
|
||||||
|
|
||||||
|
const ARROW_STREAM_CONTENT_TYPE: &str = "application/vnd.apache.arrow.stream";
|
||||||
|
|||||||
@@ -31,8 +31,7 @@ use crate::Table;
|
|||||||
use super::client::{HttpSend, RestfulLanceDbClient, Sender};
|
use super::client::{HttpSend, RestfulLanceDbClient, Sender};
|
||||||
use super::table::RemoteTable;
|
use super::table::RemoteTable;
|
||||||
use super::util::batches_to_ipc_bytes;
|
use super::util::batches_to_ipc_bytes;
|
||||||
|
use super::ARROW_STREAM_CONTENT_TYPE;
|
||||||
const ARROW_STREAM_CONTENT_TYPE: &str = "application/vnd.apache.arrow.stream";
|
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
struct ListTablesResponse {
|
struct ListTablesResponse {
|
||||||
|
|||||||
@@ -1,12 +1,16 @@
|
|||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use crate::table::dataset::DatasetReadGuard;
|
use crate::table::dataset::DatasetReadGuard;
|
||||||
|
use crate::Error;
|
||||||
use arrow_array::RecordBatchReader;
|
use arrow_array::RecordBatchReader;
|
||||||
use arrow_schema::SchemaRef;
|
use arrow_schema::SchemaRef;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use datafusion_physical_plan::ExecutionPlan;
|
use datafusion_physical_plan::ExecutionPlan;
|
||||||
|
use http::StatusCode;
|
||||||
|
use lance::arrow::json::JsonSchema;
|
||||||
use lance::dataset::scanner::{DatasetRecordBatchStream, Scanner};
|
use lance::dataset::scanner::{DatasetRecordBatchStream, Scanner};
|
||||||
use lance::dataset::{ColumnAlteration, NewColumnTransform};
|
use lance::dataset::{ColumnAlteration, NewColumnTransform};
|
||||||
|
use serde::Deserialize;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
connection::NoData,
|
connection::NoData,
|
||||||
@@ -32,6 +36,31 @@ impl<S: HttpSend> RemoteTable<S> {
|
|||||||
pub fn new(client: RestfulLanceDbClient<S>, name: String) -> Self {
|
pub fn new(client: RestfulLanceDbClient<S>, name: String) -> Self {
|
||||||
Self { client, name }
|
Self { client, name }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn describe(&self) -> Result<TableDescription> {
|
||||||
|
let request = self.client.post(&format!("/table/{}/describe/", self.name));
|
||||||
|
let response = self.client.send(request).await?;
|
||||||
|
|
||||||
|
if response.status() == StatusCode::NOT_FOUND {
|
||||||
|
return Err(Error::TableNotFound {
|
||||||
|
name: self.name.clone(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
let response = self.client.check_response(response).await?;
|
||||||
|
|
||||||
|
let body = response.text().await?;
|
||||||
|
|
||||||
|
serde_json::from_str(&body).map_err(|e| Error::Http {
|
||||||
|
message: format!("Failed to parse table description: {}", e),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct TableDescription {
|
||||||
|
version: u64,
|
||||||
|
schema: JsonSchema,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<S: HttpSend> std::fmt::Display for RemoteTable<S> {
|
impl<S: HttpSend> std::fmt::Display for RemoteTable<S> {
|
||||||
@@ -40,6 +69,24 @@ impl<S: HttpSend> std::fmt::Display for RemoteTable<S> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(all(test, feature = "remote"))]
|
||||||
|
mod test_utils {
|
||||||
|
use super::*;
|
||||||
|
use crate::remote::client::test_utils::client_with_handler;
|
||||||
|
use crate::remote::client::test_utils::MockSender;
|
||||||
|
|
||||||
|
impl RemoteTable<MockSender> {
|
||||||
|
pub fn new_mock<F, T>(name: String, handler: F) -> Self
|
||||||
|
where
|
||||||
|
F: Fn(reqwest::Request) -> http::Response<T> + Send + Sync + 'static,
|
||||||
|
T: Into<reqwest::Body>,
|
||||||
|
{
|
||||||
|
let client = client_with_handler(handler);
|
||||||
|
Self { client, name }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl<S: HttpSend> TableInternal for RemoteTable<S> {
|
impl<S: HttpSend> TableInternal for RemoteTable<S> {
|
||||||
fn as_any(&self) -> &dyn std::any::Any {
|
fn as_any(&self) -> &dyn std::any::Any {
|
||||||
@@ -52,22 +99,53 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
|
|||||||
&self.name
|
&self.name
|
||||||
}
|
}
|
||||||
async fn version(&self) -> Result<u64> {
|
async fn version(&self) -> Result<u64> {
|
||||||
todo!()
|
self.describe().await.map(|desc| desc.version)
|
||||||
}
|
}
|
||||||
async fn checkout(&self, _version: u64) -> Result<()> {
|
async fn checkout(&self, _version: u64) -> Result<()> {
|
||||||
todo!()
|
Err(Error::NotSupported {
|
||||||
|
message: "checkout is not supported on LanceDB cloud.".into(),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
async fn checkout_latest(&self) -> Result<()> {
|
async fn checkout_latest(&self) -> Result<()> {
|
||||||
todo!()
|
Err(Error::NotSupported {
|
||||||
|
message: "checkout is not supported on LanceDB cloud.".into(),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
async fn restore(&self) -> Result<()> {
|
async fn restore(&self) -> Result<()> {
|
||||||
todo!()
|
Err(Error::NotSupported {
|
||||||
|
message: "restore is not supported on LanceDB cloud.".into(),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
async fn schema(&self) -> Result<SchemaRef> {
|
async fn schema(&self) -> Result<SchemaRef> {
|
||||||
todo!()
|
let schema = self.describe().await?.schema;
|
||||||
|
Ok(Arc::new(schema.try_into()?))
|
||||||
}
|
}
|
||||||
async fn count_rows(&self, _filter: Option<String>) -> Result<usize> {
|
async fn count_rows(&self, filter: Option<String>) -> Result<usize> {
|
||||||
todo!()
|
let mut request = self
|
||||||
|
.client
|
||||||
|
.post(&format!("/table/{}/count_rows/", self.name));
|
||||||
|
|
||||||
|
if let Some(filter) = filter {
|
||||||
|
request = request.json(&serde_json::json!({ "filter": filter }));
|
||||||
|
} else {
|
||||||
|
request = request.json(&serde_json::json!({}));
|
||||||
|
}
|
||||||
|
|
||||||
|
let response = self.client.send(request).await?;
|
||||||
|
|
||||||
|
if response.status() == StatusCode::NOT_FOUND {
|
||||||
|
return Err(Error::TableNotFound {
|
||||||
|
name: self.name.clone(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
let response = self.client.check_response(response).await?;
|
||||||
|
|
||||||
|
let body = response.text().await?;
|
||||||
|
|
||||||
|
serde_json::from_str(&body).map_err(|e| Error::Http {
|
||||||
|
message: format!("Failed to parse row count: {}", e),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
async fn add(
|
async fn add(
|
||||||
&self,
|
&self,
|
||||||
@@ -140,3 +218,110 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
|
|||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
use arrow_schema::{DataType, Field, Schema};
|
||||||
|
use futures::{future::BoxFuture, TryFutureExt};
|
||||||
|
|
||||||
|
use crate::{Error, Table};
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_not_found() {
|
||||||
|
let table = Table::new_with_handler("my_table", |_| {
|
||||||
|
http::Response::builder()
|
||||||
|
.status(404)
|
||||||
|
.body("table my_table not found")
|
||||||
|
.unwrap()
|
||||||
|
});
|
||||||
|
|
||||||
|
// All endpoints should translate 404 to TableNotFound.
|
||||||
|
let results: Vec<BoxFuture<'_, Result<()>>> = vec![
|
||||||
|
Box::pin(table.version().map_ok(|_| ())),
|
||||||
|
Box::pin(table.schema().map_ok(|_| ())),
|
||||||
|
Box::pin(table.count_rows(None).map_ok(|_| ())),
|
||||||
|
// TODO: other endpoints.
|
||||||
|
];
|
||||||
|
|
||||||
|
for result in results {
|
||||||
|
let result = result.await;
|
||||||
|
assert!(result.is_err());
|
||||||
|
assert!(matches!(result, Err(Error::TableNotFound { name }) if name == "my_table"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_version() {
|
||||||
|
let table = Table::new_with_handler("my_table", |request| {
|
||||||
|
assert_eq!(request.method(), "POST");
|
||||||
|
assert_eq!(request.url().path(), "/table/my_table/describe/");
|
||||||
|
|
||||||
|
http::Response::builder()
|
||||||
|
.status(200)
|
||||||
|
.body(r#"{"version": 42, "schema": { "fields": [] }}"#)
|
||||||
|
.unwrap()
|
||||||
|
});
|
||||||
|
|
||||||
|
let version = table.version().await.unwrap();
|
||||||
|
assert_eq!(version, 42);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_schema() {
|
||||||
|
let table = Table::new_with_handler("my_table", |request| {
|
||||||
|
assert_eq!(request.method(), "POST");
|
||||||
|
assert_eq!(request.url().path(), "/table/my_table/describe/");
|
||||||
|
|
||||||
|
http::Response::builder()
|
||||||
|
.status(200)
|
||||||
|
.body(
|
||||||
|
r#"{"version": 42, "schema": {"fields": [
|
||||||
|
{"name": "a", "type": { "type": "int32" }, "nullable": false},
|
||||||
|
{"name": "b", "type": { "type": "string" }, "nullable": true}
|
||||||
|
], "metadata": {"key": "value"}}}"#,
|
||||||
|
)
|
||||||
|
.unwrap()
|
||||||
|
});
|
||||||
|
|
||||||
|
let expected = Arc::new(
|
||||||
|
Schema::new(vec![
|
||||||
|
Field::new("a", DataType::Int32, false),
|
||||||
|
Field::new("b", DataType::Utf8, true),
|
||||||
|
])
|
||||||
|
.with_metadata([("key".into(), "value".into())].into()),
|
||||||
|
);
|
||||||
|
|
||||||
|
let schema = table.schema().await.unwrap();
|
||||||
|
assert_eq!(schema, expected);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_count_rows() {
|
||||||
|
let table = Table::new_with_handler("my_table", |request| {
|
||||||
|
assert_eq!(request.method(), "POST");
|
||||||
|
assert_eq!(request.url().path(), "/table/my_table/count_rows/");
|
||||||
|
assert_eq!(request.body().unwrap().as_bytes().unwrap(), br#"{}"#);
|
||||||
|
|
||||||
|
http::Response::builder().status(200).body("42").unwrap()
|
||||||
|
});
|
||||||
|
|
||||||
|
let count = table.count_rows(None).await.unwrap();
|
||||||
|
assert_eq!(count, 42);
|
||||||
|
|
||||||
|
let table = Table::new_with_handler("my_table", |request| {
|
||||||
|
assert_eq!(request.method(), "POST");
|
||||||
|
assert_eq!(request.url().path(), "/table/my_table/count_rows/");
|
||||||
|
assert_eq!(
|
||||||
|
request.body().unwrap().as_bytes().unwrap(),
|
||||||
|
br#"{"filter":"a > 10"}"#
|
||||||
|
);
|
||||||
|
|
||||||
|
http::Response::builder().status(200).body("42").unwrap()
|
||||||
|
});
|
||||||
|
|
||||||
|
let count = table.count_rows(Some("a > 10".into())).await.unwrap();
|
||||||
|
assert_eq!(count, 42);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -428,6 +428,31 @@ pub struct Table {
|
|||||||
embedding_registry: Arc<dyn EmbeddingRegistry>,
|
embedding_registry: Arc<dyn EmbeddingRegistry>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(all(test, feature = "remote"))]
|
||||||
|
mod test_utils {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
impl Table {
|
||||||
|
pub fn new_with_handler<T>(
|
||||||
|
name: impl Into<String>,
|
||||||
|
handler: impl Fn(reqwest::Request) -> http::Response<T> + Clone + Send + Sync + 'static,
|
||||||
|
) -> Self
|
||||||
|
where
|
||||||
|
T: Into<reqwest::Body>,
|
||||||
|
{
|
||||||
|
let inner = Arc::new(crate::remote::table::RemoteTable::new_mock(
|
||||||
|
name.into(),
|
||||||
|
handler,
|
||||||
|
));
|
||||||
|
Self {
|
||||||
|
inner,
|
||||||
|
// Registry is unused.
|
||||||
|
embedding_registry: Arc::new(MemoryRegistry::new()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl std::fmt::Display for Table {
|
impl std::fmt::Display for Table {
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
write!(f, "{}", self.inner)
|
write!(f, "{}", self.inner)
|
||||||
|
|||||||
Reference in New Issue
Block a user