mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-03 02:12:56 +00:00
feat: support 4bit PQ (#1916)
This commit is contained in:
@@ -178,6 +178,12 @@ class HnswPq:
|
||||
If the dimension is not visible by 8 then we use 1 subvector. This is not
|
||||
ideal and will likely result in poor performance.
|
||||
|
||||
num_bits: int, default 8
|
||||
Number of bits to encode each sub-vector.
|
||||
|
||||
This value controls how much the sub-vectors are compressed. The more bits
|
||||
the more accurate the index but the slower search. Only 4 and 8 are supported.
|
||||
|
||||
max_iterations, default 50
|
||||
|
||||
Max iterations to train kmeans.
|
||||
@@ -232,6 +238,7 @@ class HnswPq:
|
||||
distance_type: Optional[str] = None,
|
||||
num_partitions: Optional[int] = None,
|
||||
num_sub_vectors: Optional[int] = None,
|
||||
num_bits: Optional[int] = None,
|
||||
max_iterations: Optional[int] = None,
|
||||
sample_rate: Optional[int] = None,
|
||||
m: Optional[int] = None,
|
||||
@@ -241,6 +248,7 @@ class HnswPq:
|
||||
distance_type=distance_type,
|
||||
num_partitions=num_partitions,
|
||||
num_sub_vectors=num_sub_vectors,
|
||||
num_bits=num_bits,
|
||||
max_iterations=max_iterations,
|
||||
sample_rate=sample_rate,
|
||||
m=m,
|
||||
@@ -387,6 +395,7 @@ class IvfPq:
|
||||
distance_type: Optional[str] = None,
|
||||
num_partitions: Optional[int] = None,
|
||||
num_sub_vectors: Optional[int] = None,
|
||||
num_bits: Optional[int] = None,
|
||||
max_iterations: Optional[int] = None,
|
||||
sample_rate: Optional[int] = None,
|
||||
):
|
||||
@@ -449,6 +458,12 @@ class IvfPq:
|
||||
|
||||
If the dimension is not visible by 8 then we use 1 subvector. This is not
|
||||
ideal and will likely result in poor performance.
|
||||
num_bits: int, default 8
|
||||
Number of bits to encode each sub-vector.
|
||||
|
||||
This value controls how much the sub-vectors are compressed. The more bits
|
||||
the more accurate the index but the slower search. The default is 8
|
||||
bits. Only 4 and 8 are supported.
|
||||
max_iterations: int, default 50
|
||||
Max iteration to train kmeans.
|
||||
|
||||
@@ -482,6 +497,7 @@ class IvfPq:
|
||||
distance_type=distance_type,
|
||||
num_partitions=num_partitions,
|
||||
num_sub_vectors=num_sub_vectors,
|
||||
num_bits=num_bits,
|
||||
max_iterations=max_iterations,
|
||||
sample_rate=sample_rate,
|
||||
)
|
||||
|
||||
@@ -413,6 +413,8 @@ class Table(ABC):
|
||||
replace: bool = True,
|
||||
accelerator: Optional[str] = None,
|
||||
index_cache_size: Optional[int] = None,
|
||||
*,
|
||||
num_bits: int = 8,
|
||||
):
|
||||
"""Create an index on the table.
|
||||
|
||||
@@ -439,6 +441,9 @@ class Table(ABC):
|
||||
Only support "cuda" for now.
|
||||
index_cache_size : int, optional
|
||||
The size of the index cache in number of entries. Default value is 256.
|
||||
num_bits: int
|
||||
The number of bits to encode sub-vectors. Only used with the IVF_PQ index.
|
||||
Only 4 and 8 are supported.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -1430,6 +1435,8 @@ class LanceTable(Table):
|
||||
accelerator: Optional[str] = None,
|
||||
index_cache_size: Optional[int] = None,
|
||||
index_type="IVF_PQ",
|
||||
*,
|
||||
num_bits: int = 8,
|
||||
):
|
||||
"""Create an index on the table."""
|
||||
self._dataset_mut.create_index(
|
||||
@@ -1441,6 +1448,7 @@ class LanceTable(Table):
|
||||
replace=replace,
|
||||
accelerator=accelerator,
|
||||
index_cache_size=index_cache_size,
|
||||
num_bits=num_bits,
|
||||
)
|
||||
|
||||
def create_scalar_index(
|
||||
|
||||
@@ -108,6 +108,29 @@ async def test_create_vector_index(some_table: AsyncTable):
|
||||
assert stats.num_indices == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_4bit_ivfpq_index(some_table: AsyncTable):
|
||||
# Can create
|
||||
await some_table.create_index("vector", config=IvfPq(num_bits=4))
|
||||
# Can recreate if replace=True
|
||||
await some_table.create_index("vector", config=IvfPq(num_bits=4), replace=True)
|
||||
# Can't recreate if replace=False
|
||||
with pytest.raises(RuntimeError, match="already exists"):
|
||||
await some_table.create_index("vector", replace=False)
|
||||
indices = await some_table.list_indices()
|
||||
assert len(indices) == 1
|
||||
assert indices[0].index_type == "IvfPq"
|
||||
assert indices[0].columns == ["vector"]
|
||||
assert indices[0].name == "vector_idx"
|
||||
|
||||
stats = await some_table.index_stats("vector_idx")
|
||||
assert stats.index_type == "IVF_PQ"
|
||||
assert stats.distance_type == "l2"
|
||||
assert stats.num_indexed_rows == await some_table.count_rows()
|
||||
assert stats.num_unindexed_rows == 0
|
||||
assert stats.num_indices == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_hnswpq_index(some_table: AsyncTable):
|
||||
await some_table.create_index("vector", config=HnswPq(num_partitions=10))
|
||||
|
||||
@@ -530,6 +530,7 @@ def test_create_index_method():
|
||||
replace=True,
|
||||
accelerator=None,
|
||||
index_cache_size=256,
|
||||
num_bits=8,
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user