mirror of
https://github.com/lancedb/lancedb.git
synced 2026-06-17 03:00:41 +00:00
Compare commits
1 Commits
codex/upda
...
yang/appro
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b05b33bfbc |
@@ -184,12 +184,13 @@ pub mod table;
|
||||
pub mod test_utils;
|
||||
pub mod utils;
|
||||
|
||||
use std::fmt::Display;
|
||||
use std::{fmt::Display, str::FromStr};
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
pub use connection::{ConnectNamespaceBuilder, Connection};
|
||||
pub use error::{Error, Result};
|
||||
use lance_index::vector::ApproxMode as LanceApproxMode;
|
||||
use lance_linalg::distance::DistanceType as LanceDistanceType;
|
||||
pub use table::Table;
|
||||
|
||||
@@ -258,6 +259,79 @@ impl Display for DistanceType {
|
||||
}
|
||||
}
|
||||
|
||||
/// Controls the speed / accuracy tradeoff for approximate vector search.
|
||||
///
|
||||
/// This currently only affects RQ-quantized vector indexes, such as IVF_RQ.
|
||||
/// Other index types ignore this setting.
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
|
||||
#[non_exhaustive]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum ApproxMode {
|
||||
/// Prefer lower query latency, which can reduce recall.
|
||||
Fast,
|
||||
/// Use the default balance between query latency and recall.
|
||||
#[default]
|
||||
Normal,
|
||||
/// Prefer higher recall, which can increase query latency.
|
||||
Accurate,
|
||||
}
|
||||
|
||||
impl From<ApproxMode> for LanceApproxMode {
|
||||
fn from(value: ApproxMode) -> Self {
|
||||
match value {
|
||||
ApproxMode::Fast => Self::Fast,
|
||||
ApproxMode::Normal => Self::Normal,
|
||||
ApproxMode::Accurate => Self::Accurate,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<LanceApproxMode> for ApproxMode {
|
||||
fn from(value: LanceApproxMode) -> Self {
|
||||
match value {
|
||||
LanceApproxMode::Fast => Self::Fast,
|
||||
LanceApproxMode::Normal => Self::Normal,
|
||||
LanceApproxMode::Accurate => Self::Accurate,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> TryFrom<&'a str> for ApproxMode {
|
||||
type Error = Error;
|
||||
|
||||
fn try_from(value: &str) -> std::prelude::v1::Result<Self, Self::Error> {
|
||||
Self::from_str(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for ApproxMode {
|
||||
type Err = Error;
|
||||
|
||||
fn from_str(value: &str) -> std::prelude::v1::Result<Self, Self::Err> {
|
||||
match value.to_ascii_lowercase().as_str() {
|
||||
"fast" => Ok(Self::Fast),
|
||||
"normal" => Ok(Self::Normal),
|
||||
"accurate" => Ok(Self::Accurate),
|
||||
_ => Err(Error::InvalidInput {
|
||||
message: format!(
|
||||
"approx_mode must be one of 'fast', 'normal', or 'accurate', got '{}'",
|
||||
value
|
||||
),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for ApproxMode {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::Fast => write!(f, "fast"),
|
||||
Self::Normal => write!(f, "normal"),
|
||||
Self::Accurate => write!(f, "accurate"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Connect to a database
|
||||
pub use connection::connect;
|
||||
/// Connect to a namespace-backed database
|
||||
|
||||
@@ -20,12 +20,12 @@ use lance_index::scalar::FullTextSearchQuery;
|
||||
use lance_index::scalar::inverted::SCORE_COL;
|
||||
use lance_index::vector::DIST_COL;
|
||||
|
||||
use crate::DistanceType;
|
||||
use crate::error::{Error, Result};
|
||||
use crate::rerankers::rrf::RRFReranker;
|
||||
use crate::rerankers::{NormalizeMethod, Reranker, check_reranker_result};
|
||||
use crate::table::BaseTable;
|
||||
use crate::utils::{MaxBatchLengthStream, TimeoutStream};
|
||||
use crate::{ApproxMode, DistanceType};
|
||||
use crate::{
|
||||
arrow::{SendableRecordBatchStream, SimpleRecordBatchStream},
|
||||
table::AnyQuery,
|
||||
@@ -935,6 +935,8 @@ pub struct VectorQueryRequest {
|
||||
pub refine_factor: Option<u32>,
|
||||
/// The distance type to use for the search
|
||||
pub distance_type: Option<DistanceType>,
|
||||
/// The speed / accuracy tradeoff to use for approximate vector search
|
||||
pub approx_mode: Option<ApproxMode>,
|
||||
/// Default is true. Set to false to enforce a brute force search.
|
||||
pub use_index: bool,
|
||||
}
|
||||
@@ -952,6 +954,7 @@ impl Default for VectorQueryRequest {
|
||||
ef: None,
|
||||
refine_factor: None,
|
||||
distance_type: None,
|
||||
approx_mode: None,
|
||||
use_index: true,
|
||||
}
|
||||
}
|
||||
@@ -1192,6 +1195,15 @@ impl VectorQuery {
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the speed / accuracy tradeoff for approximate vector search.
|
||||
///
|
||||
/// This setting is currently only used by RQ-quantized indexes, such as
|
||||
/// IVF_RQ. Other index types ignore this setting.
|
||||
pub fn approx_mode(mut self, approx_mode: ApproxMode) -> Self {
|
||||
self.request.approx_mode = Some(approx_mode);
|
||||
self
|
||||
}
|
||||
|
||||
/// If this is called then any vector index is skipped
|
||||
///
|
||||
/// An exhaustive (flat) search will be performed. The query vector will
|
||||
@@ -1546,6 +1558,7 @@ mod tests {
|
||||
.nprobes(1000)
|
||||
.postfilter()
|
||||
.distance_type(DistanceType::Cosine)
|
||||
.approx_mode(ApproxMode::Accurate)
|
||||
.refine_factor(999);
|
||||
|
||||
assert_eq!(
|
||||
@@ -1564,9 +1577,49 @@ mod tests {
|
||||
assert_eq!(query.request.maximum_nprobes, Some(1000));
|
||||
assert!(query.request.use_index);
|
||||
assert_eq!(query.request.distance_type, Some(DistanceType::Cosine));
|
||||
assert_eq!(query.request.approx_mode, Some(ApproxMode::Accurate));
|
||||
assert_eq!(query.request.refine_factor, Some(999));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_approx_mode_serde_parse_default_and_display() {
|
||||
assert_eq!(ApproxMode::default(), ApproxMode::Normal);
|
||||
assert_eq!(
|
||||
serde_json::to_string(&ApproxMode::Fast).unwrap(),
|
||||
"\"fast\""
|
||||
);
|
||||
assert_eq!(
|
||||
serde_json::from_str::<ApproxMode>("\"accurate\"").unwrap(),
|
||||
ApproxMode::Accurate
|
||||
);
|
||||
assert_eq!("normal".parse::<ApproxMode>().unwrap(), ApproxMode::Normal);
|
||||
assert_eq!(ApproxMode::try_from("FAST").unwrap(), ApproxMode::Fast);
|
||||
assert_eq!(ApproxMode::Accurate.to_string(), "accurate");
|
||||
assert!(ApproxMode::try_from("invalid").is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_vector_query_approx_mode_builder() {
|
||||
let tmp_dir = tempdir().unwrap();
|
||||
let dataset_path = tmp_dir.path().join("test.lance");
|
||||
let uri = dataset_path.to_str().unwrap();
|
||||
|
||||
let conn = connect(uri).execute().await.unwrap();
|
||||
let table = conn
|
||||
.create_table("my_table", make_test_batches())
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let query = table
|
||||
.query()
|
||||
.nearest_to(&[0.1, 0.2])
|
||||
.unwrap()
|
||||
.approx_mode(ApproxMode::Fast);
|
||||
|
||||
assert_eq!(query.request.approx_mode, Some(ApproxMode::Fast));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_execute() {
|
||||
// TODO: Switch back to memory://foo after https://github.com/lancedb/lancedb/issues/1051
|
||||
|
||||
@@ -706,6 +706,9 @@ impl<S: HttpSend> RemoteTable<S> {
|
||||
if let Some(distance_type) = query.distance_type {
|
||||
body["distance_type"] = serde_json::json!(distance_type);
|
||||
}
|
||||
if let Some(approx_mode) = query.approx_mode {
|
||||
body["approx_mode"] = serde_json::json!(approx_mode);
|
||||
}
|
||||
// In 0.23.1 we migrated from `nprobes` to `minimum_nprobes` and `maximum_nprobes`.
|
||||
// Old client / new server: since minimum_nprobes is missing, fallback to nprobes
|
||||
// New client / old server: old server will only see nprobes, make sure to set both
|
||||
@@ -3610,6 +3613,61 @@ mod tests {
|
||||
assert_eq!(data[0].as_ref().unwrap(), &expected_data);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_query_vector_approx_mode_sent_when_set() {
|
||||
let expected_data = RecordBatch::try_new(
|
||||
Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])),
|
||||
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
|
||||
)
|
||||
.unwrap();
|
||||
let expected_data_ref = expected_data.clone();
|
||||
|
||||
let table = Table::new_with_handler("my_table", move |request| {
|
||||
assert_eq!(request.method(), "POST");
|
||||
assert_eq!(request.url().path(), "/v1/table/my_table/query/");
|
||||
assert_eq!(
|
||||
request.headers().get("Content-Type").unwrap(),
|
||||
JSON_CONTENT_TYPE
|
||||
);
|
||||
|
||||
let body = request.body().unwrap().as_bytes().unwrap();
|
||||
let body: serde_json::Value = serde_json::from_slice(body).unwrap();
|
||||
let mut expected_body = serde_json::json!({
|
||||
"prefilter": true,
|
||||
"nprobes": 20,
|
||||
"minimum_nprobes": 20,
|
||||
"maximum_nprobes": 20,
|
||||
"approx_mode": "accurate",
|
||||
"lower_bound": Option::<f32>::None,
|
||||
"upper_bound": Option::<f32>::None,
|
||||
"k": 10,
|
||||
"ef": Option::<usize>::None,
|
||||
"refine_factor": null,
|
||||
"version": null,
|
||||
});
|
||||
expected_body["vector"] = vec![0.1f32, 0.2, 0.3].into();
|
||||
assert_eq!(body, expected_body);
|
||||
|
||||
let response_body = write_ipc_file(&expected_data_ref);
|
||||
http::Response::builder()
|
||||
.status(200)
|
||||
.header(CONTENT_TYPE, ARROW_FILE_CONTENT_TYPE)
|
||||
.body(response_body)
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
let data = table
|
||||
.query()
|
||||
.nearest_to(vec![0.1, 0.2, 0.3])
|
||||
.unwrap()
|
||||
.approx_mode(crate::ApproxMode::Accurate)
|
||||
.execute()
|
||||
.await;
|
||||
let data = data.unwrap().collect::<Vec<_>>().await;
|
||||
assert_eq!(data.len(), 1);
|
||||
assert_eq!(data[0].as_ref().unwrap(), &expected_data);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_query_fts_default_values() {
|
||||
let expected_data = RecordBatch::try_new(
|
||||
|
||||
@@ -167,6 +167,10 @@ pub async fn create_plan(
|
||||
scanner.nearest(&column, query_vector.as_ref(), top_k)?;
|
||||
}
|
||||
|
||||
if let Some(approx_mode) = query.approx_mode {
|
||||
scanner.approx_mode(approx_mode.into());
|
||||
}
|
||||
|
||||
scanner.minimum_nprobes(query.minimum_nprobes);
|
||||
if let Some(maximum_nprobes) = query.maximum_nprobes {
|
||||
scanner.maximum_nprobes(maximum_nprobes);
|
||||
@@ -779,4 +783,42 @@ mod tests {
|
||||
"Plan should add query_index column"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_create_plan_accepts_approx_mode() {
|
||||
use arrow_array::{Float32Array, RecordBatch};
|
||||
use arrow_schema::{DataType, Field, Schema};
|
||||
|
||||
use crate::connect;
|
||||
use crate::table::query::create_plan;
|
||||
|
||||
let conn = connect("memory://").execute().await.unwrap();
|
||||
let schema = Arc::new(Schema::new(vec![
|
||||
Field::new("id", DataType::Int32, false),
|
||||
Field::new(
|
||||
"vector",
|
||||
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 2),
|
||||
false,
|
||||
),
|
||||
]));
|
||||
|
||||
let batch = RecordBatch::new_empty(schema);
|
||||
let table = conn
|
||||
.create_table("test_approx_mode_plan", vec![batch])
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
let native_table = table.as_native().unwrap();
|
||||
let query_vector = Arc::new(Float32Array::from(vec![1.0, 2.0]));
|
||||
let query = AnyQuery::VectorQuery(VectorQueryRequest {
|
||||
column: Some("vector".to_string()),
|
||||
query_vector: vec![query_vector],
|
||||
approx_mode: Some(crate::ApproxMode::Accurate),
|
||||
..Default::default()
|
||||
});
|
||||
|
||||
create_plan(native_table, &query, QueryExecutionOptions::default())
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user