From d012db24c2dd11b7cc1c2485372ebd1a53b5b0c3 Mon Sep 17 00:00:00 2001 From: Will Jones Date: Fri, 19 Jan 2024 13:09:14 -0800 Subject: [PATCH] ci: lint and enforce linting (#829) @eddyxu added instructions for linting here: https://github.com/lancedb/lancedb/blob/7af213801a091cd5afb0f0814e184fc0b852de47/python/README.md?plain=1#L45-L50 However, we had a lot of failures and weren't checking this in CI. This PR fixes all lints and adds a check to CI to keep us in compliance with the lints. --- .github/workflows/python.yml | 4 +- python/lancedb/cli/cli.py | 7 +-- python/lancedb/conftest.py | 1 - python/lancedb/context.py | 3 +- python/lancedb/embeddings/__init__.py | 2 +- python/lancedb/embeddings/gemini_text.py | 40 +++++++++------ python/lancedb/embeddings/instructor.py | 51 +++++++++++-------- python/lancedb/embeddings/open_clip.py | 6 ++- python/lancedb/embeddings/registry.py | 7 ++- .../embeddings/sentence_transformers.py | 1 - python/lancedb/embeddings/utils.py | 19 ++++--- python/lancedb/fts.py | 4 +- python/lancedb/pydantic.py | 31 ++++++++--- python/lancedb/query.py | 35 ++++++++----- python/lancedb/remote/client.py | 4 +- python/lancedb/remote/db.py | 10 ++-- python/lancedb/remote/table.py | 16 +++--- python/lancedb/table.py | 14 +++-- python/lancedb/util.py | 2 +- python/lancedb/utils/config.py | 12 +++-- python/lancedb/utils/events.py | 23 +++++---- python/lancedb/utils/general.py | 35 ++++++++----- python/lancedb/utils/sentry_log.py | 24 +++++---- python/tests/test_cli.py | 9 ++-- python/tests/test_db.py | 3 +- python/tests/test_embeddings.py | 2 +- python/tests/test_pydantic.py | 51 ++++++++++--------- python/tests/test_telemetry.py | 3 +- 28 files changed, 250 insertions(+), 169 deletions(-) diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index 835cd41e..dc1c2ace 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -38,8 +38,10 @@ jobs: pip install -e .[tests] pip install tantivy@git+https://github.com/quickwit-oss/tantivy-py#164adc87e1a033117001cf70e38c82a53014d985 pip install pytest pytest-mock ruff - - name: Lint + - name: Format check run: ruff format --check . + - name: Lint + run: ruff . - name: Run tests run: pytest -m "not slow" -x -v --durations=30 tests - name: doctest diff --git a/python/lancedb/cli/cli.py b/python/lancedb/cli/cli.py index 5f51148b..06c54078 100644 --- a/python/lancedb/cli/cli.py +++ b/python/lancedb/cli/cli.py @@ -23,9 +23,10 @@ def cli(): diagnostics_help = """ -Enable or disable LanceDB diagnostics. When enabled, LanceDB will send anonymous events to help us improve LanceDB. -These diagnostics are used only for error reporting and no data is collected. You can find more about diagnosis on -our docs: https://lancedb.github.io/lancedb/cli_config/ +Enable or disable LanceDB diagnostics. When enabled, LanceDB will send anonymous events +to help us improve LanceDB. These diagnostics are used only for error reporting and no +data is collected. You can find more about diagnosis on our docs: +https://lancedb.github.io/lancedb/cli_config/ """ diff --git a/python/lancedb/conftest.py b/python/lancedb/conftest.py index df4907a7..273afbf7 100644 --- a/python/lancedb/conftest.py +++ b/python/lancedb/conftest.py @@ -1,6 +1,5 @@ import os import time -from typing import Any import numpy as np import pytest diff --git a/python/lancedb/context.py b/python/lancedb/context.py index 02051614..bd7b04c8 100644 --- a/python/lancedb/context.py +++ b/python/lancedb/context.py @@ -59,7 +59,8 @@ def contextualize(raw_df: "pd.DataFrame") -> Contextualizer: 8 dog I love 1 9 I love sandwiches 2 10 love sandwiches 2 - >>> contextualize(data).window(7).stride(1).min_window_size(7).text_col('token').to_pandas() + >>> (contextualize(data).window(7).stride(1).min_window_size(7) + ... .text_col('token').to_pandas()) token document_id 0 The quick brown fox jumped over the 1 1 quick brown fox jumped over the lazy 1 diff --git a/python/lancedb/embeddings/__init__.py b/python/lancedb/embeddings/__init__.py index cea0f381..ff60893e 100644 --- a/python/lancedb/embeddings/__init__.py +++ b/python/lancedb/embeddings/__init__.py @@ -14,10 +14,10 @@ # ruff: noqa: F401 from .base import EmbeddingFunction, EmbeddingFunctionConfig, TextEmbeddingFunction from .cohere import CohereEmbeddingFunction +from .gemini_text import GeminiText from .instructor import InstructorEmbeddingFunction from .open_clip import OpenClipEmbeddings from .openai import OpenAIEmbeddings from .registry import EmbeddingFunctionRegistry, get_registry from .sentence_transformers import SentenceTransformerEmbeddings -from .gemini_text import GeminiText from .utils import with_embeddings diff --git a/python/lancedb/embeddings/gemini_text.py b/python/lancedb/embeddings/gemini_text.py index d71ec7f5..e0f103c2 100644 --- a/python/lancedb/embeddings/gemini_text.py +++ b/python/lancedb/embeddings/gemini_text.py @@ -13,40 +13,50 @@ import os from functools import cached_property -from typing import List, Union, Any +from typing import List, Union import numpy as np +from lancedb.pydantic import PYDANTIC_VERSION + from .base import TextEmbeddingFunction from .registry import register -from .utils import api_key_not_found_help, TEXT -from lancedb.pydantic import PYDANTIC_VERSION +from .utils import TEXT, api_key_not_found_help @register("gemini-text") class GeminiText(TextEmbeddingFunction): """ - An embedding function that uses the Google's Gemini API. Requires GOOGLE_API_KEY to be set. + An embedding function that uses the Google's Gemini API. Requires GOOGLE_API_KEY to + be set. https://ai.google.dev/docs/embeddings_guide Supports various tasks types: - | Task Type | Description | - |-------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------| - | "`retrieval_query`" | Specifies the given text is a query in a search/retrieval setting. | - | "`retrieval_document`" | Specifies the given text is a document in a search/retrieval setting. Using this task type requires a title but is automatically proided by Embeddings API | - | "`semantic_similarity`" | Specifies the given text will be used for Semantic Textual Similarity (STS). | - | "`classification`" | Specifies that the embeddings will be used for classification. | - | "`clusering`" | Specifies that the embeddings will be used for clustering. | + | Task Type | Description | + |-------------------------|--------------------------------------------------------| + | "`retrieval_query`" | Specifies the given text is a query in a | + | | search/retrieval setting. | + | "`retrieval_document`" | Specifies the given text is a document in a | + | | search/retrieval setting. Using this task type | + | | requires a title but is automatically provided by | + | | Embeddings API | + | "`semantic_similarity`" | Specifies the given text will be used for Semantic | + | | Textual Similarity (STS). | + | "`classification`" | Specifies that the embeddings will be used for | + | | classification. | + | "`clustering`" | Specifies that the embeddings will be used for | + | | clustering. | - - Note: The supported task types might change in the Gemini API, but as long as a supported task type and its argument set is provided, - those will be delegated to the API calls. + Note: The supported task types might change in the Gemini API, but as long as a + supported task type and its argument set is provided, those will be delegated + to the API calls. Parameters ---------- name: str, default "models/embedding-001" - The name of the model to use. See the Gemini documentation for a list of available models. + The name of the model to use. See the Gemini documentation for a list of + available models. query_task_type: str, default "retrieval_query" Sets the task type for the queries. diff --git a/python/lancedb/embeddings/instructor.py b/python/lancedb/embeddings/instructor.py index 53be8ccb..c2058b27 100644 --- a/python/lancedb/embeddings/instructor.py +++ b/python/lancedb/embeddings/instructor.py @@ -22,22 +22,29 @@ from .utils import TEXT, weak_lru @register("instructor") class InstructorEmbeddingFunction(TextEmbeddingFunction): """ - An embedding function that uses the InstructorEmbedding library. Instructor models support multi-task learning, and can be used for a - variety of tasks, including text classification, sentence similarity, and document retrieval. - If you want to calculate customized embeddings for specific sentences, you may follow the unified template to write instructions: + An embedding function that uses the InstructorEmbedding library. Instructor models + support multi-task learning, and can be used for a variety of tasks, including + text classification, sentence similarity, and document retrieval. If you want to + calculate customized embeddings for specific sentences, you may follow the unified + template to write instructions: "Represent the `domain` `text_type` for `task_objective`": - * domain is optional, and it specifies the domain of the text, e.g., science, finance, medicine, etc. - * text_type is required, and it specifies the encoding unit, e.g., sentence, document, paragraph, etc. - * task_objective is optional, and it specifies the objective of embedding, e.g., retrieve a document, classify the sentence, etc. + * domain is optional, and it specifies the domain of the text, e.g., science, + finance, medicine, etc. + * text_type is required, and it specifies the encoding unit, e.g., sentence, + document, paragraph, etc. + * task_objective is optional, and it specifies the objective of embedding, + e.g., retrieve a document, classify the sentence, etc. - For example, if you want to calculate embeddings for a document, you may write the instruction as follows: - "Represent the document for retreival" + For example, if you want to calculate embeddings for a document, you may write the + instruction as follows: + "Represent the document for retrieval" Parameters ---------- name: str - The name of the model to use. Available models are listed at https://github.com/xlang-ai/instructor-embedding#model-list; + The name of the model to use. Available models are listed at + https://github.com/xlang-ai/instructor-embedding#model-list; The default model is hkunlp/instructor-base batch_size: int, default 32 The batch size to use when generating embeddings @@ -49,21 +56,24 @@ class InstructorEmbeddingFunction(TextEmbeddingFunction): Whether to normalize the embeddings quantize: bool, default False Whether to quantize the model - source_instruction: str, default "represent the docuement for retreival" + source_instruction: str, default "represent the document for retrieval" The instruction for the source column - query_instruction: str, default "represent the document for retreiving the most similar documents" + query_instruction: str, default "represent the document for retrieving the most + similar documents" The instruction for the query Examples -------- + import lancedb from lancedb.pydantic import LanceModel, Vector from lancedb.embeddings import get_registry, InstuctorEmbeddingFunction instructor = get_registry().get("instructor").create( - source_instruction="represent the docuement for retreival", - query_instruction="represent the document for retreiving the most similar documents" - ) + source_instruction="represent the document for retrieval", + query_instruction="represent the document for retrieving the most " + "similar documents" + ) class Schema(LanceModel): vector: Vector(instructor.ndims()) = instructor.VectorField() @@ -72,9 +82,12 @@ class InstructorEmbeddingFunction(TextEmbeddingFunction): db = lancedb.connect("~/.lancedb") tbl = db.create_table("test", schema=Schema, mode="overwrite") - texts = [{"text": "Capitalism has been dominant in the Western world since the end of feudalism, but most feel[who?] that..."}, - {"text": "The disparate impact theory is especially controversial under the Fair Housing Act because the Act..."}, - {"text": "Disparate impact in United States labor law refers to practices in employment, housing, and other areas that.."}] + texts = [{"text": "Capitalism has been dominant in the Western world since the " + "end of feudalism, but most feel[who?] that..."}, + {"text": "The disparate impact theory is especially controversial under " + "the Fair Housing Act because the Act..."}, + {"text": "Disparate impact in United States labor law refers to practices " + "in employment, housing, and other areas that.."}] tbl.add(texts) @@ -103,9 +116,7 @@ class InstructorEmbeddingFunction(TextEmbeddingFunction): def compute_source_embeddings(self, texts: TEXT, *args, **kwargs) -> List[np.array]: texts = self.sanitize_input(texts) - texts_formatted = [] - for text in texts: - texts_formatted.append([self.source_instruction, text]) + texts_formatted = [[self.source_instruction, text] for text in texts] return self.generate_embeddings(texts_formatted) def generate_embeddings(self, texts: List) -> List: diff --git a/python/lancedb/embeddings/open_clip.py b/python/lancedb/embeddings/open_clip.py index 1e9aeb6e..6392b0ef 100644 --- a/python/lancedb/embeddings/open_clip.py +++ b/python/lancedb/embeddings/open_clip.py @@ -14,7 +14,7 @@ import concurrent.futures import io import os import urllib.parse as urlparse -from typing import List, Union +from typing import TYPE_CHECKING, List, Union import numpy as np import pyarrow as pa @@ -25,6 +25,10 @@ from .base import EmbeddingFunction from .registry import register from .utils import IMAGES, url_retrieve +if TYPE_CHECKING: + import PIL + import torch + @register("open-clip") class OpenClipEmbeddings(EmbeddingFunction): diff --git a/python/lancedb/embeddings/registry.py b/python/lancedb/embeddings/registry.py index af7600dc..d5ab1f35 100644 --- a/python/lancedb/embeddings/registry.py +++ b/python/lancedb/embeddings/registry.py @@ -23,7 +23,9 @@ class EmbeddingFunctionRegistry: You can implement your own embedding function by subclassing EmbeddingFunction or TextEmbeddingFunction and registering it with the registry. - NOTE: Here TEXT is a type alias for Union[str, List[str], pa.Array, pa.ChunkedArray, np.ndarray] + NOTE: Here TEXT is a type alias for Union[str, List[str], pa.Array, + pa.ChunkedArray, np.ndarray] + Examples -------- >>> registry = EmbeddingFunctionRegistry.get_instance() @@ -164,7 +166,8 @@ __REGISTRY__ = EmbeddingFunctionRegistry() # @EmbeddingFunctionRegistry.get_instance().register(name) doesn't work in 3.8 -register = lambda name: EmbeddingFunctionRegistry.get_instance().register(name) +def register(name): + return __REGISTRY__.get_instance().register(name) def get_registry(): diff --git a/python/lancedb/embeddings/sentence_transformers.py b/python/lancedb/embeddings/sentence_transformers.py index 045fe2eb..d958e054 100644 --- a/python/lancedb/embeddings/sentence_transformers.py +++ b/python/lancedb/embeddings/sentence_transformers.py @@ -13,7 +13,6 @@ from typing import List, Union import numpy as np -from cachetools import cached from .base import TextEmbeddingFunction from .registry import register diff --git a/python/lancedb/embeddings/utils.py b/python/lancedb/embeddings/utils.py index 5aea89d2..325145f4 100644 --- a/python/lancedb/embeddings/utils.py +++ b/python/lancedb/embeddings/utils.py @@ -112,7 +112,8 @@ class FunctionWrapper: v = int(sys.version_info.minor) if v >= 11: print( - "WARNING: rate limit only support up to 3.10, proceeding without rate limiter" + "WARNING: rate limit only support up to 3.10, proceeding " + "without rate limiter" ) else: import ratelimiter @@ -168,8 +169,8 @@ class FunctionWrapper: def weak_lru(maxsize=128): """ - LRU cache that keeps weak references to the objects it caches. Only caches the latest instance of the objects to make sure memory usage - is bounded. + LRU cache that keeps weak references to the objects it caches. Only caches the + latest instance of the objects to make sure memory usage is bounded. Parameters ---------- @@ -234,15 +235,17 @@ def retry_with_exponential_backoff( num_retries = 0 delay = initial_delay - # Loop until a successful response or max_retries is hit or an exception is raised + # Loop until a successful response or max_retries is hit or an exception + # is raised while True: try: return func(*args, **kwargs) - # Currently retrying on all exceptions as there is no way to know the format of the error msgs used by different APIs - # We'll log the error and say that it is assumed that if this portion errors out, it's due to rate limit but the user - # should check the error message to be sure - except Exception as e: + # Currently retrying on all exceptions as there is no way to know the + # format of the error msgs used by different APIs. We'll log the error + # and say that it is assumed that if this portion errors out, it's due + # to rate limit but the user should check the error message to be sure. + except Exception as e: # noqa: PERF203 num_retries += 1 if num_retries > max_retries: diff --git a/python/lancedb/fts.py b/python/lancedb/fts.py index f9667fcc..750e3076 100644 --- a/python/lancedb/fts.py +++ b/python/lancedb/fts.py @@ -13,7 +13,7 @@ """Full text search index using tantivy-py""" import os -from typing import List, Optional, Tuple +from typing import List, Tuple import pyarrow as pa @@ -21,7 +21,7 @@ try: import tantivy except ImportError: raise ImportError( - "Please install tantivy-py `pip install tantivy@git+https://github.com/quickwit-oss/tantivy-py#164adc87e1a033117001cf70e38c82a53014d985` to use the full text search feature." + "Please install tantivy-py `pip install tantivy@git+https://github.com/quickwit-oss/tantivy-py#164adc87e1a033117001cf70e38c82a53014d985` to use the full text search feature." # noqa: E501 ) from .table import LanceTable diff --git a/python/lancedb/pydantic.py b/python/lancedb/pydantic.py index 859eeaa8..2a550032 100644 --- a/python/lancedb/pydantic.py +++ b/python/lancedb/pydantic.py @@ -20,15 +20,22 @@ import sys import types from abc import ABC, abstractmethod from datetime import date, datetime -from typing import Any, Callable, Dict, Generator, List, Type, Union, _GenericAlias +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Generator, + List, + Type, + Union, + _GenericAlias, +) import numpy as np import pyarrow as pa import pydantic import semver -from pydantic.fields import FieldInfo - -from .embeddings import EmbeddingFunctionRegistry PYDANTIC_VERSION = semver.Version.parse(pydantic.__version__) try: @@ -37,6 +44,11 @@ except ImportError: if PYDANTIC_VERSION >= (2,): raise +if TYPE_CHECKING: + from pydantic.fields import FieldInfo + + from .embeddings import EmbeddingFunctionConfig + class FixedSizeListMixin(ABC): @staticmethod @@ -190,7 +202,7 @@ else: ] -def _pydantic_to_arrow_type(field: pydantic.fields.FieldInfo) -> pa.DataType: +def _pydantic_to_arrow_type(field: FieldInfo) -> pa.DataType: """Convert a Pydantic FieldInfo to Arrow DataType""" if isinstance(field.annotation, _GenericAlias) or ( @@ -221,7 +233,7 @@ def _pydantic_to_arrow_type(field: pydantic.fields.FieldInfo) -> pa.DataType: return _py_type_to_arrow_type(field.annotation, field) -def is_nullable(field: pydantic.fields.FieldInfo) -> bool: +def is_nullable(field: FieldInfo) -> bool: """Check if a Pydantic FieldInfo is nullable.""" if isinstance(field.annotation, _GenericAlias): origin = field.annotation.__origin__ @@ -237,7 +249,7 @@ def is_nullable(field: pydantic.fields.FieldInfo) -> bool: return False -def _pydantic_to_field(name: str, field: pydantic.fields.FieldInfo) -> pa.Field: +def _pydantic_to_field(name: str, field: FieldInfo) -> pa.Field: """Convert a Pydantic field to a PyArrow Field.""" dt = _pydantic_to_arrow_type(field) return pa.field(name, dt, is_nullable(field)) @@ -309,6 +321,9 @@ class LanceModel(pydantic.BaseModel): schema = pydantic_to_schema(cls) functions = cls.parse_embedding_functions() if len(functions) > 0: + # Prevent circular import + from .embeddings import EmbeddingFunctionRegistry + metadata = EmbeddingFunctionRegistry.get_instance().get_table_metadata( functions ) @@ -359,7 +374,7 @@ class LanceModel(pydantic.BaseModel): return configs -def get_extras(field_info: pydantic.fields.FieldInfo, key: str) -> Any: +def get_extras(field_info: FieldInfo, key: str) -> Any: """ Get the extra metadata from a Pydantic FieldInfo. """ diff --git a/python/lancedb/query.py b/python/lancedb/query.py index 51eab73a..23e76c0d 100644 --- a/python/lancedb/query.py +++ b/python/lancedb/query.py @@ -27,7 +27,11 @@ from .common import VECTOR_COLUMN_NAME from .util import safe_import_pandas if TYPE_CHECKING: + import PIL + import polars as pl + from .pydantic import LanceModel + from .table import Table pd = safe_import_pandas() @@ -60,7 +64,7 @@ class Query(pydantic.BaseModel): - See discussion in [Querying an ANN Index][querying-an-ann-index] for tuning advice. refine_factor : Optional[int] - Refine the results by reading extra elements and re-ranking them in memory - optional + Refine the results by reading extra elements and re-ranking them in memory. - A higher number makes search more accurate but also slower. @@ -104,7 +108,7 @@ class LanceQueryBuilder(ABC): @classmethod def create( cls, - table: "lancedb.table.Table", + table: "Table", query: Optional[Union[np.ndarray, str, "PIL.Image.Image"]], query_type: str, vector_column_name: str, @@ -163,7 +167,7 @@ class LanceQueryBuilder(ABC): f"Invalid query_type, must be 'vector', 'fts', or 'auto': {query_type}" ) - def __init__(self, table: "lancedb.table.Table"): + def __init__(self, table: "Table"): self._table = table self._limit = 10 self._columns = None @@ -205,7 +209,6 @@ class LanceQueryBuilder(ABC): if flatten is True: while True: tbl = tbl.flatten() - has_struct = False # loop through all columns to check if there is any struct column if any(pa.types.is_struct(col.type) for col in tbl.schema): continue @@ -214,7 +217,8 @@ class LanceQueryBuilder(ABC): elif isinstance(flatten, int): if flatten <= 0: raise ValueError( - "Please specify a positive integer for flatten or the boolean value `True`" + "Please specify a positive integer for flatten or the boolean " + "value `True`" ) while flatten > 0: tbl = tbl.flatten() @@ -361,7 +365,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder): def __init__( self, - table: "lancedb.table.Table", + table: "Table", query: Union[np.ndarray, list, "PIL.Image.Image"], vector_column: str = VECTOR_COLUMN_NAME, ): @@ -486,7 +490,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder): class LanceFtsQueryBuilder(LanceQueryBuilder): """A builder for full text search for LanceDB.""" - def __init__(self, table: "lancedb.table.Table", query: str): + def __init__(self, table: "Table", query: str): super().__init__(table) self._query = query self._phrase_query = False @@ -513,7 +517,7 @@ class LanceFtsQueryBuilder(LanceQueryBuilder): import tantivy except ImportError: raise ImportError( - "Please install tantivy-py `pip install tantivy@git+https://github.com/quickwit-oss/tantivy-py#164adc87e1a033117001cf70e38c82a53014d985` to use the full text search feature." + "Please install tantivy-py `pip install tantivy@git+https://github.com/quickwit-oss/tantivy-py#164adc87e1a033117001cf70e38c82a53014d985` to use the full text search feature." # noqa: E501 ) from .fts import search_index @@ -523,8 +527,9 @@ class LanceFtsQueryBuilder(LanceQueryBuilder): # check if the index exist if not Path(index_path).exists(): raise FileNotFoundError( - "Fts index does not exist." - f"Please first call table.create_fts_index(['']) to create the fts index." + "Fts index does not exist. " + "Please first call table.create_fts_index(['']) to " + "create the fts index." ) # open the index index = tantivy.Index.open(index_path) @@ -543,19 +548,21 @@ class LanceFtsQueryBuilder(LanceQueryBuilder): if self._where is not None: try: - # TODO would be great to have Substrait generate pyarrow compute expressions - # or conversely have pyarrow support SQL expressions using Substrait + # TODO would be great to have Substrait generate pyarrow compute + # expressions or conversely have pyarrow support SQL expressions + # using Substrait import duckdb output_tbl = ( - duckdb.sql(f"SELECT * FROM output_tbl") + duckdb.sql("SELECT * FROM output_tbl") .filter(self._where) .to_arrow_table() ) except ImportError: - import lance import tempfile + import lance + # TODO Use "memory://" instead once that's supported with tempfile.TemporaryDirectory() as tmp: ds = lance.write_dataset(output_tbl, tmp) diff --git a/python/lancedb/remote/client.py b/python/lancedb/remote/client.py index 6f9bf292..9d0ea9d0 100644 --- a/python/lancedb/remote/client.py +++ b/python/lancedb/remote/client.py @@ -13,12 +13,12 @@ import functools -from typing import Any, Callable, Dict, Iterable, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union from urllib.parse import urljoin -import requests import attrs import pyarrow as pa +import requests from pydantic import BaseModel from lancedb.common import Credential diff --git a/python/lancedb/remote/db.py b/python/lancedb/remote/db.py index 337406db..3a88152f 100644 --- a/python/lancedb/remote/db.py +++ b/python/lancedb/remote/db.py @@ -11,7 +11,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import asyncio import inspect import logging import uuid @@ -102,8 +101,10 @@ class RemoteDBConnection(DBConnection): except LanceDBClientError as err: if str(err).startswith("Not found"): logging.error( - f"Table {name} does not exist. " - f"Please first call db.create_table({name}, data)" + "Table %s does not exist. Please first call " + "db.create_table(%s, data).", + name, + name, ) return RemoteTable(self, name) @@ -160,7 +161,8 @@ class RemoteDBConnection(DBConnection): Can create with list of tuples or dictionaries: >>> import lancedb - >>> db = lancedb.connect("db://...", api_key="...", region="...") # doctest: +SKIP + >>> db = lancedb.connect("db://...", api_key="...", # doctest: +SKIP + ... region="...") # doctest: +SKIP >>> data = [{"vector": [1.1, 1.2], "lat": 45.5, "long": -122.7}, ... {"vector": [0.2, 1.8], "lat": 40.1, "long": -74.1}] >>> db.create_table("my_table", data) # doctest: +SKIP diff --git a/python/lancedb/remote/table.py b/python/lancedb/remote/table.py index 63572ebb..4878bc1d 100644 --- a/python/lancedb/remote/table.py +++ b/python/lancedb/remote/table.py @@ -11,7 +11,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import asyncio import uuid from functools import cached_property from typing import Dict, Optional, Union @@ -89,7 +88,8 @@ class RemoteTable(Table): >>> import lancedb >>> import uuid >>> from lancedb.schema import vector - >>> db = lancedb.connect("db://...", api_key="...", region="...") # doctest: +SKIP + >>> db = lancedb.connect("db://...", api_key="...", # doctest: +SKIP + ... region="...") # doctest: +SKIP >>> table_name = uuid.uuid4().hex >>> schema = pa.schema( ... [ @@ -125,7 +125,8 @@ class RemoteTable(Table): on_bad_vectors: str = "error", fill_value: float = 0.0, ) -> int: - """Add more data to the [Table](Table). It has the same API signature as the OSS version. + """Add more data to the [Table](Table). It has the same API signature as + the OSS version. Parameters ---------- @@ -176,7 +177,8 @@ class RemoteTable(Table): Examples -------- >>> import lancedb - >>> db = lancedb.connect("db://...", api_key="...", region="...") # doctest: +SKIP + >>> db = lancedb.connect("db://...", api_key="...", # doctest: +SKIP + ... region="...") # doctest: +SKIP >>> data = [ ... {"original_width": 100, "caption": "bar", "vector": [0.1, 2.3, 4.5]}, ... {"original_width": 2000, "caption": "foo", "vector": [0.5, 3.4, 1.3]}, @@ -265,7 +267,8 @@ class RemoteTable(Table): ... {"x": 2, "vector": [3, 4]}, ... {"x": 3, "vector": [5, 6]} ... ] - >>> db = lancedb.connect("db://...", api_key="...", region="...") # doctest: +SKIP + >>> db = lancedb.connect("db://...", api_key="...", # doctest: +SKIP + ... region="...") # doctest: +SKIP >>> table = db.create_table("my_table", data) # doctest: +SKIP >>> table.search([10,10]).to_pandas() # doctest: +SKIP x vector _distance # doctest: +SKIP @@ -323,7 +326,8 @@ class RemoteTable(Table): ... {"x": 2, "vector": [3, 4]}, ... {"x": 3, "vector": [5, 6]} ... ] - >>> db = lancedb.connect("db://...", api_key="...", region="...") # doctest: +SKIP + >>> db = lancedb.connect("db://...", api_key="...", # doctest: +SKIP + ... region="...") # doctest: +SKIP >>> table = db.create_table("my_table", data) # doctest: +SKIP >>> table.to_pandas() # doctest: +SKIP x vector # doctest: +SKIP diff --git a/python/lancedb/table.py b/python/lancedb/table.py index 39d1ef74..f42435f3 100644 --- a/python/lancedb/table.py +++ b/python/lancedb/table.py @@ -14,7 +14,6 @@ from __future__ import annotations import inspect -import os from abc import ABC, abstractmethod from functools import cached_property from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Union @@ -33,18 +32,21 @@ from .pydantic import LanceModel, model_to_dict from .query import LanceQueryBuilder, Query from .util import ( fs_from_uri, + join_uri, safe_import_pandas, safe_import_polars, value_to_sql, - join_uri, ) from .utils.events import register_event if TYPE_CHECKING: from datetime import timedelta + import PIL from lance.dataset import CleanupStats, ReaderLike + from .db import LanceDBConnection + pd = safe_import_pandas() pl = safe_import_polars() @@ -526,9 +528,7 @@ class LanceTable(Table): A table in a LanceDB database. """ - def __init__( - self, connection: "lancedb.db.LanceDBConnection", name: str, version: int = None - ): + def __init__(self, connection: "LanceDBConnection", name: str, version: int = None): self._conn = connection self.name = name self._version = version @@ -783,9 +783,7 @@ class LanceTable(Table): index_exists = fs.get_file_info(path).type != pa_fs.FileType.NotFound if index_exists: if not replace: - raise ValueError( - f"Index already exists. Use replace=True to overwrite." - ) + raise ValueError("Index already exists. Use replace=True to overwrite.") fs.delete_dir(path) index = create_index(self._get_fts_index_path(), field_names) diff --git a/python/lancedb/util.py b/python/lancedb/util.py index 6122e4e0..5b4f3c65 100644 --- a/python/lancedb/util.py +++ b/python/lancedb/util.py @@ -12,9 +12,9 @@ # limitations under the License. import os +import pathlib from datetime import date, datetime from functools import singledispatch -import pathlib from typing import Tuple, Union from urllib.parse import urlparse diff --git a/python/lancedb/utils/config.py b/python/lancedb/utils/config.py index c4d94df1..a59fddfd 100644 --- a/python/lancedb/utils/config.py +++ b/python/lancedb/utils/config.py @@ -44,8 +44,9 @@ def get_user_config_dir(sub_dir="lancedb"): # GCP and AWS lambda fix, only /tmp is writeable if not is_dir_writeable(path.parent): LOGGER.warning( - f"WARNING ⚠️ user config directory '{path}' is not writeable, defaulting to '/tmp' or CWD." - "Alternatively you can define a LANCEDB_CONFIG_DIR environment variable for this path." + f"WARNING ⚠️ user config directory '{path}' is not writeable, defaulting " + "to '/tmp' or CWD. Alternatively you can define a LANCEDB_CONFIG_DIR " + "environment variable for this path." ) path = ( Path("/tmp") / sub_dir @@ -68,7 +69,8 @@ class Config(dict): Manages lancedb config stored in a YAML file. Args: - file (str | Path): Path to the lancedb config YAML file. Default is USER_CONFIG_DIR / 'config.yaml'. + file (str | Path): Path to the lancedb config YAML file. Default is + USER_CONFIG_DIR / 'config.yaml'. """ def __init__(self, file=CONFIG_FILE): @@ -90,8 +92,8 @@ class Config(dict): ) if not (correct_keys and correct_types): LOGGER.warning( - "WARNING ⚠️ LanceDB settings reset to default values. This may be due to a possible problem " - "with your settings or a recent package update. " + "WARNING ⚠️ LanceDB settings reset to default values. This may be due " + "to a possible problem with your settings or a recent package update. " f"\nView settings & usage with 'lancedb settings' or at '{self.file}'" ) self.reset() diff --git a/python/lancedb/utils/events.py b/python/lancedb/utils/events.py index 5f63dfa1..cf1f6e35 100644 --- a/python/lancedb/utils/events.py +++ b/python/lancedb/utils/events.py @@ -35,10 +35,11 @@ from .general import ( class _Events: """ - A class for collecting anonymous event analytics. Event analytics are enabled when ``diagnostics=True`` in config and - disabled when ``diagnostics=False``. + A class for collecting anonymous event analytics. Event analytics are enabled when + ``diagnostics=True`` in config and disabled when ``diagnostics=False``. - You can enable or disable diagnostics by running ``lancedb diagnostics --enabled`` or ``lancedb diagnostics --disabled``. + You can enable or disable diagnostics by running ``lancedb diagnostics --enabled`` + or ``lancedb diagnostics --disabled``. Attributes ---------- @@ -61,7 +62,8 @@ class _Events: def __init__(self): """ - Initializes the Events object with default values for events, rate_limit, and metadata. + Initializes the Events object with default values for events, rate_limit, + and metadata. """ self.events = [] # events list self.throttled_event_names = ["search_table"] @@ -83,7 +85,8 @@ class _Events: "version": importlib.metadata.version("lancedb"), "platforms": PLATFORMS, "session_id": round(random.random() * 1e15), - # 'engagement_time_msec': 1000 # TODO: In future we might be interested in this metric + # TODO: In future we might be interested in this metric + # 'engagement_time_msec': 1000 } TESTS_RUNNING = is_pytest_running() or is_github_actions_ci() @@ -100,7 +103,8 @@ class _Events: def __call__(self, event_name, params={}): """ - Attempts to add a new event to the events list and send events if the rate limit is reached. + Attempts to add a new event to the events list and send events if the rate + limit is reached. Args ---- @@ -109,7 +113,8 @@ class _Events: params : dict, optional A dictionary of additional parameters to be logged with the event. """ - ### NOTE: We might need a way to tag a session with a label to check usage from a source. Setting label should be exposed to the user. + ### NOTE: We might need a way to tag a session with a label to check usage + ### from a source. Setting label should be exposed to the user. if not self.enabled: return if ( @@ -141,8 +146,8 @@ class _Events: "batch": self.events, } # POST equivalent to requests.post(self.url, json=data). - # threaded request is used to avoid blocking, retries are disabled, and verbose is disabled - # to avoid any possible disruption in the console. + # threaded request is used to avoid blocking, retries are disabled, and + # verbose is disabled to avoid any possible disruption in the console. threaded_request( method="post", url=self.url, diff --git a/python/lancedb/utils/general.py b/python/lancedb/utils/general.py index 14141f8c..dcd3f3e4 100644 --- a/python/lancedb/utils/general.py +++ b/python/lancedb/utils/general.py @@ -82,7 +82,8 @@ def is_pip_package(filepath: str = __name__) -> bool: # Get the spec for the module spec = importlib.util.find_spec(filepath) - # Return whether the spec is not None and the origin is not None (indicating it is a package) + # Return whether the spec is not None and the origin is not None (indicating + # it is a package) return spec is not None and spec.origin is not None @@ -108,7 +109,8 @@ def is_github_actions_ci() -> bool: Returns ------- bool - True if the current environment is a GitHub Actions CI Python runner, False otherwise. + True if the current environment is a GitHub Actions CI Python runner, + False otherwise. """ return ( @@ -145,7 +147,7 @@ def is_online() -> bool: for host in "1.1.1.1", "8.8.8.8", "223.5.5.5": # Cloudflare, Google, AliDNS: try: test_connection = socket.create_connection(address=(host, 53), timeout=2) - except (socket.timeout, socket.gaierror, OSError): + except (socket.timeout, socket.gaierror, OSError): # noqa: PERF203 continue else: # If the connection was successful, close it to avoid a ResourceWarning @@ -227,7 +229,8 @@ def is_docker() -> bool: def get_git_dir(): - """Determine whether the current file is part of a git repository and if so, returns the repository root directory. + """Determine whether the current file is part of a git repository and if so, + returns the repository root directory. If the current file is not part of a git repository, returns None. Returns @@ -336,7 +339,7 @@ def yaml_print(yaml_file: Union[str, Path, dict]) -> None: yaml_load(yaml_file) if isinstance(yaml_file, (str, Path)) else yaml_file ) dump = yaml.dump(yaml_dict, sort_keys=False, allow_unicode=True) - LOGGER.info(f"Printing '{yaml_file}'\n\n{dump}") + LOGGER.info("Printing '%s'\n\n%s", yaml_file, dump) PLATFORMS = [platform.system()] @@ -375,7 +378,7 @@ class TryExcept(contextlib.ContextDecorator): def __exit__(self, exc_type, value, traceback): if self.verbose and value: - LOGGER.info(f"{self.msg}{': ' if self.msg else ''}{value}") + LOGGER.info("%s%s%s", self.msg, ": " if self.msg else "", value) return True @@ -383,7 +386,8 @@ def threaded_request( method, url, retry=3, timeout=30, thread=True, code=-1, verbose=True, **kwargs ): """ - Makes an HTTP request using the 'requests' library, with exponential backoff retries up to a specified timeout. + Makes an HTTP request using the 'requests' library, with exponential backoff + retries up to a specified timeout. Parameters ---------- @@ -394,7 +398,8 @@ def threaded_request( retry : int, optional Number of retries to attempt before giving up, by default 3. timeout : int, optional - Timeout in seconds after which the function will give up retrying, by default 30. + Timeout in seconds after which the function will give up retrying, + by default 30. thread : bool, optional Whether to execute the request in a separate daemon thread, by default True. code : int, optional @@ -405,13 +410,17 @@ def threaded_request( Returns ------- requests.Response - The HTTP response object. If the request is executed in a separate thread, returns the thread itself. + The HTTP response object. If the request is executed in a separate thread, + returns the thread itself. """ - retry_codes = () # retry only these codes TODO: add codes if needed in future (500, 408) + # retry only these codes TODO: add codes if needed in future (500, 408) + retry_codes = () @TryExcept(verbose=verbose) def func(method, url, **kwargs): - """Make HTTP requests with retries and timeouts, with optional progress tracking.""" + """Make HTTP requests with retries and timeouts, with optional progress + tracking. + """ response = None t0 = time.time() for i in range(retry + 1): @@ -428,9 +437,9 @@ def threaded_request( if response.status_code in retry_codes: m += f" Retrying {retry}x for {timeout}s." if retry else "" elif response.status_code == 429: # rate limit - m = f"Rate limit reached" + m = "Rate limit reached" if verbose: - LOGGER.warning(f"{response.status_code} #{code}") + LOGGER.warning("%s #%s", response.status_code, m) if response.status_code not in retry_codes: return response time.sleep(2**i) # exponential standoff diff --git a/python/lancedb/utils/sentry_log.py b/python/lancedb/utils/sentry_log.py index 25f13d68..0e92edb4 100644 --- a/python/lancedb/utils/sentry_log.py +++ b/python/lancedb/utils/sentry_log.py @@ -33,10 +33,12 @@ from .general import ( @TryExcept(verbose=False) def set_sentry(): """ - Initialize the Sentry SDK for error tracking and reporting. Only used if sentry_sdk package is installed and - sync=True in settings. Run 'lancedb settings' to see and update settings YAML file. + Initialize the Sentry SDK for error tracking and reporting. Only used if + sentry_sdk package is installed and sync=True in settings. Run 'lancedb settings' + to see and update settings YAML file. - Conditions required to send errors (ALL conditions must be met or no errors will be reported): + Conditions required to send errors (ALL conditions must be met or no errors will + be reported): - sentry_sdk package is installed - sync=True in settings - pytest is not running @@ -44,22 +46,26 @@ def set_sentry(): - running in a non-git directory - online environment - The function also configures Sentry SDK to ignore KeyboardInterrupt and FileNotFoundError - exceptions for now. + The function also configures Sentry SDK to ignore KeyboardInterrupt and + FileNotFoundError exceptions for now. - Additionally, the function sets custom tags and user information for Sentry events. + Additionally, the function sets custom tags and user information for Sentry + events. """ def before_send(event, hint): """ - Modify the event before sending it to Sentry based on specific exception types and messages. + Modify the event before sending it to Sentry based on specific exception + types and messages. Args: event (dict): The event dictionary containing information about the error. - hint (dict): A dictionary containing additional information about the error. + hint (dict): A dictionary containing additional information about + the error. Returns: - dict: The modified event or None if the event should not be sent to Sentry. + dict: The modified event or None if the event should not be sent + to Sentry. """ if "exc_info" in hint: exc_type, exc_value, tb = hint["exc_info"] diff --git a/python/tests/test_cli.py b/python/tests/test_cli.py index 8181ce1f..f43e0fb3 100644 --- a/python/tests/test_cli.py +++ b/python/tests/test_cli.py @@ -15,11 +15,11 @@ def test_diagnostics(): runner = CliRunner() result = runner.invoke(cli, ["diagnostics", "--disabled"]) assert result.exit_code == 0 # Main check - assert CONFIG["diagnostics"] == False + assert not CONFIG["diagnostics"] result = runner.invoke(cli, ["diagnostics", "--enabled"]) assert result.exit_code == 0 # Main check - assert CONFIG["diagnostics"] == True + assert CONFIG["diagnostics"] def test_config(): @@ -28,8 +28,5 @@ def test_config(): assert result.exit_code == 0 # Main check cfg = CONFIG.copy() cfg.pop("uuid") - for ( - item, - _, - ) in cfg.items(): # check for keys only as formatting is subject to change + for item in cfg: # check for keys only as formatting is subject to change assert item in result.output diff --git a/python/tests/test_db.py b/python/tests/test_db.py index 700b34d3..7b716f22 100644 --- a/python/tests/test_db.py +++ b/python/tests/test_db.py @@ -130,7 +130,8 @@ def test_ingest_iterator(tmp_path): PydanticSchema(vector=[3.1, 4.1], item="foo", price=10.0), PydanticSchema(vector=[5.9, 26.5], item="bar", price=20.0), ], - # TODO: test pydict separately. it is unique column number and names contraint + # TODO: test pydict separately. it is unique column number and + # name constraints ] def run_tests(schema): diff --git a/python/tests/test_embeddings.py b/python/tests/test_embeddings.py index 03af14eb..a4c84fb0 100644 --- a/python/tests/test_embeddings.py +++ b/python/tests/test_embeddings.py @@ -18,7 +18,7 @@ import pyarrow as pa import pytest import lancedb -from lancedb.conftest import MockRateLimitedEmbeddingFunction, MockTextEmbeddingFunction +from lancedb.conftest import MockTextEmbeddingFunction from lancedb.embeddings import ( EmbeddingFunctionConfig, EmbeddingFunctionRegistry, diff --git a/python/tests/test_pydantic.py b/python/tests/test_pydantic.py index c6376dce..b37373ee 100644 --- a/python/tests/test_pydantic.py +++ b/python/tests/test_pydantic.py @@ -13,7 +13,6 @@ import json -import pytz import sys from datetime import date, datetime from typing import List, Optional, Tuple @@ -49,18 +48,19 @@ def test_pydantic_to_arrow(): dt_with_tz: datetime = Field(json_schema_extra={"tz": "Asia/Shanghai"}) # d: dict - m = TestModel( - id=1, - s="hello", - vec=[1.0, 2.0, 3.0], - li=[2, 3, 4], - lili=[[2.5, 1.5], [3.5, 4.5], [5.5, 6.5]], - litu=[(2.5, 1.5), (3.5, 4.5), (5.5, 6.5)], - st=StructModel(a="a", b=1.0), - dt=date.today(), - dtt=datetime.now(), - dt_with_tz=datetime.now(pytz.timezone("Asia/Shanghai")), - ) + # TODO: test we can actually convert the model into data. + # m = TestModel( + # id=1, + # s="hello", + # vec=[1.0, 2.0, 3.0], + # li=[2, 3, 4], + # lili=[[2.5, 1.5], [3.5, 4.5], [5.5, 6.5]], + # litu=[(2.5, 1.5), (3.5, 4.5), (5.5, 6.5)], + # st=StructModel(a="a", b=1.0), + # dt=date.today(), + # dtt=datetime.now(), + # dt_with_tz=datetime.now(pytz.timezone("Asia/Shanghai")), + # ) schema = pydantic_to_schema(TestModel) @@ -133,18 +133,19 @@ def test_pydantic_to_arrow_py38(): dt_with_tz: datetime = Field(json_schema_extra={"tz": "Asia/Shanghai"}) # d: dict - m = TestModel( - id=1, - s="hello", - vec=[1.0, 2.0, 3.0], - li=[2, 3, 4], - lili=[[2.5, 1.5], [3.5, 4.5], [5.5, 6.5]], - litu=[(2.5, 1.5), (3.5, 4.5), (5.5, 6.5)], - st=StructModel(a="a", b=1.0), - dt=date.today(), - dtt=datetime.now(), - dt_with_tz=datetime.now(pytz.timezone("Asia/Shanghai")), - ) + # TODO: test we can actually convert the model to Arrow data. + # m = TestModel( + # id=1, + # s="hello", + # vec=[1.0, 2.0, 3.0], + # li=[2, 3, 4], + # lili=[[2.5, 1.5], [3.5, 4.5], [5.5, 6.5]], + # litu=[(2.5, 1.5), (3.5, 4.5), (5.5, 6.5)], + # st=StructModel(a="a", b=1.0), + # dt=date.today(), + # dtt=datetime.now(), + # dt_with_tz=datetime.now(pytz.timezone("Asia/Shanghai")), + # ) schema = pydantic_to_schema(TestModel) diff --git a/python/tests/test_telemetry.py b/python/tests/test_telemetry.py index 256d25c9..b8923408 100644 --- a/python/tests/test_telemetry.py +++ b/python/tests/test_telemetry.py @@ -46,7 +46,8 @@ def test_event_reporting(monkeypatch, request_log_path, tmp_path) -> None: with open(request_log_path, "r") as f: json_data = json.load(f) - # TODO: don't hardcode these here. Instead create a module level json scehma in lancedb.utils.events for better evolvability + # TODO: don't hardcode these here. Instead create a module level json scehma in + # lancedb.utils.events for better evolvability batch_keys = ["api_key", "distinct_id", "batch"] event_keys = ["event", "properties", "timestamp", "distinct_id"] property_keys = ["cli", "install", "platforms", "version", "session_id"]