diff --git a/rust/lancedb/src/lib.rs b/rust/lancedb/src/lib.rs index 0c279d52b..469292afb 100644 --- a/rust/lancedb/src/lib.rs +++ b/rust/lancedb/src/lib.rs @@ -184,12 +184,13 @@ pub mod table; pub mod test_utils; pub mod utils; -use std::fmt::Display; +use std::{fmt::Display, str::FromStr}; use serde::{Deserialize, Serialize}; pub use connection::{ConnectNamespaceBuilder, Connection}; pub use error::{Error, Result}; +use lance_index::vector::ApproxMode as LanceApproxMode; use lance_linalg::distance::DistanceType as LanceDistanceType; pub use table::Table; @@ -258,6 +259,79 @@ impl Display for DistanceType { } } +/// Controls the speed / accuracy tradeoff for approximate vector search. +/// +/// This currently only affects RQ-quantized vector indexes, such as IVF_RQ. +/// Other index types ignore this setting. +#[derive(Debug, Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Default)] +#[non_exhaustive] +#[serde(rename_all = "lowercase")] +pub enum ApproxMode { + /// Prefer lower query latency, which can reduce recall. + Fast, + /// Use the default balance between query latency and recall. + #[default] + Normal, + /// Prefer higher recall, which can increase query latency. + Accurate, +} + +impl From for LanceApproxMode { + fn from(value: ApproxMode) -> Self { + match value { + ApproxMode::Fast => Self::Fast, + ApproxMode::Normal => Self::Normal, + ApproxMode::Accurate => Self::Accurate, + } + } +} + +impl From for ApproxMode { + fn from(value: LanceApproxMode) -> Self { + match value { + LanceApproxMode::Fast => Self::Fast, + LanceApproxMode::Normal => Self::Normal, + LanceApproxMode::Accurate => Self::Accurate, + } + } +} + +impl<'a> TryFrom<&'a str> for ApproxMode { + type Error = Error; + + fn try_from(value: &str) -> std::prelude::v1::Result { + Self::from_str(value) + } +} + +impl FromStr for ApproxMode { + type Err = Error; + + fn from_str(value: &str) -> std::prelude::v1::Result { + match value.to_ascii_lowercase().as_str() { + "fast" => Ok(Self::Fast), + "normal" => Ok(Self::Normal), + "accurate" => Ok(Self::Accurate), + _ => Err(Error::InvalidInput { + message: format!( + "approx_mode must be one of 'fast', 'normal', or 'accurate', got '{}'", + value + ), + }), + } + } +} + +impl Display for ApproxMode { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Fast => write!(f, "fast"), + Self::Normal => write!(f, "normal"), + Self::Accurate => write!(f, "accurate"), + } + } +} + /// Connect to a database pub use connection::connect; /// Connect to a namespace-backed database diff --git a/rust/lancedb/src/query.rs b/rust/lancedb/src/query.rs index 7f82f5517..893ea7f9b 100644 --- a/rust/lancedb/src/query.rs +++ b/rust/lancedb/src/query.rs @@ -20,12 +20,12 @@ use lance_index::scalar::FullTextSearchQuery; use lance_index::scalar::inverted::SCORE_COL; use lance_index::vector::DIST_COL; -use crate::DistanceType; use crate::error::{Error, Result}; use crate::rerankers::rrf::RRFReranker; use crate::rerankers::{NormalizeMethod, Reranker, check_reranker_result}; use crate::table::BaseTable; use crate::utils::{MaxBatchLengthStream, TimeoutStream}; +use crate::{ApproxMode, DistanceType}; use crate::{ arrow::{SendableRecordBatchStream, SimpleRecordBatchStream}, table::AnyQuery, @@ -935,6 +935,8 @@ pub struct VectorQueryRequest { pub refine_factor: Option, /// The distance type to use for the search pub distance_type: Option, + /// The speed / accuracy tradeoff to use for approximate vector search + pub approx_mode: Option, /// Default is true. Set to false to enforce a brute force search. pub use_index: bool, } @@ -952,6 +954,7 @@ impl Default for VectorQueryRequest { ef: None, refine_factor: None, distance_type: None, + approx_mode: None, use_index: true, } } @@ -1192,6 +1195,15 @@ impl VectorQuery { self } + /// Set the speed / accuracy tradeoff for approximate vector search. + /// + /// This setting is currently only used by RQ-quantized indexes, such as + /// IVF_RQ. Other index types ignore this setting. + pub fn approx_mode(mut self, approx_mode: ApproxMode) -> Self { + self.request.approx_mode = Some(approx_mode); + self + } + /// If this is called then any vector index is skipped /// /// An exhaustive (flat) search will be performed. The query vector will @@ -1546,6 +1558,7 @@ mod tests { .nprobes(1000) .postfilter() .distance_type(DistanceType::Cosine) + .approx_mode(ApproxMode::Accurate) .refine_factor(999); assert_eq!( @@ -1564,9 +1577,49 @@ mod tests { assert_eq!(query.request.maximum_nprobes, Some(1000)); assert!(query.request.use_index); assert_eq!(query.request.distance_type, Some(DistanceType::Cosine)); + assert_eq!(query.request.approx_mode, Some(ApproxMode::Accurate)); assert_eq!(query.request.refine_factor, Some(999)); } + #[test] + fn test_approx_mode_serde_parse_default_and_display() { + assert_eq!(ApproxMode::default(), ApproxMode::Normal); + assert_eq!( + serde_json::to_string(&ApproxMode::Fast).unwrap(), + "\"fast\"" + ); + assert_eq!( + serde_json::from_str::("\"accurate\"").unwrap(), + ApproxMode::Accurate + ); + assert_eq!("normal".parse::().unwrap(), ApproxMode::Normal); + assert_eq!(ApproxMode::try_from("FAST").unwrap(), ApproxMode::Fast); + assert_eq!(ApproxMode::Accurate.to_string(), "accurate"); + assert!(ApproxMode::try_from("invalid").is_err()); + } + + #[tokio::test] + async fn test_vector_query_approx_mode_builder() { + let tmp_dir = tempdir().unwrap(); + let dataset_path = tmp_dir.path().join("test.lance"); + let uri = dataset_path.to_str().unwrap(); + + let conn = connect(uri).execute().await.unwrap(); + let table = conn + .create_table("my_table", make_test_batches()) + .execute() + .await + .unwrap(); + + let query = table + .query() + .nearest_to(&[0.1, 0.2]) + .unwrap() + .approx_mode(ApproxMode::Fast); + + assert_eq!(query.request.approx_mode, Some(ApproxMode::Fast)); + } + #[tokio::test] async fn test_execute() { // TODO: Switch back to memory://foo after https://github.com/lancedb/lancedb/issues/1051 diff --git a/rust/lancedb/src/remote/table.rs b/rust/lancedb/src/remote/table.rs index f11f13957..0e016cc4c 100644 --- a/rust/lancedb/src/remote/table.rs +++ b/rust/lancedb/src/remote/table.rs @@ -706,6 +706,9 @@ impl RemoteTable { if let Some(distance_type) = query.distance_type { body["distance_type"] = serde_json::json!(distance_type); } + if let Some(approx_mode) = query.approx_mode { + body["approx_mode"] = serde_json::json!(approx_mode); + } // In 0.23.1 we migrated from `nprobes` to `minimum_nprobes` and `maximum_nprobes`. // Old client / new server: since minimum_nprobes is missing, fallback to nprobes // New client / old server: old server will only see nprobes, make sure to set both @@ -3610,6 +3613,61 @@ mod tests { assert_eq!(data[0].as_ref().unwrap(), &expected_data); } + #[tokio::test] + async fn test_query_vector_approx_mode_sent_when_set() { + let expected_data = RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])), + vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], + ) + .unwrap(); + let expected_data_ref = expected_data.clone(); + + let table = Table::new_with_handler("my_table", move |request| { + assert_eq!(request.method(), "POST"); + assert_eq!(request.url().path(), "/v1/table/my_table/query/"); + assert_eq!( + request.headers().get("Content-Type").unwrap(), + JSON_CONTENT_TYPE + ); + + let body = request.body().unwrap().as_bytes().unwrap(); + let body: serde_json::Value = serde_json::from_slice(body).unwrap(); + let mut expected_body = serde_json::json!({ + "prefilter": true, + "nprobes": 20, + "minimum_nprobes": 20, + "maximum_nprobes": 20, + "approx_mode": "accurate", + "lower_bound": Option::::None, + "upper_bound": Option::::None, + "k": 10, + "ef": Option::::None, + "refine_factor": null, + "version": null, + }); + expected_body["vector"] = vec![0.1f32, 0.2, 0.3].into(); + assert_eq!(body, expected_body); + + let response_body = write_ipc_file(&expected_data_ref); + http::Response::builder() + .status(200) + .header(CONTENT_TYPE, ARROW_FILE_CONTENT_TYPE) + .body(response_body) + .unwrap() + }); + + let data = table + .query() + .nearest_to(vec![0.1, 0.2, 0.3]) + .unwrap() + .approx_mode(crate::ApproxMode::Accurate) + .execute() + .await; + let data = data.unwrap().collect::>().await; + assert_eq!(data.len(), 1); + assert_eq!(data[0].as_ref().unwrap(), &expected_data); + } + #[tokio::test] async fn test_query_fts_default_values() { let expected_data = RecordBatch::try_new( diff --git a/rust/lancedb/src/table/query.rs b/rust/lancedb/src/table/query.rs index b136de2cd..0e7135e28 100644 --- a/rust/lancedb/src/table/query.rs +++ b/rust/lancedb/src/table/query.rs @@ -167,6 +167,10 @@ pub async fn create_plan( scanner.nearest(&column, query_vector.as_ref(), top_k)?; } + if let Some(approx_mode) = query.approx_mode { + scanner.approx_mode(approx_mode.into()); + } + scanner.minimum_nprobes(query.minimum_nprobes); if let Some(maximum_nprobes) = query.maximum_nprobes { scanner.maximum_nprobes(maximum_nprobes); @@ -779,4 +783,42 @@ mod tests { "Plan should add query_index column" ); } + + #[tokio::test] + async fn test_create_plan_accepts_approx_mode() { + use arrow_array::{Float32Array, RecordBatch}; + use arrow_schema::{DataType, Field, Schema}; + + use crate::connect; + use crate::table::query::create_plan; + + let conn = connect("memory://").execute().await.unwrap(); + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new( + "vector", + DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 2), + false, + ), + ])); + + let batch = RecordBatch::new_empty(schema); + let table = conn + .create_table("test_approx_mode_plan", vec![batch]) + .execute() + .await + .unwrap(); + let native_table = table.as_native().unwrap(); + let query_vector = Arc::new(Float32Array::from(vec![1.0, 2.0])); + let query = AnyQuery::VectorQuery(VectorQueryRequest { + column: Some("vector".to_string()), + query_vector: vec![query_vector], + approx_mode: Some(crate::ApproxMode::Accurate), + ..Default::default() + }); + + create_plan(native_table, &query, QueryExecutionOptions::default()) + .await + .unwrap(); + } }