diff --git a/python/python/lancedb/query.py b/python/python/lancedb/query.py index 81363a1b..2d337f05 100644 --- a/python/python/lancedb/query.py +++ b/python/python/lancedb/query.py @@ -1802,10 +1802,22 @@ class AsyncFTSQuery(AsyncQueryBase): def __init__(self, inner: LanceFTSQuery): super().__init__(inner) self._inner = inner + self._reranker = None def get_query(self): self._inner.get_query() + def rerank( + self, + reranker: Reranker = RRFReranker(), + ) -> AsyncFTSQuery: + if reranker and not isinstance(reranker, Reranker): + raise ValueError("reranker must be an instance of Reranker class.") + + self._reranker = reranker + + return self + def nearest_to( self, query_vector: Union[VEC, Tuple, List[VEC]], @@ -1876,6 +1888,12 @@ class AsyncFTSQuery(AsyncQueryBase): self._inner.nearest_to(AsyncQuery._query_vec_to_array(query_vector)) ) + async def to_arrow(self) -> pa.Table: + results = await super().to_arrow() + if self._reranker: + results = self._reranker.rerank_fts(results) + return results + class AsyncVectorQuery(AsyncQueryBase): def __init__(self, inner: LanceVectorQuery): @@ -1890,6 +1908,7 @@ class AsyncVectorQuery(AsyncQueryBase): """ super().__init__(inner) self._inner = inner + self._reranker = None def column(self, column: str) -> AsyncVectorQuery: """ @@ -2035,6 +2054,16 @@ class AsyncVectorQuery(AsyncQueryBase): self._inner.bypass_vector_index() return self + def rerank( + self, reranker: Reranker = RRFReranker(), query_string: Optional[str] = None + ) -> AsyncHybridQuery: + if reranker and not isinstance(reranker, Reranker): + raise ValueError("reranker must be an instance of Reranker class.") + + self._reranker = reranker + + return self + def nearest_to_text( self, query: str, columns: Union[str, List[str]] = [] ) -> AsyncHybridQuery: @@ -2068,6 +2097,12 @@ class AsyncVectorQuery(AsyncQueryBase): self._inner.nearest_to_text({"query": query, "columns": columns}) ) + async def to_arrow(self) -> pa.Table: + results = await super().to_arrow() + if self._reranker: + results = self._reranker.rerank_vector(results) + return results + class AsyncHybridQuery(AsyncQueryBase): """ diff --git a/python/python/tests/test_query.py b/python/python/tests/test_query.py index cc796545..3f651fad 100644 --- a/python/python/tests/test_query.py +++ b/python/python/tests/test_query.py @@ -6,14 +6,18 @@ from datetime import timedelta from pathlib import Path import lancedb -from lancedb.index import IvfPq +from lancedb.index import IvfPq, FTS import numpy as np import pandas.testing as tm import pyarrow as pa import pytest import pytest_asyncio from lancedb.pydantic import LanceModel, Vector -from lancedb.query import AsyncQueryBase, LanceVectorQueryBuilder, Query +from lancedb.query import ( + AsyncQueryBase, + LanceVectorQueryBuilder, + Query, +) from lancedb.table import AsyncTable, LanceTable @@ -47,6 +51,7 @@ async def table_async(tmp_path) -> AsyncTable: "id": pa.array([1, 2]), "str_field": pa.array(["a", "b"]), "float_field": pa.array([1.0, 2.0]), + "text": pa.array(["a", "dog"]), } ) return await conn.create_table("test", data) @@ -314,7 +319,7 @@ async def test_query_async(table_async: AsyncTable): await check_query( table_async.query(), expected_num_rows=2, - expected_columns=["vector", "id", "str_field", "float_field"], + expected_columns=["vector", "id", "str_field", "float_field", "text"], ) await check_query(table_async.query().where("id = 2"), expected_num_rows=1) await check_query( @@ -383,32 +388,42 @@ async def test_query_async(table_async: AsyncTable): expected_columns=["id", "vector", "_rowid"], ) + # FTS with rerank + await table_async.create_index("text", config=FTS(with_position=False)) + await check_query( + table_async.query().nearest_to_text("dog").rerank(), + expected_num_rows=1, + ) + + # Vector query with rerank + await check_query(table_async.vector_search([1, 2]).rerank(), expected_num_rows=2) + @pytest.mark.asyncio async def test_query_to_arrow_async(table_async: AsyncTable): table = await table_async.to_arrow() assert table.num_rows == 2 - assert table.num_columns == 4 + assert table.num_columns == 5 table = await table_async.query().to_arrow() assert table.num_rows == 2 - assert table.num_columns == 4 + assert table.num_columns == 5 table = await table_async.query().where("id < 0").to_arrow() assert table.num_rows == 0 - assert table.num_columns == 4 + assert table.num_columns == 5 @pytest.mark.asyncio async def test_query_to_pandas_async(table_async: AsyncTable): df = await table_async.to_pandas() - assert df.shape == (2, 4) + assert df.shape == (2, 5) df = await table_async.query().to_pandas() - assert df.shape == (2, 4) + assert df.shape == (2, 5) df = await table_async.query().where("id < 0").to_pandas() - assert df.shape == (0, 4) + assert df.shape == (0, 5) @pytest.mark.asyncio