diff --git a/notebooks/multimodal_search.ipynb b/notebooks/multimodal_search.ipynb new file mode 100644 index 00000000..6fa5d552 --- /dev/null +++ b/notebooks/multimodal_search.ipynb @@ -0,0 +1,240 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip available: \u001b[0m\u001b[31;49m22.3.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.1.2\u001b[0m\n", + "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n", + "\n", + "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip available: \u001b[0m\u001b[31;49m22.3.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.1.2\u001b[0m\n", + "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n" + ] + } + ], + "source": [ + "!pip install --quiet -U lancedb\n", + "!pip install --quiet gradio transformers torch torchvision" + ] + }, + { + "cell_type": "code", + "execution_count": 60, + "metadata": {}, + "outputs": [], + "source": [ + "import io\n", + "import PIL\n", + "import duckdb\n", + "import lancedb" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## First run setup: Download data and pre-process" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 30, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# remove null prompts\n", + "import lance\n", + "import pyarrow.compute as pc\n", + "\n", + "# download s3://eto-public/datasets/diffusiondb/small_10k.lance to this uri\n", + "data = lance.dataset(\"~/datasets/rawdata.lance\").to_table()\n", + "\n", + "# First data processing and full-text-search index\n", + "db = lancedb.connect(\"~/datasets/demo\")\n", + "tbl = db.create_table(\"diffusiondb\", data.filter(~pc.field(\"prompt\").is_null()))\n", + "tbl = tbl.create_fts_index([\"prompt\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Create / Open LanceDB Table" + ] + }, + { + "cell_type": "code", + "execution_count": 62, + "metadata": {}, + "outputs": [], + "source": [ + "db = lancedb.connect(\"~/datasets/demo\")\n", + "tbl = db.open_table(\"diffusiondb\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Create CLIP embedding function for the text" + ] + }, + { + "cell_type": "code", + "execution_count": 63, + "metadata": {}, + "outputs": [], + "source": [ + "from transformers import CLIPModel, CLIPProcessor, CLIPTokenizerFast\n", + "\n", + "MODEL_ID = \"openai/clip-vit-base-patch32\"\n", + "\n", + "tokenizer = CLIPTokenizerFast.from_pretrained(MODEL_ID)\n", + "model = CLIPModel.from_pretrained(MODEL_ID)\n", + "processor = CLIPProcessor.from_pretrained(MODEL_ID)\n", + "\n", + "def embed_func(query):\n", + " inputs = tokenizer([query], padding=True, return_tensors=\"pt\")\n", + " text_features = model.get_text_features(**inputs)\n", + " return text_features.detach().numpy()[0]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Search functions for Gradio" + ] + }, + { + "cell_type": "code", + "execution_count": 64, + "metadata": {}, + "outputs": [], + "source": [ + "def find_image_vectors(query):\n", + " emb = embed_func(query)\n", + " return _extract(tbl.search(emb).limit(9).to_df())\n", + "\n", + "def find_image_keywords(query):\n", + " return _extract(tbl.search(query).limit(9).to_df())\n", + "\n", + "def find_image_sql(query):\n", + " diffusiondb = tbl.to_lance()\n", + " return _extract(duckdb.query(query).to_df())\n", + "\n", + "def _extract(df):\n", + " image_col = \"image\"\n", + " return [(PIL.Image.open(io.BytesIO(row[image_col])), row[\"prompt\"]) for _, row in df.iterrows()]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup Gradio interface" + ] + }, + { + "cell_type": "code", + "execution_count": 65, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Running on local URL: http://127.0.0.1:7867\n", + "\n", + "To create a public link, set `share=True` in `launch()`.\n" + ] + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [] + }, + "execution_count": 65, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import gradio as gr\n", + "\n", + "\n", + "with gr.Blocks() as demo:\n", + "\n", + " with gr.Row():\n", + " with gr.Tab(\"Embeddings\"):\n", + " vector_query = gr.Textbox(value=\"portraits of a person\", show_label=False)\n", + " b1 = gr.Button(\"Submit\")\n", + " with gr.Tab(\"Keywords\"):\n", + " keyword_query = gr.Textbox(value=\"ninja turtle\", show_label=False)\n", + " b2 = gr.Button(\"Submit\")\n", + " with gr.Tab(\"SQL\"):\n", + " sql_query = gr.Textbox(value=\"SELECT * from diffusiondb WHERE image_nsfw >= 2 LIMIT 9\", show_label=False)\n", + " b3 = gr.Button(\"Submit\")\n", + " with gr.Row():\n", + " gallery = gr.Gallery(\n", + " label=\"Found images\", show_label=False, elem_id=\"gallery\"\n", + " ).style(columns=[3], rows=[3], object_fit=\"contain\", height=\"auto\") \n", + " \n", + " b1.click(find_image_vectors, inputs=vector_query, outputs=gallery)\n", + " b2.click(find_image_keywords, inputs=keyword_query, outputs=gallery)\n", + " b3.click(find_image_sql, inputs=sql_query, outputs=gallery)\n", + " \n", + "demo.launch()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.3" + } + }, + "nbformat": 4, + "nbformat_minor": 1 +}