mirror of
https://github.com/lancedb/lancedb.git
synced 2026-05-31 10:50:40 +00:00
fix(rust): add embedding_registry on open_table (#2086)
# Description Fix for: https://github.com/lancedb/lancedb/issues/1581 This is the same implementation as https://github.com/lancedb/lancedb/pull/1781 but with the addition of a unit test and rustfmt.
This commit is contained in:
@@ -1081,6 +1081,7 @@ impl ConnectionInternal for Database {
|
||||
|
||||
async fn do_open_table(&self, mut options: OpenTableBuilder) -> Result<Table> {
|
||||
let table_uri = self.table_uri(&options.name)?;
|
||||
let embedding_registry = self.embedding_registry.clone();
|
||||
|
||||
// Inherit storage options from the connection
|
||||
let storage_options = options
|
||||
@@ -1117,7 +1118,10 @@ impl ConnectionInternal for Database {
|
||||
)
|
||||
.await?,
|
||||
);
|
||||
Ok(Table::new(native_table))
|
||||
Ok(Table::new_with_embedding_registry(
|
||||
native_table,
|
||||
embedding_registry,
|
||||
))
|
||||
}
|
||||
|
||||
async fn rename_table(&self, _old_name: &str, _new_name: &str) -> Result<()> {
|
||||
|
||||
@@ -155,6 +155,48 @@ async fn test_multiple_embeddings() -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_open_table_embeddings() -> Result<()> {
|
||||
let tempdir = tempfile::tempdir().unwrap();
|
||||
let tempdir = tempdir.path().to_str().unwrap();
|
||||
|
||||
let db = connect(tempdir).execute().await?;
|
||||
let embed_fun = MockEmbed::new("embed_fun".to_string(), 1);
|
||||
db.embedding_registry()
|
||||
.register("embed_fun", Arc::new(embed_fun.clone()))?;
|
||||
|
||||
db.create_table("test", create_some_records()?)
|
||||
.add_embedding(EmbeddingDefinition::new(
|
||||
"text",
|
||||
&embed_fun.name,
|
||||
Some("embeddings"),
|
||||
))?
|
||||
.execute()
|
||||
.await?;
|
||||
|
||||
// now open the table and check the embeddings
|
||||
let tbl = db.open_table("test").execute().await?;
|
||||
|
||||
let mut res = tbl.query().execute().await?;
|
||||
while let Some(Ok(batch)) = res.next().await {
|
||||
let embeddings = batch.column_by_name("embeddings");
|
||||
assert!(embeddings.is_some());
|
||||
let embeddings = embeddings.unwrap();
|
||||
assert_eq!(embeddings.data_type(), embed_fun.dest_type()?.as_ref());
|
||||
}
|
||||
// now make sure the embeddings are applied when
|
||||
// we add new records too
|
||||
tbl.add(create_some_records()?).execute().await?;
|
||||
let mut res = tbl.query().execute().await?;
|
||||
while let Some(Ok(batch)) = res.next().await {
|
||||
let embeddings = batch.column_by_name("embeddings");
|
||||
assert!(embeddings.is_some());
|
||||
let embeddings = embeddings.unwrap();
|
||||
assert_eq!(embeddings.data_type(), embed_fun.dest_type()?.as_ref());
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_no_func_in_registry() -> Result<()> {
|
||||
let tempdir = tempfile::tempdir().unwrap();
|
||||
|
||||
Reference in New Issue
Block a user