feat(python): Aws Bedrock embeddings integration (#822)

Supports amazon titan, cohere english & cohere multi-lingual base
models.
This commit is contained in:
Ayush Chaurasia
2024-01-28 02:04:15 +05:30
committed by Weston Pace
parent f2e29eb004
commit 545a03d7f9
5 changed files with 307 additions and 2 deletions

View File

@@ -119,7 +119,7 @@ texts = [{"text": "Capitalism has been dominant in the Western world since the e
tbl.add(texts)
```
## Gemini Embedding Function
### Gemini Embeddings
With Google's Gemini, you can represent text (words, sentences, and blocks of text) in a vectorized form, making it easier to compare and contrast embeddings. For example, two texts that share a similar subject matter or sentiment should have similar embeddings, which can be identified through mathematical comparison techniques such as cosine similarity. For more on how and why you should use embeddings, refer to the Embeddings guide.
The Gemini Embedding Model API supports various task types:
@@ -155,6 +155,51 @@ tbl.add(df)
rs = tbl.search("hello").limit(1).to_pandas()
```
### AWS Bedrock Text Embedding Functions
AWS Bedrock supports multiple base models for generating text embeddings. You need to setup the AWS credentials to use this embedding function.
You can do so by using `awscli` and also add your session_token:
```shell
aws configure
aws configure set aws_session_token "<your_session_token>"
```
to ensure that the credentials are set up correctly, you can run the following command:
```shell
aws sts get-caller-identity
```
Supported Embedding modelIDs are:
* `amazon.titan-embed-text-v1`
* `cohere.embed-english-v3`
* `cohere.embed-multilingual-v3`
Supported paramters (to be passed in `create` method) are:
| Parameter | Type | Default Value | Description |
|---|---|---|---|
| **name** | str | "amazon.titan-embed-text-v1" | The model ID of the bedrock model to use. Supported base models for Text Embeddings: amazon.titan-embed-text-v1, cohere.embed-english-v3, cohere.embed-multilingual-v3 |
| **region** | str | "us-east-1" | Optional name of the AWS Region in which the service should be called (e.g., "us-east-1"). |
| **profile_name** | str | None | Optional name of the AWS profile to use for calling the Bedrock service. If not specified, the default profile will be used. |
| **assumed_role** | str | None | Optional ARN of an AWS IAM role to assume for calling the Bedrock service. If not specified, the current active credentials will be used. |
| **role_session_name** | str | "lancedb-embeddings" | Optional name of the AWS IAM role session to use for calling the Bedrock service. If not specified, a "lancedb-embeddings" name will be used. |
| **runtime** | bool | True | Optional choice of getting different client to perform operations with the Amazon Bedrock service. |
| **max_retries** | int | 7 | Optional number of retries to perform when a request fails. |
Usage Example:
```python
model = get_registry().get("bedrock-text").create()
class TextModel(LanceModel):
text: str = model.SourceField()
vector: Vector(model.ndims()) = model.VectorField()
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
db = lancedb.connect("tmp_path")
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
tbl.add(df)
rs = tbl.search("hello").limit(1).to_pandas()
```
## Multi-modal embedding functions
Multi-modal embedding functions allow you to query your table using both images and text.

View File

@@ -13,6 +13,7 @@
# ruff: noqa: F401
from .base import EmbeddingFunction, EmbeddingFunctionConfig, TextEmbeddingFunction
from .bedrock import BedRockText
from .cohere import CohereEmbeddingFunction
from .gemini_text import GeminiText
from .instructor import InstructorEmbeddingFunction

View File

@@ -0,0 +1,223 @@
# Copyright (c) 2023. LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from functools import cached_property
from typing import List, Union
import numpy as np
from lancedb.pydantic import PYDANTIC_VERSION
from .base import TextEmbeddingFunction
from .registry import register
from .utils import TEXT
@register("bedrock-text")
class BedRockText(TextEmbeddingFunction):
"""
Parameters
----------
name: str, default "amazon.titan-embed-text-v1"
The model ID of the bedrock model to use. Supported models for are:
- amazon.titan-embed-text-v1
- cohere.embed-english-v3
- cohere.embed-multilingual-v3
region: str, default "us-east-1"
Optional name of the AWS Region in which the service should be called.
profile_name: str, default None
Optional name of the AWS profile to use for calling the Bedrock service.
If not specified, the default profile will be used.
assumed_role: str, default None
Optional ARN of an AWS IAM role to assume for calling the Bedrock service.
If not specified, the current active credentials will be used.
role_session_name: str, default "lancedb-embeddings"
Optional name of the AWS IAM role session to use for calling the Bedrock
service. If not specified, "lancedb-embeddings" name will be used.
Examples
--------
import lancedb
import pandas as pd
from lancedb.pydantic import LanceModel, Vector
model = get_registry().get("bedrock-text").create()
class TextModel(LanceModel):
text: str = model.SourceField()
vector: Vector(model.ndims()) = model.VectorField()
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
db = lancedb.connect("tmp_path")
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
tbl.add(df)
rs = tbl.search("hello").limit(1).to_pandas()
"""
name: str = "amazon.titan-embed-text-v1"
region: str = "us-east-1"
assumed_role: Union[str, None] = None
profile_name: Union[str, None] = None
role_session_name: str = "lancedb-embeddings"
if PYDANTIC_VERSION < (2, 0): # Pydantic 1.x compat
class Config:
keep_untouched = (cached_property,)
def ndims(self):
# return len(self._generate_embedding("test"))
# TODO: fix hardcoding
if self.name == "amazon.titan-embed-text-v1":
return 1536
elif self.name in {"cohere.embed-english-v3", "cohere.embed-multilingual-v3"}:
return 1024
else:
raise ValueError(f"Unknown model name: {self.name}")
def compute_query_embeddings(
self, query: str, *args, **kwargs
) -> List[List[float]]:
return self.compute_source_embeddings(query)
def compute_source_embeddings(
self, texts: TEXT, *args, **kwargs
) -> List[List[float]]:
texts = self.sanitize_input(texts)
return self.generate_embeddings(texts)
def generate_embeddings(
self, texts: Union[List[str], np.ndarray], *args, **kwargs
) -> List[List[float]]:
"""
Get the embeddings for the given texts
Parameters
----------
texts: list[str] or np.ndarray (of str)
The texts to embed
Returns
-------
list[list[float]]
The embeddings for the given texts
"""
results = []
for text in texts:
response = self._generate_embedding(text)
results.append(response)
return results
def _generate_embedding(self, text: str) -> List[float]:
"""
Get the embeddings for the given texts
Parameters
----------
texts: str
The texts to embed
Returns
-------
list[float]
The embeddings for the given texts
"""
# format input body for provider
provider = self.name.split(".")[0]
_model_kwargs = {}
input_body = {**_model_kwargs}
if provider == "cohere":
if "input_type" not in input_body.keys():
input_body["input_type"] = "search_document"
input_body["texts"] = [text]
else:
# includes common provider == "amazon"
input_body["inputText"] = text
body = json.dumps(input_body)
try:
# invoke bedrock API
response = self.client.invoke_model(
body=body,
modelId=self.name,
accept="application/json",
contentType="application/json",
)
# format output based on provider
response_body = json.loads(response.get("body").read())
if provider == "cohere":
return response_body.get("embeddings")[0]
else:
# includes common provider == "amazon"
return response_body.get("embedding")
except Exception as e:
help_txt = """
boto3 client failed to invoke the bedrock API. In case of
AWS credentials error:
- Please check your AWS credentials and ensure that you have access.
You can set up aws credentials using `aws configure` command and
verify by running `aws sts get-caller-identity` in your terminal.
"""
raise ValueError(f"Error raised by boto3 client: {e}. \n {help_txt}")
@cached_property
def client(self):
"""Create a boto3 client for Amazon Bedrock service
Returns
-------
boto3.client
The boto3 client for Amazon Bedrock service
"""
botocore = self.safe_import("botocore")
boto3 = self.safe_import("boto3")
session_kwargs = {"region_name": self.region}
client_kwargs = {**session_kwargs}
if self.profile_name:
session_kwargs["profile_name"] = self.profile_name
retry_config = botocore.config.Config(
region_name=self.region,
retries={
"max_attempts": 0, # disable this as retries retries are handled
"mode": "standard",
},
)
session = (
boto3.Session(**session_kwargs) if self.profile_name else boto3.Session()
)
if self.assumed_role: # if not using default credentials
sts = session.client("sts")
response = sts.assume_role(
RoleArn=str(self.assumed_role),
RoleSessionName=self.role_session_name,
)
client_kwargs["aws_access_key_id"] = response["Credentials"]["AccessKeyId"]
client_kwargs["aws_secret_access_key"] = response["Credentials"][
"SecretAccessKey"
]
client_kwargs["aws_session_token"] = response["Credentials"]["SessionToken"]
service_name = "bedrock-runtime"
bedrock_client = session.client(
service_name=service_name, config=retry_config, **client_kwargs
)
return bedrock_client

View File

@@ -49,7 +49,8 @@ tests = ["aiohttp", "pandas>=1.4", "pytest", "pytest-mock", "pytest-asyncio", "d
dev = ["ruff", "pre-commit"]
docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"]
clip = ["torch", "pillow", "open-clip"]
embeddings = ["openai>=1.6.1", "sentence-transformers", "torch", "pillow", "open-clip-torch", "cohere", "InstructorEmbedding"]
embeddings = ["openai>=1.6.1", "sentence-transformers", "torch", "pillow", "open-clip-torch", "cohere",
"InstructorEmbedding", "google.generativeai", "boto3>=1.28.57", "awscli>=1.29.57", "botocore>=1.31.57" ]
[build-system]
requires = ["setuptools", "wheel"]

View File

@@ -202,3 +202,38 @@ def test_gemini_embedding(tmp_path):
tbl.add(df)
assert len(tbl.to_pandas()["vector"][0]) == model.ndims()
assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world"
def aws_setup():
try:
import boto3
sts = boto3.client("sts")
sts.get_caller_identity()
return True
except Exception:
return False
@pytest.mark.slow
@pytest.mark.skipif(
not aws_setup(), reason="AWS credentials not set or libraries not installed"
)
def test_bedrock_embedding(tmp_path):
for name in [
"amazon.titan-embed-text-v1",
"cohere.embed-english-v3",
"cohere.embed-multilingual-v3",
]:
model = get_registry().get("bedrock-text").create(max_retries=0, name=name)
class TextModel(LanceModel):
text: str = model.SourceField()
vector: Vector(model.ndims()) = model.VectorField()
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
db = lancedb.connect(tmp_path)
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
tbl.add(df)
assert len(tbl.to_pandas()["vector"][0]) == model.ndims()