feat(python): allow the entire table to be converted a polars dataframe (#814)

This commit is contained in:
Chang She
2024-01-15 15:49:16 -08:00
committed by GitHub
parent be4ab9eef3
commit af8263af94
3 changed files with 36 additions and 2 deletions

View File

@@ -67,7 +67,7 @@ We'll cover the basics of using LanceDB on your local machine in this section.
!!! warning !!! warning
If the table already exists, LanceDB will raise an error by default. If the table already exists, LanceDB will raise an error by default.
If you want to overwrite the table, you can pass in `mode="overwrite"` If you want to make sure you overwrite the table, pass in `mode="overwrite"`
to the `createTable` function. to the `createTable` function.
=== "Javascript" === "Javascript"

View File

@@ -73,7 +73,7 @@ def _sanitize_data(
meta = data.schema.metadata if data.schema.metadata is not None else {} meta = data.schema.metadata if data.schema.metadata is not None else {}
meta = {k: v for k, v in meta.items() if k != b"pandas"} meta = {k: v for k, v in meta.items() if k != b"pandas"}
data = data.replace_schema_metadata(meta) data = data.replace_schema_metadata(meta)
elif pl is not None and isinstance(data, pl.DataFrame): elif pl is not None and isinstance(data, (pl.DataFrame, pl.LazyFrame)):
data = data.to_arrow() data = data.to_arrow()
if isinstance(data, pa.Table): if isinstance(data, pa.Table):
@@ -697,6 +697,30 @@ class LanceTable(Table):
pa.Table""" pa.Table"""
return self._dataset.to_table() return self._dataset.to_table()
def to_polars(self, batch_size=None) -> "pl.LazyFrame":
"""Return the table as a polars LazyFrame.
Parameters
----------
batch_size: int, optional
Passed to polars. This is the maximum row count for
scanned pyarrow record batches
Note
----
1. This requires polars to be installed separately
2. Currently we've disabled push-down of the filters from polars
because polars pushdown into pyarrow uses pyarrow compute
expressions rather than SQl strings (which LanceDB supports)
Returns
-------
pl.LazyFrame
"""
return pl.scan_pyarrow_dataset(
self.to_lance(), allow_pyarrow_filter=False, batch_size=batch_size
)
@property @property
def _dataset_uri(self) -> str: def _dataset_uri(self) -> str:
return join_uri(self._conn.uri, f"{self.name}.lance") return join_uri(self._conn.uri, f"{self.name}.lance")

View File

@@ -189,6 +189,7 @@ def test_polars(db):
"item": ["foo", "bar"], "item": ["foo", "bar"],
"price": [10.0, 20.0], "price": [10.0, 20.0],
} }
# Ingest polars dataframe
table = LanceTable.create(db, "test", data=pl.DataFrame(data)) table = LanceTable.create(db, "test", data=pl.DataFrame(data))
assert len(table) == 2 assert len(table) == 2
@@ -206,12 +207,21 @@ def test_polars(db):
) )
assert table.schema == schema assert table.schema == schema
# search results to polars dataframe
q = [3.1, 4.1] q = [3.1, 4.1]
result = table.search(q).limit(1).to_polars() result = table.search(q).limit(1).to_polars()
assert np.allclose(result["vector"][0], q) assert np.allclose(result["vector"][0], q)
assert result["item"][0] == "foo" assert result["item"][0] == "foo"
assert np.allclose(result["price"][0], 10.0) assert np.allclose(result["price"][0], 10.0)
# enter table to polars dataframe
result = table.to_polars()
assert np.allclose(result.collect()["vector"].to_list(), data["vector"])
# make sure filtering isn't broken
filtered_result = result.filter(pl.col("item").is_in(["foo", "bar"])).collect()
assert len(filtered_result) == 2
def _add(table, schema): def _add(table, schema):
# table = LanceTable(db, "test") # table = LanceTable(db, "test")