mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-22 21:09:58 +00:00
fix(python): typing (#2167)
@wjones127 is there a standard way you guys setup your virtualenv? I can
either relist all the dependencies in the pyright precommit section, or
specify a venv, or the user has to be in the virtual environment when
they run git commit. If the venv location was standardized or a python
manager like `uv` was used it would be easier to avoid duplicating the
pyright dependency list.
Per your suggestion, in `pyproject.toml` I added in all the passing
files to the `includes` section.
For ruff I upgraded the version and removed "TCH" which doesn't exist as
an option.
I added a `pyright_report.csv` which contains a list of all files sorted
by pyright errors ascending as a todo list to work on.
I fixed about 30 issues in `table.py` stemming from str's being passed
into methods that required a string within a set of string Literals by
extracting them into `types.py`
Can you verify in the rust bridge that the schema should be a property
and not a method here? If it's a method, then there's another place in
the code where `inner.schema` should be `inner.schema()`
``` python
class RecordBatchStream:
@property
def schema(self) -> pa.Schema: ...
```
Also unless the `_lancedb.pyi` file is wrong, then there is no
`__anext__` here for `__inner` when it's not an `AsyncGenerator` and
only `next` is defined:
``` python
async def __anext__(self) -> pa.RecordBatch:
return await self._inner.__anext__()
if isinstance(self._inner, AsyncGenerator):
batch = await self._inner.__anext__()
else:
batch = await self._inner.next()
if batch is None:
raise StopAsyncIteration
return batch
```
in the else statement, `_inner` is a `RecordBatchStream`
```python
class RecordBatchStream:
@property
def schema(self) -> pa.Schema: ...
async def next(self) -> Optional[pa.RecordBatch]: ...
```
---------
Co-authored-by: Will Jones <willjones127@gmail.com>
This commit is contained in:
32
.github/workflows/python.yml
vendored
32
.github/workflows/python.yml
vendored
@@ -33,11 +33,41 @@ jobs:
|
|||||||
python-version: "3.12"
|
python-version: "3.12"
|
||||||
- name: Install ruff
|
- name: Install ruff
|
||||||
run: |
|
run: |
|
||||||
pip install ruff==0.8.4
|
pip install ruff==0.9.9
|
||||||
- name: Format check
|
- name: Format check
|
||||||
run: ruff format --check .
|
run: ruff format --check .
|
||||||
- name: Lint
|
- name: Lint
|
||||||
run: ruff check .
|
run: ruff check .
|
||||||
|
|
||||||
|
type-check:
|
||||||
|
name: "Type Check"
|
||||||
|
timeout-minutes: 30
|
||||||
|
runs-on: "ubuntu-22.04"
|
||||||
|
defaults:
|
||||||
|
run:
|
||||||
|
shell: bash
|
||||||
|
working-directory: python
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
lfs: true
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: "3.12"
|
||||||
|
- name: Install protobuf compiler
|
||||||
|
run: |
|
||||||
|
sudo apt update
|
||||||
|
sudo apt install -y protobuf-compiler
|
||||||
|
pip install toml
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
python ../ci/parse_requirements.py pyproject.toml --extras dev,tests,embeddings > requirements.txt
|
||||||
|
pip install -r requirements.txt
|
||||||
|
- name: Run pyright
|
||||||
|
run: pyright
|
||||||
|
|
||||||
doctest:
|
doctest:
|
||||||
name: "Doctest"
|
name: "Doctest"
|
||||||
timeout-minutes: 30
|
timeout-minutes: 30
|
||||||
|
|||||||
@@ -1,21 +1,27 @@
|
|||||||
repos:
|
repos:
|
||||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
rev: v3.2.0
|
rev: v3.2.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: check-yaml
|
- id: check-yaml
|
||||||
- id: end-of-file-fixer
|
- id: end-of-file-fixer
|
||||||
- id: trailing-whitespace
|
- id: trailing-whitespace
|
||||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||||
# Ruff version.
|
# Ruff version.
|
||||||
rev: v0.8.4
|
rev: v0.9.9
|
||||||
hooks:
|
hooks:
|
||||||
- id: ruff
|
- id: ruff
|
||||||
- repo: local
|
# - repo: https://github.com/RobertCraigie/pyright-python
|
||||||
hooks:
|
# rev: v1.1.395
|
||||||
- id: local-biome-check
|
# hooks:
|
||||||
name: biome check
|
# - id: pyright
|
||||||
entry: npx @biomejs/biome@1.8.3 check --config-path nodejs/biome.json nodejs/
|
# args: ["--project", "python"]
|
||||||
language: system
|
# additional_dependencies: [pyarrow-stubs]
|
||||||
types: [text]
|
- repo: local
|
||||||
files: "nodejs/.*"
|
hooks:
|
||||||
exclude: nodejs/lancedb/native.d.ts|nodejs/dist/.*|nodejs/examples/.*
|
- id: local-biome-check
|
||||||
|
name: biome check
|
||||||
|
entry: npx @biomejs/biome@1.8.3 check --config-path nodejs/biome.json nodejs/
|
||||||
|
language: system
|
||||||
|
types: [text]
|
||||||
|
files: "nodejs/.*"
|
||||||
|
exclude: nodejs/lancedb/native.d.ts|nodejs/dist/.*|nodejs/examples/.*
|
||||||
|
|||||||
41
ci/parse_requirements.py
Normal file
41
ci/parse_requirements.py
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
import argparse
|
||||||
|
import toml
|
||||||
|
|
||||||
|
|
||||||
|
def parse_dependencies(pyproject_path, extras=None):
|
||||||
|
with open(pyproject_path, "r") as file:
|
||||||
|
pyproject = toml.load(file)
|
||||||
|
|
||||||
|
dependencies = pyproject.get("project", {}).get("dependencies", [])
|
||||||
|
for dependency in dependencies:
|
||||||
|
print(dependency)
|
||||||
|
|
||||||
|
optional_dependencies = pyproject.get("project", {}).get(
|
||||||
|
"optional-dependencies", {}
|
||||||
|
)
|
||||||
|
|
||||||
|
if extras:
|
||||||
|
for extra in extras.split(","):
|
||||||
|
for dep in optional_dependencies.get(extra, []):
|
||||||
|
print(dep)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Generate requirements.txt from pyproject.toml"
|
||||||
|
)
|
||||||
|
parser.add_argument("path", type=str, help="Path to pyproject.toml")
|
||||||
|
parser.add_argument(
|
||||||
|
"--extras",
|
||||||
|
type=str,
|
||||||
|
help="Comma-separated list of extras to include",
|
||||||
|
default="",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
parse_dependencies(args.path, args.extras)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
56
pyright_report.csv
Normal file
56
pyright_report.csv
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
file,errors,warnings,total_issues
|
||||||
|
python/python/lancedb/arrow.py,0,0,0
|
||||||
|
python/python/lancedb/background_loop.py,0,0,0
|
||||||
|
python/python/lancedb/embeddings/__init__.py,0,0,0
|
||||||
|
python/python/lancedb/exceptions.py,0,0,0
|
||||||
|
python/python/lancedb/index.py,0,0,0
|
||||||
|
python/python/lancedb/integrations/__init__.py,0,0,0
|
||||||
|
python/python/lancedb/remote/__init__.py,0,0,0
|
||||||
|
python/python/lancedb/remote/errors.py,0,0,0
|
||||||
|
python/python/lancedb/rerankers/__init__.py,0,0,0
|
||||||
|
python/python/lancedb/rerankers/answerdotai.py,0,0,0
|
||||||
|
python/python/lancedb/rerankers/cohere.py,0,0,0
|
||||||
|
python/python/lancedb/rerankers/colbert.py,0,0,0
|
||||||
|
python/python/lancedb/rerankers/cross_encoder.py,0,0,0
|
||||||
|
python/python/lancedb/rerankers/openai.py,0,0,0
|
||||||
|
python/python/lancedb/rerankers/util.py,0,0,0
|
||||||
|
python/python/lancedb/rerankers/voyageai.py,0,0,0
|
||||||
|
python/python/lancedb/schema.py,0,0,0
|
||||||
|
python/python/lancedb/types.py,0,0,0
|
||||||
|
python/python/lancedb/__init__.py,0,1,1
|
||||||
|
python/python/lancedb/conftest.py,1,0,1
|
||||||
|
python/python/lancedb/embeddings/bedrock.py,1,0,1
|
||||||
|
python/python/lancedb/merge.py,1,0,1
|
||||||
|
python/python/lancedb/rerankers/base.py,1,0,1
|
||||||
|
python/python/lancedb/rerankers/jinaai.py,0,1,1
|
||||||
|
python/python/lancedb/rerankers/linear_combination.py,1,0,1
|
||||||
|
python/python/lancedb/embeddings/instructor.py,2,0,2
|
||||||
|
python/python/lancedb/embeddings/openai.py,2,0,2
|
||||||
|
python/python/lancedb/embeddings/watsonx.py,2,0,2
|
||||||
|
python/python/lancedb/embeddings/registry.py,3,0,3
|
||||||
|
python/python/lancedb/embeddings/sentence_transformers.py,3,0,3
|
||||||
|
python/python/lancedb/integrations/pyarrow.py,3,0,3
|
||||||
|
python/python/lancedb/rerankers/rrf.py,3,0,3
|
||||||
|
python/python/lancedb/dependencies.py,4,0,4
|
||||||
|
python/python/lancedb/embeddings/gemini_text.py,4,0,4
|
||||||
|
python/python/lancedb/embeddings/gte.py,4,0,4
|
||||||
|
python/python/lancedb/embeddings/gte_mlx_model.py,4,0,4
|
||||||
|
python/python/lancedb/embeddings/ollama.py,4,0,4
|
||||||
|
python/python/lancedb/embeddings/transformers.py,4,0,4
|
||||||
|
python/python/lancedb/remote/db.py,5,0,5
|
||||||
|
python/python/lancedb/context.py,6,0,6
|
||||||
|
python/python/lancedb/embeddings/cohere.py,6,0,6
|
||||||
|
python/python/lancedb/fts.py,6,0,6
|
||||||
|
python/python/lancedb/db.py,9,0,9
|
||||||
|
python/python/lancedb/embeddings/utils.py,9,0,9
|
||||||
|
python/python/lancedb/common.py,11,0,11
|
||||||
|
python/python/lancedb/util.py,13,0,13
|
||||||
|
python/python/lancedb/embeddings/imagebind.py,14,0,14
|
||||||
|
python/python/lancedb/embeddings/voyageai.py,15,0,15
|
||||||
|
python/python/lancedb/embeddings/open_clip.py,16,0,16
|
||||||
|
python/python/lancedb/pydantic.py,16,0,16
|
||||||
|
python/python/lancedb/embeddings/base.py,17,0,17
|
||||||
|
python/python/lancedb/embeddings/jinaai.py,18,1,19
|
||||||
|
python/python/lancedb/remote/table.py,23,0,23
|
||||||
|
python/python/lancedb/query.py,47,1,48
|
||||||
|
python/python/lancedb/table.py,61,0,61
|
||||||
|
@@ -8,9 +8,9 @@ For general contribution guidelines, see [CONTRIBUTING.md](../CONTRIBUTING.md).
|
|||||||
The Python package is a wrapper around the Rust library, `lancedb`. We use
|
The Python package is a wrapper around the Rust library, `lancedb`. We use
|
||||||
[pyo3](https://pyo3.rs/) to create the bindings between Rust and Python.
|
[pyo3](https://pyo3.rs/) to create the bindings between Rust and Python.
|
||||||
|
|
||||||
* `src/`: Rust bindings source code
|
- `src/`: Rust bindings source code
|
||||||
* `python/lancedb`: Python package source code
|
- `python/lancedb`: Python package source code
|
||||||
* `python/tests`: Unit tests
|
- `python/tests`: Unit tests
|
||||||
|
|
||||||
## Development environment
|
## Development environment
|
||||||
|
|
||||||
@@ -61,6 +61,12 @@ make test
|
|||||||
make doctest
|
make doctest
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Run type checking:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
make typecheck
|
||||||
|
```
|
||||||
|
|
||||||
To run a single test, you can use the `pytest` command directly. Provide the path
|
To run a single test, you can use the `pytest` command directly. Provide the path
|
||||||
to the test file, and optionally the test name after `::`.
|
to the test file, and optionally the test name after `::`.
|
||||||
|
|
||||||
|
|||||||
@@ -23,6 +23,10 @@ check: ## Check formatting and lints.
|
|||||||
fix: ## Fix python lints
|
fix: ## Fix python lints
|
||||||
ruff check python --fix
|
ruff check python --fix
|
||||||
|
|
||||||
|
.PHONY: typecheck
|
||||||
|
typecheck: ## Run type checking with pyright.
|
||||||
|
pyright
|
||||||
|
|
||||||
.PHONY: doctest
|
.PHONY: doctest
|
||||||
doctest: ## Run documentation tests.
|
doctest: ## Run documentation tests.
|
||||||
pytest --doctest-modules python/lancedb
|
pytest --doctest-modules python/lancedb
|
||||||
|
|||||||
@@ -92,7 +92,7 @@ requires = ["maturin>=1.4"]
|
|||||||
build-backend = "maturin"
|
build-backend = "maturin"
|
||||||
|
|
||||||
[tool.ruff.lint]
|
[tool.ruff.lint]
|
||||||
select = ["F", "E", "W", "G", "TCH", "PERF"]
|
select = ["F", "E", "W", "G", "PERF"]
|
||||||
|
|
||||||
[tool.pytest.ini_options]
|
[tool.pytest.ini_options]
|
||||||
addopts = "--strict-markers --ignore-glob=lancedb/embeddings/*.py"
|
addopts = "--strict-markers --ignore-glob=lancedb/embeddings/*.py"
|
||||||
@@ -103,5 +103,28 @@ markers = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
[tool.pyright]
|
[tool.pyright]
|
||||||
include = ["python/lancedb/table.py"]
|
include = [
|
||||||
|
"python/lancedb/index.py",
|
||||||
|
"python/lancedb/rerankers/util.py",
|
||||||
|
"python/lancedb/rerankers/__init__.py",
|
||||||
|
"python/lancedb/rerankers/voyageai.py",
|
||||||
|
"python/lancedb/rerankers/jinaai.py",
|
||||||
|
"python/lancedb/rerankers/openai.py",
|
||||||
|
"python/lancedb/rerankers/cross_encoder.py",
|
||||||
|
"python/lancedb/rerankers/colbert.py",
|
||||||
|
"python/lancedb/rerankers/answerdotai.py",
|
||||||
|
"python/lancedb/rerankers/cohere.py",
|
||||||
|
"python/lancedb/arrow.py",
|
||||||
|
"python/lancedb/__init__.py",
|
||||||
|
"python/lancedb/types.py",
|
||||||
|
"python/lancedb/integrations/__init__.py",
|
||||||
|
"python/lancedb/exceptions.py",
|
||||||
|
"python/lancedb/background_loop.py",
|
||||||
|
"python/lancedb/schema.py",
|
||||||
|
"python/lancedb/remote/__init__.py",
|
||||||
|
"python/lancedb/remote/errors.py",
|
||||||
|
"python/lancedb/embeddings/__init__.py",
|
||||||
|
"python/lancedb/_lancedb.pyi",
|
||||||
|
]
|
||||||
|
exclude = ["python/tests/"]
|
||||||
pythonVersion = "3.12"
|
pythonVersion = "3.12"
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ from ._lancedb import connect as lancedb_connect
|
|||||||
from .common import URI, sanitize_uri
|
from .common import URI, sanitize_uri
|
||||||
from .db import AsyncConnection, DBConnection, LanceDBConnection
|
from .db import AsyncConnection, DBConnection, LanceDBConnection
|
||||||
from .remote import ClientConfig
|
from .remote import ClientConfig
|
||||||
|
from .remote.db import RemoteDBConnection
|
||||||
from .schema import vector
|
from .schema import vector
|
||||||
from .table import AsyncTable
|
from .table import AsyncTable
|
||||||
|
|
||||||
@@ -86,8 +87,6 @@ def connect(
|
|||||||
conn : DBConnection
|
conn : DBConnection
|
||||||
A connection to a LanceDB database.
|
A connection to a LanceDB database.
|
||||||
"""
|
"""
|
||||||
from .remote.db import RemoteDBConnection
|
|
||||||
|
|
||||||
if isinstance(uri, str) and uri.startswith("db://"):
|
if isinstance(uri, str) and uri.startswith("db://"):
|
||||||
if api_key is None:
|
if api_key is None:
|
||||||
api_key = os.environ.get("LANCEDB_API_KEY")
|
api_key = os.environ.get("LANCEDB_API_KEY")
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ from typing import Dict, List, Optional, Tuple, Any, Union, Literal
|
|||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
|
|
||||||
from .index import BTree, IvfFlat, IvfPq, Bitmap, LabelList, HnswPq, HnswSq, FTS
|
from .index import BTree, IvfFlat, IvfPq, Bitmap, LabelList, HnswPq, HnswSq, FTS
|
||||||
|
from .remote import ClientConfig
|
||||||
|
|
||||||
class Connection(object):
|
class Connection(object):
|
||||||
uri: str
|
uri: str
|
||||||
@@ -71,11 +72,15 @@ async def connect(
|
|||||||
region: Optional[str],
|
region: Optional[str],
|
||||||
host_override: Optional[str],
|
host_override: Optional[str],
|
||||||
read_consistency_interval: Optional[float],
|
read_consistency_interval: Optional[float],
|
||||||
|
client_config: Optional[Union[ClientConfig, Dict[str, Any]]],
|
||||||
|
storage_options: Optional[Dict[str, str]],
|
||||||
) -> Connection: ...
|
) -> Connection: ...
|
||||||
|
|
||||||
class RecordBatchStream:
|
class RecordBatchStream:
|
||||||
|
@property
|
||||||
def schema(self) -> pa.Schema: ...
|
def schema(self) -> pa.Schema: ...
|
||||||
async def next(self) -> Optional[pa.RecordBatch]: ...
|
def __aiter__(self) -> "RecordBatchStream": ...
|
||||||
|
async def __anext__(self) -> pa.RecordBatch: ...
|
||||||
|
|
||||||
class Query:
|
class Query:
|
||||||
def where(self, filter: str): ...
|
def where(self, filter: str): ...
|
||||||
|
|||||||
@@ -9,7 +9,8 @@ from typing import Any, Dict, Iterable, List, Optional, Union
|
|||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
from lancedb import connect_async
|
# Remove this import to fix circular dependency
|
||||||
|
# from lancedb import connect_async
|
||||||
from lancedb.remote import ClientConfig
|
from lancedb.remote import ClientConfig
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
from overrides import override
|
from overrides import override
|
||||||
@@ -78,6 +79,9 @@ class RemoteDBConnection(DBConnection):
|
|||||||
|
|
||||||
self.client_config = client_config
|
self.client_config = client_config
|
||||||
|
|
||||||
|
# Import connect_async here to avoid circular import
|
||||||
|
from lancedb import connect_async
|
||||||
|
|
||||||
self._conn = LOOP.run(
|
self._conn = LOOP.run(
|
||||||
connect_async(
|
connect_async(
|
||||||
db_url,
|
db_url,
|
||||||
|
|||||||
@@ -76,12 +76,21 @@ if TYPE_CHECKING:
|
|||||||
from .index import IndexConfig
|
from .index import IndexConfig
|
||||||
import pandas
|
import pandas
|
||||||
import PIL
|
import PIL
|
||||||
|
from .types import (
|
||||||
|
QueryType,
|
||||||
|
OnBadVectorsType,
|
||||||
|
AddMode,
|
||||||
|
CreateMode,
|
||||||
|
VectorIndexType,
|
||||||
|
ScalarIndexType,
|
||||||
|
BaseTokenizerType,
|
||||||
|
DistanceType,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
pd = safe_import_pandas()
|
pd = safe_import_pandas()
|
||||||
pl = safe_import_polars()
|
pl = safe_import_polars()
|
||||||
|
|
||||||
QueryType = Literal["vector", "fts", "hybrid", "auto"]
|
|
||||||
|
|
||||||
|
|
||||||
def _into_pyarrow_reader(data) -> pa.RecordBatchReader:
|
def _into_pyarrow_reader(data) -> pa.RecordBatchReader:
|
||||||
from lancedb.dependencies import datasets
|
from lancedb.dependencies import datasets
|
||||||
@@ -178,7 +187,7 @@ def _sanitize_data(
|
|||||||
data: "DATA",
|
data: "DATA",
|
||||||
target_schema: Optional[pa.Schema] = None,
|
target_schema: Optional[pa.Schema] = None,
|
||||||
metadata: Optional[dict] = None, # embedding metadata
|
metadata: Optional[dict] = None, # embedding metadata
|
||||||
on_bad_vectors: Literal["error", "drop", "fill", "null"] = "error",
|
on_bad_vectors: OnBadVectorsType = "error",
|
||||||
fill_value: float = 0.0,
|
fill_value: float = 0.0,
|
||||||
*,
|
*,
|
||||||
allow_subschema: bool = False,
|
allow_subschema: bool = False,
|
||||||
@@ -324,7 +333,7 @@ def sanitize_create_table(
|
|||||||
data,
|
data,
|
||||||
schema: Union[pa.Schema, LanceModel],
|
schema: Union[pa.Schema, LanceModel],
|
||||||
metadata=None,
|
metadata=None,
|
||||||
on_bad_vectors: str = "error",
|
on_bad_vectors: OnBadVectorsType = "error",
|
||||||
fill_value: float = 0.0,
|
fill_value: float = 0.0,
|
||||||
):
|
):
|
||||||
if inspect.isclass(schema) and issubclass(schema, LanceModel):
|
if inspect.isclass(schema) and issubclass(schema, LanceModel):
|
||||||
@@ -576,9 +585,7 @@ class Table(ABC):
|
|||||||
accelerator: Optional[str] = None,
|
accelerator: Optional[str] = None,
|
||||||
index_cache_size: Optional[int] = None,
|
index_cache_size: Optional[int] = None,
|
||||||
*,
|
*,
|
||||||
index_type: Literal[
|
index_type: VectorIndexType = "IVF_PQ",
|
||||||
"IVF_FLAT", "IVF_PQ", "IVF_HNSW_SQ", "IVF_HNSW_PQ"
|
|
||||||
] = "IVF_PQ",
|
|
||||||
num_bits: int = 8,
|
num_bits: int = 8,
|
||||||
max_iterations: int = 50,
|
max_iterations: int = 50,
|
||||||
sample_rate: int = 256,
|
sample_rate: int = 256,
|
||||||
@@ -643,7 +650,7 @@ class Table(ABC):
|
|||||||
column: str,
|
column: str,
|
||||||
*,
|
*,
|
||||||
replace: bool = True,
|
replace: bool = True,
|
||||||
index_type: Literal["BTREE", "BITMAP", "LABEL_LIST"] = "BTREE",
|
index_type: ScalarIndexType = "BTREE",
|
||||||
):
|
):
|
||||||
"""Create a scalar index on a column.
|
"""Create a scalar index on a column.
|
||||||
|
|
||||||
@@ -708,7 +715,7 @@ class Table(ABC):
|
|||||||
tokenizer_name: Optional[str] = None,
|
tokenizer_name: Optional[str] = None,
|
||||||
with_position: bool = True,
|
with_position: bool = True,
|
||||||
# tokenizer configs:
|
# tokenizer configs:
|
||||||
base_tokenizer: Literal["simple", "raw", "whitespace"] = "simple",
|
base_tokenizer: BaseTokenizerType = "simple",
|
||||||
language: str = "English",
|
language: str = "English",
|
||||||
max_token_length: Optional[int] = 40,
|
max_token_length: Optional[int] = 40,
|
||||||
lower_case: bool = True,
|
lower_case: bool = True,
|
||||||
@@ -777,8 +784,8 @@ class Table(ABC):
|
|||||||
def add(
|
def add(
|
||||||
self,
|
self,
|
||||||
data: DATA,
|
data: DATA,
|
||||||
mode: str = "append",
|
mode: AddMode = "append",
|
||||||
on_bad_vectors: str = "error",
|
on_bad_vectors: OnBadVectorsType = "error",
|
||||||
fill_value: float = 0.0,
|
fill_value: float = 0.0,
|
||||||
):
|
):
|
||||||
"""Add more data to the [Table](Table).
|
"""Add more data to the [Table](Table).
|
||||||
@@ -960,7 +967,7 @@ class Table(ABC):
|
|||||||
self,
|
self,
|
||||||
merge: LanceMergeInsertBuilder,
|
merge: LanceMergeInsertBuilder,
|
||||||
new_data: DATA,
|
new_data: DATA,
|
||||||
on_bad_vectors: str,
|
on_bad_vectors: OnBadVectorsType,
|
||||||
fill_value: float,
|
fill_value: float,
|
||||||
): ...
|
): ...
|
||||||
|
|
||||||
@@ -1572,10 +1579,10 @@ class LanceTable(Table):
|
|||||||
|
|
||||||
def create_index(
|
def create_index(
|
||||||
self,
|
self,
|
||||||
metric="L2",
|
metric: DistanceType = "l2",
|
||||||
num_partitions=None,
|
num_partitions=None,
|
||||||
num_sub_vectors=None,
|
num_sub_vectors=None,
|
||||||
vector_column_name=VECTOR_COLUMN_NAME,
|
vector_column_name: str = VECTOR_COLUMN_NAME,
|
||||||
replace: bool = True,
|
replace: bool = True,
|
||||||
accelerator: Optional[str] = None,
|
accelerator: Optional[str] = None,
|
||||||
index_cache_size: Optional[int] = None,
|
index_cache_size: Optional[int] = None,
|
||||||
@@ -1661,7 +1668,7 @@ class LanceTable(Table):
|
|||||||
column: str,
|
column: str,
|
||||||
*,
|
*,
|
||||||
replace: bool = True,
|
replace: bool = True,
|
||||||
index_type: Literal["BTREE", "BITMAP", "LABEL_LIST"] = "BTREE",
|
index_type: ScalarIndexType = "BTREE",
|
||||||
):
|
):
|
||||||
if index_type == "BTREE":
|
if index_type == "BTREE":
|
||||||
config = BTree()
|
config = BTree()
|
||||||
@@ -1686,7 +1693,7 @@ class LanceTable(Table):
|
|||||||
tokenizer_name: Optional[str] = None,
|
tokenizer_name: Optional[str] = None,
|
||||||
with_position: bool = True,
|
with_position: bool = True,
|
||||||
# tokenizer configs:
|
# tokenizer configs:
|
||||||
base_tokenizer: str = "simple",
|
base_tokenizer: BaseTokenizerType = "simple",
|
||||||
language: str = "English",
|
language: str = "English",
|
||||||
max_token_length: Optional[int] = 40,
|
max_token_length: Optional[int] = 40,
|
||||||
lower_case: bool = True,
|
lower_case: bool = True,
|
||||||
@@ -1820,8 +1827,8 @@ class LanceTable(Table):
|
|||||||
def add(
|
def add(
|
||||||
self,
|
self,
|
||||||
data: DATA,
|
data: DATA,
|
||||||
mode: str = "append",
|
mode: AddMode = "append",
|
||||||
on_bad_vectors: str = "error",
|
on_bad_vectors: OnBadVectorsType = "error",
|
||||||
fill_value: float = 0.0,
|
fill_value: float = 0.0,
|
||||||
):
|
):
|
||||||
"""Add data to the table.
|
"""Add data to the table.
|
||||||
@@ -2059,7 +2066,7 @@ class LanceTable(Table):
|
|||||||
query_type,
|
query_type,
|
||||||
vector_column_name=vector_column_name,
|
vector_column_name=vector_column_name,
|
||||||
ordering_field_name=ordering_field_name,
|
ordering_field_name=ordering_field_name,
|
||||||
fts_columns=fts_columns,
|
fts_columns=fts_columns or [],
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -2069,13 +2076,13 @@ class LanceTable(Table):
|
|||||||
name: str,
|
name: str,
|
||||||
data: Optional[DATA] = None,
|
data: Optional[DATA] = None,
|
||||||
schema: Optional[pa.Schema] = None,
|
schema: Optional[pa.Schema] = None,
|
||||||
mode: Literal["create", "overwrite"] = "create",
|
mode: CreateMode = "create",
|
||||||
exist_ok: bool = False,
|
exist_ok: bool = False,
|
||||||
on_bad_vectors: str = "error",
|
on_bad_vectors: OnBadVectorsType = "error",
|
||||||
fill_value: float = 0.0,
|
fill_value: float = 0.0,
|
||||||
embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None,
|
embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None,
|
||||||
*,
|
*,
|
||||||
storage_options: Optional[Dict[str, str]] = None,
|
storage_options: Optional[Dict[str, str | bool]] = None,
|
||||||
data_storage_version: Optional[str] = None,
|
data_storage_version: Optional[str] = None,
|
||||||
enable_v2_manifest_paths: Optional[bool] = None,
|
enable_v2_manifest_paths: Optional[bool] = None,
|
||||||
):
|
):
|
||||||
@@ -2229,7 +2236,7 @@ class LanceTable(Table):
|
|||||||
self,
|
self,
|
||||||
merge: LanceMergeInsertBuilder,
|
merge: LanceMergeInsertBuilder,
|
||||||
new_data: DATA,
|
new_data: DATA,
|
||||||
on_bad_vectors: str,
|
on_bad_vectors: OnBadVectorsType,
|
||||||
fill_value: float,
|
fill_value: float,
|
||||||
):
|
):
|
||||||
LOOP.run(self._table._do_merge(merge, new_data, on_bad_vectors, fill_value))
|
LOOP.run(self._table._do_merge(merge, new_data, on_bad_vectors, fill_value))
|
||||||
@@ -2880,7 +2887,7 @@ class AsyncTable:
|
|||||||
data: DATA,
|
data: DATA,
|
||||||
*,
|
*,
|
||||||
mode: Optional[Literal["append", "overwrite"]] = "append",
|
mode: Optional[Literal["append", "overwrite"]] = "append",
|
||||||
on_bad_vectors: Optional[str] = None,
|
on_bad_vectors: Optional[OnBadVectorsType] = None,
|
||||||
fill_value: Optional[float] = None,
|
fill_value: Optional[float] = None,
|
||||||
):
|
):
|
||||||
"""Add more data to the [Table](Table).
|
"""Add more data to the [Table](Table).
|
||||||
@@ -2986,7 +2993,7 @@ class AsyncTable:
|
|||||||
@overload
|
@overload
|
||||||
async def search(
|
async def search(
|
||||||
self,
|
self,
|
||||||
query: Optional[Union[str]] = None,
|
query: Optional[str] = None,
|
||||||
vector_column_name: Optional[str] = None,
|
vector_column_name: Optional[str] = None,
|
||||||
query_type: Literal["auto"] = ...,
|
query_type: Literal["auto"] = ...,
|
||||||
ordering_field_name: Optional[str] = None,
|
ordering_field_name: Optional[str] = None,
|
||||||
@@ -2996,7 +3003,7 @@ class AsyncTable:
|
|||||||
@overload
|
@overload
|
||||||
async def search(
|
async def search(
|
||||||
self,
|
self,
|
||||||
query: Optional[Union[str]] = None,
|
query: Optional[str] = None,
|
||||||
vector_column_name: Optional[str] = None,
|
vector_column_name: Optional[str] = None,
|
||||||
query_type: Literal["hybrid"] = ...,
|
query_type: Literal["hybrid"] = ...,
|
||||||
ordering_field_name: Optional[str] = None,
|
ordering_field_name: Optional[str] = None,
|
||||||
@@ -3040,7 +3047,7 @@ class AsyncTable:
|
|||||||
query_type: QueryType = "auto",
|
query_type: QueryType = "auto",
|
||||||
ordering_field_name: Optional[str] = None,
|
ordering_field_name: Optional[str] = None,
|
||||||
fts_columns: Optional[Union[str, List[str]]] = None,
|
fts_columns: Optional[Union[str, List[str]]] = None,
|
||||||
) -> AsyncQuery:
|
) -> Union[AsyncHybridQuery | AsyncFTSQuery | AsyncVectorQuery]:
|
||||||
"""Create a search query to find the nearest neighbors
|
"""Create a search query to find the nearest neighbors
|
||||||
of the given query vector. We currently support [vector search][search]
|
of the given query vector. We currently support [vector search][search]
|
||||||
and [full-text search][experimental-full-text-search].
|
and [full-text search][experimental-full-text-search].
|
||||||
@@ -3279,7 +3286,7 @@ class AsyncTable:
|
|||||||
self,
|
self,
|
||||||
merge: LanceMergeInsertBuilder,
|
merge: LanceMergeInsertBuilder,
|
||||||
new_data: DATA,
|
new_data: DATA,
|
||||||
on_bad_vectors: str,
|
on_bad_vectors: OnBadVectorsType,
|
||||||
fill_value: float,
|
fill_value: float,
|
||||||
):
|
):
|
||||||
schema = await self.schema()
|
schema = await self.schema()
|
||||||
|
|||||||
28
python/python/lancedb/types.py
Normal file
28
python/python/lancedb/types.py
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||||
|
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
# Query type literals
|
||||||
|
QueryType = Literal["vector", "fts", "hybrid", "auto"]
|
||||||
|
|
||||||
|
# Distance type literals
|
||||||
|
DistanceType = Literal["l2", "cosine", "dot"]
|
||||||
|
DistanceTypeWithHamming = Literal["l2", "cosine", "dot", "hamming"]
|
||||||
|
|
||||||
|
# Vector handling literals
|
||||||
|
OnBadVectorsType = Literal["error", "drop", "fill", "null"]
|
||||||
|
|
||||||
|
# Mode literals
|
||||||
|
AddMode = Literal["append", "overwrite"]
|
||||||
|
CreateMode = Literal["create", "overwrite"]
|
||||||
|
|
||||||
|
# Index type literals
|
||||||
|
VectorIndexType = Literal["IVF_FLAT", "IVF_PQ", "IVF_HNSW_SQ", "IVF_HNSW_PQ"]
|
||||||
|
ScalarIndexType = Literal["BTREE", "BITMAP", "LABEL_LIST"]
|
||||||
|
IndexType = Literal[
|
||||||
|
"IVF_PQ", "IVF_HNSW_PQ", "IVF_HNSW_SQ", "FTS", "BTREE", "BITMAP", "LABEL_LIST"
|
||||||
|
]
|
||||||
|
|
||||||
|
# Tokenizer literals
|
||||||
|
BaseTokenizerType = Literal["simple", "raw", "whitespace"]
|
||||||
@@ -419,17 +419,17 @@ def test_embedding_function_safe_model_dump(embedding_type):
|
|||||||
|
|
||||||
dumped_model = model.safe_model_dump()
|
dumped_model = model.safe_model_dump()
|
||||||
|
|
||||||
assert all(
|
assert all(not k.startswith("_") for k in dumped_model.keys()), (
|
||||||
not k.startswith("_") for k in dumped_model.keys()
|
f"{embedding_type}: Dumped model contains keys starting with underscore"
|
||||||
), f"{embedding_type}: Dumped model contains keys starting with underscore"
|
)
|
||||||
|
|
||||||
assert (
|
assert "max_retries" in dumped_model, (
|
||||||
"max_retries" in dumped_model
|
f"{embedding_type}: Essential field 'max_retries' is missing from dumped model"
|
||||||
), f"{embedding_type}: Essential field 'max_retries' is missing from dumped model"
|
)
|
||||||
|
|
||||||
assert isinstance(
|
assert isinstance(dumped_model, dict), (
|
||||||
dumped_model, dict
|
f"{embedding_type}: Dumped model is not a dictionary"
|
||||||
), f"{embedding_type}: Dumped model is not a dictionary"
|
)
|
||||||
|
|
||||||
for key in model.__dict__:
|
for key in model.__dict__:
|
||||||
if key.startswith("_"):
|
if key.startswith("_"):
|
||||||
|
|||||||
@@ -129,6 +129,6 @@ def test_normalize_scores():
|
|||||||
if invert:
|
if invert:
|
||||||
expected = pc.subtract(1.0, expected)
|
expected = pc.subtract(1.0, expected)
|
||||||
|
|
||||||
assert pc.equal(
|
assert pc.equal(result, expected), (
|
||||||
result, expected
|
f"Expected {expected} but got {result} for invert={invert}"
|
||||||
), f"Expected {expected} but got {result} for invert={invert}"
|
)
|
||||||
|
|||||||
@@ -784,8 +784,7 @@ async def test_query_search_auto(mem_db_async: AsyncConnection):
|
|||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
Exception,
|
Exception,
|
||||||
match=(
|
match=(
|
||||||
"Cannot perform full text search unless an INVERTED index has "
|
"Cannot perform full text search unless an INVERTED index has been created"
|
||||||
"been created"
|
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
query = await (await tbl2.search("0.1")).to_arrow()
|
query = await (await tbl2.search("0.1")).to_arrow()
|
||||||
|
|||||||
@@ -131,9 +131,9 @@ def _run_test_reranker(reranker, table, query, query_vector, schema):
|
|||||||
"represents the relevance of the result to the query & should "
|
"represents the relevance of the result to the query & should "
|
||||||
"be descending."
|
"be descending."
|
||||||
)
|
)
|
||||||
assert np.all(
|
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
|
||||||
np.diff(result.column("_relevance_score").to_numpy()) <= 0
|
ascending_relevance_err
|
||||||
), ascending_relevance_err
|
)
|
||||||
|
|
||||||
# Vector search setting
|
# Vector search setting
|
||||||
result = (
|
result = (
|
||||||
@@ -143,9 +143,9 @@ def _run_test_reranker(reranker, table, query, query_vector, schema):
|
|||||||
.to_arrow()
|
.to_arrow()
|
||||||
)
|
)
|
||||||
assert len(result) == 30
|
assert len(result) == 30
|
||||||
assert np.all(
|
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
|
||||||
np.diff(result.column("_relevance_score").to_numpy()) <= 0
|
ascending_relevance_err
|
||||||
), ascending_relevance_err
|
)
|
||||||
result_explicit = (
|
result_explicit = (
|
||||||
table.search(query_vector, vector_column_name="vector")
|
table.search(query_vector, vector_column_name="vector")
|
||||||
.rerank(reranker=reranker, query_string=query)
|
.rerank(reranker=reranker, query_string=query)
|
||||||
@@ -168,9 +168,9 @@ def _run_test_reranker(reranker, table, query, query_vector, schema):
|
|||||||
.to_arrow()
|
.to_arrow()
|
||||||
)
|
)
|
||||||
assert len(result) > 0
|
assert len(result) > 0
|
||||||
assert np.all(
|
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
|
||||||
np.diff(result.column("_relevance_score").to_numpy()) <= 0
|
ascending_relevance_err
|
||||||
), ascending_relevance_err
|
)
|
||||||
|
|
||||||
# empty FTS results
|
# empty FTS results
|
||||||
query = "abcxyz" * 100
|
query = "abcxyz" * 100
|
||||||
@@ -185,9 +185,9 @@ def _run_test_reranker(reranker, table, query, query_vector, schema):
|
|||||||
|
|
||||||
# should return _relevance_score column
|
# should return _relevance_score column
|
||||||
assert "_relevance_score" in result.column_names
|
assert "_relevance_score" in result.column_names
|
||||||
assert np.all(
|
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
|
||||||
np.diff(result.column("_relevance_score").to_numpy()) <= 0
|
ascending_relevance_err
|
||||||
), ascending_relevance_err
|
)
|
||||||
|
|
||||||
# Multi-vector search setting
|
# Multi-vector search setting
|
||||||
rs1 = table.search(query, vector_column_name="vector").limit(10).with_row_id(True)
|
rs1 = table.search(query, vector_column_name="vector").limit(10).with_row_id(True)
|
||||||
@@ -262,9 +262,9 @@ def _run_test_hybrid_reranker(reranker, tmp_path, use_tantivy):
|
|||||||
"represents the relevance of the result to the query & should "
|
"represents the relevance of the result to the query & should "
|
||||||
"be descending."
|
"be descending."
|
||||||
)
|
)
|
||||||
assert np.all(
|
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
|
||||||
np.diff(result.column("_relevance_score").to_numpy()) <= 0
|
ascending_relevance_err
|
||||||
), ascending_relevance_err
|
)
|
||||||
|
|
||||||
# Test with empty FTS results
|
# Test with empty FTS results
|
||||||
query = "abcxyz" * 100
|
query = "abcxyz" * 100
|
||||||
@@ -278,9 +278,9 @@ def _run_test_hybrid_reranker(reranker, tmp_path, use_tantivy):
|
|||||||
)
|
)
|
||||||
# should return _relevance_score column
|
# should return _relevance_score column
|
||||||
assert "_relevance_score" in result.column_names
|
assert "_relevance_score" in result.column_names
|
||||||
assert np.all(
|
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
|
||||||
np.diff(result.column("_relevance_score").to_numpy()) <= 0
|
ascending_relevance_err
|
||||||
), ascending_relevance_err
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("use_tantivy", [True, False])
|
@pytest.mark.parametrize("use_tantivy", [True, False])
|
||||||
|
|||||||
Reference in New Issue
Block a user