diff --git a/nodejs/__test__/table.test.ts b/nodejs/__test__/table.test.ts index 3c979031..31743be9 100644 --- a/nodejs/__test__/table.test.ts +++ b/nodejs/__test__/table.test.ts @@ -444,6 +444,26 @@ describe("When creating an index", () => { expect(fs.readdirSync(indexDir)).toHaveLength(1); }); + test("create a hnswPq index", async () => { + await tbl.createIndex("vec", { + config: Index.hnswPq({ + numPartitions: 10, + }), + }); + const indexDir = path.join(tmpDir.name, "test.lance", "_indices"); + expect(fs.readdirSync(indexDir)).toHaveLength(1); + }); + + test("create a HnswSq index", async () => { + await tbl.createIndex("vec", { + config: Index.hnswSq({ + numPartitions: 10, + }), + }); + const indexDir = path.join(tmpDir.name, "test.lance", "_indices"); + expect(fs.readdirSync(indexDir)).toHaveLength(1); + }); + test("create a label list index", async () => { await tbl.createIndex("tags", { config: Index.labelList(), diff --git a/nodejs/lancedb/indices.ts b/nodejs/lancedb/indices.ts index 8eb20b71..5b3b9225 100644 --- a/nodejs/lancedb/indices.ts +++ b/nodejs/lancedb/indices.ts @@ -113,6 +113,25 @@ export interface IvfPqOptions { sampleRate?: number; } +export interface HnswPqOptions { + distanceType?: "l2" | "cosine" | "dot"; + numPartitions?: number; + numSubVectors?: number; + maxIterations?: number; + sampleRate?: number; + m?: number; + efConstruction?: number; +} + +export interface HnswSqOptions { + distanceType?: "l2" | "cosine" | "dot"; + numPartitions?: number; + maxIterations?: number; + sampleRate?: number; + m?: number; + efConstruction?: number; +} + /** * Options to create a full text search index */ @@ -227,6 +246,43 @@ export class Index { static fts(options?: Partial) { return new Index(LanceDbIndex.fts(options?.withPositions)); } + + /** + * + * Create a hnswpq index + * + */ + static hnswPq(options?: Partial) { + return new Index( + LanceDbIndex.hnswPq( + options?.distanceType, + options?.numPartitions, + options?.numSubVectors, + options?.maxIterations, + options?.sampleRate, + options?.m, + options?.efConstruction, + ), + ); + } + + /** + * + * Create a hnswsq index + * + */ + static hnswSq(options?: Partial) { + return new Index( + LanceDbIndex.hnswSq( + options?.distanceType, + options?.numPartitions, + options?.maxIterations, + options?.sampleRate, + options?.m, + options?.efConstruction, + ), + ); + } } export interface IndexOptions { diff --git a/nodejs/src/index.rs b/nodejs/src/index.rs index 56c68ae8..c828f20c 100644 --- a/nodejs/src/index.rs +++ b/nodejs/src/index.rs @@ -15,7 +15,7 @@ use std::sync::Mutex; use lancedb::index::scalar::{BTreeIndexBuilder, FtsIndexBuilder}; -use lancedb::index::vector::IvfPqIndexBuilder; +use lancedb::index::vector::{IvfHnswPqIndexBuilder, IvfHnswSqIndexBuilder, IvfPqIndexBuilder}; use lancedb::index::Index as LanceDbIndex; use napi_derive::napi; @@ -101,4 +101,76 @@ impl Index { inner: Mutex::new(Some(LanceDbIndex::FTS(opts))), } } + + #[napi(factory)] + pub fn hnsw_pq( + distance_type: Option, + num_partitions: Option, + num_sub_vectors: Option, + max_iterations: Option, + sample_rate: Option, + m: Option, + ef_construction: Option, + ) -> napi::Result { + let mut hnsw_pq_builder = IvfHnswPqIndexBuilder::default(); + if let Some(distance_type) = distance_type { + let distance_type = parse_distance_type(distance_type)?; + hnsw_pq_builder = hnsw_pq_builder.distance_type(distance_type); + } + if let Some(num_partitions) = num_partitions { + hnsw_pq_builder = hnsw_pq_builder.num_partitions(num_partitions); + } + if let Some(num_sub_vectors) = num_sub_vectors { + hnsw_pq_builder = hnsw_pq_builder.num_sub_vectors(num_sub_vectors); + } + if let Some(max_iterations) = max_iterations { + hnsw_pq_builder = hnsw_pq_builder.max_iterations(max_iterations); + } + if let Some(sample_rate) = sample_rate { + hnsw_pq_builder = hnsw_pq_builder.sample_rate(sample_rate); + } + if let Some(m) = m { + hnsw_pq_builder = hnsw_pq_builder.num_edges(m); + } + if let Some(ef_construction) = ef_construction { + hnsw_pq_builder = hnsw_pq_builder.ef_construction(ef_construction); + } + Ok(Self { + inner: Mutex::new(Some(LanceDbIndex::IvfHnswPq(hnsw_pq_builder))), + }) + } + + #[napi(factory)] + pub fn hnsw_sq( + distance_type: Option, + num_partitions: Option, + max_iterations: Option, + sample_rate: Option, + m: Option, + ef_construction: Option, + ) -> napi::Result { + let mut hnsw_sq_builder = IvfHnswSqIndexBuilder::default(); + if let Some(distance_type) = distance_type { + let distance_type = parse_distance_type(distance_type)?; + hnsw_sq_builder = hnsw_sq_builder.distance_type(distance_type); + } + if let Some(num_partitions) = num_partitions { + hnsw_sq_builder = hnsw_sq_builder.num_partitions(num_partitions); + } + if let Some(max_iterations) = max_iterations { + hnsw_sq_builder = hnsw_sq_builder.max_iterations(max_iterations); + } + if let Some(sample_rate) = sample_rate { + hnsw_sq_builder = hnsw_sq_builder.sample_rate(sample_rate); + } + if let Some(m) = m { + hnsw_sq_builder = hnsw_sq_builder.num_edges(m); + } + if let Some(ef_construction) = ef_construction { + hnsw_sq_builder = hnsw_sq_builder.ef_construction(ef_construction); + } + Ok(Self { + inner: Mutex::new(Some(LanceDbIndex::IvfHnswSq(hnsw_sq_builder))), + }) + } } diff --git a/python/python/lancedb/index.py b/python/python/lancedb/index.py index aab8948d..bedbb097 100644 --- a/python/python/lancedb/index.py +++ b/python/python/lancedb/index.py @@ -82,6 +82,54 @@ class FTS: self._inner = LanceDbIndex.fts(with_position=with_position) +class HnswPq: + """Describe a Hnswpq index configuration.""" + + def __init__( + self, + *, + distance_type: Optional[str] = None, + num_partitions: Optional[int] = None, + num_sub_vectors: Optional[int] = None, + max_iterations: Optional[int] = None, + sample_rate: Optional[int] = None, + m: Optional[int] = None, + ef_construction: Optional[int] = None, + ): + self._inner = LanceDbIndex.hnsw_pq( + distance_type=distance_type, + num_partitions=num_partitions, + num_sub_vectors=num_sub_vectors, + max_iterations=max_iterations, + sample_rate=sample_rate, + m=m, + ef_construction=ef_construction, + ) + + +class HnswSq: + """Describe a HNSW-SQ index configuration.""" + + def __init__( + self, + *, + distance_type: Optional[str] = None, + num_partitions: Optional[int] = None, + max_iterations: Optional[int] = None, + sample_rate: Optional[int] = None, + m: Optional[int] = None, + ef_construction: Optional[int] = None, + ): + self._inner = LanceDbIndex.hnsw_sq( + distance_type=distance_type, + num_partitions=num_partitions, + max_iterations=max_iterations, + sample_rate=sample_rate, + m=m, + ef_construction=ef_construction, + ) + + class IvfPq: """Describes an IVF PQ Index diff --git a/python/python/tests/test_index.py b/python/python/tests/test_index.py index 283ffd27..6ce6391b 100644 --- a/python/python/tests/test_index.py +++ b/python/python/tests/test_index.py @@ -8,7 +8,7 @@ import pyarrow as pa import pytest import pytest_asyncio from lancedb import AsyncConnection, AsyncTable, connect_async -from lancedb.index import BTree, IvfPq, Bitmap, LabelList +from lancedb.index import BTree, IvfPq, Bitmap, LabelList, HnswPq, HnswSq @pytest_asyncio.fixture @@ -91,3 +91,17 @@ async def test_create_vector_index(some_table: AsyncTable): assert len(indices) == 1 assert indices[0].index_type == "IvfPq" assert indices[0].columns == ["vector"] + + +@pytest.mark.asyncio +async def test_create_hnswpq_index(some_table: AsyncTable): + await some_table.create_index("vector", config=HnswPq(num_partitions=10)) + indices = await some_table.list_indices() + assert len(indices) == 1 + + +@pytest.mark.asyncio +async def test_create_hnswsq_index(some_table: AsyncTable): + await some_table.create_index("vector", config=HnswSq(num_partitions=10)) + indices = await some_table.list_indices() + assert len(indices) == 1 diff --git a/python/src/index.rs b/python/src/index.rs index 58474b33..aa8ffda8 100644 --- a/python/src/index.rs +++ b/python/src/index.rs @@ -16,7 +16,11 @@ use std::sync::Mutex; use lancedb::index::scalar::FtsIndexBuilder; use lancedb::{ - index::{scalar::BTreeIndexBuilder, vector::IvfPqIndexBuilder, Index as LanceDbIndex}, + index::{ + scalar::BTreeIndexBuilder, + vector::{IvfHnswPqIndexBuilder, IvfHnswSqIndexBuilder, IvfPqIndexBuilder}, + Index as LanceDbIndex, + }, DistanceType, }; use pyo3::{ @@ -24,6 +28,8 @@ use pyo3::{ pyclass, pymethods, PyResult, }; +use crate::util::parse_distance_type; + #[pyclass] pub struct Index { inner: Mutex>, @@ -110,6 +116,78 @@ impl Index { inner: Mutex::new(Some(LanceDbIndex::FTS(opts))), } } + + #[staticmethod] + pub fn hnsw_pq( + distance_type: Option, + num_partitions: Option, + num_sub_vectors: Option, + max_iterations: Option, + sample_rate: Option, + m: Option, + ef_construction: Option, + ) -> PyResult { + let mut hnsw_pq_builder = IvfHnswPqIndexBuilder::default(); + if let Some(distance_type) = distance_type { + let distance_type = parse_distance_type(distance_type)?; + hnsw_pq_builder = hnsw_pq_builder.distance_type(distance_type); + } + if let Some(num_partitions) = num_partitions { + hnsw_pq_builder = hnsw_pq_builder.num_partitions(num_partitions); + } + if let Some(num_sub_vectors) = num_sub_vectors { + hnsw_pq_builder = hnsw_pq_builder.num_sub_vectors(num_sub_vectors); + } + if let Some(max_iterations) = max_iterations { + hnsw_pq_builder = hnsw_pq_builder.max_iterations(max_iterations); + } + if let Some(sample_rate) = sample_rate { + hnsw_pq_builder = hnsw_pq_builder.sample_rate(sample_rate); + } + if let Some(m) = m { + hnsw_pq_builder = hnsw_pq_builder.num_edges(m); + } + if let Some(ef_construction) = ef_construction { + hnsw_pq_builder = hnsw_pq_builder.ef_construction(ef_construction); + } + Ok(Self { + inner: Mutex::new(Some(LanceDbIndex::IvfHnswPq(hnsw_pq_builder))), + }) + } + + #[staticmethod] + pub fn hnsw_sq( + distance_type: Option, + num_partitions: Option, + max_iterations: Option, + sample_rate: Option, + m: Option, + ef_construction: Option, + ) -> PyResult { + let mut hnsw_sq_builder = IvfHnswSqIndexBuilder::default(); + if let Some(distance_type) = distance_type { + let distance_type = parse_distance_type(distance_type)?; + hnsw_sq_builder = hnsw_sq_builder.distance_type(distance_type); + } + if let Some(num_partitions) = num_partitions { + hnsw_sq_builder = hnsw_sq_builder.num_partitions(num_partitions); + } + if let Some(max_iterations) = max_iterations { + hnsw_sq_builder = hnsw_sq_builder.max_iterations(max_iterations); + } + if let Some(sample_rate) = sample_rate { + hnsw_sq_builder = hnsw_sq_builder.sample_rate(sample_rate); + } + if let Some(m) = m { + hnsw_sq_builder = hnsw_sq_builder.num_edges(m); + } + if let Some(ef_construction) = ef_construction { + hnsw_sq_builder = hnsw_sq_builder.ef_construction(ef_construction); + } + Ok(Self { + inner: Mutex::new(Some(LanceDbIndex::IvfHnswSq(hnsw_sq_builder))), + }) + } } #[pyclass(get_all)]