mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-22 21:09:58 +00:00
feat: support IVF_HNSW_SQ (#1284)
Signed-off-by: BubbleCal <bubble-cal@outlook.com>
This commit is contained in:
@@ -19,10 +19,12 @@ use snafu::Snafu;
|
||||
|
||||
#[derive(Debug, Snafu)]
|
||||
pub enum Error {
|
||||
#[allow(dead_code)]
|
||||
#[snafu(display("column '{name}' is missing"))]
|
||||
MissingColumn { name: String },
|
||||
#[snafu(display("{name}: {message}"))]
|
||||
OutOfRange { name: String, message: String },
|
||||
#[allow(dead_code)]
|
||||
#[snafu(display("{index_type} is not a valid index type"))]
|
||||
InvalidIndexType { index_type: String },
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ use neon::prelude::*;
|
||||
pub trait JsObjectExt {
|
||||
fn get_opt_u32(&self, cx: &mut FunctionContext, key: &str) -> Result<Option<u32>>;
|
||||
fn get_usize(&self, cx: &mut FunctionContext, key: &str) -> Result<usize>;
|
||||
#[allow(dead_code)]
|
||||
fn get_opt_usize(&self, cx: &mut FunctionContext, key: &str) -> Result<Option<usize>>;
|
||||
}
|
||||
|
||||
|
||||
@@ -16,7 +16,10 @@ use std::sync::Arc;
|
||||
|
||||
use crate::{table::TableInternal, Result};
|
||||
|
||||
use self::{scalar::BTreeIndexBuilder, vector::IvfPqIndexBuilder};
|
||||
use self::{
|
||||
scalar::BTreeIndexBuilder,
|
||||
vector::{IvfHnswSqIndexBuilder, IvfPqIndexBuilder},
|
||||
};
|
||||
|
||||
pub mod scalar;
|
||||
pub mod vector;
|
||||
@@ -25,6 +28,7 @@ pub enum Index {
|
||||
Auto,
|
||||
BTree(BTreeIndexBuilder),
|
||||
IvfPq(IvfPqIndexBuilder),
|
||||
IvfHnswSq(IvfHnswSqIndexBuilder),
|
||||
}
|
||||
|
||||
/// Builder for the create_index operation
|
||||
@@ -65,6 +69,7 @@ impl IndexBuilder {
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum IndexType {
|
||||
IvfPq,
|
||||
IvfHnswSq,
|
||||
BTree,
|
||||
}
|
||||
|
||||
|
||||
@@ -83,10 +83,14 @@ pub struct VectorIndexStatistics {
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct IvfPqIndexBuilder {
|
||||
pub(crate) distance_type: DistanceType,
|
||||
|
||||
// IVF
|
||||
pub(crate) num_partitions: Option<u32>,
|
||||
pub(crate) num_sub_vectors: Option<u32>,
|
||||
pub(crate) sample_rate: u32,
|
||||
pub(crate) max_iterations: u32,
|
||||
|
||||
// PQ
|
||||
pub(crate) num_sub_vectors: Option<u32>,
|
||||
}
|
||||
|
||||
impl Default for IvfPqIndexBuilder {
|
||||
@@ -201,3 +205,124 @@ pub(crate) fn suggested_num_sub_vectors(dim: u32) -> u32 {
|
||||
1
|
||||
}
|
||||
}
|
||||
|
||||
/// Builder for an IVF_HNSW_SQ index.
|
||||
///
|
||||
/// This index is a combination of IVF and HNSW.
|
||||
/// The IVF part is the same as the IVF PQ index.
|
||||
/// For each IVF partition, this builds a HNSW graph, the graph is used to
|
||||
/// quickly find the closest vectors to a query vector.
|
||||
///
|
||||
/// The SQ (scalar quantizer) is used to compress the vectors,
|
||||
/// each vector is mapped to a 8-bit integer vector, 4x compression ratio for float32 vector.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct IvfHnswSqIndexBuilder {
|
||||
// IVF
|
||||
pub(crate) distance_type: DistanceType,
|
||||
pub(crate) num_partitions: Option<u32>,
|
||||
pub(crate) sample_rate: u32,
|
||||
pub(crate) max_iterations: u32,
|
||||
|
||||
// HNSW
|
||||
pub(crate) m: u32,
|
||||
pub(crate) ef_construction: u32,
|
||||
// SQ
|
||||
// TODO add num_bits for SQ after it supports another num_bits besides 8
|
||||
}
|
||||
|
||||
impl Default for IvfHnswSqIndexBuilder {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
distance_type: DistanceType::L2,
|
||||
num_partitions: None,
|
||||
sample_rate: 256,
|
||||
max_iterations: 50,
|
||||
m: 20,
|
||||
ef_construction: 300,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl IvfHnswSqIndexBuilder {
|
||||
/// [DistanceType] to use to build the index.
|
||||
///
|
||||
/// Default value is [DistanceType::L2].
|
||||
///
|
||||
/// This is used when training the index to calculate the IVF partitions (vectors are
|
||||
/// grouped in partitions with similar vectors according to this distance type)
|
||||
///
|
||||
/// The metric type used to train an index MUST match the metric type used to search the
|
||||
/// index. Failure to do so will yield inaccurate results.
|
||||
///
|
||||
/// Now IVF_HNSW_SQ only supports L2 and Cosine distance types.
|
||||
pub fn distance_type(mut self, distance_type: DistanceType) -> Self {
|
||||
self.distance_type = distance_type;
|
||||
self
|
||||
}
|
||||
|
||||
/// The number of IVF partitions to create.
|
||||
///
|
||||
/// This value should generally scale with the number of rows in the dataset. By default
|
||||
/// the number of partitions is the square root of the number of rows.
|
||||
///
|
||||
/// If this value is too large then the first part of the search (picking the right partition)
|
||||
/// will be slow. If this value is too small then the second part of the search (searching
|
||||
/// within a partition) will be slow.
|
||||
pub fn num_partitions(mut self, num_partitions: u32) -> Self {
|
||||
self.num_partitions = Some(num_partitions);
|
||||
self
|
||||
}
|
||||
|
||||
/// The rate used to calculate the number of training vectors for kmeans and SQ.
|
||||
///
|
||||
/// When an IVF_HNSW_SQ index is trained, we need to calculate partitions and min/max value of vectors. These are groups
|
||||
/// of vectors that are similar to each other. To do this we use an algorithm called kmeans.
|
||||
///
|
||||
/// Running kmeans on a large dataset can be slow. To speed this up we run kmeans on a
|
||||
/// random sample of the data. This parameter controls the size of the sample. The total
|
||||
/// number of vectors used to train the IVF is `sample_rate * num_partitions`.
|
||||
///
|
||||
/// The total number of vectors used to train the SQ is `sample_rate * 2^{num_bits}`.
|
||||
///
|
||||
/// Increasing this value might improve the quality of the index but in most cases the
|
||||
/// default should be sufficient.
|
||||
///
|
||||
/// The default value is 256.
|
||||
pub fn sample_rate(mut self, sample_rate: u32) -> Self {
|
||||
self.sample_rate = sample_rate;
|
||||
self
|
||||
}
|
||||
|
||||
/// Max iterations to train kmeans.
|
||||
///
|
||||
/// When training an IVF index we use kmeans to calculate the partitions. This parameter
|
||||
/// controls how many iterations of kmeans to run.
|
||||
///
|
||||
/// Increasing this might improve the quality of the index but in most cases the parameter
|
||||
/// is unused because kmeans will converge with fewer iterations. The parameter is only
|
||||
/// used in cases where kmeans does not appear to converge. In those cases it is unlikely
|
||||
/// that setting this larger will lead to the index converging anyways.
|
||||
///
|
||||
/// The default value is 50.
|
||||
pub fn max_iterations(mut self, max_iterations: u32) -> Self {
|
||||
self.max_iterations = max_iterations;
|
||||
self
|
||||
}
|
||||
|
||||
/// The number of neighbors to select for each vector in the HNSW graph.
|
||||
/// Bumping this number will increase the recall of the search but also increase the build/search time.
|
||||
/// The default value is 20.
|
||||
pub fn m(mut self, m: u32) -> Self {
|
||||
self.m = m;
|
||||
self
|
||||
}
|
||||
|
||||
/// The number of candidates to evaluate during the construction of the HNSW graph.
|
||||
/// Bumping this number will increase the recall of the search but also increase the build/search time.
|
||||
/// This value should be not less than `ef` in the search phase.
|
||||
/// The default value is 300.
|
||||
pub fn ef_construction(mut self, ef_construction: u32) -> Self {
|
||||
self.ef_construction = ef_construction;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
@@ -38,6 +38,9 @@ use lance::dataset::{
|
||||
};
|
||||
use lance::dataset::{MergeInsertBuilder as LanceMergeInsertBuilder, WhenNotMatchedBySource};
|
||||
use lance::io::WrappingObjectStore;
|
||||
use lance_index::vector::hnsw::builder::HnswBuildParams;
|
||||
use lance_index::vector::ivf::IvfBuildParams;
|
||||
use lance_index::vector::sq::builder::SQBuildParams;
|
||||
use lance_index::IndexType;
|
||||
use lance_index::{optimize::OptimizeOptions, DatasetIndexExt};
|
||||
use log::info;
|
||||
@@ -48,7 +51,9 @@ use crate::arrow::IntoArrow;
|
||||
use crate::connection::NoData;
|
||||
use crate::embeddings::{EmbeddingDefinition, EmbeddingRegistry, MaybeEmbedded, MemoryRegistry};
|
||||
use crate::error::{Error, Result};
|
||||
use crate::index::vector::{IvfPqIndexBuilder, VectorIndex, VectorIndexStatistics};
|
||||
use crate::index::vector::{
|
||||
IvfHnswSqIndexBuilder, IvfPqIndexBuilder, VectorIndex, VectorIndexStatistics,
|
||||
};
|
||||
use crate::index::IndexConfig;
|
||||
use crate::index::{
|
||||
vector::{suggested_num_partitions, suggested_num_sub_vectors},
|
||||
@@ -312,6 +317,7 @@ impl UpdateBuilder {
|
||||
|
||||
#[async_trait]
|
||||
pub(crate) trait TableInternal: std::fmt::Display + std::fmt::Debug + Send + Sync {
|
||||
#[allow(dead_code)]
|
||||
fn as_any(&self) -> &dyn std::any::Any;
|
||||
/// Cast as [`NativeTable`], or return None it if is not a [`NativeTable`].
|
||||
fn as_native(&self) -> Option<&NativeTable>;
|
||||
@@ -1254,6 +1260,58 @@ impl NativeTable {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn create_ivf_hnsw_sq_index(
|
||||
&self,
|
||||
index: IvfHnswSqIndexBuilder,
|
||||
field: &Field,
|
||||
replace: bool,
|
||||
) -> Result<()> {
|
||||
if !Self::supported_vector_data_type(field.data_type()) {
|
||||
return Err(Error::InvalidInput {
|
||||
message: format!(
|
||||
"An IVF HNSW SQ index cannot be created on the column `{}` which has data type {}",
|
||||
field.name(),
|
||||
field.data_type()
|
||||
),
|
||||
});
|
||||
}
|
||||
|
||||
let num_partitions = if let Some(n) = index.num_partitions {
|
||||
n
|
||||
} else {
|
||||
suggested_num_partitions(self.count_rows(None).await?)
|
||||
};
|
||||
|
||||
let mut dataset = self.dataset.get_mut().await?;
|
||||
let mut ivf_params = IvfBuildParams::new(num_partitions as usize);
|
||||
ivf_params.sample_rate = index.sample_rate as usize;
|
||||
ivf_params.max_iters = index.max_iterations as usize;
|
||||
let hnsw_params = HnswBuildParams::default()
|
||||
.num_edges(index.m as usize)
|
||||
.max_num_edges(index.m as usize * 2)
|
||||
.ef_construction(index.ef_construction as usize);
|
||||
let sq_params = SQBuildParams {
|
||||
sample_rate: index.sample_rate as usize,
|
||||
..Default::default()
|
||||
};
|
||||
let lance_idx_params = lance::index::vector::VectorIndexParams::with_ivf_hnsw_sq_params(
|
||||
index.distance_type.into(),
|
||||
ivf_params,
|
||||
hnsw_params,
|
||||
sq_params,
|
||||
);
|
||||
dataset
|
||||
.create_index(
|
||||
&[field.name()],
|
||||
IndexType::Vector,
|
||||
None,
|
||||
&lance_idx_params,
|
||||
replace,
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn create_auto_index(&self, field: &Field, opts: IndexBuilder) -> Result<()> {
|
||||
if Self::supported_vector_data_type(field.data_type()) {
|
||||
self.create_ivf_pq_index(IvfPqIndexBuilder::default(), field, opts.replace)
|
||||
@@ -1497,6 +1555,10 @@ impl TableInternal for NativeTable {
|
||||
Index::Auto => self.create_auto_index(field, opts).await,
|
||||
Index::BTree(_) => self.create_btree_index(field, opts).await,
|
||||
Index::IvfPq(ivf_pq) => self.create_ivf_pq_index(ivf_pq, field, opts.replace).await,
|
||||
Index::IvfHnswSq(ivf_hnsw_sq) => {
|
||||
self.create_ivf_hnsw_sq_index(ivf_hnsw_sq, field, opts.replace)
|
||||
.await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2357,6 +2419,102 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_create_index_ivf_hnsw_sq() {
|
||||
use arrow_array::RecordBatch;
|
||||
use arrow_schema::{DataType, Field, Schema as ArrowSchema};
|
||||
use rand;
|
||||
use std::iter::repeat_with;
|
||||
|
||||
use arrow_array::Float32Array;
|
||||
|
||||
let tmp_dir = tempdir().unwrap();
|
||||
let uri = tmp_dir.path().to_str().unwrap();
|
||||
let conn = connect(uri).execute().await.unwrap();
|
||||
|
||||
let dimension = 16;
|
||||
let schema = Arc::new(ArrowSchema::new(vec![Field::new(
|
||||
"embeddings",
|
||||
DataType::FixedSizeList(
|
||||
Arc::new(Field::new("item", DataType::Float32, true)),
|
||||
dimension,
|
||||
),
|
||||
false,
|
||||
)]));
|
||||
|
||||
let mut rng = rand::thread_rng();
|
||||
let float_arr = Float32Array::from(
|
||||
repeat_with(|| rng.gen::<f32>())
|
||||
.take(512 * dimension as usize)
|
||||
.collect::<Vec<f32>>(),
|
||||
);
|
||||
|
||||
let vectors = Arc::new(create_fixed_size_list(float_arr, dimension).unwrap());
|
||||
let batches = RecordBatchIterator::new(
|
||||
vec![RecordBatch::try_new(schema.clone(), vec![vectors.clone()]).unwrap()]
|
||||
.into_iter()
|
||||
.map(Ok),
|
||||
schema,
|
||||
);
|
||||
|
||||
let table = conn.create_table("test", batches).execute().await.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
table
|
||||
.as_native()
|
||||
.unwrap()
|
||||
.count_indexed_rows("my_index")
|
||||
.await
|
||||
.unwrap(),
|
||||
None
|
||||
);
|
||||
assert_eq!(
|
||||
table
|
||||
.as_native()
|
||||
.unwrap()
|
||||
.count_unindexed_rows("my_index")
|
||||
.await
|
||||
.unwrap(),
|
||||
None
|
||||
);
|
||||
|
||||
let index = IvfHnswSqIndexBuilder::default();
|
||||
table
|
||||
.create_index(&["embeddings"], Index::IvfHnswSq(index))
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let index_configs = table.list_indices().await.unwrap();
|
||||
assert_eq!(index_configs.len(), 1);
|
||||
let index = index_configs.into_iter().next().unwrap();
|
||||
assert_eq!(index.index_type, crate::index::IndexType::IvfPq);
|
||||
assert_eq!(index.columns, vec!["embeddings".to_string()]);
|
||||
assert_eq!(table.count_rows(None).await.unwrap(), 512);
|
||||
assert_eq!(table.name(), "test");
|
||||
|
||||
let indices = table.as_native().unwrap().load_indices().await.unwrap();
|
||||
let index_uuid = &indices[0].index_uuid;
|
||||
assert_eq!(
|
||||
table
|
||||
.as_native()
|
||||
.unwrap()
|
||||
.count_indexed_rows(index_uuid)
|
||||
.await
|
||||
.unwrap(),
|
||||
Some(512)
|
||||
);
|
||||
assert_eq!(
|
||||
table
|
||||
.as_native()
|
||||
.unwrap()
|
||||
.count_unindexed_rows(index_uuid)
|
||||
.await
|
||||
.unwrap(),
|
||||
Some(0)
|
||||
);
|
||||
}
|
||||
|
||||
fn create_fixed_size_list<T: Array>(values: T, list_size: i32) -> Result<FixedSizeListArray> {
|
||||
let list_type = DataType::FixedSizeList(
|
||||
Arc::new(Field::new("item", values.data_type().clone(), true)),
|
||||
|
||||
Reference in New Issue
Block a user