mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-07 12:22:59 +00:00
feat(rust): openai embedding function (#1275)
part of https://github.com/lancedb/lancedb/issues/994. Adds the ability to use the openai embedding functions. the example can be run by the following ```sh > EXPORT OPENAI_API_KEY="sk-..." > cargo run --example openai --features=openai ``` which should output ``` Closest match: Winter Parka ```
This commit is contained in:
@@ -38,6 +38,7 @@ url.workspace = true
|
||||
regex.workspace = true
|
||||
serde = { version = "^1" }
|
||||
serde_json = { version = "1" }
|
||||
async-openai = { version = "0.20.0", optional = true }
|
||||
serde_with = { version = "3.8.1" }
|
||||
# For remote feature
|
||||
reqwest = { version = "0.11.24", features = ["gzip", "json"], optional = true }
|
||||
@@ -62,4 +63,10 @@ default = []
|
||||
remote = ["dep:reqwest"]
|
||||
fp16kernels = ["lance-linalg/fp16kernels"]
|
||||
s3-test = []
|
||||
openai = ["dep:async-openai", "dep:reqwest"]
|
||||
polars = ["dep:polars-arrow", "dep:polars"]
|
||||
|
||||
|
||||
[[example]]
|
||||
name = "openai"
|
||||
required-features = ["openai"]
|
||||
|
||||
82
rust/lancedb/examples/openai.rs
Normal file
82
rust/lancedb/examples/openai.rs
Normal file
@@ -0,0 +1,82 @@
|
||||
use std::{iter::once, sync::Arc};
|
||||
|
||||
use arrow_array::{Float64Array, Int32Array, RecordBatch, RecordBatchIterator, StringArray};
|
||||
use arrow_schema::{DataType, Field, Schema};
|
||||
use futures::StreamExt;
|
||||
use lancedb::{
|
||||
arrow::IntoArrow,
|
||||
connect,
|
||||
embeddings::{openai::OpenAIEmbeddingFunction, EmbeddingDefinition, EmbeddingFunction},
|
||||
query::{ExecutableQuery, QueryBase},
|
||||
Result,
|
||||
};
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
let tempdir = tempfile::tempdir().unwrap();
|
||||
let tempdir = tempdir.path().to_str().unwrap();
|
||||
let api_key = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY is not set");
|
||||
let embedding = Arc::new(OpenAIEmbeddingFunction::new_with_model(
|
||||
api_key,
|
||||
"text-embedding-3-large",
|
||||
)?);
|
||||
|
||||
let db = connect(tempdir).execute().await?;
|
||||
db.embedding_registry()
|
||||
.register("openai", embedding.clone())?;
|
||||
|
||||
let table = db
|
||||
.create_table("vectors", make_data())
|
||||
.add_embedding(EmbeddingDefinition::new(
|
||||
"text",
|
||||
"openai",
|
||||
Some("embeddings"),
|
||||
))?
|
||||
.execute()
|
||||
.await?;
|
||||
|
||||
// there is no equivalent to '.search(<query>)' yet
|
||||
let query = Arc::new(StringArray::from_iter_values(once("something warm")));
|
||||
let query_vector = embedding.compute_query_embeddings(query)?;
|
||||
let mut results = table
|
||||
.vector_search(query_vector)?
|
||||
.limit(1)
|
||||
.execute()
|
||||
.await?;
|
||||
|
||||
let rb = results.next().await.unwrap()?;
|
||||
let out = rb
|
||||
.column_by_name("text")
|
||||
.unwrap()
|
||||
.as_any()
|
||||
.downcast_ref::<StringArray>()
|
||||
.unwrap();
|
||||
let text = out.iter().next().unwrap().unwrap();
|
||||
println!("Closest match: {}", text);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn make_data() -> impl IntoArrow {
|
||||
let schema = Schema::new(vec![
|
||||
Field::new("id", DataType::Int32, true),
|
||||
Field::new("text", DataType::Utf8, false),
|
||||
Field::new("price", DataType::Float64, false),
|
||||
]);
|
||||
|
||||
let id = Int32Array::from(vec![1, 2, 3, 4]);
|
||||
let text = StringArray::from_iter_values(vec![
|
||||
"Black T-Shirt",
|
||||
"Leather Jacket",
|
||||
"Winter Parka",
|
||||
"Hooded Sweatshirt",
|
||||
]);
|
||||
let price = Float64Array::from(vec![10.0, 50.0, 100.0, 30.0]);
|
||||
let schema = Arc::new(schema);
|
||||
let rb = RecordBatch::try_new(
|
||||
schema.clone(),
|
||||
vec![Arc::new(id), Arc::new(text), Arc::new(price)],
|
||||
)
|
||||
.unwrap();
|
||||
Box::new(RecordBatchIterator::new(vec![Ok(rb)], schema))
|
||||
}
|
||||
@@ -11,6 +11,8 @@
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#[cfg(feature = "openai")]
|
||||
pub mod openai;
|
||||
|
||||
use lance::arrow::RecordBatchExt;
|
||||
use std::{
|
||||
@@ -51,8 +53,10 @@ pub trait EmbeddingFunction: std::fmt::Debug + Send + Sync {
|
||||
/// The type of the output data
|
||||
/// This should **always** match the output of the `embed` function
|
||||
fn dest_type(&self) -> Result<Cow<DataType>>;
|
||||
/// Embed the input
|
||||
fn embed(&self, source: Arc<dyn Array>) -> Result<Arc<dyn Array>>;
|
||||
/// Compute the embeddings for the source column in the database
|
||||
fn compute_source_embeddings(&self, source: Arc<dyn Array>) -> Result<Arc<dyn Array>>;
|
||||
/// Compute the embeddings for a given user query
|
||||
fn compute_query_embeddings(&self, input: Arc<dyn Array>) -> Result<Arc<dyn Array>>;
|
||||
}
|
||||
|
||||
/// Defines an embedding from input data into a lower-dimensional space
|
||||
@@ -266,7 +270,7 @@ impl<R: RecordBatchReader> Iterator for WithEmbeddings<R> {
|
||||
// todo: parallelize this
|
||||
for (fld, func) in self.embeddings.iter() {
|
||||
let src_column = batch.column_by_name(&fld.source_column).unwrap();
|
||||
let embedding = match func.embed(src_column.clone()) {
|
||||
let embedding = match func.compute_source_embeddings(src_column.clone()) {
|
||||
Ok(embedding) => embedding,
|
||||
Err(e) => {
|
||||
return Some(Err(arrow_schema::ArrowError::ComputeError(format!(
|
||||
|
||||
257
rust/lancedb/src/embeddings/openai.rs
Normal file
257
rust/lancedb/src/embeddings/openai.rs
Normal file
@@ -0,0 +1,257 @@
|
||||
use std::{borrow::Cow, fmt::Formatter, str::FromStr, sync::Arc};
|
||||
|
||||
use arrow::array::{AsArray, Float32Builder};
|
||||
use arrow_array::{Array, ArrayRef, FixedSizeListArray, Float32Array};
|
||||
use arrow_data::ArrayData;
|
||||
use arrow_schema::DataType;
|
||||
use async_openai::{
|
||||
config::OpenAIConfig,
|
||||
types::{CreateEmbeddingRequest, Embedding, EmbeddingInput, EncodingFormat},
|
||||
Client,
|
||||
};
|
||||
use tokio::{runtime::Handle, task};
|
||||
|
||||
use crate::{Error, Result};
|
||||
|
||||
use super::EmbeddingFunction;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum EmbeddingModel {
|
||||
TextEmbeddingAda002,
|
||||
TextEmbedding3Small,
|
||||
TextEmbedding3Large,
|
||||
}
|
||||
|
||||
impl EmbeddingModel {
|
||||
fn ndims(&self) -> usize {
|
||||
match self {
|
||||
Self::TextEmbeddingAda002 => 1536,
|
||||
Self::TextEmbedding3Small => 1536,
|
||||
Self::TextEmbedding3Large => 3072,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for EmbeddingModel {
|
||||
type Err = Error;
|
||||
|
||||
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
|
||||
match s {
|
||||
"text-embedding-ada-002" => Ok(Self::TextEmbeddingAda002),
|
||||
"text-embedding-3-small" => Ok(Self::TextEmbedding3Small),
|
||||
"text-embedding-3-large" => Ok(Self::TextEmbedding3Large),
|
||||
_ => Err(Error::InvalidInput {
|
||||
message: "Invalid input. Available models are: 'text-embedding-3-small', 'text-embedding-ada-002', 'text-embedding-3-large' ".to_string()
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for EmbeddingModel {
|
||||
fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::TextEmbeddingAda002 => write!(f, "text-embedding-ada-002"),
|
||||
Self::TextEmbedding3Small => write!(f, "text-embedding-3-small"),
|
||||
Self::TextEmbedding3Large => write!(f, "text-embedding-3-large"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&str> for EmbeddingModel {
|
||||
type Error = Error;
|
||||
|
||||
fn try_from(value: &str) -> std::result::Result<Self, Self::Error> {
|
||||
value.parse()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct OpenAIEmbeddingFunction {
|
||||
model: EmbeddingModel,
|
||||
api_key: String,
|
||||
api_base: Option<String>,
|
||||
org_id: Option<String>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for OpenAIEmbeddingFunction {
|
||||
fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
|
||||
// let's be safe and not print the full API key
|
||||
let creds_display = if self.api_key.len() > 6 {
|
||||
format!(
|
||||
"{}***{}",
|
||||
&self.api_key[0..2],
|
||||
&self.api_key[self.api_key.len() - 4..]
|
||||
)
|
||||
} else {
|
||||
"[INVALID]".to_string()
|
||||
};
|
||||
|
||||
f.debug_struct("OpenAI")
|
||||
.field("model", &self.model)
|
||||
.field("api_key", &creds_display)
|
||||
.field("api_base", &self.api_base)
|
||||
.field("org_id", &self.org_id)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl OpenAIEmbeddingFunction {
|
||||
/// Create a new OpenAIEmbeddingFunction
|
||||
pub fn new<A: Into<String>>(api_key: A) -> Self {
|
||||
Self::new_impl(api_key.into(), EmbeddingModel::TextEmbeddingAda002)
|
||||
}
|
||||
|
||||
pub fn new_with_model<A: Into<String>, M: TryInto<EmbeddingModel>>(
|
||||
api_key: A,
|
||||
model: M,
|
||||
) -> crate::Result<Self>
|
||||
where
|
||||
M::Error: Into<crate::Error>,
|
||||
{
|
||||
Ok(Self::new_impl(
|
||||
api_key.into(),
|
||||
model.try_into().map_err(|e| e.into())?,
|
||||
))
|
||||
}
|
||||
|
||||
/// concrete implementation to reduce monomorphization
|
||||
fn new_impl(api_key: String, model: EmbeddingModel) -> Self {
|
||||
Self {
|
||||
model,
|
||||
api_key,
|
||||
api_base: None,
|
||||
org_id: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// To use a API base url different from default "https://api.openai.com/v1"
|
||||
pub fn api_base<S: Into<String>>(mut self, api_base: S) -> Self {
|
||||
self.api_base = Some(api_base.into());
|
||||
self
|
||||
}
|
||||
|
||||
/// To use a different OpenAI organization id other than default
|
||||
pub fn org_id<S: Into<String>>(mut self, org_id: S) -> Self {
|
||||
self.org_id = Some(org_id.into());
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl EmbeddingFunction for OpenAIEmbeddingFunction {
|
||||
fn name(&self) -> &str {
|
||||
"openai"
|
||||
}
|
||||
|
||||
fn source_type(&self) -> Result<Cow<DataType>> {
|
||||
Ok(Cow::Owned(DataType::Utf8))
|
||||
}
|
||||
|
||||
fn dest_type(&self) -> Result<Cow<DataType>> {
|
||||
let n_dims = self.model.ndims();
|
||||
Ok(Cow::Owned(DataType::new_fixed_size_list(
|
||||
DataType::Float32,
|
||||
n_dims as i32,
|
||||
false,
|
||||
)))
|
||||
}
|
||||
|
||||
fn compute_source_embeddings(&self, source: ArrayRef) -> crate::Result<ArrayRef> {
|
||||
let len = source.len();
|
||||
let n_dims = self.model.ndims();
|
||||
let inner = self.compute_inner(source)?;
|
||||
|
||||
let fsl = DataType::new_fixed_size_list(DataType::Float32, n_dims as i32, false);
|
||||
|
||||
// We can't use the FixedSizeListBuilder here because it always adds a null bitmap
|
||||
// and we want to explicitly work with non-nullable arrays.
|
||||
let array_data = ArrayData::builder(fsl)
|
||||
.len(len)
|
||||
.add_child_data(inner.into_data())
|
||||
.build()?;
|
||||
|
||||
Ok(Arc::new(FixedSizeListArray::from(array_data)))
|
||||
}
|
||||
|
||||
fn compute_query_embeddings(&self, input: Arc<dyn Array>) -> Result<Arc<dyn Array>> {
|
||||
let arr = self.compute_inner(input)?;
|
||||
Ok(Arc::new(arr))
|
||||
}
|
||||
}
|
||||
impl OpenAIEmbeddingFunction {
|
||||
fn compute_inner(&self, source: Arc<dyn Array>) -> Result<Float32Array> {
|
||||
// OpenAI only supports non-nullable string arrays
|
||||
if source.is_nullable() {
|
||||
return Err(crate::Error::InvalidInput {
|
||||
message: "Expected non-nullable data type".to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
// OpenAI only supports string arrays
|
||||
if !matches!(source.data_type(), DataType::Utf8 | DataType::LargeUtf8) {
|
||||
return Err(crate::Error::InvalidInput {
|
||||
message: "Expected Utf8 data type".to_string(),
|
||||
});
|
||||
};
|
||||
|
||||
let mut creds = OpenAIConfig::new().with_api_key(self.api_key.clone());
|
||||
|
||||
if let Some(api_base) = &self.api_base {
|
||||
creds = creds.with_api_base(api_base.clone());
|
||||
}
|
||||
if let Some(org_id) = &self.org_id {
|
||||
creds = creds.with_org_id(org_id.clone());
|
||||
}
|
||||
|
||||
let input = match source.data_type() {
|
||||
DataType::Utf8 => {
|
||||
let array = source
|
||||
.as_string::<i32>()
|
||||
.into_iter()
|
||||
.map(|s| {
|
||||
s.expect("we already asserted that the array is non-nullable")
|
||||
.to_string()
|
||||
})
|
||||
.collect::<Vec<String>>();
|
||||
EmbeddingInput::StringArray(array)
|
||||
}
|
||||
DataType::LargeUtf8 => {
|
||||
let array = source
|
||||
.as_string::<i64>()
|
||||
.into_iter()
|
||||
.map(|s| {
|
||||
s.expect("we already asserted that the array is non-nullable")
|
||||
.to_string()
|
||||
})
|
||||
.collect::<Vec<String>>();
|
||||
EmbeddingInput::StringArray(array)
|
||||
}
|
||||
_ => unreachable!("This should not happen. We already checked the data type."),
|
||||
};
|
||||
|
||||
let client = Client::with_config(creds);
|
||||
let embed = client.embeddings();
|
||||
let req = CreateEmbeddingRequest {
|
||||
model: self.model.to_string(),
|
||||
input,
|
||||
encoding_format: Some(EncodingFormat::Float),
|
||||
user: None,
|
||||
dimensions: None,
|
||||
};
|
||||
|
||||
// TODO: request batching and retry logic
|
||||
task::block_in_place(move || {
|
||||
Handle::current().block_on(async {
|
||||
let mut builder = Float32Builder::new();
|
||||
|
||||
let res = embed.create(req).await.map_err(|e| crate::Error::Runtime {
|
||||
message: format!("OpenAI embed request failed: {e}"),
|
||||
})?;
|
||||
|
||||
for Embedding { embedding, .. } in res.data.iter() {
|
||||
builder.append_slice(embedding);
|
||||
}
|
||||
|
||||
Ok(builder.finish())
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -302,7 +302,7 @@ impl EmbeddingFunction for MockEmbed {
|
||||
fn dest_type(&self) -> Result<Cow<DataType>> {
|
||||
Ok(Cow::Borrowed(&self.dest_type))
|
||||
}
|
||||
fn embed(&self, source: Arc<dyn Array>) -> Result<Arc<dyn Array>> {
|
||||
fn compute_source_embeddings(&self, source: Arc<dyn Array>) -> Result<Arc<dyn Array>> {
|
||||
// We can't use the FixedSizeListBuilder here because it always adds a null bitmap
|
||||
// and we want to explicitly work with non-nullable arrays.
|
||||
let len = source.len();
|
||||
@@ -317,4 +317,8 @@ impl EmbeddingFunction for MockEmbed {
|
||||
|
||||
Ok(Arc::new(arr))
|
||||
}
|
||||
|
||||
fn compute_query_embeddings(&self, input: Arc<dyn Array>) -> Result<Arc<dyn Array>> {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user