mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-04 02:42:57 +00:00
feat(python): basic polars integration (#811)
We should now be able to directly ingest polars dataframes and return results as polars dataframes 
This commit is contained in:
@@ -260,6 +260,17 @@ class LanceQueryBuilder(ABC):
|
||||
for row in self.to_arrow().to_pylist()
|
||||
]
|
||||
|
||||
def to_polars(self) -> "pl.DataFrame":
|
||||
"""
|
||||
Execute the query and return the results as a Polars DataFrame.
|
||||
In addition to the selected columns, LanceDB also returns a vector
|
||||
and also the "_distance" column which is the distance between the query
|
||||
vector and the returned vector.
|
||||
"""
|
||||
import polars as pl
|
||||
|
||||
return pl.from_arrow(self.to_arrow())
|
||||
|
||||
def limit(self, limit: Union[int, None]) -> LanceQueryBuilder:
|
||||
"""Set the maximum number of results to return.
|
||||
|
||||
|
||||
@@ -31,7 +31,13 @@ from .common import DATA, VEC, VECTOR_COLUMN_NAME
|
||||
from .embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
|
||||
from .pydantic import LanceModel, model_to_dict
|
||||
from .query import LanceQueryBuilder, Query
|
||||
from .util import fs_from_uri, safe_import_pandas, value_to_sql, join_uri
|
||||
from .util import (
|
||||
fs_from_uri,
|
||||
safe_import_pandas,
|
||||
safe_import_polars,
|
||||
value_to_sql,
|
||||
join_uri,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from datetime import timedelta
|
||||
@@ -40,6 +46,7 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
pd = safe_import_pandas()
|
||||
pl = safe_import_polars()
|
||||
|
||||
|
||||
def _sanitize_data(
|
||||
@@ -65,6 +72,8 @@ def _sanitize_data(
|
||||
meta = data.schema.metadata if data.schema.metadata is not None else {}
|
||||
meta = {k: v for k, v in meta.items() if k != b"pandas"}
|
||||
data = data.replace_schema_metadata(meta)
|
||||
elif pl is not None and isinstance(data, pl.DataFrame):
|
||||
data = data.to_arrow()
|
||||
|
||||
if isinstance(data, pa.Table):
|
||||
if metadata:
|
||||
@@ -1268,7 +1277,8 @@ def _sanitize_vector_column(
|
||||
"""
|
||||
# ChunkedArray is annoying to work with, so we combine chunks here
|
||||
vec_arr = data[vector_column_name].combine_chunks()
|
||||
if pa.types.is_list(data[vector_column_name].type):
|
||||
typ = data[vector_column_name].type
|
||||
if pa.types.is_list(typ) or pa.types.is_large_list(typ):
|
||||
# if it's a variable size list array,
|
||||
# we make sure the dimensions are all the same
|
||||
has_jagged_ndims = len(vec_arr.values) % len(data) != 0
|
||||
|
||||
@@ -123,6 +123,15 @@ def safe_import_pandas():
|
||||
return None
|
||||
|
||||
|
||||
def safe_import_polars():
|
||||
try:
|
||||
import polars as pl
|
||||
|
||||
return pl
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
|
||||
@singledispatch
|
||||
def value_to_sql(value):
|
||||
raise NotImplementedError("SQL conversion is not implemented for this type")
|
||||
|
||||
@@ -45,7 +45,7 @@ classifiers = [
|
||||
repository = "https://github.com/lancedb/lancedb"
|
||||
|
||||
[project.optional-dependencies]
|
||||
tests = ["aiohttp", "pandas>=1.4", "pytest", "pytest-mock", "pytest-asyncio", "duckdb", "pytz"]
|
||||
tests = ["aiohttp", "pandas>=1.4", "pytest", "pytest-mock", "pytest-asyncio", "duckdb", "pytz", "polars"]
|
||||
dev = ["ruff", "pre-commit"]
|
||||
docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"]
|
||||
clip = ["torch", "pillow", "open-clip"]
|
||||
|
||||
@@ -20,6 +20,7 @@ from unittest.mock import PropertyMock, patch
|
||||
import lance
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import polars as pl
|
||||
import pyarrow as pa
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
@@ -182,6 +183,36 @@ def test_add_pydantic_model(db):
|
||||
assert len(really_flattened.columns) == 7
|
||||
|
||||
|
||||
def test_polars(db):
|
||||
data = {
|
||||
"vector": [[3.1, 4.1], [5.9, 26.5]],
|
||||
"item": ["foo", "bar"],
|
||||
"price": [10.0, 20.0],
|
||||
}
|
||||
table = LanceTable.create(db, "test", data=pl.DataFrame(data))
|
||||
assert len(table) == 2
|
||||
|
||||
result = table.to_pandas()
|
||||
assert np.allclose(result["vector"].tolist(), data["vector"])
|
||||
assert result["item"].tolist() == data["item"]
|
||||
assert np.allclose(result["price"].tolist(), data["price"])
|
||||
|
||||
schema = pa.schema(
|
||||
[
|
||||
pa.field("vector", pa.list_(pa.float32(), 2)),
|
||||
pa.field("item", pa.large_string()),
|
||||
pa.field("price", pa.float64()),
|
||||
]
|
||||
)
|
||||
assert table.schema == schema
|
||||
|
||||
q = [3.1, 4.1]
|
||||
result = table.search(q).limit(1).to_polars()
|
||||
assert np.allclose(result["vector"][0], q)
|
||||
assert result["item"][0] == "foo"
|
||||
assert np.allclose(result["price"][0], 10.0)
|
||||
|
||||
|
||||
def _add(table, schema):
|
||||
# table = LanceTable(db, "test")
|
||||
assert len(table) == 2
|
||||
|
||||
Reference in New Issue
Block a user