diff --git a/rust/lancedb/src/connection.rs b/rust/lancedb/src/connection.rs
index c03f3c423..6100abfca 100644
--- a/rust/lancedb/src/connection.rs
+++ b/rust/lancedb/src/connection.rs
@@ -1081,6 +1081,7 @@ impl ConnectionInternal for Database {
async fn do_open_table(&self, mut options: OpenTableBuilder) -> Result
{
let table_uri = self.table_uri(&options.name)?;
+ let embedding_registry = self.embedding_registry.clone();
// Inherit storage options from the connection
let storage_options = options
@@ -1117,7 +1118,10 @@ impl ConnectionInternal for Database {
)
.await?,
);
- Ok(Table::new(native_table))
+ Ok(Table::new_with_embedding_registry(
+ native_table,
+ embedding_registry,
+ ))
}
async fn rename_table(&self, _old_name: &str, _new_name: &str) -> Result<()> {
diff --git a/rust/lancedb/tests/embedding_registry_test.rs b/rust/lancedb/tests/embedding_registry_test.rs
index c6386aa82..b51917770 100644
--- a/rust/lancedb/tests/embedding_registry_test.rs
+++ b/rust/lancedb/tests/embedding_registry_test.rs
@@ -155,6 +155,48 @@ async fn test_multiple_embeddings() -> Result<()> {
Ok(())
}
+#[tokio::test]
+async fn test_open_table_embeddings() -> Result<()> {
+ let tempdir = tempfile::tempdir().unwrap();
+ let tempdir = tempdir.path().to_str().unwrap();
+
+ let db = connect(tempdir).execute().await?;
+ let embed_fun = MockEmbed::new("embed_fun".to_string(), 1);
+ db.embedding_registry()
+ .register("embed_fun", Arc::new(embed_fun.clone()))?;
+
+ db.create_table("test", create_some_records()?)
+ .add_embedding(EmbeddingDefinition::new(
+ "text",
+ &embed_fun.name,
+ Some("embeddings"),
+ ))?
+ .execute()
+ .await?;
+
+ // now open the table and check the embeddings
+ let tbl = db.open_table("test").execute().await?;
+
+ let mut res = tbl.query().execute().await?;
+ while let Some(Ok(batch)) = res.next().await {
+ let embeddings = batch.column_by_name("embeddings");
+ assert!(embeddings.is_some());
+ let embeddings = embeddings.unwrap();
+ assert_eq!(embeddings.data_type(), embed_fun.dest_type()?.as_ref());
+ }
+ // now make sure the embeddings are applied when
+ // we add new records too
+ tbl.add(create_some_records()?).execute().await?;
+ let mut res = tbl.query().execute().await?;
+ while let Some(Ok(batch)) = res.next().await {
+ let embeddings = batch.column_by_name("embeddings");
+ assert!(embeddings.is_some());
+ let embeddings = embeddings.unwrap();
+ assert_eq!(embeddings.data_type(), embed_fun.dest_type()?.as_ref());
+ }
+ Ok(())
+}
+
#[tokio::test]
async fn test_no_func_in_registry() -> Result<()> {
let tempdir = tempfile::tempdir().unwrap();