diff --git a/Cargo.toml b/Cargo.toml index ee308b842..d97737f26 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -201,3 +201,7 @@ harness = false [[bench]] name = "regex_all_terms" harness = false + +[[bench]] +name = "intersection_bench" +harness = false diff --git a/benches/intersection_bench.rs b/benches/intersection_bench.rs new file mode 100644 index 000000000..b0051908c --- /dev/null +++ b/benches/intersection_bench.rs @@ -0,0 +1,149 @@ +// Benchmarks top-K intersection of term scorers (block_wand_intersection). +// +// What's measured: +// - Conjunctive queries (+a +b, +a +b +c) with top-10 by score +// - Varying doc-frequency balance between terms (balanced, skewed, very skewed) +// - Realistic term frequencies (geometric distribution, mostly low) +// - 1M-doc single segment +// +// Run with: cargo bench --bench intersection_bench + +use binggan::{black_box, BenchRunner}; +use rand::prelude::*; +use rand::rngs::StdRng; +use rand::SeedableRng; +use tantivy::collector::TopDocs; +use tantivy::query::QueryParser; +use tantivy::schema::{Schema, TEXT}; +use tantivy::{doc, Index, ReloadPolicy, Searcher}; + +const NUM_DOCS: usize = 1_000_000; + +struct BenchIndex { + searcher: Searcher, + query_parser: QueryParser, +} + +/// Generate term frequency from a geometric-like distribution. +/// Most values are 1, a few are 2-3, rarely higher. +/// p controls the decay: higher p → more weight on tf=1. +fn random_term_freq(rng: &mut StdRng, p: f64) -> u32 { + let mut tf = 1u32; + while tf < 10 && rng.random_bool(1.0 - p) { + tf += 1; + } + tf +} + +/// Build an index with three terms (a, b, c) with given doc-frequency probabilities. +/// Each term occurrence has a realistic term frequency (geometric distribution). +/// Field length is padded with filler tokens to create varied fieldnorms. +fn build_index(p_a: f64, p_b: f64, p_c: f64) -> BenchIndex { + let mut schema_builder = Schema::builder(); + let body = schema_builder.add_text_field("body", TEXT); + let schema = schema_builder.build(); + let index = Index::create_in_ram(schema); + + let mut rng = StdRng::from_seed([42u8; 32]); + + { + let mut writer = index.writer_with_num_threads(1, 500_000_000).unwrap(); + for _ in 0..NUM_DOCS { + let mut tokens: Vec = Vec::new(); + + if rng.random_bool(p_a) { + let tf = random_term_freq(&mut rng, 0.7); + for _ in 0..tf { + tokens.push("aaa".to_string()); + } + } + if rng.random_bool(p_b) { + let tf = random_term_freq(&mut rng, 0.7); + for _ in 0..tf { + tokens.push("bbb".to_string()); + } + } + if rng.random_bool(p_c) { + let tf = random_term_freq(&mut rng, 0.7); + for _ in 0..tf { + tokens.push("ccc".to_string()); + } + } + + // Pad with filler to create varied field lengths (5-30 tokens). + let filler_count = rng.random_range(5u32..30u32); + for _ in 0..filler_count { + tokens.push("filler".to_string()); + } + + let text = tokens.join(" "); + writer.add_document(doc!(body => text)).unwrap(); + } + writer.commit().unwrap(); + } + + let reader = index + .reader_builder() + .reload_policy(ReloadPolicy::Manual) + .try_into() + .unwrap(); + let searcher = reader.searcher(); + let query_parser = QueryParser::for_index(&index, vec![body]); + + BenchIndex { + searcher, + query_parser, + } +} + +fn main() { + // Scenarios: (label, p_a, p_b, p_c) + // + // "balanced": all terms ~10% → intersection ~1% of docs + // "skewed": one common (50%), one rare (2%) → intersection ~1% + // "very_skewed": one very common (80%), one very rare (0.5%) → intersection ~0.4% + // "three_balanced": three terms ~20% each → intersection ~0.8% + // "three_skewed": 50% / 10% / 2% → intersection ~0.1% + let scenarios: Vec<(&str, f64, f64, f64)> = vec![ + ("balanced_10%_10%", 0.10, 0.10, 0.0), + ("skewed_50%_2%", 0.50, 0.02, 0.0), + ("very_skewed_80%_0.5%", 0.80, 0.005, 0.0), + ("three_balanced_20%_20%_20%", 0.20, 0.20, 0.20), + ("three_skewed_50%_10%_2%", 0.50, 0.10, 0.02), + ]; + + let mut runner = BenchRunner::new(); + + for (label, p_a, p_b, p_c) in &scenarios { + let bench_index = build_index(*p_a, *p_b, *p_c); + + let mut group = runner.new_group(); + group.set_name(format!("intersection — {label}")); + + // Two-term intersection + if *p_a > 0.0 && *p_b > 0.0 { + let query_str = "+aaa +bbb"; + let query = bench_index.query_parser.parse_query(query_str).unwrap(); + let searcher = bench_index.searcher.clone(); + group.register(format!("{query_str} top10"), move |_| { + let collector = TopDocs::with_limit(10).order_by_score(); + black_box(searcher.search(&query, &collector).unwrap()); + 1usize + }); + } + + // Three-term intersection + if *p_c > 0.0 { + let query_str = "+aaa +bbb +ccc"; + let query = bench_index.query_parser.parse_query(query_str).unwrap(); + let searcher = bench_index.searcher.clone(); + group.register(format!("{query_str} top10"), move |_| { + let collector = TopDocs::with_limit(10).order_by_score(); + black_box(searcher.search(&query, &collector).unwrap()); + 1usize + }); + } + + group.run(); + } +} diff --git a/src/postings/block_segment_postings.rs b/src/postings/block_segment_postings.rs index 47ace9975..0596ef58e 100644 --- a/src/postings/block_segment_postings.rs +++ b/src/postings/block_segment_postings.rs @@ -25,7 +25,7 @@ fn max_score>(mut it: I) -> Option { pub struct BlockSegmentPostings { pub(crate) doc_decoder: BlockDecoder, block_loaded: bool, - freq_decoder: BlockDecoder, + pub(crate) freq_decoder: BlockDecoder, freq_reading_option: FreqReadingOption, block_max_score_cache: Option, doc_freq: u32, @@ -291,11 +291,13 @@ impl BlockSegmentPostings { /// `.load_block()` needs to be called manually afterwards. /// If all docs are smaller than target, the block loaded may be empty, /// or be the last an incomplete VInt block. - pub(crate) fn seek_block(&mut self, target_doc: DocId) { + #[inline] + pub(crate) fn seek_block(&mut self, target_doc: DocId) -> bool { if self.skip_reader.seek(target_doc) { self.block_max_score_cache = None; self.block_loaded = false; } + self.skip_reader.remaining_docs != 0 } pub(crate) fn block_is_loaded(&self) -> bool { diff --git a/src/postings/skip.rs b/src/postings/skip.rs index 3900fd40e..27c8a8723 100644 --- a/src/postings/skip.rs +++ b/src/postings/skip.rs @@ -96,7 +96,7 @@ pub(crate) struct SkipReader { owned_read: OwnedBytes, skip_info: IndexRecordOption, byte_offset: usize, - remaining_docs: u32, // number of docs remaining, including the + pub remaining_docs: u32, // number of docs remaining, including the // documents in the current block. block_info: BlockInfo, diff --git a/src/query/boolean_query/block_wand_intersection.rs b/src/query/boolean_query/block_wand_intersection.rs new file mode 100644 index 000000000..5617b2e76 --- /dev/null +++ b/src/query/boolean_query/block_wand_intersection.rs @@ -0,0 +1,435 @@ +use crate::postings::compression::COMPRESSION_BLOCK_SIZE; +use crate::query::term_query::TermScorer; +use crate::query::Scorer; +use crate::{DocId, DocSet, Score, TERMINATED}; + +/// Block-max pruning for top-K over intersection of term scorers. +/// +/// Uses the least-frequent term as "leader" to define 128-doc processing windows. +/// For each window, the sum of block_max_scores is compared to the current threshold; +/// if the block can't beat it, the entire block is skipped. +/// +/// Within non-skipped blocks, individual documents are pruned by checking whether +/// leader_score + sum(secondary block_max_scores) can exceed the threshold before +/// performing the expensive intersection membership check (seeking into secondary scorers). +/// +/// # Preconditions +/// - `scorers` has at least 2 elements +/// - All scorers read frequencies (`FreqReadingOption::ReadFreq`) +pub fn block_wand_intersection( + mut scorers: Vec, + mut threshold: Score, + callback: &mut dyn FnMut(DocId, Score) -> Score, +) { + assert!(scorers.len() >= 2); + + // Sort by cost (ascending). scorers[0] becomes the "leader" (rarest term). + scorers.sort_by_key(TermScorer::size_hint); + + let (leader, secondaries) = scorers.split_first_mut().unwrap(); + + // Precompute global max scores for early termination checks. + let leader_max_score: Score = leader.max_score(); + let secondaries_global_max_sum: Score = secondaries.iter().map(TermScorer::max_score).sum(); + + // Early exit: no document can possibly beat the threshold. + if leader_max_score + secondaries_global_max_sum <= threshold { + return; + } + + // Borrow fieldnorm reader and BM25 weight before the main loop. + // These are immutable references to disjoint fields from block_cursor, + // but Rust's borrow checker can't see through method calls, so we + // extract them once upfront. + let fieldnorm_reader = leader.fieldnorm_reader().clone(); + let bm25_weight = leader.bm25_weight().clone(); + + let mut doc = leader.doc(); + + while doc < TERMINATED { + // --- Phase 1: Block-level pruning --- + // + // Position all skip readers on the block containing `doc`. + // seek_block is cheap: it only advances the skip reader, no block decompression. + leader.seek_block(doc); + let leader_block_max: Score = leader.block_max_score(); + + // Compute the window end as the minimum last_doc_in_block across all scorers. + // This ensures the block_max values are valid for all docs in [doc, window_end]. + // Different scorers have independently aligned blocks, so we must use the + // smallest window where all block_max values hold. + let mut window_end: DocId = leader.last_doc_in_block(); + + let mut secondary_block_max_sum: Score = 0.0; + for secondary in secondaries.iter_mut() { + if !secondary.block_cursor().seek_block(doc) { + return; + } + window_end = window_end.min(secondary.last_doc_in_block()); + secondary_block_max_sum += secondary.block_max_score(); + } + + if leader_block_max + secondary_block_max_sum <= threshold { + // The entire window cannot beat the threshold. Skip past it. + doc = window_end + 1; + continue; + } + + // --- Phase 2: Batch processing within the window --- + // + // Score-first approach: decode the leader's block, filter by threshold, + // then check intersection membership only for survivors. This avoids expensive + // secondary seeks for docs that can't beat the threshold. + let block_cursor = leader.block_cursor(); + // seek loads the block and returns the in-block index of the first doc >= `doc`. + let start_idx = block_cursor.seek(doc); + + // Use the branchless binary search on the doc decoder to find the first + // index past window_end. + let end_idx = block_cursor + .doc_decoder + .seek_within_block(window_end + 1) + .min(block_cursor.block_len()); + + let block_docs = &block_cursor.doc_decoder.output_array()[start_idx..end_idx]; + let block_freqs = &block_cursor.freq_decoder.output_array()[start_idx..end_idx]; + + // Pass 1: Batch-compute leader BM25 scores and branchlessly filter + // candidates that can't beat the threshold. + // + // The trick: always write to the buffer at `num_candidates`, then + // conditionally advance the count. The compiler can turn this into + // a cmov instead of a branch, avoiding misprediction costs. + let score_threshold = threshold - secondary_block_max_sum; + let mut candidate_doc_ids = [0u32; COMPRESSION_BLOCK_SIZE]; + let mut candidate_scores = [0.0f32; COMPRESSION_BLOCK_SIZE]; + let mut num_candidates = 0usize; + + for (candidate_doc, term_freq) in + block_docs.iter().copied().zip(block_freqs.iter().copied()) + { + let fieldnorm_id = fieldnorm_reader.fieldnorm_id(candidate_doc); + let leader_score = bm25_weight.score(fieldnorm_id, term_freq); + candidate_doc_ids[num_candidates] = candidate_doc; + candidate_scores[num_candidates] = leader_score; + num_candidates += (leader_score > score_threshold) as usize; + } + + // Pass 2: Check intersection membership only for survivors. + // score_threshold may be stale (threshold can increase from callbacks), + // but that's conservative — we may check a few extra candidates, never miss one. + 'next_candidate: for candidate_idx in 0..num_candidates { + let candidate_doc = candidate_doc_ids[candidate_idx]; + let mut total_score: Score = candidate_scores[candidate_idx]; + + for secondary in secondaries.iter_mut() { + // If a previous candidate already advanced this secondary past + // candidate_doc, the candidate can't be in the intersection. + if secondary.doc() > candidate_doc { + continue 'next_candidate; + } + let seek_result = secondary.seek(candidate_doc); + if seek_result != candidate_doc { + continue 'next_candidate; + } + total_score += secondary.score(); + } + + // All secondaries matched. + if total_score > threshold { + threshold = callback(candidate_doc, total_score); + + if leader_max_score + secondaries_global_max_sum <= threshold { + return; + } + } + } + + doc = window_end + 1; + } +} + +#[cfg(test)] +mod tests { + use std::cmp::Ordering; + use std::collections::BinaryHeap; + + use proptest::prelude::*; + + use crate::query::term_query::TermScorer; + use crate::query::{Bm25Weight, Scorer}; + use crate::{DocId, DocSet, Score, TERMINATED}; + + struct Float(Score); + + impl Eq for Float {} + + impl PartialEq for Float { + fn eq(&self, other: &Self) -> bool { + self.cmp(other) == Ordering::Equal + } + } + + impl PartialOrd for Float { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } + } + + impl Ord for Float { + fn cmp(&self, other: &Self) -> Ordering { + other.0.partial_cmp(&self.0).unwrap_or(Ordering::Equal) + } + } + + fn nearly_equals(left: Score, right: Score) -> bool { + (left - right).abs() < 0.0001 * (left + right).abs() + } + + /// Run block_wand_intersection and collect (doc, score) pairs above threshold. + fn compute_checkpoints_block_wand_intersection( + term_scorers: Vec, + top_k: usize, + ) -> Vec<(DocId, Score)> { + let mut heap: BinaryHeap = BinaryHeap::with_capacity(top_k); + let mut checkpoints: Vec<(DocId, Score)> = Vec::new(); + let mut limit: Score = 0.0; + + let callback = &mut |doc, score| { + heap.push(Float(score)); + if heap.len() > top_k { + heap.pop().unwrap(); + } + if heap.len() == top_k { + limit = heap.peek().unwrap().0; + } + if !nearly_equals(score, limit) { + checkpoints.push((doc, score)); + } + limit + }; + + super::block_wand_intersection(term_scorers, Score::MIN, callback); + checkpoints + } + + /// Naive baseline: intersect by iterating all docs. + fn compute_checkpoints_naive_intersection( + mut term_scorers: Vec, + top_k: usize, + ) -> Vec<(DocId, Score)> { + let mut heap: BinaryHeap = BinaryHeap::with_capacity(top_k); + let mut checkpoints: Vec<(DocId, Score)> = Vec::new(); + let mut limit = Score::MIN; + + // Sort by cost to use the cheapest as driver. + term_scorers.sort_by_key(|s| s.cost()); + + let (leader, secondaries) = term_scorers.split_first_mut().unwrap(); + + let mut doc = leader.doc(); + while doc != TERMINATED { + let mut all_match = true; + for secondary in secondaries.iter_mut() { + let secondary_doc = secondary.doc(); + let seek_result = if secondary_doc <= doc { + secondary.seek(doc) + } else { + secondary_doc + }; + if seek_result != doc { + all_match = false; + break; + } + } + + if all_match { + let score: Score = + leader.score() + secondaries.iter_mut().map(|s| s.score()).sum::(); + + if score > limit { + heap.push(Float(score)); + if heap.len() > top_k { + heap.pop().unwrap(); + } + if heap.len() == top_k { + limit = heap.peek().unwrap().0; + } + if !nearly_equals(score, limit) { + checkpoints.push((doc, score)); + } + } + } + doc = leader.advance(); + } + checkpoints + } + + const MAX_TERM_FREQ: u32 = 100u32; + + fn posting_list(max_doc: u32) -> BoxedStrategy> { + (1..max_doc + 1) + .prop_flat_map(move |doc_freq| { + ( + proptest::bits::bitset::sampled(doc_freq as usize, 0..max_doc as usize), + proptest::collection::vec(1u32..MAX_TERM_FREQ, doc_freq as usize), + ) + }) + .prop_map(|(docset, term_freqs)| { + docset + .iter() + .map(|doc| doc as u32) + .zip(term_freqs.iter().cloned()) + .collect::>() + }) + .boxed() + } + + #[expect(clippy::type_complexity)] + fn gen_term_scorers(num_scorers: usize) -> BoxedStrategy<(Vec>, Vec)> { + (1u32..100u32) + .prop_flat_map(move |max_doc: u32| { + ( + proptest::collection::vec(posting_list(max_doc), num_scorers), + proptest::collection::vec(2u32..10u32 * MAX_TERM_FREQ, max_doc as usize), + ) + }) + .boxed() + } + + fn test_block_wand_intersection_aux(posting_lists: &[Vec<(DocId, u32)>], fieldnorms: &[u32]) { + // Repeat docs 64 times to create multi-block scenarios, matching block_wand.rs test + // strategy. + const REPEAT: usize = 64; + let fieldnorms_expanded: Vec = fieldnorms + .iter() + .cloned() + .flat_map(|fieldnorm| std::iter::repeat_n(fieldnorm, REPEAT)) + .collect(); + + let postings_lists_expanded: Vec> = posting_lists + .iter() + .map(|posting_list| { + posting_list + .iter() + .cloned() + .flat_map(|(doc, term_freq)| { + (0_u32..REPEAT as u32).map(move |offset| { + ( + doc * (REPEAT as u32) + offset, + if offset == 0 { term_freq } else { 1 }, + ) + }) + }) + .collect::>() + }) + .collect(); + + let total_fieldnorms: u64 = fieldnorms_expanded + .iter() + .cloned() + .map(|fieldnorm| fieldnorm as u64) + .sum(); + let average_fieldnorm = (total_fieldnorms as Score) / (fieldnorms_expanded.len() as Score); + let max_doc = fieldnorms_expanded.len(); + + let make_scorers = || -> Vec { + postings_lists_expanded + .iter() + .map(|postings| { + let bm25_weight = Bm25Weight::for_one_term( + postings.len() as u64, + max_doc as u64, + average_fieldnorm, + ); + TermScorer::create_for_test(postings, &fieldnorms_expanded[..], bm25_weight) + }) + .collect() + }; + + for top_k in 1..4 { + let checkpoints_optimized = + compute_checkpoints_block_wand_intersection(make_scorers(), top_k); + let checkpoints_naive = compute_checkpoints_naive_intersection(make_scorers(), top_k); + assert_eq!( + checkpoints_optimized.len(), + checkpoints_naive.len(), + "Mismatch in checkpoint count for top_k={top_k}" + ); + for (&(left_doc, left_score), &(right_doc, right_score)) in + checkpoints_optimized.iter().zip(checkpoints_naive.iter()) + { + assert_eq!(left_doc, right_doc); + assert!( + nearly_equals(left_score, right_score), + "Score mismatch for doc {left_doc}: {left_score} vs {right_score}" + ); + } + } + } + + proptest! { + #![proptest_config(ProptestConfig::with_cases(500))] + #[test] + fn test_block_wand_intersection_two_scorers( + (posting_lists, fieldnorms) in gen_term_scorers(2) + ) { + test_block_wand_intersection_aux(&posting_lists[..], &fieldnorms[..]); + } + } + + proptest! { + #![proptest_config(ProptestConfig::with_cases(500))] + #[test] + fn test_block_wand_intersection_three_scorers( + (posting_lists, fieldnorms) in gen_term_scorers(3) + ) { + test_block_wand_intersection_aux(&posting_lists[..], &fieldnorms[..]); + } + } + + #[test] + fn test_block_wand_intersection_disjoint() { + // Two posting lists with no overlap — intersection is empty. + let fieldnorms: Vec = vec![10; 200]; + let average_fieldnorm = 10.0; + let postings_a: Vec<(DocId, u32)> = (0..100).map(|d| (d, 1)).collect(); + let postings_b: Vec<(DocId, u32)> = (100..200).map(|d| (d, 1)).collect(); + + let scorer_a = TermScorer::create_for_test( + &postings_a, + &fieldnorms, + Bm25Weight::for_one_term(100, 200, average_fieldnorm), + ); + let scorer_b = TermScorer::create_for_test( + &postings_b, + &fieldnorms, + Bm25Weight::for_one_term(100, 200, average_fieldnorm), + ); + + let checkpoints = compute_checkpoints_block_wand_intersection(vec![scorer_a, scorer_b], 10); + assert!(checkpoints.is_empty()); + } + + #[test] + fn test_block_wand_intersection_all_overlap() { + // Two posting lists with full overlap. + let fieldnorms: Vec = vec![10; 50]; + let average_fieldnorm = 10.0; + let postings: Vec<(DocId, u32)> = (0..50).map(|d| (d, 3)).collect(); + + let make_scorer = || { + TermScorer::create_for_test( + &postings, + &fieldnorms, + Bm25Weight::for_one_term(50, 50, average_fieldnorm), + ) + }; + + let checkpoints_opt = + compute_checkpoints_block_wand_intersection(vec![make_scorer(), make_scorer()], 5); + let checkpoints_naive = + compute_checkpoints_naive_intersection(vec![make_scorer(), make_scorer()], 5); + assert_eq!(checkpoints_opt.len(), checkpoints_naive.len()); + } +} diff --git a/src/query/boolean_query/block_wand.rs b/src/query/boolean_query/block_wand_union.rs similarity index 100% rename from src/query/boolean_query/block_wand.rs rename to src/query/boolean_query/block_wand_union.rs diff --git a/src/query/boolean_query/boolean_weight.rs b/src/query/boolean_query/boolean_weight.rs index 062449b8a..0df4f9abe 100644 --- a/src/query/boolean_query/boolean_weight.rs +++ b/src/query/boolean_query/boolean_weight.rs @@ -16,6 +16,7 @@ use crate::{DocId, Score}; enum SpecializedScorer { TermUnion(Vec), + TermIntersection(Vec), Other(Box), } @@ -93,6 +94,13 @@ fn into_box_scorer( BufferedUnionScorer::build(term_scorers, score_combiner_fn, num_docs); Box::new(union_scorer) } + SpecializedScorer::TermIntersection(term_scorers) => { + let boxed_scorers: Vec> = term_scorers + .into_iter() + .map(|s| Box::new(s) as Box) + .collect(); + intersect_scorers(boxed_scorers, num_docs) + } SpecializedScorer::Other(scorer) => scorer, } } @@ -297,14 +305,43 @@ impl BooleanWeight { // Result depends entirely on MUST + any removed AllScorers. let combined_all_scorer_count = must_special_scorer_counts.num_all_scorers + should_special_scorer_counts.num_all_scorers; - let boxed_scorer: Box = effective_must_scorer( - must_scorers, - combined_all_scorer_count, - reader.max_doc(), - num_docs, - ) - .unwrap_or_else(|| Box::new(EmptyScorer)); - SpecializedScorer::Other(boxed_scorer) + + // Try to detect a pure TermScorer intersection for block-max optimization. + // Preconditions: no removed AllScorers, at least 2 scorers, all TermScorer + // with frequency reading enabled. + if combined_all_scorer_count == 0 + && must_scorers.len() >= 2 + && must_scorers.iter().all(|s| s.is::()) + { + let term_scorers: Vec = must_scorers + .into_iter() + .map(|s| *(s.downcast::().map_err(|_| ()).unwrap())) + .collect(); + if term_scorers + .iter() + .all(|s| s.freq_reading_option() == FreqReadingOption::ReadFreq) + { + SpecializedScorer::TermIntersection(term_scorers) + } else { + let must_scorers: Vec> = term_scorers + .into_iter() + .map(|s| Box::new(s) as Box) + .collect(); + let boxed_scorer: Box = + effective_must_scorer(must_scorers, 0, reader.max_doc(), num_docs) + .unwrap_or_else(|| Box::new(EmptyScorer)); + SpecializedScorer::Other(boxed_scorer) + } + } else { + let boxed_scorer: Box = effective_must_scorer( + must_scorers, + combined_all_scorer_count, + reader.max_doc(), + num_docs, + ) + .unwrap_or_else(|| Box::new(EmptyScorer)); + SpecializedScorer::Other(boxed_scorer) + } } (ShouldScorersCombinationMethod::Optional(should_scorer), must_scorers) => { // Optional SHOULD: contributes to scoring but not required for matching. @@ -463,15 +500,21 @@ impl Weight for BooleanWeight crate::Result<()> { let scorer = self.complex_scorer(reader, 1.0, &self.score_combiner_fn)?; + let num_docs = reader.num_docs(); match scorer { SpecializedScorer::TermUnion(term_scorers) => { - let mut union_scorer = BufferedUnionScorer::build( - term_scorers, - &self.score_combiner_fn, - reader.num_docs(), - ); + let mut union_scorer = + BufferedUnionScorer::build(term_scorers, &self.score_combiner_fn, num_docs); for_each_scorer(&mut union_scorer, callback); } + SpecializedScorer::TermIntersection(term_scorers) => { + let mut intersection = into_box_scorer( + SpecializedScorer::TermIntersection(term_scorers), + &self.score_combiner_fn, + num_docs, + ); + for_each_scorer(intersection.as_mut(), callback); + } SpecializedScorer::Other(mut scorer) => { for_each_scorer(scorer.as_mut(), callback); } @@ -485,17 +528,23 @@ impl Weight for BooleanWeight crate::Result<()> { let scorer = self.complex_scorer(reader, 1.0, || DoNothingCombiner)?; + let num_docs = reader.num_docs(); let mut buffer = [0u32; COLLECT_BLOCK_BUFFER_LEN]; match scorer { SpecializedScorer::TermUnion(term_scorers) => { - let mut union_scorer = BufferedUnionScorer::build( - term_scorers, - &self.score_combiner_fn, - reader.num_docs(), - ); + let mut union_scorer = + BufferedUnionScorer::build(term_scorers, &self.score_combiner_fn, num_docs); for_each_docset_buffered(&mut union_scorer, &mut buffer, callback); } + SpecializedScorer::TermIntersection(term_scorers) => { + let mut intersection = into_box_scorer( + SpecializedScorer::TermIntersection(term_scorers), + DoNothingCombiner::default, + num_docs, + ); + for_each_docset_buffered(intersection.as_mut(), &mut buffer, callback); + } SpecializedScorer::Other(mut scorer) => { for_each_docset_buffered(scorer.as_mut(), &mut buffer, callback); } @@ -524,6 +573,9 @@ impl Weight for BooleanWeight { super::block_wand(term_scorers, threshold, callback); } + SpecializedScorer::TermIntersection(term_scorers) => { + super::block_wand_intersection(term_scorers, threshold, callback); + } SpecializedScorer::Other(mut scorer) => { for_each_pruning_scorer(scorer.as_mut(), threshold, callback); } diff --git a/src/query/boolean_query/mod.rs b/src/query/boolean_query/mod.rs index 681881c11..e0aada767 100644 --- a/src/query/boolean_query/mod.rs +++ b/src/query/boolean_query/mod.rs @@ -1,8 +1,10 @@ -mod block_wand; +mod block_wand_intersection; +mod block_wand_union; mod boolean_query; mod boolean_weight; -pub(crate) use self::block_wand::{block_wand, block_wand_single_scorer}; +pub(crate) use self::block_wand_intersection::block_wand_intersection; +pub(crate) use self::block_wand_union::{block_wand, block_wand_single_scorer}; pub use self::boolean_query::BooleanQuery; pub use self::boolean_weight::BooleanWeight; diff --git a/src/query/term_query/term_scorer.rs b/src/query/term_query/term_scorer.rs index a75648348..9cdadd7d4 100644 --- a/src/query/term_query/term_scorer.rs +++ b/src/query/term_query/term_scorer.rs @@ -1,6 +1,6 @@ use crate::docset::DocSet; use crate::fieldnorm::FieldNormReader; -use crate::postings::{FreqReadingOption, Postings, SegmentPostings}; +use crate::postings::{BlockSegmentPostings, FreqReadingOption, Postings, SegmentPostings}; use crate::query::bm25::Bm25Weight; use crate::query::{Explanation, Scorer}; use crate::{DocId, Score}; @@ -95,6 +95,21 @@ impl TermScorer { pub fn last_doc_in_block(&self) -> DocId { self.postings.block_cursor.skip_reader().last_doc_in_block() } + + /// Returns a mutable reference to the underlying block cursor. + pub(crate) fn block_cursor(&mut self) -> &mut BlockSegmentPostings { + &mut self.postings.block_cursor + } + + /// Returns a reference to the fieldnorm reader for batch lookups. + pub(crate) fn fieldnorm_reader(&self) -> &FieldNormReader { + &self.fieldnorm_reader + } + + /// Returns a reference to the BM25 weight for batch score computation. + pub(crate) fn bm25_weight(&self) -> &Bm25Weight { + &self.similarity_weight + } } impl DocSet for TermScorer {