mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-07 04:12:59 +00:00
feat(python): Aws Bedrock embeddings integration (#822)
Supports amazon titan, cohere english & cohere multi-lingual base models.
This commit is contained in:
committed by
Weston Pace
parent
f2e29eb004
commit
545a03d7f9
@@ -13,6 +13,7 @@
|
||||
|
||||
# ruff: noqa: F401
|
||||
from .base import EmbeddingFunction, EmbeddingFunctionConfig, TextEmbeddingFunction
|
||||
from .bedrock import BedRockText
|
||||
from .cohere import CohereEmbeddingFunction
|
||||
from .gemini_text import GeminiText
|
||||
from .instructor import InstructorEmbeddingFunction
|
||||
|
||||
223
python/lancedb/embeddings/bedrock.py
Normal file
223
python/lancedb/embeddings/bedrock.py
Normal file
@@ -0,0 +1,223 @@
|
||||
# Copyright (c) 2023. LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
from functools import cached_property
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lancedb.pydantic import PYDANTIC_VERSION
|
||||
|
||||
from .base import TextEmbeddingFunction
|
||||
from .registry import register
|
||||
from .utils import TEXT
|
||||
|
||||
|
||||
@register("bedrock-text")
|
||||
class BedRockText(TextEmbeddingFunction):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
name: str, default "amazon.titan-embed-text-v1"
|
||||
The model ID of the bedrock model to use. Supported models for are:
|
||||
- amazon.titan-embed-text-v1
|
||||
- cohere.embed-english-v3
|
||||
- cohere.embed-multilingual-v3
|
||||
region: str, default "us-east-1"
|
||||
Optional name of the AWS Region in which the service should be called.
|
||||
profile_name: str, default None
|
||||
Optional name of the AWS profile to use for calling the Bedrock service.
|
||||
If not specified, the default profile will be used.
|
||||
assumed_role: str, default None
|
||||
Optional ARN of an AWS IAM role to assume for calling the Bedrock service.
|
||||
If not specified, the current active credentials will be used.
|
||||
role_session_name: str, default "lancedb-embeddings"
|
||||
Optional name of the AWS IAM role session to use for calling the Bedrock
|
||||
service. If not specified, "lancedb-embeddings" name will be used.
|
||||
|
||||
Examples
|
||||
--------
|
||||
import lancedb
|
||||
import pandas as pd
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
|
||||
model = get_registry().get("bedrock-text").create()
|
||||
|
||||
class TextModel(LanceModel):
|
||||
text: str = model.SourceField()
|
||||
vector: Vector(model.ndims()) = model.VectorField()
|
||||
|
||||
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
|
||||
db = lancedb.connect("tmp_path")
|
||||
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
|
||||
|
||||
tbl.add(df)
|
||||
|
||||
rs = tbl.search("hello").limit(1).to_pandas()
|
||||
"""
|
||||
|
||||
name: str = "amazon.titan-embed-text-v1"
|
||||
region: str = "us-east-1"
|
||||
assumed_role: Union[str, None] = None
|
||||
profile_name: Union[str, None] = None
|
||||
role_session_name: str = "lancedb-embeddings"
|
||||
|
||||
if PYDANTIC_VERSION < (2, 0): # Pydantic 1.x compat
|
||||
|
||||
class Config:
|
||||
keep_untouched = (cached_property,)
|
||||
|
||||
def ndims(self):
|
||||
# return len(self._generate_embedding("test"))
|
||||
# TODO: fix hardcoding
|
||||
if self.name == "amazon.titan-embed-text-v1":
|
||||
return 1536
|
||||
elif self.name in {"cohere.embed-english-v3", "cohere.embed-multilingual-v3"}:
|
||||
return 1024
|
||||
else:
|
||||
raise ValueError(f"Unknown model name: {self.name}")
|
||||
|
||||
def compute_query_embeddings(
|
||||
self, query: str, *args, **kwargs
|
||||
) -> List[List[float]]:
|
||||
return self.compute_source_embeddings(query)
|
||||
|
||||
def compute_source_embeddings(
|
||||
self, texts: TEXT, *args, **kwargs
|
||||
) -> List[List[float]]:
|
||||
texts = self.sanitize_input(texts)
|
||||
return self.generate_embeddings(texts)
|
||||
|
||||
def generate_embeddings(
|
||||
self, texts: Union[List[str], np.ndarray], *args, **kwargs
|
||||
) -> List[List[float]]:
|
||||
"""
|
||||
Get the embeddings for the given texts
|
||||
|
||||
Parameters
|
||||
----------
|
||||
texts: list[str] or np.ndarray (of str)
|
||||
The texts to embed
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[list[float]]
|
||||
The embeddings for the given texts
|
||||
"""
|
||||
results = []
|
||||
for text in texts:
|
||||
response = self._generate_embedding(text)
|
||||
results.append(response)
|
||||
return results
|
||||
|
||||
def _generate_embedding(self, text: str) -> List[float]:
|
||||
"""
|
||||
Get the embeddings for the given texts
|
||||
|
||||
Parameters
|
||||
----------
|
||||
texts: str
|
||||
The texts to embed
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[float]
|
||||
The embeddings for the given texts
|
||||
"""
|
||||
# format input body for provider
|
||||
provider = self.name.split(".")[0]
|
||||
_model_kwargs = {}
|
||||
input_body = {**_model_kwargs}
|
||||
if provider == "cohere":
|
||||
if "input_type" not in input_body.keys():
|
||||
input_body["input_type"] = "search_document"
|
||||
input_body["texts"] = [text]
|
||||
else:
|
||||
# includes common provider == "amazon"
|
||||
input_body["inputText"] = text
|
||||
body = json.dumps(input_body)
|
||||
|
||||
try:
|
||||
# invoke bedrock API
|
||||
response = self.client.invoke_model(
|
||||
body=body,
|
||||
modelId=self.name,
|
||||
accept="application/json",
|
||||
contentType="application/json",
|
||||
)
|
||||
|
||||
# format output based on provider
|
||||
response_body = json.loads(response.get("body").read())
|
||||
if provider == "cohere":
|
||||
return response_body.get("embeddings")[0]
|
||||
else:
|
||||
# includes common provider == "amazon"
|
||||
return response_body.get("embedding")
|
||||
except Exception as e:
|
||||
help_txt = """
|
||||
boto3 client failed to invoke the bedrock API. In case of
|
||||
AWS credentials error:
|
||||
- Please check your AWS credentials and ensure that you have access.
|
||||
You can set up aws credentials using `aws configure` command and
|
||||
verify by running `aws sts get-caller-identity` in your terminal.
|
||||
"""
|
||||
raise ValueError(f"Error raised by boto3 client: {e}. \n {help_txt}")
|
||||
|
||||
@cached_property
|
||||
def client(self):
|
||||
"""Create a boto3 client for Amazon Bedrock service
|
||||
|
||||
Returns
|
||||
-------
|
||||
boto3.client
|
||||
The boto3 client for Amazon Bedrock service
|
||||
"""
|
||||
botocore = self.safe_import("botocore")
|
||||
boto3 = self.safe_import("boto3")
|
||||
|
||||
session_kwargs = {"region_name": self.region}
|
||||
client_kwargs = {**session_kwargs}
|
||||
|
||||
if self.profile_name:
|
||||
session_kwargs["profile_name"] = self.profile_name
|
||||
|
||||
retry_config = botocore.config.Config(
|
||||
region_name=self.region,
|
||||
retries={
|
||||
"max_attempts": 0, # disable this as retries retries are handled
|
||||
"mode": "standard",
|
||||
},
|
||||
)
|
||||
session = (
|
||||
boto3.Session(**session_kwargs) if self.profile_name else boto3.Session()
|
||||
)
|
||||
if self.assumed_role: # if not using default credentials
|
||||
sts = session.client("sts")
|
||||
response = sts.assume_role(
|
||||
RoleArn=str(self.assumed_role),
|
||||
RoleSessionName=self.role_session_name,
|
||||
)
|
||||
client_kwargs["aws_access_key_id"] = response["Credentials"]["AccessKeyId"]
|
||||
client_kwargs["aws_secret_access_key"] = response["Credentials"][
|
||||
"SecretAccessKey"
|
||||
]
|
||||
client_kwargs["aws_session_token"] = response["Credentials"]["SessionToken"]
|
||||
|
||||
service_name = "bedrock-runtime"
|
||||
|
||||
bedrock_client = session.client(
|
||||
service_name=service_name, config=retry_config, **client_kwargs
|
||||
)
|
||||
|
||||
return bedrock_client
|
||||
@@ -49,7 +49,8 @@ tests = ["aiohttp", "pandas>=1.4", "pytest", "pytest-mock", "pytest-asyncio", "d
|
||||
dev = ["ruff", "pre-commit"]
|
||||
docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"]
|
||||
clip = ["torch", "pillow", "open-clip"]
|
||||
embeddings = ["openai>=1.6.1", "sentence-transformers", "torch", "pillow", "open-clip-torch", "cohere", "InstructorEmbedding"]
|
||||
embeddings = ["openai>=1.6.1", "sentence-transformers", "torch", "pillow", "open-clip-torch", "cohere",
|
||||
"InstructorEmbedding", "google.generativeai", "boto3>=1.28.57", "awscli>=1.29.57", "botocore>=1.31.57" ]
|
||||
|
||||
[build-system]
|
||||
requires = ["setuptools", "wheel"]
|
||||
|
||||
@@ -202,3 +202,38 @@ def test_gemini_embedding(tmp_path):
|
||||
tbl.add(df)
|
||||
assert len(tbl.to_pandas()["vector"][0]) == model.ndims()
|
||||
assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world"
|
||||
|
||||
|
||||
def aws_setup():
|
||||
try:
|
||||
import boto3
|
||||
|
||||
sts = boto3.client("sts")
|
||||
sts.get_caller_identity()
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.skipif(
|
||||
not aws_setup(), reason="AWS credentials not set or libraries not installed"
|
||||
)
|
||||
def test_bedrock_embedding(tmp_path):
|
||||
for name in [
|
||||
"amazon.titan-embed-text-v1",
|
||||
"cohere.embed-english-v3",
|
||||
"cohere.embed-multilingual-v3",
|
||||
]:
|
||||
model = get_registry().get("bedrock-text").create(max_retries=0, name=name)
|
||||
|
||||
class TextModel(LanceModel):
|
||||
text: str = model.SourceField()
|
||||
vector: Vector(model.ndims()) = model.VectorField()
|
||||
|
||||
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
|
||||
db = lancedb.connect(tmp_path)
|
||||
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
|
||||
|
||||
tbl.add(df)
|
||||
assert len(tbl.to_pandas()["vector"][0]) == model.ndims()
|
||||
|
||||
Reference in New Issue
Block a user