mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-26 06:39:57 +00:00
fix: support pyarrow input types (#1628)
fixes #1625 Support PyArrow.RecordBatch, pa.dataset.Dataset, pa.dataset.Scanner, paRecordBatchReader
This commit is contained in:
259
python/python/lancedb/dependencies.py
Normal file
259
python/python/lancedb/dependencies.py
Normal file
@@ -0,0 +1,259 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright The Lance Authors
|
||||
#
|
||||
# The following code is originally from https://github.com/pola-rs/polars/blob/ea4389c31b0e87ddf20a85e4c3797b285966edb6/py-polars/polars/dependencies.py
|
||||
# and is licensed under the MIT license:
|
||||
#
|
||||
# License: MIT, Copyright (c) 2020 Ritchie Vink
|
||||
# https://github.com/pola-rs/polars/blob/main/LICENSE
|
||||
#
|
||||
# It has been modified by the LanceDB developers
|
||||
# to fit the needs of the LanceDB project.
|
||||
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import sys
|
||||
from functools import lru_cache
|
||||
from importlib import import_module
|
||||
from importlib.util import find_spec
|
||||
from types import ModuleType
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Hashable, cast
|
||||
|
||||
_NUMPY_AVAILABLE = True
|
||||
_PANDAS_AVAILABLE = True
|
||||
_POLARS_AVAILABLE = True
|
||||
_TORCH_AVAILABLE = True
|
||||
_HUGGING_FACE_AVAILABLE = True
|
||||
_TENSORFLOW_AVAILABLE = True
|
||||
_RAY_AVAILABLE = True
|
||||
|
||||
|
||||
class _LazyModule(ModuleType):
|
||||
"""
|
||||
Module that can act both as a lazy-loader and as a proxy.
|
||||
|
||||
Notes
|
||||
-----
|
||||
We do NOT register this module with `sys.modules` so as not to cause
|
||||
confusion in the global environment. This way we have a valid proxy
|
||||
module for our own use, but it lives _exclusively_ within lance.
|
||||
|
||||
"""
|
||||
|
||||
__lazy__ = True
|
||||
|
||||
_mod_pfx: ClassVar[dict[str, str]] = {
|
||||
"numpy": "np.",
|
||||
"pandas": "pd.",
|
||||
"polars": "pl.",
|
||||
"torch": "torch.",
|
||||
"tensorflow": "tf.",
|
||||
"ray": "ray.",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
module_name: str,
|
||||
*,
|
||||
module_available: bool,
|
||||
) -> None:
|
||||
"""
|
||||
Initialise lazy-loading proxy module.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
module_name : str
|
||||
the name of the module to lazy-load (if available).
|
||||
|
||||
module_available : bool
|
||||
indicate if the referenced module is actually available (we will proxy it
|
||||
in both cases, but raise a helpful error when invoked if it doesn't exist).
|
||||
|
||||
"""
|
||||
self._module_available = module_available
|
||||
self._module_name = module_name
|
||||
self._globals = globals()
|
||||
super().__init__(module_name)
|
||||
|
||||
def _import(self) -> ModuleType:
|
||||
# import the referenced module, replacing the proxy in this module's globals
|
||||
module = import_module(self.__name__)
|
||||
self._globals[self._module_name] = module
|
||||
self.__dict__.update(module.__dict__)
|
||||
return module
|
||||
|
||||
def __getattr__(self, attr: Any) -> Any:
|
||||
# have "hasattr('__wrapped__')" return False without triggering import
|
||||
# (it's for decorators, not modules, but keeps "make doctest" happy)
|
||||
if attr == "__wrapped__":
|
||||
raise AttributeError(
|
||||
f"{self._module_name!r} object has no attribute {attr!r}"
|
||||
)
|
||||
|
||||
# accessing the proxy module's attributes triggers import of the real thing
|
||||
if self._module_available:
|
||||
# import the module and return the requested attribute
|
||||
module = self._import()
|
||||
return getattr(module, attr)
|
||||
|
||||
# user has not installed the proxied/lazy module
|
||||
elif attr == "__name__":
|
||||
return self._module_name
|
||||
elif re.match(r"^__\w+__$", attr) and attr != "__version__":
|
||||
# allow some minimal introspection on private module
|
||||
# attrs to avoid unnecessary error-handling elsewhere
|
||||
return None
|
||||
else:
|
||||
# all other attribute access raises a helpful exception
|
||||
pfx = self._mod_pfx.get(self._module_name, "")
|
||||
raise ModuleNotFoundError(
|
||||
f"{pfx}{attr} requires {self._module_name!r} module to be installed"
|
||||
) from None
|
||||
|
||||
|
||||
def _lazy_import(module_name: str) -> tuple[ModuleType, bool]:
|
||||
"""
|
||||
Lazy import the given module; avoids up-front import costs.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
module_name : str
|
||||
name of the module to import, eg: "polars".
|
||||
|
||||
Notes
|
||||
-----
|
||||
If the requested module is not available (eg: has not been installed), a proxy
|
||||
module is created in its place, which raises an exception on any attribute
|
||||
access. This allows for import and use as normal, without requiring explicit
|
||||
guard conditions - if the module is never used, no exception occurs; if it
|
||||
is, then a helpful exception is raised.
|
||||
|
||||
Returns
|
||||
-------
|
||||
tuple of (Module, bool)
|
||||
A lazy-loading module and a boolean indicating if the requested/underlying
|
||||
module exists (if not, the returned module is a proxy).
|
||||
|
||||
"""
|
||||
# check if module is LOADED
|
||||
if module_name in sys.modules:
|
||||
return sys.modules[module_name], True
|
||||
|
||||
# check if module is AVAILABLE
|
||||
try:
|
||||
module_spec = find_spec(module_name)
|
||||
module_available = not (module_spec is None or module_spec.loader is None)
|
||||
except ModuleNotFoundError:
|
||||
module_available = False
|
||||
|
||||
# create lazy/proxy module that imports the real one on first use
|
||||
# (or raises an explanatory ModuleNotFoundError if not available)
|
||||
return (
|
||||
_LazyModule(
|
||||
module_name=module_name,
|
||||
module_available=module_available,
|
||||
),
|
||||
module_available,
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import datasets
|
||||
import numpy
|
||||
import pandas
|
||||
import polars
|
||||
import ray
|
||||
import tensorflow
|
||||
import torch
|
||||
else:
|
||||
# heavy/optional third party libs
|
||||
numpy, _NUMPY_AVAILABLE = _lazy_import("numpy")
|
||||
pandas, _PANDAS_AVAILABLE = _lazy_import("pandas")
|
||||
polars, _POLARS_AVAILABLE = _lazy_import("polars")
|
||||
torch, _TORCH_AVAILABLE = _lazy_import("torch")
|
||||
datasets, _HUGGING_FACE_AVAILABLE = _lazy_import("datasets")
|
||||
tensorflow, _TENSORFLOW_AVAILABLE = _lazy_import("tensorflow")
|
||||
ray, _RAY_AVAILABLE = _lazy_import("ray")
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def _might_be(cls: type, type_: str) -> bool:
|
||||
# infer whether the given class "might" be associated with the given
|
||||
# module (in which case it's reasonable to do a real isinstance check)
|
||||
try:
|
||||
return any(f"{type_}." in str(o) for o in cls.mro())
|
||||
except TypeError:
|
||||
return False
|
||||
|
||||
|
||||
def _check_for_numpy(obj: Any, *, check_type: bool = True) -> bool:
|
||||
return _NUMPY_AVAILABLE and _might_be(
|
||||
cast(Hashable, type(obj) if check_type else obj), "numpy"
|
||||
)
|
||||
|
||||
|
||||
def _check_for_pandas(obj: Any, *, check_type: bool = True) -> bool:
|
||||
return _PANDAS_AVAILABLE and _might_be(
|
||||
cast(Hashable, type(obj) if check_type else obj), "pandas"
|
||||
)
|
||||
|
||||
|
||||
def _check_for_polars(obj: Any, *, check_type: bool = True) -> bool:
|
||||
return _POLARS_AVAILABLE and _might_be(
|
||||
cast(Hashable, type(obj) if check_type else obj), "polars"
|
||||
)
|
||||
|
||||
|
||||
def _check_for_torch(obj: Any, *, check_type: bool = True) -> bool:
|
||||
return _TORCH_AVAILABLE and _might_be(
|
||||
cast(Hashable, type(obj) if check_type else obj), "torch"
|
||||
)
|
||||
|
||||
|
||||
def _check_for_hugging_face(obj: Any, *, check_type: bool = True) -> bool:
|
||||
return _HUGGING_FACE_AVAILABLE and _might_be(
|
||||
cast(Hashable, type(obj) if check_type else obj), "datasets"
|
||||
)
|
||||
|
||||
|
||||
def _check_for_tensorflow(obj: Any, *, check_type: bool = True) -> bool:
|
||||
return _TENSORFLOW_AVAILABLE and _might_be(
|
||||
cast(Hashable, type(obj) if check_type else obj), "tensorflow"
|
||||
)
|
||||
|
||||
|
||||
def _check_for_ray(obj: Any, *, check_type: bool = True) -> bool:
|
||||
return _RAY_AVAILABLE and _might_be(
|
||||
cast(Hashable, type(obj) if check_type else obj), "ray"
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
# lazy-load third party libs
|
||||
"datasets",
|
||||
"numpy",
|
||||
"pandas",
|
||||
"polars",
|
||||
"ray",
|
||||
"tensorflow",
|
||||
"torch",
|
||||
# lazy utilities
|
||||
"_check_for_hugging_face",
|
||||
"_check_for_numpy",
|
||||
"_check_for_pandas",
|
||||
"_check_for_polars",
|
||||
"_check_for_tensorflow",
|
||||
"_check_for_torch",
|
||||
"_check_for_ray",
|
||||
"_LazyModule",
|
||||
# exported flags/guards
|
||||
"_NUMPY_AVAILABLE",
|
||||
"_PANDAS_AVAILABLE",
|
||||
"_POLARS_AVAILABLE",
|
||||
"_TORCH_AVAILABLE",
|
||||
"_HUGGING_FACE_AVAILABLE",
|
||||
"_TENSORFLOW_AVAILABLE",
|
||||
"_RAY_AVAILABLE",
|
||||
]
|
||||
@@ -23,6 +23,7 @@ from typing import (
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import lance
|
||||
from .dependencies import _check_for_pandas
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
import pyarrow.compute as pc
|
||||
@@ -53,38 +54,23 @@ if TYPE_CHECKING:
|
||||
from .db import LanceDBConnection
|
||||
from .index import BTree, IndexConfig, IvfPq, Bitmap, LabelList, FTS
|
||||
|
||||
|
||||
pd = safe_import_pandas()
|
||||
pl = safe_import_polars()
|
||||
|
||||
|
||||
def _sanitize_data(
|
||||
data,
|
||||
schema: Optional[pa.Schema],
|
||||
metadata: Optional[dict],
|
||||
on_bad_vectors: str,
|
||||
fill_value: Any,
|
||||
):
|
||||
def _coerce_to_table(data, schema: Optional[pa.Schema] = None) -> pa.Table:
|
||||
if _check_for_hugging_face(data):
|
||||
# Huggingface datasets
|
||||
from lance.dependencies import datasets
|
||||
|
||||
if isinstance(data, datasets.dataset_dict.DatasetDict):
|
||||
if schema is None:
|
||||
schema = _schema_from_hf(data, schema)
|
||||
data = _to_record_batch_generator(
|
||||
_to_batches_with_split(data),
|
||||
schema,
|
||||
metadata,
|
||||
on_bad_vectors,
|
||||
fill_value,
|
||||
)
|
||||
elif isinstance(data, datasets.Dataset):
|
||||
if isinstance(data, datasets.Dataset):
|
||||
if schema is None:
|
||||
schema = data.features.arrow_schema
|
||||
data = _to_record_batch_generator(
|
||||
data.data.to_batches(), schema, metadata, on_bad_vectors, fill_value
|
||||
)
|
||||
return pa.Table.from_batches(data.data.to_batches(), schema=schema)
|
||||
elif isinstance(data, datasets.dataset_dict.DatasetDict):
|
||||
if schema is None:
|
||||
schema = _schema_from_hf(data, schema)
|
||||
return pa.Table.from_batches(_to_batches_with_split(data), schema=schema)
|
||||
|
||||
if isinstance(data, LanceModel):
|
||||
raise ValueError("Cannot add a single LanceModel to a table. Use a list.")
|
||||
@@ -95,40 +81,66 @@ def _sanitize_data(
|
||||
if schema is None:
|
||||
schema = data[0].__class__.to_arrow_schema()
|
||||
data = [model_to_dict(d) for d in data]
|
||||
data = pa.Table.from_pylist(data, schema=schema)
|
||||
return pa.Table.from_pylist(data, schema=schema)
|
||||
else:
|
||||
data = pa.Table.from_pylist(data)
|
||||
return pa.Table.from_pylist(data)
|
||||
elif isinstance(data, dict):
|
||||
data = vec_to_table(data)
|
||||
elif pd is not None and isinstance(data, pd.DataFrame):
|
||||
data = pa.Table.from_pandas(data, preserve_index=False)
|
||||
return vec_to_table(data)
|
||||
elif _check_for_pandas(data) and isinstance(data, pd.DataFrame):
|
||||
# Do not add schema here, since schema may contains the vector column
|
||||
table = pa.Table.from_pandas(data, preserve_index=False)
|
||||
# Do not serialize Pandas metadata
|
||||
meta = data.schema.metadata if data.schema.metadata is not None else {}
|
||||
meta = table.schema.metadata if table.schema.metadata is not None else {}
|
||||
meta = {k: v for k, v in meta.items() if k != b"pandas"}
|
||||
data = data.replace_schema_metadata(meta)
|
||||
elif pl is not None and isinstance(data, pl.DataFrame):
|
||||
data = data.to_arrow()
|
||||
|
||||
if isinstance(data, pa.Table):
|
||||
if metadata:
|
||||
data = _append_vector_col(data, metadata, schema)
|
||||
metadata.update(data.schema.metadata or {})
|
||||
data = data.replace_schema_metadata(metadata)
|
||||
data = _sanitize_schema(
|
||||
data, schema=schema, on_bad_vectors=on_bad_vectors, fill_value=fill_value
|
||||
)
|
||||
if schema is None:
|
||||
schema = data.schema
|
||||
return table.replace_schema_metadata(meta)
|
||||
elif isinstance(data, pa.Table):
|
||||
return data
|
||||
elif isinstance(data, pa.RecordBatch):
|
||||
return pa.Table.from_batches([data])
|
||||
elif isinstance(data, LanceDataset):
|
||||
return data.scanner().to_table()
|
||||
elif isinstance(data, pa.dataset.Dataset):
|
||||
return data.to_table()
|
||||
elif isinstance(data, pa.dataset.Scanner):
|
||||
return data.to_table()
|
||||
elif isinstance(data, pa.RecordBatchReader):
|
||||
return data.read_all()
|
||||
elif (
|
||||
type(data).__module__.startswith("polars")
|
||||
and data.__class__.__name__ == "DataFrame"
|
||||
):
|
||||
return data.to_arrow()
|
||||
elif isinstance(data, Iterable):
|
||||
data = _to_record_batch_generator(
|
||||
data, schema, metadata, on_bad_vectors, fill_value
|
||||
)
|
||||
if schema is None:
|
||||
data, schema = _generator_to_data_and_schema(data)
|
||||
if schema is None:
|
||||
raise ValueError("Cannot infer schema from generator data")
|
||||
return _process_iterator(data, schema)
|
||||
else:
|
||||
raise TypeError(f"Unsupported data type: {type(data)}")
|
||||
raise TypeError(
|
||||
f"Unknown data type {type(data)}. "
|
||||
"Please check "
|
||||
"https://lancedb.github.io/lancedb/python/python/ "
|
||||
"to see supported types."
|
||||
)
|
||||
|
||||
|
||||
def _sanitize_data(
|
||||
data: Any,
|
||||
schema: Optional[pa.Schema] = None,
|
||||
metadata: Optional[dict] = None, # embedding metadata
|
||||
on_bad_vectors: str = "error",
|
||||
fill_value: float = 0.0,
|
||||
):
|
||||
data = _coerce_to_table(data, schema)
|
||||
|
||||
if metadata:
|
||||
data = _append_vector_col(data, metadata, schema)
|
||||
metadata.update(data.schema.metadata or {})
|
||||
data = data.replace_schema_metadata(metadata)
|
||||
|
||||
# TODO improve the logics in _sanitize_schema
|
||||
data = _sanitize_schema(data, schema, on_bad_vectors, fill_value)
|
||||
if schema is None:
|
||||
schema = data.schema
|
||||
|
||||
_validate_schema(schema)
|
||||
return data, schema
|
||||
|
||||
|
||||
@@ -2015,6 +2027,55 @@ def _sanitize_nans(data, fill_value, on_bad_vectors, vec_arr, vector_column_name
|
||||
return data
|
||||
|
||||
|
||||
def _validate_schema(schema: pa.Schema):
|
||||
"""
|
||||
Make sure the metadata is valid utf8
|
||||
"""
|
||||
if schema.metadata is not None:
|
||||
_validate_metadata(schema.metadata)
|
||||
|
||||
|
||||
def _validate_metadata(metadata: dict):
|
||||
"""
|
||||
Make sure the metadata values are valid utf8 (can be nested)
|
||||
|
||||
Raises ValueError if not valid utf8
|
||||
"""
|
||||
for k, v in metadata.items():
|
||||
if isinstance(v, bytes):
|
||||
try:
|
||||
v.decode("utf8")
|
||||
except UnicodeDecodeError:
|
||||
raise ValueError(
|
||||
f"Metadata key {k} is not valid utf8. "
|
||||
"Consider base64 encode for generic binary metadata."
|
||||
)
|
||||
elif isinstance(v, dict):
|
||||
_validate_metadata(v)
|
||||
|
||||
|
||||
def _process_iterator(data: Iterable, schema: Optional[pa.Schema] = None) -> pa.Table:
|
||||
batches = []
|
||||
for batch in data:
|
||||
batch_table = _coerce_to_table(batch, schema)
|
||||
if schema is not None:
|
||||
if batch_table.schema != schema:
|
||||
try:
|
||||
batch_table = batch_table.cast(schema)
|
||||
except pa.lib.ArrowInvalid:
|
||||
raise ValueError(
|
||||
f"Input iterator yielded a batch with schema that "
|
||||
f"does not match the expected schema.\nExpected:\n{schema}\n"
|
||||
f"Got:\n{batch_table.schema}"
|
||||
)
|
||||
batches.append(batch_table)
|
||||
|
||||
if batches:
|
||||
return pa.concat_tables(batches)
|
||||
else:
|
||||
raise ValueError("Input iterable is empty")
|
||||
|
||||
|
||||
class AsyncTable:
|
||||
"""
|
||||
An AsyncTable is a collection of Records in a LanceDB Database.
|
||||
|
||||
@@ -64,6 +64,55 @@ def test_basic(db):
|
||||
assert table.to_lance().to_table() == ds.to_table()
|
||||
|
||||
|
||||
def test_input_data_type(db, tmp_path):
|
||||
schema = pa.schema(
|
||||
[
|
||||
pa.field("id", pa.int64()),
|
||||
pa.field("name", pa.string()),
|
||||
pa.field("age", pa.int32()),
|
||||
]
|
||||
)
|
||||
|
||||
data = {
|
||||
"id": [1, 2, 3, 4, 5],
|
||||
"name": ["Alice", "Bob", "Charlie", "David", "Eve"],
|
||||
"age": [25, 30, 35, 40, 45],
|
||||
}
|
||||
record_batch = pa.RecordBatch.from_pydict(data, schema=schema)
|
||||
pa_reader = pa.RecordBatchReader.from_batches(record_batch.schema, [record_batch])
|
||||
pa_table = pa.Table.from_batches([record_batch])
|
||||
|
||||
def create_dataset(tmp_path):
|
||||
path = os.path.join(tmp_path, "test_source_dataset")
|
||||
pa.dataset.write_dataset(pa_table, path, format="parquet")
|
||||
return pa.dataset.dataset(path, format="parquet")
|
||||
|
||||
pa_dataset = create_dataset(tmp_path)
|
||||
pa_scanner = pa_dataset.scanner()
|
||||
|
||||
input_types = [
|
||||
("RecordBatchReader", pa_reader),
|
||||
("RecordBatch", record_batch),
|
||||
("Table", pa_table),
|
||||
("Dataset", pa_dataset),
|
||||
("Scanner", pa_scanner),
|
||||
]
|
||||
for input_type, input_data in input_types:
|
||||
table_name = f"test_{input_type.lower()}"
|
||||
ds = LanceTable.create(db, table_name, data=input_data).to_lance()
|
||||
assert ds.schema == schema
|
||||
assert ds.count_rows() == 5
|
||||
|
||||
assert ds.schema.field("id").type == pa.int64()
|
||||
assert ds.schema.field("name").type == pa.string()
|
||||
assert ds.schema.field("age").type == pa.int32()
|
||||
|
||||
result_table = ds.to_table()
|
||||
assert result_table.column("id").to_pylist() == data["id"]
|
||||
assert result_table.column("name").to_pylist() == data["name"]
|
||||
assert result_table.column("age").to_pylist() == data["age"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close(db_async: AsyncConnection):
|
||||
table = await db_async.create_table("some_table", data=[{"id": 0}])
|
||||
@@ -274,7 +323,6 @@ def test_polars(db):
|
||||
|
||||
|
||||
def _add(table, schema):
|
||||
# table = LanceTable(db, "test")
|
||||
assert len(table) == 2
|
||||
|
||||
table.add([{"vector": [6.3, 100.5], "item": "new", "price": 30.0}])
|
||||
|
||||
Reference in New Issue
Block a user