From 6ea6f4bfcda62d166277821a7b27de0e24a1da2e Mon Sep 17 00:00:00 2001 From: Rob Young Date: Wed, 20 May 2020 14:25:24 +0100 Subject: [PATCH] Add offset to TopDocsCollector (#826) * Add offset to TopDocsCollector Add an offset to TopDocsCollector and TopDocs to make it clearer how to handle pagination. Closes #822 * Address review comments - Make Debug formatting of TopDocs clearer. - Add unit tests for limit and offset on TopCollector. - Change API for using offset to a fluent interface. - Add some context to the docstring to clarify what limit and offset are equivalent to in other projects. * Changes required by rebase on e25284 - Pass Collector into TweakedScoreTopCollector and CustomScoreTopCollector. - Add std:: qualifier to f32, i32 etc. Not sure why this was not failing already. - Add unit tests for TopDocs with offset including for tweaked and custom score collectors. In order to convert a TopCollector to a TopCollector I had to add a `into_tscore` method to `TopCollector`. This is a hack but I don't know how to avoid it. --- src/collector/custom_score_top_collector.rs | 4 +- src/collector/top_collector.rs | 78 +++++++++- src/collector/top_score_collector.rs | 156 ++++++++++++++++++-- src/collector/tweak_score_top_collector.rs | 4 +- src/docset.rs | 2 +- 5 files changed, 224 insertions(+), 20 deletions(-) diff --git a/src/collector/custom_score_top_collector.rs b/src/collector/custom_score_top_collector.rs index 18eac1cce..0400eb80e 100644 --- a/src/collector/custom_score_top_collector.rs +++ b/src/collector/custom_score_top_collector.rs @@ -13,11 +13,11 @@ where { pub fn new( custom_scorer: TCustomScorer, - limit: usize, + collector: TopCollector, ) -> CustomScoreTopCollector { CustomScoreTopCollector { custom_scorer, - collector: TopCollector::with_limit(limit), + collector, } } } diff --git a/src/collector/top_collector.rs b/src/collector/top_collector.rs index e930d9024..8e603d66b 100644 --- a/src/collector/top_collector.rs +++ b/src/collector/top_collector.rs @@ -57,6 +57,7 @@ impl Eq for ComparableDoc {} pub(crate) struct TopCollector { pub limit: usize, + pub offset: usize, _marker: PhantomData, } @@ -72,14 +73,20 @@ where if limit < 1 { panic!("Limit must be strictly greater than 0."); } - TopCollector { + Self { limit, + offset: 0, _marker: PhantomData, } } - pub fn limit(&self) -> usize { - self.limit + /// Skip the first "offset" documents when collecting. + /// + /// This is equivalent to `OFFSET` in MySQL or PostgreSQL and `start` in + /// Lucene's TopDocsCollector. + pub fn and_offset(mut self, offset: usize) -> TopCollector { + self.offset = offset; + self } pub fn merge_fruits( @@ -92,7 +99,7 @@ where let mut top_collector = BinaryHeap::new(); for child_fruit in children { for (feature, doc) in child_fruit { - if top_collector.len() < self.limit { + if top_collector.len() < (self.limit + self.offset) { top_collector.push(ComparableDoc { feature, doc }); } else if let Some(mut head) = top_collector.peek_mut() { if head.feature < feature { @@ -104,6 +111,7 @@ where Ok(top_collector .into_sorted_vec() .into_iter() + .skip(self.offset) .map(|cdoc| (cdoc.feature, cdoc.doc)) .collect()) } @@ -113,7 +121,23 @@ where segment_id: SegmentLocalId, _: &SegmentReader, ) -> crate::Result> { - Ok(TopSegmentCollector::new(segment_id, self.limit)) + Ok(TopSegmentCollector::new( + segment_id, + self.limit + self.offset, + )) + } + + /// Create a new TopCollector with the same limit and offset. + /// + /// Ideally we would use Into but the blanket implementation seems to cause the Scorer traits + /// to fail. + #[doc(hidden)] + pub(crate) fn into_tscore(self) -> TopCollector { + TopCollector { + limit: self.limit, + offset: self.offset, + _marker: PhantomData, + } } } @@ -187,7 +211,7 @@ impl TopSegmentCollector { #[cfg(test)] mod tests { - use super::TopSegmentCollector; + use super::{TopCollector, TopSegmentCollector}; use crate::DocAddress; #[test] @@ -248,6 +272,48 @@ mod tests { top_collector_limit_3.harvest()[..2].to_vec(), ); } + + #[test] + fn test_top_collector_with_limit_and_offset() { + let collector = TopCollector::with_limit(2).and_offset(1); + + let results = collector + .merge_fruits(vec![vec![ + (0.9, DocAddress(0, 1)), + (0.8, DocAddress(0, 2)), + (0.7, DocAddress(0, 3)), + (0.6, DocAddress(0, 4)), + (0.5, DocAddress(0, 5)), + ]]) + .unwrap(); + + assert_eq!( + results, + vec![(0.8, DocAddress(0, 2)), (0.7, DocAddress(0, 3)),] + ); + } + + #[test] + fn test_top_collector_with_limit_larger_than_set_and_offset() { + let collector = TopCollector::with_limit(2).and_offset(1); + + let results = collector + .merge_fruits(vec![vec![(0.9, DocAddress(0, 1)), (0.8, DocAddress(0, 2))]]) + .unwrap(); + + assert_eq!(results, vec![(0.8, DocAddress(0, 2)),]); + } + + #[test] + fn test_top_collector_with_limit_and_offset_larger_than_set() { + let collector = TopCollector::with_limit(2).and_offset(20); + + let results = collector + .merge_fruits(vec![vec![(0.9, DocAddress(0, 1)), (0.8, DocAddress(0, 2))]]) + .unwrap(); + + assert_eq!(results, vec![]); + } } #[cfg(all(test, feature = "unstable"))] diff --git a/src/collector/top_score_collector.rs b/src/collector/top_score_collector.rs index 260103e98..f5145f407 100644 --- a/src/collector/top_score_collector.rs +++ b/src/collector/top_score_collector.rs @@ -60,7 +60,11 @@ pub struct TopDocs(TopCollector); impl fmt::Debug for TopDocs { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "TopDocs({})", self.0.limit()) + write!( + f, + "TopDocs(limit={}, offset={})", + self.0.limit, self.0.offset + ) } } @@ -104,6 +108,45 @@ impl TopDocs { TopDocs(TopCollector::with_limit(limit)) } + /// Skip the first "offset" documents when collecting. + /// + /// This is equivalent to `OFFSET` in MySQL or PostgreSQL and `start` in + /// Lucene's TopDocsCollector. + /// + /// ```rust + /// use tantivy::collector::TopDocs; + /// use tantivy::query::QueryParser; + /// use tantivy::schema::{Schema, TEXT}; + /// use tantivy::{doc, DocAddress, Index}; + /// + /// let mut schema_builder = Schema::builder(); + /// let title = schema_builder.add_text_field("title", TEXT); + /// let schema = schema_builder.build(); + /// let index = Index::create_in_ram(schema); + /// + /// let mut index_writer = index.writer_with_num_threads(1, 3_000_000).unwrap(); + /// index_writer.add_document(doc!(title => "The Name of the Wind")); + /// index_writer.add_document(doc!(title => "The Diary of Muadib")); + /// index_writer.add_document(doc!(title => "A Dairy Cow")); + /// index_writer.add_document(doc!(title => "The Diary of a Young Girl")); + /// index_writer.add_document(doc!(title => "The Diary of Lena Mukhina")); + /// assert!(index_writer.commit().is_ok()); + /// + /// let reader = index.reader().unwrap(); + /// let searcher = reader.searcher(); + /// + /// let query_parser = QueryParser::for_index(&index, vec![title]); + /// let query = query_parser.parse_query("diary").unwrap(); + /// let top_docs = searcher.search(&query, &TopDocs::with_limit(2).and_offset(1)).unwrap(); + /// + /// assert_eq!(top_docs.len(), 2); + /// assert_eq!(&top_docs[0], &(0.5204813, DocAddress(0, 4))); + /// assert_eq!(&top_docs[1], &(0.4793185, DocAddress(0, 3))); + /// ``` + pub fn and_offset(self, offset: usize) -> TopDocs { + TopDocs(self.0.and_offset(offset)) + } + /// Set top-K to rank documents by a given fast field. /// /// ```rust @@ -284,7 +327,7 @@ impl TopDocs { TScoreSegmentTweaker: ScoreSegmentTweaker + 'static, TScoreTweaker: ScoreTweaker, { - TweakedScoreTopCollector::new(score_tweaker, self.0.limit()) + TweakedScoreTopCollector::new(score_tweaker, self.0.into_tscore()) } /// Ranks the documents using a custom score. @@ -398,7 +441,7 @@ impl TopDocs { TCustomSegmentScorer: CustomSegmentScorer + 'static, TCustomScorer: CustomScorer, { - CustomScoreTopCollector::new(custom_score, self.0.limit()) + CustomScoreTopCollector::new(custom_score, self.0.into_tscore()) } } @@ -434,10 +477,10 @@ impl Collector for TopDocs { segment_reader: &SegmentReader, ) -> crate::Result<::Fruit> { let mut heap: BinaryHeap> = - BinaryHeap::with_capacity(self.0.limit); + BinaryHeap::with_capacity(self.0.limit + self.0.offset); // first we fill the heap with the first `limit` elements. let mut doc = scorer.doc(); - while doc != TERMINATED && heap.len() < self.0.limit { + while doc != TERMINATED && heap.len() < (self.0.limit + self.0.offset) { if !segment_reader.is_deleted(doc) { let score = scorer.score(); heap.push(ComparableDoc { @@ -448,7 +491,7 @@ impl Collector for TopDocs { doc = scorer.advance(); } - let threshold = heap.peek().map(|el| el.feature).unwrap_or(f32::MIN); + let threshold = heap.peek().map(|el| el.feature).unwrap_or(std::f32::MIN); if let Some(delete_bitset) = segment_reader.delete_bitset() { scorer.for_each_pruning(threshold, &mut |doc, score| { @@ -458,7 +501,7 @@ impl Collector for TopDocs { doc, }; } - heap.peek().map(|el| el.feature).unwrap_or(f32::MIN) + heap.peek().map(|el| el.feature).unwrap_or(std::f32::MIN) }); } else { scorer.for_each_pruning(threshold, &mut |doc, score| { @@ -466,7 +509,7 @@ impl Collector for TopDocs { feature: score, doc, }; - heap.peek().map(|el| el.feature).unwrap_or(f32::MIN) + heap.peek().map(|el| el.feature).unwrap_or(std::f32::MIN) }); } @@ -501,7 +544,7 @@ mod tests { use crate::collector::Collector; use crate::query::{AllQuery, Query, QueryParser}; use crate::schema::{Field, Schema, FAST, STORED, TEXT}; - use crate::DocAddress; + use crate::{DocAddress, DocId, SegmentReader}; use crate::Index; use crate::IndexWriter; use crate::Score; @@ -544,6 +587,26 @@ mod tests { ); } + #[test] + fn test_top_collector_not_at_capacity_with_offset() { + let index = make_index(); + let field = index.schema().get_field("text").unwrap(); + let query_parser = QueryParser::for_index(&index, vec![field]); + let text_query = query_parser.parse_query("droopy tax").unwrap(); + let score_docs: Vec<(Score, DocAddress)> = index + .reader() + .unwrap() + .searcher() + .search(&text_query, &TopDocs::with_limit(4).and_offset(2)) + .unwrap(); + assert_eq!( + score_docs, + vec![ + (0.48527452, DocAddress(0, 0)) + ] + ); + } + #[test] fn test_top_collector_at_capacity() { let index = make_index(); @@ -565,6 +628,27 @@ mod tests { ); } + #[test] + fn test_top_collector_at_capacity_with_offset() { + let index = make_index(); + let field = index.schema().get_field("text").unwrap(); + let query_parser = QueryParser::for_index(&index, vec![field]); + let text_query = query_parser.parse_query("droopy tax").unwrap(); + let score_docs: Vec<(Score, DocAddress)> = index + .reader() + .unwrap() + .searcher() + .search(&text_query, &TopDocs::with_limit(2).and_offset(1)) + .unwrap(); + assert_eq!( + score_docs, + vec![ + (0.5376842, DocAddress(0u32, 2)), + (0.48527452, DocAddress(0, 0)) + ] + ); + } + #[test] fn test_top_collector_stable_sorting() { let index = make_index(); @@ -678,6 +762,60 @@ mod tests { } } + #[test] + fn test_tweak_score_top_collector_with_offset() { + let index = make_index(); + let field = index.schema().get_field("text").unwrap(); + let query_parser = QueryParser::for_index(&index, vec![field]); + let text_query = query_parser.parse_query("droopy tax").unwrap(); + let collector = TopDocs::with_limit(2).and_offset(1).tweak_score(move |_segment_reader: &SegmentReader| { + move |doc: DocId, _original_score: Score| { + doc + } + }); + let score_docs: Vec<(u32, DocAddress)> = index + .reader() + .unwrap() + .searcher() + .search(&text_query, &collector) + .unwrap(); + + assert_eq!( + score_docs, + vec![ + (1, DocAddress(0, 1)), + (0, DocAddress(0, 0)), + ] + ); + } + + #[test] + fn test_custom_score_top_collector_with_offset() { + let index = make_index(); + let field = index.schema().get_field("text").unwrap(); + let query_parser = QueryParser::for_index(&index, vec![field]); + let text_query = query_parser.parse_query("droopy tax").unwrap(); + let collector = TopDocs::with_limit(2).and_offset(1).custom_score(move |_segment_reader: &SegmentReader| { + move |doc: DocId| { + doc + } + }); + let score_docs: Vec<(u32, DocAddress)> = index + .reader() + .unwrap() + .searcher() + .search(&text_query, &collector) + .unwrap(); + + assert_eq!( + score_docs, + vec![ + (1, DocAddress(0, 1)), + (0, DocAddress(0, 0)), + ] + ); + } + fn index( query: &str, query_field: Field, diff --git a/src/collector/tweak_score_top_collector.rs b/src/collector/tweak_score_top_collector.rs index 6fdced2ec..3ec438996 100644 --- a/src/collector/tweak_score_top_collector.rs +++ b/src/collector/tweak_score_top_collector.rs @@ -14,11 +14,11 @@ where { pub fn new( score_tweaker: TScoreTweaker, - limit: usize, + collector: TopCollector, ) -> TweakedScoreTopCollector { TweakedScoreTopCollector { score_tweaker, - collector: TopCollector::with_limit(limit), + collector, } } } diff --git a/src/docset.rs b/src/docset.rs index cab25ebe1..68d38681d 100644 --- a/src/docset.rs +++ b/src/docset.rs @@ -7,7 +7,7 @@ use std::borrow::BorrowMut; /// /// This is not u32::MAX as one would have expected, due to the lack of SSE2 instructions /// to compare [u32; 4]. -pub const TERMINATED: DocId = i32::MAX as u32; +pub const TERMINATED: DocId = std::i32::MAX as u32; /// Represents an iterable set of sorted doc ids. pub trait DocSet {