feat: support IVF_HNSW_SQ (#1284)

Signed-off-by: BubbleCal <bubble-cal@outlook.com>
This commit is contained in:
BubbleCal
2024-05-16 14:28:06 +08:00
committed by GitHub
parent 6eaaee59f8
commit 5e01810438
5 changed files with 294 additions and 3 deletions

View File

@@ -19,10 +19,12 @@ use snafu::Snafu;
#[derive(Debug, Snafu)]
pub enum Error {
#[allow(dead_code)]
#[snafu(display("column '{name}' is missing"))]
MissingColumn { name: String },
#[snafu(display("{name}: {message}"))]
OutOfRange { name: String, message: String },
#[allow(dead_code)]
#[snafu(display("{index_type} is not a valid index type"))]
InvalidIndexType { index_type: String },

View File

@@ -19,6 +19,7 @@ use neon::prelude::*;
pub trait JsObjectExt {
fn get_opt_u32(&self, cx: &mut FunctionContext, key: &str) -> Result<Option<u32>>;
fn get_usize(&self, cx: &mut FunctionContext, key: &str) -> Result<usize>;
#[allow(dead_code)]
fn get_opt_usize(&self, cx: &mut FunctionContext, key: &str) -> Result<Option<usize>>;
}

View File

@@ -16,7 +16,10 @@ use std::sync::Arc;
use crate::{table::TableInternal, Result};
use self::{scalar::BTreeIndexBuilder, vector::IvfPqIndexBuilder};
use self::{
scalar::BTreeIndexBuilder,
vector::{IvfHnswSqIndexBuilder, IvfPqIndexBuilder},
};
pub mod scalar;
pub mod vector;
@@ -25,6 +28,7 @@ pub enum Index {
Auto,
BTree(BTreeIndexBuilder),
IvfPq(IvfPqIndexBuilder),
IvfHnswSq(IvfHnswSqIndexBuilder),
}
/// Builder for the create_index operation
@@ -65,6 +69,7 @@ impl IndexBuilder {
#[derive(Debug, Clone, PartialEq)]
pub enum IndexType {
IvfPq,
IvfHnswSq,
BTree,
}

View File

@@ -83,10 +83,14 @@ pub struct VectorIndexStatistics {
#[derive(Debug, Clone)]
pub struct IvfPqIndexBuilder {
pub(crate) distance_type: DistanceType,
// IVF
pub(crate) num_partitions: Option<u32>,
pub(crate) num_sub_vectors: Option<u32>,
pub(crate) sample_rate: u32,
pub(crate) max_iterations: u32,
// PQ
pub(crate) num_sub_vectors: Option<u32>,
}
impl Default for IvfPqIndexBuilder {
@@ -201,3 +205,124 @@ pub(crate) fn suggested_num_sub_vectors(dim: u32) -> u32 {
1
}
}
/// Builder for an IVF_HNSW_SQ index.
///
/// This index is a combination of IVF and HNSW.
/// The IVF part is the same as the IVF PQ index.
/// For each IVF partition, this builds a HNSW graph, the graph is used to
/// quickly find the closest vectors to a query vector.
///
/// The SQ (scalar quantizer) is used to compress the vectors,
/// each vector is mapped to a 8-bit integer vector, 4x compression ratio for float32 vector.
#[derive(Debug, Clone)]
pub struct IvfHnswSqIndexBuilder {
// IVF
pub(crate) distance_type: DistanceType,
pub(crate) num_partitions: Option<u32>,
pub(crate) sample_rate: u32,
pub(crate) max_iterations: u32,
// HNSW
pub(crate) m: u32,
pub(crate) ef_construction: u32,
// SQ
// TODO add num_bits for SQ after it supports another num_bits besides 8
}
impl Default for IvfHnswSqIndexBuilder {
fn default() -> Self {
Self {
distance_type: DistanceType::L2,
num_partitions: None,
sample_rate: 256,
max_iterations: 50,
m: 20,
ef_construction: 300,
}
}
}
impl IvfHnswSqIndexBuilder {
/// [DistanceType] to use to build the index.
///
/// Default value is [DistanceType::L2].
///
/// This is used when training the index to calculate the IVF partitions (vectors are
/// grouped in partitions with similar vectors according to this distance type)
///
/// The metric type used to train an index MUST match the metric type used to search the
/// index. Failure to do so will yield inaccurate results.
///
/// Now IVF_HNSW_SQ only supports L2 and Cosine distance types.
pub fn distance_type(mut self, distance_type: DistanceType) -> Self {
self.distance_type = distance_type;
self
}
/// The number of IVF partitions to create.
///
/// This value should generally scale with the number of rows in the dataset. By default
/// the number of partitions is the square root of the number of rows.
///
/// If this value is too large then the first part of the search (picking the right partition)
/// will be slow. If this value is too small then the second part of the search (searching
/// within a partition) will be slow.
pub fn num_partitions(mut self, num_partitions: u32) -> Self {
self.num_partitions = Some(num_partitions);
self
}
/// The rate used to calculate the number of training vectors for kmeans and SQ.
///
/// When an IVF_HNSW_SQ index is trained, we need to calculate partitions and min/max value of vectors. These are groups
/// of vectors that are similar to each other. To do this we use an algorithm called kmeans.
///
/// Running kmeans on a large dataset can be slow. To speed this up we run kmeans on a
/// random sample of the data. This parameter controls the size of the sample. The total
/// number of vectors used to train the IVF is `sample_rate * num_partitions`.
///
/// The total number of vectors used to train the SQ is `sample_rate * 2^{num_bits}`.
///
/// Increasing this value might improve the quality of the index but in most cases the
/// default should be sufficient.
///
/// The default value is 256.
pub fn sample_rate(mut self, sample_rate: u32) -> Self {
self.sample_rate = sample_rate;
self
}
/// Max iterations to train kmeans.
///
/// When training an IVF index we use kmeans to calculate the partitions. This parameter
/// controls how many iterations of kmeans to run.
///
/// Increasing this might improve the quality of the index but in most cases the parameter
/// is unused because kmeans will converge with fewer iterations. The parameter is only
/// used in cases where kmeans does not appear to converge. In those cases it is unlikely
/// that setting this larger will lead to the index converging anyways.
///
/// The default value is 50.
pub fn max_iterations(mut self, max_iterations: u32) -> Self {
self.max_iterations = max_iterations;
self
}
/// The number of neighbors to select for each vector in the HNSW graph.
/// Bumping this number will increase the recall of the search but also increase the build/search time.
/// The default value is 20.
pub fn m(mut self, m: u32) -> Self {
self.m = m;
self
}
/// The number of candidates to evaluate during the construction of the HNSW graph.
/// Bumping this number will increase the recall of the search but also increase the build/search time.
/// This value should be not less than `ef` in the search phase.
/// The default value is 300.
pub fn ef_construction(mut self, ef_construction: u32) -> Self {
self.ef_construction = ef_construction;
self
}
}

View File

@@ -38,6 +38,9 @@ use lance::dataset::{
};
use lance::dataset::{MergeInsertBuilder as LanceMergeInsertBuilder, WhenNotMatchedBySource};
use lance::io::WrappingObjectStore;
use lance_index::vector::hnsw::builder::HnswBuildParams;
use lance_index::vector::ivf::IvfBuildParams;
use lance_index::vector::sq::builder::SQBuildParams;
use lance_index::IndexType;
use lance_index::{optimize::OptimizeOptions, DatasetIndexExt};
use log::info;
@@ -48,7 +51,9 @@ use crate::arrow::IntoArrow;
use crate::connection::NoData;
use crate::embeddings::{EmbeddingDefinition, EmbeddingRegistry, MaybeEmbedded, MemoryRegistry};
use crate::error::{Error, Result};
use crate::index::vector::{IvfPqIndexBuilder, VectorIndex, VectorIndexStatistics};
use crate::index::vector::{
IvfHnswSqIndexBuilder, IvfPqIndexBuilder, VectorIndex, VectorIndexStatistics,
};
use crate::index::IndexConfig;
use crate::index::{
vector::{suggested_num_partitions, suggested_num_sub_vectors},
@@ -312,6 +317,7 @@ impl UpdateBuilder {
#[async_trait]
pub(crate) trait TableInternal: std::fmt::Display + std::fmt::Debug + Send + Sync {
#[allow(dead_code)]
fn as_any(&self) -> &dyn std::any::Any;
/// Cast as [`NativeTable`], or return None it if is not a [`NativeTable`].
fn as_native(&self) -> Option<&NativeTable>;
@@ -1254,6 +1260,58 @@ impl NativeTable {
Ok(())
}
async fn create_ivf_hnsw_sq_index(
&self,
index: IvfHnswSqIndexBuilder,
field: &Field,
replace: bool,
) -> Result<()> {
if !Self::supported_vector_data_type(field.data_type()) {
return Err(Error::InvalidInput {
message: format!(
"An IVF HNSW SQ index cannot be created on the column `{}` which has data type {}",
field.name(),
field.data_type()
),
});
}
let num_partitions = if let Some(n) = index.num_partitions {
n
} else {
suggested_num_partitions(self.count_rows(None).await?)
};
let mut dataset = self.dataset.get_mut().await?;
let mut ivf_params = IvfBuildParams::new(num_partitions as usize);
ivf_params.sample_rate = index.sample_rate as usize;
ivf_params.max_iters = index.max_iterations as usize;
let hnsw_params = HnswBuildParams::default()
.num_edges(index.m as usize)
.max_num_edges(index.m as usize * 2)
.ef_construction(index.ef_construction as usize);
let sq_params = SQBuildParams {
sample_rate: index.sample_rate as usize,
..Default::default()
};
let lance_idx_params = lance::index::vector::VectorIndexParams::with_ivf_hnsw_sq_params(
index.distance_type.into(),
ivf_params,
hnsw_params,
sq_params,
);
dataset
.create_index(
&[field.name()],
IndexType::Vector,
None,
&lance_idx_params,
replace,
)
.await?;
Ok(())
}
async fn create_auto_index(&self, field: &Field, opts: IndexBuilder) -> Result<()> {
if Self::supported_vector_data_type(field.data_type()) {
self.create_ivf_pq_index(IvfPqIndexBuilder::default(), field, opts.replace)
@@ -1497,6 +1555,10 @@ impl TableInternal for NativeTable {
Index::Auto => self.create_auto_index(field, opts).await,
Index::BTree(_) => self.create_btree_index(field, opts).await,
Index::IvfPq(ivf_pq) => self.create_ivf_pq_index(ivf_pq, field, opts.replace).await,
Index::IvfHnswSq(ivf_hnsw_sq) => {
self.create_ivf_hnsw_sq_index(ivf_hnsw_sq, field, opts.replace)
.await
}
}
}
@@ -2357,6 +2419,102 @@ mod tests {
);
}
#[tokio::test]
async fn test_create_index_ivf_hnsw_sq() {
use arrow_array::RecordBatch;
use arrow_schema::{DataType, Field, Schema as ArrowSchema};
use rand;
use std::iter::repeat_with;
use arrow_array::Float32Array;
let tmp_dir = tempdir().unwrap();
let uri = tmp_dir.path().to_str().unwrap();
let conn = connect(uri).execute().await.unwrap();
let dimension = 16;
let schema = Arc::new(ArrowSchema::new(vec![Field::new(
"embeddings",
DataType::FixedSizeList(
Arc::new(Field::new("item", DataType::Float32, true)),
dimension,
),
false,
)]));
let mut rng = rand::thread_rng();
let float_arr = Float32Array::from(
repeat_with(|| rng.gen::<f32>())
.take(512 * dimension as usize)
.collect::<Vec<f32>>(),
);
let vectors = Arc::new(create_fixed_size_list(float_arr, dimension).unwrap());
let batches = RecordBatchIterator::new(
vec![RecordBatch::try_new(schema.clone(), vec![vectors.clone()]).unwrap()]
.into_iter()
.map(Ok),
schema,
);
let table = conn.create_table("test", batches).execute().await.unwrap();
assert_eq!(
table
.as_native()
.unwrap()
.count_indexed_rows("my_index")
.await
.unwrap(),
None
);
assert_eq!(
table
.as_native()
.unwrap()
.count_unindexed_rows("my_index")
.await
.unwrap(),
None
);
let index = IvfHnswSqIndexBuilder::default();
table
.create_index(&["embeddings"], Index::IvfHnswSq(index))
.execute()
.await
.unwrap();
let index_configs = table.list_indices().await.unwrap();
assert_eq!(index_configs.len(), 1);
let index = index_configs.into_iter().next().unwrap();
assert_eq!(index.index_type, crate::index::IndexType::IvfPq);
assert_eq!(index.columns, vec!["embeddings".to_string()]);
assert_eq!(table.count_rows(None).await.unwrap(), 512);
assert_eq!(table.name(), "test");
let indices = table.as_native().unwrap().load_indices().await.unwrap();
let index_uuid = &indices[0].index_uuid;
assert_eq!(
table
.as_native()
.unwrap()
.count_indexed_rows(index_uuid)
.await
.unwrap(),
Some(512)
);
assert_eq!(
table
.as_native()
.unwrap()
.count_unindexed_rows(index_uuid)
.await
.unwrap(),
Some(0)
);
}
fn create_fixed_size_list<T: Array>(values: T, list_size: i32) -> Result<FixedSizeListArray> {
let list_type = DataType::FixedSizeList(
Arc::new(Field::new("item", values.data_type().clone(), true)),