mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-27 15:12:53 +00:00
fix: make pylance optional again (#2209)
The two remaining blockers were: * A method `with_embeddings` that was deprecated a year ago * A typecheck for `LanceDataset`
This commit is contained in:
@@ -7,10 +7,9 @@ from typing import Iterable, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
import pyarrow.dataset
|
||||
|
||||
from .util import safe_import_pandas
|
||||
|
||||
pd = safe_import_pandas()
|
||||
from .dependencies import pandas as pd
|
||||
|
||||
DATA = Union[List[dict], "pd.DataFrame", pa.Table, Iterable[pa.RecordBatch]]
|
||||
VEC = Union[list, np.ndarray, pa.Array, pa.ChunkedArray]
|
||||
|
||||
@@ -8,9 +8,7 @@ import deprecation
|
||||
|
||||
from . import __version__
|
||||
from .exceptions import MissingColumnError, MissingValueError
|
||||
from .util import safe_import_pandas
|
||||
|
||||
pd = safe_import_pandas()
|
||||
from .dependencies import pandas as pd
|
||||
|
||||
|
||||
def contextualize(raw_df: "pd.DataFrame") -> Contextualizer:
|
||||
|
||||
@@ -30,6 +30,7 @@ _TORCH_AVAILABLE = True
|
||||
_HUGGING_FACE_AVAILABLE = True
|
||||
_TENSORFLOW_AVAILABLE = True
|
||||
_RAY_AVAILABLE = True
|
||||
_LANCE_AVAILABLE = True
|
||||
|
||||
|
||||
class _LazyModule(ModuleType):
|
||||
@@ -53,6 +54,7 @@ class _LazyModule(ModuleType):
|
||||
"torch": "torch.",
|
||||
"tensorflow": "tf.",
|
||||
"ray": "ray.",
|
||||
"lance": "lance.",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
@@ -169,6 +171,7 @@ if TYPE_CHECKING:
|
||||
import ray
|
||||
import tensorflow
|
||||
import torch
|
||||
import lance
|
||||
else:
|
||||
# heavy/optional third party libs
|
||||
numpy, _NUMPY_AVAILABLE = _lazy_import("numpy")
|
||||
@@ -178,6 +181,7 @@ else:
|
||||
datasets, _HUGGING_FACE_AVAILABLE = _lazy_import("datasets")
|
||||
tensorflow, _TENSORFLOW_AVAILABLE = _lazy_import("tensorflow")
|
||||
ray, _RAY_AVAILABLE = _lazy_import("ray")
|
||||
lance, _LANCE_AVAILABLE = _lazy_import("lance")
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
@@ -232,6 +236,12 @@ def _check_for_ray(obj: Any, *, check_type: bool = True) -> bool:
|
||||
)
|
||||
|
||||
|
||||
def _check_for_lance(obj: Any, *, check_type: bool = True) -> bool:
|
||||
return _LANCE_AVAILABLE and _might_be(
|
||||
cast(Hashable, type(obj) if check_type else obj), "lance"
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
# lazy-load third party libs
|
||||
"datasets",
|
||||
@@ -241,6 +251,7 @@ __all__ = [
|
||||
"ray",
|
||||
"tensorflow",
|
||||
"torch",
|
||||
"lance",
|
||||
# lazy utilities
|
||||
"_check_for_hugging_face",
|
||||
"_check_for_numpy",
|
||||
@@ -249,6 +260,7 @@ __all__ = [
|
||||
"_check_for_tensorflow",
|
||||
"_check_for_torch",
|
||||
"_check_for_ray",
|
||||
"_check_for_lance",
|
||||
"_LazyModule",
|
||||
# exported flags/guards
|
||||
"_NUMPY_AVAILABLE",
|
||||
@@ -258,4 +270,5 @@ __all__ = [
|
||||
"_HUGGING_FACE_AVAILABLE",
|
||||
"_TENSORFLOW_AVAILABLE",
|
||||
"_RAY_AVAILABLE",
|
||||
"_LANCE_AVAILABLE",
|
||||
]
|
||||
|
||||
@@ -16,7 +16,6 @@ from .sentence_transformers import SentenceTransformerEmbeddings
|
||||
from .gte import GteEmbeddings
|
||||
from .transformers import TransformersEmbeddingFunction, ColbertEmbeddings
|
||||
from .imagebind import ImageBindEmbeddings
|
||||
from .utils import with_embeddings
|
||||
from .jinaai import JinaEmbeddings
|
||||
from .watsonx import WatsonxEmbeddings
|
||||
from .voyageai import VoyageAIEmbeddingFunction
|
||||
|
||||
@@ -16,9 +16,8 @@ from functools import wraps
|
||||
from typing import Callable, List, Union
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
from lance.vector import vec_to_table
|
||||
|
||||
from ..util import deprecated, safe_import_pandas
|
||||
from ..dependencies import pandas as pd
|
||||
|
||||
|
||||
# ruff: noqa: PERF203
|
||||
@@ -41,8 +40,6 @@ def retry(tries=10, delay=1, max_delay=30, backoff=3, jitter=1):
|
||||
return wrapper
|
||||
|
||||
|
||||
pd = safe_import_pandas()
|
||||
|
||||
DATA = Union[pa.Table, "pd.DataFrame"]
|
||||
TEXT = Union[str, List[str], pa.Array, pa.ChunkedArray, np.ndarray]
|
||||
IMAGES = Union[
|
||||
@@ -87,52 +84,6 @@ class RateLimiter:
|
||||
return wrapper
|
||||
|
||||
|
||||
@deprecated
|
||||
def with_embeddings(
|
||||
func: Callable,
|
||||
data: DATA,
|
||||
column: str = "text",
|
||||
wrap_api: bool = True,
|
||||
show_progress: bool = False,
|
||||
batch_size: int = 1000,
|
||||
) -> pa.Table:
|
||||
"""Add a vector column to a table using the given embedding function.
|
||||
|
||||
The new columns will be called "vector".
|
||||
|
||||
Parameters
|
||||
----------
|
||||
func : Callable
|
||||
A function that takes a list of strings and returns a list of vectors.
|
||||
data : pa.Table or pd.DataFrame
|
||||
The data to add an embedding column to.
|
||||
column : str, default "text"
|
||||
The name of the column to use as input to the embedding function.
|
||||
wrap_api : bool, default True
|
||||
Whether to wrap the embedding function in a retry and rate limiter.
|
||||
show_progress : bool, default False
|
||||
Whether to show a progress bar.
|
||||
batch_size : int, default 1000
|
||||
The number of row values to pass to each call of the embedding function.
|
||||
|
||||
Returns
|
||||
-------
|
||||
pa.Table
|
||||
The input table with a new column called "vector" containing the embeddings.
|
||||
"""
|
||||
func = FunctionWrapper(func)
|
||||
if wrap_api:
|
||||
func = func.retry().rate_limit()
|
||||
func = func.batch_size(batch_size)
|
||||
if show_progress:
|
||||
func = func.show_progress()
|
||||
if pd is not None and isinstance(data, pd.DataFrame):
|
||||
data = pa.Table.from_pandas(data, preserve_index=False)
|
||||
embeddings = func(data[column].to_numpy())
|
||||
table = vec_to_table(np.array(embeddings))
|
||||
return data.append_column("vector", table["vector"])
|
||||
|
||||
|
||||
class FunctionWrapper:
|
||||
"""
|
||||
A wrapper for embedding functions that adds rate limiting, retries, and batching.
|
||||
|
||||
@@ -5,6 +5,7 @@ import logging
|
||||
from typing import Any, List, Optional, Tuple, Union, Literal
|
||||
|
||||
import pyarrow as pa
|
||||
import pyarrow.dataset
|
||||
|
||||
from ..table import Table
|
||||
|
||||
|
||||
@@ -26,10 +26,11 @@ import pydantic
|
||||
|
||||
from . import __version__
|
||||
from .arrow import AsyncRecordBatchReader
|
||||
from .dependencies import pandas as pd
|
||||
from .rerankers.base import Reranker
|
||||
from .rerankers.rrf import RRFReranker
|
||||
from .rerankers.util import check_reranker_result
|
||||
from .util import safe_import_pandas, flatten_columns
|
||||
from .util import flatten_columns
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import sys
|
||||
@@ -49,8 +50,6 @@ if TYPE_CHECKING:
|
||||
else:
|
||||
from typing_extensions import Self
|
||||
|
||||
pd = safe_import_pandas()
|
||||
|
||||
|
||||
class Query(pydantic.BaseModel):
|
||||
"""The LanceDB Query
|
||||
|
||||
@@ -28,12 +28,19 @@ from urllib.parse import urlparse
|
||||
from . import __version__
|
||||
from lancedb.arrow import peek_reader
|
||||
from lancedb.background_loop import LOOP
|
||||
from .dependencies import _check_for_hugging_face, _check_for_pandas
|
||||
from .dependencies import (
|
||||
_check_for_hugging_face,
|
||||
_check_for_lance,
|
||||
_check_for_pandas,
|
||||
lance,
|
||||
pandas as pd,
|
||||
polars as pl,
|
||||
)
|
||||
import pyarrow as pa
|
||||
import pyarrow.dataset
|
||||
import pyarrow.compute as pc
|
||||
import pyarrow.fs as pa_fs
|
||||
import numpy as np
|
||||
from lance import LanceDataset
|
||||
|
||||
from .common import DATA, VEC, VECTOR_COLUMN_NAME
|
||||
from .embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
|
||||
@@ -58,8 +65,6 @@ from .util import (
|
||||
get_uri_scheme,
|
||||
infer_vector_column_name,
|
||||
join_uri,
|
||||
safe_import_pandas,
|
||||
safe_import_polars,
|
||||
value_to_sql,
|
||||
)
|
||||
from .index import lang_mapping
|
||||
@@ -88,10 +93,6 @@ if TYPE_CHECKING:
|
||||
)
|
||||
|
||||
|
||||
pd = safe_import_pandas()
|
||||
pl = safe_import_polars()
|
||||
|
||||
|
||||
def _into_pyarrow_reader(data) -> pa.RecordBatchReader:
|
||||
from lancedb.dependencies import datasets
|
||||
|
||||
@@ -130,7 +131,7 @@ def _into_pyarrow_reader(data) -> pa.RecordBatchReader:
|
||||
return data.to_reader()
|
||||
elif isinstance(data, pa.RecordBatch):
|
||||
return pa.RecordBatchReader.from_batches(data.schema, [data])
|
||||
elif isinstance(data, LanceDataset):
|
||||
elif _check_for_lance(data) and isinstance(data, lance.LanceDataset):
|
||||
return data.scanner().to_reader()
|
||||
elif isinstance(data, pa.dataset.Dataset):
|
||||
return data.scanner().to_reader()
|
||||
@@ -1440,7 +1441,7 @@ class LanceTable(Table):
|
||||
# Cacheable since it's deterministic
|
||||
return _table_path(self._conn.uri, self.name)
|
||||
|
||||
def to_lance(self, **kwargs) -> LanceDataset:
|
||||
def to_lance(self, **kwargs) -> lance.LanceDataset:
|
||||
"""Return the LanceDataset backing this table."""
|
||||
try:
|
||||
import lance
|
||||
|
||||
@@ -157,24 +157,6 @@ def attempt_import_or_raise(module: str, mitigation=None):
|
||||
raise ImportError(f"Please install {mitigation or module}")
|
||||
|
||||
|
||||
def safe_import_pandas():
|
||||
try:
|
||||
import pandas as pd
|
||||
|
||||
return pd
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
|
||||
def safe_import_polars():
|
||||
try:
|
||||
import polars as pl
|
||||
|
||||
return pl
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
|
||||
def flatten_columns(tbl: pa.Table, flatten: Optional[Union[int, bool]] = None):
|
||||
"""
|
||||
Flatten all struct columns in a table.
|
||||
|
||||
@@ -15,7 +15,6 @@ from lancedb.conftest import MockTextEmbeddingFunction
|
||||
from lancedb.embeddings import (
|
||||
EmbeddingFunctionConfig,
|
||||
EmbeddingFunctionRegistry,
|
||||
with_embeddings,
|
||||
)
|
||||
from lancedb.embeddings.base import TextEmbeddingFunction
|
||||
from lancedb.embeddings.registry import get_registry, register
|
||||
@@ -27,23 +26,6 @@ def mock_embed_func(input_data):
|
||||
return [np.random.randn(128).tolist() for _ in range(len(input_data))]
|
||||
|
||||
|
||||
def test_with_embeddings():
|
||||
for wrap_api in [True, False]:
|
||||
data = pa.Table.from_arrays(
|
||||
[
|
||||
pa.array(["foo", "bar"]),
|
||||
pa.array([10.0, 20.0]),
|
||||
],
|
||||
names=["text", "price"],
|
||||
)
|
||||
data = with_embeddings(mock_embed_func, data, wrap_api=wrap_api)
|
||||
assert data.num_columns == 3
|
||||
assert data.num_rows == 2
|
||||
assert data.column_names == ["text", "price", "vector"]
|
||||
assert data.column("text").to_pylist() == ["foo", "bar"]
|
||||
assert data.column("price").to_pylist() == [10.0, 20.0]
|
||||
|
||||
|
||||
def test_embedding_function(tmp_path):
|
||||
registry = EmbeddingFunctionRegistry.get_instance()
|
||||
|
||||
|
||||
@@ -8,13 +8,13 @@ from time import sleep
|
||||
from typing import List
|
||||
from unittest.mock import patch
|
||||
|
||||
import lance
|
||||
import lancedb
|
||||
from lancedb.index import HnswPq, HnswSq, IvfPq
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import polars as pl
|
||||
import pyarrow as pa
|
||||
import pyarrow.dataset
|
||||
import pytest
|
||||
from lancedb.conftest import MockTextEmbeddingFunction
|
||||
from lancedb.db import AsyncConnection, DBConnection
|
||||
@@ -650,6 +650,9 @@ def test_restore(mem_db: DBConnection):
|
||||
|
||||
|
||||
def test_merge(tmp_db: DBConnection, tmp_path):
|
||||
pytest.importorskip("lance")
|
||||
import lance
|
||||
|
||||
table = tmp_db.create_table(
|
||||
"my_table",
|
||||
schema=pa.schema(
|
||||
@@ -1145,6 +1148,7 @@ def test_search_with_schema_inf_multiple_vector(mem_db: DBConnection):
|
||||
|
||||
|
||||
def test_compact_cleanup(tmp_db: DBConnection):
|
||||
pytest.importorskip("lance")
|
||||
table = tmp_db.create_table(
|
||||
"my_table",
|
||||
data=[{"text": "foo", "id": 0}, {"text": "bar", "id": 1}],
|
||||
@@ -1222,6 +1226,7 @@ def setup_hybrid_search_table(db: DBConnection, embedding_func):
|
||||
def test_hybrid_search(tmp_db: DBConnection):
|
||||
# This test uses an FTS index
|
||||
pytest.importorskip("lancedb.fts")
|
||||
pytest.importorskip("lance")
|
||||
|
||||
table, MyTable, emb = setup_hybrid_search_table(tmp_db, "test")
|
||||
|
||||
@@ -1292,6 +1297,7 @@ def test_hybrid_search(tmp_db: DBConnection):
|
||||
def test_hybrid_search_metric_type(tmp_db: DBConnection):
|
||||
# This test uses an FTS index
|
||||
pytest.importorskip("lancedb.fts")
|
||||
pytest.importorskip("lance")
|
||||
|
||||
# Need to use nonnorm as the embedding function so l2 and dot results
|
||||
# are different
|
||||
|
||||
Reference in New Issue
Block a user