From 4466cfa95875b03e78db6904c82ebc47b56d0251 Mon Sep 17 00:00:00 2001 From: Chang She <759245+changhiskhan@users.noreply.github.com> Date: Wed, 20 Mar 2024 00:22:03 -0700 Subject: [PATCH] feat(python): support writing huggingface dataset and dataset dict (#1110) HuggingFace Dataset is written as arrow batches. For DatasetDict, all splits are written with a "split" column appended. - [x] what if the dataset schema already has a `split` column - [x] add unit tests --- python/python/lancedb/table.py | 66 +++++++++++-- python/python/tests/test_huggingface.py | 126 ++++++++++++++++++++++++ 2 files changed, 186 insertions(+), 6 deletions(-) create mode 100644 python/python/tests/test_huggingface.py diff --git a/python/python/lancedb/table.py b/python/python/lancedb/table.py index 3501ae60..9e5227d5 100644 --- a/python/python/lancedb/table.py +++ b/python/python/lancedb/table.py @@ -37,6 +37,7 @@ import pyarrow as pa import pyarrow.compute as pc import pyarrow.fs as pa_fs from lance import LanceDataset +from lance.dependencies import _check_for_hugging_face from lance.vector import vec_to_table from .common import DATA, VEC, VECTOR_COLUMN_NAME @@ -73,6 +74,27 @@ def _sanitize_data( on_bad_vectors: str, fill_value: Any, ): + if _check_for_hugging_face(data): + # Huggingface datasets + from lance.dependencies import datasets + + if isinstance(data, datasets.dataset_dict.DatasetDict): + if schema is None: + schema = _schema_from_hf(data, schema) + data = _to_record_batch_generator( + _to_batches_with_split(data), + schema, + metadata, + on_bad_vectors, + fill_value, + ) + elif isinstance(data, datasets.Dataset): + if schema is None: + schema = data.features.arrow_schema + data = _to_record_batch_generator( + data.data.to_batches(), schema, metadata, on_bad_vectors, fill_value + ) + if isinstance(data, list): # convert to list of dict if data is a bunch of LanceModels if isinstance(data[0], LanceModel): @@ -109,6 +131,37 @@ def _sanitize_data( return data +def _schema_from_hf(data, schema): + """ + Extract pyarrow schema from HuggingFace DatasetDict + and validate that they're all the same schema between + splits + """ + for dataset in data.values(): + if schema is None: + schema = dataset.features.arrow_schema + elif schema != dataset.features.arrow_schema: + msg = "All datasets in a HuggingFace DatasetDict must have the same schema" + raise TypeError(msg) + return schema + + +def _to_batches_with_split(data): + """ + Return a generator of RecordBatches from a HuggingFace DatasetDict + with an extra `split` column + """ + for key, dataset in data.items(): + for batch in dataset.data.to_batches(): + table = pa.Table.from_batches([batch]) + if "split" not in table.column_names: + table = table.append_column( + "split", pa.array([key] * batch.num_rows, pa.string()) + ) + for b in table.to_batches(): + yield b + + def _append_vector_col(data: pa.Table, metadata: dict, schema: Optional[pa.Schema]): """ Use the embedding function to automatically embed the source column and add the @@ -143,12 +196,13 @@ def _to_record_batch_generator( data: Iterable, schema, metadata, on_bad_vectors, fill_value ): for batch in data: - if not isinstance(batch, pa.RecordBatch): - table = _sanitize_data(batch, schema, metadata, on_bad_vectors, fill_value) - for batch in table.to_batches(): - yield batch - else: - yield batch + # always convert to table because we need to sanitize the data + # and do things like add the vector column etc + if isinstance(batch, pa.RecordBatch): + batch = pa.Table.from_batches([batch]) + batch = _sanitize_data(batch, schema, metadata, on_bad_vectors, fill_value) + for b in batch.to_batches(): + yield b class Table(ABC): diff --git a/python/python/tests/test_huggingface.py b/python/python/tests/test_huggingface.py new file mode 100644 index 00000000..b5377149 --- /dev/null +++ b/python/python/tests/test_huggingface.py @@ -0,0 +1,126 @@ +# Copyright 2024 Lance Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path + +import lancedb +import numpy as np +import pyarrow as pa +import pytest +from lancedb.embeddings import get_registry +from lancedb.embeddings.base import TextEmbeddingFunction +from lancedb.embeddings.registry import register +from lancedb.pydantic import LanceModel, Vector + +datasets = pytest.importorskip("datasets") + + +@pytest.fixture(scope="session") +def mock_embedding_function(): + @register("random") + class MockTextEmbeddingFunction(TextEmbeddingFunction): + def generate_embeddings(self, texts): + return [np.random.randn(128).tolist() for _ in range(len(texts))] + + def ndims(self): + return 128 + + +@pytest.fixture +def mock_hf_dataset(): + # Create pyarrow table with `text` and `label` columns + train = datasets.Dataset( + pa.table( + { + "text": ["foo", "bar"], + "label": [0, 1], + } + ), + split="train", + ) + + test = datasets.Dataset( + pa.table( + { + "text": ["fizz", "buzz"], + "label": [0, 1], + } + ), + split="test", + ) + return datasets.DatasetDict({"train": train, "test": test}) + + +@pytest.fixture +def hf_dataset_with_split(): + # Create pyarrow table with `text` and `label` columns + train = datasets.Dataset( + pa.table( + {"text": ["foo", "bar"], "label": [0, 1], "split": ["train", "train"]} + ), + split="train", + ) + + test = datasets.Dataset( + pa.table( + {"text": ["fizz", "buzz"], "label": [0, 1], "split": ["test", "test"]} + ), + split="test", + ) + return datasets.DatasetDict({"train": train, "test": test}) + + +def test_write_hf_dataset(tmp_path: Path, mock_embedding_function, mock_hf_dataset): + db = lancedb.connect(tmp_path) + emb = get_registry().get("random").create() + + class Schema(LanceModel): + text: str = emb.SourceField() + label: int + vector: Vector(emb.ndims()) = emb.VectorField() + + train_table = db.create_table("train", schema=Schema) + train_table.add(mock_hf_dataset["train"]) + + class WithSplit(LanceModel): + text: str = emb.SourceField() + label: int + vector: Vector(emb.ndims()) = emb.VectorField() + split: str + + full_table = db.create_table("full", schema=WithSplit) + full_table.add(mock_hf_dataset) + + assert len(train_table) == mock_hf_dataset["train"].num_rows + assert len(full_table) == sum(ds.num_rows for ds in mock_hf_dataset.values()) + + rt_train_table = full_table.to_lance().to_table( + columns=["text", "label"], filter="split='train'" + ) + assert rt_train_table.to_pylist() == mock_hf_dataset["train"].data.to_pylist() + + +def test_bad_hf_dataset(tmp_path: Path, mock_embedding_function, hf_dataset_with_split): + db = lancedb.connect(tmp_path) + emb = get_registry().get("random").create() + + class Schema(LanceModel): + text: str = emb.SourceField() + label: int + vector: Vector(emb.ndims()) = emb.VectorField() + split: str + + train_table = db.create_table("train", schema=Schema) + # this should still work because we don't add the split column + # if it already exists + train_table.add(hf_dataset_with_split)