mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-06 20:02:58 +00:00
Add tutorial notebook
Convert contextualization and embeddings functionality. And use it with converted notebook for video search
This commit is contained in:
61
python/lancedb/context.py
Normal file
61
python/lancedb/context.py
Normal file
@@ -0,0 +1,61 @@
|
||||
# Copyright 2023 LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import pandas as pd
|
||||
|
||||
|
||||
def contextualize(raw_df):
|
||||
return Contextualizer(raw_df)
|
||||
|
||||
|
||||
class Contextualizer:
|
||||
def __init__(self, raw_df):
|
||||
self._text_col = None
|
||||
self._groupby = None
|
||||
self._stride = None
|
||||
self._window = None
|
||||
self._raw_df = raw_df
|
||||
|
||||
def window(self, window):
|
||||
self._window = window
|
||||
return self
|
||||
|
||||
def stride(self, stride):
|
||||
self._stride = stride
|
||||
return self
|
||||
|
||||
def groupby(self, groupby):
|
||||
self._groupby = groupby
|
||||
return self
|
||||
|
||||
def text_col(self, text_col):
|
||||
self._text_col = text_col
|
||||
return self
|
||||
|
||||
def to_df(self):
|
||||
def process_group(grp):
|
||||
# For each video, create the text rolling window
|
||||
text = grp[self._text_col].values
|
||||
contexts = grp.iloc[: -self._window : self._stride, :].copy()
|
||||
contexts[self._text_col] = [
|
||||
" ".join(text[start_i : start_i + self._window])
|
||||
for start_i in range(0, len(grp) - self._window, self._stride)
|
||||
]
|
||||
return contexts
|
||||
|
||||
if self._groupby is None:
|
||||
return process_group(self._raw_df)
|
||||
# concat result from all groups
|
||||
return pd.concat(
|
||||
[process_group(grp) for _, grp in self._raw_df.groupby(self._groupby)]
|
||||
)
|
||||
@@ -29,6 +29,7 @@ class LanceDBConnection:
|
||||
if isinstance(uri, str):
|
||||
uri = Path(uri)
|
||||
uri = uri.expanduser().absolute()
|
||||
Path(uri).mkdir(parents=True, exist_ok=True)
|
||||
self._uri = str(uri)
|
||||
|
||||
@property
|
||||
|
||||
105
python/lancedb/embeddings.py
Normal file
105
python/lancedb/embeddings.py
Normal file
@@ -0,0 +1,105 @@
|
||||
# Copyright 2023 LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
import ratelimiter
|
||||
from retry import retry
|
||||
from typing import Callable, Union
|
||||
|
||||
from lance.vector import vec_to_table
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pyarrow as pa
|
||||
|
||||
|
||||
def with_embeddings(
|
||||
func: Callable,
|
||||
data: Union[pa.Table, pd.DataFrame],
|
||||
column: str = "text",
|
||||
wrap_api: bool = True,
|
||||
show_progress: bool = False,
|
||||
batch_size: int = 1000,
|
||||
):
|
||||
func = EmbeddingFunction(func)
|
||||
if wrap_api:
|
||||
func = func.retry().rate_limit().batch_size(batch_size)
|
||||
if show_progress:
|
||||
func = func.show_progress()
|
||||
if isinstance(data, pd.DataFrame):
|
||||
data = pa.Table.from_pandas(data)
|
||||
embeddings = func(data[column].to_numpy())
|
||||
table = vec_to_table(np.array(embeddings))
|
||||
return data.append_column("vector", table["vector"])
|
||||
|
||||
|
||||
class EmbeddingFunction:
|
||||
def __init__(self, func: Callable):
|
||||
self.func = func
|
||||
self.rate_limiter_kwargs = {}
|
||||
self.retry_kwargs = {}
|
||||
self._batch_size = None
|
||||
self._progress = False
|
||||
|
||||
def __call__(self, text):
|
||||
# Get the embedding with retry
|
||||
@retry(**self.retry_kwargs)
|
||||
def embed_func(c):
|
||||
return self.func(c.tolist())
|
||||
|
||||
max_calls = self.rate_limiter_kwargs["max_calls"]
|
||||
limiter = ratelimiter.RateLimiter(
|
||||
max_calls, period=self.rate_limiter_kwargs["period"]
|
||||
)
|
||||
rate_limited = limiter(embed_func)
|
||||
batches = self.to_batches(text)
|
||||
embeds = [emb for c in batches for emb in rate_limited(c)]
|
||||
return embeds
|
||||
|
||||
def __repr__(self):
|
||||
return f"EmbeddingFunction(func={self.func})"
|
||||
|
||||
def rate_limit(self, max_calls=0.9, period=1.0):
|
||||
self.rate_limiter_kwargs = dict(max_calls=max_calls, period=period)
|
||||
return self
|
||||
|
||||
def retry(self, tries=10, delay=1, max_delay=30, backoff=3, jitter=1):
|
||||
self.retry_kwargs = dict(
|
||||
tries=tries,
|
||||
delay=delay,
|
||||
max_delay=max_delay,
|
||||
backoff=backoff,
|
||||
jitter=jitter,
|
||||
)
|
||||
return self
|
||||
|
||||
def batch_size(self, batch_size):
|
||||
self._batch_size = batch_size
|
||||
return self
|
||||
|
||||
def show_progress(self):
|
||||
self._progress = True
|
||||
return self
|
||||
|
||||
def to_batches(self, arr):
|
||||
length = len(arr)
|
||||
|
||||
def _chunker(arr):
|
||||
for start_i in range(0, len(arr), self._batch_size):
|
||||
yield arr[start_i : start_i + self._batch_size]
|
||||
|
||||
if self._progress:
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
yield from tqdm(_chunker(arr), total=math.ceil(length / self._batch_size))
|
||||
else:
|
||||
return _chunker(arr)
|
||||
@@ -24,6 +24,8 @@ class LanceQueryBuilder:
|
||||
"""
|
||||
|
||||
def __init__(self, table: "lancedb.table.LanceTable", query: np.ndarray):
|
||||
self._nprobes = 20
|
||||
self._refine_factor = None
|
||||
self._table = table
|
||||
self._query = query
|
||||
self._limit = 10
|
||||
@@ -75,6 +77,36 @@ class LanceQueryBuilder:
|
||||
self._where = where
|
||||
return self
|
||||
|
||||
def nprobes(self, nprobes: int) -> LanceQueryBuilder:
|
||||
"""Set the number of probes to use.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
nprobes: int
|
||||
The number of probes to use.
|
||||
|
||||
Returns
|
||||
-------
|
||||
The LanceQueryBuilder object.
|
||||
"""
|
||||
self._nprobes = nprobes
|
||||
return self
|
||||
|
||||
def refine_factor(self, refine_factor: int) -> LanceQueryBuilder:
|
||||
"""Set the refine factor to use.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
refine_factor: int
|
||||
The refine factor to use.
|
||||
|
||||
Returns
|
||||
-------
|
||||
The LanceQueryBuilder object.
|
||||
"""
|
||||
self._refine_factor = refine_factor
|
||||
return self
|
||||
|
||||
def to_df(self) -> pd.DataFrame:
|
||||
"""Execute the query and return the results as a pandas DataFrame."""
|
||||
ds = self._table.to_lance()
|
||||
@@ -82,6 +114,12 @@ class LanceQueryBuilder:
|
||||
tbl = ds.to_table(
|
||||
columns=self._columns,
|
||||
filter=self._where,
|
||||
nearest={"column": VECTOR_COLUMN_NAME, "q": self._query, "k": self._limit},
|
||||
nearest={
|
||||
"column": VECTOR_COLUMN_NAME,
|
||||
"q": self._query,
|
||||
"k": self._limit,
|
||||
"nprobes": self._nprobes,
|
||||
"refine_factor": self._refine_factor,
|
||||
},
|
||||
)
|
||||
return tbl.to_pandas()
|
||||
|
||||
@@ -59,6 +59,14 @@ class LanceTable:
|
||||
def _dataset_uri(self) -> str:
|
||||
return os.path.join(self._conn.uri, f"{self.name}.lance")
|
||||
|
||||
def create_index(self, num_partitions=256, num_sub_vectors=96):
|
||||
return self._dataset.create_index(
|
||||
column=VECTOR_COLUMN_NAME,
|
||||
index_type="IVF_PQ",
|
||||
num_partitions=num_partitions,
|
||||
num_sub_vectors=num_sub_vectors,
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def _dataset(self) -> LanceDataset:
|
||||
return lance.dataset(self._dataset_uri)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
[project]
|
||||
name = "lancedb"
|
||||
version = "0.0.1"
|
||||
dependencies = ["pylance"]
|
||||
dependencies = ["pylance", "ratelimiter", "retry", "tqdm"]
|
||||
description = "lancedb"
|
||||
authors = [
|
||||
{ name = "Lance Devs", email = "dev@eto.ai" },
|
||||
@@ -43,7 +43,7 @@ dev = [
|
||||
"ruff", "pre-commit", "black"
|
||||
]
|
||||
docs = [
|
||||
"mkdocs", "mkdocs-material", "mkdocstrings[python]"
|
||||
"mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"
|
||||
]
|
||||
|
||||
[build-system]
|
||||
|
||||
Reference in New Issue
Block a user