Files
lancedb/python/python/lancedb/embeddings/openai.py
Varun Chawla 2802764092 fix(embeddings): stop retrying OpenAI 401 authentication errors (#2995)
## Summary
Fixes #1679

This PR prevents the OpenAI embedding function from retrying when
receiving a 401 Unauthorized error. Authentication errors are permanent
failures that won't be fixed by retrying, yet the current implementation
retries all exceptions up to 7 times by default.

## Changes
- Modified `retry_with_exponential_backoff` in `utils.py` to check for
non-retryable errors before retrying
- Added `_is_non_retryable_error` helper function that detects:
  - Exceptions with name `AuthenticationError` (OpenAI's 401 error)
  - Exceptions with `status_code` attribute of 401 or 403
- Enhanced OpenAI embeddings to explicitly catch and re-raise
`AuthenticationError` with better logging
- Added unit test `test_openai_no_retry_on_401` to verify authentication
errors don't trigger retries

## Test Plan
- Added test that verifies:
  1. A function raising `AuthenticationError` is only called once
  2. No retry delays occur (sleep is never called)
- Existing tests continue to pass
- Formatting applied via `make format`

## Example Behavior

**Before**: With an invalid API key, users would see 7 retry attempts
over ~2 minutes:
```
WARNING:root:Error occurred: Error code: 401 - {'error': {'message': 'Incorrect API key provided...'}}
 Retrying in 3.97 seconds (retry 1 of 7)
WARNING:root:Error occurred: Error code: 401...
 Retrying in 7.94 seconds (retry 2 of 7)
...
```

**After**: With an invalid API key, the error is raised immediately:
```
ERROR:root:Authentication failed: Invalid API key provided
AuthenticationError: Error code: 401 - {'error': {'message': 'Incorrect API key provided...'}}
```

This provides better UX and prevents unnecessary API calls that would
fail anyway.

---------

Co-authored-by: Will Jones <willjones127@gmail.com>
2026-02-19 09:20:54 -08:00

141 lines
4.0 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
from functools import cached_property
from typing import TYPE_CHECKING, List, Optional, Union
import logging
from ..util import attempt_import_or_raise
from .base import TextEmbeddingFunction
from .registry import register
if TYPE_CHECKING:
import numpy as np
@register("openai")
class OpenAIEmbeddings(TextEmbeddingFunction):
"""
An embedding function that uses the OpenAI API
https://platform.openai.com/docs/guides/embeddings
This can also be used for open source models that
are compatible with the OpenAI API.
Notes
-----
If you're running an Ollama server locally,
you can just override the `base_url` parameter
and provide the Ollama embedding model you want
to use (https://ollama.com/library):
```python
from lancedb.embeddings import get_registry
openai = get_registry().get("openai")
embedding_function = openai.create(
name="<ollama-embedding-model-name>",
base_url="http://localhost:11434",
)
```
"""
name: str = "text-embedding-ada-002"
dim: Optional[int] = None
base_url: Optional[str] = None
default_headers: Optional[dict] = None
organization: Optional[str] = None
api_key: Optional[str] = None
# Set true to use Azure OpenAI API
use_azure: bool = False
def ndims(self):
return self._ndims
@staticmethod
def sensitive_keys():
return ["api_key"]
@staticmethod
def model_names():
return [
"text-embedding-ada-002",
"text-embedding-3-large",
"text-embedding-3-small",
]
@cached_property
def _ndims(self):
if self.name == "text-embedding-ada-002":
return 1536
elif self.name == "text-embedding-3-large":
return self.dim or 3072
elif self.name == "text-embedding-3-small":
return self.dim or 1536
else:
raise ValueError(f"Unknown model name {self.name}")
def generate_embeddings(
self, texts: Union[List[str], "np.ndarray"]
) -> List["np.array"]:
"""
Get the embeddings for the given texts
Parameters
----------
texts: list[str] or np.ndarray (of str)
The texts to embed
"""
openai = attempt_import_or_raise("openai")
valid_texts = []
valid_indices = []
for idx, text in enumerate(texts):
if text:
valid_texts.append(text)
valid_indices.append(idx)
# TODO retry, rate limit, token limit
try:
kwargs = {
"input": valid_texts,
"model": self.name,
}
if self.name != "text-embedding-ada-002":
kwargs["dimensions"] = self.dim
rs = self._openai_client.embeddings.create(**kwargs)
valid_embeddings = {
idx: v.embedding for v, idx in zip(rs.data, valid_indices)
}
except openai.AuthenticationError:
logging.error("Authentication failed: Invalid API key provided")
raise
except openai.BadRequestError:
logging.exception("Bad request: %s", texts)
return [None] * len(texts)
except Exception:
logging.exception("OpenAI embeddings error")
raise
return [valid_embeddings.get(idx, None) for idx in range(len(texts))]
@cached_property
def _openai_client(self):
openai = attempt_import_or_raise("openai")
kwargs = {}
if self.base_url:
kwargs["base_url"] = self.base_url
if self.default_headers:
kwargs["default_headers"] = self.default_headers
if self.organization:
kwargs["organization"] = self.organization
if self.api_key:
kwargs["api_key"] = self.api_key
if self.use_azure:
return openai.AzureOpenAI(**kwargs)
else:
return openai.OpenAI(**kwargs)