Files
lancedb/python/python/lancedb/embeddings/utils.py
Magnus 4f07fea6df feat: add ColPali embedding support with MultiVector type (#2170)
This PR adds ColPali support with ColPaliEmbeddings class (tagged
"colpali") using ColQwen2.5 for multi-vector text/image embeddings. Also
added MultiVector Pydantic type to handle the vector lists.

I've added some integration test for the embedding model and some unit
test for the new Pydantic type. Could be a template for other ColPali
variants as well. or until transformers🤗 starts supporting it.


Still `TODO`:

- [ ] Documentation
- [ ] Add an example

_Could also allow Image as query, but didn't work well when testing it._

[ColPali-Engine](https://github.com/illuin-tech/colpali) version:
0.3.9.dev17+g3faee24

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- Introduced support for ColPali-based multimodal multi-vector
embeddings for both text and images.
- Added a new embedding class for generating multi-vector embeddings,
configurable for various model and processing options.
- Added a new Pydantic type for multi-vector embeddings, supporting
validation and schema generation for lists of fixed-dimension vectors.

- **Bug Fixes**
- Ensured proper asynchronous index creation in query tests for improved
reliability.

- **Tests**
- Added integration tests for ColPali embeddings, including
text-to-image search and validation of multi-vector fields.
- Added comprehensive tests for the new multi-vector Pydantic type,
covering schema, validation, and default value behavior.

- **Chores**
  - Updated optional dependencies to include the ColPali engine.
  - Added utility to check for availability of flash attention support.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
2025-04-21 11:47:37 +08:00

288 lines
7.8 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
import functools
import math
import random
import socket
import sys
import threading
import time
import urllib.error
import weakref
import logging
from functools import wraps
from typing import Callable, List, Union
import numpy as np
import pyarrow as pa
from ..dependencies import pandas as pd
from ..util import attempt_import_or_raise
# ruff: noqa: PERF203
def retry(tries=10, delay=1, max_delay=30, backoff=3, jitter=1):
def wrapper(fn):
@wraps(fn)
def wrapped(*args, **kwargs):
for i in range(tries):
try:
return fn(*args, **kwargs)
except Exception:
if i + 1 == tries:
raise
else:
sleep = min(delay * (backoff**i) + jitter, max_delay)
time.sleep(sleep)
return wrapped
return wrapper
DATA = Union[pa.Table, "pd.DataFrame"]
TEXT = Union[str, List[str], pa.Array, pa.ChunkedArray, np.ndarray]
IMAGES = Union[
str, bytes, List[str], List[bytes], pa.Array, pa.ChunkedArray, np.ndarray
]
AUDIO = Union[str, bytes, List[str], List[bytes], pa.Array, pa.ChunkedArray, np.ndarray]
class RateLimiter:
def __init__(self, max_calls: int = 1, period: float = 1.0):
self.period = period
self.max_calls = max(1, min(sys.maxsize, math.floor(max_calls)))
self._last_reset = time.time()
self._num_calls = 0
self._lock = threading.RLock()
def _check_sleep(self) -> float:
current_time = time.time()
elapsed = current_time - self._last_reset
period_remaining = self.period - elapsed
# If the time window has elapsed then reset.
if period_remaining <= 0:
self._num_calls = 0
self._last_reset = current_time
self._num_calls += 1
if self._num_calls > self.max_calls:
return period_remaining
return 0.0
def __call__(self, func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
with self._lock:
time.sleep(self._check_sleep())
return func(*args, **kwargs)
return wrapper
class FunctionWrapper:
"""
A wrapper for embedding functions that adds rate limiting, retries, and batching.
"""
def __init__(self, func: Callable):
self.func = func
self.rate_limiter_kwargs = {}
self.retry_kwargs = {}
self._batch_size = None
self._progress = False
def __call__(self, text):
# Get the embedding with retry
if len(self.retry_kwargs) > 0:
@retry(**self.retry_kwargs)
def embed_func(c):
return self.func(c.tolist())
else:
def embed_func(c):
return self.func(c.tolist())
if self.rate_limiter_kwargs:
limiter = RateLimiter(
max_calls=self.rate_limiter_kwargs["max_calls"],
period=self.rate_limiter_kwargs["period"],
)
embed_func = limiter(embed_func)
batches = self.to_batches(text)
embeds = [emb for c in batches for emb in embed_func(c)]
return embeds
def __repr__(self):
return f"EmbeddingFunction(func={self.func})"
def rate_limit(self, max_calls=0.9, period=1.0):
self.rate_limiter_kwargs = dict(max_calls=max_calls, period=period)
return self
def retry(self, tries=10, delay=1, max_delay=30, backoff=3, jitter=1):
self.retry_kwargs = dict(
tries=tries,
delay=delay,
max_delay=max_delay,
backoff=backoff,
jitter=jitter,
)
return self
def batch_size(self, batch_size):
self._batch_size = batch_size
return self
def show_progress(self):
self._progress = True
return self
def to_batches(self, arr):
length = len(arr)
def _chunker(arr):
for start_i in range(0, len(arr), self._batch_size):
yield arr[start_i : start_i + self._batch_size]
if self._progress:
from tqdm.auto import tqdm
yield from tqdm(_chunker(arr), total=math.ceil(length / self._batch_size))
else:
yield from _chunker(arr)
def weak_lru(maxsize=128):
"""
LRU cache that keeps weak references to the objects it caches. Only caches the
latest instance of the objects to make sure memory usage is bounded.
Parameters
----------
maxsize : int, default 128
The maximum number of objects to cache.
Returns
-------
Callable
A decorator that can be applied to a method.
Examples
--------
>>> class Foo:
... @weak_lru()
... def bar(self, x):
... return x
>>> foo = Foo()
>>> foo.bar(1)
1
>>> foo.bar(2)
2
>>> foo.bar(1)
1
"""
def wrapper(func):
@functools.lru_cache(maxsize)
def _func(_self, *args, **kwargs):
return func(_self(), *args, **kwargs)
@functools.wraps(func)
def inner(self, *args, **kwargs):
return _func(weakref.ref(self), *args, **kwargs)
return inner
return wrapper
def retry_with_exponential_backoff(
func,
initial_delay: float = 1,
exponential_base: float = 2,
jitter: bool = True,
max_retries: int = 7,
):
"""Retry a function with exponential backoff.
Args:
func (function): The function to be retried.
initial_delay (float): Initial delay in seconds (default is 1).
exponential_base (float): The base for exponential backoff (default is 2).
jitter (bool): Whether to add jitter to the delay (default is True).
max_retries (int): Maximum number of retries (default is 10).
Returns:
function: The decorated function.
"""
def wrapper(*args, **kwargs):
num_retries = 0
delay = initial_delay
# Loop until a successful response or max_retries is hit or an exception
# is raised
while True:
try:
return func(*args, **kwargs)
# Currently retrying on all exceptions as there is no way to know the
# format of the error msgs used by different APIs. We'll log the error
# and say that it is assumed that if this portion errors out, it's due
# to rate limit but the user should check the error message to be sure.
except Exception as e: # noqa: PERF203
num_retries += 1
if num_retries > max_retries:
raise Exception(
f"Maximum number of retries ({max_retries}) exceeded.", e
)
delay *= exponential_base * (1 + jitter * random.random())
logging.warning(
"Error occurred: %s \n Retrying in %s seconds (retry %s of %s) \n",
e,
delay,
num_retries,
max_retries,
)
time.sleep(delay)
return wrapper
def url_retrieve(url: str):
"""
Parameters
----------
url: str
URL to download from
"""
try:
with urllib.request.urlopen(url) as conn:
return conn.read()
except (socket.gaierror, urllib.error.URLError) as err:
raise ConnectionError("could not download {} due to {}".format(url, err))
def api_key_not_found_help(provider):
logging.error("Could not find API key for %s", provider)
raise ValueError(f"Please set the {provider.upper()}_API_KEY environment variable.")
def is_flash_attn_2_available():
try:
attempt_import_or_raise("flash_attn", "flash_attn")
return True
except ImportError:
return False