From 99cbda8b07a73f302ee14e21345f1e260ebb868a Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Thu, 1 Jun 2023 10:23:29 -0700 Subject: [PATCH] Generate diffusiondb embeddings (#111) --- notebooks/diffusiondb/datagen.py | 108 +++++++++++++++++++++++++ notebooks/diffusiondb/requirements.txt | 9 +++ 2 files changed, 117 insertions(+) create mode 100755 notebooks/diffusiondb/datagen.py create mode 100644 notebooks/diffusiondb/requirements.txt diff --git a/notebooks/diffusiondb/datagen.py b/notebooks/diffusiondb/datagen.py new file mode 100755 index 00000000..493b81d8 --- /dev/null +++ b/notebooks/diffusiondb/datagen.py @@ -0,0 +1,108 @@ +#!/usr/bin/env python +# +# Copyright 2023 LanceDB Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Dataset hf://poloclub/diffusiondb +""" + +import io +from argparse import ArgumentParser +from multiprocessing import Pool + +import lance +import lancedb +import pyarrow as pa +from datasets import load_dataset +from PIL import Image +from transformers import CLIPModel, CLIPProcessor, CLIPTokenizerFast + +MODEL_ID = "openai/clip-vit-base-patch32" + +device = "cuda" + +tokenizer = CLIPTokenizerFast.from_pretrained(MODEL_ID) +model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device) +processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") + +schema = pa.schema( + [ + pa.field("prompt", pa.string()), + pa.field("seed", pa.uint32()), + pa.field("step", pa.uint16()), + pa.field("cfg", pa.float32()), + pa.field("sampler", pa.string()), + pa.field("width", pa.uint16()), + pa.field("height", pa.uint16()), + pa.field("timestamp", pa.timestamp("s")), + pa.field("image_nsfw", pa.float32()), + pa.field("prompt_nsfw", pa.float32()), + pa.field("vector", pa.list_(pa.float32(), 512)), + pa.field("image", pa.binary()), + ] +) + + +def pil_to_bytes(img) -> list[bytes]: + buf = io.BytesIO() + img.save(buf, format="PNG") + return buf.getvalue() + + +def generate_clip_embeddings(batch) -> pa.RecordBatch: + image = processor(text=None, images=batch["image"], return_tensors="pt")[ + "pixel_values" + ].to(device) + img_emb = model.get_image_features(image) + batch["vector"] = img_emb.cpu().tolist() + + with Pool() as p: + batch["image_bytes"] = p.map(pil_to_bytes, batch["image"]) + return batch + + +def datagen(args): + """Generate DiffusionDB dataset, and use CLIP model to generate image embeddings.""" + dataset = load_dataset("poloclub/diffusiondb", args.subset) + data = [] + for b in dataset.map( + generate_clip_embeddings, batched=True, batch_size=256, remove_columns=["image"] + )["train"]: + b["image"] = b["image_bytes"] + del b["image_bytes"] + data.append(b) + tbl = pa.Table.from_pylist(data, schema=schema) + return tbl + + +def main(): + parser = ArgumentParser() + parser.add_argument( + "-o", "--output", metavar="DIR", help="Output lance directory", required=True + ) + parser.add_argument( + "-s", + "--subset", + choices=["2m_all", "2m_first_10k", "2m_first_100k"], + default="2m_first_10k", + help="subset of the hg dataset", + ) + + args = parser.parse_args() + + batches = datagen(args) + lance.write_dataset(batches, args.output) + + +if __name__ == "__main__": + main() diff --git a/notebooks/diffusiondb/requirements.txt b/notebooks/diffusiondb/requirements.txt new file mode 100644 index 00000000..4baee306 --- /dev/null +++ b/notebooks/diffusiondb/requirements.txt @@ -0,0 +1,9 @@ +datasets +Pillow +lancedb +isort +black +transformers +--index-url https://download.pytorch.org/whl/cu118 +torch +torchvision