fix: support pyarrow input types (#1628)

fixes #1625 
Support PyArrow.RecordBatch, pa.dataset.Dataset, pa.dataset.Scanner,
paRecordBatchReader
This commit is contained in:
LuQQiu
2024-09-12 10:59:18 -07:00
committed by GitHub
parent b3bf6386c3
commit c7732585bf
3 changed files with 419 additions and 51 deletions

View File

@@ -0,0 +1,259 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The Lance Authors
#
# The following code is originally from https://github.com/pola-rs/polars/blob/ea4389c31b0e87ddf20a85e4c3797b285966edb6/py-polars/polars/dependencies.py
# and is licensed under the MIT license:
#
# License: MIT, Copyright (c) 2020 Ritchie Vink
# https://github.com/pola-rs/polars/blob/main/LICENSE
#
# It has been modified by the LanceDB developers
# to fit the needs of the LanceDB project.
from __future__ import annotations
import re
import sys
from functools import lru_cache
from importlib import import_module
from importlib.util import find_spec
from types import ModuleType
from typing import TYPE_CHECKING, Any, ClassVar, Hashable, cast
_NUMPY_AVAILABLE = True
_PANDAS_AVAILABLE = True
_POLARS_AVAILABLE = True
_TORCH_AVAILABLE = True
_HUGGING_FACE_AVAILABLE = True
_TENSORFLOW_AVAILABLE = True
_RAY_AVAILABLE = True
class _LazyModule(ModuleType):
"""
Module that can act both as a lazy-loader and as a proxy.
Notes
-----
We do NOT register this module with `sys.modules` so as not to cause
confusion in the global environment. This way we have a valid proxy
module for our own use, but it lives _exclusively_ within lance.
"""
__lazy__ = True
_mod_pfx: ClassVar[dict[str, str]] = {
"numpy": "np.",
"pandas": "pd.",
"polars": "pl.",
"torch": "torch.",
"tensorflow": "tf.",
"ray": "ray.",
}
def __init__(
self,
module_name: str,
*,
module_available: bool,
) -> None:
"""
Initialise lazy-loading proxy module.
Parameters
----------
module_name : str
the name of the module to lazy-load (if available).
module_available : bool
indicate if the referenced module is actually available (we will proxy it
in both cases, but raise a helpful error when invoked if it doesn't exist).
"""
self._module_available = module_available
self._module_name = module_name
self._globals = globals()
super().__init__(module_name)
def _import(self) -> ModuleType:
# import the referenced module, replacing the proxy in this module's globals
module = import_module(self.__name__)
self._globals[self._module_name] = module
self.__dict__.update(module.__dict__)
return module
def __getattr__(self, attr: Any) -> Any:
# have "hasattr('__wrapped__')" return False without triggering import
# (it's for decorators, not modules, but keeps "make doctest" happy)
if attr == "__wrapped__":
raise AttributeError(
f"{self._module_name!r} object has no attribute {attr!r}"
)
# accessing the proxy module's attributes triggers import of the real thing
if self._module_available:
# import the module and return the requested attribute
module = self._import()
return getattr(module, attr)
# user has not installed the proxied/lazy module
elif attr == "__name__":
return self._module_name
elif re.match(r"^__\w+__$", attr) and attr != "__version__":
# allow some minimal introspection on private module
# attrs to avoid unnecessary error-handling elsewhere
return None
else:
# all other attribute access raises a helpful exception
pfx = self._mod_pfx.get(self._module_name, "")
raise ModuleNotFoundError(
f"{pfx}{attr} requires {self._module_name!r} module to be installed"
) from None
def _lazy_import(module_name: str) -> tuple[ModuleType, bool]:
"""
Lazy import the given module; avoids up-front import costs.
Parameters
----------
module_name : str
name of the module to import, eg: "polars".
Notes
-----
If the requested module is not available (eg: has not been installed), a proxy
module is created in its place, which raises an exception on any attribute
access. This allows for import and use as normal, without requiring explicit
guard conditions - if the module is never used, no exception occurs; if it
is, then a helpful exception is raised.
Returns
-------
tuple of (Module, bool)
A lazy-loading module and a boolean indicating if the requested/underlying
module exists (if not, the returned module is a proxy).
"""
# check if module is LOADED
if module_name in sys.modules:
return sys.modules[module_name], True
# check if module is AVAILABLE
try:
module_spec = find_spec(module_name)
module_available = not (module_spec is None or module_spec.loader is None)
except ModuleNotFoundError:
module_available = False
# create lazy/proxy module that imports the real one on first use
# (or raises an explanatory ModuleNotFoundError if not available)
return (
_LazyModule(
module_name=module_name,
module_available=module_available,
),
module_available,
)
if TYPE_CHECKING:
import datasets
import numpy
import pandas
import polars
import ray
import tensorflow
import torch
else:
# heavy/optional third party libs
numpy, _NUMPY_AVAILABLE = _lazy_import("numpy")
pandas, _PANDAS_AVAILABLE = _lazy_import("pandas")
polars, _POLARS_AVAILABLE = _lazy_import("polars")
torch, _TORCH_AVAILABLE = _lazy_import("torch")
datasets, _HUGGING_FACE_AVAILABLE = _lazy_import("datasets")
tensorflow, _TENSORFLOW_AVAILABLE = _lazy_import("tensorflow")
ray, _RAY_AVAILABLE = _lazy_import("ray")
@lru_cache(maxsize=None)
def _might_be(cls: type, type_: str) -> bool:
# infer whether the given class "might" be associated with the given
# module (in which case it's reasonable to do a real isinstance check)
try:
return any(f"{type_}." in str(o) for o in cls.mro())
except TypeError:
return False
def _check_for_numpy(obj: Any, *, check_type: bool = True) -> bool:
return _NUMPY_AVAILABLE and _might_be(
cast(Hashable, type(obj) if check_type else obj), "numpy"
)
def _check_for_pandas(obj: Any, *, check_type: bool = True) -> bool:
return _PANDAS_AVAILABLE and _might_be(
cast(Hashable, type(obj) if check_type else obj), "pandas"
)
def _check_for_polars(obj: Any, *, check_type: bool = True) -> bool:
return _POLARS_AVAILABLE and _might_be(
cast(Hashable, type(obj) if check_type else obj), "polars"
)
def _check_for_torch(obj: Any, *, check_type: bool = True) -> bool:
return _TORCH_AVAILABLE and _might_be(
cast(Hashable, type(obj) if check_type else obj), "torch"
)
def _check_for_hugging_face(obj: Any, *, check_type: bool = True) -> bool:
return _HUGGING_FACE_AVAILABLE and _might_be(
cast(Hashable, type(obj) if check_type else obj), "datasets"
)
def _check_for_tensorflow(obj: Any, *, check_type: bool = True) -> bool:
return _TENSORFLOW_AVAILABLE and _might_be(
cast(Hashable, type(obj) if check_type else obj), "tensorflow"
)
def _check_for_ray(obj: Any, *, check_type: bool = True) -> bool:
return _RAY_AVAILABLE and _might_be(
cast(Hashable, type(obj) if check_type else obj), "ray"
)
__all__ = [
# lazy-load third party libs
"datasets",
"numpy",
"pandas",
"polars",
"ray",
"tensorflow",
"torch",
# lazy utilities
"_check_for_hugging_face",
"_check_for_numpy",
"_check_for_pandas",
"_check_for_polars",
"_check_for_tensorflow",
"_check_for_torch",
"_check_for_ray",
"_LazyModule",
# exported flags/guards
"_NUMPY_AVAILABLE",
"_PANDAS_AVAILABLE",
"_POLARS_AVAILABLE",
"_TORCH_AVAILABLE",
"_HUGGING_FACE_AVAILABLE",
"_TENSORFLOW_AVAILABLE",
"_RAY_AVAILABLE",
]

View File

@@ -23,6 +23,7 @@ from typing import (
from urllib.parse import urlparse
import lance
from .dependencies import _check_for_pandas
import numpy as np
import pyarrow as pa
import pyarrow.compute as pc
@@ -53,38 +54,23 @@ if TYPE_CHECKING:
from .db import LanceDBConnection
from .index import BTree, IndexConfig, IvfPq, Bitmap, LabelList, FTS
pd = safe_import_pandas()
pl = safe_import_polars()
def _sanitize_data(
data,
schema: Optional[pa.Schema],
metadata: Optional[dict],
on_bad_vectors: str,
fill_value: Any,
):
def _coerce_to_table(data, schema: Optional[pa.Schema] = None) -> pa.Table:
if _check_for_hugging_face(data):
# Huggingface datasets
from lance.dependencies import datasets
if isinstance(data, datasets.dataset_dict.DatasetDict):
if schema is None:
schema = _schema_from_hf(data, schema)
data = _to_record_batch_generator(
_to_batches_with_split(data),
schema,
metadata,
on_bad_vectors,
fill_value,
)
elif isinstance(data, datasets.Dataset):
if isinstance(data, datasets.Dataset):
if schema is None:
schema = data.features.arrow_schema
data = _to_record_batch_generator(
data.data.to_batches(), schema, metadata, on_bad_vectors, fill_value
)
return pa.Table.from_batches(data.data.to_batches(), schema=schema)
elif isinstance(data, datasets.dataset_dict.DatasetDict):
if schema is None:
schema = _schema_from_hf(data, schema)
return pa.Table.from_batches(_to_batches_with_split(data), schema=schema)
if isinstance(data, LanceModel):
raise ValueError("Cannot add a single LanceModel to a table. Use a list.")
@@ -95,40 +81,66 @@ def _sanitize_data(
if schema is None:
schema = data[0].__class__.to_arrow_schema()
data = [model_to_dict(d) for d in data]
data = pa.Table.from_pylist(data, schema=schema)
return pa.Table.from_pylist(data, schema=schema)
else:
data = pa.Table.from_pylist(data)
return pa.Table.from_pylist(data)
elif isinstance(data, dict):
data = vec_to_table(data)
elif pd is not None and isinstance(data, pd.DataFrame):
data = pa.Table.from_pandas(data, preserve_index=False)
return vec_to_table(data)
elif _check_for_pandas(data) and isinstance(data, pd.DataFrame):
# Do not add schema here, since schema may contains the vector column
table = pa.Table.from_pandas(data, preserve_index=False)
# Do not serialize Pandas metadata
meta = data.schema.metadata if data.schema.metadata is not None else {}
meta = table.schema.metadata if table.schema.metadata is not None else {}
meta = {k: v for k, v in meta.items() if k != b"pandas"}
data = data.replace_schema_metadata(meta)
elif pl is not None and isinstance(data, pl.DataFrame):
data = data.to_arrow()
if isinstance(data, pa.Table):
if metadata:
data = _append_vector_col(data, metadata, schema)
metadata.update(data.schema.metadata or {})
data = data.replace_schema_metadata(metadata)
data = _sanitize_schema(
data, schema=schema, on_bad_vectors=on_bad_vectors, fill_value=fill_value
)
if schema is None:
schema = data.schema
return table.replace_schema_metadata(meta)
elif isinstance(data, pa.Table):
return data
elif isinstance(data, pa.RecordBatch):
return pa.Table.from_batches([data])
elif isinstance(data, LanceDataset):
return data.scanner().to_table()
elif isinstance(data, pa.dataset.Dataset):
return data.to_table()
elif isinstance(data, pa.dataset.Scanner):
return data.to_table()
elif isinstance(data, pa.RecordBatchReader):
return data.read_all()
elif (
type(data).__module__.startswith("polars")
and data.__class__.__name__ == "DataFrame"
):
return data.to_arrow()
elif isinstance(data, Iterable):
data = _to_record_batch_generator(
data, schema, metadata, on_bad_vectors, fill_value
)
if schema is None:
data, schema = _generator_to_data_and_schema(data)
if schema is None:
raise ValueError("Cannot infer schema from generator data")
return _process_iterator(data, schema)
else:
raise TypeError(f"Unsupported data type: {type(data)}")
raise TypeError(
f"Unknown data type {type(data)}. "
"Please check "
"https://lancedb.github.io/lancedb/python/python/ "
"to see supported types."
)
def _sanitize_data(
data: Any,
schema: Optional[pa.Schema] = None,
metadata: Optional[dict] = None, # embedding metadata
on_bad_vectors: str = "error",
fill_value: float = 0.0,
):
data = _coerce_to_table(data, schema)
if metadata:
data = _append_vector_col(data, metadata, schema)
metadata.update(data.schema.metadata or {})
data = data.replace_schema_metadata(metadata)
# TODO improve the logics in _sanitize_schema
data = _sanitize_schema(data, schema, on_bad_vectors, fill_value)
if schema is None:
schema = data.schema
_validate_schema(schema)
return data, schema
@@ -2015,6 +2027,55 @@ def _sanitize_nans(data, fill_value, on_bad_vectors, vec_arr, vector_column_name
return data
def _validate_schema(schema: pa.Schema):
"""
Make sure the metadata is valid utf8
"""
if schema.metadata is not None:
_validate_metadata(schema.metadata)
def _validate_metadata(metadata: dict):
"""
Make sure the metadata values are valid utf8 (can be nested)
Raises ValueError if not valid utf8
"""
for k, v in metadata.items():
if isinstance(v, bytes):
try:
v.decode("utf8")
except UnicodeDecodeError:
raise ValueError(
f"Metadata key {k} is not valid utf8. "
"Consider base64 encode for generic binary metadata."
)
elif isinstance(v, dict):
_validate_metadata(v)
def _process_iterator(data: Iterable, schema: Optional[pa.Schema] = None) -> pa.Table:
batches = []
for batch in data:
batch_table = _coerce_to_table(batch, schema)
if schema is not None:
if batch_table.schema != schema:
try:
batch_table = batch_table.cast(schema)
except pa.lib.ArrowInvalid:
raise ValueError(
f"Input iterator yielded a batch with schema that "
f"does not match the expected schema.\nExpected:\n{schema}\n"
f"Got:\n{batch_table.schema}"
)
batches.append(batch_table)
if batches:
return pa.concat_tables(batches)
else:
raise ValueError("Input iterable is empty")
class AsyncTable:
"""
An AsyncTable is a collection of Records in a LanceDB Database.

View File

@@ -64,6 +64,55 @@ def test_basic(db):
assert table.to_lance().to_table() == ds.to_table()
def test_input_data_type(db, tmp_path):
schema = pa.schema(
[
pa.field("id", pa.int64()),
pa.field("name", pa.string()),
pa.field("age", pa.int32()),
]
)
data = {
"id": [1, 2, 3, 4, 5],
"name": ["Alice", "Bob", "Charlie", "David", "Eve"],
"age": [25, 30, 35, 40, 45],
}
record_batch = pa.RecordBatch.from_pydict(data, schema=schema)
pa_reader = pa.RecordBatchReader.from_batches(record_batch.schema, [record_batch])
pa_table = pa.Table.from_batches([record_batch])
def create_dataset(tmp_path):
path = os.path.join(tmp_path, "test_source_dataset")
pa.dataset.write_dataset(pa_table, path, format="parquet")
return pa.dataset.dataset(path, format="parquet")
pa_dataset = create_dataset(tmp_path)
pa_scanner = pa_dataset.scanner()
input_types = [
("RecordBatchReader", pa_reader),
("RecordBatch", record_batch),
("Table", pa_table),
("Dataset", pa_dataset),
("Scanner", pa_scanner),
]
for input_type, input_data in input_types:
table_name = f"test_{input_type.lower()}"
ds = LanceTable.create(db, table_name, data=input_data).to_lance()
assert ds.schema == schema
assert ds.count_rows() == 5
assert ds.schema.field("id").type == pa.int64()
assert ds.schema.field("name").type == pa.string()
assert ds.schema.field("age").type == pa.int32()
result_table = ds.to_table()
assert result_table.column("id").to_pylist() == data["id"]
assert result_table.column("name").to_pylist() == data["name"]
assert result_table.column("age").to_pylist() == data["age"]
@pytest.mark.asyncio
async def test_close(db_async: AsyncConnection):
table = await db_async.create_table("some_table", data=[{"id": 0}])
@@ -274,7 +323,6 @@ def test_polars(db):
def _add(table, schema):
# table = LanceTable(db, "test")
assert len(table) == 2
table.add([{"vector": [6.3, 100.5], "item": "new", "price": 30.0}])