diff --git a/rust/ffi/node/src/error.rs b/rust/ffi/node/src/error.rs index ae165c12..2777cb7b 100644 --- a/rust/ffi/node/src/error.rs +++ b/rust/ffi/node/src/error.rs @@ -19,10 +19,12 @@ use snafu::Snafu; #[derive(Debug, Snafu)] pub enum Error { + #[allow(dead_code)] #[snafu(display("column '{name}' is missing"))] MissingColumn { name: String }, #[snafu(display("{name}: {message}"))] OutOfRange { name: String, message: String }, + #[allow(dead_code)] #[snafu(display("{index_type} is not a valid index type"))] InvalidIndexType { index_type: String }, diff --git a/rust/ffi/node/src/neon_ext/js_object_ext.rs b/rust/ffi/node/src/neon_ext/js_object_ext.rs index 88deb915..fb08a4e1 100644 --- a/rust/ffi/node/src/neon_ext/js_object_ext.rs +++ b/rust/ffi/node/src/neon_ext/js_object_ext.rs @@ -19,6 +19,7 @@ use neon::prelude::*; pub trait JsObjectExt { fn get_opt_u32(&self, cx: &mut FunctionContext, key: &str) -> Result>; fn get_usize(&self, cx: &mut FunctionContext, key: &str) -> Result; + #[allow(dead_code)] fn get_opt_usize(&self, cx: &mut FunctionContext, key: &str) -> Result>; } diff --git a/rust/lancedb/src/index.rs b/rust/lancedb/src/index.rs index 6ddfbd42..d29ef9cf 100644 --- a/rust/lancedb/src/index.rs +++ b/rust/lancedb/src/index.rs @@ -16,7 +16,10 @@ use std::sync::Arc; use crate::{table::TableInternal, Result}; -use self::{scalar::BTreeIndexBuilder, vector::IvfPqIndexBuilder}; +use self::{ + scalar::BTreeIndexBuilder, + vector::{IvfHnswSqIndexBuilder, IvfPqIndexBuilder}, +}; pub mod scalar; pub mod vector; @@ -25,6 +28,7 @@ pub enum Index { Auto, BTree(BTreeIndexBuilder), IvfPq(IvfPqIndexBuilder), + IvfHnswSq(IvfHnswSqIndexBuilder), } /// Builder for the create_index operation @@ -65,6 +69,7 @@ impl IndexBuilder { #[derive(Debug, Clone, PartialEq)] pub enum IndexType { IvfPq, + IvfHnswSq, BTree, } diff --git a/rust/lancedb/src/index/vector.rs b/rust/lancedb/src/index/vector.rs index c2637378..4be1a5e6 100644 --- a/rust/lancedb/src/index/vector.rs +++ b/rust/lancedb/src/index/vector.rs @@ -83,10 +83,14 @@ pub struct VectorIndexStatistics { #[derive(Debug, Clone)] pub struct IvfPqIndexBuilder { pub(crate) distance_type: DistanceType, + + // IVF pub(crate) num_partitions: Option, - pub(crate) num_sub_vectors: Option, pub(crate) sample_rate: u32, pub(crate) max_iterations: u32, + + // PQ + pub(crate) num_sub_vectors: Option, } impl Default for IvfPqIndexBuilder { @@ -201,3 +205,124 @@ pub(crate) fn suggested_num_sub_vectors(dim: u32) -> u32 { 1 } } + +/// Builder for an IVF_HNSW_SQ index. +/// +/// This index is a combination of IVF and HNSW. +/// The IVF part is the same as the IVF PQ index. +/// For each IVF partition, this builds a HNSW graph, the graph is used to +/// quickly find the closest vectors to a query vector. +/// +/// The SQ (scalar quantizer) is used to compress the vectors, +/// each vector is mapped to a 8-bit integer vector, 4x compression ratio for float32 vector. +#[derive(Debug, Clone)] +pub struct IvfHnswSqIndexBuilder { + // IVF + pub(crate) distance_type: DistanceType, + pub(crate) num_partitions: Option, + pub(crate) sample_rate: u32, + pub(crate) max_iterations: u32, + + // HNSW + pub(crate) m: u32, + pub(crate) ef_construction: u32, + // SQ + // TODO add num_bits for SQ after it supports another num_bits besides 8 +} + +impl Default for IvfHnswSqIndexBuilder { + fn default() -> Self { + Self { + distance_type: DistanceType::L2, + num_partitions: None, + sample_rate: 256, + max_iterations: 50, + m: 20, + ef_construction: 300, + } + } +} + +impl IvfHnswSqIndexBuilder { + /// [DistanceType] to use to build the index. + /// + /// Default value is [DistanceType::L2]. + /// + /// This is used when training the index to calculate the IVF partitions (vectors are + /// grouped in partitions with similar vectors according to this distance type) + /// + /// The metric type used to train an index MUST match the metric type used to search the + /// index. Failure to do so will yield inaccurate results. + /// + /// Now IVF_HNSW_SQ only supports L2 and Cosine distance types. + pub fn distance_type(mut self, distance_type: DistanceType) -> Self { + self.distance_type = distance_type; + self + } + + /// The number of IVF partitions to create. + /// + /// This value should generally scale with the number of rows in the dataset. By default + /// the number of partitions is the square root of the number of rows. + /// + /// If this value is too large then the first part of the search (picking the right partition) + /// will be slow. If this value is too small then the second part of the search (searching + /// within a partition) will be slow. + pub fn num_partitions(mut self, num_partitions: u32) -> Self { + self.num_partitions = Some(num_partitions); + self + } + + /// The rate used to calculate the number of training vectors for kmeans and SQ. + /// + /// When an IVF_HNSW_SQ index is trained, we need to calculate partitions and min/max value of vectors. These are groups + /// of vectors that are similar to each other. To do this we use an algorithm called kmeans. + /// + /// Running kmeans on a large dataset can be slow. To speed this up we run kmeans on a + /// random sample of the data. This parameter controls the size of the sample. The total + /// number of vectors used to train the IVF is `sample_rate * num_partitions`. + /// + /// The total number of vectors used to train the SQ is `sample_rate * 2^{num_bits}`. + /// + /// Increasing this value might improve the quality of the index but in most cases the + /// default should be sufficient. + /// + /// The default value is 256. + pub fn sample_rate(mut self, sample_rate: u32) -> Self { + self.sample_rate = sample_rate; + self + } + + /// Max iterations to train kmeans. + /// + /// When training an IVF index we use kmeans to calculate the partitions. This parameter + /// controls how many iterations of kmeans to run. + /// + /// Increasing this might improve the quality of the index but in most cases the parameter + /// is unused because kmeans will converge with fewer iterations. The parameter is only + /// used in cases where kmeans does not appear to converge. In those cases it is unlikely + /// that setting this larger will lead to the index converging anyways. + /// + /// The default value is 50. + pub fn max_iterations(mut self, max_iterations: u32) -> Self { + self.max_iterations = max_iterations; + self + } + + /// The number of neighbors to select for each vector in the HNSW graph. + /// Bumping this number will increase the recall of the search but also increase the build/search time. + /// The default value is 20. + pub fn m(mut self, m: u32) -> Self { + self.m = m; + self + } + + /// The number of candidates to evaluate during the construction of the HNSW graph. + /// Bumping this number will increase the recall of the search but also increase the build/search time. + /// This value should be not less than `ef` in the search phase. + /// The default value is 300. + pub fn ef_construction(mut self, ef_construction: u32) -> Self { + self.ef_construction = ef_construction; + self + } +} diff --git a/rust/lancedb/src/table.rs b/rust/lancedb/src/table.rs index 522f5dc4..3e45aeff 100644 --- a/rust/lancedb/src/table.rs +++ b/rust/lancedb/src/table.rs @@ -38,6 +38,9 @@ use lance::dataset::{ }; use lance::dataset::{MergeInsertBuilder as LanceMergeInsertBuilder, WhenNotMatchedBySource}; use lance::io::WrappingObjectStore; +use lance_index::vector::hnsw::builder::HnswBuildParams; +use lance_index::vector::ivf::IvfBuildParams; +use lance_index::vector::sq::builder::SQBuildParams; use lance_index::IndexType; use lance_index::{optimize::OptimizeOptions, DatasetIndexExt}; use log::info; @@ -48,7 +51,9 @@ use crate::arrow::IntoArrow; use crate::connection::NoData; use crate::embeddings::{EmbeddingDefinition, EmbeddingRegistry, MaybeEmbedded, MemoryRegistry}; use crate::error::{Error, Result}; -use crate::index::vector::{IvfPqIndexBuilder, VectorIndex, VectorIndexStatistics}; +use crate::index::vector::{ + IvfHnswSqIndexBuilder, IvfPqIndexBuilder, VectorIndex, VectorIndexStatistics, +}; use crate::index::IndexConfig; use crate::index::{ vector::{suggested_num_partitions, suggested_num_sub_vectors}, @@ -312,6 +317,7 @@ impl UpdateBuilder { #[async_trait] pub(crate) trait TableInternal: std::fmt::Display + std::fmt::Debug + Send + Sync { + #[allow(dead_code)] fn as_any(&self) -> &dyn std::any::Any; /// Cast as [`NativeTable`], or return None it if is not a [`NativeTable`]. fn as_native(&self) -> Option<&NativeTable>; @@ -1254,6 +1260,58 @@ impl NativeTable { Ok(()) } + async fn create_ivf_hnsw_sq_index( + &self, + index: IvfHnswSqIndexBuilder, + field: &Field, + replace: bool, + ) -> Result<()> { + if !Self::supported_vector_data_type(field.data_type()) { + return Err(Error::InvalidInput { + message: format!( + "An IVF HNSW SQ index cannot be created on the column `{}` which has data type {}", + field.name(), + field.data_type() + ), + }); + } + + let num_partitions = if let Some(n) = index.num_partitions { + n + } else { + suggested_num_partitions(self.count_rows(None).await?) + }; + + let mut dataset = self.dataset.get_mut().await?; + let mut ivf_params = IvfBuildParams::new(num_partitions as usize); + ivf_params.sample_rate = index.sample_rate as usize; + ivf_params.max_iters = index.max_iterations as usize; + let hnsw_params = HnswBuildParams::default() + .num_edges(index.m as usize) + .max_num_edges(index.m as usize * 2) + .ef_construction(index.ef_construction as usize); + let sq_params = SQBuildParams { + sample_rate: index.sample_rate as usize, + ..Default::default() + }; + let lance_idx_params = lance::index::vector::VectorIndexParams::with_ivf_hnsw_sq_params( + index.distance_type.into(), + ivf_params, + hnsw_params, + sq_params, + ); + dataset + .create_index( + &[field.name()], + IndexType::Vector, + None, + &lance_idx_params, + replace, + ) + .await?; + Ok(()) + } + async fn create_auto_index(&self, field: &Field, opts: IndexBuilder) -> Result<()> { if Self::supported_vector_data_type(field.data_type()) { self.create_ivf_pq_index(IvfPqIndexBuilder::default(), field, opts.replace) @@ -1497,6 +1555,10 @@ impl TableInternal for NativeTable { Index::Auto => self.create_auto_index(field, opts).await, Index::BTree(_) => self.create_btree_index(field, opts).await, Index::IvfPq(ivf_pq) => self.create_ivf_pq_index(ivf_pq, field, opts.replace).await, + Index::IvfHnswSq(ivf_hnsw_sq) => { + self.create_ivf_hnsw_sq_index(ivf_hnsw_sq, field, opts.replace) + .await + } } } @@ -2357,6 +2419,102 @@ mod tests { ); } + #[tokio::test] + async fn test_create_index_ivf_hnsw_sq() { + use arrow_array::RecordBatch; + use arrow_schema::{DataType, Field, Schema as ArrowSchema}; + use rand; + use std::iter::repeat_with; + + use arrow_array::Float32Array; + + let tmp_dir = tempdir().unwrap(); + let uri = tmp_dir.path().to_str().unwrap(); + let conn = connect(uri).execute().await.unwrap(); + + let dimension = 16; + let schema = Arc::new(ArrowSchema::new(vec![Field::new( + "embeddings", + DataType::FixedSizeList( + Arc::new(Field::new("item", DataType::Float32, true)), + dimension, + ), + false, + )])); + + let mut rng = rand::thread_rng(); + let float_arr = Float32Array::from( + repeat_with(|| rng.gen::()) + .take(512 * dimension as usize) + .collect::>(), + ); + + let vectors = Arc::new(create_fixed_size_list(float_arr, dimension).unwrap()); + let batches = RecordBatchIterator::new( + vec![RecordBatch::try_new(schema.clone(), vec![vectors.clone()]).unwrap()] + .into_iter() + .map(Ok), + schema, + ); + + let table = conn.create_table("test", batches).execute().await.unwrap(); + + assert_eq!( + table + .as_native() + .unwrap() + .count_indexed_rows("my_index") + .await + .unwrap(), + None + ); + assert_eq!( + table + .as_native() + .unwrap() + .count_unindexed_rows("my_index") + .await + .unwrap(), + None + ); + + let index = IvfHnswSqIndexBuilder::default(); + table + .create_index(&["embeddings"], Index::IvfHnswSq(index)) + .execute() + .await + .unwrap(); + + let index_configs = table.list_indices().await.unwrap(); + assert_eq!(index_configs.len(), 1); + let index = index_configs.into_iter().next().unwrap(); + assert_eq!(index.index_type, crate::index::IndexType::IvfPq); + assert_eq!(index.columns, vec!["embeddings".to_string()]); + assert_eq!(table.count_rows(None).await.unwrap(), 512); + assert_eq!(table.name(), "test"); + + let indices = table.as_native().unwrap().load_indices().await.unwrap(); + let index_uuid = &indices[0].index_uuid; + assert_eq!( + table + .as_native() + .unwrap() + .count_indexed_rows(index_uuid) + .await + .unwrap(), + Some(512) + ); + assert_eq!( + table + .as_native() + .unwrap() + .count_unindexed_rows(index_uuid) + .await + .unwrap(), + Some(0) + ); + } + fn create_fixed_size_list(values: T, list_size: i32) -> Result { let list_type = DataType::FixedSizeList( Arc::new(Field::new("item", values.data_type().clone(), true)),