feat: expose hnsw indices (#1595)

PR closes #1522

---------

Co-authored-by: Will Jones <willjones127@gmail.com>
This commit is contained in:
Gagan Bhullar
2024-09-10 12:08:13 -06:00
committed by GitHub
parent 2bde5401eb
commit 205fc530cf
6 changed files with 291 additions and 3 deletions

View File

@@ -82,6 +82,54 @@ class FTS:
self._inner = LanceDbIndex.fts(with_position=with_position)
class HnswPq:
"""Describe a Hnswpq index configuration."""
def __init__(
self,
*,
distance_type: Optional[str] = None,
num_partitions: Optional[int] = None,
num_sub_vectors: Optional[int] = None,
max_iterations: Optional[int] = None,
sample_rate: Optional[int] = None,
m: Optional[int] = None,
ef_construction: Optional[int] = None,
):
self._inner = LanceDbIndex.hnsw_pq(
distance_type=distance_type,
num_partitions=num_partitions,
num_sub_vectors=num_sub_vectors,
max_iterations=max_iterations,
sample_rate=sample_rate,
m=m,
ef_construction=ef_construction,
)
class HnswSq:
"""Describe a HNSW-SQ index configuration."""
def __init__(
self,
*,
distance_type: Optional[str] = None,
num_partitions: Optional[int] = None,
max_iterations: Optional[int] = None,
sample_rate: Optional[int] = None,
m: Optional[int] = None,
ef_construction: Optional[int] = None,
):
self._inner = LanceDbIndex.hnsw_sq(
distance_type=distance_type,
num_partitions=num_partitions,
max_iterations=max_iterations,
sample_rate=sample_rate,
m=m,
ef_construction=ef_construction,
)
class IvfPq:
"""Describes an IVF PQ Index

View File

@@ -8,7 +8,7 @@ import pyarrow as pa
import pytest
import pytest_asyncio
from lancedb import AsyncConnection, AsyncTable, connect_async
from lancedb.index import BTree, IvfPq, Bitmap, LabelList
from lancedb.index import BTree, IvfPq, Bitmap, LabelList, HnswPq, HnswSq
@pytest_asyncio.fixture
@@ -91,3 +91,17 @@ async def test_create_vector_index(some_table: AsyncTable):
assert len(indices) == 1
assert indices[0].index_type == "IvfPq"
assert indices[0].columns == ["vector"]
@pytest.mark.asyncio
async def test_create_hnswpq_index(some_table: AsyncTable):
await some_table.create_index("vector", config=HnswPq(num_partitions=10))
indices = await some_table.list_indices()
assert len(indices) == 1
@pytest.mark.asyncio
async def test_create_hnswsq_index(some_table: AsyncTable):
await some_table.create_index("vector", config=HnswSq(num_partitions=10))
indices = await some_table.list_indices()
assert len(indices) == 1