diff --git a/docs/src/python/python.md b/docs/src/python/python.md index 8d21d4e5..6438c8bd 100644 --- a/docs/src/python/python.md +++ b/docs/src/python/python.md @@ -58,6 +58,8 @@ pip install lancedb ::: lancedb.schema.vector +::: lancedb.merge.LanceMergeInsertBuilder + ## Integrations ### Pydantic diff --git a/python/lancedb/merge.py b/python/lancedb/merge.py new file mode 100644 index 00000000..e689513b --- /dev/null +++ b/python/lancedb/merge.py @@ -0,0 +1,86 @@ +# Copyright 2023 LanceDB Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING, Iterable, Optional + +if TYPE_CHECKING: + from .common import DATA + + +class LanceMergeInsertBuilder(object): + """Builder for a LanceDB merge insert operation + + See [`merge_insert`][lancedb.table.Table.merge_insert] for + more context + """ + + def __init__(self, table: "Table", on: Iterable[str]): # noqa: F821 + # Do not put a docstring here. This method should be hidden + # from API docs. Users should use merge_insert to create + # this object. + self._table = table + self._on = on + self._when_matched_update_all = False + self._when_not_matched_insert_all = False + self._when_not_matched_by_source_delete = False + self._when_not_matched_by_source_condition = None + + def when_matched_update_all(self) -> LanceMergeInsertBuilder: + """ + Rows that exist in both the source table (new data) and + the target table (old data) will be updated, replacing + the old row with the corresponding matching row. + + If there are multiple matches then the behavior is undefined. + Currently this causes multiple copies of the row to be created + but that behavior is subject to change. + """ + self._when_matched_update_all = True + return self + + def when_not_matched_insert_all(self) -> LanceMergeInsertBuilder: + """ + Rows that exist only in the source table (new data) should + be inserted into the target table. + """ + self._when_not_matched_insert_all = True + return self + + def when_not_matched_by_source_delete( + self, condition: Optional[str] = None + ) -> LanceMergeInsertBuilder: + """ + Rows that exist only in the target table (old data) will be + deleted. An optional condition can be provided to limit what + data is deleted. + + Parameters + ---------- + condition: Optional[str], default None + If None then all such rows will be deleted. Otherwise the + condition will be used as an SQL filter to limit what rows + are deleted. + """ + self._when_not_matched_by_source_delete = True + if condition is not None: + self._when_not_matched_by_source_condition = condition + return self + + def execute(self, new_data: DATA): + """ + Executes the merge insert operation + + Nothing is returned but the [`Table`][lancedb.table.Table] is updated + """ + self._table._do_merge(self, new_data) diff --git a/python/lancedb/remote/table.py b/python/lancedb/remote/table.py index 4878bc1d..a313d2c9 100644 --- a/python/lancedb/remote/table.py +++ b/python/lancedb/remote/table.py @@ -244,6 +244,10 @@ class RemoteTable(Table): result = self._conn._client.query(self._name, query) return result.to_arrow() + def _do_merge(self, *_args): + """_do_merge() is not supported on the LanceDB cloud yet""" + return NotImplementedError("_do_merge() is not supported on the LanceDB cloud") + def delete(self, predicate: str): """Delete rows from the table. diff --git a/python/lancedb/table.py b/python/lancedb/table.py index d33c3b73..6e433425 100644 --- a/python/lancedb/table.py +++ b/python/lancedb/table.py @@ -28,6 +28,7 @@ from lance.vector import vec_to_table from .common import DATA, VEC, VECTOR_COLUMN_NAME from .embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry +from .merge import LanceMergeInsertBuilder from .pydantic import LanceModel, model_to_dict from .query import LanceQueryBuilder, Query from .util import ( @@ -334,6 +335,64 @@ class Table(ABC): """ raise NotImplementedError + def merge_insert(self, on: Union[str, Iterable[str]]) -> LanceMergeInsertBuilder: + """ + Returns a [`LanceMergeInsertBuilder`][lancedb.merge.LanceMergeInsertBuilder] + that can be used to create a "merge insert" operation + + This operation can add rows, update rows, and remove rows all in a single + transaction. It is a very generic tool that can be used to create + behaviors like "insert if not exists", "update or insert (i.e. upsert)", + or even replace a portion of existing data with new data (e.g. replace + all data where month="january") + + The merge insert operation works by combining new data from a + **source table** with existing data in a **target table** by using a + join. There are three categories of records. + + "Matched" records are records that exist in both the source table and + the target table. "Not matched" records exist only in the source table + (e.g. these are new data) "Not matched by source" records exist only + in the target table (this is old data) + + The builder returned by this method can be used to customize what + should happen for each category of data. + + Please note that the data may appear to be reordered as part of this + operation. This is because updated rows will be deleted from the + dataset and then reinserted at the end with the new values. + + Parameters + ---------- + + on: Union[str, Iterable[str]] + A column (or columns) to join on. This is how records from the + source table and target table are matched. Typically this is some + kind of key or id column. + + Examples + -------- + >>> import lancedb + >>> data = pa.table({"a": [2, 1, 3], "b": ["a", "b", "c"]}) + >>> db = lancedb.connect("./.lancedb") + >>> table = db.create_table("my_table", data) + >>> new_data = pa.table({"a": [2, 3, 4], "b": ["x", "y", "z"]}) + >>> # Perform a "upsert" operation + >>> table.merge_insert("a") \\ + ... .when_matched_update_all() \\ + ... .when_not_matched_insert_all() \\ + ... .execute(new_data) + >>> # The order of new rows is non-deterministic since we use + >>> # a hash-join as part of this operation and so we sort here + >>> table.to_arrow().sort_by("a").to_pandas() + a b + 0 1 b + 1 2 x + 2 3 y + 3 4 z + """ + return LanceMergeInsertBuilder(self, on) + @abstractmethod def search( self, @@ -414,6 +473,16 @@ class Table(ABC): def _execute_query(self, query: Query) -> pa.Table: pass + @abstractmethod + def _do_merge( + self, + merge: LanceMergeInsertBuilder, + new_data: DATA, + *, + schema: Optional[pa.Schema] = None, + ): + pass + @abstractmethod def delete(self, where: str): """Delete rows from the table. @@ -1196,6 +1265,18 @@ class LanceTable(Table): with_row_id=query.with_row_id, ) + def _do_merge(self, merge: LanceMergeInsertBuilder, new_data: DATA, *, schema=None): + ds = self.to_lance() + builder = ds.merge_insert(merge._on) + if merge._when_matched_update_all: + builder.when_matched_update_all() + if merge._when_not_matched_insert_all: + builder.when_not_matched_insert_all() + if merge._when_not_matched_by_source_delete: + cond = merge._when_not_matched_by_source_condition + builder.when_not_matched_by_source_delete(cond) + builder.execute(new_data, schema=schema) + def cleanup_old_versions( self, older_than: Optional[timedelta] = None, diff --git a/python/pyproject.toml b/python/pyproject.toml index e43b635f..cf9f9eb3 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -3,7 +3,7 @@ name = "lancedb" version = "0.5.1" dependencies = [ "deprecation", - "pylance==0.9.10", + "pylance==0.9.11", "ratelimiter~=1.0", "retry>=0.9.2", "tqdm>=4.27.0", diff --git a/python/tests/test_table.py b/python/tests/test_table.py index f5118caa..41d72da7 100644 --- a/python/tests/test_table.py +++ b/python/tests/test_table.py @@ -493,6 +493,62 @@ def test_update_types(db): assert actual == expected +def test_merge_insert(db): + table = LanceTable.create( + db, + "my_table", + data=pa.table({"a": [1, 2, 3], "b": ["a", "b", "c"]}), + ) + assert len(table) == 3 + version = table.version + + new_data = pa.table({"a": [2, 3, 4], "b": ["x", "y", "z"]}) + + # upsert + table.merge_insert( + "a" + ).when_matched_update_all().when_not_matched_insert_all().execute(new_data) + + expected = pa.table({"a": [1, 2, 3, 4], "b": ["a", "x", "y", "z"]}) + # These `sort_by` calls can be removed once lance#1892 + # is merged (it fixes the ordering) + assert table.to_arrow().sort_by("a") == expected + + table.restore(version) + + # insert-if-not-exists + table.merge_insert("a").when_not_matched_insert_all().execute(new_data) + + expected = pa.table({"a": [1, 2, 3, 4], "b": ["a", "b", "c", "z"]}) + assert table.to_arrow().sort_by("a") == expected + + table.restore(version) + + new_data = pa.table({"a": [2, 4], "b": ["x", "z"]}) + + # replace-range + table.merge_insert( + "a" + ).when_matched_update_all().when_not_matched_insert_all().when_not_matched_by_source_delete( + "a > 2" + ).execute(new_data) + + expected = pa.table({"a": [1, 2, 4], "b": ["a", "x", "z"]}) + assert table.to_arrow().sort_by("a") == expected + + table.restore(version) + + # replace-range no condition + table.merge_insert( + "a" + ).when_matched_update_all().when_not_matched_insert_all().when_not_matched_by_source_delete().execute( + new_data + ) + + expected = pa.table({"a": [2, 4], "b": ["x", "z"]}) + assert table.to_arrow().sort_by("a") == expected + + def test_create_with_embedding_function(db): class MyTable(LanceModel): text: str