fix: make pylance optional again (#2209)

The two remaining blockers were:

* A method `with_embeddings` that was deprecated a year ago
* A typecheck for `LanceDataset`
This commit is contained in:
Will Jones
2025-03-21 11:26:32 -07:00
committed by GitHub
parent bdb6c09c3b
commit b2a38ac366
14 changed files with 49 additions and 111 deletions

View File

@@ -7,10 +7,9 @@ from typing import Iterable, List, Optional, Union
import numpy as np
import pyarrow as pa
import pyarrow.dataset
from .util import safe_import_pandas
pd = safe_import_pandas()
from .dependencies import pandas as pd
DATA = Union[List[dict], "pd.DataFrame", pa.Table, Iterable[pa.RecordBatch]]
VEC = Union[list, np.ndarray, pa.Array, pa.ChunkedArray]

View File

@@ -8,9 +8,7 @@ import deprecation
from . import __version__
from .exceptions import MissingColumnError, MissingValueError
from .util import safe_import_pandas
pd = safe_import_pandas()
from .dependencies import pandas as pd
def contextualize(raw_df: "pd.DataFrame") -> Contextualizer:

View File

@@ -30,6 +30,7 @@ _TORCH_AVAILABLE = True
_HUGGING_FACE_AVAILABLE = True
_TENSORFLOW_AVAILABLE = True
_RAY_AVAILABLE = True
_LANCE_AVAILABLE = True
class _LazyModule(ModuleType):
@@ -53,6 +54,7 @@ class _LazyModule(ModuleType):
"torch": "torch.",
"tensorflow": "tf.",
"ray": "ray.",
"lance": "lance.",
}
def __init__(
@@ -169,6 +171,7 @@ if TYPE_CHECKING:
import ray
import tensorflow
import torch
import lance
else:
# heavy/optional third party libs
numpy, _NUMPY_AVAILABLE = _lazy_import("numpy")
@@ -178,6 +181,7 @@ else:
datasets, _HUGGING_FACE_AVAILABLE = _lazy_import("datasets")
tensorflow, _TENSORFLOW_AVAILABLE = _lazy_import("tensorflow")
ray, _RAY_AVAILABLE = _lazy_import("ray")
lance, _LANCE_AVAILABLE = _lazy_import("lance")
@lru_cache(maxsize=None)
@@ -232,6 +236,12 @@ def _check_for_ray(obj: Any, *, check_type: bool = True) -> bool:
)
def _check_for_lance(obj: Any, *, check_type: bool = True) -> bool:
return _LANCE_AVAILABLE and _might_be(
cast(Hashable, type(obj) if check_type else obj), "lance"
)
__all__ = [
# lazy-load third party libs
"datasets",
@@ -241,6 +251,7 @@ __all__ = [
"ray",
"tensorflow",
"torch",
"lance",
# lazy utilities
"_check_for_hugging_face",
"_check_for_numpy",
@@ -249,6 +260,7 @@ __all__ = [
"_check_for_tensorflow",
"_check_for_torch",
"_check_for_ray",
"_check_for_lance",
"_LazyModule",
# exported flags/guards
"_NUMPY_AVAILABLE",
@@ -258,4 +270,5 @@ __all__ = [
"_HUGGING_FACE_AVAILABLE",
"_TENSORFLOW_AVAILABLE",
"_RAY_AVAILABLE",
"_LANCE_AVAILABLE",
]

View File

@@ -16,7 +16,6 @@ from .sentence_transformers import SentenceTransformerEmbeddings
from .gte import GteEmbeddings
from .transformers import TransformersEmbeddingFunction, ColbertEmbeddings
from .imagebind import ImageBindEmbeddings
from .utils import with_embeddings
from .jinaai import JinaEmbeddings
from .watsonx import WatsonxEmbeddings
from .voyageai import VoyageAIEmbeddingFunction

View File

@@ -16,9 +16,8 @@ from functools import wraps
from typing import Callable, List, Union
import numpy as np
import pyarrow as pa
from lance.vector import vec_to_table
from ..util import deprecated, safe_import_pandas
from ..dependencies import pandas as pd
# ruff: noqa: PERF203
@@ -41,8 +40,6 @@ def retry(tries=10, delay=1, max_delay=30, backoff=3, jitter=1):
return wrapper
pd = safe_import_pandas()
DATA = Union[pa.Table, "pd.DataFrame"]
TEXT = Union[str, List[str], pa.Array, pa.ChunkedArray, np.ndarray]
IMAGES = Union[
@@ -87,52 +84,6 @@ class RateLimiter:
return wrapper
@deprecated
def with_embeddings(
func: Callable,
data: DATA,
column: str = "text",
wrap_api: bool = True,
show_progress: bool = False,
batch_size: int = 1000,
) -> pa.Table:
"""Add a vector column to a table using the given embedding function.
The new columns will be called "vector".
Parameters
----------
func : Callable
A function that takes a list of strings and returns a list of vectors.
data : pa.Table or pd.DataFrame
The data to add an embedding column to.
column : str, default "text"
The name of the column to use as input to the embedding function.
wrap_api : bool, default True
Whether to wrap the embedding function in a retry and rate limiter.
show_progress : bool, default False
Whether to show a progress bar.
batch_size : int, default 1000
The number of row values to pass to each call of the embedding function.
Returns
-------
pa.Table
The input table with a new column called "vector" containing the embeddings.
"""
func = FunctionWrapper(func)
if wrap_api:
func = func.retry().rate_limit()
func = func.batch_size(batch_size)
if show_progress:
func = func.show_progress()
if pd is not None and isinstance(data, pd.DataFrame):
data = pa.Table.from_pandas(data, preserve_index=False)
embeddings = func(data[column].to_numpy())
table = vec_to_table(np.array(embeddings))
return data.append_column("vector", table["vector"])
class FunctionWrapper:
"""
A wrapper for embedding functions that adds rate limiting, retries, and batching.

View File

@@ -5,6 +5,7 @@ import logging
from typing import Any, List, Optional, Tuple, Union, Literal
import pyarrow as pa
import pyarrow.dataset
from ..table import Table

View File

@@ -26,10 +26,11 @@ import pydantic
from . import __version__
from .arrow import AsyncRecordBatchReader
from .dependencies import pandas as pd
from .rerankers.base import Reranker
from .rerankers.rrf import RRFReranker
from .rerankers.util import check_reranker_result
from .util import safe_import_pandas, flatten_columns
from .util import flatten_columns
if TYPE_CHECKING:
import sys
@@ -49,8 +50,6 @@ if TYPE_CHECKING:
else:
from typing_extensions import Self
pd = safe_import_pandas()
class Query(pydantic.BaseModel):
"""The LanceDB Query

View File

@@ -28,12 +28,19 @@ from urllib.parse import urlparse
from . import __version__
from lancedb.arrow import peek_reader
from lancedb.background_loop import LOOP
from .dependencies import _check_for_hugging_face, _check_for_pandas
from .dependencies import (
_check_for_hugging_face,
_check_for_lance,
_check_for_pandas,
lance,
pandas as pd,
polars as pl,
)
import pyarrow as pa
import pyarrow.dataset
import pyarrow.compute as pc
import pyarrow.fs as pa_fs
import numpy as np
from lance import LanceDataset
from .common import DATA, VEC, VECTOR_COLUMN_NAME
from .embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
@@ -58,8 +65,6 @@ from .util import (
get_uri_scheme,
infer_vector_column_name,
join_uri,
safe_import_pandas,
safe_import_polars,
value_to_sql,
)
from .index import lang_mapping
@@ -88,10 +93,6 @@ if TYPE_CHECKING:
)
pd = safe_import_pandas()
pl = safe_import_polars()
def _into_pyarrow_reader(data) -> pa.RecordBatchReader:
from lancedb.dependencies import datasets
@@ -130,7 +131,7 @@ def _into_pyarrow_reader(data) -> pa.RecordBatchReader:
return data.to_reader()
elif isinstance(data, pa.RecordBatch):
return pa.RecordBatchReader.from_batches(data.schema, [data])
elif isinstance(data, LanceDataset):
elif _check_for_lance(data) and isinstance(data, lance.LanceDataset):
return data.scanner().to_reader()
elif isinstance(data, pa.dataset.Dataset):
return data.scanner().to_reader()
@@ -1440,7 +1441,7 @@ class LanceTable(Table):
# Cacheable since it's deterministic
return _table_path(self._conn.uri, self.name)
def to_lance(self, **kwargs) -> LanceDataset:
def to_lance(self, **kwargs) -> lance.LanceDataset:
"""Return the LanceDataset backing this table."""
try:
import lance

View File

@@ -157,24 +157,6 @@ def attempt_import_or_raise(module: str, mitigation=None):
raise ImportError(f"Please install {mitigation or module}")
def safe_import_pandas():
try:
import pandas as pd
return pd
except ImportError:
return None
def safe_import_polars():
try:
import polars as pl
return pl
except ImportError:
return None
def flatten_columns(tbl: pa.Table, flatten: Optional[Union[int, bool]] = None):
"""
Flatten all struct columns in a table.

View File

@@ -15,7 +15,6 @@ from lancedb.conftest import MockTextEmbeddingFunction
from lancedb.embeddings import (
EmbeddingFunctionConfig,
EmbeddingFunctionRegistry,
with_embeddings,
)
from lancedb.embeddings.base import TextEmbeddingFunction
from lancedb.embeddings.registry import get_registry, register
@@ -27,23 +26,6 @@ def mock_embed_func(input_data):
return [np.random.randn(128).tolist() for _ in range(len(input_data))]
def test_with_embeddings():
for wrap_api in [True, False]:
data = pa.Table.from_arrays(
[
pa.array(["foo", "bar"]),
pa.array([10.0, 20.0]),
],
names=["text", "price"],
)
data = with_embeddings(mock_embed_func, data, wrap_api=wrap_api)
assert data.num_columns == 3
assert data.num_rows == 2
assert data.column_names == ["text", "price", "vector"]
assert data.column("text").to_pylist() == ["foo", "bar"]
assert data.column("price").to_pylist() == [10.0, 20.0]
def test_embedding_function(tmp_path):
registry = EmbeddingFunctionRegistry.get_instance()

View File

@@ -8,13 +8,13 @@ from time import sleep
from typing import List
from unittest.mock import patch
import lance
import lancedb
from lancedb.index import HnswPq, HnswSq, IvfPq
import numpy as np
import pandas as pd
import polars as pl
import pyarrow as pa
import pyarrow.dataset
import pytest
from lancedb.conftest import MockTextEmbeddingFunction
from lancedb.db import AsyncConnection, DBConnection
@@ -650,6 +650,9 @@ def test_restore(mem_db: DBConnection):
def test_merge(tmp_db: DBConnection, tmp_path):
pytest.importorskip("lance")
import lance
table = tmp_db.create_table(
"my_table",
schema=pa.schema(
@@ -1145,6 +1148,7 @@ def test_search_with_schema_inf_multiple_vector(mem_db: DBConnection):
def test_compact_cleanup(tmp_db: DBConnection):
pytest.importorskip("lance")
table = tmp_db.create_table(
"my_table",
data=[{"text": "foo", "id": 0}, {"text": "bar", "id": 1}],
@@ -1222,6 +1226,7 @@ def setup_hybrid_search_table(db: DBConnection, embedding_func):
def test_hybrid_search(tmp_db: DBConnection):
# This test uses an FTS index
pytest.importorskip("lancedb.fts")
pytest.importorskip("lance")
table, MyTable, emb = setup_hybrid_search_table(tmp_db, "test")
@@ -1292,6 +1297,7 @@ def test_hybrid_search(tmp_db: DBConnection):
def test_hybrid_search_metric_type(tmp_db: DBConnection):
# This test uses an FTS index
pytest.importorskip("lancedb.fts")
pytest.importorskip("lance")
# Need to use nonnorm as the embedding function so l2 and dot results
# are different