mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-26 22:59:57 +00:00
Slow roasted over 12 hours, Pairs well with #111 --------- Co-authored-by: Chang She <chang@lancedb.com>
241 lines
6.8 KiB
Plaintext
241 lines
6.8 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n",
|
|
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip available: \u001b[0m\u001b[31;49m22.3.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.1.2\u001b[0m\n",
|
|
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n",
|
|
"\n",
|
|
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip available: \u001b[0m\u001b[31;49m22.3.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.1.2\u001b[0m\n",
|
|
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"!pip install --quiet -U lancedb\n",
|
|
"!pip install --quiet gradio transformers torch torchvision"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 60,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import io\n",
|
|
"import PIL\n",
|
|
"import duckdb\n",
|
|
"import lancedb"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## First run setup: Download data and pre-process"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 30,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"<lance.dataset.LanceDataset at 0x3045db590>"
|
|
]
|
|
},
|
|
"execution_count": 30,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"# remove null prompts\n",
|
|
"import lance\n",
|
|
"import pyarrow.compute as pc\n",
|
|
"\n",
|
|
"# download s3://eto-public/datasets/diffusiondb/small_10k.lance to this uri\n",
|
|
"data = lance.dataset(\"~/datasets/rawdata.lance\").to_table()\n",
|
|
"\n",
|
|
"# First data processing and full-text-search index\n",
|
|
"db = lancedb.connect(\"~/datasets/demo\")\n",
|
|
"tbl = db.create_table(\"diffusiondb\", data.filter(~pc.field(\"prompt\").is_null()))\n",
|
|
"tbl = tbl.create_fts_index([\"prompt\"])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Create / Open LanceDB Table"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 62,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"db = lancedb.connect(\"~/datasets/demo\")\n",
|
|
"tbl = db.open_table(\"diffusiondb\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Create CLIP embedding function for the text"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 63,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from transformers import CLIPModel, CLIPProcessor, CLIPTokenizerFast\n",
|
|
"\n",
|
|
"MODEL_ID = \"openai/clip-vit-base-patch32\"\n",
|
|
"\n",
|
|
"tokenizer = CLIPTokenizerFast.from_pretrained(MODEL_ID)\n",
|
|
"model = CLIPModel.from_pretrained(MODEL_ID)\n",
|
|
"processor = CLIPProcessor.from_pretrained(MODEL_ID)\n",
|
|
"\n",
|
|
"def embed_func(query):\n",
|
|
" inputs = tokenizer([query], padding=True, return_tensors=\"pt\")\n",
|
|
" text_features = model.get_text_features(**inputs)\n",
|
|
" return text_features.detach().numpy()[0]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Search functions for Gradio"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 64,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def find_image_vectors(query):\n",
|
|
" emb = embed_func(query)\n",
|
|
" return _extract(tbl.search(emb).limit(9).to_df())\n",
|
|
"\n",
|
|
"def find_image_keywords(query):\n",
|
|
" return _extract(tbl.search(query).limit(9).to_df())\n",
|
|
"\n",
|
|
"def find_image_sql(query):\n",
|
|
" diffusiondb = tbl.to_lance()\n",
|
|
" return _extract(duckdb.query(query).to_df())\n",
|
|
"\n",
|
|
"def _extract(df):\n",
|
|
" image_col = \"image\"\n",
|
|
" return [(PIL.Image.open(io.BytesIO(row[image_col])), row[\"prompt\"]) for _, row in df.iterrows()]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Setup Gradio interface"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 65,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Running on local URL: http://127.0.0.1:7867\n",
|
|
"\n",
|
|
"To create a public link, set `share=True` in `launch()`.\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<div><iframe src=\"http://127.0.0.1:7867/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
|
|
],
|
|
"text/plain": [
|
|
"<IPython.core.display.HTML object>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": []
|
|
},
|
|
"execution_count": 65,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"import gradio as gr\n",
|
|
"\n",
|
|
"\n",
|
|
"with gr.Blocks() as demo:\n",
|
|
"\n",
|
|
" with gr.Row():\n",
|
|
" with gr.Tab(\"Embeddings\"):\n",
|
|
" vector_query = gr.Textbox(value=\"portraits of a person\", show_label=False)\n",
|
|
" b1 = gr.Button(\"Submit\")\n",
|
|
" with gr.Tab(\"Keywords\"):\n",
|
|
" keyword_query = gr.Textbox(value=\"ninja turtle\", show_label=False)\n",
|
|
" b2 = gr.Button(\"Submit\")\n",
|
|
" with gr.Tab(\"SQL\"):\n",
|
|
" sql_query = gr.Textbox(value=\"SELECT * from diffusiondb WHERE image_nsfw >= 2 LIMIT 9\", show_label=False)\n",
|
|
" b3 = gr.Button(\"Submit\")\n",
|
|
" with gr.Row():\n",
|
|
" gallery = gr.Gallery(\n",
|
|
" label=\"Found images\", show_label=False, elem_id=\"gallery\"\n",
|
|
" ).style(columns=[3], rows=[3], object_fit=\"contain\", height=\"auto\") \n",
|
|
" \n",
|
|
" b1.click(find_image_vectors, inputs=vector_query, outputs=gallery)\n",
|
|
" b2.click(find_image_keywords, inputs=keyword_query, outputs=gallery)\n",
|
|
" b3.click(find_image_sql, inputs=sql_query, outputs=gallery)\n",
|
|
" \n",
|
|
"demo.launch()"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.11.3"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 1
|
|
}
|