mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-26 06:39:57 +00:00
test(python): test with_row_id in sync query (#1835)
Also remove weird `MockTable` fixture.
This commit is contained in:
@@ -13,9 +13,7 @@
|
||||
|
||||
import unittest.mock as mock
|
||||
from datetime import timedelta
|
||||
from typing import Optional
|
||||
|
||||
import lance
|
||||
import lancedb
|
||||
from lancedb.index import IvfPq
|
||||
import numpy as np
|
||||
@@ -23,41 +21,15 @@ import pandas.testing as tm
|
||||
import pyarrow as pa
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from lancedb.db import LanceDBConnection
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from lancedb.query import AsyncQueryBase, LanceVectorQueryBuilder, Query
|
||||
from lancedb.table import AsyncTable, LanceTable
|
||||
|
||||
|
||||
class MockTable:
|
||||
def __init__(self, tmp_path):
|
||||
self.uri = tmp_path
|
||||
self._conn = LanceDBConnection(self.uri)
|
||||
|
||||
def to_lance(self):
|
||||
return lance.dataset(self.uri)
|
||||
|
||||
def _execute_query(self, query, batch_size: Optional[int] = None):
|
||||
ds = self.to_lance()
|
||||
return ds.scanner(
|
||||
columns=query.columns,
|
||||
filter=query.filter,
|
||||
prefilter=query.prefilter,
|
||||
nearest={
|
||||
"column": query.vector_column,
|
||||
"q": query.vector,
|
||||
"k": query.k,
|
||||
"metric": query.metric,
|
||||
"nprobes": query.nprobes,
|
||||
"refine_factor": query.refine_factor,
|
||||
},
|
||||
batch_size=batch_size,
|
||||
offset=query.offset,
|
||||
).to_reader()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def table(tmp_path) -> MockTable:
|
||||
@pytest.fixture(scope="module")
|
||||
def table(tmpdir_factory) -> lancedb.table.Table:
|
||||
tmp_path = str(tmpdir_factory.mktemp("data"))
|
||||
db = lancedb.connect(tmp_path)
|
||||
df = pa.table(
|
||||
{
|
||||
"vector": pa.array(
|
||||
@@ -68,8 +40,7 @@ def table(tmp_path) -> MockTable:
|
||||
"float_field": pa.array([1.0, 2.0]),
|
||||
}
|
||||
)
|
||||
lance.write_dataset(df, tmp_path)
|
||||
return MockTable(tmp_path)
|
||||
return db.create_table("test", df)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
@@ -126,6 +97,12 @@ def test_query_builder(table):
|
||||
assert all(np.array(rs[0]["vector"]) == [1, 2])
|
||||
|
||||
|
||||
def test_with_row_id(table: lancedb.table.Table):
|
||||
rs = table.search().with_row_id(True).to_arrow()
|
||||
assert "_rowid" in rs.column_names
|
||||
assert rs["_rowid"].to_pylist() == [0, 1]
|
||||
|
||||
|
||||
def test_vector_query_with_no_limit(table):
|
||||
with pytest.raises(ValueError):
|
||||
LanceVectorQueryBuilder(table, [0, 0], "vector").limit(0).select(
|
||||
|
||||
Reference in New Issue
Block a user