mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-03 02:12:56 +00:00
feat(python): flatten in AsyncQuery (#1967)
PR fixes #1949 --------- Co-authored-by: Will Jones <willjones127@gmail.com>
This commit is contained in:
@@ -28,7 +28,7 @@ from .arrow import AsyncRecordBatchReader
|
||||
from .rerankers.base import Reranker
|
||||
from .rerankers.rrf import RRFReranker
|
||||
from .rerankers.util import check_reranker_result
|
||||
from .util import safe_import_pandas
|
||||
from .util import safe_import_pandas, flatten_columns
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import PIL
|
||||
@@ -293,24 +293,7 @@ class LanceQueryBuilder(ABC):
|
||||
specified depth.
|
||||
If unspecified, do not flatten the nested columns.
|
||||
"""
|
||||
tbl = self.to_arrow()
|
||||
if flatten is True:
|
||||
while True:
|
||||
tbl = tbl.flatten()
|
||||
# loop through all columns to check if there is any struct column
|
||||
if any(pa.types.is_struct(col.type) for col in tbl.schema):
|
||||
continue
|
||||
else:
|
||||
break
|
||||
elif isinstance(flatten, int):
|
||||
if flatten <= 0:
|
||||
raise ValueError(
|
||||
"Please specify a positive integer for flatten or the boolean "
|
||||
"value `True`"
|
||||
)
|
||||
while flatten > 0:
|
||||
tbl = tbl.flatten()
|
||||
flatten -= 1
|
||||
tbl = flatten_columns(self.to_arrow(), flatten)
|
||||
return tbl.to_pandas()
|
||||
|
||||
@abstractmethod
|
||||
@@ -1595,7 +1578,9 @@ class AsyncQueryBase(object):
|
||||
"""
|
||||
return (await self.to_arrow()).to_pylist()
|
||||
|
||||
async def to_pandas(self) -> "pd.DataFrame":
|
||||
async def to_pandas(
|
||||
self, flatten: Optional[Union[int, bool]] = None
|
||||
) -> "pd.DataFrame":
|
||||
"""
|
||||
Execute the query and collect the results into a pandas DataFrame.
|
||||
|
||||
@@ -1615,8 +1600,16 @@ class AsyncQueryBase(object):
|
||||
... async for batch in await table.query().to_batches():
|
||||
... batch_df = batch.to_pandas()
|
||||
>>> asyncio.run(doctest_example())
|
||||
|
||||
Parameters
|
||||
----------
|
||||
flatten: Optional[Union[int, bool]]
|
||||
If flatten is True, flatten all nested columns.
|
||||
If flatten is an integer, flatten the nested columns up to the
|
||||
specified depth.
|
||||
If unspecified, do not flatten the nested columns.
|
||||
"""
|
||||
return (await self.to_arrow()).to_pandas()
|
||||
return (flatten_columns(await self.to_arrow(), flatten)).to_pandas()
|
||||
|
||||
async def to_polars(self) -> "pl.DataFrame":
|
||||
"""
|
||||
|
||||
@@ -174,6 +174,38 @@ def safe_import_polars():
|
||||
return None
|
||||
|
||||
|
||||
def flatten_columns(tbl: pa.Table, flatten: Optional[Union[int, bool]] = None):
|
||||
"""
|
||||
Flatten all struct columns in a table.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
flatten: Optional[Union[int, bool]]
|
||||
If flatten is True, flatten all nested columns.
|
||||
If flatten is an integer, flatten the nested columns up to the
|
||||
specified depth.
|
||||
If unspecified, do not flatten the nested columns.
|
||||
"""
|
||||
if flatten is True:
|
||||
while True:
|
||||
tbl = tbl.flatten()
|
||||
# loop through all columns to check if there is any struct column
|
||||
if any(pa.types.is_struct(col.type) for col in tbl.schema):
|
||||
continue
|
||||
else:
|
||||
break
|
||||
elif isinstance(flatten, int):
|
||||
if flatten <= 0:
|
||||
raise ValueError(
|
||||
"Please specify a positive integer for flatten or the boolean "
|
||||
"value `True`"
|
||||
)
|
||||
while flatten > 0:
|
||||
tbl = tbl.flatten()
|
||||
flatten -= 1
|
||||
return tbl
|
||||
|
||||
|
||||
def inf_vector_column_query(schema: pa.Schema) -> str:
|
||||
"""
|
||||
Get the vector column name
|
||||
|
||||
@@ -52,6 +52,17 @@ async def table_async(tmp_path) -> AsyncTable:
|
||||
return await conn.create_table("test", data)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def table_struct_async(tmp_path) -> AsyncTable:
|
||||
conn = await lancedb.connect_async(
|
||||
tmp_path, read_consistency_interval=timedelta(seconds=0)
|
||||
)
|
||||
struct = pa.array([{"n_legs": 2, "animals": "Parrot"}, {"year": 2022, "n_legs": 4}])
|
||||
month = pa.array([4, 6])
|
||||
table = pa.Table.from_arrays([struct, month], names=["a", "month"])
|
||||
return await conn.create_table("test_struct", table)
|
||||
|
||||
|
||||
def test_cast(table):
|
||||
class TestModel(LanceModel):
|
||||
vector: Vector(2)
|
||||
@@ -400,6 +411,15 @@ async def test_query_to_pandas_async(table_async: AsyncTable):
|
||||
assert df.shape == (0, 4)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_to_pandas_flatten_async(table_struct_async: AsyncTable):
|
||||
df = await table_struct_async.query().to_pandas()
|
||||
assert df.shape == (2, 2)
|
||||
|
||||
df = await table_struct_async.query().to_pandas(flatten=True)
|
||||
assert df.shape == (2, 4)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_to_polars_async(table_async: AsyncTable):
|
||||
df = await table_async.query().to_polars()
|
||||
|
||||
Reference in New Issue
Block a user