mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-27 07:09:57 +00:00
feat: add columns using pyarrow schema (#2284)
This commit is contained in:
@@ -52,6 +52,7 @@ class Table:
|
||||
async def list_indices(self) -> list[IndexConfig]: ...
|
||||
async def delete(self, filter: str): ...
|
||||
async def add_columns(self, columns: list[tuple[str, str]]) -> None: ...
|
||||
async def add_columns_with_schema(self, schema: pa.Schema) -> None: ...
|
||||
async def alter_columns(self, columns: list[dict[str, Any]]) -> None: ...
|
||||
async def optimize(
|
||||
self,
|
||||
|
||||
@@ -1265,16 +1265,21 @@ class Table(ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def add_columns(self, transforms: Dict[str, str]):
|
||||
def add_columns(
|
||||
self, transforms: Dict[str, str] | pa.Field | List[pa.Field] | pa.Schema
|
||||
):
|
||||
"""
|
||||
Add new columns with defined values.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
transforms: Dict[str, str]
|
||||
transforms: Dict[str, str], pa.Field, List[pa.Field], pa.Schema
|
||||
A map of column name to a SQL expression to use to calculate the
|
||||
value of the new column. These expressions will be evaluated for
|
||||
each row in the table, and can reference existing columns.
|
||||
Alternatively, a pyarrow Field or Schema can be provided to add
|
||||
new columns with the specified data types. The new columns will
|
||||
be initialized with null values.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
@@ -2460,7 +2465,9 @@ class LanceTable(Table):
|
||||
"""
|
||||
return LOOP.run(self._table.index_stats(index_name))
|
||||
|
||||
def add_columns(self, transforms: Dict[str, str]):
|
||||
def add_columns(
|
||||
self, transforms: Dict[str, str] | pa.field | List[pa.field] | pa.Schema
|
||||
):
|
||||
LOOP.run(self._table.add_columns(transforms))
|
||||
|
||||
def alter_columns(self, *alterations: Iterable[Dict[str, str]]):
|
||||
@@ -3519,7 +3526,9 @@ class AsyncTable:
|
||||
|
||||
return await self._inner.update(updates_sql, where)
|
||||
|
||||
async def add_columns(self, transforms: dict[str, str]):
|
||||
async def add_columns(
|
||||
self, transforms: dict[str, str] | pa.field | List[pa.field] | pa.Schema
|
||||
):
|
||||
"""
|
||||
Add new columns with defined values.
|
||||
|
||||
@@ -3529,8 +3538,19 @@ class AsyncTable:
|
||||
A map of column name to a SQL expression to use to calculate the
|
||||
value of the new column. These expressions will be evaluated for
|
||||
each row in the table, and can reference existing columns.
|
||||
Alternatively, you can pass a pyarrow field or schema to add
|
||||
new columns with NULLs.
|
||||
"""
|
||||
await self._inner.add_columns(list(transforms.items()))
|
||||
if isinstance(transforms, pa.Field):
|
||||
transforms = [transforms]
|
||||
if isinstance(transforms, list) and all(
|
||||
{isinstance(f, pa.Field) for f in transforms}
|
||||
):
|
||||
transforms = pa.schema(transforms)
|
||||
if isinstance(transforms, pa.Schema):
|
||||
await self._inner.add_columns_with_schema(transforms)
|
||||
else:
|
||||
await self._inner.add_columns(list(transforms.items()))
|
||||
|
||||
async def alter_columns(self, *alterations: Iterable[dict[str, Any]]):
|
||||
"""
|
||||
|
||||
@@ -1384,6 +1384,37 @@ async def test_add_columns_async(mem_db_async: AsyncConnection):
|
||||
assert data["new_col"].to_pylist() == [2, 3]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_columns_with_schema(mem_db_async: AsyncConnection):
|
||||
data = pa.table({"id": [0, 1]})
|
||||
table = await mem_db_async.create_table("my_table", data=data)
|
||||
await table.add_columns(
|
||||
[pa.field("x", pa.int64()), pa.field("vector", pa.list_(pa.float32(), 8))]
|
||||
)
|
||||
|
||||
assert await table.schema() == pa.schema(
|
||||
[
|
||||
pa.field("id", pa.int64()),
|
||||
pa.field("x", pa.int64()),
|
||||
pa.field("vector", pa.list_(pa.float32(), 8)),
|
||||
]
|
||||
)
|
||||
|
||||
table = await mem_db_async.create_table("table2", data=data)
|
||||
await table.add_columns(
|
||||
pa.schema(
|
||||
[pa.field("y", pa.int64()), pa.field("emb", pa.list_(pa.float32(), 8))]
|
||||
)
|
||||
)
|
||||
assert await table.schema() == pa.schema(
|
||||
[
|
||||
pa.field("id", pa.int64()),
|
||||
pa.field("y", pa.int64()),
|
||||
pa.field("emb", pa.list_(pa.float32(), 8)),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def test_alter_columns(mem_db: DBConnection):
|
||||
data = pa.table({"id": [0, 1]})
|
||||
table = mem_db.create_table("my_table", data=data)
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
|
||||
use arrow::{
|
||||
datatypes::DataType,
|
||||
datatypes::{DataType, Schema},
|
||||
ffi_stream::ArrowArrayStreamReader,
|
||||
pyarrow::{FromPyArrow, ToPyArrow},
|
||||
pyarrow::{FromPyArrow, PyArrowType, ToPyArrow},
|
||||
};
|
||||
use lancedb::table::{
|
||||
AddDataMode, ColumnAlteration, Duration, NewColumnTransform, OptimizeAction, OptimizeOptions,
|
||||
@@ -16,7 +18,6 @@ use pyo3::{
|
||||
Bound, FromPyObject, PyAny, PyRef, PyResult, Python,
|
||||
};
|
||||
use pyo3_async_runtimes::tokio::future_into_py;
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::{
|
||||
error::PythonErrorExt,
|
||||
@@ -444,6 +445,20 @@ impl Table {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn add_columns_with_schema(
|
||||
self_: PyRef<'_, Self>,
|
||||
schema: PyArrowType<Schema>,
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
let arrow_schema = &schema.0;
|
||||
let transform = NewColumnTransform::AllNulls(Arc::new(arrow_schema.clone()));
|
||||
|
||||
let inner = self_.inner_ref()?.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
inner.add_columns(transform, None).await.infer_error()?;
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
pub fn alter_columns<'a>(
|
||||
self_: PyRef<'a, Self>,
|
||||
alterations: Vec<Bound<PyDict>>,
|
||||
|
||||
Reference in New Issue
Block a user