diff --git a/python/python/lancedb/table.py b/python/python/lancedb/table.py index 7689b3d4..efd3f4d2 100644 --- a/python/python/lancedb/table.py +++ b/python/python/lancedb/table.py @@ -117,7 +117,8 @@ def _append_vector_col(data: pa.Table, metadata: dict, schema: Optional[pa.Schem functions = EmbeddingFunctionRegistry.get_instance().parse_functions(metadata) for vector_column, conf in functions.items(): func = conf.function - if vector_column not in data.column_names: + no_vector_column = vector_column not in data.column_names + if no_vector_column or pc.all(pc.is_null(data[vector_column])).as_py(): col_data = func.compute_source_embeddings_with_retry( data[conf.source_column] ) @@ -125,9 +126,16 @@ def _append_vector_col(data: pa.Table, metadata: dict, schema: Optional[pa.Schem dtype = schema.field(vector_column).type else: dtype = pa.list_(pa.float32(), len(col_data[0])) - data = data.append_column( - pa.field(vector_column, type=dtype), pa.array(col_data, type=dtype) - ) + if no_vector_column: + data = data.append_column( + pa.field(vector_column, type=dtype), pa.array(col_data, type=dtype) + ) + else: + data = data.set_column( + data.column_names.index(vector_column), + pa.field(vector_column, type=dtype), + pa.array(col_data, type=dtype), + ) return data diff --git a/python/python/tests/test_embeddings.py b/python/python/tests/test_embeddings.py index af442a16..ed7b105a 100644 --- a/python/python/tests/test_embeddings.py +++ b/python/python/tests/test_embeddings.py @@ -11,6 +11,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import sys +from typing import List, Union import lance import lancedb @@ -23,6 +24,8 @@ from lancedb.embeddings import ( EmbeddingFunctionRegistry, with_embeddings, ) +from lancedb.embeddings.base import TextEmbeddingFunction +from lancedb.embeddings.registry import get_registry, register from lancedb.pydantic import LanceModel, Vector @@ -112,3 +115,34 @@ def test_embedding_function_rate_limit(tmp_path): table.add([{"text": "hello world"}]) table.add([{"text": "hello world"}]) assert len(table) == 2 + + +def test_add_optional_vector(tmp_path): + @register("mock-embedding") + class MockEmbeddingFunction(TextEmbeddingFunction): + def ndims(self): + return 128 + + def generate_embeddings( + self, texts: Union[List[str], np.ndarray] + ) -> List[np.array]: + """ + Generate the embeddings for the given texts + """ + return [np.random.randn(self.ndims()).tolist() for _ in range(len(texts))] + + registry = get_registry() + model = registry.get("mock-embedding").create() + + class LanceSchema(LanceModel): + id: str + vector: Vector(model.ndims()) = model.VectorField(default=None) + text: str = model.SourceField() + + db = lancedb.connect(tmp_path) + tbl = db.create_table("optional_vector", schema=LanceSchema) + + # add works + expected = LanceSchema(id="id", text="text") + tbl.add([expected]) + assert not (np.abs(tbl.to_pandas()["vector"][0]) < 1e-6).all()