mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-22 21:09:58 +00:00
fix(python): typing (#2167)
@wjones127 is there a standard way you guys setup your virtualenv? I can
either relist all the dependencies in the pyright precommit section, or
specify a venv, or the user has to be in the virtual environment when
they run git commit. If the venv location was standardized or a python
manager like `uv` was used it would be easier to avoid duplicating the
pyright dependency list.
Per your suggestion, in `pyproject.toml` I added in all the passing
files to the `includes` section.
For ruff I upgraded the version and removed "TCH" which doesn't exist as
an option.
I added a `pyright_report.csv` which contains a list of all files sorted
by pyright errors ascending as a todo list to work on.
I fixed about 30 issues in `table.py` stemming from str's being passed
into methods that required a string within a set of string Literals by
extracting them into `types.py`
Can you verify in the rust bridge that the schema should be a property
and not a method here? If it's a method, then there's another place in
the code where `inner.schema` should be `inner.schema()`
``` python
class RecordBatchStream:
@property
def schema(self) -> pa.Schema: ...
```
Also unless the `_lancedb.pyi` file is wrong, then there is no
`__anext__` here for `__inner` when it's not an `AsyncGenerator` and
only `next` is defined:
``` python
async def __anext__(self) -> pa.RecordBatch:
return await self._inner.__anext__()
if isinstance(self._inner, AsyncGenerator):
batch = await self._inner.__anext__()
else:
batch = await self._inner.next()
if batch is None:
raise StopAsyncIteration
return batch
```
in the else statement, `_inner` is a `RecordBatchStream`
```python
class RecordBatchStream:
@property
def schema(self) -> pa.Schema: ...
async def next(self) -> Optional[pa.RecordBatch]: ...
```
---------
Co-authored-by: Will Jones <willjones127@gmail.com>
This commit is contained in:
32
.github/workflows/python.yml
vendored
32
.github/workflows/python.yml
vendored
@@ -33,11 +33,41 @@ jobs:
|
||||
python-version: "3.12"
|
||||
- name: Install ruff
|
||||
run: |
|
||||
pip install ruff==0.8.4
|
||||
pip install ruff==0.9.9
|
||||
- name: Format check
|
||||
run: ruff format --check .
|
||||
- name: Lint
|
||||
run: ruff check .
|
||||
|
||||
type-check:
|
||||
name: "Type Check"
|
||||
timeout-minutes: 30
|
||||
runs-on: "ubuntu-22.04"
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
working-directory: python
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
lfs: true
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.12"
|
||||
- name: Install protobuf compiler
|
||||
run: |
|
||||
sudo apt update
|
||||
sudo apt install -y protobuf-compiler
|
||||
pip install toml
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python ../ci/parse_requirements.py pyproject.toml --extras dev,tests,embeddings > requirements.txt
|
||||
pip install -r requirements.txt
|
||||
- name: Run pyright
|
||||
run: pyright
|
||||
|
||||
doctest:
|
||||
name: "Doctest"
|
||||
timeout-minutes: 30
|
||||
|
||||
@@ -1,21 +1,27 @@
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v3.2.0
|
||||
hooks:
|
||||
- id: check-yaml
|
||||
- id: end-of-file-fixer
|
||||
- id: trailing-whitespace
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
- id: check-yaml
|
||||
- id: end-of-file-fixer
|
||||
- id: trailing-whitespace
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
# Ruff version.
|
||||
rev: v0.8.4
|
||||
rev: v0.9.9
|
||||
hooks:
|
||||
- id: ruff
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: local-biome-check
|
||||
name: biome check
|
||||
entry: npx @biomejs/biome@1.8.3 check --config-path nodejs/biome.json nodejs/
|
||||
language: system
|
||||
types: [text]
|
||||
files: "nodejs/.*"
|
||||
exclude: nodejs/lancedb/native.d.ts|nodejs/dist/.*|nodejs/examples/.*
|
||||
- id: ruff
|
||||
# - repo: https://github.com/RobertCraigie/pyright-python
|
||||
# rev: v1.1.395
|
||||
# hooks:
|
||||
# - id: pyright
|
||||
# args: ["--project", "python"]
|
||||
# additional_dependencies: [pyarrow-stubs]
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: local-biome-check
|
||||
name: biome check
|
||||
entry: npx @biomejs/biome@1.8.3 check --config-path nodejs/biome.json nodejs/
|
||||
language: system
|
||||
types: [text]
|
||||
files: "nodejs/.*"
|
||||
exclude: nodejs/lancedb/native.d.ts|nodejs/dist/.*|nodejs/examples/.*
|
||||
|
||||
41
ci/parse_requirements.py
Normal file
41
ci/parse_requirements.py
Normal file
@@ -0,0 +1,41 @@
|
||||
import argparse
|
||||
import toml
|
||||
|
||||
|
||||
def parse_dependencies(pyproject_path, extras=None):
|
||||
with open(pyproject_path, "r") as file:
|
||||
pyproject = toml.load(file)
|
||||
|
||||
dependencies = pyproject.get("project", {}).get("dependencies", [])
|
||||
for dependency in dependencies:
|
||||
print(dependency)
|
||||
|
||||
optional_dependencies = pyproject.get("project", {}).get(
|
||||
"optional-dependencies", {}
|
||||
)
|
||||
|
||||
if extras:
|
||||
for extra in extras.split(","):
|
||||
for dep in optional_dependencies.get(extra, []):
|
||||
print(dep)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Generate requirements.txt from pyproject.toml"
|
||||
)
|
||||
parser.add_argument("path", type=str, help="Path to pyproject.toml")
|
||||
parser.add_argument(
|
||||
"--extras",
|
||||
type=str,
|
||||
help="Comma-separated list of extras to include",
|
||||
default="",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
parse_dependencies(args.path, args.extras)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
56
pyright_report.csv
Normal file
56
pyright_report.csv
Normal file
@@ -0,0 +1,56 @@
|
||||
file,errors,warnings,total_issues
|
||||
python/python/lancedb/arrow.py,0,0,0
|
||||
python/python/lancedb/background_loop.py,0,0,0
|
||||
python/python/lancedb/embeddings/__init__.py,0,0,0
|
||||
python/python/lancedb/exceptions.py,0,0,0
|
||||
python/python/lancedb/index.py,0,0,0
|
||||
python/python/lancedb/integrations/__init__.py,0,0,0
|
||||
python/python/lancedb/remote/__init__.py,0,0,0
|
||||
python/python/lancedb/remote/errors.py,0,0,0
|
||||
python/python/lancedb/rerankers/__init__.py,0,0,0
|
||||
python/python/lancedb/rerankers/answerdotai.py,0,0,0
|
||||
python/python/lancedb/rerankers/cohere.py,0,0,0
|
||||
python/python/lancedb/rerankers/colbert.py,0,0,0
|
||||
python/python/lancedb/rerankers/cross_encoder.py,0,0,0
|
||||
python/python/lancedb/rerankers/openai.py,0,0,0
|
||||
python/python/lancedb/rerankers/util.py,0,0,0
|
||||
python/python/lancedb/rerankers/voyageai.py,0,0,0
|
||||
python/python/lancedb/schema.py,0,0,0
|
||||
python/python/lancedb/types.py,0,0,0
|
||||
python/python/lancedb/__init__.py,0,1,1
|
||||
python/python/lancedb/conftest.py,1,0,1
|
||||
python/python/lancedb/embeddings/bedrock.py,1,0,1
|
||||
python/python/lancedb/merge.py,1,0,1
|
||||
python/python/lancedb/rerankers/base.py,1,0,1
|
||||
python/python/lancedb/rerankers/jinaai.py,0,1,1
|
||||
python/python/lancedb/rerankers/linear_combination.py,1,0,1
|
||||
python/python/lancedb/embeddings/instructor.py,2,0,2
|
||||
python/python/lancedb/embeddings/openai.py,2,0,2
|
||||
python/python/lancedb/embeddings/watsonx.py,2,0,2
|
||||
python/python/lancedb/embeddings/registry.py,3,0,3
|
||||
python/python/lancedb/embeddings/sentence_transformers.py,3,0,3
|
||||
python/python/lancedb/integrations/pyarrow.py,3,0,3
|
||||
python/python/lancedb/rerankers/rrf.py,3,0,3
|
||||
python/python/lancedb/dependencies.py,4,0,4
|
||||
python/python/lancedb/embeddings/gemini_text.py,4,0,4
|
||||
python/python/lancedb/embeddings/gte.py,4,0,4
|
||||
python/python/lancedb/embeddings/gte_mlx_model.py,4,0,4
|
||||
python/python/lancedb/embeddings/ollama.py,4,0,4
|
||||
python/python/lancedb/embeddings/transformers.py,4,0,4
|
||||
python/python/lancedb/remote/db.py,5,0,5
|
||||
python/python/lancedb/context.py,6,0,6
|
||||
python/python/lancedb/embeddings/cohere.py,6,0,6
|
||||
python/python/lancedb/fts.py,6,0,6
|
||||
python/python/lancedb/db.py,9,0,9
|
||||
python/python/lancedb/embeddings/utils.py,9,0,9
|
||||
python/python/lancedb/common.py,11,0,11
|
||||
python/python/lancedb/util.py,13,0,13
|
||||
python/python/lancedb/embeddings/imagebind.py,14,0,14
|
||||
python/python/lancedb/embeddings/voyageai.py,15,0,15
|
||||
python/python/lancedb/embeddings/open_clip.py,16,0,16
|
||||
python/python/lancedb/pydantic.py,16,0,16
|
||||
python/python/lancedb/embeddings/base.py,17,0,17
|
||||
python/python/lancedb/embeddings/jinaai.py,18,1,19
|
||||
python/python/lancedb/remote/table.py,23,0,23
|
||||
python/python/lancedb/query.py,47,1,48
|
||||
python/python/lancedb/table.py,61,0,61
|
||||
|
@@ -8,9 +8,9 @@ For general contribution guidelines, see [CONTRIBUTING.md](../CONTRIBUTING.md).
|
||||
The Python package is a wrapper around the Rust library, `lancedb`. We use
|
||||
[pyo3](https://pyo3.rs/) to create the bindings between Rust and Python.
|
||||
|
||||
* `src/`: Rust bindings source code
|
||||
* `python/lancedb`: Python package source code
|
||||
* `python/tests`: Unit tests
|
||||
- `src/`: Rust bindings source code
|
||||
- `python/lancedb`: Python package source code
|
||||
- `python/tests`: Unit tests
|
||||
|
||||
## Development environment
|
||||
|
||||
@@ -61,6 +61,12 @@ make test
|
||||
make doctest
|
||||
```
|
||||
|
||||
Run type checking:
|
||||
|
||||
```shell
|
||||
make typecheck
|
||||
```
|
||||
|
||||
To run a single test, you can use the `pytest` command directly. Provide the path
|
||||
to the test file, and optionally the test name after `::`.
|
||||
|
||||
|
||||
@@ -23,6 +23,10 @@ check: ## Check formatting and lints.
|
||||
fix: ## Fix python lints
|
||||
ruff check python --fix
|
||||
|
||||
.PHONY: typecheck
|
||||
typecheck: ## Run type checking with pyright.
|
||||
pyright
|
||||
|
||||
.PHONY: doctest
|
||||
doctest: ## Run documentation tests.
|
||||
pytest --doctest-modules python/lancedb
|
||||
|
||||
@@ -92,7 +92,7 @@ requires = ["maturin>=1.4"]
|
||||
build-backend = "maturin"
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["F", "E", "W", "G", "TCH", "PERF"]
|
||||
select = ["F", "E", "W", "G", "PERF"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
addopts = "--strict-markers --ignore-glob=lancedb/embeddings/*.py"
|
||||
@@ -103,5 +103,28 @@ markers = [
|
||||
]
|
||||
|
||||
[tool.pyright]
|
||||
include = ["python/lancedb/table.py"]
|
||||
include = [
|
||||
"python/lancedb/index.py",
|
||||
"python/lancedb/rerankers/util.py",
|
||||
"python/lancedb/rerankers/__init__.py",
|
||||
"python/lancedb/rerankers/voyageai.py",
|
||||
"python/lancedb/rerankers/jinaai.py",
|
||||
"python/lancedb/rerankers/openai.py",
|
||||
"python/lancedb/rerankers/cross_encoder.py",
|
||||
"python/lancedb/rerankers/colbert.py",
|
||||
"python/lancedb/rerankers/answerdotai.py",
|
||||
"python/lancedb/rerankers/cohere.py",
|
||||
"python/lancedb/arrow.py",
|
||||
"python/lancedb/__init__.py",
|
||||
"python/lancedb/types.py",
|
||||
"python/lancedb/integrations/__init__.py",
|
||||
"python/lancedb/exceptions.py",
|
||||
"python/lancedb/background_loop.py",
|
||||
"python/lancedb/schema.py",
|
||||
"python/lancedb/remote/__init__.py",
|
||||
"python/lancedb/remote/errors.py",
|
||||
"python/lancedb/embeddings/__init__.py",
|
||||
"python/lancedb/_lancedb.pyi",
|
||||
]
|
||||
exclude = ["python/tests/"]
|
||||
pythonVersion = "3.12"
|
||||
|
||||
@@ -14,6 +14,7 @@ from ._lancedb import connect as lancedb_connect
|
||||
from .common import URI, sanitize_uri
|
||||
from .db import AsyncConnection, DBConnection, LanceDBConnection
|
||||
from .remote import ClientConfig
|
||||
from .remote.db import RemoteDBConnection
|
||||
from .schema import vector
|
||||
from .table import AsyncTable
|
||||
|
||||
@@ -86,8 +87,6 @@ def connect(
|
||||
conn : DBConnection
|
||||
A connection to a LanceDB database.
|
||||
"""
|
||||
from .remote.db import RemoteDBConnection
|
||||
|
||||
if isinstance(uri, str) and uri.startswith("db://"):
|
||||
if api_key is None:
|
||||
api_key = os.environ.get("LANCEDB_API_KEY")
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import Dict, List, Optional, Tuple, Any, Union, Literal
|
||||
import pyarrow as pa
|
||||
|
||||
from .index import BTree, IvfFlat, IvfPq, Bitmap, LabelList, HnswPq, HnswSq, FTS
|
||||
from .remote import ClientConfig
|
||||
|
||||
class Connection(object):
|
||||
uri: str
|
||||
@@ -71,11 +72,15 @@ async def connect(
|
||||
region: Optional[str],
|
||||
host_override: Optional[str],
|
||||
read_consistency_interval: Optional[float],
|
||||
client_config: Optional[Union[ClientConfig, Dict[str, Any]]],
|
||||
storage_options: Optional[Dict[str, str]],
|
||||
) -> Connection: ...
|
||||
|
||||
class RecordBatchStream:
|
||||
@property
|
||||
def schema(self) -> pa.Schema: ...
|
||||
async def next(self) -> Optional[pa.RecordBatch]: ...
|
||||
def __aiter__(self) -> "RecordBatchStream": ...
|
||||
async def __anext__(self) -> pa.RecordBatch: ...
|
||||
|
||||
class Query:
|
||||
def where(self, filter: str): ...
|
||||
|
||||
@@ -9,7 +9,8 @@ from typing import Any, Dict, Iterable, List, Optional, Union
|
||||
from urllib.parse import urlparse
|
||||
import warnings
|
||||
|
||||
from lancedb import connect_async
|
||||
# Remove this import to fix circular dependency
|
||||
# from lancedb import connect_async
|
||||
from lancedb.remote import ClientConfig
|
||||
import pyarrow as pa
|
||||
from overrides import override
|
||||
@@ -78,6 +79,9 @@ class RemoteDBConnection(DBConnection):
|
||||
|
||||
self.client_config = client_config
|
||||
|
||||
# Import connect_async here to avoid circular import
|
||||
from lancedb import connect_async
|
||||
|
||||
self._conn = LOOP.run(
|
||||
connect_async(
|
||||
db_url,
|
||||
|
||||
@@ -76,12 +76,21 @@ if TYPE_CHECKING:
|
||||
from .index import IndexConfig
|
||||
import pandas
|
||||
import PIL
|
||||
from .types import (
|
||||
QueryType,
|
||||
OnBadVectorsType,
|
||||
AddMode,
|
||||
CreateMode,
|
||||
VectorIndexType,
|
||||
ScalarIndexType,
|
||||
BaseTokenizerType,
|
||||
DistanceType,
|
||||
)
|
||||
|
||||
|
||||
pd = safe_import_pandas()
|
||||
pl = safe_import_polars()
|
||||
|
||||
QueryType = Literal["vector", "fts", "hybrid", "auto"]
|
||||
|
||||
|
||||
def _into_pyarrow_reader(data) -> pa.RecordBatchReader:
|
||||
from lancedb.dependencies import datasets
|
||||
@@ -178,7 +187,7 @@ def _sanitize_data(
|
||||
data: "DATA",
|
||||
target_schema: Optional[pa.Schema] = None,
|
||||
metadata: Optional[dict] = None, # embedding metadata
|
||||
on_bad_vectors: Literal["error", "drop", "fill", "null"] = "error",
|
||||
on_bad_vectors: OnBadVectorsType = "error",
|
||||
fill_value: float = 0.0,
|
||||
*,
|
||||
allow_subschema: bool = False,
|
||||
@@ -324,7 +333,7 @@ def sanitize_create_table(
|
||||
data,
|
||||
schema: Union[pa.Schema, LanceModel],
|
||||
metadata=None,
|
||||
on_bad_vectors: str = "error",
|
||||
on_bad_vectors: OnBadVectorsType = "error",
|
||||
fill_value: float = 0.0,
|
||||
):
|
||||
if inspect.isclass(schema) and issubclass(schema, LanceModel):
|
||||
@@ -576,9 +585,7 @@ class Table(ABC):
|
||||
accelerator: Optional[str] = None,
|
||||
index_cache_size: Optional[int] = None,
|
||||
*,
|
||||
index_type: Literal[
|
||||
"IVF_FLAT", "IVF_PQ", "IVF_HNSW_SQ", "IVF_HNSW_PQ"
|
||||
] = "IVF_PQ",
|
||||
index_type: VectorIndexType = "IVF_PQ",
|
||||
num_bits: int = 8,
|
||||
max_iterations: int = 50,
|
||||
sample_rate: int = 256,
|
||||
@@ -643,7 +650,7 @@ class Table(ABC):
|
||||
column: str,
|
||||
*,
|
||||
replace: bool = True,
|
||||
index_type: Literal["BTREE", "BITMAP", "LABEL_LIST"] = "BTREE",
|
||||
index_type: ScalarIndexType = "BTREE",
|
||||
):
|
||||
"""Create a scalar index on a column.
|
||||
|
||||
@@ -708,7 +715,7 @@ class Table(ABC):
|
||||
tokenizer_name: Optional[str] = None,
|
||||
with_position: bool = True,
|
||||
# tokenizer configs:
|
||||
base_tokenizer: Literal["simple", "raw", "whitespace"] = "simple",
|
||||
base_tokenizer: BaseTokenizerType = "simple",
|
||||
language: str = "English",
|
||||
max_token_length: Optional[int] = 40,
|
||||
lower_case: bool = True,
|
||||
@@ -777,8 +784,8 @@ class Table(ABC):
|
||||
def add(
|
||||
self,
|
||||
data: DATA,
|
||||
mode: str = "append",
|
||||
on_bad_vectors: str = "error",
|
||||
mode: AddMode = "append",
|
||||
on_bad_vectors: OnBadVectorsType = "error",
|
||||
fill_value: float = 0.0,
|
||||
):
|
||||
"""Add more data to the [Table](Table).
|
||||
@@ -960,7 +967,7 @@ class Table(ABC):
|
||||
self,
|
||||
merge: LanceMergeInsertBuilder,
|
||||
new_data: DATA,
|
||||
on_bad_vectors: str,
|
||||
on_bad_vectors: OnBadVectorsType,
|
||||
fill_value: float,
|
||||
): ...
|
||||
|
||||
@@ -1572,10 +1579,10 @@ class LanceTable(Table):
|
||||
|
||||
def create_index(
|
||||
self,
|
||||
metric="L2",
|
||||
metric: DistanceType = "l2",
|
||||
num_partitions=None,
|
||||
num_sub_vectors=None,
|
||||
vector_column_name=VECTOR_COLUMN_NAME,
|
||||
vector_column_name: str = VECTOR_COLUMN_NAME,
|
||||
replace: bool = True,
|
||||
accelerator: Optional[str] = None,
|
||||
index_cache_size: Optional[int] = None,
|
||||
@@ -1661,7 +1668,7 @@ class LanceTable(Table):
|
||||
column: str,
|
||||
*,
|
||||
replace: bool = True,
|
||||
index_type: Literal["BTREE", "BITMAP", "LABEL_LIST"] = "BTREE",
|
||||
index_type: ScalarIndexType = "BTREE",
|
||||
):
|
||||
if index_type == "BTREE":
|
||||
config = BTree()
|
||||
@@ -1686,7 +1693,7 @@ class LanceTable(Table):
|
||||
tokenizer_name: Optional[str] = None,
|
||||
with_position: bool = True,
|
||||
# tokenizer configs:
|
||||
base_tokenizer: str = "simple",
|
||||
base_tokenizer: BaseTokenizerType = "simple",
|
||||
language: str = "English",
|
||||
max_token_length: Optional[int] = 40,
|
||||
lower_case: bool = True,
|
||||
@@ -1820,8 +1827,8 @@ class LanceTable(Table):
|
||||
def add(
|
||||
self,
|
||||
data: DATA,
|
||||
mode: str = "append",
|
||||
on_bad_vectors: str = "error",
|
||||
mode: AddMode = "append",
|
||||
on_bad_vectors: OnBadVectorsType = "error",
|
||||
fill_value: float = 0.0,
|
||||
):
|
||||
"""Add data to the table.
|
||||
@@ -2059,7 +2066,7 @@ class LanceTable(Table):
|
||||
query_type,
|
||||
vector_column_name=vector_column_name,
|
||||
ordering_field_name=ordering_field_name,
|
||||
fts_columns=fts_columns,
|
||||
fts_columns=fts_columns or [],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -2069,13 +2076,13 @@ class LanceTable(Table):
|
||||
name: str,
|
||||
data: Optional[DATA] = None,
|
||||
schema: Optional[pa.Schema] = None,
|
||||
mode: Literal["create", "overwrite"] = "create",
|
||||
mode: CreateMode = "create",
|
||||
exist_ok: bool = False,
|
||||
on_bad_vectors: str = "error",
|
||||
on_bad_vectors: OnBadVectorsType = "error",
|
||||
fill_value: float = 0.0,
|
||||
embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None,
|
||||
*,
|
||||
storage_options: Optional[Dict[str, str]] = None,
|
||||
storage_options: Optional[Dict[str, str | bool]] = None,
|
||||
data_storage_version: Optional[str] = None,
|
||||
enable_v2_manifest_paths: Optional[bool] = None,
|
||||
):
|
||||
@@ -2229,7 +2236,7 @@ class LanceTable(Table):
|
||||
self,
|
||||
merge: LanceMergeInsertBuilder,
|
||||
new_data: DATA,
|
||||
on_bad_vectors: str,
|
||||
on_bad_vectors: OnBadVectorsType,
|
||||
fill_value: float,
|
||||
):
|
||||
LOOP.run(self._table._do_merge(merge, new_data, on_bad_vectors, fill_value))
|
||||
@@ -2880,7 +2887,7 @@ class AsyncTable:
|
||||
data: DATA,
|
||||
*,
|
||||
mode: Optional[Literal["append", "overwrite"]] = "append",
|
||||
on_bad_vectors: Optional[str] = None,
|
||||
on_bad_vectors: Optional[OnBadVectorsType] = None,
|
||||
fill_value: Optional[float] = None,
|
||||
):
|
||||
"""Add more data to the [Table](Table).
|
||||
@@ -2986,7 +2993,7 @@ class AsyncTable:
|
||||
@overload
|
||||
async def search(
|
||||
self,
|
||||
query: Optional[Union[str]] = None,
|
||||
query: Optional[str] = None,
|
||||
vector_column_name: Optional[str] = None,
|
||||
query_type: Literal["auto"] = ...,
|
||||
ordering_field_name: Optional[str] = None,
|
||||
@@ -2996,7 +3003,7 @@ class AsyncTable:
|
||||
@overload
|
||||
async def search(
|
||||
self,
|
||||
query: Optional[Union[str]] = None,
|
||||
query: Optional[str] = None,
|
||||
vector_column_name: Optional[str] = None,
|
||||
query_type: Literal["hybrid"] = ...,
|
||||
ordering_field_name: Optional[str] = None,
|
||||
@@ -3040,7 +3047,7 @@ class AsyncTable:
|
||||
query_type: QueryType = "auto",
|
||||
ordering_field_name: Optional[str] = None,
|
||||
fts_columns: Optional[Union[str, List[str]]] = None,
|
||||
) -> AsyncQuery:
|
||||
) -> Union[AsyncHybridQuery | AsyncFTSQuery | AsyncVectorQuery]:
|
||||
"""Create a search query to find the nearest neighbors
|
||||
of the given query vector. We currently support [vector search][search]
|
||||
and [full-text search][experimental-full-text-search].
|
||||
@@ -3279,7 +3286,7 @@ class AsyncTable:
|
||||
self,
|
||||
merge: LanceMergeInsertBuilder,
|
||||
new_data: DATA,
|
||||
on_bad_vectors: str,
|
||||
on_bad_vectors: OnBadVectorsType,
|
||||
fill_value: float,
|
||||
):
|
||||
schema = await self.schema()
|
||||
|
||||
28
python/python/lancedb/types.py
Normal file
28
python/python/lancedb/types.py
Normal file
@@ -0,0 +1,28 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
from typing import Literal
|
||||
|
||||
# Query type literals
|
||||
QueryType = Literal["vector", "fts", "hybrid", "auto"]
|
||||
|
||||
# Distance type literals
|
||||
DistanceType = Literal["l2", "cosine", "dot"]
|
||||
DistanceTypeWithHamming = Literal["l2", "cosine", "dot", "hamming"]
|
||||
|
||||
# Vector handling literals
|
||||
OnBadVectorsType = Literal["error", "drop", "fill", "null"]
|
||||
|
||||
# Mode literals
|
||||
AddMode = Literal["append", "overwrite"]
|
||||
CreateMode = Literal["create", "overwrite"]
|
||||
|
||||
# Index type literals
|
||||
VectorIndexType = Literal["IVF_FLAT", "IVF_PQ", "IVF_HNSW_SQ", "IVF_HNSW_PQ"]
|
||||
ScalarIndexType = Literal["BTREE", "BITMAP", "LABEL_LIST"]
|
||||
IndexType = Literal[
|
||||
"IVF_PQ", "IVF_HNSW_PQ", "IVF_HNSW_SQ", "FTS", "BTREE", "BITMAP", "LABEL_LIST"
|
||||
]
|
||||
|
||||
# Tokenizer literals
|
||||
BaseTokenizerType = Literal["simple", "raw", "whitespace"]
|
||||
@@ -419,17 +419,17 @@ def test_embedding_function_safe_model_dump(embedding_type):
|
||||
|
||||
dumped_model = model.safe_model_dump()
|
||||
|
||||
assert all(
|
||||
not k.startswith("_") for k in dumped_model.keys()
|
||||
), f"{embedding_type}: Dumped model contains keys starting with underscore"
|
||||
assert all(not k.startswith("_") for k in dumped_model.keys()), (
|
||||
f"{embedding_type}: Dumped model contains keys starting with underscore"
|
||||
)
|
||||
|
||||
assert (
|
||||
"max_retries" in dumped_model
|
||||
), f"{embedding_type}: Essential field 'max_retries' is missing from dumped model"
|
||||
assert "max_retries" in dumped_model, (
|
||||
f"{embedding_type}: Essential field 'max_retries' is missing from dumped model"
|
||||
)
|
||||
|
||||
assert isinstance(
|
||||
dumped_model, dict
|
||||
), f"{embedding_type}: Dumped model is not a dictionary"
|
||||
assert isinstance(dumped_model, dict), (
|
||||
f"{embedding_type}: Dumped model is not a dictionary"
|
||||
)
|
||||
|
||||
for key in model.__dict__:
|
||||
if key.startswith("_"):
|
||||
|
||||
@@ -129,6 +129,6 @@ def test_normalize_scores():
|
||||
if invert:
|
||||
expected = pc.subtract(1.0, expected)
|
||||
|
||||
assert pc.equal(
|
||||
result, expected
|
||||
), f"Expected {expected} but got {result} for invert={invert}"
|
||||
assert pc.equal(result, expected), (
|
||||
f"Expected {expected} but got {result} for invert={invert}"
|
||||
)
|
||||
|
||||
@@ -784,8 +784,7 @@ async def test_query_search_auto(mem_db_async: AsyncConnection):
|
||||
with pytest.raises(
|
||||
Exception,
|
||||
match=(
|
||||
"Cannot perform full text search unless an INVERTED index has "
|
||||
"been created"
|
||||
"Cannot perform full text search unless an INVERTED index has been created"
|
||||
),
|
||||
):
|
||||
query = await (await tbl2.search("0.1")).to_arrow()
|
||||
|
||||
@@ -131,9 +131,9 @@ def _run_test_reranker(reranker, table, query, query_vector, schema):
|
||||
"represents the relevance of the result to the query & should "
|
||||
"be descending."
|
||||
)
|
||||
assert np.all(
|
||||
np.diff(result.column("_relevance_score").to_numpy()) <= 0
|
||||
), ascending_relevance_err
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
|
||||
ascending_relevance_err
|
||||
)
|
||||
|
||||
# Vector search setting
|
||||
result = (
|
||||
@@ -143,9 +143,9 @@ def _run_test_reranker(reranker, table, query, query_vector, schema):
|
||||
.to_arrow()
|
||||
)
|
||||
assert len(result) == 30
|
||||
assert np.all(
|
||||
np.diff(result.column("_relevance_score").to_numpy()) <= 0
|
||||
), ascending_relevance_err
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
|
||||
ascending_relevance_err
|
||||
)
|
||||
result_explicit = (
|
||||
table.search(query_vector, vector_column_name="vector")
|
||||
.rerank(reranker=reranker, query_string=query)
|
||||
@@ -168,9 +168,9 @@ def _run_test_reranker(reranker, table, query, query_vector, schema):
|
||||
.to_arrow()
|
||||
)
|
||||
assert len(result) > 0
|
||||
assert np.all(
|
||||
np.diff(result.column("_relevance_score").to_numpy()) <= 0
|
||||
), ascending_relevance_err
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
|
||||
ascending_relevance_err
|
||||
)
|
||||
|
||||
# empty FTS results
|
||||
query = "abcxyz" * 100
|
||||
@@ -185,9 +185,9 @@ def _run_test_reranker(reranker, table, query, query_vector, schema):
|
||||
|
||||
# should return _relevance_score column
|
||||
assert "_relevance_score" in result.column_names
|
||||
assert np.all(
|
||||
np.diff(result.column("_relevance_score").to_numpy()) <= 0
|
||||
), ascending_relevance_err
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
|
||||
ascending_relevance_err
|
||||
)
|
||||
|
||||
# Multi-vector search setting
|
||||
rs1 = table.search(query, vector_column_name="vector").limit(10).with_row_id(True)
|
||||
@@ -262,9 +262,9 @@ def _run_test_hybrid_reranker(reranker, tmp_path, use_tantivy):
|
||||
"represents the relevance of the result to the query & should "
|
||||
"be descending."
|
||||
)
|
||||
assert np.all(
|
||||
np.diff(result.column("_relevance_score").to_numpy()) <= 0
|
||||
), ascending_relevance_err
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
|
||||
ascending_relevance_err
|
||||
)
|
||||
|
||||
# Test with empty FTS results
|
||||
query = "abcxyz" * 100
|
||||
@@ -278,9 +278,9 @@ def _run_test_hybrid_reranker(reranker, tmp_path, use_tantivy):
|
||||
)
|
||||
# should return _relevance_score column
|
||||
assert "_relevance_score" in result.column_names
|
||||
assert np.all(
|
||||
np.diff(result.column("_relevance_score").to_numpy()) <= 0
|
||||
), ascending_relevance_err
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
|
||||
ascending_relevance_err
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_tantivy", [True, False])
|
||||
|
||||
Reference in New Issue
Block a user