From 1c123b58d8a17d7b1b63c56f52242ad9fc7c7c18 Mon Sep 17 00:00:00 2001 From: Will Jones Date: Fri, 13 Sep 2024 10:53:27 -0700 Subject: [PATCH] feat: implement Remote connection for LanceDB Rust (#1639) * Adding a simple test facility, which allows you to mock a single endpoint at a time with a closure. * Implementing all the database-level endpoints Table-level APIs will be done in a follow up PR. --------- Co-authored-by: Weston Pace --- rust/lancedb/Cargo.toml | 3 +- rust/lancedb/src/connection.rs | 27 +++- rust/lancedb/src/remote/client.rs | 139 +++++++++++++---- rust/lancedb/src/remote/db.rs | 245 ++++++++++++++++++++++++++++-- rust/lancedb/src/remote/table.rs | 14 +- 5 files changed, 368 insertions(+), 60 deletions(-) diff --git a/rust/lancedb/Cargo.toml b/rust/lancedb/Cargo.toml index 20f144582..1fd48d9f1 100644 --- a/rust/lancedb/Cargo.toml +++ b/rust/lancedb/Cargo.toml @@ -46,6 +46,7 @@ async-openai = { version = "0.20.0", optional = true } serde_with = { version = "3.8.1" } # For remote feature reqwest = { version = "0.11.24", features = ["gzip", "json"], optional = true } +http = { version = "0.2", optional = true } # Matching what is in reqwest polars-arrow = { version = ">=0.37,<0.40.0", optional = true } polars = { version = ">=0.37,<0.40.0", optional = true } hf-hub = { version = "0.3.2", optional = true } @@ -68,7 +69,7 @@ aws-smithy-runtime = { version = "1.3" } [features] default = [] -remote = ["dep:reqwest"] +remote = ["dep:reqwest", "dep:http"] fp16kernels = ["lance-linalg/fp16kernels"] s3-test = [] openai = ["dep:async-openai", "dep:reqwest"] diff --git a/rust/lancedb/src/connection.rs b/rust/lancedb/src/connection.rs index e46899a48..e660c1e51 100644 --- a/rust/lancedb/src/connection.rs +++ b/rust/lancedb/src/connection.rs @@ -335,8 +335,8 @@ impl CreateTableBuilder { #[derive(Clone, Debug)] pub struct OpenTableBuilder { - parent: Arc, - name: String, + pub(crate) parent: Arc, + pub(crate) name: String, index_cache_size: u32, lance_read_params: Option, } @@ -1095,6 +1095,25 @@ impl ConnectionInternal for Database { } } +#[cfg(all(test, feature = "remote"))] +mod test_utils { + use super::*; + impl Connection { + pub fn new_with_handler( + handler: impl Fn(reqwest::Request) -> http::Response + Clone + Send + Sync + 'static, + ) -> Self + where + T: Into, + { + let internal = Arc::new(crate::remote::db::RemoteDatabase::new_mock(handler)); + Self { + internal, + uri: "db://test".to_string(), + } + } + } +} + #[cfg(test)] mod tests { use arrow_schema::{DataType, Field, Schema}; @@ -1208,9 +1227,9 @@ mod tests { assert_eq!(tables, vec!["table1".to_owned()]); } - fn make_data() -> impl RecordBatchReader + Send + 'static { + fn make_data() -> Box { let id = Box::new(IncrementingInt32::new().named("id".to_string())); - BatchGenerator::new().col(id).batches(10, 2000) + Box::new(BatchGenerator::new().col(id).batches(10, 2000)) } #[tokio::test] diff --git a/rust/lancedb/src/remote/client.rs b/rust/lancedb/src/remote/client.rs index 2b2c1fd74..41964e0c1 100644 --- a/rust/lancedb/src/remote/client.rs +++ b/rust/lancedb/src/remote/client.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::time::Duration; +use std::{future::Future, time::Duration}; use reqwest::{ header::{HeaderMap, HeaderValue}, @@ -21,13 +21,66 @@ use reqwest::{ use crate::error::{Error, Result}; +// We use the `HttpSend` trait to abstract over the `reqwest::Client` so that +// we can mock responses in tests. Based on the patterns from this blog post: +// https://write.as/balrogboogie/testing-reqwest-based-clients #[derive(Clone, Debug)] -pub struct RestfulLanceDbClient { +pub struct RestfulLanceDbClient { client: reqwest::Client, host: String, + sender: S, } -impl RestfulLanceDbClient { +pub trait HttpSend: Clone + Send + Sync + std::fmt::Debug + 'static { + fn send(&self, req: RequestBuilder) -> impl Future> + Send; +} + +// Default implementation of HttpSend which sends the request normally with reqwest +#[derive(Clone, Debug)] +pub struct Sender; +impl HttpSend for Sender { + async fn send(&self, request: reqwest::RequestBuilder) -> Result { + Ok(request.send().await?) + } +} + +impl RestfulLanceDbClient { + pub fn try_new( + db_url: &str, + api_key: &str, + region: &str, + host_override: Option, + ) -> Result { + let parsed_url = url::Url::parse(db_url)?; + debug_assert_eq!(parsed_url.scheme(), "db"); + if !parsed_url.has_host() { + return Err(Error::Http { + message: format!("Invalid database URL (missing host) '{}'", db_url), + }); + } + let db_name = parsed_url.host_str().unwrap(); + let client = reqwest::Client::builder() + .timeout(Duration::from_secs(30)) + .default_headers(Self::default_headers( + api_key, + region, + db_name, + host_override.is_some(), + )?) + .build()?; + let host = match host_override { + Some(host_override) => host_override, + None => format!("https://{}.{}.api.lancedb.com", db_name, region), + }; + Ok(Self { + client, + host, + sender: Sender, + }) + } +} + +impl RestfulLanceDbClient { pub fn host(&self) -> &str { &self.host } @@ -66,36 +119,6 @@ impl RestfulLanceDbClient { Ok(headers) } - pub fn try_new( - db_url: &str, - api_key: &str, - region: &str, - host_override: Option, - ) -> Result { - let parsed_url = url::Url::parse(db_url)?; - debug_assert_eq!(parsed_url.scheme(), "db"); - if !parsed_url.has_host() { - return Err(Error::Http { - message: format!("Invalid database URL (missing host) '{}'", db_url), - }); - } - let db_name = parsed_url.host_str().unwrap(); - let client = reqwest::Client::builder() - .timeout(Duration::from_secs(30)) - .default_headers(Self::default_headers( - api_key, - region, - db_name, - host_override.is_some(), - )?) - .build()?; - let host = match host_override { - Some(host_override) => host_override, - None => format!("https://{}.{}.api.lancedb.com", db_name, region), - }; - Ok(Self { client, host }) - } - pub fn get(&self, uri: &str) -> RequestBuilder { let full_uri = format!("{}{}", self.host, uri); self.client.get(full_uri) @@ -106,6 +129,10 @@ impl RestfulLanceDbClient { self.client.post(full_uri) } + pub async fn send(&self, req: RequestBuilder) -> Result { + self.sender.send(req).await + } + async fn rsp_to_str(response: Response) -> String { let status = response.status(); response.text().await.unwrap_or_else(|_| status.to_string()) @@ -126,3 +153,49 @@ impl RestfulLanceDbClient { } } } + +#[cfg(test)] +pub mod test_utils { + use std::sync::Arc; + + use super::*; + + #[derive(Clone)] + pub struct MockSender { + f: Arc reqwest::Response + Send + Sync + 'static>, + } + + impl std::fmt::Debug for MockSender { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "MockSender") + } + } + + impl HttpSend for MockSender { + async fn send(&self, request: reqwest::RequestBuilder) -> Result { + let request = request.build().unwrap(); + let response = (self.f)(request); + Ok(response) + } + } + + pub fn client_with_handler( + handler: impl Fn(reqwest::Request) -> http::response::Response + Send + Sync + 'static, + ) -> RestfulLanceDbClient + where + T: Into, + { + let wrapper = move |req: reqwest::Request| { + let response = handler(req); + response.into() + }; + + RestfulLanceDbClient { + client: reqwest::Client::new(), + host: "http://localhost".to_string(), + sender: MockSender { + f: Arc::new(wrapper), + }, + } + } +} diff --git a/rust/lancedb/src/remote/db.rs b/rust/lancedb/src/remote/db.rs index b8062367d..5198d0c61 100644 --- a/rust/lancedb/src/remote/db.rs +++ b/rust/lancedb/src/remote/db.rs @@ -16,6 +16,7 @@ use std::sync::Arc; use arrow_array::RecordBatchReader; use async_trait::async_trait; +use http::StatusCode; use reqwest::header::CONTENT_TYPE; use serde::Deserialize; use tokio::task::spawn_blocking; @@ -27,7 +28,7 @@ use crate::embeddings::EmbeddingRegistry; use crate::error::Result; use crate::Table; -use super::client::RestfulLanceDbClient; +use super::client::{HttpSend, RestfulLanceDbClient, Sender}; use super::table::RemoteTable; use super::util::batches_to_ipc_bytes; @@ -39,8 +40,8 @@ struct ListTablesResponse { } #[derive(Debug)] -pub struct RemoteDatabase { - client: RestfulLanceDbClient, +pub struct RemoteDatabase { + client: RestfulLanceDbClient, } impl RemoteDatabase { @@ -55,14 +56,32 @@ impl RemoteDatabase { } } -impl std::fmt::Display for RemoteDatabase { +#[cfg(all(test, feature = "remote"))] +mod test_utils { + use super::*; + use crate::remote::client::test_utils::client_with_handler; + use crate::remote::client::test_utils::MockSender; + + impl RemoteDatabase { + pub fn new_mock(handler: F) -> Self + where + F: Fn(reqwest::Request) -> http::Response + Send + Sync + 'static, + T: Into, + { + let client = client_with_handler(handler); + Self { client } + } + } +} + +impl std::fmt::Display for RemoteDatabase { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "RemoteDatabase(host={})", self.client.host()) } } #[async_trait] -impl ConnectionInternal for RemoteDatabase { +impl ConnectionInternal for RemoteDatabase { async fn table_names(&self, options: TableNamesBuilder) -> Result> { let mut req = self.client.get("/v1/table/"); if let Some(limit) = options.limit { @@ -71,7 +90,7 @@ impl ConnectionInternal for RemoteDatabase { if let Some(start_after) = options.start_after { req = req.query(&[("page_token", start_after)]); } - let rsp = req.send().await?; + let rsp = self.client.send(req).await?; let rsp = self.client.check_response(rsp).await?; Ok(rsp.json::().await?.tables) } @@ -88,15 +107,24 @@ impl ConnectionInternal for RemoteDatabase { .await .unwrap()?; - let rsp = self + let req = self .client .post(&format!("/v1/table/{}/create/", options.name)) .body(data_buffer) .header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE) // This is currently expected by LanceDb cloud but will be removed soon. - .header("x-request-id", "na") - .send() - .await?; + .header("x-request-id", "na"); + let rsp = self.client.send(req).await?; + + if rsp.status() == StatusCode::BAD_REQUEST { + let body = rsp.text().await?; + if body.contains("already exists") { + return Err(crate::Error::TableAlreadyExists { name: options.name }); + } else { + return Err(crate::Error::InvalidInput { message: body }); + } + } + self.client.check_response(rsp).await?; Ok(Table::new(Arc::new(RemoteTable::new( @@ -105,19 +133,206 @@ impl ConnectionInternal for RemoteDatabase { )))) } - async fn do_open_table(&self, _options: OpenTableBuilder) -> Result { - todo!() + async fn do_open_table(&self, options: OpenTableBuilder) -> Result
{ + // We describe the table to confirm it exists before moving on. + // TODO: a TTL cache of table existence + let req = self + .client + .get(&format!("/v1/table/{}/describe/", options.name)); + let resp = self.client.send(req).await?; + if resp.status() == StatusCode::NOT_FOUND { + return Err(crate::Error::TableNotFound { name: options.name }); + } + self.client.check_response(resp).await?; + Ok(Table::new(Arc::new(RemoteTable::new( + self.client.clone(), + options.name, + )))) } - async fn drop_table(&self, _name: &str) -> Result<()> { - todo!() + async fn drop_table(&self, name: &str) -> Result<()> { + let req = self.client.post(&format!("/v1/table/{}/drop/", name)); + let resp = self.client.send(req).await?; + self.client.check_response(resp).await?; + Ok(()) } async fn drop_db(&self) -> Result<()> { - todo!() + Err(crate::Error::NotSupported { + message: "Dropping databases is not supported in the remote API".to_string(), + }) } fn embedding_registry(&self) -> &dyn EmbeddingRegistry { todo!() } } + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use arrow_array::{Int32Array, RecordBatch, RecordBatchIterator}; + use arrow_schema::{DataType, Field, Schema}; + + use crate::{remote::db::ARROW_STREAM_CONTENT_TYPE, Connection}; + + #[tokio::test] + async fn test_table_names() { + let conn = Connection::new_with_handler(|request| { + assert_eq!(request.method(), &reqwest::Method::GET); + assert_eq!(request.url().path(), "/v1/table/"); + assert_eq!(request.url().query(), None); + + http::Response::builder() + .status(200) + .body(r#"{"tables": ["table1", "table2"]}"#) + .unwrap() + }); + let names = conn.table_names().execute().await.unwrap(); + assert_eq!(names, vec!["table1", "table2"]); + } + + #[tokio::test] + async fn test_table_names_pagination() { + let conn = Connection::new_with_handler(|request| { + assert_eq!(request.method(), &reqwest::Method::GET); + assert_eq!(request.url().path(), "/v1/table/"); + assert!(request.url().query().unwrap().contains("limit=2")); + assert!(request.url().query().unwrap().contains("page_token=table2")); + + http::Response::builder() + .status(200) + .body(r#"{"tables": ["table3", "table4"], "page_token": "token"}"#) + .unwrap() + }); + let names = conn + .table_names() + .start_after("table2") + .limit(2) + .execute() + .await + .unwrap(); + assert_eq!(names, vec!["table3", "table4"]); + } + + #[tokio::test] + async fn test_open_table() { + let conn = Connection::new_with_handler(|request| { + assert_eq!(request.method(), &reqwest::Method::GET); + assert_eq!(request.url().path(), "/v1/table/table1/describe/"); + assert_eq!(request.url().query(), None); + + http::Response::builder() + .status(200) + .body(r#"{"table": "table1"}"#) + .unwrap() + }); + let table = conn.open_table("table1").execute().await.unwrap(); + assert_eq!(table.name(), "table1"); + + // Storage options should be ignored. + let table = conn + .open_table("table1") + .storage_option("key", "value") + .execute() + .await + .unwrap(); + assert_eq!(table.name(), "table1"); + } + + #[tokio::test] + async fn test_open_table_not_found() { + let conn = Connection::new_with_handler(|_| { + http::Response::builder() + .status(404) + .body("table not found") + .unwrap() + }); + let result = conn.open_table("table1").execute().await; + assert!(result.is_err()); + assert!(matches!(result, Err(crate::Error::TableNotFound { .. }))); + } + + #[tokio::test] + async fn test_create_table() { + let conn = Connection::new_with_handler(|request| { + assert_eq!(request.method(), &reqwest::Method::POST); + assert_eq!(request.url().path(), "/v1/table/table1/create/"); + assert_eq!( + request + .headers() + .get(reqwest::header::CONTENT_TYPE) + .unwrap(), + ARROW_STREAM_CONTENT_TYPE.as_bytes() + ); + + http::Response::builder().status(200).body("").unwrap() + }); + let data = RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])), + vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], + ) + .unwrap(); + let reader = RecordBatchIterator::new([Ok(data.clone())], data.schema()); + let table = conn.create_table("table1", reader).execute().await.unwrap(); + assert_eq!(table.name(), "table1"); + } + + #[tokio::test] + async fn test_create_table_already_exists() { + let conn = Connection::new_with_handler(|_| { + http::Response::builder() + .status(400) + .body("table table1 already exists") + .unwrap() + }); + let data = RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])), + vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], + ) + .unwrap(); + let reader = RecordBatchIterator::new([Ok(data.clone())], data.schema()); + let result = conn.create_table("table1", reader).execute().await; + assert!(result.is_err()); + assert!( + matches!(result, Err(crate::Error::TableAlreadyExists { name }) if name == "table1") + ); + } + + #[tokio::test] + async fn test_create_table_empty() { + let conn = Connection::new_with_handler(|request| { + assert_eq!(request.method(), &reqwest::Method::POST); + assert_eq!(request.url().path(), "/v1/table/table1/create/"); + assert_eq!( + request + .headers() + .get(reqwest::header::CONTENT_TYPE) + .unwrap(), + ARROW_STREAM_CONTENT_TYPE.as_bytes() + ); + + http::Response::builder().status(200).body("").unwrap() + }); + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + conn.create_empty_table("table1", schema) + .execute() + .await + .unwrap(); + } + + #[tokio::test] + async fn test_drop_table() { + let conn = Connection::new_with_handler(|request| { + assert_eq!(request.method(), &reqwest::Method::POST); + assert_eq!(request.url().path(), "/v1/table/table1/drop/"); + assert_eq!(request.url().query(), None); + assert!(request.body().is_none()); + + http::Response::builder().status(200).body("").unwrap() + }); + conn.drop_table("table1").await.unwrap(); + // NOTE: the API will return 200 even if the table does not exist. So we shouldn't expect 404. + } +} diff --git a/rust/lancedb/src/remote/table.rs b/rust/lancedb/src/remote/table.rs index 5add6e6ad..3406889d1 100644 --- a/rust/lancedb/src/remote/table.rs +++ b/rust/lancedb/src/remote/table.rs @@ -19,29 +19,29 @@ use crate::{ }, }; -use super::client::RestfulLanceDbClient; +use super::client::{HttpSend, RestfulLanceDbClient, Sender}; #[derive(Debug)] -pub struct RemoteTable { +pub struct RemoteTable { #[allow(dead_code)] - client: RestfulLanceDbClient, + client: RestfulLanceDbClient, name: String, } -impl RemoteTable { - pub fn new(client: RestfulLanceDbClient, name: String) -> Self { +impl RemoteTable { + pub fn new(client: RestfulLanceDbClient, name: String) -> Self { Self { client, name } } } -impl std::fmt::Display for RemoteTable { +impl std::fmt::Display for RemoteTable { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "RemoteTable({})", self.name) } } #[async_trait] -impl TableInternal for RemoteTable { +impl TableInternal for RemoteTable { fn as_any(&self) -> &dyn std::any::Any { self }