ci: lint and enforce linting (#829)

@eddyxu added instructions for linting here:


7af213801a/python/README.md (L45-L50)

However, we had a lot of failures and weren't checking this in CI. This
PR fixes all lints and adds a check to CI to keep us in compliance with
the lints.
This commit is contained in:
Will Jones
2024-01-19 13:09:14 -08:00
committed by GitHub
parent 7af213801a
commit d012db24c2
28 changed files with 250 additions and 169 deletions

View File

@@ -38,8 +38,10 @@ jobs:
pip install -e .[tests]
pip install tantivy@git+https://github.com/quickwit-oss/tantivy-py#164adc87e1a033117001cf70e38c82a53014d985
pip install pytest pytest-mock ruff
- name: Lint
- name: Format check
run: ruff format --check .
- name: Lint
run: ruff .
- name: Run tests
run: pytest -m "not slow" -x -v --durations=30 tests
- name: doctest

View File

@@ -23,9 +23,10 @@ def cli():
diagnostics_help = """
Enable or disable LanceDB diagnostics. When enabled, LanceDB will send anonymous events to help us improve LanceDB.
These diagnostics are used only for error reporting and no data is collected. You can find more about diagnosis on
our docs: https://lancedb.github.io/lancedb/cli_config/
Enable or disable LanceDB diagnostics. When enabled, LanceDB will send anonymous events
to help us improve LanceDB. These diagnostics are used only for error reporting and no
data is collected. You can find more about diagnosis on our docs:
https://lancedb.github.io/lancedb/cli_config/
"""

View File

@@ -1,6 +1,5 @@
import os
import time
from typing import Any
import numpy as np
import pytest

View File

@@ -59,7 +59,8 @@ def contextualize(raw_df: "pd.DataFrame") -> Contextualizer:
8 dog I love 1
9 I love sandwiches 2
10 love sandwiches 2
>>> contextualize(data).window(7).stride(1).min_window_size(7).text_col('token').to_pandas()
>>> (contextualize(data).window(7).stride(1).min_window_size(7)
... .text_col('token').to_pandas())
token document_id
0 The quick brown fox jumped over the 1
1 quick brown fox jumped over the lazy 1

View File

@@ -14,10 +14,10 @@
# ruff: noqa: F401
from .base import EmbeddingFunction, EmbeddingFunctionConfig, TextEmbeddingFunction
from .cohere import CohereEmbeddingFunction
from .gemini_text import GeminiText
from .instructor import InstructorEmbeddingFunction
from .open_clip import OpenClipEmbeddings
from .openai import OpenAIEmbeddings
from .registry import EmbeddingFunctionRegistry, get_registry
from .sentence_transformers import SentenceTransformerEmbeddings
from .gemini_text import GeminiText
from .utils import with_embeddings

View File

@@ -13,40 +13,50 @@
import os
from functools import cached_property
from typing import List, Union, Any
from typing import List, Union
import numpy as np
from lancedb.pydantic import PYDANTIC_VERSION
from .base import TextEmbeddingFunction
from .registry import register
from .utils import api_key_not_found_help, TEXT
from lancedb.pydantic import PYDANTIC_VERSION
from .utils import TEXT, api_key_not_found_help
@register("gemini-text")
class GeminiText(TextEmbeddingFunction):
"""
An embedding function that uses the Google's Gemini API. Requires GOOGLE_API_KEY to be set.
An embedding function that uses the Google's Gemini API. Requires GOOGLE_API_KEY to
be set.
https://ai.google.dev/docs/embeddings_guide
Supports various tasks types:
| Task Type | Description |
|-------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------|
| "`retrieval_query`" | Specifies the given text is a query in a search/retrieval setting. |
| "`retrieval_document`" | Specifies the given text is a document in a search/retrieval setting. Using this task type requires a title but is automatically proided by Embeddings API |
| "`semantic_similarity`" | Specifies the given text will be used for Semantic Textual Similarity (STS). |
| "`classification`" | Specifies that the embeddings will be used for classification. |
| "`clusering`" | Specifies that the embeddings will be used for clustering. |
| Task Type | Description |
|-------------------------|--------------------------------------------------------|
| "`retrieval_query`" | Specifies the given text is a query in a |
| | search/retrieval setting. |
| "`retrieval_document`" | Specifies the given text is a document in a |
| | search/retrieval setting. Using this task type |
| | requires a title but is automatically provided by |
| | Embeddings API |
| "`semantic_similarity`" | Specifies the given text will be used for Semantic |
| | Textual Similarity (STS). |
| "`classification`" | Specifies that the embeddings will be used for |
| | classification. |
| "`clustering`" | Specifies that the embeddings will be used for |
| | clustering. |
Note: The supported task types might change in the Gemini API, but as long as a supported task type and its argument set is provided,
those will be delegated to the API calls.
Note: The supported task types might change in the Gemini API, but as long as a
supported task type and its argument set is provided, those will be delegated
to the API calls.
Parameters
----------
name: str, default "models/embedding-001"
The name of the model to use. See the Gemini documentation for a list of available models.
The name of the model to use. See the Gemini documentation for a list of
available models.
query_task_type: str, default "retrieval_query"
Sets the task type for the queries.

View File

@@ -22,22 +22,29 @@ from .utils import TEXT, weak_lru
@register("instructor")
class InstructorEmbeddingFunction(TextEmbeddingFunction):
"""
An embedding function that uses the InstructorEmbedding library. Instructor models support multi-task learning, and can be used for a
variety of tasks, including text classification, sentence similarity, and document retrieval.
If you want to calculate customized embeddings for specific sentences, you may follow the unified template to write instructions:
An embedding function that uses the InstructorEmbedding library. Instructor models
support multi-task learning, and can be used for a variety of tasks, including
text classification, sentence similarity, and document retrieval. If you want to
calculate customized embeddings for specific sentences, you may follow the unified
template to write instructions:
"Represent the `domain` `text_type` for `task_objective`":
* domain is optional, and it specifies the domain of the text, e.g., science, finance, medicine, etc.
* text_type is required, and it specifies the encoding unit, e.g., sentence, document, paragraph, etc.
* task_objective is optional, and it specifies the objective of embedding, e.g., retrieve a document, classify the sentence, etc.
* domain is optional, and it specifies the domain of the text, e.g., science,
finance, medicine, etc.
* text_type is required, and it specifies the encoding unit, e.g., sentence,
document, paragraph, etc.
* task_objective is optional, and it specifies the objective of embedding,
e.g., retrieve a document, classify the sentence, etc.
For example, if you want to calculate embeddings for a document, you may write the instruction as follows:
"Represent the document for retreival"
For example, if you want to calculate embeddings for a document, you may write the
instruction as follows:
"Represent the document for retrieval"
Parameters
----------
name: str
The name of the model to use. Available models are listed at https://github.com/xlang-ai/instructor-embedding#model-list;
The name of the model to use. Available models are listed at
https://github.com/xlang-ai/instructor-embedding#model-list;
The default model is hkunlp/instructor-base
batch_size: int, default 32
The batch size to use when generating embeddings
@@ -49,21 +56,24 @@ class InstructorEmbeddingFunction(TextEmbeddingFunction):
Whether to normalize the embeddings
quantize: bool, default False
Whether to quantize the model
source_instruction: str, default "represent the docuement for retreival"
source_instruction: str, default "represent the document for retrieval"
The instruction for the source column
query_instruction: str, default "represent the document for retreiving the most similar documents"
query_instruction: str, default "represent the document for retrieving the most
similar documents"
The instruction for the query
Examples
--------
import lancedb
from lancedb.pydantic import LanceModel, Vector
from lancedb.embeddings import get_registry, InstuctorEmbeddingFunction
instructor = get_registry().get("instructor").create(
source_instruction="represent the docuement for retreival",
query_instruction="represent the document for retreiving the most similar documents"
)
source_instruction="represent the document for retrieval",
query_instruction="represent the document for retrieving the most "
"similar documents"
)
class Schema(LanceModel):
vector: Vector(instructor.ndims()) = instructor.VectorField()
@@ -72,9 +82,12 @@ class InstructorEmbeddingFunction(TextEmbeddingFunction):
db = lancedb.connect("~/.lancedb")
tbl = db.create_table("test", schema=Schema, mode="overwrite")
texts = [{"text": "Capitalism has been dominant in the Western world since the end of feudalism, but most feel[who?] that..."},
{"text": "The disparate impact theory is especially controversial under the Fair Housing Act because the Act..."},
{"text": "Disparate impact in United States labor law refers to practices in employment, housing, and other areas that.."}]
texts = [{"text": "Capitalism has been dominant in the Western world since the "
"end of feudalism, but most feel[who?] that..."},
{"text": "The disparate impact theory is especially controversial under "
"the Fair Housing Act because the Act..."},
{"text": "Disparate impact in United States labor law refers to practices "
"in employment, housing, and other areas that.."}]
tbl.add(texts)
@@ -103,9 +116,7 @@ class InstructorEmbeddingFunction(TextEmbeddingFunction):
def compute_source_embeddings(self, texts: TEXT, *args, **kwargs) -> List[np.array]:
texts = self.sanitize_input(texts)
texts_formatted = []
for text in texts:
texts_formatted.append([self.source_instruction, text])
texts_formatted = [[self.source_instruction, text] for text in texts]
return self.generate_embeddings(texts_formatted)
def generate_embeddings(self, texts: List) -> List:

