feat(python): add support new openai embedding functions (#912)

@PrashantDixit0

---------

Co-authored-by: Chang She <759245+changhiskhan@users.noreply.github.com>
This commit is contained in:
Ayush Chaurasia
2024-02-05 07:49:42 +05:30
committed by GitHub
parent 0b0f42537e
commit 738511c5f2
2 changed files with 73 additions and 9 deletions

View File

@@ -12,7 +12,7 @@
# limitations under the License.
import os
from functools import cached_property
from typing import List, Union
from typing import List, Optional, Union
import numpy as np
@@ -30,10 +30,21 @@ class OpenAIEmbeddings(TextEmbeddingFunction):
"""
name: str = "text-embedding-ada-002"
dim: Optional[int] = None
def ndims(self):
# TODO don't hardcode this
return 1536
return self._ndims
@cached_property
def _ndims(self):
if self.name == "text-embedding-ada-002":
return 1536
elif self.name == "text-embedding-3-large":
return self.dim or 3072
elif self.name == "text-embedding-3-small":
return self.dim or 1536
else:
raise ValueError(f"Unknown model name {self.name}")
def generate_embeddings(
self, texts: Union[List[str], np.ndarray]
@@ -47,7 +58,12 @@ class OpenAIEmbeddings(TextEmbeddingFunction):
The texts to embed
"""
# TODO retry, rate limit, token limit
rs = self._openai_client.embeddings.create(input=texts, model=self.name)
if self.name == "text-embedding-ada-002":
rs = self._openai_client.embeddings.create(input=texts, model=self.name)
else:
rs = self._openai_client.embeddings.create(
input=texts, model=self.name, dimensions=self.ndims()
)
return [v.embedding for v in rs.data]
@cached_property

View File

@@ -23,11 +23,6 @@ import lancedb
from lancedb.embeddings import get_registry
from lancedb.pydantic import LanceModel, Vector
try:
if importlib.util.find_spec("mlx.core") is not None:
_mlx = True
except ImportError:
_mlx = None
# These are integration tests for embedding functions.
# They are slow because they require downloading models
# or connection to external api
@@ -210,6 +205,13 @@ def test_gemini_embedding(tmp_path):
assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world"
try:
if importlib.util.find_spec("mlx.core") is not None:
_mlx = True
except ImportError:
_mlx = None
@pytest.mark.skipif(
_mlx is None,
reason="mlx tests only required for apple users.",
@@ -266,3 +268,49 @@ def test_bedrock_embedding(tmp_path):
tbl.add(df)
assert len(tbl.to_pandas()["vector"][0]) == model.ndims()
@pytest.mark.slow
@pytest.mark.skipif(
os.environ.get("OPENAI_API_KEY") is None, reason="OPENAI_API_KEY not set"
)
def test_openai_embedding(tmp_path):
def _get_table(model):
class TextModel(LanceModel):
text: str = model.SourceField()
vector: Vector(model.ndims()) = model.VectorField()
db = lancedb.connect(tmp_path)
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
return tbl
model = get_registry().get("openai").create(max_retries=0)
tbl = _get_table(model)
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
tbl.add(df)
assert len(tbl.to_pandas()["vector"][0]) == model.ndims()
assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world"
model = (
get_registry()
.get("openai")
.create(max_retries=0, name="text-embedding-3-large")
)
tbl = _get_table(model)
tbl.add(df)
assert len(tbl.to_pandas()["vector"][0]) == model.ndims()
assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world"
model = (
get_registry()
.get("openai")
.create(max_retries=0, name="text-embedding-3-large", dim=1024)
)
tbl = _get_table(model)
tbl.add(df)
assert len(tbl.to_pandas()["vector"][0]) == model.ndims()
assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world"