diff --git a/Cargo.toml b/Cargo.toml index a4167223..1d505920 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,10 +11,10 @@ license = "Apache-2.0" repository = "https://github.com/lancedb/lancedb" [workspace.dependencies] -lance = { "version" = "=0.9.10", "features" = ["dynamodb"] } -lance-index = { "version" = "=0.9.10" } -lance-linalg = { "version" = "=0.9.10" } -lance-testing = { "version" = "=0.9.10" } +lance = { "version" = "=0.9.12", "features" = ["dynamodb"] } +lance-index = { "version" = "=0.9.12" } +lance-linalg = { "version" = "=0.9.12" } +lance-testing = { "version" = "=0.9.12" } # Note that this one does not include pyarrow arrow = { version = "50.0", optional = false } arrow-array = "50.0" diff --git a/node/src/index.ts b/node/src/index.ts index ed34013a..9b7831b8 100644 --- a/node/src/index.ts +++ b/node/src/index.ts @@ -37,6 +37,7 @@ const { tableCountRows, tableDelete, tableUpdate, + tableMergeInsert, tableCleanupOldVersions, tableCompactFiles, tableListIndices, @@ -440,6 +441,38 @@ export interface Table { */ update: (args: UpdateArgs | UpdateSqlArgs) => Promise + /** + * Runs a "merge insert" operation on the table + * + * This operation can add rows, update rows, and remove rows all in a single + * transaction. It is a very generic tool that can be used to create + * behaviors like "insert if not exists", "update or insert (i.e. upsert)", + * or even replace a portion of existing data with new data (e.g. replace + * all data where month="january") + * + * The merge insert operation works by combining new data from a + * **source table** with existing data in a **target table** by using a + * join. There are three categories of records. + * + * "Matched" records are records that exist in both the source table and + * the target table. "Not matched" records exist only in the source table + * (e.g. these are new data) "Not matched by source" records exist only + * in the target table (this is old data) + * + * The MergeInsertArgs can be used to customize what should happen for + * each category of data. + * + * Please note that the data may appear to be reordered as part of this + * operation. This is because updated rows will be deleted from the + * dataset and then reinserted at the end with the new values. + * + * @param on a column to join on. This is how records from the source + * table and target table are matched. + * @param data the new data to insert + * @param args parameters controlling how the operation should behave + */ + mergeInsert: (on: string, data: Array> | ArrowTable, args: MergeInsertArgs) => Promise + /** * List the indicies on this table. */ @@ -483,6 +516,36 @@ export interface UpdateSqlArgs { valuesSql: Record } +export interface MergeInsertArgs { + /** + * If true then rows that exist in both the source table (new data) and + * the target table (old data) will be updated, replacing the old row + * with the corresponding matching row. + * + * If there are multiple matches then the behavior is undefined. + * Currently this causes multiple copies of the row to be created + * but that behavior is subject to change. + */ + whenMatchedUpdateAll?: boolean + /** + * If true then rows that exist only in the source table (new data) + * will be inserted into the target table. + */ + whenNotMatchedInsertAll?: boolean + /** + * If true then rows that exist only in the target table (old data) + * will be deleted. + * + * If this is a string then it will be treated as an SQL filter and + * only rows that both do not match any row in the source table and + * match the given filter will be deleted. + * + * This can be used to replace a selection of existing data with + * new data. + */ + whenNotMatchedBySourceDelete?: string | boolean +} + export interface VectorIndex { columns: string[] name: string @@ -821,6 +884,38 @@ export class LocalTable implements Table { }) } + async mergeInsert (on: string, data: Array> | ArrowTable, args: MergeInsertArgs): Promise { + const whenMatchedUpdateAll = args.whenMatchedUpdateAll ?? false + const whenNotMatchedInsertAll = args.whenNotMatchedInsertAll ?? false + let whenNotMatchedBySourceDelete = false + let whenNotMatchedBySourceDeleteFilt = null + if (args.whenNotMatchedBySourceDelete !== undefined && args.whenNotMatchedBySourceDelete !== null) { + whenNotMatchedBySourceDelete = true + if (args.whenNotMatchedBySourceDelete !== true) { + whenNotMatchedBySourceDeleteFilt = args.whenNotMatchedBySourceDelete + } + } + + const schema = await this.schema + let tbl: ArrowTable + if (data instanceof ArrowTable) { + tbl = data + } else { + tbl = makeArrowTable(data, { schema }) + } + const buffer = await fromTableToBuffer(tbl, this._embeddings, schema) + + this._tbl = await tableMergeInsert.call( + this._tbl, + on, + whenMatchedUpdateAll, + whenNotMatchedInsertAll, + whenNotMatchedBySourceDelete, + whenNotMatchedBySourceDeleteFilt, + buffer + ) + } + /** * Clean up old versions of the table, freeing disk space. * diff --git a/node/src/remote/index.ts b/node/src/remote/index.ts index b08d9e6c..28f43581 100644 --- a/node/src/remote/index.ts +++ b/node/src/remote/index.ts @@ -24,7 +24,8 @@ import { type IndexStats, type UpdateArgs, type UpdateSqlArgs, - makeArrowTable + makeArrowTable, + type MergeInsertArgs } from '../index' import { Query } from '../query' @@ -274,6 +275,52 @@ export class RemoteTable implements Table { throw new Error('Not implemented') } + async mergeInsert (on: string, data: Array> | ArrowTable, args: MergeInsertArgs): Promise { + let tbl: ArrowTable + if (data instanceof ArrowTable) { + tbl = data + } else { + tbl = makeArrowTable(data, await this.schema) + } + + const queryParams: any = { + on + } + if (args.whenMatchedUpdateAll ?? false) { + queryParams.when_matched_update_all = 'true' + } else { + queryParams.when_matched_update_all = 'false' + } + if (args.whenNotMatchedInsertAll ?? false) { + queryParams.when_not_matched_insert_all = 'true' + } else { + queryParams.when_not_matched_insert_all = 'false' + } + if (args.whenNotMatchedBySourceDelete !== false && args.whenNotMatchedBySourceDelete !== null && args.whenNotMatchedBySourceDelete !== undefined) { + queryParams.when_not_matched_by_source_delete = 'true' + if (typeof args.whenNotMatchedBySourceDelete === 'string') { + queryParams.when_not_matched_by_source_delete_filt = args.whenNotMatchedBySourceDelete + } + } else { + queryParams.when_not_matched_by_source_delete = 'false' + } + + const buffer = await fromTableToStreamBuffer(tbl, this._embeddings) + const res = await this._client.post( + `/v1/table/${this._name}/merge_insert/`, + buffer, + queryParams, + 'application/vnd.apache.arrow.stream' + ) + if (res.status !== 200) { + throw new Error( + `Server Error, status: ${res.status}, ` + + // eslint-disable-next-line @typescript-eslint/restrict-template-expressions + `message: ${res.statusText}: ${res.data}` + ) + } + } + async add (data: Array> | ArrowTable): Promise { let tbl: ArrowTable if (data instanceof ArrowTable) { diff --git a/node/src/test/test.ts b/node/src/test/test.ts index 7cf2b1ae..db87cbc6 100644 --- a/node/src/test/test.ts +++ b/node/src/test/test.ts @@ -531,6 +531,44 @@ describe('LanceDB client', function () { assert.equal(await table.countRows(), 2) }) + it('can merge insert records into the table', async function () { + const dir = await track().mkdir('lancejs') + const con = await lancedb.connect(dir) + + const data = [{ id: 1, age: 1 }, { id: 2, age: 1 }] + const table = await con.createTable('my_table', data) + + let newData = [{ id: 2, age: 2 }, { id: 3, age: 2 }] + await table.mergeInsert('id', newData, { + whenNotMatchedInsertAll: true + }) + assert.equal(await table.countRows(), 3) + assert.equal((await table.filter('age = 2').execute()).length, 1) + + newData = [{ id: 3, age: 3 }, { id: 4, age: 3 }] + await table.mergeInsert('id', newData, { + whenNotMatchedInsertAll: true, + whenMatchedUpdateAll: true + }) + assert.equal(await table.countRows(), 4) + assert.equal((await table.filter('age = 3').execute()).length, 2) + + newData = [{ id: 5, age: 4 }] + await table.mergeInsert('id', newData, { + whenNotMatchedInsertAll: true, + whenMatchedUpdateAll: true, + whenNotMatchedBySourceDelete: 'age < 3' + }) + assert.equal(await table.countRows(), 3) + + await table.mergeInsert('id', newData, { + whenNotMatchedInsertAll: true, + whenMatchedUpdateAll: true, + whenNotMatchedBySourceDelete: true + }) + assert.equal(await table.countRows(), 1) + }) + it('can update records in the table', async function () { const uri = await createTestDB() const con = await lancedb.connect(uri) diff --git a/python/lancedb/merge.py b/python/lancedb/merge.py index e689513b..9180635a 100644 --- a/python/lancedb/merge.py +++ b/python/lancedb/merge.py @@ -12,7 +12,7 @@ # limitations under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Iterable, Optional +from typing import TYPE_CHECKING, List, Optional if TYPE_CHECKING: from .common import DATA @@ -25,7 +25,7 @@ class LanceMergeInsertBuilder(object): more context """ - def __init__(self, table: "Table", on: Iterable[str]): # noqa: F821 + def __init__(self, table: "Table", on: List[str]): # noqa: F821 # Do not put a docstring here. This method should be hidden # from API docs. Users should use merge_insert to create # this object. @@ -77,10 +77,27 @@ class LanceMergeInsertBuilder(object): self._when_not_matched_by_source_condition = condition return self - def execute(self, new_data: DATA): + def execute( + self, + new_data: DATA, + on_bad_vectors: str = "error", + fill_value: float = 0.0, + ): """ Executes the merge insert operation Nothing is returned but the [`Table`][lancedb.table.Table] is updated + + Parameters + ---------- + new_data: DATA + New records which will be matched against the existing records + to potentially insert or update into the table. This parameter + can be anything you use for [`add`][lancedb.table.Table.add] + on_bad_vectors: str, default "error" + What to do if any of the vectors are not the same size or contains NaNs. + One of "error", "drop", "fill". + fill_value: float, default 0. + The value to use when filling vectors. Only used if on_bad_vectors="fill". """ - self._table._do_merge(self, new_data) + self._table._do_merge(self, new_data, on_bad_vectors, fill_value) diff --git a/python/lancedb/remote/table.py b/python/lancedb/remote/table.py index e751bcbb..9815dda6 100644 --- a/python/lancedb/remote/table.py +++ b/python/lancedb/remote/table.py @@ -19,6 +19,7 @@ import pyarrow as pa from lance import json_to_schema from lancedb.common import DATA, VEC, VECTOR_COLUMN_NAME +from lancedb.merge import LanceMergeInsertBuilder from ..query import LanceVectorQueryBuilder from ..table import Query, Table, _sanitize_data @@ -244,9 +245,46 @@ class RemoteTable(Table): result = self._conn._client.query(self._name, query) return result.to_arrow() - def _do_merge(self, *_args): - """_do_merge() is not supported on the LanceDB cloud yet""" - return NotImplementedError("_do_merge() is not supported on the LanceDB cloud") + def _do_merge( + self, + merge: LanceMergeInsertBuilder, + new_data: DATA, + on_bad_vectors: str, + fill_value: float, + ): + data = _sanitize_data( + new_data, + self.schema, + metadata=None, + on_bad_vectors=on_bad_vectors, + fill_value=fill_value, + ) + payload = to_ipc_binary(data) + + params = {} + if len(merge._on) != 1: + raise ValueError( + "RemoteTable only supports a single on key in merge_insert" + ) + params["on"] = merge._on[0] + params["when_matched_update_all"] = str(merge._when_matched_update_all).lower() + params["when_not_matched_insert_all"] = str( + merge._when_not_matched_insert_all + ).lower() + params["when_not_matched_by_source_delete"] = str( + merge._when_not_matched_by_source_delete + ).lower() + if merge._when_not_matched_by_source_condition is not None: + params[ + "when_not_matched_by_source_delete_filt" + ] = merge._when_not_matched_by_source_condition + + self._conn._client.post( + f"/v1/table/{self._name}/merge_insert/", + data=payload, + params=params, + content_type=ARROW_STREAM_CONTENT_TYPE, + ) def delete(self, predicate: str): """Delete rows from the table. diff --git a/python/lancedb/table.py b/python/lancedb/table.py index 44decda7..a589a298 100644 --- a/python/lancedb/table.py +++ b/python/lancedb/table.py @@ -390,6 +390,8 @@ class Table(ABC): 2 3 y 3 4 z """ + on = [on] if isinstance(on, str) else list(on.iter()) + return LanceMergeInsertBuilder(self, on) @abstractmethod @@ -479,8 +481,8 @@ class Table(ABC): self, merge: LanceMergeInsertBuilder, new_data: DATA, - *, - schema: Optional[pa.Schema] = None, + on_bad_vectors: str, + fill_value: float, ): pass @@ -1305,7 +1307,20 @@ class LanceTable(Table): with_row_id=query.with_row_id, ) - def _do_merge(self, merge: LanceMergeInsertBuilder, new_data: DATA, *, schema=None): + def _do_merge( + self, + merge: LanceMergeInsertBuilder, + new_data: DATA, + on_bad_vectors: str, + fill_value: float, + ): + new_data = _sanitize_data( + new_data, + self.schema, + metadata=self.schema.metadata, + on_bad_vectors=on_bad_vectors, + fill_value=fill_value, + ) ds = self.to_lance() builder = ds.merge_insert(merge._on) if merge._when_matched_update_all: @@ -1315,7 +1330,7 @@ class LanceTable(Table): if merge._when_not_matched_by_source_delete: cond = merge._when_not_matched_by_source_condition builder.when_not_matched_by_source_delete(cond) - builder.execute(new_data, schema=schema) + builder.execute(new_data) def cleanup_old_versions( self, diff --git a/rust/ffi/node/src/lib.rs b/rust/ffi/node/src/lib.rs index 4fed6e28..41212030 100644 --- a/rust/ffi/node/src/lib.rs +++ b/rust/ffi/node/src/lib.rs @@ -260,6 +260,7 @@ fn main(mut cx: ModuleContext) -> NeonResult<()> { cx.export_function("tableCountRows", JsTable::js_count_rows)?; cx.export_function("tableDelete", JsTable::js_delete)?; cx.export_function("tableUpdate", JsTable::js_update)?; + cx.export_function("tableMergeInsert", JsTable::js_merge_insert)?; cx.export_function("tableCleanupOldVersions", JsTable::js_cleanup)?; cx.export_function("tableCompactFiles", JsTable::js_compact)?; cx.export_function("tableListIndices", JsTable::js_list_indices)?; diff --git a/rust/ffi/node/src/table.rs b/rust/ffi/node/src/table.rs index ae01bb75..9f6f9f6f 100644 --- a/rust/ffi/node/src/table.rs +++ b/rust/ffi/node/src/table.rs @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::ops::Deref; + use arrow_array::{RecordBatch, RecordBatchIterator}; use lance::dataset::optimize::CompactionOptions; use lance::dataset::{WriteMode, WriteParams}; @@ -166,6 +168,53 @@ impl JsTable { Ok(promise) } + pub(crate) fn js_merge_insert(mut cx: FunctionContext) -> JsResult { + let js_table = cx.this().downcast_or_throw::, _>(&mut cx)?; + let rt = runtime(&mut cx)?; + let (deferred, promise) = cx.promise(); + let channel = cx.channel(); + let table = js_table.table.clone(); + + let key = cx.argument::(0)?.value(&mut cx); + let mut builder = table.merge_insert(&[&key]); + if cx.argument::(1)?.value(&mut cx) { + builder.when_matched_update_all(); + } + if cx.argument::(2)?.value(&mut cx) { + builder.when_not_matched_insert_all(); + } + if cx.argument::(3)?.value(&mut cx) { + if let Some(filter) = cx.argument_opt(4) { + if filter.is_a::(&mut cx) { + builder.when_not_matched_by_source_delete(None); + } else { + let filter = filter + .downcast_or_throw::(&mut cx)? + .deref() + .value(&mut cx); + builder.when_not_matched_by_source_delete(Some(filter)); + } + } else { + builder.when_not_matched_by_source_delete(None); + } + } + + let buffer = cx.argument::(5)?; + let (batches, schema) = + arrow_buffer_to_record_batch(buffer.as_slice(&cx)).or_throw(&mut cx)?; + + rt.spawn(async move { + let new_data = RecordBatchIterator::new(batches.into_iter().map(Ok), schema); + let merge_insert_result = builder.execute(Box::new(new_data)).await; + + deferred.settle_with(&channel, move |mut cx| { + merge_insert_result.or_throw(&mut cx)?; + Ok(cx.boxed(JsTable::from(table))) + }) + }); + Ok(promise) + } + pub(crate) fn js_update(mut cx: FunctionContext) -> JsResult { let js_table = cx.this().downcast_or_throw::, _>(&mut cx)?; let table = js_table.table.clone(); diff --git a/rust/vectordb/src/table.rs b/rust/vectordb/src/table.rs index 49035fca..d4a3ae28 100644 --- a/rust/vectordb/src/table.rs +++ b/rust/vectordb/src/table.rs @@ -19,6 +19,7 @@ use std::sync::{Arc, Mutex}; use arrow_array::RecordBatchReader; use arrow_schema::{Schema, SchemaRef}; +use async_trait::async_trait; use chrono::Duration; use lance::dataset::builder::DatasetBuilder; use lance::dataset::cleanup::RemovalStats; @@ -27,6 +28,7 @@ use lance::dataset::optimize::{ }; pub use lance::dataset::ReadParams; use lance::dataset::{Dataset, UpdateBuilder, WriteParams}; +use lance::dataset::{MergeInsertBuilder as LanceMergeInsertBuilder, WhenNotMatchedBySource}; use lance::io::WrappingObjectStore; use lance_index::{optimize::OptimizeOptions, DatasetIndexExt}; use log::info; @@ -38,6 +40,10 @@ use crate::query::Query; use crate::utils::{PatchReadParam, PatchWriteParam}; use crate::WriteMode; +use self::merge::{MergeInsert, MergeInsertBuilder}; + +pub mod merge; + /// Optimize the dataset. /// /// Similar to `VACUUM` in PostgreSQL, it offers different options to @@ -170,6 +176,71 @@ pub trait Table: std::fmt::Display + Send + Sync { /// ``` fn create_index(&self, column: &[&str]) -> IndexBuilder; + /// Create a builder for a merge insert operation + /// + /// This operation can add rows, update rows, and remove rows all in a single + /// transaction. It is a very generic tool that can be used to create + /// behaviors like "insert if not exists", "update or insert (i.e. upsert)", + /// or even replace a portion of existing data with new data (e.g. replace + /// all data where month="january") + /// + /// The merge insert operation works by combining new data from a + /// **source table** with existing data in a **target table** by using a + /// join. There are three categories of records. + /// + /// "Matched" records are records that exist in both the source table and + /// the target table. "Not matched" records exist only in the source table + /// (e.g. these are new data) "Not matched by source" records exist only + /// in the target table (this is old data) + /// + /// The builder returned by this method can be used to customize what + /// should happen for each category of data. + /// + /// Please note that the data may appear to be reordered as part of this + /// operation. This is because updated rows will be deleted from the + /// dataset and then reinserted at the end with the new values. + /// + /// # Arguments + /// + /// * `on` One or more columns to join on. This is how records from the + /// source table and target table are matched. Typically this is some + /// kind of key or id column. + /// + /// # Examples + /// + /// ```no_run + /// # use std::sync::Arc; + /// # use vectordb::connection::{Database, Connection}; + /// # use arrow_array::{FixedSizeListArray, types::Float32Type, RecordBatch, + /// # RecordBatchIterator, Int32Array}; + /// # use arrow_schema::{Schema, Field, DataType}; + /// # tokio::runtime::Runtime::new().unwrap().block_on(async { + /// let tmpdir = tempfile::tempdir().unwrap(); + /// let db = Database::connect(tmpdir.path().to_str().unwrap()).await.unwrap(); + /// # let tbl = db.open_table("idx_test").await.unwrap(); + /// # let schema = Arc::new(Schema::new(vec![ + /// # Field::new("id", DataType::Int32, false), + /// # Field::new("vector", DataType::FixedSizeList( + /// # Arc::new(Field::new("item", DataType::Float32, true)), 128), true), + /// # ])); + /// let new_data = RecordBatchIterator::new(vec![ + /// RecordBatch::try_new(schema.clone(), + /// vec![ + /// Arc::new(Int32Array::from_iter_values(0..10)), + /// Arc::new(FixedSizeListArray::from_iter_primitive::( + /// (0..10).map(|_| Some(vec![Some(1.0); 128])), 128)), + /// ]).unwrap() + /// ].into_iter().map(Ok), + /// schema.clone()); + /// // Perform an upsert operation + /// let mut merge_insert = tbl.merge_insert(&["id"]); + /// merge_insert.when_matched_update_all() + /// .when_not_matched_insert_all(); + /// merge_insert.execute(Box::new(new_data)).await.unwrap(); + /// # }); + /// ``` + fn merge_insert(&self, on: &[&str]) -> MergeInsertBuilder; + /// Search the table with a given query vector. /// /// This is a convenience method for preparing an ANN query. @@ -593,6 +664,42 @@ impl NativeTable { } } +#[async_trait] +impl MergeInsert for NativeTable { + async fn do_merge_insert( + &self, + params: MergeInsertBuilder, + new_data: Box, + ) -> Result<()> { + let dataset = Arc::new(self.clone_inner_dataset()); + let mut builder = LanceMergeInsertBuilder::try_new(dataset.clone(), params.on)?; + if params.when_matched_update_all { + builder.when_matched(lance::dataset::WhenMatched::UpdateAll); + } else { + builder.when_matched(lance::dataset::WhenMatched::DoNothing); + } + if params.when_not_matched_insert_all { + builder.when_not_matched(lance::dataset::WhenNotMatched::InsertAll); + } else { + builder.when_not_matched(lance::dataset::WhenNotMatched::DoNothing); + } + if params.when_not_matched_by_source_delete { + let behavior = if let Some(filter) = params.when_not_matched_by_source_delete_filt { + WhenNotMatchedBySource::delete_if(dataset.as_ref(), &filter)? + } else { + WhenNotMatchedBySource::Delete + }; + builder.when_not_matched_by_source(behavior); + } else { + builder.when_not_matched_by_source(WhenNotMatchedBySource::Keep); + } + let job = builder.try_build()?; + let new_dataset = job.execute_reader(new_data).await?; + self.reset_dataset((*new_dataset).clone()); + Ok(()) + } +} + #[async_trait::async_trait] impl Table for NativeTable { fn as_any(&self) -> &dyn std::any::Any { @@ -637,6 +744,11 @@ impl Table for NativeTable { Ok(()) } + fn merge_insert(&self, on: &[&str]) -> MergeInsertBuilder { + let on = Vec::from_iter(on.iter().map(|key| key.to_string())); + MergeInsertBuilder::new(Arc::new(self.clone()), on) + } + fn create_index(&self, columns: &[&str]) -> IndexBuilder { IndexBuilder::new(Arc::new(self.clone()), columns) } @@ -802,6 +914,38 @@ mod tests { assert_eq!(table.name, "test"); } + #[tokio::test] + async fn test_merge_insert() { + let tmp_dir = tempdir().unwrap(); + let uri = tmp_dir.path().to_str().unwrap(); + + // Create a dataset with i=0..10 + let batches = make_test_batches_with_offset(0); + let table = NativeTable::create(&uri, "test", batches, None, None) + .await + .unwrap(); + assert_eq!(table.count_rows().await.unwrap(), 10); + + // Create new data with i=5..15 + let new_batches = Box::new(make_test_batches_with_offset(5)); + + // Perform a "insert if not exists" + let mut merge_insert_builder = table.merge_insert(&["i"]); + merge_insert_builder.when_not_matched_insert_all(); + merge_insert_builder.execute(new_batches).await.unwrap(); + // Only 5 rows should actually be inserted + assert_eq!(table.count_rows().await.unwrap(), 15); + + // Create new data with i=15..25 (no id matches) + let new_batches = Box::new(make_test_batches_with_offset(15)); + // Perform a "bulk update" (should not affect anything) + let mut merge_insert_builder = table.merge_insert(&["i"]); + merge_insert_builder.when_matched_update_all(); + merge_insert_builder.execute(new_batches).await.unwrap(); + // No new rows should have been inserted + assert_eq!(table.count_rows().await.unwrap(), 15); + } + #[tokio::test] async fn test_add_overwrite() { let tmp_dir = tempdir().unwrap(); @@ -1148,17 +1292,25 @@ mod tests { assert!(wrapper.called()); } - fn make_test_batches() -> impl RecordBatchReader + Send + Sync + 'static { + fn make_test_batches_with_offset( + offset: i32, + ) -> impl RecordBatchReader + Send + Sync + 'static { let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, false)])); RecordBatchIterator::new( vec![RecordBatch::try_new( schema.clone(), - vec![Arc::new(Int32Array::from_iter_values(0..10))], + vec![Arc::new(Int32Array::from_iter_values( + offset..(offset + 10), + ))], )], schema, ) } + fn make_test_batches() -> impl RecordBatchReader + Send + Sync + 'static { + make_test_batches_with_offset(0) + } + #[tokio::test] async fn test_create_index() { use arrow_array::RecordBatch; diff --git a/rust/vectordb/src/table/merge.rs b/rust/vectordb/src/table/merge.rs new file mode 100644 index 00000000..26caa235 --- /dev/null +++ b/rust/vectordb/src/table/merge.rs @@ -0,0 +1,95 @@ +// Copyright 2024 Lance Developers. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Arc; + +use arrow_array::RecordBatchReader; +use async_trait::async_trait; + +use crate::Result; + +#[async_trait] +pub(super) trait MergeInsert: Send + Sync { + async fn do_merge_insert( + &self, + params: MergeInsertBuilder, + new_data: Box, + ) -> Result<()>; +} + +/// A builder used to create and run a merge insert operation +/// +/// See [`super::Table::merge_insert`] for more context +pub struct MergeInsertBuilder { + table: Arc, + pub(super) on: Vec, + pub(super) when_matched_update_all: bool, + pub(super) when_not_matched_insert_all: bool, + pub(super) when_not_matched_by_source_delete: bool, + pub(super) when_not_matched_by_source_delete_filt: Option, +} + +impl MergeInsertBuilder { + pub(super) fn new(table: Arc, on: Vec) -> Self { + Self { + table, + on, + when_matched_update_all: false, + when_not_matched_insert_all: false, + when_not_matched_by_source_delete: false, + when_not_matched_by_source_delete_filt: None, + } + } + + /// Rows that exist in both the source table (new data) and + /// the target table (old data) will be updated, replacing + /// the old row with the corresponding matching row. + /// + /// If there are multiple matches then the behavior is undefined. + /// Currently this causes multiple copies of the row to be created + /// but that behavior is subject to change. + pub fn when_matched_update_all(&mut self) -> &mut Self { + self.when_matched_update_all = true; + self + } + + /// Rows that exist only in the source table (new data) should + /// be inserted into the target table. + pub fn when_not_matched_insert_all(&mut self) -> &mut Self { + self.when_not_matched_insert_all = true; + self + } + + /// Rows that exist only in the target table (old data) will be + /// deleted. An optional condition can be provided to limit what + /// data is deleted. + /// + /// # Arguments + /// + /// * `condition` - If None then all such rows will be deleted. + /// Otherwise the condition will be used as an SQL filter to + /// limit what rows are deleted. + pub fn when_not_matched_by_source_delete(&mut self, filter: Option) -> &mut Self { + self.when_not_matched_by_source_delete = true; + self.when_not_matched_by_source_delete_filt = filter; + self + } + + /// Executes the merge insert operation + /// + /// Nothing is returned but the [`super::Table`] is updated + pub async fn execute(self, new_data: Box) -> Result<()> { + self.table.clone().do_merge_insert(self, new_data).await + } +}