feat(python): basic polars integration (#811)

We should now be able to directly ingest polars dataframes and return
results as polars dataframes

![image](https://github.com/lancedb/lancedb/assets/759245/828b1260-c791-45f1-a047-aa649575e798)
This commit is contained in:
Chang She
2024-01-13 16:38:16 -08:00
committed by Weston Pace
parent 2f72d5138e
commit 17dcb70076
5 changed files with 64 additions and 3 deletions

View File

@@ -260,6 +260,17 @@ class LanceQueryBuilder(ABC):
for row in self.to_arrow().to_pylist()
]
def to_polars(self) -> "pl.DataFrame":
"""
Execute the query and return the results as a Polars DataFrame.
In addition to the selected columns, LanceDB also returns a vector
and also the "_distance" column which is the distance between the query
vector and the returned vector.
"""
import polars as pl
return pl.from_arrow(self.to_arrow())
def limit(self, limit: Union[int, None]) -> LanceQueryBuilder:
"""Set the maximum number of results to return.

View File

@@ -31,7 +31,13 @@ from .common import DATA, VEC, VECTOR_COLUMN_NAME
from .embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
from .pydantic import LanceModel, model_to_dict
from .query import LanceQueryBuilder, Query
from .util import fs_from_uri, safe_import_pandas, value_to_sql, join_uri
from .util import (
fs_from_uri,
safe_import_pandas,
safe_import_polars,
value_to_sql,
join_uri,
)
if TYPE_CHECKING:
from datetime import timedelta
@@ -40,6 +46,7 @@ if TYPE_CHECKING:
pd = safe_import_pandas()
pl = safe_import_polars()
def _sanitize_data(
@@ -65,6 +72,8 @@ def _sanitize_data(
meta = data.schema.metadata if data.schema.metadata is not None else {}
meta = {k: v for k, v in meta.items() if k != b"pandas"}
data = data.replace_schema_metadata(meta)
elif pl is not None and isinstance(data, pl.DataFrame):
data = data.to_arrow()
if isinstance(data, pa.Table):
if metadata:
@@ -1268,7 +1277,8 @@ def _sanitize_vector_column(
"""
# ChunkedArray is annoying to work with, so we combine chunks here
vec_arr = data[vector_column_name].combine_chunks()
if pa.types.is_list(data[vector_column_name].type):
typ = data[vector_column_name].type
if pa.types.is_list(typ) or pa.types.is_large_list(typ):
# if it's a variable size list array,
# we make sure the dimensions are all the same
has_jagged_ndims = len(vec_arr.values) % len(data) != 0

View File

@@ -123,6 +123,15 @@ def safe_import_pandas():
return None
def safe_import_polars():
try:
import polars as pl
return pl
except ImportError:
return None
@singledispatch
def value_to_sql(value):
raise NotImplementedError("SQL conversion is not implemented for this type")