mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-08 12:52:58 +00:00
feat: rust embedding registry (#1259)
Todo: - [x] add proper documentation - [x] add unit tests - [x] better handling of the registry**1 - [x] allow user defined registry**2 **1 The python implementation just uses a global registry so it makes things a bit easier. I attached it to the db/connection to prevent future conflicts if running multiple connections/databases. I mostly modeled the registry & pattern off of datafusion's [FunctionRegistry](https://docs.rs/datafusion/latest/datafusion/execution/trait.FunctionRegistry.html). **2 Ideally, the user should be able to provide it's own registry entirely, but currently it just uses an in memory registry by default (_which isn't configurable_) `rust/lancedb/examples/embedding_registry.rs` provides a thorough example of expected usage. --- Some additional notes: This does not provide any of the out of box functionality that the python registry does. _i.e there are no built-in embedding functions._ You can think of this as the ground work for adding those built in functions, So while this is part of https://github.com/lancedb/lancedb/issues/994, it does not yet offer feature parity.
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -6,7 +6,7 @@
|
||||
venv
|
||||
|
||||
.vscode
|
||||
|
||||
.zed
|
||||
rust/target
|
||||
rust/Cargo.lock
|
||||
|
||||
|
||||
@@ -116,12 +116,19 @@ pub trait IntoArrow {
|
||||
fn into_arrow(self) -> Result<Box<dyn arrow_array::RecordBatchReader + Send>>;
|
||||
}
|
||||
|
||||
pub type BoxedRecordBatchReader = Box<dyn arrow_array::RecordBatchReader + Send>;
|
||||
|
||||
impl<T: arrow_array::RecordBatchReader + Send + 'static> IntoArrow for T {
|
||||
fn into_arrow(self) -> Result<Box<dyn arrow_array::RecordBatchReader + Send>> {
|
||||
Ok(Box::new(self))
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: Stream<Item = Result<arrow_array::RecordBatch>>> SimpleRecordBatchStream<S> {
|
||||
pub fn new(stream: S, schema: Arc<arrow_schema::Schema>) -> Self {
|
||||
Self { schema, stream }
|
||||
}
|
||||
}
|
||||
#[cfg(feature = "polars")]
|
||||
/// An iterator of record batches formed from a Polars DataFrame.
|
||||
pub struct PolarsDataFrameRecordBatchReader {
|
||||
|
||||
@@ -27,9 +27,12 @@ use object_store::{aws::AwsCredential, local::LocalFileSystem};
|
||||
use snafu::prelude::*;
|
||||
|
||||
use crate::arrow::IntoArrow;
|
||||
use crate::embeddings::{
|
||||
EmbeddingDefinition, EmbeddingFunction, EmbeddingRegistry, MemoryRegistry, WithEmbeddings,
|
||||
};
|
||||
use crate::error::{CreateDirSnafu, Error, InvalidTableNameSnafu, Result};
|
||||
use crate::io::object_store::MirroringObjectStoreWrapper;
|
||||
use crate::table::{NativeTable, WriteOptions};
|
||||
use crate::table::{NativeTable, TableDefinition, WriteOptions};
|
||||
use crate::utils::validate_table_name;
|
||||
use crate::Table;
|
||||
|
||||
@@ -133,9 +136,10 @@ pub struct CreateTableBuilder<const HAS_DATA: bool, T: IntoArrow> {
|
||||
parent: Arc<dyn ConnectionInternal>,
|
||||
pub(crate) name: String,
|
||||
pub(crate) data: Option<T>,
|
||||
pub(crate) schema: Option<SchemaRef>,
|
||||
pub(crate) mode: CreateTableMode,
|
||||
pub(crate) write_options: WriteOptions,
|
||||
pub(crate) table_definition: Option<TableDefinition>,
|
||||
pub(crate) embeddings: Vec<(EmbeddingDefinition, Arc<dyn EmbeddingFunction>)>,
|
||||
}
|
||||
|
||||
// Builder methods that only apply when we have initial data
|
||||
@@ -145,9 +149,10 @@ impl<T: IntoArrow> CreateTableBuilder<true, T> {
|
||||
parent,
|
||||
name,
|
||||
data: Some(data),
|
||||
schema: None,
|
||||
mode: CreateTableMode::default(),
|
||||
write_options: WriteOptions::default(),
|
||||
table_definition: None,
|
||||
embeddings: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -175,24 +180,43 @@ impl<T: IntoArrow> CreateTableBuilder<true, T> {
|
||||
parent: self.parent,
|
||||
name: self.name,
|
||||
data: None,
|
||||
schema: self.schema,
|
||||
table_definition: self.table_definition,
|
||||
mode: self.mode,
|
||||
write_options: self.write_options,
|
||||
embeddings: self.embeddings,
|
||||
};
|
||||
Ok((data, builder))
|
||||
}
|
||||
|
||||
pub fn add_embedding(mut self, definition: EmbeddingDefinition) -> Result<Self> {
|
||||
// Early verification of the embedding name
|
||||
let embedding_func = self
|
||||
.parent
|
||||
.embedding_registry()
|
||||
.get(&definition.embedding_name)
|
||||
.ok_or_else(|| Error::EmbeddingFunctionNotFound {
|
||||
name: definition.embedding_name.to_string(),
|
||||
reason: "No embedding function found in the connection's embedding_registry"
|
||||
.to_string(),
|
||||
})?;
|
||||
|
||||
self.embeddings.push((definition, embedding_func));
|
||||
Ok(self)
|
||||
}
|
||||
}
|
||||
|
||||
// Builder methods that only apply when we do not have initial data
|
||||
impl CreateTableBuilder<false, NoData> {
|
||||
fn new(parent: Arc<dyn ConnectionInternal>, name: String, schema: SchemaRef) -> Self {
|
||||
let table_definition = TableDefinition::new_from_schema(schema);
|
||||
Self {
|
||||
parent,
|
||||
name,
|
||||
data: None,
|
||||
schema: Some(schema),
|
||||
table_definition: Some(table_definition),
|
||||
mode: CreateTableMode::default(),
|
||||
write_options: WriteOptions::default(),
|
||||
embeddings: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -350,6 +374,7 @@ impl OpenTableBuilder {
|
||||
pub(crate) trait ConnectionInternal:
|
||||
Send + Sync + std::fmt::Debug + std::fmt::Display + 'static
|
||||
{
|
||||
fn embedding_registry(&self) -> &dyn EmbeddingRegistry;
|
||||
async fn table_names(&self, options: TableNamesBuilder) -> Result<Vec<String>>;
|
||||
async fn do_create_table(
|
||||
&self,
|
||||
@@ -366,7 +391,7 @@ pub(crate) trait ConnectionInternal:
|
||||
) -> Result<Table> {
|
||||
let batches = Box::new(RecordBatchIterator::new(
|
||||
vec![],
|
||||
options.schema.as_ref().unwrap().clone(),
|
||||
options.table_definition.clone().unwrap().schema.clone(),
|
||||
));
|
||||
self.do_create_table(options, batches).await
|
||||
}
|
||||
@@ -453,6 +478,13 @@ impl Connection {
|
||||
pub async fn drop_db(&self) -> Result<()> {
|
||||
self.internal.drop_db().await
|
||||
}
|
||||
|
||||
/// Get the in-memory embedding registry.
|
||||
/// It's important to note that the embedding registry is not persisted across connections.
|
||||
/// So if a table contains embeddings, you will need to make sure that you are using a connection that has the same embedding functions registered
|
||||
pub fn embedding_registry(&self) -> &dyn EmbeddingRegistry {
|
||||
self.internal.embedding_registry()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
@@ -486,6 +518,7 @@ pub struct ConnectBuilder {
|
||||
/// consistency only applies to read operations. Write operations are
|
||||
/// always consistent.
|
||||
read_consistency_interval: Option<std::time::Duration>,
|
||||
embedding_registry: Option<Arc<dyn EmbeddingRegistry>>,
|
||||
}
|
||||
|
||||
impl ConnectBuilder {
|
||||
@@ -498,6 +531,7 @@ impl ConnectBuilder {
|
||||
host_override: None,
|
||||
read_consistency_interval: None,
|
||||
storage_options: HashMap::new(),
|
||||
embedding_registry: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -516,6 +550,12 @@ impl ConnectBuilder {
|
||||
self
|
||||
}
|
||||
|
||||
/// Provide a custom [`EmbeddingRegistry`] to use for this connection.
|
||||
pub fn embedding_registry(mut self, registry: Arc<dyn EmbeddingRegistry>) -> Self {
|
||||
self.embedding_registry = Some(registry);
|
||||
self
|
||||
}
|
||||
|
||||
/// [`AwsCredential`] to use when connecting to S3.
|
||||
#[deprecated(note = "Pass through storage_options instead")]
|
||||
pub fn aws_creds(mut self, aws_creds: AwsCredential) -> Self {
|
||||
@@ -642,6 +682,7 @@ struct Database {
|
||||
|
||||
// Storage options to be inherited by tables created from this connection
|
||||
storage_options: HashMap<String, String>,
|
||||
embedding_registry: Arc<dyn EmbeddingRegistry>,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Database {
|
||||
@@ -675,7 +716,12 @@ impl Database {
|
||||
// TODO: pass params regardless of OS
|
||||
match parse_res {
|
||||
Ok(url) if url.scheme().len() == 1 && cfg!(windows) => {
|
||||
Self::open_path(uri, options.read_consistency_interval).await
|
||||
Self::open_path(
|
||||
uri,
|
||||
options.read_consistency_interval,
|
||||
options.embedding_registry.clone(),
|
||||
)
|
||||
.await
|
||||
}
|
||||
Ok(mut url) => {
|
||||
// iter thru the query params and extract the commit store param
|
||||
@@ -745,6 +791,10 @@ impl Database {
|
||||
None => None,
|
||||
};
|
||||
|
||||
let embedding_registry = options
|
||||
.embedding_registry
|
||||
.clone()
|
||||
.unwrap_or_else(|| Arc::new(MemoryRegistry::new()));
|
||||
Ok(Self {
|
||||
uri: table_base_uri,
|
||||
query_string,
|
||||
@@ -753,20 +803,33 @@ impl Database {
|
||||
store_wrapper: write_store_wrapper,
|
||||
read_consistency_interval: options.read_consistency_interval,
|
||||
storage_options,
|
||||
embedding_registry,
|
||||
})
|
||||
}
|
||||
Err(_) => Self::open_path(uri, options.read_consistency_interval).await,
|
||||
Err(_) => {
|
||||
Self::open_path(
|
||||
uri,
|
||||
options.read_consistency_interval,
|
||||
options.embedding_registry.clone(),
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn open_path(
|
||||
path: &str,
|
||||
read_consistency_interval: Option<std::time::Duration>,
|
||||
embedding_registry: Option<Arc<dyn EmbeddingRegistry>>,
|
||||
) -> Result<Self> {
|
||||
let (object_store, base_path) = ObjectStore::from_uri(path).await?;
|
||||
if object_store.is_local() {
|
||||
Self::try_create_dir(path).context(CreateDirSnafu { path })?;
|
||||
}
|
||||
|
||||
let embedding_registry =
|
||||
embedding_registry.unwrap_or_else(|| Arc::new(MemoryRegistry::new()));
|
||||
|
||||
Ok(Self {
|
||||
uri: path.to_string(),
|
||||
query_string: None,
|
||||
@@ -775,6 +838,7 @@ impl Database {
|
||||
store_wrapper: None,
|
||||
read_consistency_interval,
|
||||
storage_options: HashMap::new(),
|
||||
embedding_registry,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -815,6 +879,9 @@ impl Database {
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl ConnectionInternal for Database {
|
||||
fn embedding_registry(&self) -> &dyn EmbeddingRegistry {
|
||||
self.embedding_registry.as_ref()
|
||||
}
|
||||
async fn table_names(&self, options: TableNamesBuilder) -> Result<Vec<String>> {
|
||||
let mut f = self
|
||||
.object_store
|
||||
@@ -851,7 +918,7 @@ impl ConnectionInternal for Database {
|
||||
data: Box<dyn RecordBatchReader + Send>,
|
||||
) -> Result<Table> {
|
||||
let table_uri = self.table_uri(&options.name)?;
|
||||
|
||||
let embedding_registry = self.embedding_registry.clone();
|
||||
// Inherit storage options from the connection
|
||||
let storage_options = options
|
||||
.write_options
|
||||
@@ -866,6 +933,11 @@ impl ConnectionInternal for Database {
|
||||
storage_options.insert(key.clone(), value.clone());
|
||||
}
|
||||
}
|
||||
let data = if options.embeddings.is_empty() {
|
||||
data
|
||||
} else {
|
||||
Box::new(WithEmbeddings::new(data, options.embeddings))
|
||||
};
|
||||
|
||||
let mut write_params = options.write_options.lance_write_params.unwrap_or_default();
|
||||
if matches!(&options.mode, CreateTableMode::Overwrite) {
|
||||
@@ -882,7 +954,10 @@ impl ConnectionInternal for Database {
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(table) => Ok(Table::new(Arc::new(table))),
|
||||
Ok(table) => Ok(Table::new_with_embedding_registry(
|
||||
Arc::new(table),
|
||||
embedding_registry,
|
||||
)),
|
||||
Err(Error::TableAlreadyExists { name }) => match options.mode {
|
||||
CreateTableMode::Create => Err(Error::TableAlreadyExists { name }),
|
||||
CreateTableMode::ExistOk(callback) => {
|
||||
|
||||
307
rust/lancedb/src/embeddings.rs
Normal file
307
rust/lancedb/src/embeddings.rs
Normal file
@@ -0,0 +1,307 @@
|
||||
// Copyright 2024 LanceDB Developers.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use lance::arrow::RecordBatchExt;
|
||||
use std::{
|
||||
borrow::Cow,
|
||||
collections::{HashMap, HashSet},
|
||||
sync::{Arc, RwLock},
|
||||
};
|
||||
|
||||
use arrow_array::{Array, RecordBatch, RecordBatchReader};
|
||||
use arrow_schema::{DataType, Field, SchemaBuilder};
|
||||
// use async_trait::async_trait;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{
|
||||
error::Result,
|
||||
table::{ColumnDefinition, ColumnKind, TableDefinition},
|
||||
Error,
|
||||
};
|
||||
|
||||
/// Trait for embedding functions
|
||||
///
|
||||
/// An embedding function is a function that is applied to a column of input data
|
||||
/// to produce an "embedding" of that input. This embedding is then stored in the
|
||||
/// database alongside (or instead of) the original input.
|
||||
///
|
||||
/// An "embedding" is often a lower-dimensional representation of the input data.
|
||||
/// For example, sentence-transformers can be used to embed sentences into a 768-dimensional
|
||||
/// vector space. This is useful for tasks like similarity search, where we want to find
|
||||
/// similar sentences to a query sentence.
|
||||
///
|
||||
/// To use an embedding function you must first register it with the `EmbeddingsRegistry`.
|
||||
/// Then you can define it on a column in the table schema. That embedding will then be used
|
||||
/// to embed the data in that column.
|
||||
pub trait EmbeddingFunction: std::fmt::Debug + Send + Sync {
|
||||
fn name(&self) -> &str;
|
||||
/// The type of the input data
|
||||
fn source_type(&self) -> Result<Cow<DataType>>;
|
||||
/// The type of the output data
|
||||
/// This should **always** match the output of the `embed` function
|
||||
fn dest_type(&self) -> Result<Cow<DataType>>;
|
||||
/// Embed the input
|
||||
fn embed(&self, source: Arc<dyn Array>) -> Result<Arc<dyn Array>>;
|
||||
}
|
||||
|
||||
/// Defines an embedding from input data into a lower-dimensional space
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
|
||||
pub struct EmbeddingDefinition {
|
||||
/// The name of the column in the input data
|
||||
pub source_column: String,
|
||||
/// The name of the embedding column, if not specified
|
||||
/// it will be the source column with `_embedding` appended
|
||||
pub dest_column: Option<String>,
|
||||
/// The name of the embedding function to apply
|
||||
pub embedding_name: String,
|
||||
}
|
||||
|
||||
impl EmbeddingDefinition {
|
||||
pub fn new<S: Into<String>>(source_column: S, embedding_name: S, dest: Option<S>) -> Self {
|
||||
Self {
|
||||
source_column: source_column.into(),
|
||||
dest_column: dest.map(|d| d.into()),
|
||||
embedding_name: embedding_name.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A registry of embedding
|
||||
pub trait EmbeddingRegistry: Send + Sync + std::fmt::Debug {
|
||||
/// Return the names of all registered embedding functions
|
||||
fn functions(&self) -> HashSet<String>;
|
||||
/// Register a new [`EmbeddingFunction
|
||||
/// Returns an error if the function can not be registered
|
||||
fn register(&self, name: &str, function: Arc<dyn EmbeddingFunction>) -> Result<()>;
|
||||
/// Get an embedding function by name
|
||||
fn get(&self, name: &str) -> Option<Arc<dyn EmbeddingFunction>>;
|
||||
}
|
||||
|
||||
/// A [`EmbeddingRegistry`] that uses in-memory [`HashMap`]s
|
||||
#[derive(Debug, Default, Clone)]
|
||||
pub struct MemoryRegistry {
|
||||
functions: Arc<RwLock<HashMap<String, Arc<dyn EmbeddingFunction>>>>,
|
||||
}
|
||||
|
||||
impl EmbeddingRegistry for MemoryRegistry {
|
||||
fn functions(&self) -> HashSet<String> {
|
||||
self.functions.read().unwrap().keys().cloned().collect()
|
||||
}
|
||||
fn register(&self, name: &str, function: Arc<dyn EmbeddingFunction>) -> Result<()> {
|
||||
self.functions
|
||||
.write()
|
||||
.unwrap()
|
||||
.insert(name.to_string(), function);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get(&self, name: &str) -> Option<Arc<dyn EmbeddingFunction>> {
|
||||
self.functions.read().unwrap().get(name).cloned()
|
||||
}
|
||||
}
|
||||
|
||||
impl MemoryRegistry {
|
||||
/// Create a new `MemoryRegistry`
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// A record batch reader that has embeddings applied to it
|
||||
/// This is a wrapper around another record batch reader that applies an embedding function
|
||||
/// when reading from the record batch
|
||||
pub struct WithEmbeddings<R: RecordBatchReader> {
|
||||
inner: R,
|
||||
embeddings: Vec<(EmbeddingDefinition, Arc<dyn EmbeddingFunction>)>,
|
||||
}
|
||||
|
||||
/// A record batch that might have embeddings applied to it.
|
||||
pub enum MaybeEmbedded<R: RecordBatchReader> {
|
||||
/// The record batch reader has embeddings applied to it
|
||||
Yes(WithEmbeddings<R>),
|
||||
/// The record batch reader does not have embeddings applied to it
|
||||
/// The inner record batch reader is returned as-is
|
||||
No(R),
|
||||
}
|
||||
|
||||
impl<R: RecordBatchReader> MaybeEmbedded<R> {
|
||||
/// Create a new RecordBatchReader with embeddings applied to it if the table definition
|
||||
/// specifies an embedding column and the registry contains an embedding function with that name
|
||||
/// Otherwise, this is a no-op and the inner RecordBatchReader is returned.
|
||||
pub fn try_new(
|
||||
inner: R,
|
||||
table_definition: TableDefinition,
|
||||
registry: Option<Arc<dyn EmbeddingRegistry>>,
|
||||
) -> Result<Self> {
|
||||
if let Some(registry) = registry {
|
||||
let mut embeddings = Vec::with_capacity(table_definition.column_definitions.len());
|
||||
for cd in table_definition.column_definitions.iter() {
|
||||
if let ColumnKind::Embedding(embedding_def) = &cd.kind {
|
||||
match registry.get(&embedding_def.embedding_name) {
|
||||
Some(func) => {
|
||||
embeddings.push((embedding_def.clone(), func));
|
||||
}
|
||||
None => {
|
||||
return Err(Error::EmbeddingFunctionNotFound {
|
||||
name: embedding_def.embedding_name.to_string(),
|
||||
reason: format!(
|
||||
"Table was defined with an embedding column `{}` but no embedding function was found with that name within the registry.",
|
||||
embedding_def.embedding_name
|
||||
),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !embeddings.is_empty() {
|
||||
return Ok(Self::Yes(WithEmbeddings { inner, embeddings }));
|
||||
}
|
||||
};
|
||||
|
||||
// No embeddings to apply
|
||||
Ok(Self::No(inner))
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: RecordBatchReader> WithEmbeddings<R> {
|
||||
pub fn new(
|
||||
inner: R,
|
||||
embeddings: Vec<(EmbeddingDefinition, Arc<dyn EmbeddingFunction>)>,
|
||||
) -> Self {
|
||||
Self { inner, embeddings }
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: RecordBatchReader> WithEmbeddings<R> {
|
||||
fn dest_fields(&self) -> Result<Vec<Field>> {
|
||||
let schema = self.inner.schema();
|
||||
self.embeddings
|
||||
.iter()
|
||||
.map(|(ed, func)| {
|
||||
let src_field = schema.field_with_name(&ed.source_column).unwrap();
|
||||
|
||||
let field_name = ed
|
||||
.dest_column
|
||||
.clone()
|
||||
.unwrap_or_else(|| format!("{}_embedding", &ed.source_column));
|
||||
Ok(Field::new(
|
||||
field_name,
|
||||
func.dest_type()?.into_owned(),
|
||||
src_field.is_nullable(),
|
||||
))
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn column_defs(&self) -> Vec<ColumnDefinition> {
|
||||
let base_schema = self.inner.schema();
|
||||
base_schema
|
||||
.fields()
|
||||
.iter()
|
||||
.map(|_| ColumnDefinition {
|
||||
kind: ColumnKind::Physical,
|
||||
})
|
||||
.chain(self.embeddings.iter().map(|(ed, _)| ColumnDefinition {
|
||||
kind: ColumnKind::Embedding(ed.clone()),
|
||||
}))
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
|
||||
pub fn table_definition(&self) -> Result<TableDefinition> {
|
||||
let base_schema = self.inner.schema();
|
||||
|
||||
let output_fields = self.dest_fields()?;
|
||||
let column_definitions = self.column_defs();
|
||||
|
||||
let mut sb: SchemaBuilder = base_schema.as_ref().into();
|
||||
sb.extend(output_fields);
|
||||
|
||||
let schema = Arc::new(sb.finish());
|
||||
Ok(TableDefinition {
|
||||
schema,
|
||||
column_definitions,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: RecordBatchReader> Iterator for MaybeEmbedded<R> {
|
||||
type Item = std::result::Result<RecordBatch, arrow_schema::ArrowError>;
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
match self {
|
||||
Self::Yes(inner) => inner.next(),
|
||||
Self::No(inner) => inner.next(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: RecordBatchReader> RecordBatchReader for MaybeEmbedded<R> {
|
||||
fn schema(&self) -> Arc<arrow_schema::Schema> {
|
||||
match self {
|
||||
Self::Yes(inner) => inner.schema(),
|
||||
Self::No(inner) => inner.schema(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: RecordBatchReader> Iterator for WithEmbeddings<R> {
|
||||
type Item = std::result::Result<RecordBatch, arrow_schema::ArrowError>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
let batch = self.inner.next()?;
|
||||
match batch {
|
||||
Ok(mut batch) => {
|
||||
// todo: parallelize this
|
||||
for (fld, func) in self.embeddings.iter() {
|
||||
let src_column = batch.column_by_name(&fld.source_column).unwrap();
|
||||
let embedding = match func.embed(src_column.clone()) {
|
||||
Ok(embedding) => embedding,
|
||||
Err(e) => {
|
||||
return Some(Err(arrow_schema::ArrowError::ComputeError(format!(
|
||||
"Error computing embedding: {}",
|
||||
e
|
||||
))))
|
||||
}
|
||||
};
|
||||
let dst_field_name = fld
|
||||
.dest_column
|
||||
.clone()
|
||||
.unwrap_or_else(|| format!("{}_embedding", &fld.source_column));
|
||||
|
||||
let dst_field = Field::new(
|
||||
dst_field_name,
|
||||
embedding.data_type().clone(),
|
||||
embedding.nulls().is_some(),
|
||||
);
|
||||
|
||||
match batch.try_with_column(dst_field.clone(), embedding) {
|
||||
Ok(b) => batch = b,
|
||||
Err(e) => return Some(Err(e)),
|
||||
};
|
||||
}
|
||||
Some(Ok(batch))
|
||||
}
|
||||
Err(e) => Some(Err(e)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: RecordBatchReader> RecordBatchReader for WithEmbeddings<R> {
|
||||
fn schema(&self) -> Arc<arrow_schema::Schema> {
|
||||
self.table_definition()
|
||||
.expect("table definition should be infallible at this point")
|
||||
.into_rich_schema()
|
||||
}
|
||||
}
|
||||
@@ -26,6 +26,9 @@ pub enum Error {
|
||||
InvalidInput { message: String },
|
||||
#[snafu(display("Table '{name}' was not found"))]
|
||||
TableNotFound { name: String },
|
||||
#[snafu(display("Embedding function '{name}' was not found. : {reason}"))]
|
||||
EmbeddingFunctionNotFound { name: String, reason: String },
|
||||
|
||||
#[snafu(display("Table '{name}' already exists"))]
|
||||
TableAlreadyExists { name: String },
|
||||
#[snafu(display("Unable to created lance dataset at {path}: {source}"))]
|
||||
|
||||
@@ -194,6 +194,7 @@
|
||||
pub mod arrow;
|
||||
pub mod connection;
|
||||
pub mod data;
|
||||
pub mod embeddings;
|
||||
pub mod error;
|
||||
pub mod index;
|
||||
pub mod io;
|
||||
|
||||
@@ -23,6 +23,7 @@ use tokio::task::spawn_blocking;
|
||||
use crate::connection::{
|
||||
ConnectionInternal, CreateTableBuilder, NoData, OpenTableBuilder, TableNamesBuilder,
|
||||
};
|
||||
use crate::embeddings::EmbeddingRegistry;
|
||||
use crate::error::Result;
|
||||
use crate::Table;
|
||||
|
||||
@@ -115,4 +116,8 @@ impl ConnectionInternal for RemoteDatabase {
|
||||
async fn drop_db(&self) -> Result<()> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn embedding_registry(&self) -> &dyn EmbeddingRegistry {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,7 +10,7 @@ use crate::{
|
||||
query::{Query, QueryExecutionOptions, VectorQuery},
|
||||
table::{
|
||||
merge::MergeInsertBuilder, AddDataBuilder, NativeTable, OptimizeAction, OptimizeStats,
|
||||
TableInternal, UpdateBuilder,
|
||||
TableDefinition, TableInternal, UpdateBuilder,
|
||||
},
|
||||
};
|
||||
|
||||
@@ -120,4 +120,7 @@ impl TableInternal for RemoteTable {
|
||||
async fn list_indices(&self) -> Result<Vec<IndexConfig>> {
|
||||
todo!()
|
||||
}
|
||||
async fn table_definition(&self) -> Result<TableDefinition> {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -41,10 +41,12 @@ use lance::io::WrappingObjectStore;
|
||||
use lance_index::IndexType;
|
||||
use lance_index::{optimize::OptimizeOptions, DatasetIndexExt};
|
||||
use log::info;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use snafu::whatever;
|
||||
|
||||
use crate::arrow::IntoArrow;
|
||||
use crate::connection::NoData;
|
||||
use crate::embeddings::{EmbeddingDefinition, EmbeddingRegistry, MaybeEmbedded, MemoryRegistry};
|
||||
use crate::error::{Error, Result};
|
||||
use crate::index::vector::{IvfPqIndexBuilder, VectorIndex, VectorIndexStatistics};
|
||||
use crate::index::IndexConfig;
|
||||
@@ -63,6 +65,79 @@ use self::merge::MergeInsertBuilder;
|
||||
pub(crate) mod dataset;
|
||||
pub mod merge;
|
||||
|
||||
/// Defines the type of column
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum ColumnKind {
|
||||
/// Columns populated by data from the user (this is the most common case)
|
||||
Physical,
|
||||
/// Columns populated by applying an embedding function to the input
|
||||
Embedding(EmbeddingDefinition),
|
||||
}
|
||||
|
||||
/// Defines a column in a table
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ColumnDefinition {
|
||||
/// The source of the column data
|
||||
pub kind: ColumnKind,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TableDefinition {
|
||||
pub column_definitions: Vec<ColumnDefinition>,
|
||||
pub schema: SchemaRef,
|
||||
}
|
||||
|
||||
impl TableDefinition {
|
||||
pub fn new(schema: SchemaRef, column_definitions: Vec<ColumnDefinition>) -> Self {
|
||||
Self {
|
||||
column_definitions,
|
||||
schema,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_from_schema(schema: SchemaRef) -> Self {
|
||||
let column_definitions = schema
|
||||
.fields()
|
||||
.iter()
|
||||
.map(|_| ColumnDefinition {
|
||||
kind: ColumnKind::Physical,
|
||||
})
|
||||
.collect();
|
||||
Self::new(schema, column_definitions)
|
||||
}
|
||||
|
||||
pub fn try_from_rich_schema(schema: SchemaRef) -> Result<Self> {
|
||||
let column_definitions = schema.metadata.get("lancedb::column_definitions");
|
||||
if let Some(column_definitions) = column_definitions {
|
||||
let column_definitions: Vec<ColumnDefinition> =
|
||||
serde_json::from_str(column_definitions).map_err(|e| Error::Runtime {
|
||||
message: format!("Failed to deserialize column definitions: {}", e),
|
||||
})?;
|
||||
Ok(Self::new(schema, column_definitions))
|
||||
} else {
|
||||
let column_definitions = schema
|
||||
.fields()
|
||||
.iter()
|
||||
.map(|_| ColumnDefinition {
|
||||
kind: ColumnKind::Physical,
|
||||
})
|
||||
.collect();
|
||||
Ok(Self::new(schema, column_definitions))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn into_rich_schema(self) -> SchemaRef {
|
||||
// We have full control over the structure of column definitions. This should
|
||||
// not fail, except for a bug
|
||||
let lancedb_metadata = serde_json::to_string(&self.column_definitions).unwrap();
|
||||
let mut schema_with_metadata = (*self.schema).clone();
|
||||
schema_with_metadata
|
||||
.metadata
|
||||
.insert("lancedb::column_definitions".to_string(), lancedb_metadata);
|
||||
Arc::new(schema_with_metadata)
|
||||
}
|
||||
}
|
||||
|
||||
/// Optimize the dataset.
|
||||
///
|
||||
/// Similar to `VACUUM` in PostgreSQL, it offers different options to
|
||||
@@ -132,6 +207,7 @@ pub struct AddDataBuilder<T: IntoArrow> {
|
||||
pub(crate) data: T,
|
||||
pub(crate) mode: AddDataMode,
|
||||
pub(crate) write_options: WriteOptions,
|
||||
embedding_registry: Option<Arc<dyn EmbeddingRegistry>>,
|
||||
}
|
||||
|
||||
impl<T: IntoArrow> std::fmt::Debug for AddDataBuilder<T> {
|
||||
@@ -163,6 +239,7 @@ impl<T: IntoArrow> AddDataBuilder<T> {
|
||||
mode: self.mode,
|
||||
parent: self.parent,
|
||||
write_options: self.write_options,
|
||||
embedding_registry: self.embedding_registry,
|
||||
};
|
||||
parent.add(without_data, data).await
|
||||
}
|
||||
@@ -280,6 +357,7 @@ pub(crate) trait TableInternal: std::fmt::Display + std::fmt::Debug + Send + Syn
|
||||
async fn checkout(&self, version: u64) -> Result<()>;
|
||||
async fn checkout_latest(&self) -> Result<()>;
|
||||
async fn restore(&self) -> Result<()>;
|
||||
async fn table_definition(&self) -> Result<TableDefinition>;
|
||||
}
|
||||
|
||||
/// A Table is a collection of strong typed Rows.
|
||||
@@ -288,6 +366,7 @@ pub(crate) trait TableInternal: std::fmt::Display + std::fmt::Debug + Send + Syn
|
||||
#[derive(Clone)]
|
||||
pub struct Table {
|
||||
inner: Arc<dyn TableInternal>,
|
||||
embedding_registry: Arc<dyn EmbeddingRegistry>,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Table {
|
||||
@@ -298,7 +377,20 @@ impl std::fmt::Display for Table {
|
||||
|
||||
impl Table {
|
||||
pub(crate) fn new(inner: Arc<dyn TableInternal>) -> Self {
|
||||
Self { inner }
|
||||
Self {
|
||||
inner,
|
||||
embedding_registry: Arc::new(MemoryRegistry::new()),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn new_with_embedding_registry(
|
||||
inner: Arc<dyn TableInternal>,
|
||||
embedding_registry: Arc<dyn EmbeddingRegistry>,
|
||||
) -> Self {
|
||||
Self {
|
||||
inner,
|
||||
embedding_registry,
|
||||
}
|
||||
}
|
||||
|
||||
/// Cast as [`NativeTable`], or return None it if is not a [`NativeTable`].
|
||||
@@ -340,6 +432,7 @@ impl Table {
|
||||
data: batches,
|
||||
mode: AddDataMode::Append,
|
||||
write_options: WriteOptions::default(),
|
||||
embedding_registry: Some(self.embedding_registry.clone()),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -743,11 +836,10 @@ impl Table {
|
||||
|
||||
impl From<NativeTable> for Table {
|
||||
fn from(table: NativeTable) -> Self {
|
||||
Self {
|
||||
inner: Arc::new(table),
|
||||
}
|
||||
Self::new(Arc::new(table))
|
||||
}
|
||||
}
|
||||
|
||||
/// A table in a LanceDB database.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct NativeTable {
|
||||
@@ -918,7 +1010,6 @@ impl NativeTable {
|
||||
Some(wrapper) => params.patch_with_store_wrapper(wrapper)?,
|
||||
None => params,
|
||||
};
|
||||
|
||||
let storage_options = params
|
||||
.store_params
|
||||
.clone()
|
||||
@@ -1342,6 +1433,11 @@ impl TableInternal for NativeTable {
|
||||
Ok(Arc::new(Schema::from(&lance_schema)))
|
||||
}
|
||||
|
||||
async fn table_definition(&self) -> Result<TableDefinition> {
|
||||
let schema = self.schema().await?;
|
||||
TableDefinition::try_from_rich_schema(schema)
|
||||
}
|
||||
|
||||
async fn count_rows(&self, filter: Option<String>) -> Result<usize> {
|
||||
Ok(self.dataset.get().await?.count_rows(filter).await?)
|
||||
}
|
||||
@@ -1351,6 +1447,9 @@ impl TableInternal for NativeTable {
|
||||
add: AddDataBuilder<NoData>,
|
||||
data: Box<dyn RecordBatchReader + Send>,
|
||||
) -> Result<()> {
|
||||
let data =
|
||||
MaybeEmbedded::try_new(data, self.table_definition().await?, add.embedding_registry)?;
|
||||
|
||||
let mut lance_params = add.write_options.lance_write_params.unwrap_or(WriteParams {
|
||||
mode: match add.mode {
|
||||
AddDataMode::Append => WriteMode::Append,
|
||||
@@ -1378,8 +1477,8 @@ impl TableInternal for NativeTable {
|
||||
};
|
||||
|
||||
self.dataset.ensure_mutable().await?;
|
||||
|
||||
let dataset = Dataset::write(data, &self.uri, Some(lance_params)).await?;
|
||||
|
||||
self.dataset.set_latest(dataset).await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
320
rust/lancedb/tests/embedding_registry_test.rs
Normal file
320
rust/lancedb/tests/embedding_registry_test.rs
Normal file
@@ -0,0 +1,320 @@
|
||||
use std::{
|
||||
borrow::Cow,
|
||||
collections::{HashMap, HashSet},
|
||||
iter::repeat,
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use arrow::buffer::NullBuffer;
|
||||
use arrow_array::{
|
||||
Array, FixedSizeListArray, Float32Array, Int32Array, RecordBatch, RecordBatchIterator,
|
||||
StringArray,
|
||||
};
|
||||
use arrow_schema::{DataType, Field, Schema};
|
||||
use futures::StreamExt;
|
||||
use lancedb::{
|
||||
arrow::IntoArrow,
|
||||
connect,
|
||||
embeddings::{EmbeddingDefinition, EmbeddingFunction, EmbeddingRegistry},
|
||||
query::ExecutableQuery,
|
||||
Error, Result,
|
||||
};
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_custom_func() -> Result<()> {
|
||||
let tempdir = tempfile::tempdir().unwrap();
|
||||
let tempdir = tempdir.path().to_str().unwrap();
|
||||
let db = connect(tempdir).execute().await?;
|
||||
let embed_fun = MockEmbed::new("embed_fun".to_string(), 1);
|
||||
db.embedding_registry()
|
||||
.register("embed_fun", Arc::new(embed_fun.clone()))?;
|
||||
|
||||
let tbl = db
|
||||
.create_table("test", create_some_records()?)
|
||||
.add_embedding(EmbeddingDefinition::new(
|
||||
"text",
|
||||
&embed_fun.name,
|
||||
Some("embeddings"),
|
||||
))?
|
||||
.execute()
|
||||
.await?;
|
||||
let mut res = tbl.query().execute().await?;
|
||||
while let Some(Ok(batch)) = res.next().await {
|
||||
let embeddings = batch.column_by_name("embeddings");
|
||||
assert!(embeddings.is_some());
|
||||
let embeddings = embeddings.unwrap();
|
||||
assert_eq!(embeddings.data_type(), embed_fun.dest_type()?.as_ref());
|
||||
}
|
||||
// now make sure the embeddings are applied when
|
||||
// we add new records too
|
||||
tbl.add(create_some_records()?).execute().await?;
|
||||
let mut res = tbl.query().execute().await?;
|
||||
while let Some(Ok(batch)) = res.next().await {
|
||||
let embeddings = batch.column_by_name("embeddings");
|
||||
assert!(embeddings.is_some());
|
||||
let embeddings = embeddings.unwrap();
|
||||
assert_eq!(embeddings.data_type(), embed_fun.dest_type()?.as_ref());
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_custom_registry() -> Result<()> {
|
||||
let tempdir = tempfile::tempdir().unwrap();
|
||||
let tempdir = tempdir.path().to_str().unwrap();
|
||||
|
||||
let db = connect(tempdir)
|
||||
.embedding_registry(Arc::new(MyRegistry::default()))
|
||||
.execute()
|
||||
.await?;
|
||||
|
||||
let tbl = db
|
||||
.create_table("test", create_some_records()?)
|
||||
.add_embedding(EmbeddingDefinition::new(
|
||||
"text",
|
||||
"func_1",
|
||||
Some("embeddings"),
|
||||
))?
|
||||
.execute()
|
||||
.await?;
|
||||
let mut res = tbl.query().execute().await?;
|
||||
while let Some(Ok(batch)) = res.next().await {
|
||||
let embeddings = batch.column_by_name("embeddings");
|
||||
assert!(embeddings.is_some());
|
||||
let embeddings = embeddings.unwrap();
|
||||
assert_eq!(
|
||||
embeddings.data_type(),
|
||||
MockEmbed::new("func_1".to_string(), 1)
|
||||
.dest_type()?
|
||||
.as_ref()
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_multiple_embeddings() -> Result<()> {
|
||||
let tempdir = tempfile::tempdir().unwrap();
|
||||
let tempdir = tempdir.path().to_str().unwrap();
|
||||
|
||||
let db = connect(tempdir).execute().await?;
|
||||
let func_1 = MockEmbed::new("func_1".to_string(), 1);
|
||||
let func_2 = MockEmbed::new("func_2".to_string(), 10);
|
||||
db.embedding_registry()
|
||||
.register(&func_1.name, Arc::new(func_1.clone()))?;
|
||||
db.embedding_registry()
|
||||
.register(&func_2.name, Arc::new(func_2.clone()))?;
|
||||
|
||||
let tbl = db
|
||||
.create_table("test", create_some_records()?)
|
||||
.add_embedding(EmbeddingDefinition::new(
|
||||
"text",
|
||||
&func_1.name,
|
||||
Some("first_embeddings"),
|
||||
))?
|
||||
.add_embedding(EmbeddingDefinition::new(
|
||||
"text",
|
||||
&func_2.name,
|
||||
Some("second_embeddings"),
|
||||
))?
|
||||
.execute()
|
||||
.await?;
|
||||
let mut res = tbl.query().execute().await?;
|
||||
while let Some(Ok(batch)) = res.next().await {
|
||||
let embeddings = batch.column_by_name("first_embeddings");
|
||||
assert!(embeddings.is_some());
|
||||
let second_embeddings = batch.column_by_name("second_embeddings");
|
||||
assert!(second_embeddings.is_some());
|
||||
|
||||
let embeddings = embeddings.unwrap();
|
||||
assert_eq!(embeddings.data_type(), func_1.dest_type()?.as_ref());
|
||||
|
||||
let second_embeddings = second_embeddings.unwrap();
|
||||
assert_eq!(second_embeddings.data_type(), func_2.dest_type()?.as_ref());
|
||||
}
|
||||
|
||||
// now make sure the embeddings are applied when
|
||||
// we add new records too
|
||||
tbl.add(create_some_records()?).execute().await?;
|
||||
let mut res = tbl.query().execute().await?;
|
||||
while let Some(Ok(batch)) = res.next().await {
|
||||
let embeddings = batch.column_by_name("first_embeddings");
|
||||
assert!(embeddings.is_some());
|
||||
let second_embeddings = batch.column_by_name("second_embeddings");
|
||||
assert!(second_embeddings.is_some());
|
||||
|
||||
let embeddings = embeddings.unwrap();
|
||||
assert_eq!(embeddings.data_type(), func_1.dest_type()?.as_ref());
|
||||
|
||||
let second_embeddings = second_embeddings.unwrap();
|
||||
assert_eq!(second_embeddings.data_type(), func_2.dest_type()?.as_ref());
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_no_func_in_registry() -> Result<()> {
|
||||
let tempdir = tempfile::tempdir().unwrap();
|
||||
let tempdir = tempdir.path().to_str().unwrap();
|
||||
|
||||
let db = connect(tempdir).execute().await?;
|
||||
|
||||
let res = db
|
||||
.create_table("test", create_some_records()?)
|
||||
.add_embedding(EmbeddingDefinition::new(
|
||||
"text",
|
||||
"some_func",
|
||||
Some("first_embeddings"),
|
||||
));
|
||||
assert!(res.is_err());
|
||||
assert!(matches!(
|
||||
res.err().unwrap(),
|
||||
Error::EmbeddingFunctionNotFound { .. }
|
||||
));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_no_func_in_registry_on_add() -> Result<()> {
|
||||
let tempdir = tempfile::tempdir().unwrap();
|
||||
let tempdir = tempdir.path().to_str().unwrap();
|
||||
|
||||
let db = connect(tempdir).execute().await?;
|
||||
db.embedding_registry().register(
|
||||
"some_func",
|
||||
Arc::new(MockEmbed::new("some_func".to_string(), 1)),
|
||||
)?;
|
||||
|
||||
db.create_table("test", create_some_records()?)
|
||||
.add_embedding(EmbeddingDefinition::new(
|
||||
"text",
|
||||
"some_func",
|
||||
Some("first_embeddings"),
|
||||
))?
|
||||
.execute()
|
||||
.await?;
|
||||
|
||||
let db = connect(tempdir).execute().await?;
|
||||
|
||||
let tbl = db.open_table("test").execute().await?;
|
||||
// This should fail because 'tbl' is expecting "some_func" to be in the registry
|
||||
let res = tbl.add(create_some_records()?).execute().await;
|
||||
assert!(res.is_err());
|
||||
assert!(matches!(
|
||||
res.unwrap_err(),
|
||||
crate::Error::EmbeddingFunctionNotFound { .. }
|
||||
));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn create_some_records() -> Result<impl IntoArrow> {
|
||||
const TOTAL: usize = 2;
|
||||
|
||||
let schema = Arc::new(Schema::new(vec![
|
||||
Field::new("id", DataType::Int32, false),
|
||||
Field::new("text", DataType::Utf8, true),
|
||||
]));
|
||||
|
||||
// Create a RecordBatch stream.
|
||||
let batches = RecordBatchIterator::new(
|
||||
vec![RecordBatch::try_new(
|
||||
schema.clone(),
|
||||
vec![
|
||||
Arc::new(Int32Array::from_iter_values(0..TOTAL as i32)),
|
||||
Arc::new(StringArray::from_iter(
|
||||
repeat(Some("hello world".to_string())).take(TOTAL),
|
||||
)),
|
||||
],
|
||||
)
|
||||
.unwrap()]
|
||||
.into_iter()
|
||||
.map(Ok),
|
||||
schema.clone(),
|
||||
);
|
||||
Ok(Box::new(batches))
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct MyRegistry {
|
||||
functions: HashMap<String, Arc<dyn EmbeddingFunction>>,
|
||||
}
|
||||
impl Default for MyRegistry {
|
||||
fn default() -> Self {
|
||||
let funcs: Vec<Arc<dyn EmbeddingFunction>> = vec![
|
||||
Arc::new(MockEmbed::new("func_1".to_string(), 1)),
|
||||
Arc::new(MockEmbed::new("func_2".to_string(), 10)),
|
||||
];
|
||||
Self {
|
||||
functions: funcs
|
||||
.into_iter()
|
||||
.map(|f| (f.name().to_string(), f))
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// a mock registry that only has one function called `embed_fun`
|
||||
impl EmbeddingRegistry for MyRegistry {
|
||||
fn functions(&self) -> HashSet<String> {
|
||||
self.functions.keys().cloned().collect()
|
||||
}
|
||||
|
||||
fn register(&self, _name: &str, _function: Arc<dyn EmbeddingFunction>) -> Result<()> {
|
||||
Err(Error::Other {
|
||||
message: "MyRegistry is read-only".to_string(),
|
||||
source: None,
|
||||
})
|
||||
}
|
||||
|
||||
fn get(&self, name: &str) -> Option<Arc<dyn EmbeddingFunction>> {
|
||||
self.functions.get(name).cloned()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct MockEmbed {
|
||||
source_type: DataType,
|
||||
dest_type: DataType,
|
||||
name: String,
|
||||
dim: usize,
|
||||
}
|
||||
|
||||
impl MockEmbed {
|
||||
pub fn new(name: String, dim: usize) -> Self {
|
||||
Self {
|
||||
source_type: DataType::Utf8,
|
||||
dest_type: DataType::new_fixed_size_list(DataType::Float32, dim as _, true),
|
||||
name,
|
||||
dim,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl EmbeddingFunction for MockEmbed {
|
||||
fn name(&self) -> &str {
|
||||
&self.name
|
||||
}
|
||||
fn source_type(&self) -> Result<Cow<DataType>> {
|
||||
Ok(Cow::Borrowed(&self.source_type))
|
||||
}
|
||||
fn dest_type(&self) -> Result<Cow<DataType>> {
|
||||
Ok(Cow::Borrowed(&self.dest_type))
|
||||
}
|
||||
fn embed(&self, source: Arc<dyn Array>) -> Result<Arc<dyn Array>> {
|
||||
// We can't use the FixedSizeListBuilder here because it always adds a null bitmap
|
||||
// and we want to explicitly work with non-nullable arrays.
|
||||
let len = source.len();
|
||||
let inner = Arc::new(Float32Array::from(vec![Some(1.0); len * self.dim]));
|
||||
let field = Field::new("item", inner.data_type().clone(), false);
|
||||
let arr = FixedSizeListArray::new(
|
||||
Arc::new(field),
|
||||
self.dim as _,
|
||||
inner,
|
||||
Some(NullBuffer::new_valid(len)),
|
||||
);
|
||||
|
||||
Ok(Arc::new(arr))
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user