Files
lancedb/python/python/lancedb/embeddings/utils.py
Varun Chawla 2802764092 fix(embeddings): stop retrying OpenAI 401 authentication errors (#2995)
## Summary
Fixes #1679

This PR prevents the OpenAI embedding function from retrying when
receiving a 401 Unauthorized error. Authentication errors are permanent
failures that won't be fixed by retrying, yet the current implementation
retries all exceptions up to 7 times by default.

## Changes
- Modified `retry_with_exponential_backoff` in `utils.py` to check for
non-retryable errors before retrying
- Added `_is_non_retryable_error` helper function that detects:
  - Exceptions with name `AuthenticationError` (OpenAI's 401 error)
  - Exceptions with `status_code` attribute of 401 or 403
- Enhanced OpenAI embeddings to explicitly catch and re-raise
`AuthenticationError` with better logging
- Added unit test `test_openai_no_retry_on_401` to verify authentication
errors don't trigger retries

## Test Plan
- Added test that verifies:
  1. A function raising `AuthenticationError` is only called once
  2. No retry delays occur (sleep is never called)
- Existing tests continue to pass
- Formatting applied via `make format`

## Example Behavior

**Before**: With an invalid API key, users would see 7 retry attempts
over ~2 minutes:
```
WARNING:root:Error occurred: Error code: 401 - {'error': {'message': 'Incorrect API key provided...'}}
 Retrying in 3.97 seconds (retry 1 of 7)
WARNING:root:Error occurred: Error code: 401...
 Retrying in 7.94 seconds (retry 2 of 7)
...
```

**After**: With an invalid API key, the error is raised immediately:
```
ERROR:root:Authentication failed: Invalid API key provided
AuthenticationError: Error code: 401 - {'error': {'message': 'Incorrect API key provided...'}}
```

This provides better UX and prevents unnecessary API calls that would
fail anyway.

---------

Co-authored-by: Will Jones <willjones127@gmail.com>
2026-02-19 09:20:54 -08:00

346 lines
9.5 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
import functools
import math
import random
import socket
import sys
import threading
import time
import urllib.error
import weakref
import logging
from functools import wraps
from typing import Callable, List, Union
import numpy as np
import pyarrow as pa
from ..dependencies import pandas as pd
from ..util import attempt_import_or_raise
def create_import_stub(module_name: str, package_name: str = None):
"""
Create a stub module that allows class definition but fails when used.
This allows modules to be imported for doctest collection even when
optional dependencies are not available.
Parameters
----------
module_name : str
The name of the module to create a stub for
package_name : str, optional
The package name to suggest in the error message
Returns
-------
object
A stub object that can be used in place of the module
"""
class _ImportStub:
def __getattr__(self, name):
return _ImportStub # Return stub for chained access like nn.Module
def __call__(self, *args, **kwargs):
pkg = package_name or module_name
raise ImportError(f"You need to install {pkg} to use this functionality")
return _ImportStub()
# ruff: noqa: PERF203
def retry(tries=10, delay=1, max_delay=30, backoff=3, jitter=1):
def wrapper(fn):
@wraps(fn)
def wrapped(*args, **kwargs):
for i in range(tries):
try:
return fn(*args, **kwargs)
except Exception:
if i + 1 == tries:
raise
else:
sleep = min(delay * (backoff**i) + jitter, max_delay)
time.sleep(sleep)
return wrapped
return wrapper
DATA = Union[pa.Table, "pd.DataFrame"]
TEXT = Union[str, List[str], pa.Array, pa.ChunkedArray, np.ndarray]
IMAGES = Union[
str, bytes, List[str], List[bytes], pa.Array, pa.ChunkedArray, np.ndarray
]
AUDIO = Union[str, bytes, List[str], List[bytes], pa.Array, pa.ChunkedArray, np.ndarray]
class RateLimiter:
def __init__(self, max_calls: int = 1, period: float = 1.0):
self.period = period
self.max_calls = max(1, min(sys.maxsize, math.floor(max_calls)))
self._last_reset = time.time()
self._num_calls = 0
self._lock = threading.RLock()
def _check_sleep(self) -> float:
current_time = time.time()
elapsed = current_time - self._last_reset
period_remaining = self.period - elapsed
# If the time window has elapsed then reset.
if period_remaining <= 0:
self._num_calls = 0
self._last_reset = current_time
self._num_calls += 1
if self._num_calls > self.max_calls:
return period_remaining
return 0.0
def __call__(self, func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
with self._lock:
time.sleep(self._check_sleep())
return func(*args, **kwargs)
return wrapper
class FunctionWrapper:
"""
A wrapper for embedding functions that adds rate limiting, retries, and batching.
"""
def __init__(self, func: Callable):
self.func = func
self.rate_limiter_kwargs = {}
self.retry_kwargs = {}
self._batch_size = None
self._progress = False
def __call__(self, text):
# Get the embedding with retry
if len(self.retry_kwargs) > 0:
@retry(**self.retry_kwargs)
def embed_func(c):
return self.func(c.tolist())
else:
def embed_func(c):
return self.func(c.tolist())
if self.rate_limiter_kwargs:
limiter = RateLimiter(
max_calls=self.rate_limiter_kwargs["max_calls"],
period=self.rate_limiter_kwargs["period"],
)
embed_func = limiter(embed_func)
batches = self.to_batches(text)
embeds = [emb for c in batches for emb in embed_func(c)]
return embeds
def __repr__(self):
return f"EmbeddingFunction(func={self.func})"
def rate_limit(self, max_calls=0.9, period=1.0):
self.rate_limiter_kwargs = dict(max_calls=max_calls, period=period)
return self
def retry(self, tries=10, delay=1, max_delay=30, backoff=3, jitter=1):
self.retry_kwargs = dict(
tries=tries,
delay=delay,
max_delay=max_delay,
backoff=backoff,
jitter=jitter,
)
return self
def batch_size(self, batch_size):
self._batch_size = batch_size
return self
def show_progress(self):
self._progress = True
return self
def to_batches(self, arr):
length = len(arr)
def _chunker(arr):
for start_i in range(0, len(arr), self._batch_size):
yield arr[start_i : start_i + self._batch_size]
if self._progress:
from tqdm.auto import tqdm
yield from tqdm(_chunker(arr), total=math.ceil(length / self._batch_size))
else:
yield from _chunker(arr)
def weak_lru(maxsize=128):
"""
LRU cache that keeps weak references to the objects it caches. Only caches the
latest instance of the objects to make sure memory usage is bounded.
Parameters
----------
maxsize : int, default 128
The maximum number of objects to cache.
Returns
-------
Callable
A decorator that can be applied to a method.
Examples
--------
>>> class Foo:
... @weak_lru()
... def bar(self, x):
... return x
>>> foo = Foo()
>>> foo.bar(1)
1
>>> foo.bar(2)
2
>>> foo.bar(1)
1
"""
def wrapper(func):
@functools.lru_cache(maxsize)
def _func(_self, *args, **kwargs):
return func(_self(), *args, **kwargs)
@functools.wraps(func)
def inner(self, *args, **kwargs):
return _func(weakref.ref(self), *args, **kwargs)
return inner
return wrapper
def retry_with_exponential_backoff(
func,
initial_delay: float = 1,
exponential_base: float = 2,
jitter: bool = True,
max_retries: int = 7,
):
"""Retry a function with exponential backoff.
Args:
func (function): The function to be retried.
initial_delay (float): Initial delay in seconds (default is 1).
exponential_base (float): The base for exponential backoff (default is 2).
jitter (bool): Whether to add jitter to the delay (default is True).
max_retries (int): Maximum number of retries (default is 10).
Returns:
function: The decorated function.
"""
def wrapper(*args, **kwargs):
num_retries = 0
delay = initial_delay
# Loop until a successful response or max_retries is hit or an exception
# is raised
while True:
try:
return func(*args, **kwargs)
# Currently retrying on all exceptions as there is no way to know the
# format of the error msgs used by different APIs. We'll log the error
# and say that it is assumed that if this portion errors out, it's due
# to rate limit but the user should check the error message to be sure.
except Exception as e: # noqa: PERF203
# Don't retry on authentication errors (e.g., OpenAI 401)
# These are permanent failures that won't be fixed by retrying
if _is_non_retryable_error(e):
raise
num_retries += 1
if num_retries > max_retries:
raise Exception(
f"Maximum number of retries ({max_retries}) exceeded.", e
)
delay *= exponential_base * (1 + jitter * random.random())
logging.warning(
"Error occurred: %s \n Retrying in %s seconds (retry %s of %s) \n",
e,
delay,
num_retries,
max_retries,
)
time.sleep(delay)
return wrapper
def _is_non_retryable_error(error: Exception) -> bool:
"""Check if an error should not be retried.
Args:
error: The exception to check
Returns:
True if the error should not be retried, False otherwise
"""
# Check for OpenAI authentication errors
error_type = type(error).__name__
if error_type == "AuthenticationError":
return True
# Check for other common non-retryable HTTP status codes
# 401 Unauthorized, 403 Forbidden
if hasattr(error, "status_code"):
if error.status_code in (401, 403):
return True
return False
def url_retrieve(url: str):
"""
Parameters
----------
url: str
URL to download from
"""
try:
with urllib.request.urlopen(url) as conn:
return conn.read()
except (socket.gaierror, urllib.error.URLError) as err:
raise ConnectionError("could not download {} due to {}".format(url, err))
def api_key_not_found_help(provider):
logging.error("Could not find API key for %s", provider)
raise ValueError(f"Please set the {provider.upper()}_API_KEY environment variable.")
def is_flash_attn_2_available():
try:
attempt_import_or_raise("flash_attn", "flash_attn")
return True
except ImportError:
return False