mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-27 15:12:53 +00:00
feat: support to sepcify ef search param (#1844)
Signed-off-by: BubbleCal <bubble-cal@outlook.com>
This commit is contained in:
@@ -131,6 +131,8 @@ class Query(pydantic.BaseModel):
|
||||
|
||||
fast_search: bool = False
|
||||
|
||||
ef: Optional[int] = None
|
||||
|
||||
|
||||
class LanceQueryBuilder(ABC):
|
||||
"""An abstract query builder. Subclasses are defined for vector search,
|
||||
@@ -257,6 +259,7 @@ class LanceQueryBuilder(ABC):
|
||||
self._with_row_id = False
|
||||
self._vector = None
|
||||
self._text = None
|
||||
self._ef = None
|
||||
|
||||
@deprecation.deprecated(
|
||||
deprecated_in="0.3.1",
|
||||
@@ -638,6 +641,28 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
self._nprobes = nprobes
|
||||
return self
|
||||
|
||||
def ef(self, ef: int) -> LanceVectorQueryBuilder:
|
||||
"""Set the number of candidates to consider during search.
|
||||
|
||||
Higher values will yield better recall (more likely to find vectors if
|
||||
they exist) at the expense of latency.
|
||||
|
||||
This only applies to the HNSW-related index.
|
||||
The default value is 1.5 * limit.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ef: int
|
||||
The number of candidates to consider during search.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceVectorQueryBuilder
|
||||
The LanceQueryBuilder object.
|
||||
"""
|
||||
self._ef = ef
|
||||
return self
|
||||
|
||||
def refine_factor(self, refine_factor: int) -> LanceVectorQueryBuilder:
|
||||
"""Set the refine factor to use, increasing the number of vectors sampled.
|
||||
|
||||
@@ -700,6 +725,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
with_row_id=self._with_row_id,
|
||||
offset=self._offset,
|
||||
fast_search=self._fast_search,
|
||||
ef=self._ef,
|
||||
)
|
||||
result_set = self._table._execute_query(query, batch_size)
|
||||
if self._reranker is not None:
|
||||
@@ -1071,6 +1097,8 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||
self._vector_query.nprobes(self._nprobes)
|
||||
if self._refine_factor:
|
||||
self._vector_query.refine_factor(self._refine_factor)
|
||||
if self._ef:
|
||||
self._vector_query.ef(self._ef)
|
||||
|
||||
with ThreadPoolExecutor() as executor:
|
||||
fts_future = executor.submit(self._fts_query.with_row_id(True).to_arrow)
|
||||
@@ -1197,6 +1225,29 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||
self._nprobes = nprobes
|
||||
return self
|
||||
|
||||
def ef(self, ef: int) -> LanceHybridQueryBuilder:
|
||||
"""
|
||||
Set the number of candidates to consider during search.
|
||||
|
||||
Higher values will yield better recall (more likely to find vectors if
|
||||
they exist) at the expense of latency.
|
||||
|
||||
This only applies to the HNSW-related index.
|
||||
The default value is 1.5 * limit.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ef: int
|
||||
The number of candidates to consider during search.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceHybridQueryBuilder
|
||||
The LanceHybridQueryBuilder object.
|
||||
"""
|
||||
self._ef = ef
|
||||
return self
|
||||
|
||||
def metric(self, metric: Literal["L2", "cosine", "dot"]) -> LanceHybridQueryBuilder:
|
||||
"""Set the distance metric to use.
|
||||
|
||||
@@ -1644,6 +1695,21 @@ class AsyncVectorQuery(AsyncQueryBase):
|
||||
self._inner.nprobes(nprobes)
|
||||
return self
|
||||
|
||||
def ef(self, ef: int) -> AsyncVectorQuery:
|
||||
"""
|
||||
Set the number of candidates to consider during search
|
||||
|
||||
This argument is only used when the vector column has an HNSW index.
|
||||
If there is no index then this value is ignored.
|
||||
|
||||
Increasing this value will increase the recall of your query but will also
|
||||
increase the latency of your query. The default value is 1.5 * limit. This
|
||||
default is good for many cases but the best value to use will depend on your
|
||||
data and the recall that you need to achieve.
|
||||
"""
|
||||
self._inner.ef(ef)
|
||||
return self
|
||||
|
||||
def refine_factor(self, refine_factor: int) -> AsyncVectorQuery:
|
||||
"""
|
||||
A multiplier to control how many additional rows are taken during the refine
|
||||
|
||||
@@ -1959,6 +1959,7 @@ class LanceTable(Table):
|
||||
"metric": query.metric,
|
||||
"nprobes": query.nprobes,
|
||||
"refine_factor": query.refine_factor,
|
||||
"ef": query.ef,
|
||||
}
|
||||
return ds.scanner(
|
||||
columns=query.columns,
|
||||
@@ -2736,6 +2737,8 @@ class AsyncTable:
|
||||
async_query = async_query.refine_factor(query.refine_factor)
|
||||
if query.vector_column:
|
||||
async_query = async_query.column(query.vector_column)
|
||||
if query.ef:
|
||||
async_query = async_query.ef(query.ef)
|
||||
|
||||
if not query.prefilter:
|
||||
async_query = async_query.postfilter()
|
||||
|
||||
@@ -185,6 +185,7 @@ def test_query_sync_minimal():
|
||||
"k": 10,
|
||||
"prefilter": False,
|
||||
"refine_factor": None,
|
||||
"ef": None,
|
||||
"vector": [1.0, 2.0, 3.0],
|
||||
"nprobes": 20,
|
||||
}
|
||||
@@ -223,6 +224,7 @@ def test_query_sync_maximal():
|
||||
"refine_factor": 10,
|
||||
"vector": [1.0, 2.0, 3.0],
|
||||
"nprobes": 5,
|
||||
"ef": None,
|
||||
"filter": "id > 0",
|
||||
"columns": ["id", "name"],
|
||||
"vector_column": "vector2",
|
||||
@@ -318,6 +320,7 @@ def test_query_sync_hybrid():
|
||||
"refine_factor": None,
|
||||
"vector": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
|
||||
"nprobes": 20,
|
||||
"ef": None,
|
||||
"with_row_id": True,
|
||||
}
|
||||
return pa.table({"_rowid": [1, 2, 3], "_distance": [0.1, 0.2, 0.3]})
|
||||
|
||||
Reference in New Issue
Block a user