Adding implem for filter-vec for neon as suggested by Adam.

This commit is contained in:
Paul Masurel
2026-05-27 09:46:43 +02:00
parent 46b3fb9ed3
commit 465a761b2f
5 changed files with 467 additions and 8 deletions

View File

@@ -14,7 +14,6 @@ mod tests {
let mut bitpacker = BitPacker::new();
let mut buffer = Vec::new();
for _ in 0..num_els {
// the values do not matter.
bitpacker.write(0u64, bit_width, &mut buffer).unwrap();
bitpacker.flush(&mut buffer).unwrap();
}
@@ -62,4 +61,124 @@ mod tests {
blocked_bitpacker
});
}
// --- filter_vec benchmarks ---
//
// We use a large N so that the vec clone is a smaller fraction of the total time,
// and so L2/L3 cache effects are representative of real workloads.
// Values are spread uniformly in [0, MAX_VAL].
const N: usize = 100_000;
const MAX_VAL: u32 = 1_000;
fn make_values(n: usize, max_val: u32) -> Vec<u32> {
(0..n as u32)
.map(|i| (i as u64 * max_val as u64 / n as u64) as u32)
.collect()
}
#[bench]
fn bench_filter_vec_dense(b: &mut Bencher) {
// ~50% of values match [250, 750]
let vals = make_values(N, MAX_VAL);
b.iter(|| {
let mut v = vals.clone();
tantivy_bitpacker::filter_vec::filter_vec_in_place(250..=750, 0, &mut v);
v
});
}
#[bench]
fn bench_filter_vec_sparse(b: &mut Bencher) {
// ~5% of values match [0, 50]
let vals = make_values(N, MAX_VAL);
b.iter(|| {
let mut v = vals.clone();
tantivy_bitpacker::filter_vec::filter_vec_in_place(0..=50, 0, &mut v);
v
});
}
#[bench]
fn bench_filter_vec_full(b: &mut Bencher) {
// 100% of values match
let vals = make_values(N, MAX_VAL);
b.iter(|| {
let mut v = vals.clone();
tantivy_bitpacker::filter_vec::filter_vec_in_place(0..=MAX_VAL, 0, &mut v);
v
});
}
#[bench]
fn bench_filter_vec_scalar_dense(b: &mut Bencher) {
let vals = make_values(N, MAX_VAL);
b.iter(|| {
let mut v = vals.clone();
tantivy_bitpacker::filter_vec::scalar_filter_vec_in_place(250..=750, 0, &mut v);
v
});
}
#[bench]
fn bench_filter_vec_scalar_sparse(b: &mut Bencher) {
let vals = make_values(N, MAX_VAL);
b.iter(|| {
let mut v = vals.clone();
tantivy_bitpacker::filter_vec::scalar_filter_vec_in_place(0..=50, 0, &mut v);
v
});
}
// --- NEON / SVE / SVE2 (aarch64 only) ---
#[bench]
#[cfg(target_arch = "aarch64")]
fn bench_filter_vec_neon_dense(b: &mut Bencher) {
let vals = make_values(N, MAX_VAL);
b.iter(|| {
let mut v = vals.clone();
tantivy_bitpacker::filter_vec::neon_filter_vec_in_place(250..=750, 0, &mut v);
v
});
}
#[bench]
#[cfg(target_arch = "aarch64")]
fn bench_filter_vec_neon_sparse(b: &mut Bencher) {
let vals = make_values(N, MAX_VAL);
b.iter(|| {
let mut v = vals.clone();
tantivy_bitpacker::filter_vec::neon_filter_vec_in_place(0..=50, 0, &mut v);
v
});
}
#[bench]
#[cfg(all(target_arch = "aarch64", not(target_vendor = "apple")))]
fn bench_filter_vec_sve_dense(b: &mut Bencher) {
if !std::arch::is_aarch64_feature_detected!("sve") {
return;
}
let vals = make_values(N, MAX_VAL);
b.iter(|| {
let mut v = vals.clone();
tantivy_bitpacker::filter_vec::sve_filter_vec_in_place(250..=750, 0, &mut v);
v
});
}
#[bench]
#[cfg(all(target_arch = "aarch64", not(target_vendor = "apple")))]
fn bench_filter_vec_sve_sparse(b: &mut Bencher) {
if !std::arch::is_aarch64_feature_detected!("sve") {
return;
}
let vals = make_values(N, MAX_VAL);
b.iter(|| {
let mut v = vals.clone();
tantivy_bitpacker::filter_vec::sve_filter_vec_in_place(0..=50, 0, &mut v);
v
});
}
}

