mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-06 11:52:57 +00:00
feat: return version for all write operations (#2368)
return version info for all write operations (add, update, merge_insert and column modification operations) <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Table modification operations (add, update, delete, merge, add/alter/drop columns) now return detailed result objects including version numbers and operation statistics. - Result objects provide clearer feedback such as rows affected and new table version after each operation. - **Documentation** - Updated documentation to describe new result objects and their fields for all relevant table operations. - Added documentation for new result interfaces and updated method return types in Node.js and Python APIs. - **Tests** - Enhanced test coverage to assert correctness of returned versioning and operation metadata after table modifications. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
@@ -36,8 +36,10 @@ class Table:
|
||||
async def schema(self) -> pa.Schema: ...
|
||||
async def add(
|
||||
self, data: pa.RecordBatchReader, mode: Literal["append", "overwrite"]
|
||||
) -> None: ...
|
||||
async def update(self, updates: Dict[str, str], where: Optional[str]) -> None: ...
|
||||
) -> AddResult: ...
|
||||
async def update(
|
||||
self, updates: Dict[str, str], where: Optional[str]
|
||||
) -> UpdateResult: ...
|
||||
async def count_rows(self, filter: Optional[str]) -> int: ...
|
||||
async def create_index(
|
||||
self,
|
||||
@@ -51,10 +53,12 @@ class Table:
|
||||
async def checkout_latest(self): ...
|
||||
async def restore(self, version: Optional[int] = None): ...
|
||||
async def list_indices(self) -> list[IndexConfig]: ...
|
||||
async def delete(self, filter: str): ...
|
||||
async def add_columns(self, columns: list[tuple[str, str]]) -> None: ...
|
||||
async def add_columns_with_schema(self, schema: pa.Schema) -> None: ...
|
||||
async def alter_columns(self, columns: list[dict[str, Any]]) -> None: ...
|
||||
async def delete(self, filter: str) -> DeleteResult: ...
|
||||
async def add_columns(self, columns: list[tuple[str, str]]) -> AddColumnsResult: ...
|
||||
async def add_columns_with_schema(self, schema: pa.Schema) -> AddColumnsResult: ...
|
||||
async def alter_columns(
|
||||
self, columns: list[dict[str, Any]]
|
||||
) -> AlterColumnsResult: ...
|
||||
async def optimize(
|
||||
self,
|
||||
*,
|
||||
@@ -208,3 +212,28 @@ class OptimizeStats:
|
||||
class Tag(TypedDict):
|
||||
version: int
|
||||
manifest_size: int
|
||||
|
||||
class AddResult:
|
||||
version: int
|
||||
|
||||
class DeleteResult:
|
||||
version: int
|
||||
|
||||
class UpdateResult:
|
||||
rows_updated: int
|
||||
version: int
|
||||
|
||||
class MergeResult:
|
||||
version: int
|
||||
num_updated_rows: int
|
||||
num_inserted_rows: int
|
||||
num_deleted_rows: int
|
||||
|
||||
class AddColumnsResult:
|
||||
version: int
|
||||
|
||||
class AlterColumnsResult:
|
||||
version: int
|
||||
|
||||
class DropColumnsResult:
|
||||
version: int
|
||||
|
||||
@@ -8,6 +8,9 @@ from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .common import DATA
|
||||
from ._lancedb import (
|
||||
MergeInsertResult,
|
||||
)
|
||||
|
||||
|
||||
class LanceMergeInsertBuilder(object):
|
||||
@@ -78,7 +81,7 @@ class LanceMergeInsertBuilder(object):
|
||||
new_data: DATA,
|
||||
on_bad_vectors: str = "error",
|
||||
fill_value: float = 0.0,
|
||||
):
|
||||
) -> MergeInsertResult:
|
||||
"""
|
||||
Executes the merge insert operation
|
||||
|
||||
@@ -95,5 +98,10 @@ class LanceMergeInsertBuilder(object):
|
||||
One of "error", "drop", "fill".
|
||||
fill_value: float, default 0.
|
||||
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
||||
|
||||
Returns
|
||||
-------
|
||||
MergeInsertResult
|
||||
version: the new version number of the table after doing merge insert.
|
||||
"""
|
||||
return self._table._do_merge(self, new_data, on_bad_vectors, fill_value)
|
||||
|
||||
@@ -415,6 +415,7 @@ class LanceModel(pydantic.BaseModel):
|
||||
>>> table.add([
|
||||
... TestModel(name="test", vector=[1.0, 2.0])
|
||||
... ])
|
||||
AddResult(version=2)
|
||||
>>> table.search([0., 0.]).limit(1).to_pydantic(TestModel)
|
||||
[TestModel(name='test', vector=FixedSizeList(dim=2))]
|
||||
"""
|
||||
|
||||
@@ -7,7 +7,16 @@ from functools import cached_property
|
||||
from typing import Dict, Iterable, List, Optional, Union, Literal
|
||||
import warnings
|
||||
|
||||
from lancedb._lancedb import IndexConfig
|
||||
from lancedb._lancedb import (
|
||||
AddColumnsResult,
|
||||
AddResult,
|
||||
AlterColumnsResult,
|
||||
DeleteResult,
|
||||
DropColumnsResult,
|
||||
IndexConfig,
|
||||
MergeResult,
|
||||
UpdateResult,
|
||||
)
|
||||
from lancedb.embeddings.base import EmbeddingFunctionConfig
|
||||
from lancedb.index import FTS, BTree, Bitmap, HnswPq, HnswSq, IvfFlat, IvfPq, LabelList
|
||||
from lancedb.remote.db import LOOP
|
||||
@@ -263,7 +272,7 @@ class RemoteTable(Table):
|
||||
mode: str = "append",
|
||||
on_bad_vectors: str = "error",
|
||||
fill_value: float = 0.0,
|
||||
) -> int:
|
||||
) -> AddResult:
|
||||
"""Add more data to the [Table](Table). It has the same API signature as
|
||||
the OSS version.
|
||||
|
||||
@@ -286,8 +295,12 @@ class RemoteTable(Table):
|
||||
fill_value: float, default 0.
|
||||
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
||||
|
||||
Returns
|
||||
-------
|
||||
AddResult
|
||||
An object containing the new version number of the table after adding data.
|
||||
"""
|
||||
LOOP.run(
|
||||
return LOOP.run(
|
||||
self._table.add(
|
||||
data, mode=mode, on_bad_vectors=on_bad_vectors, fill_value=fill_value
|
||||
)
|
||||
@@ -413,10 +426,12 @@ class RemoteTable(Table):
|
||||
new_data: DATA,
|
||||
on_bad_vectors: str,
|
||||
fill_value: float,
|
||||
):
|
||||
LOOP.run(self._table._do_merge(merge, new_data, on_bad_vectors, fill_value))
|
||||
) -> MergeResult:
|
||||
return LOOP.run(
|
||||
self._table._do_merge(merge, new_data, on_bad_vectors, fill_value)
|
||||
)
|
||||
|
||||
def delete(self, predicate: str):
|
||||
def delete(self, predicate: str) -> DeleteResult:
|
||||
"""Delete rows from the table.
|
||||
|
||||
This can be used to delete a single row, many rows, all rows, or
|
||||
@@ -431,6 +446,11 @@ class RemoteTable(Table):
|
||||
|
||||
The filter must not be empty, or it will error.
|
||||
|
||||
Returns
|
||||
-------
|
||||
DeleteResult
|
||||
An object containing the new version number of the table after deletion.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import lancedb
|
||||
@@ -463,7 +483,7 @@ class RemoteTable(Table):
|
||||
x vector _distance # doctest: +SKIP
|
||||
0 2 [3.0, 4.0] 85.0 # doctest: +SKIP
|
||||
"""
|
||||
LOOP.run(self._table.delete(predicate))
|
||||
return LOOP.run(self._table.delete(predicate))
|
||||
|
||||
def update(
|
||||
self,
|
||||
@@ -471,7 +491,7 @@ class RemoteTable(Table):
|
||||
values: Optional[dict] = None,
|
||||
*,
|
||||
values_sql: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
) -> UpdateResult:
|
||||
"""
|
||||
This can be used to update zero to all rows depending on how many
|
||||
rows match the where clause.
|
||||
@@ -489,6 +509,12 @@ class RemoteTable(Table):
|
||||
reference existing columns. For example, {"x": "x + 1"} will increment
|
||||
the x column by 1.
|
||||
|
||||
Returns
|
||||
-------
|
||||
UpdateResult
|
||||
- rows_updated: The number of rows that were updated
|
||||
- version: The new version number of the table after the update
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import lancedb
|
||||
@@ -513,7 +539,7 @@ class RemoteTable(Table):
|
||||
2 2 [10.0, 10.0] # doctest: +SKIP
|
||||
|
||||
"""
|
||||
LOOP.run(
|
||||
return LOOP.run(
|
||||
self._table.update(where=where, updates=values, updates_sql=values_sql)
|
||||
)
|
||||
|
||||
@@ -561,13 +587,15 @@ class RemoteTable(Table):
|
||||
def count_rows(self, filter: Optional[str] = None) -> int:
|
||||
return LOOP.run(self._table.count_rows(filter))
|
||||
|
||||
def add_columns(self, transforms: Dict[str, str]):
|
||||
def add_columns(self, transforms: Dict[str, str]) -> AddColumnsResult:
|
||||
return LOOP.run(self._table.add_columns(transforms))
|
||||
|
||||
def alter_columns(self, *alterations: Iterable[Dict[str, str]]):
|
||||
def alter_columns(
|
||||
self, *alterations: Iterable[Dict[str, str]]
|
||||
) -> AlterColumnsResult:
|
||||
return LOOP.run(self._table.alter_columns(*alterations))
|
||||
|
||||
def drop_columns(self, columns: Iterable[str]):
|
||||
def drop_columns(self, columns: Iterable[str]) -> DropColumnsResult:
|
||||
return LOOP.run(self._table.drop_columns(columns))
|
||||
|
||||
def drop_index(self, index_name: str):
|
||||
|
||||
@@ -78,6 +78,13 @@ if TYPE_CHECKING:
|
||||
CleanupStats,
|
||||
CompactionStats,
|
||||
Tag,
|
||||
AddColumnsResult,
|
||||
AddResult,
|
||||
AlterColumnsResult,
|
||||
DeleteResult,
|
||||
DropColumnsResult,
|
||||
MergeResult,
|
||||
UpdateResult,
|
||||
)
|
||||
from .db import LanceDBConnection
|
||||
from .index import IndexConfig
|
||||
@@ -550,6 +557,7 @@ class Table(ABC):
|
||||
Can append new data with [Table.add()][lancedb.table.Table.add].
|
||||
|
||||
>>> table.add([{"vector": [0.5, 1.3], "b": 4}])
|
||||
AddResult(version=2)
|
||||
|
||||
Can query the table with [Table.search][lancedb.table.Table.search].
|
||||
|
||||
@@ -894,7 +902,7 @@ class Table(ABC):
|
||||
mode: AddMode = "append",
|
||||
on_bad_vectors: OnBadVectorsType = "error",
|
||||
fill_value: float = 0.0,
|
||||
):
|
||||
) -> AddResult:
|
||||
"""Add more data to the [Table](Table).
|
||||
|
||||
Parameters
|
||||
@@ -916,6 +924,10 @@ class Table(ABC):
|
||||
fill_value: float, default 0.
|
||||
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
||||
|
||||
Returns
|
||||
-------
|
||||
AddResult
|
||||
An object containing the new version number of the table after adding data.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -962,12 +974,12 @@ class Table(ABC):
|
||||
>>> table = db.create_table("my_table", data)
|
||||
>>> new_data = pa.table({"a": [2, 3, 4], "b": ["x", "y", "z"]})
|
||||
>>> # Perform a "upsert" operation
|
||||
>>> stats = table.merge_insert("a") \\
|
||||
>>> res = table.merge_insert("a") \\
|
||||
... .when_matched_update_all() \\
|
||||
... .when_not_matched_insert_all() \\
|
||||
... .execute(new_data)
|
||||
>>> stats
|
||||
{'num_inserted_rows': 1, 'num_updated_rows': 2, 'num_deleted_rows': 0}
|
||||
>>> res
|
||||
MergeResult(version=2, num_updated_rows=2, num_inserted_rows=1, num_deleted_rows=0)
|
||||
>>> # The order of new rows is non-deterministic since we use
|
||||
>>> # a hash-join as part of this operation and so we sort here
|
||||
>>> table.to_arrow().sort_by("a").to_pandas()
|
||||
@@ -976,7 +988,7 @@ class Table(ABC):
|
||||
1 2 x
|
||||
2 3 y
|
||||
3 4 z
|
||||
"""
|
||||
""" # noqa: E501
|
||||
on = [on] if isinstance(on, str) else list(iter(on))
|
||||
|
||||
return LanceMergeInsertBuilder(self, on)
|
||||
@@ -1091,10 +1103,10 @@ class Table(ABC):
|
||||
new_data: DATA,
|
||||
on_bad_vectors: OnBadVectorsType,
|
||||
fill_value: float,
|
||||
): ...
|
||||
) -> MergeResult: ...
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, where: str):
|
||||
def delete(self, where: str) -> DeleteResult:
|
||||
"""Delete rows from the table.
|
||||
|
||||
This can be used to delete a single row, many rows, all rows, or
|
||||
@@ -1109,6 +1121,11 @@ class Table(ABC):
|
||||
|
||||
The filter must not be empty, or it will error.
|
||||
|
||||
Returns
|
||||
-------
|
||||
DeleteResult
|
||||
An object containing the new version number of the table after deletion.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import lancedb
|
||||
@@ -1125,6 +1142,7 @@ class Table(ABC):
|
||||
1 2 [3.0, 4.0]
|
||||
2 3 [5.0, 6.0]
|
||||
>>> table.delete("x = 2")
|
||||
DeleteResult(version=2)
|
||||
>>> table.to_pandas()
|
||||
x vector
|
||||
0 1 [1.0, 2.0]
|
||||
@@ -1138,6 +1156,7 @@ class Table(ABC):
|
||||
>>> to_remove
|
||||
'1, 5'
|
||||
>>> table.delete(f"x IN ({to_remove})")
|
||||
DeleteResult(version=3)
|
||||
>>> table.to_pandas()
|
||||
x vector
|
||||
0 3 [5.0, 6.0]
|
||||
@@ -1151,7 +1170,7 @@ class Table(ABC):
|
||||
values: Optional[dict] = None,
|
||||
*,
|
||||
values_sql: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
) -> UpdateResult:
|
||||
"""
|
||||
This can be used to update zero to all rows depending on how many
|
||||
rows match the where clause. If no where clause is provided, then
|
||||
@@ -1173,6 +1192,12 @@ class Table(ABC):
|
||||
reference existing columns. For example, {"x": "x + 1"} will increment
|
||||
the x column by 1.
|
||||
|
||||
Returns
|
||||
-------
|
||||
UpdateResult
|
||||
- rows_updated: The number of rows that were updated
|
||||
- version: The new version number of the table after the update
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import lancedb
|
||||
@@ -1186,12 +1211,14 @@ class Table(ABC):
|
||||
1 2 [3.0, 4.0]
|
||||
2 3 [5.0, 6.0]
|
||||
>>> table.update(where="x = 2", values={"vector": [10.0, 10]})
|
||||
UpdateResult(rows_updated=1, version=2)
|
||||
>>> table.to_pandas()
|
||||
x vector
|
||||
0 1 [1.0, 2.0]
|
||||
1 3 [5.0, 6.0]
|
||||
2 2 [10.0, 10.0]
|
||||
>>> table.update(values_sql={"x": "x + 1"})
|
||||
UpdateResult(rows_updated=3, version=3)
|
||||
>>> table.to_pandas()
|
||||
x vector
|
||||
0 2 [1.0, 2.0]
|
||||
@@ -1354,6 +1381,11 @@ class Table(ABC):
|
||||
Alternatively, a pyarrow Field or Schema can be provided to add
|
||||
new columns with the specified data types. The new columns will
|
||||
be initialized with null values.
|
||||
|
||||
Returns
|
||||
-------
|
||||
AddColumnsResult
|
||||
version: the new version number of the table after adding columns.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
@@ -1379,10 +1411,15 @@ class Table(ABC):
|
||||
nullability is not changed. Only non-nullable columns can be changed
|
||||
to nullable. Currently, you cannot change a nullable column to
|
||||
non-nullable.
|
||||
|
||||
Returns
|
||||
-------
|
||||
AlterColumnsResult
|
||||
version: the new version number of the table after the alteration.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def drop_columns(self, columns: Iterable[str]):
|
||||
def drop_columns(self, columns: Iterable[str]) -> DropColumnsResult:
|
||||
"""
|
||||
Drop columns from the table.
|
||||
|
||||
@@ -1390,6 +1427,11 @@ class Table(ABC):
|
||||
----------
|
||||
columns : Iterable[str]
|
||||
The names of the columns to drop.
|
||||
|
||||
Returns
|
||||
-------
|
||||
DropColumnsResult
|
||||
version: the new version number of the table dropping the columns.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
@@ -1611,6 +1653,7 @@ class LanceTable(Table):
|
||||
... [{"vector": [1.1, 0.9], "type": "vector"}])
|
||||
>>> table.tags.create("v1", table.version)
|
||||
>>> table.add([{"vector": [0.5, 0.2], "type": "vector"}])
|
||||
AddResult(version=2)
|
||||
>>> tags = table.tags.list()
|
||||
>>> print(tags["v1"]["version"])
|
||||
1
|
||||
@@ -1649,6 +1692,7 @@ class LanceTable(Table):
|
||||
vector type
|
||||
0 [1.1, 0.9] vector
|
||||
>>> table.add([{"vector": [0.5, 0.2], "type": "vector"}])
|
||||
AddResult(version=2)
|
||||
>>> table.version
|
||||
2
|
||||
>>> table.checkout(1)
|
||||
@@ -1691,6 +1735,7 @@ class LanceTable(Table):
|
||||
vector type
|
||||
0 [1.1, 0.9] vector
|
||||
>>> table.add([{"vector": [0.5, 0.2], "type": "vector"}])
|
||||
AddResult(version=2)
|
||||
>>> table.version
|
||||
2
|
||||
>>> table.restore(1)
|
||||
@@ -2055,7 +2100,7 @@ class LanceTable(Table):
|
||||
mode: AddMode = "append",
|
||||
on_bad_vectors: OnBadVectorsType = "error",
|
||||
fill_value: float = 0.0,
|
||||
):
|
||||
) -> AddResult:
|
||||
"""Add data to the table.
|
||||
If vector columns are missing and the table
|
||||
has embedding functions, then the vector columns
|
||||
@@ -2079,7 +2124,7 @@ class LanceTable(Table):
|
||||
int
|
||||
The number of vectors in the table.
|
||||
"""
|
||||
LOOP.run(
|
||||
return LOOP.run(
|
||||
self._table.add(
|
||||
data, mode=mode, on_bad_vectors=on_bad_vectors, fill_value=fill_value
|
||||
)
|
||||
@@ -2409,8 +2454,8 @@ class LanceTable(Table):
|
||||
)
|
||||
return self
|
||||
|
||||
def delete(self, where: str):
|
||||
LOOP.run(self._table.delete(where))
|
||||
def delete(self, where: str) -> DeleteResult:
|
||||
return LOOP.run(self._table.delete(where))
|
||||
|
||||
def update(
|
||||
self,
|
||||
@@ -2418,7 +2463,7 @@ class LanceTable(Table):
|
||||
values: Optional[dict] = None,
|
||||
*,
|
||||
values_sql: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
) -> UpdateResult:
|
||||
"""
|
||||
This can be used to update zero to all rows depending on how many
|
||||
rows match the where clause.
|
||||
@@ -2436,6 +2481,12 @@ class LanceTable(Table):
|
||||
reference existing columns. For example, {"x": "x + 1"} will increment
|
||||
the x column by 1.
|
||||
|
||||
Returns
|
||||
-------
|
||||
UpdateResult
|
||||
- rows_updated: The number of rows that were updated
|
||||
- version: The new version number of the table after the update
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import lancedb
|
||||
@@ -2449,6 +2500,7 @@ class LanceTable(Table):
|
||||
1 2 [3.0, 4.0]
|
||||
2 3 [5.0, 6.0]
|
||||
>>> table.update(where="x = 2", values={"vector": [10.0, 10]})
|
||||
UpdateResult(rows_updated=1, version=2)
|
||||
>>> table.to_pandas()
|
||||
x vector
|
||||
0 1 [1.0, 2.0]
|
||||
@@ -2456,7 +2508,7 @@ class LanceTable(Table):
|
||||
2 2 [10.0, 10.0]
|
||||
|
||||
"""
|
||||
LOOP.run(self._table.update(values, where=where, updates_sql=values_sql))
|
||||
return LOOP.run(self._table.update(values, where=where, updates_sql=values_sql))
|
||||
|
||||
def _execute_query(
|
||||
self,
|
||||
@@ -2490,7 +2542,7 @@ class LanceTable(Table):
|
||||
new_data: DATA,
|
||||
on_bad_vectors: OnBadVectorsType,
|
||||
fill_value: float,
|
||||
):
|
||||
) -> MergeResult:
|
||||
return LOOP.run(
|
||||
self._table._do_merge(merge, new_data, on_bad_vectors, fill_value)
|
||||
)
|
||||
@@ -2635,14 +2687,16 @@ class LanceTable(Table):
|
||||
|
||||
def add_columns(
|
||||
self, transforms: Dict[str, str] | pa.field | List[pa.field] | pa.Schema
|
||||
):
|
||||
LOOP.run(self._table.add_columns(transforms))
|
||||
) -> AddColumnsResult:
|
||||
return LOOP.run(self._table.add_columns(transforms))
|
||||
|
||||
def alter_columns(self, *alterations: Iterable[Dict[str, str]]):
|
||||
LOOP.run(self._table.alter_columns(*alterations))
|
||||
def alter_columns(
|
||||
self, *alterations: Iterable[Dict[str, str]]
|
||||
) -> AlterColumnsResult:
|
||||
return LOOP.run(self._table.alter_columns(*alterations))
|
||||
|
||||
def drop_columns(self, columns: Iterable[str]):
|
||||
LOOP.run(self._table.drop_columns(columns))
|
||||
def drop_columns(self, columns: Iterable[str]) -> DropColumnsResult:
|
||||
return LOOP.run(self._table.drop_columns(columns))
|
||||
|
||||
def uses_v2_manifest_paths(self) -> bool:
|
||||
"""
|
||||
@@ -3197,7 +3251,7 @@ class AsyncTable:
|
||||
mode: Optional[Literal["append", "overwrite"]] = "append",
|
||||
on_bad_vectors: Optional[OnBadVectorsType] = None,
|
||||
fill_value: Optional[float] = None,
|
||||
):
|
||||
) -> AddResult:
|
||||
"""Add more data to the [Table](Table).
|
||||
|
||||
Parameters
|
||||
@@ -3236,7 +3290,7 @@ class AsyncTable:
|
||||
if isinstance(data, pa.Table):
|
||||
data = data.to_reader()
|
||||
|
||||
await self._inner.add(data, mode or "append")
|
||||
return await self._inner.add(data, mode or "append")
|
||||
|
||||
def merge_insert(self, on: Union[str, Iterable[str]]) -> LanceMergeInsertBuilder:
|
||||
"""
|
||||
@@ -3281,12 +3335,12 @@ class AsyncTable:
|
||||
>>> table = db.create_table("my_table", data)
|
||||
>>> new_data = pa.table({"a": [2, 3, 4], "b": ["x", "y", "z"]})
|
||||
>>> # Perform a "upsert" operation
|
||||
>>> stats = table.merge_insert("a") \\
|
||||
>>> res = table.merge_insert("a") \\
|
||||
... .when_matched_update_all() \\
|
||||
... .when_not_matched_insert_all() \\
|
||||
... .execute(new_data)
|
||||
>>> stats
|
||||
{'num_inserted_rows': 1, 'num_updated_rows': 2, 'num_deleted_rows': 0}
|
||||
>>> res
|
||||
MergeResult(version=2, num_updated_rows=2, num_inserted_rows=1, num_deleted_rows=0)
|
||||
>>> # The order of new rows is non-deterministic since we use
|
||||
>>> # a hash-join as part of this operation and so we sort here
|
||||
>>> table.to_arrow().sort_by("a").to_pandas()
|
||||
@@ -3295,7 +3349,7 @@ class AsyncTable:
|
||||
1 2 x
|
||||
2 3 y
|
||||
3 4 z
|
||||
"""
|
||||
""" # noqa: E501
|
||||
on = [on] if isinstance(on, str) else list(iter(on))
|
||||
|
||||
return LanceMergeInsertBuilder(self, on)
|
||||
@@ -3626,7 +3680,7 @@ class AsyncTable:
|
||||
new_data: DATA,
|
||||
on_bad_vectors: OnBadVectorsType,
|
||||
fill_value: float,
|
||||
):
|
||||
) -> MergeResult:
|
||||
schema = await self.schema()
|
||||
if on_bad_vectors is None:
|
||||
on_bad_vectors = "error"
|
||||
@@ -3654,7 +3708,7 @@ class AsyncTable:
|
||||
),
|
||||
)
|
||||
|
||||
async def delete(self, where: str):
|
||||
async def delete(self, where: str) -> DeleteResult:
|
||||
"""Delete rows from the table.
|
||||
|
||||
This can be used to delete a single row, many rows, all rows, or
|
||||
@@ -3685,6 +3739,7 @@ class AsyncTable:
|
||||
1 2 [3.0, 4.0]
|
||||
2 3 [5.0, 6.0]
|
||||
>>> table.delete("x = 2")
|
||||
DeleteResult(version=2)
|
||||
>>> table.to_pandas()
|
||||
x vector
|
||||
0 1 [1.0, 2.0]
|
||||
@@ -3698,6 +3753,7 @@ class AsyncTable:
|
||||
>>> to_remove
|
||||
'1, 5'
|
||||
>>> table.delete(f"x IN ({to_remove})")
|
||||
DeleteResult(version=3)
|
||||
>>> table.to_pandas()
|
||||
x vector
|
||||
0 3 [5.0, 6.0]
|
||||
@@ -3710,7 +3766,7 @@ class AsyncTable:
|
||||
*,
|
||||
where: Optional[str] = None,
|
||||
updates_sql: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
) -> UpdateResult:
|
||||
"""
|
||||
This can be used to update zero to all rows in the table.
|
||||
|
||||
@@ -3732,6 +3788,13 @@ class AsyncTable:
|
||||
literals (e.g. "7" or "'foo'") or they can be expressions based on the
|
||||
previous value of the row (e.g. "x + 1" to increment the x column by 1)
|
||||
|
||||
Returns
|
||||
-------
|
||||
UpdateResult
|
||||
An object containing:
|
||||
- rows_updated: The number of rows that were updated
|
||||
- version: The new version number of the table after the update
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import asyncio
|
||||
@@ -3760,7 +3823,7 @@ class AsyncTable:
|
||||
|
||||
async def add_columns(
|
||||
self, transforms: dict[str, str] | pa.field | List[pa.field] | pa.Schema
|
||||
):
|
||||
) -> AddColumnsResult:
|
||||
"""
|
||||
Add new columns with defined values.
|
||||
|
||||
@@ -3772,6 +3835,12 @@ class AsyncTable:
|
||||
each row in the table, and can reference existing columns.
|
||||
Alternatively, you can pass a pyarrow field or schema to add
|
||||
new columns with NULLs.
|
||||
|
||||
Returns
|
||||
-------
|
||||
AddColumnsResult
|
||||
version: the new version number of the table after adding columns.
|
||||
|
||||
"""
|
||||
if isinstance(transforms, pa.Field):
|
||||
transforms = [transforms]
|
||||
@@ -3780,11 +3849,13 @@ class AsyncTable:
|
||||
):
|
||||
transforms = pa.schema(transforms)
|
||||
if isinstance(transforms, pa.Schema):
|
||||
await self._inner.add_columns_with_schema(transforms)
|
||||
return await self._inner.add_columns_with_schema(transforms)
|
||||
else:
|
||||
await self._inner.add_columns(list(transforms.items()))
|
||||
return await self._inner.add_columns(list(transforms.items()))
|
||||
|
||||
async def alter_columns(self, *alterations: Iterable[dict[str, Any]]):
|
||||
async def alter_columns(
|
||||
self, *alterations: Iterable[dict[str, Any]]
|
||||
) -> AlterColumnsResult:
|
||||
"""
|
||||
Alter column names and nullability.
|
||||
|
||||
@@ -3804,8 +3875,13 @@ class AsyncTable:
|
||||
nullability is not changed. Only non-nullable columns can be changed
|
||||
to nullable. Currently, you cannot change a nullable column to
|
||||
non-nullable.
|
||||
|
||||
Returns
|
||||
-------
|
||||
AlterColumnsResult
|
||||
version: the new version number of the table after the alteration.
|
||||
"""
|
||||
await self._inner.alter_columns(alterations)
|
||||
return await self._inner.alter_columns(alterations)
|
||||
|
||||
async def drop_columns(self, columns: Iterable[str]):
|
||||
"""
|
||||
@@ -3816,7 +3892,7 @@ class AsyncTable:
|
||||
columns : Iterable[str]
|
||||
The names of the columns to drop.
|
||||
"""
|
||||
await self._inner.drop_columns(columns)
|
||||
return await self._inner.drop_columns(columns)
|
||||
|
||||
async def version(self) -> int:
|
||||
"""
|
||||
|
||||
@@ -18,19 +18,19 @@ def test_upsert(mem_db):
|
||||
{"id": 1, "name": "Bobby"},
|
||||
{"id": 2, "name": "Charlie"},
|
||||
]
|
||||
stats = (
|
||||
res = (
|
||||
table.merge_insert("id")
|
||||
.when_matched_update_all()
|
||||
.when_not_matched_insert_all()
|
||||
.execute(new_users)
|
||||
)
|
||||
table.count_rows() # 3
|
||||
stats # {'num_inserted_rows': 1, 'num_updated_rows': 1, 'num_deleted_rows': 0}
|
||||
res # {'num_inserted_rows': 1, 'num_updated_rows': 1, 'num_deleted_rows': 0}
|
||||
# --8<-- [end:upsert_basic]
|
||||
assert table.count_rows() == 3
|
||||
assert stats["num_inserted_rows"] == 1
|
||||
assert stats["num_updated_rows"] == 1
|
||||
assert stats["num_deleted_rows"] == 0
|
||||
assert res.num_inserted_rows == 1
|
||||
assert res.num_deleted_rows == 0
|
||||
assert res.num_updated_rows == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -48,19 +48,22 @@ async def test_upsert_async(mem_db_async):
|
||||
{"id": 1, "name": "Bobby"},
|
||||
{"id": 2, "name": "Charlie"},
|
||||
]
|
||||
stats = await (
|
||||
res = await (
|
||||
table.merge_insert("id")
|
||||
.when_matched_update_all()
|
||||
.when_not_matched_insert_all()
|
||||
.execute(new_users)
|
||||
)
|
||||
await table.count_rows() # 3
|
||||
stats # {'num_inserted_rows': 1, 'num_updated_rows': 1, 'num_deleted_rows': 0}
|
||||
res
|
||||
# MergeResult(version=2, num_updated_rows=1,
|
||||
# num_inserted_rows=1, num_deleted_rows=0)
|
||||
# --8<-- [end:upsert_basic_async]
|
||||
assert await table.count_rows() == 3
|
||||
assert stats["num_inserted_rows"] == 1
|
||||
assert stats["num_updated_rows"] == 1
|
||||
assert stats["num_deleted_rows"] == 0
|
||||
assert res.version == 2
|
||||
assert res.num_inserted_rows == 1
|
||||
assert res.num_deleted_rows == 0
|
||||
assert res.num_updated_rows == 1
|
||||
|
||||
|
||||
def test_insert_if_not_exists(mem_db):
|
||||
@@ -77,16 +80,19 @@ def test_insert_if_not_exists(mem_db):
|
||||
{"domain": "google.com", "name": "Google"},
|
||||
{"domain": "facebook.com", "name": "Facebook"},
|
||||
]
|
||||
stats = (
|
||||
res = (
|
||||
table.merge_insert("domain").when_not_matched_insert_all().execute(new_domains)
|
||||
)
|
||||
table.count_rows() # 3
|
||||
stats # {'num_inserted_rows': 1, 'num_updated_rows': 0, 'num_deleted_rows': 0}
|
||||
res
|
||||
# MergeResult(version=2, num_updated_rows=0,
|
||||
# num_inserted_rows=1, num_deleted_rows=0)
|
||||
# --8<-- [end:insert_if_not_exists]
|
||||
assert table.count_rows() == 3
|
||||
assert stats["num_inserted_rows"] == 1
|
||||
assert stats["num_updated_rows"] == 0
|
||||
assert stats["num_deleted_rows"] == 0
|
||||
assert res.version == 2
|
||||
assert res.num_inserted_rows == 1
|
||||
assert res.num_deleted_rows == 0
|
||||
assert res.num_updated_rows == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -104,16 +110,19 @@ async def test_insert_if_not_exists_async(mem_db_async):
|
||||
{"domain": "google.com", "name": "Google"},
|
||||
{"domain": "facebook.com", "name": "Facebook"},
|
||||
]
|
||||
stats = await (
|
||||
res = await (
|
||||
table.merge_insert("domain").when_not_matched_insert_all().execute(new_domains)
|
||||
)
|
||||
await table.count_rows() # 3
|
||||
stats # {'num_inserted_rows': 1, 'num_updated_rows': 0, 'num_deleted_rows': 0}
|
||||
# --8<-- [end:insert_if_not_exists_async]
|
||||
res
|
||||
# MergeResult(version=2, num_updated_rows=0,
|
||||
# num_inserted_rows=1, num_deleted_rows=0)
|
||||
# --8<-- [end:insert_if_not_exists]
|
||||
assert await table.count_rows() == 3
|
||||
assert stats["num_inserted_rows"] == 1
|
||||
assert stats["num_updated_rows"] == 0
|
||||
assert stats["num_deleted_rows"] == 0
|
||||
assert res.version == 2
|
||||
assert res.num_inserted_rows == 1
|
||||
assert res.num_deleted_rows == 0
|
||||
assert res.num_updated_rows == 0
|
||||
|
||||
|
||||
def test_replace_range(mem_db):
|
||||
@@ -131,7 +140,7 @@ def test_replace_range(mem_db):
|
||||
new_chunks = [
|
||||
{"doc_id": 1, "chunk_id": 0, "text": "Baz"},
|
||||
]
|
||||
stats = (
|
||||
res = (
|
||||
table.merge_insert(["doc_id", "chunk_id"])
|
||||
.when_matched_update_all()
|
||||
.when_not_matched_insert_all()
|
||||
@@ -139,12 +148,15 @@ def test_replace_range(mem_db):
|
||||
.execute(new_chunks)
|
||||
)
|
||||
table.count_rows("doc_id = 1") # 1
|
||||
stats # {'num_inserted_rows': 0, 'num_updated_rows': 1, 'num_deleted_rows': 1}
|
||||
# --8<-- [end:replace_range]
|
||||
res
|
||||
# MergeResult(version=2, num_updated_rows=1,
|
||||
# num_inserted_rows=0, num_deleted_rows=1)
|
||||
# --8<-- [end:insert_if_not_exists]
|
||||
assert table.count_rows("doc_id = 1") == 1
|
||||
assert stats["num_inserted_rows"] == 0
|
||||
assert stats["num_updated_rows"] == 1
|
||||
assert stats["num_deleted_rows"] == 1
|
||||
assert res.version == 2
|
||||
assert res.num_inserted_rows == 0
|
||||
assert res.num_deleted_rows == 1
|
||||
assert res.num_updated_rows == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -163,7 +175,7 @@ async def test_replace_range_async(mem_db_async):
|
||||
new_chunks = [
|
||||
{"doc_id": 1, "chunk_id": 0, "text": "Baz"},
|
||||
]
|
||||
stats = await (
|
||||
res = await (
|
||||
table.merge_insert(["doc_id", "chunk_id"])
|
||||
.when_matched_update_all()
|
||||
.when_not_matched_insert_all()
|
||||
@@ -171,9 +183,12 @@ async def test_replace_range_async(mem_db_async):
|
||||
.execute(new_chunks)
|
||||
)
|
||||
await table.count_rows("doc_id = 1") # 1
|
||||
stats # {'num_inserted_rows': 0, 'num_updated_rows': 1, 'num_deleted_rows': 1}
|
||||
# --8<-- [end:replace_range_async]
|
||||
res
|
||||
# MergeResult(version=2, num_updated_rows=1,
|
||||
# num_inserted_rows=0, num_deleted_rows=1)
|
||||
# --8<-- [end:insert_if_not_exists]
|
||||
assert await table.count_rows("doc_id = 1") == 1
|
||||
assert stats["num_inserted_rows"] == 0
|
||||
assert stats["num_updated_rows"] == 1
|
||||
assert stats["num_deleted_rows"] == 1
|
||||
assert res.version == 2
|
||||
assert res.num_inserted_rows == 0
|
||||
assert res.num_deleted_rows == 1
|
||||
assert res.num_updated_rows == 1
|
||||
|
||||
@@ -106,15 +106,22 @@ async def test_update_async(mem_db_async: AsyncConnection):
|
||||
table = await mem_db_async.create_table("some_table", data=[{"id": 0}])
|
||||
assert await table.count_rows("id == 0") == 1
|
||||
assert await table.count_rows("id == 7") == 0
|
||||
await table.update({"id": 7})
|
||||
update_res = await table.update({"id": 7})
|
||||
assert update_res.rows_updated == 1
|
||||
assert update_res.version == 2
|
||||
assert await table.count_rows("id == 7") == 1
|
||||
assert await table.count_rows("id == 0") == 0
|
||||
await table.add([{"id": 2}])
|
||||
await table.update(where="id % 2 == 0", updates_sql={"id": "5"})
|
||||
add_res = await table.add([{"id": 2}])
|
||||
assert add_res.version == 3
|
||||
update_res = await table.update(where="id % 2 == 0", updates_sql={"id": "5"})
|
||||
assert update_res.rows_updated == 1
|
||||
assert update_res.version == 4
|
||||
assert await table.count_rows("id == 7") == 1
|
||||
assert await table.count_rows("id == 2") == 0
|
||||
assert await table.count_rows("id == 5") == 1
|
||||
await table.update({"id": 10}, where="id == 5")
|
||||
update_res = await table.update({"id": 10}, where="id == 5")
|
||||
assert update_res.rows_updated == 1
|
||||
assert update_res.version == 5
|
||||
assert await table.count_rows("id == 10") == 1
|
||||
|
||||
|
||||
@@ -437,7 +444,8 @@ def test_add_pydantic_model(mem_db: DBConnection):
|
||||
content="foo", meta=Metadata(source="bar", timestamp=datetime.now())
|
||||
),
|
||||
)
|
||||
tbl.add([expected])
|
||||
add_res = tbl.add([expected])
|
||||
assert add_res.version == 2
|
||||
|
||||
result = tbl.search([0.0, 0.0]).limit(1).to_pydantic(LanceSchema)[0]
|
||||
assert result == expected
|
||||
@@ -459,11 +467,12 @@ async def test_add_async(mem_db_async: AsyncConnection):
|
||||
],
|
||||
)
|
||||
assert await table.count_rows() == 2
|
||||
await table.add(
|
||||
add_res = await table.add(
|
||||
data=[
|
||||
{"vector": [10.0, 11.0], "item": "baz", "price": 30.0},
|
||||
],
|
||||
)
|
||||
assert add_res.version == 2
|
||||
assert await table.count_rows() == 3
|
||||
|
||||
|
||||
@@ -795,7 +804,8 @@ def test_delete(mem_db: DBConnection):
|
||||
)
|
||||
assert len(table) == 2
|
||||
assert len(table.list_versions()) == 1
|
||||
table.delete("id=0")
|
||||
delete_res = table.delete("id=0")
|
||||
assert delete_res.version == 2
|
||||
assert len(table.list_versions()) == 2
|
||||
assert table.version == 2
|
||||
assert len(table) == 1
|
||||
@@ -809,7 +819,9 @@ def test_update(mem_db: DBConnection):
|
||||
)
|
||||
assert len(table) == 2
|
||||
assert len(table.list_versions()) == 1
|
||||
table.update(where="id=0", values={"vector": [1.1, 1.1]})
|
||||
update_res = table.update(where="id=0", values={"vector": [1.1, 1.1]})
|
||||
assert update_res.version == 2
|
||||
assert update_res.rows_updated == 1
|
||||
assert len(table.list_versions()) == 2
|
||||
assert table.version == 2
|
||||
assert len(table) == 2
|
||||
@@ -898,9 +910,16 @@ def test_merge_insert(mem_db: DBConnection):
|
||||
new_data = pa.table({"a": [2, 3, 4], "b": ["x", "y", "z"]})
|
||||
|
||||
# upsert
|
||||
table.merge_insert(
|
||||
"a"
|
||||
).when_matched_update_all().when_not_matched_insert_all().execute(new_data)
|
||||
merge_insert_res = (
|
||||
table.merge_insert("a")
|
||||
.when_matched_update_all()
|
||||
.when_not_matched_insert_all()
|
||||
.execute(new_data)
|
||||
)
|
||||
assert merge_insert_res.version == 2
|
||||
assert merge_insert_res.num_inserted_rows == 1
|
||||
assert merge_insert_res.num_updated_rows == 2
|
||||
assert merge_insert_res.num_deleted_rows == 0
|
||||
|
||||
expected = pa.table({"a": [1, 2, 3, 4], "b": ["a", "x", "y", "z"]})
|
||||
assert table.to_arrow().sort_by("a") == expected
|
||||
@@ -908,17 +927,28 @@ def test_merge_insert(mem_db: DBConnection):
|
||||
table.restore(version)
|
||||
|
||||
# conditional update
|
||||
table.merge_insert("a").when_matched_update_all(where="target.b = 'b'").execute(
|
||||
new_data
|
||||
merge_insert_res = (
|
||||
table.merge_insert("a")
|
||||
.when_matched_update_all(where="target.b = 'b'")
|
||||
.execute(new_data)
|
||||
)
|
||||
assert merge_insert_res.version == 4
|
||||
assert merge_insert_res.num_inserted_rows == 0
|
||||
assert merge_insert_res.num_updated_rows == 1
|
||||
assert merge_insert_res.num_deleted_rows == 0
|
||||
expected = pa.table({"a": [1, 2, 3], "b": ["a", "x", "c"]})
|
||||
assert table.to_arrow().sort_by("a") == expected
|
||||
|
||||
table.restore(version)
|
||||
|
||||
# insert-if-not-exists
|
||||
table.merge_insert("a").when_not_matched_insert_all().execute(new_data)
|
||||
|
||||
merge_insert_res = (
|
||||
table.merge_insert("a").when_not_matched_insert_all().execute(new_data)
|
||||
)
|
||||
assert merge_insert_res.version == 6
|
||||
assert merge_insert_res.num_inserted_rows == 1
|
||||
assert merge_insert_res.num_updated_rows == 0
|
||||
assert merge_insert_res.num_deleted_rows == 0
|
||||
expected = pa.table({"a": [1, 2, 3, 4], "b": ["a", "b", "c", "z"]})
|
||||
assert table.to_arrow().sort_by("a") == expected
|
||||
|
||||
@@ -927,13 +957,17 @@ def test_merge_insert(mem_db: DBConnection):
|
||||
new_data = pa.table({"a": [2, 4], "b": ["x", "z"]})
|
||||
|
||||
# replace-range
|
||||
(
|
||||
merge_insert_res = (
|
||||
table.merge_insert("a")
|
||||
.when_matched_update_all()
|
||||
.when_not_matched_insert_all()
|
||||
.when_not_matched_by_source_delete("a > 2")
|
||||
.execute(new_data)
|
||||
)
|
||||
assert merge_insert_res.version == 8
|
||||
assert merge_insert_res.num_inserted_rows == 1
|
||||
assert merge_insert_res.num_updated_rows == 1
|
||||
assert merge_insert_res.num_deleted_rows == 1
|
||||
|
||||
expected = pa.table({"a": [1, 2, 4], "b": ["a", "x", "z"]})
|
||||
assert table.to_arrow().sort_by("a") == expected
|
||||
@@ -941,11 +975,17 @@ def test_merge_insert(mem_db: DBConnection):
|
||||
table.restore(version)
|
||||
|
||||
# replace-range no condition
|
||||
table.merge_insert(
|
||||
"a"
|
||||
).when_matched_update_all().when_not_matched_insert_all().when_not_matched_by_source_delete().execute(
|
||||
new_data
|
||||
merge_insert_res = (
|
||||
table.merge_insert("a")
|
||||
.when_matched_update_all()
|
||||
.when_not_matched_insert_all()
|
||||
.when_not_matched_by_source_delete()
|
||||
.execute(new_data)
|
||||
)
|
||||
assert merge_insert_res.version == 10
|
||||
assert merge_insert_res.num_inserted_rows == 1
|
||||
assert merge_insert_res.num_updated_rows == 1
|
||||
assert merge_insert_res.num_deleted_rows == 2
|
||||
|
||||
expected = pa.table({"a": [2, 4], "b": ["x", "z"]})
|
||||
assert table.to_arrow().sort_by("a") == expected
|
||||
@@ -1478,11 +1518,13 @@ def test_restore_consistency(tmp_path):
|
||||
def test_add_columns(mem_db: DBConnection):
|
||||
data = pa.table({"id": [0, 1]})
|
||||
table = LanceTable.create(mem_db, "my_table", data=data)
|
||||
table.add_columns({"new_col": "id + 2"})
|
||||
add_columns_res = table.add_columns({"new_col": "id + 2"})
|
||||
assert add_columns_res.version == 2
|
||||
assert table.to_arrow().column_names == ["id", "new_col"]
|
||||
assert table.to_arrow()["new_col"].to_pylist() == [2, 3]
|
||||
|
||||
table.add_columns({"null_int": "cast(null as bigint)"})
|
||||
add_columns_res = table.add_columns({"null_int": "cast(null as bigint)"})
|
||||
assert add_columns_res.version == 3
|
||||
assert table.schema.field("null_int").type == pa.int64()
|
||||
|
||||
|
||||
@@ -1490,7 +1532,8 @@ def test_add_columns(mem_db: DBConnection):
|
||||
async def test_add_columns_async(mem_db_async: AsyncConnection):
|
||||
data = pa.table({"id": [0, 1]})
|
||||
table = await mem_db_async.create_table("my_table", data=data)
|
||||
await table.add_columns({"new_col": "id + 2"})
|
||||
add_columns_res = await table.add_columns({"new_col": "id + 2"})
|
||||
assert add_columns_res.version == 2
|
||||
data = await table.to_arrow()
|
||||
assert data.column_names == ["id", "new_col"]
|
||||
assert data["new_col"].to_pylist() == [2, 3]
|
||||
@@ -1500,9 +1543,10 @@ async def test_add_columns_async(mem_db_async: AsyncConnection):
|
||||
async def test_add_columns_with_schema(mem_db_async: AsyncConnection):
|
||||
data = pa.table({"id": [0, 1]})
|
||||
table = await mem_db_async.create_table("my_table", data=data)
|
||||
await table.add_columns(
|
||||
add_columns_res = await table.add_columns(
|
||||
[pa.field("x", pa.int64()), pa.field("vector", pa.list_(pa.float32(), 8))]
|
||||
)
|
||||
assert add_columns_res.version == 2
|
||||
|
||||
assert await table.schema() == pa.schema(
|
||||
[
|
||||
@@ -1513,11 +1557,12 @@ async def test_add_columns_with_schema(mem_db_async: AsyncConnection):
|
||||
)
|
||||
|
||||
table = await mem_db_async.create_table("table2", data=data)
|
||||
await table.add_columns(
|
||||
add_columns_res = await table.add_columns(
|
||||
pa.schema(
|
||||
[pa.field("y", pa.int64()), pa.field("emb", pa.list_(pa.float32(), 8))]
|
||||
)
|
||||
)
|
||||
assert add_columns_res.version == 2
|
||||
assert await table.schema() == pa.schema(
|
||||
[
|
||||
pa.field("id", pa.int64()),
|
||||
@@ -1530,7 +1575,8 @@ async def test_add_columns_with_schema(mem_db_async: AsyncConnection):
|
||||
def test_alter_columns(mem_db: DBConnection):
|
||||
data = pa.table({"id": [0, 1]})
|
||||
table = mem_db.create_table("my_table", data=data)
|
||||
table.alter_columns({"path": "id", "rename": "new_id"})
|
||||
alter_columns_res = table.alter_columns({"path": "id", "rename": "new_id"})
|
||||
assert alter_columns_res.version == 2
|
||||
assert table.to_arrow().column_names == ["new_id"]
|
||||
|
||||
|
||||
@@ -1538,9 +1584,13 @@ def test_alter_columns(mem_db: DBConnection):
|
||||
async def test_alter_columns_async(mem_db_async: AsyncConnection):
|
||||
data = pa.table({"id": [0, 1]})
|
||||
table = await mem_db_async.create_table("my_table", data=data)
|
||||
await table.alter_columns({"path": "id", "rename": "new_id"})
|
||||
alter_columns_res = await table.alter_columns({"path": "id", "rename": "new_id"})
|
||||
assert alter_columns_res.version == 2
|
||||
assert (await table.to_arrow()).column_names == ["new_id"]
|
||||
await table.alter_columns(dict(path="new_id", data_type=pa.int16(), nullable=True))
|
||||
alter_columns_res = await table.alter_columns(
|
||||
dict(path="new_id", data_type=pa.int16(), nullable=True)
|
||||
)
|
||||
assert alter_columns_res.version == 3
|
||||
data = await table.to_arrow()
|
||||
assert data.column(0).type == pa.int16()
|
||||
assert data.schema.field(0).nullable
|
||||
@@ -1549,7 +1599,8 @@ async def test_alter_columns_async(mem_db_async: AsyncConnection):
|
||||
def test_drop_columns(mem_db: DBConnection):
|
||||
data = pa.table({"id": [0, 1], "category": ["a", "b"]})
|
||||
table = mem_db.create_table("my_table", data=data)
|
||||
table.drop_columns(["category"])
|
||||
drop_columns_res = table.drop_columns(["category"])
|
||||
assert drop_columns_res.version == 2
|
||||
assert table.to_arrow().column_names == ["id"]
|
||||
|
||||
|
||||
@@ -1557,7 +1608,8 @@ def test_drop_columns(mem_db: DBConnection):
|
||||
async def test_drop_columns_async(mem_db_async: AsyncConnection):
|
||||
data = pa.table({"id": [0, 1], "category": ["a", "b"]})
|
||||
table = await mem_db_async.create_table("my_table", data=data)
|
||||
await table.drop_columns(["category"])
|
||||
drop_columns_res = await table.drop_columns(["category"])
|
||||
assert drop_columns_res.version == 2
|
||||
assert (await table.to_arrow()).column_names == ["id"]
|
||||
|
||||
|
||||
|
||||
@@ -11,7 +11,10 @@ use pyo3::{
|
||||
wrap_pyfunction, Bound, PyResult, Python,
|
||||
};
|
||||
use query::{FTSQuery, HybridQuery, Query, VectorQuery};
|
||||
use table::Table;
|
||||
use table::{
|
||||
AddColumnsResult, AddResult, AlterColumnsResult, DeleteResult, DropColumnsResult, MergeResult,
|
||||
Table, UpdateResult,
|
||||
};
|
||||
|
||||
pub mod arrow;
|
||||
pub mod connection;
|
||||
@@ -35,6 +38,13 @@ pub fn _lancedb(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
m.add_class::<HybridQuery>()?;
|
||||
m.add_class::<VectorQuery>()?;
|
||||
m.add_class::<RecordBatchStream>()?;
|
||||
m.add_class::<AddColumnsResult>()?;
|
||||
m.add_class::<AlterColumnsResult>()?;
|
||||
m.add_class::<AddResult>()?;
|
||||
m.add_class::<MergeResult>()?;
|
||||
m.add_class::<DeleteResult>()?;
|
||||
m.add_class::<DropColumnsResult>()?;
|
||||
m.add_class::<UpdateResult>()?;
|
||||
m.add_function(wrap_pyfunction!(connect, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(util::validate_table_name, m)?)?;
|
||||
m.add("__version__", env!("CARGO_PKG_VERSION"))?;
|
||||
|
||||
@@ -58,6 +58,170 @@ pub struct OptimizeStats {
|
||||
pub prune: RemovalStats,
|
||||
}
|
||||
|
||||
#[pyclass(get_all)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct UpdateResult {
|
||||
pub rows_updated: u64,
|
||||
pub version: u64,
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl UpdateResult {
|
||||
pub fn __repr__(&self) -> String {
|
||||
format!(
|
||||
"UpdateResult(rows_updated={}, version={})",
|
||||
self.rows_updated, self.version
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<lancedb::table::UpdateResult> for UpdateResult {
|
||||
fn from(result: lancedb::table::UpdateResult) -> Self {
|
||||
Self {
|
||||
rows_updated: result.rows_updated,
|
||||
version: result.version,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[pyclass(get_all)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct AddResult {
|
||||
pub version: u64,
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl AddResult {
|
||||
pub fn __repr__(&self) -> String {
|
||||
format!("AddResult(version={})", self.version)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<lancedb::table::AddResult> for AddResult {
|
||||
fn from(result: lancedb::table::AddResult) -> Self {
|
||||
Self {
|
||||
version: result.version,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[pyclass(get_all)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct DeleteResult {
|
||||
pub version: u64,
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl DeleteResult {
|
||||
pub fn __repr__(&self) -> String {
|
||||
format!("DeleteResult(version={})", self.version)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<lancedb::table::DeleteResult> for DeleteResult {
|
||||
fn from(result: lancedb::table::DeleteResult) -> Self {
|
||||
Self {
|
||||
version: result.version,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[pyclass(get_all)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct MergeResult {
|
||||
pub version: u64,
|
||||
pub num_updated_rows: u64,
|
||||
pub num_inserted_rows: u64,
|
||||
pub num_deleted_rows: u64,
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl MergeResult {
|
||||
pub fn __repr__(&self) -> String {
|
||||
format!(
|
||||
"MergeResult(version={}, num_updated_rows={}, num_inserted_rows={}, num_deleted_rows={})",
|
||||
self.version,
|
||||
self.num_updated_rows,
|
||||
self.num_inserted_rows,
|
||||
self.num_deleted_rows
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<lancedb::table::MergeResult> for MergeResult {
|
||||
fn from(result: lancedb::table::MergeResult) -> Self {
|
||||
Self {
|
||||
version: result.version,
|
||||
num_updated_rows: result.num_updated_rows,
|
||||
num_inserted_rows: result.num_inserted_rows,
|
||||
num_deleted_rows: result.num_deleted_rows,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[pyclass(get_all)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct AddColumnsResult {
|
||||
pub version: u64,
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl AddColumnsResult {
|
||||
pub fn __repr__(&self) -> String {
|
||||
format!("AddColumnsResult(version={})", self.version)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<lancedb::table::AddColumnsResult> for AddColumnsResult {
|
||||
fn from(result: lancedb::table::AddColumnsResult) -> Self {
|
||||
Self {
|
||||
version: result.version,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[pyclass(get_all)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct AlterColumnsResult {
|
||||
pub version: u64,
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl AlterColumnsResult {
|
||||
pub fn __repr__(&self) -> String {
|
||||
format!("AlterColumnsResult(version={})", self.version)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<lancedb::table::AlterColumnsResult> for AlterColumnsResult {
|
||||
fn from(result: lancedb::table::AlterColumnsResult) -> Self {
|
||||
Self {
|
||||
version: result.version,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[pyclass(get_all)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct DropColumnsResult {
|
||||
pub version: u64,
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl DropColumnsResult {
|
||||
pub fn __repr__(&self) -> String {
|
||||
format!("DropColumnsResult(version={})", self.version)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<lancedb::table::DropColumnsResult> for DropColumnsResult {
|
||||
fn from(result: lancedb::table::DropColumnsResult) -> Self {
|
||||
Self {
|
||||
version: result.version,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[pyclass]
|
||||
pub struct Table {
|
||||
// We keep a copy of the name to use if the inner table is dropped
|
||||
@@ -132,15 +296,16 @@ impl Table {
|
||||
}
|
||||
|
||||
future_into_py(self_.py(), async move {
|
||||
op.execute().await.infer_error()?;
|
||||
Ok(())
|
||||
let result = op.execute().await.infer_error()?;
|
||||
Ok(AddResult::from(result))
|
||||
})
|
||||
}
|
||||
|
||||
pub fn delete(self_: PyRef<'_, Self>, condition: String) -> PyResult<Bound<'_, PyAny>> {
|
||||
let inner = self_.inner_ref()?.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
inner.delete(&condition).await.infer_error()
|
||||
let result = inner.delete(&condition).await.infer_error()?;
|
||||
Ok(DeleteResult::from(result))
|
||||
})
|
||||
}
|
||||
|
||||
@@ -160,8 +325,8 @@ impl Table {
|
||||
op = op.column(column_name, value);
|
||||
}
|
||||
future_into_py(self_.py(), async move {
|
||||
op.execute().await.infer_error()?;
|
||||
Ok(())
|
||||
let result = op.execute().await.infer_error()?;
|
||||
Ok(UpdateResult::from(result))
|
||||
})
|
||||
}
|
||||
|
||||
@@ -489,14 +654,8 @@ impl Table {
|
||||
}
|
||||
|
||||
future_into_py(self_.py(), async move {
|
||||
let stats = builder.execute(Box::new(batches)).await.infer_error()?;
|
||||
Python::with_gil(|py| {
|
||||
let dict = PyDict::new(py);
|
||||
dict.set_item("num_inserted_rows", stats.num_inserted_rows)?;
|
||||
dict.set_item("num_updated_rows", stats.num_updated_rows)?;
|
||||
dict.set_item("num_deleted_rows", stats.num_deleted_rows)?;
|
||||
Ok(dict.unbind())
|
||||
})
|
||||
let res = builder.execute(Box::new(batches)).await.infer_error()?;
|
||||
Ok(MergeResult::from(res))
|
||||
})
|
||||
}
|
||||
|
||||
@@ -532,8 +691,8 @@ impl Table {
|
||||
|
||||
let inner = self_.inner_ref()?.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
inner.add_columns(definitions, None).await.infer_error()?;
|
||||
Ok(())
|
||||
let result = inner.add_columns(definitions, None).await.infer_error()?;
|
||||
Ok(AddColumnsResult::from(result))
|
||||
})
|
||||
}
|
||||
|
||||
@@ -546,8 +705,8 @@ impl Table {
|
||||
|
||||
let inner = self_.inner_ref()?.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
inner.add_columns(transform, None).await.infer_error()?;
|
||||
Ok(())
|
||||
let result = inner.add_columns(transform, None).await.infer_error()?;
|
||||
Ok(AddColumnsResult::from(result))
|
||||
})
|
||||
}
|
||||
|
||||
@@ -590,8 +749,8 @@ impl Table {
|
||||
|
||||
let inner = self_.inner_ref()?.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
inner.alter_columns(&alterations).await.infer_error()?;
|
||||
Ok(())
|
||||
let result = inner.alter_columns(&alterations).await.infer_error()?;
|
||||
Ok(AlterColumnsResult::from(result))
|
||||
})
|
||||
}
|
||||
|
||||
@@ -599,8 +758,8 @@ impl Table {
|
||||
let inner = self_.inner_ref()?.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
let column_refs = columns.iter().map(String::as_str).collect::<Vec<&str>>();
|
||||
inner.drop_columns(&column_refs).await.infer_error()?;
|
||||
Ok(())
|
||||
let result = inner.drop_columns(&column_refs).await.infer_error()?;
|
||||
Ok(DropColumnsResult::from(result))
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user