mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-05 11:22:58 +00:00
feat(python): allow the entire table to be converted a polars dataframe (#814)
This commit is contained in:
@@ -189,6 +189,7 @@ def test_polars(db):
|
||||
"item": ["foo", "bar"],
|
||||
"price": [10.0, 20.0],
|
||||
}
|
||||
# Ingest polars dataframe
|
||||
table = LanceTable.create(db, "test", data=pl.DataFrame(data))
|
||||
assert len(table) == 2
|
||||
|
||||
@@ -206,12 +207,21 @@ def test_polars(db):
|
||||
)
|
||||
assert table.schema == schema
|
||||
|
||||
# search results to polars dataframe
|
||||
q = [3.1, 4.1]
|
||||
result = table.search(q).limit(1).to_polars()
|
||||
assert np.allclose(result["vector"][0], q)
|
||||
assert result["item"][0] == "foo"
|
||||
assert np.allclose(result["price"][0], 10.0)
|
||||
|
||||
# enter table to polars dataframe
|
||||
result = table.to_polars()
|
||||
assert np.allclose(result.collect()["vector"].to_list(), data["vector"])
|
||||
|
||||
# make sure filtering isn't broken
|
||||
filtered_result = result.filter(pl.col("item").is_in(["foo", "bar"])).collect()
|
||||
assert len(filtered_result) == 2
|
||||
|
||||
|
||||
def _add(table, schema):
|
||||
# table = LanceTable(db, "test")
|
||||
|
||||
Reference in New Issue
Block a user