diff --git a/python/python/lancedb/query.py b/python/python/lancedb/query.py index 3aadde5a..02f57fdc 100644 --- a/python/python/lancedb/query.py +++ b/python/python/lancedb/query.py @@ -28,7 +28,7 @@ from .arrow import AsyncRecordBatchReader from .rerankers.base import Reranker from .rerankers.rrf import RRFReranker from .rerankers.util import check_reranker_result -from .util import safe_import_pandas +from .util import safe_import_pandas, flatten_columns if TYPE_CHECKING: import PIL @@ -293,24 +293,7 @@ class LanceQueryBuilder(ABC): specified depth. If unspecified, do not flatten the nested columns. """ - tbl = self.to_arrow() - if flatten is True: - while True: - tbl = tbl.flatten() - # loop through all columns to check if there is any struct column - if any(pa.types.is_struct(col.type) for col in tbl.schema): - continue - else: - break - elif isinstance(flatten, int): - if flatten <= 0: - raise ValueError( - "Please specify a positive integer for flatten or the boolean " - "value `True`" - ) - while flatten > 0: - tbl = tbl.flatten() - flatten -= 1 + tbl = flatten_columns(self.to_arrow(), flatten) return tbl.to_pandas() @abstractmethod @@ -1595,7 +1578,9 @@ class AsyncQueryBase(object): """ return (await self.to_arrow()).to_pylist() - async def to_pandas(self) -> "pd.DataFrame": + async def to_pandas( + self, flatten: Optional[Union[int, bool]] = None + ) -> "pd.DataFrame": """ Execute the query and collect the results into a pandas DataFrame. @@ -1615,8 +1600,16 @@ class AsyncQueryBase(object): ... async for batch in await table.query().to_batches(): ... batch_df = batch.to_pandas() >>> asyncio.run(doctest_example()) + + Parameters + ---------- + flatten: Optional[Union[int, bool]] + If flatten is True, flatten all nested columns. + If flatten is an integer, flatten the nested columns up to the + specified depth. + If unspecified, do not flatten the nested columns. """ - return (await self.to_arrow()).to_pandas() + return (flatten_columns(await self.to_arrow(), flatten)).to_pandas() async def to_polars(self) -> "pl.DataFrame": """ diff --git a/python/python/lancedb/util.py b/python/python/lancedb/util.py index eda7ce06..dd248de0 100644 --- a/python/python/lancedb/util.py +++ b/python/python/lancedb/util.py @@ -174,6 +174,38 @@ def safe_import_polars(): return None +def flatten_columns(tbl: pa.Table, flatten: Optional[Union[int, bool]] = None): + """ + Flatten all struct columns in a table. + + Parameters + ---------- + flatten: Optional[Union[int, bool]] + If flatten is True, flatten all nested columns. + If flatten is an integer, flatten the nested columns up to the + specified depth. + If unspecified, do not flatten the nested columns. + """ + if flatten is True: + while True: + tbl = tbl.flatten() + # loop through all columns to check if there is any struct column + if any(pa.types.is_struct(col.type) for col in tbl.schema): + continue + else: + break + elif isinstance(flatten, int): + if flatten <= 0: + raise ValueError( + "Please specify a positive integer for flatten or the boolean " + "value `True`" + ) + while flatten > 0: + tbl = tbl.flatten() + flatten -= 1 + return tbl + + def inf_vector_column_query(schema: pa.Schema) -> str: """ Get the vector column name diff --git a/python/python/tests/test_query.py b/python/python/tests/test_query.py index 908c6ce9..cc796545 100644 --- a/python/python/tests/test_query.py +++ b/python/python/tests/test_query.py @@ -52,6 +52,17 @@ async def table_async(tmp_path) -> AsyncTable: return await conn.create_table("test", data) +@pytest_asyncio.fixture +async def table_struct_async(tmp_path) -> AsyncTable: + conn = await lancedb.connect_async( + tmp_path, read_consistency_interval=timedelta(seconds=0) + ) + struct = pa.array([{"n_legs": 2, "animals": "Parrot"}, {"year": 2022, "n_legs": 4}]) + month = pa.array([4, 6]) + table = pa.Table.from_arrays([struct, month], names=["a", "month"]) + return await conn.create_table("test_struct", table) + + def test_cast(table): class TestModel(LanceModel): vector: Vector(2) @@ -400,6 +411,15 @@ async def test_query_to_pandas_async(table_async: AsyncTable): assert df.shape == (0, 4) +@pytest.mark.asyncio +async def test_query_to_pandas_flatten_async(table_struct_async: AsyncTable): + df = await table_struct_async.query().to_pandas() + assert df.shape == (2, 2) + + df = await table_struct_async.query().to_pandas(flatten=True) + assert df.shape == (2, 4) + + @pytest.mark.asyncio async def test_query_to_polars_async(table_async: AsyncTable): df = await table_async.query().to_polars()