View File

@@ -1,8 +1,17 @@
#[cfg(all(target_arch = "aarch64", not(target_vendor = "apple")))]
use std::arch::is_aarch64_feature_detected;
use std::ops::RangeInclusive;
#[cfg(target_arch = "x86_64")]
mod avx2;
#[cfg(target_arch = "aarch64")]
mod neon;
// SVE intrinsics are not exposed on aarch64-apple-darwin; only include on Linux/other.
#[cfg(all(target_arch = "aarch64", not(target_vendor = "apple")))]
mod sve;
mod scalar;
#[derive(Clone, Copy, Eq, PartialEq, Debug)]
@@ -10,6 +19,10 @@ mod scalar;
enum FilterImplPerInstructionSet {
#[cfg(target_arch = "x86_64")]
AVX2 = 0u8,
#[cfg(all(target_arch = "aarch64", not(target_vendor = "apple")))]
Sve = 3u8,
#[cfg(target_arch = "aarch64")]
Neon = 2u8,
Scalar = 1u8,
}
@@ -19,29 +32,56 @@ impl FilterImplPerInstructionSet {
match *self {
#[cfg(target_arch = "x86_64")]
FilterImplPerInstructionSet::AVX2 => is_x86_feature_detected!("avx2"),
#[cfg(all(target_arch = "aarch64", not(target_vendor = "apple")))]
FilterImplPerInstructionSet::Sve => is_aarch64_feature_detected!("sve"),
#[cfg(target_arch = "aarch64")]
FilterImplPerInstructionSet::Neon => true,
FilterImplPerInstructionSet::Scalar => true,
}
}
}
// List of available implementation in preferred order.
// List of available implementations in preferred order.
#[cfg(target_arch = "x86_64")]
const IMPLS: [FilterImplPerInstructionSet; 2] = [
FilterImplPerInstructionSet::AVX2,
FilterImplPerInstructionSet::Scalar,
];
#[cfg(not(target_arch = "x86_64"))]
// Non-Apple aarch64 (Graviton, etc.): try SVE, NEON, Scalar.
#[cfg(all(target_arch = "aarch64", not(target_vendor = "apple")))]
const IMPLS: [FilterImplPerInstructionSet; 3] = [
FilterImplPerInstructionSet::Sve,
FilterImplPerInstructionSet::Neon,
FilterImplPerInstructionSet::Scalar,
];
// Apple aarch64 (M-series): SVE not available; use NEON or Scalar.
#[cfg(all(target_arch = "aarch64", target_vendor = "apple"))]
const IMPLS: [FilterImplPerInstructionSet; 2] = [
FilterImplPerInstructionSet::Neon,
FilterImplPerInstructionSet::Scalar,
];
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
const IMPLS: [FilterImplPerInstructionSet; 1] = [FilterImplPerInstructionSet::Scalar];
impl FilterImplPerInstructionSet {
#[inline]
#[allow(unused_variables)] // on non-x86_64, code is unused.
#[allow(unused_variables)]
fn from(code: u8) -> FilterImplPerInstructionSet {
#[cfg(target_arch = "x86_64")]
if code == FilterImplPerInstructionSet::AVX2 as u8 {
return FilterImplPerInstructionSet::AVX2;
}
#[cfg(all(target_arch = "aarch64", not(target_vendor = "apple")))]
if code == FilterImplPerInstructionSet::Sve as u8 {
return FilterImplPerInstructionSet::Sve;
}
#[cfg(target_arch = "aarch64")]
if code == FilterImplPerInstructionSet::Neon as u8 {
return FilterImplPerInstructionSet::Neon;
}
FilterImplPerInstructionSet::Scalar
}
@@ -50,6 +90,10 @@ impl FilterImplPerInstructionSet {
match self {
#[cfg(target_arch = "x86_64")]
FilterImplPerInstructionSet::AVX2 => avx2::filter_vec_in_place(range, offset, output),
#[cfg(all(target_arch = "aarch64", not(target_vendor = "apple")))]
FilterImplPerInstructionSet::Sve => sve::filter_vec_in_place(range, offset, output),
#[cfg(target_arch = "aarch64")]
FilterImplPerInstructionSet::Neon => neon::filter_vec_in_place(range, offset, output),
FilterImplPerInstructionSet::Scalar => {
scalar::filter_vec_in_place(range, offset, output)
}
@@ -63,7 +107,6 @@ fn get_best_available_instruction_set() -> FilterImplPerInstructionSet {
static INSTRUCTION_SET_BYTE: AtomicU8 = AtomicU8::new(u8::MAX);
let instruction_set_byte: u8 = INSTRUCTION_SET_BYTE.load(Ordering::Relaxed);
if instruction_set_byte == u8::MAX {
// Let's initialize the instruction set and cache it.
let instruction_set = IMPLS
.into_iter()
.find(FilterImplPerInstructionSet::is_available)
@@ -78,14 +121,29 @@ pub fn filter_vec_in_place(range: RangeInclusive<u32>, offset: u32, output: &mut
get_best_available_instruction_set().filter_vec_in_place(range, offset, output)
}
#[doc(hidden)]
pub fn scalar_filter_vec_in_place(range: RangeInclusive<u32>, offset: u32, output: &mut Vec<u32>) {
scalar::filter_vec_in_place(range, offset, output);
}
#[doc(hidden)]
#[cfg(target_arch = "aarch64")]
pub fn neon_filter_vec_in_place(range: RangeInclusive<u32>, offset: u32, output: &mut Vec<u32>) {
neon::filter_vec_in_place(range, offset, output);
}
#[doc(hidden)]
#[cfg(all(target_arch = "aarch64", not(target_vendor = "apple")))]
pub fn sve_filter_vec_in_place(range: RangeInclusive<u32>, offset: u32, output: &mut Vec<u32>) {
sve::filter_vec_in_place(range, offset, output);
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_get_best_available_instruction_set() {
// This does not test much unfortunately.
// We just make sure the function returns without crashing and returns the same result.
let instruction_set = get_best_available_instruction_set();
assert_eq!(get_best_available_instruction_set(), instruction_set);
}
@@ -102,6 +160,31 @@ mod tests {
}
}
#[cfg(all(target_arch = "aarch64", not(target_vendor = "apple")))]
#[test]
fn test_instruction_set_to_code_from_code() {
for instruction_set in [
FilterImplPerInstructionSet::Sve,
FilterImplPerInstructionSet::Neon,
FilterImplPerInstructionSet::Scalar,
] {
let code = instruction_set as u8;
assert_eq!(instruction_set, FilterImplPerInstructionSet::from(code));
}
}
#[cfg(all(target_arch = "aarch64", target_vendor = "apple"))]
#[test]
fn test_instruction_set_to_code_from_code() {
for instruction_set in [
FilterImplPerInstructionSet::Neon,
FilterImplPerInstructionSet::Scalar,
] {
let code = instruction_set as u8;
assert_eq!(instruction_set, FilterImplPerInstructionSet::from(code));
}
}
fn test_filter_impl_empty_aux(filter_impl: FilterImplPerInstructionSet) {
let mut output = vec![];
filter_impl.filter_vec_in_place(0..=u32::MAX, 0, &mut output);
@@ -141,6 +224,20 @@ mod tests {
}
}
#[test]
#[cfg(all(target_arch = "aarch64", not(target_vendor = "apple")))]
fn test_filter_implementation_sve() {
if FilterImplPerInstructionSet::Sve.is_available() {
test_filter_impl_test_suite(FilterImplPerInstructionSet::Sve);
}
}
#[test]
#[cfg(target_arch = "aarch64")]
fn test_filter_implementation_neon() {
test_filter_impl_test_suite(FilterImplPerInstructionSet::Neon);
}
#[test]
fn test_filter_implementation_scalar() {
test_filter_impl_test_suite(FilterImplPerInstructionSet::Scalar);
@@ -162,4 +259,19 @@ mod tests {
}
}
}
#[cfg(target_arch = "aarch64")]
proptest::proptest! {
#[test]
fn test_filter_compare_scalar_and_neon_impl_proptest(
start in proptest::prelude::any::<u32>(),
end in proptest::prelude::any::<u32>(),
offset in 0u32..2u32,
mut vals in proptest::collection::vec(0..u32::MAX, 0..30)) {
let mut vals_clone = vals.clone();
FilterImplPerInstructionSet::Neon.filter_vec_in_place(start..=end, offset, &mut vals);
FilterImplPerInstructionSet::Scalar.filter_vec_in_place(start..=end, offset, &mut vals_clone);
assert_eq!(&vals, &vals_clone);
}
}
}

View File

@@ -0,0 +1,113 @@
use std::arch::aarch64::*;
use std::ops::RangeInclusive;
const NUM_LANES: usize = 4;
// Compacts matching lanes to the front using a byte-level shuffle.
// `mask` is a 4-bit value: bit k=1 means lane k should appear in the output.
#[inline]
#[target_feature(enable = "neon")]
unsafe fn compact(data: uint32x4_t, mask: u8) -> uint32x4_t {
unsafe {
// SAFETY: mask is always in [0, 15] by construction (max sum of [1,2,4,8]).
// BYTE_SHUFFLE_TABLE has 16 entries, so this is always in bounds.
let shuffle = BYTE_SHUFFLE_TABLE.get_unchecked(mask as usize);
let shuffle_vec = vld1q_u8(shuffle.as_ptr());
vreinterpretq_u32_u8(vqtbl1q_u8(vreinterpretq_u8_u32(data), shuffle_vec))
}
}
#[inline(never)]
pub fn filter_vec_in_place(range: RangeInclusive<u32>, offset: u32, output: &mut Vec<u32>) {
let num_words = output.len() / NUM_LANES;
let mut output_len = unsafe {
filter_vec_neon_aux(
output.as_ptr(),
range.clone(),
output.as_mut_ptr(),
offset,
num_words,
)
};
let remainder_start = num_words * NUM_LANES;
for i in remainder_start..output.len() {
let val = output[i];
output[output_len] = offset + i as u32;
output_len += if range.contains(&val) { 1 } else { 0 };
}
output.truncate(output_len);
}
#[target_feature(enable = "neon")]
unsafe fn filter_vec_neon_aux(
input: *const u32,
range: RangeInclusive<u32>,
output: *mut u32,
offset: u32,
num_words: usize,
) -> usize {
unsafe {
let mut input = input;
let mut output_tail = output;
let range_start_simd = vdupq_n_u32(*range.start());
let range_end_simd = vdupq_n_u32(*range.end());
let mut ids = vld1q_u32([offset, offset + 1, offset + 2, offset + 3].as_ptr());
let shift = vdupq_n_u32(NUM_LANES as u32);
let bit_weights = vld1q_u32([1u32, 2, 4, 8].as_ptr());
for _ in 0..num_words {
let word = vld1q_u32(input);
// Unsigned compares: CMHS (compare higher or same) tests `word >= start`
// and `end >= word`. ANDing both gives the inside-range mask directly,
// which is cheaper than computing `outside` and then negating.
let ge_start = vcgeq_u32(word, range_start_simd);
let le_end = vcleq_u32(word, range_end_simd);
// inside[k] = 0xFFFFFFFF if val[k] is in range, 0 otherwise.
let inside = vandq_u32(ge_start, le_end);
// Build the 4-bit mask: AND bit_weights with the inside lane mask, so each
// inside lane contributes its bit_weight (1, 2, 4, or 8). Summing yields the
// 4-bit mask in one addv.
let inside_bits = vandq_u32(bit_weights, inside);
let mask = vaddvq_u32(inside_bits) as u8;
// mask is mathematically bounded: max value is 1+2+4+8=15 (all lanes match)
debug_assert!(mask <= 15, "mask must fit in 4 bits: {}", mask);
// Count of matching lanes = popcount(mask). Derives the count directly from
// the mask instead of running a parallel SIMD reduction over `outside`.
let added_len = mask.count_ones() as usize;
// Safe because mask is guaranteed to be in [0, 15]
let filtered_ids = compact(ids, mask);
vst1q_u32(output_tail, filtered_ids);
output_tail = output_tail.add(added_len);
ids = vaddq_u32(ids, shift);
input = input.add(NUM_LANES);
}
output_tail.offset_from(output) as usize
}
}
// Byte shuffle patterns to compact matching lanes to the front of the vector.
// Index is a 4-bit mask: bit k=1 means lane k (bytes 4k..4k+3) is in-range.
// The j-th set bit determines which input lane goes to output position j.
const BYTE_SHUFFLE_TABLE: [[u8; 16]; 16] = [
[0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3], // 0b0000: none
[0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3], // 0b0001: lane 0
[4, 5, 6, 7, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3], // 0b0010: lane 1
[0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 0, 1, 2, 3], // 0b0011: lanes 0,1
[8, 9, 10, 11, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3], // 0b0100: lane 2
[0, 1, 2, 3, 8, 9, 10, 11, 0, 1, 2, 3, 0, 1, 2, 3], // 0b0101: lanes 0,2
[4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 0, 1, 2, 3], // 0b0110: lanes 1,2
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3], // 0b0111: lanes 0,1,2
[12, 13, 14, 15, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3], // 0b1000: lane 3
[0, 1, 2, 3, 12, 13, 14, 15, 0, 1, 2, 3, 0, 1, 2, 3], // 0b1001: lanes 0,3
[4, 5, 6, 7, 12, 13, 14, 15, 0, 1, 2, 3, 0, 1, 2, 3], // 0b1010: lanes 1,3
[0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15, 0, 1, 2, 3], // 0b1011: lanes 0,1,3
[8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 0, 1, 2, 3], // 0b1100: lanes 2,3
[0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3], // 0b1101: lanes 0,2,3
[4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3], // 0b1110: lanes 1,2,3
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], // 0b1111: all lanes
];

View File

@@ -0,0 +1,108 @@
use std::arch::aarch64::*;
use std::ops::RangeInclusive;
// SVE vector length (in u32 lanes) is not a compile-time constant; query at runtime.
// Safe to call only when SVE is confirmed available via is_aarch64_feature_detected!("sve").
#[target_feature(enable = "sve")]
fn num_lanes() -> usize {
svcntw() as usize
}
pub fn filter_vec_in_place(range: RangeInclusive<u32>, offset: u32, output: &mut Vec<u32>) {
let vl = unsafe { num_lanes() };
let num_words = output.len() / vl;
let range_start = *range.start();
// Unsigned subtraction trick: val ∈ [lo, hi] ↔ (val - lo) ≤ᵤ (hi - lo).
// Values below lo wrap around to large u32, so the single unsigned ≤ excludes them.
let range_width = range.end().wrapping_sub(range_start);
let mut output_len = unsafe {
filter_vec_sve_aux(
output.as_ptr(),
range_start,
range_width,
output.as_mut_ptr(),
offset,
num_words,
vl,
)
};
let remainder_start = num_words * vl;
for i in remainder_start..output.len() {
let val = output[i];
output[output_len] = offset + i as u32;
output_len += if range.contains(&val) { 1 } else { 0 };
}
output.truncate(output_len);
}
#[target_feature(enable = "sve")]
unsafe fn filter_vec_sve_aux(
input: *const u32,
range_start: u32,
range_width: u32,
output: *mut u32,
offset: u32,
num_words: usize,
vl: usize,
) -> usize {
unsafe {
let all_true = svptrue_b32();
let range_start_simd = svdup_n_u32(range_start);
let range_width_simd = svdup_n_u32(range_width);
// ids_a covers [offset .. offset+vl), ids_b covers the next vl ids.
// Keeping them separate breaks the loop-carried dependency through ids so
// both compact/cntp chains are fully independent within each unrolled body.
let mut ids_a = svindex_u32(offset, 1);
let step = svdup_n_u32(vl as u32);
let step2 = svdup_n_u32(2 * vl as u32);
let mut ids_b = svadd_u32_x(all_true, ids_a, step);
let mut input = input;
let mut output_tail = output;
// Unrolled ×2: both cntp calls have independent inputs and execute in parallel.
// The two output_tail updates are sequential but together cost 4+1+1=6 cy per
// pair vs 5+5=10 cy for two scalar iterations, breaking the cntp latency chain.
let num_pairs = num_words / 2;
for _ in 0..num_pairs {
let word_a = svld1_u32(all_true, input);
let word_b = svld1_u32(all_true, input.add(vl));
let shifted_a = svsub_u32_x(all_true, word_a, range_start_simd);
let shifted_b = svsub_u32_x(all_true, word_b, range_start_simd);
let in_range_a = svcmple_u32(all_true, shifted_a, range_width_simd);
let in_range_b = svcmple_u32(all_true, shifted_b, range_width_simd);
let compacted_a = svcompact_u32(in_range_a, ids_a);
let compacted_b = svcompact_u32(in_range_b, ids_b);
// cntp_a and cntp_b have independent inputs: OOO engine issues them in parallel.
let added_len_a = svcntp_b32(all_true, in_range_a) as usize;
let added_len_b = svcntp_b32(all_true, in_range_b) as usize;
// Write the full vector — only the first added_len slots are valid.
// Subsequent iterations overwrite the trailing zeros before truncate.
svst1_u32(all_true, output_tail, compacted_a);
output_tail = output_tail.add(added_len_a);
svst1_u32(all_true, output_tail, compacted_b);
output_tail = output_tail.add(added_len_b);
ids_a = svadd_u32_x(all_true, ids_a, step2);
ids_b = svadd_u32_x(all_true, ids_b, step2);
input = input.add(2 * vl);
}
// Handle an odd trailing word.
if num_words % 2 == 1 {
let word = svld1_u32(all_true, input);
let shifted = svsub_u32_x(all_true, word, range_start_simd);
let in_range = svcmple_u32(all_true, shifted, range_width_simd);
let added_len = svcntp_b32(all_true, in_range) as usize;
let compacted_ids = svcompact_u32(in_range, ids_a);
svst1_u32(all_true, output_tail, compacted_ids);
output_tail = output_tail.add(added_len);
}
output_tail.offset_from(output) as usize
}
}

View File

@@ -1,6 +1,13 @@
// SVE/SVE2 intrinsics are nightly-only; enable them on non-Apple aarch64 targets.
#![cfg_attr(
all(target_arch = "aarch64", not(target_vendor = "apple")),
feature(stdarch_aarch64_sve)
)]
mod bitpacker;
mod blocked_bitpacker;
mod filter_vec;
#[doc(hidden)]
pub mod filter_vec;
use std::cmp::Ordering;