feat: add columns using pyarrow schema (#2284)

This commit is contained in:
Lei Xu
2025-03-28 08:51:50 -07:00
committed by GitHub
parent c321cccc12
commit f52d05d3fa
4 changed files with 75 additions and 8 deletions

View File

@@ -52,6 +52,7 @@ class Table:
async def list_indices(self) -> list[IndexConfig]: ...
async def delete(self, filter: str): ...
async def add_columns(self, columns: list[tuple[str, str]]) -> None: ...
async def add_columns_with_schema(self, schema: pa.Schema) -> None: ...
async def alter_columns(self, columns: list[dict[str, Any]]) -> None: ...
async def optimize(
self,

View File

@@ -1265,16 +1265,21 @@ class Table(ABC):
"""
@abstractmethod
def add_columns(self, transforms: Dict[str, str]):
def add_columns(
self, transforms: Dict[str, str] | pa.Field | List[pa.Field] | pa.Schema
):
"""
Add new columns with defined values.
Parameters
----------
transforms: Dict[str, str]
transforms: Dict[str, str], pa.Field, List[pa.Field], pa.Schema
A map of column name to a SQL expression to use to calculate the
value of the new column. These expressions will be evaluated for
each row in the table, and can reference existing columns.
Alternatively, a pyarrow Field or Schema can be provided to add
new columns with the specified data types. The new columns will
be initialized with null values.
"""
@abstractmethod
@@ -2460,7 +2465,9 @@ class LanceTable(Table):
"""
return LOOP.run(self._table.index_stats(index_name))
def add_columns(self, transforms: Dict[str, str]):
def add_columns(
self, transforms: Dict[str, str] | pa.field | List[pa.field] | pa.Schema
):
LOOP.run(self._table.add_columns(transforms))
def alter_columns(self, *alterations: Iterable[Dict[str, str]]):
@@ -3519,7 +3526,9 @@ class AsyncTable:
return await self._inner.update(updates_sql, where)
async def add_columns(self, transforms: dict[str, str]):
async def add_columns(
self, transforms: dict[str, str] | pa.field | List[pa.field] | pa.Schema
):
"""
Add new columns with defined values.
@@ -3529,8 +3538,19 @@ class AsyncTable:
A map of column name to a SQL expression to use to calculate the
value of the new column. These expressions will be evaluated for
each row in the table, and can reference existing columns.
Alternatively, you can pass a pyarrow field or schema to add
new columns with NULLs.
"""
await self._inner.add_columns(list(transforms.items()))
if isinstance(transforms, pa.Field):
transforms = [transforms]
if isinstance(transforms, list) and all(
{isinstance(f, pa.Field) for f in transforms}
):
transforms = pa.schema(transforms)
if isinstance(transforms, pa.Schema):
await self._inner.add_columns_with_schema(transforms)
else:
await self._inner.add_columns(list(transforms.items()))
async def alter_columns(self, *alterations: Iterable[dict[str, Any]]):
"""

View File

@@ -1384,6 +1384,37 @@ async def test_add_columns_async(mem_db_async: AsyncConnection):
assert data["new_col"].to_pylist() == [2, 3]
@pytest.mark.asyncio
async def test_add_columns_with_schema(mem_db_async: AsyncConnection):
data = pa.table({"id": [0, 1]})
table = await mem_db_async.create_table("my_table", data=data)
await table.add_columns(
[pa.field("x", pa.int64()), pa.field("vector", pa.list_(pa.float32(), 8))]
)
assert await table.schema() == pa.schema(
[
pa.field("id", pa.int64()),
pa.field("x", pa.int64()),
pa.field("vector", pa.list_(pa.float32(), 8)),
]
)
table = await mem_db_async.create_table("table2", data=data)
await table.add_columns(
pa.schema(
[pa.field("y", pa.int64()), pa.field("emb", pa.list_(pa.float32(), 8))]
)
)
assert await table.schema() == pa.schema(
[
pa.field("id", pa.int64()),
pa.field("y", pa.int64()),
pa.field("emb", pa.list_(pa.float32(), 8)),
]
)
def test_alter_columns(mem_db: DBConnection):
data = pa.table({"id": [0, 1]})
table = mem_db.create_table("my_table", data=data)