Compare commits

...

8 Commits

Author SHA1 Message Date
ayush chaurasia
887ac0d79d Merge branch 'main' of https://github.com/lancedb/lancedb into hybrid_query 2024-03-01 11:14:06 +05:30
ayush chaurasia
3ad4992282 update 2024-02-23 14:11:58 +05:30
ayush chaurasia
51cc422799 update 2024-02-23 14:03:21 +05:30
ayush chaurasia
a696dbc8f4 update 2024-02-23 13:54:44 +05:30
ayush chaurasia
9ca0260d54 update 2024-02-23 03:03:39 +05:30
ayush chaurasia
6486ec870b update 2024-02-23 03:02:05 +05:30
ayush chaurasia
64db2393f7 update 2024-02-22 16:28:17 +05:30
ayush chaurasia
bd4e8341fe update 2024-02-21 21:43:23 +05:30
3 changed files with 97 additions and 30 deletions

View File

@@ -117,23 +117,36 @@ class LanceQueryBuilder(ABC):
query: Optional[Union[np.ndarray, str, "PIL.Image.Image", Tuple]],
query_type: str,
vector_column_name: str,
vector: Optional[VEC] = None,
text: Optional[str] = None,
) -> LanceQueryBuilder:
if query is None:
if query is None and vector is None and text is None:
return LanceEmptyQueryBuilder(table)
if query_type == "hybrid":
# hybrid fts and vector query
return LanceHybridQueryBuilder(table, query, vector_column_name)
return LanceHybridQueryBuilder(
table, query, vector_column_name, vector, text
)
# convert "auto" query_type to "vector", "fts"
# or "hybrid" and convert the query to vector if needed
# Resolve hybrid query with explicit vector and text params here to avoid
# adding them as params in the BaseQueryBuilder class
if vector is not None or text is not None:
if query_type not in ["hybrid", "auto"]:
raise ValueError(
"If `vector` and `text` are provided, then `query_type`\
must be 'hybrid' or 'auto'"
)
return LanceHybridQueryBuilder(
table, query, vector_column_name, vector, text
)
# convert "auto" query_type to "vector" or "fts"
# and convert the query to vector if needed
query, query_type = cls._resolve_query(
table, query, query_type, vector_column_name
)
if query_type == "hybrid":
return LanceHybridQueryBuilder(table, query, vector_column_name)
if isinstance(query, str):
# fts
return LanceFtsQueryBuilder(table, query)
@@ -161,8 +174,6 @@ class LanceQueryBuilder(ABC):
elif query_type == "auto":
if isinstance(query, (list, np.ndarray)):
return query, "vector"
if isinstance(query, tuple):
return query, "hybrid"
else:
conf = table.embedding_functions.get(vector_column_name)
if conf is not None:
@@ -628,12 +639,20 @@ class LanceEmptyQueryBuilder(LanceQueryBuilder):
class LanceHybridQueryBuilder(LanceQueryBuilder):
def __init__(self, table: "Table", query: str, vector_column: str):
def __init__(
self,
table: "Table",
query: str,
vector_column: str,
vector: Optional[VEC] = None,
text: Optional[str] = None,
):
super().__init__(table)
self._validate_fts_index()
vector_query, fts_query = self._validate_query(query)
vector_query, fts_query = self._validate_query(
query, vector_column, vector, text
)
self._fts_query = LanceFtsQueryBuilder(table, fts_query)
vector_query = self._query_to_vector(table, vector_query, vector_column)
self._vector_query = LanceVectorQueryBuilder(table, vector_query, vector_column)
self._norm = "score"
self._reranker = LinearCombinationReranker(weight=0.7, fill=1.0)
@@ -644,23 +663,31 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
"Please create a full-text search index " "to perform hybrid search."
)
def _validate_query(self, query):
# Temp hack to support vectorized queries for hybrid search
if isinstance(query, str):
return query, query
elif isinstance(query, tuple):
if len(query) != 2:
def _validate_query(self, query, vector_column, vector, text):
if query is not None:
if vector is not None or text is not None:
raise ValueError(
"The query must be a tuple of (vector_query, fts_query)."
"Either pass `query` or `vector` and `text` separately, not both."
)
if not isinstance(query[0], (list, np.ndarray, pa.Array, pa.ChunkedArray)):
else:
if vector is None or text is None:
raise ValueError(
"Either pass `query` or `vector` and `text` separately, not both."
)
if vector is not None and text is not None:
if not isinstance(vector, (list, np.ndarray, pa.Array, pa.ChunkedArray)):
raise ValueError(f"The vector query must be one of {VEC}.")
if not isinstance(query[1], str):
if not isinstance(text, str):
raise ValueError("The fts query must be a string.")
return query[0], query[1]
return vector, text
if isinstance(query, str):
vector = self._query_to_vector(self._table, query, vector_column)
return vector, query
else:
raise ValueError(
"The query must be either a string or a tuple of (vector, string)."
f"For hybrid search `query` must be a string or `vector` and `text` \
must be provided explicitly of types {VEC} and str respectively."
)
def to_arrow(self) -> pa.Table:

