feat: support IVF_FLAT, binary vectors and hamming distance (#1955)

binary vectors and hamming distance can work on only IVF_FLAT, so
introduce them all in this PR.

---------

Signed-off-by: BubbleCal <bubble-cal@outlook.com>
This commit is contained in:
BubbleCal
2024-12-25 02:36:20 +08:00
committed by GitHub
parent ac0068b80e
commit e70fd4fecc
14 changed files with 390 additions and 35 deletions

View File

@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use lancedb::index::vector::IvfFlatIndexBuilder;
use lancedb::index::{
scalar::{BTreeIndexBuilder, FtsIndexBuilder, TokenizerConfig},
vector::{IvfHnswPqIndexBuilder, IvfHnswSqIndexBuilder, IvfPqIndexBuilder},
@@ -59,6 +60,18 @@ pub fn extract_index_params(source: &Option<Bound<'_, PyAny>>) -> PyResult<Lance
opts.tokenizer_configs = inner_opts;
Ok(LanceDbIndex::FTS(opts))
},
"IvfFlat" => {
let params = source.extract::<IvfFlatParams>()?;
let distance_type = parse_distance_type(params.distance_type)?;
let mut ivf_flat_builder = IvfFlatIndexBuilder::default()
.distance_type(distance_type)
.max_iterations(params.max_iterations)
.sample_rate(params.sample_rate);
if let Some(num_partitions) = params.num_partitions {
ivf_flat_builder = ivf_flat_builder.num_partitions(num_partitions);
}
Ok(LanceDbIndex::IvfFlat(ivf_flat_builder))
},
"IvfPq" => {
let params = source.extract::<IvfPqParams>()?;
let distance_type = parse_distance_type(params.distance_type)?;
@@ -129,6 +142,14 @@ struct FtsParams {
ascii_folding: bool,
}
#[derive(FromPyObject)]
struct IvfFlatParams {
distance_type: String,
num_partitions: Option<u32>,
max_iterations: u32,
sample_rate: u32,
}
#[derive(FromPyObject)]
struct IvfPqParams {
distance_type: String,

View File

@@ -43,8 +43,9 @@ pub fn parse_distance_type(distance_type: impl AsRef<str>) -> PyResult<DistanceT
"l2" => Ok(DistanceType::L2),
"cosine" => Ok(DistanceType::Cosine),
"dot" => Ok(DistanceType::Dot),
"hamming" => Ok(DistanceType::Hamming),
_ => Err(PyValueError::new_err(format!(
"Invalid distance type '{}'. Must be one of l2, cosine, or dot",
"Invalid distance type '{}'. Must be one of l2, cosine, dot, or hamming",
distance_type.as_ref()
))),
}