mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-11 14:22:59 +00:00
feat(python): add option to flatten output in to_pandas (#722)
Closes https://github.com/lancedb/lance/issues/1738 We add a `flatten` parameter to the signature of `to_pandas`. By default this is None and does nothing. If set to True or -1, then LanceDB will flatten structs before converting to a pandas dataframe. All nested structs are also flattened. If set to any positive integer, then LanceDB will flatten structs up to the specified level of nesting. --------- Co-authored-by: Weston Pace <weston.pace@gmail.com>
This commit is contained in:
@@ -185,14 +185,40 @@ class LanceQueryBuilder(ABC):
|
||||
"""
|
||||
return self.to_pandas()
|
||||
|
||||
def to_pandas(self) -> "pd.DataFrame":
|
||||
def to_pandas(self, flatten: Optional[Union[int, bool]] = None) -> "pd.DataFrame":
|
||||
"""
|
||||
Execute the query and return the results as a pandas DataFrame.
|
||||
In addition to the selected columns, LanceDB also returns a vector
|
||||
and also the "_distance" column which is the distance between the query
|
||||
vector and the returned vector.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
flatten: Optional[Union[int, bool]]
|
||||
If flatten is True, flatten all nested columns.
|
||||
If flatten is an integer, flatten the nested columns up to the
|
||||
specified depth.
|
||||
If unspecified, do not flatten the nested columns.
|
||||
"""
|
||||
return self.to_arrow().to_pandas()
|
||||
tbl = self.to_arrow()
|
||||
if flatten is True:
|
||||
while True:
|
||||
tbl = tbl.flatten()
|
||||
has_struct = False
|
||||
# loop through all columns to check if there is any struct column
|
||||
if any(pa.types.is_struct(col.type) for col in tbl.schema):
|
||||
continue
|
||||
else:
|
||||
break
|
||||
elif isinstance(flatten, int):
|
||||
if flatten <= 0:
|
||||
raise ValueError(
|
||||
"Please specify a positive integer for flatten or the boolean value `True`"
|
||||
)
|
||||
while flatten > 0:
|
||||
tbl = tbl.flatten()
|
||||
flatten -= 1
|
||||
return tbl.to_pandas()
|
||||
|
||||
@abstractmethod
|
||||
def to_arrow(self) -> pa.Table:
|
||||
|
||||
@@ -144,9 +144,13 @@ def test_add(db):
|
||||
def test_add_pydantic_model(db):
|
||||
# https://github.com/lancedb/lancedb/issues/562
|
||||
|
||||
class Metadata(BaseModel):
|
||||
source: str
|
||||
timestamp: datetime
|
||||
|
||||
class Document(BaseModel):
|
||||
content: str
|
||||
source: str
|
||||
meta: Metadata
|
||||
|
||||
class LanceSchema(LanceModel):
|
||||
id: str
|
||||
@@ -162,13 +166,21 @@ def test_add_pydantic_model(db):
|
||||
id="id",
|
||||
vector=[0.0, 0.0],
|
||||
li=[1, 2, 3],
|
||||
payload=Document(content="foo", source="bar"),
|
||||
payload=Document(
|
||||
content="foo", meta=Metadata(source="bar", timestamp=datetime.now())
|
||||
),
|
||||
)
|
||||
tbl.add([expected])
|
||||
|
||||
result = tbl.search([0.0, 0.0]).limit(1).to_pydantic(LanceSchema)[0]
|
||||
assert result == expected
|
||||
|
||||
flattened = tbl.search([0.0, 0.0]).limit(1).to_pandas(flatten=1)
|
||||
assert len(flattened.columns) == 6 # _distance is automatically added
|
||||
|
||||
really_flattened = tbl.search([0.0, 0.0]).limit(1).to_pandas(flatten=True)
|
||||
assert len(really_flattened.columns) == 7
|
||||
|
||||
|
||||
def _add(table, schema):
|
||||
# table = LanceTable(db, "test")
|
||||
|
||||
Reference in New Issue
Block a user