diff --git a/rust/lancedb/src/connection.rs b/rust/lancedb/src/connection.rs index 294035281..44a6b443c 100644 --- a/rust/lancedb/src/connection.rs +++ b/rust/lancedb/src/connection.rs @@ -196,22 +196,6 @@ impl CreateTableBuilder { }; Ok((data, builder)) } - - pub fn add_embedding(mut self, definition: EmbeddingDefinition) -> Result { - // Early verification of the embedding name - let embedding_func = self - .parent - .embedding_registry() - .get(&definition.embedding_name) - .ok_or_else(|| Error::EmbeddingFunctionNotFound { - name: definition.embedding_name.clone(), - reason: "No embedding function found in the connection's embedding_registry" - .to_string(), - })?; - - self.embeddings.push((definition, embedding_func)); - Ok(self) - } } // Builder methods that only apply when we do not have initial data @@ -329,6 +313,26 @@ impl CreateTableBuilder { }; self } + + /// Add an embedding definition to the table. + /// + /// The `embedding_name` must match the name of an embedding function that + /// was previously registered with the connection's [`EmbeddingRegistry`]. + pub fn add_embedding(mut self, definition: EmbeddingDefinition) -> Result { + // Early verification of the embedding name + let embedding_func = self + .parent + .embedding_registry() + .get(&definition.embedding_name) + .ok_or_else(|| Error::EmbeddingFunctionNotFound { + name: definition.embedding_name.clone(), + reason: "No embedding function found in the connection's embedding_registry" + .to_string(), + })?; + + self.embeddings.push((definition, embedding_func)); + Ok(self) + } } #[derive(Clone, Debug)]