diff --git a/.gitignore b/.gitignore index 9389ad0f..82107c69 100644 --- a/.gitignore +++ b/.gitignore @@ -13,3 +13,5 @@ site python/build python/dist + +notebooks/.ipynb_checkpoints diff --git a/README.md b/README.md index fc7e4965..a61ea789 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@

- + LanceDB Logo **Serverless, low-latency vector database for AI applications** @@ -45,5 +45,5 @@ db = lancedb.connect(uri) table = db.create_table("my_table", data=[{"vector": [3.1, 4.1], "item": "foo", "price": 10.0}, {"vector": [5.9, 26.5], "item": "bar", "price": 20.0}]) -result = table.search([100, 100]).where("price < 15").limit(1).to_df() +result = table.search([100, 100]).limit(2).to_df() ``` diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index 92243b79..d5054ca5 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -7,6 +7,7 @@ theme: plugins: - search - mkdocstrings +- mkdocs-jupyter nav: - Home: index.md diff --git a/docs/requirements.txt b/docs/requirements.txt index 1fbe6277..c2e2ab41 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,3 +1,4 @@ mkdocs==1.4.2 +mkdocs-jupyter==0.24.1 mkdocs-material==9.1.3 mkdocstrings[python]==0.20.0 diff --git a/docs/src/index.md b/docs/src/index.md index e4c909a6..a9446ea7 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -14,7 +14,26 @@ The key features of LanceDB include: LanceDB's core is written in Rust 🦀 and is built using Lance, an open-source columnar format designed for performant ML workloads. + +## Installation + +```shell +pip install lancedb +``` + +## Quickstart + +```python +import lancedb + +db = lancedb.connect(".") +table = db.create_table("my_table", + data=[{"vector": [3.1, 4.1], "item": "foo", "price": 10.0}, + {"vector": [5.9, 26.5], "item": "bar", "price": 20.0}]) +result = table.search([100, 100]).limit(2).to_df() +``` + + ## Documentation Quick Links -* `Quick start` - search and filter a hello world vector dataset with LanceDB using the Python SDK. * [`API Reference`](python.md) - detailed documentation for the LanceDB Python SDK. diff --git a/notebooks/youtube_transcript_search.ipynb b/notebooks/youtube_transcript_search.ipynb new file mode 100644 index 00000000..8d9565f2 --- /dev/null +++ b/notebooks/youtube_transcript_search.ipynb @@ -0,0 +1,418 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "42bf01fb", + "metadata": {}, + "source": [ + "# We're going to build question and answer bot\n", + "\n", + "That allow you to search through youtube transcripts using natural language" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "48547ddb", + "metadata": {}, + "outputs": [], + "source": [ + "pip install --quiet openai datasets lancedb" + ] + }, + { + "cell_type": "markdown", + "id": "22e570f4", + "metadata": {}, + "source": [ + "## Download the data\n", + "700 videos and 208619 sentences" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "a8987fcb", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Found cached dataset json (/Users/changshe/.cache/huggingface/datasets/jamescalam___json/jamescalam--youtube-transcriptions-08d889f6a5386b9b/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51)\n" + ] + }, + { + "data": { + "text/plain": [ + "Dataset({\n", + " features: ['title', 'published', 'url', 'video_id', 'channel_id', 'id', 'text', 'start', 'end'],\n", + " num_rows: 208619\n", + "})" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from datasets import load_dataset\n", + "\n", + "data = load_dataset('jamescalam/youtube-transcriptions', split='train')\n", + "data" + ] + }, + { + "cell_type": "markdown", + "id": "5ac2b6a3", + "metadata": {}, + "source": [ + "## Prepare context\n", + "\n", + "Create context of 20 sentences" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "121a7087", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "

\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
titlepublishedurlvideo_idchannel_ididtextstartend
177622$5 MILLION AI for FREE2022-08-12 15:18:07https://youtu.be/3EjtHs_lXnk3EjtHs_lXnkUCfzlCWGWYyIQ0aLC5w48gBQ3EjtHs_lXnk-t0.0Imagine an AI where all in the same model you ...0.024.0
\n", + "
" + ], + "text/plain": [ + " title published \\\n", + "177622 $5 MILLION AI for FREE 2022-08-12 15:18:07 \n", + "\n", + " url video_id channel_id \\\n", + "177622 https://youtu.be/3EjtHs_lXnk 3EjtHs_lXnk UCfzlCWGWYyIQ0aLC5w48gBQ \n", + "\n", + " id text \\\n", + "177622 3EjtHs_lXnk-t0.0 Imagine an AI where all in the same model you ... \n", + "\n", + " start end \n", + "177622 0.0 24.0 " + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from lancedb.context import contextualize\n", + "\n", + "df = (contextualize(data.to_pandas())\n", + " .groupby(\"title\").text_col(\"text\")\n", + " .window(20).stride(4)\n", + " .to_df())\n", + "df.head(1)" + ] + }, + { + "cell_type": "markdown", + "id": "3044e0b0", + "metadata": {}, + "source": [ + "## Create embedding function\n", + "We'll call the OpenAI embeddings API to get embeddings" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "8eefc159", + "metadata": {}, + "outputs": [], + "source": [ + "import openai\n", + "\n", + "# Configure environment variable OPENAI_API_KEY\n", + "# OR add variable openai.api_key = \"sk-...\"\n", + "\n", + "def embed_func(c): \n", + " rs = openai.Embedding.create(input=c, engine=\"text-embedding-ada-002\")\n", + " return [record[\"embedding\"] for record in rs[\"data\"]]" + ] + }, + { + "cell_type": "markdown", + "id": "2106b5bb", + "metadata": {}, + "source": [ + "## Create the LanceDB Table" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "13f15068", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Building vector index: IVF64,OPQ96, metric=l2\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Sample 16384 out of 48935 to train kmeans of 1536 dim, 64 clusters\n" + ] + } + ], + "source": [ + "import lancedb\n", + "from lancedb.embeddings import with_embeddings\n", + "\n", + "data = with_embeddings(embed_func, df, show_progress=True)\n", + "\n", + "db = lancedb.connect(\"/tmp/lancedb\") # current directory\n", + "tbl = db.create_table(\"chatbot\", data)\n", + "tbl.create_index(num_partitions=64, num_sub_vectors=96)" + ] + }, + { + "cell_type": "markdown", + "id": "23afc2f9", + "metadata": {}, + "source": [ + "## Create and answer the prompt" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "06d8b867", + "metadata": {}, + "outputs": [], + "source": [ + "def create_prompt(query, context):\n", + " limit = 3750\n", + "\n", + " prompt_start = (\n", + " \"Answer the question based on the context below.\\n\\n\"+\n", + " \"Context:\\n\"\n", + " )\n", + " prompt_end = (\n", + " f\"\\n\\nQuestion: {query}\\nAnswer:\"\n", + " )\n", + " # append contexts until hitting limit\n", + " for i in range(1, len(context)):\n", + " if len(\"\\n\\n---\\n\\n\".join(context.text[:i])) >= limit:\n", + " prompt = (\n", + " prompt_start +\n", + " \"\\n\\n---\\n\\n\".join(context.text[:i-1]) +\n", + " prompt_end\n", + " )\n", + " break\n", + " elif i == len(context)-1:\n", + " prompt = (\n", + " prompt_start +\n", + " \"\\n\\n---\\n\\n\".join(context.text) +\n", + " prompt_end\n", + " ) \n", + " return prompt" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "e09c5142", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'The 12th person on the moon was Harrison Schmitt, and he landed on December 11, 1972.'" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "def complete(prompt):\n", + " # query text-davinci-003\n", + " res = openai.Completion.create(\n", + " engine='text-davinci-003',\n", + " prompt=prompt,\n", + " temperature=0,\n", + " max_tokens=400,\n", + " top_p=1,\n", + " frequency_penalty=0,\n", + " presence_penalty=0,\n", + " stop=None\n", + " )\n", + " return res['choices'][0]['text'].strip()\n", + "\n", + "# check that it works\n", + "query = \"who was the 12th person on the moon and when did they land?\"\n", + "complete(query)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "8fcef773", + "metadata": {}, + "outputs": [], + "source": [ + "def answer(question):\n", + " emb = embed_func(query)[0]\n", + " context = (tbl.search(emb).limit(3)\n", + " .nprobes(20).refine_factor(100)\n", + " .to_df())\n", + " prompt = create_prompt(question, context)\n", + " return complete(prompt), context.reset_index()" + ] + }, + { + "cell_type": "markdown", + "id": "28705959", + "metadata": {}, + "source": [ + "## Show the answer and show the video at the right place" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "25714299", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "NLI with multiple negative ranking loss.\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from IPython.display import YouTubeVideo\n", + "\n", + "query = (\"Which training method should I use for sentence transformers \"\n", + " \"when I only have pairs of related sentences?\")\n", + "completion, context = answer(query)\n", + "\n", + "print(completion)\n", + "top_match = context.iloc[0]\n", + "YouTubeVideo(top_match[\"url\"].split(\"/\")[-1], start=top_match[\"start\"])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "78b7eb11", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.8" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/python/lancedb/context.py b/python/lancedb/context.py new file mode 100644 index 00000000..3a2a4c2d --- /dev/null +++ b/python/lancedb/context.py @@ -0,0 +1,61 @@ +# Copyright 2023 LanceDB Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pandas as pd + + +def contextualize(raw_df): + return Contextualizer(raw_df) + + +class Contextualizer: + def __init__(self, raw_df): + self._text_col = None + self._groupby = None + self._stride = None + self._window = None + self._raw_df = raw_df + + def window(self, window): + self._window = window + return self + + def stride(self, stride): + self._stride = stride + return self + + def groupby(self, groupby): + self._groupby = groupby + return self + + def text_col(self, text_col): + self._text_col = text_col + return self + + def to_df(self): + def process_group(grp): + # For each video, create the text rolling window + text = grp[self._text_col].values + contexts = grp.iloc[: -self._window : self._stride, :].copy() + contexts[self._text_col] = [ + " ".join(text[start_i : start_i + self._window]) + for start_i in range(0, len(grp) - self._window, self._stride) + ] + return contexts + + if self._groupby is None: + return process_group(self._raw_df) + # concat result from all groups + return pd.concat( + [process_group(grp) for _, grp in self._raw_df.groupby(self._groupby)] + ) diff --git a/python/lancedb/db.py b/python/lancedb/db.py index 3db1c583..d869575b 100644 --- a/python/lancedb/db.py +++ b/python/lancedb/db.py @@ -29,6 +29,7 @@ class LanceDBConnection: if isinstance(uri, str): uri = Path(uri) uri = uri.expanduser().absolute() + Path(uri).mkdir(parents=True, exist_ok=True) self._uri = str(uri) @property diff --git a/python/lancedb/embeddings.py b/python/lancedb/embeddings.py new file mode 100644 index 00000000..404d8c82 --- /dev/null +++ b/python/lancedb/embeddings.py @@ -0,0 +1,105 @@ +# Copyright 2023 LanceDB Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import ratelimiter +from retry import retry +from typing import Callable, Union + +from lance.vector import vec_to_table +import numpy as np +import pandas as pd +import pyarrow as pa + + +def with_embeddings( + func: Callable, + data: Union[pa.Table, pd.DataFrame], + column: str = "text", + wrap_api: bool = True, + show_progress: bool = False, + batch_size: int = 1000, +): + func = EmbeddingFunction(func) + if wrap_api: + func = func.retry().rate_limit().batch_size(batch_size) + if show_progress: + func = func.show_progress() + if isinstance(data, pd.DataFrame): + data = pa.Table.from_pandas(data) + embeddings = func(data[column].to_numpy()) + table = vec_to_table(np.array(embeddings)) + return data.append_column("vector", table["vector"]) + + +class EmbeddingFunction: + def __init__(self, func: Callable): + self.func = func + self.rate_limiter_kwargs = {} + self.retry_kwargs = {} + self._batch_size = None + self._progress = False + + def __call__(self, text): + # Get the embedding with retry + @retry(**self.retry_kwargs) + def embed_func(c): + return self.func(c.tolist()) + + max_calls = self.rate_limiter_kwargs["max_calls"] + limiter = ratelimiter.RateLimiter( + max_calls, period=self.rate_limiter_kwargs["period"] + ) + rate_limited = limiter(embed_func) + batches = self.to_batches(text) + embeds = [emb for c in batches for emb in rate_limited(c)] + return embeds + + def __repr__(self): + return f"EmbeddingFunction(func={self.func})" + + def rate_limit(self, max_calls=0.9, period=1.0): + self.rate_limiter_kwargs = dict(max_calls=max_calls, period=period) + return self + + def retry(self, tries=10, delay=1, max_delay=30, backoff=3, jitter=1): + self.retry_kwargs = dict( + tries=tries, + delay=delay, + max_delay=max_delay, + backoff=backoff, + jitter=jitter, + ) + return self + + def batch_size(self, batch_size): + self._batch_size = batch_size + return self + + def show_progress(self): + self._progress = True + return self + + def to_batches(self, arr): + length = len(arr) + + def _chunker(arr): + for start_i in range(0, len(arr), self._batch_size): + yield arr[start_i : start_i + self._batch_size] + + if self._progress: + from tqdm.auto import tqdm + + yield from tqdm(_chunker(arr), total=math.ceil(length / self._batch_size)) + else: + return _chunker(arr) diff --git a/python/lancedb/query.py b/python/lancedb/query.py index 14ac2083..21333bec 100644 --- a/python/lancedb/query.py +++ b/python/lancedb/query.py @@ -24,6 +24,8 @@ class LanceQueryBuilder: """ def __init__(self, table: "lancedb.table.LanceTable", query: np.ndarray): + self._nprobes = 20 + self._refine_factor = None self._table = table self._query = query self._limit = 10 @@ -75,6 +77,36 @@ class LanceQueryBuilder: self._where = where return self + def nprobes(self, nprobes: int) -> LanceQueryBuilder: + """Set the number of probes to use. + + Parameters + ---------- + nprobes: int + The number of probes to use. + + Returns + ------- + The LanceQueryBuilder object. + """ + self._nprobes = nprobes + return self + + def refine_factor(self, refine_factor: int) -> LanceQueryBuilder: + """Set the refine factor to use. + + Parameters + ---------- + refine_factor: int + The refine factor to use. + + Returns + ------- + The LanceQueryBuilder object. + """ + self._refine_factor = refine_factor + return self + def to_df(self) -> pd.DataFrame: """Execute the query and return the results as a pandas DataFrame.""" ds = self._table.to_lance() @@ -82,6 +114,12 @@ class LanceQueryBuilder: tbl = ds.to_table( columns=self._columns, filter=self._where, - nearest={"column": VECTOR_COLUMN_NAME, "q": self._query, "k": self._limit}, + nearest={ + "column": VECTOR_COLUMN_NAME, + "q": self._query, + "k": self._limit, + "nprobes": self._nprobes, + "refine_factor": self._refine_factor, + }, ) return tbl.to_pandas() diff --git a/python/lancedb/table.py b/python/lancedb/table.py index 7840f396..b5fd6a1f 100644 --- a/python/lancedb/table.py +++ b/python/lancedb/table.py @@ -59,6 +59,14 @@ class LanceTable: def _dataset_uri(self) -> str: return os.path.join(self._conn.uri, f"{self.name}.lance") + def create_index(self, num_partitions=256, num_sub_vectors=96): + return self._dataset.create_index( + column=VECTOR_COLUMN_NAME, + index_type="IVF_PQ", + num_partitions=num_partitions, + num_sub_vectors=num_sub_vectors, + ) + @cached_property def _dataset(self) -> LanceDataset: return lance.dataset(self._dataset_uri) diff --git a/python/pyproject.toml b/python/pyproject.toml index 6c1804cf..e2986da3 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "lancedb" version = "0.0.1" -dependencies = ["pylance"] +dependencies = ["pylance", "ratelimiter", "retry", "tqdm"] description = "lancedb" authors = [ { name = "Lance Devs", email = "dev@eto.ai" }, @@ -43,7 +43,7 @@ dev = [ "ruff", "pre-commit", "black" ] docs = [ - "mkdocs", "mkdocs-material", "mkdocstrings[python]" + "mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]" ] [build-system]