From 0cba0f4f92d9bf19d705000fa34ba71f3ee476ae Mon Sep 17 00:00:00 2001 From: Chang She <759245+changhiskhan@users.noreply.github.com> Date: Wed, 30 Aug 2023 00:25:26 -0700 Subject: [PATCH] [python] Temporary update feature (#457) Combine delete and append to make a temporary update feature that is only enabled for the local python lancedb. The reason why this is temporary is because it first has to load the data that matches the where clause into memory, which is technical unbounded. --------- Co-authored-by: Chang She --- python/lancedb/table.py | 50 ++++++++++++++++++++++++++++++++++++++ python/tests/test_table.py | 32 ++++++++++++++++++++++++ 2 files changed, 82 insertions(+) diff --git a/python/lancedb/table.py b/python/lancedb/table.py index ebfd9713..6d361496 100644 --- a/python/lancedb/table.py +++ b/python/lancedb/table.py @@ -678,6 +678,56 @@ class LanceTable(Table): def delete(self, where: str): self._dataset.delete(where) + def update(self, where: str, values: dict): + """ + EXPERIMENTAL: Update rows in the table (not threadsafe). + + This can be used to update zero to all rows depending on how many + rows match the where clause. + + Parameters + ---------- + where: str + The SQL where clause to use when updating rows. For example, 'x = 2' + or 'x IN (1, 2, 3)'. The filter must not be empty, or it will error. + values: dict + The values to update. The keys are the column names and the values + are the values to set. + + Examples + -------- + >>> import lancedb + >>> import pandas as pd + >>> data = pd.DataFrame({"x": [1, 2, 3], "vector": [[1, 2], [3, 4], [5, 6]]}) + >>> db = lancedb.connect("./.lancedb") + >>> table = db.create_table("my_table", data) + >>> table.to_pandas() + x vector + 0 1 [1.0, 2.0] + 1 2 [3.0, 4.0] + 2 3 [5.0, 6.0] + >>> table.update(where="x = 2", values={"vector": [10, 10]}) + >>> table.to_pandas() + x vector + 0 1 [1.0, 2.0] + 1 3 [5.0, 6.0] + 2 2 [10.0, 10.0] + + """ + orig_data = self._dataset.to_table(filter=where).combine_chunks() + if len(orig_data) == 0: + return + for col, val in values.items(): + i = orig_data.column_names.index(col) + if i < 0: + raise ValueError(f"Column {col} does not exist") + orig_data = orig_data.set_column( + i, col, pa.array([val] * len(orig_data), type=orig_data[col].type) + ) + self.delete(where) + self.add(orig_data, mode="append") + self._reset_dataset() + def _execute_query(self, query: Query) -> pa.Table: ds = self.to_lance() return ds.to_table( diff --git a/python/tests/test_table.py b/python/tests/test_table.py index 7d75b400..d79357da 100644 --- a/python/tests/test_table.py +++ b/python/tests/test_table.py @@ -316,3 +316,35 @@ def test_merge(db, tmp_path): other_dataset = lance.write_dataset(other_table, tmp_path / "other_table.lance") table.restore(1) table.merge(other_dataset, left_on="id") + + +def test_delete(db): + table = LanceTable.create( + db, + "my_table", + data=[{"vector": [1.1, 0.9], "id": 0}, {"vector": [1.2, 1.9], "id": 1}], + ) + assert len(table) == 2 + assert len(table.list_versions()) == 1 + table.delete("id=0") + assert len(table.list_versions()) == 2 + assert table.version == 2 + assert len(table) == 1 + assert table.to_pandas()["id"].tolist() == [1] + + +def test_update(db): + table = LanceTable.create( + db, + "my_table", + data=[{"vector": [1.1, 0.9], "id": 0}, {"vector": [1.2, 1.9], "id": 1}], + ) + assert len(table) == 2 + assert len(table.list_versions()) == 1 + table.update(where="id=0", values={"vector": [1.1, 1.1]}) + assert len(table.list_versions()) == 3 + assert table.version == 3 + assert len(table) == 2 + v = table.to_arrow()["vector"].combine_chunks() + v = v.values.to_numpy().reshape(2, 2) + assert np.allclose(v, np.array([[1.2, 1.9], [1.1, 1.1]]))