Add tutorial notebook

Convert contextualization and embeddings functionality.
And use it with converted notebook for video search
This commit is contained in:
Chang She
2023-03-23 15:07:58 -07:00
parent 98606b4621
commit b91139d3c7
12 changed files with 660 additions and 6 deletions

2
.gitignore vendored
View File

@@ -13,3 +13,5 @@ site
python/build
python/dist
notebooks/.ipynb_checkpoints

View File

@@ -45,5 +45,5 @@ db = lancedb.connect(uri)
table = db.create_table("my_table",
data=[{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0}])
result = table.search([100, 100]).where("price < 15").limit(1).to_df()
result = table.search([100, 100]).limit(2).to_df()
```

View File

@@ -7,6 +7,7 @@ theme:
plugins:
- search
- mkdocstrings
- mkdocs-jupyter
nav:
- Home: index.md

View File

@@ -1,3 +1,4 @@
mkdocs==1.4.2
mkdocs-jupyter==0.24.1
mkdocs-material==9.1.3
mkdocstrings[python]==0.20.0

View File

@@ -14,7 +14,26 @@ The key features of LanceDB include:
LanceDB's core is written in Rust 🦀 and is built using Lance, an open-source columnar format designed for performant ML workloads.
## Installation
```shell
pip install lancedb
```
## Quickstart
```python
import lancedb
db = lancedb.connect(".")
table = db.create_table("my_table",
data=[{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0}])
result = table.search([100, 100]).limit(2).to_df()
```
## Documentation Quick Links
* `Quick start` - search and filter a hello world vector dataset with LanceDB using the Python SDK.
* [`API Reference`](python.md) - detailed documentation for the LanceDB Python SDK.

View File

@@ -0,0 +1,418 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "42bf01fb",
"metadata": {},
"source": [
"# We're going to build question and answer bot\n",
"\n",
"That allow you to search through youtube transcripts using natural language"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "48547ddb",
"metadata": {},
"outputs": [],
"source": [
"pip install --quiet openai datasets lancedb"
]
},
{
"cell_type": "markdown",
"id": "22e570f4",
"metadata": {},
"source": [
"## Download the data\n",
"700 videos and 208619 sentences"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "a8987fcb",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Found cached dataset json (/Users/changshe/.cache/huggingface/datasets/jamescalam___json/jamescalam--youtube-transcriptions-08d889f6a5386b9b/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51)\n"
]
},
{
"data": {
"text/plain": [
"Dataset({\n",
" features: ['title', 'published', 'url', 'video_id', 'channel_id', 'id', 'text', 'start', 'end'],\n",
" num_rows: 208619\n",
"})"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from datasets import load_dataset\n",
"\n",
"data = load_dataset('jamescalam/youtube-transcriptions', split='train')\n",
"data"
]
},
{
"cell_type": "markdown",
"id": "5ac2b6a3",
"metadata": {},
"source": [
"## Prepare context\n",
"\n",
"Create context of 20 sentences"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "121a7087",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>title</th>\n",
" <th>published</th>\n",
" <th>url</th>\n",
" <th>video_id</th>\n",
" <th>channel_id</th>\n",
" <th>id</th>\n",
" <th>text</th>\n",
" <th>start</th>\n",
" <th>end</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>177622</th>\n",
" <td>$5 MILLION AI for FREE</td>\n",
" <td>2022-08-12 15:18:07</td>\n",
" <td>https://youtu.be/3EjtHs_lXnk</td>\n",
" <td>3EjtHs_lXnk</td>\n",
" <td>UCfzlCWGWYyIQ0aLC5w48gBQ</td>\n",
" <td>3EjtHs_lXnk-t0.0</td>\n",
" <td>Imagine an AI where all in the same model you ...</td>\n",
" <td>0.0</td>\n",
" <td>24.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" title published \\\n",
"177622 $5 MILLION AI for FREE 2022-08-12 15:18:07 \n",
"\n",
" url video_id channel_id \\\n",
"177622 https://youtu.be/3EjtHs_lXnk 3EjtHs_lXnk UCfzlCWGWYyIQ0aLC5w48gBQ \n",
"\n",
" id text \\\n",
"177622 3EjtHs_lXnk-t0.0 Imagine an AI where all in the same model you ... \n",
"\n",
" start end \n",
"177622 0.0 24.0 "
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from lancedb.context import contextualize\n",
"\n",
"df = (contextualize(data.to_pandas())\n",
" .groupby(\"title\").text_col(\"text\")\n",
" .window(20).stride(4)\n",
" .to_df())\n",
"df.head(1)"
]
},
{
"cell_type": "markdown",
"id": "3044e0b0",
"metadata": {},
"source": [
"## Create embedding function\n",
"We'll call the OpenAI embeddings API to get embeddings"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "8eefc159",
"metadata": {},
"outputs": [],
"source": [
"import openai\n",
"\n",
"# Configure environment variable OPENAI_API_KEY\n",
"# OR add variable openai.api_key = \"sk-...\"\n",
"\n",
"def embed_func(c): \n",
" rs = openai.Embedding.create(input=c, engine=\"text-embedding-ada-002\")\n",
" return [record[\"embedding\"] for record in rs[\"data\"]]"
]
},
{
"cell_type": "markdown",
"id": "2106b5bb",
"metadata": {},
"source": [
"## Create the LanceDB Table"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "13f15068",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Building vector index: IVF64,OPQ96, metric=l2\n"
]
},
{
"data": {
"text/plain": [
"<lance.dataset.LanceDataset at 0x13fd38dc0>"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Sample 16384 out of 48935 to train kmeans of 1536 dim, 64 clusters\n"
]
}
],
"source": [
"import lancedb\n",
"from lancedb.embeddings import with_embeddings\n",
"\n",
"data = with_embeddings(embed_func, df, show_progress=True)\n",
"\n",
"db = lancedb.connect(\"/tmp/lancedb\") # current directory\n",
"tbl = db.create_table(\"chatbot\", data)\n",
"tbl.create_index(num_partitions=64, num_sub_vectors=96)"
]
},
{
"cell_type": "markdown",
"id": "23afc2f9",
"metadata": {},
"source": [
"## Create and answer the prompt"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "06d8b867",
"metadata": {},
"outputs": [],
"source": [
"def create_prompt(query, context):\n",
" limit = 3750\n",
"\n",
" prompt_start = (\n",
" \"Answer the question based on the context below.\\n\\n\"+\n",
" \"Context:\\n\"\n",
" )\n",
" prompt_end = (\n",
" f\"\\n\\nQuestion: {query}\\nAnswer:\"\n",
" )\n",
" # append contexts until hitting limit\n",
" for i in range(1, len(context)):\n",
" if len(\"\\n\\n---\\n\\n\".join(context.text[:i])) >= limit:\n",
" prompt = (\n",
" prompt_start +\n",
" \"\\n\\n---\\n\\n\".join(context.text[:i-1]) +\n",
" prompt_end\n",
" )\n",
" break\n",
" elif i == len(context)-1:\n",
" prompt = (\n",
" prompt_start +\n",
" \"\\n\\n---\\n\\n\".join(context.text) +\n",
" prompt_end\n",
" ) \n",
" return prompt"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "e09c5142",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'The 12th person on the moon was Harrison Schmitt, and he landed on December 11, 1972.'"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def complete(prompt):\n",
" # query text-davinci-003\n",
" res = openai.Completion.create(\n",
" engine='text-davinci-003',\n",
" prompt=prompt,\n",
" temperature=0,\n",
" max_tokens=400,\n",
" top_p=1,\n",
" frequency_penalty=0,\n",
" presence_penalty=0,\n",
" stop=None\n",
" )\n",
" return res['choices'][0]['text'].strip()\n",
"\n",
"# check that it works\n",
"query = \"who was the 12th person on the moon and when did they land?\"\n",
"complete(query)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "8fcef773",
"metadata": {},
"outputs": [],
"source": [
"def answer(question):\n",
" emb = embed_func(query)[0]\n",
" context = (tbl.search(emb).limit(3)\n",
" .nprobes(20).refine_factor(100)\n",
" .to_df())\n",
" prompt = create_prompt(question, context)\n",
" return complete(prompt), context.reset_index()"
]
},
{
"cell_type": "markdown",
"id": "28705959",
"metadata": {},
"source": [
"## Show the answer and show the video at the right place"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "25714299",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"NLI with multiple negative ranking loss.\n"
]
},
{
"data": {
"text/html": [
"\n",
" <iframe\n",
" width=\"400\"\n",
" height=\"300\"\n",
" src=\"https://www.youtube.com/embed/pNvujJ1XyeQ?start=289.76\"\n",
" frameborder=\"0\"\n",
" allowfullscreen\n",
" \n",
" ></iframe>\n",
" "
],
"text/plain": [
"<IPython.lib.display.YouTubeVideo at 0x12f58afb0>"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from IPython.display import YouTubeVideo\n",
"\n",
"query = (\"Which training method should I use for sentence transformers \"\n",
" \"when I only have pairs of related sentences?\")\n",
"completion, context = answer(query)\n",
"\n",
"print(completion)\n",
"top_match = context.iloc[0]\n",
"YouTubeVideo(top_match[\"url\"].split(\"/\")[-1], start=top_match[\"start\"])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "78b7eb11",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.8"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

61
python/lancedb/context.py Normal file
View File

@@ -0,0 +1,61 @@
# Copyright 2023 LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pandas as pd
def contextualize(raw_df):
return Contextualizer(raw_df)
class Contextualizer:
def __init__(self, raw_df):
self._text_col = None
self._groupby = None
self._stride = None
self._window = None
self._raw_df = raw_df
def window(self, window):
self._window = window
return self
def stride(self, stride):
self._stride = stride
return self
def groupby(self, groupby):
self._groupby = groupby
return self
def text_col(self, text_col):
self._text_col = text_col
return self
def to_df(self):
def process_group(grp):
# For each video, create the text rolling window
text = grp[self._text_col].values
contexts = grp.iloc[: -self._window : self._stride, :].copy()
contexts[self._text_col] = [
" ".join(text[start_i : start_i + self._window])
for start_i in range(0, len(grp) - self._window, self._stride)
]
return contexts
if self._groupby is None:
return process_group(self._raw_df)
# concat result from all groups
return pd.concat(
[process_group(grp) for _, grp in self._raw_df.groupby(self._groupby)]
)

View File

@@ -29,6 +29,7 @@ class LanceDBConnection:
if isinstance(uri, str):
uri = Path(uri)
uri = uri.expanduser().absolute()
Path(uri).mkdir(parents=True, exist_ok=True)
self._uri = str(uri)
@property

View File

@@ -0,0 +1,105 @@
# Copyright 2023 LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import ratelimiter
from retry import retry
from typing import Callable, Union
from lance.vector import vec_to_table
import numpy as np
import pandas as pd
import pyarrow as pa
def with_embeddings(
func: Callable,
data: Union[pa.Table, pd.DataFrame],
column: str = "text",
wrap_api: bool = True,
show_progress: bool = False,
batch_size: int = 1000,
):
func = EmbeddingFunction(func)
if wrap_api:
func = func.retry().rate_limit().batch_size(batch_size)
if show_progress:
func = func.show_progress()
if isinstance(data, pd.DataFrame):
data = pa.Table.from_pandas(data)
embeddings = func(data[column].to_numpy())
table = vec_to_table(np.array(embeddings))
return data.append_column("vector", table["vector"])
class EmbeddingFunction:
def __init__(self, func: Callable):
self.func = func
self.rate_limiter_kwargs = {}
self.retry_kwargs = {}
self._batch_size = None
self._progress = False
def __call__(self, text):
# Get the embedding with retry
@retry(**self.retry_kwargs)
def embed_func(c):
return self.func(c.tolist())
max_calls = self.rate_limiter_kwargs["max_calls"]
limiter = ratelimiter.RateLimiter(
max_calls, period=self.rate_limiter_kwargs["period"]
)
rate_limited = limiter(embed_func)
batches = self.to_batches(text)
embeds = [emb for c in batches for emb in rate_limited(c)]
return embeds
def __repr__(self):
return f"EmbeddingFunction(func={self.func})"
def rate_limit(self, max_calls=0.9, period=1.0):
self.rate_limiter_kwargs = dict(max_calls=max_calls, period=period)
return self
def retry(self, tries=10, delay=1, max_delay=30, backoff=3, jitter=1):
self.retry_kwargs = dict(
tries=tries,
delay=delay,
max_delay=max_delay,
backoff=backoff,
jitter=jitter,
)
return self
def batch_size(self, batch_size):
self._batch_size = batch_size
return self
def show_progress(self):
self._progress = True
return self
def to_batches(self, arr):
length = len(arr)
def _chunker(arr):
for start_i in range(0, len(arr), self._batch_size):
yield arr[start_i : start_i + self._batch_size]
if self._progress:
from tqdm.auto import tqdm
yield from tqdm(_chunker(arr), total=math.ceil(length / self._batch_size))
else:
return _chunker(arr)

View File

@@ -24,6 +24,8 @@ class LanceQueryBuilder:
"""
def __init__(self, table: "lancedb.table.LanceTable", query: np.ndarray):
self._nprobes = 20
self._refine_factor = None
self._table = table
self._query = query
self._limit = 10
@@ -75,6 +77,36 @@ class LanceQueryBuilder:
self._where = where
return self
def nprobes(self, nprobes: int) -> LanceQueryBuilder:
"""Set the number of probes to use.
Parameters
----------
nprobes: int
The number of probes to use.
Returns
-------
The LanceQueryBuilder object.
"""
self._nprobes = nprobes
return self
def refine_factor(self, refine_factor: int) -> LanceQueryBuilder:
"""Set the refine factor to use.
Parameters
----------
refine_factor: int
The refine factor to use.
Returns
-------
The LanceQueryBuilder object.
"""
self._refine_factor = refine_factor
return self
def to_df(self) -> pd.DataFrame:
"""Execute the query and return the results as a pandas DataFrame."""
ds = self._table.to_lance()
@@ -82,6 +114,12 @@ class LanceQueryBuilder:
tbl = ds.to_table(
columns=self._columns,
filter=self._where,
nearest={"column": VECTOR_COLUMN_NAME, "q": self._query, "k": self._limit},
nearest={
"column": VECTOR_COLUMN_NAME,
"q": self._query,
"k": self._limit,
"nprobes": self._nprobes,
"refine_factor": self._refine_factor,
},
)
return tbl.to_pandas()

View File

@@ -59,6 +59,14 @@ class LanceTable:
def _dataset_uri(self) -> str:
return os.path.join(self._conn.uri, f"{self.name}.lance")
def create_index(self, num_partitions=256, num_sub_vectors=96):
return self._dataset.create_index(
column=VECTOR_COLUMN_NAME,
index_type="IVF_PQ",
num_partitions=num_partitions,
num_sub_vectors=num_sub_vectors,
)
@cached_property
def _dataset(self) -> LanceDataset:
return lance.dataset(self._dataset_uri)

View File

@@ -1,7 +1,7 @@
[project]
name = "lancedb"
version = "0.0.1"
dependencies = ["pylance"]
dependencies = ["pylance", "ratelimiter", "retry", "tqdm"]
description = "lancedb"
authors = [
{ name = "Lance Devs", email = "dev@eto.ai" },
@@ -43,7 +43,7 @@ dev = [
"ruff", "pre-commit", "black"
]
docs = [
"mkdocs", "mkdocs-material", "mkdocstrings[python]"
"mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"
]
[build-system]