feat(python): add watsonx embeddings to registry (#1486)

Related issue: https://github.com/lancedb/lancedb/issues/1412

---------

Co-authored-by: Robby <h0rv@users.noreply.github.com>
This commit is contained in:
Robby
2024-08-06 01:28:33 -04:00
committed by GitHub
parent 61c05b51a0
commit 8d2ff7b210
5 changed files with 215 additions and 1 deletions

View File

@@ -417,3 +417,28 @@ def test_openai_embedding(tmp_path):
tbl.add(df)
assert len(tbl.to_pandas()["vector"][0]) == model.ndims()
assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world"
@pytest.mark.slow
@pytest.mark.skipif(
os.environ.get("WATSONX_API_KEY") is None
or os.environ.get("WATSONX_PROJECT_ID") is None,
reason="WATSONX_API_KEY and WATSONX_PROJECT_ID not set",
)
def test_watsonx_embedding(tmp_path):
from lancedb.embeddings import WatsonxEmbeddings
for name in WatsonxEmbeddings.model_names():
model = get_registry().get("watsonx").create(max_retries=0, name=name)
class TextModel(LanceModel):
text: str = model.SourceField()
vector: Vector(model.ndims()) = model.VectorField()
db = lancedb.connect("~/.lancedb")
tbl = db.create_table("watsonx_test", schema=TextModel, mode="overwrite")
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
tbl.add(df)
assert len(tbl.to_pandas()["vector"][0]) == model.ndims()
assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world"