mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-26 06:39:57 +00:00
@wjones127 is there a standard way you guys setup your virtualenv? I can
either relist all the dependencies in the pyright precommit section, or
specify a venv, or the user has to be in the virtual environment when
they run git commit. If the venv location was standardized or a python
manager like `uv` was used it would be easier to avoid duplicating the
pyright dependency list.
Per your suggestion, in `pyproject.toml` I added in all the passing
files to the `includes` section.
For ruff I upgraded the version and removed "TCH" which doesn't exist as
an option.
I added a `pyright_report.csv` which contains a list of all files sorted
by pyright errors ascending as a todo list to work on.
I fixed about 30 issues in `table.py` stemming from str's being passed
into methods that required a string within a set of string Literals by
extracting them into `types.py`
Can you verify in the rust bridge that the schema should be a property
and not a method here? If it's a method, then there's another place in
the code where `inner.schema` should be `inner.schema()`
``` python
class RecordBatchStream:
@property
def schema(self) -> pa.Schema: ...
```
Also unless the `_lancedb.pyi` file is wrong, then there is no
`__anext__` here for `__inner` when it's not an `AsyncGenerator` and
only `next` is defined:
``` python
async def __anext__(self) -> pa.RecordBatch:
return await self._inner.__anext__()
if isinstance(self._inner, AsyncGenerator):
batch = await self._inner.__anext__()
else:
batch = await self._inner.next()
if batch is None:
raise StopAsyncIteration
return batch
```
in the else statement, `_inner` is a `RecordBatchStream`
```python
class RecordBatchStream:
@property
def schema(self) -> pa.Schema: ...
async def next(self) -> Optional[pa.RecordBatch]: ...
```
---------
Co-authored-by: Will Jones <willjones127@gmail.com>
135 lines
4.0 KiB
Python
135 lines
4.0 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
|
|
|
import lancedb
|
|
|
|
from lancedb.query import LanceHybridQueryBuilder
|
|
import pyarrow as pa
|
|
import pyarrow.compute as pc
|
|
import pytest
|
|
import pytest_asyncio
|
|
|
|
from lancedb.index import FTS
|
|
from lancedb.table import AsyncTable
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def table(tmpdir_factory) -> AsyncTable:
|
|
tmp_path = str(tmpdir_factory.mktemp("data"))
|
|
db = await lancedb.connect_async(tmp_path)
|
|
data = pa.table(
|
|
{
|
|
"text": pa.array(["a", "b", "cat", "dog"]),
|
|
"vector": pa.array(
|
|
[[0.1, 0.1], [2, 2], [-0.1, -0.1], [0.5, -0.5]],
|
|
type=pa.list_(pa.float32(), list_size=2),
|
|
),
|
|
}
|
|
)
|
|
table = await db.create_table("test", data)
|
|
await table.create_index("text", config=FTS(with_position=False))
|
|
return table
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_hybrid_query(table: AsyncTable):
|
|
result = await (
|
|
table.query().nearest_to([0.0, 0.4]).nearest_to_text("dog").limit(2).to_arrow()
|
|
)
|
|
assert len(result) == 2
|
|
# ensure we get results that would match well for text and vector
|
|
assert result["text"].to_pylist() == ["a", "dog"]
|
|
|
|
# ensure there is no rowid by default
|
|
assert "_rowid" not in result
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_hybrid_query_with_row_ids(table: AsyncTable):
|
|
result = await (
|
|
table.query()
|
|
.nearest_to([0.0, 0.4])
|
|
.nearest_to_text("dog")
|
|
.limit(2)
|
|
.with_row_id()
|
|
.to_arrow()
|
|
)
|
|
assert len(result) == 2
|
|
# ensure we get results that would match well for text and vector
|
|
assert result["text"].to_pylist() == ["a", "dog"]
|
|
assert result["_rowid"].to_pylist() == [0, 3]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_hybrid_query_filters(table: AsyncTable):
|
|
# test that query params are passed down from the regular builder to
|
|
# child vector/fts builders
|
|
result = await (
|
|
table.query()
|
|
.where("text not in ('a', 'dog')")
|
|
.nearest_to([0.3, 0.3])
|
|
.nearest_to_text("*a*")
|
|
.distance_type("l2")
|
|
.limit(2)
|
|
.to_arrow()
|
|
)
|
|
assert len(result) == 2
|
|
# ensure we get results that would match well for text and vector
|
|
assert result["text"].to_pylist() == ["cat", "b"]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_hybrid_query_default_limit(table: AsyncTable):
|
|
# add 10 new rows
|
|
new_rows = []
|
|
for i in range(100):
|
|
if i < 2:
|
|
new_rows.append({"text": "close_vec", "vector": [0.1, 0.1]})
|
|
else:
|
|
new_rows.append({"text": "far_vec", "vector": [5 * i, 5 * i]})
|
|
await table.add(new_rows)
|
|
result = await (
|
|
table.query().nearest_to_text("dog").nearest_to([0.1, 0.1]).to_arrow()
|
|
)
|
|
|
|
# assert we got the default limit of 10
|
|
assert len(result) == 10
|
|
|
|
# assert we got the closest vectors and the text searched for
|
|
texts = result["text"].to_pylist()
|
|
assert texts.count("close_vec") == 2
|
|
assert texts.count("dog") == 1
|
|
assert texts.count("a") == 1
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_explain_plan(table: AsyncTable):
|
|
plan = await (
|
|
table.query().nearest_to_text("dog").nearest_to([0.1, 0.1]).explain_plan(True)
|
|
)
|
|
|
|
assert "Vector Search Plan" in plan
|
|
assert "KNNVectorDistance" in plan
|
|
assert "FTS Search Plan" in plan
|
|
assert "LanceScan" in plan
|
|
|
|
|
|
def test_normalize_scores():
|
|
cases = [
|
|
(pa.array([0.1, 0.4]), pa.array([0.0, 1.0])),
|
|
(pa.array([2.0, 10.0, 20.0]), pa.array([0.0, 8.0 / 18.0, 1.0])),
|
|
(pa.array([0.0, 0.0, 0.0]), pa.array([0.0, 0.0, 0.0])),
|
|
(pa.array([10.0, 9.9999999999999]), pa.array([0.0, 0.0])),
|
|
]
|
|
|
|
for input, expected in cases:
|
|
for invert in [True, False]:
|
|
result = LanceHybridQueryBuilder._normalize_scores(input, invert)
|
|
|
|
if invert:
|
|
expected = pc.subtract(1.0, expected)
|
|
|
|
assert pc.equal(result, expected), (
|
|
f"Expected {expected} but got {result} for invert={invert}"
|
|
)
|