mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-22 21:09:58 +00:00
feat: add IVF SQ index support and HNSW aliases (#2832)
Adds IVF_SQ index config through Rust core and Python bindings, plus alias names IvfHnswSq/Pq for backward compatibility. Updates remote/table helpers and types to accept the new index type. Includes tests covering IVF SQ creation and alias usage.
This commit is contained in:
@@ -3,7 +3,17 @@ from typing import Dict, List, Optional, Tuple, Any, TypedDict, Union, Literal
|
|||||||
|
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
|
|
||||||
from .index import BTree, IvfFlat, IvfPq, Bitmap, LabelList, HnswPq, HnswSq, FTS
|
from .index import (
|
||||||
|
BTree,
|
||||||
|
IvfFlat,
|
||||||
|
IvfPq,
|
||||||
|
IvfSq,
|
||||||
|
Bitmap,
|
||||||
|
LabelList,
|
||||||
|
HnswPq,
|
||||||
|
HnswSq,
|
||||||
|
FTS,
|
||||||
|
)
|
||||||
from .io import StorageOptionsProvider
|
from .io import StorageOptionsProvider
|
||||||
from lance_namespace import (
|
from lance_namespace import (
|
||||||
ListNamespacesResponse,
|
ListNamespacesResponse,
|
||||||
@@ -14,6 +24,9 @@ from lance_namespace import (
|
|||||||
)
|
)
|
||||||
from .remote import ClientConfig
|
from .remote import ClientConfig
|
||||||
|
|
||||||
|
IvfHnswPq: type[HnswPq] = HnswPq
|
||||||
|
IvfHnswSq: type[HnswSq] = HnswSq
|
||||||
|
|
||||||
class Session:
|
class Session:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -131,7 +144,17 @@ class Table:
|
|||||||
async def create_index(
|
async def create_index(
|
||||||
self,
|
self,
|
||||||
column: str,
|
column: str,
|
||||||
index: Union[IvfFlat, IvfPq, HnswPq, HnswSq, BTree, Bitmap, LabelList, FTS],
|
index: Union[
|
||||||
|
IvfFlat,
|
||||||
|
IvfSq,
|
||||||
|
IvfPq,
|
||||||
|
HnswPq,
|
||||||
|
HnswSq,
|
||||||
|
BTree,
|
||||||
|
Bitmap,
|
||||||
|
LabelList,
|
||||||
|
FTS,
|
||||||
|
],
|
||||||
replace: Optional[bool],
|
replace: Optional[bool],
|
||||||
wait_timeout: Optional[object],
|
wait_timeout: Optional[object],
|
||||||
*,
|
*,
|
||||||
|
|||||||
@@ -376,6 +376,11 @@ class HnswSq:
|
|||||||
target_partition_size: Optional[int] = None
|
target_partition_size: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
|
# Backwards-compatible aliases
|
||||||
|
IvfHnswPq = HnswPq
|
||||||
|
IvfHnswSq = HnswSq
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class IvfFlat:
|
class IvfFlat:
|
||||||
"""Describes an IVF Flat Index
|
"""Describes an IVF Flat Index
|
||||||
@@ -475,6 +480,36 @@ class IvfFlat:
|
|||||||
target_partition_size: Optional[int] = None
|
target_partition_size: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class IvfSq:
|
||||||
|
"""Describes an IVF Scalar Quantization (SQ) index.
|
||||||
|
|
||||||
|
This index applies scalar quantization to compress vectors and organizes the
|
||||||
|
quantized vectors into IVF partitions. It offers a balance between search
|
||||||
|
speed and storage efficiency while keeping good recall.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
distance_type: str, default "l2"
|
||||||
|
The distance metric used to train and search the index. Supported values
|
||||||
|
are "l2", "cosine", and "dot".
|
||||||
|
num_partitions: int, default sqrt(num_rows)
|
||||||
|
Number of IVF partitions to create.
|
||||||
|
max_iterations: int, default 50
|
||||||
|
Maximum iterations for kmeans during partition training.
|
||||||
|
sample_rate: int, default 256
|
||||||
|
Controls the number of training vectors: sample_rate * num_partitions.
|
||||||
|
target_partition_size: int, optional
|
||||||
|
Target size for each partition; adjusts the balance between speed and accuracy.
|
||||||
|
"""
|
||||||
|
|
||||||
|
distance_type: Literal["l2", "cosine", "dot"] = "l2"
|
||||||
|
num_partitions: Optional[int] = None
|
||||||
|
max_iterations: int = 50
|
||||||
|
sample_rate: int = 256
|
||||||
|
target_partition_size: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class IvfPq:
|
class IvfPq:
|
||||||
"""Describes an IVF PQ Index
|
"""Describes an IVF PQ Index
|
||||||
@@ -661,6 +696,9 @@ class IvfRq:
|
|||||||
__all__ = [
|
__all__ = [
|
||||||
"BTree",
|
"BTree",
|
||||||
"IvfPq",
|
"IvfPq",
|
||||||
|
"IvfHnswPq",
|
||||||
|
"IvfHnswSq",
|
||||||
|
"IvfSq",
|
||||||
"IvfRq",
|
"IvfRq",
|
||||||
"IvfFlat",
|
"IvfFlat",
|
||||||
"HnswPq",
|
"HnswPq",
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ from lancedb._lancedb import (
|
|||||||
UpdateResult,
|
UpdateResult,
|
||||||
)
|
)
|
||||||
from lancedb.embeddings.base import EmbeddingFunctionConfig
|
from lancedb.embeddings.base import EmbeddingFunctionConfig
|
||||||
from lancedb.index import FTS, BTree, Bitmap, HnswSq, IvfFlat, IvfPq, LabelList
|
from lancedb.index import FTS, BTree, Bitmap, HnswSq, IvfFlat, IvfPq, IvfSq, LabelList
|
||||||
from lancedb.remote.db import LOOP
|
from lancedb.remote.db import LOOP
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
|
|
||||||
@@ -265,6 +265,8 @@ class RemoteTable(Table):
|
|||||||
num_sub_vectors=num_sub_vectors,
|
num_sub_vectors=num_sub_vectors,
|
||||||
num_bits=num_bits,
|
num_bits=num_bits,
|
||||||
)
|
)
|
||||||
|
elif index_type == "IVF_SQ":
|
||||||
|
config = IvfSq(distance_type=metric, num_partitions=num_partitions)
|
||||||
elif index_type == "IVF_HNSW_PQ":
|
elif index_type == "IVF_HNSW_PQ":
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"IVF_HNSW_PQ is not supported on LanceDB cloud."
|
"IVF_HNSW_PQ is not supported on LanceDB cloud."
|
||||||
@@ -277,7 +279,7 @@ class RemoteTable(Table):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unknown vector index type: {index_type}. Valid options are"
|
f"Unknown vector index type: {index_type}. Valid options are"
|
||||||
" 'IVF_FLAT', 'IVF_PQ', 'IVF_HNSW_PQ', 'IVF_HNSW_SQ'"
|
" 'IVF_FLAT', 'IVF_SQ', 'IVF_PQ', 'IVF_HNSW_PQ', 'IVF_HNSW_SQ'"
|
||||||
)
|
)
|
||||||
|
|
||||||
LOOP.run(
|
LOOP.run(
|
||||||
|
|||||||
@@ -44,7 +44,18 @@ import numpy as np
|
|||||||
|
|
||||||
from .common import DATA, VEC, VECTOR_COLUMN_NAME
|
from .common import DATA, VEC, VECTOR_COLUMN_NAME
|
||||||
from .embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
|
from .embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
|
||||||
from .index import BTree, IvfFlat, IvfPq, Bitmap, IvfRq, LabelList, HnswPq, HnswSq, FTS
|
from .index import (
|
||||||
|
BTree,
|
||||||
|
IvfFlat,
|
||||||
|
IvfPq,
|
||||||
|
IvfSq,
|
||||||
|
Bitmap,
|
||||||
|
IvfRq,
|
||||||
|
LabelList,
|
||||||
|
HnswPq,
|
||||||
|
HnswSq,
|
||||||
|
FTS,
|
||||||
|
)
|
||||||
from .merge import LanceMergeInsertBuilder
|
from .merge import LanceMergeInsertBuilder
|
||||||
from .pydantic import LanceModel, model_to_dict
|
from .pydantic import LanceModel, model_to_dict
|
||||||
from .query import (
|
from .query import (
|
||||||
@@ -2054,7 +2065,7 @@ class LanceTable(Table):
|
|||||||
index_cache_size: Optional[int] = None,
|
index_cache_size: Optional[int] = None,
|
||||||
num_bits: int = 8,
|
num_bits: int = 8,
|
||||||
index_type: Literal[
|
index_type: Literal[
|
||||||
"IVF_FLAT", "IVF_PQ", "IVF_RQ", "IVF_HNSW_SQ", "IVF_HNSW_PQ"
|
"IVF_FLAT", "IVF_SQ", "IVF_PQ", "IVF_RQ", "IVF_HNSW_SQ", "IVF_HNSW_PQ"
|
||||||
] = "IVF_PQ",
|
] = "IVF_PQ",
|
||||||
max_iterations: int = 50,
|
max_iterations: int = 50,
|
||||||
sample_rate: int = 256,
|
sample_rate: int = 256,
|
||||||
@@ -2092,6 +2103,14 @@ class LanceTable(Table):
|
|||||||
sample_rate=sample_rate,
|
sample_rate=sample_rate,
|
||||||
target_partition_size=target_partition_size,
|
target_partition_size=target_partition_size,
|
||||||
)
|
)
|
||||||
|
elif index_type == "IVF_SQ":
|
||||||
|
config = IvfSq(
|
||||||
|
distance_type=metric,
|
||||||
|
num_partitions=num_partitions,
|
||||||
|
max_iterations=max_iterations,
|
||||||
|
sample_rate=sample_rate,
|
||||||
|
target_partition_size=target_partition_size,
|
||||||
|
)
|
||||||
elif index_type == "IVF_PQ":
|
elif index_type == "IVF_PQ":
|
||||||
config = IvfPq(
|
config = IvfPq(
|
||||||
distance_type=metric,
|
distance_type=metric,
|
||||||
@@ -3456,11 +3475,22 @@ class AsyncTable:
|
|||||||
if config is not None:
|
if config is not None:
|
||||||
if not isinstance(
|
if not isinstance(
|
||||||
config,
|
config,
|
||||||
(IvfFlat, IvfPq, IvfRq, HnswPq, HnswSq, BTree, Bitmap, LabelList, FTS),
|
(
|
||||||
|
IvfFlat,
|
||||||
|
IvfSq,
|
||||||
|
IvfPq,
|
||||||
|
IvfRq,
|
||||||
|
HnswPq,
|
||||||
|
HnswSq,
|
||||||
|
BTree,
|
||||||
|
Bitmap,
|
||||||
|
LabelList,
|
||||||
|
FTS,
|
||||||
|
),
|
||||||
):
|
):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
"config must be an instance of IvfPq, IvfRq, HnswPq, HnswSq, BTree,"
|
"config must be an instance of IvfSq, IvfPq, IvfRq, HnswPq, HnswSq,"
|
||||||
" Bitmap, LabelList, or FTS, but got " + str(type(config))
|
" BTree, Bitmap, LabelList, or FTS, but got " + str(type(config))
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
await self._inner.create_index(
|
await self._inner.create_index(
|
||||||
|
|||||||
@@ -18,12 +18,20 @@ AddMode = Literal["append", "overwrite"]
|
|||||||
CreateMode = Literal["create", "overwrite"]
|
CreateMode = Literal["create", "overwrite"]
|
||||||
|
|
||||||
# Index type literals
|
# Index type literals
|
||||||
VectorIndexType = Literal["IVF_FLAT", "IVF_PQ", "IVF_HNSW_SQ", "IVF_HNSW_PQ", "IVF_RQ"]
|
VectorIndexType = Literal[
|
||||||
|
"IVF_FLAT",
|
||||||
|
"IVF_SQ",
|
||||||
|
"IVF_PQ",
|
||||||
|
"IVF_HNSW_SQ",
|
||||||
|
"IVF_HNSW_PQ",
|
||||||
|
"IVF_RQ",
|
||||||
|
]
|
||||||
ScalarIndexType = Literal["BTREE", "BITMAP", "LABEL_LIST"]
|
ScalarIndexType = Literal["BTREE", "BITMAP", "LABEL_LIST"]
|
||||||
IndexType = Literal[
|
IndexType = Literal[
|
||||||
"IVF_PQ",
|
"IVF_PQ",
|
||||||
"IVF_HNSW_PQ",
|
"IVF_HNSW_PQ",
|
||||||
"IVF_HNSW_SQ",
|
"IVF_HNSW_SQ",
|
||||||
|
"IVF_SQ",
|
||||||
"FTS",
|
"FTS",
|
||||||
"BTREE",
|
"BTREE",
|
||||||
"BITMAP",
|
"BITMAP",
|
||||||
|
|||||||
@@ -12,6 +12,9 @@ from lancedb.index import (
|
|||||||
BTree,
|
BTree,
|
||||||
IvfFlat,
|
IvfFlat,
|
||||||
IvfPq,
|
IvfPq,
|
||||||
|
IvfSq,
|
||||||
|
IvfHnswPq,
|
||||||
|
IvfHnswSq,
|
||||||
IvfRq,
|
IvfRq,
|
||||||
Bitmap,
|
Bitmap,
|
||||||
LabelList,
|
LabelList,
|
||||||
@@ -229,6 +232,35 @@ async def test_create_hnswsq_index(some_table: AsyncTable):
|
|||||||
assert len(indices) == 1
|
assert len(indices) == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_hnswsq_alias_index(some_table: AsyncTable):
|
||||||
|
await some_table.create_index("vector", config=IvfHnswSq(num_partitions=5))
|
||||||
|
indices = await some_table.list_indices()
|
||||||
|
assert len(indices) == 1
|
||||||
|
assert indices[0].index_type in {"HnswSq", "IvfHnswSq"}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_hnswpq_alias_index(some_table: AsyncTable):
|
||||||
|
await some_table.create_index("vector", config=IvfHnswPq(num_partitions=5))
|
||||||
|
indices = await some_table.list_indices()
|
||||||
|
assert len(indices) == 1
|
||||||
|
assert indices[0].index_type in {"HnswPq", "IvfHnswPq"}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_ivfsq_index(some_table: AsyncTable):
|
||||||
|
await some_table.create_index("vector", config=IvfSq(num_partitions=10))
|
||||||
|
indices = await some_table.list_indices()
|
||||||
|
assert len(indices) == 1
|
||||||
|
assert indices[0].index_type == "IvfSq"
|
||||||
|
stats = await some_table.index_stats(indices[0].name)
|
||||||
|
assert stats.index_type == "IVF_SQ"
|
||||||
|
assert stats.distance_type == "l2"
|
||||||
|
assert stats.num_indexed_rows == await some_table.count_rows()
|
||||||
|
assert stats.num_unindexed_rows == 0
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_create_index_with_binary_vectors(binary_table: AsyncTable):
|
async def test_create_index_with_binary_vectors(binary_table: AsyncTable):
|
||||||
await binary_table.create_index(
|
await binary_table.create_index(
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
// SPDX-License-Identifier: Apache-2.0
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||||
|
|
||||||
use lancedb::index::vector::{IvfFlatIndexBuilder, IvfRqIndexBuilder};
|
use lancedb::index::vector::{IvfFlatIndexBuilder, IvfRqIndexBuilder, IvfSqIndexBuilder};
|
||||||
use lancedb::index::{
|
use lancedb::index::{
|
||||||
scalar::{BTreeIndexBuilder, FtsIndexBuilder},
|
scalar::{BTreeIndexBuilder, FtsIndexBuilder},
|
||||||
vector::{IvfHnswPqIndexBuilder, IvfHnswSqIndexBuilder, IvfPqIndexBuilder},
|
vector::{IvfHnswPqIndexBuilder, IvfHnswSqIndexBuilder, IvfPqIndexBuilder},
|
||||||
@@ -87,6 +87,21 @@ pub fn extract_index_params(source: &Option<Bound<'_, PyAny>>) -> PyResult<Lance
|
|||||||
}
|
}
|
||||||
Ok(LanceDbIndex::IvfPq(ivf_pq_builder))
|
Ok(LanceDbIndex::IvfPq(ivf_pq_builder))
|
||||||
},
|
},
|
||||||
|
"IvfSq" => {
|
||||||
|
let params = source.extract::<IvfSqParams>()?;
|
||||||
|
let distance_type = parse_distance_type(params.distance_type)?;
|
||||||
|
let mut ivf_sq_builder = IvfSqIndexBuilder::default()
|
||||||
|
.distance_type(distance_type)
|
||||||
|
.max_iterations(params.max_iterations)
|
||||||
|
.sample_rate(params.sample_rate);
|
||||||
|
if let Some(num_partitions) = params.num_partitions {
|
||||||
|
ivf_sq_builder = ivf_sq_builder.num_partitions(num_partitions);
|
||||||
|
}
|
||||||
|
if let Some(target_partition_size) = params.target_partition_size {
|
||||||
|
ivf_sq_builder = ivf_sq_builder.target_partition_size(target_partition_size);
|
||||||
|
}
|
||||||
|
Ok(LanceDbIndex::IvfSq(ivf_sq_builder))
|
||||||
|
},
|
||||||
"IvfRq" => {
|
"IvfRq" => {
|
||||||
let params = source.extract::<IvfRqParams>()?;
|
let params = source.extract::<IvfRqParams>()?;
|
||||||
let distance_type = parse_distance_type(params.distance_type)?;
|
let distance_type = parse_distance_type(params.distance_type)?;
|
||||||
@@ -142,7 +157,7 @@ pub fn extract_index_params(source: &Option<Bound<'_, PyAny>>) -> PyResult<Lance
|
|||||||
Ok(LanceDbIndex::IvfHnswSq(hnsw_sq_builder))
|
Ok(LanceDbIndex::IvfHnswSq(hnsw_sq_builder))
|
||||||
},
|
},
|
||||||
not_supported => Err(PyValueError::new_err(format!(
|
not_supported => Err(PyValueError::new_err(format!(
|
||||||
"Invalid index type '{}'. Must be one of BTree, Bitmap, LabelList, FTS, IvfPq, IvfHnswPq, or IvfHnswSq",
|
"Invalid index type '{}'. Must be one of BTree, Bitmap, LabelList, FTS, IvfPq, IvfSq, IvfHnswPq, or IvfHnswSq",
|
||||||
not_supported
|
not_supported
|
||||||
))),
|
))),
|
||||||
}
|
}
|
||||||
@@ -186,6 +201,15 @@ struct IvfPqParams {
|
|||||||
target_partition_size: Option<u32>,
|
target_partition_size: Option<u32>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(FromPyObject)]
|
||||||
|
struct IvfSqParams {
|
||||||
|
distance_type: String,
|
||||||
|
num_partitions: Option<u32>,
|
||||||
|
max_iterations: u32,
|
||||||
|
sample_rate: u32,
|
||||||
|
target_partition_size: Option<u32>,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(FromPyObject)]
|
#[derive(FromPyObject)]
|
||||||
struct IvfRqParams {
|
struct IvfRqParams {
|
||||||
distance_type: String,
|
distance_type: String,
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ use crate::{table::BaseTable, DistanceType, Error, Result};
|
|||||||
|
|
||||||
use self::{
|
use self::{
|
||||||
scalar::{BTreeIndexBuilder, BitmapIndexBuilder, LabelListIndexBuilder},
|
scalar::{BTreeIndexBuilder, BitmapIndexBuilder, LabelListIndexBuilder},
|
||||||
vector::{IvfHnswPqIndexBuilder, IvfHnswSqIndexBuilder, IvfPqIndexBuilder},
|
vector::{IvfHnswPqIndexBuilder, IvfHnswSqIndexBuilder, IvfPqIndexBuilder, IvfSqIndexBuilder},
|
||||||
};
|
};
|
||||||
|
|
||||||
pub mod scalar;
|
pub mod scalar;
|
||||||
@@ -54,6 +54,9 @@ pub enum Index {
|
|||||||
/// IVF index with Product Quantization
|
/// IVF index with Product Quantization
|
||||||
IvfPq(IvfPqIndexBuilder),
|
IvfPq(IvfPqIndexBuilder),
|
||||||
|
|
||||||
|
/// IVF index with Scalar Quantization
|
||||||
|
IvfSq(IvfSqIndexBuilder),
|
||||||
|
|
||||||
/// IVF index with RabitQ Quantization
|
/// IVF index with RabitQ Quantization
|
||||||
IvfRq(IvfRqIndexBuilder),
|
IvfRq(IvfRqIndexBuilder),
|
||||||
|
|
||||||
@@ -277,6 +280,8 @@ pub enum IndexType {
|
|||||||
// Vector
|
// Vector
|
||||||
#[serde(alias = "IVF_FLAT")]
|
#[serde(alias = "IVF_FLAT")]
|
||||||
IvfFlat,
|
IvfFlat,
|
||||||
|
#[serde(alias = "IVF_SQ")]
|
||||||
|
IvfSq,
|
||||||
#[serde(alias = "IVF_PQ")]
|
#[serde(alias = "IVF_PQ")]
|
||||||
IvfPq,
|
IvfPq,
|
||||||
#[serde(alias = "IVF_RQ")]
|
#[serde(alias = "IVF_RQ")]
|
||||||
@@ -301,6 +306,7 @@ impl std::fmt::Display for IndexType {
|
|||||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||||
match self {
|
match self {
|
||||||
Self::IvfFlat => write!(f, "IVF_FLAT"),
|
Self::IvfFlat => write!(f, "IVF_FLAT"),
|
||||||
|
Self::IvfSq => write!(f, "IVF_SQ"),
|
||||||
Self::IvfPq => write!(f, "IVF_PQ"),
|
Self::IvfPq => write!(f, "IVF_PQ"),
|
||||||
Self::IvfRq => write!(f, "IVF_RQ"),
|
Self::IvfRq => write!(f, "IVF_RQ"),
|
||||||
Self::IvfHnswPq => write!(f, "IVF_HNSW_PQ"),
|
Self::IvfHnswPq => write!(f, "IVF_HNSW_PQ"),
|
||||||
@@ -323,6 +329,7 @@ impl std::str::FromStr for IndexType {
|
|||||||
"LABEL_LIST" | "LABELLIST" => Ok(Self::LabelList),
|
"LABEL_LIST" | "LABELLIST" => Ok(Self::LabelList),
|
||||||
"FTS" | "INVERTED" => Ok(Self::FTS),
|
"FTS" | "INVERTED" => Ok(Self::FTS),
|
||||||
"IVF_FLAT" => Ok(Self::IvfFlat),
|
"IVF_FLAT" => Ok(Self::IvfFlat),
|
||||||
|
"IVF_SQ" => Ok(Self::IvfSq),
|
||||||
"IVF_PQ" => Ok(Self::IvfPq),
|
"IVF_PQ" => Ok(Self::IvfPq),
|
||||||
"IVF_RQ" => Ok(Self::IvfRq),
|
"IVF_RQ" => Ok(Self::IvfRq),
|
||||||
"IVF_HNSW_PQ" => Ok(Self::IvfHnswPq),
|
"IVF_HNSW_PQ" => Ok(Self::IvfHnswPq),
|
||||||
|
|||||||
@@ -209,6 +209,38 @@ impl IvfFlatIndexBuilder {
|
|||||||
impl_ivf_params_setter!();
|
impl_ivf_params_setter!();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Builder for an IVF SQ index.
|
||||||
|
///
|
||||||
|
/// This index compresses vectors using scalar quantization and groups them into IVF partitions.
|
||||||
|
/// It offers a balance between search performance and storage footprint.
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct IvfSqIndexBuilder {
|
||||||
|
pub(crate) distance_type: DistanceType,
|
||||||
|
|
||||||
|
// IVF
|
||||||
|
pub(crate) num_partitions: Option<u32>,
|
||||||
|
pub(crate) sample_rate: u32,
|
||||||
|
pub(crate) max_iterations: u32,
|
||||||
|
pub(crate) target_partition_size: Option<u32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for IvfSqIndexBuilder {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
distance_type: DistanceType::L2,
|
||||||
|
num_partitions: None,
|
||||||
|
sample_rate: 256,
|
||||||
|
max_iterations: 50,
|
||||||
|
target_partition_size: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl IvfSqIndexBuilder {
|
||||||
|
impl_distance_type_setter!();
|
||||||
|
impl_ivf_params_setter!();
|
||||||
|
}
|
||||||
|
|
||||||
/// Builder for an IVF PQ index.
|
/// Builder for an IVF PQ index.
|
||||||
///
|
///
|
||||||
/// This index stores a compressed (quantized) copy of every vector. These vectors
|
/// This index stores a compressed (quantized) copy of every vector. These vectors
|
||||||
|
|||||||
@@ -1072,6 +1072,14 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
|
|||||||
body["num_bits"] = serde_json::Value::Number(num_bits.into());
|
body["num_bits"] = serde_json::Value::Number(num_bits.into());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Index::IvfSq(index) => {
|
||||||
|
body[INDEX_TYPE_KEY] = serde_json::Value::String("IVF_SQ".to_string());
|
||||||
|
body[METRIC_TYPE_KEY] =
|
||||||
|
serde_json::Value::String(index.distance_type.to_string().to_lowercase());
|
||||||
|
if let Some(num_partitions) = index.num_partitions {
|
||||||
|
body["num_partitions"] = serde_json::Value::Number(num_partitions.into());
|
||||||
|
}
|
||||||
|
}
|
||||||
Index::IvfHnswSq(index) => {
|
Index::IvfHnswSq(index) => {
|
||||||
body[INDEX_TYPE_KEY] = serde_json::Value::String("IVF_HNSW_SQ".to_string());
|
body[INDEX_TYPE_KEY] = serde_json::Value::String("IVF_HNSW_SQ".to_string());
|
||||||
body[METRIC_TYPE_KEY] =
|
body[METRIC_TYPE_KEY] =
|
||||||
|
|||||||
@@ -1946,6 +1946,25 @@ impl NativeTable {
|
|||||||
VectorIndexParams::with_ivf_flat_params(index.distance_type.into(), ivf_params);
|
VectorIndexParams::with_ivf_flat_params(index.distance_type.into(), ivf_params);
|
||||||
Ok(Box::new(lance_idx_params))
|
Ok(Box::new(lance_idx_params))
|
||||||
}
|
}
|
||||||
|
Index::IvfSq(index) => {
|
||||||
|
Self::validate_index_type(field, "IVF SQ", supported_vector_data_type)?;
|
||||||
|
let ivf_params = Self::build_ivf_params(
|
||||||
|
index.num_partitions,
|
||||||
|
index.target_partition_size,
|
||||||
|
index.sample_rate,
|
||||||
|
index.max_iterations,
|
||||||
|
);
|
||||||
|
let sq_params = SQBuildParams {
|
||||||
|
sample_rate: index.sample_rate as usize,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let lance_idx_params = VectorIndexParams::with_ivf_sq_params(
|
||||||
|
index.distance_type.into(),
|
||||||
|
ivf_params,
|
||||||
|
sq_params,
|
||||||
|
);
|
||||||
|
Ok(Box::new(lance_idx_params))
|
||||||
|
}
|
||||||
Index::IvfPq(index) => {
|
Index::IvfPq(index) => {
|
||||||
Self::validate_index_type(field, "IVF PQ", supported_vector_data_type)?;
|
Self::validate_index_type(field, "IVF PQ", supported_vector_data_type)?;
|
||||||
let dim = Self::get_vector_dimension(field)?;
|
let dim = Self::get_vector_dimension(field)?;
|
||||||
@@ -2053,6 +2072,7 @@ impl NativeTable {
|
|||||||
Index::LabelList(_) => IndexType::LabelList,
|
Index::LabelList(_) => IndexType::LabelList,
|
||||||
Index::FTS(_) => IndexType::Inverted,
|
Index::FTS(_) => IndexType::Inverted,
|
||||||
Index::IvfFlat(_)
|
Index::IvfFlat(_)
|
||||||
|
| Index::IvfSq(_)
|
||||||
| Index::IvfPq(_)
|
| Index::IvfPq(_)
|
||||||
| Index::IvfRq(_)
|
| Index::IvfRq(_)
|
||||||
| Index::IvfHnswPq(_)
|
| Index::IvfHnswPq(_)
|
||||||
|
|||||||
Reference in New Issue
Block a user