This commit is contained in:
ayush chaurasia
2024-02-21 21:43:23 +05:30
parent 22c196b3e3
commit bd4e8341fe
3 changed files with 58 additions and 29 deletions

View File

@@ -117,23 +117,32 @@ class LanceQueryBuilder(ABC):
query: Optional[Union[np.ndarray, str, "PIL.Image.Image", Tuple]],
query_type: str,
vector_column_name: str,
vector: Optional[VEC] = None,
text: Optional[str] = None,
) -> LanceQueryBuilder:
if query is None:
if query is None and vector is None and text is None:
return LanceEmptyQueryBuilder(table)
if query_type == "hybrid":
# hybrid fts and vector query
return LanceHybridQueryBuilder(table, query, vector_column_name)
return LanceHybridQueryBuilder(table, query, vector_column_name, vector, text)
# Resolve hybrid query with explicit vector and text params here to avoid
# adding them as instance attributes in the LanceQueryBuilder subclasses
if vector is not None or text is not None:
# If vector and/or text are provided, then query_type must be 'hybrid'.
if query_type not in ["hybrid", "auto"]:
raise ValueError(
"If `vector` and `text` are provided, then `query_type` must be 'hybrid'."
)
return LanceHybridQueryBuilder(table, query, vector_column_name, vector, text)
# convert "auto" query_type to "vector", "fts"
# or "hybrid" and convert the query to vector if needed
query, query_type = cls._resolve_query(
table, query, query_type, vector_column_name
table, query, query_type, vector_column_name, vector, text
)
if query_type == "hybrid":
return LanceHybridQueryBuilder(table, query, vector_column_name)
if isinstance(query, str):
# fts
return LanceFtsQueryBuilder(table, query)
@@ -623,12 +632,11 @@ class LanceEmptyQueryBuilder(LanceQueryBuilder):
class LanceHybridQueryBuilder(LanceQueryBuilder):
def __init__(self, table: "Table", query: str, vector_column: str):
def __init__(self, table: "Table", query: str, vector_column: str, vector: Optional[VEC] = None, text: Optional[str] = None):
super().__init__(table)
self._validate_fts_index()
vector_query, fts_query = self._validate_query(query)
vector_query, fts_query = self._validate_query(query, vector_column, vector, text)
self._fts_query = LanceFtsQueryBuilder(table, fts_query)
vector_query = self._query_to_vector(table, vector_query, vector_column)
self._vector_query = LanceVectorQueryBuilder(table, vector_query, vector_column)
self._norm = "score"
self._reranker = LinearCombinationReranker(weight=0.7, fill=1.0)
@@ -639,25 +647,30 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
"Please create a full-text search index " "to perform hybrid search."
)
def _validate_query(self, query):
# Temp hack to support vectorized queries for hybrid search
if isinstance(query, str):
return query, query
elif isinstance(query, tuple):
if len(query) != 2:
raise ValueError(
"The query must be a tuple of (vector_query, fts_query)."
)
if not isinstance(query[0], (list, np.ndarray, pa.Array, pa.ChunkedArray)):
def _validate_query(self, query, vector_column, vector, text):
if query is not None:
if vector is not None or text is not None:
raise ValueError("Either pass `query` or `vector` and `text` separately, not both.")
else:
if vector is None or text is None:
raise ValueError("Either pass `query` or `vector` and `text` separately, not both.")
if vector is not None and text is not None:
if not isinstance(vector, (list, np.ndarray, pa.Array, pa.ChunkedArray)):
raise ValueError(f"The vector query must be one of {VEC}.")
if not isinstance(query[1], str):
if not isinstance(text, str):
raise ValueError("The fts query must be a string.")
return query[0], query[1]
return vector, text
if isinstance(query, str):
vector = self._query_to_vector(self._table, query, vector_column)
return vector, query
else:
raise ValueError(
"The query must be either a string or a tuple of (vector, string)."
f"For hybrid search `query` must be a string or `vector` and `text` must be provided explicitly \
of types {VEC} and str respectively."
)
def to_arrow(self) -> pa.Table:
with ThreadPoolExecutor() as executor:
fts_future = executor.submit(self._fts_query.with_row_id(True).to_arrow)

