fix(rust): add embedding_registry on open_table (#2086)

# Description

Fix for: https://github.com/lancedb/lancedb/issues/1581

This is the same implementation as
https://github.com/lancedb/lancedb/pull/1781 but with the addition of a
unit test and rustfmt.
This commit is contained in:
Jeff Simpson
2025-01-31 11:48:02 -05:00
committed by GitHub
parent e05c0cd87e
commit 555fa26147
2 changed files with 47 additions and 1 deletions

View File

@@ -1081,6 +1081,7 @@ impl ConnectionInternal for Database {
async fn do_open_table(&self, mut options: OpenTableBuilder) -> Result<Table> {
let table_uri = self.table_uri(&options.name)?;
let embedding_registry = self.embedding_registry.clone();
// Inherit storage options from the connection
let storage_options = options
@@ -1117,7 +1118,10 @@ impl ConnectionInternal for Database {
)
.await?,
);
Ok(Table::new(native_table))
Ok(Table::new_with_embedding_registry(
native_table,
embedding_registry,
))
}
async fn rename_table(&self, _old_name: &str, _new_name: &str) -> Result<()> {

View File

@@ -155,6 +155,48 @@ async fn test_multiple_embeddings() -> Result<()> {
Ok(())
}
#[tokio::test]
async fn test_open_table_embeddings() -> Result<()> {
let tempdir = tempfile::tempdir().unwrap();
let tempdir = tempdir.path().to_str().unwrap();
let db = connect(tempdir).execute().await?;
let embed_fun = MockEmbed::new("embed_fun".to_string(), 1);
db.embedding_registry()
.register("embed_fun", Arc::new(embed_fun.clone()))?;
db.create_table("test", create_some_records()?)
.add_embedding(EmbeddingDefinition::new(
"text",
&embed_fun.name,
Some("embeddings"),
))?
.execute()
.await?;
// now open the table and check the embeddings
let tbl = db.open_table("test").execute().await?;
let mut res = tbl.query().execute().await?;
while let Some(Ok(batch)) = res.next().await {
let embeddings = batch.column_by_name("embeddings");
assert!(embeddings.is_some());
let embeddings = embeddings.unwrap();
assert_eq!(embeddings.data_type(), embed_fun.dest_type()?.as_ref());
}
// now make sure the embeddings are applied when
// we add new records too
tbl.add(create_some_records()?).execute().await?;
let mut res = tbl.query().execute().await?;
while let Some(Ok(batch)) = res.next().await {
let embeddings = batch.column_by_name("embeddings");
assert!(embeddings.is_some());
let embeddings = embeddings.unwrap();
assert_eq!(embeddings.data_type(), embed_fun.dest_type()?.as_ref());
}
Ok(())
}
#[tokio::test]
async fn test_no_func_in_registry() -> Result<()> {
let tempdir = tempfile::tempdir().unwrap();