mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-06 11:52:57 +00:00
[python] Temporary update feature (#457)
Combine delete and append to make a temporary update feature that is only enabled for the local python lancedb. The reason why this is temporary is because it first has to load the data that matches the where clause into memory, which is technical unbounded. --------- Co-authored-by: Chang She <chang@lancedb.com>
This commit is contained in:
@@ -678,6 +678,56 @@ class LanceTable(Table):
|
|||||||
def delete(self, where: str):
|
def delete(self, where: str):
|
||||||
self._dataset.delete(where)
|
self._dataset.delete(where)
|
||||||
|
|
||||||
|
def update(self, where: str, values: dict):
|
||||||
|
"""
|
||||||
|
EXPERIMENTAL: Update rows in the table (not threadsafe).
|
||||||
|
|
||||||
|
This can be used to update zero to all rows depending on how many
|
||||||
|
rows match the where clause.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
where: str
|
||||||
|
The SQL where clause to use when updating rows. For example, 'x = 2'
|
||||||
|
or 'x IN (1, 2, 3)'. The filter must not be empty, or it will error.
|
||||||
|
values: dict
|
||||||
|
The values to update. The keys are the column names and the values
|
||||||
|
are the values to set.
|
||||||
|
|
||||||
|
Examples
|
||||||
|
--------
|
||||||
|
>>> import lancedb
|
||||||
|
>>> import pandas as pd
|
||||||
|
>>> data = pd.DataFrame({"x": [1, 2, 3], "vector": [[1, 2], [3, 4], [5, 6]]})
|
||||||
|
>>> db = lancedb.connect("./.lancedb")
|
||||||
|
>>> table = db.create_table("my_table", data)
|
||||||
|
>>> table.to_pandas()
|
||||||
|
x vector
|
||||||
|
0 1 [1.0, 2.0]
|
||||||
|
1 2 [3.0, 4.0]
|
||||||
|
2 3 [5.0, 6.0]
|
||||||
|
>>> table.update(where="x = 2", values={"vector": [10, 10]})
|
||||||
|
>>> table.to_pandas()
|
||||||
|
x vector
|
||||||
|
0 1 [1.0, 2.0]
|
||||||
|
1 3 [5.0, 6.0]
|
||||||
|
2 2 [10.0, 10.0]
|
||||||
|
|
||||||
|
"""
|
||||||
|
orig_data = self._dataset.to_table(filter=where).combine_chunks()
|
||||||
|
if len(orig_data) == 0:
|
||||||
|
return
|
||||||
|
for col, val in values.items():
|
||||||
|
i = orig_data.column_names.index(col)
|
||||||
|
if i < 0:
|
||||||
|
raise ValueError(f"Column {col} does not exist")
|
||||||
|
orig_data = orig_data.set_column(
|
||||||
|
i, col, pa.array([val] * len(orig_data), type=orig_data[col].type)
|
||||||
|
)
|
||||||
|
self.delete(where)
|
||||||
|
self.add(orig_data, mode="append")
|
||||||
|
self._reset_dataset()
|
||||||
|
|
||||||
def _execute_query(self, query: Query) -> pa.Table:
|
def _execute_query(self, query: Query) -> pa.Table:
|
||||||
ds = self.to_lance()
|
ds = self.to_lance()
|
||||||
return ds.to_table(
|
return ds.to_table(
|
||||||
|
|||||||
@@ -316,3 +316,35 @@ def test_merge(db, tmp_path):
|
|||||||
other_dataset = lance.write_dataset(other_table, tmp_path / "other_table.lance")
|
other_dataset = lance.write_dataset(other_table, tmp_path / "other_table.lance")
|
||||||
table.restore(1)
|
table.restore(1)
|
||||||
table.merge(other_dataset, left_on="id")
|
table.merge(other_dataset, left_on="id")
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete(db):
|
||||||
|
table = LanceTable.create(
|
||||||
|
db,
|
||||||
|
"my_table",
|
||||||
|
data=[{"vector": [1.1, 0.9], "id": 0}, {"vector": [1.2, 1.9], "id": 1}],
|
||||||
|
)
|
||||||
|
assert len(table) == 2
|
||||||
|
assert len(table.list_versions()) == 1
|
||||||
|
table.delete("id=0")
|
||||||
|
assert len(table.list_versions()) == 2
|
||||||
|
assert table.version == 2
|
||||||
|
assert len(table) == 1
|
||||||
|
assert table.to_pandas()["id"].tolist() == [1]
|
||||||
|
|
||||||
|
|
||||||
|
def test_update(db):
|
||||||
|
table = LanceTable.create(
|
||||||
|
db,
|
||||||
|
"my_table",
|
||||||
|
data=[{"vector": [1.1, 0.9], "id": 0}, {"vector": [1.2, 1.9], "id": 1}],
|
||||||
|
)
|
||||||
|
assert len(table) == 2
|
||||||
|
assert len(table.list_versions()) == 1
|
||||||
|
table.update(where="id=0", values={"vector": [1.1, 1.1]})
|
||||||
|
assert len(table.list_versions()) == 3
|
||||||
|
assert table.version == 3
|
||||||
|
assert len(table) == 2
|
||||||
|
v = table.to_arrow()["vector"].combine_chunks()
|
||||||
|
v = v.values.to_numpy().reshape(2, 2)
|
||||||
|
assert np.allclose(v, np.array([[1.2, 1.9], [1.1, 1.1]]))
|
||||||
|
|||||||
Reference in New Issue
Block a user