fix(python): typing (#2167)

@wjones127 is there a standard way you guys setup your virtualenv? I can
either relist all the dependencies in the pyright precommit section, or
specify a venv, or the user has to be in the virtual environment when
they run git commit. If the venv location was standardized or a python
manager like `uv` was used it would be easier to avoid duplicating the
pyright dependency list.

Per your suggestion, in `pyproject.toml` I added in all the passing
files to the `includes` section.

For ruff I upgraded the version and removed "TCH" which doesn't exist as
an option.

I added a `pyright_report.csv` which contains a list of all files sorted
by pyright errors ascending as a todo list to work on.

I fixed about 30 issues in `table.py` stemming from str's being passed
into methods that required a string within a set of string Literals by
extracting them into `types.py`

Can you verify in the rust bridge that the schema should be a property
and not a method here? If it's a method, then there's another place in
the code where `inner.schema` should be `inner.schema()`
``` python
class RecordBatchStream:
    @property
    def schema(self) -> pa.Schema: ...
```

Also unless the `_lancedb.pyi` file is wrong, then there is no
`__anext__` here for `__inner` when it's not an `AsyncGenerator` and
only `next` is defined:
``` python
    async def __anext__(self) -> pa.RecordBatch:
        return await self._inner.__anext__()
        if isinstance(self._inner, AsyncGenerator):
            batch = await self._inner.__anext__()
        else:
            batch = await self._inner.next()
        if batch is None:
            raise StopAsyncIteration
        return batch
```
in the else statement, `_inner` is a `RecordBatchStream`
```python
class RecordBatchStream:
    @property
    def schema(self) -> pa.Schema: ...
    async def next(self) -> Optional[pa.RecordBatch]: ...
```

---------

Co-authored-by: Will Jones <willjones127@gmail.com>
This commit is contained in:
msu-reevo
2025-03-10 09:01:23 -07:00
committed by GitHub
parent bc49c4db82
commit cc81f3e1a5
16 changed files with 294 additions and 86 deletions

View File

@@ -33,11 +33,41 @@ jobs:
python-version: "3.12"
- name: Install ruff
run: |
pip install ruff==0.8.4
pip install ruff==0.9.9
- name: Format check
run: ruff format --check .
- name: Lint
run: ruff check .
type-check:
name: "Type Check"
timeout-minutes: 30
runs-on: "ubuntu-22.04"
defaults:
run:
shell: bash
working-directory: python
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
lfs: true
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.12"
- name: Install protobuf compiler
run: |
sudo apt update
sudo apt install -y protobuf-compiler
pip install toml
- name: Install dependencies
run: |
python ../ci/parse_requirements.py pyproject.toml --extras dev,tests,embeddings > requirements.txt
pip install -r requirements.txt
- name: Run pyright
run: pyright
doctest:
name: "Doctest"
timeout-minutes: 30

View File

@@ -1,21 +1,27 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v3.2.0
hooks:
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/astral-sh/ruff-pre-commit
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.8.4
rev: v0.9.9
hooks:
- id: ruff
- repo: local
hooks:
- id: local-biome-check
name: biome check
entry: npx @biomejs/biome@1.8.3 check --config-path nodejs/biome.json nodejs/
language: system
types: [text]
files: "nodejs/.*"
exclude: nodejs/lancedb/native.d.ts|nodejs/dist/.*|nodejs/examples/.*
- id: ruff
# - repo: https://github.com/RobertCraigie/pyright-python
# rev: v1.1.395
# hooks:
# - id: pyright
# args: ["--project", "python"]
# additional_dependencies: [pyarrow-stubs]
- repo: local
hooks:
- id: local-biome-check
name: biome check
entry: npx @biomejs/biome@1.8.3 check --config-path nodejs/biome.json nodejs/
language: system
types: [text]
files: "nodejs/.*"
exclude: nodejs/lancedb/native.d.ts|nodejs/dist/.*|nodejs/examples/.*

41
ci/parse_requirements.py Normal file
View File

@@ -0,0 +1,41 @@
import argparse
import toml
def parse_dependencies(pyproject_path, extras=None):
with open(pyproject_path, "r") as file:
pyproject = toml.load(file)
dependencies = pyproject.get("project", {}).get("dependencies", [])
for dependency in dependencies:
print(dependency)
optional_dependencies = pyproject.get("project", {}).get(
"optional-dependencies", {}
)
if extras:
for extra in extras.split(","):
for dep in optional_dependencies.get(extra, []):
print(dep)
def main():
parser = argparse.ArgumentParser(
description="Generate requirements.txt from pyproject.toml"
)
parser.add_argument("path", type=str, help="Path to pyproject.toml")
parser.add_argument(
"--extras",
type=str,
help="Comma-separated list of extras to include",
default="",
)
args = parser.parse_args()
parse_dependencies(args.path, args.extras)
if __name__ == "__main__":
main()

56
pyright_report.csv Normal file
View File

@@ -0,0 +1,56 @@
file,errors,warnings,total_issues
python/python/lancedb/arrow.py,0,0,0
python/python/lancedb/background_loop.py,0,0,0
python/python/lancedb/embeddings/__init__.py,0,0,0
python/python/lancedb/exceptions.py,0,0,0
python/python/lancedb/index.py,0,0,0
python/python/lancedb/integrations/__init__.py,0,0,0
python/python/lancedb/remote/__init__.py,0,0,0
python/python/lancedb/remote/errors.py,0,0,0
python/python/lancedb/rerankers/__init__.py,0,0,0
python/python/lancedb/rerankers/answerdotai.py,0,0,0
python/python/lancedb/rerankers/cohere.py,0,0,0
python/python/lancedb/rerankers/colbert.py,0,0,0
python/python/lancedb/rerankers/cross_encoder.py,0,0,0
python/python/lancedb/rerankers/openai.py,0,0,0
python/python/lancedb/rerankers/util.py,0,0,0
python/python/lancedb/rerankers/voyageai.py,0,0,0
python/python/lancedb/schema.py,0,0,0
python/python/lancedb/types.py,0,0,0
python/python/lancedb/__init__.py,0,1,1
python/python/lancedb/conftest.py,1,0,1
python/python/lancedb/embeddings/bedrock.py,1,0,1
python/python/lancedb/merge.py,1,0,1
python/python/lancedb/rerankers/base.py,1,0,1
python/python/lancedb/rerankers/jinaai.py,0,1,1
python/python/lancedb/rerankers/linear_combination.py,1,0,1
python/python/lancedb/embeddings/instructor.py,2,0,2
python/python/lancedb/embeddings/openai.py,2,0,2
python/python/lancedb/embeddings/watsonx.py,2,0,2
python/python/lancedb/embeddings/registry.py,3,0,3
python/python/lancedb/embeddings/sentence_transformers.py,3,0,3
python/python/lancedb/integrations/pyarrow.py,3,0,3
python/python/lancedb/rerankers/rrf.py,3,0,3
python/python/lancedb/dependencies.py,4,0,4
python/python/lancedb/embeddings/gemini_text.py,4,0,4
python/python/lancedb/embeddings/gte.py,4,0,4
python/python/lancedb/embeddings/gte_mlx_model.py,4,0,4
python/python/lancedb/embeddings/ollama.py,4,0,4
python/python/lancedb/embeddings/transformers.py,4,0,4
python/python/lancedb/remote/db.py,5,0,5
python/python/lancedb/context.py,6,0,6
python/python/lancedb/embeddings/cohere.py,6,0,6
python/python/lancedb/fts.py,6,0,6
python/python/lancedb/db.py,9,0,9
python/python/lancedb/embeddings/utils.py,9,0,9
python/python/lancedb/common.py,11,0,11
python/python/lancedb/util.py,13,0,13
python/python/lancedb/embeddings/imagebind.py,14,0,14
python/python/lancedb/embeddings/voyageai.py,15,0,15
python/python/lancedb/embeddings/open_clip.py,16,0,16
python/python/lancedb/pydantic.py,16,0,16
python/python/lancedb/embeddings/base.py,17,0,17
python/python/lancedb/embeddings/jinaai.py,18,1,19
python/python/lancedb/remote/table.py,23,0,23
python/python/lancedb/query.py,47,1,48
python/python/lancedb/table.py,61,0,61
1 file errors warnings total_issues
2 python/python/lancedb/arrow.py 0 0 0
3 python/python/lancedb/background_loop.py 0 0 0
4 python/python/lancedb/embeddings/__init__.py 0 0 0
5 python/python/lancedb/exceptions.py 0 0 0
6 python/python/lancedb/index.py 0 0 0
7 python/python/lancedb/integrations/__init__.py 0 0 0
8 python/python/lancedb/remote/__init__.py 0 0 0
9 python/python/lancedb/remote/errors.py 0 0 0
10 python/python/lancedb/rerankers/__init__.py 0 0 0
11 python/python/lancedb/rerankers/answerdotai.py 0 0 0
12 python/python/lancedb/rerankers/cohere.py 0 0 0
13 python/python/lancedb/rerankers/colbert.py 0 0 0
14 python/python/lancedb/rerankers/cross_encoder.py 0 0 0
15 python/python/lancedb/rerankers/openai.py 0 0 0
16 python/python/lancedb/rerankers/util.py 0 0 0
17 python/python/lancedb/rerankers/voyageai.py 0 0 0
18 python/python/lancedb/schema.py 0 0 0
19 python/python/lancedb/types.py 0 0 0
20 python/python/lancedb/__init__.py 0 1 1
21 python/python/lancedb/conftest.py 1 0 1
22 python/python/lancedb/embeddings/bedrock.py 1 0 1
23 python/python/lancedb/merge.py 1 0 1
24 python/python/lancedb/rerankers/base.py 1 0 1
25 python/python/lancedb/rerankers/jinaai.py 0 1 1
26 python/python/lancedb/rerankers/linear_combination.py 1 0 1
27 python/python/lancedb/embeddings/instructor.py 2 0 2
28 python/python/lancedb/embeddings/openai.py 2 0 2
29 python/python/lancedb/embeddings/watsonx.py 2 0 2
30 python/python/lancedb/embeddings/registry.py 3 0 3
31 python/python/lancedb/embeddings/sentence_transformers.py 3 0 3
32 python/python/lancedb/integrations/pyarrow.py 3 0 3
33 python/python/lancedb/rerankers/rrf.py 3 0 3
34 python/python/lancedb/dependencies.py 4 0 4
35 python/python/lancedb/embeddings/gemini_text.py 4 0 4
36 python/python/lancedb/embeddings/gte.py 4 0 4
37 python/python/lancedb/embeddings/gte_mlx_model.py 4 0 4
38 python/python/lancedb/embeddings/ollama.py 4 0 4
39 python/python/lancedb/embeddings/transformers.py 4 0 4
40 python/python/lancedb/remote/db.py 5 0 5
41 python/python/lancedb/context.py 6 0 6
42 python/python/lancedb/embeddings/cohere.py 6 0 6
43 python/python/lancedb/fts.py 6 0 6
44 python/python/lancedb/db.py 9 0 9
45 python/python/lancedb/embeddings/utils.py 9 0 9
46 python/python/lancedb/common.py 11 0 11
47 python/python/lancedb/util.py 13 0 13
48 python/python/lancedb/embeddings/imagebind.py 14 0 14
49 python/python/lancedb/embeddings/voyageai.py 15 0 15
50 python/python/lancedb/embeddings/open_clip.py 16 0 16
51 python/python/lancedb/pydantic.py 16 0 16
52 python/python/lancedb/embeddings/base.py 17 0 17
53 python/python/lancedb/embeddings/jinaai.py 18 1 19
54 python/python/lancedb/remote/table.py 23 0 23
55 python/python/lancedb/query.py 47 1 48
56 python/python/lancedb/table.py 61 0 61

View File

@@ -8,9 +8,9 @@ For general contribution guidelines, see [CONTRIBUTING.md](../CONTRIBUTING.md).
The Python package is a wrapper around the Rust library, `lancedb`. We use
[pyo3](https://pyo3.rs/) to create the bindings between Rust and Python.
* `src/`: Rust bindings source code
* `python/lancedb`: Python package source code
* `python/tests`: Unit tests
- `src/`: Rust bindings source code
- `python/lancedb`: Python package source code
- `python/tests`: Unit tests
## Development environment
@@ -61,6 +61,12 @@ make test
make doctest
```
Run type checking:
```shell
make typecheck
```
To run a single test, you can use the `pytest` command directly. Provide the path
to the test file, and optionally the test name after `::`.

View File

@@ -23,6 +23,10 @@ check: ## Check formatting and lints.
fix: ## Fix python lints
ruff check python --fix
.PHONY: typecheck
typecheck: ## Run type checking with pyright.
pyright
.PHONY: doctest
doctest: ## Run documentation tests.
pytest --doctest-modules python/lancedb

View File

@@ -92,7 +92,7 @@ requires = ["maturin>=1.4"]
build-backend = "maturin"
[tool.ruff.lint]
select = ["F", "E", "W", "G", "TCH", "PERF"]
select = ["F", "E", "W", "G", "PERF"]
[tool.pytest.ini_options]
addopts = "--strict-markers --ignore-glob=lancedb/embeddings/*.py"
@@ -103,5 +103,28 @@ markers = [
]
[tool.pyright]
include = ["python/lancedb/table.py"]
include = [
"python/lancedb/index.py",
"python/lancedb/rerankers/util.py",
"python/lancedb/rerankers/__init__.py",
"python/lancedb/rerankers/voyageai.py",
"python/lancedb/rerankers/jinaai.py",
"python/lancedb/rerankers/openai.py",
"python/lancedb/rerankers/cross_encoder.py",
"python/lancedb/rerankers/colbert.py",
"python/lancedb/rerankers/answerdotai.py",
"python/lancedb/rerankers/cohere.py",
"python/lancedb/arrow.py",
"python/lancedb/__init__.py",
"python/lancedb/types.py",
"python/lancedb/integrations/__init__.py",
"python/lancedb/exceptions.py",
"python/lancedb/background_loop.py",
"python/lancedb/schema.py",
"python/lancedb/remote/__init__.py",
"python/lancedb/remote/errors.py",
"python/lancedb/embeddings/__init__.py",
"python/lancedb/_lancedb.pyi",
]
exclude = ["python/tests/"]
pythonVersion = "3.12"

View File

@@ -14,6 +14,7 @@ from ._lancedb import connect as lancedb_connect
from .common import URI, sanitize_uri
from .db import AsyncConnection, DBConnection, LanceDBConnection
from .remote import ClientConfig
from .remote.db import RemoteDBConnection
from .schema import vector
from .table import AsyncTable
@@ -86,8 +87,6 @@ def connect(
conn : DBConnection
A connection to a LanceDB database.
"""
from .remote.db import RemoteDBConnection
if isinstance(uri, str) and uri.startswith("db://"):
if api_key is None:
api_key = os.environ.get("LANCEDB_API_KEY")

View File

@@ -3,6 +3,7 @@ from typing import Dict, List, Optional, Tuple, Any, Union, Literal
import pyarrow as pa
from .index import BTree, IvfFlat, IvfPq, Bitmap, LabelList, HnswPq, HnswSq, FTS
from .remote import ClientConfig
class Connection(object):
uri: str
@@ -71,11 +72,15 @@ async def connect(
region: Optional[str],
host_override: Optional[str],
read_consistency_interval: Optional[float],
client_config: Optional[Union[ClientConfig, Dict[str, Any]]],
storage_options: Optional[Dict[str, str]],
) -> Connection: ...
class RecordBatchStream:
@property
def schema(self) -> pa.Schema: ...
async def next(self) -> Optional[pa.RecordBatch]: ...
def __aiter__(self) -> "RecordBatchStream": ...
async def __anext__(self) -> pa.RecordBatch: ...
class Query:
def where(self, filter: str): ...

View File

@@ -9,7 +9,8 @@ from typing import Any, Dict, Iterable, List, Optional, Union
from urllib.parse import urlparse
import warnings
from lancedb import connect_async
# Remove this import to fix circular dependency
# from lancedb import connect_async
from lancedb.remote import ClientConfig
import pyarrow as pa
from overrides import override
@@ -78,6 +79,9 @@ class RemoteDBConnection(DBConnection):
self.client_config = client_config
# Import connect_async here to avoid circular import
from lancedb import connect_async
self._conn = LOOP.run(
connect_async(
db_url,

View File

@@ -76,12 +76,21 @@ if TYPE_CHECKING:
from .index import IndexConfig
import pandas
import PIL
from .types import (
QueryType,
OnBadVectorsType,
AddMode,
CreateMode,
VectorIndexType,
ScalarIndexType,
BaseTokenizerType,
DistanceType,
)
pd = safe_import_pandas()
pl = safe_import_polars()
QueryType = Literal["vector", "fts", "hybrid", "auto"]
def _into_pyarrow_reader(data) -> pa.RecordBatchReader:
from lancedb.dependencies import datasets
@@ -178,7 +187,7 @@ def _sanitize_data(
data: "DATA",
target_schema: Optional[pa.Schema] = None,
metadata: Optional[dict] = None, # embedding metadata
on_bad_vectors: Literal["error", "drop", "fill", "null"] = "error",
on_bad_vectors: OnBadVectorsType = "error",
fill_value: float = 0.0,
*,
allow_subschema: bool = False,
@@ -324,7 +333,7 @@ def sanitize_create_table(
data,
schema: Union[pa.Schema, LanceModel],
metadata=None,
on_bad_vectors: str = "error",
on_bad_vectors: OnBadVectorsType = "error",
fill_value: float = 0.0,
):
if inspect.isclass(schema) and issubclass(schema, LanceModel):
@@ -576,9 +585,7 @@ class Table(ABC):
accelerator: Optional[str] = None,
index_cache_size: Optional[int] = None,
*,
index_type: Literal[
"IVF_FLAT", "IVF_PQ", "IVF_HNSW_SQ", "IVF_HNSW_PQ"
] = "IVF_PQ",
index_type: VectorIndexType = "IVF_PQ",
num_bits: int = 8,
max_iterations: int = 50,
sample_rate: int = 256,
@@ -643,7 +650,7 @@ class Table(ABC):
column: str,
*,
replace: bool = True,
index_type: Literal["BTREE", "BITMAP", "LABEL_LIST"] = "BTREE",
index_type: ScalarIndexType = "BTREE",
):
"""Create a scalar index on a column.
@@ -708,7 +715,7 @@ class Table(ABC):
tokenizer_name: Optional[str] = None,
with_position: bool = True,
# tokenizer configs:
base_tokenizer: Literal["simple", "raw", "whitespace"] = "simple",
base_tokenizer: BaseTokenizerType = "simple",
language: str = "English",
max_token_length: Optional[int] = 40,
lower_case: bool = True,
@@ -777,8 +784,8 @@ class Table(ABC):
def add(
self,
data: DATA,
mode: str = "append",
on_bad_vectors: str = "error",
mode: AddMode = "append",
on_bad_vectors: OnBadVectorsType = "error",
fill_value: float = 0.0,
):
"""Add more data to the [Table](Table).
@@ -960,7 +967,7 @@ class Table(ABC):
self,
merge: LanceMergeInsertBuilder,
new_data: DATA,
on_bad_vectors: str,
on_bad_vectors: OnBadVectorsType,
fill_value: float,
): ...
@@ -1572,10 +1579,10 @@ class LanceTable(Table):
def create_index(
self,
metric="L2",
metric: DistanceType = "l2",
num_partitions=None,
num_sub_vectors=None,
vector_column_name=VECTOR_COLUMN_NAME,
vector_column_name: str = VECTOR_COLUMN_NAME,
replace: bool = True,
accelerator: Optional[str] = None,
index_cache_size: Optional[int] = None,
@@ -1661,7 +1668,7 @@ class LanceTable(Table):
column: str,
*,
replace: bool = True,
index_type: Literal["BTREE", "BITMAP", "LABEL_LIST"] = "BTREE",
index_type: ScalarIndexType = "BTREE",
):
if index_type == "BTREE":
config = BTree()
@@ -1686,7 +1693,7 @@ class LanceTable(Table):
tokenizer_name: Optional[str] = None,
with_position: bool = True,
# tokenizer configs:
base_tokenizer: str = "simple",
base_tokenizer: BaseTokenizerType = "simple",
language: str = "English",
max_token_length: Optional[int] = 40,
lower_case: bool = True,
@@ -1820,8 +1827,8 @@ class LanceTable(Table):
def add(
self,
data: DATA,
mode: str = "append",
on_bad_vectors: str = "error",
mode: AddMode = "append",
on_bad_vectors: OnBadVectorsType = "error",
fill_value: float = 0.0,
):
"""Add data to the table.
@@ -2059,7 +2066,7 @@ class LanceTable(Table):
query_type,
vector_column_name=vector_column_name,
ordering_field_name=ordering_field_name,
fts_columns=fts_columns,
fts_columns=fts_columns or [],
)
@classmethod
@@ -2069,13 +2076,13 @@ class LanceTable(Table):
name: str,
data: Optional[DATA] = None,
schema: Optional[pa.Schema] = None,
mode: Literal["create", "overwrite"] = "create",
mode: CreateMode = "create",
exist_ok: bool = False,
on_bad_vectors: str = "error",
on_bad_vectors: OnBadVectorsType = "error",
fill_value: float = 0.0,
embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None,
*,
storage_options: Optional[Dict[str, str]] = None,
storage_options: Optional[Dict[str, str | bool]] = None,
data_storage_version: Optional[str] = None,
enable_v2_manifest_paths: Optional[bool] = None,
):
@@ -2229,7 +2236,7 @@ class LanceTable(Table):
self,
merge: LanceMergeInsertBuilder,
new_data: DATA,
on_bad_vectors: str,
on_bad_vectors: OnBadVectorsType,
fill_value: float,
):
LOOP.run(self._table._do_merge(merge, new_data, on_bad_vectors, fill_value))
@@ -2880,7 +2887,7 @@ class AsyncTable:
data: DATA,
*,
mode: Optional[Literal["append", "overwrite"]] = "append",
on_bad_vectors: Optional[str] = None,
on_bad_vectors: Optional[OnBadVectorsType] = None,
fill_value: Optional[float] = None,
):
"""Add more data to the [Table](Table).
@@ -2986,7 +2993,7 @@ class AsyncTable:
@overload
async def search(
self,
query: Optional[Union[str]] = None,
query: Optional[str] = None,
vector_column_name: Optional[str] = None,
query_type: Literal["auto"] = ...,
ordering_field_name: Optional[str] = None,
@@ -2996,7 +3003,7 @@ class AsyncTable:
@overload
async def search(
self,
query: Optional[Union[str]] = None,
query: Optional[str] = None,
vector_column_name: Optional[str] = None,
query_type: Literal["hybrid"] = ...,
ordering_field_name: Optional[str] = None,
@@ -3040,7 +3047,7 @@ class AsyncTable:
query_type: QueryType = "auto",
ordering_field_name: Optional[str] = None,
fts_columns: Optional[Union[str, List[str]]] = None,
) -> AsyncQuery:
) -> Union[AsyncHybridQuery | AsyncFTSQuery | AsyncVectorQuery]:
"""Create a search query to find the nearest neighbors
of the given query vector. We currently support [vector search][search]
and [full-text search][experimental-full-text-search].
@@ -3279,7 +3286,7 @@ class AsyncTable:
self,
merge: LanceMergeInsertBuilder,
new_data: DATA,
on_bad_vectors: str,
on_bad_vectors: OnBadVectorsType,
fill_value: float,
):
schema = await self.schema()

View File

@@ -0,0 +1,28 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
from typing import Literal
# Query type literals
QueryType = Literal["vector", "fts", "hybrid", "auto"]
# Distance type literals
DistanceType = Literal["l2", "cosine", "dot"]
DistanceTypeWithHamming = Literal["l2", "cosine", "dot", "hamming"]
# Vector handling literals
OnBadVectorsType = Literal["error", "drop", "fill", "null"]
# Mode literals
AddMode = Literal["append", "overwrite"]
CreateMode = Literal["create", "overwrite"]
# Index type literals
VectorIndexType = Literal["IVF_FLAT", "IVF_PQ", "IVF_HNSW_SQ", "IVF_HNSW_PQ"]
ScalarIndexType = Literal["BTREE", "BITMAP", "LABEL_LIST"]
IndexType = Literal[
"IVF_PQ", "IVF_HNSW_PQ", "IVF_HNSW_SQ", "FTS", "BTREE", "BITMAP", "LABEL_LIST"
]
# Tokenizer literals
BaseTokenizerType = Literal["simple", "raw", "whitespace"]

View File

@@ -419,17 +419,17 @@ def test_embedding_function_safe_model_dump(embedding_type):
dumped_model = model.safe_model_dump()
assert all(
not k.startswith("_") for k in dumped_model.keys()
), f"{embedding_type}: Dumped model contains keys starting with underscore"
assert all(not k.startswith("_") for k in dumped_model.keys()), (
f"{embedding_type}: Dumped model contains keys starting with underscore"
)
assert (
"max_retries" in dumped_model
), f"{embedding_type}: Essential field 'max_retries' is missing from dumped model"
assert "max_retries" in dumped_model, (
f"{embedding_type}: Essential field 'max_retries' is missing from dumped model"
)
assert isinstance(
dumped_model, dict
), f"{embedding_type}: Dumped model is not a dictionary"
assert isinstance(dumped_model, dict), (
f"{embedding_type}: Dumped model is not a dictionary"
)
for key in model.__dict__:
if key.startswith("_"):

View File

@@ -129,6 +129,6 @@ def test_normalize_scores():
if invert:
expected = pc.subtract(1.0, expected)
assert pc.equal(
result, expected
), f"Expected {expected} but got {result} for invert={invert}"
assert pc.equal(result, expected), (
f"Expected {expected} but got {result} for invert={invert}"
)

View File

@@ -784,8 +784,7 @@ async def test_query_search_auto(mem_db_async: AsyncConnection):
with pytest.raises(
Exception,
match=(
"Cannot perform full text search unless an INVERTED index has "
"been created"
"Cannot perform full text search unless an INVERTED index has been created"
),
):
query = await (await tbl2.search("0.1")).to_arrow()

View File

@@ -131,9 +131,9 @@ def _run_test_reranker(reranker, table, query, query_vector, schema):
"represents the relevance of the result to the query & should "
"be descending."
)
assert np.all(
np.diff(result.column("_relevance_score").to_numpy()) <= 0
), ascending_relevance_err
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
ascending_relevance_err
)
# Vector search setting
result = (
@@ -143,9 +143,9 @@ def _run_test_reranker(reranker, table, query, query_vector, schema):
.to_arrow()
)
assert len(result) == 30
assert np.all(
np.diff(result.column("_relevance_score").to_numpy()) <= 0
), ascending_relevance_err
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
ascending_relevance_err
)
result_explicit = (
table.search(query_vector, vector_column_name="vector")
.rerank(reranker=reranker, query_string=query)
@@ -168,9 +168,9 @@ def _run_test_reranker(reranker, table, query, query_vector, schema):
.to_arrow()
)
assert len(result) > 0
assert np.all(
np.diff(result.column("_relevance_score").to_numpy()) <= 0
), ascending_relevance_err
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
ascending_relevance_err
)
# empty FTS results
query = "abcxyz" * 100
@@ -185,9 +185,9 @@ def _run_test_reranker(reranker, table, query, query_vector, schema):
# should return _relevance_score column
assert "_relevance_score" in result.column_names
assert np.all(
np.diff(result.column("_relevance_score").to_numpy()) <= 0
), ascending_relevance_err
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
ascending_relevance_err
)
# Multi-vector search setting
rs1 = table.search(query, vector_column_name="vector").limit(10).with_row_id(True)
@@ -262,9 +262,9 @@ def _run_test_hybrid_reranker(reranker, tmp_path, use_tantivy):
"represents the relevance of the result to the query & should "
"be descending."
)
assert np.all(
np.diff(result.column("_relevance_score").to_numpy()) <= 0
), ascending_relevance_err
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
ascending_relevance_err
)
# Test with empty FTS results
query = "abcxyz" * 100
@@ -278,9 +278,9 @@ def _run_test_hybrid_reranker(reranker, tmp_path, use_tantivy):
)
# should return _relevance_score column
assert "_relevance_score" in result.column_names
assert np.all(
np.diff(result.column("_relevance_score").to_numpy()) <= 0
), ascending_relevance_err
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
ascending_relevance_err
)
@pytest.mark.parametrize("use_tantivy", [True, False])