diff --git a/python/lancedb/query.py b/python/lancedb/query.py index 886ec966..51eab73a 100644 --- a/python/lancedb/query.py +++ b/python/lancedb/query.py @@ -260,6 +260,17 @@ class LanceQueryBuilder(ABC): for row in self.to_arrow().to_pylist() ] + def to_polars(self) -> "pl.DataFrame": + """ + Execute the query and return the results as a Polars DataFrame. + In addition to the selected columns, LanceDB also returns a vector + and also the "_distance" column which is the distance between the query + vector and the returned vector. + """ + import polars as pl + + return pl.from_arrow(self.to_arrow()) + def limit(self, limit: Union[int, None]) -> LanceQueryBuilder: """Set the maximum number of results to return. diff --git a/python/lancedb/table.py b/python/lancedb/table.py index ae959c8d..2e01469e 100644 --- a/python/lancedb/table.py +++ b/python/lancedb/table.py @@ -31,7 +31,13 @@ from .common import DATA, VEC, VECTOR_COLUMN_NAME from .embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry from .pydantic import LanceModel, model_to_dict from .query import LanceQueryBuilder, Query -from .util import fs_from_uri, safe_import_pandas, value_to_sql, join_uri +from .util import ( + fs_from_uri, + safe_import_pandas, + safe_import_polars, + value_to_sql, + join_uri, +) if TYPE_CHECKING: from datetime import timedelta @@ -40,6 +46,7 @@ if TYPE_CHECKING: pd = safe_import_pandas() +pl = safe_import_polars() def _sanitize_data( @@ -65,6 +72,8 @@ def _sanitize_data( meta = data.schema.metadata if data.schema.metadata is not None else {} meta = {k: v for k, v in meta.items() if k != b"pandas"} data = data.replace_schema_metadata(meta) + elif pl is not None and isinstance(data, pl.DataFrame): + data = data.to_arrow() if isinstance(data, pa.Table): if metadata: @@ -1268,7 +1277,8 @@ def _sanitize_vector_column( """ # ChunkedArray is annoying to work with, so we combine chunks here vec_arr = data[vector_column_name].combine_chunks() - if pa.types.is_list(data[vector_column_name].type): + typ = data[vector_column_name].type + if pa.types.is_list(typ) or pa.types.is_large_list(typ): # if it's a variable size list array, # we make sure the dimensions are all the same has_jagged_ndims = len(vec_arr.values) % len(data) != 0 diff --git a/python/lancedb/util.py b/python/lancedb/util.py index f38ed49a..6122e4e0 100644 --- a/python/lancedb/util.py +++ b/python/lancedb/util.py @@ -123,6 +123,15 @@ def safe_import_pandas(): return None +def safe_import_polars(): + try: + import polars as pl + + return pl + except ImportError: + return None + + @singledispatch def value_to_sql(value): raise NotImplementedError("SQL conversion is not implemented for this type") diff --git a/python/pyproject.toml b/python/pyproject.toml index e993d9d7..6c52fbb8 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -45,7 +45,7 @@ classifiers = [ repository = "https://github.com/lancedb/lancedb" [project.optional-dependencies] -tests = ["aiohttp", "pandas>=1.4", "pytest", "pytest-mock", "pytest-asyncio", "duckdb", "pytz"] +tests = ["aiohttp", "pandas>=1.4", "pytest", "pytest-mock", "pytest-asyncio", "duckdb", "pytz", "polars"] dev = ["ruff", "pre-commit"] docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"] clip = ["torch", "pillow", "open-clip"] diff --git a/python/tests/test_table.py b/python/tests/test_table.py index 3f096a53..c38b074f 100644 --- a/python/tests/test_table.py +++ b/python/tests/test_table.py @@ -20,6 +20,7 @@ from unittest.mock import PropertyMock, patch import lance import numpy as np import pandas as pd +import polars as pl import pyarrow as pa import pytest from pydantic import BaseModel @@ -182,6 +183,36 @@ def test_add_pydantic_model(db): assert len(really_flattened.columns) == 7 +def test_polars(db): + data = { + "vector": [[3.1, 4.1], [5.9, 26.5]], + "item": ["foo", "bar"], + "price": [10.0, 20.0], + } + table = LanceTable.create(db, "test", data=pl.DataFrame(data)) + assert len(table) == 2 + + result = table.to_pandas() + assert np.allclose(result["vector"].tolist(), data["vector"]) + assert result["item"].tolist() == data["item"] + assert np.allclose(result["price"].tolist(), data["price"]) + + schema = pa.schema( + [ + pa.field("vector", pa.list_(pa.float32(), 2)), + pa.field("item", pa.large_string()), + pa.field("price", pa.float64()), + ] + ) + assert table.schema == schema + + q = [3.1, 4.1] + result = table.search(q).limit(1).to_polars() + assert np.allclose(result["vector"][0], q) + assert result["item"][0] == "foo" + assert np.allclose(result["price"][0], 10.0) + + def _add(table, schema): # table = LanceTable(db, "test") assert len(table) == 2