diff --git a/Cargo.toml b/Cargo.toml index a4167223..6285681c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,10 +11,11 @@ license = "Apache-2.0" repository = "https://github.com/lancedb/lancedb" [workspace.dependencies] -lance = { "version" = "=0.9.10", "features" = ["dynamodb"] } -lance-index = { "version" = "=0.9.10" } -lance-linalg = { "version" = "=0.9.10" } -lance-testing = { "version" = "=0.9.10" } +lance = { "version" = "=0.9.11", "features" = ["dynamodb"] } +lance-datafusion = { "version" = "=0.9.11" } +lance-index = { "version" = "=0.9.11" } +lance-linalg = { "version" = "=0.9.11" } +lance-testing = { "version" = "=0.9.11" } # Note that this one does not include pyarrow arrow = { version = "50.0", optional = false } arrow-array = "50.0" diff --git a/rust/vectordb/Cargo.toml b/rust/vectordb/Cargo.toml index b91077c6..7a74a174 100644 --- a/rust/vectordb/Cargo.toml +++ b/rust/vectordb/Cargo.toml @@ -22,6 +22,7 @@ object_store = { workspace = true } snafu = { workspace = true } half = { workspace = true } lance = { workspace = true } +lance-datafusion = { workspace = true } lance-index = { workspace = true } lance-linalg = { workspace = true } lance-testing = { workspace = true } diff --git a/rust/vectordb/src/merge_insert.rs b/rust/vectordb/src/merge_insert.rs index ab3cb4fe..99c852b3 100644 --- a/rust/vectordb/src/merge_insert.rs +++ b/rust/vectordb/src/merge_insert.rs @@ -12,8 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::pin::Pin; +use std::sync::Arc; use arrow_array::RecordBatchReader; -use lance::dataset; +use lance::dataset::{self, WhenMatched, WhenNotMatched, WhenNotMatchedBySource}; +use lance_datafusion::utils::reader_to_stream; use crate::TableRef; use crate::error::{Error, Result}; @@ -57,7 +60,42 @@ impl MergeInsertBuilder { self } - pub async fn execute(batches: Box) -> Result<()> { + pub async fn execute( + mut self, + batches: Box, + ) -> Result<()> { + let native_table = self.table.as_native().unwrap(); // TODO no unwrap + let ds = native_table.clone_inner_dataset(); + let mut builder = dataset::MergeInsertBuilder::try_new( + Arc::new(ds), + vec!["vectors".to_string()], + ) + .unwrap(); // TODO no unwrap + + if self.when_matched_update_all { + builder.when_matched(WhenMatched::UpdateAll); + } + + if self.when_not_matched_insert_all { + builder.when_not_matched(WhenNotMatched::InsertAll); + } + + if self.when_not_matched_by_source_delete { + builder.when_not_matched_by_source(WhenNotMatchedBySource::Delete); + } + + // TODO + // if self.when_not_matched_by_source_condition { + // builder.when_not_matched_by_source(WhenNotMatchedBySource::DeleteIf(())); + // } + + let job = builder.try_build().unwrap(); // TODO no unwrap + let bitches = reader_to_stream(batches).await.unwrap().0; // TODO no unwrap + let ds2 = job.execute(bitches).await.unwrap(); // TODO no unwrap + + native_table.reset_dataset(ds2.as_ref().clone()); + + Ok(()) } } \ No newline at end of file