From 7ff72022ddf33639abe7a3a9a928a2987bd5ccfd Mon Sep 17 00:00:00 2001 From: Yang Cen Date: Wed, 17 Jun 2026 19:28:36 +0800 Subject: [PATCH] feat(query): add approx mode to vector queries (#3549) ## Feature ### What is the new feature? Adds Rust core API support for configuring vector query approximation mode with `ApproxMode::{Fast, Normal, Accurate}`. ### Why do we need this feature? Lance already exposes `lance_index::vector::ApproxMode` and scanner support for controlling the speed/accuracy tradeoff for approximate vector search. LanceDB Rust queries need to expose and pass this setting through for local/native and remote vector searches. ### How does it work? - Adds public `ApproxMode` in `rust/lancedb`, with lowercase serde, `Default::Normal`, parse/display, and conversions to/from Lance's `ApproxMode`. - Adds `approx_mode: Option` to `VectorQueryRequest` and a `VectorQuery::approx_mode(...)` builder. - Applies the mode to native/local Lance scanners after `nearest(...)` when explicitly set. - Sends `approx_mode` in remote query JSON only when explicitly set; default requests omit it. ## Validation - `cargo fmt --all` - `cargo test --quiet --features remote approx_mode` - `cargo test --quiet --features remote test_query_vector_default_values` - `cargo check --quiet --features remote --tests --examples` - `git diff --check` --- rust/lancedb/src/lib.rs | 76 ++++++++++- rust/lancedb/src/query.rs | 55 +++++++- rust/lancedb/src/remote/table.rs | 58 +++++++++ rust/lancedb/src/table/query.rs | 211 ++++++++++++++++++++++++++++++- 4 files changed, 391 insertions(+), 9 deletions(-) diff --git a/rust/lancedb/src/lib.rs b/rust/lancedb/src/lib.rs index 0c279d52b..e1fa0ec40 100644 --- a/rust/lancedb/src/lib.rs +++ b/rust/lancedb/src/lib.rs @@ -184,12 +184,13 @@ pub mod table; pub mod test_utils; pub mod utils; -use std::fmt::Display; +use std::{fmt::Display, str::FromStr}; use serde::{Deserialize, Serialize}; pub use connection::{ConnectNamespaceBuilder, Connection}; pub use error::{Error, Result}; +use lance_index::vector::ApproxMode as LanceApproxMode; use lance_linalg::distance::DistanceType as LanceDistanceType; pub use table::Table; @@ -258,6 +259,79 @@ impl Display for DistanceType { } } +/// Controls the speed / accuracy tradeoff for approximate vector search. +/// +/// This currently only affects RQ-quantized vector indexes, such as IVF_RQ. +/// Other index types ignore this setting. +#[derive(Debug, Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Default)] +#[non_exhaustive] +#[serde(rename_all = "lowercase")] +pub enum ApproxMode { + /// Prefer lower query latency, which can reduce recall. + Fast, + /// Use the default balance between query latency and recall. + #[default] + Normal, + /// Prefer higher recall, which can increase query latency. + Accurate, +} + +impl From for LanceApproxMode { + fn from(value: ApproxMode) -> Self { + match value { + ApproxMode::Fast => Self::Fast, + ApproxMode::Normal => Self::Normal, + ApproxMode::Accurate => Self::Accurate, + } + } +} + +impl From for ApproxMode { + fn from(value: LanceApproxMode) -> Self { + match value { + LanceApproxMode::Fast => Self::Fast, + LanceApproxMode::Normal => Self::Normal, + LanceApproxMode::Accurate => Self::Accurate, + } + } +} + +impl TryFrom<&str> for ApproxMode { + type Error = Error; + + fn try_from(value: &str) -> std::prelude::v1::Result { + Self::from_str(value) + } +} + +impl FromStr for ApproxMode { + type Err = Error; + + fn from_str(value: &str) -> std::prelude::v1::Result { + match value.to_ascii_lowercase().as_str() { + "fast" => Ok(Self::Fast), + "normal" => Ok(Self::Normal), + "accurate" => Ok(Self::Accurate), + _ => Err(Error::InvalidInput { + message: format!( + "approx_mode must be one of 'fast', 'normal', or 'accurate', got '{}'", + value + ), + }), + } + } +} + +impl Display for ApproxMode { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Fast => write!(f, "fast"), + Self::Normal => write!(f, "normal"), + Self::Accurate => write!(f, "accurate"), + } + } +} + /// Connect to a database pub use connection::connect; /// Connect to a namespace-backed database diff --git a/rust/lancedb/src/query.rs b/rust/lancedb/src/query.rs index 7f82f5517..893ea7f9b 100644 --- a/rust/lancedb/src/query.rs +++ b/rust/lancedb/src/query.rs @@ -20,12 +20,12 @@ use lance_index::scalar::FullTextSearchQuery; use lance_index::scalar::inverted::SCORE_COL; use lance_index::vector::DIST_COL; -use crate::DistanceType; use crate::error::{Error, Result}; use crate::rerankers::rrf::RRFReranker; use crate::rerankers::{NormalizeMethod, Reranker, check_reranker_result}; use crate::table::BaseTable; use crate::utils::{MaxBatchLengthStream, TimeoutStream}; +use crate::{ApproxMode, DistanceType}; use crate::{ arrow::{SendableRecordBatchStream, SimpleRecordBatchStream}, table::AnyQuery, @@ -935,6 +935,8 @@ pub struct VectorQueryRequest { pub refine_factor: Option, /// The distance type to use for the search pub distance_type: Option, + /// The speed / accuracy tradeoff to use for approximate vector search + pub approx_mode: Option, /// Default is true. Set to false to enforce a brute force search. pub use_index: bool, } @@ -952,6 +954,7 @@ impl Default for VectorQueryRequest { ef: None, refine_factor: None, distance_type: None, + approx_mode: None, use_index: true, } } @@ -1192,6 +1195,15 @@ impl VectorQuery { self } + /// Set the speed / accuracy tradeoff for approximate vector search. + /// + /// This setting is currently only used by RQ-quantized indexes, such as + /// IVF_RQ. Other index types ignore this setting. + pub fn approx_mode(mut self, approx_mode: ApproxMode) -> Self { + self.request.approx_mode = Some(approx_mode); + self + } + /// If this is called then any vector index is skipped /// /// An exhaustive (flat) search will be performed. The query vector will @@ -1546,6 +1558,7 @@ mod tests { .nprobes(1000) .postfilter() .distance_type(DistanceType::Cosine) + .approx_mode(ApproxMode::Accurate) .refine_factor(999); assert_eq!( @@ -1564,9 +1577,49 @@ mod tests { assert_eq!(query.request.maximum_nprobes, Some(1000)); assert!(query.request.use_index); assert_eq!(query.request.distance_type, Some(DistanceType::Cosine)); + assert_eq!(query.request.approx_mode, Some(ApproxMode::Accurate)); assert_eq!(query.request.refine_factor, Some(999)); } + #[test] + fn test_approx_mode_serde_parse_default_and_display() { + assert_eq!(ApproxMode::default(), ApproxMode::Normal); + assert_eq!( + serde_json::to_string(&ApproxMode::Fast).unwrap(), + "\"fast\"" + ); + assert_eq!( + serde_json::from_str::("\"accurate\"").unwrap(), + ApproxMode::Accurate + ); + assert_eq!("normal".parse::().unwrap(), ApproxMode::Normal); + assert_eq!(ApproxMode::try_from("FAST").unwrap(), ApproxMode::Fast); + assert_eq!(ApproxMode::Accurate.to_string(), "accurate"); + assert!(ApproxMode::try_from("invalid").is_err()); + } + + #[tokio::test] + async fn test_vector_query_approx_mode_builder() { + let tmp_dir = tempdir().unwrap(); + let dataset_path = tmp_dir.path().join("test.lance"); + let uri = dataset_path.to_str().unwrap(); + + let conn = connect(uri).execute().await.unwrap(); + let table = conn + .create_table("my_table", make_test_batches()) + .execute() + .await + .unwrap(); + + let query = table + .query() + .nearest_to(&[0.1, 0.2]) + .unwrap() + .approx_mode(ApproxMode::Fast); + + assert_eq!(query.request.approx_mode, Some(ApproxMode::Fast)); + } + #[tokio::test] async fn test_execute() { // TODO: Switch back to memory://foo after https://github.com/lancedb/lancedb/issues/1051 diff --git a/rust/lancedb/src/remote/table.rs b/rust/lancedb/src/remote/table.rs index f11f13957..0e016cc4c 100644 --- a/rust/lancedb/src/remote/table.rs +++ b/rust/lancedb/src/remote/table.rs @@ -706,6 +706,9 @@ impl RemoteTable { if let Some(distance_type) = query.distance_type { body["distance_type"] = serde_json::json!(distance_type); } + if let Some(approx_mode) = query.approx_mode { + body["approx_mode"] = serde_json::json!(approx_mode); + } // In 0.23.1 we migrated from `nprobes` to `minimum_nprobes` and `maximum_nprobes`. // Old client / new server: since minimum_nprobes is missing, fallback to nprobes // New client / old server: old server will only see nprobes, make sure to set both @@ -3610,6 +3613,61 @@ mod tests { assert_eq!(data[0].as_ref().unwrap(), &expected_data); } + #[tokio::test] + async fn test_query_vector_approx_mode_sent_when_set() { + let expected_data = RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])), + vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], + ) + .unwrap(); + let expected_data_ref = expected_data.clone(); + + let table = Table::new_with_handler("my_table", move |request| { + assert_eq!(request.method(), "POST"); + assert_eq!(request.url().path(), "/v1/table/my_table/query/"); + assert_eq!( + request.headers().get("Content-Type").unwrap(), + JSON_CONTENT_TYPE + ); + + let body = request.body().unwrap().as_bytes().unwrap(); + let body: serde_json::Value = serde_json::from_slice(body).unwrap(); + let mut expected_body = serde_json::json!({ + "prefilter": true, + "nprobes": 20, + "minimum_nprobes": 20, + "maximum_nprobes": 20, + "approx_mode": "accurate", + "lower_bound": Option::::None, + "upper_bound": Option::::None, + "k": 10, + "ef": Option::::None, + "refine_factor": null, + "version": null, + }); + expected_body["vector"] = vec![0.1f32, 0.2, 0.3].into(); + assert_eq!(body, expected_body); + + let response_body = write_ipc_file(&expected_data_ref); + http::Response::builder() + .status(200) + .header(CONTENT_TYPE, ARROW_FILE_CONTENT_TYPE) + .body(response_body) + .unwrap() + }); + + let data = table + .query() + .nearest_to(vec![0.1, 0.2, 0.3]) + .unwrap() + .approx_mode(crate::ApproxMode::Accurate) + .execute() + .await; + let data = data.unwrap().collect::>().await; + assert_eq!(data.len(), 1); + assert_eq!(data[0].as_ref().unwrap(), &expected_data); + } + #[tokio::test] async fn test_query_fts_default_values() { let expected_data = RecordBatch::try_new( diff --git a/rust/lancedb/src/table/query.rs b/rust/lancedb/src/table/query.rs index b136de2cd..04961f101 100644 --- a/rust/lancedb/src/table/query.rs +++ b/rust/lancedb/src/table/query.rs @@ -44,17 +44,35 @@ pub async fn execute_query( // QueryTable pushdown runs the query server-side, but only on the main // branch: the namespace request carries no branch yet, so a branch handle // must fall through to local execution. - if table - .pushdown_operations - .contains(&NamespaceClientPushdownOperation::QueryTable) + if can_execute_namespace_query(table, query) && let Some(ref namespace_client) = table.namespace_client - && table.dataset.current_branch().is_none() { return execute_namespace_query(table, namespace_client.clone(), query, options).await; } execute_generic_query(table, query, options).await } +fn can_execute_namespace_query(table: &NativeTable, query: &AnyQuery) -> bool { + table + .pushdown_operations + .contains(&NamespaceClientPushdownOperation::QueryTable) + && table.namespace_client.is_some() + && table.dataset.current_branch().is_none() + && !requires_local_namespace_execution(query) +} + +fn requires_local_namespace_execution(query: &AnyQuery) -> bool { + // The namespace QueryTable request has no approx_mode field yet, so + // pushing this query down would silently ignore the user's setting. + matches!( + query, + AnyQuery::VectorQuery(VectorQueryRequest { + approx_mode: Some(_), + .. + }) + ) +} + pub async fn analyze_query_plan( table: &NativeTable, query: &AnyQuery, @@ -167,6 +185,10 @@ pub async fn create_plan( scanner.nearest(&column, query_vector.as_ref(), top_k)?; } + if let Some(approx_mode) = query.approx_mode { + scanner.approx_mode(approx_mode.into()); + } + scanner.minimum_nprobes(query.minimum_nprobes); if let Some(maximum_nprobes) = query.maximum_nprobes { scanner.maximum_nprobes(maximum_nprobes); @@ -587,12 +609,20 @@ async fn parse_arrow_ipc_response(bytes: bytes::Bytes) -> Result, dimension: i32) -> FixedSizeListArray { + FixedSizeListArray::try_new_from_values(Float32Array::from(values), dimension).unwrap() + } #[test] fn test_convert_to_namespace_query_vector() { @@ -715,6 +745,80 @@ mod tests { assert_eq!(count, 2); // 4 and 5 } + #[derive(Debug, Default)] + struct CountingNamespaceClient { + query_table_calls: AtomicUsize, + } + + #[async_trait::async_trait] + impl LanceNamespace for CountingNamespaceClient { + fn namespace_id(&self) -> String { + "counting".to_string() + } + + async fn query_table(&self, _request: NsQueryTableRequest) -> lance::Result { + self.query_table_calls.fetch_add(1, Ordering::SeqCst); + panic!("approx_mode queries must not be pushed down to namespace query_table"); + } + } + + #[tokio::test] + async fn test_execute_query_approx_mode_with_namespace_pushdown_runs_locally() { + use crate::connect; + use crate::table::query::execute_query; + use arrow_array::{Int32Array, RecordBatch}; + use arrow_schema::{DataType, Field, Schema}; + + let conn = connect("memory://").execute().await.unwrap(); + + let vectors = Arc::new(fixed_size_list_array( + vec![0.0, 0.0, 10.0, 10.0, 20.0, 20.0], + 2, + )); + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("vector", vectors.data_type().clone(), false), + ])); + let batch = RecordBatch::try_new( + schema, + vec![Arc::new(Int32Array::from(vec![1, 2, 3])), vectors], + ) + .unwrap(); + + let table = conn + .create_table("test_approx_mode_namespace_fallback", batch) + .execute() + .await + .unwrap(); + let namespace_client = Arc::new(CountingNamespaceClient::default()); + let mut native_table = table.as_native().unwrap().clone(); + native_table.namespace_client = Some(namespace_client.clone()); + native_table + .pushdown_operations + .insert(NamespaceClientPushdownOperation::QueryTable); + + let query_vector = Arc::new(Float32Array::from(vec![0.0, 0.0])); + let query = AnyQuery::VectorQuery(VectorQueryRequest { + base: QueryRequest { + limit: Some(1), + ..Default::default() + }, + column: Some("vector".to_string()), + query_vector: vec![query_vector as ArrayRef], + approx_mode: Some(crate::ApproxMode::Accurate), + ..Default::default() + }); + + let stream = execute_query(&native_table, &query, QueryExecutionOptions::default()) + .await + .unwrap(); + let batches = stream.try_collect::>().await.unwrap(); + let count: usize = batches.iter().map(|b| b.num_rows()).sum(); + + assert_eq!(count, 1); + assert_eq!(namespace_client.query_table_calls.load(Ordering::SeqCst), 0); + } + #[tokio::test] async fn test_create_plan_multivector_structure() { use arrow_array::{Float32Array, RecordBatch}; @@ -779,4 +883,97 @@ mod tests { "Plan should add query_index column" ); } + + #[tokio::test] + async fn test_create_plan_applies_approx_mode_to_ann_query() { + use arrow_array::RecordBatch; + use arrow_schema::{DataType, Field, Schema}; + use datafusion_physical_plan::ExecutionPlan; + use lance::io::exec::{ANNIvfPartitionExec, ANNIvfSubIndexExec}; + use lance_index::vector::ApproxMode; + + use crate::connect; + use crate::index::{Index, vector::IvfRqIndexBuilder}; + use crate::table::query::create_plan; + + fn find_ann_approx_mode(plan: &dyn ExecutionPlan) -> Option { + if let Some(ann) = plan.as_any().downcast_ref::() { + return Some(ann.query().approx_mode); + } + if let Some(ann) = plan.as_any().downcast_ref::() { + return Some(ann.query.approx_mode); + } + plan.children() + .into_iter() + .find_map(|child| find_ann_approx_mode(child.as_ref())) + } + + let conn = connect("memory://").execute().await.unwrap(); + let dimension = 8; + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new( + "vector", + DataType::FixedSizeList( + Arc::new(Field::new("item", DataType::Float32, true)), + dimension, + ), + false, + ), + ])); + + let vectors = Arc::new(fixed_size_list_array( + (0..512 * dimension) + .map(|value| value as f32 / dimension as f32) + .collect(), + dimension, + )); + let batch = RecordBatch::try_new( + schema, + vec![ + Arc::new(arrow_array::Int32Array::from_iter_values(0..512)), + vectors, + ], + ) + .unwrap(); + let table = conn + .create_table("test_approx_mode_plan", vec![batch]) + .execute() + .await + .unwrap(); + table + .create_index( + &["vector"], + Index::IvfRq( + IvfRqIndexBuilder::default() + .num_partitions(1) + .sample_rate(1) + .max_iterations(1) + .num_bits(1), + ), + ) + .execute() + .await + .unwrap(); + let native_table = table.as_native().unwrap(); + let query_vector = Arc::new(Float32Array::from(vec![0.0; dimension as usize])); + let query = AnyQuery::VectorQuery(VectorQueryRequest { + column: Some("vector".to_string()), + query_vector: vec![query_vector as ArrayRef], + base: QueryRequest { + limit: Some(1), + ..Default::default() + }, + approx_mode: Some(crate::ApproxMode::Accurate), + ..Default::default() + }); + + let plan = create_plan(native_table, &query, QueryExecutionOptions::default()) + .await + .unwrap(); + assert_eq!( + find_ann_approx_mode(plan.as_ref()), + Some(ApproxMode::Accurate) + ); + } }