mirror of
https://github.com/lancedb/lancedb.git
synced 2026-03-30 04:20:39 +00:00
## Summary Fixes #1679 This PR prevents the OpenAI embedding function from retrying when receiving a 401 Unauthorized error. Authentication errors are permanent failures that won't be fixed by retrying, yet the current implementation retries all exceptions up to 7 times by default. ## Changes - Modified `retry_with_exponential_backoff` in `utils.py` to check for non-retryable errors before retrying - Added `_is_non_retryable_error` helper function that detects: - Exceptions with name `AuthenticationError` (OpenAI's 401 error) - Exceptions with `status_code` attribute of 401 or 403 - Enhanced OpenAI embeddings to explicitly catch and re-raise `AuthenticationError` with better logging - Added unit test `test_openai_no_retry_on_401` to verify authentication errors don't trigger retries ## Test Plan - Added test that verifies: 1. A function raising `AuthenticationError` is only called once 2. No retry delays occur (sleep is never called) - Existing tests continue to pass - Formatting applied via `make format` ## Example Behavior **Before**: With an invalid API key, users would see 7 retry attempts over ~2 minutes: ``` WARNING:root:Error occurred: Error code: 401 - {'error': {'message': 'Incorrect API key provided...'}} Retrying in 3.97 seconds (retry 1 of 7) WARNING:root:Error occurred: Error code: 401... Retrying in 7.94 seconds (retry 2 of 7) ... ``` **After**: With an invalid API key, the error is raised immediately: ``` ERROR:root:Authentication failed: Invalid API key provided AuthenticationError: Error code: 401 - {'error': {'message': 'Incorrect API key provided...'}} ``` This provides better UX and prevents unnecessary API calls that would fail anyway. --------- Co-authored-by: Will Jones <willjones127@gmail.com>
346 lines
9.5 KiB
Python
346 lines
9.5 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
|
|
|
|
|
import functools
|
|
import math
|
|
import random
|
|
import socket
|
|
import sys
|
|
import threading
|
|
import time
|
|
import urllib.error
|
|
import weakref
|
|
import logging
|
|
from functools import wraps
|
|
from typing import Callable, List, Union
|
|
import numpy as np
|
|
import pyarrow as pa
|
|
|
|
from ..dependencies import pandas as pd
|
|
from ..util import attempt_import_or_raise
|
|
|
|
|
|
def create_import_stub(module_name: str, package_name: str = None):
|
|
"""
|
|
Create a stub module that allows class definition but fails when used.
|
|
This allows modules to be imported for doctest collection even when
|
|
optional dependencies are not available.
|
|
|
|
Parameters
|
|
----------
|
|
module_name : str
|
|
The name of the module to create a stub for
|
|
package_name : str, optional
|
|
The package name to suggest in the error message
|
|
|
|
Returns
|
|
-------
|
|
object
|
|
A stub object that can be used in place of the module
|
|
"""
|
|
|
|
class _ImportStub:
|
|
def __getattr__(self, name):
|
|
return _ImportStub # Return stub for chained access like nn.Module
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
pkg = package_name or module_name
|
|
raise ImportError(f"You need to install {pkg} to use this functionality")
|
|
|
|
return _ImportStub()
|
|
|
|
|
|
# ruff: noqa: PERF203
|
|
def retry(tries=10, delay=1, max_delay=30, backoff=3, jitter=1):
|
|
def wrapper(fn):
|
|
@wraps(fn)
|
|
def wrapped(*args, **kwargs):
|
|
for i in range(tries):
|
|
try:
|
|
return fn(*args, **kwargs)
|
|
except Exception:
|
|
if i + 1 == tries:
|
|
raise
|
|
else:
|
|
sleep = min(delay * (backoff**i) + jitter, max_delay)
|
|
time.sleep(sleep)
|
|
|
|
return wrapped
|
|
|
|
return wrapper
|
|
|
|
|
|
DATA = Union[pa.Table, "pd.DataFrame"]
|
|
TEXT = Union[str, List[str], pa.Array, pa.ChunkedArray, np.ndarray]
|
|
IMAGES = Union[
|
|
str, bytes, List[str], List[bytes], pa.Array, pa.ChunkedArray, np.ndarray
|
|
]
|
|
AUDIO = Union[str, bytes, List[str], List[bytes], pa.Array, pa.ChunkedArray, np.ndarray]
|
|
|
|
|
|
class RateLimiter:
|
|
def __init__(self, max_calls: int = 1, period: float = 1.0):
|
|
self.period = period
|
|
self.max_calls = max(1, min(sys.maxsize, math.floor(max_calls)))
|
|
|
|
self._last_reset = time.time()
|
|
self._num_calls = 0
|
|
self._lock = threading.RLock()
|
|
|
|
def _check_sleep(self) -> float:
|
|
current_time = time.time()
|
|
elapsed = current_time - self._last_reset
|
|
period_remaining = self.period - elapsed
|
|
|
|
# If the time window has elapsed then reset.
|
|
if period_remaining <= 0:
|
|
self._num_calls = 0
|
|
self._last_reset = current_time
|
|
|
|
self._num_calls += 1
|
|
|
|
if self._num_calls > self.max_calls:
|
|
return period_remaining
|
|
|
|
return 0.0
|
|
|
|
def __call__(self, func):
|
|
@functools.wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
with self._lock:
|
|
time.sleep(self._check_sleep())
|
|
return func(*args, **kwargs)
|
|
|
|
return wrapper
|
|
|
|
|
|
class FunctionWrapper:
|
|
"""
|
|
A wrapper for embedding functions that adds rate limiting, retries, and batching.
|
|
"""
|
|
|
|
def __init__(self, func: Callable):
|
|
self.func = func
|
|
self.rate_limiter_kwargs = {}
|
|
self.retry_kwargs = {}
|
|
self._batch_size = None
|
|
self._progress = False
|
|
|
|
def __call__(self, text):
|
|
# Get the embedding with retry
|
|
if len(self.retry_kwargs) > 0:
|
|
|
|
@retry(**self.retry_kwargs)
|
|
def embed_func(c):
|
|
return self.func(c.tolist())
|
|
|
|
else:
|
|
|
|
def embed_func(c):
|
|
return self.func(c.tolist())
|
|
|
|
if self.rate_limiter_kwargs:
|
|
limiter = RateLimiter(
|
|
max_calls=self.rate_limiter_kwargs["max_calls"],
|
|
period=self.rate_limiter_kwargs["period"],
|
|
)
|
|
embed_func = limiter(embed_func)
|
|
batches = self.to_batches(text)
|
|
embeds = [emb for c in batches for emb in embed_func(c)]
|
|
return embeds
|
|
|
|
def __repr__(self):
|
|
return f"EmbeddingFunction(func={self.func})"
|
|
|
|
def rate_limit(self, max_calls=0.9, period=1.0):
|
|
self.rate_limiter_kwargs = dict(max_calls=max_calls, period=period)
|
|
return self
|
|
|
|
def retry(self, tries=10, delay=1, max_delay=30, backoff=3, jitter=1):
|
|
self.retry_kwargs = dict(
|
|
tries=tries,
|
|
delay=delay,
|
|
max_delay=max_delay,
|
|
backoff=backoff,
|
|
jitter=jitter,
|
|
)
|
|
return self
|
|
|
|
def batch_size(self, batch_size):
|
|
self._batch_size = batch_size
|
|
return self
|
|
|
|
def show_progress(self):
|
|
self._progress = True
|
|
return self
|
|
|
|
def to_batches(self, arr):
|
|
length = len(arr)
|
|
|
|
def _chunker(arr):
|
|
for start_i in range(0, len(arr), self._batch_size):
|
|
yield arr[start_i : start_i + self._batch_size]
|
|
|
|
if self._progress:
|
|
from tqdm.auto import tqdm
|
|
|
|
yield from tqdm(_chunker(arr), total=math.ceil(length / self._batch_size))
|
|
else:
|
|
yield from _chunker(arr)
|
|
|
|
|
|
def weak_lru(maxsize=128):
|
|
"""
|
|
LRU cache that keeps weak references to the objects it caches. Only caches the
|
|
latest instance of the objects to make sure memory usage is bounded.
|
|
|
|
Parameters
|
|
----------
|
|
maxsize : int, default 128
|
|
The maximum number of objects to cache.
|
|
|
|
Returns
|
|
-------
|
|
Callable
|
|
A decorator that can be applied to a method.
|
|
|
|
Examples
|
|
--------
|
|
>>> class Foo:
|
|
... @weak_lru()
|
|
... def bar(self, x):
|
|
... return x
|
|
>>> foo = Foo()
|
|
>>> foo.bar(1)
|
|
1
|
|
>>> foo.bar(2)
|
|
2
|
|
>>> foo.bar(1)
|
|
1
|
|
"""
|
|
|
|
def wrapper(func):
|
|
@functools.lru_cache(maxsize)
|
|
def _func(_self, *args, **kwargs):
|
|
return func(_self(), *args, **kwargs)
|
|
|
|
@functools.wraps(func)
|
|
def inner(self, *args, **kwargs):
|
|
return _func(weakref.ref(self), *args, **kwargs)
|
|
|
|
return inner
|
|
|
|
return wrapper
|
|
|
|
|
|
def retry_with_exponential_backoff(
|
|
func,
|
|
initial_delay: float = 1,
|
|
exponential_base: float = 2,
|
|
jitter: bool = True,
|
|
max_retries: int = 7,
|
|
):
|
|
"""Retry a function with exponential backoff.
|
|
|
|
Args:
|
|
func (function): The function to be retried.
|
|
initial_delay (float): Initial delay in seconds (default is 1).
|
|
exponential_base (float): The base for exponential backoff (default is 2).
|
|
jitter (bool): Whether to add jitter to the delay (default is True).
|
|
max_retries (int): Maximum number of retries (default is 10).
|
|
|
|
Returns:
|
|
function: The decorated function.
|
|
"""
|
|
|
|
def wrapper(*args, **kwargs):
|
|
num_retries = 0
|
|
delay = initial_delay
|
|
|
|
# Loop until a successful response or max_retries is hit or an exception
|
|
# is raised
|
|
while True:
|
|
try:
|
|
return func(*args, **kwargs)
|
|
|
|
# Currently retrying on all exceptions as there is no way to know the
|
|
# format of the error msgs used by different APIs. We'll log the error
|
|
# and say that it is assumed that if this portion errors out, it's due
|
|
# to rate limit but the user should check the error message to be sure.
|
|
except Exception as e: # noqa: PERF203
|
|
# Don't retry on authentication errors (e.g., OpenAI 401)
|
|
# These are permanent failures that won't be fixed by retrying
|
|
if _is_non_retryable_error(e):
|
|
raise
|
|
|
|
num_retries += 1
|
|
|
|
if num_retries > max_retries:
|
|
raise Exception(
|
|
f"Maximum number of retries ({max_retries}) exceeded.", e
|
|
)
|
|
|
|
delay *= exponential_base * (1 + jitter * random.random())
|
|
logging.warning(
|
|
"Error occurred: %s \n Retrying in %s seconds (retry %s of %s) \n",
|
|
e,
|
|
delay,
|
|
num_retries,
|
|
max_retries,
|
|
)
|
|
time.sleep(delay)
|
|
|
|
return wrapper
|
|
|
|
|
|
def _is_non_retryable_error(error: Exception) -> bool:
|
|
"""Check if an error should not be retried.
|
|
|
|
Args:
|
|
error: The exception to check
|
|
|
|
Returns:
|
|
True if the error should not be retried, False otherwise
|
|
"""
|
|
# Check for OpenAI authentication errors
|
|
error_type = type(error).__name__
|
|
if error_type == "AuthenticationError":
|
|
return True
|
|
|
|
# Check for other common non-retryable HTTP status codes
|
|
# 401 Unauthorized, 403 Forbidden
|
|
if hasattr(error, "status_code"):
|
|
if error.status_code in (401, 403):
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
def url_retrieve(url: str):
|
|
"""
|
|
Parameters
|
|
----------
|
|
url: str
|
|
URL to download from
|
|
"""
|
|
try:
|
|
with urllib.request.urlopen(url) as conn:
|
|
return conn.read()
|
|
except (socket.gaierror, urllib.error.URLError) as err:
|
|
raise ConnectionError("could not download {} due to {}".format(url, err))
|
|
|
|
|
|
def api_key_not_found_help(provider):
|
|
logging.error("Could not find API key for %s", provider)
|
|
raise ValueError(f"Please set the {provider.upper()}_API_KEY environment variable.")
|
|
|
|
|
|
def is_flash_attn_2_available():
|
|
try:
|
|
attempt_import_or_raise("flash_attn", "flash_attn")
|
|
|
|
return True
|
|
except ImportError:
|
|
return False
|