Files
lancedb/notebooks/multimodal_search.ipynb
Chang She 03013a4434 Multimodal search demo (#118)
Slow roasted over 12 hours, Pairs well with #111

---------

Co-authored-by: Chang She <chang@lancedb.com>
2023-06-01 10:34:08 -07:00

241 lines
6.8 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip available: \u001b[0m\u001b[31;49m22.3.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.1.2\u001b[0m\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n",
"\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip available: \u001b[0m\u001b[31;49m22.3.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.1.2\u001b[0m\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n"
]
}
],
"source": [
"!pip install --quiet -U lancedb\n",
"!pip install --quiet gradio transformers torch torchvision"
]
},
{
"cell_type": "code",
"execution_count": 60,
"metadata": {},
"outputs": [],
"source": [
"import io\n",
"import PIL\n",
"import duckdb\n",
"import lancedb"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## First run setup: Download data and pre-process"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<lance.dataset.LanceDataset at 0x3045db590>"
]
},
"execution_count": 30,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# remove null prompts\n",
"import lance\n",
"import pyarrow.compute as pc\n",
"\n",
"# download s3://eto-public/datasets/diffusiondb/small_10k.lance to this uri\n",
"data = lance.dataset(\"~/datasets/rawdata.lance\").to_table()\n",
"\n",
"# First data processing and full-text-search index\n",
"db = lancedb.connect(\"~/datasets/demo\")\n",
"tbl = db.create_table(\"diffusiondb\", data.filter(~pc.field(\"prompt\").is_null()))\n",
"tbl = tbl.create_fts_index([\"prompt\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Create / Open LanceDB Table"
]
},
{
"cell_type": "code",
"execution_count": 62,
"metadata": {},
"outputs": [],
"source": [
"db = lancedb.connect(\"~/datasets/demo\")\n",
"tbl = db.open_table(\"diffusiondb\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Create CLIP embedding function for the text"
]
},
{
"cell_type": "code",
"execution_count": 63,
"metadata": {},
"outputs": [],
"source": [
"from transformers import CLIPModel, CLIPProcessor, CLIPTokenizerFast\n",
"\n",
"MODEL_ID = \"openai/clip-vit-base-patch32\"\n",
"\n",
"tokenizer = CLIPTokenizerFast.from_pretrained(MODEL_ID)\n",
"model = CLIPModel.from_pretrained(MODEL_ID)\n",
"processor = CLIPProcessor.from_pretrained(MODEL_ID)\n",
"\n",
"def embed_func(query):\n",
" inputs = tokenizer([query], padding=True, return_tensors=\"pt\")\n",
" text_features = model.get_text_features(**inputs)\n",
" return text_features.detach().numpy()[0]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Search functions for Gradio"
]
},
{
"cell_type": "code",
"execution_count": 64,
"metadata": {},
"outputs": [],
"source": [
"def find_image_vectors(query):\n",
" emb = embed_func(query)\n",
" return _extract(tbl.search(emb).limit(9).to_df())\n",
"\n",
"def find_image_keywords(query):\n",
" return _extract(tbl.search(query).limit(9).to_df())\n",
"\n",
"def find_image_sql(query):\n",
" diffusiondb = tbl.to_lance()\n",
" return _extract(duckdb.query(query).to_df())\n",
"\n",
"def _extract(df):\n",
" image_col = \"image\"\n",
" return [(PIL.Image.open(io.BytesIO(row[image_col])), row[\"prompt\"]) for _, row in df.iterrows()]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Setup Gradio interface"
]
},
{
"cell_type": "code",
"execution_count": 65,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Running on local URL: http://127.0.0.1:7867\n",
"\n",
"To create a public link, set `share=True` in `launch()`.\n"
]
},
{
"data": {
"text/html": [
"<div><iframe src=\"http://127.0.0.1:7867/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": []
},
"execution_count": 65,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import gradio as gr\n",
"\n",
"\n",
"with gr.Blocks() as demo:\n",
"\n",
" with gr.Row():\n",
" with gr.Tab(\"Embeddings\"):\n",
" vector_query = gr.Textbox(value=\"portraits of a person\", show_label=False)\n",
" b1 = gr.Button(\"Submit\")\n",
" with gr.Tab(\"Keywords\"):\n",
" keyword_query = gr.Textbox(value=\"ninja turtle\", show_label=False)\n",
" b2 = gr.Button(\"Submit\")\n",
" with gr.Tab(\"SQL\"):\n",
" sql_query = gr.Textbox(value=\"SELECT * from diffusiondb WHERE image_nsfw >= 2 LIMIT 9\", show_label=False)\n",
" b3 = gr.Button(\"Submit\")\n",
" with gr.Row():\n",
" gallery = gr.Gallery(\n",
" label=\"Found images\", show_label=False, elem_id=\"gallery\"\n",
" ).style(columns=[3], rows=[3], object_fit=\"contain\", height=\"auto\") \n",
" \n",
" b1.click(find_image_vectors, inputs=vector_query, outputs=gallery)\n",
" b2.click(find_image_keywords, inputs=keyword_query, outputs=gallery)\n",
" b3.click(find_image_sql, inputs=sql_query, outputs=gallery)\n",
" \n",
"demo.launch()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.3"
}
},
"nbformat": 4,
"nbformat_minor": 1
}