mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-27 23:12:58 +00:00
feat: support IVF_FLAT, binary vectors and hamming distance (#1955)
binary vectors and hamming distance can work on only IVF_FLAT, so introduce them all in this PR. --------- Signed-off-by: BubbleCal <bubble-cal@outlook.com>
This commit is contained in:
@@ -12,6 +12,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use lancedb::index::vector::IvfFlatIndexBuilder;
|
||||
use lancedb::index::{
|
||||
scalar::{BTreeIndexBuilder, FtsIndexBuilder, TokenizerConfig},
|
||||
vector::{IvfHnswPqIndexBuilder, IvfHnswSqIndexBuilder, IvfPqIndexBuilder},
|
||||
@@ -59,6 +60,18 @@ pub fn extract_index_params(source: &Option<Bound<'_, PyAny>>) -> PyResult<Lance
|
||||
opts.tokenizer_configs = inner_opts;
|
||||
Ok(LanceDbIndex::FTS(opts))
|
||||
},
|
||||
"IvfFlat" => {
|
||||
let params = source.extract::<IvfFlatParams>()?;
|
||||
let distance_type = parse_distance_type(params.distance_type)?;
|
||||
let mut ivf_flat_builder = IvfFlatIndexBuilder::default()
|
||||
.distance_type(distance_type)
|
||||
.max_iterations(params.max_iterations)
|
||||
.sample_rate(params.sample_rate);
|
||||
if let Some(num_partitions) = params.num_partitions {
|
||||
ivf_flat_builder = ivf_flat_builder.num_partitions(num_partitions);
|
||||
}
|
||||
Ok(LanceDbIndex::IvfFlat(ivf_flat_builder))
|
||||
},
|
||||
"IvfPq" => {
|
||||
let params = source.extract::<IvfPqParams>()?;
|
||||
let distance_type = parse_distance_type(params.distance_type)?;
|
||||
@@ -129,6 +142,14 @@ struct FtsParams {
|
||||
ascii_folding: bool,
|
||||
}
|
||||
|
||||
#[derive(FromPyObject)]
|
||||
struct IvfFlatParams {
|
||||
distance_type: String,
|
||||
num_partitions: Option<u32>,
|
||||
max_iterations: u32,
|
||||
sample_rate: u32,
|
||||
}
|
||||
|
||||
#[derive(FromPyObject)]
|
||||
struct IvfPqParams {
|
||||
distance_type: String,
|
||||
|
||||
@@ -43,8 +43,9 @@ pub fn parse_distance_type(distance_type: impl AsRef<str>) -> PyResult<DistanceT
|
||||
"l2" => Ok(DistanceType::L2),
|
||||
"cosine" => Ok(DistanceType::Cosine),
|
||||
"dot" => Ok(DistanceType::Dot),
|
||||
"hamming" => Ok(DistanceType::Hamming),
|
||||
_ => Err(PyValueError::new_err(format!(
|
||||
"Invalid distance type '{}'. Must be one of l2, cosine, or dot",
|
||||
"Invalid distance type '{}'. Must be one of l2, cosine, dot, or hamming",
|
||||
distance_type.as_ref()
|
||||
))),
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user