feat: expose hnsw indices (#1595)

PR closes #1522

---------

Co-authored-by: Will Jones <willjones127@gmail.com>
This commit is contained in:
Gagan Bhullar
2024-09-10 12:08:13 -06:00
committed by GitHub
parent 2bde5401eb
commit 205fc530cf
6 changed files with 291 additions and 3 deletions

View File

@@ -444,6 +444,26 @@ describe("When creating an index", () => {
expect(fs.readdirSync(indexDir)).toHaveLength(1);
});
test("create a hnswPq index", async () => {
await tbl.createIndex("vec", {
config: Index.hnswPq({
numPartitions: 10,
}),
});
const indexDir = path.join(tmpDir.name, "test.lance", "_indices");
expect(fs.readdirSync(indexDir)).toHaveLength(1);
});
test("create a HnswSq index", async () => {
await tbl.createIndex("vec", {
config: Index.hnswSq({
numPartitions: 10,
}),
});
const indexDir = path.join(tmpDir.name, "test.lance", "_indices");
expect(fs.readdirSync(indexDir)).toHaveLength(1);
});
test("create a label list index", async () => {
await tbl.createIndex("tags", {
config: Index.labelList(),

View File

@@ -113,6 +113,25 @@ export interface IvfPqOptions {
sampleRate?: number;
}
export interface HnswPqOptions {
distanceType?: "l2" | "cosine" | "dot";
numPartitions?: number;
numSubVectors?: number;
maxIterations?: number;
sampleRate?: number;
m?: number;
efConstruction?: number;
}
export interface HnswSqOptions {
distanceType?: "l2" | "cosine" | "dot";
numPartitions?: number;
maxIterations?: number;
sampleRate?: number;
m?: number;
efConstruction?: number;
}
/**
* Options to create a full text search index
*/
@@ -227,6 +246,43 @@ export class Index {
static fts(options?: Partial<FtsOptions>) {
return new Index(LanceDbIndex.fts(options?.withPositions));
}
/**
*
* Create a hnswpq index
*
*/
static hnswPq(options?: Partial<HnswPqOptions>) {
return new Index(
LanceDbIndex.hnswPq(
options?.distanceType,
options?.numPartitions,
options?.numSubVectors,
options?.maxIterations,
options?.sampleRate,
options?.m,
options?.efConstruction,
),
);
}
/**
*
* Create a hnswsq index
*
*/
static hnswSq(options?: Partial<HnswSqOptions>) {
return new Index(
LanceDbIndex.hnswSq(
options?.distanceType,
options?.numPartitions,
options?.maxIterations,
options?.sampleRate,
options?.m,
options?.efConstruction,
),
);
}
}
export interface IndexOptions {

View File

@@ -15,7 +15,7 @@
use std::sync::Mutex;
use lancedb::index::scalar::{BTreeIndexBuilder, FtsIndexBuilder};
use lancedb::index::vector::IvfPqIndexBuilder;
use lancedb::index::vector::{IvfHnswPqIndexBuilder, IvfHnswSqIndexBuilder, IvfPqIndexBuilder};
use lancedb::index::Index as LanceDbIndex;
use napi_derive::napi;
@@ -101,4 +101,76 @@ impl Index {
inner: Mutex::new(Some(LanceDbIndex::FTS(opts))),
}
}
#[napi(factory)]
pub fn hnsw_pq(
distance_type: Option<String>,
num_partitions: Option<u32>,
num_sub_vectors: Option<u32>,
max_iterations: Option<u32>,
sample_rate: Option<u32>,
m: Option<u32>,
ef_construction: Option<u32>,
) -> napi::Result<Self> {
let mut hnsw_pq_builder = IvfHnswPqIndexBuilder::default();
if let Some(distance_type) = distance_type {
let distance_type = parse_distance_type(distance_type)?;
hnsw_pq_builder = hnsw_pq_builder.distance_type(distance_type);
}
if let Some(num_partitions) = num_partitions {
hnsw_pq_builder = hnsw_pq_builder.num_partitions(num_partitions);
}
if let Some(num_sub_vectors) = num_sub_vectors {
hnsw_pq_builder = hnsw_pq_builder.num_sub_vectors(num_sub_vectors);
}
if let Some(max_iterations) = max_iterations {
hnsw_pq_builder = hnsw_pq_builder.max_iterations(max_iterations);
}
if let Some(sample_rate) = sample_rate {
hnsw_pq_builder = hnsw_pq_builder.sample_rate(sample_rate);
}
if let Some(m) = m {
hnsw_pq_builder = hnsw_pq_builder.num_edges(m);
}
if let Some(ef_construction) = ef_construction {
hnsw_pq_builder = hnsw_pq_builder.ef_construction(ef_construction);
}
Ok(Self {
inner: Mutex::new(Some(LanceDbIndex::IvfHnswPq(hnsw_pq_builder))),
})
}
#[napi(factory)]
pub fn hnsw_sq(
distance_type: Option<String>,
num_partitions: Option<u32>,
max_iterations: Option<u32>,
sample_rate: Option<u32>,
m: Option<u32>,
ef_construction: Option<u32>,
) -> napi::Result<Self> {
let mut hnsw_sq_builder = IvfHnswSqIndexBuilder::default();
if let Some(distance_type) = distance_type {
let distance_type = parse_distance_type(distance_type)?;
hnsw_sq_builder = hnsw_sq_builder.distance_type(distance_type);
}
if let Some(num_partitions) = num_partitions {
hnsw_sq_builder = hnsw_sq_builder.num_partitions(num_partitions);
}
if let Some(max_iterations) = max_iterations {
hnsw_sq_builder = hnsw_sq_builder.max_iterations(max_iterations);
}
if let Some(sample_rate) = sample_rate {
hnsw_sq_builder = hnsw_sq_builder.sample_rate(sample_rate);
}
if let Some(m) = m {
hnsw_sq_builder = hnsw_sq_builder.num_edges(m);
}
if let Some(ef_construction) = ef_construction {
hnsw_sq_builder = hnsw_sq_builder.ef_construction(ef_construction);
}
Ok(Self {
inner: Mutex::new(Some(LanceDbIndex::IvfHnswSq(hnsw_sq_builder))),
})
}
}

View File

@@ -82,6 +82,54 @@ class FTS:
self._inner = LanceDbIndex.fts(with_position=with_position)
class HnswPq:
"""Describe a Hnswpq index configuration."""
def __init__(
self,
*,
distance_type: Optional[str] = None,
num_partitions: Optional[int] = None,
num_sub_vectors: Optional[int] = None,
max_iterations: Optional[int] = None,
sample_rate: Optional[int] = None,
m: Optional[int] = None,
ef_construction: Optional[int] = None,
):
self._inner = LanceDbIndex.hnsw_pq(
distance_type=distance_type,
num_partitions=num_partitions,
num_sub_vectors=num_sub_vectors,
max_iterations=max_iterations,
sample_rate=sample_rate,
m=m,
ef_construction=ef_construction,
)
class HnswSq:
"""Describe a HNSW-SQ index configuration."""
def __init__(
self,
*,
distance_type: Optional[str] = None,
num_partitions: Optional[int] = None,
max_iterations: Optional[int] = None,
sample_rate: Optional[int] = None,
m: Optional[int] = None,
ef_construction: Optional[int] = None,
):
self._inner = LanceDbIndex.hnsw_sq(
distance_type=distance_type,
num_partitions=num_partitions,
max_iterations=max_iterations,
sample_rate=sample_rate,
m=m,
ef_construction=ef_construction,
)
class IvfPq:
"""Describes an IVF PQ Index

View File

@@ -8,7 +8,7 @@ import pyarrow as pa
import pytest
import pytest_asyncio
from lancedb import AsyncConnection, AsyncTable, connect_async
from lancedb.index import BTree, IvfPq, Bitmap, LabelList
from lancedb.index import BTree, IvfPq, Bitmap, LabelList, HnswPq, HnswSq
@pytest_asyncio.fixture
@@ -91,3 +91,17 @@ async def test_create_vector_index(some_table: AsyncTable):
assert len(indices) == 1
assert indices[0].index_type == "IvfPq"
assert indices[0].columns == ["vector"]
@pytest.mark.asyncio
async def test_create_hnswpq_index(some_table: AsyncTable):
await some_table.create_index("vector", config=HnswPq(num_partitions=10))
indices = await some_table.list_indices()
assert len(indices) == 1
@pytest.mark.asyncio
async def test_create_hnswsq_index(some_table: AsyncTable):
await some_table.create_index("vector", config=HnswSq(num_partitions=10))
indices = await some_table.list_indices()
assert len(indices) == 1

View File

@@ -16,7 +16,11 @@ use std::sync::Mutex;
use lancedb::index::scalar::FtsIndexBuilder;
use lancedb::{
index::{scalar::BTreeIndexBuilder, vector::IvfPqIndexBuilder, Index as LanceDbIndex},
index::{
scalar::BTreeIndexBuilder,
vector::{IvfHnswPqIndexBuilder, IvfHnswSqIndexBuilder, IvfPqIndexBuilder},
Index as LanceDbIndex,
},
DistanceType,
};
use pyo3::{
@@ -24,6 +28,8 @@ use pyo3::{
pyclass, pymethods, PyResult,
};
use crate::util::parse_distance_type;
#[pyclass]
pub struct Index {
inner: Mutex<Option<LanceDbIndex>>,
@@ -110,6 +116,78 @@ impl Index {
inner: Mutex::new(Some(LanceDbIndex::FTS(opts))),
}
}
#[staticmethod]
pub fn hnsw_pq(
distance_type: Option<String>,
num_partitions: Option<u32>,
num_sub_vectors: Option<u32>,
max_iterations: Option<u32>,
sample_rate: Option<u32>,
m: Option<u32>,
ef_construction: Option<u32>,
) -> PyResult<Self> {
let mut hnsw_pq_builder = IvfHnswPqIndexBuilder::default();
if let Some(distance_type) = distance_type {
let distance_type = parse_distance_type(distance_type)?;
hnsw_pq_builder = hnsw_pq_builder.distance_type(distance_type);
}
if let Some(num_partitions) = num_partitions {
hnsw_pq_builder = hnsw_pq_builder.num_partitions(num_partitions);
}
if let Some(num_sub_vectors) = num_sub_vectors {
hnsw_pq_builder = hnsw_pq_builder.num_sub_vectors(num_sub_vectors);
}
if let Some(max_iterations) = max_iterations {
hnsw_pq_builder = hnsw_pq_builder.max_iterations(max_iterations);
}
if let Some(sample_rate) = sample_rate {
hnsw_pq_builder = hnsw_pq_builder.sample_rate(sample_rate);
}
if let Some(m) = m {
hnsw_pq_builder = hnsw_pq_builder.num_edges(m);
}
if let Some(ef_construction) = ef_construction {
hnsw_pq_builder = hnsw_pq_builder.ef_construction(ef_construction);
}
Ok(Self {
inner: Mutex::new(Some(LanceDbIndex::IvfHnswPq(hnsw_pq_builder))),
})
}
#[staticmethod]
pub fn hnsw_sq(
distance_type: Option<String>,
num_partitions: Option<u32>,
max_iterations: Option<u32>,
sample_rate: Option<u32>,
m: Option<u32>,
ef_construction: Option<u32>,
) -> PyResult<Self> {
let mut hnsw_sq_builder = IvfHnswSqIndexBuilder::default();
if let Some(distance_type) = distance_type {
let distance_type = parse_distance_type(distance_type)?;
hnsw_sq_builder = hnsw_sq_builder.distance_type(distance_type);
}
if let Some(num_partitions) = num_partitions {
hnsw_sq_builder = hnsw_sq_builder.num_partitions(num_partitions);
}
if let Some(max_iterations) = max_iterations {
hnsw_sq_builder = hnsw_sq_builder.max_iterations(max_iterations);
}
if let Some(sample_rate) = sample_rate {
hnsw_sq_builder = hnsw_sq_builder.sample_rate(sample_rate);
}
if let Some(m) = m {
hnsw_sq_builder = hnsw_sq_builder.num_edges(m);
}
if let Some(ef_construction) = ef_construction {
hnsw_sq_builder = hnsw_sq_builder.ef_construction(ef_construction);
}
Ok(Self {
inner: Mutex::new(Some(LanceDbIndex::IvfHnswSq(hnsw_sq_builder))),
})
}
}
#[pyclass(get_all)]