mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-25 14:29:56 +00:00
closes #1194 #1172 #1124 #1208 @wjones127 : `if query_type != "fts":` is needed because both fts and vector search create `LanceQueryBuilder` which has `vector_column_name` as a required attribute.
228 lines
7.7 KiB
Python
228 lines
7.7 KiB
Python
# Copyright (c) 2023. LanceDB Developers
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import json
|
|
from functools import cached_property
|
|
from typing import List, Union
|
|
|
|
import numpy as np
|
|
|
|
from lancedb.pydantic import PYDANTIC_VERSION
|
|
|
|
from ..util import attempt_import_or_raise
|
|
from .base import TextEmbeddingFunction
|
|
from .registry import register
|
|
from .utils import TEXT
|
|
|
|
|
|
@register("bedrock-text")
|
|
class BedRockText(TextEmbeddingFunction):
|
|
"""
|
|
Parameters
|
|
----------
|
|
name: str, default "amazon.titan-embed-text-v1"
|
|
The model ID of the bedrock model to use. Supported models for are:
|
|
- amazon.titan-embed-text-v1
|
|
- cohere.embed-english-v3
|
|
- cohere.embed-multilingual-v3
|
|
region: str, default "us-east-1"
|
|
Optional name of the AWS Region in which the service should be called.
|
|
profile_name: str, default None
|
|
Optional name of the AWS profile to use for calling the Bedrock service.
|
|
If not specified, the default profile will be used.
|
|
assumed_role: str, default None
|
|
Optional ARN of an AWS IAM role to assume for calling the Bedrock service.
|
|
If not specified, the current active credentials will be used.
|
|
role_session_name: str, default "lancedb-embeddings"
|
|
Optional name of the AWS IAM role session to use for calling the Bedrock
|
|
service. If not specified, "lancedb-embeddings" name will be used.
|
|
|
|
Examples
|
|
--------
|
|
import lancedb
|
|
import pandas as pd
|
|
from lancedb.pydantic import LanceModel, Vector
|
|
|
|
model = get_registry().get("bedrock-text").create()
|
|
|
|
class TextModel(LanceModel):
|
|
text: str = model.SourceField()
|
|
vector: Vector(model.ndims()) = model.VectorField()
|
|
|
|
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
|
|
db = lancedb.connect("tmp_path")
|
|
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
|
|
|
|
tbl.add(df)
|
|
|
|
rs = tbl.search("hello").limit(1).to_pandas()
|
|
"""
|
|
|
|
name: str = "amazon.titan-embed-text-v1"
|
|
region: str = "us-east-1"
|
|
assumed_role: Union[str, None] = None
|
|
profile_name: Union[str, None] = None
|
|
role_session_name: str = "lancedb-embeddings"
|
|
|
|
if PYDANTIC_VERSION < (2, 0): # Pydantic 1.x compat
|
|
|
|
class Config:
|
|
keep_untouched = (cached_property,)
|
|
else:
|
|
model_config = dict()
|
|
model_config["ignored_types"] = (cached_property,)
|
|
|
|
def ndims(self):
|
|
# return len(self._generate_embedding("test"))
|
|
# TODO: fix hardcoding
|
|
if self.name == "amazon.titan-embed-text-v1":
|
|
return 1536
|
|
elif self.name in {"cohere.embed-english-v3", "cohere.embed-multilingual-v3"}:
|
|
return 1024
|
|
else:
|
|
raise ValueError(f"Unknown model name: {self.name}")
|
|
|
|
def compute_query_embeddings(
|
|
self, query: str, *args, **kwargs
|
|
) -> List[List[float]]:
|
|
return self.compute_source_embeddings(query)
|
|
|
|
def compute_source_embeddings(
|
|
self, texts: TEXT, *args, **kwargs
|
|
) -> List[List[float]]:
|
|
texts = self.sanitize_input(texts)
|
|
return self.generate_embeddings(texts)
|
|
|
|
def generate_embeddings(
|
|
self, texts: Union[List[str], np.ndarray], *args, **kwargs
|
|
) -> List[List[float]]:
|
|
"""
|
|
Get the embeddings for the given texts
|
|
|
|
Parameters
|
|
----------
|
|
texts: list[str] or np.ndarray (of str)
|
|
The texts to embed
|
|
|
|
Returns
|
|
-------
|
|
list[list[float]]
|
|
The embeddings for the given texts
|
|
"""
|
|
results = []
|
|
for text in texts:
|
|
response = self._generate_embedding(text)
|
|
results.append(response)
|
|
return results
|
|
|
|
def _generate_embedding(self, text: str) -> List[float]:
|
|
"""
|
|
Get the embeddings for the given texts
|
|
|
|
Parameters
|
|
----------
|
|
texts: str
|
|
The texts to embed
|
|
|
|
Returns
|
|
-------
|
|
list[float]
|
|
The embeddings for the given texts
|
|
"""
|
|
# format input body for provider
|
|
provider = self.name.split(".")[0]
|
|
_model_kwargs = {}
|
|
input_body = {**_model_kwargs}
|
|
if provider == "cohere":
|
|
if "input_type" not in input_body.keys():
|
|
input_body["input_type"] = "search_document"
|
|
input_body["texts"] = [text]
|
|
else:
|
|
# includes common provider == "amazon"
|
|
input_body["inputText"] = text
|
|
body = json.dumps(input_body)
|
|
|
|
try:
|
|
# invoke bedrock API
|
|
response = self.client.invoke_model(
|
|
body=body,
|
|
modelId=self.name,
|
|
accept="application/json",
|
|
contentType="application/json",
|
|
)
|
|
|
|
# format output based on provider
|
|
response_body = json.loads(response.get("body").read())
|
|
if provider == "cohere":
|
|
return response_body.get("embeddings")[0]
|
|
else:
|
|
# includes common provider == "amazon"
|
|
return response_body.get("embedding")
|
|
except Exception as e:
|
|
help_txt = """
|
|
boto3 client failed to invoke the bedrock API. In case of
|
|
AWS credentials error:
|
|
- Please check your AWS credentials and ensure that you have access.
|
|
You can set up aws credentials using `aws configure` command and
|
|
verify by running `aws sts get-caller-identity` in your terminal.
|
|
"""
|
|
raise ValueError(f"Error raised by boto3 client: {e}. \n {help_txt}")
|
|
|
|
@cached_property
|
|
def client(self):
|
|
"""Create a boto3 client for Amazon Bedrock service
|
|
|
|
Returns
|
|
-------
|
|
boto3.client
|
|
The boto3 client for Amazon Bedrock service
|
|
"""
|
|
botocore = attempt_import_or_raise("botocore")
|
|
boto3 = attempt_import_or_raise("boto3")
|
|
|
|
session_kwargs = {"region_name": self.region}
|
|
client_kwargs = {**session_kwargs}
|
|
|
|
if self.profile_name:
|
|
session_kwargs["profile_name"] = self.profile_name
|
|
|
|
retry_config = botocore.config.Config(
|
|
region_name=self.region,
|
|
retries={
|
|
"max_attempts": 0, # disable this as retries retries are handled
|
|
"mode": "standard",
|
|
},
|
|
)
|
|
session = (
|
|
boto3.Session(**session_kwargs) if self.profile_name else boto3.Session()
|
|
)
|
|
if self.assumed_role: # if not using default credentials
|
|
sts = session.client("sts")
|
|
response = sts.assume_role(
|
|
RoleArn=str(self.assumed_role),
|
|
RoleSessionName=self.role_session_name,
|
|
)
|
|
client_kwargs["aws_access_key_id"] = response["Credentials"]["AccessKeyId"]
|
|
client_kwargs["aws_secret_access_key"] = response["Credentials"][
|
|
"SecretAccessKey"
|
|
]
|
|
client_kwargs["aws_session_token"] = response["Credentials"]["SessionToken"]
|
|
|
|
service_name = "bedrock-runtime"
|
|
|
|
bedrock_client = session.client(
|
|
service_name=service_name, config=retry_config, **client_kwargs
|
|
)
|
|
|
|
return bedrock_client
|