feat: support to sepcify ef search param (#1844)

Signed-off-by: BubbleCal <bubble-cal@outlook.com>
This commit is contained in:
BubbleCal
2024-11-19 23:12:25 +08:00
committed by GitHub
parent f2e3989831
commit b2f88f0b29
10 changed files with 165 additions and 0 deletions

View File

@@ -477,6 +477,54 @@ describe("When creating an index", () => {
expect(rst.numRows).toBe(1);
});
it("should create and search IVF_HNSW indices", async () => {
await tbl.createIndex("vec", {
config: Index.hnswSq(),
});
// check index directory
const indexDir = path.join(tmpDir.name, "test.lance", "_indices");
expect(fs.readdirSync(indexDir)).toHaveLength(1);
const indices = await tbl.listIndices();
expect(indices.length).toBe(1);
expect(indices[0]).toEqual({
name: "vec_idx",
indexType: "IvfHnswSq",
columns: ["vec"],
});
// Search without specifying the column
let rst = await tbl
.query()
.limit(2)
.nearestTo(queryVec)
.distanceType("dot")
.toArrow();
expect(rst.numRows).toBe(2);
// Search using `vectorSearch`
rst = await tbl.vectorSearch(queryVec).limit(2).toArrow();
expect(rst.numRows).toBe(2);
// Search with specifying the column
const rst2 = await tbl
.query()
.limit(2)
.nearestTo(queryVec)
.column("vec")
.toArrow();
expect(rst2.numRows).toBe(2);
expect(rst.toString()).toEqual(rst2.toString());
// test offset
rst = await tbl.query().limit(2).offset(1).nearestTo(queryVec).toArrow();
expect(rst.numRows).toBe(1);
// test ef
rst = await tbl.query().limit(2).nearestTo(queryVec).ef(100).toArrow();
expect(rst.numRows).toBe(2);
});
it("should be able to query unindexed data", async () => {
await tbl.createIndex("vec");
await tbl.add([

View File

@@ -385,6 +385,20 @@ export class VectorQuery extends QueryBase<NativeVectorQuery> {
return this;
}
/**
* Set the number of candidates to consider during the search
*
* This argument is only used when the vector column has an HNSW index.
* If there is no index then this value is ignored.
*
* Increasing this value will increase the recall of your query but will
* also increase the latency of your query. The default value is 1.5*limit.
*/
ef(ef: number): VectorQuery {
super.doCall((inner) => inner.ef(ef));
return this;
}
/**
* Set the vector column to query
*

View File

@@ -167,6 +167,11 @@ impl VectorQuery {
self.inner = self.inner.clone().nprobes(nprobe as usize);
}
#[napi]
pub fn ef(&mut self, ef: u32) {
self.inner = self.inner.clone().ef(ef as usize);
}
#[napi]
pub fn bypass_vector_index(&mut self) {
self.inner = self.inner.clone().bypass_vector_index()

View File

@@ -131,6 +131,8 @@ class Query(pydantic.BaseModel):
fast_search: bool = False
ef: Optional[int] = None
class LanceQueryBuilder(ABC):
"""An abstract query builder. Subclasses are defined for vector search,
@@ -257,6 +259,7 @@ class LanceQueryBuilder(ABC):
self._with_row_id = False
self._vector = None
self._text = None
self._ef = None
@deprecation.deprecated(
deprecated_in="0.3.1",
@@ -638,6 +641,28 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
self._nprobes = nprobes
return self
def ef(self, ef: int) -> LanceVectorQueryBuilder:
"""Set the number of candidates to consider during search.
Higher values will yield better recall (more likely to find vectors if
they exist) at the expense of latency.
This only applies to the HNSW-related index.
The default value is 1.5 * limit.
Parameters
----------
ef: int
The number of candidates to consider during search.
Returns
-------
LanceVectorQueryBuilder
The LanceQueryBuilder object.
"""
self._ef = ef
return self
def refine_factor(self, refine_factor: int) -> LanceVectorQueryBuilder:
"""Set the refine factor to use, increasing the number of vectors sampled.
@@ -700,6 +725,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
with_row_id=self._with_row_id,
offset=self._offset,
fast_search=self._fast_search,
ef=self._ef,
)
result_set = self._table._execute_query(query, batch_size)
if self._reranker is not None:
@@ -1071,6 +1097,8 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
self._vector_query.nprobes(self._nprobes)
if self._refine_factor:
self._vector_query.refine_factor(self._refine_factor)
if self._ef:
self._vector_query.ef(self._ef)
with ThreadPoolExecutor() as executor:
fts_future = executor.submit(self._fts_query.with_row_id(True).to_arrow)
@@ -1197,6 +1225,29 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
self._nprobes = nprobes
return self
def ef(self, ef: int) -> LanceHybridQueryBuilder:
"""
Set the number of candidates to consider during search.
Higher values will yield better recall (more likely to find vectors if
they exist) at the expense of latency.
This only applies to the HNSW-related index.
The default value is 1.5 * limit.
Parameters
----------
ef: int
The number of candidates to consider during search.
Returns
-------
LanceHybridQueryBuilder
The LanceHybridQueryBuilder object.
"""
self._ef = ef
return self
def metric(self, metric: Literal["L2", "cosine", "dot"]) -> LanceHybridQueryBuilder:
"""Set the distance metric to use.
@@ -1644,6 +1695,21 @@ class AsyncVectorQuery(AsyncQueryBase):
self._inner.nprobes(nprobes)
return self
def ef(self, ef: int) -> AsyncVectorQuery:
"""
Set the number of candidates to consider during search
This argument is only used when the vector column has an HNSW index.
If there is no index then this value is ignored.
Increasing this value will increase the recall of your query but will also
increase the latency of your query. The default value is 1.5 * limit. This
default is good for many cases but the best value to use will depend on your
data and the recall that you need to achieve.
"""
self._inner.ef(ef)
return self
def refine_factor(self, refine_factor: int) -> AsyncVectorQuery:
"""
A multiplier to control how many additional rows are taken during the refine

View File

@@ -1959,6 +1959,7 @@ class LanceTable(Table):
"metric": query.metric,
"nprobes": query.nprobes,
"refine_factor": query.refine_factor,
"ef": query.ef,
}
return ds.scanner(
columns=query.columns,
@@ -2736,6 +2737,8 @@ class AsyncTable:
async_query = async_query.refine_factor(query.refine_factor)
if query.vector_column:
async_query = async_query.column(query.vector_column)
if query.ef:
async_query = async_query.ef(query.ef)
if not query.prefilter:
async_query = async_query.postfilter()

View File

@@ -185,6 +185,7 @@ def test_query_sync_minimal():
"k": 10,
"prefilter": False,
"refine_factor": None,
"ef": None,
"vector": [1.0, 2.0, 3.0],
"nprobes": 20,
}
@@ -223,6 +224,7 @@ def test_query_sync_maximal():
"refine_factor": 10,
"vector": [1.0, 2.0, 3.0],
"nprobes": 5,
"ef": None,
"filter": "id > 0",
"columns": ["id", "name"],
"vector_column": "vector2",
@@ -318,6 +320,7 @@ def test_query_sync_hybrid():
"refine_factor": None,
"vector": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
"nprobes": 20,
"ef": None,
"with_row_id": True,
}
return pa.table({"_rowid": [1, 2, 3], "_distance": [0.1, 0.2, 0.3]})

View File

@@ -195,6 +195,10 @@ impl VectorQuery {
self.inner = self.inner.clone().nprobes(nprobe as usize);
}
pub fn ef(&mut self, ef: u32) {
self.inner = self.inner.clone().ef(ef as usize);
}
pub fn bypass_vector_index(&mut self) {
self.inner = self.inner.clone().bypass_vector_index()
}

View File

@@ -704,6 +704,9 @@ pub struct VectorQuery {
// IVF PQ - ANN search.
pub(crate) query_vector: Vec<Arc<dyn Array>>,
pub(crate) nprobes: usize,
// The number of candidates to return during the refine step for HNSW,
// defaults to 1.5 * limit.
pub(crate) ef: Option<usize>,
pub(crate) refine_factor: Option<u32>,
pub(crate) distance_type: Option<DistanceType>,
/// Default is true. Set to false to enforce a brute force search.
@@ -717,6 +720,7 @@ impl VectorQuery {
column: None,
query_vector: Vec::new(),
nprobes: 20,
ef: None,
refine_factor: None,
distance_type: None,
use_index: true,
@@ -776,6 +780,18 @@ impl VectorQuery {
self
}
/// Set the number of candidates to return during the refine step for HNSW
///
/// This argument is only used when the vector column has an HNSW index.
/// If there is no index then this value is ignored.
///
/// Increasing this value will increase the recall of your query but will
/// also increase the latency of your query. The default value is 1.5*limit.
pub fn ef(mut self, ef: usize) -> Self {
self.ef = Some(ef);
self
}
/// A multiplier to control how many additional rows are taken during the refine step
///
/// This argument is only used when the vector column has an IVF PQ index.

View File

@@ -196,6 +196,7 @@ impl<S: HttpSend> RemoteTable<S> {
body["prefilter"] = query.base.prefilter.into();
body["distance_type"] = serde_json::json!(query.distance_type.unwrap_or_default());
body["nprobes"] = query.nprobes.into();
body["ef"] = query.ef.into();
body["refine_factor"] = query.refine_factor.into();
if let Some(vector_column) = query.column.as_ref() {
body["vector_column"] = serde_json::Value::String(vector_column.clone());
@@ -1121,6 +1122,7 @@ mod tests {
"prefilter": true,
"distance_type": "l2",
"nprobes": 20,
"ef": Option::<usize>::None,
"refine_factor": null,
});
// Pass vector separately to make sure it matches f32 precision.
@@ -1166,6 +1168,7 @@ mod tests {
"bypass_vector_index": true,
"columns": ["a", "b"],
"nprobes": 12,
"ef": Option::<usize>::None,
"refine_factor": 2,
});
// Pass vector separately to make sure it matches f32 precision.

View File

@@ -1904,6 +1904,9 @@ impl TableInternal for NativeTable {
query.base.offset.map(|offset| offset as i64),
)?;
scanner.nprobs(query.nprobes);
if let Some(ef) = query.ef {
scanner.ef(ef);
}
scanner.use_index(query.use_index);
scanner.prefilter(query.base.prefilter);
match query.base.select {