From 738511c5f2206d0eaede77112f03cf84e417c168 Mon Sep 17 00:00:00 2001 From: Ayush Chaurasia Date: Mon, 5 Feb 2024 07:49:42 +0530 Subject: [PATCH] feat(python): add support new openai embedding functions (#912) @PrashantDixit0 --------- Co-authored-by: Chang She <759245+changhiskhan@users.noreply.github.com> --- python/lancedb/embeddings/openai.py | 24 ++++++++++-- python/tests/test_embeddings_slow.py | 58 +++++++++++++++++++++++++--- 2 files changed, 73 insertions(+), 9 deletions(-) diff --git a/python/lancedb/embeddings/openai.py b/python/lancedb/embeddings/openai.py index c4d2f384..12200edd 100644 --- a/python/lancedb/embeddings/openai.py +++ b/python/lancedb/embeddings/openai.py @@ -12,7 +12,7 @@ # limitations under the License. import os from functools import cached_property -from typing import List, Union +from typing import List, Optional, Union import numpy as np @@ -30,10 +30,21 @@ class OpenAIEmbeddings(TextEmbeddingFunction): """ name: str = "text-embedding-ada-002" + dim: Optional[int] = None def ndims(self): - # TODO don't hardcode this - return 1536 + return self._ndims + + @cached_property + def _ndims(self): + if self.name == "text-embedding-ada-002": + return 1536 + elif self.name == "text-embedding-3-large": + return self.dim or 3072 + elif self.name == "text-embedding-3-small": + return self.dim or 1536 + else: + raise ValueError(f"Unknown model name {self.name}") def generate_embeddings( self, texts: Union[List[str], np.ndarray] @@ -47,7 +58,12 @@ class OpenAIEmbeddings(TextEmbeddingFunction): The texts to embed """ # TODO retry, rate limit, token limit - rs = self._openai_client.embeddings.create(input=texts, model=self.name) + if self.name == "text-embedding-ada-002": + rs = self._openai_client.embeddings.create(input=texts, model=self.name) + else: + rs = self._openai_client.embeddings.create( + input=texts, model=self.name, dimensions=self.ndims() + ) return [v.embedding for v in rs.data] @cached_property diff --git a/python/tests/test_embeddings_slow.py b/python/tests/test_embeddings_slow.py index e3724b03..cfdb1247 100644 --- a/python/tests/test_embeddings_slow.py +++ b/python/tests/test_embeddings_slow.py @@ -23,11 +23,6 @@ import lancedb from lancedb.embeddings import get_registry from lancedb.pydantic import LanceModel, Vector -try: - if importlib.util.find_spec("mlx.core") is not None: - _mlx = True -except ImportError: - _mlx = None # These are integration tests for embedding functions. # They are slow because they require downloading models # or connection to external api @@ -210,6 +205,13 @@ def test_gemini_embedding(tmp_path): assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world" +try: + if importlib.util.find_spec("mlx.core") is not None: + _mlx = True +except ImportError: + _mlx = None + + @pytest.mark.skipif( _mlx is None, reason="mlx tests only required for apple users.", @@ -266,3 +268,49 @@ def test_bedrock_embedding(tmp_path): tbl.add(df) assert len(tbl.to_pandas()["vector"][0]) == model.ndims() + + +@pytest.mark.slow +@pytest.mark.skipif( + os.environ.get("OPENAI_API_KEY") is None, reason="OPENAI_API_KEY not set" +) +def test_openai_embedding(tmp_path): + def _get_table(model): + class TextModel(LanceModel): + text: str = model.SourceField() + vector: Vector(model.ndims()) = model.VectorField() + + db = lancedb.connect(tmp_path) + tbl = db.create_table("test", schema=TextModel, mode="overwrite") + + return tbl + + model = get_registry().get("openai").create(max_retries=0) + tbl = _get_table(model) + df = pd.DataFrame({"text": ["hello world", "goodbye world"]}) + + tbl.add(df) + assert len(tbl.to_pandas()["vector"][0]) == model.ndims() + assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world" + + model = ( + get_registry() + .get("openai") + .create(max_retries=0, name="text-embedding-3-large") + ) + tbl = _get_table(model) + + tbl.add(df) + assert len(tbl.to_pandas()["vector"][0]) == model.ndims() + assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world" + + model = ( + get_registry() + .get("openai") + .create(max_retries=0, name="text-embedding-3-large", dim=1024) + ) + tbl = _get_table(model) + + tbl.add(df) + assert len(tbl.to_pandas()["vector"][0]) == model.ndims() + assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world"