Compare commits

...

2 Commits

Author SHA1 Message Date
Lance Release
600bfd7237 [python] Bump version: 0.3.4 → 0.3.5 2023-12-14 19:31:22 +00:00
Will Jones
d087e7891d feat(python): add update query support for Python (#654)
Closes #69

Will not pass until https://github.com/lancedb/lance/pull/1585 is
released
2023-12-14 11:28:32 -08:00
5 changed files with 151 additions and 31 deletions

View File

@@ -1,5 +1,5 @@
[bumpversion]
current_version = 0.3.4
current_version = 0.3.5
commit = True
message = [python] Bump version: {current_version} → {new_version}
tag = True

View File

@@ -17,7 +17,7 @@ import inspect
import os
from abc import ABC, abstractmethod
from functools import cached_property
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Union
import lance
import numpy as np
@@ -30,7 +30,7 @@ from .common import DATA, VEC, VECTOR_COLUMN_NAME
from .embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
from .pydantic import LanceModel
from .query import LanceQueryBuilder, Query
from .util import fs_from_uri, safe_import_pandas
from .util import fs_from_uri, safe_import_pandas, value_to_sql
from .utils.events import register_event
if TYPE_CHECKING:
@@ -913,30 +913,35 @@ class LanceTable(Table):
def delete(self, where: str):
self._dataset.delete(where)
def update(self, where: str, values: dict):
def update(
self,
where: Optional[str] = None,
values: Optional[dict] = None,
*,
values_sql: Optional[Dict[str, str]] = None,
):
"""
EXPERIMENTAL: Update rows in the table (not threadsafe).
This can be used to update zero to all rows depending on how many
rows match the where clause.
Parameters
----------
where: str
where: str, optional
The SQL where clause to use when updating rows. For example, 'x = 2'
or 'x IN (1, 2, 3)'. The filter must not be empty, or it will error.
values: dict
values: dict, optional
The values to update. The keys are the column names and the values
are the values to set.
values_sql: dict, optional
The values to update, expressed as SQL expression strings. These can
reference existing columns. For example, {"x": "x + 1"} will increment
the x column by 1.
Examples
--------
>>> import lancedb
>>> data = [
... {"x": 1, "vector": [1, 2]},
... {"x": 2, "vector": [3, 4]},
... {"x": 3, "vector": [5, 6]}
... ]
>>> import pandas as pd
>>> data = pd.DataFrame({"x": [1, 2, 3], "vector": [[1, 2], [3, 4], [5, 6]]})
>>> db = lancedb.connect("./.lancedb")
>>> table = db.create_table("my_table", data)
>>> table.to_pandas()
@@ -952,18 +957,15 @@ class LanceTable(Table):
2 2 [10.0, 10.0]
"""
orig_data = self._dataset.to_table(filter=where).combine_chunks()
if len(orig_data) == 0:
return
for col, val in values.items():
i = orig_data.column_names.index(col)
if i < 0:
raise ValueError(f"Column {col} does not exist")
orig_data = orig_data.set_column(
i, col, pa.array([val] * len(orig_data), type=orig_data[col].type)
)
self.delete(where)
self.add(orig_data, mode="append")
if values is not None and values_sql is not None:
raise ValueError("Only one of values or values_sql can be provided")
if values is None and values_sql is None:
raise ValueError("Either values or values_sql must be provided")
if values is not None:
values_sql = {k: value_to_sql(v) for k, v in values.items()}
self.to_lance().update(values_sql, where)
self._reset_dataset()
register_event("update")

View File

@@ -12,9 +12,12 @@
# limitations under the License.
import os
from datetime import date, datetime
from functools import singledispatch
from typing import Tuple
from urllib.parse import urlparse
import numpy as np
import pyarrow.fs as pa_fs
@@ -88,3 +91,53 @@ def safe_import_pandas():
return pd
except ImportError:
return None
@singledispatch
def value_to_sql(value):
raise NotImplementedError("SQL conversion is not implemented for this type")
@value_to_sql.register(str)
def _(value: str):
return f"'{value}'"
@value_to_sql.register(int)
def _(value: int):
return str(value)
@value_to_sql.register(float)
def _(value: float):
return str(value)
@value_to_sql.register(bool)
def _(value: bool):
return str(value).upper()
@value_to_sql.register(type(None))
def _(value: type(None)):
return "NULL"
@value_to_sql.register(datetime)
def _(value: datetime):
return f"'{value.isoformat()}'"
@value_to_sql.register(date)
def _(value: date):
return f"'{value.isoformat()}'"
@value_to_sql.register(list)
def _(value: list):
return "[" + ", ".join(map(value_to_sql, value)) + "]"
@value_to_sql.register(np.ndarray)
def _(value: np.ndarray):
return value_to_sql(value.tolist())

View File

@@ -1,12 +1,12 @@
[project]
name = "lancedb"
version = "0.3.4"
version = "0.3.5"
dependencies = [
"deprecation",
"pylance==0.8.17",
"pylance==0.8.21",
"ratelimiter~=1.0",
"retry>=0.9.2",
"tqdm>=4.1.0",
"tqdm>=4.27.0",
"aiohttp",
"pydantic>=1.10",
"attrs>=21.3.0",

View File

@@ -12,7 +12,7 @@
# limitations under the License.
import functools
from datetime import timedelta
from datetime import date, datetime, timedelta
from pathlib import Path
from typing import List
from unittest.mock import PropertyMock, patch
@@ -348,14 +348,79 @@ def test_update(db):
assert len(table) == 2
assert len(table.list_versions()) == 2
table.update(where="id=0", values={"vector": [1.1, 1.1]})
assert len(table.list_versions()) == 4
assert table.version == 4
assert len(table.list_versions()) == 3
assert table.version == 3
assert len(table) == 2
v = table.to_arrow()["vector"].combine_chunks()
v = v.values.to_numpy().reshape(2, 2)
assert np.allclose(v, np.array([[1.2, 1.9], [1.1, 1.1]]))
def test_update_types(db):
table = LanceTable.create(
db,
"my_table",
data=[
{
"id": 0,
"str": "foo",
"float": 1.1,
"timestamp": datetime(2021, 1, 1),
"date": date(2021, 1, 1),
"vector1": [1.0, 0.0],
"vector2": [1.0, 1.0],
}
],
)
# Update with SQL
table.update(
values_sql=dict(
id="1",
str="'bar'",
float="2.2",
timestamp="TIMESTAMP '2021-01-02 00:00:00'",
date="DATE '2021-01-02'",
vector1="[2.0, 2.0]",
vector2="[3.0, 3.0]",
)
)
actual = table.to_arrow().to_pylist()[0]
expected = dict(
id=1,
str="bar",
float=2.2,
timestamp=datetime(2021, 1, 2),
date=date(2021, 1, 2),
vector1=[2.0, 2.0],
vector2=[3.0, 3.0],
)
assert actual == expected
# Update with values
table.update(
values=dict(
id=2,
str="baz",
float=3.3,
timestamp=datetime(2021, 1, 3),
date=date(2021, 1, 3),
vector1=[3.0, 3.0],
vector2=np.array([4.0, 4.0]),
)
)
actual = table.to_arrow().to_pylist()[0]
expected = dict(
id=2,
str="baz",
float=3.3,
timestamp=datetime(2021, 1, 3),
date=date(2021, 1, 3),
vector1=[3.0, 3.0],
vector2=[4.0, 4.0],
)
assert actual == expected
def test_create_with_embedding_function(db):
class MyTable(LanceModel):
text: str