diff --git a/docs/src/hybrid_search/hybrid_search.md b/docs/src/hybrid_search/hybrid_search.md index 244e8740..1503a07b 100644 --- a/docs/src/hybrid_search/hybrid_search.md +++ b/docs/src/hybrid_search/hybrid_search.md @@ -43,6 +43,19 @@ table.create_fts_index("text") # hybrid search with default re-ranker results = table.search("flower moon", query_type="hybrid").to_pandas() ``` +!!! Note + You can also pass the vector and text query manually. This is useful if you're not using the embedding API or if you're using a separate embedder service. +### Explicitly passing the vector and text query +```python +vector_query = [0.1, 0.2, 0.3, 0.4, 0.5] +text_query = "flower moon" +results = table.search(query_type="hybrid") + .vector(vector_query) + .text(text_query) + .limit(5) + .to_pandas() + +``` By default, LanceDB uses `LinearCombinationReranker(weight=0.7)` to combine and rerank the results of semantic and full-text search. You can customize the hyperparameters as needed or write your own custom reranker. Here's how you can use any of the available rerankers: diff --git a/python/python/lancedb/query.py b/python/python/lancedb/query.py index 9c33740a..7e7538b0 100644 --- a/python/python/lancedb/query.py +++ b/python/python/lancedb/query.py @@ -34,7 +34,6 @@ import pydantic from . import __version__ from .arrow import AsyncRecordBatchReader -from .common import VEC from .rerankers.base import Reranker from .rerankers.linear_combination import LinearCombinationReranker from .util import safe_import_pandas @@ -43,6 +42,7 @@ if TYPE_CHECKING: import PIL import polars as pl + from .common import VEC from ._lancedb import Query as LanceQuery from ._lancedb import VectorQuery as LanceVectorQuery from .pydantic import LanceModel @@ -151,15 +151,16 @@ class LanceQueryBuilder(ABC): vector_column_name: str The name of the vector column to use for vector search. """ - if query is None: - return LanceEmptyQueryBuilder(table) - + # Check hybrid search first as it supports empty query pattern if query_type == "hybrid": # hybrid fts and vector query return LanceHybridQueryBuilder( table, query, vector_column_name, fts_columns=fts_columns ) + if query is None: + return LanceEmptyQueryBuilder(table) + # remember the string query for reranking purpose str_query = query if isinstance(query, str) else None @@ -206,8 +207,6 @@ class LanceQueryBuilder(ABC): elif query_type == "auto": if isinstance(query, (list, np.ndarray)): return query, "vector" - if isinstance(query, tuple): - return query, "hybrid" else: conf = table.embedding_functions.get(vector_column_name) if conf is not None: @@ -238,6 +237,8 @@ class LanceQueryBuilder(ABC): self._where = None self._prefilter = False self._with_row_id = False + self._vector = None + self._text = None @deprecation.deprecated( deprecated_in="0.3.1", @@ -465,6 +466,36 @@ class LanceQueryBuilder(ABC): }, ).explain_plan(verbose) + def vector(self, vector: Union[np.ndarray, list]) -> LanceQueryBuilder: + """Set the vector to search for. + + Parameters + ---------- + vector: np.ndarray or list + The vector to search for. + + Returns + ------- + LanceQueryBuilder + The LanceQueryBuilder object. + """ + raise NotImplementedError + + def text(self, text: str) -> LanceQueryBuilder: + """Set the text to search for. + + Parameters + ---------- + text: str + The text to search for. + + Returns + ------- + LanceQueryBuilder + The LanceQueryBuilder object. + """ + raise NotImplementedError + @abstractmethod def rerank(self, reranker: Reranker) -> LanceQueryBuilder: """Rerank the results using the specified reranker. @@ -895,40 +926,70 @@ class LanceHybridQueryBuilder(LanceQueryBuilder): def __init__( self, table: "Table", - query: str, - vector_column: str, + query: str = None, + vector_column: str = None, fts_columns: Union[str, List[str]] = [], ): super().__init__(table) - vector_query, fts_query = self._validate_query(query) - self._fts_query = LanceFtsQueryBuilder( - table, fts_query, fts_columns=fts_columns - ) - vector_query = self._query_to_vector(table, vector_query, vector_column) - self._vector_query = LanceVectorQueryBuilder(table, vector_query, vector_column) + self._query = query + self._vector_column = vector_column + self._fts_columns = fts_columns self._norm = "score" self._reranker = LinearCombinationReranker(weight=0.7, fill=1.0) + self._nprobes = None + self._refine_factor = None - def _validate_query(self, query): - # Temp hack to support vectorized queries for hybrid search - if isinstance(query, str): - return query, query - elif isinstance(query, tuple): - if len(query) != 2: - raise ValueError( - "The query must be a tuple of (vector_query, fts_query)." - ) - if not isinstance(query[0], (list, np.ndarray, pa.Array, pa.ChunkedArray)): - raise ValueError(f"The vector query must be one of {VEC}.") - if not isinstance(query[1], str): - raise ValueError("The fts query must be a string.") - return query[0], query[1] - else: + def _validate_query(self, query, vector=None, text=None): + if query is not None and (vector is not None or text is not None): raise ValueError( - "The query must be either a string or a tuple of (vector, string)." + "You can either provide a string query in search() method" + "or set `vector()` and `text()` explicitly for hybrid search." + "But not both." ) + vector_query = vector if vector is not None else query + if not isinstance(vector_query, (str, list, np.ndarray)): + raise ValueError("Vector query must be either a string or a vector") + + text_query = text or query + if text_query is None: + raise ValueError("Text query must be provided for hybrid search.") + if not isinstance(text_query, str): + raise ValueError("Text query must be a string") + + return vector_query, text_query + def to_arrow(self) -> pa.Table: + vector_query, fts_query = self._validate_query( + self._query, self._vector, self._text + ) + self._fts_query = LanceFtsQueryBuilder( + self._table, fts_query, fts_columns=self._fts_columns + ) + vector_query = self._query_to_vector( + self._table, vector_query, self._vector_column + ) + self._vector_query = LanceVectorQueryBuilder( + self._table, vector_query, self._vector_column + ) + + if self._limit: + self._vector_query.limit(self._limit) + self._fts_query.limit(self._limit) + if self._columns: + self._vector_query.select(self._columns) + self._fts_query.select(self._columns) + if self._where: + self._vector_query.where(self._where, self._prefilter) + self._fts_query.where(self._where, self._prefilter) + if self._with_row_id: + self._vector_query.with_row_id(True) + self._fts_query.with_row_id(True) + if self._nprobes: + self._vector_query.nprobes(self._nprobes) + if self._refine_factor: + self._vector_query.refine_factor(self._refine_factor) + with ThreadPoolExecutor() as executor: fts_future = executor.submit(self._fts_query.with_row_id(True).to_arrow) vector_future = executor.submit( @@ -1034,87 +1095,6 @@ class LanceHybridQueryBuilder(LanceQueryBuilder): return self - def limit(self, limit: int) -> LanceHybridQueryBuilder: - """ - Set the maximum number of results to return for both vector and fts search - components. - - Parameters - ---------- - limit: int - The maximum number of results to return. - - Returns - ------- - LanceHybridQueryBuilder - The LanceHybridQueryBuilder object. - """ - self._vector_query.limit(limit) - self._fts_query.limit(limit) - self._limit = limit - - return self - - def select(self, columns: list) -> LanceHybridQueryBuilder: - """ - Set the columns to return for both vector and fts search. - - Parameters - ---------- - columns: list - The columns to return. - - Returns - ------- - LanceHybridQueryBuilder - The LanceHybridQueryBuilder object. - """ - self._vector_query.select(columns) - self._fts_query.select(columns) - return self - - def where(self, where: str, prefilter: bool = False) -> LanceHybridQueryBuilder: - """ - Set the where clause for both vector and fts search. - - Parameters - ---------- - where: str - The where clause which is a valid SQL where clause. See - `Lance filter pushdown `_ - for valid SQL expressions. - - prefilter: bool, default False - If True, apply the filter before vector search, otherwise the - filter is applied on the result of vector search. - - Returns - ------- - LanceHybridQueryBuilder - The LanceHybridQueryBuilder object. - """ - - self._vector_query.where(where, prefilter=prefilter) - self._fts_query.where(where) - return self - - def metric(self, metric: Literal["L2", "cosine"]) -> LanceHybridQueryBuilder: - """ - Set the distance metric to use for vector search. - - Parameters - ---------- - metric: "L2" or "cosine" - The distance metric to use. By default "L2" is used. - - Returns - ------- - LanceHybridQueryBuilder - The LanceHybridQueryBuilder object. - """ - self._vector_query.metric(metric) - return self - def nprobes(self, nprobes: int) -> LanceHybridQueryBuilder: """ Set the number of probes to use for vector search. @@ -1132,7 +1112,7 @@ class LanceHybridQueryBuilder(LanceQueryBuilder): LanceHybridQueryBuilder The LanceHybridQueryBuilder object. """ - self._vector_query.nprobes(nprobes) + self._nprobes = nprobes return self def refine_factor(self, refine_factor: int) -> LanceHybridQueryBuilder: @@ -1150,7 +1130,15 @@ class LanceHybridQueryBuilder(LanceQueryBuilder): LanceHybridQueryBuilder The LanceHybridQueryBuilder object. """ - self._vector_query.refine_factor(refine_factor) + self._refine_factor = refine_factor + return self + + def vector(self, vector: Union[np.ndarray, list]) -> LanceHybridQueryBuilder: + self._vector = vector + return self + + def text(self, text: str) -> LanceHybridQueryBuilder: + self._text = text return self diff --git a/python/python/tests/test_rerankers.py b/python/python/tests/test_rerankers.py index fca0850c..23132066 100644 --- a/python/python/tests/test_rerankers.py +++ b/python/python/tests/test_rerankers.py @@ -111,7 +111,9 @@ def _run_test_reranker(reranker, table, query, query_vector, schema): query_vector = table.to_pandas()["vector"][0] result = ( - table.search((query_vector, query), vector_column_name="vector") + table.search(query_type="hybrid", vector_column_name="vector") + .vector(query_vector) + .text(query) .limit(30) .rerank(reranker=reranker) .to_arrow() @@ -207,14 +209,26 @@ def _run_test_hybrid_reranker(reranker, tmp_path, use_tantivy): query = "Our father who art in heaven" query_vector = table.to_pandas()["vector"][0] result = ( - table.search((query_vector, query), vector_column_name="vector") + table.search(query_type="hybrid", vector_column_name="vector") + .vector(query_vector) + .text(query) .limit(30) .rerank(normalize="score") .to_arrow() ) - assert len(result) == 30 + # Fail if both query and (vector or text) are provided + with pytest.raises(ValueError): + table.search(query, query_type="hybrid", vector_column_name="vector").vector( + query_vector + ).to_arrow() + + with pytest.raises(ValueError): + table.search(query, query_type="hybrid", vector_column_name="vector").text( + query + ).to_arrow() + assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), ( "The _relevance_score column of the results returned by the reranker " "represents the relevance of the result to the query & should "