diff --git a/python/lancedb/embeddings/openai.py b/python/lancedb/embeddings/openai.py index 406ed40f..678fe417 100644 --- a/python/lancedb/embeddings/openai.py +++ b/python/lancedb/embeddings/openai.py @@ -10,6 +10,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from functools import cached_property from typing import List, Union import numpy as np @@ -44,6 +45,10 @@ class OpenAIEmbeddings(TextEmbeddingFunction): The texts to embed """ # TODO retry, rate limit, token limit + rs = self._openai_client.embeddings.create(input=texts, model=self.name) + return [v.embedding for v in rs.data] + + @cached_property + def _openai_client(self): openai = self.safe_import("openai") - rs = openai.Embedding.create(input=texts, model=self.name)["data"] - return [v["embedding"] for v in rs] + return openai.OpenAI() diff --git a/python/lancedb/embeddings/utils.py b/python/lancedb/embeddings/utils.py index 59ed0460..8f893142 100644 --- a/python/lancedb/embeddings/utils.py +++ b/python/lancedb/embeddings/utils.py @@ -249,7 +249,7 @@ def retry_with_exponential_backoff( if num_retries > max_retries: raise Exception( - f"Maximum number of retries ({max_retries}) exceeded." + f"Maximum number of retries ({max_retries}) exceeded.", e ) delay *= exponential_base * (1 + jitter * random.random()) diff --git a/python/pyproject.toml b/python/pyproject.toml index bf16676f..73a0a625 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -53,7 +53,7 @@ tests = ["pandas>=1.4", "pytest", "pytest-mock", "pytest-asyncio", "requests", " dev = ["ruff", "pre-commit", "black"] docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"] clip = ["torch", "pillow", "open-clip"] -embeddings = ["openai", "sentence-transformers", "torch", "pillow", "open-clip-torch", "cohere", "InstructorEmbedding"] +embeddings = ["openai>=1.6.1", "sentence-transformers", "torch", "pillow", "open-clip-torch", "cohere", "InstructorEmbedding"] [project.scripts] lancedb = "lancedb.cli.cli:cli" diff --git a/python/tests/test_embeddings_slow.py b/python/tests/test_embeddings_slow.py index 2e116827..826934f9 100644 --- a/python/tests/test_embeddings_slow.py +++ b/python/tests/test_embeddings_slow.py @@ -29,7 +29,7 @@ from lancedb.pydantic import LanceModel, Vector @pytest.mark.slow @pytest.mark.parametrize("alias", ["sentence-transformers", "openai"]) -def test_sentence_transformer(alias, tmp_path): +def test_basic_text_embeddings(alias, tmp_path): db = lancedb.connect(tmp_path) registry = get_registry() func = registry.get(alias).create(max_retries=0)