feat(python): support .rerank() on non-hybrid queries in Async API (WIP) (#1972)

Fixes https://github.com/lancedb/lancedb/issues/1950

---------

Co-authored-by: Renato Marroquin <renato.marroquin@oracle.com>
This commit is contained in:
Renato Marroquin
2025-01-08 22:42:47 +01:00
committed by GitHub
parent c557e77f09
commit ea5c2266b8
2 changed files with 59 additions and 9 deletions

View File

@@ -6,14 +6,18 @@ from datetime import timedelta
from pathlib import Path
import lancedb
from lancedb.index import IvfPq
from lancedb.index import IvfPq, FTS
import numpy as np
import pandas.testing as tm
import pyarrow as pa
import pytest
import pytest_asyncio
from lancedb.pydantic import LanceModel, Vector
from lancedb.query import AsyncQueryBase, LanceVectorQueryBuilder, Query
from lancedb.query import (
AsyncQueryBase,
LanceVectorQueryBuilder,
Query,
)
from lancedb.table import AsyncTable, LanceTable
@@ -47,6 +51,7 @@ async def table_async(tmp_path) -> AsyncTable:
"id": pa.array([1, 2]),
"str_field": pa.array(["a", "b"]),
"float_field": pa.array([1.0, 2.0]),
"text": pa.array(["a", "dog"]),
}
)
return await conn.create_table("test", data)
@@ -314,7 +319,7 @@ async def test_query_async(table_async: AsyncTable):
await check_query(
table_async.query(),
expected_num_rows=2,
expected_columns=["vector", "id", "str_field", "float_field"],
expected_columns=["vector", "id", "str_field", "float_field", "text"],
)
await check_query(table_async.query().where("id = 2"), expected_num_rows=1)
await check_query(
@@ -383,32 +388,42 @@ async def test_query_async(table_async: AsyncTable):
expected_columns=["id", "vector", "_rowid"],
)
# FTS with rerank
await table_async.create_index("text", config=FTS(with_position=False))
await check_query(
table_async.query().nearest_to_text("dog").rerank(),
expected_num_rows=1,
)
# Vector query with rerank
await check_query(table_async.vector_search([1, 2]).rerank(), expected_num_rows=2)
@pytest.mark.asyncio
async def test_query_to_arrow_async(table_async: AsyncTable):
table = await table_async.to_arrow()
assert table.num_rows == 2
assert table.num_columns == 4
assert table.num_columns == 5
table = await table_async.query().to_arrow()
assert table.num_rows == 2
assert table.num_columns == 4
assert table.num_columns == 5
table = await table_async.query().where("id < 0").to_arrow()
assert table.num_rows == 0
assert table.num_columns == 4
assert table.num_columns == 5
@pytest.mark.asyncio
async def test_query_to_pandas_async(table_async: AsyncTable):
df = await table_async.to_pandas()
assert df.shape == (2, 4)
assert df.shape == (2, 5)
df = await table_async.query().to_pandas()
assert df.shape == (2, 4)
assert df.shape == (2, 5)
df = await table_async.query().where("id < 0").to_pandas()
assert df.shape == (0, 4)
assert df.shape == (0, 5)
@pytest.mark.asyncio