mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-10 05:42:58 +00:00
feat: support to sepcify ef search param (#1844)
Signed-off-by: BubbleCal <bubble-cal@outlook.com>
This commit is contained in:
@@ -477,6 +477,54 @@ describe("When creating an index", () => {
|
||||
expect(rst.numRows).toBe(1);
|
||||
});
|
||||
|
||||
it("should create and search IVF_HNSW indices", async () => {
|
||||
await tbl.createIndex("vec", {
|
||||
config: Index.hnswSq(),
|
||||
});
|
||||
|
||||
// check index directory
|
||||
const indexDir = path.join(tmpDir.name, "test.lance", "_indices");
|
||||
expect(fs.readdirSync(indexDir)).toHaveLength(1);
|
||||
const indices = await tbl.listIndices();
|
||||
expect(indices.length).toBe(1);
|
||||
expect(indices[0]).toEqual({
|
||||
name: "vec_idx",
|
||||
indexType: "IvfHnswSq",
|
||||
columns: ["vec"],
|
||||
});
|
||||
|
||||
// Search without specifying the column
|
||||
let rst = await tbl
|
||||
.query()
|
||||
.limit(2)
|
||||
.nearestTo(queryVec)
|
||||
.distanceType("dot")
|
||||
.toArrow();
|
||||
expect(rst.numRows).toBe(2);
|
||||
|
||||
// Search using `vectorSearch`
|
||||
rst = await tbl.vectorSearch(queryVec).limit(2).toArrow();
|
||||
expect(rst.numRows).toBe(2);
|
||||
|
||||
// Search with specifying the column
|
||||
const rst2 = await tbl
|
||||
.query()
|
||||
.limit(2)
|
||||
.nearestTo(queryVec)
|
||||
.column("vec")
|
||||
.toArrow();
|
||||
expect(rst2.numRows).toBe(2);
|
||||
expect(rst.toString()).toEqual(rst2.toString());
|
||||
|
||||
// test offset
|
||||
rst = await tbl.query().limit(2).offset(1).nearestTo(queryVec).toArrow();
|
||||
expect(rst.numRows).toBe(1);
|
||||
|
||||
// test ef
|
||||
rst = await tbl.query().limit(2).nearestTo(queryVec).ef(100).toArrow();
|
||||
expect(rst.numRows).toBe(2);
|
||||
});
|
||||
|
||||
it("should be able to query unindexed data", async () => {
|
||||
await tbl.createIndex("vec");
|
||||
await tbl.add([
|
||||
|
||||
@@ -385,6 +385,20 @@ export class VectorQuery extends QueryBase<NativeVectorQuery> {
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the number of candidates to consider during the search
|
||||
*
|
||||
* This argument is only used when the vector column has an HNSW index.
|
||||
* If there is no index then this value is ignored.
|
||||
*
|
||||
* Increasing this value will increase the recall of your query but will
|
||||
* also increase the latency of your query. The default value is 1.5*limit.
|
||||
*/
|
||||
ef(ef: number): VectorQuery {
|
||||
super.doCall((inner) => inner.ef(ef));
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the vector column to query
|
||||
*
|
||||
|
||||
@@ -167,6 +167,11 @@ impl VectorQuery {
|
||||
self.inner = self.inner.clone().nprobes(nprobe as usize);
|
||||
}
|
||||
|
||||
#[napi]
|
||||
pub fn ef(&mut self, ef: u32) {
|
||||
self.inner = self.inner.clone().ef(ef as usize);
|
||||
}
|
||||
|
||||
#[napi]
|
||||
pub fn bypass_vector_index(&mut self) {
|
||||
self.inner = self.inner.clone().bypass_vector_index()
|
||||
|
||||
@@ -131,6 +131,8 @@ class Query(pydantic.BaseModel):
|
||||
|
||||
fast_search: bool = False
|
||||
|
||||
ef: Optional[int] = None
|
||||
|
||||
|
||||
class LanceQueryBuilder(ABC):
|
||||
"""An abstract query builder. Subclasses are defined for vector search,
|
||||
@@ -257,6 +259,7 @@ class LanceQueryBuilder(ABC):
|
||||
self._with_row_id = False
|
||||
self._vector = None
|
||||
self._text = None
|
||||
self._ef = None
|
||||
|
||||
@deprecation.deprecated(
|
||||
deprecated_in="0.3.1",
|
||||
@@ -638,6 +641,28 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
self._nprobes = nprobes
|
||||
return self
|
||||
|
||||
def ef(self, ef: int) -> LanceVectorQueryBuilder:
|
||||
"""Set the number of candidates to consider during search.
|
||||
|
||||
Higher values will yield better recall (more likely to find vectors if
|
||||
they exist) at the expense of latency.
|
||||
|
||||
This only applies to the HNSW-related index.
|
||||
The default value is 1.5 * limit.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ef: int
|
||||
The number of candidates to consider during search.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceVectorQueryBuilder
|
||||
The LanceQueryBuilder object.
|
||||
"""
|
||||
self._ef = ef
|
||||
return self
|
||||
|
||||
def refine_factor(self, refine_factor: int) -> LanceVectorQueryBuilder:
|
||||
"""Set the refine factor to use, increasing the number of vectors sampled.
|
||||
|
||||
@@ -700,6 +725,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
with_row_id=self._with_row_id,
|
||||
offset=self._offset,
|
||||
fast_search=self._fast_search,
|
||||
ef=self._ef,
|
||||
)
|
||||
result_set = self._table._execute_query(query, batch_size)
|
||||
if self._reranker is not None:
|
||||
@@ -1071,6 +1097,8 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||
self._vector_query.nprobes(self._nprobes)
|
||||
if self._refine_factor:
|
||||
self._vector_query.refine_factor(self._refine_factor)
|
||||
if self._ef:
|
||||
self._vector_query.ef(self._ef)
|
||||
|
||||
with ThreadPoolExecutor() as executor:
|
||||
fts_future = executor.submit(self._fts_query.with_row_id(True).to_arrow)
|
||||
@@ -1197,6 +1225,29 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||
self._nprobes = nprobes
|
||||
return self
|
||||
|
||||
def ef(self, ef: int) -> LanceHybridQueryBuilder:
|
||||
"""
|
||||
Set the number of candidates to consider during search.
|
||||
|
||||
Higher values will yield better recall (more likely to find vectors if
|
||||
they exist) at the expense of latency.
|
||||
|
||||
This only applies to the HNSW-related index.
|
||||
The default value is 1.5 * limit.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ef: int
|
||||
The number of candidates to consider during search.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceHybridQueryBuilder
|
||||
The LanceHybridQueryBuilder object.
|
||||
"""
|
||||
self._ef = ef
|
||||
return self
|
||||
|
||||
def metric(self, metric: Literal["L2", "cosine", "dot"]) -> LanceHybridQueryBuilder:
|
||||
"""Set the distance metric to use.
|
||||
|
||||
@@ -1644,6 +1695,21 @@ class AsyncVectorQuery(AsyncQueryBase):
|
||||
self._inner.nprobes(nprobes)
|
||||
return self
|
||||
|
||||
def ef(self, ef: int) -> AsyncVectorQuery:
|
||||
"""
|
||||
Set the number of candidates to consider during search
|
||||
|
||||
This argument is only used when the vector column has an HNSW index.
|
||||
If there is no index then this value is ignored.
|
||||
|
||||
Increasing this value will increase the recall of your query but will also
|
||||
increase the latency of your query. The default value is 1.5 * limit. This
|
||||
default is good for many cases but the best value to use will depend on your
|
||||
data and the recall that you need to achieve.
|
||||
"""
|
||||
self._inner.ef(ef)
|
||||
return self
|
||||
|
||||
def refine_factor(self, refine_factor: int) -> AsyncVectorQuery:
|
||||
"""
|
||||
A multiplier to control how many additional rows are taken during the refine
|
||||
|
||||
@@ -1959,6 +1959,7 @@ class LanceTable(Table):
|
||||
"metric": query.metric,
|
||||
"nprobes": query.nprobes,
|
||||
"refine_factor": query.refine_factor,
|
||||
"ef": query.ef,
|
||||
}
|
||||
return ds.scanner(
|
||||
columns=query.columns,
|
||||
@@ -2736,6 +2737,8 @@ class AsyncTable:
|
||||
async_query = async_query.refine_factor(query.refine_factor)
|
||||
if query.vector_column:
|
||||
async_query = async_query.column(query.vector_column)
|
||||
if query.ef:
|
||||
async_query = async_query.ef(query.ef)
|
||||
|
||||
if not query.prefilter:
|
||||
async_query = async_query.postfilter()
|
||||
|
||||
@@ -185,6 +185,7 @@ def test_query_sync_minimal():
|
||||
"k": 10,
|
||||
"prefilter": False,
|
||||
"refine_factor": None,
|
||||
"ef": None,
|
||||
"vector": [1.0, 2.0, 3.0],
|
||||
"nprobes": 20,
|
||||
}
|
||||
@@ -223,6 +224,7 @@ def test_query_sync_maximal():
|
||||
"refine_factor": 10,
|
||||
"vector": [1.0, 2.0, 3.0],
|
||||
"nprobes": 5,
|
||||
"ef": None,
|
||||
"filter": "id > 0",
|
||||
"columns": ["id", "name"],
|
||||
"vector_column": "vector2",
|
||||
@@ -318,6 +320,7 @@ def test_query_sync_hybrid():
|
||||
"refine_factor": None,
|
||||
"vector": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
|
||||
"nprobes": 20,
|
||||
"ef": None,
|
||||
"with_row_id": True,
|
||||
}
|
||||
return pa.table({"_rowid": [1, 2, 3], "_distance": [0.1, 0.2, 0.3]})
|
||||
|
||||
@@ -195,6 +195,10 @@ impl VectorQuery {
|
||||
self.inner = self.inner.clone().nprobes(nprobe as usize);
|
||||
}
|
||||
|
||||
pub fn ef(&mut self, ef: u32) {
|
||||
self.inner = self.inner.clone().ef(ef as usize);
|
||||
}
|
||||
|
||||
pub fn bypass_vector_index(&mut self) {
|
||||
self.inner = self.inner.clone().bypass_vector_index()
|
||||
}
|
||||
|
||||
@@ -704,6 +704,9 @@ pub struct VectorQuery {
|
||||
// IVF PQ - ANN search.
|
||||
pub(crate) query_vector: Vec<Arc<dyn Array>>,
|
||||
pub(crate) nprobes: usize,
|
||||
// The number of candidates to return during the refine step for HNSW,
|
||||
// defaults to 1.5 * limit.
|
||||
pub(crate) ef: Option<usize>,
|
||||
pub(crate) refine_factor: Option<u32>,
|
||||
pub(crate) distance_type: Option<DistanceType>,
|
||||
/// Default is true. Set to false to enforce a brute force search.
|
||||
@@ -717,6 +720,7 @@ impl VectorQuery {
|
||||
column: None,
|
||||
query_vector: Vec::new(),
|
||||
nprobes: 20,
|
||||
ef: None,
|
||||
refine_factor: None,
|
||||
distance_type: None,
|
||||
use_index: true,
|
||||
@@ -776,6 +780,18 @@ impl VectorQuery {
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the number of candidates to return during the refine step for HNSW
|
||||
///
|
||||
/// This argument is only used when the vector column has an HNSW index.
|
||||
/// If there is no index then this value is ignored.
|
||||
///
|
||||
/// Increasing this value will increase the recall of your query but will
|
||||
/// also increase the latency of your query. The default value is 1.5*limit.
|
||||
pub fn ef(mut self, ef: usize) -> Self {
|
||||
self.ef = Some(ef);
|
||||
self
|
||||
}
|
||||
|
||||
/// A multiplier to control how many additional rows are taken during the refine step
|
||||
///
|
||||
/// This argument is only used when the vector column has an IVF PQ index.
|
||||
|
||||
@@ -196,6 +196,7 @@ impl<S: HttpSend> RemoteTable<S> {
|
||||
body["prefilter"] = query.base.prefilter.into();
|
||||
body["distance_type"] = serde_json::json!(query.distance_type.unwrap_or_default());
|
||||
body["nprobes"] = query.nprobes.into();
|
||||
body["ef"] = query.ef.into();
|
||||
body["refine_factor"] = query.refine_factor.into();
|
||||
if let Some(vector_column) = query.column.as_ref() {
|
||||
body["vector_column"] = serde_json::Value::String(vector_column.clone());
|
||||
@@ -1121,6 +1122,7 @@ mod tests {
|
||||
"prefilter": true,
|
||||
"distance_type": "l2",
|
||||
"nprobes": 20,
|
||||
"ef": Option::<usize>::None,
|
||||
"refine_factor": null,
|
||||
});
|
||||
// Pass vector separately to make sure it matches f32 precision.
|
||||
@@ -1166,6 +1168,7 @@ mod tests {
|
||||
"bypass_vector_index": true,
|
||||
"columns": ["a", "b"],
|
||||
"nprobes": 12,
|
||||
"ef": Option::<usize>::None,
|
||||
"refine_factor": 2,
|
||||
});
|
||||
// Pass vector separately to make sure it matches f32 precision.
|
||||
|
||||
@@ -1904,6 +1904,9 @@ impl TableInternal for NativeTable {
|
||||
query.base.offset.map(|offset| offset as i64),
|
||||
)?;
|
||||
scanner.nprobs(query.nprobes);
|
||||
if let Some(ef) = query.ef {
|
||||
scanner.ef(ef);
|
||||
}
|
||||
scanner.use_index(query.use_index);
|
||||
scanner.prefilter(query.base.prefilter);
|
||||
match query.base.select {
|
||||
|
||||
Reference in New Issue
Block a user