diff --git a/python/lancedb/common.py b/python/lancedb/common.py index 6be2f228..54c7c9e0 100644 --- a/python/lancedb/common.py +++ b/python/lancedb/common.py @@ -16,9 +16,9 @@ from typing import Iterable, List, Union import numpy as np import pyarrow as pa -from .util import safe_import +from .util import safe_import_pandas -pd = safe_import("pandas") +pd = safe_import_pandas() DATA = Union[List[dict], dict, "pd.DataFrame", pa.Table, Iterable[pa.RecordBatch]] VEC = Union[list, np.ndarray, pa.Array, pa.ChunkedArray] diff --git a/python/lancedb/context.py b/python/lancedb/context.py index 20484c47..bd7b04c8 100644 --- a/python/lancedb/context.py +++ b/python/lancedb/context.py @@ -16,9 +16,9 @@ import deprecation from . import __version__ from .exceptions import MissingColumnError, MissingValueError -from .util import safe_import +from .util import safe_import_pandas -pd = safe_import("pandas") +pd = safe_import_pandas() def contextualize(raw_df: "pd.DataFrame") -> Contextualizer: diff --git a/python/lancedb/embeddings/utils.py b/python/lancedb/embeddings/utils.py index 4708dfd7..325145f4 100644 --- a/python/lancedb/embeddings/utils.py +++ b/python/lancedb/embeddings/utils.py @@ -26,10 +26,10 @@ import pyarrow as pa from lance.vector import vec_to_table from retry import retry -from ..util import safe_import +from ..util import safe_import_pandas from ..utils.general import LOGGER -pd = safe_import("pandas") +pd = safe_import_pandas() DATA = Union[pa.Table, "pd.DataFrame"] TEXT = Union[str, List[str], pa.Array, pa.ChunkedArray, np.ndarray] diff --git a/python/lancedb/query.py b/python/lancedb/query.py index d564e736..ef28eac9 100644 --- a/python/lancedb/query.py +++ b/python/lancedb/query.py @@ -27,7 +27,7 @@ from . import __version__ from .common import VEC, VECTOR_COLUMN_NAME from .rerankers.base import Reranker from .rerankers.linear_combination import LinearCombinationReranker -from .util import safe_import +from .util import safe_import_pandas if TYPE_CHECKING: import PIL @@ -36,7 +36,7 @@ if TYPE_CHECKING: from .pydantic import LanceModel from .table import Table -pd = safe_import("pandas") +pd = safe_import_pandas() class Query(pydantic.BaseModel): diff --git a/python/lancedb/table.py b/python/lancedb/table.py index a589a298..5bd6db16 100644 --- a/python/lancedb/table.py +++ b/python/lancedb/table.py @@ -34,7 +34,8 @@ from .query import LanceQueryBuilder, Query from .util import ( fs_from_uri, join_uri, - safe_import, + safe_import_pandas, + safe_import_polars, value_to_sql, ) @@ -47,8 +48,8 @@ if TYPE_CHECKING: from .db import LanceDBConnection -pd = safe_import("pandas") -pl = safe_import("polars") +pd = safe_import_pandas() +pl = safe_import_polars() def _sanitize_data( diff --git a/python/lancedb/util.py b/python/lancedb/util.py index 02f37095..7eb80ea8 100644 --- a/python/lancedb/util.py +++ b/python/lancedb/util.py @@ -134,6 +134,24 @@ def safe_import(module: str, mitigation=None): raise ImportError(f"Please install {mitigation or module}") +def safe_import_pandas(): + try: + import pandas as pd + + return pd + except ImportError: + return None + + +def safe_import_polars(): + try: + import polars as pl + + return pl + except ImportError: + return None + + @singledispatch def value_to_sql(value): raise NotImplementedError("SQL conversion is not implemented for this type")