Compare commits

...

2 Commits

Author SHA1 Message Date
Paul Masurel
da8316d5c6 simplification of intersection 2020-05-12 16:56:02 +09:00
Paul Masurel
bf27a7b3a4 tried simplifying intersection code. 2020-05-12 16:14:23 +09:00
5 changed files with 98 additions and 212 deletions

View File

@@ -673,10 +673,10 @@ mod bench {
.read_postings(&*TERM_D, IndexRecordOption::Basic)
.unwrap();
let mut intersection = Intersection::new(vec![
segment_postings_a,
segment_postings_b,
segment_postings_c,
segment_postings_d,
segment_postings_a.into(),
segment_postings_b.into(),
segment_postings_c.into(),
segment_postings_d.into(),
]);
while intersection.advance() {}
});

View File

@@ -10,7 +10,6 @@ mod tests {
use crate::collector::tests::TEST_COLLECTOR_WITH_SCORE;
use crate::query::score_combiner::SumWithCoordsCombiner;
use crate::query::term_query::TermScorer;
use crate::query::Intersection;
use crate::query::Occur;
use crate::query::Query;
use crate::query::QueryParser;
@@ -64,29 +63,6 @@ mod tests {
assert!(scorer.is::<TermScorer>());
}
#[test]
pub fn test_boolean_termonly_intersection() {
let (index, text_field) = aux_test_helper();
let query_parser = QueryParser::for_index(&index, vec![text_field]);
let searcher = index.reader().unwrap().searcher();
{
let query = query_parser.parse_query("+a +b +c").unwrap();
let weight = query.weight(&searcher, true).unwrap();
let scorer = weight
.scorer(searcher.segment_reader(0u32), 1.0f32)
.unwrap();
assert!(scorer.is::<Intersection<TermScorer>>());
}
{
let query = query_parser.parse_query("+a +(b c)").unwrap();
let weight = query.weight(&searcher, true).unwrap();
let scorer = weight
.scorer(searcher.segment_reader(0u32), 1.0f32)
.unwrap();
assert!(scorer.is::<Intersection<Box<dyn Scorer>>>());
}
}
#[test]
pub fn test_boolean_reqopt() {
let (index, text_field) = aux_test_helper();

View File

@@ -1,5 +1,4 @@
use crate::docset::{DocSet, SkipResult};
use crate::query::term_query::TermScorer;
use crate::query::EmptyScorer;
use crate::query::Scorer;
use crate::DocId;
@@ -21,208 +20,104 @@ pub fn intersect_scorers(mut scorers: Vec<Box<dyn Scorer>>) -> Box<dyn Scorer> {
return scorers.pop().unwrap();
}
// We know that we have at least 2 elements.
let num_docsets = scorers.len();
scorers.sort_by(|left, right| right.size_hint().cmp(&left.size_hint()));
let left = scorers.pop().unwrap();
let right = scorers.pop().unwrap();
scorers.reverse();
let all_term_scorers = [&left, &right]
.iter()
.all(|&scorer| scorer.is::<TermScorer>());
if all_term_scorers {
return Box::new(Intersection {
left: *(left.downcast::<TermScorer>().map_err(|_| ()).unwrap()),
right: *(right.downcast::<TermScorer>().map_err(|_| ()).unwrap()),
others: scorers,
num_docsets,
});
}
Box::new(Intersection {
left,
right,
others: scorers,
num_docsets,
})
Box::new(Intersection::new(scorers))
}
/// Creates a `DocSet` that iterate through the intersection of two or more `DocSet`s.
pub struct Intersection<TDocSet: DocSet, TOtherDocSet: DocSet = Box<dyn Scorer>> {
left: TDocSet,
right: TDocSet,
others: Vec<TOtherDocSet>,
num_docsets: usize,
pub struct Intersection<TDocSet: DocSet> {
docsets: Vec<TDocSet>,
}
impl<TDocSet: DocSet> Intersection<TDocSet, TDocSet> {
pub(crate) fn new(mut docsets: Vec<TDocSet>) -> Intersection<TDocSet, TDocSet> {
let num_docsets = docsets.len();
assert!(num_docsets >= 2);
docsets.sort_by(|left, right| right.size_hint().cmp(&left.size_hint()));
let left = docsets.pop().unwrap();
let right = docsets.pop().unwrap();
docsets.reverse();
Intersection {
left,
right,
others: docsets,
num_docsets,
}
impl<TDocSet: DocSet> Intersection<TDocSet> {
pub(crate) fn new(mut docsets: Vec<TDocSet>) -> Intersection<TDocSet> {
assert!(docsets.len() >= 2);
docsets.sort_by_key(|scorer| scorer.size_hint());
Intersection { docsets }
}
pub(crate) fn docset_mut(&mut self, ord: usize) -> &mut TDocSet {
&mut self.docsets[ord]
}
}
impl<TDocSet: DocSet> Intersection<TDocSet, TDocSet> {
pub(crate) fn docset_mut_specialized(&mut self, ord: usize) -> &mut TDocSet {
match ord {
0 => &mut self.left,
1 => &mut self.right,
n => &mut self.others[n - 2],
}
}
}
impl<TDocSet: DocSet, TOtherDocSet: DocSet> Intersection<TDocSet, TOtherDocSet> {
pub(crate) fn docset_mut(&mut self, ord: usize) -> &mut dyn DocSet {
match ord {
0 => &mut self.left,
1 => &mut self.right,
n => &mut self.others[n - 2],
}
}
}
impl<TDocSet: DocSet, TOtherDocSet: DocSet> DocSet for Intersection<TDocSet, TOtherDocSet> {
impl<TDocSet: DocSet> DocSet for Intersection<TDocSet> {
fn advance(&mut self) -> bool {
let (left, right) = (&mut self.left, &mut self.right);
if !left.advance() {
if !self.docsets[0].advance() {
return false;
}
let mut candidate = left.doc();
let mut other_candidate_ord: usize = usize::max_value();
let mut candidate_emitter = 0;
let mut candidate = self.docsets[0].doc();
'outer: loop {
// In the first part we look for a document in the intersection
// of the two rarest `DocSet` in the intersection.
loop {
match right.skip_next(candidate) {
SkipResult::Reached => {
break;
}
SkipResult::OverStep => {
candidate = right.doc();
other_candidate_ord = usize::max_value();
}
SkipResult::End => {
return false;
}
}
match left.skip_next(candidate) {
SkipResult::Reached => {
break;
}
SkipResult::OverStep => {
candidate = left.doc();
other_candidate_ord = usize::max_value();
}
SkipResult::End => {
return false;
}
}
}
// test the remaining scorers;
for (ord, docset) in self.others.iter_mut().enumerate() {
if ord == other_candidate_ord {
for (i, docset) in self.docsets.iter_mut().enumerate() {
if i == candidate_emitter {
continue;
}
// `candidate_ord` is already at the
// right position.
//
// Calling `skip_next` would advance this docset
// and miss it.
match docset.skip_next(candidate) {
SkipResult::Reached => {}
SkipResult::OverStep => {
// this is not in the intersection,
// let's update our candidate.
candidate = docset.doc();
match left.skip_next(candidate) {
SkipResult::Reached => {
other_candidate_ord = ord;
}
SkipResult::OverStep => {
candidate = left.doc();
other_candidate_ord = usize::max_value();
}
SkipResult::End => {
return false;
}
}
continue 'outer;
}
SkipResult::End => {
return false;
}
SkipResult::OverStep => {
candidate = docset.doc();
candidate_emitter = i;
continue 'outer;
}
SkipResult::Reached => {}
}
}
return true;
}
}
fn skip_next(&mut self, target: DocId) -> SkipResult {
// We optimize skipping by skipping every single member
// of the intersection to target.
let mut current_target: DocId = target;
let mut current_ord = self.num_docsets;
// TODO implement skip_next
fn doc(&self) -> DocId {
self.docsets[0].doc()
}
fn size_hint(&self) -> u32 {
self.docsets[0].size_hint()
}
}
impl<TDocSet: Scorer + DocSet> Scorer for Intersection<TDocSet> {
fn score(&mut self) -> Score {
self.docsets.iter_mut().map(Scorer::score).sum::<Score>()
}
fn for_each(&mut self, callback: &mut dyn FnMut(DocId, Score)) {
if !self.docsets[0].advance() {
return;
}
let mut candidate_emitter = 0;
let mut candidate = self.docsets[0].doc();
'outer: loop {
for ord in 0..self.num_docsets {
let docset = self.docset_mut(ord);
if ord == current_ord {
for (i, docset) in self.docsets.iter_mut().enumerate() {
if i == candidate_emitter {
continue;
}
match docset.skip_next(current_target) {
match docset.skip_next(candidate) {
SkipResult::End => {
return SkipResult::End;
return;
}
SkipResult::OverStep => {
// update the target
// for the remaining members of the intersection.
current_target = docset.doc();
current_ord = ord;
candidate = docset.doc();
candidate_emitter = i;
continue 'outer;
}
SkipResult::Reached => {}
}
}
if target == current_target {
return SkipResult::Reached;
} else {
assert!(current_target > target);
return SkipResult::OverStep;
callback(candidate, self.score());
if !self.docsets[0].advance() {
return;
}
candidate_emitter = 0;
candidate = self.docsets[0].doc();
}
}
fn doc(&self) -> DocId {
self.left.doc()
}
fn size_hint(&self) -> u32 {
self.left.size_hint()
}
}
impl<TScorer, TOtherScorer> Scorer for Intersection<TScorer, TOtherScorer>
where
TScorer: Scorer,
TOtherScorer: Scorer,
{
fn score(&mut self) -> Score {
self.left.score()
+ self.right.score()
+ self.others.iter_mut().map(Scorer::score).sum::<Score>()
}
}
#[cfg(test)]
@@ -237,7 +132,7 @@ mod tests {
{
let left = VecDocSet::from(vec![1, 3, 9]);
let right = VecDocSet::from(vec![3, 4, 9, 18]);
let mut intersection = Intersection::new(vec![left, right]);
let mut intersection = Intersection::new(vec![Box::new(left), Box::new(right)]);
assert!(intersection.advance());
assert_eq!(intersection.doc(), 3);
assert!(intersection.advance());
@@ -245,9 +140,9 @@ mod tests {
assert!(!intersection.advance());
}
{
let a = VecDocSet::from(vec![1, 3, 9]);
let b = VecDocSet::from(vec![3, 4, 9, 18]);
let c = VecDocSet::from(vec![1, 5, 9, 111]);
let a = Box::new(VecDocSet::from(vec![1, 3, 9]));
let b = Box::new(VecDocSet::from(vec![3, 4, 9, 18]));
let c = Box::new(VecDocSet::from(vec![1, 5, 9, 111]));
let mut intersection = Intersection::new(vec![a, b, c]);
assert!(intersection.advance());
assert_eq!(intersection.doc(), 9);
@@ -257,8 +152,8 @@ mod tests {
#[test]
fn test_intersection_zero() {
let left = VecDocSet::from(vec![0]);
let right = VecDocSet::from(vec![0]);
let left = Box::new(VecDocSet::from(vec![0]));
let right = Box::new(VecDocSet::from(vec![0]));
let mut intersection = Intersection::new(vec![left, right]);
assert!(intersection.advance());
assert_eq!(intersection.doc(), 0);
@@ -266,8 +161,8 @@ mod tests {
#[test]
fn test_intersection_skip() {
let left = VecDocSet::from(vec![0, 1, 2, 4]);
let right = VecDocSet::from(vec![2, 5]);
let left = Box::new(VecDocSet::from(vec![0, 1, 2, 4]));
let right = Box::new(VecDocSet::from(vec![2, 5]));
let mut intersection = Intersection::new(vec![left, right]);
assert_eq!(intersection.skip_next(2), SkipResult::Reached);
assert_eq!(intersection.doc(), 2);

View File

@@ -43,7 +43,7 @@ impl<TPostings: Postings> DocSet for PostingsWithOffset<TPostings> {
}
pub struct PhraseScorer<TPostings: Postings> {
intersection_docset: Intersection<PostingsWithOffset<TPostings>, PostingsWithOffset<TPostings>>,
intersection_docset: Intersection<PostingsWithOffset<TPostings>>,
num_terms: usize,
left: Vec<u32>,
right: Vec<u32>,
@@ -177,13 +177,13 @@ impl<TPostings: Postings> PhraseScorer<TPostings> {
fn phrase_exists(&mut self) -> bool {
self.intersection_docset
.docset_mut_specialized(0)
.docset_mut(0)
.positions(&mut self.left);
let mut intersection_len = self.left.len();
for i in 1..self.num_terms - 1 {
{
self.intersection_docset
.docset_mut_specialized(i)
.docset_mut(i)
.positions(&mut self.right);
}
intersection_len = intersection(&mut self.left[..intersection_len], &self.right[..]);
@@ -193,7 +193,7 @@ impl<TPostings: Postings> PhraseScorer<TPostings> {
}
self.intersection_docset
.docset_mut_specialized(self.num_terms - 1)
.docset_mut(self.num_terms - 1)
.positions(&mut self.right);
intersection_exists(&self.left[..intersection_len], &self.right[..])
}
@@ -201,14 +201,14 @@ impl<TPostings: Postings> PhraseScorer<TPostings> {
fn compute_phrase_count(&mut self) -> u32 {
{
self.intersection_docset
.docset_mut_specialized(0)
.docset_mut(0)
.positions(&mut self.left);
}
let mut intersection_len = self.left.len();
for i in 1..self.num_terms - 1 {
{
self.intersection_docset
.docset_mut_specialized(i)
.docset_mut(i)
.positions(&mut self.right);
}
intersection_len = intersection(&mut self.left[..intersection_len], &self.right[..]);
@@ -218,7 +218,7 @@ impl<TPostings: Postings> PhraseScorer<TPostings> {
}
self.intersection_docset
.docset_mut_specialized(self.num_terms - 1)
.docset_mut(self.num_terms - 1)
.positions(&mut self.right);
intersection_count(&self.left[..intersection_len], &self.right[..]) as u32
}

View File

@@ -137,12 +137,11 @@ where
if self.advance_buffered() {
return true;
}
if self.refill() {
self.advance();
true
} else {
false
if !self.refill() {
return false;
}
self.advance();
true
}
fn skip_next(&mut self, target: DocId) -> SkipResult {
@@ -260,6 +259,22 @@ where
fn score(&mut self) -> Score {
self.score
}
fn for_each(&mut self, callback: &mut dyn FnMut(DocId, Score)) {
while self.refill() {
let offset = self.offset;
for (cursor, bitset) in self.bitsets.iter_mut().enumerate() {
while let Some(val) = bitset.pop_lowest() {
let delta = val + 64 * cursor as u32;
let doc: DocId = offset + delta;
let score_combiner = &mut self.scores[delta as usize];
let score = score_combiner.score();
score_combiner.clear();
callback(doc, score);
}
}
}
}
}
#[cfg(test)]