From a696dbc8f4d7307c22ad3813922d74061780b553 Mon Sep 17 00:00:00 2001 From: ayush chaurasia Date: Fri, 23 Feb 2024 13:54:44 +0530 Subject: [PATCH] update --- python/lancedb/query.py | 52 +++++++++++++++++++++------------- python/lancedb/table.py | 7 ++++- python/tests/test_rerankers.py | 5 ++-- 3 files changed, 41 insertions(+), 23 deletions(-) diff --git a/python/lancedb/query.py b/python/lancedb/query.py index 7418b2a6..1338f0f2 100644 --- a/python/lancedb/query.py +++ b/python/lancedb/query.py @@ -125,22 +125,26 @@ class LanceQueryBuilder(ABC): if query_type == "hybrid": # hybrid fts and vector query - return LanceHybridQueryBuilder(table, query, vector_column_name, vector, text) - + return LanceHybridQueryBuilder( + table, query, vector_column_name, vector, text + ) + # Resolve hybrid query with explicit vector and text params here to avoid - # adding them as instance attributes in the LanceQueryBuilder subclasses + # adding them as params in the BaseQueryBuilder class if vector is not None or text is not None: - # If vector and/or text are provided, then query_type must be 'hybrid'. if query_type not in ["hybrid", "auto"]: raise ValueError( - "If `vector` and `text` are provided, then `query_type` must be 'hybrid'." + "If `vector` and `text` are provided, then `query_type`\ + must be 'hybrid' or 'auto'" ) - return LanceHybridQueryBuilder(table, query, vector_column_name, vector, text) + return LanceHybridQueryBuilder( + table, query, vector_column_name, vector, text + ) - # convert "auto" query_type to "vector", "fts" - # or "hybrid" and convert the query to vector if needed + # convert "auto" query_type to "vector" or "fts" + # and convert the query to vector if needed query, query_type = cls._resolve_query( - table, query, query_type, vector_column_name, vector, text + table, query, query_type, vector_column_name ) if isinstance(query, str): @@ -170,8 +174,6 @@ class LanceQueryBuilder(ABC): elif query_type == "auto": if isinstance(query, (list, np.ndarray)): return query, "vector" - if isinstance(query, tuple): - return query, "hybrid" else: conf = table.embedding_functions.get(vector_column_name) if conf is not None: @@ -632,10 +634,19 @@ class LanceEmptyQueryBuilder(LanceQueryBuilder): class LanceHybridQueryBuilder(LanceQueryBuilder): - def __init__(self, table: "Table", query: str, vector_column: str, vector: Optional[VEC] = None, text: Optional[str] = None): + def __init__( + self, + table: "Table", + query: str, + vector_column: str, + vector: Optional[VEC] = None, + text: Optional[str] = None, + ): super().__init__(table) self._validate_fts_index() - vector_query, fts_query = self._validate_query(query, vector_column, vector, text) + vector_query, fts_query = self._validate_query( + query, vector_column, vector, text + ) self._fts_query = LanceFtsQueryBuilder(table, fts_query) self._vector_query = LanceVectorQueryBuilder(table, vector_query, vector_column) self._norm = "score" @@ -650,11 +661,15 @@ class LanceHybridQueryBuilder(LanceQueryBuilder): def _validate_query(self, query, vector_column, vector, text): if query is not None: if vector is not None or text is not None: - raise ValueError("Either pass `query` or `vector` and `text` separately, not both.") + raise ValueError( + "Either pass `query` or `vector` and `text` separately, not both." + ) else: if vector is None or text is None: - raise ValueError("Either pass `query` or `vector` and `text` separately, not both.") - + raise ValueError( + "Either pass `query` or `vector` and `text` separately, not both." + ) + if vector is not None and text is not None: if not isinstance(vector, (list, np.ndarray, pa.Array, pa.ChunkedArray)): raise ValueError(f"The vector query must be one of {VEC}.") @@ -666,11 +681,10 @@ class LanceHybridQueryBuilder(LanceQueryBuilder): return vector, query else: raise ValueError( - f"For hybrid search `query` must be a string or `vector` and `text` must be provided explicitly \ - of types {VEC} and str respectively." + f"For hybrid search `query` must be a string or `vector` and `text` \ + must be provided explicitly of types {VEC} and str respectively." ) - def to_arrow(self) -> pa.Table: with ThreadPoolExecutor() as executor: fts_future = executor.submit(self._fts_query.with_row_id(True).to_arrow) diff --git a/python/lancedb/table.py b/python/lancedb/table.py index bd9fc3c6..c6bf3332 100644 --- a/python/lancedb/table.py +++ b/python/lancedb/table.py @@ -1277,7 +1277,12 @@ class LanceTable(Table): vector_column_name = inf_vector_column_query(self.schema) register_event("search_table") return LanceQueryBuilder.create( - self, query, query_type, vector_column_name=vector_column_name, vector=vector, text=text + self, + query, + query_type, + vector_column_name=vector_column_name, + vector=vector, + text=text, ) @classmethod diff --git a/python/tests/test_rerankers.py b/python/tests/test_rerankers.py index f4b0fb05..105c53a1 100644 --- a/python/tests/test_rerankers.py +++ b/python/tests/test_rerankers.py @@ -129,14 +129,13 @@ def test_linear_combination(tmp_path): table.search(vector=query_vector, text=query, query_type="vector").rerank( normalize="score" ) - + # raise an error if only vector or text is provided with pytest.raises(ValueError): table.search(vector=query_vector).to_arrow() - + with pytest.raises(ValueError): table.search(text=query).to_arrow() - @pytest.mark.skipif(