mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-23 05:19:58 +00:00
Compare commits
8 Commits
add-python
...
hybrid_que
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
887ac0d79d | ||
|
|
3ad4992282 | ||
|
|
51cc422799 | ||
|
|
a696dbc8f4 | ||
|
|
9ca0260d54 | ||
|
|
6486ec870b | ||
|
|
64db2393f7 | ||
|
|
bd4e8341fe |
@@ -117,23 +117,36 @@ class LanceQueryBuilder(ABC):
|
||||
query: Optional[Union[np.ndarray, str, "PIL.Image.Image", Tuple]],
|
||||
query_type: str,
|
||||
vector_column_name: str,
|
||||
vector: Optional[VEC] = None,
|
||||
text: Optional[str] = None,
|
||||
) -> LanceQueryBuilder:
|
||||
if query is None:
|
||||
if query is None and vector is None and text is None:
|
||||
return LanceEmptyQueryBuilder(table)
|
||||
|
||||
if query_type == "hybrid":
|
||||
# hybrid fts and vector query
|
||||
return LanceHybridQueryBuilder(table, query, vector_column_name)
|
||||
return LanceHybridQueryBuilder(
|
||||
table, query, vector_column_name, vector, text
|
||||
)
|
||||
|
||||
# convert "auto" query_type to "vector", "fts"
|
||||
# or "hybrid" and convert the query to vector if needed
|
||||
# Resolve hybrid query with explicit vector and text params here to avoid
|
||||
# adding them as params in the BaseQueryBuilder class
|
||||
if vector is not None or text is not None:
|
||||
if query_type not in ["hybrid", "auto"]:
|
||||
raise ValueError(
|
||||
"If `vector` and `text` are provided, then `query_type`\
|
||||
must be 'hybrid' or 'auto'"
|
||||
)
|
||||
return LanceHybridQueryBuilder(
|
||||
table, query, vector_column_name, vector, text
|
||||
)
|
||||
|
||||
# convert "auto" query_type to "vector" or "fts"
|
||||
# and convert the query to vector if needed
|
||||
query, query_type = cls._resolve_query(
|
||||
table, query, query_type, vector_column_name
|
||||
)
|
||||
|
||||
if query_type == "hybrid":
|
||||
return LanceHybridQueryBuilder(table, query, vector_column_name)
|
||||
|
||||
if isinstance(query, str):
|
||||
# fts
|
||||
return LanceFtsQueryBuilder(table, query)
|
||||
@@ -161,8 +174,6 @@ class LanceQueryBuilder(ABC):
|
||||
elif query_type == "auto":
|
||||
if isinstance(query, (list, np.ndarray)):
|
||||
return query, "vector"
|
||||
if isinstance(query, tuple):
|
||||
return query, "hybrid"
|
||||
else:
|
||||
conf = table.embedding_functions.get(vector_column_name)
|
||||
if conf is not None:
|
||||
@@ -628,12 +639,20 @@ class LanceEmptyQueryBuilder(LanceQueryBuilder):
|
||||
|
||||
|
||||
class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||
def __init__(self, table: "Table", query: str, vector_column: str):
|
||||
def __init__(
|
||||
self,
|
||||
table: "Table",
|
||||
query: str,
|
||||
vector_column: str,
|
||||
vector: Optional[VEC] = None,
|
||||
text: Optional[str] = None,
|
||||
):
|
||||
super().__init__(table)
|
||||
self._validate_fts_index()
|
||||
vector_query, fts_query = self._validate_query(query)
|
||||
vector_query, fts_query = self._validate_query(
|
||||
query, vector_column, vector, text
|
||||
)
|
||||
self._fts_query = LanceFtsQueryBuilder(table, fts_query)
|
||||
vector_query = self._query_to_vector(table, vector_query, vector_column)
|
||||
self._vector_query = LanceVectorQueryBuilder(table, vector_query, vector_column)
|
||||
self._norm = "score"
|
||||
self._reranker = LinearCombinationReranker(weight=0.7, fill=1.0)
|
||||
@@ -644,23 +663,31 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||
"Please create a full-text search index " "to perform hybrid search."
|
||||
)
|
||||
|
||||
def _validate_query(self, query):
|
||||
# Temp hack to support vectorized queries for hybrid search
|
||||
if isinstance(query, str):
|
||||
return query, query
|
||||
elif isinstance(query, tuple):
|
||||
if len(query) != 2:
|
||||
def _validate_query(self, query, vector_column, vector, text):
|
||||
if query is not None:
|
||||
if vector is not None or text is not None:
|
||||
raise ValueError(
|
||||
"The query must be a tuple of (vector_query, fts_query)."
|
||||
"Either pass `query` or `vector` and `text` separately, not both."
|
||||
)
|
||||
if not isinstance(query[0], (list, np.ndarray, pa.Array, pa.ChunkedArray)):
|
||||
else:
|
||||
if vector is None or text is None:
|
||||
raise ValueError(
|
||||
"Either pass `query` or `vector` and `text` separately, not both."
|
||||
)
|
||||
|
||||
if vector is not None and text is not None:
|
||||
if not isinstance(vector, (list, np.ndarray, pa.Array, pa.ChunkedArray)):
|
||||
raise ValueError(f"The vector query must be one of {VEC}.")
|
||||
if not isinstance(query[1], str):
|
||||
if not isinstance(text, str):
|
||||
raise ValueError("The fts query must be a string.")
|
||||
return query[0], query[1]
|
||||
return vector, text
|
||||
if isinstance(query, str):
|
||||
vector = self._query_to_vector(self._table, query, vector_column)
|
||||
return vector, query
|
||||
else:
|
||||
raise ValueError(
|
||||
"The query must be either a string or a tuple of (vector, string)."
|
||||
f"For hybrid search `query` must be a string or `vector` and `text` \
|
||||
must be provided explicitly of types {VEC} and str respectively."
|
||||
)
|
||||
|
||||
def to_arrow(self) -> pa.Table:
|
||||
|
||||
@@ -418,6 +418,8 @@ class Table(ABC):
|
||||
query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
|
||||
vector_column_name: Optional[str] = None,
|
||||
query_type: str = "auto",
|
||||
vector: Optional[VEC] = None,
|
||||
text: Optional[str] = None,
|
||||
) -> LanceQueryBuilder:
|
||||
"""Create a search query to find the nearest neighbors
|
||||
of the given query vector. We currently support [vector search][search]
|
||||
@@ -1253,6 +1255,8 @@ class LanceTable(Table):
|
||||
query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
|
||||
vector_column_name: Optional[str] = None,
|
||||
query_type: str = "auto",
|
||||
vector: Optional[VEC] = None,
|
||||
text: Optional[str] = None,
|
||||
) -> LanceQueryBuilder:
|
||||
"""Create a search query to find the nearest neighbors
|
||||
of the given query vector. We currently support [vector search][search]
|
||||
@@ -1307,6 +1311,10 @@ class LanceTable(Table):
|
||||
or raise an error if no corresponding embedding function is found.
|
||||
If the `query` is a string, then the query type is "vector" if the
|
||||
table has embedding functions, else the query type is "fts"
|
||||
vector: list/np.ndarray, default None
|
||||
vector query for hybrid search
|
||||
text: str, default None
|
||||
text query for hybrid search
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -1316,11 +1324,17 @@ class LanceTable(Table):
|
||||
and also the "_distance" column which is the distance between the query
|
||||
vector and the returned vector.
|
||||
"""
|
||||
if vector_column_name is None and query is not None:
|
||||
is_query_defined = query is not None or vector is not None or text is not None
|
||||
if vector_column_name is None and is_query_defined:
|
||||
vector_column_name = inf_vector_column_query(self.schema)
|
||||
register_event("search_table")
|
||||
return LanceQueryBuilder.create(
|
||||
self, query, query_type, vector_column_name=vector_column_name
|
||||
self,
|
||||
query,
|
||||
query_type,
|
||||
vector_column_name=vector_column_name,
|
||||
vector=vector,
|
||||
text=text,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -104,7 +104,7 @@ def test_linear_combination(tmp_path):
|
||||
query = "Our father who art in heaven"
|
||||
query_vector = table.to_pandas()["vector"][0]
|
||||
result = (
|
||||
table.search((query_vector, query))
|
||||
table.search(vector=query_vector, text=query, query_type="hybrid")
|
||||
.limit(30)
|
||||
.rerank(normalize="score")
|
||||
.to_arrow()
|
||||
@@ -118,6 +118,32 @@ def test_linear_combination(tmp_path):
|
||||
"be descending."
|
||||
)
|
||||
|
||||
# automatically deduce the query type
|
||||
result = (
|
||||
table.search(vector=query_vector, text=query)
|
||||
.limit(30)
|
||||
.rerank(normalize="score")
|
||||
.to_arrow()
|
||||
)
|
||||
|
||||
# wrong query type raises an error
|
||||
with pytest.raises(ValueError):
|
||||
table.search(vector=query_vector, text=query, query_type="vector").rerank(
|
||||
normalize="score"
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
table.search(vector=query_vector, text=query, query_type="fts").rerank(
|
||||
normalize="score"
|
||||
)
|
||||
|
||||
# raise an error if only vector or text is provided
|
||||
with pytest.raises(ValueError):
|
||||
table.search(vector=query_vector).to_arrow()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
table.search(text=query).to_arrow()
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("COHERE_API_KEY") is None, reason="COHERE_API_KEY not set"
|
||||
@@ -141,7 +167,7 @@ def test_cohere_reranker(tmp_path):
|
||||
query = "Our father who art in heaven"
|
||||
query_vector = table.to_pandas()["vector"][0]
|
||||
result = (
|
||||
table.search((query_vector, query))
|
||||
table.search(vector=query_vector, text=query)
|
||||
.limit(30)
|
||||
.rerank(reranker=CohereReranker())
|
||||
.to_arrow()
|
||||
@@ -175,7 +201,7 @@ def test_cross_encoder_reranker(tmp_path):
|
||||
query = "Our father who art in heaven"
|
||||
query_vector = table.to_pandas()["vector"][0]
|
||||
result = (
|
||||
table.search((query_vector, query), query_type="hybrid")
|
||||
table.search(vector=query_vector, text=query, query_type="hybrid")
|
||||
.limit(30)
|
||||
.rerank(reranker=CrossEncoderReranker())
|
||||
.to_arrow()
|
||||
@@ -209,7 +235,7 @@ def test_colbert_reranker(tmp_path):
|
||||
query = "Our father who art in heaven"
|
||||
query_vector = table.to_pandas()["vector"][0]
|
||||
result = (
|
||||
table.search((query_vector, query))
|
||||
table.search(vector=query_vector, text=query)
|
||||
.limit(30)
|
||||
.rerank(reranker=ColbertReranker())
|
||||
.to_arrow()
|
||||
@@ -246,7 +272,7 @@ def test_openai_reranker(tmp_path):
|
||||
query = "Our father who art in heaven"
|
||||
query_vector = table.to_pandas()["vector"][0]
|
||||
result = (
|
||||
table.search((query_vector, query))
|
||||
table.search(vector=query_vector, text=query)
|
||||
.limit(30)
|
||||
.rerank(reranker=OpenaiReranker())
|
||||
.to_arrow()
|
||||
|
||||
Reference in New Issue
Block a user