feat!: better api for manual hybrid queries (#1575)

Currently, the only documented way of performing hybrid search is by
using embedding API and passing string queries that get automatically
embedded. There are use cases where users might like to pass vectors and
text manually instead.
This ticket contains more information and historical context -
https://github.com/lancedb/lancedb/issues/937

This breaks a undocumented pathway that allowed passing (vector, text)
tuple queries which was intended to be temporary, so this is marked as a
breaking change. For all practical purposes, this should not really
impact most users

### usage
```
results = table.search(query_type="hybrid")
                .vector(vector_query)
                .text(text_query)
                .limit(5)
                .to_pandas()
```
This commit is contained in:
Ayush Chaurasia
2024-08-30 17:37:58 +05:30
committed by GitHub
parent 1521435193
commit dc72ece847
3 changed files with 131 additions and 116 deletions

View File

@@ -43,6 +43,19 @@ table.create_fts_index("text")
# hybrid search with default re-ranker
results = table.search("flower moon", query_type="hybrid").to_pandas()
```
!!! Note
You can also pass the vector and text query manually. This is useful if you're not using the embedding API or if you're using a separate embedder service.
### Explicitly passing the vector and text query
```python
vector_query = [0.1, 0.2, 0.3, 0.4, 0.5]
text_query = "flower moon"
results = table.search(query_type="hybrid")
.vector(vector_query)
.text(text_query)
.limit(5)
.to_pandas()
```
By default, LanceDB uses `LinearCombinationReranker(weight=0.7)` to combine and rerank the results of semantic and full-text search. You can customize the hyperparameters as needed or write your own custom reranker. Here's how you can use any of the available rerankers:

View File

@@ -34,7 +34,6 @@ import pydantic
from . import __version__
from .arrow import AsyncRecordBatchReader
from .common import VEC
from .rerankers.base import Reranker
from .rerankers.linear_combination import LinearCombinationReranker
from .util import safe_import_pandas
@@ -43,6 +42,7 @@ if TYPE_CHECKING:
import PIL
import polars as pl
from .common import VEC
from ._lancedb import Query as LanceQuery
from ._lancedb import VectorQuery as LanceVectorQuery
from .pydantic import LanceModel
@@ -151,15 +151,16 @@ class LanceQueryBuilder(ABC):
vector_column_name: str
The name of the vector column to use for vector search.
"""
if query is None:
return LanceEmptyQueryBuilder(table)
# Check hybrid search first as it supports empty query pattern
if query_type == "hybrid":
# hybrid fts and vector query
return LanceHybridQueryBuilder(
table, query, vector_column_name, fts_columns=fts_columns
)
if query is None:
return LanceEmptyQueryBuilder(table)
# remember the string query for reranking purpose
str_query = query if isinstance(query, str) else None
@@ -206,8 +207,6 @@ class LanceQueryBuilder(ABC):
elif query_type == "auto":
if isinstance(query, (list, np.ndarray)):
return query, "vector"
if isinstance(query, tuple):
return query, "hybrid"
else:
conf = table.embedding_functions.get(vector_column_name)
if conf is not None:
@@ -238,6 +237,8 @@ class LanceQueryBuilder(ABC):
self._where = None
self._prefilter = False
self._with_row_id = False
self._vector = None
self._text = None
@deprecation.deprecated(
deprecated_in="0.3.1",
@@ -465,6 +466,36 @@ class LanceQueryBuilder(ABC):
},
).explain_plan(verbose)
def vector(self, vector: Union[np.ndarray, list]) -> LanceQueryBuilder:
"""Set the vector to search for.
Parameters
----------
vector: np.ndarray or list
The vector to search for.
Returns
-------
LanceQueryBuilder
The LanceQueryBuilder object.
"""
raise NotImplementedError
def text(self, text: str) -> LanceQueryBuilder:
"""Set the text to search for.
Parameters
----------
text: str
The text to search for.
Returns
-------
LanceQueryBuilder
The LanceQueryBuilder object.
"""
raise NotImplementedError
@abstractmethod
def rerank(self, reranker: Reranker) -> LanceQueryBuilder:
"""Rerank the results using the specified reranker.
@@ -895,40 +926,70 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
def __init__(
self,
table: "Table",
query: str,
vector_column: str,
query: str = None,
vector_column: str = None,
fts_columns: Union[str, List[str]] = [],
):
super().__init__(table)
vector_query, fts_query = self._validate_query(query)
self._fts_query = LanceFtsQueryBuilder(
table, fts_query, fts_columns=fts_columns
)
vector_query = self._query_to_vector(table, vector_query, vector_column)
self._vector_query = LanceVectorQueryBuilder(table, vector_query, vector_column)
self._query = query
self._vector_column = vector_column
self._fts_columns = fts_columns
self._norm = "score"
self._reranker = LinearCombinationReranker(weight=0.7, fill=1.0)
self._nprobes = None
self._refine_factor = None
def _validate_query(self, query):
# Temp hack to support vectorized queries for hybrid search
if isinstance(query, str):
return query, query
elif isinstance(query, tuple):
if len(query) != 2:
raise ValueError(
"The query must be a tuple of (vector_query, fts_query)."
)
if not isinstance(query[0], (list, np.ndarray, pa.Array, pa.ChunkedArray)):
raise ValueError(f"The vector query must be one of {VEC}.")
if not isinstance(query[1], str):
raise ValueError("The fts query must be a string.")
return query[0], query[1]
else:
def _validate_query(self, query, vector=None, text=None):
if query is not None and (vector is not None or text is not None):
raise ValueError(
"The query must be either a string or a tuple of (vector, string)."
"You can either provide a string query in search() method"
"or set `vector()` and `text()` explicitly for hybrid search."
"But not both."
)
vector_query = vector if vector is not None else query
if not isinstance(vector_query, (str, list, np.ndarray)):
raise ValueError("Vector query must be either a string or a vector")
text_query = text or query
if text_query is None:
raise ValueError("Text query must be provided for hybrid search.")
if not isinstance(text_query, str):
raise ValueError("Text query must be a string")
return vector_query, text_query
def to_arrow(self) -> pa.Table:
vector_query, fts_query = self._validate_query(
self._query, self._vector, self._text
)
self._fts_query = LanceFtsQueryBuilder(
self._table, fts_query, fts_columns=self._fts_columns
)
vector_query = self._query_to_vector(
self._table, vector_query, self._vector_column
)
self._vector_query = LanceVectorQueryBuilder(
self._table, vector_query, self._vector_column
)
if self._limit:
self._vector_query.limit(self._limit)
self._fts_query.limit(self._limit)
if self._columns:
self._vector_query.select(self._columns)
self._fts_query.select(self._columns)
if self._where:
self._vector_query.where(self._where, self._prefilter)
self._fts_query.where(self._where, self._prefilter)
if self._with_row_id:
self._vector_query.with_row_id(True)
self._fts_query.with_row_id(True)
if self._nprobes:
self._vector_query.nprobes(self._nprobes)
if self._refine_factor:
self._vector_query.refine_factor(self._refine_factor)
with ThreadPoolExecutor() as executor:
fts_future = executor.submit(self._fts_query.with_row_id(True).to_arrow)
vector_future = executor.submit(
@@ -1034,87 +1095,6 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
return self
def limit(self, limit: int) -> LanceHybridQueryBuilder:
"""
Set the maximum number of results to return for both vector and fts search
components.
Parameters
----------
limit: int
The maximum number of results to return.
Returns
-------
LanceHybridQueryBuilder
The LanceHybridQueryBuilder object.
"""
self._vector_query.limit(limit)
self._fts_query.limit(limit)
self._limit = limit
return self
def select(self, columns: list) -> LanceHybridQueryBuilder:
"""
Set the columns to return for both vector and fts search.
Parameters
----------
columns: list
The columns to return.
Returns
-------
LanceHybridQueryBuilder
The LanceHybridQueryBuilder object.
"""
self._vector_query.select(columns)
self._fts_query.select(columns)
return self
def where(self, where: str, prefilter: bool = False) -> LanceHybridQueryBuilder:
"""
Set the where clause for both vector and fts search.
Parameters
----------
where: str
The where clause which is a valid SQL where clause. See
`Lance filter pushdown <https://lancedb.github.io/lance/read_and_write.html#filter-push-down>`_
for valid SQL expressions.
prefilter: bool, default False
If True, apply the filter before vector search, otherwise the
filter is applied on the result of vector search.
Returns
-------
LanceHybridQueryBuilder
The LanceHybridQueryBuilder object.
"""
self._vector_query.where(where, prefilter=prefilter)
self._fts_query.where(where)
return self
def metric(self, metric: Literal["L2", "cosine"]) -> LanceHybridQueryBuilder:
"""
Set the distance metric to use for vector search.
Parameters
----------
metric: "L2" or "cosine"
The distance metric to use. By default "L2" is used.
Returns
-------
LanceHybridQueryBuilder
The LanceHybridQueryBuilder object.
"""
self._vector_query.metric(metric)
return self
def nprobes(self, nprobes: int) -> LanceHybridQueryBuilder:
"""
Set the number of probes to use for vector search.
@@ -1132,7 +1112,7 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
LanceHybridQueryBuilder
The LanceHybridQueryBuilder object.
"""
self._vector_query.nprobes(nprobes)
self._nprobes = nprobes
return self
def refine_factor(self, refine_factor: int) -> LanceHybridQueryBuilder:
@@ -1150,7 +1130,15 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
LanceHybridQueryBuilder
The LanceHybridQueryBuilder object.
"""
self._vector_query.refine_factor(refine_factor)
self._refine_factor = refine_factor
return self
def vector(self, vector: Union[np.ndarray, list]) -> LanceHybridQueryBuilder:
self._vector = vector
return self
def text(self, text: str) -> LanceHybridQueryBuilder:
self._text = text
return self

View File

@@ -111,7 +111,9 @@ def _run_test_reranker(reranker, table, query, query_vector, schema):
query_vector = table.to_pandas()["vector"][0]
result = (
table.search((query_vector, query), vector_column_name="vector")
table.search(query_type="hybrid", vector_column_name="vector")
.vector(query_vector)
.text(query)
.limit(30)
.rerank(reranker=reranker)
.to_arrow()
@@ -207,14 +209,26 @@ def _run_test_hybrid_reranker(reranker, tmp_path, use_tantivy):
query = "Our father who art in heaven"
query_vector = table.to_pandas()["vector"][0]
result = (
table.search((query_vector, query), vector_column_name="vector")
table.search(query_type="hybrid", vector_column_name="vector")
.vector(query_vector)
.text(query)
.limit(30)
.rerank(normalize="score")
.to_arrow()
)
assert len(result) == 30
# Fail if both query and (vector or text) are provided
with pytest.raises(ValueError):
table.search(query, query_type="hybrid", vector_column_name="vector").vector(
query_vector
).to_arrow()
with pytest.raises(ValueError):
table.search(query, query_type="hybrid", vector_column_name="vector").text(
query
).to_arrow()
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
"The _relevance_score column of the results returned by the reranker "
"represents the relevance of the result to the query & should "