mirror of
https://github.com/lancedb/lancedb.git
synced 2026-05-14 18:40:39 +00:00
update
This commit is contained in:
@@ -125,22 +125,26 @@ class LanceQueryBuilder(ABC):
|
||||
|
||||
if query_type == "hybrid":
|
||||
# hybrid fts and vector query
|
||||
return LanceHybridQueryBuilder(table, query, vector_column_name, vector, text)
|
||||
|
||||
return LanceHybridQueryBuilder(
|
||||
table, query, vector_column_name, vector, text
|
||||
)
|
||||
|
||||
# Resolve hybrid query with explicit vector and text params here to avoid
|
||||
# adding them as instance attributes in the LanceQueryBuilder subclasses
|
||||
# adding them as params in the BaseQueryBuilder class
|
||||
if vector is not None or text is not None:
|
||||
# If vector and/or text are provided, then query_type must be 'hybrid'.
|
||||
if query_type not in ["hybrid", "auto"]:
|
||||
raise ValueError(
|
||||
"If `vector` and `text` are provided, then `query_type` must be 'hybrid'."
|
||||
"If `vector` and `text` are provided, then `query_type`\
|
||||
must be 'hybrid' or 'auto'"
|
||||
)
|
||||
return LanceHybridQueryBuilder(table, query, vector_column_name, vector, text)
|
||||
return LanceHybridQueryBuilder(
|
||||
table, query, vector_column_name, vector, text
|
||||
)
|
||||
|
||||
# convert "auto" query_type to "vector", "fts"
|
||||
# or "hybrid" and convert the query to vector if needed
|
||||
# convert "auto" query_type to "vector" or "fts"
|
||||
# and convert the query to vector if needed
|
||||
query, query_type = cls._resolve_query(
|
||||
table, query, query_type, vector_column_name, vector, text
|
||||
table, query, query_type, vector_column_name
|
||||
)
|
||||
|
||||
if isinstance(query, str):
|
||||
@@ -170,8 +174,6 @@ class LanceQueryBuilder(ABC):
|
||||
elif query_type == "auto":
|
||||
if isinstance(query, (list, np.ndarray)):
|
||||
return query, "vector"
|
||||
if isinstance(query, tuple):
|
||||
return query, "hybrid"
|
||||
else:
|
||||
conf = table.embedding_functions.get(vector_column_name)
|
||||
if conf is not None:
|
||||
@@ -632,10 +634,19 @@ class LanceEmptyQueryBuilder(LanceQueryBuilder):
|
||||
|
||||
|
||||
class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||
def __init__(self, table: "Table", query: str, vector_column: str, vector: Optional[VEC] = None, text: Optional[str] = None):
|
||||
def __init__(
|
||||
self,
|
||||
table: "Table",
|
||||
query: str,
|
||||
vector_column: str,
|
||||
vector: Optional[VEC] = None,
|
||||
text: Optional[str] = None,
|
||||
):
|
||||
super().__init__(table)
|
||||
self._validate_fts_index()
|
||||
vector_query, fts_query = self._validate_query(query, vector_column, vector, text)
|
||||
vector_query, fts_query = self._validate_query(
|
||||
query, vector_column, vector, text
|
||||
)
|
||||
self._fts_query = LanceFtsQueryBuilder(table, fts_query)
|
||||
self._vector_query = LanceVectorQueryBuilder(table, vector_query, vector_column)
|
||||
self._norm = "score"
|
||||
@@ -650,11 +661,15 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||
def _validate_query(self, query, vector_column, vector, text):
|
||||
if query is not None:
|
||||
if vector is not None or text is not None:
|
||||
raise ValueError("Either pass `query` or `vector` and `text` separately, not both.")
|
||||
raise ValueError(
|
||||
"Either pass `query` or `vector` and `text` separately, not both."
|
||||
)
|
||||
else:
|
||||
if vector is None or text is None:
|
||||
raise ValueError("Either pass `query` or `vector` and `text` separately, not both.")
|
||||
|
||||
raise ValueError(
|
||||
"Either pass `query` or `vector` and `text` separately, not both."
|
||||
)
|
||||
|
||||
if vector is not None and text is not None:
|
||||
if not isinstance(vector, (list, np.ndarray, pa.Array, pa.ChunkedArray)):
|
||||
raise ValueError(f"The vector query must be one of {VEC}.")
|
||||
@@ -666,11 +681,10 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||
return vector, query
|
||||
else:
|
||||
raise ValueError(
|
||||
f"For hybrid search `query` must be a string or `vector` and `text` must be provided explicitly \
|
||||
of types {VEC} and str respectively."
|
||||
f"For hybrid search `query` must be a string or `vector` and `text` \
|
||||
must be provided explicitly of types {VEC} and str respectively."
|
||||
)
|
||||
|
||||
|
||||
def to_arrow(self) -> pa.Table:
|
||||
with ThreadPoolExecutor() as executor:
|
||||
fts_future = executor.submit(self._fts_query.with_row_id(True).to_arrow)
|
||||
|
||||
@@ -1277,7 +1277,12 @@ class LanceTable(Table):
|
||||
vector_column_name = inf_vector_column_query(self.schema)
|
||||
register_event("search_table")
|
||||
return LanceQueryBuilder.create(
|
||||
self, query, query_type, vector_column_name=vector_column_name, vector=vector, text=text
|
||||
self,
|
||||
query,
|
||||
query_type,
|
||||
vector_column_name=vector_column_name,
|
||||
vector=vector,
|
||||
text=text,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -129,14 +129,13 @@ def test_linear_combination(tmp_path):
|
||||
table.search(vector=query_vector, text=query, query_type="vector").rerank(
|
||||
normalize="score"
|
||||
)
|
||||
|
||||
|
||||
# raise an error if only vector or text is provided
|
||||
with pytest.raises(ValueError):
|
||||
table.search(vector=query_vector).to_arrow()
|
||||
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
table.search(text=query).to_arrow()
|
||||
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
|
||||
Reference in New Issue
Block a user