From ed7bd45c17034e37f697ef1fb861957e4039fde7 Mon Sep 17 00:00:00 2001 From: Ayush Chaurasia Date: Thu, 18 Jul 2024 21:04:59 +0530 Subject: [PATCH] chore: choose appropriate args for concat_table based on pyarrow version & refactor reranker tests (#1455) --- python/python/lancedb/conftest.py | 2 +- python/python/lancedb/rerankers/base.py | 12 +- python/python/tests/test_rerankers.py | 290 ++++++------------------ 3 files changed, 82 insertions(+), 222 deletions(-) diff --git a/python/python/lancedb/conftest.py b/python/python/lancedb/conftest.py index 273afbf7..7a6a5fd1 100644 --- a/python/python/lancedb/conftest.py +++ b/python/python/lancedb/conftest.py @@ -35,7 +35,7 @@ class MockTextEmbeddingFunction(TextEmbeddingFunction): def _compute_one_embedding(self, row): emb = np.array([float(hash(c)) for c in row[:10]]) emb /= np.linalg.norm(emb) - return emb + return emb if len(emb) == 10 else [0] * 10 def ndims(self): return 10 diff --git a/python/python/lancedb/rerankers/base.py b/python/python/lancedb/rerankers/base.py index d3881741..75a767f4 100644 --- a/python/python/lancedb/rerankers/base.py +++ b/python/python/lancedb/rerankers/base.py @@ -1,8 +1,11 @@ from abc import ABC, abstractmethod +from packaging.version import Version import numpy as np import pyarrow as pa +ARROW_VERSION = Version(pa.__version__) + class Reranker(ABC): def __init__(self, return_score: str = "relevance"): @@ -23,6 +26,11 @@ class Reranker(ABC): if return_score not in ["relevance", "all"]: raise ValueError("score must be either 'relevance' or 'all'") self.score = return_score + # Set the merge args based on the arrow version here to avoid checking it at + # each query + self._concat_tables_args = {"promote_options": "default"} + if ARROW_VERSION.major <= 13: + self._concat_tables_args = {"promote": True} def rerank_vector( self, @@ -119,7 +127,9 @@ class Reranker(ABC): fts_results : pa.Table The results from the FTS search """ - combined = pa.concat_tables([vector_results, fts_results], promote=True) + combined = pa.concat_tables( + [vector_results, fts_results], **self._concat_tables_args + ) row_id = combined.column("_rowid") # deduplicate diff --git a/python/python/tests/test_rerankers.py b/python/python/tests/test_rerankers.py index 7775d598..1e303905 100644 --- a/python/python/tests/test_rerankers.py +++ b/python/python/tests/test_rerankers.py @@ -11,6 +11,7 @@ from lancedb.rerankers import ( ColbertReranker, CrossEncoderReranker, OpenaiReranker, + JinaReranker, ) from lancedb.table import LanceTable @@ -82,6 +83,63 @@ def get_test_table(tmp_path): return table, MyTable +def _run_test_reranker(reranker, table, query, query_vector, schema): + # Hybrid search setting + result1 = ( + table.search(query, query_type="hybrid") + .rerank(normalize="score", reranker=reranker) + .to_pydantic(schema) + ) + result2 = ( + table.search(query, query_type="hybrid") + .rerank(reranker=reranker) + .to_pydantic(schema) + ) + assert result1 == result2 + + query_vector = table.to_pandas()["vector"][0] + result = ( + table.search((query_vector, query)) + .limit(30) + .rerank(reranker=reranker) + .to_arrow() + ) + + assert len(result) == 30 + err = ( + "The _relevance_score column of the results returned by the reranker " + "represents the relevance of the result to the query & should " + "be descending." + ) + assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err + + # Vector search setting + result = table.search(query).rerank(reranker=reranker).limit(30).to_arrow() + assert len(result) == 30 + assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err + result_explicit = ( + table.search(query_vector) + .rerank(reranker=reranker, query_string=query) + .limit(30) + .to_arrow() + ) + assert len(result_explicit) == 30 + with pytest.raises( + ValueError + ): # This raises an error because vector query is provided without reanking query + table.search(query_vector).rerank(reranker=reranker).limit(30).to_arrow() + + # FTS search setting + result = ( + table.search(query, query_type="fts") + .rerank(reranker=reranker) + .limit(30) + .to_arrow() + ) + assert len(result) > 0 + assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err + + def test_linear_combination(tmp_path): table, schema = get_test_table(tmp_path) # The default reranker @@ -126,185 +184,21 @@ def test_cohere_reranker(tmp_path): pytest.importorskip("cohere") reranker = CohereReranker() table, schema = get_test_table(tmp_path) - # Hybrid search setting - result1 = ( - table.search("Our father who art in heaven", query_type="hybrid") - .rerank(normalize="score", reranker=CohereReranker()) - .to_pydantic(schema) - ) - result2 = ( - table.search("Our father who art in heaven", query_type="hybrid") - .rerank(reranker=reranker) - .to_pydantic(schema) - ) - assert result1 == result2 - - query = "Our father who art in heaven" - query_vector = table.to_pandas()["vector"][0] - result = ( - table.search((query_vector, query)) - .limit(30) - .rerank(reranker=reranker) - .to_arrow() - ) - - assert len(result) == 30 - err = ( - "The _relevance_score column of the results returned by the reranker " - "represents the relevance of the result to the query & should " - "be descending." - ) - assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err - - # Vector search setting - query = "Our father who art in heaven" - result = table.search(query).rerank(reranker=reranker).limit(30).to_arrow() - assert len(result) == 30 - assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err - result_explicit = ( - table.search(query_vector) - .rerank(reranker=reranker, query_string=query) - .limit(30) - .to_arrow() - ) - assert len(result_explicit) == 30 - with pytest.raises( - ValueError - ): # This raises an error because vector query is provided without reanking query - table.search(query_vector).rerank(reranker=reranker).limit(30).to_arrow() - - # FTS search setting - result = ( - table.search(query, query_type="fts") - .rerank(reranker=reranker) - .limit(30) - .to_arrow() - ) - assert len(result) > 0 - assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err + _run_test_reranker(reranker, table, "single player experience", None, schema) def test_cross_encoder_reranker(tmp_path): pytest.importorskip("sentence_transformers") reranker = CrossEncoderReranker() table, schema = get_test_table(tmp_path) - result1 = ( - table.search("Our father who art in heaven", query_type="hybrid") - .rerank(normalize="score", reranker=reranker) - .to_pydantic(schema) - ) - result2 = ( - table.search("Our father who art in heaven", query_type="hybrid") - .rerank(reranker=reranker) - .to_pydantic(schema) - ) - assert result1 == result2 - - query = "Our father who art in heaven" - query_vector = table.to_pandas()["vector"][0] - result = ( - table.search((query_vector, query), query_type="hybrid") - .limit(30) - .rerank(reranker=reranker) - .to_arrow() - ) - - assert len(result) == 30 - - err = ( - "The _relevance_score column of the results returned by the reranker " - "represents the relevance of the result to the query & should " - "be descending." - ) - assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err - - # Vector search setting - result = table.search(query).rerank(reranker=reranker).limit(30).to_arrow() - assert len(result) == 30 - assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err - - result_explicit = ( - table.search(query_vector) - .rerank(reranker=reranker, query_string=query) - .limit(30) - .to_arrow() - ) - assert len(result_explicit) == 30 - with pytest.raises( - ValueError - ): # This raises an error because vector query is provided without reanking query - table.search(query_vector).rerank(reranker=reranker).limit(30).to_arrow() - - # FTS search setting - result = ( - table.search(query, query_type="fts") - .rerank(reranker=reranker) - .limit(30) - .to_arrow() - ) - assert len(result) > 0 - assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err + _run_test_reranker(reranker, table, "single player experience", None, schema) def test_colbert_reranker(tmp_path): pytest.importorskip("transformers") reranker = ColbertReranker() table, schema = get_test_table(tmp_path) - result1 = ( - table.search("Our father who art in heaven", query_type="hybrid") - .rerank(normalize="score", reranker=reranker) - .to_pydantic(schema) - ) - result2 = ( - table.search("Our father who art in heaven", query_type="hybrid") - .rerank(reranker=reranker) - .to_pydantic(schema) - ) - assert result1 == result2 - - # test explicit hybrid query - query = "Our father who art in heaven" - query_vector = table.to_pandas()["vector"][0] - result = ( - table.search((query_vector, query)) - .limit(30) - .rerank(reranker=reranker) - .to_arrow() - ) - - assert len(result) == 30 - err = ( - "The _relevance_score column of the results returned by the reranker " - "represents the relevance of the result to the query & should " - "be descending." - ) - assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err - - # Vector search setting - result = table.search(query).rerank(reranker=reranker).limit(30).to_arrow() - assert len(result) == 30 - assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err - result_explicit = ( - table.search(query_vector) - .rerank(reranker=reranker, query_string=query) - .limit(30) - .to_arrow() - ) - assert len(result_explicit) == 30 - with pytest.raises( - ValueError - ): # This raises an error because vector query is provided without reanking query - table.search(query_vector).rerank(reranker=reranker).limit(30).to_arrow() - - # FTS search setting - result = ( - table.search(query, query_type="fts") - .rerank(reranker=reranker) - .limit(30) - .to_arrow() - ) - assert len(result) > 0 - assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err + _run_test_reranker(reranker, table, "single player experience", None, schema) @pytest.mark.skipif( @@ -314,58 +208,14 @@ def test_openai_reranker(tmp_path): pytest.importorskip("openai") table, schema = get_test_table(tmp_path) reranker = OpenaiReranker() - result1 = ( - table.search("Our father who art in heaven", query_type="hybrid") - .rerank(normalize="score", reranker=reranker) - .to_pydantic(schema) - ) - result2 = ( - table.search("Our father who art in heaven", query_type="hybrid") - .rerank(reranker=OpenaiReranker()) - .to_pydantic(schema) - ) - assert result1 == result2 + _run_test_reranker(reranker, table, "single player experience", None, schema) - # test explicit hybrid query - query = "Our father who art in heaven" - query_vector = table.to_pandas()["vector"][0] - result = ( - table.search((query_vector, query)) - .limit(30) - .rerank(reranker=reranker) - .to_arrow() - ) - assert len(result) == 30 - - err = ( - "The _relevance_score column of the results returned by the reranker " - "represents the relevance of the result to the query & should " - "be descending." - ) - assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err - - # Vector search setting - result = table.search(query).rerank(reranker=reranker).limit(30).to_arrow() - assert len(result) == 30 - assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err - result_explicit = ( - table.search(query_vector) - .rerank(reranker=reranker, query_string=query) - .limit(30) - .to_arrow() - ) - assert len(result_explicit) == 30 - with pytest.raises( - ValueError - ): # This raises an error because vector query is provided without reanking query - table.search(query_vector).rerank(reranker=reranker).limit(30).to_arrow() - # FTS search setting - result = ( - table.search(query, query_type="fts") - .rerank(reranker=reranker) - .limit(30) - .to_arrow() - ) - assert len(result) > 0 - assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err +@pytest.mark.skipif( + os.environ.get("JINA_API_KEY") is None, reason="JINA_API_KEY not set" +) +def test_jina_reranker(tmp_path): + pytest.importorskip("jina") + table, schema = get_test_table(tmp_path) + reranker = JinaReranker() + _run_test_reranker(reranker, table, "single player experience", None, schema)