diff --git a/python/lancedb/table.py b/python/lancedb/table.py index 6411cd33..60aac438 100644 --- a/python/lancedb/table.py +++ b/python/lancedb/table.py @@ -17,7 +17,7 @@ import inspect import os from abc import ABC, abstractmethod from functools import cached_property -from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Union import lance import numpy as np @@ -30,7 +30,7 @@ from .common import DATA, VEC, VECTOR_COLUMN_NAME from .embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry from .pydantic import LanceModel from .query import LanceQueryBuilder, Query -from .util import fs_from_uri, safe_import_pandas +from .util import fs_from_uri, safe_import_pandas, value_to_sql from .utils.events import register_event if TYPE_CHECKING: @@ -913,30 +913,35 @@ class LanceTable(Table): def delete(self, where: str): self._dataset.delete(where) - def update(self, where: str, values: dict): + def update( + self, + where: Optional[str] = None, + values: Optional[dict] = None, + *, + values_sql: Optional[Dict[str, str]] = None, + ): """ - EXPERIMENTAL: Update rows in the table (not threadsafe). - This can be used to update zero to all rows depending on how many rows match the where clause. Parameters ---------- - where: str + where: str, optional The SQL where clause to use when updating rows. For example, 'x = 2' or 'x IN (1, 2, 3)'. The filter must not be empty, or it will error. - values: dict + values: dict, optional The values to update. The keys are the column names and the values are the values to set. + values_sql: dict, optional + The values to update, expressed as SQL expression strings. These can + reference existing columns. For example, {"x": "x + 1"} will increment + the x column by 1. Examples -------- >>> import lancedb - >>> data = [ - ... {"x": 1, "vector": [1, 2]}, - ... {"x": 2, "vector": [3, 4]}, - ... {"x": 3, "vector": [5, 6]} - ... ] + >>> import pandas as pd + >>> data = pd.DataFrame({"x": [1, 2, 3], "vector": [[1, 2], [3, 4], [5, 6]]}) >>> db = lancedb.connect("./.lancedb") >>> table = db.create_table("my_table", data) >>> table.to_pandas() @@ -952,18 +957,15 @@ class LanceTable(Table): 2 2 [10.0, 10.0] """ - orig_data = self._dataset.to_table(filter=where).combine_chunks() - if len(orig_data) == 0: - return - for col, val in values.items(): - i = orig_data.column_names.index(col) - if i < 0: - raise ValueError(f"Column {col} does not exist") - orig_data = orig_data.set_column( - i, col, pa.array([val] * len(orig_data), type=orig_data[col].type) - ) - self.delete(where) - self.add(orig_data, mode="append") + if values is not None and values_sql is not None: + raise ValueError("Only one of values or values_sql can be provided") + if values is None and values_sql is None: + raise ValueError("Either values or values_sql must be provided") + + if values is not None: + values_sql = {k: value_to_sql(v) for k, v in values.items()} + + self.to_lance().update(values_sql, where) self._reset_dataset() register_event("update") diff --git a/python/lancedb/util.py b/python/lancedb/util.py index 15b1d427..4774c429 100644 --- a/python/lancedb/util.py +++ b/python/lancedb/util.py @@ -12,9 +12,12 @@ # limitations under the License. import os +from datetime import date, datetime +from functools import singledispatch from typing import Tuple from urllib.parse import urlparse +import numpy as np import pyarrow.fs as pa_fs @@ -88,3 +91,53 @@ def safe_import_pandas(): return pd except ImportError: return None + + +@singledispatch +def value_to_sql(value): + raise NotImplementedError("SQL conversion is not implemented for this type") + + +@value_to_sql.register(str) +def _(value: str): + return f"'{value}'" + + +@value_to_sql.register(int) +def _(value: int): + return str(value) + + +@value_to_sql.register(float) +def _(value: float): + return str(value) + + +@value_to_sql.register(bool) +def _(value: bool): + return str(value).upper() + + +@value_to_sql.register(type(None)) +def _(value: type(None)): + return "NULL" + + +@value_to_sql.register(datetime) +def _(value: datetime): + return f"'{value.isoformat()}'" + + +@value_to_sql.register(date) +def _(value: date): + return f"'{value.isoformat()}'" + + +@value_to_sql.register(list) +def _(value: list): + return "[" + ", ".join(map(value_to_sql, value)) + "]" + + +@value_to_sql.register(np.ndarray) +def _(value: np.ndarray): + return value_to_sql(value.tolist()) diff --git a/python/pyproject.toml b/python/pyproject.toml index b670c2d2..046340b2 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -3,10 +3,10 @@ name = "lancedb" version = "0.3.4" dependencies = [ "deprecation", - "pylance==0.8.17", + "pylance==0.8.21", "ratelimiter~=1.0", "retry>=0.9.2", - "tqdm>=4.1.0", + "tqdm>=4.27.0", "aiohttp", "pydantic>=1.10", "attrs>=21.3.0", diff --git a/python/tests/test_table.py b/python/tests/test_table.py index 258a929e..9b12d42b 100644 --- a/python/tests/test_table.py +++ b/python/tests/test_table.py @@ -12,7 +12,7 @@ # limitations under the License. import functools -from datetime import timedelta +from datetime import date, datetime, timedelta from pathlib import Path from typing import List from unittest.mock import PropertyMock, patch @@ -348,14 +348,79 @@ def test_update(db): assert len(table) == 2 assert len(table.list_versions()) == 2 table.update(where="id=0", values={"vector": [1.1, 1.1]}) - assert len(table.list_versions()) == 4 - assert table.version == 4 + assert len(table.list_versions()) == 3 + assert table.version == 3 assert len(table) == 2 v = table.to_arrow()["vector"].combine_chunks() v = v.values.to_numpy().reshape(2, 2) assert np.allclose(v, np.array([[1.2, 1.9], [1.1, 1.1]])) +def test_update_types(db): + table = LanceTable.create( + db, + "my_table", + data=[ + { + "id": 0, + "str": "foo", + "float": 1.1, + "timestamp": datetime(2021, 1, 1), + "date": date(2021, 1, 1), + "vector1": [1.0, 0.0], + "vector2": [1.0, 1.0], + } + ], + ) + # Update with SQL + table.update( + values_sql=dict( + id="1", + str="'bar'", + float="2.2", + timestamp="TIMESTAMP '2021-01-02 00:00:00'", + date="DATE '2021-01-02'", + vector1="[2.0, 2.0]", + vector2="[3.0, 3.0]", + ) + ) + actual = table.to_arrow().to_pylist()[0] + expected = dict( + id=1, + str="bar", + float=2.2, + timestamp=datetime(2021, 1, 2), + date=date(2021, 1, 2), + vector1=[2.0, 2.0], + vector2=[3.0, 3.0], + ) + assert actual == expected + + # Update with values + table.update( + values=dict( + id=2, + str="baz", + float=3.3, + timestamp=datetime(2021, 1, 3), + date=date(2021, 1, 3), + vector1=[3.0, 3.0], + vector2=np.array([4.0, 4.0]), + ) + ) + actual = table.to_arrow().to_pylist()[0] + expected = dict( + id=2, + str="baz", + float=3.3, + timestamp=datetime(2021, 1, 3), + date=date(2021, 1, 3), + vector1=[3.0, 3.0], + vector2=[4.0, 4.0], + ) + assert actual == expected + + def test_create_with_embedding_function(db): class MyTable(LanceModel): text: str