From 5d7832c8a5d4f55bc98a4251ffe52ccc482ba362 Mon Sep 17 00:00:00 2001
From: Chang She <759245+changhiskhan@users.noreply.github.com>
Date: Fri, 24 Mar 2023 18:16:29 -0700
Subject: [PATCH] update for release
---
notebooks/youtube_transcript_search.ipynb | 319 ++++++++++++++++++----
python/lancedb/embeddings.py | 4 +-
python/lancedb/table.py | 21 ++
python/tests/test_embeddings.py | 36 +++
4 files changed, 320 insertions(+), 60 deletions(-)
create mode 100644 python/tests/test_embeddings.py
diff --git a/notebooks/youtube_transcript_search.ipynb b/notebooks/youtube_transcript_search.ipynb
index 8d9565f2..7174de89 100644
--- a/notebooks/youtube_transcript_search.ipynb
+++ b/notebooks/youtube_transcript_search.ipynb
@@ -31,7 +31,7 @@
},
{
"cell_type": "code",
- "execution_count": 2,
+ "execution_count": 1,
"id": "a8987fcb",
"metadata": {},
"outputs": [
@@ -51,7 +51,7 @@
"})"
]
},
- "execution_count": 2,
+ "execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
@@ -75,7 +75,7 @@
},
{
"cell_type": "code",
- "execution_count": 3,
+ "execution_count": 2,
"id": "121a7087",
"metadata": {},
"outputs": [
@@ -142,7 +142,7 @@
"177622 0.0 24.0 "
]
},
- "execution_count": 3,
+ "execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
@@ -166,6 +166,24 @@
"We'll call the OpenAI embeddings API to get embeddings"
]
},
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "c8104467",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import openai\n",
+ "import os\n",
+ "\n",
+ "# Configuring the environment variable OPENAI_API_KEY\n",
+ "if \"OPENAI_API_KEY\" not in os.environ:\n",
+ " # OR set the key here as a variable\n",
+ " openai.api_key = \"sk-...\"\n",
+ " \n",
+ "assert len(openai.Model.list()[\"data\"]) > 0"
+ ]
+ },
{
"cell_type": "code",
"execution_count": 4,
@@ -173,11 +191,7 @@
"metadata": {},
"outputs": [],
"source": [
- "import openai\n",
- "\n",
- "# Configure environment variable OPENAI_API_KEY\n",
- "# OR add variable openai.api_key = \"sk-...\"\n",
- "\n",
+ "import numpy as np\n",
"def embed_func(c): \n",
" rs = openai.Embedding.create(input=c, engine=\"text-embedding-ada-002\")\n",
" return [record[\"embedding\"] for record in rs[\"data\"]]"
@@ -193,33 +207,94 @@
},
{
"cell_type": "code",
- "execution_count": 7,
+ "execution_count": 5,
"id": "13f15068",
- "metadata": {},
+ "metadata": {
+ "scrolled": false
+ },
"outputs": [
{
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Building vector index: IVF64,OPQ96, metric=l2\n"
- ]
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "c4fb6f5a4ccc40ddb89d9df497213292",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ " 0%| | 0/49 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
},
{
"data": {
+ "text/html": [
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " title | \n",
+ " published | \n",
+ " url | \n",
+ " video_id | \n",
+ " channel_id | \n",
+ " id | \n",
+ " text | \n",
+ " start | \n",
+ " end | \n",
+ " vector | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | 0 | \n",
+ " $5 MILLION AI for FREE | \n",
+ " 2022-08-12 15:18:07 | \n",
+ " https://youtu.be/3EjtHs_lXnk | \n",
+ " 3EjtHs_lXnk | \n",
+ " UCfzlCWGWYyIQ0aLC5w48gBQ | \n",
+ " 3EjtHs_lXnk-t0.0 | \n",
+ " Imagine an AI where all in the same model you ... | \n",
+ " 0.0 | \n",
+ " 24.0 | \n",
+ " [-0.024402587, -0.00087673456, 0.016499246, -0... | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
"text/plain": [
- ""
+ " title published url \\\n",
+ "0 $5 MILLION AI for FREE 2022-08-12 15:18:07 https://youtu.be/3EjtHs_lXnk \n",
+ "\n",
+ " video_id channel_id id \\\n",
+ "0 3EjtHs_lXnk UCfzlCWGWYyIQ0aLC5w48gBQ 3EjtHs_lXnk-t0.0 \n",
+ "\n",
+ " text start end \\\n",
+ "0 Imagine an AI where all in the same model you ... 0.0 24.0 \n",
+ "\n",
+ " vector \n",
+ "0 [-0.024402587, -0.00087673456, 0.016499246, -0... "
]
},
- "execution_count": 7,
+ "execution_count": 5,
"metadata": {},
"output_type": "execute_result"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Sample 16384 out of 48935 to train kmeans of 1536 dim, 64 clusters\n"
- ]
}
],
"source": [
@@ -227,10 +302,110 @@
"from lancedb.embeddings import with_embeddings\n",
"\n",
"data = with_embeddings(embed_func, df, show_progress=True)\n",
- "\n",
+ "data.to_pandas().head(1)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "92d53abd",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "48935"
+ ]
+ },
+ "execution_count": 6,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
"db = lancedb.connect(\"/tmp/lancedb\") # current directory\n",
"tbl = db.create_table(\"chatbot\", data)\n",
- "tbl.create_index(num_partitions=64, num_sub_vectors=96)"
+ "len(tbl)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "22892cfd",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " title | \n",
+ " published | \n",
+ " url | \n",
+ " video_id | \n",
+ " channel_id | \n",
+ " id | \n",
+ " text | \n",
+ " start | \n",
+ " end | \n",
+ " vector | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | 0 | \n",
+ " $5 MILLION AI for FREE | \n",
+ " 2022-08-12 15:18:07 | \n",
+ " https://youtu.be/3EjtHs_lXnk | \n",
+ " 3EjtHs_lXnk | \n",
+ " UCfzlCWGWYyIQ0aLC5w48gBQ | \n",
+ " 3EjtHs_lXnk-t0.0 | \n",
+ " Imagine an AI where all in the same model you ... | \n",
+ " 0.0 | \n",
+ " 24.0 | \n",
+ " [-0.024402587, -0.00087673456, 0.016499246, -0... | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " title published url \\\n",
+ "0 $5 MILLION AI for FREE 2022-08-12 15:18:07 https://youtu.be/3EjtHs_lXnk \n",
+ "\n",
+ " video_id channel_id id \\\n",
+ "0 3EjtHs_lXnk UCfzlCWGWYyIQ0aLC5w48gBQ 3EjtHs_lXnk-t0.0 \n",
+ "\n",
+ " text start end \\\n",
+ "0 Imagine an AI where all in the same model you ... 0.0 24.0 \n",
+ "\n",
+ " vector \n",
+ "0 [-0.024402587, -0.00087673456, 0.016499246, -0... "
+ ]
+ },
+ "execution_count": 7,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "tbl.to_pandas().head(1)"
]
},
{
@@ -313,43 +488,76 @@
"complete(query)"
]
},
- {
- "cell_type": "code",
- "execution_count": 12,
- "id": "8fcef773",
- "metadata": {},
- "outputs": [],
- "source": [
- "def answer(question):\n",
- " emb = embed_func(query)[0]\n",
- " context = (tbl.search(emb).limit(3)\n",
- " .nprobes(20).refine_factor(100)\n",
- " .to_df())\n",
- " prompt = create_prompt(question, context)\n",
- " return complete(prompt), context.reset_index()"
- ]
- },
{
"cell_type": "markdown",
"id": "28705959",
"metadata": {},
"source": [
- "## Show the answer and show the video at the right place"
+ "## Use LanceDB to find the answer and show the video at the right place"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "id": "421a678d",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "query = (\"Which training method should I use for sentence transformers \"\n",
+ " \"when I only have pairs of related sentences?\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "id": "80b160f0",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Embed the question\n",
+ "emb = embed_func(query)[0]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "id": "7c3ed619",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Use LanceDB to get top 3 most relevant context\n",
+ "context = tbl.search(emb).limit(3).to_df()"
]
},
{
"cell_type": "code",
"execution_count": 13,
- "id": "25714299",
+ "id": "8fcef773",
"metadata": {},
"outputs": [
{
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "NLI with multiple negative ranking loss.\n"
- ]
- },
+ "data": {
+ "text/plain": [
+ "'NLI with multiple negative ranking loss.'"
+ ]
+ },
+ "execution_count": 13,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# Get the answer from completion API\n",
+ "prompt = create_prompt(query, context)\n",
+ "complete(prompt)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "id": "25714299",
+ "metadata": {},
+ "outputs": [
{
"data": {
"text/html": [
@@ -365,10 +573,10 @@
" "
],
"text/plain": [
- ""
+ ""
]
},
- "execution_count": 13,
+ "execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
@@ -376,11 +584,6 @@
"source": [
"from IPython.display import YouTubeVideo\n",
"\n",
- "query = (\"Which training method should I use for sentence transformers \"\n",
- " \"when I only have pairs of related sentences?\")\n",
- "completion, context = answer(query)\n",
- "\n",
- "print(completion)\n",
"top_match = context.iloc[0]\n",
"YouTubeVideo(top_match[\"url\"].split(\"/\")[-1], start=top_match[\"start\"])"
]
diff --git a/python/lancedb/embeddings.py b/python/lancedb/embeddings.py
index 404d8c82..0f7c1eb6 100644
--- a/python/lancedb/embeddings.py
+++ b/python/lancedb/embeddings.py
@@ -36,7 +36,7 @@ def with_embeddings(
if show_progress:
func = func.show_progress()
if isinstance(data, pd.DataFrame):
- data = pa.Table.from_pandas(data)
+ data = pa.Table.from_pandas(data, preserve_index=False)
embeddings = func(data[column].to_numpy())
table = vec_to_table(np.array(embeddings))
return data.append_column("vector", table["vector"])
@@ -102,4 +102,4 @@ class EmbeddingFunction:
yield from tqdm(_chunker(arr), total=math.ceil(length / self._batch_size))
else:
- return _chunker(arr)
+ yield from _chunker(arr)
diff --git a/python/lancedb/table.py b/python/lancedb/table.py
index e4eea8f6..a1c82967 100644
--- a/python/lancedb/table.py
+++ b/python/lancedb/table.py
@@ -55,6 +55,27 @@ class LanceTable:
"""Return the schema of the table."""
return self._dataset.schema
+ def __len__(self):
+ return self._dataset.count_rows()
+
+ def __repr__(self) -> str:
+ return f"LanceTable({self.name})"
+
+ def __str__(self) -> str:
+ return self.__repr__()
+
+ def head(self, n=5) -> pa.Table:
+ """Return the first n rows of the table."""
+ return self._dataset.head(n)
+
+ def to_pandas(self) -> pd.DataFrame:
+ """Return the table as a pandas DataFrame."""
+ return self.to_arrow().to_pandas()
+
+ def to_arrow(self) -> pa.Table:
+ """Return the table as a pyarrow Table."""
+ return self._dataset.to_table()
+
@property
def _dataset_uri(self) -> str:
return os.path.join(self._conn.uri, f"{self.name}.lance")
diff --git a/python/tests/test_embeddings.py b/python/tests/test_embeddings.py
new file mode 100644
index 00000000..2740ecb2
--- /dev/null
+++ b/python/tests/test_embeddings.py
@@ -0,0 +1,36 @@
+# Copyright 2023 LanceDB Developers
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import numpy as np
+import pyarrow as pa
+
+from lancedb.embeddings import with_embeddings
+
+
+def mock_embed_func(input_data):
+ return [np.random.randn(128).tolist() for _ in range(len(input_data))]
+
+
+def test_with_embeddings():
+ data = pa.Table.from_arrays(
+ [
+ pa.array(["foo", "bar"]),
+ pa.array([10.0, 20.0]),
+ ],
+ names=["text", "price"],
+ )
+ data = with_embeddings(mock_embed_func, data)
+ assert data.num_columns == 3
+ assert data.num_rows == 2
+ assert data.column_names == ["text", "price", "vector"]
+ assert data.column("text").to_pylist() == ["foo", "bar"]
+ assert data.column("price").to_pylist() == [10.0, 20.0]