View File

@@ -418,6 +418,8 @@ class Table(ABC):
query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
vector_column_name: Optional[str] = None,
query_type: str = "auto",
vector: Optional[VEC] = None,
text: Optional[str] = None,
) -> LanceQueryBuilder:
"""Create a search query to find the nearest neighbors
of the given query vector. We currently support [vector search][search]
@@ -1253,6 +1255,8 @@ class LanceTable(Table):
query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
vector_column_name: Optional[str] = None,
query_type: str = "auto",
vector: Optional[VEC] = None,
text: Optional[str] = None,
) -> LanceQueryBuilder:
"""Create a search query to find the nearest neighbors
of the given query vector. We currently support [vector search][search]
@@ -1307,6 +1311,10 @@ class LanceTable(Table):
or raise an error if no corresponding embedding function is found.
If the `query` is a string, then the query type is "vector" if the
table has embedding functions, else the query type is "fts"
vector: list/np.ndarray, default None
vector query for hybrid search
text: str, default None
text query for hybrid search
Returns
-------
@@ -1316,11 +1324,17 @@ class LanceTable(Table):
and also the "_distance" column which is the distance between the query
vector and the returned vector.
"""
if vector_column_name is None and query is not None:
is_query_defined = query is not None or vector is not None or text is not None
if vector_column_name is None and is_query_defined:
vector_column_name = inf_vector_column_query(self.schema)
register_event("search_table")
return LanceQueryBuilder.create(
self, query, query_type, vector_column_name=vector_column_name
self,
query,
query_type,
vector_column_name=vector_column_name,
vector=vector,
text=text,
)
@classmethod

View File

@@ -104,7 +104,7 @@ def test_linear_combination(tmp_path):
query = "Our father who art in heaven"
query_vector = table.to_pandas()["vector"][0]
result = (
table.search((query_vector, query))
table.search(vector=query_vector, text=query, query_type="hybrid")
.limit(30)
.rerank(normalize="score")
.to_arrow()
@@ -118,6 +118,32 @@ def test_linear_combination(tmp_path):
"be descending."
)
# automatically deduce the query type
result = (
table.search(vector=query_vector, text=query)
.limit(30)
.rerank(normalize="score")
.to_arrow()
)
# wrong query type raises an error
with pytest.raises(ValueError):
table.search(vector=query_vector, text=query, query_type="vector").rerank(
normalize="score"
)
with pytest.raises(ValueError):
table.search(vector=query_vector, text=query, query_type="fts").rerank(
normalize="score"
)
# raise an error if only vector or text is provided
with pytest.raises(ValueError):
table.search(vector=query_vector).to_arrow()
with pytest.raises(ValueError):
table.search(text=query).to_arrow()
@pytest.mark.skipif(
os.environ.get("COHERE_API_KEY") is None, reason="COHERE_API_KEY not set"
@@ -141,7 +167,7 @@ def test_cohere_reranker(tmp_path):
query = "Our father who art in heaven"
query_vector = table.to_pandas()["vector"][0]
result = (
table.search((query_vector, query))
table.search(vector=query_vector, text=query)
.limit(30)
.rerank(reranker=CohereReranker())
.to_arrow()
@@ -175,7 +201,7 @@ def test_cross_encoder_reranker(tmp_path):
query = "Our father who art in heaven"
query_vector = table.to_pandas()["vector"][0]
result = (
table.search((query_vector, query), query_type="hybrid")
table.search(vector=query_vector, text=query, query_type="hybrid")
.limit(30)
.rerank(reranker=CrossEncoderReranker())
.to_arrow()
@@ -209,7 +235,7 @@ def test_colbert_reranker(tmp_path):
query = "Our father who art in heaven"
query_vector = table.to_pandas()["vector"][0]
result = (
table.search((query_vector, query))
table.search(vector=query_vector, text=query)
.limit(30)
.rerank(reranker=ColbertReranker())
.to_arrow()
@@ -246,7 +272,7 @@ def test_openai_reranker(tmp_path):
query = "Our father who art in heaven"
query_vector = table.to_pandas()["vector"][0]
result = (
table.search((query_vector, query))
table.search(vector=query_vector, text=query)
.limit(30)
.rerank(reranker=OpenaiReranker())
.to_arrow()