diff --git a/python/lancedb/pydantic.py b/python/lancedb/pydantic.py index 2616f52f..4ca93903 100644 --- a/python/lancedb/pydantic.py +++ b/python/lancedb/pydantic.py @@ -327,7 +327,12 @@ class LanceModel(pydantic.BaseModel): for vec, func in vec_and_function: for source, field_info in cls.safe_get_fields().items(): src_func = get_extras(field_info, "source_column_for") - if src_func == func: + if src_func is func: + # note we can't use == here since the function is a pydantic + # model so two instances of the same function are ==, so if you + # have multiple vector columns from multiple sources, both will + # be mapped to the same source column + # GH594 configs.append( EmbeddingFunctionConfig( source_column=source, vector_column=vec, function=func diff --git a/python/tests/test_embeddings_slow.py b/python/tests/test_embeddings_slow.py index 607a346d..b0078397 100644 --- a/python/tests/test_embeddings_slow.py +++ b/python/tests/test_embeddings_slow.py @@ -33,10 +33,13 @@ def test_sentence_transformer(alias, tmp_path): db = lancedb.connect(tmp_path) registry = get_registry() func = registry.get(alias).create() + func2 = registry.get(alias).create() class Words(LanceModel): text: str = func.SourceField() + text2: str = func2.SourceField() vector: Vector(func.ndims()) = func.VectorField() + vector2: Vector(func2.ndims()) = func2.VectorField() table = db.create_table("words", schema=Words) table.add( @@ -50,7 +53,16 @@ def test_sentence_transformer(alias, tmp_path): "foo", "bar", "baz", - ] + ], + "text2": [ + "to be or not to be", + "that is the question", + "for whether tis nobler", + "in the mind to suffer", + "the slings and arrows", + "of outrageous fortune", + "or to take arms", + ], } ) ) @@ -62,6 +74,13 @@ def test_sentence_transformer(alias, tmp_path): expected = table.search(vec).limit(1).to_pydantic(Words)[0] assert actual.text == expected.text assert actual.text == "hello world" + assert not np.allclose(actual.vector, actual.vector2) + + actual = ( + table.search(query, vector_column_name="vector2").limit(1).to_pydantic(Words)[0] + ) + assert actual.text != "hello world" + assert not np.allclose(actual.vector, actual.vector2) @pytest.mark.slow