From c1f8feb6edd115ee818c9b84639b808163d28fbb Mon Sep 17 00:00:00 2001 From: Chang She <759245+changhiskhan@users.noreply.github.com> Date: Mon, 31 Jul 2023 14:08:58 -0400 Subject: [PATCH] make pandas an optional dependency in lancedb as well (#385) --- python/lancedb/common.py | 11 +++++------ python/lancedb/context.py | 13 +++++++++---- python/lancedb/db.py | 7 ++----- python/lancedb/embeddings.py | 10 +++++++--- python/lancedb/query.py | 8 +++++--- python/lancedb/remote/db.py | 1 - python/lancedb/remote/table.py | 4 ++-- python/lancedb/schema.py | 4 ---- python/lancedb/table.py | 11 ++++++----- python/lancedb/util.py | 10 +++++++++- 10 files changed, 45 insertions(+), 34 deletions(-) diff --git a/python/lancedb/common.py b/python/lancedb/common.py index f50451a5..54c7c9e0 100644 --- a/python/lancedb/common.py +++ b/python/lancedb/common.py @@ -11,19 +11,18 @@ # See the License for the specific language governing permissions and # limitations under the License. from pathlib import Path -from typing import List, Union +from typing import Iterable, List, Union import numpy as np -import pandas as pd import pyarrow as pa -from .pydantic import LanceModel +from .util import safe_import_pandas +pd = safe_import_pandas() + +DATA = Union[List[dict], dict, "pd.DataFrame", pa.Table, Iterable[pa.RecordBatch]] VEC = Union[list, np.ndarray, pa.Array, pa.ChunkedArray] URI = Union[str, Path] - -# TODO support generator -DATA = Union[List[dict], List[LanceModel], dict, pd.DataFrame] VECTOR_COLUMN_NAME = "vector" diff --git a/python/lancedb/context.py b/python/lancedb/context.py index 219f90f7..b29946c6 100644 --- a/python/lancedb/context.py +++ b/python/lancedb/context.py @@ -12,12 +12,13 @@ # limitations under the License. from __future__ import annotations -import pandas as pd - from .exceptions import MissingColumnError, MissingValueError +from .util import safe_import_pandas + +pd = safe_import_pandas() -def contextualize(raw_df: pd.DataFrame) -> Contextualizer: +def contextualize(raw_df: "pd.DataFrame") -> Contextualizer: """Create a Contextualizer object for the given DataFrame. Used to create context windows. Context windows are rolling subsets of text @@ -175,8 +176,12 @@ class Contextualizer: self._min_window_size = min_window_size return self - def to_df(self) -> pd.DataFrame: + def to_df(self) -> "pd.DataFrame": """Create the context windows and return a DataFrame.""" + if pd is None: + raise ImportError( + "pandas is required to create context windows using lancedb" + ) if self._text_col not in self._raw_df.columns.tolist(): raise MissingColumnError(self._text_col) diff --git a/python/lancedb/db.py b/python/lancedb/db.py index 548a1c2d..e87ea533 100644 --- a/python/lancedb/db.py +++ b/python/lancedb/db.py @@ -16,9 +16,8 @@ from __future__ import annotations import os from abc import ABC, abstractmethod from pathlib import Path -from typing import Dict, Iterable, List, Optional, Tuple, Union +from typing import Optional -import pandas as pd import pyarrow as pa from pyarrow import fs @@ -39,9 +38,7 @@ class DBConnection(ABC): def create_table( self, name: str, - data: Optional[ - Union[List[dict], dict, pd.DataFrame, pa.Table, Iterable[pa.RecordBatch]], - ] = None, + data: Optional[DATA] = None, schema: Optional[pa.Schema] = None, mode: str = "create", on_bad_vectors: str = "error", diff --git a/python/lancedb/embeddings.py b/python/lancedb/embeddings.py index 03568101..f8c419a0 100644 --- a/python/lancedb/embeddings.py +++ b/python/lancedb/embeddings.py @@ -16,15 +16,19 @@ import sys from typing import Callable, Union import numpy as np -import pandas as pd import pyarrow as pa from lance.vector import vec_to_table from retry import retry +from .util import safe_import_pandas + +pd = safe_import_pandas() +DATA = Union[pa.Table, "pd.DataFrame"] + def with_embeddings( func: Callable, - data: Union[pa.Table, pd.DataFrame], + data: DATA, column: str = "text", wrap_api: bool = True, show_progress: bool = False, @@ -60,7 +64,7 @@ def with_embeddings( func = func.batch_size(batch_size) if show_progress: func = func.show_progress() - if isinstance(data, pd.DataFrame): + if pd is not None and isinstance(data, pd.DataFrame): data = pa.Table.from_pandas(data, preserve_index=False) embeddings = func(data[column].to_numpy()) table = vec_to_table(np.array(embeddings)) diff --git a/python/lancedb/query.py b/python/lancedb/query.py index d2fd5afd..3092a760 100644 --- a/python/lancedb/query.py +++ b/python/lancedb/query.py @@ -16,12 +16,14 @@ from __future__ import annotations from typing import List, Literal, Optional, Type, Union import numpy as np -import pandas as pd import pyarrow as pa import pydantic from .common import VECTOR_COLUMN_NAME from .pydantic import LanceModel +from .util import safe_import_pandas + +pd = safe_import_pandas() class Query(pydantic.BaseModel): @@ -199,7 +201,7 @@ class LanceQueryBuilder: self._refine_factor = refine_factor return self - def to_df(self) -> pd.DataFrame: + def to_df(self) -> "pd.DataFrame": """ Execute the query and return the results as a pandas DataFrame. In addition to the selected columns, LanceDB also returns a vector @@ -250,7 +252,7 @@ class LanceQueryBuilder: class LanceFtsQueryBuilder(LanceQueryBuilder): - def to_arrow(self) -> pd.Table: + def to_arrow(self) -> pa.Table: try: import tantivy except ImportError: diff --git a/python/lancedb/remote/db.py b/python/lancedb/remote/db.py index dcbb3332..4ea889fa 100644 --- a/python/lancedb/remote/db.py +++ b/python/lancedb/remote/db.py @@ -20,7 +20,6 @@ import pyarrow as pa from lancedb.common import DATA from lancedb.db import DBConnection -from lancedb.schema import schema_to_json from lancedb.table import Table, _sanitize_data from .arrow import to_ipc_binary diff --git a/python/lancedb/remote/table.py b/python/lancedb/remote/table.py index f0618527..66bb224b 100644 --- a/python/lancedb/remote/table.py +++ b/python/lancedb/remote/table.py @@ -16,11 +16,11 @@ from functools import cached_property from typing import Union import pyarrow as pa +from lance import json_to_schema from lancedb.common import DATA, VEC, VECTOR_COLUMN_NAME -from ..query import LanceQueryBuilder, Query -from ..schema import json_to_schema +from ..query import LanceQueryBuilder from ..table import Query, Table, _sanitize_data from .arrow import to_ipc_binary from .client import ARROW_STREAM_CONTENT_TYPE diff --git a/python/lancedb/schema.py b/python/lancedb/schema.py index 8d8a77a4..9b5dd5e7 100644 --- a/python/lancedb/schema.py +++ b/python/lancedb/schema.py @@ -12,11 +12,7 @@ # limitations under the License. """Schema related utilities.""" - -from typing import Any, Dict, Type - import pyarrow as pa -from lance import json_to_schema, schema_to_json def vector(dimension: int, value_type: pa.DataType = pa.float32()) -> pa.DataType: diff --git a/python/lancedb/table.py b/python/lancedb/table.py index 5be962b9..3a08af53 100644 --- a/python/lancedb/table.py +++ b/python/lancedb/table.py @@ -20,7 +20,6 @@ from typing import Iterable, List, Union import lance import numpy as np -import pandas as pd import pyarrow as pa import pyarrow.compute as pc from lance import LanceDataset @@ -29,7 +28,9 @@ from lance.vector import vec_to_table from .common import DATA, VEC, VECTOR_COLUMN_NAME from .pydantic import LanceModel from .query import LanceFtsQueryBuilder, LanceQueryBuilder, Query -from .util import fs_from_uri +from .util import fs_from_uri, safe_import_pandas + +pd = safe_import_pandas() def _sanitize_data(data, schema, on_bad_vectors, fill_value): @@ -44,7 +45,7 @@ def _sanitize_data(data, schema, on_bad_vectors, fill_value): ) if isinstance(data, dict): data = vec_to_table(data) - if isinstance(data, pd.DataFrame): + if pd is not None and isinstance(data, pd.DataFrame): data = pa.Table.from_pandas(data) data = _sanitize_schema( data, schema=schema, on_bad_vectors=on_bad_vectors, fill_value=fill_value @@ -99,7 +100,7 @@ class Table(ABC): """ raise NotImplementedError - def to_pandas(self) -> pd.DataFrame: + def to_pandas(self): """Return the table as a pandas DataFrame. Returns @@ -333,7 +334,7 @@ class LanceTable(Table): """Return the first n rows of the table.""" return self._dataset.head(n) - def to_pandas(self) -> pd.DataFrame: + def to_pandas(self) -> "pd.DataFrame": """Return the table as a pandas DataFrame. Returns diff --git a/python/lancedb/util.py b/python/lancedb/util.py index 1d844aef..1abd8dc7 100644 --- a/python/lancedb/util.py +++ b/python/lancedb/util.py @@ -15,7 +15,6 @@ import os from typing import Tuple from urllib.parse import urlparse -import pyarrow as pa import pyarrow.fs as pa_fs @@ -76,3 +75,12 @@ def fs_from_uri(uri: str) -> Tuple[pa_fs.FileSystem, str]: return fs, path return pa_fs.FileSystem.from_uri(uri) + + +def safe_import_pandas(): + try: + import pandas as pd + + return pd + except ImportError: + return None