mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-23 05:19:58 +00:00
feat!: better api for manual hybrid queries (#1575)
Currently, the only documented way of performing hybrid search is by using embedding API and passing string queries that get automatically embedded. There are use cases where users might like to pass vectors and text manually instead. This ticket contains more information and historical context - https://github.com/lancedb/lancedb/issues/937 This breaks a undocumented pathway that allowed passing (vector, text) tuple queries which was intended to be temporary, so this is marked as a breaking change. For all practical purposes, this should not really impact most users ### usage ``` results = table.search(query_type="hybrid") .vector(vector_query) .text(text_query) .limit(5) .to_pandas() ```
This commit is contained in:
@@ -43,6 +43,19 @@ table.create_fts_index("text")
|
||||
# hybrid search with default re-ranker
|
||||
results = table.search("flower moon", query_type="hybrid").to_pandas()
|
||||
```
|
||||
!!! Note
|
||||
You can also pass the vector and text query manually. This is useful if you're not using the embedding API or if you're using a separate embedder service.
|
||||
### Explicitly passing the vector and text query
|
||||
```python
|
||||
vector_query = [0.1, 0.2, 0.3, 0.4, 0.5]
|
||||
text_query = "flower moon"
|
||||
results = table.search(query_type="hybrid")
|
||||
.vector(vector_query)
|
||||
.text(text_query)
|
||||
.limit(5)
|
||||
.to_pandas()
|
||||
|
||||
```
|
||||
|
||||
By default, LanceDB uses `LinearCombinationReranker(weight=0.7)` to combine and rerank the results of semantic and full-text search. You can customize the hyperparameters as needed or write your own custom reranker. Here's how you can use any of the available rerankers:
|
||||
|
||||
|
||||
@@ -34,7 +34,6 @@ import pydantic
|
||||
|
||||
from . import __version__
|
||||
from .arrow import AsyncRecordBatchReader
|
||||
from .common import VEC
|
||||
from .rerankers.base import Reranker
|
||||
from .rerankers.linear_combination import LinearCombinationReranker
|
||||
from .util import safe_import_pandas
|
||||
@@ -43,6 +42,7 @@ if TYPE_CHECKING:
|
||||
import PIL
|
||||
import polars as pl
|
||||
|
||||
from .common import VEC
|
||||
from ._lancedb import Query as LanceQuery
|
||||
from ._lancedb import VectorQuery as LanceVectorQuery
|
||||
from .pydantic import LanceModel
|
||||
@@ -151,15 +151,16 @@ class LanceQueryBuilder(ABC):
|
||||
vector_column_name: str
|
||||
The name of the vector column to use for vector search.
|
||||
"""
|
||||
if query is None:
|
||||
return LanceEmptyQueryBuilder(table)
|
||||
|
||||
# Check hybrid search first as it supports empty query pattern
|
||||
if query_type == "hybrid":
|
||||
# hybrid fts and vector query
|
||||
return LanceHybridQueryBuilder(
|
||||
table, query, vector_column_name, fts_columns=fts_columns
|
||||
)
|
||||
|
||||
if query is None:
|
||||
return LanceEmptyQueryBuilder(table)
|
||||
|
||||
# remember the string query for reranking purpose
|
||||
str_query = query if isinstance(query, str) else None
|
||||
|
||||
@@ -206,8 +207,6 @@ class LanceQueryBuilder(ABC):
|
||||
elif query_type == "auto":
|
||||
if isinstance(query, (list, np.ndarray)):
|
||||
return query, "vector"
|
||||
if isinstance(query, tuple):
|
||||
return query, "hybrid"
|
||||
else:
|
||||
conf = table.embedding_functions.get(vector_column_name)
|
||||
if conf is not None:
|
||||
@@ -238,6 +237,8 @@ class LanceQueryBuilder(ABC):
|
||||
self._where = None
|
||||
self._prefilter = False
|
||||
self._with_row_id = False
|
||||
self._vector = None
|
||||
self._text = None
|
||||
|
||||
@deprecation.deprecated(
|
||||
deprecated_in="0.3.1",
|
||||
@@ -465,6 +466,36 @@ class LanceQueryBuilder(ABC):
|
||||
},
|
||||
).explain_plan(verbose)
|
||||
|
||||
def vector(self, vector: Union[np.ndarray, list]) -> LanceQueryBuilder:
|
||||
"""Set the vector to search for.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
vector: np.ndarray or list
|
||||
The vector to search for.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceQueryBuilder
|
||||
The LanceQueryBuilder object.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def text(self, text: str) -> LanceQueryBuilder:
|
||||
"""Set the text to search for.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
text: str
|
||||
The text to search for.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceQueryBuilder
|
||||
The LanceQueryBuilder object.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def rerank(self, reranker: Reranker) -> LanceQueryBuilder:
|
||||
"""Rerank the results using the specified reranker.
|
||||
@@ -895,40 +926,70 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||
def __init__(
|
||||
self,
|
||||
table: "Table",
|
||||
query: str,
|
||||
vector_column: str,
|
||||
query: str = None,
|
||||
vector_column: str = None,
|
||||
fts_columns: Union[str, List[str]] = [],
|
||||
):
|
||||
super().__init__(table)
|
||||
vector_query, fts_query = self._validate_query(query)
|
||||
self._fts_query = LanceFtsQueryBuilder(
|
||||
table, fts_query, fts_columns=fts_columns
|
||||
)
|
||||
vector_query = self._query_to_vector(table, vector_query, vector_column)
|
||||
self._vector_query = LanceVectorQueryBuilder(table, vector_query, vector_column)
|
||||
self._query = query
|
||||
self._vector_column = vector_column
|
||||
self._fts_columns = fts_columns
|
||||
self._norm = "score"
|
||||
self._reranker = LinearCombinationReranker(weight=0.7, fill=1.0)
|
||||
self._nprobes = None
|
||||
self._refine_factor = None
|
||||
|
||||
def _validate_query(self, query):
|
||||
# Temp hack to support vectorized queries for hybrid search
|
||||
if isinstance(query, str):
|
||||
return query, query
|
||||
elif isinstance(query, tuple):
|
||||
if len(query) != 2:
|
||||
raise ValueError(
|
||||
"The query must be a tuple of (vector_query, fts_query)."
|
||||
)
|
||||
if not isinstance(query[0], (list, np.ndarray, pa.Array, pa.ChunkedArray)):
|
||||
raise ValueError(f"The vector query must be one of {VEC}.")
|
||||
if not isinstance(query[1], str):
|
||||
raise ValueError("The fts query must be a string.")
|
||||
return query[0], query[1]
|
||||
else:
|
||||
def _validate_query(self, query, vector=None, text=None):
|
||||
if query is not None and (vector is not None or text is not None):
|
||||
raise ValueError(
|
||||
"The query must be either a string or a tuple of (vector, string)."
|
||||
"You can either provide a string query in search() method"
|
||||
"or set `vector()` and `text()` explicitly for hybrid search."
|
||||
"But not both."
|
||||
)
|
||||
|
||||
vector_query = vector if vector is not None else query
|
||||
if not isinstance(vector_query, (str, list, np.ndarray)):
|
||||
raise ValueError("Vector query must be either a string or a vector")
|
||||
|
||||
text_query = text or query
|
||||
if text_query is None:
|
||||
raise ValueError("Text query must be provided for hybrid search.")
|
||||
if not isinstance(text_query, str):
|
||||
raise ValueError("Text query must be a string")
|
||||
|
||||
return vector_query, text_query
|
||||
|
||||
def to_arrow(self) -> pa.Table:
|
||||
vector_query, fts_query = self._validate_query(
|
||||
self._query, self._vector, self._text
|
||||
)
|
||||
self._fts_query = LanceFtsQueryBuilder(
|
||||
self._table, fts_query, fts_columns=self._fts_columns
|
||||
)
|
||||
vector_query = self._query_to_vector(
|
||||
self._table, vector_query, self._vector_column
|
||||
)
|
||||
self._vector_query = LanceVectorQueryBuilder(
|
||||
self._table, vector_query, self._vector_column
|
||||
)
|
||||
|
||||
if self._limit:
|
||||
self._vector_query.limit(self._limit)
|
||||
self._fts_query.limit(self._limit)
|
||||
if self._columns:
|
||||
self._vector_query.select(self._columns)
|
||||
self._fts_query.select(self._columns)
|
||||
if self._where:
|
||||
self._vector_query.where(self._where, self._prefilter)
|
||||
self._fts_query.where(self._where, self._prefilter)
|
||||
if self._with_row_id:
|
||||
self._vector_query.with_row_id(True)
|
||||
self._fts_query.with_row_id(True)
|
||||
if self._nprobes:
|
||||
self._vector_query.nprobes(self._nprobes)
|
||||
if self._refine_factor:
|
||||
self._vector_query.refine_factor(self._refine_factor)
|
||||
|
||||
with ThreadPoolExecutor() as executor:
|
||||
fts_future = executor.submit(self._fts_query.with_row_id(True).to_arrow)
|
||||
vector_future = executor.submit(
|
||||
@@ -1034,87 +1095,6 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||
|
||||
return self
|
||||
|
||||
def limit(self, limit: int) -> LanceHybridQueryBuilder:
|
||||
"""
|
||||
Set the maximum number of results to return for both vector and fts search
|
||||
components.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
limit: int
|
||||
The maximum number of results to return.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceHybridQueryBuilder
|
||||
The LanceHybridQueryBuilder object.
|
||||
"""
|
||||
self._vector_query.limit(limit)
|
||||
self._fts_query.limit(limit)
|
||||
self._limit = limit
|
||||
|
||||
return self
|
||||
|
||||
def select(self, columns: list) -> LanceHybridQueryBuilder:
|
||||
"""
|
||||
Set the columns to return for both vector and fts search.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
columns: list
|
||||
The columns to return.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceHybridQueryBuilder
|
||||
The LanceHybridQueryBuilder object.
|
||||
"""
|
||||
self._vector_query.select(columns)
|
||||
self._fts_query.select(columns)
|
||||
return self
|
||||
|
||||
def where(self, where: str, prefilter: bool = False) -> LanceHybridQueryBuilder:
|
||||
"""
|
||||
Set the where clause for both vector and fts search.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
where: str
|
||||
The where clause which is a valid SQL where clause. See
|
||||
`Lance filter pushdown <https://lancedb.github.io/lance/read_and_write.html#filter-push-down>`_
|
||||
for valid SQL expressions.
|
||||
|
||||
prefilter: bool, default False
|
||||
If True, apply the filter before vector search, otherwise the
|
||||
filter is applied on the result of vector search.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceHybridQueryBuilder
|
||||
The LanceHybridQueryBuilder object.
|
||||
"""
|
||||
|
||||
self._vector_query.where(where, prefilter=prefilter)
|
||||
self._fts_query.where(where)
|
||||
return self
|
||||
|
||||
def metric(self, metric: Literal["L2", "cosine"]) -> LanceHybridQueryBuilder:
|
||||
"""
|
||||
Set the distance metric to use for vector search.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
metric: "L2" or "cosine"
|
||||
The distance metric to use. By default "L2" is used.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceHybridQueryBuilder
|
||||
The LanceHybridQueryBuilder object.
|
||||
"""
|
||||
self._vector_query.metric(metric)
|
||||
return self
|
||||
|
||||
def nprobes(self, nprobes: int) -> LanceHybridQueryBuilder:
|
||||
"""
|
||||
Set the number of probes to use for vector search.
|
||||
@@ -1132,7 +1112,7 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||
LanceHybridQueryBuilder
|
||||
The LanceHybridQueryBuilder object.
|
||||
"""
|
||||
self._vector_query.nprobes(nprobes)
|
||||
self._nprobes = nprobes
|
||||
return self
|
||||
|
||||
def refine_factor(self, refine_factor: int) -> LanceHybridQueryBuilder:
|
||||
@@ -1150,7 +1130,15 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||
LanceHybridQueryBuilder
|
||||
The LanceHybridQueryBuilder object.
|
||||
"""
|
||||
self._vector_query.refine_factor(refine_factor)
|
||||
self._refine_factor = refine_factor
|
||||
return self
|
||||
|
||||
def vector(self, vector: Union[np.ndarray, list]) -> LanceHybridQueryBuilder:
|
||||
self._vector = vector
|
||||
return self
|
||||
|
||||
def text(self, text: str) -> LanceHybridQueryBuilder:
|
||||
self._text = text
|
||||
return self
|
||||
|
||||
|
||||
|
||||
@@ -111,7 +111,9 @@ def _run_test_reranker(reranker, table, query, query_vector, schema):
|
||||
|
||||
query_vector = table.to_pandas()["vector"][0]
|
||||
result = (
|
||||
table.search((query_vector, query), vector_column_name="vector")
|
||||
table.search(query_type="hybrid", vector_column_name="vector")
|
||||
.vector(query_vector)
|
||||
.text(query)
|
||||
.limit(30)
|
||||
.rerank(reranker=reranker)
|
||||
.to_arrow()
|
||||
@@ -207,14 +209,26 @@ def _run_test_hybrid_reranker(reranker, tmp_path, use_tantivy):
|
||||
query = "Our father who art in heaven"
|
||||
query_vector = table.to_pandas()["vector"][0]
|
||||
result = (
|
||||
table.search((query_vector, query), vector_column_name="vector")
|
||||
table.search(query_type="hybrid", vector_column_name="vector")
|
||||
.vector(query_vector)
|
||||
.text(query)
|
||||
.limit(30)
|
||||
.rerank(normalize="score")
|
||||
.to_arrow()
|
||||
)
|
||||
|
||||
assert len(result) == 30
|
||||
|
||||
# Fail if both query and (vector or text) are provided
|
||||
with pytest.raises(ValueError):
|
||||
table.search(query, query_type="hybrid", vector_column_name="vector").vector(
|
||||
query_vector
|
||||
).to_arrow()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
table.search(query, query_type="hybrid", vector_column_name="vector").text(
|
||||
query
|
||||
).to_arrow()
|
||||
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
|
||||
"The _relevance_score column of the results returned by the reranker "
|
||||
"represents the relevance of the result to the query & should "
|
||||
|
||||
Reference in New Issue
Block a user