typing Self

This commit is contained in:
Lei Xu
2023-12-29 09:16:46 -08:00
parent 392777952f
commit 3dc8b3305e
2 changed files with 32 additions and 9 deletions

View File

@@ -16,6 +16,12 @@ from __future__ import annotations
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, List, Literal, Optional, Type, Union
try:
# Python 3.11+
from typing import Self
except ImportError:
from typing_extensions import Self
import deprecation
import numpy as np
import pyarrow as pa
@@ -275,7 +281,7 @@ class LanceQueryBuilder(ABC):
self._limit = limit
return self
def select(self, columns: list) -> LanceQueryBuilder:
def select(self, columns: list) -> Self:
"""Set the columns to return.
Parameters
@@ -291,7 +297,7 @@ class LanceQueryBuilder(ABC):
self._columns = columns
return self
def where(self, where: str, prefilter: bool = False) -> LanceQueryBuilder:
def where(self, where: str, prefilter: bool = False) -> Self:
"""Set the where clause.
Parameters
@@ -351,7 +357,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
self._vector_column = vector_column
self._prefilter = False
def metric(self, metric: Literal["L2", "cosine"]) -> LanceVectorQueryBuilder:
def metric(self, metric: Literal["L2", "cosine"]) -> Self:
"""Set the distance metric to use.
Parameters
@@ -367,7 +373,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
self._metric = metric
return self
def nprobes(self, nprobes: int) -> LanceVectorQueryBuilder:
def nprobes(self, nprobes: int) -> Self:
"""Set the number of probes to use.
Higher values will yield better recall (more likely to find vectors if
@@ -389,7 +395,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
self._nprobes = nprobes
return self
def refine_factor(self, refine_factor: int) -> LanceVectorQueryBuilder:
def refine_factor(self, refine_factor: int) -> Self:
"""Set the refine factor to use, increasing the number of vectors sampled.
As an example, a refine factor of 2 will sample 2x as many vectors as
@@ -434,7 +440,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
)
return self._table._execute_query(query)
def where(self, where: str, prefilter: bool = False) -> LanceVectorQueryBuilder:
def where(self, where: str, prefilter: bool = False) -> Self:
"""Set the where clause.
Parameters

View File

@@ -15,7 +15,8 @@ dependencies = [
"pyyaml>=6.0",
"click>=8.1.7",
"requests>=2.31.0",
"overrides>=0.7"
"overrides>=0.7",
"typing_extensions>=4.7",
]
description = "lancedb"
authors = [{ name = "LanceDB Devs", email = "dev@lancedb.com" }]
@@ -49,11 +50,27 @@ classifiers = [
repository = "https://github.com/lancedb/lancedb"
[project.optional-dependencies]
tests = ["pandas>=1.4", "pytest", "pytest-mock", "pytest-asyncio", "requests", "duckdb", "pytz"]
tests = [
"pandas>=1.4",
"pytest",
"pytest-mock",
"pytest-asyncio",
"requests",
"duckdb",
"pytz"
]
dev = ["ruff", "pre-commit", "black"]
docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"]
clip = ["torch", "pillow", "open-clip"]
embeddings = ["openai>=1.6.1", "sentence-transformers", "torch", "pillow", "open-clip-torch", "cohere", "InstructorEmbedding"]
embeddings = [
"openai>=1.6.1",
"sentence-transformers",
"torch",
"pillow",
"open-clip-torch",
"cohere",
"InstructorEmbedding"
]
[project.scripts]
lancedb = "lancedb.cli.cli:cli"