diff --git a/python/lancedb/query.py b/python/lancedb/query.py index 743602ad..ebfbd11e 100644 --- a/python/lancedb/query.py +++ b/python/lancedb/query.py @@ -16,6 +16,12 @@ from __future__ import annotations from abc import ABC, abstractmethod from typing import TYPE_CHECKING, List, Literal, Optional, Type, Union +try: + # Python 3.11+ + from typing import Self +except ImportError: + from typing_extensions import Self + import deprecation import numpy as np import pyarrow as pa @@ -275,7 +281,7 @@ class LanceQueryBuilder(ABC): self._limit = limit return self - def select(self, columns: list) -> LanceQueryBuilder: + def select(self, columns: list) -> Self: """Set the columns to return. Parameters @@ -291,7 +297,7 @@ class LanceQueryBuilder(ABC): self._columns = columns return self - def where(self, where: str, prefilter: bool = False) -> LanceQueryBuilder: + def where(self, where: str, prefilter: bool = False) -> Self: """Set the where clause. Parameters @@ -351,7 +357,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder): self._vector_column = vector_column self._prefilter = False - def metric(self, metric: Literal["L2", "cosine"]) -> LanceVectorQueryBuilder: + def metric(self, metric: Literal["L2", "cosine"]) -> Self: """Set the distance metric to use. Parameters @@ -367,7 +373,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder): self._metric = metric return self - def nprobes(self, nprobes: int) -> LanceVectorQueryBuilder: + def nprobes(self, nprobes: int) -> Self: """Set the number of probes to use. Higher values will yield better recall (more likely to find vectors if @@ -389,7 +395,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder): self._nprobes = nprobes return self - def refine_factor(self, refine_factor: int) -> LanceVectorQueryBuilder: + def refine_factor(self, refine_factor: int) -> Self: """Set the refine factor to use, increasing the number of vectors sampled. As an example, a refine factor of 2 will sample 2x as many vectors as @@ -434,7 +440,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder): ) return self._table._execute_query(query) - def where(self, where: str, prefilter: bool = False) -> LanceVectorQueryBuilder: + def where(self, where: str, prefilter: bool = False) -> Self: """Set the where clause. Parameters diff --git a/python/pyproject.toml b/python/pyproject.toml index f3c3142d..32c1393b 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -15,7 +15,8 @@ dependencies = [ "pyyaml>=6.0", "click>=8.1.7", "requests>=2.31.0", - "overrides>=0.7" + "overrides>=0.7", + "typing_extensions>=4.7", ] description = "lancedb" authors = [{ name = "LanceDB Devs", email = "dev@lancedb.com" }] @@ -49,11 +50,27 @@ classifiers = [ repository = "https://github.com/lancedb/lancedb" [project.optional-dependencies] -tests = ["pandas>=1.4", "pytest", "pytest-mock", "pytest-asyncio", "requests", "duckdb", "pytz"] +tests = [ + "pandas>=1.4", + "pytest", + "pytest-mock", + "pytest-asyncio", + "requests", + "duckdb", + "pytz" +] dev = ["ruff", "pre-commit", "black"] docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"] clip = ["torch", "pillow", "open-clip"] -embeddings = ["openai>=1.6.1", "sentence-transformers", "torch", "pillow", "open-clip-torch", "cohere", "InstructorEmbedding"] +embeddings = [ + "openai>=1.6.1", + "sentence-transformers", + "torch", + "pillow", + "open-clip-torch", + "cohere", + "InstructorEmbedding" +] [project.scripts] lancedb = "lancedb.cli.cli:cli"