From cc81f3e1a520343426a088006b5db9dc9201e9f3 Mon Sep 17 00:00:00 2001 From: msu-reevo Date: Mon, 10 Mar 2025 09:01:23 -0700 Subject: [PATCH] fix(python): typing (#2167) @wjones127 is there a standard way you guys setup your virtualenv? I can either relist all the dependencies in the pyright precommit section, or specify a venv, or the user has to be in the virtual environment when they run git commit. If the venv location was standardized or a python manager like `uv` was used it would be easier to avoid duplicating the pyright dependency list. Per your suggestion, in `pyproject.toml` I added in all the passing files to the `includes` section. For ruff I upgraded the version and removed "TCH" which doesn't exist as an option. I added a `pyright_report.csv` which contains a list of all files sorted by pyright errors ascending as a todo list to work on. I fixed about 30 issues in `table.py` stemming from str's being passed into methods that required a string within a set of string Literals by extracting them into `types.py` Can you verify in the rust bridge that the schema should be a property and not a method here? If it's a method, then there's another place in the code where `inner.schema` should be `inner.schema()` ``` python class RecordBatchStream: @property def schema(self) -> pa.Schema: ... ``` Also unless the `_lancedb.pyi` file is wrong, then there is no `__anext__` here for `__inner` when it's not an `AsyncGenerator` and only `next` is defined: ``` python async def __anext__(self) -> pa.RecordBatch: return await self._inner.__anext__() if isinstance(self._inner, AsyncGenerator): batch = await self._inner.__anext__() else: batch = await self._inner.next() if batch is None: raise StopAsyncIteration return batch ``` in the else statement, `_inner` is a `RecordBatchStream` ```python class RecordBatchStream: @property def schema(self) -> pa.Schema: ... async def next(self) -> Optional[pa.RecordBatch]: ... ``` --------- Co-authored-by: Will Jones --- .github/workflows/python.yml | 32 +++++++++++- .pre-commit-config.yaml | 38 ++++++++------ ci/parse_requirements.py | 41 +++++++++++++++ pyright_report.csv | 56 +++++++++++++++++++++ python/CONTRIBUTING.md | 12 +++-- python/Makefile | 4 ++ python/pyproject.toml | 27 +++++++++- python/python/lancedb/__init__.py | 3 +- python/python/lancedb/_lancedb.pyi | 7 ++- python/python/lancedb/remote/db.py | 6 ++- python/python/lancedb/table.py | 63 +++++++++++++----------- python/python/lancedb/types.py | 28 +++++++++++ python/python/tests/test_embeddings.py | 18 +++---- python/python/tests/test_hybrid_query.py | 6 +-- python/python/tests/test_query.py | 3 +- python/python/tests/test_rerankers.py | 36 +++++++------- 16 files changed, 294 insertions(+), 86 deletions(-) create mode 100644 ci/parse_requirements.py create mode 100644 pyright_report.csv create mode 100644 python/python/lancedb/types.py diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index cdee5dbe..c30bac92 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -33,11 +33,41 @@ jobs: python-version: "3.12" - name: Install ruff run: | - pip install ruff==0.8.4 + pip install ruff==0.9.9 - name: Format check run: ruff format --check . - name: Lint run: ruff check . + + type-check: + name: "Type Check" + timeout-minutes: 30 + runs-on: "ubuntu-22.04" + defaults: + run: + shell: bash + working-directory: python + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + lfs: true + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.12" + - name: Install protobuf compiler + run: | + sudo apt update + sudo apt install -y protobuf-compiler + pip install toml + - name: Install dependencies + run: | + python ../ci/parse_requirements.py pyproject.toml --extras dev,tests,embeddings > requirements.txt + pip install -r requirements.txt + - name: Run pyright + run: pyright + doctest: name: "Doctest" timeout-minutes: 30 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index dd5eb5d9..bef53f90 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,21 +1,27 @@ repos: -- repo: https://github.com/pre-commit/pre-commit-hooks + - repo: https://github.com/pre-commit/pre-commit-hooks rev: v3.2.0 hooks: - - id: check-yaml - - id: end-of-file-fixer - - id: trailing-whitespace -- repo: https://github.com/astral-sh/ruff-pre-commit + - id: check-yaml + - id: end-of-file-fixer + - id: trailing-whitespace + - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: v0.8.4 + rev: v0.9.9 hooks: - - id: ruff -- repo: local - hooks: - - id: local-biome-check - name: biome check - entry: npx @biomejs/biome@1.8.3 check --config-path nodejs/biome.json nodejs/ - language: system - types: [text] - files: "nodejs/.*" - exclude: nodejs/lancedb/native.d.ts|nodejs/dist/.*|nodejs/examples/.* + - id: ruff + # - repo: https://github.com/RobertCraigie/pyright-python + # rev: v1.1.395 + # hooks: + # - id: pyright + # args: ["--project", "python"] + # additional_dependencies: [pyarrow-stubs] + - repo: local + hooks: + - id: local-biome-check + name: biome check + entry: npx @biomejs/biome@1.8.3 check --config-path nodejs/biome.json nodejs/ + language: system + types: [text] + files: "nodejs/.*" + exclude: nodejs/lancedb/native.d.ts|nodejs/dist/.*|nodejs/examples/.* diff --git a/ci/parse_requirements.py b/ci/parse_requirements.py new file mode 100644 index 00000000..4a1075a0 --- /dev/null +++ b/ci/parse_requirements.py @@ -0,0 +1,41 @@ +import argparse +import toml + + +def parse_dependencies(pyproject_path, extras=None): + with open(pyproject_path, "r") as file: + pyproject = toml.load(file) + + dependencies = pyproject.get("project", {}).get("dependencies", []) + for dependency in dependencies: + print(dependency) + + optional_dependencies = pyproject.get("project", {}).get( + "optional-dependencies", {} + ) + + if extras: + for extra in extras.split(","): + for dep in optional_dependencies.get(extra, []): + print(dep) + + +def main(): + parser = argparse.ArgumentParser( + description="Generate requirements.txt from pyproject.toml" + ) + parser.add_argument("path", type=str, help="Path to pyproject.toml") + parser.add_argument( + "--extras", + type=str, + help="Comma-separated list of extras to include", + default="", + ) + + args = parser.parse_args() + + parse_dependencies(args.path, args.extras) + + +if __name__ == "__main__": + main() diff --git a/pyright_report.csv b/pyright_report.csv new file mode 100644 index 00000000..01f4f46c --- /dev/null +++ b/pyright_report.csv @@ -0,0 +1,56 @@ +file,errors,warnings,total_issues +python/python/lancedb/arrow.py,0,0,0 +python/python/lancedb/background_loop.py,0,0,0 +python/python/lancedb/embeddings/__init__.py,0,0,0 +python/python/lancedb/exceptions.py,0,0,0 +python/python/lancedb/index.py,0,0,0 +python/python/lancedb/integrations/__init__.py,0,0,0 +python/python/lancedb/remote/__init__.py,0,0,0 +python/python/lancedb/remote/errors.py,0,0,0 +python/python/lancedb/rerankers/__init__.py,0,0,0 +python/python/lancedb/rerankers/answerdotai.py,0,0,0 +python/python/lancedb/rerankers/cohere.py,0,0,0 +python/python/lancedb/rerankers/colbert.py,0,0,0 +python/python/lancedb/rerankers/cross_encoder.py,0,0,0 +python/python/lancedb/rerankers/openai.py,0,0,0 +python/python/lancedb/rerankers/util.py,0,0,0 +python/python/lancedb/rerankers/voyageai.py,0,0,0 +python/python/lancedb/schema.py,0,0,0 +python/python/lancedb/types.py,0,0,0 +python/python/lancedb/__init__.py,0,1,1 +python/python/lancedb/conftest.py,1,0,1 +python/python/lancedb/embeddings/bedrock.py,1,0,1 +python/python/lancedb/merge.py,1,0,1 +python/python/lancedb/rerankers/base.py,1,0,1 +python/python/lancedb/rerankers/jinaai.py,0,1,1 +python/python/lancedb/rerankers/linear_combination.py,1,0,1 +python/python/lancedb/embeddings/instructor.py,2,0,2 +python/python/lancedb/embeddings/openai.py,2,0,2 +python/python/lancedb/embeddings/watsonx.py,2,0,2 +python/python/lancedb/embeddings/registry.py,3,0,3 +python/python/lancedb/embeddings/sentence_transformers.py,3,0,3 +python/python/lancedb/integrations/pyarrow.py,3,0,3 +python/python/lancedb/rerankers/rrf.py,3,0,3 +python/python/lancedb/dependencies.py,4,0,4 +python/python/lancedb/embeddings/gemini_text.py,4,0,4 +python/python/lancedb/embeddings/gte.py,4,0,4 +python/python/lancedb/embeddings/gte_mlx_model.py,4,0,4 +python/python/lancedb/embeddings/ollama.py,4,0,4 +python/python/lancedb/embeddings/transformers.py,4,0,4 +python/python/lancedb/remote/db.py,5,0,5 +python/python/lancedb/context.py,6,0,6 +python/python/lancedb/embeddings/cohere.py,6,0,6 +python/python/lancedb/fts.py,6,0,6 +python/python/lancedb/db.py,9,0,9 +python/python/lancedb/embeddings/utils.py,9,0,9 +python/python/lancedb/common.py,11,0,11 +python/python/lancedb/util.py,13,0,13 +python/python/lancedb/embeddings/imagebind.py,14,0,14 +python/python/lancedb/embeddings/voyageai.py,15,0,15 +python/python/lancedb/embeddings/open_clip.py,16,0,16 +python/python/lancedb/pydantic.py,16,0,16 +python/python/lancedb/embeddings/base.py,17,0,17 +python/python/lancedb/embeddings/jinaai.py,18,1,19 +python/python/lancedb/remote/table.py,23,0,23 +python/python/lancedb/query.py,47,1,48 +python/python/lancedb/table.py,61,0,61 diff --git a/python/CONTRIBUTING.md b/python/CONTRIBUTING.md index 135a597a..588399bc 100644 --- a/python/CONTRIBUTING.md +++ b/python/CONTRIBUTING.md @@ -8,9 +8,9 @@ For general contribution guidelines, see [CONTRIBUTING.md](../CONTRIBUTING.md). The Python package is a wrapper around the Rust library, `lancedb`. We use [pyo3](https://pyo3.rs/) to create the bindings between Rust and Python. -* `src/`: Rust bindings source code -* `python/lancedb`: Python package source code -* `python/tests`: Unit tests +- `src/`: Rust bindings source code +- `python/lancedb`: Python package source code +- `python/tests`: Unit tests ## Development environment @@ -61,6 +61,12 @@ make test make doctest ``` +Run type checking: + +```shell +make typecheck +``` + To run a single test, you can use the `pytest` command directly. Provide the path to the test file, and optionally the test name after `::`. diff --git a/python/Makefile b/python/Makefile index a22023f2..8ac38ac4 100644 --- a/python/Makefile +++ b/python/Makefile @@ -23,6 +23,10 @@ check: ## Check formatting and lints. fix: ## Fix python lints ruff check python --fix +.PHONY: typecheck +typecheck: ## Run type checking with pyright. + pyright + .PHONY: doctest doctest: ## Run documentation tests. pytest --doctest-modules python/lancedb diff --git a/python/pyproject.toml b/python/pyproject.toml index e90e27c1..39e7b981 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -92,7 +92,7 @@ requires = ["maturin>=1.4"] build-backend = "maturin" [tool.ruff.lint] -select = ["F", "E", "W", "G", "TCH", "PERF"] +select = ["F", "E", "W", "G", "PERF"] [tool.pytest.ini_options] addopts = "--strict-markers --ignore-glob=lancedb/embeddings/*.py" @@ -103,5 +103,28 @@ markers = [ ] [tool.pyright] -include = ["python/lancedb/table.py"] +include = [ + "python/lancedb/index.py", + "python/lancedb/rerankers/util.py", + "python/lancedb/rerankers/__init__.py", + "python/lancedb/rerankers/voyageai.py", + "python/lancedb/rerankers/jinaai.py", + "python/lancedb/rerankers/openai.py", + "python/lancedb/rerankers/cross_encoder.py", + "python/lancedb/rerankers/colbert.py", + "python/lancedb/rerankers/answerdotai.py", + "python/lancedb/rerankers/cohere.py", + "python/lancedb/arrow.py", + "python/lancedb/__init__.py", + "python/lancedb/types.py", + "python/lancedb/integrations/__init__.py", + "python/lancedb/exceptions.py", + "python/lancedb/background_loop.py", + "python/lancedb/schema.py", + "python/lancedb/remote/__init__.py", + "python/lancedb/remote/errors.py", + "python/lancedb/embeddings/__init__.py", + "python/lancedb/_lancedb.pyi", +] +exclude = ["python/tests/"] pythonVersion = "3.12" diff --git a/python/python/lancedb/__init__.py b/python/python/lancedb/__init__.py index 8eddee18..5eff9638 100644 --- a/python/python/lancedb/__init__.py +++ b/python/python/lancedb/__init__.py @@ -14,6 +14,7 @@ from ._lancedb import connect as lancedb_connect from .common import URI, sanitize_uri from .db import AsyncConnection, DBConnection, LanceDBConnection from .remote import ClientConfig +from .remote.db import RemoteDBConnection from .schema import vector from .table import AsyncTable @@ -86,8 +87,6 @@ def connect( conn : DBConnection A connection to a LanceDB database. """ - from .remote.db import RemoteDBConnection - if isinstance(uri, str) and uri.startswith("db://"): if api_key is None: api_key = os.environ.get("LANCEDB_API_KEY") diff --git a/python/python/lancedb/_lancedb.pyi b/python/python/lancedb/_lancedb.pyi index 8ac3ec07..cab8486b 100644 --- a/python/python/lancedb/_lancedb.pyi +++ b/python/python/lancedb/_lancedb.pyi @@ -3,6 +3,7 @@ from typing import Dict, List, Optional, Tuple, Any, Union, Literal import pyarrow as pa from .index import BTree, IvfFlat, IvfPq, Bitmap, LabelList, HnswPq, HnswSq, FTS +from .remote import ClientConfig class Connection(object): uri: str @@ -71,11 +72,15 @@ async def connect( region: Optional[str], host_override: Optional[str], read_consistency_interval: Optional[float], + client_config: Optional[Union[ClientConfig, Dict[str, Any]]], + storage_options: Optional[Dict[str, str]], ) -> Connection: ... class RecordBatchStream: + @property def schema(self) -> pa.Schema: ... - async def next(self) -> Optional[pa.RecordBatch]: ... + def __aiter__(self) -> "RecordBatchStream": ... + async def __anext__(self) -> pa.RecordBatch: ... class Query: def where(self, filter: str): ... diff --git a/python/python/lancedb/remote/db.py b/python/python/lancedb/remote/db.py index e01c49e6..44d614e5 100644 --- a/python/python/lancedb/remote/db.py +++ b/python/python/lancedb/remote/db.py @@ -9,7 +9,8 @@ from typing import Any, Dict, Iterable, List, Optional, Union from urllib.parse import urlparse import warnings -from lancedb import connect_async +# Remove this import to fix circular dependency +# from lancedb import connect_async from lancedb.remote import ClientConfig import pyarrow as pa from overrides import override @@ -78,6 +79,9 @@ class RemoteDBConnection(DBConnection): self.client_config = client_config + # Import connect_async here to avoid circular import + from lancedb import connect_async + self._conn = LOOP.run( connect_async( db_url, diff --git a/python/python/lancedb/table.py b/python/python/lancedb/table.py index 1aa3725d..a4d7c40d 100644 --- a/python/python/lancedb/table.py +++ b/python/python/lancedb/table.py @@ -76,12 +76,21 @@ if TYPE_CHECKING: from .index import IndexConfig import pandas import PIL + from .types import ( + QueryType, + OnBadVectorsType, + AddMode, + CreateMode, + VectorIndexType, + ScalarIndexType, + BaseTokenizerType, + DistanceType, + ) + pd = safe_import_pandas() pl = safe_import_polars() -QueryType = Literal["vector", "fts", "hybrid", "auto"] - def _into_pyarrow_reader(data) -> pa.RecordBatchReader: from lancedb.dependencies import datasets @@ -178,7 +187,7 @@ def _sanitize_data( data: "DATA", target_schema: Optional[pa.Schema] = None, metadata: Optional[dict] = None, # embedding metadata - on_bad_vectors: Literal["error", "drop", "fill", "null"] = "error", + on_bad_vectors: OnBadVectorsType = "error", fill_value: float = 0.0, *, allow_subschema: bool = False, @@ -324,7 +333,7 @@ def sanitize_create_table( data, schema: Union[pa.Schema, LanceModel], metadata=None, - on_bad_vectors: str = "error", + on_bad_vectors: OnBadVectorsType = "error", fill_value: float = 0.0, ): if inspect.isclass(schema) and issubclass(schema, LanceModel): @@ -576,9 +585,7 @@ class Table(ABC): accelerator: Optional[str] = None, index_cache_size: Optional[int] = None, *, - index_type: Literal[ - "IVF_FLAT", "IVF_PQ", "IVF_HNSW_SQ", "IVF_HNSW_PQ" - ] = "IVF_PQ", + index_type: VectorIndexType = "IVF_PQ", num_bits: int = 8, max_iterations: int = 50, sample_rate: int = 256, @@ -643,7 +650,7 @@ class Table(ABC): column: str, *, replace: bool = True, - index_type: Literal["BTREE", "BITMAP", "LABEL_LIST"] = "BTREE", + index_type: ScalarIndexType = "BTREE", ): """Create a scalar index on a column. @@ -708,7 +715,7 @@ class Table(ABC): tokenizer_name: Optional[str] = None, with_position: bool = True, # tokenizer configs: - base_tokenizer: Literal["simple", "raw", "whitespace"] = "simple", + base_tokenizer: BaseTokenizerType = "simple", language: str = "English", max_token_length: Optional[int] = 40, lower_case: bool = True, @@ -777,8 +784,8 @@ class Table(ABC): def add( self, data: DATA, - mode: str = "append", - on_bad_vectors: str = "error", + mode: AddMode = "append", + on_bad_vectors: OnBadVectorsType = "error", fill_value: float = 0.0, ): """Add more data to the [Table](Table). @@ -960,7 +967,7 @@ class Table(ABC): self, merge: LanceMergeInsertBuilder, new_data: DATA, - on_bad_vectors: str, + on_bad_vectors: OnBadVectorsType, fill_value: float, ): ... @@ -1572,10 +1579,10 @@ class LanceTable(Table): def create_index( self, - metric="L2", + metric: DistanceType = "l2", num_partitions=None, num_sub_vectors=None, - vector_column_name=VECTOR_COLUMN_NAME, + vector_column_name: str = VECTOR_COLUMN_NAME, replace: bool = True, accelerator: Optional[str] = None, index_cache_size: Optional[int] = None, @@ -1661,7 +1668,7 @@ class LanceTable(Table): column: str, *, replace: bool = True, - index_type: Literal["BTREE", "BITMAP", "LABEL_LIST"] = "BTREE", + index_type: ScalarIndexType = "BTREE", ): if index_type == "BTREE": config = BTree() @@ -1686,7 +1693,7 @@ class LanceTable(Table): tokenizer_name: Optional[str] = None, with_position: bool = True, # tokenizer configs: - base_tokenizer: str = "simple", + base_tokenizer: BaseTokenizerType = "simple", language: str = "English", max_token_length: Optional[int] = 40, lower_case: bool = True, @@ -1820,8 +1827,8 @@ class LanceTable(Table): def add( self, data: DATA, - mode: str = "append", - on_bad_vectors: str = "error", + mode: AddMode = "append", + on_bad_vectors: OnBadVectorsType = "error", fill_value: float = 0.0, ): """Add data to the table. @@ -2059,7 +2066,7 @@ class LanceTable(Table): query_type, vector_column_name=vector_column_name, ordering_field_name=ordering_field_name, - fts_columns=fts_columns, + fts_columns=fts_columns or [], ) @classmethod @@ -2069,13 +2076,13 @@ class LanceTable(Table): name: str, data: Optional[DATA] = None, schema: Optional[pa.Schema] = None, - mode: Literal["create", "overwrite"] = "create", + mode: CreateMode = "create", exist_ok: bool = False, - on_bad_vectors: str = "error", + on_bad_vectors: OnBadVectorsType = "error", fill_value: float = 0.0, embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None, *, - storage_options: Optional[Dict[str, str]] = None, + storage_options: Optional[Dict[str, str | bool]] = None, data_storage_version: Optional[str] = None, enable_v2_manifest_paths: Optional[bool] = None, ): @@ -2229,7 +2236,7 @@ class LanceTable(Table): self, merge: LanceMergeInsertBuilder, new_data: DATA, - on_bad_vectors: str, + on_bad_vectors: OnBadVectorsType, fill_value: float, ): LOOP.run(self._table._do_merge(merge, new_data, on_bad_vectors, fill_value)) @@ -2880,7 +2887,7 @@ class AsyncTable: data: DATA, *, mode: Optional[Literal["append", "overwrite"]] = "append", - on_bad_vectors: Optional[str] = None, + on_bad_vectors: Optional[OnBadVectorsType] = None, fill_value: Optional[float] = None, ): """Add more data to the [Table](Table). @@ -2986,7 +2993,7 @@ class AsyncTable: @overload async def search( self, - query: Optional[Union[str]] = None, + query: Optional[str] = None, vector_column_name: Optional[str] = None, query_type: Literal["auto"] = ..., ordering_field_name: Optional[str] = None, @@ -2996,7 +3003,7 @@ class AsyncTable: @overload async def search( self, - query: Optional[Union[str]] = None, + query: Optional[str] = None, vector_column_name: Optional[str] = None, query_type: Literal["hybrid"] = ..., ordering_field_name: Optional[str] = None, @@ -3040,7 +3047,7 @@ class AsyncTable: query_type: QueryType = "auto", ordering_field_name: Optional[str] = None, fts_columns: Optional[Union[str, List[str]]] = None, - ) -> AsyncQuery: + ) -> Union[AsyncHybridQuery | AsyncFTSQuery | AsyncVectorQuery]: """Create a search query to find the nearest neighbors of the given query vector. We currently support [vector search][search] and [full-text search][experimental-full-text-search]. @@ -3279,7 +3286,7 @@ class AsyncTable: self, merge: LanceMergeInsertBuilder, new_data: DATA, - on_bad_vectors: str, + on_bad_vectors: OnBadVectorsType, fill_value: float, ): schema = await self.schema() diff --git a/python/python/lancedb/types.py b/python/python/lancedb/types.py new file mode 100644 index 00000000..456c5364 --- /dev/null +++ b/python/python/lancedb/types.py @@ -0,0 +1,28 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright The LanceDB Authors + +from typing import Literal + +# Query type literals +QueryType = Literal["vector", "fts", "hybrid", "auto"] + +# Distance type literals +DistanceType = Literal["l2", "cosine", "dot"] +DistanceTypeWithHamming = Literal["l2", "cosine", "dot", "hamming"] + +# Vector handling literals +OnBadVectorsType = Literal["error", "drop", "fill", "null"] + +# Mode literals +AddMode = Literal["append", "overwrite"] +CreateMode = Literal["create", "overwrite"] + +# Index type literals +VectorIndexType = Literal["IVF_FLAT", "IVF_PQ", "IVF_HNSW_SQ", "IVF_HNSW_PQ"] +ScalarIndexType = Literal["BTREE", "BITMAP", "LABEL_LIST"] +IndexType = Literal[ + "IVF_PQ", "IVF_HNSW_PQ", "IVF_HNSW_SQ", "FTS", "BTREE", "BITMAP", "LABEL_LIST" +] + +# Tokenizer literals +BaseTokenizerType = Literal["simple", "raw", "whitespace"] diff --git a/python/python/tests/test_embeddings.py b/python/python/tests/test_embeddings.py index 168189dc..c3cc32b5 100644 --- a/python/python/tests/test_embeddings.py +++ b/python/python/tests/test_embeddings.py @@ -419,17 +419,17 @@ def test_embedding_function_safe_model_dump(embedding_type): dumped_model = model.safe_model_dump() - assert all( - not k.startswith("_") for k in dumped_model.keys() - ), f"{embedding_type}: Dumped model contains keys starting with underscore" + assert all(not k.startswith("_") for k in dumped_model.keys()), ( + f"{embedding_type}: Dumped model contains keys starting with underscore" + ) - assert ( - "max_retries" in dumped_model - ), f"{embedding_type}: Essential field 'max_retries' is missing from dumped model" + assert "max_retries" in dumped_model, ( + f"{embedding_type}: Essential field 'max_retries' is missing from dumped model" + ) - assert isinstance( - dumped_model, dict - ), f"{embedding_type}: Dumped model is not a dictionary" + assert isinstance(dumped_model, dict), ( + f"{embedding_type}: Dumped model is not a dictionary" + ) for key in model.__dict__: if key.startswith("_"): diff --git a/python/python/tests/test_hybrid_query.py b/python/python/tests/test_hybrid_query.py index 3edce4a3..15e41a89 100644 --- a/python/python/tests/test_hybrid_query.py +++ b/python/python/tests/test_hybrid_query.py @@ -129,6 +129,6 @@ def test_normalize_scores(): if invert: expected = pc.subtract(1.0, expected) - assert pc.equal( - result, expected - ), f"Expected {expected} but got {result} for invert={invert}" + assert pc.equal(result, expected), ( + f"Expected {expected} but got {result} for invert={invert}" + ) diff --git a/python/python/tests/test_query.py b/python/python/tests/test_query.py index 47ad6f05..f98ba7ad 100644 --- a/python/python/tests/test_query.py +++ b/python/python/tests/test_query.py @@ -784,8 +784,7 @@ async def test_query_search_auto(mem_db_async: AsyncConnection): with pytest.raises( Exception, match=( - "Cannot perform full text search unless an INVERTED index has " - "been created" + "Cannot perform full text search unless an INVERTED index has been created" ), ): query = await (await tbl2.search("0.1")).to_arrow() diff --git a/python/python/tests/test_rerankers.py b/python/python/tests/test_rerankers.py index 21d697a3..a4cd1290 100644 --- a/python/python/tests/test_rerankers.py +++ b/python/python/tests/test_rerankers.py @@ -131,9 +131,9 @@ def _run_test_reranker(reranker, table, query, query_vector, schema): "represents the relevance of the result to the query & should " "be descending." ) - assert np.all( - np.diff(result.column("_relevance_score").to_numpy()) <= 0 - ), ascending_relevance_err + assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), ( + ascending_relevance_err + ) # Vector search setting result = ( @@ -143,9 +143,9 @@ def _run_test_reranker(reranker, table, query, query_vector, schema): .to_arrow() ) assert len(result) == 30 - assert np.all( - np.diff(result.column("_relevance_score").to_numpy()) <= 0 - ), ascending_relevance_err + assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), ( + ascending_relevance_err + ) result_explicit = ( table.search(query_vector, vector_column_name="vector") .rerank(reranker=reranker, query_string=query) @@ -168,9 +168,9 @@ def _run_test_reranker(reranker, table, query, query_vector, schema): .to_arrow() ) assert len(result) > 0 - assert np.all( - np.diff(result.column("_relevance_score").to_numpy()) <= 0 - ), ascending_relevance_err + assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), ( + ascending_relevance_err + ) # empty FTS results query = "abcxyz" * 100 @@ -185,9 +185,9 @@ def _run_test_reranker(reranker, table, query, query_vector, schema): # should return _relevance_score column assert "_relevance_score" in result.column_names - assert np.all( - np.diff(result.column("_relevance_score").to_numpy()) <= 0 - ), ascending_relevance_err + assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), ( + ascending_relevance_err + ) # Multi-vector search setting rs1 = table.search(query, vector_column_name="vector").limit(10).with_row_id(True) @@ -262,9 +262,9 @@ def _run_test_hybrid_reranker(reranker, tmp_path, use_tantivy): "represents the relevance of the result to the query & should " "be descending." ) - assert np.all( - np.diff(result.column("_relevance_score").to_numpy()) <= 0 - ), ascending_relevance_err + assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), ( + ascending_relevance_err + ) # Test with empty FTS results query = "abcxyz" * 100 @@ -278,9 +278,9 @@ def _run_test_hybrid_reranker(reranker, tmp_path, use_tantivy): ) # should return _relevance_score column assert "_relevance_score" in result.column_names - assert np.all( - np.diff(result.column("_relevance_score").to_numpy()) <= 0 - ), ascending_relevance_err + assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), ( + ascending_relevance_err + ) @pytest.mark.parametrize("use_tantivy", [True, False])