mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-25 14:29:56 +00:00
109 lines
3.1 KiB
Python
Executable File
109 lines
3.1 KiB
Python
Executable File
#!/usr/bin/env python
|
|
#
|
|
# Copyright 2023 LanceDB Developers
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""Dataset hf://poloclub/diffusiondb
|
|
"""
|
|
|
|
import io
|
|
from argparse import ArgumentParser
|
|
from multiprocessing import Pool
|
|
|
|
import lance
|
|
import lancedb
|
|
import pyarrow as pa
|
|
from datasets import load_dataset
|
|
from PIL import Image
|
|
from transformers import CLIPModel, CLIPProcessor, CLIPTokenizerFast
|
|
|
|
MODEL_ID = "openai/clip-vit-base-patch32"
|
|
|
|
device = "cuda"
|
|
|
|
tokenizer = CLIPTokenizerFast.from_pretrained(MODEL_ID)
|
|
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
|
|
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
|
|
|
schema = pa.schema(
|
|
[
|
|
pa.field("prompt", pa.string()),
|
|
pa.field("seed", pa.uint32()),
|
|
pa.field("step", pa.uint16()),
|
|
pa.field("cfg", pa.float32()),
|
|
pa.field("sampler", pa.string()),
|
|
pa.field("width", pa.uint16()),
|
|
pa.field("height", pa.uint16()),
|
|
pa.field("timestamp", pa.timestamp("s")),
|
|
pa.field("image_nsfw", pa.float32()),
|
|
pa.field("prompt_nsfw", pa.float32()),
|
|
pa.field("vector", pa.list_(pa.float32(), 512)),
|
|
pa.field("image", pa.binary()),
|
|
]
|
|
)
|
|
|
|
|
|
def pil_to_bytes(img) -> list[bytes]:
|
|
buf = io.BytesIO()
|
|
img.save(buf, format="PNG")
|
|
return buf.getvalue()
|
|
|
|
|
|
def generate_clip_embeddings(batch) -> pa.RecordBatch:
|
|
image = processor(text=None, images=batch["image"], return_tensors="pt")[
|
|
"pixel_values"
|
|
].to(device)
|
|
img_emb = model.get_image_features(image)
|
|
batch["vector"] = img_emb.cpu().tolist()
|
|
|
|
with Pool() as p:
|
|
batch["image_bytes"] = p.map(pil_to_bytes, batch["image"])
|
|
return batch
|
|
|
|
|
|
def datagen(args):
|
|
"""Generate DiffusionDB dataset, and use CLIP model to generate image embeddings."""
|
|
dataset = load_dataset("poloclub/diffusiondb", args.subset)
|
|
data = []
|
|
for b in dataset.map(
|
|
generate_clip_embeddings, batched=True, batch_size=256, remove_columns=["image"]
|
|
)["train"]:
|
|
b["image"] = b["image_bytes"]
|
|
del b["image_bytes"]
|
|
data.append(b)
|
|
tbl = pa.Table.from_pylist(data, schema=schema)
|
|
return tbl
|
|
|
|
|
|
def main():
|
|
parser = ArgumentParser()
|
|
parser.add_argument(
|
|
"-o", "--output", metavar="DIR", help="Output lance directory", required=True
|
|
)
|
|
parser.add_argument(
|
|
"-s",
|
|
"--subset",
|
|
choices=["2m_all", "2m_first_10k", "2m_first_100k"],
|
|
default="2m_first_10k",
|
|
help="subset of the hg dataset",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
batches = datagen(args)
|
|
lance.write_dataset(batches, args.output)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|