From bf27a7b3a40dc4e77b0433f5f2aaf7df95ae2029 Mon Sep 17 00:00:00 2001 From: Paul Masurel Date: Tue, 12 May 2020 16:14:23 +0900 Subject: [PATCH] tried simplifying intersection code. --- src/postings/mod.rs | 8 +- src/query/boolean_query/mod.rs | 24 --- src/query/intersection.rs | 214 ++++-------------------- src/query/phrase_query/phrase_scorer.rs | 2 +- src/query/union.rs | 25 ++- 5 files changed, 62 insertions(+), 211 deletions(-) diff --git a/src/postings/mod.rs b/src/postings/mod.rs index b66beb413..c732f17ae 100644 --- a/src/postings/mod.rs +++ b/src/postings/mod.rs @@ -673,10 +673,10 @@ mod bench { .read_postings(&*TERM_D, IndexRecordOption::Basic) .unwrap(); let mut intersection = Intersection::new(vec![ - segment_postings_a, - segment_postings_b, - segment_postings_c, - segment_postings_d, + segment_postings_a.into(), + segment_postings_b.into(), + segment_postings_c.into(), + segment_postings_d.into(), ]); while intersection.advance() {} }); diff --git a/src/query/boolean_query/mod.rs b/src/query/boolean_query/mod.rs index 1c6a341ef..7c5977666 100644 --- a/src/query/boolean_query/mod.rs +++ b/src/query/boolean_query/mod.rs @@ -10,7 +10,6 @@ mod tests { use crate::collector::tests::TEST_COLLECTOR_WITH_SCORE; use crate::query::score_combiner::SumWithCoordsCombiner; use crate::query::term_query::TermScorer; - use crate::query::Intersection; use crate::query::Occur; use crate::query::Query; use crate::query::QueryParser; @@ -64,29 +63,6 @@ mod tests { assert!(scorer.is::()); } - #[test] - pub fn test_boolean_termonly_intersection() { - let (index, text_field) = aux_test_helper(); - let query_parser = QueryParser::for_index(&index, vec![text_field]); - let searcher = index.reader().unwrap().searcher(); - { - let query = query_parser.parse_query("+a +b +c").unwrap(); - let weight = query.weight(&searcher, true).unwrap(); - let scorer = weight - .scorer(searcher.segment_reader(0u32), 1.0f32) - .unwrap(); - assert!(scorer.is::>()); - } - { - let query = query_parser.parse_query("+a +(b c)").unwrap(); - let weight = query.weight(&searcher, true).unwrap(); - let scorer = weight - .scorer(searcher.segment_reader(0u32), 1.0f32) - .unwrap(); - assert!(scorer.is::>>()); - } - } - #[test] pub fn test_boolean_reqopt() { let (index, text_field) = aux_test_helper(); diff --git a/src/query/intersection.rs b/src/query/intersection.rs index 9140f7fd5..dda5aacf9 100644 --- a/src/query/intersection.rs +++ b/src/query/intersection.rs @@ -1,5 +1,4 @@ use crate::docset::{DocSet, SkipResult}; -use crate::query::term_query::TermScorer; use crate::query::EmptyScorer; use crate::query::Scorer; use crate::DocId; @@ -21,207 +20,68 @@ pub fn intersect_scorers(mut scorers: Vec>) -> Box { return scorers.pop().unwrap(); } // We know that we have at least 2 elements. - let num_docsets = scorers.len(); - scorers.sort_by(|left, right| right.size_hint().cmp(&left.size_hint())); - let left = scorers.pop().unwrap(); - let right = scorers.pop().unwrap(); - scorers.reverse(); - let all_term_scorers = [&left, &right] - .iter() - .all(|&scorer| scorer.is::()); - if all_term_scorers { - return Box::new(Intersection { - left: *(left.downcast::().map_err(|_| ()).unwrap()), - right: *(right.downcast::().map_err(|_| ()).unwrap()), - others: scorers, - num_docsets, - }); - } - Box::new(Intersection { - left, - right, - others: scorers, - num_docsets, - }) + Box::new(Intersection::new(scorers)) } /// Creates a `DocSet` that iterate through the intersection of two or more `DocSet`s. -pub struct Intersection> { - left: TDocSet, - right: TDocSet, - others: Vec, - num_docsets: usize, +pub struct Intersection { + docsets: Vec, } -impl Intersection { - pub(crate) fn new(mut docsets: Vec) -> Intersection { - let num_docsets = docsets.len(); - assert!(num_docsets >= 2); - docsets.sort_by(|left, right| right.size_hint().cmp(&left.size_hint())); - let left = docsets.pop().unwrap(); - let right = docsets.pop().unwrap(); - docsets.reverse(); - Intersection { - left, - right, - others: docsets, - num_docsets, - } +impl Intersection { + pub(crate) fn new(mut docsets: Vec) -> Intersection { + assert!(docsets.len() >= 2); + docsets.sort_by_key(|scorer| scorer.size_hint()); + Intersection { docsets } + } + + pub fn docset_mut_specialized(&mut self, ord: usize) -> &mut TDocSet { + &mut self.docsets[ord] } } -impl Intersection { - pub(crate) fn docset_mut_specialized(&mut self, ord: usize) -> &mut TDocSet { - match ord { - 0 => &mut self.left, - 1 => &mut self.right, - n => &mut self.others[n - 2], - } - } -} - -impl Intersection { - pub(crate) fn docset_mut(&mut self, ord: usize) -> &mut dyn DocSet { - match ord { - 0 => &mut self.left, - 1 => &mut self.right, - n => &mut self.others[n - 2], - } - } -} - -impl DocSet for Intersection { +impl DocSet for Intersection { fn advance(&mut self) -> bool { - let (left, right) = (&mut self.left, &mut self.right); - - if !left.advance() { + if !self.docsets[0].advance() { return false; } - - let mut candidate = left.doc(); - let mut other_candidate_ord: usize = usize::max_value(); - + let mut candidate_emitter = 0; + let mut candidate = self.docsets[0].doc(); 'outer: loop { - // In the first part we look for a document in the intersection - // of the two rarest `DocSet` in the intersection. - loop { - match right.skip_next(candidate) { - SkipResult::Reached => { - break; - } - SkipResult::OverStep => { - candidate = right.doc(); - other_candidate_ord = usize::max_value(); - } - SkipResult::End => { - return false; - } - } - match left.skip_next(candidate) { - SkipResult::Reached => { - break; - } - SkipResult::OverStep => { - candidate = left.doc(); - other_candidate_ord = usize::max_value(); - } - SkipResult::End => { - return false; - } - } - } - // test the remaining scorers; - for (ord, docset) in self.others.iter_mut().enumerate() { - if ord == other_candidate_ord { + for (i, docset) in self.docsets.iter_mut().enumerate() { + if i == candidate_emitter { continue; } - // `candidate_ord` is already at the - // right position. - // - // Calling `skip_next` would advance this docset - // and miss it. match docset.skip_next(candidate) { - SkipResult::Reached => {} - SkipResult::OverStep => { - // this is not in the intersection, - // let's update our candidate. - candidate = docset.doc(); - match left.skip_next(candidate) { - SkipResult::Reached => { - other_candidate_ord = ord; - } - SkipResult::OverStep => { - candidate = left.doc(); - other_candidate_ord = usize::max_value(); - } - SkipResult::End => { - return false; - } - } - continue 'outer; - } SkipResult::End => { return false; } + SkipResult::OverStep => { + candidate = docset.doc(); + candidate_emitter = i; + continue 'outer; + } + SkipResult::Reached => {} } } return true; } } - fn skip_next(&mut self, target: DocId) -> SkipResult { - // We optimize skipping by skipping every single member - // of the intersection to target. - let mut current_target: DocId = target; - let mut current_ord = self.num_docsets; - - 'outer: loop { - for ord in 0..self.num_docsets { - let docset = self.docset_mut(ord); - if ord == current_ord { - continue; - } - match docset.skip_next(current_target) { - SkipResult::End => { - return SkipResult::End; - } - SkipResult::OverStep => { - // update the target - // for the remaining members of the intersection. - current_target = docset.doc(); - current_ord = ord; - continue 'outer; - } - SkipResult::Reached => {} - } - } - if target == current_target { - return SkipResult::Reached; - } else { - assert!(current_target > target); - return SkipResult::OverStep; - } - } - } + // TODO implement skip_next fn doc(&self) -> DocId { - self.left.doc() + self.docsets[0].doc() } fn size_hint(&self) -> u32 { - self.left.size_hint() + self.docsets[0].size_hint() } } -impl Scorer for Intersection -where - TScorer: Scorer, - TOtherScorer: Scorer, -{ +impl Scorer for Intersection { fn score(&mut self) -> Score { - self.left.score() - + self.right.score() - + self.others.iter_mut().map(Scorer::score).sum::() + self.docsets.iter_mut().map(Scorer::score).sum::() } } @@ -237,7 +97,7 @@ mod tests { { let left = VecDocSet::from(vec![1, 3, 9]); let right = VecDocSet::from(vec![3, 4, 9, 18]); - let mut intersection = Intersection::new(vec![left, right]); + let mut intersection = Intersection::new(vec![Box::new(left), Box::new(right)]); assert!(intersection.advance()); assert_eq!(intersection.doc(), 3); assert!(intersection.advance()); @@ -245,9 +105,9 @@ mod tests { assert!(!intersection.advance()); } { - let a = VecDocSet::from(vec![1, 3, 9]); - let b = VecDocSet::from(vec![3, 4, 9, 18]); - let c = VecDocSet::from(vec![1, 5, 9, 111]); + let a = Box::new(VecDocSet::from(vec![1, 3, 9])); + let b = Box::new(VecDocSet::from(vec![3, 4, 9, 18])); + let c = Box::new(VecDocSet::from(vec![1, 5, 9, 111])); let mut intersection = Intersection::new(vec![a, b, c]); assert!(intersection.advance()); assert_eq!(intersection.doc(), 9); @@ -257,8 +117,8 @@ mod tests { #[test] fn test_intersection_zero() { - let left = VecDocSet::from(vec![0]); - let right = VecDocSet::from(vec![0]); + let left = Box::new(VecDocSet::from(vec![0])); + let right = Box::new(VecDocSet::from(vec![0])); let mut intersection = Intersection::new(vec![left, right]); assert!(intersection.advance()); assert_eq!(intersection.doc(), 0); @@ -266,8 +126,8 @@ mod tests { #[test] fn test_intersection_skip() { - let left = VecDocSet::from(vec![0, 1, 2, 4]); - let right = VecDocSet::from(vec![2, 5]); + let left = Box::new(VecDocSet::from(vec![0, 1, 2, 4])); + let right = Box::new(VecDocSet::from(vec![2, 5])); let mut intersection = Intersection::new(vec![left, right]); assert_eq!(intersection.skip_next(2), SkipResult::Reached); assert_eq!(intersection.doc(), 2); diff --git a/src/query/phrase_query/phrase_scorer.rs b/src/query/phrase_query/phrase_scorer.rs index 3a0902f91..6731445e7 100644 --- a/src/query/phrase_query/phrase_scorer.rs +++ b/src/query/phrase_query/phrase_scorer.rs @@ -43,7 +43,7 @@ impl DocSet for PostingsWithOffset { } pub struct PhraseScorer { - intersection_docset: Intersection, PostingsWithOffset>, + intersection_docset: Intersection>, num_terms: usize, left: Vec, right: Vec, diff --git a/src/query/union.rs b/src/query/union.rs index 7e27ac877..5bdb660fb 100644 --- a/src/query/union.rs +++ b/src/query/union.rs @@ -137,12 +137,11 @@ where if self.advance_buffered() { return true; } - if self.refill() { - self.advance(); - true - } else { - false + if !self.refill() { + return false; } + self.advance(); + true } fn skip_next(&mut self, target: DocId) -> SkipResult { @@ -260,6 +259,22 @@ where fn score(&mut self) -> Score { self.score } + + fn for_each(&mut self, callback: &mut dyn FnMut(DocId, Score)) { + while self.refill() { + let offset = self.offset; + for (cursor, bitset) in self.bitsets.iter_mut().enumerate() { + while let Some(val) = bitset.pop_lowest() { + let delta = val + 64 * cursor as u32; + let doc: DocId = offset + delta; + let score_combiner = &mut self.scores[delta as usize]; + let score = score_combiner.score(); + score_combiner.clear(); + callback(doc, score); + } + } + } + } } #[cfg(test)]