diff --git a/.github/workflows/docs_test.yml b/.github/workflows/docs_test.yml index f99cadaf..72d995c6 100644 --- a/.github/workflows/docs_test.yml +++ b/.github/workflows/docs_test.yml @@ -49,6 +49,9 @@ jobs: test-node: name: Test doc nodejs code runs-on: "ubuntu-latest" + timeout-minutes: 45 + strategy: + fail-fast: false steps: - name: Checkout uses: actions/checkout@v4 @@ -66,6 +69,12 @@ jobs: uses: swatinem/rust-cache@v2 - name: Install node dependencies run: | + sudo swapoff -a + sudo fallocate -l 8G /swapfile + sudo chmod 600 /swapfile + sudo mkswap /swapfile + sudo swapon /swapfile + sudo swapon --show cd node npm ci npm run build-release diff --git a/docs/src/hybrid_search.md b/docs/src/hybrid_search.md index 282b0d60..c6d26656 100644 --- a/docs/src/hybrid_search.md +++ b/docs/src/hybrid_search.md @@ -130,6 +130,60 @@ Arguments Only returns `_relevance_score`. Does not support `return_score = "all"`. +### ColBERT Reranker +This reranker uses the ColBERT model to combine the results of semantic and full-text search. You can use it by passing `ColbertrReranker()` to the `rerank()` method. + +ColBERT reranker model calculates relevance of given docs against the query and don't take existing fts and vector search scores into account, so it currently only supports `return_score="relevance"`. By default, it looks for `text` column to rerank the results. But you can specify the column name to use as input to the cross encoder model as described below. + +```python +from lancedb.rerankers import ColbertReranker + +reranker = ColbertReranker() + +results = table.search("harmony hall", query_type="hybrid").rerank(reranker=reranker).to_pandas() +``` + +Arguments +---------------- +* `model_name` : `str`, default `"colbert-ir/colbertv2.0"` + The name of the cross encoder model to use. +* `column` : `str`, default `"text"` + The name of the column to use as input to the cross encoder model. +* `return_score` : `str`, default `"relevance"` + options are `"relevance"` or `"all"`. Only `"relevance"` is supported for now. + +!!! Note + Only returns `_relevance_score`. Does not support `return_score = "all"`. + +### OpenAI Reranker +This reranker uses the OpenAI API to combine the results of semantic and full-text search. You can use it by passing `OpenaiReranker()` to the `rerank()` method. + +!!! Note + This prompts chat model to rerank results which is not a dedicated reranker model. This should be treated as experimental. + +!!! Tip + You might run out of token limit so set the search `limits` based on your token limit. + +```python +from lancedb.rerankers import OpenaiReranker + +reranker = OpenaiReranker() + +results = table.search("harmony hall", query_type="hybrid").rerank(reranker=reranker).to_pandas() +``` + +Arguments +---------------- +`model_name` : `str`, default `"gpt-3.5-turbo-1106"` + The name of the cross encoder model to use. +`column` : `str`, default `"text"` + The name of the column to use as input to the cross encoder model. +`return_score` : `str`, default `"relevance"` + options are "relevance" or "all". Only "relevance" is supported for now. +`api_key` : `str`, default `None` + The API key to use. If None, will use the OPENAI_API_KEY environment variable. + + ## Building Custom Rerankers You can build your own custom reranker by subclassing the `Reranker` class and implementing the `rerank_hybrid()` method. Here's an example of a custom reranker that combines the results of semantic and full-text search using a linear combination of the scores. @@ -146,7 +200,7 @@ class MyReranker(Reranker): self.param1 = param1 self.param2 = param2 - def rerank_hybrid(self, vector_results: pa.Table, fts_results: pa.Table): + def rerank_hybrid(self, query: str, vector_results: pa.Table, fts_results: pa.Table): # Use the built-in merging function combined_result = self.merge_results(vector_results, fts_results) @@ -168,7 +222,7 @@ import pyarrow as pa class MyReranker(Reranker): ... - def rerank_hybrid(self, vector_results: pa.Table, fts_results: pa.Table, filter: str): + def rerank_hybrid(self, query: str, vector_results: pa.Table, fts_results: pa.Table, filter: str): # Use the built-in merging function combined_result = self.merge_results(vector_results, fts_results) diff --git a/python/lancedb/query.py b/python/lancedb/query.py index ef28eac9..8ce7b5f2 100644 --- a/python/lancedb/query.py +++ b/python/lancedb/query.py @@ -626,7 +626,6 @@ class LanceHybridQueryBuilder(LanceQueryBuilder): def __init__(self, table: "Table", query: str, vector_column: str): super().__init__(table) self._validate_fts_index() - self._query = query vector_query, fts_query = self._validate_query(query) self._fts_query = LanceFtsQueryBuilder(table, fts_query) vector_query = self._query_to_vector(table, vector_query, vector_column) @@ -679,12 +678,18 @@ class LanceHybridQueryBuilder(LanceQueryBuilder): # rerankers might need to preserve this score to support `return_score="all"` fts_results = self._normalize_scores(fts_results, "score") - results = self._reranker.rerank_hybrid(self, vector_results, fts_results) + results = self._reranker.rerank_hybrid( + self._fts_query._query, vector_results, fts_results + ) + if not isinstance(results, pa.Table): # Enforce type raise TypeError( f"rerank_hybrid must return a pyarrow.Table, got {type(results)}" ) + # apply limit after reranking + results = results.slice(length=self._limit) + if not self._with_row_id: results = results.drop(["_rowid"]) return results @@ -776,6 +781,8 @@ class LanceHybridQueryBuilder(LanceQueryBuilder): """ self._vector_query.limit(limit) self._fts_query.limit(limit) + self._limit = limit + return self def select(self, columns: list) -> LanceHybridQueryBuilder: diff --git a/python/lancedb/rerankers/__init__.py b/python/lancedb/rerankers/__init__.py index 6d43636a..af833fd7 100644 --- a/python/lancedb/rerankers/__init__.py +++ b/python/lancedb/rerankers/__init__.py @@ -1,11 +1,15 @@ from .base import Reranker from .cohere import CohereReranker +from .colbert import ColbertReranker from .cross_encoder import CrossEncoderReranker from .linear_combination import LinearCombinationReranker +from .openai import OpenaiReranker __all__ = [ "Reranker", "CrossEncoderReranker", "CohereReranker", "LinearCombinationReranker", + "OpenaiReranker", + "ColbertReranker", ] diff --git a/python/lancedb/rerankers/base.py b/python/lancedb/rerankers/base.py index b1036ec5..96479dbd 100644 --- a/python/lancedb/rerankers/base.py +++ b/python/lancedb/rerankers/base.py @@ -1,12 +1,8 @@ -import typing from abc import ABC, abstractmethod import numpy as np import pyarrow as pa -if typing.TYPE_CHECKING: - import lancedb - class Reranker(ABC): def __init__(self, return_score: str = "relevance"): @@ -30,7 +26,7 @@ class Reranker(ABC): @abstractmethod def rerank_hybrid( - query_builder: "lancedb.HybridQueryBuilder", + query: str, vector_results: pa.Table, fts_results: pa.Table, ): @@ -41,8 +37,8 @@ class Reranker(ABC): Parameters ---------- - query_builder : "lancedb.HybridQueryBuilder" - The query builder object that was used to generate the results + query : str + The input query vector_results : pa.Table The results from the vector search fts_results : pa.Table @@ -50,36 +46,6 @@ class Reranker(ABC): """ pass - def rerank_vector( - query_builder: "lancedb.VectorQueryBuilder", vector_results: pa.Table - ): - """ - Rerank function receives the individual results from the vector search. - This isn't mandatory to implement - - Parameters - ---------- - query_builder : "lancedb.VectorQueryBuilder" - The query builder object that was used to generate the results - vector_results : pa.Table - The results from the vector search - """ - raise NotImplementedError("Vector Reranking is not implemented") - - def rerank_fts(query_builder: "lancedb.FTSQueryBuilder", fts_results: pa.Table): - """ - Rerank function receives the individual results from the FTS search. - This isn't mandatory to implement - - Parameters - ---------- - query_builder : "lancedb.FTSQueryBuilder" - The query builder object that was used to generate the results - fts_results : pa.Table - The results from the FTS search - """ - raise NotImplementedError("FTS Reranking is not implemented") - def merge_results(self, vector_results: pa.Table, fts_results: pa.Table): """ Merge the results from the vector and FTS search. This is a vanilla merging diff --git a/python/lancedb/rerankers/cohere.py b/python/lancedb/rerankers/cohere.py index 22363bc2..db5449e7 100644 --- a/python/lancedb/rerankers/cohere.py +++ b/python/lancedb/rerankers/cohere.py @@ -1,5 +1,4 @@ import os -import typing from functools import cached_property from typing import Union @@ -8,9 +7,6 @@ import pyarrow as pa from ..util import safe_import from .base import Reranker -if typing.TYPE_CHECKING: - import lancedb - class CohereReranker(Reranker): """ @@ -55,14 +51,14 @@ class CohereReranker(Reranker): def rerank_hybrid( self, - query_builder: "lancedb.HybridQueryBuilder", + query: str, vector_results: pa.Table, fts_results: pa.Table, ): combined_results = self.merge_results(vector_results, fts_results) docs = combined_results[self.column].to_pylist() results = self._client.rerank( - query=query_builder._query, + query=query, documents=docs, top_n=self.top_n, model=self.model_name, diff --git a/python/lancedb/rerankers/colbert.py b/python/lancedb/rerankers/colbert.py new file mode 100644 index 00000000..308b7473 --- /dev/null +++ b/python/lancedb/rerankers/colbert.py @@ -0,0 +1,107 @@ +from functools import cached_property + +import pyarrow as pa + +from ..util import safe_import +from .base import Reranker + + +class ColbertReranker(Reranker): + """ + Reranks the results using the ColBERT model. + + Parameters + ---------- + model_name : str, default "colbert-ir/colbertv2.0" + The name of the cross encoder model to use. + column : str, default "text" + The name of the column to use as input to the cross encoder model. + return_score : str, default "relevance" + options are "relevance" or "all". Only "relevance" is supported for now. + """ + + def __init__( + self, + model_name: str = "colbert-ir/colbertv2.0", + column: str = "text", + return_score="relevance", + ): + super().__init__(return_score) + self.model_name = model_name + self.column = column + self.torch = safe_import("torch") # import here for faster ops later + + def rerank_hybrid( + self, + query: str, + vector_results: pa.Table, + fts_results: pa.Table, + ): + combined_results = self.merge_results(vector_results, fts_results) + docs = combined_results[self.column].to_pylist() + + tokenizer, model = self._model + + # Encode the query + query_encoding = tokenizer(query, return_tensors="pt") + query_embedding = model(**query_encoding).last_hidden_state.mean(dim=1) + scores = [] + # Get score for each document + for document in docs: + document_encoding = tokenizer( + document, return_tensors="pt", truncation=True, max_length=512 + ) + document_embedding = model(**document_encoding).last_hidden_state + # Calculate MaxSim score + score = self.maxsim(query_embedding.unsqueeze(0), document_embedding) + scores.append(score.item()) + + # replace the self.column column with the docs + combined_results = combined_results.drop(self.column) + combined_results = combined_results.append_column( + self.column, pa.array(docs, type=pa.string()) + ) + # add the scores + combined_results = combined_results.append_column( + "_relevance_score", pa.array(scores, type=pa.float32()) + ) + if self.score == "relevance": + combined_results = combined_results.drop_columns(["score", "_distance"]) + elif self.score == "all": + raise NotImplementedError( + "OpenAI Reranker does not support score='all' yet" + ) + + combined_results = combined_results.sort_by( + [("_relevance_score", "descending")] + ) + + return combined_results + + @cached_property + def _model(self): + transformers = safe_import("transformers") + tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_name) + model = transformers.AutoModel.from_pretrained(self.model_name) + + return tokenizer, model + + def maxsim(self, query_embedding, document_embedding): + # Expand dimensions for broadcasting + # Query: [batch, length, size] -> [batch, query, 1, size] + # Document: [batch, length, size] -> [batch, 1, length, size] + expanded_query = query_embedding.unsqueeze(2) + expanded_doc = document_embedding.unsqueeze(1) + + # Compute cosine similarity across the embedding dimension + sim_matrix = self.torch.nn.functional.cosine_similarity( + expanded_query, expanded_doc, dim=-1 + ) + + # Take the maximum similarity for each query token (across all document tokens) + # sim_matrix shape: [batch_size, query_length, doc_length] + max_sim_scores, _ = self.torch.max(sim_matrix, dim=2) + + # Average these maximum scores across all query tokens + avg_max_sim = self.torch.mean(max_sim_scores, dim=1) + return avg_max_sim diff --git a/python/lancedb/rerankers/cross_encoder.py b/python/lancedb/rerankers/cross_encoder.py index 4d7e1c42..08c89096 100644 --- a/python/lancedb/rerankers/cross_encoder.py +++ b/python/lancedb/rerankers/cross_encoder.py @@ -1,4 +1,3 @@ -import typing from functools import cached_property from typing import Union @@ -7,9 +6,6 @@ import pyarrow as pa from ..util import safe_import from .base import Reranker -if typing.TYPE_CHECKING: - import lancedb - class CrossEncoderReranker(Reranker): """ @@ -52,13 +48,13 @@ class CrossEncoderReranker(Reranker): def rerank_hybrid( self, - query_builder: "lancedb.HybridQueryBuilder", + query: str, vector_results: pa.Table, fts_results: pa.Table, ): combined_results = self.merge_results(vector_results, fts_results) passages = combined_results[self.column].to_pylist() - cross_inp = [[query_builder._query, passage] for passage in passages] + cross_inp = [[query, passage] for passage in passages] cross_scores = self.model.predict(cross_inp) combined_results = combined_results.append_column( "_relevance_score", pa.array(cross_scores, type=pa.float32()) diff --git a/python/lancedb/rerankers/linear_combination.py b/python/lancedb/rerankers/linear_combination.py index d5032999..4f4110fa 100644 --- a/python/lancedb/rerankers/linear_combination.py +++ b/python/lancedb/rerankers/linear_combination.py @@ -36,7 +36,7 @@ class LinearCombinationReranker(Reranker): def rerank_hybrid( self, - query_builder: "lancedb.HybridQueryBuilder", # noqa: F821 + query: str, # noqa: F821 vector_results: pa.Table, fts_results: pa.Table, ): diff --git a/python/lancedb/rerankers/openai.py b/python/lancedb/rerankers/openai.py new file mode 100644 index 00000000..0e99beb0 --- /dev/null +++ b/python/lancedb/rerankers/openai.py @@ -0,0 +1,102 @@ +import json +import os +from functools import cached_property +from typing import Optional + +import pyarrow as pa + +from ..util import safe_import +from .base import Reranker + + +class OpenaiReranker(Reranker): + """ + Reranks the results using the OpenAI API. + WARNING: This is a prompt based reranker that uses chat model that is + not a dedicated reranker API. This should be treated as experimental. + + Parameters + ---------- + model_name : str, default "gpt-3.5-turbo-1106 " + The name of the cross encoder model to use. + column : str, default "text" + The name of the column to use as input to the cross encoder model. + return_score : str, default "relevance" + options are "relevance" or "all". Only "relevance" is supported for now. + api_key : str, default None + The API key to use. If None, will use the OPENAI_API_KEY environment variable. + """ + + def __init__( + self, + model_name: str = "gpt-3.5-turbo-1106", + column: str = "text", + return_score="relevance", + api_key: Optional[str] = None, + ): + super().__init__(return_score) + self.model_name = model_name + self.column = column + self.api_key = api_key + + def rerank_hybrid( + self, + query: str, + vector_results: pa.Table, + fts_results: pa.Table, + ): + combined_results = self.merge_results(vector_results, fts_results) + docs = combined_results[self.column].to_pylist() + response = self._client.chat.completions.create( + model=self.model_name, + response_format={"type": "json_object"}, + temperature=0, + messages=[ + { + "role": "system", + "content": "You are an expert relevance ranker. Given a list of\ + documents and a query, your job is to determine the relevance\ + each document is for answering the query. Your output is JSON,\ + which is a list of documents. Each document has two fields,\ + content and relevance_score. relevance_score is from 0.0 to\ + 1.0 indicating the relevance of the text to the given query.\ + Make sure to include all documents in the response.", + }, + {"role": "user", "content": f"Query: {query} Docs: {docs}"}, + ], + ) + results = json.loads(response.choices[0].message.content)["documents"] + docs, scores = list( + zip(*[(result["content"], result["relevance_score"]) for result in results]) + ) # tuples + # replace the self.column column with the docs + combined_results = combined_results.drop(self.column) + combined_results = combined_results.append_column( + self.column, pa.array(docs, type=pa.string()) + ) + # add the scores + combined_results = combined_results.append_column( + "_relevance_score", pa.array(scores, type=pa.float32()) + ) + if self.score == "relevance": + combined_results = combined_results.drop_columns(["score", "_distance"]) + elif self.score == "all": + raise NotImplementedError( + "OpenAI Reranker does not support score='all' yet" + ) + + combined_results = combined_results.sort_by( + [("_relevance_score", "descending")] + ) + + return combined_results + + @cached_property + def _client(self): + openai = safe_import("openai") # TODO: force version or handle versions < 1.0 + if os.environ.get("OPENAI_API_KEY") is None and self.api_key is None: + raise ValueError( + "OPENAI_API_KEY not set. Either set it in your environment or \ + pass it as `api_key` argument to the CohereReranker." + ) + return openai.OpenAI(api_key=os.environ.get("OPENAI_API_KEY") or self.api_key) diff --git a/python/lancedb/table.py b/python/lancedb/table.py index 99a2da15..38f28a06 100644 --- a/python/lancedb/table.py +++ b/python/lancedb/table.py @@ -447,7 +447,7 @@ class Table(ABC): *default "vector"* query_type: str *default "auto"*. - Acceptable types are: "vector", "fts", or "auto" + Acceptable types are: "vector", "fts", "hybrid", or "auto" - If "auto" then the query type is inferred from the query; diff --git a/python/tests/test_rerankers.py b/python/tests/test_rerankers.py index 19be1f8e..5d28e412 100644 --- a/python/tests/test_rerankers.py +++ b/python/tests/test_rerankers.py @@ -7,7 +7,12 @@ import lancedb from lancedb.conftest import MockTextEmbeddingFunction # noqa from lancedb.embeddings import EmbeddingFunctionRegistry from lancedb.pydantic import LanceModel, Vector -from lancedb.rerankers import CohereReranker, CrossEncoderReranker +from lancedb.rerankers import ( + CohereReranker, + ColbertReranker, + CrossEncoderReranker, + OpenaiReranker, +) from lancedb.table import LanceTable @@ -75,7 +80,6 @@ def get_test_table(tmp_path): return table, MyTable -## These tests are pretty loose, we should also check for correctness def test_linear_combination(tmp_path): table, schema = get_test_table(tmp_path) # The default reranker @@ -95,14 +99,19 @@ def test_linear_combination(tmp_path): assert result1 == result3 # 2 & 3 should be the same as they use score as score + query = "Our father who art in heaven" + query_vector = table.to_pandas()["vector"][0] result = ( - table.search("Our father who art in heaven", query_type="hybrid") - .limit(50) + table.search((query_vector, query)) + .limit(30) .rerank(normalize="score") .to_arrow() ) + + assert len(result) == 30 + assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), ( - "The _score column of the results returned by the reranker " + "The _relevance_score column of the results returned by the reranker " "represents the relevance of the result to the query & should " "be descending." ) @@ -122,19 +131,24 @@ def test_cohere_reranker(tmp_path): ) result2 = ( table.search("Our father who art in heaven", query_type="hybrid") - .rerank(normalize="rank", reranker=CohereReranker()) + .rerank(reranker=CohereReranker()) .to_pydantic(schema) ) assert result1 == result2 + query = "Our father who art in heaven" + query_vector = table.to_pandas()["vector"][0] result = ( - table.search("Our father who art in heaven", query_type="hybrid") - .limit(50) + table.search((query_vector, query)) + .limit(30) .rerank(reranker=CohereReranker()) .to_arrow() ) + + assert len(result) == 30 + assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), ( - "The _score column of the results returned by the reranker " + "The _relevance_score column of the results returned by the reranker " "represents the relevance of the result to the query & should " "be descending." ) @@ -150,19 +164,96 @@ def test_cross_encoder_reranker(tmp_path): ) result2 = ( table.search("Our father who art in heaven", query_type="hybrid") - .rerank(normalize="rank", reranker=CrossEncoderReranker()) + .rerank(reranker=CrossEncoderReranker()) .to_pydantic(schema) ) assert result1 == result2 + # test explicit hybrid query + query = "Our father who art in heaven" + query_vector = table.to_pandas()["vector"][0] result = ( - table.search("Our father who art in heaven", query_type="hybrid") - .limit(50) + table.search((query_vector, query), query_type="hybrid") + .limit(30) .rerank(reranker=CrossEncoderReranker()) .to_arrow() ) + + assert len(result) == 30 + assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), ( - "The _score column of the results returned by the reranker " + "The _relevance_score column of the results returned by the reranker " + "represents the relevance of the result to the query & should " + "be descending." + ) + + +def test_colbert_reranker(tmp_path): + pytest.importorskip("transformers") + table, schema = get_test_table(tmp_path) + result1 = ( + table.search("Our father who art in heaven", query_type="hybrid") + .rerank(normalize="score", reranker=ColbertReranker()) + .to_pydantic(schema) + ) + result2 = ( + table.search("Our father who art in heaven", query_type="hybrid") + .rerank(reranker=ColbertReranker()) + .to_pydantic(schema) + ) + assert result1 == result2 + + # test explicit hybrid query + query = "Our father who art in heaven" + query_vector = table.to_pandas()["vector"][0] + result = ( + table.search((query_vector, query)) + .limit(30) + .rerank(reranker=ColbertReranker()) + .to_arrow() + ) + + assert len(result) == 30 + + assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), ( + "The _relevance_score column of the results returned by the reranker " + "represents the relevance of the result to the query & should " + "be descending." + ) + + +@pytest.mark.skipif( + os.environ.get("OPENAI_API_KEY") is None, reason="OPENAI_API_KEY not set" +) +def test_openai_reranker(tmp_path): + pytest.importorskip("openai") + table, schema = get_test_table(tmp_path) + result1 = ( + table.search("Our father who art in heaven", query_type="hybrid") + .rerank(normalize="score", reranker=OpenaiReranker()) + .to_pydantic(schema) + ) + result2 = ( + table.search("Our father who art in heaven", query_type="hybrid") + .rerank(reranker=OpenaiReranker()) + .to_pydantic(schema) + ) + assert result1 == result2 + + # test explicit hybrid query + query = "Our father who art in heaven" + query_vector = table.to_pandas()["vector"][0] + result = ( + table.search((query_vector, query)) + .limit(30) + .rerank(reranker=OpenaiReranker()) + .to_arrow() + ) + + assert len(result) == 30 + + assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), ( + "The _relevance_score column of the results returned by the reranker " "represents the relevance of the result to the query & should " "be descending." )