mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-07 04:12:59 +00:00
Merge pull request #14 from lancedb/changhiskhan/updates
update for release
This commit is contained in:
@@ -31,7 +31,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 2,
|
"execution_count": 1,
|
||||||
"id": "a8987fcb",
|
"id": "a8987fcb",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@@ -51,7 +51,7 @@
|
|||||||
"})"
|
"})"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 2,
|
"execution_count": 1,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
@@ -75,7 +75,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 3,
|
"execution_count": 2,
|
||||||
"id": "121a7087",
|
"id": "121a7087",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@@ -142,7 +142,7 @@
|
|||||||
"177622 0.0 24.0 "
|
"177622 0.0 24.0 "
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 3,
|
"execution_count": 2,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
@@ -166,6 +166,24 @@
|
|||||||
"We'll call the OpenAI embeddings API to get embeddings"
|
"We'll call the OpenAI embeddings API to get embeddings"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"id": "c8104467",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import openai\n",
|
||||||
|
"import os\n",
|
||||||
|
"\n",
|
||||||
|
"# Configuring the environment variable OPENAI_API_KEY\n",
|
||||||
|
"if \"OPENAI_API_KEY\" not in os.environ:\n",
|
||||||
|
" # OR set the key here as a variable\n",
|
||||||
|
" openai.api_key = \"sk-...\"\n",
|
||||||
|
" \n",
|
||||||
|
"assert len(openai.Model.list()[\"data\"]) > 0"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 4,
|
"execution_count": 4,
|
||||||
@@ -173,11 +191,7 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"import openai\n",
|
"import numpy as np\n",
|
||||||
"\n",
|
|
||||||
"# Configure environment variable OPENAI_API_KEY\n",
|
|
||||||
"# OR add variable openai.api_key = \"sk-...\"\n",
|
|
||||||
"\n",
|
|
||||||
"def embed_func(c): \n",
|
"def embed_func(c): \n",
|
||||||
" rs = openai.Embedding.create(input=c, engine=\"text-embedding-ada-002\")\n",
|
" rs = openai.Embedding.create(input=c, engine=\"text-embedding-ada-002\")\n",
|
||||||
" return [record[\"embedding\"] for record in rs[\"data\"]]"
|
" return [record[\"embedding\"] for record in rs[\"data\"]]"
|
||||||
@@ -193,33 +207,94 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 7,
|
"execution_count": 5,
|
||||||
"id": "13f15068",
|
"id": "13f15068",
|
||||||
"metadata": {},
|
"metadata": {
|
||||||
|
"scrolled": false
|
||||||
|
},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"name": "stdout",
|
"data": {
|
||||||
"output_type": "stream",
|
"application/vnd.jupyter.widget-view+json": {
|
||||||
"text": [
|
"model_id": "c4fb6f5a4ccc40ddb89d9df497213292",
|
||||||
"Building vector index: IVF64,OPQ96, metric=l2\n"
|
"version_major": 2,
|
||||||
]
|
"version_minor": 0
|
||||||
|
},
|
||||||
|
"text/plain": [
|
||||||
|
" 0%| | 0/49 [00:00<?, ?it/s]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "display_data"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
|
"text/html": [
|
||||||
|
"<div>\n",
|
||||||
|
"<style scoped>\n",
|
||||||
|
" .dataframe tbody tr th:only-of-type {\n",
|
||||||
|
" vertical-align: middle;\n",
|
||||||
|
" }\n",
|
||||||
|
"\n",
|
||||||
|
" .dataframe tbody tr th {\n",
|
||||||
|
" vertical-align: top;\n",
|
||||||
|
" }\n",
|
||||||
|
"\n",
|
||||||
|
" .dataframe thead th {\n",
|
||||||
|
" text-align: right;\n",
|
||||||
|
" }\n",
|
||||||
|
"</style>\n",
|
||||||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
||||||
|
" <thead>\n",
|
||||||
|
" <tr style=\"text-align: right;\">\n",
|
||||||
|
" <th></th>\n",
|
||||||
|
" <th>title</th>\n",
|
||||||
|
" <th>published</th>\n",
|
||||||
|
" <th>url</th>\n",
|
||||||
|
" <th>video_id</th>\n",
|
||||||
|
" <th>channel_id</th>\n",
|
||||||
|
" <th>id</th>\n",
|
||||||
|
" <th>text</th>\n",
|
||||||
|
" <th>start</th>\n",
|
||||||
|
" <th>end</th>\n",
|
||||||
|
" <th>vector</th>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" </thead>\n",
|
||||||
|
" <tbody>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <th>0</th>\n",
|
||||||
|
" <td>$5 MILLION AI for FREE</td>\n",
|
||||||
|
" <td>2022-08-12 15:18:07</td>\n",
|
||||||
|
" <td>https://youtu.be/3EjtHs_lXnk</td>\n",
|
||||||
|
" <td>3EjtHs_lXnk</td>\n",
|
||||||
|
" <td>UCfzlCWGWYyIQ0aLC5w48gBQ</td>\n",
|
||||||
|
" <td>3EjtHs_lXnk-t0.0</td>\n",
|
||||||
|
" <td>Imagine an AI where all in the same model you ...</td>\n",
|
||||||
|
" <td>0.0</td>\n",
|
||||||
|
" <td>24.0</td>\n",
|
||||||
|
" <td>[-0.024402587, -0.00087673456, 0.016499246, -0...</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" </tbody>\n",
|
||||||
|
"</table>\n",
|
||||||
|
"</div>"
|
||||||
|
],
|
||||||
"text/plain": [
|
"text/plain": [
|
||||||
"<lance.dataset.LanceDataset at 0x13fd38dc0>"
|
" title published url \\\n",
|
||||||
|
"0 $5 MILLION AI for FREE 2022-08-12 15:18:07 https://youtu.be/3EjtHs_lXnk \n",
|
||||||
|
"\n",
|
||||||
|
" video_id channel_id id \\\n",
|
||||||
|
"0 3EjtHs_lXnk UCfzlCWGWYyIQ0aLC5w48gBQ 3EjtHs_lXnk-t0.0 \n",
|
||||||
|
"\n",
|
||||||
|
" text start end \\\n",
|
||||||
|
"0 Imagine an AI where all in the same model you ... 0.0 24.0 \n",
|
||||||
|
"\n",
|
||||||
|
" vector \n",
|
||||||
|
"0 [-0.024402587, -0.00087673456, 0.016499246, -0... "
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 7,
|
"execution_count": 5,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Sample 16384 out of 48935 to train kmeans of 1536 dim, 64 clusters\n"
|
|
||||||
]
|
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
@@ -227,10 +302,110 @@
|
|||||||
"from lancedb.embeddings import with_embeddings\n",
|
"from lancedb.embeddings import with_embeddings\n",
|
||||||
"\n",
|
"\n",
|
||||||
"data = with_embeddings(embed_func, df, show_progress=True)\n",
|
"data = with_embeddings(embed_func, df, show_progress=True)\n",
|
||||||
"\n",
|
"data.to_pandas().head(1)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 6,
|
||||||
|
"id": "92d53abd",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"48935"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 6,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
"db = lancedb.connect(\"/tmp/lancedb\") # current directory\n",
|
"db = lancedb.connect(\"/tmp/lancedb\") # current directory\n",
|
||||||
"tbl = db.create_table(\"chatbot\", data)\n",
|
"tbl = db.create_table(\"chatbot\", data)\n",
|
||||||
"tbl.create_index(num_partitions=64, num_sub_vectors=96)"
|
"len(tbl)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 7,
|
||||||
|
"id": "22892cfd",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/html": [
|
||||||
|
"<div>\n",
|
||||||
|
"<style scoped>\n",
|
||||||
|
" .dataframe tbody tr th:only-of-type {\n",
|
||||||
|
" vertical-align: middle;\n",
|
||||||
|
" }\n",
|
||||||
|
"\n",
|
||||||
|
" .dataframe tbody tr th {\n",
|
||||||
|
" vertical-align: top;\n",
|
||||||
|
" }\n",
|
||||||
|
"\n",
|
||||||
|
" .dataframe thead th {\n",
|
||||||
|
" text-align: right;\n",
|
||||||
|
" }\n",
|
||||||
|
"</style>\n",
|
||||||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
||||||
|
" <thead>\n",
|
||||||
|
" <tr style=\"text-align: right;\">\n",
|
||||||
|
" <th></th>\n",
|
||||||
|
" <th>title</th>\n",
|
||||||
|
" <th>published</th>\n",
|
||||||
|
" <th>url</th>\n",
|
||||||
|
" <th>video_id</th>\n",
|
||||||
|
" <th>channel_id</th>\n",
|
||||||
|
" <th>id</th>\n",
|
||||||
|
" <th>text</th>\n",
|
||||||
|
" <th>start</th>\n",
|
||||||
|
" <th>end</th>\n",
|
||||||
|
" <th>vector</th>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" </thead>\n",
|
||||||
|
" <tbody>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <th>0</th>\n",
|
||||||
|
" <td>$5 MILLION AI for FREE</td>\n",
|
||||||
|
" <td>2022-08-12 15:18:07</td>\n",
|
||||||
|
" <td>https://youtu.be/3EjtHs_lXnk</td>\n",
|
||||||
|
" <td>3EjtHs_lXnk</td>\n",
|
||||||
|
" <td>UCfzlCWGWYyIQ0aLC5w48gBQ</td>\n",
|
||||||
|
" <td>3EjtHs_lXnk-t0.0</td>\n",
|
||||||
|
" <td>Imagine an AI where all in the same model you ...</td>\n",
|
||||||
|
" <td>0.0</td>\n",
|
||||||
|
" <td>24.0</td>\n",
|
||||||
|
" <td>[-0.024402587, -0.00087673456, 0.016499246, -0...</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" </tbody>\n",
|
||||||
|
"</table>\n",
|
||||||
|
"</div>"
|
||||||
|
],
|
||||||
|
"text/plain": [
|
||||||
|
" title published url \\\n",
|
||||||
|
"0 $5 MILLION AI for FREE 2022-08-12 15:18:07 https://youtu.be/3EjtHs_lXnk \n",
|
||||||
|
"\n",
|
||||||
|
" video_id channel_id id \\\n",
|
||||||
|
"0 3EjtHs_lXnk UCfzlCWGWYyIQ0aLC5w48gBQ 3EjtHs_lXnk-t0.0 \n",
|
||||||
|
"\n",
|
||||||
|
" text start end \\\n",
|
||||||
|
"0 Imagine an AI where all in the same model you ... 0.0 24.0 \n",
|
||||||
|
"\n",
|
||||||
|
" vector \n",
|
||||||
|
"0 [-0.024402587, -0.00087673456, 0.016499246, -0... "
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 7,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"tbl.to_pandas().head(1)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -313,43 +488,76 @@
|
|||||||
"complete(query)"
|
"complete(query)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 12,
|
|
||||||
"id": "8fcef773",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"def answer(question):\n",
|
|
||||||
" emb = embed_func(query)[0]\n",
|
|
||||||
" context = (tbl.search(emb).limit(3)\n",
|
|
||||||
" .nprobes(20).refine_factor(100)\n",
|
|
||||||
" .to_df())\n",
|
|
||||||
" prompt = create_prompt(question, context)\n",
|
|
||||||
" return complete(prompt), context.reset_index()"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"id": "28705959",
|
"id": "28705959",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"## Show the answer and show the video at the right place"
|
"## Use LanceDB to find the answer and show the video at the right place"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 10,
|
||||||
|
"id": "c71f5b31",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"query = (\"Which training method should I use for sentence transformers \"\n",
|
||||||
|
" \"when I only have pairs of related sentences?\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 11,
|
||||||
|
"id": "603ba92c",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Embed the question\n",
|
||||||
|
"emb = embed_func(query)[0]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 12,
|
||||||
|
"id": "80db5c15",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Use LanceDB to get top 3 most relevant context\n",
|
||||||
|
"context = tbl.search(emb).limit(3).to_df()"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 13,
|
"execution_count": 13,
|
||||||
"id": "25714299",
|
"id": "8fcef773",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"name": "stdout",
|
"data": {
|
||||||
"output_type": "stream",
|
"text/plain": [
|
||||||
"text": [
|
"'NLI with multiple negative ranking loss.'"
|
||||||
"NLI with multiple negative ranking loss.\n"
|
]
|
||||||
]
|
},
|
||||||
},
|
"execution_count": 13,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"# Get the answer from completion API\n",
|
||||||
|
"prompt = create_prompt(query, context)\n",
|
||||||
|
"complete(prompt)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 14,
|
||||||
|
"id": "25714299",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"text/html": [
|
"text/html": [
|
||||||
@@ -365,10 +573,10 @@
|
|||||||
" "
|
" "
|
||||||
],
|
],
|
||||||
"text/plain": [
|
"text/plain": [
|
||||||
"<IPython.lib.display.YouTubeVideo at 0x12f58afb0>"
|
"<IPython.lib.display.YouTubeVideo at 0x1258aeaa0>"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 13,
|
"execution_count": 14,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
@@ -376,11 +584,6 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"from IPython.display import YouTubeVideo\n",
|
"from IPython.display import YouTubeVideo\n",
|
||||||
"\n",
|
"\n",
|
||||||
"query = (\"Which training method should I use for sentence transformers \"\n",
|
|
||||||
" \"when I only have pairs of related sentences?\")\n",
|
|
||||||
"completion, context = answer(query)\n",
|
|
||||||
"\n",
|
|
||||||
"print(completion)\n",
|
|
||||||
"top_match = context.iloc[0]\n",
|
"top_match = context.iloc[0]\n",
|
||||||
"YouTubeVideo(top_match[\"url\"].split(\"/\")[-1], start=top_match[\"start\"])"
|
"YouTubeVideo(top_match[\"url\"].split(\"/\")[-1], start=top_match[\"start\"])"
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -12,7 +12,6 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
import ratelimiter
|
|
||||||
from retry import retry
|
from retry import retry
|
||||||
from typing import Callable, Union
|
from typing import Callable, Union
|
||||||
|
|
||||||
@@ -32,11 +31,12 @@ def with_embeddings(
|
|||||||
):
|
):
|
||||||
func = EmbeddingFunction(func)
|
func = EmbeddingFunction(func)
|
||||||
if wrap_api:
|
if wrap_api:
|
||||||
func = func.retry().rate_limit().batch_size(batch_size)
|
func = func.retry().rate_limit()
|
||||||
|
func = func.batch_size(batch_size)
|
||||||
if show_progress:
|
if show_progress:
|
||||||
func = func.show_progress()
|
func = func.show_progress()
|
||||||
if isinstance(data, pd.DataFrame):
|
if isinstance(data, pd.DataFrame):
|
||||||
data = pa.Table.from_pandas(data)
|
data = pa.Table.from_pandas(data, preserve_index=False)
|
||||||
embeddings = func(data[column].to_numpy())
|
embeddings = func(data[column].to_numpy())
|
||||||
table = vec_to_table(np.array(embeddings))
|
table = vec_to_table(np.array(embeddings))
|
||||||
return data.append_column("vector", table["vector"])
|
return data.append_column("vector", table["vector"])
|
||||||
@@ -52,23 +52,38 @@ class EmbeddingFunction:
|
|||||||
|
|
||||||
def __call__(self, text):
|
def __call__(self, text):
|
||||||
# Get the embedding with retry
|
# Get the embedding with retry
|
||||||
@retry(**self.retry_kwargs)
|
if len(self.retry_kwargs) > 0:
|
||||||
def embed_func(c):
|
|
||||||
return self.func(c.tolist())
|
|
||||||
|
|
||||||
max_calls = self.rate_limiter_kwargs["max_calls"]
|
@retry(**self.retry_kwargs)
|
||||||
limiter = ratelimiter.RateLimiter(
|
def embed_func(c):
|
||||||
max_calls, period=self.rate_limiter_kwargs["period"]
|
return self.func(c.tolist())
|
||||||
)
|
|
||||||
rate_limited = limiter(embed_func)
|
else:
|
||||||
|
|
||||||
|
def embed_func(c):
|
||||||
|
return self.func(c.tolist())
|
||||||
|
|
||||||
|
if len(self.rate_limiter_kwargs) > 0:
|
||||||
|
import ratelimiter
|
||||||
|
|
||||||
|
max_calls = self.rate_limiter_kwargs["max_calls"]
|
||||||
|
limiter = ratelimiter.RateLimiter(
|
||||||
|
max_calls, period=self.rate_limiter_kwargs["period"]
|
||||||
|
)
|
||||||
|
embed_func = limiter(embed_func)
|
||||||
batches = self.to_batches(text)
|
batches = self.to_batches(text)
|
||||||
embeds = [emb for c in batches for emb in rate_limited(c)]
|
embeds = [emb for c in batches for emb in embed_func(c)]
|
||||||
return embeds
|
return embeds
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"EmbeddingFunction(func={self.func})"
|
return f"EmbeddingFunction(func={self.func})"
|
||||||
|
|
||||||
def rate_limit(self, max_calls=0.9, period=1.0):
|
def rate_limit(self, max_calls=0.9, period=1.0):
|
||||||
|
import sys
|
||||||
|
|
||||||
|
v = int(sys.version_info.minor)
|
||||||
|
if v >= 11:
|
||||||
|
raise ValueError("rate limit only support up to 3.10")
|
||||||
self.rate_limiter_kwargs = dict(max_calls=max_calls, period=period)
|
self.rate_limiter_kwargs = dict(max_calls=max_calls, period=period)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@@ -102,4 +117,4 @@ class EmbeddingFunction:
|
|||||||
|
|
||||||
yield from tqdm(_chunker(arr), total=math.ceil(length / self._batch_size))
|
yield from tqdm(_chunker(arr), total=math.ceil(length / self._batch_size))
|
||||||
else:
|
else:
|
||||||
return _chunker(arr)
|
yield from _chunker(arr)
|
||||||
|
|||||||
@@ -55,6 +55,27 @@ class LanceTable:
|
|||||||
"""Return the schema of the table."""
|
"""Return the schema of the table."""
|
||||||
return self._dataset.schema
|
return self._dataset.schema
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self._dataset.count_rows()
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"LanceTable({self.name})"
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return self.__repr__()
|
||||||
|
|
||||||
|
def head(self, n=5) -> pa.Table:
|
||||||
|
"""Return the first n rows of the table."""
|
||||||
|
return self._dataset.head(n)
|
||||||
|
|
||||||
|
def to_pandas(self) -> pd.DataFrame:
|
||||||
|
"""Return the table as a pandas DataFrame."""
|
||||||
|
return self.to_arrow().to_pandas()
|
||||||
|
|
||||||
|
def to_arrow(self) -> pa.Table:
|
||||||
|
"""Return the table as a pyarrow Table."""
|
||||||
|
return self._dataset.to_table()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _dataset_uri(self) -> str:
|
def _dataset_uri(self) -> str:
|
||||||
return os.path.join(self._conn.uri, f"{self.name}.lance")
|
return os.path.join(self._conn.uri, f"{self.name}.lance")
|
||||||
|
|||||||
42
python/tests/test_embeddings.py
Normal file
42
python/tests/test_embeddings.py
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
# Copyright 2023 LanceDB Developers
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pyarrow as pa
|
||||||
|
|
||||||
|
from lancedb.embeddings import with_embeddings
|
||||||
|
|
||||||
|
|
||||||
|
def mock_embed_func(input_data):
|
||||||
|
return [np.random.randn(128).tolist() for _ in range(len(input_data))]
|
||||||
|
|
||||||
|
|
||||||
|
def test_with_embeddings():
|
||||||
|
for wrap_api in [True, False]:
|
||||||
|
if wrap_api and sys.version_info.minor >= 11:
|
||||||
|
# ratelimiter package doesn't work on 3.11
|
||||||
|
continue
|
||||||
|
data = pa.Table.from_arrays(
|
||||||
|
[
|
||||||
|
pa.array(["foo", "bar"]),
|
||||||
|
pa.array([10.0, 20.0]),
|
||||||
|
],
|
||||||
|
names=["text", "price"],
|
||||||
|
)
|
||||||
|
data = with_embeddings(mock_embed_func, data, wrap_api=wrap_api)
|
||||||
|
assert data.num_columns == 3
|
||||||
|
assert data.num_rows == 2
|
||||||
|
assert data.column_names == ["text", "price", "vector"]
|
||||||
|
assert data.column("text").to_pylist() == ["foo", "bar"]
|
||||||
|
assert data.column("price").to_pylist() == [10.0, 20.0]
|
||||||
Reference in New Issue
Block a user