fix(python): fix multiple embedding functions bug (#597)

Closes #594

The embedding functions are pydantic models so multiple instances with
the same parameters are considered ==, which means that if you have
multiple embedding columns it's possible for the embeddings to get
overwritten. Instead we use `is` instead of == to avoid this problem.

testing: modified unit test to include this case
This commit is contained in:
Chang She
2023-10-24 13:05:05 -04:00
committed by GitHub
parent 26a97ba997
commit cd9debc3b7
2 changed files with 26 additions and 2 deletions

View File

@@ -327,7 +327,12 @@ class LanceModel(pydantic.BaseModel):
for vec, func in vec_and_function:
for source, field_info in cls.safe_get_fields().items():
src_func = get_extras(field_info, "source_column_for")
if src_func == func:
if src_func is func:
# note we can't use == here since the function is a pydantic
# model so two instances of the same function are ==, so if you
# have multiple vector columns from multiple sources, both will
# be mapped to the same source column
# GH594
configs.append(
EmbeddingFunctionConfig(
source_column=source, vector_column=vec, function=func

View File

@@ -33,10 +33,13 @@ def test_sentence_transformer(alias, tmp_path):
db = lancedb.connect(tmp_path)
registry = get_registry()
func = registry.get(alias).create()
func2 = registry.get(alias).create()
class Words(LanceModel):
text: str = func.SourceField()
text2: str = func2.SourceField()
vector: Vector(func.ndims()) = func.VectorField()
vector2: Vector(func2.ndims()) = func2.VectorField()
table = db.create_table("words", schema=Words)
table.add(
@@ -50,7 +53,16 @@ def test_sentence_transformer(alias, tmp_path):
"foo",
"bar",
"baz",
]
],
"text2": [
"to be or not to be",
"that is the question",
"for whether tis nobler",
"in the mind to suffer",
"the slings and arrows",
"of outrageous fortune",
"or to take arms",
],
}
)
)
@@ -62,6 +74,13 @@ def test_sentence_transformer(alias, tmp_path):
expected = table.search(vec).limit(1).to_pydantic(Words)[0]
assert actual.text == expected.text
assert actual.text == "hello world"
assert not np.allclose(actual.vector, actual.vector2)
actual = (
table.search(query, vector_column_name="vector2").limit(1).to_pydantic(Words)[0]
)
assert actual.text != "hello world"
assert not np.allclose(actual.vector, actual.vector2)
@pytest.mark.slow