mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-27 15:12:53 +00:00
feat(python): Hybrid search & Reranker API (#824)
based on https://github.com/lancedb/lancedb/pull/713 - The Reranker api can be plugged into vector only or fts only search but this PR doesn't do that (see example - https://txt.cohere.com/rerank/) ### Default reranker -- `LinearCombinationReranker(weight=0.7, fill=1.0)` ``` table.search("hello", query_type="hybrid").rerank(normalize="score").to_pandas() ``` ### Available rerankers LinearCombinationReranker ``` from lancedb.rerankers import LinearCombinationReranker # Same as default table.search("hello", query_type="hybrid").rerank( normalize="score", reranker=LinearCombinationReranker() ).to_pandas() # with custom params reranker = LinearCombinationReranker(weight=0.3, fill=1.0) table.search("hello", query_type="hybrid").rerank( normalize="score", reranker=reranker ).to_pandas() ``` Cohere Reranker ``` from lancedb.rerankers import CohereReranker # default model.. English and multi-lingual supported. See docstring for available custom params table.search("hello", query_type="hybrid").rerank( normalize="rank", # score or rank reranker=CohereReranker() ).to_pandas() ``` CrossEncoderReranker ``` from lancedb.rerankers import CrossEncoderReranker table.search("hello", query_type="hybrid").rerank( normalize="rank", reranker=CrossEncoderReranker() ).to_pandas() ``` ## Using custom Reranker ``` from lancedb.reranker import Reranker class CustomReranker(Reranker): def rerank_hybrid(self, vector_result, fts_result): combined_res = self.merge_results(vector_results, fts_results) # or use custom combination logic # Custom rerank logic here return combined_res ``` - [x] Expand testing - [x] Make sure usage makes sense - [x] Run simple benchmarks for correctness (Seeing weird result from cohere reranker in the toy example) - Support diverse rerankers by default: - [x] Cross encoding - [x] Cohere - [x] Reciprocal Rank Fusion --------- Co-authored-by: Chang She <759245+changhiskhan@users.noreply.github.com> Co-authored-by: Prashanth Rao <35005448+prrao87@users.noreply.github.com>
This commit is contained in:
@@ -16,9 +16,9 @@ from typing import Iterable, List, Union
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
|
||||
from .util import safe_import_pandas
|
||||
from .util import safe_import
|
||||
|
||||
pd = safe_import_pandas()
|
||||
pd = safe_import("pandas")
|
||||
|
||||
DATA = Union[List[dict], dict, "pd.DataFrame", pa.Table, Iterable[pa.RecordBatch]]
|
||||
VEC = Union[list, np.ndarray, pa.Array, pa.ChunkedArray]
|
||||
|
||||
@@ -16,9 +16,9 @@ import deprecation
|
||||
|
||||
from . import __version__
|
||||
from .exceptions import MissingColumnError, MissingValueError
|
||||
from .util import safe_import_pandas
|
||||
from .util import safe_import
|
||||
|
||||
pd = safe_import_pandas()
|
||||
pd = safe_import("pandas")
|
||||
|
||||
|
||||
def contextualize(raw_df: "pd.DataFrame") -> Contextualizer:
|
||||
|
||||
@@ -26,10 +26,10 @@ import pyarrow as pa
|
||||
from lance.vector import vec_to_table
|
||||
from retry import retry
|
||||
|
||||
from ..util import safe_import_pandas
|
||||
from ..util import safe_import
|
||||
from ..utils.general import LOGGER
|
||||
|
||||
pd = safe_import_pandas()
|
||||
pd = safe_import("pandas")
|
||||
|
||||
DATA = Union[pa.Table, "pd.DataFrame"]
|
||||
TEXT = Union[str, List[str], pa.Array, pa.ChunkedArray, np.ndarray]
|
||||
|
||||
@@ -14,8 +14,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, List, Literal, Optional, Type, Union
|
||||
from typing import TYPE_CHECKING, List, Literal, Optional, Tuple, Type, Union
|
||||
|
||||
import deprecation
|
||||
import numpy as np
|
||||
@@ -23,8 +24,10 @@ import pyarrow as pa
|
||||
import pydantic
|
||||
|
||||
from . import __version__
|
||||
from .common import VECTOR_COLUMN_NAME
|
||||
from .util import safe_import_pandas
|
||||
from .common import VEC, VECTOR_COLUMN_NAME
|
||||
from .rerankers.base import Reranker
|
||||
from .rerankers.linear_combination import LinearCombinationReranker
|
||||
from .util import safe_import
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import PIL
|
||||
@@ -33,7 +36,7 @@ if TYPE_CHECKING:
|
||||
from .pydantic import LanceModel
|
||||
from .table import Table
|
||||
|
||||
pd = safe_import_pandas()
|
||||
pd = safe_import("pandas")
|
||||
|
||||
|
||||
class Query(pydantic.BaseModel):
|
||||
@@ -99,6 +102,8 @@ class Query(pydantic.BaseModel):
|
||||
# Refine factor.
|
||||
refine_factor: Optional[int] = None
|
||||
|
||||
with_row_id: bool = False
|
||||
|
||||
|
||||
class LanceQueryBuilder(ABC):
|
||||
"""Build LanceDB query based on specific query type:
|
||||
@@ -109,19 +114,26 @@ class LanceQueryBuilder(ABC):
|
||||
def create(
|
||||
cls,
|
||||
table: "Table",
|
||||
query: Optional[Union[np.ndarray, str, "PIL.Image.Image"]],
|
||||
query: Optional[Union[np.ndarray, str, "PIL.Image.Image", Tuple]],
|
||||
query_type: str,
|
||||
vector_column_name: str,
|
||||
) -> LanceQueryBuilder:
|
||||
if query is None:
|
||||
return LanceEmptyQueryBuilder(table)
|
||||
|
||||
# convert "auto" query_type to "vector" or "fts"
|
||||
# and convert the query to vector if needed
|
||||
if query_type == "hybrid":
|
||||
# hybrid fts and vector query
|
||||
return LanceHybridQueryBuilder(table, query, vector_column_name)
|
||||
|
||||
# convert "auto" query_type to "vector", "fts"
|
||||
# or "hybrid" and convert the query to vector if needed
|
||||
query, query_type = cls._resolve_query(
|
||||
table, query, query_type, vector_column_name
|
||||
)
|
||||
|
||||
if query_type == "hybrid":
|
||||
return LanceHybridQueryBuilder(table, query, vector_column_name)
|
||||
|
||||
if isinstance(query, str):
|
||||
# fts
|
||||
return LanceFtsQueryBuilder(table, query)
|
||||
@@ -144,17 +156,13 @@ class LanceQueryBuilder(ABC):
|
||||
raise TypeError(f"'fts' queries must be a string: {type(query)}")
|
||||
return query, query_type
|
||||
elif query_type == "vector":
|
||||
if not isinstance(query, (list, np.ndarray)):
|
||||
conf = table.embedding_functions.get(vector_column_name)
|
||||
if conf is not None:
|
||||
query = conf.function.compute_query_embeddings_with_retry(query)[0]
|
||||
else:
|
||||
msg = f"No embedding function for {vector_column_name}"
|
||||
raise ValueError(msg)
|
||||
query = cls._query_to_vector(table, query, vector_column_name)
|
||||
return query, query_type
|
||||
elif query_type == "auto":
|
||||
if isinstance(query, (list, np.ndarray)):
|
||||
return query, "vector"
|
||||
if isinstance(query, tuple):
|
||||
return query, "hybrid"
|
||||
else:
|
||||
conf = table.embedding_functions.get(vector_column_name)
|
||||
if conf is not None:
|
||||
@@ -167,11 +175,23 @@ class LanceQueryBuilder(ABC):
|
||||
f"Invalid query_type, must be 'vector', 'fts', or 'auto': {query_type}"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _query_to_vector(cls, table, query, vector_column_name):
|
||||
if isinstance(query, (list, np.ndarray)):
|
||||
return query
|
||||
conf = table.embedding_functions.get(vector_column_name)
|
||||
if conf is not None:
|
||||
return conf.function.compute_query_embeddings_with_retry(query)[0]
|
||||
else:
|
||||
msg = f"No embedding function for {vector_column_name}"
|
||||
raise ValueError(msg)
|
||||
|
||||
def __init__(self, table: "Table"):
|
||||
self._table = table
|
||||
self._limit = 10
|
||||
self._columns = None
|
||||
self._where = None
|
||||
self._with_row_id = False
|
||||
|
||||
@deprecation.deprecated(
|
||||
deprecated_in="0.3.1",
|
||||
@@ -341,6 +361,22 @@ class LanceQueryBuilder(ABC):
|
||||
self._prefilter = prefilter
|
||||
return self
|
||||
|
||||
def with_row_id(self, with_row_id: bool) -> LanceQueryBuilder:
|
||||
"""Set whether to return row ids.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
with_row_id: bool
|
||||
If True, return _rowid column in the results.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceQueryBuilder
|
||||
The LanceQueryBuilder object.
|
||||
"""
|
||||
self._with_row_id = with_row_id
|
||||
return self
|
||||
|
||||
|
||||
class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
"""
|
||||
@@ -459,6 +495,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
nprobes=self._nprobes,
|
||||
refine_factor=self._refine_factor,
|
||||
vector_column=self._vector_column,
|
||||
with_row_id=self._with_row_id,
|
||||
)
|
||||
return self._table._execute_query(query)
|
||||
|
||||
@@ -568,6 +605,10 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
|
||||
ds = lance.write_dataset(output_tbl, tmp)
|
||||
output_tbl = ds.to_table(filter=self._where)
|
||||
|
||||
if self._with_row_id:
|
||||
# Need to set this to uint explicitly as vector results are in uint64
|
||||
row_ids = pa.array(row_ids, type=pa.uint64())
|
||||
output_tbl = output_tbl.append_column("_rowid", row_ids)
|
||||
return output_tbl
|
||||
|
||||
|
||||
@@ -579,3 +620,258 @@ class LanceEmptyQueryBuilder(LanceQueryBuilder):
|
||||
filter=self._where,
|
||||
limit=self._limit,
|
||||
)
|
||||
|
||||
|
||||
class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||
def __init__(self, table: "Table", query: str, vector_column: str):
|
||||
super().__init__(table)
|
||||
self._validate_fts_index()
|
||||
self._query = query
|
||||
vector_query, fts_query = self._validate_query(query)
|
||||
self._fts_query = LanceFtsQueryBuilder(table, fts_query)
|
||||
vector_query = self._query_to_vector(table, vector_query, vector_column)
|
||||
self._vector_query = LanceVectorQueryBuilder(table, vector_query, vector_column)
|
||||
self._norm = "score"
|
||||
self._reranker = LinearCombinationReranker(weight=0.7, fill=1.0)
|
||||
|
||||
def _validate_fts_index(self):
|
||||
if self._table._get_fts_index_path() is None:
|
||||
raise ValueError(
|
||||
"Please create a full-text search index " "to perform hybrid search."
|
||||
)
|
||||
|
||||
def _validate_query(self, query):
|
||||
# Temp hack to support vectorized queries for hybrid search
|
||||
if isinstance(query, str):
|
||||
return query, query
|
||||
elif isinstance(query, tuple):
|
||||
if len(query) != 2:
|
||||
raise ValueError(
|
||||
"The query must be a tuple of (vector_query, fts_query)."
|
||||
)
|
||||
if not isinstance(query[0], (list, np.ndarray, pa.Array, pa.ChunkedArray)):
|
||||
raise ValueError(f"The vector query must be one of {VEC}.")
|
||||
if not isinstance(query[1], str):
|
||||
raise ValueError("The fts query must be a string.")
|
||||
return query[0], query[1]
|
||||
else:
|
||||
raise ValueError(
|
||||
"The query must be either a string or a tuple of (vector, string)."
|
||||
)
|
||||
|
||||
def to_arrow(self) -> pa.Table:
|
||||
with ThreadPoolExecutor() as executor:
|
||||
fts_future = executor.submit(self._fts_query.with_row_id(True).to_arrow)
|
||||
vector_future = executor.submit(
|
||||
self._vector_query.with_row_id(True).to_arrow
|
||||
)
|
||||
fts_results = fts_future.result()
|
||||
vector_results = vector_future.result()
|
||||
|
||||
# convert to ranks first if needed
|
||||
if self._norm == "rank":
|
||||
vector_results = self._rank(vector_results, "_distance")
|
||||
fts_results = self._rank(fts_results, "score")
|
||||
# normalize the scores to be between 0 and 1, 0 being most relevant
|
||||
vector_results = self._normalize_scores(vector_results, "_distance")
|
||||
|
||||
# In fts higher scores represent relevance. Not inverting them here as
|
||||
# rerankers might need to preserve this score to support `return_score="all"`
|
||||
fts_results = self._normalize_scores(fts_results, "score")
|
||||
|
||||
results = self._reranker.rerank_hybrid(self, vector_results, fts_results)
|
||||
if not isinstance(results, pa.Table): # Enforce type
|
||||
raise TypeError(
|
||||
f"rerank_hybrid must return a pyarrow.Table, got {type(results)}"
|
||||
)
|
||||
|
||||
if not self._with_row_id:
|
||||
results = results.drop(["_rowid"])
|
||||
return results
|
||||
|
||||
def _rank(self, results: pa.Table, column: str, ascending: bool = True):
|
||||
if len(results) == 0:
|
||||
return results
|
||||
# Get the _score column from results
|
||||
scores = results.column(column).to_numpy()
|
||||
sort_indices = np.argsort(scores)
|
||||
if not ascending:
|
||||
sort_indices = sort_indices[::-1]
|
||||
ranks = np.empty_like(sort_indices)
|
||||
ranks[sort_indices] = np.arange(len(scores)) + 1
|
||||
# replace the _score column with the ranks
|
||||
_score_idx = results.column_names.index(column)
|
||||
results = results.set_column(
|
||||
_score_idx, column, pa.array(ranks, type=pa.float32())
|
||||
)
|
||||
return results
|
||||
|
||||
def _normalize_scores(self, results: pa.Table, column: str, invert=False):
|
||||
if len(results) == 0:
|
||||
return results
|
||||
# Get the _score column from results
|
||||
scores = results.column(column).to_numpy()
|
||||
# normalize the scores by subtracting the min and dividing by the max
|
||||
max, min = np.max(scores), np.min(scores)
|
||||
if np.isclose(max, min):
|
||||
rng = max
|
||||
else:
|
||||
rng = max - min
|
||||
scores = (scores - min) / rng
|
||||
if invert:
|
||||
scores = 1 - scores
|
||||
# replace the _score column with the ranks
|
||||
_score_idx = results.column_names.index(column)
|
||||
results = results.set_column(
|
||||
_score_idx, column, pa.array(scores, type=pa.float32())
|
||||
)
|
||||
return results
|
||||
|
||||
def rerank(
|
||||
self,
|
||||
normalize="score",
|
||||
reranker: Reranker = LinearCombinationReranker(weight=0.7, fill=1.0),
|
||||
) -> LanceHybridQueryBuilder:
|
||||
"""
|
||||
Rerank the hybrid search results using the specified reranker. The reranker
|
||||
must be an instance of Reranker class.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
normalize: str, default "score"
|
||||
The method to normalize the scores. Can be "rank" or "score". If "rank",
|
||||
the scores are converted to ranks and then normalized. If "score", the
|
||||
scores are normalized directly.
|
||||
reranker: Reranker, default LinearCombinationReranker(weight=0.7, fill=1.0)
|
||||
The reranker to use. Must be an instance of Reranker class.
|
||||
Returns
|
||||
-------
|
||||
LanceHybridQueryBuilder
|
||||
The LanceHybridQueryBuilder object.
|
||||
"""
|
||||
if normalize not in ["rank", "score"]:
|
||||
raise ValueError("normalize must be 'rank' or 'score'.")
|
||||
if reranker and not isinstance(reranker, Reranker):
|
||||
raise ValueError("reranker must be an instance of Reranker class.")
|
||||
|
||||
self._norm = normalize
|
||||
self._reranker = reranker
|
||||
|
||||
return self
|
||||
|
||||
def limit(self, limit: int) -> LanceHybridQueryBuilder:
|
||||
"""
|
||||
Set the maximum number of results to return for both vector and fts search
|
||||
components.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
limit: int
|
||||
The maximum number of results to return.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceHybridQueryBuilder
|
||||
The LanceHybridQueryBuilder object.
|
||||
"""
|
||||
self._vector_query.limit(limit)
|
||||
self._fts_query.limit(limit)
|
||||
return self
|
||||
|
||||
def select(self, columns: list) -> LanceHybridQueryBuilder:
|
||||
"""
|
||||
Set the columns to return for both vector and fts search.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
columns: list
|
||||
The columns to return.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceHybridQueryBuilder
|
||||
The LanceHybridQueryBuilder object.
|
||||
"""
|
||||
self._vector_query.select(columns)
|
||||
self._fts_query.select(columns)
|
||||
return self
|
||||
|
||||
def where(self, where: str, prefilter: bool = False) -> LanceHybridQueryBuilder:
|
||||
"""
|
||||
Set the where clause for both vector and fts search.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
where: str
|
||||
The where clause which is a valid SQL where clause. See
|
||||
`Lance filter pushdown <https://lancedb.github.io/lance/read_and_write.html#filter-push-down>`_
|
||||
for valid SQL expressions.
|
||||
|
||||
prefilter: bool, default False
|
||||
If True, apply the filter before vector search, otherwise the
|
||||
filter is applied on the result of vector search.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceHybridQueryBuilder
|
||||
The LanceHybridQueryBuilder object.
|
||||
"""
|
||||
|
||||
self._vector_query.where(where, prefilter=prefilter)
|
||||
self._fts_query.where(where)
|
||||
return self
|
||||
|
||||
def metric(self, metric: Literal["L2", "cosine"]) -> LanceHybridQueryBuilder:
|
||||
"""
|
||||
Set the distance metric to use for vector search.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
metric: "L2" or "cosine"
|
||||
The distance metric to use. By default "L2" is used.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceHybridQueryBuilder
|
||||
The LanceHybridQueryBuilder object.
|
||||
"""
|
||||
self._vector_query.metric(metric)
|
||||
return self
|
||||
|
||||
def nprobes(self, nprobes: int) -> LanceHybridQueryBuilder:
|
||||
"""
|
||||
Set the number of probes to use for vector search.
|
||||
|
||||
Higher values will yield better recall (more likely to find vectors if
|
||||
they exist) at the expense of latency.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
nprobes: int
|
||||
The number of probes to use.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceHybridQueryBuilder
|
||||
The LanceHybridQueryBuilder object.
|
||||
"""
|
||||
self._vector_query.nprobes(nprobes)
|
||||
return self
|
||||
|
||||
def refine_factor(self, refine_factor: int) -> LanceHybridQueryBuilder:
|
||||
"""
|
||||
Refine the vector search results by reading extra elements and
|
||||
re-ranking them in memory.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
refine_factor: int
|
||||
The refine factor to use.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceHybridQueryBuilder
|
||||
The LanceHybridQueryBuilder object.
|
||||
"""
|
||||
self._vector_query.refine_factor(refine_factor)
|
||||
return self
|
||||
|
||||
11
python/lancedb/rerankers/__init__.py
Normal file
11
python/lancedb/rerankers/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from .base import Reranker
|
||||
from .cohere import CohereReranker
|
||||
from .cross_encoder import CrossEncoderReranker
|
||||
from .linear_combination import LinearCombinationReranker
|
||||
|
||||
__all__ = [
|
||||
"Reranker",
|
||||
"CrossEncoderReranker",
|
||||
"CohereReranker",
|
||||
"LinearCombinationReranker",
|
||||
]
|
||||
109
python/lancedb/rerankers/base.py
Normal file
109
python/lancedb/rerankers/base.py
Normal file
@@ -0,0 +1,109 @@
|
||||
import typing
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
import lancedb
|
||||
|
||||
|
||||
class Reranker(ABC):
|
||||
def __init__(self, return_score: str = "relevance"):
|
||||
"""
|
||||
Interface for a reranker. A reranker is used to rerank the results from a
|
||||
vector and FTS search. This is useful for combining the results from both
|
||||
search methods.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
return_score : str, default "relevance"
|
||||
opntions are "relevance" or "all"
|
||||
The type of score to return. If "relevance", will return only the relevance
|
||||
score. If "all", will return all scores from the vector and FTS search along
|
||||
with the relevance score.
|
||||
|
||||
"""
|
||||
if return_score not in ["relevance", "all"]:
|
||||
raise ValueError("score must be either 'relevance' or 'all'")
|
||||
self.score = return_score
|
||||
|
||||
@abstractmethod
|
||||
def rerank_hybrid(
|
||||
query_builder: "lancedb.HybridQueryBuilder",
|
||||
vector_results: pa.Table,
|
||||
fts_results: pa.Table,
|
||||
):
|
||||
"""
|
||||
Rerank function receives the individual results from the vector and FTS search
|
||||
results. You can choose to use any of the results to generate the final results,
|
||||
allowing maximum flexibility. This is mandatory to implement
|
||||
|
||||
Parameters
|
||||
----------
|
||||
query_builder : "lancedb.HybridQueryBuilder"
|
||||
The query builder object that was used to generate the results
|
||||
vector_results : pa.Table
|
||||
The results from the vector search
|
||||
fts_results : pa.Table
|
||||
The results from the FTS search
|
||||
"""
|
||||
pass
|
||||
|
||||
def rerank_vector(
|
||||
query_builder: "lancedb.VectorQueryBuilder", vector_results: pa.Table
|
||||
):
|
||||
"""
|
||||
Rerank function receives the individual results from the vector search.
|
||||
This isn't mandatory to implement
|
||||
|
||||
Parameters
|
||||
----------
|
||||
query_builder : "lancedb.VectorQueryBuilder"
|
||||
The query builder object that was used to generate the results
|
||||
vector_results : pa.Table
|
||||
The results from the vector search
|
||||
"""
|
||||
raise NotImplementedError("Vector Reranking is not implemented")
|
||||
|
||||
def rerank_fts(query_builder: "lancedb.FTSQueryBuilder", fts_results: pa.Table):
|
||||
"""
|
||||
Rerank function receives the individual results from the FTS search.
|
||||
This isn't mandatory to implement
|
||||
|
||||
Parameters
|
||||
----------
|
||||
query_builder : "lancedb.FTSQueryBuilder"
|
||||
The query builder object that was used to generate the results
|
||||
fts_results : pa.Table
|
||||
The results from the FTS search
|
||||
"""
|
||||
raise NotImplementedError("FTS Reranking is not implemented")
|
||||
|
||||
def merge_results(self, vector_results: pa.Table, fts_results: pa.Table):
|
||||
"""
|
||||
Merge the results from the vector and FTS search. This is a vanilla merging
|
||||
function that just concatenates the results and removes the duplicates.
|
||||
|
||||
NOTE: This doesn't take score into account. It'll keep the instance that was
|
||||
encountered first. This is designed for rerankers that don't use the score.
|
||||
In case you want to use the score, or support `return_scores="all"` you'll
|
||||
have to implement your own merging function.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
vector_results : pa.Table
|
||||
The results from the vector search
|
||||
fts_results : pa.Table
|
||||
The results from the FTS search
|
||||
"""
|
||||
combined = pa.concat_tables([vector_results, fts_results], promote=True)
|
||||
row_id = combined.column("_rowid")
|
||||
|
||||
# deduplicate
|
||||
mask = np.full((combined.shape[0]), False)
|
||||
_, mask_indices = np.unique(np.array(row_id), return_index=True)
|
||||
mask[mask_indices] = True
|
||||
combined = combined.filter(mask=mask)
|
||||
|
||||
return combined
|
||||
85
python/lancedb/rerankers/cohere.py
Normal file
85
python/lancedb/rerankers/cohere.py
Normal file
@@ -0,0 +1,85 @@
|
||||
import os
|
||||
import typing
|
||||
from functools import cached_property
|
||||
from typing import Union
|
||||
|
||||
import pyarrow as pa
|
||||
|
||||
from ..util import safe_import
|
||||
from .base import Reranker
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
import lancedb
|
||||
|
||||
|
||||
class CohereReranker(Reranker):
|
||||
"""
|
||||
Reranks the results using the Cohere Rerank API.
|
||||
https://docs.cohere.com/docs/rerank-guide
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model_name : str, default "rerank-english-v2.0"
|
||||
The name of the cross encoder model to use. Available cohere models are:
|
||||
- rerank-english-v2.0
|
||||
- rerank-multilingual-v2.0
|
||||
column : str, default "text"
|
||||
The name of the column to use as input to the cross encoder model.
|
||||
top_n : str, default None
|
||||
The number of results to return. If None, will return all results.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "rerank-english-v2.0",
|
||||
column: str = "text",
|
||||
top_n: Union[int, None] = None,
|
||||
return_score="relevance",
|
||||
api_key: Union[str, None] = None,
|
||||
):
|
||||
super().__init__(return_score)
|
||||
self.model_name = model_name
|
||||
self.column = column
|
||||
self.top_n = top_n
|
||||
self.api_key = api_key
|
||||
|
||||
@cached_property
|
||||
def _client(self):
|
||||
cohere = safe_import("cohere")
|
||||
if os.environ.get("COHERE_API_KEY") is None and self.api_key is None:
|
||||
raise ValueError(
|
||||
"COHERE_API_KEY not set. Either set it in your environment or \
|
||||
pass it as `api_key` argument to the CohereReranker."
|
||||
)
|
||||
return cohere.Client(os.environ.get("COHERE_API_KEY") or self.api_key)
|
||||
|
||||
def rerank_hybrid(
|
||||
self,
|
||||
query_builder: "lancedb.HybridQueryBuilder",
|
||||
vector_results: pa.Table,
|
||||
fts_results: pa.Table,
|
||||
):
|
||||
combined_results = self.merge_results(vector_results, fts_results)
|
||||
docs = combined_results[self.column].to_pylist()
|
||||
results = self._client.rerank(
|
||||
query=query_builder._query,
|
||||
documents=docs,
|
||||
top_n=self.top_n,
|
||||
model=self.model_name,
|
||||
) # returns list (text, idx, relevance) attributes sorted descending by score
|
||||
indices, scores = list(
|
||||
zip(*[(result.index, result.relevance_score) for result in results])
|
||||
) # tuples
|
||||
combined_results = combined_results.take(list(indices))
|
||||
# add the scores
|
||||
combined_results = combined_results.append_column(
|
||||
"_relevance_score", pa.array(scores, type=pa.float32())
|
||||
)
|
||||
|
||||
if self.score == "relevance":
|
||||
combined_results = combined_results.drop_columns(["score", "_distance"])
|
||||
elif self.score == "all":
|
||||
raise NotImplementedError(
|
||||
"return_score='all' not implemented for cohere reranker"
|
||||
)
|
||||
return combined_results
|
||||
78
python/lancedb/rerankers/cross_encoder.py
Normal file
78
python/lancedb/rerankers/cross_encoder.py
Normal file
@@ -0,0 +1,78 @@
|
||||
import typing
|
||||
from functools import cached_property
|
||||
from typing import Union
|
||||
|
||||
import pyarrow as pa
|
||||
|
||||
from ..util import safe_import
|
||||
from .base import Reranker
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
import lancedb
|
||||
|
||||
|
||||
class CrossEncoderReranker(Reranker):
|
||||
"""
|
||||
Reranks the results using a cross encoder model. The cross encoder model is
|
||||
used to score the query and each result. The results are then sorted by the score.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model : str, default "cross-encoder/ms-marco-TinyBERT-L-6"
|
||||
The name of the cross encoder model to use. See the sentence transformers
|
||||
documentation for a list of available models.
|
||||
column : str, default "text"
|
||||
The name of the column to use as input to the cross encoder model.
|
||||
device : str, default None
|
||||
The device to use for the cross encoder model. If None, will use "cuda"
|
||||
if available, otherwise "cpu".
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "cross-encoder/ms-marco-TinyBERT-L-6",
|
||||
column: str = "text",
|
||||
device: Union[str, None] = None,
|
||||
return_score="relevance",
|
||||
):
|
||||
super().__init__(return_score)
|
||||
torch = safe_import("torch")
|
||||
self.model_name = model_name
|
||||
self.column = column
|
||||
self.device = device
|
||||
if self.device is None:
|
||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
@cached_property
|
||||
def model(self):
|
||||
sbert = safe_import("sentence_transformers")
|
||||
cross_encoder = sbert.CrossEncoder(self.model_name)
|
||||
|
||||
return cross_encoder
|
||||
|
||||
def rerank_hybrid(
|
||||
self,
|
||||
query_builder: "lancedb.HybridQueryBuilder",
|
||||
vector_results: pa.Table,
|
||||
fts_results: pa.Table,
|
||||
):
|
||||
combined_results = self.merge_results(vector_results, fts_results)
|
||||
passages = combined_results[self.column].to_pylist()
|
||||
cross_inp = [[query_builder._query, passage] for passage in passages]
|
||||
cross_scores = self.model.predict(cross_inp)
|
||||
combined_results = combined_results.append_column(
|
||||
"_relevance_score", pa.array(cross_scores, type=pa.float32())
|
||||
)
|
||||
|
||||
# sort the results by _score
|
||||
if self.score == "relevance":
|
||||
combined_results = combined_results.drop_columns(["score", "_distance"])
|
||||
elif self.score == "all":
|
||||
raise NotImplementedError(
|
||||
"return_score='all' not implemented for CrossEncoderReranker"
|
||||
)
|
||||
combined_results = combined_results.sort_by(
|
||||
[("_relevance_score", "descending")]
|
||||
)
|
||||
|
||||
return combined_results
|
||||
117
python/lancedb/rerankers/linear_combination.py
Normal file
117
python/lancedb/rerankers/linear_combination.py
Normal file
@@ -0,0 +1,117 @@
|
||||
from typing import List
|
||||
|
||||
import pyarrow as pa
|
||||
|
||||
from .base import Reranker
|
||||
|
||||
|
||||
class LinearCombinationReranker(Reranker):
|
||||
"""
|
||||
Reranks the results using a linear combination of the scores from the
|
||||
vector and FTS search. For missing scores, fill with `fill` value.
|
||||
Parameters
|
||||
----------
|
||||
weight : float, default 0.7
|
||||
The weight to give to the vector score. Must be between 0 and 1.
|
||||
fill : float, default 1.0
|
||||
The score to give to results that are only in one of the two result sets.
|
||||
This is treated as penalty, so a higher value means a lower score.
|
||||
TODO: We should just hardcode this--
|
||||
its pretty confusing as we invert scores to calculate final score
|
||||
return_score : str, default "relevance"
|
||||
opntions are "relevance" or "all"
|
||||
The type of score to return. If "relevance", will return only the relevance
|
||||
score. If "all", will return all scores from the vector and FTS search along
|
||||
with the relevance score.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, weight: float = 0.7, fill: float = 1.0, return_score="relevance"
|
||||
):
|
||||
if weight < 0 or weight > 1:
|
||||
raise ValueError("weight must be between 0 and 1.")
|
||||
super().__init__(return_score)
|
||||
self.weight = weight
|
||||
self.fill = fill
|
||||
|
||||
def rerank_hybrid(
|
||||
self,
|
||||
query_builder: "lancedb.HybridQueryBuilder", # noqa: F821
|
||||
vector_results: pa.Table,
|
||||
fts_results: pa.Table,
|
||||
):
|
||||
combined_results = self.merge_results(vector_results, fts_results, self.fill)
|
||||
|
||||
return combined_results
|
||||
|
||||
def merge_results(
|
||||
self, vector_results: pa.Table, fts_results: pa.Table, fill: float
|
||||
):
|
||||
# If both are empty then just return an empty table
|
||||
if len(vector_results) == 0 and len(fts_results) == 0:
|
||||
return vector_results
|
||||
# If one is empty then return the other
|
||||
if len(vector_results) == 0:
|
||||
return fts_results
|
||||
if len(fts_results) == 0:
|
||||
return vector_results
|
||||
|
||||
# sort both input tables on _rowid
|
||||
combined_list = []
|
||||
vector_list = vector_results.sort_by("_rowid").to_pylist()
|
||||
fts_list = fts_results.sort_by("_rowid").to_pylist()
|
||||
i, j = 0, 0
|
||||
while i < len(vector_list):
|
||||
if j >= len(fts_list):
|
||||
for vi in vector_list[i:]:
|
||||
vi["_relevance_score"] = self._combine_score(vi["_distance"], fill)
|
||||
combined_list.append(vi)
|
||||
break
|
||||
|
||||
vi = vector_list[i]
|
||||
fj = fts_list[j]
|
||||
# invert the fts score from relevance to distance
|
||||
inverted_fts_score = self._invert_score(fj["score"])
|
||||
if vi["_rowid"] == fj["_rowid"]:
|
||||
vi["_relevance_score"] = self._combine_score(
|
||||
vi["_distance"], inverted_fts_score
|
||||
)
|
||||
vi["score"] = fj["score"] # keep the original score
|
||||
combined_list.append(vi)
|
||||
i += 1
|
||||
j += 1
|
||||
elif vector_list[i]["_rowid"] < fts_list[j]["_rowid"]:
|
||||
vi["_relevance_score"] = self._combine_score(vi["_distance"], fill)
|
||||
combined_list.append(vi)
|
||||
i += 1
|
||||
else:
|
||||
fj["_relevance_score"] = self._combine_score(inverted_fts_score, fill)
|
||||
combined_list.append(fj)
|
||||
j += 1
|
||||
if j < len(fts_list) - 1:
|
||||
for fj in fts_list[j:]:
|
||||
fj["_relevance_score"] = self._combine_score(inverted_fts_score, fill)
|
||||
combined_list.append(fj)
|
||||
|
||||
relevance_score_schema = pa.schema(
|
||||
[
|
||||
pa.field("_relevance_score", pa.float32()),
|
||||
]
|
||||
)
|
||||
combined_schema = pa.unify_schemas(
|
||||
[vector_results.schema, fts_results.schema, relevance_score_schema]
|
||||
)
|
||||
tbl = pa.Table.from_pylist(combined_list, schema=combined_schema).sort_by(
|
||||
[("_relevance_score", "descending")]
|
||||
)
|
||||
if self.score == "relevance":
|
||||
tbl = tbl.drop_columns(["score", "_distance"])
|
||||
return tbl
|
||||
|
||||
def _combine_score(self, score1, score2):
|
||||
# these scores represent distance
|
||||
return 1 - (self.weight * score1 + (1 - self.weight) * score2)
|
||||
|
||||
def _invert_score(self, scores: List[float]):
|
||||
# Invert the scores between relevance and distance
|
||||
return 1 - scores
|
||||
@@ -16,7 +16,7 @@ from __future__ import annotations
|
||||
import inspect
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import cached_property
|
||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import lance
|
||||
import numpy as np
|
||||
@@ -33,8 +33,7 @@ from .query import LanceQueryBuilder, Query
|
||||
from .util import (
|
||||
fs_from_uri,
|
||||
join_uri,
|
||||
safe_import_pandas,
|
||||
safe_import_polars,
|
||||
safe_import,
|
||||
value_to_sql,
|
||||
)
|
||||
from .utils.events import register_event
|
||||
@@ -48,8 +47,8 @@ if TYPE_CHECKING:
|
||||
from .db import LanceDBConnection
|
||||
|
||||
|
||||
pd = safe_import_pandas()
|
||||
pl = safe_import_polars()
|
||||
pd = safe_import("pandas")
|
||||
pl = safe_import("polars")
|
||||
|
||||
|
||||
def _sanitize_data(
|
||||
@@ -338,7 +337,7 @@ class Table(ABC):
|
||||
@abstractmethod
|
||||
def search(
|
||||
self,
|
||||
query: Optional[Union[VEC, str, "PIL.Image.Image"]] = None,
|
||||
query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
|
||||
vector_column_name: str = VECTOR_COLUMN_NAME,
|
||||
query_type: str = "auto",
|
||||
) -> LanceQueryBuilder:
|
||||
@@ -924,7 +923,7 @@ class LanceTable(Table):
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: Optional[Union[VEC, str, "PIL.Image.Image"]] = None,
|
||||
query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
|
||||
vector_column_name: str = VECTOR_COLUMN_NAME,
|
||||
query_type: str = "auto",
|
||||
) -> LanceQueryBuilder:
|
||||
@@ -1194,6 +1193,7 @@ class LanceTable(Table):
|
||||
"nprobes": query.nprobes,
|
||||
"refine_factor": query.refine_factor,
|
||||
},
|
||||
with_row_id=query.with_row_id,
|
||||
)
|
||||
|
||||
def cleanup_old_versions(
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import importlib
|
||||
import os
|
||||
import pathlib
|
||||
from datetime import date, datetime
|
||||
@@ -114,22 +115,23 @@ def join_uri(base: Union[str, pathlib.Path], *parts: str) -> str:
|
||||
return "/".join([p.rstrip("/") for p in [base, *parts]])
|
||||
|
||||
|
||||
def safe_import_pandas():
|
||||
def safe_import(module: str, mitigation=None):
|
||||
"""
|
||||
Import the specified module. If the module is not installed,
|
||||
raise an ImportError with a helpful message.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
module : str
|
||||
The name of the module to import
|
||||
mitigation : Optional[str]
|
||||
The package(s) to install to mitigate the error.
|
||||
If not provided then the module name will be used.
|
||||
"""
|
||||
try:
|
||||
import pandas as pd
|
||||
|
||||
return pd
|
||||
return importlib.import_module(module)
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
|
||||
def safe_import_polars():
|
||||
try:
|
||||
import polars as pl
|
||||
|
||||
return pl
|
||||
except ImportError:
|
||||
return None
|
||||
raise ImportError(f"Please install {mitigation or module}")
|
||||
|
||||
|
||||
@singledispatch
|
||||
|
||||
168
python/tests/test_rerankers.py
Normal file
168
python/tests/test_rerankers.py
Normal file
@@ -0,0 +1,168 @@
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import lancedb
|
||||
from lancedb.conftest import MockTextEmbeddingFunction # noqa
|
||||
from lancedb.embeddings import EmbeddingFunctionRegistry
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from lancedb.rerankers import CohereReranker, CrossEncoderReranker
|
||||
from lancedb.table import LanceTable
|
||||
|
||||
|
||||
def get_test_table(tmp_path):
|
||||
db = lancedb.connect(tmp_path)
|
||||
# Create a LanceDB table schema with a vector and a text column
|
||||
emb = EmbeddingFunctionRegistry.get_instance().get("test")()
|
||||
|
||||
class MyTable(LanceModel):
|
||||
text: str = emb.SourceField()
|
||||
vector: Vector(emb.ndims()) = emb.VectorField()
|
||||
|
||||
# Initialize the table using the schema
|
||||
table = LanceTable.create(
|
||||
db,
|
||||
"my_table",
|
||||
schema=MyTable,
|
||||
)
|
||||
|
||||
# Need to test with a bunch of phrases to make sure sorting is consistent
|
||||
phrases = [
|
||||
"great kid don't get cocky",
|
||||
"now that's a name I haven't heard in a long time",
|
||||
"if you strike me down I shall become more powerful than you imagine",
|
||||
"I find your lack of faith disturbing",
|
||||
"I've got a bad feeling about this",
|
||||
"never tell me the odds",
|
||||
"I am your father",
|
||||
"somebody has to save our skins",
|
||||
"New strategy R2 let the wookiee win",
|
||||
"Arrrrggghhhhhhh",
|
||||
"I see a mansard roof through the trees",
|
||||
"I see a salty message written in the eves",
|
||||
"the ground beneath my feet",
|
||||
"the hot garbage and concrete",
|
||||
"and now the tops of buildings",
|
||||
"everybody with a worried mind could never forgive the sight",
|
||||
"of wicked snakes inside a place you thought was dignified",
|
||||
"I don't wanna live like this",
|
||||
"but I don't wanna die",
|
||||
"The templars want control",
|
||||
"the brotherhood of assassins want freedom",
|
||||
"if only they could both see the world as it really is",
|
||||
"there would be peace",
|
||||
"but the war goes on",
|
||||
"altair's legacy was a warning",
|
||||
"Kratos had a son",
|
||||
"he was a god",
|
||||
"the god of war",
|
||||
"but his son was mortal",
|
||||
"there hasn't been a good battlefield game since 2142",
|
||||
"I wish they would make another one",
|
||||
"campains are not as good as they used to be",
|
||||
"Multiplayer and open world games have destroyed the single player experience",
|
||||
"Maybe the future is console games",
|
||||
"I don't know",
|
||||
]
|
||||
|
||||
# Add the phrases and vectors to the table
|
||||
table.add([{"text": p} for p in phrases])
|
||||
|
||||
# Create a fts index
|
||||
table.create_fts_index("text")
|
||||
|
||||
return table, MyTable
|
||||
|
||||
|
||||
## These tests are pretty loose, we should also check for correctness
|
||||
def test_linear_combination(tmp_path):
|
||||
table, schema = get_test_table(tmp_path)
|
||||
# The default reranker
|
||||
result1 = (
|
||||
table.search("Our father who art in heaven", query_type="hybrid")
|
||||
.rerank(normalize="score")
|
||||
.to_pydantic(schema)
|
||||
)
|
||||
result2 = ( # noqa
|
||||
table.search("Our father who art in heaven.", query_type="hybrid")
|
||||
.rerank(normalize="rank")
|
||||
.to_pydantic(schema)
|
||||
)
|
||||
result3 = table.search(
|
||||
"Our father who art in heaven..", query_type="hybrid"
|
||||
).to_pydantic(schema)
|
||||
|
||||
assert result1 == result3 # 2 & 3 should be the same as they use score as score
|
||||
|
||||
result = (
|
||||
table.search("Our father who art in heaven", query_type="hybrid")
|
||||
.limit(50)
|
||||
.rerank(normalize="score")
|
||||
.to_arrow()
|
||||
)
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
|
||||
"The _score column of the results returned by the reranker "
|
||||
"represents the relevance of the result to the query & should "
|
||||
"be descending."
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("COHERE_API_KEY") is None, reason="COHERE_API_KEY not set"
|
||||
)
|
||||
def test_cohere_reranker(tmp_path):
|
||||
pytest.importorskip("cohere")
|
||||
table, schema = get_test_table(tmp_path)
|
||||
# The default reranker
|
||||
result1 = (
|
||||
table.search("Our father who art in heaven", query_type="hybrid")
|
||||
.rerank(normalize="score", reranker=CohereReranker())
|
||||
.to_pydantic(schema)
|
||||
)
|
||||
result2 = (
|
||||
table.search("Our father who art in heaven", query_type="hybrid")
|
||||
.rerank(normalize="rank", reranker=CohereReranker())
|
||||
.to_pydantic(schema)
|
||||
)
|
||||
assert result1 == result2
|
||||
|
||||
result = (
|
||||
table.search("Our father who art in heaven", query_type="hybrid")
|
||||
.limit(50)
|
||||
.rerank(reranker=CohereReranker())
|
||||
.to_arrow()
|
||||
)
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
|
||||
"The _score column of the results returned by the reranker "
|
||||
"represents the relevance of the result to the query & should "
|
||||
"be descending."
|
||||
)
|
||||
|
||||
|
||||
def test_cross_encoder_reranker(tmp_path):
|
||||
pytest.importorskip("sentence_transformers")
|
||||
table, schema = get_test_table(tmp_path)
|
||||
result1 = (
|
||||
table.search("Our father who art in heaven", query_type="hybrid")
|
||||
.rerank(normalize="score", reranker=CrossEncoderReranker())
|
||||
.to_pydantic(schema)
|
||||
)
|
||||
result2 = (
|
||||
table.search("Our father who art in heaven", query_type="hybrid")
|
||||
.rerank(normalize="rank", reranker=CrossEncoderReranker())
|
||||
.to_pydantic(schema)
|
||||
)
|
||||
assert result1 == result2
|
||||
|
||||
result = (
|
||||
table.search("Our father who art in heaven", query_type="hybrid")
|
||||
.limit(50)
|
||||
.rerank(reranker=CrossEncoderReranker())
|
||||
.to_arrow()
|
||||
)
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
|
||||
"The _score column of the results returned by the reranker "
|
||||
"represents the relevance of the result to the query & should "
|
||||
"be descending."
|
||||
)
|
||||
@@ -682,3 +682,57 @@ def test_count_rows(db):
|
||||
assert len(table) == 2
|
||||
assert table.count_rows() == 2
|
||||
assert table.count_rows(filter="text='bar'") == 1
|
||||
|
||||
|
||||
def test_hybrid_search(db):
|
||||
# hardcoding temporarily.. this test is failing with tmp_path mockdb.
|
||||
# Probably not being parsed right by the fts
|
||||
db = MockDB("~/lancedb_")
|
||||
# Create a LanceDB table schema with a vector and a text column
|
||||
emb = EmbeddingFunctionRegistry.get_instance().get("test")()
|
||||
|
||||
class MyTable(LanceModel):
|
||||
text: str = emb.SourceField()
|
||||
vector: Vector(emb.ndims()) = emb.VectorField()
|
||||
|
||||
# Initialize the table using the schema
|
||||
table = LanceTable.create(
|
||||
db,
|
||||
"my_table",
|
||||
schema=MyTable,
|
||||
)
|
||||
|
||||
# Create a list of 10 unique english phrases
|
||||
phrases = [
|
||||
"great kid don't get cocky",
|
||||
"now that's a name I haven't heard in a long time",
|
||||
"if you strike me down I shall become more powerful than you imagine",
|
||||
"I find your lack of faith disturbing",
|
||||
"I've got a bad feeling about this",
|
||||
"never tell me the odds",
|
||||
"I am your father",
|
||||
"somebody has to save our skins",
|
||||
"New strategy R2 let the wookiee win",
|
||||
"Arrrrggghhhhhhh",
|
||||
]
|
||||
|
||||
# Add the phrases and vectors to the table
|
||||
table.add([{"text": p} for p in phrases])
|
||||
|
||||
# Create a fts index
|
||||
table.create_fts_index("text")
|
||||
|
||||
result1 = (
|
||||
table.search("Our father who art in heaven", query_type="hybrid")
|
||||
.rerank(normalize="score")
|
||||
.to_pydantic(MyTable)
|
||||
)
|
||||
result2 = ( # noqa
|
||||
table.search("Our father who art in heaven", query_type="hybrid")
|
||||
.rerank(normalize="rank")
|
||||
.to_pydantic(MyTable)
|
||||
)
|
||||
result3 = table.search(
|
||||
"Our father who art in heaven", query_type="hybrid"
|
||||
).to_pydantic(MyTable)
|
||||
assert result1 == result3
|
||||
|
||||
Reference in New Issue
Block a user