From 2861f3398271f72ff774b8895d61c3e14da79547 Mon Sep 17 00:00:00 2001 From: Chang She <759245+changhiskhan@users.noreply.github.com> Date: Tue, 24 Oct 2023 13:05:05 -0400 Subject: [PATCH] fix(python): fix multiple embedding functions bug (#597) Closes #594 The embedding functions are pydantic models so multiple instances with the same parameters are considered ==, which means that if you have multiple embedding columns it's possible for the embeddings to get overwritten. Instead we use `is` instead of == to avoid this problem. testing: modified unit test to include this case --- python/lancedb/pydantic.py | 7 ++++++- python/tests/test_embeddings_slow.py | 21 ++++++++++++++++++++- 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/python/lancedb/pydantic.py b/python/lancedb/pydantic.py index 2616f52f..4ca93903 100644 --- a/python/lancedb/pydantic.py +++ b/python/lancedb/pydantic.py @@ -327,7 +327,12 @@ class LanceModel(pydantic.BaseModel): for vec, func in vec_and_function: for source, field_info in cls.safe_get_fields().items(): src_func = get_extras(field_info, "source_column_for") - if src_func == func: + if src_func is func: + # note we can't use == here since the function is a pydantic + # model so two instances of the same function are ==, so if you + # have multiple vector columns from multiple sources, both will + # be mapped to the same source column + # GH594 configs.append( EmbeddingFunctionConfig( source_column=source, vector_column=vec, function=func diff --git a/python/tests/test_embeddings_slow.py b/python/tests/test_embeddings_slow.py index 607a346d..b0078397 100644 --- a/python/tests/test_embeddings_slow.py +++ b/python/tests/test_embeddings_slow.py @@ -33,10 +33,13 @@ def test_sentence_transformer(alias, tmp_path): db = lancedb.connect(tmp_path) registry = get_registry() func = registry.get(alias).create() + func2 = registry.get(alias).create() class Words(LanceModel): text: str = func.SourceField() + text2: str = func2.SourceField() vector: Vector(func.ndims()) = func.VectorField() + vector2: Vector(func2.ndims()) = func2.VectorField() table = db.create_table("words", schema=Words) table.add( @@ -50,7 +53,16 @@ def test_sentence_transformer(alias, tmp_path): "foo", "bar", "baz", - ] + ], + "text2": [ + "to be or not to be", + "that is the question", + "for whether tis nobler", + "in the mind to suffer", + "the slings and arrows", + "of outrageous fortune", + "or to take arms", + ], } ) ) @@ -62,6 +74,13 @@ def test_sentence_transformer(alias, tmp_path): expected = table.search(vec).limit(1).to_pydantic(Words)[0] assert actual.text == expected.text assert actual.text == "hello world" + assert not np.allclose(actual.vector, actual.vector2) + + actual = ( + table.search(query, vector_column_name="vector2").limit(1).to_pydantic(Words)[0] + ) + assert actual.text != "hello world" + assert not np.allclose(actual.vector, actual.vector2) @pytest.mark.slow