View File

@@ -14,7 +14,7 @@ import concurrent.futures
import io
import os
import urllib.parse as urlparse
from typing import List, Union
from typing import TYPE_CHECKING, List, Union
import numpy as np
import pyarrow as pa
@@ -25,6 +25,10 @@ from .base import EmbeddingFunction
from .registry import register
from .utils import IMAGES, url_retrieve
if TYPE_CHECKING:
import PIL
import torch
@register("open-clip")
class OpenClipEmbeddings(EmbeddingFunction):

View File

@@ -23,7 +23,9 @@ class EmbeddingFunctionRegistry:
You can implement your own embedding function by subclassing EmbeddingFunction
or TextEmbeddingFunction and registering it with the registry.
NOTE: Here TEXT is a type alias for Union[str, List[str], pa.Array, pa.ChunkedArray, np.ndarray]
NOTE: Here TEXT is a type alias for Union[str, List[str], pa.Array,
pa.ChunkedArray, np.ndarray]
Examples
--------
>>> registry = EmbeddingFunctionRegistry.get_instance()
@@ -164,7 +166,8 @@ __REGISTRY__ = EmbeddingFunctionRegistry()
# @EmbeddingFunctionRegistry.get_instance().register(name) doesn't work in 3.8
register = lambda name: EmbeddingFunctionRegistry.get_instance().register(name)
def register(name):
return __REGISTRY__.get_instance().register(name)
def get_registry():