View File

@@ -416,6 +416,8 @@ class Table(ABC):
query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
vector_column_name: Optional[str] = None,
query_type: str = "auto",
vector: Optional[VEC] = None,
text: Optional[str] = None,
) -> LanceQueryBuilder:
"""Create a search query to find the nearest neighbors
of the given query vector. We currently support [vector search][search]
@@ -1201,6 +1203,8 @@ class LanceTable(Table):
query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
vector_column_name: Optional[str] = None,
query_type: str = "auto",
vector: Optional[VEC] = None,
text: Optional[str] = None,
) -> LanceQueryBuilder:
"""Create a search query to find the nearest neighbors
of the given query vector. We currently support [vector search][search]
@@ -1255,6 +1259,10 @@ class LanceTable(Table):
or raise an error if no corresponding embedding function is found.
If the `query` is a string, then the query type is "vector" if the
table has embedding functions, else the query type is "fts"
vector: list/np.ndarray, default None
vector query for hybrid search
text: str, default None
text query for hybrid search
Returns
-------
@@ -1264,11 +1272,12 @@ class LanceTable(Table):
and also the "_distance" column which is the distance between the query
vector and the returned vector.
"""
if vector_column_name is None and query is not None:
is_query_defined = query is not None or (vector is not None and text is not None)
if vector_column_name is None and is_query_defined:
vector_column_name = inf_vector_column_query(self.schema)
register_event("search_table")
return LanceQueryBuilder.create(
self, query, query_type, vector_column_name=vector_column_name
self, query, query_type, vector_column_name=vector_column_name, vector=vector, text=text
)
@classmethod

View File

@@ -102,7 +102,7 @@ def test_linear_combination(tmp_path):
query = "Our father who art in heaven"
query_vector = table.to_pandas()["vector"][0]
result = (
table.search((query_vector, query))
table.search(vector=query_vector, text=query, query_type="vector")
.limit(30)
.rerank(normalize="score")
.to_arrow()
@@ -116,6 +116,13 @@ def test_linear_combination(tmp_path):
"be descending."
)
result = (
table.search(vector=query_vector, text=query)
.limit(30)
.rerank(normalize="score")
.to_arrow()
)
@pytest.mark.skipif(
os.environ.get("COHERE_API_KEY") is None, reason="COHERE_API_KEY not set"
@@ -139,7 +146,7 @@ def test_cohere_reranker(tmp_path):
query = "Our father who art in heaven"
query_vector = table.to_pandas()["vector"][0]
result = (
table.search((query_vector, query))
table.search(vector=query_vector, text=query)
.limit(30)
.rerank(reranker=CohereReranker())
.to_arrow()
@@ -173,7 +180,7 @@ def test_cross_encoder_reranker(tmp_path):
query = "Our father who art in heaven"
query_vector = table.to_pandas()["vector"][0]
result = (
table.search((query_vector, query), query_type="hybrid")
table.search(vector=query_vector, text=query, query_type="hybrid")
.limit(30)
.rerank(reranker=CrossEncoderReranker())
.to_arrow()
@@ -207,7 +214,7 @@ def test_colbert_reranker(tmp_path):
query = "Our father who art in heaven"
query_vector = table.to_pandas()["vector"][0]
result = (
table.search((query_vector, query))
table.search(vector=query_vector, text=query)
.limit(30)
.rerank(reranker=ColbertReranker())
.to_arrow()
@@ -244,7 +251,7 @@ def test_openai_reranker(tmp_path):
query = "Our father who art in heaven"
query_vector = table.to_pandas()["vector"][0]
result = (
table.search((query_vector, query))
table.search(vector=query_vector, text=query)
.limit(30)
.rerank(reranker=OpenaiReranker())
.to_arrow()