mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-05 03:12:57 +00:00
feat(rust): remote endpoints for schema, version, count_rows (#1644)
A handful of additional endpoints.
This commit is contained in:
@@ -21,3 +21,5 @@ pub mod client;
|
||||
pub mod db;
|
||||
pub mod table;
|
||||
pub mod util;
|
||||
|
||||
const ARROW_STREAM_CONTENT_TYPE: &str = "application/vnd.apache.arrow.stream";
|
||||
|
||||
@@ -31,8 +31,7 @@ use crate::Table;
|
||||
use super::client::{HttpSend, RestfulLanceDbClient, Sender};
|
||||
use super::table::RemoteTable;
|
||||
use super::util::batches_to_ipc_bytes;
|
||||
|
||||
const ARROW_STREAM_CONTENT_TYPE: &str = "application/vnd.apache.arrow.stream";
|
||||
use super::ARROW_STREAM_CONTENT_TYPE;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct ListTablesResponse {
|
||||
|
||||
@@ -1,12 +1,16 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::table::dataset::DatasetReadGuard;
|
||||
use crate::Error;
|
||||
use arrow_array::RecordBatchReader;
|
||||
use arrow_schema::SchemaRef;
|
||||
use async_trait::async_trait;
|
||||
use datafusion_physical_plan::ExecutionPlan;
|
||||
use http::StatusCode;
|
||||
use lance::arrow::json::JsonSchema;
|
||||
use lance::dataset::scanner::{DatasetRecordBatchStream, Scanner};
|
||||
use lance::dataset::{ColumnAlteration, NewColumnTransform};
|
||||
use serde::Deserialize;
|
||||
|
||||
use crate::{
|
||||
connection::NoData,
|
||||
@@ -32,6 +36,31 @@ impl<S: HttpSend> RemoteTable<S> {
|
||||
pub fn new(client: RestfulLanceDbClient<S>, name: String) -> Self {
|
||||
Self { client, name }
|
||||
}
|
||||
|
||||
async fn describe(&self) -> Result<TableDescription> {
|
||||
let request = self.client.post(&format!("/table/{}/describe/", self.name));
|
||||
let response = self.client.send(request).await?;
|
||||
|
||||
if response.status() == StatusCode::NOT_FOUND {
|
||||
return Err(Error::TableNotFound {
|
||||
name: self.name.clone(),
|
||||
});
|
||||
}
|
||||
|
||||
let response = self.client.check_response(response).await?;
|
||||
|
||||
let body = response.text().await?;
|
||||
|
||||
serde_json::from_str(&body).map_err(|e| Error::Http {
|
||||
message: format!("Failed to parse table description: {}", e),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct TableDescription {
|
||||
version: u64,
|
||||
schema: JsonSchema,
|
||||
}
|
||||
|
||||
impl<S: HttpSend> std::fmt::Display for RemoteTable<S> {
|
||||
@@ -40,6 +69,24 @@ impl<S: HttpSend> std::fmt::Display for RemoteTable<S> {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(test, feature = "remote"))]
|
||||
mod test_utils {
|
||||
use super::*;
|
||||
use crate::remote::client::test_utils::client_with_handler;
|
||||
use crate::remote::client::test_utils::MockSender;
|
||||
|
||||
impl RemoteTable<MockSender> {
|
||||
pub fn new_mock<F, T>(name: String, handler: F) -> Self
|
||||
where
|
||||
F: Fn(reqwest::Request) -> http::Response<T> + Send + Sync + 'static,
|
||||
T: Into<reqwest::Body>,
|
||||
{
|
||||
let client = client_with_handler(handler);
|
||||
Self { client, name }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<S: HttpSend> TableInternal for RemoteTable<S> {
|
||||
fn as_any(&self) -> &dyn std::any::Any {
|
||||
@@ -52,22 +99,53 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
|
||||
&self.name
|
||||
}
|
||||
async fn version(&self) -> Result<u64> {
|
||||
todo!()
|
||||
self.describe().await.map(|desc| desc.version)
|
||||
}
|
||||
async fn checkout(&self, _version: u64) -> Result<()> {
|
||||
todo!()
|
||||
Err(Error::NotSupported {
|
||||
message: "checkout is not supported on LanceDB cloud.".into(),
|
||||
})
|
||||
}
|
||||
async fn checkout_latest(&self) -> Result<()> {
|
||||
todo!()
|
||||
Err(Error::NotSupported {
|
||||
message: "checkout is not supported on LanceDB cloud.".into(),
|
||||
})
|
||||
}
|
||||
async fn restore(&self) -> Result<()> {
|
||||
todo!()
|
||||
Err(Error::NotSupported {
|
||||
message: "restore is not supported on LanceDB cloud.".into(),
|
||||
})
|
||||
}
|
||||
async fn schema(&self) -> Result<SchemaRef> {
|
||||
todo!()
|
||||
let schema = self.describe().await?.schema;
|
||||
Ok(Arc::new(schema.try_into()?))
|
||||
}
|
||||
async fn count_rows(&self, _filter: Option<String>) -> Result<usize> {
|
||||
todo!()
|
||||
async fn count_rows(&self, filter: Option<String>) -> Result<usize> {
|
||||
let mut request = self
|
||||
.client
|
||||
.post(&format!("/table/{}/count_rows/", self.name));
|
||||
|
||||
if let Some(filter) = filter {
|
||||
request = request.json(&serde_json::json!({ "filter": filter }));
|
||||
} else {
|
||||
request = request.json(&serde_json::json!({}));
|
||||
}
|
||||
|
||||
let response = self.client.send(request).await?;
|
||||
|
||||
if response.status() == StatusCode::NOT_FOUND {
|
||||
return Err(Error::TableNotFound {
|
||||
name: self.name.clone(),
|
||||
});
|
||||
}
|
||||
|
||||
let response = self.client.check_response(response).await?;
|
||||
|
||||
let body = response.text().await?;
|
||||
|
||||
serde_json::from_str(&body).map_err(|e| Error::Http {
|
||||
message: format!("Failed to parse row count: {}", e),
|
||||
})
|
||||
}
|
||||
async fn add(
|
||||
&self,
|
||||
@@ -140,3 +218,110 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
use arrow_schema::{DataType, Field, Schema};
|
||||
use futures::{future::BoxFuture, TryFutureExt};
|
||||
|
||||
use crate::{Error, Table};
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_not_found() {
|
||||
let table = Table::new_with_handler("my_table", |_| {
|
||||
http::Response::builder()
|
||||
.status(404)
|
||||
.body("table my_table not found")
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
// All endpoints should translate 404 to TableNotFound.
|
||||
let results: Vec<BoxFuture<'_, Result<()>>> = vec![
|
||||
Box::pin(table.version().map_ok(|_| ())),
|
||||
Box::pin(table.schema().map_ok(|_| ())),
|
||||
Box::pin(table.count_rows(None).map_ok(|_| ())),
|
||||
// TODO: other endpoints.
|
||||
];
|
||||
|
||||
for result in results {
|
||||
let result = result.await;
|
||||
assert!(result.is_err());
|
||||
assert!(matches!(result, Err(Error::TableNotFound { name }) if name == "my_table"));
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_version() {
|
||||
let table = Table::new_with_handler("my_table", |request| {
|
||||
assert_eq!(request.method(), "POST");
|
||||
assert_eq!(request.url().path(), "/table/my_table/describe/");
|
||||
|
||||
http::Response::builder()
|
||||
.status(200)
|
||||
.body(r#"{"version": 42, "schema": { "fields": [] }}"#)
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
let version = table.version().await.unwrap();
|
||||
assert_eq!(version, 42);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_schema() {
|
||||
let table = Table::new_with_handler("my_table", |request| {
|
||||
assert_eq!(request.method(), "POST");
|
||||
assert_eq!(request.url().path(), "/table/my_table/describe/");
|
||||
|
||||
http::Response::builder()
|
||||
.status(200)
|
||||
.body(
|
||||
r#"{"version": 42, "schema": {"fields": [
|
||||
{"name": "a", "type": { "type": "int32" }, "nullable": false},
|
||||
{"name": "b", "type": { "type": "string" }, "nullable": true}
|
||||
], "metadata": {"key": "value"}}}"#,
|
||||
)
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
let expected = Arc::new(
|
||||
Schema::new(vec![
|
||||
Field::new("a", DataType::Int32, false),
|
||||
Field::new("b", DataType::Utf8, true),
|
||||
])
|
||||
.with_metadata([("key".into(), "value".into())].into()),
|
||||
);
|
||||
|
||||
let schema = table.schema().await.unwrap();
|
||||
assert_eq!(schema, expected);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_count_rows() {
|
||||
let table = Table::new_with_handler("my_table", |request| {
|
||||
assert_eq!(request.method(), "POST");
|
||||
assert_eq!(request.url().path(), "/table/my_table/count_rows/");
|
||||
assert_eq!(request.body().unwrap().as_bytes().unwrap(), br#"{}"#);
|
||||
|
||||
http::Response::builder().status(200).body("42").unwrap()
|
||||
});
|
||||
|
||||
let count = table.count_rows(None).await.unwrap();
|
||||
assert_eq!(count, 42);
|
||||
|
||||
let table = Table::new_with_handler("my_table", |request| {
|
||||
assert_eq!(request.method(), "POST");
|
||||
assert_eq!(request.url().path(), "/table/my_table/count_rows/");
|
||||
assert_eq!(
|
||||
request.body().unwrap().as_bytes().unwrap(),
|
||||
br#"{"filter":"a > 10"}"#
|
||||
);
|
||||
|
||||
http::Response::builder().status(200).body("42").unwrap()
|
||||
});
|
||||
|
||||
let count = table.count_rows(Some("a > 10".into())).await.unwrap();
|
||||
assert_eq!(count, 42);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -428,6 +428,31 @@ pub struct Table {
|
||||
embedding_registry: Arc<dyn EmbeddingRegistry>,
|
||||
}
|
||||
|
||||
#[cfg(all(test, feature = "remote"))]
|
||||
mod test_utils {
|
||||
use super::*;
|
||||
|
||||
impl Table {
|
||||
pub fn new_with_handler<T>(
|
||||
name: impl Into<String>,
|
||||
handler: impl Fn(reqwest::Request) -> http::Response<T> + Clone + Send + Sync + 'static,
|
||||
) -> Self
|
||||
where
|
||||
T: Into<reqwest::Body>,
|
||||
{
|
||||
let inner = Arc::new(crate::remote::table::RemoteTable::new_mock(
|
||||
name.into(),
|
||||
handler,
|
||||
));
|
||||
Self {
|
||||
inner,
|
||||
// Registry is unused.
|
||||
embedding_registry: Arc::new(MemoryRegistry::new()),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Table {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.inner)
|
||||
|
||||
Reference in New Issue
Block a user