From 3ffed897936f5eb57d59c8c53e158335d072caaf Mon Sep 17 00:00:00 2001 From: Ayush Chaurasia Date: Tue, 30 Jan 2024 19:10:33 +0530 Subject: [PATCH] feat(python): Hybrid search & Reranker API (#824) based on https://github.com/lancedb/lancedb/pull/713 - The Reranker api can be plugged into vector only or fts only search but this PR doesn't do that (see example - https://txt.cohere.com/rerank/) ### Default reranker -- `LinearCombinationReranker(weight=0.7, fill=1.0)` ``` table.search("hello", query_type="hybrid").rerank(normalize="score").to_pandas() ``` ### Available rerankers LinearCombinationReranker ``` from lancedb.rerankers import LinearCombinationReranker # Same as default table.search("hello", query_type="hybrid").rerank( normalize="score", reranker=LinearCombinationReranker() ).to_pandas() # with custom params reranker = LinearCombinationReranker(weight=0.3, fill=1.0) table.search("hello", query_type="hybrid").rerank( normalize="score", reranker=reranker ).to_pandas() ``` Cohere Reranker ``` from lancedb.rerankers import CohereReranker # default model.. English and multi-lingual supported. See docstring for available custom params table.search("hello", query_type="hybrid").rerank( normalize="rank", # score or rank reranker=CohereReranker() ).to_pandas() ``` CrossEncoderReranker ``` from lancedb.rerankers import CrossEncoderReranker table.search("hello", query_type="hybrid").rerank( normalize="rank", reranker=CrossEncoderReranker() ).to_pandas() ``` ## Using custom Reranker ``` from lancedb.reranker import Reranker class CustomReranker(Reranker): def rerank_hybrid(self, vector_result, fts_result): combined_res = self.merge_results(vector_results, fts_results) # or use custom combination logic # Custom rerank logic here return combined_res ``` - [x] Expand testing - [x] Make sure usage makes sense - [x] Run simple benchmarks for correctness (Seeing weird result from cohere reranker in the toy example) - Support diverse rerankers by default: - [x] Cross encoding - [x] Cohere - [x] Reciprocal Rank Fusion --------- Co-authored-by: Chang She <759245+changhiskhan@users.noreply.github.com> Co-authored-by: Prashanth Rao <35005448+prrao87@users.noreply.github.com> --- docs/mkdocs.yml | 2 + docs/src/hybrid_search.md | 172 ++++++++++ docs/test/md_testing.py | 1 + python/lancedb/common.py | 4 +- python/lancedb/context.py | 4 +- python/lancedb/embeddings/utils.py | 4 +- python/lancedb/query.py | 324 +++++++++++++++++- python/lancedb/rerankers/__init__.py | 11 + python/lancedb/rerankers/base.py | 109 ++++++ python/lancedb/rerankers/cohere.py | 85 +++++ python/lancedb/rerankers/cross_encoder.py | 78 +++++ .../lancedb/rerankers/linear_combination.py | 117 +++++++ python/lancedb/table.py | 14 +- python/lancedb/util.py | 30 +- python/tests/test_rerankers.py | 168 +++++++++ python/tests/test_table.py | 54 +++ 16 files changed, 1136 insertions(+), 41 deletions(-) create mode 100644 docs/src/hybrid_search.md create mode 100644 python/lancedb/rerankers/__init__.py create mode 100644 python/lancedb/rerankers/base.py create mode 100644 python/lancedb/rerankers/cohere.py create mode 100644 python/lancedb/rerankers/cross_encoder.py create mode 100644 python/lancedb/rerankers/linear_combination.py create mode 100644 python/tests/test_rerankers.py diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index 259ee0d9..b2cbe9eb 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -90,6 +90,7 @@ nav: - Building an ANN index: ann_indexes.md - Vector Search: search.md - Full-text search: fts.md + - Hybrid search: hybrid_search.md - Filtering: sql.md - Versioning & Reproducibility: notebooks/reproducibility.ipynb - Configuring Storage: guides/storage.md @@ -151,6 +152,7 @@ nav: - Building an ANN index: ann_indexes.md - Vector Search: search.md - Full-text search: fts.md + - Hybrid search: hybrid_search.md - Filtering: sql.md - Versioning & Reproducibility: notebooks/reproducibility.ipynb - Configuring Storage: guides/storage.md diff --git a/docs/src/hybrid_search.md b/docs/src/hybrid_search.md new file mode 100644 index 00000000..e6ea06ef --- /dev/null +++ b/docs/src/hybrid_search.md @@ -0,0 +1,172 @@ +# Hybrid Search + +LanceDB supports both semantic and keyword-based search. In real world applications, it is often useful to combine these two approaches to get the best best results. For example, you may want to search for a document that is semantically similar to a query document, but also contains a specific keyword. This is an example of *hybrid search*, a search algorithm that combines multiple search techniques. + +## Hybrid search in LanceDB +You can perform hybrid search in LanceDB by combining the results of semantic and full-text search via a reranking algorithm of your choice. LanceDB provides multiple rerankers out of the box. However, you can always write a custom reranker if your use case need more sophisticated logic . + +```python +import lancedb +from lancedb.embeddings import get_registry +from lancedb.pydanatic import LanceModel, Vector + +db = lancedb.connect("~/.lancedb") + +# Ingest embedding function in LanceDB table +embeddings = get_registry().get("openai").create() + +class Documents(LanceModel): + vector: Vector(embeddings.ndims) = embeddings.VectorField() + text: str = embeddings.SourceField() + +table = db.create_table("documents", schema=Documents) + +data = [ + { "text": "rebel spaceships striking from a hidden base"}, + { "text": "have won their first victory against the evil Galactic Empire"}, + { "text": "during the battle rebel spies managed to steal secret plans"}, + { "text": "to the Empire's ultimate weapon the Death Star"} +] + +# ingest docs with auto-vectorization +table.add(data) + +# hybrid search with default re-ranker +results = table.search("flower moon", query_type="hybrid").to_pandas() +``` + +By default, LanceDB uses `LinearCombinationReranker(weights=0.7)` to combine and rerank the results of semantic and full-text search. You can customize the hyperparameters as needed or write your own custom reranker. Here's how you can use any of the available rerankers: + + +### `rerank()` arguments +* `normalize`: `str`, default `"score"`: + The method to normalize the scores. Can be "rank" or "score". If "rank", the scores are converted to ranks and then normalized. If "score", the scores are normalized directly. +* `reranker`: `Reranker`, default `LinearCombinationReranker(weights=0.7)`. + The reranker to use. If not specified, the default reranker is used. + + +## Available Rerankers +LanceDB provides a number of re-rankers out of the box. You can use any of these re-rankers by passing them to the `rerank()` method. Here's a list of available re-rankers: + +### Linear Combination Reranker +This is the default re-ranker used by LanceDB. It combines the results of semantic and full-text search using a linear combination of the scores. The weights for the linear combination can be specified. It defaults to 0.7, i.e, 70% weight for semantic search and 30% weight for full-text search. + + +```python +from lancedb.rerankers import LinearCombinationReranker + +reranker = LinearCombinationReranker(weights=0.3) # Use 0.3 as the weight for vector search + +results = table.search("rebel", query_type="hybrid").rerank(reranker=reranker).to_pandas() +``` + +Arguments +---------------- +* `weight`: `float`, default `0.7`: + The weight to use for the semantic search score. The weight for the full-text search score is `1 - weights`. +* `fill`: `float`, default `1.0`: + The score to give to results that are only in one of the two result sets.This is treated as penalty, so a higher value means a lower score. + TODO: We should just hardcode this-- its pretty confusing as we invert scores to calculate final score +* `return_score` : str, default `"relevance"` + options are "relevance" or "all" + The type of score to return. If "relevance", will return only the `_relevance_score. If "all", will return all scores from the vector and FTS search along with the relevance score. + +### Cohere Reranker +This re-ranker uses the [Cohere](https://cohere.ai/) API to combine the results of semantic and full-text search. You can use this re-ranker by passing `CohereReranker()` to the `rerank()` method. Note that you'll need to set the `COHERE_API_KEY` environment variable to use this re-ranker. + +```python +from lancedb.rerankers import CohereReranker + +reranker = CohereReranker() + +results = table.search("vampire weekend", query_type="hybrid").rerank(reranker=reranker).to_pandas() +``` + +Arguments +---------------- +* `model_name`` : str, default `"rerank-english-v2.0"`` + The name of the cross encoder model to use. Available cohere models are: + - rerank-english-v2.0 + - rerank-multilingual-v2.0 +* `column` : str, default `"text"` + The name of the column to use as input to the cross encoder model. +* `top_n` : str, default `None` + The number of results to return. If None, will return all results. + +!!! Note + Only returns `_relevance_score`. Does not support `return_score = "all"`. + +### Cross Encoder Reranker +This reranker uses the [Sentence Transformers](https://www.sbert.net/) library to combine the results of semantic and full-text search. You can use it by passing `CrossEncoderReranker()` to the `rerank()` method. + +```python +from lancedb.rerankers import CrossEncoderReranker + +reranker = CrossEncoderReranker() + +results = table.search("harmony hall", query_type="hybrid").rerank(reranker=reranker).to_pandas() +``` + + +Arguments +---------------- +* `model` : str, default `"cross-encoder/ms-marco-TinyBERT-L-6"` + The name of the cross encoder model to use. Available cross encoder models can be found [here](https://www.sbert.net/docs/pretrained_cross-encoders.html) +* `column` : str, default `"text"` + The name of the column to use as input to the cross encoder model. +* `device` : str, default `None` + The device to use for the cross encoder model. If None, will use "cuda" if available, otherwise "cpu". + +!!! Note + Only returns `_relevance_score`. Does not support `return_score = "all"`. + + +## Building Custom Rerankers +You can build your own custom reranker by subclassing the `Reranker` class and implementing the `rerank_hybrid()` method. Here's an example of a custom reranker that combines the results of semantic and full-text search using a linear combination of the scores. + +The `Reranker` base interface comes with a `merge_results()` method that can be used to combine the results of semantic and full-text search. This is a vanilla merging algorithm that simply concatenates the results and removes the duplicates without taking the scores into consideration. It only keeps the first copy of the row encountered. This works well in cases that don't require the scores of semantic and full-text search to combine the results. If you want to use the scores or want to support `return_score="all"`, you'll need to implement your own merging algorithm. + +```python + +from lancedb.rerankers import Reranker +import pyarrow as pa + +class MyReranker(Reranker): + def __init__(self, param1, param2, ..., return_score="relevance"): + super().__init__(return_score) + self.param1 = param1 + self.param2 = param2 + + def rerank_hybrid(self, vector_results: pa.Table, fts_results: pa.Table): + # Use the built-in merging function + combined_result = self.merge_results(vector_results, fts_results) + + # Do something with the combined results + # ... + + # Return the combined results + return combined_result + +``` + +You can also accept additional arguments like a filter along with fts and vector search results + +```python + +from lancedb.rerankers import Reranker +import pyarrow as pa + +class MyReranker(Reranker): + ... + + def rerank_hybrid(self, vector_results: pa.Table, fts_results: pa.Table, filter: str): + # Use the built-in merging function + combined_result = self.merge_results(vector_results, fts_results) + + # Do something with the combined results & filter + # ... + + # Return the combined results + return combined_result + +``` diff --git a/docs/test/md_testing.py b/docs/test/md_testing.py index 2a4012c5..0ea8431f 100644 --- a/docs/test/md_testing.py +++ b/docs/test/md_testing.py @@ -14,6 +14,7 @@ excluded_globs = [ "../src/concepts/*.md", "../src/ann_indexes.md", "../src/basic.md", + "../src/hybrid_search.md", ] python_prefix = "py" diff --git a/python/lancedb/common.py b/python/lancedb/common.py index 54c7c9e0..6be2f228 100644 --- a/python/lancedb/common.py +++ b/python/lancedb/common.py @@ -16,9 +16,9 @@ from typing import Iterable, List, Union import numpy as np import pyarrow as pa -from .util import safe_import_pandas +from .util import safe_import -pd = safe_import_pandas() +pd = safe_import("pandas") DATA = Union[List[dict], dict, "pd.DataFrame", pa.Table, Iterable[pa.RecordBatch]] VEC = Union[list, np.ndarray, pa.Array, pa.ChunkedArray] diff --git a/python/lancedb/context.py b/python/lancedb/context.py index bd7b04c8..20484c47 100644 --- a/python/lancedb/context.py +++ b/python/lancedb/context.py @@ -16,9 +16,9 @@ import deprecation from . import __version__ from .exceptions import MissingColumnError, MissingValueError -from .util import safe_import_pandas +from .util import safe_import -pd = safe_import_pandas() +pd = safe_import("pandas") def contextualize(raw_df: "pd.DataFrame") -> Contextualizer: diff --git a/python/lancedb/embeddings/utils.py b/python/lancedb/embeddings/utils.py index 325145f4..4708dfd7 100644 --- a/python/lancedb/embeddings/utils.py +++ b/python/lancedb/embeddings/utils.py @@ -26,10 +26,10 @@ import pyarrow as pa from lance.vector import vec_to_table from retry import retry -from ..util import safe_import_pandas +from ..util import safe_import from ..utils.general import LOGGER -pd = safe_import_pandas() +pd = safe_import("pandas") DATA = Union[pa.Table, "pd.DataFrame"] TEXT = Union[str, List[str], pa.Array, pa.ChunkedArray, np.ndarray] diff --git a/python/lancedb/query.py b/python/lancedb/query.py index 23e76c0d..d564e736 100644 --- a/python/lancedb/query.py +++ b/python/lancedb/query.py @@ -14,8 +14,9 @@ from __future__ import annotations from abc import ABC, abstractmethod +from concurrent.futures import ThreadPoolExecutor from pathlib import Path -from typing import TYPE_CHECKING, List, Literal, Optional, Type, Union +from typing import TYPE_CHECKING, List, Literal, Optional, Tuple, Type, Union import deprecation import numpy as np @@ -23,8 +24,10 @@ import pyarrow as pa import pydantic from . import __version__ -from .common import VECTOR_COLUMN_NAME -from .util import safe_import_pandas +from .common import VEC, VECTOR_COLUMN_NAME +from .rerankers.base import Reranker +from .rerankers.linear_combination import LinearCombinationReranker +from .util import safe_import if TYPE_CHECKING: import PIL @@ -33,7 +36,7 @@ if TYPE_CHECKING: from .pydantic import LanceModel from .table import Table -pd = safe_import_pandas() +pd = safe_import("pandas") class Query(pydantic.BaseModel): @@ -99,6 +102,8 @@ class Query(pydantic.BaseModel): # Refine factor. refine_factor: Optional[int] = None + with_row_id: bool = False + class LanceQueryBuilder(ABC): """Build LanceDB query based on specific query type: @@ -109,19 +114,26 @@ class LanceQueryBuilder(ABC): def create( cls, table: "Table", - query: Optional[Union[np.ndarray, str, "PIL.Image.Image"]], + query: Optional[Union[np.ndarray, str, "PIL.Image.Image", Tuple]], query_type: str, vector_column_name: str, ) -> LanceQueryBuilder: if query is None: return LanceEmptyQueryBuilder(table) - # convert "auto" query_type to "vector" or "fts" - # and convert the query to vector if needed + if query_type == "hybrid": + # hybrid fts and vector query + return LanceHybridQueryBuilder(table, query, vector_column_name) + + # convert "auto" query_type to "vector", "fts" + # or "hybrid" and convert the query to vector if needed query, query_type = cls._resolve_query( table, query, query_type, vector_column_name ) + if query_type == "hybrid": + return LanceHybridQueryBuilder(table, query, vector_column_name) + if isinstance(query, str): # fts return LanceFtsQueryBuilder(table, query) @@ -144,17 +156,13 @@ class LanceQueryBuilder(ABC): raise TypeError(f"'fts' queries must be a string: {type(query)}") return query, query_type elif query_type == "vector": - if not isinstance(query, (list, np.ndarray)): - conf = table.embedding_functions.get(vector_column_name) - if conf is not None: - query = conf.function.compute_query_embeddings_with_retry(query)[0] - else: - msg = f"No embedding function for {vector_column_name}" - raise ValueError(msg) + query = cls._query_to_vector(table, query, vector_column_name) return query, query_type elif query_type == "auto": if isinstance(query, (list, np.ndarray)): return query, "vector" + if isinstance(query, tuple): + return query, "hybrid" else: conf = table.embedding_functions.get(vector_column_name) if conf is not None: @@ -167,11 +175,23 @@ class LanceQueryBuilder(ABC): f"Invalid query_type, must be 'vector', 'fts', or 'auto': {query_type}" ) + @classmethod + def _query_to_vector(cls, table, query, vector_column_name): + if isinstance(query, (list, np.ndarray)): + return query + conf = table.embedding_functions.get(vector_column_name) + if conf is not None: + return conf.function.compute_query_embeddings_with_retry(query)[0] + else: + msg = f"No embedding function for {vector_column_name}" + raise ValueError(msg) + def __init__(self, table: "Table"): self._table = table self._limit = 10 self._columns = None self._where = None + self._with_row_id = False @deprecation.deprecated( deprecated_in="0.3.1", @@ -341,6 +361,22 @@ class LanceQueryBuilder(ABC): self._prefilter = prefilter return self + def with_row_id(self, with_row_id: bool) -> LanceQueryBuilder: + """Set whether to return row ids. + + Parameters + ---------- + with_row_id: bool + If True, return _rowid column in the results. + + Returns + ------- + LanceQueryBuilder + The LanceQueryBuilder object. + """ + self._with_row_id = with_row_id + return self + class LanceVectorQueryBuilder(LanceQueryBuilder): """ @@ -459,6 +495,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder): nprobes=self._nprobes, refine_factor=self._refine_factor, vector_column=self._vector_column, + with_row_id=self._with_row_id, ) return self._table._execute_query(query) @@ -568,6 +605,10 @@ class LanceFtsQueryBuilder(LanceQueryBuilder): ds = lance.write_dataset(output_tbl, tmp) output_tbl = ds.to_table(filter=self._where) + if self._with_row_id: + # Need to set this to uint explicitly as vector results are in uint64 + row_ids = pa.array(row_ids, type=pa.uint64()) + output_tbl = output_tbl.append_column("_rowid", row_ids) return output_tbl @@ -579,3 +620,258 @@ class LanceEmptyQueryBuilder(LanceQueryBuilder): filter=self._where, limit=self._limit, ) + + +class LanceHybridQueryBuilder(LanceQueryBuilder): + def __init__(self, table: "Table", query: str, vector_column: str): + super().__init__(table) + self._validate_fts_index() + self._query = query + vector_query, fts_query = self._validate_query(query) + self._fts_query = LanceFtsQueryBuilder(table, fts_query) + vector_query = self._query_to_vector(table, vector_query, vector_column) + self._vector_query = LanceVectorQueryBuilder(table, vector_query, vector_column) + self._norm = "score" + self._reranker = LinearCombinationReranker(weight=0.7, fill=1.0) + + def _validate_fts_index(self): + if self._table._get_fts_index_path() is None: + raise ValueError( + "Please create a full-text search index " "to perform hybrid search." + ) + + def _validate_query(self, query): + # Temp hack to support vectorized queries for hybrid search + if isinstance(query, str): + return query, query + elif isinstance(query, tuple): + if len(query) != 2: + raise ValueError( + "The query must be a tuple of (vector_query, fts_query)." + ) + if not isinstance(query[0], (list, np.ndarray, pa.Array, pa.ChunkedArray)): + raise ValueError(f"The vector query must be one of {VEC}.") + if not isinstance(query[1], str): + raise ValueError("The fts query must be a string.") + return query[0], query[1] + else: + raise ValueError( + "The query must be either a string or a tuple of (vector, string)." + ) + + def to_arrow(self) -> pa.Table: + with ThreadPoolExecutor() as executor: + fts_future = executor.submit(self._fts_query.with_row_id(True).to_arrow) + vector_future = executor.submit( + self._vector_query.with_row_id(True).to_arrow + ) + fts_results = fts_future.result() + vector_results = vector_future.result() + + # convert to ranks first if needed + if self._norm == "rank": + vector_results = self._rank(vector_results, "_distance") + fts_results = self._rank(fts_results, "score") + # normalize the scores to be between 0 and 1, 0 being most relevant + vector_results = self._normalize_scores(vector_results, "_distance") + + # In fts higher scores represent relevance. Not inverting them here as + # rerankers might need to preserve this score to support `return_score="all"` + fts_results = self._normalize_scores(fts_results, "score") + + results = self._reranker.rerank_hybrid(self, vector_results, fts_results) + if not isinstance(results, pa.Table): # Enforce type + raise TypeError( + f"rerank_hybrid must return a pyarrow.Table, got {type(results)}" + ) + + if not self._with_row_id: + results = results.drop(["_rowid"]) + return results + + def _rank(self, results: pa.Table, column: str, ascending: bool = True): + if len(results) == 0: + return results + # Get the _score column from results + scores = results.column(column).to_numpy() + sort_indices = np.argsort(scores) + if not ascending: + sort_indices = sort_indices[::-1] + ranks = np.empty_like(sort_indices) + ranks[sort_indices] = np.arange(len(scores)) + 1 + # replace the _score column with the ranks + _score_idx = results.column_names.index(column) + results = results.set_column( + _score_idx, column, pa.array(ranks, type=pa.float32()) + ) + return results + + def _normalize_scores(self, results: pa.Table, column: str, invert=False): + if len(results) == 0: + return results + # Get the _score column from results + scores = results.column(column).to_numpy() + # normalize the scores by subtracting the min and dividing by the max + max, min = np.max(scores), np.min(scores) + if np.isclose(max, min): + rng = max + else: + rng = max - min + scores = (scores - min) / rng + if invert: + scores = 1 - scores + # replace the _score column with the ranks + _score_idx = results.column_names.index(column) + results = results.set_column( + _score_idx, column, pa.array(scores, type=pa.float32()) + ) + return results + + def rerank( + self, + normalize="score", + reranker: Reranker = LinearCombinationReranker(weight=0.7, fill=1.0), + ) -> LanceHybridQueryBuilder: + """ + Rerank the hybrid search results using the specified reranker. The reranker + must be an instance of Reranker class. + + Parameters + ---------- + normalize: str, default "score" + The method to normalize the scores. Can be "rank" or "score". If "rank", + the scores are converted to ranks and then normalized. If "score", the + scores are normalized directly. + reranker: Reranker, default LinearCombinationReranker(weight=0.7, fill=1.0) + The reranker to use. Must be an instance of Reranker class. + Returns + ------- + LanceHybridQueryBuilder + The LanceHybridQueryBuilder object. + """ + if normalize not in ["rank", "score"]: + raise ValueError("normalize must be 'rank' or 'score'.") + if reranker and not isinstance(reranker, Reranker): + raise ValueError("reranker must be an instance of Reranker class.") + + self._norm = normalize + self._reranker = reranker + + return self + + def limit(self, limit: int) -> LanceHybridQueryBuilder: + """ + Set the maximum number of results to return for both vector and fts search + components. + + Parameters + ---------- + limit: int + The maximum number of results to return. + + Returns + ------- + LanceHybridQueryBuilder + The LanceHybridQueryBuilder object. + """ + self._vector_query.limit(limit) + self._fts_query.limit(limit) + return self + + def select(self, columns: list) -> LanceHybridQueryBuilder: + """ + Set the columns to return for both vector and fts search. + + Parameters + ---------- + columns: list + The columns to return. + + Returns + ------- + LanceHybridQueryBuilder + The LanceHybridQueryBuilder object. + """ + self._vector_query.select(columns) + self._fts_query.select(columns) + return self + + def where(self, where: str, prefilter: bool = False) -> LanceHybridQueryBuilder: + """ + Set the where clause for both vector and fts search. + + Parameters + ---------- + where: str + The where clause which is a valid SQL where clause. See + `Lance filter pushdown `_ + for valid SQL expressions. + + prefilter: bool, default False + If True, apply the filter before vector search, otherwise the + filter is applied on the result of vector search. + + Returns + ------- + LanceHybridQueryBuilder + The LanceHybridQueryBuilder object. + """ + + self._vector_query.where(where, prefilter=prefilter) + self._fts_query.where(where) + return self + + def metric(self, metric: Literal["L2", "cosine"]) -> LanceHybridQueryBuilder: + """ + Set the distance metric to use for vector search. + + Parameters + ---------- + metric: "L2" or "cosine" + The distance metric to use. By default "L2" is used. + + Returns + ------- + LanceHybridQueryBuilder + The LanceHybridQueryBuilder object. + """ + self._vector_query.metric(metric) + return self + + def nprobes(self, nprobes: int) -> LanceHybridQueryBuilder: + """ + Set the number of probes to use for vector search. + + Higher values will yield better recall (more likely to find vectors if + they exist) at the expense of latency. + + Parameters + ---------- + nprobes: int + The number of probes to use. + + Returns + ------- + LanceHybridQueryBuilder + The LanceHybridQueryBuilder object. + """ + self._vector_query.nprobes(nprobes) + return self + + def refine_factor(self, refine_factor: int) -> LanceHybridQueryBuilder: + """ + Refine the vector search results by reading extra elements and + re-ranking them in memory. + + Parameters + ---------- + refine_factor: int + The refine factor to use. + + Returns + ------- + LanceHybridQueryBuilder + The LanceHybridQueryBuilder object. + """ + self._vector_query.refine_factor(refine_factor) + return self diff --git a/python/lancedb/rerankers/__init__.py b/python/lancedb/rerankers/__init__.py new file mode 100644 index 00000000..6d43636a --- /dev/null +++ b/python/lancedb/rerankers/__init__.py @@ -0,0 +1,11 @@ +from .base import Reranker +from .cohere import CohereReranker +from .cross_encoder import CrossEncoderReranker +from .linear_combination import LinearCombinationReranker + +__all__ = [ + "Reranker", + "CrossEncoderReranker", + "CohereReranker", + "LinearCombinationReranker", +] diff --git a/python/lancedb/rerankers/base.py b/python/lancedb/rerankers/base.py new file mode 100644 index 00000000..b1036ec5 --- /dev/null +++ b/python/lancedb/rerankers/base.py @@ -0,0 +1,109 @@ +import typing +from abc import ABC, abstractmethod + +import numpy as np +import pyarrow as pa + +if typing.TYPE_CHECKING: + import lancedb + + +class Reranker(ABC): + def __init__(self, return_score: str = "relevance"): + """ + Interface for a reranker. A reranker is used to rerank the results from a + vector and FTS search. This is useful for combining the results from both + search methods. + + Parameters + ---------- + return_score : str, default "relevance" + opntions are "relevance" or "all" + The type of score to return. If "relevance", will return only the relevance + score. If "all", will return all scores from the vector and FTS search along + with the relevance score. + + """ + if return_score not in ["relevance", "all"]: + raise ValueError("score must be either 'relevance' or 'all'") + self.score = return_score + + @abstractmethod + def rerank_hybrid( + query_builder: "lancedb.HybridQueryBuilder", + vector_results: pa.Table, + fts_results: pa.Table, + ): + """ + Rerank function receives the individual results from the vector and FTS search + results. You can choose to use any of the results to generate the final results, + allowing maximum flexibility. This is mandatory to implement + + Parameters + ---------- + query_builder : "lancedb.HybridQueryBuilder" + The query builder object that was used to generate the results + vector_results : pa.Table + The results from the vector search + fts_results : pa.Table + The results from the FTS search + """ + pass + + def rerank_vector( + query_builder: "lancedb.VectorQueryBuilder", vector_results: pa.Table + ): + """ + Rerank function receives the individual results from the vector search. + This isn't mandatory to implement + + Parameters + ---------- + query_builder : "lancedb.VectorQueryBuilder" + The query builder object that was used to generate the results + vector_results : pa.Table + The results from the vector search + """ + raise NotImplementedError("Vector Reranking is not implemented") + + def rerank_fts(query_builder: "lancedb.FTSQueryBuilder", fts_results: pa.Table): + """ + Rerank function receives the individual results from the FTS search. + This isn't mandatory to implement + + Parameters + ---------- + query_builder : "lancedb.FTSQueryBuilder" + The query builder object that was used to generate the results + fts_results : pa.Table + The results from the FTS search + """ + raise NotImplementedError("FTS Reranking is not implemented") + + def merge_results(self, vector_results: pa.Table, fts_results: pa.Table): + """ + Merge the results from the vector and FTS search. This is a vanilla merging + function that just concatenates the results and removes the duplicates. + + NOTE: This doesn't take score into account. It'll keep the instance that was + encountered first. This is designed for rerankers that don't use the score. + In case you want to use the score, or support `return_scores="all"` you'll + have to implement your own merging function. + + Parameters + ---------- + vector_results : pa.Table + The results from the vector search + fts_results : pa.Table + The results from the FTS search + """ + combined = pa.concat_tables([vector_results, fts_results], promote=True) + row_id = combined.column("_rowid") + + # deduplicate + mask = np.full((combined.shape[0]), False) + _, mask_indices = np.unique(np.array(row_id), return_index=True) + mask[mask_indices] = True + combined = combined.filter(mask=mask) + + return combined diff --git a/python/lancedb/rerankers/cohere.py b/python/lancedb/rerankers/cohere.py new file mode 100644 index 00000000..22363bc2 --- /dev/null +++ b/python/lancedb/rerankers/cohere.py @@ -0,0 +1,85 @@ +import os +import typing +from functools import cached_property +from typing import Union + +import pyarrow as pa + +from ..util import safe_import +from .base import Reranker + +if typing.TYPE_CHECKING: + import lancedb + + +class CohereReranker(Reranker): + """ + Reranks the results using the Cohere Rerank API. + https://docs.cohere.com/docs/rerank-guide + + Parameters + ---------- + model_name : str, default "rerank-english-v2.0" + The name of the cross encoder model to use. Available cohere models are: + - rerank-english-v2.0 + - rerank-multilingual-v2.0 + column : str, default "text" + The name of the column to use as input to the cross encoder model. + top_n : str, default None + The number of results to return. If None, will return all results. + """ + + def __init__( + self, + model_name: str = "rerank-english-v2.0", + column: str = "text", + top_n: Union[int, None] = None, + return_score="relevance", + api_key: Union[str, None] = None, + ): + super().__init__(return_score) + self.model_name = model_name + self.column = column + self.top_n = top_n + self.api_key = api_key + + @cached_property + def _client(self): + cohere = safe_import("cohere") + if os.environ.get("COHERE_API_KEY") is None and self.api_key is None: + raise ValueError( + "COHERE_API_KEY not set. Either set it in your environment or \ + pass it as `api_key` argument to the CohereReranker." + ) + return cohere.Client(os.environ.get("COHERE_API_KEY") or self.api_key) + + def rerank_hybrid( + self, + query_builder: "lancedb.HybridQueryBuilder", + vector_results: pa.Table, + fts_results: pa.Table, + ): + combined_results = self.merge_results(vector_results, fts_results) + docs = combined_results[self.column].to_pylist() + results = self._client.rerank( + query=query_builder._query, + documents=docs, + top_n=self.top_n, + model=self.model_name, + ) # returns list (text, idx, relevance) attributes sorted descending by score + indices, scores = list( + zip(*[(result.index, result.relevance_score) for result in results]) + ) # tuples + combined_results = combined_results.take(list(indices)) + # add the scores + combined_results = combined_results.append_column( + "_relevance_score", pa.array(scores, type=pa.float32()) + ) + + if self.score == "relevance": + combined_results = combined_results.drop_columns(["score", "_distance"]) + elif self.score == "all": + raise NotImplementedError( + "return_score='all' not implemented for cohere reranker" + ) + return combined_results diff --git a/python/lancedb/rerankers/cross_encoder.py b/python/lancedb/rerankers/cross_encoder.py new file mode 100644 index 00000000..4d7e1c42 --- /dev/null +++ b/python/lancedb/rerankers/cross_encoder.py @@ -0,0 +1,78 @@ +import typing +from functools import cached_property +from typing import Union + +import pyarrow as pa + +from ..util import safe_import +from .base import Reranker + +if typing.TYPE_CHECKING: + import lancedb + + +class CrossEncoderReranker(Reranker): + """ + Reranks the results using a cross encoder model. The cross encoder model is + used to score the query and each result. The results are then sorted by the score. + + Parameters + ---------- + model : str, default "cross-encoder/ms-marco-TinyBERT-L-6" + The name of the cross encoder model to use. See the sentence transformers + documentation for a list of available models. + column : str, default "text" + The name of the column to use as input to the cross encoder model. + device : str, default None + The device to use for the cross encoder model. If None, will use "cuda" + if available, otherwise "cpu". + """ + + def __init__( + self, + model_name: str = "cross-encoder/ms-marco-TinyBERT-L-6", + column: str = "text", + device: Union[str, None] = None, + return_score="relevance", + ): + super().__init__(return_score) + torch = safe_import("torch") + self.model_name = model_name + self.column = column + self.device = device + if self.device is None: + self.device = "cuda" if torch.cuda.is_available() else "cpu" + + @cached_property + def model(self): + sbert = safe_import("sentence_transformers") + cross_encoder = sbert.CrossEncoder(self.model_name) + + return cross_encoder + + def rerank_hybrid( + self, + query_builder: "lancedb.HybridQueryBuilder", + vector_results: pa.Table, + fts_results: pa.Table, + ): + combined_results = self.merge_results(vector_results, fts_results) + passages = combined_results[self.column].to_pylist() + cross_inp = [[query_builder._query, passage] for passage in passages] + cross_scores = self.model.predict(cross_inp) + combined_results = combined_results.append_column( + "_relevance_score", pa.array(cross_scores, type=pa.float32()) + ) + + # sort the results by _score + if self.score == "relevance": + combined_results = combined_results.drop_columns(["score", "_distance"]) + elif self.score == "all": + raise NotImplementedError( + "return_score='all' not implemented for CrossEncoderReranker" + ) + combined_results = combined_results.sort_by( + [("_relevance_score", "descending")] + ) + + return combined_results diff --git a/python/lancedb/rerankers/linear_combination.py b/python/lancedb/rerankers/linear_combination.py new file mode 100644 index 00000000..d5032999 --- /dev/null +++ b/python/lancedb/rerankers/linear_combination.py @@ -0,0 +1,117 @@ +from typing import List + +import pyarrow as pa + +from .base import Reranker + + +class LinearCombinationReranker(Reranker): + """ + Reranks the results using a linear combination of the scores from the + vector and FTS search. For missing scores, fill with `fill` value. + Parameters + ---------- + weight : float, default 0.7 + The weight to give to the vector score. Must be between 0 and 1. + fill : float, default 1.0 + The score to give to results that are only in one of the two result sets. + This is treated as penalty, so a higher value means a lower score. + TODO: We should just hardcode this-- + its pretty confusing as we invert scores to calculate final score + return_score : str, default "relevance" + opntions are "relevance" or "all" + The type of score to return. If "relevance", will return only the relevance + score. If "all", will return all scores from the vector and FTS search along + with the relevance score. + """ + + def __init__( + self, weight: float = 0.7, fill: float = 1.0, return_score="relevance" + ): + if weight < 0 or weight > 1: + raise ValueError("weight must be between 0 and 1.") + super().__init__(return_score) + self.weight = weight + self.fill = fill + + def rerank_hybrid( + self, + query_builder: "lancedb.HybridQueryBuilder", # noqa: F821 + vector_results: pa.Table, + fts_results: pa.Table, + ): + combined_results = self.merge_results(vector_results, fts_results, self.fill) + + return combined_results + + def merge_results( + self, vector_results: pa.Table, fts_results: pa.Table, fill: float + ): + # If both are empty then just return an empty table + if len(vector_results) == 0 and len(fts_results) == 0: + return vector_results + # If one is empty then return the other + if len(vector_results) == 0: + return fts_results + if len(fts_results) == 0: + return vector_results + + # sort both input tables on _rowid + combined_list = [] + vector_list = vector_results.sort_by("_rowid").to_pylist() + fts_list = fts_results.sort_by("_rowid").to_pylist() + i, j = 0, 0 + while i < len(vector_list): + if j >= len(fts_list): + for vi in vector_list[i:]: + vi["_relevance_score"] = self._combine_score(vi["_distance"], fill) + combined_list.append(vi) + break + + vi = vector_list[i] + fj = fts_list[j] + # invert the fts score from relevance to distance + inverted_fts_score = self._invert_score(fj["score"]) + if vi["_rowid"] == fj["_rowid"]: + vi["_relevance_score"] = self._combine_score( + vi["_distance"], inverted_fts_score + ) + vi["score"] = fj["score"] # keep the original score + combined_list.append(vi) + i += 1 + j += 1 + elif vector_list[i]["_rowid"] < fts_list[j]["_rowid"]: + vi["_relevance_score"] = self._combine_score(vi["_distance"], fill) + combined_list.append(vi) + i += 1 + else: + fj["_relevance_score"] = self._combine_score(inverted_fts_score, fill) + combined_list.append(fj) + j += 1 + if j < len(fts_list) - 1: + for fj in fts_list[j:]: + fj["_relevance_score"] = self._combine_score(inverted_fts_score, fill) + combined_list.append(fj) + + relevance_score_schema = pa.schema( + [ + pa.field("_relevance_score", pa.float32()), + ] + ) + combined_schema = pa.unify_schemas( + [vector_results.schema, fts_results.schema, relevance_score_schema] + ) + tbl = pa.Table.from_pylist(combined_list, schema=combined_schema).sort_by( + [("_relevance_score", "descending")] + ) + if self.score == "relevance": + tbl = tbl.drop_columns(["score", "_distance"]) + return tbl + + def _combine_score(self, score1, score2): + # these scores represent distance + return 1 - (self.weight * score1 + (1 - self.weight) * score2) + + def _invert_score(self, scores: List[float]): + # Invert the scores between relevance and distance + return 1 - scores diff --git a/python/lancedb/table.py b/python/lancedb/table.py index 898d3bb5..d33c3b73 100644 --- a/python/lancedb/table.py +++ b/python/lancedb/table.py @@ -16,7 +16,7 @@ from __future__ import annotations import inspect from abc import ABC, abstractmethod from functools import cached_property -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union import lance import numpy as np @@ -33,8 +33,7 @@ from .query import LanceQueryBuilder, Query from .util import ( fs_from_uri, join_uri, - safe_import_pandas, - safe_import_polars, + safe_import, value_to_sql, ) from .utils.events import register_event @@ -48,8 +47,8 @@ if TYPE_CHECKING: from .db import LanceDBConnection -pd = safe_import_pandas() -pl = safe_import_polars() +pd = safe_import("pandas") +pl = safe_import("polars") def _sanitize_data( @@ -338,7 +337,7 @@ class Table(ABC): @abstractmethod def search( self, - query: Optional[Union[VEC, str, "PIL.Image.Image"]] = None, + query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None, vector_column_name: str = VECTOR_COLUMN_NAME, query_type: str = "auto", ) -> LanceQueryBuilder: @@ -924,7 +923,7 @@ class LanceTable(Table): def search( self, - query: Optional[Union[VEC, str, "PIL.Image.Image"]] = None, + query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None, vector_column_name: str = VECTOR_COLUMN_NAME, query_type: str = "auto", ) -> LanceQueryBuilder: @@ -1194,6 +1193,7 @@ class LanceTable(Table): "nprobes": query.nprobes, "refine_factor": query.refine_factor, }, + with_row_id=query.with_row_id, ) def cleanup_old_versions( diff --git a/python/lancedb/util.py b/python/lancedb/util.py index 5b4f3c65..02f37095 100644 --- a/python/lancedb/util.py +++ b/python/lancedb/util.py @@ -11,6 +11,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import importlib import os import pathlib from datetime import date, datetime @@ -114,22 +115,23 @@ def join_uri(base: Union[str, pathlib.Path], *parts: str) -> str: return "/".join([p.rstrip("/") for p in [base, *parts]]) -def safe_import_pandas(): +def safe_import(module: str, mitigation=None): + """ + Import the specified module. If the module is not installed, + raise an ImportError with a helpful message. + + Parameters + ---------- + module : str + The name of the module to import + mitigation : Optional[str] + The package(s) to install to mitigate the error. + If not provided then the module name will be used. + """ try: - import pandas as pd - - return pd + return importlib.import_module(module) except ImportError: - return None - - -def safe_import_polars(): - try: - import polars as pl - - return pl - except ImportError: - return None + raise ImportError(f"Please install {mitigation or module}") @singledispatch diff --git a/python/tests/test_rerankers.py b/python/tests/test_rerankers.py new file mode 100644 index 00000000..19be1f8e --- /dev/null +++ b/python/tests/test_rerankers.py @@ -0,0 +1,168 @@ +import os + +import numpy as np +import pytest + +import lancedb +from lancedb.conftest import MockTextEmbeddingFunction # noqa +from lancedb.embeddings import EmbeddingFunctionRegistry +from lancedb.pydantic import LanceModel, Vector +from lancedb.rerankers import CohereReranker, CrossEncoderReranker +from lancedb.table import LanceTable + + +def get_test_table(tmp_path): + db = lancedb.connect(tmp_path) + # Create a LanceDB table schema with a vector and a text column + emb = EmbeddingFunctionRegistry.get_instance().get("test")() + + class MyTable(LanceModel): + text: str = emb.SourceField() + vector: Vector(emb.ndims()) = emb.VectorField() + + # Initialize the table using the schema + table = LanceTable.create( + db, + "my_table", + schema=MyTable, + ) + + # Need to test with a bunch of phrases to make sure sorting is consistent + phrases = [ + "great kid don't get cocky", + "now that's a name I haven't heard in a long time", + "if you strike me down I shall become more powerful than you imagine", + "I find your lack of faith disturbing", + "I've got a bad feeling about this", + "never tell me the odds", + "I am your father", + "somebody has to save our skins", + "New strategy R2 let the wookiee win", + "Arrrrggghhhhhhh", + "I see a mansard roof through the trees", + "I see a salty message written in the eves", + "the ground beneath my feet", + "the hot garbage and concrete", + "and now the tops of buildings", + "everybody with a worried mind could never forgive the sight", + "of wicked snakes inside a place you thought was dignified", + "I don't wanna live like this", + "but I don't wanna die", + "The templars want control", + "the brotherhood of assassins want freedom", + "if only they could both see the world as it really is", + "there would be peace", + "but the war goes on", + "altair's legacy was a warning", + "Kratos had a son", + "he was a god", + "the god of war", + "but his son was mortal", + "there hasn't been a good battlefield game since 2142", + "I wish they would make another one", + "campains are not as good as they used to be", + "Multiplayer and open world games have destroyed the single player experience", + "Maybe the future is console games", + "I don't know", + ] + + # Add the phrases and vectors to the table + table.add([{"text": p} for p in phrases]) + + # Create a fts index + table.create_fts_index("text") + + return table, MyTable + + +## These tests are pretty loose, we should also check for correctness +def test_linear_combination(tmp_path): + table, schema = get_test_table(tmp_path) + # The default reranker + result1 = ( + table.search("Our father who art in heaven", query_type="hybrid") + .rerank(normalize="score") + .to_pydantic(schema) + ) + result2 = ( # noqa + table.search("Our father who art in heaven.", query_type="hybrid") + .rerank(normalize="rank") + .to_pydantic(schema) + ) + result3 = table.search( + "Our father who art in heaven..", query_type="hybrid" + ).to_pydantic(schema) + + assert result1 == result3 # 2 & 3 should be the same as they use score as score + + result = ( + table.search("Our father who art in heaven", query_type="hybrid") + .limit(50) + .rerank(normalize="score") + .to_arrow() + ) + assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), ( + "The _score column of the results returned by the reranker " + "represents the relevance of the result to the query & should " + "be descending." + ) + + +@pytest.mark.skipif( + os.environ.get("COHERE_API_KEY") is None, reason="COHERE_API_KEY not set" +) +def test_cohere_reranker(tmp_path): + pytest.importorskip("cohere") + table, schema = get_test_table(tmp_path) + # The default reranker + result1 = ( + table.search("Our father who art in heaven", query_type="hybrid") + .rerank(normalize="score", reranker=CohereReranker()) + .to_pydantic(schema) + ) + result2 = ( + table.search("Our father who art in heaven", query_type="hybrid") + .rerank(normalize="rank", reranker=CohereReranker()) + .to_pydantic(schema) + ) + assert result1 == result2 + + result = ( + table.search("Our father who art in heaven", query_type="hybrid") + .limit(50) + .rerank(reranker=CohereReranker()) + .to_arrow() + ) + assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), ( + "The _score column of the results returned by the reranker " + "represents the relevance of the result to the query & should " + "be descending." + ) + + +def test_cross_encoder_reranker(tmp_path): + pytest.importorskip("sentence_transformers") + table, schema = get_test_table(tmp_path) + result1 = ( + table.search("Our father who art in heaven", query_type="hybrid") + .rerank(normalize="score", reranker=CrossEncoderReranker()) + .to_pydantic(schema) + ) + result2 = ( + table.search("Our father who art in heaven", query_type="hybrid") + .rerank(normalize="rank", reranker=CrossEncoderReranker()) + .to_pydantic(schema) + ) + assert result1 == result2 + + result = ( + table.search("Our father who art in heaven", query_type="hybrid") + .limit(50) + .rerank(reranker=CrossEncoderReranker()) + .to_arrow() + ) + assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), ( + "The _score column of the results returned by the reranker " + "represents the relevance of the result to the query & should " + "be descending." + ) diff --git a/python/tests/test_table.py b/python/tests/test_table.py index 0dadbbbd..f5118caa 100644 --- a/python/tests/test_table.py +++ b/python/tests/test_table.py @@ -682,3 +682,57 @@ def test_count_rows(db): assert len(table) == 2 assert table.count_rows() == 2 assert table.count_rows(filter="text='bar'") == 1 + + +def test_hybrid_search(db): + # hardcoding temporarily.. this test is failing with tmp_path mockdb. + # Probably not being parsed right by the fts + db = MockDB("~/lancedb_") + # Create a LanceDB table schema with a vector and a text column + emb = EmbeddingFunctionRegistry.get_instance().get("test")() + + class MyTable(LanceModel): + text: str = emb.SourceField() + vector: Vector(emb.ndims()) = emb.VectorField() + + # Initialize the table using the schema + table = LanceTable.create( + db, + "my_table", + schema=MyTable, + ) + + # Create a list of 10 unique english phrases + phrases = [ + "great kid don't get cocky", + "now that's a name I haven't heard in a long time", + "if you strike me down I shall become more powerful than you imagine", + "I find your lack of faith disturbing", + "I've got a bad feeling about this", + "never tell me the odds", + "I am your father", + "somebody has to save our skins", + "New strategy R2 let the wookiee win", + "Arrrrggghhhhhhh", + ] + + # Add the phrases and vectors to the table + table.add([{"text": p} for p in phrases]) + + # Create a fts index + table.create_fts_index("text") + + result1 = ( + table.search("Our father who art in heaven", query_type="hybrid") + .rerank(normalize="score") + .to_pydantic(MyTable) + ) + result2 = ( # noqa + table.search("Our father who art in heaven", query_type="hybrid") + .rerank(normalize="rank") + .to_pydantic(MyTable) + ) + result3 = table.search( + "Our father who art in heaven", query_type="hybrid" + ).to_pydantic(MyTable) + assert result1 == result3