From 74f660d22342e0e0ad90c50b017fcf239ff13b71 Mon Sep 17 00:00:00 2001 From: StevenSu Date: Fri, 15 Nov 2024 03:04:59 +0800 Subject: [PATCH] feat: add new feature, add amazon bedrock embedding function (#1788) Add amazon bedrock embedding function to rust sdk. 1. Add BedrockEmbeddingModel ( lancedb/src/embeddings/bedrock.rs) 2. Add example lancedb/examples/bedrock.rs --- rust/lancedb/Cargo.toml | 7 + rust/lancedb/examples/bedrock.rs | 89 +++++++++++ rust/lancedb/src/embeddings.rs | 3 + rust/lancedb/src/embeddings/bedrock.rs | 210 +++++++++++++++++++++++++ 4 files changed, 309 insertions(+) create mode 100644 rust/lancedb/examples/bedrock.rs create mode 100644 rust/lancedb/src/embeddings/bedrock.rs diff --git a/rust/lancedb/Cargo.toml b/rust/lancedb/Cargo.toml index 35316ef4..1fd7298a 100644 --- a/rust/lancedb/Cargo.toml +++ b/rust/lancedb/Cargo.toml @@ -46,6 +46,7 @@ serde = { version = "^1" } serde_json = { version = "1" } async-openai = { version = "0.20.0", optional = true } serde_with = { version = "3.8.1" } +aws-sdk-bedrockruntime = { version = "1.27.0", optional = true } # For remote feature reqwest = { version = "0.12.0", features = ["gzip", "json", "stream"], optional = true } rand = { version = "0.8.3", features = ["small_rng"], optional = true} @@ -72,11 +73,13 @@ aws-config = { version = "1.0" } aws-smithy-runtime = { version = "1.3" } http-body = "1" # Matching reqwest + [features] default = [] remote = ["dep:reqwest", "dep:http", "dep:rand", "dep:uuid"] fp16kernels = ["lance-linalg/fp16kernels"] s3-test = [] +bedrock = ["dep:aws-sdk-bedrockruntime"] openai = ["dep:async-openai", "dep:reqwest"] polars = ["dep:polars-arrow", "dep:polars"] sentence-transformers = [ @@ -94,3 +97,7 @@ required-features = ["openai"] [[example]] name = "sentence_transformers" required-features = ["sentence-transformers"] + +[[example]] +name = "bedrock" +required-features = ["bedrock"] diff --git a/rust/lancedb/examples/bedrock.rs b/rust/lancedb/examples/bedrock.rs new file mode 100644 index 00000000..3b9c7a23 --- /dev/null +++ b/rust/lancedb/examples/bedrock.rs @@ -0,0 +1,89 @@ +use std::{iter::once, sync::Arc}; + +use arrow_array::{Float64Array, Int32Array, RecordBatch, RecordBatchIterator, StringArray}; +use arrow_schema::{DataType, Field, Schema}; +use aws_config::Region; +use aws_sdk_bedrockruntime::Client; +use futures::StreamExt; +use lancedb::{ + arrow::IntoArrow, + connect, + embeddings::{bedrock::BedrockEmbeddingFunction, EmbeddingDefinition, EmbeddingFunction}, + query::{ExecutableQuery, QueryBase}, + Result, +}; + +#[tokio::main] +async fn main() -> Result<()> { + let tempdir = tempfile::tempdir().unwrap(); + let tempdir = tempdir.path().to_str().unwrap(); + + // create Bedrock embedding function + let region: String = "us-east-1".to_string(); + let config = aws_config::defaults(aws_config::BehaviorVersion::latest()) + .region(Region::new(region)) + .load() + .await; + + let embedding = Arc::new(BedrockEmbeddingFunction::new( + Client::new(&config), // AWS Region + )); + + let db = connect(tempdir).execute().await?; + db.embedding_registry() + .register("bedrock", embedding.clone())?; + + let table = db + .create_table("vectors", make_data()) + .add_embedding(EmbeddingDefinition::new( + "text", + "bedrock", + Some("embeddings"), + ))? + .execute() + .await?; + + // execute vector search + let query = Arc::new(StringArray::from_iter_values(once("something warm"))); + let query_vector = embedding.compute_query_embeddings(query)?; + let mut results = table + .vector_search(query_vector)? + .limit(1) + .execute() + .await?; + + let rb = results.next().await.unwrap()?; + let out = rb + .column_by_name("text") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + let text = out.iter().next().unwrap().unwrap(); + println!("Closest match: {}", text); + Ok(()) +} + +fn make_data() -> impl IntoArrow { + let schema = Schema::new(vec![ + Field::new("id", DataType::Int32, true), + Field::new("text", DataType::Utf8, false), + Field::new("price", DataType::Float64, false), + ]); + + let id = Int32Array::from(vec![1, 2, 3, 4]); + let text = StringArray::from_iter_values(vec![ + "Black T-Shirt", + "Leather Jacket", + "Winter Parka", + "Hooded Sweatshirt", + ]); + let price = Float64Array::from(vec![10.0, 50.0, 100.0, 30.0]); + let schema = Arc::new(schema); + let rb = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(id), Arc::new(text), Arc::new(price)], + ) + .unwrap(); + Box::new(RecordBatchIterator::new(vec![Ok(rb)], schema)) +} diff --git a/rust/lancedb/src/embeddings.rs b/rust/lancedb/src/embeddings.rs index 5093e931..832b4278 100644 --- a/rust/lancedb/src/embeddings.rs +++ b/rust/lancedb/src/embeddings.rs @@ -17,6 +17,9 @@ pub mod openai; #[cfg(feature = "sentence-transformers")] pub mod sentence_transformers; +#[cfg(feature = "bedrock")] +pub mod bedrock; + use lance::arrow::RecordBatchExt; use std::{ borrow::Cow, diff --git a/rust/lancedb/src/embeddings/bedrock.rs b/rust/lancedb/src/embeddings/bedrock.rs new file mode 100644 index 00000000..d0fbc921 --- /dev/null +++ b/rust/lancedb/src/embeddings/bedrock.rs @@ -0,0 +1,210 @@ +use aws_sdk_bedrockruntime::Client as BedrockClient; +use std::{borrow::Cow, fmt::Formatter, str::FromStr, sync::Arc}; + +use arrow::array::{AsArray, Float32Builder}; +use arrow_array::{Array, ArrayRef, FixedSizeListArray, Float32Array}; +use arrow_data::ArrayData; +use arrow_schema::DataType; +use serde_json::{json, Value}; + +use super::EmbeddingFunction; +use crate::{Error, Result}; + +use tokio::runtime::Handle; +use tokio::task::block_in_place; + +#[derive(Debug)] +pub enum BedrockEmbeddingModel { + TitanEmbedding, + CohereLarge, +} + +impl BedrockEmbeddingModel { + fn ndims(&self) -> usize { + match self { + Self::TitanEmbedding => 1536, + Self::CohereLarge => 1024, + } + } + + fn model_id(&self) -> &str { + match self { + Self::TitanEmbedding => "amazon.titan-embed-text-v1", + Self::CohereLarge => "cohere.embed-english-v3", + } + } +} + +impl FromStr for BedrockEmbeddingModel { + type Err = Error; + + fn from_str(s: &str) -> std::result::Result { + match s { + "titan-embed-text-v1" => Ok(Self::TitanEmbedding), + "cohere-embed-english-v3" => Ok(Self::CohereLarge), + _ => Err(Error::InvalidInput { + message: "Invalid model. Available models are: 'titan-embed-text-v1', 'cohere-embed-english-v3'".to_string() + }), + } + } +} + +pub struct BedrockEmbeddingFunction { + model: BedrockEmbeddingModel, + client: BedrockClient, +} + +impl BedrockEmbeddingFunction { + pub fn new(client: BedrockClient) -> Self { + Self { + model: BedrockEmbeddingModel::TitanEmbedding, + client, + } + } + + pub fn with_model(client: BedrockClient, model: BedrockEmbeddingModel) -> Self { + Self { model, client } + } +} + +impl EmbeddingFunction for BedrockEmbeddingFunction { + fn name(&self) -> &str { + "bedrock" + } + + fn source_type(&self) -> Result> { + Ok(Cow::Owned(DataType::Utf8)) + } + + fn dest_type(&self) -> Result> { + let n_dims = self.model.ndims(); + Ok(Cow::Owned(DataType::new_fixed_size_list( + DataType::Float32, + n_dims as i32, + false, + ))) + } + + fn compute_source_embeddings(&self, source: ArrayRef) -> Result { + let len = source.len(); + let n_dims = self.model.ndims(); + let inner = self.compute_inner(source)?; + + let fsl = DataType::new_fixed_size_list(DataType::Float32, n_dims as i32, false); + + let array_data = ArrayData::builder(fsl) + .len(len) + .add_child_data(inner.into_data()) + .build()?; + + Ok(Arc::new(FixedSizeListArray::from(array_data))) + } + + fn compute_query_embeddings(&self, input: Arc) -> Result> { + let arr = self.compute_inner(input)?; + Ok(Arc::new(arr)) + } +} + +impl std::fmt::Debug for BedrockEmbeddingFunction { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("BedrockEmbeddingFunction") + .field("model", &self.model) + // Skip client field as it doesn't implement Debug + .finish() + } +} + +impl BedrockEmbeddingFunction { + fn compute_inner(&self, source: Arc) -> Result { + if source.is_nullable() { + return Err(Error::InvalidInput { + message: "Expected non-nullable data type".to_string(), + }); + } + + if !matches!(source.data_type(), DataType::Utf8 | DataType::LargeUtf8) { + return Err(Error::InvalidInput { + message: "Expected Utf8 data type".to_string(), + }); + } + + let mut builder = Float32Builder::new(); + + let texts = match source.data_type() { + DataType::Utf8 => source + .as_string::() + .into_iter() + .map(|s| s.expect("array is non-nullable").to_string()) + .collect::>(), + DataType::LargeUtf8 => source + .as_string::() + .into_iter() + .map(|s| s.expect("array is non-nullable").to_string()) + .collect::>(), + _ => unreachable!(), + }; + + for text in texts { + let request_body = match self.model { + BedrockEmbeddingModel::TitanEmbedding => { + json!({ + "inputText": text + }) + } + BedrockEmbeddingModel::CohereLarge => { + json!({ + "texts": [text], + "input_type": "search_document" + }) + } + }; + + let client = self.client.clone(); + let model_id = self.model.model_id().to_string(); + let request_body = request_body.clone(); + + let response = block_in_place(move || { + Handle::current().block_on(async move { + client + .invoke_model() + .model_id(model_id) + .body(aws_sdk_bedrockruntime::primitives::Blob::new( + serde_json::to_vec(&request_body).unwrap(), + )) + .send() + .await + }) + }) + .unwrap(); + + let response_json: Value = + serde_json::from_slice(response.body.as_ref()).map_err(|e| Error::Runtime { + message: format!("Failed to parse response: {}", e), + })?; + + let embedding = match self.model { + BedrockEmbeddingModel::TitanEmbedding => response_json["embedding"] + .as_array() + .ok_or_else(|| Error::Runtime { + message: "Missing embedding in response".to_string(), + })? + .iter() + .map(|v| v.as_f64().unwrap() as f32) + .collect::>(), + BedrockEmbeddingModel::CohereLarge => response_json["embeddings"][0] + .as_array() + .ok_or_else(|| Error::Runtime { + message: "Missing embeddings in response".to_string(), + })? + .iter() + .map(|v| v.as_f64().unwrap() as f32) + .collect::>(), + }; + + builder.append_slice(&embedding); + } + + Ok(builder.finish()) + } +}