diff --git a/src/postings/loaded_postings.rs b/src/postings/loaded_postings.rs new file mode 100644 index 000000000..7212be4ac --- /dev/null +++ b/src/postings/loaded_postings.rs @@ -0,0 +1,155 @@ +use crate::docset::{DocSet, TERMINATED}; +use crate::postings::{Postings, SegmentPostings}; +use crate::DocId; + +/// `LoadedPostings` is a `DocSet` and `Postings` implementation. +/// It is used to represent the postings of a term in memory. +/// It is suitable if there are few documents for a term. +/// +/// It exists mainly to reduce memory usage. +/// `SegmentPostings` uses 1840 bytes per instance due to its caches. +/// If you need to keep many terms around with few docs, it's cheaper to load all the +/// postings in memory. +/// +/// This is relevant for `RegexPhraseQuery`, which may have a lot of +/// terms. +/// E.g. 100_000 terms would need 184MB due to SegmentPostings. +pub struct LoadedPostings { + doc_ids: Box<[DocId]>, + position_offsets: Box<[u32]>, + positions: Box<[u32]>, + cursor: usize, +} + +impl LoadedPostings { + /// Creates a new `LoadedPostings` from a `SegmentPostings`. + /// + /// It will also preload positions, if positions are available in the SegmentPostings. + pub fn load(segment_postings: &mut SegmentPostings) -> LoadedPostings { + let num_docs = segment_postings.doc_freq() as usize; + let mut doc_ids = Vec::with_capacity(num_docs); + let mut positions = Vec::with_capacity(num_docs); + let mut position_offsets = Vec::with_capacity(num_docs); + while segment_postings.doc() != TERMINATED { + position_offsets.push(positions.len() as u32); + doc_ids.push(segment_postings.doc()); + segment_postings.append_positions_with_offset(0, &mut positions); + segment_postings.advance(); + } + position_offsets.push(positions.len() as u32); + LoadedPostings { + doc_ids: doc_ids.into_boxed_slice(), + positions: positions.into_boxed_slice(), + position_offsets: position_offsets.into_boxed_slice(), + cursor: 0, + } + } +} + +#[cfg(test)] +impl From<(Vec, Vec>)> for LoadedPostings { + fn from(doc_ids_and_positions: (Vec, Vec>)) -> LoadedPostings { + let mut position_offsets = Vec::new(); + let mut all_positions = Vec::new(); + let (doc_ids, docid_positions) = doc_ids_and_positions; + for positions in docid_positions { + position_offsets.push(all_positions.len() as u32); + all_positions.extend_from_slice(&positions); + } + position_offsets.push(all_positions.len() as u32); + LoadedPostings { + doc_ids: doc_ids.into_boxed_slice(), + positions: all_positions.into_boxed_slice(), + position_offsets: position_offsets.into_boxed_slice(), + cursor: 0, + } + } +} + +impl DocSet for LoadedPostings { + fn advance(&mut self) -> DocId { + self.cursor += 1; + if self.cursor >= self.doc_ids.len() { + self.cursor = self.doc_ids.len(); + return TERMINATED; + } + self.doc() + } + + fn doc(&self) -> DocId { + if self.cursor >= self.doc_ids.len() { + return TERMINATED; + } + self.doc_ids[self.cursor] + } + + fn size_hint(&self) -> u32 { + self.doc_ids.len() as u32 + } +} +impl Postings for LoadedPostings { + fn term_freq(&self) -> u32 { + let start = self.position_offsets[self.cursor] as usize; + let end = self.position_offsets[self.cursor + 1] as usize; + (end - start) as u32 + } + + fn append_positions_with_offset(&mut self, offset: u32, output: &mut Vec) { + let start = self.position_offsets[self.cursor] as usize; + let end = self.position_offsets[self.cursor + 1] as usize; + for pos in &self.positions[start..end] { + output.push(*pos + offset); + } + } +} + +#[cfg(test)] +pub mod tests { + + use super::*; + + #[test] + pub fn test_vec_postings() { + let doc_ids: Vec = (0u32..1024u32).map(|e| e * 3).collect(); + let mut postings = LoadedPostings::from((doc_ids, vec![])); + assert_eq!(postings.doc(), 0u32); + assert_eq!(postings.advance(), 3u32); + assert_eq!(postings.doc(), 3u32); + assert_eq!(postings.seek(14u32), 15u32); + assert_eq!(postings.doc(), 15u32); + assert_eq!(postings.seek(300u32), 300u32); + assert_eq!(postings.doc(), 300u32); + assert_eq!(postings.seek(6000u32), TERMINATED); + } + + #[test] + pub fn test_vec_postings2() { + let doc_ids: Vec = (0u32..1024u32).map(|e| e * 3).collect(); + let mut positions = Vec::new(); + positions.resize(1024, Vec::new()); + positions[0] = vec![1u32, 2u32, 3u32]; + positions[1] = vec![30u32]; + positions[2] = vec![10u32]; + positions[4] = vec![50u32]; + let mut postings = LoadedPostings::from((doc_ids, positions)); + + let load = |postings: &mut LoadedPostings| { + let mut loaded_positions = Vec::new(); + postings.positions(loaded_positions.as_mut()); + loaded_positions + }; + assert_eq!(postings.doc(), 0u32); + assert_eq!(load(&mut postings), vec![1u32, 2u32, 3u32]); + + assert_eq!(postings.advance(), 3u32); + assert_eq!(postings.doc(), 3u32); + + assert_eq!(load(&mut postings), vec![30u32]); + + assert_eq!(postings.seek(14u32), 15u32); + assert_eq!(postings.doc(), 15u32); + assert_eq!(postings.seek(300u32), 300u32); + assert_eq!(postings.doc(), 300u32); + assert_eq!(postings.seek(6000u32), TERMINATED); + } +} diff --git a/src/postings/mod.rs b/src/postings/mod.rs index 5fd90032d..7060916bd 100644 --- a/src/postings/mod.rs +++ b/src/postings/mod.rs @@ -8,6 +8,7 @@ mod block_segment_postings; pub(crate) mod compression; mod indexing_context; mod json_postings_writer; +mod loaded_postings; mod per_field_postings_writer; mod postings; mod postings_writer; @@ -17,6 +18,7 @@ mod serializer; mod skip; mod term_info; +pub(crate) use loaded_postings::LoadedPostings; pub(crate) use stacker::compute_table_memory_size; pub use self::block_segment_postings::BlockSegmentPostings; diff --git a/src/postings/postings.rs b/src/postings/postings.rs index 682e61393..8606f00a9 100644 --- a/src/postings/postings.rs +++ b/src/postings/postings.rs @@ -17,7 +17,14 @@ pub trait Postings: DocSet + 'static { /// Returns the positions offsetted with a given value. /// It is not necessary to clear the `output` before calling this method. /// The output vector will be resized to the `term_freq`. - fn positions_with_offset(&mut self, offset: u32, output: &mut Vec); + fn positions_with_offset(&mut self, offset: u32, output: &mut Vec) { + output.clear(); + self.append_positions_with_offset(offset, output); + } + + /// Returns the positions offsetted with a given value. + /// Data will be appended to the output. + fn append_positions_with_offset(&mut self, offset: u32, output: &mut Vec); /// Returns the positions of the term in the given document. /// The output vector will be resized to the `term_freq`. @@ -25,3 +32,13 @@ pub trait Postings: DocSet + 'static { self.positions_with_offset(0u32, output); } } + +impl Postings for Box { + fn term_freq(&self) -> u32 { + (**self).term_freq() + } + + fn append_positions_with_offset(&mut self, offset: u32, output: &mut Vec) { + (**self).append_positions_with_offset(offset, output); + } +} diff --git a/src/postings/segment_postings.rs b/src/postings/segment_postings.rs index 3d91cf2ee..51194a356 100644 --- a/src/postings/segment_postings.rs +++ b/src/postings/segment_postings.rs @@ -237,8 +237,9 @@ impl Postings for SegmentPostings { self.block_cursor.freq(self.cur) } - fn positions_with_offset(&mut self, offset: u32, output: &mut Vec) { + fn append_positions_with_offset(&mut self, offset: u32, output: &mut Vec) { let term_freq = self.term_freq(); + let prev_len = output.len(); if let Some(position_reader) = self.position_reader.as_mut() { debug_assert!( !self.block_cursor.freqs().is_empty(), @@ -249,15 +250,14 @@ impl Postings for SegmentPostings { .iter() .cloned() .sum::() as u64); - output.resize(term_freq as usize, 0u32); - position_reader.read(read_offset, &mut output[..]); + // TODO: instead of zeroing the output, we could use MaybeUninit or similar. + output.resize(prev_len + term_freq as usize, 0u32); + position_reader.read(read_offset, &mut output[prev_len..]); let mut cum = offset; - for output_mut in output.iter_mut() { + for output_mut in output[prev_len..].iter_mut() { cum += *output_mut; *output_mut = cum; } - } else { - output.clear(); } } } diff --git a/src/query/automaton_weight.rs b/src/query/automaton_weight.rs index ef675864b..5f1053fb6 100644 --- a/src/query/automaton_weight.rs +++ b/src/query/automaton_weight.rs @@ -6,6 +6,7 @@ use tantivy_fst::Automaton; use super::phrase_prefix_query::prefix_end; use crate::index::SegmentReader; +use crate::postings::TermInfo; use crate::query::{BitSetDocSet, ConstScorer, Explanation, Scorer, Weight}; use crate::schema::{Field, IndexRecordOption}; use crate::termdict::{TermDictionary, TermStreamer}; @@ -64,6 +65,18 @@ where term_stream_builder.into_stream() } + + /// Returns the term infos that match the automaton + pub fn get_match_term_infos(&self, reader: &SegmentReader) -> crate::Result> { + let inverted_index = reader.inverted_index(self.field)?; + let term_dict = inverted_index.terms(); + let mut term_stream = self.automaton_stream(term_dict)?; + let mut term_infos = Vec::new(); + while term_stream.advance() { + term_infos.push(term_stream.value().clone()); + } + Ok(term_infos) + } } impl Weight for AutomatonWeight diff --git a/src/query/boolean_query/block_wand.rs b/src/query/boolean_query/block_wand.rs index ad9a8b2ba..59e22caa1 100644 --- a/src/query/boolean_query/block_wand.rs +++ b/src/query/boolean_query/block_wand.rs @@ -308,7 +308,7 @@ mod tests { use crate::query::score_combiner::SumCombiner; use crate::query::term_query::TermScorer; - use crate::query::{Bm25Weight, Scorer, Union}; + use crate::query::{Bm25Weight, BufferedUnionScorer, Scorer}; use crate::{DocId, DocSet, Score, TERMINATED}; struct Float(Score); @@ -371,7 +371,7 @@ mod tests { fn compute_checkpoints_manual(term_scorers: Vec, n: usize) -> Vec<(DocId, Score)> { let mut heap: BinaryHeap = BinaryHeap::with_capacity(n); let mut checkpoints: Vec<(DocId, Score)> = Vec::new(); - let mut scorer = Union::build(term_scorers, SumCombiner::default); + let mut scorer = BufferedUnionScorer::build(term_scorers, SumCombiner::default); let mut limit = Score::MIN; loop { diff --git a/src/query/boolean_query/boolean_weight.rs b/src/query/boolean_query/boolean_weight.rs index c0a5e2c37..7b617866f 100644 --- a/src/query/boolean_query/boolean_weight.rs +++ b/src/query/boolean_query/boolean_weight.rs @@ -9,8 +9,8 @@ use crate::query::score_combiner::{DoNothingCombiner, ScoreCombiner}; use crate::query::term_query::TermScorer; use crate::query::weight::{for_each_docset_buffered, for_each_pruning_scorer, for_each_scorer}; use crate::query::{ - intersect_scorers, EmptyScorer, Exclude, Explanation, Occur, RequiredOptionalScorer, Scorer, - Union, Weight, + intersect_scorers, BufferedUnionScorer, EmptyScorer, Exclude, Explanation, Occur, + RequiredOptionalScorer, Scorer, Weight, }; use crate::{DocId, Score}; @@ -65,14 +65,17 @@ where // Block wand is only available if we read frequencies. return SpecializedScorer::TermUnion(scorers); } else { - return SpecializedScorer::Other(Box::new(Union::build( + return SpecializedScorer::Other(Box::new(BufferedUnionScorer::build( scorers, score_combiner_fn, ))); } } } - SpecializedScorer::Other(Box::new(Union::build(scorers, score_combiner_fn))) + SpecializedScorer::Other(Box::new(BufferedUnionScorer::build( + scorers, + score_combiner_fn, + ))) } fn into_box_scorer( @@ -81,7 +84,7 @@ fn into_box_scorer( ) -> Box { match scorer { SpecializedScorer::TermUnion(term_scorers) => { - let union_scorer = Union::build(term_scorers, score_combiner_fn); + let union_scorer = BufferedUnionScorer::build(term_scorers, score_combiner_fn); Box::new(union_scorer) } SpecializedScorer::Other(scorer) => scorer, @@ -296,7 +299,8 @@ impl Weight for BooleanWeight { - let mut union_scorer = Union::build(term_scorers, &self.score_combiner_fn); + let mut union_scorer = + BufferedUnionScorer::build(term_scorers, &self.score_combiner_fn); for_each_scorer(&mut union_scorer, callback); } SpecializedScorer::Other(mut scorer) => { @@ -316,7 +320,8 @@ impl Weight for BooleanWeight { - let mut union_scorer = Union::build(term_scorers, &self.score_combiner_fn); + let mut union_scorer = + BufferedUnionScorer::build(term_scorers, &self.score_combiner_fn); for_each_docset_buffered(&mut union_scorer, &mut buffer, callback); } SpecializedScorer::Other(mut scorer) => { diff --git a/src/query/mod.rs b/src/query/mod.rs index 5e99354ff..23e64f189 100644 --- a/src/query/mod.rs +++ b/src/query/mod.rs @@ -51,6 +51,7 @@ pub use self::fuzzy_query::FuzzyTermQuery; pub use self::intersection::{intersect_scorers, Intersection}; pub use self::more_like_this::{MoreLikeThisQuery, MoreLikeThisQueryBuilder}; pub use self::phrase_prefix_query::PhrasePrefixQuery; +pub use self::phrase_query::regex_phrase_query::{wildcard_query_to_regex_str, RegexPhraseQuery}; pub use self::phrase_query::PhraseQuery; pub use self::query::{EnableScoring, Query, QueryClone}; pub use self::query_parser::{QueryParser, QueryParserError}; @@ -61,7 +62,7 @@ pub use self::score_combiner::{DisjunctionMaxCombiner, ScoreCombiner, SumCombine pub use self::scorer::Scorer; pub use self::set_query::TermSetQuery; pub use self::term_query::TermQuery; -pub use self::union::Union; +pub use self::union::BufferedUnionScorer; #[cfg(test)] pub use self::vec_docset::VecDocSet; pub use self::weight::Weight; diff --git a/src/query/phrase_query/mod.rs b/src/query/phrase_query/mod.rs index 7b8d3e007..f37c39b15 100644 --- a/src/query/phrase_query/mod.rs +++ b/src/query/phrase_query/mod.rs @@ -1,6 +1,8 @@ mod phrase_query; mod phrase_scorer; mod phrase_weight; +pub mod regex_phrase_query; +mod regex_phrase_weight; pub use self::phrase_query::PhraseQuery; pub(crate) use self::phrase_scorer::intersection_count; @@ -19,15 +21,15 @@ pub mod tests { use crate::schema::{Schema, Term, TEXT}; use crate::{assert_nearly_equals, DocAddress, DocId, IndexWriter, TERMINATED}; - pub fn create_index(texts: &[&'static str]) -> crate::Result { + pub fn create_index>(texts: &[S]) -> crate::Result { let mut schema_builder = Schema::builder(); let text_field = schema_builder.add_text_field("text", TEXT); let schema = schema_builder.build(); let index = Index::create_in_ram(schema); { let mut index_writer: IndexWriter = index.writer_for_tests()?; - for &text in texts { - let doc = doc!(text_field=>text); + for text in texts { + let doc = doc!(text_field=>text.as_ref()); index_writer.add_document(doc)?; } index_writer.commit()?; diff --git a/src/query/phrase_query/phrase_weight.rs b/src/query/phrase_query/phrase_weight.rs index 6e97bca7f..4118f79f6 100644 --- a/src/query/phrase_query/phrase_weight.rs +++ b/src/query/phrase_query/phrase_weight.rs @@ -50,27 +50,14 @@ impl PhraseWeight { .map(|similarity_weight| similarity_weight.boost_by(boost)); let fieldnorm_reader = self.fieldnorm_reader(reader)?; let mut term_postings_list = Vec::new(); - if reader.has_deletes() { - for &(offset, ref term) in &self.phrase_terms { - if let Some(postings) = reader - .inverted_index(term.field())? - .read_postings(term, IndexRecordOption::WithFreqsAndPositions)? - { - term_postings_list.push((offset, postings)); - } else { - return Ok(None); - } - } - } else { - for &(offset, ref term) in &self.phrase_terms { - if let Some(postings) = reader - .inverted_index(term.field())? - .read_postings_no_deletes(term, IndexRecordOption::WithFreqsAndPositions)? - { - term_postings_list.push((offset, postings)); - } else { - return Ok(None); - } + for &(offset, ref term) in &self.phrase_terms { + if let Some(postings) = reader + .inverted_index(term.field())? + .read_postings(term, IndexRecordOption::WithFreqsAndPositions)? + { + term_postings_list.push((offset, postings)); + } else { + return Ok(None); } } Ok(Some(PhraseScorer::new( diff --git a/src/query/phrase_query/regex_phrase_query.rs b/src/query/phrase_query/regex_phrase_query.rs new file mode 100644 index 000000000..27096fcf1 --- /dev/null +++ b/src/query/phrase_query/regex_phrase_query.rs @@ -0,0 +1,172 @@ +use super::regex_phrase_weight::RegexPhraseWeight; +use crate::query::bm25::Bm25Weight; +use crate::query::{EnableScoring, Query, Weight}; +use crate::schema::{Field, IndexRecordOption, Term, Type}; + +/// `RegexPhraseQuery` matches a specific sequence of regex queries. +/// +/// For instance, the phrase query for `"pa.* time"` will match +/// the sentence: +/// +/// **Alan just got a part time job.** +/// +/// On the other hand it will not match the sentence. +/// +/// **This is my favorite part of the job.** +/// +/// [Slop](RegexPhraseQuery::set_slop) allows leniency in term proximity +/// for some performance trade-off. +/// +/// Using a `RegexPhraseQuery` on a field requires positions +/// to be indexed for this field. +#[derive(Clone, Debug)] +pub struct RegexPhraseQuery { + field: Field, + phrase_terms: Vec<(usize, String)>, + slop: u32, + max_expansions: u32, +} + +/// Transform a wildcard query to a regex string. +/// +/// `AB*CD` for example is converted to `AB.*CD` +/// +/// All other chars are regex escaped. +pub fn wildcard_query_to_regex_str(term: &str) -> String { + regex::escape(term).replace(r"\*", ".*") +} + +impl RegexPhraseQuery { + /// Creates a new `RegexPhraseQuery` given a list of terms. + /// + /// There must be at least two terms, and all terms + /// must belong to the same field. + /// + /// Offset for each term will be same as index in the Vector + pub fn new(field: Field, terms: Vec) -> RegexPhraseQuery { + let terms_with_offset = terms.into_iter().enumerate().collect(); + RegexPhraseQuery::new_with_offset(field, terms_with_offset) + } + + /// Creates a new `RegexPhraseQuery` given a list of terms and their offsets. + /// + /// Can be used to provide custom offset for each term. + pub fn new_with_offset(field: Field, terms: Vec<(usize, String)>) -> RegexPhraseQuery { + RegexPhraseQuery::new_with_offset_and_slop(field, terms, 0) + } + + /// Creates a new `RegexPhraseQuery` given a list of terms, their offsets and a slop + pub fn new_with_offset_and_slop( + field: Field, + mut terms: Vec<(usize, String)>, + slop: u32, + ) -> RegexPhraseQuery { + assert!( + terms.len() > 1, + "A phrase query is required to have strictly more than one term." + ); + terms.sort_by_key(|&(offset, _)| offset); + RegexPhraseQuery { + field, + phrase_terms: terms, + slop, + max_expansions: 1 << 14, + } + } + + /// Slop allowed for the phrase. + /// + /// The query will match if its terms are separated by `slop` terms at most. + /// The slop can be considered a budget between all terms. + /// E.g. "A B C" with slop 1 allows "A X B C", "A B X C", but not "A X B X C". + /// + /// Transposition costs 2, e.g. "A B" with slop 1 will not match "B A" but it would with slop 2 + /// Transposition is not a special case, in the example above A is moved 1 position and B is + /// moved 1 position, so the slop is 2. + /// + /// As a result slop works in both directions, so the order of the terms may changed as long as + /// they respect the slop. + /// + /// By default the slop is 0 meaning query terms need to be adjacent. + pub fn set_slop(&mut self, value: u32) { + self.slop = value; + } + + /// Sets the max expansions a regex term can match. The limit will be over all terms. + /// After the limit is hit an error will be returned. + pub fn set_max_expansions(&mut self, value: u32) { + self.max_expansions = value; + } + + /// The [`Field`] this `RegexPhraseQuery` is targeting. + pub fn field(&self) -> Field { + self.field + } + + /// `Term`s in the phrase without the associated offsets. + pub fn phrase_terms(&self) -> Vec { + self.phrase_terms + .iter() + .map(|(_, term)| Term::from_field_text(self.field, term)) + .collect::>() + } + + /// Returns the [`RegexPhraseWeight`] for the given phrase query given a specific `searcher`. + /// + /// This function is the same as [`Query::weight()`] except it returns + /// a specialized type [`RegexPhraseWeight`] instead of a Boxed trait. + pub(crate) fn regex_phrase_weight( + &self, + enable_scoring: EnableScoring<'_>, + ) -> crate::Result { + let schema = enable_scoring.schema(); + let field_type = schema.get_field_entry(self.field).field_type().value_type(); + if field_type != Type::Str { + return Err(crate::TantivyError::SchemaError(format!( + "RegexPhraseQuery can only be used with a field of type text currently, but got \ + {:?}", + field_type + ))); + } + + let field_entry = schema.get_field_entry(self.field); + let has_positions = field_entry + .field_type() + .get_index_record_option() + .map(IndexRecordOption::has_positions) + .unwrap_or(false); + if !has_positions { + let field_name = field_entry.name(); + return Err(crate::TantivyError::SchemaError(format!( + "Applied phrase query on field {field_name:?}, which does not have positions \ + indexed" + ))); + } + let terms = self.phrase_terms(); + let bm25_weight_opt = match enable_scoring { + EnableScoring::Enabled { + statistics_provider, + .. + } => Some(Bm25Weight::for_terms(statistics_provider, &terms)?), + EnableScoring::Disabled { .. } => None, + }; + let weight = RegexPhraseWeight::new( + self.field, + self.phrase_terms.clone(), + bm25_weight_opt, + self.max_expansions, + self.slop, + ); + Ok(weight) + } +} + +impl Query for RegexPhraseQuery { + /// Create the weight associated with a query. + /// + /// See [`Weight`]. + fn weight(&self, enable_scoring: EnableScoring<'_>) -> crate::Result> { + let phrase_weight = self.regex_phrase_weight(enable_scoring)?; + Ok(Box::new(phrase_weight)) + } +} diff --git a/src/query/phrase_query/regex_phrase_weight.rs b/src/query/phrase_query/regex_phrase_weight.rs new file mode 100644 index 000000000..53959c644 --- /dev/null +++ b/src/query/phrase_query/regex_phrase_weight.rs @@ -0,0 +1,475 @@ +use std::sync::Arc; + +use common::BitSet; +use tantivy_fst::Regex; + +use super::PhraseScorer; +use crate::fieldnorm::FieldNormReader; +use crate::index::SegmentReader; +use crate::postings::{LoadedPostings, Postings, SegmentPostings, TermInfo}; +use crate::query::bm25::Bm25Weight; +use crate::query::explanation::does_not_match; +use crate::query::union::{BitSetPostingUnion, SimpleUnion}; +use crate::query::{AutomatonWeight, BitSetDocSet, EmptyScorer, Explanation, Scorer, Weight}; +use crate::schema::{Field, IndexRecordOption}; +use crate::{DocId, DocSet, InvertedIndexReader, Score}; + +type UnionType = SimpleUnion>; + +/// The `RegexPhraseWeight` is the weight associated to a regex phrase query. +/// See RegexPhraseWeight::get_union_from_term_infos for some design decisions. +pub struct RegexPhraseWeight { + field: Field, + phrase_terms: Vec<(usize, String)>, + similarity_weight_opt: Option, + slop: u32, + max_expansions: u32, +} + +impl RegexPhraseWeight { + /// Creates a new phrase weight. + /// If `similarity_weight_opt` is None, then scoring is disabled + pub fn new( + field: Field, + phrase_terms: Vec<(usize, String)>, + similarity_weight_opt: Option, + max_expansions: u32, + slop: u32, + ) -> RegexPhraseWeight { + RegexPhraseWeight { + field, + phrase_terms, + similarity_weight_opt, + slop, + max_expansions, + } + } + + fn fieldnorm_reader(&self, reader: &SegmentReader) -> crate::Result { + if self.similarity_weight_opt.is_some() { + if let Some(fieldnorm_reader) = reader.fieldnorms_readers().get_field(self.field)? { + return Ok(fieldnorm_reader); + } + } + Ok(FieldNormReader::constant(reader.max_doc(), 1)) + } + + pub(crate) fn phrase_scorer( + &self, + reader: &SegmentReader, + boost: Score, + ) -> crate::Result>> { + let similarity_weight_opt = self + .similarity_weight_opt + .as_ref() + .map(|similarity_weight| similarity_weight.boost_by(boost)); + let fieldnorm_reader = self.fieldnorm_reader(reader)?; + let mut posting_lists = Vec::new(); + let inverted_index = reader.inverted_index(self.field)?; + let mut num_terms = 0; + for &(offset, ref term) in &self.phrase_terms { + let regex = Regex::new(term) + .map_err(|e| crate::TantivyError::InvalidArgument(format!("Invalid regex: {e}")))?; + + let automaton: AutomatonWeight = + AutomatonWeight::new(self.field, Arc::new(regex)); + let term_infos = automaton.get_match_term_infos(reader)?; + // If term_infos is empty, the phrase can not match any documents. + if term_infos.is_empty() { + return Ok(None); + } + num_terms += term_infos.len(); + if num_terms > self.max_expansions as usize { + return Err(crate::TantivyError::InvalidArgument(format!( + "Phrase query exceeded max expansions {}", + num_terms + ))); + } + let union = Self::get_union_from_term_infos(&term_infos, reader, &inverted_index)?; + + posting_lists.push((offset, union)); + } + + Ok(Some(PhraseScorer::new( + posting_lists, + similarity_weight_opt, + fieldnorm_reader, + self.slop, + ))) + } + + /// Add all docs of the term to the docset + fn add_to_bitset( + inverted_index: &InvertedIndexReader, + term_info: &TermInfo, + doc_bitset: &mut BitSet, + ) -> crate::Result<()> { + let mut block_segment_postings = inverted_index + .read_block_postings_from_terminfo(term_info, IndexRecordOption::Basic)?; + loop { + let docs = block_segment_postings.docs(); + if docs.is_empty() { + break; + } + for &doc in docs { + doc_bitset.insert(doc); + } + block_segment_postings.advance(); + } + Ok(()) + } + + /// This function generates a union of document sets from multiple term information + /// (`TermInfo`). + /// + /// It uses bucketing based on term frequency to optimize query performance and memory usage. + /// The terms are divided into buckets based on their document frequency (the number of + /// documents they appear in). + /// + /// ### Bucketing Strategy: + /// Once a bucket contains more than 512 terms, it is moved to the end of the list and replaced + /// with a new empty bucket. + /// + /// - **Sparse Term Buckets**: Terms with document frequency `< 100`. + /// + /// Each sparse bucket contains: + /// - A `BitSet` to efficiently track which document IDs are present in the bucket, which is + /// used to drive the `DocSet`. + /// - A `Vec` to store the postings for each term in that bucket. + /// + /// - **Other Term Buckets**: + /// - **Bucket 0**: Terms appearing in less than `0.1%` of documents. + /// - **Bucket 1**: Terms appearing in `0.1%` to `1%` of documents. + /// - **Bucket 2**: Terms appearing in `1%` to `10%` of documents. + /// - **Bucket 3**: Terms appearing in more than `10%` of documents. + /// + /// Each bucket contains: + /// - A `BitSet` to efficiently track which document IDs are present in the bucket. + /// - A `Vec` to store the postings for each term in that bucket. + /// + /// ### Design Choices: + /// The main cost for a _unbucketed_ regex phrase query with a medium/high amount of terms is + /// the `append_positions_with_offset` from `Postings`. + /// We don't know which docsets hit, so we need to scan all of them to check if they contain the + /// docid. + /// The bucketing strategy groups less common DocSets together, so we can rule out the + /// whole docset group in many cases. + /// + /// E.g. consider the phrase "th* world" + /// It contains the term "the", which may occur in almost all documents. + /// It may also contain 10_000s very rare terms like "theologian". + /// + /// For very low-frequency terms (sparse terms), we use `LoadedPostings` and aggregate + /// their document IDs into a `BitSet`, which is more memory-efficient than using + /// `SegmentPostings`. E.g. 100_000 terms with SegmentPostings would consume 184MB. + /// `SegmentPostings` uses memory equivalent to 460 docids. The 100 docs limit should be + /// fine as long as a term doesn't have too many positions per doc. + /// + /// ### Future Optimization: + /// A larger performance improvement would be an additional partitioning of the space + /// vertically of u16::MAX blocks, where we mark which docset ord has values in each block. + /// E.g. partitioning in a index with 5 million documents this would reduce the number of + /// docsets to scan to around 1/20 in the sparse term bucket where the terms only have a few + /// docs. For higher cardinality buckets this is irrelevant as they are in most blocks. + /// + /// Use Roaring Bitmaps for sparse terms. The full bitvec is main memory consumer currently. + pub(crate) fn get_union_from_term_infos( + term_infos: &[TermInfo], + reader: &SegmentReader, + inverted_index: &InvertedIndexReader, + ) -> crate::Result { + let max_doc = reader.max_doc(); + + // Buckets for sparse terms + let mut sparse_buckets: Vec<(BitSet, Vec)> = + vec![(BitSet::with_max_value(max_doc), Vec::new())]; + + // Buckets for other terms based on document frequency percentages: + // - Bucket 0: Terms appearing in less than 0.1% of documents + // - Bucket 1: Terms appearing in 0.1% to 1% of documents + // - Bucket 2: Terms appearing in 1% to 10% of documents + // - Bucket 3: Terms appearing in more than 10% of documents + let mut buckets: Vec<(BitSet, Vec)> = (0..4) + .map(|_| (BitSet::with_max_value(max_doc), Vec::new())) + .collect(); + + const SPARSE_TERM_DOC_THRESHOLD: u32 = 100; + + for term_info in term_infos { + let mut term_posting = inverted_index + .read_postings_from_terminfo(term_info, IndexRecordOption::WithFreqsAndPositions)?; + let num_docs = term_posting.doc_freq(); + + if num_docs < SPARSE_TERM_DOC_THRESHOLD { + let current_bucket = &mut sparse_buckets[0]; + Self::add_to_bitset(inverted_index, term_info, &mut current_bucket.0)?; + let docset = LoadedPostings::load(&mut term_posting); + current_bucket.1.push(docset); + + // Move the bucket to the end if the term limit is reached + if current_bucket.1.len() == 512 { + sparse_buckets.push((BitSet::with_max_value(max_doc), Vec::new())); + let end_index = sparse_buckets.len() - 1; + sparse_buckets.swap(0, end_index); + } + } else { + // Calculate the percentage of documents the term appears in + let doc_freq_percentage = (num_docs as f32) / (max_doc as f32) * 100.0; + + // Determine the appropriate bucket based on percentage thresholds + let bucket_index = if doc_freq_percentage < 0.1 { + 0 + } else if doc_freq_percentage < 1.0 { + 1 + } else if doc_freq_percentage < 10.0 { + 2 + } else { + 3 + }; + let bucket = &mut buckets[bucket_index]; + + // Add term postings to the appropriate bucket + Self::add_to_bitset(inverted_index, term_info, &mut bucket.0)?; + bucket.1.push(term_posting); + + // Move the bucket to the end if the term limit is reached + if bucket.1.len() == 512 { + buckets.push((BitSet::with_max_value(max_doc), Vec::new())); + let end_index = buckets.len() - 1; + buckets.swap(bucket_index, end_index); + } + } + } + + // Build unions for sparse term buckets + let sparse_term_docsets: Vec<_> = sparse_buckets + .into_iter() + .filter(|(_, postings)| !postings.is_empty()) + .map(|(bitset, postings)| { + BitSetPostingUnion::build(postings, BitSetDocSet::from(bitset)) + }) + .collect(); + let sparse_term_unions = SimpleUnion::build(sparse_term_docsets); + + // Build unions for other term buckets + let bitset_unions_per_bucket: Vec<_> = buckets + .into_iter() + .filter(|(_, postings)| !postings.is_empty()) + .map(|(bitset, postings)| { + BitSetPostingUnion::build(postings, BitSetDocSet::from(bitset)) + }) + .collect(); + let other_union = SimpleUnion::build(bitset_unions_per_bucket); + + let union: SimpleUnion> = + SimpleUnion::build(vec![Box::new(sparse_term_unions), Box::new(other_union)]); + + // Return a union of sparse term unions and other term unions + Ok(union) + } +} + +impl Weight for RegexPhraseWeight { + fn scorer(&self, reader: &SegmentReader, boost: Score) -> crate::Result> { + if let Some(scorer) = self.phrase_scorer(reader, boost)? { + Ok(Box::new(scorer)) + } else { + Ok(Box::new(EmptyScorer)) + } + } + + fn explain(&self, reader: &SegmentReader, doc: DocId) -> crate::Result { + let scorer_opt = self.phrase_scorer(reader, 1.0)?; + if scorer_opt.is_none() { + return Err(does_not_match(doc)); + } + let mut scorer = scorer_opt.unwrap(); + if scorer.seek(doc) != doc { + return Err(does_not_match(doc)); + } + let fieldnorm_reader = self.fieldnorm_reader(reader)?; + let fieldnorm_id = fieldnorm_reader.fieldnorm_id(doc); + let phrase_count = scorer.phrase_count(); + let mut explanation = Explanation::new("Phrase Scorer", scorer.score()); + if let Some(similarity_weight) = self.similarity_weight_opt.as_ref() { + explanation.add_detail(similarity_weight.explain(fieldnorm_id, phrase_count)); + } + Ok(explanation) + } +} + +#[cfg(test)] +mod tests { + use proptest::prelude::*; + use rand::seq::SliceRandom; + + use super::super::tests::create_index; + use crate::docset::TERMINATED; + use crate::query::{wildcard_query_to_regex_str, EnableScoring, RegexPhraseQuery}; + use crate::DocSet; + + proptest! { + #![proptest_config(ProptestConfig::with_cases(50))] + #[test] + fn test_phrase_regex_with_random_strings(mut random_strings in proptest::collection::vec("[c-z ]{0,10}", 1..100), num_occurrences in 1..150_usize) { + let mut rng = rand::thread_rng(); + + // Insert "aaa ccc" the specified number of times into the list + for _ in 0..num_occurrences { + random_strings.push("aaa ccc".to_string()); + } + // Shuffle the list, which now contains random strings and the inserted "aaa ccc" + random_strings.shuffle(&mut rng); + + // Compute the positions of "aaa ccc" after the shuffle + let aaa_ccc_positions: Vec = random_strings + .iter() + .enumerate() + .filter_map(|(idx, s)| if s == "aaa ccc" { Some(idx) } else { None }) + .collect(); + + // Create the index with random strings and the fixed string "aaa ccc" + let index = create_index(&random_strings.iter().map(AsRef::as_ref).collect::>())?; + let schema = index.schema(); + let text_field = schema.get_field("text").unwrap(); + let searcher = index.reader()?.searcher(); + + let phrase_query = RegexPhraseQuery::new(text_field, vec![wildcard_query_to_regex_str("a*"), wildcard_query_to_regex_str("c*")]); + + let enable_scoring = EnableScoring::enabled_from_searcher(&searcher); + let phrase_weight = phrase_query.regex_phrase_weight(enable_scoring).unwrap(); + let mut phrase_scorer = phrase_weight + .phrase_scorer(searcher.segment_reader(0u32), 1.0)? + .unwrap(); + + // Check if the scorer returns the correct document positions for "aaa ccc" + for expected_doc in aaa_ccc_positions { + prop_assert_eq!(phrase_scorer.doc(), expected_doc as u32); + prop_assert_eq!(phrase_scorer.phrase_count(), 1); + phrase_scorer.advance(); + } + prop_assert_eq!(phrase_scorer.advance(), TERMINATED); + } + } + + #[test] + pub fn test_phrase_count() -> crate::Result<()> { + let index = create_index(&["a c", "a a b d a b c", " a b"])?; + let schema = index.schema(); + let text_field = schema.get_field("text").unwrap(); + let searcher = index.reader()?.searcher(); + let phrase_query = RegexPhraseQuery::new(text_field, vec!["a".into(), "b".into()]); + let enable_scoring = EnableScoring::enabled_from_searcher(&searcher); + let phrase_weight = phrase_query.regex_phrase_weight(enable_scoring).unwrap(); + let mut phrase_scorer = phrase_weight + .phrase_scorer(searcher.segment_reader(0u32), 1.0)? + .unwrap(); + assert_eq!(phrase_scorer.doc(), 1); + assert_eq!(phrase_scorer.phrase_count(), 2); + assert_eq!(phrase_scorer.advance(), 2); + assert_eq!(phrase_scorer.doc(), 2); + assert_eq!(phrase_scorer.phrase_count(), 1); + assert_eq!(phrase_scorer.advance(), TERMINATED); + Ok(()) + } + + #[test] + pub fn test_phrase_wildcard() -> crate::Result<()> { + let index = create_index(&["a c", "a aa b d ad b c", " ac b", "bac b"])?; + let schema = index.schema(); + let text_field = schema.get_field("text").unwrap(); + let searcher = index.reader()?.searcher(); + let phrase_query = RegexPhraseQuery::new(text_field, vec!["a.*".into(), "b".into()]); + let enable_scoring = EnableScoring::enabled_from_searcher(&searcher); + let phrase_weight = phrase_query.regex_phrase_weight(enable_scoring).unwrap(); + let mut phrase_scorer = phrase_weight + .phrase_scorer(searcher.segment_reader(0u32), 1.0)? + .unwrap(); + assert_eq!(phrase_scorer.doc(), 1); + assert_eq!(phrase_scorer.phrase_count(), 2); + assert_eq!(phrase_scorer.advance(), 2); + assert_eq!(phrase_scorer.doc(), 2); + assert_eq!(phrase_scorer.phrase_count(), 1); + assert_eq!(phrase_scorer.advance(), TERMINATED); + + Ok(()) + } + + #[test] + pub fn test_phrase_regex() -> crate::Result<()> { + let index = create_index(&["ba b", "a aa b d ad b c", "bac b"])?; + let schema = index.schema(); + let text_field = schema.get_field("text").unwrap(); + let searcher = index.reader()?.searcher(); + let phrase_query = RegexPhraseQuery::new(text_field, vec!["b?a.*".into(), "b".into()]); + let enable_scoring = EnableScoring::enabled_from_searcher(&searcher); + let phrase_weight = phrase_query.regex_phrase_weight(enable_scoring).unwrap(); + let mut phrase_scorer = phrase_weight + .phrase_scorer(searcher.segment_reader(0u32), 1.0)? + .unwrap(); + assert_eq!(phrase_scorer.doc(), 0); + assert_eq!(phrase_scorer.phrase_count(), 1); + assert_eq!(phrase_scorer.advance(), 1); + assert_eq!(phrase_scorer.phrase_count(), 2); + assert_eq!(phrase_scorer.advance(), 2); + assert_eq!(phrase_scorer.doc(), 2); + assert_eq!(phrase_scorer.phrase_count(), 1); + assert_eq!(phrase_scorer.advance(), TERMINATED); + + Ok(()) + } + + #[test] + pub fn test_phrase_regex_with_slop() -> crate::Result<()> { + let index = create_index(&["aaa bbb ccc ___ abc ddd bbb ccc"])?; + let schema = index.schema(); + let text_field = schema.get_field("text").unwrap(); + let searcher = index.reader()?.searcher(); + let mut phrase_query = RegexPhraseQuery::new(text_field, vec!["a.*".into(), "c.*".into()]); + phrase_query.set_slop(1); + let enable_scoring = EnableScoring::enabled_from_searcher(&searcher); + let phrase_weight = phrase_query.regex_phrase_weight(enable_scoring).unwrap(); + let mut phrase_scorer = phrase_weight + .phrase_scorer(searcher.segment_reader(0u32), 1.0)? + .unwrap(); + assert_eq!(phrase_scorer.doc(), 0); + assert_eq!(phrase_scorer.phrase_count(), 1); + assert_eq!(phrase_scorer.advance(), TERMINATED); + + phrase_query.set_slop(2); + let enable_scoring = EnableScoring::enabled_from_searcher(&searcher); + let phrase_weight = phrase_query.regex_phrase_weight(enable_scoring).unwrap(); + let mut phrase_scorer = phrase_weight + .phrase_scorer(searcher.segment_reader(0u32), 1.0)? + .unwrap(); + assert_eq!(phrase_scorer.doc(), 0); + assert_eq!(phrase_scorer.phrase_count(), 2); + assert_eq!(phrase_scorer.advance(), TERMINATED); + + Ok(()) + } + + #[test] + pub fn test_phrase_regex_double_wildcard() -> crate::Result<()> { + let index = create_index(&["baaab bccccb"])?; + let schema = index.schema(); + let text_field = schema.get_field("text").unwrap(); + let searcher = index.reader()?.searcher(); + let phrase_query = RegexPhraseQuery::new( + text_field, + vec![ + wildcard_query_to_regex_str("*a*"), + wildcard_query_to_regex_str("*c*"), + ], + ); + let enable_scoring = EnableScoring::enabled_from_searcher(&searcher); + let phrase_weight = phrase_query.regex_phrase_weight(enable_scoring).unwrap(); + let mut phrase_scorer = phrase_weight + .phrase_scorer(searcher.segment_reader(0u32), 1.0)? + .unwrap(); + assert_eq!(phrase_scorer.doc(), 0); + assert_eq!(phrase_scorer.phrase_count(), 1); + assert_eq!(phrase_scorer.advance(), TERMINATED); + Ok(()) + } +} diff --git a/src/query/union/bitset_union.rs b/src/query/union/bitset_union.rs new file mode 100644 index 000000000..8af1703ee --- /dev/null +++ b/src/query/union/bitset_union.rs @@ -0,0 +1,89 @@ +use std::cell::RefCell; + +use crate::docset::DocSet; +use crate::postings::Postings; +use crate::query::BitSetDocSet; +use crate::DocId; + +/// Creates a `Posting` that uses the bitset for hits and the docsets for PostingLists. +/// +/// It is used for the regex phrase query, where we need the union of a large amount of +/// terms, but need to keep the docsets for the postings. +pub struct BitSetPostingUnion { + /// The docsets are required to load positions + /// + /// RefCell because we mutate in term_freq + docsets: RefCell>, + /// The already unionized BitSet of the docsets + bitset: BitSetDocSet, +} + +impl BitSetPostingUnion { + pub(crate) fn build( + docsets: Vec, + bitset: BitSetDocSet, + ) -> BitSetPostingUnion { + BitSetPostingUnion { + docsets: RefCell::new(docsets), + bitset, + } + } +} + +impl Postings for BitSetPostingUnion { + fn term_freq(&self) -> u32 { + let curr_doc = self.bitset.doc(); + let mut term_freq = 0; + let mut docsets = self.docsets.borrow_mut(); + for docset in docsets.iter_mut() { + if docset.doc() < curr_doc { + docset.seek(curr_doc); + } + if docset.doc() == curr_doc { + term_freq += docset.term_freq(); + } + } + term_freq + } + + fn append_positions_with_offset(&mut self, offset: u32, output: &mut Vec) { + let curr_doc = self.bitset.doc(); + let mut docsets = self.docsets.borrow_mut(); + for docset in docsets.iter_mut() { + if docset.doc() < curr_doc { + docset.seek(curr_doc); + } + if docset.doc() == curr_doc { + docset.append_positions_with_offset(offset, output); + } + } + debug_assert!( + !output.is_empty(), + "this method should only be called if positions are available" + ); + output.sort_unstable(); + output.dedup(); + } +} + +impl DocSet for BitSetPostingUnion { + fn advance(&mut self) -> DocId { + self.bitset.advance() + } + + fn seek(&mut self, target: DocId) -> DocId { + self.bitset.seek(target) + } + + fn doc(&self) -> DocId { + self.bitset.doc() + } + + fn size_hint(&self) -> u32 { + self.bitset.size_hint() + } + + fn count_including_deleted(&mut self) -> u32 { + self.bitset.count_including_deleted() + } +} diff --git a/src/query/union.rs b/src/query/union/buffered_union.rs similarity index 50% rename from src/query/union.rs rename to src/query/union/buffered_union.rs index b1f23156a..5fc946ee1 100644 --- a/src/query/union.rs +++ b/src/query/union/buffered_union.rs @@ -26,7 +26,7 @@ where P: FnMut(&mut T) -> bool { } /// Creates a `DocSet` that iterate through the union of two or more `DocSet`s. -pub struct Union { +pub struct BufferedUnionScorer { docsets: Vec, bitsets: Box<[TinySet; HORIZON_NUM_TINYBITSETS]>, scores: Box<[TScoreCombiner; HORIZON as usize]>, @@ -61,16 +61,16 @@ fn refill( }); } -impl Union { +impl BufferedUnionScorer { pub(crate) fn build( docsets: Vec, score_combiner_fn: impl FnOnce() -> TScoreCombiner, - ) -> Union { + ) -> BufferedUnionScorer { let non_empty_docsets: Vec = docsets .into_iter() .filter(|docset| docset.doc() != TERMINATED) .collect(); - let mut union = Union { + let mut union = BufferedUnionScorer { docsets: non_empty_docsets, bitsets: Box::new([TinySet::empty(); HORIZON_NUM_TINYBITSETS]), scores: Box::new([score_combiner_fn(); HORIZON as usize]), @@ -121,7 +121,7 @@ impl Union DocSet for Union +impl DocSet for BufferedUnionScorer where TScorer: Scorer, TScoreCombiner: ScoreCombiner, @@ -230,7 +230,7 @@ where } } -impl Scorer for Union +impl Scorer for BufferedUnionScorer where TScoreCombiner: ScoreCombiner, TScorer: Scorer, @@ -239,205 +239,3 @@ where self.score } } - -#[cfg(test)] -mod tests { - - use std::collections::BTreeSet; - - use super::{Union, HORIZON}; - use crate::docset::{DocSet, TERMINATED}; - use crate::postings::tests::test_skip_against_unoptimized; - use crate::query::score_combiner::DoNothingCombiner; - use crate::query::{ConstScorer, VecDocSet}; - use crate::{tests, DocId}; - - fn aux_test_union(vals: Vec>) { - let mut val_set: BTreeSet = BTreeSet::new(); - for vs in &vals { - for &v in vs { - val_set.insert(v); - } - } - let union_vals: Vec = val_set.into_iter().collect(); - let mut union_expected = VecDocSet::from(union_vals); - let make_union = || { - Union::build( - vals.iter() - .cloned() - .map(VecDocSet::from) - .map(|docset| ConstScorer::new(docset, 1.0)) - .collect::>>(), - DoNothingCombiner::default, - ) - }; - let mut union: Union<_, DoNothingCombiner> = make_union(); - let mut count = 0; - while union.doc() != TERMINATED { - assert_eq!(union_expected.doc(), union.doc()); - assert_eq!(union_expected.advance(), union.advance()); - count += 1; - } - assert_eq!(union_expected.advance(), TERMINATED); - assert_eq!(count, make_union().count_including_deleted()); - } - - #[test] - fn test_union() { - aux_test_union(vec![ - vec![1, 3333, 100000000u32], - vec![1, 2, 100000000u32], - vec![1, 2, 100000000u32], - vec![], - ]); - aux_test_union(vec![ - vec![1, 3333, 100000000u32], - vec![1, 2, 100000000u32], - vec![1, 2, 100000000u32], - vec![], - ]); - aux_test_union(vec![ - tests::sample_with_seed(100_000, 0.01, 1), - tests::sample_with_seed(100_000, 0.05, 2), - tests::sample_with_seed(100_000, 0.001, 3), - ]); - } - - fn test_aux_union_skip(docs_list: &[Vec], skip_targets: Vec) { - let mut btree_set = BTreeSet::new(); - for docs in docs_list { - btree_set.extend(docs.iter().cloned()); - } - let docset_factory = || { - let res: Box = Box::new(Union::build( - docs_list - .iter() - .cloned() - .map(VecDocSet::from) - .map(|docset| ConstScorer::new(docset, 1.0)) - .collect::>(), - DoNothingCombiner::default, - )); - res - }; - let mut docset = docset_factory(); - for el in btree_set { - assert_eq!(el, docset.doc()); - docset.advance(); - } - assert_eq!(docset.doc(), TERMINATED); - test_skip_against_unoptimized(docset_factory, skip_targets); - } - - #[test] - fn test_union_skip_corner_case() { - test_aux_union_skip(&[vec![165132, 167382], vec![25029, 25091]], vec![25029]); - } - - #[test] - fn test_union_skip_corner_case2() { - test_aux_union_skip( - &[vec![1u32, 1u32 + HORIZON], vec![2u32, 1000u32, 10_000u32]], - vec![0u32, 1u32, 2u32, 3u32, 1u32 + HORIZON, 2u32 + HORIZON], - ); - } - - #[test] - fn test_union_skip_corner_case3() { - let mut docset = Union::build( - vec![ - ConstScorer::from(VecDocSet::from(vec![0u32, 5u32])), - ConstScorer::from(VecDocSet::from(vec![1u32, 4u32])), - ], - DoNothingCombiner::default, - ); - assert_eq!(docset.doc(), 0u32); - assert_eq!(docset.seek(0u32), 0u32); - assert_eq!(docset.seek(0u32), 0u32); - assert_eq!(docset.doc(), 0u32) - } - - #[test] - fn test_union_skip_random() { - test_aux_union_skip( - &[ - vec![1, 2, 3, 7], - vec![1, 3, 9, 10000], - vec![1, 3, 8, 9, 100], - ], - vec![1, 2, 3, 5, 6, 7, 8, 100], - ); - test_aux_union_skip( - &[ - tests::sample_with_seed(100_000, 0.001, 1), - tests::sample_with_seed(100_000, 0.002, 2), - tests::sample_with_seed(100_000, 0.005, 3), - ], - tests::sample_with_seed(100_000, 0.01, 4), - ); - } - - #[test] - fn test_union_skip_specific() { - test_aux_union_skip( - &[ - vec![1, 2, 3, 7], - vec![1, 3, 9, 10000], - vec![1, 3, 8, 9, 100], - ], - vec![1, 2, 3, 7, 8, 9, 99, 100, 101, 500, 20000], - ); - } -} - -#[cfg(all(test, feature = "unstable"))] -mod bench { - - use test::Bencher; - - use crate::query::score_combiner::DoNothingCombiner; - use crate::query::{ConstScorer, Union, VecDocSet}; - use crate::{tests, DocId, DocSet, TERMINATED}; - - #[bench] - fn bench_union_3_high(bench: &mut Bencher) { - let union_docset: Vec> = vec![ - tests::sample_with_seed(100_000, 0.1, 0), - tests::sample_with_seed(100_000, 0.2, 1), - ]; - bench.iter(|| { - let mut v = Union::build( - union_docset - .iter() - .map(|doc_ids| VecDocSet::from(doc_ids.clone())) - .map(|docset| ConstScorer::new(docset, 1.0)) - .collect::>(), - DoNothingCombiner::default, - ); - while v.doc() != TERMINATED { - v.advance(); - } - }); - } - #[bench] - fn bench_union_3_low(bench: &mut Bencher) { - let union_docset: Vec> = vec![ - tests::sample_with_seed(100_000, 0.01, 0), - tests::sample_with_seed(100_000, 0.05, 1), - tests::sample_with_seed(100_000, 0.001, 2), - ]; - bench.iter(|| { - let mut v = Union::build( - union_docset - .iter() - .map(|doc_ids| VecDocSet::from(doc_ids.clone())) - .map(|docset| ConstScorer::new(docset, 1.0)) - .collect::>(), - DoNothingCombiner::default, - ); - while v.doc() != TERMINATED { - v.advance(); - } - }); - } -} diff --git a/src/query/union/mod.rs b/src/query/union/mod.rs new file mode 100644 index 000000000..84153e272 --- /dev/null +++ b/src/query/union/mod.rs @@ -0,0 +1,303 @@ +mod bitset_union; +mod buffered_union; +mod simple_union; + +pub use bitset_union::BitSetPostingUnion; +pub use buffered_union::BufferedUnionScorer; +pub use simple_union::SimpleUnion; + +#[cfg(test)] +mod tests { + + use std::collections::BTreeSet; + + use common::BitSet; + + use super::{SimpleUnion, *}; + use crate::docset::{DocSet, TERMINATED}; + use crate::postings::tests::test_skip_against_unoptimized; + use crate::query::score_combiner::DoNothingCombiner; + use crate::query::union::bitset_union::BitSetPostingUnion; + use crate::query::{BitSetDocSet, ConstScorer, VecDocSet}; + use crate::{tests, DocId}; + + fn vec_doc_set_from_docs_list( + docs_list: &[Vec], + ) -> impl Iterator + '_ { + docs_list.iter().cloned().map(VecDocSet::from) + } + fn union_from_docs_list(docs_list: &[Vec]) -> Box { + Box::new(BufferedUnionScorer::build( + vec_doc_set_from_docs_list(docs_list) + .map(|docset| ConstScorer::new(docset, 1.0)) + .collect::>>(), + DoNothingCombiner::default, + )) + } + + fn posting_list_union_from_docs_list(docs_list: &[Vec]) -> Box { + Box::new(BitSetPostingUnion::build( + vec_doc_set_from_docs_list(docs_list).collect::>(), + bitset_from_docs_list(docs_list), + )) + } + fn simple_union_from_docs_list(docs_list: &[Vec]) -> Box { + Box::new(SimpleUnion::build( + vec_doc_set_from_docs_list(docs_list).collect::>(), + )) + } + fn bitset_from_docs_list(docs_list: &[Vec]) -> BitSetDocSet { + let max_doc = docs_list + .iter() + .flat_map(|docs| docs.iter().copied()) + .max() + .unwrap_or(0); + let mut doc_bitset = BitSet::with_max_value(max_doc + 1); + for docs in docs_list { + for &doc in docs { + doc_bitset.insert(doc); + } + } + BitSetDocSet::from(doc_bitset) + } + fn aux_test_union(docs_list: &[Vec]) { + for constructor in [ + posting_list_union_from_docs_list, + simple_union_from_docs_list, + union_from_docs_list, + ] { + aux_test_union_with_constructor(constructor, docs_list); + } + } + fn aux_test_union_with_constructor(constructor: F, docs_list: &[Vec]) + where F: Fn(&[Vec]) -> Box { + let mut val_set: BTreeSet = BTreeSet::new(); + for vs in docs_list { + for &v in vs { + val_set.insert(v); + } + } + let union_vals: Vec = val_set.into_iter().collect(); + let mut union_expected = VecDocSet::from(union_vals); + let make_union = || constructor(docs_list); + let mut union = make_union(); + let mut count = 0; + while union.doc() != TERMINATED { + assert_eq!(union_expected.doc(), union.doc()); + assert_eq!(union_expected.advance(), union.advance()); + count += 1; + } + assert_eq!(union_expected.advance(), TERMINATED); + assert_eq!(count, make_union().count_including_deleted()); + } + + use proptest::prelude::*; + + proptest! { + #[test] + fn test_union_is_same(vecs in prop::collection::vec( + prop::collection::vec(0u32..100, 1..10) + .prop_map(|mut inner| { + inner.sort_unstable(); + inner.dedup(); + inner + }), + 1..10 + ), + seek_docids in prop::collection::vec(0u32..100, 0..10).prop_map(|mut inner| { + inner.sort_unstable(); + inner + })) { + test_docid_with_skip(&vecs, &seek_docids); + } + } + + fn test_docid_with_skip(vecs: &[Vec], skip_targets: &[DocId]) { + let mut union1 = posting_list_union_from_docs_list(vecs); + let mut union2 = simple_union_from_docs_list(vecs); + let mut union3 = union_from_docs_list(vecs); + + // Check initial sequential advance + while union1.doc() != TERMINATED { + assert_eq!(union1.doc(), union2.doc()); + assert_eq!(union1.doc(), union3.doc()); + assert_eq!(union1.advance(), union2.advance()); + assert_eq!(union1.doc(), union3.advance()); + } + + // Reset and test seek functionality + let mut union1 = posting_list_union_from_docs_list(vecs); + let mut union2 = simple_union_from_docs_list(vecs); + let mut union3 = union_from_docs_list(vecs); + + for &seek_docid in skip_targets { + union1.seek(seek_docid); + union2.seek(seek_docid); + union3.seek(seek_docid); + + // Verify that all unions have the same document after seeking + assert_eq!(union3.doc(), union1.doc()); + assert_eq!(union3.doc(), union2.doc()); + } + } + + #[test] + fn test_union() { + aux_test_union(&[ + vec![1, 3333, 100000000u32], + vec![1, 2, 100000000u32], + vec![1, 2, 100000000u32], + vec![], + ]); + aux_test_union(&[ + vec![1, 3333, 100000000u32], + vec![1, 2, 100000000u32], + vec![1, 2, 100000000u32], + vec![], + ]); + aux_test_union(&[ + tests::sample_with_seed(100_000, 0.01, 1), + tests::sample_with_seed(100_000, 0.05, 2), + tests::sample_with_seed(100_000, 0.001, 3), + ]); + } + + fn test_aux_union_skip(docs_list: &[Vec], skip_targets: Vec) { + for constructor in [ + posting_list_union_from_docs_list, + simple_union_from_docs_list, + union_from_docs_list, + ] { + test_aux_union_skip_with_constructor(constructor, docs_list, skip_targets.clone()); + } + } + fn test_aux_union_skip_with_constructor( + constructor: F, + docs_list: &[Vec], + skip_targets: Vec, + ) where + F: Fn(&[Vec]) -> Box, + { + let mut btree_set = BTreeSet::new(); + for docs in docs_list { + btree_set.extend(docs.iter().cloned()); + } + let docset_factory = || { + let res: Box = constructor(docs_list); + res + }; + let mut docset = constructor(docs_list); + for el in btree_set { + assert_eq!(el, docset.doc()); + docset.advance(); + } + assert_eq!(docset.doc(), TERMINATED); + test_skip_against_unoptimized(docset_factory, skip_targets); + } + + #[test] + fn test_union_skip_corner_case() { + test_aux_union_skip(&[vec![165132, 167382], vec![25029, 25091]], vec![25029]); + } + + #[test] + fn test_union_skip_corner_case2() { + test_aux_union_skip( + &[vec![1u32, 1u32 + 100], vec![2u32, 1000u32, 10_000u32]], + vec![0u32, 1u32, 2u32, 3u32, 1u32 + 100, 2u32 + 100], + ); + } + + #[test] + fn test_union_skip_corner_case3() { + let mut docset = posting_list_union_from_docs_list(&[vec![0u32, 5u32], vec![1u32, 4u32]]); + assert_eq!(docset.doc(), 0u32); + assert_eq!(docset.seek(0u32), 0u32); + assert_eq!(docset.seek(0u32), 0u32); + assert_eq!(docset.doc(), 0u32) + } + + #[test] + fn test_union_skip_random() { + test_aux_union_skip( + &[ + vec![1, 2, 3, 7], + vec![1, 3, 9, 10000], + vec![1, 3, 8, 9, 100], + ], + vec![1, 2, 3, 5, 6, 7, 8, 100], + ); + test_aux_union_skip( + &[ + tests::sample_with_seed(100_000, 0.001, 1), + tests::sample_with_seed(100_000, 0.002, 2), + tests::sample_with_seed(100_000, 0.005, 3), + ], + tests::sample_with_seed(100_000, 0.01, 4), + ); + } + + #[test] + fn test_union_skip_specific() { + test_aux_union_skip( + &[ + vec![1, 2, 3, 7], + vec![1, 3, 9, 10000], + vec![1, 3, 8, 9, 100], + ], + vec![1, 2, 3, 7, 8, 9, 99, 100, 101, 500, 20000], + ); + } +} + +#[cfg(all(test, feature = "unstable"))] +mod bench { + + use test::Bencher; + + use crate::query::score_combiner::DoNothingCombiner; + use crate::query::{BufferedUnionScorer, ConstScorer, VecDocSet}; + use crate::{tests, DocId, DocSet, TERMINATED}; + + #[bench] + fn bench_union_3_high(bench: &mut Bencher) { + let union_docset: Vec> = vec![ + tests::sample_with_seed(100_000, 0.1, 0), + tests::sample_with_seed(100_000, 0.2, 1), + ]; + bench.iter(|| { + let mut v = BufferedUnionScorer::build( + union_docset + .iter() + .map(|doc_ids| VecDocSet::from(doc_ids.clone())) + .map(|docset| ConstScorer::new(docset, 1.0)) + .collect::>(), + DoNothingCombiner::default, + ); + while v.doc() != TERMINATED { + v.advance(); + } + }); + } + #[bench] + fn bench_union_3_low(bench: &mut Bencher) { + let union_docset: Vec> = vec![ + tests::sample_with_seed(100_000, 0.01, 0), + tests::sample_with_seed(100_000, 0.05, 1), + tests::sample_with_seed(100_000, 0.001, 2), + ]; + bench.iter(|| { + let mut v = BufferedUnionScorer::build( + union_docset + .iter() + .map(|doc_ids| VecDocSet::from(doc_ids.clone())) + .map(|docset| ConstScorer::new(docset, 1.0)) + .collect::>(), + DoNothingCombiner::default, + ); + while v.doc() != TERMINATED { + v.advance(); + } + }); + } +} diff --git a/src/query/union/simple_union.rs b/src/query/union/simple_union.rs new file mode 100644 index 000000000..041d4c90e --- /dev/null +++ b/src/query/union/simple_union.rs @@ -0,0 +1,112 @@ +use crate::docset::{DocSet, TERMINATED}; +use crate::postings::Postings; +use crate::DocId; + +/// A `SimpleUnion` is a `DocSet` that is the union of multiple `DocSet`. +/// Unlike `BufferedUnion`, it doesn't do any horizon precomputation. +/// +/// For that reason SimpleUnion is a good choice for queries that skip a lot. +pub struct SimpleUnion { + docsets: Vec, + doc: DocId, +} + +impl SimpleUnion { + pub(crate) fn build(mut docsets: Vec) -> SimpleUnion { + docsets.retain(|docset| docset.doc() != TERMINATED); + let mut docset = SimpleUnion { docsets, doc: 0 }; + + docset.initialize_first_doc_id(); + + docset + } + + fn initialize_first_doc_id(&mut self) { + let mut next_doc = TERMINATED; + + for docset in &self.docsets { + next_doc = next_doc.min(docset.doc()); + } + self.doc = next_doc; + } + + fn advance_to_next(&mut self) -> DocId { + let mut next_doc = TERMINATED; + + for docset in &mut self.docsets { + if docset.doc() <= self.doc { + docset.advance(); + } + next_doc = next_doc.min(docset.doc()); + } + self.doc = next_doc; + self.doc + } +} + +impl Postings for SimpleUnion { + fn term_freq(&self) -> u32 { + let mut term_freq = 0; + for docset in &self.docsets { + let doc = docset.doc(); + if doc == self.doc { + term_freq += docset.term_freq(); + } + } + term_freq + } + + fn append_positions_with_offset(&mut self, offset: u32, output: &mut Vec) { + for docset in &mut self.docsets { + let doc = docset.doc(); + if doc == self.doc { + docset.append_positions_with_offset(offset, output); + } + } + output.sort_unstable(); + output.dedup(); + } +} + +impl DocSet for SimpleUnion { + fn advance(&mut self) -> DocId { + self.advance_to_next(); + self.doc + } + + fn seek(&mut self, target: DocId) -> DocId { + self.doc = TERMINATED; + for docset in &mut self.docsets { + if docset.doc() < target { + docset.seek(target); + } + if docset.doc() < self.doc { + self.doc = docset.doc(); + } + } + self.doc + } + + fn doc(&self) -> DocId { + self.doc + } + + fn size_hint(&self) -> u32 { + self.docsets + .iter() + .map(|docset| docset.size_hint()) + .max() + .unwrap_or(0u32) + } + + fn count_including_deleted(&mut self) -> u32 { + if self.doc == TERMINATED { + return 0u32; + } + let mut count = 1u32; + while self.advance_to_next() != TERMINATED { + count += 1; + } + count + } +}