diff --git a/rust/lancedb/src/index/vector.rs b/rust/lancedb/src/index/vector.rs index d23c7cc1..20f3a847 100644 --- a/rust/lancedb/src/index/vector.rs +++ b/rust/lancedb/src/index/vector.rs @@ -6,8 +6,6 @@ //! //! Vector indices are only supported on fixed-size-list (tensor) columns of floating point //! values -use std::cmp::max; - use lance::table::format::{IndexMetadata, Manifest}; use crate::DistanceType; @@ -266,16 +264,6 @@ impl IvfPqIndexBuilder { impl_pq_params_setter!(); } -pub(crate) fn suggested_num_partitions(rows: usize) -> u32 { - let num_partitions = (rows as f64).sqrt() as u32; - max(1, num_partitions) -} - -pub(crate) fn suggested_num_partitions_for_hnsw(rows: usize, dim: u32) -> u32 { - let num_partitions = (((rows as u64) * (dim as u64)) / (256 * 5_000_000)) as u32; - max(1, num_partitions) -} - pub(crate) fn suggested_num_sub_vectors(dim: u32) -> u32 { if dim % 16 == 0 { // Should be more aggressive than this default. diff --git a/rust/lancedb/src/table.rs b/rust/lancedb/src/table.rs index 0e126d31..17d9fb74 100644 --- a/rust/lancedb/src/table.rs +++ b/rust/lancedb/src/table.rs @@ -33,6 +33,7 @@ use lance::io::WrappingObjectStore; use lance_datafusion::exec::{analyze_plan as lance_analyze_plan, execute_plan}; use lance_datafusion::utils::StreamingWriteSource; use lance_index::scalar::{BuiltinIndexType, ScalarIndexParams}; +use lance_index::vector::bq::RQBuildParams; use lance_index::vector::hnsw::builder::HnswBuildParams; use lance_index::vector::ivf::IvfBuildParams; use lance_index::vector::pq::PQBuildParams; @@ -53,12 +54,9 @@ use crate::connection::NoData; use crate::database::Database; use crate::embeddings::{EmbeddingDefinition, EmbeddingRegistry, MaybeEmbedded, MemoryRegistry}; use crate::error::{Error, Result}; -use crate::index::vector::{suggested_num_partitions_for_hnsw, VectorIndex}; +use crate::index::vector::VectorIndex; use crate::index::IndexStatistics; -use crate::index::{ - vector::{suggested_num_partitions, suggested_num_sub_vectors}, - Index, IndexBuilder, -}; +use crate::index::{vector::suggested_num_sub_vectors, Index, IndexBuilder}; use crate::index::{IndexConfig, IndexStatisticsImpl}; use crate::query::{ IntoQueryVector, Query, QueryExecutionOptions, QueryFilter, QueryRequest, Select, TakeQuery, @@ -1757,28 +1755,23 @@ impl NativeTable { Ok(()) } - // Helper to get num_partitions with default calculation - async fn get_num_partitions( - &self, - provided: Option, - for_hnsw: bool, - dim: Option, - ) -> Result { - if let Some(n) = provided { - Ok(n) - } else { - let row_count = self.count_rows(None).await?; - if for_hnsw { - Ok(suggested_num_partitions_for_hnsw( - row_count, - dim.ok_or_else(|| Error::InvalidInput { - message: "Vector dimension required for HNSW partitioning".to_string(), - })?, - )) - } else { - Ok(suggested_num_partitions(row_count)) + // Helper to build IVF params honoring table options. + fn build_ivf_params( + num_partitions: Option, + target_partition_size: Option, + sample_rate: u32, + max_iterations: u32, + ) -> IvfBuildParams { + let mut ivf_params = match (num_partitions, target_partition_size) { + (Some(num_partitions), _) => IvfBuildParams::new(num_partitions as usize), + (None, Some(target_partition_size)) => { + IvfBuildParams::with_target_partition_size(target_partition_size as usize) } - } + (None, None) => IvfBuildParams::default(), + }; + ivf_params.sample_rate = sample_rate as usize; + ivf_params.max_iters = max_iterations as usize; + ivf_params } // Helper to get num_sub_vectors with default calculation @@ -1805,15 +1798,16 @@ impl NativeTable { if supported_vector_data_type(field.data_type()) { // Use IvfPq as the default for auto vector indices let dim = Self::get_vector_dimension(field)?; - let num_partitions = self.get_num_partitions(None, false, None).await?; + let ivf_params = lance_index::vector::ivf::IvfBuildParams::default(); let num_sub_vectors = Self::get_num_sub_vectors(None, dim); - let lance_idx_params = lance::index::vector::VectorIndexParams::ivf_pq( - num_partitions as usize, - /*num_bits=*/ 8, - num_sub_vectors as usize, - lance_linalg::distance::MetricType::L2, - /*max_iterations=*/ 50, - ); + let pq_params = + lance_index::vector::pq::PQBuildParams::new(num_sub_vectors as usize, 8); + let lance_idx_params = + lance::index::vector::VectorIndexParams::with_ivf_pq_params( + lance_linalg::distance::MetricType::L2, + ivf_params, + pq_params, + ); Ok(Box::new(lance_idx_params)) } else if supported_btree_data_type(field.data_type()) { Ok(Box::new(ScalarIndexParams::for_builtin( @@ -1853,60 +1847,69 @@ impl NativeTable { } Index::IvfFlat(index) => { Self::validate_index_type(field, "IVF Flat", supported_vector_data_type)?; - let num_partitions = self - .get_num_partitions(index.num_partitions, false, None) - .await?; - let lance_idx_params = VectorIndexParams::ivf_flat( - num_partitions as usize, - index.distance_type.into(), + let ivf_params = Self::build_ivf_params( + index.num_partitions, + index.target_partition_size, + index.sample_rate, + index.max_iterations, ); + let lance_idx_params = + VectorIndexParams::with_ivf_flat_params(index.distance_type.into(), ivf_params); Ok(Box::new(lance_idx_params)) } Index::IvfPq(index) => { Self::validate_index_type(field, "IVF PQ", supported_vector_data_type)?; let dim = Self::get_vector_dimension(field)?; - let num_partitions = self - .get_num_partitions(index.num_partitions, false, None) - .await?; + let ivf_params = Self::build_ivf_params( + index.num_partitions, + index.target_partition_size, + index.sample_rate, + index.max_iterations, + ); let num_sub_vectors = Self::get_num_sub_vectors(index.num_sub_vectors, dim); - let lance_idx_params = VectorIndexParams::ivf_pq( - num_partitions as usize, - /*num_bits=*/ 8, - num_sub_vectors as usize, + let num_bits = index.num_bits.unwrap_or(8) as usize; + let mut pq_params = PQBuildParams::new(num_sub_vectors as usize, num_bits); + pq_params.max_iters = index.max_iterations as usize; + let lance_idx_params = VectorIndexParams::with_ivf_pq_params( index.distance_type.into(), - index.max_iterations as usize, + ivf_params, + pq_params, ); Ok(Box::new(lance_idx_params)) } Index::IvfRq(index) => { Self::validate_index_type(field, "IVF RQ", supported_vector_data_type)?; - let num_partitions = self - .get_num_partitions(index.num_partitions, false, None) - .await?; - let lance_idx_params = VectorIndexParams::ivf_rq( - num_partitions as usize, - index.num_bits.unwrap_or(1) as u8, + let ivf_params = Self::build_ivf_params( + index.num_partitions, + index.target_partition_size, + index.sample_rate, + index.max_iterations, + ); + let rq_params = RQBuildParams::new(index.num_bits.unwrap_or(1) as u8); + let lance_idx_params = VectorIndexParams::with_ivf_rq_params( index.distance_type.into(), + ivf_params, + rq_params, ); Ok(Box::new(lance_idx_params)) } Index::IvfHnswPq(index) => { Self::validate_index_type(field, "IVF HNSW PQ", supported_vector_data_type)?; let dim = Self::get_vector_dimension(field)?; - let num_partitions = self - .get_num_partitions(index.num_partitions, true, Some(dim)) - .await?; + let ivf_params = Self::build_ivf_params( + index.num_partitions, + index.target_partition_size, + index.sample_rate, + index.max_iterations, + ); let num_sub_vectors = Self::get_num_sub_vectors(index.num_sub_vectors, dim); - let mut ivf_params = IvfBuildParams::new(num_partitions as usize); - ivf_params.sample_rate = index.sample_rate as usize; - ivf_params.max_iters = index.max_iterations as usize; let hnsw_params = HnswBuildParams::default() .num_edges(index.m as usize) .ef_construction(index.ef_construction as usize); - let pq_params = PQBuildParams { - num_sub_vectors: num_sub_vectors as usize, - ..Default::default() - }; + let pq_params = PQBuildParams::new( + num_sub_vectors as usize, + index.num_bits.unwrap_or(8) as usize, + ); let lance_idx_params = VectorIndexParams::with_ivf_hnsw_pq_params( index.distance_type.into(), ivf_params, @@ -1917,13 +1920,12 @@ impl NativeTable { } Index::IvfHnswSq(index) => { Self::validate_index_type(field, "IVF HNSW SQ", supported_vector_data_type)?; - let dim = Self::get_vector_dimension(field)?; - let num_partitions = self - .get_num_partitions(index.num_partitions, true, Some(dim)) - .await?; - let mut ivf_params = IvfBuildParams::new(num_partitions as usize); - ivf_params.sample_rate = index.sample_rate as usize; - ivf_params.max_iters = index.max_iterations as usize; + let ivf_params = Self::build_ivf_params( + index.num_partitions, + index.target_partition_size, + index.sample_rate, + index.max_iterations, + ); let hnsw_params = HnswBuildParams::default() .num_edges(index.m as usize) .ef_construction(index.ef_construction as usize); @@ -3566,6 +3568,78 @@ mod tests { assert_eq!(table.list_indices().await.unwrap().len(), 0); } + #[tokio::test] + async fn test_ivf_pq_uses_default_partition_size_for_num_partitions() { + use arrow_array::{Float32Array, RecordBatch}; + use arrow_schema::{DataType, Field, Schema as ArrowSchema}; + + use crate::index::vector::IvfPqIndexBuilder; + + let tmp_dir = tempdir().unwrap(); + let uri = tmp_dir.path().to_str().unwrap(); + let conn = connect(uri).execute().await.unwrap(); + + const PARTITION_SIZE: usize = 8192; + let num_rows = PARTITION_SIZE * 2; + let dimension = 8usize; + let schema = Arc::new(ArrowSchema::new(vec![Field::new( + "embeddings", + DataType::FixedSizeList( + Arc::new(Field::new("item", DataType::Float32, true)), + dimension as i32, + ), + false, + )])); + + let float_arr = + Float32Array::from_iter_values((0..(num_rows * dimension)).map(|v| v as f32)); + let vectors = Arc::new(create_fixed_size_list(float_arr, dimension as i32).unwrap()); + let batches = RecordBatchIterator::new( + vec![RecordBatch::try_new(schema.clone(), vec![vectors]).unwrap()] + .into_iter() + .map(Ok), + schema, + ); + + let table = conn.create_table("test", batches).execute().await.unwrap(); + let native_table = table.as_native().unwrap(); + let builder = IvfPqIndexBuilder::default(); + table + .create_index(&["embeddings"], Index::IvfPq(builder)) + .execute() + .await + .unwrap(); + table + .wait_for_index(&["embeddings_idx"], std::time::Duration::from_secs(30)) + .await + .unwrap(); + + use lance::index::vector::ivf::v2::IvfPq as LanceIvfPq; + use lance::index::DatasetIndexInternalExt; + use lance_index::metrics::NoOpMetricsCollector; + use lance_index::vector::VectorIndex as LanceVectorIndex; + + let indices = native_table.load_indices().await.unwrap(); + let index_uuid = indices[0].index_uuid.clone(); + + let dataset_guard = native_table.dataset.get().await.unwrap(); + let dataset = (*dataset_guard).clone(); + drop(dataset_guard); + + let lance_index = dataset + .open_vector_index("embeddings", &index_uuid, &NoOpMetricsCollector) + .await + .unwrap(); + let ivf_index = lance_index + .as_any() + .downcast_ref::() + .expect("expected IvfPq index"); + let partition_count = ivf_index.ivf_model().num_partitions(); + + let expected_partitions = num_rows / PARTITION_SIZE; + assert_eq!(partition_count, expected_partitions); + } + #[tokio::test] async fn test_create_index_ivf_hnsw_sq() { use arrow_array::RecordBatch;