mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-09 05:12:58 +00:00
feat: add new feature, add amazon bedrock embedding function (#1788)
Add amazon bedrock embedding function to rust sdk. 1. Add BedrockEmbeddingModel ( lancedb/src/embeddings/bedrock.rs) 2. Add example lancedb/examples/bedrock.rs
This commit is contained in:
@@ -46,6 +46,7 @@ serde = { version = "^1" }
|
||||
serde_json = { version = "1" }
|
||||
async-openai = { version = "0.20.0", optional = true }
|
||||
serde_with = { version = "3.8.1" }
|
||||
aws-sdk-bedrockruntime = { version = "1.27.0", optional = true }
|
||||
# For remote feature
|
||||
reqwest = { version = "0.12.0", features = ["gzip", "json", "stream"], optional = true }
|
||||
rand = { version = "0.8.3", features = ["small_rng"], optional = true}
|
||||
@@ -72,11 +73,13 @@ aws-config = { version = "1.0" }
|
||||
aws-smithy-runtime = { version = "1.3" }
|
||||
http-body = "1" # Matching reqwest
|
||||
|
||||
|
||||
[features]
|
||||
default = []
|
||||
remote = ["dep:reqwest", "dep:http", "dep:rand", "dep:uuid"]
|
||||
fp16kernels = ["lance-linalg/fp16kernels"]
|
||||
s3-test = []
|
||||
bedrock = ["dep:aws-sdk-bedrockruntime"]
|
||||
openai = ["dep:async-openai", "dep:reqwest"]
|
||||
polars = ["dep:polars-arrow", "dep:polars"]
|
||||
sentence-transformers = [
|
||||
@@ -94,3 +97,7 @@ required-features = ["openai"]
|
||||
[[example]]
|
||||
name = "sentence_transformers"
|
||||
required-features = ["sentence-transformers"]
|
||||
|
||||
[[example]]
|
||||
name = "bedrock"
|
||||
required-features = ["bedrock"]
|
||||
|
||||
89
rust/lancedb/examples/bedrock.rs
Normal file
89
rust/lancedb/examples/bedrock.rs
Normal file
@@ -0,0 +1,89 @@
|
||||
use std::{iter::once, sync::Arc};
|
||||
|
||||
use arrow_array::{Float64Array, Int32Array, RecordBatch, RecordBatchIterator, StringArray};
|
||||
use arrow_schema::{DataType, Field, Schema};
|
||||
use aws_config::Region;
|
||||
use aws_sdk_bedrockruntime::Client;
|
||||
use futures::StreamExt;
|
||||
use lancedb::{
|
||||
arrow::IntoArrow,
|
||||
connect,
|
||||
embeddings::{bedrock::BedrockEmbeddingFunction, EmbeddingDefinition, EmbeddingFunction},
|
||||
query::{ExecutableQuery, QueryBase},
|
||||
Result,
|
||||
};
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
let tempdir = tempfile::tempdir().unwrap();
|
||||
let tempdir = tempdir.path().to_str().unwrap();
|
||||
|
||||
// create Bedrock embedding function
|
||||
let region: String = "us-east-1".to_string();
|
||||
let config = aws_config::defaults(aws_config::BehaviorVersion::latest())
|
||||
.region(Region::new(region))
|
||||
.load()
|
||||
.await;
|
||||
|
||||
let embedding = Arc::new(BedrockEmbeddingFunction::new(
|
||||
Client::new(&config), // AWS Region
|
||||
));
|
||||
|
||||
let db = connect(tempdir).execute().await?;
|
||||
db.embedding_registry()
|
||||
.register("bedrock", embedding.clone())?;
|
||||
|
||||
let table = db
|
||||
.create_table("vectors", make_data())
|
||||
.add_embedding(EmbeddingDefinition::new(
|
||||
"text",
|
||||
"bedrock",
|
||||
Some("embeddings"),
|
||||
))?
|
||||
.execute()
|
||||
.await?;
|
||||
|
||||
// execute vector search
|
||||
let query = Arc::new(StringArray::from_iter_values(once("something warm")));
|
||||
let query_vector = embedding.compute_query_embeddings(query)?;
|
||||
let mut results = table
|
||||
.vector_search(query_vector)?
|
||||
.limit(1)
|
||||
.execute()
|
||||
.await?;
|
||||
|
||||
let rb = results.next().await.unwrap()?;
|
||||
let out = rb
|
||||
.column_by_name("text")
|
||||
.unwrap()
|
||||
.as_any()
|
||||
.downcast_ref::<StringArray>()
|
||||
.unwrap();
|
||||
let text = out.iter().next().unwrap().unwrap();
|
||||
println!("Closest match: {}", text);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn make_data() -> impl IntoArrow {
|
||||
let schema = Schema::new(vec![
|
||||
Field::new("id", DataType::Int32, true),
|
||||
Field::new("text", DataType::Utf8, false),
|
||||
Field::new("price", DataType::Float64, false),
|
||||
]);
|
||||
|
||||
let id = Int32Array::from(vec![1, 2, 3, 4]);
|
||||
let text = StringArray::from_iter_values(vec![
|
||||
"Black T-Shirt",
|
||||
"Leather Jacket",
|
||||
"Winter Parka",
|
||||
"Hooded Sweatshirt",
|
||||
]);
|
||||
let price = Float64Array::from(vec![10.0, 50.0, 100.0, 30.0]);
|
||||
let schema = Arc::new(schema);
|
||||
let rb = RecordBatch::try_new(
|
||||
schema.clone(),
|
||||
vec![Arc::new(id), Arc::new(text), Arc::new(price)],
|
||||
)
|
||||
.unwrap();
|
||||
Box::new(RecordBatchIterator::new(vec![Ok(rb)], schema))
|
||||
}
|
||||
@@ -17,6 +17,9 @@ pub mod openai;
|
||||
#[cfg(feature = "sentence-transformers")]
|
||||
pub mod sentence_transformers;
|
||||
|
||||
#[cfg(feature = "bedrock")]
|
||||
pub mod bedrock;
|
||||
|
||||
use lance::arrow::RecordBatchExt;
|
||||
use std::{
|
||||
borrow::Cow,
|
||||
|
||||
210
rust/lancedb/src/embeddings/bedrock.rs
Normal file
210
rust/lancedb/src/embeddings/bedrock.rs
Normal file
@@ -0,0 +1,210 @@
|
||||
use aws_sdk_bedrockruntime::Client as BedrockClient;
|
||||
use std::{borrow::Cow, fmt::Formatter, str::FromStr, sync::Arc};
|
||||
|
||||
use arrow::array::{AsArray, Float32Builder};
|
||||
use arrow_array::{Array, ArrayRef, FixedSizeListArray, Float32Array};
|
||||
use arrow_data::ArrayData;
|
||||
use arrow_schema::DataType;
|
||||
use serde_json::{json, Value};
|
||||
|
||||
use super::EmbeddingFunction;
|
||||
use crate::{Error, Result};
|
||||
|
||||
use tokio::runtime::Handle;
|
||||
use tokio::task::block_in_place;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum BedrockEmbeddingModel {
|
||||
TitanEmbedding,
|
||||
CohereLarge,
|
||||
}
|
||||
|
||||
impl BedrockEmbeddingModel {
|
||||
fn ndims(&self) -> usize {
|
||||
match self {
|
||||
Self::TitanEmbedding => 1536,
|
||||
Self::CohereLarge => 1024,
|
||||
}
|
||||
}
|
||||
|
||||
fn model_id(&self) -> &str {
|
||||
match self {
|
||||
Self::TitanEmbedding => "amazon.titan-embed-text-v1",
|
||||
Self::CohereLarge => "cohere.embed-english-v3",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for BedrockEmbeddingModel {
|
||||
type Err = Error;
|
||||
|
||||
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
|
||||
match s {
|
||||
"titan-embed-text-v1" => Ok(Self::TitanEmbedding),
|
||||
"cohere-embed-english-v3" => Ok(Self::CohereLarge),
|
||||
_ => Err(Error::InvalidInput {
|
||||
message: "Invalid model. Available models are: 'titan-embed-text-v1', 'cohere-embed-english-v3'".to_string()
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct BedrockEmbeddingFunction {
|
||||
model: BedrockEmbeddingModel,
|
||||
client: BedrockClient,
|
||||
}
|
||||
|
||||
impl BedrockEmbeddingFunction {
|
||||
pub fn new(client: BedrockClient) -> Self {
|
||||
Self {
|
||||
model: BedrockEmbeddingModel::TitanEmbedding,
|
||||
client,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_model(client: BedrockClient, model: BedrockEmbeddingModel) -> Self {
|
||||
Self { model, client }
|
||||
}
|
||||
}
|
||||
|
||||
impl EmbeddingFunction for BedrockEmbeddingFunction {
|
||||
fn name(&self) -> &str {
|
||||
"bedrock"
|
||||
}
|
||||
|
||||
fn source_type(&self) -> Result<Cow<DataType>> {
|
||||
Ok(Cow::Owned(DataType::Utf8))
|
||||
}
|
||||
|
||||
fn dest_type(&self) -> Result<Cow<DataType>> {
|
||||
let n_dims = self.model.ndims();
|
||||
Ok(Cow::Owned(DataType::new_fixed_size_list(
|
||||
DataType::Float32,
|
||||
n_dims as i32,
|
||||
false,
|
||||
)))
|
||||
}
|
||||
|
||||
fn compute_source_embeddings(&self, source: ArrayRef) -> Result<ArrayRef> {
|
||||
let len = source.len();
|
||||
let n_dims = self.model.ndims();
|
||||
let inner = self.compute_inner(source)?;
|
||||
|
||||
let fsl = DataType::new_fixed_size_list(DataType::Float32, n_dims as i32, false);
|
||||
|
||||
let array_data = ArrayData::builder(fsl)
|
||||
.len(len)
|
||||
.add_child_data(inner.into_data())
|
||||
.build()?;
|
||||
|
||||
Ok(Arc::new(FixedSizeListArray::from(array_data)))
|
||||
}
|
||||
|
||||
fn compute_query_embeddings(&self, input: Arc<dyn Array>) -> Result<Arc<dyn Array>> {
|
||||
let arr = self.compute_inner(input)?;
|
||||
Ok(Arc::new(arr))
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for BedrockEmbeddingFunction {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("BedrockEmbeddingFunction")
|
||||
.field("model", &self.model)
|
||||
// Skip client field as it doesn't implement Debug
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl BedrockEmbeddingFunction {
|
||||
fn compute_inner(&self, source: Arc<dyn Array>) -> Result<Float32Array> {
|
||||
if source.is_nullable() {
|
||||
return Err(Error::InvalidInput {
|
||||
message: "Expected non-nullable data type".to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
if !matches!(source.data_type(), DataType::Utf8 | DataType::LargeUtf8) {
|
||||
return Err(Error::InvalidInput {
|
||||
message: "Expected Utf8 data type".to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
let mut builder = Float32Builder::new();
|
||||
|
||||
let texts = match source.data_type() {
|
||||
DataType::Utf8 => source
|
||||
.as_string::<i32>()
|
||||
.into_iter()
|
||||
.map(|s| s.expect("array is non-nullable").to_string())
|
||||
.collect::<Vec<String>>(),
|
||||
DataType::LargeUtf8 => source
|
||||
.as_string::<i64>()
|
||||
.into_iter()
|
||||
.map(|s| s.expect("array is non-nullable").to_string())
|
||||
.collect::<Vec<String>>(),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
for text in texts {
|
||||
let request_body = match self.model {
|
||||
BedrockEmbeddingModel::TitanEmbedding => {
|
||||
json!({
|
||||
"inputText": text
|
||||
})
|
||||
}
|
||||
BedrockEmbeddingModel::CohereLarge => {
|
||||
json!({
|
||||
"texts": [text],
|
||||
"input_type": "search_document"
|
||||
})
|
||||
}
|
||||
};
|
||||
|
||||
let client = self.client.clone();
|
||||
let model_id = self.model.model_id().to_string();
|
||||
let request_body = request_body.clone();
|
||||
|
||||
let response = block_in_place(move || {
|
||||
Handle::current().block_on(async move {
|
||||
client
|
||||
.invoke_model()
|
||||
.model_id(model_id)
|
||||
.body(aws_sdk_bedrockruntime::primitives::Blob::new(
|
||||
serde_json::to_vec(&request_body).unwrap(),
|
||||
))
|
||||
.send()
|
||||
.await
|
||||
})
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
let response_json: Value =
|
||||
serde_json::from_slice(response.body.as_ref()).map_err(|e| Error::Runtime {
|
||||
message: format!("Failed to parse response: {}", e),
|
||||
})?;
|
||||
|
||||
let embedding = match self.model {
|
||||
BedrockEmbeddingModel::TitanEmbedding => response_json["embedding"]
|
||||
.as_array()
|
||||
.ok_or_else(|| Error::Runtime {
|
||||
message: "Missing embedding in response".to_string(),
|
||||
})?
|
||||
.iter()
|
||||
.map(|v| v.as_f64().unwrap() as f32)
|
||||
.collect::<Vec<f32>>(),
|
||||
BedrockEmbeddingModel::CohereLarge => response_json["embeddings"][0]
|
||||
.as_array()
|
||||
.ok_or_else(|| Error::Runtime {
|
||||
message: "Missing embeddings in response".to_string(),
|
||||
})?
|
||||
.iter()
|
||||
.map(|v| v.as_f64().unwrap() as f32)
|
||||
.collect::<Vec<f32>>(),
|
||||
};
|
||||
|
||||
builder.append_slice(&embedding);
|
||||
}
|
||||
|
||||
Ok(builder.finish())
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user