Files
lancedb/python/lancedb/query.py
Chang She 49333e522c feat(python): basic polars integration (#811)
We should now be able to directly ingest polars dataframes and return
results as polars dataframes


![image](https://github.com/lancedb/lancedb/assets/759245/828b1260-c791-45f1-a047-aa649575e798)
2024-01-13 16:38:16 -08:00

575 lines
19 KiB
Python

# Copyright 2023 LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from abc import ABC, abstractmethod
from pathlib import Path
from typing import TYPE_CHECKING, List, Literal, Optional, Type, Union
import deprecation
import numpy as np
import pyarrow as pa
import pydantic
from . import __version__
from .common import VECTOR_COLUMN_NAME
from .util import safe_import_pandas
if TYPE_CHECKING:
from .pydantic import LanceModel
pd = safe_import_pandas()
class Query(pydantic.BaseModel):
"""The LanceDB Query
Attributes
----------
vector : List[float]
the vector to search for
filter : Optional[str]
sql filter to refine the query with, optional
prefilter : bool
if True then apply the filter before vector search
k : int
top k results to return
metric : str
the distance metric between a pair of vectors,
can support L2 (default), Cosine and Dot.
[metric definitions][search]
columns : Optional[List[str]]
which columns to return in the results
nprobes : int
The number of probes used - optional
- A higher number makes search more accurate but also slower.
- See discussion in [Querying an ANN Index][querying-an-ann-index] for
tuning advice.
refine_factor : Optional[int]
Refine the results by reading extra elements and re-ranking them in memory - optional
- A higher number makes search more accurate but also slower.
- See discussion in [Querying an ANN Index][querying-an-ann-index] for
tuning advice.
"""
vector_column: str = VECTOR_COLUMN_NAME
# vector to search for
vector: Union[List[float], List[List[float]]]
# sql filter to refine the query with
filter: Optional[str] = None
# if True then apply the filter before vector search
prefilter: bool = False
# top k results to return
k: int
# # metrics
metric: str = "L2"
# which columns to return in the results
columns: Optional[List[str]] = None
# optional query parameters for tuning the results,
# e.g. `{"nprobes": "10", "refine_factor": "10"}`
nprobes: int = 10
# Refine factor.
refine_factor: Optional[int] = None
class LanceQueryBuilder(ABC):
"""Build LanceDB query based on specific query type:
vector or full text search.
"""
@classmethod
def create(
cls,
table: "lancedb.table.Table",
query: Optional[Union[np.ndarray, str, "PIL.Image.Image"]],
query_type: str,
vector_column_name: str,
) -> LanceQueryBuilder:
if query is None:
return LanceEmptyQueryBuilder(table)
# convert "auto" query_type to "vector" or "fts"
# and convert the query to vector if needed
query, query_type = cls._resolve_query(
table, query, query_type, vector_column_name
)
if isinstance(query, str):
# fts
return LanceFtsQueryBuilder(table, query)
if isinstance(query, list):
query = np.array(query, dtype=np.float32)
elif isinstance(query, np.ndarray):
query = query.astype(np.float32)
else:
raise TypeError(f"Unsupported query type: {type(query)}")
return LanceVectorQueryBuilder(table, query, vector_column_name)
@classmethod
def _resolve_query(cls, table, query, query_type, vector_column_name):
# If query_type is fts, then query must be a string.
# otherwise raise TypeError
if query_type == "fts":
if not isinstance(query, str):
raise TypeError(f"'fts' queries must be a string: {type(query)}")
return query, query_type
elif query_type == "vector":
if not isinstance(query, (list, np.ndarray)):
conf = table.embedding_functions.get(vector_column_name)
if conf is not None:
query = conf.function.compute_query_embeddings_with_retry(query)[0]
else:
msg = f"No embedding function for {vector_column_name}"
raise ValueError(msg)
return query, query_type
elif query_type == "auto":
if isinstance(query, (list, np.ndarray)):
return query, "vector"
else:
conf = table.embedding_functions.get(vector_column_name)
if conf is not None:
query = conf.function.compute_query_embeddings_with_retry(query)[0]
return query, "vector"
else:
return query, "fts"
else:
raise ValueError(
f"Invalid query_type, must be 'vector', 'fts', or 'auto': {query_type}"
)
def __init__(self, table: "lancedb.table.Table"):
self._table = table
self._limit = 10
self._columns = None
self._where = None
@deprecation.deprecated(
deprecated_in="0.3.1",
removed_in="0.4.0",
current_version=__version__,
details="Use to_pandas() instead",
)
def to_df(self) -> "pd.DataFrame":
"""
*Deprecated alias for `to_pandas()`. Please use `to_pandas()` instead.*
Execute the query and return the results as a pandas DataFrame.
In addition to the selected columns, LanceDB also returns a vector
and also the "_distance" column which is the distance between the query
vector and the returned vector.
"""
return self.to_pandas()
def to_pandas(self, flatten: Optional[Union[int, bool]] = None) -> "pd.DataFrame":
"""
Execute the query and return the results as a pandas DataFrame.
In addition to the selected columns, LanceDB also returns a vector
and also the "_distance" column which is the distance between the query
vector and the returned vector.
Parameters
----------
flatten: Optional[Union[int, bool]]
If flatten is True, flatten all nested columns.
If flatten is an integer, flatten the nested columns up to the
specified depth.
If unspecified, do not flatten the nested columns.
"""
tbl = self.to_arrow()
if flatten is True:
while True:
tbl = tbl.flatten()
has_struct = False
# loop through all columns to check if there is any struct column
if any(pa.types.is_struct(col.type) for col in tbl.schema):
continue
else:
break
elif isinstance(flatten, int):
if flatten <= 0:
raise ValueError(
"Please specify a positive integer for flatten or the boolean value `True`"
)
while flatten > 0:
tbl = tbl.flatten()
flatten -= 1
return tbl.to_pandas()
@abstractmethod
def to_arrow(self) -> pa.Table:
"""
Execute the query and return the results as an
[Apache Arrow Table](https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table).
In addition to the selected columns, LanceDB also returns a vector
and also the "_distance" column which is the distance between the query
vector and the returned vectors.
"""
raise NotImplementedError
def to_list(self) -> List[dict]:
"""
Execute the query and return the results as a list of dictionaries.
Each list entry is a dictionary with the selected column names as keys,
or all table columns if `select` is not called. The vector and the "_distance"
fields are returned whether or not they're explicitly selected.
"""
return self.to_arrow().to_pylist()
def to_pydantic(self, model: Type[LanceModel]) -> List[LanceModel]:
"""Return the table as a list of pydantic models.
Parameters
----------
model: Type[LanceModel]
The pydantic model to use.
Returns
-------
List[LanceModel]
"""
return [
model(**{k: v for k, v in row.items() if k in model.field_names()})
for row in self.to_arrow().to_pylist()
]
def to_polars(self) -> "pl.DataFrame":
"""
Execute the query and return the results as a Polars DataFrame.
In addition to the selected columns, LanceDB also returns a vector
and also the "_distance" column which is the distance between the query
vector and the returned vector.
"""
import polars as pl
return pl.from_arrow(self.to_arrow())
def limit(self, limit: Union[int, None]) -> LanceQueryBuilder:
"""Set the maximum number of results to return.
Parameters
----------
limit: int
The maximum number of results to return.
By default the query is limited to the first 10.
Call this method and pass 0, a negative value,
or None to remove the limit.
*WARNING* if you have a large dataset, removing
the limit can potentially result in reading a
large amount of data into memory and cause
out of memory issues.
Returns
-------
LanceQueryBuilder
The LanceQueryBuilder object.
"""
if limit is None or limit <= 0:
self._limit = None
else:
self._limit = limit
return self
def select(self, columns: list) -> LanceQueryBuilder:
"""Set the columns to return.
Parameters
----------
columns: list
The columns to return.
Returns
-------
LanceQueryBuilder
The LanceQueryBuilder object.
"""
self._columns = columns
return self
def where(self, where: str, prefilter: bool = False) -> LanceQueryBuilder:
"""Set the where clause.
Parameters
----------
where: str
The where clause which is a valid SQL where clause. See
`Lance filter pushdown <https://lancedb.github.io/lance/read_and_write.html#filter-push-down>`_
for valid SQL expressions.
prefilter: bool, default False
If True, apply the filter before vector search, otherwise the
filter is applied on the result of vector search.
This feature is **EXPERIMENTAL** and may be removed and modified
without warning in the future.
Returns
-------
LanceQueryBuilder
The LanceQueryBuilder object.
"""
self._where = where
self._prefilter = prefilter
return self
class LanceVectorQueryBuilder(LanceQueryBuilder):
"""
Examples
--------
>>> import lancedb
>>> data = [{"vector": [1.1, 1.2], "b": 2},
... {"vector": [0.5, 1.3], "b": 4},
... {"vector": [0.4, 0.4], "b": 6},
... {"vector": [0.4, 0.4], "b": 10}]
>>> db = lancedb.connect("./.lancedb")
>>> table = db.create_table("my_table", data=data)
>>> (table.search([0.4, 0.4])
... .metric("cosine")
... .where("b < 10")
... .select(["b"])
... .limit(2)
... .to_pandas())
b vector _distance
0 6 [0.4, 0.4] 0.0
"""
def __init__(
self,
table: "lancedb.table.Table",
query: Union[np.ndarray, list, "PIL.Image.Image"],
vector_column: str = VECTOR_COLUMN_NAME,
):
super().__init__(table)
self._query = query
self._metric = "L2"
self._nprobes = 20
self._refine_factor = None
self._vector_column = vector_column
self._prefilter = False
def metric(self, metric: Literal["L2", "cosine"]) -> LanceVectorQueryBuilder:
"""Set the distance metric to use.
Parameters
----------
metric: "L2" or "cosine"
The distance metric to use. By default "L2" is used.
Returns
-------
LanceVectorQueryBuilder
The LanceQueryBuilder object.
"""
self._metric = metric
return self
def nprobes(self, nprobes: int) -> LanceVectorQueryBuilder:
"""Set the number of probes to use.
Higher values will yield better recall (more likely to find vectors if
they exist) at the expense of latency.
See discussion in [Querying an ANN Index][querying-an-ann-index] for
tuning advice.
Parameters
----------
nprobes: int
The number of probes to use.
Returns
-------
LanceVectorQueryBuilder
The LanceQueryBuilder object.
"""
self._nprobes = nprobes
return self
def refine_factor(self, refine_factor: int) -> LanceVectorQueryBuilder:
"""Set the refine factor to use, increasing the number of vectors sampled.
As an example, a refine factor of 2 will sample 2x as many vectors as
requested, re-ranks them, and returns the top half most relevant results.
See discussion in [Querying an ANN Index][querying-an-ann-index] for
tuning advice.
Parameters
----------
refine_factor: int
The refine factor to use.
Returns
-------
LanceVectorQueryBuilder
The LanceQueryBuilder object.
"""
self._refine_factor = refine_factor
return self
def to_arrow(self) -> pa.Table:
"""
Execute the query and return the results as an
[Apache Arrow Table](https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table).
In addition to the selected columns, LanceDB also returns a vector
and also the "_distance" column which is the distance between the query
vector and the returned vectors.
"""
vector = self._query if isinstance(self._query, list) else self._query.tolist()
if isinstance(vector[0], np.ndarray):
vector = [v.tolist() for v in vector]
query = Query(
vector=vector,
filter=self._where,
prefilter=self._prefilter,
k=self._limit,
metric=self._metric,
columns=self._columns,
nprobes=self._nprobes,
refine_factor=self._refine_factor,
vector_column=self._vector_column,
)
return self._table._execute_query(query)
def where(self, where: str, prefilter: bool = False) -> LanceVectorQueryBuilder:
"""Set the where clause.
Parameters
----------
where: str
The where clause which is a valid SQL where clause. See
`Lance filter pushdown <https://lancedb.github.io/lance/read_and_write.html#filter-push-down>`_
for valid SQL expressions.
prefilter: bool, default False
If True, apply the filter before vector search, otherwise the
filter is applied on the result of vector search.
This feature is **EXPERIMENTAL** and may be removed and modified
without warning in the future.
Returns
-------
LanceQueryBuilder
The LanceQueryBuilder object.
"""
self._where = where
self._prefilter = prefilter
return self
class LanceFtsQueryBuilder(LanceQueryBuilder):
"""A builder for full text search for LanceDB."""
def __init__(self, table: "lancedb.table.Table", query: str):
super().__init__(table)
self._query = query
self._phrase_query = False
def phrase_query(self, phrase_query: bool = True) -> LanceFtsQueryBuilder:
"""Set whether to use phrase query.
Parameters
----------
phrase_query: bool, default True
If True, then the query will be wrapped in quotes and
double quotes replaced by single quotes.
Returns
-------
LanceFtsQueryBuilder
The LanceFtsQueryBuilder object.
"""
self._phrase_query = phrase_query
return self
def to_arrow(self) -> pa.Table:
try:
import tantivy
except ImportError:
raise ImportError(
"Please install tantivy-py `pip install tantivy@git+https://github.com/quickwit-oss/tantivy-py#164adc87e1a033117001cf70e38c82a53014d985` to use the full text search feature."
)
from .fts import search_index
# get the index path
index_path = self._table._get_fts_index_path()
# check if the index exist
if not Path(index_path).exists():
raise FileNotFoundError(
"Fts index does not exist."
f"Please first call table.create_fts_index(['<field_names>']) to create the fts index."
)
# open the index
index = tantivy.Index.open(index_path)
# get the scores and doc ids
query = self._query
if self._phrase_query:
query = query.replace('"', "'")
query = f'"{query}"'
row_ids, scores = search_index(index, query, self._limit)
if len(row_ids) == 0:
empty_schema = pa.schema([pa.field("score", pa.float32())])
return pa.Table.from_pylist([], schema=empty_schema)
scores = pa.array(scores)
output_tbl = self._table.to_lance().take(row_ids, columns=self._columns)
output_tbl = output_tbl.append_column("score", scores)
if self._where is not None:
try:
# TODO would be great to have Substrait generate pyarrow compute expressions
# or conversely have pyarrow support SQL expressions using Substrait
import duckdb
output_tbl = (
duckdb.sql(f"SELECT * FROM output_tbl")
.filter(self._where)
.to_arrow_table()
)
except ImportError:
import lance
import tempfile
# TODO Use "memory://" instead once that's supported
with tempfile.TemporaryDirectory() as tmp:
ds = lance.write_dataset(output_tbl, tmp)
output_tbl = ds.to_table(filter=self._where)
return output_tbl
class LanceEmptyQueryBuilder(LanceQueryBuilder):
def to_arrow(self) -> pa.Table:
ds = self._table.to_lance()
return ds.to_table(
columns=self._columns,
filter=self._where,
limit=self._limit,
)