From 5f6d13e958637466bb8efe434b1a2ff7cb947577 Mon Sep 17 00:00:00 2001 From: Will Jones Date: Fri, 19 Jan 2024 13:09:14 -0800 Subject: [PATCH] ci: lint and enforce linting (#829) @eddyxu added instructions for linting here: https://github.com/lancedb/lancedb/blob/7af213801a091cd5afb0f0814e184fc0b852de47/python/README.md?plain=1#L45-L50 However, we had a lot of failures and weren't checking this in CI. This PR fixes all lints and adds a check to CI to keep us in compliance with the lints. --- .github/workflows/python.yml | 4 +- python/lancedb/conftest.py | 1 - python/lancedb/context.py | 3 +- python/lancedb/embeddings/__init__.py | 2 +- python/lancedb/embeddings/gemini_text.py | 40 +++++++++------ python/lancedb/embeddings/instructor.py | 51 +++++++++++-------- python/lancedb/embeddings/open_clip.py | 6 ++- python/lancedb/embeddings/registry.py | 7 ++- .../embeddings/sentence_transformers.py | 1 - python/lancedb/embeddings/utils.py | 19 ++++--- python/lancedb/fts.py | 4 +- python/lancedb/pydantic.py | 31 ++++++++--- python/lancedb/query.py | 35 ++++++++----- python/lancedb/remote/client.py | 4 +- python/lancedb/remote/db.py | 10 ++-- python/lancedb/remote/table.py | 16 +++--- python/lancedb/table.py | 14 +++-- python/lancedb/util.py | 2 +- python/tests/test_db.py | 3 +- python/tests/test_embeddings.py | 2 +- python/tests/test_pydantic.py | 51 ++++++++++--------- 21 files changed, 183 insertions(+), 123 deletions(-) diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index 835cd41e..dc1c2ace 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -38,8 +38,10 @@ jobs: pip install -e .[tests] pip install tantivy@git+https://github.com/quickwit-oss/tantivy-py#164adc87e1a033117001cf70e38c82a53014d985 pip install pytest pytest-mock ruff - - name: Lint + - name: Format check run: ruff format --check . + - name: Lint + run: ruff . - name: Run tests run: pytest -m "not slow" -x -v --durations=30 tests - name: doctest diff --git a/python/lancedb/conftest.py b/python/lancedb/conftest.py index df4907a7..273afbf7 100644 --- a/python/lancedb/conftest.py +++ b/python/lancedb/conftest.py @@ -1,6 +1,5 @@ import os import time -from typing import Any import numpy as np import pytest diff --git a/python/lancedb/context.py b/python/lancedb/context.py index 02051614..bd7b04c8 100644 --- a/python/lancedb/context.py +++ b/python/lancedb/context.py @@ -59,7 +59,8 @@ def contextualize(raw_df: "pd.DataFrame") -> Contextualizer: 8 dog I love 1 9 I love sandwiches 2 10 love sandwiches 2 - >>> contextualize(data).window(7).stride(1).min_window_size(7).text_col('token').to_pandas() + >>> (contextualize(data).window(7).stride(1).min_window_size(7) + ... .text_col('token').to_pandas()) token document_id 0 The quick brown fox jumped over the 1 1 quick brown fox jumped over the lazy 1 diff --git a/python/lancedb/embeddings/__init__.py b/python/lancedb/embeddings/__init__.py index cea0f381..ff60893e 100644 --- a/python/lancedb/embeddings/__init__.py +++ b/python/lancedb/embeddings/__init__.py @@ -14,10 +14,10 @@ # ruff: noqa: F401 from .base import EmbeddingFunction, EmbeddingFunctionConfig, TextEmbeddingFunction from .cohere import CohereEmbeddingFunction +from .gemini_text import GeminiText from .instructor import InstructorEmbeddingFunction from .open_clip import OpenClipEmbeddings from .openai import OpenAIEmbeddings from .registry import EmbeddingFunctionRegistry, get_registry from .sentence_transformers import SentenceTransformerEmbeddings -from .gemini_text import GeminiText from .utils import with_embeddings diff --git a/python/lancedb/embeddings/gemini_text.py b/python/lancedb/embeddings/gemini_text.py index d71ec7f5..e0f103c2 100644 --- a/python/lancedb/embeddings/gemini_text.py +++ b/python/lancedb/embeddings/gemini_text.py @@ -13,40 +13,50 @@ import os from functools import cached_property -from typing import List, Union, Any +from typing import List, Union import numpy as np +from lancedb.pydantic import PYDANTIC_VERSION + from .base import TextEmbeddingFunction from .registry import register -from .utils import api_key_not_found_help, TEXT -from lancedb.pydantic import PYDANTIC_VERSION +from .utils import TEXT, api_key_not_found_help @register("gemini-text") class GeminiText(TextEmbeddingFunction): """ - An embedding function that uses the Google's Gemini API. Requires GOOGLE_API_KEY to be set. + An embedding function that uses the Google's Gemini API. Requires GOOGLE_API_KEY to + be set. https://ai.google.dev/docs/embeddings_guide Supports various tasks types: - | Task Type | Description | - |-------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------| - | "`retrieval_query`" | Specifies the given text is a query in a search/retrieval setting. | - | "`retrieval_document`" | Specifies the given text is a document in a search/retrieval setting. Using this task type requires a title but is automatically proided by Embeddings API | - | "`semantic_similarity`" | Specifies the given text will be used for Semantic Textual Similarity (STS). | - | "`classification`" | Specifies that the embeddings will be used for classification. | - | "`clusering`" | Specifies that the embeddings will be used for clustering. | + | Task Type | Description | + |-------------------------|--------------------------------------------------------| + | "`retrieval_query`" | Specifies the given text is a query in a | + | | search/retrieval setting. | + | "`retrieval_document`" | Specifies the given text is a document in a | + | | search/retrieval setting. Using this task type | + | | requires a title but is automatically provided by | + | | Embeddings API | + | "`semantic_similarity`" | Specifies the given text will be used for Semantic | + | | Textual Similarity (STS). | + | "`classification`" | Specifies that the embeddings will be used for | + | | classification. | + | "`clustering`" | Specifies that the embeddings will be used for | + | | clustering. | - - Note: The supported task types might change in the Gemini API, but as long as a supported task type and its argument set is provided, - those will be delegated to the API calls. + Note: The supported task types might change in the Gemini API, but as long as a + supported task type and its argument set is provided, those will be delegated + to the API calls. Parameters ---------- name: str, default "models/embedding-001" - The name of the model to use. See the Gemini documentation for a list of available models. + The name of the model to use. See the Gemini documentation for a list of + available models. query_task_type: str, default "retrieval_query" Sets the task type for the queries. diff --git a/python/lancedb/embeddings/instructor.py b/python/lancedb/embeddings/instructor.py index 53be8ccb..c2058b27 100644 --- a/python/lancedb/embeddings/instructor.py +++ b/python/lancedb/embeddings/instructor.py @@ -22,22 +22,29 @@ from .utils import TEXT, weak_lru @register("instructor") class InstructorEmbeddingFunction(TextEmbeddingFunction): """ - An embedding function that uses the InstructorEmbedding library. Instructor models support multi-task learning, and can be used for a - variety of tasks, including text classification, sentence similarity, and document retrieval. - If you want to calculate customized embeddings for specific sentences, you may follow the unified template to write instructions: + An embedding function that uses the InstructorEmbedding library. Instructor models + support multi-task learning, and can be used for a variety of tasks, including + text classification, sentence similarity, and document retrieval. If you want to + calculate customized embeddings for specific sentences, you may follow the unified + template to write instructions: "Represent the `domain` `text_type` for `task_objective`": - * domain is optional, and it specifies the domain of the text, e.g., science, finance, medicine, etc. - * text_type is required, and it specifies the encoding unit, e.g., sentence, document, paragraph, etc. - * task_objective is optional, and it specifies the objective of embedding, e.g., retrieve a document, classify the sentence, etc. + * domain is optional, and it specifies the domain of the text, e.g., science, + finance, medicine, etc. + * text_type is required, and it specifies the encoding unit, e.g., sentence, + document, paragraph, etc. + * task_objective is optional, and it specifies the objective of embedding, + e.g., retrieve a document, classify the sentence, etc. - For example, if you want to calculate embeddings for a document, you may write the instruction as follows: - "Represent the document for retreival" + For example, if you want to calculate embeddings for a document, you may write the + instruction as follows: + "Represent the document for retrieval" Parameters ---------- name: str - The name of the model to use. Available models are listed at https://github.com/xlang-ai/instructor-embedding#model-list; + The name of the model to use. Available models are listed at + https://github.com/xlang-ai/instructor-embedding#model-list; The default model is hkunlp/instructor-base batch_size: int, default 32 The batch size to use when generating embeddings @@ -49,21 +56,24 @@ class InstructorEmbeddingFunction(TextEmbeddingFunction): Whether to normalize the embeddings quantize: bool, default False Whether to quantize the model - source_instruction: str, default "represent the docuement for retreival" + source_instruction: str, default "represent the document for retrieval" The instruction for the source column - query_instruction: str, default "represent the document for retreiving the most similar documents" + query_instruction: str, default "represent the document for retrieving the most + similar documents" The instruction for the query Examples -------- + import lancedb from lancedb.pydantic import LanceModel, Vector from lancedb.embeddings import get_registry, InstuctorEmbeddingFunction instructor = get_registry().get("instructor").create( - source_instruction="represent the docuement for retreival", - query_instruction="represent the document for retreiving the most similar documents" - ) + source_instruction="represent the document for retrieval", + query_instruction="represent the document for retrieving the most " + "similar documents" + ) class Schema(LanceModel): vector: Vector(instructor.ndims()) = instructor.VectorField() @@ -72,9 +82,12 @@ class InstructorEmbeddingFunction(TextEmbeddingFunction): db = lancedb.connect("~/.lancedb") tbl = db.create_table("test", schema=Schema, mode="overwrite") - texts = [{"text": "Capitalism has been dominant in the Western world since the end of feudalism, but most feel[who?] that..."}, - {"text": "The disparate impact theory is especially controversial under the Fair Housing Act because the Act..."}, - {"text": "Disparate impact in United States labor law refers to practices in employment, housing, and other areas that.."}] + texts = [{"text": "Capitalism has been dominant in the Western world since the " + "end of feudalism, but most feel[who?] that..."}, + {"text": "The disparate impact theory is especially controversial under " + "the Fair Housing Act because the Act..."}, + {"text": "Disparate impact in United States labor law refers to practices " + "in employment, housing, and other areas that.."}] tbl.add(texts) @@ -103,9 +116,7 @@ class InstructorEmbeddingFunction(TextEmbeddingFunction): def compute_source_embeddings(self, texts: TEXT, *args, **kwargs) -> List[np.array]: texts = self.sanitize_input(texts) - texts_formatted = [] - for text in texts: - texts_formatted.append([self.source_instruction, text]) + texts_formatted = [[self.source_instruction, text] for text in texts] return self.generate_embeddings(texts_formatted) def generate_embeddings(self, texts: List) -> List: diff --git a/python/lancedb/embeddings/open_clip.py b/python/lancedb/embeddings/open_clip.py index 1e9aeb6e..6392b0ef 100644 --- a/python/lancedb/embeddings/open_clip.py +++ b/python/lancedb/embeddings/open_clip.py @@ -14,7 +14,7 @@ import concurrent.futures import io import os import urllib.parse as urlparse -from typing import List, Union +from typing import TYPE_CHECKING, List, Union import numpy as np import pyarrow as pa @@ -25,6 +25,10 @@ from .base import EmbeddingFunction from .registry import register from .utils import IMAGES, url_retrieve +if TYPE_CHECKING: + import PIL + import torch + @register("open-clip") class OpenClipEmbeddings(EmbeddingFunction): diff --git a/python/lancedb/embeddings/registry.py b/python/lancedb/embeddings/registry.py index af7600dc..d5ab1f35 100644 --- a/python/lancedb/embeddings/registry.py +++ b/python/lancedb/embeddings/registry.py @@ -23,7 +23,9 @@ class EmbeddingFunctionRegistry: You can implement your own embedding function by subclassing EmbeddingFunction or TextEmbeddingFunction and registering it with the registry. - NOTE: Here TEXT is a type alias for Union[str, List[str], pa.Array, pa.ChunkedArray, np.ndarray] + NOTE: Here TEXT is a type alias for Union[str, List[str], pa.Array, + pa.ChunkedArray, np.ndarray] + Examples -------- >>> registry = EmbeddingFunctionRegistry.get_instance() @@ -164,7 +166,8 @@ __REGISTRY__ = EmbeddingFunctionRegistry() # @EmbeddingFunctionRegistry.get_instance().register(name) doesn't work in 3.8 -register = lambda name: EmbeddingFunctionRegistry.get_instance().register(name) +def register(name): + return __REGISTRY__.get_instance().register(name) def get_registry(): diff --git a/python/lancedb/embeddings/sentence_transformers.py b/python/lancedb/embeddings/sentence_transformers.py index 045fe2eb..d958e054 100644 --- a/python/lancedb/embeddings/sentence_transformers.py +++ b/python/lancedb/embeddings/sentence_transformers.py @@ -13,7 +13,6 @@ from typing import List, Union import numpy as np -from cachetools import cached from .base import TextEmbeddingFunction from .registry import register diff --git a/python/lancedb/embeddings/utils.py b/python/lancedb/embeddings/utils.py index 5aea89d2..325145f4 100644 --- a/python/lancedb/embeddings/utils.py +++ b/python/lancedb/embeddings/utils.py @@ -112,7 +112,8 @@ class FunctionWrapper: v = int(sys.version_info.minor) if v >= 11: print( - "WARNING: rate limit only support up to 3.10, proceeding without rate limiter" + "WARNING: rate limit only support up to 3.10, proceeding " + "without rate limiter" ) else: import ratelimiter @@ -168,8 +169,8 @@ class FunctionWrapper: def weak_lru(maxsize=128): """ - LRU cache that keeps weak references to the objects it caches. Only caches the latest instance of the objects to make sure memory usage - is bounded. + LRU cache that keeps weak references to the objects it caches. Only caches the + latest instance of the objects to make sure memory usage is bounded. Parameters ---------- @@ -234,15 +235,17 @@ def retry_with_exponential_backoff( num_retries = 0 delay = initial_delay - # Loop until a successful response or max_retries is hit or an exception is raised + # Loop until a successful response or max_retries is hit or an exception + # is raised while True: try: return func(*args, **kwargs) - # Currently retrying on all exceptions as there is no way to know the format of the error msgs used by different APIs - # We'll log the error and say that it is assumed that if this portion errors out, it's due to rate limit but the user - # should check the error message to be sure - except Exception as e: + # Currently retrying on all exceptions as there is no way to know the + # format of the error msgs used by different APIs. We'll log the error + # and say that it is assumed that if this portion errors out, it's due + # to rate limit but the user should check the error message to be sure. + except Exception as e: # noqa: PERF203 num_retries += 1 if num_retries > max_retries: diff --git a/python/lancedb/fts.py b/python/lancedb/fts.py index f9667fcc..750e3076 100644 --- a/python/lancedb/fts.py +++ b/python/lancedb/fts.py @@ -13,7 +13,7 @@ """Full text search index using tantivy-py""" import os -from typing import List, Optional, Tuple +from typing import List, Tuple import pyarrow as pa @@ -21,7 +21,7 @@ try: import tantivy except ImportError: raise ImportError( - "Please install tantivy-py `pip install tantivy@git+https://github.com/quickwit-oss/tantivy-py#164adc87e1a033117001cf70e38c82a53014d985` to use the full text search feature." + "Please install tantivy-py `pip install tantivy@git+https://github.com/quickwit-oss/tantivy-py#164adc87e1a033117001cf70e38c82a53014d985` to use the full text search feature." # noqa: E501 ) from .table import LanceTable diff --git a/python/lancedb/pydantic.py b/python/lancedb/pydantic.py index 859eeaa8..2a550032 100644 --- a/python/lancedb/pydantic.py +++ b/python/lancedb/pydantic.py @@ -20,15 +20,22 @@ import sys import types from abc import ABC, abstractmethod from datetime import date, datetime -from typing import Any, Callable, Dict, Generator, List, Type, Union, _GenericAlias +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Generator, + List, + Type, + Union, + _GenericAlias, +) import numpy as np import pyarrow as pa import pydantic import semver -from pydantic.fields import FieldInfo - -from .embeddings import EmbeddingFunctionRegistry PYDANTIC_VERSION = semver.Version.parse(pydantic.__version__) try: @@ -37,6 +44,11 @@ except ImportError: if PYDANTIC_VERSION >= (2,): raise +if TYPE_CHECKING: + from pydantic.fields import FieldInfo + + from .embeddings import EmbeddingFunctionConfig + class FixedSizeListMixin(ABC): @staticmethod @@ -190,7 +202,7 @@ else: ] -def _pydantic_to_arrow_type(field: pydantic.fields.FieldInfo) -> pa.DataType: +def _pydantic_to_arrow_type(field: FieldInfo) -> pa.DataType: """Convert a Pydantic FieldInfo to Arrow DataType""" if isinstance(field.annotation, _GenericAlias) or ( @@ -221,7 +233,7 @@ def _pydantic_to_arrow_type(field: pydantic.fields.FieldInfo) -> pa.DataType: return _py_type_to_arrow_type(field.annotation, field) -def is_nullable(field: pydantic.fields.FieldInfo) -> bool: +def is_nullable(field: FieldInfo) -> bool: """Check if a Pydantic FieldInfo is nullable.""" if isinstance(field.annotation, _GenericAlias): origin = field.annotation.__origin__ @@ -237,7 +249,7 @@ def is_nullable(field: pydantic.fields.FieldInfo) -> bool: return False -def _pydantic_to_field(name: str, field: pydantic.fields.FieldInfo) -> pa.Field: +def _pydantic_to_field(name: str, field: FieldInfo) -> pa.Field: """Convert a Pydantic field to a PyArrow Field.""" dt = _pydantic_to_arrow_type(field) return pa.field(name, dt, is_nullable(field)) @@ -309,6 +321,9 @@ class LanceModel(pydantic.BaseModel): schema = pydantic_to_schema(cls) functions = cls.parse_embedding_functions() if len(functions) > 0: + # Prevent circular import + from .embeddings import EmbeddingFunctionRegistry + metadata = EmbeddingFunctionRegistry.get_instance().get_table_metadata( functions ) @@ -359,7 +374,7 @@ class LanceModel(pydantic.BaseModel): return configs -def get_extras(field_info: pydantic.fields.FieldInfo, key: str) -> Any: +def get_extras(field_info: FieldInfo, key: str) -> Any: """ Get the extra metadata from a Pydantic FieldInfo. """ diff --git a/python/lancedb/query.py b/python/lancedb/query.py index 51eab73a..23e76c0d 100644 --- a/python/lancedb/query.py +++ b/python/lancedb/query.py @@ -27,7 +27,11 @@ from .common import VECTOR_COLUMN_NAME from .util import safe_import_pandas if TYPE_CHECKING: + import PIL + import polars as pl + from .pydantic import LanceModel + from .table import Table pd = safe_import_pandas() @@ -60,7 +64,7 @@ class Query(pydantic.BaseModel): - See discussion in [Querying an ANN Index][querying-an-ann-index] for tuning advice. refine_factor : Optional[int] - Refine the results by reading extra elements and re-ranking them in memory - optional + Refine the results by reading extra elements and re-ranking them in memory. - A higher number makes search more accurate but also slower. @@ -104,7 +108,7 @@ class LanceQueryBuilder(ABC): @classmethod def create( cls, - table: "lancedb.table.Table", + table: "Table", query: Optional[Union[np.ndarray, str, "PIL.Image.Image"]], query_type: str, vector_column_name: str, @@ -163,7 +167,7 @@ class LanceQueryBuilder(ABC): f"Invalid query_type, must be 'vector', 'fts', or 'auto': {query_type}" ) - def __init__(self, table: "lancedb.table.Table"): + def __init__(self, table: "Table"): self._table = table self._limit = 10 self._columns = None @@ -205,7 +209,6 @@ class LanceQueryBuilder(ABC): if flatten is True: while True: tbl = tbl.flatten() - has_struct = False # loop through all columns to check if there is any struct column if any(pa.types.is_struct(col.type) for col in tbl.schema): continue @@ -214,7 +217,8 @@ class LanceQueryBuilder(ABC): elif isinstance(flatten, int): if flatten <= 0: raise ValueError( - "Please specify a positive integer for flatten or the boolean value `True`" + "Please specify a positive integer for flatten or the boolean " + "value `True`" ) while flatten > 0: tbl = tbl.flatten() @@ -361,7 +365,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder): def __init__( self, - table: "lancedb.table.Table", + table: "Table", query: Union[np.ndarray, list, "PIL.Image.Image"], vector_column: str = VECTOR_COLUMN_NAME, ): @@ -486,7 +490,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder): class LanceFtsQueryBuilder(LanceQueryBuilder): """A builder for full text search for LanceDB.""" - def __init__(self, table: "lancedb.table.Table", query: str): + def __init__(self, table: "Table", query: str): super().__init__(table) self._query = query self._phrase_query = False @@ -513,7 +517,7 @@ class LanceFtsQueryBuilder(LanceQueryBuilder): import tantivy except ImportError: raise ImportError( - "Please install tantivy-py `pip install tantivy@git+https://github.com/quickwit-oss/tantivy-py#164adc87e1a033117001cf70e38c82a53014d985` to use the full text search feature." + "Please install tantivy-py `pip install tantivy@git+https://github.com/quickwit-oss/tantivy-py#164adc87e1a033117001cf70e38c82a53014d985` to use the full text search feature." # noqa: E501 ) from .fts import search_index @@ -523,8 +527,9 @@ class LanceFtsQueryBuilder(LanceQueryBuilder): # check if the index exist if not Path(index_path).exists(): raise FileNotFoundError( - "Fts index does not exist." - f"Please first call table.create_fts_index(['']) to create the fts index." + "Fts index does not exist. " + "Please first call table.create_fts_index(['']) to " + "create the fts index." ) # open the index index = tantivy.Index.open(index_path) @@ -543,19 +548,21 @@ class LanceFtsQueryBuilder(LanceQueryBuilder): if self._where is not None: try: - # TODO would be great to have Substrait generate pyarrow compute expressions - # or conversely have pyarrow support SQL expressions using Substrait + # TODO would be great to have Substrait generate pyarrow compute + # expressions or conversely have pyarrow support SQL expressions + # using Substrait import duckdb output_tbl = ( - duckdb.sql(f"SELECT * FROM output_tbl") + duckdb.sql("SELECT * FROM output_tbl") .filter(self._where) .to_arrow_table() ) except ImportError: - import lance import tempfile + import lance + # TODO Use "memory://" instead once that's supported with tempfile.TemporaryDirectory() as tmp: ds = lance.write_dataset(output_tbl, tmp) diff --git a/python/lancedb/remote/client.py b/python/lancedb/remote/client.py index 6f9bf292..9d0ea9d0 100644 --- a/python/lancedb/remote/client.py +++ b/python/lancedb/remote/client.py @@ -13,12 +13,12 @@ import functools -from typing import Any, Callable, Dict, Iterable, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union from urllib.parse import urljoin -import requests import attrs import pyarrow as pa +import requests from pydantic import BaseModel from lancedb.common import Credential diff --git a/python/lancedb/remote/db.py b/python/lancedb/remote/db.py index 337406db..3a88152f 100644 --- a/python/lancedb/remote/db.py +++ b/python/lancedb/remote/db.py @@ -11,7 +11,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import asyncio import inspect import logging import uuid @@ -102,8 +101,10 @@ class RemoteDBConnection(DBConnection): except LanceDBClientError as err: if str(err).startswith("Not found"): logging.error( - f"Table {name} does not exist. " - f"Please first call db.create_table({name}, data)" + "Table %s does not exist. Please first call " + "db.create_table(%s, data).", + name, + name, ) return RemoteTable(self, name) @@ -160,7 +161,8 @@ class RemoteDBConnection(DBConnection): Can create with list of tuples or dictionaries: >>> import lancedb - >>> db = lancedb.connect("db://...", api_key="...", region="...") # doctest: +SKIP + >>> db = lancedb.connect("db://...", api_key="...", # doctest: +SKIP + ... region="...") # doctest: +SKIP >>> data = [{"vector": [1.1, 1.2], "lat": 45.5, "long": -122.7}, ... {"vector": [0.2, 1.8], "lat": 40.1, "long": -74.1}] >>> db.create_table("my_table", data) # doctest: +SKIP diff --git a/python/lancedb/remote/table.py b/python/lancedb/remote/table.py index 63572ebb..4878bc1d 100644 --- a/python/lancedb/remote/table.py +++ b/python/lancedb/remote/table.py @@ -11,7 +11,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import asyncio import uuid from functools import cached_property from typing import Dict, Optional, Union @@ -89,7 +88,8 @@ class RemoteTable(Table): >>> import lancedb >>> import uuid >>> from lancedb.schema import vector - >>> db = lancedb.connect("db://...", api_key="...", region="...") # doctest: +SKIP + >>> db = lancedb.connect("db://...", api_key="...", # doctest: +SKIP + ... region="...") # doctest: +SKIP >>> table_name = uuid.uuid4().hex >>> schema = pa.schema( ... [ @@ -125,7 +125,8 @@ class RemoteTable(Table): on_bad_vectors: str = "error", fill_value: float = 0.0, ) -> int: - """Add more data to the [Table](Table). It has the same API signature as the OSS version. + """Add more data to the [Table](Table). It has the same API signature as + the OSS version. Parameters ---------- @@ -176,7 +177,8 @@ class RemoteTable(Table): Examples -------- >>> import lancedb - >>> db = lancedb.connect("db://...", api_key="...", region="...") # doctest: +SKIP + >>> db = lancedb.connect("db://...", api_key="...", # doctest: +SKIP + ... region="...") # doctest: +SKIP >>> data = [ ... {"original_width": 100, "caption": "bar", "vector": [0.1, 2.3, 4.5]}, ... {"original_width": 2000, "caption": "foo", "vector": [0.5, 3.4, 1.3]}, @@ -265,7 +267,8 @@ class RemoteTable(Table): ... {"x": 2, "vector": [3, 4]}, ... {"x": 3, "vector": [5, 6]} ... ] - >>> db = lancedb.connect("db://...", api_key="...", region="...") # doctest: +SKIP + >>> db = lancedb.connect("db://...", api_key="...", # doctest: +SKIP + ... region="...") # doctest: +SKIP >>> table = db.create_table("my_table", data) # doctest: +SKIP >>> table.search([10,10]).to_pandas() # doctest: +SKIP x vector _distance # doctest: +SKIP @@ -323,7 +326,8 @@ class RemoteTable(Table): ... {"x": 2, "vector": [3, 4]}, ... {"x": 3, "vector": [5, 6]} ... ] - >>> db = lancedb.connect("db://...", api_key="...", region="...") # doctest: +SKIP + >>> db = lancedb.connect("db://...", api_key="...", # doctest: +SKIP + ... region="...") # doctest: +SKIP >>> table = db.create_table("my_table", data) # doctest: +SKIP >>> table.to_pandas() # doctest: +SKIP x vector # doctest: +SKIP diff --git a/python/lancedb/table.py b/python/lancedb/table.py index 141db4fb..930b65fc 100644 --- a/python/lancedb/table.py +++ b/python/lancedb/table.py @@ -14,7 +14,6 @@ from __future__ import annotations import inspect -import os from abc import ABC, abstractmethod from functools import cached_property from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Union @@ -33,17 +32,20 @@ from .pydantic import LanceModel, model_to_dict from .query import LanceQueryBuilder, Query from .util import ( fs_from_uri, + join_uri, safe_import_pandas, safe_import_polars, value_to_sql, - join_uri, ) if TYPE_CHECKING: from datetime import timedelta + import PIL from lance.dataset import CleanupStats, ReaderLike + from .db import LanceDBConnection + pd = safe_import_pandas() pl = safe_import_polars() @@ -525,9 +527,7 @@ class LanceTable(Table): A table in a LanceDB database. """ - def __init__( - self, connection: "lancedb.db.LanceDBConnection", name: str, version: int = None - ): + def __init__(self, connection: "LanceDBConnection", name: str, version: int = None): self._conn = connection self.name = name self._version = version @@ -781,9 +781,7 @@ class LanceTable(Table): index_exists = fs.get_file_info(path).type != pa_fs.FileType.NotFound if index_exists: if not replace: - raise ValueError( - f"Index already exists. Use replace=True to overwrite." - ) + raise ValueError("Index already exists. Use replace=True to overwrite.") fs.delete_dir(path) index = create_index(self._get_fts_index_path(), field_names) diff --git a/python/lancedb/util.py b/python/lancedb/util.py index 6122e4e0..5b4f3c65 100644 --- a/python/lancedb/util.py +++ b/python/lancedb/util.py @@ -12,9 +12,9 @@ # limitations under the License. import os +import pathlib from datetime import date, datetime from functools import singledispatch -import pathlib from typing import Tuple, Union from urllib.parse import urlparse diff --git a/python/tests/test_db.py b/python/tests/test_db.py index 700b34d3..7b716f22 100644 --- a/python/tests/test_db.py +++ b/python/tests/test_db.py @@ -130,7 +130,8 @@ def test_ingest_iterator(tmp_path): PydanticSchema(vector=[3.1, 4.1], item="foo", price=10.0), PydanticSchema(vector=[5.9, 26.5], item="bar", price=20.0), ], - # TODO: test pydict separately. it is unique column number and names contraint + # TODO: test pydict separately. it is unique column number and + # name constraints ] def run_tests(schema): diff --git a/python/tests/test_embeddings.py b/python/tests/test_embeddings.py index 03af14eb..a4c84fb0 100644 --- a/python/tests/test_embeddings.py +++ b/python/tests/test_embeddings.py @@ -18,7 +18,7 @@ import pyarrow as pa import pytest import lancedb -from lancedb.conftest import MockRateLimitedEmbeddingFunction, MockTextEmbeddingFunction +from lancedb.conftest import MockTextEmbeddingFunction from lancedb.embeddings import ( EmbeddingFunctionConfig, EmbeddingFunctionRegistry, diff --git a/python/tests/test_pydantic.py b/python/tests/test_pydantic.py index c6376dce..b37373ee 100644 --- a/python/tests/test_pydantic.py +++ b/python/tests/test_pydantic.py @@ -13,7 +13,6 @@ import json -import pytz import sys from datetime import date, datetime from typing import List, Optional, Tuple @@ -49,18 +48,19 @@ def test_pydantic_to_arrow(): dt_with_tz: datetime = Field(json_schema_extra={"tz": "Asia/Shanghai"}) # d: dict - m = TestModel( - id=1, - s="hello", - vec=[1.0, 2.0, 3.0], - li=[2, 3, 4], - lili=[[2.5, 1.5], [3.5, 4.5], [5.5, 6.5]], - litu=[(2.5, 1.5), (3.5, 4.5), (5.5, 6.5)], - st=StructModel(a="a", b=1.0), - dt=date.today(), - dtt=datetime.now(), - dt_with_tz=datetime.now(pytz.timezone("Asia/Shanghai")), - ) + # TODO: test we can actually convert the model into data. + # m = TestModel( + # id=1, + # s="hello", + # vec=[1.0, 2.0, 3.0], + # li=[2, 3, 4], + # lili=[[2.5, 1.5], [3.5, 4.5], [5.5, 6.5]], + # litu=[(2.5, 1.5), (3.5, 4.5), (5.5, 6.5)], + # st=StructModel(a="a", b=1.0), + # dt=date.today(), + # dtt=datetime.now(), + # dt_with_tz=datetime.now(pytz.timezone("Asia/Shanghai")), + # ) schema = pydantic_to_schema(TestModel) @@ -133,18 +133,19 @@ def test_pydantic_to_arrow_py38(): dt_with_tz: datetime = Field(json_schema_extra={"tz": "Asia/Shanghai"}) # d: dict - m = TestModel( - id=1, - s="hello", - vec=[1.0, 2.0, 3.0], - li=[2, 3, 4], - lili=[[2.5, 1.5], [3.5, 4.5], [5.5, 6.5]], - litu=[(2.5, 1.5), (3.5, 4.5), (5.5, 6.5)], - st=StructModel(a="a", b=1.0), - dt=date.today(), - dtt=datetime.now(), - dt_with_tz=datetime.now(pytz.timezone("Asia/Shanghai")), - ) + # TODO: test we can actually convert the model to Arrow data. + # m = TestModel( + # id=1, + # s="hello", + # vec=[1.0, 2.0, 3.0], + # li=[2, 3, 4], + # lili=[[2.5, 1.5], [3.5, 4.5], [5.5, 6.5]], + # litu=[(2.5, 1.5), (3.5, 4.5), (5.5, 6.5)], + # st=StructModel(a="a", b=1.0), + # dt=date.today(), + # dtt=datetime.now(), + # dt_with_tz=datetime.now(pytz.timezone("Asia/Shanghai")), + # ) schema = pydantic_to_schema(TestModel)