feat(python): flatten in AsyncQuery (#1967)

PR fixes #1949

---------

Co-authored-by: Will Jones <willjones127@gmail.com>
This commit is contained in:
Gagan Bhullar
2025-01-06 11:52:03 -07:00
committed by GitHub
parent 2c05ffed52
commit b474f98049
3 changed files with 66 additions and 21 deletions

View File

@@ -28,7 +28,7 @@ from .arrow import AsyncRecordBatchReader
from .rerankers.base import Reranker
from .rerankers.rrf import RRFReranker
from .rerankers.util import check_reranker_result
from .util import safe_import_pandas
from .util import safe_import_pandas, flatten_columns
if TYPE_CHECKING:
import PIL
@@ -293,24 +293,7 @@ class LanceQueryBuilder(ABC):
specified depth.
If unspecified, do not flatten the nested columns.
"""
tbl = self.to_arrow()
if flatten is True:
while True:
tbl = tbl.flatten()
# loop through all columns to check if there is any struct column
if any(pa.types.is_struct(col.type) for col in tbl.schema):
continue
else:
break
elif isinstance(flatten, int):
if flatten <= 0:
raise ValueError(
"Please specify a positive integer for flatten or the boolean "
"value `True`"
)
while flatten > 0:
tbl = tbl.flatten()
flatten -= 1
tbl = flatten_columns(self.to_arrow(), flatten)
return tbl.to_pandas()
@abstractmethod
@@ -1595,7 +1578,9 @@ class AsyncQueryBase(object):
"""
return (await self.to_arrow()).to_pylist()
async def to_pandas(self) -> "pd.DataFrame":
async def to_pandas(
self, flatten: Optional[Union[int, bool]] = None
) -> "pd.DataFrame":
"""
Execute the query and collect the results into a pandas DataFrame.
@@ -1615,8 +1600,16 @@ class AsyncQueryBase(object):
... async for batch in await table.query().to_batches():
... batch_df = batch.to_pandas()
>>> asyncio.run(doctest_example())
Parameters
----------
flatten: Optional[Union[int, bool]]
If flatten is True, flatten all nested columns.
If flatten is an integer, flatten the nested columns up to the
specified depth.
If unspecified, do not flatten the nested columns.
"""
return (await self.to_arrow()).to_pandas()
return (flatten_columns(await self.to_arrow(), flatten)).to_pandas()
async def to_polars(self) -> "pl.DataFrame":
"""

View File

@@ -174,6 +174,38 @@ def safe_import_polars():
return None
def flatten_columns(tbl: pa.Table, flatten: Optional[Union[int, bool]] = None):
"""
Flatten all struct columns in a table.
Parameters
----------
flatten: Optional[Union[int, bool]]
If flatten is True, flatten all nested columns.
If flatten is an integer, flatten the nested columns up to the
specified depth.
If unspecified, do not flatten the nested columns.
"""
if flatten is True:
while True:
tbl = tbl.flatten()
# loop through all columns to check if there is any struct column
if any(pa.types.is_struct(col.type) for col in tbl.schema):
continue
else:
break
elif isinstance(flatten, int):
if flatten <= 0:
raise ValueError(
"Please specify a positive integer for flatten or the boolean "
"value `True`"
)
while flatten > 0:
tbl = tbl.flatten()
flatten -= 1
return tbl
def inf_vector_column_query(schema: pa.Schema) -> str:
"""
Get the vector column name

View File

@@ -52,6 +52,17 @@ async def table_async(tmp_path) -> AsyncTable:
return await conn.create_table("test", data)
@pytest_asyncio.fixture
async def table_struct_async(tmp_path) -> AsyncTable:
conn = await lancedb.connect_async(
tmp_path, read_consistency_interval=timedelta(seconds=0)
)
struct = pa.array([{"n_legs": 2, "animals": "Parrot"}, {"year": 2022, "n_legs": 4}])
month = pa.array([4, 6])
table = pa.Table.from_arrays([struct, month], names=["a", "month"])
return await conn.create_table("test_struct", table)
def test_cast(table):
class TestModel(LanceModel):
vector: Vector(2)
@@ -400,6 +411,15 @@ async def test_query_to_pandas_async(table_async: AsyncTable):
assert df.shape == (0, 4)
@pytest.mark.asyncio
async def test_query_to_pandas_flatten_async(table_struct_async: AsyncTable):
df = await table_struct_async.query().to_pandas()
assert df.shape == (2, 2)
df = await table_struct_async.query().to_pandas(flatten=True)
assert df.shape == (2, 4)
@pytest.mark.asyncio
async def test_query_to_polars_async(table_async: AsyncTable):
df = await table_async.query().to_polars()