From 029b01bbbf800a96a9764fa6c93acf10dfad040b Mon Sep 17 00:00:00 2001 From: James Wu Date: Fri, 6 Sep 2024 20:28:05 -0700 Subject: [PATCH] feat: enable phrase_query(bool) for hybrid search queries (#1578) first off, apologies for any folly since i'm new to contributing to lancedb. this PR is the continuation of [a discord thread](https://discord.com/channels/1030247538198061086/1030247538667827251/1278844345713299599): ## user story here's the lance db search query i'd like to run: ``` def search(phrase): logger.info(f'Searching for phrase: {phrase}') phrase_embedding = get_embedding(phrase) df = (table.search((phrase_embedding, phrase), query_type='hybrid') .limit(10).to_list()) logger.info(f'Success search with row count: {len(df)}') search('howdy (howdy)') search('howdy(howdy)') ``` the second search fails due to `ValueError: Syntax Error: howdy(howdy)` i saw on the [docs](https://lancedb.github.io/lancedb/fts/#phrase-queries-vs-terms-queries) that i can use `phrase_query()` to [enable a flag](https://github.com/lancedb/lancedb/blob/main/python/python/lancedb/query.py#L790-L792) to wrap the query in double quotes (as well as sanitize single quotes) prior to sending the query to search. this works for [normal FTS](https://lancedb.github.io/lancedb/fts/), but the command is unavailable on [hybrid search](https://lancedb.github.io/lancedb/hybrid_search/hybrid_search/). ## changes i added `phrase_query()` function to `LanceHybridQueryBuilder` by propagating the call down to its `self. _fts_query` object. i'm not too familiar with the codebase and am not sure if this is the best way to implement the functionality. feel free to riff on this PR or discard ## tests ``` (lancedb) JamesMPB:python james$ pwd /Users/james/src/lancedb/python (lancedb) JamesMPB:python james$ pytest python/tests/test_table.py python/tests/test_table.py ....................................... [100%] ====================================================== 39 passed, 1 warning in 2.23s ======================================================= ``` --- python/python/lancedb/query.py | 22 +++++++++++++++++++++- python/python/tests/test_table.py | 12 +++++++++++- 2 files changed, 32 insertions(+), 2 deletions(-) diff --git a/python/python/lancedb/query.py b/python/python/lancedb/query.py index c6b14c0f..13b0460c 100644 --- a/python/python/lancedb/query.py +++ b/python/python/lancedb/query.py @@ -42,9 +42,9 @@ if TYPE_CHECKING: import PIL import polars as pl - from .common import VEC from ._lancedb import Query as LanceQuery from ._lancedb import VectorQuery as LanceVectorQuery + from .common import VEC from .pydantic import LanceModel from .table import Table @@ -965,6 +965,7 @@ class LanceHybridQueryBuilder(LanceQueryBuilder): self._reranker = RRFReranker() self._nprobes = None self._refine_factor = None + self._phrase_query = False def _validate_query(self, query, vector=None, text=None): if query is not None and (vector is not None or text is not None): @@ -986,6 +987,23 @@ class LanceHybridQueryBuilder(LanceQueryBuilder): return vector_query, text_query + def phrase_query(self, phrase_query: bool = True) -> LanceHybridQueryBuilder: + """Set whether to use phrase query. + + Parameters + ---------- + phrase_query: bool, default True + If True, then the query will be wrapped in quotes and + double quotes replaced by single quotes. + + Returns + ------- + LanceHybridQueryBuilder + The LanceHybridQueryBuilder object. + """ + self._phrase_query = phrase_query + return self + def to_arrow(self) -> pa.Table: vector_query, fts_query = self._validate_query( self._query, self._vector, self._text @@ -1012,6 +1030,8 @@ class LanceHybridQueryBuilder(LanceQueryBuilder): if self._with_row_id: self._vector_query.with_row_id(True) self._fts_query.with_row_id(True) + if self._phrase_query: + self._fts_query.phrase_query(True) if self._nprobes: self._vector_query.nprobes(self._nprobes) if self._refine_factor: diff --git a/python/python/tests/test_table.py b/python/python/tests/test_table.py index 6ca2f5f1..65cf0c9d 100644 --- a/python/python/tests/test_table.py +++ b/python/python/tests/test_table.py @@ -2,13 +2,13 @@ # SPDX-FileCopyrightText: Copyright The Lance Authors import functools +import os from copy import copy from datetime import date, datetime, timedelta from pathlib import Path from time import sleep from typing import List from unittest.mock import PropertyMock, patch -import os import lance import lancedb @@ -907,6 +907,16 @@ def test_hybrid_search(db, tmp_path): "Our father who art in heaven", query_type="hybrid" ).to_pydantic(MyTable) + # Test that double and single quote characters are handled with phrase_query() + ( + table.search( + '"Aren\'t you a little short for a stormtrooper?" -- Leia', + query_type="hybrid", + ) + .phrase_query(True) + .to_pydantic(MyTable) + ) + assert result1 == result3 # with post filters