feat(python): flatten in AsyncQuery (#1967)

PR fixes #1949

---------

Co-authored-by: Will Jones <willjones127@gmail.com>
This commit is contained in:
Gagan Bhullar
2025-01-06 11:52:03 -07:00
committed by GitHub
parent 2c05ffed52
commit b474f98049
3 changed files with 66 additions and 21 deletions

View File

@@ -28,7 +28,7 @@ from .arrow import AsyncRecordBatchReader
from .rerankers.base import Reranker from .rerankers.base import Reranker
from .rerankers.rrf import RRFReranker from .rerankers.rrf import RRFReranker
from .rerankers.util import check_reranker_result from .rerankers.util import check_reranker_result
from .util import safe_import_pandas from .util import safe_import_pandas, flatten_columns
if TYPE_CHECKING: if TYPE_CHECKING:
import PIL import PIL
@@ -293,24 +293,7 @@ class LanceQueryBuilder(ABC):
specified depth. specified depth.
If unspecified, do not flatten the nested columns. If unspecified, do not flatten the nested columns.
""" """
tbl = self.to_arrow() tbl = flatten_columns(self.to_arrow(), flatten)
if flatten is True:
while True:
tbl = tbl.flatten()
# loop through all columns to check if there is any struct column
if any(pa.types.is_struct(col.type) for col in tbl.schema):
continue
else:
break
elif isinstance(flatten, int):
if flatten <= 0:
raise ValueError(
"Please specify a positive integer for flatten or the boolean "
"value `True`"
)
while flatten > 0:
tbl = tbl.flatten()
flatten -= 1
return tbl.to_pandas() return tbl.to_pandas()
@abstractmethod @abstractmethod
@@ -1595,7 +1578,9 @@ class AsyncQueryBase(object):
""" """
return (await self.to_arrow()).to_pylist() return (await self.to_arrow()).to_pylist()
async def to_pandas(self) -> "pd.DataFrame": async def to_pandas(
self, flatten: Optional[Union[int, bool]] = None
) -> "pd.DataFrame":
""" """
Execute the query and collect the results into a pandas DataFrame. Execute the query and collect the results into a pandas DataFrame.
@@ -1615,8 +1600,16 @@ class AsyncQueryBase(object):
... async for batch in await table.query().to_batches(): ... async for batch in await table.query().to_batches():
... batch_df = batch.to_pandas() ... batch_df = batch.to_pandas()
>>> asyncio.run(doctest_example()) >>> asyncio.run(doctest_example())
Parameters
----------
flatten: Optional[Union[int, bool]]
If flatten is True, flatten all nested columns.
If flatten is an integer, flatten the nested columns up to the
specified depth.
If unspecified, do not flatten the nested columns.
""" """
return (await self.to_arrow()).to_pandas() return (flatten_columns(await self.to_arrow(), flatten)).to_pandas()
async def to_polars(self) -> "pl.DataFrame": async def to_polars(self) -> "pl.DataFrame":
""" """

View File

@@ -174,6 +174,38 @@ def safe_import_polars():
return None return None
def flatten_columns(tbl: pa.Table, flatten: Optional[Union[int, bool]] = None):
"""
Flatten all struct columns in a table.
Parameters
----------
flatten: Optional[Union[int, bool]]
If flatten is True, flatten all nested columns.
If flatten is an integer, flatten the nested columns up to the
specified depth.
If unspecified, do not flatten the nested columns.
"""
if flatten is True:
while True:
tbl = tbl.flatten()
# loop through all columns to check if there is any struct column
if any(pa.types.is_struct(col.type) for col in tbl.schema):
continue
else:
break
elif isinstance(flatten, int):
if flatten <= 0:
raise ValueError(
"Please specify a positive integer for flatten or the boolean "
"value `True`"
)
while flatten > 0:
tbl = tbl.flatten()
flatten -= 1
return tbl
def inf_vector_column_query(schema: pa.Schema) -> str: def inf_vector_column_query(schema: pa.Schema) -> str:
""" """
Get the vector column name Get the vector column name

View File

@@ -52,6 +52,17 @@ async def table_async(tmp_path) -> AsyncTable:
return await conn.create_table("test", data) return await conn.create_table("test", data)
@pytest_asyncio.fixture
async def table_struct_async(tmp_path) -> AsyncTable:
conn = await lancedb.connect_async(
tmp_path, read_consistency_interval=timedelta(seconds=0)
)
struct = pa.array([{"n_legs": 2, "animals": "Parrot"}, {"year": 2022, "n_legs": 4}])
month = pa.array([4, 6])
table = pa.Table.from_arrays([struct, month], names=["a", "month"])
return await conn.create_table("test_struct", table)
def test_cast(table): def test_cast(table):
class TestModel(LanceModel): class TestModel(LanceModel):
vector: Vector(2) vector: Vector(2)
@@ -400,6 +411,15 @@ async def test_query_to_pandas_async(table_async: AsyncTable):
assert df.shape == (0, 4) assert df.shape == (0, 4)
@pytest.mark.asyncio
async def test_query_to_pandas_flatten_async(table_struct_async: AsyncTable):
df = await table_struct_async.query().to_pandas()
assert df.shape == (2, 2)
df = await table_struct_async.query().to_pandas(flatten=True)
assert df.shape == (2, 4)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_query_to_polars_async(table_async: AsyncTable): async def test_query_to_polars_async(table_async: AsyncTable):
df = await table_async.query().to_polars() df = await table_async.query().to_polars()