fix(python): force deduce vector column name if running explicit hybrid query (#1692)

Right now when passing vector and query explicitly for hybrid search ,
vector_column_name is not deduced.
(https://lancedb.github.io/lancedb/hybrid_search/hybrid_search/#hybrid-search-in-lancedb
). Because vector and query can be both none when initialising the
QueryBuilder in this case. This PR forces deduction of query type if it
is set to "hybrid"
This commit is contained in:
Ayush Chaurasia
2024-09-24 19:02:56 +05:30
committed by GitHub
parent f5c25b6fff
commit f81ce68e41
2 changed files with 33 additions and 2 deletions

View File

@@ -1630,7 +1630,9 @@ class LanceTable(Table):
and also the "_distance" column which is the distance between the query
vector and the returned vector.
"""
if vector_column_name is None and query is not None and query_type != "fts":
if (
vector_column_name is None and query is not None and query_type != "fts"
) or (vector_column_name is None and query_type == "hybrid"):
try:
vector_column_name = inf_vector_column_query(self.schema)
except Exception as e:

View File

@@ -973,7 +973,36 @@ def test_hybrid_search(db, tmp_path):
.where("text='Arrrrggghhhhhhh'")
.to_list()
)
len(result) == 1
assert len(result) == 1
# with explicit query type
vector_query = list(range(emb.ndims()))
result = (
table.search(query_type="hybrid")
.vector(vector_query)
.text("Arrrrggghhhhhhh")
.to_arrow()
)
assert len(result) > 0
assert "_relevance_score" in result.column_names
# with vector_column_name
result = (
table.search(query_type="hybrid", vector_column_name="vector")
.vector(vector_query)
.text("Arrrrggghhhhhhh")
.to_arrow()
)
assert len(result) > 0
assert "_relevance_score" in result.column_names
# fail if only text or vector is provided
with pytest.raises(ValueError):
table.search(query_type="hybrid").to_list()
with pytest.raises(ValueError):
table.search(query_type="hybrid").vector(vector_query).to_list()
with pytest.raises(ValueError):
table.search(query_type="hybrid").text("Arrrrggghhhhhhh").to_list()
@pytest.mark.parametrize(