From 9d2fb7d602276dc5e41b7ce06aba21b3197c969d Mon Sep 17 00:00:00 2001 From: Cory Grinstead Date: Mon, 6 May 2024 18:39:07 -0500 Subject: [PATCH] feat: rust embedding registry (#1259) Todo: - [x] add proper documentation - [x] add unit tests - [x] better handling of the registry**1 - [x] allow user defined registry**2 **1 The python implementation just uses a global registry so it makes things a bit easier. I attached it to the db/connection to prevent future conflicts if running multiple connections/databases. I mostly modeled the registry & pattern off of datafusion's [FunctionRegistry](https://docs.rs/datafusion/latest/datafusion/execution/trait.FunctionRegistry.html). **2 Ideally, the user should be able to provide it's own registry entirely, but currently it just uses an in memory registry by default (_which isn't configurable_) `rust/lancedb/examples/embedding_registry.rs` provides a thorough example of expected usage. --- Some additional notes: This does not provide any of the out of box functionality that the python registry does. _i.e there are no built-in embedding functions._ You can think of this as the ground work for adding those built in functions, So while this is part of https://github.com/lancedb/lancedb/issues/994, it does not yet offer feature parity. --- .gitignore | 2 +- rust/lancedb/src/arrow.rs | 7 + rust/lancedb/src/connection.rs | 95 +++++- rust/lancedb/src/embeddings.rs | 307 +++++++++++++++++ rust/lancedb/src/error.rs | 3 + rust/lancedb/src/lib.rs | 1 + rust/lancedb/src/remote/db.rs | 5 + rust/lancedb/src/remote/table.rs | 5 +- rust/lancedb/src/table.rs | 111 +++++- rust/lancedb/tests/embedding_registry_test.rs | 320 ++++++++++++++++++ 10 files changed, 838 insertions(+), 18 deletions(-) create mode 100644 rust/lancedb/src/embeddings.rs create mode 100644 rust/lancedb/tests/embedding_registry_test.rs diff --git a/.gitignore b/.gitignore index da3b594f..f61eaccc 100644 --- a/.gitignore +++ b/.gitignore @@ -6,7 +6,7 @@ venv .vscode - +.zed rust/target rust/Cargo.lock diff --git a/rust/lancedb/src/arrow.rs b/rust/lancedb/src/arrow.rs index f9440bed..2c0fad5e 100644 --- a/rust/lancedb/src/arrow.rs +++ b/rust/lancedb/src/arrow.rs @@ -116,12 +116,19 @@ pub trait IntoArrow { fn into_arrow(self) -> Result>; } +pub type BoxedRecordBatchReader = Box; + impl IntoArrow for T { fn into_arrow(self) -> Result> { Ok(Box::new(self)) } } +impl>> SimpleRecordBatchStream { + pub fn new(stream: S, schema: Arc) -> Self { + Self { schema, stream } + } +} #[cfg(feature = "polars")] /// An iterator of record batches formed from a Polars DataFrame. pub struct PolarsDataFrameRecordBatchReader { diff --git a/rust/lancedb/src/connection.rs b/rust/lancedb/src/connection.rs index 26e345ce..baea0cea 100644 --- a/rust/lancedb/src/connection.rs +++ b/rust/lancedb/src/connection.rs @@ -27,9 +27,12 @@ use object_store::{aws::AwsCredential, local::LocalFileSystem}; use snafu::prelude::*; use crate::arrow::IntoArrow; +use crate::embeddings::{ + EmbeddingDefinition, EmbeddingFunction, EmbeddingRegistry, MemoryRegistry, WithEmbeddings, +}; use crate::error::{CreateDirSnafu, Error, InvalidTableNameSnafu, Result}; use crate::io::object_store::MirroringObjectStoreWrapper; -use crate::table::{NativeTable, WriteOptions}; +use crate::table::{NativeTable, TableDefinition, WriteOptions}; use crate::utils::validate_table_name; use crate::Table; @@ -133,9 +136,10 @@ pub struct CreateTableBuilder { parent: Arc, pub(crate) name: String, pub(crate) data: Option, - pub(crate) schema: Option, pub(crate) mode: CreateTableMode, pub(crate) write_options: WriteOptions, + pub(crate) table_definition: Option, + pub(crate) embeddings: Vec<(EmbeddingDefinition, Arc)>, } // Builder methods that only apply when we have initial data @@ -145,9 +149,10 @@ impl CreateTableBuilder { parent, name, data: Some(data), - schema: None, mode: CreateTableMode::default(), write_options: WriteOptions::default(), + table_definition: None, + embeddings: Vec::new(), } } @@ -175,24 +180,43 @@ impl CreateTableBuilder { parent: self.parent, name: self.name, data: None, - schema: self.schema, + table_definition: self.table_definition, mode: self.mode, write_options: self.write_options, + embeddings: self.embeddings, }; Ok((data, builder)) } + + pub fn add_embedding(mut self, definition: EmbeddingDefinition) -> Result { + // Early verification of the embedding name + let embedding_func = self + .parent + .embedding_registry() + .get(&definition.embedding_name) + .ok_or_else(|| Error::EmbeddingFunctionNotFound { + name: definition.embedding_name.to_string(), + reason: "No embedding function found in the connection's embedding_registry" + .to_string(), + })?; + + self.embeddings.push((definition, embedding_func)); + Ok(self) + } } // Builder methods that only apply when we do not have initial data impl CreateTableBuilder { fn new(parent: Arc, name: String, schema: SchemaRef) -> Self { + let table_definition = TableDefinition::new_from_schema(schema); Self { parent, name, data: None, - schema: Some(schema), + table_definition: Some(table_definition), mode: CreateTableMode::default(), write_options: WriteOptions::default(), + embeddings: Vec::new(), } } @@ -350,6 +374,7 @@ impl OpenTableBuilder { pub(crate) trait ConnectionInternal: Send + Sync + std::fmt::Debug + std::fmt::Display + 'static { + fn embedding_registry(&self) -> &dyn EmbeddingRegistry; async fn table_names(&self, options: TableNamesBuilder) -> Result>; async fn do_create_table( &self, @@ -366,7 +391,7 @@ pub(crate) trait ConnectionInternal: ) -> Result { let batches = Box::new(RecordBatchIterator::new( vec![], - options.schema.as_ref().unwrap().clone(), + options.table_definition.clone().unwrap().schema.clone(), )); self.do_create_table(options, batches).await } @@ -453,6 +478,13 @@ impl Connection { pub async fn drop_db(&self) -> Result<()> { self.internal.drop_db().await } + + /// Get the in-memory embedding registry. + /// It's important to note that the embedding registry is not persisted across connections. + /// So if a table contains embeddings, you will need to make sure that you are using a connection that has the same embedding functions registered + pub fn embedding_registry(&self) -> &dyn EmbeddingRegistry { + self.internal.embedding_registry() + } } #[derive(Debug)] @@ -486,6 +518,7 @@ pub struct ConnectBuilder { /// consistency only applies to read operations. Write operations are /// always consistent. read_consistency_interval: Option, + embedding_registry: Option>, } impl ConnectBuilder { @@ -498,6 +531,7 @@ impl ConnectBuilder { host_override: None, read_consistency_interval: None, storage_options: HashMap::new(), + embedding_registry: None, } } @@ -516,6 +550,12 @@ impl ConnectBuilder { self } + /// Provide a custom [`EmbeddingRegistry`] to use for this connection. + pub fn embedding_registry(mut self, registry: Arc) -> Self { + self.embedding_registry = Some(registry); + self + } + /// [`AwsCredential`] to use when connecting to S3. #[deprecated(note = "Pass through storage_options instead")] pub fn aws_creds(mut self, aws_creds: AwsCredential) -> Self { @@ -642,6 +682,7 @@ struct Database { // Storage options to be inherited by tables created from this connection storage_options: HashMap, + embedding_registry: Arc, } impl std::fmt::Display for Database { @@ -675,7 +716,12 @@ impl Database { // TODO: pass params regardless of OS match parse_res { Ok(url) if url.scheme().len() == 1 && cfg!(windows) => { - Self::open_path(uri, options.read_consistency_interval).await + Self::open_path( + uri, + options.read_consistency_interval, + options.embedding_registry.clone(), + ) + .await } Ok(mut url) => { // iter thru the query params and extract the commit store param @@ -745,6 +791,10 @@ impl Database { None => None, }; + let embedding_registry = options + .embedding_registry + .clone() + .unwrap_or_else(|| Arc::new(MemoryRegistry::new())); Ok(Self { uri: table_base_uri, query_string, @@ -753,20 +803,33 @@ impl Database { store_wrapper: write_store_wrapper, read_consistency_interval: options.read_consistency_interval, storage_options, + embedding_registry, }) } - Err(_) => Self::open_path(uri, options.read_consistency_interval).await, + Err(_) => { + Self::open_path( + uri, + options.read_consistency_interval, + options.embedding_registry.clone(), + ) + .await + } } } async fn open_path( path: &str, read_consistency_interval: Option, + embedding_registry: Option>, ) -> Result { let (object_store, base_path) = ObjectStore::from_uri(path).await?; if object_store.is_local() { Self::try_create_dir(path).context(CreateDirSnafu { path })?; } + + let embedding_registry = + embedding_registry.unwrap_or_else(|| Arc::new(MemoryRegistry::new())); + Ok(Self { uri: path.to_string(), query_string: None, @@ -775,6 +838,7 @@ impl Database { store_wrapper: None, read_consistency_interval, storage_options: HashMap::new(), + embedding_registry, }) } @@ -815,6 +879,9 @@ impl Database { #[async_trait::async_trait] impl ConnectionInternal for Database { + fn embedding_registry(&self) -> &dyn EmbeddingRegistry { + self.embedding_registry.as_ref() + } async fn table_names(&self, options: TableNamesBuilder) -> Result> { let mut f = self .object_store @@ -851,7 +918,7 @@ impl ConnectionInternal for Database { data: Box, ) -> Result
{ let table_uri = self.table_uri(&options.name)?; - + let embedding_registry = self.embedding_registry.clone(); // Inherit storage options from the connection let storage_options = options .write_options @@ -866,6 +933,11 @@ impl ConnectionInternal for Database { storage_options.insert(key.clone(), value.clone()); } } + let data = if options.embeddings.is_empty() { + data + } else { + Box::new(WithEmbeddings::new(data, options.embeddings)) + }; let mut write_params = options.write_options.lance_write_params.unwrap_or_default(); if matches!(&options.mode, CreateTableMode::Overwrite) { @@ -882,7 +954,10 @@ impl ConnectionInternal for Database { ) .await { - Ok(table) => Ok(Table::new(Arc::new(table))), + Ok(table) => Ok(Table::new_with_embedding_registry( + Arc::new(table), + embedding_registry, + )), Err(Error::TableAlreadyExists { name }) => match options.mode { CreateTableMode::Create => Err(Error::TableAlreadyExists { name }), CreateTableMode::ExistOk(callback) => { diff --git a/rust/lancedb/src/embeddings.rs b/rust/lancedb/src/embeddings.rs new file mode 100644 index 00000000..07a5725a --- /dev/null +++ b/rust/lancedb/src/embeddings.rs @@ -0,0 +1,307 @@ +// Copyright 2024 LanceDB Developers. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use lance::arrow::RecordBatchExt; +use std::{ + borrow::Cow, + collections::{HashMap, HashSet}, + sync::{Arc, RwLock}, +}; + +use arrow_array::{Array, RecordBatch, RecordBatchReader}; +use arrow_schema::{DataType, Field, SchemaBuilder}; +// use async_trait::async_trait; +use serde::{Deserialize, Serialize}; + +use crate::{ + error::Result, + table::{ColumnDefinition, ColumnKind, TableDefinition}, + Error, +}; + +/// Trait for embedding functions +/// +/// An embedding function is a function that is applied to a column of input data +/// to produce an "embedding" of that input. This embedding is then stored in the +/// database alongside (or instead of) the original input. +/// +/// An "embedding" is often a lower-dimensional representation of the input data. +/// For example, sentence-transformers can be used to embed sentences into a 768-dimensional +/// vector space. This is useful for tasks like similarity search, where we want to find +/// similar sentences to a query sentence. +/// +/// To use an embedding function you must first register it with the `EmbeddingsRegistry`. +/// Then you can define it on a column in the table schema. That embedding will then be used +/// to embed the data in that column. +pub trait EmbeddingFunction: std::fmt::Debug + Send + Sync { + fn name(&self) -> &str; + /// The type of the input data + fn source_type(&self) -> Result>; + /// The type of the output data + /// This should **always** match the output of the `embed` function + fn dest_type(&self) -> Result>; + /// Embed the input + fn embed(&self, source: Arc) -> Result>; +} + +/// Defines an embedding from input data into a lower-dimensional space +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct EmbeddingDefinition { + /// The name of the column in the input data + pub source_column: String, + /// The name of the embedding column, if not specified + /// it will be the source column with `_embedding` appended + pub dest_column: Option, + /// The name of the embedding function to apply + pub embedding_name: String, +} + +impl EmbeddingDefinition { + pub fn new>(source_column: S, embedding_name: S, dest: Option) -> Self { + Self { + source_column: source_column.into(), + dest_column: dest.map(|d| d.into()), + embedding_name: embedding_name.into(), + } + } +} + +/// A registry of embedding +pub trait EmbeddingRegistry: Send + Sync + std::fmt::Debug { + /// Return the names of all registered embedding functions + fn functions(&self) -> HashSet; + /// Register a new [`EmbeddingFunction + /// Returns an error if the function can not be registered + fn register(&self, name: &str, function: Arc) -> Result<()>; + /// Get an embedding function by name + fn get(&self, name: &str) -> Option>; +} + +/// A [`EmbeddingRegistry`] that uses in-memory [`HashMap`]s +#[derive(Debug, Default, Clone)] +pub struct MemoryRegistry { + functions: Arc>>>, +} + +impl EmbeddingRegistry for MemoryRegistry { + fn functions(&self) -> HashSet { + self.functions.read().unwrap().keys().cloned().collect() + } + fn register(&self, name: &str, function: Arc) -> Result<()> { + self.functions + .write() + .unwrap() + .insert(name.to_string(), function); + + Ok(()) + } + + fn get(&self, name: &str) -> Option> { + self.functions.read().unwrap().get(name).cloned() + } +} + +impl MemoryRegistry { + /// Create a new `MemoryRegistry` + pub fn new() -> Self { + Self::default() + } +} + +/// A record batch reader that has embeddings applied to it +/// This is a wrapper around another record batch reader that applies an embedding function +/// when reading from the record batch +pub struct WithEmbeddings { + inner: R, + embeddings: Vec<(EmbeddingDefinition, Arc)>, +} + +/// A record batch that might have embeddings applied to it. +pub enum MaybeEmbedded { + /// The record batch reader has embeddings applied to it + Yes(WithEmbeddings), + /// The record batch reader does not have embeddings applied to it + /// The inner record batch reader is returned as-is + No(R), +} + +impl MaybeEmbedded { + /// Create a new RecordBatchReader with embeddings applied to it if the table definition + /// specifies an embedding column and the registry contains an embedding function with that name + /// Otherwise, this is a no-op and the inner RecordBatchReader is returned. + pub fn try_new( + inner: R, + table_definition: TableDefinition, + registry: Option>, + ) -> Result { + if let Some(registry) = registry { + let mut embeddings = Vec::with_capacity(table_definition.column_definitions.len()); + for cd in table_definition.column_definitions.iter() { + if let ColumnKind::Embedding(embedding_def) = &cd.kind { + match registry.get(&embedding_def.embedding_name) { + Some(func) => { + embeddings.push((embedding_def.clone(), func)); + } + None => { + return Err(Error::EmbeddingFunctionNotFound { + name: embedding_def.embedding_name.to_string(), + reason: format!( + "Table was defined with an embedding column `{}` but no embedding function was found with that name within the registry.", + embedding_def.embedding_name + ), + }); + } + } + } + } + + if !embeddings.is_empty() { + return Ok(Self::Yes(WithEmbeddings { inner, embeddings })); + } + }; + + // No embeddings to apply + Ok(Self::No(inner)) + } +} + +impl WithEmbeddings { + pub fn new( + inner: R, + embeddings: Vec<(EmbeddingDefinition, Arc)>, + ) -> Self { + Self { inner, embeddings } + } +} + +impl WithEmbeddings { + fn dest_fields(&self) -> Result> { + let schema = self.inner.schema(); + self.embeddings + .iter() + .map(|(ed, func)| { + let src_field = schema.field_with_name(&ed.source_column).unwrap(); + + let field_name = ed + .dest_column + .clone() + .unwrap_or_else(|| format!("{}_embedding", &ed.source_column)); + Ok(Field::new( + field_name, + func.dest_type()?.into_owned(), + src_field.is_nullable(), + )) + }) + .collect() + } + + fn column_defs(&self) -> Vec { + let base_schema = self.inner.schema(); + base_schema + .fields() + .iter() + .map(|_| ColumnDefinition { + kind: ColumnKind::Physical, + }) + .chain(self.embeddings.iter().map(|(ed, _)| ColumnDefinition { + kind: ColumnKind::Embedding(ed.clone()), + })) + .collect::>() + } + + pub fn table_definition(&self) -> Result { + let base_schema = self.inner.schema(); + + let output_fields = self.dest_fields()?; + let column_definitions = self.column_defs(); + + let mut sb: SchemaBuilder = base_schema.as_ref().into(); + sb.extend(output_fields); + + let schema = Arc::new(sb.finish()); + Ok(TableDefinition { + schema, + column_definitions, + }) + } +} + +impl Iterator for MaybeEmbedded { + type Item = std::result::Result; + fn next(&mut self) -> Option { + match self { + Self::Yes(inner) => inner.next(), + Self::No(inner) => inner.next(), + } + } +} + +impl RecordBatchReader for MaybeEmbedded { + fn schema(&self) -> Arc { + match self { + Self::Yes(inner) => inner.schema(), + Self::No(inner) => inner.schema(), + } + } +} + +impl Iterator for WithEmbeddings { + type Item = std::result::Result; + + fn next(&mut self) -> Option { + let batch = self.inner.next()?; + match batch { + Ok(mut batch) => { + // todo: parallelize this + for (fld, func) in self.embeddings.iter() { + let src_column = batch.column_by_name(&fld.source_column).unwrap(); + let embedding = match func.embed(src_column.clone()) { + Ok(embedding) => embedding, + Err(e) => { + return Some(Err(arrow_schema::ArrowError::ComputeError(format!( + "Error computing embedding: {}", + e + )))) + } + }; + let dst_field_name = fld + .dest_column + .clone() + .unwrap_or_else(|| format!("{}_embedding", &fld.source_column)); + + let dst_field = Field::new( + dst_field_name, + embedding.data_type().clone(), + embedding.nulls().is_some(), + ); + + match batch.try_with_column(dst_field.clone(), embedding) { + Ok(b) => batch = b, + Err(e) => return Some(Err(e)), + }; + } + Some(Ok(batch)) + } + Err(e) => Some(Err(e)), + } + } +} + +impl RecordBatchReader for WithEmbeddings { + fn schema(&self) -> Arc { + self.table_definition() + .expect("table definition should be infallible at this point") + .into_rich_schema() + } +} diff --git a/rust/lancedb/src/error.rs b/rust/lancedb/src/error.rs index 1f14ef57..1e0d1511 100644 --- a/rust/lancedb/src/error.rs +++ b/rust/lancedb/src/error.rs @@ -26,6 +26,9 @@ pub enum Error { InvalidInput { message: String }, #[snafu(display("Table '{name}' was not found"))] TableNotFound { name: String }, + #[snafu(display("Embedding function '{name}' was not found. : {reason}"))] + EmbeddingFunctionNotFound { name: String, reason: String }, + #[snafu(display("Table '{name}' already exists"))] TableAlreadyExists { name: String }, #[snafu(display("Unable to created lance dataset at {path}: {source}"))] diff --git a/rust/lancedb/src/lib.rs b/rust/lancedb/src/lib.rs index 5e0d8f17..01b3aeb3 100644 --- a/rust/lancedb/src/lib.rs +++ b/rust/lancedb/src/lib.rs @@ -194,6 +194,7 @@ pub mod arrow; pub mod connection; pub mod data; +pub mod embeddings; pub mod error; pub mod index; pub mod io; diff --git a/rust/lancedb/src/remote/db.rs b/rust/lancedb/src/remote/db.rs index 9511475e..b8062367 100644 --- a/rust/lancedb/src/remote/db.rs +++ b/rust/lancedb/src/remote/db.rs @@ -23,6 +23,7 @@ use tokio::task::spawn_blocking; use crate::connection::{ ConnectionInternal, CreateTableBuilder, NoData, OpenTableBuilder, TableNamesBuilder, }; +use crate::embeddings::EmbeddingRegistry; use crate::error::Result; use crate::Table; @@ -115,4 +116,8 @@ impl ConnectionInternal for RemoteDatabase { async fn drop_db(&self) -> Result<()> { todo!() } + + fn embedding_registry(&self) -> &dyn EmbeddingRegistry { + todo!() + } } diff --git a/rust/lancedb/src/remote/table.rs b/rust/lancedb/src/remote/table.rs index cbfa4fac..84b2c247 100644 --- a/rust/lancedb/src/remote/table.rs +++ b/rust/lancedb/src/remote/table.rs @@ -10,7 +10,7 @@ use crate::{ query::{Query, QueryExecutionOptions, VectorQuery}, table::{ merge::MergeInsertBuilder, AddDataBuilder, NativeTable, OptimizeAction, OptimizeStats, - TableInternal, UpdateBuilder, + TableDefinition, TableInternal, UpdateBuilder, }, }; @@ -120,4 +120,7 @@ impl TableInternal for RemoteTable { async fn list_indices(&self) -> Result> { todo!() } + async fn table_definition(&self) -> Result { + todo!() + } } diff --git a/rust/lancedb/src/table.rs b/rust/lancedb/src/table.rs index 8efbd115..522f5dc4 100644 --- a/rust/lancedb/src/table.rs +++ b/rust/lancedb/src/table.rs @@ -41,10 +41,12 @@ use lance::io::WrappingObjectStore; use lance_index::IndexType; use lance_index::{optimize::OptimizeOptions, DatasetIndexExt}; use log::info; +use serde::{Deserialize, Serialize}; use snafu::whatever; use crate::arrow::IntoArrow; use crate::connection::NoData; +use crate::embeddings::{EmbeddingDefinition, EmbeddingRegistry, MaybeEmbedded, MemoryRegistry}; use crate::error::{Error, Result}; use crate::index::vector::{IvfPqIndexBuilder, VectorIndex, VectorIndexStatistics}; use crate::index::IndexConfig; @@ -63,6 +65,79 @@ use self::merge::MergeInsertBuilder; pub(crate) mod dataset; pub mod merge; +/// Defines the type of column +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum ColumnKind { + /// Columns populated by data from the user (this is the most common case) + Physical, + /// Columns populated by applying an embedding function to the input + Embedding(EmbeddingDefinition), +} + +/// Defines a column in a table +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ColumnDefinition { + /// The source of the column data + pub kind: ColumnKind, +} + +#[derive(Debug, Clone)] +pub struct TableDefinition { + pub column_definitions: Vec, + pub schema: SchemaRef, +} + +impl TableDefinition { + pub fn new(schema: SchemaRef, column_definitions: Vec) -> Self { + Self { + column_definitions, + schema, + } + } + + pub fn new_from_schema(schema: SchemaRef) -> Self { + let column_definitions = schema + .fields() + .iter() + .map(|_| ColumnDefinition { + kind: ColumnKind::Physical, + }) + .collect(); + Self::new(schema, column_definitions) + } + + pub fn try_from_rich_schema(schema: SchemaRef) -> Result { + let column_definitions = schema.metadata.get("lancedb::column_definitions"); + if let Some(column_definitions) = column_definitions { + let column_definitions: Vec = + serde_json::from_str(column_definitions).map_err(|e| Error::Runtime { + message: format!("Failed to deserialize column definitions: {}", e), + })?; + Ok(Self::new(schema, column_definitions)) + } else { + let column_definitions = schema + .fields() + .iter() + .map(|_| ColumnDefinition { + kind: ColumnKind::Physical, + }) + .collect(); + Ok(Self::new(schema, column_definitions)) + } + } + + pub fn into_rich_schema(self) -> SchemaRef { + // We have full control over the structure of column definitions. This should + // not fail, except for a bug + let lancedb_metadata = serde_json::to_string(&self.column_definitions).unwrap(); + let mut schema_with_metadata = (*self.schema).clone(); + schema_with_metadata + .metadata + .insert("lancedb::column_definitions".to_string(), lancedb_metadata); + Arc::new(schema_with_metadata) + } +} + /// Optimize the dataset. /// /// Similar to `VACUUM` in PostgreSQL, it offers different options to @@ -132,6 +207,7 @@ pub struct AddDataBuilder { pub(crate) data: T, pub(crate) mode: AddDataMode, pub(crate) write_options: WriteOptions, + embedding_registry: Option>, } impl std::fmt::Debug for AddDataBuilder { @@ -163,6 +239,7 @@ impl AddDataBuilder { mode: self.mode, parent: self.parent, write_options: self.write_options, + embedding_registry: self.embedding_registry, }; parent.add(without_data, data).await } @@ -280,6 +357,7 @@ pub(crate) trait TableInternal: std::fmt::Display + std::fmt::Debug + Send + Syn async fn checkout(&self, version: u64) -> Result<()>; async fn checkout_latest(&self) -> Result<()>; async fn restore(&self) -> Result<()>; + async fn table_definition(&self) -> Result; } /// A Table is a collection of strong typed Rows. @@ -288,6 +366,7 @@ pub(crate) trait TableInternal: std::fmt::Display + std::fmt::Debug + Send + Syn #[derive(Clone)] pub struct Table { inner: Arc, + embedding_registry: Arc, } impl std::fmt::Display for Table { @@ -298,7 +377,20 @@ impl std::fmt::Display for Table { impl Table { pub(crate) fn new(inner: Arc) -> Self { - Self { inner } + Self { + inner, + embedding_registry: Arc::new(MemoryRegistry::new()), + } + } + + pub(crate) fn new_with_embedding_registry( + inner: Arc, + embedding_registry: Arc, + ) -> Self { + Self { + inner, + embedding_registry, + } } /// Cast as [`NativeTable`], or return None it if is not a [`NativeTable`]. @@ -340,6 +432,7 @@ impl Table { data: batches, mode: AddDataMode::Append, write_options: WriteOptions::default(), + embedding_registry: Some(self.embedding_registry.clone()), } } @@ -743,11 +836,10 @@ impl Table { impl From for Table { fn from(table: NativeTable) -> Self { - Self { - inner: Arc::new(table), - } + Self::new(Arc::new(table)) } } + /// A table in a LanceDB database. #[derive(Debug, Clone)] pub struct NativeTable { @@ -918,7 +1010,6 @@ impl NativeTable { Some(wrapper) => params.patch_with_store_wrapper(wrapper)?, None => params, }; - let storage_options = params .store_params .clone() @@ -1342,6 +1433,11 @@ impl TableInternal for NativeTable { Ok(Arc::new(Schema::from(&lance_schema))) } + async fn table_definition(&self) -> Result { + let schema = self.schema().await?; + TableDefinition::try_from_rich_schema(schema) + } + async fn count_rows(&self, filter: Option) -> Result { Ok(self.dataset.get().await?.count_rows(filter).await?) } @@ -1351,6 +1447,9 @@ impl TableInternal for NativeTable { add: AddDataBuilder, data: Box, ) -> Result<()> { + let data = + MaybeEmbedded::try_new(data, self.table_definition().await?, add.embedding_registry)?; + let mut lance_params = add.write_options.lance_write_params.unwrap_or(WriteParams { mode: match add.mode { AddDataMode::Append => WriteMode::Append, @@ -1378,8 +1477,8 @@ impl TableInternal for NativeTable { }; self.dataset.ensure_mutable().await?; - let dataset = Dataset::write(data, &self.uri, Some(lance_params)).await?; + self.dataset.set_latest(dataset).await; Ok(()) } diff --git a/rust/lancedb/tests/embedding_registry_test.rs b/rust/lancedb/tests/embedding_registry_test.rs new file mode 100644 index 00000000..36ac5dd1 --- /dev/null +++ b/rust/lancedb/tests/embedding_registry_test.rs @@ -0,0 +1,320 @@ +use std::{ + borrow::Cow, + collections::{HashMap, HashSet}, + iter::repeat, + sync::Arc, +}; + +use arrow::buffer::NullBuffer; +use arrow_array::{ + Array, FixedSizeListArray, Float32Array, Int32Array, RecordBatch, RecordBatchIterator, + StringArray, +}; +use arrow_schema::{DataType, Field, Schema}; +use futures::StreamExt; +use lancedb::{ + arrow::IntoArrow, + connect, + embeddings::{EmbeddingDefinition, EmbeddingFunction, EmbeddingRegistry}, + query::ExecutableQuery, + Error, Result, +}; + +#[tokio::test] +async fn test_custom_func() -> Result<()> { + let tempdir = tempfile::tempdir().unwrap(); + let tempdir = tempdir.path().to_str().unwrap(); + let db = connect(tempdir).execute().await?; + let embed_fun = MockEmbed::new("embed_fun".to_string(), 1); + db.embedding_registry() + .register("embed_fun", Arc::new(embed_fun.clone()))?; + + let tbl = db + .create_table("test", create_some_records()?) + .add_embedding(EmbeddingDefinition::new( + "text", + &embed_fun.name, + Some("embeddings"), + ))? + .execute() + .await?; + let mut res = tbl.query().execute().await?; + while let Some(Ok(batch)) = res.next().await { + let embeddings = batch.column_by_name("embeddings"); + assert!(embeddings.is_some()); + let embeddings = embeddings.unwrap(); + assert_eq!(embeddings.data_type(), embed_fun.dest_type()?.as_ref()); + } + // now make sure the embeddings are applied when + // we add new records too + tbl.add(create_some_records()?).execute().await?; + let mut res = tbl.query().execute().await?; + while let Some(Ok(batch)) = res.next().await { + let embeddings = batch.column_by_name("embeddings"); + assert!(embeddings.is_some()); + let embeddings = embeddings.unwrap(); + assert_eq!(embeddings.data_type(), embed_fun.dest_type()?.as_ref()); + } + Ok(()) +} + +#[tokio::test] +async fn test_custom_registry() -> Result<()> { + let tempdir = tempfile::tempdir().unwrap(); + let tempdir = tempdir.path().to_str().unwrap(); + + let db = connect(tempdir) + .embedding_registry(Arc::new(MyRegistry::default())) + .execute() + .await?; + + let tbl = db + .create_table("test", create_some_records()?) + .add_embedding(EmbeddingDefinition::new( + "text", + "func_1", + Some("embeddings"), + ))? + .execute() + .await?; + let mut res = tbl.query().execute().await?; + while let Some(Ok(batch)) = res.next().await { + let embeddings = batch.column_by_name("embeddings"); + assert!(embeddings.is_some()); + let embeddings = embeddings.unwrap(); + assert_eq!( + embeddings.data_type(), + MockEmbed::new("func_1".to_string(), 1) + .dest_type()? + .as_ref() + ); + } + Ok(()) +} + +#[tokio::test] +async fn test_multiple_embeddings() -> Result<()> { + let tempdir = tempfile::tempdir().unwrap(); + let tempdir = tempdir.path().to_str().unwrap(); + + let db = connect(tempdir).execute().await?; + let func_1 = MockEmbed::new("func_1".to_string(), 1); + let func_2 = MockEmbed::new("func_2".to_string(), 10); + db.embedding_registry() + .register(&func_1.name, Arc::new(func_1.clone()))?; + db.embedding_registry() + .register(&func_2.name, Arc::new(func_2.clone()))?; + + let tbl = db + .create_table("test", create_some_records()?) + .add_embedding(EmbeddingDefinition::new( + "text", + &func_1.name, + Some("first_embeddings"), + ))? + .add_embedding(EmbeddingDefinition::new( + "text", + &func_2.name, + Some("second_embeddings"), + ))? + .execute() + .await?; + let mut res = tbl.query().execute().await?; + while let Some(Ok(batch)) = res.next().await { + let embeddings = batch.column_by_name("first_embeddings"); + assert!(embeddings.is_some()); + let second_embeddings = batch.column_by_name("second_embeddings"); + assert!(second_embeddings.is_some()); + + let embeddings = embeddings.unwrap(); + assert_eq!(embeddings.data_type(), func_1.dest_type()?.as_ref()); + + let second_embeddings = second_embeddings.unwrap(); + assert_eq!(second_embeddings.data_type(), func_2.dest_type()?.as_ref()); + } + + // now make sure the embeddings are applied when + // we add new records too + tbl.add(create_some_records()?).execute().await?; + let mut res = tbl.query().execute().await?; + while let Some(Ok(batch)) = res.next().await { + let embeddings = batch.column_by_name("first_embeddings"); + assert!(embeddings.is_some()); + let second_embeddings = batch.column_by_name("second_embeddings"); + assert!(second_embeddings.is_some()); + + let embeddings = embeddings.unwrap(); + assert_eq!(embeddings.data_type(), func_1.dest_type()?.as_ref()); + + let second_embeddings = second_embeddings.unwrap(); + assert_eq!(second_embeddings.data_type(), func_2.dest_type()?.as_ref()); + } + Ok(()) +} + +#[tokio::test] +async fn test_no_func_in_registry() -> Result<()> { + let tempdir = tempfile::tempdir().unwrap(); + let tempdir = tempdir.path().to_str().unwrap(); + + let db = connect(tempdir).execute().await?; + + let res = db + .create_table("test", create_some_records()?) + .add_embedding(EmbeddingDefinition::new( + "text", + "some_func", + Some("first_embeddings"), + )); + assert!(res.is_err()); + assert!(matches!( + res.err().unwrap(), + Error::EmbeddingFunctionNotFound { .. } + )); + + Ok(()) +} + +#[tokio::test] +async fn test_no_func_in_registry_on_add() -> Result<()> { + let tempdir = tempfile::tempdir().unwrap(); + let tempdir = tempdir.path().to_str().unwrap(); + + let db = connect(tempdir).execute().await?; + db.embedding_registry().register( + "some_func", + Arc::new(MockEmbed::new("some_func".to_string(), 1)), + )?; + + db.create_table("test", create_some_records()?) + .add_embedding(EmbeddingDefinition::new( + "text", + "some_func", + Some("first_embeddings"), + ))? + .execute() + .await?; + + let db = connect(tempdir).execute().await?; + + let tbl = db.open_table("test").execute().await?; + // This should fail because 'tbl' is expecting "some_func" to be in the registry + let res = tbl.add(create_some_records()?).execute().await; + assert!(res.is_err()); + assert!(matches!( + res.unwrap_err(), + crate::Error::EmbeddingFunctionNotFound { .. } + )); + + Ok(()) +} + +fn create_some_records() -> Result { + const TOTAL: usize = 2; + + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("text", DataType::Utf8, true), + ])); + + // Create a RecordBatch stream. + let batches = RecordBatchIterator::new( + vec![RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from_iter_values(0..TOTAL as i32)), + Arc::new(StringArray::from_iter( + repeat(Some("hello world".to_string())).take(TOTAL), + )), + ], + ) + .unwrap()] + .into_iter() + .map(Ok), + schema.clone(), + ); + Ok(Box::new(batches)) +} + +#[derive(Debug)] +struct MyRegistry { + functions: HashMap>, +} +impl Default for MyRegistry { + fn default() -> Self { + let funcs: Vec> = vec![ + Arc::new(MockEmbed::new("func_1".to_string(), 1)), + Arc::new(MockEmbed::new("func_2".to_string(), 10)), + ]; + Self { + functions: funcs + .into_iter() + .map(|f| (f.name().to_string(), f)) + .collect(), + } + } +} + +/// a mock registry that only has one function called `embed_fun` +impl EmbeddingRegistry for MyRegistry { + fn functions(&self) -> HashSet { + self.functions.keys().cloned().collect() + } + + fn register(&self, _name: &str, _function: Arc) -> Result<()> { + Err(Error::Other { + message: "MyRegistry is read-only".to_string(), + source: None, + }) + } + + fn get(&self, name: &str) -> Option> { + self.functions.get(name).cloned() + } +} + +#[derive(Debug, Clone)] +struct MockEmbed { + source_type: DataType, + dest_type: DataType, + name: String, + dim: usize, +} + +impl MockEmbed { + pub fn new(name: String, dim: usize) -> Self { + Self { + source_type: DataType::Utf8, + dest_type: DataType::new_fixed_size_list(DataType::Float32, dim as _, true), + name, + dim, + } + } +} + +impl EmbeddingFunction for MockEmbed { + fn name(&self) -> &str { + &self.name + } + fn source_type(&self) -> Result> { + Ok(Cow::Borrowed(&self.source_type)) + } + fn dest_type(&self) -> Result> { + Ok(Cow::Borrowed(&self.dest_type)) + } + fn embed(&self, source: Arc) -> Result> { + // We can't use the FixedSizeListBuilder here because it always adds a null bitmap + // and we want to explicitly work with non-nullable arrays. + let len = source.len(); + let inner = Arc::new(Float32Array::from(vec![Some(1.0); len * self.dim])); + let field = Field::new("item", inner.data_type().clone(), false); + let arr = FixedSizeListArray::new( + Arc::new(field), + self.dim as _, + inner, + Some(NullBuffer::new_valid(len)), + ); + + Ok(Arc::new(arr)) + } +}