From f81ce68e41bc97648a0ec2e8ae4f9e0f6fafea9d Mon Sep 17 00:00:00 2001 From: Ayush Chaurasia Date: Tue, 24 Sep 2024 19:02:56 +0530 Subject: [PATCH] fix(python): force deduce vector column name if running explicit hybrid query (#1692) Right now when passing vector and query explicitly for hybrid search , vector_column_name is not deduced. (https://lancedb.github.io/lancedb/hybrid_search/hybrid_search/#hybrid-search-in-lancedb ). Because vector and query can be both none when initialising the QueryBuilder in this case. This PR forces deduction of query type if it is set to "hybrid" --- python/python/lancedb/table.py | 4 +++- python/python/tests/test_table.py | 31 ++++++++++++++++++++++++++++++- 2 files changed, 33 insertions(+), 2 deletions(-) diff --git a/python/python/lancedb/table.py b/python/python/lancedb/table.py index 615dd27e..d0bd1f38 100644 --- a/python/python/lancedb/table.py +++ b/python/python/lancedb/table.py @@ -1630,7 +1630,9 @@ class LanceTable(Table): and also the "_distance" column which is the distance between the query vector and the returned vector. """ - if vector_column_name is None and query is not None and query_type != "fts": + if ( + vector_column_name is None and query is not None and query_type != "fts" + ) or (vector_column_name is None and query_type == "hybrid"): try: vector_column_name = inf_vector_column_query(self.schema) except Exception as e: diff --git a/python/python/tests/test_table.py b/python/python/tests/test_table.py index 7e89ac82..c32a5c98 100644 --- a/python/python/tests/test_table.py +++ b/python/python/tests/test_table.py @@ -973,7 +973,36 @@ def test_hybrid_search(db, tmp_path): .where("text='Arrrrggghhhhhhh'") .to_list() ) - len(result) == 1 + assert len(result) == 1 + + # with explicit query type + vector_query = list(range(emb.ndims())) + result = ( + table.search(query_type="hybrid") + .vector(vector_query) + .text("Arrrrggghhhhhhh") + .to_arrow() + ) + assert len(result) > 0 + assert "_relevance_score" in result.column_names + + # with vector_column_name + result = ( + table.search(query_type="hybrid", vector_column_name="vector") + .vector(vector_query) + .text("Arrrrggghhhhhhh") + .to_arrow() + ) + assert len(result) > 0 + assert "_relevance_score" in result.column_names + + # fail if only text or vector is provided + with pytest.raises(ValueError): + table.search(query_type="hybrid").to_list() + with pytest.raises(ValueError): + table.search(query_type="hybrid").vector(vector_query).to_list() + with pytest.raises(ValueError): + table.search(query_type="hybrid").text("Arrrrggghhhhhhh").to_list() @pytest.mark.parametrize(