fix(embeddings): stop retrying OpenAI 401 authentication errors (#2995)

## Summary
Fixes #1679

This PR prevents the OpenAI embedding function from retrying when
receiving a 401 Unauthorized error. Authentication errors are permanent
failures that won't be fixed by retrying, yet the current implementation
retries all exceptions up to 7 times by default.

## Changes
- Modified `retry_with_exponential_backoff` in `utils.py` to check for
non-retryable errors before retrying
- Added `_is_non_retryable_error` helper function that detects:
  - Exceptions with name `AuthenticationError` (OpenAI's 401 error)
  - Exceptions with `status_code` attribute of 401 or 403
- Enhanced OpenAI embeddings to explicitly catch and re-raise
`AuthenticationError` with better logging
- Added unit test `test_openai_no_retry_on_401` to verify authentication
errors don't trigger retries

## Test Plan
- Added test that verifies:
  1. A function raising `AuthenticationError` is only called once
  2. No retry delays occur (sleep is never called)
- Existing tests continue to pass
- Formatting applied via `make format`

## Example Behavior

**Before**: With an invalid API key, users would see 7 retry attempts
over ~2 minutes:
```
WARNING:root:Error occurred: Error code: 401 - {'error': {'message': 'Incorrect API key provided...'}}
 Retrying in 3.97 seconds (retry 1 of 7)
WARNING:root:Error occurred: Error code: 401...
 Retrying in 7.94 seconds (retry 2 of 7)
...
```

**After**: With an invalid API key, the error is raised immediately:
```
ERROR:root:Authentication failed: Invalid API key provided
AuthenticationError: Error code: 401 - {'error': {'message': 'Incorrect API key provided...'}}
```

This provides better UX and prevents unnecessary API calls that would
fail anyway.

---------

Co-authored-by: Will Jones <willjones127@gmail.com>
This commit is contained in:
Varun Chawla
2026-02-19 09:20:54 -08:00
committed by GitHub
parent 37bbb0dba1
commit 2802764092
4 changed files with 69 additions and 11 deletions

View File

@@ -110,6 +110,9 @@ class OpenAIEmbeddings(TextEmbeddingFunction):
valid_embeddings = {
idx: v.embedding for v, idx in zip(rs.data, valid_indices)
}
except openai.AuthenticationError:
logging.error("Authentication failed: Invalid API key provided")
raise
except openai.BadRequestError:
logging.exception("Bad request: %s", texts)
return [None] * len(texts)

View File

@@ -269,6 +269,11 @@ def retry_with_exponential_backoff(
# and say that it is assumed that if this portion errors out, it's due
# to rate limit but the user should check the error message to be sure.
except Exception as e: # noqa: PERF203
# Don't retry on authentication errors (e.g., OpenAI 401)
# These are permanent failures that won't be fixed by retrying
if _is_non_retryable_error(e):
raise
num_retries += 1
if num_retries > max_retries:
@@ -289,6 +294,29 @@ def retry_with_exponential_backoff(
return wrapper
def _is_non_retryable_error(error: Exception) -> bool:
"""Check if an error should not be retried.
Args:
error: The exception to check
Returns:
True if the error should not be retried, False otherwise
"""
# Check for OpenAI authentication errors
error_type = type(error).__name__
if error_type == "AuthenticationError":
return True
# Check for other common non-retryable HTTP status codes
# 401 Unauthorized, 403 Forbidden
if hasattr(error, "status_code"):
if error.status_code in (401, 403):
return True
return False
def url_retrieve(url: str):
"""
Parameters

View File

@@ -515,3 +515,34 @@ def test_openai_propagates_api_key(monkeypatch):
query = "greetings"
actual = table.search(query).limit(1).to_pydantic(Words)[0]
assert len(actual.text) > 0
@patch("time.sleep")
def test_openai_no_retry_on_401(mock_sleep):
"""
Test that OpenAI embedding function does not retry on 401 authentication
errors.
"""
from lancedb.embeddings.utils import retry_with_exponential_backoff
# Create a mock that raises an AuthenticationError
class MockAuthenticationError(Exception):
"""Mock OpenAI AuthenticationError"""
pass
MockAuthenticationError.__name__ = "AuthenticationError"
mock_func = MagicMock(side_effect=MockAuthenticationError("Invalid API key"))
# Wrap the function with retry logic
wrapped_func = retry_with_exponential_backoff(mock_func, max_retries=3)
# Should raise without retrying
with pytest.raises(MockAuthenticationError):
wrapped_func()
# Verify that the function was only called once (no retries)
assert mock_func.call_count == 1
# Verify that sleep was never called (no retries)
assert mock_sleep.call_count == 0

View File

@@ -292,18 +292,14 @@ class TestModel(lancedb.pydantic.LanceModel):
lambda: pa.table({"a": [1], "b": [2]}),
lambda: pa.table({"a": [1], "b": [2]}).to_reader(),
lambda: iter(pa.table({"a": [1], "b": [2]}).to_batches()),
lambda: (
lance.write_dataset(
pa.table({"a": [1], "b": [2]}),
"memory://test",
)
),
lambda: (
lance.write_dataset(
pa.table({"a": [1], "b": [2]}),
"memory://test",
).scanner()
lambda: lance.write_dataset(
pa.table({"a": [1], "b": [2]}),
"memory://test",
),
lambda: lance.write_dataset(
pa.table({"a": [1], "b": [2]}),
"memory://test",
).scanner(),
lambda: pd.DataFrame({"a": [1], "b": [2]}),
lambda: pl.DataFrame({"a": [1], "b": [2]}),
lambda: pl.LazyFrame({"a": [1], "b": [2]}),