View File

@@ -13,7 +13,6 @@
from typing import List, Union
import numpy as np
from cachetools import cached
from .base import TextEmbeddingFunction
from .registry import register

View File

@@ -112,7 +112,8 @@ class FunctionWrapper:
v = int(sys.version_info.minor)
if v >= 11:
print(
"WARNING: rate limit only support up to 3.10, proceeding without rate limiter"
"WARNING: rate limit only support up to 3.10, proceeding "
"without rate limiter"
)
else:
import ratelimiter
@@ -168,8 +169,8 @@ class FunctionWrapper:
def weak_lru(maxsize=128):
"""
LRU cache that keeps weak references to the objects it caches. Only caches the latest instance of the objects to make sure memory usage
is bounded.
LRU cache that keeps weak references to the objects it caches. Only caches the
latest instance of the objects to make sure memory usage is bounded.
Parameters
----------
@@ -234,15 +235,17 @@ def retry_with_exponential_backoff(
num_retries = 0
delay = initial_delay
# Loop until a successful response or max_retries is hit or an exception is raised
# Loop until a successful response or max_retries is hit or an exception
# is raised
while True:
try:
return func(*args, **kwargs)
# Currently retrying on all exceptions as there is no way to know the format of the error msgs used by different APIs
# We'll log the error and say that it is assumed that if this portion errors out, it's due to rate limit but the user
# should check the error message to be sure
except Exception as e:
# Currently retrying on all exceptions as there is no way to know the
# format of the error msgs used by different APIs. We'll log the error
# and say that it is assumed that if this portion errors out, it's due
# to rate limit but the user should check the error message to be sure.
except Exception as e: # noqa: PERF203
num_retries += 1
if num_retries > max_retries:

View File

@@ -13,7 +13,7 @@
"""Full text search index using tantivy-py"""
import os
from typing import List, Optional, Tuple
from typing import List, Tuple
import pyarrow as pa
@@ -21,7 +21,7 @@ try:
import tantivy
except ImportError:
raise ImportError(
"Please install tantivy-py `pip install tantivy@git+https://github.com/quickwit-oss/tantivy-py#164adc87e1a033117001cf70e38c82a53014d985` to use the full text search feature."
"Please install tantivy-py `pip install tantivy@git+https://github.com/quickwit-oss/tantivy-py#164adc87e1a033117001cf70e38c82a53014d985` to use the full text search feature." # noqa: E501
)
from .table import LanceTable

View File

@@ -20,15 +20,22 @@ import sys
import types
from abc import ABC, abstractmethod
from datetime import date, datetime
from typing import Any, Callable, Dict, Generator, List, Type, Union, _GenericAlias
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Generator,
List,
Type,
Union,
_GenericAlias,
)
import numpy as np
import pyarrow as pa
import pydantic
import semver
from pydantic.fields import FieldInfo
from .embeddings import EmbeddingFunctionRegistry
PYDANTIC_VERSION = semver.Version.parse(pydantic.__version__)
try:
@@ -37,6 +44,11 @@ except ImportError:
if PYDANTIC_VERSION >= (2,):
raise
if TYPE_CHECKING:
from pydantic.fields import FieldInfo
from .embeddings import EmbeddingFunctionConfig
class FixedSizeListMixin(ABC):
@staticmethod
@@ -190,7 +202,7 @@ else:
]
def _pydantic_to_arrow_type(field: pydantic.fields.FieldInfo) -> pa.DataType:
def _pydantic_to_arrow_type(field: FieldInfo) -> pa.DataType:
"""Convert a Pydantic FieldInfo to Arrow DataType"""
if isinstance(field.annotation, _GenericAlias) or (
@@ -221,7 +233,7 @@ def _pydantic_to_arrow_type(field: pydantic.fields.FieldInfo) -> pa.DataType:
return _py_type_to_arrow_type(field.annotation, field)
def is_nullable(field: pydantic.fields.FieldInfo) -> bool:
def is_nullable(field: FieldInfo) -> bool:
"""Check if a Pydantic FieldInfo is nullable."""
if isinstance(field.annotation, _GenericAlias):
origin = field.annotation.__origin__
@@ -237,7 +249,7 @@ def is_nullable(field: pydantic.fields.FieldInfo) -> bool:
return False
def _pydantic_to_field(name: str, field: pydantic.fields.FieldInfo) -> pa.Field:
def _pydantic_to_field(name: str, field: FieldInfo) -> pa.Field:
"""Convert a Pydantic field to a PyArrow Field."""
dt = _pydantic_to_arrow_type(field)
return pa.field(name, dt, is_nullable(field))
@@ -309,6 +321,9 @@ class LanceModel(pydantic.BaseModel):
schema = pydantic_to_schema(cls)
functions = cls.parse_embedding_functions()
if len(functions) > 0:
# Prevent circular import
from .embeddings import EmbeddingFunctionRegistry
metadata = EmbeddingFunctionRegistry.get_instance().get_table_metadata(
functions
)
@@ -359,7 +374,7 @@ class LanceModel(pydantic.BaseModel):
return configs
def get_extras(field_info: pydantic.fields.FieldInfo, key: str) -> Any:
def get_extras(field_info: FieldInfo, key: str) -> Any:
"""
Get the extra metadata from a Pydantic FieldInfo.
"""

View File

@@ -27,7 +27,11 @@ from .common import VECTOR_COLUMN_NAME
from .util import safe_import_pandas
if TYPE_CHECKING:
import PIL
import polars as pl
from .pydantic import LanceModel
from .table import Table
pd = safe_import_pandas()
@@ -60,7 +64,7 @@ class Query(pydantic.BaseModel):
- See discussion in [Querying an ANN Index][querying-an-ann-index] for
tuning advice.
refine_factor : Optional[int]
Refine the results by reading extra elements and re-ranking them in memory - optional
Refine the results by reading extra elements and re-ranking them in memory.
- A higher number makes search more accurate but also slower.
@@ -104,7 +108,7 @@ class LanceQueryBuilder(ABC):
@classmethod
def create(
cls,
table: "lancedb.table.Table",
table: "Table",
query: Optional[Union[np.ndarray, str, "PIL.Image.Image"]],
query_type: str,
vector_column_name: str,
@@ -163,7 +167,7 @@ class LanceQueryBuilder(ABC):
f"Invalid query_type, must be 'vector', 'fts', or 'auto': {query_type}"
)
def __init__(self, table: "lancedb.table.Table"):
def __init__(self, table: "Table"):
self._table = table
self._limit = 10
self._columns = None
@@ -205,7 +209,6 @@ class LanceQueryBuilder(ABC):
if flatten is True:
while True:
tbl = tbl.flatten()
has_struct = False
# loop through all columns to check if there is any struct column
if any(pa.types.is_struct(col.type) for col in tbl.schema):
continue
@@ -214,7 +217,8 @@ class LanceQueryBuilder(ABC):
elif isinstance(flatten, int):
if flatten <= 0:
raise ValueError(
"Please specify a positive integer for flatten or the boolean value `True`"
"Please specify a positive integer for flatten or the boolean "
"value `True`"
)
while flatten > 0:
tbl = tbl.flatten()
@@ -361,7 +365,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
def __init__(
self,
table: "lancedb.table.Table",
table: "Table",
query: Union[np.ndarray, list, "PIL.Image.Image"],
vector_column: str = VECTOR_COLUMN_NAME,
):
@@ -486,7 +490,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
class LanceFtsQueryBuilder(LanceQueryBuilder):
"""A builder for full text search for LanceDB."""
def __init__(self, table: "lancedb.table.Table", query: str):
def __init__(self, table: "Table", query: str):
super().__init__(table)
self._query = query
self._phrase_query = False
@@ -513,7 +517,7 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
import tantivy
except ImportError:
raise ImportError(
"Please install tantivy-py `pip install tantivy@git+https://github.com/quickwit-oss/tantivy-py#164adc87e1a033117001cf70e38c82a53014d985` to use the full text search feature."
"Please install tantivy-py `pip install tantivy@git+https://github.com/quickwit-oss/tantivy-py#164adc87e1a033117001cf70e38c82a53014d985` to use the full text search feature." # noqa: E501
)
from .fts import search_index
@@ -523,8 +527,9 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
# check if the index exist
if not Path(index_path).exists():
raise FileNotFoundError(
"Fts index does not exist."
f"Please first call table.create_fts_index(['<field_names>']) to create the fts index."
"Fts index does not exist. "
"Please first call table.create_fts_index(['<field_names>']) to "
"create the fts index."
)
# open the index
index = tantivy.Index.open(index_path)
@@ -543,19 +548,21 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
if self._where is not None:
try:
# TODO would be great to have Substrait generate pyarrow compute expressions
# or conversely have pyarrow support SQL expressions using Substrait
# TODO would be great to have Substrait generate pyarrow compute
# expressions or conversely have pyarrow support SQL expressions
# using Substrait
import duckdb
output_tbl = (
duckdb.sql(f"SELECT * FROM output_tbl")
duckdb.sql("SELECT * FROM output_tbl")
.filter(self._where)
.to_arrow_table()
)
except ImportError:
import lance
import tempfile
import lance
# TODO Use "memory://" instead once that's supported
with tempfile.TemporaryDirectory() as tmp:
ds = lance.write_dataset(output_tbl, tmp)

View File

@@ -13,12 +13,12 @@
import functools
from typing import Any, Callable, Dict, Iterable, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Union
from urllib.parse import urljoin
import requests
import attrs
import pyarrow as pa
import requests
from pydantic import BaseModel
from lancedb.common import Credential

View File

@@ -11,7 +11,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import inspect
import logging
import uuid
@@ -102,8 +101,10 @@ class RemoteDBConnection(DBConnection):
except LanceDBClientError as err:
if str(err).startswith("Not found"):
logging.error(
f"Table {name} does not exist. "
f"Please first call db.create_table({name}, data)"
"Table %s does not exist. Please first call "
"db.create_table(%s, data).",
name,
name,
)
return RemoteTable(self, name)
@@ -160,7 +161,8 @@ class RemoteDBConnection(DBConnection):
Can create with list of tuples or dictionaries:
>>> import lancedb
>>> db = lancedb.connect("db://...", api_key="...", region="...") # doctest: +SKIP
>>> db = lancedb.connect("db://...", api_key="...", # doctest: +SKIP
... region="...") # doctest: +SKIP
>>> data = [{"vector": [1.1, 1.2], "lat": 45.5, "long": -122.7},
... {"vector": [0.2, 1.8], "lat": 40.1, "long": -74.1}]
>>> db.create_table("my_table", data) # doctest: +SKIP

View File

@@ -11,7 +11,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import uuid
from functools import cached_property
from typing import Dict, Optional, Union
@@ -89,7 +88,8 @@ class RemoteTable(Table):
>>> import lancedb
>>> import uuid
>>> from lancedb.schema import vector
>>> db = lancedb.connect("db://...", api_key="...", region="...") # doctest: +SKIP
>>> db = lancedb.connect("db://...", api_key="...", # doctest: +SKIP
... region="...") # doctest: +SKIP
>>> table_name = uuid.uuid4().hex
>>> schema = pa.schema(
... [
@@ -125,7 +125,8 @@ class RemoteTable(Table):
on_bad_vectors: str = "error",
fill_value: float = 0.0,
) -> int:
"""Add more data to the [Table](Table). It has the same API signature as the OSS version.
"""Add more data to the [Table](Table). It has the same API signature as
the OSS version.
Parameters
----------
@@ -176,7 +177,8 @@ class RemoteTable(Table):
Examples
--------
>>> import lancedb
>>> db = lancedb.connect("db://...", api_key="...", region="...") # doctest: +SKIP
>>> db = lancedb.connect("db://...", api_key="...", # doctest: +SKIP
... region="...") # doctest: +SKIP
>>> data = [
... {"original_width": 100, "caption": "bar", "vector": [0.1, 2.3, 4.5]},
... {"original_width": 2000, "caption": "foo", "vector": [0.5, 3.4, 1.3]},
@@ -265,7 +267,8 @@ class RemoteTable(Table):
... {"x": 2, "vector": [3, 4]},
... {"x": 3, "vector": [5, 6]}
... ]
>>> db = lancedb.connect("db://...", api_key="...", region="...") # doctest: +SKIP
>>> db = lancedb.connect("db://...", api_key="...", # doctest: +SKIP
... region="...") # doctest: +SKIP
>>> table = db.create_table("my_table", data) # doctest: +SKIP
>>> table.search([10,10]).to_pandas() # doctest: +SKIP
x vector _distance # doctest: +SKIP
@@ -323,7 +326,8 @@ class RemoteTable(Table):
... {"x": 2, "vector": [3, 4]},
... {"x": 3, "vector": [5, 6]}
... ]
>>> db = lancedb.connect("db://...", api_key="...", region="...") # doctest: +SKIP
>>> db = lancedb.connect("db://...", api_key="...", # doctest: +SKIP
... region="...") # doctest: +SKIP
>>> table = db.create_table("my_table", data) # doctest: +SKIP
>>> table.to_pandas() # doctest: +SKIP
x vector # doctest: +SKIP

View File

@@ -14,7 +14,6 @@
from __future__ import annotations
import inspect
import os
from abc import ABC, abstractmethod
from functools import cached_property
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Union
@@ -33,18 +32,21 @@ from .pydantic import LanceModel, model_to_dict
from .query import LanceQueryBuilder, Query
from .util import (
fs_from_uri,
join_uri,
safe_import_pandas,
safe_import_polars,
value_to_sql,
join_uri,
)
from .utils.events import register_event
if TYPE_CHECKING:
from datetime import timedelta
import PIL
from lance.dataset import CleanupStats, ReaderLike
from .db import LanceDBConnection
pd = safe_import_pandas()
pl = safe_import_polars()
@@ -526,9 +528,7 @@ class LanceTable(Table):
A table in a LanceDB database.
"""
def __init__(
self, connection: "lancedb.db.LanceDBConnection", name: str, version: int = None
):
def __init__(self, connection: "LanceDBConnection", name: str, version: int = None):
self._conn = connection
self.name = name
self._version = version
@@ -783,9 +783,7 @@ class LanceTable(Table):
index_exists = fs.get_file_info(path).type != pa_fs.FileType.NotFound
if index_exists:
if not replace:
raise ValueError(
f"Index already exists. Use replace=True to overwrite."
)
raise ValueError("Index already exists. Use replace=True to overwrite.")
fs.delete_dir(path)
index = create_index(self._get_fts_index_path(), field_names)

