chore(python): update embedding API to use openai 1.6.1 (#751)

API has changed significantly, namely `openai.Embedding.create` no
longer exists.
https://github.com/openai/openai-python/discussions/742

Update the OpenAI embedding function and put a minimum on the openai sdk
version.
This commit is contained in:
Chang She
2023-12-28 15:05:57 -08:00
committed by Weston Pace
parent 7bac1131fb
commit c97ae6b787
4 changed files with 10 additions and 5 deletions

View File

@@ -10,6 +10,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import cached_property
from typing import List, Union
import numpy as np
@@ -44,6 +45,10 @@ class OpenAIEmbeddings(TextEmbeddingFunction):
The texts to embed
"""
# TODO retry, rate limit, token limit
rs = self._openai_client.embeddings.create(input=texts, model=self.name)
return [v.embedding for v in rs.data]
@cached_property
def _openai_client(self):
openai = self.safe_import("openai")
rs = openai.Embedding.create(input=texts, model=self.name)["data"]
return [v["embedding"] for v in rs]
return openai.OpenAI()

View File

@@ -249,7 +249,7 @@ def retry_with_exponential_backoff(
if num_retries > max_retries:
raise Exception(
f"Maximum number of retries ({max_retries}) exceeded."
f"Maximum number of retries ({max_retries}) exceeded.", e
)
delay *= exponential_base * (1 + jitter * random.random())

View File

@@ -50,7 +50,7 @@ tests = ["pandas>=1.4", "pytest", "pytest-mock", "pytest-asyncio", "requests", "
dev = ["ruff", "pre-commit", "black"]
docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"]
clip = ["torch", "pillow", "open-clip"]
embeddings = ["openai", "sentence-transformers", "torch", "pillow", "open-clip-torch", "cohere", "InstructorEmbedding"]
embeddings = ["openai>=1.6.1", "sentence-transformers", "torch", "pillow", "open-clip-torch", "cohere", "InstructorEmbedding"]
[build-system]
requires = ["setuptools", "wheel"]

View File

@@ -29,7 +29,7 @@ from lancedb.pydantic import LanceModel, Vector
@pytest.mark.slow
@pytest.mark.parametrize("alias", ["sentence-transformers", "openai"])
def test_sentence_transformer(alias, tmp_path):
def test_basic_text_embeddings(alias, tmp_path):
db = lancedb.connect(tmp_path)
registry = get_registry()
func = registry.get(alias).create(max_retries=0)