feat: rust embedding registry (#1259)

Todo:

- [x] add proper documentation
- [x] add unit tests
- [x] better handling of the registry**1
- [x] allow user defined registry**2

**1 The python implementation just uses a global registry so it makes
things a bit easier. I attached it to the db/connection to prevent
future conflicts if running multiple connections/databases. I mostly
modeled the registry & pattern off of datafusion's
[FunctionRegistry](https://docs.rs/datafusion/latest/datafusion/execution/trait.FunctionRegistry.html).

**2 Ideally, the user should be able to provide it's own registry
entirely, but currently it just uses an in memory registry by default
(_which isn't configurable_)

`rust/lancedb/examples/embedding_registry.rs` provides a thorough
example of expected usage.

---

Some additional notes:

This does not provide any of the out of box functionality that the
python registry does.

_i.e there are no built-in embedding functions._ 

You can think of this as the ground work for adding those built in
functions, So while this is part of
https://github.com/lancedb/lancedb/issues/994, it does not yet offer
feature parity.
This commit is contained in:
Cory Grinstead
2024-05-06 18:39:07 -05:00
committed by GitHub
parent fdb5d6fdf1
commit 9d2fb7d602
10 changed files with 838 additions and 18 deletions

2
.gitignore vendored
View File

@@ -6,7 +6,7 @@
venv
.vscode
.zed
rust/target
rust/Cargo.lock

View File

@@ -116,12 +116,19 @@ pub trait IntoArrow {
fn into_arrow(self) -> Result<Box<dyn arrow_array::RecordBatchReader + Send>>;
}
pub type BoxedRecordBatchReader = Box<dyn arrow_array::RecordBatchReader + Send>;
impl<T: arrow_array::RecordBatchReader + Send + 'static> IntoArrow for T {
fn into_arrow(self) -> Result<Box<dyn arrow_array::RecordBatchReader + Send>> {
Ok(Box::new(self))
}
}
impl<S: Stream<Item = Result<arrow_array::RecordBatch>>> SimpleRecordBatchStream<S> {
pub fn new(stream: S, schema: Arc<arrow_schema::Schema>) -> Self {
Self { schema, stream }
}
}
#[cfg(feature = "polars")]
/// An iterator of record batches formed from a Polars DataFrame.
pub struct PolarsDataFrameRecordBatchReader {

View File

@@ -27,9 +27,12 @@ use object_store::{aws::AwsCredential, local::LocalFileSystem};
use snafu::prelude::*;
use crate::arrow::IntoArrow;
use crate::embeddings::{
EmbeddingDefinition, EmbeddingFunction, EmbeddingRegistry, MemoryRegistry, WithEmbeddings,
};
use crate::error::{CreateDirSnafu, Error, InvalidTableNameSnafu, Result};
use crate::io::object_store::MirroringObjectStoreWrapper;
use crate::table::{NativeTable, WriteOptions};
use crate::table::{NativeTable, TableDefinition, WriteOptions};
use crate::utils::validate_table_name;
use crate::Table;
@@ -133,9 +136,10 @@ pub struct CreateTableBuilder<const HAS_DATA: bool, T: IntoArrow> {
parent: Arc<dyn ConnectionInternal>,
pub(crate) name: String,
pub(crate) data: Option<T>,
pub(crate) schema: Option<SchemaRef>,
pub(crate) mode: CreateTableMode,
pub(crate) write_options: WriteOptions,
pub(crate) table_definition: Option<TableDefinition>,
pub(crate) embeddings: Vec<(EmbeddingDefinition, Arc<dyn EmbeddingFunction>)>,
}
// Builder methods that only apply when we have initial data
@@ -145,9 +149,10 @@ impl<T: IntoArrow> CreateTableBuilder<true, T> {
parent,
name,
data: Some(data),
schema: None,
mode: CreateTableMode::default(),
write_options: WriteOptions::default(),
table_definition: None,
embeddings: Vec::new(),
}
}
@@ -175,24 +180,43 @@ impl<T: IntoArrow> CreateTableBuilder<true, T> {
parent: self.parent,
name: self.name,
data: None,
schema: self.schema,
table_definition: self.table_definition,
mode: self.mode,
write_options: self.write_options,
embeddings: self.embeddings,
};
Ok((data, builder))
}
pub fn add_embedding(mut self, definition: EmbeddingDefinition) -> Result<Self> {
// Early verification of the embedding name
let embedding_func = self
.parent
.embedding_registry()
.get(&definition.embedding_name)
.ok_or_else(|| Error::EmbeddingFunctionNotFound {
name: definition.embedding_name.to_string(),
reason: "No embedding function found in the connection's embedding_registry"
.to_string(),
})?;
self.embeddings.push((definition, embedding_func));
Ok(self)
}
}
// Builder methods that only apply when we do not have initial data
impl CreateTableBuilder<false, NoData> {
fn new(parent: Arc<dyn ConnectionInternal>, name: String, schema: SchemaRef) -> Self {
let table_definition = TableDefinition::new_from_schema(schema);
Self {
parent,
name,
data: None,
schema: Some(schema),
table_definition: Some(table_definition),
mode: CreateTableMode::default(),
write_options: WriteOptions::default(),
embeddings: Vec::new(),
}
}
@@ -350,6 +374,7 @@ impl OpenTableBuilder {
pub(crate) trait ConnectionInternal:
Send + Sync + std::fmt::Debug + std::fmt::Display + 'static
{
fn embedding_registry(&self) -> &dyn EmbeddingRegistry;
async fn table_names(&self, options: TableNamesBuilder) -> Result<Vec<String>>;
async fn do_create_table(
&self,
@@ -366,7 +391,7 @@ pub(crate) trait ConnectionInternal:
) -> Result<Table> {
let batches = Box::new(RecordBatchIterator::new(
vec![],
options.schema.as_ref().unwrap().clone(),
options.table_definition.clone().unwrap().schema.clone(),
));
self.do_create_table(options, batches).await
}
@@ -453,6 +478,13 @@ impl Connection {
pub async fn drop_db(&self) -> Result<()> {
self.internal.drop_db().await
}
/// Get the in-memory embedding registry.
/// It's important to note that the embedding registry is not persisted across connections.
/// So if a table contains embeddings, you will need to make sure that you are using a connection that has the same embedding functions registered
pub fn embedding_registry(&self) -> &dyn EmbeddingRegistry {
self.internal.embedding_registry()
}
}
#[derive(Debug)]
@@ -486,6 +518,7 @@ pub struct ConnectBuilder {
/// consistency only applies to read operations. Write operations are
/// always consistent.
read_consistency_interval: Option<std::time::Duration>,
embedding_registry: Option<Arc<dyn EmbeddingRegistry>>,
}
impl ConnectBuilder {
@@ -498,6 +531,7 @@ impl ConnectBuilder {
host_override: None,
read_consistency_interval: None,
storage_options: HashMap::new(),
embedding_registry: None,
}
}
@@ -516,6 +550,12 @@ impl ConnectBuilder {
self
}
/// Provide a custom [`EmbeddingRegistry`] to use for this connection.
pub fn embedding_registry(mut self, registry: Arc<dyn EmbeddingRegistry>) -> Self {
self.embedding_registry = Some(registry);
self
}
/// [`AwsCredential`] to use when connecting to S3.
#[deprecated(note = "Pass through storage_options instead")]
pub fn aws_creds(mut self, aws_creds: AwsCredential) -> Self {
@@ -642,6 +682,7 @@ struct Database {
// Storage options to be inherited by tables created from this connection
storage_options: HashMap<String, String>,
embedding_registry: Arc<dyn EmbeddingRegistry>,
}
impl std::fmt::Display for Database {
@@ -675,7 +716,12 @@ impl Database {
// TODO: pass params regardless of OS
match parse_res {
Ok(url) if url.scheme().len() == 1 && cfg!(windows) => {
Self::open_path(uri, options.read_consistency_interval).await
Self::open_path(
uri,
options.read_consistency_interval,
options.embedding_registry.clone(),
)
.await
}
Ok(mut url) => {
// iter thru the query params and extract the commit store param
@@ -745,6 +791,10 @@ impl Database {
None => None,
};
let embedding_registry = options
.embedding_registry
.clone()
.unwrap_or_else(|| Arc::new(MemoryRegistry::new()));
Ok(Self {
uri: table_base_uri,
query_string,
@@ -753,20 +803,33 @@ impl Database {
store_wrapper: write_store_wrapper,
read_consistency_interval: options.read_consistency_interval,
storage_options,
embedding_registry,
})
}
Err(_) => Self::open_path(uri, options.read_consistency_interval).await,
Err(_) => {
Self::open_path(
uri,
options.read_consistency_interval,
options.embedding_registry.clone(),
)
.await
}
}
}
async fn open_path(
path: &str,
read_consistency_interval: Option<std::time::Duration>,
embedding_registry: Option<Arc<dyn EmbeddingRegistry>>,
) -> Result<Self> {
let (object_store, base_path) = ObjectStore::from_uri(path).await?;
if object_store.is_local() {
Self::try_create_dir(path).context(CreateDirSnafu { path })?;
}
let embedding_registry =
embedding_registry.unwrap_or_else(|| Arc::new(MemoryRegistry::new()));
Ok(Self {
uri: path.to_string(),
query_string: None,
@@ -775,6 +838,7 @@ impl Database {
store_wrapper: None,
read_consistency_interval,
storage_options: HashMap::new(),
embedding_registry,
})
}
@@ -815,6 +879,9 @@ impl Database {
#[async_trait::async_trait]
impl ConnectionInternal for Database {
fn embedding_registry(&self) -> &dyn EmbeddingRegistry {
self.embedding_registry.as_ref()
}
async fn table_names(&self, options: TableNamesBuilder) -> Result<Vec<String>> {
let mut f = self
.object_store
@@ -851,7 +918,7 @@ impl ConnectionInternal for Database {
data: Box<dyn RecordBatchReader + Send>,
) -> Result<Table> {
let table_uri = self.table_uri(&options.name)?;
let embedding_registry = self.embedding_registry.clone();
// Inherit storage options from the connection
let storage_options = options
.write_options
@@ -866,6 +933,11 @@ impl ConnectionInternal for Database {
storage_options.insert(key.clone(), value.clone());
}
}
let data = if options.embeddings.is_empty() {
data
} else {
Box::new(WithEmbeddings::new(data, options.embeddings))
};
let mut write_params = options.write_options.lance_write_params.unwrap_or_default();
if matches!(&options.mode, CreateTableMode::Overwrite) {
@@ -882,7 +954,10 @@ impl ConnectionInternal for Database {
)
.await
{
Ok(table) => Ok(Table::new(Arc::new(table))),
Ok(table) => Ok(Table::new_with_embedding_registry(
Arc::new(table),
embedding_registry,
)),
Err(Error::TableAlreadyExists { name }) => match options.mode {
CreateTableMode::Create => Err(Error::TableAlreadyExists { name }),
CreateTableMode::ExistOk(callback) => {

View File

@@ -0,0 +1,307 @@
// Copyright 2024 LanceDB Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use lance::arrow::RecordBatchExt;
use std::{
borrow::Cow,
collections::{HashMap, HashSet},
sync::{Arc, RwLock},
};
use arrow_array::{Array, RecordBatch, RecordBatchReader};
use arrow_schema::{DataType, Field, SchemaBuilder};
// use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use crate::{
error::Result,
table::{ColumnDefinition, ColumnKind, TableDefinition},
Error,
};
/// Trait for embedding functions
///
/// An embedding function is a function that is applied to a column of input data
/// to produce an "embedding" of that input. This embedding is then stored in the
/// database alongside (or instead of) the original input.
///
/// An "embedding" is often a lower-dimensional representation of the input data.
/// For example, sentence-transformers can be used to embed sentences into a 768-dimensional
/// vector space. This is useful for tasks like similarity search, where we want to find
/// similar sentences to a query sentence.
///
/// To use an embedding function you must first register it with the `EmbeddingsRegistry`.
/// Then you can define it on a column in the table schema. That embedding will then be used
/// to embed the data in that column.
pub trait EmbeddingFunction: std::fmt::Debug + Send + Sync {
fn name(&self) -> &str;
/// The type of the input data
fn source_type(&self) -> Result<Cow<DataType>>;
/// The type of the output data
/// This should **always** match the output of the `embed` function
fn dest_type(&self) -> Result<Cow<DataType>>;
/// Embed the input
fn embed(&self, source: Arc<dyn Array>) -> Result<Arc<dyn Array>>;
}
/// Defines an embedding from input data into a lower-dimensional space
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub struct EmbeddingDefinition {
/// The name of the column in the input data
pub source_column: String,
/// The name of the embedding column, if not specified
/// it will be the source column with `_embedding` appended
pub dest_column: Option<String>,
/// The name of the embedding function to apply
pub embedding_name: String,
}
impl EmbeddingDefinition {
pub fn new<S: Into<String>>(source_column: S, embedding_name: S, dest: Option<S>) -> Self {
Self {
source_column: source_column.into(),
dest_column: dest.map(|d| d.into()),
embedding_name: embedding_name.into(),
}
}
}
/// A registry of embedding
pub trait EmbeddingRegistry: Send + Sync + std::fmt::Debug {
/// Return the names of all registered embedding functions
fn functions(&self) -> HashSet<String>;
/// Register a new [`EmbeddingFunction
/// Returns an error if the function can not be registered
fn register(&self, name: &str, function: Arc<dyn EmbeddingFunction>) -> Result<()>;
/// Get an embedding function by name
fn get(&self, name: &str) -> Option<Arc<dyn EmbeddingFunction>>;
}
/// A [`EmbeddingRegistry`] that uses in-memory [`HashMap`]s
#[derive(Debug, Default, Clone)]
pub struct MemoryRegistry {
functions: Arc<RwLock<HashMap<String, Arc<dyn EmbeddingFunction>>>>,
}
impl EmbeddingRegistry for MemoryRegistry {
fn functions(&self) -> HashSet<String> {
self.functions.read().unwrap().keys().cloned().collect()
}
fn register(&self, name: &str, function: Arc<dyn EmbeddingFunction>) -> Result<()> {
self.functions
.write()
.unwrap()
.insert(name.to_string(), function);
Ok(())
}
fn get(&self, name: &str) -> Option<Arc<dyn EmbeddingFunction>> {
self.functions.read().unwrap().get(name).cloned()
}
}
impl MemoryRegistry {
/// Create a new `MemoryRegistry`
pub fn new() -> Self {
Self::default()
}
}
/// A record batch reader that has embeddings applied to it
/// This is a wrapper around another record batch reader that applies an embedding function
/// when reading from the record batch
pub struct WithEmbeddings<R: RecordBatchReader> {
inner: R,
embeddings: Vec<(EmbeddingDefinition, Arc<dyn EmbeddingFunction>)>,
}
/// A record batch that might have embeddings applied to it.
pub enum MaybeEmbedded<R: RecordBatchReader> {
/// The record batch reader has embeddings applied to it
Yes(WithEmbeddings<R>),
/// The record batch reader does not have embeddings applied to it
/// The inner record batch reader is returned as-is
No(R),
}
impl<R: RecordBatchReader> MaybeEmbedded<R> {
/// Create a new RecordBatchReader with embeddings applied to it if the table definition
/// specifies an embedding column and the registry contains an embedding function with that name
/// Otherwise, this is a no-op and the inner RecordBatchReader is returned.
pub fn try_new(
inner: R,
table_definition: TableDefinition,
registry: Option<Arc<dyn EmbeddingRegistry>>,
) -> Result<Self> {
if let Some(registry) = registry {
let mut embeddings = Vec::with_capacity(table_definition.column_definitions.len());
for cd in table_definition.column_definitions.iter() {
if let ColumnKind::Embedding(embedding_def) = &cd.kind {
match registry.get(&embedding_def.embedding_name) {
Some(func) => {
embeddings.push((embedding_def.clone(), func));
}
None => {
return Err(Error::EmbeddingFunctionNotFound {
name: embedding_def.embedding_name.to_string(),
reason: format!(
"Table was defined with an embedding column `{}` but no embedding function was found with that name within the registry.",
embedding_def.embedding_name
),
});
}
}
}
}
if !embeddings.is_empty() {
return Ok(Self::Yes(WithEmbeddings { inner, embeddings }));
}
};
// No embeddings to apply
Ok(Self::No(inner))
}
}
impl<R: RecordBatchReader> WithEmbeddings<R> {
pub fn new(
inner: R,
embeddings: Vec<(EmbeddingDefinition, Arc<dyn EmbeddingFunction>)>,
) -> Self {
Self { inner, embeddings }
}
}
impl<R: RecordBatchReader> WithEmbeddings<R> {
fn dest_fields(&self) -> Result<Vec<Field>> {
let schema = self.inner.schema();
self.embeddings
.iter()
.map(|(ed, func)| {
let src_field = schema.field_with_name(&ed.source_column).unwrap();
let field_name = ed
.dest_column
.clone()
.unwrap_or_else(|| format!("{}_embedding", &ed.source_column));
Ok(Field::new(
field_name,
func.dest_type()?.into_owned(),
src_field.is_nullable(),
))
})
.collect()
}
fn column_defs(&self) -> Vec<ColumnDefinition> {
let base_schema = self.inner.schema();
base_schema
.fields()
.iter()
.map(|_| ColumnDefinition {
kind: ColumnKind::Physical,
})
.chain(self.embeddings.iter().map(|(ed, _)| ColumnDefinition {
kind: ColumnKind::Embedding(ed.clone()),
}))
.collect::<Vec<_>>()
}
pub fn table_definition(&self) -> Result<TableDefinition> {
let base_schema = self.inner.schema();
let output_fields = self.dest_fields()?;
let column_definitions = self.column_defs();
let mut sb: SchemaBuilder = base_schema.as_ref().into();
sb.extend(output_fields);
let schema = Arc::new(sb.finish());
Ok(TableDefinition {
schema,
column_definitions,
})
}
}
impl<R: RecordBatchReader> Iterator for MaybeEmbedded<R> {
type Item = std::result::Result<RecordBatch, arrow_schema::ArrowError>;
fn next(&mut self) -> Option<Self::Item> {
match self {
Self::Yes(inner) => inner.next(),
Self::No(inner) => inner.next(),
}
}
}
impl<R: RecordBatchReader> RecordBatchReader for MaybeEmbedded<R> {
fn schema(&self) -> Arc<arrow_schema::Schema> {
match self {
Self::Yes(inner) => inner.schema(),
Self::No(inner) => inner.schema(),
}
}
}
impl<R: RecordBatchReader> Iterator for WithEmbeddings<R> {
type Item = std::result::Result<RecordBatch, arrow_schema::ArrowError>;
fn next(&mut self) -> Option<Self::Item> {
let batch = self.inner.next()?;
match batch {
Ok(mut batch) => {
// todo: parallelize this
for (fld, func) in self.embeddings.iter() {
let src_column = batch.column_by_name(&fld.source_column).unwrap();
let embedding = match func.embed(src_column.clone()) {
Ok(embedding) => embedding,
Err(e) => {
return Some(Err(arrow_schema::ArrowError::ComputeError(format!(
"Error computing embedding: {}",
e
))))
}
};
let dst_field_name = fld
.dest_column
.clone()
.unwrap_or_else(|| format!("{}_embedding", &fld.source_column));
let dst_field = Field::new(
dst_field_name,
embedding.data_type().clone(),
embedding.nulls().is_some(),
);
match batch.try_with_column(dst_field.clone(), embedding) {
Ok(b) => batch = b,
Err(e) => return Some(Err(e)),
};
}
Some(Ok(batch))
}
Err(e) => Some(Err(e)),
}
}
}
impl<R: RecordBatchReader> RecordBatchReader for WithEmbeddings<R> {
fn schema(&self) -> Arc<arrow_schema::Schema> {
self.table_definition()
.expect("table definition should be infallible at this point")
.into_rich_schema()
}
}

View File

@@ -26,6 +26,9 @@ pub enum Error {
InvalidInput { message: String },
#[snafu(display("Table '{name}' was not found"))]
TableNotFound { name: String },
#[snafu(display("Embedding function '{name}' was not found. : {reason}"))]
EmbeddingFunctionNotFound { name: String, reason: String },
#[snafu(display("Table '{name}' already exists"))]
TableAlreadyExists { name: String },
#[snafu(display("Unable to created lance dataset at {path}: {source}"))]

View File

@@ -194,6 +194,7 @@
pub mod arrow;
pub mod connection;
pub mod data;
pub mod embeddings;
pub mod error;
pub mod index;
pub mod io;

View File

@@ -23,6 +23,7 @@ use tokio::task::spawn_blocking;
use crate::connection::{
ConnectionInternal, CreateTableBuilder, NoData, OpenTableBuilder, TableNamesBuilder,
};
use crate::embeddings::EmbeddingRegistry;
use crate::error::Result;
use crate::Table;
@@ -115,4 +116,8 @@ impl ConnectionInternal for RemoteDatabase {
async fn drop_db(&self) -> Result<()> {
todo!()
}
fn embedding_registry(&self) -> &dyn EmbeddingRegistry {
todo!()
}
}

View File

@@ -10,7 +10,7 @@ use crate::{
query::{Query, QueryExecutionOptions, VectorQuery},
table::{
merge::MergeInsertBuilder, AddDataBuilder, NativeTable, OptimizeAction, OptimizeStats,
TableInternal, UpdateBuilder,
TableDefinition, TableInternal, UpdateBuilder,
},
};
@@ -120,4 +120,7 @@ impl TableInternal for RemoteTable {
async fn list_indices(&self) -> Result<Vec<IndexConfig>> {
todo!()
}
async fn table_definition(&self) -> Result<TableDefinition> {
todo!()
}
}

View File

@@ -41,10 +41,12 @@ use lance::io::WrappingObjectStore;
use lance_index::IndexType;
use lance_index::{optimize::OptimizeOptions, DatasetIndexExt};
use log::info;
use serde::{Deserialize, Serialize};
use snafu::whatever;
use crate::arrow::IntoArrow;
use crate::connection::NoData;
use crate::embeddings::{EmbeddingDefinition, EmbeddingRegistry, MaybeEmbedded, MemoryRegistry};
use crate::error::{Error, Result};
use crate::index::vector::{IvfPqIndexBuilder, VectorIndex, VectorIndexStatistics};
use crate::index::IndexConfig;
@@ -63,6 +65,79 @@ use self::merge::MergeInsertBuilder;
pub(crate) mod dataset;
pub mod merge;
/// Defines the type of column
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ColumnKind {
/// Columns populated by data from the user (this is the most common case)
Physical,
/// Columns populated by applying an embedding function to the input
Embedding(EmbeddingDefinition),
}
/// Defines a column in a table
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ColumnDefinition {
/// The source of the column data
pub kind: ColumnKind,
}
#[derive(Debug, Clone)]
pub struct TableDefinition {
pub column_definitions: Vec<ColumnDefinition>,
pub schema: SchemaRef,
}
impl TableDefinition {
pub fn new(schema: SchemaRef, column_definitions: Vec<ColumnDefinition>) -> Self {
Self {
column_definitions,
schema,
}
}
pub fn new_from_schema(schema: SchemaRef) -> Self {
let column_definitions = schema
.fields()
.iter()
.map(|_| ColumnDefinition {
kind: ColumnKind::Physical,
})
.collect();
Self::new(schema, column_definitions)
}
pub fn try_from_rich_schema(schema: SchemaRef) -> Result<Self> {
let column_definitions = schema.metadata.get("lancedb::column_definitions");
if let Some(column_definitions) = column_definitions {
let column_definitions: Vec<ColumnDefinition> =
serde_json::from_str(column_definitions).map_err(|e| Error::Runtime {
message: format!("Failed to deserialize column definitions: {}", e),
})?;
Ok(Self::new(schema, column_definitions))
} else {
let column_definitions = schema
.fields()
.iter()
.map(|_| ColumnDefinition {
kind: ColumnKind::Physical,
})
.collect();
Ok(Self::new(schema, column_definitions))
}
}
pub fn into_rich_schema(self) -> SchemaRef {
// We have full control over the structure of column definitions. This should
// not fail, except for a bug
let lancedb_metadata = serde_json::to_string(&self.column_definitions).unwrap();
let mut schema_with_metadata = (*self.schema).clone();
schema_with_metadata
.metadata
.insert("lancedb::column_definitions".to_string(), lancedb_metadata);
Arc::new(schema_with_metadata)
}
}
/// Optimize the dataset.
///
/// Similar to `VACUUM` in PostgreSQL, it offers different options to
@@ -132,6 +207,7 @@ pub struct AddDataBuilder<T: IntoArrow> {
pub(crate) data: T,
pub(crate) mode: AddDataMode,
pub(crate) write_options: WriteOptions,
embedding_registry: Option<Arc<dyn EmbeddingRegistry>>,
}
impl<T: IntoArrow> std::fmt::Debug for AddDataBuilder<T> {
@@ -163,6 +239,7 @@ impl<T: IntoArrow> AddDataBuilder<T> {
mode: self.mode,
parent: self.parent,
write_options: self.write_options,
embedding_registry: self.embedding_registry,
};
parent.add(without_data, data).await
}
@@ -280,6 +357,7 @@ pub(crate) trait TableInternal: std::fmt::Display + std::fmt::Debug + Send + Syn
async fn checkout(&self, version: u64) -> Result<()>;
async fn checkout_latest(&self) -> Result<()>;
async fn restore(&self) -> Result<()>;
async fn table_definition(&self) -> Result<TableDefinition>;
}
/// A Table is a collection of strong typed Rows.
@@ -288,6 +366,7 @@ pub(crate) trait TableInternal: std::fmt::Display + std::fmt::Debug + Send + Syn
#[derive(Clone)]
pub struct Table {
inner: Arc<dyn TableInternal>,
embedding_registry: Arc<dyn EmbeddingRegistry>,
}
impl std::fmt::Display for Table {
@@ -298,7 +377,20 @@ impl std::fmt::Display for Table {
impl Table {
pub(crate) fn new(inner: Arc<dyn TableInternal>) -> Self {
Self { inner }
Self {
inner,
embedding_registry: Arc::new(MemoryRegistry::new()),
}
}
pub(crate) fn new_with_embedding_registry(
inner: Arc<dyn TableInternal>,
embedding_registry: Arc<dyn EmbeddingRegistry>,
) -> Self {
Self {
inner,
embedding_registry,
}
}
/// Cast as [`NativeTable`], or return None it if is not a [`NativeTable`].
@@ -340,6 +432,7 @@ impl Table {
data: batches,
mode: AddDataMode::Append,
write_options: WriteOptions::default(),
embedding_registry: Some(self.embedding_registry.clone()),
}
}
@@ -743,11 +836,10 @@ impl Table {
impl From<NativeTable> for Table {
fn from(table: NativeTable) -> Self {
Self {
inner: Arc::new(table),
}
Self::new(Arc::new(table))
}
}
/// A table in a LanceDB database.
#[derive(Debug, Clone)]
pub struct NativeTable {
@@ -918,7 +1010,6 @@ impl NativeTable {
Some(wrapper) => params.patch_with_store_wrapper(wrapper)?,
None => params,
};
let storage_options = params
.store_params
.clone()
@@ -1342,6 +1433,11 @@ impl TableInternal for NativeTable {
Ok(Arc::new(Schema::from(&lance_schema)))
}
async fn table_definition(&self) -> Result<TableDefinition> {
let schema = self.schema().await?;
TableDefinition::try_from_rich_schema(schema)
}
async fn count_rows(&self, filter: Option<String>) -> Result<usize> {
Ok(self.dataset.get().await?.count_rows(filter).await?)
}
@@ -1351,6 +1447,9 @@ impl TableInternal for NativeTable {
add: AddDataBuilder<NoData>,
data: Box<dyn RecordBatchReader + Send>,
) -> Result<()> {
let data =
MaybeEmbedded::try_new(data, self.table_definition().await?, add.embedding_registry)?;
let mut lance_params = add.write_options.lance_write_params.unwrap_or(WriteParams {
mode: match add.mode {
AddDataMode::Append => WriteMode::Append,
@@ -1378,8 +1477,8 @@ impl TableInternal for NativeTable {
};
self.dataset.ensure_mutable().await?;
let dataset = Dataset::write(data, &self.uri, Some(lance_params)).await?;
self.dataset.set_latest(dataset).await;
Ok(())
}

View File

@@ -0,0 +1,320 @@
use std::{
borrow::Cow,
collections::{HashMap, HashSet},
iter::repeat,
sync::Arc,
};
use arrow::buffer::NullBuffer;
use arrow_array::{
Array, FixedSizeListArray, Float32Array, Int32Array, RecordBatch, RecordBatchIterator,
StringArray,
};
use arrow_schema::{DataType, Field, Schema};
use futures::StreamExt;
use lancedb::{
arrow::IntoArrow,
connect,
embeddings::{EmbeddingDefinition, EmbeddingFunction, EmbeddingRegistry},
query::ExecutableQuery,
Error, Result,
};
#[tokio::test]
async fn test_custom_func() -> Result<()> {
let tempdir = tempfile::tempdir().unwrap();
let tempdir = tempdir.path().to_str().unwrap();
let db = connect(tempdir).execute().await?;
let embed_fun = MockEmbed::new("embed_fun".to_string(), 1);
db.embedding_registry()
.register("embed_fun", Arc::new(embed_fun.clone()))?;
let tbl = db
.create_table("test", create_some_records()?)
.add_embedding(EmbeddingDefinition::new(
"text",
&embed_fun.name,
Some("embeddings"),
))?
.execute()
.await?;
let mut res = tbl.query().execute().await?;
while let Some(Ok(batch)) = res.next().await {
let embeddings = batch.column_by_name("embeddings");
assert!(embeddings.is_some());
let embeddings = embeddings.unwrap();
assert_eq!(embeddings.data_type(), embed_fun.dest_type()?.as_ref());
}
// now make sure the embeddings are applied when
// we add new records too
tbl.add(create_some_records()?).execute().await?;
let mut res = tbl.query().execute().await?;
while let Some(Ok(batch)) = res.next().await {
let embeddings = batch.column_by_name("embeddings");
assert!(embeddings.is_some());
let embeddings = embeddings.unwrap();
assert_eq!(embeddings.data_type(), embed_fun.dest_type()?.as_ref());
}
Ok(())
}
#[tokio::test]
async fn test_custom_registry() -> Result<()> {
let tempdir = tempfile::tempdir().unwrap();
let tempdir = tempdir.path().to_str().unwrap();
let db = connect(tempdir)
.embedding_registry(Arc::new(MyRegistry::default()))
.execute()
.await?;
let tbl = db
.create_table("test", create_some_records()?)
.add_embedding(EmbeddingDefinition::new(
"text",
"func_1",
Some("embeddings"),
))?
.execute()
.await?;
let mut res = tbl.query().execute().await?;
while let Some(Ok(batch)) = res.next().await {
let embeddings = batch.column_by_name("embeddings");
assert!(embeddings.is_some());
let embeddings = embeddings.unwrap();
assert_eq!(
embeddings.data_type(),
MockEmbed::new("func_1".to_string(), 1)
.dest_type()?
.as_ref()
);
}
Ok(())
}
#[tokio::test]
async fn test_multiple_embeddings() -> Result<()> {
let tempdir = tempfile::tempdir().unwrap();
let tempdir = tempdir.path().to_str().unwrap();
let db = connect(tempdir).execute().await?;
let func_1 = MockEmbed::new("func_1".to_string(), 1);
let func_2 = MockEmbed::new("func_2".to_string(), 10);
db.embedding_registry()
.register(&func_1.name, Arc::new(func_1.clone()))?;
db.embedding_registry()
.register(&func_2.name, Arc::new(func_2.clone()))?;
let tbl = db
.create_table("test", create_some_records()?)
.add_embedding(EmbeddingDefinition::new(
"text",
&func_1.name,
Some("first_embeddings"),
))?
.add_embedding(EmbeddingDefinition::new(
"text",
&func_2.name,
Some("second_embeddings"),
))?
.execute()
.await?;
let mut res = tbl.query().execute().await?;
while let Some(Ok(batch)) = res.next().await {
let embeddings = batch.column_by_name("first_embeddings");
assert!(embeddings.is_some());
let second_embeddings = batch.column_by_name("second_embeddings");
assert!(second_embeddings.is_some());
let embeddings = embeddings.unwrap();
assert_eq!(embeddings.data_type(), func_1.dest_type()?.as_ref());
let second_embeddings = second_embeddings.unwrap();
assert_eq!(second_embeddings.data_type(), func_2.dest_type()?.as_ref());
}
// now make sure the embeddings are applied when
// we add new records too
tbl.add(create_some_records()?).execute().await?;
let mut res = tbl.query().execute().await?;
while let Some(Ok(batch)) = res.next().await {
let embeddings = batch.column_by_name("first_embeddings");
assert!(embeddings.is_some());
let second_embeddings = batch.column_by_name("second_embeddings");
assert!(second_embeddings.is_some());
let embeddings = embeddings.unwrap();
assert_eq!(embeddings.data_type(), func_1.dest_type()?.as_ref());
let second_embeddings = second_embeddings.unwrap();
assert_eq!(second_embeddings.data_type(), func_2.dest_type()?.as_ref());
}
Ok(())
}
#[tokio::test]
async fn test_no_func_in_registry() -> Result<()> {
let tempdir = tempfile::tempdir().unwrap();
let tempdir = tempdir.path().to_str().unwrap();
let db = connect(tempdir).execute().await?;
let res = db
.create_table("test", create_some_records()?)
.add_embedding(EmbeddingDefinition::new(
"text",
"some_func",
Some("first_embeddings"),
));
assert!(res.is_err());
assert!(matches!(
res.err().unwrap(),
Error::EmbeddingFunctionNotFound { .. }
));
Ok(())
}
#[tokio::test]
async fn test_no_func_in_registry_on_add() -> Result<()> {
let tempdir = tempfile::tempdir().unwrap();
let tempdir = tempdir.path().to_str().unwrap();
let db = connect(tempdir).execute().await?;
db.embedding_registry().register(
"some_func",
Arc::new(MockEmbed::new("some_func".to_string(), 1)),
)?;
db.create_table("test", create_some_records()?)
.add_embedding(EmbeddingDefinition::new(
"text",
"some_func",
Some("first_embeddings"),
))?
.execute()
.await?;
let db = connect(tempdir).execute().await?;
let tbl = db.open_table("test").execute().await?;
// This should fail because 'tbl' is expecting "some_func" to be in the registry
let res = tbl.add(create_some_records()?).execute().await;
assert!(res.is_err());
assert!(matches!(
res.unwrap_err(),
crate::Error::EmbeddingFunctionNotFound { .. }
));
Ok(())
}
fn create_some_records() -> Result<impl IntoArrow> {
const TOTAL: usize = 2;
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("text", DataType::Utf8, true),
]));
// Create a RecordBatch stream.
let batches = RecordBatchIterator::new(
vec![RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from_iter_values(0..TOTAL as i32)),
Arc::new(StringArray::from_iter(
repeat(Some("hello world".to_string())).take(TOTAL),
)),
],
)
.unwrap()]
.into_iter()
.map(Ok),
schema.clone(),
);
Ok(Box::new(batches))
}
#[derive(Debug)]
struct MyRegistry {
functions: HashMap<String, Arc<dyn EmbeddingFunction>>,
}
impl Default for MyRegistry {
fn default() -> Self {
let funcs: Vec<Arc<dyn EmbeddingFunction>> = vec![
Arc::new(MockEmbed::new("func_1".to_string(), 1)),
Arc::new(MockEmbed::new("func_2".to_string(), 10)),
];
Self {
functions: funcs
.into_iter()
.map(|f| (f.name().to_string(), f))
.collect(),
}
}
}
/// a mock registry that only has one function called `embed_fun`
impl EmbeddingRegistry for MyRegistry {
fn functions(&self) -> HashSet<String> {
self.functions.keys().cloned().collect()
}
fn register(&self, _name: &str, _function: Arc<dyn EmbeddingFunction>) -> Result<()> {
Err(Error::Other {
message: "MyRegistry is read-only".to_string(),
source: None,
})
}
fn get(&self, name: &str) -> Option<Arc<dyn EmbeddingFunction>> {
self.functions.get(name).cloned()
}
}
#[derive(Debug, Clone)]
struct MockEmbed {
source_type: DataType,
dest_type: DataType,
name: String,
dim: usize,
}
impl MockEmbed {
pub fn new(name: String, dim: usize) -> Self {
Self {
source_type: DataType::Utf8,
dest_type: DataType::new_fixed_size_list(DataType::Float32, dim as _, true),
name,
dim,
}
}
}
impl EmbeddingFunction for MockEmbed {
fn name(&self) -> &str {
&self.name
}
fn source_type(&self) -> Result<Cow<DataType>> {
Ok(Cow::Borrowed(&self.source_type))
}
fn dest_type(&self) -> Result<Cow<DataType>> {
Ok(Cow::Borrowed(&self.dest_type))
}
fn embed(&self, source: Arc<dyn Array>) -> Result<Arc<dyn Array>> {
// We can't use the FixedSizeListBuilder here because it always adds a null bitmap
// and we want to explicitly work with non-nullable arrays.
let len = source.len();
let inner = Arc::new(Float32Array::from(vec![Some(1.0); len * self.dim]));
let field = Field::new("item", inner.data_type().clone(), false);
let arr = FixedSizeListArray::new(
Arc::new(field),
self.dim as _,
inner,
Some(NullBuffer::new_valid(len)),
);
Ok(Arc::new(arr))
}
}