From ffb28dd4fc204b89d46ece4a60c03d331f11d591 Mon Sep 17 00:00:00 2001 From: Will Jones Date: Mon, 16 Sep 2024 08:19:25 -0700 Subject: [PATCH] feat(rust): remote endpoints for schema, version, count_rows (#1644) A handful of additional endpoints. --- rust/lancedb/src/remote.rs | 2 + rust/lancedb/src/remote/db.rs | 3 +- rust/lancedb/src/remote/table.rs | 199 +++++++++++++++++++++++++++++-- rust/lancedb/src/table.rs | 25 ++++ 4 files changed, 220 insertions(+), 9 deletions(-) diff --git a/rust/lancedb/src/remote.rs b/rust/lancedb/src/remote.rs index dfdf6224..ce00a370 100644 --- a/rust/lancedb/src/remote.rs +++ b/rust/lancedb/src/remote.rs @@ -21,3 +21,5 @@ pub mod client; pub mod db; pub mod table; pub mod util; + +const ARROW_STREAM_CONTENT_TYPE: &str = "application/vnd.apache.arrow.stream"; diff --git a/rust/lancedb/src/remote/db.rs b/rust/lancedb/src/remote/db.rs index 5198d0c6..93fbe86e 100644 --- a/rust/lancedb/src/remote/db.rs +++ b/rust/lancedb/src/remote/db.rs @@ -31,8 +31,7 @@ use crate::Table; use super::client::{HttpSend, RestfulLanceDbClient, Sender}; use super::table::RemoteTable; use super::util::batches_to_ipc_bytes; - -const ARROW_STREAM_CONTENT_TYPE: &str = "application/vnd.apache.arrow.stream"; +use super::ARROW_STREAM_CONTENT_TYPE; #[derive(Deserialize)] struct ListTablesResponse { diff --git a/rust/lancedb/src/remote/table.rs b/rust/lancedb/src/remote/table.rs index 3406889d..44cac37a 100644 --- a/rust/lancedb/src/remote/table.rs +++ b/rust/lancedb/src/remote/table.rs @@ -1,12 +1,16 @@ use std::sync::Arc; use crate::table::dataset::DatasetReadGuard; +use crate::Error; use arrow_array::RecordBatchReader; use arrow_schema::SchemaRef; use async_trait::async_trait; use datafusion_physical_plan::ExecutionPlan; +use http::StatusCode; +use lance::arrow::json::JsonSchema; use lance::dataset::scanner::{DatasetRecordBatchStream, Scanner}; use lance::dataset::{ColumnAlteration, NewColumnTransform}; +use serde::Deserialize; use crate::{ connection::NoData, @@ -32,6 +36,31 @@ impl RemoteTable { pub fn new(client: RestfulLanceDbClient, name: String) -> Self { Self { client, name } } + + async fn describe(&self) -> Result { + let request = self.client.post(&format!("/table/{}/describe/", self.name)); + let response = self.client.send(request).await?; + + if response.status() == StatusCode::NOT_FOUND { + return Err(Error::TableNotFound { + name: self.name.clone(), + }); + } + + let response = self.client.check_response(response).await?; + + let body = response.text().await?; + + serde_json::from_str(&body).map_err(|e| Error::Http { + message: format!("Failed to parse table description: {}", e), + }) + } +} + +#[derive(Deserialize)] +struct TableDescription { + version: u64, + schema: JsonSchema, } impl std::fmt::Display for RemoteTable { @@ -40,6 +69,24 @@ impl std::fmt::Display for RemoteTable { } } +#[cfg(all(test, feature = "remote"))] +mod test_utils { + use super::*; + use crate::remote::client::test_utils::client_with_handler; + use crate::remote::client::test_utils::MockSender; + + impl RemoteTable { + pub fn new_mock(name: String, handler: F) -> Self + where + F: Fn(reqwest::Request) -> http::Response + Send + Sync + 'static, + T: Into, + { + let client = client_with_handler(handler); + Self { client, name } + } + } +} + #[async_trait] impl TableInternal for RemoteTable { fn as_any(&self) -> &dyn std::any::Any { @@ -52,22 +99,53 @@ impl TableInternal for RemoteTable { &self.name } async fn version(&self) -> Result { - todo!() + self.describe().await.map(|desc| desc.version) } async fn checkout(&self, _version: u64) -> Result<()> { - todo!() + Err(Error::NotSupported { + message: "checkout is not supported on LanceDB cloud.".into(), + }) } async fn checkout_latest(&self) -> Result<()> { - todo!() + Err(Error::NotSupported { + message: "checkout is not supported on LanceDB cloud.".into(), + }) } async fn restore(&self) -> Result<()> { - todo!() + Err(Error::NotSupported { + message: "restore is not supported on LanceDB cloud.".into(), + }) } async fn schema(&self) -> Result { - todo!() + let schema = self.describe().await?.schema; + Ok(Arc::new(schema.try_into()?)) } - async fn count_rows(&self, _filter: Option) -> Result { - todo!() + async fn count_rows(&self, filter: Option) -> Result { + let mut request = self + .client + .post(&format!("/table/{}/count_rows/", self.name)); + + if let Some(filter) = filter { + request = request.json(&serde_json::json!({ "filter": filter })); + } else { + request = request.json(&serde_json::json!({})); + } + + let response = self.client.send(request).await?; + + if response.status() == StatusCode::NOT_FOUND { + return Err(Error::TableNotFound { + name: self.name.clone(), + }); + } + + let response = self.client.check_response(response).await?; + + let body = response.text().await?; + + serde_json::from_str(&body).map_err(|e| Error::Http { + message: format!("Failed to parse row count: {}", e), + }) } async fn add( &self, @@ -140,3 +218,110 @@ impl TableInternal for RemoteTable { todo!() } } + +#[cfg(test)] +mod tests { + use super::*; + + use arrow_schema::{DataType, Field, Schema}; + use futures::{future::BoxFuture, TryFutureExt}; + + use crate::{Error, Table}; + + #[tokio::test] + async fn test_not_found() { + let table = Table::new_with_handler("my_table", |_| { + http::Response::builder() + .status(404) + .body("table my_table not found") + .unwrap() + }); + + // All endpoints should translate 404 to TableNotFound. + let results: Vec>> = vec![ + Box::pin(table.version().map_ok(|_| ())), + Box::pin(table.schema().map_ok(|_| ())), + Box::pin(table.count_rows(None).map_ok(|_| ())), + // TODO: other endpoints. + ]; + + for result in results { + let result = result.await; + assert!(result.is_err()); + assert!(matches!(result, Err(Error::TableNotFound { name }) if name == "my_table")); + } + } + + #[tokio::test] + async fn test_version() { + let table = Table::new_with_handler("my_table", |request| { + assert_eq!(request.method(), "POST"); + assert_eq!(request.url().path(), "/table/my_table/describe/"); + + http::Response::builder() + .status(200) + .body(r#"{"version": 42, "schema": { "fields": [] }}"#) + .unwrap() + }); + + let version = table.version().await.unwrap(); + assert_eq!(version, 42); + } + + #[tokio::test] + async fn test_schema() { + let table = Table::new_with_handler("my_table", |request| { + assert_eq!(request.method(), "POST"); + assert_eq!(request.url().path(), "/table/my_table/describe/"); + + http::Response::builder() + .status(200) + .body( + r#"{"version": 42, "schema": {"fields": [ + {"name": "a", "type": { "type": "int32" }, "nullable": false}, + {"name": "b", "type": { "type": "string" }, "nullable": true} + ], "metadata": {"key": "value"}}}"#, + ) + .unwrap() + }); + + let expected = Arc::new( + Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Utf8, true), + ]) + .with_metadata([("key".into(), "value".into())].into()), + ); + + let schema = table.schema().await.unwrap(); + assert_eq!(schema, expected); + } + + #[tokio::test] + async fn test_count_rows() { + let table = Table::new_with_handler("my_table", |request| { + assert_eq!(request.method(), "POST"); + assert_eq!(request.url().path(), "/table/my_table/count_rows/"); + assert_eq!(request.body().unwrap().as_bytes().unwrap(), br#"{}"#); + + http::Response::builder().status(200).body("42").unwrap() + }); + + let count = table.count_rows(None).await.unwrap(); + assert_eq!(count, 42); + + let table = Table::new_with_handler("my_table", |request| { + assert_eq!(request.method(), "POST"); + assert_eq!(request.url().path(), "/table/my_table/count_rows/"); + assert_eq!( + request.body().unwrap().as_bytes().unwrap(), + br#"{"filter":"a > 10"}"# + ); + + http::Response::builder().status(200).body("42").unwrap() + }); + + let count = table.count_rows(Some("a > 10".into())).await.unwrap(); + assert_eq!(count, 42); + } +} diff --git a/rust/lancedb/src/table.rs b/rust/lancedb/src/table.rs index 2244ef7f..0044d0fc 100644 --- a/rust/lancedb/src/table.rs +++ b/rust/lancedb/src/table.rs @@ -428,6 +428,31 @@ pub struct Table { embedding_registry: Arc, } +#[cfg(all(test, feature = "remote"))] +mod test_utils { + use super::*; + + impl Table { + pub fn new_with_handler( + name: impl Into, + handler: impl Fn(reqwest::Request) -> http::Response + Clone + Send + Sync + 'static, + ) -> Self + where + T: Into, + { + let inner = Arc::new(crate::remote::table::RemoteTable::new_mock( + name.into(), + handler, + )); + Self { + inner, + // Registry is unused. + embedding_registry: Arc::new(MemoryRegistry::new()), + } + } + } +} + impl std::fmt::Display for Table { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}", self.inner)