feat: support IVF_FLAT, binary vectors and hamming distance (#1955)

binary vectors and hamming distance can work on only IVF_FLAT, so
introduce them all in this PR.

---------

Signed-off-by: BubbleCal <bubble-cal@outlook.com>
This commit is contained in:
BubbleCal
2024-12-25 02:36:20 +08:00
committed by GitHub
parent ac0068b80e
commit e70fd4fecc
14 changed files with 390 additions and 35 deletions

View File

@@ -0,0 +1,44 @@
import shutil
# --8<-- [start:imports]
import lancedb
import numpy as np
import pytest
# --8<-- [end:imports]
shutil.rmtree("data/binary_lancedb", ignore_errors=True)
def test_binary_vector():
# --8<-- [start:sync_binary_vector]
db = lancedb.connect("data/binary_lancedb")
data = [
{
"id": i,
"vector": np.random.randint(0, 256, size=16),
}
for i in range(1024)
]
tbl = db.create_table("my_binary_vectors", data=data)
query = np.random.randint(0, 256, size=16)
tbl.search(query).to_arrow()
# --8<-- [end:sync_binary_vector]
db.drop_table("my_binary_vectors")
@pytest.mark.asyncio
async def test_binary_vector_async():
# --8<-- [start:async_binary_vector]
db = await lancedb.connect_async("data/binary_lancedb")
data = [
{
"id": i,
"vector": np.random.randint(0, 256, size=16),
}
for i in range(1024)
]
tbl = await db.create_table("my_binary_vectors", data=data)
query = np.random.randint(0, 256, size=16)
await tbl.query().nearest_to(query).to_arrow()
# --8<-- [end:async_binary_vector]
await db.drop_table("my_binary_vectors")

View File

@@ -8,7 +8,7 @@ import pyarrow as pa
import pytest
import pytest_asyncio
from lancedb import AsyncConnection, AsyncTable, connect_async
from lancedb.index import BTree, IvfPq, Bitmap, LabelList, HnswPq, HnswSq
from lancedb.index import BTree, IvfFlat, IvfPq, Bitmap, LabelList, HnswPq, HnswSq
@pytest_asyncio.fixture
@@ -42,6 +42,27 @@ async def some_table(db_async):
)
@pytest_asyncio.fixture
async def binary_table(db_async):
data = [
{
"id": i,
"vector": [i] * 128,
}
for i in range(NROWS)
]
return await db_async.create_table(
"binary_table",
data,
schema=pa.schema(
[
pa.field("id", pa.int64()),
pa.field("vector", pa.list_(pa.uint8(), 128)),
]
),
)
@pytest.mark.asyncio
async def test_create_scalar_index(some_table: AsyncTable):
# Can create
@@ -143,3 +164,27 @@ async def test_create_hnswsq_index(some_table: AsyncTable):
await some_table.create_index("vector", config=HnswSq(num_partitions=10))
indices = await some_table.list_indices()
assert len(indices) == 1
@pytest.mark.asyncio
async def test_create_index_with_binary_vectors(binary_table: AsyncTable):
await binary_table.create_index(
"vector", config=IvfFlat(distance_type="hamming", num_partitions=10)
)
indices = await binary_table.list_indices()
assert len(indices) == 1
assert indices[0].index_type == "IvfFlat"
assert indices[0].columns == ["vector"]
assert indices[0].name == "vector_idx"
stats = await binary_table.index_stats("vector_idx")
assert stats.index_type == "IVF_FLAT"
assert stats.distance_type == "hamming"
assert stats.num_indexed_rows == await binary_table.count_rows()
assert stats.num_unindexed_rows == 0
assert stats.num_indices == 1
# the dataset contains vectors with all values from 0 to 255
for v in range(256):
res = await binary_table.query().nearest_to([v] * 128).to_arrow()
assert res["id"][0].as_py() == v