feat: add IVF SQ index support and HNSW aliases (#2832)

Adds IVF_SQ index config through Rust core and Python bindings, plus
alias names IvfHnswSq/Pq for backward compatibility. Updates
remote/table helpers and types to accept the new index type. Includes
tests covering IVF SQ creation and alias usage.
This commit is contained in:
BubbleCal
2025-12-04 00:25:44 +08:00
committed by GitHub
parent b0170ea86a
commit a61461331c
11 changed files with 237 additions and 13 deletions

View File

@@ -3,7 +3,17 @@ from typing import Dict, List, Optional, Tuple, Any, TypedDict, Union, Literal
import pyarrow as pa import pyarrow as pa
from .index import BTree, IvfFlat, IvfPq, Bitmap, LabelList, HnswPq, HnswSq, FTS from .index import (
BTree,
IvfFlat,
IvfPq,
IvfSq,
Bitmap,
LabelList,
HnswPq,
HnswSq,
FTS,
)
from .io import StorageOptionsProvider from .io import StorageOptionsProvider
from lance_namespace import ( from lance_namespace import (
ListNamespacesResponse, ListNamespacesResponse,
@@ -14,6 +24,9 @@ from lance_namespace import (
) )
from .remote import ClientConfig from .remote import ClientConfig
IvfHnswPq: type[HnswPq] = HnswPq
IvfHnswSq: type[HnswSq] = HnswSq
class Session: class Session:
def __init__( def __init__(
self, self,
@@ -131,7 +144,17 @@ class Table:
async def create_index( async def create_index(
self, self,
column: str, column: str,
index: Union[IvfFlat, IvfPq, HnswPq, HnswSq, BTree, Bitmap, LabelList, FTS], index: Union[
IvfFlat,
IvfSq,
IvfPq,
HnswPq,
HnswSq,
BTree,
Bitmap,
LabelList,
FTS,
],
replace: Optional[bool], replace: Optional[bool],
wait_timeout: Optional[object], wait_timeout: Optional[object],
*, *,

View File

@@ -376,6 +376,11 @@ class HnswSq:
target_partition_size: Optional[int] = None target_partition_size: Optional[int] = None
# Backwards-compatible aliases
IvfHnswPq = HnswPq
IvfHnswSq = HnswSq
@dataclass @dataclass
class IvfFlat: class IvfFlat:
"""Describes an IVF Flat Index """Describes an IVF Flat Index
@@ -475,6 +480,36 @@ class IvfFlat:
target_partition_size: Optional[int] = None target_partition_size: Optional[int] = None
@dataclass
class IvfSq:
"""Describes an IVF Scalar Quantization (SQ) index.
This index applies scalar quantization to compress vectors and organizes the
quantized vectors into IVF partitions. It offers a balance between search
speed and storage efficiency while keeping good recall.
Attributes
----------
distance_type: str, default "l2"
The distance metric used to train and search the index. Supported values
are "l2", "cosine", and "dot".
num_partitions: int, default sqrt(num_rows)
Number of IVF partitions to create.
max_iterations: int, default 50
Maximum iterations for kmeans during partition training.
sample_rate: int, default 256
Controls the number of training vectors: sample_rate * num_partitions.
target_partition_size: int, optional
Target size for each partition; adjusts the balance between speed and accuracy.
"""
distance_type: Literal["l2", "cosine", "dot"] = "l2"
num_partitions: Optional[int] = None
max_iterations: int = 50
sample_rate: int = 256
target_partition_size: Optional[int] = None
@dataclass @dataclass
class IvfPq: class IvfPq:
"""Describes an IVF PQ Index """Describes an IVF PQ Index
@@ -661,6 +696,9 @@ class IvfRq:
__all__ = [ __all__ = [
"BTree", "BTree",
"IvfPq", "IvfPq",
"IvfHnswPq",
"IvfHnswSq",
"IvfSq",
"IvfRq", "IvfRq",
"IvfFlat", "IvfFlat",
"HnswPq", "HnswPq",

View File

@@ -18,7 +18,7 @@ from lancedb._lancedb import (
UpdateResult, UpdateResult,
) )
from lancedb.embeddings.base import EmbeddingFunctionConfig from lancedb.embeddings.base import EmbeddingFunctionConfig
from lancedb.index import FTS, BTree, Bitmap, HnswSq, IvfFlat, IvfPq, LabelList from lancedb.index import FTS, BTree, Bitmap, HnswSq, IvfFlat, IvfPq, IvfSq, LabelList
from lancedb.remote.db import LOOP from lancedb.remote.db import LOOP
import pyarrow as pa import pyarrow as pa
@@ -265,6 +265,8 @@ class RemoteTable(Table):
num_sub_vectors=num_sub_vectors, num_sub_vectors=num_sub_vectors,
num_bits=num_bits, num_bits=num_bits,
) )
elif index_type == "IVF_SQ":
config = IvfSq(distance_type=metric, num_partitions=num_partitions)
elif index_type == "IVF_HNSW_PQ": elif index_type == "IVF_HNSW_PQ":
raise ValueError( raise ValueError(
"IVF_HNSW_PQ is not supported on LanceDB cloud." "IVF_HNSW_PQ is not supported on LanceDB cloud."
@@ -277,7 +279,7 @@ class RemoteTable(Table):
else: else:
raise ValueError( raise ValueError(
f"Unknown vector index type: {index_type}. Valid options are" f"Unknown vector index type: {index_type}. Valid options are"
" 'IVF_FLAT', 'IVF_PQ', 'IVF_HNSW_PQ', 'IVF_HNSW_SQ'" " 'IVF_FLAT', 'IVF_SQ', 'IVF_PQ', 'IVF_HNSW_PQ', 'IVF_HNSW_SQ'"
) )
LOOP.run( LOOP.run(

View File

@@ -44,7 +44,18 @@ import numpy as np
from .common import DATA, VEC, VECTOR_COLUMN_NAME from .common import DATA, VEC, VECTOR_COLUMN_NAME
from .embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry from .embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
from .index import BTree, IvfFlat, IvfPq, Bitmap, IvfRq, LabelList, HnswPq, HnswSq, FTS from .index import (
BTree,
IvfFlat,
IvfPq,
IvfSq,
Bitmap,
IvfRq,
LabelList,
HnswPq,
HnswSq,
FTS,
)
from .merge import LanceMergeInsertBuilder from .merge import LanceMergeInsertBuilder
from .pydantic import LanceModel, model_to_dict from .pydantic import LanceModel, model_to_dict
from .query import ( from .query import (
@@ -2054,7 +2065,7 @@ class LanceTable(Table):
index_cache_size: Optional[int] = None, index_cache_size: Optional[int] = None,
num_bits: int = 8, num_bits: int = 8,
index_type: Literal[ index_type: Literal[
"IVF_FLAT", "IVF_PQ", "IVF_RQ", "IVF_HNSW_SQ", "IVF_HNSW_PQ" "IVF_FLAT", "IVF_SQ", "IVF_PQ", "IVF_RQ", "IVF_HNSW_SQ", "IVF_HNSW_PQ"
] = "IVF_PQ", ] = "IVF_PQ",
max_iterations: int = 50, max_iterations: int = 50,
sample_rate: int = 256, sample_rate: int = 256,
@@ -2092,6 +2103,14 @@ class LanceTable(Table):
sample_rate=sample_rate, sample_rate=sample_rate,
target_partition_size=target_partition_size, target_partition_size=target_partition_size,
) )
elif index_type == "IVF_SQ":
config = IvfSq(
distance_type=metric,
num_partitions=num_partitions,
max_iterations=max_iterations,
sample_rate=sample_rate,
target_partition_size=target_partition_size,
)
elif index_type == "IVF_PQ": elif index_type == "IVF_PQ":
config = IvfPq( config = IvfPq(
distance_type=metric, distance_type=metric,
@@ -3456,11 +3475,22 @@ class AsyncTable:
if config is not None: if config is not None:
if not isinstance( if not isinstance(
config, config,
(IvfFlat, IvfPq, IvfRq, HnswPq, HnswSq, BTree, Bitmap, LabelList, FTS), (
IvfFlat,
IvfSq,
IvfPq,
IvfRq,
HnswPq,
HnswSq,
BTree,
Bitmap,
LabelList,
FTS,
),
): ):
raise TypeError( raise TypeError(
"config must be an instance of IvfPq, IvfRq, HnswPq, HnswSq, BTree," "config must be an instance of IvfSq, IvfPq, IvfRq, HnswPq, HnswSq,"
" Bitmap, LabelList, or FTS, but got " + str(type(config)) " BTree, Bitmap, LabelList, or FTS, but got " + str(type(config))
) )
try: try:
await self._inner.create_index( await self._inner.create_index(

View File

@@ -18,12 +18,20 @@ AddMode = Literal["append", "overwrite"]
CreateMode = Literal["create", "overwrite"] CreateMode = Literal["create", "overwrite"]
# Index type literals # Index type literals
VectorIndexType = Literal["IVF_FLAT", "IVF_PQ", "IVF_HNSW_SQ", "IVF_HNSW_PQ", "IVF_RQ"] VectorIndexType = Literal[
"IVF_FLAT",
"IVF_SQ",
"IVF_PQ",
"IVF_HNSW_SQ",
"IVF_HNSW_PQ",
"IVF_RQ",
]
ScalarIndexType = Literal["BTREE", "BITMAP", "LABEL_LIST"] ScalarIndexType = Literal["BTREE", "BITMAP", "LABEL_LIST"]
IndexType = Literal[ IndexType = Literal[
"IVF_PQ", "IVF_PQ",
"IVF_HNSW_PQ", "IVF_HNSW_PQ",
"IVF_HNSW_SQ", "IVF_HNSW_SQ",
"IVF_SQ",
"FTS", "FTS",
"BTREE", "BTREE",
"BITMAP", "BITMAP",

View File

@@ -12,6 +12,9 @@ from lancedb.index import (
BTree, BTree,
IvfFlat, IvfFlat,
IvfPq, IvfPq,
IvfSq,
IvfHnswPq,
IvfHnswSq,
IvfRq, IvfRq,
Bitmap, Bitmap,
LabelList, LabelList,
@@ -229,6 +232,35 @@ async def test_create_hnswsq_index(some_table: AsyncTable):
assert len(indices) == 1 assert len(indices) == 1
@pytest.mark.asyncio
async def test_create_hnswsq_alias_index(some_table: AsyncTable):
await some_table.create_index("vector", config=IvfHnswSq(num_partitions=5))
indices = await some_table.list_indices()
assert len(indices) == 1
assert indices[0].index_type in {"HnswSq", "IvfHnswSq"}
@pytest.mark.asyncio
async def test_create_hnswpq_alias_index(some_table: AsyncTable):
await some_table.create_index("vector", config=IvfHnswPq(num_partitions=5))
indices = await some_table.list_indices()
assert len(indices) == 1
assert indices[0].index_type in {"HnswPq", "IvfHnswPq"}
@pytest.mark.asyncio
async def test_create_ivfsq_index(some_table: AsyncTable):
await some_table.create_index("vector", config=IvfSq(num_partitions=10))
indices = await some_table.list_indices()
assert len(indices) == 1
assert indices[0].index_type == "IvfSq"
stats = await some_table.index_stats(indices[0].name)
assert stats.index_type == "IVF_SQ"
assert stats.distance_type == "l2"
assert stats.num_indexed_rows == await some_table.count_rows()
assert stats.num_unindexed_rows == 0
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_index_with_binary_vectors(binary_table: AsyncTable): async def test_create_index_with_binary_vectors(binary_table: AsyncTable):
await binary_table.create_index( await binary_table.create_index(

View File

@@ -1,7 +1,7 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors // SPDX-FileCopyrightText: Copyright The LanceDB Authors
use lancedb::index::vector::{IvfFlatIndexBuilder, IvfRqIndexBuilder}; use lancedb::index::vector::{IvfFlatIndexBuilder, IvfRqIndexBuilder, IvfSqIndexBuilder};
use lancedb::index::{ use lancedb::index::{
scalar::{BTreeIndexBuilder, FtsIndexBuilder}, scalar::{BTreeIndexBuilder, FtsIndexBuilder},
vector::{IvfHnswPqIndexBuilder, IvfHnswSqIndexBuilder, IvfPqIndexBuilder}, vector::{IvfHnswPqIndexBuilder, IvfHnswSqIndexBuilder, IvfPqIndexBuilder},
@@ -87,6 +87,21 @@ pub fn extract_index_params(source: &Option<Bound<'_, PyAny>>) -> PyResult<Lance
} }
Ok(LanceDbIndex::IvfPq(ivf_pq_builder)) Ok(LanceDbIndex::IvfPq(ivf_pq_builder))
}, },
"IvfSq" => {
let params = source.extract::<IvfSqParams>()?;
let distance_type = parse_distance_type(params.distance_type)?;
let mut ivf_sq_builder = IvfSqIndexBuilder::default()
.distance_type(distance_type)
.max_iterations(params.max_iterations)
.sample_rate(params.sample_rate);
if let Some(num_partitions) = params.num_partitions {
ivf_sq_builder = ivf_sq_builder.num_partitions(num_partitions);
}
if let Some(target_partition_size) = params.target_partition_size {
ivf_sq_builder = ivf_sq_builder.target_partition_size(target_partition_size);
}
Ok(LanceDbIndex::IvfSq(ivf_sq_builder))
},
"IvfRq" => { "IvfRq" => {
let params = source.extract::<IvfRqParams>()?; let params = source.extract::<IvfRqParams>()?;
let distance_type = parse_distance_type(params.distance_type)?; let distance_type = parse_distance_type(params.distance_type)?;
@@ -142,7 +157,7 @@ pub fn extract_index_params(source: &Option<Bound<'_, PyAny>>) -> PyResult<Lance
Ok(LanceDbIndex::IvfHnswSq(hnsw_sq_builder)) Ok(LanceDbIndex::IvfHnswSq(hnsw_sq_builder))
}, },
not_supported => Err(PyValueError::new_err(format!( not_supported => Err(PyValueError::new_err(format!(
"Invalid index type '{}'. Must be one of BTree, Bitmap, LabelList, FTS, IvfPq, IvfHnswPq, or IvfHnswSq", "Invalid index type '{}'. Must be one of BTree, Bitmap, LabelList, FTS, IvfPq, IvfSq, IvfHnswPq, or IvfHnswSq",
not_supported not_supported
))), ))),
} }
@@ -186,6 +201,15 @@ struct IvfPqParams {
target_partition_size: Option<u32>, target_partition_size: Option<u32>,
} }
#[derive(FromPyObject)]
struct IvfSqParams {
distance_type: String,
num_partitions: Option<u32>,
max_iterations: u32,
sample_rate: u32,
target_partition_size: Option<u32>,
}
#[derive(FromPyObject)] #[derive(FromPyObject)]
struct IvfRqParams { struct IvfRqParams {
distance_type: String, distance_type: String,

View File

@@ -13,7 +13,7 @@ use crate::{table::BaseTable, DistanceType, Error, Result};
use self::{ use self::{
scalar::{BTreeIndexBuilder, BitmapIndexBuilder, LabelListIndexBuilder}, scalar::{BTreeIndexBuilder, BitmapIndexBuilder, LabelListIndexBuilder},
vector::{IvfHnswPqIndexBuilder, IvfHnswSqIndexBuilder, IvfPqIndexBuilder}, vector::{IvfHnswPqIndexBuilder, IvfHnswSqIndexBuilder, IvfPqIndexBuilder, IvfSqIndexBuilder},
}; };
pub mod scalar; pub mod scalar;
@@ -54,6 +54,9 @@ pub enum Index {
/// IVF index with Product Quantization /// IVF index with Product Quantization
IvfPq(IvfPqIndexBuilder), IvfPq(IvfPqIndexBuilder),
/// IVF index with Scalar Quantization
IvfSq(IvfSqIndexBuilder),
/// IVF index with RabitQ Quantization /// IVF index with RabitQ Quantization
IvfRq(IvfRqIndexBuilder), IvfRq(IvfRqIndexBuilder),
@@ -277,6 +280,8 @@ pub enum IndexType {
// Vector // Vector
#[serde(alias = "IVF_FLAT")] #[serde(alias = "IVF_FLAT")]
IvfFlat, IvfFlat,
#[serde(alias = "IVF_SQ")]
IvfSq,
#[serde(alias = "IVF_PQ")] #[serde(alias = "IVF_PQ")]
IvfPq, IvfPq,
#[serde(alias = "IVF_RQ")] #[serde(alias = "IVF_RQ")]
@@ -301,6 +306,7 @@ impl std::fmt::Display for IndexType {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self { match self {
Self::IvfFlat => write!(f, "IVF_FLAT"), Self::IvfFlat => write!(f, "IVF_FLAT"),
Self::IvfSq => write!(f, "IVF_SQ"),
Self::IvfPq => write!(f, "IVF_PQ"), Self::IvfPq => write!(f, "IVF_PQ"),
Self::IvfRq => write!(f, "IVF_RQ"), Self::IvfRq => write!(f, "IVF_RQ"),
Self::IvfHnswPq => write!(f, "IVF_HNSW_PQ"), Self::IvfHnswPq => write!(f, "IVF_HNSW_PQ"),
@@ -323,6 +329,7 @@ impl std::str::FromStr for IndexType {
"LABEL_LIST" | "LABELLIST" => Ok(Self::LabelList), "LABEL_LIST" | "LABELLIST" => Ok(Self::LabelList),
"FTS" | "INVERTED" => Ok(Self::FTS), "FTS" | "INVERTED" => Ok(Self::FTS),
"IVF_FLAT" => Ok(Self::IvfFlat), "IVF_FLAT" => Ok(Self::IvfFlat),
"IVF_SQ" => Ok(Self::IvfSq),
"IVF_PQ" => Ok(Self::IvfPq), "IVF_PQ" => Ok(Self::IvfPq),
"IVF_RQ" => Ok(Self::IvfRq), "IVF_RQ" => Ok(Self::IvfRq),
"IVF_HNSW_PQ" => Ok(Self::IvfHnswPq), "IVF_HNSW_PQ" => Ok(Self::IvfHnswPq),

View File

@@ -209,6 +209,38 @@ impl IvfFlatIndexBuilder {
impl_ivf_params_setter!(); impl_ivf_params_setter!();
} }
/// Builder for an IVF SQ index.
///
/// This index compresses vectors using scalar quantization and groups them into IVF partitions.
/// It offers a balance between search performance and storage footprint.
#[derive(Debug, Clone)]
pub struct IvfSqIndexBuilder {
pub(crate) distance_type: DistanceType,
// IVF
pub(crate) num_partitions: Option<u32>,
pub(crate) sample_rate: u32,
pub(crate) max_iterations: u32,
pub(crate) target_partition_size: Option<u32>,
}
impl Default for IvfSqIndexBuilder {
fn default() -> Self {
Self {
distance_type: DistanceType::L2,
num_partitions: None,
sample_rate: 256,
max_iterations: 50,
target_partition_size: None,
}
}
}
impl IvfSqIndexBuilder {
impl_distance_type_setter!();
impl_ivf_params_setter!();
}
/// Builder for an IVF PQ index. /// Builder for an IVF PQ index.
/// ///
/// This index stores a compressed (quantized) copy of every vector. These vectors /// This index stores a compressed (quantized) copy of every vector. These vectors

View File

@@ -1072,6 +1072,14 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
body["num_bits"] = serde_json::Value::Number(num_bits.into()); body["num_bits"] = serde_json::Value::Number(num_bits.into());
} }
} }
Index::IvfSq(index) => {
body[INDEX_TYPE_KEY] = serde_json::Value::String("IVF_SQ".to_string());
body[METRIC_TYPE_KEY] =
serde_json::Value::String(index.distance_type.to_string().to_lowercase());
if let Some(num_partitions) = index.num_partitions {
body["num_partitions"] = serde_json::Value::Number(num_partitions.into());
}
}
Index::IvfHnswSq(index) => { Index::IvfHnswSq(index) => {
body[INDEX_TYPE_KEY] = serde_json::Value::String("IVF_HNSW_SQ".to_string()); body[INDEX_TYPE_KEY] = serde_json::Value::String("IVF_HNSW_SQ".to_string());
body[METRIC_TYPE_KEY] = body[METRIC_TYPE_KEY] =

View File

@@ -1946,6 +1946,25 @@ impl NativeTable {
VectorIndexParams::with_ivf_flat_params(index.distance_type.into(), ivf_params); VectorIndexParams::with_ivf_flat_params(index.distance_type.into(), ivf_params);
Ok(Box::new(lance_idx_params)) Ok(Box::new(lance_idx_params))
} }
Index::IvfSq(index) => {
Self::validate_index_type(field, "IVF SQ", supported_vector_data_type)?;
let ivf_params = Self::build_ivf_params(
index.num_partitions,
index.target_partition_size,
index.sample_rate,
index.max_iterations,
);
let sq_params = SQBuildParams {
sample_rate: index.sample_rate as usize,
..Default::default()
};
let lance_idx_params = VectorIndexParams::with_ivf_sq_params(
index.distance_type.into(),
ivf_params,
sq_params,
);
Ok(Box::new(lance_idx_params))
}
Index::IvfPq(index) => { Index::IvfPq(index) => {
Self::validate_index_type(field, "IVF PQ", supported_vector_data_type)?; Self::validate_index_type(field, "IVF PQ", supported_vector_data_type)?;
let dim = Self::get_vector_dimension(field)?; let dim = Self::get_vector_dimension(field)?;
@@ -2053,6 +2072,7 @@ impl NativeTable {
Index::LabelList(_) => IndexType::LabelList, Index::LabelList(_) => IndexType::LabelList,
Index::FTS(_) => IndexType::Inverted, Index::FTS(_) => IndexType::Inverted,
Index::IvfFlat(_) Index::IvfFlat(_)
| Index::IvfSq(_)
| Index::IvfPq(_) | Index::IvfPq(_)
| Index::IvfRq(_) | Index::IvfRq(_)
| Index::IvfHnswPq(_) | Index::IvfHnswPq(_)