This commit is contained in:
ayush chaurasia
2024-02-23 13:54:44 +05:30
parent 9ca0260d54
commit a696dbc8f4
3 changed files with 41 additions and 23 deletions

View File

@@ -125,22 +125,26 @@ class LanceQueryBuilder(ABC):
if query_type == "hybrid":
# hybrid fts and vector query
return LanceHybridQueryBuilder(table, query, vector_column_name, vector, text)
return LanceHybridQueryBuilder(
table, query, vector_column_name, vector, text
)
# Resolve hybrid query with explicit vector and text params here to avoid
# adding them as instance attributes in the LanceQueryBuilder subclasses
# adding them as params in the BaseQueryBuilder class
if vector is not None or text is not None:
# If vector and/or text are provided, then query_type must be 'hybrid'.
if query_type not in ["hybrid", "auto"]:
raise ValueError(
"If `vector` and `text` are provided, then `query_type` must be 'hybrid'."
"If `vector` and `text` are provided, then `query_type`\
must be 'hybrid' or 'auto'"
)
return LanceHybridQueryBuilder(table, query, vector_column_name, vector, text)
return LanceHybridQueryBuilder(
table, query, vector_column_name, vector, text
)
# convert "auto" query_type to "vector", "fts"
# or "hybrid" and convert the query to vector if needed
# convert "auto" query_type to "vector" or "fts"
# and convert the query to vector if needed
query, query_type = cls._resolve_query(
table, query, query_type, vector_column_name, vector, text
table, query, query_type, vector_column_name
)
if isinstance(query, str):
@@ -170,8 +174,6 @@ class LanceQueryBuilder(ABC):
elif query_type == "auto":
if isinstance(query, (list, np.ndarray)):
return query, "vector"
if isinstance(query, tuple):
return query, "hybrid"
else:
conf = table.embedding_functions.get(vector_column_name)
if conf is not None:
@@ -632,10 +634,19 @@ class LanceEmptyQueryBuilder(LanceQueryBuilder):
class LanceHybridQueryBuilder(LanceQueryBuilder):
def __init__(self, table: "Table", query: str, vector_column: str, vector: Optional[VEC] = None, text: Optional[str] = None):
def __init__(
self,
table: "Table",
query: str,
vector_column: str,
vector: Optional[VEC] = None,
text: Optional[str] = None,
):
super().__init__(table)
self._validate_fts_index()
vector_query, fts_query = self._validate_query(query, vector_column, vector, text)
vector_query, fts_query = self._validate_query(
query, vector_column, vector, text
)
self._fts_query = LanceFtsQueryBuilder(table, fts_query)
self._vector_query = LanceVectorQueryBuilder(table, vector_query, vector_column)
self._norm = "score"
@@ -650,11 +661,15 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
def _validate_query(self, query, vector_column, vector, text):
if query is not None:
if vector is not None or text is not None:
raise ValueError("Either pass `query` or `vector` and `text` separately, not both.")
raise ValueError(
"Either pass `query` or `vector` and `text` separately, not both."
)
else:
if vector is None or text is None:
raise ValueError("Either pass `query` or `vector` and `text` separately, not both.")
raise ValueError(
"Either pass `query` or `vector` and `text` separately, not both."
)
if vector is not None and text is not None:
if not isinstance(vector, (list, np.ndarray, pa.Array, pa.ChunkedArray)):
raise ValueError(f"The vector query must be one of {VEC}.")
@@ -666,11 +681,10 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
return vector, query
else:
raise ValueError(
f"For hybrid search `query` must be a string or `vector` and `text` must be provided explicitly \
of types {VEC} and str respectively."
f"For hybrid search `query` must be a string or `vector` and `text` \
must be provided explicitly of types {VEC} and str respectively."
)
def to_arrow(self) -> pa.Table:
with ThreadPoolExecutor() as executor:
fts_future = executor.submit(self._fts_query.with_row_id(True).to_arrow)

View File

@@ -1277,7 +1277,12 @@ class LanceTable(Table):
vector_column_name = inf_vector_column_query(self.schema)
register_event("search_table")
return LanceQueryBuilder.create(
self, query, query_type, vector_column_name=vector_column_name, vector=vector, text=text
self,
query,
query_type,
vector_column_name=vector_column_name,
vector=vector,
text=text,
)
@classmethod

View File

@@ -129,14 +129,13 @@ def test_linear_combination(tmp_path):
table.search(vector=query_vector, text=query, query_type="vector").rerank(
normalize="score"
)
# raise an error if only vector or text is provided
with pytest.raises(ValueError):
table.search(vector=query_vector).to_arrow()
with pytest.raises(ValueError):
table.search(text=query).to_arrow()
@pytest.mark.skipif(