mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-04 02:42:57 +00:00
feat: support IVF_HNSW_SQ (#1284)
Signed-off-by: BubbleCal <bubble-cal@outlook.com>
This commit is contained in:
@@ -19,10 +19,12 @@ use snafu::Snafu;
|
|||||||
|
|
||||||
#[derive(Debug, Snafu)]
|
#[derive(Debug, Snafu)]
|
||||||
pub enum Error {
|
pub enum Error {
|
||||||
|
#[allow(dead_code)]
|
||||||
#[snafu(display("column '{name}' is missing"))]
|
#[snafu(display("column '{name}' is missing"))]
|
||||||
MissingColumn { name: String },
|
MissingColumn { name: String },
|
||||||
#[snafu(display("{name}: {message}"))]
|
#[snafu(display("{name}: {message}"))]
|
||||||
OutOfRange { name: String, message: String },
|
OutOfRange { name: String, message: String },
|
||||||
|
#[allow(dead_code)]
|
||||||
#[snafu(display("{index_type} is not a valid index type"))]
|
#[snafu(display("{index_type} is not a valid index type"))]
|
||||||
InvalidIndexType { index_type: String },
|
InvalidIndexType { index_type: String },
|
||||||
|
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ use neon::prelude::*;
|
|||||||
pub trait JsObjectExt {
|
pub trait JsObjectExt {
|
||||||
fn get_opt_u32(&self, cx: &mut FunctionContext, key: &str) -> Result<Option<u32>>;
|
fn get_opt_u32(&self, cx: &mut FunctionContext, key: &str) -> Result<Option<u32>>;
|
||||||
fn get_usize(&self, cx: &mut FunctionContext, key: &str) -> Result<usize>;
|
fn get_usize(&self, cx: &mut FunctionContext, key: &str) -> Result<usize>;
|
||||||
|
#[allow(dead_code)]
|
||||||
fn get_opt_usize(&self, cx: &mut FunctionContext, key: &str) -> Result<Option<usize>>;
|
fn get_opt_usize(&self, cx: &mut FunctionContext, key: &str) -> Result<Option<usize>>;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -16,7 +16,10 @@ use std::sync::Arc;
|
|||||||
|
|
||||||
use crate::{table::TableInternal, Result};
|
use crate::{table::TableInternal, Result};
|
||||||
|
|
||||||
use self::{scalar::BTreeIndexBuilder, vector::IvfPqIndexBuilder};
|
use self::{
|
||||||
|
scalar::BTreeIndexBuilder,
|
||||||
|
vector::{IvfHnswSqIndexBuilder, IvfPqIndexBuilder},
|
||||||
|
};
|
||||||
|
|
||||||
pub mod scalar;
|
pub mod scalar;
|
||||||
pub mod vector;
|
pub mod vector;
|
||||||
@@ -25,6 +28,7 @@ pub enum Index {
|
|||||||
Auto,
|
Auto,
|
||||||
BTree(BTreeIndexBuilder),
|
BTree(BTreeIndexBuilder),
|
||||||
IvfPq(IvfPqIndexBuilder),
|
IvfPq(IvfPqIndexBuilder),
|
||||||
|
IvfHnswSq(IvfHnswSqIndexBuilder),
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Builder for the create_index operation
|
/// Builder for the create_index operation
|
||||||
@@ -65,6 +69,7 @@ impl IndexBuilder {
|
|||||||
#[derive(Debug, Clone, PartialEq)]
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
pub enum IndexType {
|
pub enum IndexType {
|
||||||
IvfPq,
|
IvfPq,
|
||||||
|
IvfHnswSq,
|
||||||
BTree,
|
BTree,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -83,10 +83,14 @@ pub struct VectorIndexStatistics {
|
|||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct IvfPqIndexBuilder {
|
pub struct IvfPqIndexBuilder {
|
||||||
pub(crate) distance_type: DistanceType,
|
pub(crate) distance_type: DistanceType,
|
||||||
|
|
||||||
|
// IVF
|
||||||
pub(crate) num_partitions: Option<u32>,
|
pub(crate) num_partitions: Option<u32>,
|
||||||
pub(crate) num_sub_vectors: Option<u32>,
|
|
||||||
pub(crate) sample_rate: u32,
|
pub(crate) sample_rate: u32,
|
||||||
pub(crate) max_iterations: u32,
|
pub(crate) max_iterations: u32,
|
||||||
|
|
||||||
|
// PQ
|
||||||
|
pub(crate) num_sub_vectors: Option<u32>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for IvfPqIndexBuilder {
|
impl Default for IvfPqIndexBuilder {
|
||||||
@@ -201,3 +205,124 @@ pub(crate) fn suggested_num_sub_vectors(dim: u32) -> u32 {
|
|||||||
1
|
1
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Builder for an IVF_HNSW_SQ index.
|
||||||
|
///
|
||||||
|
/// This index is a combination of IVF and HNSW.
|
||||||
|
/// The IVF part is the same as the IVF PQ index.
|
||||||
|
/// For each IVF partition, this builds a HNSW graph, the graph is used to
|
||||||
|
/// quickly find the closest vectors to a query vector.
|
||||||
|
///
|
||||||
|
/// The SQ (scalar quantizer) is used to compress the vectors,
|
||||||
|
/// each vector is mapped to a 8-bit integer vector, 4x compression ratio for float32 vector.
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct IvfHnswSqIndexBuilder {
|
||||||
|
// IVF
|
||||||
|
pub(crate) distance_type: DistanceType,
|
||||||
|
pub(crate) num_partitions: Option<u32>,
|
||||||
|
pub(crate) sample_rate: u32,
|
||||||
|
pub(crate) max_iterations: u32,
|
||||||
|
|
||||||
|
// HNSW
|
||||||
|
pub(crate) m: u32,
|
||||||
|
pub(crate) ef_construction: u32,
|
||||||
|
// SQ
|
||||||
|
// TODO add num_bits for SQ after it supports another num_bits besides 8
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for IvfHnswSqIndexBuilder {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
distance_type: DistanceType::L2,
|
||||||
|
num_partitions: None,
|
||||||
|
sample_rate: 256,
|
||||||
|
max_iterations: 50,
|
||||||
|
m: 20,
|
||||||
|
ef_construction: 300,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl IvfHnswSqIndexBuilder {
|
||||||
|
/// [DistanceType] to use to build the index.
|
||||||
|
///
|
||||||
|
/// Default value is [DistanceType::L2].
|
||||||
|
///
|
||||||
|
/// This is used when training the index to calculate the IVF partitions (vectors are
|
||||||
|
/// grouped in partitions with similar vectors according to this distance type)
|
||||||
|
///
|
||||||
|
/// The metric type used to train an index MUST match the metric type used to search the
|
||||||
|
/// index. Failure to do so will yield inaccurate results.
|
||||||
|
///
|
||||||
|
/// Now IVF_HNSW_SQ only supports L2 and Cosine distance types.
|
||||||
|
pub fn distance_type(mut self, distance_type: DistanceType) -> Self {
|
||||||
|
self.distance_type = distance_type;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The number of IVF partitions to create.
|
||||||
|
///
|
||||||
|
/// This value should generally scale with the number of rows in the dataset. By default
|
||||||
|
/// the number of partitions is the square root of the number of rows.
|
||||||
|
///
|
||||||
|
/// If this value is too large then the first part of the search (picking the right partition)
|
||||||
|
/// will be slow. If this value is too small then the second part of the search (searching
|
||||||
|
/// within a partition) will be slow.
|
||||||
|
pub fn num_partitions(mut self, num_partitions: u32) -> Self {
|
||||||
|
self.num_partitions = Some(num_partitions);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The rate used to calculate the number of training vectors for kmeans and SQ.
|
||||||
|
///
|
||||||
|
/// When an IVF_HNSW_SQ index is trained, we need to calculate partitions and min/max value of vectors. These are groups
|
||||||
|
/// of vectors that are similar to each other. To do this we use an algorithm called kmeans.
|
||||||
|
///
|
||||||
|
/// Running kmeans on a large dataset can be slow. To speed this up we run kmeans on a
|
||||||
|
/// random sample of the data. This parameter controls the size of the sample. The total
|
||||||
|
/// number of vectors used to train the IVF is `sample_rate * num_partitions`.
|
||||||
|
///
|
||||||
|
/// The total number of vectors used to train the SQ is `sample_rate * 2^{num_bits}`.
|
||||||
|
///
|
||||||
|
/// Increasing this value might improve the quality of the index but in most cases the
|
||||||
|
/// default should be sufficient.
|
||||||
|
///
|
||||||
|
/// The default value is 256.
|
||||||
|
pub fn sample_rate(mut self, sample_rate: u32) -> Self {
|
||||||
|
self.sample_rate = sample_rate;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Max iterations to train kmeans.
|
||||||
|
///
|
||||||
|
/// When training an IVF index we use kmeans to calculate the partitions. This parameter
|
||||||
|
/// controls how many iterations of kmeans to run.
|
||||||
|
///
|
||||||
|
/// Increasing this might improve the quality of the index but in most cases the parameter
|
||||||
|
/// is unused because kmeans will converge with fewer iterations. The parameter is only
|
||||||
|
/// used in cases where kmeans does not appear to converge. In those cases it is unlikely
|
||||||
|
/// that setting this larger will lead to the index converging anyways.
|
||||||
|
///
|
||||||
|
/// The default value is 50.
|
||||||
|
pub fn max_iterations(mut self, max_iterations: u32) -> Self {
|
||||||
|
self.max_iterations = max_iterations;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The number of neighbors to select for each vector in the HNSW graph.
|
||||||
|
/// Bumping this number will increase the recall of the search but also increase the build/search time.
|
||||||
|
/// The default value is 20.
|
||||||
|
pub fn m(mut self, m: u32) -> Self {
|
||||||
|
self.m = m;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The number of candidates to evaluate during the construction of the HNSW graph.
|
||||||
|
/// Bumping this number will increase the recall of the search but also increase the build/search time.
|
||||||
|
/// This value should be not less than `ef` in the search phase.
|
||||||
|
/// The default value is 300.
|
||||||
|
pub fn ef_construction(mut self, ef_construction: u32) -> Self {
|
||||||
|
self.ef_construction = ef_construction;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -38,6 +38,9 @@ use lance::dataset::{
|
|||||||
};
|
};
|
||||||
use lance::dataset::{MergeInsertBuilder as LanceMergeInsertBuilder, WhenNotMatchedBySource};
|
use lance::dataset::{MergeInsertBuilder as LanceMergeInsertBuilder, WhenNotMatchedBySource};
|
||||||
use lance::io::WrappingObjectStore;
|
use lance::io::WrappingObjectStore;
|
||||||
|
use lance_index::vector::hnsw::builder::HnswBuildParams;
|
||||||
|
use lance_index::vector::ivf::IvfBuildParams;
|
||||||
|
use lance_index::vector::sq::builder::SQBuildParams;
|
||||||
use lance_index::IndexType;
|
use lance_index::IndexType;
|
||||||
use lance_index::{optimize::OptimizeOptions, DatasetIndexExt};
|
use lance_index::{optimize::OptimizeOptions, DatasetIndexExt};
|
||||||
use log::info;
|
use log::info;
|
||||||
@@ -48,7 +51,9 @@ use crate::arrow::IntoArrow;
|
|||||||
use crate::connection::NoData;
|
use crate::connection::NoData;
|
||||||
use crate::embeddings::{EmbeddingDefinition, EmbeddingRegistry, MaybeEmbedded, MemoryRegistry};
|
use crate::embeddings::{EmbeddingDefinition, EmbeddingRegistry, MaybeEmbedded, MemoryRegistry};
|
||||||
use crate::error::{Error, Result};
|
use crate::error::{Error, Result};
|
||||||
use crate::index::vector::{IvfPqIndexBuilder, VectorIndex, VectorIndexStatistics};
|
use crate::index::vector::{
|
||||||
|
IvfHnswSqIndexBuilder, IvfPqIndexBuilder, VectorIndex, VectorIndexStatistics,
|
||||||
|
};
|
||||||
use crate::index::IndexConfig;
|
use crate::index::IndexConfig;
|
||||||
use crate::index::{
|
use crate::index::{
|
||||||
vector::{suggested_num_partitions, suggested_num_sub_vectors},
|
vector::{suggested_num_partitions, suggested_num_sub_vectors},
|
||||||
@@ -312,6 +317,7 @@ impl UpdateBuilder {
|
|||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
pub(crate) trait TableInternal: std::fmt::Display + std::fmt::Debug + Send + Sync {
|
pub(crate) trait TableInternal: std::fmt::Display + std::fmt::Debug + Send + Sync {
|
||||||
|
#[allow(dead_code)]
|
||||||
fn as_any(&self) -> &dyn std::any::Any;
|
fn as_any(&self) -> &dyn std::any::Any;
|
||||||
/// Cast as [`NativeTable`], or return None it if is not a [`NativeTable`].
|
/// Cast as [`NativeTable`], or return None it if is not a [`NativeTable`].
|
||||||
fn as_native(&self) -> Option<&NativeTable>;
|
fn as_native(&self) -> Option<&NativeTable>;
|
||||||
@@ -1254,6 +1260,58 @@ impl NativeTable {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn create_ivf_hnsw_sq_index(
|
||||||
|
&self,
|
||||||
|
index: IvfHnswSqIndexBuilder,
|
||||||
|
field: &Field,
|
||||||
|
replace: bool,
|
||||||
|
) -> Result<()> {
|
||||||
|
if !Self::supported_vector_data_type(field.data_type()) {
|
||||||
|
return Err(Error::InvalidInput {
|
||||||
|
message: format!(
|
||||||
|
"An IVF HNSW SQ index cannot be created on the column `{}` which has data type {}",
|
||||||
|
field.name(),
|
||||||
|
field.data_type()
|
||||||
|
),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
let num_partitions = if let Some(n) = index.num_partitions {
|
||||||
|
n
|
||||||
|
} else {
|
||||||
|
suggested_num_partitions(self.count_rows(None).await?)
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut dataset = self.dataset.get_mut().await?;
|
||||||
|
let mut ivf_params = IvfBuildParams::new(num_partitions as usize);
|
||||||
|
ivf_params.sample_rate = index.sample_rate as usize;
|
||||||
|
ivf_params.max_iters = index.max_iterations as usize;
|
||||||
|
let hnsw_params = HnswBuildParams::default()
|
||||||
|
.num_edges(index.m as usize)
|
||||||
|
.max_num_edges(index.m as usize * 2)
|
||||||
|
.ef_construction(index.ef_construction as usize);
|
||||||
|
let sq_params = SQBuildParams {
|
||||||
|
sample_rate: index.sample_rate as usize,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let lance_idx_params = lance::index::vector::VectorIndexParams::with_ivf_hnsw_sq_params(
|
||||||
|
index.distance_type.into(),
|
||||||
|
ivf_params,
|
||||||
|
hnsw_params,
|
||||||
|
sq_params,
|
||||||
|
);
|
||||||
|
dataset
|
||||||
|
.create_index(
|
||||||
|
&[field.name()],
|
||||||
|
IndexType::Vector,
|
||||||
|
None,
|
||||||
|
&lance_idx_params,
|
||||||
|
replace,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
async fn create_auto_index(&self, field: &Field, opts: IndexBuilder) -> Result<()> {
|
async fn create_auto_index(&self, field: &Field, opts: IndexBuilder) -> Result<()> {
|
||||||
if Self::supported_vector_data_type(field.data_type()) {
|
if Self::supported_vector_data_type(field.data_type()) {
|
||||||
self.create_ivf_pq_index(IvfPqIndexBuilder::default(), field, opts.replace)
|
self.create_ivf_pq_index(IvfPqIndexBuilder::default(), field, opts.replace)
|
||||||
@@ -1497,6 +1555,10 @@ impl TableInternal for NativeTable {
|
|||||||
Index::Auto => self.create_auto_index(field, opts).await,
|
Index::Auto => self.create_auto_index(field, opts).await,
|
||||||
Index::BTree(_) => self.create_btree_index(field, opts).await,
|
Index::BTree(_) => self.create_btree_index(field, opts).await,
|
||||||
Index::IvfPq(ivf_pq) => self.create_ivf_pq_index(ivf_pq, field, opts.replace).await,
|
Index::IvfPq(ivf_pq) => self.create_ivf_pq_index(ivf_pq, field, opts.replace).await,
|
||||||
|
Index::IvfHnswSq(ivf_hnsw_sq) => {
|
||||||
|
self.create_ivf_hnsw_sq_index(ivf_hnsw_sq, field, opts.replace)
|
||||||
|
.await
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2357,6 +2419,102 @@ mod tests {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_create_index_ivf_hnsw_sq() {
|
||||||
|
use arrow_array::RecordBatch;
|
||||||
|
use arrow_schema::{DataType, Field, Schema as ArrowSchema};
|
||||||
|
use rand;
|
||||||
|
use std::iter::repeat_with;
|
||||||
|
|
||||||
|
use arrow_array::Float32Array;
|
||||||
|
|
||||||
|
let tmp_dir = tempdir().unwrap();
|
||||||
|
let uri = tmp_dir.path().to_str().unwrap();
|
||||||
|
let conn = connect(uri).execute().await.unwrap();
|
||||||
|
|
||||||
|
let dimension = 16;
|
||||||
|
let schema = Arc::new(ArrowSchema::new(vec![Field::new(
|
||||||
|
"embeddings",
|
||||||
|
DataType::FixedSizeList(
|
||||||
|
Arc::new(Field::new("item", DataType::Float32, true)),
|
||||||
|
dimension,
|
||||||
|
),
|
||||||
|
false,
|
||||||
|
)]));
|
||||||
|
|
||||||
|
let mut rng = rand::thread_rng();
|
||||||
|
let float_arr = Float32Array::from(
|
||||||
|
repeat_with(|| rng.gen::<f32>())
|
||||||
|
.take(512 * dimension as usize)
|
||||||
|
.collect::<Vec<f32>>(),
|
||||||
|
);
|
||||||
|
|
||||||
|
let vectors = Arc::new(create_fixed_size_list(float_arr, dimension).unwrap());
|
||||||
|
let batches = RecordBatchIterator::new(
|
||||||
|
vec![RecordBatch::try_new(schema.clone(), vec![vectors.clone()]).unwrap()]
|
||||||
|
.into_iter()
|
||||||
|
.map(Ok),
|
||||||
|
schema,
|
||||||
|
);
|
||||||
|
|
||||||
|
let table = conn.create_table("test", batches).execute().await.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
table
|
||||||
|
.as_native()
|
||||||
|
.unwrap()
|
||||||
|
.count_indexed_rows("my_index")
|
||||||
|
.await
|
||||||
|
.unwrap(),
|
||||||
|
None
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
table
|
||||||
|
.as_native()
|
||||||
|
.unwrap()
|
||||||
|
.count_unindexed_rows("my_index")
|
||||||
|
.await
|
||||||
|
.unwrap(),
|
||||||
|
None
|
||||||
|
);
|
||||||
|
|
||||||
|
let index = IvfHnswSqIndexBuilder::default();
|
||||||
|
table
|
||||||
|
.create_index(&["embeddings"], Index::IvfHnswSq(index))
|
||||||
|
.execute()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let index_configs = table.list_indices().await.unwrap();
|
||||||
|
assert_eq!(index_configs.len(), 1);
|
||||||
|
let index = index_configs.into_iter().next().unwrap();
|
||||||
|
assert_eq!(index.index_type, crate::index::IndexType::IvfPq);
|
||||||
|
assert_eq!(index.columns, vec!["embeddings".to_string()]);
|
||||||
|
assert_eq!(table.count_rows(None).await.unwrap(), 512);
|
||||||
|
assert_eq!(table.name(), "test");
|
||||||
|
|
||||||
|
let indices = table.as_native().unwrap().load_indices().await.unwrap();
|
||||||
|
let index_uuid = &indices[0].index_uuid;
|
||||||
|
assert_eq!(
|
||||||
|
table
|
||||||
|
.as_native()
|
||||||
|
.unwrap()
|
||||||
|
.count_indexed_rows(index_uuid)
|
||||||
|
.await
|
||||||
|
.unwrap(),
|
||||||
|
Some(512)
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
table
|
||||||
|
.as_native()
|
||||||
|
.unwrap()
|
||||||
|
.count_unindexed_rows(index_uuid)
|
||||||
|
.await
|
||||||
|
.unwrap(),
|
||||||
|
Some(0)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
fn create_fixed_size_list<T: Array>(values: T, list_size: i32) -> Result<FixedSizeListArray> {
|
fn create_fixed_size_list<T: Array>(values: T, list_size: i32) -> Result<FixedSizeListArray> {
|
||||||
let list_type = DataType::FixedSizeList(
|
let list_type = DataType::FixedSizeList(
|
||||||
Arc::new(Field::new("item", values.data_type().clone(), true)),
|
Arc::new(Field::new("item", values.data_type().clone(), true)),
|
||||||
|
|||||||
Reference in New Issue
Block a user