Compare commits

..

1 Commits

Author SHA1 Message Date
PSeitz
129c40f8ec Improve Union Performance for non-score unions (#2863)
* enhance and_or_queries bench

* optimize unions for count/non-score, bitset fix for ARM

Benchmarks run on M4 Max
```
single_field_only_union_5%_OR_1%
count                Avg: 0.1100ms (-17.46%)    Median: 0.1079ms (-14.08%)    [0.1045ms .. 0.1410ms]    Output: 54_110
top10_inv_idx        Avg: 0.1663ms (+0.79%)     Median: 0.1660ms (+0.75%)     [0.1634ms .. 0.1702ms]    Output: 10
count+top10          Avg: 0.2639ms (-1.24%)     Median: 0.2634ms (-0.31%)     [0.2512ms .. 0.2813ms]    Output: 54_110
top10_by_ff          Avg: 0.2875ms (-8.67%)     Median: 0.2852ms (-8.80%)     [0.2737ms .. 0.3083ms]    Output: 10
top10_by_2ff         Avg: 0.3137ms (-5.79%)     Median: 0.3128ms (-0.35%)     [0.3044ms .. 0.3313ms]    Output: 10
single_field_only_union_5%_OR_1%_OR_15%
count                Avg: 0.4122ms (-33.05%)    Median: 0.4140ms (-32.20%)    [0.3940ms .. 0.4341ms]    Output: 181_663
top10_inv_idx        Avg: 0.3999ms (+2.39%)     Median: 0.3987ms (+2.02%)     [0.3939ms .. 0.4160ms]    Output: 10
count+top10          Avg: 0.8520ms (-8.63%)     Median: 0.8516ms (-8.65%)     [0.8413ms .. 0.8676ms]    Output: 181_663
top10_by_ff          Avg: 0.9694ms (-13.06%)    Median: 0.9645ms (-13.77%)    [0.9403ms .. 1.0122ms]    Output: 10
top10_by_2ff         Avg: 0.9880ms (-13.01%)    Median: 0.9838ms (-13.59%)    [0.9781ms .. 1.0306ms]    Output: 10
single_field_only_union_5%_OR_30%
count                Avg: 0.7364ms (-33.11%)    Median: 0.7347ms (-33.19%)    [0.7233ms .. 0.7547ms]    Output: 303_337
top10_inv_idx        Avg: 0.8932ms (-0.89%)     Median: 0.8919ms (-0.75%)     [0.8861ms .. 0.9249ms]    Output: 10
count+top10          Avg: 1.3611ms (-9.23%)     Median: 1.3598ms (-9.39%)     [1.3426ms .. 1.3891ms]    Output: 303_337
top10_by_ff          Avg: 1.6575ms (-18.64%)    Median: 1.6224ms (-20.81%)    [1.6051ms .. 1.7560ms]    Output: 10
top10_by_2ff         Avg: 1.6800ms (-16.24%)    Median: 1.6769ms (-15.72%)    [1.6661ms .. 1.7229ms]    Output: 10
single_field_only_union_30%_OR_0.01%
count                Avg: 0.6471ms (-33.73%)    Median: 0.6464ms (-33.46%)    [0.6375ms .. 0.6604ms]    Output: 270_268
top10_inv_idx        Avg: 0.0338ms (-0.27%)     Median: 0.0338ms (+0.11%)     [0.0331ms .. 0.0351ms]    Output: 10
count+top10          Avg: 1.2209ms (-9.27%)     Median: 1.2207ms (-9.25%)     [1.2158ms .. 1.2351ms]    Output: 270_268
top10_by_ff          Avg: 1.4808ms (-17.20%)    Median: 1.4690ms (-17.91%)    [1.4384ms .. 1.5553ms]    Output: 10
top10_by_2ff         Avg: 1.5011ms (-14.30%)    Median: 1.4992ms (-13.88%)    [1.4891ms .. 1.5320ms]    Output: 10
multi_field_only_union_5%_OR_1%
count                Avg: 0.1196ms (-17.67%)    Median: 0.1166ms (-14.83%)    [0.1123ms .. 0.1462ms]    Output: 60_183
top10_inv_idx        Avg: 0.2356ms (-0.21%)     Median: 0.2355ms (+0.23%)     [0.2330ms .. 0.2406ms]    Output: 10
count+top10          Avg: 0.2985ms (-5.06%)     Median: 0.2957ms (-5.79%)     [0.2875ms .. 0.3186ms]    Output: 60_183
top10_by_ff          Avg: 0.3102ms (-9.44%)     Median: 0.3031ms (-11.09%)    [0.2994ms .. 0.3324ms]    Output: 10
top10_by_2ff         Avg: 0.3435ms (-0.91%)     Median: 0.3447ms (-0.62%)     [0.3342ms .. 0.3530ms]    Output: 10
multi_field_only_union_5%_OR_1%_OR_15%
count                Avg: 0.4465ms (-35.41%)    Median: 0.4456ms (-36.25%)    [0.4250ms .. 0.4936ms]    Output: 201_114
top10_inv_idx        Avg: 1.1542ms (+2.38%)     Median: 1.1560ms (+2.96%)     [1.1193ms .. 1.1912ms]    Output: 10
count+top10          Avg: 0.9334ms (-8.89%)     Median: 0.9330ms (-8.95%)     [0.9191ms .. 0.9542ms]    Output: 201_114
top10_by_ff          Avg: 1.0590ms (-14.10%)    Median: 1.0424ms (-15.08%)    [1.0304ms .. 1.1174ms]    Output: 10
top10_by_2ff         Avg: 1.0779ms (-17.06%)    Median: 1.0754ms (-17.40%)    [1.0650ms .. 1.1155ms]    Output: 10
multi_field_only_union_5%_OR_30%
count                Avg: 0.8137ms (-33.48%)    Median: 0.7976ms (-34.84%)    [0.7734ms .. 1.0855ms]    Output: 335_682
top10_inv_idx        Avg: 1.5108ms (+0.36%)     Median: 1.4943ms (-0.72%)     [1.4805ms .. 1.5865ms]    Output: 10
count+top10          Avg: 1.4985ms (-9.75%)     Median: 1.4936ms (-9.63%)     [1.4784ms .. 1.5472ms]    Output: 335_682
top10_by_ff          Avg: 1.8531ms (-15.70%)    Median: 1.8583ms (-16.30%)    [1.7467ms .. 2.2297ms]    Output: 10
top10_by_2ff         Avg: 1.8735ms (-16.67%)    Median: 1.8421ms (-18.05%)    [1.8146ms .. 2.3650ms]    Output: 10
multi_field_only_union_30%_OR_0.01%
count                Avg: 0.7020ms (-34.40%)    Median: 0.7004ms (-34.05%)    [0.6943ms .. 0.7156ms]    Output: 300_315
top10_inv_idx        Avg: 0.1445ms (-1.57%)     Median: 0.1442ms (-1.35%)     [0.1426ms .. 0.1478ms]    Output: 10
count+top10          Avg: 1.3309ms (-9.84%)     Median: 1.3284ms (-9.71%)     [1.3234ms .. 1.3549ms]    Output: 300_315
top10_by_ff          Avg: 1.6152ms (-17.39%)    Median: 1.6037ms (-18.72%)    [1.5778ms .. 1.7227ms]    Output: 10
top10_by_2ff         Avg: 1.6479ms (-17.10%)    Median: 1.6444ms (-15.46%)    [1.6307ms .. 1.6901ms]    Output: 10
```

* add comment

* fix comment

* remove inline(never), bounds check
2026-03-27 08:00:26 +01:00
6 changed files with 179 additions and 565 deletions

View File

@@ -22,7 +22,7 @@ use rand::rngs::StdRng;
use rand::SeedableRng;
use tantivy::collector::sort_key::SortByStaticFastValue;
use tantivy::collector::{Collector, Count, TopDocs};
use tantivy::query::{Query, QueryParser};
use tantivy::query::QueryParser;
use tantivy::schema::{Schema, FAST, TEXT};
use tantivy::{doc, Index, Order, ReloadPolicy, Searcher};
@@ -38,7 +38,7 @@ struct BenchIndex {
/// return two BenchIndex views:
/// - single_field: QueryParser defaults to only "body"
/// - multi_field: QueryParser defaults to ["title", "body"]
fn build_shared_indices(num_docs: usize, p_a: f32, p_b: f32, p_c: f32) -> (BenchIndex, BenchIndex) {
fn build_index(num_docs: usize, terms: &[(&str, f32)]) -> (BenchIndex, BenchIndex) {
// Unified schema (two text fields)
let mut schema_builder = Schema::builder();
let f_title = schema_builder.add_text_field("title", TEXT);
@@ -55,32 +55,17 @@ fn build_shared_indices(num_docs: usize, p_a: f32, p_b: f32, p_c: f32) -> (Bench
{
let mut writer = index.writer_with_num_threads(1, 500_000_000).unwrap();
for _ in 0..num_docs {
let has_a = rng.random_bool(p_a as f64);
let has_b = rng.random_bool(p_b as f64);
let has_c = rng.random_bool(p_c as f64);
let score = rng.random_range(0u64..100u64);
let score2 = rng.random_range(0u64..100_000u64);
let mut title_tokens: Vec<&str> = Vec::new();
let mut body_tokens: Vec<&str> = Vec::new();
if has_a {
if rng.random_bool(0.1) {
title_tokens.push("a");
} else {
body_tokens.push("a");
}
}
if has_b {
if rng.random_bool(0.1) {
title_tokens.push("b");
} else {
body_tokens.push("b");
}
}
if has_c {
if rng.random_bool(0.1) {
title_tokens.push("c");
} else {
body_tokens.push("c");
for &(tok, prob) in terms {
if rng.random_bool(prob as f64) {
if rng.random_bool(0.1) {
title_tokens.push(tok);
} else {
body_tokens.push(tok);
}
}
}
if title_tokens.is_empty() && body_tokens.is_empty() {
@@ -110,59 +95,97 @@ fn build_shared_indices(num_docs: usize, p_a: f32, p_b: f32, p_c: f32) -> (Bench
let qp_single = QueryParser::for_index(&index, vec![f_body]);
let qp_multi = QueryParser::for_index(&index, vec![f_title, f_body]);
let single_view = BenchIndex {
let only_title = BenchIndex {
index: index.clone(),
searcher: searcher.clone(),
query_parser: qp_single,
};
let multi_view = BenchIndex {
let title_and_body = BenchIndex {
index,
searcher,
query_parser: qp_multi,
};
(single_view, multi_view)
(only_title, title_and_body)
}
fn format_pct(p: f32) -> String {
let pct = (p as f64) * 100.0;
let rounded = (pct * 1_000_000.0).round() / 1_000_000.0;
if rounded.fract() <= 0.001 {
format!("{}%", rounded as u64)
} else {
format!("{}%", rounded)
}
}
fn query_label(query_str: &str, term_pcts: &[(&str, String)]) -> String {
let mut label = query_str.to_string();
for (term, pct) in term_pcts {
label = label.replace(term, pct);
}
label.replace(' ', "_")
}
fn main() {
// Prepare corpora with varying selectivity. Build one index per corpus
// and derive two views (single-field vs multi-field) from it.
let scenarios = vec![
// terms with varying selectivity, ordered from rarest to most common.
// With 1M docs, we expect:
// a: 0.01% (100), b: 1% (10k), c: 5% (50k), d: 15% (150k), e: 30% (300k)
let num_docs = 1_000_000;
let terms: &[(&str, f32)] = &[
("a", 0.0001),
("b", 0.01),
("c", 0.05),
("d", 0.15),
("e", 0.30),
];
let queries: &[(&str, &[&str])] = &[
(
"N=1M, p(a)=5%, p(b)=1%, p(c)=15%".to_string(),
1_000_000,
0.05,
0.01,
0.15,
"only_union",
&["c OR b", "c OR b OR d", "c OR e", "e OR a"] as &[&str],
),
(
"N=1M, p(a)=1%, p(b)=1%, p(c)=15%".to_string(),
1_000_000,
0.01,
0.01,
0.15,
"only_intersection",
&["+c +b", "+c +b +d", "+c +e", "+e +a"] as &[&str],
),
(
"union_intersection",
&["+c +(b OR d)", "+e +(c OR a)", "+(c OR b) +(d OR e)"] as &[&str],
),
];
let queries = &["a", "+a +b", "+a +b +c", "a OR b", "a OR b OR c"];
let mut runner = BenchRunner::new();
for (label, n, pa, pb, pc) in scenarios {
let (single_view, multi_view) = build_shared_indices(n, pa, pb, pc);
let (only_title, title_and_body) = build_index(num_docs, terms);
let term_pcts: Vec<(&str, String)> = terms
.iter()
.map(|&(term, p)| (term, format_pct(p)))
.collect();
for (view_name, bench_index) in [("single_field", single_view), ("multi_field", multi_view)]
{
// Single-field group: default field is body only
let mut group = runner.new_group();
group.set_name(format!("{}{}", view_name, label));
for query_str in queries {
for (view_name, bench_index) in [
("single_field", only_title),
("multi_field", title_and_body),
] {
for (category_name, category_queries) in queries {
for query_str in *category_queries {
let mut group = runner.new_group();
let query_label = query_label(query_str, &term_pcts);
group.set_name(format!("{}_{}_{}", view_name, category_name, query_label));
add_bench_task(&mut group, &bench_index, query_str, Count, "count");
add_bench_task(
&mut group,
&bench_index,
query_str,
TopDocs::with_limit(10).order_by_score(),
"top10",
"top10_inv_idx",
);
add_bench_task(
&mut group,
&bench_index,
query_str,
(Count, TopDocs::with_limit(10).order_by_score()),
"count+top10",
);
add_bench_task(
&mut group,
&bench_index,
@@ -180,39 +203,47 @@ fn main() {
)),
"top10_by_2ff",
);
group.run();
}
group.run();
}
}
}
trait FruitCount {
fn count(&self) -> usize;
}
impl FruitCount for usize {
fn count(&self) -> usize {
*self
}
}
impl<T> FruitCount for Vec<T> {
fn count(&self) -> usize {
self.len()
}
}
impl<A: FruitCount, B> FruitCount for (A, B) {
fn count(&self) -> usize {
self.0.count()
}
}
fn add_bench_task<C: Collector + 'static>(
bench_group: &mut BenchGroup,
bench_index: &BenchIndex,
query_str: &str,
collector: C,
collector_name: &str,
) {
let task_name = format!("{}_{}", query_str.replace(" ", "_"), collector_name);
) where
C::Fruit: FruitCount,
{
let query = bench_index.query_parser.parse_query(query_str).unwrap();
let search_task = SearchTask {
searcher: bench_index.searcher.clone(),
collector,
query,
};
bench_group.register(task_name, move |_| black_box(search_task.run()));
}
struct SearchTask<C: Collector> {
searcher: Searcher,
collector: C,
query: Box<dyn Query>,
}
impl<C: Collector> SearchTask<C> {
#[inline(never)]
pub fn run(&self) -> usize {
self.searcher.search(&self.query, &self.collector).unwrap();
1
}
let searcher = bench_index.searcher.clone();
bench_group.register(collector_name.to_string(), move |_| {
black_box(searcher.search(&query, &collector).unwrap().count())
});
}

View File

@@ -153,7 +153,22 @@ impl TinySet {
None
} else {
let lowest = self.0.trailing_zeros();
self.0 ^= TinySet::singleton(lowest).0;
// Kernighan's trick: `n &= n - 1` clears the lowest set bit
// without depending on `lowest`. This lets the CPU execute
// `trailing_zeros` and the bit-clear in parallel instead of
// serializing them.
//
// The previous form `self.0 ^= 1 << lowest` needs the result of
// `trailing_zeros` before it can shift, creating a dependency chain:
// ARM64: rbit → clz → lsl → eor
// x86: tzcnt → btc
//
// With Kernighan's trick the clear path is independent of the count:
// ARM64: sub → and (trailing_zeros runs in parallel)
// x86: blsr (tzcnt runs in parallel)
//
// https://godbolt.org/z/fnfrP1T5f
self.0 &= self.0 - 1;
Some(lowest)
}
}

View File

@@ -1,418 +0,0 @@
use crate::query::term_query::TermScorer;
use crate::query::Scorer;
use crate::{DocId, DocSet, Score, TERMINATED};
/// Block-max pruning for top-K over intersection of term scorers.
///
/// Uses the least-frequent term as "leader" to define 128-doc processing windows.
/// For each window, the sum of block_max_scores is compared to the current threshold;
/// if the block can't beat it, the entire block is skipped.
///
/// Within non-skipped blocks, individual documents are pruned by checking whether
/// leader_score + sum(secondary block_max_scores) can exceed the threshold before
/// performing the expensive intersection membership check (seeking into secondary scorers).
///
/// # Preconditions
/// - `scorers` has at least 2 elements
/// - All scorers read frequencies (`FreqReadingOption::ReadFreq`)
pub fn block_wand_intersection(
mut scorers: Vec<TermScorer>,
mut threshold: Score,
callback: &mut dyn FnMut(DocId, Score) -> Score,
) {
assert!(scorers.len() >= 2);
// Sort by cost (ascending). scorers[0] becomes the "leader" (rarest term).
scorers.sort_by_key(TermScorer::size_hint);
let (leader, secondaries) = scorers.split_first_mut().unwrap();
// Precompute global max scores for early termination checks.
let secondaries_global_max_sum: Score = secondaries.iter().map(|s| s.max_score()).sum();
// Early exit: no document can possibly beat the threshold.
if leader.max_score() + secondaries_global_max_sum <= threshold {
return;
}
let mut doc = leader.doc();
if doc == TERMINATED {
return;
}
loop {
// --- Phase 1: Block-level pruning ---
//
// Position all skip readers on the block containing `doc`.
// seek_block is cheap: it only advances the skip reader, no block decompression.
leader.seek_block(doc);
let leader_block_max: Score = leader.block_max_score();
// Compute the window end as the minimum last_doc_in_block across all scorers.
// This ensures the block_max values are valid for all docs in [doc, window_end].
// Different scorers have independently aligned blocks, so we must use the
// smallest window where all block_max values hold.
let mut window_end: DocId = leader.last_doc_in_block();
let mut secondary_block_max_sum: Score = 0.0;
for secondary in secondaries.iter_mut() {
secondary.seek_block(doc);
secondary_block_max_sum += secondary.block_max_score();
window_end = window_end.min(secondary.last_doc_in_block());
}
if leader_block_max + secondary_block_max_sum <= threshold {
// The entire window cannot beat the threshold. Skip past it.
if window_end == TERMINATED {
return;
}
doc = window_end + 1;
continue;
}
// --- Phase 2: Doc-level processing within the window ---
//
// Load the leader's block and iterate through its documents up to window_end.
doc = leader.seek(doc);
if doc == TERMINATED {
return;
}
'next_doc: while doc <= window_end {
let leader_score: Score = leader.score();
// Doc-level pruning: can leader_score + best possible secondary contribution
// beat the threshold?
if leader_score + secondary_block_max_sum <= threshold {
doc = leader.advance();
if doc == TERMINATED {
return;
}
continue;
}
// Check intersection membership in secondaries.
let mut total_score: Score = leader_score;
for secondary in secondaries.iter_mut() {
// seek() requires target >= self.doc(). If the secondary is already
// past `doc` from a previous seek, this doc is not in the intersection.
let secondary_doc = secondary.doc();
let seek_result = if secondary_doc <= doc {
secondary.seek(doc)
} else {
secondary_doc
};
if seek_result != doc {
doc = leader.advance();
if doc == TERMINATED {
return;
}
continue 'next_doc;
}
total_score += secondary.score();
}
// All secondaries matched.
if total_score > threshold {
threshold = callback(doc, total_score);
// Re-check global early termination after threshold update.
if leader.max_score() + secondaries_global_max_sum <= threshold {
return;
}
}
doc = leader.advance();
if doc == TERMINATED {
return;
}
}
// `doc` is now past window_end but not TERMINATED.
// Loop back to Phase 1 with this new doc.
}
}
#[cfg(test)]
mod tests {
use std::cmp::Ordering;
use std::collections::BinaryHeap;
use proptest::prelude::*;
use crate::query::term_query::TermScorer;
use crate::query::{Bm25Weight, Scorer};
use crate::{DocId, DocSet, Score, TERMINATED};
struct Float(Score);
impl Eq for Float {}
impl PartialEq for Float {
fn eq(&self, other: &Self) -> bool {
self.cmp(other) == Ordering::Equal
}
}
impl PartialOrd for Float {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for Float {
fn cmp(&self, other: &Self) -> Ordering {
other.0.partial_cmp(&self.0).unwrap_or(Ordering::Equal)
}
}
fn nearly_equals(left: Score, right: Score) -> bool {
(left - right).abs() < 0.0001 * (left + right).abs()
}
/// Run block_wand_intersection and collect (doc, score) pairs above threshold.
fn compute_checkpoints_block_wand_intersection(
term_scorers: Vec<TermScorer>,
top_k: usize,
) -> Vec<(DocId, Score)> {
let mut heap: BinaryHeap<Float> = BinaryHeap::with_capacity(top_k);
let mut checkpoints: Vec<(DocId, Score)> = Vec::new();
let mut limit: Score = 0.0;
let callback = &mut |doc, score| {
heap.push(Float(score));
if heap.len() > top_k {
heap.pop().unwrap();
}
if heap.len() == top_k {
limit = heap.peek().unwrap().0;
}
if !nearly_equals(score, limit) {
checkpoints.push((doc, score));
}
limit
};
super::block_wand_intersection(term_scorers, Score::MIN, callback);
checkpoints
}
/// Naive baseline: intersect by iterating all docs.
fn compute_checkpoints_naive_intersection(
mut term_scorers: Vec<TermScorer>,
top_k: usize,
) -> Vec<(DocId, Score)> {
let mut heap: BinaryHeap<Float> = BinaryHeap::with_capacity(top_k);
let mut checkpoints: Vec<(DocId, Score)> = Vec::new();
let mut limit = Score::MIN;
// Sort by cost to use the cheapest as driver.
term_scorers.sort_by_key(|s| s.cost());
let (leader, secondaries) = term_scorers.split_first_mut().unwrap();
let mut doc = leader.doc();
while doc != TERMINATED {
let mut all_match = true;
for secondary in secondaries.iter_mut() {
let secondary_doc = secondary.doc();
let seek_result = if secondary_doc <= doc {
secondary.seek(doc)
} else {
secondary_doc
};
if seek_result != doc {
all_match = false;
break;
}
}
if all_match {
let score: Score =
leader.score() + secondaries.iter_mut().map(|s| s.score()).sum::<Score>();
if score > limit {
heap.push(Float(score));
if heap.len() > top_k {
heap.pop().unwrap();
}
if heap.len() == top_k {
limit = heap.peek().unwrap().0;
}
if !nearly_equals(score, limit) {
checkpoints.push((doc, score));
}
}
}
doc = leader.advance();
}
checkpoints
}
const MAX_TERM_FREQ: u32 = 100u32;
fn posting_list(max_doc: u32) -> BoxedStrategy<Vec<(DocId, u32)>> {
(1..max_doc + 1)
.prop_flat_map(move |doc_freq| {
(
proptest::bits::bitset::sampled(doc_freq as usize, 0..max_doc as usize),
proptest::collection::vec(1u32..MAX_TERM_FREQ, doc_freq as usize),
)
})
.prop_map(|(docset, term_freqs)| {
docset
.iter()
.map(|doc| doc as u32)
.zip(term_freqs.iter().cloned())
.collect::<Vec<_>>()
})
.boxed()
}
#[expect(clippy::type_complexity)]
fn gen_term_scorers(num_scorers: usize) -> BoxedStrategy<(Vec<Vec<(DocId, u32)>>, Vec<u32>)> {
(1u32..100u32)
.prop_flat_map(move |max_doc: u32| {
(
proptest::collection::vec(posting_list(max_doc), num_scorers),
proptest::collection::vec(2u32..10u32 * MAX_TERM_FREQ, max_doc as usize),
)
})
.boxed()
}
fn test_block_wand_intersection_aux(posting_lists: &[Vec<(DocId, u32)>], fieldnorms: &[u32]) {
// Repeat docs 64 times to create multi-block scenarios, matching block_wand.rs test
// strategy.
const REPEAT: usize = 64;
let fieldnorms_expanded: Vec<u32> = fieldnorms
.iter()
.cloned()
.flat_map(|fieldnorm| std::iter::repeat_n(fieldnorm, REPEAT))
.collect();
let postings_lists_expanded: Vec<Vec<(DocId, u32)>> = posting_lists
.iter()
.map(|posting_list| {
posting_list
.iter()
.cloned()
.flat_map(|(doc, term_freq)| {
(0_u32..REPEAT as u32).map(move |offset| {
(
doc * (REPEAT as u32) + offset,
if offset == 0 { term_freq } else { 1 },
)
})
})
.collect::<Vec<(DocId, u32)>>()
})
.collect();
let total_fieldnorms: u64 = fieldnorms_expanded
.iter()
.cloned()
.map(|fieldnorm| fieldnorm as u64)
.sum();
let average_fieldnorm = (total_fieldnorms as Score) / (fieldnorms_expanded.len() as Score);
let max_doc = fieldnorms_expanded.len();
let make_scorers = || -> Vec<TermScorer> {
postings_lists_expanded
.iter()
.map(|postings| {
let bm25_weight = Bm25Weight::for_one_term(
postings.len() as u64,
max_doc as u64,
average_fieldnorm,
);
TermScorer::create_for_test(postings, &fieldnorms_expanded[..], bm25_weight)
})
.collect()
};
for top_k in 1..4 {
let checkpoints_optimized =
compute_checkpoints_block_wand_intersection(make_scorers(), top_k);
let checkpoints_naive = compute_checkpoints_naive_intersection(make_scorers(), top_k);
assert_eq!(
checkpoints_optimized.len(),
checkpoints_naive.len(),
"Mismatch in checkpoint count for top_k={top_k}"
);
for (&(left_doc, left_score), &(right_doc, right_score)) in
checkpoints_optimized.iter().zip(checkpoints_naive.iter())
{
assert_eq!(left_doc, right_doc);
assert!(
nearly_equals(left_score, right_score),
"Score mismatch for doc {left_doc}: {left_score} vs {right_score}"
);
}
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(500))]
#[test]
fn test_block_wand_intersection_two_scorers(
(posting_lists, fieldnorms) in gen_term_scorers(2)
) {
test_block_wand_intersection_aux(&posting_lists[..], &fieldnorms[..]);
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(500))]
#[test]
fn test_block_wand_intersection_three_scorers(
(posting_lists, fieldnorms) in gen_term_scorers(3)
) {
test_block_wand_intersection_aux(&posting_lists[..], &fieldnorms[..]);
}
}
#[test]
fn test_block_wand_intersection_disjoint() {
// Two posting lists with no overlap — intersection is empty.
let fieldnorms: Vec<u32> = vec![10; 200];
let average_fieldnorm = 10.0;
let postings_a: Vec<(DocId, u32)> = (0..100).map(|d| (d, 1)).collect();
let postings_b: Vec<(DocId, u32)> = (100..200).map(|d| (d, 1)).collect();
let scorer_a = TermScorer::create_for_test(
&postings_a,
&fieldnorms,
Bm25Weight::for_one_term(100, 200, average_fieldnorm),
);
let scorer_b = TermScorer::create_for_test(
&postings_b,
&fieldnorms,
Bm25Weight::for_one_term(100, 200, average_fieldnorm),
);
let checkpoints = compute_checkpoints_block_wand_intersection(vec![scorer_a, scorer_b], 10);
assert!(checkpoints.is_empty());
}
#[test]
fn test_block_wand_intersection_all_overlap() {
// Two posting lists with full overlap.
let fieldnorms: Vec<u32> = vec![10; 50];
let average_fieldnorm = 10.0;
let postings: Vec<(DocId, u32)> = (0..50).map(|d| (d, 3)).collect();
let make_scorer = || {
TermScorer::create_for_test(
&postings,
&fieldnorms,
Bm25Weight::for_one_term(50, 50, average_fieldnorm),
)
};
let checkpoints_opt =
compute_checkpoints_block_wand_intersection(vec![make_scorer(), make_scorer()], 5);
let checkpoints_naive =
compute_checkpoints_naive_intersection(vec![make_scorer(), make_scorer()], 5);
assert_eq!(checkpoints_opt.len(), checkpoints_naive.len());
}
}

View File

@@ -16,7 +16,6 @@ use crate::{DocId, Score};
enum SpecializedScorer {
TermUnion(Vec<TermScorer>),
TermIntersection(Vec<TermScorer>),
Other(Box<dyn Scorer>),
}
@@ -94,13 +93,6 @@ fn into_box_scorer<TScoreCombiner: ScoreCombiner>(
BufferedUnionScorer::build(term_scorers, score_combiner_fn, num_docs);
Box::new(union_scorer)
}
SpecializedScorer::TermIntersection(term_scorers) => {
let boxed_scorers: Vec<Box<dyn Scorer>> = term_scorers
.into_iter()
.map(|s| Box::new(s) as Box<dyn Scorer>)
.collect();
intersect_scorers(boxed_scorers, num_docs)
}
SpecializedScorer::Other(scorer) => scorer,
}
}
@@ -305,43 +297,14 @@ impl<TScoreCombiner: ScoreCombiner> BooleanWeight<TScoreCombiner> {
// Result depends entirely on MUST + any removed AllScorers.
let combined_all_scorer_count = must_special_scorer_counts.num_all_scorers
+ should_special_scorer_counts.num_all_scorers;
// Try to detect a pure TermScorer intersection for block-max optimization.
// Preconditions: no removed AllScorers, at least 2 scorers, all TermScorer
// with frequency reading enabled.
if combined_all_scorer_count == 0
&& must_scorers.len() >= 2
&& must_scorers.iter().all(|s| s.is::<TermScorer>())
{
let term_scorers: Vec<TermScorer> = must_scorers
.into_iter()
.map(|s| *(s.downcast::<TermScorer>().map_err(|_| ()).unwrap()))
.collect();
if term_scorers
.iter()
.all(|s| s.freq_reading_option() == FreqReadingOption::ReadFreq)
{
SpecializedScorer::TermIntersection(term_scorers)
} else {
let must_scorers: Vec<Box<dyn Scorer>> = term_scorers
.into_iter()
.map(|s| Box::new(s) as Box<dyn Scorer>)
.collect();
let boxed_scorer: Box<dyn Scorer> =
effective_must_scorer(must_scorers, 0, reader.max_doc(), num_docs)
.unwrap_or_else(|| Box::new(EmptyScorer));
SpecializedScorer::Other(boxed_scorer)
}
} else {
let boxed_scorer: Box<dyn Scorer> = effective_must_scorer(
must_scorers,
combined_all_scorer_count,
reader.max_doc(),
num_docs,
)
.unwrap_or_else(|| Box::new(EmptyScorer));
SpecializedScorer::Other(boxed_scorer)
}
let boxed_scorer: Box<dyn Scorer> = effective_must_scorer(
must_scorers,
combined_all_scorer_count,
reader.max_doc(),
num_docs,
)
.unwrap_or_else(|| Box::new(EmptyScorer));
SpecializedScorer::Other(boxed_scorer)
}
(ShouldScorersCombinationMethod::Optional(should_scorer), must_scorers) => {
// Optional SHOULD: contributes to scoring but not required for matching.
@@ -500,20 +463,14 @@ impl<TScoreCombiner: ScoreCombiner + Sync> Weight for BooleanWeight<TScoreCombin
callback: &mut dyn FnMut(DocId, Score),
) -> crate::Result<()> {
let scorer = self.complex_scorer(reader, 1.0, &self.score_combiner_fn)?;
let num_docs = reader.num_docs();
match scorer {
SpecializedScorer::TermUnion(term_scorers) => {
let mut union_scorer =
BufferedUnionScorer::build(term_scorers, &self.score_combiner_fn, num_docs);
for_each_scorer(&mut union_scorer, callback);
}
SpecializedScorer::TermIntersection(term_scorers) => {
let mut intersection = into_box_scorer(
SpecializedScorer::TermIntersection(term_scorers),
let mut union_scorer = BufferedUnionScorer::build(
term_scorers,
&self.score_combiner_fn,
num_docs,
reader.num_docs(),
);
for_each_scorer(intersection.as_mut(), callback);
for_each_scorer(&mut union_scorer, callback);
}
SpecializedScorer::Other(mut scorer) => {
for_each_scorer(scorer.as_mut(), callback);
@@ -528,22 +485,16 @@ impl<TScoreCombiner: ScoreCombiner + Sync> Weight for BooleanWeight<TScoreCombin
callback: &mut dyn FnMut(&[DocId]),
) -> crate::Result<()> {
let scorer = self.complex_scorer(reader, 1.0, || DoNothingCombiner)?;
let num_docs = reader.num_docs();
let mut buffer = [0u32; COLLECT_BLOCK_BUFFER_LEN];
match scorer {
SpecializedScorer::TermUnion(term_scorers) => {
let mut union_scorer =
BufferedUnionScorer::build(term_scorers, &self.score_combiner_fn, num_docs);
for_each_docset_buffered(&mut union_scorer, &mut buffer, callback);
}
SpecializedScorer::TermIntersection(term_scorers) => {
let mut intersection = into_box_scorer(
SpecializedScorer::TermIntersection(term_scorers),
DoNothingCombiner::default,
num_docs,
let mut union_scorer = BufferedUnionScorer::build(
term_scorers,
&self.score_combiner_fn,
reader.num_docs(),
);
for_each_docset_buffered(intersection.as_mut(), &mut buffer, callback);
for_each_docset_buffered(&mut union_scorer, &mut buffer, callback);
}
SpecializedScorer::Other(mut scorer) => {
for_each_docset_buffered(scorer.as_mut(), &mut buffer, callback);
@@ -573,9 +524,6 @@ impl<TScoreCombiner: ScoreCombiner + Sync> Weight for BooleanWeight<TScoreCombin
SpecializedScorer::TermUnion(term_scorers) => {
super::block_wand(term_scorers, threshold, callback);
}
SpecializedScorer::TermIntersection(term_scorers) => {
super::block_wand_intersection(term_scorers, threshold, callback);
}
SpecializedScorer::Other(mut scorer) => {
for_each_pruning_scorer(scorer.as_mut(), threshold, callback);
}

View File

@@ -1,10 +1,8 @@
mod block_wand;
mod block_wand_intersection;
mod boolean_query;
mod boolean_weight;
pub(crate) use self::block_wand::{block_wand, block_wand_single_scorer};
pub(crate) use self::block_wand_intersection::block_wand_intersection;
pub use self::boolean_query::BooleanQuery;
pub use self::boolean_weight::BooleanWeight;

View File

@@ -1,6 +1,6 @@
use common::TinySet;
use crate::docset::{DocSet, SeekDangerResult, TERMINATED};
use crate::docset::{DocSet, SeekDangerResult, COLLECT_BLOCK_BUFFER_LEN, TERMINATED};
use crate::query::score_combiner::{DoNothingCombiner, ScoreCombiner};
use crate::query::size_hint::estimate_union;
use crate::query::Scorer;
@@ -172,6 +172,46 @@ where
self.doc
}
fn fill_buffer(&mut self, buffer: &mut [DocId; COLLECT_BLOCK_BUFFER_LEN]) -> usize {
if self.doc == TERMINATED {
return 0;
}
// The current doc (self.doc) has already been popped from the bitsets,
// so the loop below won't yield it. Emit it here first.
buffer[0] = self.doc;
let mut count = 1;
loop {
// Drain docs directly from the pre-computed bitsets.
while self.bucket_idx < HORIZON_NUM_TINYBITSETS {
// Move bitset to a local variable to avoid read/store on self.bitsets while
// iterating through the bits.
let mut tinyset: TinySet = self.bitsets[self.bucket_idx];
while let Some(val) = tinyset.pop_lowest() {
let delta = val + (self.bucket_idx as u32) * 64;
self.doc = self.window_start_doc + delta;
if count >= COLLECT_BLOCK_BUFFER_LEN {
// Buffer full; put remaining bits back.
self.bitsets[self.bucket_idx] = tinyset;
return COLLECT_BLOCK_BUFFER_LEN;
}
buffer[count] = self.doc;
count += 1;
}
self.bitsets[self.bucket_idx] = TinySet::empty();
self.bucket_idx += 1;
}
// Current window exhausted, refill.
if !self.refill() {
self.doc = TERMINATED;
return count;
}
}
}
fn seek(&mut self, target: DocId) -> DocId {
if self.doc >= target {
return self.doc;