diff --git a/rust/lancedb/Cargo.toml b/rust/lancedb/Cargo.toml index 1fd48d9f..3bdca38c 100644 --- a/rust/lancedb/Cargo.toml +++ b/rust/lancedb/Cargo.toml @@ -45,8 +45,8 @@ serde_json = { version = "1" } async-openai = { version = "0.20.0", optional = true } serde_with = { version = "3.8.1" } # For remote feature -reqwest = { version = "0.11.24", features = ["gzip", "json"], optional = true } -http = { version = "0.2", optional = true } # Matching what is in reqwest +reqwest = { version = "0.12.0", features = ["gzip", "json", "stream"], optional = true } +http = { version = "1", optional = true } # Matching what is in reqwest polars-arrow = { version = ">=0.37,<0.40.0", optional = true } polars = { version = ">=0.37,<0.40.0", optional = true } hf-hub = { version = "0.3.2", optional = true } @@ -66,6 +66,7 @@ aws-sdk-s3 = { version = "1.38.0" } aws-sdk-kms = { version = "1.37" } aws-config = { version = "1.0" } aws-smithy-runtime = { version = "1.3" } +http-body = "1" # Matching reqwest [features] default = [] diff --git a/rust/lancedb/src/remote/table.rs b/rust/lancedb/src/remote/table.rs index 44cac37a..c7a13b49 100644 --- a/rust/lancedb/src/remote/table.rs +++ b/rust/lancedb/src/remote/table.rs @@ -1,16 +1,18 @@ -use std::sync::Arc; +use std::sync::{Arc, Mutex}; use crate::table::dataset::DatasetReadGuard; +use crate::table::AddDataMode; use crate::Error; use arrow_array::RecordBatchReader; use arrow_schema::SchemaRef; use async_trait::async_trait; use datafusion_physical_plan::ExecutionPlan; +use http::header::CONTENT_TYPE; use http::StatusCode; use lance::arrow::json::JsonSchema; use lance::dataset::scanner::{DatasetRecordBatchStream, Scanner}; use lance::dataset::{ColumnAlteration, NewColumnTransform}; -use serde::Deserialize; +use serde::{Deserialize, Serialize}; use crate::{ connection::NoData, @@ -24,6 +26,7 @@ use crate::{ }; use super::client::{HttpSend, RestfulLanceDbClient, Sender}; +use super::ARROW_STREAM_CONTENT_TYPE; #[derive(Debug)] pub struct RemoteTable { @@ -41,13 +44,7 @@ impl RemoteTable { let request = self.client.post(&format!("/table/{}/describe/", self.name)); let response = self.client.send(request).await?; - if response.status() == StatusCode::NOT_FOUND { - return Err(Error::TableNotFound { - name: self.name.clone(), - }); - } - - let response = self.client.check_response(response).await?; + let response = self.check_table_response(response).await?; let body = response.text().await?; @@ -55,6 +52,39 @@ impl RemoteTable { message: format!("Failed to parse table description: {}", e), }) } + + fn reader_as_body(data: Box) -> Result { + // TODO: Once Phalanx supports compression, we should use it here. + let mut writer = arrow_ipc::writer::StreamWriter::try_new(Vec::new(), &data.schema())?; + + // Mutex is just here to make it sync. We shouldn't have any contention. + let mut data = Mutex::new(data); + let body_iter = std::iter::from_fn(move || match data.get_mut().unwrap().next() { + Some(Ok(batch)) => { + writer.write(&batch).ok()?; + let buffer = std::mem::take(writer.get_mut()); + Some(Ok(buffer)) + } + Some(Err(e)) => Some(Err(e)), + None => { + writer.finish().ok()?; + let buffer = std::mem::take(writer.get_mut()); + Some(Ok(buffer)) + } + }); + let body_stream = futures::stream::iter(body_iter); + Ok(reqwest::Body::wrap_stream(body_stream)) + } + + async fn check_table_response(&self, response: reqwest::Response) -> Result { + if response.status() == StatusCode::NOT_FOUND { + return Err(Error::TableNotFound { + name: self.name.clone(), + }); + } + + self.client.check_response(response).await + } } #[derive(Deserialize)] @@ -133,13 +163,7 @@ impl TableInternal for RemoteTable { let response = self.client.send(request).await?; - if response.status() == StatusCode::NOT_FOUND { - return Err(Error::TableNotFound { - name: self.name.clone(), - }); - } - - let response = self.client.check_response(response).await?; + let response = self.check_table_response(response).await?; let body = response.text().await?; @@ -149,10 +173,28 @@ impl TableInternal for RemoteTable { } async fn add( &self, - _add: AddDataBuilder, - _data: Box, + add: AddDataBuilder, + data: Box, ) -> Result<()> { - todo!() + let body = Self::reader_as_body(data)?; + let mut request = self + .client + .post(&format!("/table/{}/insert/", self.name)) + .header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE) + .body(body); + + match add.mode { + AddDataMode::Append => {} + AddDataMode::Overwrite => { + request = request.query(&[("mode", "overwrite")]); + } + } + + let response = self.client.send(request).await?; + + self.check_table_response(response).await?; + + Ok(()) } async fn build_plan( &self, @@ -160,71 +202,170 @@ impl TableInternal for RemoteTable { _query: &VectorQuery, _options: Option, ) -> Result { - todo!() + Err(Error::NotSupported { + message: "build_plan is not supported on LanceDB cloud.".into(), + }) } async fn create_plan( &self, _query: &VectorQuery, _options: QueryExecutionOptions, ) -> Result> { - unimplemented!() + Err(Error::NotSupported { + message: "create_plan is not supported on LanceDB cloud.".into(), + }) } async fn explain_plan(&self, _query: &VectorQuery, _verbose: bool) -> Result { - todo!() + Err(Error::NotSupported { + message: "explain_plan is not supported on LanceDB cloud.".into(), + }) } async fn plain_query( &self, _query: &Query, _options: QueryExecutionOptions, ) -> Result { - todo!() + Err(Error::NotSupported { + message: "plain_query is not yet supported on LanceDB cloud.".into(), + }) } - async fn update(&self, _update: UpdateBuilder) -> Result<()> { - todo!() + async fn update(&self, update: UpdateBuilder) -> Result<()> { + let request = self.client.post(&format!("/table/{}/update/", self.name)); + + let mut updates = Vec::new(); + for (column, expression) in update.columns { + updates.push(column); + updates.push(expression); + } + + let request = request.json(&serde_json::json!({ + "updates": updates, + "only_if": update.filter, + })); + + let response = self.client.send(request).await?; + + self.check_table_response(response).await?; + + Ok(()) } - async fn delete(&self, _predicate: &str) -> Result<()> { - todo!() + async fn delete(&self, predicate: &str) -> Result<()> { + let body = serde_json::json!({ "predicate": predicate }); + let request = self + .client + .post(&format!("/table/{}/delete/", self.name)) + .json(&body); + let response = self.client.send(request).await?; + self.check_table_response(response).await?; + Ok(()) } async fn create_index(&self, _index: IndexBuilder) -> Result<()> { - todo!() + Err(Error::NotSupported { + message: "create_index is not yet supported on LanceDB cloud.".into(), + }) } async fn merge_insert( &self, - _params: MergeInsertBuilder, - _new_data: Box, + params: MergeInsertBuilder, + new_data: Box, ) -> Result<()> { - todo!() + let query = MergeInsertRequest::try_from(params)?; + let body = Self::reader_as_body(new_data)?; + let request = self + .client + .post(&format!("/table/{}/merge_insert/", self.name)) + .query(&query) + .header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE) + .body(body); + + let response = self.client.send(request).await?; + + self.check_table_response(response).await?; + + Ok(()) } async fn optimize(&self, _action: OptimizeAction) -> Result { - todo!() + Err(Error::NotSupported { + message: "optimize is not supported on LanceDB cloud.".into(), + }) } async fn add_columns( &self, _transforms: NewColumnTransform, _read_columns: Option>, ) -> Result<()> { - todo!() + Err(Error::NotSupported { + message: "add_columns is not yet supported.".into(), + }) } async fn alter_columns(&self, _alterations: &[ColumnAlteration]) -> Result<()> { - todo!() + Err(Error::NotSupported { + message: "alter_columns is not yet supported.".into(), + }) } async fn drop_columns(&self, _columns: &[&str]) -> Result<()> { - todo!() + Err(Error::NotSupported { + message: "drop_columns is not yet supported.".into(), + }) } async fn list_indices(&self) -> Result> { - todo!() + Err(Error::NotSupported { + message: "list_indices is not yet supported.".into(), + }) } async fn table_definition(&self) -> Result { - todo!() + Err(Error::NotSupported { + message: "table_definition is not supported on LanceDB cloud.".into(), + }) + } +} + +#[derive(Serialize)] +struct MergeInsertRequest { + on: String, + when_matched_update_all: bool, + when_matched_update_all_filt: Option, + when_not_matched_insert_all: bool, + when_not_matched_by_source_delete: bool, + when_not_matched_by_source_delete_filt: Option, +} + +impl TryFrom for MergeInsertRequest { + type Error = Error; + + fn try_from(value: MergeInsertBuilder) -> Result { + if value.on.is_empty() { + return Err(Error::InvalidInput { + message: "MergeInsertBuilder missing required 'on' field".into(), + }); + } else if value.on.len() > 1 { + return Err(Error::NotSupported { + message: "MergeInsertBuilder only supports a single 'on' column".into(), + }); + } + let on = value.on[0].clone(); + + Ok(Self { + on, + when_matched_update_all: value.when_matched_update_all, + when_matched_update_all_filt: value.when_matched_update_all_filt, + when_not_matched_insert_all: value.when_not_matched_insert_all, + when_not_matched_by_source_delete: value.when_not_matched_by_source_delete, + when_not_matched_by_source_delete_filt: value.when_not_matched_by_source_delete_filt, + }) } } #[cfg(test)] mod tests { + use std::{collections::HashMap, pin::Pin}; + use super::*; + use arrow_array::{Int32Array, RecordBatch, RecordBatchIterator}; use arrow_schema::{DataType, Field, Schema}; - use futures::{future::BoxFuture, TryFutureExt}; + use futures::{future::BoxFuture, StreamExt, TryFutureExt}; + use reqwest::Body; use crate::{Error, Table}; @@ -237,12 +378,27 @@ mod tests { .unwrap() }); + let batch = RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])), + vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], + ) + .unwrap(); + let example_data = || { + Box::new(RecordBatchIterator::new( + [Ok(batch.clone())], + batch.schema(), + )) + }; + // All endpoints should translate 404 to TableNotFound. let results: Vec>> = vec![ Box::pin(table.version().map_ok(|_| ())), Box::pin(table.schema().map_ok(|_| ())), Box::pin(table.count_rows(None).map_ok(|_| ())), - // TODO: other endpoints. + Box::pin(table.update().column("a", "a + 1").execute()), + Box::pin(table.add(example_data()).execute().map_ok(|_| ())), + Box::pin(table.merge_insert(&["test"]).execute(example_data())), + Box::pin(table.delete("false")), // TODO: other endpoints. ]; for result in results { @@ -324,4 +480,245 @@ mod tests { let count = table.count_rows(Some("a > 10".into())).await.unwrap(); assert_eq!(count, 42); } + + async fn collect_body(body: Body) -> Vec { + use http_body::Body; + let mut body = body; + let mut data = Vec::new(); + let mut body_pin = Pin::new(&mut body); + futures::stream::poll_fn(|cx| body_pin.as_mut().poll_frame(cx)) + .for_each(|frame| { + data.extend_from_slice(frame.unwrap().data_ref().unwrap()); + futures::future::ready(()) + }) + .await; + data + } + + fn write_ipc_stream(data: &RecordBatch) -> Vec { + let mut body = Vec::new(); + { + let mut writer = arrow_ipc::writer::StreamWriter::try_new(&mut body, &data.schema()) + .expect("Failed to create writer"); + writer.write(data).expect("Failed to write data"); + writer.finish().expect("Failed to finish"); + } + body + } + + #[tokio::test] + async fn test_add_append() { + let data = RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])), + vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], + ) + .unwrap(); + + let (sender, receiver) = std::sync::mpsc::channel(); + let table = Table::new_with_handler("my_table", move |mut request| { + assert_eq!(request.method(), "POST"); + assert_eq!(request.url().path(), "/table/my_table/insert/"); + // If mode is specified, it should be "append". Append is default + // so it's not required. + assert!(request + .url() + .query_pairs() + .filter(|(k, _)| k == "mode") + .all(|(_, v)| v == "append")); + + assert_eq!( + request.headers().get("Content-Type").unwrap(), + ARROW_STREAM_CONTENT_TYPE + ); + + let mut body_out = reqwest::Body::from(Vec::new()); + std::mem::swap(request.body_mut().as_mut().unwrap(), &mut body_out); + sender.send(body_out).unwrap(); + + http::Response::builder().status(200).body("").unwrap() + }); + + table + .add(RecordBatchIterator::new([Ok(data.clone())], data.schema())) + .execute() + .await + .unwrap(); + + let body = receiver.recv().unwrap(); + let body = collect_body(body).await; + let expected_body = write_ipc_stream(&data); + assert_eq!(&body, &expected_body); + } + + #[tokio::test] + async fn test_add_overwrite() { + let data = RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])), + vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], + ) + .unwrap(); + + let (sender, receiver) = std::sync::mpsc::channel(); + let table = Table::new_with_handler("my_table", move |mut request| { + assert_eq!(request.method(), "POST"); + assert_eq!(request.url().path(), "/table/my_table/insert/"); + assert_eq!( + request + .url() + .query_pairs() + .find(|(k, _)| k == "mode") + .map(|kv| kv.1) + .as_deref(), + Some("overwrite"), + "Expected mode=overwrite" + ); + + assert_eq!( + request.headers().get("Content-Type").unwrap(), + ARROW_STREAM_CONTENT_TYPE + ); + + let mut body_out = reqwest::Body::from(Vec::new()); + std::mem::swap(request.body_mut().as_mut().unwrap(), &mut body_out); + sender.send(body_out).unwrap(); + + http::Response::builder().status(200).body("").unwrap() + }); + + table + .add(RecordBatchIterator::new([Ok(data.clone())], data.schema())) + .mode(AddDataMode::Overwrite) + .execute() + .await + .unwrap(); + + let body = receiver.recv().unwrap(); + let body = collect_body(body).await; + let expected_body = write_ipc_stream(&data); + assert_eq!(&body, &expected_body); + } + + #[tokio::test] + async fn test_update() { + let table = Table::new_with_handler("my_table", |request| { + assert_eq!(request.method(), "POST"); + assert_eq!(request.url().path(), "/table/my_table/update/"); + + if let Some(body) = request.body().unwrap().as_bytes() { + let body = std::str::from_utf8(body).unwrap(); + let value: serde_json::Value = serde_json::from_str(body).unwrap(); + let updates = value.get("updates").unwrap().as_array().unwrap(); + assert!(updates.len() == 2); + + let col_name = updates[0].as_str().unwrap(); + let expression = updates[1].as_str().unwrap(); + assert_eq!(col_name, "a"); + assert_eq!(expression, "a + 1"); + + let only_if = value.get("only_if").unwrap().as_str().unwrap(); + assert_eq!(only_if, "b > 10"); + } + + http::Response::builder().status(200).body("").unwrap() + }); + + table + .update() + .column("a", "a + 1") + .only_if("b > 10") + .execute() + .await + .unwrap(); + } + + #[tokio::test] + async fn test_merge_insert() { + let batch = RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])), + vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], + ) + .unwrap(); + let data = Box::new(RecordBatchIterator::new( + [Ok(batch.clone())], + batch.schema(), + )); + + // Default parameters + let table = Table::new_with_handler("my_table", |request| { + assert_eq!(request.method(), "POST"); + assert_eq!(request.url().path(), "/table/my_table/merge_insert/"); + + let params = request.url().query_pairs().collect::>(); + assert_eq!(params["on"], "some_col"); + assert_eq!(params["when_matched_update_all"], "false"); + assert_eq!(params["when_not_matched_insert_all"], "false"); + assert_eq!(params["when_not_matched_by_source_delete"], "false"); + assert!(!params.contains_key("when_matched_update_all_filt")); + assert!(!params.contains_key("when_not_matched_by_source_delete_filt")); + + http::Response::builder().status(200).body("").unwrap() + }); + + table + .merge_insert(&["some_col"]) + .execute(data) + .await + .unwrap(); + + // All parameters specified + let (sender, receiver) = std::sync::mpsc::channel(); + let table = Table::new_with_handler("my_table", move |mut request| { + assert_eq!(request.method(), "POST"); + assert_eq!(request.url().path(), "/table/my_table/merge_insert/"); + assert_eq!( + request.headers().get("Content-Type").unwrap(), + ARROW_STREAM_CONTENT_TYPE + ); + + let params = request.url().query_pairs().collect::>(); + assert_eq!(params["on"], "some_col"); + assert_eq!(params["when_matched_update_all"], "true"); + assert_eq!(params["when_not_matched_insert_all"], "false"); + assert_eq!(params["when_not_matched_by_source_delete"], "true"); + assert_eq!(params["when_matched_update_all_filt"], "a = 1"); + assert_eq!(params["when_not_matched_by_source_delete_filt"], "b = 2"); + + let mut body_out = reqwest::Body::from(Vec::new()); + std::mem::swap(request.body_mut().as_mut().unwrap(), &mut body_out); + sender.send(body_out).unwrap(); + + http::Response::builder().status(200).body("").unwrap() + }); + let mut builder = table.merge_insert(&["some_col"]); + builder + .when_matched_update_all(Some("a = 1".into())) + .when_not_matched_by_source_delete(Some("b = 2".into())); + let data = Box::new(RecordBatchIterator::new( + [Ok(batch.clone())], + batch.schema(), + )); + builder.execute(data).await.unwrap(); + + let body = receiver.recv().unwrap(); + let body = collect_body(body).await; + let expected_body = write_ipc_stream(&batch); + assert_eq!(&body, &expected_body); + } + + #[tokio::test] + async fn test_delete() { + let table = Table::new_with_handler("my_table", |request| { + assert_eq!(request.method(), "POST"); + assert_eq!(request.url().path(), "/table/my_table/delete/"); + + let body = request.body().unwrap().as_bytes().unwrap(); + let body: serde_json::Value = serde_json::from_slice(body).unwrap(); + let predicate = body.get("predicate").unwrap().as_str().unwrap(); + assert_eq!(predicate, "id in (1, 2, 3)"); + + http::Response::builder().status(200).body("").unwrap() + }); + + table.delete("id in (1, 2, 3)").await.unwrap(); + } } diff --git a/rust/lancedb/src/table/merge.rs b/rust/lancedb/src/table/merge.rs index 5c422b9d..159ede1c 100644 --- a/rust/lancedb/src/table/merge.rs +++ b/rust/lancedb/src/table/merge.rs @@ -26,12 +26,12 @@ use super::TableInternal; #[derive(Debug, Clone)] pub struct MergeInsertBuilder { table: Arc, - pub(super) on: Vec, - pub(super) when_matched_update_all: bool, - pub(super) when_matched_update_all_filt: Option, - pub(super) when_not_matched_insert_all: bool, - pub(super) when_not_matched_by_source_delete: bool, - pub(super) when_not_matched_by_source_delete_filt: Option, + pub(crate) on: Vec, + pub(crate) when_matched_update_all: bool, + pub(crate) when_matched_update_all_filt: Option, + pub(crate) when_not_matched_insert_all: bool, + pub(crate) when_not_matched_by_source_delete: bool, + pub(crate) when_not_matched_by_source_delete_filt: Option, } impl MergeInsertBuilder {