[python] Temporary update feature (#457)

Combine delete and append to make a temporary update feature that is
only enabled for the local python lancedb.

The reason why this is temporary is because it first has to load the
data that matches the where clause into memory, which is technical
unbounded.

---------

Co-authored-by: Chang She <chang@lancedb.com>
This commit is contained in:
Chang She
2023-08-30 00:25:26 -07:00
committed by GitHub
parent 8391ffee84
commit 0cba0f4f92
2 changed files with 82 additions and 0 deletions

View File

@@ -678,6 +678,56 @@ class LanceTable(Table):
def delete(self, where: str): def delete(self, where: str):
self._dataset.delete(where) self._dataset.delete(where)
def update(self, where: str, values: dict):
"""
EXPERIMENTAL: Update rows in the table (not threadsafe).
This can be used to update zero to all rows depending on how many
rows match the where clause.
Parameters
----------
where: str
The SQL where clause to use when updating rows. For example, 'x = 2'
or 'x IN (1, 2, 3)'. The filter must not be empty, or it will error.
values: dict
The values to update. The keys are the column names and the values
are the values to set.
Examples
--------
>>> import lancedb
>>> import pandas as pd
>>> data = pd.DataFrame({"x": [1, 2, 3], "vector": [[1, 2], [3, 4], [5, 6]]})
>>> db = lancedb.connect("./.lancedb")
>>> table = db.create_table("my_table", data)
>>> table.to_pandas()
x vector
0 1 [1.0, 2.0]
1 2 [3.0, 4.0]
2 3 [5.0, 6.0]
>>> table.update(where="x = 2", values={"vector": [10, 10]})
>>> table.to_pandas()
x vector
0 1 [1.0, 2.0]
1 3 [5.0, 6.0]
2 2 [10.0, 10.0]
"""
orig_data = self._dataset.to_table(filter=where).combine_chunks()
if len(orig_data) == 0:
return
for col, val in values.items():
i = orig_data.column_names.index(col)
if i < 0:
raise ValueError(f"Column {col} does not exist")
orig_data = orig_data.set_column(
i, col, pa.array([val] * len(orig_data), type=orig_data[col].type)
)
self.delete(where)
self.add(orig_data, mode="append")
self._reset_dataset()
def _execute_query(self, query: Query) -> pa.Table: def _execute_query(self, query: Query) -> pa.Table:
ds = self.to_lance() ds = self.to_lance()
return ds.to_table( return ds.to_table(

View File

@@ -316,3 +316,35 @@ def test_merge(db, tmp_path):
other_dataset = lance.write_dataset(other_table, tmp_path / "other_table.lance") other_dataset = lance.write_dataset(other_table, tmp_path / "other_table.lance")
table.restore(1) table.restore(1)
table.merge(other_dataset, left_on="id") table.merge(other_dataset, left_on="id")
def test_delete(db):
table = LanceTable.create(
db,
"my_table",
data=[{"vector": [1.1, 0.9], "id": 0}, {"vector": [1.2, 1.9], "id": 1}],
)
assert len(table) == 2
assert len(table.list_versions()) == 1
table.delete("id=0")
assert len(table.list_versions()) == 2
assert table.version == 2
assert len(table) == 1
assert table.to_pandas()["id"].tolist() == [1]
def test_update(db):
table = LanceTable.create(
db,
"my_table",
data=[{"vector": [1.1, 0.9], "id": 0}, {"vector": [1.2, 1.9], "id": 1}],
)
assert len(table) == 2
assert len(table.list_versions()) == 1
table.update(where="id=0", values={"vector": [1.1, 1.1]})
assert len(table.list_versions()) == 3
assert table.version == 3
assert len(table) == 2
v = table.to_arrow()["vector"].combine_chunks()
v = v.values.to_numpy().reshape(2, 2)
assert np.allclose(v, np.array([[1.2, 1.9], [1.1, 1.1]]))