feat: let lance determine the default num_partitions param (#2775)

This commit is contained in:
BubbleCal
2025-11-12 09:43:19 +08:00
committed by GitHub
parent 1ff594a6a4
commit 3e42a43bbf
2 changed files with 145 additions and 83 deletions

View File

@@ -6,8 +6,6 @@
//!
//! Vector indices are only supported on fixed-size-list (tensor) columns of floating point
//! values
use std::cmp::max;
use lance::table::format::{IndexMetadata, Manifest};
use crate::DistanceType;
@@ -266,16 +264,6 @@ impl IvfPqIndexBuilder {
impl_pq_params_setter!();
}
pub(crate) fn suggested_num_partitions(rows: usize) -> u32 {
let num_partitions = (rows as f64).sqrt() as u32;
max(1, num_partitions)
}
pub(crate) fn suggested_num_partitions_for_hnsw(rows: usize, dim: u32) -> u32 {
let num_partitions = (((rows as u64) * (dim as u64)) / (256 * 5_000_000)) as u32;
max(1, num_partitions)
}
pub(crate) fn suggested_num_sub_vectors(dim: u32) -> u32 {
if dim % 16 == 0 {
// Should be more aggressive than this default.

View File

@@ -33,6 +33,7 @@ use lance::io::WrappingObjectStore;
use lance_datafusion::exec::{analyze_plan as lance_analyze_plan, execute_plan};
use lance_datafusion::utils::StreamingWriteSource;
use lance_index::scalar::{BuiltinIndexType, ScalarIndexParams};
use lance_index::vector::bq::RQBuildParams;
use lance_index::vector::hnsw::builder::HnswBuildParams;
use lance_index::vector::ivf::IvfBuildParams;
use lance_index::vector::pq::PQBuildParams;
@@ -53,12 +54,9 @@ use crate::connection::NoData;
use crate::database::Database;
use crate::embeddings::{EmbeddingDefinition, EmbeddingRegistry, MaybeEmbedded, MemoryRegistry};
use crate::error::{Error, Result};
use crate::index::vector::{suggested_num_partitions_for_hnsw, VectorIndex};
use crate::index::vector::VectorIndex;
use crate::index::IndexStatistics;
use crate::index::{
vector::{suggested_num_partitions, suggested_num_sub_vectors},
Index, IndexBuilder,
};
use crate::index::{vector::suggested_num_sub_vectors, Index, IndexBuilder};
use crate::index::{IndexConfig, IndexStatisticsImpl};
use crate::query::{
IntoQueryVector, Query, QueryExecutionOptions, QueryFilter, QueryRequest, Select, TakeQuery,
@@ -1757,28 +1755,23 @@ impl NativeTable {
Ok(())
}
// Helper to get num_partitions with default calculation
async fn get_num_partitions(
&self,
provided: Option<u32>,
for_hnsw: bool,
dim: Option<u32>,
) -> Result<u32> {
if let Some(n) = provided {
Ok(n)
} else {
let row_count = self.count_rows(None).await?;
if for_hnsw {
Ok(suggested_num_partitions_for_hnsw(
row_count,
dim.ok_or_else(|| Error::InvalidInput {
message: "Vector dimension required for HNSW partitioning".to_string(),
})?,
))
} else {
Ok(suggested_num_partitions(row_count))
// Helper to build IVF params honoring table options.
fn build_ivf_params(
num_partitions: Option<u32>,
target_partition_size: Option<u32>,
sample_rate: u32,
max_iterations: u32,
) -> IvfBuildParams {
let mut ivf_params = match (num_partitions, target_partition_size) {
(Some(num_partitions), _) => IvfBuildParams::new(num_partitions as usize),
(None, Some(target_partition_size)) => {
IvfBuildParams::with_target_partition_size(target_partition_size as usize)
}
}
(None, None) => IvfBuildParams::default(),
};
ivf_params.sample_rate = sample_rate as usize;
ivf_params.max_iters = max_iterations as usize;
ivf_params
}
// Helper to get num_sub_vectors with default calculation
@@ -1805,15 +1798,16 @@ impl NativeTable {
if supported_vector_data_type(field.data_type()) {
// Use IvfPq as the default for auto vector indices
let dim = Self::get_vector_dimension(field)?;
let num_partitions = self.get_num_partitions(None, false, None).await?;
let ivf_params = lance_index::vector::ivf::IvfBuildParams::default();
let num_sub_vectors = Self::get_num_sub_vectors(None, dim);
let lance_idx_params = lance::index::vector::VectorIndexParams::ivf_pq(
num_partitions as usize,
/*num_bits=*/ 8,
num_sub_vectors as usize,
lance_linalg::distance::MetricType::L2,
/*max_iterations=*/ 50,
);
let pq_params =
lance_index::vector::pq::PQBuildParams::new(num_sub_vectors as usize, 8);
let lance_idx_params =
lance::index::vector::VectorIndexParams::with_ivf_pq_params(
lance_linalg::distance::MetricType::L2,
ivf_params,
pq_params,
);
Ok(Box::new(lance_idx_params))
} else if supported_btree_data_type(field.data_type()) {
Ok(Box::new(ScalarIndexParams::for_builtin(
@@ -1853,60 +1847,69 @@ impl NativeTable {
}
Index::IvfFlat(index) => {
Self::validate_index_type(field, "IVF Flat", supported_vector_data_type)?;
let num_partitions = self
.get_num_partitions(index.num_partitions, false, None)
.await?;
let lance_idx_params = VectorIndexParams::ivf_flat(
num_partitions as usize,
index.distance_type.into(),
let ivf_params = Self::build_ivf_params(
index.num_partitions,
index.target_partition_size,
index.sample_rate,
index.max_iterations,
);
let lance_idx_params =
VectorIndexParams::with_ivf_flat_params(index.distance_type.into(), ivf_params);
Ok(Box::new(lance_idx_params))
}
Index::IvfPq(index) => {
Self::validate_index_type(field, "IVF PQ", supported_vector_data_type)?;
let dim = Self::get_vector_dimension(field)?;
let num_partitions = self
.get_num_partitions(index.num_partitions, false, None)
.await?;
let ivf_params = Self::build_ivf_params(
index.num_partitions,
index.target_partition_size,
index.sample_rate,
index.max_iterations,
);
let num_sub_vectors = Self::get_num_sub_vectors(index.num_sub_vectors, dim);
let lance_idx_params = VectorIndexParams::ivf_pq(
num_partitions as usize,
/*num_bits=*/ 8,
num_sub_vectors as usize,
let num_bits = index.num_bits.unwrap_or(8) as usize;
let mut pq_params = PQBuildParams::new(num_sub_vectors as usize, num_bits);
pq_params.max_iters = index.max_iterations as usize;
let lance_idx_params = VectorIndexParams::with_ivf_pq_params(
index.distance_type.into(),
index.max_iterations as usize,
ivf_params,
pq_params,
);
Ok(Box::new(lance_idx_params))
}
Index::IvfRq(index) => {
Self::validate_index_type(field, "IVF RQ", supported_vector_data_type)?;
let num_partitions = self
.get_num_partitions(index.num_partitions, false, None)
.await?;
let lance_idx_params = VectorIndexParams::ivf_rq(
num_partitions as usize,
index.num_bits.unwrap_or(1) as u8,
let ivf_params = Self::build_ivf_params(
index.num_partitions,
index.target_partition_size,
index.sample_rate,
index.max_iterations,
);
let rq_params = RQBuildParams::new(index.num_bits.unwrap_or(1) as u8);
let lance_idx_params = VectorIndexParams::with_ivf_rq_params(
index.distance_type.into(),
ivf_params,
rq_params,
);
Ok(Box::new(lance_idx_params))
}
Index::IvfHnswPq(index) => {
Self::validate_index_type(field, "IVF HNSW PQ", supported_vector_data_type)?;
let dim = Self::get_vector_dimension(field)?;
let num_partitions = self
.get_num_partitions(index.num_partitions, true, Some(dim))
.await?;
let ivf_params = Self::build_ivf_params(
index.num_partitions,
index.target_partition_size,
index.sample_rate,
index.max_iterations,
);
let num_sub_vectors = Self::get_num_sub_vectors(index.num_sub_vectors, dim);
let mut ivf_params = IvfBuildParams::new(num_partitions as usize);
ivf_params.sample_rate = index.sample_rate as usize;
ivf_params.max_iters = index.max_iterations as usize;
let hnsw_params = HnswBuildParams::default()
.num_edges(index.m as usize)
.ef_construction(index.ef_construction as usize);
let pq_params = PQBuildParams {
num_sub_vectors: num_sub_vectors as usize,
..Default::default()
};
let pq_params = PQBuildParams::new(
num_sub_vectors as usize,
index.num_bits.unwrap_or(8) as usize,
);
let lance_idx_params = VectorIndexParams::with_ivf_hnsw_pq_params(
index.distance_type.into(),
ivf_params,
@@ -1917,13 +1920,12 @@ impl NativeTable {
}
Index::IvfHnswSq(index) => {
Self::validate_index_type(field, "IVF HNSW SQ", supported_vector_data_type)?;
let dim = Self::get_vector_dimension(field)?;
let num_partitions = self
.get_num_partitions(index.num_partitions, true, Some(dim))
.await?;
let mut ivf_params = IvfBuildParams::new(num_partitions as usize);
ivf_params.sample_rate = index.sample_rate as usize;
ivf_params.max_iters = index.max_iterations as usize;
let ivf_params = Self::build_ivf_params(
index.num_partitions,
index.target_partition_size,
index.sample_rate,
index.max_iterations,
);
let hnsw_params = HnswBuildParams::default()
.num_edges(index.m as usize)
.ef_construction(index.ef_construction as usize);
@@ -3566,6 +3568,78 @@ mod tests {
assert_eq!(table.list_indices().await.unwrap().len(), 0);
}
#[tokio::test]
async fn test_ivf_pq_uses_default_partition_size_for_num_partitions() {
use arrow_array::{Float32Array, RecordBatch};
use arrow_schema::{DataType, Field, Schema as ArrowSchema};
use crate::index::vector::IvfPqIndexBuilder;
let tmp_dir = tempdir().unwrap();
let uri = tmp_dir.path().to_str().unwrap();
let conn = connect(uri).execute().await.unwrap();
const PARTITION_SIZE: usize = 8192;
let num_rows = PARTITION_SIZE * 2;
let dimension = 8usize;
let schema = Arc::new(ArrowSchema::new(vec![Field::new(
"embeddings",
DataType::FixedSizeList(
Arc::new(Field::new("item", DataType::Float32, true)),
dimension as i32,
),
false,
)]));
let float_arr =
Float32Array::from_iter_values((0..(num_rows * dimension)).map(|v| v as f32));
let vectors = Arc::new(create_fixed_size_list(float_arr, dimension as i32).unwrap());
let batches = RecordBatchIterator::new(
vec![RecordBatch::try_new(schema.clone(), vec![vectors]).unwrap()]
.into_iter()
.map(Ok),
schema,
);
let table = conn.create_table("test", batches).execute().await.unwrap();
let native_table = table.as_native().unwrap();
let builder = IvfPqIndexBuilder::default();
table
.create_index(&["embeddings"], Index::IvfPq(builder))
.execute()
.await
.unwrap();
table
.wait_for_index(&["embeddings_idx"], std::time::Duration::from_secs(30))
.await
.unwrap();
use lance::index::vector::ivf::v2::IvfPq as LanceIvfPq;
use lance::index::DatasetIndexInternalExt;
use lance_index::metrics::NoOpMetricsCollector;
use lance_index::vector::VectorIndex as LanceVectorIndex;
let indices = native_table.load_indices().await.unwrap();
let index_uuid = indices[0].index_uuid.clone();
let dataset_guard = native_table.dataset.get().await.unwrap();
let dataset = (*dataset_guard).clone();
drop(dataset_guard);
let lance_index = dataset
.open_vector_index("embeddings", &index_uuid, &NoOpMetricsCollector)
.await
.unwrap();
let ivf_index = lance_index
.as_any()
.downcast_ref::<LanceIvfPq>()
.expect("expected IvfPq index");
let partition_count = ivf_index.ivf_model().num_partitions();
let expected_partitions = num_rows / PARTITION_SIZE;
assert_eq!(partition_count, expected_partitions);
}
#[tokio::test]
async fn test_create_index_ivf_hnsw_sq() {
use arrow_array::RecordBatch;