feat: support 4bit PQ (#1916)

This commit is contained in:
BubbleCal
2024-12-10 10:36:03 +08:00
committed by GitHub
parent ab5316b4fa
commit 3324e7d525
10 changed files with 94 additions and 5 deletions

View File

@@ -47,12 +47,13 @@ impl Index {
#[pymethods]
impl Index {
#[pyo3(signature = (distance_type=None, num_partitions=None, num_sub_vectors=None, max_iterations=None, sample_rate=None))]
#[pyo3(signature = (distance_type=None, num_partitions=None, num_sub_vectors=None,num_bits=None, max_iterations=None, sample_rate=None))]
#[staticmethod]
pub fn ivf_pq(
distance_type: Option<String>,
num_partitions: Option<u32>,
num_sub_vectors: Option<u32>,
num_bits: Option<u32>,
max_iterations: Option<u32>,
sample_rate: Option<u32>,
) -> PyResult<Self> {
@@ -75,6 +76,9 @@ impl Index {
if let Some(num_sub_vectors) = num_sub_vectors {
ivf_pq_builder = ivf_pq_builder.num_sub_vectors(num_sub_vectors);
}
if let Some(num_bits) = num_bits {
ivf_pq_builder = ivf_pq_builder.num_bits(num_bits);
}
if let Some(max_iterations) = max_iterations {
ivf_pq_builder = ivf_pq_builder.max_iterations(max_iterations);
}
@@ -148,12 +152,14 @@ impl Index {
}
}
#[pyo3(signature = (distance_type=None, num_partitions=None, num_sub_vectors=None, max_iterations=None, sample_rate=None, m=None, ef_construction=None))]
#[pyo3(signature = (distance_type=None, num_partitions=None, num_sub_vectors=None,num_bits=None, max_iterations=None, sample_rate=None, m=None, ef_construction=None))]
#[staticmethod]
#[allow(clippy::too_many_arguments)]
pub fn hnsw_pq(
distance_type: Option<String>,
num_partitions: Option<u32>,
num_sub_vectors: Option<u32>,
num_bits: Option<u32>,
max_iterations: Option<u32>,
sample_rate: Option<u32>,
m: Option<u32>,
@@ -170,6 +176,9 @@ impl Index {
if let Some(num_sub_vectors) = num_sub_vectors {
hnsw_pq_builder = hnsw_pq_builder.num_sub_vectors(num_sub_vectors);
}
if let Some(num_bits) = num_bits {
hnsw_pq_builder = hnsw_pq_builder.num_bits(num_bits);
}
if let Some(max_iterations) = max_iterations {
hnsw_pq_builder = hnsw_pq_builder.max_iterations(max_iterations);
}