mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-26 06:39:57 +00:00
feat: refactor the query API and add query support to the python async API (#1113)
In addition, there are also a number of changes in nodejs to the docstrings of existing methods because this PR adds a jsdoc linter.
This commit is contained in:
@@ -12,16 +12,19 @@
|
||||
# limitations under the License.
|
||||
|
||||
import unittest.mock as mock
|
||||
from datetime import timedelta
|
||||
|
||||
import lance
|
||||
import lancedb
|
||||
import numpy as np
|
||||
import pandas.testing as tm
|
||||
import pyarrow as pa
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from lancedb.db import LanceDBConnection
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from lancedb.query import LanceVectorQueryBuilder, Query
|
||||
from lancedb.table import LanceTable
|
||||
from lancedb.query import AsyncQueryBase, LanceVectorQueryBuilder, Query
|
||||
from lancedb.table import AsyncTable, LanceTable
|
||||
|
||||
|
||||
class MockTable:
|
||||
@@ -65,6 +68,24 @@ def table(tmp_path) -> MockTable:
|
||||
return MockTable(tmp_path)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def table_async(tmp_path) -> AsyncTable:
|
||||
conn = await lancedb.connect_async(
|
||||
tmp_path, read_consistency_interval=timedelta(seconds=0)
|
||||
)
|
||||
data = pa.table(
|
||||
{
|
||||
"vector": pa.array(
|
||||
[[1, 2], [3, 4]], type=pa.list_(pa.float32(), list_size=2)
|
||||
),
|
||||
"id": pa.array([1, 2]),
|
||||
"str_field": pa.array(["a", "b"]),
|
||||
"float_field": pa.array([1.0, 2.0]),
|
||||
}
|
||||
)
|
||||
return await conn.create_table("test", data)
|
||||
|
||||
|
||||
def test_cast(table):
|
||||
class TestModel(LanceModel):
|
||||
vector: Vector(2)
|
||||
@@ -184,3 +205,109 @@ def test_query_builder_with_different_vector_column():
|
||||
|
||||
def cosine_distance(vec1, vec2):
|
||||
return 1 - np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
|
||||
|
||||
|
||||
async def check_query(
|
||||
query: AsyncQueryBase, *, expected_num_rows=None, expected_columns=None
|
||||
):
|
||||
num_rows = 0
|
||||
results = await query.to_batches()
|
||||
async for batch in results:
|
||||
if expected_columns is not None:
|
||||
assert batch.schema.names == expected_columns
|
||||
num_rows += batch.num_rows
|
||||
if expected_num_rows is not None:
|
||||
assert num_rows == expected_num_rows
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_async(table_async: AsyncTable):
|
||||
await check_query(
|
||||
table_async.query(),
|
||||
expected_num_rows=2,
|
||||
expected_columns=["vector", "id", "str_field", "float_field"],
|
||||
)
|
||||
await check_query(table_async.query().where("id = 2"), expected_num_rows=1)
|
||||
await check_query(
|
||||
table_async.query().select(["id", "vector"]), expected_columns=["id", "vector"]
|
||||
)
|
||||
await check_query(
|
||||
table_async.query().select({"foo": "id", "bar": "id + 1"}),
|
||||
expected_columns=["foo", "bar"],
|
||||
)
|
||||
await check_query(table_async.query().limit(1), expected_num_rows=1)
|
||||
await check_query(
|
||||
table_async.query().nearest_to(pa.array([1, 2])), expected_num_rows=2
|
||||
)
|
||||
# Support different types of inputs for the vector query
|
||||
for vector_query in [
|
||||
[1, 2],
|
||||
[1.0, 2.0],
|
||||
np.array([1, 2]),
|
||||
(1, 2),
|
||||
]:
|
||||
await check_query(
|
||||
table_async.query().nearest_to(vector_query), expected_num_rows=2
|
||||
)
|
||||
|
||||
# No easy way to check these vector query parameters are doing what they say. We
|
||||
# just check that they don't raise exceptions and assume this is tested at a lower
|
||||
# level.
|
||||
await check_query(
|
||||
table_async.query().where("id = 2").nearest_to(pa.array([1, 2])).postfilter(),
|
||||
expected_num_rows=1,
|
||||
)
|
||||
await check_query(
|
||||
table_async.query().nearest_to(pa.array([1, 2])).refine_factor(1),
|
||||
expected_num_rows=2,
|
||||
)
|
||||
await check_query(
|
||||
table_async.query().nearest_to(pa.array([1, 2])).nprobes(10),
|
||||
expected_num_rows=2,
|
||||
)
|
||||
await check_query(
|
||||
table_async.query().nearest_to(pa.array([1, 2])).bypass_vector_index(),
|
||||
expected_num_rows=2,
|
||||
)
|
||||
await check_query(
|
||||
table_async.query().nearest_to(pa.array([1, 2])).distance_type("dot"),
|
||||
expected_num_rows=2,
|
||||
)
|
||||
await check_query(
|
||||
table_async.query().nearest_to(pa.array([1, 2])).distance_type("DoT"),
|
||||
expected_num_rows=2,
|
||||
)
|
||||
|
||||
# Make sure we can use a vector query as a base query (e.g. call limit on it)
|
||||
# Also make sure `vector_search` works
|
||||
await check_query(table_async.vector_search([1, 2]).limit(1), expected_num_rows=1)
|
||||
|
||||
# Also check an empty query
|
||||
await check_query(table_async.query().where("id < 0"), expected_num_rows=0)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_to_arrow_async(table_async: AsyncTable):
|
||||
table = await table_async.to_arrow()
|
||||
assert table.num_rows == 2
|
||||
assert table.num_columns == 4
|
||||
|
||||
table = await table_async.query().to_arrow()
|
||||
assert table.num_rows == 2
|
||||
assert table.num_columns == 4
|
||||
|
||||
table = await table_async.query().where("id < 0").to_arrow()
|
||||
assert table.num_rows == 0
|
||||
assert table.num_columns == 4
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_to_pandas_async(table_async: AsyncTable):
|
||||
df = await table_async.to_pandas()
|
||||
assert df.shape == (2, 4)
|
||||
|
||||
df = await table_async.query().to_pandas()
|
||||
assert df.shape == (2, 4)
|
||||
|
||||
df = await table_async.query().where("id < 0").to_pandas()
|
||||
assert df.shape == (0, 4)
|
||||
|
||||
Reference in New Issue
Block a user