mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-27 15:12:53 +00:00
feat(python): support writing huggingface dataset and dataset dict (#1110)
HuggingFace Dataset is written as arrow batches. For DatasetDict, all splits are written with a "split" column appended. - [x] what if the dataset schema already has a `split` column - [x] add unit tests
This commit is contained in:
@@ -37,6 +37,7 @@ import pyarrow as pa
|
||||
import pyarrow.compute as pc
|
||||
import pyarrow.fs as pa_fs
|
||||
from lance import LanceDataset
|
||||
from lance.dependencies import _check_for_hugging_face
|
||||
from lance.vector import vec_to_table
|
||||
|
||||
from .common import DATA, VEC, VECTOR_COLUMN_NAME
|
||||
@@ -73,6 +74,27 @@ def _sanitize_data(
|
||||
on_bad_vectors: str,
|
||||
fill_value: Any,
|
||||
):
|
||||
if _check_for_hugging_face(data):
|
||||
# Huggingface datasets
|
||||
from lance.dependencies import datasets
|
||||
|
||||
if isinstance(data, datasets.dataset_dict.DatasetDict):
|
||||
if schema is None:
|
||||
schema = _schema_from_hf(data, schema)
|
||||
data = _to_record_batch_generator(
|
||||
_to_batches_with_split(data),
|
||||
schema,
|
||||
metadata,
|
||||
on_bad_vectors,
|
||||
fill_value,
|
||||
)
|
||||
elif isinstance(data, datasets.Dataset):
|
||||
if schema is None:
|
||||
schema = data.features.arrow_schema
|
||||
data = _to_record_batch_generator(
|
||||
data.data.to_batches(), schema, metadata, on_bad_vectors, fill_value
|
||||
)
|
||||
|
||||
if isinstance(data, list):
|
||||
# convert to list of dict if data is a bunch of LanceModels
|
||||
if isinstance(data[0], LanceModel):
|
||||
@@ -109,6 +131,37 @@ def _sanitize_data(
|
||||
return data
|
||||
|
||||
|
||||
def _schema_from_hf(data, schema):
|
||||
"""
|
||||
Extract pyarrow schema from HuggingFace DatasetDict
|
||||
and validate that they're all the same schema between
|
||||
splits
|
||||
"""
|
||||
for dataset in data.values():
|
||||
if schema is None:
|
||||
schema = dataset.features.arrow_schema
|
||||
elif schema != dataset.features.arrow_schema:
|
||||
msg = "All datasets in a HuggingFace DatasetDict must have the same schema"
|
||||
raise TypeError(msg)
|
||||
return schema
|
||||
|
||||
|
||||
def _to_batches_with_split(data):
|
||||
"""
|
||||
Return a generator of RecordBatches from a HuggingFace DatasetDict
|
||||
with an extra `split` column
|
||||
"""
|
||||
for key, dataset in data.items():
|
||||
for batch in dataset.data.to_batches():
|
||||
table = pa.Table.from_batches([batch])
|
||||
if "split" not in table.column_names:
|
||||
table = table.append_column(
|
||||
"split", pa.array([key] * batch.num_rows, pa.string())
|
||||
)
|
||||
for b in table.to_batches():
|
||||
yield b
|
||||
|
||||
|
||||
def _append_vector_col(data: pa.Table, metadata: dict, schema: Optional[pa.Schema]):
|
||||
"""
|
||||
Use the embedding function to automatically embed the source column and add the
|
||||
@@ -143,12 +196,13 @@ def _to_record_batch_generator(
|
||||
data: Iterable, schema, metadata, on_bad_vectors, fill_value
|
||||
):
|
||||
for batch in data:
|
||||
if not isinstance(batch, pa.RecordBatch):
|
||||
table = _sanitize_data(batch, schema, metadata, on_bad_vectors, fill_value)
|
||||
for batch in table.to_batches():
|
||||
yield batch
|
||||
else:
|
||||
yield batch
|
||||
# always convert to table because we need to sanitize the data
|
||||
# and do things like add the vector column etc
|
||||
if isinstance(batch, pa.RecordBatch):
|
||||
batch = pa.Table.from_batches([batch])
|
||||
batch = _sanitize_data(batch, schema, metadata, on_bad_vectors, fill_value)
|
||||
for b in batch.to_batches():
|
||||
yield b
|
||||
|
||||
|
||||
class Table(ABC):
|
||||
|
||||
126
python/python/tests/test_huggingface.py
Normal file
126
python/python/tests/test_huggingface.py
Normal file
@@ -0,0 +1,126 @@
|
||||
# Copyright 2024 Lance Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import lancedb
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
import pytest
|
||||
from lancedb.embeddings import get_registry
|
||||
from lancedb.embeddings.base import TextEmbeddingFunction
|
||||
from lancedb.embeddings.registry import register
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
|
||||
datasets = pytest.importorskip("datasets")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def mock_embedding_function():
|
||||
@register("random")
|
||||
class MockTextEmbeddingFunction(TextEmbeddingFunction):
|
||||
def generate_embeddings(self, texts):
|
||||
return [np.random.randn(128).tolist() for _ in range(len(texts))]
|
||||
|
||||
def ndims(self):
|
||||
return 128
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_hf_dataset():
|
||||
# Create pyarrow table with `text` and `label` columns
|
||||
train = datasets.Dataset(
|
||||
pa.table(
|
||||
{
|
||||
"text": ["foo", "bar"],
|
||||
"label": [0, 1],
|
||||
}
|
||||
),
|
||||
split="train",
|
||||
)
|
||||
|
||||
test = datasets.Dataset(
|
||||
pa.table(
|
||||
{
|
||||
"text": ["fizz", "buzz"],
|
||||
"label": [0, 1],
|
||||
}
|
||||
),
|
||||
split="test",
|
||||
)
|
||||
return datasets.DatasetDict({"train": train, "test": test})
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def hf_dataset_with_split():
|
||||
# Create pyarrow table with `text` and `label` columns
|
||||
train = datasets.Dataset(
|
||||
pa.table(
|
||||
{"text": ["foo", "bar"], "label": [0, 1], "split": ["train", "train"]}
|
||||
),
|
||||
split="train",
|
||||
)
|
||||
|
||||
test = datasets.Dataset(
|
||||
pa.table(
|
||||
{"text": ["fizz", "buzz"], "label": [0, 1], "split": ["test", "test"]}
|
||||
),
|
||||
split="test",
|
||||
)
|
||||
return datasets.DatasetDict({"train": train, "test": test})
|
||||
|
||||
|
||||
def test_write_hf_dataset(tmp_path: Path, mock_embedding_function, mock_hf_dataset):
|
||||
db = lancedb.connect(tmp_path)
|
||||
emb = get_registry().get("random").create()
|
||||
|
||||
class Schema(LanceModel):
|
||||
text: str = emb.SourceField()
|
||||
label: int
|
||||
vector: Vector(emb.ndims()) = emb.VectorField()
|
||||
|
||||
train_table = db.create_table("train", schema=Schema)
|
||||
train_table.add(mock_hf_dataset["train"])
|
||||
|
||||
class WithSplit(LanceModel):
|
||||
text: str = emb.SourceField()
|
||||
label: int
|
||||
vector: Vector(emb.ndims()) = emb.VectorField()
|
||||
split: str
|
||||
|
||||
full_table = db.create_table("full", schema=WithSplit)
|
||||
full_table.add(mock_hf_dataset)
|
||||
|
||||
assert len(train_table) == mock_hf_dataset["train"].num_rows
|
||||
assert len(full_table) == sum(ds.num_rows for ds in mock_hf_dataset.values())
|
||||
|
||||
rt_train_table = full_table.to_lance().to_table(
|
||||
columns=["text", "label"], filter="split='train'"
|
||||
)
|
||||
assert rt_train_table.to_pylist() == mock_hf_dataset["train"].data.to_pylist()
|
||||
|
||||
|
||||
def test_bad_hf_dataset(tmp_path: Path, mock_embedding_function, hf_dataset_with_split):
|
||||
db = lancedb.connect(tmp_path)
|
||||
emb = get_registry().get("random").create()
|
||||
|
||||
class Schema(LanceModel):
|
||||
text: str = emb.SourceField()
|
||||
label: int
|
||||
vector: Vector(emb.ndims()) = emb.VectorField()
|
||||
split: str
|
||||
|
||||
train_table = db.create_table("train", schema=Schema)
|
||||
# this should still work because we don't add the split column
|
||||
# if it already exists
|
||||
train_table.add(hf_dataset_with_split)
|
||||
Reference in New Issue
Block a user