diff --git a/rust/lancedb/Cargo.toml b/rust/lancedb/Cargo.toml index b9ac536f..ea606727 100644 --- a/rust/lancedb/Cargo.toml +++ b/rust/lancedb/Cargo.toml @@ -38,6 +38,7 @@ url.workspace = true regex.workspace = true serde = { version = "^1" } serde_json = { version = "1" } +async-openai = { version = "0.20.0", optional = true } serde_with = { version = "3.8.1" } # For remote feature reqwest = { version = "0.11.24", features = ["gzip", "json"], optional = true } @@ -62,4 +63,10 @@ default = [] remote = ["dep:reqwest"] fp16kernels = ["lance-linalg/fp16kernels"] s3-test = [] +openai = ["dep:async-openai", "dep:reqwest"] polars = ["dep:polars-arrow", "dep:polars"] + + +[[example]] +name = "openai" +required-features = ["openai"] diff --git a/rust/lancedb/examples/openai.rs b/rust/lancedb/examples/openai.rs new file mode 100644 index 00000000..ce5f811f --- /dev/null +++ b/rust/lancedb/examples/openai.rs @@ -0,0 +1,82 @@ +use std::{iter::once, sync::Arc}; + +use arrow_array::{Float64Array, Int32Array, RecordBatch, RecordBatchIterator, StringArray}; +use arrow_schema::{DataType, Field, Schema}; +use futures::StreamExt; +use lancedb::{ + arrow::IntoArrow, + connect, + embeddings::{openai::OpenAIEmbeddingFunction, EmbeddingDefinition, EmbeddingFunction}, + query::{ExecutableQuery, QueryBase}, + Result, +}; + +#[tokio::main] +async fn main() -> Result<()> { + let tempdir = tempfile::tempdir().unwrap(); + let tempdir = tempdir.path().to_str().unwrap(); + let api_key = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY is not set"); + let embedding = Arc::new(OpenAIEmbeddingFunction::new_with_model( + api_key, + "text-embedding-3-large", + )?); + + let db = connect(tempdir).execute().await?; + db.embedding_registry() + .register("openai", embedding.clone())?; + + let table = db + .create_table("vectors", make_data()) + .add_embedding(EmbeddingDefinition::new( + "text", + "openai", + Some("embeddings"), + ))? + .execute() + .await?; + + // there is no equivalent to '.search()' yet + let query = Arc::new(StringArray::from_iter_values(once("something warm"))); + let query_vector = embedding.compute_query_embeddings(query)?; + let mut results = table + .vector_search(query_vector)? + .limit(1) + .execute() + .await?; + + let rb = results.next().await.unwrap()?; + let out = rb + .column_by_name("text") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + let text = out.iter().next().unwrap().unwrap(); + println!("Closest match: {}", text); + + Ok(()) +} + +fn make_data() -> impl IntoArrow { + let schema = Schema::new(vec![ + Field::new("id", DataType::Int32, true), + Field::new("text", DataType::Utf8, false), + Field::new("price", DataType::Float64, false), + ]); + + let id = Int32Array::from(vec![1, 2, 3, 4]); + let text = StringArray::from_iter_values(vec![ + "Black T-Shirt", + "Leather Jacket", + "Winter Parka", + "Hooded Sweatshirt", + ]); + let price = Float64Array::from(vec![10.0, 50.0, 100.0, 30.0]); + let schema = Arc::new(schema); + let rb = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(id), Arc::new(text), Arc::new(price)], + ) + .unwrap(); + Box::new(RecordBatchIterator::new(vec![Ok(rb)], schema)) +} diff --git a/rust/lancedb/src/embeddings.rs b/rust/lancedb/src/embeddings.rs index 007d0543..614f2ec3 100644 --- a/rust/lancedb/src/embeddings.rs +++ b/rust/lancedb/src/embeddings.rs @@ -11,6 +11,8 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. +#[cfg(feature = "openai")] +pub mod openai; use lance::arrow::RecordBatchExt; use std::{ @@ -51,8 +53,10 @@ pub trait EmbeddingFunction: std::fmt::Debug + Send + Sync { /// The type of the output data /// This should **always** match the output of the `embed` function fn dest_type(&self) -> Result>; - /// Embed the input - fn embed(&self, source: Arc) -> Result>; + /// Compute the embeddings for the source column in the database + fn compute_source_embeddings(&self, source: Arc) -> Result>; + /// Compute the embeddings for a given user query + fn compute_query_embeddings(&self, input: Arc) -> Result>; } /// Defines an embedding from input data into a lower-dimensional space @@ -266,7 +270,7 @@ impl Iterator for WithEmbeddings { // todo: parallelize this for (fld, func) in self.embeddings.iter() { let src_column = batch.column_by_name(&fld.source_column).unwrap(); - let embedding = match func.embed(src_column.clone()) { + let embedding = match func.compute_source_embeddings(src_column.clone()) { Ok(embedding) => embedding, Err(e) => { return Some(Err(arrow_schema::ArrowError::ComputeError(format!( diff --git a/rust/lancedb/src/embeddings/openai.rs b/rust/lancedb/src/embeddings/openai.rs new file mode 100644 index 00000000..88528919 --- /dev/null +++ b/rust/lancedb/src/embeddings/openai.rs @@ -0,0 +1,257 @@ +use std::{borrow::Cow, fmt::Formatter, str::FromStr, sync::Arc}; + +use arrow::array::{AsArray, Float32Builder}; +use arrow_array::{Array, ArrayRef, FixedSizeListArray, Float32Array}; +use arrow_data::ArrayData; +use arrow_schema::DataType; +use async_openai::{ + config::OpenAIConfig, + types::{CreateEmbeddingRequest, Embedding, EmbeddingInput, EncodingFormat}, + Client, +}; +use tokio::{runtime::Handle, task}; + +use crate::{Error, Result}; + +use super::EmbeddingFunction; + +#[derive(Debug)] +pub enum EmbeddingModel { + TextEmbeddingAda002, + TextEmbedding3Small, + TextEmbedding3Large, +} + +impl EmbeddingModel { + fn ndims(&self) -> usize { + match self { + Self::TextEmbeddingAda002 => 1536, + Self::TextEmbedding3Small => 1536, + Self::TextEmbedding3Large => 3072, + } + } +} + +impl FromStr for EmbeddingModel { + type Err = Error; + + fn from_str(s: &str) -> std::result::Result { + match s { + "text-embedding-ada-002" => Ok(Self::TextEmbeddingAda002), + "text-embedding-3-small" => Ok(Self::TextEmbedding3Small), + "text-embedding-3-large" => Ok(Self::TextEmbedding3Large), + _ => Err(Error::InvalidInput { + message: "Invalid input. Available models are: 'text-embedding-3-small', 'text-embedding-ada-002', 'text-embedding-3-large' ".to_string() + }), + } + } +} + +impl std::fmt::Display for EmbeddingModel { + fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { + match self { + Self::TextEmbeddingAda002 => write!(f, "text-embedding-ada-002"), + Self::TextEmbedding3Small => write!(f, "text-embedding-3-small"), + Self::TextEmbedding3Large => write!(f, "text-embedding-3-large"), + } + } +} + +impl TryFrom<&str> for EmbeddingModel { + type Error = Error; + + fn try_from(value: &str) -> std::result::Result { + value.parse() + } +} + +pub struct OpenAIEmbeddingFunction { + model: EmbeddingModel, + api_key: String, + api_base: Option, + org_id: Option, +} + +impl std::fmt::Debug for OpenAIEmbeddingFunction { + fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { + // let's be safe and not print the full API key + let creds_display = if self.api_key.len() > 6 { + format!( + "{}***{}", + &self.api_key[0..2], + &self.api_key[self.api_key.len() - 4..] + ) + } else { + "[INVALID]".to_string() + }; + + f.debug_struct("OpenAI") + .field("model", &self.model) + .field("api_key", &creds_display) + .field("api_base", &self.api_base) + .field("org_id", &self.org_id) + .finish() + } +} + +impl OpenAIEmbeddingFunction { + /// Create a new OpenAIEmbeddingFunction + pub fn new>(api_key: A) -> Self { + Self::new_impl(api_key.into(), EmbeddingModel::TextEmbeddingAda002) + } + + pub fn new_with_model, M: TryInto>( + api_key: A, + model: M, + ) -> crate::Result + where + M::Error: Into, + { + Ok(Self::new_impl( + api_key.into(), + model.try_into().map_err(|e| e.into())?, + )) + } + + /// concrete implementation to reduce monomorphization + fn new_impl(api_key: String, model: EmbeddingModel) -> Self { + Self { + model, + api_key, + api_base: None, + org_id: None, + } + } + + /// To use a API base url different from default "https://api.openai.com/v1" + pub fn api_base>(mut self, api_base: S) -> Self { + self.api_base = Some(api_base.into()); + self + } + + /// To use a different OpenAI organization id other than default + pub fn org_id>(mut self, org_id: S) -> Self { + self.org_id = Some(org_id.into()); + self + } +} + +impl EmbeddingFunction for OpenAIEmbeddingFunction { + fn name(&self) -> &str { + "openai" + } + + fn source_type(&self) -> Result> { + Ok(Cow::Owned(DataType::Utf8)) + } + + fn dest_type(&self) -> Result> { + let n_dims = self.model.ndims(); + Ok(Cow::Owned(DataType::new_fixed_size_list( + DataType::Float32, + n_dims as i32, + false, + ))) + } + + fn compute_source_embeddings(&self, source: ArrayRef) -> crate::Result { + let len = source.len(); + let n_dims = self.model.ndims(); + let inner = self.compute_inner(source)?; + + let fsl = DataType::new_fixed_size_list(DataType::Float32, n_dims as i32, false); + + // We can't use the FixedSizeListBuilder here because it always adds a null bitmap + // and we want to explicitly work with non-nullable arrays. + let array_data = ArrayData::builder(fsl) + .len(len) + .add_child_data(inner.into_data()) + .build()?; + + Ok(Arc::new(FixedSizeListArray::from(array_data))) + } + + fn compute_query_embeddings(&self, input: Arc) -> Result> { + let arr = self.compute_inner(input)?; + Ok(Arc::new(arr)) + } +} +impl OpenAIEmbeddingFunction { + fn compute_inner(&self, source: Arc) -> Result { + // OpenAI only supports non-nullable string arrays + if source.is_nullable() { + return Err(crate::Error::InvalidInput { + message: "Expected non-nullable data type".to_string(), + }); + } + + // OpenAI only supports string arrays + if !matches!(source.data_type(), DataType::Utf8 | DataType::LargeUtf8) { + return Err(crate::Error::InvalidInput { + message: "Expected Utf8 data type".to_string(), + }); + }; + + let mut creds = OpenAIConfig::new().with_api_key(self.api_key.clone()); + + if let Some(api_base) = &self.api_base { + creds = creds.with_api_base(api_base.clone()); + } + if let Some(org_id) = &self.org_id { + creds = creds.with_org_id(org_id.clone()); + } + + let input = match source.data_type() { + DataType::Utf8 => { + let array = source + .as_string::() + .into_iter() + .map(|s| { + s.expect("we already asserted that the array is non-nullable") + .to_string() + }) + .collect::>(); + EmbeddingInput::StringArray(array) + } + DataType::LargeUtf8 => { + let array = source + .as_string::() + .into_iter() + .map(|s| { + s.expect("we already asserted that the array is non-nullable") + .to_string() + }) + .collect::>(); + EmbeddingInput::StringArray(array) + } + _ => unreachable!("This should not happen. We already checked the data type."), + }; + + let client = Client::with_config(creds); + let embed = client.embeddings(); + let req = CreateEmbeddingRequest { + model: self.model.to_string(), + input, + encoding_format: Some(EncodingFormat::Float), + user: None, + dimensions: None, + }; + + // TODO: request batching and retry logic + task::block_in_place(move || { + Handle::current().block_on(async { + let mut builder = Float32Builder::new(); + + let res = embed.create(req).await.map_err(|e| crate::Error::Runtime { + message: format!("OpenAI embed request failed: {e}"), + })?; + + for Embedding { embedding, .. } in res.data.iter() { + builder.append_slice(embedding); + } + + Ok(builder.finish()) + }) + }) + } +} diff --git a/rust/lancedb/tests/embedding_registry_test.rs b/rust/lancedb/tests/embedding_registry_test.rs index 36ac5dd1..f2e8c9de 100644 --- a/rust/lancedb/tests/embedding_registry_test.rs +++ b/rust/lancedb/tests/embedding_registry_test.rs @@ -302,7 +302,7 @@ impl EmbeddingFunction for MockEmbed { fn dest_type(&self) -> Result> { Ok(Cow::Borrowed(&self.dest_type)) } - fn embed(&self, source: Arc) -> Result> { + fn compute_source_embeddings(&self, source: Arc) -> Result> { // We can't use the FixedSizeListBuilder here because it always adds a null bitmap // and we want to explicitly work with non-nullable arrays. let len = source.len(); @@ -317,4 +317,8 @@ impl EmbeddingFunction for MockEmbed { Ok(Arc::new(arr)) } + + fn compute_query_embeddings(&self, input: Arc) -> Result> { + unimplemented!() + } }