diff --git a/python/python/lancedb/merge.py b/python/python/lancedb/merge.py index 69671c5e..48cc9847 100644 --- a/python/python/lancedb/merge.py +++ b/python/python/lancedb/merge.py @@ -104,4 +104,4 @@ class LanceMergeInsertBuilder(object): fill_value: float, default 0. The value to use when filling vectors. Only used if on_bad_vectors="fill". """ - self._table._do_merge(self, new_data, on_bad_vectors, fill_value) + return self._table._do_merge(self, new_data, on_bad_vectors, fill_value) diff --git a/python/python/lancedb/table.py b/python/python/lancedb/table.py index 59e0d465..b4f8a2e6 100644 --- a/python/python/lancedb/table.py +++ b/python/python/lancedb/table.py @@ -2464,7 +2464,31 @@ class AsyncTable: on_bad_vectors: str, fill_value: float, ): - pass + schema = await self.schema() + if on_bad_vectors is None: + on_bad_vectors = "error" + if fill_value is None: + fill_value = 0.0 + data, _ = _sanitize_data( + new_data, + schema, + metadata=schema.metadata, + on_bad_vectors=on_bad_vectors, + fill_value=fill_value, + ) + if isinstance(data, pa.Table): + data = pa.RecordBatchReader.from_batches(data.schema, data.to_batches()) + await self._inner.execute_merge_insert( + data, + dict( + on=merge._on, + when_matched_update_all=merge._when_matched_update_all, + when_matched_update_all_condition=merge._when_matched_update_all_condition, + when_not_matched_insert_all=merge._when_not_matched_insert_all, + when_not_matched_by_source_delete=merge._when_not_matched_by_source_delete, + when_not_matched_by_source_condition=merge._when_not_matched_by_source_condition, + ), + ) async def delete(self, where: str): """Delete rows from the table. diff --git a/python/python/tests/test_table.py b/python/python/tests/test_table.py index c32a5c98..cc5ecbd2 100644 --- a/python/python/tests/test_table.py +++ b/python/python/tests/test_table.py @@ -636,11 +636,13 @@ def test_merge_insert(db): new_data = pa.table({"a": [2, 4], "b": ["x", "z"]}) # replace-range - table.merge_insert( - "a" - ).when_matched_update_all().when_not_matched_insert_all().when_not_matched_by_source_delete( - "a > 2" - ).execute(new_data) + ( + table.merge_insert("a") + .when_matched_update_all() + .when_not_matched_insert_all() + .when_not_matched_by_source_delete("a > 2") + .execute(new_data) + ) expected = pa.table({"a": [1, 2, 4], "b": ["a", "x", "z"]}) assert table.to_arrow().sort_by("a") == expected @@ -658,6 +660,75 @@ def test_merge_insert(db): assert table.to_arrow().sort_by("a") == expected +@pytest.mark.asyncio +async def test_merge_insert_async(db_async: AsyncConnection): + data = pa.table({"a": [1, 2, 3], "b": ["a", "b", "c"]}) + table = await db_async.create_table("some_table", data=data) + assert await table.count_rows() == 3 + version = await table.version() + + new_data = pa.table({"a": [2, 3, 4], "b": ["x", "y", "z"]}) + + # upsert + await ( + table.merge_insert("a") + .when_matched_update_all() + .when_not_matched_insert_all() + .execute(new_data) + ) + expected = pa.table({"a": [1, 2, 3, 4], "b": ["a", "x", "y", "z"]}) + assert (await table.to_arrow()).sort_by("a") == expected + + await table.checkout(version) + await table.restore() + + # conditional update + await ( + table.merge_insert("a") + .when_matched_update_all(where="target.b = 'b'") + .execute(new_data) + ) + expected = pa.table({"a": [1, 2, 3], "b": ["a", "x", "c"]}) + assert (await table.to_arrow()).sort_by("a") == expected + + await table.checkout(version) + await table.restore() + + # insert-if-not-exists + await table.merge_insert("a").when_not_matched_insert_all().execute(new_data) + expected = pa.table({"a": [1, 2, 3, 4], "b": ["a", "b", "c", "z"]}) + assert (await table.to_arrow()).sort_by("a") == expected + + await table.checkout(version) + await table.restore() + + # replace-range + new_data = pa.table({"a": [2, 4], "b": ["x", "z"]}) + await ( + table.merge_insert("a") + .when_matched_update_all() + .when_not_matched_insert_all() + .when_not_matched_by_source_delete("a > 2") + .execute(new_data) + ) + expected = pa.table({"a": [1, 2, 4], "b": ["a", "x", "z"]}) + assert (await table.to_arrow()).sort_by("a") == expected + + await table.checkout(version) + await table.restore() + + # replace-range no condition + await ( + table.merge_insert("a") + .when_matched_update_all() + .when_not_matched_insert_all() + .when_not_matched_by_source_delete() + .execute(new_data) + ) + expected = pa.table({"a": [2, 4], "b": ["x", "z"]}) + assert (await table.to_arrow()).sort_by("a") == expected + + def test_create_with_embedding_function(db): class MyTable(LanceModel): text: str diff --git a/python/src/table.rs b/python/src/table.rs index 957bf76f..b5087f8d 100644 --- a/python/src/table.rs +++ b/python/src/table.rs @@ -9,7 +9,7 @@ use pyo3::{ exceptions::{PyRuntimeError, PyValueError}, pyclass, pymethods, types::{PyDict, PyDictMethods, PyString}, - Bound, PyAny, PyRef, PyResult, Python, ToPyObject, + Bound, FromPyObject, PyAny, PyRef, PyResult, Python, ToPyObject, }; use pyo3_asyncio_0_21::tokio::future_into_py; @@ -331,6 +331,31 @@ impl Table { }) } + pub fn execute_merge_insert<'a>( + self_: PyRef<'a, Self>, + data: Bound<'a, PyAny>, + parameters: MergeInsertParams, + ) -> PyResult> { + let batches: ArrowArrayStreamReader = ArrowArrayStreamReader::from_pyarrow_bound(&data)?; + let on = parameters.on.iter().map(|s| s.as_str()).collect::>(); + let mut builder = self_.inner_ref()?.merge_insert(&on); + if parameters.when_matched_update_all { + builder.when_matched_update_all(parameters.when_matched_update_all_condition); + } + if parameters.when_not_matched_insert_all { + builder.when_not_matched_insert_all(); + } + if parameters.when_not_matched_by_source_delete { + builder + .when_not_matched_by_source_delete(parameters.when_not_matched_by_source_condition); + } + + future_into_py(self_.py(), async move { + builder.execute(Box::new(batches)).await.infer_error()?; + Ok(()) + }) + } + pub fn uses_v2_manifest_paths(self_: PyRef<'_, Self>) -> PyResult> { let inner = self_.inner_ref()?.clone(); future_into_py(self_.py(), async move { @@ -355,3 +380,14 @@ impl Table { }) } } + +#[derive(FromPyObject)] +#[pyo3(from_item_all)] +pub struct MergeInsertParams { + on: Vec, + when_matched_update_all: bool, + when_matched_update_all_condition: Option, + when_not_matched_insert_all: bool, + when_not_matched_by_source_delete: bool, + when_not_matched_by_source_condition: Option, +}