diff --git a/python/python/tests/test_query.py b/python/python/tests/test_query.py index b3f0d26a..ade3b7d5 100644 --- a/python/python/tests/test_query.py +++ b/python/python/tests/test_query.py @@ -13,9 +13,7 @@ import unittest.mock as mock from datetime import timedelta -from typing import Optional -import lance import lancedb from lancedb.index import IvfPq import numpy as np @@ -23,41 +21,15 @@ import pandas.testing as tm import pyarrow as pa import pytest import pytest_asyncio -from lancedb.db import LanceDBConnection from lancedb.pydantic import LanceModel, Vector from lancedb.query import AsyncQueryBase, LanceVectorQueryBuilder, Query from lancedb.table import AsyncTable, LanceTable -class MockTable: - def __init__(self, tmp_path): - self.uri = tmp_path - self._conn = LanceDBConnection(self.uri) - - def to_lance(self): - return lance.dataset(self.uri) - - def _execute_query(self, query, batch_size: Optional[int] = None): - ds = self.to_lance() - return ds.scanner( - columns=query.columns, - filter=query.filter, - prefilter=query.prefilter, - nearest={ - "column": query.vector_column, - "q": query.vector, - "k": query.k, - "metric": query.metric, - "nprobes": query.nprobes, - "refine_factor": query.refine_factor, - }, - batch_size=batch_size, - offset=query.offset, - ).to_reader() - - -@pytest.fixture -def table(tmp_path) -> MockTable: +@pytest.fixture(scope="module") +def table(tmpdir_factory) -> lancedb.table.Table: + tmp_path = str(tmpdir_factory.mktemp("data")) + db = lancedb.connect(tmp_path) df = pa.table( { "vector": pa.array( @@ -68,8 +40,7 @@ def table(tmp_path) -> MockTable: "float_field": pa.array([1.0, 2.0]), } ) - lance.write_dataset(df, tmp_path) - return MockTable(tmp_path) + return db.create_table("test", df) @pytest_asyncio.fixture @@ -126,6 +97,12 @@ def test_query_builder(table): assert all(np.array(rs[0]["vector"]) == [1, 2]) +def test_with_row_id(table: lancedb.table.Table): + rs = table.search().with_row_id(True).to_arrow() + assert "_rowid" in rs.column_names + assert rs["_rowid"].to_pylist() == [0, 1] + + def test_vector_query_with_no_limit(table): with pytest.raises(ValueError): LanceVectorQueryBuilder(table, [0, 0], "vector").limit(0).select(