make pandas an optional dependency in lancedb as well (#385)

This commit is contained in:
Chang She
2023-07-31 14:08:58 -04:00
committed by GitHub
parent cada35d5b7
commit c1f8feb6ed
10 changed files with 45 additions and 34 deletions

View File

@@ -11,19 +11,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path
from typing import List, Union
from typing import Iterable, List, Union
import numpy as np
import pandas as pd
import pyarrow as pa
from .pydantic import LanceModel
from .util import safe_import_pandas
pd = safe_import_pandas()
DATA = Union[List[dict], dict, "pd.DataFrame", pa.Table, Iterable[pa.RecordBatch]]
VEC = Union[list, np.ndarray, pa.Array, pa.ChunkedArray]
URI = Union[str, Path]
# TODO support generator
DATA = Union[List[dict], List[LanceModel], dict, pd.DataFrame]
VECTOR_COLUMN_NAME = "vector"

View File

@@ -12,12 +12,13 @@
# limitations under the License.
from __future__ import annotations
import pandas as pd
from .exceptions import MissingColumnError, MissingValueError
from .util import safe_import_pandas
pd = safe_import_pandas()
def contextualize(raw_df: pd.DataFrame) -> Contextualizer:
def contextualize(raw_df: "pd.DataFrame") -> Contextualizer:
"""Create a Contextualizer object for the given DataFrame.
Used to create context windows. Context windows are rolling subsets of text
@@ -175,8 +176,12 @@ class Contextualizer:
self._min_window_size = min_window_size
return self
def to_df(self) -> pd.DataFrame:
def to_df(self) -> "pd.DataFrame":
"""Create the context windows and return a DataFrame."""
if pd is None:
raise ImportError(
"pandas is required to create context windows using lancedb"
)
if self._text_col not in self._raw_df.columns.tolist():
raise MissingColumnError(self._text_col)

View File

@@ -16,9 +16,8 @@ from __future__ import annotations
import os
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Tuple, Union
from typing import Optional
import pandas as pd
import pyarrow as pa
from pyarrow import fs
@@ -39,9 +38,7 @@ class DBConnection(ABC):
def create_table(
self,
name: str,
data: Optional[
Union[List[dict], dict, pd.DataFrame, pa.Table, Iterable[pa.RecordBatch]],
] = None,
data: Optional[DATA] = None,
schema: Optional[pa.Schema] = None,
mode: str = "create",
on_bad_vectors: str = "error",

View File

@@ -16,15 +16,19 @@ import sys
from typing import Callable, Union
import numpy as np
import pandas as pd
import pyarrow as pa
from lance.vector import vec_to_table
from retry import retry
from .util import safe_import_pandas
pd = safe_import_pandas()
DATA = Union[pa.Table, "pd.DataFrame"]
def with_embeddings(
func: Callable,
data: Union[pa.Table, pd.DataFrame],
data: DATA,
column: str = "text",
wrap_api: bool = True,
show_progress: bool = False,
@@ -60,7 +64,7 @@ def with_embeddings(
func = func.batch_size(batch_size)
if show_progress:
func = func.show_progress()
if isinstance(data, pd.DataFrame):
if pd is not None and isinstance(data, pd.DataFrame):
data = pa.Table.from_pandas(data, preserve_index=False)
embeddings = func(data[column].to_numpy())
table = vec_to_table(np.array(embeddings))

View File

@@ -16,12 +16,14 @@ from __future__ import annotations
from typing import List, Literal, Optional, Type, Union
import numpy as np
import pandas as pd
import pyarrow as pa
import pydantic
from .common import VECTOR_COLUMN_NAME
from .pydantic import LanceModel
from .util import safe_import_pandas
pd = safe_import_pandas()
class Query(pydantic.BaseModel):
@@ -199,7 +201,7 @@ class LanceQueryBuilder:
self._refine_factor = refine_factor
return self
def to_df(self) -> pd.DataFrame:
def to_df(self) -> "pd.DataFrame":
"""
Execute the query and return the results as a pandas DataFrame.
In addition to the selected columns, LanceDB also returns a vector
@@ -250,7 +252,7 @@ class LanceQueryBuilder:
class LanceFtsQueryBuilder(LanceQueryBuilder):
def to_arrow(self) -> pd.Table:
def to_arrow(self) -> pa.Table:
try:
import tantivy
except ImportError:

View File

@@ -20,7 +20,6 @@ import pyarrow as pa
from lancedb.common import DATA
from lancedb.db import DBConnection
from lancedb.schema import schema_to_json
from lancedb.table import Table, _sanitize_data
from .arrow import to_ipc_binary

View File

@@ -16,11 +16,11 @@ from functools import cached_property
from typing import Union
import pyarrow as pa
from lance import json_to_schema
from lancedb.common import DATA, VEC, VECTOR_COLUMN_NAME
from ..query import LanceQueryBuilder, Query
from ..schema import json_to_schema
from ..query import LanceQueryBuilder
from ..table import Query, Table, _sanitize_data
from .arrow import to_ipc_binary
from .client import ARROW_STREAM_CONTENT_TYPE

View File

@@ -12,11 +12,7 @@
# limitations under the License.
"""Schema related utilities."""
from typing import Any, Dict, Type
import pyarrow as pa
from lance import json_to_schema, schema_to_json
def vector(dimension: int, value_type: pa.DataType = pa.float32()) -> pa.DataType:

View File

@@ -20,7 +20,6 @@ from typing import Iterable, List, Union
import lance
import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.compute as pc
from lance import LanceDataset
@@ -29,7 +28,9 @@ from lance.vector import vec_to_table
from .common import DATA, VEC, VECTOR_COLUMN_NAME
from .pydantic import LanceModel
from .query import LanceFtsQueryBuilder, LanceQueryBuilder, Query
from .util import fs_from_uri
from .util import fs_from_uri, safe_import_pandas
pd = safe_import_pandas()
def _sanitize_data(data, schema, on_bad_vectors, fill_value):
@@ -44,7 +45,7 @@ def _sanitize_data(data, schema, on_bad_vectors, fill_value):
)
if isinstance(data, dict):
data = vec_to_table(data)
if isinstance(data, pd.DataFrame):
if pd is not None and isinstance(data, pd.DataFrame):
data = pa.Table.from_pandas(data)
data = _sanitize_schema(
data, schema=schema, on_bad_vectors=on_bad_vectors, fill_value=fill_value
@@ -99,7 +100,7 @@ class Table(ABC):
"""
raise NotImplementedError
def to_pandas(self) -> pd.DataFrame:
def to_pandas(self):
"""Return the table as a pandas DataFrame.
Returns
@@ -333,7 +334,7 @@ class LanceTable(Table):
"""Return the first n rows of the table."""
return self._dataset.head(n)
def to_pandas(self) -> pd.DataFrame:
def to_pandas(self) -> "pd.DataFrame":
"""Return the table as a pandas DataFrame.
Returns

View File

@@ -15,7 +15,6 @@ import os
from typing import Tuple
from urllib.parse import urlparse
import pyarrow as pa
import pyarrow.fs as pa_fs
@@ -76,3 +75,12 @@ def fs_from_uri(uri: str) -> Tuple[pa_fs.FileSystem, str]:
return fs, path
return pa_fs.FileSystem.from_uri(uri)
def safe_import_pandas():
try:
import pandas as pd
return pd
except ImportError:
return None