mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-10 22:02:58 +00:00
feat(python): basic polars integration (#811)
We should now be able to directly ingest polars dataframes and return results as polars dataframes 
This commit is contained in:
@@ -260,6 +260,17 @@ class LanceQueryBuilder(ABC):
|
||||
for row in self.to_arrow().to_pylist()
|
||||
]
|
||||
|
||||
def to_polars(self) -> "pl.DataFrame":
|
||||
"""
|
||||
Execute the query and return the results as a Polars DataFrame.
|
||||
In addition to the selected columns, LanceDB also returns a vector
|
||||
and also the "_distance" column which is the distance between the query
|
||||
vector and the returned vector.
|
||||
"""
|
||||
import polars as pl
|
||||
|
||||
return pl.from_arrow(self.to_arrow())
|
||||
|
||||
def limit(self, limit: Union[int, None]) -> LanceQueryBuilder:
|
||||
"""Set the maximum number of results to return.
|
||||
|
||||
|
||||
@@ -31,7 +31,13 @@ from .common import DATA, VEC, VECTOR_COLUMN_NAME
|
||||
from .embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
|
||||
from .pydantic import LanceModel, model_to_dict
|
||||
from .query import LanceQueryBuilder, Query
|
||||
from .util import fs_from_uri, safe_import_pandas, value_to_sql, join_uri
|
||||
from .util import (
|
||||
fs_from_uri,
|
||||
safe_import_pandas,
|
||||
safe_import_polars,
|
||||
value_to_sql,
|
||||
join_uri,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from datetime import timedelta
|
||||
@@ -40,6 +46,7 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
pd = safe_import_pandas()
|
||||
pl = safe_import_polars()
|
||||
|
||||
|
||||
def _sanitize_data(
|
||||
@@ -65,6 +72,8 @@ def _sanitize_data(
|
||||
meta = data.schema.metadata if data.schema.metadata is not None else {}
|
||||
meta = {k: v for k, v in meta.items() if k != b"pandas"}
|
||||
data = data.replace_schema_metadata(meta)
|
||||
elif pl is not None and isinstance(data, pl.DataFrame):
|
||||
data = data.to_arrow()
|
||||
|
||||
if isinstance(data, pa.Table):
|
||||
if metadata:
|
||||
@@ -1268,7 +1277,8 @@ def _sanitize_vector_column(
|
||||
"""
|
||||
# ChunkedArray is annoying to work with, so we combine chunks here
|
||||
vec_arr = data[vector_column_name].combine_chunks()
|
||||
if pa.types.is_list(data[vector_column_name].type):
|
||||
typ = data[vector_column_name].type
|
||||
if pa.types.is_list(typ) or pa.types.is_large_list(typ):
|
||||
# if it's a variable size list array,
|
||||
# we make sure the dimensions are all the same
|
||||
has_jagged_ndims = len(vec_arr.values) % len(data) != 0
|
||||
|
||||
@@ -123,6 +123,15 @@ def safe_import_pandas():
|
||||
return None
|
||||
|
||||
|
||||
def safe_import_polars():
|
||||
try:
|
||||
import polars as pl
|
||||
|
||||
return pl
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
|
||||
@singledispatch
|
||||
def value_to_sql(value):
|
||||
raise NotImplementedError("SQL conversion is not implemented for this type")
|
||||
|
||||
Reference in New Issue
Block a user