mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-09 13:22:58 +00:00
feat(python): support .rerank() on non-hybrid queries in Async API (WIP) (#1972)
Fixes https://github.com/lancedb/lancedb/issues/1950 --------- Co-authored-by: Renato Marroquin <renato.marroquin@oracle.com>
This commit is contained in:
@@ -1802,10 +1802,22 @@ class AsyncFTSQuery(AsyncQueryBase):
|
||||
def __init__(self, inner: LanceFTSQuery):
|
||||
super().__init__(inner)
|
||||
self._inner = inner
|
||||
self._reranker = None
|
||||
|
||||
def get_query(self):
|
||||
self._inner.get_query()
|
||||
|
||||
def rerank(
|
||||
self,
|
||||
reranker: Reranker = RRFReranker(),
|
||||
) -> AsyncFTSQuery:
|
||||
if reranker and not isinstance(reranker, Reranker):
|
||||
raise ValueError("reranker must be an instance of Reranker class.")
|
||||
|
||||
self._reranker = reranker
|
||||
|
||||
return self
|
||||
|
||||
def nearest_to(
|
||||
self,
|
||||
query_vector: Union[VEC, Tuple, List[VEC]],
|
||||
@@ -1876,6 +1888,12 @@ class AsyncFTSQuery(AsyncQueryBase):
|
||||
self._inner.nearest_to(AsyncQuery._query_vec_to_array(query_vector))
|
||||
)
|
||||
|
||||
async def to_arrow(self) -> pa.Table:
|
||||
results = await super().to_arrow()
|
||||
if self._reranker:
|
||||
results = self._reranker.rerank_fts(results)
|
||||
return results
|
||||
|
||||
|
||||
class AsyncVectorQuery(AsyncQueryBase):
|
||||
def __init__(self, inner: LanceVectorQuery):
|
||||
@@ -1890,6 +1908,7 @@ class AsyncVectorQuery(AsyncQueryBase):
|
||||
"""
|
||||
super().__init__(inner)
|
||||
self._inner = inner
|
||||
self._reranker = None
|
||||
|
||||
def column(self, column: str) -> AsyncVectorQuery:
|
||||
"""
|
||||
@@ -2035,6 +2054,16 @@ class AsyncVectorQuery(AsyncQueryBase):
|
||||
self._inner.bypass_vector_index()
|
||||
return self
|
||||
|
||||
def rerank(
|
||||
self, reranker: Reranker = RRFReranker(), query_string: Optional[str] = None
|
||||
) -> AsyncHybridQuery:
|
||||
if reranker and not isinstance(reranker, Reranker):
|
||||
raise ValueError("reranker must be an instance of Reranker class.")
|
||||
|
||||
self._reranker = reranker
|
||||
|
||||
return self
|
||||
|
||||
def nearest_to_text(
|
||||
self, query: str, columns: Union[str, List[str]] = []
|
||||
) -> AsyncHybridQuery:
|
||||
@@ -2068,6 +2097,12 @@ class AsyncVectorQuery(AsyncQueryBase):
|
||||
self._inner.nearest_to_text({"query": query, "columns": columns})
|
||||
)
|
||||
|
||||
async def to_arrow(self) -> pa.Table:
|
||||
results = await super().to_arrow()
|
||||
if self._reranker:
|
||||
results = self._reranker.rerank_vector(results)
|
||||
return results
|
||||
|
||||
|
||||
class AsyncHybridQuery(AsyncQueryBase):
|
||||
"""
|
||||
|
||||
@@ -6,14 +6,18 @@ from datetime import timedelta
|
||||
from pathlib import Path
|
||||
|
||||
import lancedb
|
||||
from lancedb.index import IvfPq
|
||||
from lancedb.index import IvfPq, FTS
|
||||
import numpy as np
|
||||
import pandas.testing as tm
|
||||
import pyarrow as pa
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from lancedb.query import AsyncQueryBase, LanceVectorQueryBuilder, Query
|
||||
from lancedb.query import (
|
||||
AsyncQueryBase,
|
||||
LanceVectorQueryBuilder,
|
||||
Query,
|
||||
)
|
||||
from lancedb.table import AsyncTable, LanceTable
|
||||
|
||||
|
||||
@@ -47,6 +51,7 @@ async def table_async(tmp_path) -> AsyncTable:
|
||||
"id": pa.array([1, 2]),
|
||||
"str_field": pa.array(["a", "b"]),
|
||||
"float_field": pa.array([1.0, 2.0]),
|
||||
"text": pa.array(["a", "dog"]),
|
||||
}
|
||||
)
|
||||
return await conn.create_table("test", data)
|
||||
@@ -314,7 +319,7 @@ async def test_query_async(table_async: AsyncTable):
|
||||
await check_query(
|
||||
table_async.query(),
|
||||
expected_num_rows=2,
|
||||
expected_columns=["vector", "id", "str_field", "float_field"],
|
||||
expected_columns=["vector", "id", "str_field", "float_field", "text"],
|
||||
)
|
||||
await check_query(table_async.query().where("id = 2"), expected_num_rows=1)
|
||||
await check_query(
|
||||
@@ -383,32 +388,42 @@ async def test_query_async(table_async: AsyncTable):
|
||||
expected_columns=["id", "vector", "_rowid"],
|
||||
)
|
||||
|
||||
# FTS with rerank
|
||||
await table_async.create_index("text", config=FTS(with_position=False))
|
||||
await check_query(
|
||||
table_async.query().nearest_to_text("dog").rerank(),
|
||||
expected_num_rows=1,
|
||||
)
|
||||
|
||||
# Vector query with rerank
|
||||
await check_query(table_async.vector_search([1, 2]).rerank(), expected_num_rows=2)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_to_arrow_async(table_async: AsyncTable):
|
||||
table = await table_async.to_arrow()
|
||||
assert table.num_rows == 2
|
||||
assert table.num_columns == 4
|
||||
assert table.num_columns == 5
|
||||
|
||||
table = await table_async.query().to_arrow()
|
||||
assert table.num_rows == 2
|
||||
assert table.num_columns == 4
|
||||
assert table.num_columns == 5
|
||||
|
||||
table = await table_async.query().where("id < 0").to_arrow()
|
||||
assert table.num_rows == 0
|
||||
assert table.num_columns == 4
|
||||
assert table.num_columns == 5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_to_pandas_async(table_async: AsyncTable):
|
||||
df = await table_async.to_pandas()
|
||||
assert df.shape == (2, 4)
|
||||
assert df.shape == (2, 5)
|
||||
|
||||
df = await table_async.query().to_pandas()
|
||||
assert df.shape == (2, 4)
|
||||
assert df.shape == (2, 5)
|
||||
|
||||
df = await table_async.query().where("id < 0").to_pandas()
|
||||
assert df.shape == (0, 4)
|
||||
assert df.shape == (0, 5)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
Reference in New Issue
Block a user