diff --git a/python/lancedb/table.py b/python/lancedb/table.py index ebfd9713..6d361496 100644 --- a/python/lancedb/table.py +++ b/python/lancedb/table.py @@ -678,6 +678,56 @@ class LanceTable(Table): def delete(self, where: str): self._dataset.delete(where) + def update(self, where: str, values: dict): + """ + EXPERIMENTAL: Update rows in the table (not threadsafe). + + This can be used to update zero to all rows depending on how many + rows match the where clause. + + Parameters + ---------- + where: str + The SQL where clause to use when updating rows. For example, 'x = 2' + or 'x IN (1, 2, 3)'. The filter must not be empty, or it will error. + values: dict + The values to update. The keys are the column names and the values + are the values to set. + + Examples + -------- + >>> import lancedb + >>> import pandas as pd + >>> data = pd.DataFrame({"x": [1, 2, 3], "vector": [[1, 2], [3, 4], [5, 6]]}) + >>> db = lancedb.connect("./.lancedb") + >>> table = db.create_table("my_table", data) + >>> table.to_pandas() + x vector + 0 1 [1.0, 2.0] + 1 2 [3.0, 4.0] + 2 3 [5.0, 6.0] + >>> table.update(where="x = 2", values={"vector": [10, 10]}) + >>> table.to_pandas() + x vector + 0 1 [1.0, 2.0] + 1 3 [5.0, 6.0] + 2 2 [10.0, 10.0] + + """ + orig_data = self._dataset.to_table(filter=where).combine_chunks() + if len(orig_data) == 0: + return + for col, val in values.items(): + i = orig_data.column_names.index(col) + if i < 0: + raise ValueError(f"Column {col} does not exist") + orig_data = orig_data.set_column( + i, col, pa.array([val] * len(orig_data), type=orig_data[col].type) + ) + self.delete(where) + self.add(orig_data, mode="append") + self._reset_dataset() + def _execute_query(self, query: Query) -> pa.Table: ds = self.to_lance() return ds.to_table( diff --git a/python/tests/test_table.py b/python/tests/test_table.py index 7d75b400..d79357da 100644 --- a/python/tests/test_table.py +++ b/python/tests/test_table.py @@ -316,3 +316,35 @@ def test_merge(db, tmp_path): other_dataset = lance.write_dataset(other_table, tmp_path / "other_table.lance") table.restore(1) table.merge(other_dataset, left_on="id") + + +def test_delete(db): + table = LanceTable.create( + db, + "my_table", + data=[{"vector": [1.1, 0.9], "id": 0}, {"vector": [1.2, 1.9], "id": 1}], + ) + assert len(table) == 2 + assert len(table.list_versions()) == 1 + table.delete("id=0") + assert len(table.list_versions()) == 2 + assert table.version == 2 + assert len(table) == 1 + assert table.to_pandas()["id"].tolist() == [1] + + +def test_update(db): + table = LanceTable.create( + db, + "my_table", + data=[{"vector": [1.1, 0.9], "id": 0}, {"vector": [1.2, 1.9], "id": 1}], + ) + assert len(table) == 2 + assert len(table.list_versions()) == 1 + table.update(where="id=0", values={"vector": [1.1, 1.1]}) + assert len(table.list_versions()) == 3 + assert table.version == 3 + assert len(table) == 2 + v = table.to_arrow()["vector"].combine_chunks() + v = v.values.to_numpy().reshape(2, 2) + assert np.allclose(v, np.array([[1.2, 1.9], [1.1, 1.1]]))