mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-08 21:02:58 +00:00
ci: lint and enforce linting (#829)
@eddyxu added instructions for linting here:
7af213801a/python/README.md (L45-L50)
However, we had a lot of failures and weren't checking this in CI. This
PR fixes all lints and adds a check to CI to keep us in compliance with
the lints.
This commit is contained in:
@@ -1,6 +1,5 @@
|
||||
import os
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
@@ -59,7 +59,8 @@ def contextualize(raw_df: "pd.DataFrame") -> Contextualizer:
|
||||
8 dog I love 1
|
||||
9 I love sandwiches 2
|
||||
10 love sandwiches 2
|
||||
>>> contextualize(data).window(7).stride(1).min_window_size(7).text_col('token').to_pandas()
|
||||
>>> (contextualize(data).window(7).stride(1).min_window_size(7)
|
||||
... .text_col('token').to_pandas())
|
||||
token document_id
|
||||
0 The quick brown fox jumped over the 1
|
||||
1 quick brown fox jumped over the lazy 1
|
||||
|
||||
@@ -14,10 +14,10 @@
|
||||
# ruff: noqa: F401
|
||||
from .base import EmbeddingFunction, EmbeddingFunctionConfig, TextEmbeddingFunction
|
||||
from .cohere import CohereEmbeddingFunction
|
||||
from .gemini_text import GeminiText
|
||||
from .instructor import InstructorEmbeddingFunction
|
||||
from .open_clip import OpenClipEmbeddings
|
||||
from .openai import OpenAIEmbeddings
|
||||
from .registry import EmbeddingFunctionRegistry, get_registry
|
||||
from .sentence_transformers import SentenceTransformerEmbeddings
|
||||
from .gemini_text import GeminiText
|
||||
from .utils import with_embeddings
|
||||
|
||||
@@ -13,40 +13,50 @@
|
||||
|
||||
import os
|
||||
from functools import cached_property
|
||||
from typing import List, Union, Any
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lancedb.pydantic import PYDANTIC_VERSION
|
||||
|
||||
from .base import TextEmbeddingFunction
|
||||
from .registry import register
|
||||
from .utils import api_key_not_found_help, TEXT
|
||||
from lancedb.pydantic import PYDANTIC_VERSION
|
||||
from .utils import TEXT, api_key_not_found_help
|
||||
|
||||
|
||||
@register("gemini-text")
|
||||
class GeminiText(TextEmbeddingFunction):
|
||||
"""
|
||||
An embedding function that uses the Google's Gemini API. Requires GOOGLE_API_KEY to be set.
|
||||
An embedding function that uses the Google's Gemini API. Requires GOOGLE_API_KEY to
|
||||
be set.
|
||||
|
||||
https://ai.google.dev/docs/embeddings_guide
|
||||
|
||||
Supports various tasks types:
|
||||
| Task Type | Description |
|
||||
|-------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| "`retrieval_query`" | Specifies the given text is a query in a search/retrieval setting. |
|
||||
| "`retrieval_document`" | Specifies the given text is a document in a search/retrieval setting. Using this task type requires a title but is automatically proided by Embeddings API |
|
||||
| "`semantic_similarity`" | Specifies the given text will be used for Semantic Textual Similarity (STS). |
|
||||
| "`classification`" | Specifies that the embeddings will be used for classification. |
|
||||
| "`clusering`" | Specifies that the embeddings will be used for clustering. |
|
||||
| Task Type | Description |
|
||||
|-------------------------|--------------------------------------------------------|
|
||||
| "`retrieval_query`" | Specifies the given text is a query in a |
|
||||
| | search/retrieval setting. |
|
||||
| "`retrieval_document`" | Specifies the given text is a document in a |
|
||||
| | search/retrieval setting. Using this task type |
|
||||
| | requires a title but is automatically provided by |
|
||||
| | Embeddings API |
|
||||
| "`semantic_similarity`" | Specifies the given text will be used for Semantic |
|
||||
| | Textual Similarity (STS). |
|
||||
| "`classification`" | Specifies that the embeddings will be used for |
|
||||
| | classification. |
|
||||
| "`clustering`" | Specifies that the embeddings will be used for |
|
||||
| | clustering. |
|
||||
|
||||
|
||||
Note: The supported task types might change in the Gemini API, but as long as a supported task type and its argument set is provided,
|
||||
those will be delegated to the API calls.
|
||||
Note: The supported task types might change in the Gemini API, but as long as a
|
||||
supported task type and its argument set is provided, those will be delegated
|
||||
to the API calls.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name: str, default "models/embedding-001"
|
||||
The name of the model to use. See the Gemini documentation for a list of available models.
|
||||
The name of the model to use. See the Gemini documentation for a list of
|
||||
available models.
|
||||
|
||||
query_task_type: str, default "retrieval_query"
|
||||
Sets the task type for the queries.
|
||||
|
||||
@@ -22,22 +22,29 @@ from .utils import TEXT, weak_lru
|
||||
@register("instructor")
|
||||
class InstructorEmbeddingFunction(TextEmbeddingFunction):
|
||||
"""
|
||||
An embedding function that uses the InstructorEmbedding library. Instructor models support multi-task learning, and can be used for a
|
||||
variety of tasks, including text classification, sentence similarity, and document retrieval.
|
||||
If you want to calculate customized embeddings for specific sentences, you may follow the unified template to write instructions:
|
||||
An embedding function that uses the InstructorEmbedding library. Instructor models
|
||||
support multi-task learning, and can be used for a variety of tasks, including
|
||||
text classification, sentence similarity, and document retrieval. If you want to
|
||||
calculate customized embeddings for specific sentences, you may follow the unified
|
||||
template to write instructions:
|
||||
"Represent the `domain` `text_type` for `task_objective`":
|
||||
|
||||
* domain is optional, and it specifies the domain of the text, e.g., science, finance, medicine, etc.
|
||||
* text_type is required, and it specifies the encoding unit, e.g., sentence, document, paragraph, etc.
|
||||
* task_objective is optional, and it specifies the objective of embedding, e.g., retrieve a document, classify the sentence, etc.
|
||||
* domain is optional, and it specifies the domain of the text, e.g., science,
|
||||
finance, medicine, etc.
|
||||
* text_type is required, and it specifies the encoding unit, e.g., sentence,
|
||||
document, paragraph, etc.
|
||||
* task_objective is optional, and it specifies the objective of embedding,
|
||||
e.g., retrieve a document, classify the sentence, etc.
|
||||
|
||||
For example, if you want to calculate embeddings for a document, you may write the instruction as follows:
|
||||
"Represent the document for retreival"
|
||||
For example, if you want to calculate embeddings for a document, you may write the
|
||||
instruction as follows:
|
||||
"Represent the document for retrieval"
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name: str
|
||||
The name of the model to use. Available models are listed at https://github.com/xlang-ai/instructor-embedding#model-list;
|
||||
The name of the model to use. Available models are listed at
|
||||
https://github.com/xlang-ai/instructor-embedding#model-list;
|
||||
The default model is hkunlp/instructor-base
|
||||
batch_size: int, default 32
|
||||
The batch size to use when generating embeddings
|
||||
@@ -49,21 +56,24 @@ class InstructorEmbeddingFunction(TextEmbeddingFunction):
|
||||
Whether to normalize the embeddings
|
||||
quantize: bool, default False
|
||||
Whether to quantize the model
|
||||
source_instruction: str, default "represent the docuement for retreival"
|
||||
source_instruction: str, default "represent the document for retrieval"
|
||||
The instruction for the source column
|
||||
query_instruction: str, default "represent the document for retreiving the most similar documents"
|
||||
query_instruction: str, default "represent the document for retrieving the most
|
||||
similar documents"
|
||||
The instruction for the query
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
||||
import lancedb
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from lancedb.embeddings import get_registry, InstuctorEmbeddingFunction
|
||||
|
||||
instructor = get_registry().get("instructor").create(
|
||||
source_instruction="represent the docuement for retreival",
|
||||
query_instruction="represent the document for retreiving the most similar documents"
|
||||
)
|
||||
source_instruction="represent the document for retrieval",
|
||||
query_instruction="represent the document for retrieving the most "
|
||||
"similar documents"
|
||||
)
|
||||
|
||||
class Schema(LanceModel):
|
||||
vector: Vector(instructor.ndims()) = instructor.VectorField()
|
||||
@@ -72,9 +82,12 @@ class InstructorEmbeddingFunction(TextEmbeddingFunction):
|
||||
db = lancedb.connect("~/.lancedb")
|
||||
tbl = db.create_table("test", schema=Schema, mode="overwrite")
|
||||
|
||||
texts = [{"text": "Capitalism has been dominant in the Western world since the end of feudalism, but most feel[who?] that..."},
|
||||
{"text": "The disparate impact theory is especially controversial under the Fair Housing Act because the Act..."},
|
||||
{"text": "Disparate impact in United States labor law refers to practices in employment, housing, and other areas that.."}]
|
||||
texts = [{"text": "Capitalism has been dominant in the Western world since the "
|
||||
"end of feudalism, but most feel[who?] that..."},
|
||||
{"text": "The disparate impact theory is especially controversial under "
|
||||
"the Fair Housing Act because the Act..."},
|
||||
{"text": "Disparate impact in United States labor law refers to practices "
|
||||
"in employment, housing, and other areas that.."}]
|
||||
|
||||
tbl.add(texts)
|
||||
|
||||
@@ -103,9 +116,7 @@ class InstructorEmbeddingFunction(TextEmbeddingFunction):
|
||||
|
||||
def compute_source_embeddings(self, texts: TEXT, *args, **kwargs) -> List[np.array]:
|
||||
texts = self.sanitize_input(texts)
|
||||
texts_formatted = []
|
||||
for text in texts:
|
||||
texts_formatted.append([self.source_instruction, text])
|
||||
texts_formatted = [[self.source_instruction, text] for text in texts]
|
||||
return self.generate_embeddings(texts_formatted)
|
||||
|
||||
def generate_embeddings(self, texts: List) -> List:
|
||||
|
||||
@@ -14,7 +14,7 @@ import concurrent.futures
|
||||
import io
|
||||
import os
|
||||
import urllib.parse as urlparse
|
||||
from typing import List, Union
|
||||
from typing import TYPE_CHECKING, List, Union
|
||||
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
@@ -25,6 +25,10 @@ from .base import EmbeddingFunction
|
||||
from .registry import register
|
||||
from .utils import IMAGES, url_retrieve
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import PIL
|
||||
import torch
|
||||
|
||||
|
||||
@register("open-clip")
|
||||
class OpenClipEmbeddings(EmbeddingFunction):
|
||||
|
||||
@@ -23,7 +23,9 @@ class EmbeddingFunctionRegistry:
|
||||
You can implement your own embedding function by subclassing EmbeddingFunction
|
||||
or TextEmbeddingFunction and registering it with the registry.
|
||||
|
||||
NOTE: Here TEXT is a type alias for Union[str, List[str], pa.Array, pa.ChunkedArray, np.ndarray]
|
||||
NOTE: Here TEXT is a type alias for Union[str, List[str], pa.Array,
|
||||
pa.ChunkedArray, np.ndarray]
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> registry = EmbeddingFunctionRegistry.get_instance()
|
||||
@@ -164,7 +166,8 @@ __REGISTRY__ = EmbeddingFunctionRegistry()
|
||||
|
||||
|
||||
# @EmbeddingFunctionRegistry.get_instance().register(name) doesn't work in 3.8
|
||||
register = lambda name: EmbeddingFunctionRegistry.get_instance().register(name)
|
||||
def register(name):
|
||||
return __REGISTRY__.get_instance().register(name)
|
||||
|
||||
|
||||
def get_registry():
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
from cachetools import cached
|
||||
|
||||
from .base import TextEmbeddingFunction
|
||||
from .registry import register
|
||||
|
||||
@@ -112,7 +112,8 @@ class FunctionWrapper:
|
||||
v = int(sys.version_info.minor)
|
||||
if v >= 11:
|
||||
print(
|
||||
"WARNING: rate limit only support up to 3.10, proceeding without rate limiter"
|
||||
"WARNING: rate limit only support up to 3.10, proceeding "
|
||||
"without rate limiter"
|
||||
)
|
||||
else:
|
||||
import ratelimiter
|
||||
@@ -168,8 +169,8 @@ class FunctionWrapper:
|
||||
|
||||
def weak_lru(maxsize=128):
|
||||
"""
|
||||
LRU cache that keeps weak references to the objects it caches. Only caches the latest instance of the objects to make sure memory usage
|
||||
is bounded.
|
||||
LRU cache that keeps weak references to the objects it caches. Only caches the
|
||||
latest instance of the objects to make sure memory usage is bounded.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -234,15 +235,17 @@ def retry_with_exponential_backoff(
|
||||
num_retries = 0
|
||||
delay = initial_delay
|
||||
|
||||
# Loop until a successful response or max_retries is hit or an exception is raised
|
||||
# Loop until a successful response or max_retries is hit or an exception
|
||||
# is raised
|
||||
while True:
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
# Currently retrying on all exceptions as there is no way to know the format of the error msgs used by different APIs
|
||||
# We'll log the error and say that it is assumed that if this portion errors out, it's due to rate limit but the user
|
||||
# should check the error message to be sure
|
||||
except Exception as e:
|
||||
# Currently retrying on all exceptions as there is no way to know the
|
||||
# format of the error msgs used by different APIs. We'll log the error
|
||||
# and say that it is assumed that if this portion errors out, it's due
|
||||
# to rate limit but the user should check the error message to be sure.
|
||||
except Exception as e: # noqa: PERF203
|
||||
num_retries += 1
|
||||
|
||||
if num_retries > max_retries:
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
|
||||
"""Full text search index using tantivy-py"""
|
||||
import os
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import List, Tuple
|
||||
|
||||
import pyarrow as pa
|
||||
|
||||
@@ -21,7 +21,7 @@ try:
|
||||
import tantivy
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Please install tantivy-py `pip install tantivy@git+https://github.com/quickwit-oss/tantivy-py#164adc87e1a033117001cf70e38c82a53014d985` to use the full text search feature."
|
||||
"Please install tantivy-py `pip install tantivy@git+https://github.com/quickwit-oss/tantivy-py#164adc87e1a033117001cf70e38c82a53014d985` to use the full text search feature." # noqa: E501
|
||||
)
|
||||
|
||||
from .table import LanceTable
|
||||
|
||||
@@ -20,15 +20,22 @@ import sys
|
||||
import types
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import date, datetime
|
||||
from typing import Any, Callable, Dict, Generator, List, Type, Union, _GenericAlias
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Generator,
|
||||
List,
|
||||
Type,
|
||||
Union,
|
||||
_GenericAlias,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
import pydantic
|
||||
import semver
|
||||
from pydantic.fields import FieldInfo
|
||||
|
||||
from .embeddings import EmbeddingFunctionRegistry
|
||||
|
||||
PYDANTIC_VERSION = semver.Version.parse(pydantic.__version__)
|
||||
try:
|
||||
@@ -37,6 +44,11 @@ except ImportError:
|
||||
if PYDANTIC_VERSION >= (2,):
|
||||
raise
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic.fields import FieldInfo
|
||||
|
||||
from .embeddings import EmbeddingFunctionConfig
|
||||
|
||||
|
||||
class FixedSizeListMixin(ABC):
|
||||
@staticmethod
|
||||
@@ -190,7 +202,7 @@ else:
|
||||
]
|
||||
|
||||
|
||||
def _pydantic_to_arrow_type(field: pydantic.fields.FieldInfo) -> pa.DataType:
|
||||
def _pydantic_to_arrow_type(field: FieldInfo) -> pa.DataType:
|
||||
"""Convert a Pydantic FieldInfo to Arrow DataType"""
|
||||
|
||||
if isinstance(field.annotation, _GenericAlias) or (
|
||||
@@ -221,7 +233,7 @@ def _pydantic_to_arrow_type(field: pydantic.fields.FieldInfo) -> pa.DataType:
|
||||
return _py_type_to_arrow_type(field.annotation, field)
|
||||
|
||||
|
||||
def is_nullable(field: pydantic.fields.FieldInfo) -> bool:
|
||||
def is_nullable(field: FieldInfo) -> bool:
|
||||
"""Check if a Pydantic FieldInfo is nullable."""
|
||||
if isinstance(field.annotation, _GenericAlias):
|
||||
origin = field.annotation.__origin__
|
||||
@@ -237,7 +249,7 @@ def is_nullable(field: pydantic.fields.FieldInfo) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def _pydantic_to_field(name: str, field: pydantic.fields.FieldInfo) -> pa.Field:
|
||||
def _pydantic_to_field(name: str, field: FieldInfo) -> pa.Field:
|
||||
"""Convert a Pydantic field to a PyArrow Field."""
|
||||
dt = _pydantic_to_arrow_type(field)
|
||||
return pa.field(name, dt, is_nullable(field))
|
||||
@@ -309,6 +321,9 @@ class LanceModel(pydantic.BaseModel):
|
||||
schema = pydantic_to_schema(cls)
|
||||
functions = cls.parse_embedding_functions()
|
||||
if len(functions) > 0:
|
||||
# Prevent circular import
|
||||
from .embeddings import EmbeddingFunctionRegistry
|
||||
|
||||
metadata = EmbeddingFunctionRegistry.get_instance().get_table_metadata(
|
||||
functions
|
||||
)
|
||||
@@ -359,7 +374,7 @@ class LanceModel(pydantic.BaseModel):
|
||||
return configs
|
||||
|
||||
|
||||
def get_extras(field_info: pydantic.fields.FieldInfo, key: str) -> Any:
|
||||
def get_extras(field_info: FieldInfo, key: str) -> Any:
|
||||
"""
|
||||
Get the extra metadata from a Pydantic FieldInfo.
|
||||
"""
|
||||
|
||||
@@ -27,7 +27,11 @@ from .common import VECTOR_COLUMN_NAME
|
||||
from .util import safe_import_pandas
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import PIL
|
||||
import polars as pl
|
||||
|
||||
from .pydantic import LanceModel
|
||||
from .table import Table
|
||||
|
||||
pd = safe_import_pandas()
|
||||
|
||||
@@ -60,7 +64,7 @@ class Query(pydantic.BaseModel):
|
||||
- See discussion in [Querying an ANN Index][querying-an-ann-index] for
|
||||
tuning advice.
|
||||
refine_factor : Optional[int]
|
||||
Refine the results by reading extra elements and re-ranking them in memory - optional
|
||||
Refine the results by reading extra elements and re-ranking them in memory.
|
||||
|
||||
- A higher number makes search more accurate but also slower.
|
||||
|
||||
@@ -104,7 +108,7 @@ class LanceQueryBuilder(ABC):
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
table: "lancedb.table.Table",
|
||||
table: "Table",
|
||||
query: Optional[Union[np.ndarray, str, "PIL.Image.Image"]],
|
||||
query_type: str,
|
||||
vector_column_name: str,
|
||||
@@ -163,7 +167,7 @@ class LanceQueryBuilder(ABC):
|
||||
f"Invalid query_type, must be 'vector', 'fts', or 'auto': {query_type}"
|
||||
)
|
||||
|
||||
def __init__(self, table: "lancedb.table.Table"):
|
||||
def __init__(self, table: "Table"):
|
||||
self._table = table
|
||||
self._limit = 10
|
||||
self._columns = None
|
||||
@@ -205,7 +209,6 @@ class LanceQueryBuilder(ABC):
|
||||
if flatten is True:
|
||||
while True:
|
||||
tbl = tbl.flatten()
|
||||
has_struct = False
|
||||
# loop through all columns to check if there is any struct column
|
||||
if any(pa.types.is_struct(col.type) for col in tbl.schema):
|
||||
continue
|
||||
@@ -214,7 +217,8 @@ class LanceQueryBuilder(ABC):
|
||||
elif isinstance(flatten, int):
|
||||
if flatten <= 0:
|
||||
raise ValueError(
|
||||
"Please specify a positive integer for flatten or the boolean value `True`"
|
||||
"Please specify a positive integer for flatten or the boolean "
|
||||
"value `True`"
|
||||
)
|
||||
while flatten > 0:
|
||||
tbl = tbl.flatten()
|
||||
@@ -361,7 +365,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
table: "lancedb.table.Table",
|
||||
table: "Table",
|
||||
query: Union[np.ndarray, list, "PIL.Image.Image"],
|
||||
vector_column: str = VECTOR_COLUMN_NAME,
|
||||
):
|
||||
@@ -486,7 +490,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
class LanceFtsQueryBuilder(LanceQueryBuilder):
|
||||
"""A builder for full text search for LanceDB."""
|
||||
|
||||
def __init__(self, table: "lancedb.table.Table", query: str):
|
||||
def __init__(self, table: "Table", query: str):
|
||||
super().__init__(table)
|
||||
self._query = query
|
||||
self._phrase_query = False
|
||||
@@ -513,7 +517,7 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
|
||||
import tantivy
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Please install tantivy-py `pip install tantivy@git+https://github.com/quickwit-oss/tantivy-py#164adc87e1a033117001cf70e38c82a53014d985` to use the full text search feature."
|
||||
"Please install tantivy-py `pip install tantivy@git+https://github.com/quickwit-oss/tantivy-py#164adc87e1a033117001cf70e38c82a53014d985` to use the full text search feature." # noqa: E501
|
||||
)
|
||||
|
||||
from .fts import search_index
|
||||
@@ -523,8 +527,9 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
|
||||
# check if the index exist
|
||||
if not Path(index_path).exists():
|
||||
raise FileNotFoundError(
|
||||
"Fts index does not exist."
|
||||
f"Please first call table.create_fts_index(['<field_names>']) to create the fts index."
|
||||
"Fts index does not exist. "
|
||||
"Please first call table.create_fts_index(['<field_names>']) to "
|
||||
"create the fts index."
|
||||
)
|
||||
# open the index
|
||||
index = tantivy.Index.open(index_path)
|
||||
@@ -543,19 +548,21 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
|
||||
|
||||
if self._where is not None:
|
||||
try:
|
||||
# TODO would be great to have Substrait generate pyarrow compute expressions
|
||||
# or conversely have pyarrow support SQL expressions using Substrait
|
||||
# TODO would be great to have Substrait generate pyarrow compute
|
||||
# expressions or conversely have pyarrow support SQL expressions
|
||||
# using Substrait
|
||||
import duckdb
|
||||
|
||||
output_tbl = (
|
||||
duckdb.sql(f"SELECT * FROM output_tbl")
|
||||
duckdb.sql("SELECT * FROM output_tbl")
|
||||
.filter(self._where)
|
||||
.to_arrow_table()
|
||||
)
|
||||
except ImportError:
|
||||
import lance
|
||||
import tempfile
|
||||
|
||||
import lance
|
||||
|
||||
# TODO Use "memory://" instead once that's supported
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
ds = lance.write_dataset(output_tbl, tmp)
|
||||
|
||||
@@ -13,12 +13,12 @@
|
||||
|
||||
|
||||
import functools
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import requests
|
||||
import attrs
|
||||
import pyarrow as pa
|
||||
import requests
|
||||
from pydantic import BaseModel
|
||||
|
||||
from lancedb.common import Credential
|
||||
|
||||
@@ -11,7 +11,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import logging
|
||||
import uuid
|
||||
@@ -102,8 +101,10 @@ class RemoteDBConnection(DBConnection):
|
||||
except LanceDBClientError as err:
|
||||
if str(err).startswith("Not found"):
|
||||
logging.error(
|
||||
f"Table {name} does not exist. "
|
||||
f"Please first call db.create_table({name}, data)"
|
||||
"Table %s does not exist. Please first call "
|
||||
"db.create_table(%s, data).",
|
||||
name,
|
||||
name,
|
||||
)
|
||||
return RemoteTable(self, name)
|
||||
|
||||
@@ -160,7 +161,8 @@ class RemoteDBConnection(DBConnection):
|
||||
Can create with list of tuples or dictionaries:
|
||||
|
||||
>>> import lancedb
|
||||
>>> db = lancedb.connect("db://...", api_key="...", region="...") # doctest: +SKIP
|
||||
>>> db = lancedb.connect("db://...", api_key="...", # doctest: +SKIP
|
||||
... region="...") # doctest: +SKIP
|
||||
>>> data = [{"vector": [1.1, 1.2], "lat": 45.5, "long": -122.7},
|
||||
... {"vector": [0.2, 1.8], "lat": 40.1, "long": -74.1}]
|
||||
>>> db.create_table("my_table", data) # doctest: +SKIP
|
||||
|
||||
@@ -11,7 +11,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import asyncio
|
||||
import uuid
|
||||
from functools import cached_property
|
||||
from typing import Dict, Optional, Union
|
||||
@@ -89,7 +88,8 @@ class RemoteTable(Table):
|
||||
>>> import lancedb
|
||||
>>> import uuid
|
||||
>>> from lancedb.schema import vector
|
||||
>>> db = lancedb.connect("db://...", api_key="...", region="...") # doctest: +SKIP
|
||||
>>> db = lancedb.connect("db://...", api_key="...", # doctest: +SKIP
|
||||
... region="...") # doctest: +SKIP
|
||||
>>> table_name = uuid.uuid4().hex
|
||||
>>> schema = pa.schema(
|
||||
... [
|
||||
@@ -125,7 +125,8 @@ class RemoteTable(Table):
|
||||
on_bad_vectors: str = "error",
|
||||
fill_value: float = 0.0,
|
||||
) -> int:
|
||||
"""Add more data to the [Table](Table). It has the same API signature as the OSS version.
|
||||
"""Add more data to the [Table](Table). It has the same API signature as
|
||||
the OSS version.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -176,7 +177,8 @@ class RemoteTable(Table):
|
||||
Examples
|
||||
--------
|
||||
>>> import lancedb
|
||||
>>> db = lancedb.connect("db://...", api_key="...", region="...") # doctest: +SKIP
|
||||
>>> db = lancedb.connect("db://...", api_key="...", # doctest: +SKIP
|
||||
... region="...") # doctest: +SKIP
|
||||
>>> data = [
|
||||
... {"original_width": 100, "caption": "bar", "vector": [0.1, 2.3, 4.5]},
|
||||
... {"original_width": 2000, "caption": "foo", "vector": [0.5, 3.4, 1.3]},
|
||||
@@ -265,7 +267,8 @@ class RemoteTable(Table):
|
||||
... {"x": 2, "vector": [3, 4]},
|
||||
... {"x": 3, "vector": [5, 6]}
|
||||
... ]
|
||||
>>> db = lancedb.connect("db://...", api_key="...", region="...") # doctest: +SKIP
|
||||
>>> db = lancedb.connect("db://...", api_key="...", # doctest: +SKIP
|
||||
... region="...") # doctest: +SKIP
|
||||
>>> table = db.create_table("my_table", data) # doctest: +SKIP
|
||||
>>> table.search([10,10]).to_pandas() # doctest: +SKIP
|
||||
x vector _distance # doctest: +SKIP
|
||||
@@ -323,7 +326,8 @@ class RemoteTable(Table):
|
||||
... {"x": 2, "vector": [3, 4]},
|
||||
... {"x": 3, "vector": [5, 6]}
|
||||
... ]
|
||||
>>> db = lancedb.connect("db://...", api_key="...", region="...") # doctest: +SKIP
|
||||
>>> db = lancedb.connect("db://...", api_key="...", # doctest: +SKIP
|
||||
... region="...") # doctest: +SKIP
|
||||
>>> table = db.create_table("my_table", data) # doctest: +SKIP
|
||||
>>> table.to_pandas() # doctest: +SKIP
|
||||
x vector # doctest: +SKIP
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import cached_property
|
||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Union
|
||||
@@ -33,17 +32,20 @@ from .pydantic import LanceModel, model_to_dict
|
||||
from .query import LanceQueryBuilder, Query
|
||||
from .util import (
|
||||
fs_from_uri,
|
||||
join_uri,
|
||||
safe_import_pandas,
|
||||
safe_import_polars,
|
||||
value_to_sql,
|
||||
join_uri,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from datetime import timedelta
|
||||
|
||||
import PIL
|
||||
from lance.dataset import CleanupStats, ReaderLike
|
||||
|
||||
from .db import LanceDBConnection
|
||||
|
||||
|
||||
pd = safe_import_pandas()
|
||||
pl = safe_import_polars()
|
||||
@@ -525,9 +527,7 @@ class LanceTable(Table):
|
||||
A table in a LanceDB database.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, connection: "lancedb.db.LanceDBConnection", name: str, version: int = None
|
||||
):
|
||||
def __init__(self, connection: "LanceDBConnection", name: str, version: int = None):
|
||||
self._conn = connection
|
||||
self.name = name
|
||||
self._version = version
|
||||
@@ -781,9 +781,7 @@ class LanceTable(Table):
|
||||
index_exists = fs.get_file_info(path).type != pa_fs.FileType.NotFound
|
||||
if index_exists:
|
||||
if not replace:
|
||||
raise ValueError(
|
||||
f"Index already exists. Use replace=True to overwrite."
|
||||
)
|
||||
raise ValueError("Index already exists. Use replace=True to overwrite.")
|
||||
fs.delete_dir(path)
|
||||
|
||||
index = create_index(self._get_fts_index_path(), field_names)
|
||||
|
||||
@@ -12,9 +12,9 @@
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import pathlib
|
||||
from datetime import date, datetime
|
||||
from functools import singledispatch
|
||||
import pathlib
|
||||
from typing import Tuple, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
|
||||
@@ -130,7 +130,8 @@ def test_ingest_iterator(tmp_path):
|
||||
PydanticSchema(vector=[3.1, 4.1], item="foo", price=10.0),
|
||||
PydanticSchema(vector=[5.9, 26.5], item="bar", price=20.0),
|
||||
],
|
||||
# TODO: test pydict separately. it is unique column number and names contraint
|
||||
# TODO: test pydict separately. it is unique column number and
|
||||
# name constraints
|
||||
]
|
||||
|
||||
def run_tests(schema):
|
||||
|
||||
@@ -18,7 +18,7 @@ import pyarrow as pa
|
||||
import pytest
|
||||
|
||||
import lancedb
|
||||
from lancedb.conftest import MockRateLimitedEmbeddingFunction, MockTextEmbeddingFunction
|
||||
from lancedb.conftest import MockTextEmbeddingFunction
|
||||
from lancedb.embeddings import (
|
||||
EmbeddingFunctionConfig,
|
||||
EmbeddingFunctionRegistry,
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
|
||||
|
||||
import json
|
||||
import pytz
|
||||
import sys
|
||||
from datetime import date, datetime
|
||||
from typing import List, Optional, Tuple
|
||||
@@ -49,18 +48,19 @@ def test_pydantic_to_arrow():
|
||||
dt_with_tz: datetime = Field(json_schema_extra={"tz": "Asia/Shanghai"})
|
||||
# d: dict
|
||||
|
||||
m = TestModel(
|
||||
id=1,
|
||||
s="hello",
|
||||
vec=[1.0, 2.0, 3.0],
|
||||
li=[2, 3, 4],
|
||||
lili=[[2.5, 1.5], [3.5, 4.5], [5.5, 6.5]],
|
||||
litu=[(2.5, 1.5), (3.5, 4.5), (5.5, 6.5)],
|
||||
st=StructModel(a="a", b=1.0),
|
||||
dt=date.today(),
|
||||
dtt=datetime.now(),
|
||||
dt_with_tz=datetime.now(pytz.timezone("Asia/Shanghai")),
|
||||
)
|
||||
# TODO: test we can actually convert the model into data.
|
||||
# m = TestModel(
|
||||
# id=1,
|
||||
# s="hello",
|
||||
# vec=[1.0, 2.0, 3.0],
|
||||
# li=[2, 3, 4],
|
||||
# lili=[[2.5, 1.5], [3.5, 4.5], [5.5, 6.5]],
|
||||
# litu=[(2.5, 1.5), (3.5, 4.5), (5.5, 6.5)],
|
||||
# st=StructModel(a="a", b=1.0),
|
||||
# dt=date.today(),
|
||||
# dtt=datetime.now(),
|
||||
# dt_with_tz=datetime.now(pytz.timezone("Asia/Shanghai")),
|
||||
# )
|
||||
|
||||
schema = pydantic_to_schema(TestModel)
|
||||
|
||||
@@ -133,18 +133,19 @@ def test_pydantic_to_arrow_py38():
|
||||
dt_with_tz: datetime = Field(json_schema_extra={"tz": "Asia/Shanghai"})
|
||||
# d: dict
|
||||
|
||||
m = TestModel(
|
||||
id=1,
|
||||
s="hello",
|
||||
vec=[1.0, 2.0, 3.0],
|
||||
li=[2, 3, 4],
|
||||
lili=[[2.5, 1.5], [3.5, 4.5], [5.5, 6.5]],
|
||||
litu=[(2.5, 1.5), (3.5, 4.5), (5.5, 6.5)],
|
||||
st=StructModel(a="a", b=1.0),
|
||||
dt=date.today(),
|
||||
dtt=datetime.now(),
|
||||
dt_with_tz=datetime.now(pytz.timezone("Asia/Shanghai")),
|
||||
)
|
||||
# TODO: test we can actually convert the model to Arrow data.
|
||||
# m = TestModel(
|
||||
# id=1,
|
||||
# s="hello",
|
||||
# vec=[1.0, 2.0, 3.0],
|
||||
# li=[2, 3, 4],
|
||||
# lili=[[2.5, 1.5], [3.5, 4.5], [5.5, 6.5]],
|
||||
# litu=[(2.5, 1.5), (3.5, 4.5), (5.5, 6.5)],
|
||||
# st=StructModel(a="a", b=1.0),
|
||||
# dt=date.today(),
|
||||
# dtt=datetime.now(),
|
||||
# dt_with_tz=datetime.now(pytz.timezone("Asia/Shanghai")),
|
||||
# )
|
||||
|
||||
schema = pydantic_to_schema(TestModel)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user