From b2f88f0b29d4035add2d5dd0aba2068a0f0d5705 Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Tue, 19 Nov 2024 23:12:25 +0800 Subject: [PATCH] feat: support to sepcify ef search param (#1844) Signed-off-by: BubbleCal --- nodejs/__test__/table.test.ts | 48 +++++++++++++++++++ nodejs/lancedb/query.ts | 14 ++++++ nodejs/src/query.rs | 5 ++ python/python/lancedb/query.py | 66 +++++++++++++++++++++++++++ python/python/lancedb/table.py | 3 ++ python/python/tests/test_remote_db.py | 3 ++ python/src/query.rs | 4 ++ rust/lancedb/src/query.rs | 16 +++++++ rust/lancedb/src/remote/table.rs | 3 ++ rust/lancedb/src/table.rs | 3 ++ 10 files changed, 165 insertions(+) diff --git a/nodejs/__test__/table.test.ts b/nodejs/__test__/table.test.ts index 39289c0d..6e002361 100644 --- a/nodejs/__test__/table.test.ts +++ b/nodejs/__test__/table.test.ts @@ -477,6 +477,54 @@ describe("When creating an index", () => { expect(rst.numRows).toBe(1); }); + it("should create and search IVF_HNSW indices", async () => { + await tbl.createIndex("vec", { + config: Index.hnswSq(), + }); + + // check index directory + const indexDir = path.join(tmpDir.name, "test.lance", "_indices"); + expect(fs.readdirSync(indexDir)).toHaveLength(1); + const indices = await tbl.listIndices(); + expect(indices.length).toBe(1); + expect(indices[0]).toEqual({ + name: "vec_idx", + indexType: "IvfHnswSq", + columns: ["vec"], + }); + + // Search without specifying the column + let rst = await tbl + .query() + .limit(2) + .nearestTo(queryVec) + .distanceType("dot") + .toArrow(); + expect(rst.numRows).toBe(2); + + // Search using `vectorSearch` + rst = await tbl.vectorSearch(queryVec).limit(2).toArrow(); + expect(rst.numRows).toBe(2); + + // Search with specifying the column + const rst2 = await tbl + .query() + .limit(2) + .nearestTo(queryVec) + .column("vec") + .toArrow(); + expect(rst2.numRows).toBe(2); + expect(rst.toString()).toEqual(rst2.toString()); + + // test offset + rst = await tbl.query().limit(2).offset(1).nearestTo(queryVec).toArrow(); + expect(rst.numRows).toBe(1); + + // test ef + rst = await tbl.query().limit(2).nearestTo(queryVec).ef(100).toArrow(); + expect(rst.numRows).toBe(2); + }); + it("should be able to query unindexed data", async () => { await tbl.createIndex("vec"); await tbl.add([ diff --git a/nodejs/lancedb/query.ts b/nodejs/lancedb/query.ts index d29babb3..25fabf70 100644 --- a/nodejs/lancedb/query.ts +++ b/nodejs/lancedb/query.ts @@ -385,6 +385,20 @@ export class VectorQuery extends QueryBase { return this; } + /** + * Set the number of candidates to consider during the search + * + * This argument is only used when the vector column has an HNSW index. + * If there is no index then this value is ignored. + * + * Increasing this value will increase the recall of your query but will + * also increase the latency of your query. The default value is 1.5*limit. + */ + ef(ef: number): VectorQuery { + super.doCall((inner) => inner.ef(ef)); + return this; + } + /** * Set the vector column to query * diff --git a/nodejs/src/query.rs b/nodejs/src/query.rs index 57eb24c4..fd8e3b48 100644 --- a/nodejs/src/query.rs +++ b/nodejs/src/query.rs @@ -167,6 +167,11 @@ impl VectorQuery { self.inner = self.inner.clone().nprobes(nprobe as usize); } + #[napi] + pub fn ef(&mut self, ef: u32) { + self.inner = self.inner.clone().ef(ef as usize); + } + #[napi] pub fn bypass_vector_index(&mut self) { self.inner = self.inner.clone().bypass_vector_index() diff --git a/python/python/lancedb/query.py b/python/python/lancedb/query.py index e9886f45..bee9f1ce 100644 --- a/python/python/lancedb/query.py +++ b/python/python/lancedb/query.py @@ -131,6 +131,8 @@ class Query(pydantic.BaseModel): fast_search: bool = False + ef: Optional[int] = None + class LanceQueryBuilder(ABC): """An abstract query builder. Subclasses are defined for vector search, @@ -257,6 +259,7 @@ class LanceQueryBuilder(ABC): self._with_row_id = False self._vector = None self._text = None + self._ef = None @deprecation.deprecated( deprecated_in="0.3.1", @@ -638,6 +641,28 @@ class LanceVectorQueryBuilder(LanceQueryBuilder): self._nprobes = nprobes return self + def ef(self, ef: int) -> LanceVectorQueryBuilder: + """Set the number of candidates to consider during search. + + Higher values will yield better recall (more likely to find vectors if + they exist) at the expense of latency. + + This only applies to the HNSW-related index. + The default value is 1.5 * limit. + + Parameters + ---------- + ef: int + The number of candidates to consider during search. + + Returns + ------- + LanceVectorQueryBuilder + The LanceQueryBuilder object. + """ + self._ef = ef + return self + def refine_factor(self, refine_factor: int) -> LanceVectorQueryBuilder: """Set the refine factor to use, increasing the number of vectors sampled. @@ -700,6 +725,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder): with_row_id=self._with_row_id, offset=self._offset, fast_search=self._fast_search, + ef=self._ef, ) result_set = self._table._execute_query(query, batch_size) if self._reranker is not None: @@ -1071,6 +1097,8 @@ class LanceHybridQueryBuilder(LanceQueryBuilder): self._vector_query.nprobes(self._nprobes) if self._refine_factor: self._vector_query.refine_factor(self._refine_factor) + if self._ef: + self._vector_query.ef(self._ef) with ThreadPoolExecutor() as executor: fts_future = executor.submit(self._fts_query.with_row_id(True).to_arrow) @@ -1197,6 +1225,29 @@ class LanceHybridQueryBuilder(LanceQueryBuilder): self._nprobes = nprobes return self + def ef(self, ef: int) -> LanceHybridQueryBuilder: + """ + Set the number of candidates to consider during search. + + Higher values will yield better recall (more likely to find vectors if + they exist) at the expense of latency. + + This only applies to the HNSW-related index. + The default value is 1.5 * limit. + + Parameters + ---------- + ef: int + The number of candidates to consider during search. + + Returns + ------- + LanceHybridQueryBuilder + The LanceHybridQueryBuilder object. + """ + self._ef = ef + return self + def metric(self, metric: Literal["L2", "cosine", "dot"]) -> LanceHybridQueryBuilder: """Set the distance metric to use. @@ -1644,6 +1695,21 @@ class AsyncVectorQuery(AsyncQueryBase): self._inner.nprobes(nprobes) return self + def ef(self, ef: int) -> AsyncVectorQuery: + """ + Set the number of candidates to consider during search + + This argument is only used when the vector column has an HNSW index. + If there is no index then this value is ignored. + + Increasing this value will increase the recall of your query but will also + increase the latency of your query. The default value is 1.5 * limit. This + default is good for many cases but the best value to use will depend on your + data and the recall that you need to achieve. + """ + self._inner.ef(ef) + return self + def refine_factor(self, refine_factor: int) -> AsyncVectorQuery: """ A multiplier to control how many additional rows are taken during the refine diff --git a/python/python/lancedb/table.py b/python/python/lancedb/table.py index e2bd64c2..eee14dd9 100644 --- a/python/python/lancedb/table.py +++ b/python/python/lancedb/table.py @@ -1959,6 +1959,7 @@ class LanceTable(Table): "metric": query.metric, "nprobes": query.nprobes, "refine_factor": query.refine_factor, + "ef": query.ef, } return ds.scanner( columns=query.columns, @@ -2736,6 +2737,8 @@ class AsyncTable: async_query = async_query.refine_factor(query.refine_factor) if query.vector_column: async_query = async_query.column(query.vector_column) + if query.ef: + async_query = async_query.ef(query.ef) if not query.prefilter: async_query = async_query.postfilter() diff --git a/python/python/tests/test_remote_db.py b/python/python/tests/test_remote_db.py index 7ac26103..cd0691e8 100644 --- a/python/python/tests/test_remote_db.py +++ b/python/python/tests/test_remote_db.py @@ -185,6 +185,7 @@ def test_query_sync_minimal(): "k": 10, "prefilter": False, "refine_factor": None, + "ef": None, "vector": [1.0, 2.0, 3.0], "nprobes": 20, } @@ -223,6 +224,7 @@ def test_query_sync_maximal(): "refine_factor": 10, "vector": [1.0, 2.0, 3.0], "nprobes": 5, + "ef": None, "filter": "id > 0", "columns": ["id", "name"], "vector_column": "vector2", @@ -318,6 +320,7 @@ def test_query_sync_hybrid(): "refine_factor": None, "vector": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], "nprobes": 20, + "ef": None, "with_row_id": True, } return pa.table({"_rowid": [1, 2, 3], "_distance": [0.1, 0.2, 0.3]}) diff --git a/python/src/query.rs b/python/src/query.rs index 8e4042df..5eeb4aa8 100644 --- a/python/src/query.rs +++ b/python/src/query.rs @@ -195,6 +195,10 @@ impl VectorQuery { self.inner = self.inner.clone().nprobes(nprobe as usize); } + pub fn ef(&mut self, ef: u32) { + self.inner = self.inner.clone().ef(ef as usize); + } + pub fn bypass_vector_index(&mut self) { self.inner = self.inner.clone().bypass_vector_index() } diff --git a/rust/lancedb/src/query.rs b/rust/lancedb/src/query.rs index db9bf311..54c344e1 100644 --- a/rust/lancedb/src/query.rs +++ b/rust/lancedb/src/query.rs @@ -704,6 +704,9 @@ pub struct VectorQuery { // IVF PQ - ANN search. pub(crate) query_vector: Vec>, pub(crate) nprobes: usize, + // The number of candidates to return during the refine step for HNSW, + // defaults to 1.5 * limit. + pub(crate) ef: Option, pub(crate) refine_factor: Option, pub(crate) distance_type: Option, /// Default is true. Set to false to enforce a brute force search. @@ -717,6 +720,7 @@ impl VectorQuery { column: None, query_vector: Vec::new(), nprobes: 20, + ef: None, refine_factor: None, distance_type: None, use_index: true, @@ -776,6 +780,18 @@ impl VectorQuery { self } + /// Set the number of candidates to return during the refine step for HNSW + /// + /// This argument is only used when the vector column has an HNSW index. + /// If there is no index then this value is ignored. + /// + /// Increasing this value will increase the recall of your query but will + /// also increase the latency of your query. The default value is 1.5*limit. + pub fn ef(mut self, ef: usize) -> Self { + self.ef = Some(ef); + self + } + /// A multiplier to control how many additional rows are taken during the refine step /// /// This argument is only used when the vector column has an IVF PQ index. diff --git a/rust/lancedb/src/remote/table.rs b/rust/lancedb/src/remote/table.rs index 55cabf95..30fb59e2 100644 --- a/rust/lancedb/src/remote/table.rs +++ b/rust/lancedb/src/remote/table.rs @@ -196,6 +196,7 @@ impl RemoteTable { body["prefilter"] = query.base.prefilter.into(); body["distance_type"] = serde_json::json!(query.distance_type.unwrap_or_default()); body["nprobes"] = query.nprobes.into(); + body["ef"] = query.ef.into(); body["refine_factor"] = query.refine_factor.into(); if let Some(vector_column) = query.column.as_ref() { body["vector_column"] = serde_json::Value::String(vector_column.clone()); @@ -1121,6 +1122,7 @@ mod tests { "prefilter": true, "distance_type": "l2", "nprobes": 20, + "ef": Option::::None, "refine_factor": null, }); // Pass vector separately to make sure it matches f32 precision. @@ -1166,6 +1168,7 @@ mod tests { "bypass_vector_index": true, "columns": ["a", "b"], "nprobes": 12, + "ef": Option::::None, "refine_factor": 2, }); // Pass vector separately to make sure it matches f32 precision. diff --git a/rust/lancedb/src/table.rs b/rust/lancedb/src/table.rs index 11415e52..8b4f9cee 100644 --- a/rust/lancedb/src/table.rs +++ b/rust/lancedb/src/table.rs @@ -1904,6 +1904,9 @@ impl TableInternal for NativeTable { query.base.offset.map(|offset| offset as i64), )?; scanner.nprobs(query.nprobes); + if let Some(ef) = query.ef { + scanner.ef(ef); + } scanner.use_index(query.use_index); scanner.prefilter(query.base.prefilter); match query.base.select {