feat(rust): remote endpoints for schema, version, count_rows (#1644)

A handful of additional endpoints.
This commit is contained in:
Will Jones
2024-09-16 08:19:25 -07:00
committed by GitHub
parent 32af962c0c
commit ffb28dd4fc
4 changed files with 220 additions and 9 deletions

View File

@@ -21,3 +21,5 @@ pub mod client;
pub mod db;
pub mod table;
pub mod util;
const ARROW_STREAM_CONTENT_TYPE: &str = "application/vnd.apache.arrow.stream";

View File

@@ -31,8 +31,7 @@ use crate::Table;
use super::client::{HttpSend, RestfulLanceDbClient, Sender};
use super::table::RemoteTable;
use super::util::batches_to_ipc_bytes;
const ARROW_STREAM_CONTENT_TYPE: &str = "application/vnd.apache.arrow.stream";
use super::ARROW_STREAM_CONTENT_TYPE;
#[derive(Deserialize)]
struct ListTablesResponse {

View File

@@ -1,12 +1,16 @@
use std::sync::Arc;
use crate::table::dataset::DatasetReadGuard;
use crate::Error;
use arrow_array::RecordBatchReader;
use arrow_schema::SchemaRef;
use async_trait::async_trait;
use datafusion_physical_plan::ExecutionPlan;
use http::StatusCode;
use lance::arrow::json::JsonSchema;
use lance::dataset::scanner::{DatasetRecordBatchStream, Scanner};
use lance::dataset::{ColumnAlteration, NewColumnTransform};
use serde::Deserialize;
use crate::{
connection::NoData,
@@ -32,6 +36,31 @@ impl<S: HttpSend> RemoteTable<S> {
pub fn new(client: RestfulLanceDbClient<S>, name: String) -> Self {
Self { client, name }
}
async fn describe(&self) -> Result<TableDescription> {
let request = self.client.post(&format!("/table/{}/describe/", self.name));
let response = self.client.send(request).await?;
if response.status() == StatusCode::NOT_FOUND {
return Err(Error::TableNotFound {
name: self.name.clone(),
});
}
let response = self.client.check_response(response).await?;
let body = response.text().await?;
serde_json::from_str(&body).map_err(|e| Error::Http {
message: format!("Failed to parse table description: {}", e),
})
}
}
#[derive(Deserialize)]
struct TableDescription {
version: u64,
schema: JsonSchema,
}
impl<S: HttpSend> std::fmt::Display for RemoteTable<S> {
@@ -40,6 +69,24 @@ impl<S: HttpSend> std::fmt::Display for RemoteTable<S> {
}
}
#[cfg(all(test, feature = "remote"))]
mod test_utils {
use super::*;
use crate::remote::client::test_utils::client_with_handler;
use crate::remote::client::test_utils::MockSender;
impl RemoteTable<MockSender> {
pub fn new_mock<F, T>(name: String, handler: F) -> Self
where
F: Fn(reqwest::Request) -> http::Response<T> + Send + Sync + 'static,
T: Into<reqwest::Body>,
{
let client = client_with_handler(handler);
Self { client, name }
}
}
}
#[async_trait]
impl<S: HttpSend> TableInternal for RemoteTable<S> {
fn as_any(&self) -> &dyn std::any::Any {
@@ -52,22 +99,53 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
&self.name
}
async fn version(&self) -> Result<u64> {
todo!()
self.describe().await.map(|desc| desc.version)
}
async fn checkout(&self, _version: u64) -> Result<()> {
todo!()
Err(Error::NotSupported {
message: "checkout is not supported on LanceDB cloud.".into(),
})
}
async fn checkout_latest(&self) -> Result<()> {
todo!()
Err(Error::NotSupported {
message: "checkout is not supported on LanceDB cloud.".into(),
})
}
async fn restore(&self) -> Result<()> {
todo!()
Err(Error::NotSupported {
message: "restore is not supported on LanceDB cloud.".into(),
})
}
async fn schema(&self) -> Result<SchemaRef> {
todo!()
let schema = self.describe().await?.schema;
Ok(Arc::new(schema.try_into()?))
}
async fn count_rows(&self, _filter: Option<String>) -> Result<usize> {
todo!()
async fn count_rows(&self, filter: Option<String>) -> Result<usize> {
let mut request = self
.client
.post(&format!("/table/{}/count_rows/", self.name));
if let Some(filter) = filter {
request = request.json(&serde_json::json!({ "filter": filter }));
} else {
request = request.json(&serde_json::json!({}));
}
let response = self.client.send(request).await?;
if response.status() == StatusCode::NOT_FOUND {
return Err(Error::TableNotFound {
name: self.name.clone(),
});
}
let response = self.client.check_response(response).await?;
let body = response.text().await?;
serde_json::from_str(&body).map_err(|e| Error::Http {
message: format!("Failed to parse row count: {}", e),
})
}
async fn add(
&self,
@@ -140,3 +218,110 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
todo!()
}
}
#[cfg(test)]
mod tests {
use super::*;
use arrow_schema::{DataType, Field, Schema};
use futures::{future::BoxFuture, TryFutureExt};
use crate::{Error, Table};
#[tokio::test]
async fn test_not_found() {
let table = Table::new_with_handler("my_table", |_| {
http::Response::builder()
.status(404)
.body("table my_table not found")
.unwrap()
});
// All endpoints should translate 404 to TableNotFound.
let results: Vec<BoxFuture<'_, Result<()>>> = vec![
Box::pin(table.version().map_ok(|_| ())),
Box::pin(table.schema().map_ok(|_| ())),
Box::pin(table.count_rows(None).map_ok(|_| ())),
// TODO: other endpoints.
];
for result in results {
let result = result.await;
assert!(result.is_err());
assert!(matches!(result, Err(Error::TableNotFound { name }) if name == "my_table"));
}
}
#[tokio::test]
async fn test_version() {
let table = Table::new_with_handler("my_table", |request| {
assert_eq!(request.method(), "POST");
assert_eq!(request.url().path(), "/table/my_table/describe/");
http::Response::builder()
.status(200)
.body(r#"{"version": 42, "schema": { "fields": [] }}"#)
.unwrap()
});
let version = table.version().await.unwrap();
assert_eq!(version, 42);
}
#[tokio::test]
async fn test_schema() {
let table = Table::new_with_handler("my_table", |request| {
assert_eq!(request.method(), "POST");
assert_eq!(request.url().path(), "/table/my_table/describe/");
http::Response::builder()
.status(200)
.body(
r#"{"version": 42, "schema": {"fields": [
{"name": "a", "type": { "type": "int32" }, "nullable": false},
{"name": "b", "type": { "type": "string" }, "nullable": true}
], "metadata": {"key": "value"}}}"#,
)
.unwrap()
});
let expected = Arc::new(
Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Utf8, true),
])
.with_metadata([("key".into(), "value".into())].into()),
);
let schema = table.schema().await.unwrap();
assert_eq!(schema, expected);
}
#[tokio::test]
async fn test_count_rows() {
let table = Table::new_with_handler("my_table", |request| {
assert_eq!(request.method(), "POST");
assert_eq!(request.url().path(), "/table/my_table/count_rows/");
assert_eq!(request.body().unwrap().as_bytes().unwrap(), br#"{}"#);
http::Response::builder().status(200).body("42").unwrap()
});
let count = table.count_rows(None).await.unwrap();
assert_eq!(count, 42);
let table = Table::new_with_handler("my_table", |request| {
assert_eq!(request.method(), "POST");
assert_eq!(request.url().path(), "/table/my_table/count_rows/");
assert_eq!(
request.body().unwrap().as_bytes().unwrap(),
br#"{"filter":"a > 10"}"#
);
http::Response::builder().status(200).body("42").unwrap()
});
let count = table.count_rows(Some("a > 10".into())).await.unwrap();
assert_eq!(count, 42);
}
}

View File

@@ -428,6 +428,31 @@ pub struct Table {
embedding_registry: Arc<dyn EmbeddingRegistry>,
}
#[cfg(all(test, feature = "remote"))]
mod test_utils {
use super::*;
impl Table {
pub fn new_with_handler<T>(
name: impl Into<String>,
handler: impl Fn(reqwest::Request) -> http::Response<T> + Clone + Send + Sync + 'static,
) -> Self
where
T: Into<reqwest::Body>,
{
let inner = Arc::new(crate::remote::table::RemoteTable::new_mock(
name.into(),
handler,
));
Self {
inner,
// Registry is unused.
embedding_registry: Arc::new(MemoryRegistry::new()),
}
}
}
}
impl std::fmt::Display for Table {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.inner)