Compare commits

...

1 Commits

Author SHA1 Message Date
BubbleCal
b05b33bfbc feat(query): add approx mode to vector queries 2026-06-16 21:00:08 +08:00
4 changed files with 229 additions and 2 deletions

View File

@@ -184,12 +184,13 @@ pub mod table;
pub mod test_utils;
pub mod utils;
use std::fmt::Display;
use std::{fmt::Display, str::FromStr};
use serde::{Deserialize, Serialize};
pub use connection::{ConnectNamespaceBuilder, Connection};
pub use error::{Error, Result};
use lance_index::vector::ApproxMode as LanceApproxMode;
use lance_linalg::distance::DistanceType as LanceDistanceType;
pub use table::Table;
@@ -258,6 +259,79 @@ impl Display for DistanceType {
}
}
/// Controls the speed / accuracy tradeoff for approximate vector search.
///
/// This currently only affects RQ-quantized vector indexes, such as IVF_RQ.
/// Other index types ignore this setting.
#[derive(Debug, Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
#[non_exhaustive]
#[serde(rename_all = "lowercase")]
pub enum ApproxMode {
/// Prefer lower query latency, which can reduce recall.
Fast,
/// Use the default balance between query latency and recall.
#[default]
Normal,
/// Prefer higher recall, which can increase query latency.
Accurate,
}
impl From<ApproxMode> for LanceApproxMode {
fn from(value: ApproxMode) -> Self {
match value {
ApproxMode::Fast => Self::Fast,
ApproxMode::Normal => Self::Normal,
ApproxMode::Accurate => Self::Accurate,
}
}
}
impl From<LanceApproxMode> for ApproxMode {
fn from(value: LanceApproxMode) -> Self {
match value {
LanceApproxMode::Fast => Self::Fast,
LanceApproxMode::Normal => Self::Normal,
LanceApproxMode::Accurate => Self::Accurate,
}
}
}
impl<'a> TryFrom<&'a str> for ApproxMode {
type Error = Error;
fn try_from(value: &str) -> std::prelude::v1::Result<Self, Self::Error> {
Self::from_str(value)
}
}
impl FromStr for ApproxMode {
type Err = Error;
fn from_str(value: &str) -> std::prelude::v1::Result<Self, Self::Err> {
match value.to_ascii_lowercase().as_str() {
"fast" => Ok(Self::Fast),
"normal" => Ok(Self::Normal),
"accurate" => Ok(Self::Accurate),
_ => Err(Error::InvalidInput {
message: format!(
"approx_mode must be one of 'fast', 'normal', or 'accurate', got '{}'",
value
),
}),
}
}
}
impl Display for ApproxMode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Fast => write!(f, "fast"),
Self::Normal => write!(f, "normal"),
Self::Accurate => write!(f, "accurate"),
}
}
}
/// Connect to a database
pub use connection::connect;
/// Connect to a namespace-backed database

View File

