diff --git a/python/lancedb/query.py b/python/lancedb/query.py index 44809da9..7418b2a6 100644 --- a/python/lancedb/query.py +++ b/python/lancedb/query.py @@ -117,23 +117,32 @@ class LanceQueryBuilder(ABC): query: Optional[Union[np.ndarray, str, "PIL.Image.Image", Tuple]], query_type: str, vector_column_name: str, + vector: Optional[VEC] = None, + text: Optional[str] = None, ) -> LanceQueryBuilder: - if query is None: + if query is None and vector is None and text is None: return LanceEmptyQueryBuilder(table) if query_type == "hybrid": # hybrid fts and vector query - return LanceHybridQueryBuilder(table, query, vector_column_name) + return LanceHybridQueryBuilder(table, query, vector_column_name, vector, text) + + # Resolve hybrid query with explicit vector and text params here to avoid + # adding them as instance attributes in the LanceQueryBuilder subclasses + if vector is not None or text is not None: + # If vector and/or text are provided, then query_type must be 'hybrid'. + if query_type not in ["hybrid", "auto"]: + raise ValueError( + "If `vector` and `text` are provided, then `query_type` must be 'hybrid'." + ) + return LanceHybridQueryBuilder(table, query, vector_column_name, vector, text) # convert "auto" query_type to "vector", "fts" # or "hybrid" and convert the query to vector if needed query, query_type = cls._resolve_query( - table, query, query_type, vector_column_name + table, query, query_type, vector_column_name, vector, text ) - if query_type == "hybrid": - return LanceHybridQueryBuilder(table, query, vector_column_name) - if isinstance(query, str): # fts return LanceFtsQueryBuilder(table, query) @@ -623,12 +632,11 @@ class LanceEmptyQueryBuilder(LanceQueryBuilder): class LanceHybridQueryBuilder(LanceQueryBuilder): - def __init__(self, table: "Table", query: str, vector_column: str): + def __init__(self, table: "Table", query: str, vector_column: str, vector: Optional[VEC] = None, text: Optional[str] = None): super().__init__(table) self._validate_fts_index() - vector_query, fts_query = self._validate_query(query) + vector_query, fts_query = self._validate_query(query, vector_column, vector, text) self._fts_query = LanceFtsQueryBuilder(table, fts_query) - vector_query = self._query_to_vector(table, vector_query, vector_column) self._vector_query = LanceVectorQueryBuilder(table, vector_query, vector_column) self._norm = "score" self._reranker = LinearCombinationReranker(weight=0.7, fill=1.0) @@ -639,25 +647,30 @@ class LanceHybridQueryBuilder(LanceQueryBuilder): "Please create a full-text search index " "to perform hybrid search." ) - def _validate_query(self, query): - # Temp hack to support vectorized queries for hybrid search - if isinstance(query, str): - return query, query - elif isinstance(query, tuple): - if len(query) != 2: - raise ValueError( - "The query must be a tuple of (vector_query, fts_query)." - ) - if not isinstance(query[0], (list, np.ndarray, pa.Array, pa.ChunkedArray)): + def _validate_query(self, query, vector_column, vector, text): + if query is not None: + if vector is not None or text is not None: + raise ValueError("Either pass `query` or `vector` and `text` separately, not both.") + else: + if vector is None or text is None: + raise ValueError("Either pass `query` or `vector` and `text` separately, not both.") + + if vector is not None and text is not None: + if not isinstance(vector, (list, np.ndarray, pa.Array, pa.ChunkedArray)): raise ValueError(f"The vector query must be one of {VEC}.") - if not isinstance(query[1], str): + if not isinstance(text, str): raise ValueError("The fts query must be a string.") - return query[0], query[1] + return vector, text + if isinstance(query, str): + vector = self._query_to_vector(self._table, query, vector_column) + return vector, query else: raise ValueError( - "The query must be either a string or a tuple of (vector, string)." + f"For hybrid search `query` must be a string or `vector` and `text` must be provided explicitly \ + of types {VEC} and str respectively." ) + def to_arrow(self) -> pa.Table: with ThreadPoolExecutor() as executor: fts_future = executor.submit(self._fts_query.with_row_id(True).to_arrow) diff --git a/python/lancedb/table.py b/python/lancedb/table.py index 04bad713..0c74f560 100644 --- a/python/lancedb/table.py +++ b/python/lancedb/table.py @@ -416,6 +416,8 @@ class Table(ABC): query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None, vector_column_name: Optional[str] = None, query_type: str = "auto", + vector: Optional[VEC] = None, + text: Optional[str] = None, ) -> LanceQueryBuilder: """Create a search query to find the nearest neighbors of the given query vector. We currently support [vector search][search] @@ -1201,6 +1203,8 @@ class LanceTable(Table): query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None, vector_column_name: Optional[str] = None, query_type: str = "auto", + vector: Optional[VEC] = None, + text: Optional[str] = None, ) -> LanceQueryBuilder: """Create a search query to find the nearest neighbors of the given query vector. We currently support [vector search][search] @@ -1255,6 +1259,10 @@ class LanceTable(Table): or raise an error if no corresponding embedding function is found. If the `query` is a string, then the query type is "vector" if the table has embedding functions, else the query type is "fts" + vector: list/np.ndarray, default None + vector query for hybrid search + text: str, default None + text query for hybrid search Returns ------- @@ -1264,11 +1272,12 @@ class LanceTable(Table): and also the "_distance" column which is the distance between the query vector and the returned vector. """ - if vector_column_name is None and query is not None: + is_query_defined = query is not None or (vector is not None and text is not None) + if vector_column_name is None and is_query_defined: vector_column_name = inf_vector_column_query(self.schema) register_event("search_table") return LanceQueryBuilder.create( - self, query, query_type, vector_column_name=vector_column_name + self, query, query_type, vector_column_name=vector_column_name, vector=vector, text=text ) @classmethod diff --git a/python/tests/test_rerankers.py b/python/tests/test_rerankers.py index 5d28e412..28f422a6 100644 --- a/python/tests/test_rerankers.py +++ b/python/tests/test_rerankers.py @@ -102,7 +102,7 @@ def test_linear_combination(tmp_path): query = "Our father who art in heaven" query_vector = table.to_pandas()["vector"][0] result = ( - table.search((query_vector, query)) + table.search(vector=query_vector, text=query, query_type="vector") .limit(30) .rerank(normalize="score") .to_arrow() @@ -116,6 +116,13 @@ def test_linear_combination(tmp_path): "be descending." ) + result = ( + table.search(vector=query_vector, text=query) + .limit(30) + .rerank(normalize="score") + .to_arrow() + ) + @pytest.mark.skipif( os.environ.get("COHERE_API_KEY") is None, reason="COHERE_API_KEY not set" @@ -139,7 +146,7 @@ def test_cohere_reranker(tmp_path): query = "Our father who art in heaven" query_vector = table.to_pandas()["vector"][0] result = ( - table.search((query_vector, query)) + table.search(vector=query_vector, text=query) .limit(30) .rerank(reranker=CohereReranker()) .to_arrow() @@ -173,7 +180,7 @@ def test_cross_encoder_reranker(tmp_path): query = "Our father who art in heaven" query_vector = table.to_pandas()["vector"][0] result = ( - table.search((query_vector, query), query_type="hybrid") + table.search(vector=query_vector, text=query, query_type="hybrid") .limit(30) .rerank(reranker=CrossEncoderReranker()) .to_arrow() @@ -207,7 +214,7 @@ def test_colbert_reranker(tmp_path): query = "Our father who art in heaven" query_vector = table.to_pandas()["vector"][0] result = ( - table.search((query_vector, query)) + table.search(vector=query_vector, text=query) .limit(30) .rerank(reranker=ColbertReranker()) .to_arrow() @@ -244,7 +251,7 @@ def test_openai_reranker(tmp_path): query = "Our father who art in heaven" query_vector = table.to_pandas()["vector"][0] result = ( - table.search((query_vector, query)) + table.search(vector=query_vector, text=query) .limit(30) .rerank(reranker=OpenaiReranker()) .to_arrow()