From 650f173236cdf45b5a33266a5e440e935aa43687 Mon Sep 17 00:00:00 2001 From: Shengan Zhang Date: Mon, 11 May 2026 15:08:32 -0700 Subject: [PATCH] feat(python): add IVF_HNSW_FLAT vector index support (#3366) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary Wire up `IVF_HNSW_FLAT` in the Rust core and Python SDK. The index was documented at https://docs.lancedb.com/indexing/vector-index but `lancedb.Table.create_index(index_type="IVF_HNSW_FLAT")` raised `ValueError: Unknown index type IVF_HNSW_FLAT` — the underlying `pylance` already accepted it, only the LanceDB wrapper was missing the wiring. **Rust core (`rust/lancedb`):** - Add `Index::IvfHnswFlat` / `IndexType::IvfHnswFlat` variants and the `IvfHnswFlatIndexBuilder` (modelled on `IvfHnswSqIndexBuilder`). - Build Lance params via the existing `VectorIndexParams::ivf_hnsw(...)` helper, keeping symmetry with the other `IVF_HNSW_*` variants. - Forward the variant in `RemoteTable::create_index` and add two parametrised tests (default + customised config) for the JSON serialisation. - New `NativeTable` integration test (`test_create_index_ivf_hnsw_flat`). **Python binding (`python/`):** - New `HnswFlat` dataclass + backwards-compat `IvfHnswFlat` alias. - PyO3 `extract_index_params` recognises the `HnswFlat` config. - `LanceTable.create_index(index_type="IVF_HNSW_FLAT", …)` and the sync `RemoteTable.create_index` both dispatch to the new config. - `IndexStatistics.index_type` `Literal` and `_lancedb.pyi` stubs cover the new type so `pyright`/`make check` stays clean. - Async integration tests (`HnswFlat` + `IvfHnswFlat` alias) and a sync dispatcher test, mirroring the existing `IVF_HNSW_SQ` coverage. - Existing `test_index_statistics_index_type_lists_all_supported_values` updated to include `IVF_HNSW_FLAT`. A matching Node.js / TypeScript binding is in a follow-up PR. Closes #3331 ## Test plan - [ ] \`cargo check --quiet --features remote --tests --examples\` - [ ] \`cargo test --quiet --features remote -p lancedb\` (covers the new \`test_create_index_ivf_hnsw_flat\` and the two new parametrised \`RemoteTable::create_index\` cases) - [ ] \`cargo fmt --all\` / \`cargo clippy --quiet --features remote --tests --examples\` - [ ] \`cd python && make develop && make check && make test\` (covers the two new async tests, the alias test, the dispatcher test, and the updated \`test_index_statistics_index_type_lists_all_supported_values\` assertion) --- python/python/lancedb/_lancedb.pyi | 3 + python/python/lancedb/index.py | 91 +++++++++++++++++++++++++++ python/python/lancedb/remote/table.py | 5 +- python/python/lancedb/table.py | 34 +++++++++- python/python/lancedb/types.py | 2 + python/python/tests/test_index.py | 18 ++++++ python/python/tests/test_table.py | 17 ++++- python/src/index.rs | 37 ++++++++++- rust/lancedb/src/index.rs | 13 +++- rust/lancedb/src/index/vector.rs | 43 +++++++++++++ rust/lancedb/src/remote/table.rs | 33 +++++++++- rust/lancedb/src/table.rs | 71 ++++++++++++++++++++- 12 files changed, 357 insertions(+), 10 deletions(-) diff --git a/python/python/lancedb/_lancedb.pyi b/python/python/lancedb/_lancedb.pyi index b33f89e40..8811723e2 100644 --- a/python/python/lancedb/_lancedb.pyi +++ b/python/python/lancedb/_lancedb.pyi @@ -12,6 +12,7 @@ from .index import ( LabelList, HnswPq, HnswSq, + HnswFlat, FTS, ) from lance_namespace import ( @@ -25,6 +26,7 @@ from .remote import ClientConfig IvfHnswPq: type[HnswPq] = HnswPq IvfHnswSq: type[HnswSq] = HnswSq +IvfHnswFlat: type[HnswFlat] = HnswFlat class PyExpr: """A type-safe DataFusion expression node (Rust-side handle).""" @@ -180,6 +182,7 @@ class Table: IvfPq, HnswPq, HnswSq, + HnswFlat, BTree, Bitmap, LabelList, diff --git a/python/python/lancedb/index.py b/python/python/lancedb/index.py index 4fbffc50d..f3f4d6a6e 100644 --- a/python/python/lancedb/index.py +++ b/python/python/lancedb/index.py @@ -388,9 +388,98 @@ class HnswSq: target_partition_size: Optional[int] = None +@dataclass +class HnswFlat: + """Describe a HNSW-FLAT index configuration. + + HNSW-FLAT stands for Hierarchical Navigable Small World without quantization. + It stores raw vectors in the HNSW graph, providing the highest recall among + the IVF_HNSW family at the cost of more memory and disk space compared to + :class:`HnswSq` or :class:`HnswPq`. + + Parameters + ---------- + + distance_type: str, default "l2" + + The distance metric used to train the index. + + The following distance types are available: + + "l2" - Euclidean distance. This is a very common distance metric that + accounts for both magnitude and direction when determining the distance + between vectors. l2 distance has a range of [0, ∞). + + "cosine" - Cosine distance. Cosine distance is a distance metric + calculated from the cosine similarity between two vectors. Cosine + similarity is a measure of similarity between two non-zero vectors of an + inner product space. It is defined to equal the cosine of the angle + between them. Unlike l2, the cosine distance is not affected by the + magnitude of the vectors. Cosine distance has a range of [0, 2]. + + "dot" - Dot product. Dot distance is the dot product of two vectors. Dot + distance has a range of (-∞, ∞). If the vectors are normalized (i.e. their + l2 norm is 1), then dot distance is equivalent to the cosine distance. + + num_partitions, default sqrt(num_rows) + + The number of IVF partitions to create. + + For HNSW, we recommend a small number of partitions. Setting this to 1 + works well for most tables. For very large tables, training just one HNSW + graph will require too much memory. Each partition becomes its own HNSW + graph, so setting this value higher reduces the peak memory use of + training. + + max_iterations, default 50 + + Max iterations to train kmeans. + + When training an IVF index we use kmeans to calculate the partitions. + This parameter controls how many iterations of kmeans to run. + + sample_rate, default 256 + + The rate used to calculate the number of training vectors for kmeans. + + m, default 20 + + The number of neighbors to select for each vector in the HNSW graph. + + This value controls the tradeoff between search speed and accuracy. + The higher the value the more accurate the search but the slower it + will be. + + ef_construction, default 300 + + The number of candidates to evaluate during the construction of the HNSW + graph. + + This value controls the tradeoff between build speed and accuracy. + The higher the value the more accurate the build but the slower it will + be. 150 to 300 is the typical range. 100 is a minimum for good quality + search results. In most cases, there is no benefit to setting this higher + than 500. This value should be set to a value that is not less than `ef` + in the search phase. + + target_partition_size, default is 1,048,576 + + The target size of each partition. + """ + + distance_type: Literal["l2", "cosine", "dot"] = "l2" + num_partitions: Optional[int] = None + max_iterations: int = 50 + sample_rate: int = 256 + m: int = 20 + ef_construction: int = 300 + target_partition_size: Optional[int] = None + + # Backwards-compatible aliases IvfHnswPq = HnswPq IvfHnswSq = HnswSq +IvfHnswFlat = HnswFlat @dataclass @@ -710,11 +799,13 @@ __all__ = [ "IvfPq", "IvfHnswPq", "IvfHnswSq", + "IvfHnswFlat", "IvfSq", "IvfRq", "IvfFlat", "HnswPq", "HnswSq", + "HnswFlat", "IndexConfig", "FTS", "Bitmap", diff --git a/python/python/lancedb/remote/table.py b/python/python/lancedb/remote/table.py index a33166937..f4237110d 100644 --- a/python/python/lancedb/remote/table.py +++ b/python/python/lancedb/remote/table.py @@ -22,6 +22,7 @@ from lancedb.index import ( FTS, BTree, Bitmap, + HnswFlat, HnswSq, IvfFlat, IvfPq, @@ -285,13 +286,15 @@ class RemoteTable(Table): ) elif index_type == "IVF_HNSW_SQ": config = HnswSq(distance_type=metric, num_partitions=num_partitions) + elif index_type == "IVF_HNSW_FLAT": + config = HnswFlat(distance_type=metric, num_partitions=num_partitions) elif index_type == "IVF_FLAT": config = IvfFlat(distance_type=metric, num_partitions=num_partitions) else: raise ValueError( f"Unknown vector index type: {index_type}. Valid options are" " 'IVF_FLAT', 'IVF_PQ', 'IVF_RQ', 'IVF_SQ'," - " 'IVF_HNSW_PQ', 'IVF_HNSW_SQ'" + " 'IVF_HNSW_PQ', 'IVF_HNSW_SQ', 'IVF_HNSW_FLAT'" ) LOOP.run( diff --git a/python/python/lancedb/table.py b/python/python/lancedb/table.py index 82768197c..87b14e434 100644 --- a/python/python/lancedb/table.py +++ b/python/python/lancedb/table.py @@ -57,6 +57,7 @@ from .index import ( LabelList, HnswPq, HnswSq, + HnswFlat, FTS, ) from .merge import LanceMergeInsertBuilder @@ -2236,7 +2237,13 @@ class LanceTable(Table): index_cache_size: Optional[int] = None, num_bits: int = 8, index_type: Literal[ - "IVF_FLAT", "IVF_SQ", "IVF_PQ", "IVF_RQ", "IVF_HNSW_SQ", "IVF_HNSW_PQ" + "IVF_FLAT", + "IVF_SQ", + "IVF_PQ", + "IVF_RQ", + "IVF_HNSW_SQ", + "IVF_HNSW_PQ", + "IVF_HNSW_FLAT", ] = "IVF_PQ", max_iterations: int = 50, sample_rate: int = 256, @@ -2323,6 +2330,16 @@ class LanceTable(Table): ef_construction=ef_construction, target_partition_size=target_partition_size, ) + elif index_type == "IVF_HNSW_FLAT": + config = HnswFlat( + distance_type=metric, + num_partitions=num_partitions, + max_iterations=max_iterations, + sample_rate=sample_rate, + m=m, + ef_construction=ef_construction, + target_partition_size=target_partition_size, + ) else: raise ValueError(f"Unknown index type {index_type}") @@ -3873,7 +3890,18 @@ class AsyncTable: *, replace: Optional[bool] = None, config: Optional[ - Union[IvfFlat, IvfPq, IvfRq, HnswPq, HnswSq, BTree, Bitmap, LabelList, FTS] + Union[ + IvfFlat, + IvfPq, + IvfRq, + HnswPq, + HnswSq, + HnswFlat, + BTree, + Bitmap, + LabelList, + FTS, + ] ] = None, wait_timeout: Optional[timedelta] = None, name: Optional[str] = None, @@ -3920,6 +3948,7 @@ class AsyncTable: IvfRq, HnswPq, HnswSq, + HnswFlat, BTree, Bitmap, LabelList, @@ -5090,6 +5119,7 @@ class IndexStatistics: "IVF_RQ", "IVF_HNSW_SQ", "IVF_HNSW_PQ", + "IVF_HNSW_FLAT", "FTS", "BTREE", "BITMAP", diff --git a/python/python/lancedb/types.py b/python/python/lancedb/types.py index 2e26e5630..8bab57e0e 100644 --- a/python/python/lancedb/types.py +++ b/python/python/lancedb/types.py @@ -24,6 +24,7 @@ VectorIndexType = Literal[ "IVF_PQ", "IVF_HNSW_SQ", "IVF_HNSW_PQ", + "IVF_HNSW_FLAT", "IVF_RQ", ] ScalarIndexType = Literal["BTREE", "BITMAP", "LABEL_LIST"] @@ -31,6 +32,7 @@ IndexType = Literal[ "IVF_PQ", "IVF_HNSW_PQ", "IVF_HNSW_SQ", + "IVF_HNSW_FLAT", "IVF_SQ", "FTS", "BTREE", diff --git a/python/python/tests/test_index.py b/python/python/tests/test_index.py index 8dfa55a77..80f9f2673 100644 --- a/python/python/tests/test_index.py +++ b/python/python/tests/test_index.py @@ -16,11 +16,13 @@ from lancedb.index import ( IvfSq, IvfHnswPq, IvfHnswSq, + IvfHnswFlat, IvfRq, Bitmap, LabelList, HnswPq, HnswSq, + HnswFlat, FTS, ) from lancedb.table import IndexStatistics @@ -250,6 +252,21 @@ async def test_create_hnswpq_alias_index(some_table: AsyncTable): assert indices[0].index_type in {"HnswPq", "IvfHnswPq"} +@pytest.mark.asyncio +async def test_create_hnswflat_index(some_table: AsyncTable): + await some_table.create_index("vector", config=HnswFlat(num_partitions=10)) + indices = await some_table.list_indices() + assert len(indices) == 1 + + +@pytest.mark.asyncio +async def test_create_hnswflat_alias_index(some_table: AsyncTable): + await some_table.create_index("vector", config=IvfHnswFlat(num_partitions=5)) + indices = await some_table.list_indices() + assert len(indices) == 1 + assert indices[0].index_type in {"HnswFlat", "IvfHnswFlat"} + + @pytest.mark.asyncio async def test_create_ivfsq_index(some_table: AsyncTable): await some_table.create_index("vector", config=IvfSq(num_partitions=10)) @@ -295,6 +312,7 @@ def test_index_statistics_index_type_lists_all_supported_values(): "IVF_RQ", "IVF_HNSW_SQ", "IVF_HNSW_PQ", + "IVF_HNSW_FLAT", "FTS", "BTREE", "BITMAP", diff --git a/python/python/tests/test_table.py b/python/python/tests/test_table.py index 4e20d2cfc..3e27d0e69 100644 --- a/python/python/tests/test_table.py +++ b/python/python/tests/test_table.py @@ -11,7 +11,7 @@ from unittest.mock import patch import lancedb from lancedb.dependencies import _PANDAS_AVAILABLE -from lancedb.index import HnswPq, HnswSq, IvfPq +from lancedb.index import HnswFlat, HnswPq, HnswSq, IvfPq import numpy as np import polars as pl import pyarrow as pa @@ -917,6 +917,21 @@ def test_create_index_method(mock_create_index, mem_db: DBConnection): "my_vector", replace=True, config=expected_config, name=None, train=True ) + table.create_index( + vector_column_name="my_vector", + metric="cosine", + index_type="IVF_HNSW_FLAT", + sample_rate=0.1, + m=29, + ef_construction=10, + ) + expected_config = HnswFlat( + distance_type="cosine", sample_rate=0.1, m=29, ef_construction=10 + ) + mock_create_index.assert_called_with( + "my_vector", replace=True, config=expected_config, name=None, train=True + ) + @patch("lancedb.table.AsyncTable.create_index") def test_create_index_name_and_train_parameters( diff --git a/python/src/index.rs b/python/src/index.rs index ce90280b0..508b10f17 100644 --- a/python/src/index.rs +++ b/python/src/index.rs @@ -1,11 +1,13 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The LanceDB Authors -use lancedb::index::vector::{IvfFlatIndexBuilder, IvfRqIndexBuilder, IvfSqIndexBuilder}; +use lancedb::index::vector::{ + IvfFlatIndexBuilder, IvfHnswFlatIndexBuilder, IvfHnswPqIndexBuilder, IvfHnswSqIndexBuilder, + IvfPqIndexBuilder, IvfRqIndexBuilder, IvfSqIndexBuilder, +}; use lancedb::index::{ Index as LanceDbIndex, scalar::{BTreeIndexBuilder, FtsIndexBuilder}, - vector::{IvfHnswPqIndexBuilder, IvfHnswSqIndexBuilder, IvfPqIndexBuilder}, }; use pyo3::IntoPyObject; use pyo3::types::PyStringMethods; @@ -162,8 +164,26 @@ pub fn extract_index_params(source: &Option>) -> PyResult { + let params = source.extract::()?; + let distance_type = parse_distance_type(params.distance_type)?; + let mut hnsw_flat_builder = IvfHnswFlatIndexBuilder::default() + .distance_type(distance_type) + .max_iterations(params.max_iterations) + .sample_rate(params.sample_rate) + .num_edges(params.m) + .ef_construction(params.ef_construction); + if let Some(num_partitions) = params.num_partitions { + hnsw_flat_builder = hnsw_flat_builder.num_partitions(num_partitions); + } + if let Some(target_partition_size) = params.target_partition_size { + hnsw_flat_builder = + hnsw_flat_builder.target_partition_size(target_partition_size); + } + Ok(LanceDbIndex::IvfHnswFlat(hnsw_flat_builder)) + } not_supported => Err(PyValueError::new_err(format!( - "Invalid index type '{}'. Must be one of BTree, Bitmap, LabelList, FTS, IvfPq, IvfSq, IvfHnswPq, or IvfHnswSq", + "Invalid index type '{}'. Must be one of BTree, Bitmap, LabelList, FTS, IvfPq, IvfSq, IvfHnswPq, IvfHnswSq, or IvfHnswFlat", not_supported ))), } @@ -250,6 +270,17 @@ struct IvfHnswSqParams { target_partition_size: Option, } +#[derive(FromPyObject)] +struct IvfHnswFlatParams { + distance_type: String, + num_partitions: Option, + max_iterations: u32, + sample_rate: u32, + m: u32, + ef_construction: u32, + target_partition_size: Option, +} + #[pyclass(get_all)] /// A description of an index currently configured on a column pub struct IndexConfig { diff --git a/rust/lancedb/src/index.rs b/rust/lancedb/src/index.rs index a012fd4cd..3a55eeedf 100644 --- a/rust/lancedb/src/index.rs +++ b/rust/lancedb/src/index.rs @@ -13,7 +13,10 @@ use crate::{DistanceType, Error, Result, table::BaseTable}; use self::{ scalar::{BTreeIndexBuilder, BitmapIndexBuilder, LabelListIndexBuilder}, - vector::{IvfHnswPqIndexBuilder, IvfHnswSqIndexBuilder, IvfPqIndexBuilder, IvfSqIndexBuilder}, + vector::{ + IvfHnswFlatIndexBuilder, IvfHnswPqIndexBuilder, IvfHnswSqIndexBuilder, IvfPqIndexBuilder, + IvfSqIndexBuilder, + }, }; pub mod scalar; @@ -67,6 +70,10 @@ pub enum Index { /// IVF-HNSW index with Scalar Quantization /// It is a variant of the HNSW algorithm that uses scalar quantization to compress the vectors. IvfHnswSq(IvfHnswSqIndexBuilder), + + /// IVF-HNSW index without quantization. + /// Stores raw vectors, providing the highest recall at the cost of more memory and disk space. + IvfHnswFlat(IvfHnswFlatIndexBuilder), } /// Builder for the create_index operation @@ -290,6 +297,8 @@ pub enum IndexType { IvfHnswPq, #[serde(alias = "IVF_HNSW_SQ")] IvfHnswSq, + #[serde(alias = "IVF_HNSW_FLAT")] + IvfHnswFlat, // Scalar #[serde(alias = "BTREE")] BTree, @@ -311,6 +320,7 @@ impl std::fmt::Display for IndexType { Self::IvfRq => write!(f, "IVF_RQ"), Self::IvfHnswPq => write!(f, "IVF_HNSW_PQ"), Self::IvfHnswSq => write!(f, "IVF_HNSW_SQ"), + Self::IvfHnswFlat => write!(f, "IVF_HNSW_FLAT"), Self::BTree => write!(f, "BTREE"), Self::Bitmap => write!(f, "BITMAP"), Self::LabelList => write!(f, "LABEL_LIST"), @@ -334,6 +344,7 @@ impl std::str::FromStr for IndexType { "IVF_RQ" => Ok(Self::IvfRq), "IVF_HNSW_PQ" => Ok(Self::IvfHnswPq), "IVF_HNSW_SQ" => Ok(Self::IvfHnswSq), + "IVF_HNSW_FLAT" => Ok(Self::IvfHnswFlat), _ => Err(Error::InvalidInput { message: format!("the input value {} is not a valid IndexType", value), }), diff --git a/rust/lancedb/src/index/vector.rs b/rust/lancedb/src/index/vector.rs index a5507f41c..7e62c9f6c 100644 --- a/rust/lancedb/src/index/vector.rs +++ b/rust/lancedb/src/index/vector.rs @@ -474,3 +474,46 @@ impl IvfHnswSqIndexBuilder { impl_ivf_params_setter!(); impl_hnsw_params_setter!(); } + +/// Builder for an IVF_HNSW_FLAT index. +/// +/// This index combines IVF partitioning with an HNSW graph per partition, +/// storing raw (unquantized) vectors. It offers the highest recall among +/// the IVF_HNSW family at the cost of more memory and disk space compared +/// to [`IvfHnswSqIndexBuilder`] or [`IvfHnswPqIndexBuilder`]. +#[derive(Debug, Clone, Serialize)] +pub struct IvfHnswFlatIndexBuilder { + // IVF + #[serde(rename = "metric_type")] + pub(crate) distance_type: DistanceType, + #[serde(skip_serializing_if = "Option::is_none")] + pub(crate) num_partitions: Option, + pub(crate) sample_rate: u32, + pub(crate) max_iterations: u32, + #[serde(skip_serializing_if = "Option::is_none")] + pub(crate) target_partition_size: Option, + + // HNSW + pub(crate) m: u32, + pub(crate) ef_construction: u32, +} + +impl Default for IvfHnswFlatIndexBuilder { + fn default() -> Self { + Self { + distance_type: DistanceType::L2, + num_partitions: None, + sample_rate: 256, + max_iterations: 50, + m: 20, + ef_construction: 300, + target_partition_size: None, + } + } +} + +impl IvfHnswFlatIndexBuilder { + impl_distance_type_setter!(); + impl_ivf_params_setter!(); + impl_hnsw_params_setter!(); +} diff --git a/rust/lancedb/src/remote/table.rs b/rust/lancedb/src/remote/table.rs index c9b807505..b991ed335 100644 --- a/rust/lancedb/src/remote/table.rs +++ b/rust/lancedb/src/remote/table.rs @@ -1540,6 +1540,7 @@ impl BaseTable for RemoteTable { Index::IvfPq(p) => ("IVF_PQ", Some(to_json(p)?)), Index::IvfSq(p) => ("IVF_SQ", Some(to_json(p)?)), Index::IvfHnswSq(p) => ("IVF_HNSW_SQ", Some(to_json(p)?)), + Index::IvfHnswFlat(p) => ("IVF_HNSW_FLAT", Some(to_json(p)?)), Index::IvfRq(p) => ("IVF_RQ", Some(to_json(p)?)), Index::BTree(p) => ("BTREE", Some(to_json(p)?)), Index::Bitmap(p) => ("BITMAP", Some(to_json(p)?)), @@ -2068,7 +2069,8 @@ mod tests { use serde_json::json; use crate::index::vector::{ - IvfFlatIndexBuilder, IvfHnswSqIndexBuilder, IvfRqIndexBuilder, IvfSqIndexBuilder, + IvfFlatIndexBuilder, IvfHnswFlatIndexBuilder, IvfHnswSqIndexBuilder, IvfRqIndexBuilder, + IvfSqIndexBuilder, }; use crate::remote::JSON_CONTENT_TYPE; use crate::remote::db::DEFAULT_SERVER_VERSION; @@ -3321,6 +3323,35 @@ mod tests { .ef_construction(500), ), ), + ( + "IVF_HNSW_FLAT", + json!({ + "metric_type": "l2", + "sample_rate": 256, + "max_iterations": 50, + "m": 20, + "ef_construction": 300, + }), + Index::IvfHnswFlat(Default::default()), + ), + ( + "IVF_HNSW_FLAT", + json!({ + "metric_type": "cosine", + "num_partitions": 64, + "sample_rate": 256, + "max_iterations": 50, + "m": 40, + "ef_construction": 500, + }), + Index::IvfHnswFlat( + IvfHnswFlatIndexBuilder::default() + .distance_type(DistanceType::Cosine) + .num_partitions(64) + .num_edges(40) + .ef_construction(500), + ), + ), ( "IVF_SQ", json!({ diff --git a/rust/lancedb/src/table.rs b/rust/lancedb/src/table.rs index 73415e89b..29bcaea26 100644 --- a/rust/lancedb/src/table.rs +++ b/rust/lancedb/src/table.rs @@ -2033,6 +2033,24 @@ impl NativeTable { ); Ok(Box::new(lance_idx_params)) } + Index::IvfHnswFlat(index) => { + Self::validate_index_type(field, "IVF HNSW FLAT", supported_vector_data_type)?; + let ivf_params = Self::build_ivf_params( + index.num_partitions, + index.target_partition_size, + index.sample_rate, + index.max_iterations, + ); + let hnsw_params = HnswBuildParams::default() + .num_edges(index.m as usize) + .ef_construction(index.ef_construction as usize); + let lance_idx_params = VectorIndexParams::ivf_hnsw( + index.distance_type.into(), + ivf_params, + hnsw_params, + ); + Ok(Box::new(lance_idx_params)) + } } } @@ -2058,7 +2076,8 @@ impl NativeTable { | Index::IvfPq(_) | Index::IvfRq(_) | Index::IvfHnswPq(_) - | Index::IvfHnswSq(_) => IndexType::Vector, + | Index::IvfHnswSq(_) + | Index::IvfHnswFlat(_) => IndexType::Vector, } } @@ -3176,6 +3195,56 @@ mod tests { assert_eq!(stats.num_unindexed_rows, 0); } + #[tokio::test] + async fn test_create_index_ivf_hnsw_flat() { + use arrow_array::RecordBatch; + use arrow_schema::{DataType, Field, Schema as ArrowSchema}; + use rand; + use std::iter::repeat_with; + + use crate::index::vector::IvfHnswFlatIndexBuilder; + use arrow_array::Float32Array; + + let tmp_dir = tempdir().unwrap(); + let uri = tmp_dir.path().to_str().unwrap(); + let conn = connect(uri).execute().await.unwrap(); + + let dimension = 16; + let schema = Arc::new(ArrowSchema::new(vec![Field::new( + "embeddings", + DataType::FixedSizeList( + Arc::new(Field::new("item", DataType::Float32, true)), + dimension, + ), + false, + )])); + + let float_arr = Float32Array::from( + repeat_with(rand::random::) + .take(512 * dimension as usize) + .collect::>(), + ); + + let vectors = Arc::new(create_fixed_size_list(float_arr, dimension).unwrap()); + let batch = RecordBatch::try_new(schema.clone(), vec![vectors.clone()]).unwrap(); + + let table = conn.create_table("test", batch).execute().await.unwrap(); + + let index = IvfHnswFlatIndexBuilder::default(); + table + .create_index(&["embeddings"], Index::IvfHnswFlat(index)) + .execute() + .await + .unwrap(); + + let index_configs = table.list_indices().await.unwrap(); + assert_eq!(index_configs.len(), 1); + let index = index_configs.into_iter().next().unwrap(); + assert_eq!(index.index_type, crate::index::IndexType::IvfHnswFlat); + assert_eq!(index.columns, vec!["embeddings".to_string()]); + assert_eq!(table.count_rows(None).await.unwrap(), 512); + } + fn create_fixed_size_list(values: T, list_size: i32) -> Result { let list_type = DataType::FixedSizeList( Arc::new(Field::new("item", values.data_type().clone(), true)),