feat(python): Hybrid search & Reranker API (#824)

based on https://github.com/lancedb/lancedb/pull/713
- The Reranker api can be plugged into vector only or fts only search
but this PR doesn't do that (see example -
https://txt.cohere.com/rerank/)


### Default reranker -- `LinearCombinationReranker(weight=0.7,
fill=1.0)`

```
table.search("hello", query_type="hybrid").rerank(normalize="score").to_pandas()
```
### Available rerankers
LinearCombinationReranker
```
from lancedb.rerankers import LinearCombinationReranker

# Same as default 
table.search("hello", query_type="hybrid").rerank(
                                      normalize="score", 
                                      reranker=LinearCombinationReranker()
                                     ).to_pandas()

# with custom params
reranker = LinearCombinationReranker(weight=0.3, fill=1.0)
table.search("hello", query_type="hybrid").rerank(
                                      normalize="score", 
                                      reranker=reranker
                                     ).to_pandas()
```

Cohere Reranker
```
from lancedb.rerankers import CohereReranker

# default model.. English and multi-lingual supported. See docstring for available custom params
table.search("hello", query_type="hybrid").rerank(
                                      normalize="rank",  # score or rank
                                      reranker=CohereReranker()
                                     ).to_pandas()

```

CrossEncoderReranker

```
from lancedb.rerankers import CrossEncoderReranker

table.search("hello", query_type="hybrid").rerank(
                                      normalize="rank", 
                                      reranker=CrossEncoderReranker()
                                     ).to_pandas()

```

## Using custom Reranker
```
from lancedb.reranker import Reranker

class CustomReranker(Reranker):
    def rerank_hybrid(self, vector_result, fts_result):
           combined_res = self.merge_results(vector_results, fts_results) # or use custom combination logic
           # Custom rerank logic here
           
           return combined_res
