diff --git a/docs/src/embeddings/default_embedding_functions.md b/docs/src/embeddings/default_embedding_functions.md index 0ef58925..95122437 100644 --- a/docs/src/embeddings/default_embedding_functions.md +++ b/docs/src/embeddings/default_embedding_functions.md @@ -518,6 +518,82 @@ tbl.add(df) rs = tbl.search("hello").limit(1).to_pandas() ``` +# IBM watsonx.ai Embeddings + +Generate text embeddings using IBM's watsonx.ai platform. + +## Supported Models + +You can find a list of supported models at [IBM watsonx.ai Documentation](https://dataplatform.cloud.ibm.com/docs/content/wsj/analyze-data/fm-models-embed.html?context=wx). The currently supported model names are: + +- `ibm/slate-125m-english-rtrvr` +- `ibm/slate-30m-english-rtrvr` +- `sentence-transformers/all-minilm-l12-v2` +- `intfloat/multilingual-e5-large` + +## Parameters + +The following parameters can be passed to the `create` method: + +| Parameter | Type | Default Value | Description | +|------------|----------|----------------------------------|-----------------------------------------------------------| +| name | str | "ibm/slate-125m-english-rtrvr" | The model ID of the watsonx.ai model to use | +| api_key | str | None | Optional IBM Cloud API key (or set `WATSONX_API_KEY`) | +| project_id | str | None | Optional watsonx project ID (or set `WATSONX_PROJECT_ID`) | +| url | str | None | Optional custom URL for the watsonx.ai instance | +| params | dict | None | Optional additional parameters for the embedding model | + +## Usage Example + +First, the watsonx.ai library is an optional dependency, so must be installed seperately: + +``` +pip install ibm-watsonx-ai +``` + +Optionally set environment variables (if not passing credentials to `create` directly): + +```sh +export WATSONX_API_KEY="YOUR_WATSONX_API_KEY" +export WATSONX_PROJECT_ID="YOUR_WATSONX_PROJECT_ID" +``` + +```python +import os +import lancedb +from lancedb.pydantic import LanceModel, Vector +from lancedb.embeddings import EmbeddingFunctionRegistry + +watsonx_embed = EmbeddingFunctionRegistry + .get_instance() + .get("watsonx") + .create( + name="ibm/slate-125m-english-rtrvr", + # Uncomment and set these if not using environment variables + # api_key="your_api_key_here", + # project_id="your_project_id_here", + # url="your_watsonx_url_here", + # params={...}, + ) + +class TextModel(LanceModel): + text: str = watsonx_embed.SourceField() + vector: Vector(watsonx_embed.ndims()) = watsonx_embed.VectorField() + +data = [ + {"text": "hello world"}, + {"text": "goodbye world"}, +] + +db = lancedb.connect("~/.lancedb") +tbl = db.create_table("watsonx_test", schema=TextModel, mode="overwrite") + +tbl.add(data) + +rs = tbl.search("hello").limit(1).to_pandas() +print(rs) +``` + ## Multi-modal embedding functions Multi-modal embedding functions allow you to query your table using both images and text. @@ -721,4 +797,4 @@ Usage Example: table.add( pd.DataFrame({"label": labels, "image_uri": uris, "image_bytes": image_bytes}) ) -``` \ No newline at end of file +``` diff --git a/python/pyproject.toml b/python/pyproject.toml index 485f47f9..4e7ab4c6 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -76,6 +76,7 @@ embeddings = [ "awscli>=1.29.57", "botocore>=1.31.57", "ollama", + "ibm-watsonx-ai>=1.1.2", ] azure = ["adlfs>=2024.2.0"] diff --git a/python/python/lancedb/embeddings/__init__.py b/python/python/lancedb/embeddings/__init__.py index 3bb6d56d..76da3ab4 100644 --- a/python/python/lancedb/embeddings/__init__.py +++ b/python/python/lancedb/embeddings/__init__.py @@ -26,3 +26,4 @@ from .transformers import TransformersEmbeddingFunction, ColbertEmbeddings from .imagebind import ImageBindEmbeddings from .utils import with_embeddings from .jinaai import JinaEmbeddings +from .watsonx import WatsonxEmbeddings diff --git a/python/python/lancedb/embeddings/watsonx.py b/python/python/lancedb/embeddings/watsonx.py new file mode 100644 index 00000000..35aa132f --- /dev/null +++ b/python/python/lancedb/embeddings/watsonx.py @@ -0,0 +1,111 @@ +# Copyright (c) 2023. LanceDB Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from functools import cached_property +from typing import List, Optional, Dict, Union + +from ..util import attempt_import_or_raise +from .base import TextEmbeddingFunction +from .registry import register + +import numpy as np + +DEFAULT_WATSONX_URL = "https://us-south.ml.cloud.ibm.com" + +MODELS_DIMS = { + "ibm/slate-125m-english-rtrvr": 768, + "ibm/slate-30m-english-rtrvr": 384, + "sentence-transformers/all-minilm-l12-v2": 384, + "intfloat/multilingual-e5-large": 1024, +} + + +@register("watsonx") +class WatsonxEmbeddings(TextEmbeddingFunction): + """ + API Docs: + --------- + https://cloud.ibm.com/apidocs/watsonx-ai#text-embeddings + + Supported embedding models: + --------------------------- + https://dataplatform.cloud.ibm.com/docs/content/wsj/analyze-data/fm-models-embed.html?context=wx + """ + + name: str = "ibm/slate-125m-english-rtrvr" + api_key: Optional[str] = None + project_id: Optional[str] = None + url: Optional[str] = None + params: Optional[Dict] = None + + @staticmethod + def model_names(): + return [ + "ibm/slate-125m-english-rtrvr", + "ibm/slate-30m-english-rtrvr", + "sentence-transformers/all-minilm-l12-v2", + "intfloat/multilingual-e5-large", + ] + + def ndims(self): + return self._ndims + + @cached_property + def _ndims(self): + if self.name not in MODELS_DIMS: + raise ValueError(f"Unknown model name {self.name}") + return MODELS_DIMS[self.name] + + def generate_embeddings( + self, + texts: Union[List[str], np.ndarray], + *args, + **kwargs, + ) -> List[List[float]]: + return self._watsonx_client.embed_documents( + texts=list(texts), + *args, + **kwargs, + ) + + @cached_property + def _watsonx_client(self): + ibm_watsonx_ai = attempt_import_or_raise("ibm_watsonx_ai") + ibm_watsonx_ai_foundation_models = attempt_import_or_raise( + "ibm_watsonx_ai.foundation_models" + ) + + kwargs = {"model_id": self.name} + if self.params: + kwargs["params"] = self.params + if self.project_id: + kwargs["project_id"] = self.project_id + elif "WATSONX_PROJECT_ID" in os.environ: + kwargs["project_id"] = os.environ["WATSONX_PROJECT_ID"] + else: + raise ValueError("WATSONX_PROJECT_ID must be set or passed") + + creds_kwargs = {} + if self.api_key: + creds_kwargs["api_key"] = self.api_key + elif "WATSONX_API_KEY" in os.environ: + creds_kwargs["api_key"] = os.environ["WATSONX_API_KEY"] + else: + raise ValueError("WATSONX_API_KEY must be set or passed") + if self.url: + creds_kwargs["url"] = self.url + else: + creds_kwargs["url"] = DEFAULT_WATSONX_URL + kwargs["credentials"] = ibm_watsonx_ai.Credentials(**creds_kwargs) + + return ibm_watsonx_ai_foundation_models.Embeddings(**kwargs) diff --git a/python/python/tests/test_embeddings_slow.py b/python/python/tests/test_embeddings_slow.py index 5be77ed7..6762ae10 100644 --- a/python/python/tests/test_embeddings_slow.py +++ b/python/python/tests/test_embeddings_slow.py @@ -417,3 +417,28 @@ def test_openai_embedding(tmp_path): tbl.add(df) assert len(tbl.to_pandas()["vector"][0]) == model.ndims() assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world" + + +@pytest.mark.slow +@pytest.mark.skipif( + os.environ.get("WATSONX_API_KEY") is None + or os.environ.get("WATSONX_PROJECT_ID") is None, + reason="WATSONX_API_KEY and WATSONX_PROJECT_ID not set", +) +def test_watsonx_embedding(tmp_path): + from lancedb.embeddings import WatsonxEmbeddings + + for name in WatsonxEmbeddings.model_names(): + model = get_registry().get("watsonx").create(max_retries=0, name=name) + + class TextModel(LanceModel): + text: str = model.SourceField() + vector: Vector(model.ndims()) = model.VectorField() + + db = lancedb.connect("~/.lancedb") + tbl = db.create_table("watsonx_test", schema=TextModel, mode="overwrite") + df = pd.DataFrame({"text": ["hello world", "goodbye world"]}) + + tbl.add(df) + assert len(tbl.to_pandas()["vector"][0]) == model.ndims() + assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world"