feat(python): allow the entire table to be converted a polars dataframe (#814)

This commit is contained in:
Chang She
2024-01-15 15:49:16 -08:00
committed by GitHub
parent be4ab9eef3
commit af8263af94
3 changed files with 36 additions and 2 deletions

View File

@@ -73,7 +73,7 @@ def _sanitize_data(
meta = data.schema.metadata if data.schema.metadata is not None else {}
meta = {k: v for k, v in meta.items() if k != b"pandas"}
data = data.replace_schema_metadata(meta)
elif pl is not None and isinstance(data, pl.DataFrame):
elif pl is not None and isinstance(data, (pl.DataFrame, pl.LazyFrame)):
data = data.to_arrow()
if isinstance(data, pa.Table):
@@ -697,6 +697,30 @@ class LanceTable(Table):
pa.Table"""
return self._dataset.to_table()
def to_polars(self, batch_size=None) -> "pl.LazyFrame":
"""Return the table as a polars LazyFrame.
Parameters
----------
batch_size: int, optional
Passed to polars. This is the maximum row count for
scanned pyarrow record batches
Note
----
1. This requires polars to be installed separately
2. Currently we've disabled push-down of the filters from polars
because polars pushdown into pyarrow uses pyarrow compute
expressions rather than SQl strings (which LanceDB supports)
Returns
-------
pl.LazyFrame
"""
return pl.scan_pyarrow_dataset(
self.to_lance(), allow_pyarrow_filter=False, batch_size=batch_size
)
@property
def _dataset_uri(self) -> str:
return join_uri(self._conn.uri, f"{self.name}.lance")

View File

@@ -189,6 +189,7 @@ def test_polars(db):
"item": ["foo", "bar"],
"price": [10.0, 20.0],
}
# Ingest polars dataframe
table = LanceTable.create(db, "test", data=pl.DataFrame(data))
assert len(table) == 2
@@ -206,12 +207,21 @@ def test_polars(db):
)
assert table.schema == schema
# search results to polars dataframe
q = [3.1, 4.1]
result = table.search(q).limit(1).to_polars()
assert np.allclose(result["vector"][0], q)
assert result["item"][0] == "foo"
assert np.allclose(result["price"][0], 10.0)
# enter table to polars dataframe
result = table.to_polars()
assert np.allclose(result.collect()["vector"].to_list(), data["vector"])
# make sure filtering isn't broken
filtered_result = result.filter(pl.col("item").is_in(["foo", "bar"])).collect()
assert len(filtered_result) == 2
def _add(table, schema):
# table = LanceTable(db, "test")