feat(rust): openai embedding function (#1275)

part of https://github.com/lancedb/lancedb/issues/994. 

Adds the ability to use the openai embedding functions.


the example can be run by the following

```sh
> EXPORT OPENAI_API_KEY="sk-..."
> cargo run --example openai --features=openai
```

which should output
```
Closest match: Winter Parka
```
This commit is contained in:
Cory Grinstead
2024-05-30 15:55:55 -05:00
committed by GitHub
parent 1e85b57c82
commit 01dd6c5e75
5 changed files with 358 additions and 4 deletions

View File

@@ -38,6 +38,7 @@ url.workspace = true
regex.workspace = true
serde = { version = "^1" }
serde_json = { version = "1" }
async-openai = { version = "0.20.0", optional = true }
serde_with = { version = "3.8.1" }
# For remote feature
reqwest = { version = "0.11.24", features = ["gzip", "json"], optional = true }
@@ -62,4 +63,10 @@ default = []
remote = ["dep:reqwest"]
fp16kernels = ["lance-linalg/fp16kernels"]
s3-test = []
openai = ["dep:async-openai", "dep:reqwest"]
polars = ["dep:polars-arrow", "dep:polars"]
[[example]]
name = "openai"
required-features = ["openai"]

View File

@@ -0,0 +1,82 @@
use std::{iter::once, sync::Arc};
use arrow_array::{Float64Array, Int32Array, RecordBatch, RecordBatchIterator, StringArray};
use arrow_schema::{DataType, Field, Schema};
use futures::StreamExt;
use lancedb::{
arrow::IntoArrow,
connect,
embeddings::{openai::OpenAIEmbeddingFunction, EmbeddingDefinition, EmbeddingFunction},
query::{ExecutableQuery, QueryBase},
Result,
};
#[tokio::main]
async fn main() -> Result<()> {
let tempdir = tempfile::tempdir().unwrap();
let tempdir = tempdir.path().to_str().unwrap();
let api_key = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY is not set");
let embedding = Arc::new(OpenAIEmbeddingFunction::new_with_model(
api_key,
"text-embedding-3-large",
)?);
let db = connect(tempdir).execute().await?;
db.embedding_registry()
.register("openai", embedding.clone())?;
let table = db
.create_table("vectors", make_data())
.add_embedding(EmbeddingDefinition::new(
"text",
"openai",
Some("embeddings"),
))?
.execute()
.await?;
// there is no equivalent to '.search(<query>)' yet
let query = Arc::new(StringArray::from_iter_values(once("something warm")));
let query_vector = embedding.compute_query_embeddings(query)?;
let mut results = table
.vector_search(query_vector)?
.limit(1)
.execute()
.await?;
let rb = results.next().await.unwrap()?;
let out = rb
.column_by_name("text")
.unwrap()
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
let text = out.iter().next().unwrap().unwrap();
println!("Closest match: {}", text);
Ok(())
}
fn make_data() -> impl IntoArrow {
let schema = Schema::new(vec![
Field::new("id", DataType::Int32, true),
Field::new("text", DataType::Utf8, false),
Field::new("price", DataType::Float64, false),
]);
let id = Int32Array::from(vec![1, 2, 3, 4]);
let text = StringArray::from_iter_values(vec![
"Black T-Shirt",
"Leather Jacket",
"Winter Parka",
"Hooded Sweatshirt",
]);
let price = Float64Array::from(vec![10.0, 50.0, 100.0, 30.0]);
let schema = Arc::new(schema);
let rb = RecordBatch::try_new(
schema.clone(),
vec![Arc::new(id), Arc::new(text), Arc::new(price)],
)
.unwrap();
Box::new(RecordBatchIterator::new(vec![Ok(rb)], schema))
}

View File

@@ -11,6 +11,8 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#[cfg(feature = "openai")]
pub mod openai;
use lance::arrow::RecordBatchExt;
use std::{
@@ -51,8 +53,10 @@ pub trait EmbeddingFunction: std::fmt::Debug + Send + Sync {
/// The type of the output data
/// This should **always** match the output of the `embed` function
fn dest_type(&self) -> Result<Cow<DataType>>;
/// Embed the input
fn embed(&self, source: Arc<dyn Array>) -> Result<Arc<dyn Array>>;
/// Compute the embeddings for the source column in the database
fn compute_source_embeddings(&self, source: Arc<dyn Array>) -> Result<Arc<dyn Array>>;
/// Compute the embeddings for a given user query
fn compute_query_embeddings(&self, input: Arc<dyn Array>) -> Result<Arc<dyn Array>>;
}
/// Defines an embedding from input data into a lower-dimensional space
@@ -266,7 +270,7 @@ impl<R: RecordBatchReader> Iterator for WithEmbeddings<R> {
// todo: parallelize this
for (fld, func) in self.embeddings.iter() {
let src_column = batch.column_by_name(&fld.source_column).unwrap();
let embedding = match func.embed(src_column.clone()) {
let embedding = match func.compute_source_embeddings(src_column.clone()) {
Ok(embedding) => embedding,
Err(e) => {
return Some(Err(arrow_schema::ArrowError::ComputeError(format!(

View File

@@ -0,0 +1,257 @@
use std::{borrow::Cow, fmt::Formatter, str::FromStr, sync::Arc};
use arrow::array::{AsArray, Float32Builder};
use arrow_array::{Array, ArrayRef, FixedSizeListArray, Float32Array};
use arrow_data::ArrayData;
use arrow_schema::DataType;
use async_openai::{
config::OpenAIConfig,
types::{CreateEmbeddingRequest, Embedding, EmbeddingInput, EncodingFormat},
Client,
};
use tokio::{runtime::Handle, task};
use crate::{Error, Result};
use super::EmbeddingFunction;
#[derive(Debug)]
pub enum EmbeddingModel {
TextEmbeddingAda002,
TextEmbedding3Small,
TextEmbedding3Large,
}
impl EmbeddingModel {
fn ndims(&self) -> usize {
match self {
Self::TextEmbeddingAda002 => 1536,
Self::TextEmbedding3Small => 1536,
Self::TextEmbedding3Large => 3072,
}
}
}
impl FromStr for EmbeddingModel {
type Err = Error;
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
match s {
"text-embedding-ada-002" => Ok(Self::TextEmbeddingAda002),
"text-embedding-3-small" => Ok(Self::TextEmbedding3Small),
"text-embedding-3-large" => Ok(Self::TextEmbedding3Large),
_ => Err(Error::InvalidInput {
message: "Invalid input. Available models are: 'text-embedding-3-small', 'text-embedding-ada-002', 'text-embedding-3-large' ".to_string()
}),
}
}
}
impl std::fmt::Display for EmbeddingModel {
fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
match self {
Self::TextEmbeddingAda002 => write!(f, "text-embedding-ada-002"),
Self::TextEmbedding3Small => write!(f, "text-embedding-3-small"),
Self::TextEmbedding3Large => write!(f, "text-embedding-3-large"),
}
}
}
impl TryFrom<&str> for EmbeddingModel {
type Error = Error;
fn try_from(value: &str) -> std::result::Result<Self, Self::Error> {
value.parse()
}
}
pub struct OpenAIEmbeddingFunction {
model: EmbeddingModel,
api_key: String,
api_base: Option<String>,
org_id: Option<String>,
}
impl std::fmt::Debug for OpenAIEmbeddingFunction {
fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
// let's be safe and not print the full API key
let creds_display = if self.api_key.len() > 6 {
format!(
"{}***{}",
&self.api_key[0..2],
&self.api_key[self.api_key.len() - 4..]
)
} else {
"[INVALID]".to_string()
};
f.debug_struct("OpenAI")
.field("model", &self.model)
.field("api_key", &creds_display)
.field("api_base", &self.api_base)
.field("org_id", &self.org_id)
.finish()
}
}
impl OpenAIEmbeddingFunction {
/// Create a new OpenAIEmbeddingFunction
pub fn new<A: Into<String>>(api_key: A) -> Self {
Self::new_impl(api_key.into(), EmbeddingModel::TextEmbeddingAda002)
}
pub fn new_with_model<A: Into<String>, M: TryInto<EmbeddingModel>>(
api_key: A,
model: M,
) -> crate::Result<Self>
where
M::Error: Into<crate::Error>,
{
Ok(Self::new_impl(
api_key.into(),
model.try_into().map_err(|e| e.into())?,
))
}
/// concrete implementation to reduce monomorphization
fn new_impl(api_key: String, model: EmbeddingModel) -> Self {
Self {
model,
api_key,
api_base: None,
org_id: None,
}
}
/// To use a API base url different from default "https://api.openai.com/v1"
pub fn api_base<S: Into<String>>(mut self, api_base: S) -> Self {
self.api_base = Some(api_base.into());
self
}
/// To use a different OpenAI organization id other than default
pub fn org_id<S: Into<String>>(mut self, org_id: S) -> Self {
self.org_id = Some(org_id.into());
self
}
}
impl EmbeddingFunction for OpenAIEmbeddingFunction {
fn name(&self) -> &str {
"openai"
}
fn source_type(&self) -> Result<Cow<DataType>> {
Ok(Cow::Owned(DataType::Utf8))
}
fn dest_type(&self) -> Result<Cow<DataType>> {
let n_dims = self.model.ndims();
Ok(Cow::Owned(DataType::new_fixed_size_list(
DataType::Float32,
n_dims as i32,
false,
)))
}
fn compute_source_embeddings(&self, source: ArrayRef) -> crate::Result<ArrayRef> {
let len = source.len();
let n_dims = self.model.ndims();
let inner = self.compute_inner(source)?;
let fsl = DataType::new_fixed_size_list(DataType::Float32, n_dims as i32, false);
// We can't use the FixedSizeListBuilder here because it always adds a null bitmap
// and we want to explicitly work with non-nullable arrays.
let array_data = ArrayData::builder(fsl)
.len(len)
.add_child_data(inner.into_data())
.build()?;
Ok(Arc::new(FixedSizeListArray::from(array_data)))
}
fn compute_query_embeddings(&self, input: Arc<dyn Array>) -> Result<Arc<dyn Array>> {
let arr = self.compute_inner(input)?;
Ok(Arc::new(arr))
}
}
impl OpenAIEmbeddingFunction {
fn compute_inner(&self, source: Arc<dyn Array>) -> Result<Float32Array> {
// OpenAI only supports non-nullable string arrays
if source.is_nullable() {
return Err(crate::Error::InvalidInput {
message: "Expected non-nullable data type".to_string(),
});
}
// OpenAI only supports string arrays
if !matches!(source.data_type(), DataType::Utf8 | DataType::LargeUtf8) {
return Err(crate::Error::InvalidInput {
message: "Expected Utf8 data type".to_string(),
});
};
let mut creds = OpenAIConfig::new().with_api_key(self.api_key.clone());
if let Some(api_base) = &self.api_base {
creds = creds.with_api_base(api_base.clone());
}
if let Some(org_id) = &self.org_id {
creds = creds.with_org_id(org_id.clone());
}
let input = match source.data_type() {
DataType::Utf8 => {
let array = source
.as_string::<i32>()
.into_iter()
.map(|s| {
s.expect("we already asserted that the array is non-nullable")
.to_string()
})
.collect::<Vec<String>>();
EmbeddingInput::StringArray(array)
}
DataType::LargeUtf8 => {
let array = source
.as_string::<i64>()
.into_iter()
.map(|s| {
s.expect("we already asserted that the array is non-nullable")
.to_string()
})
.collect::<Vec<String>>();
EmbeddingInput::StringArray(array)
}
_ => unreachable!("This should not happen. We already checked the data type."),
};
let client = Client::with_config(creds);
let embed = client.embeddings();
let req = CreateEmbeddingRequest {
model: self.model.to_string(),
input,
encoding_format: Some(EncodingFormat::Float),
user: None,
dimensions: None,
};
// TODO: request batching and retry logic
task::block_in_place(move || {
Handle::current().block_on(async {
let mut builder = Float32Builder::new();
let res = embed.create(req).await.map_err(|e| crate::Error::Runtime {
message: format!("OpenAI embed request failed: {e}"),
})?;
for Embedding { embedding, .. } in res.data.iter() {
builder.append_slice(embedding);
}
Ok(builder.finish())
})
})
}
}

View File

@@ -302,7 +302,7 @@ impl EmbeddingFunction for MockEmbed {
fn dest_type(&self) -> Result<Cow<DataType>> {
Ok(Cow::Borrowed(&self.dest_type))
}
fn embed(&self, source: Arc<dyn Array>) -> Result<Arc<dyn Array>> {
fn compute_source_embeddings(&self, source: Arc<dyn Array>) -> Result<Arc<dyn Array>> {
// We can't use the FixedSizeListBuilder here because it always adds a null bitmap
// and we want to explicitly work with non-nullable arrays.
let len = source.len();
@@ -317,4 +317,8 @@ impl EmbeddingFunction for MockEmbed {
Ok(Arc::new(arr))
}
fn compute_query_embeddings(&self, input: Arc<dyn Array>) -> Result<Arc<dyn Array>> {
unimplemented!()
}
}