diff --git a/python/python/lancedb/embeddings/openai.py b/python/python/lancedb/embeddings/openai.py index 7da8f2c3..2fca549d 100644 --- a/python/python/lancedb/embeddings/openai.py +++ b/python/python/lancedb/embeddings/openai.py @@ -1,17 +1,9 @@ -# Copyright (c) 2023. LanceDB Developers -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright The LanceDB Authors + from functools import cached_property from typing import TYPE_CHECKING, List, Optional, Union +import logging from ..util import attempt_import_or_raise from .base import TextEmbeddingFunction @@ -89,17 +81,26 @@ class OpenAIEmbeddings(TextEmbeddingFunction): texts: list[str] or np.ndarray (of str) The texts to embed """ + openai = attempt_import_or_raise("openai") + # TODO retry, rate limit, token limit - if self.name == "text-embedding-ada-002": - rs = self._openai_client.embeddings.create(input=texts, model=self.name) - else: - kwargs = { - "input": texts, - "model": self.name, - } - if self.dim: - kwargs["dimensions"] = self.dim - rs = self._openai_client.embeddings.create(**kwargs) + try: + if self.name == "text-embedding-ada-002": + rs = self._openai_client.embeddings.create(input=texts, model=self.name) + else: + kwargs = { + "input": texts, + "model": self.name, + } + if self.dim: + kwargs["dimensions"] = self.dim + rs = self._openai_client.embeddings.create(**kwargs) + except openai.BadRequestError: + logging.exception("Bad request: %s", texts) + return [None] * len(texts) + except Exception: + logging.exception("OpenAI embeddings error") + raise return [v.embedding for v in rs.data] @cached_property