mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-03 18:32:55 +00:00
feat(python): allow the entire table to be converted a polars dataframe (#814)
This commit is contained in:
@@ -73,7 +73,7 @@ def _sanitize_data(
|
||||
meta = data.schema.metadata if data.schema.metadata is not None else {}
|
||||
meta = {k: v for k, v in meta.items() if k != b"pandas"}
|
||||
data = data.replace_schema_metadata(meta)
|
||||
elif pl is not None and isinstance(data, pl.DataFrame):
|
||||
elif pl is not None and isinstance(data, (pl.DataFrame, pl.LazyFrame)):
|
||||
data = data.to_arrow()
|
||||
|
||||
if isinstance(data, pa.Table):
|
||||
@@ -697,6 +697,30 @@ class LanceTable(Table):
|
||||
pa.Table"""
|
||||
return self._dataset.to_table()
|
||||
|
||||
def to_polars(self, batch_size=None) -> "pl.LazyFrame":
|
||||
"""Return the table as a polars LazyFrame.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
batch_size: int, optional
|
||||
Passed to polars. This is the maximum row count for
|
||||
scanned pyarrow record batches
|
||||
|
||||
Note
|
||||
----
|
||||
1. This requires polars to be installed separately
|
||||
2. Currently we've disabled push-down of the filters from polars
|
||||
because polars pushdown into pyarrow uses pyarrow compute
|
||||
expressions rather than SQl strings (which LanceDB supports)
|
||||
|
||||
Returns
|
||||
-------
|
||||
pl.LazyFrame
|
||||
"""
|
||||
return pl.scan_pyarrow_dataset(
|
||||
self.to_lance(), allow_pyarrow_filter=False, batch_size=batch_size
|
||||
)
|
||||
|
||||
@property
|
||||
def _dataset_uri(self) -> str:
|
||||
return join_uri(self._conn.uri, f"{self.name}.lance")
|
||||
|
||||
@@ -189,6 +189,7 @@ def test_polars(db):
|
||||
"item": ["foo", "bar"],
|
||||
"price": [10.0, 20.0],
|
||||
}
|
||||
# Ingest polars dataframe
|
||||
table = LanceTable.create(db, "test", data=pl.DataFrame(data))
|
||||
assert len(table) == 2
|
||||
|
||||
@@ -206,12 +207,21 @@ def test_polars(db):
|
||||
)
|
||||
assert table.schema == schema
|
||||
|
||||
# search results to polars dataframe
|
||||
q = [3.1, 4.1]
|
||||
result = table.search(q).limit(1).to_polars()
|
||||
assert np.allclose(result["vector"][0], q)
|
||||
assert result["item"][0] == "foo"
|
||||
assert np.allclose(result["price"][0], 10.0)
|
||||
|
||||
# enter table to polars dataframe
|
||||
result = table.to_polars()
|
||||
assert np.allclose(result.collect()["vector"].to_list(), data["vector"])
|
||||
|
||||
# make sure filtering isn't broken
|
||||
filtered_result = result.filter(pl.col("item").is_in(["foo", "bar"])).collect()
|
||||
assert len(filtered_result) == 2
|
||||
|
||||
|
||||
def _add(table, schema):
|
||||
# table = LanceTable(db, "test")
|
||||
|
||||
Reference in New Issue
Block a user