mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-27 07:09:57 +00:00
feat(python): support .rerank() on non-hybrid queries in Async API (WIP) (#1972)
Fixes https://github.com/lancedb/lancedb/issues/1950 --------- Co-authored-by: Renato Marroquin <renato.marroquin@oracle.com>
This commit is contained in:
@@ -6,14 +6,18 @@ from datetime import timedelta
|
||||
from pathlib import Path
|
||||
|
||||
import lancedb
|
||||
from lancedb.index import IvfPq
|
||||
from lancedb.index import IvfPq, FTS
|
||||
import numpy as np
|
||||
import pandas.testing as tm
|
||||
import pyarrow as pa
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from lancedb.query import AsyncQueryBase, LanceVectorQueryBuilder, Query
|
||||
from lancedb.query import (
|
||||
AsyncQueryBase,
|
||||
LanceVectorQueryBuilder,
|
||||
Query,
|
||||
)
|
||||
from lancedb.table import AsyncTable, LanceTable
|
||||
|
||||
|
||||
@@ -47,6 +51,7 @@ async def table_async(tmp_path) -> AsyncTable:
|
||||
"id": pa.array([1, 2]),
|
||||
"str_field": pa.array(["a", "b"]),
|
||||
"float_field": pa.array([1.0, 2.0]),
|
||||
"text": pa.array(["a", "dog"]),
|
||||
}
|
||||
)
|
||||
return await conn.create_table("test", data)
|
||||
@@ -314,7 +319,7 @@ async def test_query_async(table_async: AsyncTable):
|
||||
await check_query(
|
||||
table_async.query(),
|
||||
expected_num_rows=2,
|
||||
expected_columns=["vector", "id", "str_field", "float_field"],
|
||||
expected_columns=["vector", "id", "str_field", "float_field", "text"],
|
||||
)
|
||||
await check_query(table_async.query().where("id = 2"), expected_num_rows=1)
|
||||
await check_query(
|
||||
@@ -383,32 +388,42 @@ async def test_query_async(table_async: AsyncTable):
|
||||
expected_columns=["id", "vector", "_rowid"],
|
||||
)
|
||||
|
||||
# FTS with rerank
|
||||
await table_async.create_index("text", config=FTS(with_position=False))
|
||||
await check_query(
|
||||
table_async.query().nearest_to_text("dog").rerank(),
|
||||
expected_num_rows=1,
|
||||
)
|
||||
|
||||
# Vector query with rerank
|
||||
await check_query(table_async.vector_search([1, 2]).rerank(), expected_num_rows=2)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_to_arrow_async(table_async: AsyncTable):
|
||||
table = await table_async.to_arrow()
|
||||
assert table.num_rows == 2
|
||||
assert table.num_columns == 4
|
||||
assert table.num_columns == 5
|
||||
|
||||
table = await table_async.query().to_arrow()
|
||||
assert table.num_rows == 2
|
||||
assert table.num_columns == 4
|
||||
assert table.num_columns == 5
|
||||
|
||||
table = await table_async.query().where("id < 0").to_arrow()
|
||||
assert table.num_rows == 0
|
||||
assert table.num_columns == 4
|
||||
assert table.num_columns == 5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_to_pandas_async(table_async: AsyncTable):
|
||||
df = await table_async.to_pandas()
|
||||
assert df.shape == (2, 4)
|
||||
assert df.shape == (2, 5)
|
||||
|
||||
df = await table_async.query().to_pandas()
|
||||
assert df.shape == (2, 4)
|
||||
assert df.shape == (2, 5)
|
||||
|
||||
df = await table_async.query().where("id < 0").to_pandas()
|
||||
assert df.shape == (0, 4)
|
||||
assert df.shape == (0, 5)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
Reference in New Issue
Block a user