mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-26 06:39:57 +00:00
fix(python): various fixes for async query builders (#2048)
This includes several improvements and fixes to the Python Async query builders: 1. The API reference docs show all the methods for each builder 2. The hybrid query builder now has all the same setter methods as the vector search one, so you can now set things like `.distance_type()` on a hybrid query. 3. Re-rankers are now properly hooked up and tested for FTS and vector search. Previously the re-rankers were accidentally bypassed in unit tests, because the builders overrode `.to_arrow()`, but the unit test called `.to_batches()` which was only defined in the base class. Now all builders implement `.to_batches()` and leave `.to_arrow()` to the base class. 4. The `AsyncQueryBase` and `AsyncVectoryQueryBase` setter methods now return `Self`, which provides the appropriate subclass as the type hint return value. Previously, `AsyncQueryBase` had them all hard-coded to `AsyncQuery`, which was unfortunate. (This required bringing in `typing-extensions` for older Python version, but I think it's worth it.)
This commit is contained in:
@@ -7,6 +7,7 @@ from pathlib import Path
|
||||
|
||||
import lancedb
|
||||
from lancedb.index import IvfPq, FTS
|
||||
from lancedb.rerankers.cross_encoder import CrossEncoderReranker
|
||||
import numpy as np
|
||||
import pandas.testing as tm
|
||||
import pyarrow as pa
|
||||
@@ -515,15 +516,24 @@ async def test_query_async(table_async: AsyncTable):
|
||||
expected_columns=["id", "vector", "_rowid"],
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.slow
|
||||
async def test_query_reranked_async(table_async: AsyncTable):
|
||||
# FTS with rerank
|
||||
await table_async.create_index("text", config=FTS(with_position=False))
|
||||
await check_query(
|
||||
table_async.query().nearest_to_text("dog").rerank(),
|
||||
table_async.query().nearest_to_text("dog").rerank(CrossEncoderReranker()),
|
||||
expected_num_rows=1,
|
||||
)
|
||||
|
||||
# Vector query with rerank
|
||||
await check_query(table_async.vector_search([1, 2]).rerank(), expected_num_rows=2)
|
||||
await check_query(
|
||||
table_async.vector_search([1, 2]).rerank(
|
||||
CrossEncoderReranker(), query_string="dog"
|
||||
),
|
||||
expected_num_rows=2,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
Reference in New Issue
Block a user