mirror of
https://github.com/lancedb/lancedb.git
synced 2026-05-14 02:20:40 +00:00
fix(embeddings): stop retrying OpenAI 401 authentication errors (#2995)
## Summary Fixes #1679 This PR prevents the OpenAI embedding function from retrying when receiving a 401 Unauthorized error. Authentication errors are permanent failures that won't be fixed by retrying, yet the current implementation retries all exceptions up to 7 times by default. ## Changes - Modified `retry_with_exponential_backoff` in `utils.py` to check for non-retryable errors before retrying - Added `_is_non_retryable_error` helper function that detects: - Exceptions with name `AuthenticationError` (OpenAI's 401 error) - Exceptions with `status_code` attribute of 401 or 403 - Enhanced OpenAI embeddings to explicitly catch and re-raise `AuthenticationError` with better logging - Added unit test `test_openai_no_retry_on_401` to verify authentication errors don't trigger retries ## Test Plan - Added test that verifies: 1. A function raising `AuthenticationError` is only called once 2. No retry delays occur (sleep is never called) - Existing tests continue to pass - Formatting applied via `make format` ## Example Behavior **Before**: With an invalid API key, users would see 7 retry attempts over ~2 minutes: ``` WARNING:root:Error occurred: Error code: 401 - {'error': {'message': 'Incorrect API key provided...'}} Retrying in 3.97 seconds (retry 1 of 7) WARNING:root:Error occurred: Error code: 401... Retrying in 7.94 seconds (retry 2 of 7) ... ``` **After**: With an invalid API key, the error is raised immediately: ``` ERROR:root:Authentication failed: Invalid API key provided AuthenticationError: Error code: 401 - {'error': {'message': 'Incorrect API key provided...'}} ``` This provides better UX and prevents unnecessary API calls that would fail anyway. --------- Co-authored-by: Will Jones <willjones127@gmail.com>
This commit is contained in:
@@ -110,6 +110,9 @@ class OpenAIEmbeddings(TextEmbeddingFunction):
|
||||
valid_embeddings = {
|
||||
idx: v.embedding for v, idx in zip(rs.data, valid_indices)
|
||||
}
|
||||
except openai.AuthenticationError:
|
||||
logging.error("Authentication failed: Invalid API key provided")
|
||||
raise
|
||||
except openai.BadRequestError:
|
||||
logging.exception("Bad request: %s", texts)
|
||||
return [None] * len(texts)
|
||||
|
||||
@@ -269,6 +269,11 @@ def retry_with_exponential_backoff(
|
||||
# and say that it is assumed that if this portion errors out, it's due
|
||||
# to rate limit but the user should check the error message to be sure.
|
||||
except Exception as e: # noqa: PERF203
|
||||
# Don't retry on authentication errors (e.g., OpenAI 401)
|
||||
# These are permanent failures that won't be fixed by retrying
|
||||
if _is_non_retryable_error(e):
|
||||
raise
|
||||
|
||||
num_retries += 1
|
||||
|
||||
if num_retries > max_retries:
|
||||
@@ -289,6 +294,29 @@ def retry_with_exponential_backoff(
|
||||
return wrapper
|
||||
|
||||
|
||||
def _is_non_retryable_error(error: Exception) -> bool:
|
||||
"""Check if an error should not be retried.
|
||||
|
||||
Args:
|
||||
error: The exception to check
|
||||
|
||||
Returns:
|
||||
True if the error should not be retried, False otherwise
|
||||
"""
|
||||
# Check for OpenAI authentication errors
|
||||
error_type = type(error).__name__
|
||||
if error_type == "AuthenticationError":
|
||||
return True
|
||||
|
||||
# Check for other common non-retryable HTTP status codes
|
||||
# 401 Unauthorized, 403 Forbidden
|
||||
if hasattr(error, "status_code"):
|
||||
if error.status_code in (401, 403):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def url_retrieve(url: str):
|
||||
"""
|
||||
Parameters
|
||||
|
||||
@@ -515,3 +515,34 @@ def test_openai_propagates_api_key(monkeypatch):
|
||||
query = "greetings"
|
||||
actual = table.search(query).limit(1).to_pydantic(Words)[0]
|
||||
assert len(actual.text) > 0
|
||||
|
||||
|
||||
@patch("time.sleep")
|
||||
def test_openai_no_retry_on_401(mock_sleep):
|
||||
"""
|
||||
Test that OpenAI embedding function does not retry on 401 authentication
|
||||
errors.
|
||||
"""
|
||||
from lancedb.embeddings.utils import retry_with_exponential_backoff
|
||||
|
||||
# Create a mock that raises an AuthenticationError
|
||||
class MockAuthenticationError(Exception):
|
||||
"""Mock OpenAI AuthenticationError"""
|
||||
|
||||
pass
|
||||
|
||||
MockAuthenticationError.__name__ = "AuthenticationError"
|
||||
|
||||
mock_func = MagicMock(side_effect=MockAuthenticationError("Invalid API key"))
|
||||
|
||||
# Wrap the function with retry logic
|
||||
wrapped_func = retry_with_exponential_backoff(mock_func, max_retries=3)
|
||||
|
||||
# Should raise without retrying
|
||||
with pytest.raises(MockAuthenticationError):
|
||||
wrapped_func()
|
||||
|
||||
# Verify that the function was only called once (no retries)
|
||||
assert mock_func.call_count == 1
|
||||
# Verify that sleep was never called (no retries)
|
||||
assert mock_sleep.call_count == 0
|
||||
|
||||
@@ -292,18 +292,14 @@ class TestModel(lancedb.pydantic.LanceModel):
|
||||
lambda: pa.table({"a": [1], "b": [2]}),
|
||||
lambda: pa.table({"a": [1], "b": [2]}).to_reader(),
|
||||
lambda: iter(pa.table({"a": [1], "b": [2]}).to_batches()),
|
||||
lambda: (
|
||||
lance.write_dataset(
|
||||
pa.table({"a": [1], "b": [2]}),
|
||||
"memory://test",
|
||||
)
|
||||
),
|
||||
lambda: (
|
||||
lance.write_dataset(
|
||||
pa.table({"a": [1], "b": [2]}),
|
||||
"memory://test",
|
||||
).scanner()
|
||||
lambda: lance.write_dataset(
|
||||
pa.table({"a": [1], "b": [2]}),
|
||||
"memory://test",
|
||||
),
|
||||
lambda: lance.write_dataset(
|
||||
pa.table({"a": [1], "b": [2]}),
|
||||
"memory://test",
|
||||
).scanner(),
|
||||
lambda: pd.DataFrame({"a": [1], "b": [2]}),
|
||||
lambda: pl.DataFrame({"a": [1], "b": [2]}),
|
||||
lambda: pl.LazyFrame({"a": [1], "b": [2]}),
|
||||
|
||||
Reference in New Issue
Block a user