mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-22 21:09:58 +00:00
Add tutorial notebook
Convert contextualization and embeddings functionality. And use it with converted notebook for video search
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -13,3 +13,5 @@ site
|
||||
|
||||
python/build
|
||||
python/dist
|
||||
|
||||
notebooks/.ipynb_checkpoints
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
<div align="center">
|
||||
<p align="center">
|
||||
|
||||
|
||||
<img width="275" alt="LanceDB Logo" src="https://user-images.githubusercontent.com/917119/226205734-6063d87a-1ecc-45fe-85be-1dea6383a3d8.png">
|
||||
|
||||
**Serverless, low-latency vector database for AI applications**
|
||||
@@ -45,5 +45,5 @@ db = lancedb.connect(uri)
|
||||
table = db.create_table("my_table",
|
||||
data=[{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
|
||||
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0}])
|
||||
result = table.search([100, 100]).where("price < 15").limit(1).to_df()
|
||||
result = table.search([100, 100]).limit(2).to_df()
|
||||
```
|
||||
|
||||
@@ -7,6 +7,7 @@ theme:
|
||||
plugins:
|
||||
- search
|
||||
- mkdocstrings
|
||||
- mkdocs-jupyter
|
||||
|
||||
nav:
|
||||
- Home: index.md
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
mkdocs==1.4.2
|
||||
mkdocs-jupyter==0.24.1
|
||||
mkdocs-material==9.1.3
|
||||
mkdocstrings[python]==0.20.0
|
||||
|
||||
@@ -14,7 +14,26 @@ The key features of LanceDB include:
|
||||
|
||||
LanceDB's core is written in Rust 🦀 and is built using Lance, an open-source columnar format designed for performant ML workloads.
|
||||
|
||||
|
||||
## Installation
|
||||
|
||||
```shell
|
||||
pip install lancedb
|
||||
```
|
||||
|
||||
## Quickstart
|
||||
|
||||
```python
|
||||
import lancedb
|
||||
|
||||
db = lancedb.connect(".")
|
||||
table = db.create_table("my_table",
|
||||
data=[{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
|
||||
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0}])
|
||||
result = table.search([100, 100]).limit(2).to_df()
|
||||
```
|
||||
|
||||
|
||||
## Documentation Quick Links
|
||||
|
||||
* `Quick start` - search and filter a hello world vector dataset with LanceDB using the Python SDK.
|
||||
* [`API Reference`](python.md) - detailed documentation for the LanceDB Python SDK.
|
||||
|
||||
418
notebooks/youtube_transcript_search.ipynb
Normal file
418
notebooks/youtube_transcript_search.ipynb
Normal file
@@ -0,0 +1,418 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "42bf01fb",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# We're going to build question and answer bot\n",
|
||||
"\n",
|
||||
"That allow you to search through youtube transcripts using natural language"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "48547ddb",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"pip install --quiet openai datasets lancedb"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "22e570f4",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Download the data\n",
|
||||
"700 videos and 208619 sentences"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "a8987fcb",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Found cached dataset json (/Users/changshe/.cache/huggingface/datasets/jamescalam___json/jamescalam--youtube-transcriptions-08d889f6a5386b9b/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"Dataset({\n",
|
||||
" features: ['title', 'published', 'url', 'video_id', 'channel_id', 'id', 'text', 'start', 'end'],\n",
|
||||
" num_rows: 208619\n",
|
||||
"})"
|
||||
]
|
||||
},
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from datasets import load_dataset\n",
|
||||
"\n",
|
||||
"data = load_dataset('jamescalam/youtube-transcriptions', split='train')\n",
|
||||
"data"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "5ac2b6a3",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Prepare context\n",
|
||||
"\n",
|
||||
"Create context of 20 sentences"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "121a7087",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<div>\n",
|
||||
"<style scoped>\n",
|
||||
" .dataframe tbody tr th:only-of-type {\n",
|
||||
" vertical-align: middle;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe tbody tr th {\n",
|
||||
" vertical-align: top;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe thead th {\n",
|
||||
" text-align: right;\n",
|
||||
" }\n",
|
||||
"</style>\n",
|
||||
"<table border=\"1\" class=\"dataframe\">\n",
|
||||
" <thead>\n",
|
||||
" <tr style=\"text-align: right;\">\n",
|
||||
" <th></th>\n",
|
||||
" <th>title</th>\n",
|
||||
" <th>published</th>\n",
|
||||
" <th>url</th>\n",
|
||||
" <th>video_id</th>\n",
|
||||
" <th>channel_id</th>\n",
|
||||
" <th>id</th>\n",
|
||||
" <th>text</th>\n",
|
||||
" <th>start</th>\n",
|
||||
" <th>end</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <th>177622</th>\n",
|
||||
" <td>$5 MILLION AI for FREE</td>\n",
|
||||
" <td>2022-08-12 15:18:07</td>\n",
|
||||
" <td>https://youtu.be/3EjtHs_lXnk</td>\n",
|
||||
" <td>3EjtHs_lXnk</td>\n",
|
||||
" <td>UCfzlCWGWYyIQ0aLC5w48gBQ</td>\n",
|
||||
" <td>3EjtHs_lXnk-t0.0</td>\n",
|
||||
" <td>Imagine an AI where all in the same model you ...</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>24.0</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table>\n",
|
||||
"</div>"
|
||||
],
|
||||
"text/plain": [
|
||||
" title published \\\n",
|
||||
"177622 $5 MILLION AI for FREE 2022-08-12 15:18:07 \n",
|
||||
"\n",
|
||||
" url video_id channel_id \\\n",
|
||||
"177622 https://youtu.be/3EjtHs_lXnk 3EjtHs_lXnk UCfzlCWGWYyIQ0aLC5w48gBQ \n",
|
||||
"\n",
|
||||
" id text \\\n",
|
||||
"177622 3EjtHs_lXnk-t0.0 Imagine an AI where all in the same model you ... \n",
|
||||
"\n",
|
||||
" start end \n",
|
||||
"177622 0.0 24.0 "
|
||||
]
|
||||
},
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from lancedb.context import contextualize\n",
|
||||
"\n",
|
||||
"df = (contextualize(data.to_pandas())\n",
|
||||
" .groupby(\"title\").text_col(\"text\")\n",
|
||||
" .window(20).stride(4)\n",
|
||||
" .to_df())\n",
|
||||
"df.head(1)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "3044e0b0",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Create embedding function\n",
|
||||
"We'll call the OpenAI embeddings API to get embeddings"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "8eefc159",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import openai\n",
|
||||
"\n",
|
||||
"# Configure environment variable OPENAI_API_KEY\n",
|
||||
"# OR add variable openai.api_key = \"sk-...\"\n",
|
||||
"\n",
|
||||
"def embed_func(c): \n",
|
||||
" rs = openai.Embedding.create(input=c, engine=\"text-embedding-ada-002\")\n",
|
||||
" return [record[\"embedding\"] for record in rs[\"data\"]]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "2106b5bb",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Create the LanceDB Table"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "13f15068",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Building vector index: IVF64,OPQ96, metric=l2\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"<lance.dataset.LanceDataset at 0x13fd38dc0>"
|
||||
]
|
||||
},
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Sample 16384 out of 48935 to train kmeans of 1536 dim, 64 clusters\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import lancedb\n",
|
||||
"from lancedb.embeddings import with_embeddings\n",
|
||||
"\n",
|
||||
"data = with_embeddings(embed_func, df, show_progress=True)\n",
|
||||
"\n",
|
||||
"db = lancedb.connect(\"/tmp/lancedb\") # current directory\n",
|
||||
"tbl = db.create_table(\"chatbot\", data)\n",
|
||||
"tbl.create_index(num_partitions=64, num_sub_vectors=96)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "23afc2f9",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Create and answer the prompt"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "06d8b867",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def create_prompt(query, context):\n",
|
||||
" limit = 3750\n",
|
||||
"\n",
|
||||
" prompt_start = (\n",
|
||||
" \"Answer the question based on the context below.\\n\\n\"+\n",
|
||||
" \"Context:\\n\"\n",
|
||||
" )\n",
|
||||
" prompt_end = (\n",
|
||||
" f\"\\n\\nQuestion: {query}\\nAnswer:\"\n",
|
||||
" )\n",
|
||||
" # append contexts until hitting limit\n",
|
||||
" for i in range(1, len(context)):\n",
|
||||
" if len(\"\\n\\n---\\n\\n\".join(context.text[:i])) >= limit:\n",
|
||||
" prompt = (\n",
|
||||
" prompt_start +\n",
|
||||
" \"\\n\\n---\\n\\n\".join(context.text[:i-1]) +\n",
|
||||
" prompt_end\n",
|
||||
" )\n",
|
||||
" break\n",
|
||||
" elif i == len(context)-1:\n",
|
||||
" prompt = (\n",
|
||||
" prompt_start +\n",
|
||||
" \"\\n\\n---\\n\\n\".join(context.text) +\n",
|
||||
" prompt_end\n",
|
||||
" ) \n",
|
||||
" return prompt"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "e09c5142",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'The 12th person on the moon was Harrison Schmitt, and he landed on December 11, 1972.'"
|
||||
]
|
||||
},
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"def complete(prompt):\n",
|
||||
" # query text-davinci-003\n",
|
||||
" res = openai.Completion.create(\n",
|
||||
" engine='text-davinci-003',\n",
|
||||
" prompt=prompt,\n",
|
||||
" temperature=0,\n",
|
||||
" max_tokens=400,\n",
|
||||
" top_p=1,\n",
|
||||
" frequency_penalty=0,\n",
|
||||
" presence_penalty=0,\n",
|
||||
" stop=None\n",
|
||||
" )\n",
|
||||
" return res['choices'][0]['text'].strip()\n",
|
||||
"\n",
|
||||
"# check that it works\n",
|
||||
"query = \"who was the 12th person on the moon and when did they land?\"\n",
|
||||
"complete(query)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"id": "8fcef773",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def answer(question):\n",
|
||||
" emb = embed_func(query)[0]\n",
|
||||
" context = (tbl.search(emb).limit(3)\n",
|
||||
" .nprobes(20).refine_factor(100)\n",
|
||||
" .to_df())\n",
|
||||
" prompt = create_prompt(question, context)\n",
|
||||
" return complete(prompt), context.reset_index()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "28705959",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Show the answer and show the video at the right place"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"id": "25714299",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"NLI with multiple negative ranking loss.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"\n",
|
||||
" <iframe\n",
|
||||
" width=\"400\"\n",
|
||||
" height=\"300\"\n",
|
||||
" src=\"https://www.youtube.com/embed/pNvujJ1XyeQ?start=289.76\"\n",
|
||||
" frameborder=\"0\"\n",
|
||||
" allowfullscreen\n",
|
||||
" \n",
|
||||
" ></iframe>\n",
|
||||
" "
|
||||
],
|
||||
"text/plain": [
|
||||
"<IPython.lib.display.YouTubeVideo at 0x12f58afb0>"
|
||||
]
|
||||
},
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from IPython.display import YouTubeVideo\n",
|
||||
"\n",
|
||||
"query = (\"Which training method should I use for sentence transformers \"\n",
|
||||
" \"when I only have pairs of related sentences?\")\n",
|
||||
"completion, context = answer(query)\n",
|
||||
"\n",
|
||||
"print(completion)\n",
|
||||
"top_match = context.iloc[0]\n",
|
||||
"YouTubeVideo(top_match[\"url\"].split(\"/\")[-1], start=top_match[\"start\"])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "78b7eb11",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.8"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
61
python/lancedb/context.py
Normal file
61
python/lancedb/context.py
Normal file
@@ -0,0 +1,61 @@
|
||||
# Copyright 2023 LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import pandas as pd
|
||||
|
||||
|
||||
def contextualize(raw_df):
|
||||
return Contextualizer(raw_df)
|
||||
|
||||
|
||||
class Contextualizer:
|
||||
def __init__(self, raw_df):
|
||||
self._text_col = None
|
||||
self._groupby = None
|
||||
self._stride = None
|
||||
self._window = None
|
||||
self._raw_df = raw_df
|
||||
|
||||
def window(self, window):
|
||||
self._window = window
|
||||
return self
|
||||
|
||||
def stride(self, stride):
|
||||
self._stride = stride
|
||||
return self
|
||||
|
||||
def groupby(self, groupby):
|
||||
self._groupby = groupby
|
||||
return self
|
||||
|
||||
def text_col(self, text_col):
|
||||
self._text_col = text_col
|
||||
return self
|
||||
|
||||
def to_df(self):
|
||||
def process_group(grp):
|
||||
# For each video, create the text rolling window
|
||||
text = grp[self._text_col].values
|
||||
contexts = grp.iloc[: -self._window : self._stride, :].copy()
|
||||
contexts[self._text_col] = [
|
||||
" ".join(text[start_i : start_i + self._window])
|
||||
for start_i in range(0, len(grp) - self._window, self._stride)
|
||||
]
|
||||
return contexts
|
||||
|
||||
if self._groupby is None:
|
||||
return process_group(self._raw_df)
|
||||
# concat result from all groups
|
||||
return pd.concat(
|
||||
[process_group(grp) for _, grp in self._raw_df.groupby(self._groupby)]
|
||||
)
|
||||
@@ -29,6 +29,7 @@ class LanceDBConnection:
|
||||
if isinstance(uri, str):
|
||||
uri = Path(uri)
|
||||
uri = uri.expanduser().absolute()
|
||||
Path(uri).mkdir(parents=True, exist_ok=True)
|
||||
self._uri = str(uri)
|
||||
|
||||
@property
|
||||
|
||||
105
python/lancedb/embeddings.py
Normal file
105
python/lancedb/embeddings.py
Normal file
@@ -0,0 +1,105 @@
|
||||
# Copyright 2023 LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
import ratelimiter
|
||||
from retry import retry
|
||||
from typing import Callable, Union
|
||||
|
||||
from lance.vector import vec_to_table
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pyarrow as pa
|
||||
|
||||
|
||||
def with_embeddings(
|
||||
func: Callable,
|
||||
data: Union[pa.Table, pd.DataFrame],
|
||||
column: str = "text",
|
||||
wrap_api: bool = True,
|
||||
show_progress: bool = False,
|
||||
batch_size: int = 1000,
|
||||
):
|
||||
func = EmbeddingFunction(func)
|
||||
if wrap_api:
|
||||
func = func.retry().rate_limit().batch_size(batch_size)
|
||||
if show_progress:
|
||||
func = func.show_progress()
|
||||
if isinstance(data, pd.DataFrame):
|
||||
data = pa.Table.from_pandas(data)
|
||||
embeddings = func(data[column].to_numpy())
|
||||
table = vec_to_table(np.array(embeddings))
|
||||
return data.append_column("vector", table["vector"])
|
||||
|
||||
|
||||
class EmbeddingFunction:
|
||||
def __init__(self, func: Callable):
|
||||
self.func = func
|
||||
self.rate_limiter_kwargs = {}
|
||||
self.retry_kwargs = {}
|
||||
self._batch_size = None
|
||||
self._progress = False
|
||||
|
||||
def __call__(self, text):
|
||||
# Get the embedding with retry
|
||||
@retry(**self.retry_kwargs)
|
||||
def embed_func(c):
|
||||
return self.func(c.tolist())
|
||||
|
||||
max_calls = self.rate_limiter_kwargs["max_calls"]
|
||||
limiter = ratelimiter.RateLimiter(
|
||||
max_calls, period=self.rate_limiter_kwargs["period"]
|
||||
)
|
||||
rate_limited = limiter(embed_func)
|
||||
batches = self.to_batches(text)
|
||||
embeds = [emb for c in batches for emb in rate_limited(c)]
|
||||
return embeds
|
||||
|
||||
def __repr__(self):
|
||||
return f"EmbeddingFunction(func={self.func})"
|
||||
|
||||
def rate_limit(self, max_calls=0.9, period=1.0):
|
||||
self.rate_limiter_kwargs = dict(max_calls=max_calls, period=period)
|
||||
return self
|
||||
|
||||
def retry(self, tries=10, delay=1, max_delay=30, backoff=3, jitter=1):
|
||||
self.retry_kwargs = dict(
|
||||
tries=tries,
|
||||
delay=delay,
|
||||
max_delay=max_delay,
|
||||
backoff=backoff,
|
||||
jitter=jitter,
|
||||
)
|
||||
return self
|
||||
|
||||
def batch_size(self, batch_size):
|
||||
self._batch_size = batch_size
|
||||
return self
|
||||
|
||||
def show_progress(self):
|
||||
self._progress = True
|
||||
return self
|
||||
|
||||
def to_batches(self, arr):
|
||||
length = len(arr)
|
||||
|
||||
def _chunker(arr):
|
||||
for start_i in range(0, len(arr), self._batch_size):
|
||||
yield arr[start_i : start_i + self._batch_size]
|
||||
|
||||
if self._progress:
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
yield from tqdm(_chunker(arr), total=math.ceil(length / self._batch_size))
|
||||
else:
|
||||
return _chunker(arr)
|
||||
@@ -24,6 +24,8 @@ class LanceQueryBuilder:
|
||||
"""
|
||||
|
||||
def __init__(self, table: "lancedb.table.LanceTable", query: np.ndarray):
|
||||
self._nprobes = 20
|
||||
self._refine_factor = None
|
||||
self._table = table
|
||||
self._query = query
|
||||
self._limit = 10
|
||||
@@ -75,6 +77,36 @@ class LanceQueryBuilder:
|
||||
self._where = where
|
||||
return self
|
||||
|
||||
def nprobes(self, nprobes: int) -> LanceQueryBuilder:
|
||||
"""Set the number of probes to use.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
nprobes: int
|
||||
The number of probes to use.
|
||||
|
||||
Returns
|
||||
-------
|
||||
The LanceQueryBuilder object.
|
||||
"""
|
||||
self._nprobes = nprobes
|
||||
return self
|
||||
|
||||
def refine_factor(self, refine_factor: int) -> LanceQueryBuilder:
|
||||
"""Set the refine factor to use.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
refine_factor: int
|
||||
The refine factor to use.
|
||||
|
||||
Returns
|
||||
-------
|
||||
The LanceQueryBuilder object.
|
||||
"""
|
||||
self._refine_factor = refine_factor
|
||||
return self
|
||||
|
||||
def to_df(self) -> pd.DataFrame:
|
||||
"""Execute the query and return the results as a pandas DataFrame."""
|
||||
ds = self._table.to_lance()
|
||||
@@ -82,6 +114,12 @@ class LanceQueryBuilder:
|
||||
tbl = ds.to_table(
|
||||
columns=self._columns,
|
||||
filter=self._where,
|
||||
nearest={"column": VECTOR_COLUMN_NAME, "q": self._query, "k": self._limit},
|
||||
nearest={
|
||||
"column": VECTOR_COLUMN_NAME,
|
||||
"q": self._query,
|
||||
"k": self._limit,
|
||||
"nprobes": self._nprobes,
|
||||
"refine_factor": self._refine_factor,
|
||||
},
|
||||
)
|
||||
return tbl.to_pandas()
|
||||
|
||||
@@ -59,6 +59,14 @@ class LanceTable:
|
||||
def _dataset_uri(self) -> str:
|
||||
return os.path.join(self._conn.uri, f"{self.name}.lance")
|
||||
|
||||
def create_index(self, num_partitions=256, num_sub_vectors=96):
|
||||
return self._dataset.create_index(
|
||||
column=VECTOR_COLUMN_NAME,
|
||||
index_type="IVF_PQ",
|
||||
num_partitions=num_partitions,
|
||||
num_sub_vectors=num_sub_vectors,
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def _dataset(self) -> LanceDataset:
|
||||
return lance.dataset(self._dataset_uri)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
[project]
|
||||
name = "lancedb"
|
||||
version = "0.0.1"
|
||||
dependencies = ["pylance"]
|
||||
dependencies = ["pylance", "ratelimiter", "retry", "tqdm"]
|
||||
description = "lancedb"
|
||||
authors = [
|
||||
{ name = "Lance Devs", email = "dev@eto.ai" },
|
||||
@@ -43,7 +43,7 @@ dev = [
|
||||
"ruff", "pre-commit", "black"
|
||||
]
|
||||
docs = [
|
||||
"mkdocs", "mkdocs-material", "mkdocstrings[python]"
|
||||
"mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"
|
||||
]
|
||||
|
||||
[build-system]
|
||||
|
||||
Reference in New Issue
Block a user