mirror of
https://github.com/lancedb/lancedb.git
synced 2026-05-20 21:40:43 +00:00
## Summary Wire up `IVF_HNSW_FLAT` in the Rust core and Python SDK. The index was documented at https://docs.lancedb.com/indexing/vector-index but `lancedb.Table.create_index(index_type="IVF_HNSW_FLAT")` raised `ValueError: Unknown index type IVF_HNSW_FLAT` — the underlying `pylance` already accepted it, only the LanceDB wrapper was missing the wiring. **Rust core (`rust/lancedb`):** - Add `Index::IvfHnswFlat` / `IndexType::IvfHnswFlat` variants and the `IvfHnswFlatIndexBuilder` (modelled on `IvfHnswSqIndexBuilder`). - Build Lance params via the existing `VectorIndexParams::ivf_hnsw(...)` helper, keeping symmetry with the other `IVF_HNSW_*` variants. - Forward the variant in `RemoteTable::create_index` and add two parametrised tests (default + customised config) for the JSON serialisation. - New `NativeTable` integration test (`test_create_index_ivf_hnsw_flat`). **Python binding (`python/`):** - New `HnswFlat` dataclass + backwards-compat `IvfHnswFlat` alias. - PyO3 `extract_index_params` recognises the `HnswFlat` config. - `LanceTable.create_index(index_type="IVF_HNSW_FLAT", …)` and the sync `RemoteTable.create_index` both dispatch to the new config. - `IndexStatistics.index_type` `Literal` and `_lancedb.pyi` stubs cover the new type so `pyright`/`make check` stays clean. - Async integration tests (`HnswFlat` + `IvfHnswFlat` alias) and a sync dispatcher test, mirroring the existing `IVF_HNSW_SQ` coverage. - Existing `test_index_statistics_index_type_lists_all_supported_values` updated to include `IVF_HNSW_FLAT`. A matching Node.js / TypeScript binding is in a follow-up PR. Closes #3331 ## Test plan - [ ] \`cargo check --quiet --features remote --tests --examples\` - [ ] \`cargo test --quiet --features remote -p lancedb\` (covers the new \`test_create_index_ivf_hnsw_flat\` and the two new parametrised \`RemoteTable::create_index\` cases) - [ ] \`cargo fmt --all\` / \`cargo clippy --quiet --features remote --tests --examples\` - [ ] \`cd python && make develop && make check && make test\` (covers the two new async tests, the alias test, the dispatcher test, and the updated \`test_index_statistics_index_type_lists_all_supported_values\` assertion)
326 lines
11 KiB
Python
326 lines
11 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
|
|
|
from datetime import timedelta
|
|
import random
|
|
from typing import get_args, get_type_hints
|
|
|
|
import pyarrow as pa
|
|
import pytest
|
|
import pytest_asyncio
|
|
from lancedb import AsyncConnection, AsyncTable, connect_async
|
|
from lancedb.index import (
|
|
BTree,
|
|
IvfFlat,
|
|
IvfPq,
|
|
IvfSq,
|
|
IvfHnswPq,
|
|
IvfHnswSq,
|
|
IvfHnswFlat,
|
|
IvfRq,
|
|
Bitmap,
|
|
LabelList,
|
|
HnswPq,
|
|
HnswSq,
|
|
HnswFlat,
|
|
FTS,
|
|
)
|
|
from lancedb.table import IndexStatistics
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def db_async(tmp_path) -> AsyncConnection:
|
|
return await connect_async(tmp_path, read_consistency_interval=timedelta(seconds=0))
|
|
|
|
|
|
def sample_fixed_size_list_array(nrows, dim):
|
|
vector_data = pa.array([float(i) for i in range(dim * nrows)], pa.float32())
|
|
return pa.FixedSizeListArray.from_arrays(vector_data, dim)
|
|
|
|
|
|
DIM = 8
|
|
NROWS = 256
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def some_table(db_async):
|
|
data = pa.Table.from_pydict(
|
|
{
|
|
"id": list(range(NROWS)),
|
|
"vector": sample_fixed_size_list_array(NROWS, DIM),
|
|
"fsb": pa.array([bytes([i]) for i in range(NROWS)], pa.binary(1)),
|
|
"tags": [
|
|
[f"tag{random.randint(0, 8)}" for _ in range(2)] for _ in range(NROWS)
|
|
],
|
|
"is_active": [random.choice([True, False]) for _ in range(NROWS)],
|
|
"data": [random.randbytes(random.randint(0, 128)) for _ in range(NROWS)],
|
|
}
|
|
)
|
|
return await db_async.create_table(
|
|
"some_table",
|
|
data,
|
|
)
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def binary_table(db_async):
|
|
data = [
|
|
{
|
|
"id": i,
|
|
"vector": [i] * 128,
|
|
}
|
|
for i in range(NROWS)
|
|
]
|
|
return await db_async.create_table(
|
|
"binary_table",
|
|
data,
|
|
schema=pa.schema(
|
|
[
|
|
pa.field("id", pa.int64()),
|
|
pa.field("vector", pa.list_(pa.uint8(), 128)),
|
|
]
|
|
),
|
|
)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_scalar_index(some_table: AsyncTable):
|
|
# Can create
|
|
await some_table.create_index("id")
|
|
# Can recreate if replace=True
|
|
await some_table.create_index("id", replace=True)
|
|
indices = await some_table.list_indices()
|
|
assert str(indices) == '[Index(BTree, columns=["id"], name="id_idx")]'
|
|
assert len(indices) == 1
|
|
assert indices[0].index_type == "BTree"
|
|
assert indices[0].columns == ["id"]
|
|
# Can't recreate if replace=False
|
|
with pytest.raises(RuntimeError, match="already exists"):
|
|
await some_table.create_index("id", replace=False)
|
|
# can also specify index type
|
|
await some_table.create_index("id", config=BTree())
|
|
|
|
await some_table.drop_index("id_idx")
|
|
indices = await some_table.list_indices()
|
|
assert len(indices) == 0
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_fixed_size_binary_index(some_table: AsyncTable):
|
|
await some_table.create_index("fsb", config=BTree())
|
|
indices = await some_table.list_indices()
|
|
assert str(indices) == '[Index(BTree, columns=["fsb"], name="fsb_idx")]'
|
|
assert len(indices) == 1
|
|
assert indices[0].index_type == "BTree"
|
|
assert indices[0].columns == ["fsb"]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_bitmap_index(some_table: AsyncTable):
|
|
await some_table.create_index("id", config=Bitmap())
|
|
await some_table.create_index("is_active", config=Bitmap())
|
|
await some_table.create_index("data", config=Bitmap())
|
|
indices = await some_table.list_indices()
|
|
assert len(indices) == 3
|
|
assert indices[0].index_type == "Bitmap"
|
|
assert indices[0].columns == ["id"]
|
|
assert indices[1].index_type == "Bitmap"
|
|
assert indices[1].columns == ["is_active"]
|
|
assert indices[2].index_type == "Bitmap"
|
|
assert indices[2].columns == ["data"]
|
|
|
|
index_name = indices[0].name
|
|
stats = await some_table.index_stats(index_name)
|
|
assert stats.index_type == "BITMAP"
|
|
assert stats.distance_type is None
|
|
assert stats.num_indexed_rows == await some_table.count_rows()
|
|
assert stats.num_unindexed_rows == 0
|
|
assert stats.num_indices == 1
|
|
|
|
assert (
|
|
"ScalarIndexQuery"
|
|
in await some_table.query().where("is_active = TRUE").explain_plan()
|
|
)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_label_list_index(some_table: AsyncTable):
|
|
await some_table.create_index("tags", config=LabelList())
|
|
indices = await some_table.list_indices()
|
|
assert str(indices) == '[Index(LabelList, columns=["tags"], name="tags_idx")]'
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_full_text_search_index(some_table: AsyncTable):
|
|
await some_table.create_index("tags", config=FTS(with_position=False))
|
|
indices = await some_table.list_indices()
|
|
assert str(indices) == '[Index(FTS, columns=["tags"], name="tags_idx")]'
|
|
|
|
await some_table.prewarm_index("tags_idx")
|
|
|
|
res = await (await some_table.search("tag0")).to_arrow()
|
|
assert res.num_rows > 0
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_vector_index(some_table: AsyncTable):
|
|
# Can create
|
|
await some_table.create_index("vector")
|
|
# Can recreate if replace=True
|
|
await some_table.create_index("vector", replace=True)
|
|
# Can't recreate if replace=False
|
|
with pytest.raises(RuntimeError, match="already exists"):
|
|
await some_table.create_index("vector", replace=False)
|
|
# Can also specify index type
|
|
await some_table.create_index("vector", config=IvfPq(num_partitions=100))
|
|
indices = await some_table.list_indices()
|
|
assert len(indices) == 1
|
|
assert indices[0].index_type == "IvfPq"
|
|
assert indices[0].columns == ["vector"]
|
|
assert indices[0].name == "vector_idx"
|
|
|
|
stats = await some_table.index_stats("vector_idx")
|
|
assert stats.index_type == "IVF_PQ"
|
|
assert stats.distance_type == "l2"
|
|
assert stats.num_indexed_rows == await some_table.count_rows()
|
|
assert stats.num_unindexed_rows == 0
|
|
assert stats.num_indices == 1
|
|
assert stats.loss >= 0.0
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_4bit_ivfpq_index(some_table: AsyncTable):
|
|
# Can create
|
|
await some_table.create_index("vector", config=IvfPq(num_bits=4))
|
|
# Can recreate if replace=True
|
|
await some_table.create_index("vector", config=IvfPq(num_bits=4), replace=True)
|
|
# Can't recreate if replace=False
|
|
with pytest.raises(RuntimeError, match="already exists"):
|
|
await some_table.create_index("vector", replace=False)
|
|
indices = await some_table.list_indices()
|
|
assert len(indices) == 1
|
|
assert indices[0].index_type == "IvfPq"
|
|
assert indices[0].columns == ["vector"]
|
|
assert indices[0].name == "vector_idx"
|
|
|
|
stats = await some_table.index_stats("vector_idx")
|
|
assert stats.index_type == "IVF_PQ"
|
|
assert stats.distance_type == "l2"
|
|
assert stats.num_indexed_rows == await some_table.count_rows()
|
|
assert stats.num_unindexed_rows == 0
|
|
assert stats.num_indices == 1
|
|
assert stats.loss >= 0.0
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_ivfrq_index(some_table: AsyncTable):
|
|
await some_table.create_index("vector", config=IvfRq(num_bits=1))
|
|
indices = await some_table.list_indices()
|
|
assert len(indices) == 1
|
|
assert indices[0].index_type == "IvfRq"
|
|
assert indices[0].columns == ["vector"]
|
|
assert indices[0].name == "vector_idx"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_hnswpq_index(some_table: AsyncTable):
|
|
await some_table.create_index("vector", config=HnswPq(num_partitions=10))
|
|
indices = await some_table.list_indices()
|
|
assert len(indices) == 1
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_hnswsq_index(some_table: AsyncTable):
|
|
await some_table.create_index("vector", config=HnswSq(num_partitions=10))
|
|
indices = await some_table.list_indices()
|
|
assert len(indices) == 1
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_hnswsq_alias_index(some_table: AsyncTable):
|
|
await some_table.create_index("vector", config=IvfHnswSq(num_partitions=5))
|
|
indices = await some_table.list_indices()
|
|
assert len(indices) == 1
|
|
assert indices[0].index_type in {"HnswSq", "IvfHnswSq"}
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_hnswpq_alias_index(some_table: AsyncTable):
|
|
await some_table.create_index("vector", config=IvfHnswPq(num_partitions=5))
|
|
indices = await some_table.list_indices()
|
|
assert len(indices) == 1
|
|
assert indices[0].index_type in {"HnswPq", "IvfHnswPq"}
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_hnswflat_index(some_table: AsyncTable):
|
|
await some_table.create_index("vector", config=HnswFlat(num_partitions=10))
|
|
indices = await some_table.list_indices()
|
|
assert len(indices) == 1
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_hnswflat_alias_index(some_table: AsyncTable):
|
|
await some_table.create_index("vector", config=IvfHnswFlat(num_partitions=5))
|
|
indices = await some_table.list_indices()
|
|
assert len(indices) == 1
|
|
assert indices[0].index_type in {"HnswFlat", "IvfHnswFlat"}
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_ivfsq_index(some_table: AsyncTable):
|
|
await some_table.create_index("vector", config=IvfSq(num_partitions=10))
|
|
indices = await some_table.list_indices()
|
|
assert len(indices) == 1
|
|
assert indices[0].index_type == "IvfSq"
|
|
stats = await some_table.index_stats(indices[0].name)
|
|
assert stats.index_type == "IVF_SQ"
|
|
assert stats.distance_type == "l2"
|
|
assert stats.num_indexed_rows == await some_table.count_rows()
|
|
assert stats.num_unindexed_rows == 0
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_index_with_binary_vectors(binary_table: AsyncTable):
|
|
await binary_table.create_index(
|
|
"vector", config=IvfFlat(distance_type="hamming", num_partitions=10)
|
|
)
|
|
indices = await binary_table.list_indices()
|
|
assert len(indices) == 1
|
|
assert indices[0].index_type == "IvfFlat"
|
|
assert indices[0].columns == ["vector"]
|
|
assert indices[0].name == "vector_idx"
|
|
|
|
stats = await binary_table.index_stats("vector_idx")
|
|
assert stats.index_type == "IVF_FLAT"
|
|
assert stats.distance_type == "hamming"
|
|
assert stats.num_indexed_rows == await binary_table.count_rows()
|
|
assert stats.num_unindexed_rows == 0
|
|
assert stats.num_indices == 1
|
|
|
|
# the dataset contains vectors with all values from 0 to 255
|
|
for v in range(256):
|
|
res = await binary_table.query().nearest_to([v] * 128).to_arrow()
|
|
assert res["id"][0].as_py() == v
|
|
|
|
|
|
def test_index_statistics_index_type_lists_all_supported_values():
|
|
expected_index_types = {
|
|
"IVF_FLAT",
|
|
"IVF_SQ",
|
|
"IVF_PQ",
|
|
"IVF_RQ",
|
|
"IVF_HNSW_SQ",
|
|
"IVF_HNSW_PQ",
|
|
"IVF_HNSW_FLAT",
|
|
"FTS",
|
|
"BTREE",
|
|
"BITMAP",
|
|
"LABEL_LIST",
|
|
}
|
|
|
|
assert (
|
|
set(get_args(get_type_hints(IndexStatistics)["index_type"]))
|
|
== expected_index_types
|
|
)
|