feat(python): basic polars integration (#811)

We should now be able to directly ingest polars dataframes and return
results as polars dataframes

![image](https://github.com/lancedb/lancedb/assets/759245/828b1260-c791-45f1-a047-aa649575e798)
This commit is contained in:
Chang She
2024-01-13 16:38:16 -08:00
committed by Weston Pace
parent 2f72d5138e
commit 17dcb70076
5 changed files with 64 additions and 3 deletions

View File

@@ -260,6 +260,17 @@ class LanceQueryBuilder(ABC):
for row in self.to_arrow().to_pylist()
]
def to_polars(self) -> "pl.DataFrame":
"""
Execute the query and return the results as a Polars DataFrame.
In addition to the selected columns, LanceDB also returns a vector
and also the "_distance" column which is the distance between the query
vector and the returned vector.
"""
import polars as pl
return pl.from_arrow(self.to_arrow())
def limit(self, limit: Union[int, None]) -> LanceQueryBuilder:
"""Set the maximum number of results to return.

View File

@@ -31,7 +31,13 @@ from .common import DATA, VEC, VECTOR_COLUMN_NAME
from .embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
from .pydantic import LanceModel, model_to_dict
from .query import LanceQueryBuilder, Query
from .util import fs_from_uri, safe_import_pandas, value_to_sql, join_uri
from .util import (
fs_from_uri,
safe_import_pandas,
safe_import_polars,
value_to_sql,
join_uri,
)
if TYPE_CHECKING:
from datetime import timedelta
@@ -40,6 +46,7 @@ if TYPE_CHECKING:
pd = safe_import_pandas()
pl = safe_import_polars()
def _sanitize_data(
@@ -65,6 +72,8 @@ def _sanitize_data(
meta = data.schema.metadata if data.schema.metadata is not None else {}
meta = {k: v for k, v in meta.items() if k != b"pandas"}
data = data.replace_schema_metadata(meta)
elif pl is not None and isinstance(data, pl.DataFrame):
data = data.to_arrow()
if isinstance(data, pa.Table):
if metadata:
@@ -1268,7 +1277,8 @@ def _sanitize_vector_column(
"""
# ChunkedArray is annoying to work with, so we combine chunks here
vec_arr = data[vector_column_name].combine_chunks()
if pa.types.is_list(data[vector_column_name].type):
typ = data[vector_column_name].type
if pa.types.is_list(typ) or pa.types.is_large_list(typ):
# if it's a variable size list array,
# we make sure the dimensions are all the same
has_jagged_ndims = len(vec_arr.values) % len(data) != 0

View File

@@ -123,6 +123,15 @@ def safe_import_pandas():
return None
def safe_import_polars():
try:
import polars as pl
return pl
except ImportError:
return None
@singledispatch
def value_to_sql(value):
raise NotImplementedError("SQL conversion is not implemented for this type")

View File

@@ -45,7 +45,7 @@ classifiers = [
repository = "https://github.com/lancedb/lancedb"
[project.optional-dependencies]
tests = ["aiohttp", "pandas>=1.4", "pytest", "pytest-mock", "pytest-asyncio", "duckdb", "pytz"]
tests = ["aiohttp", "pandas>=1.4", "pytest", "pytest-mock", "pytest-asyncio", "duckdb", "pytz", "polars"]
dev = ["ruff", "pre-commit"]
docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"]
clip = ["torch", "pillow", "open-clip"]

View File

@@ -20,6 +20,7 @@ from unittest.mock import PropertyMock, patch
import lance
import numpy as np
import pandas as pd
import polars as pl
import pyarrow as pa
import pytest
from pydantic import BaseModel
@@ -182,6 +183,36 @@ def test_add_pydantic_model(db):
assert len(really_flattened.columns) == 7
def test_polars(db):
data = {
"vector": [[3.1, 4.1], [5.9, 26.5]],
"item": ["foo", "bar"],
"price": [10.0, 20.0],
}
table = LanceTable.create(db, "test", data=pl.DataFrame(data))
assert len(table) == 2
result = table.to_pandas()
assert np.allclose(result["vector"].tolist(), data["vector"])
assert result["item"].tolist() == data["item"]
assert np.allclose(result["price"].tolist(), data["price"])
schema = pa.schema(
[
pa.field("vector", pa.list_(pa.float32(), 2)),
pa.field("item", pa.large_string()),
pa.field("price", pa.float64()),
]
)
assert table.schema == schema
q = [3.1, 4.1]
result = table.search(q).limit(1).to_polars()
assert np.allclose(result["vector"][0], q)
assert result["item"][0] == "foo"
assert np.allclose(result["price"][0], 10.0)
def _add(table, schema):
# table = LanceTable(db, "test")
assert len(table) == 2