From ac3d95ec3489db167f2e0cf406958de34a347853 Mon Sep 17 00:00:00 2001 From: Chang She <759245+changhiskhan@users.noreply.github.com> Date: Mon, 15 Jan 2024 15:49:16 -0800 Subject: [PATCH] feat(python): allow the entire table to be converted a polars dataframe (#814) --- docs/src/basic.md | 2 +- python/lancedb/table.py | 26 +++++++++++++++++++++++++- python/tests/test_table.py | 10 ++++++++++ 3 files changed, 36 insertions(+), 2 deletions(-) diff --git a/docs/src/basic.md b/docs/src/basic.md index 7d8850b3..98553ce6 100644 --- a/docs/src/basic.md +++ b/docs/src/basic.md @@ -67,7 +67,7 @@ We'll cover the basics of using LanceDB on your local machine in this section. !!! warning If the table already exists, LanceDB will raise an error by default. - If you want to overwrite the table, you can pass in `mode="overwrite"` + If you want to make sure you overwrite the table, pass in `mode="overwrite"` to the `createTable` function. === "Javascript" diff --git a/python/lancedb/table.py b/python/lancedb/table.py index f914838d..1a6e6ff7 100644 --- a/python/lancedb/table.py +++ b/python/lancedb/table.py @@ -72,7 +72,7 @@ def _sanitize_data( meta = data.schema.metadata if data.schema.metadata is not None else {} meta = {k: v for k, v in meta.items() if k != b"pandas"} data = data.replace_schema_metadata(meta) - elif pl is not None and isinstance(data, pl.DataFrame): + elif pl is not None and isinstance(data, (pl.DataFrame, pl.LazyFrame)): data = data.to_arrow() if isinstance(data, pa.Table): @@ -696,6 +696,30 @@ class LanceTable(Table): pa.Table""" return self._dataset.to_table() + def to_polars(self, batch_size=None) -> "pl.LazyFrame": + """Return the table as a polars LazyFrame. + + Parameters + ---------- + batch_size: int, optional + Passed to polars. This is the maximum row count for + scanned pyarrow record batches + + Note + ---- + 1. This requires polars to be installed separately + 2. Currently we've disabled push-down of the filters from polars + because polars pushdown into pyarrow uses pyarrow compute + expressions rather than SQl strings (which LanceDB supports) + + Returns + ------- + pl.LazyFrame + """ + return pl.scan_pyarrow_dataset( + self.to_lance(), allow_pyarrow_filter=False, batch_size=batch_size + ) + @property def _dataset_uri(self) -> str: return join_uri(self._conn.uri, f"{self.name}.lance") diff --git a/python/tests/test_table.py b/python/tests/test_table.py index c38b074f..339f8668 100644 --- a/python/tests/test_table.py +++ b/python/tests/test_table.py @@ -189,6 +189,7 @@ def test_polars(db): "item": ["foo", "bar"], "price": [10.0, 20.0], } + # Ingest polars dataframe table = LanceTable.create(db, "test", data=pl.DataFrame(data)) assert len(table) == 2 @@ -206,12 +207,21 @@ def test_polars(db): ) assert table.schema == schema + # search results to polars dataframe q = [3.1, 4.1] result = table.search(q).limit(1).to_polars() assert np.allclose(result["vector"][0], q) assert result["item"][0] == "foo" assert np.allclose(result["price"][0], 10.0) + # enter table to polars dataframe + result = table.to_polars() + assert np.allclose(result.collect()["vector"].to_list(), data["vector"]) + + # make sure filtering isn't broken + filtered_result = result.filter(pl.col("item").is_in(["foo", "bar"])).collect() + assert len(filtered_result) == 2 + def _add(table, schema): # table = LanceTable(db, "test")