mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-25 14:29:56 +00:00
This PR adds ColPali support with ColPaliEmbeddings class (tagged "colpali") using ColQwen2.5 for multi-vector text/image embeddings. Also added MultiVector Pydantic type to handle the vector lists. I've added some integration test for the embedding model and some unit test for the new Pydantic type. Could be a template for other ColPali variants as well. or until transformers🤗 starts supporting it. Still `TODO`: - [ ] Documentation - [ ] Add an example _Could also allow Image as query, but didn't work well when testing it._ [ColPali-Engine](https://github.com/illuin-tech/colpali) version: 0.3.9.dev17+g3faee24 <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Introduced support for ColPali-based multimodal multi-vector embeddings for both text and images. - Added a new embedding class for generating multi-vector embeddings, configurable for various model and processing options. - Added a new Pydantic type for multi-vector embeddings, supporting validation and schema generation for lists of fixed-dimension vectors. - **Bug Fixes** - Ensured proper asynchronous index creation in query tests for improved reliability. - **Tests** - Added integration tests for ColPali embeddings, including text-to-image search and validation of multi-vector fields. - Added comprehensive tests for the new multi-vector Pydantic type, covering schema, validation, and default value behavior. - **Chores** - Updated optional dependencies to include the ColPali engine. - Added utility to check for availability of flash attention support. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
288 lines
7.8 KiB
Python
288 lines
7.8 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
|
|
|
|
|
import functools
|
|
import math
|
|
import random
|
|
import socket
|
|
import sys
|
|
import threading
|
|
import time
|
|
import urllib.error
|
|
import weakref
|
|
import logging
|
|
from functools import wraps
|
|
from typing import Callable, List, Union
|
|
import numpy as np
|
|
import pyarrow as pa
|
|
|
|
from ..dependencies import pandas as pd
|
|
from ..util import attempt_import_or_raise
|
|
|
|
|
|
# ruff: noqa: PERF203
|
|
def retry(tries=10, delay=1, max_delay=30, backoff=3, jitter=1):
|
|
def wrapper(fn):
|
|
@wraps(fn)
|
|
def wrapped(*args, **kwargs):
|
|
for i in range(tries):
|
|
try:
|
|
return fn(*args, **kwargs)
|
|
except Exception:
|
|
if i + 1 == tries:
|
|
raise
|
|
else:
|
|
sleep = min(delay * (backoff**i) + jitter, max_delay)
|
|
time.sleep(sleep)
|
|
|
|
return wrapped
|
|
|
|
return wrapper
|
|
|
|
|
|
DATA = Union[pa.Table, "pd.DataFrame"]
|
|
TEXT = Union[str, List[str], pa.Array, pa.ChunkedArray, np.ndarray]
|
|
IMAGES = Union[
|
|
str, bytes, List[str], List[bytes], pa.Array, pa.ChunkedArray, np.ndarray
|
|
]
|
|
AUDIO = Union[str, bytes, List[str], List[bytes], pa.Array, pa.ChunkedArray, np.ndarray]
|
|
|
|
|
|
class RateLimiter:
|
|
def __init__(self, max_calls: int = 1, period: float = 1.0):
|
|
self.period = period
|
|
self.max_calls = max(1, min(sys.maxsize, math.floor(max_calls)))
|
|
|
|
self._last_reset = time.time()
|
|
self._num_calls = 0
|
|
self._lock = threading.RLock()
|
|
|
|
def _check_sleep(self) -> float:
|
|
current_time = time.time()
|
|
elapsed = current_time - self._last_reset
|
|
period_remaining = self.period - elapsed
|
|
|
|
# If the time window has elapsed then reset.
|
|
if period_remaining <= 0:
|
|
self._num_calls = 0
|
|
self._last_reset = current_time
|
|
|
|
self._num_calls += 1
|
|
|
|
if self._num_calls > self.max_calls:
|
|
return period_remaining
|
|
|
|
return 0.0
|
|
|
|
def __call__(self, func):
|
|
@functools.wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
with self._lock:
|
|
time.sleep(self._check_sleep())
|
|
return func(*args, **kwargs)
|
|
|
|
return wrapper
|
|
|
|
|
|
class FunctionWrapper:
|
|
"""
|
|
A wrapper for embedding functions that adds rate limiting, retries, and batching.
|
|
"""
|
|
|
|
def __init__(self, func: Callable):
|
|
self.func = func
|
|
self.rate_limiter_kwargs = {}
|
|
self.retry_kwargs = {}
|
|
self._batch_size = None
|
|
self._progress = False
|
|
|
|
def __call__(self, text):
|
|
# Get the embedding with retry
|
|
if len(self.retry_kwargs) > 0:
|
|
|
|
@retry(**self.retry_kwargs)
|
|
def embed_func(c):
|
|
return self.func(c.tolist())
|
|
|
|
else:
|
|
|
|
def embed_func(c):
|
|
return self.func(c.tolist())
|
|
|
|
if self.rate_limiter_kwargs:
|
|
limiter = RateLimiter(
|
|
max_calls=self.rate_limiter_kwargs["max_calls"],
|
|
period=self.rate_limiter_kwargs["period"],
|
|
)
|
|
embed_func = limiter(embed_func)
|
|
batches = self.to_batches(text)
|
|
embeds = [emb for c in batches for emb in embed_func(c)]
|
|
return embeds
|
|
|
|
def __repr__(self):
|
|
return f"EmbeddingFunction(func={self.func})"
|
|
|
|
def rate_limit(self, max_calls=0.9, period=1.0):
|
|
self.rate_limiter_kwargs = dict(max_calls=max_calls, period=period)
|
|
return self
|
|
|
|
def retry(self, tries=10, delay=1, max_delay=30, backoff=3, jitter=1):
|
|
self.retry_kwargs = dict(
|
|
tries=tries,
|
|
delay=delay,
|
|
max_delay=max_delay,
|
|
backoff=backoff,
|
|
jitter=jitter,
|
|
)
|
|
return self
|
|
|
|
def batch_size(self, batch_size):
|
|
self._batch_size = batch_size
|
|
return self
|
|
|
|
def show_progress(self):
|
|
self._progress = True
|
|
return self
|
|
|
|
def to_batches(self, arr):
|
|
length = len(arr)
|
|
|
|
def _chunker(arr):
|
|
for start_i in range(0, len(arr), self._batch_size):
|
|
yield arr[start_i : start_i + self._batch_size]
|
|
|
|
if self._progress:
|
|
from tqdm.auto import tqdm
|
|
|
|
yield from tqdm(_chunker(arr), total=math.ceil(length / self._batch_size))
|
|
else:
|
|
yield from _chunker(arr)
|
|
|
|
|
|
def weak_lru(maxsize=128):
|
|
"""
|
|
LRU cache that keeps weak references to the objects it caches. Only caches the
|
|
latest instance of the objects to make sure memory usage is bounded.
|
|
|
|
Parameters
|
|
----------
|
|
maxsize : int, default 128
|
|
The maximum number of objects to cache.
|
|
|
|
Returns
|
|
-------
|
|
Callable
|
|
A decorator that can be applied to a method.
|
|
|
|
Examples
|
|
--------
|
|
>>> class Foo:
|
|
... @weak_lru()
|
|
... def bar(self, x):
|
|
... return x
|
|
>>> foo = Foo()
|
|
>>> foo.bar(1)
|
|
1
|
|
>>> foo.bar(2)
|
|
2
|
|
>>> foo.bar(1)
|
|
1
|
|
"""
|
|
|
|
def wrapper(func):
|
|
@functools.lru_cache(maxsize)
|
|
def _func(_self, *args, **kwargs):
|
|
return func(_self(), *args, **kwargs)
|
|
|
|
@functools.wraps(func)
|
|
def inner(self, *args, **kwargs):
|
|
return _func(weakref.ref(self), *args, **kwargs)
|
|
|
|
return inner
|
|
|
|
return wrapper
|
|
|
|
|
|
def retry_with_exponential_backoff(
|
|
func,
|
|
initial_delay: float = 1,
|
|
exponential_base: float = 2,
|
|
jitter: bool = True,
|
|
max_retries: int = 7,
|
|
):
|
|
"""Retry a function with exponential backoff.
|
|
|
|
Args:
|
|
func (function): The function to be retried.
|
|
initial_delay (float): Initial delay in seconds (default is 1).
|
|
exponential_base (float): The base for exponential backoff (default is 2).
|
|
jitter (bool): Whether to add jitter to the delay (default is True).
|
|
max_retries (int): Maximum number of retries (default is 10).
|
|
|
|
Returns:
|
|
function: The decorated function.
|
|
"""
|
|
|
|
def wrapper(*args, **kwargs):
|
|
num_retries = 0
|
|
delay = initial_delay
|
|
|
|
# Loop until a successful response or max_retries is hit or an exception
|
|
# is raised
|
|
while True:
|
|
try:
|
|
return func(*args, **kwargs)
|
|
|
|
# Currently retrying on all exceptions as there is no way to know the
|
|
# format of the error msgs used by different APIs. We'll log the error
|
|
# and say that it is assumed that if this portion errors out, it's due
|
|
# to rate limit but the user should check the error message to be sure.
|
|
except Exception as e: # noqa: PERF203
|
|
num_retries += 1
|
|
|
|
if num_retries > max_retries:
|
|
raise Exception(
|
|
f"Maximum number of retries ({max_retries}) exceeded.", e
|
|
)
|
|
|
|
delay *= exponential_base * (1 + jitter * random.random())
|
|
logging.warning(
|
|
"Error occurred: %s \n Retrying in %s seconds (retry %s of %s) \n",
|
|
e,
|
|
delay,
|
|
num_retries,
|
|
max_retries,
|
|
)
|
|
time.sleep(delay)
|
|
|
|
return wrapper
|
|
|
|
|
|
def url_retrieve(url: str):
|
|
"""
|
|
Parameters
|
|
----------
|
|
url: str
|
|
URL to download from
|
|
"""
|
|
try:
|
|
with urllib.request.urlopen(url) as conn:
|
|
return conn.read()
|
|
except (socket.gaierror, urllib.error.URLError) as err:
|
|
raise ConnectionError("could not download {} due to {}".format(url, err))
|
|
|
|
|
|
def api_key_not_found_help(provider):
|
|
logging.error("Could not find API key for %s", provider)
|
|
raise ValueError(f"Please set the {provider.upper()}_API_KEY environment variable.")
|
|
|
|
|
|
def is_flash_attn_2_available():
|
|
try:
|
|
attempt_import_or_raise("flash_attn", "flash_attn")
|
|
|
|
return True
|
|
except ImportError:
|
|
return False
|