feat(python): support .rerank() on non-hybrid queries in Async API (WIP) (#1972)

Fixes https://github.com/lancedb/lancedb/issues/1950

---------

Co-authored-by: Renato Marroquin <renato.marroquin@oracle.com>
This commit is contained in:
Renato Marroquin
2025-01-08 22:42:47 +01:00
committed by GitHub
parent c557e77f09
commit ea5c2266b8
2 changed files with 59 additions and 9 deletions

View File

@@ -1802,10 +1802,22 @@ class AsyncFTSQuery(AsyncQueryBase):
def __init__(self, inner: LanceFTSQuery):
super().__init__(inner)
self._inner = inner
self._reranker = None
def get_query(self):
self._inner.get_query()
def rerank(
self,
reranker: Reranker = RRFReranker(),
) -> AsyncFTSQuery:
if reranker and not isinstance(reranker, Reranker):
raise ValueError("reranker must be an instance of Reranker class.")
self._reranker = reranker
return self
def nearest_to(
self,
query_vector: Union[VEC, Tuple, List[VEC]],
@@ -1876,6 +1888,12 @@ class AsyncFTSQuery(AsyncQueryBase):
self._inner.nearest_to(AsyncQuery._query_vec_to_array(query_vector))
)
async def to_arrow(self) -> pa.Table:
results = await super().to_arrow()
if self._reranker:
results = self._reranker.rerank_fts(results)
return results
class AsyncVectorQuery(AsyncQueryBase):
def __init__(self, inner: LanceVectorQuery):
@@ -1890,6 +1908,7 @@ class AsyncVectorQuery(AsyncQueryBase):
"""
super().__init__(inner)
self._inner = inner
self._reranker = None
def column(self, column: str) -> AsyncVectorQuery:
"""
@@ -2035,6 +2054,16 @@ class AsyncVectorQuery(AsyncQueryBase):
self._inner.bypass_vector_index()
return self
def rerank(
self, reranker: Reranker = RRFReranker(), query_string: Optional[str] = None
) -> AsyncHybridQuery:
if reranker and not isinstance(reranker, Reranker):
raise ValueError("reranker must be an instance of Reranker class.")
self._reranker = reranker
return self
def nearest_to_text(
self, query: str, columns: Union[str, List[str]] = []
) -> AsyncHybridQuery:
@@ -2068,6 +2097,12 @@ class AsyncVectorQuery(AsyncQueryBase):
self._inner.nearest_to_text({"query": query, "columns": columns})
)
async def to_arrow(self) -> pa.Table:
results = await super().to_arrow()
if self._reranker:
results = self._reranker.rerank_vector(results)
return results
class AsyncHybridQuery(AsyncQueryBase):
"""

View File

@@ -6,14 +6,18 @@ from datetime import timedelta
from pathlib import Path
import lancedb
from lancedb.index import IvfPq
from lancedb.index import IvfPq, FTS
import numpy as np
import pandas.testing as tm
import pyarrow as pa
import pytest
import pytest_asyncio
from lancedb.pydantic import LanceModel, Vector
from lancedb.query import AsyncQueryBase, LanceVectorQueryBuilder, Query
from lancedb.query import (
AsyncQueryBase,
LanceVectorQueryBuilder,
Query,
)
from lancedb.table import AsyncTable, LanceTable
@@ -47,6 +51,7 @@ async def table_async(tmp_path) -> AsyncTable:
"id": pa.array([1, 2]),
"str_field": pa.array(["a", "b"]),
"float_field": pa.array([1.0, 2.0]),
"text": pa.array(["a", "dog"]),
}
)
return await conn.create_table("test", data)
@@ -314,7 +319,7 @@ async def test_query_async(table_async: AsyncTable):
await check_query(
table_async.query(),
expected_num_rows=2,
expected_columns=["vector", "id", "str_field", "float_field"],
expected_columns=["vector", "id", "str_field", "float_field", "text"],
)
await check_query(table_async.query().where("id = 2"), expected_num_rows=1)
await check_query(
@@ -383,32 +388,42 @@ async def test_query_async(table_async: AsyncTable):
expected_columns=["id", "vector", "_rowid"],
)
# FTS with rerank
await table_async.create_index("text", config=FTS(with_position=False))
await check_query(
table_async.query().nearest_to_text("dog").rerank(),
expected_num_rows=1,
)
# Vector query with rerank
await check_query(table_async.vector_search([1, 2]).rerank(), expected_num_rows=2)
@pytest.mark.asyncio
async def test_query_to_arrow_async(table_async: AsyncTable):
table = await table_async.to_arrow()
assert table.num_rows == 2
assert table.num_columns == 4
assert table.num_columns == 5
table = await table_async.query().to_arrow()
assert table.num_rows == 2
assert table.num_columns == 4
assert table.num_columns == 5
table = await table_async.query().where("id < 0").to_arrow()
assert table.num_rows == 0
assert table.num_columns == 4
assert table.num_columns == 5
@pytest.mark.asyncio
async def test_query_to_pandas_async(table_async: AsyncTable):
df = await table_async.to_pandas()
assert df.shape == (2, 4)
assert df.shape == (2, 5)
df = await table_async.query().to_pandas()
assert df.shape == (2, 4)
assert df.shape == (2, 5)
df = await table_async.query().where("id < 0").to_pandas()
assert df.shape == (0, 4)
assert df.shape == (0, 5)
@pytest.mark.asyncio