Compare commits

..

2 Commits

Author SHA1 Message Date
Pascal Seitz
5d53b11a2c change debug_assert to assert in term length
change `debug_assert` to `assert` since the bug with the truncated Terms only occurs in prod
2024-04-04 22:54:28 +08:00
PSeitz
4e79e11007 add collect_block to BoxableSegmentCollector (#2331) 2024-03-21 09:10:25 +01:00
7 changed files with 30 additions and 156 deletions

View File

@@ -8,7 +8,6 @@ mod tests {
use rand::thread_rng;
use tantivy_bitpacker::{BitPacker, BitUnpacker, BlockedBitpacker};
use test::Bencher;
use tantivy_bitpacker::filter_vec;
#[inline(never)]
fn create_bitpacked_data(bit_width: u8, num_els: u32) -> Vec<u8> {
@@ -63,40 +62,4 @@ mod tests {
blocked_bitpacker
});
}
fn bench_filter_vec(//values: Vec<u32>,
filter_impl: filter_vec::FilterImplPerInstructionSet) -> u32{
let mut values = vec![0u32; 1_000_000];
//let mut values = values;
filter_impl.filter_vec_in_place(0..=10, 0, &mut values);
values[0]
}
#[bench]
fn bench_filter_vec_avx512(b: &mut Bencher) {
//let values = vec![0u32; 1_000_000];
if filter_vec::FilterImplPerInstructionSet::AVX512.is_available() {
b.iter(|| {
bench_filter_vec(filter_vec::FilterImplPerInstructionSet::AVX512)
});
}
}
#[bench]
fn bench_filter_vec_avx2(b: &mut Bencher) {
//let values = vec![0u32; 1_000_000];
if filter_vec::FilterImplPerInstructionSet::AVX2.is_available() {
b.iter(|| {
bench_filter_vec(filter_vec::FilterImplPerInstructionSet::AVX2)
});
}
}
#[bench]
fn bench_filter_vec_scalar(b: &mut Bencher) {
//let values = vec![0u32; 1_000_000];
if filter_vec::FilterImplPerInstructionSet::Scalar.is_available() {
b.iter(|| {
bench_filter_vec(filter_vec::FilterImplPerInstructionSet::Scalar)
});
}
}
}

View File

