mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-06 20:02:58 +00:00
feat(python): flatten in AsyncQuery (#1967)
PR fixes #1949 --------- Co-authored-by: Will Jones <willjones127@gmail.com>
This commit is contained in:
@@ -28,7 +28,7 @@ from .arrow import AsyncRecordBatchReader
|
|||||||
from .rerankers.base import Reranker
|
from .rerankers.base import Reranker
|
||||||
from .rerankers.rrf import RRFReranker
|
from .rerankers.rrf import RRFReranker
|
||||||
from .rerankers.util import check_reranker_result
|
from .rerankers.util import check_reranker_result
|
||||||
from .util import safe_import_pandas
|
from .util import safe_import_pandas, flatten_columns
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import PIL
|
import PIL
|
||||||
@@ -293,24 +293,7 @@ class LanceQueryBuilder(ABC):
|
|||||||
specified depth.
|
specified depth.
|
||||||
If unspecified, do not flatten the nested columns.
|
If unspecified, do not flatten the nested columns.
|
||||||
"""
|
"""
|
||||||
tbl = self.to_arrow()
|
tbl = flatten_columns(self.to_arrow(), flatten)
|
||||||
if flatten is True:
|
|
||||||
while True:
|
|
||||||
tbl = tbl.flatten()
|
|
||||||
# loop through all columns to check if there is any struct column
|
|
||||||
if any(pa.types.is_struct(col.type) for col in tbl.schema):
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
elif isinstance(flatten, int):
|
|
||||||
if flatten <= 0:
|
|
||||||
raise ValueError(
|
|
||||||
"Please specify a positive integer for flatten or the boolean "
|
|
||||||
"value `True`"
|
|
||||||
)
|
|
||||||
while flatten > 0:
|
|
||||||
tbl = tbl.flatten()
|
|
||||||
flatten -= 1
|
|
||||||
return tbl.to_pandas()
|
return tbl.to_pandas()
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@@ -1595,7 +1578,9 @@ class AsyncQueryBase(object):
|
|||||||
"""
|
"""
|
||||||
return (await self.to_arrow()).to_pylist()
|
return (await self.to_arrow()).to_pylist()
|
||||||
|
|
||||||
async def to_pandas(self) -> "pd.DataFrame":
|
async def to_pandas(
|
||||||
|
self, flatten: Optional[Union[int, bool]] = None
|
||||||
|
) -> "pd.DataFrame":
|
||||||
"""
|
"""
|
||||||
Execute the query and collect the results into a pandas DataFrame.
|
Execute the query and collect the results into a pandas DataFrame.
|
||||||
|
|
||||||
@@ -1615,8 +1600,16 @@ class AsyncQueryBase(object):
|
|||||||
... async for batch in await table.query().to_batches():
|
... async for batch in await table.query().to_batches():
|
||||||
... batch_df = batch.to_pandas()
|
... batch_df = batch.to_pandas()
|
||||||
>>> asyncio.run(doctest_example())
|
>>> asyncio.run(doctest_example())
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
flatten: Optional[Union[int, bool]]
|
||||||
|
If flatten is True, flatten all nested columns.
|
||||||
|
If flatten is an integer, flatten the nested columns up to the
|
||||||
|
specified depth.
|
||||||
|
If unspecified, do not flatten the nested columns.
|
||||||
"""
|
"""
|
||||||
return (await self.to_arrow()).to_pandas()
|
return (flatten_columns(await self.to_arrow(), flatten)).to_pandas()
|
||||||
|
|
||||||
async def to_polars(self) -> "pl.DataFrame":
|
async def to_polars(self) -> "pl.DataFrame":
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -174,6 +174,38 @@ def safe_import_polars():
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def flatten_columns(tbl: pa.Table, flatten: Optional[Union[int, bool]] = None):
|
||||||
|
"""
|
||||||
|
Flatten all struct columns in a table.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
flatten: Optional[Union[int, bool]]
|
||||||
|
If flatten is True, flatten all nested columns.
|
||||||
|
If flatten is an integer, flatten the nested columns up to the
|
||||||
|
specified depth.
|
||||||
|
If unspecified, do not flatten the nested columns.
|
||||||
|
"""
|
||||||
|
if flatten is True:
|
||||||
|
while True:
|
||||||
|
tbl = tbl.flatten()
|
||||||
|
# loop through all columns to check if there is any struct column
|
||||||
|
if any(pa.types.is_struct(col.type) for col in tbl.schema):
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
elif isinstance(flatten, int):
|
||||||
|
if flatten <= 0:
|
||||||
|
raise ValueError(
|
||||||
|
"Please specify a positive integer for flatten or the boolean "
|
||||||
|
"value `True`"
|
||||||
|
)
|
||||||
|
while flatten > 0:
|
||||||
|
tbl = tbl.flatten()
|
||||||
|
flatten -= 1
|
||||||
|
return tbl
|
||||||
|
|
||||||
|
|
||||||
def inf_vector_column_query(schema: pa.Schema) -> str:
|
def inf_vector_column_query(schema: pa.Schema) -> str:
|
||||||
"""
|
"""
|
||||||
Get the vector column name
|
Get the vector column name
|
||||||
|
|||||||
@@ -52,6 +52,17 @@ async def table_async(tmp_path) -> AsyncTable:
|
|||||||
return await conn.create_table("test", data)
|
return await conn.create_table("test", data)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def table_struct_async(tmp_path) -> AsyncTable:
|
||||||
|
conn = await lancedb.connect_async(
|
||||||
|
tmp_path, read_consistency_interval=timedelta(seconds=0)
|
||||||
|
)
|
||||||
|
struct = pa.array([{"n_legs": 2, "animals": "Parrot"}, {"year": 2022, "n_legs": 4}])
|
||||||
|
month = pa.array([4, 6])
|
||||||
|
table = pa.Table.from_arrays([struct, month], names=["a", "month"])
|
||||||
|
return await conn.create_table("test_struct", table)
|
||||||
|
|
||||||
|
|
||||||
def test_cast(table):
|
def test_cast(table):
|
||||||
class TestModel(LanceModel):
|
class TestModel(LanceModel):
|
||||||
vector: Vector(2)
|
vector: Vector(2)
|
||||||
@@ -400,6 +411,15 @@ async def test_query_to_pandas_async(table_async: AsyncTable):
|
|||||||
assert df.shape == (0, 4)
|
assert df.shape == (0, 4)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_to_pandas_flatten_async(table_struct_async: AsyncTable):
|
||||||
|
df = await table_struct_async.query().to_pandas()
|
||||||
|
assert df.shape == (2, 2)
|
||||||
|
|
||||||
|
df = await table_struct_async.query().to_pandas(flatten=True)
|
||||||
|
assert df.shape == (2, 4)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_query_to_polars_async(table_async: AsyncTable):
|
async def test_query_to_polars_async(table_async: AsyncTable):
|
||||||
df = await table_async.query().to_polars()
|
df = await table_async.query().to_polars()
|
||||||
|
|||||||
Reference in New Issue
Block a user