View File

@@ -12,9 +12,9 @@
# limitations under the License.
import os
import pathlib
from datetime import date, datetime
from functools import singledispatch
import pathlib
from typing import Tuple, Union
from urllib.parse import urlparse

View File

@@ -44,8 +44,9 @@ def get_user_config_dir(sub_dir="lancedb"):
# GCP and AWS lambda fix, only /tmp is writeable
if not is_dir_writeable(path.parent):
LOGGER.warning(
f"WARNING ⚠️ user config directory '{path}' is not writeable, defaulting to '/tmp' or CWD."
"Alternatively you can define a LANCEDB_CONFIG_DIR environment variable for this path."
f"WARNING ⚠️ user config directory '{path}' is not writeable, defaulting "
"to '/tmp' or CWD. Alternatively you can define a LANCEDB_CONFIG_DIR "
"environment variable for this path."
)
path = (
Path("/tmp") / sub_dir
@@ -68,7 +69,8 @@ class Config(dict):
Manages lancedb config stored in a YAML file.
Args:
file (str | Path): Path to the lancedb config YAML file. Default is USER_CONFIG_DIR / 'config.yaml'.
file (str | Path): Path to the lancedb config YAML file. Default is
USER_CONFIG_DIR / 'config.yaml'.
"""
def __init__(self, file=CONFIG_FILE):
@@ -90,8 +92,8 @@ class Config(dict):
)
if not (correct_keys and correct_types):
LOGGER.warning(
"WARNING ⚠️ LanceDB settings reset to default values. This may be due to a possible problem "
"with your settings or a recent package update. "
"WARNING ⚠️ LanceDB settings reset to default values. This may be due "
"to a possible problem with your settings or a recent package update. "
f"\nView settings & usage with 'lancedb settings' or at '{self.file}'"
)
self.reset()

View File

@@ -35,10 +35,11 @@ from .general import (
class _Events:
"""
A class for collecting anonymous event analytics. Event analytics are enabled when ``diagnostics=True`` in config and
disabled when ``diagnostics=False``.
A class for collecting anonymous event analytics. Event analytics are enabled when
``diagnostics=True`` in config and disabled when ``diagnostics=False``.
You can enable or disable diagnostics by running ``lancedb diagnostics --enabled`` or ``lancedb diagnostics --disabled``.
You can enable or disable diagnostics by running ``lancedb diagnostics --enabled``
or ``lancedb diagnostics --disabled``.
Attributes
----------
@@ -61,7 +62,8 @@ class _Events:
def __init__(self):
"""
Initializes the Events object with default values for events, rate_limit, and metadata.
Initializes the Events object with default values for events, rate_limit,
and metadata.
"""
self.events = [] # events list
self.throttled_event_names = ["search_table"]
@@ -83,7 +85,8 @@ class _Events:
"version": importlib.metadata.version("lancedb"),
"platforms": PLATFORMS,
"session_id": round(random.random() * 1e15),
# 'engagement_time_msec': 1000 # TODO: In future we might be interested in this metric
# TODO: In future we might be interested in this metric
# 'engagement_time_msec': 1000
}
TESTS_RUNNING = is_pytest_running() or is_github_actions_ci()
@@ -100,7 +103,8 @@ class _Events:
def __call__(self, event_name, params={}):
"""
Attempts to add a new event to the events list and send events if the rate limit is reached.
Attempts to add a new event to the events list and send events if the rate
limit is reached.
Args
----
@@ -109,7 +113,8 @@ class _Events:
params : dict, optional
A dictionary of additional parameters to be logged with the event.
"""
### NOTE: We might need a way to tag a session with a label to check usage from a source. Setting label should be exposed to the user.
### NOTE: We might need a way to tag a session with a label to check usage
### from a source. Setting label should be exposed to the user.
if not self.enabled:
return
if (
@@ -141,8 +146,8 @@ class _Events:
"batch": self.events,
}
# POST equivalent to requests.post(self.url, json=data).
# threaded request is used to avoid blocking, retries are disabled, and verbose is disabled
# to avoid any possible disruption in the console.
# threaded request is used to avoid blocking, retries are disabled, and
# verbose is disabled to avoid any possible disruption in the console.
threaded_request(
method="post",
url=self.url,

View File

@@ -82,7 +82,8 @@ def is_pip_package(filepath: str = __name__) -> bool:
# Get the spec for the module
spec = importlib.util.find_spec(filepath)
# Return whether the spec is not None and the origin is not None (indicating it is a package)
# Return whether the spec is not None and the origin is not None (indicating
# it is a package)
return spec is not None and spec.origin is not None
@@ -108,7 +109,8 @@ def is_github_actions_ci() -> bool:
Returns
-------
bool
True if the current environment is a GitHub Actions CI Python runner, False otherwise.
True if the current environment is a GitHub Actions CI Python runner,
False otherwise.
"""
return (
@@ -145,7 +147,7 @@ def is_online() -> bool:
for host in "1.1.1.1", "8.8.8.8", "223.5.5.5": # Cloudflare, Google, AliDNS:
try:
test_connection = socket.create_connection(address=(host, 53), timeout=2)
except (socket.timeout, socket.gaierror, OSError):
except (socket.timeout, socket.gaierror, OSError): # noqa: PERF203
continue
else:
# If the connection was successful, close it to avoid a ResourceWarning
@@ -227,7 +229,8 @@ def is_docker() -> bool:
def get_git_dir():
"""Determine whether the current file is part of a git repository and if so, returns the repository root directory.
"""Determine whether the current file is part of a git repository and if so,
returns the repository root directory.
If the current file is not part of a git repository, returns None.
Returns
@@ -336,7 +339,7 @@ def yaml_print(yaml_file: Union[str, Path, dict]) -> None:
yaml_load(yaml_file) if isinstance(yaml_file, (str, Path)) else yaml_file
)
dump = yaml.dump(yaml_dict, sort_keys=False, allow_unicode=True)
LOGGER.info(f"Printing '{yaml_file}'\n\n{dump}")
LOGGER.info("Printing '%s'\n\n%s", yaml_file, dump)
PLATFORMS = [platform.system()]
@@ -375,7 +378,7 @@ class TryExcept(contextlib.ContextDecorator):
def __exit__(self, exc_type, value, traceback):
if self.verbose and value:
LOGGER.info(f"{self.msg}{': ' if self.msg else ''}{value}")
LOGGER.info("%s%s%s", self.msg, ": " if self.msg else "", value)
return True
@@ -383,7 +386,8 @@ def threaded_request(
method, url, retry=3, timeout=30, thread=True, code=-1, verbose=True, **kwargs
):
"""
Makes an HTTP request using the 'requests' library, with exponential backoff retries up to a specified timeout.
Makes an HTTP request using the 'requests' library, with exponential backoff
retries up to a specified timeout.
Parameters
----------
@@ -394,7 +398,8 @@ def threaded_request(
retry : int, optional
Number of retries to attempt before giving up, by default 3.
timeout : int, optional
Timeout in seconds after which the function will give up retrying, by default 30.
Timeout in seconds after which the function will give up retrying,
by default 30.
thread : bool, optional
Whether to execute the request in a separate daemon thread, by default True.
code : int, optional
@@ -405,13 +410,17 @@ def threaded_request(
Returns
-------
requests.Response
The HTTP response object. If the request is executed in a separate thread, returns the thread itself.
The HTTP response object. If the request is executed in a separate thread,
returns the thread itself.
"""
retry_codes = () # retry only these codes TODO: add codes if needed in future (500, 408)
# retry only these codes TODO: add codes if needed in future (500, 408)
retry_codes = ()
@TryExcept(verbose=verbose)
def func(method, url, **kwargs):
"""Make HTTP requests with retries and timeouts, with optional progress tracking."""
"""Make HTTP requests with retries and timeouts, with optional progress
tracking.
"""
response = None
t0 = time.time()
for i in range(retry + 1):
@@ -428,9 +437,9 @@ def threaded_request(
if response.status_code in retry_codes:
m += f" Retrying {retry}x for {timeout}s." if retry else ""
elif response.status_code == 429: # rate limit
m = f"Rate limit reached"
m = "Rate limit reached"
if verbose:
LOGGER.warning(f"{response.status_code} #{code}")
LOGGER.warning("%s #%s", response.status_code, m)
if response.status_code not in retry_codes:
return response
time.sleep(2**i) # exponential standoff

View File

@@ -33,10 +33,12 @@ from .general import (
@TryExcept(verbose=False)
def set_sentry():
"""
Initialize the Sentry SDK for error tracking and reporting. Only used if sentry_sdk package is installed and
sync=True in settings. Run 'lancedb settings' to see and update settings YAML file.
Initialize the Sentry SDK for error tracking and reporting. Only used if
sentry_sdk package is installed and sync=True in settings. Run 'lancedb settings'
to see and update settings YAML file.
Conditions required to send errors (ALL conditions must be met or no errors will be reported):
Conditions required to send errors (ALL conditions must be met or no errors will
be reported):
- sentry_sdk package is installed
- sync=True in settings
- pytest is not running
@@ -44,22 +46,26 @@ def set_sentry():
- running in a non-git directory
- online environment
The function also configures Sentry SDK to ignore KeyboardInterrupt and FileNotFoundError
exceptions for now.
The function also configures Sentry SDK to ignore KeyboardInterrupt and
FileNotFoundError exceptions for now.
Additionally, the function sets custom tags and user information for Sentry events.
Additionally, the function sets custom tags and user information for Sentry
events.
"""
def before_send(event, hint):
"""
Modify the event before sending it to Sentry based on specific exception types and messages.
Modify the event before sending it to Sentry based on specific exception
types and messages.
Args:
event (dict): The event dictionary containing information about the error.
hint (dict): A dictionary containing additional information about the error.
hint (dict): A dictionary containing additional information about
the error.
Returns:
dict: The modified event or None if the event should not be sent to Sentry.
dict: The modified event or None if the event should not be sent
to Sentry.
"""
if "exc_info" in hint:
exc_type, exc_value, tb = hint["exc_info"]

View File

@@ -15,11 +15,11 @@ def test_diagnostics():
runner = CliRunner()
result = runner.invoke(cli, ["diagnostics", "--disabled"])
assert result.exit_code == 0 # Main check
assert CONFIG["diagnostics"] == False
assert not CONFIG["diagnostics"]
result = runner.invoke(cli, ["diagnostics", "--enabled"])
assert result.exit_code == 0 # Main check
assert CONFIG["diagnostics"] == True
assert CONFIG["diagnostics"]
def test_config():
@@ -28,8 +28,5 @@ def test_config():
assert result.exit_code == 0 # Main check
cfg = CONFIG.copy()
cfg.pop("uuid")
for (
item,
_,
) in cfg.items(): # check for keys only as formatting is subject to change
for item in cfg: # check for keys only as formatting is subject to change
assert item in result.output

View File

@@ -130,7 +130,8 @@ def test_ingest_iterator(tmp_path):
PydanticSchema(vector=[3.1, 4.1], item="foo", price=10.0),
PydanticSchema(vector=[5.9, 26.5], item="bar", price=20.0),
],
# TODO: test pydict separately. it is unique column number and names contraint
# TODO: test pydict separately. it is unique column number and
# name constraints
]
def run_tests(schema):

View File

@@ -18,7 +18,7 @@ import pyarrow as pa
import pytest
import lancedb
from lancedb.conftest import MockRateLimitedEmbeddingFunction, MockTextEmbeddingFunction
from lancedb.conftest import MockTextEmbeddingFunction
from lancedb.embeddings import (
EmbeddingFunctionConfig,
EmbeddingFunctionRegistry,

View File

@@ -13,7 +13,6 @@
import json
import pytz
import sys
from datetime import date, datetime
from typing import List, Optional, Tuple
@@ -49,18 +48,19 @@ def test_pydantic_to_arrow():
dt_with_tz: datetime = Field(json_schema_extra={"tz": "Asia/Shanghai"})
# d: dict
m = TestModel(
id=1,
s="hello",
vec=[1.0, 2.0, 3.0],
li=[2, 3, 4],
lili=[[2.5, 1.5], [3.5, 4.5], [5.5, 6.5]],
litu=[(2.5, 1.5), (3.5, 4.5), (5.5, 6.5)],
st=StructModel(a="a", b=1.0),
dt=date.today(),
dtt=datetime.now(),
dt_with_tz=datetime.now(pytz.timezone("Asia/Shanghai")),
)
# TODO: test we can actually convert the model into data.
# m = TestModel(
# id=1,
# s="hello",
# vec=[1.0, 2.0, 3.0],
# li=[2, 3, 4],
# lili=[[2.5, 1.5], [3.5, 4.5], [5.5, 6.5]],
# litu=[(2.5, 1.5), (3.5, 4.5), (5.5, 6.5)],
# st=StructModel(a="a", b=1.0),
# dt=date.today(),
# dtt=datetime.now(),
# dt_with_tz=datetime.now(pytz.timezone("Asia/Shanghai")),
# )
schema = pydantic_to_schema(TestModel)
@@ -133,18 +133,19 @@ def test_pydantic_to_arrow_py38():
dt_with_tz: datetime = Field(json_schema_extra={"tz": "Asia/Shanghai"})
# d: dict
m = TestModel(
id=1,
s="hello",
vec=[1.0, 2.0, 3.0],
li=[2, 3, 4],
lili=[[2.5, 1.5], [3.5, 4.5], [5.5, 6.5]],
litu=[(2.5, 1.5), (3.5, 4.5), (5.5, 6.5)],
st=StructModel(a="a", b=1.0),
dt=date.today(),
dtt=datetime.now(),
dt_with_tz=datetime.now(pytz.timezone("Asia/Shanghai")),
)
# TODO: test we can actually convert the model to Arrow data.
# m = TestModel(
# id=1,
# s="hello",
# vec=[1.0, 2.0, 3.0],
# li=[2, 3, 4],
# lili=[[2.5, 1.5], [3.5, 4.5], [5.5, 6.5]],
# litu=[(2.5, 1.5), (3.5, 4.5), (5.5, 6.5)],
# st=StructModel(a="a", b=1.0),
# dt=date.today(),
# dtt=datetime.now(),
# dt_with_tz=datetime.now(pytz.timezone("Asia/Shanghai")),
# )
schema = pydantic_to_schema(TestModel)

View File

@@ -46,7 +46,8 @@ def test_event_reporting(monkeypatch, request_log_path, tmp_path) -> None:
with open(request_log_path, "r") as f:
json_data = json.load(f)
# TODO: don't hardcode these here. Instead create a module level json scehma in lancedb.utils.events for better evolvability
# TODO: don't hardcode these here. Instead create a module level json scehma in
# lancedb.utils.events for better evolvability
batch_keys = ["api_key", "distinct_id", "batch"]
event_keys = ["event", "properties", "timestamp", "distinct_id"]
property_keys = ["cli", "install", "platforms", "version", "session_id"]