mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-26 14:49:57 +00:00
feat: add update to the async API (#1093)
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from typing import Optional
|
||||
from typing import Dict, Optional
|
||||
|
||||
import pyarrow as pa
|
||||
|
||||
@@ -30,6 +30,7 @@ class Table:
|
||||
def __repr__(self) -> str: ...
|
||||
async def schema(self) -> pa.Schema: ...
|
||||
async def add(self, data: pa.RecordBatchReader, mode: str) -> None: ...
|
||||
async def update(self, updates: Dict[str, str], where: Optional[str]) -> None: ...
|
||||
async def count_rows(self, filter: Optional[str]) -> int: ...
|
||||
async def create_index(
|
||||
self, column: str, config: Optional[Index], replace: Optional[bool]
|
||||
|
||||
@@ -2206,58 +2206,57 @@ class AsyncTable:
|
||||
|
||||
async def update(
|
||||
self,
|
||||
where: Optional[str] = None,
|
||||
values: Optional[dict] = None,
|
||||
updates: Optional[Dict[str, Any]] = None,
|
||||
*,
|
||||
values_sql: Optional[Dict[str, str]] = None,
|
||||
where: Optional[str] = None,
|
||||
updates_sql: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
"""
|
||||
This can be used to update zero to all rows depending on how many
|
||||
rows match the where clause. If no where clause is provided, then
|
||||
all rows will be updated.
|
||||
This can be used to update zero to all rows in the table.
|
||||
|
||||
Either `values` or `values_sql` must be provided. You cannot provide
|
||||
both.
|
||||
If a filter is provided with `where` then only rows matching the
|
||||
filter will be updated. Otherwise all rows will be updated.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
updates: dict, optional
|
||||
The updates to apply. The keys should be the name of the column to
|
||||
update. The values should be the new values to assign. This is
|
||||
required unless updates_sql is supplied.
|
||||
where: str, optional
|
||||
The SQL where clause to use when updating rows. For example, 'x = 2'
|
||||
or 'x IN (1, 2, 3)'. The filter must not be empty, or it will error.
|
||||
values: dict, optional
|
||||
The values to update. The keys are the column names and the values
|
||||
are the values to set.
|
||||
values_sql: dict, optional
|
||||
The values to update, expressed as SQL expression strings. These can
|
||||
reference existing columns. For example, {"x": "x + 1"} will increment
|
||||
the x column by 1.
|
||||
An SQL filter that controls which rows are updated. For example, 'x = 2'
|
||||
or 'x IN (1, 2, 3)'. Only rows that satisfy this filter will be udpated.
|
||||
updates_sql: dict, optional
|
||||
The updates to apply, expressed as SQL expression strings. The keys should
|
||||
be column names. The values should be SQL expressions. These can be SQL
|
||||
literals (e.g. "7" or "'foo'") or they can be expressions based on the
|
||||
previous value of the row (e.g. "x + 1" to increment the x column by 1)
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import asyncio
|
||||
>>> import lancedb
|
||||
>>> import pandas as pd
|
||||
>>> data = pd.DataFrame({"x": [1, 2, 3], "vector": [[1, 2], [3, 4], [5, 6]]})
|
||||
>>> db = lancedb.connect("./.lancedb")
|
||||
>>> table = db.create_table("my_table", data)
|
||||
>>> table.to_pandas()
|
||||
x vector
|
||||
0 1 [1.0, 2.0]
|
||||
1 2 [3.0, 4.0]
|
||||
2 3 [5.0, 6.0]
|
||||
>>> table.update(where="x = 2", values={"vector": [10, 10]})
|
||||
>>> table.to_pandas()
|
||||
x vector
|
||||
0 1 [1.0, 2.0]
|
||||
1 3 [5.0, 6.0]
|
||||
2 2 [10.0, 10.0]
|
||||
>>> table.update(values_sql={"x": "x + 1"})
|
||||
>>> table.to_pandas()
|
||||
x vector
|
||||
0 2 [1.0, 2.0]
|
||||
1 4 [5.0, 6.0]
|
||||
2 3 [10.0, 10.0]
|
||||
>>> async def demo_update():
|
||||
... data = pd.DataFrame({"x": [1, 2], "vector": [[1, 2], [3, 4]]})
|
||||
... db = await lancedb.connect_async("./.lancedb")
|
||||
... table = await db.create_table("my_table", data)
|
||||
... # x is [1, 2], vector is [[1, 2], [3, 4]]
|
||||
... await table.update({"vector": [10, 10]}, where="x = 2")
|
||||
... # x is [1, 2], vector is [[1, 2], [10, 10]]
|
||||
... await table.update(updates_sql={"x": "x + 1"})
|
||||
... # x is [2, 3], vector is [[1, 2], [10, 10]]
|
||||
>>> asyncio.run(demo_update())
|
||||
"""
|
||||
raise NotImplementedError
|
||||
if updates is not None and updates_sql is not None:
|
||||
raise ValueError("Only one of updates or updates_sql can be provided")
|
||||
if updates is None and updates_sql is None:
|
||||
raise ValueError("Either updates or updates_sql must be provided")
|
||||
|
||||
if updates is not None:
|
||||
updates_sql = {k: value_to_sql(v) for k, v in updates.items()}
|
||||
|
||||
return await self._inner.update(updates_sql, where)
|
||||
|
||||
async def cleanup_old_versions(
|
||||
self,
|
||||
|
||||
@@ -85,6 +85,23 @@ async def test_close(db_async: AsyncConnection):
|
||||
assert str(table) == "ClosedTable(some_table)"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_async(db_async: AsyncConnection):
|
||||
table = await db_async.create_table("some_table", data=[{"id": 0}])
|
||||
assert await table.count_rows("id == 0") == 1
|
||||
assert await table.count_rows("id == 7") == 0
|
||||
await table.update({"id": 7})
|
||||
assert await table.count_rows("id == 7") == 1
|
||||
assert await table.count_rows("id == 0") == 0
|
||||
await table.add([{"id": 2}])
|
||||
await table.update(where="id % 2 == 0", updates_sql={"id": "5"})
|
||||
assert await table.count_rows("id == 7") == 1
|
||||
assert await table.count_rows("id == 2") == 0
|
||||
assert await table.count_rows("id == 5") == 1
|
||||
await table.update({"id": 10}, where="id == 5")
|
||||
assert await table.count_rows("id == 10") == 1
|
||||
|
||||
|
||||
def test_create_table(db):
|
||||
schema = pa.schema(
|
||||
[
|
||||
|
||||
Reference in New Issue
Block a user