feat(python): add support new openai embedding functions (#912)

@PrashantDixit0

---------

Co-authored-by: Chang She <759245+changhiskhan@users.noreply.github.com>
This commit is contained in:
Ayush Chaurasia
2024-02-05 07:49:42 +05:30
committed by GitHub
parent 0b0f42537e
commit 738511c5f2
2 changed files with 73 additions and 9 deletions

View File

@@ -23,11 +23,6 @@ import lancedb
from lancedb.embeddings import get_registry
from lancedb.pydantic import LanceModel, Vector
try:
if importlib.util.find_spec("mlx.core") is not None:
_mlx = True
except ImportError:
_mlx = None
# These are integration tests for embedding functions.
# They are slow because they require downloading models
# or connection to external api
@@ -210,6 +205,13 @@ def test_gemini_embedding(tmp_path):
assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world"
try:
if importlib.util.find_spec("mlx.core") is not None:
_mlx = True
except ImportError:
_mlx = None
@pytest.mark.skipif(
_mlx is None,
reason="mlx tests only required for apple users.",
@@ -266,3 +268,49 @@ def test_bedrock_embedding(tmp_path):
tbl.add(df)
assert len(tbl.to_pandas()["vector"][0]) == model.ndims()
@pytest.mark.slow
@pytest.mark.skipif(
os.environ.get("OPENAI_API_KEY") is None, reason="OPENAI_API_KEY not set"
)
def test_openai_embedding(tmp_path):
def _get_table(model):
class TextModel(LanceModel):
text: str = model.SourceField()
vector: Vector(model.ndims()) = model.VectorField()
db = lancedb.connect(tmp_path)
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
return tbl
model = get_registry().get("openai").create(max_retries=0)
tbl = _get_table(model)
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
tbl.add(df)
assert len(tbl.to_pandas()["vector"][0]) == model.ndims()
assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world"
model = (
get_registry()
.get("openai")
.create(max_retries=0, name="text-embedding-3-large")
)
tbl = _get_table(model)
tbl.add(df)
assert len(tbl.to_pandas()["vector"][0]) == model.ndims()
assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world"
model = (
get_registry()
.get("openai")
.create(max_retries=0, name="text-embedding-3-large", dim=1024)
)
tbl = _get_table(model)
tbl.add(df)
assert len(tbl.to_pandas()["vector"][0]) == model.ndims()
assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world"