diff --git a/python/python/lancedb/_lancedb.pyi b/python/python/lancedb/_lancedb.pyi index b1d27d95..a68de4e2 100644 --- a/python/python/lancedb/_lancedb.pyi +++ b/python/python/lancedb/_lancedb.pyi @@ -52,6 +52,7 @@ class Table: async def list_indices(self) -> list[IndexConfig]: ... async def delete(self, filter: str): ... async def add_columns(self, columns: list[tuple[str, str]]) -> None: ... + async def add_columns_with_schema(self, schema: pa.Schema) -> None: ... async def alter_columns(self, columns: list[dict[str, Any]]) -> None: ... async def optimize( self, diff --git a/python/python/lancedb/table.py b/python/python/lancedb/table.py index 0beb4936..d579f9d8 100644 --- a/python/python/lancedb/table.py +++ b/python/python/lancedb/table.py @@ -1265,16 +1265,21 @@ class Table(ABC): """ @abstractmethod - def add_columns(self, transforms: Dict[str, str]): + def add_columns( + self, transforms: Dict[str, str] | pa.Field | List[pa.Field] | pa.Schema + ): """ Add new columns with defined values. Parameters ---------- - transforms: Dict[str, str] + transforms: Dict[str, str], pa.Field, List[pa.Field], pa.Schema A map of column name to a SQL expression to use to calculate the value of the new column. These expressions will be evaluated for each row in the table, and can reference existing columns. + Alternatively, a pyarrow Field or Schema can be provided to add + new columns with the specified data types. The new columns will + be initialized with null values. """ @abstractmethod @@ -2460,7 +2465,9 @@ class LanceTable(Table): """ return LOOP.run(self._table.index_stats(index_name)) - def add_columns(self, transforms: Dict[str, str]): + def add_columns( + self, transforms: Dict[str, str] | pa.field | List[pa.field] | pa.Schema + ): LOOP.run(self._table.add_columns(transforms)) def alter_columns(self, *alterations: Iterable[Dict[str, str]]): @@ -3519,7 +3526,9 @@ class AsyncTable: return await self._inner.update(updates_sql, where) - async def add_columns(self, transforms: dict[str, str]): + async def add_columns( + self, transforms: dict[str, str] | pa.field | List[pa.field] | pa.Schema + ): """ Add new columns with defined values. @@ -3529,8 +3538,19 @@ class AsyncTable: A map of column name to a SQL expression to use to calculate the value of the new column. These expressions will be evaluated for each row in the table, and can reference existing columns. + Alternatively, you can pass a pyarrow field or schema to add + new columns with NULLs. """ - await self._inner.add_columns(list(transforms.items())) + if isinstance(transforms, pa.Field): + transforms = [transforms] + if isinstance(transforms, list) and all( + {isinstance(f, pa.Field) for f in transforms} + ): + transforms = pa.schema(transforms) + if isinstance(transforms, pa.Schema): + await self._inner.add_columns_with_schema(transforms) + else: + await self._inner.add_columns(list(transforms.items())) async def alter_columns(self, *alterations: Iterable[dict[str, Any]]): """ diff --git a/python/python/tests/test_table.py b/python/python/tests/test_table.py index e816dc71..2bb18989 100644 --- a/python/python/tests/test_table.py +++ b/python/python/tests/test_table.py @@ -1384,6 +1384,37 @@ async def test_add_columns_async(mem_db_async: AsyncConnection): assert data["new_col"].to_pylist() == [2, 3] +@pytest.mark.asyncio +async def test_add_columns_with_schema(mem_db_async: AsyncConnection): + data = pa.table({"id": [0, 1]}) + table = await mem_db_async.create_table("my_table", data=data) + await table.add_columns( + [pa.field("x", pa.int64()), pa.field("vector", pa.list_(pa.float32(), 8))] + ) + + assert await table.schema() == pa.schema( + [ + pa.field("id", pa.int64()), + pa.field("x", pa.int64()), + pa.field("vector", pa.list_(pa.float32(), 8)), + ] + ) + + table = await mem_db_async.create_table("table2", data=data) + await table.add_columns( + pa.schema( + [pa.field("y", pa.int64()), pa.field("emb", pa.list_(pa.float32(), 8))] + ) + ) + assert await table.schema() == pa.schema( + [ + pa.field("id", pa.int64()), + pa.field("y", pa.int64()), + pa.field("emb", pa.list_(pa.float32(), 8)), + ] + ) + + def test_alter_columns(mem_db: DBConnection): data = pa.table({"id": [0, 1]}) table = mem_db.create_table("my_table", data=data) diff --git a/python/src/table.rs b/python/src/table.rs index 9e3436f9..60d88a1f 100644 --- a/python/src/table.rs +++ b/python/src/table.rs @@ -1,9 +1,11 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The LanceDB Authors +use std::{collections::HashMap, sync::Arc}; + use arrow::{ - datatypes::DataType, + datatypes::{DataType, Schema}, ffi_stream::ArrowArrayStreamReader, - pyarrow::{FromPyArrow, ToPyArrow}, + pyarrow::{FromPyArrow, PyArrowType, ToPyArrow}, }; use lancedb::table::{ AddDataMode, ColumnAlteration, Duration, NewColumnTransform, OptimizeAction, OptimizeOptions, @@ -16,7 +18,6 @@ use pyo3::{ Bound, FromPyObject, PyAny, PyRef, PyResult, Python, }; use pyo3_async_runtimes::tokio::future_into_py; -use std::collections::HashMap; use crate::{ error::PythonErrorExt, @@ -444,6 +445,20 @@ impl Table { }) } + pub fn add_columns_with_schema( + self_: PyRef<'_, Self>, + schema: PyArrowType, + ) -> PyResult> { + let arrow_schema = &schema.0; + let transform = NewColumnTransform::AllNulls(Arc::new(arrow_schema.clone())); + + let inner = self_.inner_ref()?.clone(); + future_into_py(self_.py(), async move { + inner.add_columns(transform, None).await.infer_error()?; + Ok(()) + }) + } + pub fn alter_columns<'a>( self_: PyRef<'a, Self>, alterations: Vec>,