@@ -1,7 +1,5 @@
// Copyright 2024 The Tantivy Authors. All Rights Reserved.
//! SIMD filtering of a vector as described in the following blog post.
// https://quickwit.io/blog/simd-range
//! <https://quickwit.io/blog/filtering%20a%20vector%20with%20simd%20instructions%20avx-2%20and%20avx-512>
use std::arch::x86_64::{
__m256i as DataType, _mm256_add_epi32 as op_add, _mm256_cmpgt_epi32 as op_greater,
_mm256_lddqu_si256 as load_unaligned, _mm256_or_si256 as op_or, _mm256_set1_epi32 as set1,

View File

@@ -1,86 +0,0 @@
// https://quickwit.io/blog/simd-range
use std::ops::RangeInclusive;
use std::arch::x86_64::_mm512_add_epi32 as op_add;
use std::arch::x86_64::_mm512_cmple_epi32_mask as op_less_or_equal;
use std::arch::x86_64::_mm512_loadu_epi32 as load_unaligned;
use std::arch::x86_64::_mm512_set1_epi32 as set1;
use std::arch::x86_64::_mm512_mask_compressstoreu_epi32 as compress;
use std::arch::x86_64::__m512i;
const NUM_LANES: usize = 16;
pub fn filter_vec_in_place(//input: &[u32],
range: RangeInclusive<u32>, offset: u32,
output: &mut Vec<u32>) {
//assert_eq!(output.len() % NUM_LANES, 0); // Not required. // but maybe we need some padding on the output for avx512 to work well?
// We restrict the accepted boundary, because unsigned integers & SIMD don't
// play well.
// TODO.
let accepted_range = 0u32..(i32::MAX as u32);
assert!(accepted_range.contains(range.start()), "{:?}", range);
assert!(accepted_range.contains(range.end()), "{:?}", range);
//output.clear();
//output.reserve(input.len());
let num_words = output.len() / NUM_LANES;
let mut output_len = unsafe {
filter_vec_avx512_aux(
//output.as_ptr() as *const __m512i,
output.as_ptr(),
range.clone(),
output.as_mut_ptr(),
offset,
num_words,
)
};
let reminder_start = num_words * NUM_LANES;
for i in reminder_start..output.len() {
let val = output[i];
output[output_len] = offset + i as u32;
//output[output_len] = i as u32;
output_len += if range.contains(&val) { 1 } else { 0 };
}
output.truncate(output_len);
}
#[target_feature(enable = "avx512f")]
pub unsafe fn filter_vec_avx512_aux(
mut input: *const u32,
range: RangeInclusive<u32>,
output: *mut u32,
offset: u32,
num_words: usize,
) -> usize {
let mut output_end = output;
let range_simd =
set1(*range.start() as i32)..=set1(*range.end() as i32);
let mut ids = from_u32x16([offset + 0, offset + 1, offset + 2, offset + 3, offset + 4, offset + 5, offset + 6, offset + 7,
offset + 8, offset + 9, offset + 10, offset + 11, offset + 12, offset + 13, offset + 14, offset + 15]);
const SHIFT: __m512i = from_u32x16([NUM_LANES as u32; NUM_LANES]);
for _ in 0..num_words {
let word = load_unaligned(input as *const i32);
let keeper_bitset = compute_filter_bitset(word, range_simd.clone());
compress(output_end as *mut u8, keeper_bitset, ids);
let added_len = keeper_bitset.count_ones();
output_end = output_end.offset(added_len as isize);
ids = op_add(ids, SHIFT);
input = input.offset(16);
}
output_end.offset_from(output) as usize
}
#[inline]
unsafe fn compute_filter_bitset(
val: __m512i,
range: RangeInclusive<__m512i>) -> u16 {
let low = op_less_or_equal(*range.start(), val);
let high = op_less_or_equal(val, *range.end());
low & high
}
const fn from_u32x16(vals: [u32; NUM_LANES]) -> __m512i {
union U8x64 {
vector: __m512i,
vals: [u32; NUM_LANES],
}
unsafe { U8x64 { vals }.vector }
}

View File

@@ -2,17 +2,15 @@ use std::ops::RangeInclusive;
#[cfg(target_arch = "x86_64")]
mod avx2;
mod avx512;
mod scalar;
#[derive(Clone, Copy, Eq, PartialEq, Debug)]
#[repr(u8)]
pub enum FilterImplPerInstructionSet {
enum FilterImplPerInstructionSet {
#[cfg(target_arch = "x86_64")]
AVX512 = 0u8,
AVX2 = 1u8,
Scalar = 2u8,
AVX2 = 0u8,
Scalar = 1u8,
}
impl FilterImplPerInstructionSet {
@@ -20,9 +18,6 @@ impl FilterImplPerInstructionSet {
pub fn is_available(&self) -> bool {
match *self {
#[cfg(target_arch = "x86_64")]
FilterImplPerInstructionSet::AVX512 => is_x86_feature_detected!("avx512f") && is_x86_feature_detected!("avx512vl"),
//FilterImplPerInstructionSet::AVX512 => false,
FilterImplPerInstructionSet::AVX2 => is_x86_feature_detected!("avx2"),
FilterImplPerInstructionSet::Scalar => true,
}
@@ -31,8 +26,7 @@ impl FilterImplPerInstructionSet {
// List of available implementation in preferred order.
#[cfg(target_arch = "x86_64")]
const IMPLS: [FilterImplPerInstructionSet; 3] = [
FilterImplPerInstructionSet::AVX512,
const IMPLS: [FilterImplPerInstructionSet; 2] = [
FilterImplPerInstructionSet::AVX2,
FilterImplPerInstructionSet::Scalar,
];
@@ -45,9 +39,6 @@ impl FilterImplPerInstructionSet {
#[inline]
fn from(code: u8) -> FilterImplPerInstructionSet {
#[cfg(target_arch = "x86_64")]
if code == FilterImplPerInstructionSet::AVX512 as u8 {
return FilterImplPerInstructionSet::AVX512;
}
if code == FilterImplPerInstructionSet::AVX2 as u8 {
return FilterImplPerInstructionSet::AVX2;
}
@@ -55,10 +46,9 @@ impl FilterImplPerInstructionSet {
}
#[inline]
pub fn filter_vec_in_place(self, range: RangeInclusive<u32>, offset: u32, output: &mut Vec<u32>) {
fn filter_vec_in_place(self, range: RangeInclusive<u32>, offset: u32, output: &mut Vec<u32>) {
match self {
#[cfg(target_arch = "x86_64")]
FilterImplPerInstructionSet::AVX512 => avx512::filter_vec_in_place(range, offset, output),
FilterImplPerInstructionSet::AVX2 => avx2::filter_vec_in_place(range, offset, output),
FilterImplPerInstructionSet::Scalar => {
scalar::filter_vec_in_place(range, offset, output)
@@ -104,7 +94,6 @@ mod tests {
#[test]
fn test_instruction_set_to_code_from_code() {
for instruction_set in [
FilterImplPerInstructionSet::AVX512,
FilterImplPerInstructionSet::AVX2,
FilterImplPerInstructionSet::Scalar,
] {
@@ -138,10 +127,10 @@ mod tests {
}
fn test_filter_impl_test_suite(filter_impl: FilterImplPerInstructionSet) {
//test_filter_impl_empty_aux(filter_impl);
test_filter_impl_empty_aux(filter_impl);
test_filter_impl_simple_aux(filter_impl);
test_filter_impl_simple_aux_shifted(filter_impl);
// test_filter_impl_simple_outside_i32_range(filter_impl);
test_filter_impl_simple_outside_i32_range(filter_impl);
}
#[test]
@@ -151,13 +140,6 @@ mod tests {
test_filter_impl_test_suite(FilterImplPerInstructionSet::AVX2);
}
}
#[test]
#[cfg(target_arch = "x86_64")]
fn test_filter_implementation_avx512() {
if FilterImplPerInstructionSet::AVX512.is_available() {
test_filter_impl_test_suite(FilterImplPerInstructionSet::AVX512);
}
}
#[test]
fn test_filter_implementation_scalar() {

View File

@@ -1,9 +1,6 @@
#![feature(stdarch_x86_avx512)]
#![feature(avx512_target_feature)]
mod bitpacker;
mod blocked_bitpacker;
pub mod filter_vec;
mod filter_vec;
use std::cmp::Ordering;

View File

@@ -52,10 +52,16 @@ impl<TCollector: Collector> Collector for CollectorWrapper<TCollector> {
impl SegmentCollector for Box<dyn BoxableSegmentCollector> {
type Fruit = Box<dyn Fruit>;
#[inline]
fn collect(&mut self, doc: u32, score: Score) {
self.as_mut().collect(doc, score);
}
#[inline]
fn collect_block(&mut self, docs: &[DocId]) {
self.as_mut().collect_block(docs);
}
fn harvest(self) -> Box<dyn Fruit> {
BoxableSegmentCollector::harvest_from_box(self)
}
@@ -63,6 +69,11 @@ impl SegmentCollector for Box<dyn BoxableSegmentCollector> {
pub trait BoxableSegmentCollector {
fn collect(&mut self, doc: u32, score: Score);
fn collect_block(&mut self, docs: &[DocId]) {
for &doc in docs {
self.collect(doc, 0.0);
}
}
fn harvest_from_box(self: Box<Self>) -> Box<dyn Fruit>;
}
@@ -71,9 +82,14 @@ pub struct SegmentCollectorWrapper<TSegmentCollector: SegmentCollector>(TSegment
impl<TSegmentCollector: SegmentCollector> BoxableSegmentCollector
for SegmentCollectorWrapper<TSegmentCollector>
{
#[inline]
fn collect(&mut self, doc: u32, score: Score) {
self.0.collect(doc, score);
}
#[inline]
fn collect_block(&mut self, docs: &[DocId]) {
self.0.collect_block(docs);
}
fn harvest_from_box(self: Box<Self>) -> Box<dyn Fruit> {
Box::new(self.0.harvest())

View File

@@ -204,7 +204,11 @@ impl<Rec: Recorder> SpecializedPostingsWriter<Rec> {
impl<Rec: Recorder> PostingsWriter for SpecializedPostingsWriter<Rec> {
#[inline]
fn subscribe(&mut self, doc: DocId, position: u32, term: &Term, ctx: &mut IndexingContext) {
debug_assert!(term.serialized_term().len() >= 4);
assert!(
term.serialized_term().len() >= 4,
"Term too short expect >=4 but got {:?}",
term.serialized_term()
);
self.total_num_tokens += 1;
let (term_index, arena) = (&mut ctx.term_index, &mut ctx.arena);
term_index.mutate_or_create(term.serialized_term(), |opt_recorder: Option<Rec>| {