```

- [x] Expand testing
- [x] Make sure usage makes sense
- [x] Run simple benchmarks for correctness (Seeing weird result from
cohere reranker in the toy example)
- Support diverse rerankers by default:
- [x] Cross encoding
- [x] Cohere
- [x] Reciprocal Rank Fusion

---------

Co-authored-by: Chang She <759245+changhiskhan@users.noreply.github.com>
Co-authored-by: Prashanth Rao <35005448+prrao87@users.noreply.github.com>
This commit is contained in:
Ayush Chaurasia
2024-01-30 19:10:33 +05:30
committed by GitHub
parent f150768739
commit 3ffed89793
16 changed files with 1136 additions and 41 deletions

View File

@@ -16,9 +16,9 @@ from typing import Iterable, List, Union
import numpy as np
import pyarrow as pa
from .util import safe_import_pandas
from .util import safe_import
pd = safe_import_pandas()
pd = safe_import("pandas")
DATA = Union[List[dict], dict, "pd.DataFrame", pa.Table, Iterable[pa.RecordBatch]]
VEC = Union[list, np.ndarray, pa.Array, pa.ChunkedArray]

View File

@@ -16,9 +16,9 @@ import deprecation
from . import __version__
from .exceptions import MissingColumnError, MissingValueError
from .util import safe_import_pandas
from .util import safe_import
pd = safe_import_pandas()
pd = safe_import("pandas")
def contextualize(raw_df: "pd.DataFrame") -> Contextualizer:

View File

@@ -26,10 +26,10 @@ import pyarrow as pa
from lance.vector import vec_to_table
from retry import retry
from ..util import safe_import_pandas
from ..util import safe_import
from ..utils.general import LOGGER
pd = safe_import_pandas()
pd = safe_import("pandas")
DATA = Union[pa.Table, "pd.DataFrame"]
TEXT = Union[str, List[str], pa.Array, pa.ChunkedArray, np.ndarray]

View File

@@ -14,8 +14,9 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import TYPE_CHECKING, List, Literal, Optional, Type, Union
from typing import TYPE_CHECKING, List, Literal, Optional, Tuple, Type, Union
import deprecation
import numpy as np
@@ -23,8 +24,10 @@ import pyarrow as pa
import pydantic
from . import __version__
from .common import VECTOR_COLUMN_NAME
from .util import safe_import_pandas
from .common import VEC, VECTOR_COLUMN_NAME
from .rerankers.base import Reranker
from .rerankers.linear_combination import LinearCombinationReranker
from .util import safe_import
if TYPE_CHECKING:
import PIL
@@ -33,7 +36,7 @@ if TYPE_CHECKING:
from .pydantic import LanceModel
from .table import Table
pd = safe_import_pandas()
pd = safe_import("pandas")
class Query(pydantic.BaseModel):
@@ -99,6 +102,8 @@ class Query(pydantic.BaseModel):
# Refine factor.
refine_factor: Optional[int] = None
with_row_id: bool = False
class LanceQueryBuilder(ABC):
"""Build LanceDB query based on specific query type:
@@ -109,19 +114,26 @@ class LanceQueryBuilder(ABC):
def create(
cls,
table: "Table",
query: Optional[Union[np.ndarray, str, "PIL.Image.Image"]],
query: Optional[Union[np.ndarray, str, "PIL.Image.Image", Tuple]],
query_type: str,
vector_column_name: str,
) -> LanceQueryBuilder:
if query is None:
return LanceEmptyQueryBuilder(table)
# convert "auto" query_type to "vector" or "fts"
# and convert the query to vector if needed
if query_type == "hybrid":
# hybrid fts and vector query
return LanceHybridQueryBuilder(table, query, vector_column_name)
# convert "auto" query_type to "vector", "fts"
# or "hybrid" and convert the query to vector if needed
query, query_type = cls._resolve_query(
table, query, query_type, vector_column_name
)
if query_type == "hybrid":
return LanceHybridQueryBuilder(table, query, vector_column_name)
if isinstance(query, str):
# fts
return LanceFtsQueryBuilder(table, query)
@@ -144,17 +156,13 @@ class LanceQueryBuilder(ABC):
raise TypeError(f"'fts' queries must be a string: {type(query)}")
return query, query_type
elif query_type == "vector":
if not isinstance(query, (list, np.ndarray)):
conf = table.embedding_functions.get(vector_column_name)
if conf is not None:
query = conf.function.compute_query_embeddings_with_retry(query)[0]
else:
msg = f"No embedding function for {vector_column_name}"
raise ValueError(msg)
query = cls._query_to_vector(table, query, vector_column_name)
return query, query_type
elif query_type == "auto":
if isinstance(query, (list, np.ndarray)):
return query, "vector"
if isinstance(query, tuple):
return query, "hybrid"
else:
conf = table.embedding_functions.get(vector_column_name)
if conf is not None:
@@ -167,11 +175,23 @@ class LanceQueryBuilder(ABC):
f"Invalid query_type, must be 'vector', 'fts', or 'auto': {query_type}"
)
@classmethod
def _query_to_vector(cls, table, query, vector_column_name):
if isinstance(query, (list, np.ndarray)):
return query
conf = table.embedding_functions.get(vector_column_name)
if conf is not None:
return conf.function.compute_query_embeddings_with_retry(query)[0]
else:
msg = f"No embedding function for {vector_column_name}"
raise ValueError(msg)
def __init__(self, table: "Table"):
self._table = table
self._limit = 10
self._columns = None
self._where = None
self._with_row_id = False
@deprecation.deprecated(
deprecated_in="0.3.1",
@@ -341,6 +361,22 @@ class LanceQueryBuilder(ABC):
self._prefilter = prefilter
return self
def with_row_id(self, with_row_id: bool) -> LanceQueryBuilder:
"""Set whether to return row ids.
Parameters
----------
with_row_id: bool
If True, return _rowid column in the results.
Returns
-------
LanceQueryBuilder
The LanceQueryBuilder object.
"""
self._with_row_id = with_row_id
return self
class LanceVectorQueryBuilder(LanceQueryBuilder):
"""
@@ -459,6 +495,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
nprobes=self._nprobes,
refine_factor=self._refine_factor,
vector_column=self._vector_column,
with_row_id=self._with_row_id,
)
return self._table._execute_query(query)
@@ -568,6 +605,10 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
ds = lance.write_dataset(output_tbl, tmp)
output_tbl = ds.to_table(filter=self._where)
if self._with_row_id:
# Need to set this to uint explicitly as vector results are in uint64
row_ids = pa.array(row_ids, type=pa.uint64())
output_tbl = output_tbl.append_column("_rowid", row_ids)
return output_tbl
@@ -579,3 +620,258 @@ class LanceEmptyQueryBuilder(LanceQueryBuilder):
filter=self._where,
limit=self._limit,
)
class LanceHybridQueryBuilder(LanceQueryBuilder):
def __init__(self, table: "Table", query: str, vector_column: str):
super().__init__(table)
self._validate_fts_index()
self._query = query
vector_query, fts_query = self._validate_query(query)
self._fts_query = LanceFtsQueryBuilder(table, fts_query)
vector_query = self._query_to_vector(table, vector_query, vector_column)
self._vector_query = LanceVectorQueryBuilder(table, vector_query, vector_column)
self._norm = "score"
self._reranker = LinearCombinationReranker(weight=0.7, fill=1.0)
def _validate_fts_index(self):
if self._table._get_fts_index_path() is None:
raise ValueError(
"Please create a full-text search index " "to perform hybrid search."
)
def _validate_query(self, query):
# Temp hack to support vectorized queries for hybrid search
if isinstance(query, str):
return query, query
elif isinstance(query, tuple):
if len(query) != 2:
raise ValueError(
"The query must be a tuple of (vector_query, fts_query)."
)
if not isinstance(query[0], (list, np.ndarray, pa.Array, pa.ChunkedArray)):
raise ValueError(f"The vector query must be one of {VEC}.")
if not isinstance(query[1], str):
raise ValueError("The fts query must be a string.")
return query[0], query[1]
else:
raise ValueError(
"The query must be either a string or a tuple of (vector, string)."
)
def to_arrow(self) -> pa.Table:
with ThreadPoolExecutor() as executor:
fts_future = executor.submit(self._fts_query.with_row_id(True).to_arrow)
vector_future = executor.submit(
self._vector_query.with_row_id(True).to_arrow
)
fts_results = fts_future.result()
vector_results = vector_future.result()
# convert to ranks first if needed
if self._norm == "rank":
vector_results = self._rank(vector_results, "_distance")
fts_results = self._rank(fts_results, "score")
# normalize the scores to be between 0 and 1, 0 being most relevant
vector_results = self._normalize_scores(vector_results, "_distance")
# In fts higher scores represent relevance. Not inverting them here as
# rerankers might need to preserve this score to support `return_score="all"`
fts_results = self._normalize_scores(fts_results, "score")
results = self._reranker.rerank_hybrid(self, vector_results, fts_results)
if not isinstance(results, pa.Table): # Enforce type
raise TypeError(
f"rerank_hybrid must return a pyarrow.Table, got {type(results)}"
)
if not self._with_row_id:
results = results.drop(["_rowid"])
return results
def _rank(self, results: pa.Table, column: str, ascending: bool = True):
if len(results) == 0:
return results
# Get the _score column from results
scores = results.column(column).to_numpy()
sort_indices = np.argsort(scores)
if not ascending:
sort_indices = sort_indices[::-1]
ranks = np.empty_like(sort_indices)
ranks[sort_indices] = np.arange(len(scores)) + 1
# replace the _score column with the ranks
_score_idx = results.column_names.index(column)
results = results.set_column(
_score_idx, column, pa.array(ranks, type=pa.float32())
)
return results
def _normalize_scores(self, results: pa.Table, column: str, invert=False):
if len(results) == 0:
return results
# Get the _score column from results
scores = results.column(column).to_numpy()
# normalize the scores by subtracting the min and dividing by the max
max, min = np.max(scores), np.min(scores)
if np.isclose(max, min):
rng = max
else:
rng = max - min
scores = (scores - min) / rng
if invert:
scores = 1 - scores
# replace the _score column with the ranks
_score_idx = results.column_names.index(column)
results = results.set_column(
_score_idx, column, pa.array(scores, type=pa.float32())
)
return results
def rerank(
self,
normalize="score",
reranker: Reranker = LinearCombinationReranker(weight=0.7, fill=1.0),
) -> LanceHybridQueryBuilder:
"""
Rerank the hybrid search results using the specified reranker. The reranker
must be an instance of Reranker class.
Parameters
----------
normalize: str, default "score"
The method to normalize the scores. Can be "rank" or "score". If "rank",
the scores are converted to ranks and then normalized. If "score", the
scores are normalized directly.
reranker: Reranker, default LinearCombinationReranker(weight=0.7, fill=1.0)
The reranker to use. Must be an instance of Reranker class.
Returns
-------
LanceHybridQueryBuilder
The LanceHybridQueryBuilder object.
"""
if normalize not in ["rank", "score"]:
raise ValueError("normalize must be 'rank' or 'score'.")
if reranker and not isinstance(reranker, Reranker):
raise ValueError("reranker must be an instance of Reranker class.")
self._norm = normalize
self._reranker = reranker
return self
def limit(self, limit: int) -> LanceHybridQueryBuilder:
"""
Set the maximum number of results to return for both vector and fts search
components.
Parameters
----------
limit: int
The maximum number of results to return.
Returns
-------
LanceHybridQueryBuilder
The LanceHybridQueryBuilder object.
"""
self._vector_query.limit(limit)
self._fts_query.limit(limit)
return self
def select(self, columns: list) -> LanceHybridQueryBuilder:
"""
Set the columns to return for both vector and fts search.
Parameters
----------
columns: list
The columns to return.
Returns
-------
LanceHybridQueryBuilder
The LanceHybridQueryBuilder object.
"""
self._vector_query.select(columns)
self._fts_query.select(columns)
return self
def where(self, where: str, prefilter: bool = False) -> LanceHybridQueryBuilder:
"""
Set the where clause for both vector and fts search.
Parameters
----------
where: str
The where clause which is a valid SQL where clause. See
`Lance filter pushdown <https://lancedb.github.io/lance/read_and_write.html#filter-push-down>`_
for valid SQL expressions.
prefilter: bool, default False
If True, apply the filter before vector search, otherwise the
filter is applied on the result of vector search.
Returns
-------
LanceHybridQueryBuilder
The LanceHybridQueryBuilder object.
"""
self._vector_query.where(where, prefilter=prefilter)
self._fts_query.where(where)
return self
def metric(self, metric: Literal["L2", "cosine"]) -> LanceHybridQueryBuilder:
"""
Set the distance metric to use for vector search.
Parameters
----------
metric: "L2" or "cosine"
The distance metric to use. By default "L2" is used.
Returns
-------
LanceHybridQueryBuilder
The LanceHybridQueryBuilder object.
"""
self._vector_query.metric(metric)
return self
def nprobes(self, nprobes: int) -> LanceHybridQueryBuilder:
"""
Set the number of probes to use for vector search.
Higher values will yield better recall (more likely to find vectors if
they exist) at the expense of latency.
Parameters
----------
nprobes: int
The number of probes to use.
Returns
-------
LanceHybridQueryBuilder
The LanceHybridQueryBuilder object.
"""
self._vector_query.nprobes(nprobes)
return self
def refine_factor(self, refine_factor: int) -> LanceHybridQueryBuilder:
"""
Refine the vector search results by reading extra elements and
re-ranking them in memory.
Parameters
----------
refine_factor: int
The refine factor to use.
Returns
-------
LanceHybridQueryBuilder
The LanceHybridQueryBuilder object.
"""
self._vector_query.refine_factor(refine_factor)
return self

View File

@@ -0,0 +1,11 @@
from .base import Reranker
from .cohere import CohereReranker
from .cross_encoder import CrossEncoderReranker
from .linear_combination import LinearCombinationReranker
__all__ = [
"Reranker",
"CrossEncoderReranker",
"CohereReranker",
"LinearCombinationReranker",
]

View File

@@ -0,0 +1,109 @@
import typing
from abc import ABC, abstractmethod
import numpy as np
import pyarrow as pa
if typing.TYPE_CHECKING:
import lancedb
class Reranker(ABC):
def __init__(self, return_score: str = "relevance"):
"""
Interface for a reranker. A reranker is used to rerank the results from a
vector and FTS search. This is useful for combining the results from both
search methods.
Parameters
----------
return_score : str, default "relevance"
opntions are "relevance" or "all"
The type of score to return. If "relevance", will return only the relevance
score. If "all", will return all scores from the vector and FTS search along
with the relevance score.
"""
if return_score not in ["relevance", "all"]:
raise ValueError("score must be either 'relevance' or 'all'")
self.score = return_score
@abstractmethod
def rerank_hybrid(
query_builder: "lancedb.HybridQueryBuilder",
vector_results: pa.Table,
fts_results: pa.Table,
):
"""
Rerank function receives the individual results from the vector and FTS search
results. You can choose to use any of the results to generate the final results,
allowing maximum flexibility. This is mandatory to implement
Parameters
----------
query_builder : "lancedb.HybridQueryBuilder"
The query builder object that was used to generate the results
vector_results : pa.Table
The results from the vector search
fts_results : pa.Table
The results from the FTS search
"""
pass
def rerank_vector(
query_builder: "lancedb.VectorQueryBuilder", vector_results: pa.Table
):
"""
Rerank function receives the individual results from the vector search.
This isn't mandatory to implement
Parameters
----------
query_builder : "lancedb.VectorQueryBuilder"
The query builder object that was used to generate the results
vector_results : pa.Table
The results from the vector search
"""
raise NotImplementedError("Vector Reranking is not implemented")
def rerank_fts(query_builder: "lancedb.FTSQueryBuilder", fts_results: pa.Table):
"""
Rerank function receives the individual results from the FTS search.
This isn't mandatory to implement
Parameters
----------
query_builder : "lancedb.FTSQueryBuilder"
The query builder object that was used to generate the results
fts_results : pa.Table
The results from the FTS search
"""
raise NotImplementedError("FTS Reranking is not implemented")
def merge_results(self, vector_results: pa.Table, fts_results: pa.Table):
"""
Merge the results from the vector and FTS search. This is a vanilla merging
function that just concatenates the results and removes the duplicates.
NOTE: This doesn't take score into account. It'll keep the instance that was
encountered first. This is designed for rerankers that don't use the score.
In case you want to use the score, or support `return_scores="all"` you'll
have to implement your own merging function.
Parameters
----------
vector_results : pa.Table
The results from the vector search
fts_results : pa.Table
The results from the FTS search
"""
combined = pa.concat_tables([vector_results, fts_results], promote=True)
row_id = combined.column("_rowid")
# deduplicate
mask = np.full((combined.shape[0]), False)
_, mask_indices = np.unique(np.array(row_id), return_index=True)
mask[mask_indices] = True
combined = combined.filter(mask=mask)
return combined

View File

@@ -0,0 +1,85 @@
import os
import typing
from functools import cached_property
from typing import Union
import pyarrow as pa
from ..util import safe_import
from .base import Reranker
if typing.TYPE_CHECKING:
import lancedb
class CohereReranker(Reranker):
"""
Reranks the results using the Cohere Rerank API.
https://docs.cohere.com/docs/rerank-guide
Parameters
----------
model_name : str, default "rerank-english-v2.0"
The name of the cross encoder model to use. Available cohere models are:
- rerank-english-v2.0
- rerank-multilingual-v2.0
column : str, default "text"
The name of the column to use as input to the cross encoder model.
top_n : str, default None
The number of results to return. If None, will return all results.
"""
def __init__(
self,
model_name: str = "rerank-english-v2.0",
column: str = "text",
top_n: Union[int, None] = None,
return_score="relevance",
api_key: Union[str, None] = None,
):
super().__init__(return_score)
self.model_name = model_name
self.column = column
self.top_n = top_n
self.api_key = api_key
@cached_property
def _client(self):
cohere = safe_import("cohere")
if os.environ.get("COHERE_API_KEY") is None and self.api_key is None:
raise ValueError(
"COHERE_API_KEY not set. Either set it in your environment or \
pass it as `api_key` argument to the CohereReranker."
)
return cohere.Client(os.environ.get("COHERE_API_KEY") or self.api_key)
def rerank_hybrid(
self,
query_builder: "lancedb.HybridQueryBuilder",
vector_results: pa.Table,
fts_results: pa.Table,
):
combined_results = self.merge_results(vector_results, fts_results)
docs = combined_results[self.column].to_pylist()
results = self._client.rerank(
query=query_builder._query,
documents=docs,
top_n=self.top_n,
model=self.model_name,
) # returns list (text, idx, relevance) attributes sorted descending by score
indices, scores = list(
zip(*[(result.index, result.relevance_score) for result in results])
) # tuples
combined_results = combined_results.take(list(indices))
# add the scores
combined_results = combined_results.append_column(
"_relevance_score", pa.array(scores, type=pa.float32())
)
if self.score == "relevance":
combined_results = combined_results.drop_columns(["score", "_distance"])
elif self.score == "all":
raise NotImplementedError(
"return_score='all' not implemented for cohere reranker"
)
return combined_results

View File

@@ -0,0 +1,78 @@
import typing
from functools import cached_property
from typing import Union
import pyarrow as pa
from ..util import safe_import
from .base import Reranker
if typing.TYPE_CHECKING:
import lancedb
class CrossEncoderReranker(Reranker):
"""
Reranks the results using a cross encoder model. The cross encoder model is
used to score the query and each result. The results are then sorted by the score.
Parameters
----------
model : str, default "cross-encoder/ms-marco-TinyBERT-L-6"
The name of the cross encoder model to use. See the sentence transformers
documentation for a list of available models.
column : str, default "text"
The name of the column to use as input to the cross encoder model.
device : str, default None
The device to use for the cross encoder model. If None, will use "cuda"
if available, otherwise "cpu".
"""
def __init__(
self,
model_name: str = "cross-encoder/ms-marco-TinyBERT-L-6",
column: str = "text",
device: Union[str, None] = None,
return_score="relevance",
):
super().__init__(return_score)
torch = safe_import("torch")
self.model_name = model_name
self.column = column
self.device = device
if self.device is None:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
@cached_property
def model(self):
sbert = safe_import("sentence_transformers")
cross_encoder = sbert.CrossEncoder(self.model_name)
return cross_encoder
def rerank_hybrid(
self,
query_builder: "lancedb.HybridQueryBuilder",
vector_results: pa.Table,
fts_results: pa.Table,
):
combined_results = self.merge_results(vector_results, fts_results)
passages = combined_results[self.column].to_pylist()
cross_inp = [[query_builder._query, passage] for passage in passages]
cross_scores = self.model.predict(cross_inp)
combined_results = combined_results.append_column(
"_relevance_score", pa.array(cross_scores, type=pa.float32())
)
# sort the results by _score
if self.score == "relevance":
combined_results = combined_results.drop_columns(["score", "_distance"])
elif self.score == "all":
raise NotImplementedError(
"return_score='all' not implemented for CrossEncoderReranker"
)
combined_results = combined_results.sort_by(
[("_relevance_score", "descending")]
)
return combined_results

View File

@@ -0,0 +1,117 @@
from typing import List
import pyarrow as pa
from .base import Reranker
class LinearCombinationReranker(Reranker):
"""
Reranks the results using a linear combination of the scores from the
vector and FTS search. For missing scores, fill with `fill` value.
Parameters
----------
weight : float, default 0.7
The weight to give to the vector score. Must be between 0 and 1.
fill : float, default 1.0
The score to give to results that are only in one of the two result sets.
This is treated as penalty, so a higher value means a lower score.
TODO: We should just hardcode this--
its pretty confusing as we invert scores to calculate final score
return_score : str, default "relevance"
opntions are "relevance" or "all"
The type of score to return. If "relevance", will return only the relevance
score. If "all", will return all scores from the vector and FTS search along
with the relevance score.
"""
def __init__(
self, weight: float = 0.7, fill: float = 1.0, return_score="relevance"
):
if weight < 0 or weight > 1:
raise ValueError("weight must be between 0 and 1.")
super().__init__(return_score)
self.weight = weight
self.fill = fill
def rerank_hybrid(
self,
query_builder: "lancedb.HybridQueryBuilder", # noqa: F821
vector_results: pa.Table,
fts_results: pa.Table,
):
combined_results = self.merge_results(vector_results, fts_results, self.fill)
return combined_results
def merge_results(
self, vector_results: pa.Table, fts_results: pa.Table, fill: float
):
# If both are empty then just return an empty table
if len(vector_results) == 0 and len(fts_results) == 0:
return vector_results
# If one is empty then return the other
if len(vector_results) == 0:
return fts_results
if len(fts_results) == 0:
return vector_results
# sort both input tables on _rowid
combined_list = []
vector_list = vector_results.sort_by("_rowid").to_pylist()
fts_list = fts_results.sort_by("_rowid").to_pylist()
i, j = 0, 0
while i < len(vector_list):
if j >= len(fts_list):
for vi in vector_list[i:]:
vi["_relevance_score"] = self._combine_score(vi["_distance"], fill)
combined_list.append(vi)
break
vi = vector_list[i]
fj = fts_list[j]
# invert the fts score from relevance to distance
inverted_fts_score = self._invert_score(fj["score"])
if vi["_rowid"] == fj["_rowid"]:
vi["_relevance_score"] = self._combine_score(
vi["_distance"], inverted_fts_score
)
vi["score"] = fj["score"] # keep the original score
combined_list.append(vi)
i += 1
j += 1
elif vector_list[i]["_rowid"] < fts_list[j]["_rowid"]:
vi["_relevance_score"] = self._combine_score(vi["_distance"], fill)
combined_list.append(vi)
i += 1
else:
fj["_relevance_score"] = self._combine_score(inverted_fts_score, fill)
combined_list.append(fj)
j += 1
if j < len(fts_list) - 1:
for fj in fts_list[j:]:
fj["_relevance_score"] = self._combine_score(inverted_fts_score, fill)
combined_list.append(fj)
relevance_score_schema = pa.schema(
[
pa.field("_relevance_score", pa.float32()),
]
)
combined_schema = pa.unify_schemas(
[vector_results.schema, fts_results.schema, relevance_score_schema]
)
tbl = pa.Table.from_pylist(combined_list, schema=combined_schema).sort_by(
[("_relevance_score", "descending")]
)
if self.score == "relevance":
tbl = tbl.drop_columns(["score", "_distance"])
return tbl
def _combine_score(self, score1, score2):
# these scores represent distance
return 1 - (self.weight * score1 + (1 - self.weight) * score2)
def _invert_score(self, scores: List[float]):
# Invert the scores between relevance and distance
return 1 - scores

View File

@@ -16,7 +16,7 @@ from __future__ import annotations
import inspect
from abc import ABC, abstractmethod
from functools import cached_property
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union
import lance
import numpy as np
@@ -33,8 +33,7 @@ from .query import LanceQueryBuilder, Query
from .util import (
fs_from_uri,
join_uri,
safe_import_pandas,
safe_import_polars,
safe_import,
value_to_sql,
)
from .utils.events import register_event
@@ -48,8 +47,8 @@ if TYPE_CHECKING:
from .db import LanceDBConnection
pd = safe_import_pandas()
pl = safe_import_polars()
pd = safe_import("pandas")
pl = safe_import("polars")
def _sanitize_data(
@@ -338,7 +337,7 @@ class Table(ABC):
@abstractmethod
def search(
self,
query: Optional[Union[VEC, str, "PIL.Image.Image"]] = None,
query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
vector_column_name: str = VECTOR_COLUMN_NAME,
query_type: str = "auto",
) -> LanceQueryBuilder:
@@ -924,7 +923,7 @@ class LanceTable(Table):
def search(
self,
query: Optional[Union[VEC, str, "PIL.Image.Image"]] = None,
query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
vector_column_name: str = VECTOR_COLUMN_NAME,
query_type: str = "auto",
) -> LanceQueryBuilder:
@@ -1194,6 +1193,7 @@ class LanceTable(Table):
"nprobes": query.nprobes,
"refine_factor": query.refine_factor,
},
with_row_id=query.with_row_id,
)
def cleanup_old_versions(

View File

@@ -11,6 +11,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import os
import pathlib
from datetime import date, datetime
@@ -114,22 +115,23 @@ def join_uri(base: Union[str, pathlib.Path], *parts: str) -> str:
return "/".join([p.rstrip("/") for p in [base, *parts]])
def safe_import_pandas():
def safe_import(module: str, mitigation=None):
"""
Import the specified module. If the module is not installed,
raise an ImportError with a helpful message.
Parameters
----------
module : str
The name of the module to import
mitigation : Optional[str]
The package(s) to install to mitigate the error.
If not provided then the module name will be used.
"""
try:
import pandas as pd
return pd
return importlib.import_module(module)
except ImportError:
return None
def safe_import_polars():
try:
import polars as pl
return pl
except ImportError:
return None
raise ImportError(f"Please install {mitigation or module}")
@singledispatch

View File

@@ -0,0 +1,168 @@
import os
import numpy as np
import pytest
import lancedb
from lancedb.conftest import MockTextEmbeddingFunction # noqa
from lancedb.embeddings import EmbeddingFunctionRegistry
from lancedb.pydantic import LanceModel, Vector
from lancedb.rerankers import CohereReranker, CrossEncoderReranker
from lancedb.table import LanceTable
def get_test_table(tmp_path):
db = lancedb.connect(tmp_path)
# Create a LanceDB table schema with a vector and a text column
emb = EmbeddingFunctionRegistry.get_instance().get("test")()
class MyTable(LanceModel):
text: str = emb.SourceField()
vector: Vector(emb.ndims()) = emb.VectorField()
# Initialize the table using the schema
table = LanceTable.create(
db,
"my_table",
schema=MyTable,
)
# Need to test with a bunch of phrases to make sure sorting is consistent
phrases = [
"great kid don't get cocky",
"now that's a name I haven't heard in a long time",
"if you strike me down I shall become more powerful than you imagine",
"I find your lack of faith disturbing",
"I've got a bad feeling about this",
"never tell me the odds",
"I am your father",
"somebody has to save our skins",
"New strategy R2 let the wookiee win",
"Arrrrggghhhhhhh",
"I see a mansard roof through the trees",
"I see a salty message written in the eves",
"the ground beneath my feet",
"the hot garbage and concrete",
"and now the tops of buildings",
"everybody with a worried mind could never forgive the sight",
"of wicked snakes inside a place you thought was dignified",
"I don't wanna live like this",
"but I don't wanna die",
"The templars want control",
"the brotherhood of assassins want freedom",
"if only they could both see the world as it really is",
"there would be peace",
"but the war goes on",
"altair's legacy was a warning",
"Kratos had a son",
"he was a god",
"the god of war",
"but his son was mortal",
"there hasn't been a good battlefield game since 2142",
"I wish they would make another one",
"campains are not as good as they used to be",
"Multiplayer and open world games have destroyed the single player experience",
"Maybe the future is console games",
"I don't know",
]
# Add the phrases and vectors to the table
table.add([{"text": p} for p in phrases])
# Create a fts index
table.create_fts_index("text")
return table, MyTable
## These tests are pretty loose, we should also check for correctness
def test_linear_combination(tmp_path):
table, schema = get_test_table(tmp_path)
# The default reranker
result1 = (
table.search("Our father who art in heaven", query_type="hybrid")
.rerank(normalize="score")
.to_pydantic(schema)
)
result2 = ( # noqa
table.search("Our father who art in heaven.", query_type="hybrid")
.rerank(normalize="rank")
.to_pydantic(schema)
)
result3 = table.search(
"Our father who art in heaven..", query_type="hybrid"
).to_pydantic(schema)
assert result1 == result3 # 2 & 3 should be the same as they use score as score
result = (
table.search("Our father who art in heaven", query_type="hybrid")
.limit(50)
.rerank(normalize="score")
.to_arrow()
)
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
"The _score column of the results returned by the reranker "
"represents the relevance of the result to the query & should "
"be descending."
)
@pytest.mark.skipif(
os.environ.get("COHERE_API_KEY") is None, reason="COHERE_API_KEY not set"
)
def test_cohere_reranker(tmp_path):
pytest.importorskip("cohere")
table, schema = get_test_table(tmp_path)
# The default reranker
result1 = (
table.search("Our father who art in heaven", query_type="hybrid")
.rerank(normalize="score", reranker=CohereReranker())
.to_pydantic(schema)
)
result2 = (
table.search("Our father who art in heaven", query_type="hybrid")
.rerank(normalize="rank", reranker=CohereReranker())
.to_pydantic(schema)
)
assert result1 == result2
result = (
table.search("Our father who art in heaven", query_type="hybrid")
.limit(50)
.rerank(reranker=CohereReranker())
.to_arrow()
)
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
"The _score column of the results returned by the reranker "
"represents the relevance of the result to the query & should "
"be descending."
)
def test_cross_encoder_reranker(tmp_path):
pytest.importorskip("sentence_transformers")
table, schema = get_test_table(tmp_path)
result1 = (
table.search("Our father who art in heaven", query_type="hybrid")
.rerank(normalize="score", reranker=CrossEncoderReranker())
.to_pydantic(schema)
)
result2 = (
table.search("Our father who art in heaven", query_type="hybrid")
.rerank(normalize="rank", reranker=CrossEncoderReranker())
.to_pydantic(schema)
)
assert result1 == result2
result = (
table.search("Our father who art in heaven", query_type="hybrid")
.limit(50)
.rerank(reranker=CrossEncoderReranker())
.to_arrow()
)
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
"The _score column of the results returned by the reranker "
"represents the relevance of the result to the query & should "
"be descending."
)

View File

@@ -682,3 +682,57 @@ def test_count_rows(db):
assert len(table) == 2
assert table.count_rows() == 2
assert table.count_rows(filter="text='bar'") == 1
def test_hybrid_search(db):
# hardcoding temporarily.. this test is failing with tmp_path mockdb.
# Probably not being parsed right by the fts
db = MockDB("~/lancedb_")
# Create a LanceDB table schema with a vector and a text column
emb = EmbeddingFunctionRegistry.get_instance().get("test")()
class MyTable(LanceModel):
text: str = emb.SourceField()
vector: Vector(emb.ndims()) = emb.VectorField()
# Initialize the table using the schema
table = LanceTable.create(
db,
"my_table",
schema=MyTable,
)
# Create a list of 10 unique english phrases
phrases = [
"great kid don't get cocky",
"now that's a name I haven't heard in a long time",
"if you strike me down I shall become more powerful than you imagine",
"I find your lack of faith disturbing",
"I've got a bad feeling about this",
"never tell me the odds",
"I am your father",
"somebody has to save our skins",
"New strategy R2 let the wookiee win",
"Arrrrggghhhhhhh",
]
# Add the phrases and vectors to the table
table.add([{"text": p} for p in phrases])
# Create a fts index
table.create_fts_index("text")
result1 = (
table.search("Our father who art in heaven", query_type="hybrid")
.rerank(normalize="score")
.to_pydantic(MyTable)
)
result2 = ( # noqa
table.search("Our father who art in heaven", query_type="hybrid")
.rerank(normalize="rank")
.to_pydantic(MyTable)
)
result3 = table.search(
"Our father who art in heaven", query_type="hybrid"
).to_pydantic(MyTable)
assert result1 == result3