mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-05 19:32:56 +00:00
fix(python): force deduce vector column name if running explicit hybrid query (#1692)
Right now when passing vector and query explicitly for hybrid search , vector_column_name is not deduced. (https://lancedb.github.io/lancedb/hybrid_search/hybrid_search/#hybrid-search-in-lancedb ). Because vector and query can be both none when initialising the QueryBuilder in this case. This PR forces deduction of query type if it is set to "hybrid"
This commit is contained in:
@@ -1630,7 +1630,9 @@ class LanceTable(Table):
|
||||
and also the "_distance" column which is the distance between the query
|
||||
vector and the returned vector.
|
||||
"""
|
||||
if vector_column_name is None and query is not None and query_type != "fts":
|
||||
if (
|
||||
vector_column_name is None and query is not None and query_type != "fts"
|
||||
) or (vector_column_name is None and query_type == "hybrid"):
|
||||
try:
|
||||
vector_column_name = inf_vector_column_query(self.schema)
|
||||
except Exception as e:
|
||||
|
||||
@@ -973,7 +973,36 @@ def test_hybrid_search(db, tmp_path):
|
||||
.where("text='Arrrrggghhhhhhh'")
|
||||
.to_list()
|
||||
)
|
||||
len(result) == 1
|
||||
assert len(result) == 1
|
||||
|
||||
# with explicit query type
|
||||
vector_query = list(range(emb.ndims()))
|
||||
result = (
|
||||
table.search(query_type="hybrid")
|
||||
.vector(vector_query)
|
||||
.text("Arrrrggghhhhhhh")
|
||||
.to_arrow()
|
||||
)
|
||||
assert len(result) > 0
|
||||
assert "_relevance_score" in result.column_names
|
||||
|
||||
# with vector_column_name
|
||||
result = (
|
||||
table.search(query_type="hybrid", vector_column_name="vector")
|
||||
.vector(vector_query)
|
||||
.text("Arrrrggghhhhhhh")
|
||||
.to_arrow()
|
||||
)
|
||||
assert len(result) > 0
|
||||
assert "_relevance_score" in result.column_names
|
||||
|
||||
# fail if only text or vector is provided
|
||||
with pytest.raises(ValueError):
|
||||
table.search(query_type="hybrid").to_list()
|
||||
with pytest.raises(ValueError):
|
||||
table.search(query_type="hybrid").vector(vector_query).to_list()
|
||||
with pytest.raises(ValueError):
|
||||
table.search(query_type="hybrid").text("Arrrrggghhhhhhh").to_list()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
||||
Reference in New Issue
Block a user