From 280276409233c18e0903176466399a97bc4c7125 Mon Sep 17 00:00:00 2001 From: Varun Chawla <34209028+veeceey@users.noreply.github.com> Date: Thu, 19 Feb 2026 09:20:54 -0800 Subject: [PATCH] fix(embeddings): stop retrying OpenAI 401 authentication errors (#2995) ## Summary Fixes #1679 This PR prevents the OpenAI embedding function from retrying when receiving a 401 Unauthorized error. Authentication errors are permanent failures that won't be fixed by retrying, yet the current implementation retries all exceptions up to 7 times by default. ## Changes - Modified `retry_with_exponential_backoff` in `utils.py` to check for non-retryable errors before retrying - Added `_is_non_retryable_error` helper function that detects: - Exceptions with name `AuthenticationError` (OpenAI's 401 error) - Exceptions with `status_code` attribute of 401 or 403 - Enhanced OpenAI embeddings to explicitly catch and re-raise `AuthenticationError` with better logging - Added unit test `test_openai_no_retry_on_401` to verify authentication errors don't trigger retries ## Test Plan - Added test that verifies: 1. A function raising `AuthenticationError` is only called once 2. No retry delays occur (sleep is never called) - Existing tests continue to pass - Formatting applied via `make format` ## Example Behavior **Before**: With an invalid API key, users would see 7 retry attempts over ~2 minutes: ``` WARNING:root:Error occurred: Error code: 401 - {'error': {'message': 'Incorrect API key provided...'}} Retrying in 3.97 seconds (retry 1 of 7) WARNING:root:Error occurred: Error code: 401... Retrying in 7.94 seconds (retry 2 of 7) ... ``` **After**: With an invalid API key, the error is raised immediately: ``` ERROR:root:Authentication failed: Invalid API key provided AuthenticationError: Error code: 401 - {'error': {'message': 'Incorrect API key provided...'}} ``` This provides better UX and prevents unnecessary API calls that would fail anyway. --------- Co-authored-by: Will Jones --- python/python/lancedb/embeddings/openai.py | 3 +++ python/python/lancedb/embeddings/utils.py | 28 +++++++++++++++++++ python/python/tests/test_embeddings.py | 31 ++++++++++++++++++++++ python/python/tests/test_util.py | 18 +++++-------- 4 files changed, 69 insertions(+), 11 deletions(-) diff --git a/python/python/lancedb/embeddings/openai.py b/python/python/lancedb/embeddings/openai.py index 9b18e45cd..8f3073d49 100644 --- a/python/python/lancedb/embeddings/openai.py +++ b/python/python/lancedb/embeddings/openai.py @@ -110,6 +110,9 @@ class OpenAIEmbeddings(TextEmbeddingFunction): valid_embeddings = { idx: v.embedding for v, idx in zip(rs.data, valid_indices) } + except openai.AuthenticationError: + logging.error("Authentication failed: Invalid API key provided") + raise except openai.BadRequestError: logging.exception("Bad request: %s", texts) return [None] * len(texts) diff --git a/python/python/lancedb/embeddings/utils.py b/python/python/lancedb/embeddings/utils.py index 8b892a065..1fefc78bf 100644 --- a/python/python/lancedb/embeddings/utils.py +++ b/python/python/lancedb/embeddings/utils.py @@ -269,6 +269,11 @@ def retry_with_exponential_backoff( # and say that it is assumed that if this portion errors out, it's due # to rate limit but the user should check the error message to be sure. except Exception as e: # noqa: PERF203 + # Don't retry on authentication errors (e.g., OpenAI 401) + # These are permanent failures that won't be fixed by retrying + if _is_non_retryable_error(e): + raise + num_retries += 1 if num_retries > max_retries: @@ -289,6 +294,29 @@ def retry_with_exponential_backoff( return wrapper +def _is_non_retryable_error(error: Exception) -> bool: + """Check if an error should not be retried. + + Args: + error: The exception to check + + Returns: + True if the error should not be retried, False otherwise + """ + # Check for OpenAI authentication errors + error_type = type(error).__name__ + if error_type == "AuthenticationError": + return True + + # Check for other common non-retryable HTTP status codes + # 401 Unauthorized, 403 Forbidden + if hasattr(error, "status_code"): + if error.status_code in (401, 403): + return True + + return False + + def url_retrieve(url: str): """ Parameters diff --git a/python/python/tests/test_embeddings.py b/python/python/tests/test_embeddings.py index 2f01cf3b8..c78b822f1 100644 --- a/python/python/tests/test_embeddings.py +++ b/python/python/tests/test_embeddings.py @@ -515,3 +515,34 @@ def test_openai_propagates_api_key(monkeypatch): query = "greetings" actual = table.search(query).limit(1).to_pydantic(Words)[0] assert len(actual.text) > 0 + + +@patch("time.sleep") +def test_openai_no_retry_on_401(mock_sleep): + """ + Test that OpenAI embedding function does not retry on 401 authentication + errors. + """ + from lancedb.embeddings.utils import retry_with_exponential_backoff + + # Create a mock that raises an AuthenticationError + class MockAuthenticationError(Exception): + """Mock OpenAI AuthenticationError""" + + pass + + MockAuthenticationError.__name__ = "AuthenticationError" + + mock_func = MagicMock(side_effect=MockAuthenticationError("Invalid API key")) + + # Wrap the function with retry logic + wrapped_func = retry_with_exponential_backoff(mock_func, max_retries=3) + + # Should raise without retrying + with pytest.raises(MockAuthenticationError): + wrapped_func() + + # Verify that the function was only called once (no retries) + assert mock_func.call_count == 1 + # Verify that sleep was never called (no retries) + assert mock_sleep.call_count == 0 diff --git a/python/python/tests/test_util.py b/python/python/tests/test_util.py index a847deaca..e7ba8bf86 100644 --- a/python/python/tests/test_util.py +++ b/python/python/tests/test_util.py @@ -292,18 +292,14 @@ class TestModel(lancedb.pydantic.LanceModel): lambda: pa.table({"a": [1], "b": [2]}), lambda: pa.table({"a": [1], "b": [2]}).to_reader(), lambda: iter(pa.table({"a": [1], "b": [2]}).to_batches()), - lambda: ( - lance.write_dataset( - pa.table({"a": [1], "b": [2]}), - "memory://test", - ) - ), - lambda: ( - lance.write_dataset( - pa.table({"a": [1], "b": [2]}), - "memory://test", - ).scanner() + lambda: lance.write_dataset( + pa.table({"a": [1], "b": [2]}), + "memory://test", ), + lambda: lance.write_dataset( + pa.table({"a": [1], "b": [2]}), + "memory://test", + ).scanner(), lambda: pd.DataFrame({"a": [1], "b": [2]}), lambda: pl.DataFrame({"a": [1], "b": [2]}), lambda: pl.LazyFrame({"a": [1], "b": [2]}),