mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-06 20:02:58 +00:00
feat(python): allow the entire table to be converted a polars dataframe (#814)
This commit is contained in:
@@ -67,7 +67,7 @@ We'll cover the basics of using LanceDB on your local machine in this section.
|
|||||||
!!! warning
|
!!! warning
|
||||||
|
|
||||||
If the table already exists, LanceDB will raise an error by default.
|
If the table already exists, LanceDB will raise an error by default.
|
||||||
If you want to overwrite the table, you can pass in `mode="overwrite"`
|
If you want to make sure you overwrite the table, pass in `mode="overwrite"`
|
||||||
to the `createTable` function.
|
to the `createTable` function.
|
||||||
|
|
||||||
=== "Javascript"
|
=== "Javascript"
|
||||||
|
|||||||
@@ -73,7 +73,7 @@ def _sanitize_data(
|
|||||||
meta = data.schema.metadata if data.schema.metadata is not None else {}
|
meta = data.schema.metadata if data.schema.metadata is not None else {}
|
||||||
meta = {k: v for k, v in meta.items() if k != b"pandas"}
|
meta = {k: v for k, v in meta.items() if k != b"pandas"}
|
||||||
data = data.replace_schema_metadata(meta)
|
data = data.replace_schema_metadata(meta)
|
||||||
elif pl is not None and isinstance(data, pl.DataFrame):
|
elif pl is not None and isinstance(data, (pl.DataFrame, pl.LazyFrame)):
|
||||||
data = data.to_arrow()
|
data = data.to_arrow()
|
||||||
|
|
||||||
if isinstance(data, pa.Table):
|
if isinstance(data, pa.Table):
|
||||||
@@ -697,6 +697,30 @@ class LanceTable(Table):
|
|||||||
pa.Table"""
|
pa.Table"""
|
||||||
return self._dataset.to_table()
|
return self._dataset.to_table()
|
||||||
|
|
||||||
|
def to_polars(self, batch_size=None) -> "pl.LazyFrame":
|
||||||
|
"""Return the table as a polars LazyFrame.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
batch_size: int, optional
|
||||||
|
Passed to polars. This is the maximum row count for
|
||||||
|
scanned pyarrow record batches
|
||||||
|
|
||||||
|
Note
|
||||||
|
----
|
||||||
|
1. This requires polars to be installed separately
|
||||||
|
2. Currently we've disabled push-down of the filters from polars
|
||||||
|
because polars pushdown into pyarrow uses pyarrow compute
|
||||||
|
expressions rather than SQl strings (which LanceDB supports)
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
pl.LazyFrame
|
||||||
|
"""
|
||||||
|
return pl.scan_pyarrow_dataset(
|
||||||
|
self.to_lance(), allow_pyarrow_filter=False, batch_size=batch_size
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _dataset_uri(self) -> str:
|
def _dataset_uri(self) -> str:
|
||||||
return join_uri(self._conn.uri, f"{self.name}.lance")
|
return join_uri(self._conn.uri, f"{self.name}.lance")
|
||||||
|
|||||||
@@ -189,6 +189,7 @@ def test_polars(db):
|
|||||||
"item": ["foo", "bar"],
|
"item": ["foo", "bar"],
|
||||||
"price": [10.0, 20.0],
|
"price": [10.0, 20.0],
|
||||||
}
|
}
|
||||||
|
# Ingest polars dataframe
|
||||||
table = LanceTable.create(db, "test", data=pl.DataFrame(data))
|
table = LanceTable.create(db, "test", data=pl.DataFrame(data))
|
||||||
assert len(table) == 2
|
assert len(table) == 2
|
||||||
|
|
||||||
@@ -206,12 +207,21 @@ def test_polars(db):
|
|||||||
)
|
)
|
||||||
assert table.schema == schema
|
assert table.schema == schema
|
||||||
|
|
||||||
|
# search results to polars dataframe
|
||||||
q = [3.1, 4.1]
|
q = [3.1, 4.1]
|
||||||
result = table.search(q).limit(1).to_polars()
|
result = table.search(q).limit(1).to_polars()
|
||||||
assert np.allclose(result["vector"][0], q)
|
assert np.allclose(result["vector"][0], q)
|
||||||
assert result["item"][0] == "foo"
|
assert result["item"][0] == "foo"
|
||||||
assert np.allclose(result["price"][0], 10.0)
|
assert np.allclose(result["price"][0], 10.0)
|
||||||
|
|
||||||
|
# enter table to polars dataframe
|
||||||
|
result = table.to_polars()
|
||||||
|
assert np.allclose(result.collect()["vector"].to_list(), data["vector"])
|
||||||
|
|
||||||
|
# make sure filtering isn't broken
|
||||||
|
filtered_result = result.filter(pl.col("item").is_in(["foo", "bar"])).collect()
|
||||||
|
assert len(filtered_result) == 2
|
||||||
|
|
||||||
|
|
||||||
def _add(table, schema):
|
def _add(table, schema):
|
||||||
# table = LanceTable(db, "test")
|
# table = LanceTable(db, "test")
|
||||||
|
|||||||
Reference in New Issue
Block a user