From b91139d3c706efae2f87749db54af8978c4c610f Mon Sep 17 00:00:00 2001
From: Chang She <759245+changhiskhan@users.noreply.github.com>
Date: Thu, 23 Mar 2023 15:07:58 -0700
Subject: [PATCH 1/2] Add tutorial notebook
Convert contextualization and embeddings functionality.
And use it with converted notebook for video search
---
.gitignore | 2 +
README.md | 4 +-
docs/mkdocs.yml | 1 +
docs/requirements.txt | 1 +
docs/src/index.md | 21 +-
notebooks/youtube_transcript_search.ipynb | 418 ++++++++++++++++++++++
python/lancedb/context.py | 61 ++++
python/lancedb/db.py | 1 +
python/lancedb/embeddings.py | 105 ++++++
python/lancedb/query.py | 40 ++-
python/lancedb/table.py | 8 +
python/pyproject.toml | 4 +-
12 files changed, 660 insertions(+), 6 deletions(-)
create mode 100644 notebooks/youtube_transcript_search.ipynb
create mode 100644 python/lancedb/context.py
create mode 100644 python/lancedb/embeddings.py
diff --git a/.gitignore b/.gitignore
index 9389ad0f..82107c69 100644
--- a/.gitignore
+++ b/.gitignore
@@ -13,3 +13,5 @@ site
python/build
python/dist
+
+notebooks/.ipynb_checkpoints
diff --git a/README.md b/README.md
index fc7e4965..a61ea789 100644
--- a/README.md
+++ b/README.md
@@ -1,6 +1,6 @@
-
+
**Serverless, low-latency vector database for AI applications**
@@ -45,5 +45,5 @@ db = lancedb.connect(uri)
table = db.create_table("my_table",
data=[{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0}])
-result = table.search([100, 100]).where("price < 15").limit(1).to_df()
+result = table.search([100, 100]).limit(2).to_df()
```
diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml
index 92243b79..d5054ca5 100644
--- a/docs/mkdocs.yml
+++ b/docs/mkdocs.yml
@@ -7,6 +7,7 @@ theme:
plugins:
- search
- mkdocstrings
+- mkdocs-jupyter
nav:
- Home: index.md
diff --git a/docs/requirements.txt b/docs/requirements.txt
index 1fbe6277..c2e2ab41 100644
--- a/docs/requirements.txt
+++ b/docs/requirements.txt
@@ -1,3 +1,4 @@
mkdocs==1.4.2
+mkdocs-jupyter==0.24.1
mkdocs-material==9.1.3
mkdocstrings[python]==0.20.0
diff --git a/docs/src/index.md b/docs/src/index.md
index e4c909a6..a9446ea7 100644
--- a/docs/src/index.md
+++ b/docs/src/index.md
@@ -14,7 +14,26 @@ The key features of LanceDB include:
LanceDB's core is written in Rust 🦀 and is built using Lance, an open-source columnar format designed for performant ML workloads.
+
+## Installation
+
+```shell
+pip install lancedb
+```
+
+## Quickstart
+
+```python
+import lancedb
+
+db = lancedb.connect(".")
+table = db.create_table("my_table",
+ data=[{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
+ {"vector": [5.9, 26.5], "item": "bar", "price": 20.0}])
+result = table.search([100, 100]).limit(2).to_df()
+```
+
+
## Documentation Quick Links
-* `Quick start` - search and filter a hello world vector dataset with LanceDB using the Python SDK.
* [`API Reference`](python.md) - detailed documentation for the LanceDB Python SDK.
diff --git a/notebooks/youtube_transcript_search.ipynb b/notebooks/youtube_transcript_search.ipynb
new file mode 100644
index 00000000..8d9565f2
--- /dev/null
+++ b/notebooks/youtube_transcript_search.ipynb
@@ -0,0 +1,418 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "42bf01fb",
+ "metadata": {},
+ "source": [
+ "# We're going to build question and answer bot\n",
+ "\n",
+ "That allow you to search through youtube transcripts using natural language"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "48547ddb",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "pip install --quiet openai datasets lancedb"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "22e570f4",
+ "metadata": {},
+ "source": [
+ "## Download the data\n",
+ "700 videos and 208619 sentences"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "a8987fcb",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Found cached dataset json (/Users/changshe/.cache/huggingface/datasets/jamescalam___json/jamescalam--youtube-transcriptions-08d889f6a5386b9b/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51)\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "Dataset({\n",
+ " features: ['title', 'published', 'url', 'video_id', 'channel_id', 'id', 'text', 'start', 'end'],\n",
+ " num_rows: 208619\n",
+ "})"
+ ]
+ },
+ "execution_count": 2,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "from datasets import load_dataset\n",
+ "\n",
+ "data = load_dataset('jamescalam/youtube-transcriptions', split='train')\n",
+ "data"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "5ac2b6a3",
+ "metadata": {},
+ "source": [
+ "## Prepare context\n",
+ "\n",
+ "Create context of 20 sentences"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "121a7087",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " title | \n",
+ " published | \n",
+ " url | \n",
+ " video_id | \n",
+ " channel_id | \n",
+ " id | \n",
+ " text | \n",
+ " start | \n",
+ " end | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | 177622 | \n",
+ " $5 MILLION AI for FREE | \n",
+ " 2022-08-12 15:18:07 | \n",
+ " https://youtu.be/3EjtHs_lXnk | \n",
+ " 3EjtHs_lXnk | \n",
+ " UCfzlCWGWYyIQ0aLC5w48gBQ | \n",
+ " 3EjtHs_lXnk-t0.0 | \n",
+ " Imagine an AI where all in the same model you ... | \n",
+ " 0.0 | \n",
+ " 24.0 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " title published \\\n",
+ "177622 $5 MILLION AI for FREE 2022-08-12 15:18:07 \n",
+ "\n",
+ " url video_id channel_id \\\n",
+ "177622 https://youtu.be/3EjtHs_lXnk 3EjtHs_lXnk UCfzlCWGWYyIQ0aLC5w48gBQ \n",
+ "\n",
+ " id text \\\n",
+ "177622 3EjtHs_lXnk-t0.0 Imagine an AI where all in the same model you ... \n",
+ "\n",
+ " start end \n",
+ "177622 0.0 24.0 "
+ ]
+ },
+ "execution_count": 3,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "from lancedb.context import contextualize\n",
+ "\n",
+ "df = (contextualize(data.to_pandas())\n",
+ " .groupby(\"title\").text_col(\"text\")\n",
+ " .window(20).stride(4)\n",
+ " .to_df())\n",
+ "df.head(1)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "3044e0b0",
+ "metadata": {},
+ "source": [
+ "## Create embedding function\n",
+ "We'll call the OpenAI embeddings API to get embeddings"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "8eefc159",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import openai\n",
+ "\n",
+ "# Configure environment variable OPENAI_API_KEY\n",
+ "# OR add variable openai.api_key = \"sk-...\"\n",
+ "\n",
+ "def embed_func(c): \n",
+ " rs = openai.Embedding.create(input=c, engine=\"text-embedding-ada-002\")\n",
+ " return [record[\"embedding\"] for record in rs[\"data\"]]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "2106b5bb",
+ "metadata": {},
+ "source": [
+ "## Create the LanceDB Table"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "13f15068",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Building vector index: IVF64,OPQ96, metric=l2\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "
"
+ ]
+ },
+ "execution_count": 7,
+ "metadata": {},
+ "output_type": "execute_result"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Sample 16384 out of 48935 to train kmeans of 1536 dim, 64 clusters\n"
+ ]
+ }
+ ],
+ "source": [
+ "import lancedb\n",
+ "from lancedb.embeddings import with_embeddings\n",
+ "\n",
+ "data = with_embeddings(embed_func, df, show_progress=True)\n",
+ "\n",
+ "db = lancedb.connect(\"/tmp/lancedb\") # current directory\n",
+ "tbl = db.create_table(\"chatbot\", data)\n",
+ "tbl.create_index(num_partitions=64, num_sub_vectors=96)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "23afc2f9",
+ "metadata": {},
+ "source": [
+ "## Create and answer the prompt"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "06d8b867",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def create_prompt(query, context):\n",
+ " limit = 3750\n",
+ "\n",
+ " prompt_start = (\n",
+ " \"Answer the question based on the context below.\\n\\n\"+\n",
+ " \"Context:\\n\"\n",
+ " )\n",
+ " prompt_end = (\n",
+ " f\"\\n\\nQuestion: {query}\\nAnswer:\"\n",
+ " )\n",
+ " # append contexts until hitting limit\n",
+ " for i in range(1, len(context)):\n",
+ " if len(\"\\n\\n---\\n\\n\".join(context.text[:i])) >= limit:\n",
+ " prompt = (\n",
+ " prompt_start +\n",
+ " \"\\n\\n---\\n\\n\".join(context.text[:i-1]) +\n",
+ " prompt_end\n",
+ " )\n",
+ " break\n",
+ " elif i == len(context)-1:\n",
+ " prompt = (\n",
+ " prompt_start +\n",
+ " \"\\n\\n---\\n\\n\".join(context.text) +\n",
+ " prompt_end\n",
+ " ) \n",
+ " return prompt"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "id": "e09c5142",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "'The 12th person on the moon was Harrison Schmitt, and he landed on December 11, 1972.'"
+ ]
+ },
+ "execution_count": 9,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "def complete(prompt):\n",
+ " # query text-davinci-003\n",
+ " res = openai.Completion.create(\n",
+ " engine='text-davinci-003',\n",
+ " prompt=prompt,\n",
+ " temperature=0,\n",
+ " max_tokens=400,\n",
+ " top_p=1,\n",
+ " frequency_penalty=0,\n",
+ " presence_penalty=0,\n",
+ " stop=None\n",
+ " )\n",
+ " return res['choices'][0]['text'].strip()\n",
+ "\n",
+ "# check that it works\n",
+ "query = \"who was the 12th person on the moon and when did they land?\"\n",
+ "complete(query)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "id": "8fcef773",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def answer(question):\n",
+ " emb = embed_func(query)[0]\n",
+ " context = (tbl.search(emb).limit(3)\n",
+ " .nprobes(20).refine_factor(100)\n",
+ " .to_df())\n",
+ " prompt = create_prompt(question, context)\n",
+ " return complete(prompt), context.reset_index()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "28705959",
+ "metadata": {},
+ "source": [
+ "## Show the answer and show the video at the right place"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "id": "25714299",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "NLI with multiple negative ranking loss.\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 13,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "from IPython.display import YouTubeVideo\n",
+ "\n",
+ "query = (\"Which training method should I use for sentence transformers \"\n",
+ " \"when I only have pairs of related sentences?\")\n",
+ "completion, context = answer(query)\n",
+ "\n",
+ "print(completion)\n",
+ "top_match = context.iloc[0]\n",
+ "YouTubeVideo(top_match[\"url\"].split(\"/\")[-1], start=top_match[\"start\"])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "78b7eb11",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.8"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/python/lancedb/context.py b/python/lancedb/context.py
new file mode 100644
index 00000000..3a2a4c2d
--- /dev/null
+++ b/python/lancedb/context.py
@@ -0,0 +1,61 @@
+# Copyright 2023 LanceDB Developers
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import pandas as pd
+
+
+def contextualize(raw_df):
+ return Contextualizer(raw_df)
+
+
+class Contextualizer:
+ def __init__(self, raw_df):
+ self._text_col = None
+ self._groupby = None
+ self._stride = None
+ self._window = None
+ self._raw_df = raw_df
+
+ def window(self, window):
+ self._window = window
+ return self
+
+ def stride(self, stride):
+ self._stride = stride
+ return self
+
+ def groupby(self, groupby):
+ self._groupby = groupby
+ return self
+
+ def text_col(self, text_col):
+ self._text_col = text_col
+ return self
+
+ def to_df(self):
+ def process_group(grp):
+ # For each video, create the text rolling window
+ text = grp[self._text_col].values
+ contexts = grp.iloc[: -self._window : self._stride, :].copy()
+ contexts[self._text_col] = [
+ " ".join(text[start_i : start_i + self._window])
+ for start_i in range(0, len(grp) - self._window, self._stride)
+ ]
+ return contexts
+
+ if self._groupby is None:
+ return process_group(self._raw_df)
+ # concat result from all groups
+ return pd.concat(
+ [process_group(grp) for _, grp in self._raw_df.groupby(self._groupby)]
+ )
diff --git a/python/lancedb/db.py b/python/lancedb/db.py
index 3db1c583..d869575b 100644
--- a/python/lancedb/db.py
+++ b/python/lancedb/db.py
@@ -29,6 +29,7 @@ class LanceDBConnection:
if isinstance(uri, str):
uri = Path(uri)
uri = uri.expanduser().absolute()
+ Path(uri).mkdir(parents=True, exist_ok=True)
self._uri = str(uri)
@property
diff --git a/python/lancedb/embeddings.py b/python/lancedb/embeddings.py
new file mode 100644
index 00000000..404d8c82
--- /dev/null
+++ b/python/lancedb/embeddings.py
@@ -0,0 +1,105 @@
+# Copyright 2023 LanceDB Developers
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+import ratelimiter
+from retry import retry
+from typing import Callable, Union
+
+from lance.vector import vec_to_table
+import numpy as np
+import pandas as pd
+import pyarrow as pa
+
+
+def with_embeddings(
+ func: Callable,
+ data: Union[pa.Table, pd.DataFrame],
+ column: str = "text",
+ wrap_api: bool = True,
+ show_progress: bool = False,
+ batch_size: int = 1000,
+):
+ func = EmbeddingFunction(func)
+ if wrap_api:
+ func = func.retry().rate_limit().batch_size(batch_size)
+ if show_progress:
+ func = func.show_progress()
+ if isinstance(data, pd.DataFrame):
+ data = pa.Table.from_pandas(data)
+ embeddings = func(data[column].to_numpy())
+ table = vec_to_table(np.array(embeddings))
+ return data.append_column("vector", table["vector"])
+
+
+class EmbeddingFunction:
+ def __init__(self, func: Callable):
+ self.func = func
+ self.rate_limiter_kwargs = {}
+ self.retry_kwargs = {}
+ self._batch_size = None
+ self._progress = False
+
+ def __call__(self, text):
+ # Get the embedding with retry
+ @retry(**self.retry_kwargs)
+ def embed_func(c):
+ return self.func(c.tolist())
+
+ max_calls = self.rate_limiter_kwargs["max_calls"]
+ limiter = ratelimiter.RateLimiter(
+ max_calls, period=self.rate_limiter_kwargs["period"]
+ )
+ rate_limited = limiter(embed_func)
+ batches = self.to_batches(text)
+ embeds = [emb for c in batches for emb in rate_limited(c)]
+ return embeds
+
+ def __repr__(self):
+ return f"EmbeddingFunction(func={self.func})"
+
+ def rate_limit(self, max_calls=0.9, period=1.0):
+ self.rate_limiter_kwargs = dict(max_calls=max_calls, period=period)
+ return self
+
+ def retry(self, tries=10, delay=1, max_delay=30, backoff=3, jitter=1):
+ self.retry_kwargs = dict(
+ tries=tries,
+ delay=delay,
+ max_delay=max_delay,
+ backoff=backoff,
+ jitter=jitter,
+ )
+ return self
+
+ def batch_size(self, batch_size):
+ self._batch_size = batch_size
+ return self
+
+ def show_progress(self):
+ self._progress = True
+ return self
+
+ def to_batches(self, arr):
+ length = len(arr)
+
+ def _chunker(arr):
+ for start_i in range(0, len(arr), self._batch_size):
+ yield arr[start_i : start_i + self._batch_size]
+
+ if self._progress:
+ from tqdm.auto import tqdm
+
+ yield from tqdm(_chunker(arr), total=math.ceil(length / self._batch_size))
+ else:
+ return _chunker(arr)
diff --git a/python/lancedb/query.py b/python/lancedb/query.py
index 14ac2083..21333bec 100644
--- a/python/lancedb/query.py
+++ b/python/lancedb/query.py
@@ -24,6 +24,8 @@ class LanceQueryBuilder:
"""
def __init__(self, table: "lancedb.table.LanceTable", query: np.ndarray):
+ self._nprobes = 20
+ self._refine_factor = None
self._table = table
self._query = query
self._limit = 10
@@ -75,6 +77,36 @@ class LanceQueryBuilder:
self._where = where
return self
+ def nprobes(self, nprobes: int) -> LanceQueryBuilder:
+ """Set the number of probes to use.
+
+ Parameters
+ ----------
+ nprobes: int
+ The number of probes to use.
+
+ Returns
+ -------
+ The LanceQueryBuilder object.
+ """
+ self._nprobes = nprobes
+ return self
+
+ def refine_factor(self, refine_factor: int) -> LanceQueryBuilder:
+ """Set the refine factor to use.
+
+ Parameters
+ ----------
+ refine_factor: int
+ The refine factor to use.
+
+ Returns
+ -------
+ The LanceQueryBuilder object.
+ """
+ self._refine_factor = refine_factor
+ return self
+
def to_df(self) -> pd.DataFrame:
"""Execute the query and return the results as a pandas DataFrame."""
ds = self._table.to_lance()
@@ -82,6 +114,12 @@ class LanceQueryBuilder:
tbl = ds.to_table(
columns=self._columns,
filter=self._where,
- nearest={"column": VECTOR_COLUMN_NAME, "q": self._query, "k": self._limit},
+ nearest={
+ "column": VECTOR_COLUMN_NAME,
+ "q": self._query,
+ "k": self._limit,
+ "nprobes": self._nprobes,
+ "refine_factor": self._refine_factor,
+ },
)
return tbl.to_pandas()
diff --git a/python/lancedb/table.py b/python/lancedb/table.py
index 7840f396..b5fd6a1f 100644
--- a/python/lancedb/table.py
+++ b/python/lancedb/table.py
@@ -59,6 +59,14 @@ class LanceTable:
def _dataset_uri(self) -> str:
return os.path.join(self._conn.uri, f"{self.name}.lance")
+ def create_index(self, num_partitions=256, num_sub_vectors=96):
+ return self._dataset.create_index(
+ column=VECTOR_COLUMN_NAME,
+ index_type="IVF_PQ",
+ num_partitions=num_partitions,
+ num_sub_vectors=num_sub_vectors,
+ )
+
@cached_property
def _dataset(self) -> LanceDataset:
return lance.dataset(self._dataset_uri)
diff --git a/python/pyproject.toml b/python/pyproject.toml
index 6c1804cf..e2986da3 100644
--- a/python/pyproject.toml
+++ b/python/pyproject.toml
@@ -1,7 +1,7 @@
[project]
name = "lancedb"
version = "0.0.1"
-dependencies = ["pylance"]
+dependencies = ["pylance", "ratelimiter", "retry", "tqdm"]
description = "lancedb"
authors = [
{ name = "Lance Devs", email = "dev@eto.ai" },
@@ -43,7 +43,7 @@ dev = [
"ruff", "pre-commit", "black"
]
docs = [
- "mkdocs", "mkdocs-material", "mkdocstrings[python]"
+ "mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"
]
[build-system]
From 826fe320bb823ccc7243b300f7685280708ccf86 Mon Sep 17 00:00:00 2001
From: Chang She <759245+changhiskhan@users.noreply.github.com>
Date: Thu, 23 Mar 2023 17:31:24 -0700
Subject: [PATCH 2/2] address PR comments
---
python/lancedb/context.py | 49 +++++++++++++++++++++++++++++++++------
python/lancedb/db.py | 5 ++++
python/lancedb/table.py | 11 +++++++++
3 files changed, 58 insertions(+), 7 deletions(-)
diff --git a/python/lancedb/context.py b/python/lancedb/context.py
index 3a2a4c2d..25090195 100644
--- a/python/lancedb/context.py
+++ b/python/lancedb/context.py
@@ -10,11 +10,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from __future__ import annotations
import pandas as pd
-def contextualize(raw_df):
+def contextualize(raw_df: pd.DataFrame) -> Contextualizer:
+ """Create a Contextualizer object for the given DataFrame.
+ Used to create context windows.
+ """
return Contextualizer(raw_df)
@@ -26,25 +30,56 @@ class Contextualizer:
self._window = None
self._raw_df = raw_df
- def window(self, window):
+ def window(self, window: int) -> Contextualizer:
+ """Set the window size. i.e., how many rows to include in each window.
+
+ Parameters
+ ----------
+ window: int
+ The window size.
+ """
self._window = window
return self
- def stride(self, stride):
+ def stride(self, stride: int) -> Contextualizer:
+ """Set the stride. i.e., how many rows to skip between each window.
+
+ Parameters
+ ----------
+ stride: int
+ The stride.
+ """
self._stride = stride
return self
- def groupby(self, groupby):
+ def groupby(self, groupby: str) -> Contextualizer:
+ """Set the groupby column. i.e., how to group the rows.
+ Windows don't cross groups
+
+ Parameters
+ ----------
+ groupby: str
+ The groupby column.
+ """
self._groupby = groupby
return self
- def text_col(self, text_col):
+ def text_col(self, text_col: str) -> Contextualizer:
+ """Set the text column used to make the context window.
+
+ Parameters
+ ----------
+ text_col: str
+ The text column.
+ """
self._text_col = text_col
return self
- def to_df(self):
+ def to_df(self) -> pd.DataFrame:
+ """Create the context windows and return a DataFrame."""
+
def process_group(grp):
- # For each video, create the text rolling window
+ # For each group, create the text rolling window
text = grp[self._text_col].values
contexts = grp.iloc[: -self._window : self._stride, :].copy()
contexts[self._text_col] = [
diff --git a/python/lancedb/db.py b/python/lancedb/db.py
index d869575b..cf408e21 100644
--- a/python/lancedb/db.py
+++ b/python/lancedb/db.py
@@ -68,6 +68,11 @@ class LanceDBConnection:
schema: pyarrow.Schema; optional
The schema of the table.
+ Note
+ ----
+ The vector index won't be created by default.
+ To create the index, call the `create_index` method on the table.
+
Returns
-------
A LanceTable object representing the table.
diff --git a/python/lancedb/table.py b/python/lancedb/table.py
index b5fd6a1f..e4eea8f6 100644
--- a/python/lancedb/table.py
+++ b/python/lancedb/table.py
@@ -60,6 +60,17 @@ class LanceTable:
return os.path.join(self._conn.uri, f"{self.name}.lance")
def create_index(self, num_partitions=256, num_sub_vectors=96):
+ """Create an index on the table.
+
+ Parameters
+ ----------
+ num_partitions: int
+ The number of IVF partitions to use when creating the index.
+ Default is 256.
+ num_sub_vectors: int
+ The number of PQ sub-vectors to use when creating the index.
+ Default is 96.
+ """
return self._dataset.create_index(
column=VECTOR_COLUMN_NAME,
index_type="IVF_PQ",