feat: add merge_insert to the node and rust APIs (#915)

This commit is contained in:
Weston Pace
2024-02-02 13:16:51 -08:00
parent 2e75b16403
commit 18f7bad3dd
11 changed files with 565 additions and 18 deletions

View File

@@ -11,10 +11,10 @@ license = "Apache-2.0"
repository = "https://github.com/lancedb/lancedb"
[workspace.dependencies]
lance = { "version" = "=0.9.10", "features" = ["dynamodb"] }
lance-index = { "version" = "=0.9.10" }
lance-linalg = { "version" = "=0.9.10" }
lance-testing = { "version" = "=0.9.10" }
lance = { "version" = "=0.9.12", "features" = ["dynamodb"] }
lance-index = { "version" = "=0.9.12" }
lance-linalg = { "version" = "=0.9.12" }
lance-testing = { "version" = "=0.9.12" }
# Note that this one does not include pyarrow
arrow = { version = "50.0", optional = false }
arrow-array = "50.0"

View File

@@ -37,6 +37,7 @@ const {
tableCountRows,
tableDelete,
tableUpdate,
tableMergeInsert,
tableCleanupOldVersions,
tableCompactFiles,
tableListIndices,
@@ -440,6 +441,38 @@ export interface Table<T = number[]> {
*/
update: (args: UpdateArgs | UpdateSqlArgs) => Promise<void>
/**
* Runs a "merge insert" operation on the table
*
* This operation can add rows, update rows, and remove rows all in a single
* transaction. It is a very generic tool that can be used to create
* behaviors like "insert if not exists", "update or insert (i.e. upsert)",
* or even replace a portion of existing data with new data (e.g. replace
* all data where month="january")
*
* The merge insert operation works by combining new data from a
* **source table** with existing data in a **target table** by using a
* join. There are three categories of records.
*
* "Matched" records are records that exist in both the source table and
* the target table. "Not matched" records exist only in the source table
* (e.g. these are new data) "Not matched by source" records exist only
* in the target table (this is old data)
*
* The MergeInsertArgs can be used to customize what should happen for
* each category of data.
*
* Please note that the data may appear to be reordered as part of this
* operation. This is because updated rows will be deleted from the
* dataset and then reinserted at the end with the new values.
*
* @param on a column to join on. This is how records from the source
* table and target table are matched.
* @param data the new data to insert
* @param args parameters controlling how the operation should behave
*/
mergeInsert: (on: string, data: Array<Record<string, unknown>> | ArrowTable, args: MergeInsertArgs) => Promise<void>
/**
* List the indicies on this table.
*/
@@ -483,6 +516,36 @@ export interface UpdateSqlArgs {
valuesSql: Record<string, string>
}
export interface MergeInsertArgs {
/**
* If true then rows that exist in both the source table (new data) and
* the target table (old data) will be updated, replacing the old row
* with the corresponding matching row.
*
* If there are multiple matches then the behavior is undefined.
* Currently this causes multiple copies of the row to be created
* but that behavior is subject to change.
*/
whenMatchedUpdateAll?: boolean
/**
* If true then rows that exist only in the source table (new data)
* will be inserted into the target table.
*/
whenNotMatchedInsertAll?: boolean
/**
* If true then rows that exist only in the target table (old data)
* will be deleted.
*
* If this is a string then it will be treated as an SQL filter and
* only rows that both do not match any row in the source table and
* match the given filter will be deleted.
*
* This can be used to replace a selection of existing data with
* new data.
*/
whenNotMatchedBySourceDelete?: string | boolean
}
export interface VectorIndex {
columns: string[]
name: string
@@ -821,6 +884,38 @@ export class LocalTable<T = number[]> implements Table<T> {
})
}
async mergeInsert (on: string, data: Array<Record<string, unknown>> | ArrowTable, args: MergeInsertArgs): Promise<void> {
const whenMatchedUpdateAll = args.whenMatchedUpdateAll ?? false
const whenNotMatchedInsertAll = args.whenNotMatchedInsertAll ?? false
let whenNotMatchedBySourceDelete = false
let whenNotMatchedBySourceDeleteFilt = null
if (args.whenNotMatchedBySourceDelete !== undefined && args.whenNotMatchedBySourceDelete !== null) {
whenNotMatchedBySourceDelete = true
if (args.whenNotMatchedBySourceDelete !== true) {
whenNotMatchedBySourceDeleteFilt = args.whenNotMatchedBySourceDelete
}
}
const schema = await this.schema
let tbl: ArrowTable
if (data instanceof ArrowTable) {
tbl = data
} else {
tbl = makeArrowTable(data, { schema })
}
const buffer = await fromTableToBuffer(tbl, this._embeddings, schema)
this._tbl = await tableMergeInsert.call(
this._tbl,
on,
whenMatchedUpdateAll,
whenNotMatchedInsertAll,
whenNotMatchedBySourceDelete,
whenNotMatchedBySourceDeleteFilt,
buffer
)
}
/**
* Clean up old versions of the table, freeing disk space.
*

View File

@@ -24,7 +24,8 @@ import {
type IndexStats,
type UpdateArgs,
type UpdateSqlArgs,
makeArrowTable
makeArrowTable,
type MergeInsertArgs
} from '../index'
import { Query } from '../query'
@@ -274,6 +275,52 @@ export class RemoteTable<T = number[]> implements Table<T> {
throw new Error('Not implemented')
}
async mergeInsert (on: string, data: Array<Record<string, unknown>> | ArrowTable, args: MergeInsertArgs): Promise<void> {
let tbl: ArrowTable
if (data instanceof ArrowTable) {
tbl = data
} else {
tbl = makeArrowTable(data, await this.schema)
}
const queryParams: any = {
on
}
if (args.whenMatchedUpdateAll ?? false) {
queryParams.when_matched_update_all = 'true'
} else {
queryParams.when_matched_update_all = 'false'
}
if (args.whenNotMatchedInsertAll ?? false) {
queryParams.when_not_matched_insert_all = 'true'
} else {
queryParams.when_not_matched_insert_all = 'false'
}
if (args.whenNotMatchedBySourceDelete !== false && args.whenNotMatchedBySourceDelete !== null && args.whenNotMatchedBySourceDelete !== undefined) {
queryParams.when_not_matched_by_source_delete = 'true'
if (typeof args.whenNotMatchedBySourceDelete === 'string') {
queryParams.when_not_matched_by_source_delete_filt = args.whenNotMatchedBySourceDelete
}
} else {
queryParams.when_not_matched_by_source_delete = 'false'
}
const buffer = await fromTableToStreamBuffer(tbl, this._embeddings)
const res = await this._client.post(
`/v1/table/${this._name}/merge_insert/`,
buffer,
queryParams,
'application/vnd.apache.arrow.stream'
)
if (res.status !== 200) {
throw new Error(
`Server Error, status: ${res.status}, ` +
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
`message: ${res.statusText}: ${res.data}`
)
}
}
async add (data: Array<Record<string, unknown>> | ArrowTable): Promise<number> {
let tbl: ArrowTable
if (data instanceof ArrowTable) {

View File

@@ -531,6 +531,44 @@ describe('LanceDB client', function () {
assert.equal(await table.countRows(), 2)
})
it('can merge insert records into the table', async function () {
const dir = await track().mkdir('lancejs')
const con = await lancedb.connect(dir)
const data = [{ id: 1, age: 1 }, { id: 2, age: 1 }]
const table = await con.createTable('my_table', data)
let newData = [{ id: 2, age: 2 }, { id: 3, age: 2 }]
await table.mergeInsert('id', newData, {
whenNotMatchedInsertAll: true
})
assert.equal(await table.countRows(), 3)
assert.equal((await table.filter('age = 2').execute()).length, 1)
newData = [{ id: 3, age: 3 }, { id: 4, age: 3 }]
await table.mergeInsert('id', newData, {
whenNotMatchedInsertAll: true,
whenMatchedUpdateAll: true
})
assert.equal(await table.countRows(), 4)
assert.equal((await table.filter('age = 3').execute()).length, 2)
newData = [{ id: 5, age: 4 }]
await table.mergeInsert('id', newData, {
whenNotMatchedInsertAll: true,
whenMatchedUpdateAll: true,
whenNotMatchedBySourceDelete: 'age < 3'
})
assert.equal(await table.countRows(), 3)
await table.mergeInsert('id', newData, {
whenNotMatchedInsertAll: true,
whenMatchedUpdateAll: true,
whenNotMatchedBySourceDelete: true
})
assert.equal(await table.countRows(), 1)
})
it('can update records in the table', async function () {
const uri = await createTestDB()
const con = await lancedb.connect(uri)

View File

@@ -12,7 +12,7 @@
# limitations under the License.
from __future__ import annotations
from typing import TYPE_CHECKING, Iterable, Optional
from typing import TYPE_CHECKING, List, Optional
if TYPE_CHECKING:
from .common import DATA
@@ -25,7 +25,7 @@ class LanceMergeInsertBuilder(object):
more context
"""
def __init__(self, table: "Table", on: Iterable[str]): # noqa: F821
def __init__(self, table: "Table", on: List[str]): # noqa: F821
# Do not put a docstring here. This method should be hidden
# from API docs. Users should use merge_insert to create
# this object.
@@ -77,10 +77,27 @@ class LanceMergeInsertBuilder(object):
self._when_not_matched_by_source_condition = condition
return self
def execute(self, new_data: DATA):
def execute(
self,
new_data: DATA,
on_bad_vectors: str = "error",
fill_value: float = 0.0,
):
"""
Executes the merge insert operation
Nothing is returned but the [`Table`][lancedb.table.Table] is updated
Parameters
----------
new_data: DATA
New records which will be matched against the existing records
to potentially insert or update into the table. This parameter
can be anything you use for [`add`][lancedb.table.Table.add]
on_bad_vectors: str, default "error"
What to do if any of the vectors are not the same size or contains NaNs.
One of "error", "drop", "fill".
fill_value: float, default 0.
The value to use when filling vectors. Only used if on_bad_vectors="fill".
"""
self._table._do_merge(self, new_data)
self._table._do_merge(self, new_data, on_bad_vectors, fill_value)

View File

@@ -19,6 +19,7 @@ import pyarrow as pa
from lance import json_to_schema
from lancedb.common import DATA, VEC, VECTOR_COLUMN_NAME
from lancedb.merge import LanceMergeInsertBuilder
from ..query import LanceVectorQueryBuilder
from ..table import Query, Table, _sanitize_data
@@ -244,9 +245,46 @@ class RemoteTable(Table):
result = self._conn._client.query(self._name, query)
return result.to_arrow()
def _do_merge(self, *_args):
"""_do_merge() is not supported on the LanceDB cloud yet"""
return NotImplementedError("_do_merge() is not supported on the LanceDB cloud")
def _do_merge(
self,
merge: LanceMergeInsertBuilder,
new_data: DATA,
on_bad_vectors: str,
fill_value: float,
):
data = _sanitize_data(
new_data,
self.schema,
metadata=None,
on_bad_vectors=on_bad_vectors,
fill_value=fill_value,
)
payload = to_ipc_binary(data)
params = {}
if len(merge._on) != 1:
raise ValueError(
"RemoteTable only supports a single on key in merge_insert"
)
params["on"] = merge._on[0]
params["when_matched_update_all"] = str(merge._when_matched_update_all).lower()
params["when_not_matched_insert_all"] = str(
merge._when_not_matched_insert_all
).lower()
params["when_not_matched_by_source_delete"] = str(
merge._when_not_matched_by_source_delete
).lower()
if merge._when_not_matched_by_source_condition is not None:
params[
"when_not_matched_by_source_delete_filt"
] = merge._when_not_matched_by_source_condition
self._conn._client.post(
f"/v1/table/{self._name}/merge_insert/",
data=payload,
params=params,
content_type=ARROW_STREAM_CONTENT_TYPE,
)
def delete(self, predicate: str):
"""Delete rows from the table.

View File

@@ -390,6 +390,8 @@ class Table(ABC):
2 3 y
3 4 z
"""
on = [on] if isinstance(on, str) else list(on.iter())
return LanceMergeInsertBuilder(self, on)
@abstractmethod
@@ -479,8 +481,8 @@ class Table(ABC):
self,
merge: LanceMergeInsertBuilder,
new_data: DATA,
*,
schema: Optional[pa.Schema] = None,
on_bad_vectors: str,
fill_value: float,
):
pass
@@ -1305,7 +1307,20 @@ class LanceTable(Table):
with_row_id=query.with_row_id,
)
def _do_merge(self, merge: LanceMergeInsertBuilder, new_data: DATA, *, schema=None):
def _do_merge(
self,
merge: LanceMergeInsertBuilder,
new_data: DATA,
on_bad_vectors: str,
fill_value: float,
):
new_data = _sanitize_data(
new_data,
self.schema,
metadata=self.schema.metadata,
on_bad_vectors=on_bad_vectors,
fill_value=fill_value,
)
ds = self.to_lance()
builder = ds.merge_insert(merge._on)
if merge._when_matched_update_all:
@@ -1315,7 +1330,7 @@ class LanceTable(Table):
if merge._when_not_matched_by_source_delete:
cond = merge._when_not_matched_by_source_condition
builder.when_not_matched_by_source_delete(cond)
builder.execute(new_data, schema=schema)
builder.execute(new_data)
def cleanup_old_versions(
self,

View File

@@ -260,6 +260,7 @@ fn main(mut cx: ModuleContext) -> NeonResult<()> {
cx.export_function("tableCountRows", JsTable::js_count_rows)?;
cx.export_function("tableDelete", JsTable::js_delete)?;
cx.export_function("tableUpdate", JsTable::js_update)?;
cx.export_function("tableMergeInsert", JsTable::js_merge_insert)?;
cx.export_function("tableCleanupOldVersions", JsTable::js_cleanup)?;
cx.export_function("tableCompactFiles", JsTable::js_compact)?;
cx.export_function("tableListIndices", JsTable::js_list_indices)?;

View File

@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::ops::Deref;
use arrow_array::{RecordBatch, RecordBatchIterator};
use lance::dataset::optimize::CompactionOptions;
use lance::dataset::{WriteMode, WriteParams};
@@ -166,6 +168,53 @@ impl JsTable {
Ok(promise)
}
pub(crate) fn js_merge_insert(mut cx: FunctionContext) -> JsResult<JsPromise> {
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
let rt = runtime(&mut cx)?;
let (deferred, promise) = cx.promise();
let channel = cx.channel();
let table = js_table.table.clone();
let key = cx.argument::<JsString>(0)?.value(&mut cx);
let mut builder = table.merge_insert(&[&key]);
if cx.argument::<JsBoolean>(1)?.value(&mut cx) {
builder.when_matched_update_all();
}
if cx.argument::<JsBoolean>(2)?.value(&mut cx) {
builder.when_not_matched_insert_all();
}
if cx.argument::<JsBoolean>(3)?.value(&mut cx) {
if let Some(filter) = cx.argument_opt(4) {
if filter.is_a::<JsNull, _>(&mut cx) {
builder.when_not_matched_by_source_delete(None);
} else {
let filter = filter
.downcast_or_throw::<JsString, _>(&mut cx)?
.deref()
.value(&mut cx);
builder.when_not_matched_by_source_delete(Some(filter));
}
} else {
builder.when_not_matched_by_source_delete(None);
}
}
let buffer = cx.argument::<JsBuffer>(5)?;
let (batches, schema) =
arrow_buffer_to_record_batch(buffer.as_slice(&cx)).or_throw(&mut cx)?;
rt.spawn(async move {
let new_data = RecordBatchIterator::new(batches.into_iter().map(Ok), schema);
let merge_insert_result = builder.execute(Box::new(new_data)).await;
deferred.settle_with(&channel, move |mut cx| {
merge_insert_result.or_throw(&mut cx)?;
Ok(cx.boxed(JsTable::from(table)))
})
});
Ok(promise)
}
pub(crate) fn js_update(mut cx: FunctionContext) -> JsResult<JsPromise> {
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
let table = js_table.table.clone();

View File

@@ -19,6 +19,7 @@ use std::sync::{Arc, Mutex};
use arrow_array::RecordBatchReader;
use arrow_schema::{Schema, SchemaRef};
use async_trait::async_trait;
use chrono::Duration;
use lance::dataset::builder::DatasetBuilder;
use lance::dataset::cleanup::RemovalStats;
@@ -27,6 +28,7 @@ use lance::dataset::optimize::{
};
pub use lance::dataset::ReadParams;
use lance::dataset::{Dataset, UpdateBuilder, WriteParams};
use lance::dataset::{MergeInsertBuilder as LanceMergeInsertBuilder, WhenNotMatchedBySource};
use lance::io::WrappingObjectStore;
use lance_index::{optimize::OptimizeOptions, DatasetIndexExt};
use log::info;
@@ -38,6 +40,10 @@ use crate::query::Query;
use crate::utils::{PatchReadParam, PatchWriteParam};
use crate::WriteMode;
use self::merge::{MergeInsert, MergeInsertBuilder};
pub mod merge;
/// Optimize the dataset.
///
/// Similar to `VACUUM` in PostgreSQL, it offers different options to
@@ -170,6 +176,71 @@ pub trait Table: std::fmt::Display + Send + Sync {
/// ```
fn create_index(&self, column: &[&str]) -> IndexBuilder;
/// Create a builder for a merge insert operation
///
/// This operation can add rows, update rows, and remove rows all in a single
/// transaction. It is a very generic tool that can be used to create
/// behaviors like "insert if not exists", "update or insert (i.e. upsert)",
/// or even replace a portion of existing data with new data (e.g. replace
/// all data where month="january")
///
/// The merge insert operation works by combining new data from a
/// **source table** with existing data in a **target table** by using a
/// join. There are three categories of records.
///
/// "Matched" records are records that exist in both the source table and
/// the target table. "Not matched" records exist only in the source table
/// (e.g. these are new data) "Not matched by source" records exist only
/// in the target table (this is old data)
///
/// The builder returned by this method can be used to customize what
/// should happen for each category of data.
///
/// Please note that the data may appear to be reordered as part of this
/// operation. This is because updated rows will be deleted from the
/// dataset and then reinserted at the end with the new values.
///
/// # Arguments
///
/// * `on` One or more columns to join on. This is how records from the
/// source table and target table are matched. Typically this is some
/// kind of key or id column.
///
/// # Examples
///
/// ```no_run
/// # use std::sync::Arc;
/// # use vectordb::connection::{Database, Connection};
/// # use arrow_array::{FixedSizeListArray, types::Float32Type, RecordBatch,
/// # RecordBatchIterator, Int32Array};
/// # use arrow_schema::{Schema, Field, DataType};
/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
/// let tmpdir = tempfile::tempdir().unwrap();
/// let db = Database::connect(tmpdir.path().to_str().unwrap()).await.unwrap();
/// # let tbl = db.open_table("idx_test").await.unwrap();
/// # let schema = Arc::new(Schema::new(vec![
/// # Field::new("id", DataType::Int32, false),
/// # Field::new("vector", DataType::FixedSizeList(
/// # Arc::new(Field::new("item", DataType::Float32, true)), 128), true),
/// # ]));
/// let new_data = RecordBatchIterator::new(vec![
/// RecordBatch::try_new(schema.clone(),
/// vec![
/// Arc::new(Int32Array::from_iter_values(0..10)),
/// Arc::new(FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(
/// (0..10).map(|_| Some(vec![Some(1.0); 128])), 128)),
/// ]).unwrap()
/// ].into_iter().map(Ok),
/// schema.clone());
/// // Perform an upsert operation
/// let mut merge_insert = tbl.merge_insert(&["id"]);
/// merge_insert.when_matched_update_all()
/// .when_not_matched_insert_all();
/// merge_insert.execute(Box::new(new_data)).await.unwrap();
/// # });
/// ```
fn merge_insert(&self, on: &[&str]) -> MergeInsertBuilder;
/// Search the table with a given query vector.
///
/// This is a convenience method for preparing an ANN query.
@@ -593,6 +664,42 @@ impl NativeTable {
}
}
#[async_trait]
impl MergeInsert for NativeTable {
async fn do_merge_insert(
&self,
params: MergeInsertBuilder,
new_data: Box<dyn RecordBatchReader + Send>,
) -> Result<()> {
let dataset = Arc::new(self.clone_inner_dataset());
let mut builder = LanceMergeInsertBuilder::try_new(dataset.clone(), params.on)?;
if params.when_matched_update_all {
builder.when_matched(lance::dataset::WhenMatched::UpdateAll);
} else {
builder.when_matched(lance::dataset::WhenMatched::DoNothing);
}
if params.when_not_matched_insert_all {
builder.when_not_matched(lance::dataset::WhenNotMatched::InsertAll);
} else {
builder.when_not_matched(lance::dataset::WhenNotMatched::DoNothing);
}
if params.when_not_matched_by_source_delete {
let behavior = if let Some(filter) = params.when_not_matched_by_source_delete_filt {
WhenNotMatchedBySource::delete_if(dataset.as_ref(), &filter)?
} else {
WhenNotMatchedBySource::Delete
};
builder.when_not_matched_by_source(behavior);
} else {
builder.when_not_matched_by_source(WhenNotMatchedBySource::Keep);
}
let job = builder.try_build()?;
let new_dataset = job.execute_reader(new_data).await?;
self.reset_dataset((*new_dataset).clone());
Ok(())
}
}
#[async_trait::async_trait]
impl Table for NativeTable {
fn as_any(&self) -> &dyn std::any::Any {
@@ -637,6 +744,11 @@ impl Table for NativeTable {
Ok(())
}
fn merge_insert(&self, on: &[&str]) -> MergeInsertBuilder {
let on = Vec::from_iter(on.iter().map(|key| key.to_string()));
MergeInsertBuilder::new(Arc::new(self.clone()), on)
}
fn create_index(&self, columns: &[&str]) -> IndexBuilder {
IndexBuilder::new(Arc::new(self.clone()), columns)
}
@@ -802,6 +914,38 @@ mod tests {
assert_eq!(table.name, "test");
}
#[tokio::test]
async fn test_merge_insert() {
let tmp_dir = tempdir().unwrap();
let uri = tmp_dir.path().to_str().unwrap();
// Create a dataset with i=0..10
let batches = make_test_batches_with_offset(0);
let table = NativeTable::create(&uri, "test", batches, None, None)
.await
.unwrap();
assert_eq!(table.count_rows().await.unwrap(), 10);
// Create new data with i=5..15
let new_batches = Box::new(make_test_batches_with_offset(5));
// Perform a "insert if not exists"
let mut merge_insert_builder = table.merge_insert(&["i"]);
merge_insert_builder.when_not_matched_insert_all();
merge_insert_builder.execute(new_batches).await.unwrap();
// Only 5 rows should actually be inserted
assert_eq!(table.count_rows().await.unwrap(), 15);
// Create new data with i=15..25 (no id matches)
let new_batches = Box::new(make_test_batches_with_offset(15));
// Perform a "bulk update" (should not affect anything)
let mut merge_insert_builder = table.merge_insert(&["i"]);
merge_insert_builder.when_matched_update_all();
merge_insert_builder.execute(new_batches).await.unwrap();
// No new rows should have been inserted
assert_eq!(table.count_rows().await.unwrap(), 15);
}
#[tokio::test]
async fn test_add_overwrite() {
let tmp_dir = tempdir().unwrap();
@@ -1148,17 +1292,25 @@ mod tests {
assert!(wrapper.called());
}
fn make_test_batches() -> impl RecordBatchReader + Send + Sync + 'static {
fn make_test_batches_with_offset(
offset: i32,
) -> impl RecordBatchReader + Send + Sync + 'static {
let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, false)]));
RecordBatchIterator::new(
vec![RecordBatch::try_new(
schema.clone(),
vec![Arc::new(Int32Array::from_iter_values(0..10))],
vec![Arc::new(Int32Array::from_iter_values(
offset..(offset + 10),
))],
)],
schema,
)
}
fn make_test_batches() -> impl RecordBatchReader + Send + Sync + 'static {
make_test_batches_with_offset(0)
}
#[tokio::test]
async fn test_create_index() {
use arrow_array::RecordBatch;

View File

@@ -0,0 +1,95 @@
// Copyright 2024 Lance Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Arc;
use arrow_array::RecordBatchReader;
use async_trait::async_trait;
use crate::Result;
#[async_trait]
pub(super) trait MergeInsert: Send + Sync {
async fn do_merge_insert(
&self,
params: MergeInsertBuilder,
new_data: Box<dyn RecordBatchReader + Send>,
) -> Result<()>;
}
/// A builder used to create and run a merge insert operation
///
/// See [`super::Table::merge_insert`] for more context
pub struct MergeInsertBuilder {
table: Arc<dyn MergeInsert>,
pub(super) on: Vec<String>,
pub(super) when_matched_update_all: bool,
pub(super) when_not_matched_insert_all: bool,
pub(super) when_not_matched_by_source_delete: bool,
pub(super) when_not_matched_by_source_delete_filt: Option<String>,
}
impl MergeInsertBuilder {
pub(super) fn new(table: Arc<dyn MergeInsert>, on: Vec<String>) -> Self {
Self {
table,
on,
when_matched_update_all: false,
when_not_matched_insert_all: false,
when_not_matched_by_source_delete: false,
when_not_matched_by_source_delete_filt: None,
}
}
/// Rows that exist in both the source table (new data) and
/// the target table (old data) will be updated, replacing
/// the old row with the corresponding matching row.
///
/// If there are multiple matches then the behavior is undefined.
/// Currently this causes multiple copies of the row to be created
/// but that behavior is subject to change.
pub fn when_matched_update_all(&mut self) -> &mut Self {
self.when_matched_update_all = true;
self
}
/// Rows that exist only in the source table (new data) should
/// be inserted into the target table.
pub fn when_not_matched_insert_all(&mut self) -> &mut Self {
self.when_not_matched_insert_all = true;
self
}
/// Rows that exist only in the target table (old data) will be
/// deleted. An optional condition can be provided to limit what
/// data is deleted.
///
/// # Arguments
///
/// * `condition` - If None then all such rows will be deleted.
/// Otherwise the condition will be used as an SQL filter to
/// limit what rows are deleted.
pub fn when_not_matched_by_source_delete(&mut self, filter: Option<String>) -> &mut Self {
self.when_not_matched_by_source_delete = true;
self.when_not_matched_by_source_delete_filt = filter;
self
}
/// Executes the merge insert operation
///
/// Nothing is returned but the [`super::Table`] is updated
pub async fn execute(self, new_data: Box<dyn RecordBatchReader + Send>) -> Result<()> {
self.table.clone().do_merge_insert(self, new_data).await
}
}