From 555fa261477744bae06957d5d1a0121395fc5ed8 Mon Sep 17 00:00:00 2001 From: Jeff Simpson Date: Fri, 31 Jan 2025 11:48:02 -0500 Subject: [PATCH] fix(rust): add embedding_registry on open_table (#2086) # Description Fix for: https://github.com/lancedb/lancedb/issues/1581 This is the same implementation as https://github.com/lancedb/lancedb/pull/1781 but with the addition of a unit test and rustfmt. --- rust/lancedb/src/connection.rs | 6 ++- rust/lancedb/tests/embedding_registry_test.rs | 42 +++++++++++++++++++ 2 files changed, 47 insertions(+), 1 deletion(-) diff --git a/rust/lancedb/src/connection.rs b/rust/lancedb/src/connection.rs index c03f3c423..6100abfca 100644 --- a/rust/lancedb/src/connection.rs +++ b/rust/lancedb/src/connection.rs @@ -1081,6 +1081,7 @@ impl ConnectionInternal for Database { async fn do_open_table(&self, mut options: OpenTableBuilder) -> Result { let table_uri = self.table_uri(&options.name)?; + let embedding_registry = self.embedding_registry.clone(); // Inherit storage options from the connection let storage_options = options @@ -1117,7 +1118,10 @@ impl ConnectionInternal for Database { ) .await?, ); - Ok(Table::new(native_table)) + Ok(Table::new_with_embedding_registry( + native_table, + embedding_registry, + )) } async fn rename_table(&self, _old_name: &str, _new_name: &str) -> Result<()> { diff --git a/rust/lancedb/tests/embedding_registry_test.rs b/rust/lancedb/tests/embedding_registry_test.rs index c6386aa82..b51917770 100644 --- a/rust/lancedb/tests/embedding_registry_test.rs +++ b/rust/lancedb/tests/embedding_registry_test.rs @@ -155,6 +155,48 @@ async fn test_multiple_embeddings() -> Result<()> { Ok(()) } +#[tokio::test] +async fn test_open_table_embeddings() -> Result<()> { + let tempdir = tempfile::tempdir().unwrap(); + let tempdir = tempdir.path().to_str().unwrap(); + + let db = connect(tempdir).execute().await?; + let embed_fun = MockEmbed::new("embed_fun".to_string(), 1); + db.embedding_registry() + .register("embed_fun", Arc::new(embed_fun.clone()))?; + + db.create_table("test", create_some_records()?) + .add_embedding(EmbeddingDefinition::new( + "text", + &embed_fun.name, + Some("embeddings"), + ))? + .execute() + .await?; + + // now open the table and check the embeddings + let tbl = db.open_table("test").execute().await?; + + let mut res = tbl.query().execute().await?; + while let Some(Ok(batch)) = res.next().await { + let embeddings = batch.column_by_name("embeddings"); + assert!(embeddings.is_some()); + let embeddings = embeddings.unwrap(); + assert_eq!(embeddings.data_type(), embed_fun.dest_type()?.as_ref()); + } + // now make sure the embeddings are applied when + // we add new records too + tbl.add(create_some_records()?).execute().await?; + let mut res = tbl.query().execute().await?; + while let Some(Ok(batch)) = res.next().await { + let embeddings = batch.column_by_name("embeddings"); + assert!(embeddings.is_some()); + let embeddings = embeddings.unwrap(); + assert_eq!(embeddings.data_type(), embed_fun.dest_type()?.as_ref()); + } + Ok(()) +} + #[tokio::test] async fn test_no_func_in_registry() -> Result<()> { let tempdir = tempfile::tempdir().unwrap();