@@ -20,12 +20,12 @@ use lance_index::scalar::FullTextSearchQuery;
use lance_index::scalar::inverted::SCORE_COL;
use lance_index::vector::DIST_COL;
use crate::DistanceType;
use crate::error::{Error, Result};
use crate::rerankers::rrf::RRFReranker;
use crate::rerankers::{NormalizeMethod, Reranker, check_reranker_result};
use crate::table::BaseTable;
use crate::utils::{MaxBatchLengthStream, TimeoutStream};
use crate::{ApproxMode, DistanceType};
use crate::{
arrow::{SendableRecordBatchStream, SimpleRecordBatchStream},
table::AnyQuery,
@@ -935,6 +935,8 @@ pub struct VectorQueryRequest {
pub refine_factor: Option<u32>,
/// The distance type to use for the search
pub distance_type: Option<DistanceType>,
/// The speed / accuracy tradeoff to use for approximate vector search
pub approx_mode: Option<ApproxMode>,
/// Default is true. Set to false to enforce a brute force search.
pub use_index: bool,
}
@@ -952,6 +954,7 @@ impl Default for VectorQueryRequest {
ef: None,
refine_factor: None,
distance_type: None,
approx_mode: None,
use_index: true,
}
}
@@ -1192,6 +1195,15 @@ impl VectorQuery {
self
}
/// Set the speed / accuracy tradeoff for approximate vector search.
///
/// This setting is currently only used by RQ-quantized indexes, such as
/// IVF_RQ. Other index types ignore this setting.
pub fn approx_mode(mut self, approx_mode: ApproxMode) -> Self {
self.request.approx_mode = Some(approx_mode);
self
}
/// If this is called then any vector index is skipped
///
/// An exhaustive (flat) search will be performed. The query vector will
@@ -1546,6 +1558,7 @@ mod tests {
.nprobes(1000)
.postfilter()
.distance_type(DistanceType::Cosine)
.approx_mode(ApproxMode::Accurate)
.refine_factor(999);
assert_eq!(
@@ -1564,9 +1577,49 @@ mod tests {
assert_eq!(query.request.maximum_nprobes, Some(1000));
assert!(query.request.use_index);
assert_eq!(query.request.distance_type, Some(DistanceType::Cosine));
assert_eq!(query.request.approx_mode, Some(ApproxMode::Accurate));
assert_eq!(query.request.refine_factor, Some(999));
}
#[test]
fn test_approx_mode_serde_parse_default_and_display() {
assert_eq!(ApproxMode::default(), ApproxMode::Normal);
assert_eq!(
serde_json::to_string(&ApproxMode::Fast).unwrap(),
"\"fast\""
);
assert_eq!(
serde_json::from_str::<ApproxMode>("\"accurate\"").unwrap(),
ApproxMode::Accurate
);
assert_eq!("normal".parse::<ApproxMode>().unwrap(), ApproxMode::Normal);
assert_eq!(ApproxMode::try_from("FAST").unwrap(), ApproxMode::Fast);
assert_eq!(ApproxMode::Accurate.to_string(), "accurate");
assert!(ApproxMode::try_from("invalid").is_err());
}
#[tokio::test]
async fn test_vector_query_approx_mode_builder() {
let tmp_dir = tempdir().unwrap();
let dataset_path = tmp_dir.path().join("test.lance");
let uri = dataset_path.to_str().unwrap();
let conn = connect(uri).execute().await.unwrap();
let table = conn
.create_table("my_table", make_test_batches())
.execute()
.await
.unwrap();
let query = table
.query()
.nearest_to(&[0.1, 0.2])
.unwrap()
.approx_mode(ApproxMode::Fast);
assert_eq!(query.request.approx_mode, Some(ApproxMode::Fast));
}
#[tokio::test]
async fn test_execute() {
// TODO: Switch back to memory://foo after https://github.com/lancedb/lancedb/issues/1051

View File

@@ -706,6 +706,9 @@ impl<S: HttpSend> RemoteTable<S> {
if let Some(distance_type) = query.distance_type {
body["distance_type"] = serde_json::json!(distance_type);
}
if let Some(approx_mode) = query.approx_mode {
body["approx_mode"] = serde_json::json!(approx_mode);
}
// In 0.23.1 we migrated from `nprobes` to `minimum_nprobes` and `maximum_nprobes`.
// Old client / new server: since minimum_nprobes is missing, fallback to nprobes
// New client / old server: old server will only see nprobes, make sure to set both
@@ -3610,6 +3613,61 @@ mod tests {
assert_eq!(data[0].as_ref().unwrap(), &expected_data);
}
#[tokio::test]
async fn test_query_vector_approx_mode_sent_when_set() {
let expected_data = RecordBatch::try_new(
Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])),
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
)
.unwrap();
let expected_data_ref = expected_data.clone();
let table = Table::new_with_handler("my_table", move |request| {
assert_eq!(request.method(), "POST");
assert_eq!(request.url().path(), "/v1/table/my_table/query/");
assert_eq!(
request.headers().get("Content-Type").unwrap(),
JSON_CONTENT_TYPE
);
let body = request.body().unwrap().as_bytes().unwrap();
let body: serde_json::Value = serde_json::from_slice(body).unwrap();
let mut expected_body = serde_json::json!({
"prefilter": true,
"nprobes": 20,
"minimum_nprobes": 20,
"maximum_nprobes": 20,
"approx_mode": "accurate",
"lower_bound": Option::<f32>::None,
"upper_bound": Option::<f32>::None,
"k": 10,
"ef": Option::<usize>::None,
"refine_factor": null,
"version": null,
});
expected_body["vector"] = vec![0.1f32, 0.2, 0.3].into();
assert_eq!(body, expected_body);
let response_body = write_ipc_file(&expected_data_ref);
http::Response::builder()
.status(200)
.header(CONTENT_TYPE, ARROW_FILE_CONTENT_TYPE)
.body(response_body)
.unwrap()
});
let data = table
.query()
.nearest_to(vec![0.1, 0.2, 0.3])
.unwrap()
.approx_mode(crate::ApproxMode::Accurate)
.execute()
.await;
let data = data.unwrap().collect::<Vec<_>>().await;
assert_eq!(data.len(), 1);
assert_eq!(data[0].as_ref().unwrap(), &expected_data);
}
#[tokio::test]
async fn test_query_fts_default_values() {
let expected_data = RecordBatch::try_new(

View File

@@ -167,6 +167,10 @@ pub async fn create_plan(
scanner.nearest(&column, query_vector.as_ref(), top_k)?;
}
if let Some(approx_mode) = query.approx_mode {
scanner.approx_mode(approx_mode.into());
}
scanner.minimum_nprobes(query.minimum_nprobes);
if let Some(maximum_nprobes) = query.maximum_nprobes {
scanner.maximum_nprobes(maximum_nprobes);
@@ -779,4 +783,42 @@ mod tests {
"Plan should add query_index column"
);
}
#[tokio::test]
async fn test_create_plan_accepts_approx_mode() {
use arrow_array::{Float32Array, RecordBatch};
use arrow_schema::{DataType, Field, Schema};
use crate::connect;
use crate::table::query::create_plan;
let conn = connect("memory://").execute().await.unwrap();
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new(
"vector",
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 2),
false,
),
]));
let batch = RecordBatch::new_empty(schema);
let table = conn
.create_table("test_approx_mode_plan", vec![batch])
.execute()
.await
.unwrap();
let native_table = table.as_native().unwrap();
let query_vector = Arc::new(Float32Array::from(vec![1.0, 2.0]));
let query = AnyQuery::VectorQuery(VectorQueryRequest {
column: Some("vector".to_string()),
query_vector: vec![query_vector],
approx_mode: Some(crate::ApproxMode::Accurate),
..Default::default()
});
create_plan(native_table, &query, QueryExecutionOptions::default())
.await
.unwrap();
}
}