mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-27 15:12:53 +00:00
chore: add error handling for openai embedding generation (#1680)
This commit is contained in:
@@ -1,17 +1,9 @@
|
||||
# Copyright (c) 2023. LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
from functools import cached_property
|
||||
from typing import TYPE_CHECKING, List, Optional, Union
|
||||
import logging
|
||||
|
||||
from ..util import attempt_import_or_raise
|
||||
from .base import TextEmbeddingFunction
|
||||
@@ -89,17 +81,26 @@ class OpenAIEmbeddings(TextEmbeddingFunction):
|
||||
texts: list[str] or np.ndarray (of str)
|
||||
The texts to embed
|
||||
"""
|
||||
openai = attempt_import_or_raise("openai")
|
||||
|
||||
# TODO retry, rate limit, token limit
|
||||
if self.name == "text-embedding-ada-002":
|
||||
rs = self._openai_client.embeddings.create(input=texts, model=self.name)
|
||||
else:
|
||||
kwargs = {
|
||||
"input": texts,
|
||||
"model": self.name,
|
||||
}
|
||||
if self.dim:
|
||||
kwargs["dimensions"] = self.dim
|
||||
rs = self._openai_client.embeddings.create(**kwargs)
|
||||
try:
|
||||
if self.name == "text-embedding-ada-002":
|
||||
rs = self._openai_client.embeddings.create(input=texts, model=self.name)
|
||||
else:
|
||||
kwargs = {
|
||||
"input": texts,
|
||||
"model": self.name,
|
||||
}
|
||||
if self.dim:
|
||||
kwargs["dimensions"] = self.dim
|
||||
rs = self._openai_client.embeddings.create(**kwargs)
|
||||
except openai.BadRequestError:
|
||||
logging.exception("Bad request: %s", texts)
|
||||
return [None] * len(texts)
|
||||
except Exception:
|
||||
logging.exception("OpenAI embeddings error")
|
||||
raise
|
||||
return [v.embedding for v in rs.data]
|
||||
|
||||
@cached_property
|
||||
|
||||
Reference in New Issue
Block a user