diff --git a/notebooks/youtube_transcript_search.ipynb b/notebooks/youtube_transcript_search.ipynb index 8d9565f2..7174de89 100644 --- a/notebooks/youtube_transcript_search.ipynb +++ b/notebooks/youtube_transcript_search.ipynb @@ -31,7 +31,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "id": "a8987fcb", "metadata": {}, "outputs": [ @@ -51,7 +51,7 @@ "})" ] }, - "execution_count": 2, + "execution_count": 1, "metadata": {}, "output_type": "execute_result" } @@ -75,7 +75,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 2, "id": "121a7087", "metadata": {}, "outputs": [ @@ -142,7 +142,7 @@ "177622 0.0 24.0 " ] }, - "execution_count": 3, + "execution_count": 2, "metadata": {}, "output_type": "execute_result" } @@ -166,6 +166,24 @@ "We'll call the OpenAI embeddings API to get embeddings" ] }, + { + "cell_type": "code", + "execution_count": 3, + "id": "c8104467", + "metadata": {}, + "outputs": [], + "source": [ + "import openai\n", + "import os\n", + "\n", + "# Configuring the environment variable OPENAI_API_KEY\n", + "if \"OPENAI_API_KEY\" not in os.environ:\n", + " # OR set the key here as a variable\n", + " openai.api_key = \"sk-...\"\n", + " \n", + "assert len(openai.Model.list()[\"data\"]) > 0" + ] + }, { "cell_type": "code", "execution_count": 4, @@ -173,11 +191,7 @@ "metadata": {}, "outputs": [], "source": [ - "import openai\n", - "\n", - "# Configure environment variable OPENAI_API_KEY\n", - "# OR add variable openai.api_key = \"sk-...\"\n", - "\n", + "import numpy as np\n", "def embed_func(c): \n", " rs = openai.Embedding.create(input=c, engine=\"text-embedding-ada-002\")\n", " return [record[\"embedding\"] for record in rs[\"data\"]]" @@ -193,33 +207,94 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 5, "id": "13f15068", - "metadata": {}, + "metadata": { + "scrolled": false + }, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "Building vector index: IVF64,OPQ96, metric=l2\n" - ] + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "c4fb6f5a4ccc40ddb89d9df497213292", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/49 [00:00\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
titlepublishedurlvideo_idchannel_ididtextstartendvector
0$5 MILLION AI for FREE2022-08-12 15:18:07https://youtu.be/3EjtHs_lXnk3EjtHs_lXnkUCfzlCWGWYyIQ0aLC5w48gBQ3EjtHs_lXnk-t0.0Imagine an AI where all in the same model you ...0.024.0[-0.024402587, -0.00087673456, 0.016499246, -0...
\n", + "" + ], "text/plain": [ - "" + " title published url \\\n", + "0 $5 MILLION AI for FREE 2022-08-12 15:18:07 https://youtu.be/3EjtHs_lXnk \n", + "\n", + " video_id channel_id id \\\n", + "0 3EjtHs_lXnk UCfzlCWGWYyIQ0aLC5w48gBQ 3EjtHs_lXnk-t0.0 \n", + "\n", + " text start end \\\n", + "0 Imagine an AI where all in the same model you ... 0.0 24.0 \n", + "\n", + " vector \n", + "0 [-0.024402587, -0.00087673456, 0.016499246, -0... " ] }, - "execution_count": 7, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Sample 16384 out of 48935 to train kmeans of 1536 dim, 64 clusters\n" - ] } ], "source": [ @@ -227,10 +302,110 @@ "from lancedb.embeddings import with_embeddings\n", "\n", "data = with_embeddings(embed_func, df, show_progress=True)\n", - "\n", + "data.to_pandas().head(1)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "92d53abd", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "48935" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ "db = lancedb.connect(\"/tmp/lancedb\") # current directory\n", "tbl = db.create_table(\"chatbot\", data)\n", - "tbl.create_index(num_partitions=64, num_sub_vectors=96)" + "len(tbl)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "22892cfd", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
titlepublishedurlvideo_idchannel_ididtextstartendvector
0$5 MILLION AI for FREE2022-08-12 15:18:07https://youtu.be/3EjtHs_lXnk3EjtHs_lXnkUCfzlCWGWYyIQ0aLC5w48gBQ3EjtHs_lXnk-t0.0Imagine an AI where all in the same model you ...0.024.0[-0.024402587, -0.00087673456, 0.016499246, -0...
\n", + "
" + ], + "text/plain": [ + " title published url \\\n", + "0 $5 MILLION AI for FREE 2022-08-12 15:18:07 https://youtu.be/3EjtHs_lXnk \n", + "\n", + " video_id channel_id id \\\n", + "0 3EjtHs_lXnk UCfzlCWGWYyIQ0aLC5w48gBQ 3EjtHs_lXnk-t0.0 \n", + "\n", + " text start end \\\n", + "0 Imagine an AI where all in the same model you ... 0.0 24.0 \n", + "\n", + " vector \n", + "0 [-0.024402587, -0.00087673456, 0.016499246, -0... " + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tbl.to_pandas().head(1)" ] }, { @@ -313,43 +488,76 @@ "complete(query)" ] }, - { - "cell_type": "code", - "execution_count": 12, - "id": "8fcef773", - "metadata": {}, - "outputs": [], - "source": [ - "def answer(question):\n", - " emb = embed_func(query)[0]\n", - " context = (tbl.search(emb).limit(3)\n", - " .nprobes(20).refine_factor(100)\n", - " .to_df())\n", - " prompt = create_prompt(question, context)\n", - " return complete(prompt), context.reset_index()" - ] - }, { "cell_type": "markdown", "id": "28705959", "metadata": {}, "source": [ - "## Show the answer and show the video at the right place" + "## Use LanceDB to find the answer and show the video at the right place" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "421a678d", + "metadata": {}, + "outputs": [], + "source": [ + "query = (\"Which training method should I use for sentence transformers \"\n", + " \"when I only have pairs of related sentences?\")" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "80b160f0", + "metadata": {}, + "outputs": [], + "source": [ + "# Embed the question\n", + "emb = embed_func(query)[0]" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "7c3ed619", + "metadata": {}, + "outputs": [], + "source": [ + "# Use LanceDB to get top 3 most relevant context\n", + "context = tbl.search(emb).limit(3).to_df()" ] }, { "cell_type": "code", "execution_count": 13, - "id": "25714299", + "id": "8fcef773", "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "NLI with multiple negative ranking loss.\n" - ] - }, + "data": { + "text/plain": [ + "'NLI with multiple negative ranking loss.'" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Get the answer from completion API\n", + "prompt = create_prompt(query, context)\n", + "complete(prompt)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "25714299", + "metadata": {}, + "outputs": [ { "data": { "text/html": [ @@ -365,10 +573,10 @@ " " ], "text/plain": [ - "" + "" ] }, - "execution_count": 13, + "execution_count": 14, "metadata": {}, "output_type": "execute_result" } @@ -376,11 +584,6 @@ "source": [ "from IPython.display import YouTubeVideo\n", "\n", - "query = (\"Which training method should I use for sentence transformers \"\n", - " \"when I only have pairs of related sentences?\")\n", - "completion, context = answer(query)\n", - "\n", - "print(completion)\n", "top_match = context.iloc[0]\n", "YouTubeVideo(top_match[\"url\"].split(\"/\")[-1], start=top_match[\"start\"])" ] diff --git a/python/lancedb/embeddings.py b/python/lancedb/embeddings.py index 404d8c82..0f7c1eb6 100644 --- a/python/lancedb/embeddings.py +++ b/python/lancedb/embeddings.py @@ -36,7 +36,7 @@ def with_embeddings( if show_progress: func = func.show_progress() if isinstance(data, pd.DataFrame): - data = pa.Table.from_pandas(data) + data = pa.Table.from_pandas(data, preserve_index=False) embeddings = func(data[column].to_numpy()) table = vec_to_table(np.array(embeddings)) return data.append_column("vector", table["vector"]) @@ -102,4 +102,4 @@ class EmbeddingFunction: yield from tqdm(_chunker(arr), total=math.ceil(length / self._batch_size)) else: - return _chunker(arr) + yield from _chunker(arr) diff --git a/python/lancedb/table.py b/python/lancedb/table.py index e4eea8f6..a1c82967 100644 --- a/python/lancedb/table.py +++ b/python/lancedb/table.py @@ -55,6 +55,27 @@ class LanceTable: """Return the schema of the table.""" return self._dataset.schema + def __len__(self): + return self._dataset.count_rows() + + def __repr__(self) -> str: + return f"LanceTable({self.name})" + + def __str__(self) -> str: + return self.__repr__() + + def head(self, n=5) -> pa.Table: + """Return the first n rows of the table.""" + return self._dataset.head(n) + + def to_pandas(self) -> pd.DataFrame: + """Return the table as a pandas DataFrame.""" + return self.to_arrow().to_pandas() + + def to_arrow(self) -> pa.Table: + """Return the table as a pyarrow Table.""" + return self._dataset.to_table() + @property def _dataset_uri(self) -> str: return os.path.join(self._conn.uri, f"{self.name}.lance") diff --git a/python/tests/test_embeddings.py b/python/tests/test_embeddings.py new file mode 100644 index 00000000..2740ecb2 --- /dev/null +++ b/python/tests/test_embeddings.py @@ -0,0 +1,36 @@ +# Copyright 2023 LanceDB Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import pyarrow as pa + +from lancedb.embeddings import with_embeddings + + +def mock_embed_func(input_data): + return [np.random.randn(128).tolist() for _ in range(len(input_data))] + + +def test_with_embeddings(): + data = pa.Table.from_arrays( + [ + pa.array(["foo", "bar"]), + pa.array([10.0, 20.0]), + ], + names=["text", "price"], + ) + data = with_embeddings(mock_embed_func, data) + assert data.num_columns == 3 + assert data.num_rows == 2 + assert data.column_names == ["text", "price", "vector"] + assert data.column("text").to_pylist() == ["foo", "bar"] + assert data.column("price").to_pylist() == [10.0, 20.0]