diff --git a/src/docset.rs b/src/docset.rs index 01ea1125a..06ecc46b3 100644 --- a/src/docset.rs +++ b/src/docset.rs @@ -51,31 +51,55 @@ pub trait DocSet: Send { doc } - /// Seeks to the target if possible and returns true if the target is in the DocSet. + /// !!!Dragons ahead!!! + /// In spirit, this is an approximate and dangerous version of `seek`. + /// + /// It can leave the DocSet in an `invalid` state and might return a + /// lower bound of what the result of Seek would have been. + /// + /// + /// More accurately it returns either: + /// - Found if the target is in the docset. In that case, the DocSet is left in a valid state. + /// - SeekLowerBound(seek_lower_bound) if the target is not in the docset. In that case, The + /// DocSet can be the left in a invalid state. The DocSet should then only receives call to + /// `seek_danger(..)` until it returns `Found`, and get back to a valid state. + /// + /// `seek_lower_bound` can be any `DocId` (in the docset or not) as long as it is in + /// `(target .. seek_result]` where `seek_result` is the first document in the docset greater + /// than to `target`. + /// + /// `seek_danger` may return `SeekLowerBound(TERMINATED)`. + /// + /// Calling `seek_danger` with TERMINATED as a target is allowed, + /// and should always return NewTarget(TERMINATED) or anything larger as TERMINATED is NOT in + /// the DocSet. /// /// DocSets that already have an efficient `seek` method don't need to implement - /// `seek_into_the_danger_zone`. All wrapper DocSets should forward - /// `seek_into_the_danger_zone` to the underlying DocSet. + /// `seek_danger`. /// - /// ## API Behaviour - /// If `seek_into_the_danger_zone` is returning true, a call to `doc()` has to return target. - /// If `seek_into_the_danger_zone` is returning false, a call to `doc()` may return any doc - /// between the last doc that matched and target or a doc that is a valid next hit after - /// target. The DocSet is considered to be in an invalid state until - /// `seek_into_the_danger_zone` returns true again. - /// - /// `target` needs to be equal or larger than `doc` when in a valid state. - /// - /// Consecutive calls are not allowed to have decreasing `target` values. - /// - /// # Warning - /// This is an advanced API used by intersection. The API contract is tricky, avoid using it. - fn seek_into_the_danger_zone(&mut self, target: DocId) -> bool { - let current_doc = self.doc(); - if current_doc < target { - self.seek(target); + /// Consecutive calls to seek_danger are guaranteed to have strictly increasing `target` + /// values. + fn seek_danger(&mut self, target: DocId) -> SeekDangerResult { + if target >= TERMINATED { + debug_assert!(target == TERMINATED); + // No need to advance. + return SeekDangerResult::SeekLowerBound(target); + } + + // The default implementation does not include any + // `danger zone` behavior. + // + // It does not leave the scorer in an invalid state. + // For this reason, we can safely call `self.doc()`. + let mut doc = self.doc(); + if doc < target { + doc = self.seek(target); + } + if doc == target { + SeekDangerResult::Found + } else { + SeekDangerResult::SeekLowerBound(self.doc()) } - self.doc() == target } /// Fills a given mutable buffer with the next doc ids from the @@ -166,6 +190,17 @@ pub trait DocSet: Send { } } +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum SeekDangerResult { + /// The target was found in the DocSet. + Found, + /// The target was not found in the DocSet. + /// We return a range in which the value could be. + /// The given target can be any DocId, that is <= than the first document + /// in the docset after the target. + SeekLowerBound(DocId), +} + impl DocSet for &mut dyn DocSet { fn advance(&mut self) -> u32 { (**self).advance() @@ -175,8 +210,8 @@ impl DocSet for &mut dyn DocSet { (**self).seek(target) } - fn seek_into_the_danger_zone(&mut self, target: DocId) -> bool { - (**self).seek_into_the_danger_zone(target) + fn seek_danger(&mut self, target: DocId) -> SeekDangerResult { + (**self).seek_danger(target) } fn doc(&self) -> u32 { @@ -211,9 +246,9 @@ impl DocSet for Box { unboxed.seek(target) } - fn seek_into_the_danger_zone(&mut self, target: DocId) -> bool { + fn seek_danger(&mut self, target: DocId) -> SeekDangerResult { let unboxed: &mut TDocSet = self.borrow_mut(); - unboxed.seek_into_the_danger_zone(target) + unboxed.seek_danger(target) } fn fill_buffer(&mut self, buffer: &mut [DocId; COLLECT_BLOCK_BUFFER_LEN]) -> usize { diff --git a/src/query/boost_query.rs b/src/query/boost_query.rs index cc4c10f7a..69847d750 100644 --- a/src/query/boost_query.rs +++ b/src/query/boost_query.rs @@ -1,6 +1,6 @@ use std::fmt; -use crate::docset::COLLECT_BLOCK_BUFFER_LEN; +use crate::docset::{SeekDangerResult, COLLECT_BLOCK_BUFFER_LEN}; use crate::fastfield::AliveBitSet; use crate::query::{EnableScoring, Explanation, Query, Scorer, Weight}; use crate::{DocId, DocSet, Score, SegmentReader, Term}; @@ -104,8 +104,8 @@ impl DocSet for BoostScorer { fn seek(&mut self, target: DocId) -> DocId { self.underlying.seek(target) } - fn seek_into_the_danger_zone(&mut self, target: DocId) -> bool { - self.underlying.seek_into_the_danger_zone(target) + fn seek_danger(&mut self, target: DocId) -> SeekDangerResult { + self.underlying.seek_danger(target) } fn fill_buffer(&mut self, buffer: &mut [DocId; COLLECT_BLOCK_BUFFER_LEN]) -> usize { diff --git a/src/query/disjunction.rs b/src/query/disjunction.rs index ca7eab20d..2b4b54c00 100644 --- a/src/query/disjunction.rs +++ b/src/query/disjunction.rs @@ -1,6 +1,7 @@ use std::cmp::Ordering; use std::collections::BinaryHeap; +use crate::docset::SeekDangerResult; use crate::query::score_combiner::DoNothingCombiner; use crate::query::{ScoreCombiner, Scorer}; use crate::{DocId, DocSet, Score, TERMINATED}; @@ -67,10 +68,12 @@ impl DocSet for ScorerWrapper { self.current_doc = doc_id; doc_id } - fn seek_into_the_danger_zone(&mut self, target: DocId) -> bool { - let found = self.scorer.seek_into_the_danger_zone(target); - self.current_doc = self.scorer.doc(); - found + fn seek_danger(&mut self, target: DocId) -> SeekDangerResult { + let result = self.scorer.seek_danger(target); + if result == SeekDangerResult::Found { + self.current_doc = target; + } + result } fn doc(&self) -> DocId { diff --git a/src/query/intersection.rs b/src/query/intersection.rs index d536dcf05..fcf61ee2a 100644 --- a/src/query/intersection.rs +++ b/src/query/intersection.rs @@ -1,5 +1,5 @@ use super::size_hint::estimate_intersection; -use crate::docset::{DocSet, TERMINATED}; +use crate::docset::{DocSet, SeekDangerResult, TERMINATED}; use crate::query::term_query::TermScorer; use crate::query::{EmptyScorer, Scorer}; use crate::{DocId, Score}; @@ -108,46 +108,63 @@ impl DocSet for Intersection DocId { let (left, right) = (&mut self.left, &mut self.right); - let mut candidate = left.advance(); - if candidate == TERMINATED { - return TERMINATED; - } - loop { - // In the first part we look for a document in the intersection - // of the two rarest `DocSet` in the intersection. + // Invariant: + // - candidate is always <= to the next document in the intersection. + // - candidate strictly increases at every occurence of the loop. + let mut candidate = 0; - loop { - if right.seek_into_the_danger_zone(candidate) { - break; - } - let right_doc = right.doc(); - // TODO: Think about which value would make sense here - // It depends on the DocSet implementation, when a seek would outweigh an advance. - if right_doc > candidate.wrapping_add(100) { - candidate = left.seek(right_doc); - } else { - candidate = left.advance(); - } - if candidate == TERMINATED { - return TERMINATED; - } - } + // Termination: candidate strictly increases. + 'outer: while candidate < TERMINATED { + // As we enter the loop, we should always have candidate < next_doc. - debug_assert_eq!(left.doc(), right.doc()); - // test the remaining scorers - if self - .others - .iter_mut() - .all(|docset| docset.seek_into_the_danger_zone(candidate)) + // This step always increases candidate. + // + // TODO: Think about which value would make sense here + // It depends on the DocSet implementation, when a seek would outweigh an advance. + candidate = if candidate > left.doc().wrapping_add(100) { + left.seek(candidate) + } else { + left.advance() + }; + + // Left is positionned on `candidate`. + debug_assert_eq!(left.doc(), candidate); + + if let SeekDangerResult::SeekLowerBound(seek_lower_bound) = right.seek_danger(candidate) { - debug_assert_eq!(candidate, self.left.doc()); - debug_assert_eq!(candidate, self.right.doc()); - debug_assert!(self.others.iter().all(|docset| docset.doc() == candidate)); - return candidate; + // The max is technically useless but it makes the invariant + // easier to proofread. + debug_assert!(seek_lower_bound >= candidate); + candidate = seek_lower_bound; + continue; } - candidate = left.advance(); + + // Left and right are positionned on `candidate`. + debug_assert_eq!(right.doc(), candidate); + + for other in &mut self.others { + if let SeekDangerResult::SeekLowerBound(seek_lower_bound) = + other.seek_danger(candidate) + { + // One of the scorer does not match, let's restart at the top of the loop. + debug_assert!(seek_lower_bound >= candidate); + candidate = seek_lower_bound; + continue 'outer; + } + } + + // At this point all scorers are in a valid state, aligned on the next document in the + // intersection. + debug_assert!(self.others.iter().all(|docset| docset.doc() == candidate)); + return candidate; } + + // We make sure our docset is in a valid state. + // In particular, we want .doc() to return TERMINATED. + left.seek(TERMINATED); + + TERMINATED } fn seek(&mut self, target: DocId) -> DocId { @@ -166,13 +183,19 @@ impl DocSet for Intersection bool { - self.left.seek_into_the_danger_zone(target) - && self.right.seek_into_the_danger_zone(target) - && self - .others - .iter_mut() - .all(|docset| docset.seek_into_the_danger_zone(target)) + fn seek_danger(&mut self, target: DocId) -> SeekDangerResult { + if let SeekDangerResult::SeekLowerBound(new_target) = self.left.seek_danger(target) { + return SeekDangerResult::SeekLowerBound(new_target); + } + if let SeekDangerResult::SeekLowerBound(new_target) = self.right.seek_danger(target) { + return SeekDangerResult::SeekLowerBound(new_target); + } + for docset in &mut self.others { + if let SeekDangerResult::SeekLowerBound(new_target) = docset.seek_danger(target) { + return SeekDangerResult::SeekLowerBound(new_target); + } + } + SeekDangerResult::Found } #[inline] @@ -304,6 +327,58 @@ mod tests { assert_eq!(intersection.doc(), TERMINATED); } + #[test] + fn test_intersection_abc() { + let a = VecDocSet::from(vec![2, 3, 6]); + let b = VecDocSet::from(vec![1, 3, 5]); + let c = VecDocSet::from(vec![1, 3, 5]); + let mut intersection = Intersection::new(vec![c, b, a], 10); + let mut docs = Vec::new(); + use crate::DocSet; + while intersection.doc() != TERMINATED { + docs.push(intersection.doc()); + intersection.advance(); + } + assert_eq!(&docs, &[3]); + } + + #[test] + fn test_intersection_termination() { + use crate::query::score_combiner::DoNothingCombiner; + use crate::query::{BufferedUnionScorer, ConstScorer, VecDocSet}; + + let a1 = ConstScorer::new(VecDocSet::from(vec![0u32, 10000]), 1.0); + let a2 = ConstScorer::new(VecDocSet::from(vec![0u32, 10000]), 1.0); + + let mut b_scorers = vec![]; + for _ in 0..2 { + // Union matches 0 and 10000. + b_scorers.push(ConstScorer::new(VecDocSet::from(vec![0, 10000]), 1.0)); + } + // That's the union of two scores matching 0, and 10_000. + let union = BufferedUnionScorer::build(b_scorers, DoNothingCombiner::default, 30000); + + // Mismatching scorer: matches 0 and 20000. We then append more docs at the end to ensure it + // is last. + let mut m_docs = vec![0, 20000]; + for i in 30000..30100 { + m_docs.push(i); + } + let m = ConstScorer::new(VecDocSet::from(m_docs), 1.0); + + // Costs: A1=2, A2=2, Union=4, M=102. + // Sorted: A1, A2, Union, M. + // Left=A1, Right=A2, Others=[Union, M]. + let mut intersection = crate::query::intersect_scorers( + vec![Box::new(a1), Box::new(a2), Box::new(union), Box::new(m)], + 40000, + ); + + while intersection.doc() != TERMINATED { + intersection.advance(); + } + } + // Strategy to generate sorted and deduplicated vectors of u32 document IDs fn sorted_deduped_vec(max_val: u32, max_size: usize) -> impl Strategy> { prop::collection::vec(0..max_val, 0..max_size).prop_map(|mut vec| { @@ -335,6 +410,5 @@ mod tests { } assert_eq!(intersection.doc(), TERMINATED); } - } } diff --git a/src/query/phrase_prefix_query/phrase_prefix_scorer.rs b/src/query/phrase_prefix_query/phrase_prefix_scorer.rs index 8b03089fa..f2df3433d 100644 --- a/src/query/phrase_prefix_query/phrase_prefix_scorer.rs +++ b/src/query/phrase_prefix_query/phrase_prefix_scorer.rs @@ -1,4 +1,4 @@ -use crate::docset::{DocSet, TERMINATED}; +use crate::docset::{DocSet, SeekDangerResult, TERMINATED}; use crate::fieldnorm::FieldNormReader; use crate::postings::Postings; use crate::query::bm25::Bm25Weight; @@ -194,11 +194,16 @@ impl DocSet for PhrasePrefixScorer { self.advance() } - fn seek_into_the_danger_zone(&mut self, target: DocId) -> bool { - if self.phrase_scorer.seek_into_the_danger_zone(target) { - self.matches_prefix() + fn seek_danger(&mut self, target: DocId) -> SeekDangerResult { + let seek_res = self.phrase_scorer.seek_danger(target); + if seek_res != SeekDangerResult::Found { + return seek_res; + } + // The intersection matched. Now let's see if we match the prefix. + if self.matches_prefix() { + SeekDangerResult::Found } else { - false + SeekDangerResult::SeekLowerBound(target + 1) } } diff --git a/src/query/phrase_query/phrase_scorer.rs b/src/query/phrase_query/phrase_scorer.rs index 108783b40..7460fcf79 100644 --- a/src/query/phrase_query/phrase_scorer.rs +++ b/src/query/phrase_query/phrase_scorer.rs @@ -1,6 +1,6 @@ use std::cmp::Ordering; -use crate::docset::{DocSet, TERMINATED}; +use crate::docset::{DocSet, SeekDangerResult, TERMINATED}; use crate::fieldnorm::FieldNormReader; use crate::postings::Postings; use crate::query::bm25::Bm25Weight; @@ -530,12 +530,18 @@ impl DocSet for PhraseScorer { self.advance() } - fn seek_into_the_danger_zone(&mut self, target: DocId) -> bool { + fn seek_danger(&mut self, target: DocId) -> SeekDangerResult { debug_assert!(target >= self.doc()); - if self.intersection_docset.seek_into_the_danger_zone(target) && self.phrase_match() { - return true; + let seek_res = self.intersection_docset.seek_danger(target); + if seek_res != SeekDangerResult::Found { + return seek_res; + } + // The intersection matched. Now let's see if we match the phrase. + if self.phrase_match() { + SeekDangerResult::Found + } else { + SeekDangerResult::SeekLowerBound(target + 1) } - false } fn doc(&self) -> DocId { diff --git a/src/query/reqopt_scorer.rs b/src/query/reqopt_scorer.rs index bed99f5b7..50b701d6d 100644 --- a/src/query/reqopt_scorer.rs +++ b/src/query/reqopt_scorer.rs @@ -1,6 +1,6 @@ use std::marker::PhantomData; -use crate::docset::DocSet; +use crate::docset::{DocSet, SeekDangerResult}; use crate::query::score_combiner::ScoreCombiner; use crate::query::Scorer; use crate::{DocId, Score}; @@ -56,9 +56,9 @@ where self.req_scorer.seek(target) } - fn seek_into_the_danger_zone(&mut self, target: DocId) -> bool { + fn seek_danger(&mut self, target: DocId) -> SeekDangerResult { self.score_cache = None; - self.req_scorer.seek_into_the_danger_zone(target) + self.req_scorer.seek_danger(target) } fn doc(&self) -> DocId { diff --git a/src/query/union/buffered_union.rs b/src/query/union/buffered_union.rs index ee554e357..9ab980c9a 100644 --- a/src/query/union/buffered_union.rs +++ b/src/query/union/buffered_union.rs @@ -1,6 +1,6 @@ use common::TinySet; -use crate::docset::{DocSet, TERMINATED}; +use crate::docset::{DocSet, SeekDangerResult, TERMINATED}; use crate::query::score_combiner::{DoNothingCombiner, ScoreCombiner}; use crate::query::size_hint::estimate_union; use crate::query::Scorer; @@ -225,25 +225,47 @@ where } } - fn seek_into_the_danger_zone(&mut self, target: DocId) -> bool { + fn seek_danger(&mut self, target: DocId) -> SeekDangerResult { + if target >= TERMINATED { + return SeekDangerResult::SeekLowerBound(TERMINATED); + } if self.is_in_horizon(target) { // Our value is within the buffered horizon and the docset may already have been // processed and removed, so we need to use seek, which uses the regular advance. - self.seek(target) == target - } else { - // The docsets are not in the buffered range, so we can use seek_into_the_danger_zone - // of the underlying docsets - let is_hit = self - .docsets - .iter_mut() - .any(|docset| docset.seek_into_the_danger_zone(target)); + let seek_doc = self.seek(target); + if seek_doc == target { + return SeekDangerResult::Found; + } else { + return SeekDangerResult::SeekLowerBound(seek_doc); + }; + } - // The API requires the DocSet to be in a valid state when `seek_into_the_danger_zone` - // returns true. - if is_hit { - self.seek(target); + // The docsets are not in the buffered range, so we can use seek_into_the_danger_zone + // of the underlying docsets + let mut is_hit = false; + let mut min_new_target = TERMINATED; + + for docset in self.docsets.iter_mut() { + match docset.seek_danger(target) { + SeekDangerResult::Found => { + is_hit = true; + break; + } + SeekDangerResult::SeekLowerBound(new_target) => { + min_new_target = min_new_target.min(new_target); + } } - is_hit + } + + // The API requires the DocSet to be in a valid state when `seek_into_the_danger_zone` + // returns Found. + if is_hit { + // The doc is found. Let's make sure we position the union on the target + // to bring it back to a valid state. + self.seek(target); + SeekDangerResult::Found + } else { + SeekDangerResult::SeekLowerBound(min_new_target) } } diff --git a/src/query/union/mod.rs b/src/query/union/mod.rs index 539c6c387..825ee219b 100644 --- a/src/query/union/mod.rs +++ b/src/query/union/mod.rs @@ -14,7 +14,7 @@ mod tests { use common::BitSet; use super::{SimpleUnion, *}; - use crate::docset::{DocSet, TERMINATED}; + use crate::docset::{DocSet, SeekDangerResult, TERMINATED}; use crate::postings::tests::test_skip_against_unoptimized; use crate::query::score_combiner::DoNothingCombiner; use crate::query::union::bitset_union::BitSetPostingUnion; @@ -254,6 +254,27 @@ mod tests { vec![1, 2, 3, 7, 8, 9, 99, 100, 101, 500, 20000], ); } + + #[test] + fn test_buffered_union_seek_into_danger_zone_terminated() { + let scorer1 = ConstScorer::new(VecDocSet::from(vec![1, 2]), 1.0); + let scorer2 = ConstScorer::new(VecDocSet::from(vec![2, 3]), 1.0); + + let mut union_scorer = + BufferedUnionScorer::build(vec![scorer1, scorer2], DoNothingCombiner::default, 100); + + // Advance to end + while union_scorer.doc() != TERMINATED { + union_scorer.advance(); + } + + assert_eq!(union_scorer.doc(), TERMINATED); + + assert_eq!( + union_scorer.seek_danger(TERMINATED), + SeekDangerResult::SeekLowerBound(TERMINATED) + ); + } } #[cfg(all(test, feature = "unstable"))] diff --git a/src/query/vec_docset.rs b/src/query/vec_docset.rs index 9dafa3ffd..1b6336183 100644 --- a/src/query/vec_docset.rs +++ b/src/query/vec_docset.rs @@ -17,6 +17,9 @@ pub struct VecDocSet { impl From> for VecDocSet { fn from(doc_ids: Vec) -> VecDocSet { + // We do not use `slice::is_sorted`, as we want to check for doc ids to be strictly + // sorted. + assert!(doc_ids.windows(2).all(|w| w[0] < w[1])); VecDocSet { doc_ids, cursor: 0 } } }