mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-04 10:52:56 +00:00
ci: lint and enforce linting (#829)
@eddyxu added instructions for linting here:
7af213801a/python/README.md (L45-L50)
However, we had a lot of failures and weren't checking this in CI. This
PR fixes all lints and adds a check to CI to keep us in compliance with
the lints.
This commit is contained in:
4
.github/workflows/python.yml
vendored
4
.github/workflows/python.yml
vendored
@@ -38,8 +38,10 @@ jobs:
|
||||
pip install -e .[tests]
|
||||
pip install tantivy@git+https://github.com/quickwit-oss/tantivy-py#164adc87e1a033117001cf70e38c82a53014d985
|
||||
pip install pytest pytest-mock ruff
|
||||
- name: Lint
|
||||
- name: Format check
|
||||
run: ruff format --check .
|
||||
- name: Lint
|
||||
run: ruff .
|
||||
- name: Run tests
|
||||
run: pytest -m "not slow" -x -v --durations=30 tests
|
||||
- name: doctest
|
||||
|
||||
@@ -23,9 +23,10 @@ def cli():
|
||||
|
||||
|
||||
diagnostics_help = """
|
||||
Enable or disable LanceDB diagnostics. When enabled, LanceDB will send anonymous events to help us improve LanceDB.
|
||||
These diagnostics are used only for error reporting and no data is collected. You can find more about diagnosis on
|
||||
our docs: https://lancedb.github.io/lancedb/cli_config/
|
||||
Enable or disable LanceDB diagnostics. When enabled, LanceDB will send anonymous events
|
||||
to help us improve LanceDB. These diagnostics are used only for error reporting and no
|
||||
data is collected. You can find more about diagnosis on our docs:
|
||||
https://lancedb.github.io/lancedb/cli_config/
|
||||
"""
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import os
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
@@ -59,7 +59,8 @@ def contextualize(raw_df: "pd.DataFrame") -> Contextualizer:
|
||||
8 dog I love 1
|
||||
9 I love sandwiches 2
|
||||
10 love sandwiches 2
|
||||
>>> contextualize(data).window(7).stride(1).min_window_size(7).text_col('token').to_pandas()
|
||||
>>> (contextualize(data).window(7).stride(1).min_window_size(7)
|
||||
... .text_col('token').to_pandas())
|
||||
token document_id
|
||||
0 The quick brown fox jumped over the 1
|
||||
1 quick brown fox jumped over the lazy 1
|
||||
|
||||
@@ -14,10 +14,10 @@
|
||||
# ruff: noqa: F401
|
||||
from .base import EmbeddingFunction, EmbeddingFunctionConfig, TextEmbeddingFunction
|
||||
from .cohere import CohereEmbeddingFunction
|
||||
from .gemini_text import GeminiText
|
||||
from .instructor import InstructorEmbeddingFunction
|
||||
from .open_clip import OpenClipEmbeddings
|
||||
from .openai import OpenAIEmbeddings
|
||||
from .registry import EmbeddingFunctionRegistry, get_registry
|
||||
from .sentence_transformers import SentenceTransformerEmbeddings
|
||||
from .gemini_text import GeminiText
|
||||
from .utils import with_embeddings
|
||||
|
||||
@@ -13,40 +13,50 @@
|
||||
|
||||
import os
|
||||
from functools import cached_property
|
||||
from typing import List, Union, Any
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lancedb.pydantic import PYDANTIC_VERSION
|
||||
|
||||
from .base import TextEmbeddingFunction
|
||||
from .registry import register
|
||||
from .utils import api_key_not_found_help, TEXT
|
||||
from lancedb.pydantic import PYDANTIC_VERSION
|
||||
from .utils import TEXT, api_key_not_found_help
|
||||
|
||||
|
||||
@register("gemini-text")
|
||||
class GeminiText(TextEmbeddingFunction):
|
||||
"""
|
||||
An embedding function that uses the Google's Gemini API. Requires GOOGLE_API_KEY to be set.
|
||||
An embedding function that uses the Google's Gemini API. Requires GOOGLE_API_KEY to
|
||||
be set.
|
||||
|
||||
https://ai.google.dev/docs/embeddings_guide
|
||||
|
||||
Supports various tasks types:
|
||||
| Task Type | Description |
|
||||
|-------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| "`retrieval_query`" | Specifies the given text is a query in a search/retrieval setting. |
|
||||
| "`retrieval_document`" | Specifies the given text is a document in a search/retrieval setting. Using this task type requires a title but is automatically proided by Embeddings API |
|
||||
| "`semantic_similarity`" | Specifies the given text will be used for Semantic Textual Similarity (STS). |
|
||||
| "`classification`" | Specifies that the embeddings will be used for classification. |
|
||||
| "`clusering`" | Specifies that the embeddings will be used for clustering. |
|
||||
| Task Type | Description |
|
||||
|-------------------------|--------------------------------------------------------|
|
||||
| "`retrieval_query`" | Specifies the given text is a query in a |
|
||||
| | search/retrieval setting. |
|
||||
| "`retrieval_document`" | Specifies the given text is a document in a |
|
||||
| | search/retrieval setting. Using this task type |
|
||||
| | requires a title but is automatically provided by |
|
||||
| | Embeddings API |
|
||||
| "`semantic_similarity`" | Specifies the given text will be used for Semantic |
|
||||
| | Textual Similarity (STS). |
|
||||
| "`classification`" | Specifies that the embeddings will be used for |
|
||||
| | classification. |
|
||||
| "`clustering`" | Specifies that the embeddings will be used for |
|
||||
| | clustering. |
|
||||
|
||||
|
||||
Note: The supported task types might change in the Gemini API, but as long as a supported task type and its argument set is provided,
|
||||
those will be delegated to the API calls.
|
||||
Note: The supported task types might change in the Gemini API, but as long as a
|
||||
supported task type and its argument set is provided, those will be delegated
|
||||
to the API calls.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name: str, default "models/embedding-001"
|
||||
The name of the model to use. See the Gemini documentation for a list of available models.
|
||||
The name of the model to use. See the Gemini documentation for a list of
|
||||
available models.
|
||||
|
||||
query_task_type: str, default "retrieval_query"
|
||||
Sets the task type for the queries.
|
||||
|
||||
@@ -22,22 +22,29 @@ from .utils import TEXT, weak_lru
|
||||
@register("instructor")
|
||||
class InstructorEmbeddingFunction(TextEmbeddingFunction):
|
||||
"""
|
||||
An embedding function that uses the InstructorEmbedding library. Instructor models support multi-task learning, and can be used for a
|
||||
variety of tasks, including text classification, sentence similarity, and document retrieval.
|
||||
If you want to calculate customized embeddings for specific sentences, you may follow the unified template to write instructions:
|
||||
An embedding function that uses the InstructorEmbedding library. Instructor models
|
||||
support multi-task learning, and can be used for a variety of tasks, including
|
||||
text classification, sentence similarity, and document retrieval. If you want to
|
||||
calculate customized embeddings for specific sentences, you may follow the unified
|
||||
template to write instructions:
|
||||
"Represent the `domain` `text_type` for `task_objective`":
|
||||
|
||||
* domain is optional, and it specifies the domain of the text, e.g., science, finance, medicine, etc.
|
||||
* text_type is required, and it specifies the encoding unit, e.g., sentence, document, paragraph, etc.
|
||||
* task_objective is optional, and it specifies the objective of embedding, e.g., retrieve a document, classify the sentence, etc.
|
||||
* domain is optional, and it specifies the domain of the text, e.g., science,
|
||||
finance, medicine, etc.
|
||||
* text_type is required, and it specifies the encoding unit, e.g., sentence,
|
||||
document, paragraph, etc.
|
||||
* task_objective is optional, and it specifies the objective of embedding,
|
||||
e.g., retrieve a document, classify the sentence, etc.
|
||||
|
||||
For example, if you want to calculate embeddings for a document, you may write the instruction as follows:
|
||||
"Represent the document for retreival"
|
||||
For example, if you want to calculate embeddings for a document, you may write the
|
||||
instruction as follows:
|
||||
"Represent the document for retrieval"
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name: str
|
||||
The name of the model to use. Available models are listed at https://github.com/xlang-ai/instructor-embedding#model-list;
|
||||
The name of the model to use. Available models are listed at
|
||||
https://github.com/xlang-ai/instructor-embedding#model-list;
|
||||
The default model is hkunlp/instructor-base
|
||||
batch_size: int, default 32
|
||||
The batch size to use when generating embeddings
|
||||
@@ -49,21 +56,24 @@ class InstructorEmbeddingFunction(TextEmbeddingFunction):
|
||||
Whether to normalize the embeddings
|
||||
quantize: bool, default False
|
||||
Whether to quantize the model
|
||||
source_instruction: str, default "represent the docuement for retreival"
|
||||
source_instruction: str, default "represent the document for retrieval"
|
||||
The instruction for the source column
|
||||
query_instruction: str, default "represent the document for retreiving the most similar documents"
|
||||
query_instruction: str, default "represent the document for retrieving the most
|
||||
similar documents"
|
||||
The instruction for the query
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
||||
import lancedb
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from lancedb.embeddings import get_registry, InstuctorEmbeddingFunction
|
||||
|
||||
instructor = get_registry().get("instructor").create(
|
||||
source_instruction="represent the docuement for retreival",
|
||||
query_instruction="represent the document for retreiving the most similar documents"
|
||||
)
|
||||
source_instruction="represent the document for retrieval",
|
||||
query_instruction="represent the document for retrieving the most "
|
||||
"similar documents"
|
||||
)
|
||||
|
||||
class Schema(LanceModel):
|
||||
vector: Vector(instructor.ndims()) = instructor.VectorField()
|
||||
@@ -72,9 +82,12 @@ class InstructorEmbeddingFunction(TextEmbeddingFunction):
|
||||
db = lancedb.connect("~/.lancedb")
|
||||
tbl = db.create_table("test", schema=Schema, mode="overwrite")
|
||||
|
||||
texts = [{"text": "Capitalism has been dominant in the Western world since the end of feudalism, but most feel[who?] that..."},
|
||||
{"text": "The disparate impact theory is especially controversial under the Fair Housing Act because the Act..."},
|
||||
{"text": "Disparate impact in United States labor law refers to practices in employment, housing, and other areas that.."}]
|
||||
texts = [{"text": "Capitalism has been dominant in the Western world since the "
|
||||
"end of feudalism, but most feel[who?] that..."},
|
||||
{"text": "The disparate impact theory is especially controversial under "
|
||||
"the Fair Housing Act because the Act..."},
|
||||
{"text": "Disparate impact in United States labor law refers to practices "
|
||||
"in employment, housing, and other areas that.."}]
|
||||
|
||||
tbl.add(texts)
|
||||
|
||||
@@ -103,9 +116,7 @@ class InstructorEmbeddingFunction(TextEmbeddingFunction):
|
||||
|
||||
def compute_source_embeddings(self, texts: TEXT, *args, **kwargs) -> List[np.array]:
|
||||
texts = self.sanitize_input(texts)
|
||||
texts_formatted = []
|
||||
for text in texts:
|
||||
texts_formatted.append([self.source_instruction, text])
|
||||
texts_formatted = [[self.source_instruction, text] for text in texts]
|
||||
return self.generate_embeddings(texts_formatted)
|
||||
|
||||
def generate_embeddings(self, texts: List) -> List:
|
||||
|
||||
@@ -14,7 +14,7 @@ import concurrent.futures
|
||||
import io
|
||||
import os
|
||||
import urllib.parse as urlparse
|
||||
from typing import List, Union
|
||||
from typing import TYPE_CHECKING, List, Union
|
||||
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
@@ -25,6 +25,10 @@ from .base import EmbeddingFunction
|
||||
from .registry import register
|
||||
from .utils import IMAGES, url_retrieve
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import PIL
|
||||
import torch
|
||||
|
||||
|
||||
@register("open-clip")
|
||||
class OpenClipEmbeddings(EmbeddingFunction):
|
||||
|
||||
@@ -23,7 +23,9 @@ class EmbeddingFunctionRegistry:
|
||||
You can implement your own embedding function by subclassing EmbeddingFunction
|
||||
or TextEmbeddingFunction and registering it with the registry.
|
||||
|
||||
NOTE: Here TEXT is a type alias for Union[str, List[str], pa.Array, pa.ChunkedArray, np.ndarray]
|
||||
NOTE: Here TEXT is a type alias for Union[str, List[str], pa.Array,
|
||||
pa.ChunkedArray, np.ndarray]
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> registry = EmbeddingFunctionRegistry.get_instance()
|
||||
@@ -164,7 +166,8 @@ __REGISTRY__ = EmbeddingFunctionRegistry()
|
||||
|
||||
|
||||
# @EmbeddingFunctionRegistry.get_instance().register(name) doesn't work in 3.8
|
||||
register = lambda name: EmbeddingFunctionRegistry.get_instance().register(name)
|
||||
def register(name):
|
||||
return __REGISTRY__.get_instance().register(name)
|
||||
|
||||
|
||||
def get_registry():
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
from cachetools import cached
|
||||
|
||||
from .base import TextEmbeddingFunction
|
||||
from .registry import register
|
||||
|
||||
@@ -112,7 +112,8 @@ class FunctionWrapper:
|
||||
v = int(sys.version_info.minor)
|
||||
if v >= 11:
|
||||
print(
|
||||
"WARNING: rate limit only support up to 3.10, proceeding without rate limiter"
|
||||
"WARNING: rate limit only support up to 3.10, proceeding "
|
||||
"without rate limiter"
|
||||
)
|
||||
else:
|
||||
import ratelimiter
|
||||
@@ -168,8 +169,8 @@ class FunctionWrapper:
|
||||
|
||||
def weak_lru(maxsize=128):
|
||||
"""
|
||||
LRU cache that keeps weak references to the objects it caches. Only caches the latest instance of the objects to make sure memory usage
|
||||
is bounded.
|
||||
LRU cache that keeps weak references to the objects it caches. Only caches the
|
||||
latest instance of the objects to make sure memory usage is bounded.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -234,15 +235,17 @@ def retry_with_exponential_backoff(
|
||||
num_retries = 0
|
||||
delay = initial_delay
|
||||
|
||||
# Loop until a successful response or max_retries is hit or an exception is raised
|
||||
# Loop until a successful response or max_retries is hit or an exception
|
||||
# is raised
|
||||
while True:
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
# Currently retrying on all exceptions as there is no way to know the format of the error msgs used by different APIs
|
||||
# We'll log the error and say that it is assumed that if this portion errors out, it's due to rate limit but the user
|
||||
# should check the error message to be sure
|
||||
except Exception as e:
|
||||
# Currently retrying on all exceptions as there is no way to know the
|
||||
# format of the error msgs used by different APIs. We'll log the error
|
||||
# and say that it is assumed that if this portion errors out, it's due
|
||||
# to rate limit but the user should check the error message to be sure.
|
||||
except Exception as e: # noqa: PERF203
|
||||
num_retries += 1
|
||||
|
||||
if num_retries > max_retries:
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
|
||||
"""Full text search index using tantivy-py"""
|
||||
import os
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import List, Tuple
|
||||
|
||||
import pyarrow as pa
|
||||
|
||||
@@ -21,7 +21,7 @@ try:
|
||||
import tantivy
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Please install tantivy-py `pip install tantivy@git+https://github.com/quickwit-oss/tantivy-py#164adc87e1a033117001cf70e38c82a53014d985` to use the full text search feature."
|
||||
"Please install tantivy-py `pip install tantivy@git+https://github.com/quickwit-oss/tantivy-py#164adc87e1a033117001cf70e38c82a53014d985` to use the full text search feature." # noqa: E501
|
||||
)
|
||||
|
||||
from .table import LanceTable
|
||||
|
||||
@@ -20,15 +20,22 @@ import sys
|
||||
import types
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import date, datetime
|
||||
from typing import Any, Callable, Dict, Generator, List, Type, Union, _GenericAlias
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Generator,
|
||||
List,
|
||||
Type,
|
||||
Union,
|
||||
_GenericAlias,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
import pydantic
|
||||
import semver
|
||||
from pydantic.fields import FieldInfo
|
||||
|
||||
from .embeddings import EmbeddingFunctionRegistry
|
||||
|
||||
PYDANTIC_VERSION = semver.Version.parse(pydantic.__version__)
|
||||
try:
|
||||
@@ -37,6 +44,11 @@ except ImportError:
|
||||
if PYDANTIC_VERSION >= (2,):
|
||||
raise
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic.fields import FieldInfo
|
||||
|
||||
from .embeddings import EmbeddingFunctionConfig
|
||||
|
||||
|
||||
class FixedSizeListMixin(ABC):
|
||||
@staticmethod
|
||||
@@ -190,7 +202,7 @@ else:
|
||||
]
|
||||
|
||||
|
||||
def _pydantic_to_arrow_type(field: pydantic.fields.FieldInfo) -> pa.DataType:
|
||||
def _pydantic_to_arrow_type(field: FieldInfo) -> pa.DataType:
|
||||
"""Convert a Pydantic FieldInfo to Arrow DataType"""
|
||||
|
||||
if isinstance(field.annotation, _GenericAlias) or (
|
||||
@@ -221,7 +233,7 @@ def _pydantic_to_arrow_type(field: pydantic.fields.FieldInfo) -> pa.DataType:
|
||||
return _py_type_to_arrow_type(field.annotation, field)
|
||||
|
||||
|
||||
def is_nullable(field: pydantic.fields.FieldInfo) -> bool:
|
||||
def is_nullable(field: FieldInfo) -> bool:
|
||||
"""Check if a Pydantic FieldInfo is nullable."""
|
||||
if isinstance(field.annotation, _GenericAlias):
|
||||
origin = field.annotation.__origin__
|
||||
@@ -237,7 +249,7 @@ def is_nullable(field: pydantic.fields.FieldInfo) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def _pydantic_to_field(name: str, field: pydantic.fields.FieldInfo) -> pa.Field:
|
||||
def _pydantic_to_field(name: str, field: FieldInfo) -> pa.Field:
|
||||
"""Convert a Pydantic field to a PyArrow Field."""
|
||||
dt = _pydantic_to_arrow_type(field)
|
||||
return pa.field(name, dt, is_nullable(field))
|
||||
@@ -309,6 +321,9 @@ class LanceModel(pydantic.BaseModel):
|
||||
schema = pydantic_to_schema(cls)
|
||||
functions = cls.parse_embedding_functions()
|
||||
if len(functions) > 0:
|
||||
# Prevent circular import
|
||||
from .embeddings import EmbeddingFunctionRegistry
|
||||
|
||||
metadata = EmbeddingFunctionRegistry.get_instance().get_table_metadata(
|
||||
functions
|
||||
)
|
||||
@@ -359,7 +374,7 @@ class LanceModel(pydantic.BaseModel):
|
||||
return configs
|
||||
|
||||
|
||||
def get_extras(field_info: pydantic.fields.FieldInfo, key: str) -> Any:
|
||||
def get_extras(field_info: FieldInfo, key: str) -> Any:
|
||||
"""
|
||||
Get the extra metadata from a Pydantic FieldInfo.
|
||||
"""
|
||||
|
||||
@@ -27,7 +27,11 @@ from .common import VECTOR_COLUMN_NAME
|
||||
from .util import safe_import_pandas
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import PIL
|
||||
import polars as pl
|
||||
|
||||
from .pydantic import LanceModel
|
||||
from .table import Table
|
||||
|
||||
pd = safe_import_pandas()
|
||||
|
||||
@@ -60,7 +64,7 @@ class Query(pydantic.BaseModel):
|
||||
- See discussion in [Querying an ANN Index][querying-an-ann-index] for
|
||||
tuning advice.
|
||||
refine_factor : Optional[int]
|
||||
Refine the results by reading extra elements and re-ranking them in memory - optional
|
||||
Refine the results by reading extra elements and re-ranking them in memory.
|
||||
|
||||
- A higher number makes search more accurate but also slower.
|
||||
|
||||
@@ -104,7 +108,7 @@ class LanceQueryBuilder(ABC):
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
table: "lancedb.table.Table",
|
||||
table: "Table",
|
||||
query: Optional[Union[np.ndarray, str, "PIL.Image.Image"]],
|
||||
query_type: str,
|
||||
vector_column_name: str,
|
||||
@@ -163,7 +167,7 @@ class LanceQueryBuilder(ABC):
|
||||
f"Invalid query_type, must be 'vector', 'fts', or 'auto': {query_type}"
|
||||
)
|
||||
|
||||
def __init__(self, table: "lancedb.table.Table"):
|
||||
def __init__(self, table: "Table"):
|
||||
self._table = table
|
||||
self._limit = 10
|
||||
self._columns = None
|
||||
@@ -205,7 +209,6 @@ class LanceQueryBuilder(ABC):
|
||||
if flatten is True:
|
||||
while True:
|
||||
tbl = tbl.flatten()
|
||||
has_struct = False
|
||||
# loop through all columns to check if there is any struct column
|
||||
if any(pa.types.is_struct(col.type) for col in tbl.schema):
|
||||
continue
|
||||
@@ -214,7 +217,8 @@ class LanceQueryBuilder(ABC):
|
||||
elif isinstance(flatten, int):
|
||||
if flatten <= 0:
|
||||
raise ValueError(
|
||||
"Please specify a positive integer for flatten or the boolean value `True`"
|
||||
"Please specify a positive integer for flatten or the boolean "
|
||||
"value `True`"
|
||||
)
|
||||
while flatten > 0:
|
||||
tbl = tbl.flatten()
|
||||
@@ -361,7 +365,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
table: "lancedb.table.Table",
|
||||
table: "Table",
|
||||
query: Union[np.ndarray, list, "PIL.Image.Image"],
|
||||
vector_column: str = VECTOR_COLUMN_NAME,
|
||||
):
|
||||
@@ -486,7 +490,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
class LanceFtsQueryBuilder(LanceQueryBuilder):
|
||||
"""A builder for full text search for LanceDB."""
|
||||
|
||||
def __init__(self, table: "lancedb.table.Table", query: str):
|
||||
def __init__(self, table: "Table", query: str):
|
||||
super().__init__(table)
|
||||
self._query = query
|
||||
self._phrase_query = False
|
||||
@@ -513,7 +517,7 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
|
||||
import tantivy
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Please install tantivy-py `pip install tantivy@git+https://github.com/quickwit-oss/tantivy-py#164adc87e1a033117001cf70e38c82a53014d985` to use the full text search feature."
|
||||
"Please install tantivy-py `pip install tantivy@git+https://github.com/quickwit-oss/tantivy-py#164adc87e1a033117001cf70e38c82a53014d985` to use the full text search feature." # noqa: E501
|
||||
)
|
||||
|
||||
from .fts import search_index
|
||||
@@ -523,8 +527,9 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
|
||||
# check if the index exist
|
||||
if not Path(index_path).exists():
|
||||
raise FileNotFoundError(
|
||||
"Fts index does not exist."
|
||||
f"Please first call table.create_fts_index(['<field_names>']) to create the fts index."
|
||||
"Fts index does not exist. "
|
||||
"Please first call table.create_fts_index(['<field_names>']) to "
|
||||
"create the fts index."
|
||||
)
|
||||
# open the index
|
||||
index = tantivy.Index.open(index_path)
|
||||
@@ -543,19 +548,21 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
|
||||
|
||||
if self._where is not None:
|
||||
try:
|
||||
# TODO would be great to have Substrait generate pyarrow compute expressions
|
||||
# or conversely have pyarrow support SQL expressions using Substrait
|
||||
# TODO would be great to have Substrait generate pyarrow compute
|
||||
# expressions or conversely have pyarrow support SQL expressions
|
||||
# using Substrait
|
||||
import duckdb
|
||||
|
||||
output_tbl = (
|
||||
duckdb.sql(f"SELECT * FROM output_tbl")
|
||||
duckdb.sql("SELECT * FROM output_tbl")
|
||||
.filter(self._where)
|
||||
.to_arrow_table()
|
||||
)
|
||||
except ImportError:
|
||||
import lance
|
||||
import tempfile
|
||||
|
||||
import lance
|
||||
|
||||
# TODO Use "memory://" instead once that's supported
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
ds = lance.write_dataset(output_tbl, tmp)
|
||||
|
||||
@@ -13,12 +13,12 @@
|
||||
|
||||
|
||||
import functools
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import requests
|
||||
import attrs
|
||||
import pyarrow as pa
|
||||
import requests
|
||||
from pydantic import BaseModel
|
||||
|
||||
from lancedb.common import Credential
|
||||
|
||||
@@ -11,7 +11,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import logging
|
||||
import uuid
|
||||
@@ -102,8 +101,10 @@ class RemoteDBConnection(DBConnection):
|
||||
except LanceDBClientError as err:
|
||||
if str(err).startswith("Not found"):
|
||||
logging.error(
|
||||
f"Table {name} does not exist. "
|
||||
f"Please first call db.create_table({name}, data)"
|
||||
"Table %s does not exist. Please first call "
|
||||
"db.create_table(%s, data).",
|
||||
name,
|
||||
name,
|
||||
)
|
||||
return RemoteTable(self, name)
|
||||
|
||||
@@ -160,7 +161,8 @@ class RemoteDBConnection(DBConnection):
|
||||
Can create with list of tuples or dictionaries:
|
||||
|
||||
>>> import lancedb
|
||||
>>> db = lancedb.connect("db://...", api_key="...", region="...") # doctest: +SKIP
|
||||
>>> db = lancedb.connect("db://...", api_key="...", # doctest: +SKIP
|
||||
... region="...") # doctest: +SKIP
|
||||
>>> data = [{"vector": [1.1, 1.2], "lat": 45.5, "long": -122.7},
|
||||
... {"vector": [0.2, 1.8], "lat": 40.1, "long": -74.1}]
|
||||
>>> db.create_table("my_table", data) # doctest: +SKIP
|
||||
|
||||
@@ -11,7 +11,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import asyncio
|
||||
import uuid
|
||||
from functools import cached_property
|
||||
from typing import Dict, Optional, Union
|
||||
@@ -89,7 +88,8 @@ class RemoteTable(Table):
|
||||
>>> import lancedb
|
||||
>>> import uuid
|
||||
>>> from lancedb.schema import vector
|
||||
>>> db = lancedb.connect("db://...", api_key="...", region="...") # doctest: +SKIP
|
||||
>>> db = lancedb.connect("db://...", api_key="...", # doctest: +SKIP
|
||||
... region="...") # doctest: +SKIP
|
||||
>>> table_name = uuid.uuid4().hex
|
||||
>>> schema = pa.schema(
|
||||
... [
|
||||
@@ -125,7 +125,8 @@ class RemoteTable(Table):
|
||||
on_bad_vectors: str = "error",
|
||||
fill_value: float = 0.0,
|
||||
) -> int:
|
||||
"""Add more data to the [Table](Table). It has the same API signature as the OSS version.
|
||||
"""Add more data to the [Table](Table). It has the same API signature as
|
||||
the OSS version.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -176,7 +177,8 @@ class RemoteTable(Table):
|
||||
Examples
|
||||
--------
|
||||
>>> import lancedb
|
||||
>>> db = lancedb.connect("db://...", api_key="...", region="...") # doctest: +SKIP
|
||||
>>> db = lancedb.connect("db://...", api_key="...", # doctest: +SKIP
|
||||
... region="...") # doctest: +SKIP
|
||||
>>> data = [
|
||||
... {"original_width": 100, "caption": "bar", "vector": [0.1, 2.3, 4.5]},
|
||||
... {"original_width": 2000, "caption": "foo", "vector": [0.5, 3.4, 1.3]},
|
||||
@@ -265,7 +267,8 @@ class RemoteTable(Table):
|
||||
... {"x": 2, "vector": [3, 4]},
|
||||
... {"x": 3, "vector": [5, 6]}
|
||||
... ]
|
||||
>>> db = lancedb.connect("db://...", api_key="...", region="...") # doctest: +SKIP
|
||||
>>> db = lancedb.connect("db://...", api_key="...", # doctest: +SKIP
|
||||
... region="...") # doctest: +SKIP
|
||||
>>> table = db.create_table("my_table", data) # doctest: +SKIP
|
||||
>>> table.search([10,10]).to_pandas() # doctest: +SKIP
|
||||
x vector _distance # doctest: +SKIP
|
||||
@@ -323,7 +326,8 @@ class RemoteTable(Table):
|
||||
... {"x": 2, "vector": [3, 4]},
|
||||
... {"x": 3, "vector": [5, 6]}
|
||||
... ]
|
||||
>>> db = lancedb.connect("db://...", api_key="...", region="...") # doctest: +SKIP
|
||||
>>> db = lancedb.connect("db://...", api_key="...", # doctest: +SKIP
|
||||
... region="...") # doctest: +SKIP
|
||||
>>> table = db.create_table("my_table", data) # doctest: +SKIP
|
||||
>>> table.to_pandas() # doctest: +SKIP
|
||||
x vector # doctest: +SKIP
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import cached_property
|
||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Union
|
||||
@@ -33,18 +32,21 @@ from .pydantic import LanceModel, model_to_dict
|
||||
from .query import LanceQueryBuilder, Query
|
||||
from .util import (
|
||||
fs_from_uri,
|
||||
join_uri,
|
||||
safe_import_pandas,
|
||||
safe_import_polars,
|
||||
value_to_sql,
|
||||
join_uri,
|
||||
)
|
||||
from .utils.events import register_event
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from datetime import timedelta
|
||||
|
||||
import PIL
|
||||
from lance.dataset import CleanupStats, ReaderLike
|
||||
|
||||
from .db import LanceDBConnection
|
||||
|
||||
|
||||
pd = safe_import_pandas()
|
||||
pl = safe_import_polars()
|
||||
@@ -526,9 +528,7 @@ class LanceTable(Table):
|
||||
A table in a LanceDB database.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, connection: "lancedb.db.LanceDBConnection", name: str, version: int = None
|
||||
):
|
||||
def __init__(self, connection: "LanceDBConnection", name: str, version: int = None):
|
||||
self._conn = connection
|
||||
self.name = name
|
||||
self._version = version
|
||||
@@ -783,9 +783,7 @@ class LanceTable(Table):
|
||||
index_exists = fs.get_file_info(path).type != pa_fs.FileType.NotFound
|
||||
if index_exists:
|
||||
if not replace:
|
||||
raise ValueError(
|
||||
f"Index already exists. Use replace=True to overwrite."
|
||||
)
|
||||
raise ValueError("Index already exists. Use replace=True to overwrite.")
|
||||
fs.delete_dir(path)
|
||||
|
||||
index = create_index(self._get_fts_index_path(), field_names)
|
||||
|
||||
@@ -12,9 +12,9 @@
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import pathlib
|
||||
from datetime import date, datetime
|
||||
from functools import singledispatch
|
||||
import pathlib
|
||||
from typing import Tuple, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
|
||||
@@ -44,8 +44,9 @@ def get_user_config_dir(sub_dir="lancedb"):
|
||||
# GCP and AWS lambda fix, only /tmp is writeable
|
||||
if not is_dir_writeable(path.parent):
|
||||
LOGGER.warning(
|
||||
f"WARNING ⚠️ user config directory '{path}' is not writeable, defaulting to '/tmp' or CWD."
|
||||
"Alternatively you can define a LANCEDB_CONFIG_DIR environment variable for this path."
|
||||
f"WARNING ⚠️ user config directory '{path}' is not writeable, defaulting "
|
||||
"to '/tmp' or CWD. Alternatively you can define a LANCEDB_CONFIG_DIR "
|
||||
"environment variable for this path."
|
||||
)
|
||||
path = (
|
||||
Path("/tmp") / sub_dir
|
||||
@@ -68,7 +69,8 @@ class Config(dict):
|
||||
Manages lancedb config stored in a YAML file.
|
||||
|
||||
Args:
|
||||
file (str | Path): Path to the lancedb config YAML file. Default is USER_CONFIG_DIR / 'config.yaml'.
|
||||
file (str | Path): Path to the lancedb config YAML file. Default is
|
||||
USER_CONFIG_DIR / 'config.yaml'.
|
||||
"""
|
||||
|
||||
def __init__(self, file=CONFIG_FILE):
|
||||
@@ -90,8 +92,8 @@ class Config(dict):
|
||||
)
|
||||
if not (correct_keys and correct_types):
|
||||
LOGGER.warning(
|
||||
"WARNING ⚠️ LanceDB settings reset to default values. This may be due to a possible problem "
|
||||
"with your settings or a recent package update. "
|
||||
"WARNING ⚠️ LanceDB settings reset to default values. This may be due "
|
||||
"to a possible problem with your settings or a recent package update. "
|
||||
f"\nView settings & usage with 'lancedb settings' or at '{self.file}'"
|
||||
)
|
||||
self.reset()
|
||||
|
||||
@@ -35,10 +35,11 @@ from .general import (
|
||||
|
||||
class _Events:
|
||||
"""
|
||||
A class for collecting anonymous event analytics. Event analytics are enabled when ``diagnostics=True`` in config and
|
||||
disabled when ``diagnostics=False``.
|
||||
A class for collecting anonymous event analytics. Event analytics are enabled when
|
||||
``diagnostics=True`` in config and disabled when ``diagnostics=False``.
|
||||
|
||||
You can enable or disable diagnostics by running ``lancedb diagnostics --enabled`` or ``lancedb diagnostics --disabled``.
|
||||
You can enable or disable diagnostics by running ``lancedb diagnostics --enabled``
|
||||
or ``lancedb diagnostics --disabled``.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
@@ -61,7 +62,8 @@ class _Events:
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Initializes the Events object with default values for events, rate_limit, and metadata.
|
||||
Initializes the Events object with default values for events, rate_limit,
|
||||
and metadata.
|
||||
"""
|
||||
self.events = [] # events list
|
||||
self.throttled_event_names = ["search_table"]
|
||||
@@ -83,7 +85,8 @@ class _Events:
|
||||
"version": importlib.metadata.version("lancedb"),
|
||||
"platforms": PLATFORMS,
|
||||
"session_id": round(random.random() * 1e15),
|
||||
# 'engagement_time_msec': 1000 # TODO: In future we might be interested in this metric
|
||||
# TODO: In future we might be interested in this metric
|
||||
# 'engagement_time_msec': 1000
|
||||
}
|
||||
|
||||
TESTS_RUNNING = is_pytest_running() or is_github_actions_ci()
|
||||
@@ -100,7 +103,8 @@ class _Events:
|
||||
|
||||
def __call__(self, event_name, params={}):
|
||||
"""
|
||||
Attempts to add a new event to the events list and send events if the rate limit is reached.
|
||||
Attempts to add a new event to the events list and send events if the rate
|
||||
limit is reached.
|
||||
|
||||
Args
|
||||
----
|
||||
@@ -109,7 +113,8 @@ class _Events:
|
||||
params : dict, optional
|
||||
A dictionary of additional parameters to be logged with the event.
|
||||
"""
|
||||
### NOTE: We might need a way to tag a session with a label to check usage from a source. Setting label should be exposed to the user.
|
||||
### NOTE: We might need a way to tag a session with a label to check usage
|
||||
### from a source. Setting label should be exposed to the user.
|
||||
if not self.enabled:
|
||||
return
|
||||
if (
|
||||
@@ -141,8 +146,8 @@ class _Events:
|
||||
"batch": self.events,
|
||||
}
|
||||
# POST equivalent to requests.post(self.url, json=data).
|
||||
# threaded request is used to avoid blocking, retries are disabled, and verbose is disabled
|
||||
# to avoid any possible disruption in the console.
|
||||
# threaded request is used to avoid blocking, retries are disabled, and
|
||||
# verbose is disabled to avoid any possible disruption in the console.
|
||||
threaded_request(
|
||||
method="post",
|
||||
url=self.url,
|
||||
|
||||
@@ -82,7 +82,8 @@ def is_pip_package(filepath: str = __name__) -> bool:
|
||||
# Get the spec for the module
|
||||
spec = importlib.util.find_spec(filepath)
|
||||
|
||||
# Return whether the spec is not None and the origin is not None (indicating it is a package)
|
||||
# Return whether the spec is not None and the origin is not None (indicating
|
||||
# it is a package)
|
||||
return spec is not None and spec.origin is not None
|
||||
|
||||
|
||||
@@ -108,7 +109,8 @@ def is_github_actions_ci() -> bool:
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if the current environment is a GitHub Actions CI Python runner, False otherwise.
|
||||
True if the current environment is a GitHub Actions CI Python runner,
|
||||
False otherwise.
|
||||
"""
|
||||
|
||||
return (
|
||||
@@ -145,7 +147,7 @@ def is_online() -> bool:
|
||||
for host in "1.1.1.1", "8.8.8.8", "223.5.5.5": # Cloudflare, Google, AliDNS:
|
||||
try:
|
||||
test_connection = socket.create_connection(address=(host, 53), timeout=2)
|
||||
except (socket.timeout, socket.gaierror, OSError):
|
||||
except (socket.timeout, socket.gaierror, OSError): # noqa: PERF203
|
||||
continue
|
||||
else:
|
||||
# If the connection was successful, close it to avoid a ResourceWarning
|
||||
@@ -227,7 +229,8 @@ def is_docker() -> bool:
|
||||
|
||||
|
||||
def get_git_dir():
|
||||
"""Determine whether the current file is part of a git repository and if so, returns the repository root directory.
|
||||
"""Determine whether the current file is part of a git repository and if so,
|
||||
returns the repository root directory.
|
||||
If the current file is not part of a git repository, returns None.
|
||||
|
||||
Returns
|
||||
@@ -336,7 +339,7 @@ def yaml_print(yaml_file: Union[str, Path, dict]) -> None:
|
||||
yaml_load(yaml_file) if isinstance(yaml_file, (str, Path)) else yaml_file
|
||||
)
|
||||
dump = yaml.dump(yaml_dict, sort_keys=False, allow_unicode=True)
|
||||
LOGGER.info(f"Printing '{yaml_file}'\n\n{dump}")
|
||||
LOGGER.info("Printing '%s'\n\n%s", yaml_file, dump)
|
||||
|
||||
|
||||
PLATFORMS = [platform.system()]
|
||||
@@ -375,7 +378,7 @@ class TryExcept(contextlib.ContextDecorator):
|
||||
|
||||
def __exit__(self, exc_type, value, traceback):
|
||||
if self.verbose and value:
|
||||
LOGGER.info(f"{self.msg}{': ' if self.msg else ''}{value}")
|
||||
LOGGER.info("%s%s%s", self.msg, ": " if self.msg else "", value)
|
||||
return True
|
||||
|
||||
|
||||
@@ -383,7 +386,8 @@ def threaded_request(
|
||||
method, url, retry=3, timeout=30, thread=True, code=-1, verbose=True, **kwargs
|
||||
):
|
||||
"""
|
||||
Makes an HTTP request using the 'requests' library, with exponential backoff retries up to a specified timeout.
|
||||
Makes an HTTP request using the 'requests' library, with exponential backoff
|
||||
retries up to a specified timeout.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -394,7 +398,8 @@ def threaded_request(
|
||||
retry : int, optional
|
||||
Number of retries to attempt before giving up, by default 3.
|
||||
timeout : int, optional
|
||||
Timeout in seconds after which the function will give up retrying, by default 30.
|
||||
Timeout in seconds after which the function will give up retrying,
|
||||
by default 30.
|
||||
thread : bool, optional
|
||||
Whether to execute the request in a separate daemon thread, by default True.
|
||||
code : int, optional
|
||||
@@ -405,13 +410,17 @@ def threaded_request(
|
||||
Returns
|
||||
-------
|
||||
requests.Response
|
||||
The HTTP response object. If the request is executed in a separate thread, returns the thread itself.
|
||||
The HTTP response object. If the request is executed in a separate thread,
|
||||
returns the thread itself.
|
||||
"""
|
||||
retry_codes = () # retry only these codes TODO: add codes if needed in future (500, 408)
|
||||
# retry only these codes TODO: add codes if needed in future (500, 408)
|
||||
retry_codes = ()
|
||||
|
||||
@TryExcept(verbose=verbose)
|
||||
def func(method, url, **kwargs):
|
||||
"""Make HTTP requests with retries and timeouts, with optional progress tracking."""
|
||||
"""Make HTTP requests with retries and timeouts, with optional progress
|
||||
tracking.
|
||||
"""
|
||||
response = None
|
||||
t0 = time.time()
|
||||
for i in range(retry + 1):
|
||||
@@ -428,9 +437,9 @@ def threaded_request(
|
||||
if response.status_code in retry_codes:
|
||||
m += f" Retrying {retry}x for {timeout}s." if retry else ""
|
||||
elif response.status_code == 429: # rate limit
|
||||
m = f"Rate limit reached"
|
||||
m = "Rate limit reached"
|
||||
if verbose:
|
||||
LOGGER.warning(f"{response.status_code} #{code}")
|
||||
LOGGER.warning("%s #%s", response.status_code, m)
|
||||
if response.status_code not in retry_codes:
|
||||
return response
|
||||
time.sleep(2**i) # exponential standoff
|
||||
|
||||
@@ -33,10 +33,12 @@ from .general import (
|
||||
@TryExcept(verbose=False)
|
||||
def set_sentry():
|
||||
"""
|
||||
Initialize the Sentry SDK for error tracking and reporting. Only used if sentry_sdk package is installed and
|
||||
sync=True in settings. Run 'lancedb settings' to see and update settings YAML file.
|
||||
Initialize the Sentry SDK for error tracking and reporting. Only used if
|
||||
sentry_sdk package is installed and sync=True in settings. Run 'lancedb settings'
|
||||
to see and update settings YAML file.
|
||||
|
||||
Conditions required to send errors (ALL conditions must be met or no errors will be reported):
|
||||
Conditions required to send errors (ALL conditions must be met or no errors will
|
||||
be reported):
|
||||
- sentry_sdk package is installed
|
||||
- sync=True in settings
|
||||
- pytest is not running
|
||||
@@ -44,22 +46,26 @@ def set_sentry():
|
||||
- running in a non-git directory
|
||||
- online environment
|
||||
|
||||
The function also configures Sentry SDK to ignore KeyboardInterrupt and FileNotFoundError
|
||||
exceptions for now.
|
||||
The function also configures Sentry SDK to ignore KeyboardInterrupt and
|
||||
FileNotFoundError exceptions for now.
|
||||
|
||||
Additionally, the function sets custom tags and user information for Sentry events.
|
||||
Additionally, the function sets custom tags and user information for Sentry
|
||||
events.
|
||||
"""
|
||||
|
||||
def before_send(event, hint):
|
||||
"""
|
||||
Modify the event before sending it to Sentry based on specific exception types and messages.
|
||||
Modify the event before sending it to Sentry based on specific exception
|
||||
types and messages.
|
||||
|
||||
Args:
|
||||
event (dict): The event dictionary containing information about the error.
|
||||
hint (dict): A dictionary containing additional information about the error.
|
||||
hint (dict): A dictionary containing additional information about
|
||||
the error.
|
||||
|
||||
Returns:
|
||||
dict: The modified event or None if the event should not be sent to Sentry.
|
||||
dict: The modified event or None if the event should not be sent
|
||||
to Sentry.
|
||||
"""
|
||||
if "exc_info" in hint:
|
||||
exc_type, exc_value, tb = hint["exc_info"]
|
||||
|
||||
@@ -15,11 +15,11 @@ def test_diagnostics():
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ["diagnostics", "--disabled"])
|
||||
assert result.exit_code == 0 # Main check
|
||||
assert CONFIG["diagnostics"] == False
|
||||
assert not CONFIG["diagnostics"]
|
||||
|
||||
result = runner.invoke(cli, ["diagnostics", "--enabled"])
|
||||
assert result.exit_code == 0 # Main check
|
||||
assert CONFIG["diagnostics"] == True
|
||||
assert CONFIG["diagnostics"]
|
||||
|
||||
|
||||
def test_config():
|
||||
@@ -28,8 +28,5 @@ def test_config():
|
||||
assert result.exit_code == 0 # Main check
|
||||
cfg = CONFIG.copy()
|
||||
cfg.pop("uuid")
|
||||
for (
|
||||
item,
|
||||
_,
|
||||
) in cfg.items(): # check for keys only as formatting is subject to change
|
||||
for item in cfg: # check for keys only as formatting is subject to change
|
||||
assert item in result.output
|
||||
|
||||
@@ -130,7 +130,8 @@ def test_ingest_iterator(tmp_path):
|
||||
PydanticSchema(vector=[3.1, 4.1], item="foo", price=10.0),
|
||||
PydanticSchema(vector=[5.9, 26.5], item="bar", price=20.0),
|
||||
],
|
||||
# TODO: test pydict separately. it is unique column number and names contraint
|
||||
# TODO: test pydict separately. it is unique column number and
|
||||
# name constraints
|
||||
]
|
||||
|
||||
def run_tests(schema):
|
||||
|
||||
@@ -18,7 +18,7 @@ import pyarrow as pa
|
||||
import pytest
|
||||
|
||||
import lancedb
|
||||
from lancedb.conftest import MockRateLimitedEmbeddingFunction, MockTextEmbeddingFunction
|
||||
from lancedb.conftest import MockTextEmbeddingFunction
|
||||
from lancedb.embeddings import (
|
||||
EmbeddingFunctionConfig,
|
||||
EmbeddingFunctionRegistry,
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
|
||||
|
||||
import json
|
||||
import pytz
|
||||
import sys
|
||||
from datetime import date, datetime
|
||||
from typing import List, Optional, Tuple
|
||||
@@ -49,18 +48,19 @@ def test_pydantic_to_arrow():
|
||||
dt_with_tz: datetime = Field(json_schema_extra={"tz": "Asia/Shanghai"})
|
||||
# d: dict
|
||||
|
||||
m = TestModel(
|
||||
id=1,
|
||||
s="hello",
|
||||
vec=[1.0, 2.0, 3.0],
|
||||
li=[2, 3, 4],
|
||||
lili=[[2.5, 1.5], [3.5, 4.5], [5.5, 6.5]],
|
||||
litu=[(2.5, 1.5), (3.5, 4.5), (5.5, 6.5)],
|
||||
st=StructModel(a="a", b=1.0),
|
||||
dt=date.today(),
|
||||
dtt=datetime.now(),
|
||||
dt_with_tz=datetime.now(pytz.timezone("Asia/Shanghai")),
|
||||
)
|
||||
# TODO: test we can actually convert the model into data.
|
||||
# m = TestModel(
|
||||
# id=1,
|
||||
# s="hello",
|
||||
# vec=[1.0, 2.0, 3.0],
|
||||
# li=[2, 3, 4],
|
||||
# lili=[[2.5, 1.5], [3.5, 4.5], [5.5, 6.5]],
|
||||
# litu=[(2.5, 1.5), (3.5, 4.5), (5.5, 6.5)],
|
||||
# st=StructModel(a="a", b=1.0),
|
||||
# dt=date.today(),
|
||||
# dtt=datetime.now(),
|
||||
# dt_with_tz=datetime.now(pytz.timezone("Asia/Shanghai")),
|
||||
# )
|
||||
|
||||
schema = pydantic_to_schema(TestModel)
|
||||
|
||||
@@ -133,18 +133,19 @@ def test_pydantic_to_arrow_py38():
|
||||
dt_with_tz: datetime = Field(json_schema_extra={"tz": "Asia/Shanghai"})
|
||||
# d: dict
|
||||
|
||||
m = TestModel(
|
||||
id=1,
|
||||
s="hello",
|
||||
vec=[1.0, 2.0, 3.0],
|
||||
li=[2, 3, 4],
|
||||
lili=[[2.5, 1.5], [3.5, 4.5], [5.5, 6.5]],
|
||||
litu=[(2.5, 1.5), (3.5, 4.5), (5.5, 6.5)],
|
||||
st=StructModel(a="a", b=1.0),
|
||||
dt=date.today(),
|
||||
dtt=datetime.now(),
|
||||
dt_with_tz=datetime.now(pytz.timezone("Asia/Shanghai")),
|
||||
)
|
||||
# TODO: test we can actually convert the model to Arrow data.
|
||||
# m = TestModel(
|
||||
# id=1,
|
||||
# s="hello",
|
||||
# vec=[1.0, 2.0, 3.0],
|
||||
# li=[2, 3, 4],
|
||||
# lili=[[2.5, 1.5], [3.5, 4.5], [5.5, 6.5]],
|
||||
# litu=[(2.5, 1.5), (3.5, 4.5), (5.5, 6.5)],
|
||||
# st=StructModel(a="a", b=1.0),
|
||||
# dt=date.today(),
|
||||
# dtt=datetime.now(),
|
||||
# dt_with_tz=datetime.now(pytz.timezone("Asia/Shanghai")),
|
||||
# )
|
||||
|
||||
schema = pydantic_to_schema(TestModel)
|
||||
|
||||
|
||||
@@ -46,7 +46,8 @@ def test_event_reporting(monkeypatch, request_log_path, tmp_path) -> None:
|
||||
with open(request_log_path, "r") as f:
|
||||
json_data = json.load(f)
|
||||
|
||||
# TODO: don't hardcode these here. Instead create a module level json scehma in lancedb.utils.events for better evolvability
|
||||
# TODO: don't hardcode these here. Instead create a module level json scehma in
|
||||
# lancedb.utils.events for better evolvability
|
||||
batch_keys = ["api_key", "distinct_id", "batch"]
|
||||
event_keys = ["event", "properties", "timestamp", "distinct_id"]
|
||||
property_keys = ["cli", "install", "platforms", "version", "session_id"]
|
||||
|
||||
Reference in New Issue
Block a user