fix(python): various fixes for async query builders (#2048)

This includes several improvements and fixes to the Python Async query
builders:

1. The API reference docs show all the methods for each builder
2. The hybrid query builder now has all the same setter methods as the
vector search one, so you can now set things like `.distance_type()` on
a hybrid query.
3. Re-rankers are now properly hooked up and tested for FTS and vector
search. Previously the re-rankers were accidentally bypassed in unit
tests, because the builders overrode `.to_arrow()`, but the unit test
called `.to_batches()` which was only defined in the base class. Now all
builders implement `.to_batches()` and leave `.to_arrow()` to the base
class.
4. The `AsyncQueryBase` and `AsyncVectoryQueryBase` setter methods now
return `Self`, which provides the appropriate subclass as the type hint
return value. Previously, `AsyncQueryBase` had them all hard-coded to
`AsyncQuery`, which was unfortunate. (This required bringing in
`typing-extensions` for older Python version, but I think it's worth
it.)
This commit is contained in:
Will Jones
2025-01-20 16:14:34 -08:00
committed by GitHub
parent 214d0debf5
commit bcfc93cc88
9 changed files with 153 additions and 73 deletions

View File

@@ -7,6 +7,7 @@ from pathlib import Path
import lancedb
from lancedb.index import IvfPq, FTS
from lancedb.rerankers.cross_encoder import CrossEncoderReranker
import numpy as np
import pandas.testing as tm
import pyarrow as pa
@@ -515,15 +516,24 @@ async def test_query_async(table_async: AsyncTable):
expected_columns=["id", "vector", "_rowid"],
)
@pytest.mark.asyncio
@pytest.mark.slow
async def test_query_reranked_async(table_async: AsyncTable):
# FTS with rerank
await table_async.create_index("text", config=FTS(with_position=False))
await check_query(
table_async.query().nearest_to_text("dog").rerank(),
table_async.query().nearest_to_text("dog").rerank(CrossEncoderReranker()),
expected_num_rows=1,
)
# Vector query with rerank
await check_query(table_async.vector_search([1, 2]).rerank(), expected_num_rows=2)
await check_query(
table_async.vector_search([1, 2]).rerank(
CrossEncoderReranker(), query_string="dog"
),
expected_num_rows=2,
)
@pytest.mark.asyncio