mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-27 15:12:53 +00:00
feat: add columns using pyarrow schema (#2284)
This commit is contained in:
@@ -52,6 +52,7 @@ class Table:
|
||||
async def list_indices(self) -> list[IndexConfig]: ...
|
||||
async def delete(self, filter: str): ...
|
||||
async def add_columns(self, columns: list[tuple[str, str]]) -> None: ...
|
||||
async def add_columns_with_schema(self, schema: pa.Schema) -> None: ...
|
||||
async def alter_columns(self, columns: list[dict[str, Any]]) -> None: ...
|
||||
async def optimize(
|
||||
self,
|
||||
|
||||
@@ -1265,16 +1265,21 @@ class Table(ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def add_columns(self, transforms: Dict[str, str]):
|
||||
def add_columns(
|
||||
self, transforms: Dict[str, str] | pa.Field | List[pa.Field] | pa.Schema
|
||||
):
|
||||
"""
|
||||
Add new columns with defined values.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
transforms: Dict[str, str]
|
||||
transforms: Dict[str, str], pa.Field, List[pa.Field], pa.Schema
|
||||
A map of column name to a SQL expression to use to calculate the
|
||||
value of the new column. These expressions will be evaluated for
|
||||
each row in the table, and can reference existing columns.
|
||||
Alternatively, a pyarrow Field or Schema can be provided to add
|
||||
new columns with the specified data types. The new columns will
|
||||
be initialized with null values.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
@@ -2460,7 +2465,9 @@ class LanceTable(Table):
|
||||
"""
|
||||
return LOOP.run(self._table.index_stats(index_name))
|
||||
|
||||
def add_columns(self, transforms: Dict[str, str]):
|
||||
def add_columns(
|
||||
self, transforms: Dict[str, str] | pa.field | List[pa.field] | pa.Schema
|
||||
):
|
||||
LOOP.run(self._table.add_columns(transforms))
|
||||
|
||||
def alter_columns(self, *alterations: Iterable[Dict[str, str]]):
|
||||
@@ -3519,7 +3526,9 @@ class AsyncTable:
|
||||
|
||||
return await self._inner.update(updates_sql, where)
|
||||
|
||||
async def add_columns(self, transforms: dict[str, str]):
|
||||
async def add_columns(
|
||||
self, transforms: dict[str, str] | pa.field | List[pa.field] | pa.Schema
|
||||
):
|
||||
"""
|
||||
Add new columns with defined values.
|
||||
|
||||
@@ -3529,8 +3538,19 @@ class AsyncTable:
|
||||
A map of column name to a SQL expression to use to calculate the
|
||||
value of the new column. These expressions will be evaluated for
|
||||
each row in the table, and can reference existing columns.
|
||||
Alternatively, you can pass a pyarrow field or schema to add
|
||||
new columns with NULLs.
|
||||
"""
|
||||
await self._inner.add_columns(list(transforms.items()))
|
||||
if isinstance(transforms, pa.Field):
|
||||
transforms = [transforms]
|
||||
if isinstance(transforms, list) and all(
|
||||
{isinstance(f, pa.Field) for f in transforms}
|
||||
):
|
||||
transforms = pa.schema(transforms)
|
||||
if isinstance(transforms, pa.Schema):
|
||||
await self._inner.add_columns_with_schema(transforms)
|
||||
else:
|
||||
await self._inner.add_columns(list(transforms.items()))
|
||||
|
||||
async def alter_columns(self, *alterations: Iterable[dict[str, Any]]):
|
||||
"""
|
||||
|
||||
@@ -1384,6 +1384,37 @@ async def test_add_columns_async(mem_db_async: AsyncConnection):
|
||||
assert data["new_col"].to_pylist() == [2, 3]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_columns_with_schema(mem_db_async: AsyncConnection):
|
||||
data = pa.table({"id": [0, 1]})
|
||||
table = await mem_db_async.create_table("my_table", data=data)
|
||||
await table.add_columns(
|
||||
[pa.field("x", pa.int64()), pa.field("vector", pa.list_(pa.float32(), 8))]
|
||||
)
|
||||
|
||||
assert await table.schema() == pa.schema(
|
||||
[
|
||||
pa.field("id", pa.int64()),
|
||||
pa.field("x", pa.int64()),
|
||||
pa.field("vector", pa.list_(pa.float32(), 8)),
|
||||
]
|
||||
)
|
||||
|
||||
table = await mem_db_async.create_table("table2", data=data)
|
||||
await table.add_columns(
|
||||
pa.schema(
|
||||
[pa.field("y", pa.int64()), pa.field("emb", pa.list_(pa.float32(), 8))]
|
||||
)
|
||||
)
|
||||
assert await table.schema() == pa.schema(
|
||||
[
|
||||
pa.field("id", pa.int64()),
|
||||
pa.field("y", pa.int64()),
|
||||
pa.field("emb", pa.list_(pa.float32(), 8)),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def test_alter_columns(mem_db: DBConnection):
|
||||
data = pa.table({"id": [0, 1]})
|
||||
table = mem_db.create_table("my_table", data=data)
|
||||
|
||||
Reference in New Issue
Block a user