feat: support to sepcify ef search param (#1844)

Signed-off-by: BubbleCal <bubble-cal@outlook.com>
This commit is contained in:
BubbleCal
2024-11-19 23:12:25 +08:00
committed by GitHub
parent f2e3989831
commit b2f88f0b29
10 changed files with 165 additions and 0 deletions

View File

@@ -131,6 +131,8 @@ class Query(pydantic.BaseModel):
fast_search: bool = False
ef: Optional[int] = None
class LanceQueryBuilder(ABC):
"""An abstract query builder. Subclasses are defined for vector search,
@@ -257,6 +259,7 @@ class LanceQueryBuilder(ABC):
self._with_row_id = False
self._vector = None
self._text = None
self._ef = None
@deprecation.deprecated(
deprecated_in="0.3.1",
@@ -638,6 +641,28 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
self._nprobes = nprobes
return self
def ef(self, ef: int) -> LanceVectorQueryBuilder:
"""Set the number of candidates to consider during search.
Higher values will yield better recall (more likely to find vectors if
they exist) at the expense of latency.
This only applies to the HNSW-related index.
The default value is 1.5 * limit.
Parameters
----------
ef: int
The number of candidates to consider during search.
Returns
-------
LanceVectorQueryBuilder
The LanceQueryBuilder object.
"""
self._ef = ef
return self
def refine_factor(self, refine_factor: int) -> LanceVectorQueryBuilder:
"""Set the refine factor to use, increasing the number of vectors sampled.
@@ -700,6 +725,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
with_row_id=self._with_row_id,
offset=self._offset,
fast_search=self._fast_search,
ef=self._ef,
)
result_set = self._table._execute_query(query, batch_size)
if self._reranker is not None:
@@ -1071,6 +1097,8 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
self._vector_query.nprobes(self._nprobes)
if self._refine_factor:
self._vector_query.refine_factor(self._refine_factor)
if self._ef:
self._vector_query.ef(self._ef)
with ThreadPoolExecutor() as executor:
fts_future = executor.submit(self._fts_query.with_row_id(True).to_arrow)
@@ -1197,6 +1225,29 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
self._nprobes = nprobes
return self
def ef(self, ef: int) -> LanceHybridQueryBuilder:
"""
Set the number of candidates to consider during search.
Higher values will yield better recall (more likely to find vectors if
they exist) at the expense of latency.
This only applies to the HNSW-related index.
The default value is 1.5 * limit.
Parameters
----------
ef: int
The number of candidates to consider during search.
Returns
-------
LanceHybridQueryBuilder
The LanceHybridQueryBuilder object.
"""
self._ef = ef
return self
def metric(self, metric: Literal["L2", "cosine", "dot"]) -> LanceHybridQueryBuilder:
"""Set the distance metric to use.
@@ -1644,6 +1695,21 @@ class AsyncVectorQuery(AsyncQueryBase):
self._inner.nprobes(nprobes)
return self
def ef(self, ef: int) -> AsyncVectorQuery:
"""
Set the number of candidates to consider during search
This argument is only used when the vector column has an HNSW index.
If there is no index then this value is ignored.
Increasing this value will increase the recall of your query but will also
increase the latency of your query. The default value is 1.5 * limit. This
default is good for many cases but the best value to use will depend on your
data and the recall that you need to achieve.
"""
self._inner.ef(ef)
return self
def refine_factor(self, refine_factor: int) -> AsyncVectorQuery:
"""
A multiplier to control how many additional rows are taken during the refine

View File

@@ -1959,6 +1959,7 @@ class LanceTable(Table):
"metric": query.metric,
"nprobes": query.nprobes,
"refine_factor": query.refine_factor,
"ef": query.ef,
}
return ds.scanner(
columns=query.columns,
@@ -2736,6 +2737,8 @@ class AsyncTable:
async_query = async_query.refine_factor(query.refine_factor)
if query.vector_column:
async_query = async_query.column(query.vector_column)
if query.ef:
async_query = async_query.ef(query.ef)
if not query.prefilter:
async_query = async_query.postfilter()

View File

@@ -185,6 +185,7 @@ def test_query_sync_minimal():
"k": 10,
"prefilter": False,
"refine_factor": None,
"ef": None,
"vector": [1.0, 2.0, 3.0],
"nprobes": 20,
}
@@ -223,6 +224,7 @@ def test_query_sync_maximal():
"refine_factor": 10,
"vector": [1.0, 2.0, 3.0],
"nprobes": 5,
"ef": None,
"filter": "id > 0",
"columns": ["id", "name"],
"vector_column": "vector2",
@@ -318,6 +320,7 @@ def test_query_sync_hybrid():
"refine_factor": None,
"vector": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
"nprobes": 20,
"ef": None,
"with_row_id": True,
}
return pa.table({"_rowid": [1, 2, 3], "_distance": [0.1, 0.2, 0.3]})