mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-09 05:12:58 +00:00
feat: expose hnsw indices (#1595)
PR closes #1522 --------- Co-authored-by: Will Jones <willjones127@gmail.com>
This commit is contained in:
@@ -444,6 +444,26 @@ describe("When creating an index", () => {
|
||||
expect(fs.readdirSync(indexDir)).toHaveLength(1);
|
||||
});
|
||||
|
||||
test("create a hnswPq index", async () => {
|
||||
await tbl.createIndex("vec", {
|
||||
config: Index.hnswPq({
|
||||
numPartitions: 10,
|
||||
}),
|
||||
});
|
||||
const indexDir = path.join(tmpDir.name, "test.lance", "_indices");
|
||||
expect(fs.readdirSync(indexDir)).toHaveLength(1);
|
||||
});
|
||||
|
||||
test("create a HnswSq index", async () => {
|
||||
await tbl.createIndex("vec", {
|
||||
config: Index.hnswSq({
|
||||
numPartitions: 10,
|
||||
}),
|
||||
});
|
||||
const indexDir = path.join(tmpDir.name, "test.lance", "_indices");
|
||||
expect(fs.readdirSync(indexDir)).toHaveLength(1);
|
||||
});
|
||||
|
||||
test("create a label list index", async () => {
|
||||
await tbl.createIndex("tags", {
|
||||
config: Index.labelList(),
|
||||
|
||||
@@ -113,6 +113,25 @@ export interface IvfPqOptions {
|
||||
sampleRate?: number;
|
||||
}
|
||||
|
||||
export interface HnswPqOptions {
|
||||
distanceType?: "l2" | "cosine" | "dot";
|
||||
numPartitions?: number;
|
||||
numSubVectors?: number;
|
||||
maxIterations?: number;
|
||||
sampleRate?: number;
|
||||
m?: number;
|
||||
efConstruction?: number;
|
||||
}
|
||||
|
||||
export interface HnswSqOptions {
|
||||
distanceType?: "l2" | "cosine" | "dot";
|
||||
numPartitions?: number;
|
||||
maxIterations?: number;
|
||||
sampleRate?: number;
|
||||
m?: number;
|
||||
efConstruction?: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* Options to create a full text search index
|
||||
*/
|
||||
@@ -227,6 +246,43 @@ export class Index {
|
||||
static fts(options?: Partial<FtsOptions>) {
|
||||
return new Index(LanceDbIndex.fts(options?.withPositions));
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* Create a hnswpq index
|
||||
*
|
||||
*/
|
||||
static hnswPq(options?: Partial<HnswPqOptions>) {
|
||||
return new Index(
|
||||
LanceDbIndex.hnswPq(
|
||||
options?.distanceType,
|
||||
options?.numPartitions,
|
||||
options?.numSubVectors,
|
||||
options?.maxIterations,
|
||||
options?.sampleRate,
|
||||
options?.m,
|
||||
options?.efConstruction,
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* Create a hnswsq index
|
||||
*
|
||||
*/
|
||||
static hnswSq(options?: Partial<HnswSqOptions>) {
|
||||
return new Index(
|
||||
LanceDbIndex.hnswSq(
|
||||
options?.distanceType,
|
||||
options?.numPartitions,
|
||||
options?.maxIterations,
|
||||
options?.sampleRate,
|
||||
options?.m,
|
||||
options?.efConstruction,
|
||||
),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
export interface IndexOptions {
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
use std::sync::Mutex;
|
||||
|
||||
use lancedb::index::scalar::{BTreeIndexBuilder, FtsIndexBuilder};
|
||||
use lancedb::index::vector::IvfPqIndexBuilder;
|
||||
use lancedb::index::vector::{IvfHnswPqIndexBuilder, IvfHnswSqIndexBuilder, IvfPqIndexBuilder};
|
||||
use lancedb::index::Index as LanceDbIndex;
|
||||
use napi_derive::napi;
|
||||
|
||||
@@ -101,4 +101,76 @@ impl Index {
|
||||
inner: Mutex::new(Some(LanceDbIndex::FTS(opts))),
|
||||
}
|
||||
}
|
||||
|
||||
#[napi(factory)]
|
||||
pub fn hnsw_pq(
|
||||
distance_type: Option<String>,
|
||||
num_partitions: Option<u32>,
|
||||
num_sub_vectors: Option<u32>,
|
||||
max_iterations: Option<u32>,
|
||||
sample_rate: Option<u32>,
|
||||
m: Option<u32>,
|
||||
ef_construction: Option<u32>,
|
||||
) -> napi::Result<Self> {
|
||||
let mut hnsw_pq_builder = IvfHnswPqIndexBuilder::default();
|
||||
if let Some(distance_type) = distance_type {
|
||||
let distance_type = parse_distance_type(distance_type)?;
|
||||
hnsw_pq_builder = hnsw_pq_builder.distance_type(distance_type);
|
||||
}
|
||||
if let Some(num_partitions) = num_partitions {
|
||||
hnsw_pq_builder = hnsw_pq_builder.num_partitions(num_partitions);
|
||||
}
|
||||
if let Some(num_sub_vectors) = num_sub_vectors {
|
||||
hnsw_pq_builder = hnsw_pq_builder.num_sub_vectors(num_sub_vectors);
|
||||
}
|
||||
if let Some(max_iterations) = max_iterations {
|
||||
hnsw_pq_builder = hnsw_pq_builder.max_iterations(max_iterations);
|
||||
}
|
||||
if let Some(sample_rate) = sample_rate {
|
||||
hnsw_pq_builder = hnsw_pq_builder.sample_rate(sample_rate);
|
||||
}
|
||||
if let Some(m) = m {
|
||||
hnsw_pq_builder = hnsw_pq_builder.num_edges(m);
|
||||
}
|
||||
if let Some(ef_construction) = ef_construction {
|
||||
hnsw_pq_builder = hnsw_pq_builder.ef_construction(ef_construction);
|
||||
}
|
||||
Ok(Self {
|
||||
inner: Mutex::new(Some(LanceDbIndex::IvfHnswPq(hnsw_pq_builder))),
|
||||
})
|
||||
}
|
||||
|
||||
#[napi(factory)]
|
||||
pub fn hnsw_sq(
|
||||
distance_type: Option<String>,
|
||||
num_partitions: Option<u32>,
|
||||
max_iterations: Option<u32>,
|
||||
sample_rate: Option<u32>,
|
||||
m: Option<u32>,
|
||||
ef_construction: Option<u32>,
|
||||
) -> napi::Result<Self> {
|
||||
let mut hnsw_sq_builder = IvfHnswSqIndexBuilder::default();
|
||||
if let Some(distance_type) = distance_type {
|
||||
let distance_type = parse_distance_type(distance_type)?;
|
||||
hnsw_sq_builder = hnsw_sq_builder.distance_type(distance_type);
|
||||
}
|
||||
if let Some(num_partitions) = num_partitions {
|
||||
hnsw_sq_builder = hnsw_sq_builder.num_partitions(num_partitions);
|
||||
}
|
||||
if let Some(max_iterations) = max_iterations {
|
||||
hnsw_sq_builder = hnsw_sq_builder.max_iterations(max_iterations);
|
||||
}
|
||||
if let Some(sample_rate) = sample_rate {
|
||||
hnsw_sq_builder = hnsw_sq_builder.sample_rate(sample_rate);
|
||||
}
|
||||
if let Some(m) = m {
|
||||
hnsw_sq_builder = hnsw_sq_builder.num_edges(m);
|
||||
}
|
||||
if let Some(ef_construction) = ef_construction {
|
||||
hnsw_sq_builder = hnsw_sq_builder.ef_construction(ef_construction);
|
||||
}
|
||||
Ok(Self {
|
||||
inner: Mutex::new(Some(LanceDbIndex::IvfHnswSq(hnsw_sq_builder))),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -82,6 +82,54 @@ class FTS:
|
||||
self._inner = LanceDbIndex.fts(with_position=with_position)
|
||||
|
||||
|
||||
class HnswPq:
|
||||
"""Describe a Hnswpq index configuration."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
distance_type: Optional[str] = None,
|
||||
num_partitions: Optional[int] = None,
|
||||
num_sub_vectors: Optional[int] = None,
|
||||
max_iterations: Optional[int] = None,
|
||||
sample_rate: Optional[int] = None,
|
||||
m: Optional[int] = None,
|
||||
ef_construction: Optional[int] = None,
|
||||
):
|
||||
self._inner = LanceDbIndex.hnsw_pq(
|
||||
distance_type=distance_type,
|
||||
num_partitions=num_partitions,
|
||||
num_sub_vectors=num_sub_vectors,
|
||||
max_iterations=max_iterations,
|
||||
sample_rate=sample_rate,
|
||||
m=m,
|
||||
ef_construction=ef_construction,
|
||||
)
|
||||
|
||||
|
||||
class HnswSq:
|
||||
"""Describe a HNSW-SQ index configuration."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
distance_type: Optional[str] = None,
|
||||
num_partitions: Optional[int] = None,
|
||||
max_iterations: Optional[int] = None,
|
||||
sample_rate: Optional[int] = None,
|
||||
m: Optional[int] = None,
|
||||
ef_construction: Optional[int] = None,
|
||||
):
|
||||
self._inner = LanceDbIndex.hnsw_sq(
|
||||
distance_type=distance_type,
|
||||
num_partitions=num_partitions,
|
||||
max_iterations=max_iterations,
|
||||
sample_rate=sample_rate,
|
||||
m=m,
|
||||
ef_construction=ef_construction,
|
||||
)
|
||||
|
||||
|
||||
class IvfPq:
|
||||
"""Describes an IVF PQ Index
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ import pyarrow as pa
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from lancedb import AsyncConnection, AsyncTable, connect_async
|
||||
from lancedb.index import BTree, IvfPq, Bitmap, LabelList
|
||||
from lancedb.index import BTree, IvfPq, Bitmap, LabelList, HnswPq, HnswSq
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
@@ -91,3 +91,17 @@ async def test_create_vector_index(some_table: AsyncTable):
|
||||
assert len(indices) == 1
|
||||
assert indices[0].index_type == "IvfPq"
|
||||
assert indices[0].columns == ["vector"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_hnswpq_index(some_table: AsyncTable):
|
||||
await some_table.create_index("vector", config=HnswPq(num_partitions=10))
|
||||
indices = await some_table.list_indices()
|
||||
assert len(indices) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_hnswsq_index(some_table: AsyncTable):
|
||||
await some_table.create_index("vector", config=HnswSq(num_partitions=10))
|
||||
indices = await some_table.list_indices()
|
||||
assert len(indices) == 1
|
||||
|
||||
@@ -16,7 +16,11 @@ use std::sync::Mutex;
|
||||
|
||||
use lancedb::index::scalar::FtsIndexBuilder;
|
||||
use lancedb::{
|
||||
index::{scalar::BTreeIndexBuilder, vector::IvfPqIndexBuilder, Index as LanceDbIndex},
|
||||
index::{
|
||||
scalar::BTreeIndexBuilder,
|
||||
vector::{IvfHnswPqIndexBuilder, IvfHnswSqIndexBuilder, IvfPqIndexBuilder},
|
||||
Index as LanceDbIndex,
|
||||
},
|
||||
DistanceType,
|
||||
};
|
||||
use pyo3::{
|
||||
@@ -24,6 +28,8 @@ use pyo3::{
|
||||
pyclass, pymethods, PyResult,
|
||||
};
|
||||
|
||||
use crate::util::parse_distance_type;
|
||||
|
||||
#[pyclass]
|
||||
pub struct Index {
|
||||
inner: Mutex<Option<LanceDbIndex>>,
|
||||
@@ -110,6 +116,78 @@ impl Index {
|
||||
inner: Mutex::new(Some(LanceDbIndex::FTS(opts))),
|
||||
}
|
||||
}
|
||||
|
||||
#[staticmethod]
|
||||
pub fn hnsw_pq(
|
||||
distance_type: Option<String>,
|
||||
num_partitions: Option<u32>,
|
||||
num_sub_vectors: Option<u32>,
|
||||
max_iterations: Option<u32>,
|
||||
sample_rate: Option<u32>,
|
||||
m: Option<u32>,
|
||||
ef_construction: Option<u32>,
|
||||
) -> PyResult<Self> {
|
||||
let mut hnsw_pq_builder = IvfHnswPqIndexBuilder::default();
|
||||
if let Some(distance_type) = distance_type {
|
||||
let distance_type = parse_distance_type(distance_type)?;
|
||||
hnsw_pq_builder = hnsw_pq_builder.distance_type(distance_type);
|
||||
}
|
||||
if let Some(num_partitions) = num_partitions {
|
||||
hnsw_pq_builder = hnsw_pq_builder.num_partitions(num_partitions);
|
||||
}
|
||||
if let Some(num_sub_vectors) = num_sub_vectors {
|
||||
hnsw_pq_builder = hnsw_pq_builder.num_sub_vectors(num_sub_vectors);
|
||||
}
|
||||
if let Some(max_iterations) = max_iterations {
|
||||
hnsw_pq_builder = hnsw_pq_builder.max_iterations(max_iterations);
|
||||
}
|
||||
if let Some(sample_rate) = sample_rate {
|
||||
hnsw_pq_builder = hnsw_pq_builder.sample_rate(sample_rate);
|
||||
}
|
||||
if let Some(m) = m {
|
||||
hnsw_pq_builder = hnsw_pq_builder.num_edges(m);
|
||||
}
|
||||
if let Some(ef_construction) = ef_construction {
|
||||
hnsw_pq_builder = hnsw_pq_builder.ef_construction(ef_construction);
|
||||
}
|
||||
Ok(Self {
|
||||
inner: Mutex::new(Some(LanceDbIndex::IvfHnswPq(hnsw_pq_builder))),
|
||||
})
|
||||
}
|
||||
|
||||
#[staticmethod]
|
||||
pub fn hnsw_sq(
|
||||
distance_type: Option<String>,
|
||||
num_partitions: Option<u32>,
|
||||
max_iterations: Option<u32>,
|
||||
sample_rate: Option<u32>,
|
||||
m: Option<u32>,
|
||||
ef_construction: Option<u32>,
|
||||
) -> PyResult<Self> {
|
||||
let mut hnsw_sq_builder = IvfHnswSqIndexBuilder::default();
|
||||
if let Some(distance_type) = distance_type {
|
||||
let distance_type = parse_distance_type(distance_type)?;
|
||||
hnsw_sq_builder = hnsw_sq_builder.distance_type(distance_type);
|
||||
}
|
||||
if let Some(num_partitions) = num_partitions {
|
||||
hnsw_sq_builder = hnsw_sq_builder.num_partitions(num_partitions);
|
||||
}
|
||||
if let Some(max_iterations) = max_iterations {
|
||||
hnsw_sq_builder = hnsw_sq_builder.max_iterations(max_iterations);
|
||||
}
|
||||
if let Some(sample_rate) = sample_rate {
|
||||
hnsw_sq_builder = hnsw_sq_builder.sample_rate(sample_rate);
|
||||
}
|
||||
if let Some(m) = m {
|
||||
hnsw_sq_builder = hnsw_sq_builder.num_edges(m);
|
||||
}
|
||||
if let Some(ef_construction) = ef_construction {
|
||||
hnsw_sq_builder = hnsw_sq_builder.ef_construction(ef_construction);
|
||||
}
|
||||
Ok(Self {
|
||||
inner: Mutex::new(Some(LanceDbIndex::IvfHnswSq(hnsw_sq_builder))),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[pyclass(get_all)]
|
||||
|
||||
Reference in New Issue
Block a user