From 4bdbc013ba962b3092b730a562b8f73630948924 Mon Sep 17 00:00:00 2001 From: Paul Masurel Date: Thu, 26 Mar 2026 15:17:36 -0400 Subject: [PATCH] Optimizing top K using Adrien Grand's ideas https://jpountz.github.io/2025/08/28/compiled-vs-vectorized-search-engine-edition.html --- .../boolean_query/block_wand_intersection.rs | 418 ++++++++++++++++++ src/query/boolean_query/boolean_weight.rs | 88 +++- src/query/boolean_query/mod.rs | 2 + 3 files changed, 490 insertions(+), 18 deletions(-) create mode 100644 src/query/boolean_query/block_wand_intersection.rs diff --git a/src/query/boolean_query/block_wand_intersection.rs b/src/query/boolean_query/block_wand_intersection.rs new file mode 100644 index 000000000..3a77f5877 --- /dev/null +++ b/src/query/boolean_query/block_wand_intersection.rs @@ -0,0 +1,418 @@ +use crate::query::term_query::TermScorer; +use crate::query::Scorer; +use crate::{DocId, DocSet, Score, TERMINATED}; + +/// Block-max pruning for top-K over intersection of term scorers. +/// +/// Uses the least-frequent term as "leader" to define 128-doc processing windows. +/// For each window, the sum of block_max_scores is compared to the current threshold; +/// if the block can't beat it, the entire block is skipped. +/// +/// Within non-skipped blocks, individual documents are pruned by checking whether +/// leader_score + sum(secondary block_max_scores) can exceed the threshold before +/// performing the expensive intersection membership check (seeking into secondary scorers). +/// +/// # Preconditions +/// - `scorers` has at least 2 elements +/// - All scorers read frequencies (`FreqReadingOption::ReadFreq`) +pub fn block_wand_intersection( + mut scorers: Vec, + mut threshold: Score, + callback: &mut dyn FnMut(DocId, Score) -> Score, +) { + assert!(scorers.len() >= 2); + + // Sort by cost (ascending). scorers[0] becomes the "leader" (rarest term). + scorers.sort_by_key(TermScorer::size_hint); + + let (leader, secondaries) = scorers.split_first_mut().unwrap(); + + // Precompute global max scores for early termination checks. + let secondaries_global_max_sum: Score = secondaries.iter().map(|s| s.max_score()).sum(); + + // Early exit: no document can possibly beat the threshold. + if leader.max_score() + secondaries_global_max_sum <= threshold { + return; + } + + let mut doc = leader.doc(); + if doc == TERMINATED { + return; + } + + loop { + // --- Phase 1: Block-level pruning --- + // + // Position all skip readers on the block containing `doc`. + // seek_block is cheap: it only advances the skip reader, no block decompression. + leader.seek_block(doc); + let leader_block_max: Score = leader.block_max_score(); + + // Compute the window end as the minimum last_doc_in_block across all scorers. + // This ensures the block_max values are valid for all docs in [doc, window_end]. + // Different scorers have independently aligned blocks, so we must use the + // smallest window where all block_max values hold. + let mut window_end: DocId = leader.last_doc_in_block(); + + let mut secondary_block_max_sum: Score = 0.0; + for secondary in secondaries.iter_mut() { + secondary.seek_block(doc); + secondary_block_max_sum += secondary.block_max_score(); + window_end = window_end.min(secondary.last_doc_in_block()); + } + + if leader_block_max + secondary_block_max_sum <= threshold { + // The entire window cannot beat the threshold. Skip past it. + if window_end == TERMINATED { + return; + } + doc = window_end + 1; + continue; + } + + // --- Phase 2: Doc-level processing within the window --- + // + // Load the leader's block and iterate through its documents up to window_end. + doc = leader.seek(doc); + if doc == TERMINATED { + return; + } + + 'next_doc: while doc <= window_end { + let leader_score: Score = leader.score(); + + // Doc-level pruning: can leader_score + best possible secondary contribution + // beat the threshold? + if leader_score + secondary_block_max_sum <= threshold { + doc = leader.advance(); + if doc == TERMINATED { + return; + } + continue; + } + + // Check intersection membership in secondaries. + let mut total_score: Score = leader_score; + for secondary in secondaries.iter_mut() { + // seek() requires target >= self.doc(). If the secondary is already + // past `doc` from a previous seek, this doc is not in the intersection. + let secondary_doc = secondary.doc(); + let seek_result = if secondary_doc <= doc { + secondary.seek(doc) + } else { + secondary_doc + }; + if seek_result != doc { + doc = leader.advance(); + if doc == TERMINATED { + return; + } + continue 'next_doc; + } + total_score += secondary.score(); + } + + // All secondaries matched. + if total_score > threshold { + threshold = callback(doc, total_score); + + // Re-check global early termination after threshold update. + if leader.max_score() + secondaries_global_max_sum <= threshold { + return; + } + } + + doc = leader.advance(); + if doc == TERMINATED { + return; + } + } + // `doc` is now past window_end but not TERMINATED. + // Loop back to Phase 1 with this new doc. + } +} + +#[cfg(test)] +mod tests { + use std::cmp::Ordering; + use std::collections::BinaryHeap; + + use proptest::prelude::*; + + use crate::query::term_query::TermScorer; + use crate::query::{Bm25Weight, Scorer}; + use crate::{DocId, DocSet, Score, TERMINATED}; + + struct Float(Score); + + impl Eq for Float {} + + impl PartialEq for Float { + fn eq(&self, other: &Self) -> bool { + self.cmp(other) == Ordering::Equal + } + } + + impl PartialOrd for Float { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } + } + + impl Ord for Float { + fn cmp(&self, other: &Self) -> Ordering { + other.0.partial_cmp(&self.0).unwrap_or(Ordering::Equal) + } + } + + fn nearly_equals(left: Score, right: Score) -> bool { + (left - right).abs() < 0.0001 * (left + right).abs() + } + + /// Run block_wand_intersection and collect (doc, score) pairs above threshold. + fn compute_checkpoints_block_wand_intersection( + term_scorers: Vec, + top_k: usize, + ) -> Vec<(DocId, Score)> { + let mut heap: BinaryHeap = BinaryHeap::with_capacity(top_k); + let mut checkpoints: Vec<(DocId, Score)> = Vec::new(); + let mut limit: Score = 0.0; + + let callback = &mut |doc, score| { + heap.push(Float(score)); + if heap.len() > top_k { + heap.pop().unwrap(); + } + if heap.len() == top_k { + limit = heap.peek().unwrap().0; + } + if !nearly_equals(score, limit) { + checkpoints.push((doc, score)); + } + limit + }; + + super::block_wand_intersection(term_scorers, Score::MIN, callback); + checkpoints + } + + /// Naive baseline: intersect by iterating all docs. + fn compute_checkpoints_naive_intersection( + mut term_scorers: Vec, + top_k: usize, + ) -> Vec<(DocId, Score)> { + let mut heap: BinaryHeap = BinaryHeap::with_capacity(top_k); + let mut checkpoints: Vec<(DocId, Score)> = Vec::new(); + let mut limit = Score::MIN; + + // Sort by cost to use the cheapest as driver. + term_scorers.sort_by_key(|s| s.cost()); + + let (leader, secondaries) = term_scorers.split_first_mut().unwrap(); + + let mut doc = leader.doc(); + while doc != TERMINATED { + let mut all_match = true; + for secondary in secondaries.iter_mut() { + let secondary_doc = secondary.doc(); + let seek_result = if secondary_doc <= doc { + secondary.seek(doc) + } else { + secondary_doc + }; + if seek_result != doc { + all_match = false; + break; + } + } + + if all_match { + let score: Score = + leader.score() + secondaries.iter_mut().map(|s| s.score()).sum::(); + + if score > limit { + heap.push(Float(score)); + if heap.len() > top_k { + heap.pop().unwrap(); + } + if heap.len() == top_k { + limit = heap.peek().unwrap().0; + } + if !nearly_equals(score, limit) { + checkpoints.push((doc, score)); + } + } + } + doc = leader.advance(); + } + checkpoints + } + + const MAX_TERM_FREQ: u32 = 100u32; + + fn posting_list(max_doc: u32) -> BoxedStrategy> { + (1..max_doc + 1) + .prop_flat_map(move |doc_freq| { + ( + proptest::bits::bitset::sampled(doc_freq as usize, 0..max_doc as usize), + proptest::collection::vec(1u32..MAX_TERM_FREQ, doc_freq as usize), + ) + }) + .prop_map(|(docset, term_freqs)| { + docset + .iter() + .map(|doc| doc as u32) + .zip(term_freqs.iter().cloned()) + .collect::>() + }) + .boxed() + } + + #[expect(clippy::type_complexity)] + fn gen_term_scorers(num_scorers: usize) -> BoxedStrategy<(Vec>, Vec)> { + (1u32..100u32) + .prop_flat_map(move |max_doc: u32| { + ( + proptest::collection::vec(posting_list(max_doc), num_scorers), + proptest::collection::vec(2u32..10u32 * MAX_TERM_FREQ, max_doc as usize), + ) + }) + .boxed() + } + + fn test_block_wand_intersection_aux(posting_lists: &[Vec<(DocId, u32)>], fieldnorms: &[u32]) { + // Repeat docs 64 times to create multi-block scenarios, matching block_wand.rs test + // strategy. + const REPEAT: usize = 64; + let fieldnorms_expanded: Vec = fieldnorms + .iter() + .cloned() + .flat_map(|fieldnorm| std::iter::repeat_n(fieldnorm, REPEAT)) + .collect(); + + let postings_lists_expanded: Vec> = posting_lists + .iter() + .map(|posting_list| { + posting_list + .iter() + .cloned() + .flat_map(|(doc, term_freq)| { + (0_u32..REPEAT as u32).map(move |offset| { + ( + doc * (REPEAT as u32) + offset, + if offset == 0 { term_freq } else { 1 }, + ) + }) + }) + .collect::>() + }) + .collect(); + + let total_fieldnorms: u64 = fieldnorms_expanded + .iter() + .cloned() + .map(|fieldnorm| fieldnorm as u64) + .sum(); + let average_fieldnorm = (total_fieldnorms as Score) / (fieldnorms_expanded.len() as Score); + let max_doc = fieldnorms_expanded.len(); + + let make_scorers = || -> Vec { + postings_lists_expanded + .iter() + .map(|postings| { + let bm25_weight = Bm25Weight::for_one_term( + postings.len() as u64, + max_doc as u64, + average_fieldnorm, + ); + TermScorer::create_for_test(postings, &fieldnorms_expanded[..], bm25_weight) + }) + .collect() + }; + + for top_k in 1..4 { + let checkpoints_optimized = + compute_checkpoints_block_wand_intersection(make_scorers(), top_k); + let checkpoints_naive = compute_checkpoints_naive_intersection(make_scorers(), top_k); + assert_eq!( + checkpoints_optimized.len(), + checkpoints_naive.len(), + "Mismatch in checkpoint count for top_k={top_k}" + ); + for (&(left_doc, left_score), &(right_doc, right_score)) in + checkpoints_optimized.iter().zip(checkpoints_naive.iter()) + { + assert_eq!(left_doc, right_doc); + assert!( + nearly_equals(left_score, right_score), + "Score mismatch for doc {left_doc}: {left_score} vs {right_score}" + ); + } + } + } + + proptest! { + #![proptest_config(ProptestConfig::with_cases(500))] + #[test] + fn test_block_wand_intersection_two_scorers( + (posting_lists, fieldnorms) in gen_term_scorers(2) + ) { + test_block_wand_intersection_aux(&posting_lists[..], &fieldnorms[..]); + } + } + + proptest! { + #![proptest_config(ProptestConfig::with_cases(500))] + #[test] + fn test_block_wand_intersection_three_scorers( + (posting_lists, fieldnorms) in gen_term_scorers(3) + ) { + test_block_wand_intersection_aux(&posting_lists[..], &fieldnorms[..]); + } + } + + #[test] + fn test_block_wand_intersection_disjoint() { + // Two posting lists with no overlap — intersection is empty. + let fieldnorms: Vec = vec![10; 200]; + let average_fieldnorm = 10.0; + let postings_a: Vec<(DocId, u32)> = (0..100).map(|d| (d, 1)).collect(); + let postings_b: Vec<(DocId, u32)> = (100..200).map(|d| (d, 1)).collect(); + + let scorer_a = TermScorer::create_for_test( + &postings_a, + &fieldnorms, + Bm25Weight::for_one_term(100, 200, average_fieldnorm), + ); + let scorer_b = TermScorer::create_for_test( + &postings_b, + &fieldnorms, + Bm25Weight::for_one_term(100, 200, average_fieldnorm), + ); + + let checkpoints = compute_checkpoints_block_wand_intersection(vec![scorer_a, scorer_b], 10); + assert!(checkpoints.is_empty()); + } + + #[test] + fn test_block_wand_intersection_all_overlap() { + // Two posting lists with full overlap. + let fieldnorms: Vec = vec![10; 50]; + let average_fieldnorm = 10.0; + let postings: Vec<(DocId, u32)> = (0..50).map(|d| (d, 3)).collect(); + + let make_scorer = || { + TermScorer::create_for_test( + &postings, + &fieldnorms, + Bm25Weight::for_one_term(50, 50, average_fieldnorm), + ) + }; + + let checkpoints_opt = + compute_checkpoints_block_wand_intersection(vec![make_scorer(), make_scorer()], 5); + let checkpoints_naive = + compute_checkpoints_naive_intersection(vec![make_scorer(), make_scorer()], 5); + assert_eq!(checkpoints_opt.len(), checkpoints_naive.len()); + } +} diff --git a/src/query/boolean_query/boolean_weight.rs b/src/query/boolean_query/boolean_weight.rs index 062449b8a..0df4f9abe 100644 --- a/src/query/boolean_query/boolean_weight.rs +++ b/src/query/boolean_query/boolean_weight.rs @@ -16,6 +16,7 @@ use crate::{DocId, Score}; enum SpecializedScorer { TermUnion(Vec), + TermIntersection(Vec), Other(Box), } @@ -93,6 +94,13 @@ fn into_box_scorer( BufferedUnionScorer::build(term_scorers, score_combiner_fn, num_docs); Box::new(union_scorer) } + SpecializedScorer::TermIntersection(term_scorers) => { + let boxed_scorers: Vec> = term_scorers + .into_iter() + .map(|s| Box::new(s) as Box) + .collect(); + intersect_scorers(boxed_scorers, num_docs) + } SpecializedScorer::Other(scorer) => scorer, } } @@ -297,14 +305,43 @@ impl BooleanWeight { // Result depends entirely on MUST + any removed AllScorers. let combined_all_scorer_count = must_special_scorer_counts.num_all_scorers + should_special_scorer_counts.num_all_scorers; - let boxed_scorer: Box = effective_must_scorer( - must_scorers, - combined_all_scorer_count, - reader.max_doc(), - num_docs, - ) - .unwrap_or_else(|| Box::new(EmptyScorer)); - SpecializedScorer::Other(boxed_scorer) + + // Try to detect a pure TermScorer intersection for block-max optimization. + // Preconditions: no removed AllScorers, at least 2 scorers, all TermScorer + // with frequency reading enabled. + if combined_all_scorer_count == 0 + && must_scorers.len() >= 2 + && must_scorers.iter().all(|s| s.is::()) + { + let term_scorers: Vec = must_scorers + .into_iter() + .map(|s| *(s.downcast::().map_err(|_| ()).unwrap())) + .collect(); + if term_scorers + .iter() + .all(|s| s.freq_reading_option() == FreqReadingOption::ReadFreq) + { + SpecializedScorer::TermIntersection(term_scorers) + } else { + let must_scorers: Vec> = term_scorers + .into_iter() + .map(|s| Box::new(s) as Box) + .collect(); + let boxed_scorer: Box = + effective_must_scorer(must_scorers, 0, reader.max_doc(), num_docs) + .unwrap_or_else(|| Box::new(EmptyScorer)); + SpecializedScorer::Other(boxed_scorer) + } + } else { + let boxed_scorer: Box = effective_must_scorer( + must_scorers, + combined_all_scorer_count, + reader.max_doc(), + num_docs, + ) + .unwrap_or_else(|| Box::new(EmptyScorer)); + SpecializedScorer::Other(boxed_scorer) + } } (ShouldScorersCombinationMethod::Optional(should_scorer), must_scorers) => { // Optional SHOULD: contributes to scoring but not required for matching. @@ -463,15 +500,21 @@ impl Weight for BooleanWeight crate::Result<()> { let scorer = self.complex_scorer(reader, 1.0, &self.score_combiner_fn)?; + let num_docs = reader.num_docs(); match scorer { SpecializedScorer::TermUnion(term_scorers) => { - let mut union_scorer = BufferedUnionScorer::build( - term_scorers, - &self.score_combiner_fn, - reader.num_docs(), - ); + let mut union_scorer = + BufferedUnionScorer::build(term_scorers, &self.score_combiner_fn, num_docs); for_each_scorer(&mut union_scorer, callback); } + SpecializedScorer::TermIntersection(term_scorers) => { + let mut intersection = into_box_scorer( + SpecializedScorer::TermIntersection(term_scorers), + &self.score_combiner_fn, + num_docs, + ); + for_each_scorer(intersection.as_mut(), callback); + } SpecializedScorer::Other(mut scorer) => { for_each_scorer(scorer.as_mut(), callback); } @@ -485,17 +528,23 @@ impl Weight for BooleanWeight crate::Result<()> { let scorer = self.complex_scorer(reader, 1.0, || DoNothingCombiner)?; + let num_docs = reader.num_docs(); let mut buffer = [0u32; COLLECT_BLOCK_BUFFER_LEN]; match scorer { SpecializedScorer::TermUnion(term_scorers) => { - let mut union_scorer = BufferedUnionScorer::build( - term_scorers, - &self.score_combiner_fn, - reader.num_docs(), - ); + let mut union_scorer = + BufferedUnionScorer::build(term_scorers, &self.score_combiner_fn, num_docs); for_each_docset_buffered(&mut union_scorer, &mut buffer, callback); } + SpecializedScorer::TermIntersection(term_scorers) => { + let mut intersection = into_box_scorer( + SpecializedScorer::TermIntersection(term_scorers), + DoNothingCombiner::default, + num_docs, + ); + for_each_docset_buffered(intersection.as_mut(), &mut buffer, callback); + } SpecializedScorer::Other(mut scorer) => { for_each_docset_buffered(scorer.as_mut(), &mut buffer, callback); } @@ -524,6 +573,9 @@ impl Weight for BooleanWeight { super::block_wand(term_scorers, threshold, callback); } + SpecializedScorer::TermIntersection(term_scorers) => { + super::block_wand_intersection(term_scorers, threshold, callback); + } SpecializedScorer::Other(mut scorer) => { for_each_pruning_scorer(scorer.as_mut(), threshold, callback); } diff --git a/src/query/boolean_query/mod.rs b/src/query/boolean_query/mod.rs index 681881c11..63faa4f57 100644 --- a/src/query/boolean_query/mod.rs +++ b/src/query/boolean_query/mod.rs @@ -1,8 +1,10 @@ mod block_wand; +mod block_wand_intersection; mod boolean_query; mod boolean_weight; pub(crate) use self::block_wand::{block_wand, block_wand_single_scorer}; +pub(crate) use self::block_wand_intersection::block_wand_intersection; pub use self::boolean_query::BooleanQuery; pub use self::boolean_weight::BooleanWeight;